1
0

Compare commits

...

1 Commits

Author SHA1 Message Date
Richard van der Hoff
070c0279d4 await_lazy_loading 2022-05-19 16:18:30 +01:00
5 changed files with 78 additions and 19 deletions

View File

@@ -120,8 +120,10 @@ class EventBuilder:
The signed and hashed event. The signed and hashed event.
""" """
if auth_event_ids is None: if auth_event_ids is None:
# we pick the auth events based on our best knowledge of the current state
# of the room, so we don't need to await full state.
state_ids = await self._state.get_current_state_ids( state_ids = await self._state.get_current_state_ids(
self.room_id, prev_event_ids self.room_id, prev_event_ids, await_full_state=False
) )
auth_event_ids = self._event_auth_handler.compute_auth_events( auth_event_ids = self._event_auth_handler.compute_auth_events(
self, state_ids self, state_ids

View File

@@ -902,11 +902,15 @@ class SyncHandler:
if full_state: if full_state:
if batch: if batch:
current_state_ids = await self.state_store.get_state_ids_for_event( current_state_ids = await self.state_store.get_state_ids_for_event(
batch.events[-1].event_id, state_filter=state_filter batch.events[-1].event_id,
state_filter=state_filter,
await_full_state=not lazy_load_members, # TODO
) )
state_ids = await self.state_store.get_state_ids_for_event( state_ids = await self.state_store.get_state_ids_for_event(
batch.events[0].event_id, state_filter=state_filter batch.events[0].event_id,
state_filter=state_filter,
await_full_state=not lazy_load_members, # TODO
) )
else: else:

View File

@@ -48,6 +48,7 @@ from synapse.logging.context import ContextResourceUsage
from synapse.state import v1, v2 from synapse.state import v1, v2
from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.storage.roommember import ProfileInfo from synapse.storage.roommember import ProfileInfo
from synapse.storage.state import StateFilter
from synapse.types import StateMap from synapse.types import StateMap
from synapse.util.async_helpers import Linearizer from synapse.util.async_helpers import Linearizer
from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.expiringcache import ExpiringCache
@@ -177,7 +178,16 @@ class StateHandler:
assert latest_event_ids is not None assert latest_event_ids is not None
logger.debug("calling resolve_state_groups from get_current_state") logger.debug("calling resolve_state_groups from get_current_state")
ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
filter = StateFilter.all()
if event_type:
filter = StateFilter.from_types(((event_type, state_key),))
ret = await self.resolve_state_groups_for_events(
room_id,
latest_event_ids,
await_full_state=filter.must_await_full_state(self.hs.is_mine_id),
)
state = ret.state state = ret.state
if event_type: if event_type:
@@ -195,7 +205,10 @@ class StateHandler:
} }
async def get_current_state_ids( async def get_current_state_ids(
self, room_id: str, latest_event_ids: Optional[Collection[str]] = None self,
room_id: str,
latest_event_ids: Optional[Collection[str]] = None,
await_full_state: bool = True,
) -> StateMap[str]: ) -> StateMap[str]:
"""Get the current state, or the state at a set of events, for a room """Get the current state, or the state at a set of events, for a room
@@ -203,6 +216,8 @@ class StateHandler:
room_id: room_id:
latest_event_ids: if given, the forward extremities to resolve. If latest_event_ids: if given, the forward extremities to resolve. If
None, we look them up from the database (via a cache). None, we look them up from the database (via a cache).
await_full_state: if true, will block if we do not yet have complete
state at the latest events.
Returns: Returns:
the state dict, mapping from (event_type, state_key) -> event_id the state dict, mapping from (event_type, state_key) -> event_id
@@ -212,7 +227,9 @@ class StateHandler:
assert latest_event_ids is not None assert latest_event_ids is not None
logger.debug("calling resolve_state_groups from get_current_state_ids") logger.debug("calling resolve_state_groups from get_current_state_ids")
ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids) ret = await self.resolve_state_groups_for_events(
room_id, latest_event_ids, await_full_state=await_full_state
)
return ret.state return ret.state
async def get_current_users_in_room( async def get_current_users_in_room(
@@ -323,7 +340,9 @@ class StateHandler:
logger.debug("calling resolve_state_groups from compute_event_context") logger.debug("calling resolve_state_groups from compute_event_context")
entry = await self.resolve_state_groups_for_events( entry = await self.resolve_state_groups_for_events(
event.room_id, event.prev_event_ids() event.room_id,
event.prev_event_ids(),
await_full_state=False,
) )
state_ids_before_event = entry.state state_ids_before_event = entry.state
@@ -404,7 +423,7 @@ class StateHandler:
@measure_func() @measure_func()
async def resolve_state_groups_for_events( async def resolve_state_groups_for_events(
self, room_id: str, event_ids: Collection[str] self, room_id: str, event_ids: Collection[str], await_full_state: bool = True
) -> _StateCacheEntry: ) -> _StateCacheEntry:
"""Given a list of event_ids this method fetches the state at each """Given a list of event_ids this method fetches the state at each
event, resolves conflicts between them and returns them. event, resolves conflicts between them and returns them.
@@ -412,13 +431,17 @@ class StateHandler:
Args: Args:
room_id room_id
event_ids event_ids
await_full_state: if true, will block if we do not yet have complete
state at these events.
Returns: Returns:
The resolved state The resolved state
""" """
logger.debug("resolve_state_groups event_ids %s", event_ids) logger.debug("resolve_state_groups event_ids %s", event_ids)
state_groups = await self.state_store.get_state_group_for_events(event_ids) state_groups = await self.state_store.get_state_group_for_events(
event_ids, await_full_state=await_full_state
)
state_group_ids = state_groups.values() state_group_ids = state_groups.values()

View File

@@ -609,13 +609,18 @@ class StateGroupStorage:
return state_group_delta.prev_group, state_group_delta.delta_ids return state_group_delta.prev_group, state_group_delta.delta_ids
async def get_state_groups_ids( async def get_state_groups_ids(
self, _room_id: str, event_ids: Collection[str] self,
_room_id: str,
event_ids: Collection[str],
await_full_state: bool = True,
) -> Dict[int, MutableStateMap[str]]: ) -> Dict[int, MutableStateMap[str]]:
"""Get the event IDs of all the state for the state groups for the given events """Get the event IDs of all the state for the state groups for the given events
Args: Args:
_room_id: id of the room for these events _room_id: id of the room for these events
event_ids: ids of the events event_ids: ids of the events
await_full_state: if true, will block if we do not yet have complete
state at these events.
Returns: Returns:
dict of state_group_id -> (dict of (type, state_key) -> event id) dict of state_group_id -> (dict of (type, state_key) -> event id)
@@ -627,7 +632,9 @@ class StateGroupStorage:
if not event_ids: if not event_ids:
return {} return {}
event_to_groups = await self.get_state_group_for_events(event_ids) event_to_groups = await self.get_state_group_for_events(
event_ids, await_full_state=await_full_state
)
groups = set(event_to_groups.values()) groups = set(event_to_groups.values())
group_to_state = await self.stores.state._get_state_for_groups(groups) group_to_state = await self.stores.state._get_state_for_groups(groups)
@@ -700,7 +707,10 @@ class StateGroupStorage:
return self.stores.state._get_state_groups_from_groups(groups, state_filter) return self.stores.state._get_state_groups_from_groups(groups, state_filter)
async def get_state_for_events( async def get_state_for_events(
self, event_ids: Collection[str], state_filter: Optional[StateFilter] = None self,
event_ids: Collection[str],
state_filter: Optional[StateFilter] = None,
await_full_state: bool = True,
) -> Dict[str, StateMap[EventBase]]: ) -> Dict[str, StateMap[EventBase]]:
"""Given a list of event_ids and type tuples, return a list of state """Given a list of event_ids and type tuples, return a list of state
dicts for each event. dicts for each event.
@@ -708,6 +718,8 @@ class StateGroupStorage:
Args: Args:
event_ids: The events to fetch the state of. event_ids: The events to fetch the state of.
state_filter: The state filter used to fetch state. state_filter: The state filter used to fetch state.
await_full_state: if true, will block if the state_filter includes state
which is not yet complete.
Returns: Returns:
A dict of (event_id) -> (type, state_key) -> [state_events] A dict of (event_id) -> (type, state_key) -> [state_events]
@@ -716,8 +728,11 @@ class StateGroupStorage:
RuntimeError if we don't have a state group for one or more of the events RuntimeError if we don't have a state group for one or more of the events
(ie they are outliers or unknown) (ie they are outliers or unknown)
""" """
await_full_state = True if (
if state_filter and not state_filter.must_await_full_state(self._is_mine_id): await_full_state
and state_filter
and not state_filter.must_await_full_state(self._is_mine_id)
):
await_full_state = False await_full_state = False
event_to_groups = await self.get_state_group_for_events( event_to_groups = await self.get_state_group_for_events(
@@ -749,6 +764,7 @@ class StateGroupStorage:
self, self,
event_ids: Collection[str], event_ids: Collection[str],
state_filter: Optional[StateFilter] = None, state_filter: Optional[StateFilter] = None,
await_full_state: bool = True,
) -> Dict[str, StateMap[str]]: ) -> Dict[str, StateMap[str]]:
""" """
Get the state dicts corresponding to a list of events, containing the event_ids Get the state dicts corresponding to a list of events, containing the event_ids
@@ -757,6 +773,8 @@ class StateGroupStorage:
Args: Args:
event_ids: events whose state should be returned event_ids: events whose state should be returned
state_filter: The state filter used to fetch state from the database. state_filter: The state filter used to fetch state from the database.
await_full_state: if true, will block if the state_filter includes state
which is not yet complete.
Returns: Returns:
A dict from event_id -> (type, state_key) -> event_id A dict from event_id -> (type, state_key) -> event_id
@@ -765,8 +783,12 @@ class StateGroupStorage:
RuntimeError if we don't have a state group for one or more of the events RuntimeError if we don't have a state group for one or more of the events
(ie they are outliers or unknown) (ie they are outliers or unknown)
""" """
await_full_state = True
if state_filter and not state_filter.must_await_full_state(self._is_mine_id): if (
await_full_state
and state_filter
and not state_filter.must_await_full_state(self._is_mine_id)
):
await_full_state = False await_full_state = False
event_to_groups = await self.get_state_group_for_events( event_to_groups = await self.get_state_group_for_events(
@@ -808,7 +830,10 @@ class StateGroupStorage:
return state_map[event_id] return state_map[event_id]
async def get_state_ids_for_event( async def get_state_ids_for_event(
self, event_id: str, state_filter: Optional[StateFilter] = None self,
event_id: str,
state_filter: Optional[StateFilter] = None,
await_full_state: bool = True,
) -> StateMap[str]: ) -> StateMap[str]:
""" """
Get the state dict corresponding to a particular event Get the state dict corresponding to a particular event
@@ -816,6 +841,8 @@ class StateGroupStorage:
Args: Args:
event_id: event whose state should be returned event_id: event whose state should be returned
state_filter: The state filter used to fetch state from the database. state_filter: The state filter used to fetch state from the database.
await_full_state: if true, will block if the state_filter includes state
which is not yet complete.
Returns: Returns:
A dict from (type, state_key) -> state_event_id A dict from (type, state_key) -> state_event_id
@@ -825,7 +852,9 @@ class StateGroupStorage:
outlier or is unknown) outlier or is unknown)
""" """
state_map = await self.get_state_ids_for_events( state_map = await self.get_state_ids_for_events(
[event_id], state_filter or StateFilter.all() [event_id],
state_filter or StateFilter.all(),
await_full_state=await_full_state,
) )
return state_map[event_id] return state_map[event_id]
@@ -857,7 +886,7 @@ class StateGroupStorage:
Args: Args:
event_ids: events to get state groups for event_ids: events to get state groups for
await_full_state: if true, will block if we do not yet have complete await_full_state: if true, will block if we do not yet have complete
state at these events. state at these event.
""" """
if await_full_state: if await_full_state:
await self._partial_state_events_tracker.await_full_state(event_ids) await self._partial_state_events_tracker.await_full_state(event_ids)

View File

@@ -85,6 +85,7 @@ async def filter_events_for_client(
event_id_to_state = await storage.state.get_state_for_events( event_id_to_state = await storage.state.get_state_for_events(
frozenset(e.event_id for e in events if not e.internal_metadata.outlier), frozenset(e.event_id for e in events if not e.internal_metadata.outlier),
state_filter=StateFilter.from_types(types), state_filter=StateFilter.from_types(types),
await_full_state=False,
) )
# Get the users who are ignored by the requesting user. # Get the users who are ignored by the requesting user.