1
0

Stub in threads sliding sync extension

This commit is contained in:
Devon Hudson
2025-10-01 10:04:29 -06:00
parent ad8dcc2119
commit cd4f4223de
6 changed files with 145 additions and 0 deletions
+3
View File
@@ -595,3 +595,6 @@ class ExperimentalConfig(Config):
# MSC4306: Thread Subscriptions
# (and MSC4308: Thread Subscriptions extension to Sliding Sync)
self.msc4306_enabled: bool = experimental.get("msc4306_enabled", False)
# MSC4360: Threads Extension to Sliding Sync
self.msc4360_enabled: bool = experimental.get("msc4360_enabled", False)
@@ -77,6 +77,7 @@ class SlidingSyncExtensionHandler:
self.device_handler = hs.get_device_handler()
self.push_rules_handler = hs.get_push_rules_handler()
self._enable_thread_subscriptions = hs.config.experimental.msc4306_enabled
self._enable_threads_ext = hs.config.experimental.msc4360_enabled
@trace
async def get_extensions_response(
@@ -177,6 +178,18 @@ class SlidingSyncExtensionHandler:
from_token=from_token,
)
threads_coro = None
if (
sync_config.extensions.threads is not None
and self._enable_threads_ext
):
threads_coro = self.get_threads_extension_response(
sync_config=sync_config,
threads_request=sync_config.extensions.threads,
to_token=to_token,
from_token=from_token,
)
(
to_device_response,
e2ee_response,
@@ -184,6 +197,7 @@ class SlidingSyncExtensionHandler:
receipts_response,
typing_response,
thread_subs_response,
threads_response,
) = await gather_optional_coroutines(
to_device_coro,
e2ee_coro,
@@ -191,6 +205,7 @@ class SlidingSyncExtensionHandler:
receipts_coro,
typing_coro,
thread_subs_coro,
threads_coro,
)
return SlidingSyncResult.Extensions(
@@ -200,6 +215,7 @@ class SlidingSyncExtensionHandler:
receipts=receipts_response,
typing=typing_response,
thread_subscriptions=thread_subs_response,
threads=threads_response,
)
def find_relevant_room_ids_for_extension(
@@ -970,3 +986,35 @@ class SlidingSyncExtensionHandler:
unsubscribed=unsubscribed_threads,
prev_batch=prev_batch,
)
async def get_threads_extension_response(
self,
sync_config: SlidingSyncConfig,
threads_request: SlidingSyncConfig.Extensions.ThreadsExtension,
to_token: StreamToken,
from_token: Optional[SlidingSyncStreamToken],
) -> Optional[SlidingSyncResult.Extensions.ThreadsExtension]:
"""Handle Threads extension (MSC4360)
Args:
sync_config: Sync configuration
threads_request: The threads extension from the request
to_token: The point in the stream to sync up to.
from_token: The point in the stream to sync from.
Returns:
the response (None if empty or threads extension is disabled)
"""
if not threads_request.enabled:
return None
# TODO: implement
_limit = threads_request.limit
prev_batch = None
return SlidingSyncResult.Extensions.ThreadsExtension(
updates=None,
prev_batch=prev_batch,
)
+31
View File
@@ -648,6 +648,7 @@ class SlidingSyncRestServlet(RestServlet):
- receipts (MSC3960)
- account data (MSC3959)
- thread subscriptions (MSC4308)
- threads (MSC4360)
Request query parameters:
timeout: How long to wait for new events in milliseconds.
@@ -1091,6 +1092,12 @@ class SlidingSyncRestServlet(RestServlet):
_serialise_thread_subscriptions(extensions.thread_subscriptions)
)
# excludes both None and falsy `threads`
if extensions.threads:
serialized_extensions["io.element.msc4360.threads"] = _serialise_threads(
extensions.threads
)
return serialized_extensions
@@ -1127,6 +1134,30 @@ def _serialise_thread_subscriptions(
return out
# TODO: is this necessary for serialization?
def _serialise_threads(
threads: SlidingSyncResult.Extensions.ThreadsExtension,
) -> JsonDict:
out: JsonDict = {}
if threads.updates:
out["updates"] = {
room_id: {
thread_root_id: {
"thread_root": update.thread_root,
"prev_batch": update.prev_batch,
}
for thread_root_id, update in thread_updates.items()
}
for room_id, thread_updates in threads.updates.items()
}
if threads.prev_batch:
out["prev_batch"] = threads.prev_batch.to_string()
return out
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
SyncRestServlet(hs).register(http_server)
+24
View File
@@ -396,12 +396,35 @@ class SlidingSyncResult:
or bool(self.prev_batch)
)
@attr.s(slots=True, frozen=True, auto_attribs=True)
class ThreadsExtension:
# TODO: comment
"""The Threads extension (MSC4360)
Attributes:
"""
@attr.s(slots=True, frozen=True, auto_attribs=True)
class ThreadUpdates:
# TODO: comment
thread_root: Optional[EventBase]
# TODO: comment
prev_batch: Optional[StreamToken]
updates: Optional[Mapping[str, Mapping[str, ThreadUpdates]]]
prev_batch: Optional[ThreadSubscriptionsToken]
def __bool__(self) -> bool:
return bool(self.updates) or bool(self.prev_batch)
to_device: Optional[ToDeviceExtension] = None
e2ee: Optional[E2eeExtension] = None
account_data: Optional[AccountDataExtension] = None
receipts: Optional[ReceiptsExtension] = None
typing: Optional[TypingExtension] = None
thread_subscriptions: Optional[ThreadSubscriptionsExtension] = None
threads: Optional[ThreadsExtension] = None
def __bool__(self) -> bool:
return bool(
@@ -411,6 +434,7 @@ class SlidingSyncResult:
or self.receipts
or self.typing
or self.thread_subscriptions
or self.threads
)
next_pos: SlidingSyncStreamToken
+14
View File
@@ -376,6 +376,19 @@ class SlidingSyncBody(RequestBodyModel):
enabled: Optional[StrictBool] = False
limit: StrictInt = 100
class ThreadsExtension(RequestBodyModel):
"""The Threads extension (MSC4360)
Attributes:
enabled
include_roots: whether to include thread root events in the extension response.
limit: maximum number of thread updates to return.
"""
enabled: Optional[StrictBool] = False
include_roots: StrictBool = False
limit: StrictInt = 100
to_device: Optional[ToDeviceExtension] = None
e2ee: Optional[E2eeExtension] = None
account_data: Optional[AccountDataExtension] = None
@@ -384,6 +397,7 @@ class SlidingSyncBody(RequestBodyModel):
thread_subscriptions: Optional[ThreadSubscriptionsExtension] = Field(
alias="io.element.msc4308.thread_subscriptions"
)
threads: Optional[ThreadsExtension] = Field(alias="io.element.msc4360.threads")
conn_id: Optional[StrictStr]
+25
View File
@@ -347,6 +347,7 @@ T3 = TypeVar("T3")
T4 = TypeVar("T4")
T5 = TypeVar("T5")
T6 = TypeVar("T6")
T7 = TypeVar("T7")
@overload
@@ -478,6 +479,30 @@ async def gather_optional_coroutines(
]: ...
@overload
async def gather_optional_coroutines(
*coroutines: Unpack[
Tuple[
Optional[Coroutine[Any, Any, T1]],
Optional[Coroutine[Any, Any, T2]],
Optional[Coroutine[Any, Any, T3]],
Optional[Coroutine[Any, Any, T4]],
Optional[Coroutine[Any, Any, T5]],
Optional[Coroutine[Any, Any, T6]],
Optional[Coroutine[Any, Any, T7]],
]
],
) -> Tuple[
Optional[T1],
Optional[T2],
Optional[T3],
Optional[T4],
Optional[T5],
Optional[T6],
Optional[T7],
]: ...
async def gather_optional_coroutines(
*coroutines: Unpack[Tuple[Optional[Coroutine[Any, Any, T1]], ...]],
) -> Tuple[Optional[T1], ...]: