diff --git a/synapse/replication/tcp/streams/_base.py b/synapse/replication/tcp/streams/_base.py index 2699e466bc..c14dff6c64 100644 --- a/synapse/replication/tcp/streams/_base.py +++ b/synapse/replication/tcp/streams/_base.py @@ -16,7 +16,7 @@ import logging from collections import namedtuple -from typing import Any, Awaitable, Callable, List, Optional, Tuple, Union +from typing import Any, Awaitable, Callable, List, Optional, Tuple import attr @@ -29,6 +29,15 @@ logger = logging.getLogger(__name__) MAX_EVENTS_BEHIND = 500000 +# Some type aliases to make things a bit easier. + +# A stream position token +Token = int + +# A pair of position in stream and args used to create an instance of `ROW_TYPE`. +StreamRow = Tuple[Token, tuple] + + class Stream(object): """Base class for the streams. @@ -66,7 +75,7 @@ class Stream(object): """ self.last_token = self.current_token() - async def get_updates(self) -> Tuple[List[Tuple[int, JsonDict]], int, bool]: + async def get_updates(self) -> Tuple[List[Tuple[Token, JsonDict]], Token, bool]: """Gets all updates since the last time this function was called (or since the stream was constructed if it hadn't been called before). @@ -85,8 +94,8 @@ class Stream(object): return updates, current_token, limited async def get_updates_since( - self, from_token: Union[int, str], upto_token: int, limit: int = 100 - ) -> Tuple[List[Tuple[int, JsonDict]], int, bool]: + self, from_token: Token, upto_token: Token, limit: int = 100 + ) -> Tuple[List[Tuple[Token, JsonDict]], Token, bool]: """Like get_updates except allows specifying from when we should stream updates @@ -128,8 +137,8 @@ class Stream(object): def db_query_to_update_function( - query_function: Callable[[int, int, int], Awaitable[List[tuple]]] -) -> Callable[[int, int, int], Awaitable[Tuple[List[Tuple[int, tuple]], int, bool]]]: + query_function: Callable[[Token, Token, int], Awaitable[List[tuple]]] +) -> Callable[[Token, Token, int], Awaitable[Tuple[List[StreamRow], Token, bool]]]: """Wraps a db query function which returns a list of rows to make it suitable for use as an `update_function` for the Stream class """ @@ -147,6 +156,28 @@ def db_query_to_update_function( return update_function +def make_http_update_function( + hs, stream_name: str +) -> Callable[[Token, Token, Token], Awaitable[Tuple[List[StreamRow], Token, bool]]]: + """Makes a suitable function for use as an `update_function` that queries + the master process for updates. + """ + + client = ReplicationGetStreamUpdates.make_client(hs) + + async def update_function( + from_token: int, upto_token: int, limit: int + ) -> Tuple[List[Tuple[int, tuple]], int, bool]: + return await client( + stream_name=stream_name, + from_token=from_token, + upto_token=upto_token, + limit=limit, + ) + + return update_function + + class BackfillStream(Stream): """We fetched some old events and either we had never seen that event before or it went from being an outlier to not. @@ -204,7 +235,7 @@ class PresenceStream(Stream): self.update_function = db_query_to_update_function(presence_handler.get_all_presence_updates) # type: ignore else: # Query master process - self.update_function = ReplicationGetStreamUpdates.make_client(hs) # type: ignore + self.update_function = make_http_update_function(hs, self.NAME) # type: ignore super(PresenceStream, self).__init__(hs) @@ -226,7 +257,7 @@ class TypingStream(Stream): self.update_function = db_query_to_update_function(typing_handler.get_all_typing_updates) # type: ignore else: # Query master process - self.update_function = ReplicationGetStreamUpdates.make_client(hs) # type: ignore + self.update_function = make_http_update_function(hs, self.NAME) # type: ignore super(TypingStream, self).__init__(hs)