1
0

Compare commits

...

30 Commits

Author SHA1 Message Date
Erik Johnston
055dc16d49 Fix sending receipts 2022-05-23 09:40:36 +01:00
Erik Johnston
9319bf036c Log when we get state at sync 2022-05-21 17:10:54 +01:00
Erik Johnston
e5c2ea6341 Don't pull out state for catchup 2022-05-21 16:53:01 +01:00
Erik Johnston
8141a0d0b3 Fix test for prev 2022-05-21 16:51:58 +01:00
Erik Johnston
8e47d72992 Pull current hosts out from current_state table 2022-05-21 16:42:51 +01:00
Erik Johnston
da10dfc311 Merge branch 'erikj/less_state_on_missing' into erikj/push_hack 2022-05-21 14:13:10 +01:00
Erik Johnston
f9d470b2da Fix tests 2022-05-21 14:11:52 +01:00
Erik Johnston
8b33331cb5 Newsfile 2022-05-21 14:01:21 +01:00
Erik Johnston
2ebb0c6f99 Pull out less state when handling gaps 2022-05-21 13:58:52 +01:00
Erik Johnston
3bbe3074fb Ignore display name push 2022-05-20 18:57:26 +01:00
Erik Johnston
6fd8b850ed Reduce state that push rules pull from DB 2022-05-20 18:57:13 +01:00
Erik Johnston
4b5a1a45da Don't pull out stuff for push 2022-05-20 18:50:59 +01:00
Erik Johnston
2dd2ca17a0 Fix lint 2022-05-20 18:34:05 +01:00
reivilibre
39dee30f01 Send USER_IP commands on a different Redis channel, in order to reduce traffic to workers that do not process these commands. (#12809) 2022-05-20 15:28:23 +01:00
Erik Johnston
456a394bf7 Use 'get_domains_from_state' still for backfill 2022-05-20 15:13:06 +01:00
David Teller
10280fc943 Uniformize spam-checker API, part 1: the Code enum. (#12703) 2022-05-20 14:53:25 +02:00
Erik Johnston
4bd06c9c98 Newsfile 2022-05-20 13:02:47 +01:00
Erik Johnston
c8c12ac13a Pull out less state when checking soft fail 2022-05-20 12:58:50 +01:00
Erik Johnston
9bb3bbe153 Require 'latest_event_ids' 2022-05-20 12:58:50 +01:00
Erik Johnston
68ff8f3575 Remove 'get_current_state' from StateHandler 2022-05-20 12:58:50 +01:00
Erik Johnston
11efe7231f Use new helper functions 2022-05-20 12:51:53 +01:00
Erik Johnston
f69785e875 Add helper methods to store 2022-05-20 12:51:53 +01:00
Erik Johnston
151cb6e2f4 Use new store.get_current_state_event 2022-05-20 12:51:53 +01:00
Erik Johnston
d882ee6219 Use helper function elsewhere 2022-05-20 12:07:37 +01:00
Erik Johnston
94cd2cad4f Use helper function in auth 2022-05-20 12:03:40 +01:00
Erik Johnston
155399a145 Add helper function to get the current state event in the room 2022-05-20 12:03:40 +01:00
Shay
71e8afe34d Update EventContext get_current_event_ids and get_prev_event_ids to accept state filters and update calls where possible (#12791) 2022-05-20 09:54:12 +01:00
Sean Quah
2be5a2b07b Fix RetryDestinationLimiter re-starting finished log contexts (#12803)
Signed-off-by: Sean Quah <seanq@matrix.org>
2022-05-19 20:17:10 +01:00
Andrew Morgan
96df31239c Add a unit test for copying over arbitrary room types when upgrading a room (#12792) 2022-05-19 18:32:48 +01:00
reivilibre
177b884ad7 Lay some foundation work to allow workers to only subscribe to some kinds of messages, reducing replication traffic. (#12672) 2022-05-19 16:29:08 +01:00
49 changed files with 718 additions and 440 deletions

View File

@@ -0,0 +1 @@
Send `USER_IP` commands on a different Redis channel, in order to reduce traffic to workers that do not process these commands.

1
changelog.d/12703.misc Normal file
View File

@@ -0,0 +1 @@
Convert namespace class `Codes` into a string enum.

1
changelog.d/12791.misc Normal file
View File

@@ -0,0 +1 @@
Update EventContext `get_current_event_ids` and `get_prev_event_ids` to accept state filters and update calls where possible.

View File

@@ -0,0 +1 @@
Implement [MSC3818: Copy room type on upgrade](https://github.com/matrix-org/matrix-spec-proposals/pull/3818).

1
changelog.d/12803.bugfix Normal file
View File

@@ -0,0 +1 @@
Fix a long-standing bug where finished log contexts would be re-started when failing to contact remote homeservers.

View File

@@ -0,0 +1 @@
Send `USER_IP` commands on a different Redis channel, in order to reduce traffic to workers that do not process these commands.

1
changelog.d/12811.misc Normal file
View File

@@ -0,0 +1 @@
Reduce the amount of state we pull from the DB.

1
changelog.d/12828.misc Normal file
View File

@@ -0,0 +1 @@
Pull out less state when handling gaps in room DAG.

View File

@@ -61,7 +61,6 @@ class Auth:
self.hs = hs
self.clock = hs.get_clock()
self.store = hs.get_datastores().main
self.state = hs.get_state_handler()
self._account_validity_handler = hs.get_account_validity_handler()
self.token_cache: LruCache[str, Tuple[str, bool]] = LruCache(
@@ -81,7 +80,7 @@ class Auth:
user_id: str,
current_state: Optional[StateMap[EventBase]] = None,
allow_departed_users: bool = False,
) -> EventBase:
) -> Tuple[str, Optional[str]]:
"""Check if the user is in the room, or was at some point.
Args:
room_id: The room to check.
@@ -99,29 +98,28 @@ class Auth:
Raises:
AuthError if the user is/was not in the room.
Returns:
Membership event for the user if the user was in the
room. This will be the join event if they are currently joined to
the room. This will be the leave event if they have left the room.
The current membership of the user in the room and the
membership event ID of the user.
"""
if current_state:
member = current_state.get((EventTypes.Member, user_id), None)
else:
member = await self.state.get_current_state(
room_id=room_id, event_type=EventTypes.Member, state_key=user_id
)
if member:
membership = member.membership
(
membership,
member_event_id,
) = await self.store.get_local_current_membership_for_user_in_room(
user_id=user_id,
room_id=room_id,
)
if membership:
if membership == Membership.JOIN:
return member
return membership, member_event_id
# XXX this looks totally bogus. Why do we not allow users who have been banned,
# or those who were members previously and have been re-invited?
if allow_departed_users and membership == Membership.LEAVE:
forgot = await self.store.did_forget(user_id, room_id)
if not forgot:
return member
return membership, member_event_id
raise AuthError(403, "User %s not in room %s" % (user_id, room_id))
@@ -602,7 +600,8 @@ class Auth:
# We currently require the user is a "moderator" in the room. We do this
# by checking if they would (theoretically) be able to change the
# m.room.canonical_alias events
power_level_event = await self.state.get_current_state(
power_level_event = await self.store.get_current_state_event(
room_id, EventTypes.PowerLevels, ""
)
@@ -693,12 +692,11 @@ class Auth:
# * The user is a non-guest user, and was ever in the room
# * The user is a guest user, and has joined the room
# else it will throw.
member_event = await self.check_user_in_room(
return await self.check_user_in_room(
room_id, user_id, allow_departed_users=allow_departed_users
)
return member_event.membership, member_event.event_id
except AuthError:
visibility = await self.state.get_current_state(
visibility = await self.store.get_current_state_event(
room_id, EventTypes.RoomHistoryVisibility, ""
)
if (

View File

@@ -17,6 +17,7 @@
import logging
import typing
from enum import Enum
from http import HTTPStatus
from typing import Any, Dict, List, Optional, Union
@@ -30,7 +31,11 @@ if typing.TYPE_CHECKING:
logger = logging.getLogger(__name__)
class Codes:
class Codes(str, Enum):
"""
All known error codes, as an enum of strings.
"""
UNRECOGNIZED = "M_UNRECOGNIZED"
UNAUTHORIZED = "M_UNAUTHORIZED"
FORBIDDEN = "M_FORBIDDEN"
@@ -265,7 +270,9 @@ class UnrecognizedRequestError(SynapseError):
"""An error indicating we don't understand the request you're trying to make"""
def __init__(
self, msg: str = "Unrecognized request", errcode: str = Codes.UNRECOGNIZED
self,
msg: str = "Unrecognized request",
errcode: str = Codes.UNRECOGNIZED,
):
super().__init__(400, msg, errcode)

View File

@@ -24,6 +24,7 @@ from synapse.types import JsonDict, StateMap
if TYPE_CHECKING:
from synapse.storage import Storage
from synapse.storage.databases.main import DataStore
from synapse.storage.state import StateFilter
@attr.s(slots=True, auto_attribs=True)
@@ -196,7 +197,9 @@ class EventContext:
return self._state_group
async def get_current_state_ids(self) -> Optional[StateMap[str]]:
async def get_current_state_ids(
self, state_filter: Optional["StateFilter"] = None
) -> Optional[StateMap[str]]:
"""
Gets the room state map, including this event - ie, the state in ``state_group``
@@ -204,6 +207,9 @@ class EventContext:
not make it into the room state. This method will raise an exception if
``rejected`` is set.
Arg:
state_filter: specifies the type of state event to fetch from DB, example: EventTypes.JoinRules
Returns:
Returns None if state_group is None, which happens when the associated
event is an outlier.
@@ -216,7 +222,7 @@ class EventContext:
assert self._state_delta_due_to_event is not None
prev_state_ids = await self.get_prev_state_ids()
prev_state_ids = await self.get_prev_state_ids(state_filter)
if self._state_delta_due_to_event:
prev_state_ids = dict(prev_state_ids)
@@ -224,12 +230,17 @@ class EventContext:
return prev_state_ids
async def get_prev_state_ids(self) -> StateMap[str]:
async def get_prev_state_ids(
self, state_filter: Optional["StateFilter"] = None
) -> StateMap[str]:
"""
Gets the room state map, excluding this event.
For a non-state event, this will be the same as get_current_state_ids().
Args:
state_filter: specifies the type of state event to fetch from DB, example: EventTypes.JoinRules
Returns:
Returns {} if state_group is None, which happens when the associated
event is an outlier.
@@ -239,7 +250,7 @@ class EventContext:
"""
assert self.state_group_before_event is not None
return await self._storage.state.get_state_ids_for_group(
self.state_group_before_event
self.state_group_before_event, state_filter
)

View File

@@ -1167,14 +1167,10 @@ class FederationServer(FederationBase):
Raises:
AuthError if the server does not match the ACL
"""
state_ids = await self.store.get_current_state_ids(room_id)
acl_event_id = state_ids.get((EventTypes.ServerACL, ""))
if not acl_event_id:
return
acl_event = await self.store.get_event(acl_event_id)
if server_matches_acl_event(server_name, acl_event):
acl_event = await self.store.get_current_state_event(
room_id, EventTypes.ServerACL, ""
)
if not acl_event or server_matches_acl_event(server_name, acl_event):
return
raise AuthError(code=403, msg="Server is banned from room")

View File

@@ -602,7 +602,7 @@ class FederationSender(AbstractFederationSender):
room_id = receipt.room_id
# Work out which remote servers should be poked and poke them.
domains_set = await self.state.get_current_hosts_in_room(room_id)
domains_set = await self.store.get_current_hosts_in_room(room_id)
domains = [
d
for d in domains_set

View File

@@ -36,6 +36,7 @@ from synapse.metrics import sent_transactions_counter
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import ReadReceipt
from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter
from synapse.visibility import filter_events_for_server
if TYPE_CHECKING:
import synapse.server
@@ -76,6 +77,7 @@ class PerDestinationQueue:
):
self._server_name = hs.hostname
self._clock = hs.get_clock()
self._storage = hs.get_storage()
self._store = hs.get_datastores().main
self._transaction_manager = transaction_manager
self._instance_name = hs.get_instance_name()
@@ -441,6 +443,12 @@ class PerDestinationQueue:
"This should not happen." % event_ids
)
logger.info(
"Catching up destination %s with %d PDUs",
self._destination,
len(catchup_pdus),
)
# We send transactions with events from one room only, as its likely
# that the remote will have to do additional processing, which may
# take some time. It's better to give it small amounts of work
@@ -486,19 +494,17 @@ class PerDestinationQueue:
):
continue
# Filter out events where the server is not in the room,
# e.g. it may have left/been kicked. *Ideally* we'd pull
# out the kick and send that, but it's a rare edge case
# so we don't bother for now (the server that sent the
# kick should send it out if its online).
hosts = await self._state.get_hosts_in_room_at_events(
p.room_id, [p.event_id]
)
if self._destination not in hosts:
continue
new_pdus.append(p)
# Filter out events where the server is not in the room,
# e.g. it may have left/been kicked. *Ideally* we'd pull
# out the kick and send that, but it's a rare edge case
# so we don't bother for now (the server that sent the
# kick should send it out if its online).
new_pdus = await filter_events_for_server(
self._storage, self._destination, new_pdus, redact=False
)
# If we've filtered out all the extremities, fall back to
# sending the original event. This should ensure that the
# server gets at least some of missed events (especially if

View File

@@ -319,7 +319,7 @@ class DirectoryHandler:
Raises:
ShadowBanError if the requester has been shadow-banned.
"""
alias_event = await self.state.get_current_state(
alias_event = await self.store.get_current_state_event(
room_id, EventTypes.CanonicalAlias, ""
)

View File

@@ -54,6 +54,7 @@ from synapse.replication.http.federation import (
ReplicationStoreRoomOnOutlierMembershipRestServlet,
)
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.storage.state import StateFilter
from synapse.types import JsonDict, StateMap, get_domain_from_id
from synapse.util.async_helpers import Linearizer
from synapse.util.retryutils import NotRetryingDestination
@@ -352,7 +353,7 @@ class FederationHandler:
# First we try hosts that are already in the room
# TODO: HEURISTIC ALERT.
curr_state = await self.state_handler.get_current_state(room_id)
curr_state = await self.store.get_current_state(room_id)
curr_domains = get_domains_from_state(curr_state)
@@ -1259,7 +1260,9 @@ class FederationHandler:
event.content["third_party_invite"]["signed"]["token"],
)
original_invite = None
prev_state_ids = await context.get_prev_state_ids()
prev_state_ids = await context.get_prev_state_ids(
StateFilter.from_types([(EventTypes.ThirdPartyInvite, None)])
)
original_invite_id = prev_state_ids.get(key)
if original_invite_id:
original_invite = await self.store.get_event(
@@ -1308,7 +1311,9 @@ class FederationHandler:
signed = event.content["third_party_invite"]["signed"]
token = signed["token"]
prev_state_ids = await context.get_prev_state_ids()
prev_state_ids = await context.get_prev_state_ids(
StateFilter.from_types([(EventTypes.ThirdPartyInvite, None)])
)
invite_event_id = prev_state_ids.get((EventTypes.ThirdPartyInvite, token))
invite_event = None

View File

@@ -30,6 +30,7 @@ from typing import (
from prometheus_client import Counter
from synapse import event_auth
from synapse.api.constants import (
EventContentFields,
EventTypes,
@@ -63,6 +64,7 @@ from synapse.replication.http.federation import (
)
from synapse.state import StateResolutionStore
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.storage.state import StateFilter
from synapse.types import (
PersistedEventPosition,
RoomStreamToken,
@@ -461,7 +463,9 @@ class FederationEventHandler:
with nested_logging_context(suffix=event.event_id):
context = await self._state_handler.compute_event_context(
event,
old_state=state,
state_ids_before_event={
(e.type, e.state_key): e.event_id for e in state
},
partial_state=partial_state,
)
@@ -499,7 +503,7 @@ class FederationEventHandler:
# build a new state group for it if need be
context = await self._state_handler.compute_event_context(
event,
old_state=state,
state_ids_before_event=state,
)
if context.partial_state:
# this can happen if some or all of the event's prev_events still have
@@ -763,7 +767,7 @@ class FederationEventHandler:
async def _resolve_state_at_missing_prevs(
self, dest: str, event: EventBase
) -> Optional[Iterable[EventBase]]:
) -> Optional[StateMap[str]]:
"""Calculate the state at an event with missing prev_events.
This is used when we have pulled a batch of events from a remote server, and
@@ -790,8 +794,8 @@ class FederationEventHandler:
event: an event to check for missing prevs.
Returns:
if we already had all the prev events, `None`. Otherwise, returns a list of
the events in the state at `event`.
if we already had all the prev events, `None`. Otherwise, returns
the state at `event`.
"""
room_id = event.room_id
event_id = event.event_id
@@ -835,13 +839,7 @@ class FederationEventHandler:
dest, room_id, p
)
remote_state_map = {
(x.type, x.state_key): x.event_id for x in remote_state
}
state_maps.append(remote_state_map)
for x in remote_state:
event_map[x.event_id] = x
state_maps.append(remote_state)
room_version = await self._store.get_room_version_id(room_id)
state_map = await self._state_resolution_handler.resolve_events_with_store(
@@ -852,19 +850,6 @@ class FederationEventHandler:
state_res_store=StateResolutionStore(self._store),
)
# We need to give _process_received_pdu the actual state events
# rather than event ids, so generate that now.
# First though we need to fetch all the events that are in
# state_map, so we can build up the state below.
evs = await self._store.get_events(
list(state_map.values()),
get_prev_content=False,
redact_behaviour=EventRedactBehaviour.as_is,
)
event_map.update(evs)
state = [event_map[e] for e in state_map.values()]
except Exception:
logger.warning(
"Error attempting to resolve state at missing prev_events",
@@ -876,14 +861,14 @@ class FederationEventHandler:
"We can't get valid state history.",
affected=event_id,
)
return state
return state_map
async def _get_state_after_missing_prev_event(
self,
destination: str,
room_id: str,
event_id: str,
) -> List[EventBase]:
) -> StateMap[str]:
"""Requests all of the room state at a given event from a remote homeserver.
Args:
@@ -892,7 +877,7 @@ class FederationEventHandler:
event_id: The id of the event we want the state at.
Returns:
A list of events in the state, including the event itself
The state *after* the given event.
"""
(
state_event_ids,
@@ -911,15 +896,13 @@ class FederationEventHandler:
desired_events = set(state_event_ids)
desired_events.add(event_id)
logger.debug("Fetching %i events from cache/store", len(desired_events))
fetched_events = await self._store.get_events(
desired_events, allow_rejected=True
)
have_events = await self._store.have_seen_events(room_id, desired_events)
missing_desired_events = desired_events - fetched_events.keys()
missing_desired_events = desired_events - have_events
logger.debug(
"We are missing %i events (got %i)",
len(missing_desired_events),
len(fetched_events),
len(have_events),
)
# We probably won't need most of the auth events, so let's just check which
@@ -930,7 +913,7 @@ class FederationEventHandler:
# already have a bunch of the state events. It would be nice if the
# federation api gave us a way of finding out which we actually need.
missing_auth_events = set(auth_event_ids) - fetched_events.keys()
missing_auth_events = set(auth_event_ids) - have_events
missing_auth_events.difference_update(
await self._store.have_seen_events(room_id, missing_auth_events)
)
@@ -956,47 +939,54 @@ class FederationEventHandler:
destination=destination, room_id=room_id, event_ids=missing_events
)
# we need to make sure we re-load from the database to get the rejected
# state correct.
fetched_events.update(
await self._store.get_events(missing_desired_events, allow_rejected=True)
)
event_metadata = await self._store.get_metadata_for_events(state_event_ids)
# check for events which were in the wrong room.
#
# this can happen if a remote server claims that the state or
# auth_events at an event in room A are actually events in room B
bad_events = [
(event_id, event.room_id)
for event_id, event in fetched_events.items()
if event.room_id != room_id
]
event_metadata = await self._store.get_metadata_for_events(state_event_ids)
for bad_event_id, bad_room_id in bad_events:
# This is a bogus situation, but since we may only discover it a long time
# after it happened, we try our best to carry on, by just omitting the
# bad events from the returned state set.
logger.warning(
"Remote server %s claims event %s in room %s is an auth/state "
"event in room %s",
destination,
bad_event_id,
bad_room_id,
room_id,
)
state_map = {}
del fetched_events[bad_event_id]
for state_event_id, metadata in event_metadata.items():
if metadata.room_id != room_id:
# This is a bogus situation, but since we may only discover it a long time
# after it happened, we try our best to carry on, by just omitting the
# bad events from the returned state set.
logger.warning(
"Remote server %s claims event %s in room %s is an auth/state "
"event in room %s",
destination,
state_event_id,
metadata.room_id,
room_id,
)
continue
if metadata.state_key is None:
logger.warning(
"Remote server gave us non-state event in state: %s", state_event_id
)
continue
state_map[(metadata.event_type, metadata.state_key)] = state_event_id
# if we couldn't get the prev event in question, that's a problem.
remote_event = fetched_events.get(event_id)
remote_event = await self._store.get_event(
event_id,
allow_none=True,
allow_rejected=True,
redact_behaviour=EventRedactBehaviour.as_is,
)
if not remote_event:
raise Exception("Unable to get missing prev_event %s" % (event_id,))
# missing state at that event is a warning, not a blocker
# XXX: this doesn't sound right? it means that we'll end up with incomplete
# state.
failed_to_fetch = desired_events - fetched_events.keys()
failed_to_fetch = desired_events - event_metadata.keys()
if failed_to_fetch:
logger.warning(
"Failed to fetch missing state events for %s %s",
@@ -1004,14 +994,12 @@ class FederationEventHandler:
failed_to_fetch,
)
remote_state = [
fetched_events[e_id] for e_id in state_event_ids if e_id in fetched_events
]
if remote_event.is_state() and remote_event.rejected_reason is None:
remote_state.append(remote_event)
state_map[
(remote_event.type, remote_event.state_key)
] = remote_event.event_id
return remote_state
return state_map
async def _get_state_and_persist(
self, destination: str, room_id: str, event_id: str
@@ -1038,7 +1026,7 @@ class FederationEventHandler:
self,
origin: str,
event: EventBase,
state: Optional[Iterable[EventBase]],
state: Optional[StateMap[str]],
backfilled: bool = False,
) -> None:
"""Called when we have a new non-outlier event.
@@ -1072,7 +1060,7 @@ class FederationEventHandler:
try:
context = await self._state_handler.compute_event_context(
event, old_state=state
event, state_ids_before_event=state
)
context = await self._check_event_auth(
origin,
@@ -1500,7 +1488,11 @@ class FederationEventHandler:
return context
# now check auth against what we think the auth events *should* be.
prev_state_ids = await context.get_prev_state_ids()
event_types = event_auth.auth_types_for_event(event.room_version, event)
prev_state_ids = await context.get_prev_state_ids(
StateFilter.from_types(event_types)
)
auth_events_ids = self._event_auth_handler.compute_auth_events(
event, prev_state_ids, for_verification=True
)
@@ -1552,14 +1544,14 @@ class FederationEventHandler:
if guest_access == GuestAccess.CAN_JOIN:
return
current_state_map = await self._state_handler.get_current_state(event.room_id)
current_state = list(current_state_map.values())
await self._get_room_member_handler().kick_guest_users(current_state)
current_state = await self._store.get_current_state(event.room_id)
current_state_list = list(current_state.values())
await self._get_room_member_handler().kick_guest_users(current_state_list)
async def _check_for_soft_fail(
self,
event: EventBase,
state: Optional[Iterable[EventBase]],
state: Optional[StateMap[str]],
origin: str,
) -> None:
"""Checks if we should soft fail the event; if so, marks the event as
@@ -1582,6 +1574,9 @@ class FederationEventHandler:
room_version = await self._store.get_room_version_id(event.room_id)
room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
# The event types we want to pull from the "current" state.
auth_types = auth_types_for_event(room_version_obj, event)
# Calculate the "current state".
if state is not None:
# If we're explicitly given the state then we won't have all the
@@ -1596,20 +1591,24 @@ class FederationEventHandler:
# given state at the event. This should correctly handle cases
# like bans, especially with state res v2.
state_sets_d = await self._state_store.get_state_groups(
state_sets_d = await self._state_store.get_state_groups_ids(
event.room_id, extrem_ids
)
state_sets: List[Iterable[EventBase]] = list(state_sets_d.values())
state_sets: List[StateMap[str]] = list(state_sets_d.values())
state_sets.append(state)
current_states = await self._state_handler.resolve_events(
room_version, state_sets, event
current_state_ids = (
await self._state_resolution_handler.resolve_events_with_store(
event.room_id,
room_version,
state_sets,
event_map={},
state_res_store=StateResolutionStore(self._store),
)
)
current_state_ids: StateMap[str] = {
k: e.event_id for k, e in current_states.items()
}
else:
current_state_ids = await self._state_handler.get_current_state_ids(
event.room_id, latest_event_ids=extrem_ids
current_state_ids = await self._store.get_filtered_current_state_ids(
event.room_id, StateFilter.from_types(auth_types)
)
logger.debug(
@@ -1619,7 +1618,6 @@ class FederationEventHandler:
)
# Now check if event pass auth against said current state
auth_types = auth_types_for_event(room_version_obj, event)
current_state_ids_list = [
e for k, e in current_state_ids.items() if k in auth_types
]

View File

@@ -190,7 +190,7 @@ class InitialSyncHandler:
if event.membership == Membership.JOIN:
room_end_token = now_token.room_key
deferred_room_state = run_in_background(
self.state_handler.get_current_state, event.room_id
self.store.get_current_state, event.room_id
)
elif event.membership == Membership.LEAVE:
room_end_token = RoomStreamToken(
@@ -404,7 +404,7 @@ class InitialSyncHandler:
membership: str,
is_peeking: bool,
) -> JsonDict:
current_state = await self.state.get_current_state(room_id=room_id)
current_state = await self.store.get_current_state(room_id=room_id)
# TODO: These concurrently
time_now = self.clock.time_msec()

View File

@@ -117,7 +117,9 @@ class MessageHandler:
)
if membership == Membership.JOIN:
data = await self.state.get_current_state(room_id, event_type, state_key)
data = await self.store.get_current_state_event(
room_id, event_type, state_key
)
elif membership == Membership.LEAVE:
key = (event_type, state_key)
# If the membership is not JOIN, then the event ID should exist.
@@ -634,7 +636,9 @@ class EventCreationHandler:
# federation as well as those created locally. As of room v3, aliases events
# can be created by users that are not in the room, therefore we have to
# tolerate them in event_auth.check().
prev_state_ids = await context.get_prev_state_ids()
prev_state_ids = await context.get_prev_state_ids(
StateFilter.from_types([(EventTypes.Member, None)])
)
prev_event_id = prev_state_ids.get((EventTypes.Member, event.sender))
prev_event = (
await self.store.get_event(prev_event_id, allow_none=True)
@@ -761,7 +765,9 @@ class EventCreationHandler:
# This can happen due to out of band memberships
return None
prev_state_ids = await context.get_prev_state_ids()
prev_state_ids = await context.get_prev_state_ids(
StateFilter.from_types([(event.type, None)])
)
prev_event_id = prev_state_ids.get((event.type, event.state_key))
if not prev_event_id:
return None
@@ -1017,8 +1023,21 @@ class EventCreationHandler:
#
# TODO(faster_joins): figure out how this works, and make sure that the
# old state is complete.
old_state = await self.store.get_events_as_list(state_event_ids)
context = await self.state.compute_event_context(event, old_state=old_state)
metadata = await self.store.get_metadata_for_events(state_event_ids)
state_map = {}
for event_id, data in metadata.items():
if data.state_key is None:
raise Exception(
"Trying to set non-state event as state: %s", event_id
)
state_map[(data.event_type, data.state_key)] = event_id
context = await self.state.compute_event_context(
event,
state_ids_before_event=state_map,
)
else:
context = await self.state.compute_event_context(event)
@@ -1547,7 +1566,11 @@ class EventCreationHandler:
"Redacting MSC2716 events is not supported in this room version",
)
prev_state_ids = await context.get_prev_state_ids()
event_types = event_auth.auth_types_for_event(event.room_version, event)
prev_state_ids = await context.get_prev_state_ids(
StateFilter.from_types(event_types)
)
auth_events_ids = self._event_auth_handler.compute_auth_events(
event, prev_state_ids, for_verification=True
)

View File

@@ -303,7 +303,10 @@ class RoomCreationHandler:
context=tombstone_context,
)
old_room_state = await tombstone_context.get_current_state_ids()
state_filter = StateFilter.from_types(
[(EventTypes.CanonicalAlias, ""), (EventTypes.PowerLevels, "")]
)
old_room_state = await tombstone_context.get_current_state_ids(state_filter)
# We know the tombstone event isn't an outlier so it has current state.
assert old_room_state is not None
@@ -427,7 +430,7 @@ class RoomCreationHandler:
requester: the user requesting the upgrade
old_room_id : the id of the room to be replaced
new_room_id: the id to give the new room (should already have been
created with _gemerate_room_id())
created with _generate_room_id())
new_room_version: the new room version to use
tombstone_event_id: the ID of the tombstone event in the old room.
"""
@@ -1396,7 +1399,7 @@ class TimestampLookupHandler:
)
# Find other homeservers from the given state in the room
curr_state = await self.state_handler.get_current_state(room_id)
curr_state = await self.store.get_current_state(room_id)
curr_domains = get_domains_from_state(curr_state)
likely_domains = [
domain for domain, depth in curr_domains if domain != self.server_name

View File

@@ -38,6 +38,7 @@ from synapse.event_auth import get_named_level, get_power_level_event
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.handlers.profile import MAX_AVATAR_URL_LEN, MAX_DISPLAYNAME_LEN
from synapse.storage.state import StateFilter
from synapse.types import (
JsonDict,
Requester,
@@ -362,7 +363,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
historical=historical,
)
prev_state_ids = await context.get_prev_state_ids()
prev_state_ids = await context.get_prev_state_ids(
StateFilter.from_types([(EventTypes.Member, None)])
)
prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None)
@@ -1160,7 +1163,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
else:
requester = types.create_requester(target_user)
prev_state_ids = await context.get_prev_state_ids()
prev_state_ids = await context.get_prev_state_ids(
StateFilter.from_types([(EventTypes.GuestAccess, None)])
)
if event.membership == Membership.JOIN:
if requester.is_guest:
guest_can_join = await self._can_guest_join(prev_state_ids)
@@ -1404,7 +1409,19 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
txn_id: Optional[str],
id_access_token: Optional[str] = None,
) -> int:
room_state = await self.state_handler.get_current_state(room_id)
room_state = await self.store.get_filtered_current_state(
room_id,
StateFilter.from_types(
[
(EventTypes.Member, user.to_string()),
(EventTypes.CanonicalAlias, ""),
(EventTypes.Name, ""),
(EventTypes.Create, ""),
(EventTypes.JoinRules, ""),
(EventTypes.RoomAvatar, ""),
]
),
)
inviter_display_name = ""
inviter_avatar_url = ""
@@ -1800,7 +1817,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
async def forget(self, user: UserID, room_id: str) -> None:
user_id = user.to_string()
member = await self.state_handler.get_current_state(
member = await self.store.get_current_state_event(
room_id=room_id, event_type=EventTypes.Member, state_key=user_id
)
membership = member.membership if member else None

View File

@@ -348,7 +348,7 @@ class SearchHandler:
state_results = {}
if include_state:
for room_id in {e.room_id for e in search_result.allowed_events}:
state = await self.state_handler.get_current_state(room_id)
state = await self.store.get_current_state(room_id)
state_results[room_id] = list(state.values())
aggregations = await self._relations_handler.get_bundled_aggregations(

View File

@@ -18,6 +18,8 @@ from typing import TYPE_CHECKING, Any, Dict, FrozenSet, List, Optional, Set, Tup
import attr
from prometheus_client import Counter
from twisted.python import failure
from synapse.api.constants import EventTypes, Membership, ReceiptTypes
from synapse.api.filtering import FilterCollection
from synapse.api.presence import UserPresenceState
@@ -643,6 +645,13 @@ class SyncHandler:
event: event of interest
state_filter: The state filter used to fetch state from the database.
"""
f = failure.Failure()
logger.info(
"SYNC get_state_after_event in room %s",
event.room_id,
exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore
)
state_ids = await self.state_store.get_state_ids_for_event(
event.event_id, state_filter=state_filter or StateFilter.all()
)

View File

@@ -681,7 +681,7 @@ class Notifier:
return joined_room_ids, True
async def _is_world_readable(self, room_id: str) -> bool:
state = await self.state_handler.get_current_state(
state = await self.store.get_current_state_event(
room_id, EventTypes.RoomHistoryVisibility, ""
)
if state and "history_visibility" in state.content:

View File

@@ -20,7 +20,7 @@ import attr
from prometheus_client import Counter
from synapse.api.constants import EventTypes, Membership, RelationTypes
from synapse.event_auth import get_user_power_level
from synapse.event_auth import auth_types_for_event, get_user_power_level
from synapse.events import EventBase, relation_from_event
from synapse.events.snapshot import EventContext
from synapse.state import POWER_KEY
@@ -31,6 +31,7 @@ from synapse.util.caches.descriptors import lru_cache
from synapse.util.caches.lrucache import LruCache
from synapse.util.metrics import measure_func
from ..storage.state import StateFilter
from .push_rule_evaluator import PushRuleEvaluatorForEvent
if TYPE_CHECKING:
@@ -168,8 +169,12 @@ class BulkPushRuleEvaluator:
async def _get_power_levels_and_sender_level(
self, event: EventBase, context: EventContext
) -> Tuple[dict, int]:
prev_state_ids = await context.get_prev_state_ids()
event_types = auth_types_for_event(event.room_version, event)
prev_state_ids = await context.get_prev_state_ids(
StateFilter.from_types(event_types)
)
pl_event_id = prev_state_ids.get(POWER_KEY)
if pl_event_id:
# fastpath: if there's a power level event, that's all we need, and
# not having a power level event is an extreme edge case
@@ -204,7 +209,12 @@ class BulkPushRuleEvaluator:
rules_by_user = await self._get_rules_for_event(event, context)
actions_by_user: Dict[str, List[Union[dict, str]]] = {}
room_members = await self.store.get_joined_users_from_context(event, context)
# FIXME!!!
# room_members = await self.store.get_joined_users_from_context(event, context)
room_member_count = await self.store.get_number_joined_users_in_room(
event.room_id
)
(
power_levels,
@@ -212,7 +222,7 @@ class BulkPushRuleEvaluator:
) = await self._get_power_levels_and_sender_level(event, context)
evaluator = PushRuleEvaluatorForEvent(
event, len(room_members), sender_power_level, power_levels
event, room_member_count, sender_power_level, power_levels
)
# If the event is not a state event check if any users ignore the sender.
@@ -229,9 +239,10 @@ class BulkPushRuleEvaluator:
continue
display_name = None
profile_info = room_members.get(uid)
if profile_info:
display_name = profile_info.display_name
# FIXME!!!
# profile_info = room_members.get(uid)
# if profile_info:
# display_name = profile_info.display_name
if not display_name:
# Handle the case where we are pushing a membership event to
@@ -382,77 +393,27 @@ class RulesForRoom:
self.room_push_rule_cache_metrics.inc_hits()
return self.data.rules_by_user
self.room_push_rule_cache_metrics.inc_misses()
ret_rules_by_user = {}
missing_member_event_ids = {}
if state_group and self.data.state_group == context.prev_group:
# If we have a simple delta then we can reuse most of the previous
# results.
ret_rules_by_user = self.data.rules_by_user
current_state_ids = context.delta_ids
push_rules_delta_state_cache_metric.inc_hits()
else:
current_state_ids = await context.get_current_state_ids()
push_rules_delta_state_cache_metric.inc_misses()
# Ensure the state IDs exist.
assert current_state_ids is not None
push_rules_state_size_counter.inc(len(current_state_ids))
logger.debug(
"Looking for member changes in %r %r", state_group, current_state_ids
local_users = await self.store.get_local_users_in_room(
self.room_id, on_invalidate=self.invalidate_all_cb
)
# Loop through to see which member events we've seen and have rules
# for and which we need to fetch
for key in current_state_ids:
typ, user_id = key
if typ != EventTypes.Member:
continue
if event.type == EventTypes.Member and event.membership == Membership.JOIN:
if self.is_mine_id(event.state_key):
local_users = list(local_users)
local_users.append(event.state_key)
if user_id in self.data.uninteresting_user_set:
continue
ret_rules_by_user = await self.store.bulk_get_push_rules(
local_users, on_invalidate=self.invalidate_all_cb
)
if not self.is_mine_id(user_id):
self.data.uninteresting_user_set.add(user_id)
continue
logger.info("Users in room: %s", local_users)
if self.store.get_if_app_services_interested_in_user(user_id):
self.data.uninteresting_user_set.add(user_id)
continue
event_id = current_state_ids[key]
res = self.data.member_map.get(event_id, None)
if res:
if res.membership == Membership.JOIN:
rules = self.data.rules_by_user.get(res.user_id, None)
if rules:
ret_rules_by_user[res.user_id] = rules
continue
# If a user has left a room we remove their push rule. If they
# joined then we re-add it later in _update_rules_with_member_event_ids
ret_rules_by_user.pop(user_id, None)
missing_member_event_ids[user_id] = event_id
if missing_member_event_ids:
# If we have some member events we haven't seen, look them up
# and fetch push rules for them if appropriate.
logger.debug("Found new member events %r", missing_member_event_ids)
await self._update_rules_with_member_event_ids(
ret_rules_by_user, missing_member_event_ids, state_group, event
)
else:
# The push rules didn't change but lets update the cache anyway
self.update_cache(
self.data.sequence,
members={}, # There were no membership changes
rules_by_user=ret_rules_by_user,
state_group=state_group,
)
self.update_cache(
self.data.sequence,
members={}, # There were no membership changes
rules_by_user=ret_rules_by_user,
state_group=state_group,
)
if logger.isEnabledFor(logging.DEBUG):
logger.debug(
@@ -460,67 +421,6 @@ class RulesForRoom:
)
return ret_rules_by_user
async def _update_rules_with_member_event_ids(
self,
ret_rules_by_user: Dict[str, list],
member_event_ids: Dict[str, str],
state_group: Optional[int],
event: EventBase,
) -> None:
"""Update the partially filled rules_by_user dict by fetching rules for
any newly joined users in the `member_event_ids` list.
Args:
ret_rules_by_user: Partially filled dict of push rules. Gets
updated with any new rules.
member_event_ids: Dict of user id to event id for membership events
that have happened since the last time we filled rules_by_user
state_group: The state group we are currently computing push rules
for. Used when updating the cache.
event: The event we are currently computing push rules for.
"""
sequence = self.data.sequence
members = await self.store.get_membership_from_event_ids(
member_event_ids.values()
)
# If the event is a join event then it will be in current state events
# map but not in the DB, so we have to explicitly insert it.
if event.type == EventTypes.Member:
for event_id in member_event_ids.values():
if event_id == event.event_id:
members[event_id] = EventIdMembership(
user_id=event.state_key, membership=event.membership
)
if logger.isEnabledFor(logging.DEBUG):
logger.debug("Found members %r: %r", self.room_id, members.values())
joined_user_ids = {
entry.user_id
for entry in members.values()
if entry and entry.membership == Membership.JOIN
}
logger.debug("Joined: %r", joined_user_ids)
# Previously we only considered users with pushers or read receipts in that
# room. We can't do this anymore because we use push actions to calculate unread
# counts, which don't rely on the user having pushers or sent a read receipt into
# the room. Therefore we just need to filter for local users here.
user_ids = list(filter(self.is_mine_id, joined_user_ids))
rules_by_user = await self.store.bulk_get_push_rules(
user_ids, on_invalidate=self.invalidate_all_cb
)
ret_rules_by_user.update(
item for item in rules_by_user.items() if item[0] is not None
)
self.update_cache(sequence, members, ret_rules_by_user, state_group)
def update_cache(
self,
sequence: int,

View File

@@ -58,6 +58,15 @@ class Command(metaclass=abc.ABCMeta):
# by default, we just use the command name.
return self.NAME
def redis_channel_name(self, prefix: str) -> str:
"""
Returns the Redis channel name upon which to publish this command.
Args:
prefix: The prefix for the channel.
"""
return prefix
SC = TypeVar("SC", bound="_SimpleCommand")
@@ -395,6 +404,9 @@ class UserIpCommand(Command):
f"{self.user_agent!r}, {self.device_id!r}, {self.last_seen})"
)
def redis_channel_name(self, prefix: str) -> str:
return f"{prefix}/USER_IP"
class RemoteServerUpCommand(_SimpleCommand):
"""Sent when a worker has detected that a remote server is no longer

View File

@@ -1,5 +1,5 @@
# Copyright 2017 Vector Creations Ltd
# Copyright 2020 The Matrix.org Foundation C.I.C.
# Copyright 2020, 2022 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -101,6 +101,9 @@ class ReplicationCommandHandler:
self._instance_id = hs.get_instance_id()
self._instance_name = hs.get_instance_name()
# Additional Redis channel suffixes to subscribe to.
self._channels_to_subscribe_to: List[str] = []
self._is_presence_writer = (
hs.get_instance_name() in hs.config.worker.writers.presence
)
@@ -243,6 +246,31 @@ class ReplicationCommandHandler:
# If we're NOT using Redis, this must be handled by the master
self._should_insert_client_ips = hs.get_instance_name() == "master"
if self._is_master or self._should_insert_client_ips:
self.subscribe_to_channel("USER_IP")
def subscribe_to_channel(self, channel_name: str) -> None:
"""
Indicates that we wish to subscribe to a Redis channel by name.
(The name will later be prefixed with the server name; i.e. subscribing
to the 'ABC' channel actually subscribes to 'example.com/ABC' Redis-side.)
Raises:
- If replication has already started, then it's too late to subscribe
to new channels.
"""
if self._factory is not None:
# We don't allow subscribing after the fact to avoid the chance
# of missing an important message because we didn't subscribe in time.
raise RuntimeError(
"Cannot subscribe to more channels after replication started."
)
if channel_name not in self._channels_to_subscribe_to:
self._channels_to_subscribe_to.append(channel_name)
def _add_command_to_stream_queue(
self, conn: IReplicationConnection, cmd: Union[RdataCommand, PositionCommand]
) -> None:
@@ -321,7 +349,9 @@ class ReplicationCommandHandler:
# Now create the factory/connection for the subscription stream.
self._factory = RedisDirectTcpReplicationClientFactory(
hs, outbound_redis_connection
hs,
outbound_redis_connection,
channel_names=self._channels_to_subscribe_to,
)
hs.get_reactor().connectTCP(
hs.config.redis.redis_host,

View File

@@ -14,7 +14,7 @@
import logging
from inspect import isawaitable
from typing import TYPE_CHECKING, Any, Generic, Optional, Type, TypeVar, cast
from typing import TYPE_CHECKING, Any, Generic, List, Optional, Type, TypeVar, cast
import attr
import txredisapi
@@ -85,14 +85,15 @@ class RedisSubscriber(txredisapi.SubscriberProtocol):
Attributes:
synapse_handler: The command handler to handle incoming commands.
synapse_stream_name: The *redis* stream name to subscribe to and publish
synapse_stream_prefix: The *redis* stream name to subscribe to and publish
from (not anything to do with Synapse replication streams).
synapse_outbound_redis_connection: The connection to redis to use to send
commands.
"""
synapse_handler: "ReplicationCommandHandler"
synapse_stream_name: str
synapse_stream_prefix: str
synapse_channel_names: List[str]
synapse_outbound_redis_connection: txredisapi.ConnectionHandler
def __init__(self, *args: Any, **kwargs: Any):
@@ -117,8 +118,13 @@ class RedisSubscriber(txredisapi.SubscriberProtocol):
# it's important to make sure that we only send the REPLICATE command once we
# have successfully subscribed to the stream - otherwise we might miss the
# POSITION response sent back by the other end.
logger.info("Sending redis SUBSCRIBE for %s", self.synapse_stream_name)
await make_deferred_yieldable(self.subscribe(self.synapse_stream_name))
fully_qualified_stream_names = [
f"{self.synapse_stream_prefix}/{stream_suffix}"
for stream_suffix in self.synapse_channel_names
] + [self.synapse_stream_prefix]
logger.info("Sending redis SUBSCRIBE for %r", fully_qualified_stream_names)
await make_deferred_yieldable(self.subscribe(fully_qualified_stream_names))
logger.info(
"Successfully subscribed to redis stream, sending REPLICATE command"
)
@@ -215,10 +221,10 @@ class RedisSubscriber(txredisapi.SubscriberProtocol):
# remote instances.
tcp_outbound_commands_counter.labels(cmd.NAME, "redis").inc()
channel_name = cmd.redis_channel_name(self.synapse_stream_prefix)
await make_deferred_yieldable(
self.synapse_outbound_redis_connection.publish(
self.synapse_stream_name, encoded_string
)
self.synapse_outbound_redis_connection.publish(channel_name, encoded_string)
)
@@ -300,20 +306,27 @@ def format_address(address: IAddress) -> str:
class RedisDirectTcpReplicationClientFactory(SynapseRedisFactory):
"""This is a reconnecting factory that connects to redis and immediately
subscribes to a stream.
subscribes to some streams.
Args:
hs
outbound_redis_connection: A connection to redis that will be used to
send outbound commands (this is separate to the redis connection
used to subscribe).
channel_names: A list of channel names to append to the base channel name
to additionally subscribe to.
e.g. if ['ABC', 'DEF'] is specified then we'll listen to:
example.com; example.com/ABC; and example.com/DEF.
"""
maxDelay = 5
protocol = RedisSubscriber
def __init__(
self, hs: "HomeServer", outbound_redis_connection: txredisapi.ConnectionHandler
self,
hs: "HomeServer",
outbound_redis_connection: txredisapi.ConnectionHandler,
channel_names: List[str],
):
super().__init__(
@@ -326,7 +339,8 @@ class RedisDirectTcpReplicationClientFactory(SynapseRedisFactory):
)
self.synapse_handler = hs.get_replication_command_handler()
self.synapse_stream_name = hs.hostname
self.synapse_stream_prefix = hs.hostname
self.synapse_channel_names = channel_names
self.synapse_outbound_redis_connection = outbound_redis_connection
@@ -340,7 +354,8 @@ class RedisDirectTcpReplicationClientFactory(SynapseRedisFactory):
# protocol.
p.synapse_handler = self.synapse_handler
p.synapse_outbound_redis_connection = self.synapse_outbound_redis_connection
p.synapse_stream_name = self.synapse_stream_name
p.synapse_stream_prefix = self.synapse_stream_prefix
p.synapse_channel_names = self.synapse_channel_names
return p

View File

@@ -34,6 +34,7 @@ from synapse.rest.admin._base import (
assert_user_is_admin,
)
from synapse.storage.databases.main.room import RoomSortOrder
from synapse.storage.state import StateFilter
from synapse.types import JsonDict, RoomID, UserID, create_requester
from synapse.util import json_decoder
@@ -447,7 +448,7 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet):
super().__init__(hs)
self.auth = hs.get_auth()
self.admin_handler = hs.get_admin_handler()
self.state_handler = hs.get_state_handler()
self.store = hs.get_datastores().main
self.is_mine = hs.is_mine
async def on_POST(
@@ -489,8 +490,9 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet):
)
# send invite if room has "JoinRules.INVITE"
room_state = await self.state_handler.get_current_state(room_id)
join_rules_event = room_state.get((EventTypes.JoinRules, ""))
join_rules_event = await self.store.get_current_state_event(
room_id, EventTypes.JoinRules, ""
)
if join_rules_event:
if not (join_rules_event.content.get("join_rule") == JoinRules.PUBLIC):
# update_membership with an action of "invite" can raise a
@@ -552,12 +554,22 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet):
user_to_add = content.get("user_id", requester.user.to_string())
# Figure out which local users currently have power in the room, if any.
room_state = await self.state_handler.get_current_state(room_id)
if not room_state:
filtered_room_state = await self.store.get_filtered_current_state(
room_id,
StateFilter.from_types(
[
(EventTypes.Create, ""),
(EventTypes.PowerLevels, ""),
(EventTypes.JoinRules, ""),
(EventTypes.Member, user_to_add),
]
),
)
if not filtered_room_state:
raise SynapseError(HTTPStatus.BAD_REQUEST, "Server not in room")
create_event = room_state[(EventTypes.Create, "")]
power_levels = room_state.get((EventTypes.PowerLevels, ""))
create_event = filtered_room_state[(EventTypes.Create, "")]
power_levels = filtered_room_state.get((EventTypes.PowerLevels, ""))
if power_levels is not None:
# We pick the local user with the highest power.
@@ -633,7 +645,7 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet):
# Now we check if the user we're granting admin rights to is already in
# the room. If not and it's not a public room we invite them.
member_event = room_state.get((EventTypes.Member, user_to_add))
member_event = filtered_room_state.get((EventTypes.Member, user_to_add))
is_joined = False
if member_event:
is_joined = member_event.content["membership"] in (
@@ -644,7 +656,7 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet):
if is_joined:
return HTTPStatus.OK, {}
join_rules = room_state.get((EventTypes.JoinRules, ""))
join_rules = filtered_room_state.get((EventTypes.JoinRules, ""))
is_public = False
if join_rules:
is_public = join_rules.content.get("join_rule") == JoinRules.PUBLIC

View File

@@ -673,7 +673,7 @@ class RoomEventServlet(RestServlet):
if include_unredacted_content and not await self.auth.is_server_admin(
requester.user
):
power_level_event = await self._state.get_current_state(
power_level_event = await self._store.get_current_state_event(
room_id, EventTypes.PowerLevels, ""
)

View File

@@ -178,8 +178,8 @@ class ResourceLimitsServerNotices:
currently_blocked = False
pinned_state_event = None
try:
pinned_state_event = await self._state.get_current_state(
room_id, event_type=EventTypes.Pinned
pinned_state_event = await self._store.get_current_state_event(
room_id, event_type=EventTypes.Pinned, state_key=""
)
except AuthError:
# The user has yet to join the server notices room

View File

@@ -32,13 +32,11 @@ from typing import (
Set,
Tuple,
Union,
overload,
)
import attr
from frozendict import frozendict
from prometheus_client import Counter, Histogram
from typing_extensions import Literal
from synapse.api.constants import EventTypes
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, StateResolutionVersions
@@ -132,85 +130,20 @@ class StateHandler:
self._state_resolution_handler = hs.get_state_resolution_handler()
self._storage = hs.get_storage()
@overload
async def get_current_state(
self,
room_id: str,
event_type: Literal[None] = None,
state_key: str = "",
latest_event_ids: Optional[List[str]] = None,
) -> StateMap[EventBase]:
...
@overload
async def get_current_state(
self,
room_id: str,
event_type: str,
state_key: str = "",
latest_event_ids: Optional[List[str]] = None,
) -> Optional[EventBase]:
...
async def get_current_state(
self,
room_id: str,
event_type: Optional[str] = None,
state_key: str = "",
latest_event_ids: Optional[List[str]] = None,
) -> Union[Optional[EventBase], StateMap[EventBase]]:
"""Retrieves the current state for the room. This is done by
calling `get_latest_events_in_room` to get the leading edges of the
event graph and then resolving any of the state conflicts.
This is equivalent to getting the state of an event that were to send
next before receiving any new events.
Returns:
If `event_type` is specified, then the method returns only the one
event (or None) with that `event_type` and `state_key`.
Otherwise, a map from (type, state_key) to event.
"""
if not latest_event_ids:
latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id)
assert latest_event_ids is not None
logger.debug("calling resolve_state_groups from get_current_state")
ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
state = ret.state
if event_type:
event_id = state.get((event_type, state_key))
event = None
if event_id:
event = await self.store.get_event(event_id, allow_none=True)
return event
state_map = await self.store.get_events(
list(state.values()), get_prev_content=False
)
return {
key: state_map[e_id] for key, e_id in state.items() if e_id in state_map
}
async def get_current_state_ids(
self, room_id: str, latest_event_ids: Optional[Collection[str]] = None
self,
room_id: str,
latest_event_ids: Collection[str],
) -> StateMap[str]:
"""Get the current state, or the state at a set of events, for a room
Args:
room_id:
latest_event_ids: if given, the forward extremities to resolve. If
None, we look them up from the database (via a cache).
latest_event_ids: The forward extremities to resolve.
Returns:
the state dict, mapping from (event_type, state_key) -> event_id
"""
if not latest_event_ids:
latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id)
assert latest_event_ids is not None
logger.debug("calling resolve_state_groups from get_current_state_ids")
ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
return ret.state
@@ -239,10 +172,6 @@ class StateHandler:
entry = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
return await self.store.get_joined_users_from_state(room_id, entry)
async def get_current_hosts_in_room(self, room_id: str) -> FrozenSet[str]:
event_ids = await self.store.get_latest_event_ids_in_room(room_id)
return await self.get_hosts_in_room_at_events(room_id, event_ids)
async def get_hosts_in_room_at_events(
self, room_id: str, event_ids: Collection[str]
) -> FrozenSet[str]:
@@ -261,7 +190,7 @@ class StateHandler:
async def compute_event_context(
self,
event: EventBase,
old_state: Optional[Iterable[EventBase]] = None,
state_ids_before_event: Optional[StateMap[str]] = None,
partial_state: bool = False,
) -> EventContext:
"""Build an EventContext structure for a non-outlier event.
@@ -273,12 +202,12 @@ class StateHandler:
Args:
event:
old_state: The state at the event if it can't be
state_ids_before_event: The state at the event if it can't be
calculated from existing events. This is normally only specified
when receiving an event from federation where we don't have the
prev events for, e.g. when backfilling.
partial_state: True if `old_state` is partial and omits non-critical
membership events
partial_state: True if `state_ids_before_event` is partial and omits
non-critical membership events
Returns:
The event context.
"""
@@ -288,11 +217,7 @@ class StateHandler:
#
# first of all, figure out the state before the event
#
if old_state:
# if we're given the state before the event, then we use that
state_ids_before_event: StateMap[str] = {
(s.type, s.state_key): s.event_id for s in old_state
}
if state_ids_before_event:
state_group_before_event = None
state_group_before_event_prev_group = None
deltas_to_state_group_before_event = None

View File

@@ -74,6 +74,9 @@ class SQLBaseStore(metaclass=ABCMeta):
self._attempt_to_invalidate_cache(
"get_users_in_room_with_profiles", (room_id,)
)
self._attempt_to_invalidate_cache(
"get_number_joined_users_in_room.invalidate", (room_id,)
)
# Purge other caches based on room state.
self._attempt_to_invalidate_cache("get_room_summary", (room_id,))

View File

@@ -217,6 +217,8 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
if etype == EventTypes.Member:
self._membership_stream_cache.entity_has_changed(state_key, stream_ordering)
self.get_invited_rooms_for_local_user.invalidate((state_key,))
self.get_local_users_in_room.invalidate((room_id,))
self.get_number_joined_users_in_room((room_id,))
if relates_to:
self.get_relations_for_event.invalidate((relates_to,))

View File

@@ -1766,6 +1766,14 @@ class PersistEventsStore:
self.store.get_invited_rooms_for_local_user.invalidate,
(event.state_key,),
)
txn.call_after(
self.store.get_local_users_in_room.invalidate,
(event.room_id,),
)
txn.call_after(
self.store.get_number_joined_users_in_room.invalidate,
(event.room_id,),
)
# The `_get_membership_from_event_id` is immutable, except for the
# case where we look up an event *before* persisting it.

View File

@@ -337,6 +337,15 @@ class RoomMemberWorkerStore(EventsWorkerStore):
"get_room_summary", _get_room_summary_txn
)
@cached()
async def get_number_joined_users_in_room(self, room_id: str) -> int:
return await self.db_pool.simple_select_one_onecol(
table="current_state_events",
keyvalues={"room_id": room_id, "membership": Membership.JOIN},
retcol="COUNT(*)",
desc="get_number_joined_users_in_room",
)
@cached()
async def get_invited_rooms_for_local_user(
self, user_id: str
@@ -444,6 +453,15 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return results
@cached()
async def get_local_users_in_room(self, room_id: str) -> List[str]:
return await self.db_pool.simple_select_onecol(
table="local_current_membership",
keyvalues={"room_id": room_id, "membership": Membership.JOIN},
retcol="user_id",
desc="get_local_users_in_room",
)
async def get_local_current_membership_for_user_in_room(
self, user_id: str, room_id: str
) -> Tuple[Optional[str], Optional[str]]:
@@ -869,6 +887,29 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return True
async def get_current_hosts_in_room(self, room_id: str) -> Set[str]:
"""Get current hosts in room."""
if isinstance(self.database_engine, Sqlite3Engine):
users = await self.get_users_in_room(room_id)
return {get_domain_from_id(u) for u in users}
def get_current_hosts_in_room_txn(txn: LoggingTransaction) -> Set[str]:
sql = """
SELECT DISTINCT substring(state_key FROM '@[^:]*:(.*)$')
FROM current_state_events
WHERE
type = 'm.room.member'
AND membership = 'join'
AND room_id = ?
"""
txn.execute(sql, (room_id,))
return {d for d, in txn}
return await self.db_pool.runInteraction(
"get_current_hosts_in_room", get_current_hosts_in_room_txn
)
async def get_joined_hosts(
self, room_id: str, state_entry: "_StateCacheEntry"
) -> FrozenSet[str]:

View File

@@ -16,6 +16,8 @@ import collections.abc
import logging
from typing import TYPE_CHECKING, Collection, Dict, Iterable, Optional, Set, Tuple
import attr
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
@@ -26,6 +28,7 @@ from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
LoggingTransaction,
make_in_list_sql_clause,
)
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
@@ -43,6 +46,15 @@ logger = logging.getLogger(__name__)
MAX_STATE_DELTA_HOPS = 100
@attr.s(slots=True, frozen=True, auto_attribs=True)
class EventMetadata:
"""Returned by `get_metadata_for_events`"""
room_id: str
event_type: str
state_key: Optional[str]
def _retrieve_and_check_room_version(room_id: str, room_version_id: str) -> RoomVersion:
v = KNOWN_ROOM_VERSIONS.get(room_version_id)
if not v:
@@ -133,6 +145,36 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
return room_version
async def get_metadata_for_events(
self, event_ids: Collection[str]
) -> Dict[str, EventMetadata]:
"""Get some metadata (room_id, type, state_key) for the given events."""
clause, args = make_in_list_sql_clause(
self.database_engine, "e.event_id", event_ids
)
sql = f"""
SELECT e.event_id, e.room_id, e.type, e.state_key FROM events AS e
LEFT JOIN state_events USING (event_id)
WHERE {clause}
"""
def get_metadata_for_events_txn(
txn: LoggingTransaction,
) -> Dict[str, EventMetadata]:
txn.execute(sql, args)
return {
event_id: EventMetadata(
room_id=room_id, event_type=event_type, state_key=state_key
)
for event_id, room_id, event_type, state_key in txn
}
return await self.db_pool.runInteraction(
"get_metadata_for_events", get_metadata_for_events_txn
)
async def get_room_predecessor(self, room_id: str) -> Optional[JsonMapping]:
"""Get the predecessor of an upgraded room if it exists.
Otherwise return None.
@@ -218,6 +260,20 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
"get_current_state_ids", _get_current_state_ids_txn
)
async def get_current_state(self, room_id: str) -> StateMap[EventBase]:
"""Same as `get_current_state_ids` but also fetches the events"""
state_map_ids = await self.get_current_state_ids(room_id)
event_map = await self.get_events(list(state_map_ids.values()))
state_map = {}
for key, event_id in state_map_ids.items():
event = event_map.get(event_id)
if event:
state_map[key] = event
return state_map
# FIXME: how should this be cached?
async def get_filtered_current_state_ids(
self, room_id: str, state_filter: Optional[StateFilter] = None
@@ -269,6 +325,39 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
"get_filtered_current_state_ids", _get_filtered_current_state_ids_txn
)
async def get_filtered_current_state(
self, room_id: str, state_filter: Optional[StateFilter] = None
) -> StateMap[EventBase]:
"""Same as `get_filtered_current_state_ids` but also fetches the events"""
state_map_ids = await self.get_filtered_current_state_ids(room_id, state_filter)
event_map = await self.get_events(list(state_map_ids.values()))
state_map = {}
for key, event_id in state_map_ids.items():
event = event_map.get(event_id)
if event:
state_map[key] = event
return state_map
async def get_current_state_event(
self, room_id: str, event_type: str, state_key: str
) -> Optional[EventBase]:
"""Get the current state event for the given type/state_key."""
key = (event_type, state_key)
state_map = await self.get_filtered_current_state_ids(
room_id, StateFilter.from_types((key,))
)
event_id = state_map.get(key)
event = None
if event_id:
event = await self.get_event(event_id, allow_none=True)
return event
async def get_canonical_alias_for_room(self, room_id: str) -> Optional[str]:
"""Get canonical alias for room, if any

View File

@@ -634,16 +634,19 @@ class StateGroupStorage:
return group_to_state
async def get_state_ids_for_group(self, state_group: int) -> StateMap[str]:
async def get_state_ids_for_group(
self, state_group: int, state_filter: Optional[StateFilter] = None
) -> StateMap[str]:
"""Get the event IDs of all the state in the given state group
Args:
state_group: A state group for which we want to get the state IDs.
state_filter: specifies the type of state event to fetch from DB, example: EventTypes.JoinRules
Returns:
Resolves to a map of (type, state_key) -> event_id
"""
group_to_state = await self.get_state_for_groups((state_group,))
group_to_state = await self.get_state_for_groups((state_group,), state_filter)
return group_to_state[state_group]

View File

@@ -16,8 +16,8 @@ import random
from types import TracebackType
from typing import TYPE_CHECKING, Any, Optional, Type
import synapse.logging.context
from synapse.api.errors import CodeMessageException
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage import DataStore
from synapse.util import Clock
@@ -265,4 +265,4 @@ class RetryDestinationLimiter:
logger.exception("Failed to store destination_retry_timings")
# we deliberately do this in the background.
synapse.logging.context.run_in_background(store_retry_timings)
run_as_background_process("store_retry_timings", store_retry_timings)

View File

@@ -30,16 +30,16 @@ from tests.unittest import HomeserverTestCase, override_config
class FederationSenderReceiptsTestCases(HomeserverTestCase):
def make_homeserver(self, reactor, clock):
mock_state_handler = Mock(spec=["get_current_hosts_in_room"])
# Ensure a new Awaitable is created for each call.
mock_state_handler.get_current_hosts_in_room.return_value = make_awaitable(
["test", "host2"]
)
return self.setup_test_homeserver(
state_handler=mock_state_handler,
hs = self.setup_test_homeserver(
federation_transport_client=Mock(spec=["send_transaction"]),
)
hs.get_datastores().main.get_current_hosts_in_room = Mock(
return_value=make_awaitable(["test", "host2"])
)
return hs
@override_config({"send_federation": True})
def test_send_receipts(self):
mock_send_transaction = (

View File

@@ -207,7 +207,7 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase):
# the room should show that the new user is a member
r = self.get_success(
self.hs.get_state_handler().get_current_state(self._room_id)
self.hs.get_datastores().main.get_current_state(self._room_id)
)
self.assertEqual(r[("m.room.member", joining_user)].membership, "join")
@@ -258,7 +258,7 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase):
# the room should show that the new user is a member
r = self.get_success(
self.hs.get_state_handler().get_current_state(self._room_id)
self.hs.get_datastores().main.get_current_state(self._room_id)
)
self.assertEqual(r[("m.room.member", joining_user)].membership, "join")

View File

@@ -335,7 +335,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
def _get_canonical_alias(self):
"""Get the canonical alias state of the room."""
return self.get_success(
self.state_handler.get_current_state(
self.store.get_current_state_event(
self.room_id, EventTypes.CanonicalAlias, ""
)
)

View File

@@ -276,7 +276,9 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
# federation handler wanting to backfill the fake event.
self.get_success(
federation_event_handler._process_received_pdu(
self.OTHER_SERVER_NAME, event, state=current_state
self.OTHER_SERVER_NAME,
event,
state={(e.type, e.state_key): e.event_id for e in current_state},
)
)

View File

@@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Any, Dict, List, Optional, Tuple
from collections import defaultdict
from typing import Any, Dict, List, Optional, Set, Tuple
from twisted.internet.address import IPv4Address
from twisted.internet.protocol import Protocol
@@ -32,6 +33,7 @@ from synapse.server import HomeServer
from tests import unittest
from tests.server import FakeTransport
from tests.utils import USE_POSTGRES_FOR_TESTS
try:
import hiredis
@@ -475,22 +477,25 @@ class FakeRedisPubSubServer:
"""A fake Redis server for pub/sub."""
def __init__(self):
self._subscribers = set()
self._subscribers_by_channel: Dict[
bytes, Set["FakeRedisPubSubProtocol"]
] = defaultdict(set)
def add_subscriber(self, conn):
def add_subscriber(self, conn, channel: bytes):
"""A connection has called SUBSCRIBE"""
self._subscribers.add(conn)
self._subscribers_by_channel[channel].add(conn)
def remove_subscriber(self, conn):
"""A connection has called UNSUBSCRIBE"""
self._subscribers.discard(conn)
"""A connection has lost connection"""
for subscribers in self._subscribers_by_channel.values():
subscribers.discard(conn)
def publish(self, conn, channel, msg) -> int:
def publish(self, conn, channel: bytes, msg) -> int:
"""A connection want to publish a message to subscribers."""
for sub in self._subscribers:
for sub in self._subscribers_by_channel[channel]:
sub.send(["message", channel, msg])
return len(self._subscribers)
return len(self._subscribers_by_channel)
def buildProtocol(self, addr):
return FakeRedisPubSubProtocol(self)
@@ -531,9 +536,10 @@ class FakeRedisPubSubProtocol(Protocol):
num_subscribers = self._server.publish(self, channel, message)
self.send(num_subscribers)
elif command == b"SUBSCRIBE":
(channel,) = args
self._server.add_subscriber(self)
self.send(["subscribe", channel, 1])
for idx, channel in enumerate(args):
num_channels = idx + 1
self._server.add_subscriber(self, channel)
self.send(["subscribe", channel, num_channels])
# Since we use SET/GET to cache things we can safely no-op them.
elif command == b"SET":
@@ -576,3 +582,27 @@ class FakeRedisPubSubProtocol(Protocol):
def connectionLost(self, reason):
self._server.remove_subscriber(self)
class RedisMultiWorkerStreamTestCase(BaseMultiWorkerStreamTestCase):
"""
A test case that enables Redis, providing a fake Redis server.
"""
if not hiredis:
skip = "Requires hiredis"
if not USE_POSTGRES_FOR_TESTS:
# Redis replication only takes place on Postgres
skip = "Requires Postgres"
def default_config(self) -> Dict[str, Any]:
"""
Overrides the default config to enable Redis.
Even if the test only uses make_worker_hs, the main process needs Redis
enabled otherwise it won't create a Fake Redis server to listen on the
Redis port and accept fake TCP connections.
"""
base = super().default_config()
base["redis"] = {"enabled": True}
return base

View File

@@ -0,0 +1,73 @@
# Copyright 2022 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from tests.replication._base import RedisMultiWorkerStreamTestCase
class ChannelsTestCase(RedisMultiWorkerStreamTestCase):
def test_subscribed_to_enough_redis_channels(self) -> None:
# The default main process is subscribed to the USER_IP channel.
self.assertCountEqual(
self.hs.get_replication_command_handler()._channels_to_subscribe_to,
["USER_IP"],
)
def test_background_worker_subscribed_to_user_ip(self) -> None:
# The default main process is subscribed to the USER_IP channel.
worker1 = self.make_worker_hs(
"synapse.app.generic_worker",
extra_config={
"worker_name": "worker1",
"run_background_tasks_on": "worker1",
"redis": {"enabled": True},
},
)
self.assertIn(
"USER_IP",
worker1.get_replication_command_handler()._channels_to_subscribe_to,
)
# Advance so the Redis subscription gets processed
self.pump(0.1)
# The counts are 2 because both the main process and the worker are subscribed.
self.assertEqual(len(self._redis_server._subscribers_by_channel[b"test"]), 2)
self.assertEqual(
len(self._redis_server._subscribers_by_channel[b"test/USER_IP"]), 2
)
def test_non_background_worker_not_subscribed_to_user_ip(self) -> None:
# The default main process is subscribed to the USER_IP channel.
worker2 = self.make_worker_hs(
"synapse.app.generic_worker",
extra_config={
"worker_name": "worker2",
"run_background_tasks_on": "worker1",
"redis": {"enabled": True},
},
)
self.assertNotIn(
"USER_IP",
worker2.get_replication_command_handler()._channels_to_subscribe_to,
)
# Advance so the Redis subscription gets processed
self.pump(0.1)
# The count is 2 because both the main process and the worker are subscribed.
self.assertEqual(len(self._redis_server._subscribers_by_channel[b"test"]), 2)
# For USER_IP, the count is 1 because only the main process is subscribed.
self.assertEqual(
len(self._redis_server._subscribers_by_channel[b"test/USER_IP"]), 1
)

View File

@@ -76,7 +76,7 @@ class UpgradeRoomTest(unittest.HomeserverTestCase):
"""
Upgrading a room should work fine.
"""
# THe user isn't in the room.
# The user isn't in the room.
roomless = self.register_user("roomless", "pass")
roomless_token = self.login(roomless, "pass")
@@ -263,3 +263,33 @@ class UpgradeRoomTest(unittest.HomeserverTestCase):
self.assertIn((EventTypes.SpaceChild, self.room_id), state_ids)
# The child that was removed should not be copied over.
self.assertNotIn((EventTypes.SpaceChild, old_room_id), state_ids)
def test_custom_room_type(self) -> None:
"""Test upgrading a room that has a custom room type set."""
test_room_type = "com.example.my_custom_room_type"
# Create a room with a custom room type.
room_id = self.helper.create_room_as(
self.creator,
tok=self.creator_token,
extra_content={
"creation_content": {EventContentFields.ROOM_TYPE: test_room_type}
},
)
# Upgrade the room!
channel = self._upgrade_room(room_id=room_id)
self.assertEqual(200, channel.code, channel.result)
self.assertIn("replacement_room", channel.json_body)
new_room_id = channel.json_body["replacement_room"]
state_ids = self.get_success(self.store.get_current_state_ids(new_room_id))
# Ensure the new room is the same type as the old room.
create_event = self.get_success(
self.store.get_event(state_ids[(EventTypes.Create, "")])
)
self.assertEqual(
create_event.content.get(EventContentFields.ROOM_TYPE), test_room_type
)

View File

@@ -69,7 +69,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
def persist_event(self, event, state=None):
"""Persist the event, with optional state"""
context = self.get_success(
self.state.compute_event_context(event, old_state=state)
self.state.compute_event_context(event, state_ids_before_event=state)
)
self.get_success(self.persistence.persist_event(event, context))
@@ -103,9 +103,11 @@ class ExtremPruneTestCase(HomeserverTestCase):
RoomVersions.V6,
)
state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
state_before_gap = self.get_success(
self.store.get_current_state_ids(self.room_id)
)
self.persist_event(remote_event_2, state=state_before_gap.values())
self.persist_event(remote_event_2, state=state_before_gap)
# Check the new extremity is just the new remote event.
self.assert_extremities([remote_event_2.event_id])
@@ -135,13 +137,14 @@ class ExtremPruneTestCase(HomeserverTestCase):
# setting. The state resolution across the old and new event will then
# include it, and so the resolved state won't match the new state.
state_before_gap = dict(
self.get_success(self.state.get_current_state(self.room_id))
self.get_success(self.store.get_current_state_ids(self.room_id))
)
state_before_gap.pop(("m.room.history_visibility", ""))
context = self.get_success(
self.state.compute_event_context(
remote_event_2, old_state=state_before_gap.values()
remote_event_2,
state_ids_before_event=state_before_gap,
)
)
@@ -177,9 +180,11 @@ class ExtremPruneTestCase(HomeserverTestCase):
RoomVersions.V6,
)
state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
state_before_gap = self.get_success(
self.store.get_current_state_ids(self.room_id)
)
self.persist_event(remote_event_2, state=state_before_gap.values())
self.persist_event(remote_event_2, state=state_before_gap)
# Check the new extremity is just the new remote event.
self.assert_extremities([remote_event_2.event_id])
@@ -207,9 +212,11 @@ class ExtremPruneTestCase(HomeserverTestCase):
RoomVersions.V6,
)
state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
state_before_gap = self.get_success(
self.store.get_current_state_ids(self.room_id)
)
self.persist_event(remote_event_2, state=state_before_gap.values())
self.persist_event(remote_event_2, state=state_before_gap)
# Check the new extremity is just the new remote event.
self.assert_extremities([self.remote_event_1.event_id, remote_event_2.event_id])
@@ -247,9 +254,11 @@ class ExtremPruneTestCase(HomeserverTestCase):
RoomVersions.V6,
)
state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
state_before_gap = self.get_success(
self.store.get_current_state_ids(self.room_id)
)
self.persist_event(remote_event_2, state=state_before_gap.values())
self.persist_event(remote_event_2, state=state_before_gap)
# Check the new extremity is just the new remote event.
self.assert_extremities([remote_event_2.event_id])
@@ -289,9 +298,11 @@ class ExtremPruneTestCase(HomeserverTestCase):
RoomVersions.V6,
)
state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
state_before_gap = self.get_success(
self.store.get_current_state_ids(self.room_id)
)
self.persist_event(remote_event_2, state=state_before_gap.values())
self.persist_event(remote_event_2, state=state_before_gap)
# Check the new extremity is just the new remote event.
self.assert_extremities([remote_event_2.event_id, local_message_event_id])
@@ -323,9 +334,11 @@ class ExtremPruneTestCase(HomeserverTestCase):
RoomVersions.V6,
)
state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
state_before_gap = self.get_success(
self.store.get_current_state_ids(self.room_id)
)
self.persist_event(remote_event_2, state=state_before_gap.values())
self.persist_event(remote_event_2, state=state_before_gap)
# Check the new extremity is just the new remote event.
self.assert_extremities([local_message_event_id, remote_event_2.event_id])

View File

@@ -98,9 +98,8 @@ class PurgeTests(HomeserverTestCase):
first = self.helper.send(self.room_id, body="test1")
# Get the current room state.
state_handler = self.hs.get_state_handler()
create_event = self.get_success(
state_handler.get_current_state(self.room_id, "m.room.create", "")
self.store.get_current_state_event(self.room_id, "m.room.create", "")
)
self.assertIsNotNone(create_event)

View File

@@ -88,7 +88,7 @@ class _DummyStore:
return groups
async def get_state_ids_for_group(self, state_group):
async def get_state_ids_for_group(self, state_group, state_filter=None):
return self._group_to_state[state_group]
async def store_state_group(
@@ -442,7 +442,12 @@ class StateTestCase(unittest.TestCase):
]
context = yield defer.ensureDeferred(
self.state.compute_event_context(event, old_state=old_state)
self.state.compute_event_context(
event,
state_ids_before_event={
(e.type, e.state_key): e.event_id for e in old_state
},
)
)
prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())
@@ -467,7 +472,12 @@ class StateTestCase(unittest.TestCase):
]
context = yield defer.ensureDeferred(
self.state.compute_event_context(event, old_state=old_state)
self.state.compute_event_context(
event,
state_ids_before_event={
(e.type, e.state_key): e.event_id for e in old_state
},
)
)
prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())