Merge branch 'madlittlemods/13856-fix-have-seen-events-not-being-invalidated' into maddlittlemods/msc2716-many-batches-optimization
Conflicts: tests/storage/databases/main/test_events_worker.py
This commit is contained in:
3
.github/workflows/latest_deps.yml
vendored
3
.github/workflows/latest_deps.yml
vendored
@@ -201,10 +201,11 @@ jobs:
|
||||
open-issue:
|
||||
if: "failure() && github.event_name != 'push' && github.event_name != 'pull_request'"
|
||||
needs:
|
||||
# TODO: should mypy be included here? It feels more brittle than the other two.
|
||||
# TODO: should mypy be included here? It feels more brittle than the others.
|
||||
- mypy
|
||||
- trial
|
||||
- sytest
|
||||
- complement
|
||||
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
|
||||
1
changelog.d/13667.feature
Normal file
1
changelog.d/13667.feature
Normal file
@@ -0,0 +1 @@
|
||||
Add cache invalidation across workers to module API.
|
||||
1
changelog.d/13722.feature
Normal file
1
changelog.d/13722.feature
Normal file
@@ -0,0 +1 @@
|
||||
Experimental implementation of MSC3882 to allow an existing device/session to generate a login token for use on a new device/session.
|
||||
1
changelog.d/13772.doc
Normal file
1
changelog.d/13772.doc
Normal file
@@ -0,0 +1 @@
|
||||
Add `worker_main_http_uri` for the worker generator bash script.
|
||||
1
changelog.d/13799.feature
Normal file
1
changelog.d/13799.feature
Normal file
@@ -0,0 +1 @@
|
||||
Add experimental support for [MSC3881: Remotely toggle push notifications for another client](https://github.com/matrix-org/matrix-spec-proposals/pull/3881).
|
||||
1
changelog.d/13809.misc
Normal file
1
changelog.d/13809.misc
Normal file
@@ -0,0 +1 @@
|
||||
Improve the `synapse.api.auth.Auth` mock used in unit tests.
|
||||
1
changelog.d/13831.feature
Normal file
1
changelog.d/13831.feature
Normal file
@@ -0,0 +1 @@
|
||||
Add experimental support for [MSC3881: Remotely toggle push notifications for another client](https://github.com/matrix-org/matrix-spec-proposals/pull/3881).
|
||||
1
changelog.d/13832.feature
Normal file
1
changelog.d/13832.feature
Normal file
@@ -0,0 +1 @@
|
||||
Improve validation for the unspecced, internal-only `_matrix/client/unstable/add_threepid/msisdn/submit_token` endpoint.
|
||||
1
changelog.d/13836.doc
Normal file
1
changelog.d/13836.doc
Normal file
@@ -0,0 +1 @@
|
||||
Fix a mistake in sso_mapping_providers.md: `map_user_attributes` is expected to return `display_name` not `displayname`.
|
||||
1
changelog.d/13840.bugfix
Normal file
1
changelog.d/13840.bugfix
Normal file
@@ -0,0 +1 @@
|
||||
Fix a bug introduced in Synapse v1.53.0 where the experimental implementation of [MSC3715](https://github.com/matrix-org/matrix-spec-proposals/pull/3715) would give incorrect results when paginating forward.
|
||||
1
changelog.d/13850.misc
Normal file
1
changelog.d/13850.misc
Normal file
@@ -0,0 +1 @@
|
||||
Fix the release script not publishing binary wheels.
|
||||
1
changelog.d/13859.misc
Normal file
1
changelog.d/13859.misc
Normal file
@@ -0,0 +1 @@
|
||||
Raise issue if complement fails with latest deps.
|
||||
1
changelog.d/13860.feature
Normal file
1
changelog.d/13860.feature
Normal file
@@ -0,0 +1 @@
|
||||
Add experimental support for [MSC3881: Remotely toggle push notifications for another client](https://github.com/matrix-org/matrix-spec-proposals/pull/3881).
|
||||
1
changelog.d/13863.bugfix
Normal file
1
changelog.d/13863.bugfix
Normal file
@@ -0,0 +1 @@
|
||||
Fix `have_seen_event` cache not being invalidated after we persist an event which causes inefficiency effects like extra `/state` federation calls.
|
||||
1
changelog.d/13870.doc
Normal file
1
changelog.d/13870.doc
Normal file
@@ -0,0 +1 @@
|
||||
Fix a cross-link from the register admin API to the `registration_shared_secret` configuration documentation.
|
||||
@@ -7,7 +7,7 @@ You can alternatively create multiple worker configuration files with a simple `
|
||||
#!/bin/bash
|
||||
for i in {1..5}
|
||||
do
|
||||
cat << EOF >> generic_worker$i.yaml
|
||||
cat << EOF > generic_worker$i.yaml
|
||||
worker_app: synapse.app.generic_worker
|
||||
worker_name: generic_worker$i
|
||||
|
||||
@@ -15,6 +15,8 @@ worker_name: generic_worker$i
|
||||
worker_replication_host: 127.0.0.1
|
||||
worker_replication_http_port: 9093
|
||||
|
||||
worker_main_http_uri: http://localhost:8008/
|
||||
|
||||
worker_listeners:
|
||||
- type: http
|
||||
port: 808$i
|
||||
|
||||
@@ -5,7 +5,7 @@ non-interactive way. This is generally used for bootstrapping a Synapse
|
||||
instance with administrator accounts.
|
||||
|
||||
To authenticate yourself to the server, you will need both the shared secret
|
||||
([`registration_shared_secret`](../configuration/config_documentation.md#registration_shared_secret)
|
||||
([`registration_shared_secret`](../usage/configuration/config_documentation.md#registration_shared_secret)
|
||||
in the homeserver configuration), and a one-time nonce. If the registration
|
||||
shared secret is not configured, this API is not enabled.
|
||||
|
||||
|
||||
@@ -73,8 +73,8 @@ A custom mapping provider must specify the following methods:
|
||||
* `async def map_user_attributes(self, userinfo, token, failures)`
|
||||
- This method must be async.
|
||||
- Arguments:
|
||||
- `userinfo` - A `authlib.oidc.core.claims.UserInfo` object to extract user
|
||||
information from.
|
||||
- `userinfo` - An [`authlib.oidc.core.claims.UserInfo`](https://docs.authlib.org/en/latest/specs/oidc.html#authlib.oidc.core.UserInfo)
|
||||
object to extract user information from.
|
||||
- `token` - A dictionary which includes information necessary to make
|
||||
further requests to the OpenID provider.
|
||||
- `failures` - An `int` that represents the amount of times the returned
|
||||
@@ -91,7 +91,13 @@ A custom mapping provider must specify the following methods:
|
||||
`None`, the user is prompted to pick their own username. This is only used
|
||||
during a user's first login. Once a localpart has been associated with a
|
||||
remote user ID (see `get_remote_user_id`) it cannot be updated.
|
||||
- `displayname`: An optional string, the display name for the user.
|
||||
- `confirm_localpart`: A boolean. If set to `True`, when a `localpart`
|
||||
string is returned from this method, Synapse will prompt the user to
|
||||
either accept this localpart or pick their own username. Otherwise this
|
||||
option has no effect. If omitted, defaults to `False`.
|
||||
- `display_name`: An optional string, the display name for the user.
|
||||
- `emails`: A list of strings, the email address(es) to associate with
|
||||
this user. If omitted, defaults to an empty list.
|
||||
* `async def get_extra_attributes(self, userinfo, token)`
|
||||
- This method must be async.
|
||||
- Arguments:
|
||||
|
||||
@@ -29,7 +29,7 @@ class SynapsePlugin(Plugin):
|
||||
self, fullname: str
|
||||
) -> Optional[Callable[[MethodSigContext], CallableType]]:
|
||||
if fullname.startswith(
|
||||
"synapse.util.caches.descriptors._CachedFunction.__call__"
|
||||
"synapse.util.caches.descriptors.CachedFunction.__call__"
|
||||
) or fullname.startswith(
|
||||
"synapse.util.caches.descriptors._LruCachedFunction.__call__"
|
||||
):
|
||||
@@ -38,7 +38,7 @@ class SynapsePlugin(Plugin):
|
||||
|
||||
|
||||
def cached_function_method_signature(ctx: MethodSigContext) -> CallableType:
|
||||
"""Fixes the `_CachedFunction.__call__` signature to be correct.
|
||||
"""Fixes the `CachedFunction.__call__` signature to be correct.
|
||||
|
||||
It already has *almost* the correct signature, except:
|
||||
|
||||
|
||||
@@ -427,11 +427,12 @@ def _publish(gh_token: str) -> None:
|
||||
|
||||
|
||||
@cli.command()
|
||||
def upload() -> None:
|
||||
_upload()
|
||||
@click.option("--gh-token", envvar=["GH_TOKEN", "GITHUB_TOKEN"], required=False)
|
||||
def upload(gh_token: Optional[str]) -> None:
|
||||
_upload(gh_token)
|
||||
|
||||
|
||||
def _upload() -> None:
|
||||
def _upload(gh_token: Optional[str]) -> None:
|
||||
"""Upload release to pypi."""
|
||||
|
||||
current_version = get_package_version()
|
||||
@@ -444,18 +445,40 @@ def _upload() -> None:
|
||||
click.echo("Tag {tag_name} (tag.commit) is not currently checked out!")
|
||||
click.get_current_context().abort()
|
||||
|
||||
pypi_asset_names = [
|
||||
f"matrix_synapse-{current_version}-py3-none-any.whl",
|
||||
f"matrix-synapse-{current_version}.tar.gz",
|
||||
]
|
||||
# Query all the assets corresponding to this release.
|
||||
gh = Github(gh_token)
|
||||
gh_repo = gh.get_repo("matrix-org/synapse")
|
||||
gh_release = gh_repo.get_release(tag_name)
|
||||
|
||||
all_assets = set(gh_release.get_assets())
|
||||
|
||||
# Only accept the wheels and sdist.
|
||||
# Notably: we don't care about debs.tar.xz.
|
||||
asset_names_and_urls = sorted(
|
||||
(asset.name, asset.browser_download_url)
|
||||
for asset in all_assets
|
||||
if asset.name.endswith((".whl", ".tar.gz"))
|
||||
)
|
||||
|
||||
# Print out what we've determined.
|
||||
print("Found relevant assets:")
|
||||
for asset_name, _ in asset_names_and_urls:
|
||||
print(f" - {asset_name}")
|
||||
|
||||
ignored_asset_names = sorted(
|
||||
{asset.name for asset in all_assets}
|
||||
- {asset_name for asset_name, _ in asset_names_and_urls}
|
||||
)
|
||||
print("\nIgnoring irrelevant assets:")
|
||||
for asset_name in ignored_asset_names:
|
||||
print(f" - {asset_name}")
|
||||
|
||||
with TemporaryDirectory(prefix=f"synapse_upload_{tag_name}_") as tmpdir:
|
||||
for name in pypi_asset_names:
|
||||
for name, asset_download_url in asset_names_and_urls:
|
||||
filename = path.join(tmpdir, name)
|
||||
url = f"https://github.com/matrix-org/synapse/releases/download/{tag_name}/{name}"
|
||||
|
||||
click.echo(f"Downloading {name} into {filename}")
|
||||
urllib.request.urlretrieve(url, filename=filename)
|
||||
urllib.request.urlretrieve(asset_download_url, filename=filename)
|
||||
|
||||
if click.confirm("Upload to PyPI?", default=True):
|
||||
subprocess.run("twine upload *", shell=True, cwd=tmpdir)
|
||||
@@ -672,7 +695,7 @@ def full(gh_token: str) -> None:
|
||||
_publish(gh_token)
|
||||
|
||||
click.echo("\n*** upload ***")
|
||||
_upload()
|
||||
_upload(gh_token)
|
||||
|
||||
click.echo("\n*** merge back ***")
|
||||
_merge_back()
|
||||
|
||||
@@ -111,6 +111,7 @@ BOOLEAN_COLUMNS = {
|
||||
"e2e_fallback_keys_json": ["used"],
|
||||
"access_tokens": ["used"],
|
||||
"device_lists_changes_in_room": ["converted_to_destinations"],
|
||||
"pushers": ["enabled"],
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -93,3 +93,13 @@ class ExperimentalConfig(Config):
|
||||
|
||||
# MSC3852: Expose last seen user agent field on /_matrix/client/v3/devices.
|
||||
self.msc3852_enabled: bool = experimental.get("msc3852_enabled", False)
|
||||
|
||||
# MSC3881: Remotely toggle push notifications for another client
|
||||
self.msc3881_enabled: bool = experimental.get("msc3881_enabled", False)
|
||||
|
||||
# MSC3882: Allow an existing session to sign in a new session
|
||||
self.msc3882_enabled: bool = experimental.get("msc3882_enabled", False)
|
||||
self.msc3882_ui_auth: bool = experimental.get("msc3882_ui_auth", True)
|
||||
self.msc3882_token_timeout = self.parse_duration(
|
||||
experimental.get("msc3882_token_timeout", "5m")
|
||||
)
|
||||
|
||||
@@ -997,7 +997,7 @@ class RegistrationHandler:
|
||||
assert user_tuple
|
||||
token_id = user_tuple.token_id
|
||||
|
||||
await self.pusher_pool.add_pusher(
|
||||
await self.pusher_pool.add_or_update_pusher(
|
||||
user_id=user_id,
|
||||
access_token=token_id,
|
||||
kind="email",
|
||||
@@ -1005,7 +1005,7 @@ class RegistrationHandler:
|
||||
app_display_name="Email Notifications",
|
||||
device_display_name=threepid["address"],
|
||||
pushkey=threepid["address"],
|
||||
lang=None, # We don't know a user's language here
|
||||
lang=None,
|
||||
data={},
|
||||
)
|
||||
|
||||
|
||||
@@ -128,6 +128,9 @@ class SsoIdentityProvider(Protocol):
|
||||
|
||||
@attr.s(auto_attribs=True)
|
||||
class UserAttributes:
|
||||
# NB: This struct is documented in docs/sso_mapping_providers.md so that users can
|
||||
# populate it with data from their own mapping providers.
|
||||
|
||||
# the localpart of the mxid that the mapper has assigned to the user.
|
||||
# if `None`, the mapper has not picked a userid, and the user should be prompted to
|
||||
# enter one.
|
||||
|
||||
@@ -125,7 +125,7 @@ from synapse.types import (
|
||||
)
|
||||
from synapse.util import Clock
|
||||
from synapse.util.async_helpers import maybe_awaitable
|
||||
from synapse.util.caches.descriptors import cached
|
||||
from synapse.util.caches.descriptors import CachedFunction, cached
|
||||
from synapse.util.frozenutils import freeze
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -836,6 +836,37 @@ class ModuleApi:
|
||||
self._store.db_pool.runInteraction(desc, func, *args, **kwargs) # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
def register_cached_function(self, cached_func: CachedFunction) -> None:
|
||||
"""Register a cached function that should be invalidated across workers.
|
||||
Invalidation local to a worker can be done directly using `cached_func.invalidate`,
|
||||
however invalidation that needs to go to other workers needs to call `invalidate_cache`
|
||||
on the module API instead.
|
||||
|
||||
Args:
|
||||
cached_function: The cached function that will be registered to receive invalidation
|
||||
locally and from other workers.
|
||||
"""
|
||||
self._store.register_external_cached_function(
|
||||
f"{cached_func.__module__}.{cached_func.__name__}", cached_func
|
||||
)
|
||||
|
||||
async def invalidate_cache(
|
||||
self, cached_func: CachedFunction, keys: Tuple[Any, ...]
|
||||
) -> None:
|
||||
"""Invalidate a cache entry of a cached function across workers. The cached function
|
||||
needs to be registered on all workers first with `register_cached_function`.
|
||||
|
||||
Args:
|
||||
cached_function: The cached function that needs an invalidation
|
||||
keys: keys of the entry to invalidate, usually matching the arguments of the
|
||||
cached function.
|
||||
"""
|
||||
cached_func.invalidate(keys)
|
||||
await self._store.send_invalidation_to_replication(
|
||||
f"{cached_func.__module__}.{cached_func.__name__}",
|
||||
keys,
|
||||
)
|
||||
|
||||
async def complete_sso_login_async(
|
||||
self,
|
||||
registered_user_id: str,
|
||||
|
||||
@@ -116,6 +116,8 @@ class PusherConfig:
|
||||
last_stream_ordering: int
|
||||
last_success: Optional[int]
|
||||
failing_since: Optional[int]
|
||||
enabled: bool
|
||||
device_id: Optional[str]
|
||||
|
||||
def as_dict(self) -> Dict[str, Any]:
|
||||
"""Information that can be retrieved about a pusher after creation."""
|
||||
@@ -128,6 +130,8 @@ class PusherConfig:
|
||||
"lang": self.lang,
|
||||
"profile_tag": self.profile_tag,
|
||||
"pushkey": self.pushkey,
|
||||
"enabled": self.enabled,
|
||||
"device_id": self.device_id,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -94,7 +94,7 @@ class PusherPool:
|
||||
return
|
||||
run_as_background_process("start_pushers", self._start_pushers)
|
||||
|
||||
async def add_pusher(
|
||||
async def add_or_update_pusher(
|
||||
self,
|
||||
user_id: str,
|
||||
access_token: Optional[int],
|
||||
@@ -106,6 +106,8 @@ class PusherPool:
|
||||
lang: Optional[str],
|
||||
data: JsonDict,
|
||||
profile_tag: str = "",
|
||||
enabled: bool = True,
|
||||
device_id: Optional[str] = None,
|
||||
) -> Optional[Pusher]:
|
||||
"""Creates a new pusher and adds it to the pool
|
||||
|
||||
@@ -147,9 +149,22 @@ class PusherPool:
|
||||
last_stream_ordering=last_stream_ordering,
|
||||
last_success=None,
|
||||
failing_since=None,
|
||||
enabled=enabled,
|
||||
device_id=device_id,
|
||||
)
|
||||
)
|
||||
|
||||
# Before we actually persist the pusher, we check if the user already has one
|
||||
# this app ID and pushkey. If so, we want to keep the access token and device ID
|
||||
# in place, since this could be one device modifying (e.g. enabling/disabling)
|
||||
# another device's pusher.
|
||||
existing_config = await self._get_pusher_config_for_user_by_app_id_and_pushkey(
|
||||
user_id, app_id, pushkey
|
||||
)
|
||||
if existing_config:
|
||||
access_token = existing_config.access_token
|
||||
device_id = existing_config.device_id
|
||||
|
||||
await self.store.add_pusher(
|
||||
user_id=user_id,
|
||||
access_token=access_token,
|
||||
@@ -163,8 +178,10 @@ class PusherPool:
|
||||
data=data,
|
||||
last_stream_ordering=last_stream_ordering,
|
||||
profile_tag=profile_tag,
|
||||
enabled=enabled,
|
||||
device_id=device_id,
|
||||
)
|
||||
pusher = await self.start_pusher_by_id(app_id, pushkey, user_id)
|
||||
pusher = await self.process_pusher_change_by_id(app_id, pushkey, user_id)
|
||||
|
||||
return pusher
|
||||
|
||||
@@ -276,10 +293,25 @@ class PusherPool:
|
||||
except Exception:
|
||||
logger.exception("Exception in pusher on_new_receipts")
|
||||
|
||||
async def start_pusher_by_id(
|
||||
async def _get_pusher_config_for_user_by_app_id_and_pushkey(
|
||||
self, user_id: str, app_id: str, pushkey: str
|
||||
) -> Optional[PusherConfig]:
|
||||
resultlist = await self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey)
|
||||
|
||||
pusher_config = None
|
||||
for r in resultlist:
|
||||
if r.user_name == user_id:
|
||||
pusher_config = r
|
||||
|
||||
return pusher_config
|
||||
|
||||
async def process_pusher_change_by_id(
|
||||
self, app_id: str, pushkey: str, user_id: str
|
||||
) -> Optional[Pusher]:
|
||||
"""Look up the details for the given pusher, and start it
|
||||
"""Look up the details for the given pusher, and either start it if its
|
||||
"enabled" flag is True, or try to stop it otherwise.
|
||||
|
||||
If the pusher is new and its "enabled" flag is False, the stop is a noop.
|
||||
|
||||
Returns:
|
||||
The pusher started, if any
|
||||
@@ -290,12 +322,13 @@ class PusherPool:
|
||||
if not self._pusher_shard_config.should_handle(self._instance_name, user_id):
|
||||
return None
|
||||
|
||||
resultlist = await self.store.get_pushers_by_app_id_and_pushkey(app_id, pushkey)
|
||||
pusher_config = await self._get_pusher_config_for_user_by_app_id_and_pushkey(
|
||||
user_id, app_id, pushkey
|
||||
)
|
||||
|
||||
pusher_config = None
|
||||
for r in resultlist:
|
||||
if r.user_name == user_id:
|
||||
pusher_config = r
|
||||
if pusher_config and not pusher_config.enabled:
|
||||
self.maybe_stop_pusher(app_id, pushkey, user_id)
|
||||
return None
|
||||
|
||||
pusher = None
|
||||
if pusher_config:
|
||||
@@ -305,7 +338,7 @@ class PusherPool:
|
||||
|
||||
async def _start_pushers(self) -> None:
|
||||
"""Start all the pushers"""
|
||||
pushers = await self.store.get_all_pushers()
|
||||
pushers = await self.store.get_enabled_pushers()
|
||||
|
||||
# Stagger starting up the pushers so we don't completely drown the
|
||||
# process on start up.
|
||||
@@ -363,6 +396,8 @@ class PusherPool:
|
||||
|
||||
synapse_pushers.labels(type(pusher).__name__, pusher.app_id).inc()
|
||||
|
||||
logger.info("Starting pusher %s / %s", pusher.user_id, appid_pushkey)
|
||||
|
||||
# Check if there *may* be push to process. We do this as this check is a
|
||||
# lot cheaper to do than actually fetching the exact rows we need to
|
||||
# push.
|
||||
@@ -382,16 +417,7 @@ class PusherPool:
|
||||
return pusher
|
||||
|
||||
async def remove_pusher(self, app_id: str, pushkey: str, user_id: str) -> None:
|
||||
appid_pushkey = "%s:%s" % (app_id, pushkey)
|
||||
|
||||
byuser = self.pushers.get(user_id, {})
|
||||
|
||||
if appid_pushkey in byuser:
|
||||
logger.info("Stopping pusher %s / %s", user_id, appid_pushkey)
|
||||
pusher = byuser.pop(appid_pushkey)
|
||||
pusher.on_stop()
|
||||
|
||||
synapse_pushers.labels(type(pusher).__name__, pusher.app_id).dec()
|
||||
self.maybe_stop_pusher(app_id, pushkey, user_id)
|
||||
|
||||
# We can only delete pushers on master.
|
||||
if self._remove_pusher_client:
|
||||
@@ -402,3 +428,22 @@ class PusherPool:
|
||||
await self.store.delete_pusher_by_app_id_pushkey_user_id(
|
||||
app_id, pushkey, user_id
|
||||
)
|
||||
|
||||
def maybe_stop_pusher(self, app_id: str, pushkey: str, user_id: str) -> None:
|
||||
"""Stops a pusher with the given app ID and push key if one is running.
|
||||
|
||||
Args:
|
||||
app_id: the pusher's app ID.
|
||||
pushkey: the pusher's push key.
|
||||
user_id: the user the pusher belongs to. Only used for logging.
|
||||
"""
|
||||
appid_pushkey = "%s:%s" % (app_id, pushkey)
|
||||
|
||||
byuser = self.pushers.get(user_id, {})
|
||||
|
||||
if appid_pushkey in byuser:
|
||||
logger.info("Stopping pusher %s / %s", user_id, appid_pushkey)
|
||||
pusher = byuser.pop(appid_pushkey)
|
||||
pusher.on_stop()
|
||||
|
||||
synapse_pushers.labels(type(pusher).__name__, pusher.app_id).dec()
|
||||
|
||||
@@ -189,7 +189,9 @@ class ReplicationDataHandler:
|
||||
if row.deleted:
|
||||
self.stop_pusher(row.user_id, row.app_id, row.pushkey)
|
||||
else:
|
||||
await self.start_pusher(row.user_id, row.app_id, row.pushkey)
|
||||
await self.process_pusher_change(
|
||||
row.user_id, row.app_id, row.pushkey
|
||||
)
|
||||
elif stream_name == EventsStream.NAME:
|
||||
# We shouldn't get multiple rows per token for events stream, so
|
||||
# we don't need to optimise this for multiple rows.
|
||||
@@ -334,13 +336,15 @@ class ReplicationDataHandler:
|
||||
logger.info("Stopping pusher %r / %r", user_id, key)
|
||||
pusher.on_stop()
|
||||
|
||||
async def start_pusher(self, user_id: str, app_id: str, pushkey: str) -> None:
|
||||
async def process_pusher_change(
|
||||
self, user_id: str, app_id: str, pushkey: str
|
||||
) -> None:
|
||||
if not self._notify_pushers:
|
||||
return
|
||||
|
||||
key = "%s:%s" % (app_id, pushkey)
|
||||
logger.info("Starting pusher %r / %r", user_id, key)
|
||||
await self._pusher_pool.start_pusher_by_id(app_id, pushkey, user_id)
|
||||
await self._pusher_pool.process_pusher_change_by_id(app_id, pushkey, user_id)
|
||||
|
||||
|
||||
class FederationSenderHandler:
|
||||
|
||||
@@ -30,6 +30,7 @@ from synapse.rest.client import (
|
||||
keys,
|
||||
knock,
|
||||
login as v1_login,
|
||||
login_token_request,
|
||||
logout,
|
||||
mutual_rooms,
|
||||
notifications,
|
||||
@@ -130,3 +131,4 @@ class ClientRestResource(JsonResource):
|
||||
|
||||
# unstable
|
||||
mutual_rooms.register_servlets(hs, client_resource)
|
||||
login_token_request.register_servlets(hs, client_resource)
|
||||
|
||||
@@ -375,7 +375,7 @@ class UserRestServletV2(RestServlet):
|
||||
and self.hs.config.email.email_notif_for_new_users
|
||||
and medium == "email"
|
||||
):
|
||||
await self.pusher_pool.add_pusher(
|
||||
await self.pusher_pool.add_or_update_pusher(
|
||||
user_id=user_id,
|
||||
access_token=None,
|
||||
kind="email",
|
||||
@@ -383,7 +383,7 @@ class UserRestServletV2(RestServlet):
|
||||
app_display_name="Email Notifications",
|
||||
device_display_name=address,
|
||||
pushkey=address,
|
||||
lang=None, # We don't know a user's language here
|
||||
lang=None,
|
||||
data={},
|
||||
)
|
||||
|
||||
|
||||
@@ -534,6 +534,11 @@ class AddThreepidMsisdnSubmitTokenServlet(RestServlet):
|
||||
"/add_threepid/msisdn/submit_token$", releases=(), unstable=True
|
||||
)
|
||||
|
||||
class PostBody(RequestBodyModel):
|
||||
client_secret: ClientSecretStr
|
||||
sid: StrictStr
|
||||
token: StrictStr
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self.config = hs.config
|
||||
@@ -549,16 +554,14 @@ class AddThreepidMsisdnSubmitTokenServlet(RestServlet):
|
||||
"instead.",
|
||||
)
|
||||
|
||||
body = parse_json_object_from_request(request)
|
||||
assert_params_in_dict(body, ["client_secret", "sid", "token"])
|
||||
assert_valid_client_secret(body["client_secret"])
|
||||
body = parse_and_validate_json_object_from_request(request, self.PostBody)
|
||||
|
||||
# Proxy submit_token request to msisdn threepid delegate
|
||||
response = await self.identity_handler.proxy_msisdn_submit_token(
|
||||
self.config.registration.account_threepid_delegate_msisdn,
|
||||
body["client_secret"],
|
||||
body["sid"],
|
||||
body["token"],
|
||||
body.client_secret,
|
||||
body.sid,
|
||||
body.token,
|
||||
)
|
||||
return 200, response
|
||||
|
||||
@@ -581,6 +584,10 @@ class ThreepidRestServlet(RestServlet):
|
||||
|
||||
return 200, {"threepids": threepids}
|
||||
|
||||
# NOTE(dmr): I have chosen not to use Pydantic to parse this request's body, because
|
||||
# the endpoint is deprecated. (If you really want to, you could do this by reusing
|
||||
# ThreePidBindRestServelet.PostBody with an `alias_generator` to handle
|
||||
# `threePidCreds` versus `three_pid_creds`.
|
||||
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||
if not self.hs.config.registration.enable_3pid_changes:
|
||||
raise SynapseError(
|
||||
|
||||
94
synapse/rest/client/login_token_request.py
Normal file
94
synapse/rest/client/login_token_request.py
Normal file
@@ -0,0 +1,94 @@
|
||||
# Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
|
||||
from synapse.http.server import HttpServer
|
||||
from synapse.http.servlet import RestServlet, parse_json_object_from_request
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.rest.client._base import client_patterns, interactive_auth_handler
|
||||
from synapse.types import JsonDict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LoginTokenRequestServlet(RestServlet):
|
||||
"""
|
||||
Get a token that can be used with `m.login.token` to log in a second device.
|
||||
|
||||
Request:
|
||||
|
||||
POST /login/token HTTP/1.1
|
||||
Content-Type: application/json
|
||||
|
||||
{}
|
||||
|
||||
Response:
|
||||
|
||||
HTTP/1.1 200 OK
|
||||
{
|
||||
"login_token": "ABDEFGH",
|
||||
"expires_in": 3600,
|
||||
}
|
||||
"""
|
||||
|
||||
PATTERNS = client_patterns("/login/token$")
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self.store = hs.get_datastores().main
|
||||
self.clock = hs.get_clock()
|
||||
self.server_name = hs.config.server.server_name
|
||||
self.macaroon_gen = hs.get_macaroon_generator()
|
||||
self.auth_handler = hs.get_auth_handler()
|
||||
self.token_timeout = hs.config.experimental.msc3882_token_timeout
|
||||
self.ui_auth = hs.config.experimental.msc3882_ui_auth
|
||||
|
||||
@interactive_auth_handler
|
||||
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
body = parse_json_object_from_request(request)
|
||||
|
||||
if self.ui_auth:
|
||||
await self.auth_handler.validate_user_via_ui_auth(
|
||||
requester,
|
||||
request,
|
||||
body,
|
||||
"issue a new access token for your account",
|
||||
can_skip_ui_auth=False, # Don't allow skipping of UI auth
|
||||
)
|
||||
|
||||
login_token = self.macaroon_gen.generate_short_term_login_token(
|
||||
user_id=requester.user.to_string(),
|
||||
auth_provider_id="org.matrix.msc3882.login_token_request",
|
||||
duration_in_ms=self.token_timeout,
|
||||
)
|
||||
|
||||
return (
|
||||
200,
|
||||
{
|
||||
"login_token": login_token,
|
||||
"expires_in": self.token_timeout // 1000,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
||||
if hs.config.experimental.msc3882_enabled:
|
||||
LoginTokenRequestServlet(hs).register(http_server)
|
||||
@@ -42,6 +42,7 @@ class PushersRestServlet(RestServlet):
|
||||
super().__init__()
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self._msc3881_enabled = self.hs.config.experimental.msc3881_enabled
|
||||
|
||||
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
@@ -51,9 +52,16 @@ class PushersRestServlet(RestServlet):
|
||||
user.to_string()
|
||||
)
|
||||
|
||||
filtered_pushers = [p.as_dict() for p in pushers]
|
||||
pusher_dicts = [p.as_dict() for p in pushers]
|
||||
|
||||
return 200, {"pushers": filtered_pushers}
|
||||
for pusher in pusher_dicts:
|
||||
if self._msc3881_enabled:
|
||||
pusher["org.matrix.msc3881.enabled"] = pusher["enabled"]
|
||||
pusher["org.matrix.msc3881.device_id"] = pusher["device_id"]
|
||||
del pusher["enabled"]
|
||||
del pusher["device_id"]
|
||||
|
||||
return 200, {"pushers": pusher_dicts}
|
||||
|
||||
|
||||
class PushersSetRestServlet(RestServlet):
|
||||
@@ -65,6 +73,7 @@ class PushersSetRestServlet(RestServlet):
|
||||
self.auth = hs.get_auth()
|
||||
self.notifier = hs.get_notifier()
|
||||
self.pusher_pool = self.hs.get_pusherpool()
|
||||
self._msc3881_enabled = self.hs.config.experimental.msc3881_enabled
|
||||
|
||||
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
@@ -103,6 +112,10 @@ class PushersSetRestServlet(RestServlet):
|
||||
if "append" in content:
|
||||
append = content["append"]
|
||||
|
||||
enabled = True
|
||||
if self._msc3881_enabled and "org.matrix.msc3881.enabled" in content:
|
||||
enabled = content["org.matrix.msc3881.enabled"]
|
||||
|
||||
if not append:
|
||||
await self.pusher_pool.remove_pushers_by_app_id_and_pushkey_not_user(
|
||||
app_id=content["app_id"],
|
||||
@@ -111,7 +124,7 @@ class PushersSetRestServlet(RestServlet):
|
||||
)
|
||||
|
||||
try:
|
||||
await self.pusher_pool.add_pusher(
|
||||
await self.pusher_pool.add_or_update_pusher(
|
||||
user_id=user.to_string(),
|
||||
access_token=requester.access_token_id,
|
||||
kind=content["kind"],
|
||||
@@ -122,6 +135,8 @@ class PushersSetRestServlet(RestServlet):
|
||||
lang=content["lang"],
|
||||
data=content["data"],
|
||||
profile_tag=content.get("profile_tag", ""),
|
||||
enabled=enabled,
|
||||
device_id=requester.device_id,
|
||||
)
|
||||
except PusherConfigException as pce:
|
||||
raise SynapseError(
|
||||
|
||||
@@ -105,6 +105,10 @@ class VersionsRestServlet(RestServlet):
|
||||
"org.matrix.msc3440.stable": True, # TODO: remove when "v1.3" is added above
|
||||
# Allows moderators to fetch redacted event content as described in MSC2815
|
||||
"fi.mau.msc2815": self.config.experimental.msc2815_enabled,
|
||||
# Adds support for login token requests as per MSC3882
|
||||
"org.matrix.msc3882": self.config.experimental.msc3882_enabled,
|
||||
# Adds support for remotely enabling/disabling pushers, as per MSC3881
|
||||
"org.matrix.msc3881": self.config.experimental.msc3881_enabled,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
@@ -15,12 +15,13 @@
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from abc import ABCMeta
|
||||
from typing import TYPE_CHECKING, Any, Collection, Iterable, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Collection, Dict, Iterable, Optional, Union
|
||||
|
||||
from synapse.storage.database import make_in_list_sql_clause # noqa: F401; noqa: F401
|
||||
from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
|
||||
from synapse.types import get_domain_from_id
|
||||
from synapse.util import json_decoder
|
||||
from synapse.util.caches.descriptors import CachedFunction
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
@@ -47,6 +48,8 @@ class SQLBaseStore(metaclass=ABCMeta):
|
||||
self.database_engine = database.engine
|
||||
self.db_pool = database
|
||||
|
||||
self.external_cached_functions: Dict[str, CachedFunction] = {}
|
||||
|
||||
def process_replication_rows(
|
||||
self,
|
||||
stream_name: str,
|
||||
@@ -95,7 +98,7 @@ class SQLBaseStore(metaclass=ABCMeta):
|
||||
|
||||
def _attempt_to_invalidate_cache(
|
||||
self, cache_name: str, key: Optional[Collection[Any]]
|
||||
) -> None:
|
||||
) -> bool:
|
||||
"""Attempts to invalidate the cache of the given name, ignoring if the
|
||||
cache doesn't exist. Mainly used for invalidating caches on workers,
|
||||
where they may not have the cache.
|
||||
@@ -113,9 +116,12 @@ class SQLBaseStore(metaclass=ABCMeta):
|
||||
try:
|
||||
cache = getattr(self, cache_name)
|
||||
except AttributeError:
|
||||
# We probably haven't pulled in the cache in this worker,
|
||||
# which is fine.
|
||||
return
|
||||
# Check if an externally defined module cache has been registered
|
||||
cache = self.external_cached_functions.get(cache_name)
|
||||
if not cache:
|
||||
# We probably haven't pulled in the cache in this worker,
|
||||
# which is fine.
|
||||
return False
|
||||
|
||||
if key is None:
|
||||
cache.invalidate_all()
|
||||
@@ -125,6 +131,13 @@ class SQLBaseStore(metaclass=ABCMeta):
|
||||
invalidate_method = getattr(cache, "invalidate_local", cache.invalidate)
|
||||
invalidate_method(tuple(key))
|
||||
|
||||
return True
|
||||
|
||||
def register_external_cached_function(
|
||||
self, cache_name: str, func: CachedFunction
|
||||
) -> None:
|
||||
self.external_cached_functions[cache_name] = func
|
||||
|
||||
|
||||
def db_to_json(db_content: Union[memoryview, bytes, bytearray, str]) -> Any:
|
||||
"""
|
||||
|
||||
@@ -33,7 +33,7 @@ from synapse.storage.database import (
|
||||
)
|
||||
from synapse.storage.engines import PostgresEngine
|
||||
from synapse.storage.util.id_generators import MultiWriterIdGenerator
|
||||
from synapse.util.caches.descriptors import _CachedFunction
|
||||
from synapse.util.caches.descriptors import CachedFunction
|
||||
from synapse.util.iterutils import batch_iter
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -269,9 +269,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
|
||||
return
|
||||
|
||||
cache_func.invalidate(keys)
|
||||
await self.db_pool.runInteraction(
|
||||
"invalidate_cache_and_stream",
|
||||
self._send_invalidation_to_replication,
|
||||
await self.send_invalidation_to_replication(
|
||||
cache_func.__name__,
|
||||
keys,
|
||||
)
|
||||
@@ -279,7 +277,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
|
||||
def _invalidate_cache_and_stream(
|
||||
self,
|
||||
txn: LoggingTransaction,
|
||||
cache_func: _CachedFunction,
|
||||
cache_func: CachedFunction,
|
||||
keys: Tuple[Any, ...],
|
||||
) -> None:
|
||||
"""Invalidates the cache and adds it to the cache stream so slaves
|
||||
@@ -293,7 +291,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
|
||||
self._send_invalidation_to_replication(txn, cache_func.__name__, keys)
|
||||
|
||||
def _invalidate_all_cache_and_stream(
|
||||
self, txn: LoggingTransaction, cache_func: _CachedFunction
|
||||
self, txn: LoggingTransaction, cache_func: CachedFunction
|
||||
) -> None:
|
||||
"""Invalidates the entire cache and adds it to the cache stream so slaves
|
||||
will know to invalidate their caches.
|
||||
@@ -334,6 +332,16 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
|
||||
txn, CURRENT_STATE_CACHE_NAME, [room_id]
|
||||
)
|
||||
|
||||
async def send_invalidation_to_replication(
|
||||
self, cache_name: str, keys: Optional[Collection[Any]]
|
||||
) -> None:
|
||||
await self.db_pool.runInteraction(
|
||||
"send_invalidation_to_replication",
|
||||
self._send_invalidation_to_replication,
|
||||
cache_name,
|
||||
keys,
|
||||
)
|
||||
|
||||
def _send_invalidation_to_replication(
|
||||
self, txn: LoggingTransaction, cache_name: str, keys: Optional[Iterable[Any]]
|
||||
) -> None:
|
||||
|
||||
@@ -412,6 +412,31 @@ class PersistEventsStore:
|
||||
assert min_stream_order
|
||||
assert max_stream_order
|
||||
|
||||
# Once the txn completes, invalidate all of the relevant caches. Note that we do this
|
||||
# up here because it captures all the events_and_contexts before any are removed.
|
||||
for event, _ in events_and_contexts:
|
||||
self.store.invalidate_get_event_cache_after_txn(txn, event.event_id)
|
||||
if event.redacts:
|
||||
self.store.invalidate_get_event_cache_after_txn(txn, event.redacts)
|
||||
|
||||
relates_to = None
|
||||
relation = relation_from_event(event)
|
||||
if relation:
|
||||
relates_to = relation.parent_id
|
||||
|
||||
assert event.internal_metadata.stream_ordering is not None
|
||||
txn.call_after(
|
||||
self.store._invalidate_caches_for_event,
|
||||
event.internal_metadata.stream_ordering,
|
||||
event.event_id,
|
||||
event.room_id,
|
||||
event.type,
|
||||
getattr(event, "state_key", None),
|
||||
event.redacts,
|
||||
relates_to,
|
||||
backfilled=False,
|
||||
)
|
||||
|
||||
self._update_forward_extremities_txn(
|
||||
txn,
|
||||
new_forward_extremities=new_forward_extremities,
|
||||
|
||||
@@ -1474,32 +1474,38 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
# the batches as big as possible.
|
||||
|
||||
results: Set[str] = set()
|
||||
for chunk in batch_iter(event_ids, 500):
|
||||
r = await self._have_seen_events_dict(
|
||||
[(room_id, event_id) for event_id in chunk]
|
||||
for event_ids_chunk in batch_iter(event_ids, 500):
|
||||
events_seen_dict = await self._have_seen_events_dict(
|
||||
room_id, event_ids_chunk
|
||||
)
|
||||
results.update(
|
||||
eid for (eid, have_event) in events_seen_dict.items() if have_event
|
||||
)
|
||||
results.update(eid for ((_rid, eid), have_event) in r.items() if have_event)
|
||||
|
||||
return results
|
||||
|
||||
@cachedList(cached_method_name="have_seen_event", list_name="keys")
|
||||
@cachedList(cached_method_name="have_seen_event", list_name="event_ids")
|
||||
async def _have_seen_events_dict(
|
||||
self, keys: Collection[Tuple[str, str]]
|
||||
) -> Dict[Tuple[str, str], bool]:
|
||||
self,
|
||||
room_id: str,
|
||||
event_ids: Collection[str],
|
||||
) -> Dict[str, bool]:
|
||||
"""Helper for have_seen_events
|
||||
|
||||
Returns:
|
||||
a dict {(room_id, event_id)-> bool}
|
||||
a dict {event_id -> bool}
|
||||
"""
|
||||
# if the event cache contains the event, obviously we've seen it.
|
||||
|
||||
cache_results = {
|
||||
(rid, eid)
|
||||
for (rid, eid) in keys
|
||||
if await self._get_event_cache.contains((eid,))
|
||||
event_id
|
||||
for event_id in event_ids
|
||||
if await self._get_event_cache.contains((event_id,))
|
||||
}
|
||||
results = dict.fromkeys(cache_results, True)
|
||||
remaining = [k for k in keys if k not in cache_results]
|
||||
remaining = [
|
||||
event_id for event_id in event_ids if event_id not in cache_results
|
||||
]
|
||||
if not remaining:
|
||||
return results
|
||||
|
||||
@@ -1511,23 +1517,21 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
|
||||
sql = "SELECT event_id FROM events AS e WHERE "
|
||||
clause, args = make_in_list_sql_clause(
|
||||
txn.database_engine, "e.event_id", [eid for (_rid, eid) in remaining]
|
||||
txn.database_engine, "e.event_id", remaining
|
||||
)
|
||||
txn.execute(sql + clause, args)
|
||||
found_events = {eid for eid, in txn}
|
||||
|
||||
# ... and then we can update the results for each key
|
||||
results.update(
|
||||
{(rid, eid): (eid in found_events) for (rid, eid) in remaining}
|
||||
)
|
||||
results.update({eid: (eid in found_events) for eid in remaining})
|
||||
|
||||
await self.db_pool.runInteraction("have_seen_events", have_seen_events_txn)
|
||||
return results
|
||||
|
||||
@cached(max_entries=100000, tree=True)
|
||||
async def have_seen_event(self, room_id: str, event_id: str) -> bool:
|
||||
res = await self._have_seen_events_dict(((room_id, event_id),))
|
||||
return res[(room_id, event_id)]
|
||||
res = await self._have_seen_events_dict(room_id, [event_id])
|
||||
return res[event_id]
|
||||
|
||||
def _get_current_state_event_counts_txn(
|
||||
self, txn: LoggingTransaction, room_id: str
|
||||
|
||||
@@ -89,6 +89,11 @@ class PusherWorkerStore(SQLBaseStore):
|
||||
)
|
||||
continue
|
||||
|
||||
# If we're using SQLite, then boolean values are integers. This is
|
||||
# troublesome since some code using the return value of this method might
|
||||
# expect it to be a boolean, or will expose it to clients (in responses).
|
||||
r["enabled"] = bool(r["enabled"])
|
||||
|
||||
yield PusherConfig(**r)
|
||||
|
||||
async def get_pushers_by_app_id_and_pushkey(
|
||||
@@ -100,38 +105,52 @@ class PusherWorkerStore(SQLBaseStore):
|
||||
return await self.get_pushers_by({"user_name": user_id})
|
||||
|
||||
async def get_pushers_by(self, keyvalues: Dict[str, Any]) -> Iterator[PusherConfig]:
|
||||
ret = await self.db_pool.simple_select_list(
|
||||
"pushers",
|
||||
keyvalues,
|
||||
[
|
||||
"id",
|
||||
"user_name",
|
||||
"access_token",
|
||||
"profile_tag",
|
||||
"kind",
|
||||
"app_id",
|
||||
"app_display_name",
|
||||
"device_display_name",
|
||||
"pushkey",
|
||||
"ts",
|
||||
"lang",
|
||||
"data",
|
||||
"last_stream_ordering",
|
||||
"last_success",
|
||||
"failing_since",
|
||||
],
|
||||
"""Retrieve pushers that match the given criteria.
|
||||
|
||||
Args:
|
||||
keyvalues: A {column: value} dictionary.
|
||||
|
||||
Returns:
|
||||
The pushers for which the given columns have the given values.
|
||||
"""
|
||||
|
||||
def get_pushers_by_txn(txn: LoggingTransaction) -> List[Dict[str, Any]]:
|
||||
# We could technically use simple_select_list here, but we need to call
|
||||
# COALESCE on the 'enabled' column. While it is technically possible to give
|
||||
# simple_select_list the whole `COALESCE(...) AS ...` as a column name, it
|
||||
# feels a bit hacky, so it's probably better to just inline the query.
|
||||
sql = """
|
||||
SELECT
|
||||
id, user_name, access_token, profile_tag, kind, app_id,
|
||||
app_display_name, device_display_name, pushkey, ts, lang, data,
|
||||
last_stream_ordering, last_success, failing_since,
|
||||
COALESCE(enabled, TRUE) AS enabled, device_id
|
||||
FROM pushers
|
||||
"""
|
||||
|
||||
sql += "WHERE %s" % (" AND ".join("%s = ?" % (k,) for k in keyvalues),)
|
||||
|
||||
txn.execute(sql, list(keyvalues.values()))
|
||||
|
||||
return self.db_pool.cursor_to_dict(txn)
|
||||
|
||||
ret = await self.db_pool.runInteraction(
|
||||
desc="get_pushers_by",
|
||||
func=get_pushers_by_txn,
|
||||
)
|
||||
|
||||
return self._decode_pushers_rows(ret)
|
||||
|
||||
async def get_all_pushers(self) -> Iterator[PusherConfig]:
|
||||
def get_pushers(txn: LoggingTransaction) -> Iterator[PusherConfig]:
|
||||
txn.execute("SELECT * FROM pushers")
|
||||
async def get_enabled_pushers(self) -> Iterator[PusherConfig]:
|
||||
def get_enabled_pushers_txn(txn: LoggingTransaction) -> Iterator[PusherConfig]:
|
||||
txn.execute("SELECT * FROM pushers WHERE COALESCE(enabled, TRUE)")
|
||||
rows = self.db_pool.cursor_to_dict(txn)
|
||||
|
||||
return self._decode_pushers_rows(rows)
|
||||
|
||||
return await self.db_pool.runInteraction("get_all_pushers", get_pushers)
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_enabled_pushers", get_enabled_pushers_txn
|
||||
)
|
||||
|
||||
async def get_all_updated_pushers_rows(
|
||||
self, instance_name: str, last_id: int, current_id: int, limit: int
|
||||
@@ -458,7 +477,74 @@ class PusherWorkerStore(SQLBaseStore):
|
||||
return number_deleted
|
||||
|
||||
|
||||
class PusherStore(PusherWorkerStore):
|
||||
class PusherBackgroundUpdatesStore(SQLBaseStore):
|
||||
def __init__(
|
||||
self,
|
||||
database: DatabasePool,
|
||||
db_conn: LoggingDatabaseConnection,
|
||||
hs: "HomeServer",
|
||||
):
|
||||
super().__init__(database, db_conn, hs)
|
||||
|
||||
self.db_pool.updates.register_background_update_handler(
|
||||
"set_device_id_for_pushers", self._set_device_id_for_pushers
|
||||
)
|
||||
|
||||
async def _set_device_id_for_pushers(
|
||||
self, progress: JsonDict, batch_size: int
|
||||
) -> int:
|
||||
"""Background update to populate the device_id column of the pushers table."""
|
||||
last_pusher_id = progress.get("pusher_id", 0)
|
||||
|
||||
def set_device_id_for_pushers_txn(txn: LoggingTransaction) -> int:
|
||||
txn.execute(
|
||||
"""
|
||||
SELECT p.id, at.device_id
|
||||
FROM pushers AS p
|
||||
INNER JOIN access_tokens AS at
|
||||
ON p.access_token = at.id
|
||||
WHERE
|
||||
p.access_token IS NOT NULL
|
||||
AND at.device_id IS NOT NULL
|
||||
AND p.id > ?
|
||||
ORDER BY p.id
|
||||
LIMIT ?
|
||||
""",
|
||||
(last_pusher_id, batch_size),
|
||||
)
|
||||
|
||||
rows = self.db_pool.cursor_to_dict(txn)
|
||||
if len(rows) == 0:
|
||||
return 0
|
||||
|
||||
self.db_pool.simple_update_many_txn(
|
||||
txn=txn,
|
||||
table="pushers",
|
||||
key_names=("id",),
|
||||
key_values=[(row["id"],) for row in rows],
|
||||
value_names=("device_id",),
|
||||
value_values=[(row["device_id"],) for row in rows],
|
||||
)
|
||||
|
||||
self.db_pool.updates._background_update_progress_txn(
|
||||
txn, "set_device_id_for_pushers", {"pusher_id": rows[-1]["id"]}
|
||||
)
|
||||
|
||||
return len(rows)
|
||||
|
||||
nb_processed = await self.db_pool.runInteraction(
|
||||
"set_device_id_for_pushers", set_device_id_for_pushers_txn
|
||||
)
|
||||
|
||||
if nb_processed < batch_size:
|
||||
await self.db_pool.updates._end_background_update(
|
||||
"set_device_id_for_pushers"
|
||||
)
|
||||
|
||||
return nb_processed
|
||||
|
||||
|
||||
class PusherStore(PusherWorkerStore, PusherBackgroundUpdatesStore):
|
||||
def get_pushers_stream_token(self) -> int:
|
||||
return self._pushers_id_gen.get_current_token()
|
||||
|
||||
@@ -476,6 +562,8 @@ class PusherStore(PusherWorkerStore):
|
||||
data: Optional[JsonDict],
|
||||
last_stream_ordering: int,
|
||||
profile_tag: str = "",
|
||||
enabled: bool = True,
|
||||
device_id: Optional[str] = None,
|
||||
) -> None:
|
||||
async with self._pushers_id_gen.get_next() as stream_id:
|
||||
# no need to lock because `pushers` has a unique key on
|
||||
@@ -494,6 +582,8 @@ class PusherStore(PusherWorkerStore):
|
||||
"last_stream_ordering": last_stream_ordering,
|
||||
"profile_tag": profile_tag,
|
||||
"id": stream_id,
|
||||
"enabled": enabled,
|
||||
"device_id": device_id,
|
||||
},
|
||||
desc="add_pusher",
|
||||
lock=False,
|
||||
|
||||
@@ -52,6 +52,8 @@ class _RelatedEvent:
|
||||
event_id: str
|
||||
# The sender of the related event.
|
||||
sender: str
|
||||
topological_ordering: Optional[int]
|
||||
stream_ordering: int
|
||||
|
||||
|
||||
class RelationsWorkerStore(SQLBaseStore):
|
||||
@@ -92,6 +94,9 @@ class RelationsWorkerStore(SQLBaseStore):
|
||||
# it. The `event_id` must match the `event.event_id`.
|
||||
assert event.event_id == event_id
|
||||
|
||||
# Ensure bad limits aren't being passed in.
|
||||
assert limit >= 0
|
||||
|
||||
where_clause = ["relates_to_id = ?", "room_id = ?"]
|
||||
where_args: List[Union[str, int]] = [event.event_id, room_id]
|
||||
is_redacted = event.internal_metadata.is_redacted()
|
||||
@@ -140,21 +145,34 @@ class RelationsWorkerStore(SQLBaseStore):
|
||||
) -> Tuple[List[_RelatedEvent], Optional[StreamToken]]:
|
||||
txn.execute(sql, where_args + [limit + 1])
|
||||
|
||||
last_topo_id = None
|
||||
last_stream_id = None
|
||||
events = []
|
||||
for row in txn:
|
||||
for event_id, relation_type, sender, topo_ordering, stream_ordering in txn:
|
||||
# Do not include edits for redacted events as they leak event
|
||||
# content.
|
||||
if not is_redacted or row[1] != RelationTypes.REPLACE:
|
||||
events.append(_RelatedEvent(row[0], row[2]))
|
||||
last_topo_id = row[3]
|
||||
last_stream_id = row[4]
|
||||
if not is_redacted or relation_type != RelationTypes.REPLACE:
|
||||
events.append(
|
||||
_RelatedEvent(event_id, sender, topo_ordering, stream_ordering)
|
||||
)
|
||||
|
||||
# If there are more events, generate the next pagination key.
|
||||
# If there are more events, generate the next pagination key from the
|
||||
# last event returned.
|
||||
next_token = None
|
||||
if len(events) > limit and last_topo_id and last_stream_id:
|
||||
next_key = RoomStreamToken(last_topo_id, last_stream_id)
|
||||
if len(events) > limit:
|
||||
# Instead of using the last row (which tells us there is more
|
||||
# data), use the last row to be returned.
|
||||
events = events[:limit]
|
||||
|
||||
topo = events[-1].topological_ordering
|
||||
token = events[-1].stream_ordering
|
||||
if direction == "b":
|
||||
# Tokens are positions between events.
|
||||
# This token points *after* the last event in the chunk.
|
||||
# We need it to point to the event before it in the chunk
|
||||
# when we are going backwards so we subtract one from the
|
||||
# stream part.
|
||||
token -= 1
|
||||
next_key = RoomStreamToken(topo, token)
|
||||
|
||||
if from_token:
|
||||
next_token = from_token.copy_and_replace(
|
||||
StreamKeyType.ROOM, next_key
|
||||
|
||||
@@ -1334,15 +1334,15 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
|
||||
if rows:
|
||||
topo = rows[-1].topological_ordering
|
||||
toke = rows[-1].stream_ordering
|
||||
token = rows[-1].stream_ordering
|
||||
if direction == "b":
|
||||
# Tokens are positions between events.
|
||||
# This token points *after* the last event in the chunk.
|
||||
# We need it to point to the event before it in the chunk
|
||||
# when we are going backwards so we subtract one from the
|
||||
# stream part.
|
||||
toke -= 1
|
||||
next_token = RoomStreamToken(topo, toke)
|
||||
token -= 1
|
||||
next_token = RoomStreamToken(topo, token)
|
||||
else:
|
||||
# TODO (erikj): We should work out what to do here instead.
|
||||
next_token = to_token if to_token else from_token
|
||||
|
||||
@@ -0,0 +1,16 @@
|
||||
/* Copyright 2022 The Matrix.org Foundation C.I.C
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
ALTER TABLE pushers ADD COLUMN enabled BOOLEAN;
|
||||
20
synapse/storage/schema/main/delta/73/03pusher_device_id.sql
Normal file
20
synapse/storage/schema/main/delta/73/03pusher_device_id.sql
Normal file
@@ -0,0 +1,20 @@
|
||||
/* Copyright 2022 The Matrix.org Foundation C.I.C
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
-- Add a device_id column to track the device ID that created the pusher. It's NULLable
|
||||
-- on purpose, because a) it might not be possible to track down the device that created
|
||||
-- old pushers (pushers.access_token and access_tokens.device_id are both NULLable), and
|
||||
-- b) access tokens retrieved via the admin API don't have a device associated to them.
|
||||
ALTER TABLE pushers ADD COLUMN device_id TEXT;
|
||||
@@ -53,7 +53,7 @@ CacheKey = Union[Tuple, Any]
|
||||
F = TypeVar("F", bound=Callable[..., Any])
|
||||
|
||||
|
||||
class _CachedFunction(Generic[F]):
|
||||
class CachedFunction(Generic[F]):
|
||||
invalidate: Any = None
|
||||
invalidate_all: Any = None
|
||||
prefill: Any = None
|
||||
@@ -242,7 +242,7 @@ class LruCacheDescriptor(_CacheDescriptorBase):
|
||||
|
||||
return ret2
|
||||
|
||||
wrapped = cast(_CachedFunction, _wrapped)
|
||||
wrapped = cast(CachedFunction, _wrapped)
|
||||
wrapped.cache = cache
|
||||
obj.__dict__[self.name] = wrapped
|
||||
|
||||
@@ -363,7 +363,7 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
|
||||
|
||||
return make_deferred_yieldable(ret)
|
||||
|
||||
wrapped = cast(_CachedFunction, _wrapped)
|
||||
wrapped = cast(CachedFunction, _wrapped)
|
||||
|
||||
if self.num_args == 1:
|
||||
assert not self.tree
|
||||
@@ -431,6 +431,12 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
|
||||
cache: DeferredCache[CacheKey, Any] = cached_method.cache
|
||||
num_args = cached_method.num_args
|
||||
|
||||
if num_args != self.num_args:
|
||||
raise Exception(
|
||||
"Number of args (%s) does not match underlying cache_method_name=%s (%s)."
|
||||
% (self.num_args, self.cached_method_name, num_args)
|
||||
)
|
||||
|
||||
@functools.wraps(self.orig)
|
||||
def wrapped(*args: Any, **kwargs: Any) -> "defer.Deferred[Dict]":
|
||||
# If we're passed a cache_context then we'll want to call its
|
||||
@@ -572,7 +578,7 @@ def cached(
|
||||
iterable: bool = False,
|
||||
prune_unread_entries: bool = True,
|
||||
name: Optional[str] = None,
|
||||
) -> Callable[[F], _CachedFunction[F]]:
|
||||
) -> Callable[[F], CachedFunction[F]]:
|
||||
func = lambda orig: DeferredCacheDescriptor(
|
||||
orig,
|
||||
max_entries=max_entries,
|
||||
@@ -585,7 +591,7 @@ def cached(
|
||||
name=name,
|
||||
)
|
||||
|
||||
return cast(Callable[[F], _CachedFunction[F]], func)
|
||||
return cast(Callable[[F], CachedFunction[F]], func)
|
||||
|
||||
|
||||
def cachedList(
|
||||
@@ -594,7 +600,7 @@ def cachedList(
|
||||
list_name: str,
|
||||
num_args: Optional[int] = None,
|
||||
name: Optional[str] = None,
|
||||
) -> Callable[[F], _CachedFunction[F]]:
|
||||
) -> Callable[[F], CachedFunction[F]]:
|
||||
"""Creates a descriptor that wraps a function in a `DeferredCacheListDescriptor`.
|
||||
|
||||
Used to do batch lookups for an already created cache. One of the arguments
|
||||
@@ -631,7 +637,7 @@ def cachedList(
|
||||
name=name,
|
||||
)
|
||||
|
||||
return cast(Callable[[F], _CachedFunction[F]], func)
|
||||
return cast(Callable[[F], CachedFunction[F]], func)
|
||||
|
||||
|
||||
def _get_cache_key_builder(
|
||||
|
||||
@@ -114,7 +114,7 @@ class EmailPusherTests(HomeserverTestCase):
|
||||
)
|
||||
|
||||
self.pusher = self.get_success(
|
||||
self.hs.get_pusherpool().add_pusher(
|
||||
self.hs.get_pusherpool().add_or_update_pusher(
|
||||
user_id=self.user_id,
|
||||
access_token=self.token_id,
|
||||
kind="email",
|
||||
@@ -136,7 +136,7 @@ class EmailPusherTests(HomeserverTestCase):
|
||||
"""
|
||||
with self.assertRaises(SynapseError) as cm:
|
||||
self.get_success_or_raise(
|
||||
self.hs.get_pusherpool().add_pusher(
|
||||
self.hs.get_pusherpool().add_or_update_pusher(
|
||||
user_id=self.user_id,
|
||||
access_token=self.token_id,
|
||||
kind="email",
|
||||
|
||||
@@ -19,9 +19,10 @@ from twisted.test.proto_helpers import MemoryReactor
|
||||
|
||||
import synapse.rest.admin
|
||||
from synapse.logging.context import make_deferred_yieldable
|
||||
from synapse.push import PusherConfigException
|
||||
from synapse.rest.client import login, push_rule, receipts, room
|
||||
from synapse.push import PusherConfig, PusherConfigException
|
||||
from synapse.rest.client import login, push_rule, pusher, receipts, room
|
||||
from synapse.server import HomeServer
|
||||
from synapse.storage.databases.main.registration import TokenLookupResult
|
||||
from synapse.types import JsonDict
|
||||
from synapse.util import Clock
|
||||
|
||||
@@ -35,6 +36,7 @@ class HTTPPusherTests(HomeserverTestCase):
|
||||
login.register_servlets,
|
||||
receipts.register_servlets,
|
||||
push_rule.register_servlets,
|
||||
pusher.register_servlets,
|
||||
]
|
||||
user_id = True
|
||||
hijack_auth = False
|
||||
@@ -74,7 +76,7 @@ class HTTPPusherTests(HomeserverTestCase):
|
||||
|
||||
def test_data(data: Optional[JsonDict]) -> None:
|
||||
self.get_failure(
|
||||
self.hs.get_pusherpool().add_pusher(
|
||||
self.hs.get_pusherpool().add_or_update_pusher(
|
||||
user_id=user_id,
|
||||
access_token=token_id,
|
||||
kind="http",
|
||||
@@ -119,7 +121,7 @@ class HTTPPusherTests(HomeserverTestCase):
|
||||
token_id = user_tuple.token_id
|
||||
|
||||
self.get_success(
|
||||
self.hs.get_pusherpool().add_pusher(
|
||||
self.hs.get_pusherpool().add_or_update_pusher(
|
||||
user_id=user_id,
|
||||
access_token=token_id,
|
||||
kind="http",
|
||||
@@ -235,7 +237,7 @@ class HTTPPusherTests(HomeserverTestCase):
|
||||
token_id = user_tuple.token_id
|
||||
|
||||
self.get_success(
|
||||
self.hs.get_pusherpool().add_pusher(
|
||||
self.hs.get_pusherpool().add_or_update_pusher(
|
||||
user_id=user_id,
|
||||
access_token=token_id,
|
||||
kind="http",
|
||||
@@ -355,7 +357,7 @@ class HTTPPusherTests(HomeserverTestCase):
|
||||
token_id = user_tuple.token_id
|
||||
|
||||
self.get_success(
|
||||
self.hs.get_pusherpool().add_pusher(
|
||||
self.hs.get_pusherpool().add_or_update_pusher(
|
||||
user_id=user_id,
|
||||
access_token=token_id,
|
||||
kind="http",
|
||||
@@ -441,7 +443,7 @@ class HTTPPusherTests(HomeserverTestCase):
|
||||
token_id = user_tuple.token_id
|
||||
|
||||
self.get_success(
|
||||
self.hs.get_pusherpool().add_pusher(
|
||||
self.hs.get_pusherpool().add_or_update_pusher(
|
||||
user_id=user_id,
|
||||
access_token=token_id,
|
||||
kind="http",
|
||||
@@ -518,7 +520,7 @@ class HTTPPusherTests(HomeserverTestCase):
|
||||
token_id = user_tuple.token_id
|
||||
|
||||
self.get_success(
|
||||
self.hs.get_pusherpool().add_pusher(
|
||||
self.hs.get_pusherpool().add_or_update_pusher(
|
||||
user_id=user_id,
|
||||
access_token=token_id,
|
||||
kind="http",
|
||||
@@ -624,7 +626,7 @@ class HTTPPusherTests(HomeserverTestCase):
|
||||
token_id = user_tuple.token_id
|
||||
|
||||
self.get_success(
|
||||
self.hs.get_pusherpool().add_pusher(
|
||||
self.hs.get_pusherpool().add_or_update_pusher(
|
||||
user_id=user_id,
|
||||
access_token=token_id,
|
||||
kind="http",
|
||||
@@ -728,18 +730,38 @@ class HTTPPusherTests(HomeserverTestCase):
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.json_body)
|
||||
|
||||
def _make_user_with_pusher(self, username: str) -> Tuple[str, str]:
|
||||
def _make_user_with_pusher(
|
||||
self, username: str, enabled: bool = True
|
||||
) -> Tuple[str, str]:
|
||||
"""Registers a user and creates a pusher for them.
|
||||
|
||||
Args:
|
||||
username: the localpart of the new user's Matrix ID.
|
||||
enabled: whether to create the pusher in an enabled or disabled state.
|
||||
"""
|
||||
user_id = self.register_user(username, "pass")
|
||||
access_token = self.login(username, "pass")
|
||||
|
||||
# Register the pusher
|
||||
self._set_pusher(user_id, access_token, enabled)
|
||||
|
||||
return user_id, access_token
|
||||
|
||||
def _set_pusher(self, user_id: str, access_token: str, enabled: bool) -> None:
|
||||
"""Creates or updates the pusher for the given user.
|
||||
|
||||
Args:
|
||||
user_id: the user's Matrix ID.
|
||||
access_token: the access token associated with the pusher.
|
||||
enabled: whether to enable or disable the pusher.
|
||||
"""
|
||||
user_tuple = self.get_success(
|
||||
self.hs.get_datastores().main.get_user_by_access_token(access_token)
|
||||
)
|
||||
token_id = user_tuple.token_id
|
||||
|
||||
self.get_success(
|
||||
self.hs.get_pusherpool().add_pusher(
|
||||
self.hs.get_pusherpool().add_or_update_pusher(
|
||||
user_id=user_id,
|
||||
access_token=token_id,
|
||||
kind="http",
|
||||
@@ -749,11 +771,11 @@ class HTTPPusherTests(HomeserverTestCase):
|
||||
pushkey="a@example.com",
|
||||
lang=None,
|
||||
data={"url": "http://example.com/_matrix/push/v1/notify"},
|
||||
enabled=enabled,
|
||||
device_id=user_tuple.device_id,
|
||||
)
|
||||
)
|
||||
|
||||
return user_id, access_token
|
||||
|
||||
def test_dont_notify_rule_overrides_message(self) -> None:
|
||||
"""
|
||||
The override push rule will suppress notification
|
||||
@@ -791,3 +813,148 @@ class HTTPPusherTests(HomeserverTestCase):
|
||||
# The user sends a message back (sends a notification)
|
||||
self.helper.send(room, body="Hello", tok=access_token)
|
||||
self.assertEqual(len(self.push_attempts), 1)
|
||||
|
||||
@override_config({"experimental_features": {"msc3881_enabled": True}})
|
||||
def test_disable(self) -> None:
|
||||
"""Tests that disabling a pusher means it's not pushed to anymore."""
|
||||
user_id, access_token = self._make_user_with_pusher("user")
|
||||
other_user_id, other_access_token = self._make_user_with_pusher("otheruser")
|
||||
|
||||
room = self.helper.create_room_as(user_id, tok=access_token)
|
||||
self.helper.join(room=room, user=other_user_id, tok=other_access_token)
|
||||
|
||||
# Send a message and check that it generated a push.
|
||||
self.helper.send(room, body="Hi!", tok=other_access_token)
|
||||
self.assertEqual(len(self.push_attempts), 1)
|
||||
|
||||
# Disable the pusher.
|
||||
self._set_pusher(user_id, access_token, enabled=False)
|
||||
|
||||
# Send another message and check that it did not generate a push.
|
||||
self.helper.send(room, body="Hi!", tok=other_access_token)
|
||||
self.assertEqual(len(self.push_attempts), 1)
|
||||
|
||||
# Get the pushers for the user and check that it is marked as disabled.
|
||||
channel = self.make_request("GET", "/pushers", access_token=access_token)
|
||||
self.assertEqual(channel.code, 200)
|
||||
self.assertEqual(len(channel.json_body["pushers"]), 1)
|
||||
|
||||
enabled = channel.json_body["pushers"][0]["org.matrix.msc3881.enabled"]
|
||||
self.assertFalse(enabled)
|
||||
self.assertTrue(isinstance(enabled, bool))
|
||||
|
||||
@override_config({"experimental_features": {"msc3881_enabled": True}})
|
||||
def test_enable(self) -> None:
|
||||
"""Tests that enabling a disabled pusher means it gets pushed to."""
|
||||
# Create the user with the pusher already disabled.
|
||||
user_id, access_token = self._make_user_with_pusher("user", enabled=False)
|
||||
other_user_id, other_access_token = self._make_user_with_pusher("otheruser")
|
||||
|
||||
room = self.helper.create_room_as(user_id, tok=access_token)
|
||||
self.helper.join(room=room, user=other_user_id, tok=other_access_token)
|
||||
|
||||
# Send a message and check that it did not generate a push.
|
||||
self.helper.send(room, body="Hi!", tok=other_access_token)
|
||||
self.assertEqual(len(self.push_attempts), 0)
|
||||
|
||||
# Enable the pusher.
|
||||
self._set_pusher(user_id, access_token, enabled=True)
|
||||
|
||||
# Send another message and check that it did generate a push.
|
||||
self.helper.send(room, body="Hi!", tok=other_access_token)
|
||||
self.assertEqual(len(self.push_attempts), 1)
|
||||
|
||||
# Get the pushers for the user and check that it is marked as enabled.
|
||||
channel = self.make_request("GET", "/pushers", access_token=access_token)
|
||||
self.assertEqual(channel.code, 200)
|
||||
self.assertEqual(len(channel.json_body["pushers"]), 1)
|
||||
|
||||
enabled = channel.json_body["pushers"][0]["org.matrix.msc3881.enabled"]
|
||||
self.assertTrue(enabled)
|
||||
self.assertTrue(isinstance(enabled, bool))
|
||||
|
||||
@override_config({"experimental_features": {"msc3881_enabled": True}})
|
||||
def test_null_enabled(self) -> None:
|
||||
"""Tests that a pusher that has an 'enabled' column set to NULL (eg pushers
|
||||
created before the column was introduced) is considered enabled.
|
||||
"""
|
||||
# We intentionally set 'enabled' to None so that it's stored as NULL in the
|
||||
# database.
|
||||
user_id, access_token = self._make_user_with_pusher("user", enabled=None) # type: ignore[arg-type]
|
||||
|
||||
channel = self.make_request("GET", "/pushers", access_token=access_token)
|
||||
self.assertEqual(channel.code, 200)
|
||||
self.assertEqual(len(channel.json_body["pushers"]), 1)
|
||||
self.assertTrue(channel.json_body["pushers"][0]["org.matrix.msc3881.enabled"])
|
||||
|
||||
def test_update_different_device_access_token_device_id(self) -> None:
|
||||
"""Tests that if we create a pusher from one device, the update it from another
|
||||
device, the access token and device ID associated with the pusher stays the
|
||||
same.
|
||||
"""
|
||||
# Create a user with a pusher.
|
||||
user_id, access_token = self._make_user_with_pusher("user")
|
||||
|
||||
# Get the token ID for the current access token, since that's what we store in
|
||||
# the pushers table. Also get the device ID from it.
|
||||
user_tuple = self.get_success(
|
||||
self.hs.get_datastores().main.get_user_by_access_token(access_token)
|
||||
)
|
||||
token_id = user_tuple.token_id
|
||||
device_id = user_tuple.device_id
|
||||
|
||||
# Generate a new access token, and update the pusher with it.
|
||||
new_token = self.login("user", "pass")
|
||||
self._set_pusher(user_id, new_token, enabled=False)
|
||||
|
||||
# Get the current list of pushers for the user.
|
||||
ret = self.get_success(
|
||||
self.hs.get_datastores().main.get_pushers_by({"user_name": user_id})
|
||||
)
|
||||
pushers: List[PusherConfig] = list(ret)
|
||||
|
||||
# Check that we still have one pusher, and that the access token and device ID
|
||||
# associated with it didn't change.
|
||||
self.assertEqual(len(pushers), 1)
|
||||
self.assertEqual(pushers[0].access_token, token_id)
|
||||
self.assertEqual(pushers[0].device_id, device_id)
|
||||
|
||||
@override_config({"experimental_features": {"msc3881_enabled": True}})
|
||||
def test_device_id(self) -> None:
|
||||
"""Tests that a pusher created with a given device ID shows that device ID in
|
||||
GET /pushers requests.
|
||||
"""
|
||||
self.register_user("user", "pass")
|
||||
access_token = self.login("user", "pass")
|
||||
|
||||
# We create the pusher with an HTTP request rather than with
|
||||
# _make_user_with_pusher so that we can test the device ID is correctly set when
|
||||
# creating a pusher via an API call.
|
||||
self.make_request(
|
||||
method="POST",
|
||||
path="/pushers/set",
|
||||
content={
|
||||
"kind": "http",
|
||||
"app_id": "m.http",
|
||||
"app_display_name": "HTTP Push Notifications",
|
||||
"device_display_name": "pushy push",
|
||||
"pushkey": "a@example.com",
|
||||
"lang": "en",
|
||||
"data": {"url": "http://example.com/_matrix/push/v1/notify"},
|
||||
},
|
||||
access_token=access_token,
|
||||
)
|
||||
|
||||
# Look up the user info for the access token so we can compare the device ID.
|
||||
lookup_result: TokenLookupResult = self.get_success(
|
||||
self.hs.get_datastores().main.get_user_by_access_token(access_token)
|
||||
)
|
||||
|
||||
# Get the user's devices and check it has the correct device ID.
|
||||
channel = self.make_request("GET", "/pushers", access_token=access_token)
|
||||
self.assertEqual(channel.code, 200)
|
||||
self.assertEqual(len(channel.json_body["pushers"]), 1)
|
||||
self.assertEqual(
|
||||
channel.json_body["pushers"][0]["org.matrix.msc3881.device_id"],
|
||||
lookup_result.device_id,
|
||||
)
|
||||
|
||||
79
tests/replication/test_module_cache_invalidation.py
Normal file
79
tests/replication/test_module_cache_invalidation.py
Normal file
@@ -0,0 +1,79 @@
|
||||
# Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
|
||||
import synapse
|
||||
from synapse.module_api import cached
|
||||
|
||||
from tests.replication._base import BaseMultiWorkerStreamTestCase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
FIRST_VALUE = "one"
|
||||
SECOND_VALUE = "two"
|
||||
|
||||
KEY = "mykey"
|
||||
|
||||
|
||||
class TestCache:
|
||||
current_value = FIRST_VALUE
|
||||
|
||||
@cached()
|
||||
async def cached_function(self, user_id: str) -> str:
|
||||
return self.current_value
|
||||
|
||||
|
||||
class ModuleCacheInvalidationTestCase(BaseMultiWorkerStreamTestCase):
|
||||
servlets = [
|
||||
synapse.rest.admin.register_servlets,
|
||||
]
|
||||
|
||||
def test_module_cache_full_invalidation(self):
|
||||
main_cache = TestCache()
|
||||
self.hs.get_module_api().register_cached_function(main_cache.cached_function)
|
||||
|
||||
worker_hs = self.make_worker_hs("synapse.app.generic_worker")
|
||||
|
||||
worker_cache = TestCache()
|
||||
worker_hs.get_module_api().register_cached_function(
|
||||
worker_cache.cached_function
|
||||
)
|
||||
|
||||
self.assertEqual(FIRST_VALUE, self.get_success(main_cache.cached_function(KEY)))
|
||||
self.assertEqual(
|
||||
FIRST_VALUE, self.get_success(worker_cache.cached_function(KEY))
|
||||
)
|
||||
|
||||
main_cache.current_value = SECOND_VALUE
|
||||
worker_cache.current_value = SECOND_VALUE
|
||||
# No invalidation yet, should return the cached value on both the main process and the worker
|
||||
self.assertEqual(FIRST_VALUE, self.get_success(main_cache.cached_function(KEY)))
|
||||
self.assertEqual(
|
||||
FIRST_VALUE, self.get_success(worker_cache.cached_function(KEY))
|
||||
)
|
||||
|
||||
# Full invalidation on the main process, should be replicated on the worker that
|
||||
# should returned the updated value too
|
||||
self.get_success(
|
||||
self.hs.get_module_api().invalidate_cache(
|
||||
main_cache.cached_function, (KEY,)
|
||||
)
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
SECOND_VALUE, self.get_success(main_cache.cached_function(KEY))
|
||||
)
|
||||
self.assertEqual(
|
||||
SECOND_VALUE, self.get_success(worker_cache.cached_function(KEY))
|
||||
)
|
||||
@@ -55,7 +55,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
|
||||
token_id = user_dict.token_id
|
||||
|
||||
self.get_success(
|
||||
self.hs.get_pusherpool().add_pusher(
|
||||
self.hs.get_pusherpool().add_or_update_pusher(
|
||||
user_id=user_id,
|
||||
access_token=token_id,
|
||||
kind="http",
|
||||
|
||||
@@ -2839,7 +2839,7 @@ class PushersRestTestCase(unittest.HomeserverTestCase):
|
||||
token_id = user_tuple.token_id
|
||||
|
||||
self.get_success(
|
||||
self.hs.get_pusherpool().add_pusher(
|
||||
self.hs.get_pusherpool().add_or_update_pusher(
|
||||
user_id=self.other_user,
|
||||
access_token=token_id,
|
||||
kind="http",
|
||||
|
||||
132
tests/rest/client/test_login_token_request.py
Normal file
132
tests/rest/client/test_login_token_request.py
Normal file
@@ -0,0 +1,132 @@
|
||||
# Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from twisted.test.proto_helpers import MemoryReactor
|
||||
|
||||
from synapse.rest import admin
|
||||
from synapse.rest.client import login, login_token_request
|
||||
from synapse.server import HomeServer
|
||||
from synapse.util import Clock
|
||||
|
||||
from tests import unittest
|
||||
from tests.unittest import override_config
|
||||
|
||||
|
||||
class LoginTokenRequestServletTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
servlets = [
|
||||
login.register_servlets,
|
||||
admin.register_servlets,
|
||||
login_token_request.register_servlets,
|
||||
]
|
||||
|
||||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||
self.hs = self.setup_test_homeserver()
|
||||
self.hs.config.registration.enable_registration = True
|
||||
self.hs.config.registration.registrations_require_3pid = []
|
||||
self.hs.config.registration.auto_join_rooms = []
|
||||
self.hs.config.captcha.enable_registration_captcha = False
|
||||
|
||||
return self.hs
|
||||
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
self.user = "user123"
|
||||
self.password = "password"
|
||||
|
||||
def test_disabled(self) -> None:
|
||||
channel = self.make_request("POST", "/login/token", {}, access_token=None)
|
||||
self.assertEqual(channel.code, 400)
|
||||
|
||||
self.register_user(self.user, self.password)
|
||||
token = self.login(self.user, self.password)
|
||||
|
||||
channel = self.make_request("POST", "/login/token", {}, access_token=token)
|
||||
self.assertEqual(channel.code, 400)
|
||||
|
||||
@override_config({"experimental_features": {"msc3882_enabled": True}})
|
||||
def test_require_auth(self) -> None:
|
||||
channel = self.make_request("POST", "/login/token", {}, access_token=None)
|
||||
self.assertEqual(channel.code, 401)
|
||||
|
||||
@override_config({"experimental_features": {"msc3882_enabled": True}})
|
||||
def test_uia_on(self) -> None:
|
||||
user_id = self.register_user(self.user, self.password)
|
||||
token = self.login(self.user, self.password)
|
||||
|
||||
channel = self.make_request("POST", "/login/token", {}, access_token=token)
|
||||
self.assertEqual(channel.code, 401)
|
||||
self.assertIn({"stages": ["m.login.password"]}, channel.json_body["flows"])
|
||||
|
||||
session = channel.json_body["session"]
|
||||
|
||||
uia = {
|
||||
"auth": {
|
||||
"type": "m.login.password",
|
||||
"identifier": {"type": "m.id.user", "user": self.user},
|
||||
"password": self.password,
|
||||
"session": session,
|
||||
},
|
||||
}
|
||||
|
||||
channel = self.make_request("POST", "/login/token", uia, access_token=token)
|
||||
self.assertEqual(channel.code, 200)
|
||||
self.assertEqual(channel.json_body["expires_in"], 300)
|
||||
|
||||
login_token = channel.json_body["login_token"]
|
||||
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
"/login",
|
||||
content={"type": "m.login.token", "token": login_token},
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
self.assertEqual(channel.json_body["user_id"], user_id)
|
||||
|
||||
@override_config(
|
||||
{"experimental_features": {"msc3882_enabled": True, "msc3882_ui_auth": False}}
|
||||
)
|
||||
def test_uia_off(self) -> None:
|
||||
user_id = self.register_user(self.user, self.password)
|
||||
token = self.login(self.user, self.password)
|
||||
|
||||
channel = self.make_request("POST", "/login/token", {}, access_token=token)
|
||||
self.assertEqual(channel.code, 200)
|
||||
self.assertEqual(channel.json_body["expires_in"], 300)
|
||||
|
||||
login_token = channel.json_body["login_token"]
|
||||
|
||||
channel = self.make_request(
|
||||
"POST",
|
||||
"/login",
|
||||
content={"type": "m.login.token", "token": login_token},
|
||||
)
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
self.assertEqual(channel.json_body["user_id"], user_id)
|
||||
|
||||
@override_config(
|
||||
{
|
||||
"experimental_features": {
|
||||
"msc3882_enabled": True,
|
||||
"msc3882_ui_auth": False,
|
||||
"msc3882_token_timeout": "15s",
|
||||
}
|
||||
}
|
||||
)
|
||||
def test_expires_in(self) -> None:
|
||||
self.register_user(self.user, self.password)
|
||||
token = self.login(self.user, self.password)
|
||||
|
||||
channel = self.make_request("POST", "/login/token", {}, access_token=token)
|
||||
self.assertEqual(channel.code, 200)
|
||||
self.assertEqual(channel.json_body["expires_in"], 15)
|
||||
@@ -788,6 +788,7 @@ class RelationPaginationTestCase(BaseRelationsTestCase):
|
||||
channel.json_body["chunk"][0],
|
||||
)
|
||||
|
||||
@unittest.override_config({"experimental_features": {"msc3715_enabled": True}})
|
||||
def test_repeated_paginate_relations(self) -> None:
|
||||
"""Test that if we paginate using a limit and tokens then we get the
|
||||
expected events.
|
||||
@@ -809,7 +810,7 @@ class RelationPaginationTestCase(BaseRelationsTestCase):
|
||||
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
f"/_matrix/client/v1/rooms/{self.room}/relations/{self.parent_id}?limit=1{from_token}",
|
||||
f"/_matrix/client/v1/rooms/{self.room}/relations/{self.parent_id}?limit=3{from_token}",
|
||||
access_token=self.user_token,
|
||||
)
|
||||
self.assertEqual(200, channel.code, channel.json_body)
|
||||
@@ -827,6 +828,32 @@ class RelationPaginationTestCase(BaseRelationsTestCase):
|
||||
found_event_ids.reverse()
|
||||
self.assertEqual(found_event_ids, expected_event_ids)
|
||||
|
||||
# Test forward pagination.
|
||||
prev_token = ""
|
||||
found_event_ids = []
|
||||
for _ in range(20):
|
||||
from_token = ""
|
||||
if prev_token:
|
||||
from_token = "&from=" + prev_token
|
||||
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
f"/_matrix/client/v1/rooms/{self.room}/relations/{self.parent_id}?org.matrix.msc3715.dir=f&limit=3{from_token}",
|
||||
access_token=self.user_token,
|
||||
)
|
||||
self.assertEqual(200, channel.code, channel.json_body)
|
||||
|
||||
found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"])
|
||||
next_batch = channel.json_body.get("next_batch")
|
||||
|
||||
self.assertNotEqual(prev_token, next_batch)
|
||||
prev_token = next_batch
|
||||
|
||||
if not prev_token:
|
||||
break
|
||||
|
||||
self.assertEqual(found_event_ids, expected_event_ids)
|
||||
|
||||
def test_pagination_from_sync_and_messages(self) -> None:
|
||||
"""Pagination tokens from /sync and /messages can be used to paginate /relations."""
|
||||
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "A")
|
||||
|
||||
@@ -103,6 +103,11 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase):
|
||||
self.assertEqual(ctx.get_resource_usage().db_txn_count, 0)
|
||||
|
||||
def test_persisting_event_invalidates_cache(self):
|
||||
"""
|
||||
Test to make sure that the `have_seen_event` cache
|
||||
is invalidated after we persist an event and returns
|
||||
the updated value.
|
||||
"""
|
||||
event, event_context = self.get_success(
|
||||
create_event(
|
||||
self.hs,
|
||||
@@ -145,6 +150,33 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase):
|
||||
# That should result in a single db query to lookup
|
||||
self.assertEqual(ctx.get_resource_usage().db_txn_count, 1)
|
||||
|
||||
def test_invalidate_cache_by_room_id(self):
|
||||
"""
|
||||
Test to make sure that all events associated with the given `(room_id,)`
|
||||
are invalidated in the `have_seen_event` cache.
|
||||
"""
|
||||
with LoggingContext(name="test") as ctx:
|
||||
# Prime the cache with some values
|
||||
res = self.get_success(
|
||||
self.store.have_seen_events(self.room_id, self.event_ids)
|
||||
)
|
||||
self.assertEqual(res, set(self.event_ids))
|
||||
|
||||
# That should result in a single db query to lookup
|
||||
self.assertEqual(ctx.get_resource_usage().db_txn_count, 1)
|
||||
|
||||
# Clear the cache with any events associated with the `room_id`
|
||||
self.store.have_seen_event.invalidate((self.room_id,))
|
||||
|
||||
with LoggingContext(name="test") as ctx:
|
||||
res = self.get_success(
|
||||
self.store.have_seen_events(self.room_id, self.event_ids)
|
||||
)
|
||||
self.assertEqual(res, set(self.event_ids))
|
||||
|
||||
# Since we cleared the cache, it should result in another db query to lookup
|
||||
self.assertEqual(ctx.get_resource_usage().db_txn_count, 1)
|
||||
|
||||
|
||||
class EventCacheTestCase(unittest.HomeserverTestCase):
|
||||
"""Test that the various layers of event cache works."""
|
||||
|
||||
@@ -300,47 +300,31 @@ class HomeserverTestCase(TestCase):
|
||||
if hasattr(self, "user_id"):
|
||||
if self.hijack_auth:
|
||||
assert self.helper.auth_user_id is not None
|
||||
token = "some_fake_token"
|
||||
|
||||
# We need a valid token ID to satisfy foreign key constraints.
|
||||
token_id = self.get_success(
|
||||
self.hs.get_datastores().main.add_access_token_to_user(
|
||||
self.helper.auth_user_id,
|
||||
"some_fake_token",
|
||||
token,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
)
|
||||
|
||||
async def get_user_by_access_token(
|
||||
token: Optional[str] = None, allow_guest: bool = False
|
||||
) -> JsonDict:
|
||||
assert self.helper.auth_user_id is not None
|
||||
return {
|
||||
"user": UserID.from_string(self.helper.auth_user_id),
|
||||
"token_id": token_id,
|
||||
"is_guest": False,
|
||||
}
|
||||
|
||||
async def get_user_by_req(
|
||||
request: SynapseRequest,
|
||||
allow_guest: bool = False,
|
||||
allow_expired: bool = False,
|
||||
) -> Requester:
|
||||
# This has to be a function and not just a Mock, because
|
||||
# `self.helper.auth_user_id` is temporarily reassigned in some tests
|
||||
async def get_requester(*args, **kwargs) -> Requester:
|
||||
assert self.helper.auth_user_id is not None
|
||||
return create_requester(
|
||||
UserID.from_string(self.helper.auth_user_id),
|
||||
token_id,
|
||||
False,
|
||||
False,
|
||||
None,
|
||||
user_id=UserID.from_string(self.helper.auth_user_id),
|
||||
access_token_id=token_id,
|
||||
)
|
||||
|
||||
# Type ignore: mypy doesn't like us assigning to methods.
|
||||
self.hs.get_auth().get_user_by_req = get_user_by_req # type: ignore[assignment]
|
||||
self.hs.get_auth().get_user_by_access_token = get_user_by_access_token # type: ignore[assignment]
|
||||
self.hs.get_auth().get_access_token_from_request = Mock( # type: ignore[assignment]
|
||||
return_value="1234"
|
||||
)
|
||||
self.hs.get_auth().get_user_by_req = get_requester # type: ignore[assignment]
|
||||
self.hs.get_auth().get_user_by_access_token = get_requester # type: ignore[assignment]
|
||||
self.hs.get_auth().get_access_token_from_request = Mock(return_value=token) # type: ignore[assignment]
|
||||
|
||||
if self.needs_threadpool:
|
||||
self.reactor.threadpool = ThreadPool() # type: ignore[assignment]
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from typing import Set
|
||||
from typing import Iterable, Set, Tuple
|
||||
from unittest import mock
|
||||
|
||||
from twisted.internet import defer, reactor
|
||||
@@ -1008,3 +1008,34 @@ class CachedListDescriptorTestCase(unittest.TestCase):
|
||||
obj.inner_context_was_finished, "Tried to restart a finished logcontext"
|
||||
)
|
||||
self.assertEqual(current_context(), SENTINEL_CONTEXT)
|
||||
|
||||
def test_num_args_mismatch(self):
|
||||
"""
|
||||
Make sure someone does not accidentally use @cachedList on a method with
|
||||
a mismatch in the number args to the underlying single cache method.
|
||||
"""
|
||||
|
||||
class Cls:
|
||||
@descriptors.cached(tree=True)
|
||||
def fn(self, room_id, event_id):
|
||||
pass
|
||||
|
||||
# This is wrong ❌. `@cachedList` expects to be given the same number
|
||||
# of arguments as the underlying cached function, just with one of
|
||||
# the arguments being an iterable
|
||||
@descriptors.cachedList(cached_method_name="fn", list_name="keys")
|
||||
def list_fn(self, keys: Iterable[Tuple[str, str]]):
|
||||
pass
|
||||
|
||||
# Corrected syntax ✅
|
||||
#
|
||||
# @cachedList(cached_method_name="fn", list_name="event_ids")
|
||||
# async def list_fn(
|
||||
# self, room_id: str, event_ids: Collection[str],
|
||||
# )
|
||||
|
||||
obj = Cls()
|
||||
|
||||
# Make sure this raises an error about the arg mismatch
|
||||
with self.assertRaises(Exception):
|
||||
obj.list_fn([("foo", "bar")])
|
||||
|
||||
Reference in New Issue
Block a user