From 5949ab86f8db0ef3dac2063e42210030f17786fb Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 1 Jun 2022 11:57:49 +0100 Subject: [PATCH 01/14] Fix potential thumbnail memory leaks. (#12932) --- changelog.d/12932.bugfix | 1 + synapse/rest/media/v1/media_repository.py | 269 ++++++++++++---------- synapse/rest/media/v1/thumbnailer.py | 71 +++++- 3 files changed, 204 insertions(+), 137 deletions(-) create mode 100644 changelog.d/12932.bugfix diff --git a/changelog.d/12932.bugfix b/changelog.d/12932.bugfix new file mode 100644 index 0000000000..506f92b427 --- /dev/null +++ b/changelog.d/12932.bugfix @@ -0,0 +1 @@ +Fix potential memory leak when generating thumbnails. diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index 20af366538..a551458a9f 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -587,15 +587,16 @@ class MediaRepository: ) return None - t_byte_source = await defer_to_thread( - self.hs.get_reactor(), - self._generate_thumbnail, - thumbnailer, - t_width, - t_height, - t_method, - t_type, - ) + with thumbnailer: + t_byte_source = await defer_to_thread( + self.hs.get_reactor(), + self._generate_thumbnail, + thumbnailer, + t_width, + t_height, + t_method, + t_type, + ) if t_byte_source: try: @@ -657,15 +658,16 @@ class MediaRepository: ) return None - t_byte_source = await defer_to_thread( - self.hs.get_reactor(), - self._generate_thumbnail, - thumbnailer, - t_width, - t_height, - t_method, - t_type, - ) + with thumbnailer: + t_byte_source = await defer_to_thread( + self.hs.get_reactor(), + self._generate_thumbnail, + thumbnailer, + t_width, + t_height, + t_method, + t_type, + ) if t_byte_source: try: @@ -749,119 +751,134 @@ class MediaRepository: ) return None - m_width = thumbnailer.width - m_height = thumbnailer.height + with thumbnailer: + m_width = thumbnailer.width + m_height = thumbnailer.height - if m_width * m_height >= self.max_image_pixels: - logger.info( - "Image too large to thumbnail %r x %r > %r", - m_width, - m_height, - self.max_image_pixels, - ) - return None - - if thumbnailer.transpose_method is not None: - m_width, m_height = await defer_to_thread( - self.hs.get_reactor(), thumbnailer.transpose - ) - - # We deduplicate the thumbnail sizes by ignoring the cropped versions if - # they have the same dimensions of a scaled one. - thumbnails: Dict[Tuple[int, int, str], str] = {} - for requirement in requirements: - if requirement.method == "crop": - thumbnails.setdefault( - (requirement.width, requirement.height, requirement.media_type), - requirement.method, + if m_width * m_height >= self.max_image_pixels: + logger.info( + "Image too large to thumbnail %r x %r > %r", + m_width, + m_height, + self.max_image_pixels, ) - elif requirement.method == "scale": - t_width, t_height = thumbnailer.aspect( - requirement.width, requirement.height + return None + + if thumbnailer.transpose_method is not None: + m_width, m_height = await defer_to_thread( + self.hs.get_reactor(), thumbnailer.transpose ) - t_width = min(m_width, t_width) - t_height = min(m_height, t_height) - thumbnails[ - (t_width, t_height, requirement.media_type) - ] = requirement.method - # Now we generate the thumbnails for each dimension, store it - for (t_width, t_height, t_type), t_method in thumbnails.items(): - # Generate the thumbnail - if t_method == "crop": - t_byte_source = await defer_to_thread( - self.hs.get_reactor(), thumbnailer.crop, t_width, t_height, t_type - ) - elif t_method == "scale": - t_byte_source = await defer_to_thread( - self.hs.get_reactor(), thumbnailer.scale, t_width, t_height, t_type - ) - else: - logger.error("Unrecognized method: %r", t_method) - continue - - if not t_byte_source: - continue - - file_info = FileInfo( - server_name=server_name, - file_id=file_id, - url_cache=url_cache, - thumbnail=ThumbnailInfo( - width=t_width, - height=t_height, - method=t_method, - type=t_type, - ), - ) - - with self.media_storage.store_into_file(file_info) as (f, fname, finish): - try: - await self.media_storage.write_to_file(t_byte_source, f) - await finish() - finally: - t_byte_source.close() - - t_len = os.path.getsize(fname) - - # Write to database - if server_name: - # Multiple remote media download requests can race (when - # using multiple media repos), so this may throw a violation - # constraint exception. If it does we'll delete the newly - # generated thumbnail from disk (as we're in the ctx - # manager). - # - # However: we've already called `finish()` so we may have - # also written to the storage providers. This is preferable - # to the alternative where we call `finish()` *after* this, - # where we could end up having an entry in the DB but fail - # to write the files to the storage providers. - try: - await self.store.store_remote_media_thumbnail( - server_name, - media_id, - file_id, - t_width, - t_height, - t_type, - t_method, - t_len, - ) - except Exception as e: - thumbnail_exists = await self.store.get_remote_media_thumbnail( - server_name, - media_id, - t_width, - t_height, - t_type, - ) - if not thumbnail_exists: - raise e - else: - await self.store.store_local_thumbnail( - media_id, t_width, t_height, t_type, t_method, t_len + # We deduplicate the thumbnail sizes by ignoring the cropped versions if + # they have the same dimensions of a scaled one. + thumbnails: Dict[Tuple[int, int, str], str] = {} + for requirement in requirements: + if requirement.method == "crop": + thumbnails.setdefault( + (requirement.width, requirement.height, requirement.media_type), + requirement.method, ) + elif requirement.method == "scale": + t_width, t_height = thumbnailer.aspect( + requirement.width, requirement.height + ) + t_width = min(m_width, t_width) + t_height = min(m_height, t_height) + thumbnails[ + (t_width, t_height, requirement.media_type) + ] = requirement.method + + # Now we generate the thumbnails for each dimension, store it + for (t_width, t_height, t_type), t_method in thumbnails.items(): + # Generate the thumbnail + if t_method == "crop": + t_byte_source = await defer_to_thread( + self.hs.get_reactor(), + thumbnailer.crop, + t_width, + t_height, + t_type, + ) + elif t_method == "scale": + t_byte_source = await defer_to_thread( + self.hs.get_reactor(), + thumbnailer.scale, + t_width, + t_height, + t_type, + ) + else: + logger.error("Unrecognized method: %r", t_method) + continue + + if not t_byte_source: + continue + + file_info = FileInfo( + server_name=server_name, + file_id=file_id, + url_cache=url_cache, + thumbnail=ThumbnailInfo( + width=t_width, + height=t_height, + method=t_method, + type=t_type, + ), + ) + + with self.media_storage.store_into_file(file_info) as ( + f, + fname, + finish, + ): + try: + await self.media_storage.write_to_file(t_byte_source, f) + await finish() + finally: + t_byte_source.close() + + t_len = os.path.getsize(fname) + + # Write to database + if server_name: + # Multiple remote media download requests can race (when + # using multiple media repos), so this may throw a violation + # constraint exception. If it does we'll delete the newly + # generated thumbnail from disk (as we're in the ctx + # manager). + # + # However: we've already called `finish()` so we may have + # also written to the storage providers. This is preferable + # to the alternative where we call `finish()` *after* this, + # where we could end up having an entry in the DB but fail + # to write the files to the storage providers. + try: + await self.store.store_remote_media_thumbnail( + server_name, + media_id, + file_id, + t_width, + t_height, + t_type, + t_method, + t_len, + ) + except Exception as e: + thumbnail_exists = ( + await self.store.get_remote_media_thumbnail( + server_name, + media_id, + t_width, + t_height, + t_type, + ) + ) + if not thumbnail_exists: + raise e + else: + await self.store.store_local_thumbnail( + media_id, t_width, t_height, t_type, t_method, t_len + ) return {"width": m_width, "height": m_height} diff --git a/synapse/rest/media/v1/thumbnailer.py b/synapse/rest/media/v1/thumbnailer.py index 390491eb83..9b93b9b4f6 100644 --- a/synapse/rest/media/v1/thumbnailer.py +++ b/synapse/rest/media/v1/thumbnailer.py @@ -14,7 +14,8 @@ # limitations under the License. import logging from io import BytesIO -from typing import Tuple +from types import TracebackType +from typing import Optional, Tuple, Type from PIL import Image @@ -45,6 +46,9 @@ class Thumbnailer: Image.MAX_IMAGE_PIXELS = max_image_pixels def __init__(self, input_path: str): + # Have we closed the image? + self._closed = False + try: self.image = Image.open(input_path) except OSError as e: @@ -89,7 +93,8 @@ class Thumbnailer: # Safety: `transpose` takes an int rather than e.g. an IntEnum. # self.transpose_method is set above to be a value in # EXIF_TRANSPOSE_MAPPINGS, and that only contains correct values. - self.image = self.image.transpose(self.transpose_method) # type: ignore[arg-type] + with self.image: + self.image = self.image.transpose(self.transpose_method) # type: ignore[arg-type] self.width, self.height = self.image.size self.transpose_method = None # We don't need EXIF any more @@ -122,9 +127,11 @@ class Thumbnailer: # If the image has transparency, use RGBA instead. if self.image.mode in ["1", "L", "P"]: if self.image.info.get("transparency", None) is not None: - self.image = self.image.convert("RGBA") + with self.image: + self.image = self.image.convert("RGBA") else: - self.image = self.image.convert("RGB") + with self.image: + self.image = self.image.convert("RGB") return self.image.resize((width, height), Image.ANTIALIAS) def scale(self, width: int, height: int, output_type: str) -> BytesIO: @@ -133,8 +140,8 @@ class Thumbnailer: Returns: BytesIO: the bytes of the encoded image ready to be written to disk """ - scaled = self._resize(width, height) - return self._encode_image(scaled, output_type) + with self._resize(width, height) as scaled: + return self._encode_image(scaled, output_type) def crop(self, width: int, height: int, output_type: str) -> BytesIO: """Rescales and crops the image to the given dimensions preserving @@ -151,18 +158,21 @@ class Thumbnailer: BytesIO: the bytes of the encoded image ready to be written to disk """ if width * self.height > height * self.width: + scaled_width = width scaled_height = (width * self.height) // self.width - scaled_image = self._resize(width, scaled_height) crop_top = (scaled_height - height) // 2 crop_bottom = height + crop_top - cropped = scaled_image.crop((0, crop_top, width, crop_bottom)) + crop = (0, crop_top, width, crop_bottom) else: scaled_width = (height * self.width) // self.height - scaled_image = self._resize(scaled_width, height) + scaled_height = height crop_left = (scaled_width - width) // 2 crop_right = width + crop_left - cropped = scaled_image.crop((crop_left, 0, crop_right, height)) - return self._encode_image(cropped, output_type) + crop = (crop_left, 0, crop_right, height) + + with self._resize(scaled_width, scaled_height) as scaled_image: + with scaled_image.crop(crop) as cropped: + return self._encode_image(cropped, output_type) def _encode_image(self, output_image: Image.Image, output_type: str) -> BytesIO: output_bytes_io = BytesIO() @@ -171,3 +181,42 @@ class Thumbnailer: output_image = output_image.convert("RGB") output_image.save(output_bytes_io, fmt, quality=80) return output_bytes_io + + def close(self) -> None: + """Closes the underlying image file. + + Once closed no other functions can be called. + + Can be called multiple times. + """ + + if self._closed: + return + + self._closed = True + + # Since we run this on the finalizer then we need to handle `__init__` + # raising an exception before it can define `self.image`. + image = getattr(self, "image", None) + if image is None: + return + + image.close() + + def __enter__(self) -> "Thumbnailer": + """Make `Thumbnailer` a context manager that calls `close` on + `__exit__`. + """ + return self + + def __exit__( + self, + type: Optional[Type[BaseException]], + value: Optional[BaseException], + traceback: Optional[TracebackType], + ) -> None: + self.close() + + def __del__(self) -> None: + # Make sure we actually do close the image, rather than leak data. + self.close() From 79dadf7216836170af2ac5ef130bfc012b86821c Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Wed, 1 Jun 2022 12:29:51 +0100 Subject: [PATCH 02/14] Fix 404 on `/sync` when the last event is a redaction of an unknown/purged event (#12905) Currently, we try to pull the event corresponding to a sync token from the database. However, when we fetch redaction events, we check the target of that redaction (because we aren't allowed to send redactions to clients without validating them). So, if the sync token points to a redaction of an event that we don't have, we have a problem. It turns out we don't really need that event, and can just work with its ID and metadata, which sidesteps the whole problem. --- changelog.d/12905.bugfix | 1 + synapse/handlers/message.py | 114 +++++++++++++++-------- synapse/handlers/sync.py | 27 ++++-- synapse/storage/databases/main/state.py | 12 ++- synapse/storage/databases/main/stream.py | 12 +-- synapse/visibility.py | 28 ++++-- 6 files changed, 129 insertions(+), 65 deletions(-) create mode 100644 changelog.d/12905.bugfix diff --git a/changelog.d/12905.bugfix b/changelog.d/12905.bugfix new file mode 100644 index 0000000000..67e95d0398 --- /dev/null +++ b/changelog.d/12905.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in Synapse 1.58.0 where `/sync` would fail if the most recent event in a room was a redaction of an event that has since been purged. diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index cf7c2d1979..ac911a2ddc 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -28,6 +28,7 @@ from synapse.api.constants import ( EventContentFields, EventTypes, GuestAccess, + HistoryVisibility, Membership, RelationTypes, UserTypes, @@ -66,7 +67,7 @@ from synapse.util import json_decoder, json_encoder, log_failure, unwrapFirstErr from synapse.util.async_helpers import Linearizer, gather_results from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.metrics import measure_func -from synapse.visibility import filter_events_for_client +from synapse.visibility import get_effective_room_visibility_from_state if TYPE_CHECKING: from synapse.events.third_party_rules import ThirdPartyEventRules @@ -182,51 +183,31 @@ class MessageHandler: state_filter = state_filter or StateFilter.all() if at_token: - last_event = await self.store.get_last_event_in_room_before_stream_ordering( - room_id, - end_token=at_token.room_key, + last_event_id = ( + await self.store.get_last_event_in_room_before_stream_ordering( + room_id, + end_token=at_token.room_key, + ) ) - if not last_event: + if not last_event_id: raise NotFoundError("Can't find event for token %s" % (at_token,)) - # check whether the user is in the room at that time to determine - # whether they should be treated as peeking. - state_map = await self._state_storage_controller.get_state_for_event( - last_event.event_id, - StateFilter.from_types([(EventTypes.Member, user_id)]), - ) - - joined = False - membership_event = state_map.get((EventTypes.Member, user_id)) - if membership_event: - joined = membership_event.membership == Membership.JOIN - - is_peeking = not joined - - visible_events = await filter_events_for_client( - self._storage_controllers, - user_id, - [last_event], - filter_send_to_client=False, - is_peeking=is_peeking, - ) - - if visible_events: - room_state_events = ( - await self._state_storage_controller.get_state_for_events( - [last_event.event_id], state_filter=state_filter - ) - ) - room_state: Mapping[Any, EventBase] = room_state_events[ - last_event.event_id - ] - else: + if not await self._user_can_see_state_at_event( + user_id, room_id, last_event_id + ): raise AuthError( 403, "User %s not allowed to view events in room %s at token %s" % (user_id, room_id, at_token), ) + + room_state_events = ( + await self._state_storage_controller.get_state_for_events( + [last_event_id], state_filter=state_filter + ) + ) + room_state: Mapping[Any, EventBase] = room_state_events[last_event_id] else: ( membership, @@ -256,6 +237,65 @@ class MessageHandler: events = self._event_serializer.serialize_events(room_state.values(), now) return events + async def _user_can_see_state_at_event( + self, user_id: str, room_id: str, event_id: str + ) -> bool: + # check whether the user was in the room, and the history visibility, + # at that time. + state_map = await self._state_storage_controller.get_state_for_event( + event_id, + StateFilter.from_types( + [ + (EventTypes.Member, user_id), + (EventTypes.RoomHistoryVisibility, ""), + ] + ), + ) + + membership = None + membership_event = state_map.get((EventTypes.Member, user_id)) + if membership_event: + membership = membership_event.membership + + # if the user was a member of the room at the time of the event, + # they can see it. + if membership == Membership.JOIN: + return True + + # otherwise, it depends on the history visibility. + visibility = get_effective_room_visibility_from_state(state_map) + + if visibility == HistoryVisibility.JOINED: + # we weren't a member at the time of the event, so we can't see this event. + return False + + # otherwise *invited* is good enough + if membership == Membership.INVITE: + return True + + if visibility == HistoryVisibility.INVITED: + # we weren't invited, so we can't see this event. + return False + + if visibility == HistoryVisibility.WORLD_READABLE: + return True + + # So it's SHARED, and the user was not a member at the time. The user cannot + # see history, unless they have *subsequently* joined the room. + # + # XXX: if the user has subsequently joined and then left again, + # ideally we would share history up to the point they left. But + # we don't know when they left. We just treat it as though they + # never joined, and restrict access. + + ( + current_membership, + _, + ) = await self.store.get_local_current_membership_for_user_in_room( + user_id, event_id + ) + return current_membership == Membership.JOIN + async def get_joined_members(self, requester: Requester, room_id: str) -> dict: """Get all the joined members in the room and their profile information. diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index b5859dcb28..a1d41358d9 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -621,21 +621,32 @@ class SyncHandler: ) async def get_state_after_event( - self, event: EventBase, state_filter: Optional[StateFilter] = None + self, event_id: str, state_filter: Optional[StateFilter] = None ) -> StateMap[str]: """ Get the room state after the given event Args: - event: event of interest + event_id: event of interest state_filter: The state filter used to fetch state from the database. """ state_ids = await self._state_storage_controller.get_state_ids_for_event( - event.event_id, state_filter=state_filter or StateFilter.all() + event_id, state_filter=state_filter or StateFilter.all() ) - if event.is_state(): + + # using get_metadata_for_events here (instead of get_event) sidesteps an issue + # with redactions: if `event_id` is a redaction event, and we don't have the + # original (possibly because it got purged), get_event will refuse to return + # the redaction event, which isn't terribly helpful here. + # + # (To be fair, in that case we could assume it's *not* a state event, and + # therefore we don't need to worry about it. But still, it seems cleaner just + # to pull the metadata.) + m = (await self.store.get_metadata_for_events([event_id]))[event_id] + if m.state_key is not None and m.rejection_reason is None: state_ids = dict(state_ids) - state_ids[(event.type, event.state_key)] = event.event_id + state_ids[(m.event_type, m.state_key)] = event_id + return state_ids async def get_state_at( @@ -654,14 +665,14 @@ class SyncHandler: # FIXME: This gets the state at the latest event before the stream ordering, # which might not be the same as the "current state" of the room at the time # of the stream token if there were multiple forward extremities at the time. - last_event = await self.store.get_last_event_in_room_before_stream_ordering( + last_event_id = await self.store.get_last_event_in_room_before_stream_ordering( room_id, end_token=stream_position.room_key, ) - if last_event: + if last_event_id: state = await self.get_state_after_event( - last_event, state_filter=state_filter or StateFilter.all() + last_event_id, state_filter=state_filter or StateFilter.all() ) else: diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py index a07ad85582..3f2be3854b 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py @@ -54,6 +54,7 @@ class EventMetadata: room_id: str event_type: str state_key: Optional[str] + rejection_reason: Optional[str] def _retrieve_and_check_room_version(room_id: str, room_version_id: str) -> RoomVersion: @@ -167,17 +168,22 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): ) sql = f""" - SELECT e.event_id, e.room_id, e.type, se.state_key FROM events AS e + SELECT e.event_id, e.room_id, e.type, se.state_key, r.reason + FROM events AS e LEFT JOIN state_events se USING (event_id) + LEFT JOIN rejections r USING (event_id) WHERE {clause} """ txn.execute(sql, args) return { event_id: EventMetadata( - room_id=room_id, event_type=event_type, state_key=state_key + room_id=room_id, + event_type=event_type, + state_key=state_key, + rejection_reason=rejection_reason, ) - for event_id, room_id, event_type, state_key in txn + for event_id, room_id, event_type, state_key, rejection_reason in txn } result_map: Dict[str, EventMetadata] = {} diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index 0e3a23a140..8e88784d3c 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -765,15 +765,16 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): self, room_id: str, end_token: RoomStreamToken, - ) -> Optional[EventBase]: - """Returns the last event in a room at or before a stream ordering + ) -> Optional[str]: + """Returns the ID of the last event in a room at or before a stream ordering Args: room_id end_token: The token used to stream from Returns: - The most recent event. + The ID of the most recent event, or None if there are no events in the room + before this stream ordering. """ last_row = await self.get_room_event_before_stream_ordering( @@ -781,10 +782,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): stream_ordering=end_token.stream, ) if last_row: - _, _, event_id = last_row - event = await self.get_event(event_id, get_prev_content=True) - return event - + return last_row[2] return None async def get_current_room_stream_token_for_room_id( diff --git a/synapse/visibility.py b/synapse/visibility.py index 97548c14e3..8aaa8c709f 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -162,16 +162,7 @@ async def filter_events_for_client( state = event_id_to_state[event.event_id] # get the room_visibility at the time of the event. - visibility_event = state.get(_HISTORY_VIS_KEY, None) - if visibility_event: - visibility = visibility_event.content.get( - "history_visibility", HistoryVisibility.SHARED - ) - else: - visibility = HistoryVisibility.SHARED - - if visibility not in VISIBILITY_PRIORITY: - visibility = HistoryVisibility.SHARED + visibility = get_effective_room_visibility_from_state(state) # Always allow history visibility events on boundaries. This is done # by setting the effective visibility to the least restrictive @@ -267,6 +258,23 @@ async def filter_events_for_client( return [ev for ev in filtered_events if ev] +def get_effective_room_visibility_from_state(state: StateMap[EventBase]) -> str: + """Get the actual history vis, from a state map including the history_visibility event + + Handles missing and invalid history visibility events. + """ + visibility_event = state.get(_HISTORY_VIS_KEY, None) + if not visibility_event: + return HistoryVisibility.SHARED + + visibility = visibility_event.content.get( + "history_visibility", HistoryVisibility.SHARED + ) + if visibility not in VISIBILITY_PRIORITY: + visibility = HistoryVisibility.SHARED + return visibility + + async def filter_events_for_server( storage: StorageControllers, server_name: str, From 88193f2125ad2e1dc1c83d6876757cc5eb3c467d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jacek=20Ku=C5=9Bnierz?= Date: Wed, 1 Jun 2022 13:32:35 +0200 Subject: [PATCH 03/14] Remove direct refeferences to PyNaCl (use signedjson instead). (#12902) --- changelog.d/12902.misc | 1 + contrib/cmdclient/console.py | 9 ++++----- poetry.lock | 2 +- pyproject.toml | 1 - tests/crypto/test_event_signing.py | 17 +++++------------ tests/crypto/test_keyring.py | 2 +- 6 files changed, 12 insertions(+), 20 deletions(-) create mode 100644 changelog.d/12902.misc diff --git a/changelog.d/12902.misc b/changelog.d/12902.misc new file mode 100644 index 0000000000..3ee8f92552 --- /dev/null +++ b/changelog.d/12902.misc @@ -0,0 +1 @@ +Remove PyNaCl occurrences directly used in Synapse code. \ No newline at end of file diff --git a/contrib/cmdclient/console.py b/contrib/cmdclient/console.py index 856dd437db..895b2a7af1 100755 --- a/contrib/cmdclient/console.py +++ b/contrib/cmdclient/console.py @@ -16,6 +16,7 @@ """ Starts a synapse client console. """ import argparse +import binascii import cmd import getpass import json @@ -26,9 +27,8 @@ import urllib from http import TwistedHttpClient from typing import Optional -import nacl.encoding -import nacl.signing import urlparse +from signedjson.key import NACL_ED25519, decode_verify_key_bytes from signedjson.sign import SignatureVerifyException, verify_signed_json from twisted.internet import defer, reactor, threads @@ -41,7 +41,6 @@ TRUSTED_ID_SERVERS = ["localhost:8001"] class SynapseCmd(cmd.Cmd): - """Basic synapse command-line processor. This processes commands from the user and calls the relevant HTTP methods. @@ -420,8 +419,8 @@ class SynapseCmd(cmd.Cmd): pubKey = None pubKeyObj = yield self.http_client.do_request("GET", url) if "public_key" in pubKeyObj: - pubKey = nacl.signing.VerifyKey( - pubKeyObj["public_key"], encoder=nacl.encoding.HexEncoder + pubKey = decode_verify_key_bytes( + NACL_ED25519, binascii.unhexlify(pubKeyObj["public_key"]) ) else: print("No public key found in pubkey response!") diff --git a/poetry.lock b/poetry.lock index 6b4686545b..7c561e3182 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1563,7 +1563,7 @@ url_preview = ["lxml"] [metadata] lock-version = "1.1" python-versions = "^3.7.1" -content-hash = "d39d5ac5d51c014581186b7691999b861058b569084c525523baf70b77f292b1" +content-hash = "539e5326f401472d1ffc8325d53d72e544cd70156b3f43f32f1285c4c131f831" [metadata.files] attrs = [ diff --git a/pyproject.toml b/pyproject.toml index 75251c863d..ec6e81f254 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -113,7 +113,6 @@ unpaddedbase64 = ">=2.1.0" canonicaljson = ">=1.4.0" # we use the type definitions added in signedjson 1.1. signedjson = ">=1.1.0" -PyNaCl = ">=1.2.1" # validating SSL certs for IP addresses requires service_identity 18.1. service-identity = ">=18.1.0" # Twisted 18.9 introduces some logger improvements that the structured diff --git a/tests/crypto/test_event_signing.py b/tests/crypto/test_event_signing.py index 06e0545a4f..8fa710c9dc 100644 --- a/tests/crypto/test_event_signing.py +++ b/tests/crypto/test_event_signing.py @@ -12,10 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. - -import nacl.signing -import signedjson.types -from unpaddedbase64 import decode_base64 +from signedjson.key import decode_signing_key_base64 +from signedjson.types import SigningKey from synapse.api.room_versions import RoomVersions from synapse.crypto.event_signing import add_hashes_and_signatures @@ -25,7 +23,7 @@ from tests import unittest # Perform these tests using given secret key so we get entirely deterministic # signatures output that we can test against. -SIGNING_KEY_SEED = decode_base64("YJDBA9Xnr2sVqXD9Vj7XVUnmFZcZrlw8Md7kMW+3XA1") +SIGNING_KEY_SEED = "YJDBA9Xnr2sVqXD9Vj7XVUnmFZcZrlw8Md7kMW+3XA1" KEY_ALG = "ed25519" KEY_VER = "1" @@ -36,14 +34,9 @@ HOSTNAME = "domain" class EventSigningTestCase(unittest.TestCase): def setUp(self): - # NB: `signedjson` expects `nacl.signing.SigningKey` instances which have been - # monkeypatched to include new `alg` and `version` attributes. This is captured - # by the `signedjson.types.SigningKey` protocol. - self.signing_key: signedjson.types.SigningKey = nacl.signing.SigningKey( # type: ignore[assignment] - SIGNING_KEY_SEED + self.signing_key: SigningKey = decode_signing_key_base64( + KEY_ALG, KEY_VER, SIGNING_KEY_SEED ) - self.signing_key.alg = KEY_ALG - self.signing_key.version = KEY_VER def test_sign_minimal(self): event_dict = { diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index d00ef24ca8..820a1a54e2 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -19,8 +19,8 @@ import attr import canonicaljson import signedjson.key import signedjson.sign -from nacl.signing import SigningKey from signedjson.key import encode_verify_key_base64, get_verify_key +from signedjson.types import SigningKey from twisted.internet import defer from twisted.internet.defer import Deferred, ensureDeferred From 7bc08f320147a1d80371eb13258328c88073fad0 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Wed, 1 Jun 2022 09:41:25 -0400 Subject: [PATCH 04/14] Remove remaining bits of groups code. (#12936) * Update worker docs to remove group endpoints. * Removes an unused parameter to `ApplicationService`. * Break dependency between media repo and groups. * Avoid copying `m.room.related_groups` state events during room upgrades. --- changelog.d/12936.removal | 1 + docs/workers.md | 6 ------ synapse/api/constants.py | 1 - synapse/appservice/__init__.py | 2 -- synapse/config/appservice.py | 1 - synapse/handlers/room.py | 1 - synapse/storage/databases/main/media_repository.py | 4 ---- tests/api/test_auth.py | 2 -- tests/api/test_ratelimiting.py | 2 -- tests/appservice/test_api.py | 1 - tests/appservice/test_appservice.py | 1 - tests/handlers/test_appservice.py | 3 --- tests/handlers/test_user_directory.py | 1 - tests/rest/client/test_account.py | 1 - tests/rest/client/test_login.py | 2 -- tests/rest/client/test_register.py | 2 -- tests/rest/client/test_room_batch.py | 1 - tests/storage/test_user_directory.py | 1 - tests/test_mau.py | 3 --- 19 files changed, 1 insertion(+), 35 deletions(-) create mode 100644 changelog.d/12936.removal diff --git a/changelog.d/12936.removal b/changelog.d/12936.removal new file mode 100644 index 0000000000..41f6fae5da --- /dev/null +++ b/changelog.d/12936.removal @@ -0,0 +1 @@ +Remove support for the non-standard groups/communities feature from Synapse. diff --git a/docs/workers.md b/docs/workers.md index 78973a498c..6969c424d8 100644 --- a/docs/workers.md +++ b/docs/workers.md @@ -191,7 +191,6 @@ information. ^/_matrix/federation/v1/event_auth/ ^/_matrix/federation/v1/exchange_third_party_invite/ ^/_matrix/federation/v1/user/devices/ - ^/_matrix/federation/v1/get_groups_publicised$ ^/_matrix/key/v2/query ^/_matrix/federation/v1/hierarchy/ @@ -213,9 +212,6 @@ information. ^/_matrix/client/(r0|v3|unstable)/devices$ ^/_matrix/client/versions$ ^/_matrix/client/(api/v1|r0|v3|unstable)/voip/turnServer$ - ^/_matrix/client/(r0|v3|unstable)/joined_groups$ - ^/_matrix/client/(r0|v3|unstable)/publicised_groups$ - ^/_matrix/client/(r0|v3|unstable)/publicised_groups/ ^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/event/ ^/_matrix/client/(api/v1|r0|v3|unstable)/joined_rooms$ ^/_matrix/client/(api/v1|r0|v3|unstable)/search$ @@ -255,9 +251,7 @@ information. Additionally, the following REST endpoints can be handled for GET requests: - ^/_matrix/federation/v1/groups/ ^/_matrix/client/(api/v1|r0|v3|unstable)/pushrules/ - ^/_matrix/client/(r0|v3|unstable)/groups/ Pagination requests can also be handled, but all requests for a given room must be routed to the same instance. Additionally, care must be taken to diff --git a/synapse/api/constants.py b/synapse/api/constants.py index f03fdd6dae..e1d31cabed 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -95,7 +95,6 @@ class EventTypes: Aliases: Final = "m.room.aliases" Redaction: Final = "m.room.redaction" ThirdPartyInvite: Final = "m.room.third_party_invite" - RelatedGroups: Final = "m.room.related_groups" RoomHistoryVisibility: Final = "m.room.history_visibility" CanonicalAlias: Final = "m.room.canonical_alias" diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py index ed92c2e910..0dfa00df44 100644 --- a/synapse/appservice/__init__.py +++ b/synapse/appservice/__init__.py @@ -70,7 +70,6 @@ class ApplicationService: def __init__( self, token: str, - hostname: str, id: str, sender: str, url: Optional[str] = None, @@ -88,7 +87,6 @@ class ApplicationService: ) # url must not end with a slash self.hs_token = hs_token self.sender = sender - self.server_name = hostname self.namespaces = self._check_namespaces(namespaces) self.id = id self.ip_range_whitelist = ip_range_whitelist diff --git a/synapse/config/appservice.py b/synapse/config/appservice.py index 24498e7944..16f93273b3 100644 --- a/synapse/config/appservice.py +++ b/synapse/config/appservice.py @@ -179,7 +179,6 @@ def _load_appservice( return ApplicationService( token=as_info["as_token"], - hostname=hostname, url=as_info["url"], namespaces=as_info["namespaces"], hs_token=as_info["hs_token"], diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 5c91d33f58..e1341dd9bb 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -468,7 +468,6 @@ class RoomCreationHandler: (EventTypes.RoomAvatar, ""), (EventTypes.RoomEncryption, ""), (EventTypes.ServerACL, ""), - (EventTypes.RelatedGroups, ""), (EventTypes.PowerLevels, ""), ] diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py index 40ac377ca9..deffdc19ce 100644 --- a/synapse/storage/databases/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py @@ -276,10 +276,6 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): (SELECT 1 FROM profiles WHERE profiles.avatar_url = '{media_prefix}' || lmr.media_id) - AND NOT EXISTS - (SELECT 1 - FROM groups - WHERE groups.avatar_url = '{media_prefix}' || lmr.media_id) AND NOT EXISTS (SELECT 1 FROM room_memberships diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index d547df8a64..bc75ddd3e9 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -404,7 +404,6 @@ class AuthTestCase(unittest.HomeserverTestCase): appservice = ApplicationService( "abcd", - self.hs.config.server.server_name, id="1234", namespaces={ "users": [{"regex": "@_appservice.*:sender", "exclusive": True}] @@ -433,7 +432,6 @@ class AuthTestCase(unittest.HomeserverTestCase): appservice = ApplicationService( "abcd", - self.hs.config.server.server_name, id="1234", namespaces={ "users": [{"regex": "@_appservice.*:sender", "exclusive": True}] diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py index 483d5463ad..f661a9ff8e 100644 --- a/tests/api/test_ratelimiting.py +++ b/tests/api/test_ratelimiting.py @@ -31,7 +31,6 @@ class TestRatelimiter(unittest.HomeserverTestCase): def test_allowed_appservice_ratelimited_via_can_requester_do_action(self): appservice = ApplicationService( None, - "example.com", id="foo", rate_limited=True, sender="@as:example.com", @@ -62,7 +61,6 @@ class TestRatelimiter(unittest.HomeserverTestCase): def test_allowed_appservice_via_can_requester_do_action(self): appservice = ApplicationService( None, - "example.com", id="foo", rate_limited=False, sender="@as:example.com", diff --git a/tests/appservice/test_api.py b/tests/appservice/test_api.py index 3e0db4dd98..532b676365 100644 --- a/tests/appservice/test_api.py +++ b/tests/appservice/test_api.py @@ -37,7 +37,6 @@ class ApplicationServiceApiTestCase(unittest.HomeserverTestCase): url=URL, token="unused", hs_token=TOKEN, - hostname="myserver", ) def test_query_3pe_authenticates_token(self): diff --git a/tests/appservice/test_appservice.py b/tests/appservice/test_appservice.py index 7135362f76..3018d3fc6f 100644 --- a/tests/appservice/test_appservice.py +++ b/tests/appservice/test_appservice.py @@ -33,7 +33,6 @@ class ApplicationServiceTestCase(unittest.TestCase): sender="@as:test", url="some_url", token="some_token", - hostname="matrix.org", # only used by get_groups_for_user ) self.event = Mock( event_id="$abc:xyz", diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py index 0e100c404d..d96d5aa138 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py @@ -697,7 +697,6 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase): # Create an application service appservice = ApplicationService( token=random_string(10), - hostname="example.com", id=random_string(10), sender="@as:example.com", rate_limited=False, @@ -776,7 +775,6 @@ class ApplicationServicesHandlerDeviceListsTestCase(unittest.HomeserverTestCase) # Create an appservice that is interested in "local_user" appservice = ApplicationService( token=random_string(10), - hostname="example.com", id=random_string(10), sender="@as:example.com", rate_limited=False, @@ -843,7 +841,6 @@ class ApplicationServicesHandlerOtkCountsTestCase(unittest.HomeserverTestCase): self._service_token = "VERYSECRET" self._service = ApplicationService( self._service_token, - "as1.invalid", "as1", "@as.sender:test", namespaces={ diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index a68c2ffd45..9e39cd97e5 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -60,7 +60,6 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): self.appservice = ApplicationService( token="i_am_an_app_service", - hostname="test", id="1234", namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]}, # Note: this user does not match the regex above, so that tests diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py index e0a11da97b..a43a137273 100644 --- a/tests/rest/client/test_account.py +++ b/tests/rest/client/test_account.py @@ -548,7 +548,6 @@ class WhoamiTestCase(unittest.HomeserverTestCase): appservice = ApplicationService( as_token, - self.hs.config.server.server_name, id="1234", namespaces={"users": [{"regex": user_id, "exclusive": True}]}, sender=user_id, diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py index 4920468f7a..f4ea1209d9 100644 --- a/tests/rest/client/test_login.py +++ b/tests/rest/client/test_login.py @@ -1112,7 +1112,6 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): self.service = ApplicationService( id="unique_identifier", token="some_token", - hostname="example.com", sender="@asbot:example.com", namespaces={ ApplicationService.NS_USERS: [ @@ -1125,7 +1124,6 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): self.another_service = ApplicationService( id="another__identifier", token="another_token", - hostname="example.com", sender="@as2bot:example.com", namespaces={ ApplicationService.NS_USERS: [ diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py index 9aebf1735a..afb08b2736 100644 --- a/tests/rest/client/test_register.py +++ b/tests/rest/client/test_register.py @@ -56,7 +56,6 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): appservice = ApplicationService( as_token, - self.hs.config.server.server_name, id="1234", namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]}, sender="@as:test", @@ -80,7 +79,6 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): appservice = ApplicationService( as_token, - self.hs.config.server.server_name, id="1234", namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]}, sender="@as:test", diff --git a/tests/rest/client/test_room_batch.py b/tests/rest/client/test_room_batch.py index 1b7ee08ab2..9d5cb60d16 100644 --- a/tests/rest/client/test_room_batch.py +++ b/tests/rest/client/test_room_batch.py @@ -71,7 +71,6 @@ class RoomBatchTestCase(unittest.HomeserverTestCase): self.appservice = ApplicationService( token="i_am_an_app_service", - hostname="test", id="1234", namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]}, # Note: this user does not have to match the regex above diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py index 7f1964eb6a..5b60cf5285 100644 --- a/tests/storage/test_user_directory.py +++ b/tests/storage/test_user_directory.py @@ -134,7 +134,6 @@ class UserDirectoryInitialPopulationTestcase(HomeserverTestCase): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: self.appservice = ApplicationService( token="i_am_an_app_service", - hostname="test", id="1234", namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]}, sender="@as:test", diff --git a/tests/test_mau.py b/tests/test_mau.py index 5bbc361aa2..f14fcb7db9 100644 --- a/tests/test_mau.py +++ b/tests/test_mau.py @@ -105,7 +105,6 @@ class TestMauLimit(unittest.HomeserverTestCase): self.store.services_cache.append( ApplicationService( token=as_token, - hostname=self.hs.hostname, id="SomeASID", sender="@as_sender:test", namespaces={"users": [{"regex": "@as_*", "exclusive": True}]}, @@ -251,7 +250,6 @@ class TestMauLimit(unittest.HomeserverTestCase): self.store.services_cache.append( ApplicationService( token=as_token_1, - hostname=self.hs.hostname, id="SomeASID", sender="@as_sender_1:test", namespaces={"users": [{"regex": "@as_1.*", "exclusive": True}]}, @@ -262,7 +260,6 @@ class TestMauLimit(unittest.HomeserverTestCase): self.store.services_cache.append( ApplicationService( token=as_token_2, - hostname=self.hs.hostname, id="AnotherASID", sender="@as_sender_2:test", namespaces={"users": [{"regex": "@as_2.*", "exclusive": True}]}, From 782cb7420a88fe29241dcecdfee91e25940b2ac7 Mon Sep 17 00:00:00 2001 From: Michael Telatynski <7t3chguy@gmail.com> Date: Wed, 1 Jun 2022 15:57:09 +0100 Subject: [PATCH 05/14] Fix complement tests using the wrong path (#12933) --- .github/workflows/tests.yml | 2 +- changelog.d/12933.misc | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) create mode 100644 changelog.d/12933.misc diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 3693cf06c3..83ab727378 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -372,7 +372,7 @@ jobs: # Attempt to check out the same branch of Complement as the PR. If it # doesn't exist, fallback to HEAD. - name: Checkout complement - run: .ci/scripts/checkout_complement.sh + run: synapse/.ci/scripts/checkout_complement.sh - run: | set -o pipefail diff --git a/changelog.d/12933.misc b/changelog.d/12933.misc new file mode 100644 index 0000000000..e29bf02407 --- /dev/null +++ b/changelog.d/12933.misc @@ -0,0 +1 @@ +Test Synapse against Complement with workers. From 888a29f4127723a8d048ce47cff37ee8a7a6f1b9 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 1 Jun 2022 16:02:53 +0100 Subject: [PATCH 06/14] Wait for lazy join to complete when getting current state (#12872) --- changelog.d/12872.misc | 1 + synapse/events/third_party_rules.py | 3 +- synapse/federation/federation_server.py | 4 +- synapse/handlers/device.py | 2 +- synapse/handlers/directory.py | 7 +- synapse/handlers/federation.py | 7 +- synapse/handlers/message.py | 2 +- synapse/handlers/presence.py | 6 +- synapse/handlers/register.py | 3 +- synapse/handlers/room.py | 13 +- synapse/handlers/room_list.py | 3 +- synapse/handlers/room_member.py | 5 +- synapse/handlers/room_summary.py | 11 +- synapse/handlers/stats.py | 6 +- synapse/handlers/sync.py | 13 +- synapse/handlers/user_directory.py | 6 +- synapse/module_api/__init__.py | 19 ++- synapse/push/mailer.py | 4 +- synapse/rest/admin/rooms.py | 3 +- synapse/storage/_base.py | 2 +- synapse/storage/controllers/__init__.py | 4 +- synapse/storage/controllers/persist_events.py | 4 +- synapse/storage/controllers/state.py | 112 +++++++++++++++++- synapse/storage/databases/main/room.py | 18 +++ synapse/storage/databases/main/state.py | 38 ++---- .../storage/databases/main/state_deltas.py | 4 +- .../storage/databases/main/user_directory.py | 4 +- .../util/partial_state_events_tracker.py | 60 ++++++++++ tests/handlers/test_federation.py | 6 +- tests/handlers/test_federation_event.py | 4 +- tests/handlers/test_typing.py | 2 +- tests/rest/client/test_upgrade_room.py | 8 +- .../util/test_partial_state_events_tracker.py | 59 ++++++++- 33 files changed, 361 insertions(+), 82 deletions(-) create mode 100644 changelog.d/12872.misc diff --git a/changelog.d/12872.misc b/changelog.d/12872.misc new file mode 100644 index 0000000000..f60a756f21 --- /dev/null +++ b/changelog.d/12872.misc @@ -0,0 +1 @@ +Faster room joins: when querying the current state of the room, wait for state to be populated. diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py index 9f4ff9799c..35f3f3690f 100644 --- a/synapse/events/third_party_rules.py +++ b/synapse/events/third_party_rules.py @@ -152,6 +152,7 @@ class ThirdPartyEventRules: self.third_party_rules = None self.store = hs.get_datastores().main + self._storage_controllers = hs.get_storage_controllers() self._check_event_allowed_callbacks: List[CHECK_EVENT_ALLOWED_CALLBACK] = [] self._on_create_room_callbacks: List[ON_CREATE_ROOM_CALLBACK] = [] @@ -463,7 +464,7 @@ class ThirdPartyEventRules: Returns: A dict mapping (event type, state key) to state event. """ - state_ids = await self.store.get_filtered_current_state_ids(room_id) + state_ids = await self._storage_controllers.state.get_current_state_ids(room_id) room_state_events = await self.store.get_events(state_ids.values()) state_events = {} diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 12591dc8db..f4af121c4d 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -118,6 +118,8 @@ class FederationServer(FederationBase): self.state = hs.get_state_handler() self._event_auth_handler = hs.get_event_auth_handler() + self._state_storage_controller = hs.get_storage_controllers().state + self.device_handler = hs.get_device_handler() # Ensure the following handlers are loaded since they register callbacks @@ -1221,7 +1223,7 @@ class FederationServer(FederationBase): Raises: AuthError if the server does not match the ACL """ - state_ids = await self.store.get_current_state_ids(room_id) + state_ids = await self._state_storage_controller.get_current_state_ids(room_id) acl_event_id = state_ids.get((EventTypes.ServerACL, "")) if not acl_event_id: diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 72faf2ee38..a0cbeedc30 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -166,7 +166,7 @@ class DeviceWorkerHandler: possibly_changed = set(changed) possibly_left = set() for room_id in rooms_changed: - current_state_ids = await self.store.get_current_state_ids(room_id) + current_state_ids = await self._state_storage.get_current_state_ids(room_id) # The user may have left the room # TODO: Check if they actually did or if we were just invited. diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index 4aa33df884..44e84698c4 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -45,6 +45,7 @@ class DirectoryHandler: self.appservice_handler = hs.get_application_service_handler() self.event_creation_handler = hs.get_event_creation_handler() self.store = hs.get_datastores().main + self._storage_controllers = hs.get_storage_controllers() self.config = hs.config self.enable_room_list_search = hs.config.roomdirectory.enable_room_list_search self.require_membership = hs.config.server.require_membership_for_aliases @@ -463,7 +464,11 @@ class DirectoryHandler: making_public = visibility == "public" if making_public: room_aliases = await self.store.get_aliases_for_room(room_id) - canonical_alias = await self.store.get_canonical_alias_for_room(room_id) + canonical_alias = ( + await self._storage_controllers.state.get_canonical_alias_for_room( + room_id + ) + ) if canonical_alias: room_aliases.append(canonical_alias) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 659f279441..b212ee2172 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -750,7 +750,9 @@ class FederationHandler: # Note that this requires the /send_join request to come back to the # same server. if room_version.msc3083_join_rules: - state_ids = await self.store.get_current_state_ids(room_id) + state_ids = await self._state_storage_controller.get_current_state_ids( + room_id + ) if await self._event_auth_handler.has_restricted_join_rules( state_ids, room_version ): @@ -1552,6 +1554,9 @@ class FederationHandler: success = await self.store.clear_partial_state_room(room_id) if success: logger.info("State resync complete for %s", room_id) + self._storage_controllers.state.notify_room_un_partial_stated( + room_id + ) # TODO(faster_joins) update room stats and user directory? return diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index ac911a2ddc..081625f0bd 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -217,7 +217,7 @@ class MessageHandler: ) if membership == Membership.JOIN: - state_ids = await self.store.get_filtered_current_state_ids( + state_ids = await self._state_storage_controller.get_current_state_ids( room_id, state_filter=state_filter ) room_state = await self.store.get_events(state_ids.values()) diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index bf112b9e1e..895ea63ed3 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -134,6 +134,7 @@ class BasePresenceHandler(abc.ABC): def __init__(self, hs: "HomeServer"): self.clock = hs.get_clock() self.store = hs.get_datastores().main + self._storage_controllers = hs.get_storage_controllers() self.presence_router = hs.get_presence_router() self.state = hs.get_state_handler() self.is_mine_id = hs.is_mine_id @@ -1348,7 +1349,10 @@ class PresenceHandler(BasePresenceHandler): self._event_pos, room_max_stream_ordering, ) - max_pos, deltas = await self.store.get_current_state_deltas( + ( + max_pos, + deltas, + ) = await self._storage_controllers.state.get_current_state_deltas( self._event_pos, room_max_stream_ordering ) diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 05bb1e0225..338204287f 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -87,6 +87,7 @@ class LoginDict(TypedDict): class RegistrationHandler: def __init__(self, hs: "HomeServer"): self.store = hs.get_datastores().main + self._storage_controllers = hs.get_storage_controllers() self.clock = hs.get_clock() self.hs = hs self.auth = hs.get_auth() @@ -528,7 +529,7 @@ class RegistrationHandler: if requires_invite: # If the server is in the room, check if the room is public. - state = await self.store.get_filtered_current_state_ids( + state = await self._storage_controllers.state.get_current_state_ids( room_id, StateFilter.from_types([(EventTypes.JoinRules, "")]) ) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index e1341dd9bb..e2b0e519d4 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -107,6 +107,7 @@ class EventContext: class RoomCreationHandler: def __init__(self, hs: "HomeServer"): self.store = hs.get_datastores().main + self._storage_controllers = hs.get_storage_controllers() self.auth = hs.get_auth() self.clock = hs.get_clock() self.hs = hs @@ -480,8 +481,10 @@ class RoomCreationHandler: if room_type == RoomTypes.SPACE: types_to_copy.append((EventTypes.SpaceChild, None)) - old_room_state_ids = await self.store.get_filtered_current_state_ids( - old_room_id, StateFilter.from_types(types_to_copy) + old_room_state_ids = ( + await self._storage_controllers.state.get_current_state_ids( + old_room_id, StateFilter.from_types(types_to_copy) + ) ) # map from event_id to BaseEvent old_room_state_events = await self.store.get_events(old_room_state_ids.values()) @@ -558,8 +561,10 @@ class RoomCreationHandler: ) # Transfer membership events - old_room_member_state_ids = await self.store.get_filtered_current_state_ids( - old_room_id, StateFilter.from_types([(EventTypes.Member, None)]) + old_room_member_state_ids = ( + await self._storage_controllers.state.get_current_state_ids( + old_room_id, StateFilter.from_types([(EventTypes.Member, None)]) + ) ) # map from event_id to BaseEvent diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py index f3577b5d5a..183d4ae3c4 100644 --- a/synapse/handlers/room_list.py +++ b/synapse/handlers/room_list.py @@ -50,6 +50,7 @@ EMPTY_THIRD_PARTY_ID = ThirdPartyInstanceID(None, None) class RoomListHandler: def __init__(self, hs: "HomeServer"): self.store = hs.get_datastores().main + self._storage_controllers = hs.get_storage_controllers() self.hs = hs self.enable_room_list_search = hs.config.roomdirectory.enable_room_list_search self.response_cache: ResponseCache[ @@ -274,7 +275,7 @@ class RoomListHandler: if aliases: result["aliases"] = aliases - current_state_ids = await self.store.get_current_state_ids( + current_state_ids = await self._storage_controllers.state.get_current_state_ids( room_id, on_invalidate=cache_context.invalidate ) diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 00662dc961..70c674ff8e 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -68,6 +68,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): def __init__(self, hs: "HomeServer"): self.hs = hs self.store = hs.get_datastores().main + self._storage_controllers = hs.get_storage_controllers() self.auth = hs.get_auth() self.state_handler = hs.get_state_handler() self.config = hs.config @@ -994,7 +995,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): # If the host is in the room, but not one of the authorised hosts # for restricted join rules, a remote join must be used. room_version = await self.store.get_room_version(room_id) - current_state_ids = await self.store.get_current_state_ids(room_id) + current_state_ids = await self._storage_controllers.state.get_current_state_ids( + room_id + ) # If restricted join rules are not being used, a local join can always # be used. diff --git a/synapse/handlers/room_summary.py b/synapse/handlers/room_summary.py index 75aee6a111..13098f56ed 100644 --- a/synapse/handlers/room_summary.py +++ b/synapse/handlers/room_summary.py @@ -90,6 +90,7 @@ class RoomSummaryHandler: def __init__(self, hs: "HomeServer"): self._event_auth_handler = hs.get_event_auth_handler() self._store = hs.get_datastores().main + self._storage_controllers = hs.get_storage_controllers() self._event_serializer = hs.get_event_client_serializer() self._server_name = hs.hostname self._federation_client = hs.get_federation_client() @@ -537,7 +538,7 @@ class RoomSummaryHandler: Returns: True if the room is accessible to the requesting user or server. """ - state_ids = await self._store.get_current_state_ids(room_id) + state_ids = await self._storage_controllers.state.get_current_state_ids(room_id) # If there's no state for the room, it isn't known. if not state_ids: @@ -702,7 +703,9 @@ class RoomSummaryHandler: # there should always be an entry assert stats is not None, "unable to retrieve stats for %s" % (room_id,) - current_state_ids = await self._store.get_current_state_ids(room_id) + current_state_ids = await self._storage_controllers.state.get_current_state_ids( + room_id + ) create_event = await self._store.get_event( current_state_ids[(EventTypes.Create, "")] ) @@ -760,7 +763,9 @@ class RoomSummaryHandler: """ # look for child rooms/spaces. - current_state_ids = await self._store.get_current_state_ids(room_id) + current_state_ids = await self._storage_controllers.state.get_current_state_ids( + room_id + ) events = await self._store.get_events_as_list( [ diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py index 436cd971ce..f45e06eb0e 100644 --- a/synapse/handlers/stats.py +++ b/synapse/handlers/stats.py @@ -40,6 +40,7 @@ class StatsHandler: def __init__(self, hs: "HomeServer"): self.hs = hs self.store = hs.get_datastores().main + self._storage_controllers = hs.get_storage_controllers() self.state = hs.get_state_handler() self.server_name = hs.hostname self.clock = hs.get_clock() @@ -105,7 +106,10 @@ class StatsHandler: logger.debug( "Processing room stats %s->%s", self.pos, room_max_stream_ordering ) - max_pos, deltas = await self.store.get_current_state_deltas( + ( + max_pos, + deltas, + ) = await self._storage_controllers.state.get_current_state_deltas( self.pos, room_max_stream_ordering ) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index a1d41358d9..b4ead79f97 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -506,8 +506,10 @@ class SyncHandler: # ensure that we always include current state in the timeline current_state_ids: FrozenSet[str] = frozenset() if any(e.is_state() for e in recents): - current_state_ids_map = await self.store.get_current_state_ids( - room_id + current_state_ids_map = ( + await self._state_storage_controller.get_current_state_ids( + room_id + ) ) current_state_ids = frozenset(current_state_ids_map.values()) @@ -574,8 +576,11 @@ class SyncHandler: # ensure that we always include current state in the timeline current_state_ids = frozenset() if any(e.is_state() for e in loaded_recents): - current_state_ids_map = await self.store.get_current_state_ids( - room_id + # FIXME(faster_joins): We use the partial state here as + # we don't want to block `/sync` on finishing a lazy join. + # Is this the correct way of doing it? + current_state_ids_map = ( + await self.store.get_partial_current_state_ids(room_id) ) current_state_ids = frozenset(current_state_ids_map.values()) diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py index 74f7fdfe6c..8c3c52e1ca 100644 --- a/synapse/handlers/user_directory.py +++ b/synapse/handlers/user_directory.py @@ -56,6 +56,7 @@ class UserDirectoryHandler(StateDeltasHandler): super().__init__(hs) self.store = hs.get_datastores().main + self._storage_controllers = hs.get_storage_controllers() self.server_name = hs.hostname self.clock = hs.get_clock() self.notifier = hs.get_notifier() @@ -174,7 +175,10 @@ class UserDirectoryHandler(StateDeltasHandler): logger.debug( "Processing user stats %s->%s", self.pos, room_max_stream_ordering ) - max_pos, deltas = await self.store.get_current_state_deltas( + ( + max_pos, + deltas, + ) = await self._storage_controllers.state.get_current_state_deltas( self.pos, room_max_stream_ordering ) diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index b7451fc870..a8ad575fcd 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -194,6 +194,7 @@ class ModuleApi: self._store: Union[ DataStore, "GenericWorkerSlavedStore" ] = hs.get_datastores().main + self._storage_controllers = hs.get_storage_controllers() self._auth = hs.get_auth() self._auth_handler = auth_handler self._server_name = hs.hostname @@ -911,7 +912,7 @@ class ModuleApi: The filtered state events in the room. """ state_ids = yield defer.ensureDeferred( - self._store.get_filtered_current_state_ids( + self._storage_controllers.state.get_current_state_ids( room_id=room_id, state_filter=StateFilter.from_types(types) ) ) @@ -1289,20 +1290,16 @@ class ModuleApi: # regardless of their state key ] """ + state_filter = None if event_filter: # If a filter was provided, turn it into a StateFilter and retrieve a filtered # view of the state. state_filter = StateFilter.from_types(event_filter) - state_ids = await self._store.get_filtered_current_state_ids( - room_id, - state_filter, - ) - else: - # If no filter was provided, get the whole state. We could also reuse the call - # to get_filtered_current_state_ids above, with `state_filter = StateFilter.all()`, - # but get_filtered_current_state_ids isn't cached and `get_current_state_ids` - # is, so using the latter when we can is better for perf. - state_ids = await self._store.get_current_state_ids(room_id) + + state_ids = await self._storage_controllers.state.get_current_state_ids( + room_id, + state_filter, + ) state_events = await self._store.get_events(state_ids.values()) diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py index 63aefd07f5..015c19b2d9 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py @@ -255,7 +255,9 @@ class Mailer: user_display_name = user_id async def _fetch_room_state(room_id: str) -> None: - room_state = await self.store.get_current_state_ids(room_id) + room_state = await self._state_storage_controller.get_current_state_ids( + room_id + ) state_by_room[room_id] = room_state # Run at most 3 of these at once: sync does 10 at a time but email diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index 356d6f74d7..1cacd1a4f0 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -418,6 +418,7 @@ class RoomStateRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self.auth = hs.get_auth() self.store = hs.get_datastores().main + self._storage_controllers = hs.get_storage_controllers() self.clock = hs.get_clock() self._event_serializer = hs.get_event_client_serializer() @@ -430,7 +431,7 @@ class RoomStateRestServlet(RestServlet): if not ret: raise NotFoundError("Room not found") - event_ids = await self.store.get_current_state_ids(room_id) + event_ids = await self._storage_controllers.state.get_current_state_ids(room_id) events = await self.store.get_events(event_ids.values()) now = self.clock.time_msec() room_state = self._event_serializer.serialize_events(events.values(), now) diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 8df80664a2..57bd74700e 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -77,7 +77,7 @@ class SQLBaseStore(metaclass=ABCMeta): # Purge other caches based on room state. self._attempt_to_invalidate_cache("get_room_summary", (room_id,)) - self._attempt_to_invalidate_cache("get_current_state_ids", (room_id,)) + self._attempt_to_invalidate_cache("get_partial_current_state_ids", (room_id,)) def _attempt_to_invalidate_cache( self, cache_name: str, key: Optional[Collection[Any]] diff --git a/synapse/storage/controllers/__init__.py b/synapse/storage/controllers/__init__.py index 992261d07b..55649719f6 100644 --- a/synapse/storage/controllers/__init__.py +++ b/synapse/storage/controllers/__init__.py @@ -18,7 +18,7 @@ from synapse.storage.controllers.persist_events import ( EventsPersistenceStorageController, ) from synapse.storage.controllers.purge_events import PurgeEventsStorageController -from synapse.storage.controllers.state import StateGroupStorageController +from synapse.storage.controllers.state import StateStorageController from synapse.storage.databases import Databases from synapse.storage.databases.main import DataStore @@ -39,7 +39,7 @@ class StorageControllers: self.main = stores.main self.purge_events = PurgeEventsStorageController(hs, stores) - self.state = StateGroupStorageController(hs, stores) + self.state = StateStorageController(hs, stores) self.persistence = None if stores.persist_events: diff --git a/synapse/storage/controllers/persist_events.py b/synapse/storage/controllers/persist_events.py index ef8c135b12..4caaa81808 100644 --- a/synapse/storage/controllers/persist_events.py +++ b/synapse/storage/controllers/persist_events.py @@ -994,7 +994,7 @@ class EventsPersistenceStorageController: Assumes that we are only persisting events for one room at a time. """ - existing_state = await self.main_store.get_current_state_ids(room_id) + existing_state = await self.main_store.get_partial_current_state_ids(room_id) to_delete = [key for key in existing_state if key not in current_state] @@ -1083,7 +1083,7 @@ class EventsPersistenceStorageController: # The server will leave the room, so we go and find out which remote # users will still be joined when we leave. if current_state is None: - current_state = await self.main_store.get_current_state_ids(room_id) + current_state = await self.main_store.get_partial_current_state_ids(room_id) current_state = dict(current_state) for key in delta.to_delete: current_state.pop(key, None) diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py index 0f09953086..9952b00493 100644 --- a/synapse/storage/controllers/state.py +++ b/synapse/storage/controllers/state.py @@ -14,7 +14,9 @@ import logging from typing import ( TYPE_CHECKING, + Any, Awaitable, + Callable, Collection, Dict, Iterable, @@ -24,9 +26,13 @@ from typing import ( Tuple, ) +from synapse.api.constants import EventTypes from synapse.events import EventBase from synapse.storage.state import StateFilter -from synapse.storage.util.partial_state_events_tracker import PartialStateEventsTracker +from synapse.storage.util.partial_state_events_tracker import ( + PartialCurrentStateTracker, + PartialStateEventsTracker, +) from synapse.types import MutableStateMap, StateMap if TYPE_CHECKING: @@ -36,17 +42,27 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -class StateGroupStorageController: - """High level interface to fetching state for event.""" +class StateStorageController: + """High level interface to fetching state for an event, or the current state + in a room. + """ def __init__(self, hs: "HomeServer", stores: "Databases"): self._is_mine_id = hs.is_mine_id self.stores = stores self._partial_state_events_tracker = PartialStateEventsTracker(stores.main) + self._partial_state_room_tracker = PartialCurrentStateTracker(stores.main) def notify_event_un_partial_stated(self, event_id: str) -> None: self._partial_state_events_tracker.notify_un_partial_stated(event_id) + def notify_room_un_partial_stated(self, room_id: str) -> None: + """Notify that the room no longer has any partial state. + + Must be called after `DataStore.clear_partial_state_room` + """ + self._partial_state_room_tracker.notify_un_partial_stated(room_id) + async def get_state_group_delta( self, state_group: int ) -> Tuple[Optional[int], Optional[StateMap[str]]]: @@ -349,3 +365,93 @@ class StateGroupStorageController: return await self.stores.state.store_state_group( event_id, room_id, prev_group, delta_ids, current_state_ids ) + + async def get_current_state_ids( + self, + room_id: str, + state_filter: Optional[StateFilter] = None, + on_invalidate: Optional[Callable[[], None]] = None, + ) -> StateMap[str]: + """Get the current state event ids for a room based on the + current_state_events table. + + If a state filter is given (that is not `StateFilter.all()`) the query + result is *not* cached. + + Args: + room_id: The room to get the state IDs of. state_filter: The state + filter used to fetch state from the + database. + on_invalidate: Callback for when the `get_current_state_ids` cache + for the room gets invalidated. + + Returns: + The current state of the room. + """ + if not state_filter or state_filter.must_await_full_state(self._is_mine_id): + await self._partial_state_room_tracker.await_full_state(room_id) + + if state_filter and not state_filter.is_full(): + return await self.stores.main.get_partial_filtered_current_state_ids( + room_id, state_filter + ) + else: + return await self.stores.main.get_partial_current_state_ids( + room_id, on_invalidate=on_invalidate + ) + + async def get_canonical_alias_for_room(self, room_id: str) -> Optional[str]: + """Get canonical alias for room, if any + + Args: + room_id: The room ID + + Returns: + The canonical alias, if any + """ + + state = await self.get_current_state_ids( + room_id, StateFilter.from_types([(EventTypes.CanonicalAlias, "")]) + ) + + event_id = state.get((EventTypes.CanonicalAlias, "")) + if not event_id: + return None + + event = await self.stores.main.get_event(event_id, allow_none=True) + if not event: + return None + + return event.content.get("canonical_alias") + + async def get_current_state_deltas( + self, prev_stream_id: int, max_stream_id: int + ) -> Tuple[int, List[Dict[str, Any]]]: + """Fetch a list of room state changes since the given stream id + + Each entry in the result contains the following fields: + - stream_id (int) + - room_id (str) + - type (str): event type + - state_key (str): + - event_id (str|None): new event_id for this state key. None if the + state has been deleted. + - prev_event_id (str|None): previous event_id for this state key. None + if it's new state. + + Args: + prev_stream_id: point to get changes since (exclusive) + max_stream_id: the point that we know has been correctly persisted + - ie, an upper limit to return changes from. + + Returns: + A tuple consisting of: + - the stream id which these results go up to + - list of current_state_delta_stream rows. If it is empty, we are + up to date. + """ + # FIXME(faster_joins): what do we do here? + + return await self.stores.main.get_partial_current_state_deltas( + prev_stream_id, max_stream_id + ) diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index cfd8ce1624..68d4fc2e64 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -1139,6 +1139,24 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): keyvalues={"room_id": room_id}, ) + async def is_partial_state_room(self, room_id: str) -> bool: + """Checks if this room has partial state. + + Returns true if this is a "partial-state" room, which means that the state + at events in the room, and `current_state_events`, may not yet be + complete. + """ + + entry = await self.db_pool.simple_select_one_onecol( + table="partial_state_rooms", + keyvalues={"room_id": room_id}, + retcol="room_id", + allow_none=True, + desc="is_partial_state_room", + ) + + return entry is not None + class _BackgroundUpdates: REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory" diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py index 3f2be3854b..bdd00273cd 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py @@ -242,7 +242,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): Raises: NotFoundError if the room is unknown """ - state_ids = await self.get_current_state_ids(room_id) + state_ids = await self.get_partial_current_state_ids(room_id) if not state_ids: raise NotFoundError(f"Current state for room {room_id} is empty") @@ -258,10 +258,12 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): return create_event @cached(max_entries=100000, iterable=True) - async def get_current_state_ids(self, room_id: str) -> StateMap[str]: + async def get_partial_current_state_ids(self, room_id: str) -> StateMap[str]: """Get the current state event ids for a room based on the current_state_events table. + This may be the partial state if we're lazy joining the room. + Args: room_id: The room to get the state IDs of. @@ -280,17 +282,19 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): return {(intern_string(r[0]), intern_string(r[1])): r[2] for r in txn} return await self.db_pool.runInteraction( - "get_current_state_ids", _get_current_state_ids_txn + "get_partial_current_state_ids", _get_current_state_ids_txn ) # FIXME: how should this be cached? - async def get_filtered_current_state_ids( + async def get_partial_filtered_current_state_ids( self, room_id: str, state_filter: Optional[StateFilter] = None ) -> StateMap[str]: """Get the current state event of a given type for a room based on the current_state_events table. This may not be as up-to-date as the result of doing a fresh state resolution as per state_handler.get_current_state + This may be the partial state if we're lazy joining the room. + Args: room_id state_filter: The state filter used to fetch state @@ -306,7 +310,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): if not where_clause: # We delegate to the cached version - return await self.get_current_state_ids(room_id) + return await self.get_partial_current_state_ids(room_id) def _get_filtered_current_state_ids_txn( txn: LoggingTransaction, @@ -334,30 +338,6 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): "get_filtered_current_state_ids", _get_filtered_current_state_ids_txn ) - async def get_canonical_alias_for_room(self, room_id: str) -> Optional[str]: - """Get canonical alias for room, if any - - Args: - room_id: The room ID - - Returns: - The canonical alias, if any - """ - - state = await self.get_filtered_current_state_ids( - room_id, StateFilter.from_types([(EventTypes.CanonicalAlias, "")]) - ) - - event_id = state.get((EventTypes.CanonicalAlias, "")) - if not event_id: - return None - - event = await self.get_event(event_id, allow_none=True) - if not event: - return None - - return event.content.get("canonical_alias") - @cached(max_entries=50000) async def _get_state_group_for_event(self, event_id: str) -> Optional[int]: return await self.db_pool.simple_select_one_onecol( diff --git a/synapse/storage/databases/main/state_deltas.py b/synapse/storage/databases/main/state_deltas.py index 188afec332..445213e12a 100644 --- a/synapse/storage/databases/main/state_deltas.py +++ b/synapse/storage/databases/main/state_deltas.py @@ -27,7 +27,7 @@ class StateDeltasStore(SQLBaseStore): # attribute. TODO: can we get static analysis to enforce this? _curr_state_delta_stream_cache: StreamChangeCache - async def get_current_state_deltas( + async def get_partial_current_state_deltas( self, prev_stream_id: int, max_stream_id: int ) -> Tuple[int, List[Dict[str, Any]]]: """Fetch a list of room state changes since the given stream id @@ -42,6 +42,8 @@ class StateDeltasStore(SQLBaseStore): - prev_event_id (str|None): previous event_id for this state key. None if it's new state. + This may be the partial state if we're lazy joining the room. + Args: prev_stream_id: point to get changes since (exclusive) max_stream_id: the point that we know has been correctly persisted diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index 2282242e9d..ddb25b5cea 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -441,7 +441,9 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): (EventTypes.RoomHistoryVisibility, ""), ) - current_state_ids = await self.get_filtered_current_state_ids( # type: ignore[attr-defined] + # Getting the partial state is fine, as we're not looking at membership + # events. + current_state_ids = await self.get_partial_filtered_current_state_ids( # type: ignore[attr-defined] room_id, StateFilter.from_types(types_to_filter) ) diff --git a/synapse/storage/util/partial_state_events_tracker.py b/synapse/storage/util/partial_state_events_tracker.py index a61a951ef0..211437cfaa 100644 --- a/synapse/storage/util/partial_state_events_tracker.py +++ b/synapse/storage/util/partial_state_events_tracker.py @@ -21,6 +21,7 @@ from twisted.internet.defer import Deferred from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable from synapse.storage.databases.main.events_worker import EventsWorkerStore +from synapse.storage.databases.main.room import RoomWorkerStore from synapse.util import unwrapFirstError logger = logging.getLogger(__name__) @@ -118,3 +119,62 @@ class PartialStateEventsTracker: observer_set.discard(observer) if not observer_set: del self._observers[event_id] + + +class PartialCurrentStateTracker: + """Keeps track of which rooms have partial state, after partial-state joins""" + + def __init__(self, store: RoomWorkerStore): + self._store = store + + # a map from room id to a set of Deferreds which are waiting for that room to be + # un-partial-stated. + self._observers: Dict[str, Set[Deferred[None]]] = defaultdict(set) + + def notify_un_partial_stated(self, room_id: str) -> None: + """Notify that we now have full current state for a given room + + Unblocks any callers to await_full_state() for that room. + + Args: + room_id: the room that now has full current state. + """ + observers = self._observers.pop(room_id, None) + if not observers: + return + logger.info( + "Notifying %i things waiting for un-partial-stating of room %s", + len(observers), + room_id, + ) + with PreserveLoggingContext(): + for o in observers: + o.callback(None) + + async def await_full_state(self, room_id: str) -> None: + # We add the deferred immediately so that the DB call to check for + # partial state doesn't race when we unpartial the room. + d: Deferred[None] = Deferred() + self._observers.setdefault(room_id, set()).add(d) + + try: + # Check if the room has partial current state or not. + has_partial_state = await self._store.is_partial_state_room(room_id) + if not has_partial_state: + return + + logger.info( + "Awaiting un-partial-stating of room %s", + room_id, + ) + + await make_deferred_yieldable(d) + + logger.info("Room has un-partial-stated") + finally: + # Remove the added observer, and remove the room entry if its empty. + ds = self._observers.get(room_id) + if ds is not None: + ds.discard(d) + if not ds: + self._observers.pop(room_id, None) diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py index 500c9ccfbc..e0eda545b9 100644 --- a/tests/handlers/test_federation.py +++ b/tests/handlers/test_federation.py @@ -237,7 +237,9 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase): ) current_state = self.get_success( self.store.get_events_as_list( - (self.get_success(self.store.get_current_state_ids(room_id))).values() + ( + self.get_success(self.store.get_partial_current_state_ids(room_id)) + ).values() ) ) @@ -512,7 +514,7 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase): self.get_success(d) # sanity-check: the room should show that the new user is a member - r = self.get_success(self.store.get_current_state_ids(room_id)) + r = self.get_success(self.store.get_partial_current_state_ids(room_id)) self.assertEqual(r[(EventTypes.Member, other_user)], join_event.event_id) return join_event diff --git a/tests/handlers/test_federation_event.py b/tests/handlers/test_federation_event.py index 1d5b2492c0..1a36c25c41 100644 --- a/tests/handlers/test_federation_event.py +++ b/tests/handlers/test_federation_event.py @@ -91,7 +91,9 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase): event_injection.inject_member_event(self.hs, room_id, OTHER_USER, "join") ) - initial_state_map = self.get_success(main_store.get_current_state_ids(room_id)) + initial_state_map = self.get_success( + main_store.get_partial_current_state_ids(room_id) + ) auth_event_ids = [ initial_state_map[("m.room.create", "")], diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index 057256cecd..14a0ee4922 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -146,7 +146,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): ) ) - self.datastore.get_current_state_deltas = Mock(return_value=(0, None)) + self.datastore.get_partial_current_state_deltas = Mock(return_value=(0, None)) self.datastore.get_to_device_stream_token = lambda: 0 self.datastore.get_new_device_msgs_for_remote = ( diff --git a/tests/rest/client/test_upgrade_room.py b/tests/rest/client/test_upgrade_room.py index a21cbe9fa8..98c1039d33 100644 --- a/tests/rest/client/test_upgrade_room.py +++ b/tests/rest/client/test_upgrade_room.py @@ -249,7 +249,9 @@ class UpgradeRoomTest(unittest.HomeserverTestCase): new_space_id = channel.json_body["replacement_room"] - state_ids = self.get_success(self.store.get_current_state_ids(new_space_id)) + state_ids = self.get_success( + self.store.get_partial_current_state_ids(new_space_id) + ) # Ensure the new room is still a space. create_event = self.get_success( @@ -284,7 +286,9 @@ class UpgradeRoomTest(unittest.HomeserverTestCase): new_room_id = channel.json_body["replacement_room"] - state_ids = self.get_success(self.store.get_current_state_ids(new_room_id)) + state_ids = self.get_success( + self.store.get_partial_current_state_ids(new_room_id) + ) # Ensure the new room is the same type as the old room. create_event = self.get_success( diff --git a/tests/storage/util/test_partial_state_events_tracker.py b/tests/storage/util/test_partial_state_events_tracker.py index 303e190b6c..cae14151c0 100644 --- a/tests/storage/util/test_partial_state_events_tracker.py +++ b/tests/storage/util/test_partial_state_events_tracker.py @@ -17,8 +17,12 @@ from unittest import mock from twisted.internet.defer import CancelledError, ensureDeferred -from synapse.storage.util.partial_state_events_tracker import PartialStateEventsTracker +from synapse.storage.util.partial_state_events_tracker import ( + PartialCurrentStateTracker, + PartialStateEventsTracker, +) +from tests.test_utils import make_awaitable from tests.unittest import TestCase @@ -115,3 +119,56 @@ class PartialStateEventsTrackerTestCase(TestCase): self.tracker.notify_un_partial_stated("event1") self.successResultOf(d2) + + +class PartialCurrentStateTrackerTestCase(TestCase): + def setUp(self) -> None: + self.mock_store = mock.Mock(spec_set=["is_partial_state_room"]) + + self.tracker = PartialCurrentStateTracker(self.mock_store) + + def test_does_not_block_for_full_state_rooms(self): + self.mock_store.is_partial_state_room.return_value = make_awaitable(False) + + self.successResultOf(ensureDeferred(self.tracker.await_full_state("room_id"))) + + def test_blocks_for_partial_room_state(self): + self.mock_store.is_partial_state_room.return_value = make_awaitable(True) + + d = ensureDeferred(self.tracker.await_full_state("room_id")) + + # there should be no result yet + self.assertNoResult(d) + + # notifying that the room has been de-partial-stated should unblock + self.tracker.notify_un_partial_stated("room_id") + self.successResultOf(d) + + def test_un_partial_state_race(self): + # We should correctly handle race between awaiting the state and us + # un-partialling the state + async def is_partial_state_room(events): + self.tracker.notify_un_partial_stated("room_id") + return True + + self.mock_store.is_partial_state_room.side_effect = is_partial_state_room + + self.successResultOf(ensureDeferred(self.tracker.await_full_state("room_id"))) + + def test_cancellation(self): + self.mock_store.is_partial_state_room.return_value = make_awaitable(True) + + d1 = ensureDeferred(self.tracker.await_full_state("room_id")) + self.assertNoResult(d1) + + d2 = ensureDeferred(self.tracker.await_full_state("room_id")) + self.assertNoResult(d2) + + d1.cancel() + self.assertFailure(d1, CancelledError) + + # d2 should still be waiting! + self.assertNoResult(d2) + + self.tracker.notify_un_partial_stated("room_id") + self.successResultOf(d2) From 01df5bacac3aa0e8356fed889ea0b69c4c044535 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 3 Jun 2022 12:09:12 -0400 Subject: [PATCH 07/14] Improve URL previews for some pages (#12951) * Skip `og` and `meta` tags where the value is empty. * Fallback to the favicon if there are no other images. * Ignore tags meant for navigation. --- changelog.d/12951.feature | 1 + synapse/rest/media/v1/preview_html.py | 52 ++++++++++++++++-------- tests/rest/media/v1/test_html_preview.py | 37 ++++++++++++++++- 3 files changed, 72 insertions(+), 18 deletions(-) create mode 100644 changelog.d/12951.feature diff --git a/changelog.d/12951.feature b/changelog.d/12951.feature new file mode 100644 index 0000000000..f885be9fe4 --- /dev/null +++ b/changelog.d/12951.feature @@ -0,0 +1 @@ +Improve URL previews for pages with empty elements. diff --git a/synapse/rest/media/v1/preview_html.py b/synapse/rest/media/v1/preview_html.py index 13ec7ab533..ed8f21a483 100644 --- a/synapse/rest/media/v1/preview_html.py +++ b/synapse/rest/media/v1/preview_html.py @@ -30,6 +30,9 @@ _xml_encoding_match = re.compile( ) _content_type_match = re.compile(r'.*; *charset="?(.*?)"?(;|$)', flags=re.I) +# Certain elements aren't meant for display. +ARIA_ROLES_TO_IGNORE = {"directory", "menu", "menubar", "toolbar"} + def _normalise_encoding(encoding: str) -> Optional[str]: """Use the Python codec's name as the normalised entry.""" @@ -174,13 +177,15 @@ def parse_html_to_open_graph(tree: "etree.Element") -> Dict[str, Optional[str]]: # "og:video:secure_url": "https://www.youtube.com/v/LXDBoHyjmtw?version=3", og: Dict[str, Optional[str]] = {} - for tag in tree.xpath("//*/meta[starts-with(@property, 'og:')]"): - if "content" in tag.attrib: - # if we've got more than 50 tags, someone is taking the piss - if len(og) >= 50: - logger.warning("Skipping OG for page with too many 'og:' tags") - return {} - og[tag.attrib["property"]] = tag.attrib["content"] + for tag in tree.xpath( + "//*/meta[starts-with(@property, 'og:')][@content][not(@content='')]" + ): + # if we've got more than 50 tags, someone is taking the piss + if len(og) >= 50: + logger.warning("Skipping OG for page with too many 'og:' tags") + return {} + + og[tag.attrib["property"]] = tag.attrib["content"] # TODO: grab article: meta tags too, e.g.: @@ -192,21 +197,23 @@ def parse_html_to_open_graph(tree: "etree.Element") -> Dict[str, Optional[str]]: # "article:modified_time" content="2016-04-01T18:31:53+00:00" /> if "og:title" not in og: - # do some basic spidering of the HTML - title = tree.xpath("(//title)[1] | (//h1)[1] | (//h2)[1] | (//h3)[1]") - if title and title[0].text is not None: - og["og:title"] = title[0].text.strip() + # Attempt to find a title from the title tag, or the biggest header on the page. + title = tree.xpath("((//title)[1] | (//h1)[1] | (//h2)[1] | (//h3)[1])/text()") + if title: + og["og:title"] = title[0].strip() else: og["og:title"] = None if "og:image" not in og: - # TODO: extract a favicon failing all else meta_image = tree.xpath( - "//*/meta[translate(@itemprop, 'IMAGE', 'image')='image']/@content" + "//*/meta[translate(@itemprop, 'IMAGE', 'image')='image'][not(@content='')]/@content[1]" ) + # If a meta image is found, use it. if meta_image: og["og:image"] = meta_image[0] else: + # Try to find images which are larger than 10px by 10px. + # # TODO: consider inlined CSS styles as well as width & height attribs images = tree.xpath("//img[@src][number(@width)>10][number(@height)>10]") images = sorted( @@ -215,17 +222,24 @@ def parse_html_to_open_graph(tree: "etree.Element") -> Dict[str, Optional[str]]: -1 * float(i.attrib["width"]) * float(i.attrib["height"]) ), ) + # If no images were found, try to find *any* images. if not images: - images = tree.xpath("//img[@src]") + images = tree.xpath("//img[@src][1]") if images: og["og:image"] = images[0].attrib["src"] + # Finally, fallback to the favicon if nothing else. + else: + favicons = tree.xpath("//link[@href][contains(@rel, 'icon')]/@href[1]") + if favicons: + og["og:image"] = favicons[0] + if "og:description" not in og: + # Check the first meta description tag for content. meta_description = tree.xpath( - "//*/meta" - "[translate(@name, 'DESCRIPTION', 'description')='description']" - "/@content" + "//*/meta[translate(@name, 'DESCRIPTION', 'description')='description'][not(@content='')]/@content[1]" ) + # If a meta description is found with content, use it. if meta_description: og["og:description"] = meta_description[0] else: @@ -306,6 +320,10 @@ def _iterate_over_text( if isinstance(el, str): yield el elif el.tag not in tags_to_ignore: + # If the element isn't meant for display, ignore it. + if el.get("role") in ARIA_ROLES_TO_IGNORE: + continue + # el.text is the text before the first child, so we can immediately # return it if the text exists. if el.text: diff --git a/tests/rest/media/v1/test_html_preview.py b/tests/rest/media/v1/test_html_preview.py index 62e308814d..ea9e5889bf 100644 --- a/tests/rest/media/v1/test_html_preview.py +++ b/tests/rest/media/v1/test_html_preview.py @@ -145,7 +145,7 @@ class SummarizeTestCase(unittest.TestCase): ) -class CalcOgTestCase(unittest.TestCase): +class OpenGraphFromHtmlTestCase(unittest.TestCase): if not lxml: skip = "url preview feature requires lxml" @@ -235,6 +235,21 @@ class CalcOgTestCase(unittest.TestCase): self.assertEqual(og, {"og:title": None, "og:description": "Some text."}) + # Another variant is a title with no content. + html = b""" + + + +

Title

+ + + """ + + tree = decode_body(html, "http://example.com/test.html") + og = parse_html_to_open_graph(tree) + + self.assertEqual(og, {"og:title": "Title", "og:description": "Title"}) + def test_h1_as_title(self) -> None: html = b""" @@ -250,6 +265,26 @@ class CalcOgTestCase(unittest.TestCase): self.assertEqual(og, {"og:title": "Title", "og:description": "Some text."}) + def test_empty_description(self) -> None: + """Description tags with empty content should be ignored.""" + html = b""" + + + + + + + +

Title

+ + + """ + + tree = decode_body(html, "http://example.com/test.html") + og = parse_html_to_open_graph(tree) + + self.assertEqual(og, {"og:title": "Title", "og:description": "Finally!"}) + def test_missing_title_and_broken_h1(self) -> None: html = b""" From 6b46c3eb3d526d903e1e4833b2e8ae9b73de8502 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Fri, 3 Jun 2022 12:13:35 -0400 Subject: [PATCH 08/14] Remove groups code from synapse_port_db. (#12899) --- changelog.d/12899.removal | 1 + synapse/_scripts/synapse_port_db.py | 23 ++++++++++++------- .../storage/databases/main/group_server.py | 9 ++------ 3 files changed, 18 insertions(+), 15 deletions(-) create mode 100644 changelog.d/12899.removal diff --git a/changelog.d/12899.removal b/changelog.d/12899.removal new file mode 100644 index 0000000000..41f6fae5da --- /dev/null +++ b/changelog.d/12899.removal @@ -0,0 +1 @@ +Remove support for the non-standard groups/communities feature from Synapse. diff --git a/synapse/_scripts/synapse_port_db.py b/synapse/_scripts/synapse_port_db.py index d7dfa92bd1..4939573f30 100755 --- a/synapse/_scripts/synapse_port_db.py +++ b/synapse/_scripts/synapse_port_db.py @@ -102,14 +102,6 @@ BOOLEAN_COLUMNS = { "devices": ["hidden"], "device_lists_outbound_pokes": ["sent"], "users_who_share_rooms": ["share_private"], - "groups": ["is_public"], - "group_rooms": ["is_public"], - "group_users": ["is_public", "is_admin"], - "group_summary_rooms": ["is_public"], - "group_room_categories": ["is_public"], - "group_summary_users": ["is_public"], - "group_roles": ["is_public"], - "local_group_membership": ["is_publicised", "is_admin"], "e2e_room_keys": ["is_verified"], "account_validity": ["email_sent"], "redactions": ["have_censored"], @@ -175,6 +167,21 @@ IGNORED_TABLES = { "ui_auth_sessions", "ui_auth_sessions_credentials", "ui_auth_sessions_ips", + # Groups/communities is no longer supported. + "group_attestations_remote", + "group_attestations_renewals", + "group_invites", + "group_roles", + "group_room_categories", + "group_rooms", + "group_summary_roles", + "group_summary_room_categories", + "group_summary_rooms", + "group_summary_users", + "group_users", + "groups", + "local_group_membership", + "local_group_updates", } diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py index da21a50144..c15a7136b6 100644 --- a/synapse/storage/databases/main/group_server.py +++ b/synapse/storage/databases/main/group_server.py @@ -29,11 +29,6 @@ class GroupServerStore(SQLBaseStore): db_conn: LoggingDatabaseConnection, hs: "HomeServer", ): - database.updates.register_background_index_update( - update_name="local_group_updates_index", - index_name="local_group_updates_stream_id_index", - table="local_group_updates", - columns=("stream_id",), - unique=True, - ) + # Register a legacy groups background update as a no-op. + database.updates.register_noop_background_update("local_group_updates_index") super().__init__(database, db_conn, hs) From e3163e2e11cf8bffa4cb3e58ac0b86a83eca314c Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 6 Jun 2022 11:24:12 +0300 Subject: [PATCH 09/14] Reduce the amount of state we pull from the DB (#12811) --- changelog.d/12811.misc | 1 + synapse/api/auth.py | 45 ++++++----- synapse/federation/federation_base.py | 1 + synapse/federation/federation_server.py | 12 +-- synapse/handlers/directory.py | 2 +- synapse/handlers/federation.py | 2 +- synapse/handlers/federation_event.py | 18 +++-- synapse/handlers/initial_sync.py | 6 +- synapse/handlers/message.py | 4 +- synapse/handlers/room.py | 5 +- synapse/handlers/room_member.py | 16 +++- synapse/handlers/search.py | 2 +- synapse/notifier.py | 2 +- synapse/rest/admin/rooms.py | 34 ++++++--- synapse/rest/client/room.py | 7 +- .../resource_limits_server_notices.py | 7 +- synapse/state/__init__.py | 75 +------------------ synapse/storage/controllers/state.py | 27 +++++++ tests/federation/test_federation_server.py | 6 +- tests/handlers/test_directory.py | 3 +- tests/storage/test_events.py | 17 +++-- tests/storage/test_purge.py | 5 +- tests/storage/test_room.py | 12 ++- 23 files changed, 162 insertions(+), 147 deletions(-) create mode 100644 changelog.d/12811.misc diff --git a/changelog.d/12811.misc b/changelog.d/12811.misc new file mode 100644 index 0000000000..d57e1aca6b --- /dev/null +++ b/changelog.d/12811.misc @@ -0,0 +1 @@ +Reduce the amount of state we pull from the DB. diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 931750668e..5a410f805a 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -29,12 +29,11 @@ from synapse.api.errors import ( MissingClientTokenError, ) from synapse.appservice import ApplicationService -from synapse.events import EventBase from synapse.http import get_request_user_agent from synapse.http.site import SynapseRequest from synapse.logging.opentracing import active_span, force_tracing, start_active_span from synapse.storage.databases.main.registration import TokenLookupResult -from synapse.types import Requester, StateMap, UserID, create_requester +from synapse.types import Requester, UserID, create_requester from synapse.util.caches.lrucache import LruCache from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry @@ -61,8 +60,8 @@ class Auth: self.hs = hs self.clock = hs.get_clock() self.store = hs.get_datastores().main - self.state = hs.get_state_handler() self._account_validity_handler = hs.get_account_validity_handler() + self._storage_controllers = hs.get_storage_controllers() self.token_cache: LruCache[str, Tuple[str, bool]] = LruCache( 10000, "token_cache" @@ -79,9 +78,8 @@ class Auth: self, room_id: str, user_id: str, - current_state: Optional[StateMap[EventBase]] = None, allow_departed_users: bool = False, - ) -> EventBase: + ) -> Tuple[str, Optional[str]]: """Check if the user is in the room, or was at some point. Args: room_id: The room to check. @@ -99,29 +97,28 @@ class Auth: Raises: AuthError if the user is/was not in the room. Returns: - Membership event for the user if the user was in the - room. This will be the join event if they are currently joined to - the room. This will be the leave event if they have left the room. + The current membership of the user in the room and the + membership event ID of the user. """ - if current_state: - member = current_state.get((EventTypes.Member, user_id), None) - else: - member = await self.state.get_current_state( - room_id=room_id, event_type=EventTypes.Member, state_key=user_id - ) - if member: - membership = member.membership + ( + membership, + member_event_id, + ) = await self.store.get_local_current_membership_for_user_in_room( + user_id=user_id, + room_id=room_id, + ) + if membership: if membership == Membership.JOIN: - return member + return membership, member_event_id # XXX this looks totally bogus. Why do we not allow users who have been banned, # or those who were members previously and have been re-invited? if allow_departed_users and membership == Membership.LEAVE: forgot = await self.store.did_forget(user_id, room_id) if not forgot: - return member + return membership, member_event_id raise AuthError(403, "User %s not in room %s" % (user_id, room_id)) @@ -602,8 +599,11 @@ class Auth: # We currently require the user is a "moderator" in the room. We do this # by checking if they would (theoretically) be able to change the # m.room.canonical_alias events - power_level_event = await self.state.get_current_state( - room_id, EventTypes.PowerLevels, "" + + power_level_event = ( + await self._storage_controllers.state.get_current_state_event( + room_id, EventTypes.PowerLevels, "" + ) ) auth_events = {} @@ -693,12 +693,11 @@ class Auth: # * The user is a non-guest user, and was ever in the room # * The user is a guest user, and has joined the room # else it will throw. - member_event = await self.check_user_in_room( + return await self.check_user_in_room( room_id, user_id, allow_departed_users=allow_departed_users ) - return member_event.membership, member_event.event_id except AuthError: - visibility = await self.state.get_current_state( + visibility = await self._storage_controllers.state.get_current_state_event( room_id, EventTypes.RoomHistoryVisibility, "" ) if ( diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py index a6232e048b..2522bf78fc 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py @@ -53,6 +53,7 @@ class FederationBase: self.spam_checker = hs.get_spam_checker() self.store = hs.get_datastores().main self._clock = hs.get_clock() + self._storage_controllers = hs.get_storage_controllers() async def _check_sigs_and_hash( self, room_version: RoomVersion, pdu: EventBase diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index f4af121c4d..3e1518f1f6 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -1223,14 +1223,10 @@ class FederationServer(FederationBase): Raises: AuthError if the server does not match the ACL """ - state_ids = await self._state_storage_controller.get_current_state_ids(room_id) - acl_event_id = state_ids.get((EventTypes.ServerACL, "")) - - if not acl_event_id: - return - - acl_event = await self.store.get_event(acl_event_id) - if server_matches_acl_event(server_name, acl_event): + acl_event = await self._storage_controllers.state.get_current_state_event( + room_id, EventTypes.ServerACL, "" + ) + if not acl_event or server_matches_acl_event(server_name, acl_event): return raise AuthError(code=403, msg="Server is banned from room") diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index 44e84698c4..1459a046de 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -320,7 +320,7 @@ class DirectoryHandler: Raises: ShadowBanError if the requester has been shadow-banned. """ - alias_event = await self.state.get_current_state( + alias_event = await self._storage_controllers.state.get_current_state_event( room_id, EventTypes.CanonicalAlias, "" ) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index b212ee2172..6a143440d3 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -371,7 +371,7 @@ class FederationHandler: # First we try hosts that are already in the room # TODO: HEURISTIC ALERT. - curr_state = await self.state_handler.get_current_state(room_id) + curr_state = await self._storage_controllers.state.get_current_state(room_id) curr_domains = get_domains_from_state(curr_state) diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 549b066dd9..87a0608359 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -1584,9 +1584,11 @@ class FederationEventHandler: if guest_access == GuestAccess.CAN_JOIN: return - current_state_map = await self._state_handler.get_current_state(event.room_id) - current_state = list(current_state_map.values()) - await self._get_room_member_handler().kick_guest_users(current_state) + current_state = await self._storage_controllers.state.get_current_state( + event.room_id + ) + current_state_list = list(current_state.values()) + await self._get_room_member_handler().kick_guest_users(current_state_list) async def _check_for_soft_fail( self, @@ -1614,6 +1616,9 @@ class FederationEventHandler: room_version = await self._store.get_room_version_id(event.room_id) room_version_obj = KNOWN_ROOM_VERSIONS[room_version] + # The event types we want to pull from the "current" state. + auth_types = auth_types_for_event(room_version_obj, event) + # Calculate the "current state". if state_ids is not None: # If we're explicitly given the state then we won't have all the @@ -1643,8 +1648,10 @@ class FederationEventHandler: ) ) else: - current_state_ids = await self._state_handler.get_current_state_ids( - event.room_id, latest_event_ids=extrem_ids + current_state_ids = ( + await self._state_storage_controller.get_current_state_ids( + event.room_id, StateFilter.from_types(auth_types) + ) ) logger.debug( @@ -1654,7 +1661,6 @@ class FederationEventHandler: ) # Now check if event pass auth against said current state - auth_types = auth_types_for_event(room_version_obj, event) current_state_ids_list = [ e for k, e in current_state_ids.items() if k in auth_types ] diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index d2b489e816..85b472f250 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -190,7 +190,7 @@ class InitialSyncHandler: if event.membership == Membership.JOIN: room_end_token = now_token.room_key deferred_room_state = run_in_background( - self.state_handler.get_current_state, event.room_id + self._state_storage_controller.get_current_state, event.room_id ) elif event.membership == Membership.LEAVE: room_end_token = RoomStreamToken( @@ -407,7 +407,9 @@ class InitialSyncHandler: membership: str, is_peeking: bool, ) -> JsonDict: - current_state = await self.state.get_current_state(room_id=room_id) + current_state = await self._storage_controllers.state.get_current_state( + room_id=room_id + ) # TODO: These concurrently time_now = self.clock.time_msec() diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 081625f0bd..f455158a2c 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -125,7 +125,9 @@ class MessageHandler: ) if membership == Membership.JOIN: - data = await self.state.get_current_state(room_id, event_type, state_key) + data = await self._storage_controllers.state.get_current_state_event( + room_id, event_type, state_key + ) elif membership == Membership.LEAVE: key = (event_type, state_key) # If the membership is not JOIN, then the event ID should exist. diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index e2b0e519d4..520663f172 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -1333,6 +1333,7 @@ class TimestampLookupHandler: self.store = hs.get_datastores().main self.state_handler = hs.get_state_handler() self.federation_client = hs.get_federation_client() + self._storage_controllers = hs.get_storage_controllers() async def get_event_for_timestamp( self, @@ -1406,7 +1407,9 @@ class TimestampLookupHandler: ) # Find other homeservers from the given state in the room - curr_state = await self.state_handler.get_current_state(room_id) + curr_state = await self._storage_controllers.state.get_current_state( + room_id + ) curr_domains = get_domains_from_state(curr_state) likely_domains = [ domain for domain, depth in curr_domains if domain != self.server_name diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 70c674ff8e..d1199a0644 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -1401,7 +1401,19 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): txn_id: Optional[str], id_access_token: Optional[str] = None, ) -> int: - room_state = await self.state_handler.get_current_state(room_id) + room_state = await self._storage_controllers.state.get_current_state( + room_id, + StateFilter.from_types( + [ + (EventTypes.Member, user.to_string()), + (EventTypes.CanonicalAlias, ""), + (EventTypes.Name, ""), + (EventTypes.Create, ""), + (EventTypes.JoinRules, ""), + (EventTypes.RoomAvatar, ""), + ] + ), + ) inviter_display_name = "" inviter_avatar_url = "" @@ -1797,7 +1809,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): async def forget(self, user: UserID, room_id: str) -> None: user_id = user.to_string() - member = await self.state_handler.get_current_state( + member = await self._storage_controllers.state.get_current_state_event( room_id=room_id, event_type=EventTypes.Member, state_key=user_id ) membership = member.membership if member else None diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index 659f99f7e2..bcab98c6d5 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -348,7 +348,7 @@ class SearchHandler: state_results = {} if include_state: for room_id in {e.room_id for e in search_result.allowed_events}: - state = await self.state_handler.get_current_state(room_id) + state = await self._storage_controllers.state.get_current_state(room_id) state_results[room_id] = list(state.values()) aggregations = await self._relations_handler.get_bundled_aggregations( diff --git a/synapse/notifier.py b/synapse/notifier.py index 1100434b3f..54b0ec4b97 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -681,7 +681,7 @@ class Notifier: return joined_room_ids, True async def _is_world_readable(self, room_id: str) -> bool: - state = await self.state_handler.get_current_state( + state = await self._storage_controllers.state.get_current_state_event( room_id, EventTypes.RoomHistoryVisibility, "" ) if state and "history_visibility" in state.content: diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index 1cacd1a4f0..9d953d58de 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -34,6 +34,7 @@ from synapse.rest.admin._base import ( assert_user_is_admin, ) from synapse.storage.databases.main.room import RoomSortOrder +from synapse.storage.state import StateFilter from synapse.types import JsonDict, RoomID, UserID, create_requester from synapse.util import json_decoder @@ -448,7 +449,8 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet): super().__init__(hs) self.auth = hs.get_auth() self.admin_handler = hs.get_admin_handler() - self.state_handler = hs.get_state_handler() + self.store = hs.get_datastores().main + self._storage_controllers = hs.get_storage_controllers() self.is_mine = hs.is_mine async def on_POST( @@ -490,8 +492,11 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet): ) # send invite if room has "JoinRules.INVITE" - room_state = await self.state_handler.get_current_state(room_id) - join_rules_event = room_state.get((EventTypes.JoinRules, "")) + join_rules_event = ( + await self._storage_controllers.state.get_current_state_event( + room_id, EventTypes.JoinRules, "" + ) + ) if join_rules_event: if not (join_rules_event.content.get("join_rule") == JoinRules.PUBLIC): # update_membership with an action of "invite" can raise a @@ -536,6 +541,7 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet): super().__init__(hs) self.auth = hs.get_auth() self.store = hs.get_datastores().main + self._state_storage_controller = hs.get_storage_controllers().state self.event_creation_handler = hs.get_event_creation_handler() self.state_handler = hs.get_state_handler() self.is_mine_id = hs.is_mine_id @@ -553,12 +559,22 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet): user_to_add = content.get("user_id", requester.user.to_string()) # Figure out which local users currently have power in the room, if any. - room_state = await self.state_handler.get_current_state(room_id) - if not room_state: + filtered_room_state = await self._state_storage_controller.get_current_state( + room_id, + StateFilter.from_types( + [ + (EventTypes.Create, ""), + (EventTypes.PowerLevels, ""), + (EventTypes.JoinRules, ""), + (EventTypes.Member, user_to_add), + ] + ), + ) + if not filtered_room_state: raise SynapseError(HTTPStatus.BAD_REQUEST, "Server not in room") - create_event = room_state[(EventTypes.Create, "")] - power_levels = room_state.get((EventTypes.PowerLevels, "")) + create_event = filtered_room_state[(EventTypes.Create, "")] + power_levels = filtered_room_state.get((EventTypes.PowerLevels, "")) if power_levels is not None: # We pick the local user with the highest power. @@ -634,7 +650,7 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet): # Now we check if the user we're granting admin rights to is already in # the room. If not and it's not a public room we invite them. - member_event = room_state.get((EventTypes.Member, user_to_add)) + member_event = filtered_room_state.get((EventTypes.Member, user_to_add)) is_joined = False if member_event: is_joined = member_event.content["membership"] in ( @@ -645,7 +661,7 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet): if is_joined: return HTTPStatus.OK, {} - join_rules = room_state.get((EventTypes.JoinRules, "")) + join_rules = filtered_room_state.get((EventTypes.JoinRules, "")) is_public = False if join_rules: is_public = join_rules.content.get("join_rule") == JoinRules.PUBLIC diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index 7a5ce8ad0e..a26e976492 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -650,6 +650,7 @@ class RoomEventServlet(RestServlet): self.clock = hs.get_clock() self._store = hs.get_datastores().main self._state = hs.get_state_handler() + self._storage_controllers = hs.get_storage_controllers() self.event_handler = hs.get_event_handler() self._event_serializer = hs.get_event_client_serializer() self._relations_handler = hs.get_relations_handler() @@ -673,8 +674,10 @@ class RoomEventServlet(RestServlet): if include_unredacted_content and not await self.auth.is_server_admin( requester.user ): - power_level_event = await self._state.get_current_state( - room_id, EventTypes.PowerLevels, "" + power_level_event = ( + await self._storage_controllers.state.get_current_state_event( + room_id, EventTypes.PowerLevels, "" + ) ) auth_events = {} diff --git a/synapse/server_notices/resource_limits_server_notices.py b/synapse/server_notices/resource_limits_server_notices.py index b5f3a0c74e..6863020778 100644 --- a/synapse/server_notices/resource_limits_server_notices.py +++ b/synapse/server_notices/resource_limits_server_notices.py @@ -36,6 +36,7 @@ class ResourceLimitsServerNotices: def __init__(self, hs: "HomeServer"): self._server_notices_manager = hs.get_server_notices_manager() self._store = hs.get_datastores().main + self._storage_controllers = hs.get_storage_controllers() self._auth = hs.get_auth() self._config = hs.config self._resouce_limited = False @@ -178,8 +179,10 @@ class ResourceLimitsServerNotices: currently_blocked = False pinned_state_event = None try: - pinned_state_event = await self._state.get_current_state( - room_id, event_type=EventTypes.Pinned + pinned_state_event = ( + await self._storage_controllers.state.get_current_state_event( + room_id, event_type=EventTypes.Pinned, state_key="" + ) ) except AuthError: # The user has yet to join the server notices room diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index bf09f5128a..ab68e2b6a4 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -32,13 +32,11 @@ from typing import ( Set, Tuple, Union, - overload, ) import attr from frozendict import frozendict from prometheus_client import Counter, Histogram -from typing_extensions import Literal from synapse.api.constants import EventTypes from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, StateResolutionVersions @@ -132,85 +130,20 @@ class StateHandler: self._state_resolution_handler = hs.get_state_resolution_handler() self._storage_controllers = hs.get_storage_controllers() - @overload - async def get_current_state( - self, - room_id: str, - event_type: Literal[None] = None, - state_key: str = "", - latest_event_ids: Optional[List[str]] = None, - ) -> StateMap[EventBase]: - ... - - @overload - async def get_current_state( - self, - room_id: str, - event_type: str, - state_key: str = "", - latest_event_ids: Optional[List[str]] = None, - ) -> Optional[EventBase]: - ... - - async def get_current_state( - self, - room_id: str, - event_type: Optional[str] = None, - state_key: str = "", - latest_event_ids: Optional[List[str]] = None, - ) -> Union[Optional[EventBase], StateMap[EventBase]]: - """Retrieves the current state for the room. This is done by - calling `get_latest_events_in_room` to get the leading edges of the - event graph and then resolving any of the state conflicts. - - This is equivalent to getting the state of an event that were to send - next before receiving any new events. - - Returns: - If `event_type` is specified, then the method returns only the one - event (or None) with that `event_type` and `state_key`. - - Otherwise, a map from (type, state_key) to event. - """ - if not latest_event_ids: - latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id) - assert latest_event_ids is not None - - logger.debug("calling resolve_state_groups from get_current_state") - ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids) - state = ret.state - - if event_type: - event_id = state.get((event_type, state_key)) - event = None - if event_id: - event = await self.store.get_event(event_id, allow_none=True) - return event - - state_map = await self.store.get_events( - list(state.values()), get_prev_content=False - ) - return { - key: state_map[e_id] for key, e_id in state.items() if e_id in state_map - } - async def get_current_state_ids( - self, room_id: str, latest_event_ids: Optional[Collection[str]] = None + self, + room_id: str, + latest_event_ids: Collection[str], ) -> StateMap[str]: """Get the current state, or the state at a set of events, for a room Args: room_id: - latest_event_ids: if given, the forward extremities to resolve. If - None, we look them up from the database (via a cache). + latest_event_ids: The forward extremities to resolve. Returns: the state dict, mapping from (event_type, state_key) -> event_id """ - if not latest_event_ids: - latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id) - assert latest_event_ids is not None - logger.debug("calling resolve_state_groups from get_current_state_ids") ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids) return ret.state diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py index 9952b00493..63a78ebc87 100644 --- a/synapse/storage/controllers/state.py +++ b/synapse/storage/controllers/state.py @@ -455,3 +455,30 @@ class StateStorageController: return await self.stores.main.get_partial_current_state_deltas( prev_stream_id, max_stream_id ) + + async def get_current_state( + self, room_id: str, state_filter: Optional[StateFilter] = None + ) -> StateMap[EventBase]: + """Same as `get_current_state_ids` but also fetches the events""" + state_map_ids = await self.get_current_state_ids(room_id, state_filter) + + event_map = await self.stores.main.get_events(list(state_map_ids.values())) + + state_map = {} + for key, event_id in state_map_ids.items(): + event = event_map.get(event_id) + if event: + state_map[key] = event + + return state_map + + async def get_current_state_event( + self, room_id: str, event_type: str, state_key: str + ) -> Optional[EventBase]: + """Get the current state event for the given type/state_key.""" + + key = (event_type, state_key) + state_map = await self.get_current_state( + room_id, StateFilter.from_types((key,)) + ) + return state_map.get(key) diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py index b19365b81a..413b3c9426 100644 --- a/tests/federation/test_federation_server.py +++ b/tests/federation/test_federation_server.py @@ -134,6 +134,8 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer): super().prepare(reactor, clock, hs) + self._storage_controllers = hs.get_storage_controllers() + # create the room creator_user_id = self.register_user("kermit", "test") tok = self.login("kermit", "test") @@ -207,7 +209,7 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase): # the room should show that the new user is a member r = self.get_success( - self.hs.get_state_handler().get_current_state(self._room_id) + self._storage_controllers.state.get_current_state(self._room_id) ) self.assertEqual(r[("m.room.member", joining_user)].membership, "join") @@ -258,7 +260,7 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase): # the room should show that the new user is a member r = self.get_success( - self.hs.get_state_handler().get_current_state(self._room_id) + self._storage_controllers.state.get_current_state(self._room_id) ) self.assertEqual(r[("m.room.member", joining_user)].membership, "join") diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py index 11ad44223d..53d49ca896 100644 --- a/tests/handlers/test_directory.py +++ b/tests/handlers/test_directory.py @@ -298,6 +298,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase): self.store = hs.get_datastores().main self.handler = hs.get_directory_handler() self.state_handler = hs.get_state_handler() + self._storage_controllers = hs.get_storage_controllers() # Create user self.admin_user = self.register_user("admin", "pass", admin=True) @@ -335,7 +336,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase): def _get_canonical_alias(self): """Get the canonical alias state of the room.""" return self.get_success( - self.state_handler.get_current_state( + self._storage_controllers.state.get_current_state_event( self.room_id, EventTypes.CanonicalAlias, "" ) ) diff --git a/tests/storage/test_events.py b/tests/storage/test_events.py index a76718e8f9..2ff88e64a5 100644 --- a/tests/storage/test_events.py +++ b/tests/storage/test_events.py @@ -32,6 +32,7 @@ class ExtremPruneTestCase(HomeserverTestCase): def prepare(self, reactor, clock, homeserver): self.state = self.hs.get_state_handler() self._persistence = self.hs.get_storage_controllers().persistence + self._state_storage_controller = self.hs.get_storage_controllers().state self.store = self.hs.get_datastores().main self.register_user("user", "pass") @@ -104,7 +105,7 @@ class ExtremPruneTestCase(HomeserverTestCase): ) state_before_gap = self.get_success( - self.state.get_current_state_ids(self.room_id) + self._state_storage_controller.get_current_state_ids(self.room_id) ) self.persist_event(remote_event_2, state=state_before_gap) @@ -137,7 +138,9 @@ class ExtremPruneTestCase(HomeserverTestCase): # setting. The state resolution across the old and new event will then # include it, and so the resolved state won't match the new state. state_before_gap = dict( - self.get_success(self.state.get_current_state_ids(self.room_id)) + self.get_success( + self._state_storage_controller.get_current_state_ids(self.room_id) + ) ) state_before_gap.pop(("m.room.history_visibility", "")) @@ -181,7 +184,7 @@ class ExtremPruneTestCase(HomeserverTestCase): ) state_before_gap = self.get_success( - self.state.get_current_state_ids(self.room_id) + self._state_storage_controller.get_current_state_ids(self.room_id) ) self.persist_event(remote_event_2, state=state_before_gap) @@ -213,7 +216,7 @@ class ExtremPruneTestCase(HomeserverTestCase): ) state_before_gap = self.get_success( - self.state.get_current_state_ids(self.room_id) + self._state_storage_controller.get_current_state_ids(self.room_id) ) self.persist_event(remote_event_2, state=state_before_gap) @@ -255,7 +258,7 @@ class ExtremPruneTestCase(HomeserverTestCase): ) state_before_gap = self.get_success( - self.state.get_current_state_ids(self.room_id) + self._state_storage_controller.get_current_state_ids(self.room_id) ) self.persist_event(remote_event_2, state=state_before_gap) @@ -299,7 +302,7 @@ class ExtremPruneTestCase(HomeserverTestCase): ) state_before_gap = self.get_success( - self.state.get_current_state_ids(self.room_id) + self._state_storage_controller.get_current_state_ids(self.room_id) ) self.persist_event(remote_event_2, state=state_before_gap) @@ -335,7 +338,7 @@ class ExtremPruneTestCase(HomeserverTestCase): ) state_before_gap = self.get_success( - self.state.get_current_state_ids(self.room_id) + self._state_storage_controller.get_current_state_ids(self.room_id) ) self.persist_event(remote_event_2, state=state_before_gap) diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py index 92cd0dfc05..8dfaa0559b 100644 --- a/tests/storage/test_purge.py +++ b/tests/storage/test_purge.py @@ -102,9 +102,10 @@ class PurgeTests(HomeserverTestCase): first = self.helper.send(self.room_id, body="test1") # Get the current room state. - state_handler = self.hs.get_state_handler() create_event = self.get_success( - state_handler.get_current_state(self.room_id, "m.room.create", "") + self._storage_controllers.state.get_current_state_event( + self.room_id, "m.room.create", "" + ) ) self.assertIsNotNone(create_event) diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py index d497a19f63..3c79dabc9f 100644 --- a/tests/storage/test_room.py +++ b/tests/storage/test_room.py @@ -72,7 +72,7 @@ class RoomEventsStoreTestCase(HomeserverTestCase): # Room events need the full datastore, for persist_event() and # get_room_state() self.store = hs.get_datastores().main - self._storage = hs.get_storage_controllers() + self._storage_controllers = hs.get_storage_controllers() self.event_factory = hs.get_event_factory() self.room = RoomID.from_string("!abcde:test") @@ -88,7 +88,7 @@ class RoomEventsStoreTestCase(HomeserverTestCase): def inject_room_event(self, **kwargs): self.get_success( - self._storage.persistence.persist_event( + self._storage_controllers.persistence.persist_event( self.event_factory.create_event(room_id=self.room.to_string(), **kwargs) ) ) @@ -101,7 +101,9 @@ class RoomEventsStoreTestCase(HomeserverTestCase): ) state = self.get_success( - self.store.get_current_state(room_id=self.room.to_string()) + self._storage_controllers.state.get_current_state( + room_id=self.room.to_string() + ) ) self.assertEqual(1, len(state)) @@ -118,7 +120,9 @@ class RoomEventsStoreTestCase(HomeserverTestCase): ) state = self.get_success( - self.store.get_current_state(room_id=self.room.to_string()) + self._storage_controllers.state.get_current_state( + room_id=self.room.to_string() + ) ) self.assertEqual(1, len(state)) From fcd8703508ce5bfe481fc2f1510b05731477ce32 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20Christian=20Gr=C3=BCnhage?= Date: Mon, 6 Jun 2022 13:10:13 +0200 Subject: [PATCH 10/14] Allow updating passwords using the admin api without logging out devices (#12952) --- changelog.d/12952.feature | 1 + docs/admin_api/user_admin_api.md | 4 +++- synapse/rest/admin/users.py | 8 +++++++- 3 files changed, 11 insertions(+), 2 deletions(-) create mode 100644 changelog.d/12952.feature diff --git a/changelog.d/12952.feature b/changelog.d/12952.feature new file mode 100644 index 0000000000..7329bcc3d4 --- /dev/null +++ b/changelog.d/12952.feature @@ -0,0 +1 @@ +Allow updating a user's password using the admin API without logging out their devices. Contributed by @jcgruenhage. diff --git a/docs/admin_api/user_admin_api.md b/docs/admin_api/user_admin_api.md index c8794299e7..62f89e8cba 100644 --- a/docs/admin_api/user_admin_api.md +++ b/docs/admin_api/user_admin_api.md @@ -115,7 +115,9 @@ URL parameters: Body parameters: - `password` - string, optional. If provided, the user's password is updated and all - devices are logged out. + devices are logged out, unless `logout_devices` is set to `false`. +- `logout_devices` - bool, optional, defaults to `true`. If set to false, devices aren't + logged out even when `password` is provided. - `displayname` - string, optional, defaults to the value of `user_id`. - `threepids` - array, optional, allows setting the third-party IDs (email, msisdn) - `medium` - string. Kind of third-party ID, either `email` or `msisdn`. diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 8e29ada8a0..f0614a2897 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -226,6 +226,13 @@ class UserRestServletV2(RestServlet): if not isinstance(password, str) or len(password) > 512: raise SynapseError(HTTPStatus.BAD_REQUEST, "Invalid password") + logout_devices = body.get("logout_devices", True) + if not isinstance(logout_devices, bool): + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "'logout_devices' parameter is not of type boolean", + ) + deactivate = body.get("deactivated", False) if not isinstance(deactivate, bool): raise SynapseError( @@ -305,7 +312,6 @@ class UserRestServletV2(RestServlet): await self.store.set_server_admin(target_user, set_admin_to) if password is not None: - logout_devices = True new_password_hash = await self.auth_handler.hash(password) await self.set_password_handler.set_password( From 1acc897c317f2ed66c28a0cc27b6c584b8afdd6a Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 6 Jun 2022 07:18:04 -0400 Subject: [PATCH 11/14] Implement MSC3816, consider the root event for thread participation. (#12766) As opposed to only considering a user to have "participated" if they replied to the thread. --- changelog.d/12766.bugfix | 1 + synapse/handlers/relations.py | 58 +++++++++++++------- tests/rest/client/test_relations.py | 85 ++++++++++++++++++++--------- 3 files changed, 97 insertions(+), 47 deletions(-) create mode 100644 changelog.d/12766.bugfix diff --git a/changelog.d/12766.bugfix b/changelog.d/12766.bugfix new file mode 100644 index 0000000000..912c3deb70 --- /dev/null +++ b/changelog.d/12766.bugfix @@ -0,0 +1 @@ +Implement [MSC3816](https://github.com/matrix-org/matrix-spec-proposals/pull/3816): sending the root event in a thread should count as "participated" in it. diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py index 9a1cc11bb3..0b63cd2186 100644 --- a/synapse/handlers/relations.py +++ b/synapse/handlers/relations.py @@ -12,16 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import ( - TYPE_CHECKING, - Collection, - Dict, - FrozenSet, - Iterable, - List, - Optional, - Tuple, -) +from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Tuple import attr @@ -256,13 +247,19 @@ class RelationsHandler: return filtered_results - async def get_threads_for_events( - self, event_ids: Collection[str], user_id: str, ignored_users: FrozenSet[str] + async def _get_threads_for_events( + self, + events_by_id: Dict[str, EventBase], + relations_by_id: Dict[str, str], + user_id: str, + ignored_users: FrozenSet[str], ) -> Dict[str, _ThreadAggregation]: """Get the bundled aggregations for threads for the requested events. Args: - event_ids: Events to get aggregations for threads. + events_by_id: A map of event_id to events to get aggregations for threads. + relations_by_id: A map of event_id to the relation type, if one exists + for that event. user_id: The user requesting the bundled aggregations. ignored_users: The users ignored by the requesting user. @@ -273,16 +270,34 @@ class RelationsHandler: """ user = UserID.from_string(user_id) + # It is not valid to start a thread on an event which itself relates to another event. + event_ids = [eid for eid in events_by_id.keys() if eid not in relations_by_id] + # Fetch thread summaries. summaries = await self._main_store.get_thread_summaries(event_ids) - # Only fetch participated for a limited selection based on what had - # summaries. + # Limit fetching whether the requester has participated in a thread to + # events which are thread roots. thread_event_ids = [ event_id for event_id, summary in summaries.items() if summary ] - participated = await self._main_store.get_threads_participated( - thread_event_ids, user_id + + # Pre-seed thread participation with whether the requester sent the event. + participated = { + event_id: events_by_id[event_id].sender == user_id + for event_id in thread_event_ids + } + # For events the requester did not send, check the database for whether + # the requester sent a threaded reply. + participated.update( + await self._main_store.get_threads_participated( + [ + event_id + for event_id in thread_event_ids + if not participated[event_id] + ], + user_id, + ) ) # Then subtract off the results for any ignored users. @@ -343,7 +358,8 @@ class RelationsHandler: count=thread_count, # If there's a thread summary it must also exist in the # participated dictionary. - current_user_participated=participated[event_id], + current_user_participated=events_by_id[event_id].sender == user_id + or participated[event_id], ) return results @@ -401,9 +417,9 @@ class RelationsHandler: # events to be fetched. Thus, we check those first! # Fetch thread summaries (but only for the directly requested events). - threads = await self.get_threads_for_events( - # It is not valid to start a thread on an event which itself relates to another event. - [eid for eid in events_by_id.keys() if eid not in relations_by_id], + threads = await self._get_threads_for_events( + events_by_id, + relations_by_id, user_id, ignored_users, ) diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index bc9cc51b92..62e4db23ef 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -896,6 +896,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): relation_type: str, assertion_callable: Callable[[JsonDict], None], expected_db_txn_for_event: int, + access_token: Optional[str] = None, ) -> None: """ Makes requests to various endpoints which should include bundled aggregations @@ -907,7 +908,9 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): for relation-specific assertions. expected_db_txn_for_event: The number of database transactions which are expected for a call to /event/. + access_token: The access token to user, defaults to self.user_token. """ + access_token = access_token or self.user_token def assert_bundle(event_json: JsonDict) -> None: """Assert the expected values of the bundled aggregations.""" @@ -921,7 +924,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): channel = self.make_request( "GET", f"/rooms/{self.room}/event/{self.parent_id}", - access_token=self.user_token, + access_token=access_token, ) self.assertEqual(200, channel.code, channel.json_body) assert_bundle(channel.json_body) @@ -932,7 +935,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): channel = self.make_request( "GET", f"/rooms/{self.room}/messages?dir=b", - access_token=self.user_token, + access_token=access_token, ) self.assertEqual(200, channel.code, channel.json_body) assert_bundle(self._find_event_in_chunk(channel.json_body["chunk"])) @@ -941,7 +944,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): channel = self.make_request( "GET", f"/rooms/{self.room}/context/{self.parent_id}", - access_token=self.user_token, + access_token=access_token, ) self.assertEqual(200, channel.code, channel.json_body) assert_bundle(channel.json_body["event"]) @@ -949,7 +952,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): # Request sync. filter = urllib.parse.quote_plus(b'{"room": {"timeline": {"limit": 4}}}') channel = self.make_request( - "GET", f"/sync?filter={filter}", access_token=self.user_token + "GET", f"/sync?filter={filter}", access_token=access_token ) self.assertEqual(200, channel.code, channel.json_body) room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"] @@ -962,7 +965,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): "/search", # Search term matches the parent message. content={"search_categories": {"room_events": {"search_term": "Hi"}}}, - access_token=self.user_token, + access_token=access_token, ) self.assertEqual(200, channel.code, channel.json_body) chunk = [ @@ -1037,30 +1040,60 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): """ Test that threads get correctly bundled. """ - self._send_relation(RelationTypes.THREAD, "m.room.test") - channel = self._send_relation(RelationTypes.THREAD, "m.room.test") + # The root message is from "user", send replies as "user2". + self._send_relation( + RelationTypes.THREAD, "m.room.test", access_token=self.user2_token + ) + channel = self._send_relation( + RelationTypes.THREAD, "m.room.test", access_token=self.user2_token + ) thread_2 = channel.json_body["event_id"] - def assert_thread(bundled_aggregations: JsonDict) -> None: - self.assertEqual(2, bundled_aggregations.get("count")) - self.assertTrue(bundled_aggregations.get("current_user_participated")) - # The latest thread event has some fields that don't matter. - self.assert_dict( - { - "content": { - "m.relates_to": { - "event_id": self.parent_id, - "rel_type": RelationTypes.THREAD, - } + # This needs two assertion functions which are identical except for whether + # the current_user_participated flag is True, create a factory for the + # two versions. + def _gen_assert(participated: bool) -> Callable[[JsonDict], None]: + def assert_thread(bundled_aggregations: JsonDict) -> None: + self.assertEqual(2, bundled_aggregations.get("count")) + self.assertEqual( + participated, bundled_aggregations.get("current_user_participated") + ) + # The latest thread event has some fields that don't matter. + self.assert_dict( + { + "content": { + "m.relates_to": { + "event_id": self.parent_id, + "rel_type": RelationTypes.THREAD, + } + }, + "event_id": thread_2, + "sender": self.user2_id, + "type": "m.room.test", }, - "event_id": thread_2, - "sender": self.user_id, - "type": "m.room.test", - }, - bundled_aggregations.get("latest_event"), - ) + bundled_aggregations.get("latest_event"), + ) - self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 9) + return assert_thread + + # The "user" sent the root event and is making queries for the bundled + # aggregations: they have participated. + self._test_bundled_aggregations(RelationTypes.THREAD, _gen_assert(True), 8) + # The "user2" sent replies in the thread and is making queries for the + # bundled aggregations: they have participated. + # + # Note that this re-uses some cached values, so the total number of + # queries is much smaller. + self._test_bundled_aggregations( + RelationTypes.THREAD, _gen_assert(True), 2, access_token=self.user2_token + ) + + # A user with no interactions with the thread: they have not participated. + user3_id, user3_token = self._create_user("charlie") + self.helper.join(self.room, user=user3_id, tok=user3_token) + self._test_bundled_aggregations( + RelationTypes.THREAD, _gen_assert(False), 2, access_token=user3_token + ) def test_thread_with_bundled_aggregations_for_latest(self) -> None: """ @@ -1106,7 +1139,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): bundled_aggregations["latest_event"].get("unsigned"), ) - self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 9) + self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 8) def test_nested_thread(self) -> None: """ From 148fe58a247d61ffb76c566ba397285480d93f74 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 6 Jun 2022 07:46:04 -0400 Subject: [PATCH 12/14] Do not break URL previews if an image is unreachable. (#12950) Avoid breaking a URL preview completely if the chosen image 404s or is unreachable for some other reason (e.g. DNS). --- changelog.d/12950.bugfix | 1 + synapse/rest/media/v1/preview_url_resource.py | 23 ++++++++---- tests/rest/media/v1/test_url_preview.py | 35 +++++++++++++++++++ 3 files changed, 53 insertions(+), 6 deletions(-) create mode 100644 changelog.d/12950.bugfix diff --git a/changelog.d/12950.bugfix b/changelog.d/12950.bugfix new file mode 100644 index 0000000000..e835d9aa72 --- /dev/null +++ b/changelog.d/12950.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug where a URL preview would break if the image failed to download. diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index 2b2db63bf7..54a849eac9 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -586,12 +586,16 @@ class PreviewUrlResource(DirectServeJsonResource): og: The Open Graph dictionary. This is modified with image information. """ # If there's no image or it is blank, there's nothing to do. - if "og:image" not in og or not og["og:image"]: + if "og:image" not in og: + return + + # Remove the raw image URL, this will be replaced with an MXC URL, if successful. + image_url = og.pop("og:image") + if not image_url: return # The image URL from the HTML might be relative to the previewed page, # convert it to an URL which can be requested directly. - image_url = og["og:image"] url_parts = urlparse(image_url) if url_parts.scheme != "data": image_url = urljoin(media_info.uri, image_url) @@ -599,7 +603,16 @@ class PreviewUrlResource(DirectServeJsonResource): # FIXME: it might be cleaner to use the same flow as the main /preview_url # request itself and benefit from the same caching etc. But for now we # just rely on the caching on the master request to speed things up. - image_info = await self._handle_url(image_url, user, allow_data_urls=True) + try: + image_info = await self._handle_url(image_url, user, allow_data_urls=True) + except Exception as e: + # Pre-caching the image failed, don't block the entire URL preview. + logger.warning( + "Pre-caching image failed during URL preview: %s errored with %s", + image_url, + e, + ) + return if _is_media(image_info.media_type): # TODO: make sure we don't choke on white-on-transparent images @@ -611,13 +624,11 @@ class PreviewUrlResource(DirectServeJsonResource): og["og:image:width"] = dims["width"] og["og:image:height"] = dims["height"] else: - logger.warning("Couldn't get dims for %s", og["og:image"]) + logger.warning("Couldn't get dims for %s", image_url) og["og:image"] = f"mxc://{self.server_name}/{image_info.filesystem_id}" og["og:image:type"] = image_info.media_type og["matrix:image:size"] = image_info.media_length - else: - del og["og:image"] async def _handle_oembed_response( self, url: str, media_info: MediaInfo, expiration_ms: int diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py index 3b24d0ace6..2c321f8d04 100644 --- a/tests/rest/media/v1/test_url_preview.py +++ b/tests/rest/media/v1/test_url_preview.py @@ -656,6 +656,41 @@ class URLPreviewTests(unittest.HomeserverTestCase): server.data, ) + def test_nonexistent_image(self) -> None: + """If the preview image doesn't exist, ensure some data is returned.""" + self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")] + + end_content = ( + b"""""" + ) + + channel = self.make_request( + "GET", + "preview_url?url=http://matrix.org", + shorthand=False, + await_result=False, + ) + self.pump() + + client = self.reactor.tcpClients[0][2].buildProtocol(None) + server = AccumulatingProtocol() + server.makeConnection(FakeTransport(client, self.reactor)) + client.makeConnection(FakeTransport(server, self.reactor)) + client.dataReceived( + ( + b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n" + b'Content-Type: text/html; charset="utf8"\r\n\r\n' + ) + % (len(end_content),) + + end_content + ) + + self.pump() + self.assertEqual(channel.code, 200) + + # The image should not be in the result. + self.assertNotIn("og:image", channel.json_body) + def test_data_url(self) -> None: """ Requesting to preview a data URL is not supported. From 44de53bb79f961147386ea2a8bfbeb54b007cd41 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 6 Jun 2022 18:46:11 +0300 Subject: [PATCH 13/14] Reduce state pulled from DB due to sending typing and receipts over federation (#12964) Reducing the amount of state we pull from the DB is useful as fetching state is expensive in terms of DB, CPU and memory. --- changelog.d/12964.misc | 1 + synapse/federation/sender/__init__.py | 6 +++- synapse/handlers/typing.py | 7 ++-- synapse/state/__init__.py | 4 --- synapse/storage/_base.py | 1 + synapse/storage/controllers/state.py | 8 +++++ synapse/storage/databases/main/roommember.py | 37 ++++++++++++++++++++ tests/federation/test_federation_sender.py | 14 ++++---- tests/handlers/test_typing.py | 6 ++-- 9 files changed, 68 insertions(+), 16 deletions(-) create mode 100644 changelog.d/12964.misc diff --git a/changelog.d/12964.misc b/changelog.d/12964.misc new file mode 100644 index 0000000000..d57e1aca6b --- /dev/null +++ b/changelog.d/12964.misc @@ -0,0 +1 @@ +Reduce the amount of state we pull from the DB. diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py index dbe303ed9b..99a794c042 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py @@ -245,6 +245,8 @@ class FederationSender(AbstractFederationSender): self.store = hs.get_datastores().main self.state = hs.get_state_handler() + self._storage_controllers = hs.get_storage_controllers() + self.clock = hs.get_clock() self.is_mine_id = hs.is_mine_id @@ -602,7 +604,9 @@ class FederationSender(AbstractFederationSender): room_id = receipt.room_id # Work out which remote servers should be poked and poke them. - domains_set = await self.state.get_current_hosts_in_room(room_id) + domains_set = await self._storage_controllers.state.get_current_hosts_in_room( + room_id + ) domains = [ d for d in domains_set diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index 0aeab86bbb..d104ea07fe 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -59,6 +59,7 @@ class FollowerTypingHandler: def __init__(self, hs: "HomeServer"): self.store = hs.get_datastores().main + self._storage_controllers = hs.get_storage_controllers() self.server_name = hs.config.server.server_name self.clock = hs.get_clock() self.is_mine_id = hs.is_mine_id @@ -131,7 +132,6 @@ class FollowerTypingHandler: return try: - users = await self.store.get_users_in_room(member.room_id) self._member_last_federation_poke[member] = self.clock.time_msec() now = self.clock.time_msec() @@ -139,7 +139,10 @@ class FollowerTypingHandler: now=now, obj=member, then=now + FEDERATION_PING_INTERVAL ) - for domain in {get_domain_from_id(u) for u in users}: + hosts = await self._storage_controllers.state.get_current_hosts_in_room( + member.room_id + ) + for domain in hosts: if domain != self.server_name: logger.debug("sending typing update to %s", domain) self.federation.build_and_send_edu( diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index ab68e2b6a4..da25f20ae5 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -172,10 +172,6 @@ class StateHandler: entry = await self.resolve_state_groups_for_events(room_id, latest_event_ids) return await self.store.get_joined_users_from_state(room_id, entry) - async def get_current_hosts_in_room(self, room_id: str) -> FrozenSet[str]: - event_ids = await self.store.get_latest_event_ids_in_room(room_id) - return await self.get_hosts_in_room_at_events(room_id, event_ids) - async def get_hosts_in_room_at_events( self, room_id: str, event_ids: Collection[str] ) -> FrozenSet[str]: diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 57bd74700e..abfc56b061 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -71,6 +71,7 @@ class SQLBaseStore(metaclass=ABCMeta): self._attempt_to_invalidate_cache("is_host_joined", (room_id, host)) if members_changed: self._attempt_to_invalidate_cache("get_users_in_room", (room_id,)) + self._attempt_to_invalidate_cache("get_current_hosts_in_room", (room_id,)) self._attempt_to_invalidate_cache( "get_users_in_room_with_profiles", (room_id,) ) diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py index 63a78ebc87..3b4cdb67eb 100644 --- a/synapse/storage/controllers/state.py +++ b/synapse/storage/controllers/state.py @@ -23,6 +23,7 @@ from typing import ( List, Mapping, Optional, + Set, Tuple, ) @@ -482,3 +483,10 @@ class StateStorageController: room_id, StateFilter.from_types((key,)) ) return state_map.get(key) + + async def get_current_hosts_in_room(self, room_id: str) -> Set[str]: + """Get current hosts in room based on current state.""" + + await self._partial_state_room_tracker.await_full_state(room_id) + + return await self.stores.main.get_current_hosts_in_room(room_id) diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index e222b7bd1f..31bc8c5601 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -893,6 +893,43 @@ class RoomMemberWorkerStore(EventsWorkerStore): return True + @cached(iterable=True, max_entries=10000) + async def get_current_hosts_in_room(self, room_id: str) -> Set[str]: + """Get current hosts in room based on current state.""" + + # First we check if we already have `get_users_in_room` in the cache, as + # we can just calculate result from that + users = self.get_users_in_room.cache.get_immediate( + (room_id,), None, update_metrics=False + ) + if users is not None: + return {get_domain_from_id(u) for u in users} + + if isinstance(self.database_engine, Sqlite3Engine): + # If we're using SQLite then let's just always use + # `get_users_in_room` rather than funky SQL. + users = await self.get_users_in_room(room_id) + return {get_domain_from_id(u) for u in users} + + # For PostgreSQL we can use a regex to pull out the domains from the + # joined users in `current_state_events` via regex. + + def get_current_hosts_in_room_txn(txn: LoggingTransaction) -> Set[str]: + sql = """ + SELECT DISTINCT substring(state_key FROM '@[^:]*:(.*)$') + FROM current_state_events + WHERE + type = 'm.room.member' + AND membership = 'join' + AND room_id = ? + """ + txn.execute(sql, (room_id,)) + return {d for d, in txn} + + return await self.db_pool.runInteraction( + "get_current_hosts_in_room", get_current_hosts_in_room_txn + ) + async def get_joined_hosts( self, room_id: str, state_entry: "_StateCacheEntry" ) -> FrozenSet[str]: diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py index b5be727fe4..01a1db6115 100644 --- a/tests/federation/test_federation_sender.py +++ b/tests/federation/test_federation_sender.py @@ -30,16 +30,16 @@ from tests.unittest import HomeserverTestCase, override_config class FederationSenderReceiptsTestCases(HomeserverTestCase): def make_homeserver(self, reactor, clock): - mock_state_handler = Mock(spec=["get_current_hosts_in_room"]) - # Ensure a new Awaitable is created for each call. - mock_state_handler.get_current_hosts_in_room.return_value = make_awaitable( - ["test", "host2"] - ) - return self.setup_test_homeserver( - state_handler=mock_state_handler, + hs = self.setup_test_homeserver( federation_transport_client=Mock(spec=["send_transaction"]), ) + hs.get_storage_controllers().state.get_current_hosts_in_room = Mock( + return_value=make_awaitable({"test", "host2"}) + ) + + return hs + @override_config({"send_federation": True}) def test_send_receipts(self): mock_send_transaction = ( diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index 14a0ee4922..7af1333126 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -129,10 +129,12 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): hs.get_event_auth_handler().check_host_in_room = check_host_in_room - def get_joined_hosts_for_room(room_id: str): + async def get_current_hosts_in_room(room_id: str): return {member.domain for member in self.room_members} - self.datastore.get_joined_hosts_for_room = get_joined_hosts_for_room + hs.get_storage_controllers().state.get_current_hosts_in_room = ( + get_current_hosts_in_room + ) async def get_users_in_room(room_id: str): return {str(u) for u in self.room_members} From f7baffd8ece67c96fac6cd17d50c4aba92f323c5 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 6 Jun 2022 13:20:05 -0400 Subject: [PATCH 14/14] Remove remaining pieces of groups code. (#12966) * Remove an unused stream ID generator. * Remove the now unused remote profile cache. --- changelog.d/12966.removal | 1 + synapse/_scripts/synapse_port_db.py | 1 + synapse/handlers/profile.py | 83 +------------- synapse/storage/databases/main/__init__.py | 18 --- synapse/storage/databases/main/profile.py | 107 +----------------- .../storage/databases/main/purge_events.py | 2 - synapse/storage/schema/__init__.py | 1 + tests/rest/admin/test_room.py | 2 - 8 files changed, 6 insertions(+), 209 deletions(-) create mode 100644 changelog.d/12966.removal diff --git a/changelog.d/12966.removal b/changelog.d/12966.removal new file mode 100644 index 0000000000..41f6fae5da --- /dev/null +++ b/changelog.d/12966.removal @@ -0,0 +1 @@ +Remove support for the non-standard groups/communities feature from Synapse. diff --git a/synapse/_scripts/synapse_port_db.py b/synapse/_scripts/synapse_port_db.py index 4939573f30..361b51d2fa 100755 --- a/synapse/_scripts/synapse_port_db.py +++ b/synapse/_scripts/synapse_port_db.py @@ -182,6 +182,7 @@ IGNORED_TABLES = { "groups", "local_group_membership", "local_group_updates", + "remote_profile_cache", } diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 239b0aa744..6eed3826a7 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -23,14 +23,7 @@ from synapse.api.errors import ( StoreError, SynapseError, ) -from synapse.metrics.background_process_metrics import wrap_as_background_process -from synapse.types import ( - JsonDict, - Requester, - UserID, - create_requester, - get_domain_from_id, -) +from synapse.types import JsonDict, Requester, UserID, create_requester from synapse.util.caches.descriptors import cached from synapse.util.stringutils import parse_and_validate_mxc_uri @@ -50,9 +43,6 @@ class ProfileHandler: delegate to master when necessary. """ - PROFILE_UPDATE_MS = 60 * 1000 - PROFILE_UPDATE_EVERY_MS = 24 * 60 * 60 * 1000 - def __init__(self, hs: "HomeServer"): self.store = hs.get_datastores().main self.clock = hs.get_clock() @@ -73,11 +63,6 @@ class ProfileHandler: self._third_party_rules = hs.get_third_party_event_rules() - if hs.config.worker.run_background_tasks: - self.clock.looping_call( - self._update_remote_profile_cache, self.PROFILE_UPDATE_MS - ) - async def get_profile(self, user_id: str) -> JsonDict: target_user = UserID.from_string(user_id) @@ -116,30 +101,6 @@ class ProfileHandler: raise SynapseError(502, "Failed to fetch profile") raise e.to_synapse_error() - async def get_profile_from_cache(self, user_id: str) -> JsonDict: - """Get the profile information from our local cache. If the user is - ours then the profile information will always be correct. Otherwise, - it may be out of date/missing. - """ - target_user = UserID.from_string(user_id) - if self.hs.is_mine(target_user): - try: - displayname = await self.store.get_profile_displayname( - target_user.localpart - ) - avatar_url = await self.store.get_profile_avatar_url( - target_user.localpart - ) - except StoreError as e: - if e.code == 404: - raise SynapseError(404, "Profile was not found", Codes.NOT_FOUND) - raise - - return {"displayname": displayname, "avatar_url": avatar_url} - else: - profile = await self.store.get_from_remote_profile_cache(user_id) - return profile or {} - async def get_displayname(self, target_user: UserID) -> Optional[str]: if self.hs.is_mine(target_user): try: @@ -509,45 +470,3 @@ class ProfileHandler: # so we act as if we couldn't find the profile. raise SynapseError(403, "Profile isn't available", Codes.FORBIDDEN) raise - - @wrap_as_background_process("Update remote profile") - async def _update_remote_profile_cache(self) -> None: - """Called periodically to check profiles of remote users we haven't - checked in a while. - """ - entries = await self.store.get_remote_profile_cache_entries_that_expire( - last_checked=self.clock.time_msec() - self.PROFILE_UPDATE_EVERY_MS - ) - - for user_id, displayname, avatar_url in entries: - is_subscribed = await self.store.is_subscribed_remote_profile_for_user( - user_id - ) - if not is_subscribed: - await self.store.maybe_delete_remote_profile_cache(user_id) - continue - - try: - profile = await self.federation.make_query( - destination=get_domain_from_id(user_id), - query_type="profile", - args={"user_id": user_id}, - ignore_backoff=True, - ) - except Exception: - logger.exception("Failed to get avatar_url") - - await self.store.update_remote_profile_cache( - user_id, displayname, avatar_url - ) - continue - - new_name = profile.get("displayname") - if not isinstance(new_name, str): - new_name = None - new_avatar = profile.get("avatar_url") - if not isinstance(new_avatar, str): - new_avatar = None - - # We always hit update to update the last_check timestamp - await self.store.update_remote_profile_cache(user_id, new_name, new_avatar) diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index d545a1c002..11d9d16c19 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -151,10 +151,6 @@ class DataStore( ], ) - self._group_updates_id_gen = StreamIdGenerator( - db_conn, "local_group_updates", "stream_id" - ) - self._cache_id_gen: Optional[MultiWriterIdGenerator] if isinstance(self.database_engine, PostgresEngine): # We set the `writers` to an empty list here as we don't care about @@ -197,20 +193,6 @@ class DataStore( prefilled_cache=curr_state_delta_prefill, ) - _group_updates_prefill, min_group_updates_id = self.db_pool.get_cache_dict( - db_conn, - "local_group_updates", - entity_column="user_id", - stream_column="stream_id", - max_value=self._group_updates_id_gen.get_current_token(), - limit=1000, - ) - self._group_updates_stream_cache = StreamChangeCache( - "_group_updates_stream_cache", - min_group_updates_id, - prefilled_cache=_group_updates_prefill, - ) - self._stream_order_on_start = self.get_room_max_stream_ordering() self._min_stream_order_on_start = self.get_room_min_stream_ordering() diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py index e197b7203e..a1747f04ce 100644 --- a/synapse/storage/databases/main/profile.py +++ b/synapse/storage/databases/main/profile.py @@ -11,11 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional +from typing import Optional from synapse.api.errors import StoreError from synapse.storage._base import SQLBaseStore -from synapse.storage.database import LoggingTransaction from synapse.storage.databases.main.roommember import ProfileInfo @@ -55,17 +54,6 @@ class ProfileWorkerStore(SQLBaseStore): desc="get_profile_avatar_url", ) - async def get_from_remote_profile_cache( - self, user_id: str - ) -> Optional[Dict[str, Any]]: - return await self.db_pool.simple_select_one( - table="remote_profile_cache", - keyvalues={"user_id": user_id}, - retcols=("displayname", "avatar_url"), - allow_none=True, - desc="get_from_remote_profile_cache", - ) - async def create_profile(self, user_localpart: str) -> None: await self.db_pool.simple_insert( table="profiles", values={"user_id": user_localpart}, desc="create_profile" @@ -91,97 +79,6 @@ class ProfileWorkerStore(SQLBaseStore): desc="set_profile_avatar_url", ) - async def update_remote_profile_cache( - self, user_id: str, displayname: Optional[str], avatar_url: Optional[str] - ) -> int: - return await self.db_pool.simple_update( - table="remote_profile_cache", - keyvalues={"user_id": user_id}, - updatevalues={ - "displayname": displayname, - "avatar_url": avatar_url, - "last_check": self._clock.time_msec(), - }, - desc="update_remote_profile_cache", - ) - - async def maybe_delete_remote_profile_cache(self, user_id: str) -> None: - """Check if we still care about the remote user's profile, and if we - don't then remove their profile from the cache - """ - subscribed = await self.is_subscribed_remote_profile_for_user(user_id) - if not subscribed: - await self.db_pool.simple_delete( - table="remote_profile_cache", - keyvalues={"user_id": user_id}, - desc="delete_remote_profile_cache", - ) - - async def is_subscribed_remote_profile_for_user(self, user_id: str) -> bool: - """Check whether we are interested in a remote user's profile.""" - res: Optional[str] = await self.db_pool.simple_select_one_onecol( - table="group_users", - keyvalues={"user_id": user_id}, - retcol="user_id", - allow_none=True, - desc="should_update_remote_profile_cache_for_user", - ) - - if res: - return True - - res = await self.db_pool.simple_select_one_onecol( - table="group_invites", - keyvalues={"user_id": user_id}, - retcol="user_id", - allow_none=True, - desc="should_update_remote_profile_cache_for_user", - ) - - if res: - return True - return False - - async def get_remote_profile_cache_entries_that_expire( - self, last_checked: int - ) -> List[Dict[str, str]]: - """Get all users who haven't been checked since `last_checked`""" - - def _get_remote_profile_cache_entries_that_expire_txn( - txn: LoggingTransaction, - ) -> List[Dict[str, str]]: - sql = """ - SELECT user_id, displayname, avatar_url - FROM remote_profile_cache - WHERE last_check < ? - """ - - txn.execute(sql, (last_checked,)) - - return self.db_pool.cursor_to_dict(txn) - - return await self.db_pool.runInteraction( - "get_remote_profile_cache_entries_that_expire", - _get_remote_profile_cache_entries_that_expire_txn, - ) - class ProfileStore(ProfileWorkerStore): - async def add_remote_profile_cache( - self, user_id: str, displayname: str, avatar_url: str - ) -> None: - """Ensure we are caching the remote user's profiles. - - This should only be called when `is_subscribed_remote_profile_for_user` - would return true for the user. - """ - await self.db_pool.simple_upsert( - table="remote_profile_cache", - keyvalues={"user_id": user_id}, - values={ - "displayname": displayname, - "avatar_url": avatar_url, - "last_check": self._clock.time_msec(), - }, - desc="add_remote_profile_cache", - ) + pass diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py index 2353c120e9..ba385f9fc4 100644 --- a/synapse/storage/databases/main/purge_events.py +++ b/synapse/storage/databases/main/purge_events.py @@ -393,7 +393,6 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore): "partial_state_events", "events", "federation_inbound_events_staging", - "group_rooms", "local_current_membership", "partial_state_rooms_servers", "partial_state_rooms", @@ -413,7 +412,6 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore): "e2e_room_keys", "event_push_summary", "pusher_throttle", - "group_summary_rooms", "room_account_data", "room_tags", # "rooms" happens last, to keep the foreign keys in the other tables diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py index 19466150d4..5843fae605 100644 --- a/synapse/storage/schema/__init__.py +++ b/synapse/storage/schema/__init__.py @@ -70,6 +70,7 @@ Changes in SCHEMA_VERSION = 70: Changes in SCHEMA_VERSION = 71: - event_edges.room_id is no longer read from. + - Tables related to groups are no longer accessed. """ diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index 608d3f2dc3..ca6af9417b 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -2467,7 +2467,6 @@ PURGE_TABLES = [ "event_push_actions", "event_search", "events", - "group_rooms", "receipts_graph", "receipts_linearized", "room_aliases", @@ -2484,7 +2483,6 @@ PURGE_TABLES = [ "e2e_room_keys", "event_push_summary", "pusher_throttle", - "group_summary_rooms", "room_account_data", "room_tags", # "state_groups", # Current impl leaves orphaned state groups around.