1
0

Compare commits

..

10 Commits

Author SHA1 Message Date
David Robertson ee2fd939b0 Changelog 2022-09-16 14:19:06 +01:00
David Robertson 675b1d47d7 Try caching black 2022-09-16 14:18:17 +01:00
Quentin Gliech 74f60cec92 Add an admin API endpoint to find a user based on its external ID in an auth provider. (#13810) 2022-09-16 12:29:03 +00:00
reivilibre f7a77ad717 Update request log format documentation to mention the format used when the authenticated user is controlling another user. (#13794) 2022-09-16 11:48:41 +00:00
Sean Quah b73cbb8215 Avoid putting rejected events in room state (#13723)
Signed-off-by: Sean Quah <seanq@matrix.org>
2022-09-16 12:45:04 +01:00
Eric Eastwood 6986bcbf39 Document common fix of Poetry problems by removing egg-info (#13785)
`matrix_synapse.egg-info/`

Mentioned at https://matrix.to/#/!vcyiEtMVHIhWXcJAfl:sw1v.org/$aKy_IjrKwb70aTVZWeW_6zt0k7OIZ1YkyZpkP9uiRaM?via=matrix.org&via=element.io&via=beeper.com and many other places.
2022-09-15 16:28:03 -05:00
Eric Eastwood 5093cbf88d Be able to correlate timeouts in reverse-proxy layer in front of Synapse (pull request ID from header) (#13801)
Fix https://github.com/matrix-org/synapse/issues/13685

New config:

```diff
  listeners:
    - port: 8008
      tls: false
      type: http
      x_forwarded: true
+     request_id_header: "cf-ray"
      bind_addresses: ['::1', '127.0.0.1', '0.0.0.0']
```
2022-09-15 15:32:25 -05:00
Eric Eastwood 140af0cdb6 Record any exception when processing a pulled event (#13814)
Part of https://github.com/matrix-org/synapse/issues/13700 and https://github.com/matrix-org/synapse/issues/13356

Follow-up to https://github.com/matrix-org/synapse/pull/13589
2022-09-15 14:40:49 -05:00
Patrick Cloke b2b0c85279 Support providing an index predicate for upserts. (#13822)
This is useful to upsert against a table which has a unique
partial index while avoiding conflicts.
2022-09-15 18:28:48 +00:00
David Robertson 742f9f9d78 A third batch of Pydantic validation for rest/client/account.py (#13736) 2022-09-15 18:36:02 +01:00
36 changed files with 776 additions and 410 deletions
+1 -1
View File
@@ -45,7 +45,7 @@ jobs:
- run: scripts-dev/check_schema_delta.py --force-colors
lint:
uses: "matrix-org/backend-meta/.github/workflows/python-poetry-ci.yml@v1"
uses: "matrix-org/backend-meta/.github/workflows/python-poetry-ci.yml@dmr/try-caching-black"
with:
typechecking-extras: "all"
+1 -1
View File
@@ -1 +1 @@
Keep track when we attempt to backfill an event but fail so we can intelligently back-off in the future.
Keep track when we fail to process a pulled event over federation so we can intelligently back-off in the future.
+1
View File
@@ -0,0 +1 @@
Fix a long-standing bug where previously rejected events could end up in room state because they pass auth checks given the current state of the room.
+1
View File
@@ -0,0 +1 @@
Improve validation of request bodies for the following client-server API endpoints: [`/account/3pid/add`](https://spec.matrix.org/v1.3/client-server-api/#post_matrixclientv3account3pidadd), [`/account/3pid/bind`](https://spec.matrix.org/v1.3/client-server-api/#post_matrixclientv3account3pidbind), [`/account/3pid/delete`](https://spec.matrix.org/v1.3/client-server-api/#post_matrixclientv3account3piddelete) and [`/account/3pid/unbind`](https://spec.matrix.org/v1.3/client-server-api/#post_matrixclientv3account3pidunbind).
+1
View File
@@ -0,0 +1 @@
Add docs for common fix of deleting the `matrix_synapse.egg-info/` directory for fixing Python dependency problems.
+1
View File
@@ -0,0 +1 @@
Update request log format documentation to mention the format used when the authenticated user is controlling another user.
+1
View File
@@ -0,0 +1 @@
Add `listeners[x].request_id_header` config to specify which request header to extract and use as the request ID in order to correlate requests from a reverse-proxy.
+1
View File
@@ -0,0 +1 @@
Add an admin API endpoint to find a user based on its external ID in an auth provider.
+1
View File
@@ -0,0 +1 @@
Keep track when we fail to process a pulled event over federation so we can intelligently back-off in the future.
+1
View File
@@ -0,0 +1 @@
Support providing an index predicate clause when doing upserts.
+1
View File
@@ -0,0 +1 @@
Dummy for testing CI speed.
+38
View File
@@ -1155,3 +1155,41 @@ GET /_synapse/admin/v1/username_available?username=$localpart
The request and response format is the same as the
[/_matrix/client/r0/register/available](https://matrix.org/docs/spec/client_server/r0.6.0#get-matrix-client-r0-register-available) API.
### Find a user based on their ID in an auth provider
The API is:
```
GET /_synapse/admin/v1/auth_providers/$provider/users/$external_id
```
When a user matched the given ID for the given provider, an HTTP code `200` with a response body like the following is returned:
```json
{
"user_id": "@hello:example.org"
}
```
**Parameters**
The following parameters should be set in the URL:
- `provider` - The ID of the authentication provider, as advertised by the [`GET /_matrix/client/v3/login`](https://spec.matrix.org/latest/client-server-api/#post_matrixclientv3login) API in the `m.login.sso` authentication method.
- `external_id` - The user ID from the authentication provider. Usually corresponds to the `sub` claim for OIDC providers, or to the `uid` attestation for SAML2 providers.
The `external_id` may have characters that are not URL-safe (typically `/`, `:` or `@`), so it is advised to URL-encode those parameters.
**Errors**
Returns a `404` HTTP status code if no user was found, with a response body like this:
```json
{
"errcode":"M_NOT_FOUND",
"error":"User not found"
}
```
_Added in Synapse 1.68.0._
+27
View File
@@ -126,6 +126,23 @@ context of poetry's venv, without having to run `poetry shell` beforehand.
poetry install --extras all --remove-untracked
```
## ...delete everything and start over from scratch?
```shell
# Stop the current virtualenv if active
$ deactivate
# Remove all of the files from the current environment.
# Don't worry, even though it says "all", this will only
# remove the Poetry virtualenvs for the current project.
$ poetry env remove --all
# Reactivate Poetry shell to create the virtualenv again
$ poetry shell
# Install everything again
$ poetry install --extras all
```
## ...run a command in the `poetry` virtualenv?
Use `poetry run cmd args` when you need the python virtualenv context.
@@ -256,6 +273,16 @@ from PyPI. (This is what makes poetry seem slow when doing the first
`poetry install`.) Try `poetry cache list` and `poetry cache clear --all
<name of cache>` to see if that fixes things.
## Remove outdated egg-info
Delete the `matrix_synapse.egg-info/` directory from the root of your Synapse
install.
This stores some cached information about dependencies and often conflicts with
letting Poetry do the right thing.
## Try `--verbose` or `--dry-run` arguments.
Sometimes useful to see what poetry's internal logic is.
+4
View File
@@ -45,6 +45,10 @@ listens to traffic on localhost. (Do not change `bind_addresses` to `127.0.0.1`
when using a containerized Synapse, as that will prevent it from responding
to proxied traffic.)
Optionally, you can also set
[`request_id_header`](../usage/configuration/config_documentation.md#listeners)
so that the server extracts and re-uses the same request ID format that the
reverse proxy is using.
## Reverse-proxy configuration examples
+2 -2
View File
@@ -12,14 +12,14 @@ See the following for how to decode the dense data available from the default lo
| Part | Explanation |
| ----- | ------------ |
| AAAA | Timestamp request was logged (not recieved) |
| AAAA | Timestamp request was logged (not received) |
| BBBB | Logger name (`synapse.access.(http\|https).<tag>`, where 'tag' is defined in the `listeners` config section, normally the port) |
| CCCC | Line number in code |
| DDDD | Log Level |
| EEEE | Request Identifier (This identifier is shared by related log lines)|
| FFFF | Source IP (Or X-Forwarded-For if enabled) |
| GGGG | Server Port |
| HHHH | Federated Server or Local User making request (blank if unauthenticated or not supplied) |
| HHHH | Federated Server or Local User making request (blank if unauthenticated or not supplied).<br/>If this is of the form `@aaa:example.com|@bbb:example.com`, then that means that `@aaa:example.com` is authenticated but they are controlling `@bbb:example.com`, e.g. if `aaa` is controlling `bbb` [via the admin API](https://matrix-org.github.io/synapse/latest/admin_api/user_admin_api.html#login-as-a-user). |
| IIII | Total Time to process the request |
| JJJJ | Time to send response over network once generated (this may be negative if the socket is closed before the response is generated)|
| KKKK | Userland CPU time |
@@ -434,7 +434,16 @@ Sub-options for each listener include:
* `tls`: set to true to enable TLS for this listener. Will use the TLS key/cert specified in tls_private_key_path / tls_certificate_path.
* `x_forwarded`: Only valid for an 'http' listener. Set to true to use the X-Forwarded-For header as the client IP. Useful when Synapse is
behind a reverse-proxy.
behind a [reverse-proxy](../../reverse_proxy.md).
* `request_id_header`: The header extracted from each incoming request that is
used as the basis for the request ID. The request ID is used in
[logs](../administration/request_log.md#request-log-format) and tracing to
correlate and match up requests. When unset, Synapse will automatically
generate sequential request IDs. This option is useful when Synapse is behind
a [reverse-proxy](../../reverse_proxy.md).
_Added in Synapse 1.68.0._
* `resources`: Only valid for an 'http' listener. A list of resources to host
on this port. Sub-options for each resource are:
+10 -3
View File
@@ -206,6 +206,7 @@ class HttpListenerConfig:
resources: List[HttpResourceConfig] = attr.Factory(list)
additional_resources: Dict[str, dict] = attr.Factory(dict)
tag: Optional[str] = None
request_id_header: Optional[str] = None
@attr.s(slots=True, frozen=True, auto_attribs=True)
@@ -520,9 +521,11 @@ class ServerConfig(Config):
):
raise ConfigError("allowed_avatar_mimetypes must be a list")
self.listeners = [
parse_listener_def(i, x) for i, x in enumerate(config.get("listeners", []))
]
listeners = config.get("listeners", [])
if not isinstance(listeners, list):
raise ConfigError("Expected a list", ("listeners",))
self.listeners = [parse_listener_def(i, x) for i, x in enumerate(listeners)]
# no_tls is not really supported any more, but let's grandfather it in
# here.
@@ -889,6 +892,9 @@ def read_gc_thresholds(
def parse_listener_def(num: int, listener: Any) -> ListenerConfig:
"""parse a listener config from the config file"""
if not isinstance(listener, dict):
raise ConfigError("Expected a dictionary", ("listeners", str(num)))
listener_type = listener["type"]
# Raise a helpful error if direct TCP replication is still configured.
if listener_type == "replication":
@@ -928,6 +934,7 @@ def parse_listener_def(num: int, listener: Any) -> ListenerConfig:
resources=resources,
additional_resources=listener.get("additional_resources", {}),
tag=listener.get("tag"),
request_id_header=listener.get("request_id_header"),
)
return ListenerConfig(port, bind_addresses, listener_type, tls, http_config)
+10
View File
@@ -866,6 +866,11 @@ class FederationEventHandler:
event.room_id, event_id, str(err)
)
return
except Exception as exc:
await self._store.record_event_failed_pull_attempt(
event.room_id, event_id, str(exc)
)
raise exc
try:
try:
@@ -908,6 +913,11 @@ class FederationEventHandler:
logger.warning("Pulled event %s failed history check.", event_id)
else:
raise
except Exception as exc:
await self._store.record_event_failed_pull_attempt(
event.room_id, event_id, str(exc)
)
raise exc
@trace
async def _compute_event_context_with_maybe_missing_prevs(
+13 -1
View File
@@ -72,10 +72,12 @@ class SynapseRequest(Request):
site: "SynapseSite",
*args: Any,
max_request_body_size: int = 1024,
request_id_header: Optional[str] = None,
**kw: Any,
):
super().__init__(channel, *args, **kw)
self._max_request_body_size = max_request_body_size
self.request_id_header = request_id_header
self.synapse_site = site
self.reactor = site.reactor
self._channel = channel # this is used by the tests
@@ -172,7 +174,14 @@ class SynapseRequest(Request):
self._opentracing_span = span
def get_request_id(self) -> str:
return "%s-%i" % (self.get_method(), self.request_seq)
request_id_value = None
if self.request_id_header:
request_id_value = self.getHeader(self.request_id_header)
if request_id_value is None:
request_id_value = str(self.request_seq)
return "%s-%s" % (self.get_method(), request_id_value)
def get_redacted_uri(self) -> str:
"""Gets the redacted URI associated with the request (or placeholder if the URI
@@ -611,12 +620,15 @@ class SynapseSite(Site):
proxied = config.http_options.x_forwarded
request_class = XForwardedForRequest if proxied else SynapseRequest
request_id_header = config.http_options.request_id_header
def request_factory(channel: HTTPChannel, queued: bool) -> Request:
return request_class(
channel,
self,
max_request_body_size=max_request_body_size,
queued=queued,
request_id_header=request_id_header,
)
self.requestFactory = request_factory # type: ignore
+2
View File
@@ -80,6 +80,7 @@ from synapse.rest.admin.users import (
SearchUsersRestServlet,
ShadowBanRestServlet,
UserAdminServlet,
UserByExternalId,
UserMembershipRestServlet,
UserRegisterServlet,
UserRestServletV2,
@@ -275,6 +276,7 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
ListDestinationsRestServlet(hs).register(http_server)
RoomMessagesRestServlet(hs).register(http_server)
RoomTimestampToEventRestServlet(hs).register(http_server)
UserByExternalId(hs).register(http_server)
# Some servlets only get registered for the main process.
if hs.config.worker.worker_app is None:
+27
View File
@@ -1156,3 +1156,30 @@ class AccountDataRestServlet(RestServlet):
"rooms": by_room_data,
},
}
class UserByExternalId(RestServlet):
"""Find a user based on an external ID from an auth provider"""
PATTERNS = admin_patterns(
"/auth_providers/(?P<provider>[^/]*)/users/(?P<external_id>[^/]*)"
)
def __init__(self, hs: "HomeServer"):
self._auth = hs.get_auth()
self._store = hs.get_datastores().main
async def on_GET(
self,
request: SynapseRequest,
provider: str,
external_id: str,
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self._auth, request)
user_id = await self._store.get_user_by_external_id(provider, external_id)
if user_id is None:
raise NotFoundError("User not found")
return HTTPStatus.OK, {"user_id": user_id}
+36 -29
View File
@@ -19,6 +19,7 @@ from typing import TYPE_CHECKING, List, Optional, Tuple
from urllib.parse import urlparse
from pydantic import StrictBool, StrictStr, constr
from typing_extensions import Literal
from twisted.web.server import Request
@@ -43,6 +44,7 @@ from synapse.metrics import threepid_send_requests
from synapse.push.mailer import Mailer
from synapse.rest.client.models import (
AuthenticationData,
ClientSecretStr,
EmailRequestTokenBody,
MsisdnRequestTokenBody,
)
@@ -627,6 +629,11 @@ class ThreepidAddRestServlet(RestServlet):
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
class PostBody(RequestBodyModel):
auth: Optional[AuthenticationData] = None
client_secret: ClientSecretStr
sid: StrictStr
@interactive_auth_handler
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
if not self.hs.config.registration.enable_3pid_changes:
@@ -636,22 +643,17 @@ class ThreepidAddRestServlet(RestServlet):
requester = await self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
body = parse_json_object_from_request(request)
assert_params_in_dict(body, ["client_secret", "sid"])
sid = body["sid"]
client_secret = body["client_secret"]
assert_valid_client_secret(client_secret)
body = parse_and_validate_json_object_from_request(request, self.PostBody)
await self.auth_handler.validate_user_via_ui_auth(
requester,
request,
body,
body.dict(exclude_unset=True),
"add a third-party identifier to your account",
)
validation_session = await self.identity_handler.validate_threepid_session(
client_secret, sid
body.client_secret, body.sid
)
if validation_session:
await self.auth_handler.add_threepid(
@@ -676,23 +678,20 @@ class ThreepidBindRestServlet(RestServlet):
self.identity_handler = hs.get_identity_handler()
self.auth = hs.get_auth()
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
body = parse_json_object_from_request(request)
class PostBody(RequestBodyModel):
client_secret: ClientSecretStr
id_access_token: StrictStr
id_server: StrictStr
sid: StrictStr
assert_params_in_dict(
body, ["id_server", "sid", "id_access_token", "client_secret"]
)
id_server = body["id_server"]
sid = body["sid"]
id_access_token = body["id_access_token"]
client_secret = body["client_secret"]
assert_valid_client_secret(client_secret)
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
body = parse_and_validate_json_object_from_request(request, self.PostBody)
requester = await self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
await self.identity_handler.bind_threepid(
client_secret, sid, user_id, id_server, id_access_token
body.client_secret, body.sid, user_id, body.id_server, body.id_access_token
)
return 200, {}
@@ -708,23 +707,27 @@ class ThreepidUnbindRestServlet(RestServlet):
self.auth = hs.get_auth()
self.datastore = self.hs.get_datastores().main
class PostBody(RequestBodyModel):
address: StrictStr
id_server: Optional[StrictStr] = None
medium: Literal["email", "msisdn"]
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
"""Unbind the given 3pid from a specific identity server, or identity servers that are
known to have this 3pid bound
"""
requester = await self.auth.get_user_by_req(request)
body = parse_json_object_from_request(request)
assert_params_in_dict(body, ["medium", "address"])
medium = body.get("medium")
address = body.get("address")
id_server = body.get("id_server")
body = parse_and_validate_json_object_from_request(request, self.PostBody)
# Attempt to unbind the threepid from an identity server. If id_server is None, try to
# unbind from all identity servers this threepid has been added to in the past
result = await self.identity_handler.try_unbind_threepid(
requester.user.to_string(),
{"address": address, "medium": medium, "id_server": id_server},
{
"address": body.address,
"medium": body.medium,
"id_server": body.id_server,
},
)
return 200, {"id_server_unbind_result": "success" if result else "no-support"}
@@ -738,21 +741,25 @@ class ThreepidDeleteRestServlet(RestServlet):
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
class PostBody(RequestBodyModel):
address: StrictStr
id_server: Optional[StrictStr] = None
medium: Literal["email", "msisdn"]
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
if not self.hs.config.registration.enable_3pid_changes:
raise SynapseError(
400, "3PID changes are disabled on this server", Codes.FORBIDDEN
)
body = parse_json_object_from_request(request)
assert_params_in_dict(body, ["medium", "address"])
body = parse_and_validate_json_object_from_request(request, self.PostBody)
requester = await self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
try:
ret = await self.auth_handler.delete_threepid(
user_id, body["medium"], body["address"], body.get("id_server")
user_id, body.medium, body.address, body.id_server
)
except Exception:
# NB. This endpoint should succeed if there is nothing to
+15 -13
View File
@@ -36,18 +36,20 @@ class AuthenticationData(RequestBodyModel):
type: Optional[StrictStr] = None
class ThreePidRequestTokenBody(RequestBodyModel):
if TYPE_CHECKING:
client_secret: StrictStr
else:
# See also assert_valid_client_secret()
client_secret: constr(
regex="[0-9a-zA-Z.=_-]", # noqa: F722
min_length=0,
max_length=255,
strict=True,
)
if TYPE_CHECKING:
ClientSecretStr = StrictStr
else:
# See also assert_valid_client_secret()
ClientSecretStr = constr(
regex="[0-9a-zA-Z.=_-]", # noqa: F722
min_length=1,
max_length=255,
strict=True,
)
class ThreepidRequestTokenBody(RequestBodyModel):
client_secret: ClientSecretStr
id_server: Optional[StrictStr]
id_access_token: Optional[StrictStr]
next_link: Optional[StrictStr]
@@ -62,7 +64,7 @@ class ThreePidRequestTokenBody(RequestBodyModel):
return token
class EmailRequestTokenBody(ThreePidRequestTokenBody):
class EmailRequestTokenBody(ThreepidRequestTokenBody):
email: StrictStr
# Canonicalise the email address. The addresses are all stored canonicalised
@@ -80,6 +82,6 @@ else:
ISO3116_1_Alpha_2 = constr(regex="[A-Z]{2}", strict=True)
class MsisdnRequestTokenBody(ThreePidRequestTokenBody):
class MsisdnRequestTokenBody(ThreepidRequestTokenBody):
country: ISO3116_1_Alpha_2
phone_number: StrictStr
+15
View File
@@ -577,6 +577,21 @@ async def _iterative_auth_checks(
if ev.rejected_reason is None:
auth_events[key] = event_map[ev_id]
if event.rejected_reason is not None:
# Do not admit previously rejected events into state.
# TODO: This isn't spec compliant. Events that were previously rejected due
# to failing auth checks at their state, but pass auth checks during
# state resolution should be accepted. Synapse does not handle the
# change of rejection status well, so we preserve the previous
# rejection status for now.
#
# Note that events rejected for non-state reasons, such as having the
# wrong auth events, should remain rejected.
#
# https://spec.matrix.org/v1.2/rooms/v9/#rejected-events
# https://github.com/matrix-org/synapse/issues/13797
continue
try:
event_auth.check_state_dependent_auth_rules(
event,
+1
View File
@@ -533,6 +533,7 @@ class BackgroundUpdater:
index_name: name of index to add
table: table to add index to
columns: columns/expressions to include in index
where_clause: A WHERE clause to specify a partial unique index.
unique: true to make a UNIQUE index
psql_only: true to only create this index on psql databases (useful
for virtual sqlite tables)
+23 -7
View File
@@ -1191,6 +1191,7 @@ class DatabasePool:
keyvalues: Dict[str, Any],
values: Dict[str, Any],
insertion_values: Optional[Dict[str, Any]] = None,
where_clause: Optional[str] = None,
lock: bool = True,
) -> bool:
"""
@@ -1203,6 +1204,7 @@ class DatabasePool:
keyvalues: The unique key tables and their new values
values: The nonunique columns and their new values
insertion_values: additional key/values to use only when inserting
where_clause: An index predicate to apply to the upsert.
lock: True to lock the table when doing the upsert. Unused when performing
a native upsert.
Returns:
@@ -1213,7 +1215,12 @@ class DatabasePool:
if table not in self._unsafe_to_upsert_tables:
return self.simple_upsert_txn_native_upsert(
txn, table, keyvalues, values, insertion_values=insertion_values
txn,
table,
keyvalues,
values,
insertion_values=insertion_values,
where_clause=where_clause,
)
else:
return self.simple_upsert_txn_emulated(
@@ -1222,6 +1229,7 @@ class DatabasePool:
keyvalues,
values,
insertion_values=insertion_values,
where_clause=where_clause,
lock=lock,
)
@@ -1232,6 +1240,7 @@ class DatabasePool:
keyvalues: Dict[str, Any],
values: Dict[str, Any],
insertion_values: Optional[Dict[str, Any]] = None,
where_clause: Optional[str] = None,
lock: bool = True,
) -> bool:
"""
@@ -1240,6 +1249,7 @@ class DatabasePool:
keyvalues: The unique key tables and their new values
values: The nonunique columns and their new values
insertion_values: additional key/values to use only when inserting
where_clause: An index predicate to apply to the upsert.
lock: True to lock the table when doing the upsert.
Returns:
Returns True if a row was inserted or updated (i.e. if `values` is
@@ -1259,14 +1269,17 @@ class DatabasePool:
else:
return "%s = ?" % (key,)
# Generate a where clause of each keyvalue and optionally the provided
# index predicate.
where = [_getwhere(k) for k in keyvalues]
if where_clause:
where.append(where_clause)
if not values:
# If `values` is empty, then all of the values we care about are in
# the unique key, so there is nothing to UPDATE. We can just do a
# SELECT instead to see if it exists.
sql = "SELECT 1 FROM %s WHERE %s" % (
table,
" AND ".join(_getwhere(k) for k in keyvalues),
)
sql = "SELECT 1 FROM %s WHERE %s" % (table, " AND ".join(where))
sqlargs = list(keyvalues.values())
txn.execute(sql, sqlargs)
if txn.fetchall():
@@ -1277,7 +1290,7 @@ class DatabasePool:
sql = "UPDATE %s SET %s WHERE %s" % (
table,
", ".join("%s = ?" % (k,) for k in values),
" AND ".join(_getwhere(k) for k in keyvalues),
" AND ".join(where),
)
sqlargs = list(values.values()) + list(keyvalues.values())
@@ -1307,6 +1320,7 @@ class DatabasePool:
keyvalues: Dict[str, Any],
values: Dict[str, Any],
insertion_values: Optional[Dict[str, Any]] = None,
where_clause: Optional[str] = None,
) -> bool:
"""
Use the native UPSERT functionality in PostgreSQL.
@@ -1316,6 +1330,7 @@ class DatabasePool:
keyvalues: The unique key tables and their new values
values: The nonunique columns and their new values
insertion_values: additional key/values to use only when inserting
where_clause: An index predicate to apply to the upsert.
Returns:
Returns True if a row was inserted or updated (i.e. if `values` is
@@ -1331,11 +1346,12 @@ class DatabasePool:
allvalues.update(values)
latter = "UPDATE SET " + ", ".join(k + "=EXCLUDED." + k for k in values)
sql = ("INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) DO %s") % (
sql = "INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) %s DO %s" % (
table,
", ".join(k for k in allvalues),
", ".join("?" for _ in allvalues),
", ".join(k for k in keyvalues),
f"WHERE {where_clause}" if where_clause else "",
latter,
)
txn.execute(sql, list(allvalues.values()))
@@ -194,7 +194,7 @@ class CensorEventsStore(EventsWorkerStore, CacheInvalidationWorkerStore, SQLBase
# changed its content in the database. We can't call
# self._invalidate_cache_and_stream because self.get_event_cache isn't of the
# right type.
self.invalidate_get_event_cache_by_event_id_after_txn(txn, event.event_id)
self.invalidate_get_event_cache_after_txn(txn, event.event_id)
# Send that invalidation to replication so that other workers also invalidate
# the event cache.
self._send_invalidation_to_replication(
+3 -5
View File
@@ -1294,10 +1294,8 @@ class PersistEventsStore:
"""
depth_updates: Dict[str, int] = {}
for event, context in events_and_contexts:
# Remove any existing cache entries for the event_ids
self.store.invalidate_get_event_cache_by_event_id_after_txn(
txn, event.event_id
)
# Remove the any existing cache entries for the event_ids
self.store.invalidate_get_event_cache_after_txn(txn, event.event_id)
# Then update the `stream_ordering` position to mark the latest
# event as the front of the room. This should not be done for
# backfilled events because backfilled events have negative
@@ -1705,7 +1703,7 @@ class PersistEventsStore:
_invalidate_caches_for_event.
"""
assert event.redacts is not None
self.store.invalidate_get_event_cache_by_event_id_after_txn(txn, event.redacts)
self.store.invalidate_get_event_cache_after_txn(txn, event.redacts)
txn.call_after(self.store.get_relations_for_event.invalidate, (event.redacts,))
txn.call_after(self.store.get_applicable_edit.invalidate, (event.redacts,))
@@ -80,7 +80,6 @@ from synapse.types import JsonDict, get_domain_from_id
from synapse.util import unwrapFirstError
from synapse.util.async_helpers import ObservableDeferred, delay_cancellation
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.dual_lookup_cache import DualLookupCache
from synapse.util.caches.lrucache import AsyncLruCache
from synapse.util.cancellation import cancellable
from synapse.util.iterutils import batch_iter
@@ -246,8 +245,6 @@ class EventsWorkerStore(SQLBaseStore):
] = AsyncLruCache(
cache_name="*getEvent*",
max_size=hs.config.caches.event_cache_size,
cache_type=DualLookupCache,
dual_lookup_secondary_key_function=lambda v: (v.event.room_id,),
)
# Map from event ID to a deferred that will result in a map from event
@@ -736,7 +733,7 @@ class EventsWorkerStore(SQLBaseStore):
return event_entry_map
def invalidate_get_event_cache_by_event_id_after_txn(
def invalidate_get_event_cache_after_txn(
self, txn: LoggingTransaction, event_id: str
) -> None:
"""
@@ -750,31 +747,10 @@ class EventsWorkerStore(SQLBaseStore):
event_id: the event ID to be invalidated from caches
"""
txn.async_call_after(
self._invalidate_async_get_event_cache_by_event_id, event_id
)
txn.call_after(self._invalidate_local_get_event_cache_by_event_id, event_id)
txn.async_call_after(self._invalidate_async_get_event_cache, event_id)
txn.call_after(self._invalidate_local_get_event_cache, event_id)
def invalidate_get_event_cache_by_room_id_after_txn(
self, txn: LoggingTransaction, room_id: str
) -> None:
"""
Prepares a database transaction to invalidate the get event cache for a given
room ID when executed successfully. This is achieved by attaching two callbacks
to the transaction, one to invalidate the async cache and one for the in memory
sync cache (importantly called in that order).
Arguments:
txn: the database transaction to attach the callbacks to.
room_id: the room ID to invalidate all associated event caches for.
"""
txn.async_call_after(self._invalidate_async_get_event_cache_by_room_id, room_id)
txn.call_after(self._invalidate_local_get_event_cache_by_room_id, room_id)
async def _invalidate_async_get_event_cache_by_event_id(
self, event_id: str
) -> None:
async def _invalidate_async_get_event_cache(self, event_id: str) -> None:
"""
Invalidates an event in the asyncronous get event cache, which may be remote.
@@ -784,18 +760,7 @@ class EventsWorkerStore(SQLBaseStore):
await self._get_event_cache.invalidate((event_id,))
async def _invalidate_async_get_event_cache_by_room_id(self, room_id: str) -> None:
"""
Invalidates all events associated with a given room in the asyncronous get event
cache, which may be remote.
Arguments:
room_id: the room ID to invalidate associated events of.
"""
await self._get_event_cache.invalidate((room_id,))
def _invalidate_local_get_event_cache_by_event_id(self, event_id: str) -> None:
def _invalidate_local_get_event_cache(self, event_id: str) -> None:
"""
Invalidates an event in local in-memory get event caches.
@@ -807,18 +772,6 @@ class EventsWorkerStore(SQLBaseStore):
self._event_ref.pop(event_id, None)
self._current_event_fetches.pop(event_id, None)
def _invalidate_local_get_event_cache_by_room_id(self, room_id: str) -> None:
"""
Invalidates all events associated with a given room ID in local in-memory
get event caches.
Arguments:
room_id: the room ID to invalidate events of.
"""
self._get_event_cache.invalidate_local((room_id,))
# TODO: invalidate _event_ref and _current_event_fetches. How?
async def _get_events_from_cache(
self, events: Iterable[str], update_metrics: bool = True
) -> Dict[str, EventCacheEntry]:
@@ -2331,7 +2284,7 @@ class EventsWorkerStore(SQLBaseStore):
updatevalues={"rejection_reason": rejection_reason},
)
self.invalidate_get_event_cache_by_event_id_after_txn(txn, event_id)
self.invalidate_get_event_cache_after_txn(txn, event_id)
# TODO(faster_joins): invalidate the cache on workers. Ideally we'd just
# call '_send_invalidation_to_replication', but we actually need the other
@@ -304,7 +304,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
self._invalidate_cache_and_stream(
txn, self.have_seen_event, (room_id, event_id)
)
self.invalidate_get_event_cache_by_event_id_after_txn(txn, event_id)
self.invalidate_get_event_cache_after_txn(txn, event_id)
logger.info("[purge] done")
@@ -478,7 +478,6 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
# XXX: as with purge_history, this is racy, but no worse than other races
# that already exist.
self._invalidate_cache_and_stream(txn, self.have_seen_event, (room_id,))
self._invalidate_local_get_event_cache_by_room_id(room_id)
logger.info("[purge] done")
-238
View File
@@ -1,238 +0,0 @@
# Copyright 2022 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import (
Callable,
Dict,
Generic,
ItemsView,
List,
Optional,
TypeVar,
Union,
ValuesView,
)
# Used to discern between a value not existing in a map, or the value being 'None'.
SENTINEL = object()
# The type of the primary dict's keys.
PKT = TypeVar("PKT")
# The type of the primary dict's values.
PVT = TypeVar("PVT")
# The type of the secondary dict's keys.
SKT = TypeVar("SKT")
logger = logging.getLogger(__name__)
class SecondarySet(set):
"""
Used to differentiate between an entry in the secondary_dict, and a set stored
in the primary_dict. This is necessary as pop() can return either.
"""
class DualLookupCache(Generic[PKT, PVT, SKT]):
"""
A backing store for LruCache that supports multiple entry points.
Allows subsets of data to be deleted efficiently without requiring extra
information to query.
The data structure is two dictionaries:
* primary_dict containing a mapping of primary_key -> value.
* secondary_dict containing a mapping of secondary_key -> set of primary_key.
On insert, a mapping in the primary_dict must be created. A mapping in the
secondary_dict from a secondary_key to (a set containing) the same
primary_key will be made. The secondary_key
must be derived from the inserted value via a lambda function provided at cache
initialisation. This is so invalidated entries in the primary_dict may automatically
invalidate those in the secondary_dict. The secondary_key may be associated with one
or more primary_key's.
This creates an interface which allows for efficient lookups of a value given
a primary_key, as well as efficient invalidation of a subset of mapping in the
primary_dict given a secondary_key. A primary_key may not be associated with more
than one secondary_key.
As a worked example, consider storing a cache of room events. We could configure
the cache to store mappings between EventIDs and EventBase in the primary_dict,
while storing a mapping between room IDs and event IDs as the secondary_dict:
primary_dict: EventID -> EventBase
secondary_dict: RoomID -> {EventID, EventID, ...}
This would be efficient for the following operations:
* Given an EventID, look up the associated EventBase, and thus the roomID.
* Given a RoomID, invalidate all primary_dict entries for events in that room.
Since this is intended as a backing store for LRUCache, when it came time to evict
an entry from the primary_dict (EventID -> EventBase), the secondary_key could be
derived from a provided lambda function:
secondary_key = lambda event_base: event_base.room_id
The EventID set under room_id would then have the appropriate EventID entry evicted.
"""
def __init__(self, secondary_key_function: Callable[[PVT], SKT]) -> None:
self._primary_dict: Dict[PKT, PVT] = {}
self._secondary_dict: Dict[SKT, SecondarySet] = {}
self._secondary_key_function = secondary_key_function
def __setitem__(self, key: PKT, value: PVT) -> None:
self.set(key, value)
def __contains__(self, key: PKT) -> bool:
return key in self._primary_dict
def set(self, key: PKT, value: PVT) -> None:
"""Add an entry to the cache.
Will add an entry to the primary_dict consisting of key->value, as well as append
to the set referred to by secondary_key_function(value) in the secondary_dict.
Args:
key: The key for a new mapping in primary_dict.
value: The value for a new mapping in primary_dict.
"""
# Create an entry in the primary_dict.
self._primary_dict[key] = value
# Derive the secondary_key to use from the given primary_value.
secondary_key = self._secondary_key_function(value)
# TODO: If the lambda function resolves to None, don't insert an entry?
# And create a mapping in the secondary_dict to a set containing the
# primary_key, creating the set if necessary.
secondary_key_set = self._secondary_dict.setdefault(
secondary_key, SecondarySet()
)
secondary_key_set.add(key)
logger.info("*** Insert into primary_dict: %s: %s", key, value)
logger.info("*** Insert into secondary_dict: %s: %s", secondary_key, key)
def get(self, key: PKT, default: Optional[PVT] = None) -> Optional[PVT]:
"""Retrieve a value from the cache if it exists. If not, return the default
value.
This method simply pulls entries from the primary_dict.
# TODO: Any use cases for externally getting entries from the secondary_dict?
Args:
key: The key to search the cache for.
default: The default value to return if the given key is not found.
Returns:
The value referenced by the given key, if it exists in the cache. If not,
the value of `default` will be returned.
"""
logger.info("*** Retrieving key from primary_dict: %s", key)
return self._primary_dict.get(key, default)
def clear(self) -> None:
"""Evicts all entries from the cache."""
self._primary_dict.clear()
self._secondary_dict.clear()
def pop(
self, key: Union[PKT, SKT], default: Optional[Union[Dict[PKT, PVT], PVT]] = None
) -> Optional[Union[Dict[PKT, PVT], PVT]]:
"""Remove an entry from either the primary_dict or secondary_dict.
The primary_dict is checked first for the key. If an entry is found, it is
removed from the primary_dict and returned.
If no entry in the primary_dict exists, then the secondary_dict is checked.
If an entry exists, all associated entries in the primary_dict will be
deleted, and all primary_dict keys returned from this function in a SecondarySet.
Args:
key: A key to drop from either the primary_dict or secondary_dict.
default: The default value if the key does not exist in either dict.
Returns:
Either a matched value from the primary_dict or the secondary_dict. If no
value is found for the key, then None.
"""
# Attempt to remove from the primary_dict first.
primary_value = self._primary_dict.pop(key, SENTINEL)
if primary_value is not SENTINEL:
# We found a value in the primary_dict. Remove it from the corresponding
# entry in the secondary_dict, and then return it.
logger.info(
"*** Popped entry from primary_dict: %s: %s", key, primary_value
)
# Derive the secondary_key from the primary_value
secondary_key = self._secondary_key_function(primary_value)
# Pop the entry from the secondary_dict
secondary_key_set = self._secondary_dict[secondary_key]
if len(secondary_key_set) > 1:
# Delete just the set entry for the given key.
secondary_key_set.remove(key)
logger.info(
"*** Popping from secondary_dict: %s: %s", secondary_key, key
)
else:
# Delete the entire set referenced by the secondary_key, as it only
# has one entry.
del self._secondary_dict[secondary_key]
logger.info("*** Popping from secondary_dict: %s", secondary_key)
return primary_value
# There was no matching value in the primary_dict. Attempt the secondary_dict.
primary_key_set = self._secondary_dict.pop(key, SENTINEL)
if primary_key_set is not SENTINEL:
# We found a set in the secondary_dict.
logger.info(
"*** Found '%s' in secondary_dict: %s: ",
key,
primary_key_set,
)
popped_primary_dict_values: List[PVT] = []
# We found an entry in the secondary_dict. Delete all related entries in the
# primary_dict.
logger.info(
"*** Found key in secondary_dict to pop: %s. "
"Popping primary_dict entries",
key,
)
for primary_key in primary_key_set:
primary_value = self._primary_dict.pop(primary_key)
logger.info("*** Popping entry from primary_dict: %s - %s", primary_key, primary_value)
logger.info("*** primary_dict: %s", self._primary_dict)
popped_primary_dict_values.append(primary_value)
# Now return the unmodified copy of the set.
return popped_primary_dict_values
# No match in either dict.
return default
def values(self) -> ValuesView:
return self._primary_dict.values()
def items(self) -> ItemsView:
return self._primary_dict.items()
def __len__(self) -> int:
return len(self._primary_dict)
+8 -50
View File
@@ -46,10 +46,8 @@ from synapse.metrics.background_process_metrics import wrap_as_background_proces
from synapse.metrics.jemalloc import get_jemalloc_stats
from synapse.util import Clock, caches
from synapse.util.caches import CacheMetric, EvictionReason, register_cache
from synapse.util.caches.dual_lookup_cache import DualLookupCache, SecondarySet
from synapse.util.caches.treecache import (
TreeCache,
TreeCacheNode,
iterate_tree_cache_entry,
iterate_tree_cache_items,
)
@@ -377,13 +375,12 @@ class LruCache(Generic[KT, VT]):
self,
max_size: int,
cache_name: Optional[str] = None,
cache_type: Type[Union[dict, TreeCache, DualLookupCache]] = dict,
cache_type: Type[Union[dict, TreeCache]] = dict,
size_callback: Optional[Callable[[VT], int]] = None,
metrics_collection_callback: Optional[Callable[[], None]] = None,
apply_cache_factor_from_config: bool = True,
clock: Optional[Clock] = None,
prune_unread_entries: bool = True,
dual_lookup_secondary_key_function: Optional[Callable[[Any], Any]] = None,
):
"""
Args:
@@ -414,10 +411,6 @@ class LruCache(Generic[KT, VT]):
prune_unread_entries: If True, cache entries that haven't been read recently
will be evicted from the cache in the background. Set to False to
opt-out of this behaviour.
# TODO: At this point we should probably just pass an initialised cache type
# to LruCache, no?
dual_lookup_secondary_key_function:
"""
# Default `clock` to something sensible. Note that we rename it to
# `real_clock` so that mypy doesn't think its still `Optional`.
@@ -426,30 +419,7 @@ class LruCache(Generic[KT, VT]):
else:
real_clock = clock
# TODO: I've had to make this ugly to appease mypy :(
# Perhaps initialise the backing cache and then pass to LruCache?
cache: Union[Dict[KT, _Node[KT, VT]], TreeCache, DualLookupCache]
if cache_type is DualLookupCache:
# The dual_lookup_secondary_key_function is a function that's intended to
# extract a key from the value in the cache. Since we wrap values given to
# us in a _Node object, this function will actually operate on a _Node,
# instead of directly on the object type callers are expecting.
#
# Thus, we wrap the function given by the caller in another one that
# extracts the value from the _Node, before then handing it off to the
# given function for processing.
def key_function_wrapper(node: Any) -> Any:
assert dual_lookup_secondary_key_function is not None
return dual_lookup_secondary_key_function(node.value)
cache = DualLookupCache(
secondary_key_function=key_function_wrapper,
)
elif cache_type is TreeCache:
cache = TreeCache()
else:
cache = {}
cache: Union[Dict[KT, _Node[KT, VT]], TreeCache] = cache_type()
self.cache = cache # Used for introspection.
self.apply_cache_factor_from_config = apply_cache_factor_from_config
@@ -752,25 +722,13 @@ class LruCache(Generic[KT, VT]):
may be of lower cardinality than the TreeCache - in which case the whole
subtree is deleted.
"""
# Remove an entry from the cache.
# In the case of a 'dict' cache type, we're just removing an entry from the
# dict. For a TreeCache, we're removing a subtree which has children.
popped_entry: _Node[KT, VT] = cache.pop(key, None)
if popped_entry is None:
popped = cache.pop(key, None)
if popped is None:
return
if isinstance(popped_entry, TreeCacheNode):
# We've popped a subtree from a TreeCache - now we need to clean up
# each child node.
for leaf in iterate_tree_cache_entry(popped_entry):
# For each deleted child node, we remove it from the linked list and
# run its callbacks.
delete_node(leaf)
elif isinstance(popped_entry, SecondarySet):
for leaf in popped_entry:
delete_node(leaf)
else:
delete_node(popped_entry)
# for each deleted node, we now need to remove it from the linked list
# and run its callbacks.
for leaf in iterate_tree_cache_entry(popped):
delete_node(leaf)
@synchronized
def cache_clear() -> None:
+399
View File
@@ -11,14 +11,23 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
from unittest import mock
from synapse.api.errors import AuthError
from synapse.api.room_versions import RoomVersion
from synapse.event_auth import (
check_state_dependent_auth_rules,
check_state_independent_auth_rules,
)
from synapse.events import make_event_from_dict
from synapse.events.snapshot import EventContext
from synapse.federation.transport.client import StateRequestResponse
from synapse.logging.context import LoggingContext
from synapse.rest import admin
from synapse.rest.client import login, room
from synapse.state.v2 import _mainline_sort, _reverse_topological_power_sort
from synapse.types import JsonDict
from tests import unittest
from tests.test_utils import event_injection, make_awaitable
@@ -449,3 +458,393 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
main_store.get_event(pulled_event.event_id, allow_none=True)
)
self.assertIsNotNone(persisted, "pulled event was not persisted at all")
def test_process_pulled_event_with_rejected_missing_state(self) -> None:
"""Ensure that we correctly handle pulled events with missing state containing a
rejected state event
In this test, we pretend we are processing a "pulled" event (eg, via backfill
or get_missing_events). The pulled event has a prev_event we haven't previously
seen, so the server requests the state at that prev_event. We expect the server
to make a /state request.
We simulate a remote server whose /state includes a rejected kick event for a
local user. Notably, the kick event is rejected only because it cites a rejected
auth event and would otherwise be accepted based on the room state. During state
resolution, we re-run auth and can potentially introduce such rejected events
into the state if we are not careful.
We check that the pulled event is correctly persisted, and that the state
afterwards does not include the rejected kick.
"""
# The DAG we are testing looks like:
#
# ...
# |
# v
# remote admin user joins
# | |
# +-------+ +-------+
# | |
# | rejected power levels
# | from remote server
# | |
# | v
# | rejected kick of local user
# v from remote server
# new power levels |
# | v
# | missing event
# | from remote server
# | |
# +-------+ +-------+
# | |
# v v
# pulled event
# from remote server
#
# (arrows are in the opposite direction to prev_events.)
OTHER_USER = f"@user:{self.OTHER_SERVER_NAME}"
main_store = self.hs.get_datastores().main
# Create the room.
kermit_user_id = self.register_user("kermit", "test")
kermit_tok = self.login("kermit", "test")
room_id = self.helper.create_room_as(
room_creator=kermit_user_id, tok=kermit_tok
)
room_version = self.get_success(main_store.get_room_version(room_id))
# Add another local user to the room. This user is going to be kicked in a
# rejected event.
bert_user_id = self.register_user("bert", "test")
bert_tok = self.login("bert", "test")
self.helper.join(room_id, user=bert_user_id, tok=bert_tok)
# Allow the remote user to kick bert.
# The remote user is going to send a rejected power levels event later on and we
# need state resolution to order it before another power levels event kermit is
# going to send later on. Hence we give both users the same power level, so that
# ties are broken by `origin_server_ts`.
self.helper.send_state(
room_id,
"m.room.power_levels",
{"users": {kermit_user_id: 100, OTHER_USER: 100}},
tok=kermit_tok,
)
# Add the remote user to the room.
other_member_event = self.get_success(
event_injection.inject_member_event(self.hs, room_id, OTHER_USER, "join")
)
initial_state_map = self.get_success(
main_store.get_partial_current_state_ids(room_id)
)
create_event = self.get_success(
main_store.get_event(initial_state_map[("m.room.create", "")])
)
bert_member_event = self.get_success(
main_store.get_event(initial_state_map[("m.room.member", bert_user_id)])
)
power_levels_event = self.get_success(
main_store.get_event(initial_state_map[("m.room.power_levels", "")])
)
# We now need a rejected state event that will fail
# `check_state_independent_auth_rules` but pass
# `check_state_dependent_auth_rules`.
# First, we create a power levels event that we pretend the remote server has
# accepted, but the local homeserver will reject.
next_depth = 100
next_timestamp = other_member_event.origin_server_ts + 100
rejected_power_levels_event = make_event_from_dict(
self.add_hashes_and_signatures_from_other_server(
{
"type": "m.room.power_levels",
"state_key": "",
"room_id": room_id,
"sender": OTHER_USER,
"prev_events": [other_member_event.event_id],
"auth_events": [
initial_state_map[("m.room.create", "")],
initial_state_map[("m.room.power_levels", "")],
# The event will be rejected because of the duplicated auth
# event.
other_member_event.event_id,
other_member_event.event_id,
],
"origin_server_ts": next_timestamp,
"depth": next_depth,
"content": power_levels_event.content,
}
),
room_version,
)
next_depth += 1
next_timestamp += 100
with LoggingContext("send_rejected_power_levels_event"):
self.get_success(
self.hs.get_federation_event_handler()._process_pulled_event(
self.OTHER_SERVER_NAME,
rejected_power_levels_event,
backfilled=False,
)
)
self.assertEqual(
self.get_success(
main_store.get_rejection_reason(
rejected_power_levels_event.event_id
)
),
"auth_error",
)
# Then we create a kick event for a local user that cites the rejected power
# levels event in its auth events. The kick event will be rejected solely
# because of the rejected auth event and would otherwise be accepted.
rejected_kick_event = make_event_from_dict(
self.add_hashes_and_signatures_from_other_server(
{
"type": "m.room.member",
"state_key": bert_user_id,
"room_id": room_id,
"sender": OTHER_USER,
"prev_events": [rejected_power_levels_event.event_id],
"auth_events": [
initial_state_map[("m.room.create", "")],
rejected_power_levels_event.event_id,
initial_state_map[("m.room.member", bert_user_id)],
initial_state_map[("m.room.member", OTHER_USER)],
],
"origin_server_ts": next_timestamp,
"depth": next_depth,
"content": {"membership": "leave"},
}
),
room_version,
)
next_depth += 1
next_timestamp += 100
# The kick event must fail the state-independent auth rules, but pass the
# state-dependent auth rules, so that it has a chance of making it through state
# resolution.
self.get_failure(
check_state_independent_auth_rules(main_store, rejected_kick_event),
AuthError,
)
check_state_dependent_auth_rules(
rejected_kick_event,
[create_event, power_levels_event, other_member_event, bert_member_event],
)
# The kick event must also win over the original member event during state
# resolution.
self.assertEqual(
self.get_success(
_mainline_sort(
self.clock,
room_id,
event_ids=[
bert_member_event.event_id,
rejected_kick_event.event_id,
],
resolved_power_event_id=power_levels_event.event_id,
event_map={
bert_member_event.event_id: bert_member_event,
rejected_kick_event.event_id: rejected_kick_event,
},
state_res_store=main_store,
)
),
[bert_member_event.event_id, rejected_kick_event.event_id],
"The rejected kick event will not be applied after bert's join event "
"during state resolution. The test setup is incorrect.",
)
with LoggingContext("send_rejected_kick_event"):
self.get_success(
self.hs.get_federation_event_handler()._process_pulled_event(
self.OTHER_SERVER_NAME, rejected_kick_event, backfilled=False
)
)
self.assertEqual(
self.get_success(
main_store.get_rejection_reason(rejected_kick_event.event_id)
),
"auth_error",
)
# We need another power levels event which will win over the rejected one during
# state resolution, otherwise we hit other issues where we end up with rejected
# a power levels event during state resolution.
self.reactor.advance(100) # ensure the `origin_server_ts` is larger
new_power_levels_event = self.get_success(
main_store.get_event(
self.helper.send_state(
room_id,
"m.room.power_levels",
{"users": {kermit_user_id: 100, OTHER_USER: 100, bert_user_id: 1}},
tok=kermit_tok,
)["event_id"]
)
)
self.assertEqual(
self.get_success(
_reverse_topological_power_sort(
self.clock,
room_id,
event_ids=[
new_power_levels_event.event_id,
rejected_power_levels_event.event_id,
],
event_map={},
state_res_store=main_store,
full_conflicted_set=set(),
)
),
[rejected_power_levels_event.event_id, new_power_levels_event.event_id],
"The power levels events will not have the desired ordering during state "
"resolution. The test setup is incorrect.",
)
# Create a missing event, so that the local homeserver has to do a `/state` or
# `/state_ids` request to pull state from the remote homeserver.
missing_event = make_event_from_dict(
self.add_hashes_and_signatures_from_other_server(
{
"type": "m.room.message",
"room_id": room_id,
"sender": OTHER_USER,
"prev_events": [rejected_kick_event.event_id],
"auth_events": [
initial_state_map[("m.room.create", "")],
initial_state_map[("m.room.power_levels", "")],
initial_state_map[("m.room.member", OTHER_USER)],
],
"origin_server_ts": next_timestamp,
"depth": next_depth,
"content": {"msgtype": "m.text", "body": "foo"},
}
),
room_version,
)
next_depth += 1
next_timestamp += 100
# The pulled event has two prev events, one of which is missing. We will make a
# `/state` or `/state_ids` request to the remote homeserver to ask it for the
# state before the missing prev event.
pulled_event = make_event_from_dict(
self.add_hashes_and_signatures_from_other_server(
{
"type": "m.room.message",
"room_id": room_id,
"sender": OTHER_USER,
"prev_events": [
new_power_levels_event.event_id,
missing_event.event_id,
],
"auth_events": [
initial_state_map[("m.room.create", "")],
new_power_levels_event.event_id,
initial_state_map[("m.room.member", OTHER_USER)],
],
"origin_server_ts": next_timestamp,
"depth": next_depth,
"content": {"msgtype": "m.text", "body": "bar"},
}
),
room_version,
)
next_depth += 1
next_timestamp += 100
# Prepare the response for the `/state` or `/state_ids` request.
# The remote server believes bert has been kicked, while the local server does
# not.
state_before_missing_event = self.get_success(
main_store.get_events_as_list(initial_state_map.values())
)
state_before_missing_event = [
event
for event in state_before_missing_event
if event.event_id != bert_member_event.event_id
]
state_before_missing_event.append(rejected_kick_event)
# We have to bump the clock a bit, to keep the retry logic in
# `FederationClient.get_pdu` happy
self.reactor.advance(60000)
with LoggingContext("send_pulled_event"):
async def get_event(
destination: str, event_id: str, timeout: Optional[int] = None
) -> JsonDict:
self.assertEqual(destination, self.OTHER_SERVER_NAME)
self.assertEqual(event_id, missing_event.event_id)
return {"pdus": [missing_event.get_pdu_json()]}
async def get_room_state_ids(
destination: str, room_id: str, event_id: str
) -> JsonDict:
self.assertEqual(destination, self.OTHER_SERVER_NAME)
self.assertEqual(event_id, missing_event.event_id)
return {
"pdu_ids": [event.event_id for event in state_before_missing_event],
"auth_chain_ids": [],
}
async def get_room_state(
room_version: RoomVersion, destination: str, room_id: str, event_id: str
) -> StateRequestResponse:
self.assertEqual(destination, self.OTHER_SERVER_NAME)
self.assertEqual(event_id, missing_event.event_id)
return StateRequestResponse(
state=state_before_missing_event,
auth_events=[],
)
self.mock_federation_transport_client.get_event.side_effect = get_event
self.mock_federation_transport_client.get_room_state_ids.side_effect = (
get_room_state_ids
)
self.mock_federation_transport_client.get_room_state.side_effect = (
get_room_state
)
self.get_success(
self.hs.get_federation_event_handler()._process_pulled_event(
self.OTHER_SERVER_NAME, pulled_event, backfilled=False
)
)
self.assertIsNone(
self.get_success(
main_store.get_rejection_reason(pulled_event.event_id)
),
"Pulled event was unexpectedly rejected, likely due to a problem with "
"the test setup.",
)
self.assertEqual(
{pulled_event.event_id},
self.get_success(
main_store.have_events_in_timeline([pulled_event.event_id])
),
"Pulled event was not persisted, likely due to a problem with the test "
"setup.",
)
# We must not accept rejected events into the room state, so we expect bert
# to not be kicked, even if the remote server believes so.
new_state_map = self.get_success(
main_store.get_partial_current_state_ids(room_id)
)
self.assertEqual(
new_state_map[("m.room.member", bert_user_id)],
bert_member_event.event_id,
"Rejected kick event unexpectedly became part of room state.",
)
+87
View File
@@ -4140,3 +4140,90 @@ class AccountDataTestCase(unittest.HomeserverTestCase):
{"b": 2},
channel.json_body["account_data"]["rooms"]["test_room"]["m.per_room"],
)
class UsersByExternalIdTestCase(unittest.HomeserverTestCase):
servlets = [
synapse.rest.admin.register_servlets,
login.register_servlets,
]
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass")
self.other_user = self.register_user("user", "pass")
self.get_success(
self.store.record_user_external_id(
"the-auth-provider", "the-external-id", self.other_user
)
)
self.get_success(
self.store.record_user_external_id(
"another-auth-provider", "a:complex@external/id", self.other_user
)
)
def test_no_auth(self) -> None:
"""Try to lookup a user without authentication."""
url = (
"/_synapse/admin/v1/auth_providers/the-auth-provider/users/the-external-id"
)
channel = self.make_request(
"GET",
url,
)
self.assertEqual(401, channel.code, msg=channel.json_body)
self.assertEqual(Codes.MISSING_TOKEN, channel.json_body["errcode"])
def test_binding_does_not_exist(self) -> None:
"""Tests that a lookup for an external ID that does not exist returns a 404"""
url = "/_synapse/admin/v1/auth_providers/the-auth-provider/users/unknown-id"
channel = self.make_request(
"GET",
url,
access_token=self.admin_user_tok,
)
self.assertEqual(404, channel.code, msg=channel.json_body)
self.assertEqual(Codes.NOT_FOUND, channel.json_body["errcode"])
def test_success(self) -> None:
"""Tests a successful external ID lookup"""
url = (
"/_synapse/admin/v1/auth_providers/the-auth-provider/users/the-external-id"
)
channel = self.make_request(
"GET",
url,
access_token=self.admin_user_tok,
)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(
{"user_id": self.other_user},
channel.json_body,
)
def test_success_urlencoded(self) -> None:
"""Tests a successful external ID lookup with an url-encoded ID"""
url = "/_synapse/admin/v1/auth_providers/another-auth-provider/users/a%3Acomplex%40external%2Fid"
channel = self.make_request(
"GET",
url,
access_token=self.admin_user_tok,
)
self.assertEqual(200, channel.code, msg=channel.json_body)
self.assertEqual(
{"user_id": self.other_user},
channel.json_body,
)
+26 -3
View File
@@ -11,14 +11,37 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import unittest as stdlib_unittest
from pydantic import ValidationError
from pydantic import BaseModel, ValidationError
from typing_extensions import Literal
from synapse.rest.client.models import EmailRequestTokenBody
class EmailRequestTokenBodyTestCase(unittest.TestCase):
class ThreepidMediumEnumTestCase(stdlib_unittest.TestCase):
class Model(BaseModel):
medium: Literal["email", "msisdn"]
def test_accepts_valid_medium_string(self) -> None:
"""Sanity check that Pydantic behaves sensibly with an enum-of-str
This is arguably more of a test of a class that inherits from str and Enum
simultaneously.
"""
model = self.Model.parse_obj({"medium": "email"})
self.assertEqual(model.medium, "email")
def test_rejects_invalid_medium_value(self) -> None:
with self.assertRaises(ValidationError):
self.Model.parse_obj({"medium": "interpretive_dance"})
def test_rejects_invalid_medium_type(self) -> None:
with self.assertRaises(ValidationError):
self.Model.parse_obj({"medium": 123})
class EmailRequestTokenBodyTestCase(stdlib_unittest.TestCase):
base_request = {
"client_secret": "hunter2",
"email": "alice@wonderland.com",
+1
View File
@@ -115,5 +115,6 @@ class PurgeTests(HomeserverTestCase):
)
# The events aren't found.
self.store._invalidate_local_get_event_cache(create_event.event_id)
self.get_failure(self.store.get_event(create_event.event_id), NotFoundError)
self.get_failure(self.store.get_event(first["event_id"]), NotFoundError)