1
0

Add types to function signatures in SyncHandler

This commit is contained in:
Erik Johnston
2020-01-31 12:56:55 +00:00
parent ad5e4de70d
commit 4e60e6cb39

View File

@@ -16,7 +16,7 @@
import itertools
import logging
from typing import Any, List, Optional, Set, Tuple, FrozenSet, Dict
from typing import Any, Dict, FrozenSet, List, Optional, Set, Tuple
from six import iteritems, itervalues
@@ -30,7 +30,14 @@ from synapse.logging.context import LoggingContext
from synapse.push.clientformat import format_push_rules_for_user
from synapse.storage.roommember import MemberSummary
from synapse.storage.state import StateFilter
from synapse.types import JsonDict, RoomStreamToken, StateMap, StreamToken, UserID
from synapse.types import (
JsonDict,
RoomStreamToken,
StateMap,
StreamToken,
UserID,
Collection,
)
from synapse.util.async_helpers import concurrently_execute
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.caches.lrucache import LruCache
@@ -76,7 +83,7 @@ class SyncConfig:
@attr.s(slots=True, frozen=True)
class TimelineBatch:
prev_batch = attr.ib(type=str)
prev_batch = attr.ib(type=StreamToken)
events = attr.ib(type=List[EventBase])
limited = attr.ib(bool)
@@ -91,13 +98,13 @@ class TimelineBatch:
@attr.s(slots=True, frozen=True)
class JoinedSyncResult:
room_id = attr.ib(type=bool)
room_id = attr.ib(type=str)
timeline = attr.ib(type=TimelineBatch)
state = attr.ib(type=StateMap[EventBase])
ephemeral = attr.ib(type=List[JsonDict])
account_data = attr.ib(type=List[JsonDict])
unread_notifications = attr.ib(type=JsonDict)
summary = attr.ib(type=JsonDict)
summary = attr.ib(type=Optional[JsonDict])
def __nonzero__(self) -> bool:
"""Make the result appear empty if there are no updates. This is used
@@ -163,8 +170,8 @@ class DeviceLists:
left: List of user_ids whose devices we no longer track
"""
changed = attr.ib(type=List[str])
left = attr.ib(type=List[str])
changed = attr.ib(type=Collection[str])
left = attr.ib(type=Collection[str])
def __nonzero__(self) -> bool:
return bool(self.changed or self.left)
@@ -254,13 +261,15 @@ class SyncHandler(object):
)
async def wait_for_sync_for_user(
self, sync_config, since_token=None, timeout=0, full_state=False
):
self,
sync_config: SyncConfig,
since_token: Optional[StreamToken] = None,
timeout: int = 0,
full_state: bool = False,
) -> SyncResult:
"""Get the sync for a client if we have new data for it now. Otherwise
wait for new data to arrive on the server. If the timeout expires, then
return an empty sync result.
Returns:
Deferred[SyncResult]
"""
# If the user is not part of the mau group, then check that limits have
# not been exceeded (if not part of the group by this point, almost certain
@@ -279,8 +288,12 @@ class SyncHandler(object):
return res
async def _wait_for_sync_for_user(
self, sync_config, since_token, timeout, full_state
):
self,
sync_config: SyncConfig,
since_token: Optional[StreamToken] = None,
timeout: int = 0,
full_state: bool = False,
) -> SyncResult:
if since_token is None:
sync_type = "initial_sync"
elif full_state:
@@ -319,25 +332,33 @@ class SyncHandler(object):
return result
def current_sync_for_user(self, sync_config, since_token=None, full_state=False):
async def current_sync_for_user(
self,
sync_config: SyncConfig,
since_token: Optional[StreamToken] = None,
full_state: bool = False,
) -> SyncResult:
"""Get the sync for client needed to match what the server has now.
Returns:
A Deferred SyncResult.
"""
return self.generate_sync_result(sync_config, since_token, full_state)
return await self.generate_sync_result(sync_config, since_token, full_state)
async def push_rules_for_user(self, user):
async def push_rules_for_user(self, user: UserID) -> JsonDict:
user_id = user.to_string()
rules = await self.store.get_push_rules_for_user(user_id)
rules = format_push_rules_for_user(user, rules)
return rules
async def ephemeral_by_room(self, sync_result_builder, now_token, since_token=None):
async def ephemeral_by_room(
self,
sync_result_builder: "SyncResultBuilder",
now_token: StreamToken,
since_token: Optional[StreamToken] = None,
) -> Tuple[StreamToken, Dict[str, List[JsonDict]]]:
"""Get the ephemeral events for each room the user is in
Args:
sync_result_builder(SyncResultBuilder)
now_token (StreamToken): Where the server is currently up to.
since_token (StreamToken): Where the server was when the client
sync_result_builder
now_token: Where the server is currently up to.
since_token: Where the server was when the client
last synced.
Returns:
A tuple of the now StreamToken, updated to reflect the which typing
@@ -394,13 +415,13 @@ class SyncHandler(object):
async def _load_filtered_recents(
self,
room_id,
sync_config,
now_token,
since_token=None,
recents=None,
newly_joined_room=False,
):
room_id: str,
sync_config: SyncConfig,
now_token: StreamToken,
since_token: Optional[StreamToken] = None,
potential_recents: Optional[List[EventBase]] = None,
newly_joined_room: bool = False,
) -> TimelineBatch:
"""
Returns:
a Deferred TimelineBatch
@@ -411,20 +432,28 @@ class SyncHandler(object):
sync_config.filter_collection.blocks_all_room_timeline()
)
if recents is None or newly_joined_room or timeline_limit < len(recents):
if (
potential_recents is None
or newly_joined_room
or timeline_limit < len(potential_recents)
):
limited = True
else:
limited = False
if recents:
recents = sync_config.filter_collection.filter_room_timeline(recents)
if potential_recents:
recents = sync_config.filter_collection.filter_room_timeline(
potential_recents
)
# We check if there are any state events, if there are then we pass
# all current state events to the filter_events function. This is to
# ensure that we always include current state in the timeline
current_state_ids = frozenset() # type: FrozenSet[str]
if any(e.is_state() for e in recents):
current_state_ids_map = await self.state.get_current_state_ids(room_id)
current_state_ids_map = await self.state.get_current_state_ids(
room_id
)
current_state_ids = frozenset(itervalues(current_state_ids_map))
recents = await filter_events_for_client(
@@ -477,7 +506,9 @@ class SyncHandler(object):
# ensure that we always include current state in the timeline
current_state_ids = frozenset()
if any(e.is_state() for e in loaded_recents):
current_state_ids_map = await self.state.get_current_state_ids(room_id)
current_state_ids_map = await self.state.get_current_state_ids(
room_id
)
current_state_ids = frozenset(itervalues(current_state_ids_map))
loaded_recents = await filter_events_for_client(
@@ -507,17 +538,15 @@ class SyncHandler(object):
limited=limited or newly_joined_room,
)
async def get_state_after_event(self, event, state_filter=StateFilter.all()):
async def get_state_after_event(
self, event: EventBase, state_filter: StateFilter = StateFilter.all()
) -> StateMap[str]:
"""
Get the room state after the given event
Args:
event(synapse.events.EventBase): event of interest
state_filter (StateFilter): The state filter used to fetch state
from the database.
Returns:
A Deferred map from ((type, state_key)->Event)
event: event of interest
state_filter: The state filter used to fetch state from the database.
"""
state_ids = await self.state_store.get_state_ids_for_event(
event.event_id, state_filter=state_filter
@@ -528,18 +557,17 @@ class SyncHandler(object):
return state_ids
async def get_state_at(
self, room_id, stream_position, state_filter=StateFilter.all()
):
self,
room_id: str,
stream_position: StreamToken,
state_filter: StateFilter = StateFilter.all(),
) -> StateMap[str]:
""" Get the room state at a particular stream position
Args:
room_id(str): room for which to get state
stream_position(StreamToken): point at which to get state
state_filter (StateFilter): The state filter used to fetch state
from the database.
Returns:
A Deferred map from ((type, state_key)->Event)
room_id: room for which to get state
stream_position: point at which to get state
state_filter: The state filter used to fetch state from the database.
"""
# FIXME this claims to get the state at a stream position, but
# get_recent_events_for_room operates by topo ordering. This therefore
@@ -560,23 +588,25 @@ class SyncHandler(object):
state = {}
return state
async def compute_summary(self, room_id, sync_config, batch, state, now_token):
async def compute_summary(
self,
room_id: str,
sync_config: SyncConfig,
batch: TimelineBatch,
state: StateMap[EventBase],
now_token: StreamToken,
) -> Optional[JsonDict]:
""" Works out a room summary block for this room, summarising the number
of joined members in the room, and providing the 'hero' members if the
room has no name so clients can consistently name rooms. Also adds
state events to 'state' if needed to describe the heroes.
Args:
room_id(str):
sync_config(synapse.handlers.sync.SyncConfig):
batch(synapse.handlers.sync.TimelineBatch): The timeline batch for
the room that will be sent to the user.
state(dict): dict of (type, state_key) -> Event as returned by
compute_state_delta
now_token(str): Token of the end of the current batch.
Returns:
A deferred dict describing the room summary
Args
room_id
sync_config
batch: The timeline batch for the room that will be sent to the user.
state: State as returned by compute_state_delta
now_token: Token of the end of the current batch.
"""
# FIXME: we could/should get this from room_stats when matthew/stats lands
@@ -695,7 +725,7 @@ class SyncHandler(object):
return summary
def get_lazy_loaded_members_cache(self, cache_key):
def get_lazy_loaded_members_cache(self, cache_key: Any) -> LruCache:
cache = self.lazy_loaded_members_cache.get(cache_key)
if cache is None:
logger.debug("creating LruCache for %r", cache_key)
@@ -706,23 +736,24 @@ class SyncHandler(object):
return cache
async def compute_state_delta(
self, room_id, batch, sync_config, since_token, now_token, full_state
):
self,
room_id: str,
batch: TimelineBatch,
sync_config: SyncConfig,
since_token: Optional[StreamToken],
now_token: StreamToken,
full_state: bool,
) -> StateMap[EventBase]:
""" Works out the difference in state between the start of the timeline
and the previous sync.
Args:
room_id(str):
batch(synapse.handlers.sync.TimelineBatch): The timeline batch for
the room that will be sent to the user.
sync_config(synapse.handlers.sync.SyncConfig):
since_token(str|None): Token of the end of the previous batch. May
be None.
now_token(str): Token of the end of the current batch.
full_state(bool): Whether to force returning the full state.
Returns:
A deferred dict of (type, state_key) -> Event
room_id:
batch: The timeline batch for the room that will be sent to the user.
sync_config:
since_token: Token of the end of the previous batch. May be None.
now_token: Token of the end of the current batch.
full_state: Whether to force returning the full state.
"""
# TODO(mjark) Check if the state events were received by the server
# after the previous sync, since we need to include those state
@@ -814,9 +845,12 @@ class SyncHandler(object):
# about them).
state_filter = StateFilter.all()
state_at_previous_sync = await self.get_state_at(
room_id, stream_position=since_token, state_filter=state_filter
)
if since_token:
state_at_previous_sync = await self.get_state_at(
room_id, stream_position=since_token, state_filter=state_filter
)
else:
state_at_previous_sync = {}
if batch:
current_state_ids = await self.state_store.get_state_ids_for_event(
@@ -899,7 +933,9 @@ class SyncHandler(object):
)
}
async def unread_notifs_for_room_id(self, room_id, sync_config):
async def unread_notifs_for_room_id(
self, room_id: str, sync_config: SyncConfig
) -> Optional[Dict[str, str]]:
with Measure(self.clock, "unread_notifs_for_room_id"):
last_unread_event_id = await self.store.get_last_receipt_event_id_for_user(
user_id=sync_config.user.to_string(),
@@ -918,17 +954,12 @@ class SyncHandler(object):
return None
async def generate_sync_result(
self, sync_config, since_token=None, full_state=False
):
self,
sync_config: SyncConfig,
since_token: Optional[StreamToken] = None,
full_state: bool = False,
) -> SyncResult:
"""Generates a sync result.
Args:
sync_config (SyncConfig)
since_token (StreamToken)
full_state (bool)
Returns:
Deferred(SyncResult)
"""
# NB: The now_token gets changed by some of the generate_sync_* methods,
# this is due to some of the underlying streams not supporting the ability
@@ -1020,7 +1051,9 @@ class SyncHandler(object):
)
@measure_func("_generate_sync_entry_for_groups")
async def _generate_sync_entry_for_groups(self, sync_result_builder):
async def _generate_sync_entry_for_groups(
self, sync_result_builder: "SyncResultBuilder"
):
user_id = sync_result_builder.sync_config.user.to_string()
since_token = sync_result_builder.since_token
now_token = sync_result_builder.now_token
@@ -1065,27 +1098,22 @@ class SyncHandler(object):
@measure_func("_generate_sync_entry_for_device_list")
async def _generate_sync_entry_for_device_list(
self,
sync_result_builder,
newly_joined_rooms,
newly_joined_or_invited_users,
newly_left_rooms,
newly_left_users,
sync_result_builder: "SyncResultBuilder",
newly_joined_rooms: Set[str],
newly_joined_or_invited_users: Set[str],
newly_left_rooms: Set[str],
newly_left_users: Set[str],
):
"""Generate the DeviceLists section of sync
Args:
sync_result_builder (SyncResultBuilder)
newly_joined_rooms (set[str]): Set of rooms user has joined since
sync_result_builder
newly_joined_rooms: Set of rooms user has joined since previous sync
newly_joined_or_invited_users: Set of users that have joined or
been invited to a room since previous sync.
newly_left_rooms: Set of rooms user has left since previous sync
newly_left_users: Set of users that have left a room we're in since
previous sync
newly_joined_or_invited_users (set[str]): Set of users that have
joined or been invited to a room since previous sync.
newly_left_rooms (set[str]): Set of rooms user has left since
previous sync
newly_left_users (set[str]): Set of users that have left a room
we're in since previous sync
Returns:
Deferred[DeviceLists]
"""
user_id = sync_result_builder.sync_config.user.to_string()
@@ -1146,15 +1174,11 @@ class SyncHandler(object):
else:
return DeviceLists(changed=[], left=[])
async def _generate_sync_entry_for_to_device(self, sync_result_builder):
async def _generate_sync_entry_for_to_device(
self, sync_result_builder: "SyncResultBuilder"
):
"""Generates the portion of the sync response. Populates
`sync_result_builder` with the result.
Args:
sync_result_builder(SyncResultBuilder)
Returns:
Deferred(dict): A dictionary containing the per room account data.
"""
user_id = sync_result_builder.sync_config.user.to_string()
device_id = sync_result_builder.sync_config.device_id
@@ -1192,15 +1216,17 @@ class SyncHandler(object):
else:
sync_result_builder.to_device = []
async def _generate_sync_entry_for_account_data(self, sync_result_builder):
async def _generate_sync_entry_for_account_data(
self, sync_result_builder: "SyncResultBuilder"
) -> Dict[str, Dict[str, JsonDict]]:
"""Generates the account data portion of the sync response. Populates
`sync_result_builder` with the result.
Args:
sync_result_builder(SyncResultBuilder)
sync_result_builder
Returns:
Deferred(dict): A dictionary containing the per room account data.
A dictionary containing the per room account data.
"""
sync_config = sync_result_builder.sync_config
user_id = sync_result_builder.sync_config.user.to_string()
@@ -1244,18 +1270,21 @@ class SyncHandler(object):
return account_data_by_room
async def _generate_sync_entry_for_presence(
self, sync_result_builder, newly_joined_rooms, newly_joined_or_invited_users
self,
sync_result_builder: "SyncResultBuilder",
newly_joined_rooms: Set[str],
newly_joined_or_invited_users: Set[str],
):
"""Generates the presence portion of the sync response. Populates the
`sync_result_builder` with the result.
Args:
sync_result_builder(SyncResultBuilder)
newly_joined_rooms(list): List of rooms that the user has joined
since the last sync (or empty if an initial sync)
newly_joined_or_invited_users(list): List of users that have joined
or been invited to rooms since the last sync (or empty if an initial
sync)
sync_result_builder
newly_joined_rooms: Set of rooms that the user has joined since
the last sync (or empty if an initial sync)
newly_joined_or_invited_users: Set of users that have joined or
been invited to rooms since the last sync (or empty if an
initial sync)
"""
now_token = sync_result_builder.now_token
sync_config = sync_result_builder.sync_config
@@ -1299,17 +1328,19 @@ class SyncHandler(object):
sync_result_builder.presence = presence
async def _generate_sync_entry_for_rooms(
self, sync_result_builder, account_data_by_room
):
self,
sync_result_builder: "SyncResultBuilder",
account_data_by_room: Dict[str, Dict[str, JsonDict]],
) -> Tuple[Set[str], Set[str], Set[str], Set[str]]:
"""Generates the rooms portion of the sync response. Populates the
`sync_result_builder` with the result.
Args:
sync_result_builder(SyncResultBuilder)
account_data_by_room(dict): Dictionary of per room account data
sync_result_builder
account_data_by_room: Dictionary of per room account data
Returns:
Deferred(tuple): Returns a 4-tuple of
Returns a 4-tuple of
`(newly_joined_rooms, newly_joined_or_invited_users,
newly_left_rooms, newly_left_users)`
"""
@@ -1341,7 +1372,7 @@ class SyncHandler(object):
)
if not tags_by_room:
logger.debug("no-oping sync")
return [], [], [], []
return set(), set(), set(), set()
ignored_account_data = await self.store.get_global_account_data_by_type_for_user(
"m.ignored_user_list", user_id=user_id
@@ -1408,13 +1439,15 @@ class SyncHandler(object):
newly_left_users -= newly_joined_or_invited_users
return (
newly_joined_rooms,
set(newly_joined_rooms),
newly_joined_or_invited_users,
newly_left_rooms,
set(newly_left_rooms),
newly_left_users,
)
async def _have_rooms_changed(self, sync_result_builder):
async def _have_rooms_changed(
self, sync_result_builder: "SyncResultBuilder"
) -> bool:
"""Returns whether there may be any new events that should be sent down
the sync. Returns True if there are.
"""
@@ -1716,26 +1749,26 @@ class SyncHandler(object):
async def _generate_room_entry(
self,
sync_result_builder,
ignored_users,
room_builder,
ephemeral,
tags,
account_data,
always_include=False,
sync_result_builder: "SyncResultBuilder",
ignored_users: Set[str],
room_builder: "RoomSyncResultBuilder",
ephemeral: List[JsonDict],
tags: Optional[List[JsonDict]],
account_data: Dict[str, JsonDict],
always_include: bool = False,
):
"""Populates the `joined` and `archived` section of `sync_result_builder`
based on the `room_builder`.
Args:
sync_result_builder(SyncResultBuilder)
ignored_users(set(str)): Set of users ignored by user.
room_builder(RoomSyncResultBuilder)
ephemeral(list): List of new ephemeral events for room
tags(list): List of *all* tags for room, or None if there has been
sync_result_builder
ignored_users: Set of users ignored by user.
room_builder
ephemeral: List of new ephemeral events for room
tags: List of *all* tags for room, or None if there has been
no change.
account_data(list): List of new account data for room
always_include(bool): Always include this room in the sync response,
account_data: List of new account data for room
always_include: Always include this room in the sync response,
even if empty.
"""
newly_joined = room_builder.newly_joined
@@ -1761,7 +1794,7 @@ class SyncHandler(object):
sync_config,
now_token=upto_token,
since_token=since_token,
recents=events,
potential_recents=events,
newly_joined_room=newly_joined,
)
@@ -1812,7 +1845,7 @@ class SyncHandler(object):
room_id, batch, sync_config, since_token, now_token, full_state=full_state
)
summary = {} # type: JsonDict
summary = {} # type: Optional[JsonDict]
# we include a summary in room responses when we're lazy loading
# members (as the client otherwise doesn't have enough info to form
@@ -1874,7 +1907,9 @@ class SyncHandler(object):
else:
raise Exception("Unrecognized rtype: %r", room_builder.rtype)
async def get_rooms_for_user_at(self, user_id, stream_ordering):
async def get_rooms_for_user_at(
self, user_id: str, stream_ordering: int
) -> FrozenSet[str]:
"""Get set of joined rooms for a user at the given stream ordering.
The stream ordering *must* be recent, otherwise this may throw an
@@ -1882,12 +1917,11 @@ class SyncHandler(object):
current token, which should be perfectly fine).
Args:
user_id (str)
stream_ordering (int)
user_id
stream_ordering
ReturnValue:
Deferred[frozenset[str]]: Set of room_ids the user is in at given
stream_ordering.
Set of room_ids the user is in at given stream_ordering.
"""
joined_rooms = await self.store.get_rooms_for_user_with_stream_ordering(user_id)
@@ -1917,7 +1951,7 @@ class SyncHandler(object):
return frozenset(joined_room_ids)
def _action_has_highlight(actions):
def _action_has_highlight(actions: List[JsonDict]) -> bool:
for action in actions:
try:
if action.get("set_tweak", None) == "highlight":
@@ -1929,22 +1963,23 @@ def _action_has_highlight(actions):
def _calculate_state(
timeline_contains, timeline_start, previous, current, lazy_load_members
):
timeline_contains: StateMap[str],
timeline_start: StateMap[str],
previous: StateMap[str],
current: StateMap[str],
lazy_load_members: bool,
) -> StateMap[str]:
"""Works out what state to include in a sync response.
Args:
timeline_contains (dict): state in the timeline
timeline_start (dict): state at the start of the timeline
previous (dict): state at the end of the previous sync (or empty dict
timeline_contains: state in the timeline
timeline_start: state at the start of the timeline
previous: state at the end of the previous sync (or empty dict
if this is an initial sync)
current (dict): state at the end of the timeline
lazy_load_members (bool): whether to return members from timeline_start
current: state at the end of the timeline
lazy_load_members: whether to return members from timeline_start
or not. assumes that timeline_start has already been filtered to
include only the members the client needs to know about.
Returns:
dict
"""
event_id_to_key = {
e: key
@@ -2006,7 +2041,7 @@ class SyncResultBuilder:
full_state = attr.ib(type=bool)
since_token = attr.ib(type=Optional[StreamToken])
now_token = attr.ib(type=StreamToken)
joined_room_ids = attr.ib(type=List[str])
joined_room_ids = attr.ib(type=FrozenSet[str])
presence = attr.ib(type=List[JsonDict], default=attr.Factory(list))
account_data = attr.ib(type=List[JsonDict], default=attr.Factory(list))
@@ -2039,4 +2074,4 @@ class RoomSyncResultBuilder(object):
newly_joined = attr.ib(type=bool)
full_state = attr.ib(type=bool)
since_token = attr.ib(type=Optional[StreamToken])
upto_token = attr.ib(type=Optional[StreamToken])
upto_token = attr.ib(type=StreamToken)