1
0

Merge commit '3950ae51e' into anoa/dinsic_release_1_21_x

* commit '3950ae51e':
  Ensure that remove_pusher is always async (#7981)
  Ensure the msg property of HttpResponseException is a string. (#7979)
  Remove from the event_relations table when purging historical events. (#7978)
  Add additional logging for SAML sessions. (#7971)
  Add MSC reference to changelog for #7736
  Re-implement unread counts (#7736)
  Various improvements to the docs (#7899)
  Convert storage layer to async/await. (#7963)
  Add an option to disable purge in delete room admin API (#7964)
  Move some log lines from default logger to sql/transaction loggers (#7952)
  Use the JSON module from the std library instead of simplejson. (#7936)
  Fix exit code for `check_line_terminators.sh` (#7970)
  Option to allow server admins to join complex rooms (#7902)
  Fix typo in metrics docs (#7966)
  Add script for finding files with unix line terminators (#7965)
  Convert the remaining media repo code to async / await. (#7947)
  Convert a synapse.events to async/await. (#7949)
  Convert groups and visibility code to async / await. (#7951)
  Convert push to async/await. (#7948)
This commit is contained in:
Andrew Morgan
2020-10-16 17:06:25 +01:00
85 changed files with 1325 additions and 940 deletions
+95 -14
View File
@@ -1,10 +1,12 @@
- [Choosing your server name](#choosing-your-server-name)
- [Picking a database engine](#picking-a-database-engine)
- [Installing Synapse](#installing-synapse)
- [Installing from source](#installing-from-source)
- [Platform-Specific Instructions](#platform-specific-instructions)
- [Prebuilt packages](#prebuilt-packages)
- [Setting up Synapse](#setting-up-synapse)
- [TLS certificates](#tls-certificates)
- [Client Well-Known URI](#client-well-known-uri)
- [Email](#email)
- [Registering a user](#registering-a-user)
- [Setting up a TURN server](#setting-up-a-turn-server)
@@ -27,6 +29,25 @@ that your email address is probably `user@example.com` rather than
`user@email.example.com`) - but doing so may require more advanced setup: see
[Setting up Federation](docs/federate.md).
# Picking a database engine
Synapse offers two database engines:
* [PostgreSQL](https://www.postgresql.org)
* [SQLite](https://sqlite.org/)
Almost all installations should opt to use PostgreSQL. Advantages include:
* significant performance improvements due to the superior threading and
caching model, smarter query optimiser
* allowing the DB to be run on separate hardware
For information on how to install and use PostgreSQL, please see
[docs/postgres.md](docs/postgres.md)
By default Synapse uses SQLite and in doing so trades performance for convenience.
SQLite is only recommended in Synapse for testing purposes or for servers with
light workloads.
# Installing Synapse
## Installing from source
@@ -234,9 +255,9 @@ for a number of platforms.
There is an offical synapse image available at
https://hub.docker.com/r/matrixdotorg/synapse which can be used with
the docker-compose file available at [contrib/docker](contrib/docker). Further information on
this including configuration options is available in the README on
hub.docker.com.
the docker-compose file available at [contrib/docker](contrib/docker). Further
information on this including configuration options is available in the README
on hub.docker.com.
Alternatively, Andreas Peters (previously Silvio Fricke) has contributed a
Dockerfile to automate a synapse server in a single Docker image, at
@@ -244,7 +265,8 @@ https://hub.docker.com/r/avhost/docker-matrix/tags/
Slavi Pantaleev has created an Ansible playbook,
which installs the offical Docker image of Matrix Synapse
along with many other Matrix-related services (Postgres database, riot-web, coturn, mxisd, SSL support, etc.).
along with many other Matrix-related services (Postgres database, Element, coturn,
ma1sd, SSL support, etc.).
For more details, see
https://github.com/spantaleev/matrix-docker-ansible-deploy
@@ -277,22 +299,27 @@ The fingerprint of the repository signing key (as shown by `gpg
/usr/share/keyrings/matrix-org-archive-keyring.gpg`) is
`AAF9AE843A7584B5A3E4CD2BCF45A512DE2DA058`.
#### Downstream Debian/Ubuntu packages
#### Downstream Debian packages
For `buster` and `sid`, Synapse is available in the Debian repositories and
it should be possible to install it with simply:
We do not recommend using the packages from the default Debian `buster`
repository at this time, as they are old and suffer from known security
vulnerabilities. You can install the latest version of Synapse from
[our repository](#matrixorg-packages) or from `buster-backports`. Please
see the [Debian documentation](https://backports.debian.org/Instructions/)
for information on how to use backports.
If you are using Debian `sid` or testing, Synapse is available in the default
repositories and it should be possible to install it simply with:
```
sudo apt install matrix-synapse
```
There is also a version of `matrix-synapse` in `stretch-backports`. Please see
the [Debian documentation on
backports](https://backports.debian.org/Instructions/) for information on how
to use them.
#### Downstream Ubuntu packages
We do not recommend using the packages in downstream Ubuntu at this time, as
they are old and suffer from known security vulnerabilities.
We do not recommend using the packages in the default Ubuntu repository
at this time, as they are old and suffer from known security vulnerabilities.
The latest version of Synapse can be installed from [our repository](#matrixorg-packages).
### Fedora
@@ -419,6 +446,60 @@ so, you will need to edit `homeserver.yaml`, as follows:
For a more detailed guide to configuring your server for federation, see
[federate.md](docs/federate.md).
## Client Well-Known URI
Setting up the client Well-Known URI is optional but if you set it up, it will
allow users to enter their full username (e.g. `@user:<server_name>`) into clients
which support well-known lookup to automatically configure the homeserver and
identity server URLs. This is useful so that users don't have to memorize or think
about the actual homeserver URL you are using.
The URL `https://<server_name>/.well-known/matrix/client` should return JSON in
the following format.
```
{
"m.homeserver": {
"base_url": "https://<matrix.example.com>"
}
}
```
It can optionally contain identity server information as well.
```
{
"m.homeserver": {
"base_url": "https://<matrix.example.com>"
},
"m.identity_server": {
"base_url": "https://<identity.example.com>"
}
}
```
To work in browser based clients, the file must be served with the appropriate
Cross-Origin Resource Sharing (CORS) headers. A recommended value would be
`Access-Control-Allow-Origin: *` which would allow all browser based clients to
view it.
In nginx this would be something like:
```
location /.well-known/matrix/client {
return 200 '{"m.homeserver": {"base_url": "https://<matrix.example.com>"}}';
add_header Content-Type application/json;
add_header Access-Control-Allow-Origin *;
}
```
You should also ensure the `public_baseurl` option in `homeserver.yaml` is set
correctly. `public_baseurl` should be set to the URL that clients will use to
connect to your server. This is the same URL you put for the `m.homeserver`
`base_url` above.
```
public_baseurl: "https://<matrix.example.com>"
```
## Email
@@ -437,7 +518,7 @@ email will be disabled.
## Registering a user
The easiest way to create a new user is to do so from a client like [Riot](https://riot.im).
The easiest way to create a new user is to do so from a client like [Element](https://element.io/).
Alternatively you can do so from the command line if you have installed via pip.
+7 -36
View File
@@ -45,7 +45,7 @@ which handle:
- Eventually-consistent cryptographically secure synchronisation of room
state across a global open network of federated servers and services
- Sending and receiving extensible messages in a room with (optional)
end-to-end encryption[1]
end-to-end encryption
- Inviting, joining, leaving, kicking, banning room members
- Managing user accounts (registration, login, logout)
- Using 3rd Party IDs (3PIDs) such as email addresses, phone numbers,
@@ -82,9 +82,6 @@ at the `Matrix spec <https://matrix.org/docs/spec>`_, and experiment with the
Thanks for using Matrix!
[1] End-to-end encryption is currently in beta: `blog post <https://matrix.org/blog/2016/11/21/matrixs-olm-end-to-end-encryption-security-assessment-released-and-implemented-cross-platform-on-riot-at-last>`_.
Support
=======
@@ -115,12 +112,11 @@ Unless you are running a test instance of Synapse on your local machine, in
general, you will need to enable TLS support before you can successfully
connect from a client: see `<INSTALL.md#tls-certificates>`_.
An easy way to get started is to login or register via Riot at
https://riot.im/app/#/login or https://riot.im/app/#/register respectively.
An easy way to get started is to login or register via Element at
https://app.element.io/#/login or https://app.element.io/#/register respectively.
You will need to change the server you are logging into from ``matrix.org``
and instead specify a Homeserver URL of ``https://<server_name>:8448``
(or just ``https://<server_name>`` if you are using a reverse proxy).
(Leave the identity server as the default - see `Identity servers`_.)
If you prefer to use another client, refer to our
`client breakdown <https://matrix.org/docs/projects/clients-matrix>`_.
@@ -137,7 +133,7 @@ it, specify ``enable_registration: true`` in ``homeserver.yaml``. (It is then
recommended to also set up CAPTCHA - see `<docs/CAPTCHA_SETUP.md>`_.)
Once ``enable_registration`` is set to ``true``, it is possible to register a
user via `riot.im <https://riot.im/app/#/register>`_ or other Matrix clients.
user via a Matrix client.
Your new user name will be formed partly from the ``server_name``, and partly
from a localpart you specify when you create the account. Your name will take
@@ -183,30 +179,6 @@ versions of synapse.
.. _UPGRADE.rst: UPGRADE.rst
Using PostgreSQL
================
Synapse offers two database engines:
* `PostgreSQL <https://www.postgresql.org>`_
* `SQLite <https://sqlite.org/>`_
Almost all installations should opt to use PostgreSQL. Advantages include:
* significant performance improvements due to the superior threading and
caching model, smarter query optimiser
* allowing the DB to be run on separate hardware
* allowing basic active/backup high-availability with a "hot spare" synapse
pointing at the same DB master, as well as enabling DB replication in
synapse itself.
For information on how to install and use PostgreSQL, please see
`docs/postgres.md <docs/postgres.md>`_.
By default Synapse uses SQLite and in doing so trades performance for convenience.
SQLite is only recommended in Synapse for testing purposes or for servers with
light workloads.
.. _reverse-proxy:
Using a reverse proxy with Synapse
@@ -255,10 +227,9 @@ email address.
Password reset
==============
If a user has registered an email address to their account using an identity
server, they can request a password-reset token via clients such as Riot.
A manual password reset can be done via direct database access as follows.
Users can reset their password through their client. Alternatively, a server admin
can reset a users password using the `admin API <docs/admin_api/user_admin_api.rst#reset-password>`_
or by directly editing the database as shown below.
First calculate the hash of the new password::
+1
View File
@@ -0,0 +1 @@
Add unread messages count to sync responses, as specified in [MSC2654](https://github.com/matrix-org/matrix-doc/pull/2654).
+1
View File
@@ -0,0 +1 @@
Document how to set up a Client Well-Known file and fix several pieces of outdated documentation.
+1
View File
@@ -0,0 +1 @@
Add option to allow server admins to join rooms which fail complexity checks. Contributed by @lugino-emeritus.
+1
View File
@@ -0,0 +1 @@
Switch to the JSON implementation from the standard library and bump the minimum version of the canonicaljson library to 1.2.0.
+1
View File
@@ -0,0 +1 @@
Convert various parts of the codebase to async/await.
+1
View File
@@ -0,0 +1 @@
Convert various parts of the codebase to async/await.
+1
View File
@@ -0,0 +1 @@
Convert various parts of the codebase to async/await.
+1
View File
@@ -0,0 +1 @@
Convert various parts of the codebase to async/await.
+1
View File
@@ -0,0 +1 @@
Move some database-related log lines from the default logger to the database/transaction loggers.
+1
View File
@@ -0,0 +1 @@
Convert various parts of the codebase to async/await.
+1
View File
@@ -0,0 +1 @@
Add an option to purge room or not with delete room admin endpoint (`POST /_synapse/admin/v1/rooms/<room_id>/delete`). Contributed by @dklimpel.
+1
View File
@@ -0,0 +1 @@
Add a script to detect source code files using non-unix line terminators.
+1
View File
@@ -0,0 +1 @@
Add a script to detect source code files using non-unix line terminators.
+1
View File
@@ -0,0 +1 @@
Log the SAML session ID during creation.
+1
View File
@@ -0,0 +1 @@
Fix a long standing bug: 'Duplicate key value violates unique constraint "event_relations_id"' when message retention is configured.
+1
View File
@@ -0,0 +1 @@
Switch to the JSON implementation from the standard library and bump the minimum version of the canonicaljson library to 1.2.0.
+1
View File
@@ -0,0 +1 @@
Convert various parts of the codebase to async/await.
+1 -1
View File
@@ -1,2 +1,2 @@
# Specify environment variables used when running Synapse
# SYNAPSE_CACHE_FACTOR=1 (default)
# SYNAPSE_CACHE_FACTOR=0.5 (default)
+14 -13
View File
@@ -46,19 +46,20 @@ Configuration file may be generated as follows:
## ENVIRONMENT
* `SYNAPSE_CACHE_FACTOR`:
Synapse's architecture is quite RAM hungry currently - a lot of
recent room data and metadata is deliberately cached in RAM in
order to speed up common requests. This will be improved in
future, but for now the easiest way to either reduce the RAM usage
(at the risk of slowing things down) is to set the
SYNAPSE_CACHE_FACTOR environment variable. Roughly speaking, a
SYNAPSE_CACHE_FACTOR of 1.0 will max out at around 3-4GB of
resident memory - this is what we currently run the matrix.org
on. The default setting is currently 0.1, which is probably around
a ~700MB footprint. You can dial it down further to 0.02 if
desired, which targets roughly ~512MB. Conversely you can dial it
up if you need performance for lots of users and have a box with a
lot of RAM.
Synapse's architecture is quite RAM hungry currently - we deliberately
cache a lot of recent room data and metadata in RAM in order to speed up
common requests. We'll improve this in the future, but for now the easiest
way to either reduce the RAM usage (at the risk of slowing things down)
is to set the almost-undocumented ``SYNAPSE_CACHE_FACTOR`` environment
variable. The default is 0.5, which can be decreased to reduce RAM usage
in memory constrained enviroments, or increased if performance starts to
degrade.
However, degraded performance due to a low cache factor, common on
machines with slow disks, often leads to explosions in memory use due
backlogged requests. In this case, reducing the cache factor will make
things worse. Instead, try increasing it drastically. 2.0 is a good
starting value.
## COPYRIGHT
+11
View File
@@ -10,5 +10,16 @@
# homeserver.yaml. Instead, if you are starting from scratch, please generate
# a fresh config using Synapse by following the instructions in INSTALL.md.
# Configuration options that take a time period can be set using a number
# followed by a letter. Letters have the following meanings:
# s = second
# m = minute
# h = hour
# d = day
# w = week
# y = year
# For example, setting redaction_retention_period: 5m would remove redacted
# messages from the database after 5 minutes, rather than 5 months.
################################################################################
+9 -4
View File
@@ -369,7 +369,9 @@ to the new room will have power level `-10` by default, and thus be unable to sp
If `block` is `True` it prevents new joins to the old room.
This API will remove all trace of the old room from your database after removing
all local users.
all local users. If `purge` is `true` (the default), all traces of the old room will
be removed from your database after removing all local users. If you do not want
this to happen, set `purge` to `false`.
Depending on the amount of history being purged a call to the API may take
several minutes or longer.
@@ -388,7 +390,8 @@ with a body of:
"new_room_user_id": "@someuser:example.com",
"room_name": "Content Violation Notification",
"message": "Bad Room has been shutdown due to content violations on this server. Please review our Terms of Service.",
"block": true
"block": true,
"purge": true
}
```
@@ -430,8 +433,10 @@ The following JSON body parameters are available:
`new_room_user_id` in the new room. Ideally this will clearly convey why the
original room was shut down. Defaults to `Sharing illegal content on this server
is not permitted and rooms in violation will be blocked.`
* `block` - Optional. If set to `true`, this room will be added to a blocking list, preventing future attempts to
join the room. Defaults to `false`.
* `block` - Optional. If set to `true`, this room will be added to a blocking list, preventing
future attempts to join the room. Defaults to `false`.
* `purge` - Optional. If set to `true`, it will remove all traces of the room from your database.
Defaults to `true`.
The JSON body must not be empty. The body must be at least `{}`.
+1 -1
View File
@@ -27,7 +27,7 @@
different thread to Synapse. This can make it more resilient to
heavy load meaning metrics cannot be retrieved, and can be exposed
to just internal networks easier. The served metrics are available
over HTTP only, and will be available at `/`.
over HTTP only, and will be available at `/_synapse/metrics`.
Add a new listener to homeserver.yaml:
+3
View File
@@ -188,6 +188,9 @@ to do step 2.
It is safe to at any time kill the port script and restart it.
Note that the database may take up significantly more (25% - 100% more)
space on disk after porting to Postgres.
### Using the port script
Firstly, shut down the currently running synapse server and copy its
+15 -270
View File
@@ -10,6 +10,17 @@
# homeserver.yaml. Instead, if you are starting from scratch, please generate
# a fresh config using Synapse by following the instructions in INSTALL.md.
# Configuration options that take a time period can be set using a number
# followed by a letter. Letters have the following meanings:
# s = second
# m = minute
# h = hour
# d = day
# w = week
# y = year
# For example, setting redaction_retention_period: 5m would remove redacted
# messages from the database after 5 minutes, rather than 5 months.
################################################################################
# Configuration file for Synapse.
@@ -314,6 +325,10 @@ limit_remote_rooms:
#
#complexity_error: "This room is too complex."
# allow server admins to join complex rooms. Default is false.
#
#admins_can_join: true
# Whether to require a user to be in the room to add an alias to it.
# Defaults to 'true'.
#
@@ -325,74 +340,6 @@ limit_remote_rooms:
#
#allow_per_room_profiles: false
# Whether to show the users on this homeserver in the user directory. Defaults to
# 'true'.
#
#show_users_in_user_directory: false
# Message retention policy at the server level.
#
# Room admins and mods can define a retention period for their rooms using the
# 'm.room.retention' state event, and server admins can cap this period by setting
# the 'allowed_lifetime_min' and 'allowed_lifetime_max' config options.
#
# If this feature is enabled, Synapse will regularly look for and purge events
# which are older than the room's maximum retention period. Synapse will also
# filter events received over federation so that events that should have been
# purged are ignored and not stored again.
#
retention:
# The message retention policies feature is disabled by default. Uncomment the
# following line to enable it.
#
#enabled: true
# Default retention policy. If set, Synapse will apply it to rooms that lack the
# 'm.room.retention' state event. Currently, the value of 'min_lifetime' doesn't
# matter much because Synapse doesn't take it into account yet.
#
#default_policy:
# min_lifetime: 1d
# max_lifetime: 1y
# Retention policy limits. If set, a user won't be able to send a
# 'm.room.retention' event which features a 'min_lifetime' or a 'max_lifetime'
# that's not within this range. This is especially useful in closed federations,
# in which server admins can make sure every federating server applies the same
# rules.
#
#allowed_lifetime_min: 1d
#allowed_lifetime_max: 1y
# Server admins can define the settings of the background jobs purging the
# events which lifetime has expired under the 'purge_jobs' section.
#
# If no configuration is provided, a single job will be set up to delete expired
# events in every room daily.
#
# Each job's configuration defines which range of message lifetimes the job
# takes care of. For example, if 'shortest_max_lifetime' is '2d' and
# 'longest_max_lifetime' is '3d', the job will handle purging expired events in
# rooms whose state defines a 'max_lifetime' that's both higher than 2 days, and
# lower than or equal to 3 days. Both the minimum and the maximum value of a
# range are optional, e.g. a job with no 'shortest_max_lifetime' and a
# 'longest_max_lifetime' of '3d' will handle every room with a retention policy
# which 'max_lifetime' is lower than or equal to three days.
#
# The rationale for this per-job configuration is that some rooms might have a
# retention policy with a low 'max_lifetime', where history needs to be purged
# of outdated messages on a very frequent basis (e.g. every 5min), but not want
# that purge to be performed by a job that's iterating over every room it knows,
# which would be quite heavy on the server.
#
#purge_jobs:
# - shortest_max_lifetime: 1d
# longest_max_lifetime: 3d
# interval: 5m:
# - shortest_max_lifetime: 3d
# longest_max_lifetime: 1y
# interval: 24h
# How long to keep redacted events in unredacted form in the database. After
# this period redacted events get replaced with their redacted form in the DB.
#
@@ -479,24 +426,6 @@ retention:
#
#request_token_inhibit_3pid_errors: true
# A list of domains that the domain portion of 'next_link' parameters
# must match.
#
# This parameter is optionally provided by clients while requesting
# validation of an email or phone number, and maps to a link that
# users will be automatically redirected to after validation
# succeeds. Clients can make use this parameter to aid the validation
# process.
#
# The whitelist is applied whether the homeserver or an
# identity server is handling validation.
#
# The default value is no whitelist functionality; all domains are
# allowed. Setting this value to an empty list will instead disallow
# all domains.
#
#next_link_domain_whitelist: ["matrix.org"]
## TLS ##
@@ -814,8 +743,6 @@ log_config: "CONFDIR/SERVERNAME.log.config"
# - one for login that ratelimits login requests based on the account the
# client is attempting to log into, based on the amount of failed login
# attempts for this account.
# - one that ratelimits third-party invites requests based on the account
# that's making the requests.
# - one for ratelimiting redactions by room admins. If this is not explicitly
# set then it uses the same ratelimiting as per rc_message. This is useful
# to allow room admins to deal with abuse quickly.
@@ -841,10 +768,6 @@ log_config: "CONFDIR/SERVERNAME.log.config"
# per_second: 0.17
# burst_count: 3
#
#rc_third_party_invite:
# per_second: 0.2
# burst_count: 10
#
#rc_admin_redaction:
# per_second: 1
# burst_count: 50
@@ -911,30 +834,6 @@ media_store_path: "DATADIR/media_store"
#
#max_upload_size: 10M
# The largest allowed size for a user avatar. If not defined, no
# restriction will be imposed.
#
# Note that this only applies when an avatar is changed globally.
# Per-room avatar changes are not affected. See allow_per_room_profiles
# for disabling that functionality.
#
# Note that user avatar changes will not work if this is set without
# using Synapse's local media repo.
#
#max_avatar_size: 10M
# Allow mimetypes for a user avatar. If not defined, no restriction will
# be imposed.
#
# Note that this only applies when an avatar is changed globally.
# Per-room avatar changes are not affected. See allow_per_room_profiles
# for disabling that functionality.
#
# Note that user avatar changes will not work if this is set without
# using Synapse's local media repo.
#
#allowed_avatar_mimetypes: ["image/png", "image/jpeg", "image/gif"]
# Maximum number of pixels that will be thumbnailed
#
#max_image_pixels: 32M
@@ -1219,32 +1118,9 @@ account_validity:
#
#disable_msisdn_registration: true
# Derive the user's matrix ID from a type of 3PID used when registering.
# This overrides any matrix ID the user proposes when calling /register
# The 3PID type should be present in registrations_require_3pid to avoid
# users failing to register if they don't specify the right kind of 3pid.
#
#register_mxid_from_3pid: email
# Uncomment to set the display name of new users to their email address,
# rather than using the default heuristic.
#
#register_just_use_email_for_display_name: true
# Mandate that users are only allowed to associate certain formats of
# 3PIDs with accounts on this server.
#
# Use an Identity Server to establish which 3PIDs are allowed to register?
# Overrides allowed_local_3pids below.
#
#check_is_for_allowed_local_3pids: matrix.org
#
# If you are using an IS you can also check whether that IS registers
# pending invites for the given 3PID (and then allow it to sign up on
# the platform):
#
#allow_invited_3pids: false
#
#allowed_local_3pids:
# - medium: email
# pattern: '.*@matrix\.org'
@@ -1253,11 +1129,6 @@ account_validity:
# - medium: msisdn
# pattern: '\+44'
# If true, stop users from trying to change the 3PIDs associated with
# their accounts.
#
#disable_3pid_changes: false
# Enable 3PIDs lookup requests to identity servers from this server.
#
#enable_3pid_lookup: true
@@ -1289,48 +1160,6 @@ account_validity:
#
#default_identity_server: https://matrix.org
# The list of identity servers trusted to verify third party
# identifiers by this server.
#
# Also defines the ID server which will be called when an account is
# deactivated (one will be picked arbitrarily).
#
# Note: This option is deprecated. Since v0.99.4, Synapse has tracked which identity
# server a 3PID has been bound to. For 3PIDs bound before then, Synapse runs a
# background migration script, informing itself that the identity server all of its
# 3PIDs have been bound to is likely one of the below.
#
# As of Synapse v1.4.0, all other functionality of this option has been deprecated, and
# it is now solely used for the purposes of the background migration script, and can be
# removed once it has run.
#trusted_third_party_id_servers:
# - matrix.org
# - vector.im
# If enabled, user IDs, display names and avatar URLs will be replicated
# to this server whenever they change.
# This is an experimental API currently implemented by sydent to support
# cross-homeserver user directories.
#
#replicate_user_profiles_to: example.com
# If specified, attempt to replay registrations, profile changes & 3pid
# bindings on the given target homeserver via the AS API. The HS is authed
# via a given AS token.
#
#shadow_server:
# hs_url: https://shadow.example.com
# hs: shadow.example.com
# as_token: 12u394refgbdhivsia
# If enabled, don't let users set their own display names/avatars
# other than for the very first time (unless they are a server admin).
# Useful when provisioning users based on the contents of a 3rd party
# directory and to avoid ambiguities.
#
#disable_set_displayname: false
#disable_set_avatar_url: false
# Handle threepid (email/phone etc) registration and password resets through a set of
# *trusted* identity servers. Note that this allows the configured identity server to
# reset passwords for accounts!
@@ -1457,31 +1286,6 @@ account_threepid_delegates:
#
#auto_join_rooms_for_guests: false
# Rewrite identity server URLs with a map from one URL to another. Applies to URLs
# provided by clients (which have https:// prepended) and those specified
# in `account_threepid_delegates`. URLs should not feature a trailing slash.
#
#rewrite_identity_server_urls:
# "https://somewhere.example.com": "https://somewhereelse.example.com"
# When a user registers an account with an email address, it can be useful to
# bind that email address to their mxid on an identity server. Typically, this
# requires the user to validate their email address with the identity server.
# However if Synapse itself is handling email validation on registration, the
# user ends up needing to validate their email twice, which leads to poor UX.
#
# It is possible to force Sydent, one identity server implementation, to bind
# threepids using its internal, unauthenticated bind API:
# https://github.com/matrix-org/sydent/#internal-bind-and-unbind-api
#
# Configure the address of a Sydent server here to have Synapse attempt
# to automatically bind users' emails following registration. The
# internal bind API must be reachable from Synapse, but should NOT be
# exposed to any third party, as it allows the creation of bindings
# without validation.
#
#bind_new_user_emails_to_sydent: https://example.com:8091
## Metrics ###
@@ -2386,11 +2190,6 @@ spam_checker:
#user_directory:
# enabled: true
# search_all_users: false
#
# # If this is set, user search will be delegated to this ID server instead
# # of synapse performing the search itself.
# # This is an experimental API.
# defer_to_id_server: https://id.example.com
# User Consent configuration
@@ -2596,57 +2395,3 @@ opentracing:
#
# logging:
# false
## Workers ##
# Disables sending of outbound federation transactions on the main process.
# Uncomment if using a federation sender worker.
#
#send_federation: false
# It is possible to run multiple federation sender workers, in which case the
# work is balanced across them.
#
# This configuration must be shared between all federation sender workers, and if
# changed all federation sender workers must be stopped at the same time and then
# started, to ensure that all instances are running with the same config (otherwise
# events may be dropped).
#
#federation_sender_instances:
# - federation_sender1
# When using workers this should be a map from `worker_name` to the
# HTTP replication listener of the worker, if configured.
#
#instance_map:
# worker1:
# host: localhost
# port: 8034
# Experimental: When using workers you can define which workers should
# handle event persistence and typing notifications. Any worker
# specified here must also be in the `instance_map`.
#
#stream_writers:
# events: worker1
# typing: worker1
# Configuration for Redis when using workers. This *must* be enabled when
# using workers (unless using old style direct TCP configuration).
#
redis:
# Uncomment the below to enable Redis support.
#
#enabled: true
# Optional host and port to use to connect to redis. Defaults to
# localhost and 6379
#
#host: localhost
#port: 6379
# Optional password if configured on the Redis instance
#
#password: <secret_password>
+34
View File
@@ -0,0 +1,34 @@
#!/bin/bash
#
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# This script checks that line terminators in all repository files (excluding
# those in the .git directory) feature unix line terminators.
#
# Usage:
#
# ./check_line_terminators.sh
#
# The script will emit exit code 1 if any files that do not use unix line
# terminators are found, 0 otherwise.
# cd to the root of the repository
cd `dirname $0`/..
# Find and print files with non-unix line terminators
if find . -path './.git/*' -prune -o -type f -print0 | xargs -0 grep -I -l $'\r$'; then
echo -e '\e[31mERROR: found files with CRLF line endings. See above.\e[39m'
exit 1
fi
+1 -1
View File
@@ -70,7 +70,7 @@ logger = logging.getLogger("synapse_port_db")
BOOLEAN_COLUMNS = {
"events": ["processed", "outlier", "contains_url"],
"events": ["processed", "outlier", "contains_url", "count_as_unread"],
"rooms": ["is_public"],
"event_edges": ["is_state"],
"presence_list": ["accepted"],
+12
View File
@@ -17,6 +17,7 @@
""" This is a reference implementation of a Matrix homeserver.
"""
import json
import os
import sys
@@ -25,6 +26,9 @@ if sys.version_info < (3, 5):
print("Synapse requires Python 3.5 or above.")
sys.exit(1)
# Twisted and canonicaljson will fail to import when this file is executed to
# get the __version__ during a fresh install. That's OK and subsequent calls to
# actually start Synapse will import these libraries fine.
try:
from twisted.internet import protocol
from twisted.internet.protocol import Factory
@@ -36,6 +40,14 @@ try:
except ImportError:
pass
# Use the standard library json implementation instead of simplejson.
try:
from canonicaljson import set_json_library
set_json_library(json)
except ImportError:
pass
__version__ = "1.18.0"
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
+1 -1
View File
@@ -82,7 +82,7 @@ class Auth(object):
@defer.inlineCallbacks
def check_from_context(self, room_version: str, event, context, do_sig_check=True):
prev_state_ids = yield context.get_prev_state_ids()
prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())
auth_events_ids = yield self.compute_auth_events(
event, prev_state_ids, for_verification=True
)
+1 -1
View File
@@ -628,7 +628,7 @@ class GenericWorkerServer(HomeServer):
self.get_tcp_replication().start_replication(self)
def remove_pusher(self, app_id, push_key, user_id):
async def remove_pusher(self, app_id, push_key, user_id):
self.get_tcp_replication().send_remove_pusher(app_id, push_key, user_id)
def build_replication_data_handler(self):
-18
View File
@@ -412,24 +412,6 @@ class RegistrationConfig(Config):
#
#default_identity_server: https://matrix.org
# The list of identity servers trusted to verify third party
# identifiers by this server.
#
# Also defines the ID server which will be called when an account is
# deactivated (one will be picked arbitrarily).
#
# Note: This option is deprecated. Since v0.99.4, Synapse has tracked which identity
# server a 3PID has been bound to. For 3PIDs bound before then, Synapse runs a
# background migration script, informing itself that the identity server all of its
# 3PIDs have been bound to is likely one of the below.
#
# As of Synapse v1.4.0, all other functionality of this option has been deprecated, and
# it is now solely used for the purposes of the background migration script, and can be
# removed once it has run.
#trusted_third_party_id_servers:
# - matrix.org
# - vector.im
# If enabled, user IDs, display names and avatar URLs will be replicated
# to this server whenever they change.
# This is an experimental API currently implemented by sydent to support
+7
View File
@@ -445,6 +445,9 @@ class ServerConfig(Config):
validator=attr.validators.instance_of(str),
default=ROOM_COMPLEXITY_TOO_GREAT,
)
admins_can_join = attr.ib(
validator=attr.validators.instance_of(bool), default=False
)
self.limit_remote_rooms = LimitRemoteRoomsConfig(
**(config.get("limit_remote_rooms") or {})
@@ -927,6 +930,10 @@ class ServerConfig(Config):
#
#complexity_error: "This room is too complex."
# allow server admins to join complex rooms. Default is false.
#
#admins_can_join: true
# Whether to require a user to be in the room to add an alias to it.
# Defaults to 'true'.
#
+8 -11
View File
@@ -17,8 +17,6 @@ from typing import Optional
import attr
from nacl.signing import SigningKey
from twisted.internet import defer
from synapse.api.constants import MAX_DEPTH
from synapse.api.errors import UnsupportedRoomVersionError
from synapse.api.room_versions import (
@@ -95,31 +93,30 @@ class EventBuilder(object):
def is_state(self):
return self._state_key is not None
@defer.inlineCallbacks
def build(self, prev_event_ids):
async def build(self, prev_event_ids):
"""Transform into a fully signed and hashed event
Args:
prev_event_ids (list[str]): The event IDs to use as the prev events
Returns:
Deferred[FrozenEvent]
FrozenEvent
"""
state_ids = yield defer.ensureDeferred(
self._state.get_current_state_ids(self.room_id, prev_event_ids)
state_ids = await self._state.get_current_state_ids(
self.room_id, prev_event_ids
)
auth_ids = yield self._auth.compute_auth_events(self, state_ids)
auth_ids = await self._auth.compute_auth_events(self, state_ids)
format_version = self.room_version.event_format
if format_version == EventFormatVersions.V1:
auth_events = yield self._store.add_event_hashes(auth_ids)
prev_events = yield self._store.add_event_hashes(prev_event_ids)
auth_events = await self._store.add_event_hashes(auth_ids)
prev_events = await self._store.add_event_hashes(prev_event_ids)
else:
auth_events = auth_ids
prev_events = prev_event_ids
old_depth = yield self._store.get_max_depth_of(prev_event_ids)
old_depth = await self._store.get_max_depth_of(prev_event_ids)
depth = old_depth + 1
# we cap depth of generated events, to ensure that they are not
+22 -24
View File
@@ -12,17 +12,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Union
from typing import TYPE_CHECKING, Optional, Union
import attr
from frozendict import frozendict
from twisted.internet import defer
from synapse.appservice import ApplicationService
from synapse.events import EventBase
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.types import StateMap
if TYPE_CHECKING:
from synapse.storage.data_stores.main import DataStore
@attr.s(slots=True)
class EventContext:
@@ -129,8 +131,7 @@ class EventContext:
delta_ids=delta_ids,
)
@defer.inlineCallbacks
def serialize(self, event, store):
async def serialize(self, event: EventBase, store: "DataStore") -> dict:
"""Converts self to a type that can be serialized as JSON, and then
deserialized by `deserialize`
@@ -146,7 +147,7 @@ class EventContext:
# the prev_state_ids, so if we're a state event we include the event
# id that we replaced in the state.
if event.is_state():
prev_state_ids = yield self.get_prev_state_ids()
prev_state_ids = await self.get_prev_state_ids()
prev_state_id = prev_state_ids.get((event.type, event.state_key))
else:
prev_state_id = None
@@ -214,8 +215,7 @@ class EventContext:
return self._state_group
@defer.inlineCallbacks
def get_current_state_ids(self):
async def get_current_state_ids(self) -> Optional[StateMap[str]]:
"""
Gets the room state map, including this event - ie, the state in ``state_group``
@@ -224,32 +224,31 @@ class EventContext:
``rejected`` is set.
Returns:
Deferred[dict[(str, str), str]|None]: Returns None if state_group
is None, which happens when the associated event is an outlier.
Returns None if state_group is None, which happens when the associated
event is an outlier.
Maps a (type, state_key) to the event ID of the state event matching
this tuple.
Maps a (type, state_key) to the event ID of the state event matching
this tuple.
"""
if self.rejected:
raise RuntimeError("Attempt to access state_ids of rejected event")
yield self._ensure_fetched()
await self._ensure_fetched()
return self._current_state_ids
@defer.inlineCallbacks
def get_prev_state_ids(self):
async def get_prev_state_ids(self):
"""
Gets the room state map, excluding this event.
For a non-state event, this will be the same as get_current_state_ids().
Returns:
Deferred[dict[(str, str), str]|None]: Returns None if state_group
dict[(str, str), str]|None: Returns None if state_group
is None, which happens when the associated event is an outlier.
Maps a (type, state_key) to the event ID of the state event matching
this tuple.
"""
yield self._ensure_fetched()
await self._ensure_fetched()
return self._prev_state_ids
def get_cached_current_state_ids(self):
@@ -269,8 +268,8 @@ class EventContext:
return self._current_state_ids
def _ensure_fetched(self):
return defer.succeed(None)
async def _ensure_fetched(self):
return None
@attr.s(slots=True)
@@ -303,21 +302,20 @@ class _AsyncEventContextImpl(EventContext):
_event_state_key = attr.ib(default=None)
_fetching_state_deferred = attr.ib(default=None)
def _ensure_fetched(self):
async def _ensure_fetched(self):
if not self._fetching_state_deferred:
self._fetching_state_deferred = run_in_background(self._fill_out_state)
return make_deferred_yieldable(self._fetching_state_deferred)
return await make_deferred_yieldable(self._fetching_state_deferred)
@defer.inlineCallbacks
def _fill_out_state(self):
async def _fill_out_state(self):
"""Called to populate the _current_state_ids and _prev_state_ids
attributes by loading from the database.
"""
if self.state_group is None:
return
self._current_state_ids = yield self._storage.state.get_state_ids_for_group(
self._current_state_ids = await self._storage.state.get_state_ids_for_group(
self.state_group
)
if self._event_state_key is not None:
+22 -15
View File
@@ -15,8 +15,9 @@
from typing import Callable
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.module_api import ModuleApi
from synapse.types import StateMap
from synapse.types import Requester, StateMap
class ThirdPartyEventRules(object):
@@ -42,15 +43,17 @@ class ThirdPartyEventRules(object):
config=config, module_api=ModuleApi(hs, hs.get_auth_handler()),
)
async def check_event_allowed(self, event, context):
async def check_event_allowed(
self, event: EventBase, context: EventContext
) -> bool:
"""Check if a provided event should be allowed in the given context.
Args:
event (synapse.events.EventBase): The event to be checked.
context (synapse.events.snapshot.EventContext): The context of the event.
event: The event to be checked.
context: The context of the event.
Returns:
defer.Deferred[bool]: True if the event should be allowed, False if not.
True if the event should be allowed, False if not.
"""
if self.third_party_rules is None:
return True
@@ -65,17 +68,19 @@ class ThirdPartyEventRules(object):
ret = await self.third_party_rules.check_event_allowed(event, state_events)
return ret
async def on_create_room(self, requester, config, is_requester_admin):
async def on_create_room(
self, requester: Requester, config: dict, is_requester_admin: bool
) -> bool:
"""Intercept requests to create room to allow, deny or update the
request config.
Args:
requester (Requester)
config (dict): The creation config from the client.
is_requester_admin (bool): If the requester is an admin
requester
config: The creation config from the client.
is_requester_admin: If the requester is an admin
Returns:
defer.Deferred[bool]: Whether room creation is allowed or denied.
Whether room creation is allowed or denied.
"""
if self.third_party_rules is None:
@@ -86,16 +91,18 @@ class ThirdPartyEventRules(object):
)
return ret
async def check_threepid_can_be_invited(self, medium, address, room_id):
async def check_threepid_can_be_invited(
self, medium: str, address: str, room_id: str
) -> bool:
"""Check if a provided 3PID can be invited in the given room.
Args:
medium (str): The 3PID's medium.
address (str): The 3PID's address.
room_id (str): The room we want to invite the threepid to.
medium: The 3PID's medium.
address: The 3PID's address.
room_id: The room we want to invite the threepid to.
Returns:
defer.Deferred[bool], True if the 3PID can be invited, False if not.
True if the 3PID can be invited, False if not.
"""
if self.third_party_rules is None:
+7 -8
View File
@@ -18,8 +18,6 @@ from typing import Any, Mapping, Union
from frozendict import frozendict
from twisted.internet import defer
from synapse.api.constants import EventTypes, RelationTypes
from synapse.api.errors import Codes, SynapseError
from synapse.api.room_versions import RoomVersion
@@ -337,8 +335,9 @@ class EventClientSerializer(object):
hs.config.experimental_msc1849_support_enabled
)
@defer.inlineCallbacks
def serialize_event(self, event, time_now, bundle_aggregations=True, **kwargs):
async def serialize_event(
self, event, time_now, bundle_aggregations=True, **kwargs
):
"""Serializes a single event.
Args:
@@ -348,7 +347,7 @@ class EventClientSerializer(object):
**kwargs: Arguments to pass to `serialize_event`
Returns:
Deferred[dict]: The serialized event
dict: The serialized event
"""
# To handle the case of presence events and the like
if not isinstance(event, EventBase):
@@ -363,8 +362,8 @@ class EventClientSerializer(object):
if not event.internal_metadata.is_redacted() and (
self.experimental_msc1849_support_enabled and bundle_aggregations
):
annotations = yield self.store.get_aggregation_groups_for_event(event_id)
references = yield self.store.get_relations_for_event(
annotations = await self.store.get_aggregation_groups_for_event(event_id)
references = await self.store.get_relations_for_event(
event_id, RelationTypes.REFERENCE, direction="f"
)
@@ -378,7 +377,7 @@ class EventClientSerializer(object):
edit = None
if event.type == EventTypes.Message:
edit = yield self.store.get_applicable_edit(event_id)
edit = await self.store.get_applicable_edit(event_id)
if edit:
# If there is an edit replace the content, preserving existing
+11 -14
View File
@@ -41,8 +41,6 @@ from typing import Tuple
from signedjson.sign import sign_json
from twisted.internet import defer
from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import get_domain_from_id
@@ -72,8 +70,9 @@ class GroupAttestationSigning(object):
self.server_name = hs.hostname
self.signing_key = hs.signing_key
@defer.inlineCallbacks
def verify_attestation(self, attestation, group_id, user_id, server_name=None):
async def verify_attestation(
self, attestation, group_id, user_id, server_name=None
):
"""Verifies that the given attestation matches the given parameters.
An optional server_name can be supplied to explicitly set which server's
@@ -102,7 +101,7 @@ class GroupAttestationSigning(object):
if valid_until_ms < now:
raise SynapseError(400, "Attestation expired")
yield self.keyring.verify_json_for_server(
await self.keyring.verify_json_for_server(
server_name, attestation, now, "Group attestation"
)
@@ -142,8 +141,7 @@ class GroupAttestionRenewer(object):
self._start_renew_attestations, 30 * 60 * 1000
)
@defer.inlineCallbacks
def on_renew_attestation(self, group_id, user_id, content):
async def on_renew_attestation(self, group_id, user_id, content):
"""When a remote updates an attestation
"""
attestation = content["attestation"]
@@ -151,11 +149,11 @@ class GroupAttestionRenewer(object):
if not self.is_mine_id(group_id) and not self.is_mine_id(user_id):
raise SynapseError(400, "Neither user not group are on this server")
yield self.attestations.verify_attestation(
await self.attestations.verify_attestation(
attestation, user_id=user_id, group_id=group_id
)
yield self.store.update_remote_attestion(group_id, user_id, attestation)
await self.store.update_remote_attestion(group_id, user_id, attestation)
return {}
@@ -172,8 +170,7 @@ class GroupAttestionRenewer(object):
now + UPDATE_ATTESTATION_TIME_MS
)
@defer.inlineCallbacks
def _renew_attestation(group_user: Tuple[str, str]):
async def _renew_attestation(group_user: Tuple[str, str]):
group_id, user_id = group_user
try:
if not self.is_mine_id(group_id):
@@ -186,16 +183,16 @@ class GroupAttestionRenewer(object):
user_id,
group_id,
)
yield self.store.remove_attestation_renewal(group_id, user_id)
await self.store.remove_attestation_renewal(group_id, user_id)
return
attestation = self.attestations.create_attestation(group_id, user_id)
yield self.transport_client.renew_group_attestation(
await self.transport_client.renew_group_attestation(
destination, group_id, user_id, content={"attestation": attestation}
)
yield self.store.update_attestation_renewal(
await self.store.update_attestation_renewal(
group_id, user_id, attestation
)
except (RequestSendFailed, HttpResponseException) as e:
+1 -1
View File
@@ -2480,7 +2480,7 @@ class FederationHandler(BaseHandler):
}
current_state_ids = await context.get_current_state_ids()
current_state_ids = dict(current_state_ids)
current_state_ids = dict(current_state_ids) # type: ignore
current_state_ids.update(state_updates)
+6 -2
View File
@@ -1007,7 +1007,11 @@ class RoomMemberMasterHandler(RoomMemberHandler):
if len(remote_room_hosts) == 0:
raise SynapseError(404, "No known servers")
if self.hs.config.limit_remote_rooms.enabled:
check_complexity = self.hs.config.limit_remote_rooms.enabled
if check_complexity and self.hs.config.limit_remote_rooms.admins_can_join:
check_complexity = not await self.hs.auth.is_server_admin(user)
if check_complexity:
# Fetch the room complexity
too_complex = await self._is_remote_room_too_complex(
room_id, remote_room_hosts
@@ -1030,7 +1034,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
# Check the room we just joined wasn't too large, if we didn't fetch the
# complexity of it before.
if self.hs.config.limit_remote_rooms.enabled:
if check_complexity:
if too_complex is False:
# We checked, and we're under the limit.
return event_id, stream_id
+3
View File
@@ -96,6 +96,9 @@ class SamlHandler:
relay_state=client_redirect_url
)
# Since SAML sessions timeout it is useful to log when they were created.
logger.info("Initiating a new SAML session: %s" % (reqid,))
now = self._clock.time_msec()
self._outstanding_requests_dict[reqid] = Saml2SessionData(
creation_time=now, ui_auth_session_id=ui_auth_session_id,
+6
View File
@@ -103,6 +103,7 @@ class JoinedSyncResult:
account_data = attr.ib(type=List[JsonDict])
unread_notifications = attr.ib(type=JsonDict)
summary = attr.ib(type=Optional[JsonDict])
unread_count = attr.ib(type=int)
def __nonzero__(self) -> bool:
"""Make the result appear empty if there are no updates. This is used
@@ -1886,6 +1887,10 @@ class SyncHandler(object):
if room_builder.rtype == "joined":
unread_notifications = {} # type: Dict[str, str]
unread_count = await self.store.get_unread_message_count_for_user(
room_id, sync_config.user.to_string(),
)
room_sync = JoinedSyncResult(
room_id=room_id,
timeline=batch,
@@ -1894,6 +1899,7 @@ class SyncHandler(object):
account_data=account_data_events,
unread_notifications=unread_notifications,
summary=summary,
unread_count=unread_count,
)
if room_sync or always_include:
+12 -4
View File
@@ -395,7 +395,9 @@ class SimpleHttpClient(object):
if 200 <= response.code < 300:
return json.loads(body.decode("utf-8"))
else:
raise HttpResponseException(response.code, response.phrase, body)
raise HttpResponseException(
response.code, response.phrase.decode("ascii", errors="replace"), body
)
@defer.inlineCallbacks
def post_json_get_json(self, uri, post_json, headers=None):
@@ -436,7 +438,9 @@ class SimpleHttpClient(object):
if 200 <= response.code < 300:
return json.loads(body.decode("utf-8"))
else:
raise HttpResponseException(response.code, response.phrase, body)
raise HttpResponseException(
response.code, response.phrase.decode("ascii", errors="replace"), body
)
@defer.inlineCallbacks
def get_json(self, uri, args={}, headers=None):
@@ -509,7 +513,9 @@ class SimpleHttpClient(object):
if 200 <= response.code < 300:
return json.loads(body.decode("utf-8"))
else:
raise HttpResponseException(response.code, response.phrase, body)
raise HttpResponseException(
response.code, response.phrase.decode("ascii", errors="replace"), body
)
@defer.inlineCallbacks
def get_raw(self, uri, args={}, headers=None):
@@ -544,7 +550,9 @@ class SimpleHttpClient(object):
if 200 <= response.code < 300:
return body
else:
raise HttpResponseException(response.code, response.phrase, body)
raise HttpResponseException(
response.code, response.phrase.decode("ascii", errors="replace"), body
)
# XXX: FIXME: This is horribly copy-pasted from matrixfederationclient.
# The two should be factored out.
+4 -3
View File
@@ -447,6 +447,7 @@ class MatrixFederationHttpClient(object):
).inc()
set_tag(tags.HTTP_STATUS_CODE, response.code)
response_phrase = response.phrase.decode("ascii", errors="replace")
if 200 <= response.code < 300:
logger.debug(
@@ -454,7 +455,7 @@ class MatrixFederationHttpClient(object):
request.txn_id,
request.destination,
response.code,
response.phrase.decode("ascii", errors="replace"),
response_phrase,
)
pass
else:
@@ -463,7 +464,7 @@ class MatrixFederationHttpClient(object):
request.txn_id,
request.destination,
response.code,
response.phrase.decode("ascii", errors="replace"),
response_phrase,
)
# :'(
# Update transactions table?
@@ -487,7 +488,7 @@ class MatrixFederationHttpClient(object):
)
body = None
e = HttpResponseException(response.code, response.phrase, body)
e = HttpResponseException(response.code, response_phrase, body)
# Retry if the error is a 429 (Too Many Requests),
# otherwise just raise a standard HttpResponseException
+2 -5
View File
@@ -15,8 +15,6 @@
import logging
from twisted.internet import defer
from synapse.util.metrics import Measure
from .bulk_push_rule_evaluator import BulkPushRuleEvaluator
@@ -37,7 +35,6 @@ class ActionGenerator(object):
# event stream, so we just run the rules for a client with no profile
# tag (ie. we just need all the users).
@defer.inlineCallbacks
def handle_push_actions_for_event(self, event, context):
async def handle_push_actions_for_event(self, event, context):
with Measure(self.clock, "action_for_event_by_user"):
yield self.bulk_evaluator.action_for_event_by_user(event, context)
await self.bulk_evaluator.action_for_event_by_user(event, context)
+25 -37
View File
@@ -19,8 +19,6 @@ from collections import namedtuple
from prometheus_client import Counter
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
from synapse.event_auth import get_user_power_level
from synapse.state import POWER_KEY
@@ -70,8 +68,7 @@ class BulkPushRuleEvaluator(object):
resizable=False,
)
@defer.inlineCallbacks
def _get_rules_for_event(self, event, context):
async def _get_rules_for_event(self, event, context):
"""This gets the rules for all users in the room at the time of the event,
as well as the push rules for the invitee if the event is an invite.
@@ -79,19 +76,19 @@ class BulkPushRuleEvaluator(object):
dict of user_id -> push_rules
"""
room_id = event.room_id
rules_for_room = yield self._get_rules_for_room(room_id)
rules_for_room = await self._get_rules_for_room(room_id)
rules_by_user = yield rules_for_room.get_rules(event, context)
rules_by_user = await rules_for_room.get_rules(event, context)
# if this event is an invite event, we may need to run rules for the user
# who's been invited, otherwise they won't get told they've been invited
if event.type == "m.room.member" and event.content["membership"] == "invite":
invited = event.state_key
if invited and self.hs.is_mine_id(invited):
has_pusher = yield self.store.user_has_pusher(invited)
has_pusher = await self.store.user_has_pusher(invited)
if has_pusher:
rules_by_user = dict(rules_by_user)
rules_by_user[invited] = yield self.store.get_push_rules_for_user(
rules_by_user[invited] = await self.store.get_push_rules_for_user(
invited
)
@@ -114,20 +111,19 @@ class BulkPushRuleEvaluator(object):
self.room_push_rule_cache_metrics,
)
@defer.inlineCallbacks
def _get_power_levels_and_sender_level(self, event, context):
prev_state_ids = yield context.get_prev_state_ids()
async def _get_power_levels_and_sender_level(self, event, context):
prev_state_ids = await context.get_prev_state_ids()
pl_event_id = prev_state_ids.get(POWER_KEY)
if pl_event_id:
# fastpath: if there's a power level event, that's all we need, and
# not having a power level event is an extreme edge case
pl_event = yield self.store.get_event(pl_event_id)
pl_event = await self.store.get_event(pl_event_id)
auth_events = {POWER_KEY: pl_event}
else:
auth_events_ids = yield self.auth.compute_auth_events(
auth_events_ids = await self.auth.compute_auth_events(
event, prev_state_ids, for_verification=False
)
auth_events = yield self.store.get_events(auth_events_ids)
auth_events = await self.store.get_events(auth_events_ids)
auth_events = {(e.type, e.state_key): e for e in auth_events.values()}
sender_level = get_user_power_level(event.sender, auth_events)
@@ -136,23 +132,19 @@ class BulkPushRuleEvaluator(object):
return pl_event.content if pl_event else {}, sender_level
@defer.inlineCallbacks
def action_for_event_by_user(self, event, context):
async def action_for_event_by_user(self, event, context) -> None:
"""Given an event and context, evaluate the push rules and insert the
results into the event_push_actions_staging table.
Returns:
Deferred
"""
rules_by_user = yield self._get_rules_for_event(event, context)
rules_by_user = await self._get_rules_for_event(event, context)
actions_by_user = {}
room_members = yield self.store.get_joined_users_from_context(event, context)
room_members = await self.store.get_joined_users_from_context(event, context)
(
power_levels,
sender_power_level,
) = yield self._get_power_levels_and_sender_level(event, context)
) = await self._get_power_levels_and_sender_level(event, context)
evaluator = PushRuleEvaluatorForEvent(
event, len(room_members), sender_power_level, power_levels
@@ -165,7 +157,7 @@ class BulkPushRuleEvaluator(object):
continue
if not event.is_state():
is_ignored = yield self.store.is_ignored_by(event.sender, uid)
is_ignored = await self.store.is_ignored_by(event.sender, uid)
if is_ignored:
continue
@@ -197,7 +189,7 @@ class BulkPushRuleEvaluator(object):
# Mark in the DB staging area the push actions for users who should be
# notified for this event. (This will then get handled when we persist
# the event)
yield self.store.add_push_actions_to_staging(event.event_id, actions_by_user)
await self.store.add_push_actions_to_staging(event.event_id, actions_by_user)
def _condition_checker(evaluator, conditions, uid, display_name, cache):
@@ -274,8 +266,7 @@ class RulesForRoom(object):
# to self around in the callback.
self.invalidate_all_cb = _Invalidation(rules_for_room_cache, room_id)
@defer.inlineCallbacks
def get_rules(self, event, context):
async def get_rules(self, event, context):
"""Given an event context return the rules for all users who are
currently in the room.
"""
@@ -286,7 +277,7 @@ class RulesForRoom(object):
self.room_push_rule_cache_metrics.inc_hits()
return self.rules_by_user
with (yield self.linearizer.queue(())):
with (await self.linearizer.queue(())):
if state_group and self.state_group == state_group:
logger.debug("Using cached rules for %r", self.room_id)
self.room_push_rule_cache_metrics.inc_hits()
@@ -304,9 +295,7 @@ class RulesForRoom(object):
push_rules_delta_state_cache_metric.inc_hits()
else:
current_state_ids = yield defer.ensureDeferred(
context.get_current_state_ids()
)
current_state_ids = await context.get_current_state_ids()
push_rules_delta_state_cache_metric.inc_misses()
push_rules_state_size_counter.inc(len(current_state_ids))
@@ -353,7 +342,7 @@ class RulesForRoom(object):
# If we have some memebr events we haven't seen, look them up
# and fetch push rules for them if appropriate.
logger.debug("Found new member events %r", missing_member_event_ids)
yield self._update_rules_with_member_event_ids(
await self._update_rules_with_member_event_ids(
ret_rules_by_user, missing_member_event_ids, state_group, event
)
else:
@@ -371,8 +360,7 @@ class RulesForRoom(object):
)
return ret_rules_by_user
@defer.inlineCallbacks
def _update_rules_with_member_event_ids(
async def _update_rules_with_member_event_ids(
self, ret_rules_by_user, member_event_ids, state_group, event
):
"""Update the partially filled rules_by_user dict by fetching rules for
@@ -388,7 +376,7 @@ class RulesForRoom(object):
"""
sequence = self.sequence
rows = yield self.store.get_membership_from_event_ids(member_event_ids.values())
rows = await self.store.get_membership_from_event_ids(member_event_ids.values())
members = {row["event_id"]: (row["user_id"], row["membership"]) for row in rows}
@@ -410,7 +398,7 @@ class RulesForRoom(object):
logger.debug("Joined: %r", interested_in_user_ids)
if_users_with_pushers = yield self.store.get_if_users_have_pushers(
if_users_with_pushers = await self.store.get_if_users_have_pushers(
interested_in_user_ids, on_invalidate=self.invalidate_all_cb
)
@@ -420,7 +408,7 @@ class RulesForRoom(object):
logger.debug("With pushers: %r", user_ids)
users_with_receipts = yield self.store.get_users_with_read_receipts_in_room(
users_with_receipts = await self.store.get_users_with_read_receipts_in_room(
self.room_id, on_invalidate=self.invalidate_all_cb
)
@@ -431,7 +419,7 @@ class RulesForRoom(object):
if uid in interested_in_user_ids:
user_ids.add(uid)
rules_by_user = yield self.store.bulk_get_push_rules(
rules_by_user = await self.store.bulk_get_push_rules(
user_ids, on_invalidate=self.invalidate_all_cb
)
+25 -33
View File
@@ -17,7 +17,6 @@ import logging
from prometheus_client import Counter
from twisted.internet import defer
from twisted.internet.error import AlreadyCalled, AlreadyCancelled
from synapse.api.constants import EventTypes
@@ -128,12 +127,11 @@ class HttpPusher(object):
# but currently that's the only type of receipt anyway...
run_as_background_process("http_pusher.on_new_receipts", self._update_badge)
@defer.inlineCallbacks
def _update_badge(self):
async def _update_badge(self):
# XXX as per https://github.com/matrix-org/matrix-doc/issues/2627, this seems
# to be largely redundant. perhaps we can remove it.
badge = yield push_tools.get_badge_count(self.hs.get_datastore(), self.user_id)
yield self._send_badge(badge)
badge = await push_tools.get_badge_count(self.hs.get_datastore(), self.user_id)
await self._send_badge(badge)
def on_timer(self):
self._start_processing()
@@ -152,8 +150,7 @@ class HttpPusher(object):
run_as_background_process("httppush.process", self._process)
@defer.inlineCallbacks
def _process(self):
async def _process(self):
# we should never get here if we are already processing
assert not self._is_processing
@@ -164,7 +161,7 @@ class HttpPusher(object):
while True:
starting_max_ordering = self.max_stream_ordering
try:
yield self._unsafe_process()
await self._unsafe_process()
except Exception:
logger.exception("Exception processing notifs")
if self.max_stream_ordering == starting_max_ordering:
@@ -172,8 +169,7 @@ class HttpPusher(object):
finally:
self._is_processing = False
@defer.inlineCallbacks
def _unsafe_process(self):
async def _unsafe_process(self):
"""
Looks for unset notifications and dispatch them, in order
Never call this directly: use _process which will only allow this to
@@ -181,7 +177,7 @@ class HttpPusher(object):
"""
fn = self.store.get_unread_push_actions_for_user_in_range_for_http
unprocessed = yield fn(
unprocessed = await fn(
self.user_id, self.last_stream_ordering, self.max_stream_ordering
)
@@ -203,13 +199,13 @@ class HttpPusher(object):
"app_display_name": self.app_display_name,
},
):
processed = yield self._process_one(push_action)
processed = await self._process_one(push_action)
if processed:
http_push_processed_counter.inc()
self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC
self.last_stream_ordering = push_action["stream_ordering"]
pusher_still_exists = yield self.store.update_pusher_last_stream_ordering_and_success(
pusher_still_exists = await self.store.update_pusher_last_stream_ordering_and_success(
self.app_id,
self.pushkey,
self.user_id,
@@ -224,14 +220,14 @@ class HttpPusher(object):
if self.failing_since:
self.failing_since = None
yield self.store.update_pusher_failing_since(
await self.store.update_pusher_failing_since(
self.app_id, self.pushkey, self.user_id, self.failing_since
)
else:
http_push_failed_counter.inc()
if not self.failing_since:
self.failing_since = self.clock.time_msec()
yield self.store.update_pusher_failing_since(
await self.store.update_pusher_failing_since(
self.app_id, self.pushkey, self.user_id, self.failing_since
)
@@ -250,7 +246,7 @@ class HttpPusher(object):
)
self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC
self.last_stream_ordering = push_action["stream_ordering"]
pusher_still_exists = yield self.store.update_pusher_last_stream_ordering(
pusher_still_exists = await self.store.update_pusher_last_stream_ordering(
self.app_id,
self.pushkey,
self.user_id,
@@ -263,7 +259,7 @@ class HttpPusher(object):
return
self.failing_since = None
yield self.store.update_pusher_failing_since(
await self.store.update_pusher_failing_since(
self.app_id, self.pushkey, self.user_id, self.failing_since
)
else:
@@ -276,18 +272,17 @@ class HttpPusher(object):
)
break
@defer.inlineCallbacks
def _process_one(self, push_action):
async def _process_one(self, push_action):
if "notify" not in push_action["actions"]:
return True
tweaks = push_rule_evaluator.tweaks_for_actions(push_action["actions"])
badge = yield push_tools.get_badge_count(self.hs.get_datastore(), self.user_id)
badge = await push_tools.get_badge_count(self.hs.get_datastore(), self.user_id)
event = yield self.store.get_event(push_action["event_id"], allow_none=True)
event = await self.store.get_event(push_action["event_id"], allow_none=True)
if event is None:
return True # It's been redacted
rejected = yield self.dispatch_push(event, tweaks, badge)
rejected = await self.dispatch_push(event, tweaks, badge)
if rejected is False:
return False
@@ -301,11 +296,10 @@ class HttpPusher(object):
)
else:
logger.info("Pushkey %s was rejected: removing", pk)
yield self.hs.remove_pusher(self.app_id, pk, self.user_id)
await self.hs.remove_pusher(self.app_id, pk, self.user_id)
return True
@defer.inlineCallbacks
def _build_notification_dict(self, event, tweaks, badge):
async def _build_notification_dict(self, event, tweaks, badge):
priority = "low"
if (
event.type == EventTypes.Encrypted
@@ -335,7 +329,7 @@ class HttpPusher(object):
}
return d
ctx = yield push_tools.get_context_for_event(
ctx = await push_tools.get_context_for_event(
self.storage, self.state_handler, event, self.user_id
)
@@ -377,13 +371,12 @@ class HttpPusher(object):
return d
@defer.inlineCallbacks
def dispatch_push(self, event, tweaks, badge):
notification_dict = yield self._build_notification_dict(event, tweaks, badge)
async def dispatch_push(self, event, tweaks, badge):
notification_dict = await self._build_notification_dict(event, tweaks, badge)
if not notification_dict:
return []
try:
resp = yield self.http_client.post_json_get_json(
resp = await self.http_client.post_json_get_json(
self.url, notification_dict
)
except Exception as e:
@@ -400,8 +393,7 @@ class HttpPusher(object):
rejected = resp["rejected"]
return rejected
@defer.inlineCallbacks
def _send_badge(self, badge):
async def _send_badge(self, badge):
"""
Args:
badge (int): number of unread messages
@@ -424,7 +416,7 @@ class HttpPusher(object):
}
}
try:
yield self.http_client.post_json_get_json(self.url, d)
await self.http_client.post_json_get_json(self.url, d)
http_badges_processed_counter.inc()
except Exception as e:
logger.warning(
+6 -9
View File
@@ -16,8 +16,6 @@
import logging
import re
from twisted.internet import defer
from synapse.api.constants import EventTypes
logger = logging.getLogger(__name__)
@@ -29,8 +27,7 @@ ALIAS_RE = re.compile(r"^#.*:.+$")
ALL_ALONE = "Empty Room"
@defer.inlineCallbacks
def calculate_room_name(
async def calculate_room_name(
store,
room_state_ids,
user_id,
@@ -53,7 +50,7 @@ def calculate_room_name(
"""
# does it have a name?
if (EventTypes.Name, "") in room_state_ids:
m_room_name = yield store.get_event(
m_room_name = await store.get_event(
room_state_ids[(EventTypes.Name, "")], allow_none=True
)
if m_room_name and m_room_name.content and m_room_name.content["name"]:
@@ -61,7 +58,7 @@ def calculate_room_name(
# does it have a canonical alias?
if (EventTypes.CanonicalAlias, "") in room_state_ids:
canon_alias = yield store.get_event(
canon_alias = await store.get_event(
room_state_ids[(EventTypes.CanonicalAlias, "")], allow_none=True
)
if (
@@ -81,7 +78,7 @@ def calculate_room_name(
my_member_event = None
if (EventTypes.Member, user_id) in room_state_ids:
my_member_event = yield store.get_event(
my_member_event = await store.get_event(
room_state_ids[(EventTypes.Member, user_id)], allow_none=True
)
@@ -90,7 +87,7 @@ def calculate_room_name(
and my_member_event.content["membership"] == "invite"
):
if (EventTypes.Member, my_member_event.sender) in room_state_ids:
inviter_member_event = yield store.get_event(
inviter_member_event = await store.get_event(
room_state_ids[(EventTypes.Member, my_member_event.sender)],
allow_none=True,
)
@@ -107,7 +104,7 @@ def calculate_room_name(
# we're going to have to generate a name based on who's in the room,
# so find out who is in the room that isn't the user.
if EventTypes.Member in room_state_bytype_ids:
member_events = yield store.get_events(
member_events = await store.get_events(
list(room_state_bytype_ids[EventTypes.Member].values())
)
all_members = [
+11 -24
View File
@@ -13,53 +13,40 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.push.presentable_names import calculate_room_name, name_from_member_event
from synapse.storage import Storage
@defer.inlineCallbacks
def get_badge_count(store, user_id):
invites = yield store.get_invited_rooms_for_local_user(user_id)
joins = yield store.get_rooms_for_user(user_id)
my_receipts_by_room = yield store.get_receipts_for_user(user_id, "m.read")
async def get_badge_count(store, user_id):
invites = await store.get_invited_rooms_for_local_user(user_id)
joins = await store.get_rooms_for_user(user_id)
badge = len(invites)
for room_id in joins:
if room_id in my_receipts_by_room:
last_unread_event_id = my_receipts_by_room[room_id]
notifs = yield (
store.get_unread_event_push_actions_by_room_for_user(
room_id, user_id, last_unread_event_id
)
)
# return one badge count per conversation, as count per
# message is so noisy as to be almost useless
badge += 1 if notifs["notify_count"] else 0
unread_count = await store.get_unread_message_count_for_user(room_id, user_id)
# return one badge count per conversation, as count per
# message is so noisy as to be almost useless
badge += 1 if unread_count else 0
return badge
@defer.inlineCallbacks
def get_context_for_event(storage: Storage, state_handler, ev, user_id):
async def get_context_for_event(storage: Storage, state_handler, ev, user_id):
ctx = {}
room_state_ids = yield storage.state.get_state_ids_for_event(ev.event_id)
room_state_ids = await storage.state.get_state_ids_for_event(ev.event_id)
# we no longer bother setting room_alias, and make room_name the
# human-readable name instead, be that m.room.name, an alias or
# a list of people in the room
name = yield calculate_room_name(
name = await calculate_room_name(
storage.main, room_state_ids, user_id, fallback_to_single_member=False
)
if name:
ctx["name"] = name
sender_state_event_id = room_state_ids[("m.room.member", ev.sender)]
sender_state_event = yield storage.main.get_event(sender_state_event_id)
sender_state_event = await storage.main.get_event(sender_state_event_id)
ctx["sender_display_name"] = name_from_member_event(sender_state_event)
return ctx
+29 -43
View File
@@ -19,8 +19,6 @@ from typing import TYPE_CHECKING, Dict, Union
from prometheus_client import Gauge
from twisted.internet import defer
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.push import PusherConfigException
from synapse.push.emailpusher import EmailPusher
@@ -52,7 +50,7 @@ class PusherPool:
Note that it is expected that each pusher will have its own 'processing' loop which
will send out the notifications in the background, rather than blocking until the
notifications are sent; accordingly Pusher.on_started, Pusher.on_new_notifications and
Pusher.on_new_receipts are not expected to return deferreds.
Pusher.on_new_receipts are not expected to return awaitables.
"""
def __init__(self, hs: "HomeServer"):
@@ -79,8 +77,7 @@ class PusherPool:
return
run_as_background_process("start_pushers", self._start_pushers)
@defer.inlineCallbacks
def add_pusher(
async def add_pusher(
self,
user_id,
access_token,
@@ -96,7 +93,7 @@ class PusherPool:
"""Creates a new pusher and adds it to the pool
Returns:
Deferred[EmailPusher|HttpPusher]
EmailPusher|HttpPusher
"""
time_now_msec = self.clock.time_msec()
@@ -126,9 +123,9 @@ class PusherPool:
# create the pusher setting last_stream_ordering to the current maximum
# stream ordering in event_push_actions, so it will process
# pushes from this point onwards.
last_stream_ordering = yield self.store.get_latest_push_action_stream_ordering()
last_stream_ordering = await self.store.get_latest_push_action_stream_ordering()
yield self.store.add_pusher(
await self.store.add_pusher(
user_id=user_id,
access_token=access_token,
kind=kind,
@@ -142,15 +139,14 @@ class PusherPool:
last_stream_ordering=last_stream_ordering,
profile_tag=profile_tag,
)
pusher = yield self.start_pusher_by_id(app_id, pushkey, user_id)
pusher = await self.start_pusher_by_id(app_id, pushkey, user_id)
return pusher
@defer.inlineCallbacks
def remove_pushers_by_app_id_and_pushkey_not_user(
async def remove_pushers_by_app_id_and_pushkey_not_user(
self, app_id, pushkey, not_user_id
):
to_remove = yield self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey)
to_remove = await self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey)
for p in to_remove:
if p["user_name"] != not_user_id:
logger.info(
@@ -159,10 +155,9 @@ class PusherPool:
pushkey,
p["user_name"],
)
yield self.remove_pusher(p["app_id"], p["pushkey"], p["user_name"])
await self.remove_pusher(p["app_id"], p["pushkey"], p["user_name"])
@defer.inlineCallbacks
def remove_pushers_by_access_token(self, user_id, access_tokens):
async def remove_pushers_by_access_token(self, user_id, access_tokens):
"""Remove the pushers for a given user corresponding to a set of
access_tokens.
@@ -175,7 +170,7 @@ class PusherPool:
return
tokens = set(access_tokens)
for p in (yield self.store.get_pushers_by_user_id(user_id)):
for p in await self.store.get_pushers_by_user_id(user_id):
if p["access_token"] in tokens:
logger.info(
"Removing pusher for app id %s, pushkey %s, user %s",
@@ -183,16 +178,15 @@ class PusherPool:
p["pushkey"],
p["user_name"],
)
yield self.remove_pusher(p["app_id"], p["pushkey"], p["user_name"])
await self.remove_pusher(p["app_id"], p["pushkey"], p["user_name"])
@defer.inlineCallbacks
def on_new_notifications(self, min_stream_id, max_stream_id):
async def on_new_notifications(self, min_stream_id, max_stream_id):
if not self.pushers:
# nothing to do here.
return
try:
users_affected = yield self.store.get_push_action_users_in_range(
users_affected = await self.store.get_push_action_users_in_range(
min_stream_id, max_stream_id
)
@@ -200,7 +194,7 @@ class PusherPool:
if u in self.pushers:
# Don't push if the user account has expired
if self._account_validity.enabled:
expired = yield self.store.is_account_expired(
expired = await self.store.is_account_expired(
u, self.clock.time_msec()
)
if expired:
@@ -212,8 +206,7 @@ class PusherPool:
except Exception:
logger.exception("Exception in pusher on_new_notifications")
@defer.inlineCallbacks
def on_new_receipts(self, min_stream_id, max_stream_id, affected_room_ids):
async def on_new_receipts(self, min_stream_id, max_stream_id, affected_room_ids):
if not self.pushers:
# nothing to do here.
return
@@ -221,7 +214,7 @@ class PusherPool:
try:
# Need to subtract 1 from the minimum because the lower bound here
# is not inclusive
users_affected = yield self.store.get_users_sent_receipts_between(
users_affected = await self.store.get_users_sent_receipts_between(
min_stream_id - 1, max_stream_id
)
@@ -241,12 +234,11 @@ class PusherPool:
except Exception:
logger.exception("Exception in pusher on_new_receipts")
@defer.inlineCallbacks
def start_pusher_by_id(self, app_id, pushkey, user_id):
async def start_pusher_by_id(self, app_id, pushkey, user_id):
"""Look up the details for the given pusher, and start it
Returns:
Deferred[EmailPusher|HttpPusher|None]: The pusher started, if any
EmailPusher|HttpPusher|None: The pusher started, if any
"""
if not self._should_start_pushers:
return
@@ -254,7 +246,7 @@ class PusherPool:
if not self._pusher_shard_config.should_handle(self._instance_name, user_id):
return
resultlist = yield self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey)
resultlist = await self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey)
pusher_dict = None
for r in resultlist:
@@ -263,34 +255,29 @@ class PusherPool:
pusher = None
if pusher_dict:
pusher = yield self._start_pusher(pusher_dict)
pusher = await self._start_pusher(pusher_dict)
return pusher
@defer.inlineCallbacks
def _start_pushers(self):
async def _start_pushers(self) -> None:
"""Start all the pushers
Returns:
Deferred
"""
pushers = yield self.store.get_all_pushers()
pushers = await self.store.get_all_pushers()
# Stagger starting up the pushers so we don't completely drown the
# process on start up.
yield concurrently_execute(self._start_pusher, pushers, 10)
await concurrently_execute(self._start_pusher, pushers, 10)
logger.info("Started pushers")
@defer.inlineCallbacks
def _start_pusher(self, pusherdict):
async def _start_pusher(self, pusherdict):
"""Start the given pusher
Args:
pusherdict (dict): dict with the values pulled from the db table
Returns:
Deferred[EmailPusher|HttpPusher]
EmailPusher|HttpPusher
"""
if not self._pusher_shard_config.should_handle(
self._instance_name, pusherdict["user_name"]
@@ -333,7 +320,7 @@ class PusherPool:
user_id = pusherdict["user_name"]
last_stream_ordering = pusherdict["last_stream_ordering"]
if last_stream_ordering:
have_notifs = yield self.store.get_if_maybe_push_in_range_for_user(
have_notifs = await self.store.get_if_maybe_push_in_range_for_user(
user_id, last_stream_ordering
)
else:
@@ -345,8 +332,7 @@ class PusherPool:
return p
@defer.inlineCallbacks
def remove_pusher(self, app_id, pushkey, user_id):
async def remove_pusher(self, app_id, pushkey, user_id):
appid_pushkey = "%s:%s" % (app_id, pushkey)
byuser = self.pushers.get(user_id, {})
@@ -358,6 +344,6 @@ class PusherPool:
synapse_pushers.labels(type(pusher).__name__, pusher.app_id).dec()
yield self.store.delete_pusher_by_app_id_pushkey_user_id(
await self.store.delete_pusher_by_app_id_pushkey_user_id(
app_id, pushkey, user_id
)
+1 -1
View File
@@ -43,7 +43,7 @@ REQUIREMENTS = [
"jsonschema>=2.5.1",
"frozendict>=1",
"unpaddedbase64>=1.1.0",
"canonicaljson>=1.1.3",
"canonicaljson>=1.2.0",
# we use the type definitions added in signedjson 1.1.
"signedjson>=1.1.0",
"pynacl>=1.2.1",
+3 -1
View File
@@ -78,7 +78,9 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
"""
event_payloads = []
for event, context in event_and_contexts:
serialized_context = yield context.serialize(event, store)
serialized_context = yield defer.ensureDeferred(
context.serialize(event, store)
)
event_payloads.append(
{
+1 -1
View File
@@ -77,7 +77,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
extra_users (list(UserID)): Any extra users to notify about event
"""
serialized_context = yield context.serialize(event, store)
serialized_context = yield defer.ensureDeferred(context.serialize(event, store))
payload = {
"event": event.get_pdu_json(),
+10 -1
View File
@@ -103,6 +103,14 @@ class DeleteRoomRestServlet(RestServlet):
Codes.BAD_JSON,
)
purge = content.get("purge", True)
if not isinstance(purge, bool):
raise SynapseError(
HTTPStatus.BAD_REQUEST,
"Param 'purge' must be a boolean, if given",
Codes.BAD_JSON,
)
ret = await self.room_shutdown_handler.shutdown_room(
room_id=room_id,
new_room_user_id=content.get("new_room_user_id"),
@@ -113,7 +121,8 @@ class DeleteRoomRestServlet(RestServlet):
)
# Purge room
await self.pagination_handler.purge_room(room_id)
if purge:
await self.pagination_handler.purge_room(room_id)
return (200, ret)
+1
View File
@@ -426,6 +426,7 @@ class SyncRestServlet(RestServlet):
result["ephemeral"] = {"events": ephemeral_events}
result["unread_notifications"] = room.unread_notifications
result["summary"] = room.summary
result["org.matrix.msc2654.unread_count"] = room.unread_count
return result
+5 -3
View File
@@ -17,7 +17,9 @@
import logging
import os
import urllib
from typing import Awaitable
from twisted.internet.interfaces import IConsumer
from twisted.protocols.basic import FileSender
from synapse.api.errors import Codes, SynapseError, cs_error
@@ -240,14 +242,14 @@ class Responder(object):
held can be cleaned up.
"""
def write_to_consumer(self, consumer):
def write_to_consumer(self, consumer: IConsumer) -> Awaitable:
"""Stream response into consumer
Args:
consumer (IConsumer)
consumer: The consumer to stream into.
Returns:
Deferred: Resolves once the response has finished being written
Resolves once the response has finished being written
"""
pass
+61 -44
View File
@@ -18,10 +18,11 @@ import errno
import logging
import os
import shutil
from typing import Dict, Tuple
from typing import IO, Dict, Optional, Tuple
import twisted.internet.error
import twisted.web.http
from twisted.web.http import Request
from twisted.web.resource import Resource
from synapse.api.errors import (
@@ -40,6 +41,7 @@ from synapse.util.stringutils import random_string
from ._base import (
FileInfo,
Responder,
get_filename_from_headers,
respond_404,
respond_with_responder,
@@ -135,19 +137,24 @@ class MediaRepository(object):
self.recently_accessed_locals.add(media_id)
async def create_content(
self, media_type, upload_name, content, content_length, auth_user
):
self,
media_type: str,
upload_name: str,
content: IO,
content_length: int,
auth_user: str,
) -> str:
"""Store uploaded content for a local user and return the mxc URL
Args:
media_type(str): The content type of the file
upload_name(str): The name of the file
media_type: The content type of the file
upload_name: The name of the file
content: A file like object that is the content to store
content_length(int): The length of the content
auth_user(str): The user_id of the uploader
content_length: The length of the content
auth_user: The user_id of the uploader
Returns:
Deferred[str]: The mxc url of the stored content
The mxc url of the stored content
"""
media_id = random_string(24)
@@ -170,19 +177,20 @@ class MediaRepository(object):
return "mxc://%s/%s" % (self.server_name, media_id)
async def get_local_media(self, request, media_id, name):
async def get_local_media(
self, request: Request, media_id: str, name: Optional[str]
) -> None:
"""Responds to reqests for local media, if exists, or returns 404.
Args:
request(twisted.web.http.Request)
media_id (str): The media ID of the content. (This is the same as
request: The incoming request.
media_id: The media ID of the content. (This is the same as
the file_id for local content.)
name (str|None): Optional name that, if specified, will be used as
name: Optional name that, if specified, will be used as
the filename in the Content-Disposition header of the response.
Returns:
Deferred: Resolves once a response has successfully been written
to request
Resolves once a response has successfully been written to request
"""
media_info = await self.store.get_local_media(media_id)
if not media_info or media_info["quarantined_by"]:
@@ -203,20 +211,20 @@ class MediaRepository(object):
request, responder, media_type, media_length, upload_name
)
async def get_remote_media(self, request, server_name, media_id, name):
async def get_remote_media(
self, request: Request, server_name: str, media_id: str, name: Optional[str]
) -> None:
"""Respond to requests for remote media.
Args:
request(twisted.web.http.Request)
server_name (str): Remote server_name where the media originated.
media_id (str): The media ID of the content (as defined by the
remote server).
name (str|None): Optional name that, if specified, will be used as
request: The incoming request.
server_name: Remote server_name where the media originated.
media_id: The media ID of the content (as defined by the remote server).
name: Optional name that, if specified, will be used as
the filename in the Content-Disposition header of the response.
Returns:
Deferred: Resolves once a response has successfully been written
to request
Resolves once a response has successfully been written to request
"""
if (
self.federation_domain_whitelist is not None
@@ -245,17 +253,16 @@ class MediaRepository(object):
else:
respond_404(request)
async def get_remote_media_info(self, server_name, media_id):
async def get_remote_media_info(self, server_name: str, media_id: str) -> dict:
"""Gets the media info associated with the remote file, downloading
if necessary.
Args:
server_name (str): Remote server_name where the media originated.
media_id (str): The media ID of the content (as defined by the
remote server).
server_name: Remote server_name where the media originated.
media_id: The media ID of the content (as defined by the remote server).
Returns:
Deferred[dict]: The media_info of the file
The media info of the file
"""
if (
self.federation_domain_whitelist is not None
@@ -278,7 +285,9 @@ class MediaRepository(object):
return media_info
async def _get_remote_media_impl(self, server_name, media_id):
async def _get_remote_media_impl(
self, server_name: str, media_id: str
) -> Tuple[Optional[Responder], dict]:
"""Looks for media in local cache, if not there then attempt to
download from remote server.
@@ -288,7 +297,7 @@ class MediaRepository(object):
remote server).
Returns:
Deferred[(Responder, media_info)]
A tuple of responder and the media info of the file.
"""
media_info = await self.store.get_cached_remote_media(server_name, media_id)
@@ -319,19 +328,21 @@ class MediaRepository(object):
responder = await self.media_storage.fetch_media(file_info)
return responder, media_info
async def _download_remote_file(self, server_name, media_id, file_id):
async def _download_remote_file(
self, server_name: str, media_id: str, file_id: str
) -> dict:
"""Attempt to download the remote file from the given server name,
using the given file_id as the local id.
Args:
server_name (str): Originating server
media_id (str): The media ID of the content (as defined by the
server_name: Originating server
media_id: The media ID of the content (as defined by the
remote server). This is different than the file_id, which is
locally generated.
file_id (str): Local file ID
file_id: Local file ID
Returns:
Deferred[MediaInfo]
The media info of the file.
"""
file_info = FileInfo(server_name=server_name, file_id=file_id)
@@ -549,25 +560,31 @@ class MediaRepository(object):
return output_path
async def _generate_thumbnails(
self, server_name, media_id, file_id, media_type, url_cache=False
):
self,
server_name: Optional[str],
media_id: str,
file_id: str,
media_type: str,
url_cache: bool = False,
) -> Optional[dict]:
"""Generate and store thumbnails for an image.
Args:
server_name (str|None): The server name if remote media, else None if local
media_id (str): The media ID of the content. (This is the same as
server_name: The server name if remote media, else None if local
media_id: The media ID of the content. (This is the same as
the file_id for local content)
file_id (str): Local file ID
media_type (str): The content type of the file
url_cache (bool): If we are thumbnailing images downloaded for the URL cache,
file_id: Local file ID
media_type: The content type of the file
url_cache: If we are thumbnailing images downloaded for the URL cache,
used exclusively by the url previewer
Returns:
Deferred[dict]: Dict with "width" and "height" keys of original image
Dict with "width" and "height" keys of original image or None if the
media cannot be thumbnailed.
"""
requirements = self._get_thumbnail_requirements(media_type)
if not requirements:
return
return None
input_path = await self.media_storage.ensure_media_is_in_local_cache(
FileInfo(server_name, file_id, url_cache=url_cache)
@@ -584,7 +601,7 @@ class MediaRepository(object):
m_height,
self.max_image_pixels,
)
return
return None
if thumbnailer.transpose_method is not None:
m_width, m_height = await defer_to_thread(
+28 -24
View File
@@ -12,13 +12,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
import inspect
import logging
import os
import shutil
from typing import Optional
from typing import IO, TYPE_CHECKING, Any, Optional, Sequence
from twisted.protocols.basic import FileSender
@@ -26,6 +25,12 @@ from synapse.logging.context import defer_to_thread, make_deferred_yieldable
from synapse.util.file_consumer import BackgroundFileConsumer
from ._base import FileInfo, Responder
from .filepath import MediaFilePaths
if TYPE_CHECKING:
from synapse.server import HomeServer
from .storage_provider import StorageProvider
logger = logging.getLogger(__name__)
@@ -34,20 +39,25 @@ class MediaStorage(object):
"""Responsible for storing/fetching files from local sources.
Args:
hs (synapse.server.Homeserver)
local_media_directory (str): Base path where we store media on disk
filepaths (MediaFilePaths)
storage_providers ([StorageProvider]): List of StorageProvider that are
used to fetch and store files.
hs
local_media_directory: Base path where we store media on disk
filepaths
storage_providers: List of StorageProvider that are used to fetch and store files.
"""
def __init__(self, hs, local_media_directory, filepaths, storage_providers):
def __init__(
self,
hs: "HomeServer",
local_media_directory: str,
filepaths: MediaFilePaths,
storage_providers: Sequence["StorageProvider"],
):
self.hs = hs
self.local_media_directory = local_media_directory
self.filepaths = filepaths
self.storage_providers = storage_providers
async def store_file(self, source, file_info: FileInfo) -> str:
async def store_file(self, source: IO, file_info: FileInfo) -> str:
"""Write `source` to the on disk media store, and also any other
configured storage providers
@@ -69,7 +79,7 @@ class MediaStorage(object):
return fname
@contextlib.contextmanager
def store_into_file(self, file_info):
def store_into_file(self, file_info: FileInfo):
"""Context manager used to get a file like object to write into, as
described by file_info.
@@ -85,7 +95,7 @@ class MediaStorage(object):
error.
Args:
file_info (FileInfo): Info about the file to store
file_info: Info about the file to store
Example:
@@ -143,9 +153,9 @@ class MediaStorage(object):
return FileResponder(open(local_path, "rb"))
for provider in self.storage_providers:
res = provider.fetch(path, file_info)
# Fetch is supposed to return an Awaitable, but guard against
# improper implementations.
res = provider.fetch(path, file_info) # type: Any
# Fetch is supposed to return an Awaitable[Responder], but guard
# against improper implementations.
if inspect.isawaitable(res):
res = await res
if res:
@@ -174,9 +184,9 @@ class MediaStorage(object):
os.makedirs(dirname)
for provider in self.storage_providers:
res = provider.fetch(path, file_info)
# Fetch is supposed to return an Awaitable, but guard against
# improper implementations.
res = provider.fetch(path, file_info) # type: Any
# Fetch is supposed to return an Awaitable[Responder], but guard
# against improper implementations.
if inspect.isawaitable(res):
res = await res
if res:
@@ -190,17 +200,11 @@ class MediaStorage(object):
raise Exception("file could not be found")
def _file_info_to_path(self, file_info):
def _file_info_to_path(self, file_info: FileInfo) -> str:
"""Converts file_info into a relative path.
The path is suitable for storing files under a directory, e.g. used to
store files on local FS under the base media repository directory.
Args:
file_info (FileInfo)
Returns:
str
"""
if file_info.url_cache:
if file_info.thumbnail:
@@ -231,16 +231,16 @@ class PreviewUrlResource(DirectServeJsonResource):
og = await make_deferred_yieldable(defer.maybeDeferred(observable.observe))
respond_with_json_bytes(request, 200, og, send_cors=True)
async def _do_preview(self, url, user, ts):
async def _do_preview(self, url: str, user: str, ts: int) -> bytes:
"""Check the db, and download the URL and build a preview
Args:
url (str):
user (str):
ts (int):
url: The URL to preview.
user: The user requesting the preview.
ts: The timestamp requested for the preview.
Returns:
Deferred[bytes]: json-encoded og data
json-encoded og data
"""
# check the URL cache in the DB (which will also provide us with
# historical previews, if we have any)
+31 -31
View File
@@ -16,62 +16,62 @@
import logging
import os
import shutil
from twisted.internet import defer
from typing import Optional
from synapse.config._base import Config
from synapse.logging.context import defer_to_thread, run_in_background
from ._base import FileInfo, Responder
from .media_storage import FileResponder
logger = logging.getLogger(__name__)
class StorageProvider(object):
class StorageProvider:
"""A storage provider is a service that can store uploaded media and
retrieve them.
"""
def store_file(self, path, file_info):
async def store_file(self, path: str, file_info: FileInfo):
"""Store the file described by file_info. The actual contents can be
retrieved by reading the file in file_info.upload_path.
Args:
path (str): Relative path of file in local cache
file_info (FileInfo)
Returns:
Deferred
path: Relative path of file in local cache
file_info: The metadata of the file.
"""
pass
def fetch(self, path, file_info):
async def fetch(self, path: str, file_info: FileInfo) -> Optional[Responder]:
"""Attempt to fetch the file described by file_info and stream it
into writer.
Args:
path (str): Relative path of file in local cache
file_info (FileInfo)
path: Relative path of file in local cache
file_info: The metadata of the file.
Returns:
Deferred(Responder): Returns a Responder if the provider has the file,
otherwise returns None.
Returns a Responder if the provider has the file, otherwise returns None.
"""
pass
class StorageProviderWrapper(StorageProvider):
"""Wraps a storage provider and provides various config options
Args:
backend (StorageProvider)
store_local (bool): Whether to store new local files or not.
store_synchronous (bool): Whether to wait for file to be successfully
backend: The storage provider to wrap.
store_local: Whether to store new local files or not.
store_synchronous: Whether to wait for file to be successfully
uploaded, or todo the upload in the background.
store_remote (bool): Whether remote media should be uploaded
store_remote: Whether remote media should be uploaded
"""
def __init__(self, backend, store_local, store_synchronous, store_remote):
def __init__(
self,
backend: StorageProvider,
store_local: bool,
store_synchronous: bool,
store_remote: bool,
):
self.backend = backend
self.store_local = store_local
self.store_synchronous = store_synchronous
@@ -80,15 +80,15 @@ class StorageProviderWrapper(StorageProvider):
def __str__(self):
return "StorageProviderWrapper[%s]" % (self.backend,)
def store_file(self, path, file_info):
async def store_file(self, path, file_info):
if not file_info.server_name and not self.store_local:
return defer.succeed(None)
return None
if file_info.server_name and not self.store_remote:
return defer.succeed(None)
return None
if self.store_synchronous:
return self.backend.store_file(path, file_info)
return await self.backend.store_file(path, file_info)
else:
# TODO: Handle errors.
def store():
@@ -98,10 +98,10 @@ class StorageProviderWrapper(StorageProvider):
logger.exception("Error storing file")
run_in_background(store)
return defer.succeed(None)
return None
def fetch(self, path, file_info):
return self.backend.fetch(path, file_info)
async def fetch(self, path, file_info):
return await self.backend.fetch(path, file_info)
class FileStorageProviderBackend(StorageProvider):
@@ -120,7 +120,7 @@ class FileStorageProviderBackend(StorageProvider):
def __str__(self):
return "FileStorageProviderBackend[%s]" % (self.base_directory,)
def store_file(self, path, file_info):
async def store_file(self, path, file_info):
"""See StorageProvider.store_file"""
primary_fname = os.path.join(self.cache_directory, path)
@@ -130,11 +130,11 @@ class FileStorageProviderBackend(StorageProvider):
if not os.path.exists(dirname):
os.makedirs(dirname)
return defer_to_thread(
return await defer_to_thread(
self.hs.get_reactor(), shutil.copyfile, primary_fname, backup_fname
)
def fetch(self, path, file_info):
async def fetch(self, path, file_info):
"""See StorageProvider.fetch"""
backup_fname = os.path.join(self.base_directory, path)
@@ -172,6 +172,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
self.get_latest_event_ids_in_room.invalidate((room_id,))
self.get_unread_message_count_for_user.invalidate_many((room_id,))
self.get_unread_event_push_actions_by_room_for_user.invalidate_many((room_id,))
if not backfilled:
@@ -411,7 +411,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
_get_if_maybe_push_in_range_for_user_txn,
)
def add_push_actions_to_staging(self, event_id, user_id_actions):
async def add_push_actions_to_staging(self, event_id, user_id_actions):
"""Add the push actions for the event to the push action staging area.
Args:
@@ -457,7 +457,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
),
)
return self.db.runInteraction(
return await self.db.runInteraction(
"add_push_actions_to_staging", _add_push_actions_to_staging_txn
)
+47 -1
View File
@@ -53,6 +53,47 @@ event_counter = Counter(
["type", "origin_type", "origin_entity"],
)
STATE_EVENT_TYPES_TO_MARK_UNREAD = {
EventTypes.Topic,
EventTypes.Name,
EventTypes.RoomAvatar,
EventTypes.Tombstone,
}
def should_count_as_unread(event: EventBase, context: EventContext) -> bool:
# Exclude rejected and soft-failed events.
if context.rejected or event.internal_metadata.is_soft_failed():
return False
# Exclude notices.
if (
not event.is_state()
and event.type == EventTypes.Message
and event.content.get("msgtype") == "m.notice"
):
return False
# Exclude edits.
relates_to = event.content.get("m.relates_to", {})
if relates_to.get("rel_type") == RelationTypes.REPLACE:
return False
# Mark events that have a non-empty string body as unread.
body = event.content.get("body")
if isinstance(body, str) and body:
return True
# Mark some state events as unread.
if event.is_state() and event.type in STATE_EVENT_TYPES_TO_MARK_UNREAD:
return True
# Mark encrypted events as unread.
if not event.is_state() and event.type == EventTypes.Encrypted:
return True
return False
def encode_json(json_object):
"""
@@ -196,6 +237,10 @@ class PersistEventsStore:
event_counter.labels(event.type, origin_type, origin_entity).inc()
self.store.get_unread_message_count_for_user.invalidate_many(
(event.room_id,),
)
for room_id, new_state in current_state_for_room.items():
self.store.get_current_state_ids.prefill((room_id,), new_state)
@@ -817,8 +862,9 @@ class PersistEventsStore:
"contains_url": (
"url" in event.content and isinstance(event.content["url"], str)
),
"count_as_unread": should_count_as_unread(event, context),
}
for event, _ in events_and_contexts
for event, context in events_and_contexts
],
)
@@ -41,9 +41,15 @@ from synapse.replication.tcp.streams import BackfillStream
from synapse.replication.tcp.streams.events import EventsStream
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import Database
from synapse.storage.types import Cursor
from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.types import get_domain_from_id
from synapse.util.caches.descriptors import Cache, cached, cachedInlineCallbacks
from synapse.util.caches.descriptors import (
Cache,
_CacheContext,
cached,
cachedInlineCallbacks,
)
from synapse.util.iterutils import batch_iter
from synapse.util.metrics import Measure
@@ -1358,6 +1364,84 @@ class EventsWorkerStore(SQLBaseStore):
desc="get_next_event_to_expire", func=get_next_event_to_expire_txn
)
@cached(tree=True, cache_context=True)
async def get_unread_message_count_for_user(
self, room_id: str, user_id: str, cache_context: _CacheContext,
) -> int:
"""Retrieve the count of unread messages for the given room and user.
Args:
room_id: The ID of the room to count unread messages in.
user_id: The ID of the user to count unread messages for.
Returns:
The number of unread messages for the given user in the given room.
"""
with Measure(self._clock, "get_unread_message_count_for_user"):
last_read_event_id = await self.get_last_receipt_event_id_for_user(
user_id=user_id,
room_id=room_id,
receipt_type="m.read",
on_invalidate=cache_context.invalidate,
)
return await self.db.runInteraction(
"get_unread_message_count_for_user",
self._get_unread_message_count_for_user_txn,
user_id,
room_id,
last_read_event_id,
)
def _get_unread_message_count_for_user_txn(
self,
txn: Cursor,
user_id: str,
room_id: str,
last_read_event_id: Optional[str],
) -> int:
if last_read_event_id:
# Get the stream ordering for the last read event.
stream_ordering = self.db.simple_select_one_onecol_txn(
txn=txn,
table="events",
keyvalues={"room_id": room_id, "event_id": last_read_event_id},
retcol="stream_ordering",
)
else:
# If there's no read receipt for that room, it probably means the user hasn't
# opened it yet, in which case use the stream ID of their join event.
# We can't just set it to 0 otherwise messages from other local users from
# before this user joined will be counted as well.
txn.execute(
"""
SELECT stream_ordering FROM local_current_membership
LEFT JOIN events USING (event_id, room_id)
WHERE membership = 'join'
AND user_id = ?
AND room_id = ?
""",
(user_id, room_id),
)
row = txn.fetchone()
if row is None:
return 0
stream_ordering = row[0]
# Count the messages that qualify as unread after the stream ordering we've just
# retrieved.
sql = """
SELECT COUNT(*) FROM events
WHERE sender != ? AND room_id = ? AND stream_ordering > ? AND count_as_unread
"""
txn.execute(sql, (user_id, room_id, stream_ordering))
row = txn.fetchone()
return row[0] if row else 0
AllNewEventsResult = namedtuple(
"AllNewEventsResult",
@@ -62,6 +62,7 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
# event_json
# event_push_actions
# event_reference_hashes
# event_relations
# event_search
# event_to_state_groups
# events
@@ -209,6 +210,7 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
"event_edges",
"event_forward_extremities",
"event_reference_hashes",
"event_relations",
"event_search",
"rejections",
):
@@ -0,0 +1,18 @@
/* Copyright 2020 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- Store a boolean value in the events table for whether the event should be counted in
-- the unread_count property of sync responses.
ALTER TABLE events ADD COLUMN count_as_unread BOOLEAN;
+10 -8
View File
@@ -49,11 +49,11 @@ from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3E
from synapse.storage.types import Connection, Cursor
from synapse.types import Collection
logger = logging.getLogger(__name__)
# python 3 does not have a maximum int value
MAX_TXN_ID = 2 ** 63 - 1
logger = logging.getLogger(__name__)
sql_logger = logging.getLogger("synapse.storage.SQL")
transaction_logger = logging.getLogger("synapse.storage.txn")
perf_logger = logging.getLogger("synapse.storage.TIME")
@@ -233,7 +233,7 @@ class LoggingTransaction:
try:
return func(sql, *args)
except Exception as e:
logger.debug("[SQL FAIL] {%s} %s", self.name, e)
sql_logger.debug("[SQL FAIL] {%s} %s", self.name, e)
raise
finally:
secs = time.time() - start
@@ -419,7 +419,7 @@ class Database(object):
except self.engine.module.OperationalError as e:
# This can happen if the database disappears mid
# transaction.
logger.warning(
transaction_logger.warning(
"[TXN OPERROR] {%s} %s %d/%d", name, e, i, N,
)
if i < N:
@@ -427,18 +427,20 @@ class Database(object):
try:
conn.rollback()
except self.engine.module.Error as e1:
logger.warning("[TXN EROLL] {%s} %s", name, e1)
transaction_logger.warning("[TXN EROLL] {%s} %s", name, e1)
continue
raise
except self.engine.module.DatabaseError as e:
if self.engine.is_deadlock(e):
logger.warning("[TXN DEADLOCK] {%s} %d/%d", name, i, N)
transaction_logger.warning(
"[TXN DEADLOCK] {%s} %d/%d", name, i, N
)
if i < N:
i += 1
try:
conn.rollback()
except self.engine.module.Error as e1:
logger.warning(
transaction_logger.warning(
"[TXN EROLL] {%s} %s", name, e1,
)
continue
@@ -478,7 +480,7 @@ class Database(object):
# [2]: https://github.com/python/cpython/blob/v3.8.0/Modules/_sqlite/cursor.c#L236
cursor.close()
except Exception as e:
logger.debug("[TXN FAIL] {%s} %s", name, e)
transaction_logger.debug("[TXN FAIL] {%s} %s", name, e)
raise
finally:
end = monotonic_time()
+18 -22
View File
@@ -25,7 +25,7 @@ from prometheus_client import Counter, Histogram
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
from synapse.events import FrozenEvent
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
from synapse.metrics.background_process_metrics import run_as_background_process
@@ -192,12 +192,11 @@ class EventsPersistenceStorage(object):
self._event_persist_queue = _EventPeristenceQueue()
self._state_resolution_handler = hs.get_state_resolution_handler()
@defer.inlineCallbacks
def persist_events(
async def persist_events(
self,
events_and_contexts: List[Tuple[FrozenEvent, EventContext]],
events_and_contexts: List[Tuple[EventBase, EventContext]],
backfilled: bool = False,
):
) -> int:
"""
Write events to the database
Args:
@@ -207,7 +206,7 @@ class EventsPersistenceStorage(object):
which might update the current state etc.
Returns:
Deferred[int]: the stream ordering of the latest persisted event
the stream ordering of the latest persisted event
"""
partitioned = {}
for event, ctx in events_and_contexts:
@@ -223,22 +222,19 @@ class EventsPersistenceStorage(object):
for room_id in partitioned:
self._maybe_start_persisting(room_id)
yield make_deferred_yieldable(
await make_deferred_yieldable(
defer.gatherResults(deferreds, consumeErrors=True)
)
max_persisted_id = yield self.main_store.get_current_events_token()
return self.main_store.get_current_events_token()
return max_persisted_id
@defer.inlineCallbacks
def persist_event(
self, event: FrozenEvent, context: EventContext, backfilled: bool = False
):
async def persist_event(
self, event: EventBase, context: EventContext, backfilled: bool = False
) -> Tuple[int, int]:
"""
Returns:
Deferred[Tuple[int, int]]: the stream ordering of ``event``,
and the stream ordering of the latest persisted event
The stream ordering of `event`, and the stream ordering of the
latest persisted event
"""
deferred = self._event_persist_queue.add_to_queue(
event.room_id, [(event, context)], backfilled=backfilled
@@ -246,9 +242,9 @@ class EventsPersistenceStorage(object):
self._maybe_start_persisting(event.room_id)
yield make_deferred_yieldable(deferred)
await make_deferred_yieldable(deferred)
max_persisted_id = yield self.main_store.get_current_events_token()
max_persisted_id = self.main_store.get_current_events_token()
return (event.internal_metadata.stream_ordering, max_persisted_id)
def _maybe_start_persisting(self, room_id: str):
@@ -262,7 +258,7 @@ class EventsPersistenceStorage(object):
async def _persist_events(
self,
events_and_contexts: List[Tuple[FrozenEvent, EventContext]],
events_and_contexts: List[Tuple[EventBase, EventContext]],
backfilled: bool = False,
):
"""Calculates the change to current state and forward extremities, and
@@ -439,7 +435,7 @@ class EventsPersistenceStorage(object):
async def _calculate_new_extremities(
self,
room_id: str,
event_contexts: List[Tuple[FrozenEvent, EventContext]],
event_contexts: List[Tuple[EventBase, EventContext]],
latest_event_ids: List[str],
):
"""Calculates the new forward extremities for a room given events to
@@ -497,7 +493,7 @@ class EventsPersistenceStorage(object):
async def _get_new_state_after_events(
self,
room_id: str,
events_context: List[Tuple[FrozenEvent, EventContext]],
events_context: List[Tuple[EventBase, EventContext]],
old_latest_event_ids: Iterable[str],
new_latest_event_ids: Iterable[str],
) -> Tuple[Optional[StateMap[str]], Optional[StateMap[str]]]:
@@ -683,7 +679,7 @@ class EventsPersistenceStorage(object):
async def _is_server_still_joined(
self,
room_id: str,
ev_ctx_rm: List[Tuple[FrozenEvent, EventContext]],
ev_ctx_rm: List[Tuple[EventBase, EventContext]],
delta: DeltaState,
current_state: Optional[StateMap[str]],
potentially_left_users: Set[str],
+18 -20
View File
@@ -15,8 +15,7 @@
import itertools
import logging
from twisted.internet import defer
from typing import Set
logger = logging.getLogger(__name__)
@@ -28,49 +27,48 @@ class PurgeEventsStorage(object):
def __init__(self, hs, stores):
self.stores = stores
@defer.inlineCallbacks
def purge_room(self, room_id: str):
async def purge_room(self, room_id: str):
"""Deletes all record of a room
"""
state_groups_to_delete = yield self.stores.main.purge_room(room_id)
yield self.stores.state.purge_room_state(room_id, state_groups_to_delete)
state_groups_to_delete = await self.stores.main.purge_room(room_id)
await self.stores.state.purge_room_state(room_id, state_groups_to_delete)
@defer.inlineCallbacks
def purge_history(self, room_id, token, delete_local_events):
async def purge_history(
self, room_id: str, token: str, delete_local_events: bool
) -> None:
"""Deletes room history before a certain point
Args:
room_id (str):
room_id: The room ID
token (str): A topological token to delete events before
token: A topological token to delete events before
delete_local_events (bool):
delete_local_events:
if True, we will delete local events as well as remote ones
(instead of just marking them as outliers and deleting their
state groups).
"""
state_groups = yield self.stores.main.purge_history(
state_groups = await self.stores.main.purge_history(
room_id, token, delete_local_events
)
logger.info("[purge] finding state groups that can be deleted")
sg_to_delete = yield self._find_unreferenced_groups(state_groups)
sg_to_delete = await self._find_unreferenced_groups(state_groups)
yield self.stores.state.purge_unreferenced_state_groups(room_id, sg_to_delete)
await self.stores.state.purge_unreferenced_state_groups(room_id, sg_to_delete)
@defer.inlineCallbacks
def _find_unreferenced_groups(self, state_groups):
async def _find_unreferenced_groups(self, state_groups: Set[int]) -> Set[int]:
"""Used when purging history to figure out which state groups can be
deleted.
Args:
state_groups (set[int]): Set of state groups referenced by events
state_groups: Set of state groups referenced by events
that are going to be deleted.
Returns:
Deferred[set[int]] The set of state groups that can be deleted.
The set of state groups that can be deleted.
"""
# Graph of state group -> previous group
graph = {}
@@ -93,7 +91,7 @@ class PurgeEventsStorage(object):
current_search = set(itertools.islice(next_to_search, 100))
next_to_search -= current_search
referenced = yield self.stores.main.get_referenced_state_groups(
referenced = await self.stores.main.get_referenced_state_groups(
current_search
)
referenced_groups |= referenced
@@ -102,7 +100,7 @@ class PurgeEventsStorage(object):
# groups that are referenced.
current_search -= referenced
edges = yield self.stores.state.get_previous_state_groups(current_search)
edges = await self.stores.state.get_previous_state_groups(current_search)
prevs = set(edges.values())
# We don't bother re-handling groups we've already seen
+109 -98
View File
@@ -14,13 +14,12 @@
# limitations under the License.
import logging
from typing import Iterable, List, TypeVar
from typing import Dict, Iterable, List, Optional, Set, Tuple, TypeVar
import attr
from twisted.internet import defer
from synapse.api.constants import EventTypes
from synapse.events import EventBase
from synapse.types import StateMap
logger = logging.getLogger(__name__)
@@ -34,16 +33,16 @@ class StateFilter(object):
"""A filter used when querying for state.
Attributes:
types (dict[str, set[str]|None]): Map from type to set of state keys (or
None). This specifies which state_keys for the given type to fetch
from the DB. If None then all events with that type are fetched. If
the set is empty then no events with that type are fetched.
include_others (bool): Whether to fetch events with types that do not
types: Map from type to set of state keys (or None). This specifies
which state_keys for the given type to fetch from the DB. If None
then all events with that type are fetched. If the set is empty
then no events with that type are fetched.
include_others: Whether to fetch events with types that do not
appear in `types`.
"""
types = attr.ib()
include_others = attr.ib(default=False)
types = attr.ib(type=Dict[str, Optional[Set[str]]])
include_others = attr.ib(default=False, type=bool)
def __attrs_post_init__(self):
# If `include_others` is set we canonicalise the filter by removing
@@ -52,36 +51,35 @@ class StateFilter(object):
self.types = {k: v for k, v in self.types.items() if v is not None}
@staticmethod
def all():
def all() -> "StateFilter":
"""Creates a filter that fetches everything.
Returns:
StateFilter
The new state filter.
"""
return StateFilter(types={}, include_others=True)
@staticmethod
def none():
def none() -> "StateFilter":
"""Creates a filter that fetches nothing.
Returns:
StateFilter
The new state filter.
"""
return StateFilter(types={}, include_others=False)
@staticmethod
def from_types(types):
def from_types(types: Iterable[Tuple[str, Optional[str]]]) -> "StateFilter":
"""Creates a filter that only fetches the given types
Args:
types (Iterable[tuple[str, str|None]]): A list of type and state
keys to fetch. A state_key of None fetches everything for
that type
types: A list of type and state keys to fetch. A state_key of None
fetches everything for that type
Returns:
StateFilter
The new state filter.
"""
type_dict = {}
type_dict = {} # type: Dict[str, Optional[Set[str]]]
for typ, s in types:
if typ in type_dict:
if type_dict[typ] is None:
@@ -91,24 +89,24 @@ class StateFilter(object):
type_dict[typ] = None
continue
type_dict.setdefault(typ, set()).add(s)
type_dict.setdefault(typ, set()).add(s) # type: ignore
return StateFilter(types=type_dict)
@staticmethod
def from_lazy_load_member_list(members):
def from_lazy_load_member_list(members: Iterable[str]) -> "StateFilter":
"""Creates a filter that returns all non-member events, plus the member
events for the given users
Args:
members (iterable[str]): Set of user IDs
members: Set of user IDs
Returns:
StateFilter
The new state filter
"""
return StateFilter(types={EventTypes.Member: set(members)}, include_others=True)
def return_expanded(self):
def return_expanded(self) -> "StateFilter":
"""Creates a new StateFilter where type wild cards have been removed
(except for memberships). The returned filter is a superset of the
current one, i.e. anything that passes the current filter will pass
@@ -130,7 +128,7 @@ class StateFilter(object):
return all non-member events
Returns:
StateFilter
The new state filter.
"""
if self.is_full():
@@ -167,7 +165,7 @@ class StateFilter(object):
include_others=True,
)
def make_sql_filter_clause(self):
def make_sql_filter_clause(self) -> Tuple[str, List[str]]:
"""Converts the filter to an SQL clause.
For example:
@@ -179,13 +177,12 @@ class StateFilter(object):
Returns:
tuple[str, list]: The SQL string (may be empty) and arguments. An
empty SQL string is returned when the filter matches everything
(i.e. is "full").
The SQL string (may be empty) and arguments. An empty SQL string is
returned when the filter matches everything (i.e. is "full").
"""
where_clause = ""
where_args = []
where_args = [] # type: List[str]
if self.is_full():
return where_clause, where_args
@@ -221,7 +218,7 @@ class StateFilter(object):
return where_clause, where_args
def max_entries_returned(self):
def max_entries_returned(self) -> Optional[int]:
"""Returns the maximum number of entries this filter will return if
known, otherwise returns None.
@@ -260,33 +257,33 @@ class StateFilter(object):
return filtered_state
def is_full(self):
def is_full(self) -> bool:
"""Whether this filter fetches everything or not
Returns:
bool
True if the filter fetches everything.
"""
return self.include_others and not self.types
def has_wildcards(self):
def has_wildcards(self) -> bool:
"""Whether the filter includes wildcards or is attempting to fetch
specific state.
Returns:
bool
True if the filter includes wildcards.
"""
return self.include_others or any(
state_keys is None for state_keys in self.types.values()
)
def concrete_types(self):
def concrete_types(self) -> List[Tuple[str, str]]:
"""Returns a list of concrete type/state_keys (i.e. not None) that
will be fetched. This will be a complete list if `has_wildcards`
returns False, but otherwise will be a subset (or even empty).
Returns:
list[tuple[str,str]]
A list of type/state_keys tuples.
"""
return [
(t, s)
@@ -295,7 +292,7 @@ class StateFilter(object):
for s in state_keys
]
def get_member_split(self):
def get_member_split(self) -> Tuple["StateFilter", "StateFilter"]:
"""Return the filter split into two: one which assumes it's exclusively
matching against member state, and one which assumes it's matching
against non member state.
@@ -307,7 +304,7 @@ class StateFilter(object):
state caches).
Returns:
tuple[StateFilter, StateFilter]: The member and non member filters
The member and non member filters
"""
if EventTypes.Member in self.types:
@@ -340,6 +337,9 @@ class StateGroupStorage(object):
"""Given a state group try to return a previous group and a delta between
the old and the new.
Args:
state_group: The state group used to retrieve state deltas.
Returns:
Deferred[Tuple[Optional[int], Optional[StateMap[str]]]]:
(prev_group, delta_ids)
@@ -347,55 +347,59 @@ class StateGroupStorage(object):
return self.stores.state.get_state_group_delta(state_group)
@defer.inlineCallbacks
def get_state_groups_ids(self, _room_id, event_ids):
async def get_state_groups_ids(
self, _room_id: str, event_ids: Iterable[str]
) -> Dict[int, StateMap[str]]:
"""Get the event IDs of all the state for the state groups for the given events
Args:
_room_id (str): id of the room for these events
event_ids (iterable[str]): ids of the events
_room_id: id of the room for these events
event_ids: ids of the events
Returns:
Deferred[dict[int, StateMap[str]]]:
dict of state_group_id -> (dict of (type, state_key) -> event id)
dict of state_group_id -> (dict of (type, state_key) -> event id)
"""
if not event_ids:
return {}
event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids)
event_to_groups = await self.stores.main._get_state_group_for_events(event_ids)
groups = set(event_to_groups.values())
group_to_state = yield self.stores.state._get_state_for_groups(groups)
group_to_state = await self.stores.state._get_state_for_groups(groups)
return group_to_state
@defer.inlineCallbacks
def get_state_ids_for_group(self, state_group):
async def get_state_ids_for_group(self, state_group: int) -> StateMap[str]:
"""Get the event IDs of all the state in the given state group
Args:
state_group (int)
state_group: A state group for which we want to get the state IDs.
Returns:
Deferred[dict]: Resolves to a map of (type, state_key) -> event_id
Resolves to a map of (type, state_key) -> event_id
"""
group_to_state = yield self._get_state_for_groups((state_group,))
group_to_state = await self._get_state_for_groups((state_group,))
return group_to_state[state_group]
@defer.inlineCallbacks
def get_state_groups(self, room_id, event_ids):
async def get_state_groups(
self, room_id: str, event_ids: Iterable[str]
) -> Dict[int, List[EventBase]]:
""" Get the state groups for the given list of event_ids
Args:
room_id: ID of the room for these events.
event_ids: The event IDs to retrieve state for.
Returns:
Deferred[dict[int, list[EventBase]]]:
dict of state_group_id -> list of state events.
dict of state_group_id -> list of state events.
"""
if not event_ids:
return {}
group_to_ids = yield self.get_state_groups_ids(room_id, event_ids)
group_to_ids = await self.get_state_groups_ids(room_id, event_ids)
state_event_map = yield self.stores.main.get_events(
state_event_map = await self.stores.main.get_events(
[
ev_id
for group_ids in group_to_ids.values()
@@ -423,31 +427,34 @@ class StateGroupStorage(object):
groups: list of state group IDs to query
state_filter: The state filter used to fetch state
from the database.
Returns:
Deferred[Dict[int, StateMap[str]]]: Dict of state group to state map.
"""
return self.stores.state._get_state_groups_from_groups(groups, state_filter)
@defer.inlineCallbacks
def get_state_for_events(self, event_ids, state_filter=StateFilter.all()):
async def get_state_for_events(
self, event_ids: List[str], state_filter: StateFilter = StateFilter.all()
):
"""Given a list of event_ids and type tuples, return a list of state
dicts for each event.
Args:
event_ids (list[string])
state_filter (StateFilter): The state filter used to fetch state
from the database.
event_ids: The events to fetch the state of.
state_filter: The state filter used to fetch state.
Returns:
deferred: A dict of (event_id) -> (type, state_key) -> [state_events]
A dict of (event_id) -> (type, state_key) -> [state_events]
"""
event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids)
event_to_groups = await self.stores.main._get_state_group_for_events(event_ids)
groups = set(event_to_groups.values())
group_to_state = yield self.stores.state._get_state_for_groups(
group_to_state = await self.stores.state._get_state_for_groups(
groups, state_filter
)
state_event_map = yield self.stores.main.get_events(
state_event_map = await self.stores.main.get_events(
[ev_id for sd in group_to_state.values() for ev_id in sd.values()],
get_prev_content=False,
)
@@ -463,24 +470,24 @@ class StateGroupStorage(object):
return {event: event_to_state[event] for event in event_ids}
@defer.inlineCallbacks
def get_state_ids_for_events(self, event_ids, state_filter=StateFilter.all()):
async def get_state_ids_for_events(
self, event_ids: List[str], state_filter: StateFilter = StateFilter.all()
):
"""
Get the state dicts corresponding to a list of events, containing the event_ids
of the state events (as opposed to the events themselves)
Args:
event_ids(list(str)): events whose state should be returned
state_filter (StateFilter): The state filter used to fetch state
from the database.
event_ids: events whose state should be returned
state_filter: The state filter used to fetch state from the database.
Returns:
A deferred dict from event_id -> (type, state_key) -> event_id
A dict from event_id -> (type, state_key) -> event_id
"""
event_to_groups = yield self.stores.main._get_state_group_for_events(event_ids)
event_to_groups = await self.stores.main._get_state_group_for_events(event_ids)
groups = set(event_to_groups.values())
group_to_state = yield self.stores.state._get_state_for_groups(
group_to_state = await self.stores.state._get_state_for_groups(
groups, state_filter
)
@@ -491,36 +498,36 @@ class StateGroupStorage(object):
return {event: event_to_state[event] for event in event_ids}
@defer.inlineCallbacks
def get_state_for_event(self, event_id, state_filter=StateFilter.all()):
async def get_state_for_event(
self, event_id: str, state_filter: StateFilter = StateFilter.all()
):
"""
Get the state dict corresponding to a particular event
Args:
event_id(str): event whose state should be returned
state_filter (StateFilter): The state filter used to fetch state
from the database.
event_id: event whose state should be returned
state_filter: The state filter used to fetch state from the database.
Returns:
A deferred dict from (type, state_key) -> state_event
A dict from (type, state_key) -> state_event
"""
state_map = yield self.get_state_for_events([event_id], state_filter)
state_map = await self.get_state_for_events([event_id], state_filter)
return state_map[event_id]
@defer.inlineCallbacks
def get_state_ids_for_event(self, event_id, state_filter=StateFilter.all()):
async def get_state_ids_for_event(
self, event_id: str, state_filter: StateFilter = StateFilter.all()
):
"""
Get the state dict corresponding to a particular event
Args:
event_id(str): event whose state should be returned
state_filter (StateFilter): The state filter used to fetch state
from the database.
event_id: event whose state should be returned
state_filter: The state filter used to fetch state from the database.
Returns:
A deferred dict from (type, state_key) -> state_event
"""
state_map = yield self.get_state_ids_for_events([event_id], state_filter)
state_map = await self.get_state_ids_for_events([event_id], state_filter)
return state_map[event_id]
def _get_state_for_groups(
@@ -530,9 +537,8 @@ class StateGroupStorage(object):
filtering by type/state_key
Args:
groups (iterable[int]): list of state groups for which we want
to get the state.
state_filter (StateFilter): The state filter used to fetch state
groups: list of state groups for which we want to get the state.
state_filter: The state filter used to fetch state.
from the database.
Returns:
Deferred[dict[int, StateMap[str]]]: Dict of state group to state map.
@@ -540,18 +546,23 @@ class StateGroupStorage(object):
return self.stores.state._get_state_for_groups(groups, state_filter)
def store_state_group(
self, event_id, room_id, prev_group, delta_ids, current_state_ids
self,
event_id: str,
room_id: str,
prev_group: Optional[int],
delta_ids: Optional[dict],
current_state_ids: dict,
):
"""Store a new set of state, returning a newly assigned state group.
Args:
event_id (str): The event ID for which the state was calculated
room_id (str)
prev_group (int|None): A previous state group for the room, optional.
delta_ids (dict|None): The delta between state at `prev_group` and
event_id: The event ID for which the state was calculated.
room_id: ID of the room for which the state was calculated.
prev_group: A previous state group for the room, optional.
delta_ids: The delta between state at `prev_group` and
`current_state_ids`, if `prev_group` was given. Same format as
`current_state_ids`.
current_state_ids (dict): The state to store. Map of (type, state_key)
current_state_ids: The state to store. Map of (type, state_key)
to event_id.
Returns:
+13 -17
View File
@@ -16,8 +16,6 @@
import logging
import operator
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
from synapse.events.utils import prune_event
from synapse.storage import Storage
@@ -39,8 +37,7 @@ MEMBERSHIP_PRIORITY = (
)
@defer.inlineCallbacks
def filter_events_for_client(
async def filter_events_for_client(
storage: Storage,
user_id,
events,
@@ -67,19 +64,19 @@ def filter_events_for_client(
also be called to check whether a user can see the state at a given point.
Returns:
Deferred[list[synapse.events.EventBase]]
list[synapse.events.EventBase]
"""
# Filter out events that have been soft failed so that we don't relay them
# to clients.
events = [e for e in events if not e.internal_metadata.is_soft_failed()]
types = ((EventTypes.RoomHistoryVisibility, ""), (EventTypes.Member, user_id))
event_id_to_state = yield storage.state.get_state_for_events(
event_id_to_state = await storage.state.get_state_for_events(
frozenset(e.event_id for e in events),
state_filter=StateFilter.from_types(types),
)
ignore_dict_content = yield storage.main.get_global_account_data_by_type_for_user(
ignore_dict_content = await storage.main.get_global_account_data_by_type_for_user(
"m.ignored_user_list", user_id
)
@@ -90,7 +87,7 @@ def filter_events_for_client(
else []
)
erased_senders = yield storage.main.are_users_erased((e.sender for e in events))
erased_senders = await storage.main.are_users_erased((e.sender for e in events))
if filter_send_to_client:
room_ids = {e.room_id for e in events}
@@ -99,7 +96,7 @@ def filter_events_for_client(
for room_id in room_ids:
retention_policies[
room_id
] = yield storage.main.get_retention_policy_for_room(room_id)
] = await storage.main.get_retention_policy_for_room(room_id)
def allowed(event):
"""
@@ -254,8 +251,7 @@ def filter_events_for_client(
return list(filtered_events)
@defer.inlineCallbacks
def filter_events_for_server(
async def filter_events_for_server(
storage: Storage,
server_name,
events,
@@ -277,7 +273,7 @@ def filter_events_for_server(
backfill or not.
Returns
Deferred[list[FrozenEvent]]
list[FrozenEvent]
"""
def is_sender_erased(event, erased_senders):
@@ -321,7 +317,7 @@ def filter_events_for_server(
# Lets check to see if all the events have a history visibility
# of "shared" or "world_readable". If that's the case then we don't
# need to check membership (as we know the server is in the room).
event_to_state_ids = yield storage.state.get_state_ids_for_events(
event_to_state_ids = await storage.state.get_state_ids_for_events(
frozenset(e.event_id for e in events),
state_filter=StateFilter.from_types(
types=((EventTypes.RoomHistoryVisibility, ""),)
@@ -339,14 +335,14 @@ def filter_events_for_server(
if not visibility_ids:
all_open = True
else:
event_map = yield storage.main.get_events(visibility_ids)
event_map = await storage.main.get_events(visibility_ids)
all_open = all(
e.content.get("history_visibility") in (None, "shared", "world_readable")
for e in event_map.values()
)
if not check_history_visibility_only:
erased_senders = yield storage.main.are_users_erased((e.sender for e in events))
erased_senders = await storage.main.are_users_erased((e.sender for e in events))
else:
# We don't want to check whether users are erased, which is equivalent
# to no users having been erased.
@@ -375,7 +371,7 @@ def filter_events_for_server(
# first, for each event we're wanting to return, get the event_ids
# of the history vis and membership state at those events.
event_to_state_ids = yield storage.state.get_state_ids_for_events(
event_to_state_ids = await storage.state.get_state_ids_for_events(
frozenset(e.event_id for e in events),
state_filter=StateFilter.from_types(
types=((EventTypes.RoomHistoryVisibility, ""), (EventTypes.Member, None))
@@ -405,7 +401,7 @@ def filter_events_for_server(
return False
return state_key[idx + 1 :] == server_name
event_map = yield storage.main.get_events(
event_map = await storage.main.get_events(
[e_id for e_id, key in event_id_to_state_key.items() if include(key[0], key[1])]
)
+109
View File
@@ -99,6 +99,37 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
self.assertEqual(f.value.code, 400, f.value)
self.assertEqual(f.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
def test_join_too_large_admin(self):
# Check whether an admin can join if option "admins_can_join" is undefined,
# this option defaults to false, so the join should fail.
u1 = self.register_user("u1", "pass", admin=True)
handler = self.hs.get_room_member_handler()
fed_transport = self.hs.get_federation_transport_client()
# Mock out some things, because we don't want to test the whole join
fed_transport.client.get_json = Mock(return_value=defer.succeed({"v1": 9999}))
handler.federation_handler.do_invite_join = Mock(
return_value=defer.succeed(("", 1))
)
d = handler._remote_join(
None,
["other.example.com"],
"roomid",
UserID.from_string(u1),
{"membership": "join"},
)
self.pump()
# The request failed with a SynapseError saying the resource limit was
# exceeded.
f = self.get_failure(d, SynapseError)
self.assertEqual(f.value.code, 400, f.value)
self.assertEqual(f.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
def test_join_too_large_once_joined(self):
u1 = self.register_user("u1", "pass")
@@ -141,3 +172,81 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
f = self.get_failure(d, SynapseError)
self.assertEqual(f.value.code, 400)
self.assertEqual(f.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase):
# Test the behavior of joining rooms which exceed the complexity if option
# limit_remote_rooms.admins_can_join is True.
servlets = [
admin.register_servlets,
room.register_servlets,
login.register_servlets,
]
def default_config(self):
config = super().default_config()
config["limit_remote_rooms"] = {
"enabled": True,
"complexity": 0.05,
"admins_can_join": True,
}
return config
def test_join_too_large_no_admin(self):
# A user which is not an admin should not be able to join a remote room
# which is too complex.
u1 = self.register_user("u1", "pass")
handler = self.hs.get_room_member_handler()
fed_transport = self.hs.get_federation_transport_client()
# Mock out some things, because we don't want to test the whole join
fed_transport.client.get_json = Mock(return_value=defer.succeed({"v1": 9999}))
handler.federation_handler.do_invite_join = Mock(
return_value=defer.succeed(("", 1))
)
d = handler._remote_join(
None,
["other.example.com"],
"roomid",
UserID.from_string(u1),
{"membership": "join"},
)
self.pump()
# The request failed with a SynapseError saying the resource limit was
# exceeded.
f = self.get_failure(d, SynapseError)
self.assertEqual(f.value.code, 400, f.value)
self.assertEqual(f.value.errcode, Codes.RESOURCE_LIMIT_EXCEEDED)
def test_join_too_large_admin(self):
# An admin should be able to join rooms where a complexity check fails.
u1 = self.register_user("u1", "pass", admin=True)
handler = self.hs.get_room_member_handler()
fed_transport = self.hs.get_federation_transport_client()
# Mock out some things, because we don't want to test the whole join
fed_transport.client.get_json = Mock(return_value=defer.succeed({"v1": 9999}))
handler.federation_handler.do_invite_join = Mock(
return_value=defer.succeed(("", 1))
)
d = handler._remote_join(
None,
["other.example.com"],
"roomid",
UserID.from_string(u1),
{"membership": "join"},
)
self.pump()
# The request success since the user is an admin
self.get_success(d)
@@ -366,7 +366,9 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
state_handler = self.hs.get_state_handler()
context = self.get_success(state_handler.compute_event_context(event))
self.master_store.add_push_actions_to_staging(
event.event_id, {user_id: actions for user_id, actions in push_actions}
self.get_success(
self.master_store.add_push_actions_to_staging(
event.event_id, {user_id: actions for user_id, actions in push_actions}
)
)
return event, context
+55 -2
View File
@@ -283,6 +283,23 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"])
def test_purge_is_not_bool(self):
"""
If parameter `purge` is not boolean, return an error
"""
body = json.dumps({"purge": "NotBool"})
request, channel = self.make_request(
"POST",
self.url,
content=body.encode(encoding="utf_8"),
access_token=self.admin_user_tok,
)
self.render(request)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.BAD_JSON, channel.json_body["errcode"])
def test_purge_room_and_block(self):
"""Test to purge a room and block it.
Members will not be moved to a new room and will not receive a message.
@@ -297,7 +314,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
# Assert one user in room
self._is_member(room_id=self.room_id, user_id=self.other_user)
body = json.dumps({"block": True})
body = json.dumps({"block": True, "purge": True})
request, channel = self.make_request(
"POST",
@@ -331,7 +348,7 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
# Assert one user in room
self._is_member(room_id=self.room_id, user_id=self.other_user)
body = json.dumps({"block": False})
body = json.dumps({"block": False, "purge": True})
request, channel = self.make_request(
"POST",
@@ -351,6 +368,42 @@ class DeleteRoomTestCase(unittest.HomeserverTestCase):
self._is_blocked(self.room_id, expect=False)
self._has_no_members(self.room_id)
def test_block_room_and_not_purge(self):
"""Test to block a room without purging it.
Members will not be moved to a new room and will not receive a message.
The room will not be purged.
"""
# Test that room is not purged
with self.assertRaises(AssertionError):
self._is_purged(self.room_id)
# Test that room is not blocked
self._is_blocked(self.room_id, expect=False)
# Assert one user in room
self._is_member(room_id=self.room_id, user_id=self.other_user)
body = json.dumps({"block": False, "purge": False})
request, channel = self.make_request(
"POST",
self.url.encode("ascii"),
content=body.encode(encoding="utf_8"),
access_token=self.admin_user_tok,
)
self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(None, channel.json_body["new_room_id"])
self.assertEqual(self.other_user, channel.json_body["kicked_users"][0])
self.assertIn("failed_to_kick_users", channel.json_body)
self.assertIn("local_aliases", channel.json_body)
with self.assertRaises(AssertionError):
self._is_purged(self.room_id)
self._is_blocked(self.room_id, expect=False)
self._has_no_members(self.room_id)
def test_shutdown_room_consent(self):
"""Test that we can shutdown rooms with local users who have not
yet accepted the privacy policy. This used to fail when we tried to
+20
View File
@@ -143,6 +143,26 @@ class RestHelper(object):
return channel.json_body
def redact(self, room_id, event_id, txn_id=None, tok=None, expect_code=200):
if txn_id is None:
txn_id = "m%s" % (str(time.time()))
path = "/_matrix/client/r0/rooms/%s/redact/%s/%s" % (room_id, event_id, txn_id)
if tok:
path = path + "?access_token=%s" % tok
request, channel = make_request(
self.hs.get_reactor(), "PUT", path, json.dumps({}).encode("utf8")
)
render(request, self.resource, self.hs.get_reactor())
assert int(channel.result["code"]) == expect_code, (
"Expected: %d, got: %d, resp: %r"
% (expect_code, int(channel.result["code"]), channel.result["body"])
)
return channel.json_body
def _read_write_state(
self,
room_id: str,
+155 -2
View File
@@ -16,9 +16,9 @@
import json
import synapse.rest.admin
from synapse.api.constants import EventContentFields, EventTypes
from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
from synapse.rest.client.v1 import login, room
from synapse.rest.client.v2_alpha import sync
from synapse.rest.client.v2_alpha import read_marker, sync
from tests import unittest
from tests.server import TimedOutException
@@ -324,3 +324,156 @@ class SyncTypingTests(unittest.HomeserverTestCase):
"GET", sync_url % (access_token, next_batch)
)
self.assertRaises(TimedOutException, self.render, request)
class UnreadMessagesTestCase(unittest.HomeserverTestCase):
servlets = [
synapse.rest.admin.register_servlets,
login.register_servlets,
read_marker.register_servlets,
room.register_servlets,
sync.register_servlets,
]
def prepare(self, reactor, clock, hs):
self.url = "/sync?since=%s"
self.next_batch = "s0"
# Register the first user (used to check the unread counts).
self.user_id = self.register_user("kermit", "monkey")
self.tok = self.login("kermit", "monkey")
# Create the room we'll check unread counts for.
self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
# Register the second user (used to send events to the room).
self.user2 = self.register_user("kermit2", "monkey")
self.tok2 = self.login("kermit2", "monkey")
# Change the power levels of the room so that the second user can send state
# events.
self.helper.send_state(
self.room_id,
EventTypes.PowerLevels,
{
"users": {self.user_id: 100, self.user2: 100},
"users_default": 0,
"events": {
"m.room.name": 50,
"m.room.power_levels": 100,
"m.room.history_visibility": 100,
"m.room.canonical_alias": 50,
"m.room.avatar": 50,
"m.room.tombstone": 100,
"m.room.server_acl": 100,
"m.room.encryption": 100,
},
"events_default": 0,
"state_default": 50,
"ban": 50,
"kick": 50,
"redact": 50,
"invite": 0,
},
tok=self.tok,
)
def test_unread_counts(self):
"""Tests that /sync returns the right value for the unread count (MSC2654)."""
# Check that our own messages don't increase the unread count.
self.helper.send(self.room_id, "hello", tok=self.tok)
self._check_unread_count(0)
# Join the new user and check that this doesn't increase the unread count.
self.helper.join(room=self.room_id, user=self.user2, tok=self.tok2)
self._check_unread_count(0)
# Check that the new user sending a message increases our unread count.
res = self.helper.send(self.room_id, "hello", tok=self.tok2)
self._check_unread_count(1)
# Send a read receipt to tell the server we've read the latest event.
body = json.dumps({"m.read": res["event_id"]}).encode("utf8")
request, channel = self.make_request(
"POST",
"/rooms/%s/read_markers" % self.room_id,
body,
access_token=self.tok,
)
self.render(request)
self.assertEqual(channel.code, 200, channel.json_body)
# Check that the unread counter is back to 0.
self._check_unread_count(0)
# Check that room name changes increase the unread counter.
self.helper.send_state(
self.room_id, "m.room.name", {"name": "my super room"}, tok=self.tok2,
)
self._check_unread_count(1)
# Check that room topic changes increase the unread counter.
self.helper.send_state(
self.room_id, "m.room.topic", {"topic": "welcome!!!"}, tok=self.tok2,
)
self._check_unread_count(2)
# Check that encrypted messages increase the unread counter.
self.helper.send_event(self.room_id, EventTypes.Encrypted, {}, tok=self.tok2)
self._check_unread_count(3)
# Check that custom events with a body increase the unread counter.
self.helper.send_event(
self.room_id, "org.matrix.custom_type", {"body": "hello"}, tok=self.tok2,
)
self._check_unread_count(4)
# Check that edits don't increase the unread counter.
self.helper.send_event(
room_id=self.room_id,
type=EventTypes.Message,
content={
"body": "hello",
"msgtype": "m.text",
"m.relates_to": {"rel_type": RelationTypes.REPLACE},
},
tok=self.tok2,
)
self._check_unread_count(4)
# Check that notices don't increase the unread counter.
self.helper.send_event(
room_id=self.room_id,
type=EventTypes.Message,
content={"body": "hello", "msgtype": "m.notice"},
tok=self.tok2,
)
self._check_unread_count(4)
# Check that tombstone events changes increase the unread counter.
self.helper.send_state(
self.room_id,
EventTypes.Tombstone,
{"replacement_room": "!someroom:test"},
tok=self.tok2,
)
self._check_unread_count(5)
def _check_unread_count(self, expected_count: True):
"""Syncs and compares the unread count with the expected value."""
request, channel = self.make_request(
"GET", self.url % self.next_batch, access_token=self.tok,
)
self.render(request)
self.assertEqual(channel.code, 200, channel.json_body)
room_entry = channel.json_body["rooms"]["join"][self.room_id]
self.assertEqual(
room_entry["org.matrix.msc2654.unread_count"], expected_count, room_entry,
)
# Store the next batch for the next request.
self.next_batch = channel.json_body["next_batch"]
+4 -2
View File
@@ -72,8 +72,10 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
event.internal_metadata.stream_ordering = stream
event.depth = stream
yield self.store.add_push_actions_to_staging(
event.event_id, {user_id: action}
yield defer.ensureDeferred(
self.store.add_push_actions_to_staging(
event.event_id, {user_id: action}
)
)
yield self.store.db.runInteraction(
"",
+6 -2
View File
@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.rest.client.v1 import room
from tests.unittest import HomeserverTestCase
@@ -49,7 +51,9 @@ class PurgeTests(HomeserverTestCase):
event = self.successResultOf(event)
# Purge everything before this topological token
purge = storage.purge_events.purge_history(self.room_id, event, True)
purge = defer.ensureDeferred(
storage.purge_events.purge_history(self.room_id, event, True)
)
self.pump()
self.assertEqual(self.successResultOf(purge), None)
@@ -88,7 +92,7 @@ class PurgeTests(HomeserverTestCase):
)
# Purge everything before this topological token
purge = storage.purge_history(self.room_id, event, True)
purge = defer.ensureDeferred(storage.purge_history(self.room_id, event, True))
self.pump()
f = self.failureResultOf(purge)
self.assertIn("greater than forward", f.value.args[0])
+3 -1
View File
@@ -237,7 +237,9 @@ class RedactionTestCase(unittest.HomeserverTestCase):
@defer.inlineCallbacks
def build(self, prev_event_ids):
built_event = yield self._base_builder.build(prev_event_ids)
built_event = yield defer.ensureDeferred(
self._base_builder.build(prev_event_ids)
)
built_event._event_id = self._event_id
built_event._dict["event_id"] = self._event_id
+4 -2
View File
@@ -97,8 +97,10 @@ class RoomEventsStoreTestCase(unittest.TestCase):
@defer.inlineCallbacks
def inject_room_event(self, **kwargs):
yield self.storage.persistence.persist_event(
self.event_factory.create_event(room_id=self.room.to_string(), **kwargs)
yield defer.ensureDeferred(
self.storage.persistence.persist_event(
self.event_factory.create_event(room_id=self.room.to_string(), **kwargs)
)
)
@defer.inlineCallbacks
+39 -25
View File
@@ -68,7 +68,9 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.event_creation_handler.create_new_client_event(builder)
)
yield self.storage.persistence.persist_event(event, context)
yield defer.ensureDeferred(
self.storage.persistence.persist_event(event, context)
)
return event
@@ -87,8 +89,8 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
)
state_group_map = yield self.storage.state.get_state_groups_ids(
self.room, [e2.event_id]
state_group_map = yield defer.ensureDeferred(
self.storage.state.get_state_groups_ids(self.room, [e2.event_id])
)
self.assertEqual(len(state_group_map), 1)
state_map = list(state_group_map.values())[0]
@@ -106,8 +108,8 @@ class StateStoreTestCase(tests.unittest.TestCase):
self.room, self.u_alice, EventTypes.Name, "", {"name": "test room"}
)
state_group_map = yield self.storage.state.get_state_groups(
self.room, [e2.event_id]
state_group_map = yield defer.ensureDeferred(
self.storage.state.get_state_groups(self.room, [e2.event_id])
)
self.assertEqual(len(state_group_map), 1)
state_list = list(state_group_map.values())[0]
@@ -148,7 +150,9 @@ class StateStoreTestCase(tests.unittest.TestCase):
)
# check we get the full state as of the final event
state = yield self.storage.state.get_state_for_event(e5.event_id)
state = yield defer.ensureDeferred(
self.storage.state.get_state_for_event(e5.event_id)
)
self.assertIsNotNone(e4)
@@ -164,22 +168,28 @@ class StateStoreTestCase(tests.unittest.TestCase):
)
# check we can filter to the m.room.name event (with a '' state key)
state = yield self.storage.state.get_state_for_event(
e5.event_id, StateFilter.from_types([(EventTypes.Name, "")])
state = yield defer.ensureDeferred(
self.storage.state.get_state_for_event(
e5.event_id, StateFilter.from_types([(EventTypes.Name, "")])
)
)
self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
# check we can filter to the m.room.name event (with a wildcard None state key)
state = yield self.storage.state.get_state_for_event(
e5.event_id, StateFilter.from_types([(EventTypes.Name, None)])
state = yield defer.ensureDeferred(
self.storage.state.get_state_for_event(
e5.event_id, StateFilter.from_types([(EventTypes.Name, None)])
)
)
self.assertStateMapEqual({(e2.type, e2.state_key): e2}, state)
# check we can grab the m.room.member events (with a wildcard None state key)
state = yield self.storage.state.get_state_for_event(
e5.event_id, StateFilter.from_types([(EventTypes.Member, None)])
state = yield defer.ensureDeferred(
self.storage.state.get_state_for_event(
e5.event_id, StateFilter.from_types([(EventTypes.Member, None)])
)
)
self.assertStateMapEqual(
@@ -188,12 +198,14 @@ class StateStoreTestCase(tests.unittest.TestCase):
# check we can grab a specific room member without filtering out the
# other event types
state = yield self.storage.state.get_state_for_event(
e5.event_id,
state_filter=StateFilter(
types={EventTypes.Member: {self.u_alice.to_string()}},
include_others=True,
),
state = yield defer.ensureDeferred(
self.storage.state.get_state_for_event(
e5.event_id,
state_filter=StateFilter(
types={EventTypes.Member: {self.u_alice.to_string()}},
include_others=True,
),
)
)
self.assertStateMapEqual(
@@ -206,11 +218,13 @@ class StateStoreTestCase(tests.unittest.TestCase):
)
# check that we can grab everything except members
state = yield self.storage.state.get_state_for_event(
e5.event_id,
state_filter=StateFilter(
types={EventTypes.Member: set()}, include_others=True
),
state = yield defer.ensureDeferred(
self.storage.state.get_state_for_event(
e5.event_id,
state_filter=StateFilter(
types={EventTypes.Member: set()}, include_others=True
),
)
)
self.assertStateMapEqual(
@@ -222,8 +236,8 @@ class StateStoreTestCase(tests.unittest.TestCase):
#######################################################
room_id = self.room.to_string()
group_ids = yield self.storage.state.get_state_groups_ids(
room_id, [e5.event_id]
group_ids = yield defer.ensureDeferred(
self.storage.state.get_state_groups_ids(room_id, [e5.event_id])
)
group = list(group_ids.keys())[0]
+7 -7
View File
@@ -213,7 +213,7 @@ class StateTestCase(unittest.TestCase):
ctx_c = context_store["C"]
ctx_d = context_store["D"]
prev_state_ids = yield ctx_d.get_prev_state_ids()
prev_state_ids = yield defer.ensureDeferred(ctx_d.get_prev_state_ids())
self.assertEqual(2, len(prev_state_ids))
self.assertEqual(ctx_c.state_group, ctx_d.state_group_before_event)
@@ -259,7 +259,7 @@ class StateTestCase(unittest.TestCase):
ctx_c = context_store["C"]
ctx_d = context_store["D"]
prev_state_ids = yield ctx_d.get_prev_state_ids()
prev_state_ids = yield defer.ensureDeferred(ctx_d.get_prev_state_ids())
self.assertSetEqual({"START", "A", "C"}, set(prev_state_ids.values()))
self.assertEqual(ctx_c.state_group, ctx_d.state_group_before_event)
@@ -318,7 +318,7 @@ class StateTestCase(unittest.TestCase):
ctx_c = context_store["C"]
ctx_e = context_store["E"]
prev_state_ids = yield ctx_e.get_prev_state_ids()
prev_state_ids = yield defer.ensureDeferred(ctx_e.get_prev_state_ids())
self.assertSetEqual({"START", "A", "B", "C"}, set(prev_state_ids.values()))
self.assertEqual(ctx_c.state_group, ctx_e.state_group_before_event)
self.assertEqual(ctx_e.state_group_before_event, ctx_e.state_group)
@@ -393,7 +393,7 @@ class StateTestCase(unittest.TestCase):
ctx_b = context_store["B"]
ctx_d = context_store["D"]
prev_state_ids = yield ctx_d.get_prev_state_ids()
prev_state_ids = yield defer.ensureDeferred(ctx_d.get_prev_state_ids())
self.assertSetEqual({"A1", "A2", "A3", "A5", "B"}, set(prev_state_ids.values()))
self.assertEqual(ctx_b.state_group, ctx_d.state_group_before_event)
@@ -425,7 +425,7 @@ class StateTestCase(unittest.TestCase):
self.state.compute_event_context(event, old_state=old_state)
)
prev_state_ids = yield context.get_prev_state_ids()
prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())
self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values())
current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
@@ -450,7 +450,7 @@ class StateTestCase(unittest.TestCase):
self.state.compute_event_context(event, old_state=old_state)
)
prev_state_ids = yield context.get_prev_state_ids()
prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())
self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values())
current_state_ids = yield defer.ensureDeferred(context.get_current_state_ids())
@@ -519,7 +519,7 @@ class StateTestCase(unittest.TestCase):
context = yield defer.ensureDeferred(self.state.compute_event_context(event))
prev_state_ids = yield context.get_prev_state_ids()
prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())
self.assertEqual({e.event_id for e in old_state}, set(prev_state_ids.values()))
+16 -10
View File
@@ -40,7 +40,7 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
self.store = self.hs.get_datastore()
self.storage = self.hs.get_storage()
yield create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM")
yield defer.ensureDeferred(create_room(self.hs, TEST_ROOM_ID, "@someone:ROOM"))
@defer.inlineCallbacks
def test_filtering(self):
@@ -64,8 +64,8 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
evt = yield self.inject_room_member(user, extra_content={"a": "b"})
events_to_filter.append(evt)
filtered = yield filter_events_for_server(
self.storage, "test_server", events_to_filter
filtered = yield defer.ensureDeferred(
filter_events_for_server(self.storage, "test_server", events_to_filter)
)
# the result should be 5 redacted events, and 5 unredacted events.
@@ -102,8 +102,8 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
yield self.hs.get_datastore().mark_user_erased("@erased:local_hs")
# ... and the filtering happens.
filtered = yield filter_events_for_server(
self.storage, "test_server", events_to_filter
filtered = yield defer.ensureDeferred(
filter_events_for_server(self.storage, "test_server", events_to_filter)
)
for i in range(0, len(events_to_filter)):
@@ -140,7 +140,9 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
event, context = yield defer.ensureDeferred(
self.event_creation_handler.create_new_client_event(builder)
)
yield self.storage.persistence.persist_event(event, context)
yield defer.ensureDeferred(
self.storage.persistence.persist_event(event, context)
)
return event
@defer.inlineCallbacks
@@ -162,7 +164,9 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
self.event_creation_handler.create_new_client_event(builder)
)
yield self.storage.persistence.persist_event(event, context)
yield defer.ensureDeferred(
self.storage.persistence.persist_event(event, context)
)
return event
@defer.inlineCallbacks
@@ -183,7 +187,9 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
self.event_creation_handler.create_new_client_event(builder)
)
yield self.storage.persistence.persist_event(event, context)
yield defer.ensureDeferred(
self.storage.persistence.persist_event(event, context)
)
return event
@defer.inlineCallbacks
@@ -265,8 +271,8 @@ class FilterEventsForServerTestCase(tests.unittest.TestCase):
storage.main = test_store
storage.state = test_store
filtered = yield filter_events_for_server(
test_store, "test_server", events_to_filter
filtered = yield defer.ensureDeferred(
filter_events_for_server(test_store, "test_server", events_to_filter)
)
logger.info("Filtering took %f seconds", time.time() - start)
+4 -12
View File
@@ -640,14 +640,8 @@ class DeferredMockCallable(object):
)
@defer.inlineCallbacks
def create_room(hs, room_id, creator_id):
async def create_room(hs, room_id: str, creator_id: str):
"""Creates and persist a creation event for the given room
Args:
hs
room_id (str)
creator_id (str)
"""
persistence_store = hs.get_storage().persistence
@@ -655,7 +649,7 @@ def create_room(hs, room_id, creator_id):
event_builder_factory = hs.get_event_builder_factory()
event_creation_handler = hs.get_event_creation_handler()
yield store.store_room(
await store.store_room(
room_id=room_id,
room_creator_user_id=creator_id,
is_public=False,
@@ -673,8 +667,6 @@ def create_room(hs, room_id, creator_id):
},
)
event, context = yield defer.ensureDeferred(
event_creation_handler.create_new_client_event(builder)
)
event, context = await event_creation_handler.create_new_client_event(builder)
yield persistence_store.persist_event(event, context)
await persistence_store.persist_event(event, context)
+1
View File
@@ -207,6 +207,7 @@ commands = mypy \
synapse/storage/data_stores/main/ui_auth.py \
synapse/storage/database.py \
synapse/storage/engines \
synapse/storage/state.py \
synapse/storage/util \
synapse/streams \
synapse/util/caches/stream_change_cache.py \