From 8f33017653477caedada088c4f1283a102bf26a5 Mon Sep 17 00:00:00 2001 From: Will Hunt Date: Fri, 27 Jun 2025 17:25:28 +0100 Subject: [PATCH] Add functions to store and get the deleted room members --- synapse/storage/databases/main/roommember.py | 80 +++++++++++++++++++- 1 file changed, 79 insertions(+), 1 deletion(-) diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 7ca73abb83..23f4ac27f4 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -102,6 +102,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): super().__init__(database, db_conn, hs) self._server_notices_mxid = hs.config.servernotices.server_notices_mxid + self._our_server_name = hs.config.server.server_name if ( self.hs.config.worker.run_background_tasks @@ -1388,7 +1389,7 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): rows = cast( List[Tuple[str, str, str]], - await self.db_pool.simple_select_many_batch( + await self.db_pool.simple_select_onecol( table="room_memberships", column="event_id", iterable=member_event_ids, @@ -1845,6 +1846,83 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore): "_get_room_participation_txn", _get_room_participation_txn, user_id, room_id ) + async def store_deleted_room_members( + self, + room_id: str, + ) -> None: + """TODO: WRITE + + Args: + room_id: the ID of the room + stream_id: stream ID at the point the room was deleted + user_ids: all users who were ever present in the room + """ + + # Welcome to the pain zone. We need to first extract all the local members + sql = """ + SELECT state_key, membership, event_stream_ordering FROM current_state_events + WHERE type = 'm.room.member' + AND room_id = ? + AND state_key LIKE ? + """ + + # TODO: Should we check for any joins, everyone should be banned or left at this point... + + # We do need to be careful to ensure that host doesn't have any wild cards + # in it, but we checked above for known ones and we'll check below that + # the returned user actually has the correct domain. + user_rows = await self.db_pool.execute( + "store_deleted_room_members_get_members", sql, room_id, ("%:" + self._our_server_name) + ) + + logger.info("store_deleted_room_members %s %s %s %s %s", room_id, user_rows, sql, room_id, ("%:" + self._our_server_name)) + + await self.db_pool.runInteraction( + "store_deleted_room_members", + self._store_deleted_room_members_txn, + room_id, + user_rows + ) + + def _store_deleted_room_members_txn( + self, + txn: LoggingTransaction, + room_id: str, + users: Iterable[Tuple[str, str, int]], + ) -> None: + # If the user is still currently joined, they are about to get kicked so + # use the latest stream position + max = self.get_room_max_stream_ordering() + return DatabasePool.simple_insert_many_txn( + txn, + table="deleted_room_members", + keys=("room_id", "user_id", "deleted_at_stream_id"), + values=[(room_id, user[0], user[2] if user[1] in [Membership.BAN, Membership.LEAVE] else max) for user in users], + ) + + async def get_deleted_rooms_for_user( + self, user_id: str, stream_pos: int + ) -> list[(str, int)]: + """Checks if the given rooms have partial state. + + Returns true for "partial-state" rooms, which means that the state + at events in the room, and `current_state_events`, may not yet be + complete. + """ + + def _get_deleted_rooms_for_user(txn: LoggingTransaction) -> list[(str, int)]: + sql = """ + SELECT room_id, deleted_at_stream_id FROM deleted_room_members + WHERE user_id = ? + AND ? < deleted_at_stream_id + """ + txn.execute(sql, (user_id, stream_pos)) + return set([(r[0], r[1]) for r in txn]) + + return await self.db_pool.runInteraction( + "get_deleted_rooms_for_user", + _get_deleted_rooms_for_user + ) class RoomMemberBackgroundUpdateStore(SQLBaseStore): def __init__(