diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index d7a3d67558..8546e2fc40 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -595,3 +595,6 @@ class ExperimentalConfig(Config): # MSC4306: Thread Subscriptions # (and MSC4308: Thread Subscriptions extension to Sliding Sync) self.msc4306_enabled: bool = experimental.get("msc4306_enabled", False) + + # MSC4360: Threads Extension to Sliding Sync + self.msc4360_enabled: bool = experimental.get("msc4360_enabled", False) diff --git a/synapse/handlers/sliding_sync/extensions.py b/synapse/handlers/sliding_sync/extensions.py index 25ee954b7f..5ce185ff80 100644 --- a/synapse/handlers/sliding_sync/extensions.py +++ b/synapse/handlers/sliding_sync/extensions.py @@ -77,6 +77,7 @@ class SlidingSyncExtensionHandler: self.device_handler = hs.get_device_handler() self.push_rules_handler = hs.get_push_rules_handler() self._enable_thread_subscriptions = hs.config.experimental.msc4306_enabled + self._enable_threads_ext = hs.config.experimental.msc4360_enabled @trace async def get_extensions_response( @@ -177,6 +178,18 @@ class SlidingSyncExtensionHandler: from_token=from_token, ) + threads_coro = None + if ( + sync_config.extensions.threads is not None + and self._enable_threads_ext + ): + threads_coro = self.get_threads_extension_response( + sync_config=sync_config, + threads_request=sync_config.extensions.threads, + to_token=to_token, + from_token=from_token, + ) + ( to_device_response, e2ee_response, @@ -184,6 +197,7 @@ class SlidingSyncExtensionHandler: receipts_response, typing_response, thread_subs_response, + threads_response, ) = await gather_optional_coroutines( to_device_coro, e2ee_coro, @@ -191,6 +205,7 @@ class SlidingSyncExtensionHandler: receipts_coro, typing_coro, thread_subs_coro, + threads_coro, ) return SlidingSyncResult.Extensions( @@ -200,6 +215,7 @@ class SlidingSyncExtensionHandler: receipts=receipts_response, typing=typing_response, thread_subscriptions=thread_subs_response, + threads=threads_response, ) def find_relevant_room_ids_for_extension( @@ -970,3 +986,35 @@ class SlidingSyncExtensionHandler: unsubscribed=unsubscribed_threads, prev_batch=prev_batch, ) + + async def get_threads_extension_response( + self, + sync_config: SlidingSyncConfig, + threads_request: SlidingSyncConfig.Extensions.ThreadsExtension, + to_token: StreamToken, + from_token: Optional[SlidingSyncStreamToken], + ) -> Optional[SlidingSyncResult.Extensions.ThreadsExtension]: + """Handle Threads extension (MSC4360) + + Args: + sync_config: Sync configuration + threads_request: The threads extension from the request + to_token: The point in the stream to sync up to. + from_token: The point in the stream to sync from. + + Returns: + the response (None if empty or threads extension is disabled) + """ + if not threads_request.enabled: + return None + + # TODO: implement + + _limit = threads_request.limit + + prev_batch = None + + return SlidingSyncResult.Extensions.ThreadsExtension( + updates=None, + prev_batch=prev_batch, + ) diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py index bb63b51599..6b08fa148c 100644 --- a/synapse/rest/client/sync.py +++ b/synapse/rest/client/sync.py @@ -648,6 +648,7 @@ class SlidingSyncRestServlet(RestServlet): - receipts (MSC3960) - account data (MSC3959) - thread subscriptions (MSC4308) + - threads (MSC4360) Request query parameters: timeout: How long to wait for new events in milliseconds. @@ -1091,6 +1092,12 @@ class SlidingSyncRestServlet(RestServlet): _serialise_thread_subscriptions(extensions.thread_subscriptions) ) + # excludes both None and falsy `threads` + if extensions.threads: + serialized_extensions["io.element.msc4360.threads"] = _serialise_threads( + extensions.threads + ) + return serialized_extensions @@ -1127,6 +1134,30 @@ def _serialise_thread_subscriptions( return out +# TODO: is this necessary for serialization? +def _serialise_threads( + threads: SlidingSyncResult.Extensions.ThreadsExtension, +) -> JsonDict: + out: JsonDict = {} + + if threads.updates: + out["updates"] = { + room_id: { + thread_root_id: { + "thread_root": update.thread_root, + "prev_batch": update.prev_batch, + } + for thread_root_id, update in thread_updates.items() + } + for room_id, thread_updates in threads.updates.items() + } + + if threads.prev_batch: + out["prev_batch"] = threads.prev_batch.to_string() + + return out + + def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: SyncRestServlet(hs).register(http_server) diff --git a/synapse/types/handlers/sliding_sync.py b/synapse/types/handlers/sliding_sync.py index b7bc565464..afa564c6b1 100644 --- a/synapse/types/handlers/sliding_sync.py +++ b/synapse/types/handlers/sliding_sync.py @@ -396,12 +396,35 @@ class SlidingSyncResult: or bool(self.prev_batch) ) + @attr.s(slots=True, frozen=True, auto_attribs=True) + class ThreadsExtension: + # TODO: comment + """The Threads extension (MSC4360) + + Attributes: + """ + + @attr.s(slots=True, frozen=True, auto_attribs=True) + class ThreadUpdates: + # TODO: comment + thread_root: Optional[EventBase] + + # TODO: comment + prev_batch: Optional[StreamToken] + + updates: Optional[Mapping[str, Mapping[str, ThreadUpdates]]] + prev_batch: Optional[ThreadSubscriptionsToken] + + def __bool__(self) -> bool: + return bool(self.updates) or bool(self.prev_batch) + to_device: Optional[ToDeviceExtension] = None e2ee: Optional[E2eeExtension] = None account_data: Optional[AccountDataExtension] = None receipts: Optional[ReceiptsExtension] = None typing: Optional[TypingExtension] = None thread_subscriptions: Optional[ThreadSubscriptionsExtension] = None + threads: Optional[ThreadsExtension] = None def __bool__(self) -> bool: return bool( @@ -411,6 +434,7 @@ class SlidingSyncResult: or self.receipts or self.typing or self.thread_subscriptions + or self.threads ) next_pos: SlidingSyncStreamToken diff --git a/synapse/types/rest/client/__init__.py b/synapse/types/rest/client/__init__.py index 11d7e59b43..e8f08e434a 100644 --- a/synapse/types/rest/client/__init__.py +++ b/synapse/types/rest/client/__init__.py @@ -376,6 +376,19 @@ class SlidingSyncBody(RequestBodyModel): enabled: Optional[StrictBool] = False limit: StrictInt = 100 + class ThreadsExtension(RequestBodyModel): + """The Threads extension (MSC4360) + + Attributes: + enabled + include_roots: whether to include thread root events in the extension response. + limit: maximum number of thread updates to return. + """ + + enabled: Optional[StrictBool] = False + include_roots: StrictBool = False + limit: StrictInt = 100 + to_device: Optional[ToDeviceExtension] = None e2ee: Optional[E2eeExtension] = None account_data: Optional[AccountDataExtension] = None @@ -384,6 +397,7 @@ class SlidingSyncBody(RequestBodyModel): thread_subscriptions: Optional[ThreadSubscriptionsExtension] = Field( alias="io.element.msc4308.thread_subscriptions" ) + threads: Optional[ThreadsExtension] = Field(alias="io.element.msc4360.threads") conn_id: Optional[StrictStr] diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index 1f90988525..5c828420a2 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -347,6 +347,7 @@ T3 = TypeVar("T3") T4 = TypeVar("T4") T5 = TypeVar("T5") T6 = TypeVar("T6") +T7 = TypeVar("T7") @overload @@ -478,6 +479,30 @@ async def gather_optional_coroutines( ]: ... +@overload +async def gather_optional_coroutines( + *coroutines: Unpack[ + Tuple[ + Optional[Coroutine[Any, Any, T1]], + Optional[Coroutine[Any, Any, T2]], + Optional[Coroutine[Any, Any, T3]], + Optional[Coroutine[Any, Any, T4]], + Optional[Coroutine[Any, Any, T5]], + Optional[Coroutine[Any, Any, T6]], + Optional[Coroutine[Any, Any, T7]], + ] + ], +) -> Tuple[ + Optional[T1], + Optional[T2], + Optional[T3], + Optional[T4], + Optional[T5], + Optional[T6], + Optional[T7], +]: ... + + async def gather_optional_coroutines( *coroutines: Unpack[Tuple[Optional[Coroutine[Any, Any, T1]], ...]], ) -> Tuple[Optional[T1], ...]: