diff --git a/.github/workflows/latest_deps.yml b/.github/workflows/latest_deps.yml index 8366ac9393..9a708286a4 100644 --- a/.github/workflows/latest_deps.yml +++ b/.github/workflows/latest_deps.yml @@ -201,10 +201,11 @@ jobs: open-issue: if: "failure() && github.event_name != 'push' && github.event_name != 'pull_request'" needs: - # TODO: should mypy be included here? It feels more brittle than the other two. + # TODO: should mypy be included here? It feels more brittle than the others. - mypy - trial - sytest + - complement runs-on: ubuntu-latest diff --git a/changelog.d/13667.feature b/changelog.d/13667.feature new file mode 100644 index 0000000000..a0b3cfe18c --- /dev/null +++ b/changelog.d/13667.feature @@ -0,0 +1 @@ +Add cache invalidation across workers to module API. diff --git a/changelog.d/13722.feature b/changelog.d/13722.feature new file mode 100644 index 0000000000..588d143c0f --- /dev/null +++ b/changelog.d/13722.feature @@ -0,0 +1 @@ +Experimental implementation of MSC3882 to allow an existing device/session to generate a login token for use on a new device/session. diff --git a/changelog.d/13772.doc b/changelog.d/13772.doc new file mode 100644 index 0000000000..3398ff3765 --- /dev/null +++ b/changelog.d/13772.doc @@ -0,0 +1 @@ +Add `worker_main_http_uri` for the worker generator bash script. diff --git a/changelog.d/13799.feature b/changelog.d/13799.feature new file mode 100644 index 0000000000..6c8e5cffe2 --- /dev/null +++ b/changelog.d/13799.feature @@ -0,0 +1 @@ +Add experimental support for [MSC3881: Remotely toggle push notifications for another client](https://github.com/matrix-org/matrix-spec-proposals/pull/3881). diff --git a/changelog.d/13809.misc b/changelog.d/13809.misc new file mode 100644 index 0000000000..c2dacca2f2 --- /dev/null +++ b/changelog.d/13809.misc @@ -0,0 +1 @@ +Improve the `synapse.api.auth.Auth` mock used in unit tests. diff --git a/changelog.d/13831.feature b/changelog.d/13831.feature new file mode 100644 index 0000000000..6c8e5cffe2 --- /dev/null +++ b/changelog.d/13831.feature @@ -0,0 +1 @@ +Add experimental support for [MSC3881: Remotely toggle push notifications for another client](https://github.com/matrix-org/matrix-spec-proposals/pull/3881). diff --git a/changelog.d/13832.feature b/changelog.d/13832.feature new file mode 100644 index 0000000000..1dc1d66efe --- /dev/null +++ b/changelog.d/13832.feature @@ -0,0 +1 @@ +Improve validation for the unspecced, internal-only `_matrix/client/unstable/add_threepid/msisdn/submit_token` endpoint. diff --git a/changelog.d/13836.doc b/changelog.d/13836.doc new file mode 100644 index 0000000000..f2edab00f4 --- /dev/null +++ b/changelog.d/13836.doc @@ -0,0 +1 @@ +Fix a mistake in sso_mapping_providers.md: `map_user_attributes` is expected to return `display_name` not `displayname`. diff --git a/changelog.d/13840.bugfix b/changelog.d/13840.bugfix new file mode 100644 index 0000000000..0f014439a8 --- /dev/null +++ b/changelog.d/13840.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in Synapse v1.53.0 where the experimental implementation of [MSC3715](https://github.com/matrix-org/matrix-spec-proposals/pull/3715) would give incorrect results when paginating forward. diff --git a/changelog.d/13850.misc b/changelog.d/13850.misc new file mode 100644 index 0000000000..a973118aaf --- /dev/null +++ b/changelog.d/13850.misc @@ -0,0 +1 @@ +Fix the release script not publishing binary wheels. \ No newline at end of file diff --git a/changelog.d/13859.misc b/changelog.d/13859.misc new file mode 100644 index 0000000000..2780a4af3c --- /dev/null +++ b/changelog.d/13859.misc @@ -0,0 +1 @@ +Raise issue if complement fails with latest deps. diff --git a/changelog.d/13860.feature b/changelog.d/13860.feature new file mode 100644 index 0000000000..6c8e5cffe2 --- /dev/null +++ b/changelog.d/13860.feature @@ -0,0 +1 @@ +Add experimental support for [MSC3881: Remotely toggle push notifications for another client](https://github.com/matrix-org/matrix-spec-proposals/pull/3881). diff --git a/changelog.d/13863.bugfix b/changelog.d/13863.bugfix new file mode 100644 index 0000000000..74264a4fab --- /dev/null +++ b/changelog.d/13863.bugfix @@ -0,0 +1 @@ +Fix `have_seen_event` cache not being invalidated after we persist an event which causes inefficiency effects like extra `/state` federation calls. diff --git a/changelog.d/13870.doc b/changelog.d/13870.doc new file mode 100644 index 0000000000..2598bc270c --- /dev/null +++ b/changelog.d/13870.doc @@ -0,0 +1 @@ +Fix a cross-link from the register admin API to the `registration_shared_secret` configuration documentation. diff --git a/contrib/workers-bash-scripts/create-multiple-generic-workers.md b/contrib/workers-bash-scripts/create-multiple-generic-workers.md index d303101429..c9be707b3c 100644 --- a/contrib/workers-bash-scripts/create-multiple-generic-workers.md +++ b/contrib/workers-bash-scripts/create-multiple-generic-workers.md @@ -7,7 +7,7 @@ You can alternatively create multiple worker configuration files with a simple ` #!/bin/bash for i in {1..5} do -cat << EOF >> generic_worker$i.yaml +cat << EOF > generic_worker$i.yaml worker_app: synapse.app.generic_worker worker_name: generic_worker$i @@ -15,6 +15,8 @@ worker_name: generic_worker$i worker_replication_host: 127.0.0.1 worker_replication_http_port: 9093 +worker_main_http_uri: http://localhost:8008/ + worker_listeners: - type: http port: 808$i diff --git a/docs/admin_api/register_api.md b/docs/admin_api/register_api.md index f6be31b443..dd2830f3a1 100644 --- a/docs/admin_api/register_api.md +++ b/docs/admin_api/register_api.md @@ -5,7 +5,7 @@ non-interactive way. This is generally used for bootstrapping a Synapse instance with administrator accounts. To authenticate yourself to the server, you will need both the shared secret -([`registration_shared_secret`](../configuration/config_documentation.md#registration_shared_secret) +([`registration_shared_secret`](../usage/configuration/config_documentation.md#registration_shared_secret) in the homeserver configuration), and a one-time nonce. If the registration shared secret is not configured, this API is not enabled. diff --git a/docs/sso_mapping_providers.md b/docs/sso_mapping_providers.md index 817499149f..9f5e5fbbe1 100644 --- a/docs/sso_mapping_providers.md +++ b/docs/sso_mapping_providers.md @@ -73,8 +73,8 @@ A custom mapping provider must specify the following methods: * `async def map_user_attributes(self, userinfo, token, failures)` - This method must be async. - Arguments: - - `userinfo` - A `authlib.oidc.core.claims.UserInfo` object to extract user - information from. + - `userinfo` - An [`authlib.oidc.core.claims.UserInfo`](https://docs.authlib.org/en/latest/specs/oidc.html#authlib.oidc.core.UserInfo) + object to extract user information from. - `token` - A dictionary which includes information necessary to make further requests to the OpenID provider. - `failures` - An `int` that represents the amount of times the returned @@ -91,7 +91,13 @@ A custom mapping provider must specify the following methods: `None`, the user is prompted to pick their own username. This is only used during a user's first login. Once a localpart has been associated with a remote user ID (see `get_remote_user_id`) it cannot be updated. - - `displayname`: An optional string, the display name for the user. + - `confirm_localpart`: A boolean. If set to `True`, when a `localpart` + string is returned from this method, Synapse will prompt the user to + either accept this localpart or pick their own username. Otherwise this + option has no effect. If omitted, defaults to `False`. + - `display_name`: An optional string, the display name for the user. + - `emails`: A list of strings, the email address(es) to associate with + this user. If omitted, defaults to an empty list. * `async def get_extra_attributes(self, userinfo, token)` - This method must be async. - Arguments: diff --git a/scripts-dev/mypy_synapse_plugin.py b/scripts-dev/mypy_synapse_plugin.py index d08517a953..2c377533c0 100644 --- a/scripts-dev/mypy_synapse_plugin.py +++ b/scripts-dev/mypy_synapse_plugin.py @@ -29,7 +29,7 @@ class SynapsePlugin(Plugin): self, fullname: str ) -> Optional[Callable[[MethodSigContext], CallableType]]: if fullname.startswith( - "synapse.util.caches.descriptors._CachedFunction.__call__" + "synapse.util.caches.descriptors.CachedFunction.__call__" ) or fullname.startswith( "synapse.util.caches.descriptors._LruCachedFunction.__call__" ): @@ -38,7 +38,7 @@ class SynapsePlugin(Plugin): def cached_function_method_signature(ctx: MethodSigContext) -> CallableType: - """Fixes the `_CachedFunction.__call__` signature to be correct. + """Fixes the `CachedFunction.__call__` signature to be correct. It already has *almost* the correct signature, except: diff --git a/scripts-dev/release.py b/scripts-dev/release.py index 6603bc593b..c82c58c54b 100755 --- a/scripts-dev/release.py +++ b/scripts-dev/release.py @@ -427,11 +427,12 @@ def _publish(gh_token: str) -> None: @cli.command() -def upload() -> None: - _upload() +@click.option("--gh-token", envvar=["GH_TOKEN", "GITHUB_TOKEN"], required=False) +def upload(gh_token: Optional[str]) -> None: + _upload(gh_token) -def _upload() -> None: +def _upload(gh_token: Optional[str]) -> None: """Upload release to pypi.""" current_version = get_package_version() @@ -444,18 +445,40 @@ def _upload() -> None: click.echo("Tag {tag_name} (tag.commit) is not currently checked out!") click.get_current_context().abort() - pypi_asset_names = [ - f"matrix_synapse-{current_version}-py3-none-any.whl", - f"matrix-synapse-{current_version}.tar.gz", - ] + # Query all the assets corresponding to this release. + gh = Github(gh_token) + gh_repo = gh.get_repo("matrix-org/synapse") + gh_release = gh_repo.get_release(tag_name) + + all_assets = set(gh_release.get_assets()) + + # Only accept the wheels and sdist. + # Notably: we don't care about debs.tar.xz. + asset_names_and_urls = sorted( + (asset.name, asset.browser_download_url) + for asset in all_assets + if asset.name.endswith((".whl", ".tar.gz")) + ) + + # Print out what we've determined. + print("Found relevant assets:") + for asset_name, _ in asset_names_and_urls: + print(f" - {asset_name}") + + ignored_asset_names = sorted( + {asset.name for asset in all_assets} + - {asset_name for asset_name, _ in asset_names_and_urls} + ) + print("\nIgnoring irrelevant assets:") + for asset_name in ignored_asset_names: + print(f" - {asset_name}") with TemporaryDirectory(prefix=f"synapse_upload_{tag_name}_") as tmpdir: - for name in pypi_asset_names: + for name, asset_download_url in asset_names_and_urls: filename = path.join(tmpdir, name) - url = f"https://github.com/matrix-org/synapse/releases/download/{tag_name}/{name}" click.echo(f"Downloading {name} into {filename}") - urllib.request.urlretrieve(url, filename=filename) + urllib.request.urlretrieve(asset_download_url, filename=filename) if click.confirm("Upload to PyPI?", default=True): subprocess.run("twine upload *", shell=True, cwd=tmpdir) @@ -672,7 +695,7 @@ def full(gh_token: str) -> None: _publish(gh_token) click.echo("\n*** upload ***") - _upload() + _upload(gh_token) click.echo("\n*** merge back ***") _merge_back() diff --git a/synapse/_scripts/synapse_port_db.py b/synapse/_scripts/synapse_port_db.py index 30983c47fb..450ba462ba 100755 --- a/synapse/_scripts/synapse_port_db.py +++ b/synapse/_scripts/synapse_port_db.py @@ -111,6 +111,7 @@ BOOLEAN_COLUMNS = { "e2e_fallback_keys_json": ["used"], "access_tokens": ["used"], "device_lists_changes_in_room": ["converted_to_destinations"], + "pushers": ["enabled"], } diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index 702b81e636..bf27f6c101 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -93,3 +93,13 @@ class ExperimentalConfig(Config): # MSC3852: Expose last seen user agent field on /_matrix/client/v3/devices. self.msc3852_enabled: bool = experimental.get("msc3852_enabled", False) + + # MSC3881: Remotely toggle push notifications for another client + self.msc3881_enabled: bool = experimental.get("msc3881_enabled", False) + + # MSC3882: Allow an existing session to sign in a new session + self.msc3882_enabled: bool = experimental.get("msc3882_enabled", False) + self.msc3882_ui_auth: bool = experimental.get("msc3882_ui_auth", True) + self.msc3882_token_timeout = self.parse_duration( + experimental.get("msc3882_token_timeout", "5m") + ) diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 20ec22105a..cfcadb34db 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -997,7 +997,7 @@ class RegistrationHandler: assert user_tuple token_id = user_tuple.token_id - await self.pusher_pool.add_pusher( + await self.pusher_pool.add_or_update_pusher( user_id=user_id, access_token=token_id, kind="email", @@ -1005,7 +1005,7 @@ class RegistrationHandler: app_display_name="Email Notifications", device_display_name=threepid["address"], pushkey=threepid["address"], - lang=None, # We don't know a user's language here + lang=None, data={}, ) diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index 1e171f3f71..6bc1cbd787 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -128,6 +128,9 @@ class SsoIdentityProvider(Protocol): @attr.s(auto_attribs=True) class UserAttributes: + # NB: This struct is documented in docs/sso_mapping_providers.md so that users can + # populate it with data from their own mapping providers. + # the localpart of the mxid that the mapper has assigned to the user. # if `None`, the mapper has not picked a userid, and the user should be prompted to # enter one. diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 9287c0fb8d..59755bff6d 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -125,7 +125,7 @@ from synapse.types import ( ) from synapse.util import Clock from synapse.util.async_helpers import maybe_awaitable -from synapse.util.caches.descriptors import cached +from synapse.util.caches.descriptors import CachedFunction, cached from synapse.util.frozenutils import freeze if TYPE_CHECKING: @@ -836,6 +836,37 @@ class ModuleApi: self._store.db_pool.runInteraction(desc, func, *args, **kwargs) # type: ignore[arg-type] ) + def register_cached_function(self, cached_func: CachedFunction) -> None: + """Register a cached function that should be invalidated across workers. + Invalidation local to a worker can be done directly using `cached_func.invalidate`, + however invalidation that needs to go to other workers needs to call `invalidate_cache` + on the module API instead. + + Args: + cached_function: The cached function that will be registered to receive invalidation + locally and from other workers. + """ + self._store.register_external_cached_function( + f"{cached_func.__module__}.{cached_func.__name__}", cached_func + ) + + async def invalidate_cache( + self, cached_func: CachedFunction, keys: Tuple[Any, ...] + ) -> None: + """Invalidate a cache entry of a cached function across workers. The cached function + needs to be registered on all workers first with `register_cached_function`. + + Args: + cached_function: The cached function that needs an invalidation + keys: keys of the entry to invalidate, usually matching the arguments of the + cached function. + """ + cached_func.invalidate(keys) + await self._store.send_invalidation_to_replication( + f"{cached_func.__module__}.{cached_func.__name__}", + keys, + ) + async def complete_sso_login_async( self, registered_user_id: str, diff --git a/synapse/push/__init__.py b/synapse/push/__init__.py index 57c4d70466..a0c760239d 100644 --- a/synapse/push/__init__.py +++ b/synapse/push/__init__.py @@ -116,6 +116,8 @@ class PusherConfig: last_stream_ordering: int last_success: Optional[int] failing_since: Optional[int] + enabled: bool + device_id: Optional[str] def as_dict(self) -> Dict[str, Any]: """Information that can be retrieved about a pusher after creation.""" @@ -128,6 +130,8 @@ class PusherConfig: "lang": self.lang, "profile_tag": self.profile_tag, "pushkey": self.pushkey, + "enabled": self.enabled, + "device_id": self.device_id, } diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index 1e0ef44fc7..e2648cbc93 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -94,7 +94,7 @@ class PusherPool: return run_as_background_process("start_pushers", self._start_pushers) - async def add_pusher( + async def add_or_update_pusher( self, user_id: str, access_token: Optional[int], @@ -106,6 +106,8 @@ class PusherPool: lang: Optional[str], data: JsonDict, profile_tag: str = "", + enabled: bool = True, + device_id: Optional[str] = None, ) -> Optional[Pusher]: """Creates a new pusher and adds it to the pool @@ -147,9 +149,22 @@ class PusherPool: last_stream_ordering=last_stream_ordering, last_success=None, failing_since=None, + enabled=enabled, + device_id=device_id, ) ) + # Before we actually persist the pusher, we check if the user already has one + # this app ID and pushkey. If so, we want to keep the access token and device ID + # in place, since this could be one device modifying (e.g. enabling/disabling) + # another device's pusher. + existing_config = await self._get_pusher_config_for_user_by_app_id_and_pushkey( + user_id, app_id, pushkey + ) + if existing_config: + access_token = existing_config.access_token + device_id = existing_config.device_id + await self.store.add_pusher( user_id=user_id, access_token=access_token, @@ -163,8 +178,10 @@ class PusherPool: data=data, last_stream_ordering=last_stream_ordering, profile_tag=profile_tag, + enabled=enabled, + device_id=device_id, ) - pusher = await self.start_pusher_by_id(app_id, pushkey, user_id) + pusher = await self.process_pusher_change_by_id(app_id, pushkey, user_id) return pusher @@ -276,10 +293,25 @@ class PusherPool: except Exception: logger.exception("Exception in pusher on_new_receipts") - async def start_pusher_by_id( + async def _get_pusher_config_for_user_by_app_id_and_pushkey( + self, user_id: str, app_id: str, pushkey: str + ) -> Optional[PusherConfig]: + resultlist = await self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey) + + pusher_config = None + for r in resultlist: + if r.user_name == user_id: + pusher_config = r + + return pusher_config + + async def process_pusher_change_by_id( self, app_id: str, pushkey: str, user_id: str ) -> Optional[Pusher]: - """Look up the details for the given pusher, and start it + """Look up the details for the given pusher, and either start it if its + "enabled" flag is True, or try to stop it otherwise. + + If the pusher is new and its "enabled" flag is False, the stop is a noop. Returns: The pusher started, if any @@ -290,12 +322,13 @@ class PusherPool: if not self._pusher_shard_config.should_handle(self._instance_name, user_id): return None - resultlist = await self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey) + pusher_config = await self._get_pusher_config_for_user_by_app_id_and_pushkey( + user_id, app_id, pushkey + ) - pusher_config = None - for r in resultlist: - if r.user_name == user_id: - pusher_config = r + if pusher_config and not pusher_config.enabled: + self.maybe_stop_pusher(app_id, pushkey, user_id) + return None pusher = None if pusher_config: @@ -305,7 +338,7 @@ class PusherPool: async def _start_pushers(self) -> None: """Start all the pushers""" - pushers = await self.store.get_all_pushers() + pushers = await self.store.get_enabled_pushers() # Stagger starting up the pushers so we don't completely drown the # process on start up. @@ -363,6 +396,8 @@ class PusherPool: synapse_pushers.labels(type(pusher).__name__, pusher.app_id).inc() + logger.info("Starting pusher %s / %s", pusher.user_id, appid_pushkey) + # Check if there *may* be push to process. We do this as this check is a # lot cheaper to do than actually fetching the exact rows we need to # push. @@ -382,16 +417,7 @@ class PusherPool: return pusher async def remove_pusher(self, app_id: str, pushkey: str, user_id: str) -> None: - appid_pushkey = "%s:%s" % (app_id, pushkey) - - byuser = self.pushers.get(user_id, {}) - - if appid_pushkey in byuser: - logger.info("Stopping pusher %s / %s", user_id, appid_pushkey) - pusher = byuser.pop(appid_pushkey) - pusher.on_stop() - - synapse_pushers.labels(type(pusher).__name__, pusher.app_id).dec() + self.maybe_stop_pusher(app_id, pushkey, user_id) # We can only delete pushers on master. if self._remove_pusher_client: @@ -402,3 +428,22 @@ class PusherPool: await self.store.delete_pusher_by_app_id_pushkey_user_id( app_id, pushkey, user_id ) + + def maybe_stop_pusher(self, app_id: str, pushkey: str, user_id: str) -> None: + """Stops a pusher with the given app ID and push key if one is running. + + Args: + app_id: the pusher's app ID. + pushkey: the pusher's push key. + user_id: the user the pusher belongs to. Only used for logging. + """ + appid_pushkey = "%s:%s" % (app_id, pushkey) + + byuser = self.pushers.get(user_id, {}) + + if appid_pushkey in byuser: + logger.info("Stopping pusher %s / %s", user_id, appid_pushkey) + pusher = byuser.pop(appid_pushkey) + pusher.on_stop() + + synapse_pushers.labels(type(pusher).__name__, pusher.app_id).dec() diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index e4f2201c92..cf9cd6833b 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -189,7 +189,9 @@ class ReplicationDataHandler: if row.deleted: self.stop_pusher(row.user_id, row.app_id, row.pushkey) else: - await self.start_pusher(row.user_id, row.app_id, row.pushkey) + await self.process_pusher_change( + row.user_id, row.app_id, row.pushkey + ) elif stream_name == EventsStream.NAME: # We shouldn't get multiple rows per token for events stream, so # we don't need to optimise this for multiple rows. @@ -334,13 +336,15 @@ class ReplicationDataHandler: logger.info("Stopping pusher %r / %r", user_id, key) pusher.on_stop() - async def start_pusher(self, user_id: str, app_id: str, pushkey: str) -> None: + async def process_pusher_change( + self, user_id: str, app_id: str, pushkey: str + ) -> None: if not self._notify_pushers: return key = "%s:%s" % (app_id, pushkey) logger.info("Starting pusher %r / %r", user_id, key) - await self._pusher_pool.start_pusher_by_id(app_id, pushkey, user_id) + await self._pusher_pool.process_pusher_change_by_id(app_id, pushkey, user_id) class FederationSenderHandler: diff --git a/synapse/rest/__init__.py b/synapse/rest/__init__.py index b712215112..9a2ab99ede 100644 --- a/synapse/rest/__init__.py +++ b/synapse/rest/__init__.py @@ -30,6 +30,7 @@ from synapse.rest.client import ( keys, knock, login as v1_login, + login_token_request, logout, mutual_rooms, notifications, @@ -130,3 +131,4 @@ class ClientRestResource(JsonResource): # unstable mutual_rooms.register_servlets(hs, client_resource) + login_token_request.register_servlets(hs, client_resource) diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 2ca6b2d08a..1274773d7e 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -375,7 +375,7 @@ class UserRestServletV2(RestServlet): and self.hs.config.email.email_notif_for_new_users and medium == "email" ): - await self.pusher_pool.add_pusher( + await self.pusher_pool.add_or_update_pusher( user_id=user_id, access_token=None, kind="email", @@ -383,7 +383,7 @@ class UserRestServletV2(RestServlet): app_display_name="Email Notifications", device_display_name=address, pushkey=address, - lang=None, # We don't know a user's language here + lang=None, data={}, ) diff --git a/synapse/rest/client/account.py b/synapse/rest/client/account.py index 2db2a04f95..44f622bcce 100644 --- a/synapse/rest/client/account.py +++ b/synapse/rest/client/account.py @@ -534,6 +534,11 @@ class AddThreepidMsisdnSubmitTokenServlet(RestServlet): "/add_threepid/msisdn/submit_token$", releases=(), unstable=True ) + class PostBody(RequestBodyModel): + client_secret: ClientSecretStr + sid: StrictStr + token: StrictStr + def __init__(self, hs: "HomeServer"): super().__init__() self.config = hs.config @@ -549,16 +554,14 @@ class AddThreepidMsisdnSubmitTokenServlet(RestServlet): "instead.", ) - body = parse_json_object_from_request(request) - assert_params_in_dict(body, ["client_secret", "sid", "token"]) - assert_valid_client_secret(body["client_secret"]) + body = parse_and_validate_json_object_from_request(request, self.PostBody) # Proxy submit_token request to msisdn threepid delegate response = await self.identity_handler.proxy_msisdn_submit_token( self.config.registration.account_threepid_delegate_msisdn, - body["client_secret"], - body["sid"], - body["token"], + body.client_secret, + body.sid, + body.token, ) return 200, response @@ -581,6 +584,10 @@ class ThreepidRestServlet(RestServlet): return 200, {"threepids": threepids} + # NOTE(dmr): I have chosen not to use Pydantic to parse this request's body, because + # the endpoint is deprecated. (If you really want to, you could do this by reusing + # ThreePidBindRestServelet.PostBody with an `alias_generator` to handle + # `threePidCreds` versus `three_pid_creds`. async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: if not self.hs.config.registration.enable_3pid_changes: raise SynapseError( diff --git a/synapse/rest/client/login_token_request.py b/synapse/rest/client/login_token_request.py new file mode 100644 index 0000000000..ca5c54bf17 --- /dev/null +++ b/synapse/rest/client/login_token_request.py @@ -0,0 +1,94 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import TYPE_CHECKING, Tuple + +from synapse.http.server import HttpServer +from synapse.http.servlet import RestServlet, parse_json_object_from_request +from synapse.http.site import SynapseRequest +from synapse.rest.client._base import client_patterns, interactive_auth_handler +from synapse.types import JsonDict + +if TYPE_CHECKING: + from synapse.server import HomeServer + +logger = logging.getLogger(__name__) + + +class LoginTokenRequestServlet(RestServlet): + """ + Get a token that can be used with `m.login.token` to log in a second device. + + Request: + + POST /login/token HTTP/1.1 + Content-Type: application/json + + {} + + Response: + + HTTP/1.1 200 OK + { + "login_token": "ABDEFGH", + "expires_in": 3600, + } + """ + + PATTERNS = client_patterns("/login/token$") + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.auth = hs.get_auth() + self.store = hs.get_datastores().main + self.clock = hs.get_clock() + self.server_name = hs.config.server.server_name + self.macaroon_gen = hs.get_macaroon_generator() + self.auth_handler = hs.get_auth_handler() + self.token_timeout = hs.config.experimental.msc3882_token_timeout + self.ui_auth = hs.config.experimental.msc3882_ui_auth + + @interactive_auth_handler + async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request) + body = parse_json_object_from_request(request) + + if self.ui_auth: + await self.auth_handler.validate_user_via_ui_auth( + requester, + request, + body, + "issue a new access token for your account", + can_skip_ui_auth=False, # Don't allow skipping of UI auth + ) + + login_token = self.macaroon_gen.generate_short_term_login_token( + user_id=requester.user.to_string(), + auth_provider_id="org.matrix.msc3882.login_token_request", + duration_in_ms=self.token_timeout, + ) + + return ( + 200, + { + "login_token": login_token, + "expires_in": self.token_timeout // 1000, + }, + ) + + +def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: + if hs.config.experimental.msc3882_enabled: + LoginTokenRequestServlet(hs).register(http_server) diff --git a/synapse/rest/client/pusher.py b/synapse/rest/client/pusher.py index 9a1f10f4be..975eef2144 100644 --- a/synapse/rest/client/pusher.py +++ b/synapse/rest/client/pusher.py @@ -42,6 +42,7 @@ class PushersRestServlet(RestServlet): super().__init__() self.hs = hs self.auth = hs.get_auth() + self._msc3881_enabled = self.hs.config.experimental.msc3881_enabled async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) @@ -51,9 +52,16 @@ class PushersRestServlet(RestServlet): user.to_string() ) - filtered_pushers = [p.as_dict() for p in pushers] + pusher_dicts = [p.as_dict() for p in pushers] - return 200, {"pushers": filtered_pushers} + for pusher in pusher_dicts: + if self._msc3881_enabled: + pusher["org.matrix.msc3881.enabled"] = pusher["enabled"] + pusher["org.matrix.msc3881.device_id"] = pusher["device_id"] + del pusher["enabled"] + del pusher["device_id"] + + return 200, {"pushers": pusher_dicts} class PushersSetRestServlet(RestServlet): @@ -65,6 +73,7 @@ class PushersSetRestServlet(RestServlet): self.auth = hs.get_auth() self.notifier = hs.get_notifier() self.pusher_pool = self.hs.get_pusherpool() + self._msc3881_enabled = self.hs.config.experimental.msc3881_enabled async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) @@ -103,6 +112,10 @@ class PushersSetRestServlet(RestServlet): if "append" in content: append = content["append"] + enabled = True + if self._msc3881_enabled and "org.matrix.msc3881.enabled" in content: + enabled = content["org.matrix.msc3881.enabled"] + if not append: await self.pusher_pool.remove_pushers_by_app_id_and_pushkey_not_user( app_id=content["app_id"], @@ -111,7 +124,7 @@ class PushersSetRestServlet(RestServlet): ) try: - await self.pusher_pool.add_pusher( + await self.pusher_pool.add_or_update_pusher( user_id=user.to_string(), access_token=requester.access_token_id, kind=content["kind"], @@ -122,6 +135,8 @@ class PushersSetRestServlet(RestServlet): lang=content["lang"], data=content["data"], profile_tag=content.get("profile_tag", ""), + enabled=enabled, + device_id=requester.device_id, ) except PusherConfigException as pce: raise SynapseError( diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py index c516cda95d..b3917a5abc 100644 --- a/synapse/rest/client/versions.py +++ b/synapse/rest/client/versions.py @@ -105,6 +105,10 @@ class VersionsRestServlet(RestServlet): "org.matrix.msc3440.stable": True, # TODO: remove when "v1.3" is added above # Allows moderators to fetch redacted event content as described in MSC2815 "fi.mau.msc2815": self.config.experimental.msc2815_enabled, + # Adds support for login token requests as per MSC3882 + "org.matrix.msc3882": self.config.experimental.msc3882_enabled, + # Adds support for remotely enabling/disabling pushers, as per MSC3881 + "org.matrix.msc3881": self.config.experimental.msc3881_enabled, }, }, ) diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index e30f9c76d4..303a5d5298 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -15,12 +15,13 @@ # limitations under the License. import logging from abc import ABCMeta -from typing import TYPE_CHECKING, Any, Collection, Iterable, Optional, Union +from typing import TYPE_CHECKING, Any, Collection, Dict, Iterable, Optional, Union from synapse.storage.database import make_in_list_sql_clause # noqa: F401; noqa: F401 from synapse.storage.database import DatabasePool, LoggingDatabaseConnection from synapse.types import get_domain_from_id from synapse.util import json_decoder +from synapse.util.caches.descriptors import CachedFunction if TYPE_CHECKING: from synapse.server import HomeServer @@ -47,6 +48,8 @@ class SQLBaseStore(metaclass=ABCMeta): self.database_engine = database.engine self.db_pool = database + self.external_cached_functions: Dict[str, CachedFunction] = {} + def process_replication_rows( self, stream_name: str, @@ -95,7 +98,7 @@ class SQLBaseStore(metaclass=ABCMeta): def _attempt_to_invalidate_cache( self, cache_name: str, key: Optional[Collection[Any]] - ) -> None: + ) -> bool: """Attempts to invalidate the cache of the given name, ignoring if the cache doesn't exist. Mainly used for invalidating caches on workers, where they may not have the cache. @@ -113,9 +116,12 @@ class SQLBaseStore(metaclass=ABCMeta): try: cache = getattr(self, cache_name) except AttributeError: - # We probably haven't pulled in the cache in this worker, - # which is fine. - return + # Check if an externally defined module cache has been registered + cache = self.external_cached_functions.get(cache_name) + if not cache: + # We probably haven't pulled in the cache in this worker, + # which is fine. + return False if key is None: cache.invalidate_all() @@ -125,6 +131,13 @@ class SQLBaseStore(metaclass=ABCMeta): invalidate_method = getattr(cache, "invalidate_local", cache.invalidate) invalidate_method(tuple(key)) + return True + + def register_external_cached_function( + self, cache_name: str, func: CachedFunction + ) -> None: + self.external_cached_functions[cache_name] = func + def db_to_json(db_content: Union[memoryview, bytes, bytearray, str]) -> Any: """ diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index 99b0b5d09f..2dbf93a4e5 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -33,7 +33,7 @@ from synapse.storage.database import ( ) from synapse.storage.engines import PostgresEngine from synapse.storage.util.id_generators import MultiWriterIdGenerator -from synapse.util.caches.descriptors import _CachedFunction +from synapse.util.caches.descriptors import CachedFunction from synapse.util.iterutils import batch_iter if TYPE_CHECKING: @@ -269,9 +269,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore): return cache_func.invalidate(keys) - await self.db_pool.runInteraction( - "invalidate_cache_and_stream", - self._send_invalidation_to_replication, + await self.send_invalidation_to_replication( cache_func.__name__, keys, ) @@ -279,7 +277,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore): def _invalidate_cache_and_stream( self, txn: LoggingTransaction, - cache_func: _CachedFunction, + cache_func: CachedFunction, keys: Tuple[Any, ...], ) -> None: """Invalidates the cache and adds it to the cache stream so slaves @@ -293,7 +291,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore): self._send_invalidation_to_replication(txn, cache_func.__name__, keys) def _invalidate_all_cache_and_stream( - self, txn: LoggingTransaction, cache_func: _CachedFunction + self, txn: LoggingTransaction, cache_func: CachedFunction ) -> None: """Invalidates the entire cache and adds it to the cache stream so slaves will know to invalidate their caches. @@ -334,6 +332,16 @@ class CacheInvalidationWorkerStore(SQLBaseStore): txn, CURRENT_STATE_CACHE_NAME, [room_id] ) + async def send_invalidation_to_replication( + self, cache_name: str, keys: Optional[Collection[Any]] + ) -> None: + await self.db_pool.runInteraction( + "send_invalidation_to_replication", + self._send_invalidation_to_replication, + cache_name, + keys, + ) + def _send_invalidation_to_replication( self, txn: LoggingTransaction, cache_name: str, keys: Optional[Iterable[Any]] ) -> None: diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index 9f88ea8c7f..e08d2fe74a 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -412,6 +412,31 @@ class PersistEventsStore: assert min_stream_order assert max_stream_order + # Once the txn completes, invalidate all of the relevant caches. Note that we do this + # up here because it captures all the events_and_contexts before any are removed. + for event, _ in events_and_contexts: + self.store.invalidate_get_event_cache_after_txn(txn, event.event_id) + if event.redacts: + self.store.invalidate_get_event_cache_after_txn(txn, event.redacts) + + relates_to = None + relation = relation_from_event(event) + if relation: + relates_to = relation.parent_id + + assert event.internal_metadata.stream_ordering is not None + txn.call_after( + self.store._invalidate_caches_for_event, + event.internal_metadata.stream_ordering, + event.event_id, + event.room_id, + event.type, + getattr(event, "state_key", None), + event.redacts, + relates_to, + backfilled=False, + ) + self._update_forward_extremities_txn( txn, new_forward_extremities=new_forward_extremities, diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 9f6b1fcef1..74991a4992 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -1474,32 +1474,38 @@ class EventsWorkerStore(SQLBaseStore): # the batches as big as possible. results: Set[str] = set() - for chunk in batch_iter(event_ids, 500): - r = await self._have_seen_events_dict( - [(room_id, event_id) for event_id in chunk] + for event_ids_chunk in batch_iter(event_ids, 500): + events_seen_dict = await self._have_seen_events_dict( + room_id, event_ids_chunk + ) + results.update( + eid for (eid, have_event) in events_seen_dict.items() if have_event ) - results.update(eid for ((_rid, eid), have_event) in r.items() if have_event) return results - @cachedList(cached_method_name="have_seen_event", list_name="keys") + @cachedList(cached_method_name="have_seen_event", list_name="event_ids") async def _have_seen_events_dict( - self, keys: Collection[Tuple[str, str]] - ) -> Dict[Tuple[str, str], bool]: + self, + room_id: str, + event_ids: Collection[str], + ) -> Dict[str, bool]: """Helper for have_seen_events Returns: - a dict {(room_id, event_id)-> bool} + a dict {event_id -> bool} """ # if the event cache contains the event, obviously we've seen it. cache_results = { - (rid, eid) - for (rid, eid) in keys - if await self._get_event_cache.contains((eid,)) + event_id + for event_id in event_ids + if await self._get_event_cache.contains((event_id,)) } results = dict.fromkeys(cache_results, True) - remaining = [k for k in keys if k not in cache_results] + remaining = [ + event_id for event_id in event_ids if event_id not in cache_results + ] if not remaining: return results @@ -1511,23 +1517,21 @@ class EventsWorkerStore(SQLBaseStore): sql = "SELECT event_id FROM events AS e WHERE " clause, args = make_in_list_sql_clause( - txn.database_engine, "e.event_id", [eid for (_rid, eid) in remaining] + txn.database_engine, "e.event_id", remaining ) txn.execute(sql + clause, args) found_events = {eid for eid, in txn} # ... and then we can update the results for each key - results.update( - {(rid, eid): (eid in found_events) for (rid, eid) in remaining} - ) + results.update({eid: (eid in found_events) for eid in remaining}) await self.db_pool.runInteraction("have_seen_events", have_seen_events_txn) return results @cached(max_entries=100000, tree=True) async def have_seen_event(self, room_id: str, event_id: str) -> bool: - res = await self._have_seen_events_dict(((room_id, event_id),)) - return res[(room_id, event_id)] + res = await self._have_seen_events_dict(room_id, [event_id]) + return res[event_id] def _get_current_state_event_counts_txn( self, txn: LoggingTransaction, room_id: str diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py index bd0cfa7f32..01206950a9 100644 --- a/synapse/storage/databases/main/pusher.py +++ b/synapse/storage/databases/main/pusher.py @@ -89,6 +89,11 @@ class PusherWorkerStore(SQLBaseStore): ) continue + # If we're using SQLite, then boolean values are integers. This is + # troublesome since some code using the return value of this method might + # expect it to be a boolean, or will expose it to clients (in responses). + r["enabled"] = bool(r["enabled"]) + yield PusherConfig(**r) async def get_pushers_by_app_id_and_pushkey( @@ -100,38 +105,52 @@ class PusherWorkerStore(SQLBaseStore): return await self.get_pushers_by({"user_name": user_id}) async def get_pushers_by(self, keyvalues: Dict[str, Any]) -> Iterator[PusherConfig]: - ret = await self.db_pool.simple_select_list( - "pushers", - keyvalues, - [ - "id", - "user_name", - "access_token", - "profile_tag", - "kind", - "app_id", - "app_display_name", - "device_display_name", - "pushkey", - "ts", - "lang", - "data", - "last_stream_ordering", - "last_success", - "failing_since", - ], + """Retrieve pushers that match the given criteria. + + Args: + keyvalues: A {column: value} dictionary. + + Returns: + The pushers for which the given columns have the given values. + """ + + def get_pushers_by_txn(txn: LoggingTransaction) -> List[Dict[str, Any]]: + # We could technically use simple_select_list here, but we need to call + # COALESCE on the 'enabled' column. While it is technically possible to give + # simple_select_list the whole `COALESCE(...) AS ...` as a column name, it + # feels a bit hacky, so it's probably better to just inline the query. + sql = """ + SELECT + id, user_name, access_token, profile_tag, kind, app_id, + app_display_name, device_display_name, pushkey, ts, lang, data, + last_stream_ordering, last_success, failing_since, + COALESCE(enabled, TRUE) AS enabled, device_id + FROM pushers + """ + + sql += "WHERE %s" % (" AND ".join("%s = ?" % (k,) for k in keyvalues),) + + txn.execute(sql, list(keyvalues.values())) + + return self.db_pool.cursor_to_dict(txn) + + ret = await self.db_pool.runInteraction( desc="get_pushers_by", + func=get_pushers_by_txn, ) + return self._decode_pushers_rows(ret) - async def get_all_pushers(self) -> Iterator[PusherConfig]: - def get_pushers(txn: LoggingTransaction) -> Iterator[PusherConfig]: - txn.execute("SELECT * FROM pushers") + async def get_enabled_pushers(self) -> Iterator[PusherConfig]: + def get_enabled_pushers_txn(txn: LoggingTransaction) -> Iterator[PusherConfig]: + txn.execute("SELECT * FROM pushers WHERE COALESCE(enabled, TRUE)") rows = self.db_pool.cursor_to_dict(txn) return self._decode_pushers_rows(rows) - return await self.db_pool.runInteraction("get_all_pushers", get_pushers) + return await self.db_pool.runInteraction( + "get_enabled_pushers", get_enabled_pushers_txn + ) async def get_all_updated_pushers_rows( self, instance_name: str, last_id: int, current_id: int, limit: int @@ -458,7 +477,74 @@ class PusherWorkerStore(SQLBaseStore): return number_deleted -class PusherStore(PusherWorkerStore): +class PusherBackgroundUpdatesStore(SQLBaseStore): + def __init__( + self, + database: DatabasePool, + db_conn: LoggingDatabaseConnection, + hs: "HomeServer", + ): + super().__init__(database, db_conn, hs) + + self.db_pool.updates.register_background_update_handler( + "set_device_id_for_pushers", self._set_device_id_for_pushers + ) + + async def _set_device_id_for_pushers( + self, progress: JsonDict, batch_size: int + ) -> int: + """Background update to populate the device_id column of the pushers table.""" + last_pusher_id = progress.get("pusher_id", 0) + + def set_device_id_for_pushers_txn(txn: LoggingTransaction) -> int: + txn.execute( + """ + SELECT p.id, at.device_id + FROM pushers AS p + INNER JOIN access_tokens AS at + ON p.access_token = at.id + WHERE + p.access_token IS NOT NULL + AND at.device_id IS NOT NULL + AND p.id > ? + ORDER BY p.id + LIMIT ? + """, + (last_pusher_id, batch_size), + ) + + rows = self.db_pool.cursor_to_dict(txn) + if len(rows) == 0: + return 0 + + self.db_pool.simple_update_many_txn( + txn=txn, + table="pushers", + key_names=("id",), + key_values=[(row["id"],) for row in rows], + value_names=("device_id",), + value_values=[(row["device_id"],) for row in rows], + ) + + self.db_pool.updates._background_update_progress_txn( + txn, "set_device_id_for_pushers", {"pusher_id": rows[-1]["id"]} + ) + + return len(rows) + + nb_processed = await self.db_pool.runInteraction( + "set_device_id_for_pushers", set_device_id_for_pushers_txn + ) + + if nb_processed < batch_size: + await self.db_pool.updates._end_background_update( + "set_device_id_for_pushers" + ) + + return nb_processed + + +class PusherStore(PusherWorkerStore, PusherBackgroundUpdatesStore): def get_pushers_stream_token(self) -> int: return self._pushers_id_gen.get_current_token() @@ -476,6 +562,8 @@ class PusherStore(PusherWorkerStore): data: Optional[JsonDict], last_stream_ordering: int, profile_tag: str = "", + enabled: bool = True, + device_id: Optional[str] = None, ) -> None: async with self._pushers_id_gen.get_next() as stream_id: # no need to lock because `pushers` has a unique key on @@ -494,6 +582,8 @@ class PusherStore(PusherWorkerStore): "last_stream_ordering": last_stream_ordering, "profile_tag": profile_tag, "id": stream_id, + "enabled": enabled, + "device_id": device_id, }, desc="add_pusher", lock=False, diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index 255854cd66..64cba763c4 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -52,6 +52,8 @@ class _RelatedEvent: event_id: str # The sender of the related event. sender: str + topological_ordering: Optional[int] + stream_ordering: int class RelationsWorkerStore(SQLBaseStore): @@ -92,6 +94,9 @@ class RelationsWorkerStore(SQLBaseStore): # it. The `event_id` must match the `event.event_id`. assert event.event_id == event_id + # Ensure bad limits aren't being passed in. + assert limit >= 0 + where_clause = ["relates_to_id = ?", "room_id = ?"] where_args: List[Union[str, int]] = [event.event_id, room_id] is_redacted = event.internal_metadata.is_redacted() @@ -140,21 +145,34 @@ class RelationsWorkerStore(SQLBaseStore): ) -> Tuple[List[_RelatedEvent], Optional[StreamToken]]: txn.execute(sql, where_args + [limit + 1]) - last_topo_id = None - last_stream_id = None events = [] - for row in txn: + for event_id, relation_type, sender, topo_ordering, stream_ordering in txn: # Do not include edits for redacted events as they leak event # content. - if not is_redacted or row[1] != RelationTypes.REPLACE: - events.append(_RelatedEvent(row[0], row[2])) - last_topo_id = row[3] - last_stream_id = row[4] + if not is_redacted or relation_type != RelationTypes.REPLACE: + events.append( + _RelatedEvent(event_id, sender, topo_ordering, stream_ordering) + ) - # If there are more events, generate the next pagination key. + # If there are more events, generate the next pagination key from the + # last event returned. next_token = None - if len(events) > limit and last_topo_id and last_stream_id: - next_key = RoomStreamToken(last_topo_id, last_stream_id) + if len(events) > limit: + # Instead of using the last row (which tells us there is more + # data), use the last row to be returned. + events = events[:limit] + + topo = events[-1].topological_ordering + token = events[-1].stream_ordering + if direction == "b": + # Tokens are positions between events. + # This token points *after* the last event in the chunk. + # We need it to point to the event before it in the chunk + # when we are going backwards so we subtract one from the + # stream part. + token -= 1 + next_key = RoomStreamToken(topo, token) + if from_token: next_token = from_token.copy_and_replace( StreamKeyType.ROOM, next_key diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index f0b179eea5..323c7bf7a5 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -1334,15 +1334,15 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): if rows: topo = rows[-1].topological_ordering - toke = rows[-1].stream_ordering + token = rows[-1].stream_ordering if direction == "b": # Tokens are positions between events. # This token points *after* the last event in the chunk. # We need it to point to the event before it in the chunk # when we are going backwards so we subtract one from the # stream part. - toke -= 1 - next_token = RoomStreamToken(topo, toke) + token -= 1 + next_token = RoomStreamToken(topo, token) else: # TODO (erikj): We should work out what to do here instead. next_token = to_token if to_token else from_token diff --git a/synapse/storage/schema/main/delta/73/02add_pusher_enabled.sql b/synapse/storage/schema/main/delta/73/02add_pusher_enabled.sql new file mode 100644 index 0000000000..dba3b4900b --- /dev/null +++ b/synapse/storage/schema/main/delta/73/02add_pusher_enabled.sql @@ -0,0 +1,16 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +ALTER TABLE pushers ADD COLUMN enabled BOOLEAN; \ No newline at end of file diff --git a/synapse/storage/schema/main/delta/73/03pusher_device_id.sql b/synapse/storage/schema/main/delta/73/03pusher_device_id.sql new file mode 100644 index 0000000000..1b4ffbeebe --- /dev/null +++ b/synapse/storage/schema/main/delta/73/03pusher_device_id.sql @@ -0,0 +1,20 @@ +/* Copyright 2022 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +-- Add a device_id column to track the device ID that created the pusher. It's NULLable +-- on purpose, because a) it might not be possible to track down the device that created +-- old pushers (pushers.access_token and access_tokens.device_id are both NULLable), and +-- b) access tokens retrieved via the admin API don't have a device associated to them. +ALTER TABLE pushers ADD COLUMN device_id TEXT; \ No newline at end of file diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 10aff4d04a..0391966462 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -53,7 +53,7 @@ CacheKey = Union[Tuple, Any] F = TypeVar("F", bound=Callable[..., Any]) -class _CachedFunction(Generic[F]): +class CachedFunction(Generic[F]): invalidate: Any = None invalidate_all: Any = None prefill: Any = None @@ -242,7 +242,7 @@ class LruCacheDescriptor(_CacheDescriptorBase): return ret2 - wrapped = cast(_CachedFunction, _wrapped) + wrapped = cast(CachedFunction, _wrapped) wrapped.cache = cache obj.__dict__[self.name] = wrapped @@ -363,7 +363,7 @@ class DeferredCacheDescriptor(_CacheDescriptorBase): return make_deferred_yieldable(ret) - wrapped = cast(_CachedFunction, _wrapped) + wrapped = cast(CachedFunction, _wrapped) if self.num_args == 1: assert not self.tree @@ -431,6 +431,12 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase): cache: DeferredCache[CacheKey, Any] = cached_method.cache num_args = cached_method.num_args + if num_args != self.num_args: + raise Exception( + "Number of args (%s) does not match underlying cache_method_name=%s (%s)." + % (self.num_args, self.cached_method_name, num_args) + ) + @functools.wraps(self.orig) def wrapped(*args: Any, **kwargs: Any) -> "defer.Deferred[Dict]": # If we're passed a cache_context then we'll want to call its @@ -572,7 +578,7 @@ def cached( iterable: bool = False, prune_unread_entries: bool = True, name: Optional[str] = None, -) -> Callable[[F], _CachedFunction[F]]: +) -> Callable[[F], CachedFunction[F]]: func = lambda orig: DeferredCacheDescriptor( orig, max_entries=max_entries, @@ -585,7 +591,7 @@ def cached( name=name, ) - return cast(Callable[[F], _CachedFunction[F]], func) + return cast(Callable[[F], CachedFunction[F]], func) def cachedList( @@ -594,7 +600,7 @@ def cachedList( list_name: str, num_args: Optional[int] = None, name: Optional[str] = None, -) -> Callable[[F], _CachedFunction[F]]: +) -> Callable[[F], CachedFunction[F]]: """Creates a descriptor that wraps a function in a `DeferredCacheListDescriptor`. Used to do batch lookups for an already created cache. One of the arguments @@ -631,7 +637,7 @@ def cachedList( name=name, ) - return cast(Callable[[F], _CachedFunction[F]], func) + return cast(Callable[[F], CachedFunction[F]], func) def _get_cache_key_builder( diff --git a/tests/push/test_email.py b/tests/push/test_email.py index 7a3b0d6755..fd14568f55 100644 --- a/tests/push/test_email.py +++ b/tests/push/test_email.py @@ -114,7 +114,7 @@ class EmailPusherTests(HomeserverTestCase): ) self.pusher = self.get_success( - self.hs.get_pusherpool().add_pusher( + self.hs.get_pusherpool().add_or_update_pusher( user_id=self.user_id, access_token=self.token_id, kind="email", @@ -136,7 +136,7 @@ class EmailPusherTests(HomeserverTestCase): """ with self.assertRaises(SynapseError) as cm: self.get_success_or_raise( - self.hs.get_pusherpool().add_pusher( + self.hs.get_pusherpool().add_or_update_pusher( user_id=self.user_id, access_token=self.token_id, kind="email", diff --git a/tests/push/test_http.py b/tests/push/test_http.py index d9c68cdd2d..b383b8401f 100644 --- a/tests/push/test_http.py +++ b/tests/push/test_http.py @@ -19,9 +19,10 @@ from twisted.test.proto_helpers import MemoryReactor import synapse.rest.admin from synapse.logging.context import make_deferred_yieldable -from synapse.push import PusherConfigException -from synapse.rest.client import login, push_rule, receipts, room +from synapse.push import PusherConfig, PusherConfigException +from synapse.rest.client import login, push_rule, pusher, receipts, room from synapse.server import HomeServer +from synapse.storage.databases.main.registration import TokenLookupResult from synapse.types import JsonDict from synapse.util import Clock @@ -35,6 +36,7 @@ class HTTPPusherTests(HomeserverTestCase): login.register_servlets, receipts.register_servlets, push_rule.register_servlets, + pusher.register_servlets, ] user_id = True hijack_auth = False @@ -74,7 +76,7 @@ class HTTPPusherTests(HomeserverTestCase): def test_data(data: Optional[JsonDict]) -> None: self.get_failure( - self.hs.get_pusherpool().add_pusher( + self.hs.get_pusherpool().add_or_update_pusher( user_id=user_id, access_token=token_id, kind="http", @@ -119,7 +121,7 @@ class HTTPPusherTests(HomeserverTestCase): token_id = user_tuple.token_id self.get_success( - self.hs.get_pusherpool().add_pusher( + self.hs.get_pusherpool().add_or_update_pusher( user_id=user_id, access_token=token_id, kind="http", @@ -235,7 +237,7 @@ class HTTPPusherTests(HomeserverTestCase): token_id = user_tuple.token_id self.get_success( - self.hs.get_pusherpool().add_pusher( + self.hs.get_pusherpool().add_or_update_pusher( user_id=user_id, access_token=token_id, kind="http", @@ -355,7 +357,7 @@ class HTTPPusherTests(HomeserverTestCase): token_id = user_tuple.token_id self.get_success( - self.hs.get_pusherpool().add_pusher( + self.hs.get_pusherpool().add_or_update_pusher( user_id=user_id, access_token=token_id, kind="http", @@ -441,7 +443,7 @@ class HTTPPusherTests(HomeserverTestCase): token_id = user_tuple.token_id self.get_success( - self.hs.get_pusherpool().add_pusher( + self.hs.get_pusherpool().add_or_update_pusher( user_id=user_id, access_token=token_id, kind="http", @@ -518,7 +520,7 @@ class HTTPPusherTests(HomeserverTestCase): token_id = user_tuple.token_id self.get_success( - self.hs.get_pusherpool().add_pusher( + self.hs.get_pusherpool().add_or_update_pusher( user_id=user_id, access_token=token_id, kind="http", @@ -624,7 +626,7 @@ class HTTPPusherTests(HomeserverTestCase): token_id = user_tuple.token_id self.get_success( - self.hs.get_pusherpool().add_pusher( + self.hs.get_pusherpool().add_or_update_pusher( user_id=user_id, access_token=token_id, kind="http", @@ -728,18 +730,38 @@ class HTTPPusherTests(HomeserverTestCase): ) self.assertEqual(channel.code, 200, channel.json_body) - def _make_user_with_pusher(self, username: str) -> Tuple[str, str]: + def _make_user_with_pusher( + self, username: str, enabled: bool = True + ) -> Tuple[str, str]: + """Registers a user and creates a pusher for them. + + Args: + username: the localpart of the new user's Matrix ID. + enabled: whether to create the pusher in an enabled or disabled state. + """ user_id = self.register_user(username, "pass") access_token = self.login(username, "pass") # Register the pusher + self._set_pusher(user_id, access_token, enabled) + + return user_id, access_token + + def _set_pusher(self, user_id: str, access_token: str, enabled: bool) -> None: + """Creates or updates the pusher for the given user. + + Args: + user_id: the user's Matrix ID. + access_token: the access token associated with the pusher. + enabled: whether to enable or disable the pusher. + """ user_tuple = self.get_success( self.hs.get_datastores().main.get_user_by_access_token(access_token) ) token_id = user_tuple.token_id self.get_success( - self.hs.get_pusherpool().add_pusher( + self.hs.get_pusherpool().add_or_update_pusher( user_id=user_id, access_token=token_id, kind="http", @@ -749,11 +771,11 @@ class HTTPPusherTests(HomeserverTestCase): pushkey="a@example.com", lang=None, data={"url": "http://example.com/_matrix/push/v1/notify"}, + enabled=enabled, + device_id=user_tuple.device_id, ) ) - return user_id, access_token - def test_dont_notify_rule_overrides_message(self) -> None: """ The override push rule will suppress notification @@ -791,3 +813,148 @@ class HTTPPusherTests(HomeserverTestCase): # The user sends a message back (sends a notification) self.helper.send(room, body="Hello", tok=access_token) self.assertEqual(len(self.push_attempts), 1) + + @override_config({"experimental_features": {"msc3881_enabled": True}}) + def test_disable(self) -> None: + """Tests that disabling a pusher means it's not pushed to anymore.""" + user_id, access_token = self._make_user_with_pusher("user") + other_user_id, other_access_token = self._make_user_with_pusher("otheruser") + + room = self.helper.create_room_as(user_id, tok=access_token) + self.helper.join(room=room, user=other_user_id, tok=other_access_token) + + # Send a message and check that it generated a push. + self.helper.send(room, body="Hi!", tok=other_access_token) + self.assertEqual(len(self.push_attempts), 1) + + # Disable the pusher. + self._set_pusher(user_id, access_token, enabled=False) + + # Send another message and check that it did not generate a push. + self.helper.send(room, body="Hi!", tok=other_access_token) + self.assertEqual(len(self.push_attempts), 1) + + # Get the pushers for the user and check that it is marked as disabled. + channel = self.make_request("GET", "/pushers", access_token=access_token) + self.assertEqual(channel.code, 200) + self.assertEqual(len(channel.json_body["pushers"]), 1) + + enabled = channel.json_body["pushers"][0]["org.matrix.msc3881.enabled"] + self.assertFalse(enabled) + self.assertTrue(isinstance(enabled, bool)) + + @override_config({"experimental_features": {"msc3881_enabled": True}}) + def test_enable(self) -> None: + """Tests that enabling a disabled pusher means it gets pushed to.""" + # Create the user with the pusher already disabled. + user_id, access_token = self._make_user_with_pusher("user", enabled=False) + other_user_id, other_access_token = self._make_user_with_pusher("otheruser") + + room = self.helper.create_room_as(user_id, tok=access_token) + self.helper.join(room=room, user=other_user_id, tok=other_access_token) + + # Send a message and check that it did not generate a push. + self.helper.send(room, body="Hi!", tok=other_access_token) + self.assertEqual(len(self.push_attempts), 0) + + # Enable the pusher. + self._set_pusher(user_id, access_token, enabled=True) + + # Send another message and check that it did generate a push. + self.helper.send(room, body="Hi!", tok=other_access_token) + self.assertEqual(len(self.push_attempts), 1) + + # Get the pushers for the user and check that it is marked as enabled. + channel = self.make_request("GET", "/pushers", access_token=access_token) + self.assertEqual(channel.code, 200) + self.assertEqual(len(channel.json_body["pushers"]), 1) + + enabled = channel.json_body["pushers"][0]["org.matrix.msc3881.enabled"] + self.assertTrue(enabled) + self.assertTrue(isinstance(enabled, bool)) + + @override_config({"experimental_features": {"msc3881_enabled": True}}) + def test_null_enabled(self) -> None: + """Tests that a pusher that has an 'enabled' column set to NULL (eg pushers + created before the column was introduced) is considered enabled. + """ + # We intentionally set 'enabled' to None so that it's stored as NULL in the + # database. + user_id, access_token = self._make_user_with_pusher("user", enabled=None) # type: ignore[arg-type] + + channel = self.make_request("GET", "/pushers", access_token=access_token) + self.assertEqual(channel.code, 200) + self.assertEqual(len(channel.json_body["pushers"]), 1) + self.assertTrue(channel.json_body["pushers"][0]["org.matrix.msc3881.enabled"]) + + def test_update_different_device_access_token_device_id(self) -> None: + """Tests that if we create a pusher from one device, the update it from another + device, the access token and device ID associated with the pusher stays the + same. + """ + # Create a user with a pusher. + user_id, access_token = self._make_user_with_pusher("user") + + # Get the token ID for the current access token, since that's what we store in + # the pushers table. Also get the device ID from it. + user_tuple = self.get_success( + self.hs.get_datastores().main.get_user_by_access_token(access_token) + ) + token_id = user_tuple.token_id + device_id = user_tuple.device_id + + # Generate a new access token, and update the pusher with it. + new_token = self.login("user", "pass") + self._set_pusher(user_id, new_token, enabled=False) + + # Get the current list of pushers for the user. + ret = self.get_success( + self.hs.get_datastores().main.get_pushers_by({"user_name": user_id}) + ) + pushers: List[PusherConfig] = list(ret) + + # Check that we still have one pusher, and that the access token and device ID + # associated with it didn't change. + self.assertEqual(len(pushers), 1) + self.assertEqual(pushers[0].access_token, token_id) + self.assertEqual(pushers[0].device_id, device_id) + + @override_config({"experimental_features": {"msc3881_enabled": True}}) + def test_device_id(self) -> None: + """Tests that a pusher created with a given device ID shows that device ID in + GET /pushers requests. + """ + self.register_user("user", "pass") + access_token = self.login("user", "pass") + + # We create the pusher with an HTTP request rather than with + # _make_user_with_pusher so that we can test the device ID is correctly set when + # creating a pusher via an API call. + self.make_request( + method="POST", + path="/pushers/set", + content={ + "kind": "http", + "app_id": "m.http", + "app_display_name": "HTTP Push Notifications", + "device_display_name": "pushy push", + "pushkey": "a@example.com", + "lang": "en", + "data": {"url": "http://example.com/_matrix/push/v1/notify"}, + }, + access_token=access_token, + ) + + # Look up the user info for the access token so we can compare the device ID. + lookup_result: TokenLookupResult = self.get_success( + self.hs.get_datastores().main.get_user_by_access_token(access_token) + ) + + # Get the user's devices and check it has the correct device ID. + channel = self.make_request("GET", "/pushers", access_token=access_token) + self.assertEqual(channel.code, 200) + self.assertEqual(len(channel.json_body["pushers"]), 1) + self.assertEqual( + channel.json_body["pushers"][0]["org.matrix.msc3881.device_id"], + lookup_result.device_id, + ) diff --git a/tests/replication/test_module_cache_invalidation.py b/tests/replication/test_module_cache_invalidation.py new file mode 100644 index 0000000000..b93cae67d3 --- /dev/null +++ b/tests/replication/test_module_cache_invalidation.py @@ -0,0 +1,79 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +import synapse +from synapse.module_api import cached + +from tests.replication._base import BaseMultiWorkerStreamTestCase + +logger = logging.getLogger(__name__) + +FIRST_VALUE = "one" +SECOND_VALUE = "two" + +KEY = "mykey" + + +class TestCache: + current_value = FIRST_VALUE + + @cached() + async def cached_function(self, user_id: str) -> str: + return self.current_value + + +class ModuleCacheInvalidationTestCase(BaseMultiWorkerStreamTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + ] + + def test_module_cache_full_invalidation(self): + main_cache = TestCache() + self.hs.get_module_api().register_cached_function(main_cache.cached_function) + + worker_hs = self.make_worker_hs("synapse.app.generic_worker") + + worker_cache = TestCache() + worker_hs.get_module_api().register_cached_function( + worker_cache.cached_function + ) + + self.assertEqual(FIRST_VALUE, self.get_success(main_cache.cached_function(KEY))) + self.assertEqual( + FIRST_VALUE, self.get_success(worker_cache.cached_function(KEY)) + ) + + main_cache.current_value = SECOND_VALUE + worker_cache.current_value = SECOND_VALUE + # No invalidation yet, should return the cached value on both the main process and the worker + self.assertEqual(FIRST_VALUE, self.get_success(main_cache.cached_function(KEY))) + self.assertEqual( + FIRST_VALUE, self.get_success(worker_cache.cached_function(KEY)) + ) + + # Full invalidation on the main process, should be replicated on the worker that + # should returned the updated value too + self.get_success( + self.hs.get_module_api().invalidate_cache( + main_cache.cached_function, (KEY,) + ) + ) + + self.assertEqual( + SECOND_VALUE, self.get_success(main_cache.cached_function(KEY)) + ) + self.assertEqual( + SECOND_VALUE, self.get_success(worker_cache.cached_function(KEY)) + ) diff --git a/tests/replication/test_pusher_shard.py b/tests/replication/test_pusher_shard.py index 8f4f6688ce..59fea93e49 100644 --- a/tests/replication/test_pusher_shard.py +++ b/tests/replication/test_pusher_shard.py @@ -55,7 +55,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase): token_id = user_dict.token_id self.get_success( - self.hs.get_pusherpool().add_pusher( + self.hs.get_pusherpool().add_or_update_pusher( user_id=user_id, access_token=token_id, kind="http", diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py index 9f536ceeb3..1847e6ad6b 100644 --- a/tests/rest/admin/test_user.py +++ b/tests/rest/admin/test_user.py @@ -2839,7 +2839,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase): token_id = user_tuple.token_id self.get_success( - self.hs.get_pusherpool().add_pusher( + self.hs.get_pusherpool().add_or_update_pusher( user_id=self.other_user, access_token=token_id, kind="http", diff --git a/tests/rest/client/test_login_token_request.py b/tests/rest/client/test_login_token_request.py new file mode 100644 index 0000000000..d5bb16c98d --- /dev/null +++ b/tests/rest/client/test_login_token_request.py @@ -0,0 +1,132 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.test.proto_helpers import MemoryReactor + +from synapse.rest import admin +from synapse.rest.client import login, login_token_request +from synapse.server import HomeServer +from synapse.util import Clock + +from tests import unittest +from tests.unittest import override_config + + +class LoginTokenRequestServletTestCase(unittest.HomeserverTestCase): + + servlets = [ + login.register_servlets, + admin.register_servlets, + login_token_request.register_servlets, + ] + + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: + self.hs = self.setup_test_homeserver() + self.hs.config.registration.enable_registration = True + self.hs.config.registration.registrations_require_3pid = [] + self.hs.config.registration.auto_join_rooms = [] + self.hs.config.captcha.enable_registration_captcha = False + + return self.hs + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.user = "user123" + self.password = "password" + + def test_disabled(self) -> None: + channel = self.make_request("POST", "/login/token", {}, access_token=None) + self.assertEqual(channel.code, 400) + + self.register_user(self.user, self.password) + token = self.login(self.user, self.password) + + channel = self.make_request("POST", "/login/token", {}, access_token=token) + self.assertEqual(channel.code, 400) + + @override_config({"experimental_features": {"msc3882_enabled": True}}) + def test_require_auth(self) -> None: + channel = self.make_request("POST", "/login/token", {}, access_token=None) + self.assertEqual(channel.code, 401) + + @override_config({"experimental_features": {"msc3882_enabled": True}}) + def test_uia_on(self) -> None: + user_id = self.register_user(self.user, self.password) + token = self.login(self.user, self.password) + + channel = self.make_request("POST", "/login/token", {}, access_token=token) + self.assertEqual(channel.code, 401) + self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"]) + + session = channel.json_body["session"] + + uia = { + "auth": { + "type": "m.login.password", + "identifier": {"type": "m.id.user", "user": self.user}, + "password": self.password, + "session": session, + }, + } + + channel = self.make_request("POST", "/login/token", uia, access_token=token) + self.assertEqual(channel.code, 200) + self.assertEqual(channel.json_body["expires_in"], 300) + + login_token = channel.json_body["login_token"] + + channel = self.make_request( + "POST", + "/login", + content={"type": "m.login.token", "token": login_token}, + ) + self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.json_body["user_id"], user_id) + + @override_config( + {"experimental_features": {"msc3882_enabled": True, "msc3882_ui_auth": False}} + ) + def test_uia_off(self) -> None: + user_id = self.register_user(self.user, self.password) + token = self.login(self.user, self.password) + + channel = self.make_request("POST", "/login/token", {}, access_token=token) + self.assertEqual(channel.code, 200) + self.assertEqual(channel.json_body["expires_in"], 300) + + login_token = channel.json_body["login_token"] + + channel = self.make_request( + "POST", + "/login", + content={"type": "m.login.token", "token": login_token}, + ) + self.assertEqual(channel.code, 200, channel.result) + self.assertEqual(channel.json_body["user_id"], user_id) + + @override_config( + { + "experimental_features": { + "msc3882_enabled": True, + "msc3882_ui_auth": False, + "msc3882_token_timeout": "15s", + } + } + ) + def test_expires_in(self) -> None: + self.register_user(self.user, self.password) + token = self.login(self.user, self.password) + + channel = self.make_request("POST", "/login/token", {}, access_token=token) + self.assertEqual(channel.code, 200) + self.assertEqual(channel.json_body["expires_in"], 15) diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index 651f4f415d..d33e34d829 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -788,6 +788,7 @@ class RelationPaginationTestCase(BaseRelationsTestCase): channel.json_body["chunk"][0], ) + @unittest.override_config({"experimental_features": {"msc3715_enabled": True}}) def test_repeated_paginate_relations(self) -> None: """Test that if we paginate using a limit and tokens then we get the expected events. @@ -809,7 +810,7 @@ class RelationPaginationTestCase(BaseRelationsTestCase): channel = self.make_request( "GET", - f"/_matrix/client/v1/rooms/{self.room}/relations/{self.parent_id}?limit=1{from_token}", + f"/_matrix/client/v1/rooms/{self.room}/relations/{self.parent_id}?limit=3{from_token}", access_token=self.user_token, ) self.assertEqual(200, channel.code, channel.json_body) @@ -827,6 +828,32 @@ class RelationPaginationTestCase(BaseRelationsTestCase): found_event_ids.reverse() self.assertEqual(found_event_ids, expected_event_ids) + # Test forward pagination. + prev_token = "" + found_event_ids = [] + for _ in range(20): + from_token = "" + if prev_token: + from_token = "&from=" + prev_token + + channel = self.make_request( + "GET", + f"/_matrix/client/v1/rooms/{self.room}/relations/{self.parent_id}?org.matrix.msc3715.dir=f&limit=3{from_token}", + access_token=self.user_token, + ) + self.assertEqual(200, channel.code, channel.json_body) + + found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"]) + next_batch = channel.json_body.get("next_batch") + + self.assertNotEqual(prev_token, next_batch) + prev_token = next_batch + + if not prev_token: + break + + self.assertEqual(found_event_ids, expected_event_ids) + def test_pagination_from_sync_and_messages(self) -> None: """Pagination tokens from /sync and /messages can be used to paginate /relations.""" channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "A") diff --git a/tests/storage/databases/main/test_events_worker.py b/tests/storage/databases/main/test_events_worker.py index 158ad1f439..32a798d74b 100644 --- a/tests/storage/databases/main/test_events_worker.py +++ b/tests/storage/databases/main/test_events_worker.py @@ -103,6 +103,11 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase): self.assertEqual(ctx.get_resource_usage().db_txn_count, 0) def test_persisting_event_invalidates_cache(self): + """ + Test to make sure that the `have_seen_event` cache + is invalidated after we persist an event and returns + the updated value. + """ event, event_context = self.get_success( create_event( self.hs, @@ -145,6 +150,33 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase): # That should result in a single db query to lookup self.assertEqual(ctx.get_resource_usage().db_txn_count, 1) + def test_invalidate_cache_by_room_id(self): + """ + Test to make sure that all events associated with the given `(room_id,)` + are invalidated in the `have_seen_event` cache. + """ + with LoggingContext(name="test") as ctx: + # Prime the cache with some values + res = self.get_success( + self.store.have_seen_events(self.room_id, self.event_ids) + ) + self.assertEqual(res, set(self.event_ids)) + + # That should result in a single db query to lookup + self.assertEqual(ctx.get_resource_usage().db_txn_count, 1) + + # Clear the cache with any events associated with the `room_id` + self.store.have_seen_event.invalidate((self.room_id,)) + + with LoggingContext(name="test") as ctx: + res = self.get_success( + self.store.have_seen_events(self.room_id, self.event_ids) + ) + self.assertEqual(res, set(self.event_ids)) + + # Since we cleared the cache, it should result in another db query to lookup + self.assertEqual(ctx.get_resource_usage().db_txn_count, 1) + class EventCacheTestCase(unittest.HomeserverTestCase): """Test that the various layers of event cache works.""" diff --git a/tests/unittest.py b/tests/unittest.py index 975b0a23a7..00cb023198 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -300,47 +300,31 @@ class HomeserverTestCase(TestCase): if hasattr(self, "user_id"): if self.hijack_auth: assert self.helper.auth_user_id is not None + token = "some_fake_token" # We need a valid token ID to satisfy foreign key constraints. token_id = self.get_success( self.hs.get_datastores().main.add_access_token_to_user( self.helper.auth_user_id, - "some_fake_token", + token, None, None, ) ) - async def get_user_by_access_token( - token: Optional[str] = None, allow_guest: bool = False - ) -> JsonDict: - assert self.helper.auth_user_id is not None - return { - "user": UserID.from_string(self.helper.auth_user_id), - "token_id": token_id, - "is_guest": False, - } - - async def get_user_by_req( - request: SynapseRequest, - allow_guest: bool = False, - allow_expired: bool = False, - ) -> Requester: + # This has to be a function and not just a Mock, because + # `self.helper.auth_user_id` is temporarily reassigned in some tests + async def get_requester(*args, **kwargs) -> Requester: assert self.helper.auth_user_id is not None return create_requester( - UserID.from_string(self.helper.auth_user_id), - token_id, - False, - False, - None, + user_id=UserID.from_string(self.helper.auth_user_id), + access_token_id=token_id, ) # Type ignore: mypy doesn't like us assigning to methods. - self.hs.get_auth().get_user_by_req = get_user_by_req # type: ignore[assignment] - self.hs.get_auth().get_user_by_access_token = get_user_by_access_token # type: ignore[assignment] - self.hs.get_auth().get_access_token_from_request = Mock( # type: ignore[assignment] - return_value="1234" - ) + self.hs.get_auth().get_user_by_req = get_requester # type: ignore[assignment] + self.hs.get_auth().get_user_by_access_token = get_requester # type: ignore[assignment] + self.hs.get_auth().get_access_token_from_request = Mock(return_value=token) # type: ignore[assignment] if self.needs_threadpool: self.reactor.threadpool = ThreadPool() # type: ignore[assignment] diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py index 48e616ac74..90861fe522 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import Set +from typing import Iterable, Set, Tuple from unittest import mock from twisted.internet import defer, reactor @@ -1008,3 +1008,34 @@ class CachedListDescriptorTestCase(unittest.TestCase): obj.inner_context_was_finished, "Tried to restart a finished logcontext" ) self.assertEqual(current_context(), SENTINEL_CONTEXT) + + def test_num_args_mismatch(self): + """ + Make sure someone does not accidentally use @cachedList on a method with + a mismatch in the number args to the underlying single cache method. + """ + + class Cls: + @descriptors.cached(tree=True) + def fn(self, room_id, event_id): + pass + + # This is wrong ❌. `@cachedList` expects to be given the same number + # of arguments as the underlying cached function, just with one of + # the arguments being an iterable + @descriptors.cachedList(cached_method_name="fn", list_name="keys") + def list_fn(self, keys: Iterable[Tuple[str, str]]): + pass + + # Corrected syntax ✅ + # + # @cachedList(cached_method_name="fn", list_name="event_ids") + # async def list_fn( + # self, room_id: str, event_ids: Collection[str], + # ) + + obj = Cls() + + # Make sure this raises an error about the arg mismatch + with self.assertRaises(Exception): + obj.list_fn([("foo", "bar")])