1
0

update compute_event_context_for_batched to take list and assign state groups

This commit is contained in:
H. Shay
2022-11-08 11:28:56 -08:00
parent b840328d36
commit 8cd1196c4b

View File

@@ -422,49 +422,59 @@ class StateHandler:
async def compute_event_context_for_batched(
self,
event: EventBase,
state_ids_before_event: StateMap[str],
) -> EventContext:
events_and_context: List[Tuple[EventBase, EventContext]],
prev_group: int,
state_ids_before_event: StateMap,
) -> List[Tuple[EventBase, EventContext]]:
"""
Generate an event context for an event that has not yet been persisted to the
database. Intended for use with events that are created to be persisted in a batch.
Args:
event: the event the context is being computed for
state_ids_before_event: a state map consisting of the state ids of the events
created prior to this event.
current_state_group: the current state group before the event.
events_and_context: a list of events and their associated contexts
prev_group: the state group of the last event persisted before the batched events
were created
state_ids_before_event: a state map consisting of current state ids
"""
state_group_before_event_prev_group = None
deltas_to_state_group_before_event = None
# separate out state and non-state contexts
state_events = []
for event, context in events_and_context:
if event.is_state():
state_events.append((event, context))
# if the event is not state, we are set
if not event.is_state():
return EventContext.without_state_group(
storage=self._storage_controllers,
state_delta_due_to_event={},
prev_group=state_group_before_event_prev_group,
delta_ids=deltas_to_state_group_before_event,
partial_state=False,
)
# otherwise, we'll need to create a new state group for after the event
key = (event.type, event.state_key)
if state_ids_before_event is not None:
replaces = state_ids_before_event.get(key)
if replaces and replaces != event.event_id:
event.unsigned["replaces_state"] = replaces
delta_ids = {key: event.event_id}
context = EventContext.without_state_group(
storage=self._storage_controllers,
state_delta_due_to_event=delta_ids,
delta_ids=delta_ids,
partial_state=False,
# get state groups for state events
room_id = events_and_context[0][0].room_id
assert self.hs.datastores is not None
await self.hs.datastores.state.store_state_deltas_for_batched(
state_events, room_id, prev_group=prev_group
)
return context
# iterate through all contexts and update everything
current_state_group = prev_group
for event, context in events_and_context:
# if the event is not state, we need to update it
if not event.is_state():
context._state_group = current_state_group
context.state_group_before_event = current_state_group
context._state_delta_due_to_event = {}
context.prev_group = None
context.delta_ids = None
context.partial_state = False
# the context should have been updated when storing the state groups but let's
# be sure - if it does not have a state group there is a problem
if context._state_group is None:
raise RuntimeError(
f"Event {event.event_id} is missing a state group."
)
current_state_group = context._state_group
key = (event.type, event.state_key)
replaces = state_ids_before_event.get(key)
if replaces and replaces != event.event_id:
event.unsigned["replaces_state"] = replaces
return events_and_context
@measure_func()
async def resolve_state_groups_for_events(