1
0

Merge commit 'de119063f' into anoa/dinsic_release_1_18_x

* commit 'de119063f': (31 commits)
  Convert room list handler to async/await. (#7912)
  Element CSS and logo in email templates (#7919)
  Lint the contrib/ directory in CI and linting scripts, add synctl to linting script (#7914)
  Remove unused code from synapse.logging.utils. (#7897)
  Fix a typo in the sample config. (#7890)
  Fix deprecation warning: import ABC from collections.abc (#7892)
  Change sample config's postgres user to synapse_user (#7889)
  Fix deprecation warning due to invalid escape sequences (#7895)
  Remove Ubuntu Eoan that is now EOL (#7888)
  Fix the trace function for async functions. (#7872)
  Add help for creating a user via docker (#7885)
  Switch to Debian:Slim from Alpine for the docker image (#7839)
  Stop using 'device_max_stream_id' (#7882)
  Fix TypeError in synapse.notifier (#7880)
  Add a default limit (of 100) to get/sync operations. (#7858)
  Change "unknown room ver" logging to warning. (#7881)
  Convert device handler to async/await (#7871)
  Convert synapse.app to async/await. (#7868)
  Convert _base, profile, and _receipts handlers to async/await (#7860)
  Add admin endpoint to get members in a room. (#7842)
  ...
This commit is contained in:
Andrew Morgan
2020-08-03 17:38:45 -07:00
122 changed files with 1701 additions and 1143 deletions
+1
View File
@@ -0,0 +1 @@
Base docker image on Debian Buster rather than Alpine Linux. Contributed by @maquis196.
+1
View File
@@ -0,0 +1 @@
Add an admin API to list the users in a room. Contributed by Awesome Technologies Innovationslabor GmbH.
+1
View File
@@ -0,0 +1 @@
Consistently use `db_to_json` to convert from database values to JSON objects.
+1
View File
@@ -0,0 +1 @@
Add experimental support for running multiple pusher workers.
+1
View File
@@ -0,0 +1 @@
The default value of `filter_timeline_limit` was changed from -1 (no limit) to 100.
+1
View File
@@ -0,0 +1 @@
Fix a bug which allowed empty rooms to be rejoined over federation.
+1
View File
@@ -0,0 +1 @@
Convert _base, profile, and _receipts handlers to async/await.
+1
View File
@@ -0,0 +1 @@
Optimise queueing of inbound replication commands.
+1
View File
@@ -0,0 +1 @@
Fix 'Unable to find a suitable guest user ID' error when using multiple client_reader workers.
+1
View File
@@ -0,0 +1 @@
Convert synapse.app and federation client to async/await.
+1
View File
@@ -0,0 +1 @@
Add experimental support for moving typing off master.
+1
View File
@@ -0,0 +1 @@
Convert device handler to async/await.
+1
View File
@@ -0,0 +1 @@
Fix a long standing bug where the tracing of async functions with opentracing was broken.
+1
View File
@@ -0,0 +1 @@
Fix "TypeError in `synapse.notifier`" exceptions.
+1
View File
@@ -0,0 +1 @@
Change "unknown room version" logging from 'error' to 'warning'.
+1
View File
@@ -0,0 +1 @@
Stop using `device_max_stream_id` table and just use `device_inbox.stream_id`.
+1
View File
@@ -0,0 +1 @@
Provide instructions on using `register_new_matrix_user` via docker.
+1
View File
@@ -0,0 +1 @@
Remove Ubuntu Eoan from the list of `.deb` packages that we build as it is now end-of-life. Contributed by @gary-kim.
+1
View File
@@ -0,0 +1 @@
Change the sample config postgres user section to use `synapse_user` instead of `synapse` to align with the documentation.
+1
View File
@@ -0,0 +1 @@
Fix typo in generated config file. Contributed by @ThiefMaster.
+1
View File
@@ -0,0 +1 @@
Import ABC from `collections.abc` for Python 3.10 compatibility.
+1
View File
@@ -0,0 +1 @@
Fix deprecation warning due to invalid escape sequences.
+2
View File
@@ -0,0 +1,2 @@
Remove unused functions `time_function`, `trace_function`, `get_previous_frames`
and `get_previous_frame` from `synapse.logging.utils` module.
+1
View File
@@ -0,0 +1 @@
Convert `RoomListHandler` to async/await.
+1
View File
@@ -0,0 +1 @@
Lint the `contrib/` directory in CI and linting scripts, add `synctl` to the linting script for consistency with CI.
+1
View File
@@ -0,0 +1 @@
Use Element CSS and logo in notification emails when app name is Element.
+10 -11
View File
@@ -17,9 +17,6 @@
""" Starts a synapse client console. """
from __future__ import print_function
from twisted.internet import reactor, defer, threads
from http import TwistedHttpClient
import argparse
import cmd
import getpass
@@ -28,12 +25,14 @@ import shlex
import sys
import time
import urllib
import urlparse
from http import TwistedHttpClient
import nacl.signing
import nacl.encoding
import nacl.signing
import urlparse
from signedjson.sign import SignatureVerifyException, verify_signed_json
from signedjson.sign import verify_signed_json, SignatureVerifyException
from twisted.internet import defer, reactor, threads
CONFIG_JSON = "cmdclient_config.json"
@@ -493,7 +492,7 @@ class SynapseCmd(cmd.Cmd):
"list messages <roomid> from=END&to=START&limit=3"
"""
args = self._parse(line, ["type", "roomid", "qp"])
if not "type" in args or not "roomid" in args:
if "type" not in args or "roomid" not in args:
print("Must specify type and room ID.")
return
if args["type"] not in ["members", "messages"]:
@@ -508,7 +507,7 @@ class SynapseCmd(cmd.Cmd):
try:
key_value = key_value_str.split("=")
qp[key_value[0]] = key_value[1]
except:
except Exception:
print("Bad query param: %s" % key_value)
return
@@ -585,7 +584,7 @@ class SynapseCmd(cmd.Cmd):
parsed_url = urlparse.urlparse(args["path"])
qp.update(urlparse.parse_qs(parsed_url.query))
args["path"] = parsed_url.path
except:
except Exception:
pass
reactor.callFromThread(
@@ -772,10 +771,10 @@ def main(server_url, identity_server_url, username, token, config_path):
syn_cmd.config = json.load(config)
try:
http_client.verbose = "on" == syn_cmd.config["verbose"]
except:
except Exception:
pass
print("Loaded config from %s" % config_path)
except:
except Exception:
pass
# Twisted-specific: Runs the command processor in Twisted's event loop
+5 -5
View File
@@ -14,14 +14,14 @@
# limitations under the License.
from __future__ import print_function
from twisted.web.client import Agent, readBody
from twisted.web.http_headers import Headers
from twisted.internet import defer, reactor
from pprint import pformat
import json
import urllib
from pprint import pformat
from twisted.internet import defer, reactor
from twisted.web.client import Agent, readBody
from twisted.web.http_headers import Headers
class HttpClient(object):
+21 -34
View File
@@ -28,27 +28,24 @@ Currently assumes the local address is localhost:<port>
"""
from synapse.federation import ReplicationHandler
from synapse.federation.units import Pdu
from synapse.util import origin_from_ucid
from synapse.app.homeserver import SynapseHomeServer
# from synapse.logging.utils import log_function
from twisted.internet import reactor, defer
from twisted.python import log
import argparse
import curses.wrapper
import json
import logging
import os
import re
import cursesio
import curses.wrapper
from twisted.internet import defer, reactor
from twisted.python import log
from synapse.app.homeserver import SynapseHomeServer
from synapse.federation import ReplicationHandler
from synapse.federation.units import Pdu
from synapse.util import origin_from_ucid
# from synapse.logging.utils import log_function
logger = logging.getLogger("example")
@@ -75,7 +72,7 @@ class InputOutput(object):
"""
try:
m = re.match("^join (\S+)$", line)
m = re.match(r"^join (\S+)$", line)
if m:
# The `sender` wants to join a room.
(room_name,) = m.groups()
@@ -84,7 +81,7 @@ class InputOutput(object):
# self.print_line("OK.")
return
m = re.match("^invite (\S+) (\S+)$", line)
m = re.match(r"^invite (\S+) (\S+)$", line)
if m:
# `sender` wants to invite someone to a room
room_name, invitee = m.groups()
@@ -93,7 +90,7 @@ class InputOutput(object):
# self.print_line("OK.")
return
m = re.match("^send (\S+) (.*)$", line)
m = re.match(r"^send (\S+) (.*)$", line)
if m:
# `sender` wants to message a room
room_name, body = m.groups()
@@ -102,7 +99,7 @@ class InputOutput(object):
# self.print_line("OK.")
return
m = re.match("^backfill (\S+)$", line)
m = re.match(r"^backfill (\S+)$", line)
if m:
# we want to backfill a room
(room_name,) = m.groups()
@@ -201,16 +198,6 @@ class HomeServer(ReplicationHandler):
% (pdu.context, pdu.pdu_type, json.dumps(pdu.content))
)
# def on_state_change(self, pdu):
##self.output.print_line("#%s (state) %s *** %s" %
##(pdu.context, pdu.state_key, pdu.pdu_type)
##)
# if "joinee" in pdu.content:
# self._on_join(pdu.context, pdu.content["joinee"])
# elif "invitee" in pdu.content:
# self._on_invite(pdu.origin, pdu.context, pdu.content["invitee"])
def _on_message(self, pdu):
""" We received a message
"""
@@ -314,7 +301,7 @@ class HomeServer(ReplicationHandler):
return self.replication_layer.backfill(dest, room_name, limit)
def _get_room_remote_servers(self, room_name):
return [i for i in self.joined_rooms.setdefault(room_name).servers]
return list(self.joined_rooms.setdefault(room_name).servers)
def _get_or_create_room(self, room_name):
return self.joined_rooms.setdefault(room_name, Room(room_name))
@@ -334,7 +321,7 @@ def main(stdscr):
user = args.user
server_name = origin_from_ucid(user)
## Set up logging ##
# Set up logging
root_logger = logging.getLogger()
@@ -354,7 +341,7 @@ def main(stdscr):
observer = log.PythonLoggingObserver()
observer.start()
## Set up synapse server
# Set up synapse server
curses_stdio = cursesio.CursesStdIO(stdscr)
input_output = InputOutput(curses_stdio, user)
@@ -368,16 +355,16 @@ def main(stdscr):
input_output.set_home_server(hs)
## Add input_output logger
# Add input_output logger
io_logger = IOLoggerHandler(input_output)
io_logger.setFormatter(formatter)
root_logger.addHandler(io_logger)
## Start! ##
# Start!
try:
port = int(server_name.split(":")[1])
except:
except Exception:
port = 12345
app_hs.get_http_server().start_listening(port)
+10 -11
View File
@@ -1,5 +1,13 @@
from __future__ import print_function
import argparse
import cgi
import datetime
import json
import pydot
import urllib2
# Copyright 2014-2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -15,15 +23,6 @@ from __future__ import print_function
# limitations under the License.
import sqlite3
import pydot
import cgi
import json
import datetime
import argparse
import urllib2
def make_name(pdu_id, origin):
return "%s@%s" % (pdu_id, origin)
@@ -33,7 +32,7 @@ def make_graph(pdus, room, filename_prefix):
node_map = {}
origins = set()
colors = set(("red", "green", "blue", "yellow", "purple"))
colors = {"red", "green", "blue", "yellow", "purple"}
for pdu in pdus:
origins.add(pdu.get("origin"))
@@ -49,7 +48,7 @@ def make_graph(pdus, room, filename_prefix):
try:
c = colors.pop()
color_map[o] = c
except:
except Exception:
print("Run out of colours!")
color_map[o] = "black"
+7 -6
View File
@@ -13,12 +13,13 @@
# limitations under the License.
import sqlite3
import pydot
import cgi
import json
import datetime
import argparse
import cgi
import datetime
import json
import sqlite3
import pydot
from synapse.events import FrozenEvent
from synapse.util.frozenutils import unfreeze
@@ -98,7 +99,7 @@ def make_graph(db_name, room_id, file_prefix, limit):
for prev_id, _ in event.prev_events:
try:
end_node = node_map[prev_id]
except:
except Exception:
end_node = pydot.Node(name=prev_id, label="<<b>%s</b>>" % (prev_id,))
node_map[prev_id] = end_node
+11 -11
View File
@@ -1,5 +1,15 @@
from __future__ import print_function
import argparse
import cgi
import datetime
import pydot
import simplejson as json
from synapse.events import FrozenEvent
from synapse.util.frozenutils import unfreeze
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -15,16 +25,6 @@ from __future__ import print_function
# limitations under the License.
import pydot
import cgi
import simplejson as json
import datetime
import argparse
from synapse.events import FrozenEvent
from synapse.util.frozenutils import unfreeze
def make_graph(file_name, room_id, file_prefix, limit):
print("Reading lines")
with open(file_name) as f:
@@ -106,7 +106,7 @@ def make_graph(file_name, room_id, file_prefix, limit):
for prev_id, _ in event.prev_events:
try:
end_node = node_map[prev_id]
except:
except Exception:
end_node = pydot.Node(name=prev_id, label="<<b>%s</b>>" % (prev_id,))
node_map[prev_id] = end_node
+5 -5
View File
@@ -12,15 +12,15 @@ npm install jquery jsdom
"""
from __future__ import print_function
import gevent
import grequests
from BeautifulSoup import BeautifulSoup
import json
import urllib
import subprocess
import time
# ACCESS_TOKEN="" #
import gevent
import grequests
from BeautifulSoup import BeautifulSoup
ACCESS_TOKEN = ""
MATRIXBASE = "https://matrix.org/_matrix/client/api/v1/"
MYUSERNAME = "@davetest:matrix.org"
+4 -2
View File
@@ -1,10 +1,12 @@
#!/usr/bin/env python
from __future__ import print_function
from argparse import ArgumentParser
import json
import requests
import sys
import urllib
from argparse import ArgumentParser
import requests
try:
raw_input
+23 -34
View File
@@ -16,35 +16,31 @@ ARG PYTHON_VERSION=3.7
###
### Stage 0: builder
###
FROM docker.io/python:${PYTHON_VERSION}-alpine3.11 as builder
FROM docker.io/python:${PYTHON_VERSION}-slim as builder
# install the OS build deps
RUN apk add \
build-base \
libffi-dev \
libjpeg-turbo-dev \
libwebp-dev \
libressl-dev \
libxslt-dev \
linux-headers \
postgresql-dev \
zlib-dev
# build things which have slow build steps, before we copy synapse, so that
# the layer can be cached.
#
# (we really just care about caching a wheel here, as the "pip install" below
# will install them again.)
RUN apt-get update && apt-get install -y \
build-essential \
libpq-dev \
&& rm -rf /var/lib/apt/lists/*
# Build dependencies that are not available as wheels, to speed up rebuilds
RUN pip install --prefix="/install" --no-warn-script-location \
cryptography \
msgpack-python \
pillow \
pynacl
frozendict \
jaeger-client \
opentracing \
prometheus-client \
psycopg2 \
pycparser \
pyrsistent \
pyyaml \
simplejson \
threadloop \
thrift
# now install synapse and all of the python deps to /install.
COPY synapse /synapse/synapse/
COPY scripts /synapse/scripts/
COPY MANIFEST.in README.rst setup.py synctl /synapse/
@@ -56,20 +52,13 @@ RUN pip install --prefix="/install" --no-warn-script-location \
### Stage 1: runtime
###
FROM docker.io/python:${PYTHON_VERSION}-alpine3.11
FROM docker.io/python:${PYTHON_VERSION}-slim
# xmlsec is required for saml support
RUN apk add --no-cache --virtual .runtime_deps \
libffi \
libjpeg-turbo \
libwebp \
libressl \
libxslt \
libpq \
zlib \
su-exec \
tzdata \
xmlsec
RUN apt-get update && apt-get install -y \
libpq5 \
xmlsec1 \
gosu \
&& rm -rf /var/lib/apt/lists/*
COPY --from=builder /install /usr/local
COPY ./docker/start.py /start.py
+15
View File
@@ -94,6 +94,21 @@ The following environment variables are supported in run mode:
* `UID`, `GID`: the user and group id to run Synapse as. Defaults to `991`, `991`.
* `TZ`: the [timezone](https://en.wikipedia.org/wiki/List_of_tz_database_time_zones) the container will run with. Defaults to `UTC`.
## Generating an (admin) user
After synapse is running, you may wish to create a user via `register_new_matrix_user`.
This requires a `registration_shared_secret` to be set in your config file. Synapse
must be restarted to pick up this change.
You can then call the script:
```
docker exec -it synapse register_new_matrix_user http://localhost:8008 -c /data/homeserver.yaml --help
```
Remember to remove the `registration_shared_secret` and restart if you no-longer need it.
## TLS support
The default configuration exposes a single HTTP port: http://localhost:8008. It
+6 -6
View File
@@ -120,7 +120,7 @@ def generate_config_from_template(config_dir, config_path, environ, ownership):
if ownership is not None:
subprocess.check_output(["chown", "-R", ownership, "/data"])
args = ["su-exec", ownership] + args
args = ["gosu", ownership] + args
subprocess.check_output(args)
@@ -172,8 +172,8 @@ def run_generate_config(environ, ownership):
# make sure that synapse has perms to write to the data dir.
subprocess.check_output(["chown", ownership, data_dir])
args = ["su-exec", ownership] + args
os.execv("/sbin/su-exec", args)
args = ["gosu", ownership] + args
os.execv("/usr/sbin/gosu", args)
else:
os.execv("/usr/local/bin/python", args)
@@ -189,7 +189,7 @@ def main(args, environ):
ownership = "{}:{}".format(desired_uid, desired_gid)
if ownership is None:
log("Will not perform chmod/su-exec as UserID already matches request")
log("Will not perform chmod/gosu as UserID already matches request")
# In generate mode, generate a configuration and missing keys, then exit
if mode == "generate":
@@ -236,8 +236,8 @@ running with 'migrate_config'. See the README for more details.
args = ["python", "-m", synapse_worker, "--config-path", config_path]
if ownership is not None:
args = ["su-exec", ownership] + args
os.execv("/sbin/su-exec", args)
args = ["gosu", ownership] + args
os.execv("/usr/sbin/gosu", args)
else:
os.execv("/usr/local/bin/python", args)
+33 -1
View File
@@ -319,11 +319,43 @@ Response:
}
```
# Room Members API
The Room Members admin API allows server admins to get a list of all members of a room.
The response includes the following fields:
* `members` - A list of all the members that are present in the room, represented by their ids.
* `total` - Total number of members in the room.
## Usage
A standard request:
```
GET /_synapse/admin/v1/rooms/<room_id>/members
{}
```
Response:
```
{
"members": [
"@foo:matrix.org",
"@bar:matrix.org",
"@foobar:matrix.org
],
"total": 3
}
```
# Delete Room API
The Delete Room admin API allows server admins to remove rooms from server
and block these rooms.
It is a combination and improvement of "[Shutdown room](shutdown_room.md)"
It is a combination and improvement of "[Shutdown room](shutdown_room.md)"
and "[Purge room](purge_room.md)" API.
Shuts down a room. Moves all local users and room aliases automatically to a
+5 -11
View File
@@ -38,6 +38,11 @@ the reverse proxy and the homeserver.
server {
listen 443 ssl;
listen [::]:443 ssl;
# For the federation port
listen 8448 ssl default_server;
listen [::]:8448 ssl default_server;
server_name matrix.example.com;
location /_matrix {
@@ -48,17 +53,6 @@ server {
client_max_body_size 10M;
}
}
server {
listen 8448 ssl default_server;
listen [::]:8448 ssl default_server;
server_name example.com;
location / {
proxy_pass http://localhost:8008;
proxy_set_header X-Forwarded-For $remote_addr;
}
}
```
**NOTE**: Do not add a path after the port in `proxy_pass`, otherwise nginx will
+5 -3
View File
@@ -102,7 +102,9 @@ pid_file: DATADIR/homeserver.pid
#gc_thresholds: [700, 10, 10]
# Set the limit on the returned events in the timeline in the get
# and sync operations. The default value is -1, means no upper limit.
# and sync operations. The default value is 100. -1 means no upper limit.
#
# Uncomment the following to increase the limit to 5000.
#
#filter_timeline_limit: 5000
@@ -146,7 +148,7 @@ pid_file: DATADIR/homeserver.pid
# names: a list of names of HTTP resources. See below for a list of
# valid resource names.
#
# compress: set to true to enable HTTP comression for this resource.
# compress: set to true to enable HTTP compression for this resource.
#
# additional_resources: Only valid for an 'http' listener. A map of
# additional endpoints which should be loaded via dynamic modules.
@@ -751,7 +753,7 @@ caches:
#database:
# name: psycopg2
# args:
# user: synapse
# user: synapse_user
# password: secretpassword
# database: synapse
# host: localhost
-1
View File
@@ -24,7 +24,6 @@ DISTS = (
"debian:sid",
"ubuntu:xenial",
"ubuntu:bionic",
"ubuntu:eoan",
"ubuntu:focal",
)
+1 -1
View File
@@ -11,7 +11,7 @@ if [ $# -ge 1 ]
then
files=$*
else
files="synapse tests scripts-dev scripts"
files="synapse tests scripts-dev scripts contrib synctl"
fi
echo "Linting these locations: $files"
+11 -1
View File
@@ -49,6 +49,7 @@ from synapse.storage.data_stores.main.media_repository import (
from synapse.storage.data_stores.main.profile import ProfileStore
from synapse.storage.data_stores.main.registration import (
RegistrationBackgroundUpdateStore,
find_max_generated_user_id_localpart,
)
from synapse.storage.data_stores.main.room import RoomBackgroundUpdateStore
from synapse.storage.data_stores.main.roommember import RoomMemberBackgroundUpdateStore
@@ -624,8 +625,10 @@ class Porter(object):
)
)
# Step 5. Do final post-processing
# Step 5. Set up sequences
self.progress.set_state("Setting up sequence generators")
await self._setup_state_group_id_seq()
await self._setup_user_id_seq()
self.progress.done()
except Exception as e:
@@ -795,6 +798,13 @@ class Porter(object):
return self.postgres_store.db.runInteraction("setup_state_group_id_seq", r)
def _setup_user_id_seq(self):
def r(txn):
next_id = find_max_generated_user_id_localpart(txn) + 1
txn.execute("ALTER SEQUENCE user_id_seq RESTART WITH %s", (next_id,))
return self.postgres_store.db.runInteraction("setup_user_id_seq", r)
##############################################
# The following is simply UI stuff
+7 -41
View File
@@ -21,7 +21,7 @@ from typing import Dict, Iterable, Optional, Set
from typing_extensions import ContextManager
from twisted.internet import address, defer, reactor
from twisted.internet import address, reactor
import synapse
import synapse.events
@@ -111,6 +111,7 @@ from synapse.rest.client.v1.room import (
RoomSendEventRestServlet,
RoomStateEventRestServlet,
RoomStateRestServlet,
RoomTypingRestServlet,
)
from synapse.rest.client.v1.voip import VoipRestServlet
from synapse.rest.client.v2_alpha import groups, sync, user_directory
@@ -374,9 +375,8 @@ class GenericWorkerPresence(BasePresenceHandler):
return _user_syncing()
@defer.inlineCallbacks
def notify_from_replication(self, states, stream_id):
parties = yield get_interested_parties(self.store, states)
async def notify_from_replication(self, states, stream_id):
parties = await get_interested_parties(self.store, states)
room_ids_to_states, users_to_states = parties
self.notifier.on_new_event(
@@ -386,8 +386,7 @@ class GenericWorkerPresence(BasePresenceHandler):
users=users_to_states.keys(),
)
@defer.inlineCallbacks
def process_replication_rows(self, token, rows):
async def process_replication_rows(self, token, rows):
states = [
UserPresenceState(
row.user_id,
@@ -405,7 +404,7 @@ class GenericWorkerPresence(BasePresenceHandler):
self.user_to_current_state[state.user_id] = state
stream_id = token
yield self.notify_from_replication(states, stream_id)
await self.notify_from_replication(states, stream_id)
def get_currently_syncing_users_for_replication(self) -> Iterable[str]:
return [
@@ -451,37 +450,6 @@ class GenericWorkerPresence(BasePresenceHandler):
await self._bump_active_client(user_id=user_id)
class GenericWorkerTyping(object):
def __init__(self, hs):
self._latest_room_serial = 0
self._reset()
def _reset(self):
"""
Reset the typing handler's data caches.
"""
# map room IDs to serial numbers
self._room_serials = {}
# map room IDs to sets of users currently typing
self._room_typing = {}
def process_replication_rows(self, token, rows):
if self._latest_room_serial > token:
# The master has gone backwards. To prevent inconsistent data, just
# clear everything.
self._reset()
# Set the latest serial token to whatever the server gave us.
self._latest_room_serial = token
for row in rows:
self._room_serials[row.room_id] = token
self._room_typing[row.room_id] = row.user_ids
def get_current_token(self) -> int:
return self._latest_room_serial
class GenericWorkerSlavedStore(
# FIXME(#3714): We need to add UserDirectoryStore as we write directly
# rather than going via the correct worker.
@@ -558,6 +526,7 @@ class GenericWorkerServer(HomeServer):
KeyUploadServlet(self).register(resource)
AccountDataServlet(self).register(resource)
RoomAccountDataServlet(self).register(resource)
RoomTypingRestServlet(self).register(resource)
sync.register_servlets(self, resource)
events.register_servlets(self, resource)
@@ -669,9 +638,6 @@ class GenericWorkerServer(HomeServer):
def build_presence_handler(self):
return GenericWorkerPresence(self)
def build_typing_handler(self):
return GenericWorkerTyping(self)
class GenericWorkerReplicationHandler(ReplicationDataHandler):
def __init__(self, hs):
+12 -13
View File
@@ -483,8 +483,7 @@ class SynapseService(service.Service):
_stats_process = []
@defer.inlineCallbacks
def phone_stats_home(hs, stats, stats_process=_stats_process):
async def phone_stats_home(hs, stats, stats_process=_stats_process):
logger.info("Gathering stats for reporting")
now = int(hs.get_clock().time())
uptime = int(now - hs.start_time)
@@ -522,28 +521,28 @@ def phone_stats_home(hs, stats, stats_process=_stats_process):
stats["python_version"] = "{}.{}.{}".format(
version.major, version.minor, version.micro
)
stats["total_users"] = yield hs.get_datastore().count_all_users()
stats["total_users"] = await hs.get_datastore().count_all_users()
total_nonbridged_users = yield hs.get_datastore().count_nonbridged_users()
total_nonbridged_users = await hs.get_datastore().count_nonbridged_users()
stats["total_nonbridged_users"] = total_nonbridged_users
daily_user_type_results = yield hs.get_datastore().count_daily_user_type()
daily_user_type_results = await hs.get_datastore().count_daily_user_type()
for name, count in daily_user_type_results.items():
stats["daily_user_type_" + name] = count
room_count = yield hs.get_datastore().get_room_count()
room_count = await hs.get_datastore().get_room_count()
stats["total_room_count"] = room_count
stats["daily_active_users"] = yield hs.get_datastore().count_daily_users()
stats["monthly_active_users"] = yield hs.get_datastore().count_monthly_users()
stats["daily_active_rooms"] = yield hs.get_datastore().count_daily_active_rooms()
stats["daily_messages"] = yield hs.get_datastore().count_daily_messages()
stats["daily_active_users"] = await hs.get_datastore().count_daily_users()
stats["monthly_active_users"] = await hs.get_datastore().count_monthly_users()
stats["daily_active_rooms"] = await hs.get_datastore().count_daily_active_rooms()
stats["daily_messages"] = await hs.get_datastore().count_daily_messages()
r30_results = yield hs.get_datastore().count_r30_users()
r30_results = await hs.get_datastore().count_r30_users()
for name, count in r30_results.items():
stats["r30_users_" + name] = count
daily_sent_messages = yield hs.get_datastore().count_daily_sent_messages()
daily_sent_messages = await hs.get_datastore().count_daily_sent_messages()
stats["daily_sent_messages"] = daily_sent_messages
stats["cache_factor"] = hs.config.caches.global_factor
stats["event_cache_size"] = hs.config.caches.event_cache_size
@@ -558,7 +557,7 @@ def phone_stats_home(hs, stats, stats_process=_stats_process):
logger.info("Reporting stats to %s: %s" % (hs.config.report_stats_endpoint, stats))
try:
yield hs.get_proxied_http_client().put_json(
await hs.get_proxied_http_client().put_json(
hs.config.report_stats_endpoint, stats
)
except Exception as e:
+36 -2
View File
@@ -19,10 +19,12 @@ import argparse
import errno
import os
from collections import OrderedDict
from hashlib import sha256
from io import open as io_open
from textwrap import dedent
from typing import Any, MutableMapping, Optional
from typing import Any, List, MutableMapping, Optional
import attr
import yaml
@@ -718,4 +720,36 @@ def find_config_files(search_paths):
return config_files
__all__ = ["Config", "RootConfig"]
@attr.s
class ShardedWorkerHandlingConfig:
"""Algorithm for choosing which instance is responsible for handling some
sharded work.
For example, the federation senders use this to determine which instances
handles sending stuff to a given destination (which is used as the `key`
below).
"""
instances = attr.ib(type=List[str])
def should_handle(self, instance_name: str, key: str) -> bool:
"""Whether this instance is responsible for handling the given key.
"""
# If multiple instances are not defined we always return true.
if not self.instances or len(self.instances) == 1:
return True
# We shard by taking the hash, modulo it by the number of instances and
# then checking whether this instance matches the instance at that
# index.
#
# (Technically this introduces some bias and is not entirely uniform,
# but since the hash is so large the bias is ridiculously small).
dest_hash = sha256(key.encode("utf8")).digest()
dest_int = int.from_bytes(dest_hash, byteorder="little")
remainder = dest_int % (len(self.instances))
return self.instances[remainder] == instance_name
__all__ = ["Config", "RootConfig", "ShardedWorkerHandlingConfig"]
+5
View File
@@ -137,3 +137,8 @@ class Config:
def read_config_files(config_files: List[str]): ...
def find_config_files(search_paths: List[str]): ...
class ShardedWorkerHandlingConfig:
instances: List[str]
def __init__(self, instances: List[str]) -> None: ...
def should_handle(self, instance_name: str, key: str) -> bool: ...
+1 -1
View File
@@ -55,7 +55,7 @@ DEFAULT_CONFIG = """\
#database:
# name: psycopg2
# args:
# user: synapse
# user: synapse_user
# password: secretpassword
# database: synapse
# host: localhost
+3 -34
View File
@@ -13,42 +13,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from hashlib import sha256
from typing import List, Optional
from typing import Optional
import attr
from netaddr import IPSet
from ._base import Config, ConfigError
@attr.s
class ShardedFederationSendingConfig:
"""Algorithm for choosing which federation sender instance is responsible
for which destionation host.
"""
instances = attr.ib(type=List[str])
def should_send_to(self, instance_name: str, destination: str) -> bool:
"""Whether this instance is responsible for sending transcations for
the given host.
"""
# If multiple federation senders are not defined we always return true.
if not self.instances or len(self.instances) == 1:
return True
# We shard by taking the hash, modulo it by the number of federation
# senders and then checking whether this instance matches the instance
# at that index.
#
# (Technically this introduces some bias and is not entirely uniform, but
# since the hash is so large the bias is ridiculously small).
dest_hash = sha256(destination.encode("utf8")).digest()
dest_int = int.from_bytes(dest_hash, byteorder="little")
remainder = dest_int % (len(self.instances))
return self.instances[remainder] == instance_name
from ._base import Config, ConfigError, ShardedWorkerHandlingConfig
class FederationConfig(Config):
@@ -61,7 +30,7 @@ class FederationConfig(Config):
self.send_federation = config.get("send_federation", True)
federation_sender_instances = config.get("federation_sender_instances") or []
self.federation_shard_config = ShardedFederationSendingConfig(
self.federation_shard_config = ShardedWorkerHandlingConfig(
federation_sender_instances
)
+4 -1
View File
@@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import Config
from ._base import Config, ShardedWorkerHandlingConfig
class PushConfig(Config):
@@ -24,6 +24,9 @@ class PushConfig(Config):
push_config = config.get("push", {})
self.push_include_content = push_config.get("include_content", True)
pusher_instances = config.get("pusher_instances") or []
self.pusher_shard_config = ShardedWorkerHandlingConfig(pusher_instances)
# There was a a 'redact_content' setting but mistakenly read from the
# 'email'section'. Check for the flag in the 'push' section, and log,
# but do not honour it to avoid nasty surprises when people upgrade.
+5 -3
View File
@@ -207,7 +207,7 @@ class ServerConfig(Config):
# errors when attempting to search for messages.
self.enable_search = config.get("enable_search", True)
self.filter_timeline_limit = config.get("filter_timeline_limit", -1)
self.filter_timeline_limit = config.get("filter_timeline_limit", 100)
# Whether we should block invites sent to users on this server
# (other than those sent by local server admins)
@@ -699,7 +699,9 @@ class ServerConfig(Config):
#gc_thresholds: [700, 10, 10]
# Set the limit on the returned events in the timeline in the get
# and sync operations. The default value is -1, means no upper limit.
# and sync operations. The default value is 100. -1 means no upper limit.
#
# Uncomment the following to increase the limit to 5000.
#
#filter_timeline_limit: 5000
@@ -743,7 +745,7 @@ class ServerConfig(Config):
# names: a list of names of HTTP resources. See below for a list of
# valid resource names.
#
# compress: set to true to enable HTTP comression for this resource.
# compress: set to true to enable HTTP compression for this resource.
#
# additional_resources: Only valid for an 'http' listener. A map of
# additional endpoints which should be loaded via dynamic modules.
+10 -9
View File
@@ -34,9 +34,11 @@ class WriterLocations:
Attributes:
events: The instance that writes to the event and backfill streams.
events: The instance that writes to the typing stream.
"""
events = attr.ib(default="master", type=str)
typing = attr.ib(default="master", type=str)
class WorkerConfig(Config):
@@ -93,16 +95,15 @@ class WorkerConfig(Config):
writers = config.get("stream_writers") or {}
self.writers = WriterLocations(**writers)
# Check that the configured writer for events also appears in
# Check that the configured writer for events and typing also appears in
# `instance_map`.
if (
self.writers.events != "master"
and self.writers.events not in self.instance_map
):
raise ConfigError(
"Instance %r is configured to write events but does not appear in `instance_map` config."
% (self.writers.events,)
)
for stream in ("events", "typing"):
instance = getattr(self.writers, stream)
if instance != "master" and instance not in self.instance_map:
raise ConfigError(
"Instance %r is configured to write %s but does not appear in `instance_map` config."
% (instance, stream)
)
def read_arguments(self, args):
# We support a bunch of command line arguments that override options in
+3 -3
View File
@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import collections.abc
import re
from typing import Any, Mapping, Union
@@ -424,7 +424,7 @@ def copy_power_levels_contents(
Raises:
TypeError if the input does not look like a valid power levels event content
"""
if not isinstance(old_power_levels, collections.Mapping):
if not isinstance(old_power_levels, collections.abc.Mapping):
raise TypeError("Not a valid power-levels content: %r" % (old_power_levels,))
power_levels = {}
@@ -434,7 +434,7 @@ def copy_power_levels_contents(
power_levels[k] = v
continue
if isinstance(v, collections.Mapping):
if isinstance(v, collections.abc.Mapping):
power_levels[k] = h = {}
for k1, v1 in v.items():
# we should only have one level of nesting
+19 -21
View File
@@ -374,29 +374,26 @@ class FederationClient(FederationBase):
"""
deferreds = self._check_sigs_and_hashes(room_version, pdus)
@defer.inlineCallbacks
def handle_check_result(pdu: EventBase, deferred: Deferred):
async def handle_check_result(pdu: EventBase, deferred: Deferred):
try:
res = yield make_deferred_yieldable(deferred)
res = await make_deferred_yieldable(deferred)
except SynapseError:
res = None
if not res:
# Check local db.
res = yield self.store.get_event(
res = await self.store.get_event(
pdu.event_id, allow_rejected=True, allow_none=True
)
if not res and pdu.origin != origin:
try:
res = yield defer.ensureDeferred(
self.get_pdu(
destinations=[pdu.origin],
event_id=pdu.event_id,
room_version=room_version,
outlier=outlier,
timeout=10000,
)
res = await self.get_pdu(
destinations=[pdu.origin],
event_id=pdu.event_id,
room_version=room_version,
outlier=outlier,
timeout=10000,
)
except SynapseError:
pass
@@ -995,24 +992,25 @@ class FederationClient(FederationBase):
raise RuntimeError("Failed to send to any server.")
@defer.inlineCallbacks
def get_room_complexity(self, destination, room_id):
async def get_room_complexity(
self, destination: str, room_id: str
) -> Optional[dict]:
"""
Fetch the complexity of a remote room from another server.
Args:
destination (str): The remote server
room_id (str): The room ID to ask about.
destination: The remote server
room_id: The room ID to ask about.
Returns:
Deferred[dict] or Deferred[None]: Dict contains the complexity
metric versions, while None means we could not fetch the complexity.
Dict contains the complexity metric versions, while None means we
could not fetch the complexity.
"""
try:
complexity = yield self.transport_layer.get_room_complexity(
complexity = await self.transport_layer.get_room_complexity(
destination=destination, room_id=room_id
)
defer.returnValue(complexity)
return complexity
except CodeMessageException as e:
# We didn't manage to get it -- probably a 404. We are okay if other
# servers don't give it to us.
@@ -1029,4 +1027,4 @@ class FederationClient(FederationBase):
# If we don't manage to find it, return None. It's not an error if a
# server doesn't give it to us.
defer.returnValue(None)
return None
+75 -52
View File
@@ -15,7 +15,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Any, Callable, Dict, List, Match, Optional, Tuple, Union
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Dict,
List,
Match,
Optional,
Tuple,
Union,
)
from canonicaljson import json
from prometheus_client import Counter, Histogram
@@ -56,6 +67,9 @@ from synapse.util import glob_to_regex, unwrapFirstError
from synapse.util.async_helpers import Linearizer, concurrently_execute
from synapse.util.caches.response_cache import ResponseCache
if TYPE_CHECKING:
from synapse.server import HomeServer
# when processing incoming transactions, we try to handle multiple rooms in
# parallel, up to this limit.
TRANSACTION_CONCURRENCY_LIMIT = 10
@@ -768,11 +782,30 @@ class FederationHandlerRegistry(object):
query type for incoming federation traffic.
"""
def __init__(self):
self.edu_handlers = {}
self.query_handlers = {}
def __init__(self, hs: "HomeServer"):
self.config = hs.config
self.http_client = hs.get_simple_http_client()
self.clock = hs.get_clock()
self._instance_name = hs.get_instance_name()
def register_edu_handler(self, edu_type: str, handler: Callable[[str, dict], None]):
# These are safe to load in monolith mode, but will explode if we try
# and use them. However we have guards before we use them to ensure that
# we don't route to ourselves, and in monolith mode that will always be
# the case.
self._get_query_client = ReplicationGetQueryRestServlet.make_client(hs)
self._send_edu = ReplicationFederationSendEduRestServlet.make_client(hs)
self.edu_handlers = (
{}
) # type: Dict[str, Callable[[str, dict], Awaitable[None]]]
self.query_handlers = {} # type: Dict[str, Callable[[dict], Awaitable[None]]]
# Map from type to instance name that we should route EDU handling to.
self._edu_type_to_instance = {} # type: Dict[str, str]
def register_edu_handler(
self, edu_type: str, handler: Callable[[str, dict], Awaitable[None]]
):
"""Sets the handler callable that will be used to handle an incoming
federation EDU of the given type.
@@ -809,66 +842,56 @@ class FederationHandlerRegistry(object):
self.query_handlers[query_type] = handler
def register_instance_for_edu(self, edu_type: str, instance_name: str):
"""Register that the EDU handler is on a different instance than master.
"""
self._edu_type_to_instance[edu_type] = instance_name
async def on_edu(self, edu_type: str, origin: str, content: dict):
handler = self.edu_handlers.get(edu_type)
if not handler:
logger.warning("No handler registered for EDU type %s", edu_type)
if not self.config.use_presence and edu_type == "m.presence":
return
with start_active_span_from_edu(content, "handle_edu"):
# Check if we have a handler on this instance
handler = self.edu_handlers.get(edu_type)
if handler:
with start_active_span_from_edu(content, "handle_edu"):
try:
await handler(origin, content)
except SynapseError as e:
logger.info("Failed to handle edu %r: %r", edu_type, e)
except Exception:
logger.exception("Failed to handle edu %r", edu_type)
return
# Check if we can route it somewhere else that isn't us
route_to = self._edu_type_to_instance.get(edu_type, "master")
if route_to != self._instance_name:
try:
await handler(origin, content)
await self._send_edu(
instance_name=route_to,
edu_type=edu_type,
origin=origin,
content=content,
)
except SynapseError as e:
logger.info("Failed to handle edu %r: %r", edu_type, e)
except Exception:
logger.exception("Failed to handle edu %r", edu_type)
def on_query(self, query_type: str, args: dict) -> defer.Deferred:
handler = self.query_handlers.get(query_type)
if not handler:
logger.warning("No handler registered for query type %s", query_type)
raise NotFoundError("No handler for Query type '%s'" % (query_type,))
return handler(args)
class ReplicationFederationHandlerRegistry(FederationHandlerRegistry):
"""A FederationHandlerRegistry for worker processes.
When receiving EDU or queries it will check if an appropriate handler has
been registered on the worker, if there isn't one then it calls off to the
master process.
"""
def __init__(self, hs):
self.config = hs.config
self.http_client = hs.get_simple_http_client()
self.clock = hs.get_clock()
self._get_query_client = ReplicationGetQueryRestServlet.make_client(hs)
self._send_edu = ReplicationFederationSendEduRestServlet.make_client(hs)
super(ReplicationFederationHandlerRegistry, self).__init__()
async def on_edu(self, edu_type: str, origin: str, content: dict):
"""Overrides FederationHandlerRegistry
"""
if not self.config.use_presence and edu_type == "m.presence":
return
handler = self.edu_handlers.get(edu_type)
if handler:
return await super(ReplicationFederationHandlerRegistry, self).on_edu(
edu_type, origin, content
)
return await self._send_edu(edu_type=edu_type, origin=origin, content=content)
# Oh well, let's just log and move on.
logger.warning("No handler registered for EDU type %s", edu_type)
async def on_query(self, query_type: str, args: dict):
"""Overrides FederationHandlerRegistry
"""
handler = self.query_handlers.get(query_type)
if handler:
return await handler(args)
return await self._get_query_client(query_type=query_type, args=args)
# Check if we can route it somewhere else that isn't us
if self._instance_name == "master":
return await self._get_query_client(query_type=query_type, args=args)
# Uh oh, no handler! Let's raise an exception so the request returns an
# error.
logger.warning("No handler registered for query type %s", query_type)
raise NotFoundError("No handler for Query type '%s'" % (query_type,))
+8 -8
View File
@@ -197,7 +197,7 @@ class FederationSender(object):
destinations = {
d
for d in destinations
if self._federation_shard_config.should_send_to(
if self._federation_shard_config.should_handle(
self._instance_name, d
)
}
@@ -335,7 +335,7 @@ class FederationSender(object):
d
for d in domains
if d != self.server_name
and self._federation_shard_config.should_send_to(self._instance_name, d)
and self._federation_shard_config.should_handle(self._instance_name, d)
]
if not domains:
return
@@ -441,7 +441,7 @@ class FederationSender(object):
for destination in destinations:
if destination == self.server_name:
continue
if not self._federation_shard_config.should_send_to(
if not self._federation_shard_config.should_handle(
self._instance_name, destination
):
continue
@@ -460,7 +460,7 @@ class FederationSender(object):
if destination == self.server_name:
continue
if not self._federation_shard_config.should_send_to(
if not self._federation_shard_config.should_handle(
self._instance_name, destination
):
continue
@@ -486,7 +486,7 @@ class FederationSender(object):
logger.info("Not sending EDU to ourselves")
return
if not self._federation_shard_config.should_send_to(
if not self._federation_shard_config.should_handle(
self._instance_name, destination
):
return
@@ -507,7 +507,7 @@ class FederationSender(object):
edu: edu to send
key: clobbering key for this edu
"""
if not self._federation_shard_config.should_send_to(
if not self._federation_shard_config.should_handle(
self._instance_name, edu.destination
):
return
@@ -523,7 +523,7 @@ class FederationSender(object):
logger.warning("Not sending device update to ourselves")
return
if not self._federation_shard_config.should_send_to(
if not self._federation_shard_config.should_handle(
self._instance_name, destination
):
return
@@ -541,7 +541,7 @@ class FederationSender(object):
logger.warning("Not waking up ourselves")
return
if not self._federation_shard_config.should_send_to(
if not self._federation_shard_config.should_handle(
self._instance_name, destination
):
return
@@ -78,7 +78,7 @@ class PerDestinationQueue(object):
self._federation_shard_config = hs.config.federation.federation_shard_config
self._should_send_on_this_instance = True
if not self._federation_shard_config.should_send_to(
if not self._federation_shard_config.should_handle(
self._instance_name, destination
):
# We don't raise an exception here to avoid taking out any other
+2 -8
View File
@@ -20,8 +20,6 @@ import logging
import re
from typing import Optional, Tuple, Type
from twisted.internet.defer import maybeDeferred
import synapse
from synapse.api.errors import Codes, FederationDeniedError, SynapseError
from synapse.api.room_versions import RoomVersions
@@ -796,12 +794,8 @@ class PublicRoomList(BaseFederationServlet):
# zero is a special value which corresponds to no limit.
limit = None
data = await maybeDeferred(
self.handler.get_local_public_room_list,
limit,
since_token,
network_tuple=network_tuple,
from_federation=True,
data = await self.handler.get_local_public_room_list(
limit, since_token, network_tuple=network_tuple, from_federation=True
)
return 200, data
+2 -5
View File
@@ -15,8 +15,6 @@
import logging
from twisted.internet import defer
import synapse.state
import synapse.storage
import synapse.types
@@ -66,8 +64,7 @@ class BaseHandler(object):
self.event_builder_factory = hs.get_event_builder_factory()
@defer.inlineCallbacks
def ratelimit(self, requester, update=True, is_admin_redaction=False):
async def ratelimit(self, requester, update=True, is_admin_redaction=False):
"""Ratelimits requests.
Args:
@@ -99,7 +96,7 @@ class BaseHandler(object):
burst_count = self._rc_message.burst_count
# Check if there is a per user override in the DB.
override = yield self.store.get_ratelimit_for_user(user_id)
override = await self.store.get_ratelimit_for_user(user_id)
if override:
# If overridden with a null Hz then ratelimiting has been entirely
# disabled for the user
+113 -136
View File
@@ -15,9 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Any, Dict, Optional
from twisted.internet import defer
from typing import Any, Dict, List, Optional
from synapse.api import errors
from synapse.api.constants import EventTypes
@@ -57,21 +55,20 @@ class DeviceWorkerHandler(BaseHandler):
self._auth_handler = hs.get_auth_handler()
@trace
@defer.inlineCallbacks
def get_devices_by_user(self, user_id):
async def get_devices_by_user(self, user_id: str) -> List[Dict[str, Any]]:
"""
Retrieve the given user's devices
Args:
user_id (str):
user_id: The user ID to query for devices.
Returns:
defer.Deferred: list[dict[str, X]]: info on each device
info on each device
"""
set_tag("user_id", user_id)
device_map = yield self.store.get_devices_by_user(user_id)
device_map = await self.store.get_devices_by_user(user_id)
ips = yield self.store.get_last_client_ip_by_device(user_id, device_id=None)
ips = await self.store.get_last_client_ip_by_device(user_id, device_id=None)
devices = list(device_map.values())
for device in devices:
@@ -81,24 +78,23 @@ class DeviceWorkerHandler(BaseHandler):
return devices
@trace
@defer.inlineCallbacks
def get_device(self, user_id, device_id):
async def get_device(self, user_id: str, device_id: str) -> Dict[str, Any]:
""" Retrieve the given device
Args:
user_id (str):
device_id (str):
user_id: The user to get the device from
device_id: The device to fetch.
Returns:
defer.Deferred: dict[str, X]: info on the device
info on the device
Raises:
errors.NotFoundError: if the device was not found
"""
try:
device = yield self.store.get_device(user_id, device_id)
device = await self.store.get_device(user_id, device_id)
except errors.StoreError:
raise errors.NotFoundError
ips = yield self.store.get_last_client_ip_by_device(user_id, device_id)
ips = await self.store.get_last_client_ip_by_device(user_id, device_id)
_update_device_from_client_ips(device, ips)
set_tag("device", device)
@@ -106,10 +102,9 @@ class DeviceWorkerHandler(BaseHandler):
return device
@measure_func("device.get_user_ids_changed")
@trace
@defer.inlineCallbacks
def get_user_ids_changed(self, user_id, from_token):
@measure_func("device.get_user_ids_changed")
async def get_user_ids_changed(self, user_id, from_token):
"""Get list of users that have had the devices updated, or have newly
joined a room, that `user_id` may be interested in.
@@ -120,13 +115,13 @@ class DeviceWorkerHandler(BaseHandler):
set_tag("user_id", user_id)
set_tag("from_token", from_token)
now_room_key = yield self.store.get_room_events_max_id()
now_room_key = await self.store.get_room_events_max_id()
room_ids = yield self.store.get_rooms_for_user(user_id)
room_ids = await self.store.get_rooms_for_user(user_id)
# First we check if any devices have changed for users that we share
# rooms with.
users_who_share_room = yield self.store.get_users_who_share_room_with_user(
users_who_share_room = await self.store.get_users_who_share_room_with_user(
user_id
)
@@ -135,14 +130,14 @@ class DeviceWorkerHandler(BaseHandler):
# Always tell the user about their own devices
tracked_users.add(user_id)
changed = yield self.store.get_users_whose_devices_changed(
changed = await self.store.get_users_whose_devices_changed(
from_token.device_list_key, tracked_users
)
# Then work out if any users have since joined
rooms_changed = self.store.get_rooms_that_changed(room_ids, from_token.room_key)
member_events = yield self.store.get_membership_changes_for_user(
member_events = await self.store.get_membership_changes_for_user(
user_id, from_token.room_key, now_room_key
)
rooms_changed.update(event.room_id for event in member_events)
@@ -152,7 +147,7 @@ class DeviceWorkerHandler(BaseHandler):
possibly_changed = set(changed)
possibly_left = set()
for room_id in rooms_changed:
current_state_ids = yield self.store.get_current_state_ids(room_id)
current_state_ids = await self.store.get_current_state_ids(room_id)
# The user may have left the room
# TODO: Check if they actually did or if we were just invited.
@@ -166,7 +161,7 @@ class DeviceWorkerHandler(BaseHandler):
# Fetch the current state at the time.
try:
event_ids = yield self.store.get_forward_extremeties_for_room(
event_ids = await self.store.get_forward_extremeties_for_room(
room_id, stream_ordering=stream_ordering
)
except errors.StoreError:
@@ -192,7 +187,7 @@ class DeviceWorkerHandler(BaseHandler):
continue
# mapping from event_id -> state_dict
prev_state_ids = yield self.state_store.get_state_ids_for_events(event_ids)
prev_state_ids = await self.state_store.get_state_ids_for_events(event_ids)
# Check if we've joined the room? If so we just blindly add all the users to
# the "possibly changed" users.
@@ -238,11 +233,10 @@ class DeviceWorkerHandler(BaseHandler):
return result
@defer.inlineCallbacks
def on_federation_query_user_devices(self, user_id):
stream_id, devices = yield self.store.get_devices_with_keys_by_user(user_id)
master_key = yield self.store.get_e2e_cross_signing_key(user_id, "master")
self_signing_key = yield self.store.get_e2e_cross_signing_key(
async def on_federation_query_user_devices(self, user_id):
stream_id, devices = await self.store.get_devices_with_keys_by_user(user_id)
master_key = await self.store.get_e2e_cross_signing_key(user_id, "master")
self_signing_key = await self.store.get_e2e_cross_signing_key(
user_id, "self_signing"
)
@@ -271,8 +265,7 @@ class DeviceHandler(DeviceWorkerHandler):
hs.get_distributor().observe("user_left_room", self.user_left_room)
@defer.inlineCallbacks
def check_device_registered(
async def check_device_registered(
self, user_id, device_id, initial_device_display_name=None
):
"""
@@ -290,13 +283,13 @@ class DeviceHandler(DeviceWorkerHandler):
str: device id (generated if none was supplied)
"""
if device_id is not None:
new_device = yield self.store.store_device(
new_device = await self.store.store_device(
user_id=user_id,
device_id=device_id,
initial_device_display_name=initial_device_display_name,
)
if new_device:
yield self.notify_device_update(user_id, [device_id])
await self.notify_device_update(user_id, [device_id])
return device_id
# if the device id is not specified, we'll autogen one, but loop a few
@@ -304,33 +297,29 @@ class DeviceHandler(DeviceWorkerHandler):
attempts = 0
while attempts < 5:
device_id = stringutils.random_string(10).upper()
new_device = yield self.store.store_device(
new_device = await self.store.store_device(
user_id=user_id,
device_id=device_id,
initial_device_display_name=initial_device_display_name,
)
if new_device:
yield self.notify_device_update(user_id, [device_id])
await self.notify_device_update(user_id, [device_id])
return device_id
attempts += 1
raise errors.StoreError(500, "Couldn't generate a device ID.")
@trace
@defer.inlineCallbacks
def delete_device(self, user_id, device_id):
async def delete_device(self, user_id: str, device_id: str) -> None:
""" Delete the given device
Args:
user_id (str):
device_id (str):
Returns:
defer.Deferred:
user_id: The user to delete the device from.
device_id: The device to delete.
"""
try:
yield self.store.delete_device(user_id, device_id)
await self.store.delete_device(user_id, device_id)
except errors.StoreError as e:
if e.code == 404:
# no match
@@ -342,49 +331,40 @@ class DeviceHandler(DeviceWorkerHandler):
else:
raise
yield defer.ensureDeferred(
self._auth_handler.delete_access_tokens_for_user(
user_id, device_id=device_id
)
await self._auth_handler.delete_access_tokens_for_user(
user_id, device_id=device_id
)
yield self.store.delete_e2e_keys_by_device(user_id=user_id, device_id=device_id)
await self.store.delete_e2e_keys_by_device(user_id=user_id, device_id=device_id)
yield self.notify_device_update(user_id, [device_id])
await self.notify_device_update(user_id, [device_id])
@trace
@defer.inlineCallbacks
def delete_all_devices_for_user(self, user_id, except_device_id=None):
async def delete_all_devices_for_user(
self, user_id: str, except_device_id: Optional[str] = None
) -> None:
"""Delete all of the user's devices
Args:
user_id (str):
except_device_id (str|None): optional device id which should not
be deleted
Returns:
defer.Deferred:
user_id: The user to remove all devices from
except_device_id: optional device id which should not be deleted
"""
device_map = yield self.store.get_devices_by_user(user_id)
device_map = await self.store.get_devices_by_user(user_id)
device_ids = list(device_map)
if except_device_id is not None:
device_ids = [d for d in device_ids if d != except_device_id]
yield self.delete_devices(user_id, device_ids)
await self.delete_devices(user_id, device_ids)
@defer.inlineCallbacks
def delete_devices(self, user_id, device_ids):
async def delete_devices(self, user_id: str, device_ids: List[str]) -> None:
""" Delete several devices
Args:
user_id (str):
device_ids (List[str]): The list of device IDs to delete
Returns:
defer.Deferred:
user_id: The user to delete devices from.
device_ids: The list of device IDs to delete
"""
try:
yield self.store.delete_devices(user_id, device_ids)
await self.store.delete_devices(user_id, device_ids)
except errors.StoreError as e:
if e.code == 404:
# no match
@@ -397,28 +377,22 @@ class DeviceHandler(DeviceWorkerHandler):
# Delete access tokens and e2e keys for each device. Not optimised as it is not
# considered as part of a critical path.
for device_id in device_ids:
yield defer.ensureDeferred(
self._auth_handler.delete_access_tokens_for_user(
user_id, device_id=device_id
)
await self._auth_handler.delete_access_tokens_for_user(
user_id, device_id=device_id
)
yield self.store.delete_e2e_keys_by_device(
await self.store.delete_e2e_keys_by_device(
user_id=user_id, device_id=device_id
)
yield self.notify_device_update(user_id, device_ids)
await self.notify_device_update(user_id, device_ids)
@defer.inlineCallbacks
def update_device(self, user_id, device_id, content):
async def update_device(self, user_id: str, device_id: str, content: dict) -> None:
""" Update the given device
Args:
user_id (str):
device_id (str):
content (dict): body of update request
Returns:
defer.Deferred:
user_id: The user to update devices of.
device_id: The device to update.
content: body of update request
"""
# Reject a new displayname which is too long.
@@ -431,10 +405,10 @@ class DeviceHandler(DeviceWorkerHandler):
)
try:
yield self.store.update_device(
await self.store.update_device(
user_id, device_id, new_display_name=new_display_name
)
yield self.notify_device_update(user_id, [device_id])
await self.notify_device_update(user_id, [device_id])
except errors.StoreError as e:
if e.code == 404:
raise errors.NotFoundError()
@@ -443,12 +417,15 @@ class DeviceHandler(DeviceWorkerHandler):
@trace
@measure_func("notify_device_update")
@defer.inlineCallbacks
def notify_device_update(self, user_id, device_ids):
async def notify_device_update(self, user_id, device_ids):
"""Notify that a user's device(s) has changed. Pokes the notifier, and
remote servers if the user is local.
"""
users_who_share_room = yield self.store.get_users_who_share_room_with_user(
if not device_ids:
# No changes to notify about, so this is a no-op.
return
users_who_share_room = await self.store.get_users_who_share_room_with_user(
user_id
)
@@ -459,20 +436,24 @@ class DeviceHandler(DeviceWorkerHandler):
set_tag("target_hosts", hosts)
position = yield self.store.add_device_change_to_streams(
position = await self.store.add_device_change_to_streams(
user_id, device_ids, list(hosts)
)
if not position:
# This should only happen if there are no updates, so we bail.
return
for device_id in device_ids:
logger.debug(
"Notifying about update %r/%r, ID: %r", user_id, device_id, position
)
room_ids = yield self.store.get_rooms_for_user(user_id)
room_ids = await self.store.get_rooms_for_user(user_id)
# specify the user ID too since the user should always get their own device list
# updates, even if they aren't in any rooms.
yield self.notifier.on_new_event(
self.notifier.on_new_event(
"device_list_key", position, users=[user_id], rooms=room_ids
)
@@ -484,29 +465,29 @@ class DeviceHandler(DeviceWorkerHandler):
self.federation_sender.send_device_messages(host)
log_kv({"message": "sent device update to host", "host": host})
@defer.inlineCallbacks
def notify_user_signature_update(self, from_user_id, user_ids):
async def notify_user_signature_update(
self, from_user_id: str, user_ids: List[str]
) -> None:
"""Notify a user that they have made new signatures of other users.
Args:
from_user_id (str): the user who made the signature
user_ids (list[str]): the users IDs that have new signatures
from_user_id: the user who made the signature
user_ids: the users IDs that have new signatures
"""
position = yield self.store.add_user_signature_change_to_streams(
position = await self.store.add_user_signature_change_to_streams(
from_user_id, user_ids
)
self.notifier.on_new_event("device_list_key", position, users=[from_user_id])
@defer.inlineCallbacks
def user_left_room(self, user, room_id):
async def user_left_room(self, user, room_id):
user_id = user.to_string()
room_ids = yield self.store.get_rooms_for_user(user_id)
room_ids = await self.store.get_rooms_for_user(user_id)
if not room_ids:
# We no longer share rooms with this user, so we'll no longer
# receive device updates. Mark this in DB.
yield self.store.mark_remote_user_device_list_as_unsubscribed(user_id)
await self.store.mark_remote_user_device_list_as_unsubscribed(user_id)
def _update_device_from_client_ips(device, client_ips):
@@ -549,8 +530,7 @@ class DeviceListUpdater(object):
)
@trace
@defer.inlineCallbacks
def incoming_device_list_update(self, origin, edu_content):
async def incoming_device_list_update(self, origin, edu_content):
"""Called on incoming device list update from federation. Responsible
for parsing the EDU and adding to pending updates list.
"""
@@ -583,7 +563,7 @@ class DeviceListUpdater(object):
)
return
room_ids = yield self.store.get_rooms_for_user(user_id)
room_ids = await self.store.get_rooms_for_user(user_id)
if not room_ids:
# We don't share any rooms with this user. Ignore update, as we
# probably won't get any further updates.
@@ -608,14 +588,13 @@ class DeviceListUpdater(object):
(device_id, stream_id, prev_ids, edu_content)
)
yield self._handle_device_updates(user_id)
await self._handle_device_updates(user_id)
@measure_func("_incoming_device_list_update")
@defer.inlineCallbacks
def _handle_device_updates(self, user_id):
async def _handle_device_updates(self, user_id):
"Actually handle pending updates."
with (yield self._remote_edu_linearizer.queue(user_id)):
with (await self._remote_edu_linearizer.queue(user_id)):
pending_updates = self._pending_updates.pop(user_id, [])
if not pending_updates:
# This can happen since we batch updates
@@ -632,7 +611,7 @@ class DeviceListUpdater(object):
# Given a list of updates we check if we need to resync. This
# happens if we've missed updates.
resync = yield self._need_to_do_resync(user_id, pending_updates)
resync = await self._need_to_do_resync(user_id, pending_updates)
if logger.isEnabledFor(logging.INFO):
logger.info(
@@ -643,16 +622,16 @@ class DeviceListUpdater(object):
)
if resync:
yield self.user_device_resync(user_id)
await self.user_device_resync(user_id)
else:
# Simply update the single device, since we know that is the only
# change (because of the single prev_id matching the current cache)
for device_id, stream_id, prev_ids, content in pending_updates:
yield self.store.update_remote_device_list_cache_entry(
await self.store.update_remote_device_list_cache_entry(
user_id, device_id, content, stream_id
)
yield self.device_handler.notify_device_update(
await self.device_handler.notify_device_update(
user_id, [device_id for device_id, _, _, _ in pending_updates]
)
@@ -660,14 +639,13 @@ class DeviceListUpdater(object):
stream_id for _, stream_id, _, _ in pending_updates
)
@defer.inlineCallbacks
def _need_to_do_resync(self, user_id, updates):
async def _need_to_do_resync(self, user_id, updates):
"""Given a list of updates for a user figure out if we need to do a full
resync, or whether we have enough data that we can just apply the delta.
"""
seen_updates = self._seen_updates.get(user_id, set())
extremity = yield self.store.get_device_list_last_stream_id_for_remote(user_id)
extremity = await self.store.get_device_list_last_stream_id_for_remote(user_id)
logger.debug("Current extremity for %r: %r", user_id, extremity)
@@ -692,8 +670,7 @@ class DeviceListUpdater(object):
return False
@trace
@defer.inlineCallbacks
def _maybe_retry_device_resync(self):
async def _maybe_retry_device_resync(self):
"""Retry to resync device lists that are out of sync, except if another retry is
in progress.
"""
@@ -705,12 +682,12 @@ class DeviceListUpdater(object):
# we don't send too many requests.
self._resync_retry_in_progress = True
# Get all of the users that need resyncing.
need_resync = yield self.store.get_user_ids_requiring_device_list_resync()
need_resync = await self.store.get_user_ids_requiring_device_list_resync()
# Iterate over the set of user IDs.
for user_id in need_resync:
try:
# Try to resync the current user's devices list.
result = yield self.user_device_resync(
result = await self.user_device_resync(
user_id=user_id, mark_failed_as_stale=False,
)
@@ -734,16 +711,17 @@ class DeviceListUpdater(object):
# Allow future calls to retry resyncinc out of sync device lists.
self._resync_retry_in_progress = False
@defer.inlineCallbacks
def user_device_resync(self, user_id, mark_failed_as_stale=True):
async def user_device_resync(
self, user_id: str, mark_failed_as_stale: bool = True
) -> Optional[dict]:
"""Fetches all devices for a user and updates the device cache with them.
Args:
user_id (str): The user's id whose device_list will be updated.
mark_failed_as_stale (bool): Whether to mark the user's device list as stale
user_id: The user's id whose device_list will be updated.
mark_failed_as_stale: Whether to mark the user's device list as stale
if the attempt to resync failed.
Returns:
Deferred[dict]: a dict with device info as under the "devices" in the result of this
A dict with device info as under the "devices" in the result of this
request:
https://matrix.org/docs/spec/server_server/r0.1.2#get-matrix-federation-v1-user-devices-userid
"""
@@ -752,12 +730,12 @@ class DeviceListUpdater(object):
# Fetch all devices for the user.
origin = get_domain_from_id(user_id)
try:
result = yield self.federation.query_user_devices(origin, user_id)
result = await self.federation.query_user_devices(origin, user_id)
except NotRetryingDestination:
if mark_failed_as_stale:
# Mark the remote user's device list as stale so we know we need to retry
# it later.
yield self.store.mark_remote_user_device_cache_as_stale(user_id)
await self.store.mark_remote_user_device_cache_as_stale(user_id)
return
except (RequestSendFailed, HttpResponseException) as e:
@@ -768,7 +746,7 @@ class DeviceListUpdater(object):
if mark_failed_as_stale:
# Mark the remote user's device list as stale so we know we need to retry
# it later.
yield self.store.mark_remote_user_device_cache_as_stale(user_id)
await self.store.mark_remote_user_device_cache_as_stale(user_id)
# We abort on exceptions rather than accepting the update
# as otherwise synapse will 'forget' that its device list
@@ -792,7 +770,7 @@ class DeviceListUpdater(object):
if mark_failed_as_stale:
# Mark the remote user's device list as stale so we know we need to retry
# it later.
yield self.store.mark_remote_user_device_cache_as_stale(user_id)
await self.store.mark_remote_user_device_cache_as_stale(user_id)
return
log_kv({"result": result})
@@ -833,25 +811,24 @@ class DeviceListUpdater(object):
stream_id,
)
yield self.store.update_remote_device_list_cache(user_id, devices, stream_id)
await self.store.update_remote_device_list_cache(user_id, devices, stream_id)
device_ids = [device["device_id"] for device in devices]
# Handle cross-signing keys.
cross_signing_device_ids = yield self.process_cross_signing_key_update(
cross_signing_device_ids = await self.process_cross_signing_key_update(
user_id, master_key, self_signing_key,
)
device_ids = device_ids + cross_signing_device_ids
yield self.device_handler.notify_device_update(user_id, device_ids)
await self.device_handler.notify_device_update(user_id, device_ids)
# We clobber the seen updates since we've re-synced from a given
# point.
self._seen_updates[user_id] = {stream_id}
defer.returnValue(result)
return result
@defer.inlineCallbacks
def process_cross_signing_key_update(
async def process_cross_signing_key_update(
self,
user_id: str,
master_key: Optional[Dict[str, Any]],
@@ -872,14 +849,14 @@ class DeviceListUpdater(object):
device_ids = []
if master_key:
yield self.store.set_e2e_cross_signing_key(user_id, "master", master_key)
await self.store.set_e2e_cross_signing_key(user_id, "master", master_key)
_, verify_key = get_verify_key_from_cross_signing_key(master_key)
# verify_key is a VerifyKey from signedjson, which uses
# .version to denote the portion of the key ID after the
# algorithm and colon, which is the device ID
device_ids.append(verify_key.version)
if self_signing_key:
yield self.store.set_e2e_cross_signing_key(
await self.store.set_e2e_cross_signing_key(
user_id, "self_signing", self_signing_key
)
_, verify_key = get_verify_key_from_cross_signing_key(self_signing_key)
+14 -3
View File
@@ -19,7 +19,7 @@
import itertools
import logging
from collections import Container
from collections.abc import Container
from http import HTTPStatus
from typing import Dict, Iterable, List, Optional, Sequence, Tuple, Union
@@ -44,6 +44,7 @@ from synapse.api.errors import (
FederationDeniedError,
FederationError,
HttpResponseException,
NotFoundError,
RequestSendFailed,
SynapseError,
)
@@ -1442,10 +1443,20 @@ class FederationHandler(BaseHandler):
)
raise SynapseError(403, "User not from origin", Codes.FORBIDDEN)
event_content = {"membership": Membership.JOIN}
# checking the room version will check that we've actually heard of the room
# (and return a 404 otherwise)
room_version = await self.store.get_room_version_id(room_id)
# now check that we are *still* in the room
is_in_room = await self.auth.check_host_in_room(room_id, self.server_name)
if not is_in_room:
logger.info(
"Got /make_join request for room %s we are no longer in", room_id,
)
raise NotFoundError("Not an active room on this server")
event_content = {"membership": Membership.JOIN}
builder = self.event_builder_factory.new(
room_version,
{
+6 -2
View File
@@ -488,11 +488,15 @@ class EventCreationHandler(object):
try:
if "displayname" not in content:
displayname = yield profile.get_displayname(target)
displayname = yield defer.ensureDeferred(
profile.get_displayname(target)
)
if displayname is not None:
content["displayname"] = displayname
if "avatar_url" not in content:
avatar_url = yield profile.get_avatar_url(target)
avatar_url = yield defer.ensureDeferred(
profile.get_avatar_url(target)
)
if avatar_url is not None:
content["avatar_url"] = avatar_url
except Exception as e:
+28 -37
View File
@@ -15,10 +15,8 @@
# limitations under the License.
import logging
from typing import List
from six.moves import range
from signedjson.sign import sign_json
from twisted.internet import defer, reactor
@@ -145,16 +143,15 @@ class BaseProfileHandler(BaseHandler):
)
raise
@defer.inlineCallbacks
def get_profile(self, user_id):
async def get_profile(self, user_id):
target_user = UserID.from_string(user_id)
if self.hs.is_mine(target_user):
try:
displayname = yield self.store.get_profile_displayname(
displayname = await self.store.get_profile_displayname(
target_user.localpart
)
avatar_url = yield self.store.get_profile_avatar_url(
avatar_url = await self.store.get_profile_avatar_url(
target_user.localpart
)
except StoreError as e:
@@ -165,7 +162,7 @@ class BaseProfileHandler(BaseHandler):
return {"displayname": displayname, "avatar_url": avatar_url}
else:
try:
result = yield self.federation.make_query(
result = await self.federation.make_query(
destination=target_user.domain,
query_type="profile",
args={"user_id": user_id},
@@ -177,8 +174,7 @@ class BaseProfileHandler(BaseHandler):
except HttpResponseException as e:
raise e.to_synapse_error()
@defer.inlineCallbacks
def get_profile_from_cache(self, user_id):
async def get_profile_from_cache(self, user_id):
"""Get the profile information from our local cache. If the user is
ours then the profile information will always be corect. Otherwise,
it may be out of date/missing.
@@ -186,10 +182,10 @@ class BaseProfileHandler(BaseHandler):
target_user = UserID.from_string(user_id)
if self.hs.is_mine(target_user):
try:
displayname = yield self.store.get_profile_displayname(
displayname = await self.store.get_profile_displayname(
target_user.localpart
)
avatar_url = yield self.store.get_profile_avatar_url(
avatar_url = await self.store.get_profile_avatar_url(
target_user.localpart
)
except StoreError as e:
@@ -199,14 +195,13 @@ class BaseProfileHandler(BaseHandler):
return {"displayname": displayname, "avatar_url": avatar_url}
else:
profile = yield self.store.get_from_remote_profile_cache(user_id)
profile = await self.store.get_from_remote_profile_cache(user_id)
return profile or {}
@defer.inlineCallbacks
def get_displayname(self, target_user):
async def get_displayname(self, target_user):
if self.hs.is_mine(target_user):
try:
displayname = yield self.store.get_profile_displayname(
displayname = await self.store.get_profile_displayname(
target_user.localpart
)
except StoreError as e:
@@ -217,7 +212,7 @@ class BaseProfileHandler(BaseHandler):
return displayname
else:
try:
result = yield self.federation.make_query(
result = await self.federation.make_query(
destination=target_user.domain,
query_type="profile",
args={"user_id": target_user.to_string(), "field": "displayname"},
@@ -334,11 +329,10 @@ class BaseProfileHandler(BaseHandler):
# start a profile replication push
run_in_background(self._replicate_profiles)
@defer.inlineCallbacks
def get_avatar_url(self, target_user):
async def get_avatar_url(self, target_user):
if self.hs.is_mine(target_user):
try:
avatar_url = yield self.store.get_profile_avatar_url(
avatar_url = await self.store.get_profile_avatar_url(
target_user.localpart
)
except StoreError as e:
@@ -348,7 +342,7 @@ class BaseProfileHandler(BaseHandler):
return avatar_url
else:
try:
result = yield self.federation.make_query(
result = await self.federation.make_query(
destination=target_user.domain,
query_type="profile",
args={"user_id": target_user.to_string(), "field": "avatar_url"},
@@ -455,8 +449,7 @@ class BaseProfileHandler(BaseHandler):
raise SynapseError(400, "Invalid avatar URL '%s' supplied" % mxc)
return avatar_pieces[-1]
@defer.inlineCallbacks
def on_profile_query(self, args):
async def on_profile_query(self, args):
user = UserID.from_string(args["user_id"])
if not self.hs.is_mine(user):
raise SynapseError(400, "User is not hosted on this homeserver")
@@ -466,12 +459,12 @@ class BaseProfileHandler(BaseHandler):
response = {}
try:
if just_field is None or just_field == "displayname":
response["displayname"] = yield self.store.get_profile_displayname(
response["displayname"] = await self.store.get_profile_displayname(
user.localpart
)
if just_field is None or just_field == "avatar_url":
response["avatar_url"] = yield self.store.get_profile_avatar_url(
response["avatar_url"] = await self.store.get_profile_avatar_url(
user.localpart
)
except StoreError as e:
@@ -506,8 +499,7 @@ class BaseProfileHandler(BaseHandler):
"Failed to update join event for room %s - %s", room_id, str(e)
)
@defer.inlineCallbacks
def check_profile_query_allowed(self, target_user, requester=None):
async def check_profile_query_allowed(self, target_user, requester=None):
"""Checks whether a profile query is allowed. If the
'require_auth_for_profile_requests' config flag is set to True and a
'requester' is provided, the query is only allowed if the two users
@@ -539,8 +531,8 @@ class BaseProfileHandler(BaseHandler):
return
try:
requester_rooms = yield self.store.get_rooms_for_user(requester.to_string())
target_user_rooms = yield self.store.get_rooms_for_user(
requester_rooms = await self.store.get_rooms_for_user(requester.to_string())
target_user_rooms = await self.store.get_rooms_for_user(
target_user.to_string()
)
@@ -573,25 +565,24 @@ class MasterProfileHandler(BaseProfileHandler):
"Update remote profile", self._update_remote_profile_cache
)
@defer.inlineCallbacks
def _update_remote_profile_cache(self):
async def _update_remote_profile_cache(self):
"""Called periodically to check profiles of remote users we haven't
checked in a while.
"""
entries = yield self.store.get_remote_profile_cache_entries_that_expire(
entries = await self.store.get_remote_profile_cache_entries_that_expire(
last_checked=self.clock.time_msec() - self.PROFILE_UPDATE_EVERY_MS
)
for user_id, displayname, avatar_url in entries:
is_subscribed = yield self.store.is_subscribed_remote_profile_for_user(
is_subscribed = await self.store.is_subscribed_remote_profile_for_user(
user_id
)
if not is_subscribed:
yield self.store.maybe_delete_remote_profile_cache(user_id)
await self.store.maybe_delete_remote_profile_cache(user_id)
continue
try:
profile = yield self.federation.make_query(
profile = await self.federation.make_query(
destination=get_domain_from_id(user_id),
query_type="profile",
args={"user_id": user_id},
@@ -600,7 +591,7 @@ class MasterProfileHandler(BaseProfileHandler):
except Exception:
logger.exception("Failed to get avatar_url")
yield self.store.update_remote_profile_cache(
await self.store.update_remote_profile_cache(
user_id, displayname, avatar_url
)
continue
@@ -609,4 +600,4 @@ class MasterProfileHandler(BaseProfileHandler):
new_avatar = profile.get("avatar_url")
# We always hit update to update the last_check timestamp
yield self.store.update_remote_profile_cache(user_id, new_name, new_avatar)
await self.store.update_remote_profile_cache(user_id, new_name, new_avatar)
+6 -10
View File
@@ -14,8 +14,6 @@
# limitations under the License.
import logging
from twisted.internet import defer
from synapse.handlers._base import BaseHandler
from synapse.types import ReadReceipt, get_domain_from_id
from synapse.util.async_helpers import maybe_awaitable
@@ -129,15 +127,14 @@ class ReceiptEventSource(object):
def __init__(self, hs):
self.store = hs.get_datastore()
@defer.inlineCallbacks
def get_new_events(self, from_key, room_ids, **kwargs):
async def get_new_events(self, from_key, room_ids, **kwargs):
from_key = int(from_key)
to_key = yield self.get_current_key()
to_key = self.get_current_key()
if from_key == to_key:
return [], to_key
events = yield self.store.get_linearized_receipts_for_rooms(
events = await self.store.get_linearized_receipts_for_rooms(
room_ids, from_key=from_key, to_key=to_key
)
@@ -146,8 +143,7 @@ class ReceiptEventSource(object):
def get_current_key(self, direction="f"):
return self.store.get_max_receipt_stream_id()
@defer.inlineCallbacks
def get_pagination_rows(self, user, config, key):
async def get_pagination_rows(self, user, config, key):
to_key = int(config.from_key)
if config.to_key:
@@ -155,8 +151,8 @@ class ReceiptEventSource(object):
else:
from_key = None
room_ids = yield self.store.get_rooms_for_user(user.to_string())
events = yield self.store.get_linearized_receipts_for_rooms(
room_ids = await self.store.get_rooms_for_user(user.to_string())
events = await self.store.get_linearized_receipts_for_rooms(
room_ids, from_key=from_key, to_key=to_key
)
+1 -9
View File
@@ -28,7 +28,6 @@ from synapse.replication.http.register import (
)
from synapse.storage.state import StateFilter
from synapse.types import RoomAlias, UserID, create_requester
from synapse.util.async_helpers import Linearizer
from ._base import BaseHandler
@@ -51,14 +50,7 @@ class RegistrationHandler(BaseHandler):
self.http_client = hs.get_simple_http_client()
self.identity_handler = self.hs.get_handlers().identity_handler
self.ratelimiter = hs.get_registration_ratelimiter()
self._next_generated_user_id = None
self.macaroon_gen = hs.get_macaroon_generator()
self._generate_user_id_linearizer = Linearizer(
name="_generate_user_id_linearizer"
)
self._server_notices_mxid = hs.config.server_notices_mxid
self._show_in_user_directory = self.hs.config.show_users_in_user_directory
@@ -239,7 +231,7 @@ class RegistrationHandler(BaseHandler):
if fail_count > 10:
raise SynapseError(500, "Unable to find a suitable guest user ID")
localpart = await self._generate_user_id()
localpart = await self.store.generate_user_id()
user = UserID(localpart, self.hs.hostname)
user_id = user.to_string()
self.check_user_id_not_appservice_exclusive(user_id)
+29 -33
View File
@@ -20,12 +20,10 @@ from typing import Any, Dict, Optional
import msgpack
from unpaddedbase64 import decode_base64, encode_base64
from twisted.internet import defer
from synapse.api.constants import EventTypes, JoinRules
from synapse.api.errors import Codes, HttpResponseException
from synapse.types import ThirdPartyInstanceID
from synapse.util.caches.descriptors import cachedInlineCallbacks
from synapse.util.caches.descriptors import cached
from synapse.util.caches.response_cache import ResponseCache
from ._base import BaseHandler
@@ -47,7 +45,7 @@ class RoomListHandler(BaseHandler):
hs, "remote_room_list", timeout_ms=30 * 1000
)
def get_local_public_room_list(
async def get_local_public_room_list(
self,
limit=None,
since_token=None,
@@ -72,7 +70,7 @@ class RoomListHandler(BaseHandler):
API
"""
if not self.enable_room_list_search:
return defer.succeed({"chunk": [], "total_room_count_estimate": 0})
return {"chunk": [], "total_room_count_estimate": 0}
logger.info(
"Getting public room list: limit=%r, since=%r, search=%r, network=%r",
@@ -87,7 +85,7 @@ class RoomListHandler(BaseHandler):
# appservice specific lists.
logger.info("Bypassing cache as search request.")
return self._get_public_room_list(
return await self._get_public_room_list(
limit,
since_token,
search_filter,
@@ -96,7 +94,7 @@ class RoomListHandler(BaseHandler):
)
key = (limit, since_token, network_tuple)
return self.response_cache.wrap(
return await self.response_cache.wrap(
key,
self._get_public_room_list,
limit,
@@ -105,8 +103,7 @@ class RoomListHandler(BaseHandler):
from_federation=from_federation,
)
@defer.inlineCallbacks
def _get_public_room_list(
async def _get_public_room_list(
self,
limit: Optional[int] = None,
since_token: Optional[str] = None,
@@ -145,7 +142,7 @@ class RoomListHandler(BaseHandler):
# we request one more than wanted to see if there are more pages to come
probing_limit = limit + 1 if limit is not None else None
results = yield self.store.get_largest_public_rooms(
results = await self.store.get_largest_public_rooms(
network_tuple,
search_filter,
probing_limit,
@@ -221,44 +218,44 @@ class RoomListHandler(BaseHandler):
response["chunk"] = results
response["total_room_count_estimate"] = yield self.store.count_public_rooms(
response["total_room_count_estimate"] = await self.store.count_public_rooms(
network_tuple, ignore_non_federatable=from_federation
)
return response
@cachedInlineCallbacks(num_args=1, cache_context=True)
def generate_room_entry(
@cached(num_args=1, cache_context=True)
async def generate_room_entry(
self,
room_id,
num_joined_users,
room_id: str,
num_joined_users: int,
cache_context,
with_alias=True,
allow_private=False,
):
with_alias: bool = True,
allow_private: bool = False,
) -> Optional[dict]:
"""Returns the entry for a room
Args:
room_id (str): The room's ID.
num_joined_users (int): Number of users in the room.
room_id: The room's ID.
num_joined_users: Number of users in the room.
cache_context: Information for cached responses.
with_alias (bool): Whether to return the room's aliases in the result.
allow_private (bool): Whether invite-only rooms should be shown.
with_alias: Whether to return the room's aliases in the result.
allow_private: Whether invite-only rooms should be shown.
Returns:
Deferred[dict|None]: Returns a room entry as a dictionary, or None if this
Returns a room entry as a dictionary, or None if this
room was determined not to be shown publicly.
"""
result = {"room_id": room_id, "num_joined_members": num_joined_users}
if with_alias:
aliases = yield self.store.get_aliases_for_room(
aliases = await self.store.get_aliases_for_room(
room_id, on_invalidate=cache_context.invalidate
)
if aliases:
result["aliases"] = aliases
current_state_ids = yield self.store.get_current_state_ids(
current_state_ids = await self.store.get_current_state_ids(
room_id, on_invalidate=cache_context.invalidate
)
@@ -266,7 +263,7 @@ class RoomListHandler(BaseHandler):
# We're not in the room, so may as well bail out here.
return result
event_map = yield self.store.get_events(
event_map = await self.store.get_events(
[
event_id
for key, event_id in current_state_ids.items()
@@ -336,8 +333,7 @@ class RoomListHandler(BaseHandler):
return result
@defer.inlineCallbacks
def get_remote_public_room_list(
async def get_remote_public_room_list(
self,
server_name,
limit=None,
@@ -356,7 +352,7 @@ class RoomListHandler(BaseHandler):
# to a locally-filtered search if we must.
try:
res = yield self._get_remote_list_cached(
res = await self._get_remote_list_cached(
server_name,
limit=limit,
since_token=since_token,
@@ -381,7 +377,7 @@ class RoomListHandler(BaseHandler):
limit = None
since_token = None
res = yield self._get_remote_list_cached(
res = await self._get_remote_list_cached(
server_name,
limit=limit,
since_token=since_token,
@@ -400,7 +396,7 @@ class RoomListHandler(BaseHandler):
return res
def _get_remote_list_cached(
async def _get_remote_list_cached(
self,
server_name,
limit=None,
@@ -412,7 +408,7 @@ class RoomListHandler(BaseHandler):
repl_layer = self.hs.get_federation_client()
if search_filter:
# We can't cache when asking for search
return repl_layer.get_public_rooms(
return await repl_layer.get_public_rooms(
server_name,
limit=limit,
since_token=since_token,
@@ -428,7 +424,7 @@ class RoomListHandler(BaseHandler):
include_all_networks,
third_party_instance_id,
)
return self.remote_response_cache.wrap(
return await self.remote_response_cache.wrap(
key,
repl_layer.get_public_rooms,
server_name,
+164 -71
View File
@@ -15,15 +15,19 @@
import logging
from collections import namedtuple
from typing import List, Tuple
from typing import TYPE_CHECKING, List, Set, Tuple
from synapse.api.errors import AuthError, SynapseError
from synapse.logging.context import run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.tcp.streams import TypingStream
from synapse.types import UserID, get_domain_from_id
from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.util.metrics import Measure
from synapse.util.wheel_timer import WheelTimer
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@@ -39,48 +43,48 @@ FEDERATION_TIMEOUT = 60 * 1000
FEDERATION_PING_INTERVAL = 40 * 1000
class TypingHandler(object):
def __init__(self, hs):
class FollowerTypingHandler:
"""A typing handler on a different process than the writer that is updated
via replication.
"""
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
self.server_name = hs.config.server_name
self.auth = hs.get_auth()
self.is_mine_id = hs.is_mine_id
self.notifier = hs.get_notifier()
self.state = hs.get_state_handler()
self.hs = hs
self.clock = hs.get_clock()
self.wheel_timer = WheelTimer(bucket_size=5000)
self.is_mine_id = hs.is_mine_id
self.federation = hs.get_federation_sender()
self.federation = None
if hs.should_send_federation():
self.federation = hs.get_federation_sender()
hs.get_federation_registry().register_edu_handler("m.typing", self._recv_edu)
if hs.config.worker.writers.typing != hs.get_instance_name():
hs.get_federation_registry().register_instance_for_edu(
"m.typing", hs.config.worker.writers.typing,
)
hs.get_distributor().observe("user_left_room", self.user_left_room)
# map room IDs to serial numbers
self._room_serials = {}
# map room IDs to sets of users currently typing
self._room_typing = {}
self._member_typing_until = {} # clock time we expect to stop
self._member_last_federation_poke = {}
self.wheel_timer = WheelTimer(bucket_size=5000)
self._latest_room_serial = 0
self._reset()
# caches which room_ids changed at which serials
self._typing_stream_change_cache = StreamChangeCache(
"TypingStreamChangeCache", self._latest_room_serial
)
self.clock.looping_call(self._handle_timeouts, 5000)
def _reset(self):
"""
Reset the typing handler's data caches.
"""Reset the typing handler's data caches.
"""
# map room IDs to serial numbers
self._room_serials = {}
# map room IDs to sets of users currently typing
self._room_typing = {}
self._member_last_federation_poke = {}
self.wheel_timer = WheelTimer(bucket_size=5000)
def _handle_timeouts(self):
logger.debug("Checking for typing timeouts")
@@ -89,30 +93,140 @@ class TypingHandler(object):
members = set(self.wheel_timer.fetch(now))
for member in members:
if not self.is_typing(member):
# Nothing to do if they're no longer typing
continue
self._handle_timeout_for_member(now, member)
until = self._member_typing_until.get(member, None)
if not until or until <= now:
logger.info("Timing out typing for: %s", member.user_id)
self._stopped_typing(member)
continue
def _handle_timeout_for_member(self, now: int, member: RoomMember):
if not self.is_typing(member):
# Nothing to do if they're no longer typing
return
# Check if we need to resend a keep alive over federation for this
# user.
if self.hs.is_mine_id(member.user_id):
last_fed_poke = self._member_last_federation_poke.get(member, None)
if not last_fed_poke or last_fed_poke + FEDERATION_PING_INTERVAL <= now:
run_in_background(self._push_remote, member=member, typing=True)
# Check if we need to resend a keep alive over federation for this
# user.
if self.federation and self.is_mine_id(member.user_id):
last_fed_poke = self._member_last_federation_poke.get(member, None)
if not last_fed_poke or last_fed_poke + FEDERATION_PING_INTERVAL <= now:
run_as_background_process(
"typing._push_remote", self._push_remote, member=member, typing=True
)
# Add a paranoia timer to ensure that we always have a timer for
# each person typing.
self.wheel_timer.insert(now=now, obj=member, then=now + 60 * 1000)
# Add a paranoia timer to ensure that we always have a timer for
# each person typing.
self.wheel_timer.insert(now=now, obj=member, then=now + 60 * 1000)
def is_typing(self, member):
return member.user_id in self._room_typing.get(member.room_id, [])
async def _push_remote(self, member, typing):
if not self.federation:
return
try:
users = await self.store.get_users_in_room(member.room_id)
self._member_last_federation_poke[member] = self.clock.time_msec()
now = self.clock.time_msec()
self.wheel_timer.insert(
now=now, obj=member, then=now + FEDERATION_PING_INTERVAL
)
for domain in {get_domain_from_id(u) for u in users}:
if domain != self.server_name:
logger.debug("sending typing update to %s", domain)
self.federation.build_and_send_edu(
destination=domain,
edu_type="m.typing",
content={
"room_id": member.room_id,
"user_id": member.user_id,
"typing": typing,
},
key=member,
)
except Exception:
logger.exception("Error pushing typing notif to remotes")
def process_replication_rows(
self, token: int, rows: List[TypingStream.TypingStreamRow]
):
"""Should be called whenever we receive updates for typing stream.
"""
if self._latest_room_serial > token:
# The master has gone backwards. To prevent inconsistent data, just
# clear everything.
self._reset()
# Set the latest serial token to whatever the server gave us.
self._latest_room_serial = token
for row in rows:
self._room_serials[row.room_id] = token
prev_typing = set(self._room_typing.get(row.room_id, []))
now_typing = set(row.user_ids)
self._room_typing[row.room_id] = row.user_ids
run_as_background_process(
"_handle_change_in_typing",
self._handle_change_in_typing,
row.room_id,
prev_typing,
now_typing,
)
async def _handle_change_in_typing(
self, room_id: str, prev_typing: Set[str], now_typing: Set[str]
):
"""Process a change in typing of a room from replication, sending EDUs
for any local users.
"""
for user_id in now_typing - prev_typing:
if self.is_mine_id(user_id):
await self._push_remote(RoomMember(room_id, user_id), True)
for user_id in prev_typing - now_typing:
if self.is_mine_id(user_id):
await self._push_remote(RoomMember(room_id, user_id), False)
def get_current_token(self):
return self._latest_room_serial
class TypingWriterHandler(FollowerTypingHandler):
def __init__(self, hs):
super().__init__(hs)
assert hs.config.worker.writers.typing == hs.get_instance_name()
self.auth = hs.get_auth()
self.notifier = hs.get_notifier()
self.hs = hs
hs.get_federation_registry().register_edu_handler("m.typing", self._recv_edu)
hs.get_distributor().observe("user_left_room", self.user_left_room)
self._member_typing_until = {} # clock time we expect to stop
# caches which room_ids changed at which serials
self._typing_stream_change_cache = StreamChangeCache(
"TypingStreamChangeCache", self._latest_room_serial
)
def _handle_timeout_for_member(self, now: int, member: RoomMember):
super()._handle_timeout_for_member(now, member)
if not self.is_typing(member):
# Nothing to do if they're no longer typing
return
until = self._member_typing_until.get(member, None)
if not until or until <= now:
logger.info("Timing out typing for: %s", member.user_id)
self._stopped_typing(member)
return
async def started_typing(self, target_user, auth_user, room_id, timeout):
target_user_id = target_user.to_string()
auth_user_id = auth_user.to_string()
@@ -179,35 +293,11 @@ class TypingHandler(object):
def _push_update(self, member, typing):
if self.hs.is_mine_id(member.user_id):
# Only send updates for changes to our own users.
run_in_background(self._push_remote, member, typing)
self._push_update_local(member=member, typing=typing)
async def _push_remote(self, member, typing):
try:
users = await self.store.get_users_in_room(member.room_id)
self._member_last_federation_poke[member] = self.clock.time_msec()
now = self.clock.time_msec()
self.wheel_timer.insert(
now=now, obj=member, then=now + FEDERATION_PING_INTERVAL
run_as_background_process(
"typing._push_remote", self._push_remote, member, typing
)
for domain in {get_domain_from_id(u) for u in users}:
if domain != self.server_name:
logger.debug("sending typing update to %s", domain)
self.federation.build_and_send_edu(
destination=domain,
edu_type="m.typing",
content={
"room_id": member.room_id,
"user_id": member.user_id,
"typing": typing,
},
key=member,
)
except Exception:
logger.exception("Error pushing typing notif to remotes")
self._push_update_local(member=member, typing=typing)
async def _recv_edu(self, origin, content):
room_id = content["room_id"]
@@ -304,8 +394,11 @@ class TypingHandler(object):
return rows, current_id, limited
def get_current_token(self):
return self._latest_room_serial
def process_replication_rows(
self, token: int, rows: List[TypingStream.TypingStreamRow]
):
# The writing process should never get updates from replication.
raise Exception("Typing writer instance got typing info over replication")
class TypingNotificationEventSource(object):
+42 -25
View File
@@ -733,37 +733,54 @@ def trace(func=None, opname=None):
_opname = opname if opname else func.__name__
@wraps(func)
def _trace_inner(*args, **kwargs):
if opentracing is None:
return func(*args, **kwargs)
if inspect.iscoroutinefunction(func):
scope = start_active_span(_opname)
scope.__enter__()
@wraps(func)
async def _trace_inner(*args, **kwargs):
if opentracing is None:
return await func(*args, **kwargs)
try:
result = func(*args, **kwargs)
if isinstance(result, defer.Deferred):
def call_back(result):
scope.__exit__(None, None, None)
return result
def err_back(result):
with start_active_span(_opname) as scope:
try:
return await func(*args, **kwargs)
except Exception:
scope.span.set_tag(tags.ERROR, True)
raise
else:
# The other case here handles both sync functions and those
# decorated with inlineDeferred.
@wraps(func)
def _trace_inner(*args, **kwargs):
if opentracing is None:
return func(*args, **kwargs)
scope = start_active_span(_opname)
scope.__enter__()
try:
result = func(*args, **kwargs)
if isinstance(result, defer.Deferred):
def call_back(result):
scope.__exit__(None, None, None)
return result
def err_back(result):
scope.span.set_tag(tags.ERROR, True)
scope.__exit__(None, None, None)
return result
result.addCallbacks(call_back, err_back)
else:
scope.__exit__(None, None, None)
return result
result.addCallbacks(call_back, err_back)
return result
else:
scope.__exit__(None, None, None)
return result
except Exception as e:
scope.__exit__(type(e), None, e.__traceback__)
raise
except Exception as e:
scope.__exit__(type(e), None, e.__traceback__)
raise
return _trace_inner
-126
View File
@@ -14,9 +14,7 @@
# limitations under the License.
import inspect
import logging
import time
from functools import wraps
from inspect import getcallargs
@@ -74,127 +72,3 @@ def log_function(f):
wrapped.__name__ = func_name
return wrapped
def time_function(f):
func_name = f.__name__
@wraps(f)
def wrapped(*args, **kwargs):
global _TIME_FUNC_ID
id = _TIME_FUNC_ID
_TIME_FUNC_ID += 1
start = time.clock()
try:
_log_debug_as_f(f, "[FUNC START] {%s-%d}", (func_name, id))
r = f(*args, **kwargs)
finally:
end = time.clock()
_log_debug_as_f(
f, "[FUNC END] {%s-%d} %.3f sec", (func_name, id, end - start)
)
return r
return wrapped
def trace_function(f):
func_name = f.__name__
linenum = f.func_code.co_firstlineno
pathname = f.func_code.co_filename
@wraps(f)
def wrapped(*args, **kwargs):
name = f.__module__
logger = logging.getLogger(name)
level = logging.DEBUG
frame = inspect.currentframe()
if frame is None:
raise Exception("Can't get current frame!")
s = frame.f_back
to_print = [
"\t%s:%s %s. Args: args=%s, kwargs=%s"
% (pathname, linenum, func_name, args, kwargs)
]
while s:
if True or s.f_globals["__name__"].startswith("synapse"):
filename, lineno, function, _, _ = inspect.getframeinfo(s)
args_string = inspect.formatargvalues(*inspect.getargvalues(s))
to_print.append(
"\t%s:%d %s. Args: %s" % (filename, lineno, function, args_string)
)
s = s.f_back
msg = "\nTraceback for %s:\n" % (func_name,) + "\n".join(to_print)
record = logging.LogRecord(
name=name,
level=level,
pathname=pathname,
lineno=lineno,
msg=msg,
args=(),
exc_info=None,
)
logger.handle(record)
return f(*args, **kwargs)
wrapped.__name__ = func_name
return wrapped
def get_previous_frames():
frame = inspect.currentframe()
if frame is None:
raise Exception("Can't get current frame!")
s = frame.f_back.f_back
to_return = []
while s:
if s.f_globals["__name__"].startswith("synapse"):
filename, lineno, function, _, _ = inspect.getframeinfo(s)
args_string = inspect.formatargvalues(*inspect.getargvalues(s))
to_return.append(
"{{ %s:%d %s - Args: %s }}" % (filename, lineno, function, args_string)
)
s = s.f_back
return ", ".join(to_return)
def get_previous_frame(ignore=[]):
frame = inspect.currentframe()
if frame is None:
raise Exception("Can't get current frame!")
s = frame.f_back.f_back
while s:
if s.f_globals["__name__"].startswith("synapse"):
if not any(s.f_globals["__name__"].startswith(ig) for ig in ignore):
filename, lineno, function, _, _ = inspect.getframeinfo(s)
args_string = inspect.formatargvalues(*inspect.getargvalues(s))
return "{{ %s:%d %s - Args: %s }}" % (
filename,
lineno,
function,
args_string,
)
s = s.f_back
return None
+42 -36
View File
@@ -15,13 +15,12 @@
# limitations under the License.
import logging
from collections import defaultdict
from threading import Lock
from typing import Dict, Tuple, Union
from typing import TYPE_CHECKING, Dict, Union
from prometheus_client import Gauge
from twisted.internet import defer
from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.push import PusherConfigException
from synapse.push.emailpusher import EmailPusher
@@ -29,9 +28,18 @@ from synapse.push.httppusher import HttpPusher
from synapse.push.pusher import PusherFactory
from synapse.util.async_helpers import concurrently_execute
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
synapse_pushers = Gauge(
"synapse_pushers", "Number of active synapse pushers", ["kind", "app_id"]
)
class PusherPool:
"""
The pusher pool. This is responsible for dispatching notifications of new events to
@@ -47,36 +55,20 @@ class PusherPool:
Pusher.on_new_receipts are not expected to return deferreds.
"""
def __init__(self, _hs):
self.hs = _hs
self.pusher_factory = PusherFactory(_hs)
self._should_start_pushers = _hs.config.start_pushers
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.pusher_factory = PusherFactory(hs)
self._should_start_pushers = hs.config.start_pushers
self.store = self.hs.get_datastore()
self.clock = self.hs.get_clock()
# We shard the handling of push notifications by user ID.
self._pusher_shard_config = hs.config.push.pusher_shard_config
self._instance_name = hs.get_instance_name()
# map from user id to app_id:pushkey to pusher
self.pushers = {} # type: Dict[str, Dict[str, Union[HttpPusher, EmailPusher]]]
# a lock for the pushers dict, since `count_pushers` is called from an different
# and we otherwise get concurrent modification errors
self._pushers_lock = Lock()
def count_pushers():
results = defaultdict(int) # type: Dict[Tuple[str, str], int]
with self._pushers_lock:
for pushers in self.pushers.values():
for pusher in pushers.values():
k = (type(pusher).__name__, pusher.app_id)
results[k] += 1
return results
LaterGauge(
name="synapse_pushers",
desc="the number of active pushers",
labels=["kind", "app_id"],
caller=count_pushers,
)
def start(self):
"""Starts the pushers off in a background process.
"""
@@ -104,6 +96,7 @@ class PusherPool:
Returns:
Deferred[EmailPusher|HttpPusher]
"""
time_now_msec = self.clock.time_msec()
# we try to create the pusher just to validate the config: it
@@ -176,6 +169,9 @@ class PusherPool:
access_tokens (Iterable[int]): access token *ids* to remove pushers
for
"""
if not self._pusher_shard_config.should_handle(self._instance_name, user_id):
return
tokens = set(access_tokens)
for p in (yield self.store.get_pushers_by_user_id(user_id)):
if p["access_token"] in tokens:
@@ -237,6 +233,9 @@ class PusherPool:
if not self._should_start_pushers:
return
if not self._pusher_shard_config.should_handle(self._instance_name, user_id):
return
resultlist = yield self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey)
pusher_dict = None
@@ -275,6 +274,11 @@ class PusherPool:
Returns:
Deferred[EmailPusher|HttpPusher]
"""
if not self._pusher_shard_config.should_handle(
self._instance_name, pusherdict["user_name"]
):
return
try:
p = self.pusher_factory.create_pusher(pusherdict)
except PusherConfigException as e:
@@ -298,11 +302,12 @@ class PusherPool:
appid_pushkey = "%s:%s" % (pusherdict["app_id"], pusherdict["pushkey"])
with self._pushers_lock:
byuser = self.pushers.setdefault(pusherdict["user_name"], {})
if appid_pushkey in byuser:
byuser[appid_pushkey].on_stop()
byuser[appid_pushkey] = p
byuser = self.pushers.setdefault(pusherdict["user_name"], {})
if appid_pushkey in byuser:
byuser[appid_pushkey].on_stop()
byuser[appid_pushkey] = p
synapse_pushers.labels(type(p).__name__, p.app_id).inc()
# Check if there *may* be push to process. We do this as this check is a
# lot cheaper to do than actually fetching the exact rows we need to
@@ -330,9 +335,10 @@ class PusherPool:
if appid_pushkey in byuser:
logger.info("Stopping pusher %s / %s", user_id, appid_pushkey)
byuser[appid_pushkey].on_stop()
with self._pushers_lock:
del byuser[appid_pushkey]
pusher = byuser.pop(appid_pushkey)
pusher.on_stop()
synapse_pushers.labels(type(pusher).__name__, pusher.app_id).dec()
yield self.store.delete_pusher_by_app_id_pushkey_user_id(
app_id, pushkey, user_id
@@ -26,7 +26,7 @@ class SlavedDeviceInboxStore(DeviceInboxWorkerStore, BaseSlavedStore):
def __init__(self, database: Database, db_conn, hs):
super(SlavedDeviceInboxStore, self).__init__(database, db_conn, hs)
self._device_inbox_id_gen = SlavedIdTracker(
db_conn, "device_max_stream_id", "stream_id"
db_conn, "device_inbox", "stream_id"
)
self._device_inbox_stream_cache = StreamChangeCache(
"DeviceInboxStreamChangeCache",
+220 -112
View File
@@ -14,9 +14,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Any, Dict, Iterable, Iterator, List, Optional, Set, Tuple, TypeVar
from typing import (
Any,
Dict,
Iterable,
Iterator,
List,
Optional,
Set,
Tuple,
TypeVar,
Union,
)
from prometheus_client import Counter
from typing_extensions import Deque
from twisted.internet.protocol import ReconnectingClientFactory
@@ -42,8 +54,8 @@ from synapse.replication.tcp.streams import (
EventsStream,
FederationStream,
Stream,
TypingStream,
)
from synapse.util.async_helpers import Linearizer
logger = logging.getLogger(__name__)
@@ -61,6 +73,12 @@ invalidate_cache_counter = Counter(
user_ip_cache_counter = Counter("synapse_replication_tcp_resource_user_ip_cache", "")
# the type of the entries in _command_queues_by_stream
_StreamCommandQueue = Deque[
Tuple[Union[RdataCommand, PositionCommand], AbstractConnection]
]
class ReplicationCommandHandler:
"""Handles incoming commands from replication as well as sending commands
back out to connections.
@@ -96,6 +114,14 @@ class ReplicationCommandHandler:
continue
if isinstance(stream, TypingStream):
# Only add TypingStream as a source on the instance in charge of
# typing.
if hs.config.worker.writers.typing == hs.get_instance_name():
self._streams_to_replicate.append(stream)
continue
# Only add any other streams if we're on master.
if hs.config.worker_app is not None:
continue
@@ -107,10 +133,6 @@ class ReplicationCommandHandler:
self._streams_to_replicate.append(stream)
self._position_linearizer = Linearizer(
"replication_position", clock=self._clock
)
# Map of stream name to batched updates. See RdataCommand for info on
# how batching works.
self._pending_batches = {} # type: Dict[str, List[Any]]
@@ -122,10 +144,6 @@ class ReplicationCommandHandler:
# outgoing replication commands to.)
self._connections = [] # type: List[AbstractConnection]
# For each connection, the incoming stream names that are coming from
# that connection.
self._streams_by_connection = {} # type: Dict[AbstractConnection, Set[str]]
LaterGauge(
"synapse_replication_tcp_resource_total_connections",
"",
@@ -133,6 +151,32 @@ class ReplicationCommandHandler:
lambda: len(self._connections),
)
# When POSITION or RDATA commands arrive, we stick them in a queue and process
# them in order in a separate background process.
# the streams which are currently being processed by _unsafe_process_stream
self._processing_streams = set() # type: Set[str]
# for each stream, a queue of commands that are awaiting processing, and the
# connection that they arrived on.
self._command_queues_by_stream = {
stream_name: _StreamCommandQueue() for stream_name in self._streams
}
# For each connection, the incoming stream names that have received a POSITION
# from that connection.
self._streams_by_connection = {} # type: Dict[AbstractConnection, Set[str]]
LaterGauge(
"synapse_replication_tcp_command_queue",
"Number of inbound RDATA/POSITION commands queued for processing",
["stream_name"],
lambda: {
(stream_name,): len(queue)
for stream_name, queue in self._command_queues_by_stream.items()
},
)
self._is_master = hs.config.worker_app is None
self._federation_sender = None
@@ -143,6 +187,64 @@ class ReplicationCommandHandler:
if self._is_master:
self._server_notices_sender = hs.get_server_notices_sender()
async def _add_command_to_stream_queue(
self, conn: AbstractConnection, cmd: Union[RdataCommand, PositionCommand]
) -> None:
"""Queue the given received command for processing
Adds the given command to the per-stream queue, and processes the queue if
necessary
"""
stream_name = cmd.stream_name
queue = self._command_queues_by_stream.get(stream_name)
if queue is None:
logger.error("Got %s for unknown stream: %s", cmd.NAME, stream_name)
return
# if we're already processing this stream, stick the new command in the
# queue, and we're done.
if stream_name in self._processing_streams:
queue.append((cmd, conn))
return
# otherwise, process the new command.
# arguably we should start off a new background process here, but nothing
# will be too upset if we don't return for ages, so let's save the overhead
# and use the existing logcontext.
self._processing_streams.add(stream_name)
try:
# might as well skip the queue for this one, since it must be empty
assert not queue
await self._process_command(cmd, conn, stream_name)
# now process any other commands that have built up while we were
# dealing with that one.
while queue:
cmd, conn = queue.popleft()
try:
await self._process_command(cmd, conn, stream_name)
except Exception:
logger.exception("Failed to handle command %s", cmd)
finally:
self._processing_streams.discard(stream_name)
async def _process_command(
self,
cmd: Union[PositionCommand, RdataCommand],
conn: AbstractConnection,
stream_name: str,
) -> None:
if isinstance(cmd, PositionCommand):
await self._process_position(stream_name, conn, cmd)
elif isinstance(cmd, RdataCommand):
await self._process_rdata(stream_name, conn, cmd)
else:
# This shouldn't be possible
raise Exception("Unrecognised command %s in stream queue", cmd.NAME)
def start_replication(self, hs):
"""Helper method to start a replication connection to the remote server
using TCP.
@@ -276,63 +378,71 @@ class ReplicationCommandHandler:
stream_name = cmd.stream_name
inbound_rdata_count.labels(stream_name).inc()
try:
row = STREAMS_MAP[stream_name].parse_row(cmd.row)
except Exception:
logger.exception("Failed to parse RDATA: %r %r", stream_name, cmd.row)
raise
# We linearize here for two reasons:
# We put the received command into a queue here for two reasons:
# 1. so we don't try and concurrently handle multiple rows for the
# same stream, and
# 2. so we don't race with getting a POSITION command and fetching
# missing RDATA.
with await self._position_linearizer.queue(cmd.stream_name):
# make sure that we've processed a POSITION for this stream *on this
# connection*. (A POSITION on another connection is no good, as there
# is no guarantee that we have seen all the intermediate updates.)
sbc = self._streams_by_connection.get(conn)
if not sbc or stream_name not in sbc:
# Let's drop the row for now, on the assumption we'll receive a
# `POSITION` soon and we'll catch up correctly then.
logger.debug(
"Discarding RDATA for unconnected stream %s -> %s",
stream_name,
cmd.token,
)
return
if cmd.token is None:
# I.e. this is part of a batch of updates for this stream (in
# which case batch until we get an update for the stream with a non
# None token).
self._pending_batches.setdefault(stream_name, []).append(row)
else:
# Check if this is the last of a batch of updates
rows = self._pending_batches.pop(stream_name, [])
rows.append(row)
await self._add_command_to_stream_queue(conn, cmd)
stream = self._streams.get(stream_name)
if not stream:
logger.error("Got RDATA for unknown stream: %s", stream_name)
return
async def _process_rdata(
self, stream_name: str, conn: AbstractConnection, cmd: RdataCommand
) -> None:
"""Process an RDATA command
# Find where we previously streamed up to.
current_token = stream.current_token(cmd.instance_name)
Called after the command has been popped off the queue of inbound commands
"""
try:
row = STREAMS_MAP[stream_name].parse_row(cmd.row)
except Exception as e:
raise Exception(
"Failed to parse RDATA: %r %r" % (stream_name, cmd.row)
) from e
# Discard this data if this token is earlier than the current
# position. Note that streams can be reset (in which case you
# expect an earlier token), but that must be preceded by a
# POSITION command.
if cmd.token <= current_token:
logger.debug(
"Discarding RDATA from stream %s at position %s before previous position %s",
stream_name,
cmd.token,
current_token,
)
else:
await self.on_rdata(stream_name, cmd.instance_name, cmd.token, rows)
# make sure that we've processed a POSITION for this stream *on this
# connection*. (A POSITION on another connection is no good, as there
# is no guarantee that we have seen all the intermediate updates.)
sbc = self._streams_by_connection.get(conn)
if not sbc or stream_name not in sbc:
# Let's drop the row for now, on the assumption we'll receive a
# `POSITION` soon and we'll catch up correctly then.
logger.debug(
"Discarding RDATA for unconnected stream %s -> %s",
stream_name,
cmd.token,
)
return
if cmd.token is None:
# I.e. this is part of a batch of updates for this stream (in
# which case batch until we get an update for the stream with a non
# None token).
self._pending_batches.setdefault(stream_name, []).append(row)
return
# Check if this is the last of a batch of updates
rows = self._pending_batches.pop(stream_name, [])
rows.append(row)
stream = self._streams[stream_name]
# Find where we previously streamed up to.
current_token = stream.current_token(cmd.instance_name)
# Discard this data if this token is earlier than the current
# position. Note that streams can be reset (in which case you
# expect an earlier token), but that must be preceded by a
# POSITION command.
if cmd.token <= current_token:
logger.debug(
"Discarding RDATA from stream %s at position %s before previous position %s",
stream_name,
cmd.token,
current_token,
)
else:
await self.on_rdata(stream_name, cmd.instance_name, cmd.token, rows)
async def on_rdata(
self, stream_name: str, instance_name: str, token: int, rows: list
@@ -358,67 +468,65 @@ class ReplicationCommandHandler:
logger.info("Handling '%s %s'", cmd.NAME, cmd.to_line())
stream_name = cmd.stream_name
stream = self._streams.get(stream_name)
if not stream:
logger.error("Got POSITION for unknown stream: %s", stream_name)
return
await self._add_command_to_stream_queue(conn, cmd)
# We protect catching up with a linearizer in case the replication
# connection reconnects under us.
with await self._position_linearizer.queue(stream_name):
# We're about to go and catch up with the stream, so remove from set
# of connected streams.
for streams in self._streams_by_connection.values():
streams.discard(stream_name)
async def _process_position(
self, stream_name: str, conn: AbstractConnection, cmd: PositionCommand
) -> None:
"""Process a POSITION command
# We clear the pending batches for the stream as the fetching of the
# missing updates below will fetch all rows in the batch.
self._pending_batches.pop(stream_name, [])
Called after the command has been popped off the queue of inbound commands
"""
stream = self._streams[stream_name]
# Find where we previously streamed up to.
current_token = stream.current_token(cmd.instance_name)
# We're about to go and catch up with the stream, so remove from set
# of connected streams.
for streams in self._streams_by_connection.values():
streams.discard(stream_name)
# If the position token matches our current token then we're up to
# date and there's nothing to do. Otherwise, fetch all updates
# between then and now.
missing_updates = cmd.token != current_token
while missing_updates:
logger.info(
"Fetching replication rows for '%s' between %i and %i",
stream_name,
current_token,
cmd.token,
)
(
updates,
current_token,
missing_updates,
) = await stream.get_updates_since(
cmd.instance_name, current_token, cmd.token
)
# We clear the pending batches for the stream as the fetching of the
# missing updates below will fetch all rows in the batch.
self._pending_batches.pop(stream_name, [])
# TODO: add some tests for this
# Find where we previously streamed up to.
current_token = stream.current_token(cmd.instance_name)
# Some streams return multiple rows with the same stream IDs,
# which need to be processed in batches.
for token, rows in _batch_updates(updates):
await self.on_rdata(
stream_name,
cmd.instance_name,
token,
[stream.parse_row(row) for row in rows],
)
logger.info("Caught up with stream '%s' to %i", stream_name, cmd.token)
# We've now caught up to position sent to us, notify handler.
await self._replication_data_handler.on_position(
cmd.stream_name, cmd.instance_name, cmd.token
# If the position token matches our current token then we're up to
# date and there's nothing to do. Otherwise, fetch all updates
# between then and now.
missing_updates = cmd.token != current_token
while missing_updates:
logger.info(
"Fetching replication rows for '%s' between %i and %i",
stream_name,
current_token,
cmd.token,
)
(updates, current_token, missing_updates) = await stream.get_updates_since(
cmd.instance_name, current_token, cmd.token
)
self._streams_by_connection.setdefault(conn, set()).add(stream_name)
# TODO: add some tests for this
# Some streams return multiple rows with the same stream IDs,
# which need to be processed in batches.
for token, rows in _batch_updates(updates):
await self.on_rdata(
stream_name,
cmd.instance_name,
token,
[stream.parse_row(row) for row in rows],
)
logger.info("Caught up with stream '%s' to %i", stream_name, cmd.token)
# We've now caught up to position sent to us, notify handler.
await self._replication_data_handler.on_position(
cmd.stream_name, cmd.instance_name, cmd.token
)
self._streams_by_connection.setdefault(conn, set()).add(stream_name)
async def on_REMOTE_SERVER_UP(
self, conn: AbstractConnection, cmd: RemoteServerUpCommand
+4 -3
View File
@@ -294,11 +294,12 @@ class TypingStream(Stream):
def __init__(self, hs):
typing_handler = hs.get_typing_handler()
if hs.config.worker_app is None:
# on the master, query the typing handler
writer_instance = hs.config.worker.writers.typing
if writer_instance == hs.get_instance_name():
# On the writer, query the typing handler
update_function = typing_handler.get_all_typing_updates
else:
# Query master process
# Query the typing writer process
update_function = make_http_update_function(hs, self.NAME)
super().__init__(
+1 -1
View File
@@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import heapq
from collections import Iterable
from collections.abc import Iterable
from typing import List, Tuple, Type
import attr
+7
View File
@@ -0,0 +1,7 @@
.header {
border-bottom: 4px solid #e4f7ed ! important;
}
.notif_link a, .footer a {
color: #76CFA6 ! important;
}
+2
View File
@@ -22,6 +22,8 @@
<img src="http://riot.im/img/external/riot-logo-email.png" width="83" height="83" alt="[Riot]"/>
{% elif app_name == "Vector" %}
<img src="http://matrix.org/img/vector-logo-email.png" width="64" height="83" alt="[Vector]"/>
{% elif app_name == "Element" %}
<img src="https://static.element.io/images/email-logo.png" width="83" height="83" alt="[Element]"/>
{% else %}
<img src="http://matrix.org/img/matrix-120x51.png" width="120" height="51" alt="[matrix]"/>
{% endif %}
+2
View File
@@ -22,6 +22,8 @@
<img src="http://riot.im/img/external/riot-logo-email.png" width="83" height="83" alt="[Riot]"/>
{% elif app_name == "Vector" %}
<img src="http://matrix.org/img/vector-logo-email.png" width="64" height="83" alt="[Vector]"/>
{% elif app_name == "Element" %}
<img src="https://static.element.io/images/email-logo.png" width="83" height="83" alt="[Element]"/>
{% else %}
<img src="http://matrix.org/img/matrix-120x51.png" width="120" height="51" alt="[matrix]"/>
{% endif %}
+2
View File
@@ -38,6 +38,7 @@ from synapse.rest.admin.rooms import (
DeleteRoomRestServlet,
JoinRoomAliasServlet,
ListRoomRestServlet,
RoomMembersRestServlet,
RoomRestServlet,
ShutdownRoomRestServlet,
)
@@ -201,6 +202,7 @@ def register_servlets(hs, http_server):
register_servlets_for_client_rest_resource(hs, http_server)
ListRoomRestServlet(hs).register(http_server)
RoomRestServlet(hs).register(http_server)
RoomMembersRestServlet(hs).register(http_server)
DeleteRoomRestServlet(hs).register(http_server)
JoinRoomAliasServlet(hs).register(http_server)
PurgeRoomServlet(hs).register(http_server)
+25
View File
@@ -231,6 +231,31 @@ class RoomRestServlet(RestServlet):
return 200, ret
class RoomMembersRestServlet(RestServlet):
"""
Get members list of a room.
"""
PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]+)/members")
def __init__(self, hs):
self.hs = hs
self.auth = hs.get_auth()
self.store = hs.get_datastore()
async def on_GET(self, request, room_id):
await assert_requester_is_admin(self.auth, request)
ret = await self.store.get_room(room_id)
if not ret:
raise NotFoundError("Room not found")
members = await self.store.get_users_in_room(room_id)
ret = {"members": members, "total": len(members)}
return 200, ret
class JoinRoomAliasServlet(RestServlet):
PATTERNS = admin_patterns("/join/(?P<room_identifier>[^/]*)")
+9
View File
@@ -818,9 +818,18 @@ class RoomTypingRestServlet(RestServlet):
self.typing_handler = hs.get_typing_handler()
self.auth = hs.get_auth()
# If we're not on the typing writer instance we should scream if we get
# requests.
self._is_typing_writer = (
hs.config.worker.writers.typing == hs.get_instance_name()
)
async def on_PUT(self, request, room_id, user_id):
requester = await self.auth.get_user_by_req(request)
if not self._is_typing_writer:
raise Exception("Got /typing request on instance that is not typing writer")
room_id = urlparse.unquote(room_id)
target_user = UserID.from_string(urlparse.unquote(user_id))
+10 -1
View File
@@ -22,6 +22,7 @@ from twisted.internet import defer
from synapse.api.errors import InteractiveAuthIncompleteError
from synapse.api.urls import CLIENT_API_PREFIX
from synapse.types import JsonDict
logger = logging.getLogger(__name__)
@@ -51,7 +52,15 @@ def client_patterns(path_regex, releases=(0,), unstable=True, v1=False):
return patterns
def set_timeline_upper_limit(filter_json, filter_timeline_limit):
def set_timeline_upper_limit(filter_json: JsonDict, filter_timeline_limit: int) -> None:
"""
Enforces a maximum limit of a timeline query.
Params:
filter_json: The timeline query to modify.
filter_timeline_limit: The maximum limit to allow, passing -1 will
disable enforcing a maximum limit.
"""
if filter_timeline_limit < 0:
return # no upper limits
timeline = filter_json.get("room", {}).get("timeline", {})
+6 -7
View File
@@ -44,7 +44,6 @@ from synapse.federation.federation_client import FederationClient
from synapse.federation.federation_server import (
FederationHandlerRegistry,
FederationServer,
ReplicationFederationHandlerRegistry,
)
from synapse.federation.send_queue import FederationRemoteSendQueue
from synapse.federation.sender import FederationSender
@@ -84,7 +83,7 @@ from synapse.handlers.room_member_worker import RoomMemberWorkerHandler
from synapse.handlers.set_password import SetPasswordHandler
from synapse.handlers.stats import StatsHandler
from synapse.handlers.sync import SyncHandler
from synapse.handlers.typing import TypingHandler
from synapse.handlers.typing import FollowerTypingHandler, TypingWriterHandler
from synapse.handlers.user_directory import UserDirectoryHandler
from synapse.http.client import InsecureInterceptableContextFactory, SimpleHttpClient
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
@@ -380,7 +379,10 @@ class HomeServer(object):
return PresenceHandler(self)
def build_typing_handler(self):
return TypingHandler(self)
if self.config.worker.writers.typing == self.get_instance_name():
return TypingWriterHandler(self)
else:
return FollowerTypingHandler(self)
def build_sync_handler(self):
return SyncHandler(self)
@@ -536,10 +538,7 @@ class HomeServer(object):
return RoomMemberMasterHandler(self)
def build_federation_registry(self):
if self.config.worker_app:
return ReplicationFederationHandlerRegistry(self)
else:
return FederationHandlerRegistry()
return FederationHandlerRegistry(self)
def build_server_notices_manager(self):
if self.config.worker_app:
+2
View File
@@ -148,3 +148,5 @@ class HomeServer(object):
self,
) -> synapse.http.matrixfederationclient.MatrixFederationHttpClient:
pass
def should_send_federation(self) -> bool:
pass
+2 -2
View File
@@ -100,8 +100,8 @@ def db_to_json(db_content):
if isinstance(db_content, memoryview):
db_content = db_content.tobytes()
# Decode it to a Unicode string before feeding it to json.loads, so we
# consistenty get a Unicode-containing object out.
# Decode it to a Unicode string before feeding it to json.loads, since
# Python 3.5 does not support deserializing bytes.
if isinstance(db_content, (bytes, bytearray)):
db_content = db_content.decode("utf8")
+4 -1
View File
@@ -249,7 +249,10 @@ class BackgroundUpdater(object):
retcol="progress_json",
)
progress = json.loads(progress_json)
# Avoid a circular import.
from synapse.storage._base import db_to_json
progress = db_to_json(progress_json)
time_start = self._clock.time_msec()
items_updated = await update_handler(progress, batch_size)
+1 -1
View File
@@ -128,7 +128,7 @@ class DataStore(
db_conn, "presence_stream", "stream_id"
)
self._device_inbox_id_gen = StreamIdGenerator(
db_conn, "device_max_stream_id", "stream_id"
db_conn, "device_inbox", "stream_id"
)
self._public_room_id_gen = StreamIdGenerator(
db_conn, "public_room_list_stream", "stream_id"
@@ -22,7 +22,7 @@ from canonicaljson import json
from twisted.internet import defer
from synapse.storage._base import SQLBaseStore
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import Database
from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
@@ -77,7 +77,7 @@ class AccountDataWorkerStore(SQLBaseStore):
)
global_account_data = {
row["account_data_type"]: json.loads(row["content"]) for row in rows
row["account_data_type"]: db_to_json(row["content"]) for row in rows
}
rows = self.db.simple_select_list_txn(
@@ -90,7 +90,7 @@ class AccountDataWorkerStore(SQLBaseStore):
by_room = {}
for row in rows:
room_data = by_room.setdefault(row["room_id"], {})
room_data[row["account_data_type"]] = json.loads(row["content"])
room_data[row["account_data_type"]] = db_to_json(row["content"])
return global_account_data, by_room
@@ -113,7 +113,7 @@ class AccountDataWorkerStore(SQLBaseStore):
)
if result:
return json.loads(result)
return db_to_json(result)
else:
return None
@@ -137,7 +137,7 @@ class AccountDataWorkerStore(SQLBaseStore):
)
return {
row["account_data_type"]: json.loads(row["content"]) for row in rows
row["account_data_type"]: db_to_json(row["content"]) for row in rows
}
return self.db.runInteraction(
@@ -170,7 +170,7 @@ class AccountDataWorkerStore(SQLBaseStore):
allow_none=True,
)
return json.loads(content_json) if content_json else None
return db_to_json(content_json) if content_json else None
return self.db.runInteraction(
"get_account_data_for_room_and_type", get_account_data_for_room_and_type_txn
@@ -255,7 +255,7 @@ class AccountDataWorkerStore(SQLBaseStore):
txn.execute(sql, (user_id, stream_id))
global_account_data = {row[0]: json.loads(row[1]) for row in txn}
global_account_data = {row[0]: db_to_json(row[1]) for row in txn}
sql = (
"SELECT room_id, account_data_type, content FROM room_account_data"
@@ -267,7 +267,7 @@ class AccountDataWorkerStore(SQLBaseStore):
account_data_by_room = {}
for row in txn:
room_account_data = account_data_by_room.setdefault(row[0], {})
room_account_data[row[1]] = json.loads(row[2])
room_account_data[row[1]] = db_to_json(row[2])
return global_account_data, account_data_by_room
@@ -22,7 +22,7 @@ from twisted.internet import defer
from synapse.appservice import AppServiceTransaction
from synapse.config.appservice import load_appservices
from synapse.storage._base import SQLBaseStore
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
from synapse.storage.database import Database
@@ -303,7 +303,7 @@ class ApplicationServiceTransactionWorkerStore(
if not entry:
return None
event_ids = json.loads(entry["event_ids"])
event_ids = db_to_json(entry["event_ids"])
events = yield self.get_events_as_list(event_ids)
@@ -21,7 +21,7 @@ from canonicaljson import json
from twisted.internet import defer
from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import Database
from synapse.util.caches.expiringcache import ExpiringCache
@@ -65,7 +65,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
messages = []
for row in txn:
stream_pos = row[0]
messages.append(json.loads(row[1]))
messages.append(db_to_json(row[1]))
if len(messages) < limit:
stream_pos = current_stream_id
return messages, stream_pos
@@ -173,7 +173,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
messages = []
for row in txn:
stream_pos = row[0]
messages.append(json.loads(row[1]))
messages.append(db_to_json(row[1]))
if len(messages) < limit:
log_kv({"message": "Set stream position to current position"})
stream_pos = current_stream_id
@@ -424,9 +424,6 @@ class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore)
def _add_messages_to_local_device_inbox_txn(
self, txn, stream_id, messages_by_user_then_device
):
sql = "UPDATE device_max_stream_id" " SET stream_id = ?" " WHERE stream_id < ?"
txn.execute(sql, (stream_id, stream_id))
local_by_user_then_device = {}
for user_id, messages_by_device in messages_by_user_then_device.items():
messages_json_for_user = {}
+1 -1
View File
@@ -577,7 +577,7 @@ class DeviceWorkerStore(SQLBaseStore):
rows = yield self.db.execute(
"get_users_whose_signatures_changed", None, sql, user_id, from_key
)
return {user for row in rows for user in json.loads(row[0])}
return {user for row in rows for user in db_to_json(row[0])}
else:
return set()
@@ -14,13 +14,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from canonicaljson import json
from twisted.internet import defer
from synapse.api.errors import StoreError
from synapse.logging.opentracing import log_kv, trace
from synapse.storage._base import SQLBaseStore
from synapse.storage._base import SQLBaseStore, db_to_json
class EndToEndRoomKeyStore(SQLBaseStore):
@@ -148,7 +148,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
"forwarded_count": row["forwarded_count"],
# is_verified must be returned to the client as a boolean
"is_verified": bool(row["is_verified"]),
"session_data": json.loads(row["session_data"]),
"session_data": db_to_json(row["session_data"]),
}
return sessions
@@ -222,7 +222,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
"first_message_index": row[2],
"forwarded_count": row[3],
"is_verified": row[4],
"session_data": json.loads(row[5]),
"session_data": db_to_json(row[5]),
}
return ret
@@ -319,7 +319,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
keyvalues={"user_id": user_id, "version": this_version, "deleted": 0},
retcols=("version", "algorithm", "auth_data", "etag"),
)
result["auth_data"] = json.loads(result["auth_data"])
result["auth_data"] = db_to_json(result["auth_data"])
result["version"] = str(result["version"])
if result["etag"] is None:
result["etag"] = 0
@@ -366,7 +366,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
for row in rows:
user_id = row["user_id"]
key_type = row["keytype"]
key = json.loads(row["keydata"])
key = db_to_json(row["keydata"])
user_info = result.setdefault(user_id, {})
user_info[key_type] = key
@@ -21,7 +21,7 @@ from canonicaljson import json
from twisted.internet import defer
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import LoggingTransaction, SQLBaseStore
from synapse.storage._base import LoggingTransaction, SQLBaseStore, db_to_json
from synapse.storage.database import Database
from synapse.util.caches.descriptors import cachedInlineCallbacks
@@ -58,7 +58,7 @@ def _deserialize_action(actions, is_highlight):
"""Custom deserializer for actions. This allows us to "compress" common actions
"""
if actions:
return json.loads(actions)
return db_to_json(actions)
if is_highlight:
return DEFAULT_HIGHLIGHT_ACTION
+4 -5
View File
@@ -20,7 +20,6 @@ from collections import OrderedDict, namedtuple
from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple
import attr
from canonicaljson import json
from prometheus_client import Counter
from twisted.internet import defer
@@ -32,7 +31,7 @@ from synapse.crypto.event_signing import compute_event_reference_hash
from synapse.events import EventBase # noqa: F401
from synapse.events.snapshot import EventContext # noqa: F401
from synapse.logging.utils import log_function
from synapse.storage._base import make_in_list_sql_clause
from synapse.storage._base import db_to_json, make_in_list_sql_clause
from synapse.storage.data_stores.main.search import SearchEntry
from synapse.storage.database import Database, LoggingTransaction
from synapse.storage.util.id_generators import StreamIdGenerator
@@ -236,7 +235,7 @@ class PersistEventsStore:
)
txn.execute(sql + clause, args)
results.extend(r[0] for r in txn if not json.loads(r[1]).get("soft_failed"))
results.extend(r[0] for r in txn if not db_to_json(r[1]).get("soft_failed"))
for chunk in batch_iter(event_ids, 100):
yield self.db.runInteraction(
@@ -297,7 +296,7 @@ class PersistEventsStore:
if prev_event_id in existing_prevs:
continue
soft_failed = json.loads(metadata).get("soft_failed")
soft_failed = db_to_json(metadata).get("soft_failed")
if soft_failed or rejected:
to_recursively_check.append(prev_event_id)
existing_prevs.add(prev_event_id)
@@ -583,7 +582,7 @@ class PersistEventsStore:
txn.execute(sql, (room_id, EventTypes.Create, ""))
row = txn.fetchone()
if row:
event_json = json.loads(row[0])
event_json = db_to_json(row[0])
content = event_json.get("content", {})
creator = content.get("creator")
room_version_id = content.get("room_version", RoomVersions.V1.identifier)
@@ -15,12 +15,10 @@
import logging
from canonicaljson import json
from twisted.internet import defer
from synapse.api.constants import EventContentFields
from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import Database
logger = logging.getLogger(__name__)
@@ -125,7 +123,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
for row in rows:
try:
event_id = row[1]
event_json = json.loads(row[2])
event_json = db_to_json(row[2])
sender = event_json["sender"]
content = event_json["content"]
@@ -208,7 +206,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
for row in ev_rows:
event_id = row["event_id"]
event_json = json.loads(row["json"])
event_json = db_to_json(row["json"])
try:
origin_server_ts = event_json["origin_server_ts"]
except (KeyError, AttributeError):
@@ -317,7 +315,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
soft_failed = False
if metadata:
soft_failed = json.loads(metadata).get("soft_failed")
soft_failed = db_to_json(metadata).get("soft_failed")
if soft_failed or rejected:
soft_failed_events_to_lookup.add(event_id)
@@ -358,7 +356,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
graph[event_id] = {prev_event_id}
soft_failed = json.loads(metadata).get("soft_failed")
soft_failed = db_to_json(metadata).get("soft_failed")
if soft_failed or rejected:
soft_failed_events_to_lookup.add(event_id)
else:
@@ -543,7 +541,7 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
last_row_event_id = ""
for (event_id, event_json_raw) in results:
try:
event_json = json.loads(event_json_raw)
event_json = db_to_json(event_json_raw)
self.db.simple_insert_many_txn(
txn=txn,
@@ -21,7 +21,6 @@ import threading
from collections import namedtuple
from typing import List, Optional, Tuple
from canonicaljson import json
from constantly import NamedConstant, Names
from twisted.internet import defer
@@ -40,7 +39,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import BackfillStream
from synapse.replication.tcp.streams.events import EventsStream
from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import Database
from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.types import get_domain_from_id
@@ -611,8 +610,8 @@ class EventsWorkerStore(SQLBaseStore):
if not allow_rejected and rejected_reason:
continue
d = json.loads(row["json"])
internal_metadata = json.loads(row["internal_metadata"])
d = db_to_json(row["json"])
internal_metadata = db_to_json(row["internal_metadata"])
format_version = row["format_version"]
if format_version is None:
@@ -640,7 +639,7 @@ class EventsWorkerStore(SQLBaseStore):
else:
room_version = KNOWN_ROOM_VERSIONS.get(room_version_id)
if not room_version:
logger.error(
logger.warning(
"Event %s in room %s has unknown room version %s",
event_id,
d["room_id"],
@@ -21,7 +21,7 @@ from canonicaljson import json
from twisted.internet import defer
from synapse.api.errors import SynapseError
from synapse.storage._base import SQLBaseStore
from synapse.storage._base import SQLBaseStore, db_to_json
# The category ID for the "default" category. We don't store as null in the
# database to avoid the fun of null != null
@@ -197,7 +197,7 @@ class GroupServerWorkerStore(SQLBaseStore):
categories = {
row[0]: {
"is_public": row[1],
"profile": json.loads(row[2]),
"profile": db_to_json(row[2]),
"order": row[3],
}
for row in txn
@@ -221,7 +221,7 @@ class GroupServerWorkerStore(SQLBaseStore):
return {
row["category_id"]: {
"is_public": row["is_public"],
"profile": json.loads(row["profile"]),
"profile": db_to_json(row["profile"]),
}
for row in rows
}
@@ -235,7 +235,7 @@ class GroupServerWorkerStore(SQLBaseStore):
desc="get_group_category",
)
category["profile"] = json.loads(category["profile"])
category["profile"] = db_to_json(category["profile"])
return category
@@ -251,7 +251,7 @@ class GroupServerWorkerStore(SQLBaseStore):
return {
row["role_id"]: {
"is_public": row["is_public"],
"profile": json.loads(row["profile"]),
"profile": db_to_json(row["profile"]),
}
for row in rows
}
@@ -265,7 +265,7 @@ class GroupServerWorkerStore(SQLBaseStore):
desc="get_group_role",
)
role["profile"] = json.loads(role["profile"])
role["profile"] = db_to_json(role["profile"])
return role
@@ -333,7 +333,7 @@ class GroupServerWorkerStore(SQLBaseStore):
roles = {
row[0]: {
"is_public": row[1],
"profile": json.loads(row[2]),
"profile": db_to_json(row[2]),
"order": row[3],
}
for row in txn
@@ -462,7 +462,7 @@ class GroupServerWorkerStore(SQLBaseStore):
now = int(self._clock.time_msec())
if row and now < row["valid_until_ms"]:
return json.loads(row["attestation_json"])
return db_to_json(row["attestation_json"])
return None
@@ -489,7 +489,7 @@ class GroupServerWorkerStore(SQLBaseStore):
"group_id": row[0],
"type": row[1],
"membership": row[2],
"content": json.loads(row[3]),
"content": db_to_json(row[3]),
}
for row in txn
]
@@ -519,7 +519,7 @@ class GroupServerWorkerStore(SQLBaseStore):
"group_id": group_id,
"membership": membership,
"type": gtype,
"content": json.loads(content_json),
"content": db_to_json(content_json),
}
for group_id, membership, gtype, content_json in txn
]
@@ -567,7 +567,7 @@ class GroupServerWorkerStore(SQLBaseStore):
"""
txn.execute(sql, (last_id, current_id, limit))
updates = [
(stream_id, (group_id, user_id, gtype, json.loads(content_json)))
(stream_id, (group_id, user_id, gtype, db_to_json(content_json)))
for stream_id, group_id, user_id, gtype, content_json in txn
]
@@ -24,7 +24,7 @@ from twisted.internet import defer
from synapse.push.baserules import list_with_base_rules
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.storage._base import SQLBaseStore
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.data_stores.main.appservice import ApplicationServiceWorkerStore
from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
from synapse.storage.data_stores.main.pusher import PusherWorkerStore
@@ -43,8 +43,8 @@ def _load_rules(rawrules, enabled_map):
ruleslist = []
for rawrule in rawrules:
rule = dict(rawrule)
rule["conditions"] = json.loads(rawrule["conditions"])
rule["actions"] = json.loads(rawrule["actions"])
rule["conditions"] = db_to_json(rawrule["conditions"])
rule["actions"] = db_to_json(rawrule["actions"])
rule["default"] = False
ruleslist.append(rule)
+3 -3
View File
@@ -17,11 +17,11 @@
import logging
from typing import Iterable, Iterator, List, Tuple
from canonicaljson import encode_canonical_json, json
from canonicaljson import encode_canonical_json
from twisted.internet import defer
from synapse.storage._base import SQLBaseStore
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.util.caches.descriptors import cachedInlineCallbacks, cachedList
logger = logging.getLogger(__name__)
@@ -36,7 +36,7 @@ class PusherWorkerStore(SQLBaseStore):
for r in rows:
dataJson = r["data"]
try:
r["data"] = json.loads(dataJson)
r["data"] = db_to_json(dataJson)
except Exception as e:
logger.warning(
"Invalid JSON in data for pusher %d: %s, %s",
+4 -4
View File
@@ -22,7 +22,7 @@ from canonicaljson import json
from twisted.internet import defer
from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import Database
from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.util.async_helpers import ObservableDeferred
@@ -203,7 +203,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
for row in rows:
content.setdefault(row["event_id"], {}).setdefault(row["receipt_type"], {})[
row["user_id"]
] = json.loads(row["data"])
] = db_to_json(row["data"])
return [{"type": "m.receipt", "room_id": room_id, "content": content}]
@@ -260,7 +260,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
event_entry = room_event["content"].setdefault(row["event_id"], {})
receipt_type = event_entry.setdefault(row["receipt_type"], {})
receipt_type[row["user_id"]] = json.loads(row["data"])
receipt_type[row["user_id"]] = db_to_json(row["data"])
results = {
room_id: [results[room_id]] if room_id in results else []
@@ -329,7 +329,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
"""
txn.execute(sql, (last_id, current_id, limit))
updates = [(r[0], r[1:5] + (json.loads(r[5]),)) for r in txn]
updates = [(r[0], r[1:5] + (db_to_json(r[5]),)) for r in txn]
limited = False
upper_bound = current_id

Some files were not shown because too many files have changed in this diff Show More