1
0

Compare commits

...

12 Commits

Author SHA1 Message Date
Andrew Morgan 6452c704ce changelog 2024-05-17 10:45:20 +01:00
Andrew Morgan 0abc0cbe52 Remove v2 variant from endpoints 2024-05-15 11:59:15 +01:00
Dominic Schubert 336a85d1fa Knocking Endpoints added (federated)
Knocking Endpoints was missing for federaded Worker
2024-04-07 01:08:57 +02:00
Erik Johnston 5360baeb64 Pull out fewer receipts from DB when doing push (#17049)
Before we were pulling out *all* read receipts for a user for every
event we pushed. Instead let's only pull out the relevant receipts.

This also pulled out the event rows for each receipt, causing load on
the events table.
2024-04-05 12:46:34 +01:00
Richard van der Hoff 0e68e9b7f4 Fix bug in calculating state for non-gappy syncs (#16942)
Unfortunately, the optimisation we applied here for non-gappy syncs is
not actually valid.

Fixes https://github.com/element-hq/synapse/issues/16941.

~~Based on https://github.com/element-hq/synapse/pull/16930.~~
Requires https://github.com/matrix-org/sytest/pull/1374.
2024-04-04 16:15:35 +00:00
Richard van der Hoff 230b709d9d /sync: fix bug in calculating state response (#16930)
Fix a long-standing issue which could cause state to be omitted from the
sync response if the last event was filtered out.

Fixes: https://github.com/element-hq/synapse/issues/16928
2024-04-04 12:14:24 +00:00
Richard van der Hoff 05957ac70f Fix bug in /sync response for archived rooms (#16932)
This PR fixes a very, very niche edge-case, but I've got some more work
coming which will otherwise make the problem worse.

The bug happens when the syncing user leaves a room, and has a sync
filter which includes "left" rooms, but sets the timeline limit to 0. In
that case, the state returned in the `state` section is calculated
incorrectly.

The fix is to pass a token corresponding to the point that the user
leaves the room through to `compute_state_delta`.
2024-04-04 12:47:59 +01:00
Erik Johnston 31122b71bc Add missing index to access_tokens table (#17045)
This was causing sequential scans when using refresh tokens.
2024-04-04 11:05:40 +01:00
Erik Johnston 51776745b9 Merge branch 'master' into develop 2024-04-02 18:44:47 +01:00
Erik Johnston ec174d0470 Refactor chain fetching (#17044)
Since these queries are duplicated in two places.
2024-04-02 15:33:56 +01:00
Erik Johnston fd48fc4585 Fixups to new push stream (#17038)
Follow on from #17037
2024-03-28 16:29:23 +00:00
Erik Johnston ea6bfae0fc Add support for moving /push_rules off of main process (#17037) 2024-03-28 15:44:07 +00:00
25 changed files with 861 additions and 268 deletions
+1
View File
@@ -0,0 +1 @@
Fix various long-standing bugs which could cause incorrect state to be returned from `/sync` in certain situations.
+1
View File
@@ -0,0 +1 @@
Fix various long-standing bugs which could cause incorrect state to be returned from `/sync` in certain situations.
+1
View File
@@ -0,0 +1 @@
Fix various long-standing bugs which could cause incorrect state to be returned from `/sync` in certain situations.
+1
View File
@@ -0,0 +1 @@
Add support for moving `/pushrules` off of main process.
+1
View File
@@ -0,0 +1 @@
Add support for moving `/pushrules` off of main process.
+1
View File
@@ -0,0 +1 @@
Refactor auth chain fetching to reduce duplication.
+1
View File
@@ -0,0 +1 @@
Improve database performance by adding a missing index to `access_tokens.refresh_token_id`.
+1
View File
@@ -0,0 +1 @@
Improve database performance by reducing number of receipts fetched when sending push notifications.
+1
View File
@@ -0,0 +1 @@
Document [`/v1/make_knock`](https://spec.matrix.org/v1.10/server-server-api/#get_matrixfederationv1make_knockroomiduserid) and [`/v1/send_knock/](https://spec.matrix.org/v1.10/server-server-api/#put_matrixfederationv1send_knockroomideventid) federation endpoints as worker-compatible.
+8
View File
@@ -310,6 +310,13 @@ WORKERS_CONFIG: Dict[str, Dict[str, Any]] = {
"shared_extra_conf": {}, "shared_extra_conf": {},
"worker_extra_conf": "", "worker_extra_conf": "",
}, },
"push_rules": {
"app": "synapse.app.generic_worker",
"listener_resources": ["client", "replication"],
"endpoint_patterns": ["^/_matrix/client/(api/v1|r0|v3|unstable)/pushrules/"],
"shared_extra_conf": {},
"worker_extra_conf": "",
},
} }
# Templates for sections that may be inserted multiple times in config files # Templates for sections that may be inserted multiple times in config files
@@ -401,6 +408,7 @@ def add_worker_roles_to_shared_config(
"receipts", "receipts",
"to_device", "to_device",
"typing", "typing",
"push_rules",
] ]
# Worker-type specific sharding config. Now a single worker can fulfill multiple # Worker-type specific sharding config. Now a single worker can fulfill multiple
+9
View File
@@ -211,6 +211,8 @@ information.
^/_matrix/federation/v1/make_leave/ ^/_matrix/federation/v1/make_leave/
^/_matrix/federation/(v1|v2)/send_join/ ^/_matrix/federation/(v1|v2)/send_join/
^/_matrix/federation/(v1|v2)/send_leave/ ^/_matrix/federation/(v1|v2)/send_leave/
^/_matrix/federation/v1/make_knock/
^/_matrix/federation/v1/send_knock/
^/_matrix/federation/(v1|v2)/invite/ ^/_matrix/federation/(v1|v2)/invite/
^/_matrix/federation/v1/event_auth/ ^/_matrix/federation/v1/event_auth/
^/_matrix/federation/v1/timestamp_to_event/ ^/_matrix/federation/v1/timestamp_to_event/
@@ -532,6 +534,13 @@ the stream writer for the `presence` stream:
^/_matrix/client/(api/v1|r0|v3|unstable)/presence/ ^/_matrix/client/(api/v1|r0|v3|unstable)/presence/
##### The `push_rules` stream
The following endpoints should be routed directly to the worker configured as
the stream writer for the `push` stream:
^/_matrix/client/(api/v1|r0|v3|unstable)/pushrules/
#### Restrict outbound federation traffic to a specific set of workers #### Restrict outbound federation traffic to a specific set of workers
The The
+3 -6
View File
@@ -60,7 +60,7 @@ from synapse.logging.context import (
) )
from synapse.notifier import ReplicationNotifier from synapse.notifier import ReplicationNotifier
from synapse.storage.database import DatabasePool, LoggingTransaction, make_conn from synapse.storage.database import DatabasePool, LoggingTransaction, make_conn
from synapse.storage.databases.main import FilteringWorkerStore, PushRuleStore from synapse.storage.databases.main import FilteringWorkerStore
from synapse.storage.databases.main.account_data import AccountDataWorkerStore from synapse.storage.databases.main.account_data import AccountDataWorkerStore
from synapse.storage.databases.main.client_ips import ClientIpBackgroundUpdateStore from synapse.storage.databases.main.client_ips import ClientIpBackgroundUpdateStore
from synapse.storage.databases.main.deviceinbox import DeviceInboxBackgroundUpdateStore from synapse.storage.databases.main.deviceinbox import DeviceInboxBackgroundUpdateStore
@@ -77,10 +77,8 @@ from synapse.storage.databases.main.media_repository import (
) )
from synapse.storage.databases.main.presence import PresenceBackgroundUpdateStore from synapse.storage.databases.main.presence import PresenceBackgroundUpdateStore
from synapse.storage.databases.main.profile import ProfileWorkerStore from synapse.storage.databases.main.profile import ProfileWorkerStore
from synapse.storage.databases.main.pusher import ( from synapse.storage.databases.main.push_rule import PusherWorkerStore
PusherBackgroundUpdatesStore, from synapse.storage.databases.main.pusher import PusherBackgroundUpdatesStore
PusherWorkerStore,
)
from synapse.storage.databases.main.receipts import ReceiptsBackgroundUpdateStore from synapse.storage.databases.main.receipts import ReceiptsBackgroundUpdateStore
from synapse.storage.databases.main.registration import ( from synapse.storage.databases.main.registration import (
RegistrationBackgroundUpdateStore, RegistrationBackgroundUpdateStore,
@@ -245,7 +243,6 @@ class Store(
AccountDataWorkerStore, AccountDataWorkerStore,
FilteringWorkerStore, FilteringWorkerStore,
ProfileWorkerStore, ProfileWorkerStore,
PushRuleStore,
PusherWorkerStore, PusherWorkerStore,
PusherBackgroundUpdatesStore, PusherBackgroundUpdatesStore,
PresenceBackgroundUpdateStore, PresenceBackgroundUpdateStore,
+12
View File
@@ -156,6 +156,8 @@ class WriterLocations:
can only be a single instance. can only be a single instance.
presence: The instances that write to the presence stream. Currently presence: The instances that write to the presence stream. Currently
can only be a single instance. can only be a single instance.
push_rules: The instances that write to the push stream. Currently
can only be a single instance.
""" """
events: List[str] = attr.ib( events: List[str] = attr.ib(
@@ -182,6 +184,10 @@ class WriterLocations:
default=["master"], default=["master"],
converter=_instance_to_list_converter, converter=_instance_to_list_converter,
) )
push_rules: List[str] = attr.ib(
default=["master"],
converter=_instance_to_list_converter,
)
@attr.s(auto_attribs=True) @attr.s(auto_attribs=True)
@@ -341,6 +347,7 @@ class WorkerConfig(Config):
"account_data", "account_data",
"receipts", "receipts",
"presence", "presence",
"push_rules",
): ):
instances = _instance_to_list_converter(getattr(self.writers, stream)) instances = _instance_to_list_converter(getattr(self.writers, stream))
for instance in instances: for instance in instances:
@@ -378,6 +385,11 @@ class WorkerConfig(Config):
"Must only specify one instance to handle `presence` messages." "Must only specify one instance to handle `presence` messages."
) )
if len(self.writers.push_rules) != 1:
raise ConfigError(
"Must only specify one instance to handle `push` messages."
)
self.events_shard_config = RoutableShardedWorkerHandlingConfig( self.events_shard_config = RoutableShardedWorkerHandlingConfig(
self.writers.events self.writers.events
) )
+18 -3
View File
@@ -51,6 +51,7 @@ from synapse.handlers.worker_lock import NEW_EVENT_DURING_PURGE_LOCK_NAME
from synapse.logging import opentracing from synapse.logging import opentracing
from synapse.metrics import event_processing_positions from synapse.metrics import event_processing_positions
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.http.push import ReplicationCopyPusherRestServlet
from synapse.storage.databases.main.state_deltas import StateDelta from synapse.storage.databases.main.state_deltas import StateDelta
from synapse.types import ( from synapse.types import (
JsonDict, JsonDict,
@@ -181,6 +182,12 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
hs.config.server.forgotten_room_retention_period hs.config.server.forgotten_room_retention_period
) )
self._is_push_writer = (
hs.get_instance_name() in hs.config.worker.writers.push_rules
)
self._push_writer = hs.config.worker.writers.push_rules[0]
self._copy_push_client = ReplicationCopyPusherRestServlet.make_client(hs)
def _on_user_joined_room(self, event_id: str, room_id: str) -> None: def _on_user_joined_room(self, event_id: str, room_id: str) -> None:
"""Notify the rate limiter that a room join has occurred. """Notify the rate limiter that a room join has occurred.
@@ -1301,9 +1308,17 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
old_room_id, new_room_id, user_id old_room_id, new_room_id, user_id
) )
# Copy over push rules # Copy over push rules
await self.store.copy_push_rules_from_room_to_room_for_user( if self._is_push_writer:
old_room_id, new_room_id, user_id await self.store.copy_push_rules_from_room_to_room_for_user(
) old_room_id, new_room_id, user_id
)
else:
await self._copy_push_client(
instance_name=self._push_writer,
user_id=user_id,
old_room_id=old_room_id,
new_room_id=new_room_id,
)
except Exception: except Exception:
logger.exception( logger.exception(
"Error copying tags and/or push rules from rooms %s to %s for user %s. " "Error copying tags and/or push rules from rooms %s to %s for user %s. "
+138 -90
View File
@@ -953,7 +953,7 @@ class SyncHandler:
batch: TimelineBatch, batch: TimelineBatch,
sync_config: SyncConfig, sync_config: SyncConfig,
since_token: Optional[StreamToken], since_token: Optional[StreamToken],
now_token: StreamToken, end_token: StreamToken,
full_state: bool, full_state: bool,
) -> MutableStateMap[EventBase]: ) -> MutableStateMap[EventBase]:
"""Works out the difference in state between the end of the previous sync and """Works out the difference in state between the end of the previous sync and
@@ -964,7 +964,9 @@ class SyncHandler:
batch: The timeline batch for the room that will be sent to the user. batch: The timeline batch for the room that will be sent to the user.
sync_config: sync_config:
since_token: Token of the end of the previous batch. May be `None`. since_token: Token of the end of the previous batch. May be `None`.
now_token: Token of the end of the current batch. end_token: Token of the end of the current batch. Normally this will be
the same as the global "now_token", but if the user has left the room,
the point just after their leave event.
full_state: Whether to force returning the full state. full_state: Whether to force returning the full state.
`lazy_load_members` still applies when `full_state` is `True`. `lazy_load_members` still applies when `full_state` is `True`.
@@ -1044,7 +1046,7 @@ class SyncHandler:
room_id, room_id,
sync_config.user, sync_config.user,
batch, batch,
now_token, end_token,
members_to_fetch, members_to_fetch,
timeline_state, timeline_state,
) )
@@ -1058,7 +1060,7 @@ class SyncHandler:
room_id, room_id,
batch, batch,
since_token, since_token,
now_token, end_token,
members_to_fetch, members_to_fetch,
timeline_state, timeline_state,
) )
@@ -1130,7 +1132,7 @@ class SyncHandler:
room_id: str, room_id: str,
syncing_user: UserID, syncing_user: UserID,
batch: TimelineBatch, batch: TimelineBatch,
now_token: StreamToken, end_token: StreamToken,
members_to_fetch: Optional[Set[str]], members_to_fetch: Optional[Set[str]],
timeline_state: StateMap[str], timeline_state: StateMap[str],
) -> StateMap[str]: ) -> StateMap[str]:
@@ -1143,7 +1145,9 @@ class SyncHandler:
room_id: The room we are calculating for. room_id: The room we are calculating for.
syncing_user: The user that is calling `/sync`. syncing_user: The user that is calling `/sync`.
batch: The timeline batch for the room that will be sent to the user. batch: The timeline batch for the room that will be sent to the user.
now_token: Token of the end of the current batch. end_token: Token of the end of the current batch. Normally this will be
the same as the global "now_token", but if the user has left the room,
the point just after their leave event.
members_to_fetch: If lazy-loading is enabled, the memberships needed for members_to_fetch: If lazy-loading is enabled, the memberships needed for
events in the timeline. events in the timeline.
timeline_state: The contribution to the room state from state events in timeline_state: The contribution to the room state from state events in
@@ -1183,15 +1187,16 @@ class SyncHandler:
await_full_state = True await_full_state = True
lazy_load_members = False lazy_load_members = False
if batch: state_at_timeline_end = await self.get_state_at(
state_at_timeline_end = ( room_id,
await self._state_storage_controller.get_state_ids_for_event( stream_position=end_token,
batch.events[-1].event_id, state_filter=state_filter,
state_filter=state_filter, await_full_state=await_full_state,
await_full_state=await_full_state, )
)
)
if batch:
# Strictly speaking, this returns the state *after* the first event in the
# timeline, but that is good enough here.
state_at_timeline_start = ( state_at_timeline_start = (
await self._state_storage_controller.get_state_ids_for_event( await self._state_storage_controller.get_state_ids_for_event(
batch.events[0].event_id, batch.events[0].event_id,
@@ -1200,13 +1205,6 @@ class SyncHandler:
) )
) )
else: else:
state_at_timeline_end = await self.get_state_at(
room_id,
stream_position=now_token,
state_filter=state_filter,
await_full_state=await_full_state,
)
state_at_timeline_start = state_at_timeline_end state_at_timeline_start = state_at_timeline_end
state_ids = _calculate_state( state_ids = _calculate_state(
@@ -1223,7 +1221,7 @@ class SyncHandler:
room_id: str, room_id: str,
batch: TimelineBatch, batch: TimelineBatch,
since_token: StreamToken, since_token: StreamToken,
now_token: StreamToken, end_token: StreamToken,
members_to_fetch: Optional[Set[str]], members_to_fetch: Optional[Set[str]],
timeline_state: StateMap[str], timeline_state: StateMap[str],
) -> StateMap[str]: ) -> StateMap[str]:
@@ -1239,7 +1237,9 @@ class SyncHandler:
room_id: The room we are calculating for. room_id: The room we are calculating for.
batch: The timeline batch for the room that will be sent to the user. batch: The timeline batch for the room that will be sent to the user.
since_token: Token of the end of the previous batch. since_token: Token of the end of the previous batch.
now_token: Token of the end of the current batch. end_token: Token of the end of the current batch. Normally this will be
the same as the global "now_token", but if the user has left the room,
the point just after their leave event.
members_to_fetch: If lazy-loading is enabled, the memberships needed for members_to_fetch: If lazy-loading is enabled, the memberships needed for
events in the timeline. Otherwise, `None`. events in the timeline. Otherwise, `None`.
timeline_state: The contribution to the room state from state events in timeline_state: The contribution to the room state from state events in
@@ -1259,25 +1259,25 @@ class SyncHandler:
await_full_state = True await_full_state = True
lazy_load_members = False lazy_load_members = False
if batch.limited: if batch:
if batch: state_at_timeline_start = (
state_at_timeline_start = ( await self._state_storage_controller.get_state_ids_for_event(
await self._state_storage_controller.get_state_ids_for_event( batch.events[0].event_id,
batch.events[0].event_id,
state_filter=state_filter,
await_full_state=await_full_state,
)
)
else:
# We can get here if the user has ignored the senders of all
# the recent events.
state_at_timeline_start = await self.get_state_at(
room_id,
stream_position=now_token,
state_filter=state_filter, state_filter=state_filter,
await_full_state=await_full_state, await_full_state=await_full_state,
) )
)
else:
# We can get here if the user has ignored the senders of all
# the recent events.
state_at_timeline_start = await self.get_state_at(
room_id,
stream_position=end_token,
state_filter=state_filter,
await_full_state=await_full_state,
)
if batch.limited:
# for now, we disable LL for gappy syncs - see # for now, we disable LL for gappy syncs - see
# https://github.com/vector-im/riot-web/issues/7211#issuecomment-419976346 # https://github.com/vector-im/riot-web/issues/7211#issuecomment-419976346
# N.B. this slows down incr syncs as we are now processing way # N.B. this slows down incr syncs as we are now processing way
@@ -1292,58 +1292,28 @@ class SyncHandler:
# about them). # about them).
state_filter = StateFilter.all() state_filter = StateFilter.all()
state_at_previous_sync = await self.get_state_at( state_at_previous_sync = await self.get_state_at(
room_id, room_id,
stream_position=since_token, stream_position=since_token,
state_filter=state_filter, state_filter=state_filter,
await_full_state=await_full_state, await_full_state=await_full_state,
) )
if batch: state_at_timeline_end = await self.get_state_at(
state_at_timeline_end = ( room_id,
await self._state_storage_controller.get_state_ids_for_event( stream_position=end_token,
batch.events[-1].event_id, state_filter=state_filter,
state_filter=state_filter, await_full_state=await_full_state,
await_full_state=await_full_state, )
)
)
else:
# We can get here if the user has ignored the senders of all
# the recent events.
state_at_timeline_end = await self.get_state_at(
room_id,
stream_position=now_token,
state_filter=state_filter,
await_full_state=await_full_state,
)
state_ids = _calculate_state( state_ids = _calculate_state(
timeline_contains=timeline_state, timeline_contains=timeline_state,
timeline_start=state_at_timeline_start, timeline_start=state_at_timeline_start,
timeline_end=state_at_timeline_end, timeline_end=state_at_timeline_end,
previous_timeline_end=state_at_previous_sync, previous_timeline_end=state_at_previous_sync,
lazy_load_members=lazy_load_members, lazy_load_members=lazy_load_members,
) )
else:
state_ids = {}
if lazy_load_members:
if members_to_fetch and batch.events:
# We're returning an incremental sync, with no
# "gap" since the previous sync, so normally there would be
# no state to return.
# But we're lazy-loading, so the client might need some more
# member events to understand the events in this timeline.
# So we fish out all the member events corresponding to the
# timeline here. The caller will then dedupe any redundant ones.
state_ids = await self._state_storage_controller.get_state_ids_for_event(
batch.events[0].event_id,
# we only want members!
state_filter=StateFilter.from_types(
(EventTypes.Member, member) for member in members_to_fetch
),
await_full_state=False,
)
return state_ids return state_ids
async def _find_missing_partial_state_memberships( async def _find_missing_partial_state_memberships(
@@ -2344,6 +2314,7 @@ class SyncHandler:
full_state=False, full_state=False,
since_token=since_token, since_token=since_token,
upto_token=leave_token, upto_token=leave_token,
end_token=leave_token,
out_of_band=leave_event.internal_metadata.is_out_of_band_membership(), out_of_band=leave_event.internal_metadata.is_out_of_band_membership(),
) )
) )
@@ -2381,6 +2352,7 @@ class SyncHandler:
full_state=False, full_state=False,
since_token=None if newly_joined else since_token, since_token=None if newly_joined else since_token,
upto_token=prev_batch_token, upto_token=prev_batch_token,
end_token=now_token,
) )
else: else:
entry = RoomSyncResultBuilder( entry = RoomSyncResultBuilder(
@@ -2391,6 +2363,7 @@ class SyncHandler:
full_state=False, full_state=False,
since_token=since_token, since_token=since_token,
upto_token=since_token, upto_token=since_token,
end_token=now_token,
) )
room_entries.append(entry) room_entries.append(entry)
@@ -2449,6 +2422,7 @@ class SyncHandler:
full_state=True, full_state=True,
since_token=since_token, since_token=since_token,
upto_token=now_token, upto_token=now_token,
end_token=now_token,
) )
) )
elif event.membership == Membership.INVITE: elif event.membership == Membership.INVITE:
@@ -2478,6 +2452,7 @@ class SyncHandler:
full_state=True, full_state=True,
since_token=since_token, since_token=since_token,
upto_token=leave_token, upto_token=leave_token,
end_token=leave_token,
) )
) )
@@ -2548,6 +2523,7 @@ class SyncHandler:
{ {
"since_token": since_token, "since_token": since_token,
"upto_token": upto_token, "upto_token": upto_token,
"end_token": room_builder.end_token,
} }
) )
@@ -2621,7 +2597,7 @@ class SyncHandler:
batch, batch,
sync_config, sync_config,
since_token, since_token,
now_token, room_builder.end_token,
full_state=full_state, full_state=full_state,
) )
else: else:
@@ -2781,6 +2757,61 @@ def _calculate_state(
e for t, e in timeline_start.items() if t[0] == EventTypes.Member e for t, e in timeline_start.items() if t[0] == EventTypes.Member
) )
# Naively, we would just return the difference between the state at the start
# of the timeline (`timeline_start_ids`) and that at the end of the previous sync
# (`previous_timeline_end_ids`). However, that fails in the presence of forks in
# the DAG.
#
# For example, consider a DAG such as the following:
#
# E1
# ↗ ↖
# | S2
# | ↑
# --|------|----
# | |
# E3 |
# ↖ /
# E4
#
# ... and a filter that means we only return 2 events, represented by the dashed
# horizontal line. Assuming S2 was *not* included in the previous sync, we need to
# include it in the `state` section.
#
# Note that the state at the start of the timeline (E3) does not include S2. So,
# to make sure it gets included in the calculation here, we actually look at
# the state at the *end* of the timeline, and subtract any events that are present
# in the timeline.
#
# ----------
#
# Aside 1: You may then wonder if we need to include `timeline_start` in the
# calculation. Consider a linear DAG:
#
# E1
# ↑
# S2
# ↑
# ----|------
# |
# E3
# ↑
# S4
# ↑
# E5
#
# ... where S2 and S4 change the same piece of state; and where we have a filter
# that returns 3 events (E3, S4, E5). We still need to tell the client about S2,
# because it might affect the display of E3. However, the state at the end of the
# timeline only tells us about S4; if we don't inspect `timeline_start` we won't
# find out about S2.
#
# (There are yet more complicated cases in which a state event is excluded from the
# timeline, but whose effect actually lands in the DAG in the *middle* of the
# timeline. We have no way to represent that in the /sync response, and we don't
# even try; it is ether omitted or plonked into `state` as if it were at the start
# of the timeline, depending on what else is in the timeline.)
state_ids = ( state_ids = (
(timeline_end_ids | timeline_start_ids) (timeline_end_ids | timeline_start_ids)
- previous_timeline_end_ids - previous_timeline_end_ids
@@ -2883,13 +2914,30 @@ class RoomSyncResultBuilder:
Attributes: Attributes:
room_id room_id
rtype: One of `"joined"` or `"archived"` rtype: One of `"joined"` or `"archived"`
events: List of events to include in the room (more events may be added events: List of events to include in the room (more events may be added
when generating result). when generating result).
newly_joined: If the user has newly joined the room newly_joined: If the user has newly joined the room
full_state: Whether the full state should be sent in result full_state: Whether the full state should be sent in result
since_token: Earliest point to return events from, or None since_token: Earliest point to return events from, or None
upto_token: Latest point to return events from.
upto_token: Latest point to return events from. If `events` is populated,
this is set to the token at the start of `events`
end_token: The last point in the timeline that the client should see events
from. Normally this will be the same as the global `now_token`, but in
the case of rooms where the user has left the room, this will be the point
just after their leave event.
This is used in the calculation of the state which is returned in `state`:
any state changes *up to* `end_token` (and not beyond!) which are not
reflected in the timeline need to be returned in `state`.
out_of_band: whether the events in the room are "out of band" events out_of_band: whether the events in the room are "out of band" events
and the server isn't in the room. and the server isn't in the room.
""" """
@@ -2901,5 +2949,5 @@ class RoomSyncResultBuilder:
full_state: bool full_state: bool
since_token: Optional[StreamToken] since_token: Optional[StreamToken]
upto_token: StreamToken upto_token: StreamToken
end_token: StreamToken
out_of_band: bool = False out_of_band: bool = False
+41
View File
@@ -77,5 +77,46 @@ class ReplicationRemovePusherRestServlet(ReplicationEndpoint):
return 200, {} return 200, {}
class ReplicationCopyPusherRestServlet(ReplicationEndpoint):
"""Copies push rules from an old room to new room.
Request format:
POST /_synapse/replication/copy_push_rules/:user_id/:old_room_id/:new_room_id
{}
"""
NAME = "copy_push_rules"
PATH_ARGS = ("user_id", "old_room_id", "new_room_id")
CACHE = False
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self._store = hs.get_datastores().main
@staticmethod
async def _serialize_payload(user_id: str, old_room_id: str, new_room_id: str) -> JsonDict: # type: ignore[override]
return {}
async def _handle_request( # type: ignore[override]
self,
request: Request,
content: JsonDict,
user_id: str,
old_room_id: str,
new_room_id: str,
) -> Tuple[int, JsonDict]:
await self._store.copy_push_rules_from_room_to_room_for_user(
old_room_id, new_room_id, user_id
)
return 200, {}
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
ReplicationRemovePusherRestServlet(hs).register(http_server) ReplicationRemovePusherRestServlet(hs).register(http_server)
ReplicationCopyPusherRestServlet(hs).register(http_server)
+7
View File
@@ -66,6 +66,7 @@ from synapse.replication.tcp.streams import (
FederationStream, FederationStream,
PresenceFederationStream, PresenceFederationStream,
PresenceStream, PresenceStream,
PushRulesStream,
ReceiptsStream, ReceiptsStream,
Stream, Stream,
ToDeviceStream, ToDeviceStream,
@@ -178,6 +179,12 @@ class ReplicationCommandHandler:
continue continue
if isinstance(stream, PushRulesStream):
if hs.get_instance_name() in hs.config.worker.writers.push_rules:
self._streams_to_replicate.append(stream)
continue
# Only add any other streams if we're on master. # Only add any other streams if we're on master.
if hs.config.worker.worker_app is not None: if hs.config.worker.worker_app is not None:
continue continue
+5 -3
View File
@@ -59,12 +59,14 @@ class PushRuleRestServlet(RestServlet):
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.store = hs.get_datastores().main self.store = hs.get_datastores().main
self.notifier = hs.get_notifier() self.notifier = hs.get_notifier()
self._is_worker = hs.config.worker.worker_app is not None self._is_push_worker = (
hs.get_instance_name() in hs.config.worker.writers.push_rules
)
self._push_rules_handler = hs.get_push_rules_handler() self._push_rules_handler = hs.get_push_rules_handler()
self._push_rule_linearizer = Linearizer(name="push_rules") self._push_rule_linearizer = Linearizer(name="push_rules")
async def on_PUT(self, request: SynapseRequest, path: str) -> Tuple[int, JsonDict]: async def on_PUT(self, request: SynapseRequest, path: str) -> Tuple[int, JsonDict]:
if self._is_worker: if not self._is_push_worker:
raise Exception("Cannot handle PUT /push_rules on worker") raise Exception("Cannot handle PUT /push_rules on worker")
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
@@ -137,7 +139,7 @@ class PushRuleRestServlet(RestServlet):
async def on_DELETE( async def on_DELETE(
self, request: SynapseRequest, path: str self, request: SynapseRequest, path: str
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
if self._is_worker: if not self._is_push_worker:
raise Exception("Cannot handle DELETE /push_rules on worker") raise Exception("Cannot handle DELETE /push_rules on worker")
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
+2 -2
View File
@@ -63,7 +63,7 @@ from .openid import OpenIdStore
from .presence import PresenceStore from .presence import PresenceStore
from .profile import ProfileStore from .profile import ProfileStore
from .purge_events import PurgeEventsStore from .purge_events import PurgeEventsStore
from .push_rule import PushRuleStore from .push_rule import PushRulesWorkerStore
from .pusher import PusherStore from .pusher import PusherStore
from .receipts import ReceiptsStore from .receipts import ReceiptsStore
from .registration import RegistrationStore from .registration import RegistrationStore
@@ -130,7 +130,6 @@ class DataStore(
RejectionsStore, RejectionsStore,
FilteringWorkerStore, FilteringWorkerStore,
PusherStore, PusherStore,
PushRuleStore,
ApplicationServiceTransactionStore, ApplicationServiceTransactionStore,
EventPushActionsStore, EventPushActionsStore,
ServerMetricsStore, ServerMetricsStore,
@@ -140,6 +139,7 @@ class DataStore(
SearchStore, SearchStore,
TagsStore, TagsStore,
AccountDataStore, AccountDataStore,
PushRulesWorkerStore,
StreamWorkerStore, StreamWorkerStore,
OpenIdStore, OpenIdStore,
ClientIpWorkerStore, ClientIpWorkerStore,
@@ -27,6 +27,7 @@ from typing import (
Collection, Collection,
Dict, Dict,
FrozenSet, FrozenSet,
Generator,
Iterable, Iterable,
List, List,
Optional, Optional,
@@ -279,64 +280,16 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# Now we look up all links for the chains we have, adding chains that # Now we look up all links for the chains we have, adding chains that
# are reachable from any event. # are reachable from any event.
#
# This query is structured to first get all chain IDs reachable, and
# then pull out all links from those chains. This does pull out more
# rows than is strictly necessary, however there isn't a way of
# structuring the recursive part of query to pull out the links without
# also returning large quantities of redundant data (which can make it a
# lot slower).
sql = """
WITH RECURSIVE links(chain_id) AS (
SELECT
DISTINCT origin_chain_id
FROM event_auth_chain_links WHERE %s
UNION
SELECT
target_chain_id
FROM event_auth_chain_links
INNER JOIN links ON (chain_id = origin_chain_id)
)
SELECT
origin_chain_id, origin_sequence_number,
target_chain_id, target_sequence_number
FROM links
INNER JOIN event_auth_chain_links ON (chain_id = origin_chain_id)
"""
# A map from chain ID to max sequence number *reachable* from any event ID. # A map from chain ID to max sequence number *reachable* from any event ID.
chains: Dict[int, int] = {} chains: Dict[int, int] = {}
for links in self._get_chain_links(txn, set(event_chains.keys())):
# Add all linked chains reachable from initial set of chains.
chains_to_fetch = set(event_chains.keys())
while chains_to_fetch:
batch2 = tuple(itertools.islice(chains_to_fetch, 1000))
chains_to_fetch.difference_update(batch2)
clause, args = make_in_list_sql_clause(
txn.database_engine, "origin_chain_id", batch2
)
txn.execute(sql % (clause,), args)
links: Dict[int, List[Tuple[int, int, int]]] = {}
for (
origin_chain_id,
origin_sequence_number,
target_chain_id,
target_sequence_number,
) in txn:
links.setdefault(origin_chain_id, []).append(
(origin_sequence_number, target_chain_id, target_sequence_number)
)
for chain_id in links: for chain_id in links:
if chain_id not in event_chains: if chain_id not in event_chains:
continue continue
_materialize(chain_id, event_chains[chain_id], links, chains) _materialize(chain_id, event_chains[chain_id], links, chains)
chains_to_fetch.difference_update(chains)
# Add the initial set of chains, excluding the sequence corresponding to # Add the initial set of chains, excluding the sequence corresponding to
# initial event. # initial event.
for chain_id, seq_no in event_chains.items(): for chain_id, seq_no in event_chains.items():
@@ -380,6 +333,68 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
return results return results
@classmethod
def _get_chain_links(
cls, txn: LoggingTransaction, chains_to_fetch: Set[int]
) -> Generator[Dict[int, List[Tuple[int, int, int]]], None, None]:
"""Fetch all auth chain links from the given set of chains, and all
links from those chains, recursively.
Note: This may return links that are not reachable from the given
chains.
Returns a generator that produces dicts from origin chain ID to 3-tuple
of origin sequence number, target chain ID and target sequence number.
"""
# This query is structured to first get all chain IDs reachable, and
# then pull out all links from those chains. This does pull out more
# rows than is strictly necessary, however there isn't a way of
# structuring the recursive part of query to pull out the links without
# also returning large quantities of redundant data (which can make it a
# lot slower).
sql = """
WITH RECURSIVE links(chain_id) AS (
SELECT
DISTINCT origin_chain_id
FROM event_auth_chain_links WHERE %s
UNION
SELECT
target_chain_id
FROM event_auth_chain_links
INNER JOIN links ON (chain_id = origin_chain_id)
)
SELECT
origin_chain_id, origin_sequence_number,
target_chain_id, target_sequence_number
FROM links
INNER JOIN event_auth_chain_links ON (chain_id = origin_chain_id)
"""
while chains_to_fetch:
batch2 = tuple(itertools.islice(chains_to_fetch, 1000))
chains_to_fetch.difference_update(batch2)
clause, args = make_in_list_sql_clause(
txn.database_engine, "origin_chain_id", batch2
)
txn.execute(sql % (clause,), args)
links: Dict[int, List[Tuple[int, int, int]]] = {}
for (
origin_chain_id,
origin_sequence_number,
target_chain_id,
target_sequence_number,
) in txn:
links.setdefault(origin_chain_id, []).append(
(origin_sequence_number, target_chain_id, target_sequence_number)
)
chains_to_fetch.difference_update(links)
yield links
def _get_auth_chain_ids_txn( def _get_auth_chain_ids_txn(
self, txn: LoggingTransaction, event_ids: Collection[str], include_given: bool self, txn: LoggingTransaction, event_ids: Collection[str], include_given: bool
) -> Set[str]: ) -> Set[str]:
@@ -564,53 +579,9 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# Now we look up all links for the chains we have, adding chains that # Now we look up all links for the chains we have, adding chains that
# are reachable from any event. # are reachable from any event.
#
# This query is structured to first get all chain IDs reachable, and
# then pull out all links from those chains. This does pull out more
# rows than is strictly necessary, however there isn't a way of
# structuring the recursive part of query to pull out the links without
# also returning large quantities of redundant data (which can make it a
# lot slower).
sql = """
WITH RECURSIVE links(chain_id) AS (
SELECT
DISTINCT origin_chain_id
FROM event_auth_chain_links WHERE %s
UNION
SELECT
target_chain_id
FROM event_auth_chain_links
INNER JOIN links ON (chain_id = origin_chain_id)
)
SELECT
origin_chain_id, origin_sequence_number,
target_chain_id, target_sequence_number
FROM links
INNER JOIN event_auth_chain_links ON (chain_id = origin_chain_id)
"""
# (We need to take a copy of `seen_chains` as we want to mutate it in
# the loop)
chains_to_fetch = set(seen_chains)
while chains_to_fetch:
batch2 = tuple(itertools.islice(chains_to_fetch, 1000))
clause, args = make_in_list_sql_clause(
txn.database_engine, "origin_chain_id", batch2
)
txn.execute(sql % (clause,), args)
links: Dict[int, List[Tuple[int, int, int]]] = {}
for (
origin_chain_id,
origin_sequence_number,
target_chain_id,
target_sequence_number,
) in txn:
links.setdefault(origin_chain_id, []).append(
(origin_sequence_number, target_chain_id, target_sequence_number)
)
# (We need to take a copy of `seen_chains` as the function mutates it)
for links in self._get_chain_links(txn, set(seen_chains)):
for chains in set_to_chain: for chains in set_to_chain:
for chain_id in links: for chain_id in links:
if chain_id not in chains: if chain_id not in chains:
@@ -618,7 +589,6 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
_materialize(chain_id, chains[chain_id], links, chains) _materialize(chain_id, chains[chain_id], links, chains)
chains_to_fetch.difference_update(chains)
seen_chains.update(chains) seen_chains.update(chains)
# Now for each chain we figure out the maximum sequence number reachable # Now for each chain we figure out the maximum sequence number reachable
@@ -106,7 +106,7 @@ from synapse.storage.database import (
) )
from synapse.storage.databases.main.receipts import ReceiptsWorkerStore from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
from synapse.storage.databases.main.stream import StreamWorkerStore from synapse.storage.databases.main.stream import StreamWorkerStore
from synapse.types import JsonDict from synapse.types import JsonDict, StrCollection
from synapse.util import json_encoder from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached from synapse.util.caches.descriptors import cached
@@ -859,37 +859,86 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
return await self.db_pool.runInteraction("get_push_action_users_in_range", f) return await self.db_pool.runInteraction("get_push_action_users_in_range", f)
def _get_receipts_by_room_txn( def _get_receipts_for_room_and_threads_txn(
self, txn: LoggingTransaction, user_id: str self,
txn: LoggingTransaction,
user_id: str,
room_ids: StrCollection,
thread_ids: StrCollection,
) -> Dict[str, _RoomReceipt]: ) -> Dict[str, _RoomReceipt]:
""" """
Generate a map of room ID to the latest stream ordering that has been Get (private) read receipts for a user in each of the given room IDs
read by the given user. and thread IDs.
Args: Note: The corresponding room ID for each thread must appear in
txn: `room_ids` arg.
user_id: The user to fetch receipts for.
Returns: Returns:
A map including all rooms the user is in with a receipt. It maps A map including all rooms the user is in with a receipt. It maps
room IDs to _RoomReceipt instances room IDs to _RoomReceipt instances
""" """
receipt_types_clause, args = make_in_list_sql_clause(
receipt_types_clause, receipts_args = make_in_list_sql_clause(
self.database_engine, self.database_engine,
"receipt_type", "receipt_type",
(ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE), (ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE),
) )
thread_ids_clause, thread_ids_args = make_in_list_sql_clause(
self.database_engine,
"thread_id",
thread_ids,
)
room_ids_clause, room_ids_args = make_in_list_sql_clause(
self.database_engine,
"room_id",
room_ids,
)
# We use the union of two (almost identical) queries here, the first to
# fetch the specific thread receipts and the second to fetch the
# unthreaded receipts.
#
# This SQL is optimized to use the indices we have on
# `receipts_linearized`.
#
# We compare room ID and thread IDs independently due to the above,
# which means that this query might return more rows than we need if the
# same thread ID appears across different rooms (e.g. 'main' thread ID).
# This doesn't cause any logic issues, and isn't a performance concern
# given this function generally gets called with only one room and
# thread ID.
sql = f""" sql = f"""
SELECT room_id, thread_id, MAX(stream_ordering) SELECT room_id, thread_id, MAX(stream_ordering)
FROM receipts_linearized FROM receipts_linearized
INNER JOIN events USING (room_id, event_id) INNER JOIN events USING (room_id, event_id)
WHERE {receipt_types_clause} WHERE {receipt_types_clause}
AND {thread_ids_clause}
AND {room_ids_clause}
AND user_id = ?
GROUP BY room_id, thread_id
UNION ALL
SELECT room_id, thread_id, MAX(stream_ordering)
FROM receipts_linearized
INNER JOIN events USING (room_id, event_id)
WHERE {receipt_types_clause}
AND {room_ids_clause}
AND thread_id IS NULL
AND user_id = ? AND user_id = ?
GROUP BY room_id, thread_id GROUP BY room_id, thread_id
""" """
args.extend((user_id,)) args = list(receipts_args)
args.extend(thread_ids_args)
args.extend(room_ids_args)
args.append(user_id)
args.extend(receipts_args)
args.extend(room_ids_args)
args.append(user_id)
txn.execute(sql, args) txn.execute(sql, args)
result: Dict[str, _RoomReceipt] = {} result: Dict[str, _RoomReceipt] = {}
@@ -925,12 +974,6 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
The list will have between 0~limit entries. The list will have between 0~limit entries.
""" """
receipts_by_room = await self.db_pool.runInteraction(
"get_unread_push_actions_for_user_in_range_http_receipts",
self._get_receipts_by_room_txn,
user_id=user_id,
)
def get_push_actions_txn( def get_push_actions_txn(
txn: LoggingTransaction, txn: LoggingTransaction,
) -> List[Tuple[str, str, str, int, str, bool]]: ) -> List[Tuple[str, str, str, int, str, bool]]:
@@ -952,6 +995,27 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
"get_unread_push_actions_for_user_in_range_http", get_push_actions_txn "get_unread_push_actions_for_user_in_range_http", get_push_actions_txn
) )
room_ids = set()
thread_ids = []
for (
_,
room_id,
thread_id,
_,
_,
_,
) in push_actions:
room_ids.add(room_id)
thread_ids.append(thread_id)
receipts_by_room = await self.db_pool.runInteraction(
"get_unread_push_actions_for_user_in_range_http_receipts",
self._get_receipts_for_room_and_threads_txn,
user_id=user_id,
room_ids=room_ids,
thread_ids=thread_ids,
)
notifs = [ notifs = [
HttpPushAction( HttpPushAction(
event_id=event_id, event_id=event_id,
@@ -998,12 +1062,6 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
The list will have between 0~limit entries. The list will have between 0~limit entries.
""" """
receipts_by_room = await self.db_pool.runInteraction(
"get_unread_push_actions_for_user_in_range_email_receipts",
self._get_receipts_by_room_txn,
user_id=user_id,
)
def get_push_actions_txn( def get_push_actions_txn(
txn: LoggingTransaction, txn: LoggingTransaction,
) -> List[Tuple[str, str, str, int, str, bool, int]]: ) -> List[Tuple[str, str, str, int, str, bool, int]]:
@@ -1026,6 +1084,28 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
"get_unread_push_actions_for_user_in_range_email", get_push_actions_txn "get_unread_push_actions_for_user_in_range_email", get_push_actions_txn
) )
room_ids = set()
thread_ids = []
for (
_,
room_id,
thread_id,
_,
_,
_,
_,
) in push_actions:
room_ids.add(room_id)
thread_ids.append(thread_id)
receipts_by_room = await self.db_pool.runInteraction(
"get_unread_push_actions_for_user_in_range_email_receipts",
self._get_receipts_for_room_and_threads_txn,
user_id=user_id,
room_ids=room_ids,
thread_ids=thread_ids,
)
# Make a list of dicts from the two sets of results. # Make a list of dicts from the two sets of results.
notifs = [ notifs = [
EmailPushAction( EmailPushAction(
+43 -26
View File
@@ -53,11 +53,7 @@ from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
from synapse.storage.engines import PostgresEngine, Sqlite3Engine from synapse.storage.engines import PostgresEngine, Sqlite3Engine
from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
from synapse.storage.util.id_generators import ( from synapse.storage.util.id_generators import IdGenerator, StreamIdGenerator
AbstractStreamIdGenerator,
IdGenerator,
StreamIdGenerator,
)
from synapse.synapse_rust.push import FilteredPushRules, PushRule, PushRules from synapse.synapse_rust.push import FilteredPushRules, PushRule, PushRules
from synapse.types import JsonDict from synapse.types import JsonDict
from synapse.util import json_encoder, unwrapFirstError from synapse.util import json_encoder, unwrapFirstError
@@ -130,6 +126,8 @@ class PushRulesWorkerStore(
`get_max_push_rules_stream_id` which can be called in the initializer. `get_max_push_rules_stream_id` which can be called in the initializer.
""" """
_push_rules_stream_id_gen: StreamIdGenerator
def __init__( def __init__(
self, self,
database: DatabasePool, database: DatabasePool,
@@ -138,6 +136,10 @@ class PushRulesWorkerStore(
): ):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
self._is_push_writer = (
hs.get_instance_name() in hs.config.worker.writers.push_rules
)
# In the worker store this is an ID tracker which we overwrite in the non-worker # In the worker store this is an ID tracker which we overwrite in the non-worker
# class below that is used on the main process. # class below that is used on the main process.
self._push_rules_stream_id_gen = StreamIdGenerator( self._push_rules_stream_id_gen = StreamIdGenerator(
@@ -145,7 +147,7 @@ class PushRulesWorkerStore(
hs.get_replication_notifier(), hs.get_replication_notifier(),
"push_rules_stream", "push_rules_stream",
"stream_id", "stream_id",
is_writer=hs.config.worker.worker_app is None, is_writer=self._is_push_writer,
) )
push_rules_prefill, push_rules_id = self.db_pool.get_cache_dict( push_rules_prefill, push_rules_id = self.db_pool.get_cache_dict(
@@ -162,6 +164,9 @@ class PushRulesWorkerStore(
prefilled_cache=push_rules_prefill, prefilled_cache=push_rules_prefill,
) )
self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")
def get_max_push_rules_stream_id(self) -> int: def get_max_push_rules_stream_id(self) -> int:
"""Get the position of the push rules stream. """Get the position of the push rules stream.
@@ -383,23 +388,6 @@ class PushRulesWorkerStore(
"get_all_push_rule_updates", get_all_push_rule_updates_txn "get_all_push_rule_updates", get_all_push_rule_updates_txn
) )
class PushRuleStore(PushRulesWorkerStore):
# Because we have write access, this will be a StreamIdGenerator
# (see PushRulesWorkerStore.__init__)
_push_rules_stream_id_gen: AbstractStreamIdGenerator
def __init__(
self,
database: DatabasePool,
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
super().__init__(database, db_conn, hs)
self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
self._push_rules_enable_id_gen = IdGenerator(db_conn, "push_rules_enable", "id")
async def add_push_rule( async def add_push_rule(
self, self,
user_id: str, user_id: str,
@@ -410,6 +398,9 @@ class PushRuleStore(PushRulesWorkerStore):
before: Optional[str] = None, before: Optional[str] = None,
after: Optional[str] = None, after: Optional[str] = None,
) -> None: ) -> None:
if not self._is_push_writer:
raise Exception("Not a push writer")
conditions_json = json_encoder.encode(conditions) conditions_json = json_encoder.encode(conditions)
actions_json = json_encoder.encode(actions) actions_json = json_encoder.encode(actions)
async with self._push_rules_stream_id_gen.get_next() as stream_id: async with self._push_rules_stream_id_gen.get_next() as stream_id:
@@ -455,6 +446,9 @@ class PushRuleStore(PushRulesWorkerStore):
before: str, before: str,
after: str, after: str,
) -> None: ) -> None:
if not self._is_push_writer:
raise Exception("Not a push writer")
relative_to_rule = before or after relative_to_rule = before or after
sql = """ sql = """
@@ -524,6 +518,9 @@ class PushRuleStore(PushRulesWorkerStore):
conditions_json: str, conditions_json: str,
actions_json: str, actions_json: str,
) -> None: ) -> None:
if not self._is_push_writer:
raise Exception("Not a push writer")
if isinstance(self.database_engine, PostgresEngine): if isinstance(self.database_engine, PostgresEngine):
# Postgres doesn't do FOR UPDATE on aggregate functions, so select the rows first # Postgres doesn't do FOR UPDATE on aggregate functions, so select the rows first
# then re-select the count/max below. # then re-select the count/max below.
@@ -575,6 +572,9 @@ class PushRuleStore(PushRulesWorkerStore):
actions_json: str, actions_json: str,
update_stream: bool = True, update_stream: bool = True,
) -> None: ) -> None:
if not self._is_push_writer:
raise Exception("Not a push writer")
"""Specialised version of simple_upsert_txn that picks a push_rule_id """Specialised version of simple_upsert_txn that picks a push_rule_id
using the _push_rule_id_gen if it needs to insert the rule. It assumes using the _push_rule_id_gen if it needs to insert the rule. It assumes
that the "push_rules" table is locked""" that the "push_rules" table is locked"""
@@ -653,6 +653,8 @@ class PushRuleStore(PushRulesWorkerStore):
user_id: The matrix ID of the push rule owner user_id: The matrix ID of the push rule owner
rule_id: The rule_id of the rule to be deleted rule_id: The rule_id of the rule to be deleted
""" """
if not self._is_push_writer:
raise Exception("Not a push writer")
def delete_push_rule_txn( def delete_push_rule_txn(
txn: LoggingTransaction, txn: LoggingTransaction,
@@ -704,6 +706,9 @@ class PushRuleStore(PushRulesWorkerStore):
Raises: Raises:
RuleNotFoundException if the rule does not exist. RuleNotFoundException if the rule does not exist.
""" """
if not self._is_push_writer:
raise Exception("Not a push writer")
async with self._push_rules_stream_id_gen.get_next() as stream_id: async with self._push_rules_stream_id_gen.get_next() as stream_id:
event_stream_ordering = self._stream_id_gen.get_current_token() event_stream_ordering = self._stream_id_gen.get_current_token()
await self.db_pool.runInteraction( await self.db_pool.runInteraction(
@@ -727,6 +732,9 @@ class PushRuleStore(PushRulesWorkerStore):
enabled: bool, enabled: bool,
is_default_rule: bool, is_default_rule: bool,
) -> None: ) -> None:
if not self._is_push_writer:
raise Exception("Not a push writer")
new_id = self._push_rules_enable_id_gen.get_next() new_id = self._push_rules_enable_id_gen.get_next()
if not is_default_rule: if not is_default_rule:
@@ -796,6 +804,9 @@ class PushRuleStore(PushRulesWorkerStore):
Raises: Raises:
RuleNotFoundException if the rule does not exist. RuleNotFoundException if the rule does not exist.
""" """
if not self._is_push_writer:
raise Exception("Not a push writer")
actions_json = json_encoder.encode(actions) actions_json = json_encoder.encode(actions)
def set_push_rule_actions_txn( def set_push_rule_actions_txn(
@@ -865,6 +876,9 @@ class PushRuleStore(PushRulesWorkerStore):
op: str, op: str,
data: Optional[JsonDict] = None, data: Optional[JsonDict] = None,
) -> None: ) -> None:
if not self._is_push_writer:
raise Exception("Not a push writer")
values = { values = {
"stream_id": stream_id, "stream_id": stream_id,
"event_stream_ordering": event_stream_ordering, "event_stream_ordering": event_stream_ordering,
@@ -882,9 +896,6 @@ class PushRuleStore(PushRulesWorkerStore):
self.push_rules_stream_cache.entity_has_changed, user_id, stream_id self.push_rules_stream_cache.entity_has_changed, user_id, stream_id
) )
def get_max_push_rules_stream_id(self) -> int:
return self._push_rules_stream_id_gen.get_current_token()
async def copy_push_rule_from_room_to_room( async def copy_push_rule_from_room_to_room(
self, new_room_id: str, user_id: str, rule: PushRule self, new_room_id: str, user_id: str, rule: PushRule
) -> None: ) -> None:
@@ -895,6 +906,9 @@ class PushRuleStore(PushRulesWorkerStore):
user_id : ID of user the push rule belongs to. user_id : ID of user the push rule belongs to.
rule: A push rule. rule: A push rule.
""" """
if not self._is_push_writer:
raise Exception("Not a push writer")
# Create new rule id # Create new rule id
rule_id_scope = "/".join(rule.rule_id.split("/")[:-1]) rule_id_scope = "/".join(rule.rule_id.split("/")[:-1])
new_rule_id = rule_id_scope + "/" + new_room_id new_rule_id = rule_id_scope + "/" + new_room_id
@@ -930,6 +944,9 @@ class PushRuleStore(PushRulesWorkerStore):
new_room_id: ID of the new room. new_room_id: ID of the new room.
user_id: ID of user to copy push rules for. user_id: ID of user to copy push rules for.
""" """
if not self._is_push_writer:
raise Exception("Not a push writer")
# Retrieve push rules for this user # Retrieve push rules for this user
user_push_rules = await self.get_push_rules_for_user(user_id) user_push_rules = await self.get_push_rules_for_user(user_id)
@@ -2266,6 +2266,13 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
): ):
super().__init__(database, db_conn, hs) super().__init__(database, db_conn, hs)
self.db_pool.updates.register_background_index_update(
update_name="access_tokens_refresh_token_id_idx",
index_name="access_tokens_refresh_token_id_idx",
table="access_tokens",
columns=("refresh_token_id",),
)
self._ignore_unknown_session_error = ( self._ignore_unknown_session_error = (
hs.config.server.request_token_inhibit_3pid_errors hs.config.server.request_token_inhibit_3pid_errors
) )
+380 -13
View File
@@ -17,14 +17,16 @@
# [This file includes modifications made by New Vector Limited] # [This file includes modifications made by New Vector Limited]
# #
# #
from typing import Collection, List, Optional from typing import Collection, ContextManager, List, Optional
from unittest.mock import AsyncMock, Mock, patch from unittest.mock import AsyncMock, Mock, patch
from parameterized import parameterized
from twisted.test.proto_helpers import MemoryReactor from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import EventTypes, JoinRules from synapse.api.constants import EventTypes, JoinRules
from synapse.api.errors import Codes, ResourceLimitError from synapse.api.errors import Codes, ResourceLimitError
from synapse.api.filtering import Filtering from synapse.api.filtering import FilterCollection, Filtering
from synapse.api.room_versions import RoomVersion, RoomVersions from synapse.api.room_versions import RoomVersion, RoomVersions
from synapse.events import EventBase from synapse.events import EventBase
from synapse.events.snapshot import EventContext from synapse.events.snapshot import EventContext
@@ -33,7 +35,7 @@ from synapse.handlers.sync import SyncConfig, SyncResult
from synapse.rest import admin from synapse.rest import admin
from synapse.rest.client import knock, login, room from synapse.rest.client import knock, login, room
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.types import UserID, create_requester from synapse.types import JsonDict, UserID, create_requester
from synapse.util import Clock from synapse.util import Clock
import tests.unittest import tests.unittest
@@ -258,13 +260,7 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
# Eve tries to join the room. We monkey patch the internal logic which selects # Eve tries to join the room. We monkey patch the internal logic which selects
# the prev_events used when creating the join event, such that the ban does not # the prev_events used when creating the join event, such that the ban does not
# precede the join. # precede the join.
mocked_get_prev_events = patch.object( with self._patch_get_latest_events([last_room_creation_event_id]):
self.hs.get_datastores().main,
"get_prev_events_for_room",
new_callable=AsyncMock,
return_value=[last_room_creation_event_id],
)
with mocked_get_prev_events:
self.helper.join(room_id, eve, tok=eve_token) self.helper.join(room_id, eve, tok=eve_token)
# Eve makes a second, incremental sync. # Eve makes a second, incremental sync.
@@ -288,6 +284,365 @@ class SyncTestCase(tests.unittest.HomeserverTestCase):
) )
self.assertEqual(eve_initial_sync_after_join.joined, []) self.assertEqual(eve_initial_sync_after_join.joined, [])
def test_state_includes_changes_on_forks(self) -> None:
"""State changes that happen on a fork of the DAG must be included in `state`
Given the following DAG:
E1
| S2
|
--|------|----
| |
E3 |
/
E4
... and a filter that means we only return 2 events, represented by the dashed
horizontal line: `S2` must be included in the `state` section.
"""
alice = self.register_user("alice", "password")
alice_tok = self.login(alice, "password")
alice_requester = create_requester(alice)
room_id = self.helper.create_room_as(alice, is_public=True, tok=alice_tok)
# Do an initial sync as Alice to get a known starting point.
initial_sync_result = self.get_success(
self.sync_handler.wait_for_sync_for_user(
alice_requester, generate_sync_config(alice)
)
)
last_room_creation_event_id = (
initial_sync_result.joined[0].timeline.events[-1].event_id
)
# Send a state event, and a regular event, both using the same prev ID
with self._patch_get_latest_events([last_room_creation_event_id]):
s2_event = self.helper.send_state(room_id, "s2", {}, tok=alice_tok)[
"event_id"
]
e3_event = self.helper.send(room_id, "e3", tok=alice_tok)["event_id"]
# Send a final event, joining the two branches of the dag
e4_event = self.helper.send(room_id, "e4", tok=alice_tok)["event_id"]
# do an incremental sync, with a filter that will ensure we only get two of
# the three new events.
incremental_sync = self.get_success(
self.sync_handler.wait_for_sync_for_user(
alice_requester,
generate_sync_config(
alice,
filter_collection=FilterCollection(
self.hs, {"room": {"timeline": {"limit": 2}}}
),
),
since_token=initial_sync_result.next_batch,
)
)
# The state event should appear in the 'state' section of the response.
room_sync = incremental_sync.joined[0]
self.assertEqual(room_sync.room_id, room_id)
self.assertTrue(room_sync.timeline.limited)
self.assertEqual(
[e.event_id for e in room_sync.timeline.events],
[e3_event, e4_event],
)
self.assertEqual(
[e.event_id for e in room_sync.state.values()],
[s2_event],
)
def test_state_includes_changes_on_forks_when_events_excluded(self) -> None:
"""A variation on the previous test, but where one event is filtered
The DAG is the same as the previous test, but E4 is excluded by the filter.
E1
| S2
|
--|------|----
| |
E3 |
/
(E4)
"""
alice = self.register_user("alice", "password")
alice_tok = self.login(alice, "password")
alice_requester = create_requester(alice)
room_id = self.helper.create_room_as(alice, is_public=True, tok=alice_tok)
# Do an initial sync as Alice to get a known starting point.
initial_sync_result = self.get_success(
self.sync_handler.wait_for_sync_for_user(
alice_requester, generate_sync_config(alice)
)
)
last_room_creation_event_id = (
initial_sync_result.joined[0].timeline.events[-1].event_id
)
# Send a state event, and a regular event, both using the same prev ID
with self._patch_get_latest_events([last_room_creation_event_id]):
s2_event = self.helper.send_state(room_id, "s2", {}, tok=alice_tok)[
"event_id"
]
e3_event = self.helper.send(room_id, "e3", tok=alice_tok)["event_id"]
# Send a final event, joining the two branches of the dag
self.helper.send(room_id, "e4", type="not_a_normal_message", tok=alice_tok)[
"event_id"
]
# do an incremental sync, with a filter that will only return E3, excluding S2
# and E4.
incremental_sync = self.get_success(
self.sync_handler.wait_for_sync_for_user(
alice_requester,
generate_sync_config(
alice,
filter_collection=FilterCollection(
self.hs,
{
"room": {
"timeline": {
"limit": 1,
"not_types": ["not_a_normal_message"],
}
}
},
),
),
since_token=initial_sync_result.next_batch,
)
)
# The state event should appear in the 'state' section of the response.
room_sync = incremental_sync.joined[0]
self.assertEqual(room_sync.room_id, room_id)
self.assertTrue(room_sync.timeline.limited)
self.assertEqual(
[e.event_id for e in room_sync.timeline.events],
[e3_event],
)
self.assertEqual(
[e.event_id for e in room_sync.state.values()],
[s2_event],
)
def test_state_includes_changes_on_ungappy_syncs(self) -> None:
"""Test `state` where the sync is not gappy.
We start with a DAG like this:
E1
| S2
|
--|---
|
E3
... and initialsync with `limit=1`, represented by the horizontal dashed line.
At this point, we do not expect S2 to appear in the response at all (since
it is excluded from the timeline by the `limit`, and the state is based on the
state after the most recent event before the sync token (E3), which doesn't
include S2.
Now more events arrive, and we do an incremental sync:
E1
| S2
|
E3 |
|
--|------|----
| |
E4 |
/
E5
This is the last chance for us to tell the client about S2, so it *must* be
included in the response.
"""
alice = self.register_user("alice", "password")
alice_tok = self.login(alice, "password")
alice_requester = create_requester(alice)
room_id = self.helper.create_room_as(alice, is_public=True, tok=alice_tok)
# Do an initial sync to get a known starting point.
initial_sync_result = self.get_success(
self.sync_handler.wait_for_sync_for_user(
alice_requester, generate_sync_config(alice)
)
)
last_room_creation_event_id = (
initial_sync_result.joined[0].timeline.events[-1].event_id
)
# Send a state event, and a regular event, both using the same prev ID
with self._patch_get_latest_events([last_room_creation_event_id]):
s2_event = self.helper.send_state(room_id, "s2", {}, tok=alice_tok)[
"event_id"
]
e3_event = self.helper.send(room_id, "e3", tok=alice_tok)["event_id"]
# Another initial sync, with limit=1
initial_sync_result = self.get_success(
self.sync_handler.wait_for_sync_for_user(
alice_requester,
generate_sync_config(
alice,
filter_collection=FilterCollection(
self.hs, {"room": {"timeline": {"limit": 1}}}
),
),
)
)
room_sync = initial_sync_result.joined[0]
self.assertEqual(room_sync.room_id, room_id)
self.assertEqual(
[e.event_id for e in room_sync.timeline.events],
[e3_event],
)
self.assertNotIn(s2_event, [e.event_id for e in room_sync.state.values()])
# More events, E4 and E5
with self._patch_get_latest_events([e3_event]):
e4_event = self.helper.send(room_id, "e4", tok=alice_tok)["event_id"]
e5_event = self.helper.send(room_id, "e5", tok=alice_tok)["event_id"]
# Now incremental sync
incremental_sync = self.get_success(
self.sync_handler.wait_for_sync_for_user(
alice_requester,
generate_sync_config(alice),
since_token=initial_sync_result.next_batch,
)
)
# The state event should appear in the 'state' section of the response.
room_sync = incremental_sync.joined[0]
self.assertEqual(room_sync.room_id, room_id)
self.assertFalse(room_sync.timeline.limited)
self.assertEqual(
[e.event_id for e in room_sync.timeline.events],
[e4_event, e5_event],
)
self.assertEqual(
[e.event_id for e in room_sync.state.values()],
[s2_event],
)
@parameterized.expand(
[
(False, False),
(True, False),
(False, True),
(True, True),
]
)
def test_archived_rooms_do_not_include_state_after_leave(
self, initial_sync: bool, empty_timeline: bool
) -> None:
"""If the user leaves the room, state changes that happen after they leave are not returned.
We try with both a zero and a normal timeline limit,
and we try both an initial sync and an incremental sync for both.
"""
if empty_timeline and not initial_sync:
# FIXME synapse doesn't return the room at all in this situation!
self.skipTest("Synapse does not correctly handle this case")
# Alice creates the room, and bob joins.
alice = self.register_user("alice", "password")
alice_tok = self.login(alice, "password")
bob = self.register_user("bob", "password")
bob_tok = self.login(bob, "password")
bob_requester = create_requester(bob)
room_id = self.helper.create_room_as(alice, is_public=True, tok=alice_tok)
self.helper.join(room_id, bob, tok=bob_tok)
initial_sync_result = self.get_success(
self.sync_handler.wait_for_sync_for_user(
bob_requester, generate_sync_config(bob)
)
)
# Alice sends a message and a state
before_message_event = self.helper.send(room_id, "before", tok=alice_tok)[
"event_id"
]
before_state_event = self.helper.send_state(
room_id, "test_state", {"body": "before"}, tok=alice_tok
)["event_id"]
# Bob leaves
leave_event = self.helper.leave(room_id, bob, tok=bob_tok)["event_id"]
# Alice sends some more stuff
self.helper.send(room_id, "after", tok=alice_tok)["event_id"]
self.helper.send_state(room_id, "test_state", {"body": "after"}, tok=alice_tok)[
"event_id"
]
# And now, Bob resyncs.
filter_dict: JsonDict = {"room": {"include_leave": True}}
if empty_timeline:
filter_dict["room"]["timeline"] = {"limit": 0}
sync_room_result = self.get_success(
self.sync_handler.wait_for_sync_for_user(
bob_requester,
generate_sync_config(
bob, filter_collection=FilterCollection(self.hs, filter_dict)
),
since_token=None if initial_sync else initial_sync_result.next_batch,
)
).archived[0]
if empty_timeline:
# The timeline should be empty
self.assertEqual(sync_room_result.timeline.events, [])
# And the state should include the leave event...
self.assertEqual(
sync_room_result.state[("m.room.member", bob)].event_id, leave_event
)
# ... and the state change before he left.
self.assertEqual(
sync_room_result.state[("test_state", "")].event_id, before_state_event
)
else:
# The last three events in the timeline should be those leading up to the
# leave
self.assertEqual(
[e.event_id for e in sync_room_result.timeline.events[-3:]],
[before_message_event, before_state_event, leave_event],
)
# ... And the state should be empty
self.assertEqual(sync_room_result.state, {})
def _patch_get_latest_events(self, latest_events: List[str]) -> ContextManager:
"""Monkey-patch `get_prev_events_for_room`
Returns a context manager which will replace the implementation of
`get_prev_events_for_room` with one which returns `latest_events`.
"""
return patch.object(
self.hs.get_datastores().main,
"get_prev_events_for_room",
new_callable=AsyncMock,
return_value=latest_events,
)
def test_call_invite_in_public_room_not_returned(self) -> None: def test_call_invite_in_public_room_not_returned(self) -> None:
user = self.register_user("alice", "password") user = self.register_user("alice", "password")
tok = self.login(user, "password") tok = self.login(user, "password")
@@ -401,14 +756,26 @@ _request_key = 0
def generate_sync_config( def generate_sync_config(
user_id: str, device_id: Optional[str] = "device_id" user_id: str,
device_id: Optional[str] = "device_id",
filter_collection: Optional[FilterCollection] = None,
) -> SyncConfig: ) -> SyncConfig:
"""Generate a sync config (with a unique request key).""" """Generate a sync config (with a unique request key).
Args:
user_id: user who is syncing.
device_id: device that is syncing. Defaults to "device_id".
filter_collection: filter to apply. Defaults to the default filter (ie,
return everything, with a default limit)
"""
if filter_collection is None:
filter_collection = Filtering(Mock()).DEFAULT_FILTER_COLLECTION
global _request_key global _request_key
_request_key += 1 _request_key += 1
return SyncConfig( return SyncConfig(
user=UserID.from_string(user_id), user=UserID.from_string(user_id),
filter_collection=Filtering(Mock()).DEFAULT_FILTER_COLLECTION, filter_collection=filter_collection,
is_guest=False, is_guest=False,
request_key=("request_key", _request_key), request_key=("request_key", _request_key),
device_id=device_id, device_id=device_id,
+11 -7
View File
@@ -170,8 +170,8 @@ class RestHelper:
targ: Optional[str] = None, targ: Optional[str] = None,
expect_code: int = HTTPStatus.OK, expect_code: int = HTTPStatus.OK,
tok: Optional[str] = None, tok: Optional[str] = None,
) -> None: ) -> JsonDict:
self.change_membership( return self.change_membership(
room=room, room=room,
src=src, src=src,
targ=targ, targ=targ,
@@ -189,8 +189,8 @@ class RestHelper:
appservice_user_id: Optional[str] = None, appservice_user_id: Optional[str] = None,
expect_errcode: Optional[Codes] = None, expect_errcode: Optional[Codes] = None,
expect_additional_fields: Optional[dict] = None, expect_additional_fields: Optional[dict] = None,
) -> None: ) -> JsonDict:
self.change_membership( return self.change_membership(
room=room, room=room,
src=user, src=user,
targ=user, targ=user,
@@ -242,8 +242,8 @@ class RestHelper:
user: Optional[str] = None, user: Optional[str] = None,
expect_code: int = HTTPStatus.OK, expect_code: int = HTTPStatus.OK,
tok: Optional[str] = None, tok: Optional[str] = None,
) -> None: ) -> JsonDict:
self.change_membership( return self.change_membership(
room=room, room=room,
src=user, src=user,
targ=user, targ=user,
@@ -282,7 +282,7 @@ class RestHelper:
expect_code: int = HTTPStatus.OK, expect_code: int = HTTPStatus.OK,
expect_errcode: Optional[str] = None, expect_errcode: Optional[str] = None,
expect_additional_fields: Optional[dict] = None, expect_additional_fields: Optional[dict] = None,
) -> None: ) -> JsonDict:
""" """
Send a membership state event into a room. Send a membership state event into a room.
@@ -298,6 +298,9 @@ class RestHelper:
using an application service access token in `tok`. using an application service access token in `tok`.
expect_code: The expected HTTP response code expect_code: The expected HTTP response code
expect_errcode: The expected Matrix error code expect_errcode: The expected Matrix error code
Returns:
The JSON response
""" """
temp_id = self.auth_user_id temp_id = self.auth_user_id
self.auth_user_id = src self.auth_user_id = src
@@ -356,6 +359,7 @@ class RestHelper:
) )
self.auth_user_id = temp_id self.auth_user_id = temp_id
return channel.json_body
def send( def send(
self, self,