1
0

Compare commits

..

3 Commits

Author SHA1 Message Date
Andrew Morgan
8acd2c01bc lil fix 2022-09-26 16:13:53 +01:00
Andrew Morgan
f1d98d3b70 wip2 2022-09-22 15:54:30 +01:00
Andrew Morgan
6ff8ba5fc6 wip 2022-09-21 17:37:38 +01:00
59 changed files with 634 additions and 1592 deletions

31
.ci/scripts/postgres_exec.py Executable file
View File

@@ -0,0 +1,31 @@
#!/usr/bin/env python
# Copyright 2019 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import psycopg2
# a very simple replacment for `psql`, to make up for the lack of the postgres client
# libraries in the synapse docker image.
# We use "postgres" as a database because it's bound to exist and the "synapse" one
# doesn't exist yet.
db_conn = psycopg2.connect(
user="postgres", host="localhost", password="postgres", dbname="postgres"
)
db_conn.autocommit = True
cur = db_conn.cursor()
for c in sys.argv[1:]:
cur.execute(c)

View File

@@ -32,7 +32,7 @@ else
fi
# Create the PostgreSQL database.
psql -c "CREATE DATABASE synapse"
poetry run .ci/scripts/postgres_exec.py "CREATE DATABASE synapse"
# Port the SQLite databse to postgres so we can check command works against postgres
echo "+++ Port SQLite3 databse to postgres"

View File

@@ -2,27 +2,27 @@
#
# Test script for 'synapse_port_db'.
# - configures synapse and a postgres server.
# - runs the port script on a prepopulated test sqlite db. Checks that the
# return code is zero.
# - reruns the port script on the same sqlite db, targetting the same postgres db.
# Checks that the return code is zero.
# - runs the port script against a new sqlite db. Checks the return code is zero.
# - runs the port script on a prepopulated test sqlite db
# - also runs it against an new sqlite db
#
# Expects Synapse to have been already installed with `poetry install --extras postgres`.
# Expects `poetry` to be available on the `PATH`.
set -xe -o pipefail
set -xe
cd "$(dirname "$0")/../.."
echo "--- Generate the signing key"
# Generate the server's signing key.
poetry run synapse_homeserver --generate-keys -c .ci/sqlite-config.yaml
echo "--- Prepare test database"
# Make sure the SQLite3 database is using the latest schema and has no pending background updates.
# Make sure the SQLite3 database is using the latest schema and has no pending background update.
poetry run update_synapse_database --database-config .ci/sqlite-config.yaml --run-background-updates
# Create the PostgreSQL database.
psql -c "CREATE DATABASE synapse"
poetry run .ci/scripts/postgres_exec.py "CREATE DATABASE synapse"
echo "+++ Run synapse_port_db against test database"
# TODO: this invocation of synapse_port_db (and others below) used to be prepended with `coverage run`,
@@ -45,23 +45,9 @@ rm .ci/test_db.db
poetry run update_synapse_database --database-config .ci/sqlite-config.yaml --run-background-updates
# re-create the PostgreSQL database.
psql \
-c "DROP DATABASE synapse" \
-c "CREATE DATABASE synapse"
poetry run .ci/scripts/postgres_exec.py \
"DROP DATABASE synapse" \
"CREATE DATABASE synapse"
echo "+++ Run synapse_port_db against empty database"
poetry run synapse_port_db --sqlite-database .ci/test_db.db --postgres-config .ci/postgres-config.yaml
echo "--- Create a brand new postgres database from schema"
cp .ci/postgres-config.yaml .ci/postgres-config-unported.yaml
sed -i -e 's/database: synapse/database: synapse_unported/' .ci/postgres-config-unported.yaml
psql -c "CREATE DATABASE synapse_unported"
poetry run update_synapse_database --database-config .ci/postgres-config-unported.yaml --run-background-updates
echo "+++ Comparing ported schema with unported schema"
# Ignore the tables that portdb creates. (Should it tidy them up when the porting is completed?)
psql synapse -c "DROP TABLE port_from_sqlite3;"
pg_dump --format=plain --schema-only --no-tablespaces --no-acl --no-owner synapse_unported > unported.sql
pg_dump --format=plain --schema-only --no-tablespaces --no-acl --no-owner synapse > ported.sql
# By default, `diff` returns zero if there are no changes and nonzero otherwise
diff -u unported.sql ported.sql | tee schema_diff

View File

@@ -11,6 +11,5 @@
!build_rust.py
rust/target
synapse/*.so
**/__pycache__

View File

@@ -32,11 +32,9 @@ jobs:
steps:
- uses: actions/checkout@v2
- uses: actions/setup-python@v2
- uses: matrix-org/setup-python-poetry@v1
with:
extras: "all"
- run: poetry run scripts-dev/generate_sample_config.sh --check
- run: poetry run scripts-dev/config-lint.sh
- run: pip install .
- run: scripts-dev/generate_sample_config.sh --check
- run: scripts-dev/config-lint.sh
check-schema-delta:
runs-on: ubuntu-latest
@@ -78,6 +76,7 @@ jobs:
- uses: actions/checkout@v2
with:
ref: ${{ github.event.pull_request.head.sha }}
fetch-depth: 0
- uses: matrix-org/setup-python-poetry@v1
with:
extras: "all"
@@ -362,22 +361,18 @@ jobs:
steps:
- uses: actions/checkout@v2
- run: sudo apt-get -qq install xmlsec1 postgresql-client
- run: sudo apt-get -qq install xmlsec1
- uses: matrix-org/setup-python-poetry@v1
with:
extras: "postgres"
- run: .ci/scripts/test_export_data_command.sh
env:
PGHOST: localhost
PGUSER: postgres
PGPASSWORD: postgres
PGDATABASE: postgres
portdb:
if: ${{ !failure() && !cancelled() }} # Allow previous steps to be skipped, but not fail
needs: linting-done
runs-on: ubuntu-latest
env:
TOP: ${{ github.workspace }}
strategy:
matrix:
include:
@@ -403,27 +398,12 @@ jobs:
steps:
- uses: actions/checkout@v2
- run: sudo apt-get -qq install xmlsec1 postgresql-client
- run: sudo apt-get -qq install xmlsec1
- uses: matrix-org/setup-python-poetry@v1
with:
python-version: ${{ matrix.python-version }}
extras: "postgres"
- run: .ci/scripts/test_synapse_port_db.sh
id: run_tester_script
env:
PGHOST: localhost
PGUSER: postgres
PGPASSWORD: postgres
PGDATABASE: postgres
- name: "Upload schema differences"
uses: actions/upload-artifact@v3
if: ${{ failure() && !cancelled() && steps.run_tester_script.outcome == 'failure' }}
with:
name: Schema dumps
path: |
unported.sql
ported.sql
schema_diff
complement:
if: "${{ !failure() && !cancelled() }}"

View File

@@ -1 +0,0 @@
Refactor ` _send_events_for_new_room` to separate creating and sending events.

View File

@@ -1 +1 @@
Keep track when we fail to process a pulled event over federation so we can intelligently back-off in the future.
Keep track when we attempt to backfill an event but fail so we can intelligently back-off in the future.

View File

@@ -1 +0,0 @@
Fix a long-standing bug where previously rejected events could end up in room state because they pass auth checks given the current state of the room.

View File

