1
0

Merge branch 'madlittlemods/13856-fix-have-seen-events-not-being-invalidated' into maddlittlemods/msc2716-many-batches-optimization

Conflicts:
	tests/storage/databases/main/test_events_worker.py
This commit is contained in:
Eric Eastwood
2022-09-23 00:16:50 -05:00
54 changed files with 1093 additions and 175 deletions

View File

@@ -201,10 +201,11 @@ jobs:
open-issue:
if: "failure() && github.event_name != 'push' && github.event_name != 'pull_request'"
needs:
# TODO: should mypy be included here? It feels more brittle than the other two.
# TODO: should mypy be included here? It feels more brittle than the others.
- mypy
- trial
- sytest
- complement
runs-on: ubuntu-latest

View File

@@ -0,0 +1 @@
Add cache invalidation across workers to module API.

View File

@@ -0,0 +1 @@
Experimental implementation of MSC3882 to allow an existing device/session to generate a login token for use on a new device/session.

1
changelog.d/13772.doc Normal file
View File

@@ -0,0 +1 @@
Add `worker_main_http_uri` for the worker generator bash script.

View File

@@ -0,0 +1 @@
Add experimental support for [MSC3881: Remotely toggle push notifications for another client](https://github.com/matrix-org/matrix-spec-proposals/pull/3881).

1
changelog.d/13809.misc Normal file
View File

@@ -0,0 +1 @@
Improve the `synapse.api.auth.Auth` mock used in unit tests.

View File

@@ -0,0 +1 @@
Add experimental support for [MSC3881: Remotely toggle push notifications for another client](https://github.com/matrix-org/matrix-spec-proposals/pull/3881).

View File

@@ -0,0 +1 @@
Improve validation for the unspecced, internal-only `_matrix/client/unstable/add_threepid/msisdn/submit_token` endpoint.

1
changelog.d/13836.doc Normal file
View File

@@ -0,0 +1 @@
Fix a mistake in sso_mapping_providers.md: `map_user_attributes` is expected to return `display_name` not `displayname`.

1
changelog.d/13840.bugfix Normal file
View File

@@ -0,0 +1 @@
Fix a bug introduced in Synapse v1.53.0 where the experimental implementation of [MSC3715](https://github.com/matrix-org/matrix-spec-proposals/pull/3715) would give incorrect results when paginating forward.

1
changelog.d/13850.misc Normal file
View File

@@ -0,0 +1 @@
Fix the release script not publishing binary wheels.

1
changelog.d/13859.misc Normal file
View File

@@ -0,0 +1 @@
Raise issue if complement fails with latest deps.

View File

@@ -0,0 +1 @@
Add experimental support for [MSC3881: Remotely toggle push notifications for another client](https://github.com/matrix-org/matrix-spec-proposals/pull/3881).

1
changelog.d/13863.bugfix Normal file
View File

@@ -0,0 +1 @@
Fix `have_seen_event` cache not being invalidated after we persist an event which causes inefficiency effects like extra `/state` federation calls.

1
changelog.d/13870.doc Normal file
View File

@@ -0,0 +1 @@
Fix a cross-link from the register admin API to the `registration_shared_secret` configuration documentation.

View File

@@ -7,7 +7,7 @@ You can alternatively create multiple worker configuration files with a simple `
#!/bin/bash
for i in {1..5}
do
cat << EOF >> generic_worker$i.yaml
cat << EOF > generic_worker$i.yaml
worker_app: synapse.app.generic_worker
worker_name: generic_worker$i
@@ -15,6 +15,8 @@ worker_name: generic_worker$i
worker_replication_host: 127.0.0.1
worker_replication_http_port: 9093
worker_main_http_uri: http://localhost:8008/
worker_listeners:
- type: http
port: 808$i

View File

@@ -5,7 +5,7 @@ non-interactive way. This is generally used for bootstrapping a Synapse
instance with administrator accounts.
To authenticate yourself to the server, you will need both the shared secret
([`registration_shared_secret`](../configuration/config_documentation.md#registration_shared_secret)
([`registration_shared_secret`](../usage/configuration/config_documentation.md#registration_shared_secret)
in the homeserver configuration), and a one-time nonce. If the registration
shared secret is not configured, this API is not enabled.

View File

@@ -73,8 +73,8 @@ A custom mapping provider must specify the following methods:
* `async def map_user_attributes(self, userinfo, token, failures)`
- This method must be async.
- Arguments:
- `userinfo` - A `authlib.oidc.core.claims.UserInfo` object to extract user
information from.
- `userinfo` - An [`authlib.oidc.core.claims.UserInfo`](https://docs.authlib.org/en/latest/specs/oidc.html#authlib.oidc.core.UserInfo)
object to extract user information from.
- `token` - A dictionary which includes information necessary to make
further requests to the OpenID provider.
- `failures` - An `int` that represents the amount of times the returned
@@ -91,7 +91,13 @@ A custom mapping provider must specify the following methods:
`None`, the user is prompted to pick their own username. This is only used
during a user's first login. Once a localpart has been associated with a
remote user ID (see `get_remote_user_id`) it cannot be updated.
- `displayname`: An optional string, the display name for the user.
- `confirm_localpart`: A boolean. If set to `True`, when a `localpart`
string is returned from this method, Synapse will prompt the user to
either accept this localpart or pick their own username. Otherwise this
option has no effect. If omitted, defaults to `False`.
- `display_name`: An optional string, the display name for the user.
- `emails`: A list of strings, the email address(es) to associate with
this user. If omitted, defaults to an empty list.
* `async def get_extra_attributes(self, userinfo, token)`
- This method must be async.
- Arguments:

View File

@@ -29,7 +29,7 @@ class SynapsePlugin(Plugin):
self, fullname: str
) -> Optional[Callable[[MethodSigContext], CallableType]]:
if fullname.startswith(
"synapse.util.caches.descriptors._CachedFunction.__call__"
"synapse.util.caches.descriptors.CachedFunction.__call__"
) or fullname.startswith(
"synapse.util.caches.descriptors._LruCachedFunction.__call__"
):
@@ -38,7 +38,7 @@ class SynapsePlugin(Plugin):
def cached_function_method_signature(ctx: MethodSigContext) -> CallableType:
"""Fixes the `_CachedFunction.__call__` signature to be correct.
"""Fixes the `CachedFunction.__call__` signature to be correct.
It already has *almost* the correct signature, except:

View File

@@ -427,11 +427,12 @@ def _publish(gh_token: str) -> None:
@cli.command()
def upload() -> None:
_upload()
@click.option("--gh-token", envvar=["GH_TOKEN", "GITHUB_TOKEN"], required=False)
def upload(gh_token: Optional[str]) -> None:
_upload(gh_token)
def _upload() -> None:
def _upload(gh_token: Optional[str]) -> None:
"""Upload release to pypi."""
current_version = get_package_version()
@@ -444,18 +445,40 @@ def _upload() -> None:
click.echo("Tag {tag_name} (tag.commit) is not currently checked out!")
click.get_current_context().abort()
pypi_asset_names = [
f"matrix_synapse-{current_version}-py3-none-any.whl",
f"matrix-synapse-{current_version}.tar.gz",
]
# Query all the assets corresponding to this release.
gh = Github(gh_token)
gh_repo = gh.get_repo("matrix-org/synapse")
gh_release = gh_repo.get_release(tag_name)
all_assets = set(gh_release.get_assets())
# Only accept the wheels and sdist.
# Notably: we don't care about debs.tar.xz.
asset_names_and_urls = sorted(
(asset.name, asset.browser_download_url)
for asset in all_assets
if asset.name.endswith((".whl", ".tar.gz"))
)
# Print out what we've determined.
print("Found relevant assets:")
for asset_name, _ in asset_names_and_urls:
print(f" - {asset_name}")
ignored_asset_names = sorted(
{asset.name for asset in all_assets}
- {asset_name for asset_name, _ in asset_names_and_urls}
)
print("\nIgnoring irrelevant assets:")
for asset_name in ignored_asset_names:
print(f" - {asset_name}")
with TemporaryDirectory(prefix=f"synapse_upload_{tag_name}_") as tmpdir:
for name in pypi_asset_names:
for name, asset_download_url in asset_names_and_urls:
filename = path.join(tmpdir, name)
url = f"https://github.com/matrix-org/synapse/releases/download/{tag_name}/{name}"
click.echo(f"Downloading {name} into {filename}")
urllib.request.urlretrieve(url, filename=filename)
urllib.request.urlretrieve(asset_download_url, filename=filename)
if click.confirm("Upload to PyPI?", default=True):
subprocess.run("twine upload *", shell=True, cwd=tmpdir)
@@ -672,7 +695,7 @@ def full(gh_token: str) -> None:
_publish(gh_token)
click.echo("\n*** upload ***")
_upload()
_upload(gh_token)
click.echo("\n*** merge back ***")
_merge_back()

View File

@@ -111,6 +111,7 @@ BOOLEAN_COLUMNS = {
"e2e_fallback_keys_json": ["used"],
"access_tokens": ["used"],
"device_lists_changes_in_room": ["converted_to_destinations"],
"pushers": ["enabled"],
}

View File

@@ -93,3 +93,13 @@ class ExperimentalConfig(Config):
# MSC3852: Expose last seen user agent field on /_matrix/client/v3/devices.
self.msc3852_enabled: bool = experimental.get("msc3852_enabled", False)
# MSC3881: Remotely toggle push notifications for another client
self.msc3881_enabled: bool = experimental.get("msc3881_enabled", False)
# MSC3882: Allow an existing session to sign in a new session
self.msc3882_enabled: bool = experimental.get("msc3882_enabled", False)
self.msc3882_ui_auth: bool = experimental.get("msc3882_ui_auth", True)
self.msc3882_token_timeout = self.parse_duration(
experimental.get("msc3882_token_timeout", "5m")
)

View File

@@ -997,7 +997,7 @@ class RegistrationHandler:
assert user_tuple
token_id = user_tuple.token_id
await self.pusher_pool.add_pusher(
await self.pusher_pool.add_or_update_pusher(
user_id=user_id,
access_token=token_id,
kind="email",
@@ -1005,7 +1005,7 @@ class RegistrationHandler:
app_display_name="Email Notifications",
device_display_name=threepid["address"],
pushkey=threepid["address"],
lang=None, # We don't know a user's language here
lang=None,
data={},
)

View File

@@ -128,6 +128,9 @@ class SsoIdentityProvider(Protocol):
@attr.s(auto_attribs=True)
class UserAttributes:
# NB: This struct is documented in docs/sso_mapping_providers.md so that users can
# populate it with data from their own mapping providers.
# the localpart of the mxid that the mapper has assigned to the user.
# if `None`, the mapper has not picked a userid, and the user should be prompted to
# enter one.

View File

@@ -125,7 +125,7 @@ from synapse.types import (
)
from synapse.util import Clock
from synapse.util.async_helpers import maybe_awaitable
from synapse.util.caches.descriptors import cached
from synapse.util.caches.descriptors import CachedFunction, cached
from synapse.util.frozenutils import freeze
if TYPE_CHECKING:
@@ -836,6 +836,37 @@ class ModuleApi:
self._store.db_pool.runInteraction(desc, func, *args, **kwargs) # type: ignore[arg-type]
)
def register_cached_function(self, cached_func: CachedFunction) -> None:
"""Register a cached function that should be invalidated across workers.
Invalidation local to a worker can be done directly using `cached_func.invalidate`,
however invalidation that needs to go to other workers needs to call `invalidate_cache`
on the module API instead.
Args:
cached_function: The cached function that will be registered to receive invalidation
locally and from other workers.
"""
self._store.register_external_cached_function(
f"{cached_func.__module__}.{cached_func.__name__}", cached_func
)
async def invalidate_cache(
self, cached_func: CachedFunction, keys: Tuple[Any, ...]
) -> None:
"""Invalidate a cache entry of a cached function across workers. The cached function
needs to be registered on all workers first with `register_cached_function`.
Args:
cached_function: The cached function that needs an invalidation
keys: keys of the entry to invalidate, usually matching the arguments of the
cached function.
"""
cached_func.invalidate(keys)
await self._store.send_invalidation_to_replication(
f"{cached_func.__module__}.{cached_func.__name__}",
keys,
)
async def complete_sso_login_async(
self,
registered_user_id: str,

View File

@@ -116,6 +116,8 @@ class PusherConfig:
last_stream_ordering: int
last_success: Optional[int]
failing_since: Optional[int]
enabled: bool
device_id: Optional[str]
def as_dict(self) -> Dict[str, Any]:
"""Information that can be retrieved about a pusher after creation."""
@@ -128,6 +130,8 @@ class PusherConfig:
"lang": self.lang,
"profile_tag": self.profile_tag,
"pushkey": self.pushkey,
"enabled": self.enabled,
"device_id": self.device_id,
}

View File

@@ -94,7 +94,7 @@ class PusherPool:
return
run_as_background_process("start_pushers", self._start_pushers)
async def add_pusher(
async def add_or_update_pusher(
self,
user_id: str,
access_token: Optional[int],
@@ -106,6 +106,8 @@ class PusherPool:
lang: Optional[str],
data: JsonDict,
profile_tag: str = "",
enabled: bool = True,
device_id: Optional[str] = None,
) -> Optional[Pusher]:
"""Creates a new pusher and adds it to the pool
@@ -147,9 +149,22 @@ class PusherPool:
last_stream_ordering=last_stream_ordering,
last_success=None,
failing_since=None,
enabled=enabled,
device_id=device_id,
)
)
# Before we actually persist the pusher, we check if the user already has one
# this app ID and pushkey. If so, we want to keep the access token and device ID
# in place, since this could be one device modifying (e.g. enabling/disabling)
# another device's pusher.
existing_config = await self._get_pusher_config_for_user_by_app_id_and_pushkey(
user_id, app_id, pushkey
)
if existing_config:
access_token = existing_config.access_token
device_id = existing_config.device_id
await self.store.add_pusher(
user_id=user_id,
access_token=access_token,
@@ -163,8 +178,10 @@ class PusherPool:
data=data,
last_stream_ordering=last_stream_ordering,
profile_tag=profile_tag,
enabled=enabled,
device_id=device_id,
)
pusher = await self.start_pusher_by_id(app_id, pushkey, user_id)
pusher = await self.process_pusher_change_by_id(app_id, pushkey, user_id)
return pusher
@@ -276,10 +293,25 @@ class PusherPool:
except Exception:
logger.exception("Exception in pusher on_new_receipts")
async def start_pusher_by_id(
async def _get_pusher_config_for_user_by_app_id_and_pushkey(
self, user_id: str, app_id: str, pushkey: str
) -> Optional[PusherConfig]:
resultlist = await self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey)
pusher_config = None
for r in resultlist:
if r.user_name == user_id:
pusher_config = r
return pusher_config
async def process_pusher_change_by_id(
self, app_id: str, pushkey: str, user_id: str
) -> Optional[Pusher]:
"""Look up the details for the given pusher, and start it
"""Look up the details for the given pusher, and either start it if its
"enabled" flag is True, or try to stop it otherwise.
If the pusher is new and its "enabled" flag is False, the stop is a noop.
Returns:
The pusher started, if any
@@ -290,12 +322,13 @@ class PusherPool:
if not self._pusher_shard_config.should_handle(self._instance_name, user_id):
return None
resultlist = await self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey)
pusher_config = await self._get_pusher_config_for_user_by_app_id_and_pushkey(
user_id, app_id, pushkey
)
pusher_config = None
for r in resultlist:
if r.user_name == user_id:
pusher_config = r
if pusher_config and not pusher_config.enabled:
self.maybe_stop_pusher(app_id, pushkey, user_id)
return None
pusher = None
if pusher_config:
@@ -305,7 +338,7 @@ class PusherPool:
async def _start_pushers(self) -> None:
"""Start all the pushers"""
pushers = await self.store.get_all_pushers()
pushers = await self.store.get_enabled_pushers()
# Stagger starting up the pushers so we don't completely drown the
# process on start up.
@@ -363,6 +396,8 @@ class PusherPool:
synapse_pushers.labels(type(pusher).__name__, pusher.app_id).inc()
logger.info("Starting pusher %s / %s", pusher.user_id, appid_pushkey)
# Check if there *may* be push to process. We do this as this check is a
# lot cheaper to do than actually fetching the exact rows we need to
# push.
@@ -382,16 +417,7 @@ class PusherPool:
return pusher
async def remove_pusher(self, app_id: str, pushkey: str, user_id: str) -> None:
appid_pushkey = "%s:%s" % (app_id, pushkey)
byuser = self.pushers.get(user_id, {})
if appid_pushkey in byuser:
logger.info("Stopping pusher %s / %s", user_id, appid_pushkey)
pusher = byuser.pop(appid_pushkey)
pusher.on_stop()
synapse_pushers.labels(type(pusher).__name__, pusher.app_id).dec()
self.maybe_stop_pusher(app_id, pushkey, user_id)
# We can only delete pushers on master.
if self._remove_pusher_client:
@@ -402,3 +428,22 @@ class PusherPool:
await self.store.delete_pusher_by_app_id_pushkey_user_id(
app_id, pushkey, user_id
)
def maybe_stop_pusher(self, app_id: str, pushkey: str, user_id: str) -> None:
"""Stops a pusher with the given app ID and push key if one is running.
Args:
app_id: the pusher's app ID.
pushkey: the pusher's push key.
user_id: the user the pusher belongs to. Only used for logging.
"""
appid_pushkey = "%s:%s" % (app_id, pushkey)
byuser = self.pushers.get(user_id, {})
if appid_pushkey in byuser:
logger.info("Stopping pusher %s / %s", user_id, appid_pushkey)
pusher = byuser.pop(appid_pushkey)
pusher.on_stop()
synapse_pushers.labels(type(pusher).__name__, pusher.app_id).dec()

View File

@@ -189,7 +189,9 @@ class ReplicationDataHandler:
if row.deleted:
self.stop_pusher(row.user_id, row.app_id, row.pushkey)
else:
await self.start_pusher(row.user_id, row.app_id, row.pushkey)
await self.process_pusher_change(
row.user_id, row.app_id, row.pushkey
)
elif stream_name == EventsStream.NAME:
# We shouldn't get multiple rows per token for events stream, so
# we don't need to optimise this for multiple rows.
@@ -334,13 +336,15 @@ class ReplicationDataHandler:
logger.info("Stopping pusher %r / %r", user_id, key)
pusher.on_stop()
async def start_pusher(self, user_id: str, app_id: str, pushkey: str) -> None:
async def process_pusher_change(
self, user_id: str, app_id: str, pushkey: str
) -> None:
if not self._notify_pushers:
return
key = "%s:%s" % (app_id, pushkey)
logger.info("Starting pusher %r / %r", user_id, key)
await self._pusher_pool.start_pusher_by_id(app_id, pushkey, user_id)
await self._pusher_pool.process_pusher_change_by_id(app_id, pushkey, user_id)
class FederationSenderHandler:

View File

@@ -30,6 +30,7 @@ from synapse.rest.client import (
keys,
knock,
login as v1_login,
login_token_request,
logout,
mutual_rooms,
notifications,
@@ -130,3 +131,4 @@ class ClientRestResource(JsonResource):
# unstable
mutual_rooms.register_servlets(hs, client_resource)
login_token_request.register_servlets(hs, client_resource)

View File

@@ -375,7 +375,7 @@ class UserRestServletV2(RestServlet):
and self.hs.config.email.email_notif_for_new_users
and medium == "email"
):
await self.pusher_pool.add_pusher(
await self.pusher_pool.add_or_update_pusher(
user_id=user_id,
access_token=None,
kind="email",
@@ -383,7 +383,7 @@ class UserRestServletV2(RestServlet):
app_display_name="Email Notifications",
device_display_name=address,
pushkey=address,
lang=None, # We don't know a user's language here
lang=None,
data={},
)

View File

@@ -534,6 +534,11 @@ class AddThreepidMsisdnSubmitTokenServlet(RestServlet):
"/add_threepid/msisdn/submit_token$", releases=(), unstable=True
)
class PostBody(RequestBodyModel):
client_secret: ClientSecretStr
sid: StrictStr
token: StrictStr
def __init__(self, hs: "HomeServer"):
super().__init__()
self.config = hs.config
@@ -549,16 +554,14 @@ class AddThreepidMsisdnSubmitTokenServlet(RestServlet):
"instead.",
)
body = parse_json_object_from_request(request)
assert_params_in_dict(body, ["client_secret", "sid", "token"])
assert_valid_client_secret(body["client_secret"])
body = parse_and_validate_json_object_from_request(request, self.PostBody)
# Proxy submit_token request to msisdn threepid delegate
response = await self.identity_handler.proxy_msisdn_submit_token(
self.config.registration.account_threepid_delegate_msisdn,
body["client_secret"],
body["sid"],
body["token"],
body.client_secret,
body.sid,
body.token,
)
return 200, response
@@ -581,6 +584,10 @@ class ThreepidRestServlet(RestServlet):
return 200, {"threepids": threepids}
# NOTE(dmr): I have chosen not to use Pydantic to parse this request's body, because
# the endpoint is deprecated. (If you really want to, you could do this by reusing
# ThreePidBindRestServelet.PostBody with an `alias_generator` to handle
# `threePidCreds` versus `three_pid_creds`.
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
if not self.hs.config.registration.enable_3pid_changes:
raise SynapseError(

View File

@@ -0,0 +1,94 @@
# Copyright 2022 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Tuple
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.http.site import SynapseRequest
from synapse.rest.client._base import client_patterns, interactive_auth_handler
from synapse.types import JsonDict
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
class LoginTokenRequestServlet(RestServlet):
"""
Get a token that can be used with `m.login.token` to log in a second device.
Request:
POST /login/token HTTP/1.1
Content-Type: application/json
{}
Response:
HTTP/1.1 200 OK
{
"login_token": "ABDEFGH",
"expires_in": 3600,
}
"""
PATTERNS = client_patterns("/login/token$")
def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastores().main
self.clock = hs.get_clock()
self.server_name = hs.config.server.server_name
self.macaroon_gen = hs.get_macaroon_generator()
self.auth_handler = hs.get_auth_handler()
self.token_timeout = hs.config.experimental.msc3882_token_timeout
self.ui_auth = hs.config.experimental.msc3882_ui_auth
@interactive_auth_handler
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
body = parse_json_object_from_request(request)
if self.ui_auth:
await self.auth_handler.validate_user_via_ui_auth(
requester,
request,
body,
"issue a new access token for your account",
can_skip_ui_auth=False, # Don't allow skipping of UI auth
)
login_token = self.macaroon_gen.generate_short_term_login_token(
user_id=requester.user.to_string(),
auth_provider_id="org.matrix.msc3882.login_token_request",
duration_in_ms=self.token_timeout,
)
return (
200,
{
"login_token": login_token,
"expires_in": self.token_timeout // 1000,
},
)
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
if hs.config.experimental.msc3882_enabled:
LoginTokenRequestServlet(hs).register(http_server)

View File

@@ -42,6 +42,7 @@ class PushersRestServlet(RestServlet):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self._msc3881_enabled = self.hs.config.experimental.msc3881_enabled
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
@@ -51,9 +52,16 @@ class PushersRestServlet(RestServlet):
user.to_string()
)
filtered_pushers = [p.as_dict() for p in pushers]
pusher_dicts = [p.as_dict() for p in pushers]
return 200, {"pushers": filtered_pushers}
for pusher in pusher_dicts:
if self._msc3881_enabled:
pusher["org.matrix.msc3881.enabled"] = pusher["enabled"]
pusher["org.matrix.msc3881.device_id"] = pusher["device_id"]
del pusher["enabled"]
del pusher["device_id"]
return 200, {"pushers": pusher_dicts}
class PushersSetRestServlet(RestServlet):
@@ -65,6 +73,7 @@ class PushersSetRestServlet(RestServlet):
self.auth = hs.get_auth()
self.notifier = hs.get_notifier()
self.pusher_pool = self.hs.get_pusherpool()
self._msc3881_enabled = self.hs.config.experimental.msc3881_enabled
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
@@ -103,6 +112,10 @@ class PushersSetRestServlet(RestServlet):
if "append" in content:
append = content["append"]
enabled = True
if self._msc3881_enabled and "org.matrix.msc3881.enabled" in content:
enabled = content["org.matrix.msc3881.enabled"]
if not append:
await self.pusher_pool.remove_pushers_by_app_id_and_pushkey_not_user(
app_id=content["app_id"],
@@ -111,7 +124,7 @@ class PushersSetRestServlet(RestServlet):
)
try:
await self.pusher_pool.add_pusher(
await self.pusher_pool.add_or_update_pusher(
user_id=user.to_string(),
access_token=requester.access_token_id,
kind=content["kind"],
@@ -122,6 +135,8 @@ class PushersSetRestServlet(RestServlet):
lang=content["lang"],
data=content["data"],
profile_tag=content.get("profile_tag", ""),
enabled=enabled,
device_id=requester.device_id,
)
except PusherConfigException as pce:
raise SynapseError(

View File

@@ -105,6 +105,10 @@ class VersionsRestServlet(RestServlet):
"org.matrix.msc3440.stable": True, # TODO: remove when "v1.3" is added above
# Allows moderators to fetch redacted event content as described in MSC2815
"fi.mau.msc2815": self.config.experimental.msc2815_enabled,
# Adds support for login token requests as per MSC3882
"org.matrix.msc3882": self.config.experimental.msc3882_enabled,
# Adds support for remotely enabling/disabling pushers, as per MSC3881
"org.matrix.msc3881": self.config.experimental.msc3881_enabled,
},
},
)

View File

@@ -15,12 +15,13 @@
# limitations under the License.
import logging
from abc import ABCMeta
from typing import TYPE_CHECKING, Any, Collection, Iterable, Optional, Union
from typing import TYPE_CHECKING, Any, Collection, Dict, Iterable, Optional, Union
from synapse.storage.database import make_in_list_sql_clause # noqa: F401; noqa: F401
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
from synapse.types import get_domain_from_id
from synapse.util import json_decoder
from synapse.util.caches.descriptors import CachedFunction
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -47,6 +48,8 @@ class SQLBaseStore(metaclass=ABCMeta):
self.database_engine = database.engine
self.db_pool = database
self.external_cached_functions: Dict[str, CachedFunction] = {}
def process_replication_rows(
self,
stream_name: str,
@@ -95,7 +98,7 @@ class SQLBaseStore(metaclass=ABCMeta):
def _attempt_to_invalidate_cache(
self, cache_name: str, key: Optional[Collection[Any]]
) -> None:
) -> bool:
"""Attempts to invalidate the cache of the given name, ignoring if the
cache doesn't exist. Mainly used for invalidating caches on workers,
where they may not have the cache.
@@ -113,9 +116,12 @@ class SQLBaseStore(metaclass=ABCMeta):
try:
cache = getattr(self, cache_name)
except AttributeError:
# We probably haven't pulled in the cache in this worker,
# which is fine.
return
# Check if an externally defined module cache has been registered
cache = self.external_cached_functions.get(cache_name)
if not cache:
# We probably haven't pulled in the cache in this worker,
# which is fine.
return False
if key is None:
cache.invalidate_all()
@@ -125,6 +131,13 @@ class SQLBaseStore(metaclass=ABCMeta):
invalidate_method = getattr(cache, "invalidate_local", cache.invalidate)
invalidate_method(tuple(key))
return True
def register_external_cached_function(
self, cache_name: str, func: CachedFunction
) -> None:
self.external_cached_functions[cache_name] = func
def db_to_json(db_content: Union[memoryview, bytes, bytearray, str]) -> Any:
"""

View File

@@ -33,7 +33,7 @@ from synapse.storage.database import (
)
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import MultiWriterIdGenerator
from synapse.util.caches.descriptors import _CachedFunction
from synapse.util.caches.descriptors import CachedFunction
from synapse.util.iterutils import batch_iter
if TYPE_CHECKING:
@@ -269,9 +269,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
return
cache_func.invalidate(keys)
await self.db_pool.runInteraction(
"invalidate_cache_and_stream",
self._send_invalidation_to_replication,
await self.send_invalidation_to_replication(
cache_func.__name__,
keys,
)
@@ -279,7 +277,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
def _invalidate_cache_and_stream(
self,
txn: LoggingTransaction,
cache_func: _CachedFunction,
cache_func: CachedFunction,
keys: Tuple[Any, ...],
) -> None:
"""Invalidates the cache and adds it to the cache stream so slaves
@@ -293,7 +291,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
self._send_invalidation_to_replication(txn, cache_func.__name__, keys)
def _invalidate_all_cache_and_stream(
self, txn: LoggingTransaction, cache_func: _CachedFunction
self, txn: LoggingTransaction, cache_func: CachedFunction
) -> None:
"""Invalidates the entire cache and adds it to the cache stream so slaves
will know to invalidate their caches.
@@ -334,6 +332,16 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
txn, CURRENT_STATE_CACHE_NAME, [room_id]
)
async def send_invalidation_to_replication(
self, cache_name: str, keys: Optional[Collection[Any]]
) -> None:
await self.db_pool.runInteraction(
"send_invalidation_to_replication",
self._send_invalidation_to_replication,
cache_name,
keys,
)
def _send_invalidation_to_replication(
self, txn: LoggingTransaction, cache_name: str, keys: Optional[Iterable[Any]]
) -> None:

View File

@@ -412,6 +412,31 @@ class PersistEventsStore:
assert min_stream_order
assert max_stream_order
# Once the txn completes, invalidate all of the relevant caches. Note that we do this
# up here because it captures all the events_and_contexts before any are removed.
for event, _ in events_and_contexts:
self.store.invalidate_get_event_cache_after_txn(txn, event.event_id)
if event.redacts:
self.store.invalidate_get_event_cache_after_txn(txn, event.redacts)
relates_to = None
relation = relation_from_event(event)
if relation:
relates_to = relation.parent_id
assert event.internal_metadata.stream_ordering is not None
txn.call_after(
self.store._invalidate_caches_for_event,
event.internal_metadata.stream_ordering,
event.event_id,
event.room_id,
event.type,
getattr(event, "state_key", None),
event.redacts,
relates_to,
backfilled=False,
)
self._update_forward_extremities_txn(
txn,
new_forward_extremities=new_forward_extremities,

View File

@@ -1474,32 +1474,38 @@ class EventsWorkerStore(SQLBaseStore):
# the batches as big as possible.
results: Set[str] = set()
for chunk in batch_iter(event_ids, 500):
r = await self._have_seen_events_dict(
[(room_id, event_id) for event_id in chunk]
for event_ids_chunk in batch_iter(event_ids, 500):
events_seen_dict = await self._have_seen_events_dict(
room_id, event_ids_chunk
)
results.update(
eid for (eid, have_event) in events_seen_dict.items() if have_event
)
results.update(eid for ((_rid, eid), have_event) in r.items() if have_event)
return results
@cachedList(cached_method_name="have_seen_event", list_name="keys")
@cachedList(cached_method_name="have_seen_event", list_name="event_ids")
async def _have_seen_events_dict(
self, keys: Collection[Tuple[str, str]]
) -> Dict[Tuple[str, str], bool]:
self,
room_id: str,
event_ids: Collection[str],
) -> Dict[str, bool]:
"""Helper for have_seen_events
Returns:
a dict {(room_id, event_id)-> bool}
a dict {event_id -> bool}
"""
# if the event cache contains the event, obviously we've seen it.
cache_results = {
(rid, eid)
for (rid, eid) in keys
if await self._get_event_cache.contains((eid,))
event_id
for event_id in event_ids
if await self._get_event_cache.contains((event_id,))
}
results = dict.fromkeys(cache_results, True)
remaining = [k for k in keys if k not in cache_results]
remaining = [
event_id for event_id in event_ids if event_id not in cache_results
]
if not remaining:
return results
@@ -1511,23 +1517,21 @@ class EventsWorkerStore(SQLBaseStore):
sql = "SELECT event_id FROM events AS e WHERE "
clause, args = make_in_list_sql_clause(
txn.database_engine, "e.event_id", [eid for (_rid, eid) in remaining]
txn.database_engine, "e.event_id", remaining
)
txn.execute(sql + clause, args)
found_events = {eid for eid, in txn}
# ... and then we can update the results for each key
results.update(
{(rid, eid): (eid in found_events) for (rid, eid) in remaining}
)
results.update({eid: (eid in found_events) for eid in remaining})
await self.db_pool.runInteraction("have_seen_events", have_seen_events_txn)
return results
@cached(max_entries=100000, tree=True)
async def have_seen_event(self, room_id: str, event_id: str) -> bool:
res = await self._have_seen_events_dict(((room_id, event_id),))
return res[(room_id, event_id)]
res = await self._have_seen_events_dict(room_id, [event_id])
return res[event_id]
def _get_current_state_event_counts_txn(
self, txn: LoggingTransaction, room_id: str

View File

@@ -89,6 +89,11 @@ class PusherWorkerStore(SQLBaseStore):
)
continue
# If we're using SQLite, then boolean values are integers. This is
# troublesome since some code using the return value of this method might
# expect it to be a boolean, or will expose it to clients (in responses).
r["enabled"] = bool(r["enabled"])
yield PusherConfig(**r)
async def get_pushers_by_app_id_and_pushkey(
@@ -100,38 +105,52 @@ class PusherWorkerStore(SQLBaseStore):
return await self.get_pushers_by({"user_name": user_id})
async def get_pushers_by(self, keyvalues: Dict[str, Any]) -> Iterator[PusherConfig]:
ret = await self.db_pool.simple_select_list(
"pushers",
keyvalues,
[
"id",
"user_name",
"access_token",
"profile_tag",
"kind",
"app_id",
"app_display_name",
"device_display_name",
"pushkey",
"ts",
"lang",
"data",
"last_stream_ordering",
"last_success",
"failing_since",
],
"""Retrieve pushers that match the given criteria.
Args:
keyvalues: A {column: value} dictionary.
Returns:
The pushers for which the given columns have the given values.
"""
def get_pushers_by_txn(txn: LoggingTransaction) -> List[Dict[str, Any]]:
# We could technically use simple_select_list here, but we need to call
# COALESCE on the 'enabled' column. While it is technically possible to give
# simple_select_list the whole `COALESCE(...) AS ...` as a column name, it
# feels a bit hacky, so it's probably better to just inline the query.
sql = """
SELECT
id, user_name, access_token, profile_tag, kind, app_id,
app_display_name, device_display_name, pushkey, ts, lang, data,
last_stream_ordering, last_success, failing_since,
COALESCE(enabled, TRUE) AS enabled, device_id
FROM pushers
"""
sql += "WHERE %s" % (" AND ".join("%s = ?" % (k,) for k in keyvalues),)
txn.execute(sql, list(keyvalues.values()))
return self.db_pool.cursor_to_dict(txn)
ret = await self.db_pool.runInteraction(
desc="get_pushers_by",
func=get_pushers_by_txn,
)
return self._decode_pushers_rows(ret)
async def get_all_pushers(self) -> Iterator[PusherConfig]:
def get_pushers(txn: LoggingTransaction) -> Iterator[PusherConfig]:
txn.execute("SELECT * FROM pushers")
async def get_enabled_pushers(self) -> Iterator[PusherConfig]:
def get_enabled_pushers_txn(txn: LoggingTransaction) -> Iterator[PusherConfig]:
txn.execute("SELECT * FROM pushers WHERE COALESCE(enabled, TRUE)")
rows = self.db_pool.cursor_to_dict(txn)
return self._decode_pushers_rows(rows)
return await self.db_pool.runInteraction("get_all_pushers", get_pushers)
return await self.db_pool.runInteraction(
"get_enabled_pushers", get_enabled_pushers_txn
)
async def get_all_updated_pushers_rows(
self, instance_name: str, last_id: int, current_id: int, limit: int
@@ -458,7 +477,74 @@ class PusherWorkerStore(SQLBaseStore):
return number_deleted
class PusherStore(PusherWorkerStore):
class PusherBackgroundUpdatesStore(SQLBaseStore):
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_update_handler(
"set_device_id_for_pushers", self._set_device_id_for_pushers
)
async def _set_device_id_for_pushers(
self, progress: JsonDict, batch_size: int
) -> int:
"""Background update to populate the device_id column of the pushers table."""
last_pusher_id = progress.get("pusher_id", 0)
def set_device_id_for_pushers_txn(txn: LoggingTransaction) -> int:
txn.execute(
"""
SELECT p.id, at.device_id
FROM pushers AS p
INNER JOIN access_tokens AS at
ON p.access_token = at.id
WHERE
p.access_token IS NOT NULL
AND at.device_id IS NOT NULL
AND p.id > ?
ORDER BY p.id
LIMIT ?
""",
(last_pusher_id, batch_size),
)
rows = self.db_pool.cursor_to_dict(txn)
if len(rows) == 0:
return 0
self.db_pool.simple_update_many_txn(
txn=txn,
table="pushers",
key_names=("id",),
key_values=[(row["id"],) for row in rows],
value_names=("device_id",),
value_values=[(row["device_id"],) for row in rows],
)
self.db_pool.updates._background_update_progress_txn(
txn, "set_device_id_for_pushers", {"pusher_id": rows[-1]["id"]}
)
return len(rows)
nb_processed = await self.db_pool.runInteraction(
"set_device_id_for_pushers", set_device_id_for_pushers_txn
)
if nb_processed < batch_size:
await self.db_pool.updates._end_background_update(
"set_device_id_for_pushers"
)
return nb_processed
class PusherStore(PusherWorkerStore, PusherBackgroundUpdatesStore):
def get_pushers_stream_token(self) -> int:
return self._pushers_id_gen.get_current_token()
@@ -476,6 +562,8 @@ class PusherStore(PusherWorkerStore):
data: Optional[JsonDict],
last_stream_ordering: int,
profile_tag: str = "",
enabled: bool = True,
device_id: Optional[str] = None,
) -> None:
async with self._pushers_id_gen.get_next() as stream_id:
# no need to lock because `pushers` has a unique key on
@@ -494,6 +582,8 @@ class PusherStore(PusherWorkerStore):
"last_stream_ordering": last_stream_ordering,
"profile_tag": profile_tag,
"id": stream_id,
"enabled": enabled,
"device_id": device_id,
},
desc="add_pusher",
lock=False,

View File

@@ -52,6 +52,8 @@ class _RelatedEvent:
event_id: str
# The sender of the related event.
sender: str
topological_ordering: Optional[int]
stream_ordering: int
class RelationsWorkerStore(SQLBaseStore):
@@ -92,6 +94,9 @@ class RelationsWorkerStore(SQLBaseStore):
# it. The `event_id` must match the `event.event_id`.
assert event.event_id == event_id
# Ensure bad limits aren't being passed in.
assert limit >= 0
where_clause = ["relates_to_id = ?", "room_id = ?"]
where_args: List[Union[str, int]] = [event.event_id, room_id]
is_redacted = event.internal_metadata.is_redacted()
@@ -140,21 +145,34 @@ class RelationsWorkerStore(SQLBaseStore):
) -> Tuple[List[_RelatedEvent], Optional[StreamToken]]:
txn.execute(sql, where_args + [limit + 1])
last_topo_id = None
last_stream_id = None
events = []
for row in txn:
for event_id, relation_type, sender, topo_ordering, stream_ordering in txn:
# Do not include edits for redacted events as they leak event
# content.
if not is_redacted or row[1] != RelationTypes.REPLACE:
events.append(_RelatedEvent(row[0], row[2]))
last_topo_id = row[3]
last_stream_id = row[4]
if not is_redacted or relation_type != RelationTypes.REPLACE:
events.append(
_RelatedEvent(event_id, sender, topo_ordering, stream_ordering)
)
# If there are more events, generate the next pagination key.
# If there are more events, generate the next pagination key from the
# last event returned.
next_token = None
if len(events) > limit and last_topo_id and last_stream_id:
next_key = RoomStreamToken(last_topo_id, last_stream_id)
if len(events) > limit:
# Instead of using the last row (which tells us there is more
# data), use the last row to be returned.
events = events[:limit]
topo = events[-1].topological_ordering
token = events[-1].stream_ordering
if direction == "b":
# Tokens are positions between events.
# This token points *after* the last event in the chunk.
# We need it to point to the event before it in the chunk
# when we are going backwards so we subtract one from the
# stream part.
token -= 1
next_key = RoomStreamToken(topo, token)
if from_token:
next_token = from_token.copy_and_replace(
StreamKeyType.ROOM, next_key

View File

@@ -1334,15 +1334,15 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
if rows:
topo = rows[-1].topological_ordering
toke = rows[-1].stream_ordering
token = rows[-1].stream_ordering
if direction == "b":
# Tokens are positions between events.
# This token points *after* the last event in the chunk.
# We need it to point to the event before it in the chunk
# when we are going backwards so we subtract one from the
# stream part.
toke -= 1
next_token = RoomStreamToken(topo, toke)
token -= 1
next_token = RoomStreamToken(topo, token)
else:
# TODO (erikj): We should work out what to do here instead.
next_token = to_token if to_token else from_token

View File

@@ -0,0 +1,16 @@
/* Copyright 2022 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
ALTER TABLE pushers ADD COLUMN enabled BOOLEAN;

View File

@@ -0,0 +1,20 @@
/* Copyright 2022 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- Add a device_id column to track the device ID that created the pusher. It's NULLable
-- on purpose, because a) it might not be possible to track down the device that created
-- old pushers (pushers.access_token and access_tokens.device_id are both NULLable), and
-- b) access tokens retrieved via the admin API don't have a device associated to them.
ALTER TABLE pushers ADD COLUMN device_id TEXT;

View File

@@ -53,7 +53,7 @@ CacheKey = Union[Tuple, Any]
F = TypeVar("F", bound=Callable[..., Any])
class _CachedFunction(Generic[F]):
class CachedFunction(Generic[F]):
invalidate: Any = None
invalidate_all: Any = None
prefill: Any = None
@@ -242,7 +242,7 @@ class LruCacheDescriptor(_CacheDescriptorBase):
return ret2
wrapped = cast(_CachedFunction, _wrapped)
wrapped = cast(CachedFunction, _wrapped)
wrapped.cache = cache
obj.__dict__[self.name] = wrapped
@@ -363,7 +363,7 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
return make_deferred_yieldable(ret)
wrapped = cast(_CachedFunction, _wrapped)
wrapped = cast(CachedFunction, _wrapped)
if self.num_args == 1:
assert not self.tree
@@ -431,6 +431,12 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
cache: DeferredCache[CacheKey, Any] = cached_method.cache
num_args = cached_method.num_args
if num_args != self.num_args:
raise Exception(
"Number of args (%s) does not match underlying cache_method_name=%s (%s)."
% (self.num_args, self.cached_method_name, num_args)
)
@functools.wraps(self.orig)
def wrapped(*args: Any, **kwargs: Any) -> "defer.Deferred[Dict]":
# If we're passed a cache_context then we'll want to call its
@@ -572,7 +578,7 @@ def cached(
iterable: bool = False,
prune_unread_entries: bool = True,
name: Optional[str] = None,
) -> Callable[[F], _CachedFunction[F]]:
) -> Callable[[F], CachedFunction[F]]:
func = lambda orig: DeferredCacheDescriptor(
orig,
max_entries=max_entries,
@@ -585,7 +591,7 @@ def cached(
name=name,
)
return cast(Callable[[F], _CachedFunction[F]], func)
return cast(Callable[[F], CachedFunction[F]], func)
def cachedList(
@@ -594,7 +600,7 @@ def cachedList(
list_name: str,
num_args: Optional[int] = None,
name: Optional[str] = None,
) -> Callable[[F], _CachedFunction[F]]:
) -> Callable[[F], CachedFunction[F]]:
"""Creates a descriptor that wraps a function in a `DeferredCacheListDescriptor`.
Used to do batch lookups for an already created cache. One of the arguments
@@ -631,7 +637,7 @@ def cachedList(
name=name,
)
return cast(Callable[[F], _CachedFunction[F]], func)
return cast(Callable[[F], CachedFunction[F]], func)
def _get_cache_key_builder(

View File

@@ -114,7 +114,7 @@ class EmailPusherTests(HomeserverTestCase):
)
self.pusher = self.get_success(
self.hs.get_pusherpool().add_pusher(
self.hs.get_pusherpool().add_or_update_pusher(
user_id=self.user_id,
access_token=self.token_id,
kind="email",
@@ -136,7 +136,7 @@ class EmailPusherTests(HomeserverTestCase):
"""
with self.assertRaises(SynapseError) as cm:
self.get_success_or_raise(
self.hs.get_pusherpool().add_pusher(
self.hs.get_pusherpool().add_or_update_pusher(
user_id=self.user_id,
access_token=self.token_id,
kind="email",

View File

@@ -19,9 +19,10 @@ from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
from synapse.logging.context import make_deferred_yieldable
from synapse.push import PusherConfigException
from synapse.rest.client import login, push_rule, receipts, room
from synapse.push import PusherConfig, PusherConfigException
from synapse.rest.client import login, push_rule, pusher, receipts, room
from synapse.server import HomeServer
from synapse.storage.databases.main.registration import TokenLookupResult
from synapse.types import JsonDict
from synapse.util import Clock
@@ -35,6 +36,7 @@ class HTTPPusherTests(HomeserverTestCase):
login.register_servlets,
receipts.register_servlets,
push_rule.register_servlets,
pusher.register_servlets,
]
user_id = True
hijack_auth = False
@@ -74,7 +76,7 @@ class HTTPPusherTests(HomeserverTestCase):
def test_data(data: Optional[JsonDict]) -> None:
self.get_failure(
self.hs.get_pusherpool().add_pusher(
self.hs.get_pusherpool().add_or_update_pusher(
user_id=user_id,
access_token=token_id,
kind="http",
@@ -119,7 +121,7 @@ class HTTPPusherTests(HomeserverTestCase):
token_id = user_tuple.token_id
self.get_success(
self.hs.get_pusherpool().add_pusher(
self.hs.get_pusherpool().add_or_update_pusher(
user_id=user_id,
access_token=token_id,
kind="http",
@@ -235,7 +237,7 @@ class HTTPPusherTests(HomeserverTestCase):
token_id = user_tuple.token_id
self.get_success(
self.hs.get_pusherpool().add_pusher(
self.hs.get_pusherpool().add_or_update_pusher(
user_id=user_id,
access_token=token_id,
kind="http",
@@ -355,7 +357,7 @@ class HTTPPusherTests(HomeserverTestCase):
token_id = user_tuple.token_id
self.get_success(
self.hs.get_pusherpool().add_pusher(
self.hs.get_pusherpool().add_or_update_pusher(
user_id=user_id,
access_token=token_id,
kind="http",
@@ -441,7 +443,7 @@ class HTTPPusherTests(HomeserverTestCase):
token_id = user_tuple.token_id
self.get_success(
self.hs.get_pusherpool().add_pusher(
self.hs.get_pusherpool().add_or_update_pusher(
user_id=user_id,
access_token=token_id,
kind="http",
@@ -518,7 +520,7 @@ class HTTPPusherTests(HomeserverTestCase):
token_id = user_tuple.token_id
self.get_success(
self.hs.get_pusherpool().add_pusher(
self.hs.get_pusherpool().add_or_update_pusher(
user_id=user_id,
access_token=token_id,
kind="http",
@@ -624,7 +626,7 @@ class HTTPPusherTests(HomeserverTestCase):
token_id = user_tuple.token_id
self.get_success(
self.hs.get_pusherpool().add_pusher(
self.hs.get_pusherpool().add_or_update_pusher(
user_id=user_id,
access_token=token_id,
kind="http",
@@ -728,18 +730,38 @@ class HTTPPusherTests(HomeserverTestCase):
)
self.assertEqual(channel.code, 200, channel.json_body)
def _make_user_with_pusher(self, username: str) -> Tuple[str, str]:
def _make_user_with_pusher(
self, username: str, enabled: bool = True
) -> Tuple[str, str]:
"""Registers a user and creates a pusher for them.
Args:
username: the localpart of the new user's Matrix ID.
enabled: whether to create the pusher in an enabled or disabled state.
"""
user_id = self.register_user(username, "pass")
access_token = self.login(username, "pass")
# Register the pusher
self._set_pusher(user_id, access_token, enabled)
return user_id, access_token
def _set_pusher(self, user_id: str, access_token: str, enabled: bool) -> None:
"""Creates or updates the pusher for the given user.
Args:
user_id: the user's Matrix ID.
access_token: the access token associated with the pusher.
enabled: whether to enable or disable the pusher.
"""
user_tuple = self.get_success(
self.hs.get_datastores().main.get_user_by_access_token(access_token)
)
token_id = user_tuple.token_id
self.get_success(
self.hs.get_pusherpool().add_pusher(
self.hs.get_pusherpool().add_or_update_pusher(
user_id=user_id,
access_token=token_id,
kind="http",
@@ -749,11 +771,11 @@ class HTTPPusherTests(HomeserverTestCase):
pushkey="a@example.com",
lang=None,
data={"url": "http://example.com/_matrix/push/v1/notify"},
enabled=enabled,
device_id=user_tuple.device_id,
)
)
return user_id, access_token
def test_dont_notify_rule_overrides_message(self) -> None:
"""
The override push rule will suppress notification
@@ -791,3 +813,148 @@ class HTTPPusherTests(HomeserverTestCase):
# The user sends a message back (sends a notification)
self.helper.send(room, body="Hello", tok=access_token)
self.assertEqual(len(self.push_attempts), 1)
@override_config({"experimental_features": {"msc3881_enabled": True}})
def test_disable(self) -> None:
"""Tests that disabling a pusher means it's not pushed to anymore."""
user_id, access_token = self._make_user_with_pusher("user")
other_user_id, other_access_token = self._make_user_with_pusher("otheruser")
room = self.helper.create_room_as(user_id, tok=access_token)
self.helper.join(room=room, user=other_user_id, tok=other_access_token)
# Send a message and check that it generated a push.
self.helper.send(room, body="Hi!", tok=other_access_token)
self.assertEqual(len(self.push_attempts), 1)
# Disable the pusher.
self._set_pusher(user_id, access_token, enabled=False)
# Send another message and check that it did not generate a push.
self.helper.send(room, body="Hi!", tok=other_access_token)
self.assertEqual(len(self.push_attempts), 1)
# Get the pushers for the user and check that it is marked as disabled.
channel = self.make_request("GET", "/pushers", access_token=access_token)
self.assertEqual(channel.code, 200)
self.assertEqual(len(channel.json_body["pushers"]), 1)
enabled = channel.json_body["pushers"][0]["org.matrix.msc3881.enabled"]
self.assertFalse(enabled)
self.assertTrue(isinstance(enabled, bool))
@override_config({"experimental_features": {"msc3881_enabled": True}})
def test_enable(self) -> None:
"""Tests that enabling a disabled pusher means it gets pushed to."""
# Create the user with the pusher already disabled.
user_id, access_token = self._make_user_with_pusher("user", enabled=False)
other_user_id, other_access_token = self._make_user_with_pusher("otheruser")
room = self.helper.create_room_as(user_id, tok=access_token)
self.helper.join(room=room, user=other_user_id, tok=other_access_token)
# Send a message and check that it did not generate a push.
self.helper.send(room, body="Hi!", tok=other_access_token)
self.assertEqual(len(self.push_attempts), 0)
# Enable the pusher.
self._set_pusher(user_id, access_token, enabled=True)
# Send another message and check that it did generate a push.
self.helper.send(room, body="Hi!", tok=other_access_token)
self.assertEqual(len(self.push_attempts), 1)
# Get the pushers for the user and check that it is marked as enabled.
channel = self.make_request("GET", "/pushers", access_token=access_token)
self.assertEqual(channel.code, 200)
self.assertEqual(len(channel.json_body["pushers"]), 1)
enabled = channel.json_body["pushers"][0]["org.matrix.msc3881.enabled"]
self.assertTrue(enabled)
self.assertTrue(isinstance(enabled, bool))
@override_config({"experimental_features": {"msc3881_enabled": True}})
def test_null_enabled(self) -> None:
"""Tests that a pusher that has an 'enabled' column set to NULL (eg pushers
created before the column was introduced) is considered enabled.
"""
# We intentionally set 'enabled' to None so that it's stored as NULL in the
# database.
user_id, access_token = self._make_user_with_pusher("user", enabled=None) # type: ignore[arg-type]
channel = self.make_request("GET", "/pushers", access_token=access_token)
self.assertEqual(channel.code, 200)
self.assertEqual(len(channel.json_body["pushers"]), 1)
self.assertTrue(channel.json_body["pushers"][0]["org.matrix.msc3881.enabled"])
def test_update_different_device_access_token_device_id(self) -> None:
"""Tests that if we create a pusher from one device, the update it from another
device, the access token and device ID associated with the pusher stays the
same.
"""
# Create a user with a pusher.
user_id, access_token = self._make_user_with_pusher("user")
# Get the token ID for the current access token, since that's what we store in
# the pushers table. Also get the device ID from it.
user_tuple = self.get_success(
self.hs.get_datastores().main.get_user_by_access_token(access_token)
)
token_id = user_tuple.token_id
device_id = user_tuple.device_id
# Generate a new access token, and update the pusher with it.
new_token = self.login("user", "pass")
self._set_pusher(user_id, new_token, enabled=False)
# Get the current list of pushers for the user.
ret = self.get_success(
self.hs.get_datastores().main.get_pushers_by({"user_name": user_id})
)
pushers: List[PusherConfig] = list(ret)
# Check that we still have one pusher, and that the access token and device ID
# associated with it didn't change.
self.assertEqual(len(pushers), 1)
self.assertEqual(pushers[0].access_token, token_id)
self.assertEqual(pushers[0].device_id, device_id)
@override_config({"experimental_features": {"msc3881_enabled": True}})
def test_device_id(self) -> None:
"""Tests that a pusher created with a given device ID shows that device ID in
GET /pushers requests.
"""
self.register_user("user", "pass")
access_token = self.login("user", "pass")
# We create the pusher with an HTTP request rather than with
# _make_user_with_pusher so that we can test the device ID is correctly set when
# creating a pusher via an API call.
self.make_request(
method="POST",
path="/pushers/set",
content={
"kind": "http",
"app_id": "m.http",
"app_display_name": "HTTP Push Notifications",
"device_display_name": "pushy push",
"pushkey": "a@example.com",
"lang": "en",
"data": {"url": "http://example.com/_matrix/push/v1/notify"},
},
access_token=access_token,
)
# Look up the user info for the access token so we can compare the device ID.
lookup_result: TokenLookupResult = self.get_success(
self.hs.get_datastores().main.get_user_by_access_token(access_token)
)
# Get the user's devices and check it has the correct device ID.
channel = self.make_request("GET", "/pushers", access_token=access_token)
self.assertEqual(channel.code, 200)
self.assertEqual(len(channel.json_body["pushers"]), 1)
self.assertEqual(
channel.json_body["pushers"][0]["org.matrix.msc3881.device_id"],
lookup_result.device_id,
)

View File

@@ -0,0 +1,79 @@
# Copyright 2022 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import synapse
from synapse.module_api import cached
from tests.replication._base import BaseMultiWorkerStreamTestCase
logger = logging.getLogger(__name__)
FIRST_VALUE = "one"
SECOND_VALUE = "two"
KEY = "mykey"
class TestCache:
current_value = FIRST_VALUE
@cached()
async def cached_function(self, user_id: str) -> str:
return self.current_value
class ModuleCacheInvalidationTestCase(BaseMultiWorkerStreamTestCase):
servlets = [
synapse.rest.admin.register_servlets,
]
def test_module_cache_full_invalidation(self):
main_cache = TestCache()
self.hs.get_module_api().register_cached_function(main_cache.cached_function)
worker_hs = self.make_worker_hs("synapse.app.generic_worker")
worker_cache = TestCache()
worker_hs.get_module_api().register_cached_function(
worker_cache.cached_function
)
self.assertEqual(FIRST_VALUE, self.get_success(main_cache.cached_function(KEY)))
self.assertEqual(
FIRST_VALUE, self.get_success(worker_cache.cached_function(KEY))
)
main_cache.current_value = SECOND_VALUE
worker_cache.current_value = SECOND_VALUE
# No invalidation yet, should return the cached value on both the main process and the worker
self.assertEqual(FIRST_VALUE, self.get_success(main_cache.cached_function(KEY)))
self.assertEqual(
FIRST_VALUE, self.get_success(worker_cache.cached_function(KEY))
)
# Full invalidation on the main process, should be replicated on the worker that
# should returned the updated value too
self.get_success(
self.hs.get_module_api().invalidate_cache(
main_cache.cached_function, (KEY,)
)
)
self.assertEqual(
SECOND_VALUE, self.get_success(main_cache.cached_function(KEY))
)
self.assertEqual(
SECOND_VALUE, self.get_success(worker_cache.cached_function(KEY))
)

View File

@@ -55,7 +55,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
token_id = user_dict.token_id
self.get_success(
self.hs.get_pusherpool().add_pusher(
self.hs.get_pusherpool().add_or_update_pusher(
user_id=user_id,
access_token=token_id,
kind="http",

View File

@@ -2839,7 +2839,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
token_id = user_tuple.token_id
self.get_success(
self.hs.get_pusherpool().add_pusher(
self.hs.get_pusherpool().add_or_update_pusher(
user_id=self.other_user,
access_token=token_id,
kind="http",

View File

@@ -0,0 +1,132 @@
# Copyright 2022 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.test.proto_helpers import MemoryReactor
from synapse.rest import admin
from synapse.rest.client import login, login_token_request
from synapse.server import HomeServer
from synapse.util import Clock
from tests import unittest
from tests.unittest import override_config
class LoginTokenRequestServletTestCase(unittest.HomeserverTestCase):
servlets = [
login.register_servlets,
admin.register_servlets,
login_token_request.register_servlets,
]
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.hs = self.setup_test_homeserver()
self.hs.config.registration.enable_registration = True
self.hs.config.registration.registrations_require_3pid = []
self.hs.config.registration.auto_join_rooms = []
self.hs.config.captcha.enable_registration_captcha = False
return self.hs
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.user = "user123"
self.password = "password"
def test_disabled(self) -> None:
channel = self.make_request("POST", "/login/token", {}, access_token=None)
self.assertEqual(channel.code, 400)
self.register_user(self.user, self.password)
token = self.login(self.user, self.password)
channel = self.make_request("POST", "/login/token", {}, access_token=token)
self.assertEqual(channel.code, 400)
@override_config({"experimental_features": {"msc3882_enabled": True}})
def test_require_auth(self) -> None:
channel = self.make_request("POST", "/login/token", {}, access_token=None)
self.assertEqual(channel.code, 401)
@override_config({"experimental_features": {"msc3882_enabled": True}})
def test_uia_on(self) -> None:
user_id = self.register_user(self.user, self.password)
token = self.login(self.user, self.password)
channel = self.make_request("POST", "/login/token", {}, access_token=token)
self.assertEqual(channel.code, 401)
self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"])
session = channel.json_body["session"]
uia = {
"auth": {
"type": "m.login.password",
"identifier": {"type": "m.id.user", "user": self.user},
"password": self.password,
"session": session,
},
}
channel = self.make_request("POST", "/login/token", uia, access_token=token)
self.assertEqual(channel.code, 200)
self.assertEqual(channel.json_body["expires_in"], 300)
login_token = channel.json_body["login_token"]
channel = self.make_request(
"POST",
"/login",
content={"type": "m.login.token", "token": login_token},
)
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual(channel.json_body["user_id"], user_id)
@override_config(
{"experimental_features": {"msc3882_enabled": True, "msc3882_ui_auth": False}}
)
def test_uia_off(self) -> None:
user_id = self.register_user(self.user, self.password)
token = self.login(self.user, self.password)
channel = self.make_request("POST", "/login/token", {}, access_token=token)
self.assertEqual(channel.code, 200)
self.assertEqual(channel.json_body["expires_in"], 300)
login_token = channel.json_body["login_token"]
channel = self.make_request(
"POST",
"/login",
content={"type": "m.login.token", "token": login_token},
)
self.assertEqual(channel.code, 200, channel.result)
self.assertEqual(channel.json_body["user_id"], user_id)
@override_config(
{
"experimental_features": {
"msc3882_enabled": True,
"msc3882_ui_auth": False,
"msc3882_token_timeout": "15s",
}
}
)
def test_expires_in(self) -> None:
self.register_user(self.user, self.password)
token = self.login(self.user, self.password)
channel = self.make_request("POST", "/login/token", {}, access_token=token)
self.assertEqual(channel.code, 200)
self.assertEqual(channel.json_body["expires_in"], 15)

View File

@@ -788,6 +788,7 @@ class RelationPaginationTestCase(BaseRelationsTestCase):
channel.json_body["chunk"][0],
)
@unittest.override_config({"experimental_features": {"msc3715_enabled": True}})
def test_repeated_paginate_relations(self) -> None:
"""Test that if we paginate using a limit and tokens then we get the
expected events.
@@ -809,7 +810,7 @@ class RelationPaginationTestCase(BaseRelationsTestCase):
channel = self.make_request(
"GET",
f"/_matrix/client/v1/rooms/{self.room}/relations/{self.parent_id}?limit=1{from_token}",
f"/_matrix/client/v1/rooms/{self.room}/relations/{self.parent_id}?limit=3{from_token}",
access_token=self.user_token,
)
self.assertEqual(200, channel.code, channel.json_body)
@@ -827,6 +828,32 @@ class RelationPaginationTestCase(BaseRelationsTestCase):
found_event_ids.reverse()
self.assertEqual(found_event_ids, expected_event_ids)
# Test forward pagination.
prev_token = ""
found_event_ids = []
for _ in range(20):
from_token = ""
if prev_token:
from_token = "&from=" + prev_token
channel = self.make_request(
"GET",
f"/_matrix/client/v1/rooms/{self.room}/relations/{self.parent_id}?org.matrix.msc3715.dir=f&limit=3{from_token}",
access_token=self.user_token,
)
self.assertEqual(200, channel.code, channel.json_body)
found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"])
next_batch = channel.json_body.get("next_batch")
self.assertNotEqual(prev_token, next_batch)
prev_token = next_batch
if not prev_token:
break
self.assertEqual(found_event_ids, expected_event_ids)
def test_pagination_from_sync_and_messages(self) -> None:
"""Pagination tokens from /sync and /messages can be used to paginate /relations."""
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "A")

View File

@@ -103,6 +103,11 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase):
self.assertEqual(ctx.get_resource_usage().db_txn_count, 0)
def test_persisting_event_invalidates_cache(self):
"""
Test to make sure that the `have_seen_event` cache
is invalidated after we persist an event and returns
the updated value.
"""
event, event_context = self.get_success(
create_event(
self.hs,
@@ -145,6 +150,33 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase):
# That should result in a single db query to lookup
self.assertEqual(ctx.get_resource_usage().db_txn_count, 1)
def test_invalidate_cache_by_room_id(self):
"""
Test to make sure that all events associated with the given `(room_id,)`
are invalidated in the `have_seen_event` cache.
"""
with LoggingContext(name="test") as ctx:
# Prime the cache with some values
res = self.get_success(
self.store.have_seen_events(self.room_id, self.event_ids)
)
self.assertEqual(res, set(self.event_ids))
# That should result in a single db query to lookup
self.assertEqual(ctx.get_resource_usage().db_txn_count, 1)
# Clear the cache with any events associated with the `room_id`
self.store.have_seen_event.invalidate((self.room_id,))
with LoggingContext(name="test") as ctx:
res = self.get_success(
self.store.have_seen_events(self.room_id, self.event_ids)
)
self.assertEqual(res, set(self.event_ids))
# Since we cleared the cache, it should result in another db query to lookup
self.assertEqual(ctx.get_resource_usage().db_txn_count, 1)
class EventCacheTestCase(unittest.HomeserverTestCase):
"""Test that the various layers of event cache works."""

View File

@@ -300,47 +300,31 @@ class HomeserverTestCase(TestCase):
if hasattr(self, "user_id"):
if self.hijack_auth:
assert self.helper.auth_user_id is not None
token = "some_fake_token"
# We need a valid token ID to satisfy foreign key constraints.
token_id = self.get_success(
self.hs.get_datastores().main.add_access_token_to_user(
self.helper.auth_user_id,
"some_fake_token",
token,
None,
None,
)
)
async def get_user_by_access_token(
token: Optional[str] = None, allow_guest: bool = False
) -> JsonDict:
assert self.helper.auth_user_id is not None
return {
"user": UserID.from_string(self.helper.auth_user_id),
"token_id": token_id,
"is_guest": False,
}
async def get_user_by_req(
request: SynapseRequest,
allow_guest: bool = False,
allow_expired: bool = False,
) -> Requester:
# This has to be a function and not just a Mock, because
# `self.helper.auth_user_id` is temporarily reassigned in some tests
async def get_requester(*args, **kwargs) -> Requester:
assert self.helper.auth_user_id is not None
return create_requester(
UserID.from_string(self.helper.auth_user_id),
token_id,
False,
False,
None,
user_id=UserID.from_string(self.helper.auth_user_id),
access_token_id=token_id,
)
# Type ignore: mypy doesn't like us assigning to methods.
self.hs.get_auth().get_user_by_req = get_user_by_req # type: ignore[assignment]
self.hs.get_auth().get_user_by_access_token = get_user_by_access_token # type: ignore[assignment]
self.hs.get_auth().get_access_token_from_request = Mock( # type: ignore[assignment]
return_value="1234"
)
self.hs.get_auth().get_user_by_req = get_requester # type: ignore[assignment]
self.hs.get_auth().get_user_by_access_token = get_requester # type: ignore[assignment]
self.hs.get_auth().get_access_token_from_request = Mock(return_value=token) # type: ignore[assignment]
if self.needs_threadpool:
self.reactor.threadpool = ThreadPool() # type: ignore[assignment]

View File

@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Set
from typing import Iterable, Set, Tuple
from unittest import mock
from twisted.internet import defer, reactor
@@ -1008,3 +1008,34 @@ class CachedListDescriptorTestCase(unittest.TestCase):
obj.inner_context_was_finished, "Tried to restart a finished logcontext"
)
self.assertEqual(current_context(), SENTINEL_CONTEXT)
def test_num_args_mismatch(self):
"""
Make sure someone does not accidentally use @cachedList on a method with
a mismatch in the number args to the underlying single cache method.
"""
class Cls:
@descriptors.cached(tree=True)
def fn(self, room_id, event_id):
pass
# This is wrong ❌. `@cachedList` expects to be given the same number
# of arguments as the underlying cached function, just with one of
# the arguments being an iterable
@descriptors.cachedList(cached_method_name="fn", list_name="keys")
def list_fn(self, keys: Iterable[Tuple[str, str]]):
pass
# Corrected syntax ✅
#
# @cachedList(cached_method_name="fn", list_name="event_ids")
# async def list_fn(
# self, room_id: str, event_ids: Collection[str],
# )
obj = Cls()
# Make sure this raises an error about the arg mismatch
with self.assertRaises(Exception):
obj.list_fn([("foo", "bar")])