1
0

Merge commit '837293c31' into anoa/dinsic_release_1_21_x

* commit '837293c31':
  Remove obsolete __future__ imports (#8337)
  Use admin_patterns for all admin APIs. (#8331)
  Fix a potential bug of UnboundLocalError (#8329)
  Switch metaclass initialization to python 3-compatible syntax (#8326)
  Catch-up after Federation Outage (split, 4): catch-up loop (#8272)
  Use slots in attrs classes where possible (#8296)
  Fix typos in comments.
  Add the topic and avatar to the room details admin API (#8305)
  Improve SAML error messages (#8248)
  Add experimental support for sharding event persister. Again. (#8294)
  Make `StreamToken.room_key` be a `RoomStreamToken` instance. (#8281)
  Use TLSv1.2 for fake servers in tests (#8208)
  Add /_synapse/client to the reverse proxy docs (#8227)
  Clean up `Notifier.on_new_room_event` code path (#8288)
This commit is contained in:
Andrew Morgan
2020-10-20 19:51:28 +01:00
113 changed files with 990 additions and 593 deletions
-32
View File
@@ -128,38 +128,6 @@ template. These templates are similar, but the parameters are slightly different
* A string ``error`` parameter is available that includes a short hint of why a
user is seeing the error page.
ThirdPartyEventRules breaking changes
-------------------------------------
This release introduces a backwards-incompatible change to modules making use of
`ThirdPartyEventRules` in Synapse.
The `http_client` argument is no longer passed to modules as they are initialised. Instead,
modules are expected to make use of the `http_client` property on the `ModuleApi` class.
Modules are now passed a `module_api` argument during initialisation, which is an instance of
`ModuleApi`.
New HTML templates
------------------
A new HTML template,
`password_reset_confirmation.html <https://github.com/matrix-org/synapse/blob/develop/synapse/res/templates/password_reset_confirmation.html>`_,
has been added to the ``synapse/res/templates`` directory. If you are using a
custom template directory, you may want to copy the template over and modify it.
Note that as of v1.20.0, templates do not need to be included in custom template
directories for Synapse to start. The default templates will be used if a custom
template cannot be found.
This page will appear to the user after clicking a password reset link that has
been emailed to them.
To complete password reset, the page must include a way to make a `POST`
request to
``/_synapse/client/password_reset/{medium}/submit_token``
with the query parameters from the original link, presented as a URL-encoded form. See the file
itself for more details.
Upgrading to v1.18.0
====================
+1
View File
@@ -0,0 +1 @@
Fix tests on distros which disable TLSv1.0. Contributed by @danc86.
+1
View File
@@ -0,0 +1 @@
Add `/_synapse/client` to the reverse proxy documentation.
+1
View File
@@ -0,0 +1 @@
Consolidate the SSO error template across all configuration.
+1
View File
@@ -0,0 +1 @@
Fix messages over federation being lost until an event is sent into the same room.
+1
View File
@@ -0,0 +1 @@
Change `StreamToken.room_key` to be a `RoomStreamToken` instance.
+1
View File
@@ -0,0 +1 @@
Refactor notifier code to correctly use the max event stream position.
+1
View File
@@ -0,0 +1 @@
Add experimental support for sharding event persister.
+1
View File
@@ -0,0 +1 @@
Use slotted classes where possible.
+1
View File
@@ -0,0 +1 @@
Add the room topic and avatar to the room details admin API.
+1
View File
@@ -0,0 +1 @@
Fix fetching events from remote servers that are malformed.
+1
View File
@@ -0,0 +1 @@
Update outdated usages of `metaclass` to python 3 syntax.
+1
View File
@@ -0,0 +1 @@
Fix UnboundLocalError from occuring when appservices send malformed register request.
+1
View File
@@ -0,0 +1 @@
Use the `admin_patterns` helper in additional locations.
+1
View File
@@ -0,0 +1 @@
Remove `__future__` imports related to Python 2 compatibility.
-2
View File
@@ -15,8 +15,6 @@
# limitations under the License.
""" Starts a synapse client console. """
from __future__ import print_function
import argparse
import cmd
import getpass
-2
View File
@@ -13,8 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import json
import urllib
from pprint import pformat
-2
View File
@@ -1,5 +1,3 @@
from __future__ import print_function
import argparse
import cgi
import datetime
-2
View File
@@ -1,5 +1,3 @@
from __future__ import print_function
import argparse
import cgi
import datetime
@@ -10,8 +10,6 @@ the bridge.
Requires:
npm install jquery jsdom
"""
from __future__ import print_function
import json
import subprocess
import time
+1 -7
View File
@@ -1,5 +1,4 @@
#!/usr/bin/env python
from __future__ import print_function
import json
import sys
@@ -8,11 +7,6 @@ from argparse import ArgumentParser
import requests
try:
raw_input
except NameError: # Python 3
raw_input = input
def _mkurl(template, kws):
for key in kws:
@@ -58,7 +52,7 @@ def main(hs, room_id, access_token, user_id_prefix, why):
print("The following user IDs will be kicked from %s" % room_name)
for uid in kick_list:
print(uid)
doit = raw_input("Continue? [Y]es\n")
doit = input("Continue? [Y]es\n")
if len(doit) > 0 and doit.lower() == "y":
print("Kicking members...")
# encode them all
+4
View File
@@ -275,6 +275,8 @@ The following fields are possible in the JSON response body:
* `room_id` - The ID of the room.
* `name` - The name of the room.
* `topic` - The topic of the room.
* `avatar` - The `mxc` URI to the avatar of the room.
* `canonical_alias` - The canonical (main) alias address of the room.
* `joined_members` - How many users are currently in the room.
* `joined_local_members` - How many local users are currently in the room.
@@ -304,6 +306,8 @@ Response:
{
"room_id": "!mscvqgqpHYjBGDxNym:matrix.org",
"name": "Music Theory",
"avatar": "mxc://matrix.org/AQDaVFlbkQoErdOgqWRgiGSV",
"topic": "Theory, Composition, Notation, Analysis",
"canonical_alias": "#musictheory:matrix.org",
"joined_members": 127
"joined_local_members": 2,
+21 -2
View File
@@ -11,7 +11,7 @@ privileges.
**NOTE**: Your reverse proxy must not `canonicalise` or `normalise`
the requested URI in any way (for example, by decoding `%xx` escapes).
Beware that Apache *will* canonicalise URIs unless you specifify
Beware that Apache *will* canonicalise URIs unless you specify
`nocanon`.
When setting up a reverse proxy, remember that Matrix clients and other
@@ -23,6 +23,10 @@ specification](https://matrix.org/docs/spec/server_server/latest#resolving-serve
for more details of the algorithm used for federation connections, and
[delegate.md](<delegate.md>) for instructions on setting up delegation.
Endpoints that are part of the standardised Matrix specification are
located under `/_matrix`, whereas endpoints specific to Synapse are
located under `/_synapse/client`.
Let's assume that we expect clients to connect to our server at
`https://matrix.example.com`, and other servers to connect at
`https://example.com:8448`. The following sections detail the configuration of
@@ -45,7 +49,7 @@ server {
server_name matrix.example.com;
location /_matrix {
location ~* ^(\/_matrix|\/_synapse\/client) {
proxy_pass http://localhost:8008;
proxy_set_header X-Forwarded-For $remote_addr;
# Nginx by default only allows file uploads up to 1M in size
@@ -65,6 +69,10 @@ matrix.example.com {
proxy /_matrix http://localhost:8008 {
transparent
}
proxy /_synapse/client http://localhost:8008 {
transparent
}
}
example.com:8448 {
@@ -79,6 +87,7 @@ example.com:8448 {
```
matrix.example.com {
reverse_proxy /_matrix/* http://localhost:8008
reverse_proxy /_synapse/client/* http://localhost:8008
}
example.com:8448 {
@@ -96,6 +105,8 @@ example.com:8448 {
AllowEncodedSlashes NoDecode
ProxyPass /_matrix http://127.0.0.1:8008/_matrix nocanon
ProxyPassReverse /_matrix http://127.0.0.1:8008/_matrix
ProxyPass /_synapse/client http://127.0.0.1:8008/_synapse/client nocanon
ProxyPassReverse /_synapse/client http://127.0.0.1:8008/_synapse/client
</VirtualHost>
<VirtualHost *:8448>
@@ -119,6 +130,7 @@ frontend https
# Matrix client traffic
acl matrix-host hdr(host) -i matrix.example.com
acl matrix-path path_beg /_matrix
acl matrix-path path_beg /_synapse/client
use_backend matrix if matrix-host matrix-path
@@ -146,3 +158,10 @@ connecting to Synapse from a client.
Synapse exposes a health check endpoint for use by reverse proxies.
Each configured HTTP listener has a `/health` endpoint which always returns
200 OK (and doesn't get logged).
## Synapse administration endpoints
Endpoints for administering your Synapse instance are placed under
`/_synapse/admin`. These require authentication through an access token of an
admin user. However as access to these endpoints grants the caller a lot of power,
we do not recommend exposing them to the public internet without good reason.
+4 -26
View File
@@ -1660,11 +1660,14 @@ trusted_key_servers:
# At least one of `sp_config` or `config_path` must be set in this section to
# enable SAML login.
#
# (You will probably also want to set the following options to `false` to
# You will probably also want to set the following options to `false` to
# disable the regular login/registration flows:
# * enable_registration
# * password_config.enabled
#
# You will also want to investigate the settings under the "sso" configuration
# section below.
#
# Once SAML support is enabled, a metadata file will be exposed at
# https://<server>:<port>/_matrix/saml2/metadata.xml, which you may be able to
# use to configure your SAML IdP with. Alternatively, you can manually configure
@@ -1787,31 +1790,6 @@ saml2_config:
# - attribute: department
# value: "sales"
# Directory in which Synapse will try to find the template files below.
# If not set, default templates from within the Synapse package will be used.
#
# DO NOT UNCOMMENT THIS SETTING unless you want to customise the templates.
# If you *do* uncomment it, you will need to make sure that all the templates
# below are in the directory.
#
# Synapse will look for the following templates in this directory:
#
# * HTML page to display to users if something goes wrong during the
# authentication process: 'saml_error.html'.
#
# When rendering, this template is given the following variables:
# * code: an HTML error code corresponding to the error that is being
# returned (typically 400 or 500)
#
# * msg: a textual message describing the error.
#
# The variables will automatically be HTML-escaped.
#
# You can see the default templates at:
# https://github.com/matrix-org/synapse/tree/master/synapse/res/templates
#
#template_dir: "res/templates"
# OpenID Connect integration. The following settings can be used to make Synapse
# use an OpenID Connect Provider for authentication, instead of its internal
+1
View File
@@ -217,6 +217,7 @@ expressions:
^/_matrix/client/(api/v1|r0|unstable)/joined_groups$
^/_matrix/client/(api/v1|r0|unstable)/publicised_groups$
^/_matrix/client/(api/v1|r0|unstable)/publicised_groups/
^/_synapse/client/password_reset/email/submit_token$
# Registration/login requests
^/_matrix/client/(api/v1|r0|unstable)/login$
+2
View File
@@ -46,10 +46,12 @@ files =
synapse/server_notices,
synapse/spam_checker_api,
synapse/state,
synapse/storage/databases/main/events.py,
synapse/storage/databases/main/stream.py,
synapse/storage/databases/main/ui_auth.py,
synapse/storage/database.py,
synapse/storage/engines,
synapse/storage/persist_events.py,
synapse/storage/state.py,
synapse/storage/util,
synapse/streams,
-2
View File
@@ -1,7 +1,5 @@
#! /usr/bin/python
from __future__ import print_function
import argparse
import ast
import os
-2
View File
@@ -1,7 +1,5 @@
#!/usr/bin/env python2
from __future__ import print_function
import sys
import pymacaroons
-2
View File
@@ -15,8 +15,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import argparse
import base64
import json
-2
View File
@@ -1,5 +1,3 @@
from __future__ import print_function
import sqlite3
import sys
@@ -32,8 +32,6 @@ To use, pipe the above into::
PYTHON_PATH=. ./scripts/move_remote_media_to_new_store.py <source repo> <dest repo>
"""
from __future__ import print_function
import argparse
import logging
import os
-2
View File
@@ -14,8 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from synapse._scripts.register_new_matrix_user import main
if __name__ == "__main__":
@@ -14,8 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import argparse
import getpass
import hashlib
-2
View File
@@ -15,8 +15,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import gc
import logging
import math
+18 -3
View File
@@ -833,11 +833,26 @@ class ShardedWorkerHandlingConfig:
def should_handle(self, instance_name: str, key: str) -> bool:
"""Whether this instance is responsible for handling the given key.
"""
# If multiple instances are not defined we always return true.
# If multiple instances are not defined we always return true
if not self.instances or len(self.instances) == 1:
return True
return self.get_instance(key) == instance_name
def get_instance(self, key: str) -> str:
"""Get the instance responsible for handling the given key.
Note: For things like federation sending the config for which instance
is sending is known only to the sender instance if there is only one.
Therefore `should_handle` should be used where possible.
"""
if not self.instances:
return "master"
if len(self.instances) == 1:
return self.instances[0]
# We shard by taking the hash, modulo it by the number of instances and
# then checking whether this instance matches the instance at that
# index.
@@ -847,7 +862,7 @@ class ShardedWorkerHandlingConfig:
dest_hash = sha256(key.encode("utf8")).digest()
dest_int = int.from_bytes(dest_hash, byteorder="little")
remainder = dest_int % (len(self.instances))
return self.instances[remainder] == instance_name
return self.instances[remainder]
__all__ = ["Config", "RootConfig", "ShardedWorkerHandlingConfig"]
+1
View File
@@ -142,3 +142,4 @@ class ShardedWorkerHandlingConfig:
instances: List[str]
def __init__(self, instances: List[str]) -> None: ...
def should_handle(self, instance_name: str, key: str) -> bool: ...
def get_instance(self, key: str) -> str: ...
-1
View File
@@ -14,7 +14,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
# This file can't be called email.py because if it is, we cannot:
import email.utils
+4 -30
View File
@@ -169,10 +169,6 @@ class SAML2Config(Config):
saml2_config.get("saml_session_lifetime", "15m")
)
self.saml2_error_html_template = self.read_templates(
["saml_error.html"], saml2_config.get("template_dir")
)[0]
def _default_saml_config_dict(
self, required_attributes: set, optional_attributes: set
):
@@ -225,11 +221,14 @@ class SAML2Config(Config):
# At least one of `sp_config` or `config_path` must be set in this section to
# enable SAML login.
#
# (You will probably also want to set the following options to `false` to
# You will probably also want to set the following options to `false` to
# disable the regular login/registration flows:
# * enable_registration
# * password_config.enabled
#
# You will also want to investigate the settings under the "sso" configuration
# section below.
#
# Once SAML support is enabled, a metadata file will be exposed at
# https://<server>:<port>/_matrix/saml2/metadata.xml, which you may be able to
# use to configure your SAML IdP with. Alternatively, you can manually configure
@@ -351,31 +350,6 @@ class SAML2Config(Config):
# value: "staff"
# - attribute: department
# value: "sales"
# Directory in which Synapse will try to find the template files below.
# If not set, default templates from within the Synapse package will be used.
#
# DO NOT UNCOMMENT THIS SETTING unless you want to customise the templates.
# If you *do* uncomment it, you will need to make sure that all the templates
# below are in the directory.
#
# Synapse will look for the following templates in this directory:
#
# * HTML page to display to users if something goes wrong during the
# authentication process: 'saml_error.html'.
#
# When rendering, this template is given the following variables:
# * code: an HTML error code corresponding to the error that is being
# returned (typically 400 or 500)
#
# * msg: a textual message describing the error.
#
# The variables will automatically be HTML-escaped.
#
# You can see the default templates at:
# https://github.com/matrix-org/synapse/tree/master/synapse/res/templates
#
#template_dir: "res/templates"
""" % {
"config_dir_path": config_dir_path
}
-2
View File
@@ -13,8 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
import sys
from ._base import Config
+27 -10
View File
@@ -13,12 +13,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Union
import attr
from ._base import Config, ConfigError, ShardedWorkerHandlingConfig
from .server import ListenerConfig, parse_listener_def
def _instance_to_list_converter(obj: Union[str, List[str]]) -> List[str]:
"""Helper for allowing parsing a string or list of strings to a config
option expecting a list of strings.
"""
if isinstance(obj, str):
return [obj]
return obj
@attr.s
class InstanceLocationConfig:
"""The host and port to talk to an instance via HTTP replication.
@@ -33,11 +45,13 @@ class WriterLocations:
"""Specifies the instances that write various streams.
Attributes:
events: The instance that writes to the event and backfill streams.
events: The instance that writes to the typing stream.
events: The instances that write to the event and backfill streams.
typing: The instance that writes to the typing stream.
"""
events = attr.ib(default="master", type=str)
events = attr.ib(
default=["master"], type=List[str], converter=_instance_to_list_converter
)
typing = attr.ib(default="master", type=str)
@@ -105,15 +119,18 @@ class WorkerConfig(Config):
writers = config.get("stream_writers") or {}
self.writers = WriterLocations(**writers)
# Check that the configured writer for events and typing also appears in
# Check that the configured writers for events and typing also appears in
# `instance_map`.
for stream in ("events", "typing"):
instance = getattr(self.writers, stream)
if instance != "master" and instance not in self.instance_map:
raise ConfigError(
"Instance %r is configured to write %s but does not appear in `instance_map` config."
% (instance, stream)
)
instances = _instance_to_list_converter(getattr(self.writers, stream))
for instance in instances:
if instance != "master" and instance not in self.instance_map:
raise ConfigError(
"Instance %r is configured to write %s but does not appear in `instance_map` config."
% (instance, stream)
)
self.events_shard_config = ShardedWorkerHandlingConfig(self.writers.events)
def generate_config_section(self, config_dir_path, server_name, **kwargs):
return """\
@@ -15,7 +15,7 @@
# limitations under the License.
import datetime
import logging
from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Tuple
from typing import TYPE_CHECKING, Dict, Hashable, Iterable, List, Optional, Tuple, cast
from prometheus_client import Counter
@@ -92,6 +92,21 @@ class PerDestinationQueue:
self._destination = destination
self.transmission_loop_running = False
# True whilst we are sending events that the remote homeserver missed
# because it was unreachable. We start in this state so we can perform
# catch-up at startup.
# New events will only be sent once this is finished, at which point
# _catching_up is flipped to False.
self._catching_up = True # type: bool
# The stream_ordering of the most recent PDU that was discarded due to
# being in catch-up mode.
self._catchup_last_skipped = 0 # type: int
# Cache of the last successfully-transmitted stream ordering for this
# destination (we are the only updater so this is safe)
self._last_successful_stream_ordering = None # type: Optional[int]
# a list of pending PDUs
self._pending_pdus = [] # type: List[EventBase]
@@ -138,7 +153,13 @@ class PerDestinationQueue:
Args:
pdu: pdu to send
"""
self._pending_pdus.append(pdu)
if not self._catching_up or self._last_successful_stream_ordering is None:
# only enqueue the PDU if we are not catching up (False) or do not
# yet know if we have anything to catch up (None)
self._pending_pdus.append(pdu)
else:
self._catchup_last_skipped = pdu.internal_metadata.stream_ordering
self.attempt_new_transaction()
def send_presence(self, states: Iterable[UserPresenceState]) -> None:
@@ -218,6 +239,13 @@ class PerDestinationQueue:
# hence why we throw the result away.
await get_retry_limiter(self._destination, self._clock, self._store)
if self._catching_up:
# we potentially need to catch-up first
await self._catch_up_transmission_loop()
if self._catching_up:
# not caught up yet
return
pending_pdus = []
while True:
# We have to keep 2 free slots for presence and rr_edus
@@ -351,8 +379,9 @@ class PerDestinationQueue:
if e.retry_interval > 60 * 60 * 1000:
# we won't retry for another hour!
# (this suggests a significant outage)
# We drop pending PDUs and EDUs because otherwise they will
# We drop pending EDUs because otherwise they will
# rack up indefinitely.
# (Dropping PDUs is already performed by `_start_catching_up`.)
# Note that:
# - the EDUs that are being dropped here are those that we can
# afford to drop (specifically, only typing notifications,
@@ -364,11 +393,12 @@ class PerDestinationQueue:
# dropping read receipts is a bit sad but should be solved
# through another mechanism, because this is all volatile!
self._pending_pdus = []
self._pending_edus = []
self._pending_edus_keyed = {}
self._pending_presence = {}
self._pending_rrs = {}
self._start_catching_up()
except FederationDeniedError as e:
logger.info(e)
except HttpResponseException as e:
@@ -378,6 +408,8 @@ class PerDestinationQueue:
e.code,
e,
)
self._start_catching_up()
except RequestSendFailed as e:
logger.warning(
"TX [%s] Failed to send transaction: %s", self._destination, e
@@ -387,16 +419,96 @@ class PerDestinationQueue:
logger.info(
"Failed to send event %s to %s", p.event_id, self._destination
)
self._start_catching_up()
except Exception:
logger.exception("TX [%s] Failed to send transaction", self._destination)
for p in pending_pdus:
logger.info(
"Failed to send event %s to %s", p.event_id, self._destination
)
self._start_catching_up()
finally:
# We want to be *very* sure we clear this after we stop processing
self.transmission_loop_running = False
async def _catch_up_transmission_loop(self) -> None:
first_catch_up_check = self._last_successful_stream_ordering is None
if first_catch_up_check:
# first catchup so get last_successful_stream_ordering from database
self._last_successful_stream_ordering = await self._store.get_destination_last_successful_stream_ordering(
self._destination
)
if self._last_successful_stream_ordering is None:
# if it's still None, then this means we don't have the information
# in our database ­ we haven't successfully sent a PDU to this server
# (at least since the introduction of the feature tracking
# last_successful_stream_ordering).
# Sadly, this means we can't do anything here as we don't know what
# needs catching up — so catching up is futile; let's stop.
self._catching_up = False
return
# get at most 50 catchup room/PDUs
while True:
event_ids = await self._store.get_catch_up_room_event_ids(
self._destination, self._last_successful_stream_ordering,
)
if not event_ids:
# No more events to catch up on, but we can't ignore the chance
# of a race condition, so we check that no new events have been
# skipped due to us being in catch-up mode
if self._catchup_last_skipped > self._last_successful_stream_ordering:
# another event has been skipped because we were in catch-up mode
continue
# we are done catching up!
self._catching_up = False
break
if first_catch_up_check:
# as this is our check for needing catch-up, we may have PDUs in
# the queue from before we *knew* we had to do catch-up, so
# clear those out now.
self._start_catching_up()
# fetch the relevant events from the event store
# - redacted behaviour of REDACT is fine, since we only send metadata
# of redacted events to the destination.
# - don't need to worry about rejected events as we do not actively
# forward received events over federation.
catchup_pdus = await self._store.get_events_as_list(event_ids)
if not catchup_pdus:
raise AssertionError(
"No events retrieved when we asked for %r. "
"This should not happen." % event_ids
)
if logger.isEnabledFor(logging.INFO):
rooms = (p.room_id for p in catchup_pdus)
logger.info("Catching up rooms to %s: %r", self._destination, rooms)
success = await self._transaction_manager.send_new_transaction(
self._destination, catchup_pdus, []
)
if not success:
return
sent_transactions_counter.inc()
final_pdu = catchup_pdus[-1]
self._last_successful_stream_ordering = cast(
int, final_pdu.internal_metadata.stream_ordering
)
await self._store.set_destination_last_successful_stream_ordering(
self._destination, self._last_successful_stream_ordering
)
def _get_rr_edus(self, force_flush: bool) -> Iterable[Edu]:
if not self._pending_rrs:
return
@@ -457,3 +569,12 @@ class PerDestinationQueue:
]
return (edus, stream_id)
def _start_catching_up(self) -> None:
"""
Marks this destination as being in catch-up mode.
This throws away the PDU queue.
"""
self._catching_up = True
self._pending_pdus = []
+1 -1
View File
@@ -76,7 +76,7 @@ def create_issuing_service(reactor, acme_url, account_key_file, well_known_resou
)
@attr.s
@attr.s(slots=True)
@implementer(ICertificateStore)
class ErsatzStore:
"""
+3 -3
View File
@@ -125,8 +125,8 @@ class AdminHandler(BaseHandler):
else:
stream_ordering = room.stream_ordering
from_key = str(RoomStreamToken(0, 0))
to_key = str(RoomStreamToken(None, stream_ordering))
from_key = RoomStreamToken(0, 0)
to_key = RoomStreamToken(None, stream_ordering)
written_events = set() # Events that we've processed in this room
@@ -153,7 +153,7 @@ class AdminHandler(BaseHandler):
if not events:
break
from_key = events[-1].internal_metadata.after
from_key = RoomStreamToken.parse(events[-1].internal_metadata.after)
events = await filter_events_for_client(self.storage, user_id, events)
+1 -1
View File
@@ -1235,7 +1235,7 @@ class AuthHandler(BaseHandler):
return urllib.parse.urlunparse(url_parts)
@attr.s
@attr.s(slots=True)
class MacaroonGenerator:
hs = attr.ib()
+5 -7
View File
@@ -29,6 +29,7 @@ from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import (
RoomStreamToken,
StreamToken,
get_domain_from_id,
get_verify_key_from_cross_signing_key,
)
@@ -104,18 +105,15 @@ class DeviceWorkerHandler(BaseHandler):
@trace
@measure_func("device.get_user_ids_changed")
async def get_user_ids_changed(self, user_id, from_token):
async def get_user_ids_changed(self, user_id: str, from_token: StreamToken):
"""Get list of users that have had the devices updated, or have newly
joined a room, that `user_id` may be interested in.
Args:
user_id (str)
from_token (StreamToken)
"""
set_tag("user_id", user_id)
set_tag("from_token", from_token)
now_room_key = await self.store.get_room_events_max_id()
now_room_id = self.store.get_room_max_stream_ordering()
now_room_key = RoomStreamToken(None, now_room_id)
room_ids = await self.store.get_rooms_for_user(user_id)
@@ -142,7 +140,7 @@ class DeviceWorkerHandler(BaseHandler):
)
rooms_changed.update(event.room_id for event in member_events)
stream_ordering = RoomStreamToken.parse_stream_token(from_token.room_key).stream
stream_ordering = from_token.room_key.stream
possibly_changed = set(changed)
possibly_left = set()
+1 -1
View File
@@ -1201,7 +1201,7 @@ def _one_time_keys_match(old_key_json, new_key):
return old_key == new_key_copy
@attr.s
@attr.s(slots=True)
class SignatureListItem:
"""An item in the signature list as used by upload_signatures_for_device_keys.
"""
+31 -18
View File
@@ -86,7 +86,7 @@ from synapse.visibility import filter_events_for_server
logger = logging.getLogger(__name__)
@attr.s
@attr.s(slots=True)
class _NewEventInfo:
"""Holds information about a received event, ready for passing to _handle_new_events
@@ -128,7 +128,6 @@ class FederationHandler(BaseHandler):
self.keyring = hs.get_keyring()
self.action_generator = hs.get_action_generator()
self.is_mine_id = hs.is_mine_id
self.pusher_pool = hs.get_pusherpool()
self.spam_checker = hs.get_spam_checker()
self.event_creation_handler = hs.get_event_creation_handler()
self._message_handler = hs.get_message_handler()
@@ -900,7 +899,8 @@ class FederationHandler(BaseHandler):
)
)
await self._handle_new_events(dest, ev_infos, backfilled=True)
if ev_infos:
await self._handle_new_events(dest, room_id, ev_infos, backfilled=True)
# Step 2: Persist the rest of the events in the chunk one by one
events.sort(key=lambda e: e.depth)
@@ -1193,7 +1193,7 @@ class FederationHandler(BaseHandler):
event_infos.append(_NewEventInfo(event, None, auth))
await self._handle_new_events(
destination, event_infos,
destination, room_id, event_infos,
)
def _sanity_check_event(self, ev):
@@ -1340,15 +1340,15 @@ class FederationHandler(BaseHandler):
)
max_stream_id = await self._persist_auth_tree(
origin, auth_chain, state, event, room_version_obj
origin, room_id, auth_chain, state, event, room_version_obj
)
# We wait here until this instance has seen the events come down
# replication (if we're using replication) as the below uses caches.
#
# TODO: Currently the events stream is written to from master
await self._replication.wait_for_stream_position(
self.config.worker.writers.events, "events", max_stream_id
self.config.worker.events_shard_config.get_instance(room_id),
"events",
max_stream_id,
)
# Check whether this room is the result of an upgrade of a room we already know
@@ -1604,7 +1604,7 @@ class FederationHandler(BaseHandler):
)
context = await self.state_handler.compute_event_context(event)
await self.persist_events_and_notify([(event, context)])
await self.persist_events_and_notify(event.room_id, [(event, context)])
return event
@@ -1631,7 +1631,9 @@ class FederationHandler(BaseHandler):
await self.federation_client.send_leave(host_list, event)
context = await self.state_handler.compute_event_context(event)
stream_id = await self.persist_events_and_notify([(event, context)])
stream_id = await self.persist_events_and_notify(
event.room_id, [(event, context)]
)
return event, stream_id
@@ -1879,7 +1881,7 @@ class FederationHandler(BaseHandler):
)
await self.persist_events_and_notify(
[(event, context)], backfilled=backfilled
event.room_id, [(event, context)], backfilled=backfilled
)
except Exception:
run_in_background(
@@ -1892,6 +1894,7 @@ class FederationHandler(BaseHandler):
async def _handle_new_events(
self,
origin: str,
room_id: str,
event_infos: Iterable[_NewEventInfo],
backfilled: bool = False,
) -> None:
@@ -1923,6 +1926,7 @@ class FederationHandler(BaseHandler):
)
await self.persist_events_and_notify(
room_id,
[
(ev_info.event, context)
for ev_info, context in zip(event_infos, contexts)
@@ -1933,6 +1937,7 @@ class FederationHandler(BaseHandler):
async def _persist_auth_tree(
self,
origin: str,
room_id: str,
auth_events: List[EventBase],
state: List[EventBase],
event: EventBase,
@@ -1947,6 +1952,7 @@ class FederationHandler(BaseHandler):
Args:
origin: Where the events came from
room_id,
auth_events
state
event
@@ -2021,17 +2027,20 @@ class FederationHandler(BaseHandler):
events_to_context[e.event_id].rejected = RejectedReason.AUTH_ERROR
await self.persist_events_and_notify(
room_id,
[
(e, events_to_context[e.event_id])
for e in itertools.chain(auth_events, state)
]
],
)
new_event_context = await self.state_handler.compute_event_context(
event, old_state=state
)
return await self.persist_events_and_notify([(event, new_event_context)])
return await self.persist_events_and_notify(
room_id, [(event, new_event_context)]
)
async def _prep_event(
self,
@@ -2882,6 +2891,7 @@ class FederationHandler(BaseHandler):
async def persist_events_and_notify(
self,
room_id: str,
event_and_contexts: Sequence[Tuple[EventBase, EventContext]],
backfilled: bool = False,
) -> int:
@@ -2889,14 +2899,19 @@ class FederationHandler(BaseHandler):
necessary.
Args:
event_and_contexts:
room_id: The room ID of events being persisted.
event_and_contexts: Sequence of events with their associated
context that should be persisted. All events must belong to
the same room.
backfilled: Whether these events are a result of
backfilling or not
"""
if self.config.worker.writers.events != self._instance_name:
instance = self.config.worker.events_shard_config.get_instance(room_id)
if instance != self._instance_name:
result = await self._send_events(
instance_name=self.config.worker.writers.events,
instance_name=instance,
store=self.store,
room_id=room_id,
event_and_contexts=event_and_contexts,
backfilled=backfilled,
)
@@ -2949,8 +2964,6 @@ class FederationHandler(BaseHandler):
event, event_stream_id, max_stream_id, extra_users=extra_users
)
await self.pusher_pool.on_new_notifications(max_stream_id)
async def _clean_room_for_join(self, room_id: str) -> None:
"""Called to clean up any data in DB for a given room, ready for the
server to join the room.
+2 -2
View File
@@ -25,7 +25,7 @@ from synapse.handlers.presence import format_user_presence_state
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.storage.roommember import RoomsForUser
from synapse.streams.config import PaginationConfig
from synapse.types import JsonDict, Requester, StreamToken, UserID
from synapse.types import JsonDict, Requester, RoomStreamToken, StreamToken, UserID
from synapse.util import unwrapFirstError
from synapse.util.async_helpers import concurrently_execute
from synapse.util.caches.response_cache import ResponseCache
@@ -167,7 +167,7 @@ class InitialSyncHandler(BaseHandler):
self.state_handler.get_current_state, event.room_id
)
elif event.membership == Membership.LEAVE:
room_end_token = "s%d" % (event.stream_ordering,)
room_end_token = RoomStreamToken(None, event.stream_ordering,)
deferred_room_state = run_in_background(
self.state_store.get_state_for_events, [event.event_id]
)
+9 -10
View File
@@ -377,9 +377,8 @@ class EventCreationHandler:
self.notifier = hs.get_notifier()
self.config = hs.config
self.require_membership_for_aliases = hs.config.require_membership_for_aliases
self._is_event_writer = (
self.config.worker.writers.events == hs.get_instance_name()
)
self._events_shard_config = self.config.worker.events_shard_config
self._instance_name = hs.get_instance_name()
self.room_invite_state_types = self.hs.config.room_invite_state_types
@@ -388,8 +387,6 @@ class EventCreationHandler:
# This is only used to get at ratelimit function, and maybe_kick_guest_users
self.base_handler = BaseHandler(hs)
self.pusher_pool = hs.get_pusherpool()
# We arbitrarily limit concurrent event creation for a room to 5.
# This is to stop us from diverging history *too* much.
self.limiter = Linearizer(max_count=5, name="room_event_creation_limit")
@@ -907,9 +904,10 @@ class EventCreationHandler:
try:
# If we're a worker we need to hit out to the master.
if not self._is_event_writer:
writer_instance = self._events_shard_config.get_instance(event.room_id)
if writer_instance != self._instance_name:
result = await self.send_event(
instance_name=self.config.worker.writers.events,
instance_name=writer_instance,
event_id=event.event_id,
store=self.store,
requester=requester,
@@ -977,7 +975,10 @@ class EventCreationHandler:
This should only be run on the instance in charge of persisting events.
"""
assert self._is_event_writer
assert self.storage.persistence is not None
assert self._events_shard_config.should_handle(
self._instance_name, event.room_id
)
if ratelimit:
# We check if this is a room admin redacting an event so that we
@@ -1148,8 +1149,6 @@ class EventCreationHandler:
# If there's an expiry timestamp on the event, schedule its expiry.
self._message_handler.maybe_schedule_expiry(event)
await self.pusher_pool.on_new_notifications(max_stream_id)
def _notify():
try:
self.notifier.on_new_room_event(
+2 -2
View File
@@ -131,10 +131,10 @@ class OidcHandler:
def _render_error(
self, request, error: str, error_description: Optional[str] = None
) -> None:
"""Renders the error template and respond with it.
"""Render the error template and respond to the request with it.
This is used to show errors to the user. The template of this page can
be found under ``synapse/res/templates/sso_error.html``.
be found under `synapse/res/templates/sso_error.html`.
Args:
request: The incoming request from the browser.
+2 -2
View File
@@ -344,7 +344,7 @@ class PaginationHandler:
# gets called.
raise Exception("limit not set")
room_token = RoomStreamToken.parse(from_token.room_key)
room_token = from_token.room_key
with await self.pagination_lock.read(room_id):
(
@@ -381,7 +381,7 @@ class PaginationHandler:
if leave_token.topological < max_topo:
from_token = from_token.copy_and_replace(
"room_key", leave_token_str
"room_key", leave_token
)
await self.hs.get_handlers().federation_handler.maybe_backfill(
+16 -13
View File
@@ -833,7 +833,9 @@ class RoomCreationHandler(BaseHandler):
# Always wait for room creation to progate before returning
await self._replication.wait_for_stream_position(
self.hs.config.worker.writers.events, "events", last_stream_id
self.hs.config.worker.events_shard_config.get_instance(room_id),
"events",
last_stream_id,
)
return result, last_stream_id
@@ -1121,20 +1123,19 @@ class RoomEventSource:
async def get_new_events(
self,
user: UserID,
from_key: str,
from_key: RoomStreamToken,
limit: int,
room_ids: List[str],
is_guest: bool,
explicit_room_id: Optional[str] = None,
) -> Tuple[List[EventBase], str]:
) -> Tuple[List[EventBase], RoomStreamToken]:
# We just ignore the key for now.
to_key = self.get_current_key()
from_token = RoomStreamToken.parse(from_key)
if from_token.topological:
if from_key.topological:
logger.warning("Stream has topological part!!!! %r", from_key)
from_key = "s%s" % (from_token.stream,)
from_key = RoomStreamToken(None, from_key.stream)
app_service = self.store.get_app_service_by_user_id(user.to_string())
if app_service:
@@ -1163,14 +1164,14 @@ class RoomEventSource:
events[:] = events[:limit]
if events:
end_key = events[-1].internal_metadata.after
end_key = RoomStreamToken.parse(events[-1].internal_metadata.after)
else:
end_key = to_key
return (events, end_key)
def get_current_key(self) -> str:
return "s%d" % (self.store.get_room_max_stream_ordering(),)
def get_current_key(self) -> RoomStreamToken:
return RoomStreamToken(None, self.store.get_room_max_stream_ordering())
def get_current_key_for_room(self, room_id: str) -> Awaitable[str]:
return self.store.get_room_events_max_id(room_id)
@@ -1290,10 +1291,10 @@ class RoomShutdownHandler:
# We now wait for the create room to come back in via replication so
# that we can assume that all the joins/invites have propogated before
# we try and auto join below.
#
# TODO: Currently the events stream is written to from master
await self._replication.wait_for_stream_position(
self.hs.config.worker.writers.events, "events", stream_id
self.hs.config.worker.events_shard_config.get_instance(new_room_id),
"events",
stream_id,
)
else:
new_room_id = None
@@ -1323,7 +1324,9 @@ class RoomShutdownHandler:
# Wait for leave to come in over replication before trying to forget.
await self._replication.wait_for_stream_position(
self.hs.config.worker.writers.events, "events", stream_id
self.hs.config.worker.events_shard_config.get_instance(room_id),
"events",
stream_id,
)
await self.room_member_handler.forget(target_requester.user, room_id)
+1 -10
View File
@@ -51,14 +51,12 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
class RoomMemberHandler:
class RoomMemberHandler(metaclass=abc.ABCMeta):
# TODO(paul): This handler currently contains a messy conflation of
# low-level API that works on UserID objects and so on, and REST-level
# API that takes ID strings and returns pagination chunks. These concerns
# ought to be separated out a lot better.
__metaclass__ = abc.ABCMeta
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.store = hs.get_datastore()
@@ -83,13 +81,6 @@ class RoomMemberHandler:
self._enable_lookup = hs.config.enable_3pid_lookup
self.allow_per_room_profiles = self.config.allow_per_room_profiles
self._event_stream_writer_instance = hs.config.worker.writers.events
self._is_on_event_persistence_instance = (
self._event_stream_writer_instance == hs.get_instance_name()
)
if self._is_on_event_persistence_instance:
self.persist_event_storage = hs.get_storage().persistence
self._join_rate_limiter_local = Ratelimiter(
clock=self.clock,
rate_hz=hs.config.ratelimiting.rc_joins_local.per_second,
+111 -60
View File
@@ -21,9 +21,10 @@ import saml2
import saml2.response
from saml2.client import Saml2Client
from synapse.api.errors import AuthError, SynapseError
from synapse.api.errors import SynapseError
from synapse.config import ConfigError
from synapse.config.saml2_config import SamlAttributeRequirement
from synapse.http.server import respond_with_html
from synapse.http.servlet import parse_string
from synapse.http.site import SynapseRequest
from synapse.module_api import ModuleApi
@@ -41,7 +42,11 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
@attr.s
class MappingException(Exception):
"""Used to catch errors when mapping the SAML2 response to a user."""
@attr.s(slots=True)
class Saml2SessionData:
"""Data we track about SAML2 sessions"""
@@ -68,6 +73,7 @@ class SamlHandler:
hs.config.saml2_grandfathered_mxid_source_attribute
)
self._saml2_attribute_requirements = hs.config.saml2.attribute_requirements
self._error_template = hs.config.sso_error_template
# plugin to do custom mapping from saml response to mxid
self._user_mapping_provider = hs.config.saml2_user_mapping_provider_class(
@@ -84,6 +90,25 @@ class SamlHandler:
# a lock on the mappings
self._mapping_lock = Linearizer(name="saml_mapping", clock=self._clock)
def _render_error(
self, request, error: str, error_description: Optional[str] = None
) -> None:
"""Render the error template and respond to the request with it.
This is used to show errors to the user. The template of this page can
be found under `synapse/res/templates/sso_error.html`.
Args:
request: The incoming request from the browser.
We'll respond with an HTML page describing the error.
error: A technical identifier for this error.
error_description: A human-readable description of the error.
"""
html = self._error_template.render(
error=error, error_description=error_description
)
respond_with_html(request, 400, html)
def handle_redirect_request(
self, client_redirect_url: bytes, ui_auth_session_id: Optional[str] = None
) -> bytes:
@@ -134,49 +159,6 @@ class SamlHandler:
# the dict.
self.expire_sessions()
# Pull out the user-agent and IP from the request.
user_agent = request.requestHeaders.getRawHeaders(b"User-Agent", default=[b""])[
0
].decode("ascii", "surrogateescape")
ip_address = self.hs.get_ip_from_request(request)
user_id, current_session = await self._map_saml_response_to_user(
resp_bytes, relay_state, user_agent, ip_address
)
# Complete the interactive auth session or the login.
if current_session and current_session.ui_auth_session_id:
await self._auth_handler.complete_sso_ui_auth(
user_id, current_session.ui_auth_session_id, request
)
else:
await self._auth_handler.complete_sso_login(user_id, request, relay_state)
async def _map_saml_response_to_user(
self,
resp_bytes: str,
client_redirect_url: str,
user_agent: str,
ip_address: str,
) -> Tuple[str, Optional[Saml2SessionData]]:
"""
Given a sample response, retrieve the cached session and user for it.
Args:
resp_bytes: The SAML response.
client_redirect_url: The redirect URL passed in by the client.
user_agent: The user agent of the client making the request.
ip_address: The IP address of the client making the request.
Returns:
Tuple of the user ID and SAML session associated with this response.
Raises:
SynapseError if there was a problem with the response.
RedirectException: some mapping providers may raise this if they need
to redirect to an interstitial page.
"""
try:
saml2_auth = self._saml_client.parse_authn_request_response(
resp_bytes,
@@ -189,12 +171,23 @@ class SamlHandler:
# in the (user-visible) exception message, so let's log the exception here
# so we can track down the session IDs later.
logger.warning(str(e))
raise SynapseError(400, "Unexpected SAML2 login.")
self._render_error(
request, "unsolicited_response", "Unexpected SAML2 login."
)
return
except Exception as e:
raise SynapseError(400, "Unable to parse SAML2 response: %s." % (e,))
self._render_error(
request,
"invalid_response",
"Unable to parse SAML2 response: %s." % (e,),
)
return
if saml2_auth.not_signed:
raise SynapseError(400, "SAML2 response was not signed.")
self._render_error(
request, "unsigned_respond", "SAML2 response was not signed."
)
return
logger.debug("SAML2 response: %s", saml2_auth.origxml)
for assertion in saml2_auth.assertions:
@@ -213,15 +206,73 @@ class SamlHandler:
saml2_auth.in_response_to, None
)
# Ensure that the attributes of the logged in user meet the required
# attributes.
for requirement in self._saml2_attribute_requirements:
_check_attribute_requirement(saml2_auth.ava, requirement)
if not _check_attribute_requirement(saml2_auth.ava, requirement):
self._render_error(
request, "unauthorised", "You are not authorised to log in here."
)
return
# Pull out the user-agent and IP from the request.
user_agent = request.requestHeaders.getRawHeaders(b"User-Agent", default=[b""])[
0
].decode("ascii", "surrogateescape")
ip_address = self.hs.get_ip_from_request(request)
# Call the mapper to register/login the user
try:
user_id = await self._map_saml_response_to_user(
saml2_auth, relay_state, user_agent, ip_address
)
except MappingException as e:
logger.exception("Could not map user")
self._render_error(request, "mapping_error", str(e))
return
# Complete the interactive auth session or the login.
if current_session and current_session.ui_auth_session_id:
await self._auth_handler.complete_sso_ui_auth(
user_id, current_session.ui_auth_session_id, request
)
else:
await self._auth_handler.complete_sso_login(user_id, request, relay_state)
async def _map_saml_response_to_user(
self,
saml2_auth: saml2.response.AuthnResponse,
client_redirect_url: str,
user_agent: str,
ip_address: str,
) -> str:
"""
Given a SAML response, retrieve the user ID for it and possibly register the user.
Args:
saml2_auth: The parsed SAML2 response.
client_redirect_url: The redirect URL passed in by the client.
user_agent: The user agent of the client making the request.
ip_address: The IP address of the client making the request.
Returns:
The user ID associated with this response.
Raises:
MappingException if there was a problem mapping the response to a user.
RedirectException: some mapping providers may raise this if they need
to redirect to an interstitial page.
"""
remote_user_id = self._user_mapping_provider.get_remote_user_id(
saml2_auth, client_redirect_url
)
if not remote_user_id:
raise Exception("Failed to extract remote user id from SAML response")
raise MappingException(
"Failed to extract remote user id from SAML response"
)
with (await self._mapping_lock.queue(self._auth_provider_id)):
# first of all, check if we already have a mapping for this user
@@ -235,7 +286,7 @@ class SamlHandler:
)
if registered_user_id is not None:
logger.info("Found existing mapping %s", registered_user_id)
return registered_user_id, current_session
return registered_user_id
# backwards-compatibility hack: see if there is an existing user with a
# suitable mapping from the uid
@@ -260,7 +311,7 @@ class SamlHandler:
await self._datastore.record_user_external_id(
self._auth_provider_id, remote_user_id, registered_user_id
)
return registered_user_id, current_session
return registered_user_id
# Map saml response to user attributes using the configured mapping provider
for i in range(1000):
@@ -277,7 +328,7 @@ class SamlHandler:
localpart = attribute_dict.get("mxid_localpart")
if not localpart:
raise Exception(
raise MappingException(
"Error parsing SAML2 response: SAML mapping provider plugin "
"did not return a mxid_localpart value"
)
@@ -294,8 +345,8 @@ class SamlHandler:
else:
# Unable to generate a username in 1000 iterations
# Break and return error to the user
raise SynapseError(
500, "Unable to generate a Matrix ID from the SAML response"
raise MappingException(
"Unable to generate a Matrix ID from the SAML response"
)
logger.info("Mapped SAML user to local part %s", localpart)
@@ -310,7 +361,7 @@ class SamlHandler:
await self._datastore.record_user_external_id(
self._auth_provider_id, remote_user_id, registered_user_id
)
return registered_user_id, current_session
return registered_user_id
def expire_sessions(self):
expire_before = self._clock.time_msec() - self._saml2_session_lifetime
@@ -323,11 +374,11 @@ class SamlHandler:
del self._outstanding_requests_dict[reqid]
def _check_attribute_requirement(ava: dict, req: SamlAttributeRequirement):
def _check_attribute_requirement(ava: dict, req: SamlAttributeRequirement) -> bool:
values = ava.get(req.attribute, [])
for v in values:
if v == req.value:
return
return True
logger.info(
"SAML2 attribute %s did not match required value '%s' (was '%s')",
@@ -335,7 +386,7 @@ def _check_attribute_requirement(ava: dict, req: SamlAttributeRequirement):
req.value,
values,
)
raise AuthError(403, "You are not authorized to log in here.")
return False
DOT_REPLACE_PATTERN = re.compile(
@@ -390,7 +441,7 @@ class DefaultSamlMappingProvider:
return saml_response.ava["uid"][0]
except KeyError:
logger.warning("SAML2 response lacks a 'uid' attestation")
raise SynapseError(400, "'uid' not in SAML2 response")
raise MappingException("'uid' not in SAML2 response")
def saml_response_to_user_attributes(
self,
+16 -29
View File
@@ -89,14 +89,12 @@ class TimelineBatch:
events = attr.ib(type=List[EventBase])
limited = attr.ib(bool)
def __nonzero__(self) -> bool:
def __bool__(self) -> bool:
"""Make the result appear empty if there are no updates. This is used
to tell if room needs to be part of the sync result.
"""
return bool(self.events)
__bool__ = __nonzero__ # python3
# We can't freeze this class, because we need to update it after it's instantiated to
# update its unread count. This is because we calculate the unread count for a room only
@@ -114,7 +112,7 @@ class JoinedSyncResult:
summary = attr.ib(type=Optional[JsonDict])
unread_count = attr.ib(type=int)
def __nonzero__(self) -> bool:
def __bool__(self) -> bool:
"""Make the result appear empty if there are no updates. This is used
to tell if room needs to be part of the sync result.
"""
@@ -127,8 +125,6 @@ class JoinedSyncResult:
# else in the result, we don't need to send it.
)
__bool__ = __nonzero__ # python3
@attr.s(slots=True, frozen=True)
class ArchivedSyncResult:
@@ -137,26 +133,22 @@ class ArchivedSyncResult:
state = attr.ib(type=StateMap[EventBase])
account_data = attr.ib(type=List[JsonDict])
def __nonzero__(self) -> bool:
def __bool__(self) -> bool:
"""Make the result appear empty if there are no updates. This is used
to tell if room needs to be part of the sync result.
"""
return bool(self.timeline or self.state or self.account_data)
__bool__ = __nonzero__ # python3
@attr.s(slots=True, frozen=True)
class InvitedSyncResult:
room_id = attr.ib(type=str)
invite = attr.ib(type=EventBase)
def __nonzero__(self) -> bool:
def __bool__(self) -> bool:
"""Invited rooms should always be reported to the client"""
return True
__bool__ = __nonzero__ # python3
@attr.s(slots=True, frozen=True)
class GroupsSyncResult:
@@ -164,11 +156,9 @@ class GroupsSyncResult:
invite = attr.ib(type=JsonDict)
leave = attr.ib(type=JsonDict)
def __nonzero__(self) -> bool:
def __bool__(self) -> bool:
return bool(self.join or self.invite or self.leave)
__bool__ = __nonzero__ # python3
@attr.s(slots=True, frozen=True)
class DeviceLists:
@@ -181,13 +171,11 @@ class DeviceLists:
changed = attr.ib(type=Collection[str])
left = attr.ib(type=Collection[str])
def __nonzero__(self) -> bool:
def __bool__(self) -> bool:
return bool(self.changed or self.left)
__bool__ = __nonzero__ # python3
@attr.s
@attr.s(slots=True)
class _RoomChanges:
"""The set of room entries to include in the sync, plus the set of joined
and left room IDs since last sync.
@@ -227,7 +215,7 @@ class SyncResult:
device_one_time_keys_count = attr.ib(type=JsonDict)
groups = attr.ib(type=Optional[GroupsSyncResult])
def __nonzero__(self) -> bool:
def __bool__(self) -> bool:
"""Make the result appear empty if there are no updates. This is used
to tell if the notifier needs to wait for more events when polling for
events.
@@ -243,8 +231,6 @@ class SyncResult:
or self.groups
)
__bool__ = __nonzero__ # python3
class SyncHandler:
def __init__(self, hs: "HomeServer"):
@@ -378,7 +364,7 @@ class SyncHandler:
sync_config = sync_result_builder.sync_config
with Measure(self.clock, "ephemeral_by_room"):
typing_key = since_token.typing_key if since_token else "0"
typing_key = since_token.typing_key if since_token else 0
room_ids = sync_result_builder.joined_room_ids
@@ -402,7 +388,7 @@ class SyncHandler:
event_copy = {k: v for (k, v) in event.items() if k != "room_id"}
ephemeral_by_room.setdefault(room_id, []).append(event_copy)
receipt_key = since_token.receipt_key if since_token else "0"
receipt_key = since_token.receipt_key if since_token else 0
receipt_source = self.event_sources.sources["receipt"]
receipts, receipt_key = await receipt_source.get_new_events(
@@ -533,7 +519,7 @@ class SyncHandler:
if len(recents) > timeline_limit:
limited = True
recents = recents[-timeline_limit:]
room_key = recents[0].internal_metadata.before
room_key = RoomStreamToken.parse(recents[0].internal_metadata.before)
prev_batch_token = now_token.copy_and_replace("room_key", room_key)
@@ -1322,6 +1308,7 @@ class SyncHandler:
is_guest=sync_config.is_guest,
include_offline=include_offline,
)
assert presence_key
sync_result_builder.now_token = now_token.copy_and_replace(
"presence_key", presence_key
)
@@ -1484,7 +1471,7 @@ class SyncHandler:
if rooms_changed:
return True
stream_id = RoomStreamToken.parse_stream_token(since_token.room_key).stream
stream_id = since_token.room_key.stream
for room_id in sync_result_builder.joined_room_ids:
if self.store.has_room_changed_since(room_id, stream_id):
return True
@@ -1750,7 +1737,7 @@ class SyncHandler:
continue
leave_token = now_token.copy_and_replace(
"room_key", "s%d" % (event.stream_ordering,)
"room_key", RoomStreamToken(None, event.stream_ordering)
)
room_entries.append(
RoomSyncResultBuilder(
@@ -2037,7 +2024,7 @@ def _calculate_state(
return {event_id_to_key[e]: e for e in state_ids}
@attr.s
@attr.s(slots=True)
class SyncResultBuilder:
"""Used to help build up a new SyncResult for a user
@@ -2073,7 +2060,7 @@ class SyncResultBuilder:
to_device = attr.ib(type=List[JsonDict], default=attr.Factory(list))
@attr.s
@attr.s(slots=True)
class RoomSyncResultBuilder:
"""Stores information needed to create either a `JoinedSyncResult` or
`ArchivedSyncResult`.
@@ -311,7 +311,7 @@ def _parse_cache_control(headers: Headers) -> Dict[bytes, Optional[bytes]]:
return cache_controls
@attr.s()
@attr.s(slots=True)
class _FetchWellKnownFailure(Exception):
# True if we didn't get a non-5xx HTTP response, i.e. this may or may not be
# a temporary failure.
+1 -1
View File
@@ -76,7 +76,7 @@ MAXINT = sys.maxsize
_next_id = 1
@attr.s(frozen=True)
@attr.s(slots=True, frozen=True)
class MatrixFederationRequest:
method = attr.ib()
"""HTTP method
+1 -3
View File
@@ -217,11 +217,9 @@ class _Sentinel:
def record_event_fetch(self, event_count):
pass
def __nonzero__(self):
def __bool__(self):
return False
__bool__ = __nonzero__ # python3
SENTINEL_CONTEXT = _Sentinel()
+1 -1
View File
@@ -509,7 +509,7 @@ def start_active_span_from_edu(
]
# For some reason jaeger decided not to support the visualization of multiple parent
# spans or explicitely show references. I include the span context as a tag here as
# spans or explicitly show references. I include the span context as a tag here as
# an aid to people debugging but it's really not an ideal solution.
references += _references
+2 -2
View File
@@ -59,7 +59,7 @@ class RegistryProxy:
yield metric
@attr.s(hash=True)
@attr.s(slots=True, hash=True)
class LaterGauge:
name = attr.ib(type=str)
@@ -205,7 +205,7 @@ class InFlightGauge:
all_gauges[self.name] = self
@attr.s(hash=True)
@attr.s(slots=True, hash=True)
class BucketCollector:
"""
Like a Histogram, but allows buckets to be point-in-time instead of
+50 -28
View File
@@ -42,7 +42,7 @@ from synapse.logging.utils import log_function
from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.streams.config import PaginationConfig
from synapse.types import Collection, StreamToken, UserID
from synapse.types import Collection, RoomStreamToken, StreamToken, UserID
from synapse.util.async_helpers import ObservableDeferred, timeout_deferred
from synapse.util.metrics import Measure
from synapse.visibility import filter_events_for_client
@@ -112,7 +112,9 @@ class _NotifierUserStream:
with PreserveLoggingContext():
self.notify_deferred = ObservableDeferred(defer.Deferred())
def notify(self, stream_key: str, stream_id: int, time_now_ms: int):
def notify(
self, stream_key: str, stream_id: Union[int, RoomStreamToken], time_now_ms: int,
):
"""Notify any listeners for this user of a new event from an
event source.
Args:
@@ -162,11 +164,9 @@ class _NotifierUserStream:
class EventStreamResult(namedtuple("EventStreamResult", ("events", "tokens"))):
def __nonzero__(self):
def __bool__(self):
return bool(self.events)
__bool__ = __nonzero__ # python3
class Notifier:
""" This class is responsible for notifying any listeners when there are
@@ -187,7 +187,7 @@ class Notifier:
self.store = hs.get_datastore()
self.pending_new_room_events = (
[]
) # type: List[Tuple[int, EventBase, Collection[Union[str, UserID]]]]
) # type: List[Tuple[int, EventBase, Collection[UserID]]]
# Called when there are new things to stream over replication
self.replication_callbacks = [] # type: List[Callable[[], None]]
@@ -198,6 +198,7 @@ class Notifier:
self.clock = hs.get_clock()
self.appservice_handler = hs.get_application_service_handler()
self._pusher_pool = hs.get_pusherpool()
self.federation_sender = None
if hs.should_send_federation():
@@ -247,7 +248,7 @@ class Notifier:
event: EventBase,
room_stream_id: int,
max_room_stream_id: int,
extra_users: Collection[Union[str, UserID]] = [],
extra_users: Collection[UserID] = [],
):
""" Used by handlers to inform the notifier something has happened
in the room, room event wise.
@@ -274,47 +275,68 @@ class Notifier:
"""
pending = self.pending_new_room_events
self.pending_new_room_events = []
users = set() # type: Set[UserID]
rooms = set() # type: Set[str]
for room_stream_id, event, extra_users in pending:
if room_stream_id > max_room_stream_id:
self.pending_new_room_events.append(
(room_stream_id, event, extra_users)
)
else:
self._on_new_room_event(event, room_stream_id, extra_users)
if (
event.type == EventTypes.Member
and event.membership == Membership.JOIN
):
self._user_joined_room(event.state_key, event.room_id)
users.update(extra_users)
rooms.add(event.room_id)
if users or rooms:
self.on_new_event(
"room_key",
RoomStreamToken(None, max_room_stream_id),
users=users,
rooms=rooms,
)
self._on_updated_room_token(max_room_stream_id)
def _on_updated_room_token(self, max_room_stream_id: int):
"""Poke services that might care that the room position has been
updated.
"""
def _on_new_room_event(
self,
event: EventBase,
room_stream_id: int,
extra_users: Collection[Union[str, UserID]] = [],
):
"""Notify any user streams that are interested in this room event"""
# poke any interested application service.
run_as_background_process(
"notify_app_services", self._notify_app_services, room_stream_id
"_notify_app_services", self._notify_app_services, max_room_stream_id
)
run_as_background_process(
"_notify_pusher_pool", self._notify_pusher_pool, max_room_stream_id
)
if self.federation_sender:
self.federation_sender.notify_new_events(room_stream_id)
self.federation_sender.notify_new_events(max_room_stream_id)
if event.type == EventTypes.Member and event.membership == Membership.JOIN:
self._user_joined_room(event.state_key, event.room_id)
self.on_new_event(
"room_key", room_stream_id, users=extra_users, rooms=[event.room_id]
)
async def _notify_app_services(self, room_stream_id: int):
async def _notify_app_services(self, max_room_stream_id: int):
try:
await self.appservice_handler.notify_interested_services(room_stream_id)
await self.appservice_handler.notify_interested_services(max_room_stream_id)
except Exception:
logger.exception("Error notifying application services of event")
async def _notify_pusher_pool(self, max_room_stream_id: int):
try:
await self._pusher_pool.on_new_notifications(max_room_stream_id)
except Exception:
logger.exception("Error pusher pool of event")
def on_new_event(
self,
stream_key: str,
new_token: int,
users: Collection[Union[str, UserID]] = [],
new_token: Union[int, RoomStreamToken],
users: Collection[UserID] = [],
rooms: Collection[str] = [],
):
""" Used to inform listeners that something has happened event wise.
+1 -1
View File
@@ -184,7 +184,7 @@ class PusherPool:
)
await self.remove_pusher(p["app_id"], p["pushkey"], p["user_name"])
async def on_new_notifications(self, max_stream_id):
async def on_new_notifications(self, max_stream_id: int):
if not self.pushers:
# nothing to do here.
return
+1 -3
View File
@@ -33,7 +33,7 @@ from synapse.util.stringutils import random_string
logger = logging.getLogger(__name__)
class ReplicationEndpoint:
class ReplicationEndpoint(metaclass=abc.ABCMeta):
"""Helper base class for defining new replication HTTP endpoints.
This creates an endpoint under `/_synapse/replication/:NAME/:PATH_ARGS..`
@@ -72,8 +72,6 @@ class ReplicationEndpoint:
is received.
"""
__metaclass__ = abc.ABCMeta
NAME = abc.abstractproperty() # type: str # type: ignore
PATH_ARGS = abc.abstractproperty() # type: Tuple[str, ...] # type: ignore
METHOD = "POST"
+9 -3
View File
@@ -65,10 +65,11 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
self.federation_handler = hs.get_handlers().federation_handler
@staticmethod
async def _serialize_payload(store, event_and_contexts, backfilled):
async def _serialize_payload(store, room_id, event_and_contexts, backfilled):
"""
Args:
store
room_id (str)
event_and_contexts (list[tuple[FrozenEvent, EventContext]])
backfilled (bool): Whether or not the events are the result of
backfilling
@@ -88,7 +89,11 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
}
)
payload = {"events": event_payloads, "backfilled": backfilled}
payload = {
"events": event_payloads,
"backfilled": backfilled,
"room_id": room_id,
}
return payload
@@ -96,6 +101,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
with Measure(self.clock, "repl_fed_send_events_parse"):
content = parse_json_object_from_request(request)
room_id = content["room_id"]
backfilled = content["backfilled"]
event_payloads = content["events"]
@@ -120,7 +126,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
logger.info("Got %d events from federation", len(event_and_contexts))
max_stream_id = await self.federation_handler.persist_events_and_notify(
event_and_contexts, backfilled
room_id, event_and_contexts, backfilled
)
return 200, {"max_stream_id": max_stream_id}
+3 -6
View File
@@ -29,6 +29,7 @@ from synapse.replication.tcp.streams.events import (
EventsStreamEventRow,
EventsStreamRow,
)
from synapse.types import UserID
from synapse.util.async_helpers import timeout_deferred
from synapse.util.metrics import Measure
@@ -98,7 +99,6 @@ class ReplicationDataHandler:
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
self.pusher_pool = hs.get_pusherpool()
self.notifier = hs.get_notifier()
self._reactor = hs.get_reactor()
self._clock = hs.get_clock()
@@ -148,15 +148,12 @@ class ReplicationDataHandler:
if event.rejected_reason:
continue
extra_users = () # type: Tuple[str, ...]
extra_users = () # type: Tuple[UserID, ...]
if event.type == EventTypes.Member:
extra_users = (event.state_key,)
extra_users = (UserID.from_string(event.state_key),)
max_token = self.store.get_room_max_stream_ordering()
self.notifier.on_new_room_event(event, token, max_token, extra_users)
max_token = self.store.get_room_max_stream_ordering()
await self.pusher_pool.on_new_notifications(max_token)
# Notify any waiting deferreds. The list is ordered by position so we
# just iterate through the list until we reach a position that is
# greater than the received row position.
+1 -1
View File
@@ -109,7 +109,7 @@ class ReplicationCommandHandler:
if isinstance(stream, (EventsStream, BackfillStream)):
# Only add EventStream and BackfillStream as a source on the
# instance in charge of event persistence.
if hs.config.worker.writers.events == hs.get_instance_name():
if hs.get_instance_name() in hs.config.worker.writers.events:
self._streams_to_replicate.append(stream)
continue
+1 -1
View File
@@ -93,7 +93,7 @@ class ReplicationStreamer:
"""
if not self.command_handler.connected():
# Don't bother if nothing is listening. We still need to advance
# the stream tokens otherwise they'll fall beihind forever
# the stream tokens otherwise they'll fall behind forever
for stream in self.streams:
stream.discard_updates_and_advance()
return
+2 -2
View File
@@ -383,7 +383,7 @@ class CachesStream(Stream):
the cache on the workers
"""
@attr.s
@attr.s(slots=True)
class CachesStreamRow:
"""Stream to inform workers they should invalidate their cache.
@@ -441,7 +441,7 @@ class DeviceListsStream(Stream):
told about a device update.
"""
@attr.s
@attr.s(slots=True)
class DeviceListsStreamRow:
entity = attr.ib(type=str)
+2 -2
View File
@@ -19,7 +19,7 @@ from typing import List, Tuple, Type
import attr
from ._base import Stream, StreamUpdateResult, Token, current_token_without_instance
from ._base import Stream, StreamUpdateResult, Token
"""Handling of the 'events' replication stream
@@ -117,7 +117,7 @@ class EventsStream(Stream):
self._store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
current_token_without_instance(self._store.get_current_events_token),
self._store._stream_id_gen.get_current_token_for_writer,
self._update_function,
)
-52
View File
@@ -1,52 +0,0 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<title>SSO login error</title>
</head>
<body>
{# a 403 means we have actively rejected their login #}
{% if code == 403 %}
<p>You are not allowed to log in here.</p>
{% else %}
<p>
There was an error during authentication:
</p>
<div id="errormsg" style="margin:20px 80px">{{ msg }}</div>
<p>
If you are seeing this page after clicking a link sent to you via email, make
sure you only click the confirmation link once, and that you open the
validation link in the same client you're logging in from.
</p>
<p>
Try logging in again from your Matrix client and if the problem persists
please contact the server's administrator.
</p>
<script type="text/javascript">
// Error handling to support Auth0 errors that we might get through a GET request
// to the validation endpoint. If an error is provided, it's either going to be
// located in the query string or in a query string-like URI fragment.
// We try to locate the error from any of these two locations, but if we can't
// we just don't print anything specific.
let searchStr = "";
if (window.location.search) {
// window.location.searchParams isn't always defined when
// window.location.search is, so it's more reliable to parse the latter.
searchStr = window.location.search;
} else if (window.location.hash) {
// Replace the # with a ? so that URLSearchParams does the right thing and
// doesn't parse the first parameter incorrectly.
searchStr = window.location.hash.replace("#", "?");
}
// We might end up with no error in the URL, so we need to check if we have one
// to print one.
let errorDesc = new URLSearchParams(searchStr).get("error_description")
if (errorDesc) {
document.getElementById("errormsg").innerText = errorDesc;
}
</script>
{% endif %}
</body>
</html>
+39 -4
View File
@@ -5,14 +5,49 @@
<title>SSO error</title>
</head>
<body>
<p>Oops! Something went wrong during authentication.</p>
{# If an error of unauthorised is returned it means we have actively rejected their login #}
{% if error == "unauthorised" %}
<p>You are not allowed to log in here.</p>
{% else %}
<p>
There was an error during authentication:
</p>
<div id="errormsg" style="margin:20px 80px">{{ error_description }}</div>
<p>
If you are seeing this page after clicking a link sent to you via email, make
sure you only click the confirmation link once, and that you open the
validation link in the same client you're logging in from.
</p>
<p>
Try logging in again from your Matrix client and if the problem persists
please contact the server's administrator.
</p>
<p>Error: <code>{{ error }}</code></p>
{% if error_description %}
<pre><code>{{ error_description }}</code></pre>
{% endif %}
<script type="text/javascript">
// Error handling to support Auth0 errors that we might get through a GET request
// to the validation endpoint. If an error is provided, it's either going to be
// located in the query string or in a query string-like URI fragment.
// We try to locate the error from any of these two locations, but if we can't
// we just don't print anything specific.
let searchStr = "";
if (window.location.search) {
// window.location.searchParams isn't always defined when
// window.location.search is, so it's more reliable to parse the latter.
searchStr = window.location.search;
} else if (window.location.hash) {
// Replace the # with a ? so that URLSearchParams does the right thing and
// doesn't parse the first parameter incorrectly.
searchStr = window.location.hash.replace("#", "?");
}
// We might end up with no error in the URL, so we need to check if we have one
// to print one.
let errorDesc = new URLSearchParams(searchStr).get("error_description")
if (errorDesc) {
document.getElementById("errormsg").innerText = errorDesc;
}
</script>
{% endif %}
</body>
</html>
+2 -2
View File
@@ -16,13 +16,13 @@
import logging
import platform
import re
import synapse
from synapse.api.errors import Codes, NotFoundError, SynapseError
from synapse.http.server import JsonResource
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.rest.admin._base import (
admin_patterns,
assert_requester_is_admin,
historical_admin_path_patterns,
)
@@ -61,7 +61,7 @@ logger = logging.getLogger(__name__)
class VersionServlet(RestServlet):
PATTERNS = (re.compile("^/_synapse/admin/v1/server_version$"),)
PATTERNS = admin_patterns("/server_version$")
def __init__(self, hs):
self.res = {
+2 -2
View File
@@ -44,7 +44,7 @@ def historical_admin_path_patterns(path_regex):
]
def admin_patterns(path_regex: str):
def admin_patterns(path_regex: str, version: str = "v1"):
"""Returns the list of patterns for an admin endpoint
Args:
@@ -54,7 +54,7 @@ def admin_patterns(path_regex: str):
Returns:
A list of regex patterns.
"""
admin_prefix = "^/_synapse/admin/v1"
admin_prefix = "^/_synapse/admin/" + version
patterns = [re.compile(admin_prefix + path_regex)]
return patterns
+5 -10
View File
@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import re
from synapse.api.errors import NotFoundError, SynapseError
from synapse.http.servlet import (
@@ -21,7 +20,7 @@ from synapse.http.servlet import (
assert_params_in_dict,
parse_json_object_from_request,
)
from synapse.rest.admin._base import assert_requester_is_admin
from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin
from synapse.types import UserID
logger = logging.getLogger(__name__)
@@ -32,10 +31,8 @@ class DeviceRestServlet(RestServlet):
Get, update or delete the given user's device
"""
PATTERNS = (
re.compile(
"^/_synapse/admin/v2/users/(?P<user_id>[^/]*)/devices/(?P<device_id>[^/]*)$"
),
PATTERNS = admin_patterns(
"/users/(?P<user_id>[^/]*)/devices/(?P<device_id>[^/]*)$", "v2"
)
def __init__(self, hs):
@@ -98,7 +95,7 @@ class DevicesRestServlet(RestServlet):
Retrieve the given user's devices
"""
PATTERNS = (re.compile("^/_synapse/admin/v2/users/(?P<user_id>[^/]*)/devices$"),)
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/devices$", "v2")
def __init__(self, hs):
"""
@@ -131,9 +128,7 @@ class DeleteDevicesRestServlet(RestServlet):
key which lists the device_ids to delete.
"""
PATTERNS = (
re.compile("^/_synapse/admin/v2/users/(?P<user_id>[^/]*)/delete_devices$"),
)
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/delete_devices$", "v2")
def __init__(self, hs):
self.hs = hs
+2 -3
View File
@@ -12,14 +12,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from synapse.http.servlet import (
RestServlet,
assert_params_in_dict,
parse_json_object_from_request,
)
from synapse.rest.admin import assert_requester_is_admin
from synapse.rest.admin._base import admin_patterns
class PurgeRoomServlet(RestServlet):
@@ -35,7 +34,7 @@ class PurgeRoomServlet(RestServlet):
{}
"""
PATTERNS = (re.compile("^/_synapse/admin/v1/purge_room$"),)
PATTERNS = admin_patterns("/purge_room$")
def __init__(self, hs):
"""
+4 -5
View File
@@ -12,8 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
from synapse.api.constants import EventTypes
from synapse.api.errors import SynapseError
from synapse.http.servlet import (
@@ -22,6 +20,7 @@ from synapse.http.servlet import (
parse_json_object_from_request,
)
from synapse.rest.admin import assert_requester_is_admin
from synapse.rest.admin._base import admin_patterns
from synapse.rest.client.transactions import HttpTransactionCache
from synapse.types import UserID
@@ -56,13 +55,13 @@ class SendServerNoticeServlet(RestServlet):
self.snm = hs.get_server_notices_manager()
def register(self, json_resource):
PATTERN = "^/_synapse/admin/v1/send_server_notice"
PATTERN = "/send_server_notice"
json_resource.register_paths(
"POST", (re.compile(PATTERN + "$"),), self.on_POST, self.__class__.__name__
"POST", admin_patterns(PATTERN + "$"), self.on_POST, self.__class__.__name__
)
json_resource.register_paths(
"PUT",
(re.compile(PATTERN + "/(?P<txn_id>[^/]*)$"),),
admin_patterns(PATTERN + "/(?P<txn_id>[^/]*)$"),
self.on_PUT,
self.__class__.__name__,
)
+4 -4
View File
@@ -15,7 +15,6 @@
import hashlib
import hmac
import logging
import re
from http import HTTPStatus
from synapse.api.constants import UserTypes
@@ -29,6 +28,7 @@ from synapse.http.servlet import (
parse_string,
)
from synapse.rest.admin._base import (
admin_patterns,
assert_requester_is_admin,
assert_user_is_admin,
historical_admin_path_patterns,
@@ -60,7 +60,7 @@ class UsersRestServlet(RestServlet):
class UsersRestServletV2(RestServlet):
PATTERNS = (re.compile("^/_synapse/admin/v2/users$"),)
PATTERNS = admin_patterns("/users$", "v2")
"""Get request to list all local users.
This needs user to have administrator access in Synapse.
@@ -105,7 +105,7 @@ class UsersRestServletV2(RestServlet):
class UserRestServletV2(RestServlet):
PATTERNS = (re.compile("^/_synapse/admin/v2/users/(?P<user_id>[^/]+)$"),)
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]+)$", "v2")
"""Get request to list user details.
This needs user to have administrator access in Synapse.
@@ -642,7 +642,7 @@ class UserAdminServlet(RestServlet):
{}
"""
PATTERNS = (re.compile("^/_synapse/admin/v1/users/(?P<user_id>[^/]*)/admin$"),)
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/admin$")
def __init__(self, hs):
self.hs = hs
+1 -1
View File
@@ -329,7 +329,7 @@ class DeactivateAccountRestServlet(RestServlet):
requester = await self.auth.get_user_by_req(request)
# allow ASes to dectivate their own users
# allow ASes to deactivate their own users
if requester.app_service:
await self._deactivate_account_handler.deactivate_account(
requester.user.to_string(), erase
+8 -5
View File
@@ -438,11 +438,14 @@ class RegisterRestServlet(RestServlet):
access_token = self.auth.get_access_token_from_request(request)
if isinstance(desired_username, str):
result = await self._do_appservice_registration(
desired_username, password, desired_display_name, access_token, body
)
return 200, result # we throw for non 200 responses
if not isinstance(desired_username, str):
raise SynapseError(400, "Desired Username is missing or not a string")
result = await self._do_appservice_registration(
desired_username, password, desired_display_name, access_token, body
)
return 200, result
# == Normal User Registration == (everyone else)
if not self._registration_enabled:
+1 -1
View File
@@ -35,7 +35,7 @@ class RemoteKey(DirectServeJsonResource):
Supports individual GET APIs and a bulk query POST API.
Requsts:
Requests:
GET /_matrix/key/v2/query/remote.server.example.com HTTP/1.1
@@ -102,7 +102,7 @@ for endpoint, globs in _oembed_globs.items():
_oembed_patterns[re.compile(pattern)] = endpoint
@attr.s
@attr.s(slots=True)
class OEmbedResult:
# Either HTML content or URL must be provided.
html = attr.ib(type=Optional[str])
+2 -2
View File
@@ -83,7 +83,7 @@ class Thumbnailer:
Args:
max_width: The largest possible width.
max_height: The larget possible height.
max_height: The largest possible height.
"""
if max_width * self.height < max_height * self.width:
@@ -117,7 +117,7 @@ class Thumbnailer:
Args:
max_width: The largest possible width.
max_height: The larget possible height.
max_height: The largest possible height.
Returns:
BytesIO: the bytes of the encoded image ready to be written to disk
+4 -12
View File
@@ -13,10 +13,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.python import failure
from synapse.api.errors import SynapseError
from synapse.http.server import DirectServeHtmlResource, return_html_error
from synapse.http.server import DirectServeHtmlResource
class SAML2ResponseResource(DirectServeHtmlResource):
@@ -27,21 +25,15 @@ class SAML2ResponseResource(DirectServeHtmlResource):
def __init__(self, hs):
super().__init__()
self._saml_handler = hs.get_saml_handler()
self._error_html_template = hs.config.saml2.saml2_error_html_template
async def _async_render_GET(self, request):
# We're not expecting any GET request on that resource if everything goes right,
# but some IdPs sometimes end up responding with a 302 redirect on this endpoint.
# In this case, just tell the user that something went wrong and they should
# try to authenticate again.
f = failure.Failure(
SynapseError(400, "Unexpected GET request on /saml2/authn_response")
self._saml_handler._render_error(
request, "unexpected_get", "Unexpected GET request on /saml2/authn_response"
)
return_html_error(f, request, self._error_html_template)
async def _async_render_POST(self, request):
try:
await self._saml_handler.handle_saml_response(request)
except Exception:
f = failure.Failure()
return_html_error(f, request, self._error_html_template)
await self._saml_handler.handle_saml_response(request)
+1 -1
View File
@@ -678,7 +678,7 @@ def resolve_events_with_store(
)
@attr.s
@attr.s(slots=True)
class StateResolutionStore:
"""Interface that allows state resolution algorithms to access the database
in well defined way.
+4 -1
View File
@@ -47,6 +47,9 @@ class Storage:
# interfaces.
self.main = stores.main
self.persistence = EventsPersistenceStorage(hs, stores)
self.purge_events = PurgeEventsStorage(hs, stores)
self.state = StateGroupStorage(hs, stores)
self.persistence = None
if stores.persist_events:
self.persistence = EventsPersistenceStorage(hs, stores)
+1 -1
View File
@@ -75,7 +75,7 @@ class Databases:
# If we're on a process that can persist events also
# instantiate a `PersistEventsStore`
if hs.config.worker.writers.events == hs.get_instance_name():
if hs.get_instance_name() in hs.config.worker.writers.events:
persist_events = PersistEventsStore(hs, database, main)
if "state" in database_config.databases:
@@ -29,15 +29,13 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache
logger = logging.getLogger(__name__)
class AccountDataWorkerStore(SQLBaseStore):
# The ABCMeta metaclass ensures that it cannot be instantiated without
# the abstract methods being implemented.
class AccountDataWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
"""This is an abstract base class where subclasses must implement
`get_max_account_data_stream_id` which can be called in the initializer.
"""
# This ABCMeta metaclass ensures that we cannot be instantiated without
# the abstract methods being implemented.
__metaclass__ = abc.ABCMeta
def __init__(self, database: DatabasePool, db_conn, hs):
account_max = self.get_max_account_data_stream_id()
self._account_data_stream_cache = StreamChangeCache(
@@ -35,7 +35,7 @@ if TYPE_CHECKING:
from synapse.handlers.e2e_keys import SignatureListItem
@attr.s
@attr.s(slots=True)
class DeviceKeyLookupResult:
"""The type returned by get_e2e_device_keys_and_signatures"""
@@ -438,7 +438,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
"""
if stream_ordering <= self.stream_ordering_month_ago:
raise StoreError(400, "stream_ordering too old")
raise StoreError(400, "stream_ordering too old %s" % (stream_ordering,))
sql = """
SELECT event_id FROM stream_ordering_to_exterm
@@ -969,7 +969,7 @@ def _action_has_highlight(actions):
return False
@attr.s
@attr.s(slots=True)
class _EventPushSummary:
"""Summary of pending event push actions for a given user in a given room.
Used in _rotate_notifs_before_txn to manipulate results from event_push_actions.
+23 -10
View File
@@ -32,7 +32,7 @@ from synapse.logging.utils import log_function
from synapse.storage._base import db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool, LoggingTransaction
from synapse.storage.databases.main.search import SearchEntry
from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.storage.util.id_generators import MultiWriterIdGenerator
from synapse.types import StateMap, get_domain_from_id
from synapse.util.frozenutils import frozendict_json_encoder
from synapse.util.iterutils import batch_iter
@@ -97,18 +97,21 @@ class PersistEventsStore:
self.store = main_data_store
self.database_engine = db.engine
self._clock = hs.get_clock()
self._instance_name = hs.get_instance_name()
self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
self.is_mine_id = hs.is_mine_id
# Ideally we'd move these ID gens here, unfortunately some other ID
# generators are chained off them so doing so is a bit of a PITA.
self._backfill_id_gen = self.store._backfill_id_gen # type: StreamIdGenerator
self._stream_id_gen = self.store._stream_id_gen # type: StreamIdGenerator
self._backfill_id_gen = (
self.store._backfill_id_gen
) # type: MultiWriterIdGenerator
self._stream_id_gen = self.store._stream_id_gen # type: MultiWriterIdGenerator
# This should only exist on instances that are configured to write
assert (
hs.config.worker.writers.events == hs.get_instance_name()
hs.get_instance_name() in hs.config.worker.writers.events
), "Can only instantiate EventsStore on master"
async def _persist_events_and_state_updates(
@@ -213,7 +216,7 @@ class PersistEventsStore:
Returns:
Filtered event ids
"""
results = []
results = [] # type: List[str]
def _get_events_which_are_prevs_txn(txn, batch):
sql = """
@@ -631,7 +634,9 @@ class PersistEventsStore:
)
@classmethod
def _filter_events_and_contexts_for_duplicates(cls, events_and_contexts):
def _filter_events_and_contexts_for_duplicates(
cls, events_and_contexts: List[Tuple[EventBase, EventContext]]
) -> List[Tuple[EventBase, EventContext]]:
"""Ensure that we don't have the same event twice.
Pick the earliest non-outlier if there is one, else the earliest one.
@@ -641,7 +646,9 @@ class PersistEventsStore:
Returns:
list[(EventBase, EventContext)]: filtered list
"""
new_events_and_contexts = OrderedDict()
new_events_and_contexts = (
OrderedDict()
) # type: OrderedDict[str, Tuple[EventBase, EventContext]]
for event, context in events_and_contexts:
prev_event_context = new_events_and_contexts.get(event.event_id)
if prev_event_context:
@@ -655,7 +662,12 @@ class PersistEventsStore:
new_events_and_contexts[event.event_id] = (event, context)
return list(new_events_and_contexts.values())
def _update_room_depths_txn(self, txn, events_and_contexts, backfilled):
def _update_room_depths_txn(
self,
txn,
events_and_contexts: List[Tuple[EventBase, EventContext]],
backfilled: bool,
):
"""Update min_depth for each room
Args:
@@ -664,7 +676,7 @@ class PersistEventsStore:
we are persisting
backfilled (bool): True if the events were backfilled
"""
depth_updates = {}
depth_updates = {} # type: Dict[str, int]
for event, context in events_and_contexts:
# Remove the any existing cache entries for the event_ids
txn.call_after(self.store._invalidate_get_event_cache, event.event_id)
@@ -800,6 +812,7 @@ class PersistEventsStore:
table="events",
values=[
{
"instance_name": self._instance_name,
"stream_ordering": event.internal_metadata.stream_ordering,
"topological_ordering": event.depth,
"depth": event.depth,
@@ -1436,7 +1449,7 @@ class PersistEventsStore:
Forward extremities are handled when we first start persisting the events.
"""
events_by_room = {}
events_by_room = {} # type: Dict[str, List[EventBase]]
for ev in events:
events_by_room.setdefault(ev.room_id, []).append(ev)
+47 -21
View File
@@ -13,8 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
import itertools
import logging
import threading
@@ -42,7 +40,8 @@ from synapse.replication.tcp.streams import BackfillStream
from synapse.replication.tcp.streams.events import EventsStream
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool
from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
from synapse.types import Collection, get_domain_from_id
from synapse.util.caches.descriptors import Cache, cached
from synapse.util.iterutils import batch_iter
@@ -78,27 +77,54 @@ class EventsWorkerStore(SQLBaseStore):
def __init__(self, database: DatabasePool, db_conn, hs):
super(EventsWorkerStore, self).__init__(database, db_conn, hs)
if hs.config.worker.writers.events == hs.get_instance_name():
# We are the process in charge of generating stream ids for events,
# so instantiate ID generators based on the database
self._stream_id_gen = StreamIdGenerator(
db_conn, "events", "stream_ordering",
if isinstance(database.engine, PostgresEngine):
# If we're using Postgres than we can use `MultiWriterIdGenerator`
# regardless of whether this process writes to the streams or not.
self._stream_id_gen = MultiWriterIdGenerator(
db_conn=db_conn,
db=database,
instance_name=hs.get_instance_name(),
table="events",
instance_column="instance_name",
id_column="stream_ordering",
sequence_name="events_stream_seq",
)
self._backfill_id_gen = StreamIdGenerator(
db_conn,
"events",
"stream_ordering",
step=-1,
extra_tables=[("ex_outlier_stream", "event_stream_ordering")],
self._backfill_id_gen = MultiWriterIdGenerator(
db_conn=db_conn,
db=database,
instance_name=hs.get_instance_name(),
table="events",
instance_column="instance_name",
id_column="stream_ordering",
sequence_name="events_backfill_stream_seq",
positive=False,
)
else:
# Another process is in charge of persisting events and generating
# stream IDs: rely on the replication streams to let us know which
# IDs we can process.
self._stream_id_gen = SlavedIdTracker(db_conn, "events", "stream_ordering")
self._backfill_id_gen = SlavedIdTracker(
db_conn, "events", "stream_ordering", step=-1
)
# We shouldn't be running in worker mode with SQLite, but its useful
# to support it for unit tests.
#
# If this process is the writer than we need to use
# `StreamIdGenerator`, otherwise we use `SlavedIdTracker` which gets
# updated over replication. (Multiple writers are not supported for
# SQLite).
if hs.get_instance_name() in hs.config.worker.writers.events:
self._stream_id_gen = StreamIdGenerator(
db_conn, "events", "stream_ordering",
)
self._backfill_id_gen = StreamIdGenerator(
db_conn,
"events",
"stream_ordering",
step=-1,
extra_tables=[("ex_outlier_stream", "event_stream_ordering")],
)
else:
self._stream_id_gen = SlavedIdTracker(
db_conn, "events", "stream_ordering"
)
self._backfill_id_gen = SlavedIdTracker(
db_conn, "events", "stream_ordering", step=-1
)
self._get_event_cache = Cache(
"*getEvent*",
+3 -4
View File
@@ -61,6 +61,8 @@ def _load_rules(rawrules, enabled_map, use_new_defaults=False):
return rules
# The ABCMeta metaclass ensures that it cannot be instantiated without
# the abstract methods being implemented.
class PushRulesWorkerStore(
ApplicationServiceWorkerStore,
ReceiptsWorkerStore,
@@ -68,15 +70,12 @@ class PushRulesWorkerStore(
RoomMemberWorkerStore,
EventsWorkerStore,
SQLBaseStore,
metaclass=abc.ABCMeta,
):
"""This is an abstract base class where subclasses must implement
`get_max_push_rules_stream_id` which can be called in the initializer.
"""
# This ABCMeta metaclass ensures that we cannot be instantiated without
# the abstract methods being implemented.
__metaclass__ = abc.ABCMeta
def __init__(self, database: DatabasePool, db_conn, hs):
super(PushRulesWorkerStore, self).__init__(database, db_conn, hs)
+3 -5
View File
@@ -31,15 +31,13 @@ from synapse.util.caches.stream_change_cache import StreamChangeCache
logger = logging.getLogger(__name__)
class ReceiptsWorkerStore(SQLBaseStore):
# The ABCMeta metaclass ensures that it cannot be instantiated without
# the abstract methods being implemented.
class ReceiptsWorkerStore(SQLBaseStore, metaclass=abc.ABCMeta):
"""This is an abstract base class where subclasses must implement
`get_max_receipt_stream_id` which can be called in the initializer.
"""
# This ABCMeta metaclass ensures that we cannot be instantiated without
# the abstract methods being implemented.
__metaclass__ = abc.ABCMeta
def __init__(self, database: DatabasePool, db_conn, hs):
super(ReceiptsWorkerStore, self).__init__(database, db_conn, hs)
+2 -1
View File
@@ -104,7 +104,8 @@ class RoomWorkerStore(SQLBaseStore):
curr.local_users_in_room AS joined_local_members, rooms.room_version AS version,
rooms.creator, state.encryption, state.is_federatable AS federatable,
rooms.is_public AS public, state.join_rules, state.guest_access,
state.history_visibility, curr.current_state_events AS state_events
state.history_visibility, curr.current_state_events AS state_events,
state.avatar, state.topic
FROM rooms
LEFT JOIN room_stats_state state USING (room_id)
LEFT JOIN room_stats_current curr USING (room_id)
@@ -0,0 +1,16 @@
/* Copyright 2020 The Matrix.org Foundation C.I.C.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
ALTER TABLE events ADD COLUMN instance_name TEXT;
@@ -0,0 +1,26 @@
/* Copyright 2020 The Matrix.org Foundation C.I.C.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE SEQUENCE IF NOT EXISTS events_stream_seq;
SELECT setval('events_stream_seq', (
SELECT COALESCE(MAX(stream_ordering), 1) FROM events
));
CREATE SEQUENCE IF NOT EXISTS events_backfill_stream_seq;
SELECT setval('events_backfill_stream_seq', (
SELECT COALESCE(-MIN(stream_ordering), 1) FROM events
));
+33 -46
View File
@@ -259,14 +259,12 @@ def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]:
return " AND ".join(clauses), args
class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
"""This is an abstract base class where subclasses must implement
`get_room_max_stream_ordering` and `get_room_min_stream_ordering`
which can be called in the initializer.
"""
__metaclass__ = abc.ABCMeta
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super(StreamWorkerStore, self).__init__(database, db_conn, hs)
@@ -310,11 +308,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
async def get_room_events_stream_for_rooms(
self,
room_ids: Collection[str],
from_key: str,
to_key: str,
from_key: RoomStreamToken,
to_key: RoomStreamToken,
limit: int = 0,
order: str = "DESC",
) -> Dict[str, Tuple[List[EventBase], str]]:
) -> Dict[str, Tuple[List[EventBase], RoomStreamToken]]:
"""Get new room events in stream ordering since `from_key`.
Args:
@@ -333,9 +331,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
- list of recent events in the room
- stream ordering key for the start of the chunk of events returned.
"""
from_id = RoomStreamToken.parse_stream_token(from_key).stream
room_ids = self._events_stream_cache.get_entities_changed(room_ids, from_id)
room_ids = self._events_stream_cache.get_entities_changed(
room_ids, from_key.stream
)
if not room_ids:
return {}
@@ -364,16 +362,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return results
def get_rooms_that_changed(
self, room_ids: Collection[str], from_key: str
self, room_ids: Collection[str], from_key: RoomStreamToken
) -> Set[str]:
"""Given a list of rooms and a token, return rooms where there may have
been changes.
Args:
room_ids
from_key: The room_key portion of a StreamToken
"""
from_id = RoomStreamToken.parse_stream_token(from_key).stream
from_id = from_key.stream
return {
room_id
for room_id in room_ids
@@ -383,11 +377,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
async def get_room_events_stream_for_room(
self,
room_id: str,
from_key: str,
to_key: str,
from_key: RoomStreamToken,
to_key: RoomStreamToken,
limit: int = 0,
order: str = "DESC",
) -> Tuple[List[EventBase], str]:
) -> Tuple[List[EventBase], RoomStreamToken]:
"""Get new room events in stream ordering since `from_key`.
Args:
@@ -408,8 +402,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
if from_key == to_key:
return [], from_key
from_id = RoomStreamToken.parse_stream_token(from_key).stream
to_id = RoomStreamToken.parse_stream_token(to_key).stream
from_id = from_key.stream
to_id = to_key.stream
has_changed = self._events_stream_cache.has_entity_changed(room_id, from_id)
@@ -441,7 +435,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
ret.reverse()
if rows:
key = "s%d" % min(r.stream_ordering for r in rows)
key = RoomStreamToken(None, min(r.stream_ordering for r in rows))
else:
# Assume we didn't get anything because there was nothing to
# get.
@@ -450,10 +444,10 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return ret, key
async def get_membership_changes_for_user(
self, user_id: str, from_key: str, to_key: str
self, user_id: str, from_key: RoomStreamToken, to_key: RoomStreamToken
) -> List[EventBase]:
from_id = RoomStreamToken.parse_stream_token(from_key).stream
to_id = RoomStreamToken.parse_stream_token(to_key).stream
from_id = from_key.stream
to_id = to_key.stream
if from_key == to_key:
return []
@@ -491,8 +485,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return ret
async def get_recent_events_for_room(
self, room_id: str, limit: int, end_token: str
) -> Tuple[List[EventBase], str]:
self, room_id: str, limit: int, end_token: RoomStreamToken
) -> Tuple[List[EventBase], RoomStreamToken]:
"""Get the most recent events in the room in topological ordering.
Args:
@@ -518,8 +512,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return (events, token)
async def get_recent_event_ids_for_room(
self, room_id: str, limit: int, end_token: str
) -> Tuple[List[_EventDictReturn], str]:
self, room_id: str, limit: int, end_token: RoomStreamToken
) -> Tuple[List[_EventDictReturn], RoomStreamToken]:
"""Get the most recent events in the room in topological ordering.
Args:
@@ -535,13 +529,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
if limit == 0:
return [], end_token
parsed_end_token = RoomStreamToken.parse(end_token)
rows, token = await self.db_pool.runInteraction(
"get_recent_event_ids_for_room",
self._paginate_room_events_txn,
room_id,
from_token=parsed_end_token,
from_token=end_token,
limit=limit,
)
@@ -619,17 +611,17 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
allow_none=allow_none,
)
async def get_stream_token_for_event(self, event_id: str) -> str:
async def get_stream_token_for_event(self, event_id: str) -> RoomStreamToken:
"""The stream token for an event
Args:
event_id: The id of the event to look up a stream token for.
Raises:
StoreError if the event wasn't in the database.
Returns:
A "s%d" stream token.
A stream token.
"""
stream_id = await self.get_stream_id_for_event(event_id)
return "s%d" % (stream_id,)
return RoomStreamToken(None, stream_id)
async def get_topological_token_for_event(self, event_id: str) -> str:
"""The stream token for an event
@@ -954,7 +946,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
direction: str = "b",
limit: int = -1,
event_filter: Optional[Filter] = None,
) -> Tuple[List[_EventDictReturn], str]:
) -> Tuple[List[_EventDictReturn], RoomStreamToken]:
"""Returns list of events before or after a given token.
Args:
@@ -1054,17 +1046,17 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
# TODO (erikj): We should work out what to do here instead.
next_token = to_token if to_token else from_token
return rows, str(next_token)
return rows, next_token
async def paginate_room_events(
self,
room_id: str,
from_key: str,
to_key: Optional[str] = None,
from_key: RoomStreamToken,
to_key: Optional[RoomStreamToken] = None,
direction: str = "b",
limit: int = -1,
event_filter: Optional[Filter] = None,
) -> Tuple[List[EventBase], str]:
) -> Tuple[List[EventBase], RoomStreamToken]:
"""Returns list of events before or after a given token.
Args:
@@ -1083,17 +1075,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
and `to_key`).
"""
parsed_from_key = RoomStreamToken.parse(from_key)
parsed_to_key = None
if to_key:
parsed_to_key = RoomStreamToken.parse(to_key)
rows, token = await self.db_pool.runInteraction(
"paginate_room_events",
self._paginate_room_events_txn,
room_id,
parsed_from_key,
parsed_to_key,
from_key,
to_key,
direction,
limit,
event_filter,
+42 -1
View File
@@ -15,7 +15,7 @@
import logging
from collections import namedtuple
from typing import Iterable, Optional, Tuple
from typing import Iterable, List, Optional, Tuple
from canonicaljson import encode_canonical_json
@@ -371,3 +371,44 @@ class TransactionStore(SQLBaseStore):
values={"last_successful_stream_ordering": last_successful_stream_ordering},
desc="set_last_successful_stream_ordering",
)
async def get_catch_up_room_event_ids(
self, destination: str, last_successful_stream_ordering: int,
) -> List[str]:
"""
Returns at most 50 event IDs and their corresponding stream_orderings
that correspond to the oldest events that have not yet been sent to
the destination.
Args:
destination: the destination in question
last_successful_stream_ordering: the stream_ordering of the
most-recently successfully-transmitted event to the destination
Returns:
list of event_ids
"""
return await self.db_pool.runInteraction(
"get_catch_up_room_event_ids",
self._get_catch_up_room_event_ids_txn,
destination,
last_successful_stream_ordering,
)
@staticmethod
def _get_catch_up_room_event_ids_txn(
txn, destination: str, last_successful_stream_ordering: int,
) -> List[str]:
q = """
SELECT event_id FROM destination_rooms
JOIN events USING (stream_ordering)
WHERE destination = ?
AND stream_ordering > ?
ORDER BY stream_ordering
LIMIT 50
"""
txn.execute(
q, (destination, last_successful_stream_ordering),
)
event_ids = [row[0] for row in txn]
return event_ids
+1 -1
View File
@@ -23,7 +23,7 @@ from synapse.types import JsonDict
from synapse.util import json_encoder, stringutils
@attr.s
@attr.s(slots=True)
class UIAuthSessionData:
session_id = attr.ib(type=str)
# The dictionary from the client root level, not the 'auth' key.

Some files were not shown because too many files have changed in this diff Show More