Compare commits
35 Commits
michaelkay
...
1.7.2
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f03c877b32 | ||
|
|
b2db382841 | ||
|
|
138708608b | ||
|
|
29794c6bc8 | ||
|
|
4caab0e95e | ||
|
|
03d3792f3c | ||
|
|
4154637834 | ||
|
|
e156c86a7f | ||
|
|
d656e91fc2 | ||
|
|
5029422530 | ||
|
|
5ca2cfadc3 | ||
|
|
6316530366 | ||
|
|
7f8d8fd2e2 | ||
|
|
a820069549 | ||
|
|
284e690aa0 | ||
|
|
a29420f9f4 | ||
|
|
596dd9914d | ||
|
|
bbb75ff6ee | ||
|
|
ff773ff724 | ||
|
|
83895316d4 | ||
|
|
6577f2d887 | ||
|
|
35bbe4ca79 | ||
|
|
20d5ba16e6 | ||
|
|
be294d6fde | ||
|
|
4c7b1bb6cc | ||
|
|
8b9f5c21c3 | ||
|
|
ac87ddb242 | ||
|
|
487f1bb49d | ||
|
|
bee1982d17 | ||
|
|
ba57a45644 | ||
|
|
bac1578013 | ||
|
|
d046f367a4 | ||
|
|
f5aeea9e89 | ||
|
|
58fdcbdfe7 | ||
|
|
e2cce15af1 |
@@ -34,8 +34,33 @@ Device list doesn't change if remote server is down
|
||||
Remote servers cannot set power levels in rooms without existing powerlevels
|
||||
Remote servers should reject attempts by non-creators to set the power levels
|
||||
|
||||
# https://buildkite.com/matrix-dot-org/synapse/builds/6134#6f67bf47-e234-474d-80e8-c6e1868b15c5
|
||||
# new failures as of https://github.com/matrix-org/sytest/pull/753
|
||||
GET /rooms/:room_id/messages returns a message
|
||||
GET /rooms/:room_id/messages lazy loads members correctly
|
||||
Read receipts are sent as events
|
||||
Only original members of the room can see messages from erased users
|
||||
Device deletion propagates over federation
|
||||
If user leaves room, remote user changes device and rejoins we see update in /sync and /keys/changes
|
||||
Changing user-signing key notifies local users
|
||||
Newly updated tags appear in an incremental v2 /sync
|
||||
Server correctly handles incoming m.device_list_update
|
||||
Local device key changes get to remote servers with correct prev_id
|
||||
AS-ghosted users can use rooms via AS
|
||||
Ghost user must register before joining room
|
||||
Test that a message is pushed
|
||||
Invites are pushed
|
||||
Rooms with aliases are correctly named in pushed
|
||||
Rooms with names are correctly named in pushed
|
||||
Rooms with canonical alias are correctly named in pushed
|
||||
Rooms with many users are correctly pushed
|
||||
Don't get pushed for rooms you've muted
|
||||
Rejected events are not pushed
|
||||
Test that rejected pushers are removed.
|
||||
Events come down the correct room
|
||||
|
||||
# this fails reliably with a torture level of 100 due to https://github.com/matrix-org/synapse/issues/6536
|
||||
Outbound federation requests missing prev_events and then asks for /state_ids and resolves the state
|
||||
# https://buildkite.com/matrix-dot-org/sytest/builds/326#cca62404-a88a-4fcb-ad41-175fd3377603
|
||||
Presence changes to UNAVAILABLE are reported to remote room members
|
||||
If remote user leaves room, changes device and rejoins we see update in sync
|
||||
uploading self-signing key notifies over federation
|
||||
Inbound federation can receive redacted events
|
||||
Outbound federation can request missing events
|
||||
|
||||
43
CHANGES.md
43
CHANGES.md
@@ -1,3 +1,44 @@
|
||||
Synapse 1.7.2 (2019-12-20)
|
||||
==========================
|
||||
|
||||
This release fixes some regressions introduced in Synapse 1.7.0 and 1.7.1.
|
||||
|
||||
Bugfixes
|
||||
--------
|
||||
|
||||
- Fix a regression introduced in Synapse 1.7.1 which caused errors when attempting to backfill rooms over federation. ([\#6576](https://github.com/matrix-org/synapse/issues/6576))
|
||||
- Fix a bug introduced in Synapse 1.7.0 which caused an error on startup when upgrading from versions before 1.3.0. ([\#6578](https://github.com/matrix-org/synapse/issues/6578))
|
||||
|
||||
|
||||
Synapse 1.7.1 (2019-12-18)
|
||||
==========================
|
||||
|
||||
This release includes several security fixes as well as a fix to a bug exposed by the security fixes. Administrators are encouraged to upgrade as soon as possible.
|
||||
|
||||
Security updates
|
||||
----------------
|
||||
|
||||
- Fix a bug which could cause room events to be incorrectly authorized using events from a different room. ([\#6501](https://github.com/matrix-org/synapse/issues/6501), [\#6503](https://github.com/matrix-org/synapse/issues/6503), [\#6521](https://github.com/matrix-org/synapse/issues/6521), [\#6524](https://github.com/matrix-org/synapse/issues/6524), [\#6530](https://github.com/matrix-org/synapse/issues/6530), [\#6531](https://github.com/matrix-org/synapse/issues/6531))
|
||||
- Fix a bug causing responses to the `/context` client endpoint to not use the pruned version of the event. ([\#6553](https://github.com/matrix-org/synapse/issues/6553))
|
||||
- Fix a cause of state resets in room versions 2 onwards. ([\#6556](https://github.com/matrix-org/synapse/issues/6556), [\#6560](https://github.com/matrix-org/synapse/issues/6560))
|
||||
|
||||
Bugfixes
|
||||
--------
|
||||
|
||||
- Fix a bug which could cause the federation server to incorrectly return errors when handling certain obscure event graphs. ([\#6526](https://github.com/matrix-org/synapse/issues/6526), [\#6527](https://github.com/matrix-org/synapse/issues/6527))
|
||||
|
||||
Synapse 1.7.0 (2019-12-13)
|
||||
==========================
|
||||
|
||||
This release changes the default settings so that only local authenticated users can query the server's room directory. See the [upgrade notes](UPGRADE.rst#upgrading-to-v170) for details.
|
||||
|
||||
Support for SQLite versions before 3.11 is now deprecated. A future release will refuse to start if used with an SQLite version before 3.11.
|
||||
|
||||
Administrators are reminded that SQLite should not be used for production instances. Instructions for migrating to Postgres are available [here](docs/postgres.md). A future release of synapse will, by default, disable federation for servers using SQLite.
|
||||
|
||||
No significant changes since 1.7.0rc2.
|
||||
|
||||
|
||||
Synapse 1.7.0rc2 (2019-12-11)
|
||||
=============================
|
||||
|
||||
@@ -76,7 +117,7 @@ Internal Changes
|
||||
- Add a test scenario to make sure room history purges don't break `/messages` in the future. ([\#6392](https://github.com/matrix-org/synapse/issues/6392))
|
||||
- Clarifications for the email configuration settings. ([\#6423](https://github.com/matrix-org/synapse/issues/6423))
|
||||
- Add more tests to the blacklist when running in worker mode. ([\#6429](https://github.com/matrix-org/synapse/issues/6429))
|
||||
- Refactor data store layer to support multiple databases in the future. ([\#6454](https://github.com/matrix-org/synapse/issues/6454), [\#6464](https://github.com/matrix-org/synapse/issues/6464), [\#6469](https://github.com/matrix-org/synapse/issues/6469), [\#6487](https://github.com/matrix-org/synapse/issues/6487))
|
||||
- Refactor data store layer to support multiple databases in the future. ([\#6454](https://github.com/matrix-org/synapse/issues/6454), [\#6464](https://github.com/matrix-org/synapse/issues/6464), [\#6469](https://github.com/matrix-org/synapse/issues/6469), [\#6487](https://github.com/matrix-org/synapse/issues/6487))
|
||||
- Port synapse.rest.client.v1 to async/await. ([\#6482](https://github.com/matrix-org/synapse/issues/6482))
|
||||
- Port synapse.rest.client.v2_alpha to async/await. ([\#6483](https://github.com/matrix-org/synapse/issues/6483))
|
||||
- Port SyncHandler to async/await. ([\#6484](https://github.com/matrix-org/synapse/issues/6484))
|
||||
|
||||
@@ -393,4 +393,4 @@ something like the following in their logs::
|
||||
2019-09-11 19:32:04,271 - synapse.federation.transport.server - 288 - WARNING - GET-11752 - authenticate_request failed: 401: Invalid signature for server <server> with key ed25519:a_EqML: Unable to verify signature for <server>
|
||||
|
||||
This is normally caused by a misconfiguration in your reverse-proxy. See
|
||||
`<docs/reverse_proxy.rst>`_ and double-check that your settings are correct.
|
||||
`<docs/reverse_proxy.md>`_ and double-check that your settings are correct.
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
Implement v2 APIs for the `send_join` and `send_leave` federation endpoints (as described in [MSC1802](https://github.com/matrix-org/matrix-doc/pull/1802)).
|
||||
@@ -1 +0,0 @@
|
||||
Prevent redacted events from being returned during message search.
|
||||
@@ -1 +0,0 @@
|
||||
Prevent error on trying to search a upgraded room when the server is not in the predecessor room.
|
||||
@@ -1 +0,0 @@
|
||||
Add a develop script to generate full SQL schemas.
|
||||
@@ -1 +0,0 @@
|
||||
Allow custom SAML username mapping functinality through an external provider plugin.
|
||||
@@ -1 +0,0 @@
|
||||
Improve performance of looking up cross-signing keys.
|
||||
@@ -1 +0,0 @@
|
||||
Port synapse.handlers.initial_sync to async/await.
|
||||
@@ -1 +0,0 @@
|
||||
Refactor get_events_from_store_or_dest to return a dict.
|
||||
@@ -1 +0,0 @@
|
||||
Remove redundant code from event authorisation implementation.
|
||||
@@ -1 +0,0 @@
|
||||
Move get_state methods into FederationHandler.
|
||||
@@ -1 +0,0 @@
|
||||
Port handlers.account_data and handlers.account_validity to async/await.
|
||||
@@ -1 +0,0 @@
|
||||
Make `make_deferred_yieldable` to work with async/await.
|
||||
@@ -1 +0,0 @@
|
||||
Remove `SnapshotCache` in favour of `ResponseCache`.
|
||||
@@ -1 +0,0 @@
|
||||
Change phone home stats to not assume there is a single database and report information about the database used by the main data store.
|
||||
@@ -1 +0,0 @@
|
||||
Move database config from apps into HomeServer object.
|
||||
@@ -1 +0,0 @@
|
||||
Silence mypy errors for files outside those specified.
|
||||
@@ -1 +0,0 @@
|
||||
Fix race which occasionally caused deleted devices to reappear.
|
||||
@@ -1 +0,0 @@
|
||||
Clean up some logging when handling incoming events over federation.
|
||||
@@ -1 +0,0 @@
|
||||
Port some of FederationHandler to async/await.
|
||||
@@ -1 +0,0 @@
|
||||
Refactor some code in the event authentication path for clarity.
|
||||
@@ -1 +0,0 @@
|
||||
Prevent redacted events from being returned during message search.
|
||||
@@ -1,2 +0,0 @@
|
||||
Improve sanity-checking when receiving events over federation.
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
Test more folders against mypy.
|
||||
@@ -1 +0,0 @@
|
||||
Adjust the sytest blacklist for worker mode.
|
||||
18
debian/changelog
vendored
18
debian/changelog
vendored
@@ -1,3 +1,21 @@
|
||||
matrix-synapse-py3 (1.7.2) stable; urgency=medium
|
||||
|
||||
* New synapse release 1.7.2.
|
||||
|
||||
-- Synapse Packaging team <packages@matrix.org> Fri, 20 Dec 2019 10:56:50 +0000
|
||||
|
||||
matrix-synapse-py3 (1.7.1) stable; urgency=medium
|
||||
|
||||
* New synapse release 1.7.1.
|
||||
|
||||
-- Synapse Packaging team <packages@matrix.org> Wed, 18 Dec 2019 09:37:59 +0000
|
||||
|
||||
matrix-synapse-py3 (1.7.0) stable; urgency=medium
|
||||
|
||||
* New synapse release 1.7.0.
|
||||
|
||||
-- Synapse Packaging team <packages@matrix.org> Fri, 13 Dec 2019 10:19:38 +0000
|
||||
|
||||
matrix-synapse-py3 (1.6.1) stable; urgency=medium
|
||||
|
||||
* New synapse release 1.6.1.
|
||||
|
||||
@@ -1,77 +0,0 @@
|
||||
# SAML Mapping Providers
|
||||
|
||||
A SAML mapping provider is a Python class (loaded via a Python module) that
|
||||
works out how to map attributes of a SAML response object to Matrix-specific
|
||||
user attributes. Details such as user ID localpart, displayname, and even avatar
|
||||
URLs are all things that can be mapped from talking to a SSO service.
|
||||
|
||||
As an example, a SSO service may return the email address
|
||||
"john.smith@example.com" for a user, whereas Synapse will need to figure out how
|
||||
to turn that into a displayname when creating a Matrix user for this individual.
|
||||
It may choose `John Smith`, or `Smith, John [Example.com]` or any number of
|
||||
variations. As each Synapse configuration may want something different, this is
|
||||
where SAML mapping providers come into play.
|
||||
|
||||
## Enabling Providers
|
||||
|
||||
External mapping providers are provided to Synapse in the form of an external
|
||||
Python module. Retrieve this module from [PyPi](https://pypi.org) or elsewhere,
|
||||
then tell Synapse where to look for the handler class by editing the
|
||||
`saml2_config.user_mapping_provider.module` config option.
|
||||
|
||||
`saml2_config.user_mapping_provider.config` allows you to provide custom
|
||||
configuration options to the module. Check with the module's documentation for
|
||||
what options it provides (if any). The options listed by default are for the
|
||||
user mapping provider built in to Synapse. If using a custom module, you should
|
||||
comment these options out and use those specified by the module instead.
|
||||
|
||||
## Building a Custom Mapping Provider
|
||||
|
||||
A custom mapping provider must specify the following methods:
|
||||
|
||||
* `__init__(self, parsed_config)`
|
||||
- Arguments:
|
||||
- `parsed_config` - A configuration object that is the return value of the
|
||||
`parse_config` method. You should set any configuration options needed by
|
||||
the module here.
|
||||
* `saml_response_to_user_attributes(self, saml_response, failures)`
|
||||
- Arguments:
|
||||
- `saml_response` - A `saml2.response.AuthnResponse` object to extract user
|
||||
information from.
|
||||
- `failures` - An `int` that represents the amount of times the returned
|
||||
mxid localpart mapping has failed. This should be used
|
||||
to create a deduplicated mxid localpart which should be
|
||||
returned instead. For example, if this method returns
|
||||
`john.doe` as the value of `mxid_localpart` in the returned
|
||||
dict, and that is already taken on the homeserver, this
|
||||
method will be called again with the same parameters but
|
||||
with failures=1. The method should then return a different
|
||||
`mxid_localpart` value, such as `john.doe1`.
|
||||
- This method must return a dictionary, which will then be used by Synapse
|
||||
to build a new user. The following keys are allowed:
|
||||
* `mxid_localpart` - Required. The mxid localpart of the new user.
|
||||
* `displayname` - The displayname of the new user. If not provided, will default to
|
||||
the value of `mxid_localpart`.
|
||||
* `parse_config(config)`
|
||||
- This method should have the `@staticmethod` decoration.
|
||||
- Arguments:
|
||||
- `config` - A `dict` representing the parsed content of the
|
||||
`saml2_config.user_mapping_provider.config` homeserver config option.
|
||||
Runs on homeserver startup. Providers should extract any option values
|
||||
they need here.
|
||||
- Whatever is returned will be passed back to the user mapping provider module's
|
||||
`__init__` method during construction.
|
||||
* `get_saml_attributes(config)`
|
||||
- This method should have the `@staticmethod` decoration.
|
||||
- Arguments:
|
||||
- `config` - A object resulting from a call to `parse_config`.
|
||||
- Returns a tuple of two sets. The first set equates to the saml auth
|
||||
response attributes that are required for the module to function, whereas
|
||||
the second set consists of those attributes which can be used if available,
|
||||
but are not necessary.
|
||||
|
||||
## Synapse's Default Provider
|
||||
|
||||
Synapse has a built-in SAML mapping provider if a custom provider isn't
|
||||
specified in the config. It is located at
|
||||
[`synapse.handlers.saml_handler.DefaultSamlMappingProvider`](../synapse/handlers/saml_handler.py).
|
||||
@@ -1250,58 +1250,33 @@ saml2_config:
|
||||
#
|
||||
#config_path: "CONFDIR/sp_conf.py"
|
||||
|
||||
# The lifetime of a SAML session. This defines how long a user has to
|
||||
# the lifetime of a SAML session. This defines how long a user has to
|
||||
# complete the authentication process, if allow_unsolicited is unset.
|
||||
# The default is 5 minutes.
|
||||
#
|
||||
#saml_session_lifetime: 5m
|
||||
|
||||
# An external module can be provided here as a custom solution to
|
||||
# mapping attributes returned from a saml provider onto a matrix user.
|
||||
# The SAML attribute (after mapping via the attribute maps) to use to derive
|
||||
# the Matrix ID from. 'uid' by default.
|
||||
#
|
||||
user_mapping_provider:
|
||||
# The custom module's class. Uncomment to use a custom module.
|
||||
#
|
||||
#module: mapping_provider.SamlMappingProvider
|
||||
#mxid_source_attribute: displayName
|
||||
|
||||
# Custom configuration values for the module. Below options are
|
||||
# intended for the built-in provider, they should be changed if
|
||||
# using a custom module. This section will be passed as a Python
|
||||
# dictionary to the module's `parse_config` method.
|
||||
#
|
||||
config:
|
||||
# The SAML attribute (after mapping via the attribute maps) to use
|
||||
# to derive the Matrix ID from. 'uid' by default.
|
||||
#
|
||||
# Note: This used to be configured by the
|
||||
# saml2_config.mxid_source_attribute option. If that is still
|
||||
# defined, its value will be used instead.
|
||||
#
|
||||
#mxid_source_attribute: displayName
|
||||
# The mapping system to use for mapping the saml attribute onto a matrix ID.
|
||||
# Options include:
|
||||
# * 'hexencode' (which maps unpermitted characters to '=xx')
|
||||
# * 'dotreplace' (which replaces unpermitted characters with '.').
|
||||
# The default is 'hexencode'.
|
||||
#
|
||||
#mxid_mapping: dotreplace
|
||||
|
||||
# The mapping system to use for mapping the saml attribute onto a
|
||||
# matrix ID.
|
||||
#
|
||||
# Options include:
|
||||
# * 'hexencode' (which maps unpermitted characters to '=xx')
|
||||
# * 'dotreplace' (which replaces unpermitted characters with
|
||||
# '.').
|
||||
# The default is 'hexencode'.
|
||||
#
|
||||
# Note: This used to be configured by the
|
||||
# saml2_config.mxid_mapping option. If that is still defined, its
|
||||
# value will be used instead.
|
||||
#
|
||||
#mxid_mapping: dotreplace
|
||||
|
||||
# In previous versions of synapse, the mapping from SAML attribute to
|
||||
# MXID was always calculated dynamically rather than stored in a
|
||||
# table. For backwards- compatibility, we will look for user_ids
|
||||
# matching such a pattern before creating a new account.
|
||||
# In previous versions of synapse, the mapping from SAML attribute to MXID was
|
||||
# always calculated dynamically rather than stored in a table. For backwards-
|
||||
# compatibility, we will look for user_ids matching such a pattern before
|
||||
# creating a new account.
|
||||
#
|
||||
# This setting controls the SAML attribute which will be used for this
|
||||
# backwards-compatibility lookup. Typically it should be 'uid', but if
|
||||
# the attribute maps are changed, it may be necessary to change it.
|
||||
# backwards-compatibility lookup. Typically it should be 'uid', but if the
|
||||
# attribute maps are changed, it may be necessary to change it.
|
||||
#
|
||||
# The default is 'uid'.
|
||||
#
|
||||
|
||||
43
docs/sample_log_config.yaml
Normal file
43
docs/sample_log_config.yaml
Normal file
@@ -0,0 +1,43 @@
|
||||
# Example log config file for synapse.
|
||||
#
|
||||
# This is a YAML file containing a standard Python logging configuration
|
||||
# dictionary. See [1] for details on the valid settings.
|
||||
#
|
||||
# [1]: https://docs.python.org/3.7/library/logging.config.html#configuration-dictionary-schema
|
||||
|
||||
version: 1
|
||||
|
||||
formatters:
|
||||
precise:
|
||||
format: '%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s - %(message)s'
|
||||
|
||||
filters:
|
||||
context:
|
||||
(): synapse.logging.context.LoggingContextFilter
|
||||
request: ""
|
||||
|
||||
handlers:
|
||||
file:
|
||||
class: logging.handlers.RotatingFileHandler
|
||||
formatter: precise
|
||||
filename: /home/rav/work/synapse/homeserver.log
|
||||
maxBytes: 104857600
|
||||
backupCount: 10
|
||||
filters: [context]
|
||||
encoding: utf8
|
||||
console:
|
||||
class: logging.StreamHandler
|
||||
formatter: precise
|
||||
filters: [context]
|
||||
|
||||
loggers:
|
||||
synapse.storage.SQL:
|
||||
# beware: increasing this to DEBUG will make synapse log sensitive
|
||||
# information such as access tokens.
|
||||
level: INFO
|
||||
|
||||
root:
|
||||
level: INFO
|
||||
handlers: [file, console]
|
||||
|
||||
disable_existing_loggers: false
|
||||
@@ -196,7 +196,7 @@ Handles the media repository. It can handle all endpoints starting with:
|
||||
|
||||
/_matrix/media/
|
||||
|
||||
And the following regular expressions matching media-specific administration APIs:
|
||||
... and the following regular expressions matching media-specific administration APIs:
|
||||
|
||||
^/_synapse/admin/v1/purge_media_cache$
|
||||
^/_synapse/admin/v1/room/.*/media$
|
||||
@@ -206,6 +206,18 @@ You should also set `enable_media_repo: False` in the shared configuration
|
||||
file to stop the main synapse running background jobs related to managing the
|
||||
media repository.
|
||||
|
||||
In the `media_repository` worker configuration file, configure the http listener to
|
||||
expose the `media` resource. For example:
|
||||
|
||||
```yaml
|
||||
worker_listeners:
|
||||
- type: http
|
||||
port: 8085
|
||||
resources:
|
||||
- names:
|
||||
- media
|
||||
```
|
||||
|
||||
Note this worker cannot be load-balanced: only one instance should be active.
|
||||
|
||||
### `synapse.app.client_reader`
|
||||
|
||||
2
mypy.ini
2
mypy.ini
@@ -1,7 +1,7 @@
|
||||
[mypy]
|
||||
namespace_packages = True
|
||||
plugins = mypy_zope:plugin
|
||||
follow_imports = silent
|
||||
follow_imports = normal
|
||||
check_untyped_defs = True
|
||||
show_error_codes = True
|
||||
show_traceback = True
|
||||
|
||||
@@ -1,184 +0,0 @@
|
||||
#!/bin/bash
|
||||
#
|
||||
# This script generates SQL files for creating a brand new Synapse DB with the latest
|
||||
# schema, on both SQLite3 and Postgres.
|
||||
#
|
||||
# It does so by having Synapse generate an up-to-date SQLite DB, then running
|
||||
# synapse_port_db to convert it to Postgres. It then dumps the contents of both.
|
||||
|
||||
POSTGRES_HOST="localhost"
|
||||
POSTGRES_DB_NAME="synapse_full_schema.$$"
|
||||
|
||||
SQLITE_FULL_SCHEMA_OUTPUT_FILE="full.sql.sqlite"
|
||||
POSTGRES_FULL_SCHEMA_OUTPUT_FILE="full.sql.postgres"
|
||||
|
||||
REQUIRED_DEPS=("matrix-synapse" "psycopg2")
|
||||
|
||||
usage() {
|
||||
echo
|
||||
echo "Usage: $0 -p <postgres_username> -o <path> [-c] [-n] [-h]"
|
||||
echo
|
||||
echo "-p <postgres_username>"
|
||||
echo " Username to connect to local postgres instance. The password will be requested"
|
||||
echo " during script execution."
|
||||
echo "-c"
|
||||
echo " CI mode. Enables coverage tracking and prints every command that the script runs."
|
||||
echo "-o <path>"
|
||||
echo " Directory to output full schema files to."
|
||||
echo "-h"
|
||||
echo " Display this help text."
|
||||
}
|
||||
|
||||
while getopts "p:co:h" opt; do
|
||||
case $opt in
|
||||
p)
|
||||
POSTGRES_USERNAME=$OPTARG
|
||||
;;
|
||||
c)
|
||||
# Print all commands that are being executed
|
||||
set -x
|
||||
|
||||
# Modify required dependencies for coverage
|
||||
REQUIRED_DEPS+=("coverage" "coverage-enable-subprocess")
|
||||
|
||||
COVERAGE=1
|
||||
;;
|
||||
o)
|
||||
command -v realpath > /dev/null || (echo "The -o flag requires the 'realpath' binary to be installed" && exit 1)
|
||||
OUTPUT_DIR="$(realpath "$OPTARG")"
|
||||
;;
|
||||
h)
|
||||
usage
|
||||
exit
|
||||
;;
|
||||
\?)
|
||||
echo "ERROR: Invalid option: -$OPTARG" >&2
|
||||
usage
|
||||
exit
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
# Check that required dependencies are installed
|
||||
unsatisfied_requirements=()
|
||||
for dep in "${REQUIRED_DEPS[@]}"; do
|
||||
pip show "$dep" --quiet || unsatisfied_requirements+=("$dep")
|
||||
done
|
||||
if [ ${#unsatisfied_requirements} -ne 0 ]; then
|
||||
echo "Please install the following python packages: ${unsatisfied_requirements[*]}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ -z "$POSTGRES_USERNAME" ]; then
|
||||
echo "No postgres username supplied"
|
||||
usage
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ -z "$OUTPUT_DIR" ]; then
|
||||
echo "No output directory supplied"
|
||||
usage
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Create the output directory if it doesn't exist
|
||||
mkdir -p "$OUTPUT_DIR"
|
||||
|
||||
read -rsp "Postgres password for '$POSTGRES_USERNAME': " POSTGRES_PASSWORD
|
||||
echo ""
|
||||
|
||||
# Exit immediately if a command fails
|
||||
set -e
|
||||
|
||||
# cd to root of the synapse directory
|
||||
cd "$(dirname "$0")/.."
|
||||
|
||||
# Create temporary SQLite and Postgres homeserver db configs and key file
|
||||
TMPDIR=$(mktemp -d)
|
||||
KEY_FILE=$TMPDIR/test.signing.key # default Synapse signing key path
|
||||
SQLITE_CONFIG=$TMPDIR/sqlite.conf
|
||||
SQLITE_DB=$TMPDIR/homeserver.db
|
||||
POSTGRES_CONFIG=$TMPDIR/postgres.conf
|
||||
|
||||
# Ensure these files are delete on script exit
|
||||
trap 'rm -rf $TMPDIR' EXIT
|
||||
|
||||
cat > "$SQLITE_CONFIG" <<EOF
|
||||
server_name: "test"
|
||||
|
||||
signing_key_path: "$KEY_FILE"
|
||||
macaroon_secret_key: "abcde"
|
||||
|
||||
report_stats: false
|
||||
|
||||
database:
|
||||
name: "sqlite3"
|
||||
args:
|
||||
database: "$SQLITE_DB"
|
||||
|
||||
# Suppress the key server warning.
|
||||
trusted_key_servers: []
|
||||
EOF
|
||||
|
||||
cat > "$POSTGRES_CONFIG" <<EOF
|
||||
server_name: "test"
|
||||
|
||||
signing_key_path: "$KEY_FILE"
|
||||
macaroon_secret_key: "abcde"
|
||||
|
||||
report_stats: false
|
||||
|
||||
database:
|
||||
name: "psycopg2"
|
||||
args:
|
||||
user: "$POSTGRES_USERNAME"
|
||||
host: "$POSTGRES_HOST"
|
||||
password: "$POSTGRES_PASSWORD"
|
||||
database: "$POSTGRES_DB_NAME"
|
||||
|
||||
# Suppress the key server warning.
|
||||
trusted_key_servers: []
|
||||
EOF
|
||||
|
||||
# Generate the server's signing key.
|
||||
echo "Generating SQLite3 db schema..."
|
||||
python -m synapse.app.homeserver --generate-keys -c "$SQLITE_CONFIG"
|
||||
|
||||
# Make sure the SQLite3 database is using the latest schema and has no pending background update.
|
||||
echo "Running db background jobs..."
|
||||
scripts-dev/update_database --database-config "$SQLITE_CONFIG"
|
||||
|
||||
# Create the PostgreSQL database.
|
||||
echo "Creating postgres database..."
|
||||
createdb $POSTGRES_DB_NAME
|
||||
|
||||
echo "Copying data from SQLite3 to Postgres with synapse_port_db..."
|
||||
if [ -z "$COVERAGE" ]; then
|
||||
# No coverage needed
|
||||
scripts/synapse_port_db --sqlite-database "$SQLITE_DB" --postgres-config "$POSTGRES_CONFIG"
|
||||
else
|
||||
# Coverage desired
|
||||
coverage run scripts/synapse_port_db --sqlite-database "$SQLITE_DB" --postgres-config "$POSTGRES_CONFIG"
|
||||
fi
|
||||
|
||||
# Delete schema_version, applied_schema_deltas and applied_module_schemas tables
|
||||
# This needs to be done after synapse_port_db is run
|
||||
echo "Dropping unwanted db tables..."
|
||||
SQL="
|
||||
DROP TABLE schema_version;
|
||||
DROP TABLE applied_schema_deltas;
|
||||
DROP TABLE applied_module_schemas;
|
||||
"
|
||||
sqlite3 "$SQLITE_DB" <<< "$SQL"
|
||||
psql $POSTGRES_DB_NAME -U "$POSTGRES_USERNAME" -w <<< "$SQL"
|
||||
|
||||
echo "Dumping SQLite3 schema to '$OUTPUT_DIR/$SQLITE_FULL_SCHEMA_OUTPUT_FILE'..."
|
||||
sqlite3 "$SQLITE_DB" ".dump" > "$OUTPUT_DIR/$SQLITE_FULL_SCHEMA_OUTPUT_FILE"
|
||||
|
||||
echo "Dumping Postgres schema to '$OUTPUT_DIR/$POSTGRES_FULL_SCHEMA_OUTPUT_FILE'..."
|
||||
pg_dump --format=plain --no-tablespaces --no-acl --no-owner $POSTGRES_DB_NAME | sed -e '/^--/d' -e 's/public\.//g' -e '/^SET /d' -e '/^SELECT /d' > "$OUTPUT_DIR/$POSTGRES_FULL_SCHEMA_OUTPUT_FILE"
|
||||
|
||||
echo "Cleaning up temporary Postgres database..."
|
||||
dropdb $POSTGRES_DB_NAME
|
||||
|
||||
echo "Done! Files dumped to: $OUTPUT_DIR"
|
||||
@@ -26,6 +26,7 @@ from synapse.config.homeserver import HomeServerConfig
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.server import HomeServer
|
||||
from synapse.storage import DataStore
|
||||
from synapse.storage.engines import create_engine
|
||||
from synapse.storage.prepare_database import prepare_database
|
||||
|
||||
logger = logging.getLogger("update_database")
|
||||
@@ -34,11 +35,21 @@ logger = logging.getLogger("update_database")
|
||||
class MockHomeserver(HomeServer):
|
||||
DATASTORE_CLASS = DataStore
|
||||
|
||||
def __init__(self, config, **kwargs):
|
||||
def __init__(self, config, database_engine, db_conn, **kwargs):
|
||||
super(MockHomeserver, self).__init__(
|
||||
config.server_name, reactor=reactor, config=config, **kwargs
|
||||
config.server_name,
|
||||
reactor=reactor,
|
||||
config=config,
|
||||
database_engine=database_engine,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
self.database_engine = database_engine
|
||||
self.db_conn = db_conn
|
||||
|
||||
def get_db_conn(self):
|
||||
return self.db_conn
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
@@ -74,14 +85,24 @@ if __name__ == "__main__":
|
||||
config = HomeServerConfig()
|
||||
config.parse_config_dict(hs_config, "", "")
|
||||
|
||||
# Instantiate and initialise the homeserver object.
|
||||
hs = MockHomeserver(config)
|
||||
# Create the database engine and a connection to it.
|
||||
database_engine = create_engine(config.database_config)
|
||||
db_conn = database_engine.module.connect(
|
||||
**{
|
||||
k: v
|
||||
for k, v in config.database_config.get("args", {}).items()
|
||||
if not k.startswith("cp_")
|
||||
}
|
||||
)
|
||||
|
||||
db_conn = hs.get_db_conn()
|
||||
# Update the database to the latest schema.
|
||||
prepare_database(db_conn, hs.database_engine, config=config)
|
||||
prepare_database(db_conn, database_engine, config=config)
|
||||
db_conn.commit()
|
||||
|
||||
# Instantiate and initialise the homeserver object.
|
||||
hs = MockHomeserver(
|
||||
config, database_engine, db_conn, db_config=config.database_config,
|
||||
)
|
||||
# setup instantiates the store within the homeserver object.
|
||||
hs.setup()
|
||||
store = hs.get_datastore()
|
||||
|
||||
@@ -36,7 +36,7 @@ try:
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
__version__ = "1.7.0rc2"
|
||||
__version__ = "1.7.2"
|
||||
|
||||
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
|
||||
# We import here so that we don't have to install a bunch of deps when
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import Dict, Tuple
|
||||
|
||||
from six import itervalues
|
||||
|
||||
@@ -25,13 +26,7 @@ from twisted.internet import defer
|
||||
import synapse.logging.opentracing as opentracing
|
||||
import synapse.types
|
||||
from synapse import event_auth
|
||||
from synapse.api.constants import (
|
||||
EventTypes,
|
||||
JoinRules,
|
||||
LimitBlockingTypes,
|
||||
Membership,
|
||||
UserTypes,
|
||||
)
|
||||
from synapse.api.constants import EventTypes, LimitBlockingTypes, Membership, UserTypes
|
||||
from synapse.api.errors import (
|
||||
AuthError,
|
||||
Codes,
|
||||
@@ -513,71 +508,43 @@ class Auth(object):
|
||||
"""
|
||||
return self.store.is_server_admin(user)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def compute_auth_events(self, event, current_state_ids, for_verification=False):
|
||||
def compute_auth_events(
|
||||
self,
|
||||
event,
|
||||
current_state_ids: Dict[Tuple[str, str], str],
|
||||
for_verification: bool = False,
|
||||
):
|
||||
"""Given an event and current state return the list of event IDs used
|
||||
to auth an event.
|
||||
|
||||
If `for_verification` is False then only return auth events that
|
||||
should be added to the event's `auth_events`.
|
||||
|
||||
Returns:
|
||||
defer.Deferred(list[str]): List of event IDs.
|
||||
"""
|
||||
|
||||
if event.type == EventTypes.Create:
|
||||
return []
|
||||
return defer.succeed([])
|
||||
|
||||
# Currently we ignore the `for_verification` flag even though there are
|
||||
# some situations where we can drop particular auth events when adding
|
||||
# to the event's `auth_events` (e.g. joins pointing to previous joins
|
||||
# when room is publically joinable). Dropping event IDs has the
|
||||
# advantage that the auth chain for the room grows slower, but we use
|
||||
# the auth chain in state resolution v2 to order events, which means
|
||||
# care must be taken if dropping events to ensure that it doesn't
|
||||
# introduce undesirable "state reset" behaviour.
|
||||
#
|
||||
# All of which sounds a bit tricky so we don't bother for now.
|
||||
|
||||
auth_ids = []
|
||||
for etype, state_key in event_auth.auth_types_for_event(event):
|
||||
auth_ev_id = current_state_ids.get((etype, state_key))
|
||||
if auth_ev_id:
|
||||
auth_ids.append(auth_ev_id)
|
||||
|
||||
key = (EventTypes.PowerLevels, "")
|
||||
power_level_event_id = current_state_ids.get(key)
|
||||
|
||||
if power_level_event_id:
|
||||
auth_ids.append(power_level_event_id)
|
||||
|
||||
key = (EventTypes.JoinRules, "")
|
||||
join_rule_event_id = current_state_ids.get(key)
|
||||
|
||||
key = (EventTypes.Member, event.sender)
|
||||
member_event_id = current_state_ids.get(key)
|
||||
|
||||
key = (EventTypes.Create, "")
|
||||
create_event_id = current_state_ids.get(key)
|
||||
if create_event_id:
|
||||
auth_ids.append(create_event_id)
|
||||
|
||||
if join_rule_event_id:
|
||||
join_rule_event = yield self.store.get_event(join_rule_event_id)
|
||||
join_rule = join_rule_event.content.get("join_rule")
|
||||
is_public = join_rule == JoinRules.PUBLIC if join_rule else False
|
||||
else:
|
||||
is_public = False
|
||||
|
||||
if event.type == EventTypes.Member:
|
||||
e_type = event.content["membership"]
|
||||
if e_type in [Membership.JOIN, Membership.INVITE]:
|
||||
if join_rule_event_id:
|
||||
auth_ids.append(join_rule_event_id)
|
||||
|
||||
if e_type == Membership.JOIN:
|
||||
if member_event_id and not is_public:
|
||||
auth_ids.append(member_event_id)
|
||||
else:
|
||||
if member_event_id:
|
||||
auth_ids.append(member_event_id)
|
||||
|
||||
if for_verification:
|
||||
key = (EventTypes.Member, event.state_key)
|
||||
existing_event_id = current_state_ids.get(key)
|
||||
if existing_event_id:
|
||||
auth_ids.append(existing_event_id)
|
||||
|
||||
if e_type == Membership.INVITE:
|
||||
if "third_party_invite" in event.content:
|
||||
key = (
|
||||
EventTypes.ThirdPartyInvite,
|
||||
event.content["third_party_invite"]["signed"]["token"],
|
||||
)
|
||||
third_party_invite_id = current_state_ids.get(key)
|
||||
if third_party_invite_id:
|
||||
auth_ids.append(third_party_invite_id)
|
||||
elif member_event_id:
|
||||
member_event = yield self.store.get_event(member_event_id)
|
||||
if member_event.content["membership"] == Membership.JOIN:
|
||||
auth_ids.append(member_event.event_id)
|
||||
|
||||
return auth_ids
|
||||
return defer.succeed(auth_ids)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def check_can_change_room_list(self, room_id, user):
|
||||
|
||||
@@ -45,6 +45,7 @@ from synapse.replication.slave.storage.registration import SlavedRegistrationSto
|
||||
from synapse.replication.slave.storage.room import RoomStore
|
||||
from synapse.replication.tcp.client import ReplicationClientHandler
|
||||
from synapse.server import HomeServer
|
||||
from synapse.storage.engines import create_engine
|
||||
from synapse.util.logcontext import LoggingContext
|
||||
from synapse.util.versionstring import get_version_string
|
||||
|
||||
@@ -228,10 +229,14 @@ def start(config_options):
|
||||
|
||||
synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts
|
||||
|
||||
database_engine = create_engine(config.database_config)
|
||||
|
||||
ss = AdminCmdServer(
|
||||
config.server_name,
|
||||
db_config=config.database_config,
|
||||
config=config,
|
||||
version_string="Synapse/" + get_version_string(synapse),
|
||||
database_engine=database_engine,
|
||||
)
|
||||
|
||||
setup_logging(ss, config, use_worker_options=True)
|
||||
|
||||
@@ -34,6 +34,7 @@ from synapse.replication.slave.storage.events import SlavedEventStore
|
||||
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
|
||||
from synapse.replication.tcp.client import ReplicationClientHandler
|
||||
from synapse.server import HomeServer
|
||||
from synapse.storage.engines import create_engine
|
||||
from synapse.util.httpresourcetree import create_resource_tree
|
||||
from synapse.util.manhole import manhole
|
||||
from synapse.util.versionstring import get_version_string
|
||||
@@ -142,6 +143,8 @@ def start(config_options):
|
||||
|
||||
events.USE_FROZEN_DICTS = config.use_frozen_dicts
|
||||
|
||||
database_engine = create_engine(config.database_config)
|
||||
|
||||
if config.notify_appservices:
|
||||
sys.stderr.write(
|
||||
"\nThe appservices must be disabled in the main synapse process"
|
||||
@@ -156,8 +159,10 @@ def start(config_options):
|
||||
|
||||
ps = AppserviceServer(
|
||||
config.server_name,
|
||||
db_config=config.database_config,
|
||||
config=config,
|
||||
version_string="Synapse/" + get_version_string(synapse),
|
||||
database_engine=database_engine,
|
||||
)
|
||||
|
||||
setup_logging(ps, config, use_worker_options=True)
|
||||
|
||||
@@ -62,6 +62,7 @@ from synapse.rest.client.v2_alpha.keys import KeyChangesServlet, KeyQueryServlet
|
||||
from synapse.rest.client.v2_alpha.register import RegisterRestServlet
|
||||
from synapse.rest.client.versions import VersionsRestServlet
|
||||
from synapse.server import HomeServer
|
||||
from synapse.storage.engines import create_engine
|
||||
from synapse.util.httpresourcetree import create_resource_tree
|
||||
from synapse.util.manhole import manhole
|
||||
from synapse.util.versionstring import get_version_string
|
||||
@@ -180,10 +181,14 @@ def start(config_options):
|
||||
|
||||
events.USE_FROZEN_DICTS = config.use_frozen_dicts
|
||||
|
||||
database_engine = create_engine(config.database_config)
|
||||
|
||||
ss = ClientReaderServer(
|
||||
config.server_name,
|
||||
db_config=config.database_config,
|
||||
config=config,
|
||||
version_string="Synapse/" + get_version_string(synapse),
|
||||
database_engine=database_engine,
|
||||
)
|
||||
|
||||
setup_logging(ss, config, use_worker_options=True)
|
||||
|
||||
@@ -57,6 +57,7 @@ from synapse.rest.client.v1.room import (
|
||||
)
|
||||
from synapse.server import HomeServer
|
||||
from synapse.storage.data_stores.main.user_directory import UserDirectoryStore
|
||||
from synapse.storage.engines import create_engine
|
||||
from synapse.util.httpresourcetree import create_resource_tree
|
||||
from synapse.util.manhole import manhole
|
||||
from synapse.util.versionstring import get_version_string
|
||||
@@ -179,10 +180,14 @@ def start(config_options):
|
||||
|
||||
events.USE_FROZEN_DICTS = config.use_frozen_dicts
|
||||
|
||||
database_engine = create_engine(config.database_config)
|
||||
|
||||
ss = EventCreatorServer(
|
||||
config.server_name,
|
||||
db_config=config.database_config,
|
||||
config=config,
|
||||
version_string="Synapse/" + get_version_string(synapse),
|
||||
database_engine=database_engine,
|
||||
)
|
||||
|
||||
setup_logging(ss, config, use_worker_options=True)
|
||||
|
||||
@@ -46,6 +46,7 @@ from synapse.replication.slave.storage.transactions import SlavedTransactionStor
|
||||
from synapse.replication.tcp.client import ReplicationClientHandler
|
||||
from synapse.rest.key.v2 import KeyApiV2Resource
|
||||
from synapse.server import HomeServer
|
||||
from synapse.storage.engines import create_engine
|
||||
from synapse.util.httpresourcetree import create_resource_tree
|
||||
from synapse.util.manhole import manhole
|
||||
from synapse.util.versionstring import get_version_string
|
||||
@@ -161,10 +162,14 @@ def start(config_options):
|
||||
|
||||
events.USE_FROZEN_DICTS = config.use_frozen_dicts
|
||||
|
||||
database_engine = create_engine(config.database_config)
|
||||
|
||||
ss = FederationReaderServer(
|
||||
config.server_name,
|
||||
db_config=config.database_config,
|
||||
config=config,
|
||||
version_string="Synapse/" + get_version_string(synapse),
|
||||
database_engine=database_engine,
|
||||
)
|
||||
|
||||
setup_logging(ss, config, use_worker_options=True)
|
||||
|
||||
@@ -41,6 +41,7 @@ from synapse.replication.tcp.client import ReplicationClientHandler
|
||||
from synapse.replication.tcp.streams._base import ReceiptsStream
|
||||
from synapse.server import HomeServer
|
||||
from synapse.storage.database import Database
|
||||
from synapse.storage.engines import create_engine
|
||||
from synapse.types import ReadReceipt
|
||||
from synapse.util.async_helpers import Linearizer
|
||||
from synapse.util.httpresourcetree import create_resource_tree
|
||||
@@ -173,6 +174,8 @@ def start(config_options):
|
||||
|
||||
events.USE_FROZEN_DICTS = config.use_frozen_dicts
|
||||
|
||||
database_engine = create_engine(config.database_config)
|
||||
|
||||
if config.send_federation:
|
||||
sys.stderr.write(
|
||||
"\nThe send_federation must be disabled in the main synapse process"
|
||||
@@ -187,8 +190,10 @@ def start(config_options):
|
||||
|
||||
ss = FederationSenderServer(
|
||||
config.server_name,
|
||||
db_config=config.database_config,
|
||||
config=config,
|
||||
version_string="Synapse/" + get_version_string(synapse),
|
||||
database_engine=database_engine,
|
||||
)
|
||||
|
||||
setup_logging(ss, config, use_worker_options=True)
|
||||
|
||||
@@ -39,6 +39,7 @@ from synapse.replication.slave.storage.registration import SlavedRegistrationSto
|
||||
from synapse.replication.tcp.client import ReplicationClientHandler
|
||||
from synapse.rest.client.v2_alpha._base import client_patterns
|
||||
from synapse.server import HomeServer
|
||||
from synapse.storage.engines import create_engine
|
||||
from synapse.util.httpresourcetree import create_resource_tree
|
||||
from synapse.util.manhole import manhole
|
||||
from synapse.util.versionstring import get_version_string
|
||||
@@ -233,10 +234,14 @@ def start(config_options):
|
||||
|
||||
events.USE_FROZEN_DICTS = config.use_frozen_dicts
|
||||
|
||||
database_engine = create_engine(config.database_config)
|
||||
|
||||
ss = FrontendProxyServer(
|
||||
config.server_name,
|
||||
db_config=config.database_config,
|
||||
config=config,
|
||||
version_string="Synapse/" + get_version_string(synapse),
|
||||
database_engine=database_engine,
|
||||
)
|
||||
|
||||
setup_logging(ss, config, use_worker_options=True)
|
||||
|
||||
@@ -69,7 +69,7 @@ from synapse.rest.media.v0.content_repository import ContentRepoResource
|
||||
from synapse.rest.well_known import WellKnownResource
|
||||
from synapse.server import HomeServer
|
||||
from synapse.storage import DataStore
|
||||
from synapse.storage.engines import IncorrectDatabaseSetup
|
||||
from synapse.storage.engines import IncorrectDatabaseSetup, create_engine
|
||||
from synapse.storage.prepare_database import UpgradeDatabaseException
|
||||
from synapse.util.caches import CACHE_SIZE_FACTOR
|
||||
from synapse.util.httpresourcetree import create_resource_tree
|
||||
@@ -328,10 +328,15 @@ def setup(config_options):
|
||||
|
||||
events.USE_FROZEN_DICTS = config.use_frozen_dicts
|
||||
|
||||
database_engine = create_engine(config.database_config)
|
||||
config.database_config["args"]["cp_openfun"] = database_engine.on_new_connection
|
||||
|
||||
hs = SynapseHomeServer(
|
||||
config.server_name,
|
||||
db_config=config.database_config,
|
||||
config=config,
|
||||
version_string="Synapse/" + get_version_string(synapse),
|
||||
database_engine=database_engine,
|
||||
)
|
||||
|
||||
synapse.config.logger.setup_logging(hs, config, use_worker_options=False)
|
||||
@@ -514,10 +519,8 @@ def phone_stats_home(hs, stats, stats_process=_stats_process):
|
||||
# Database version
|
||||
#
|
||||
|
||||
# This only reports info about the *main* database.
|
||||
stats["database_engine"] = hs.get_datastore().db.engine.module.__name__
|
||||
stats["database_server_version"] = hs.get_datastore().db.engine.server_version
|
||||
|
||||
stats["database_engine"] = hs.database_engine.module.__name__
|
||||
stats["database_server_version"] = hs.database_engine.server_version
|
||||
logger.info("Reporting stats to %s: %s" % (hs.config.report_stats_endpoint, stats))
|
||||
try:
|
||||
yield hs.get_proxied_http_client().put_json(
|
||||
|
||||
@@ -40,6 +40,7 @@ from synapse.rest.admin import register_servlets_for_media_repo
|
||||
from synapse.rest.media.v0.content_repository import ContentRepoResource
|
||||
from synapse.server import HomeServer
|
||||
from synapse.storage.data_stores.main.media_repository import MediaRepositoryStore
|
||||
from synapse.storage.engines import create_engine
|
||||
from synapse.util.httpresourcetree import create_resource_tree
|
||||
from synapse.util.manhole import manhole
|
||||
from synapse.util.versionstring import get_version_string
|
||||
@@ -156,10 +157,14 @@ def start(config_options):
|
||||
|
||||
events.USE_FROZEN_DICTS = config.use_frozen_dicts
|
||||
|
||||
database_engine = create_engine(config.database_config)
|
||||
|
||||
ss = MediaRepositoryServer(
|
||||
config.server_name,
|
||||
db_config=config.database_config,
|
||||
config=config,
|
||||
version_string="Synapse/" + get_version_string(synapse),
|
||||
database_engine=database_engine,
|
||||
)
|
||||
|
||||
setup_logging(ss, config, use_worker_options=True)
|
||||
|
||||
@@ -37,6 +37,7 @@ from synapse.replication.slave.storage.room import RoomStore
|
||||
from synapse.replication.tcp.client import ReplicationClientHandler
|
||||
from synapse.server import HomeServer
|
||||
from synapse.storage import DataStore
|
||||
from synapse.storage.engines import create_engine
|
||||
from synapse.util.httpresourcetree import create_resource_tree
|
||||
from synapse.util.manhole import manhole
|
||||
from synapse.util.versionstring import get_version_string
|
||||
@@ -202,10 +203,14 @@ def start(config_options):
|
||||
# Force the pushers to start since they will be disabled in the main config
|
||||
config.start_pushers = True
|
||||
|
||||
database_engine = create_engine(config.database_config)
|
||||
|
||||
ps = PusherServer(
|
||||
config.server_name,
|
||||
db_config=config.database_config,
|
||||
config=config,
|
||||
version_string="Synapse/" + get_version_string(synapse),
|
||||
database_engine=database_engine,
|
||||
)
|
||||
|
||||
setup_logging(ps, config, use_worker_options=True)
|
||||
|
||||
@@ -55,6 +55,7 @@ from synapse.rest.client.v1.room import RoomInitialSyncRestServlet
|
||||
from synapse.rest.client.v2_alpha import sync
|
||||
from synapse.server import HomeServer
|
||||
from synapse.storage.data_stores.main.presence import UserPresenceState
|
||||
from synapse.storage.engines import create_engine
|
||||
from synapse.util.httpresourcetree import create_resource_tree
|
||||
from synapse.util.manhole import manhole
|
||||
from synapse.util.stringutils import random_string
|
||||
@@ -436,10 +437,14 @@ def start(config_options):
|
||||
|
||||
synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts
|
||||
|
||||
database_engine = create_engine(config.database_config)
|
||||
|
||||
ss = SynchrotronServer(
|
||||
config.server_name,
|
||||
db_config=config.database_config,
|
||||
config=config,
|
||||
version_string="Synapse/" + get_version_string(synapse),
|
||||
database_engine=database_engine,
|
||||
application_service_handler=SynchrotronApplicationService(),
|
||||
)
|
||||
|
||||
|
||||
@@ -44,6 +44,7 @@ from synapse.rest.client.v2_alpha import user_directory
|
||||
from synapse.server import HomeServer
|
||||
from synapse.storage.data_stores.main.user_directory import UserDirectoryStore
|
||||
from synapse.storage.database import Database
|
||||
from synapse.storage.engines import create_engine
|
||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||
from synapse.util.httpresourcetree import create_resource_tree
|
||||
from synapse.util.manhole import manhole
|
||||
@@ -199,6 +200,8 @@ def start(config_options):
|
||||
|
||||
events.USE_FROZEN_DICTS = config.use_frozen_dicts
|
||||
|
||||
database_engine = create_engine(config.database_config)
|
||||
|
||||
if config.update_user_directory:
|
||||
sys.stderr.write(
|
||||
"\nThe update_user_directory must be disabled in the main synapse process"
|
||||
@@ -213,8 +216,10 @@ def start(config_options):
|
||||
|
||||
ss = UserDirectoryServer(
|
||||
config.server_name,
|
||||
db_config=config.database_config,
|
||||
config=config,
|
||||
version_string="Synapse/" + get_version_string(synapse),
|
||||
database_engine=database_engine,
|
||||
)
|
||||
|
||||
setup_logging(ss, config, use_worker_options=True)
|
||||
|
||||
@@ -14,19 +14,17 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import re
|
||||
|
||||
from synapse.python_dependencies import DependencyException, check_requirements
|
||||
from synapse.util.module_loader import load_module, load_python_module
|
||||
from synapse.types import (
|
||||
map_username_to_mxid_localpart,
|
||||
mxid_localpart_allowed_characters,
|
||||
)
|
||||
from synapse.util.module_loader import load_python_module
|
||||
|
||||
from ._base import Config, ConfigError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_USER_MAPPING_PROVIDER = (
|
||||
"synapse.handlers.saml_handler.DefaultSamlMappingProvider"
|
||||
)
|
||||
|
||||
|
||||
def _dict_merge(merge_dict, into_dict):
|
||||
"""Do a deep merge of two dicts
|
||||
@@ -77,69 +75,15 @@ class SAML2Config(Config):
|
||||
|
||||
self.saml2_enabled = True
|
||||
|
||||
self.saml2_mxid_source_attribute = saml2_config.get(
|
||||
"mxid_source_attribute", "uid"
|
||||
)
|
||||
|
||||
self.saml2_grandfathered_mxid_source_attribute = saml2_config.get(
|
||||
"grandfathered_mxid_source_attribute", "uid"
|
||||
)
|
||||
|
||||
# user_mapping_provider may be None if the key is present but has no value
|
||||
ump_dict = saml2_config.get("user_mapping_provider") or {}
|
||||
|
||||
# Use the default user mapping provider if not set
|
||||
ump_dict.setdefault("module", DEFAULT_USER_MAPPING_PROVIDER)
|
||||
|
||||
# Ensure a config is present
|
||||
ump_dict["config"] = ump_dict.get("config") or {}
|
||||
|
||||
if ump_dict["module"] == DEFAULT_USER_MAPPING_PROVIDER:
|
||||
# Load deprecated options for use by the default module
|
||||
old_mxid_source_attribute = saml2_config.get("mxid_source_attribute")
|
||||
if old_mxid_source_attribute:
|
||||
logger.warning(
|
||||
"The config option saml2_config.mxid_source_attribute is deprecated. "
|
||||
"Please use saml2_config.user_mapping_provider.config"
|
||||
".mxid_source_attribute instead."
|
||||
)
|
||||
ump_dict["config"]["mxid_source_attribute"] = old_mxid_source_attribute
|
||||
|
||||
old_mxid_mapping = saml2_config.get("mxid_mapping")
|
||||
if old_mxid_mapping:
|
||||
logger.warning(
|
||||
"The config option saml2_config.mxid_mapping is deprecated. Please "
|
||||
"use saml2_config.user_mapping_provider.config.mxid_mapping instead."
|
||||
)
|
||||
ump_dict["config"]["mxid_mapping"] = old_mxid_mapping
|
||||
|
||||
# Retrieve an instance of the module's class
|
||||
# Pass the config dictionary to the module for processing
|
||||
(
|
||||
self.saml2_user_mapping_provider_class,
|
||||
self.saml2_user_mapping_provider_config,
|
||||
) = load_module(ump_dict)
|
||||
|
||||
# Ensure loaded user mapping module has defined all necessary methods
|
||||
# Note parse_config() is already checked during the call to load_module
|
||||
required_methods = [
|
||||
"get_saml_attributes",
|
||||
"saml_response_to_user_attributes",
|
||||
]
|
||||
missing_methods = [
|
||||
method
|
||||
for method in required_methods
|
||||
if not hasattr(self.saml2_user_mapping_provider_class, method)
|
||||
]
|
||||
if missing_methods:
|
||||
raise ConfigError(
|
||||
"Class specified by saml2_config."
|
||||
"user_mapping_provider.module is missing required "
|
||||
"methods: %s" % (", ".join(missing_methods),)
|
||||
)
|
||||
|
||||
# Get the desired saml auth response attributes from the module
|
||||
saml2_config_dict = self._default_saml_config_dict(
|
||||
*self.saml2_user_mapping_provider_class.get_saml_attributes(
|
||||
self.saml2_user_mapping_provider_config
|
||||
)
|
||||
)
|
||||
saml2_config_dict = self._default_saml_config_dict()
|
||||
_dict_merge(
|
||||
merge_dict=saml2_config.get("sp_config", {}), into_dict=saml2_config_dict
|
||||
)
|
||||
@@ -159,27 +103,22 @@ class SAML2Config(Config):
|
||||
saml2_config.get("saml_session_lifetime", "5m")
|
||||
)
|
||||
|
||||
def _default_saml_config_dict(
|
||||
self, required_attributes: set, optional_attributes: set
|
||||
):
|
||||
"""Generate a configuration dictionary with required and optional attributes that
|
||||
will be needed to process new user registration
|
||||
mapping = saml2_config.get("mxid_mapping", "hexencode")
|
||||
try:
|
||||
self.saml2_mxid_mapper = MXID_MAPPER_MAP[mapping]
|
||||
except KeyError:
|
||||
raise ConfigError("%s is not a known mxid_mapping" % (mapping,))
|
||||
|
||||
Args:
|
||||
required_attributes: SAML auth response attributes that are
|
||||
necessary to function
|
||||
optional_attributes: SAML auth response attributes that can be used to add
|
||||
additional information to Synapse user accounts, but are not required
|
||||
|
||||
Returns:
|
||||
dict: A SAML configuration dictionary
|
||||
"""
|
||||
def _default_saml_config_dict(self):
|
||||
import saml2
|
||||
|
||||
public_baseurl = self.public_baseurl
|
||||
if public_baseurl is None:
|
||||
raise ConfigError("saml2_config requires a public_baseurl to be set")
|
||||
|
||||
required_attributes = {"uid", self.saml2_mxid_source_attribute}
|
||||
|
||||
optional_attributes = {"displayName"}
|
||||
if self.saml2_grandfathered_mxid_source_attribute:
|
||||
optional_attributes.add(self.saml2_grandfathered_mxid_source_attribute)
|
||||
optional_attributes -= required_attributes
|
||||
@@ -268,58 +207,33 @@ class SAML2Config(Config):
|
||||
#
|
||||
#config_path: "%(config_dir_path)s/sp_conf.py"
|
||||
|
||||
# The lifetime of a SAML session. This defines how long a user has to
|
||||
# the lifetime of a SAML session. This defines how long a user has to
|
||||
# complete the authentication process, if allow_unsolicited is unset.
|
||||
# The default is 5 minutes.
|
||||
#
|
||||
#saml_session_lifetime: 5m
|
||||
|
||||
# An external module can be provided here as a custom solution to
|
||||
# mapping attributes returned from a saml provider onto a matrix user.
|
||||
# The SAML attribute (after mapping via the attribute maps) to use to derive
|
||||
# the Matrix ID from. 'uid' by default.
|
||||
#
|
||||
user_mapping_provider:
|
||||
# The custom module's class. Uncomment to use a custom module.
|
||||
#
|
||||
#module: mapping_provider.SamlMappingProvider
|
||||
#mxid_source_attribute: displayName
|
||||
|
||||
# Custom configuration values for the module. Below options are
|
||||
# intended for the built-in provider, they should be changed if
|
||||
# using a custom module. This section will be passed as a Python
|
||||
# dictionary to the module's `parse_config` method.
|
||||
#
|
||||
config:
|
||||
# The SAML attribute (after mapping via the attribute maps) to use
|
||||
# to derive the Matrix ID from. 'uid' by default.
|
||||
#
|
||||
# Note: This used to be configured by the
|
||||
# saml2_config.mxid_source_attribute option. If that is still
|
||||
# defined, its value will be used instead.
|
||||
#
|
||||
#mxid_source_attribute: displayName
|
||||
# The mapping system to use for mapping the saml attribute onto a matrix ID.
|
||||
# Options include:
|
||||
# * 'hexencode' (which maps unpermitted characters to '=xx')
|
||||
# * 'dotreplace' (which replaces unpermitted characters with '.').
|
||||
# The default is 'hexencode'.
|
||||
#
|
||||
#mxid_mapping: dotreplace
|
||||
|
||||
# The mapping system to use for mapping the saml attribute onto a
|
||||
# matrix ID.
|
||||
#
|
||||
# Options include:
|
||||
# * 'hexencode' (which maps unpermitted characters to '=xx')
|
||||
# * 'dotreplace' (which replaces unpermitted characters with
|
||||
# '.').
|
||||
# The default is 'hexencode'.
|
||||
#
|
||||
# Note: This used to be configured by the
|
||||
# saml2_config.mxid_mapping option. If that is still defined, its
|
||||
# value will be used instead.
|
||||
#
|
||||
#mxid_mapping: dotreplace
|
||||
|
||||
# In previous versions of synapse, the mapping from SAML attribute to
|
||||
# MXID was always calculated dynamically rather than stored in a
|
||||
# table. For backwards- compatibility, we will look for user_ids
|
||||
# matching such a pattern before creating a new account.
|
||||
# In previous versions of synapse, the mapping from SAML attribute to MXID was
|
||||
# always calculated dynamically rather than stored in a table. For backwards-
|
||||
# compatibility, we will look for user_ids matching such a pattern before
|
||||
# creating a new account.
|
||||
#
|
||||
# This setting controls the SAML attribute which will be used for this
|
||||
# backwards-compatibility lookup. Typically it should be 'uid', but if
|
||||
# the attribute maps are changed, it may be necessary to change it.
|
||||
# backwards-compatibility lookup. Typically it should be 'uid', but if the
|
||||
# attribute maps are changed, it may be necessary to change it.
|
||||
#
|
||||
# The default is 'uid'.
|
||||
#
|
||||
@@ -327,3 +241,23 @@ class SAML2Config(Config):
|
||||
""" % {
|
||||
"config_dir_path": config_dir_path
|
||||
}
|
||||
|
||||
|
||||
DOT_REPLACE_PATTERN = re.compile(
|
||||
("[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters)),))
|
||||
)
|
||||
|
||||
|
||||
def dot_replace_for_mxid(username: str) -> str:
|
||||
username = username.lower()
|
||||
username = DOT_REPLACE_PATTERN.sub(".", username)
|
||||
|
||||
# regular mxids aren't allowed to start with an underscore either
|
||||
username = re.sub("^_", "", username)
|
||||
return username
|
||||
|
||||
|
||||
MXID_MAPPER_MAP = {
|
||||
"hexencode": map_username_to_mxid_localpart,
|
||||
"dotreplace": dot_replace_for_mxid,
|
||||
}
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import Set, Tuple
|
||||
|
||||
from canonicaljson import encode_canonical_json
|
||||
from signedjson.key import decode_verify_key_bytes
|
||||
@@ -42,14 +43,24 @@ def check(room_version, event, auth_events, do_sig_check=True, do_size_check=Tru
|
||||
Returns:
|
||||
if the auth checks pass.
|
||||
"""
|
||||
assert isinstance(auth_events, dict)
|
||||
|
||||
if do_size_check:
|
||||
_check_size_limits(event)
|
||||
|
||||
if not hasattr(event, "room_id"):
|
||||
raise AuthError(500, "Event has no room_id: %s" % event)
|
||||
|
||||
room_id = event.room_id
|
||||
|
||||
# I'm not really expecting to get auth events in the wrong room, but let's
|
||||
# sanity-check it
|
||||
for auth_event in auth_events.values():
|
||||
if auth_event.room_id != room_id:
|
||||
raise Exception(
|
||||
"During auth for event %s in room %s, found event %s in the state "
|
||||
"which is in room %s"
|
||||
% (event.event_id, room_id, auth_event.event_id, auth_event.room_id)
|
||||
)
|
||||
|
||||
if do_sig_check:
|
||||
sender_domain = get_domain_from_id(event.sender)
|
||||
|
||||
@@ -76,6 +87,12 @@ def check(room_version, event, auth_events, do_sig_check=True, do_size_check=Tru
|
||||
if not event.signatures.get(event_id_domain):
|
||||
raise AuthError(403, "Event not signed by sending server")
|
||||
|
||||
if auth_events is None:
|
||||
# Oh, we don't know what the state of the room was, so we
|
||||
# are trusting that this is allowed (at least for now)
|
||||
logger.warning("Trusting event: %s", event.event_id)
|
||||
return
|
||||
|
||||
if event.type == EventTypes.Create:
|
||||
sender_domain = get_domain_from_id(event.sender)
|
||||
room_id_domain = get_domain_from_id(event.room_id)
|
||||
@@ -621,7 +638,7 @@ def get_public_keys(invite_event):
|
||||
return public_keys
|
||||
|
||||
|
||||
def auth_types_for_event(event):
|
||||
def auth_types_for_event(event) -> Set[Tuple[str]]:
|
||||
"""Given an event, return a list of (EventType, StateKey) that may be
|
||||
needed to auth the event. The returned list may be a superset of what
|
||||
would actually be required depending on the full state of the room.
|
||||
@@ -630,20 +647,20 @@ def auth_types_for_event(event):
|
||||
actually auth the event.
|
||||
"""
|
||||
if event.type == EventTypes.Create:
|
||||
return []
|
||||
return set()
|
||||
|
||||
auth_types = [
|
||||
auth_types = {
|
||||
(EventTypes.PowerLevels, ""),
|
||||
(EventTypes.Member, event.sender),
|
||||
(EventTypes.Create, ""),
|
||||
]
|
||||
}
|
||||
|
||||
if event.type == EventTypes.Member:
|
||||
membership = event.content["membership"]
|
||||
if membership in [Membership.JOIN, Membership.INVITE]:
|
||||
auth_types.append((EventTypes.JoinRules, ""))
|
||||
auth_types.add((EventTypes.JoinRules, ""))
|
||||
|
||||
auth_types.append((EventTypes.Member, event.state_key))
|
||||
auth_types.add((EventTypes.Member, event.state_key))
|
||||
|
||||
if membership == Membership.INVITE:
|
||||
if "third_party_invite" in event.content:
|
||||
@@ -651,6 +668,6 @@ def auth_types_for_event(event):
|
||||
EventTypes.ThirdPartyInvite,
|
||||
event.content["third_party_invite"]["signed"]["token"],
|
||||
)
|
||||
auth_types.append(key)
|
||||
auth_types.add(key)
|
||||
|
||||
return auth_types
|
||||
|
||||
@@ -526,7 +526,13 @@ class FederationClient(FederationBase):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def send_request(destination):
|
||||
content = yield self._do_send_join(destination, pdu)
|
||||
time_now = self._clock.time_msec()
|
||||
_, content = yield self.transport_layer.send_join(
|
||||
destination=destination,
|
||||
room_id=pdu.room_id,
|
||||
event_id=pdu.event_id,
|
||||
content=pdu.get_pdu_json(time_now),
|
||||
)
|
||||
|
||||
logger.debug("Got content: %s", content)
|
||||
|
||||
@@ -593,44 +599,6 @@ class FederationClient(FederationBase):
|
||||
|
||||
return self._try_destination_list("send_join", destinations, send_request)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _do_send_join(self, destination, pdu):
|
||||
time_now = self._clock.time_msec()
|
||||
|
||||
try:
|
||||
content = yield self.transport_layer.send_join_v2(
|
||||
destination=destination,
|
||||
room_id=pdu.room_id,
|
||||
event_id=pdu.event_id,
|
||||
content=pdu.get_pdu_json(time_now),
|
||||
)
|
||||
|
||||
return content
|
||||
except HttpResponseException as e:
|
||||
if e.code in [400, 404]:
|
||||
err = e.to_synapse_error()
|
||||
|
||||
# If we receive an error response that isn't a generic error, or an
|
||||
# unrecognised endpoint error, we assume that the remote understands
|
||||
# the v2 invite API and this is a legitimate error.
|
||||
if err.errcode not in [Codes.UNKNOWN, Codes.UNRECOGNIZED]:
|
||||
raise err
|
||||
else:
|
||||
raise e.to_synapse_error()
|
||||
|
||||
logger.debug("Couldn't send_join with the v2 API, falling back to the v1 API")
|
||||
|
||||
resp = yield self.transport_layer.send_join_v1(
|
||||
destination=destination,
|
||||
room_id=pdu.room_id,
|
||||
event_id=pdu.event_id,
|
||||
content=pdu.get_pdu_json(time_now),
|
||||
)
|
||||
|
||||
# We expect the v1 API to respond with [200, content], so we only return the
|
||||
# content.
|
||||
return resp[1]
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def send_invite(self, destination, room_id, event_id, pdu):
|
||||
room_version = yield self.store.get_room_version(room_id)
|
||||
@@ -740,50 +708,18 @@ class FederationClient(FederationBase):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def send_request(destination):
|
||||
content = yield self._do_send_leave(destination, pdu)
|
||||
|
||||
logger.debug("Got content: %s", content)
|
||||
return None
|
||||
|
||||
return self._try_destination_list("send_leave", destinations, send_request)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _do_send_leave(self, destination, pdu):
|
||||
time_now = self._clock.time_msec()
|
||||
|
||||
try:
|
||||
content = yield self.transport_layer.send_leave_v2(
|
||||
time_now = self._clock.time_msec()
|
||||
_, content = yield self.transport_layer.send_leave(
|
||||
destination=destination,
|
||||
room_id=pdu.room_id,
|
||||
event_id=pdu.event_id,
|
||||
content=pdu.get_pdu_json(time_now),
|
||||
)
|
||||
|
||||
return content
|
||||
except HttpResponseException as e:
|
||||
if e.code in [400, 404]:
|
||||
err = e.to_synapse_error()
|
||||
logger.debug("Got content: %s", content)
|
||||
return None
|
||||
|
||||
# If we receive an error response that isn't a generic error, or an
|
||||
# unrecognised endpoint error, we assume that the remote understands
|
||||
# the v2 invite API and this is a legitimate error.
|
||||
if err.errcode not in [Codes.UNKNOWN, Codes.UNRECOGNIZED]:
|
||||
raise err
|
||||
else:
|
||||
raise e.to_synapse_error()
|
||||
|
||||
logger.debug("Couldn't send_leave with the v2 API, falling back to the v1 API")
|
||||
|
||||
resp = yield self.transport_layer.send_leave_v1(
|
||||
destination=destination,
|
||||
room_id=pdu.room_id,
|
||||
event_id=pdu.event_id,
|
||||
content=pdu.get_pdu_json(time_now),
|
||||
)
|
||||
|
||||
# We expect the v1 API to respond with [200, content], so we only return the
|
||||
# content.
|
||||
return resp[1]
|
||||
return self._try_destination_list("send_leave", destinations, send_request)
|
||||
|
||||
def get_public_rooms(
|
||||
self,
|
||||
|
||||
@@ -384,10 +384,15 @@ class FederationServer(FederationBase):
|
||||
|
||||
res_pdus = await self.handler.on_send_join_request(origin, pdu)
|
||||
time_now = self._clock.time_msec()
|
||||
return {
|
||||
"state": [p.get_pdu_json(time_now) for p in res_pdus["state"]],
|
||||
"auth_chain": [p.get_pdu_json(time_now) for p in res_pdus["auth_chain"]],
|
||||
}
|
||||
return (
|
||||
200,
|
||||
{
|
||||
"state": [p.get_pdu_json(time_now) for p in res_pdus["state"]],
|
||||
"auth_chain": [
|
||||
p.get_pdu_json(time_now) for p in res_pdus["auth_chain"]
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
async def on_make_leave_request(self, origin, room_id, user_id):
|
||||
origin_host, _ = parse_server_name(origin)
|
||||
@@ -414,7 +419,7 @@ class FederationServer(FederationBase):
|
||||
pdu = await self._check_sigs_and_hash(room_version, pdu)
|
||||
|
||||
await self.handler.on_send_leave_request(origin, pdu)
|
||||
return {}
|
||||
return 200, {}
|
||||
|
||||
async def on_event_auth(self, origin, room_id, event_id):
|
||||
with (await self._server_linearizer.queue((origin, room_id))):
|
||||
|
||||
@@ -243,7 +243,7 @@ class TransportLayerClient(object):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def send_join_v1(self, destination, room_id, event_id, content):
|
||||
def send_join(self, destination, room_id, event_id, content):
|
||||
path = _create_v1_path("/send_join/%s/%s", room_id, event_id)
|
||||
|
||||
response = yield self.client.put_json(
|
||||
@@ -254,18 +254,7 @@ class TransportLayerClient(object):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def send_join_v2(self, destination, room_id, event_id, content):
|
||||
path = _create_v2_path("/send_join/%s/%s", room_id, event_id)
|
||||
|
||||
response = yield self.client.put_json(
|
||||
destination=destination, path=path, data=content
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def send_leave_v1(self, destination, room_id, event_id, content):
|
||||
def send_leave(self, destination, room_id, event_id, content):
|
||||
path = _create_v1_path("/send_leave/%s/%s", room_id, event_id)
|
||||
|
||||
response = yield self.client.put_json(
|
||||
@@ -281,24 +270,6 @@ class TransportLayerClient(object):
|
||||
|
||||
return response
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def send_leave_v2(self, destination, room_id, event_id, content):
|
||||
path = _create_v2_path("/send_leave/%s/%s", room_id, event_id)
|
||||
|
||||
response = yield self.client.put_json(
|
||||
destination=destination,
|
||||
path=path,
|
||||
data=content,
|
||||
# we want to do our best to send this through. The problem is
|
||||
# that if it fails, we won't retry it later, so if the remote
|
||||
# server was just having a momentary blip, the room will be out of
|
||||
# sync.
|
||||
ignore_backoff=True,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def send_invite_v1(self, destination, room_id, event_id, content):
|
||||
|
||||
@@ -506,19 +506,9 @@ class FederationMakeLeaveServlet(BaseFederationServlet):
|
||||
return 200, content
|
||||
|
||||
|
||||
class FederationV1SendLeaveServlet(BaseFederationServlet):
|
||||
class FederationSendLeaveServlet(BaseFederationServlet):
|
||||
PATH = "/send_leave/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
|
||||
|
||||
async def on_PUT(self, origin, content, query, room_id, event_id):
|
||||
content = await self.handler.on_send_leave_request(origin, content, room_id)
|
||||
return 200, (200, content)
|
||||
|
||||
|
||||
class FederationV2SendLeaveServlet(BaseFederationServlet):
|
||||
PATH = "/send_leave/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
|
||||
|
||||
PREFIX = FEDERATION_V2_PREFIX
|
||||
|
||||
async def on_PUT(self, origin, content, query, room_id, event_id):
|
||||
content = await self.handler.on_send_leave_request(origin, content, room_id)
|
||||
return 200, content
|
||||
@@ -531,21 +521,9 @@ class FederationEventAuthServlet(BaseFederationServlet):
|
||||
return await self.handler.on_event_auth(origin, context, event_id)
|
||||
|
||||
|
||||
class FederationV1SendJoinServlet(BaseFederationServlet):
|
||||
class FederationSendJoinServlet(BaseFederationServlet):
|
||||
PATH = "/send_join/(?P<context>[^/]*)/(?P<event_id>[^/]*)"
|
||||
|
||||
async def on_PUT(self, origin, content, query, context, event_id):
|
||||
# TODO(paul): assert that context/event_id parsed from path actually
|
||||
# match those given in content
|
||||
content = await self.handler.on_send_join_request(origin, content, context)
|
||||
return 200, (200, content)
|
||||
|
||||
|
||||
class FederationV2SendJoinServlet(BaseFederationServlet):
|
||||
PATH = "/send_join/(?P<context>[^/]*)/(?P<event_id>[^/]*)"
|
||||
|
||||
PREFIX = FEDERATION_V2_PREFIX
|
||||
|
||||
async def on_PUT(self, origin, content, query, context, event_id):
|
||||
# TODO(paul): assert that context/event_id parsed from path actually
|
||||
# match those given in content
|
||||
@@ -1389,10 +1367,8 @@ FEDERATION_SERVLET_CLASSES = (
|
||||
FederationMakeJoinServlet,
|
||||
FederationMakeLeaveServlet,
|
||||
FederationEventServlet,
|
||||
FederationV1SendJoinServlet,
|
||||
FederationV2SendJoinServlet,
|
||||
FederationV1SendLeaveServlet,
|
||||
FederationV2SendLeaveServlet,
|
||||
FederationSendJoinServlet,
|
||||
FederationSendLeaveServlet,
|
||||
FederationV1InviteServlet,
|
||||
FederationV2InviteServlet,
|
||||
FederationQueryAuthServlet,
|
||||
|
||||
@@ -13,6 +13,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
|
||||
class AccountDataEventSource(object):
|
||||
def __init__(self, hs):
|
||||
@@ -21,14 +23,15 @@ class AccountDataEventSource(object):
|
||||
def get_current_key(self, direction="f"):
|
||||
return self.store.get_max_account_data_stream_id()
|
||||
|
||||
async def get_new_events(self, user, from_key, **kwargs):
|
||||
@defer.inlineCallbacks
|
||||
def get_new_events(self, user, from_key, **kwargs):
|
||||
user_id = user.to_string()
|
||||
last_stream_id = from_key
|
||||
|
||||
current_stream_id = self.store.get_max_account_data_stream_id()
|
||||
current_stream_id = yield self.store.get_max_account_data_stream_id()
|
||||
|
||||
results = []
|
||||
tags = await self.store.get_updated_tags(user_id, last_stream_id)
|
||||
tags = yield self.store.get_updated_tags(user_id, last_stream_id)
|
||||
|
||||
for room_id, room_tags in tags.items():
|
||||
results.append(
|
||||
@@ -38,7 +41,7 @@ class AccountDataEventSource(object):
|
||||
(
|
||||
account_data,
|
||||
room_account_data,
|
||||
) = await self.store.get_updated_account_data_for_user(user_id, last_stream_id)
|
||||
) = yield self.store.get_updated_account_data_for_user(user_id, last_stream_id)
|
||||
|
||||
for account_data_type, content in account_data.items():
|
||||
results.append({"type": account_data_type, "content": content})
|
||||
@@ -51,5 +54,6 @@ class AccountDataEventSource(object):
|
||||
|
||||
return results, current_stream_id
|
||||
|
||||
async def get_pagination_rows(self, user, config, key):
|
||||
@defer.inlineCallbacks
|
||||
def get_pagination_rows(self, user, config, key):
|
||||
return [], config.to_id
|
||||
|
||||
@@ -18,7 +18,8 @@ import email.utils
|
||||
import logging
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
from email.mime.text import MIMEText
|
||||
from typing import List
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.errors import StoreError
|
||||
from synapse.logging.context import make_deferred_yieldable
|
||||
@@ -77,39 +78,42 @@ class AccountValidityHandler(object):
|
||||
# run as a background process to make sure that the database transactions
|
||||
# have a logcontext to report to
|
||||
return run_as_background_process(
|
||||
"send_renewals", self._send_renewal_emails
|
||||
"send_renewals", self.send_renewal_emails
|
||||
)
|
||||
|
||||
self.clock.looping_call(send_emails, 30 * 60 * 1000)
|
||||
|
||||
async def _send_renewal_emails(self):
|
||||
@defer.inlineCallbacks
|
||||
def send_renewal_emails(self):
|
||||
"""Gets the list of users whose account is expiring in the amount of time
|
||||
configured in the ``renew_at`` parameter from the ``account_validity``
|
||||
configuration, and sends renewal emails to all of these users as long as they
|
||||
have an email 3PID attached to their account.
|
||||
"""
|
||||
expiring_users = await self.store.get_users_expiring_soon()
|
||||
expiring_users = yield self.store.get_users_expiring_soon()
|
||||
|
||||
if expiring_users:
|
||||
for user in expiring_users:
|
||||
await self._send_renewal_email(
|
||||
yield self._send_renewal_email(
|
||||
user_id=user["user_id"], expiration_ts=user["expiration_ts_ms"]
|
||||
)
|
||||
|
||||
async def send_renewal_email_to_user(self, user_id: str):
|
||||
expiration_ts = await self.store.get_expiration_ts_for_user(user_id)
|
||||
await self._send_renewal_email(user_id, expiration_ts)
|
||||
@defer.inlineCallbacks
|
||||
def send_renewal_email_to_user(self, user_id):
|
||||
expiration_ts = yield self.store.get_expiration_ts_for_user(user_id)
|
||||
yield self._send_renewal_email(user_id, expiration_ts)
|
||||
|
||||
async def _send_renewal_email(self, user_id: str, expiration_ts: int):
|
||||
@defer.inlineCallbacks
|
||||
def _send_renewal_email(self, user_id, expiration_ts):
|
||||
"""Sends out a renewal email to every email address attached to the given user
|
||||
with a unique link allowing them to renew their account.
|
||||
|
||||
Args:
|
||||
user_id: ID of the user to send email(s) to.
|
||||
expiration_ts: Timestamp in milliseconds for the expiration date of
|
||||
user_id (str): ID of the user to send email(s) to.
|
||||
expiration_ts (int): Timestamp in milliseconds for the expiration date of
|
||||
this user's account (used in the email templates).
|
||||
"""
|
||||
addresses = await self._get_email_addresses_for_user(user_id)
|
||||
addresses = yield self._get_email_addresses_for_user(user_id)
|
||||
|
||||
# Stop right here if the user doesn't have at least one email address.
|
||||
# In this case, they will have to ask their server admin to renew their
|
||||
@@ -121,7 +125,7 @@ class AccountValidityHandler(object):
|
||||
return
|
||||
|
||||
try:
|
||||
user_display_name = await self.store.get_profile_displayname(
|
||||
user_display_name = yield self.store.get_profile_displayname(
|
||||
UserID.from_string(user_id).localpart
|
||||
)
|
||||
if user_display_name is None:
|
||||
@@ -129,7 +133,7 @@ class AccountValidityHandler(object):
|
||||
except StoreError:
|
||||
user_display_name = user_id
|
||||
|
||||
renewal_token = await self._get_renewal_token(user_id)
|
||||
renewal_token = yield self._get_renewal_token(user_id)
|
||||
url = "%s_matrix/client/unstable/account_validity/renew?token=%s" % (
|
||||
self.hs.config.public_baseurl,
|
||||
renewal_token,
|
||||
@@ -161,7 +165,7 @@ class AccountValidityHandler(object):
|
||||
|
||||
logger.info("Sending renewal email to %s", address)
|
||||
|
||||
await make_deferred_yieldable(
|
||||
yield make_deferred_yieldable(
|
||||
self.sendmail(
|
||||
self.hs.config.email_smtp_host,
|
||||
self._raw_from,
|
||||
@@ -176,18 +180,19 @@ class AccountValidityHandler(object):
|
||||
)
|
||||
)
|
||||
|
||||
await self.store.set_renewal_mail_status(user_id=user_id, email_sent=True)
|
||||
yield self.store.set_renewal_mail_status(user_id=user_id, email_sent=True)
|
||||
|
||||
async def _get_email_addresses_for_user(self, user_id: str) -> List[str]:
|
||||
@defer.inlineCallbacks
|
||||
def _get_email_addresses_for_user(self, user_id):
|
||||
"""Retrieve the list of email addresses attached to a user's account.
|
||||
|
||||
Args:
|
||||
user_id: ID of the user to lookup email addresses for.
|
||||
user_id (str): ID of the user to lookup email addresses for.
|
||||
|
||||
Returns:
|
||||
Email addresses for this account.
|
||||
defer.Deferred[list[str]]: Email addresses for this account.
|
||||
"""
|
||||
threepids = await self.store.user_get_threepids(user_id)
|
||||
threepids = yield self.store.user_get_threepids(user_id)
|
||||
|
||||
addresses = []
|
||||
for threepid in threepids:
|
||||
@@ -196,15 +201,16 @@ class AccountValidityHandler(object):
|
||||
|
||||
return addresses
|
||||
|
||||
async def _get_renewal_token(self, user_id: str) -> str:
|
||||
@defer.inlineCallbacks
|
||||
def _get_renewal_token(self, user_id):
|
||||
"""Generates a 32-byte long random string that will be inserted into the
|
||||
user's renewal email's unique link, then saves it into the database.
|
||||
|
||||
Args:
|
||||
user_id: ID of the user to generate a string for.
|
||||
user_id (str): ID of the user to generate a string for.
|
||||
|
||||
Returns:
|
||||
The generated string.
|
||||
defer.Deferred[str]: The generated string.
|
||||
|
||||
Raises:
|
||||
StoreError(500): Couldn't generate a unique string after 5 attempts.
|
||||
@@ -213,52 +219,52 @@ class AccountValidityHandler(object):
|
||||
while attempts < 5:
|
||||
try:
|
||||
renewal_token = stringutils.random_string(32)
|
||||
await self.store.set_renewal_token_for_user(user_id, renewal_token)
|
||||
yield self.store.set_renewal_token_for_user(user_id, renewal_token)
|
||||
return renewal_token
|
||||
except StoreError:
|
||||
attempts += 1
|
||||
raise StoreError(500, "Couldn't generate a unique string as refresh string.")
|
||||
|
||||
async def renew_account(self, renewal_token: str) -> bool:
|
||||
@defer.inlineCallbacks
|
||||
def renew_account(self, renewal_token):
|
||||
"""Renews the account attached to a given renewal token by pushing back the
|
||||
expiration date by the current validity period in the server's configuration.
|
||||
|
||||
Args:
|
||||
renewal_token: Token sent with the renewal request.
|
||||
renewal_token (str): Token sent with the renewal request.
|
||||
Returns:
|
||||
Whether the provided token is valid.
|
||||
bool: Whether the provided token is valid.
|
||||
"""
|
||||
try:
|
||||
user_id = await self.store.get_user_from_renewal_token(renewal_token)
|
||||
user_id = yield self.store.get_user_from_renewal_token(renewal_token)
|
||||
except StoreError:
|
||||
return False
|
||||
defer.returnValue(False)
|
||||
|
||||
logger.debug("Renewing an account for user %s", user_id)
|
||||
await self.renew_account_for_user(user_id)
|
||||
yield self.renew_account_for_user(user_id)
|
||||
|
||||
return True
|
||||
defer.returnValue(True)
|
||||
|
||||
async def renew_account_for_user(
|
||||
self, user_id: str, expiration_ts: int = None, email_sent: bool = False
|
||||
) -> int:
|
||||
@defer.inlineCallbacks
|
||||
def renew_account_for_user(self, user_id, expiration_ts=None, email_sent=False):
|
||||
"""Renews the account attached to a given user by pushing back the
|
||||
expiration date by the current validity period in the server's
|
||||
configuration.
|
||||
|
||||
Args:
|
||||
renewal_token: Token sent with the renewal request.
|
||||
expiration_ts: New expiration date. Defaults to now + validity period.
|
||||
email_sen: Whether an email has been sent for this validity period.
|
||||
renewal_token (str): Token sent with the renewal request.
|
||||
expiration_ts (int): New expiration date. Defaults to now + validity period.
|
||||
email_sent (bool): Whether an email has been sent for this validity period.
|
||||
Defaults to False.
|
||||
|
||||
Returns:
|
||||
New expiration date for this account, as a timestamp in
|
||||
milliseconds since epoch.
|
||||
defer.Deferred[int]: New expiration date for this account, as a timestamp
|
||||
in milliseconds since epoch.
|
||||
"""
|
||||
if expiration_ts is None:
|
||||
expiration_ts = self.clock.time_msec() + self._account_validity.period
|
||||
|
||||
await self.store.set_account_validity_for_user(
|
||||
yield self.store.set_account_validity_for_user(
|
||||
user_id=user_id, expiration_ts=expiration_ts, email_sent=email_sent
|
||||
)
|
||||
|
||||
|
||||
@@ -264,7 +264,6 @@ class E2eKeysHandler(object):
|
||||
|
||||
return ret
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_cross_signing_keys_from_cache(self, query, from_user_id):
|
||||
"""Get cross-signing keys for users from the database
|
||||
|
||||
@@ -284,32 +283,14 @@ class E2eKeysHandler(object):
|
||||
self_signing_keys = {}
|
||||
user_signing_keys = {}
|
||||
|
||||
user_ids = list(query)
|
||||
|
||||
keys = yield self.store.get_e2e_cross_signing_keys_bulk(user_ids, from_user_id)
|
||||
|
||||
for user_id, user_info in keys.items():
|
||||
if user_info is None:
|
||||
continue
|
||||
if "master" in user_info:
|
||||
master_keys[user_id] = user_info["master"]
|
||||
if "self_signing" in user_info:
|
||||
self_signing_keys[user_id] = user_info["self_signing"]
|
||||
|
||||
if (
|
||||
from_user_id in keys
|
||||
and keys[from_user_id] is not None
|
||||
and "user_signing" in keys[from_user_id]
|
||||
):
|
||||
# users can see other users' master and self-signing keys, but can
|
||||
# only see their own user-signing keys
|
||||
user_signing_keys[from_user_id] = keys[from_user_id]["user_signing"]
|
||||
|
||||
return {
|
||||
"master_keys": master_keys,
|
||||
"self_signing_keys": self_signing_keys,
|
||||
"user_signing_keys": user_signing_keys,
|
||||
}
|
||||
# Currently a stub, implementation coming in https://github.com/matrix-org/synapse/pull/6486
|
||||
return defer.succeed(
|
||||
{
|
||||
"master_keys": master_keys,
|
||||
"self_signing_keys": self_signing_keys,
|
||||
"user_signing_keys": user_signing_keys,
|
||||
}
|
||||
)
|
||||
|
||||
@trace
|
||||
@defer.inlineCallbacks
|
||||
|
||||
@@ -19,7 +19,7 @@
|
||||
|
||||
import itertools
|
||||
import logging
|
||||
from typing import Dict, Iterable, List, Optional, Sequence, Tuple
|
||||
from typing import Dict, Iterable, Optional, Sequence, Tuple
|
||||
|
||||
import six
|
||||
from six import iteritems, itervalues
|
||||
@@ -63,10 +63,8 @@ from synapse.replication.http.federation import (
|
||||
)
|
||||
from synapse.replication.http.membership import ReplicationUserJoinedLeftRoomRestServlet
|
||||
from synapse.state import StateResolutionStore, resolve_events_with_store
|
||||
from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
|
||||
from synapse.types import UserID, get_domain_from_id
|
||||
from synapse.util import batch_iter, unwrapFirstError
|
||||
from synapse.util.async_helpers import Linearizer
|
||||
from synapse.util.async_helpers import Linearizer, concurrently_execute
|
||||
from synapse.util.distributor import user_joined_room
|
||||
from synapse.util.retryutils import NotRetryingDestination
|
||||
from synapse.visibility import filter_events_for_server
|
||||
@@ -165,7 +163,8 @@ class FederationHandler(BaseHandler):
|
||||
|
||||
self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
|
||||
|
||||
async def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False) -> None:
|
||||
@defer.inlineCallbacks
|
||||
def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False):
|
||||
""" Process a PDU received via a federation /send/ transaction, or
|
||||
via backfill of missing prev_events
|
||||
|
||||
@@ -175,15 +174,17 @@ class FederationHandler(BaseHandler):
|
||||
pdu (FrozenEvent): received PDU
|
||||
sent_to_us_directly (bool): True if this event was pushed to us; False if
|
||||
we pulled it as the result of a missing prev_event.
|
||||
|
||||
Returns (Deferred): completes with None
|
||||
"""
|
||||
|
||||
room_id = pdu.room_id
|
||||
event_id = pdu.event_id
|
||||
|
||||
logger.info("handling received PDU: %s", pdu)
|
||||
logger.info("[%s %s] handling received PDU: %s", room_id, event_id, pdu)
|
||||
|
||||
# We reprocess pdus when we have seen them only as outliers
|
||||
existing = await self.store.get_event(
|
||||
existing = yield self.store.get_event(
|
||||
event_id, allow_none=True, allow_rejected=True
|
||||
)
|
||||
|
||||
@@ -227,7 +228,7 @@ class FederationHandler(BaseHandler):
|
||||
#
|
||||
# Note that if we were never in the room then we would have already
|
||||
# dropped the event, since we wouldn't know the room version.
|
||||
is_in_room = await self.auth.check_host_in_room(room_id, self.server_name)
|
||||
is_in_room = yield self.auth.check_host_in_room(room_id, self.server_name)
|
||||
if not is_in_room:
|
||||
logger.info(
|
||||
"[%s %s] Ignoring PDU from %s as we're not in the room",
|
||||
@@ -238,17 +239,16 @@ class FederationHandler(BaseHandler):
|
||||
return None
|
||||
|
||||
state = None
|
||||
auth_chain = []
|
||||
|
||||
# Get missing pdus if necessary.
|
||||
if not pdu.internal_metadata.is_outlier():
|
||||
# We only backfill backwards to the min depth.
|
||||
min_depth = await self.get_min_depth_for_context(pdu.room_id)
|
||||
min_depth = yield self.get_min_depth_for_context(pdu.room_id)
|
||||
|
||||
logger.debug("[%s %s] min_depth: %d", room_id, event_id, min_depth)
|
||||
|
||||
prevs = set(pdu.prev_event_ids())
|
||||
seen = await self.store.have_seen_events(prevs)
|
||||
seen = yield self.store.have_seen_events(prevs)
|
||||
|
||||
if min_depth and pdu.depth < min_depth:
|
||||
# This is so that we don't notify the user about this
|
||||
@@ -268,7 +268,7 @@ class FederationHandler(BaseHandler):
|
||||
len(missing_prevs),
|
||||
shortstr(missing_prevs),
|
||||
)
|
||||
with (await self._room_pdu_linearizer.queue(pdu.room_id)):
|
||||
with (yield self._room_pdu_linearizer.queue(pdu.room_id)):
|
||||
logger.info(
|
||||
"[%s %s] Acquired room lock to fetch %d missing prev_events",
|
||||
room_id,
|
||||
@@ -276,19 +276,13 @@ class FederationHandler(BaseHandler):
|
||||
len(missing_prevs),
|
||||
)
|
||||
|
||||
try:
|
||||
await self._get_missing_events_for_pdu(
|
||||
origin, pdu, prevs, min_depth
|
||||
)
|
||||
except Exception as e:
|
||||
raise Exception(
|
||||
"Error fetching missing prev_events for %s: %s"
|
||||
% (event_id, e)
|
||||
)
|
||||
yield self._get_missing_events_for_pdu(
|
||||
origin, pdu, prevs, min_depth
|
||||
)
|
||||
|
||||
# Update the set of things we've seen after trying to
|
||||
# fetch the missing stuff
|
||||
seen = await self.store.have_seen_events(prevs)
|
||||
seen = yield self.store.have_seen_events(prevs)
|
||||
|
||||
if not prevs - seen:
|
||||
logger.info(
|
||||
@@ -296,6 +290,14 @@ class FederationHandler(BaseHandler):
|
||||
room_id,
|
||||
event_id,
|
||||
)
|
||||
elif missing_prevs:
|
||||
logger.info(
|
||||
"[%s %s] Not recursively fetching %d missing prev_events: %s",
|
||||
room_id,
|
||||
event_id,
|
||||
len(missing_prevs),
|
||||
shortstr(missing_prevs),
|
||||
)
|
||||
|
||||
if prevs - seen:
|
||||
# We've still not been able to get all of the prev_events for this event.
|
||||
@@ -340,19 +342,12 @@ class FederationHandler(BaseHandler):
|
||||
affected=pdu.event_id,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Event %s is missing prev_events: calculating state for a "
|
||||
"backwards extremity",
|
||||
event_id,
|
||||
)
|
||||
|
||||
# Calculate the state after each of the previous events, and
|
||||
# resolve them to find the correct state at the current event.
|
||||
auth_chains = set()
|
||||
event_map = {event_id: pdu}
|
||||
try:
|
||||
# Get the state of the events we know about
|
||||
ours = await self.state_store.get_state_groups_ids(room_id, seen)
|
||||
ours = yield self.state_store.get_state_groups_ids(room_id, seen)
|
||||
|
||||
# state_maps is a list of mappings from (type, state_key) to event_id
|
||||
state_maps = list(
|
||||
@@ -366,27 +361,20 @@ class FederationHandler(BaseHandler):
|
||||
# know about
|
||||
for p in prevs - seen:
|
||||
logger.info(
|
||||
"Requesting state at missing prev_event %s", event_id,
|
||||
"[%s %s] Requesting state at missing prev_event %s",
|
||||
room_id,
|
||||
event_id,
|
||||
p,
|
||||
)
|
||||
|
||||
room_version = await self.store.get_room_version(room_id)
|
||||
|
||||
with nested_logging_context(p):
|
||||
# note that if any of the missing prevs share missing state or
|
||||
# auth events, the requests to fetch those events are deduped
|
||||
# by the get_pdu_cache in federation_client.
|
||||
(
|
||||
remote_state,
|
||||
got_auth_chain,
|
||||
) = await self._get_state_for_room(
|
||||
(remote_state, _,) = yield self._get_state_for_room(
|
||||
origin, room_id, p, include_event_in_state=True
|
||||
)
|
||||
|
||||
# XXX hrm I'm not convinced that duplicate events will compare
|
||||
# for equality, so I'm not sure this does what the author
|
||||
# hoped.
|
||||
auth_chains.update(got_auth_chain)
|
||||
|
||||
remote_state_map = {
|
||||
(x.type, x.state_key): x.event_id for x in remote_state
|
||||
}
|
||||
@@ -395,7 +383,9 @@ class FederationHandler(BaseHandler):
|
||||
for x in remote_state:
|
||||
event_map[x.event_id] = x
|
||||
|
||||
state_map = await resolve_events_with_store(
|
||||
room_version = yield self.store.get_room_version(room_id)
|
||||
state_map = yield resolve_events_with_store(
|
||||
room_id,
|
||||
room_version,
|
||||
state_maps,
|
||||
event_map,
|
||||
@@ -407,15 +397,14 @@ class FederationHandler(BaseHandler):
|
||||
|
||||
# First though we need to fetch all the events that are in
|
||||
# state_map, so we can build up the state below.
|
||||
evs = await self.store.get_events(
|
||||
evs = yield self.store.get_events(
|
||||
list(state_map.values()),
|
||||
get_prev_content=False,
|
||||
redact_behaviour=EventRedactBehaviour.AS_IS,
|
||||
check_redacted=False,
|
||||
)
|
||||
event_map.update(evs)
|
||||
|
||||
state = [event_map[e] for e in six.itervalues(state_map)]
|
||||
auth_chain = list(auth_chains)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"[%s %s] Error attempting to resolve state at missing "
|
||||
@@ -431,11 +420,10 @@ class FederationHandler(BaseHandler):
|
||||
affected=event_id,
|
||||
)
|
||||
|
||||
await self._process_received_pdu(
|
||||
origin, pdu, state=state, auth_chain=auth_chain
|
||||
)
|
||||
yield self._process_received_pdu(origin, pdu, state=state)
|
||||
|
||||
async def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth):
|
||||
@defer.inlineCallbacks
|
||||
def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth):
|
||||
"""
|
||||
Args:
|
||||
origin (str): Origin of the pdu. Will be called to get the missing events
|
||||
@@ -447,12 +435,12 @@ class FederationHandler(BaseHandler):
|
||||
room_id = pdu.room_id
|
||||
event_id = pdu.event_id
|
||||
|
||||
seen = await self.store.have_seen_events(prevs)
|
||||
seen = yield self.store.have_seen_events(prevs)
|
||||
|
||||
if not prevs - seen:
|
||||
return
|
||||
|
||||
latest = await self.store.get_latest_event_ids_in_room(room_id)
|
||||
latest = yield self.store.get_latest_event_ids_in_room(room_id)
|
||||
|
||||
# We add the prev events that we have seen to the latest
|
||||
# list to ensure the remote server doesn't give them to us
|
||||
@@ -516,7 +504,7 @@ class FederationHandler(BaseHandler):
|
||||
# All that said: Let's try increasing the timout to 60s and see what happens.
|
||||
|
||||
try:
|
||||
missing_events = await self.federation_client.get_missing_events(
|
||||
missing_events = yield self.federation_client.get_missing_events(
|
||||
origin,
|
||||
room_id,
|
||||
earliest_events_ids=list(latest),
|
||||
@@ -555,7 +543,7 @@ class FederationHandler(BaseHandler):
|
||||
)
|
||||
with nested_logging_context(ev.event_id):
|
||||
try:
|
||||
await self.on_receive_pdu(origin, ev, sent_to_us_directly=False)
|
||||
yield self.on_receive_pdu(origin, ev, sent_to_us_directly=False)
|
||||
except FederationError as e:
|
||||
if e.code == 403:
|
||||
logger.warning(
|
||||
@@ -567,30 +555,29 @@ class FederationHandler(BaseHandler):
|
||||
else:
|
||||
raise
|
||||
|
||||
async def _get_state_for_room(
|
||||
self,
|
||||
destination: str,
|
||||
room_id: str,
|
||||
event_id: str,
|
||||
include_event_in_state: bool = False,
|
||||
) -> Tuple[List[EventBase], List[EventBase]]:
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def _get_state_for_room(
|
||||
self, destination, room_id, event_id, include_event_in_state
|
||||
):
|
||||
"""Requests all of the room state at a given event from a remote homeserver.
|
||||
|
||||
Args:
|
||||
destination: The remote homeserver to query for the state.
|
||||
room_id: The id of the room we're interested in.
|
||||
event_id: The id of the event we want the state at.
|
||||
destination (str): The remote homeserver to query for the state.
|
||||
room_id (str): The id of the room we're interested in.
|
||||
event_id (str): The id of the event we want the state at.
|
||||
include_event_in_state: if true, the event itself will be included in the
|
||||
returned state event list.
|
||||
|
||||
Returns:
|
||||
A list of events in the state, possibly including the event itself, and
|
||||
a list of events in the auth chain for the given event.
|
||||
Deferred[Tuple[List[EventBase], List[EventBase]]]:
|
||||
A list of events in the state, and a list of events in the auth chain
|
||||
for the given event.
|
||||
"""
|
||||
(
|
||||
state_event_ids,
|
||||
auth_event_ids,
|
||||
) = await self.federation_client.get_room_state_ids(
|
||||
) = yield self.federation_client.get_room_state_ids(
|
||||
destination, room_id, event_id=event_id
|
||||
)
|
||||
|
||||
@@ -599,15 +586,15 @@ class FederationHandler(BaseHandler):
|
||||
if include_event_in_state:
|
||||
desired_events.add(event_id)
|
||||
|
||||
event_map = await self._get_events_from_store_or_dest(
|
||||
event_map = yield self._get_events_from_store_or_dest(
|
||||
destination, room_id, desired_events
|
||||
)
|
||||
|
||||
failed_to_fetch = desired_events - event_map.keys()
|
||||
if failed_to_fetch:
|
||||
logger.warning(
|
||||
"Failed to fetch missing state/auth events for %s %s",
|
||||
event_id,
|
||||
"Failed to fetch missing state/auth events for %s: %s",
|
||||
room_id,
|
||||
failed_to_fetch,
|
||||
)
|
||||
|
||||
@@ -619,7 +606,7 @@ class FederationHandler(BaseHandler):
|
||||
remote_event = event_map.get(event_id)
|
||||
if not remote_event:
|
||||
raise Exception("Unable to get missing prev_event %s" % (event_id,))
|
||||
if remote_event.is_state():
|
||||
if remote_event.is_state() and remote_event.rejected_reason is None:
|
||||
remote_state.append(remote_event)
|
||||
|
||||
auth_chain = [event_map[e_id] for e_id in auth_event_ids if e_id in event_map]
|
||||
@@ -627,24 +614,26 @@ class FederationHandler(BaseHandler):
|
||||
|
||||
return remote_state, auth_chain
|
||||
|
||||
async def _get_events_from_store_or_dest(
|
||||
self, destination: str, room_id: str, event_ids: Iterable[str]
|
||||
) -> Dict[str, EventBase]:
|
||||
@defer.inlineCallbacks
|
||||
def _get_events_from_store_or_dest(self, destination, room_id, event_ids):
|
||||
"""Fetch events from a remote destination, checking if we already have them.
|
||||
|
||||
Args:
|
||||
destination
|
||||
room_id
|
||||
event_ids
|
||||
destination (str)
|
||||
room_id (str)
|
||||
event_ids (Iterable[str])
|
||||
|
||||
Persists any events we don't already have as outliers.
|
||||
|
||||
If we fail to fetch any of the events, a warning will be logged, and the event
|
||||
will be omitted from the result. Likewise, any events which turn out not to
|
||||
be in the given room.
|
||||
|
||||
Returns:
|
||||
map from event_id to event
|
||||
Deferred[dict[str, EventBase]]: A deferred resolving to a map
|
||||
from event_id to event
|
||||
"""
|
||||
fetched_events = await self.store.get_events(event_ids, allow_rejected=True)
|
||||
fetched_events = yield self.store.get_events(event_ids, allow_rejected=True)
|
||||
|
||||
missing_events = set(event_ids) - fetched_events.keys()
|
||||
|
||||
@@ -655,27 +644,15 @@ class FederationHandler(BaseHandler):
|
||||
room_id,
|
||||
)
|
||||
|
||||
room_version = await self.store.get_room_version(room_id)
|
||||
yield self._get_events_and_persist(
|
||||
destination=destination, room_id=room_id, events=missing_events
|
||||
)
|
||||
|
||||
# XXX 20 requests at once? really?
|
||||
for batch in batch_iter(missing_events, 20):
|
||||
deferreds = [
|
||||
run_in_background(
|
||||
self.federation_client.get_pdu,
|
||||
destinations=[destination],
|
||||
event_id=e_id,
|
||||
room_version=room_version,
|
||||
)
|
||||
for e_id in batch
|
||||
]
|
||||
|
||||
res = await make_deferred_yieldable(
|
||||
defer.DeferredList(deferreds, consumeErrors=True)
|
||||
)
|
||||
|
||||
for success, result in res:
|
||||
if success and result:
|
||||
fetched_events[result.event_id] = result
|
||||
# we need to make sure we re-load from the database to get the rejected
|
||||
# state correct.
|
||||
fetched_events.update(
|
||||
(yield self.store.get_events(missing_events, allow_rejected=True))
|
||||
)
|
||||
|
||||
# check for events which were in the wrong room.
|
||||
#
|
||||
@@ -704,60 +681,35 @@ class FederationHandler(BaseHandler):
|
||||
|
||||
return fetched_events
|
||||
|
||||
async def _process_received_pdu(self, origin, event, state, auth_chain):
|
||||
@defer.inlineCallbacks
|
||||
def _process_received_pdu(self, origin, event, state):
|
||||
""" Called when we have a new pdu. We need to do auth checks and put it
|
||||
through the StateHandler.
|
||||
|
||||
Args:
|
||||
origin: server sending the event
|
||||
|
||||
event: event to be persisted
|
||||
|
||||
state: Normally None, but if we are handling a gap in the graph
|
||||
(ie, we are missing one or more prev_events), the resolved state at the
|
||||
event
|
||||
"""
|
||||
room_id = event.room_id
|
||||
event_id = event.event_id
|
||||
|
||||
logger.debug("[%s %s] Processing event: %s", room_id, event_id, event)
|
||||
|
||||
event_ids = set()
|
||||
if state:
|
||||
event_ids |= {e.event_id for e in state}
|
||||
if auth_chain:
|
||||
event_ids |= {e.event_id for e in auth_chain}
|
||||
|
||||
seen_ids = await self.store.have_seen_events(event_ids)
|
||||
|
||||
if state and auth_chain is not None:
|
||||
# If we have any state or auth_chain given to us by the replication
|
||||
# layer, then we should handle them (if we haven't before.)
|
||||
|
||||
event_infos = []
|
||||
|
||||
for e in itertools.chain(auth_chain, state):
|
||||
if e.event_id in seen_ids:
|
||||
continue
|
||||
e.internal_metadata.outlier = True
|
||||
auth_ids = e.auth_event_ids()
|
||||
auth = {
|
||||
(e.type, e.state_key): e
|
||||
for e in auth_chain
|
||||
if e.event_id in auth_ids or e.type == EventTypes.Create
|
||||
}
|
||||
event_infos.append(_NewEventInfo(event=e, auth_events=auth))
|
||||
seen_ids.add(e.event_id)
|
||||
|
||||
logger.info(
|
||||
"[%s %s] persisting newly-received auth/state events %s",
|
||||
room_id,
|
||||
event_id,
|
||||
[e.event.event_id for e in event_infos],
|
||||
)
|
||||
await self._handle_new_events(origin, event_infos)
|
||||
|
||||
try:
|
||||
context = await self._handle_new_event(origin, event, state=state)
|
||||
context = yield self._handle_new_event(origin, event, state=state)
|
||||
except AuthError as e:
|
||||
raise FederationError("ERROR", e.code, e.msg, affected=event.event_id)
|
||||
|
||||
room = await self.store.get_room(room_id)
|
||||
room = yield self.store.get_room(room_id)
|
||||
|
||||
if not room:
|
||||
try:
|
||||
await self.store.store_room(
|
||||
yield self.store.store_room(
|
||||
room_id=room_id, room_creator_user_id="", is_public=False
|
||||
)
|
||||
except StoreError:
|
||||
@@ -770,11 +722,11 @@ class FederationHandler(BaseHandler):
|
||||
# changing their profile info.
|
||||
newly_joined = True
|
||||
|
||||
prev_state_ids = await context.get_prev_state_ids(self.store)
|
||||
prev_state_ids = yield context.get_prev_state_ids(self.store)
|
||||
|
||||
prev_state_id = prev_state_ids.get((event.type, event.state_key))
|
||||
if prev_state_id:
|
||||
prev_state = await self.store.get_event(
|
||||
prev_state = yield self.store.get_event(
|
||||
prev_state_id, allow_none=True
|
||||
)
|
||||
if prev_state and prev_state.membership == Membership.JOIN:
|
||||
@@ -782,10 +734,11 @@ class FederationHandler(BaseHandler):
|
||||
|
||||
if newly_joined:
|
||||
user = UserID.from_string(event.state_key)
|
||||
await self.user_joined_room(user, room_id)
|
||||
yield self.user_joined_room(user, room_id)
|
||||
|
||||
@log_function
|
||||
async def backfill(self, dest, room_id, limit, extremities):
|
||||
@defer.inlineCallbacks
|
||||
def backfill(self, dest, room_id, limit, extremities):
|
||||
""" Trigger a backfill request to `dest` for the given `room_id`
|
||||
|
||||
This will attempt to get more events from the remote. If the other side
|
||||
@@ -802,9 +755,7 @@ class FederationHandler(BaseHandler):
|
||||
if dest == self.server_name:
|
||||
raise SynapseError(400, "Can't backfill from self.")
|
||||
|
||||
room_version = await self.store.get_room_version(room_id)
|
||||
|
||||
events = await self.federation_client.backfill(
|
||||
events = yield self.federation_client.backfill(
|
||||
dest, room_id, limit=limit, extremities=extremities
|
||||
)
|
||||
|
||||
@@ -819,7 +770,7 @@ class FederationHandler(BaseHandler):
|
||||
# self._sanity_check_event(ev)
|
||||
|
||||
# Don't bother processing events we already have.
|
||||
seen_events = await self.store.have_events_in_timeline(
|
||||
seen_events = yield self.store.have_events_in_timeline(
|
||||
set(e.event_id for e in events)
|
||||
)
|
||||
|
||||
@@ -832,6 +783,9 @@ class FederationHandler(BaseHandler):
|
||||
|
||||
event_ids = set(e.event_id for e in events)
|
||||
|
||||
# build a list of events whose prev_events weren't in the batch.
|
||||
# (XXX: this will include events whose prev_events we already have; that doesn't
|
||||
# sound right?)
|
||||
edges = [ev.event_id for ev in events if set(ev.prev_event_ids()) - event_ids]
|
||||
|
||||
logger.info("backfill: Got %d events with %d edges", len(events), len(edges))
|
||||
@@ -842,8 +796,11 @@ class FederationHandler(BaseHandler):
|
||||
state_events = {}
|
||||
events_to_state = {}
|
||||
for e_id in edges:
|
||||
state, auth = await self._get_state_for_room(
|
||||
destination=dest, room_id=room_id, event_id=e_id
|
||||
state, auth = yield self._get_state_for_room(
|
||||
destination=dest,
|
||||
room_id=room_id,
|
||||
event_id=e_id,
|
||||
include_event_in_state=False,
|
||||
)
|
||||
auth_events.update({a.event_id: a for a in auth})
|
||||
auth_events.update({s.event_id: s for s in state})
|
||||
@@ -860,95 +817,11 @@ class FederationHandler(BaseHandler):
|
||||
auth_events.update(
|
||||
{e_id: event_map[e_id] for e_id in required_auth if e_id in event_map}
|
||||
)
|
||||
missing_auth = required_auth - set(auth_events)
|
||||
failed_to_fetch = set()
|
||||
|
||||
# Try and fetch any missing auth events from both DB and remote servers.
|
||||
# We repeatedly do this until we stop finding new auth events.
|
||||
while missing_auth - failed_to_fetch:
|
||||
logger.info("Missing auth for backfill: %r", missing_auth)
|
||||
ret_events = await self.store.get_events(missing_auth - failed_to_fetch)
|
||||
auth_events.update(ret_events)
|
||||
|
||||
required_auth.update(
|
||||
a_id for event in ret_events.values() for a_id in event.auth_event_ids()
|
||||
)
|
||||
missing_auth = required_auth - set(auth_events)
|
||||
|
||||
if missing_auth - failed_to_fetch:
|
||||
logger.info(
|
||||
"Fetching missing auth for backfill: %r",
|
||||
missing_auth - failed_to_fetch,
|
||||
)
|
||||
|
||||
results = await make_deferred_yieldable(
|
||||
defer.gatherResults(
|
||||
[
|
||||
run_in_background(
|
||||
self.federation_client.get_pdu,
|
||||
[dest],
|
||||
event_id,
|
||||
room_version=room_version,
|
||||
outlier=True,
|
||||
timeout=10000,
|
||||
)
|
||||
for event_id in missing_auth - failed_to_fetch
|
||||
],
|
||||
consumeErrors=True,
|
||||
)
|
||||
).addErrback(unwrapFirstError)
|
||||
auth_events.update({a.event_id: a for a in results if a})
|
||||
required_auth.update(
|
||||
a_id
|
||||
for event in results
|
||||
if event
|
||||
for a_id in event.auth_event_ids()
|
||||
)
|
||||
missing_auth = required_auth - set(auth_events)
|
||||
|
||||
failed_to_fetch = missing_auth - set(auth_events)
|
||||
|
||||
seen_events = await self.store.have_seen_events(
|
||||
set(auth_events.keys()) | set(state_events.keys())
|
||||
)
|
||||
|
||||
# We now have a chunk of events plus associated state and auth chain to
|
||||
# persist. We do the persistence in two steps:
|
||||
# 1. Auth events and state get persisted as outliers, plus the
|
||||
# backward extremities get persisted (as non-outliers).
|
||||
# 2. The rest of the events in the chunk get persisted one by one, as
|
||||
# each one depends on the previous event for its state.
|
||||
#
|
||||
# The important thing is that events in the chunk get persisted as
|
||||
# non-outliers, including when those events are also in the state or
|
||||
# auth chain. Caution must therefore be taken to ensure that they are
|
||||
# not accidentally marked as outliers.
|
||||
|
||||
# Step 1a: persist auth events that *don't* appear in the chunk
|
||||
ev_infos = []
|
||||
for a in auth_events.values():
|
||||
# We only want to persist auth events as outliers that we haven't
|
||||
# seen and aren't about to persist as part of the backfilled chunk.
|
||||
if a.event_id in seen_events or a.event_id in event_map:
|
||||
continue
|
||||
|
||||
a.internal_metadata.outlier = True
|
||||
ev_infos.append(
|
||||
_NewEventInfo(
|
||||
event=a,
|
||||
auth_events={
|
||||
(
|
||||
auth_events[a_id].type,
|
||||
auth_events[a_id].state_key,
|
||||
): auth_events[a_id]
|
||||
for a_id in a.auth_event_ids()
|
||||
if a_id in auth_events
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# Step 1b: persist the events in the chunk we fetched state for (i.e.
|
||||
# the backwards extremities) as non-outliers.
|
||||
# Step 1: persist the events in the chunk we fetched state for (i.e.
|
||||
# the backwards extremities), with custom auth events and state
|
||||
for e_id in events_to_state:
|
||||
# For paranoia we ensure that these events are marked as
|
||||
# non-outliers
|
||||
@@ -970,7 +843,7 @@ class FederationHandler(BaseHandler):
|
||||
)
|
||||
)
|
||||
|
||||
await self._handle_new_events(dest, ev_infos, backfilled=True)
|
||||
yield self._handle_new_events(dest, ev_infos, backfilled=True)
|
||||
|
||||
# Step 2: Persist the rest of the events in the chunk one by one
|
||||
events.sort(key=lambda e: e.depth)
|
||||
@@ -986,15 +859,16 @@ class FederationHandler(BaseHandler):
|
||||
# We store these one at a time since each event depends on the
|
||||
# previous to work out the state.
|
||||
# TODO: We can probably do something more clever here.
|
||||
await self._handle_new_event(dest, event, backfilled=True)
|
||||
yield self._handle_new_event(dest, event, backfilled=True)
|
||||
|
||||
return events
|
||||
|
||||
async def maybe_backfill(self, room_id, current_depth):
|
||||
@defer.inlineCallbacks
|
||||
def maybe_backfill(self, room_id, current_depth):
|
||||
"""Checks the database to see if we should backfill before paginating,
|
||||
and if so do.
|
||||
"""
|
||||
extremities = await self.store.get_oldest_events_with_depth_in_room(room_id)
|
||||
extremities = yield self.store.get_oldest_events_with_depth_in_room(room_id)
|
||||
|
||||
if not extremities:
|
||||
logger.debug("Not backfilling as no extremeties found.")
|
||||
@@ -1026,17 +900,15 @@ class FederationHandler(BaseHandler):
|
||||
# state *before* the event, ignoring the special casing certain event
|
||||
# types have.
|
||||
|
||||
forward_events = await self.store.get_successor_events(list(extremities))
|
||||
forward_events = yield self.store.get_successor_events(list(extremities))
|
||||
|
||||
extremities_events = await self.store.get_events(
|
||||
forward_events,
|
||||
redact_behaviour=EventRedactBehaviour.AS_IS,
|
||||
get_prev_content=False,
|
||||
extremities_events = yield self.store.get_events(
|
||||
forward_events, check_redacted=False, get_prev_content=False
|
||||
)
|
||||
|
||||
# We set `check_history_visibility_only` as we might otherwise get false
|
||||
# positives from users having been erased.
|
||||
filtered_extremities = await filter_events_for_server(
|
||||
filtered_extremities = yield filter_events_for_server(
|
||||
self.storage,
|
||||
self.server_name,
|
||||
list(extremities_events.values()),
|
||||
@@ -1066,7 +938,7 @@ class FederationHandler(BaseHandler):
|
||||
# First we try hosts that are already in the room
|
||||
# TODO: HEURISTIC ALERT.
|
||||
|
||||
curr_state = await self.state_handler.get_current_state(room_id)
|
||||
curr_state = yield self.state_handler.get_current_state(room_id)
|
||||
|
||||
def get_domains_from_state(state):
|
||||
"""Get joined domains from state
|
||||
@@ -1105,11 +977,12 @@ class FederationHandler(BaseHandler):
|
||||
domain for domain, depth in curr_domains if domain != self.server_name
|
||||
]
|
||||
|
||||
async def try_backfill(domains):
|
||||
@defer.inlineCallbacks
|
||||
def try_backfill(domains):
|
||||
# TODO: Should we try multiple of these at a time?
|
||||
for dom in domains:
|
||||
try:
|
||||
await self.backfill(
|
||||
yield self.backfill(
|
||||
dom, room_id, limit=100, extremities=extremities
|
||||
)
|
||||
# If this succeeded then we probably already have the
|
||||
@@ -1140,7 +1013,7 @@ class FederationHandler(BaseHandler):
|
||||
|
||||
return False
|
||||
|
||||
success = await try_backfill(likely_domains)
|
||||
success = yield try_backfill(likely_domains)
|
||||
if success:
|
||||
return True
|
||||
|
||||
@@ -1154,7 +1027,7 @@ class FederationHandler(BaseHandler):
|
||||
|
||||
logger.debug("calling resolve_state_groups in _maybe_backfill")
|
||||
resolve = preserve_fn(self.state_handler.resolve_state_groups_for_events)
|
||||
states = await make_deferred_yieldable(
|
||||
states = yield make_deferred_yieldable(
|
||||
defer.gatherResults(
|
||||
[resolve(room_id, [e]) for e in event_ids], consumeErrors=True
|
||||
)
|
||||
@@ -1164,7 +1037,7 @@ class FederationHandler(BaseHandler):
|
||||
# event_ids.
|
||||
states = dict(zip(event_ids, [s.state for s in states]))
|
||||
|
||||
state_map = await self.store.get_events(
|
||||
state_map = yield self.store.get_events(
|
||||
[e_id for ids in itervalues(states) for e_id in itervalues(ids)],
|
||||
get_prev_content=False,
|
||||
)
|
||||
@@ -1180,7 +1053,7 @@ class FederationHandler(BaseHandler):
|
||||
for e_id, _ in sorted_extremeties_tuple:
|
||||
likely_domains = get_domains_from_state(states[e_id])
|
||||
|
||||
success = await try_backfill(
|
||||
success = yield try_backfill(
|
||||
[dom for dom, _ in likely_domains if dom not in tried_domains]
|
||||
)
|
||||
if success:
|
||||
@@ -1190,6 +1063,57 @@ class FederationHandler(BaseHandler):
|
||||
|
||||
return False
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _get_events_and_persist(
|
||||
self, destination: str, room_id: str, events: Iterable[str]
|
||||
):
|
||||
"""Fetch the given events from a server, and persist them as outliers.
|
||||
|
||||
Logs a warning if we can't find the given event.
|
||||
"""
|
||||
|
||||
room_version = yield self.store.get_room_version(room_id)
|
||||
|
||||
event_infos = []
|
||||
|
||||
async def get_event(event_id: str):
|
||||
with nested_logging_context(event_id):
|
||||
try:
|
||||
event = await self.federation_client.get_pdu(
|
||||
[destination], event_id, room_version, outlier=True,
|
||||
)
|
||||
if event is None:
|
||||
logger.warning(
|
||||
"Server %s didn't return event %s", destination, event_id,
|
||||
)
|
||||
return
|
||||
|
||||
# recursively fetch the auth events for this event
|
||||
auth_events = await self._get_events_from_store_or_dest(
|
||||
destination, room_id, event.auth_event_ids()
|
||||
)
|
||||
auth = {}
|
||||
for auth_event_id in event.auth_event_ids():
|
||||
ae = auth_events.get(auth_event_id)
|
||||
if ae:
|
||||
auth[(ae.type, ae.state_key)] = ae
|
||||
|
||||
event_infos.append(_NewEventInfo(event, None, auth))
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Error fetching missing state/auth event %s: %s %s",
|
||||
event_id,
|
||||
type(e),
|
||||
e,
|
||||
)
|
||||
|
||||
yield concurrently_execute(get_event, events, 5)
|
||||
|
||||
yield self._handle_new_events(
|
||||
destination, event_infos,
|
||||
)
|
||||
|
||||
def _sanity_check_event(self, ev):
|
||||
"""
|
||||
Do some early sanity checks of a received event
|
||||
@@ -1329,7 +1253,7 @@ class FederationHandler(BaseHandler):
|
||||
# Check whether this room is the result of an upgrade of a room we already know
|
||||
# about. If so, migrate over user information
|
||||
predecessor = yield self.store.get_room_predecessor(room_id)
|
||||
if not predecessor or not isinstance(predecessor.get("room_id"), str):
|
||||
if not predecessor:
|
||||
return
|
||||
old_room_id = predecessor["room_id"]
|
||||
logger.debug(
|
||||
@@ -1357,7 +1281,8 @@ class FederationHandler(BaseHandler):
|
||||
|
||||
return True
|
||||
|
||||
async def _handle_queued_pdus(self, room_queue):
|
||||
@defer.inlineCallbacks
|
||||
def _handle_queued_pdus(self, room_queue):
|
||||
"""Process PDUs which got queued up while we were busy send_joining.
|
||||
|
||||
Args:
|
||||
@@ -1373,7 +1298,7 @@ class FederationHandler(BaseHandler):
|
||||
p.room_id,
|
||||
)
|
||||
with nested_logging_context(p.event_id):
|
||||
await self.on_receive_pdu(origin, p, sent_to_us_directly=True)
|
||||
yield self.on_receive_pdu(origin, p, sent_to_us_directly=True)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Error handling queued PDU %s from %s: %s", p.event_id, origin, e
|
||||
@@ -1571,7 +1496,7 @@ class FederationHandler(BaseHandler):
|
||||
@defer.inlineCallbacks
|
||||
def do_remotely_reject_invite(self, target_hosts, room_id, user_id, content):
|
||||
origin, event, event_format_version = yield self._make_and_verify_event(
|
||||
target_hosts, room_id, user_id, "leave", content=content
|
||||
target_hosts, room_id, user_id, "leave", content=content,
|
||||
)
|
||||
# Mark as outlier as we don't have any state for this event; we're not
|
||||
# even in the room.
|
||||
@@ -2932,7 +2857,7 @@ class FederationHandler(BaseHandler):
|
||||
room_id=room_id, user_id=user.to_string(), change="joined"
|
||||
)
|
||||
else:
|
||||
return defer.succeed(user_joined_room(self.distributor, user, room_id))
|
||||
return user_joined_room(self.distributor, user, room_id)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_room_complexity(self, remote_room_hosts, room_id):
|
||||
|
||||
@@ -26,7 +26,7 @@ from synapse.streams.config import PaginationConfig
|
||||
from synapse.types import StreamToken, UserID
|
||||
from synapse.util import unwrapFirstError
|
||||
from synapse.util.async_helpers import concurrently_execute
|
||||
from synapse.util.caches.response_cache import ResponseCache
|
||||
from synapse.util.caches.snapshot_cache import SnapshotCache
|
||||
from synapse.visibility import filter_events_for_client
|
||||
|
||||
from ._base import BaseHandler
|
||||
@@ -41,7 +41,7 @@ class InitialSyncHandler(BaseHandler):
|
||||
self.state = hs.get_state_handler()
|
||||
self.clock = hs.get_clock()
|
||||
self.validator = EventValidator()
|
||||
self.snapshot_cache = ResponseCache(hs, "initial_sync_cache")
|
||||
self.snapshot_cache = SnapshotCache()
|
||||
self._event_serializer = hs.get_event_client_serializer()
|
||||
self.storage = hs.get_storage()
|
||||
self.state_store = self.storage.state
|
||||
@@ -79,17 +79,21 @@ class InitialSyncHandler(BaseHandler):
|
||||
as_client_event,
|
||||
include_archived,
|
||||
)
|
||||
now_ms = self.clock.time_msec()
|
||||
result = self.snapshot_cache.get(now_ms, key)
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
return self.snapshot_cache.wrap(
|
||||
return self.snapshot_cache.set(
|
||||
now_ms,
|
||||
key,
|
||||
self._snapshot_all_rooms,
|
||||
user_id,
|
||||
pagin_config,
|
||||
as_client_event,
|
||||
include_archived,
|
||||
self._snapshot_all_rooms(
|
||||
user_id, pagin_config, as_client_event, include_archived
|
||||
),
|
||||
)
|
||||
|
||||
async def _snapshot_all_rooms(
|
||||
@defer.inlineCallbacks
|
||||
def _snapshot_all_rooms(
|
||||
self,
|
||||
user_id=None,
|
||||
pagin_config=None,
|
||||
@@ -101,7 +105,7 @@ class InitialSyncHandler(BaseHandler):
|
||||
if include_archived:
|
||||
memberships.append(Membership.LEAVE)
|
||||
|
||||
room_list = await self.store.get_rooms_for_user_where_membership_is(
|
||||
room_list = yield self.store.get_rooms_for_user_where_membership_is(
|
||||
user_id=user_id, membership_list=memberships
|
||||
)
|
||||
|
||||
@@ -109,32 +113,33 @@ class InitialSyncHandler(BaseHandler):
|
||||
|
||||
rooms_ret = []
|
||||
|
||||
now_token = await self.hs.get_event_sources().get_current_token()
|
||||
now_token = yield self.hs.get_event_sources().get_current_token()
|
||||
|
||||
presence_stream = self.hs.get_event_sources().sources["presence"]
|
||||
pagination_config = PaginationConfig(from_token=now_token)
|
||||
presence, _ = await presence_stream.get_pagination_rows(
|
||||
presence, _ = yield presence_stream.get_pagination_rows(
|
||||
user, pagination_config.get_source_config("presence"), None
|
||||
)
|
||||
|
||||
receipt_stream = self.hs.get_event_sources().sources["receipt"]
|
||||
receipt, _ = await receipt_stream.get_pagination_rows(
|
||||
receipt, _ = yield receipt_stream.get_pagination_rows(
|
||||
user, pagination_config.get_source_config("receipt"), None
|
||||
)
|
||||
|
||||
tags_by_room = await self.store.get_tags_for_user(user_id)
|
||||
tags_by_room = yield self.store.get_tags_for_user(user_id)
|
||||
|
||||
account_data, account_data_by_room = await self.store.get_account_data_for_user(
|
||||
account_data, account_data_by_room = yield self.store.get_account_data_for_user(
|
||||
user_id
|
||||
)
|
||||
|
||||
public_room_ids = await self.store.get_public_room_ids()
|
||||
public_room_ids = yield self.store.get_public_room_ids()
|
||||
|
||||
limit = pagin_config.limit
|
||||
if limit is None:
|
||||
limit = 10
|
||||
|
||||
async def handle_room(event):
|
||||
@defer.inlineCallbacks
|
||||
def handle_room(event):
|
||||
d = {
|
||||
"room_id": event.room_id,
|
||||
"membership": event.membership,
|
||||
@@ -147,8 +152,8 @@ class InitialSyncHandler(BaseHandler):
|
||||
time_now = self.clock.time_msec()
|
||||
d["inviter"] = event.sender
|
||||
|
||||
invite_event = await self.store.get_event(event.event_id)
|
||||
d["invite"] = await self._event_serializer.serialize_event(
|
||||
invite_event = yield self.store.get_event(event.event_id)
|
||||
d["invite"] = yield self._event_serializer.serialize_event(
|
||||
invite_event, time_now, as_client_event
|
||||
)
|
||||
|
||||
@@ -172,7 +177,7 @@ class InitialSyncHandler(BaseHandler):
|
||||
lambda states: states[event.event_id]
|
||||
)
|
||||
|
||||
(messages, token), current_state = await make_deferred_yieldable(
|
||||
(messages, token), current_state = yield make_deferred_yieldable(
|
||||
defer.gatherResults(
|
||||
[
|
||||
run_in_background(
|
||||
@@ -186,7 +191,7 @@ class InitialSyncHandler(BaseHandler):
|
||||
)
|
||||
).addErrback(unwrapFirstError)
|
||||
|
||||
messages = await filter_events_for_client(
|
||||
messages = yield filter_events_for_client(
|
||||
self.storage, user_id, messages
|
||||
)
|
||||
|
||||
@@ -196,7 +201,7 @@ class InitialSyncHandler(BaseHandler):
|
||||
|
||||
d["messages"] = {
|
||||
"chunk": (
|
||||
await self._event_serializer.serialize_events(
|
||||
yield self._event_serializer.serialize_events(
|
||||
messages, time_now=time_now, as_client_event=as_client_event
|
||||
)
|
||||
),
|
||||
@@ -204,7 +209,7 @@ class InitialSyncHandler(BaseHandler):
|
||||
"end": end_token.to_string(),
|
||||
}
|
||||
|
||||
d["state"] = await self._event_serializer.serialize_events(
|
||||
d["state"] = yield self._event_serializer.serialize_events(
|
||||
current_state.values(),
|
||||
time_now=time_now,
|
||||
as_client_event=as_client_event,
|
||||
@@ -227,7 +232,7 @@ class InitialSyncHandler(BaseHandler):
|
||||
except Exception:
|
||||
logger.exception("Failed to get snapshot")
|
||||
|
||||
await concurrently_execute(handle_room, room_list, 10)
|
||||
yield concurrently_execute(handle_room, room_list, 10)
|
||||
|
||||
account_data_events = []
|
||||
for account_data_type, content in account_data.items():
|
||||
@@ -251,7 +256,8 @@ class InitialSyncHandler(BaseHandler):
|
||||
|
||||
return ret
|
||||
|
||||
async def room_initial_sync(self, requester, room_id, pagin_config=None):
|
||||
@defer.inlineCallbacks
|
||||
def room_initial_sync(self, requester, room_id, pagin_config=None):
|
||||
"""Capture the a snapshot of a room. If user is currently a member of
|
||||
the room this will be what is currently in the room. If the user left
|
||||
the room this will be what was in the room when they left.
|
||||
@@ -268,32 +274,32 @@ class InitialSyncHandler(BaseHandler):
|
||||
A JSON serialisable dict with the snapshot of the room.
|
||||
"""
|
||||
|
||||
blocked = await self.store.is_room_blocked(room_id)
|
||||
blocked = yield self.store.is_room_blocked(room_id)
|
||||
if blocked:
|
||||
raise SynapseError(403, "This room has been blocked on this server")
|
||||
|
||||
user_id = requester.user.to_string()
|
||||
|
||||
membership, member_event_id = await self._check_in_room_or_world_readable(
|
||||
membership, member_event_id = yield self._check_in_room_or_world_readable(
|
||||
room_id, user_id
|
||||
)
|
||||
is_peeking = member_event_id is None
|
||||
|
||||
if membership == Membership.JOIN:
|
||||
result = await self._room_initial_sync_joined(
|
||||
result = yield self._room_initial_sync_joined(
|
||||
user_id, room_id, pagin_config, membership, is_peeking
|
||||
)
|
||||
elif membership == Membership.LEAVE:
|
||||
result = await self._room_initial_sync_parted(
|
||||
result = yield self._room_initial_sync_parted(
|
||||
user_id, room_id, pagin_config, membership, member_event_id, is_peeking
|
||||
)
|
||||
|
||||
account_data_events = []
|
||||
tags = await self.store.get_tags_for_room(user_id, room_id)
|
||||
tags = yield self.store.get_tags_for_room(user_id, room_id)
|
||||
if tags:
|
||||
account_data_events.append({"type": "m.tag", "content": {"tags": tags}})
|
||||
|
||||
account_data = await self.store.get_account_data_for_room(user_id, room_id)
|
||||
account_data = yield self.store.get_account_data_for_room(user_id, room_id)
|
||||
for account_data_type, content in account_data.items():
|
||||
account_data_events.append({"type": account_data_type, "content": content})
|
||||
|
||||
@@ -301,10 +307,11 @@ class InitialSyncHandler(BaseHandler):
|
||||
|
||||
return result
|
||||
|
||||
async def _room_initial_sync_parted(
|
||||
@defer.inlineCallbacks
|
||||
def _room_initial_sync_parted(
|
||||
self, user_id, room_id, pagin_config, membership, member_event_id, is_peeking
|
||||
):
|
||||
room_state = await self.state_store.get_state_for_events([member_event_id])
|
||||
room_state = yield self.state_store.get_state_for_events([member_event_id])
|
||||
|
||||
room_state = room_state[member_event_id]
|
||||
|
||||
@@ -312,13 +319,13 @@ class InitialSyncHandler(BaseHandler):
|
||||
if limit is None:
|
||||
limit = 10
|
||||
|
||||
stream_token = await self.store.get_stream_token_for_event(member_event_id)
|
||||
stream_token = yield self.store.get_stream_token_for_event(member_event_id)
|
||||
|
||||
messages, token = await self.store.get_recent_events_for_room(
|
||||
messages, token = yield self.store.get_recent_events_for_room(
|
||||
room_id, limit=limit, end_token=stream_token
|
||||
)
|
||||
|
||||
messages = await filter_events_for_client(
|
||||
messages = yield filter_events_for_client(
|
||||
self.storage, user_id, messages, is_peeking=is_peeking
|
||||
)
|
||||
|
||||
@@ -332,13 +339,13 @@ class InitialSyncHandler(BaseHandler):
|
||||
"room_id": room_id,
|
||||
"messages": {
|
||||
"chunk": (
|
||||
await self._event_serializer.serialize_events(messages, time_now)
|
||||
yield self._event_serializer.serialize_events(messages, time_now)
|
||||
),
|
||||
"start": start_token.to_string(),
|
||||
"end": end_token.to_string(),
|
||||
},
|
||||
"state": (
|
||||
await self._event_serializer.serialize_events(
|
||||
yield self._event_serializer.serialize_events(
|
||||
room_state.values(), time_now
|
||||
)
|
||||
),
|
||||
@@ -346,18 +353,19 @@ class InitialSyncHandler(BaseHandler):
|
||||
"receipts": [],
|
||||
}
|
||||
|
||||
async def _room_initial_sync_joined(
|
||||
@defer.inlineCallbacks
|
||||
def _room_initial_sync_joined(
|
||||
self, user_id, room_id, pagin_config, membership, is_peeking
|
||||
):
|
||||
current_state = await self.state.get_current_state(room_id=room_id)
|
||||
current_state = yield self.state.get_current_state(room_id=room_id)
|
||||
|
||||
# TODO: These concurrently
|
||||
time_now = self.clock.time_msec()
|
||||
state = await self._event_serializer.serialize_events(
|
||||
state = yield self._event_serializer.serialize_events(
|
||||
current_state.values(), time_now
|
||||
)
|
||||
|
||||
now_token = await self.hs.get_event_sources().get_current_token()
|
||||
now_token = yield self.hs.get_event_sources().get_current_token()
|
||||
|
||||
limit = pagin_config.limit if pagin_config else None
|
||||
if limit is None:
|
||||
@@ -372,26 +380,28 @@ class InitialSyncHandler(BaseHandler):
|
||||
|
||||
presence_handler = self.hs.get_presence_handler()
|
||||
|
||||
async def get_presence():
|
||||
@defer.inlineCallbacks
|
||||
def get_presence():
|
||||
# If presence is disabled, return an empty list
|
||||
if not self.hs.config.use_presence:
|
||||
return []
|
||||
|
||||
states = await presence_handler.get_states(
|
||||
states = yield presence_handler.get_states(
|
||||
[m.user_id for m in room_members], as_event=True
|
||||
)
|
||||
|
||||
return states
|
||||
|
||||
async def get_receipts():
|
||||
receipts = await self.store.get_linearized_receipts_for_room(
|
||||
@defer.inlineCallbacks
|
||||
def get_receipts():
|
||||
receipts = yield self.store.get_linearized_receipts_for_room(
|
||||
room_id, to_key=now_token.receipt_key
|
||||
)
|
||||
if not receipts:
|
||||
receipts = []
|
||||
return receipts
|
||||
|
||||
presence, receipts, (messages, token) = await make_deferred_yieldable(
|
||||
presence, receipts, (messages, token) = yield make_deferred_yieldable(
|
||||
defer.gatherResults(
|
||||
[
|
||||
run_in_background(get_presence),
|
||||
@@ -407,7 +417,7 @@ class InitialSyncHandler(BaseHandler):
|
||||
).addErrback(unwrapFirstError)
|
||||
)
|
||||
|
||||
messages = await filter_events_for_client(
|
||||
messages = yield filter_events_for_client(
|
||||
self.storage, user_id, messages, is_peeking=is_peeking
|
||||
)
|
||||
|
||||
@@ -420,7 +430,7 @@ class InitialSyncHandler(BaseHandler):
|
||||
"room_id": room_id,
|
||||
"messages": {
|
||||
"chunk": (
|
||||
await self._event_serializer.serialize_events(messages, time_now)
|
||||
yield self._event_serializer.serialize_events(messages, time_now)
|
||||
),
|
||||
"start": start_token.to_string(),
|
||||
"end": end_token.to_string(),
|
||||
@@ -434,17 +444,18 @@ class InitialSyncHandler(BaseHandler):
|
||||
|
||||
return ret
|
||||
|
||||
async def _check_in_room_or_world_readable(self, room_id, user_id):
|
||||
@defer.inlineCallbacks
|
||||
def _check_in_room_or_world_readable(self, room_id, user_id):
|
||||
try:
|
||||
# check_user_was_in_room will return the most recent membership
|
||||
# event for the user if:
|
||||
# * The user is a non-guest user, and was ever in the room
|
||||
# * The user is a guest user, and has joined the room
|
||||
# else it will throw.
|
||||
member_event = await self.auth.check_user_was_in_room(room_id, user_id)
|
||||
member_event = yield self.auth.check_user_was_in_room(room_id, user_id)
|
||||
return member_event.membership, member_event.event_id
|
||||
except AuthError:
|
||||
visibility = await self.state_handler.get_current_state(
|
||||
visibility = yield self.state_handler.get_current_state(
|
||||
room_id, EventTypes.RoomHistoryVisibility, ""
|
||||
)
|
||||
if (
|
||||
|
||||
@@ -46,7 +46,6 @@ from synapse.events.validator import EventValidator
|
||||
from synapse.logging.context import run_in_background
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.replication.http.send_event import ReplicationSendEventRestServlet
|
||||
from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
|
||||
from synapse.storage.state import StateFilter
|
||||
from synapse.types import RoomAlias, UserID, create_requester
|
||||
from synapse.util.async_helpers import Linearizer
|
||||
@@ -876,7 +875,7 @@ class EventCreationHandler(object):
|
||||
if event.type == EventTypes.Redaction:
|
||||
original_event = yield self.store.get_event(
|
||||
event.redacts,
|
||||
redact_behaviour=EventRedactBehaviour.AS_IS,
|
||||
check_redacted=False,
|
||||
get_prev_content=False,
|
||||
allow_rejected=False,
|
||||
allow_none=True,
|
||||
@@ -953,7 +952,7 @@ class EventCreationHandler(object):
|
||||
if event.type == EventTypes.Redaction:
|
||||
original_event = yield self.store.get_event(
|
||||
event.redacts,
|
||||
redact_behaviour=EventRedactBehaviour.AS_IS,
|
||||
check_redacted=False,
|
||||
get_prev_content=False,
|
||||
allow_rejected=False,
|
||||
allow_none=True,
|
||||
|
||||
@@ -280,7 +280,8 @@ class PaginationHandler(object):
|
||||
|
||||
await self.storage.purge_events.purge_room(room_id)
|
||||
|
||||
async def get_messages(
|
||||
@defer.inlineCallbacks
|
||||
def get_messages(
|
||||
self,
|
||||
requester,
|
||||
room_id=None,
|
||||
@@ -306,7 +307,7 @@ class PaginationHandler(object):
|
||||
room_token = pagin_config.from_token.room_key
|
||||
else:
|
||||
pagin_config.from_token = (
|
||||
await self.hs.get_event_sources().get_current_token_for_pagination()
|
||||
yield self.hs.get_event_sources().get_current_token_for_pagination()
|
||||
)
|
||||
room_token = pagin_config.from_token.room_key
|
||||
|
||||
@@ -318,11 +319,11 @@ class PaginationHandler(object):
|
||||
|
||||
source_config = pagin_config.get_source_config("room")
|
||||
|
||||
with (await self.pagination_lock.read(room_id)):
|
||||
with (yield self.pagination_lock.read(room_id)):
|
||||
(
|
||||
membership,
|
||||
member_event_id,
|
||||
) = await self.auth.check_in_room_or_world_readable(room_id, user_id)
|
||||
) = yield self.auth.check_in_room_or_world_readable(room_id, user_id)
|
||||
|
||||
if source_config.direction == "b":
|
||||
# if we're going backwards, we might need to backfill. This
|
||||
@@ -330,7 +331,7 @@ class PaginationHandler(object):
|
||||
if room_token.topological:
|
||||
max_topo = room_token.topological
|
||||
else:
|
||||
max_topo = await self.store.get_max_topological_token(
|
||||
max_topo = yield self.store.get_max_topological_token(
|
||||
room_id, room_token.stream
|
||||
)
|
||||
|
||||
@@ -338,18 +339,18 @@ class PaginationHandler(object):
|
||||
# If they have left the room then clamp the token to be before
|
||||
# they left the room, to save the effort of loading from the
|
||||
# database.
|
||||
leave_token = await self.store.get_topological_token_for_event(
|
||||
leave_token = yield self.store.get_topological_token_for_event(
|
||||
member_event_id
|
||||
)
|
||||
leave_token = RoomStreamToken.parse(leave_token)
|
||||
if leave_token.topological < max_topo:
|
||||
source_config.from_key = str(leave_token)
|
||||
|
||||
await self.hs.get_handlers().federation_handler.maybe_backfill(
|
||||
yield self.hs.get_handlers().federation_handler.maybe_backfill(
|
||||
room_id, max_topo
|
||||
)
|
||||
|
||||
events, next_key = await self.store.paginate_room_events(
|
||||
events, next_key = yield self.store.paginate_room_events(
|
||||
room_id=room_id,
|
||||
from_key=source_config.from_key,
|
||||
to_key=source_config.to_key,
|
||||
@@ -364,7 +365,7 @@ class PaginationHandler(object):
|
||||
if event_filter:
|
||||
events = event_filter.filter(events)
|
||||
|
||||
events = await filter_events_for_client(
|
||||
events = yield filter_events_for_client(
|
||||
self.storage, user_id, events, is_peeking=(member_event_id is None)
|
||||
)
|
||||
|
||||
@@ -384,19 +385,19 @@ class PaginationHandler(object):
|
||||
(EventTypes.Member, event.sender) for event in events
|
||||
)
|
||||
|
||||
state_ids = await self.state_store.get_state_ids_for_event(
|
||||
state_ids = yield self.state_store.get_state_ids_for_event(
|
||||
events[0].event_id, state_filter=state_filter
|
||||
)
|
||||
|
||||
if state_ids:
|
||||
state = await self.store.get_events(list(state_ids.values()))
|
||||
state = yield self.store.get_events(list(state_ids.values()))
|
||||
state = state.values()
|
||||
|
||||
time_now = self.clock.time_msec()
|
||||
|
||||
chunk = {
|
||||
"chunk": (
|
||||
await self._event_serializer.serialize_events(
|
||||
yield self._event_serializer.serialize_events(
|
||||
events, time_now, as_client_event=as_client_event
|
||||
)
|
||||
),
|
||||
@@ -405,7 +406,7 @@ class PaginationHandler(object):
|
||||
}
|
||||
|
||||
if state:
|
||||
chunk["state"] = await self._event_serializer.serialize_events(
|
||||
chunk["state"] = yield self._event_serializer.serialize_events(
|
||||
state, time_now, as_client_event=as_client_event
|
||||
)
|
||||
|
||||
|
||||
@@ -907,7 +907,10 @@ class RoomContextHandler(object):
|
||||
|
||||
results["events_before"] = yield filter_evts(results["events_before"])
|
||||
results["events_after"] = yield filter_evts(results["events_after"])
|
||||
results["event"] = event
|
||||
# filter_evts can return a pruned event in case the user is allowed to see that
|
||||
# there's something there but not see the content, so use the event that's in
|
||||
# `filtered` rather than the event we retrieved from the datastore.
|
||||
results["event"] = filtered[0]
|
||||
|
||||
if results["events_after"]:
|
||||
last_event_id = results["events_after"][-1].event_id
|
||||
@@ -938,7 +941,7 @@ class RoomContextHandler(object):
|
||||
if event_filter:
|
||||
state_events = event_filter.filter(state_events)
|
||||
|
||||
results["state"] = state_events
|
||||
results["state"] = yield filter_evts(state_events)
|
||||
|
||||
# We use a dummy token here as we only care about the room portion of
|
||||
# the token, which we replace.
|
||||
|
||||
@@ -13,36 +13,20 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import re
|
||||
from typing import Tuple
|
||||
|
||||
import attr
|
||||
import saml2
|
||||
import saml2.response
|
||||
from saml2.client import Saml2Client
|
||||
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.config import ConfigError
|
||||
from synapse.http.servlet import parse_string
|
||||
from synapse.rest.client.v1.login import SSOAuthHandler
|
||||
from synapse.types import (
|
||||
UserID,
|
||||
map_username_to_mxid_localpart,
|
||||
mxid_localpart_allowed_characters,
|
||||
)
|
||||
from synapse.types import UserID, map_username_to_mxid_localpart
|
||||
from synapse.util.async_helpers import Linearizer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@attr.s
|
||||
class Saml2SessionData:
|
||||
"""Data we track about SAML2 sessions"""
|
||||
|
||||
# time the session was created, in milliseconds
|
||||
creation_time = attr.ib()
|
||||
|
||||
|
||||
class SamlHandler:
|
||||
def __init__(self, hs):
|
||||
self._saml_client = Saml2Client(hs.config.saml2_sp_config)
|
||||
@@ -53,14 +37,11 @@ class SamlHandler:
|
||||
self._datastore = hs.get_datastore()
|
||||
self._hostname = hs.hostname
|
||||
self._saml2_session_lifetime = hs.config.saml2_session_lifetime
|
||||
self._mxid_source_attribute = hs.config.saml2_mxid_source_attribute
|
||||
self._grandfathered_mxid_source_attribute = (
|
||||
hs.config.saml2_grandfathered_mxid_source_attribute
|
||||
)
|
||||
|
||||
# plugin to do custom mapping from saml response to mxid
|
||||
self._user_mapping_provider = hs.config.saml2_user_mapping_provider_class(
|
||||
hs.config.saml2_user_mapping_provider_config
|
||||
)
|
||||
self._mxid_mapper = hs.config.saml2_mxid_mapper
|
||||
|
||||
# identifier for the external_ids table
|
||||
self._auth_provider_id = "saml"
|
||||
@@ -137,10 +118,22 @@ class SamlHandler:
|
||||
remote_user_id = saml2_auth.ava["uid"][0]
|
||||
except KeyError:
|
||||
logger.warning("SAML2 response lacks a 'uid' attestation")
|
||||
raise SynapseError(400, "'uid' not in SAML2 response")
|
||||
raise SynapseError(400, "uid not in SAML2 response")
|
||||
|
||||
try:
|
||||
mxid_source = saml2_auth.ava[self._mxid_source_attribute][0]
|
||||
except KeyError:
|
||||
logger.warning(
|
||||
"SAML2 response lacks a '%s' attestation", self._mxid_source_attribute
|
||||
)
|
||||
raise SynapseError(
|
||||
400, "%s not in SAML2 response" % (self._mxid_source_attribute,)
|
||||
)
|
||||
|
||||
self._outstanding_requests_dict.pop(saml2_auth.in_response_to, None)
|
||||
|
||||
displayName = saml2_auth.ava.get("displayName", [None])[0]
|
||||
|
||||
with (await self._mapping_lock.queue(self._auth_provider_id)):
|
||||
# first of all, check if we already have a mapping for this user
|
||||
logger.info(
|
||||
@@ -180,46 +173,22 @@ class SamlHandler:
|
||||
)
|
||||
return registered_user_id
|
||||
|
||||
# Map saml response to user attributes using the configured mapping provider
|
||||
for i in range(1000):
|
||||
attribute_dict = self._user_mapping_provider.saml_response_to_user_attributes(
|
||||
saml2_auth, i
|
||||
)
|
||||
# figure out a new mxid for this user
|
||||
base_mxid_localpart = self._mxid_mapper(mxid_source)
|
||||
|
||||
logger.debug(
|
||||
"Retrieved SAML attributes from user mapping provider: %s "
|
||||
"(attempt %d)",
|
||||
attribute_dict,
|
||||
i,
|
||||
)
|
||||
|
||||
localpart = attribute_dict.get("mxid_localpart")
|
||||
if not localpart:
|
||||
logger.error(
|
||||
"SAML mapping provider plugin did not return a "
|
||||
"mxid_localpart object"
|
||||
)
|
||||
raise SynapseError(500, "Error parsing SAML2 response")
|
||||
|
||||
displayname = attribute_dict.get("displayname")
|
||||
|
||||
# Check if this mxid already exists
|
||||
suffix = 0
|
||||
while True:
|
||||
localpart = base_mxid_localpart + (str(suffix) if suffix else "")
|
||||
if not await self._datastore.get_users_by_id_case_insensitive(
|
||||
UserID(localpart, self._hostname).to_string()
|
||||
):
|
||||
# This mxid is free
|
||||
break
|
||||
else:
|
||||
# Unable to generate a username in 1000 iterations
|
||||
# Break and return error to the user
|
||||
raise SynapseError(
|
||||
500, "Unable to generate a Matrix ID from the SAML response"
|
||||
)
|
||||
suffix += 1
|
||||
logger.info("Allocating mxid for new user with localpart %s", localpart)
|
||||
|
||||
registered_user_id = await self._registration_handler.register_user(
|
||||
localpart=localpart, default_display_name=displayname
|
||||
localpart=localpart, default_display_name=displayName
|
||||
)
|
||||
|
||||
await self._datastore.record_user_external_id(
|
||||
self._auth_provider_id, remote_user_id, registered_user_id
|
||||
)
|
||||
@@ -236,120 +205,9 @@ class SamlHandler:
|
||||
del self._outstanding_requests_dict[reqid]
|
||||
|
||||
|
||||
DOT_REPLACE_PATTERN = re.compile(
|
||||
("[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters)),))
|
||||
)
|
||||
|
||||
|
||||
def dot_replace_for_mxid(username: str) -> str:
|
||||
username = username.lower()
|
||||
username = DOT_REPLACE_PATTERN.sub(".", username)
|
||||
|
||||
# regular mxids aren't allowed to start with an underscore either
|
||||
username = re.sub("^_", "", username)
|
||||
return username
|
||||
|
||||
|
||||
MXID_MAPPER_MAP = {
|
||||
"hexencode": map_username_to_mxid_localpart,
|
||||
"dotreplace": dot_replace_for_mxid,
|
||||
}
|
||||
|
||||
|
||||
@attr.s
|
||||
class SamlConfig(object):
|
||||
mxid_source_attribute = attr.ib()
|
||||
mxid_mapper = attr.ib()
|
||||
class Saml2SessionData:
|
||||
"""Data we track about SAML2 sessions"""
|
||||
|
||||
|
||||
class DefaultSamlMappingProvider(object):
|
||||
__version__ = "0.0.1"
|
||||
|
||||
def __init__(self, parsed_config: SamlConfig):
|
||||
"""The default SAML user mapping provider
|
||||
|
||||
Args:
|
||||
parsed_config: Module configuration
|
||||
"""
|
||||
self._mxid_source_attribute = parsed_config.mxid_source_attribute
|
||||
self._mxid_mapper = parsed_config.mxid_mapper
|
||||
|
||||
def saml_response_to_user_attributes(
|
||||
self, saml_response: saml2.response.AuthnResponse, failures: int = 0,
|
||||
) -> dict:
|
||||
"""Maps some text from a SAML response to attributes of a new user
|
||||
|
||||
Args:
|
||||
saml_response: A SAML auth response object
|
||||
|
||||
failures: How many times a call to this function with this
|
||||
saml_response has resulted in a failure
|
||||
|
||||
Returns:
|
||||
dict: A dict containing new user attributes. Possible keys:
|
||||
* mxid_localpart (str): Required. The localpart of the user's mxid
|
||||
* displayname (str): The displayname of the user
|
||||
"""
|
||||
try:
|
||||
mxid_source = saml_response.ava[self._mxid_source_attribute][0]
|
||||
except KeyError:
|
||||
logger.warning(
|
||||
"SAML2 response lacks a '%s' attestation", self._mxid_source_attribute,
|
||||
)
|
||||
raise SynapseError(
|
||||
400, "%s not in SAML2 response" % (self._mxid_source_attribute,)
|
||||
)
|
||||
|
||||
# Use the configured mapper for this mxid_source
|
||||
base_mxid_localpart = self._mxid_mapper(mxid_source)
|
||||
|
||||
# Append suffix integer if last call to this function failed to produce
|
||||
# a usable mxid
|
||||
localpart = base_mxid_localpart + (str(failures) if failures else "")
|
||||
|
||||
# Retrieve the display name from the saml response
|
||||
# If displayname is None, the mxid_localpart will be used instead
|
||||
displayname = saml_response.ava.get("displayName", [None])[0]
|
||||
|
||||
return {
|
||||
"mxid_localpart": localpart,
|
||||
"displayname": displayname,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def parse_config(config: dict) -> SamlConfig:
|
||||
"""Parse the dict provided by the homeserver's config
|
||||
Args:
|
||||
config: A dictionary containing configuration options for this provider
|
||||
Returns:
|
||||
SamlConfig: A custom config object for this module
|
||||
"""
|
||||
# Parse config options and use defaults where necessary
|
||||
mxid_source_attribute = config.get("mxid_source_attribute", "uid")
|
||||
mapping_type = config.get("mxid_mapping", "hexencode")
|
||||
|
||||
# Retrieve the associating mapping function
|
||||
try:
|
||||
mxid_mapper = MXID_MAPPER_MAP[mapping_type]
|
||||
except KeyError:
|
||||
raise ConfigError(
|
||||
"saml2_config.user_mapping_provider.config: '%s' is not a valid "
|
||||
"mxid_mapping value" % (mapping_type,)
|
||||
)
|
||||
|
||||
return SamlConfig(mxid_source_attribute, mxid_mapper)
|
||||
|
||||
@staticmethod
|
||||
def get_saml_attributes(config: SamlConfig) -> Tuple[set, set]:
|
||||
"""Returns the required attributes of a SAML
|
||||
|
||||
Args:
|
||||
config: A SamlConfig object containing configuration params for this provider
|
||||
|
||||
Returns:
|
||||
tuple[set,set]: The first set equates to the saml auth response
|
||||
attributes that are required for the module to function, whereas the
|
||||
second set consists of those attributes which can be used if
|
||||
available, but are not necessary
|
||||
"""
|
||||
return {"uid", config.mxid_source_attribute}, {"displayName"}
|
||||
# time the session was created, in milliseconds
|
||||
creation_time = attr.ib()
|
||||
|
||||
@@ -21,7 +21,7 @@ from unpaddedbase64 import decode_base64, encode_base64
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.constants import EventTypes, Membership
|
||||
from synapse.api.errors import NotFoundError, SynapseError
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.api.filtering import Filter
|
||||
from synapse.storage.state import StateFilter
|
||||
from synapse.visibility import filter_events_for_client
|
||||
@@ -37,7 +37,6 @@ class SearchHandler(BaseHandler):
|
||||
self._event_serializer = hs.get_event_client_serializer()
|
||||
self.storage = hs.get_storage()
|
||||
self.state_store = self.storage.state
|
||||
self.auth = hs.get_auth()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_old_rooms_from_upgraded_room(self, room_id):
|
||||
@@ -54,38 +53,23 @@ class SearchHandler(BaseHandler):
|
||||
room_id (str): id of the room to search through.
|
||||
|
||||
Returns:
|
||||
Deferred[iterable[str]]: predecessor room ids
|
||||
Deferred[iterable[unicode]]: predecessor room ids
|
||||
"""
|
||||
|
||||
historical_room_ids = []
|
||||
|
||||
# The initial room must have been known for us to get this far
|
||||
predecessor = yield self.store.get_room_predecessor(room_id)
|
||||
|
||||
while True:
|
||||
predecessor = yield self.store.get_room_predecessor(room_id)
|
||||
|
||||
# If no predecessor, assume we've hit a dead end
|
||||
if not predecessor:
|
||||
# We have reached the end of the chain of predecessors
|
||||
break
|
||||
|
||||
if not isinstance(predecessor.get("room_id"), str):
|
||||
# This predecessor object is malformed. Exit here
|
||||
break
|
||||
# Add predecessor's room ID
|
||||
historical_room_ids.append(predecessor["room_id"])
|
||||
|
||||
predecessor_room_id = predecessor["room_id"]
|
||||
|
||||
# Don't add it to the list until we have checked that we are in the room
|
||||
try:
|
||||
next_predecessor_room = yield self.store.get_room_predecessor(
|
||||
predecessor_room_id
|
||||
)
|
||||
except NotFoundError:
|
||||
# The predecessor is not a known room, so we are done here
|
||||
break
|
||||
|
||||
historical_room_ids.append(predecessor_room_id)
|
||||
|
||||
# And repeat
|
||||
predecessor = next_predecessor_room
|
||||
# Scan through the old room for further predecessors
|
||||
room_id = predecessor["room_id"]
|
||||
|
||||
return historical_room_ids
|
||||
|
||||
|
||||
@@ -76,9 +76,9 @@ def flatten_event(event: dict, metadata: dict, include_time: bool = False):
|
||||
# context only given in the log format (e.g. what is being logged) is
|
||||
# available.
|
||||
if "log_text" in event:
|
||||
new_event["message"] = event["log_text"]
|
||||
new_event["log"] = event["log_text"]
|
||||
else:
|
||||
new_event["message"] = event["log_format"]
|
||||
new_event["log"] = event["log_format"]
|
||||
|
||||
# We want to include the timestamp when forwarding over the network, but
|
||||
# exclude it when we are writing to stdout. This is because the log ingester
|
||||
|
||||
@@ -23,7 +23,6 @@ them.
|
||||
See doc/log_contexts.rst for details on how this works.
|
||||
"""
|
||||
|
||||
import inspect
|
||||
import logging
|
||||
import threading
|
||||
import types
|
||||
@@ -613,8 +612,7 @@ def run_in_background(f, *args, **kwargs):
|
||||
|
||||
|
||||
def make_deferred_yieldable(deferred):
|
||||
"""Given a deferred (or coroutine), make it follow the Synapse logcontext
|
||||
rules:
|
||||
"""Given a deferred, make it follow the Synapse logcontext rules:
|
||||
|
||||
If the deferred has completed (or is not actually a Deferred), essentially
|
||||
does nothing (just returns another completed deferred with the
|
||||
@@ -626,13 +624,6 @@ def make_deferred_yieldable(deferred):
|
||||
|
||||
(This is more-or-less the opposite operation to run_in_background.)
|
||||
"""
|
||||
if inspect.isawaitable(deferred):
|
||||
# If we're given a coroutine we convert it to a deferred so that we
|
||||
# run it and find out if it immediately finishes, it it does then we
|
||||
# don't need to fiddle with log contexts at all and can return
|
||||
# immediately.
|
||||
deferred = defer.ensureDeferred(deferred)
|
||||
|
||||
if not isinstance(deferred, defer.Deferred):
|
||||
return deferred
|
||||
|
||||
|
||||
@@ -34,7 +34,6 @@ from synapse.api.filtering import Filtering
|
||||
from synapse.api.ratelimiting import Ratelimiter
|
||||
from synapse.appservice.api import ApplicationServiceApi
|
||||
from synapse.appservice.scheduler import ApplicationServiceScheduler
|
||||
from synapse.config.homeserver import HomeServerConfig
|
||||
from synapse.crypto import context_factory
|
||||
from synapse.crypto.keyring import Keyring
|
||||
from synapse.events.builder import EventBuilderFactory
|
||||
@@ -98,7 +97,6 @@ from synapse.server_notices.worker_server_notices_sender import (
|
||||
)
|
||||
from synapse.state import StateHandler, StateResolutionHandler
|
||||
from synapse.storage import DataStores, Storage
|
||||
from synapse.storage.engines import create_engine
|
||||
from synapse.streams.events import EventSources
|
||||
from synapse.util import Clock
|
||||
from synapse.util.distributor import Distributor
|
||||
@@ -211,18 +209,16 @@ class HomeServer(object):
|
||||
# instantiated during setup() for future return by get_datastore()
|
||||
DATASTORE_CLASS = abc.abstractproperty()
|
||||
|
||||
def __init__(self, hostname: str, config: HomeServerConfig, reactor=None, **kwargs):
|
||||
def __init__(self, hostname, reactor=None, **kwargs):
|
||||
"""
|
||||
Args:
|
||||
hostname : The hostname for the server.
|
||||
config: The full config for the homeserver.
|
||||
"""
|
||||
if not reactor:
|
||||
from twisted.internet import reactor
|
||||
|
||||
self._reactor = reactor
|
||||
self.hostname = hostname
|
||||
self.config = config
|
||||
self._building = {}
|
||||
self._listening_services = []
|
||||
self.start_time = None
|
||||
@@ -233,12 +229,6 @@ class HomeServer(object):
|
||||
self.admin_redaction_ratelimiter = Ratelimiter()
|
||||
self.registration_ratelimiter = Ratelimiter()
|
||||
|
||||
self.database_engine = create_engine(config.database_config)
|
||||
config.database_config.setdefault("args", {})[
|
||||
"cp_openfun"
|
||||
] = self.database_engine.on_new_connection
|
||||
self.db_config = config.database_config
|
||||
|
||||
self.datastores = None
|
||||
|
||||
# Other kwargs are explicit dependencies
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
|
||||
import logging
|
||||
from collections import namedtuple
|
||||
from typing import Iterable, Optional
|
||||
from typing import Dict, Iterable, List, Optional, Tuple
|
||||
|
||||
from six import iteritems, itervalues
|
||||
|
||||
@@ -32,7 +32,6 @@ from synapse.events import EventBase
|
||||
from synapse.events.snapshot import EventContext
|
||||
from synapse.logging.utils import log_function
|
||||
from synapse.state import v1, v2
|
||||
from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
|
||||
from synapse.util.async_helpers import Linearizer
|
||||
from synapse.util.caches import get_cache_factor_for
|
||||
from synapse.util.caches.expiringcache import ExpiringCache
|
||||
@@ -417,6 +416,7 @@ class StateHandler(object):
|
||||
|
||||
with Measure(self.clock, "state._resolve_events"):
|
||||
new_state = yield resolve_events_with_store(
|
||||
event.room_id,
|
||||
room_version,
|
||||
state_set_ids,
|
||||
event_map=state_map,
|
||||
@@ -462,7 +462,7 @@ class StateResolutionHandler(object):
|
||||
not be called for a single state group
|
||||
|
||||
Args:
|
||||
room_id (str): room we are resolving for (used for logging)
|
||||
room_id (str): room we are resolving for (used for logging and sanity checks)
|
||||
room_version (str): version of the room
|
||||
state_groups_ids (dict[int, dict[(str, str), str]]):
|
||||
map from state group id to the state in that state group
|
||||
@@ -518,6 +518,7 @@ class StateResolutionHandler(object):
|
||||
logger.info("Resolving conflicted state for %r", room_id)
|
||||
with Measure(self.clock, "state._resolve_events"):
|
||||
new_state = yield resolve_events_with_store(
|
||||
room_id,
|
||||
room_version,
|
||||
list(itervalues(state_groups_ids)),
|
||||
event_map=event_map,
|
||||
@@ -589,36 +590,44 @@ def _make_state_cache_entry(new_state, state_groups_ids):
|
||||
)
|
||||
|
||||
|
||||
def resolve_events_with_store(room_version, state_sets, event_map, state_res_store):
|
||||
def resolve_events_with_store(
|
||||
room_id: str,
|
||||
room_version: str,
|
||||
state_sets: List[Dict[Tuple[str, str], str]],
|
||||
event_map: Optional[Dict[str, EventBase]],
|
||||
state_res_store: "StateResolutionStore",
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
room_version(str): Version of the room
|
||||
room_id: the room we are working in
|
||||
|
||||
state_sets(list): List of dicts of (type, state_key) -> event_id,
|
||||
room_version: Version of the room
|
||||
|
||||
state_sets: List of dicts of (type, state_key) -> event_id,
|
||||
which are the different state groups to resolve.
|
||||
|
||||
event_map(dict[str,FrozenEvent]|None):
|
||||
event_map:
|
||||
a dict from event_id to event, for any events that we happen to
|
||||
have in flight (eg, those currently being persisted). This will be
|
||||
used as a starting point fof finding the state we need; any missing
|
||||
events will be requested via state_map_factory.
|
||||
|
||||
If None, all events will be fetched via state_map_factory.
|
||||
If None, all events will be fetched via state_res_store.
|
||||
|
||||
state_res_store (StateResolutionStore)
|
||||
state_res_store: a place to fetch events from
|
||||
|
||||
Returns
|
||||
Returns:
|
||||
Deferred[dict[(str, str), str]]:
|
||||
a map from (type, state_key) to event_id.
|
||||
"""
|
||||
v = KNOWN_ROOM_VERSIONS[room_version]
|
||||
if v.state_res == StateResolutionVersions.V1:
|
||||
return v1.resolve_events_with_store(
|
||||
state_sets, event_map, state_res_store.get_events
|
||||
room_id, state_sets, event_map, state_res_store.get_events
|
||||
)
|
||||
else:
|
||||
return v2.resolve_events_with_store(
|
||||
room_version, state_sets, event_map, state_res_store
|
||||
room_id, room_version, state_sets, event_map, state_res_store
|
||||
)
|
||||
|
||||
|
||||
@@ -646,7 +655,7 @@ class StateResolutionStore(object):
|
||||
|
||||
return self.store.get_events(
|
||||
event_ids,
|
||||
redact_behaviour=EventRedactBehaviour.AS_IS,
|
||||
check_redacted=False,
|
||||
get_prev_content=False,
|
||||
allow_rejected=allow_rejected,
|
||||
)
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
from typing import Callable, Dict, List, Optional, Tuple
|
||||
|
||||
from six import iteritems, iterkeys, itervalues
|
||||
|
||||
@@ -24,6 +25,7 @@ from synapse import event_auth
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.api.errors import AuthError
|
||||
from synapse.api.room_versions import RoomVersions
|
||||
from synapse.events import EventBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -32,13 +34,20 @@ POWER_KEY = (EventTypes.PowerLevels, "")
|
||||
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def resolve_events_with_store(state_sets, event_map, state_map_factory):
|
||||
def resolve_events_with_store(
|
||||
room_id: str,
|
||||
state_sets: List[Dict[Tuple[str, str], str]],
|
||||
event_map: Optional[Dict[str, EventBase]],
|
||||
state_map_factory: Callable,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
state_sets(list): List of dicts of (type, state_key) -> event_id,
|
||||
room_id: the room we are working in
|
||||
|
||||
state_sets: List of dicts of (type, state_key) -> event_id,
|
||||
which are the different state groups to resolve.
|
||||
|
||||
event_map(dict[str,FrozenEvent]|None):
|
||||
event_map:
|
||||
a dict from event_id to event, for any events that we happen to
|
||||
have in flight (eg, those currently being persisted). This will be
|
||||
used as a starting point fof finding the state we need; any missing
|
||||
@@ -46,11 +55,11 @@ def resolve_events_with_store(state_sets, event_map, state_map_factory):
|
||||
|
||||
If None, all events will be fetched via state_map_factory.
|
||||
|
||||
state_map_factory(func): will be called
|
||||
state_map_factory: will be called
|
||||
with a list of event_ids that are needed, and should return with
|
||||
a Deferred of dict of event_id to event.
|
||||
|
||||
Returns
|
||||
Returns:
|
||||
Deferred[dict[(str, str), str]]:
|
||||
a map from (type, state_key) to event_id.
|
||||
"""
|
||||
@@ -76,6 +85,14 @@ def resolve_events_with_store(state_sets, event_map, state_map_factory):
|
||||
if event_map is not None:
|
||||
state_map.update(event_map)
|
||||
|
||||
# everything in the state map should be in the right room
|
||||
for event in state_map.values():
|
||||
if event.room_id != room_id:
|
||||
raise Exception(
|
||||
"Attempting to state-resolve for room %s with event %s which is in %s"
|
||||
% (room_id, event.event_id, event.room_id,)
|
||||
)
|
||||
|
||||
# get the ids of the auth events which allow us to authenticate the
|
||||
# conflicted state, picking only from the unconflicting state.
|
||||
#
|
||||
@@ -95,6 +112,13 @@ def resolve_events_with_store(state_sets, event_map, state_map_factory):
|
||||
)
|
||||
|
||||
state_map_new = yield state_map_factory(new_needed_events)
|
||||
for event in state_map_new.values():
|
||||
if event.room_id != room_id:
|
||||
raise Exception(
|
||||
"Attempting to state-resolve for room %s with event %s which is in %s"
|
||||
% (room_id, event.event_id, event.room_id,)
|
||||
)
|
||||
|
||||
state_map.update(state_map_new)
|
||||
|
||||
return _resolve_with_state(
|
||||
|
||||
@@ -16,29 +16,40 @@
|
||||
import heapq
|
||||
import itertools
|
||||
import logging
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from six import iteritems, itervalues
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
import synapse.state
|
||||
from synapse import event_auth
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.api.errors import AuthError
|
||||
from synapse.events import EventBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def resolve_events_with_store(room_version, state_sets, event_map, state_res_store):
|
||||
def resolve_events_with_store(
|
||||
room_id: str,
|
||||
room_version: str,
|
||||
state_sets: List[Dict[Tuple[str, str], str]],
|
||||
event_map: Optional[Dict[str, EventBase]],
|
||||
state_res_store: "synapse.state.StateResolutionStore",
|
||||
):
|
||||
"""Resolves the state using the v2 state resolution algorithm
|
||||
|
||||
Args:
|
||||
room_version (str): The room version
|
||||
room_id: the room we are working in
|
||||
|
||||
state_sets(list): List of dicts of (type, state_key) -> event_id,
|
||||
room_version: The room version
|
||||
|
||||
state_sets: List of dicts of (type, state_key) -> event_id,
|
||||
which are the different state groups to resolve.
|
||||
|
||||
event_map(dict[str,FrozenEvent]|None):
|
||||
event_map:
|
||||
a dict from event_id to event, for any events that we happen to
|
||||
have in flight (eg, those currently being persisted). This will be
|
||||
used as a starting point fof finding the state we need; any missing
|
||||
@@ -46,9 +57,9 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto
|
||||
|
||||
If None, all events will be fetched via state_res_store.
|
||||
|
||||
state_res_store (StateResolutionStore)
|
||||
state_res_store:
|
||||
|
||||
Returns
|
||||
Returns:
|
||||
Deferred[dict[(str, str), str]]:
|
||||
a map from (type, state_key) to event_id.
|
||||
"""
|
||||
@@ -84,6 +95,14 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto
|
||||
)
|
||||
event_map.update(events)
|
||||
|
||||
# everything in the event map should be in the right room
|
||||
for event in event_map.values():
|
||||
if event.room_id != room_id:
|
||||
raise Exception(
|
||||
"Attempting to state-resolve for room %s with event %s which is in %s"
|
||||
% (room_id, event.event_id, event.room_id,)
|
||||
)
|
||||
|
||||
full_conflicted_set = set(eid for eid in full_conflicted_set if eid in event_map)
|
||||
|
||||
logger.debug("%d full_conflicted_set entries", len(full_conflicted_set))
|
||||
@@ -94,13 +113,14 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto
|
||||
)
|
||||
|
||||
sorted_power_events = yield _reverse_topological_power_sort(
|
||||
power_events, event_map, state_res_store, full_conflicted_set
|
||||
room_id, power_events, event_map, state_res_store, full_conflicted_set
|
||||
)
|
||||
|
||||
logger.debug("sorted %d power events", len(sorted_power_events))
|
||||
|
||||
# Now sequentially auth each one
|
||||
resolved_state = yield _iterative_auth_checks(
|
||||
room_id,
|
||||
room_version,
|
||||
sorted_power_events,
|
||||
unconflicted_state,
|
||||
@@ -121,13 +141,18 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto
|
||||
|
||||
pl = resolved_state.get((EventTypes.PowerLevels, ""), None)
|
||||
leftover_events = yield _mainline_sort(
|
||||
leftover_events, pl, event_map, state_res_store
|
||||
room_id, leftover_events, pl, event_map, state_res_store
|
||||
)
|
||||
|
||||
logger.debug("resolving remaining events")
|
||||
|
||||
resolved_state = yield _iterative_auth_checks(
|
||||
room_version, leftover_events, resolved_state, event_map, state_res_store
|
||||
room_id,
|
||||
room_version,
|
||||
leftover_events,
|
||||
resolved_state,
|
||||
event_map,
|
||||
state_res_store,
|
||||
)
|
||||
|
||||
logger.debug("resolved")
|
||||
@@ -141,11 +166,12 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto
|
||||
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _get_power_level_for_sender(event_id, event_map, state_res_store):
|
||||
def _get_power_level_for_sender(room_id, event_id, event_map, state_res_store):
|
||||
"""Return the power level of the sender of the given event according to
|
||||
their auth events.
|
||||
|
||||
Args:
|
||||
room_id (str)
|
||||
event_id (str)
|
||||
event_map (dict[str,FrozenEvent])
|
||||
state_res_store (StateResolutionStore)
|
||||
@@ -153,11 +179,11 @@ def _get_power_level_for_sender(event_id, event_map, state_res_store):
|
||||
Returns:
|
||||
Deferred[int]
|
||||
"""
|
||||
event = yield _get_event(event_id, event_map, state_res_store)
|
||||
event = yield _get_event(room_id, event_id, event_map, state_res_store)
|
||||
|
||||
pl = None
|
||||
for aid in event.auth_event_ids():
|
||||
aev = yield _get_event(aid, event_map, state_res_store)
|
||||
aev = yield _get_event(room_id, aid, event_map, state_res_store)
|
||||
if (aev.type, aev.state_key) == (EventTypes.PowerLevels, ""):
|
||||
pl = aev
|
||||
break
|
||||
@@ -165,7 +191,7 @@ def _get_power_level_for_sender(event_id, event_map, state_res_store):
|
||||
if pl is None:
|
||||
# Couldn't find power level. Check if they're the creator of the room
|
||||
for aid in event.auth_event_ids():
|
||||
aev = yield _get_event(aid, event_map, state_res_store)
|
||||
aev = yield _get_event(room_id, aid, event_map, state_res_store)
|
||||
if (aev.type, aev.state_key) == (EventTypes.Create, ""):
|
||||
if aev.content.get("creator") == event.sender:
|
||||
return 100
|
||||
@@ -279,7 +305,7 @@ def _is_power_event(event):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _add_event_and_auth_chain_to_graph(
|
||||
graph, event_id, event_map, state_res_store, auth_diff
|
||||
graph, room_id, event_id, event_map, state_res_store, auth_diff
|
||||
):
|
||||
"""Helper function for _reverse_topological_power_sort that add the event
|
||||
and its auth chain (that is in the auth diff) to the graph
|
||||
@@ -287,6 +313,7 @@ def _add_event_and_auth_chain_to_graph(
|
||||
Args:
|
||||
graph (dict[str, set[str]]): A map from event ID to the events auth
|
||||
event IDs
|
||||
room_id (str): the room we are working in
|
||||
event_id (str): Event to add to the graph
|
||||
event_map (dict[str,FrozenEvent])
|
||||
state_res_store (StateResolutionStore)
|
||||
@@ -298,7 +325,7 @@ def _add_event_and_auth_chain_to_graph(
|
||||
eid = state.pop()
|
||||
graph.setdefault(eid, set())
|
||||
|
||||
event = yield _get_event(eid, event_map, state_res_store)
|
||||
event = yield _get_event(room_id, eid, event_map, state_res_store)
|
||||
for aid in event.auth_event_ids():
|
||||
if aid in auth_diff:
|
||||
if aid not in graph:
|
||||
@@ -308,11 +335,14 @@ def _add_event_and_auth_chain_to_graph(
|
||||
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _reverse_topological_power_sort(event_ids, event_map, state_res_store, auth_diff):
|
||||
def _reverse_topological_power_sort(
|
||||
room_id, event_ids, event_map, state_res_store, auth_diff
|
||||
):
|
||||
"""Returns a list of the event_ids sorted by reverse topological ordering,
|
||||
and then by power level and origin_server_ts
|
||||
|
||||
Args:
|
||||
room_id (str): the room we are working in
|
||||
event_ids (list[str]): The events to sort
|
||||
event_map (dict[str,FrozenEvent])
|
||||
state_res_store (StateResolutionStore)
|
||||
@@ -325,12 +355,14 @@ def _reverse_topological_power_sort(event_ids, event_map, state_res_store, auth_
|
||||
graph = {}
|
||||
for event_id in event_ids:
|
||||
yield _add_event_and_auth_chain_to_graph(
|
||||
graph, event_id, event_map, state_res_store, auth_diff
|
||||
graph, room_id, event_id, event_map, state_res_store, auth_diff
|
||||
)
|
||||
|
||||
event_to_pl = {}
|
||||
for event_id in graph:
|
||||
pl = yield _get_power_level_for_sender(event_id, event_map, state_res_store)
|
||||
pl = yield _get_power_level_for_sender(
|
||||
room_id, event_id, event_map, state_res_store
|
||||
)
|
||||
event_to_pl[event_id] = pl
|
||||
|
||||
def _get_power_order(event_id):
|
||||
@@ -348,12 +380,13 @@ def _reverse_topological_power_sort(event_ids, event_map, state_res_store, auth_
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _iterative_auth_checks(
|
||||
room_version, event_ids, base_state, event_map, state_res_store
|
||||
room_id, room_version, event_ids, base_state, event_map, state_res_store
|
||||
):
|
||||
"""Sequentially apply auth checks to each event in given list, updating the
|
||||
state as it goes along.
|
||||
|
||||
Args:
|
||||
room_id (str)
|
||||
room_version (str)
|
||||
event_ids (list[str]): Ordered list of events to apply auth checks to
|
||||
base_state (dict[tuple[str, str], str]): The set of state to start with
|
||||
@@ -370,7 +403,7 @@ def _iterative_auth_checks(
|
||||
|
||||
auth_events = {}
|
||||
for aid in event.auth_event_ids():
|
||||
ev = yield _get_event(aid, event_map, state_res_store)
|
||||
ev = yield _get_event(room_id, aid, event_map, state_res_store)
|
||||
|
||||
if ev.rejected_reason is None:
|
||||
auth_events[(ev.type, ev.state_key)] = ev
|
||||
@@ -378,7 +411,7 @@ def _iterative_auth_checks(
|
||||
for key in event_auth.auth_types_for_event(event):
|
||||
if key in resolved_state:
|
||||
ev_id = resolved_state[key]
|
||||
ev = yield _get_event(ev_id, event_map, state_res_store)
|
||||
ev = yield _get_event(room_id, ev_id, event_map, state_res_store)
|
||||
|
||||
if ev.rejected_reason is None:
|
||||
auth_events[key] = event_map[ev_id]
|
||||
@@ -400,11 +433,14 @@ def _iterative_auth_checks(
|
||||
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _mainline_sort(event_ids, resolved_power_event_id, event_map, state_res_store):
|
||||
def _mainline_sort(
|
||||
room_id, event_ids, resolved_power_event_id, event_map, state_res_store
|
||||
):
|
||||
"""Returns a sorted list of event_ids sorted by mainline ordering based on
|
||||
the given event resolved_power_event_id
|
||||
|
||||
Args:
|
||||
room_id (str): room we're working in
|
||||
event_ids (list[str]): Events to sort
|
||||
resolved_power_event_id (str): The final resolved power level event ID
|
||||
event_map (dict[str,FrozenEvent])
|
||||
@@ -417,11 +453,11 @@ def _mainline_sort(event_ids, resolved_power_event_id, event_map, state_res_stor
|
||||
pl = resolved_power_event_id
|
||||
while pl:
|
||||
mainline.append(pl)
|
||||
pl_ev = yield _get_event(pl, event_map, state_res_store)
|
||||
pl_ev = yield _get_event(room_id, pl, event_map, state_res_store)
|
||||
auth_events = pl_ev.auth_event_ids()
|
||||
pl = None
|
||||
for aid in auth_events:
|
||||
ev = yield _get_event(aid, event_map, state_res_store)
|
||||
ev = yield _get_event(room_id, aid, event_map, state_res_store)
|
||||
if (ev.type, ev.state_key) == (EventTypes.PowerLevels, ""):
|
||||
pl = aid
|
||||
break
|
||||
@@ -457,6 +493,8 @@ def _get_mainline_depth_for_event(event, mainline_map, event_map, state_res_stor
|
||||
Deferred[int]
|
||||
"""
|
||||
|
||||
room_id = event.room_id
|
||||
|
||||
# We do an iterative search, replacing `event with the power level in its
|
||||
# auth events (if any)
|
||||
while event:
|
||||
@@ -468,7 +506,7 @@ def _get_mainline_depth_for_event(event, mainline_map, event_map, state_res_stor
|
||||
event = None
|
||||
|
||||
for aid in auth_events:
|
||||
aev = yield _get_event(aid, event_map, state_res_store)
|
||||
aev = yield _get_event(room_id, aid, event_map, state_res_store)
|
||||
if (aev.type, aev.state_key) == (EventTypes.PowerLevels, ""):
|
||||
event = aev
|
||||
break
|
||||
@@ -478,11 +516,12 @@ def _get_mainline_depth_for_event(event, mainline_map, event_map, state_res_stor
|
||||
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _get_event(event_id, event_map, state_res_store):
|
||||
def _get_event(room_id, event_id, event_map, state_res_store):
|
||||
"""Helper function to look up event in event_map, falling back to looking
|
||||
it up in the store
|
||||
|
||||
Args:
|
||||
room_id (str)
|
||||
event_id (str)
|
||||
event_map (dict[str,FrozenEvent])
|
||||
state_res_store (StateResolutionStore)
|
||||
@@ -493,7 +532,14 @@ def _get_event(event_id, event_map, state_res_store):
|
||||
if event_id not in event_map:
|
||||
events = yield state_res_store.get_events([event_id], allow_rejected=True)
|
||||
event_map.update(events)
|
||||
return event_map[event_id]
|
||||
event = event_map[event_id]
|
||||
assert event is not None
|
||||
if event.room_id != room_id:
|
||||
raise Exception(
|
||||
"In state res for room %s, event %s is in %s"
|
||||
% (room_id, event_id, event.room_id)
|
||||
)
|
||||
return event
|
||||
|
||||
|
||||
def lexicographical_topological_sort(graph, key):
|
||||
|
||||
@@ -451,18 +451,16 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
|
||||
# Technically an access token might not be associated with
|
||||
# a device so we need to check.
|
||||
if device_id:
|
||||
# this is always an update rather than an upsert: the row should
|
||||
# already exist, and if it doesn't, that may be because it has been
|
||||
# deleted, and we don't want to re-create it.
|
||||
self.db.simple_update_txn(
|
||||
self.db.simple_upsert_txn(
|
||||
txn,
|
||||
table="devices",
|
||||
keyvalues={"user_id": user_id, "device_id": device_id},
|
||||
updatevalues={
|
||||
values={
|
||||
"user_agent": user_agent,
|
||||
"last_seen": last_seen,
|
||||
"ip": ip,
|
||||
},
|
||||
lock=False,
|
||||
)
|
||||
except Exception as e:
|
||||
# Failed to upsert, log and continue
|
||||
|
||||
@@ -14,18 +14,15 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Dict, List
|
||||
|
||||
from six import iteritems
|
||||
|
||||
from canonicaljson import encode_canonical_json, json
|
||||
|
||||
from twisted.enterprise.adbapi import Connection
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.logging.opentracing import log_kv, set_tag, trace
|
||||
from synapse.storage._base import SQLBaseStore, db_to_json
|
||||
from synapse.util.caches.descriptors import cached, cachedList
|
||||
from synapse.util.caches.descriptors import cached
|
||||
|
||||
|
||||
class EndToEndKeyWorkerStore(SQLBaseStore):
|
||||
@@ -274,7 +271,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
|
||||
Args:
|
||||
txn (twisted.enterprise.adbapi.Connection): db connection
|
||||
user_id (str): the user whose key is being requested
|
||||
key_type (str): the type of key that is being requested: either 'master'
|
||||
key_type (str): the type of key that is being set: either 'master'
|
||||
for a master key, 'self_signing' for a self-signing key, or
|
||||
'user_signing' for a user-signing key
|
||||
from_user_id (str): if specified, signatures made by this user on
|
||||
@@ -319,10 +316,8 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
|
||||
"""Returns a user's cross-signing key.
|
||||
|
||||
Args:
|
||||
user_id (str): the user whose key is being requested
|
||||
key_type (str): the type of key that is being requested: either 'master'
|
||||
for a master key, 'self_signing' for a self-signing key, or
|
||||
'user_signing' for a user-signing key
|
||||
user_id (str): the user whose self-signing key is being requested
|
||||
key_type (str): the type of cross-signing key to get
|
||||
from_user_id (str): if specified, signatures made by this user on
|
||||
the self-signing key will be included in the result
|
||||
|
||||
@@ -337,206 +332,6 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
|
||||
from_user_id,
|
||||
)
|
||||
|
||||
@cached(num_args=1)
|
||||
def _get_bare_e2e_cross_signing_keys(self, user_id):
|
||||
"""Dummy function. Only used to make a cache for
|
||||
_get_bare_e2e_cross_signing_keys_bulk.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@cachedList(
|
||||
cached_method_name="_get_bare_e2e_cross_signing_keys",
|
||||
list_name="user_ids",
|
||||
num_args=1,
|
||||
)
|
||||
def _get_bare_e2e_cross_signing_keys_bulk(
|
||||
self, user_ids: List[str]
|
||||
) -> Dict[str, Dict[str, dict]]:
|
||||
"""Returns the cross-signing keys for a set of users. The output of this
|
||||
function should be passed to _get_e2e_cross_signing_signatures_txn if
|
||||
the signatures for the calling user need to be fetched.
|
||||
|
||||
Args:
|
||||
user_ids (list[str]): the users whose keys are being requested
|
||||
|
||||
Returns:
|
||||
dict[str, dict[str, dict]]: mapping from user ID to key type to key
|
||||
data. If a user's cross-signing keys were not found, either
|
||||
their user ID will not be in the dict, or their user ID will map
|
||||
to None.
|
||||
|
||||
"""
|
||||
return self.db.runInteraction(
|
||||
"get_bare_e2e_cross_signing_keys_bulk",
|
||||
self._get_bare_e2e_cross_signing_keys_bulk_txn,
|
||||
user_ids,
|
||||
)
|
||||
|
||||
def _get_bare_e2e_cross_signing_keys_bulk_txn(
|
||||
self, txn: Connection, user_ids: List[str],
|
||||
) -> Dict[str, Dict[str, dict]]:
|
||||
"""Returns the cross-signing keys for a set of users. The output of this
|
||||
function should be passed to _get_e2e_cross_signing_signatures_txn if
|
||||
the signatures for the calling user need to be fetched.
|
||||
|
||||
Args:
|
||||
txn (twisted.enterprise.adbapi.Connection): db connection
|
||||
user_ids (list[str]): the users whose keys are being requested
|
||||
|
||||
Returns:
|
||||
dict[str, dict[str, dict]]: mapping from user ID to key type to key
|
||||
data. If a user's cross-signing keys were not found, their user
|
||||
ID will not be in the dict.
|
||||
|
||||
"""
|
||||
result = {}
|
||||
|
||||
batch_size = 100
|
||||
chunks = [
|
||||
user_ids[i : i + batch_size] for i in range(0, len(user_ids), batch_size)
|
||||
]
|
||||
for user_chunk in chunks:
|
||||
sql = """
|
||||
SELECT k.user_id, k.keytype, k.keydata, k.stream_id
|
||||
FROM e2e_cross_signing_keys k
|
||||
INNER JOIN (SELECT user_id, keytype, MAX(stream_id) AS stream_id
|
||||
FROM e2e_cross_signing_keys
|
||||
GROUP BY user_id, keytype) s
|
||||
USING (user_id, stream_id, keytype)
|
||||
WHERE k.user_id IN (%s)
|
||||
""" % (
|
||||
",".join("?" for u in user_chunk),
|
||||
)
|
||||
query_params = []
|
||||
query_params.extend(user_chunk)
|
||||
|
||||
txn.execute(sql, query_params)
|
||||
rows = self.db.cursor_to_dict(txn)
|
||||
|
||||
for row in rows:
|
||||
user_id = row["user_id"]
|
||||
key_type = row["keytype"]
|
||||
key = json.loads(row["keydata"])
|
||||
user_info = result.setdefault(user_id, {})
|
||||
user_info[key_type] = key
|
||||
|
||||
return result
|
||||
|
||||
def _get_e2e_cross_signing_signatures_txn(
|
||||
self, txn: Connection, keys: Dict[str, Dict[str, dict]], from_user_id: str,
|
||||
) -> Dict[str, Dict[str, dict]]:
|
||||
"""Returns the cross-signing signatures made by a user on a set of keys.
|
||||
|
||||
Args:
|
||||
txn (twisted.enterprise.adbapi.Connection): db connection
|
||||
keys (dict[str, dict[str, dict]]): a map of user ID to key type to
|
||||
key data. This dict will be modified to add signatures.
|
||||
from_user_id (str): fetch the signatures made by this user
|
||||
|
||||
Returns:
|
||||
dict[str, dict[str, dict]]: mapping from user ID to key type to key
|
||||
data. The return value will be the same as the keys argument,
|
||||
with the modifications included.
|
||||
"""
|
||||
|
||||
# find out what cross-signing keys (a.k.a. devices) we need to get
|
||||
# signatures for. This is a map of (user_id, device_id) to key type
|
||||
# (device_id is the key's public part).
|
||||
devices = {}
|
||||
|
||||
for user_id, user_info in keys.items():
|
||||
if user_info is None:
|
||||
continue
|
||||
for key_type, key in user_info.items():
|
||||
device_id = None
|
||||
for k in key["keys"].values():
|
||||
device_id = k
|
||||
devices[(user_id, device_id)] = key_type
|
||||
|
||||
device_list = list(devices)
|
||||
|
||||
# split into batches
|
||||
batch_size = 100
|
||||
chunks = [
|
||||
device_list[i : i + batch_size]
|
||||
for i in range(0, len(device_list), batch_size)
|
||||
]
|
||||
for user_chunk in chunks:
|
||||
sql = """
|
||||
SELECT target_user_id, target_device_id, key_id, signature
|
||||
FROM e2e_cross_signing_signatures
|
||||
WHERE user_id = ?
|
||||
AND (%s)
|
||||
""" % (
|
||||
" OR ".join(
|
||||
"(target_user_id = ? AND target_device_id = ?)" for d in devices
|
||||
)
|
||||
)
|
||||
query_params = [from_user_id]
|
||||
for item in devices:
|
||||
# item is a (user_id, device_id) tuple
|
||||
query_params.extend(item)
|
||||
|
||||
txn.execute(sql, query_params)
|
||||
rows = self.db.cursor_to_dict(txn)
|
||||
|
||||
# and add the signatures to the appropriate keys
|
||||
for row in rows:
|
||||
key_id = row["key_id"]
|
||||
target_user_id = row["target_user_id"]
|
||||
target_device_id = row["target_device_id"]
|
||||
key_type = devices[(target_user_id, target_device_id)]
|
||||
# We need to copy everything, because the result may have come
|
||||
# from the cache. dict.copy only does a shallow copy, so we
|
||||
# need to recursively copy the dicts that will be modified.
|
||||
user_info = keys[target_user_id] = keys[target_user_id].copy()
|
||||
target_user_key = user_info[key_type] = user_info[key_type].copy()
|
||||
if "signatures" in target_user_key:
|
||||
signatures = target_user_key["signatures"] = target_user_key[
|
||||
"signatures"
|
||||
].copy()
|
||||
if from_user_id in signatures:
|
||||
user_sigs = signatures[from_user_id] = signatures[from_user_id]
|
||||
user_sigs[key_id] = row["signature"]
|
||||
else:
|
||||
signatures[from_user_id] = {key_id: row["signature"]}
|
||||
else:
|
||||
target_user_key["signatures"] = {
|
||||
from_user_id: {key_id: row["signature"]}
|
||||
}
|
||||
|
||||
return keys
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_e2e_cross_signing_keys_bulk(
|
||||
self, user_ids: List[str], from_user_id: str = None
|
||||
) -> defer.Deferred:
|
||||
"""Returns the cross-signing keys for a set of users.
|
||||
|
||||
Args:
|
||||
user_ids (list[str]): the users whose keys are being requested
|
||||
from_user_id (str): if specified, signatures made by this user on
|
||||
the self-signing keys will be included in the result
|
||||
|
||||
Returns:
|
||||
Deferred[dict[str, dict[str, dict]]]: map of user ID to key type to
|
||||
key data. If a user's cross-signing keys were not found, either
|
||||
their user ID will not be in the dict, or their user ID will map
|
||||
to None.
|
||||
"""
|
||||
|
||||
result = yield self._get_bare_e2e_cross_signing_keys_bulk(user_ids)
|
||||
|
||||
if from_user_id:
|
||||
result = yield self.db.runInteraction(
|
||||
"get_e2e_cross_signing_signatures",
|
||||
self._get_e2e_cross_signing_signatures_txn,
|
||||
result,
|
||||
from_user_id,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def get_all_user_signature_changes_for_remotes(self, from_key, to_key):
|
||||
"""Return a list of changes from the user signature stream to notify remotes.
|
||||
Note that the user signature stream represents when a user signs their
|
||||
@@ -725,10 +520,6 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
|
||||
},
|
||||
)
|
||||
|
||||
self._invalidate_cache_and_stream(
|
||||
txn, self._get_bare_e2e_cross_signing_keys, (user_id,)
|
||||
)
|
||||
|
||||
def set_e2e_cross_signing_key(self, user_id, key_type, key):
|
||||
"""Set a user's cross-signing key.
|
||||
|
||||
|
||||
@@ -19,10 +19,8 @@ import itertools
|
||||
import logging
|
||||
import threading
|
||||
from collections import namedtuple
|
||||
from typing import List, Optional
|
||||
|
||||
from canonicaljson import json
|
||||
from constantly import NamedConstant, Names
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
@@ -57,16 +55,6 @@ EVENT_QUEUE_TIMEOUT_S = 0.1 # Timeout when waiting for requests for events
|
||||
_EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event"))
|
||||
|
||||
|
||||
class EventRedactBehaviour(Names):
|
||||
"""
|
||||
What to do when retrieving a redacted event from the database.
|
||||
"""
|
||||
|
||||
AS_IS = NamedConstant()
|
||||
REDACT = NamedConstant()
|
||||
BLOCK = NamedConstant()
|
||||
|
||||
|
||||
class EventsWorkerStore(SQLBaseStore):
|
||||
def __init__(self, database: Database, db_conn, hs):
|
||||
super(EventsWorkerStore, self).__init__(database, db_conn, hs)
|
||||
@@ -137,27 +125,25 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
@defer.inlineCallbacks
|
||||
def get_event(
|
||||
self,
|
||||
event_id: List[str],
|
||||
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
|
||||
get_prev_content: bool = False,
|
||||
allow_rejected: bool = False,
|
||||
allow_none: bool = False,
|
||||
check_room_id: Optional[str] = None,
|
||||
event_id,
|
||||
check_redacted=True,
|
||||
get_prev_content=False,
|
||||
allow_rejected=False,
|
||||
allow_none=False,
|
||||
check_room_id=None,
|
||||
):
|
||||
"""Get an event from the database by event_id.
|
||||
|
||||
Args:
|
||||
event_id: The event_id of the event to fetch
|
||||
redact_behaviour: Determine what to do with a redacted event. Possible values:
|
||||
* AS_IS - Return the full event body with no redacted content
|
||||
* REDACT - Return the event but with a redacted body
|
||||
* DISALLOW - Do not return redacted events
|
||||
get_prev_content: If True and event is a state event,
|
||||
event_id (str): The event_id of the event to fetch
|
||||
check_redacted (bool): If True, check if event has been redacted
|
||||
and redact it.
|
||||
get_prev_content (bool): If True and event is a state event,
|
||||
include the previous states content in the unsigned field.
|
||||
allow_rejected: If True return rejected events.
|
||||
allow_none: If True, return None if no event found, if
|
||||
allow_rejected (bool): If True return rejected events.
|
||||
allow_none (bool): If True, return None if no event found, if
|
||||
False throw a NotFoundError
|
||||
check_room_id: if not None, check the room of the found event.
|
||||
check_room_id (str|None): if not None, check the room of the found event.
|
||||
If there is a mismatch, behave as per allow_none.
|
||||
|
||||
Returns:
|
||||
@@ -168,7 +154,7 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
|
||||
events = yield self.get_events_as_list(
|
||||
[event_id],
|
||||
redact_behaviour=redact_behaviour,
|
||||
check_redacted=check_redacted,
|
||||
get_prev_content=get_prev_content,
|
||||
allow_rejected=allow_rejected,
|
||||
)
|
||||
@@ -187,30 +173,27 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
@defer.inlineCallbacks
|
||||
def get_events(
|
||||
self,
|
||||
event_ids: List[str],
|
||||
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
|
||||
get_prev_content: bool = False,
|
||||
allow_rejected: bool = False,
|
||||
event_ids,
|
||||
check_redacted=True,
|
||||
get_prev_content=False,
|
||||
allow_rejected=False,
|
||||
):
|
||||
"""Get events from the database
|
||||
|
||||
Args:
|
||||
event_ids: The event_ids of the events to fetch
|
||||
redact_behaviour: Determine what to do with a redacted event. Possible
|
||||
values:
|
||||
* AS_IS - Return the full event body with no redacted content
|
||||
* REDACT - Return the event but with a redacted body
|
||||
* DISALLOW - Do not return redacted events
|
||||
get_prev_content: If True and event is a state event,
|
||||
event_ids (list): The event_ids of the events to fetch
|
||||
check_redacted (bool): If True, check if event has been redacted
|
||||
and redact it.
|
||||
get_prev_content (bool): If True and event is a state event,
|
||||
include the previous states content in the unsigned field.
|
||||
allow_rejected: If True return rejected events.
|
||||
allow_rejected (bool): If True return rejected events.
|
||||
|
||||
Returns:
|
||||
Deferred : Dict from event_id to event.
|
||||
"""
|
||||
events = yield self.get_events_as_list(
|
||||
event_ids,
|
||||
redact_behaviour=redact_behaviour,
|
||||
check_redacted=check_redacted,
|
||||
get_prev_content=get_prev_content,
|
||||
allow_rejected=allow_rejected,
|
||||
)
|
||||
@@ -220,23 +203,21 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
@defer.inlineCallbacks
|
||||
def get_events_as_list(
|
||||
self,
|
||||
event_ids: List[str],
|
||||
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
|
||||
get_prev_content: bool = False,
|
||||
allow_rejected: bool = False,
|
||||
event_ids,
|
||||
check_redacted=True,
|
||||
get_prev_content=False,
|
||||
allow_rejected=False,
|
||||
):
|
||||
"""Get events from the database and return in a list in the same order
|
||||
as given by `event_ids` arg.
|
||||
|
||||
Args:
|
||||
event_ids: The event_ids of the events to fetch
|
||||
redact_behaviour: Determine what to do with a redacted event. Possible values:
|
||||
* AS_IS - Return the full event body with no redacted content
|
||||
* REDACT - Return the event but with a redacted body
|
||||
* DISALLOW - Do not return redacted events
|
||||
get_prev_content: If True and event is a state event,
|
||||
event_ids (list): The event_ids of the events to fetch
|
||||
check_redacted (bool): If True, check if event has been redacted
|
||||
and redact it.
|
||||
get_prev_content (bool): If True and event is a state event,
|
||||
include the previous states content in the unsigned field.
|
||||
allow_rejected: If True, return rejected events.
|
||||
allow_rejected (bool): If True return rejected events.
|
||||
|
||||
Returns:
|
||||
Deferred[list[EventBase]]: List of events fetched from the database. The
|
||||
@@ -338,14 +319,10 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
# Update the cache to save doing the checks again.
|
||||
entry.event.internal_metadata.recheck_redaction = False
|
||||
|
||||
event = entry.event
|
||||
|
||||
if entry.redacted_event:
|
||||
if redact_behaviour == EventRedactBehaviour.BLOCK:
|
||||
# Skip this event
|
||||
continue
|
||||
elif redact_behaviour == EventRedactBehaviour.REDACT:
|
||||
event = entry.redacted_event
|
||||
if check_redacted and entry.redacted_event:
|
||||
event = entry.redacted_event
|
||||
else:
|
||||
event = entry.event
|
||||
|
||||
events.append(event)
|
||||
|
||||
|
||||
@@ -1,13 +0,0 @@
|
||||
# Building full schema dumps
|
||||
|
||||
These schemas need to be made from a database that has had all background updates run.
|
||||
|
||||
To do so, use `scripts-dev/make_full_schema.sh`. This will produce
|
||||
`full.sql.postgres ` and `full.sql.sqlite` files.
|
||||
|
||||
Ensure postgres is installed and your user has the ability to run bash commands
|
||||
such as `createdb`.
|
||||
|
||||
```
|
||||
./scripts-dev/make_full_schema.sh -p postgres_username -o output_dir/
|
||||
```
|
||||
@@ -0,0 +1,19 @@
|
||||
Building full schema dumps
|
||||
==========================
|
||||
|
||||
These schemas need to be made from a database that has had all background updates run.
|
||||
|
||||
Postgres
|
||||
--------
|
||||
|
||||
$ pg_dump --format=plain --schema-only --no-tablespaces --no-acl --no-owner $DATABASE_NAME| sed -e '/^--/d' -e 's/public\.//g' -e '/^SET /d' -e '/^SELECT /d' > full.sql.postgres
|
||||
|
||||
SQLite
|
||||
------
|
||||
|
||||
$ sqlite3 $DATABASE_FILE ".schema" > full.sql.sqlite
|
||||
|
||||
After
|
||||
-----
|
||||
|
||||
Delete the CREATE statements for "sqlite_stat1", "schema_version", "applied_schema_deltas", and "applied_module_schemas".
|
||||
@@ -25,7 +25,6 @@ from twisted.internet import defer
|
||||
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
|
||||
from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
|
||||
from synapse.storage.database import Database
|
||||
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
|
||||
|
||||
@@ -385,7 +384,7 @@ class SearchStore(SearchBackgroundUpdateStore):
|
||||
"""
|
||||
clauses = []
|
||||
|
||||
search_query = _parse_query(self.database_engine, search_term)
|
||||
search_query = search_query = _parse_query(self.database_engine, search_term)
|
||||
|
||||
args = []
|
||||
|
||||
@@ -454,12 +453,7 @@ class SearchStore(SearchBackgroundUpdateStore):
|
||||
|
||||
results = list(filter(lambda row: row["room_id"] in room_ids, results))
|
||||
|
||||
# We set redact_behaviour to BLOCK here to prevent redacted events being returned in
|
||||
# search results (which is a data leak)
|
||||
events = yield self.get_events_as_list(
|
||||
[r["event_id"] for r in results],
|
||||
redact_behaviour=EventRedactBehaviour.BLOCK,
|
||||
)
|
||||
events = yield self.get_events_as_list([r["event_id"] for r in results])
|
||||
|
||||
event_map = {ev.event_id: ev for ev in events}
|
||||
|
||||
@@ -501,7 +495,7 @@ class SearchStore(SearchBackgroundUpdateStore):
|
||||
"""
|
||||
clauses = []
|
||||
|
||||
search_query = _parse_query(self.database_engine, search_term)
|
||||
search_query = search_query = _parse_query(self.database_engine, search_term)
|
||||
|
||||
args = []
|
||||
|
||||
@@ -606,12 +600,7 @@ class SearchStore(SearchBackgroundUpdateStore):
|
||||
|
||||
results = list(filter(lambda row: row["room_id"] in room_ids, results))
|
||||
|
||||
# We set redact_behaviour to BLOCK here to prevent redacted events being returned in
|
||||
# search results (which is a data leak)
|
||||
events = yield self.get_events_as_list(
|
||||
[r["event_id"] for r in results],
|
||||
redact_behaviour=EventRedactBehaviour.BLOCK,
|
||||
)
|
||||
events = yield self.get_events_as_list([r["event_id"] for r in results])
|
||||
|
||||
event_map = {ev.event_id: ev for ev in events}
|
||||
|
||||
|
||||
@@ -278,7 +278,7 @@ class StateGroupWorkerStore(
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_room_predecessor(self, room_id):
|
||||
"""Get the predecessor of an upgraded room if it exists.
|
||||
"""Get the predecessor room of an upgraded room if one exists.
|
||||
Otherwise return None.
|
||||
|
||||
Args:
|
||||
@@ -291,22 +291,14 @@ class StateGroupWorkerStore(
|
||||
* room_id (str): The room ID of the predecessor room
|
||||
* event_id (str): The ID of the tombstone event in the predecessor room
|
||||
|
||||
None if a predecessor key is not found, or is not a dictionary.
|
||||
|
||||
Raises:
|
||||
NotFoundError if the given room is unknown
|
||||
NotFoundError if the room is unknown
|
||||
"""
|
||||
# Retrieve the room's create event
|
||||
create_event = yield self.get_create_event_for_room(room_id)
|
||||
|
||||
# Retrieve the predecessor key of the create event
|
||||
predecessor = create_event.content.get("predecessor", None)
|
||||
|
||||
# Ensure the key is a dictionary
|
||||
if not isinstance(predecessor, dict):
|
||||
return None
|
||||
|
||||
return predecessor
|
||||
# Return predecessor if present
|
||||
return create_event.content.get("predecessor", None)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_create_event_for_room(self, room_id):
|
||||
@@ -326,7 +318,7 @@ class StateGroupWorkerStore(
|
||||
|
||||
# If we can't find the create event, assume we've hit a dead end
|
||||
if not create_id:
|
||||
raise NotFoundError("Unknown room %s" % (room_id,))
|
||||
raise NotFoundError("Unknown room %s" % (room_id))
|
||||
|
||||
# Retrieve the room's create event and return
|
||||
create_event = yield self.get_event(create_id)
|
||||
|
||||
@@ -25,6 +25,9 @@ class Sqlite3Engine(object):
|
||||
def __init__(self, database_module, database_config):
|
||||
self.module = database_module
|
||||
|
||||
database = database_config.get("args", {}).get("database")
|
||||
self._is_in_memory = database in (None, ":memory:",)
|
||||
|
||||
# The current max state_group, or None if we haven't looked
|
||||
# in the DB yet.
|
||||
self._current_state_group_id = None
|
||||
@@ -59,7 +62,12 @@ class Sqlite3Engine(object):
|
||||
return sql
|
||||
|
||||
def on_new_connection(self, db_conn):
|
||||
prepare_database(db_conn, self, config=None)
|
||||
if self._is_in_memory:
|
||||
# In memory databases need to be rebuilt each time. Ideally we'd
|
||||
# reuse the same connection as we do when starting up, but that
|
||||
# would involve using adbapi before we have started the reactor.
|
||||
prepare_database(db_conn, self, config=None)
|
||||
|
||||
db_conn.create_function("rank", 1, _rank)
|
||||
|
||||
def is_deadlock(self, error):
|
||||
|
||||
@@ -140,8 +140,8 @@ def concurrently_execute(func, args, limit):
|
||||
|
||||
Args:
|
||||
func (func): Function to execute, should return a deferred or coroutine.
|
||||
args (list): List of arguments to pass to func, each invocation of func
|
||||
gets a signle argument.
|
||||
args (Iterable): List of arguments to pass to func, each invocation of func
|
||||
gets a single argument.
|
||||
limit (int): Maximum number of conccurent executions.
|
||||
|
||||
Returns:
|
||||
|
||||
@@ -271,7 +271,7 @@ class _CacheDescriptorBase(object):
|
||||
else:
|
||||
self.function_to_call = orig
|
||||
|
||||
arg_spec = inspect.getfullargspec(orig)
|
||||
arg_spec = inspect.getargspec(orig)
|
||||
all_args = arg_spec.args
|
||||
|
||||
if "cache_context" in all_args:
|
||||
|
||||
94
synapse/util/caches/snapshot_cache.py
Normal file
94
synapse/util/caches/snapshot_cache.py
Normal file
@@ -0,0 +1,94 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2015, 2016 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from synapse.util.async_helpers import ObservableDeferred
|
||||
|
||||
|
||||
class SnapshotCache(object):
|
||||
"""Cache for snapshots like the response of /initialSync.
|
||||
The response of initialSync only has to be a recent snapshot of the
|
||||
server state. It shouldn't matter to clients if it is a few minutes out
|
||||
of date.
|
||||
|
||||
This caches a deferred response. Until the deferred completes it will be
|
||||
returned from the cache. This means that if the client retries the request
|
||||
while the response is still being computed, that original response will be
|
||||
used rather than trying to compute a new response.
|
||||
|
||||
Once the deferred completes it will removed from the cache after 5 minutes.
|
||||
We delay removing it from the cache because a client retrying its request
|
||||
could race with us finishing computing the response.
|
||||
|
||||
Rather than tracking precisely how long something has been in the cache we
|
||||
keep two generations of completed responses. Every 5 minutes discard the
|
||||
old generation, move the new generation to the old generation, and set the
|
||||
new generation to be empty. This means that a result will be in the cache
|
||||
somewhere between 5 and 10 minutes.
|
||||
"""
|
||||
|
||||
DURATION_MS = 5 * 60 * 1000 # Cache results for 5 minutes.
|
||||
|
||||
def __init__(self):
|
||||
self.pending_result_cache = {} # Request that haven't finished yet.
|
||||
self.prev_result_cache = {} # The older requests that have finished.
|
||||
self.next_result_cache = {} # The newer requests that have finished.
|
||||
self.time_last_rotated_ms = 0
|
||||
|
||||
def rotate(self, time_now_ms):
|
||||
# Rotate once if the cache duration has passed since the last rotation.
|
||||
if time_now_ms - self.time_last_rotated_ms >= self.DURATION_MS:
|
||||
self.prev_result_cache = self.next_result_cache
|
||||
self.next_result_cache = {}
|
||||
self.time_last_rotated_ms += self.DURATION_MS
|
||||
|
||||
# Rotate again if the cache duration has passed twice since the last
|
||||
# rotation.
|
||||
if time_now_ms - self.time_last_rotated_ms >= self.DURATION_MS:
|
||||
self.prev_result_cache = self.next_result_cache
|
||||
self.next_result_cache = {}
|
||||
self.time_last_rotated_ms = time_now_ms
|
||||
|
||||
def get(self, time_now_ms, key):
|
||||
self.rotate(time_now_ms)
|
||||
# This cache is intended to deduplicate requests, so we expect it to be
|
||||
# missed most of the time. So we just lookup the key in all of the
|
||||
# dictionaries rather than trying to short circuit the lookup if the
|
||||
# key is found.
|
||||
result = self.prev_result_cache.get(key)
|
||||
result = self.next_result_cache.get(key, result)
|
||||
result = self.pending_result_cache.get(key, result)
|
||||
if result is not None:
|
||||
return result.observe()
|
||||
else:
|
||||
return None
|
||||
|
||||
def set(self, time_now_ms, key, deferred):
|
||||
self.rotate(time_now_ms)
|
||||
|
||||
result = ObservableDeferred(deferred)
|
||||
|
||||
self.pending_result_cache[key] = result
|
||||
|
||||
def shuffle_along(r):
|
||||
# When the deferred completes we shuffle it along to the first
|
||||
# generation of the result cache. So that it will eventually
|
||||
# expire from the rotation of that cache.
|
||||
self.next_result_cache[key] = result
|
||||
self.pending_result_cache.pop(key, None)
|
||||
return r
|
||||
|
||||
result.addBoth(shuffle_along)
|
||||
|
||||
return result.observe()
|
||||
@@ -52,7 +52,8 @@ def filter_events_for_client(
|
||||
apply_retention_policies=True,
|
||||
):
|
||||
"""
|
||||
Check which events a user is allowed to see
|
||||
Check which events a user is allowed to see. If the user can see the event but its
|
||||
sender asked for their data to be erased, prune the content of the event.
|
||||
|
||||
Args:
|
||||
storage
|
||||
|
||||
@@ -183,6 +183,10 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
|
||||
)
|
||||
self.assertDictEqual(devices["master_keys"], {local_user: keys2["master_key"]})
|
||||
|
||||
test_replace_master_key.skip = (
|
||||
"Disabled waiting on #https://github.com/matrix-org/synapse/pull/6486"
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_reupload_signatures(self):
|
||||
"""re-uploading a signature should not fail"""
|
||||
@@ -503,3 +507,7 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
|
||||
],
|
||||
other_master_key["signatures"][local_user]["ed25519:" + usersigning_pubkey],
|
||||
)
|
||||
|
||||
test_upload_signatures.skip = (
|
||||
"Disabled waiting on #https://github.com/matrix-org/synapse/pull/6486"
|
||||
)
|
||||
|
||||
@@ -29,6 +29,7 @@ import synapse.rest.admin
|
||||
from synapse.api.constants import EventContentFields, EventTypes, Membership
|
||||
from synapse.handlers.pagination import PurgeStatus
|
||||
from synapse.rest.client.v1 import login, profile, room
|
||||
from synapse.rest.client.v2_alpha import account
|
||||
from synapse.util.stringutils import random_string
|
||||
|
||||
from tests import unittest
|
||||
@@ -1597,3 +1598,129 @@ class LabelsTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
|
||||
return event_id
|
||||
|
||||
|
||||
class ContextTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
servlets = [
|
||||
synapse.rest.admin.register_servlets_for_client_rest_resource,
|
||||
room.register_servlets,
|
||||
login.register_servlets,
|
||||
account.register_servlets,
|
||||
]
|
||||
|
||||
def prepare(self, reactor, clock, homeserver):
|
||||
self.user_id = self.register_user("user", "password")
|
||||
self.tok = self.login("user", "password")
|
||||
self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
|
||||
|
||||
self.other_user_id = self.register_user("user2", "password")
|
||||
self.other_tok = self.login("user2", "password")
|
||||
|
||||
self.helper.invite(self.room_id, self.user_id, self.other_user_id, tok=self.tok)
|
||||
self.helper.join(self.room_id, self.other_user_id, tok=self.other_tok)
|
||||
|
||||
def test_erased_sender(self):
|
||||
"""Test that an erasure request results in the requester's events being hidden
|
||||
from any new member of the room.
|
||||
"""
|
||||
|
||||
# Send a bunch of events in the room.
|
||||
|
||||
self.helper.send(self.room_id, "message 1", tok=self.tok)
|
||||
self.helper.send(self.room_id, "message 2", tok=self.tok)
|
||||
event_id = self.helper.send(self.room_id, "message 3", tok=self.tok)["event_id"]
|
||||
self.helper.send(self.room_id, "message 4", tok=self.tok)
|
||||
self.helper.send(self.room_id, "message 5", tok=self.tok)
|
||||
|
||||
# Check that we can still see the messages before the erasure request.
|
||||
|
||||
request, channel = self.make_request(
|
||||
"GET",
|
||||
'/rooms/%s/context/%s?filter={"types":["m.room.message"]}'
|
||||
% (self.room_id, event_id),
|
||||
access_token=self.tok,
|
||||
)
|
||||
self.render(request)
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
|
||||
events_before = channel.json_body["events_before"]
|
||||
|
||||
self.assertEqual(len(events_before), 2, events_before)
|
||||
self.assertEqual(
|
||||
events_before[0].get("content", {}).get("body"),
|
||||
"message 2",
|
||||
events_before[0],
|
||||
)
|
||||
self.assertEqual(
|
||||
events_before[1].get("content", {}).get("body"),
|
||||
"message 1",
|
||||
events_before[1],
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
channel.json_body["event"].get("content", {}).get("body"),
|
||||
"message 3",
|
||||
channel.json_body["event"],
|
||||
)
|
||||
|
||||
events_after = channel.json_body["events_after"]
|
||||
|
||||
self.assertEqual(len(events_after), 2, events_after)
|
||||
self.assertEqual(
|
||||
events_after[0].get("content", {}).get("body"),
|
||||
"message 4",
|
||||
events_after[0],
|
||||
)
|
||||
self.assertEqual(
|
||||
events_after[1].get("content", {}).get("body"),
|
||||
"message 5",
|
||||
events_after[1],
|
||||
)
|
||||
|
||||
# Deactivate the first account and erase the user's data.
|
||||
|
||||
deactivate_account_handler = self.hs.get_deactivate_account_handler()
|
||||
self.get_success(
|
||||
deactivate_account_handler.deactivate_account(self.user_id, erase_data=True)
|
||||
)
|
||||
|
||||
# Invite another user in the room. This is needed because messages will be
|
||||
# pruned only if the user wasn't a member of the room when the messages were
|
||||
# sent.
|
||||
|
||||
invited_user_id = self.register_user("user3", "password")
|
||||
invited_tok = self.login("user3", "password")
|
||||
|
||||
self.helper.invite(
|
||||
self.room_id, self.other_user_id, invited_user_id, tok=self.other_tok
|
||||
)
|
||||
self.helper.join(self.room_id, invited_user_id, tok=invited_tok)
|
||||
|
||||
# Check that a user that joined the room after the erasure request can't see
|
||||
# the messages anymore.
|
||||
|
||||
request, channel = self.make_request(
|
||||
"GET",
|
||||
'/rooms/%s/context/%s?filter={"types":["m.room.message"]}'
|
||||
% (self.room_id, event_id),
|
||||
access_token=invited_tok,
|
||||
)
|
||||
self.render(request)
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
|
||||
events_before = channel.json_body["events_before"]
|
||||
|
||||
self.assertEqual(len(events_before), 2, events_before)
|
||||
self.assertDictEqual(events_before[0].get("content"), {}, events_before[0])
|
||||
self.assertDictEqual(events_before[1].get("content"), {}, events_before[1])
|
||||
|
||||
self.assertDictEqual(
|
||||
channel.json_body["event"].get("content"), {}, channel.json_body["event"]
|
||||
)
|
||||
|
||||
events_after = channel.json_body["events_after"]
|
||||
|
||||
self.assertEqual(len(events_after), 2, events_after)
|
||||
self.assertDictEqual(events_after[0].get("content"), {}, events_after[0])
|
||||
self.assertEqual(events_after[1].get("content"), {}, events_after[1])
|
||||
|
||||
@@ -391,8 +391,9 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
|
||||
# Email config.
|
||||
self.email_attempts = []
|
||||
|
||||
async def sendmail(*args, **kwargs):
|
||||
def sendmail(*args, **kwargs):
|
||||
self.email_attempts.append((args, kwargs))
|
||||
return
|
||||
|
||||
config["email"] = {
|
||||
"enable_notifs": True,
|
||||
|
||||
@@ -58,6 +58,7 @@ class FakeEvent(object):
|
||||
self.type = type
|
||||
self.state_key = state_key
|
||||
self.content = content
|
||||
self.room_id = ROOM_ID
|
||||
|
||||
def to_event(self, auth_events, prev_events):
|
||||
"""Given the auth_events and prev_events, convert to a Frozen Event
|
||||
@@ -418,6 +419,7 @@ class StateTestCase(unittest.TestCase):
|
||||
state_before = dict(state_at_event[prev_events[0]])
|
||||
else:
|
||||
state_d = resolve_events_with_store(
|
||||
ROOM_ID,
|
||||
RoomVersions.V2.identifier,
|
||||
[state_at_event[n] for n in prev_events],
|
||||
event_map=event_map,
|
||||
@@ -565,6 +567,7 @@ class SimpleParamStateTestCase(unittest.TestCase):
|
||||
# Test that we correctly handle passing `None` as the event_map
|
||||
|
||||
state_d = resolve_events_with_store(
|
||||
ROOM_ID,
|
||||
RoomVersions.V2.identifier,
|
||||
[self.state_at_bob, self.state_at_charlie],
|
||||
event_map=None,
|
||||
|
||||
@@ -37,13 +37,9 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
|
||||
self.reactor.advance(12345678)
|
||||
|
||||
user_id = "@user:id"
|
||||
device_id = "MY_DEVICE"
|
||||
|
||||
# Insert a user IP
|
||||
self.get_success(self.store.store_device(user_id, device_id, "display name",))
|
||||
self.get_success(
|
||||
self.store.insert_client_ip(
|
||||
user_id, "access_token", "ip", "user_agent", device_id
|
||||
user_id, "access_token", "ip", "user_agent", "device_id"
|
||||
)
|
||||
)
|
||||
|
||||
@@ -51,14 +47,14 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
|
||||
self.reactor.advance(10)
|
||||
|
||||
result = self.get_success(
|
||||
self.store.get_last_client_ip_by_device(user_id, device_id)
|
||||
self.store.get_last_client_ip_by_device(user_id, "device_id")
|
||||
)
|
||||
|
||||
r = result[(user_id, device_id)]
|
||||
r = result[(user_id, "device_id")]
|
||||
self.assertDictContainsSubset(
|
||||
{
|
||||
"user_id": user_id,
|
||||
"device_id": device_id,
|
||||
"device_id": "device_id",
|
||||
"ip": "ip",
|
||||
"user_agent": "user_agent",
|
||||
"last_seen": 12345678000,
|
||||
@@ -213,16 +209,14 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
|
||||
self.store.db.updates.do_next_background_update(100), by=0.1
|
||||
)
|
||||
|
||||
user_id = "@user:id"
|
||||
device_id = "MY_DEVICE"
|
||||
|
||||
# Insert a user IP
|
||||
self.get_success(self.store.store_device(user_id, device_id, "display name",))
|
||||
user_id = "@user:id"
|
||||
self.get_success(
|
||||
self.store.insert_client_ip(
|
||||
user_id, "access_token", "ip", "user_agent", device_id
|
||||
user_id, "access_token", "ip", "user_agent", "device_id"
|
||||
)
|
||||
)
|
||||
|
||||
# Force persisting to disk
|
||||
self.reactor.advance(200)
|
||||
|
||||
@@ -230,7 +224,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
|
||||
self.get_success(
|
||||
self.store.db.simple_update(
|
||||
table="devices",
|
||||
keyvalues={"user_id": user_id, "device_id": device_id},
|
||||
keyvalues={"user_id": user_id, "device_id": "device_id"},
|
||||
updatevalues={"last_seen": None, "ip": None, "user_agent": None},
|
||||
desc="test_devices_last_seen_bg_update",
|
||||
)
|
||||
@@ -238,14 +232,14 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
# We should now get nulls when querying
|
||||
result = self.get_success(
|
||||
self.store.get_last_client_ip_by_device(user_id, device_id)
|
||||
self.store.get_last_client_ip_by_device(user_id, "device_id")
|
||||
)
|
||||
|
||||
r = result[(user_id, device_id)]
|
||||
r = result[(user_id, "device_id")]
|
||||
self.assertDictContainsSubset(
|
||||
{
|
||||
"user_id": user_id,
|
||||
"device_id": device_id,
|
||||
"device_id": "device_id",
|
||||
"ip": None,
|
||||
"user_agent": None,
|
||||
"last_seen": None,
|
||||
@@ -278,14 +272,14 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
# We should now get the correct result again
|
||||
result = self.get_success(
|
||||
self.store.get_last_client_ip_by_device(user_id, device_id)
|
||||
self.store.get_last_client_ip_by_device(user_id, "device_id")
|
||||
)
|
||||
|
||||
r = result[(user_id, device_id)]
|
||||
r = result[(user_id, "device_id")]
|
||||
self.assertDictContainsSubset(
|
||||
{
|
||||
"user_id": user_id,
|
||||
"device_id": device_id,
|
||||
"device_id": "device_id",
|
||||
"ip": "ip",
|
||||
"user_agent": "user_agent",
|
||||
"last_seen": 0,
|
||||
@@ -302,14 +296,11 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
|
||||
self.store.db.updates.do_next_background_update(100), by=0.1
|
||||
)
|
||||
|
||||
user_id = "@user:id"
|
||||
device_id = "MY_DEVICE"
|
||||
|
||||
# Insert a user IP
|
||||
self.get_success(self.store.store_device(user_id, device_id, "display name",))
|
||||
user_id = "@user:id"
|
||||
self.get_success(
|
||||
self.store.insert_client_ip(
|
||||
user_id, "access_token", "ip", "user_agent", device_id
|
||||
user_id, "access_token", "ip", "user_agent", "device_id"
|
||||
)
|
||||
)
|
||||
|
||||
@@ -333,7 +324,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
|
||||
"access_token": "access_token",
|
||||
"ip": "ip",
|
||||
"user_agent": "user_agent",
|
||||
"device_id": device_id,
|
||||
"device_id": "device_id",
|
||||
"last_seen": 0,
|
||||
}
|
||||
],
|
||||
@@ -356,14 +347,14 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
# But we should still get the correct values for the device
|
||||
result = self.get_success(
|
||||
self.store.get_last_client_ip_by_device(user_id, device_id)
|
||||
self.store.get_last_client_ip_by_device(user_id, "device_id")
|
||||
)
|
||||
|
||||
r = result[(user_id, device_id)]
|
||||
r = result[(user_id, "device_id")]
|
||||
self.assertDictContainsSubset(
|
||||
{
|
||||
"user_id": user_id,
|
||||
"device_id": device_id,
|
||||
"device_id": "device_id",
|
||||
"ip": "ip",
|
||||
"user_agent": "user_agent",
|
||||
"last_seen": 0,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from mock import Mock
|
||||
|
||||
from twisted.internet.defer import ensureDeferred, maybeDeferred, succeed
|
||||
from twisted.internet.defer import maybeDeferred, succeed
|
||||
|
||||
from synapse.events import FrozenEvent
|
||||
from synapse.logging.context import LoggingContext
|
||||
@@ -70,10 +70,8 @@ class MessageAcceptTests(unittest.TestCase):
|
||||
)
|
||||
|
||||
# Send the join, it should return None (which is not an error)
|
||||
d = ensureDeferred(
|
||||
self.handler.on_receive_pdu(
|
||||
"test.serv", join_event, sent_to_us_directly=True
|
||||
)
|
||||
d = self.handler.on_receive_pdu(
|
||||
"test.serv", join_event, sent_to_us_directly=True
|
||||
)
|
||||
self.reactor.advance(1)
|
||||
self.assertEqual(self.successResultOf(d), None)
|
||||
@@ -121,10 +119,8 @@ class MessageAcceptTests(unittest.TestCase):
|
||||
)
|
||||
|
||||
with LoggingContext(request="lying_event"):
|
||||
d = ensureDeferred(
|
||||
self.handler.on_receive_pdu(
|
||||
"test.serv", lying_event, sent_to_us_directly=True
|
||||
)
|
||||
d = self.handler.on_receive_pdu(
|
||||
"test.serv", lying_event, sent_to_us_directly=True
|
||||
)
|
||||
|
||||
# Step the reactor, so the database fetches come back
|
||||
|
||||
@@ -17,15 +17,18 @@ from synapse.api.errors import SynapseError
|
||||
from synapse.types import GroupID, RoomAlias, UserID, map_username_to_mxid_localpart
|
||||
|
||||
from tests import unittest
|
||||
from tests.utils import TestHomeServer
|
||||
|
||||
mock_homeserver = TestHomeServer(hostname="my.domain")
|
||||
|
||||
|
||||
class UserIDTestCase(unittest.HomeserverTestCase):
|
||||
class UserIDTestCase(unittest.TestCase):
|
||||
def test_parse(self):
|
||||
user = UserID.from_string("@1234abcd:test")
|
||||
user = UserID.from_string("@1234abcd:my.domain")
|
||||
|
||||
self.assertEquals("1234abcd", user.localpart)
|
||||
self.assertEquals("test", user.domain)
|
||||
self.assertEquals(True, self.hs.is_mine(user))
|
||||
self.assertEquals("my.domain", user.domain)
|
||||
self.assertEquals(True, mock_homeserver.is_mine(user))
|
||||
|
||||
def test_pase_empty(self):
|
||||
with self.assertRaises(SynapseError):
|
||||
@@ -45,13 +48,13 @@ class UserIDTestCase(unittest.HomeserverTestCase):
|
||||
self.assertTrue(userA != userB)
|
||||
|
||||
|
||||
class RoomAliasTestCase(unittest.HomeserverTestCase):
|
||||
class RoomAliasTestCase(unittest.TestCase):
|
||||
def test_parse(self):
|
||||
room = RoomAlias.from_string("#channel:test")
|
||||
room = RoomAlias.from_string("#channel:my.domain")
|
||||
|
||||
self.assertEquals("channel", room.localpart)
|
||||
self.assertEquals("test", room.domain)
|
||||
self.assertEquals(True, self.hs.is_mine(room))
|
||||
self.assertEquals("my.domain", room.domain)
|
||||
self.assertEquals(True, mock_homeserver.is_mine(room))
|
||||
|
||||
def test_build(self):
|
||||
room = RoomAlias("channel", "my.domain")
|
||||
|
||||
@@ -179,30 +179,6 @@ class LoggingContextTestCase(unittest.TestCase):
|
||||
nested_context = nested_logging_context(suffix="bar")
|
||||
self.assertEqual(nested_context.request, "foo-bar")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_make_deferred_yieldable_with_await(self):
|
||||
# an async function which retuns an incomplete coroutine, but doesn't
|
||||
# follow the synapse rules.
|
||||
|
||||
async def blocking_function():
|
||||
d = defer.Deferred()
|
||||
reactor.callLater(0, d.callback, None)
|
||||
await d
|
||||
|
||||
sentinel_context = LoggingContext.current_context()
|
||||
|
||||
with LoggingContext() as context_one:
|
||||
context_one.request = "one"
|
||||
|
||||
d1 = make_deferred_yieldable(blocking_function())
|
||||
# make sure that the context was reset by make_deferred_yieldable
|
||||
self.assertIs(LoggingContext.current_context(), sentinel_context)
|
||||
|
||||
yield d1
|
||||
|
||||
# now it should be restored
|
||||
self._check_test_key("one")
|
||||
|
||||
|
||||
# a function which returns a deferred which has been "called", but
|
||||
# which had a function which returned another incomplete deferred on
|
||||
|
||||
63
tests/util/test_snapshot_cache.py
Normal file
63
tests/util/test_snapshot_cache.py
Normal file
@@ -0,0 +1,63 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2015, 2016 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from twisted.internet.defer import Deferred
|
||||
|
||||
from synapse.util.caches.snapshot_cache import SnapshotCache
|
||||
|
||||
from .. import unittest
|
||||
|
||||
|
||||
class SnapshotCacheTestCase(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.cache = SnapshotCache()
|
||||
self.cache.DURATION_MS = 1
|
||||
|
||||
def test_get_set(self):
|
||||
# Check that getting a missing key returns None
|
||||
self.assertEquals(self.cache.get(0, "key"), None)
|
||||
|
||||
# Check that setting a key with a deferred returns
|
||||
# a deferred that resolves when the initial deferred does
|
||||
d = Deferred()
|
||||
set_result = self.cache.set(0, "key", d)
|
||||
self.assertIsNotNone(set_result)
|
||||
self.assertFalse(set_result.called)
|
||||
|
||||
# Check that getting the key before the deferred has resolved
|
||||
# returns a deferred that resolves when the initial deferred does.
|
||||
get_result_at_10 = self.cache.get(10, "key")
|
||||
self.assertIsNotNone(get_result_at_10)
|
||||
self.assertFalse(get_result_at_10.called)
|
||||
|
||||
# Check that the returned deferreds resolve when the initial deferred
|
||||
# does.
|
||||
d.callback("v")
|
||||
self.assertTrue(set_result.called)
|
||||
self.assertTrue(get_result_at_10.called)
|
||||
|
||||
# Check that getting the key after the deferred has resolved
|
||||
# before the cache expires returns a resolved deferred.
|
||||
get_result_at_11 = self.cache.get(11, "key")
|
||||
self.assertIsNotNone(get_result_at_11)
|
||||
if isinstance(get_result_at_11, Deferred):
|
||||
# The cache may return the actual result rather than a deferred
|
||||
self.assertTrue(get_result_at_11.called)
|
||||
|
||||
# Check that getting the key after the deferred has resolved
|
||||
# after the cache expires returns None
|
||||
get_result_at_12 = self.cache.get(12, "key")
|
||||
self.assertIsNone(get_result_at_12)
|
||||
@@ -260,7 +260,9 @@ def setup_test_homeserver(
|
||||
hs = homeserverToUse(
|
||||
name,
|
||||
config=config,
|
||||
db_config=config.database_config,
|
||||
version_string="Synapse/tests",
|
||||
database_engine=db_engine,
|
||||
tls_server_context_factory=Mock(),
|
||||
tls_client_options_factory=Mock(),
|
||||
reactor=reactor,
|
||||
|
||||
14
tox.ini
14
tox.ini
@@ -177,17 +177,5 @@ env =
|
||||
MYPYPATH = stubs/
|
||||
extras = all
|
||||
commands = mypy \
|
||||
synapse/config/ \
|
||||
synapse/handlers/ui_auth \
|
||||
synapse/logging/ \
|
||||
synapse/module_api \
|
||||
synapse/rest/consent \
|
||||
synapse/rest/media/v0 \
|
||||
synapse/rest/saml2 \
|
||||
synapse/spam_checker_api \
|
||||
synapse/storage/engines \
|
||||
synapse/streams
|
||||
|
||||
# To find all folders that pass mypy you run:
|
||||
#
|
||||
# find synapse/* -type d -not -name __pycache__ -exec bash -c "mypy '{}' > /dev/null" \; -print
|
||||
synapse/config/
|
||||
|
||||
Reference in New Issue
Block a user