diff --git a/docker/complement/conf/workers-shared-extra.yaml.j2 b/docker/complement/conf/workers-shared-extra.yaml.j2 index 94e74df9d1..cb963c04ec 100644 --- a/docker/complement/conf/workers-shared-extra.yaml.j2 +++ b/docker/complement/conf/workers-shared-extra.yaml.j2 @@ -135,6 +135,8 @@ experimental_features: msc4155_enabled: true # Thread Subscriptions msc4306_enabled: true + # Sticky Events + msc4354_enabled: true server_notices: system_mxid_localpart: _server diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index d086deab3f..2cff80c753 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -592,3 +592,6 @@ class ExperimentalConfig(Config): # MSC4306: Thread Subscriptions # (and MSC4308: Thread Subscriptions extension to Sliding Sync) self.msc4306_enabled: bool = experimental.get("msc4306_enabled", False) + + # MSC4354: Sticky Events + self.msc4354_enabled: bool = experimental.get("msc4354_enabled", False) diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py index a1c9c286ac..c71def8c76 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py @@ -195,6 +195,8 @@ class FederationBase: # using the event in prev_events). redacted_event = prune_event(pdu) redacted_event.internal_metadata.soft_failed = True + # Mark this as spam so we don't re-evaluate soft-failure status. + pdu.internal_metadata.policy_server_spammy = True return redacted_event return pdu diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index 64deae7650..01790bc1e4 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -82,6 +82,9 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +MSC4354_STICKY_DURATION_QUERY_PARAM = "msc4354_stick_duration_ms" +MSC4354_STICKY_EVENT_KEY = "msc4354_sticky" + class _RoomSize(Enum): """ @@ -206,6 +209,7 @@ class RoomStateEventRestServlet(RestServlet): self.clock = hs.get_clock() self._max_event_delay_ms = hs.config.server.max_event_delay_ms self._spam_checker_module_callbacks = hs.get_module_api_callbacks().spam_checker + self.msc4354_enabled = hs.config.experimental.msc4354_enabled def register(self, http_server: HttpServer) -> None: # /rooms/$roomid/state/$eventtype @@ -364,6 +368,14 @@ class RoomStateEventRestServlet(RestServlet): "room_id": room_id, "sender": requester.user.to_string(), } + if self.msc4354_enabled: + sticky_duration_ms = parse_integer( + request, MSC4354_STICKY_DURATION_QUERY_PARAM + ) + if sticky_duration_ms is not None: + event_dict[MSC4354_STICKY_EVENT_KEY] = { + "duration_ms": sticky_duration_ms, + } if state_key is not None: event_dict["state_key"] = state_key @@ -396,6 +408,7 @@ class RoomSendEventRestServlet(TransactionRestServlet): self.delayed_events_handler = hs.get_delayed_events_handler() self.auth = hs.get_auth() self._max_event_delay_ms = hs.config.server.max_event_delay_ms + self.msc4354_enabled = hs.config.experimental.msc4354_enabled def register(self, http_server: HttpServer) -> None: # /rooms/$roomid/send/$event_type[/$txn_id] @@ -442,6 +455,15 @@ class RoomSendEventRestServlet(TransactionRestServlet): if origin_server_ts is not None: event_dict["origin_server_ts"] = origin_server_ts + if self.msc4354_enabled: + sticky_duration_ms = parse_integer( + request, MSC4354_STICKY_DURATION_QUERY_PARAM + ) + if sticky_duration_ms is not None: + event_dict[MSC4354_STICKY_EVENT_KEY] = { + "duration_ms": sticky_duration_ms, + } + try: ( event, diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py index a50e889b9d..b9542d9f22 100644 --- a/synapse/storage/databases/main/events.py +++ b/synapse/storage/databases/main/events.py @@ -22,6 +22,7 @@ import collections import itertools import logging +import time from collections import OrderedDict from typing import ( TYPE_CHECKING, @@ -1174,6 +1175,13 @@ class PersistEventsStore: self._update_sliding_sync_tables_with_new_persisted_events_txn( txn, room_id, events_and_contexts ) + # process events which are sticky as well as re-evaluate soft-failed sticky events. + self._handle_sticky_events_txn( + txn, + room_id, + events_and_contexts, + state_delta_for_room, + ) def _persist_event_auth_chain_txn( self, @@ -2921,6 +2929,225 @@ class PersistEventsStore: }, ) + def _handle_sticky_events_txn( + self, + txn: LoggingTransaction, + room_id: str, + events_and_contexts: List[EventPersistencePair], + state_delta_for_room: Optional[DeltaState], + ) -> None: + """ + Update the sticky events table, used in MSC4354. + + This function assumes that `_store_event_txn()` (to persist the event) and + `_update_current_state_txn(...)` (so the current state has taken the events into account) + have already been run. + + "Handling" sticky events is broken into two phases: + - for each sticky event in events_and_contexts, mark them as sticky in the sticky events table. + - for each still-sticky soft-failed event in the room, re-evaluate soft-failedness. + + Args: + txn + room_id: The room that all of the events belong to + events_and_contexts: The events being persisted. + state_delta_for_room: The changes to the current state, used to detect if we need to + re-evaluate soft-failed sticky events. + """ + if len(events_and_contexts) == 0: + return + + # TODO: finish the impl + # fetch soft failed sticky events to recheck now, before we insert new sticky events, else + # we could incorrectly re-evaluate new sticky events + # event_ids_to_check = self._get_soft_failed_sticky_events_to_recheck(txn, room_id, state_delta_for_room) + # logger.info(f"_get_soft_failed_sticky_events_to_recheck => {event_ids_to_check}") + # recheck them and update any that now pass soft-fail checks. + # self._recheck_soft_failed_events(txn, room_id, event_ids_to_check) + + # insert brand new sticky events. + self._insert_sticky_events_txn(txn, events_and_contexts) + + def _insert_sticky_events_txn( + self, + txn: LoggingTransaction, + events_and_contexts: List[EventPersistencePair], + ) -> None: + sticky_events: List[EventBase] = [] + for ev, _ in events_and_contexts: + # MSC: Note: policy servers and other similar antispam techniques still apply to these events. + if ev.internal_metadata.policy_server_spammy: + continue + # We shouldn't be passed rejected events, but if we do, we filter them out too. + if ev.rejected_reason is not None: + continue + # MSC: The presence of sticky.duration_ms with a valid value makes the event “sticky” + sticky_obj = ev.get("sticky", None) + if type(sticky_obj) is dict: + sticky_duration_ms = sticky_obj.get("duration_ms", None) + # MSC: Valid values are the integer range 0-3600000 (1 hour). + if ( + type(sticky_duration_ms) is int + and sticky_duration_ms >= 0 + and sticky_duration_ms <= 3600000 + ): + sticky_events.append(ev) + + if len(sticky_events) == 0: + return + now_ms = round(time.time() * 1000) + self.db_pool.simple_insert_many_txn( + txn, + "sticky_events", + keys=("room_id", "event_id", "sender", "expires_at", "soft_failed"), + values=[ + ( + ev.room_id, + ev.event_id, + ev.sender, + # MSC: The start time is min(now, origin_server_ts). + # This ensures that malicious origin timestamps cannot specify start times in the future. + # Calculate the end time as start_time + min(sticky.duration_ms, 3600000). + min(ev.origin_server_ts, now_ms) + + min(ev.get_dict()["sticky"]["duration_ms"], 3600000), + ev.internal_metadata.soft_failed, + ) + for ev in sticky_events + ], + ) + + def _get_soft_failed_sticky_events_to_recheck( + self, + txn: LoggingTransaction, + room_id: str, + state_delta_for_room: Optional[DeltaState], + ) -> List[str]: + """Fetch soft-failed sticky events which should be rechecked against the current state. + + Soft-failed events are not rejected, so they pass auth at the state before + the event and at the auth_events in the event. Instead, soft-failed events failed auth at + the _current state of the room_. We only need to recheck soft failure if we have a reason to + believe the event may pass that check now. + + Note that we don't bother rechecking accepted events that may now be soft-failed, because + by that point it's too late as we've already sent the event to clients. + + Returns: + A list of event IDs to recheck + """ + + if state_delta_for_room is None: + # No change to current state => no way soft failure status could be different. + return [] + + # any change to critical auth events may change soft failure status. This means any changes + # to join rules, power levels or member events. If the state has changed but it isn't one + # of those events, we don't need to recheck. + critical_auth_types = ( + EventTypes.JoinRules, + EventTypes.PowerLevels, + EventTypes.Member, + ) + critical_auth_types_changed = set() + critical_auth_types_changed.update( + [ + typ + for typ, _ in state_delta_for_room.to_delete + if typ in critical_auth_types + ] + ) + critical_auth_types_changed.update( + [ + typ + for typ, _ in state_delta_for_room.to_insert + if typ in critical_auth_types + ] + ) + if len(critical_auth_types_changed) == 0: + # No change to critical auth events => no way soft failure status could be different. + return [] + + if critical_auth_types_changed == {EventTypes.Member}: + # the final case we want to catch is when unprivileged users join/leave rooms. These users cause + # changes in the critical auth types (the member event) but ultimately have no effect on soft + # failure status for anyone but that user themselves. + # + # Grab the set of senders that have been modified and see if any of them sent a soft-failed + # sticky event. If they did, then we need to re-evaluate. If they didn't, then we don't need to. + new_membership_changes = set( + [ + skey + for typ, skey in state_delta_for_room.to_insert + if typ == EventTypes.Member + ] + + [ + skey + for typ, skey in state_delta_for_room.to_delete + if typ == EventTypes.Member + ] + ) + # pull out senders of sticky events in this room + events_to_recheck: List[Tuple[str]] = self.db_pool.simple_select_many_txn( + txn, + table="sticky_events", + column="sender", + iterable=new_membership_changes, + keyvalues={ + "room_id": room_id, + "soft_failed": True, + }, + retcols=("event_id"), + ) + return [event_id for (event_id,) in events_to_recheck] + + # otherwise one of the following must be true: + # - there was a change in PL or join rules + # - there was a change in the membership of a sender of a soft-failed sticky event. + # In both of these cases we want to re-evaluate soft failure status. + # + # NB: event auth checks are NOT recursive. We don't need to specifically handle the case where + # an admin user's membership changes which causes a PL event to be allowed, as when the PL event + # gets allowed we will re-evaluate anyway. E.g: + # + # PL(send_event=0, sender=Admin) + # ^ ^_____________________ + # | | + # . PL(send_event=50, sender=Mod) sticky event (sender=User) + # + # In this scenario, the sticky event is soft-failed due to the Mod updating the PL event to + # set send_event=50, which User does not have. If we learn of an event which makes Mod's PL + # event invalid (say, Mod was banned by Admin concurrently to Mod setting the PL event), then + # the act of seeing the ban event will cause the old PL event to be in the state delta, meaning + # we will re-evaluate the sticky event due to the PL changing. We don't need to specially handle case.a + events_to_recheck = self.db_pool.simple_select_list_txn( + txn, + table="sticky_events", + keyvalues={ + "room_id": room_id, + "soft_failed": True, + }, + retcols=("event_id"), + ) + return [event_id for (event_id,) in events_to_recheck] + + def _recheck_soft_failed_events( + self, + txn: LoggingTransaction, + room_id: str, + event_ids: List[str], + ) -> None: + """ + Recheck authorised but soft-failed events. The provided event IDs must have already passed + all auth checks (so the event isn't rejected) but soft-failure checks. + + Args: + txn: The SQL transaction + room_id: The room the event IDs are in. + event_ids: The soft-failed events to re-evaluate. + """ + # We know the events are otherwise authorised, so we only need to load the current state + # and check if the events pass auth at the current state. + def insert_labels_for_event_txn( self, txn: LoggingTransaction,