Stub in threads sliding sync extension
This commit is contained in:
@@ -595,3 +595,6 @@ class ExperimentalConfig(Config):
|
||||
# MSC4306: Thread Subscriptions
|
||||
# (and MSC4308: Thread Subscriptions extension to Sliding Sync)
|
||||
self.msc4306_enabled: bool = experimental.get("msc4306_enabled", False)
|
||||
|
||||
# MSC4360: Threads Extension to Sliding Sync
|
||||
self.msc4360_enabled: bool = experimental.get("msc4360_enabled", False)
|
||||
|
||||
@@ -77,6 +77,7 @@ class SlidingSyncExtensionHandler:
|
||||
self.device_handler = hs.get_device_handler()
|
||||
self.push_rules_handler = hs.get_push_rules_handler()
|
||||
self._enable_thread_subscriptions = hs.config.experimental.msc4306_enabled
|
||||
self._enable_threads_ext = hs.config.experimental.msc4360_enabled
|
||||
|
||||
@trace
|
||||
async def get_extensions_response(
|
||||
@@ -177,6 +178,18 @@ class SlidingSyncExtensionHandler:
|
||||
from_token=from_token,
|
||||
)
|
||||
|
||||
threads_coro = None
|
||||
if (
|
||||
sync_config.extensions.threads is not None
|
||||
and self._enable_threads_ext
|
||||
):
|
||||
threads_coro = self.get_threads_extension_response(
|
||||
sync_config=sync_config,
|
||||
threads_request=sync_config.extensions.threads,
|
||||
to_token=to_token,
|
||||
from_token=from_token,
|
||||
)
|
||||
|
||||
(
|
||||
to_device_response,
|
||||
e2ee_response,
|
||||
@@ -184,6 +197,7 @@ class SlidingSyncExtensionHandler:
|
||||
receipts_response,
|
||||
typing_response,
|
||||
thread_subs_response,
|
||||
threads_response,
|
||||
) = await gather_optional_coroutines(
|
||||
to_device_coro,
|
||||
e2ee_coro,
|
||||
@@ -191,6 +205,7 @@ class SlidingSyncExtensionHandler:
|
||||
receipts_coro,
|
||||
typing_coro,
|
||||
thread_subs_coro,
|
||||
threads_coro,
|
||||
)
|
||||
|
||||
return SlidingSyncResult.Extensions(
|
||||
@@ -200,6 +215,7 @@ class SlidingSyncExtensionHandler:
|
||||
receipts=receipts_response,
|
||||
typing=typing_response,
|
||||
thread_subscriptions=thread_subs_response,
|
||||
threads=threads_response,
|
||||
)
|
||||
|
||||
def find_relevant_room_ids_for_extension(
|
||||
@@ -970,3 +986,35 @@ class SlidingSyncExtensionHandler:
|
||||
unsubscribed=unsubscribed_threads,
|
||||
prev_batch=prev_batch,
|
||||
)
|
||||
|
||||
async def get_threads_extension_response(
|
||||
self,
|
||||
sync_config: SlidingSyncConfig,
|
||||
threads_request: SlidingSyncConfig.Extensions.ThreadsExtension,
|
||||
to_token: StreamToken,
|
||||
from_token: Optional[SlidingSyncStreamToken],
|
||||
) -> Optional[SlidingSyncResult.Extensions.ThreadsExtension]:
|
||||
"""Handle Threads extension (MSC4360)
|
||||
|
||||
Args:
|
||||
sync_config: Sync configuration
|
||||
threads_request: The threads extension from the request
|
||||
to_token: The point in the stream to sync up to.
|
||||
from_token: The point in the stream to sync from.
|
||||
|
||||
Returns:
|
||||
the response (None if empty or threads extension is disabled)
|
||||
"""
|
||||
if not threads_request.enabled:
|
||||
return None
|
||||
|
||||
# TODO: implement
|
||||
|
||||
_limit = threads_request.limit
|
||||
|
||||
prev_batch = None
|
||||
|
||||
return SlidingSyncResult.Extensions.ThreadsExtension(
|
||||
updates=None,
|
||||
prev_batch=prev_batch,
|
||||
)
|
||||
|
||||
@@ -648,6 +648,7 @@ class SlidingSyncRestServlet(RestServlet):
|
||||
- receipts (MSC3960)
|
||||
- account data (MSC3959)
|
||||
- thread subscriptions (MSC4308)
|
||||
- threads (MSC4360)
|
||||
|
||||
Request query parameters:
|
||||
timeout: How long to wait for new events in milliseconds.
|
||||
@@ -1091,6 +1092,12 @@ class SlidingSyncRestServlet(RestServlet):
|
||||
_serialise_thread_subscriptions(extensions.thread_subscriptions)
|
||||
)
|
||||
|
||||
# excludes both None and falsy `threads`
|
||||
if extensions.threads:
|
||||
serialized_extensions["io.element.msc4360.threads"] = _serialise_threads(
|
||||
extensions.threads
|
||||
)
|
||||
|
||||
return serialized_extensions
|
||||
|
||||
|
||||
@@ -1127,6 +1134,30 @@ def _serialise_thread_subscriptions(
|
||||
return out
|
||||
|
||||
|
||||
# TODO: is this necessary for serialization?
|
||||
def _serialise_threads(
|
||||
threads: SlidingSyncResult.Extensions.ThreadsExtension,
|
||||
) -> JsonDict:
|
||||
out: JsonDict = {}
|
||||
|
||||
if threads.updates:
|
||||
out["updates"] = {
|
||||
room_id: {
|
||||
thread_root_id: {
|
||||
"thread_root": update.thread_root,
|
||||
"prev_batch": update.prev_batch,
|
||||
}
|
||||
for thread_root_id, update in thread_updates.items()
|
||||
}
|
||||
for room_id, thread_updates in threads.updates.items()
|
||||
}
|
||||
|
||||
if threads.prev_batch:
|
||||
out["prev_batch"] = threads.prev_batch.to_string()
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
||||
SyncRestServlet(hs).register(http_server)
|
||||
|
||||
|
||||
@@ -396,12 +396,35 @@ class SlidingSyncResult:
|
||||
or bool(self.prev_batch)
|
||||
)
|
||||
|
||||
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||
class ThreadsExtension:
|
||||
# TODO: comment
|
||||
"""The Threads extension (MSC4360)
|
||||
|
||||
Attributes:
|
||||
"""
|
||||
|
||||
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||
class ThreadUpdates:
|
||||
# TODO: comment
|
||||
thread_root: Optional[EventBase]
|
||||
|
||||
# TODO: comment
|
||||
prev_batch: Optional[StreamToken]
|
||||
|
||||
updates: Optional[Mapping[str, Mapping[str, ThreadUpdates]]]
|
||||
prev_batch: Optional[ThreadSubscriptionsToken]
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
return bool(self.updates) or bool(self.prev_batch)
|
||||
|
||||
to_device: Optional[ToDeviceExtension] = None
|
||||
e2ee: Optional[E2eeExtension] = None
|
||||
account_data: Optional[AccountDataExtension] = None
|
||||
receipts: Optional[ReceiptsExtension] = None
|
||||
typing: Optional[TypingExtension] = None
|
||||
thread_subscriptions: Optional[ThreadSubscriptionsExtension] = None
|
||||
threads: Optional[ThreadsExtension] = None
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
return bool(
|
||||
@@ -411,6 +434,7 @@ class SlidingSyncResult:
|
||||
or self.receipts
|
||||
or self.typing
|
||||
or self.thread_subscriptions
|
||||
or self.threads
|
||||
)
|
||||
|
||||
next_pos: SlidingSyncStreamToken
|
||||
|
||||
@@ -376,6 +376,19 @@ class SlidingSyncBody(RequestBodyModel):
|
||||
enabled: Optional[StrictBool] = False
|
||||
limit: StrictInt = 100
|
||||
|
||||
class ThreadsExtension(RequestBodyModel):
|
||||
"""The Threads extension (MSC4360)
|
||||
|
||||
Attributes:
|
||||
enabled
|
||||
include_roots: whether to include thread root events in the extension response.
|
||||
limit: maximum number of thread updates to return.
|
||||
"""
|
||||
|
||||
enabled: Optional[StrictBool] = False
|
||||
include_roots: StrictBool = False
|
||||
limit: StrictInt = 100
|
||||
|
||||
to_device: Optional[ToDeviceExtension] = None
|
||||
e2ee: Optional[E2eeExtension] = None
|
||||
account_data: Optional[AccountDataExtension] = None
|
||||
@@ -384,6 +397,7 @@ class SlidingSyncBody(RequestBodyModel):
|
||||
thread_subscriptions: Optional[ThreadSubscriptionsExtension] = Field(
|
||||
alias="io.element.msc4308.thread_subscriptions"
|
||||
)
|
||||
threads: Optional[ThreadsExtension] = Field(alias="io.element.msc4360.threads")
|
||||
|
||||
conn_id: Optional[StrictStr]
|
||||
|
||||
|
||||
@@ -347,6 +347,7 @@ T3 = TypeVar("T3")
|
||||
T4 = TypeVar("T4")
|
||||
T5 = TypeVar("T5")
|
||||
T6 = TypeVar("T6")
|
||||
T7 = TypeVar("T7")
|
||||
|
||||
|
||||
@overload
|
||||
@@ -478,6 +479,30 @@ async def gather_optional_coroutines(
|
||||
]: ...
|
||||
|
||||
|
||||
@overload
|
||||
async def gather_optional_coroutines(
|
||||
*coroutines: Unpack[
|
||||
Tuple[
|
||||
Optional[Coroutine[Any, Any, T1]],
|
||||
Optional[Coroutine[Any, Any, T2]],
|
||||
Optional[Coroutine[Any, Any, T3]],
|
||||
Optional[Coroutine[Any, Any, T4]],
|
||||
Optional[Coroutine[Any, Any, T5]],
|
||||
Optional[Coroutine[Any, Any, T6]],
|
||||
Optional[Coroutine[Any, Any, T7]],
|
||||
]
|
||||
],
|
||||
) -> Tuple[
|
||||
Optional[T1],
|
||||
Optional[T2],
|
||||
Optional[T3],
|
||||
Optional[T4],
|
||||
Optional[T5],
|
||||
Optional[T6],
|
||||
Optional[T7],
|
||||
]: ...
|
||||
|
||||
|
||||
async def gather_optional_coroutines(
|
||||
*coroutines: Unpack[Tuple[Optional[Coroutine[Any, Any, T1]], ...]],
|
||||
) -> Tuple[Optional[T1], ...]:
|
||||
|
||||
Reference in New Issue
Block a user