@@ -1 +0,0 @@
Improve validation of request bodies for the following client-server API endpoints: [`/account/3pid/add`](https://spec.matrix.org/v1.3/client-server-api/#post_matrixclientv3account3pidadd), [`/account/3pid/bind`](https://spec.matrix.org/v1.3/client-server-api/#post_matrixclientv3account3pidbind), [`/account/3pid/delete`](https://spec.matrix.org/v1.3/client-server-api/#post_matrixclientv3account3piddelete) and [`/account/3pid/unbind`](https://spec.matrix.org/v1.3/client-server-api/#post_matrixclientv3account3pidunbind).

View File

@@ -1 +0,0 @@
Add docs for common fix of deleting the `matrix_synapse.egg-info/` directory for fixing Python dependency problems.

View File

@@ -1 +0,0 @@
Update request log format documentation to mention the format used when the authenticated user is controlling another user.

View File

@@ -1 +0,0 @@
Add `listeners[x].request_id_header` config to specify which request header to extract and use as the request ID in order to correlate requests from a reverse-proxy.

View File

@@ -1 +0,0 @@
Check that portdb generates the same postgres schema as that in the source tree.

View File

@@ -1 +0,0 @@
Add an admin API endpoint to find a user based on its external ID in an auth provider.

View File

@@ -1 +0,0 @@
Fix Docker build when Rust .so has been build locally first.

View File

@@ -1 +0,0 @@
Keep track when we fail to process a pulled event over federation so we can intelligently back-off in the future.

View File

@@ -1 +0,0 @@
complement: init postgres DB directly inside the target image instead of the base postgres image to fix building using Buildah.

View File

@@ -1 +0,0 @@
Support providing an index predicate clause when doing upserts.

View File

@@ -1 +0,0 @@
Delete associated data from `event_failed_pull_attempts`, `insertion_events`, `insertion_event_extremities`, `insertion_event_extremities`, `insertion_event_extremities` when purging the room.

View File

@@ -1 +0,0 @@
Fix a long standing bug where device lists would remain cached when remote users left and rejoined the last room shared with the local homeserver.

View File

@@ -1 +0,0 @@
Minor speedups to linting in CI.

View File

@@ -31,9 +31,7 @@ ARG PYTHON_VERSION=3.9
###
### Stage 0: generate requirements.txt
###
# We hardcode the use of Debian bullseye here because this could change upstream
# and other Dockerfiles used for testing are expecting bullseye.
FROM docker.io/python:${PYTHON_VERSION}-slim-bullseye as requirements
FROM docker.io/python:${PYTHON_VERSION}-slim as requirements
# RUN --mount is specific to buildkit and is documented at
# https://github.com/moby/buildkit/blob/master/frontend/dockerfile/docs/syntax.md#build-mounts-run---mount.
@@ -78,7 +76,7 @@ RUN if [ -z "$TEST_ONLY_IGNORE_POETRY_LOCKFILE" ]; then \
###
### Stage 1: builder
###
FROM docker.io/python:${PYTHON_VERSION}-slim-bullseye as builder
FROM docker.io/python:${PYTHON_VERSION}-slim as builder
# install the OS build deps
RUN \
@@ -139,7 +137,7 @@ RUN if [ -z "$TEST_ONLY_IGNORE_POETRY_LOCKFILE" ]; then \
### Stage 2: runtime
###
FROM docker.io/python:${PYTHON_VERSION}-slim-bullseye
FROM docker.io/python:${PYTHON_VERSION}-slim
LABEL org.opencontainers.image.url='https://matrix.org/docs/projects/server/synapse'
LABEL org.opencontainers.image.documentation='https://github.com/matrix-org/synapse/blob/master/docker/README.md'

View File

@@ -17,17 +17,7 @@ ARG SYNAPSE_VERSION=latest
# the same debian version as Synapse's docker image (so the versions of the
# shared libraries match).
# now build the final image, based on the Synapse image.
FROM matrixdotorg/synapse-workers:$SYNAPSE_VERSION
# copy the postgres installation over from the image we built above
RUN adduser --system --uid 999 postgres --home /var/lib/postgresql
COPY --from=postgres:13-bullseye /usr/lib/postgresql /usr/lib/postgresql
COPY --from=postgres:13-bullseye /usr/share/postgresql /usr/share/postgresql
RUN mkdir /var/run/postgresql && chown postgres /var/run/postgresql
ENV PATH="${PATH}:/usr/lib/postgresql/13/bin"
ENV PGDATA=/var/lib/postgresql/data
FROM postgres:13-bullseye AS postgres_base
# initialise the database cluster in /var/lib/postgresql
RUN gosu postgres initdb --locale=C --encoding=UTF-8 --auth-host password
@@ -35,6 +25,18 @@ FROM matrixdotorg/synapse-workers:$SYNAPSE_VERSION
RUN echo "ALTER USER postgres PASSWORD 'somesecret'" | gosu postgres postgres --single
RUN echo "CREATE DATABASE synapse" | gosu postgres postgres --single
# now build the final image, based on the Synapse image.
FROM matrixdotorg/synapse-workers:$SYNAPSE_VERSION
# copy the postgres installation over from the image we built above
RUN adduser --system --uid 999 postgres --home /var/lib/postgresql
COPY --from=postgres_base /var/lib/postgresql /var/lib/postgresql
COPY --from=postgres_base /usr/lib/postgresql /usr/lib/postgresql
COPY --from=postgres_base /usr/share/postgresql /usr/share/postgresql
RUN mkdir /var/run/postgresql && chown postgres /var/run/postgresql
ENV PATH="${PATH}:/usr/lib/postgresql/13/bin"
ENV PGDATA=/var/lib/postgresql/data
# Extend the shared homeserver config to disable rate-limiting,
# set Complement's static shared secret, enable registration, amongst other
# tweaks to get Synapse ready for testing.

View File

@@ -1155,41 +1155,3 @@ GET /_synapse/admin/v1/username_available?username=$localpart
The request and response format is the same as the
[/_matrix/client/r0/register/available](https://matrix.org/docs/spec/client_server/r0.6.0#get-matrix-client-r0-register-available) API.
### Find a user based on their ID in an auth provider
The API is:
```
GET /_synapse/admin/v1/auth_providers/$provider/users/$external_id
```
When a user matched the given ID for the given provider, an HTTP code `200` with a response body like the following is returned:
```json
{
"user_id": "@hello:example.org"
}
```
**Parameters**
The following parameters should be set in the URL:
- `provider` - The ID of the authentication provider, as advertised by the [`GET /_matrix/client/v3/login`](https://spec.matrix.org/latest/client-server-api/#post_matrixclientv3login) API in the `m.login.sso` authentication method.
- `external_id` - The user ID from the authentication provider. Usually corresponds to the `sub` claim for OIDC providers, or to the `uid` attestation for SAML2 providers.
The `external_id` may have characters that are not URL-safe (typically `/`, `:` or `@`), so it is advised to URL-encode those parameters.
**Errors**
Returns a `404` HTTP status code if no user was found, with a response body like this:
```json
{
"errcode":"M_NOT_FOUND",
"error":"User not found"
}
```
_Added in Synapse 1.68.0._

View File

@@ -126,23 +126,6 @@ context of poetry's venv, without having to run `poetry shell` beforehand.
poetry install --extras all --remove-untracked
```
## ...delete everything and start over from scratch?
```shell
# Stop the current virtualenv if active
$ deactivate
# Remove all of the files from the current environment.
# Don't worry, even though it says "all", this will only
# remove the Poetry virtualenvs for the current project.
$ poetry env remove --all
# Reactivate Poetry shell to create the virtualenv again
$ poetry shell
# Install everything again
$ poetry install --extras all
```
## ...run a command in the `poetry` virtualenv?
Use `poetry run cmd args` when you need the python virtualenv context.
@@ -273,16 +256,6 @@ from PyPI. (This is what makes poetry seem slow when doing the first
`poetry install`.) Try `poetry cache list` and `poetry cache clear --all
<name of cache>` to see if that fixes things.
## Remove outdated egg-info
Delete the `matrix_synapse.egg-info/` directory from the root of your Synapse
install.
This stores some cached information about dependencies and often conflicts with
letting Poetry do the right thing.
## Try `--verbose` or `--dry-run` arguments.
Sometimes useful to see what poetry's internal logic is.

View File

@@ -45,10 +45,6 @@ listens to traffic on localhost. (Do not change `bind_addresses` to `127.0.0.1`
when using a containerized Synapse, as that will prevent it from responding
to proxied traffic.)
Optionally, you can also set
[`request_id_header`](../usage/configuration/config_documentation.md#listeners)
so that the server extracts and re-uses the same request ID format that the
reverse proxy is using.
## Reverse-proxy configuration examples

View File

@@ -12,14 +12,14 @@ See the following for how to decode the dense data available from the default lo
| Part | Explanation |
| ----- | ------------ |
| AAAA | Timestamp request was logged (not received) |
| AAAA | Timestamp request was logged (not recieved) |
| BBBB | Logger name (`synapse.access.(http\|https).<tag>`, where 'tag' is defined in the `listeners` config section, normally the port) |
| CCCC | Line number in code |
| DDDD | Log Level |
| EEEE | Request Identifier (This identifier is shared by related log lines)|
| FFFF | Source IP (Or X-Forwarded-For if enabled) |
| GGGG | Server Port |
| HHHH | Federated Server or Local User making request (blank if unauthenticated or not supplied).<br/>If this is of the form `@aaa:example.com|@bbb:example.com`, then that means that `@aaa:example.com` is authenticated but they are controlling `@bbb:example.com`, e.g. if `aaa` is controlling `bbb` [via the admin API](https://matrix-org.github.io/synapse/latest/admin_api/user_admin_api.html#login-as-a-user). |
| HHHH | Federated Server or Local User making request (blank if unauthenticated or not supplied) |
| IIII | Total Time to process the request |
| JJJJ | Time to send response over network once generated (this may be negative if the socket is closed before the response is generated)|
| KKKK | Userland CPU time |

View File

@@ -434,16 +434,7 @@ Sub-options for each listener include:
* `tls`: set to true to enable TLS for this listener. Will use the TLS key/cert specified in tls_private_key_path / tls_certificate_path.
* `x_forwarded`: Only valid for an 'http' listener. Set to true to use the X-Forwarded-For header as the client IP. Useful when Synapse is
behind a [reverse-proxy](../../reverse_proxy.md).
* `request_id_header`: The header extracted from each incoming request that is
used as the basis for the request ID. The request ID is used in
[logs](../administration/request_log.md#request-log-format) and tracing to
correlate and match up requests. When unset, Synapse will automatically
generate sequential request IDs. This option is useful when Synapse is behind
a [reverse-proxy](../../reverse_proxy.md).
_Added in Synapse 1.68.0._
behind a reverse-proxy.
* `resources`: Only valid for an 'http' listener. A list of resources to host
on this port. Sub-options for each resource are:

View File

@@ -206,7 +206,6 @@ class HttpListenerConfig:
resources: List[HttpResourceConfig] = attr.Factory(list)
additional_resources: Dict[str, dict] = attr.Factory(dict)
tag: Optional[str] = None
request_id_header: Optional[str] = None
@attr.s(slots=True, frozen=True, auto_attribs=True)
@@ -521,11 +520,9 @@ class ServerConfig(Config):
):
raise ConfigError("allowed_avatar_mimetypes must be a list")
listeners = config.get("listeners", [])
if not isinstance(listeners, list):
raise ConfigError("Expected a list", ("listeners",))
self.listeners = [parse_listener_def(i, x) for i, x in enumerate(listeners)]
self.listeners = [
parse_listener_def(i, x) for i, x in enumerate(config.get("listeners", []))
]
# no_tls is not really supported any more, but let's grandfather it in
# here.
@@ -892,9 +889,6 @@ def read_gc_thresholds(
def parse_listener_def(num: int, listener: Any) -> ListenerConfig:
"""parse a listener config from the config file"""
if not isinstance(listener, dict):
raise ConfigError("Expected a dictionary", ("listeners", str(num)))
listener_type = listener["type"]
# Raise a helpful error if direct TCP replication is still configured.
if listener_type == "replication":
@@ -934,7 +928,6 @@ def parse_listener_def(num: int, listener: Any) -> ListenerConfig:
resources=resources,
additional_resources=listener.get("additional_resources", {}),
tag=listener.get("tag"),
request_id_header=listener.get("request_id_header"),
)
return ListenerConfig(port, bind_addresses, listener_type, tls, http_config)

View File

@@ -188,21 +188,18 @@ class E2eKeysHandler:
)
invalid_cached_users = cached_users - valid_cached_users
if invalid_cached_users:
# Fix up results. If we get here, it means there was either a bug in
# device list tracking, or we hit the race mentioned above.
# TODO: In practice, this path is hit fairly often in existing
# deployments when clients query the keys of departed remote
# users. A background update to mark the appropriate device
# lists as unsubscribed is needed.
# https://github.com/matrix-org/synapse/issues/13651
# Note that this currently introduces a failure mode when clients
# are trying to decrypt old messages from a remote user whose
# homeserver is no longer available. We may want to consider falling
# back to the cached data when we fail to retrieve a device list
# over federation for such remote users.
# Fix up results. If we get here, there is either a bug in device
# list tracking, or we hit the race mentioned above.
user_ids_not_in_cache.update(invalid_cached_users)
for invalid_user_id in invalid_cached_users:
remote_results.pop(invalid_user_id)
# This log message may be removed if it turns out it's almost
# entirely triggered by races.
logger.error(
"Devices for %s were cached, but the server no longer shares "
"any rooms with them. The cached device lists are stale.",
invalid_cached_users,
)
for user_id, devices in remote_results.items():
user_devices = results.setdefault(user_id, {})

View File

@@ -866,11 +866,6 @@ class FederationEventHandler:
event.room_id, event_id, str(err)
)
return
except Exception as exc:
await self._store.record_event_failed_pull_attempt(
event.room_id, event_id, str(exc)
)
raise exc
try:
try:
@@ -913,11 +908,6 @@ class FederationEventHandler:
logger.warning("Pulled event %s failed history check.", event_id)
else:
raise
except Exception as exc:
await self._store.record_event_failed_pull_attempt(
event.room_id, event_id, str(exc)
)
raise exc
@trace
async def _compute_event_context_with_maybe_missing_prevs(

View File

@@ -56,16 +56,13 @@ from synapse.logging import opentracing
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.http.send_event import ReplicationSendEventRestServlet
from synapse.replication.http.send_events import ReplicationSendEventsRestServlet
from synapse.storage.databases.main.events import PartialStateConflictError
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.storage.state import StateFilter
from synapse.types import (
MutableStateMap,
PersistedEventPosition,
Requester,
RoomAlias,
StateMap,
StreamToken,
UserID,
create_requester,
@@ -495,7 +492,6 @@ class EventCreationHandler:
self.membership_types_to_include_profile_data_in.add(Membership.INVITE)
self.send_event = ReplicationSendEventRestServlet.make_client(hs)
self.send_events = ReplicationSendEventsRestServlet.make_client(hs)
self.request_ratelimiter = hs.get_request_ratelimiter()
@@ -631,10 +627,49 @@ class EventCreationHandler:
"""
await self.auth_blocking.check_auth_blocking(requester=requester)
builder = await self._get_and_validate_builder(event_dict)
if event_dict["type"] == EventTypes.Create and event_dict["state_key"] == "":
room_version_id = event_dict["content"]["room_version"]
maybe_room_version_obj = KNOWN_ROOM_VERSIONS.get(room_version_id)
if not maybe_room_version_obj:
# this can happen if support is withdrawn for a room version
raise UnsupportedRoomVersionError(room_version_id)
room_version_obj = maybe_room_version_obj
else:
try:
room_version_obj = await self.store.get_room_version(
event_dict["room_id"]
)
except NotFoundError:
raise AuthError(403, "Unknown room")
builder = self.event_builder_factory.for_room_version(
room_version_obj, event_dict
)
self.validator.validate_builder(builder)
if builder.type == EventTypes.Member:
await self._build_profile_data(builder)
membership = builder.content.get("membership", None)
target = UserID.from_string(builder.state_key)
if membership in self.membership_types_to_include_profile_data_in:
# If event doesn't include a display name, add one.
profile = self.profile_handler
content = builder.content
try:
if "displayname" not in content:
displayname = await profile.get_displayname(target)
if displayname is not None:
content["displayname"] = displayname
if "avatar_url" not in content:
avatar_url = await profile.get_avatar_url(target)
if avatar_url is not None:
content["avatar_url"] = avatar_url
except Exception as e:
logger.info(
"Failed to get profile information for %r: %s", target, e
)
is_exempt = await self._is_exempt_from_privacy_policy(builder, requester)
if require_consent and not is_exempt:
@@ -700,126 +735,6 @@ class EventCreationHandler:
return event, context
async def create_event_for_batch(
self,
requester: Requester,
event_dict: dict,
prev_event_ids: List[str],
depth: int,
state_map: StateMap,
txn_id: Optional[str] = None,
require_consent: bool = True,
outlier: bool = False,
) -> EventBase:
"""
Given a dict from a client, create a new event. Notably does not create an event
context. Adds display names to Join membership events.
Args:
requester
event_dict: An entire event
txn_id
prev_event_ids:
the forward extremities to use as the prev_events for the
new event.
state_map: a state_map of previously created events for batching. Will be used
to calculate the auth_ids for the event, as the previously created events for
batching will not yet have been persisted to the db
require_consent: Whether to check if the requester has
consented to the privacy policy.
outlier: Indicates whether the event is an `outlier`, i.e. if
it's from an arbitrary point and floating in the DAG as
opposed to being inline with the current DAG.
depth: Override the depth used to order the event in the DAG.
Returns:
the created event
"""
await self.auth_blocking.check_auth_blocking(requester=requester)
builder = await self._get_and_validate_builder(event_dict)
if builder.type == EventTypes.Member:
await self._build_profile_data(builder)
is_exempt = await self._is_exempt_from_privacy_policy(builder, requester)
if require_consent and not is_exempt:
await self.assert_accepted_privacy_policy(requester)
if requester.access_token_id is not None:
builder.internal_metadata.token_id = requester.access_token_id
if txn_id is not None:
builder.internal_metadata.txn_id = txn_id
builder.internal_metadata.outlier = outlier
auth_ids = self._event_auth_handler.compute_auth_events(builder, state_map)
event = await builder.build(
prev_event_ids=prev_event_ids,
auth_event_ids=auth_ids,
depth=depth,
)
# Pass on the outlier property from the builder to the event
# after it is created
if builder.internal_metadata.outlier:
event.internal_metadata.outlier = True
self.validator.validate_new(event, self.config)
return event
async def _build_profile_data(self, builder: EventBuilder) -> None:
"""
Helper method to add profile information to membership event
"""
membership = builder.content.get("membership", None)
target = UserID.from_string(builder.state_key)
if membership in self.membership_types_to_include_profile_data_in:
# If event doesn't include a display name, add one.
profile = self.profile_handler
content = builder.content
try:
if "displayname" not in content:
displayname = await profile.get_displayname(target)
if displayname is not None:
content["displayname"] = displayname
if "avatar_url" not in content:
avatar_url = await profile.get_avatar_url(target)
if avatar_url is not None:
content["avatar_url"] = avatar_url
except Exception as e:
logger.info("Failed to get profile information for %r: %s", target, e)
async def _get_and_validate_builder(self, event_dict: dict) -> EventBuilder:
"""
Helper method to create and validate a builder object when creating an event
"""
if event_dict["type"] == EventTypes.Create and event_dict["state_key"] == "":
room_version_id = event_dict["content"]["room_version"]
maybe_room_version_obj = KNOWN_ROOM_VERSIONS.get(room_version_id)
if not maybe_room_version_obj:
# this can happen if support is withdrawn for a room version
raise UnsupportedRoomVersionError(room_version_id)
room_version_obj = maybe_room_version_obj
else:
try:
room_version_obj = await self.store.get_room_version(
event_dict["room_id"]
)
except NotFoundError:
raise AuthError(403, "Unknown room")
builder = self.event_builder_factory.for_room_version(
room_version_obj, event_dict
)
self.validator.validate_builder(builder)
return builder
async def _is_exempt_from_privacy_policy(
self, builder: EventBuilder, requester: Requester
) -> bool:
@@ -1319,147 +1234,6 @@ class EventCreationHandler:
400, "Cannot start threads from an event with a relation"
)
async def handle_create_room_events(
self,
requester: Requester,
events_and_ctx: List[Tuple[EventBase, EventContext]],
ratelimit: bool = True,
) -> EventBase:
"""
Process a batch of room creation events. For each event in the list it checks
the authorization and that the event can be serialized. Returns the last event in the
list once it has been persisted.
Args:
requester: the room creator
events_and_ctx: a set of events and their associated contexts to persist
ratelimit: whether to ratelimit this request
"""
for event, context in events_and_ctx:
try:
validate_event_for_room_version(event)
await self._event_auth_handler.check_auth_rules_from_context(
event, context
)
except AuthError as err:
logger.warning("Denying new event %r because %s", event, err)
raise err
# Ensure that we can round trip before trying to persist in db
try:
dump = json_encoder.encode(event.content)
json_decoder.decode(dump)
except Exception:
logger.exception("Failed to encode content: %r", event.content)
raise
# We now persist the events
try:
result = await self._persist_events_batch(
requester, events_and_ctx, ratelimit
)
except Exception as e:
logger.info(f"Encountered an error persisting events: {e}")
return result
async def _persist_events_batch(
self,
requester: Requester,
events_and_ctx: List[Tuple[EventBase, EventContext]],
ratelimit: bool = True,
) -> EventBase:
"""
Processes the push actions and adds them to the push staging area before attempting to
persist the batch of events.
See handle_create_room_events for arguments
Returns the last event in the list if persisted successfully
"""
for event, context in events_and_ctx:
with opentracing.start_active_span("calculate_push_actions"):
await self._bulk_push_rule_evaluator.action_for_event_by_user(
event, context
)
try:
# If we're a worker we need to hit out to the master.
writer_instance = self._events_shard_config.get_instance(event.room_id)
if writer_instance != self._instance_name:
try:
result = await self.send_events(
instance_name=writer_instance,
store=self.store,
requester=requester,
events_and_ctx=events_and_ctx,
ratelimit=ratelimit,
)
except SynapseError as e:
if e.code == HTTPStatus.CONFLICT:
raise PartialStateConflictError()
raise
stream_id = result["stream_id"]
# If we newly persisted the event then we need to update its
# stream_ordering entry manually (as it was persisted on
# another worker).
event.internal_metadata.stream_ordering = stream_id
return event
last_event = await self.persist_and_notify_batched_events(
requester, events_and_ctx, ratelimit
)
except Exception:
# Ensure that we actually remove the entries in the push actions
# staging area, if we calculated them.
for event, _ in events_and_ctx:
await self.store.remove_push_actions_from_staging(event.event_id)
raise
return last_event
async def persist_and_notify_batched_events(
self,
requester: Requester,
events_and_ctx: List[Tuple[EventBase, EventContext]],
ratelimit: bool = True,
) -> EventBase:
"""
Handles the actual persisting of a batch of events to the DB, and sends the appropriate
notifications when this is done.
Args:
requester: the room creator
events_and_ctx: list of events and their associated contexts to persist
ratelimit: whether to apply ratelimiting to this request
"""
if ratelimit:
await self.request_ratelimiter.ratelimit(requester)
for event, context in events_and_ctx:
await self._actions_by_event_type(event, context)
assert self._storage_controllers.persistence is not None
(
persisted_events,
max_stream_token,
) = await self._storage_controllers.persistence.persist_events(events_and_ctx)
stream_ordering = persisted_events[-1].internal_metadata.stream_ordering
assert stream_ordering is not None
pos = PersistedEventPosition(self._instance_name, stream_ordering)
async def _notify() -> None:
try:
await self.notifier.on_new_room_event(
persisted_events[-1], pos, max_stream_token
)
except Exception:
logger.exception(
"Error notifying about new room event %s",
event.event_id,
)
run_in_background(_notify)
return persisted_events[-1]
@measure_func("handle_new_client_event")
async def handle_new_client_event(
self,
@@ -1794,55 +1568,6 @@ class EventCreationHandler:
requester, is_admin_redaction=is_admin_redaction
)
# run checks/actions on event based on type
await self._actions_by_event_type(event, context)
# Mark any `m.historical` messages as backfilled so they don't appear
# in `/sync` and have the proper decrementing `stream_ordering` as we import
backfilled = False
if event.internal_metadata.is_historical():
backfilled = True
# Note that this returns the event that was persisted, which may not be
# the same as we passed in if it was deduplicated due transaction IDs.
(
event,
event_pos,
max_stream_token,
) = await self._storage_controllers.persistence.persist_event(
event, context=context, backfilled=backfilled
)
if self._ephemeral_events_enabled:
# If there's an expiry timestamp on the event, schedule its expiry.
self._message_handler.maybe_schedule_expiry(event)
async def _notify() -> None:
try:
await self.notifier.on_new_room_event(
event, event_pos, max_stream_token, extra_users=extra_users
)
except Exception:
logger.exception(
"Error notifying about new room event %s",
event.event_id,
)
run_in_background(_notify)
if event.type == EventTypes.Message:
# We don't want to block sending messages on any presence code. This
# matters as sometimes presence code can take a while.
run_in_background(self._bump_active_time, requester.user)
return event
async def _actions_by_event_type(
self, event: EventBase, context: EventContext
) -> None:
"""
Helper function to execute actions/checks based on the event type
"""
if event.type == EventTypes.Member and event.membership == Membership.JOIN:
(
current_membership,
@@ -1863,13 +1588,11 @@ class EventCreationHandler:
original_event_id = event.unsigned.get("replaces_state")
if original_event_id:
original_alias_event = await self.store.get_event(original_event_id)
original_event = await self.store.get_event(original_event_id)
if original_alias_event:
original_alias = original_alias_event.content.get("alias", None)
original_alt_aliases = original_alias_event.content.get(
"alt_aliases", []
)
if original_event:
original_alias = original_event.content.get("alias", None)
original_alt_aliases = original_event.content.get("alt_aliases", [])
# Check the alias is currently valid (if it has changed).
room_alias_str = event.content.get("alias", None)
@@ -2047,6 +1770,46 @@ class EventCreationHandler:
errcode=Codes.INVALID_PARAM,
)
# Mark any `m.historical` messages as backfilled so they don't appear
# in `/sync` and have the proper decrementing `stream_ordering` as we import
backfilled = False
if event.internal_metadata.is_historical():
backfilled = True
# Note that this returns the event that was persisted, which may not be
# the same as we passed in if it was deduplicated due transaction IDs.
(
event,
event_pos,
max_stream_token,
) = await self._storage_controllers.persistence.persist_event(
event, context=context, backfilled=backfilled
)
if self._ephemeral_events_enabled:
# If there's an expiry timestamp on the event, schedule its expiry.
self._message_handler.maybe_schedule_expiry(event)
async def _notify() -> None:
try:
await self.notifier.on_new_room_event(
event, event_pos, max_stream_token, extra_users=extra_users
)
except Exception:
logger.exception(
"Error notifying about new room event %s",
event.event_id,
)
run_in_background(_notify)
if event.type == EventTypes.Message:
# We don't want to block sending messages on any presence code. This
# matters as sometimes presence code can take a while.
run_in_background(self._bump_active_time, requester.user)
return event
async def _maybe_kick_guest_users(
self, event: EventBase, context: EventContext
) -> None:

View File

@@ -108,7 +108,6 @@ class EventContext:
class RoomCreationHandler:
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastores().main
self.state = hs.get_state_handler()
self._storage_controllers = hs.get_storage_controllers()
self.auth = hs.get_auth()
self.auth_blocking = hs.get_auth_blocking()
@@ -120,7 +119,6 @@ class RoomCreationHandler:
self._event_auth_handler = hs.get_event_auth_handler()
self.config = hs.config
self.request_ratelimiter = hs.get_request_ratelimiter()
self.builder = hs.get_event_builder_factory()
# Room state based off defined presets
self._presets_dict: Dict[str, Dict[str, Any]] = {
@@ -718,7 +716,7 @@ class RoomCreationHandler:
if (
self._server_notices_mxid is not None
and user_id == self._server_notices_mxid
and requester.user.to_string() == self._server_notices_mxid
):
# allow the server notices mxid to create rooms
is_requester_admin = True
@@ -1055,21 +1053,13 @@ class RoomCreationHandler:
"""
creator_id = creator.user.to_string()
event_keys = {"room_id": room_id, "sender": creator_id, "state_key": ""}
depth = 1
# the last event sent/persisted to the db
event_keys = {"room_id": room_id, "sender": creator_id, "state_key": ""}
depth = 1
last_sent_event_id: Optional[str] = None
# the most recently created event
prev_event: List[str] = []
# a map of event types, state keys -> event_ids. We collect these mappings this
# as events are created (but not persisted to the db) to determine state for
# future created events (as this info can't be pulled from the db)
state_map: dict = {}
def create_event_dict(etype: str, content: JsonDict, **kwargs: Any) -> JsonDict:
def create(etype: str, content: JsonDict, **kwargs: Any) -> JsonDict:
e = {"type": etype, "content": content}
e.update(event_keys)
@@ -1077,49 +1067,32 @@ class RoomCreationHandler:
return e
async def create_event(
etype: str,
content: JsonDict,
**kwargs: Any,
) -> EventBase:
nonlocal depth
nonlocal prev_event
event_dict = create_event_dict(etype, content, **kwargs)
event = await self.event_creation_handler.create_event_for_batch(
creator,
event_dict,
prev_event,
depth,
state_map,
)
depth += 1
prev_event = [event.event_id]
state_map[(event.type, event.state_key)] = event.event_id
return event
async def send(
event: EventBase,
context: synapse.events.snapshot.EventContext,
creator: Requester,
) -> int:
async def send(etype: str, content: JsonDict, **kwargs: Any) -> int:
nonlocal last_sent_event_id
nonlocal depth
ev = await self.event_creation_handler.handle_new_client_event(
requester=creator,
event=event,
context=context,
event = create(etype, content, **kwargs)
logger.debug("Sending %s in new room", etype)
# Allow these events to be sent even if the user is shadow-banned to
# allow the room creation to complete.
(
sent_event,
last_stream_id,
) = await self.event_creation_handler.create_and_send_nonmember_event(
creator,
event,
ratelimit=False,
ignore_shadow_ban=True,
# Note: we don't pass state_event_ids here because this triggers
# an additional query per event to look them up from the events table.
prev_event_ids=[last_sent_event_id] if last_sent_event_id else [],
depth=depth,
)
last_sent_event_id = ev.event_id
last_sent_event_id = sent_event.event_id
depth += 1
# we know it was persisted, so must have a stream ordering
assert ev.internal_metadata.stream_ordering
return ev.internal_metadata.stream_ordering
return last_stream_id
try:
config = self._presets_dict[preset_config]
@@ -1129,15 +1102,9 @@ class RoomCreationHandler:
)
creation_content.update({"creator": creator_id})
creation_event = await create_event(
EventTypes.Create,
creation_content,
)
creation_context = await self.state.compute_event_context(creation_event)
await send(etype=EventTypes.Create, content=creation_content)
logger.debug("Sending %s in new room", EventTypes.Member)
await send(creation_event, creation_context, creator)
# Room create event must exist at this point
assert last_sent_event_id is not None
member_event_id, _ = await self.room_member_handler.update_membership(
@@ -1151,23 +1118,15 @@ class RoomCreationHandler:
prev_event_ids=[last_sent_event_id],
depth=depth,
)
# last_sent_event_id = member_event_id
prev_event = [member_event_id]
# update the depth and state map here as these are otherwise updated in
# 'create_event' the membership event has been created through a different code
# path
depth += 1
state_map[(EventTypes.Member, creator.user.to_string())] = member_event_id
last_sent_event_id = member_event_id
# We treat the power levels override specially as this needs to be one
# of the first events that get sent into a room.
pl_content = initial_state.pop((EventTypes.PowerLevels, ""), None)
if pl_content is not None:
power_event = await create_event(EventTypes.PowerLevels, pl_content)
power_context = await self.state.compute_event_context(power_event)
current_state_group = power_context._state_group
await send(power_event, power_context, creator)
last_sent_stream_id = await send(
etype=EventTypes.PowerLevels, content=pl_content
)
else:
power_level_content: JsonDict = {
"users": {creator_id: 100},
@@ -1210,92 +1169,48 @@ class RoomCreationHandler:
# apply those.
if power_level_content_override:
power_level_content.update(power_level_content_override)
pl_event = await create_event(
EventTypes.PowerLevels,
power_level_content,
)
pl_context = await self.state.compute_event_context(pl_event)
current_state_group = pl_context._state_group
await send(pl_event, pl_context, creator)
events_to_send = []
last_sent_stream_id = await send(
etype=EventTypes.PowerLevels, content=power_level_content
)
if room_alias and (EventTypes.CanonicalAlias, "") not in initial_state:
room_alias_event = await create_event(
EventTypes.CanonicalAlias,
{"alias": room_alias.to_string()},
last_sent_stream_id = await send(
etype=EventTypes.CanonicalAlias,
content={"alias": room_alias.to_string()},
)
assert current_state_group is not None
room_alias_context = await self.state.compute_event_context_for_batched(
room_alias_event, state_map, current_state_group
)
current_state_group = room_alias_context._state_group
events_to_send.append((room_alias_event, room_alias_context))
if (EventTypes.JoinRules, "") not in initial_state:
join_rules_event = await create_event(
EventTypes.JoinRules,
{"join_rule": config["join_rules"]},
last_sent_stream_id = await send(
etype=EventTypes.JoinRules, content={"join_rule": config["join_rules"]}
)
assert current_state_group is not None
join_rules_context = await self.state.compute_event_context_for_batched(
join_rules_event, state_map, current_state_group
)
current_state_group = join_rules_context._state_group
events_to_send.append((join_rules_event, join_rules_context))
if (EventTypes.RoomHistoryVisibility, "") not in initial_state:
visibility_event = await create_event(
EventTypes.RoomHistoryVisibility,
{"history_visibility": config["history_visibility"]},
last_sent_stream_id = await send(
etype=EventTypes.RoomHistoryVisibility,
content={"history_visibility": config["history_visibility"]},
)
assert current_state_group is not None
visibility_context = await self.state.compute_event_context_for_batched(
visibility_event, state_map, current_state_group
)
current_state_group = visibility_context._state_group
events_to_send.append((visibility_event, visibility_context))
if config["guest_can_join"]:
if (EventTypes.GuestAccess, "") not in initial_state:
guest_access_event = await create_event(
EventTypes.GuestAccess,
{EventContentFields.GUEST_ACCESS: GuestAccess.CAN_JOIN},
last_sent_stream_id = await send(
etype=EventTypes.GuestAccess,
content={EventContentFields.GUEST_ACCESS: GuestAccess.CAN_JOIN},
)
assert current_state_group is not None
guest_access_context = (
await self.state.compute_event_context_for_batched(
guest_access_event, state_map, current_state_group
)
)
current_state_group = guest_access_context._state_group
events_to_send.append((guest_access_event, guest_access_context))
for (etype, state_key), content in initial_state.items():
event = await create_event(etype, content, state_key=state_key)
assert current_state_group is not None
context = await self.state.compute_event_context_for_batched(
event, state_map, current_state_group
last_sent_stream_id = await send(
etype=etype, state_key=state_key, content=content
)
current_state_group = context._state_group
events_to_send.append((event, context))
if config["encrypted"]:
encryption_event = await create_event(
EventTypes.RoomEncryption,
{"algorithm": RoomEncryptionAlgorithms.DEFAULT},
last_sent_stream_id = await send(
etype=EventTypes.RoomEncryption,
state_key="",
content={"algorithm": RoomEncryptionAlgorithms.DEFAULT},
)
assert current_state_group is not None
encryption_context = await self.state.compute_event_context_for_batched(
encryption_event, state_map, current_state_group
)
events_to_send.append((encryption_event, encryption_context))
last_event = await self.event_creation_handler.handle_create_room_events(
creator, events_to_send
)
assert last_event.internal_metadata.stream_ordering is not None
return last_event.internal_metadata.stream_ordering, last_event.event_id, depth
return last_sent_stream_id, last_sent_event_id, depth
def _generate_room_id(self) -> str:
"""Generates a random room ID.

View File

@@ -72,12 +72,10 @@ class SynapseRequest(Request):
site: "SynapseSite",
*args: Any,
max_request_body_size: int = 1024,
request_id_header: Optional[str] = None,
**kw: Any,
):
super().__init__(channel, *args, **kw)
self._max_request_body_size = max_request_body_size
self.request_id_header = request_id_header
self.synapse_site = site
self.reactor = site.reactor
self._channel = channel # this is used by the tests
@@ -174,14 +172,7 @@ class SynapseRequest(Request):
self._opentracing_span = span
def get_request_id(self) -> str:
request_id_value = None
if self.request_id_header:
request_id_value = self.getHeader(self.request_id_header)
if request_id_value is None:
request_id_value = str(self.request_seq)
return "%s-%s" % (self.get_method(), request_id_value)
return "%s-%i" % (self.get_method(), self.request_seq)
def get_redacted_uri(self) -> str:
"""Gets the redacted URI associated with the request (or placeholder if the URI
@@ -620,15 +611,12 @@ class SynapseSite(Site):
proxied = config.http_options.x_forwarded
request_class = XForwardedForRequest if proxied else SynapseRequest
request_id_header = config.http_options.request_id_header
def request_factory(channel: HTTPChannel, queued: bool) -> Request:
return request_class(
channel,
self,
max_request_body_size=max_request_body_size,
queued=queued,
request_id_header=request_id_header,
)
self.requestFactory = request_factory # type: ignore

View File

@@ -25,7 +25,6 @@ from synapse.replication.http import (
push,
register,
send_event,
send_events,
state,
streams,
)
@@ -44,7 +43,6 @@ class ReplicationRestResource(JsonResource):
def register_servlets(self, hs: "HomeServer") -> None:
send_event.register_servlets(hs, self)
send_events.register_servlets(hs, self)
federation.register_servlets(hs, self)
presence.register_servlets(hs, self)
membership.register_servlets(hs, self)

View File

@@ -1,165 +0,0 @@
# Copyright 2022 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING, List, Tuple
from twisted.web.server import Request
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import EventBase, make_event_from_dict
from synapse.events.snapshot import EventContext
from synapse.http.server import HttpServer
from synapse.http.servlet import parse_json_object_from_request
from synapse.replication.http._base import ReplicationEndpoint
from synapse.types import JsonDict, Requester
from synapse.util.metrics import Measure
if TYPE_CHECKING:
from synapse.server import HomeServer
from synapse.storage.databases.main import DataStore
logger = logging.getLogger(__name__)
class ReplicationSendEventsRestServlet(ReplicationEndpoint):
"""Handles batches of newly created events on workers, including persisting and
notifying.
The API looks like:
POST /_synapse/replication/send_events/:txn_id
{
"events": [{
"event": { .. serialized event .. },
"room_version": .., // "1", "2", "3", etc: the version of the room
// containing the event
"event_format_version": .., // 1,2,3 etc: the event format version
"internal_metadata": { .. serialized internal_metadata .. },
"outlier": true|false,
"rejected_reason": .., // The event.rejected_reason field
"context": { .. serialized event context .. },
"requester": { .. serialized requester .. },
"ratelimit": true,
}]
}
200 OK
{ "stream_id": 12345, "event_id": "$abcdef..." }
Responds with a 409 when a `PartialStateConflictError` is raised due to an event
context that needs to be recomputed due to the un-partial stating of a room.
"""
NAME = "send_events"
PATH_ARGS = ()
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.event_creation_handler = hs.get_event_creation_handler()
self.store = hs.get_datastores().main
self._storage_controllers = hs.get_storage_controllers()
self.clock = hs.get_clock()
@staticmethod
async def _serialize_payload( # type: ignore[override]
store: "DataStore",
events_and_ctx: List[Tuple[EventBase, EventContext]],
requester: Requester,
ratelimit: bool,
) -> JsonDict:
"""
Args:
store
requester
events_and_ctx
ratelimit
"""
serialized_events = []
for event, context in events_and_ctx:
serialized_context = await context.serialize(event, store)
serialized_event = {
"event": event.get_pdu_json(),
"room_version": event.room_version.identifier,
"event_format_version": event.format_version,
"internal_metadata": event.internal_metadata.get_dict(),
"outlier": event.internal_metadata.is_outlier(),
"rejected_reason": event.rejected_reason,
"context": serialized_context,
"requester": requester.serialize(),
"ratelimit": ratelimit,
}
serialized_events.append(serialized_event)
payload = {"events": serialized_events}
return payload
async def _handle_request( # type: ignore[override]
self, request: Request
) -> Tuple[int, JsonDict]:
with Measure(self.clock, "repl_send_events_parse"):
payload = parse_json_object_from_request(request)
events_and_ctx = []
events = payload["events"]
for event_payload in events:
event_dict = event_payload["event"]
room_ver = KNOWN_ROOM_VERSIONS[event_payload["room_version"]]
internal_metadata = event_payload["internal_metadata"]
rejected_reason = event_payload["rejected_reason"]
event = make_event_from_dict(
event_dict, room_ver, internal_metadata, rejected_reason
)
event.internal_metadata.outlier = event_payload["outlier"]
requester = Requester.deserialize(
self.store, event_payload["requester"]
)
context = EventContext.deserialize(
self._storage_controllers, event_payload["context"]
)
ratelimit = event_payload["ratelimit"]
events_and_ctx.append((event, context))
logger.info(
"Got batch of events to send, last ID of batch is: %s, sending into room: %s",
event.event_id,
event.room_id,
)
last_event = (
await self.event_creation_handler.persist_and_notify_batched_events(
requester, events_and_ctx, ratelimit
)
)
return (
200,
{
"stream_id": last_event.internal_metadata.stream_ordering,
"event_id": last_event.event_id,
},
)
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
ReplicationSendEventsRestServlet(hs).register(http_server)

View File

@@ -80,7 +80,6 @@ from synapse.rest.admin.users import (
SearchUsersRestServlet,
ShadowBanRestServlet,
UserAdminServlet,
UserByExternalId,
UserMembershipRestServlet,
UserRegisterServlet,
UserRestServletV2,
@@ -276,7 +275,6 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
ListDestinationsRestServlet(hs).register(http_server)
RoomMessagesRestServlet(hs).register(http_server)
RoomTimestampToEventRestServlet(hs).register(http_server)
UserByExternalId(hs).register(http_server)
# Some servlets only get registered for the main process.
if hs.config.worker.worker_app is None:

View File

@@ -1156,30 +1156,3 @@ class AccountDataRestServlet(RestServlet):
"rooms": by_room_data,
},
}
class UserByExternalId(RestServlet):
"""Find a user based on an external ID from an auth provider"""
PATTERNS = admin_patterns(
"/auth_providers/(?P<provider>[^/]*)/users/(?P<external_id>[^/]*)"
)
def __init__(self, hs: "HomeServer"):
self._auth = hs.get_auth()
self._store = hs.get_datastores().main
async def on_GET(
self,
request: SynapseRequest,
provider: str,
external_id: str,
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self._auth, request)
user_id = await self._store.get_user_by_external_id(provider, external_id)
if user_id is None:
raise NotFoundError("User not found")
return HTTPStatus.OK, {"user_id": user_id}

View File

@@ -19,7 +19,6 @@ from typing import TYPE_CHECKING, List, Optional, Tuple
from urllib.parse import urlparse
from pydantic import StrictBool, StrictStr, constr
from typing_extensions import Literal
from twisted.web.server import Request
@@ -44,7 +43,6 @@ from synapse.metrics import threepid_send_requests
from synapse.push.mailer import Mailer
from synapse.rest.client.models import (
AuthenticationData,
ClientSecretStr,
EmailRequestTokenBody,
MsisdnRequestTokenBody,
)
@@ -629,11 +627,6 @@ class ThreepidAddRestServlet(RestServlet):
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
class PostBody(RequestBodyModel):
auth: Optional[AuthenticationData] = None
client_secret: ClientSecretStr
sid: StrictStr
@interactive_auth_handler
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
if not self.hs.config.registration.enable_3pid_changes:
@@ -643,17 +636,22 @@ class ThreepidAddRestServlet(RestServlet):
requester = await self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
body = parse_and_validate_json_object_from_request(request, self.PostBody)
body = parse_json_object_from_request(request)
assert_params_in_dict(body, ["client_secret", "sid"])
sid = body["sid"]
client_secret = body["client_secret"]
assert_valid_client_secret(client_secret)
await self.auth_handler.validate_user_via_ui_auth(
requester,
request,
body.dict(exclude_unset=True),
body,
"add a third-party identifier to your account",
)
validation_session = await self.identity_handler.validate_threepid_session(
body.client_secret, body.sid
client_secret, sid
)
if validation_session:
await self.auth_handler.add_threepid(
@@ -678,20 +676,23 @@ class ThreepidBindRestServlet(RestServlet):
self.identity_handler = hs.get_identity_handler()
self.auth = hs.get_auth()
class PostBody(RequestBodyModel):
client_secret: ClientSecretStr
id_access_token: StrictStr
id_server: StrictStr
sid: StrictStr
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
body = parse_and_validate_json_object_from_request(request, self.PostBody)
body = parse_json_object_from_request(request)
assert_params_in_dict(
body, ["id_server", "sid", "id_access_token", "client_secret"]
)
id_server = body["id_server"]
sid = body["sid"]
id_access_token = body["id_access_token"]
client_secret = body["client_secret"]
assert_valid_client_secret(client_secret)
requester = await self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
await self.identity_handler.bind_threepid(
body.client_secret, body.sid, user_id, body.id_server, body.id_access_token
client_secret, sid, user_id, id_server, id_access_token
)
return 200, {}
@@ -707,27 +708,23 @@ class ThreepidUnbindRestServlet(RestServlet):
self.auth = hs.get_auth()
self.datastore = self.hs.get_datastores().main
class PostBody(RequestBodyModel):
address: StrictStr
id_server: Optional[StrictStr] = None
medium: Literal["email", "msisdn"]
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
"""Unbind the given 3pid from a specific identity server, or identity servers that are
known to have this 3pid bound
"""
requester = await self.auth.get_user_by_req(request)
body = parse_and_validate_json_object_from_request(request, self.PostBody)
body = parse_json_object_from_request(request)
assert_params_in_dict(body, ["medium", "address"])
medium = body.get("medium")
address = body.get("address")
id_server = body.get("id_server")
# Attempt to unbind the threepid from an identity server. If id_server is None, try to
# unbind from all identity servers this threepid has been added to in the past
result = await self.identity_handler.try_unbind_threepid(
requester.user.to_string(),
{
"address": body.address,
"medium": body.medium,
"id_server": body.id_server,
},
{"address": address, "medium": medium, "id_server": id_server},
)
return 200, {"id_server_unbind_result": "success" if result else "no-support"}
@@ -741,25 +738,21 @@ class ThreepidDeleteRestServlet(RestServlet):
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
class PostBody(RequestBodyModel):
address: StrictStr
id_server: Optional[StrictStr] = None
medium: Literal["email", "msisdn"]
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
if not self.hs.config.registration.enable_3pid_changes:
raise SynapseError(
400, "3PID changes are disabled on this server", Codes.FORBIDDEN
)
body = parse_and_validate_json_object_from_request(request, self.PostBody)
body = parse_json_object_from_request(request)
assert_params_in_dict(body, ["medium", "address"])
requester = await self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
try:
ret = await self.auth_handler.delete_threepid(
user_id, body.medium, body.address, body.id_server
user_id, body["medium"], body["address"], body.get("id_server")
)
except Exception:
# NB. This endpoint should succeed if there is nothing to

View File

@@ -36,20 +36,18 @@ class AuthenticationData(RequestBodyModel):
type: Optional[StrictStr] = None
if TYPE_CHECKING:
ClientSecretStr = StrictStr
else:
# See also assert_valid_client_secret()
ClientSecretStr = constr(
regex="[0-9a-zA-Z.=_-]", # noqa: F722
min_length=1,
max_length=255,
strict=True,
)
class ThreePidRequestTokenBody(RequestBodyModel):
if TYPE_CHECKING:
client_secret: StrictStr
else:
# See also assert_valid_client_secret()
client_secret: constr(
regex="[0-9a-zA-Z.=_-]", # noqa: F722
min_length=0,
max_length=255,
strict=True,
)
class ThreepidRequestTokenBody(RequestBodyModel):
client_secret: ClientSecretStr
id_server: Optional[StrictStr]
id_access_token: Optional[StrictStr]
next_link: Optional[StrictStr]
@@ -64,7 +62,7 @@ class ThreepidRequestTokenBody(RequestBodyModel):
return token
class EmailRequestTokenBody(ThreepidRequestTokenBody):
class EmailRequestTokenBody(ThreePidRequestTokenBody):
email: StrictStr
# Canonicalise the email address. The addresses are all stored canonicalised
@@ -82,6 +80,6 @@ else:
ISO3116_1_Alpha_2 = constr(regex="[A-Z]{2}", strict=True)
class MsisdnRequestTokenBody(ThreepidRequestTokenBody):
class MsisdnRequestTokenBody(ThreePidRequestTokenBody):
country: ISO3116_1_Alpha_2
phone_number: StrictStr

View File

@@ -282,6 +282,7 @@ class StateHandler:
RuntimeError if `state_ids_before_event` is not provided and one or more
prev events are missing or outliers.
"""
assert not event.internal_metadata.is_outlier()
#
@@ -332,7 +333,6 @@ class StateHandler:
logger.debug("calling resolve_state_groups from compute_event_context")
# we've already taken into account partial state, so no need to wait for
# complete state here.
entry = await self.resolve_state_groups_for_events(
event.room_id,
event.prev_event_ids(),
@@ -420,69 +420,6 @@ class StateHandler:
partial_state=partial_state,
)
async def compute_event_context_for_batched(
self,
event: EventBase,
state_ids_before_event: StateMap[str],
current_state_group: int,
) -> EventContext:
"""
Generate an event context for an event that has not yet been persisted to the
database. Intended for use with events that are created to be persisted in a batch.
Args:
event: the event the context is being computed for
state_ids_before_event: a state map consisting of the state ids of the events
created prior to this event.
current_state_group: the current state group before the event.
"""
state_group_before_event_prev_group = None
deltas_to_state_group_before_event = None
state_group_before_event = current_state_group
# if the event is not state, we are set
if not event.is_state():
return EventContext.with_state(
storage=self._storage_controllers,
state_group_before_event=state_group_before_event,
state_group=state_group_before_event,
state_delta_due_to_event={},
prev_group=state_group_before_event_prev_group,
delta_ids=deltas_to_state_group_before_event,
partial_state=False,
)
# otherwise, we'll need to create a new state group for after the event
key = (event.type, event.state_key)
if state_ids_before_event is not None:
replaces = state_ids_before_event.get(key)
if replaces and replaces != event.event_id:
event.unsigned["replaces_state"] = replaces
delta_ids = {key: event.event_id}
state_group_after_event = (
await self._state_storage_controller.store_state_group(
event.event_id,
event.room_id,
prev_group=state_group_before_event,
delta_ids=delta_ids,
current_state_ids=None,
)
)
return EventContext.with_state(
storage=self._storage_controllers,
state_group=state_group_after_event,
state_group_before_event=state_group_before_event,
state_delta_due_to_event=delta_ids,
prev_group=state_group_before_event,
delta_ids=delta_ids,
partial_state=False,
)
@measure_func()
async def resolve_state_groups_for_events(
self, room_id: str, event_ids: Collection[str], await_full_state: bool = True

View File

@@ -577,21 +577,6 @@ async def _iterative_auth_checks(
if ev.rejected_reason is None:
auth_events[key] = event_map[ev_id]
if event.rejected_reason is not None:
# Do not admit previously rejected events into state.
# TODO: This isn't spec compliant. Events that were previously rejected due
# to failing auth checks at their state, but pass auth checks during
# state resolution should be accepted. Synapse does not handle the
# change of rejection status well, so we preserve the previous
# rejection status for now.
#
# Note that events rejected for non-state reasons, such as having the
# wrong auth events, should remain rejected.
#
# https://spec.matrix.org/v1.2/rooms/v9/#rejected-events
# https://github.com/matrix-org/synapse/issues/13797
continue
try:
event_auth.check_state_dependent_auth_rules(
event,

View File

@@ -533,7 +533,6 @@ class BackgroundUpdater:
index_name: name of index to add
table: table to add index to
columns: columns/expressions to include in index
where_clause: A WHERE clause to specify a partial unique index.
unique: true to make a UNIQUE index
psql_only: true to only create this index on psql databases (useful
for virtual sqlite tables)

View File

@@ -1191,7 +1191,6 @@ class DatabasePool:
keyvalues: Dict[str, Any],
values: Dict[str, Any],
insertion_values: Optional[Dict[str, Any]] = None,
where_clause: Optional[str] = None,
lock: bool = True,
) -> bool:
"""
@@ -1204,7 +1203,6 @@ class DatabasePool:
keyvalues: The unique key tables and their new values
values: The nonunique columns and their new values
insertion_values: additional key/values to use only when inserting
where_clause: An index predicate to apply to the upsert.
lock: True to lock the table when doing the upsert. Unused when performing
a native upsert.
Returns:
@@ -1215,12 +1213,7 @@ class DatabasePool:
if table not in self._unsafe_to_upsert_tables:
return self.simple_upsert_txn_native_upsert(
txn,
table,
keyvalues,
values,
insertion_values=insertion_values,
where_clause=where_clause,
txn, table, keyvalues, values, insertion_values=insertion_values
)
else:
return self.simple_upsert_txn_emulated(
@@ -1229,7 +1222,6 @@ class DatabasePool:
keyvalues,
values,
insertion_values=insertion_values,
where_clause=where_clause,
lock=lock,
)
@@ -1240,7 +1232,6 @@ class DatabasePool:
keyvalues: Dict[str, Any],
values: Dict[str, Any],
insertion_values: Optional[Dict[str, Any]] = None,
where_clause: Optional[str] = None,
lock: bool = True,
) -> bool:
"""
@@ -1249,7 +1240,6 @@ class DatabasePool:
keyvalues: The unique key tables and their new values
values: The nonunique columns and their new values
insertion_values: additional key/values to use only when inserting
where_clause: An index predicate to apply to the upsert.
lock: True to lock the table when doing the upsert.
Returns:
Returns True if a row was inserted or updated (i.e. if `values` is
@@ -1269,17 +1259,14 @@ class DatabasePool:
else:
return "%s = ?" % (key,)
# Generate a where clause of each keyvalue and optionally the provided
# index predicate.
where = [_getwhere(k) for k in keyvalues]
if where_clause:
where.append(where_clause)
if not values:
# If `values` is empty, then all of the values we care about are in
# the unique key, so there is nothing to UPDATE. We can just do a
# SELECT instead to see if it exists.
sql = "SELECT 1 FROM %s WHERE %s" % (table, " AND ".join(where))
sql = "SELECT 1 FROM %s WHERE %s" % (
table,
" AND ".join(_getwhere(k) for k in keyvalues),
)
sqlargs = list(keyvalues.values())
txn.execute(sql, sqlargs)
if txn.fetchall():
@@ -1290,7 +1277,7 @@ class DatabasePool:
sql = "UPDATE %s SET %s WHERE %s" % (
table,
", ".join("%s = ?" % (k,) for k in values),
" AND ".join(where),
" AND ".join(_getwhere(k) for k in keyvalues),
)
sqlargs = list(values.values()) + list(keyvalues.values())
@@ -1320,7 +1307,6 @@ class DatabasePool:
keyvalues: Dict[str, Any],
values: Dict[str, Any],
insertion_values: Optional[Dict[str, Any]] = None,
where_clause: Optional[str] = None,
) -> bool:
"""
Use the native UPSERT functionality in PostgreSQL.
@@ -1330,7 +1316,6 @@ class DatabasePool:
keyvalues: The unique key tables and their new values
values: The nonunique columns and their new values
insertion_values: additional key/values to use only when inserting
where_clause: An index predicate to apply to the upsert.
Returns:
Returns True if a row was inserted or updated (i.e. if `values` is
@@ -1346,12 +1331,11 @@ class DatabasePool:
allvalues.update(values)
latter = "UPDATE SET " + ", ".join(k + "=EXCLUDED." + k for k in values)
sql = "INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) %s DO %s" % (
sql = ("INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) DO %s") % (
table,
", ".join(k for k in allvalues),
", ".join("?" for _ in allvalues),
", ".join(k for k in keyvalues),
f"WHERE {where_clause}" if where_clause else "",
latter,
)
txn.execute(sql, list(allvalues.values()))

View File

@@ -194,7 +194,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
# changed its content in the database. We can't call
# self._invalidate_cache_and_stream because self.get_event_cache isn't of the
# right type.
self.invalidate_get_event_cache_after_txn(txn, event.event_id)
self.invalidate_get_event_cache_by_event_id_after_txn(txn, event.event_id)
# Send that invalidation to replication so that other workers also invalidate
# the event cache.
self._send_invalidation_to_replication(

View File

@@ -1294,8 +1294,10 @@ class PersistEventsStore:
"""
depth_updates: Dict[str, int] = {}
for event, context in events_and_contexts:
# Remove the any existing cache entries for the event_ids
self.store.invalidate_get_event_cache_after_txn(txn, event.event_id)
# Remove any existing cache entries for the event_ids
self.store.invalidate_get_event_cache_by_event_id_after_txn(
txn, event.event_id
)
# Then update the `stream_ordering` position to mark the latest
# event as the front of the room. This should not be done for
# backfilled events because backfilled events have negative
@@ -1703,7 +1705,7 @@ class PersistEventsStore:
_invalidate_caches_for_event.
"""
assert event.redacts is not None
self.store.invalidate_get_event_cache_after_txn(txn, event.redacts)
self.store.invalidate_get_event_cache_by_event_id_after_txn(txn, event.redacts)
txn.call_after(self.store.get_relations_for_event.invalidate, (event.redacts,))
txn.call_after(self.store.get_applicable_edit.invalidate, (event.redacts,))

View File

@@ -80,6 +80,7 @@ from synapse.types import JsonDict, get_domain_from_id
from synapse.util import unwrapFirstError
from synapse.util.async_helpers import ObservableDeferred, delay_cancellation
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.dual_lookup_cache import DualLookupCache
from synapse.util.caches.lrucache import AsyncLruCache
from synapse.util.cancellation import cancellable
from synapse.util.iterutils import batch_iter
@@ -245,6 +246,8 @@ class EventsWorkerStore(SQLBaseStore):
] = AsyncLruCache(
cache_name="*getEvent*",
max_size=hs.config.caches.event_cache_size,
cache_type=DualLookupCache,
dual_lookup_secondary_key_function=lambda v: (v.event.room_id,),
)
# Map from event ID to a deferred that will result in a map from event
@@ -733,7 +736,7 @@ class EventsWorkerStore(SQLBaseStore):
return event_entry_map
def invalidate_get_event_cache_after_txn(
def invalidate_get_event_cache_by_event_id_after_txn(
self, txn: LoggingTransaction, event_id: str
) -> None:
"""
@@ -747,10 +750,31 @@ class EventsWorkerStore(SQLBaseStore):
event_id: the event ID to be invalidated from caches
"""
txn.async_call_after(self._invalidate_async_get_event_cache, event_id)
txn.call_after(self._invalidate_local_get_event_cache, event_id)
txn.async_call_after(
self._invalidate_async_get_event_cache_by_event_id, event_id
)
txn.call_after(self._invalidate_local_get_event_cache_by_event_id, event_id)
async def _invalidate_async_get_event_cache(self, event_id: str) -> None:
def invalidate_get_event_cache_by_room_id_after_txn(
self, txn: LoggingTransaction, room_id: str
) -> None:
"""
Prepares a database transaction to invalidate the get event cache for a given
room ID when executed successfully. This is achieved by attaching two callbacks
to the transaction, one to invalidate the async cache and one for the in memory
sync cache (importantly called in that order).
Arguments:
txn: the database transaction to attach the callbacks to.
room_id: the room ID to invalidate all associated event caches for.
"""
txn.async_call_after(self._invalidate_async_get_event_cache_by_room_id, room_id)
txn.call_after(self._invalidate_local_get_event_cache_by_room_id, room_id)
async def _invalidate_async_get_event_cache_by_event_id(
self, event_id: str
) -> None:
"""
Invalidates an event in the asyncronous get event cache, which may be remote.
@@ -760,7 +784,18 @@ class EventsWorkerStore(SQLBaseStore):
await self._get_event_cache.invalidate((event_id,))
def _invalidate_local_get_event_cache(self, event_id: str) -> None:
async def _invalidate_async_get_event_cache_by_room_id(self, room_id: str) -> None:
"""
Invalidates all events associated with a given room in the asyncronous get event
cache, which may be remote.
Arguments:
room_id: the room ID to invalidate associated events of.
"""
await self._get_event_cache.invalidate((room_id,))
def _invalidate_local_get_event_cache_by_event_id(self, event_id: str) -> None:
"""
Invalidates an event in local in-memory get event caches.
@@ -772,6 +807,18 @@ class EventsWorkerStore(SQLBaseStore):
self._event_ref.pop(event_id, None)
self._current_event_fetches.pop(event_id, None)
def _invalidate_local_get_event_cache_by_room_id(self, room_id: str) -> None:
"""
Invalidates all events associated with a given room ID in local in-memory
get event caches.
Arguments:
room_id: the room ID to invalidate events of.
"""
self._get_event_cache.invalidate_local((room_id,))
# TODO: invalidate _event_ref and _current_event_fetches. How?
async def _get_events_from_cache(
self, events: Iterable[str], update_metrics: bool = True
) -> Dict[str, EventCacheEntry]:
@@ -2284,7 +2331,7 @@ class EventsWorkerStore(SQLBaseStore):
updatevalues={"rejection_reason": rejection_reason},
)
self.invalidate_get_event_cache_after_txn(txn, event_id)
self.invalidate_get_event_cache_by_event_id_after_txn(txn, event_id)
# TODO(faster_joins): invalidate the cache on workers. Ideally we'd just
# call '_send_invalidation_to_replication', but we actually need the other

View File

@@ -304,7 +304,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
self._invalidate_cache_and_stream(
txn, self.have_seen_event, (room_id, event_id)
)
self.invalidate_get_event_cache_after_txn(txn, event_id)
self.invalidate_get_event_cache_by_event_id_after_txn(txn, event_id)
logger.info("[purge] done")
@@ -419,7 +419,6 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
"event_forward_extremities",
"event_push_actions",
"event_search",
"event_failed_pull_attempts",
"partial_state_events",
"events",
"federation_inbound_events_staging",
@@ -442,10 +441,6 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
"e2e_room_keys",
"event_push_summary",
"pusher_throttle",
"insertion_events",
"insertion_event_extremities",
"insertion_event_edges",
"batch_events",
"room_account_data",
"room_tags",
# "rooms" happens last, to keep the foreign keys in the other tables
@@ -483,6 +478,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
# XXX: as with purge_history, this is racy, but no worse than other races
# that already exist.
self._invalidate_cache_and_stream(txn, self.have_seen_event, (room_id,))
self._invalidate_local_get_event_cache_by_room_id(room_id)
logger.info("[purge] done")

View File

@@ -83,8 +83,6 @@ Changes in SCHEMA_VERSION = 73;
event_push_summary, receipts_linearized, and receipts_graph.
- Add table `event_failed_pull_attempts` to keep track when we fail to pull
events over federation.
- Add indexes to various tables (`event_failed_pull_attempts`, `insertion_events`,
`batch_events`) to make it easy to delete all associated rows when purging a room.
"""

View File

@@ -1,22 +0,0 @@
/* Copyright 2022 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- Add index so we can easily purge all rows from a given `room_id`
CREATE INDEX IF NOT EXISTS event_failed_pull_attempts_room_id ON event_failed_pull_attempts(room_id);
-- MSC2716 related tables:
-- Add indexes so we can easily purge all rows from a given `room_id`
CREATE INDEX IF NOT EXISTS insertion_events_room_id ON insertion_events(room_id);
CREATE INDEX IF NOT EXISTS batch_events_room_id ON batch_events(room_id);

View File

@@ -0,0 +1,238 @@
# Copyright 2022 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import (
Callable,
Dict,
Generic,
ItemsView,
List,
Optional,
TypeVar,
Union,
ValuesView,
)
# Used to discern between a value not existing in a map, or the value being 'None'.
SENTINEL = object()
# The type of the primary dict's keys.
PKT = TypeVar("PKT")
# The type of the primary dict's values.
PVT = TypeVar("PVT")
# The type of the secondary dict's keys.
SKT = TypeVar("SKT")
logger = logging.getLogger(__name__)
class SecondarySet(set):
"""
Used to differentiate between an entry in the secondary_dict, and a set stored
in the primary_dict. This is necessary as pop() can return either.
"""
class DualLookupCache(Generic[PKT, PVT, SKT]):
"""
A backing store for LruCache that supports multiple entry points.
Allows subsets of data to be deleted efficiently without requiring extra
information to query.
The data structure is two dictionaries:
* primary_dict containing a mapping of primary_key -> value.
* secondary_dict containing a mapping of secondary_key -> set of primary_key.
On insert, a mapping in the primary_dict must be created. A mapping in the
secondary_dict from a secondary_key to (a set containing) the same
primary_key will be made. The secondary_key
must be derived from the inserted value via a lambda function provided at cache
initialisation. This is so invalidated entries in the primary_dict may automatically
invalidate those in the secondary_dict. The secondary_key may be associated with one
or more primary_key's.
This creates an interface which allows for efficient lookups of a value given
a primary_key, as well as efficient invalidation of a subset of mapping in the
primary_dict given a secondary_key. A primary_key may not be associated with more
than one secondary_key.
As a worked example, consider storing a cache of room events. We could configure
the cache to store mappings between EventIDs and EventBase in the primary_dict,
while storing a mapping between room IDs and event IDs as the secondary_dict:
primary_dict: EventID -> EventBase
secondary_dict: RoomID -> {EventID, EventID, ...}
This would be efficient for the following operations:
* Given an EventID, look up the associated EventBase, and thus the roomID.
* Given a RoomID, invalidate all primary_dict entries for events in that room.
Since this is intended as a backing store for LRUCache, when it came time to evict
an entry from the primary_dict (EventID -> EventBase), the secondary_key could be
derived from a provided lambda function:
secondary_key = lambda event_base: event_base.room_id
The EventID set under room_id would then have the appropriate EventID entry evicted.
"""
def __init__(self, secondary_key_function: Callable[[PVT], SKT]) -> None:
self._primary_dict: Dict[PKT, PVT] = {}
self._secondary_dict: Dict[SKT, SecondarySet] = {}
self._secondary_key_function = secondary_key_function
def __setitem__(self, key: PKT, value: PVT) -> None:
self.set(key, value)
def __contains__(self, key: PKT) -> bool:
return key in self._primary_dict
def set(self, key: PKT, value: PVT) -> None:
"""Add an entry to the cache.
Will add an entry to the primary_dict consisting of key->value, as well as append
to the set referred to by secondary_key_function(value) in the secondary_dict.
Args:
key: The key for a new mapping in primary_dict.
value: The value for a new mapping in primary_dict.
"""
# Create an entry in the primary_dict.
self._primary_dict[key] = value
# Derive the secondary_key to use from the given primary_value.
secondary_key = self._secondary_key_function(value)
# TODO: If the lambda function resolves to None, don't insert an entry?
# And create a mapping in the secondary_dict to a set containing the
# primary_key, creating the set if necessary.
secondary_key_set = self._secondary_dict.setdefault(
secondary_key, SecondarySet()
)
secondary_key_set.add(key)
logger.info("*** Insert into primary_dict: %s: %s", key, value)
logger.info("*** Insert into secondary_dict: %s: %s", secondary_key, key)
def get(self, key: PKT, default: Optional[PVT] = None) -> Optional[PVT]:
"""Retrieve a value from the cache if it exists. If not, return the default
value.
This method simply pulls entries from the primary_dict.
# TODO: Any use cases for externally getting entries from the secondary_dict?
Args:
key: The key to search the cache for.
default: The default value to return if the given key is not found.
Returns:
The value referenced by the given key, if it exists in the cache. If not,
the value of `default` will be returned.
"""
logger.info("*** Retrieving key from primary_dict: %s", key)
return self._primary_dict.get(key, default)
def clear(self) -> None:
"""Evicts all entries from the cache."""
self._primary_dict.clear()
self._secondary_dict.clear()
def pop(
self, key: Union[PKT, SKT], default: Optional[Union[Dict[PKT, PVT], PVT]] = None
) -> Optional[Union[Dict[PKT, PVT], PVT]]:
"""Remove an entry from either the primary_dict or secondary_dict.
The primary_dict is checked first for the key. If an entry is found, it is
removed from the primary_dict and returned.
If no entry in the primary_dict exists, then the secondary_dict is checked.
If an entry exists, all associated entries in the primary_dict will be
deleted, and all primary_dict keys returned from this function in a SecondarySet.
Args:
key: A key to drop from either the primary_dict or secondary_dict.
default: The default value if the key does not exist in either dict.
Returns:
Either a matched value from the primary_dict or the secondary_dict. If no
value is found for the key, then None.
"""
# Attempt to remove from the primary_dict first.
primary_value = self._primary_dict.pop(key, SENTINEL)
if primary_value is not SENTINEL:
# We found a value in the primary_dict. Remove it from the corresponding
# entry in the secondary_dict, and then return it.
logger.info(
"*** Popped entry from primary_dict: %s: %s", key, primary_value
)
# Derive the secondary_key from the primary_value
secondary_key = self._secondary_key_function(primary_value)
# Pop the entry from the secondary_dict
secondary_key_set = self._secondary_dict[secondary_key]
if len(secondary_key_set) > 1:
# Delete just the set entry for the given key.
secondary_key_set.remove(key)
logger.info(
"*** Popping from secondary_dict: %s: %s", secondary_key, key
)
else:
# Delete the entire set referenced by the secondary_key, as it only
# has one entry.
del self._secondary_dict[secondary_key]
logger.info("*** Popping from secondary_dict: %s", secondary_key)
return primary_value
# There was no matching value in the primary_dict. Attempt the secondary_dict.
primary_key_set = self._secondary_dict.pop(key, SENTINEL)
if primary_key_set is not SENTINEL:
# We found a set in the secondary_dict.
logger.info(
"*** Found '%s' in secondary_dict: %s: ",
key,
primary_key_set,
)
popped_primary_dict_values: List[PVT] = []
# We found an entry in the secondary_dict. Delete all related entries in the
# primary_dict.
logger.info(
"*** Found key in secondary_dict to pop: %s. "
"Popping primary_dict entries",
key,
)
for primary_key in primary_key_set:
primary_value = self._primary_dict.pop(primary_key)
logger.info("*** Popping entry from primary_dict: %s - %s", primary_key, primary_value)
logger.info("*** primary_dict: %s", self._primary_dict)
popped_primary_dict_values.append(primary_value)
# Now return the unmodified copy of the set.
return popped_primary_dict_values
# No match in either dict.
return default
def values(self) -> ValuesView:
return self._primary_dict.values()
def items(self) -> ItemsView:
return self._primary_dict.items()
def __len__(self) -> int:
return len(self._primary_dict)

View File

@@ -46,8 +46,10 @@ from synapse.metrics.background_process_metrics import wrap_as_background_proces
from synapse.metrics.jemalloc import get_jemalloc_stats
from synapse.util import Clock, caches
from synapse.util.caches import CacheMetric, EvictionReason, register_cache
from synapse.util.caches.dual_lookup_cache import DualLookupCache, SecondarySet
from synapse.util.caches.treecache import (
TreeCache,
TreeCacheNode,
iterate_tree_cache_entry,
iterate_tree_cache_items,
)
@@ -375,12 +377,13 @@ class LruCache(Generic[KT, VT]):
self,
max_size: int,
cache_name: Optional[str] = None,
cache_type: Type[Union[dict, TreeCache]] = dict,
cache_type: Type[Union[dict, TreeCache, DualLookupCache]] = dict,
size_callback: Optional[Callable[[VT], int]] = None,
metrics_collection_callback: Optional[Callable[[], None]] = None,
apply_cache_factor_from_config: bool = True,
clock: Optional[Clock] = None,
prune_unread_entries: bool = True,
dual_lookup_secondary_key_function: Optional[Callable[[Any], Any]] = None,
):
"""
Args:
@@ -411,6 +414,10 @@ class LruCache(Generic[KT, VT]):
prune_unread_entries: If True, cache entries that haven't been read recently
will be evicted from the cache in the background. Set to False to
opt-out of this behaviour.
# TODO: At this point we should probably just pass an initialised cache type
# to LruCache, no?
dual_lookup_secondary_key_function:
"""
# Default `clock` to something sensible. Note that we rename it to
# `real_clock` so that mypy doesn't think its still `Optional`.
@@ -419,7 +426,30 @@ class LruCache(Generic[KT, VT]):
else:
real_clock = clock
cache: Union[Dict[KT, _Node[KT, VT]], TreeCache] = cache_type()
# TODO: I've had to make this ugly to appease mypy :(
# Perhaps initialise the backing cache and then pass to LruCache?
cache: Union[Dict[KT, _Node[KT, VT]], TreeCache, DualLookupCache]
if cache_type is DualLookupCache:
# The dual_lookup_secondary_key_function is a function that's intended to
# extract a key from the value in the cache. Since we wrap values given to
# us in a _Node object, this function will actually operate on a _Node,
# instead of directly on the object type callers are expecting.
#
# Thus, we wrap the function given by the caller in another one that
# extracts the value from the _Node, before then handing it off to the
# given function for processing.
def key_function_wrapper(node: Any) -> Any:
assert dual_lookup_secondary_key_function is not None
return dual_lookup_secondary_key_function(node.value)
cache = DualLookupCache(
secondary_key_function=key_function_wrapper,
)
elif cache_type is TreeCache:
cache = TreeCache()
else:
cache = {}
self.cache = cache # Used for introspection.
self.apply_cache_factor_from_config = apply_cache_factor_from_config
@@ -722,13 +752,25 @@ class LruCache(Generic[KT, VT]):
may be of lower cardinality than the TreeCache - in which case the whole
subtree is deleted.
"""
popped = cache.pop(key, None)
if popped is None:
# Remove an entry from the cache.
# In the case of a 'dict' cache type, we're just removing an entry from the
# dict. For a TreeCache, we're removing a subtree which has children.
popped_entry: _Node[KT, VT] = cache.pop(key, None)
if popped_entry is None:
return
# for each deleted node, we now need to remove it from the linked list
# and run its callbacks.
for leaf in iterate_tree_cache_entry(popped):
delete_node(leaf)
if isinstance(popped_entry, TreeCacheNode):
# We've popped a subtree from a TreeCache - now we need to clean up
# each child node.
for leaf in iterate_tree_cache_entry(popped_entry):
# For each deleted child node, we remove it from the linked list and
# run its callbacks.
delete_node(leaf)
elif isinstance(popped_entry, SecondarySet):
for leaf in popped_entry:
delete_node(leaf)
else:
delete_node(popped_entry)
@synchronized
def cache_clear() -> None:

View File

@@ -11,23 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
from unittest import mock
from synapse.api.errors import AuthError
from synapse.api.room_versions import RoomVersion
from synapse.event_auth import (
check_state_dependent_auth_rules,
check_state_independent_auth_rules,
)
from synapse.events import make_event_from_dict
from synapse.events.snapshot import EventContext
from synapse.federation.transport.client import StateRequestResponse
from synapse.logging.context import LoggingContext
from synapse.rest import admin
from synapse.rest.client import login, room
from synapse.state.v2 import _mainline_sort, _reverse_topological_power_sort
from synapse.types import JsonDict
from tests import unittest
from tests.test_utils import event_injection, make_awaitable
@@ -458,393 +449,3 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
main_store.get_event(pulled_event.event_id, allow_none=True)
)
self.assertIsNotNone(persisted, "pulled event was not persisted at all")
def test_process_pulled_event_with_rejected_missing_state(self) -> None:
"""Ensure that we correctly handle pulled events with missing state containing a
rejected state event
In this test, we pretend we are processing a "pulled" event (eg, via backfill
or get_missing_events). The pulled event has a prev_event we haven't previously
seen, so the server requests the state at that prev_event. We expect the server
to make a /state request.
We simulate a remote server whose /state includes a rejected kick event for a
local user. Notably, the kick event is rejected only because it cites a rejected
auth event and would otherwise be accepted based on the room state. During state
resolution, we re-run auth and can potentially introduce such rejected events
into the state if we are not careful.
We check that the pulled event is correctly persisted, and that the state
afterwards does not include the rejected kick.
"""
# The DAG we are testing looks like:
#
# ...
# |
# v
# remote admin user joins
# | |
# +-------+ +-------+
# | |
# | rejected power levels
# | from remote server
# | |
# | v
# | rejected kick of local user
# v from remote server
# new power levels |
# | v
# | missing event
# | from remote server
# | |
# +-------+ +-------+
# | |
# v v
# pulled event
# from remote server
#
# (arrows are in the opposite direction to prev_events.)
OTHER_USER = f"@user:{self.OTHER_SERVER_NAME}"
main_store = self.hs.get_datastores().main
# Create the room.
kermit_user_id = self.register_user("kermit", "test")
kermit_tok = self.login("kermit", "test")
room_id = self.helper.create_room_as(
room_creator=kermit_user_id, tok=kermit_tok
)
room_version = self.get_success(main_store.get_room_version(room_id))
# Add another local user to the room. This user is going to be kicked in a
# rejected event.
bert_user_id = self.register_user("bert", "test")
bert_tok = self.login("bert", "test")
self.helper.join(room_id, user=bert_user_id, tok=bert_tok)
# Allow the remote user to kick bert.
# The remote user is going to send a rejected power levels event later on and we
# need state resolution to order it before another power levels event kermit is
# going to send later on. Hence we give both users the same power level, so that
# ties are broken by `origin_server_ts`.
self.helper.send_state(
room_id,
"m.room.power_levels",
{"users": {kermit_user_id: 100, OTHER_USER: 100}},
tok=kermit_tok,
)
# Add the remote user to the room.
other_member_event = self.get_success(
event_injection.inject_member_event(self.hs, room_id, OTHER_USER, "join")
)
initial_state_map = self.get_success(
main_store.get_partial_current_state_ids(room_id)
)
create_event = self.get_success(
main_store.get_event(initial_state_map[("m.room.create", "")])
)
bert_member_event = self.get_success(
main_store.get_event(initial_state_map[("m.room.member", bert_user_id)])
)
power_levels_event = self.get_success(
main_store.get_event(initial_state_map[("m.room.power_levels", "")])
)
# We now need a rejected state event that will fail
# `check_state_independent_auth_rules` but pass
# `check_state_dependent_auth_rules`.
# First, we create a power levels event that we pretend the remote server has
# accepted, but the local homeserver will reject.
next_depth = 100
next_timestamp = other_member_event.origin_server_ts + 100
rejected_power_levels_event = make_event_from_dict(
self.add_hashes_and_signatures_from_other_server(
{
"type": "m.room.power_levels",
"state_key": "",
"room_id": room_id,
"sender": OTHER_USER,
"prev_events": [other_member_event.event_id],
"auth_events": [
initial_state_map[("m.room.create", "")],
initial_state_map[("m.room.power_levels", "")],
# The event will be rejected because of the duplicated auth
# event.
other_member_event.event_id,
other_member_event.event_id,
],
"origin_server_ts": next_timestamp,
"depth": next_depth,
"content": power_levels_event.content,
}
),
room_version,
)
next_depth += 1
next_timestamp += 100
with LoggingContext("send_rejected_power_levels_event"):
self.get_success(
self.hs.get_federation_event_handler()._process_pulled_event(
self.OTHER_SERVER_NAME,
rejected_power_levels_event,
backfilled=False,
)
)
self.assertEqual(
self.get_success(
main_store.get_rejection_reason(
rejected_power_levels_event.event_id
)
),
"auth_error",
)
# Then we create a kick event for a local user that cites the rejected power
# levels event in its auth events. The kick event will be rejected solely
# because of the rejected auth event and would otherwise be accepted.
rejected_kick_event = make_event_from_dict(
self.add_hashes_and_signatures_from_other_server(
{
"type": "m.room.member",
"state_key": bert_user_id,
"room_id": room_id,
"sender": OTHER_USER,
"prev_events": [rejected_power_levels_event.event_id],
"auth_events": [
initial_state_map[("m.room.create", "")],
rejected_power_levels_event.event_id,
initial_state_map[("m.room.member", bert_user_id)],
initial_state_map[("m.room.member", OTHER_USER)],
],
"origin_server_ts": next_timestamp,
"depth": next_depth,
"content": {"membership": "leave"},
}
),
room_version,
)
next_depth += 1
next_timestamp += 100
# The kick event must fail the state-independent auth rules, but pass the
# state-dependent auth rules, so that it has a chance of making it through state
# resolution.
self.get_failure(
check_state_independent_auth_rules(main_store, rejected_kick_event),
AuthError,
)
check_state_dependent_auth_rules(
rejected_kick_event,
[create_event, power_levels_event, other_member_event, bert_member_event],
)
# The kick event must also win over the original member event during state
# resolution.
self.assertEqual(
self.get_success(
_mainline_sort(
self.clock,
room_id,
event_ids=[
bert_member_event.event_id,
rejected_kick_event.event_id,
],
resolved_power_event_id=power_levels_event.event_id,
event_map={
bert_member_event.event_id: bert_member_event,
rejected_kick_event.event_id: rejected_kick_event,
},
state_res_store=main_store,
)
),
[bert_member_event.event_id, rejected_kick_event.event_id],
"The rejected kick event will not be applied after bert's join event "
"during state resolution. The test setup is incorrect.",
)
with LoggingContext("send_rejected_kick_event"):
self.get_success(
self.hs.get_federation_event_handler()._process_pulled_event(
self.OTHER_SERVER_NAME, rejected_kick_event, backfilled=False
)
)
self.assertEqual(
self.get_success(
main_store.get_rejection_reason(rejected_kick_event.event_id)
),
"auth_error",
)
# We need another power levels event which will win over the rejected one during
# state resolution, otherwise we hit other issues where we end up with rejected
# a power levels event during state resolution.
self.reactor.advance(100) # ensure the `origin_server_ts` is larger
new_power_levels_event = self.get_success(
main_store.get_event(
self.helper.send_state(
room_id,
"m.room.power_levels",
{"users": {kermit_user_id: 100, OTHER_USER: 100, bert_user_id: 1}},
tok=kermit_tok,
)["event_id"]
)
)
self.assertEqual(
self.get_success(
_reverse_topological_power_sort(
self.clock,
room_id,
event_ids=[
new_power_levels_event.event_id,
rejected_power_levels_event.event_id,
],
event_map={},
state_res_store=main_store,
full_conflicted_set=set(),
)
),
[rejected_power_levels_event.event_id, new_power_levels_event.event_id],
"The power levels events will not have the desired ordering during state "
"resolution. The test setup is incorrect.",
)
# Create a missing event, so that the local homeserver has to do a `/state` or
# `/state_ids` request to pull state from the remote homeserver.
missing_event = make_event_from_dict(
self.add_hashes_and_signatures_from_other_server(
{
"type": "m.room.message",
"room_id": room_id,
"sender": OTHER_USER,
"prev_events": [rejected_kick_event.event_id],
"auth_events": [
initial_state_map[("m.room.create", "")],
initial_state_map[("m.room.power_levels", "")],
initial_state_map[("m.room.member", OTHER_USER)],
],
"origin_server_ts": next_timestamp,
"depth": next_depth,
"content": {"msgtype": "m.text", "body": "foo"},
}
),
room_version,
)
next_depth += 1
next_timestamp += 100
# The pulled event has two prev events, one of which is missing. We will make a
# `/state` or `/state_ids` request to the remote homeserver to ask it for the
# state before the missing prev event.
pulled_event = make_event_from_dict(
self.add_hashes_and_signatures_from_other_server(
{
"type": "m.room.message",
"room_id": room_id,
"sender": OTHER_USER,
"prev_events": [
new_power_levels_event.event_id,
missing_event.event_id,
],
"auth_events": [
initial_state_map[("m.room.create", "")],
new_power_levels_event.event_id,
initial_state_map[("m.room.member", OTHER_USER)],
],
"origin_server_ts": next_timestamp,
"depth": next_depth,
"content": {"msgtype": "m.text", "body": "bar"},
}
),
room_version,
)
next_depth += 1
next_timestamp += 100
# Prepare the response for the `/state` or `/state_ids` request.
# The remote server believes bert has been kicked, while the local server does
# not.
state_before_missing_event = self.get_success(
main_store.get_events_as_list(initial_state_map.values())
)
state_before_missing_event = [
event
for event in state_before_missing_event
if event.event_id != bert_member_event.event_id
]
state_before_missing_event.append(rejected_kick_event)
# We have to bump the clock a bit, to keep the retry logic in
# `FederationClient.get_pdu` happy
self.reactor.advance(60000)
with LoggingContext("send_pulled_event"):
async def get_event(
destination: str, event_id: str, timeout: Optional[int] = None
) -> JsonDict:
self.assertEqual(destination, self.OTHER_SERVER_NAME)
self.assertEqual(event_id, missing_event.event_id)
return {"pdus": [missing_event.get_pdu_json()]}
async def get_room_state_ids(
destination: str, room_id: str, event_id: str
) -> JsonDict:
self.assertEqual(destination, self.OTHER_SERVER_NAME)
self.assertEqual(event_id, missing_event.event_id)
return {
"pdu_ids": [event.event_id for event in state_before_missing_event],
"auth_chain_ids": [],
}
async def get_room_state(
room_version: RoomVersion, destination: str, room_id: str, event_id: str
) -> StateRequestResponse:
self.assertEqual(destination, self.OTHER_SERVER_NAME)
self.assertEqual(event_id, missing_event.event_id)
return StateRequestResponse(
state=state_before_missing_event,
auth_events=[],
)
self.mock_federation_transport_client.get_event.side_effect = get_event
self.mock_federation_transport_client.get_room_state_ids.side_effect = (
get_room_state_ids
)
self.mock_federation_transport_client.get_room_state.side_effect = (
get_room_state
)
self.get_success(
self.hs.get_federation_event_handler()._process_pulled_event(
self.OTHER_SERVER_NAME, pulled_event, backfilled=False
)
)
self.assertIsNone(
self.get_success(
main_store.get_rejection_reason(pulled_event.event_id)
),
"Pulled event was unexpectedly rejected, likely due to a problem with "
"the test setup.",
)
self.assertEqual(
{pulled_event.event_id},
self.get_success(
main_store.have_events_in_timeline([pulled_event.event_id])
),
"Pulled event was not persisted, likely due to a problem with the test "
"setup.",
)
# We must not accept rejected events into the room state, so we expect bert
# to not be kicked, even if the remote server believes so.
new_state_map = self.get_success(
main_store.get_partial_current_state_ids(room_id)
)
self.assertEqual(
new_state_map[("m.room.member", bert_user_id)],
bert_member_event.event_id,
"Rejected kick event unexpectedly became part of room state.",
)

View File

@@ -222,10 +222,9 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
)
self.assertEqual(len(alice_sync_result.joined), 1)
self.assertEqual(alice_sync_result.joined[0].room_id, room_id)
last_room_creation_event_ids = [
alice_sync_result.joined[0].timeline.events[-1].event_id,
alice_sync_result.joined[0].timeline.events[-2].event_id,
]
last_room_creation_event_id = (
alice_sync_result.joined[0].timeline.events[-1].event_id
)
# Eve, a ne'er-do-well, registers.
eve = self.register_user("eve", "password")
@@ -251,7 +250,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
self.hs.get_datastores().main,
"get_prev_events_for_room",
new_callable=MagicMock,
return_value=make_awaitable(last_room_creation_event_ids),
return_value=make_awaitable([last_room_creation_event_id]),
)
with mocked_get_prev_events:
self.helper.join(room_id, eve, tok=eve_token)

View File

@@ -4140,90 +4140,3 @@ class AccountDataTestCase(unittest.HomeserverTestCase):
{"b": 2},
channel.json_body["account_data"]["rooms"]["test_room"]["m.per_room"],
)
class UsersByExternalIdTestCase(unittest.HomeserverTestCase):
servlets = [
synapse.rest.admin.register_servlets,
login.register_servlets,
]
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass")
self.other_user = self.register_user("user", "pass")
self.get_success(
self.store.record_user_external_id(
"the-auth-provider", "the-external-id", self.other_user
)
)
self.get_success(
self.store.record_user_external_id(
"another-auth-provider", "a:complex@external/id", self.other_user
)
)
def test_no_auth(self) -> None:
"""Try to lookup a user without authentication."""
url = (
"/_synapse/admin/v1/auth_providers/the-auth-provider/users/the-external-id"
)
channel = self.make_request(
"GET",
url,
)
self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_binding_does_not_exist(self) -> None:
"""Tests that a lookup for an external ID that does not exist returns a 404"""
url = "/_synapse/admin/v1/auth_providers/the-auth-provider/users/unknown-id"
channel = self.make_request(
"GET",
url,
access_token=self.admin_user_tok,
)
self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
def test_success(self) -> None:
"""Tests a successful external ID lookup"""
url = (
"/_synapse/admin/v1/auth_providers/the-auth-provider/users/the-external-id"
)
channel = self.make_request(
"GET",
url,
access_token=self.admin_user_tok,
)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(
{"user_id": self.other_user},
channel.json_body,
)
def test_success_urlencoded(self) -> None:
"""Tests a successful external ID lookup with an url-encoded ID"""
url = "/_synapse/admin/v1/auth_providers/another-auth-provider/users/a%3Acomplex%40external%2Fid"
channel = self.make_request(
"GET",
url,
access_token=self.admin_user_tok,
)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(
{"user_id": self.other_user},
channel.json_body,
)

View File

@@ -11,37 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest as stdlib_unittest
import unittest
from pydantic import BaseModel, ValidationError
from typing_extensions import Literal
from pydantic import ValidationError
from synapse.rest.client.models import EmailRequestTokenBody
class ThreepidMediumEnumTestCase(stdlib_unittest.TestCase):
class Model(BaseModel):
medium: Literal["email", "msisdn"]
def test_accepts_valid_medium_string(self) -> None:
"""Sanity check that Pydantic behaves sensibly with an enum-of-str
This is arguably more of a test of a class that inherits from str and Enum
simultaneously.
"""
model = self.Model.parse_obj({"medium": "email"})
self.assertEqual(model.medium, "email")
def test_rejects_invalid_medium_value(self) -> None:
with self.assertRaises(ValidationError):
self.Model.parse_obj({"medium": "interpretive_dance"})
def test_rejects_invalid_medium_type(self) -> None:
with self.assertRaises(ValidationError):
self.Model.parse_obj({"medium": 123})
class EmailRequestTokenBodyTestCase(stdlib_unittest.TestCase):
class EmailRequestTokenBodyTestCase(unittest.TestCase):
base_request = {
"client_secret": "hunter2",
"email": "alice@wonderland.com",

View File

@@ -710,7 +710,7 @@ class RoomsCreateTestCase(RoomBase):
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
self.assertTrue("room_id" in channel.json_body)
assert channel.resource_usage is not None
self.assertEqual(36, channel.resource_usage.db_txn_count)
self.assertEqual(44, channel.resource_usage.db_txn_count)
def test_post_room_initial_state(self) -> None:
# POST with initial_state config key, expect new room id
@@ -723,7 +723,7 @@ class RoomsCreateTestCase(RoomBase):
self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
self.assertTrue("room_id" in channel.json_body)
assert channel.resource_usage is not None
self.assertEqual(39, channel.resource_usage.db_txn_count)
self.assertEqual(50, channel.resource_usage.db_txn_count)
def test_post_room_visibility_key(self) -> None:
# POST with visibility config key, expect new room id

View File

@@ -115,6 +115,5 @@ class PurgeTests(HomeserverTestCase):
)
# The events aren't found.
self.store._invalidate_local_get_event_cache(create_event.event_id)
self.get_failure(self.store.get_event(create_event.event_id), NotFoundError)
self.get_failure(self.store.get_event(first["event_id"]), NotFoundError)

View File

@@ -46,9 +46,10 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
user_id = UserID("us", "test")
our_user = create_requester(user_id)
room_creator = self.homeserver.get_room_creation_handler()
config = {"preset": "public_chat"}
self.room_id = self.get_success(
room_creator.create_room(our_user, config, ratelimit=False)
room_creator.create_room(
our_user, room_creator._presets_dict["public_chat"], ratelimit=False
)
)[0]["room_id"]
self.store = self.homeserver.get_datastores().main
@@ -98,8 +99,10 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
)
# Make sure we actually joined the room
res = self.get_success(self.store.get_latest_event_ids_in_room(self.room_id))
assert "$join:test.serv" in res
self.assertEqual(
self.get_success(self.store.get_latest_event_ids_in_room(self.room_id))[0],
"$join:test.serv",
)
def test_cant_hide_direct_ancestors(self):
"""