1
0

Compare commits

...

63 Commits

Author SHA1 Message Date
Michael Kaye
6543296467 Emit log messages as 'message' not 'log'. 2019-12-13 10:50:18 +00:00
Richard van der Hoff
4ce05ec171 Adjust the sytest blacklist for worker mode (#6538)
Remove tests that got blacklisted while torturing was enabled, and add one that
fails.
2019-12-13 10:15:20 +00:00
Erik Johnston
caa52836e4 Merge pull request #6496 from matrix-org/erikj/initial_sync_asnyc
Port synapse.handlers.initial_sync to async/await
2019-12-13 10:01:51 +00:00
Erik Johnston
7e67cfc87d Merge pull request #6534 from matrix-org/erikj/extend_mypy
Include more folders in mypy
2019-12-12 17:21:42 +00:00
Hubert Chathi
cb2db17994 look up cross-signing keys from the DB in bulk (#6486) 2019-12-12 12:03:28 -05:00
Andrew Morgan
5bfd8855d6 Fix redacted events being returned in search results ordered by "recent" (#6522) 2019-12-12 15:53:49 +00:00
Erik Johnston
c965253e4b Newsfile 2019-12-12 14:54:03 +00:00
Erik Johnston
324d4f61b8 Include more folders in mypy 2019-12-12 14:52:11 +00:00
Richard van der Hoff
25f1244329 Check the room_id of events when fetching room state/auth (#6524)
When we request the state/auth_events to populate a backwards extremity (on
backfill or in the case of missing events in a transaction push), we should
check that the returned events are in the right room rather than blindly using
them in the room state or auth chain.

Given that _get_events_from_store_or_dest takes a room_id, it seems clear that
it should be sanity-checking the room_id of the requested events, so let's do
it there.
2019-12-12 12:57:45 +00:00
Erik Johnston
b8e4b39b69 Merge pull request #6511 from matrix-org/erikj/remove_db_config_from_apps
Move database config from apps into HomeServer object
2019-12-12 10:37:56 +00:00
Erik Johnston
cfcfb57e58 Add new config param to docstring and add types 2019-12-11 17:27:46 +00:00
Erik Johnston
6828b47c45 Merge branch 'develop' of github.com:matrix-org/synapse into erikj/initial_sync_asnyc 2019-12-11 17:01:41 +00:00
Richard van der Hoff
2045356517 Add include_event_in_state to _get_state_for_room (#6521)
Make it return the state *after* the requested event, rather than the one
before it. This is a bit easier and requires fewer calls to
get_events_from_store_or_dest.
2019-12-11 16:37:51 +00:00
Richard van der Hoff
894d2addac Merge pull request #6517 from matrix-org/rav/event_auth/13
Port some of FederationHandler to async/await
2019-12-11 16:36:06 +00:00
Erik Johnston
31905a518e Merge pull request #6504 from matrix-org/erikj/account_validity_async_await
Port handlers.account_validity to async/await.
2019-12-11 14:49:26 +00:00
Richard van der Hoff
5324bc20a6 changelog 2019-12-11 14:39:26 +00:00
Richard van der Hoff
6637d90d77 convert to async: FederationHandler._process_received_pdu
also fix user_joined_room to consistently return deferreds
2019-12-11 14:39:26 +00:00
Richard van der Hoff
4db394a4b3 convert to async: FederationHandler._get_state_for_room
... and _get_events_from_store_or_dest
2019-12-11 14:39:26 +00:00
Richard van der Hoff
e77237b935 convert to async: FederationHandler.on_receive_pdu
and associated functions:
 * on_receive_pdu
 * handle_queued_pdus
 * get_missing_events_for_pdu
2019-12-11 14:39:26 +00:00
Richard van der Hoff
7712e751b8 Convert federation backfill to async
PaginationHandler.get_messages is only called by RoomMessageListRestServlet,
which is async.

Chase the code path down from there:
 - FederationHandler.maybe_backfill (and nested try_backfill)
 - FederationHandler.backfill
2019-12-11 14:39:25 +00:00
Richard van der Hoff
7c429f92d6 Clean up some logging (#6515)
This just makes some of the logging easier to follow when things start going
wrong.
2019-12-11 14:32:25 +00:00
Erik Johnston
adb3a873fd Merge tag 'v1.7.0rc2' into develop
Synapse 1.7.0rc2 (2019-12-11)
=============================

Bugfixes
--------

- Fix incorrect error message for invalid requests when setting user's avatar URL. ([\#6497](https://github.com/matrix-org/synapse/issues/6497))
- Fix support for SQLite 3.7. ([\#6499](https://github.com/matrix-org/synapse/issues/6499))
- Fix regression where sending email push would not work when using a pusher worker. ([\#6507](https://github.com/matrix-org/synapse/issues/6507), [\#6509](https://github.com/matrix-org/synapse/issues/6509))
2019-12-11 14:14:30 +00:00
Andrew Morgan
fc316a4894 Prevent redacted events from appearing in message search (#6377) 2019-12-11 13:39:47 +00:00
Andrew Morgan
6676ee9c4a Add dev script to generate full SQL schema files (#6394) 2019-12-11 13:16:01 +00:00
Andrew Morgan
ea0f0ad414 Prevent message search in upgraded rooms we're not in (#6385) 2019-12-11 13:07:25 +00:00
Brendan Abolivier
54ae52ba96 Merge pull request #6349 from matrix-org/babolivier/msc1802
Implement v2 APIs for send_join and send_leave
2019-12-11 11:41:47 +00:00
Erik Johnston
239d86a134 Merge pull request #6512 from matrix-org/erikj/silence_mypy
Silence mypy errors for files outside those specified
2019-12-11 10:39:57 +00:00
Richard van der Hoff
f8bc2ae883 Move get_state methods into FederationHandler (#6503)
This is a non-functional refactor as a precursor to some other work.
2019-12-10 17:42:46 +00:00
Andrew Morgan
4947de5a14 Allow SAML username provider plugins (#6411) 2019-12-10 17:30:16 +00:00
Erik Johnston
b2dcddc413 Merge pull request #6510 from matrix-org/erikj/phone_home_stats_db
Phone home stats DB reporting should not assume a single DB.
2019-12-10 16:31:24 +00:00
Richard van der Hoff
40eda84933 Fix race which caused deleted devices to reappear (#6514)
Stop the `update_client_ips` background job from recreating deleted devices.
2019-12-10 16:22:29 +00:00
Richard van der Hoff
c3dda2874d Refactor get_events_from_store_or_dest to return a dict (#6501)
There was a bunch of unnecessary conversion back and forth between dict and
list going on here. We can simplify a bunch of the code.
2019-12-10 16:22:00 +00:00
Richard van der Hoff
424fd58237 Remove redundant code from event authorisation implementation. (#6502) 2019-12-10 15:09:45 +00:00
Erik Johnston
3f97b4c16b Newsfile 2019-12-10 14:40:15 +00:00
Erik Johnston
257ef2c727 Port handlers.account_validity to async/await. 2019-12-10 14:40:15 +00:00
Erik Johnston
d630c82349 Newsfile 2019-12-10 14:34:17 +00:00
Erik Johnston
67c991b78f Fix upgrade db script 2019-12-10 14:34:17 +00:00
Erik Johnston
bc5cb8bfe8 Remove database config parsing from apps. 2019-12-10 14:34:17 +00:00
Erik Johnston
35f3c366ef Merge pull request #6505 from matrix-org/erikj/make_deferred_yiedable
Fix `make_deferred_yieldable` to work with coroutines
2019-12-10 14:20:26 +00:00
Erik Johnston
e726e18737 Merge pull request #6499 from matrix-org/erikj/fix_sqlite_7
Fix support for SQLite 3.7.
2019-12-10 13:43:52 +00:00
Erik Johnston
4643bb2a37 Newsfile 2019-12-10 13:36:00 +00:00
Erik Johnston
28b758fa0f Silence mypy errors for files outside those specified 2019-12-10 13:34:56 +00:00
Erik Johnston
accd343f91 Newsfile 2019-12-10 13:22:42 +00:00
Erik Johnston
663238aeb4 Phone home stats DB reporting should not assume a single DB. 2019-12-10 13:21:04 +00:00
Erik Johnston
ffeafade48 Update comment 2019-12-10 13:17:39 +00:00
Erik Johnston
e3f528c544 Merge pull request #6506 from matrix-org/erikj/remove_snapshot_cache
Remove SnapshotCache in favour of ResponseCache
2019-12-10 13:04:50 +00:00
Erik Johnston
b1e7012dee Newsfile 2019-12-10 11:29:44 +00:00
Erik Johnston
f5bb1531b7 Newsfile 2019-12-10 11:23:52 +00:00
Erik Johnston
9a2223d4c8 Fix make_deferred_yieldable to work with coroutines 2019-12-10 11:22:12 +00:00
Erik Johnston
353396e3a7 Port handlers.account_data to async/await. 2019-12-10 11:12:56 +00:00
Erik Johnston
aeaeb72ee4 Newsfile 2019-12-09 13:48:14 +00:00
Erik Johnston
a1f8ea9051 Port synapse.handlers.initial_sync to async/await 2019-12-09 13:46:45 +00:00
Erik Johnston
f166a8d1f5 Remove SnapshotCache in favour of ResponseCache 2019-12-09 13:42:49 +00:00
Brendan Abolivier
e126d83f74 Merge branch 'develop' into babolivier/msc1802 2019-12-05 21:00:43 +00:00
Brendan Abolivier
c8444b4152 Merge branch 'develop' into babolivier/msc1802 2019-11-14 10:12:11 +00:00
Brendan Abolivier
94cdd6fffe Lint 2019-11-11 16:56:55 +00:00
Brendan Abolivier
21056ad12a Changelog 2019-11-11 16:53:29 +00:00
Brendan Abolivier
edc4c7d4c5 Lint 2019-11-11 16:51:54 +00:00
Brendan Abolivier
5e18dc7955 Fix prefix for v2/send_leave 2019-11-11 16:46:09 +00:00
Brendan Abolivier
74897de01f Add server-side support to the v2 API 2019-11-11 16:40:45 +00:00
Brendan Abolivier
1e202a90f1 Implement v2 API for send_leave 2019-11-11 16:26:53 +00:00
Brendan Abolivier
92527d7b21 Add missing yield 2019-11-11 16:20:53 +00:00
Brendan Abolivier
4c131b2c78 Implement v2 API for send_join 2019-11-11 15:47:47 +00:00
79 changed files with 1599 additions and 882 deletions

View File

@@ -34,33 +34,8 @@ Device list doesn't change if remote server is down
Remote servers cannot set power levels in rooms without existing powerlevels
Remote servers should reject attempts by non-creators to set the power levels
# new failures as of https://github.com/matrix-org/sytest/pull/753
GET /rooms/:room_id/messages returns a message
GET /rooms/:room_id/messages lazy loads members correctly
Read receipts are sent as events
Only original members of the room can see messages from erased users
Device deletion propagates over federation
If user leaves room, remote user changes device and rejoins we see update in /sync and /keys/changes
Changing user-signing key notifies local users
Newly updated tags appear in an incremental v2 /sync
# https://buildkite.com/matrix-dot-org/synapse/builds/6134#6f67bf47-e234-474d-80e8-c6e1868b15c5
Server correctly handles incoming m.device_list_update
Local device key changes get to remote servers with correct prev_id
AS-ghosted users can use rooms via AS
Ghost user must register before joining room
Test that a message is pushed
Invites are pushed
Rooms with aliases are correctly named in pushed
Rooms with names are correctly named in pushed
Rooms with canonical alias are correctly named in pushed
Rooms with many users are correctly pushed
Don't get pushed for rooms you've muted
Rejected events are not pushed
Test that rejected pushers are removed.
Events come down the correct room
# https://buildkite.com/matrix-dot-org/sytest/builds/326#cca62404-a88a-4fcb-ad41-175fd3377603
Presence changes to UNAVAILABLE are reported to remote room members
If remote user leaves room, changes device and rejoins we see update in sync
uploading self-signing key notifies over federation
Inbound federation can receive redacted events
Outbound federation can request missing events
# this fails reliably with a torture level of 100 due to https://github.com/matrix-org/synapse/issues/6536
Outbound federation requests missing prev_events and then asks for /state_ids and resolves the state

1
changelog.d/6349.feature Normal file
View File

@@ -0,0 +1 @@
Implement v2 APIs for the `send_join` and `send_leave` federation endpoints (as described in [MSC1802](https://github.com/matrix-org/matrix-doc/pull/1802)).

1
changelog.d/6377.bugfix Normal file
View File

@@ -0,0 +1 @@
Prevent redacted events from being returned during message search.

1
changelog.d/6385.bugfix Normal file
View File

@@ -0,0 +1 @@
Prevent error on trying to search a upgraded room when the server is not in the predecessor room.

1
changelog.d/6394.feature Normal file
View File

@@ -0,0 +1 @@
Add a develop script to generate full SQL schemas.

1
changelog.d/6411.feature Normal file
View File

@@ -0,0 +1 @@
Allow custom SAML username mapping functinality through an external provider plugin.

1
changelog.d/6486.bugfix Normal file
View File

@@ -0,0 +1 @@
Improve performance of looking up cross-signing keys.

1
changelog.d/6496.misc Normal file
View File

@@ -0,0 +1 @@
Port synapse.handlers.initial_sync to async/await.

1
changelog.d/6501.misc Normal file
View File

@@ -0,0 +1 @@
Refactor get_events_from_store_or_dest to return a dict.

1
changelog.d/6502.removal Normal file
View File

@@ -0,0 +1 @@
Remove redundant code from event authorisation implementation.

1
changelog.d/6503.misc Normal file
View File

@@ -0,0 +1 @@
Move get_state methods into FederationHandler.

1
changelog.d/6504.misc Normal file
View File

@@ -0,0 +1 @@
Port handlers.account_data and handlers.account_validity to async/await.

1
changelog.d/6505.misc Normal file
View File

@@ -0,0 +1 @@
Make `make_deferred_yieldable` to work with async/await.

1
changelog.d/6506.misc Normal file
View File

@@ -0,0 +1 @@
Remove `SnapshotCache` in favour of `ResponseCache`.

1
changelog.d/6510.misc Normal file
View File

@@ -0,0 +1 @@
Change phone home stats to not assume there is a single database and report information about the database used by the main data store.

1
changelog.d/6511.misc Normal file
View File

@@ -0,0 +1 @@
Move database config from apps into HomeServer object.

1
changelog.d/6512.misc Normal file
View File

@@ -0,0 +1 @@
Silence mypy errors for files outside those specified.

1
changelog.d/6514.bugfix Normal file
View File

@@ -0,0 +1 @@
Fix race which occasionally caused deleted devices to reappear.

1
changelog.d/6515.misc Normal file
View File

@@ -0,0 +1 @@
Clean up some logging when handling incoming events over federation.

1
changelog.d/6517.misc Normal file
View File

@@ -0,0 +1 @@
Port some of FederationHandler to async/await.

1
changelog.d/6521.misc Normal file
View File

@@ -0,0 +1 @@
Refactor some code in the event authentication path for clarity.

1
changelog.d/6522.bugfix Normal file
View File

@@ -0,0 +1 @@
Prevent redacted events from being returned during message search.

2
changelog.d/6524.misc Normal file
View File

@@ -0,0 +1,2 @@
Improve sanity-checking when receiving events over federation.

1
changelog.d/6534.misc Normal file
View File

@@ -0,0 +1 @@
Test more folders against mypy.

1
changelog.d/6538.misc Normal file
View File

@@ -0,0 +1 @@
Adjust the sytest blacklist for worker mode.

View File

@@ -0,0 +1,77 @@
# SAML Mapping Providers
A SAML mapping provider is a Python class (loaded via a Python module) that
works out how to map attributes of a SAML response object to Matrix-specific
user attributes. Details such as user ID localpart, displayname, and even avatar
URLs are all things that can be mapped from talking to a SSO service.
As an example, a SSO service may return the email address
"john.smith@example.com" for a user, whereas Synapse will need to figure out how
to turn that into a displayname when creating a Matrix user for this individual.
It may choose `John Smith`, or `Smith, John [Example.com]` or any number of
variations. As each Synapse configuration may want something different, this is
where SAML mapping providers come into play.
## Enabling Providers
External mapping providers are provided to Synapse in the form of an external
Python module. Retrieve this module from [PyPi](https://pypi.org) or elsewhere,
then tell Synapse where to look for the handler class by editing the
`saml2_config.user_mapping_provider.module` config option.
`saml2_config.user_mapping_provider.config` allows you to provide custom
configuration options to the module. Check with the module's documentation for
what options it provides (if any). The options listed by default are for the
user mapping provider built in to Synapse. If using a custom module, you should
comment these options out and use those specified by the module instead.
## Building a Custom Mapping Provider
A custom mapping provider must specify the following methods:
* `__init__(self, parsed_config)`
- Arguments:
- `parsed_config` - A configuration object that is the return value of the
`parse_config` method. You should set any configuration options needed by
the module here.
* `saml_response_to_user_attributes(self, saml_response, failures)`
- Arguments:
- `saml_response` - A `saml2.response.AuthnResponse` object to extract user
information from.
- `failures` - An `int` that represents the amount of times the returned
mxid localpart mapping has failed. This should be used
to create a deduplicated mxid localpart which should be
returned instead. For example, if this method returns
`john.doe` as the value of `mxid_localpart` in the returned
dict, and that is already taken on the homeserver, this
method will be called again with the same parameters but
with failures=1. The method should then return a different
`mxid_localpart` value, such as `john.doe1`.
- This method must return a dictionary, which will then be used by Synapse
to build a new user. The following keys are allowed:
* `mxid_localpart` - Required. The mxid localpart of the new user.
* `displayname` - The displayname of the new user. If not provided, will default to
the value of `mxid_localpart`.
* `parse_config(config)`
- This method should have the `@staticmethod` decoration.
- Arguments:
- `config` - A `dict` representing the parsed content of the
`saml2_config.user_mapping_provider.config` homeserver config option.
Runs on homeserver startup. Providers should extract any option values
they need here.
- Whatever is returned will be passed back to the user mapping provider module's
`__init__` method during construction.
* `get_saml_attributes(config)`
- This method should have the `@staticmethod` decoration.
- Arguments:
- `config` - A object resulting from a call to `parse_config`.
- Returns a tuple of two sets. The first set equates to the saml auth
response attributes that are required for the module to function, whereas
the second set consists of those attributes which can be used if available,
but are not necessary.
## Synapse's Default Provider
Synapse has a built-in SAML mapping provider if a custom provider isn't
specified in the config. It is located at
[`synapse.handlers.saml_handler.DefaultSamlMappingProvider`](../synapse/handlers/saml_handler.py).

View File

@@ -1250,33 +1250,58 @@ saml2_config:
#
#config_path: "CONFDIR/sp_conf.py"
# the lifetime of a SAML session. This defines how long a user has to
# The lifetime of a SAML session. This defines how long a user has to
# complete the authentication process, if allow_unsolicited is unset.
# The default is 5 minutes.
#
#saml_session_lifetime: 5m
# The SAML attribute (after mapping via the attribute maps) to use to derive
# the Matrix ID from. 'uid' by default.
# An external module can be provided here as a custom solution to
# mapping attributes returned from a saml provider onto a matrix user.
#
#mxid_source_attribute: displayName
user_mapping_provider:
# The custom module's class. Uncomment to use a custom module.
#
#module: mapping_provider.SamlMappingProvider
# The mapping system to use for mapping the saml attribute onto a matrix ID.
# Options include:
# * 'hexencode' (which maps unpermitted characters to '=xx')
# * 'dotreplace' (which replaces unpermitted characters with '.').
# The default is 'hexencode'.
#
#mxid_mapping: dotreplace
# Custom configuration values for the module. Below options are
# intended for the built-in provider, they should be changed if
# using a custom module. This section will be passed as a Python
# dictionary to the module's `parse_config` method.
#
config:
# The SAML attribute (after mapping via the attribute maps) to use
# to derive the Matrix ID from. 'uid' by default.
#
# Note: This used to be configured by the
# saml2_config.mxid_source_attribute option. If that is still
# defined, its value will be used instead.
#
#mxid_source_attribute: displayName
# In previous versions of synapse, the mapping from SAML attribute to MXID was
# always calculated dynamically rather than stored in a table. For backwards-
# compatibility, we will look for user_ids matching such a pattern before
# creating a new account.
# The mapping system to use for mapping the saml attribute onto a
# matrix ID.
#
# Options include:
# * 'hexencode' (which maps unpermitted characters to '=xx')
# * 'dotreplace' (which replaces unpermitted characters with
# '.').
# The default is 'hexencode'.
#
# Note: This used to be configured by the
# saml2_config.mxid_mapping option. If that is still defined, its
# value will be used instead.
#
#mxid_mapping: dotreplace
# In previous versions of synapse, the mapping from SAML attribute to
# MXID was always calculated dynamically rather than stored in a
# table. For backwards- compatibility, we will look for user_ids
# matching such a pattern before creating a new account.
#
# This setting controls the SAML attribute which will be used for this
# backwards-compatibility lookup. Typically it should be 'uid', but if the
# attribute maps are changed, it may be necessary to change it.
# backwards-compatibility lookup. Typically it should be 'uid', but if
# the attribute maps are changed, it may be necessary to change it.
#
# The default is 'uid'.
#

View File

@@ -1,7 +1,7 @@
[mypy]
namespace_packages = True
plugins = mypy_zope:plugin
follow_imports = normal
follow_imports = silent
check_untyped_defs = True
show_error_codes = True
show_traceback = True

184
scripts-dev/make_full_schema.sh Executable file
View File

@@ -0,0 +1,184 @@
#!/bin/bash
#
# This script generates SQL files for creating a brand new Synapse DB with the latest
# schema, on both SQLite3 and Postgres.
#
# It does so by having Synapse generate an up-to-date SQLite DB, then running
# synapse_port_db to convert it to Postgres. It then dumps the contents of both.
POSTGRES_HOST="localhost"
POSTGRES_DB_NAME="synapse_full_schema.$$"
SQLITE_FULL_SCHEMA_OUTPUT_FILE="full.sql.sqlite"
POSTGRES_FULL_SCHEMA_OUTPUT_FILE="full.sql.postgres"
REQUIRED_DEPS=("matrix-synapse" "psycopg2")
usage() {
echo
echo "Usage: $0 -p <postgres_username> -o <path> [-c] [-n] [-h]"
echo
echo "-p <postgres_username>"
echo " Username to connect to local postgres instance. The password will be requested"
echo " during script execution."
echo "-c"
echo " CI mode. Enables coverage tracking and prints every command that the script runs."
echo "-o <path>"
echo " Directory to output full schema files to."
echo "-h"
echo " Display this help text."
}
while getopts "p:co:h" opt; do
case $opt in
p)
POSTGRES_USERNAME=$OPTARG
;;
c)
# Print all commands that are being executed
set -x
# Modify required dependencies for coverage
REQUIRED_DEPS+=("coverage" "coverage-enable-subprocess")
COVERAGE=1
;;
o)
command -v realpath > /dev/null || (echo "The -o flag requires the 'realpath' binary to be installed" && exit 1)
OUTPUT_DIR="$(realpath "$OPTARG")"
;;
h)
usage
exit
;;
\?)
echo "ERROR: Invalid option: -$OPTARG" >&2
usage
exit
;;
esac
done
# Check that required dependencies are installed
unsatisfied_requirements=()
for dep in "${REQUIRED_DEPS[@]}"; do
pip show "$dep" --quiet || unsatisfied_requirements+=("$dep")
done
if [ ${#unsatisfied_requirements} -ne 0 ]; then
echo "Please install the following python packages: ${unsatisfied_requirements[*]}"
exit 1
fi
if [ -z "$POSTGRES_USERNAME" ]; then
echo "No postgres username supplied"
usage
exit 1
fi
if [ -z "$OUTPUT_DIR" ]; then
echo "No output directory supplied"
usage
exit 1
fi
# Create the output directory if it doesn't exist
mkdir -p "$OUTPUT_DIR"
read -rsp "Postgres password for '$POSTGRES_USERNAME': " POSTGRES_PASSWORD
echo ""
# Exit immediately if a command fails
set -e
# cd to root of the synapse directory
cd "$(dirname "$0")/.."
# Create temporary SQLite and Postgres homeserver db configs and key file
TMPDIR=$(mktemp -d)
KEY_FILE=$TMPDIR/test.signing.key # default Synapse signing key path
SQLITE_CONFIG=$TMPDIR/sqlite.conf
SQLITE_DB=$TMPDIR/homeserver.db
POSTGRES_CONFIG=$TMPDIR/postgres.conf
# Ensure these files are delete on script exit
trap 'rm -rf $TMPDIR' EXIT
cat > "$SQLITE_CONFIG" <<EOF
server_name: "test"
signing_key_path: "$KEY_FILE"
macaroon_secret_key: "abcde"
report_stats: false
database:
name: "sqlite3"
args:
database: "$SQLITE_DB"
# Suppress the key server warning.
trusted_key_servers: []
EOF
cat > "$POSTGRES_CONFIG" <<EOF
server_name: "test"
signing_key_path: "$KEY_FILE"
macaroon_secret_key: "abcde"
report_stats: false
database:
name: "psycopg2"
args:
user: "$POSTGRES_USERNAME"
host: "$POSTGRES_HOST"
password: "$POSTGRES_PASSWORD"
database: "$POSTGRES_DB_NAME"
# Suppress the key server warning.
trusted_key_servers: []
EOF
# Generate the server's signing key.
echo "Generating SQLite3 db schema..."
python -m synapse.app.homeserver --generate-keys -c "$SQLITE_CONFIG"
# Make sure the SQLite3 database is using the latest schema and has no pending background update.
echo "Running db background jobs..."
scripts-dev/update_database --database-config "$SQLITE_CONFIG"
# Create the PostgreSQL database.
echo "Creating postgres database..."
createdb $POSTGRES_DB_NAME
echo "Copying data from SQLite3 to Postgres with synapse_port_db..."
if [ -z "$COVERAGE" ]; then
# No coverage needed
scripts/synapse_port_db --sqlite-database "$SQLITE_DB" --postgres-config "$POSTGRES_CONFIG"
else
# Coverage desired
coverage run scripts/synapse_port_db --sqlite-database "$SQLITE_DB" --postgres-config "$POSTGRES_CONFIG"
fi
# Delete schema_version, applied_schema_deltas and applied_module_schemas tables
# This needs to be done after synapse_port_db is run
echo "Dropping unwanted db tables..."
SQL="
DROP TABLE schema_version;
DROP TABLE applied_schema_deltas;
DROP TABLE applied_module_schemas;
"
sqlite3 "$SQLITE_DB" <<< "$SQL"
psql $POSTGRES_DB_NAME -U "$POSTGRES_USERNAME" -w <<< "$SQL"
echo "Dumping SQLite3 schema to '$OUTPUT_DIR/$SQLITE_FULL_SCHEMA_OUTPUT_FILE'..."
sqlite3 "$SQLITE_DB" ".dump" > "$OUTPUT_DIR/$SQLITE_FULL_SCHEMA_OUTPUT_FILE"
echo "Dumping Postgres schema to '$OUTPUT_DIR/$POSTGRES_FULL_SCHEMA_OUTPUT_FILE'..."
pg_dump --format=plain --no-tablespaces --no-acl --no-owner $POSTGRES_DB_NAME | sed -e '/^--/d' -e 's/public\.//g' -e '/^SET /d' -e '/^SELECT /d' > "$OUTPUT_DIR/$POSTGRES_FULL_SCHEMA_OUTPUT_FILE"
echo "Cleaning up temporary Postgres database..."
dropdb $POSTGRES_DB_NAME
echo "Done! Files dumped to: $OUTPUT_DIR"

View File

@@ -26,7 +26,6 @@ from synapse.config.homeserver import HomeServerConfig
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.server import HomeServer
from synapse.storage import DataStore
from synapse.storage.engines import create_engine
from synapse.storage.prepare_database import prepare_database
logger = logging.getLogger("update_database")
@@ -35,21 +34,11 @@ logger = logging.getLogger("update_database")
class MockHomeserver(HomeServer):
DATASTORE_CLASS = DataStore
def __init__(self, config, database_engine, db_conn, **kwargs):
def __init__(self, config, **kwargs):
super(MockHomeserver, self).__init__(
config.server_name,
reactor=reactor,
config=config,
database_engine=database_engine,
**kwargs
config.server_name, reactor=reactor, config=config, **kwargs
)
self.database_engine = database_engine
self.db_conn = db_conn
def get_db_conn(self):
return self.db_conn
if __name__ == "__main__":
parser = argparse.ArgumentParser(
@@ -85,24 +74,14 @@ if __name__ == "__main__":
config = HomeServerConfig()
config.parse_config_dict(hs_config, "", "")
# Create the database engine and a connection to it.
database_engine = create_engine(config.database_config)
db_conn = database_engine.module.connect(
**{
k: v
for k, v in config.database_config.get("args", {}).items()
if not k.startswith("cp_")
}
)
# Instantiate and initialise the homeserver object.
hs = MockHomeserver(config)
db_conn = hs.get_db_conn()
# Update the database to the latest schema.
prepare_database(db_conn, database_engine, config=config)
prepare_database(db_conn, hs.database_engine, config=config)
db_conn.commit()
# Instantiate and initialise the homeserver object.
hs = MockHomeserver(
config, database_engine, db_conn, db_config=config.database_config,
)
# setup instantiates the store within the homeserver object.
hs.setup()
store = hs.get_datastore()

View File

@@ -45,7 +45,6 @@ from synapse.replication.slave.storage.registration import SlavedRegistrationSto
from synapse.replication.slave.storage.room import RoomStore
from synapse.replication.tcp.client import ReplicationClientHandler
from synapse.server import HomeServer
from synapse.storage.engines import create_engine
from synapse.util.logcontext import LoggingContext
from synapse.util.versionstring import get_version_string
@@ -229,14 +228,10 @@ def start(config_options):
synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts
database_engine = create_engine(config.database_config)
ss = AdminCmdServer(
config.server_name,
db_config=config.database_config,
config=config,
version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine,
)
setup_logging(ss, config, use_worker_options=True)

View File

@@ -34,7 +34,6 @@ from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
from synapse.replication.tcp.client import ReplicationClientHandler
from synapse.server import HomeServer
from synapse.storage.engines import create_engine
from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.manhole import manhole
from synapse.util.versionstring import get_version_string
@@ -143,8 +142,6 @@ def start(config_options):
events.USE_FROZEN_DICTS = config.use_frozen_dicts
database_engine = create_engine(config.database_config)
if config.notify_appservices:
sys.stderr.write(
"\nThe appservices must be disabled in the main synapse process"
@@ -159,10 +156,8 @@ def start(config_options):
ps = AppserviceServer(
config.server_name,
db_config=config.database_config,
config=config,
version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine,
)
setup_logging(ps, config, use_worker_options=True)

View File

@@ -62,7 +62,6 @@ from synapse.rest.client.v2_alpha.keys import KeyChangesServlet, KeyQueryServlet
from synapse.rest.client.v2_alpha.register import RegisterRestServlet
from synapse.rest.client.versions import VersionsRestServlet
from synapse.server import HomeServer
from synapse.storage.engines import create_engine
from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.manhole import manhole
from synapse.util.versionstring import get_version_string
@@ -181,14 +180,10 @@ def start(config_options):
events.USE_FROZEN_DICTS = config.use_frozen_dicts
database_engine = create_engine(config.database_config)
ss = ClientReaderServer(
config.server_name,
db_config=config.database_config,
config=config,
version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine,
)
setup_logging(ss, config, use_worker_options=True)

View File

@@ -57,7 +57,6 @@ from synapse.rest.client.v1.room import (
)
from synapse.server import HomeServer
from synapse.storage.data_stores.main.user_directory import UserDirectoryStore
from synapse.storage.engines import create_engine
from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.manhole import manhole
from synapse.util.versionstring import get_version_string
@@ -180,14 +179,10 @@ def start(config_options):
events.USE_FROZEN_DICTS = config.use_frozen_dicts
database_engine = create_engine(config.database_config)
ss = EventCreatorServer(
config.server_name,
db_config=config.database_config,
config=config,
version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine,
)
setup_logging(ss, config, use_worker_options=True)

View File

@@ -46,7 +46,6 @@ from synapse.replication.slave.storage.transactions import SlavedTransactionStor
from synapse.replication.tcp.client import ReplicationClientHandler
from synapse.rest.key.v2 import KeyApiV2Resource
from synapse.server import HomeServer
from synapse.storage.engines import create_engine
from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.manhole import manhole
from synapse.util.versionstring import get_version_string
@@ -162,14 +161,10 @@ def start(config_options):
events.USE_FROZEN_DICTS = config.use_frozen_dicts
database_engine = create_engine(config.database_config)
ss = FederationReaderServer(
config.server_name,
db_config=config.database_config,
config=config,
version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine,
)
setup_logging(ss, config, use_worker_options=True)

View File

@@ -41,7 +41,6 @@ from synapse.replication.tcp.client import ReplicationClientHandler
from synapse.replication.tcp.streams._base import ReceiptsStream
from synapse.server import HomeServer
from synapse.storage.database import Database
from synapse.storage.engines import create_engine
from synapse.types import ReadReceipt
from synapse.util.async_helpers import Linearizer
from synapse.util.httpresourcetree import create_resource_tree
@@ -174,8 +173,6 @@ def start(config_options):
events.USE_FROZEN_DICTS = config.use_frozen_dicts
database_engine = create_engine(config.database_config)
if config.send_federation:
sys.stderr.write(
"\nThe send_federation must be disabled in the main synapse process"
@@ -190,10 +187,8 @@ def start(config_options):
ss = FederationSenderServer(
config.server_name,
db_config=config.database_config,
config=config,
version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine,
)
setup_logging(ss, config, use_worker_options=True)

View File

@@ -39,7 +39,6 @@ from synapse.replication.slave.storage.registration import SlavedRegistrationSto
from synapse.replication.tcp.client import ReplicationClientHandler
from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.server import HomeServer
from synapse.storage.engines import create_engine
from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.manhole import manhole
from synapse.util.versionstring import get_version_string
@@ -234,14 +233,10 @@ def start(config_options):
events.USE_FROZEN_DICTS = config.use_frozen_dicts
database_engine = create_engine(config.database_config)
ss = FrontendProxyServer(
config.server_name,
db_config=config.database_config,
config=config,
version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine,
)
setup_logging(ss, config, use_worker_options=True)

View File

@@ -69,7 +69,7 @@ from synapse.rest.media.v0.content_repository import ContentRepoResource
from synapse.rest.well_known import WellKnownResource
from synapse.server import HomeServer
from synapse.storage import DataStore
from synapse.storage.engines import IncorrectDatabaseSetup, create_engine
from synapse.storage.engines import IncorrectDatabaseSetup
from synapse.storage.prepare_database import UpgradeDatabaseException
from synapse.util.caches import CACHE_SIZE_FACTOR
from synapse.util.httpresourcetree import create_resource_tree
@@ -328,15 +328,10 @@ def setup(config_options):
events.USE_FROZEN_DICTS = config.use_frozen_dicts
database_engine = create_engine(config.database_config)
config.database_config["args"]["cp_openfun"] = database_engine.on_new_connection
hs = SynapseHomeServer(
config.server_name,
db_config=config.database_config,
config=config,
version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine,
)
synapse.config.logger.setup_logging(hs, config, use_worker_options=False)
@@ -519,8 +514,10 @@ def phone_stats_home(hs, stats, stats_process=_stats_process):
# Database version
#
stats["database_engine"] = hs.database_engine.module.__name__
stats["database_server_version"] = hs.database_engine.server_version
# This only reports info about the *main* database.
stats["database_engine"] = hs.get_datastore().db.engine.module.__name__
stats["database_server_version"] = hs.get_datastore().db.engine.server_version
logger.info("Reporting stats to %s: %s" % (hs.config.report_stats_endpoint, stats))
try:
yield hs.get_proxied_http_client().put_json(

View File

@@ -40,7 +40,6 @@ from synapse.rest.admin import register_servlets_for_media_repo
from synapse.rest.media.v0.content_repository import ContentRepoResource
from synapse.server import HomeServer
from synapse.storage.data_stores.main.media_repository import MediaRepositoryStore
from synapse.storage.engines import create_engine
from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.manhole import manhole
from synapse.util.versionstring import get_version_string
@@ -157,14 +156,10 @@ def start(config_options):
events.USE_FROZEN_DICTS = config.use_frozen_dicts
database_engine = create_engine(config.database_config)
ss = MediaRepositoryServer(
config.server_name,
db_config=config.database_config,
config=config,
version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine,
)
setup_logging(ss, config, use_worker_options=True)

View File

@@ -37,7 +37,6 @@ from synapse.replication.slave.storage.room import RoomStore
from synapse.replication.tcp.client import ReplicationClientHandler
from synapse.server import HomeServer
from synapse.storage import DataStore
from synapse.storage.engines import create_engine
from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.manhole import manhole
from synapse.util.versionstring import get_version_string
@@ -203,14 +202,10 @@ def start(config_options):
# Force the pushers to start since they will be disabled in the main config
config.start_pushers = True
database_engine = create_engine(config.database_config)
ps = PusherServer(
config.server_name,
db_config=config.database_config,
config=config,
version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine,
)
setup_logging(ps, config, use_worker_options=True)

View File

@@ -55,7 +55,6 @@ from synapse.rest.client.v1.room import RoomInitialSyncRestServlet
from synapse.rest.client.v2_alpha import sync
from synapse.server import HomeServer
from synapse.storage.data_stores.main.presence import UserPresenceState
from synapse.storage.engines import create_engine
from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.manhole import manhole
from synapse.util.stringutils import random_string
@@ -437,14 +436,10 @@ def start(config_options):
synapse.events.USE_FROZEN_DICTS = config.use_frozen_dicts
database_engine = create_engine(config.database_config)
ss = SynchrotronServer(
config.server_name,
db_config=config.database_config,
config=config,
version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine,
application_service_handler=SynchrotronApplicationService(),
)

View File

@@ -44,7 +44,6 @@ from synapse.rest.client.v2_alpha import user_directory
from synapse.server import HomeServer
from synapse.storage.data_stores.main.user_directory import UserDirectoryStore
from synapse.storage.database import Database
from synapse.storage.engines import create_engine
from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.manhole import manhole
@@ -200,8 +199,6 @@ def start(config_options):
events.USE_FROZEN_DICTS = config.use_frozen_dicts
database_engine = create_engine(config.database_config)
if config.update_user_directory:
sys.stderr.write(
"\nThe update_user_directory must be disabled in the main synapse process"
@@ -216,10 +213,8 @@ def start(config_options):
ss = UserDirectoryServer(
config.server_name,
db_config=config.database_config,
config=config,
version_string="Synapse/" + get_version_string(synapse),
database_engine=database_engine,
)
setup_logging(ss, config, use_worker_options=True)

View File

@@ -14,17 +14,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import re
import logging
from synapse.python_dependencies import DependencyException, check_requirements
from synapse.types import (
map_username_to_mxid_localpart,
mxid_localpart_allowed_characters,
)
from synapse.util.module_loader import load_python_module
from synapse.util.module_loader import load_module, load_python_module
from ._base import Config, ConfigError
logger = logging.getLogger(__name__)
DEFAULT_USER_MAPPING_PROVIDER = (
"synapse.handlers.saml_handler.DefaultSamlMappingProvider"
)
def _dict_merge(merge_dict, into_dict):
"""Do a deep merge of two dicts
@@ -75,15 +77,69 @@ class SAML2Config(Config):
self.saml2_enabled = True
self.saml2_mxid_source_attribute = saml2_config.get(
"mxid_source_attribute", "uid"
)
self.saml2_grandfathered_mxid_source_attribute = saml2_config.get(
"grandfathered_mxid_source_attribute", "uid"
)
saml2_config_dict = self._default_saml_config_dict()
# user_mapping_provider may be None if the key is present but has no value
ump_dict = saml2_config.get("user_mapping_provider") or {}
# Use the default user mapping provider if not set
ump_dict.setdefault("module", DEFAULT_USER_MAPPING_PROVIDER)
# Ensure a config is present
ump_dict["config"] = ump_dict.get("config") or {}
if ump_dict["module"] == DEFAULT_USER_MAPPING_PROVIDER:
# Load deprecated options for use by the default module
old_mxid_source_attribute = saml2_config.get("mxid_source_attribute")
if old_mxid_source_attribute:
logger.warning(
"The config option saml2_config.mxid_source_attribute is deprecated. "
"Please use saml2_config.user_mapping_provider.config"
".mxid_source_attribute instead."
)
ump_dict["config"]["mxid_source_attribute"] = old_mxid_source_attribute
old_mxid_mapping = saml2_config.get("mxid_mapping")
if old_mxid_mapping:
logger.warning(
"The config option saml2_config.mxid_mapping is deprecated. Please "
"use saml2_config.user_mapping_provider.config.mxid_mapping instead."
)
ump_dict["config"]["mxid_mapping"] = old_mxid_mapping
# Retrieve an instance of the module's class
# Pass the config dictionary to the module for processing
(
self.saml2_user_mapping_provider_class,
self.saml2_user_mapping_provider_config,
) = load_module(ump_dict)
# Ensure loaded user mapping module has defined all necessary methods
# Note parse_config() is already checked during the call to load_module
required_methods = [
"get_saml_attributes",
"saml_response_to_user_attributes",
]
missing_methods = [
method
for method in required_methods
if not hasattr(self.saml2_user_mapping_provider_class, method)
]
if missing_methods:
raise ConfigError(
"Class specified by saml2_config."
"user_mapping_provider.module is missing required "
"methods: %s" % (", ".join(missing_methods),)
)
# Get the desired saml auth response attributes from the module
saml2_config_dict = self._default_saml_config_dict(
*self.saml2_user_mapping_provider_class.get_saml_attributes(
self.saml2_user_mapping_provider_config
)
)
_dict_merge(
merge_dict=saml2_config.get("sp_config", {}), into_dict=saml2_config_dict
)
@@ -103,22 +159,27 @@ class SAML2Config(Config):
saml2_config.get("saml_session_lifetime", "5m")
)
mapping = saml2_config.get("mxid_mapping", "hexencode")
try:
self.saml2_mxid_mapper = MXID_MAPPER_MAP[mapping]
except KeyError:
raise ConfigError("%s is not a known mxid_mapping" % (mapping,))
def _default_saml_config_dict(
self, required_attributes: set, optional_attributes: set
):
"""Generate a configuration dictionary with required and optional attributes that
will be needed to process new user registration
def _default_saml_config_dict(self):
Args:
required_attributes: SAML auth response attributes that are
necessary to function
optional_attributes: SAML auth response attributes that can be used to add
additional information to Synapse user accounts, but are not required
Returns:
dict: A SAML configuration dictionary
"""
import saml2
public_baseurl = self.public_baseurl
if public_baseurl is None:
raise ConfigError("saml2_config requires a public_baseurl to be set")
required_attributes = {"uid", self.saml2_mxid_source_attribute}
optional_attributes = {"displayName"}
if self.saml2_grandfathered_mxid_source_attribute:
optional_attributes.add(self.saml2_grandfathered_mxid_source_attribute)
optional_attributes -= required_attributes
@@ -207,33 +268,58 @@ class SAML2Config(Config):
#
#config_path: "%(config_dir_path)s/sp_conf.py"
# the lifetime of a SAML session. This defines how long a user has to
# The lifetime of a SAML session. This defines how long a user has to
# complete the authentication process, if allow_unsolicited is unset.
# The default is 5 minutes.
#
#saml_session_lifetime: 5m
# The SAML attribute (after mapping via the attribute maps) to use to derive
# the Matrix ID from. 'uid' by default.
# An external module can be provided here as a custom solution to
# mapping attributes returned from a saml provider onto a matrix user.
#
#mxid_source_attribute: displayName
user_mapping_provider:
# The custom module's class. Uncomment to use a custom module.
#
#module: mapping_provider.SamlMappingProvider
# The mapping system to use for mapping the saml attribute onto a matrix ID.
# Options include:
# * 'hexencode' (which maps unpermitted characters to '=xx')
# * 'dotreplace' (which replaces unpermitted characters with '.').
# The default is 'hexencode'.
#
#mxid_mapping: dotreplace
# Custom configuration values for the module. Below options are
# intended for the built-in provider, they should be changed if
# using a custom module. This section will be passed as a Python
# dictionary to the module's `parse_config` method.
#
config:
# The SAML attribute (after mapping via the attribute maps) to use
# to derive the Matrix ID from. 'uid' by default.
#
# Note: This used to be configured by the
# saml2_config.mxid_source_attribute option. If that is still
# defined, its value will be used instead.
#
#mxid_source_attribute: displayName
# In previous versions of synapse, the mapping from SAML attribute to MXID was
# always calculated dynamically rather than stored in a table. For backwards-
# compatibility, we will look for user_ids matching such a pattern before
# creating a new account.
# The mapping system to use for mapping the saml attribute onto a
# matrix ID.
#
# Options include:
# * 'hexencode' (which maps unpermitted characters to '=xx')
# * 'dotreplace' (which replaces unpermitted characters with
# '.').
# The default is 'hexencode'.
#
# Note: This used to be configured by the
# saml2_config.mxid_mapping option. If that is still defined, its
# value will be used instead.
#
#mxid_mapping: dotreplace
# In previous versions of synapse, the mapping from SAML attribute to
# MXID was always calculated dynamically rather than stored in a
# table. For backwards- compatibility, we will look for user_ids
# matching such a pattern before creating a new account.
#
# This setting controls the SAML attribute which will be used for this
# backwards-compatibility lookup. Typically it should be 'uid', but if the
# attribute maps are changed, it may be necessary to change it.
# backwards-compatibility lookup. Typically it should be 'uid', but if
# the attribute maps are changed, it may be necessary to change it.
#
# The default is 'uid'.
#
@@ -241,23 +327,3 @@ class SAML2Config(Config):
""" % {
"config_dir_path": config_dir_path
}
DOT_REPLACE_PATTERN = re.compile(
("[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters)),))
)
def dot_replace_for_mxid(username: str) -> str:
username = username.lower()
username = DOT_REPLACE_PATTERN.sub(".", username)
# regular mxids aren't allowed to start with an underscore either
username = re.sub("^_", "", username)
return username
MXID_MAPPER_MAP = {
"hexencode": map_username_to_mxid_localpart,
"dotreplace": dot_replace_for_mxid,
}

View File

@@ -42,6 +42,8 @@ def check(room_version, event, auth_events, do_sig_check=True, do_size_check=Tru
Returns:
if the auth checks pass.
"""
assert isinstance(auth_events, dict)
if do_size_check:
_check_size_limits(event)
@@ -74,12 +76,6 @@ def check(room_version, event, auth_events, do_sig_check=True, do_size_check=Tru
if not event.signatures.get(event_id_domain):
raise AuthError(403, "Event not signed by sending server")
if auth_events is None:
# Oh, we don't know what the state of the room was, so we
# are trusting that this is allowed (at least for now)
logger.warning("Trusting event: %s", event.event_id)
return
if event.type == EventTypes.Create:
sender_domain = get_domain_from_id(event.sender)
room_id_domain = get_domain_from_id(event.room_id)

View File

@@ -18,8 +18,6 @@ import copy
import itertools
import logging
from six.moves import range
from prometheus_client import Counter
from twisted.internet import defer
@@ -39,7 +37,7 @@ from synapse.api.room_versions import (
)
from synapse.events import builder, room_version_to_event_format
from synapse.federation.federation_base import FederationBase, event_from_pdu_json
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.logging.context import make_deferred_yieldable
from synapse.logging.utils import log_function
from synapse.util import unwrapFirstError
from synapse.util.caches.expiringcache import ExpiringCache
@@ -310,19 +308,12 @@ class FederationClient(FederationBase):
return signed_pdu
@defer.inlineCallbacks
@log_function
def get_state_for_room(self, destination, room_id, event_id):
"""Requests all of the room state at a given event from a remote homeserver.
Args:
destination (str): The remote homeserver to query for the state.
room_id (str): The id of the room we're interested in.
event_id (str): The id of the event we want the state at.
def get_room_state_ids(self, destination: str, room_id: str, event_id: str):
"""Calls the /state_ids endpoint to fetch the state at a particular point
in the room, and the auth events for the given event
Returns:
Deferred[Tuple[List[EventBase], List[EventBase]]]:
A list of events in the state, and a list of events in the auth chain
for the given event.
Tuple[List[str], List[str]]: a tuple of (state event_ids, auth event_ids)
"""
result = yield self.transport_layer.get_room_state_ids(
destination, room_id, event_id=event_id
@@ -331,86 +322,12 @@ class FederationClient(FederationBase):
state_event_ids = result["pdu_ids"]
auth_event_ids = result.get("auth_chain_ids", [])
fetched_events, failed_to_fetch = yield self.get_events_from_store_or_dest(
destination, room_id, set(state_event_ids + auth_event_ids)
)
if not isinstance(state_event_ids, list) or not isinstance(
auth_event_ids, list
):
raise Exception("invalid response from /state_ids")
if failed_to_fetch:
logger.warning(
"Failed to fetch missing state/auth events for %s: %s",
room_id,
failed_to_fetch,
)
event_map = {ev.event_id: ev for ev in fetched_events}
pdus = [event_map[e_id] for e_id in state_event_ids if e_id in event_map]
auth_chain = [event_map[e_id] for e_id in auth_event_ids if e_id in event_map]
auth_chain.sort(key=lambda e: e.depth)
return pdus, auth_chain
@defer.inlineCallbacks
def get_events_from_store_or_dest(self, destination, room_id, event_ids):
"""Fetch events from a remote destination, checking if we already have them.
Args:
destination (str)
room_id (str)
event_ids (list)
Returns:
Deferred: A deferred resolving to a 2-tuple where the first is a list of
events and the second is a list of event ids that we failed to fetch.
"""
seen_events = yield self.store.get_events(event_ids, allow_rejected=True)
signed_events = list(seen_events.values())
failed_to_fetch = set()
missing_events = set(event_ids)
for k in seen_events:
missing_events.discard(k)
if not missing_events:
return signed_events, failed_to_fetch
logger.debug(
"Fetching unknown state/auth events %s for room %s",
missing_events,
event_ids,
)
room_version = yield self.store.get_room_version(room_id)
batch_size = 20
missing_events = list(missing_events)
for i in range(0, len(missing_events), batch_size):
batch = set(missing_events[i : i + batch_size])
deferreds = [
run_in_background(
self.get_pdu,
destinations=[destination],
event_id=e_id,
room_version=room_version,
)
for e_id in batch
]
res = yield make_deferred_yieldable(
defer.DeferredList(deferreds, consumeErrors=True)
)
for success, result in res:
if success and result:
signed_events.append(result)
batch.discard(result.event_id)
# We removed all events we successfully fetched from `batch`
failed_to_fetch.update(batch)
return signed_events, failed_to_fetch
return state_event_ids, auth_event_ids
@defer.inlineCallbacks
@log_function
@@ -609,13 +526,7 @@ class FederationClient(FederationBase):
@defer.inlineCallbacks
def send_request(destination):
time_now = self._clock.time_msec()
_, content = yield self.transport_layer.send_join(
destination=destination,
room_id=pdu.room_id,
event_id=pdu.event_id,
content=pdu.get_pdu_json(time_now),
)
content = yield self._do_send_join(destination, pdu)
logger.debug("Got content: %s", content)
@@ -682,6 +593,44 @@ class FederationClient(FederationBase):
return self._try_destination_list("send_join", destinations, send_request)
@defer.inlineCallbacks
def _do_send_join(self, destination, pdu):
time_now = self._clock.time_msec()
try:
content = yield self.transport_layer.send_join_v2(
destination=destination,
room_id=pdu.room_id,
event_id=pdu.event_id,
content=pdu.get_pdu_json(time_now),
)
return content
except HttpResponseException as e:
if e.code in [400, 404]:
err = e.to_synapse_error()
# If we receive an error response that isn't a generic error, or an
# unrecognised endpoint error, we assume that the remote understands
# the v2 invite API and this is a legitimate error.
if err.errcode not in [Codes.UNKNOWN, Codes.UNRECOGNIZED]:
raise err
else:
raise e.to_synapse_error()
logger.debug("Couldn't send_join with the v2 API, falling back to the v1 API")
resp = yield self.transport_layer.send_join_v1(
destination=destination,
room_id=pdu.room_id,
event_id=pdu.event_id,
content=pdu.get_pdu_json(time_now),
)
# We expect the v1 API to respond with [200, content], so we only return the
# content.
return resp[1]
@defer.inlineCallbacks
def send_invite(self, destination, room_id, event_id, pdu):
room_version = yield self.store.get_room_version(room_id)
@@ -791,18 +740,50 @@ class FederationClient(FederationBase):
@defer.inlineCallbacks
def send_request(destination):
time_now = self._clock.time_msec()
_, content = yield self.transport_layer.send_leave(
content = yield self._do_send_leave(destination, pdu)
logger.debug("Got content: %s", content)
return None
return self._try_destination_list("send_leave", destinations, send_request)
@defer.inlineCallbacks
def _do_send_leave(self, destination, pdu):
time_now = self._clock.time_msec()
try:
content = yield self.transport_layer.send_leave_v2(
destination=destination,
room_id=pdu.room_id,
event_id=pdu.event_id,
content=pdu.get_pdu_json(time_now),
)
logger.debug("Got content: %s", content)
return None
return content
except HttpResponseException as e:
if e.code in [400, 404]:
err = e.to_synapse_error()
return self._try_destination_list("send_leave", destinations, send_request)
# If we receive an error response that isn't a generic error, or an
# unrecognised endpoint error, we assume that the remote understands
# the v2 invite API and this is a legitimate error.
if err.errcode not in [Codes.UNKNOWN, Codes.UNRECOGNIZED]:
raise err
else:
raise e.to_synapse_error()
logger.debug("Couldn't send_leave with the v2 API, falling back to the v1 API")
resp = yield self.transport_layer.send_leave_v1(
destination=destination,
room_id=pdu.room_id,
event_id=pdu.event_id,
content=pdu.get_pdu_json(time_now),
)
# We expect the v1 API to respond with [200, content], so we only return the
# content.
return resp[1]
def get_public_rooms(
self,

View File

@@ -384,15 +384,10 @@ class FederationServer(FederationBase):
res_pdus = await self.handler.on_send_join_request(origin, pdu)
time_now = self._clock.time_msec()
return (
200,
{
"state": [p.get_pdu_json(time_now) for p in res_pdus["state"]],
"auth_chain": [
p.get_pdu_json(time_now) for p in res_pdus["auth_chain"]
],
},
)
return {
"state": [p.get_pdu_json(time_now) for p in res_pdus["state"]],
"auth_chain": [p.get_pdu_json(time_now) for p in res_pdus["auth_chain"]],
}
async def on_make_leave_request(self, origin, room_id, user_id):
origin_host, _ = parse_server_name(origin)
@@ -419,7 +414,7 @@ class FederationServer(FederationBase):
pdu = await self._check_sigs_and_hash(room_version, pdu)
await self.handler.on_send_leave_request(origin, pdu)
return 200, {}
return {}
async def on_event_auth(self, origin, room_id, event_id):
with (await self._server_linearizer.queue((origin, room_id))):

View File

@@ -243,7 +243,7 @@ class TransportLayerClient(object):
@defer.inlineCallbacks
@log_function
def send_join(self, destination, room_id, event_id, content):
def send_join_v1(self, destination, room_id, event_id, content):
path = _create_v1_path("/send_join/%s/%s", room_id, event_id)
response = yield self.client.put_json(
@@ -254,7 +254,18 @@ class TransportLayerClient(object):
@defer.inlineCallbacks
@log_function
def send_leave(self, destination, room_id, event_id, content):
def send_join_v2(self, destination, room_id, event_id, content):
path = _create_v2_path("/send_join/%s/%s", room_id, event_id)
response = yield self.client.put_json(
destination=destination, path=path, data=content
)
return response
@defer.inlineCallbacks
@log_function
def send_leave_v1(self, destination, room_id, event_id, content):
path = _create_v1_path("/send_leave/%s/%s", room_id, event_id)
response = yield self.client.put_json(
@@ -270,6 +281,24 @@ class TransportLayerClient(object):
return response
@defer.inlineCallbacks
@log_function
def send_leave_v2(self, destination, room_id, event_id, content):
path = _create_v2_path("/send_leave/%s/%s", room_id, event_id)
response = yield self.client.put_json(
destination=destination,
path=path,
data=content,
# we want to do our best to send this through. The problem is
# that if it fails, we won't retry it later, so if the remote
# server was just having a momentary blip, the room will be out of
# sync.
ignore_backoff=True,
)
return response
@defer.inlineCallbacks
@log_function
def send_invite_v1(self, destination, room_id, event_id, content):

View File

@@ -506,9 +506,19 @@ class FederationMakeLeaveServlet(BaseFederationServlet):
return 200, content
class FederationSendLeaveServlet(BaseFederationServlet):
class FederationV1SendLeaveServlet(BaseFederationServlet):
PATH = "/send_leave/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
async def on_PUT(self, origin, content, query, room_id, event_id):
content = await self.handler.on_send_leave_request(origin, content, room_id)
return 200, (200, content)
class FederationV2SendLeaveServlet(BaseFederationServlet):
PATH = "/send_leave/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)"
PREFIX = FEDERATION_V2_PREFIX
async def on_PUT(self, origin, content, query, room_id, event_id):
content = await self.handler.on_send_leave_request(origin, content, room_id)
return 200, content
@@ -521,9 +531,21 @@ class FederationEventAuthServlet(BaseFederationServlet):
return await self.handler.on_event_auth(origin, context, event_id)
class FederationSendJoinServlet(BaseFederationServlet):
class FederationV1SendJoinServlet(BaseFederationServlet):
PATH = "/send_join/(?P<context>[^/]*)/(?P<event_id>[^/]*)"
async def on_PUT(self, origin, content, query, context, event_id):
# TODO(paul): assert that context/event_id parsed from path actually
# match those given in content
content = await self.handler.on_send_join_request(origin, content, context)
return 200, (200, content)
class FederationV2SendJoinServlet(BaseFederationServlet):
PATH = "/send_join/(?P<context>[^/]*)/(?P<event_id>[^/]*)"
PREFIX = FEDERATION_V2_PREFIX
async def on_PUT(self, origin, content, query, context, event_id):
# TODO(paul): assert that context/event_id parsed from path actually
# match those given in content
@@ -1367,8 +1389,10 @@ FEDERATION_SERVLET_CLASSES = (
FederationMakeJoinServlet,
FederationMakeLeaveServlet,
FederationEventServlet,
FederationSendJoinServlet,
FederationSendLeaveServlet,
FederationV1SendJoinServlet,
FederationV2SendJoinServlet,
FederationV1SendLeaveServlet,
FederationV2SendLeaveServlet,
FederationV1InviteServlet,
FederationV2InviteServlet,
FederationQueryAuthServlet,

View File

@@ -13,8 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
class AccountDataEventSource(object):
def __init__(self, hs):
@@ -23,15 +21,14 @@ class AccountDataEventSource(object):
def get_current_key(self, direction="f"):
return self.store.get_max_account_data_stream_id()
@defer.inlineCallbacks
def get_new_events(self, user, from_key, **kwargs):
async def get_new_events(self, user, from_key, **kwargs):
user_id = user.to_string()
last_stream_id = from_key
current_stream_id = yield self.store.get_max_account_data_stream_id()
current_stream_id = self.store.get_max_account_data_stream_id()
results = []
tags = yield self.store.get_updated_tags(user_id, last_stream_id)
tags = await self.store.get_updated_tags(user_id, last_stream_id)
for room_id, room_tags in tags.items():
results.append(
@@ -41,7 +38,7 @@ class AccountDataEventSource(object):
(
account_data,
room_account_data,
) = yield self.store.get_updated_account_data_for_user(user_id, last_stream_id)
) = await self.store.get_updated_account_data_for_user(user_id, last_stream_id)
for account_data_type, content in account_data.items():
results.append({"type": account_data_type, "content": content})
@@ -54,6 +51,5 @@ class AccountDataEventSource(object):
return results, current_stream_id
@defer.inlineCallbacks
def get_pagination_rows(self, user, config, key):
async def get_pagination_rows(self, user, config, key):
return [], config.to_id

View File

@@ -18,8 +18,7 @@ import email.utils
import logging
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
from twisted.internet import defer
from typing import List
from synapse.api.errors import StoreError
from synapse.logging.context import make_deferred_yieldable
@@ -78,42 +77,39 @@ class AccountValidityHandler(object):
# run as a background process to make sure that the database transactions
# have a logcontext to report to
return run_as_background_process(
"send_renewals", self.send_renewal_emails
"send_renewals", self._send_renewal_emails
)
self.clock.looping_call(send_emails, 30 * 60 * 1000)
@defer.inlineCallbacks
def send_renewal_emails(self):
async def _send_renewal_emails(self):
"""Gets the list of users whose account is expiring in the amount of time
configured in the ``renew_at`` parameter from the ``account_validity``
configuration, and sends renewal emails to all of these users as long as they
have an email 3PID attached to their account.
"""
expiring_users = yield self.store.get_users_expiring_soon()
expiring_users = await self.store.get_users_expiring_soon()
if expiring_users:
for user in expiring_users:
yield self._send_renewal_email(
await self._send_renewal_email(
user_id=user["user_id"], expiration_ts=user["expiration_ts_ms"]
)
@defer.inlineCallbacks
def send_renewal_email_to_user(self, user_id):
expiration_ts = yield self.store.get_expiration_ts_for_user(user_id)
yield self._send_renewal_email(user_id, expiration_ts)
async def send_renewal_email_to_user(self, user_id: str):
expiration_ts = await self.store.get_expiration_ts_for_user(user_id)
await self._send_renewal_email(user_id, expiration_ts)
@defer.inlineCallbacks
def _send_renewal_email(self, user_id, expiration_ts):
async def _send_renewal_email(self, user_id: str, expiration_ts: int):
"""Sends out a renewal email to every email address attached to the given user
with a unique link allowing them to renew their account.
Args:
user_id (str): ID of the user to send email(s) to.
expiration_ts (int): Timestamp in milliseconds for the expiration date of
user_id: ID of the user to send email(s) to.
expiration_ts: Timestamp in milliseconds for the expiration date of
this user's account (used in the email templates).
"""
addresses = yield self._get_email_addresses_for_user(user_id)
addresses = await self._get_email_addresses_for_user(user_id)
# Stop right here if the user doesn't have at least one email address.
# In this case, they will have to ask their server admin to renew their
@@ -125,7 +121,7 @@ class AccountValidityHandler(object):
return
try:
user_display_name = yield self.store.get_profile_displayname(
user_display_name = await self.store.get_profile_displayname(
UserID.from_string(user_id).localpart
)
if user_display_name is None:
@@ -133,7 +129,7 @@ class AccountValidityHandler(object):
except StoreError:
user_display_name = user_id
renewal_token = yield self._get_renewal_token(user_id)
renewal_token = await self._get_renewal_token(user_id)
url = "%s_matrix/client/unstable/account_validity/renew?token=%s" % (
self.hs.config.public_baseurl,
renewal_token,
@@ -165,7 +161,7 @@ class AccountValidityHandler(object):
logger.info("Sending renewal email to %s", address)
yield make_deferred_yieldable(
await make_deferred_yieldable(
self.sendmail(
self.hs.config.email_smtp_host,
self._raw_from,
@@ -180,19 +176,18 @@ class AccountValidityHandler(object):
)
)
yield self.store.set_renewal_mail_status(user_id=user_id, email_sent=True)
await self.store.set_renewal_mail_status(user_id=user_id, email_sent=True)
@defer.inlineCallbacks
def _get_email_addresses_for_user(self, user_id):
async def _get_email_addresses_for_user(self, user_id: str) -> List[str]:
"""Retrieve the list of email addresses attached to a user's account.
Args:
user_id (str): ID of the user to lookup email addresses for.
user_id: ID of the user to lookup email addresses for.
Returns:
defer.Deferred[list[str]]: Email addresses for this account.
Email addresses for this account.
"""
threepids = yield self.store.user_get_threepids(user_id)
threepids = await self.store.user_get_threepids(user_id)
addresses = []
for threepid in threepids:
@@ -201,16 +196,15 @@ class AccountValidityHandler(object):
return addresses
@defer.inlineCallbacks
def _get_renewal_token(self, user_id):
async def _get_renewal_token(self, user_id: str) -> str:
"""Generates a 32-byte long random string that will be inserted into the
user's renewal email's unique link, then saves it into the database.
Args:
user_id (str): ID of the user to generate a string for.
user_id: ID of the user to generate a string for.
Returns:
defer.Deferred[str]: The generated string.
The generated string.
Raises:
StoreError(500): Couldn't generate a unique string after 5 attempts.
@@ -219,52 +213,52 @@ class AccountValidityHandler(object):
while attempts < 5:
try:
renewal_token = stringutils.random_string(32)
yield self.store.set_renewal_token_for_user(user_id, renewal_token)
await self.store.set_renewal_token_for_user(user_id, renewal_token)
return renewal_token
except StoreError:
attempts += 1
raise StoreError(500, "Couldn't generate a unique string as refresh string.")
@defer.inlineCallbacks
def renew_account(self, renewal_token):
async def renew_account(self, renewal_token: str) -> bool:
"""Renews the account attached to a given renewal token by pushing back the
expiration date by the current validity period in the server's configuration.
Args:
renewal_token (str): Token sent with the renewal request.
renewal_token: Token sent with the renewal request.
Returns:
bool: Whether the provided token is valid.
Whether the provided token is valid.
"""
try:
user_id = yield self.store.get_user_from_renewal_token(renewal_token)
user_id = await self.store.get_user_from_renewal_token(renewal_token)
except StoreError:
defer.returnValue(False)
return False
logger.debug("Renewing an account for user %s", user_id)
yield self.renew_account_for_user(user_id)
await self.renew_account_for_user(user_id)
defer.returnValue(True)
return True
@defer.inlineCallbacks
def renew_account_for_user(self, user_id, expiration_ts=None, email_sent=False):
async def renew_account_for_user(
self, user_id: str, expiration_ts: int = None, email_sent: bool = False
) -> int:
"""Renews the account attached to a given user by pushing back the
expiration date by the current validity period in the server's
configuration.
Args:
renewal_token (str): Token sent with the renewal request.
expiration_ts (int): New expiration date. Defaults to now + validity period.
email_sent (bool): Whether an email has been sent for this validity period.
renewal_token: Token sent with the renewal request.
expiration_ts: New expiration date. Defaults to now + validity period.
email_sen: Whether an email has been sent for this validity period.
Defaults to False.
Returns:
defer.Deferred[int]: New expiration date for this account, as a timestamp
in milliseconds since epoch.
New expiration date for this account, as a timestamp in
milliseconds since epoch.
"""
if expiration_ts is None:
expiration_ts = self.clock.time_msec() + self._account_validity.period
yield self.store.set_account_validity_for_user(
await self.store.set_account_validity_for_user(
user_id=user_id, expiration_ts=expiration_ts, email_sent=email_sent
)

View File

@@ -264,6 +264,7 @@ class E2eKeysHandler(object):
return ret
@defer.inlineCallbacks
def get_cross_signing_keys_from_cache(self, query, from_user_id):
"""Get cross-signing keys for users from the database
@@ -283,14 +284,32 @@ class E2eKeysHandler(object):
self_signing_keys = {}
user_signing_keys = {}
# Currently a stub, implementation coming in https://github.com/matrix-org/synapse/pull/6486
return defer.succeed(
{
"master_keys": master_keys,
"self_signing_keys": self_signing_keys,
"user_signing_keys": user_signing_keys,
}
)
user_ids = list(query)
keys = yield self.store.get_e2e_cross_signing_keys_bulk(user_ids, from_user_id)
for user_id, user_info in keys.items():
if user_info is None:
continue
if "master" in user_info:
master_keys[user_id] = user_info["master"]
if "self_signing" in user_info:
self_signing_keys[user_id] = user_info["self_signing"]
if (
from_user_id in keys
and keys[from_user_id] is not None
and "user_signing" in keys[from_user_id]
):
# users can see other users' master and self-signing keys, but can
# only see their own user-signing keys
user_signing_keys[from_user_id] = keys[from_user_id]["user_signing"]
return {
"master_keys": master_keys,
"self_signing_keys": self_signing_keys,
"user_signing_keys": user_signing_keys,
}
@trace
@defer.inlineCallbacks

View File

@@ -19,7 +19,7 @@
import itertools
import logging
from typing import Dict, Iterable, Optional, Sequence, Tuple
from typing import Dict, Iterable, List, Optional, Sequence, Tuple
import six
from six import iteritems, itervalues
@@ -63,8 +63,9 @@ from synapse.replication.http.federation import (
)
from synapse.replication.http.membership import ReplicationUserJoinedLeftRoomRestServlet
from synapse.state import StateResolutionStore, resolve_events_with_store
from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
from synapse.types import UserID, get_domain_from_id
from synapse.util import unwrapFirstError
from synapse.util import batch_iter, unwrapFirstError
from synapse.util.async_helpers import Linearizer
from synapse.util.distributor import user_joined_room
from synapse.util.retryutils import NotRetryingDestination
@@ -164,8 +165,7 @@ class FederationHandler(BaseHandler):
self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
@defer.inlineCallbacks
def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False):
async def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False) -> None:
""" Process a PDU received via a federation /send/ transaction, or
via backfill of missing prev_events
@@ -175,17 +175,15 @@ class FederationHandler(BaseHandler):
pdu (FrozenEvent): received PDU
sent_to_us_directly (bool): True if this event was pushed to us; False if
we pulled it as the result of a missing prev_event.
Returns (Deferred): completes with None
"""
room_id = pdu.room_id
event_id = pdu.event_id
logger.info("[%s %s] handling received PDU: %s", room_id, event_id, pdu)
logger.info("handling received PDU: %s", pdu)
# We reprocess pdus when we have seen them only as outliers
existing = yield self.store.get_event(
existing = await self.store.get_event(
event_id, allow_none=True, allow_rejected=True
)
@@ -229,7 +227,7 @@ class FederationHandler(BaseHandler):
#
# Note that if we were never in the room then we would have already
# dropped the event, since we wouldn't know the room version.
is_in_room = yield self.auth.check_host_in_room(room_id, self.server_name)
is_in_room = await self.auth.check_host_in_room(room_id, self.server_name)
if not is_in_room:
logger.info(
"[%s %s] Ignoring PDU from %s as we're not in the room",
@@ -245,12 +243,12 @@ class FederationHandler(BaseHandler):
# Get missing pdus if necessary.
if not pdu.internal_metadata.is_outlier():
# We only backfill backwards to the min depth.
min_depth = yield self.get_min_depth_for_context(pdu.room_id)
min_depth = await self.get_min_depth_for_context(pdu.room_id)
logger.debug("[%s %s] min_depth: %d", room_id, event_id, min_depth)
prevs = set(pdu.prev_event_ids())
seen = yield self.store.have_seen_events(prevs)
seen = await self.store.have_seen_events(prevs)
if min_depth and pdu.depth < min_depth:
# This is so that we don't notify the user about this
@@ -270,7 +268,7 @@ class FederationHandler(BaseHandler):
len(missing_prevs),
shortstr(missing_prevs),
)
with (yield self._room_pdu_linearizer.queue(pdu.room_id)):
with (await self._room_pdu_linearizer.queue(pdu.room_id)):
logger.info(
"[%s %s] Acquired room lock to fetch %d missing prev_events",
room_id,
@@ -278,13 +276,19 @@ class FederationHandler(BaseHandler):
len(missing_prevs),
)
yield self._get_missing_events_for_pdu(
origin, pdu, prevs, min_depth
)
try:
await self._get_missing_events_for_pdu(
origin, pdu, prevs, min_depth
)
except Exception as e:
raise Exception(
"Error fetching missing prev_events for %s: %s"
% (event_id, e)
)
# Update the set of things we've seen after trying to
# fetch the missing stuff
seen = yield self.store.have_seen_events(prevs)
seen = await self.store.have_seen_events(prevs)
if not prevs - seen:
logger.info(
@@ -292,14 +296,6 @@ class FederationHandler(BaseHandler):
room_id,
event_id,
)
elif missing_prevs:
logger.info(
"[%s %s] Not recursively fetching %d missing prev_events: %s",
room_id,
event_id,
len(missing_prevs),
shortstr(missing_prevs),
)
if prevs - seen:
# We've still not been able to get all of the prev_events for this event.
@@ -344,13 +340,19 @@ class FederationHandler(BaseHandler):
affected=pdu.event_id,
)
logger.info(
"Event %s is missing prev_events: calculating state for a "
"backwards extremity",
event_id,
)
# Calculate the state after each of the previous events, and
# resolve them to find the correct state at the current event.
auth_chains = set()
event_map = {event_id: pdu}
try:
# Get the state of the events we know about
ours = yield self.state_store.get_state_groups_ids(room_id, seen)
ours = await self.state_store.get_state_groups_ids(room_id, seen)
# state_maps is a list of mappings from (type, state_key) to event_id
state_maps = list(
@@ -364,13 +366,10 @@ class FederationHandler(BaseHandler):
# know about
for p in prevs - seen:
logger.info(
"[%s %s] Requesting state at missing prev_event %s",
room_id,
event_id,
p,
"Requesting state at missing prev_event %s", event_id,
)
room_version = yield self.store.get_room_version(room_id)
room_version = await self.store.get_room_version(room_id)
with nested_logging_context(p):
# note that if any of the missing prevs share missing state or
@@ -379,24 +378,10 @@ class FederationHandler(BaseHandler):
(
remote_state,
got_auth_chain,
) = yield self.federation_client.get_state_for_room(
origin, room_id, p
) = await self._get_state_for_room(
origin, room_id, p, include_event_in_state=True
)
# we want the state *after* p; get_state_for_room returns the
# state *before* p.
remote_event = yield self.federation_client.get_pdu(
[origin], p, room_version, outlier=True
)
if remote_event is None:
raise Exception(
"Unable to get missing prev_event %s" % (p,)
)
if remote_event.is_state():
remote_state.append(remote_event)
# XXX hrm I'm not convinced that duplicate events will compare
# for equality, so I'm not sure this does what the author
# hoped.
@@ -410,7 +395,7 @@ class FederationHandler(BaseHandler):
for x in remote_state:
event_map[x.event_id] = x
state_map = yield resolve_events_with_store(
state_map = await resolve_events_with_store(
room_version,
state_maps,
event_map,
@@ -422,10 +407,10 @@ class FederationHandler(BaseHandler):
# First though we need to fetch all the events that are in
# state_map, so we can build up the state below.
evs = yield self.store.get_events(
evs = await self.store.get_events(
list(state_map.values()),
get_prev_content=False,
check_redacted=False,
redact_behaviour=EventRedactBehaviour.AS_IS,
)
event_map.update(evs)
@@ -446,12 +431,11 @@ class FederationHandler(BaseHandler):
affected=event_id,
)
yield self._process_received_pdu(
await self._process_received_pdu(
origin, pdu, state=state, auth_chain=auth_chain
)
@defer.inlineCallbacks
def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth):
async def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth):
"""
Args:
origin (str): Origin of the pdu. Will be called to get the missing events
@@ -463,12 +447,12 @@ class FederationHandler(BaseHandler):
room_id = pdu.room_id
event_id = pdu.event_id
seen = yield self.store.have_seen_events(prevs)
seen = await self.store.have_seen_events(prevs)
if not prevs - seen:
return
latest = yield self.store.get_latest_event_ids_in_room(room_id)
latest = await self.store.get_latest_event_ids_in_room(room_id)
# We add the prev events that we have seen to the latest
# list to ensure the remote server doesn't give them to us
@@ -532,7 +516,7 @@ class FederationHandler(BaseHandler):
# All that said: Let's try increasing the timout to 60s and see what happens.
try:
missing_events = yield self.federation_client.get_missing_events(
missing_events = await self.federation_client.get_missing_events(
origin,
room_id,
earliest_events_ids=list(latest),
@@ -571,7 +555,7 @@ class FederationHandler(BaseHandler):
)
with nested_logging_context(ev.event_id):
try:
yield self.on_receive_pdu(origin, ev, sent_to_us_directly=False)
await self.on_receive_pdu(origin, ev, sent_to_us_directly=False)
except FederationError as e:
if e.code == 403:
logger.warning(
@@ -583,8 +567,144 @@ class FederationHandler(BaseHandler):
else:
raise
@defer.inlineCallbacks
def _process_received_pdu(self, origin, event, state, auth_chain):
async def _get_state_for_room(
self,
destination: str,
room_id: str,
event_id: str,
include_event_in_state: bool = False,
) -> Tuple[List[EventBase], List[EventBase]]:
"""Requests all of the room state at a given event from a remote homeserver.
Args:
destination: The remote homeserver to query for the state.
room_id: The id of the room we're interested in.
event_id: The id of the event we want the state at.
include_event_in_state: if true, the event itself will be included in the
returned state event list.
Returns:
A list of events in the state, possibly including the event itself, and
a list of events in the auth chain for the given event.
"""
(
state_event_ids,
auth_event_ids,
) = await self.federation_client.get_room_state_ids(
destination, room_id, event_id=event_id
)
desired_events = set(state_event_ids + auth_event_ids)
if include_event_in_state:
desired_events.add(event_id)
event_map = await self._get_events_from_store_or_dest(
destination, room_id, desired_events
)
failed_to_fetch = desired_events - event_map.keys()
if failed_to_fetch:
logger.warning(
"Failed to fetch missing state/auth events for %s %s",
event_id,
failed_to_fetch,
)
remote_state = [
event_map[e_id] for e_id in state_event_ids if e_id in event_map
]
if include_event_in_state:
remote_event = event_map.get(event_id)
if not remote_event:
raise Exception("Unable to get missing prev_event %s" % (event_id,))
if remote_event.is_state():
remote_state.append(remote_event)
auth_chain = [event_map[e_id] for e_id in auth_event_ids if e_id in event_map]
auth_chain.sort(key=lambda e: e.depth)
return remote_state, auth_chain
async def _get_events_from_store_or_dest(
self, destination: str, room_id: str, event_ids: Iterable[str]
) -> Dict[str, EventBase]:
"""Fetch events from a remote destination, checking if we already have them.
Args:
destination
room_id
event_ids
If we fail to fetch any of the events, a warning will be logged, and the event
will be omitted from the result. Likewise, any events which turn out not to
be in the given room.
Returns:
map from event_id to event
"""
fetched_events = await self.store.get_events(event_ids, allow_rejected=True)
missing_events = set(event_ids) - fetched_events.keys()
if missing_events:
logger.debug(
"Fetching unknown state/auth events %s for room %s",
missing_events,
room_id,
)
room_version = await self.store.get_room_version(room_id)
# XXX 20 requests at once? really?
for batch in batch_iter(missing_events, 20):
deferreds = [
run_in_background(
self.federation_client.get_pdu,
destinations=[destination],
event_id=e_id,
room_version=room_version,
)
for e_id in batch
]
res = await make_deferred_yieldable(
defer.DeferredList(deferreds, consumeErrors=True)
)
for success, result in res:
if success and result:
fetched_events[result.event_id] = result
# check for events which were in the wrong room.
#
# this can happen if a remote server claims that the state or
# auth_events at an event in room A are actually events in room B
bad_events = list(
(event_id, event.room_id)
for event_id, event in fetched_events.items()
if event.room_id != room_id
)
for bad_event_id, bad_room_id in bad_events:
# This is a bogus situation, but since we may only discover it a long time
# after it happened, we try our best to carry on, by just omitting the
# bad events from the returned auth/state set.
logger.warning(
"Remote server %s claims event %s in room %s is an auth/state "
"event in room %s",
destination,
bad_event_id,
bad_room_id,
room_id,
)
del fetched_events[bad_event_id]
return fetched_events
async def _process_received_pdu(self, origin, event, state, auth_chain):
""" Called when we have a new pdu. We need to do auth checks and put it
through the StateHandler.
"""
@@ -599,7 +719,7 @@ class FederationHandler(BaseHandler):
if auth_chain:
event_ids |= {e.event_id for e in auth_chain}
seen_ids = yield self.store.have_seen_events(event_ids)
seen_ids = await self.store.have_seen_events(event_ids)
if state and auth_chain is not None:
# If we have any state or auth_chain given to us by the replication
@@ -626,18 +746,18 @@ class FederationHandler(BaseHandler):
event_id,
[e.event.event_id for e in event_infos],
)
yield self._handle_new_events(origin, event_infos)
await self._handle_new_events(origin, event_infos)
try:
context = yield self._handle_new_event(origin, event, state=state)
context = await self._handle_new_event(origin, event, state=state)
except AuthError as e:
raise FederationError("ERROR", e.code, e.msg, affected=event.event_id)
room = yield self.store.get_room(room_id)
room = await self.store.get_room(room_id)
if not room:
try:
yield self.store.store_room(
await self.store.store_room(
room_id=room_id, room_creator_user_id="", is_public=False
)
except StoreError:
@@ -650,11 +770,11 @@ class FederationHandler(BaseHandler):
# changing their profile info.
newly_joined = True
prev_state_ids = yield context.get_prev_state_ids(self.store)
prev_state_ids = await context.get_prev_state_ids(self.store)
prev_state_id = prev_state_ids.get((event.type, event.state_key))
if prev_state_id:
prev_state = yield self.store.get_event(
prev_state = await self.store.get_event(
prev_state_id, allow_none=True
)
if prev_state and prev_state.membership == Membership.JOIN:
@@ -662,11 +782,10 @@ class FederationHandler(BaseHandler):
if newly_joined:
user = UserID.from_string(event.state_key)
yield self.user_joined_room(user, room_id)
await self.user_joined_room(user, room_id)
@log_function
@defer.inlineCallbacks
def backfill(self, dest, room_id, limit, extremities):
async def backfill(self, dest, room_id, limit, extremities):
""" Trigger a backfill request to `dest` for the given `room_id`
This will attempt to get more events from the remote. If the other side
@@ -683,9 +802,9 @@ class FederationHandler(BaseHandler):
if dest == self.server_name:
raise SynapseError(400, "Can't backfill from self.")
room_version = yield self.store.get_room_version(room_id)
room_version = await self.store.get_room_version(room_id)
events = yield self.federation_client.backfill(
events = await self.federation_client.backfill(
dest, room_id, limit=limit, extremities=extremities
)
@@ -700,7 +819,7 @@ class FederationHandler(BaseHandler):
# self._sanity_check_event(ev)
# Don't bother processing events we already have.
seen_events = yield self.store.have_events_in_timeline(
seen_events = await self.store.have_events_in_timeline(
set(e.event_id for e in events)
)
@@ -723,7 +842,7 @@ class FederationHandler(BaseHandler):
state_events = {}
events_to_state = {}
for e_id in edges:
state, auth = yield self.federation_client.get_state_for_room(
state, auth = await self._get_state_for_room(
destination=dest, room_id=room_id, event_id=e_id
)
auth_events.update({a.event_id: a for a in auth})
@@ -748,7 +867,7 @@ class FederationHandler(BaseHandler):
# We repeatedly do this until we stop finding new auth events.
while missing_auth - failed_to_fetch:
logger.info("Missing auth for backfill: %r", missing_auth)
ret_events = yield self.store.get_events(missing_auth - failed_to_fetch)
ret_events = await self.store.get_events(missing_auth - failed_to_fetch)
auth_events.update(ret_events)
required_auth.update(
@@ -762,7 +881,7 @@ class FederationHandler(BaseHandler):
missing_auth - failed_to_fetch,
)
results = yield make_deferred_yieldable(
results = await make_deferred_yieldable(
defer.gatherResults(
[
run_in_background(
@@ -789,7 +908,7 @@ class FederationHandler(BaseHandler):
failed_to_fetch = missing_auth - set(auth_events)
seen_events = yield self.store.have_seen_events(
seen_events = await self.store.have_seen_events(
set(auth_events.keys()) | set(state_events.keys())
)
@@ -851,7 +970,7 @@ class FederationHandler(BaseHandler):
)
)
yield self._handle_new_events(dest, ev_infos, backfilled=True)
await self._handle_new_events(dest, ev_infos, backfilled=True)
# Step 2: Persist the rest of the events in the chunk one by one
events.sort(key=lambda e: e.depth)
@@ -867,16 +986,15 @@ class FederationHandler(BaseHandler):
# We store these one at a time since each event depends on the
# previous to work out the state.
# TODO: We can probably do something more clever here.
yield self._handle_new_event(dest, event, backfilled=True)
await self._handle_new_event(dest, event, backfilled=True)
return events
@defer.inlineCallbacks
def maybe_backfill(self, room_id, current_depth):
async def maybe_backfill(self, room_id, current_depth):
"""Checks the database to see if we should backfill before paginating,
and if so do.
"""
extremities = yield self.store.get_oldest_events_with_depth_in_room(room_id)
extremities = await self.store.get_oldest_events_with_depth_in_room(room_id)
if not extremities:
logger.debug("Not backfilling as no extremeties found.")
@@ -908,15 +1026,17 @@ class FederationHandler(BaseHandler):
# state *before* the event, ignoring the special casing certain event
# types have.
forward_events = yield self.store.get_successor_events(list(extremities))
forward_events = await self.store.get_successor_events(list(extremities))
extremities_events = yield self.store.get_events(
forward_events, check_redacted=False, get_prev_content=False
extremities_events = await self.store.get_events(
forward_events,
redact_behaviour=EventRedactBehaviour.AS_IS,
get_prev_content=False,
)
# We set `check_history_visibility_only` as we might otherwise get false
# positives from users having been erased.
filtered_extremities = yield filter_events_for_server(
filtered_extremities = await filter_events_for_server(
self.storage,
self.server_name,
list(extremities_events.values()),
@@ -946,7 +1066,7 @@ class FederationHandler(BaseHandler):
# First we try hosts that are already in the room
# TODO: HEURISTIC ALERT.
curr_state = yield self.state_handler.get_current_state(room_id)
curr_state = await self.state_handler.get_current_state(room_id)
def get_domains_from_state(state):
"""Get joined domains from state
@@ -985,12 +1105,11 @@ class FederationHandler(BaseHandler):
domain for domain, depth in curr_domains if domain != self.server_name
]
@defer.inlineCallbacks
def try_backfill(domains):
async def try_backfill(domains):
# TODO: Should we try multiple of these at a time?
for dom in domains:
try:
yield self.backfill(
await self.backfill(
dom, room_id, limit=100, extremities=extremities
)
# If this succeeded then we probably already have the
@@ -1021,7 +1140,7 @@ class FederationHandler(BaseHandler):
return False
success = yield try_backfill(likely_domains)
success = await try_backfill(likely_domains)
if success:
return True
@@ -1035,7 +1154,7 @@ class FederationHandler(BaseHandler):
logger.debug("calling resolve_state_groups in _maybe_backfill")
resolve = preserve_fn(self.state_handler.resolve_state_groups_for_events)
states = yield make_deferred_yieldable(
states = await make_deferred_yieldable(
defer.gatherResults(
[resolve(room_id, [e]) for e in event_ids], consumeErrors=True
)
@@ -1045,7 +1164,7 @@ class FederationHandler(BaseHandler):
# event_ids.
states = dict(zip(event_ids, [s.state for s in states]))
state_map = yield self.store.get_events(
state_map = await self.store.get_events(
[e_id for ids in itervalues(states) for e_id in itervalues(ids)],
get_prev_content=False,
)
@@ -1061,7 +1180,7 @@ class FederationHandler(BaseHandler):
for e_id, _ in sorted_extremeties_tuple:
likely_domains = get_domains_from_state(states[e_id])
success = yield try_backfill(
success = await try_backfill(
[dom for dom, _ in likely_domains if dom not in tried_domains]
)
if success:
@@ -1210,7 +1329,7 @@ class FederationHandler(BaseHandler):
# Check whether this room is the result of an upgrade of a room we already know
# about. If so, migrate over user information
predecessor = yield self.store.get_room_predecessor(room_id)
if not predecessor:
if not predecessor or not isinstance(predecessor.get("room_id"), str):
return
old_room_id = predecessor["room_id"]
logger.debug(
@@ -1238,8 +1357,7 @@ class FederationHandler(BaseHandler):
return True
@defer.inlineCallbacks
def _handle_queued_pdus(self, room_queue):
async def _handle_queued_pdus(self, room_queue):
"""Process PDUs which got queued up while we were busy send_joining.
Args:
@@ -1255,7 +1373,7 @@ class FederationHandler(BaseHandler):
p.room_id,
)
with nested_logging_context(p.event_id):
yield self.on_receive_pdu(origin, p, sent_to_us_directly=True)
await self.on_receive_pdu(origin, p, sent_to_us_directly=True)
except Exception as e:
logger.warning(
"Error handling queued PDU %s from %s: %s", p.event_id, origin, e
@@ -1453,7 +1571,7 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks
def do_remotely_reject_invite(self, target_hosts, room_id, user_id, content):
origin, event, event_format_version = yield self._make_and_verify_event(
target_hosts, room_id, user_id, "leave", content=content,
target_hosts, room_id, user_id, "leave", content=content
)
# Mark as outlier as we don't have any state for this event; we're not
# even in the room.
@@ -2814,7 +2932,7 @@ class FederationHandler(BaseHandler):
room_id=room_id, user_id=user.to_string(), change="joined"
)
else:
return user_joined_room(self.distributor, user, room_id)
return defer.succeed(user_joined_room(self.distributor, user, room_id))
@defer.inlineCallbacks
def get_room_complexity(self, remote_room_hosts, room_id):

View File

@@ -26,7 +26,7 @@ from synapse.streams.config import PaginationConfig
from synapse.types import StreamToken, UserID
from synapse.util import unwrapFirstError
from synapse.util.async_helpers import concurrently_execute
from synapse.util.caches.snapshot_cache import SnapshotCache
from synapse.util.caches.response_cache import ResponseCache
from synapse.visibility import filter_events_for_client
from ._base import BaseHandler
@@ -41,7 +41,7 @@ class InitialSyncHandler(BaseHandler):
self.state = hs.get_state_handler()
self.clock = hs.get_clock()
self.validator = EventValidator()
self.snapshot_cache = SnapshotCache()
self.snapshot_cache = ResponseCache(hs, "initial_sync_cache")
self._event_serializer = hs.get_event_client_serializer()
self.storage = hs.get_storage()
self.state_store = self.storage.state
@@ -79,21 +79,17 @@ class InitialSyncHandler(BaseHandler):
as_client_event,
include_archived,
)
now_ms = self.clock.time_msec()
result = self.snapshot_cache.get(now_ms, key)
if result is not None:
return result
return self.snapshot_cache.set(
now_ms,
return self.snapshot_cache.wrap(
key,
self._snapshot_all_rooms(
user_id, pagin_config, as_client_event, include_archived
),
self._snapshot_all_rooms,
user_id,
pagin_config,
as_client_event,
include_archived,
)
@defer.inlineCallbacks
def _snapshot_all_rooms(
async def _snapshot_all_rooms(
self,
user_id=None,
pagin_config=None,
@@ -105,7 +101,7 @@ class InitialSyncHandler(BaseHandler):
if include_archived:
memberships.append(Membership.LEAVE)
room_list = yield self.store.get_rooms_for_user_where_membership_is(
room_list = await self.store.get_rooms_for_user_where_membership_is(
user_id=user_id, membership_list=memberships
)
@@ -113,33 +109,32 @@ class InitialSyncHandler(BaseHandler):
rooms_ret = []
now_token = yield self.hs.get_event_sources().get_current_token()
now_token = await self.hs.get_event_sources().get_current_token()
presence_stream = self.hs.get_event_sources().sources["presence"]
pagination_config = PaginationConfig(from_token=now_token)
presence, _ = yield presence_stream.get_pagination_rows(
presence, _ = await presence_stream.get_pagination_rows(
user, pagination_config.get_source_config("presence"), None
)
receipt_stream = self.hs.get_event_sources().sources["receipt"]
receipt, _ = yield receipt_stream.get_pagination_rows(
receipt, _ = await receipt_stream.get_pagination_rows(
user, pagination_config.get_source_config("receipt"), None
)
tags_by_room = yield self.store.get_tags_for_user(user_id)
tags_by_room = await self.store.get_tags_for_user(user_id)
account_data, account_data_by_room = yield self.store.get_account_data_for_user(
account_data, account_data_by_room = await self.store.get_account_data_for_user(
user_id
)
public_room_ids = yield self.store.get_public_room_ids()
public_room_ids = await self.store.get_public_room_ids()
limit = pagin_config.limit
if limit is None:
limit = 10
@defer.inlineCallbacks
def handle_room(event):
async def handle_room(event):
d = {
"room_id": event.room_id,
"membership": event.membership,
@@ -152,8 +147,8 @@ class InitialSyncHandler(BaseHandler):
time_now = self.clock.time_msec()
d["inviter"] = event.sender
invite_event = yield self.store.get_event(event.event_id)
d["invite"] = yield self._event_serializer.serialize_event(
invite_event = await self.store.get_event(event.event_id)
d["invite"] = await self._event_serializer.serialize_event(
invite_event, time_now, as_client_event
)
@@ -177,7 +172,7 @@ class InitialSyncHandler(BaseHandler):
lambda states: states[event.event_id]
)
(messages, token), current_state = yield make_deferred_yieldable(
(messages, token), current_state = await make_deferred_yieldable(
defer.gatherResults(
[
run_in_background(
@@ -191,7 +186,7 @@ class InitialSyncHandler(BaseHandler):
)
).addErrback(unwrapFirstError)
messages = yield filter_events_for_client(
messages = await filter_events_for_client(
self.storage, user_id, messages
)
@@ -201,7 +196,7 @@ class InitialSyncHandler(BaseHandler):
d["messages"] = {
"chunk": (
yield self._event_serializer.serialize_events(
await self._event_serializer.serialize_events(
messages, time_now=time_now, as_client_event=as_client_event
)
),
@@ -209,7 +204,7 @@ class InitialSyncHandler(BaseHandler):
"end": end_token.to_string(),
}
d["state"] = yield self._event_serializer.serialize_events(
d["state"] = await self._event_serializer.serialize_events(
current_state.values(),
time_now=time_now,
as_client_event=as_client_event,
@@ -232,7 +227,7 @@ class InitialSyncHandler(BaseHandler):
except Exception:
logger.exception("Failed to get snapshot")
yield concurrently_execute(handle_room, room_list, 10)
await concurrently_execute(handle_room, room_list, 10)
account_data_events = []
for account_data_type, content in account_data.items():
@@ -256,8 +251,7 @@ class InitialSyncHandler(BaseHandler):
return ret
@defer.inlineCallbacks
def room_initial_sync(self, requester, room_id, pagin_config=None):
async def room_initial_sync(self, requester, room_id, pagin_config=None):
"""Capture the a snapshot of a room. If user is currently a member of
the room this will be what is currently in the room. If the user left
the room this will be what was in the room when they left.
@@ -274,32 +268,32 @@ class InitialSyncHandler(BaseHandler):
A JSON serialisable dict with the snapshot of the room.
"""
blocked = yield self.store.is_room_blocked(room_id)
blocked = await self.store.is_room_blocked(room_id)
if blocked:
raise SynapseError(403, "This room has been blocked on this server")
user_id = requester.user.to_string()
membership, member_event_id = yield self._check_in_room_or_world_readable(
membership, member_event_id = await self._check_in_room_or_world_readable(
room_id, user_id
)
is_peeking = member_event_id is None
if membership == Membership.JOIN:
result = yield self._room_initial_sync_joined(
result = await self._room_initial_sync_joined(
user_id, room_id, pagin_config, membership, is_peeking
)
elif membership == Membership.LEAVE:
result = yield self._room_initial_sync_parted(
result = await self._room_initial_sync_parted(
user_id, room_id, pagin_config, membership, member_event_id, is_peeking
)
account_data_events = []
tags = yield self.store.get_tags_for_room(user_id, room_id)
tags = await self.store.get_tags_for_room(user_id, room_id)
if tags:
account_data_events.append({"type": "m.tag", "content": {"tags": tags}})
account_data = yield self.store.get_account_data_for_room(user_id, room_id)
account_data = await self.store.get_account_data_for_room(user_id, room_id)
for account_data_type, content in account_data.items():
account_data_events.append({"type": account_data_type, "content": content})
@@ -307,11 +301,10 @@ class InitialSyncHandler(BaseHandler):
return result
@defer.inlineCallbacks
def _room_initial_sync_parted(
async def _room_initial_sync_parted(
self, user_id, room_id, pagin_config, membership, member_event_id, is_peeking
):
room_state = yield self.state_store.get_state_for_events([member_event_id])
room_state = await self.state_store.get_state_for_events([member_event_id])
room_state = room_state[member_event_id]
@@ -319,13 +312,13 @@ class InitialSyncHandler(BaseHandler):
if limit is None:
limit = 10
stream_token = yield self.store.get_stream_token_for_event(member_event_id)
stream_token = await self.store.get_stream_token_for_event(member_event_id)
messages, token = yield self.store.get_recent_events_for_room(
messages, token = await self.store.get_recent_events_for_room(
room_id, limit=limit, end_token=stream_token
)
messages = yield filter_events_for_client(
messages = await filter_events_for_client(
self.storage, user_id, messages, is_peeking=is_peeking
)
@@ -339,13 +332,13 @@ class InitialSyncHandler(BaseHandler):
"room_id": room_id,
"messages": {
"chunk": (
yield self._event_serializer.serialize_events(messages, time_now)
await self._event_serializer.serialize_events(messages, time_now)
),
"start": start_token.to_string(),
"end": end_token.to_string(),
},
"state": (
yield self._event_serializer.serialize_events(
await self._event_serializer.serialize_events(
room_state.values(), time_now
)
),
@@ -353,19 +346,18 @@ class InitialSyncHandler(BaseHandler):
"receipts": [],
}
@defer.inlineCallbacks
def _room_initial_sync_joined(
async def _room_initial_sync_joined(
self, user_id, room_id, pagin_config, membership, is_peeking
):
current_state = yield self.state.get_current_state(room_id=room_id)
current_state = await self.state.get_current_state(room_id=room_id)
# TODO: These concurrently
time_now = self.clock.time_msec()
state = yield self._event_serializer.serialize_events(
state = await self._event_serializer.serialize_events(
current_state.values(), time_now
)
now_token = yield self.hs.get_event_sources().get_current_token()
now_token = await self.hs.get_event_sources().get_current_token()
limit = pagin_config.limit if pagin_config else None
if limit is None:
@@ -380,28 +372,26 @@ class InitialSyncHandler(BaseHandler):
presence_handler = self.hs.get_presence_handler()
@defer.inlineCallbacks
def get_presence():
async def get_presence():
# If presence is disabled, return an empty list
if not self.hs.config.use_presence:
return []
states = yield presence_handler.get_states(
states = await presence_handler.get_states(
[m.user_id for m in room_members], as_event=True
)
return states
@defer.inlineCallbacks
def get_receipts():
receipts = yield self.store.get_linearized_receipts_for_room(
async def get_receipts():
receipts = await self.store.get_linearized_receipts_for_room(
room_id, to_key=now_token.receipt_key
)
if not receipts:
receipts = []
return receipts
presence, receipts, (messages, token) = yield make_deferred_yieldable(
presence, receipts, (messages, token) = await make_deferred_yieldable(
defer.gatherResults(
[
run_in_background(get_presence),
@@ -417,7 +407,7 @@ class InitialSyncHandler(BaseHandler):
).addErrback(unwrapFirstError)
)
messages = yield filter_events_for_client(
messages = await filter_events_for_client(
self.storage, user_id, messages, is_peeking=is_peeking
)
@@ -430,7 +420,7 @@ class InitialSyncHandler(BaseHandler):
"room_id": room_id,
"messages": {
"chunk": (
yield self._event_serializer.serialize_events(messages, time_now)
await self._event_serializer.serialize_events(messages, time_now)
),
"start": start_token.to_string(),
"end": end_token.to_string(),
@@ -444,18 +434,17 @@ class InitialSyncHandler(BaseHandler):
return ret
@defer.inlineCallbacks
def _check_in_room_or_world_readable(self, room_id, user_id):
async def _check_in_room_or_world_readable(self, room_id, user_id):
try:
# check_user_was_in_room will return the most recent membership
# event for the user if:
# * The user is a non-guest user, and was ever in the room
# * The user is a guest user, and has joined the room
# else it will throw.
member_event = yield self.auth.check_user_was_in_room(room_id, user_id)
member_event = await self.auth.check_user_was_in_room(room_id, user_id)
return member_event.membership, member_event.event_id
except AuthError:
visibility = yield self.state_handler.get_current_state(
visibility = await self.state_handler.get_current_state(
room_id, EventTypes.RoomHistoryVisibility, ""
)
if (

View File

@@ -46,6 +46,7 @@ from synapse.events.validator import EventValidator
from synapse.logging.context import run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.http.send_event import ReplicationSendEventRestServlet
from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
from synapse.storage.state import StateFilter
from synapse.types import RoomAlias, UserID, create_requester
from synapse.util.async_helpers import Linearizer
@@ -875,7 +876,7 @@ class EventCreationHandler(object):
if event.type == EventTypes.Redaction:
original_event = yield self.store.get_event(
event.redacts,
check_redacted=False,
redact_behaviour=EventRedactBehaviour.AS_IS,
get_prev_content=False,
allow_rejected=False,
allow_none=True,
@@ -952,7 +953,7 @@ class EventCreationHandler(object):
if event.type == EventTypes.Redaction:
original_event = yield self.store.get_event(
event.redacts,
check_redacted=False,
redact_behaviour=EventRedactBehaviour.AS_IS,
get_prev_content=False,
allow_rejected=False,
allow_none=True,

View File

@@ -280,8 +280,7 @@ class PaginationHandler(object):
await self.storage.purge_events.purge_room(room_id)
@defer.inlineCallbacks
def get_messages(
async def get_messages(
self,
requester,
room_id=None,
@@ -307,7 +306,7 @@ class PaginationHandler(object):
room_token = pagin_config.from_token.room_key
else:
pagin_config.from_token = (
yield self.hs.get_event_sources().get_current_token_for_pagination()
await self.hs.get_event_sources().get_current_token_for_pagination()
)
room_token = pagin_config.from_token.room_key
@@ -319,11 +318,11 @@ class PaginationHandler(object):
source_config = pagin_config.get_source_config("room")
with (yield self.pagination_lock.read(room_id)):
with (await self.pagination_lock.read(room_id)):
(
membership,
member_event_id,
) = yield self.auth.check_in_room_or_world_readable(room_id, user_id)
) = await self.auth.check_in_room_or_world_readable(room_id, user_id)
if source_config.direction == "b":
# if we're going backwards, we might need to backfill. This
@@ -331,7 +330,7 @@ class PaginationHandler(object):
if room_token.topological:
max_topo = room_token.topological
else:
max_topo = yield self.store.get_max_topological_token(
max_topo = await self.store.get_max_topological_token(
room_id, room_token.stream
)
@@ -339,18 +338,18 @@ class PaginationHandler(object):
# If they have left the room then clamp the token to be before
# they left the room, to save the effort of loading from the
# database.
leave_token = yield self.store.get_topological_token_for_event(
leave_token = await self.store.get_topological_token_for_event(
member_event_id
)
leave_token = RoomStreamToken.parse(leave_token)
if leave_token.topological < max_topo:
source_config.from_key = str(leave_token)
yield self.hs.get_handlers().federation_handler.maybe_backfill(
await self.hs.get_handlers().federation_handler.maybe_backfill(
room_id, max_topo
)
events, next_key = yield self.store.paginate_room_events(
events, next_key = await self.store.paginate_room_events(
room_id=room_id,
from_key=source_config.from_key,
to_key=source_config.to_key,
@@ -365,7 +364,7 @@ class PaginationHandler(object):
if event_filter:
events = event_filter.filter(events)
events = yield filter_events_for_client(
events = await filter_events_for_client(
self.storage, user_id, events, is_peeking=(member_event_id is None)
)
@@ -385,19 +384,19 @@ class PaginationHandler(object):
(EventTypes.Member, event.sender) for event in events
)
state_ids = yield self.state_store.get_state_ids_for_event(
state_ids = await self.state_store.get_state_ids_for_event(
events[0].event_id, state_filter=state_filter
)
if state_ids:
state = yield self.store.get_events(list(state_ids.values()))
state = await self.store.get_events(list(state_ids.values()))
state = state.values()
time_now = self.clock.time_msec()
chunk = {
"chunk": (
yield self._event_serializer.serialize_events(
await self._event_serializer.serialize_events(
events, time_now, as_client_event=as_client_event
)
),
@@ -406,7 +405,7 @@ class PaginationHandler(object):
}
if state:
chunk["state"] = yield self._event_serializer.serialize_events(
chunk["state"] = await self._event_serializer.serialize_events(
state, time_now, as_client_event=as_client_event
)

View File

@@ -13,20 +13,36 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import re
from typing import Tuple
import attr
import saml2
import saml2.response
from saml2.client import Saml2Client
from synapse.api.errors import SynapseError
from synapse.config import ConfigError
from synapse.http.servlet import parse_string
from synapse.rest.client.v1.login import SSOAuthHandler
from synapse.types import UserID, map_username_to_mxid_localpart
from synapse.types import (
UserID,
map_username_to_mxid_localpart,
mxid_localpart_allowed_characters,
)
from synapse.util.async_helpers import Linearizer
logger = logging.getLogger(__name__)
@attr.s
class Saml2SessionData:
"""Data we track about SAML2 sessions"""
# time the session was created, in milliseconds
creation_time = attr.ib()
class SamlHandler:
def __init__(self, hs):
self._saml_client = Saml2Client(hs.config.saml2_sp_config)
@@ -37,11 +53,14 @@ class SamlHandler:
self._datastore = hs.get_datastore()
self._hostname = hs.hostname
self._saml2_session_lifetime = hs.config.saml2_session_lifetime
self._mxid_source_attribute = hs.config.saml2_mxid_source_attribute
self._grandfathered_mxid_source_attribute = (
hs.config.saml2_grandfathered_mxid_source_attribute
)
self._mxid_mapper = hs.config.saml2_mxid_mapper
# plugin to do custom mapping from saml response to mxid
self._user_mapping_provider = hs.config.saml2_user_mapping_provider_class(
hs.config.saml2_user_mapping_provider_config
)
# identifier for the external_ids table
self._auth_provider_id = "saml"
@@ -118,22 +137,10 @@ class SamlHandler:
remote_user_id = saml2_auth.ava["uid"][0]
except KeyError:
logger.warning("SAML2 response lacks a 'uid' attestation")
raise SynapseError(400, "uid not in SAML2 response")
try:
mxid_source = saml2_auth.ava[self._mxid_source_attribute][0]
except KeyError:
logger.warning(
"SAML2 response lacks a '%s' attestation", self._mxid_source_attribute
)
raise SynapseError(
400, "%s not in SAML2 response" % (self._mxid_source_attribute,)
)
raise SynapseError(400, "'uid' not in SAML2 response")
self._outstanding_requests_dict.pop(saml2_auth.in_response_to, None)
displayName = saml2_auth.ava.get("displayName", [None])[0]
with (await self._mapping_lock.queue(self._auth_provider_id)):
# first of all, check if we already have a mapping for this user
logger.info(
@@ -173,22 +180,46 @@ class SamlHandler:
)
return registered_user_id
# figure out a new mxid for this user
base_mxid_localpart = self._mxid_mapper(mxid_source)
# Map saml response to user attributes using the configured mapping provider
for i in range(1000):
attribute_dict = self._user_mapping_provider.saml_response_to_user_attributes(
saml2_auth, i
)
suffix = 0
while True:
localpart = base_mxid_localpart + (str(suffix) if suffix else "")
logger.debug(
"Retrieved SAML attributes from user mapping provider: %s "
"(attempt %d)",
attribute_dict,
i,
)
localpart = attribute_dict.get("mxid_localpart")
if not localpart:
logger.error(
"SAML mapping provider plugin did not return a "
"mxid_localpart object"
)
raise SynapseError(500, "Error parsing SAML2 response")
displayname = attribute_dict.get("displayname")
# Check if this mxid already exists
if not await self._datastore.get_users_by_id_case_insensitive(
UserID(localpart, self._hostname).to_string()
):
# This mxid is free
break
suffix += 1
logger.info("Allocating mxid for new user with localpart %s", localpart)
else:
# Unable to generate a username in 1000 iterations
# Break and return error to the user
raise SynapseError(
500, "Unable to generate a Matrix ID from the SAML response"
)
registered_user_id = await self._registration_handler.register_user(
localpart=localpart, default_display_name=displayName
localpart=localpart, default_display_name=displayname
)
await self._datastore.record_user_external_id(
self._auth_provider_id, remote_user_id, registered_user_id
)
@@ -205,9 +236,120 @@ class SamlHandler:
del self._outstanding_requests_dict[reqid]
@attr.s
class Saml2SessionData:
"""Data we track about SAML2 sessions"""
DOT_REPLACE_PATTERN = re.compile(
("[^%s]" % (re.escape("".join(mxid_localpart_allowed_characters)),))
)
# time the session was created, in milliseconds
creation_time = attr.ib()
def dot_replace_for_mxid(username: str) -> str:
username = username.lower()
username = DOT_REPLACE_PATTERN.sub(".", username)
# regular mxids aren't allowed to start with an underscore either
username = re.sub("^_", "", username)
return username
MXID_MAPPER_MAP = {
"hexencode": map_username_to_mxid_localpart,
"dotreplace": dot_replace_for_mxid,
}
@attr.s
class SamlConfig(object):
mxid_source_attribute = attr.ib()
mxid_mapper = attr.ib()
class DefaultSamlMappingProvider(object):
__version__ = "0.0.1"
def __init__(self, parsed_config: SamlConfig):
"""The default SAML user mapping provider
Args:
parsed_config: Module configuration
"""
self._mxid_source_attribute = parsed_config.mxid_source_attribute
self._mxid_mapper = parsed_config.mxid_mapper
def saml_response_to_user_attributes(
self, saml_response: saml2.response.AuthnResponse, failures: int = 0,
) -> dict:
"""Maps some text from a SAML response to attributes of a new user
Args:
saml_response: A SAML auth response object
failures: How many times a call to this function with this
saml_response has resulted in a failure
Returns:
dict: A dict containing new user attributes. Possible keys:
* mxid_localpart (str): Required. The localpart of the user's mxid
* displayname (str): The displayname of the user
"""
try:
mxid_source = saml_response.ava[self._mxid_source_attribute][0]
except KeyError:
logger.warning(
"SAML2 response lacks a '%s' attestation", self._mxid_source_attribute,
)
raise SynapseError(
400, "%s not in SAML2 response" % (self._mxid_source_attribute,)
)
# Use the configured mapper for this mxid_source
base_mxid_localpart = self._mxid_mapper(mxid_source)
# Append suffix integer if last call to this function failed to produce
# a usable mxid
localpart = base_mxid_localpart + (str(failures) if failures else "")
# Retrieve the display name from the saml response
# If displayname is None, the mxid_localpart will be used instead
displayname = saml_response.ava.get("displayName", [None])[0]
return {
"mxid_localpart": localpart,
"displayname": displayname,
}
@staticmethod
def parse_config(config: dict) -> SamlConfig:
"""Parse the dict provided by the homeserver's config
Args:
config: A dictionary containing configuration options for this provider
Returns:
SamlConfig: A custom config object for this module
"""
# Parse config options and use defaults where necessary
mxid_source_attribute = config.get("mxid_source_attribute", "uid")
mapping_type = config.get("mxid_mapping", "hexencode")
# Retrieve the associating mapping function
try:
mxid_mapper = MXID_MAPPER_MAP[mapping_type]
except KeyError:
raise ConfigError(
"saml2_config.user_mapping_provider.config: '%s' is not a valid "
"mxid_mapping value" % (mapping_type,)
)
return SamlConfig(mxid_source_attribute, mxid_mapper)
@staticmethod
def get_saml_attributes(config: SamlConfig) -> Tuple[set, set]:
"""Returns the required attributes of a SAML
Args:
config: A SamlConfig object containing configuration params for this provider
Returns:
tuple[set,set]: The first set equates to the saml auth response
attributes that are required for the module to function, whereas the
second set consists of those attributes which can be used if
available, but are not necessary
"""
return {"uid", config.mxid_source_attribute}, {"displayName"}

View File

@@ -21,7 +21,7 @@ from unpaddedbase64 import decode_base64, encode_base64
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import SynapseError
from synapse.api.errors import NotFoundError, SynapseError
from synapse.api.filtering import Filter
from synapse.storage.state import StateFilter
from synapse.visibility import filter_events_for_client
@@ -37,6 +37,7 @@ class SearchHandler(BaseHandler):
self._event_serializer = hs.get_event_client_serializer()
self.storage = hs.get_storage()
self.state_store = self.storage.state
self.auth = hs.get_auth()
@defer.inlineCallbacks
def get_old_rooms_from_upgraded_room(self, room_id):
@@ -53,23 +54,38 @@ class SearchHandler(BaseHandler):
room_id (str): id of the room to search through.
Returns:
Deferred[iterable[unicode]]: predecessor room ids
Deferred[iterable[str]]: predecessor room ids
"""
historical_room_ids = []
while True:
predecessor = yield self.store.get_room_predecessor(room_id)
# The initial room must have been known for us to get this far
predecessor = yield self.store.get_room_predecessor(room_id)
# If no predecessor, assume we've hit a dead end
while True:
if not predecessor:
# We have reached the end of the chain of predecessors
break
# Add predecessor's room ID
historical_room_ids.append(predecessor["room_id"])
if not isinstance(predecessor.get("room_id"), str):
# This predecessor object is malformed. Exit here
break
# Scan through the old room for further predecessors
room_id = predecessor["room_id"]
predecessor_room_id = predecessor["room_id"]
# Don't add it to the list until we have checked that we are in the room
try:
next_predecessor_room = yield self.store.get_room_predecessor(
predecessor_room_id
)
except NotFoundError:
# The predecessor is not a known room, so we are done here
break
historical_room_ids.append(predecessor_room_id)
# And repeat
predecessor = next_predecessor_room
return historical_room_ids

View File

@@ -76,9 +76,9 @@ def flatten_event(event: dict, metadata: dict, include_time: bool = False):
# context only given in the log format (e.g. what is being logged) is
# available.
if "log_text" in event:
new_event["log"] = event["log_text"]
new_event["message"] = event["log_text"]
else:
new_event["log"] = event["log_format"]
new_event["message"] = event["log_format"]
# We want to include the timestamp when forwarding over the network, but
# exclude it when we are writing to stdout. This is because the log ingester

View File

@@ -23,6 +23,7 @@ them.
See doc/log_contexts.rst for details on how this works.
"""
import inspect
import logging
import threading
import types
@@ -612,7 +613,8 @@ def run_in_background(f, *args, **kwargs):
def make_deferred_yieldable(deferred):
"""Given a deferred, make it follow the Synapse logcontext rules:
"""Given a deferred (or coroutine), make it follow the Synapse logcontext
rules:
If the deferred has completed (or is not actually a Deferred), essentially
does nothing (just returns another completed deferred with the
@@ -624,6 +626,13 @@ def make_deferred_yieldable(deferred):
(This is more-or-less the opposite operation to run_in_background.)
"""
if inspect.isawaitable(deferred):
# If we're given a coroutine we convert it to a deferred so that we
# run it and find out if it immediately finishes, it it does then we
# don't need to fiddle with log contexts at all and can return
# immediately.
deferred = defer.ensureDeferred(deferred)
if not isinstance(deferred, defer.Deferred):
return deferred

View File

@@ -34,6 +34,7 @@ from synapse.api.filtering import Filtering
from synapse.api.ratelimiting import Ratelimiter
from synapse.appservice.api import ApplicationServiceApi
from synapse.appservice.scheduler import ApplicationServiceScheduler
from synapse.config.homeserver import HomeServerConfig
from synapse.crypto import context_factory
from synapse.crypto.keyring import Keyring
from synapse.events.builder import EventBuilderFactory
@@ -97,6 +98,7 @@ from synapse.server_notices.worker_server_notices_sender import (
)
from synapse.state import StateHandler, StateResolutionHandler
from synapse.storage import DataStores, Storage
from synapse.storage.engines import create_engine
from synapse.streams.events import EventSources
from synapse.util import Clock
from synapse.util.distributor import Distributor
@@ -209,16 +211,18 @@ class HomeServer(object):
# instantiated during setup() for future return by get_datastore()
DATASTORE_CLASS = abc.abstractproperty()
def __init__(self, hostname, reactor=None, **kwargs):
def __init__(self, hostname: str, config: HomeServerConfig, reactor=None, **kwargs):
"""
Args:
hostname : The hostname for the server.
config: The full config for the homeserver.
"""
if not reactor:
from twisted.internet import reactor
self._reactor = reactor
self.hostname = hostname
self.config = config
self._building = {}
self._listening_services = []
self.start_time = None
@@ -229,6 +233,12 @@ class HomeServer(object):
self.admin_redaction_ratelimiter = Ratelimiter()
self.registration_ratelimiter = Ratelimiter()
self.database_engine = create_engine(config.database_config)
config.database_config.setdefault("args", {})[
"cp_openfun"
] = self.database_engine.on_new_connection
self.db_config = config.database_config
self.datastores = None
# Other kwargs are explicit dependencies

View File

@@ -32,6 +32,7 @@ from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.logging.utils import log_function
from synapse.state import v1, v2
from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
from synapse.util.async_helpers import Linearizer
from synapse.util.caches import get_cache_factor_for
from synapse.util.caches.expiringcache import ExpiringCache
@@ -645,7 +646,7 @@ class StateResolutionStore(object):
return self.store.get_events(
event_ids,
check_redacted=False,
redact_behaviour=EventRedactBehaviour.AS_IS,
get_prev_content=False,
allow_rejected=allow_rejected,
)

View File

@@ -451,16 +451,18 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
# Technically an access token might not be associated with
# a device so we need to check.
if device_id:
self.db.simple_upsert_txn(
# this is always an update rather than an upsert: the row should
# already exist, and if it doesn't, that may be because it has been
# deleted, and we don't want to re-create it.
self.db.simple_update_txn(
txn,
table="devices",
keyvalues={"user_id": user_id, "device_id": device_id},
values={
updatevalues={
"user_agent": user_agent,
"last_seen": last_seen,
"ip": ip,
},
lock=False,
)
except Exception as e:
# Failed to upsert, log and continue

View File

@@ -14,15 +14,18 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, List
from six import iteritems
from canonicaljson import encode_canonical_json, json
from twisted.enterprise.adbapi import Connection
from twisted.internet import defer
from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.util.caches.descriptors import cached
from synapse.util.caches.descriptors import cached, cachedList
class EndToEndKeyWorkerStore(SQLBaseStore):
@@ -271,7 +274,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
Args:
txn (twisted.enterprise.adbapi.Connection): db connection
user_id (str): the user whose key is being requested
key_type (str): the type of key that is being set: either 'master'
key_type (str): the type of key that is being requested: either 'master'
for a master key, 'self_signing' for a self-signing key, or
'user_signing' for a user-signing key
from_user_id (str): if specified, signatures made by this user on
@@ -316,8 +319,10 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
"""Returns a user's cross-signing key.
Args:
user_id (str): the user whose self-signing key is being requested
key_type (str): the type of cross-signing key to get
user_id (str): the user whose key is being requested
key_type (str): the type of key that is being requested: either 'master'
for a master key, 'self_signing' for a self-signing key, or
'user_signing' for a user-signing key
from_user_id (str): if specified, signatures made by this user on
the self-signing key will be included in the result
@@ -332,6 +337,206 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
from_user_id,
)
@cached(num_args=1)
def _get_bare_e2e_cross_signing_keys(self, user_id):
"""Dummy function. Only used to make a cache for
_get_bare_e2e_cross_signing_keys_bulk.
"""
raise NotImplementedError()
@cachedList(
cached_method_name="_get_bare_e2e_cross_signing_keys",
list_name="user_ids",
num_args=1,
)
def _get_bare_e2e_cross_signing_keys_bulk(
self, user_ids: List[str]
) -> Dict[str, Dict[str, dict]]:
"""Returns the cross-signing keys for a set of users. The output of this
function should be passed to _get_e2e_cross_signing_signatures_txn if
the signatures for the calling user need to be fetched.
Args:
user_ids (list[str]): the users whose keys are being requested
Returns:
dict[str, dict[str, dict]]: mapping from user ID to key type to key
data. If a user's cross-signing keys were not found, either
their user ID will not be in the dict, or their user ID will map
to None.
"""
return self.db.runInteraction(
"get_bare_e2e_cross_signing_keys_bulk",
self._get_bare_e2e_cross_signing_keys_bulk_txn,
user_ids,
)
def _get_bare_e2e_cross_signing_keys_bulk_txn(
self, txn: Connection, user_ids: List[str],
) -> Dict[str, Dict[str, dict]]:
"""Returns the cross-signing keys for a set of users. The output of this
function should be passed to _get_e2e_cross_signing_signatures_txn if
the signatures for the calling user need to be fetched.
Args:
txn (twisted.enterprise.adbapi.Connection): db connection
user_ids (list[str]): the users whose keys are being requested
Returns:
dict[str, dict[str, dict]]: mapping from user ID to key type to key
data. If a user's cross-signing keys were not found, their user
ID will not be in the dict.
"""
result = {}
batch_size = 100
chunks = [
user_ids[i : i + batch_size] for i in range(0, len(user_ids), batch_size)
]
for user_chunk in chunks:
sql = """
SELECT k.user_id, k.keytype, k.keydata, k.stream_id
FROM e2e_cross_signing_keys k
INNER JOIN (SELECT user_id, keytype, MAX(stream_id) AS stream_id
FROM e2e_cross_signing_keys
GROUP BY user_id, keytype) s
USING (user_id, stream_id, keytype)
WHERE k.user_id IN (%s)
""" % (
",".join("?" for u in user_chunk),
)
query_params = []
query_params.extend(user_chunk)
txn.execute(sql, query_params)
rows = self.db.cursor_to_dict(txn)
for row in rows:
user_id = row["user_id"]
key_type = row["keytype"]
key = json.loads(row["keydata"])
user_info = result.setdefault(user_id, {})
user_info[key_type] = key
return result
def _get_e2e_cross_signing_signatures_txn(
self, txn: Connection, keys: Dict[str, Dict[str, dict]], from_user_id: str,
) -> Dict[str, Dict[str, dict]]:
"""Returns the cross-signing signatures made by a user on a set of keys.
Args:
txn (twisted.enterprise.adbapi.Connection): db connection
keys (dict[str, dict[str, dict]]): a map of user ID to key type to
key data. This dict will be modified to add signatures.
from_user_id (str): fetch the signatures made by this user
Returns:
dict[str, dict[str, dict]]: mapping from user ID to key type to key
data. The return value will be the same as the keys argument,
with the modifications included.
"""
# find out what cross-signing keys (a.k.a. devices) we need to get
# signatures for. This is a map of (user_id, device_id) to key type
# (device_id is the key's public part).
devices = {}
for user_id, user_info in keys.items():
if user_info is None:
continue
for key_type, key in user_info.items():
device_id = None
for k in key["keys"].values():
device_id = k
devices[(user_id, device_id)] = key_type
device_list = list(devices)
# split into batches
batch_size = 100
chunks = [
device_list[i : i + batch_size]
for i in range(0, len(device_list), batch_size)
]
for user_chunk in chunks:
sql = """
SELECT target_user_id, target_device_id, key_id, signature
FROM e2e_cross_signing_signatures
WHERE user_id = ?
AND (%s)
""" % (
" OR ".join(
"(target_user_id = ? AND target_device_id = ?)" for d in devices
)
)
query_params = [from_user_id]
for item in devices:
# item is a (user_id, device_id) tuple
query_params.extend(item)
txn.execute(sql, query_params)
rows = self.db.cursor_to_dict(txn)
# and add the signatures to the appropriate keys
for row in rows:
key_id = row["key_id"]
target_user_id = row["target_user_id"]
target_device_id = row["target_device_id"]
key_type = devices[(target_user_id, target_device_id)]
# We need to copy everything, because the result may have come
# from the cache. dict.copy only does a shallow copy, so we
# need to recursively copy the dicts that will be modified.
user_info = keys[target_user_id] = keys[target_user_id].copy()
target_user_key = user_info[key_type] = user_info[key_type].copy()
if "signatures" in target_user_key:
signatures = target_user_key["signatures"] = target_user_key[
"signatures"
].copy()
if from_user_id in signatures:
user_sigs = signatures[from_user_id] = signatures[from_user_id]
user_sigs[key_id] = row["signature"]
else:
signatures[from_user_id] = {key_id: row["signature"]}
else:
target_user_key["signatures"] = {
from_user_id: {key_id: row["signature"]}
}
return keys
@defer.inlineCallbacks
def get_e2e_cross_signing_keys_bulk(
self, user_ids: List[str], from_user_id: str = None
) -> defer.Deferred:
"""Returns the cross-signing keys for a set of users.
Args:
user_ids (list[str]): the users whose keys are being requested
from_user_id (str): if specified, signatures made by this user on
the self-signing keys will be included in the result
Returns:
Deferred[dict[str, dict[str, dict]]]: map of user ID to key type to
key data. If a user's cross-signing keys were not found, either
their user ID will not be in the dict, or their user ID will map
to None.
"""
result = yield self._get_bare_e2e_cross_signing_keys_bulk(user_ids)
if from_user_id:
result = yield self.db.runInteraction(
"get_e2e_cross_signing_signatures",
self._get_e2e_cross_signing_signatures_txn,
result,
from_user_id,
)
return result
def get_all_user_signature_changes_for_remotes(self, from_key, to_key):
"""Return a list of changes from the user signature stream to notify remotes.
Note that the user signature stream represents when a user signs their
@@ -520,6 +725,10 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
},
)
self._invalidate_cache_and_stream(
txn, self._get_bare_e2e_cross_signing_keys, (user_id,)
)
def set_e2e_cross_signing_key(self, user_id, key_type, key):
"""Set a user's cross-signing key.

View File

@@ -19,8 +19,10 @@ import itertools
import logging
import threading
from collections import namedtuple
from typing import List, Optional
from canonicaljson import json
from constantly import NamedConstant, Names
from twisted.internet import defer
@@ -55,6 +57,16 @@ EVENT_QUEUE_TIMEOUT_S = 0.1 # Timeout when waiting for requests for events
_EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event"))
class EventRedactBehaviour(Names):
"""
What to do when retrieving a redacted event from the database.
"""
AS_IS = NamedConstant()
REDACT = NamedConstant()
BLOCK = NamedConstant()
class EventsWorkerStore(SQLBaseStore):
def __init__(self, database: Database, db_conn, hs):
super(EventsWorkerStore, self).__init__(database, db_conn, hs)
@@ -125,25 +137,27 @@ class EventsWorkerStore(SQLBaseStore):
@defer.inlineCallbacks
def get_event(
self,
event_id,
check_redacted=True,
get_prev_content=False,
allow_rejected=False,
allow_none=False,
check_room_id=None,
event_id: List[str],
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
get_prev_content: bool = False,
allow_rejected: bool = False,
allow_none: bool = False,
check_room_id: Optional[str] = None,
):
"""Get an event from the database by event_id.
Args:
event_id (str): The event_id of the event to fetch
check_redacted (bool): If True, check if event has been redacted
and redact it.
get_prev_content (bool): If True and event is a state event,
event_id: The event_id of the event to fetch
redact_behaviour: Determine what to do with a redacted event. Possible values:
* AS_IS - Return the full event body with no redacted content
* REDACT - Return the event but with a redacted body
* DISALLOW - Do not return redacted events
get_prev_content: If True and event is a state event,
include the previous states content in the unsigned field.
allow_rejected (bool): If True return rejected events.
allow_none (bool): If True, return None if no event found, if
allow_rejected: If True return rejected events.
allow_none: If True, return None if no event found, if
False throw a NotFoundError
check_room_id (str|None): if not None, check the room of the found event.
check_room_id: if not None, check the room of the found event.
If there is a mismatch, behave as per allow_none.
Returns:
@@ -154,7 +168,7 @@ class EventsWorkerStore(SQLBaseStore):
events = yield self.get_events_as_list(
[event_id],
check_redacted=check_redacted,
redact_behaviour=redact_behaviour,
get_prev_content=get_prev_content,
allow_rejected=allow_rejected,
)
@@ -173,27 +187,30 @@ class EventsWorkerStore(SQLBaseStore):
@defer.inlineCallbacks
def get_events(
self,
event_ids,
check_redacted=True,
get_prev_content=False,
allow_rejected=False,
event_ids: List[str],
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
get_prev_content: bool = False,
allow_rejected: bool = False,
):
"""Get events from the database
Args:
event_ids (list): The event_ids of the events to fetch
check_redacted (bool): If True, check if event has been redacted
and redact it.
get_prev_content (bool): If True and event is a state event,
event_ids: The event_ids of the events to fetch
redact_behaviour: Determine what to do with a redacted event. Possible
values:
* AS_IS - Return the full event body with no redacted content
* REDACT - Return the event but with a redacted body
* DISALLOW - Do not return redacted events
get_prev_content: If True and event is a state event,
include the previous states content in the unsigned field.
allow_rejected (bool): If True return rejected events.
allow_rejected: If True return rejected events.
Returns:
Deferred : Dict from event_id to event.
"""
events = yield self.get_events_as_list(
event_ids,
check_redacted=check_redacted,
redact_behaviour=redact_behaviour,
get_prev_content=get_prev_content,
allow_rejected=allow_rejected,
)
@@ -203,21 +220,23 @@ class EventsWorkerStore(SQLBaseStore):
@defer.inlineCallbacks
def get_events_as_list(
self,
event_ids,
check_redacted=True,
get_prev_content=False,
allow_rejected=False,
event_ids: List[str],
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
get_prev_content: bool = False,
allow_rejected: bool = False,
):
"""Get events from the database and return in a list in the same order
as given by `event_ids` arg.
Args:
event_ids (list): The event_ids of the events to fetch
check_redacted (bool): If True, check if event has been redacted
and redact it.
get_prev_content (bool): If True and event is a state event,
event_ids: The event_ids of the events to fetch
redact_behaviour: Determine what to do with a redacted event. Possible values:
* AS_IS - Return the full event body with no redacted content
* REDACT - Return the event but with a redacted body
* DISALLOW - Do not return redacted events
get_prev_content: If True and event is a state event,
include the previous states content in the unsigned field.
allow_rejected (bool): If True return rejected events.
allow_rejected: If True, return rejected events.
Returns:
Deferred[list[EventBase]]: List of events fetched from the database. The
@@ -319,10 +338,14 @@ class EventsWorkerStore(SQLBaseStore):
# Update the cache to save doing the checks again.
entry.event.internal_metadata.recheck_redaction = False
if check_redacted and entry.redacted_event:
event = entry.redacted_event
else:
event = entry.event
event = entry.event
if entry.redacted_event:
if redact_behaviour == EventRedactBehaviour.BLOCK:
# Skip this event
continue
elif redact_behaviour == EventRedactBehaviour.REDACT:
event = entry.redacted_event
events.append(event)

View File

@@ -0,0 +1,13 @@
# Building full schema dumps
These schemas need to be made from a database that has had all background updates run.
To do so, use `scripts-dev/make_full_schema.sh`. This will produce
`full.sql.postgres ` and `full.sql.sqlite` files.
Ensure postgres is installed and your user has the ability to run bash commands
such as `createdb`.
```
./scripts-dev/make_full_schema.sh -p postgres_username -o output_dir/
```

View File

@@ -1,19 +0,0 @@
Building full schema dumps
==========================
These schemas need to be made from a database that has had all background updates run.
Postgres
--------
$ pg_dump --format=plain --schema-only --no-tablespaces --no-acl --no-owner $DATABASE_NAME| sed -e '/^--/d' -e 's/public\.//g' -e '/^SET /d' -e '/^SELECT /d' > full.sql.postgres
SQLite
------
$ sqlite3 $DATABASE_FILE ".schema" > full.sql.sqlite
After
-----
Delete the CREATE statements for "sqlite_stat1", "schema_version", "applied_schema_deltas", and "applied_module_schemas".

View File

@@ -25,6 +25,7 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError
from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
from synapse.storage.database import Database
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
@@ -384,7 +385,7 @@ class SearchStore(SearchBackgroundUpdateStore):
"""
clauses = []
search_query = search_query = _parse_query(self.database_engine, search_term)
search_query = _parse_query(self.database_engine, search_term)
args = []
@@ -453,7 +454,12 @@ class SearchStore(SearchBackgroundUpdateStore):
results = list(filter(lambda row: row["room_id"] in room_ids, results))
events = yield self.get_events_as_list([r["event_id"] for r in results])
# We set redact_behaviour to BLOCK here to prevent redacted events being returned in
# search results (which is a data leak)
events = yield self.get_events_as_list(
[r["event_id"] for r in results],
redact_behaviour=EventRedactBehaviour.BLOCK,
)
event_map = {ev.event_id: ev for ev in events}
@@ -495,7 +501,7 @@ class SearchStore(SearchBackgroundUpdateStore):
"""
clauses = []
search_query = search_query = _parse_query(self.database_engine, search_term)
search_query = _parse_query(self.database_engine, search_term)
args = []
@@ -600,7 +606,12 @@ class SearchStore(SearchBackgroundUpdateStore):
results = list(filter(lambda row: row["room_id"] in room_ids, results))
events = yield self.get_events_as_list([r["event_id"] for r in results])
# We set redact_behaviour to BLOCK here to prevent redacted events being returned in
# search results (which is a data leak)
events = yield self.get_events_as_list(
[r["event_id"] for r in results],
redact_behaviour=EventRedactBehaviour.BLOCK,
)
event_map = {ev.event_id: ev for ev in events}

View File

@@ -278,7 +278,7 @@ class StateGroupWorkerStore(
@defer.inlineCallbacks
def get_room_predecessor(self, room_id):
"""Get the predecessor room of an upgraded room if one exists.
"""Get the predecessor of an upgraded room if it exists.
Otherwise return None.
Args:
@@ -291,14 +291,22 @@ class StateGroupWorkerStore(
* room_id (str): The room ID of the predecessor room
* event_id (str): The ID of the tombstone event in the predecessor room
None if a predecessor key is not found, or is not a dictionary.
Raises:
NotFoundError if the room is unknown
NotFoundError if the given room is unknown
"""
# Retrieve the room's create event
create_event = yield self.get_create_event_for_room(room_id)
# Return predecessor if present
return create_event.content.get("predecessor", None)
# Retrieve the predecessor key of the create event
predecessor = create_event.content.get("predecessor", None)
# Ensure the key is a dictionary
if not isinstance(predecessor, dict):
return None
return predecessor
@defer.inlineCallbacks
def get_create_event_for_room(self, room_id):
@@ -318,7 +326,7 @@ class StateGroupWorkerStore(
# If we can't find the create event, assume we've hit a dead end
if not create_id:
raise NotFoundError("Unknown room %s" % (room_id))
raise NotFoundError("Unknown room %s" % (room_id,))
# Retrieve the room's create event and return
create_event = yield self.get_event(create_id)

View File

@@ -271,7 +271,7 @@ class _CacheDescriptorBase(object):
else:
self.function_to_call = orig
arg_spec = inspect.getargspec(orig)
arg_spec = inspect.getfullargspec(orig)
all_args = arg_spec.args
if "cache_context" in all_args:

View File

@@ -1,94 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.util.async_helpers import ObservableDeferred
class SnapshotCache(object):
"""Cache for snapshots like the response of /initialSync.
The response of initialSync only has to be a recent snapshot of the
server state. It shouldn't matter to clients if it is a few minutes out
of date.
This caches a deferred response. Until the deferred completes it will be
returned from the cache. This means that if the client retries the request
while the response is still being computed, that original response will be
used rather than trying to compute a new response.
Once the deferred completes it will removed from the cache after 5 minutes.
We delay removing it from the cache because a client retrying its request
could race with us finishing computing the response.
Rather than tracking precisely how long something has been in the cache we
keep two generations of completed responses. Every 5 minutes discard the
old generation, move the new generation to the old generation, and set the
new generation to be empty. This means that a result will be in the cache
somewhere between 5 and 10 minutes.
"""
DURATION_MS = 5 * 60 * 1000 # Cache results for 5 minutes.
def __init__(self):
self.pending_result_cache = {} # Request that haven't finished yet.
self.prev_result_cache = {} # The older requests that have finished.
self.next_result_cache = {} # The newer requests that have finished.
self.time_last_rotated_ms = 0
def rotate(self, time_now_ms):
# Rotate once if the cache duration has passed since the last rotation.
if time_now_ms - self.time_last_rotated_ms >= self.DURATION_MS:
self.prev_result_cache = self.next_result_cache
self.next_result_cache = {}
self.time_last_rotated_ms += self.DURATION_MS
# Rotate again if the cache duration has passed twice since the last
# rotation.
if time_now_ms - self.time_last_rotated_ms >= self.DURATION_MS:
self.prev_result_cache = self.next_result_cache
self.next_result_cache = {}
self.time_last_rotated_ms = time_now_ms
def get(self, time_now_ms, key):
self.rotate(time_now_ms)
# This cache is intended to deduplicate requests, so we expect it to be
# missed most of the time. So we just lookup the key in all of the
# dictionaries rather than trying to short circuit the lookup if the
# key is found.
result = self.prev_result_cache.get(key)
result = self.next_result_cache.get(key, result)
result = self.pending_result_cache.get(key, result)
if result is not None:
return result.observe()
else:
return None
def set(self, time_now_ms, key, deferred):
self.rotate(time_now_ms)
result = ObservableDeferred(deferred)
self.pending_result_cache[key] = result
def shuffle_along(r):
# When the deferred completes we shuffle it along to the first
# generation of the result cache. So that it will eventually
# expire from the rotation of that cache.
self.next_result_cache[key] = result
self.pending_result_cache.pop(key, None)
return r
result.addBoth(shuffle_along)
return result.observe()

View File

@@ -183,10 +183,6 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
)
self.assertDictEqual(devices["master_keys"], {local_user: keys2["master_key"]})
test_replace_master_key.skip = (
"Disabled waiting on #https://github.com/matrix-org/synapse/pull/6486"
)
@defer.inlineCallbacks
def test_reupload_signatures(self):
"""re-uploading a signature should not fail"""
@@ -507,7 +503,3 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
],
other_master_key["signatures"][local_user]["ed25519:" + usersigning_pubkey],
)
test_upload_signatures.skip = (
"Disabled waiting on #https://github.com/matrix-org/synapse/pull/6486"
)

View File

@@ -391,9 +391,8 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
# Email config.
self.email_attempts = []
def sendmail(*args, **kwargs):
async def sendmail(*args, **kwargs):
self.email_attempts.append((args, kwargs))
return
config["email"] = {
"enable_notifs": True,

View File

@@ -37,9 +37,13 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
self.reactor.advance(12345678)
user_id = "@user:id"
device_id = "MY_DEVICE"
# Insert a user IP
self.get_success(self.store.store_device(user_id, device_id, "display name",))
self.get_success(
self.store.insert_client_ip(
user_id, "access_token", "ip", "user_agent", "device_id"
user_id, "access_token", "ip", "user_agent", device_id
)
)
@@ -47,14 +51,14 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
self.reactor.advance(10)
result = self.get_success(
self.store.get_last_client_ip_by_device(user_id, "device_id")
self.store.get_last_client_ip_by_device(user_id, device_id)
)
r = result[(user_id, "device_id")]
r = result[(user_id, device_id)]
self.assertDictContainsSubset(
{
"user_id": user_id,
"device_id": "device_id",
"device_id": device_id,
"ip": "ip",
"user_agent": "user_agent",
"last_seen": 12345678000,
@@ -209,14 +213,16 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
self.store.db.updates.do_next_background_update(100), by=0.1
)
# Insert a user IP
user_id = "@user:id"
device_id = "MY_DEVICE"
# Insert a user IP
self.get_success(self.store.store_device(user_id, device_id, "display name",))
self.get_success(
self.store.insert_client_ip(
user_id, "access_token", "ip", "user_agent", "device_id"
user_id, "access_token", "ip", "user_agent", device_id
)
)
# Force persisting to disk
self.reactor.advance(200)
@@ -224,7 +230,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
self.get_success(
self.store.db.simple_update(
table="devices",
keyvalues={"user_id": user_id, "device_id": "device_id"},
keyvalues={"user_id": user_id, "device_id": device_id},
updatevalues={"last_seen": None, "ip": None, "user_agent": None},
desc="test_devices_last_seen_bg_update",
)
@@ -232,14 +238,14 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
# We should now get nulls when querying
result = self.get_success(
self.store.get_last_client_ip_by_device(user_id, "device_id")
self.store.get_last_client_ip_by_device(user_id, device_id)
)
r = result[(user_id, "device_id")]
r = result[(user_id, device_id)]
self.assertDictContainsSubset(
{
"user_id": user_id,
"device_id": "device_id",
"device_id": device_id,
"ip": None,
"user_agent": None,
"last_seen": None,
@@ -272,14 +278,14 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
# We should now get the correct result again
result = self.get_success(
self.store.get_last_client_ip_by_device(user_id, "device_id")
self.store.get_last_client_ip_by_device(user_id, device_id)
)
r = result[(user_id, "device_id")]
r = result[(user_id, device_id)]
self.assertDictContainsSubset(
{
"user_id": user_id,
"device_id": "device_id",
"device_id": device_id,
"ip": "ip",
"user_agent": "user_agent",
"last_seen": 0,
@@ -296,11 +302,14 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
self.store.db.updates.do_next_background_update(100), by=0.1
)
# Insert a user IP
user_id = "@user:id"
device_id = "MY_DEVICE"
# Insert a user IP
self.get_success(self.store.store_device(user_id, device_id, "display name",))
self.get_success(
self.store.insert_client_ip(
user_id, "access_token", "ip", "user_agent", "device_id"
user_id, "access_token", "ip", "user_agent", device_id
)
)
@@ -324,7 +333,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
"access_token": "access_token",
"ip": "ip",
"user_agent": "user_agent",
"device_id": "device_id",
"device_id": device_id,
"last_seen": 0,
}
],
@@ -347,14 +356,14 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
# But we should still get the correct values for the device
result = self.get_success(
self.store.get_last_client_ip_by_device(user_id, "device_id")
self.store.get_last_client_ip_by_device(user_id, device_id)
)
r = result[(user_id, "device_id")]
r = result[(user_id, device_id)]
self.assertDictContainsSubset(
{
"user_id": user_id,
"device_id": "device_id",
"device_id": device_id,
"ip": "ip",
"user_agent": "user_agent",
"last_seen": 0,

View File

@@ -1,6 +1,6 @@
from mock import Mock
from twisted.internet.defer import maybeDeferred, succeed
from twisted.internet.defer import ensureDeferred, maybeDeferred, succeed
from synapse.events import FrozenEvent
from synapse.logging.context import LoggingContext
@@ -70,8 +70,10 @@ class MessageAcceptTests(unittest.TestCase):
)
# Send the join, it should return None (which is not an error)
d = self.handler.on_receive_pdu(
"test.serv", join_event, sent_to_us_directly=True
d = ensureDeferred(
self.handler.on_receive_pdu(
"test.serv", join_event, sent_to_us_directly=True
)
)
self.reactor.advance(1)
self.assertEqual(self.successResultOf(d), None)
@@ -119,8 +121,10 @@ class MessageAcceptTests(unittest.TestCase):
)
with LoggingContext(request="lying_event"):
d = self.handler.on_receive_pdu(
"test.serv", lying_event, sent_to_us_directly=True
d = ensureDeferred(
self.handler.on_receive_pdu(
"test.serv", lying_event, sent_to_us_directly=True
)
)
# Step the reactor, so the database fetches come back

View File

@@ -17,18 +17,15 @@ from synapse.api.errors import SynapseError
from synapse.types import GroupID, RoomAlias, UserID, map_username_to_mxid_localpart
from tests import unittest
from tests.utils import TestHomeServer
mock_homeserver = TestHomeServer(hostname="my.domain")
class UserIDTestCase(unittest.TestCase):
class UserIDTestCase(unittest.HomeserverTestCase):
def test_parse(self):
user = UserID.from_string("@1234abcd:my.domain")
user = UserID.from_string("@1234abcd:test")
self.assertEquals("1234abcd", user.localpart)
self.assertEquals("my.domain", user.domain)
self.assertEquals(True, mock_homeserver.is_mine(user))
self.assertEquals("test", user.domain)
self.assertEquals(True, self.hs.is_mine(user))
def test_pase_empty(self):
with self.assertRaises(SynapseError):
@@ -48,13 +45,13 @@ class UserIDTestCase(unittest.TestCase):
self.assertTrue(userA != userB)
class RoomAliasTestCase(unittest.TestCase):
class RoomAliasTestCase(unittest.HomeserverTestCase):
def test_parse(self):
room = RoomAlias.from_string("#channel:my.domain")
room = RoomAlias.from_string("#channel:test")
self.assertEquals("channel", room.localpart)
self.assertEquals("my.domain", room.domain)
self.assertEquals(True, mock_homeserver.is_mine(room))
self.assertEquals("test", room.domain)
self.assertEquals(True, self.hs.is_mine(room))
def test_build(self):
room = RoomAlias("channel", "my.domain")

View File

@@ -179,6 +179,30 @@ class LoggingContextTestCase(unittest.TestCase):
nested_context = nested_logging_context(suffix="bar")
self.assertEqual(nested_context.request, "foo-bar")
@defer.inlineCallbacks
def test_make_deferred_yieldable_with_await(self):
# an async function which retuns an incomplete coroutine, but doesn't
# follow the synapse rules.
async def blocking_function():
d = defer.Deferred()
reactor.callLater(0, d.callback, None)
await d
sentinel_context = LoggingContext.current_context()
with LoggingContext() as context_one:
context_one.request = "one"
d1 = make_deferred_yieldable(blocking_function())
# make sure that the context was reset by make_deferred_yieldable
self.assertIs(LoggingContext.current_context(), sentinel_context)
yield d1
# now it should be restored
self._check_test_key("one")
# a function which returns a deferred which has been "called", but
# which had a function which returned another incomplete deferred on

View File

@@ -1,63 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet.defer import Deferred
from synapse.util.caches.snapshot_cache import SnapshotCache
from .. import unittest
class SnapshotCacheTestCase(unittest.TestCase):
def setUp(self):
self.cache = SnapshotCache()
self.cache.DURATION_MS = 1
def test_get_set(self):
# Check that getting a missing key returns None
self.assertEquals(self.cache.get(0, "key"), None)
# Check that setting a key with a deferred returns
# a deferred that resolves when the initial deferred does
d = Deferred()
set_result = self.cache.set(0, "key", d)
self.assertIsNotNone(set_result)
self.assertFalse(set_result.called)
# Check that getting the key before the deferred has resolved
# returns a deferred that resolves when the initial deferred does.
get_result_at_10 = self.cache.get(10, "key")
self.assertIsNotNone(get_result_at_10)
self.assertFalse(get_result_at_10.called)
# Check that the returned deferreds resolve when the initial deferred
# does.
d.callback("v")
self.assertTrue(set_result.called)
self.assertTrue(get_result_at_10.called)
# Check that getting the key after the deferred has resolved
# before the cache expires returns a resolved deferred.
get_result_at_11 = self.cache.get(11, "key")
self.assertIsNotNone(get_result_at_11)
if isinstance(get_result_at_11, Deferred):
# The cache may return the actual result rather than a deferred
self.assertTrue(get_result_at_11.called)
# Check that getting the key after the deferred has resolved
# after the cache expires returns None
get_result_at_12 = self.cache.get(12, "key")
self.assertIsNone(get_result_at_12)

View File

@@ -260,9 +260,7 @@ def setup_test_homeserver(
hs = homeserverToUse(
name,
config=config,
db_config=config.database_config,
version_string="Synapse/tests",
database_engine=db_engine,
tls_server_context_factory=Mock(),
tls_client_options_factory=Mock(),
reactor=reactor,

14
tox.ini
View File

@@ -177,5 +177,17 @@ env =
MYPYPATH = stubs/
extras = all
commands = mypy \
synapse/config/ \
synapse/handlers/ui_auth \
synapse/logging/ \
synapse/config/
synapse/module_api \
synapse/rest/consent \
synapse/rest/media/v0 \
synapse/rest/saml2 \
synapse/spam_checker_api \
synapse/storage/engines \
synapse/streams
# To find all folders that pass mypy you run:
#
# find synapse/* -type d -not -name __pycache__ -exec bash -c "mypy '{}' > /dev/null" \; -print