Compare commits
36 Commits
develop
...
hs-voip-st
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a4cac852a9 | ||
|
|
4d02a4cca0 | ||
|
|
b2c967fd1c | ||
|
|
f0689cee5e | ||
|
|
b1af5fece6 | ||
|
|
adb601b2d1 | ||
|
|
4def40414e | ||
|
|
686ce52723 | ||
|
|
58bf128581 | ||
|
|
7f1e057cca | ||
|
|
075312cf2d | ||
|
|
aac3c846a8 | ||
|
|
888ab79b3b | ||
|
|
aa45bf7c3a | ||
|
|
15453d4e6e | ||
|
|
78c40973f4 | ||
|
|
4acc98d23e | ||
|
|
651e829632 | ||
|
|
105d2cd05b | ||
|
|
de3e9b49ec | ||
|
|
148caefcba | ||
|
|
33d80be69f | ||
|
|
ad6a2b9e0c | ||
|
|
771692addd | ||
|
|
666e94b75a | ||
|
|
2728b21f3d | ||
|
|
1e812e4df0 | ||
|
|
ac0f8c20e8 | ||
|
|
7c8daf4ed9 | ||
|
|
0cfdd0d6b5 | ||
|
|
7af74298b3 | ||
|
|
3e7a5a6bd6 | ||
|
|
e01a22b2de | ||
|
|
7801e68a33 | ||
|
|
869953456a | ||
|
|
abf658c712 |
1
changelog.d/18968.feature
Normal file
1
changelog.d/18968.feature
Normal file
@@ -0,0 +1 @@
|
||||
Implement support for MSC4354: Sticky Events.
|
||||
@@ -135,6 +135,8 @@ experimental_features:
|
||||
msc4155_enabled: true
|
||||
# Thread Subscriptions
|
||||
msc4306_enabled: true
|
||||
# Sticky Events
|
||||
msc4354_enabled: true
|
||||
|
||||
server_notices:
|
||||
system_mxid_localpart: _server
|
||||
|
||||
@@ -231,6 +231,7 @@ test_packages=(
|
||||
./tests/msc4140
|
||||
./tests/msc4155
|
||||
./tests/msc4306
|
||||
./tests/msc4354
|
||||
)
|
||||
|
||||
# Enable dirty runs, so tests will reuse the same container where possible.
|
||||
|
||||
@@ -135,6 +135,7 @@ BOOLEAN_COLUMNS = {
|
||||
"has_known_state",
|
||||
"is_encrypted",
|
||||
],
|
||||
"sticky_events": ["soft_failed"],
|
||||
"thread_subscriptions": ["subscribed", "automatic"],
|
||||
"users": ["shadow_banned", "approved", "locked", "suspended"],
|
||||
"un_partial_stated_event_stream": ["rejection_status_changed"],
|
||||
|
||||
@@ -24,7 +24,7 @@
|
||||
"""Contains constants from the specification."""
|
||||
|
||||
import enum
|
||||
from typing import Final
|
||||
from typing import Final, TypedDict
|
||||
|
||||
# the max size of a (canonical-json-encoded) event
|
||||
MAX_PDU_SIZE = 65536
|
||||
@@ -279,6 +279,8 @@ class EventUnsignedContentFields:
|
||||
# Requesting user's membership, per MSC4115
|
||||
MEMBERSHIP: Final = "membership"
|
||||
|
||||
STICKY_TTL: Final = "msc4354_sticky_duration_ttl_ms"
|
||||
|
||||
|
||||
class MTextFields:
|
||||
"""Fields found inside m.text content blocks."""
|
||||
@@ -360,3 +362,14 @@ class Direction(enum.Enum):
|
||||
class ProfileFields:
|
||||
DISPLAYNAME: Final = "displayname"
|
||||
AVATAR_URL: Final = "avatar_url"
|
||||
|
||||
|
||||
class StickyEventField(TypedDict):
|
||||
duration_ms: int
|
||||
|
||||
|
||||
class StickyEvent:
|
||||
QUERY_PARAM_NAME: Final = "org.matrix.msc4354.sticky_duration_ms"
|
||||
FIELD_NAME: Final = "msc4354_sticky"
|
||||
MAX_DURATION_MS: Final = 3600000 # 1 hour
|
||||
MAX_EVENTS_IN_SYNC: Final = 100
|
||||
|
||||
@@ -102,6 +102,7 @@ from synapse.storage.databases.main.signatures import SignatureWorkerStore
|
||||
from synapse.storage.databases.main.sliding_sync import SlidingSyncStore
|
||||
from synapse.storage.databases.main.state import StateGroupWorkerStore
|
||||
from synapse.storage.databases.main.stats import StatsStore
|
||||
from synapse.storage.databases.main.sticky_events import StickyEventsWorkerStore
|
||||
from synapse.storage.databases.main.stream import StreamWorkerStore
|
||||
from synapse.storage.databases.main.tags import TagsWorkerStore
|
||||
from synapse.storage.databases.main.task_scheduler import TaskSchedulerWorkerStore
|
||||
@@ -136,6 +137,7 @@ class GenericWorkerStore(
|
||||
RoomWorkerStore,
|
||||
DirectoryWorkerStore,
|
||||
ThreadSubscriptionsWorkerStore,
|
||||
StickyEventsWorkerStore,
|
||||
PushRulesWorkerStore,
|
||||
ApplicationServiceTransactionWorkerStore,
|
||||
ApplicationServiceWorkerStore,
|
||||
|
||||
@@ -595,3 +595,6 @@ class ExperimentalConfig(Config):
|
||||
# MSC4306: Thread Subscriptions
|
||||
# (and MSC4308: Thread Subscriptions extension to Sliding Sync)
|
||||
self.msc4306_enabled: bool = experimental.get("msc4306_enabled", False)
|
||||
|
||||
# MSC4354: Sticky Events
|
||||
self.msc4354_enabled: bool = experimental.get("msc4354_enabled", False)
|
||||
|
||||
@@ -127,7 +127,7 @@ class WriterLocations:
|
||||
"""Specifies the instances that write various streams.
|
||||
|
||||
Attributes:
|
||||
events: The instances that write to the event and backfill streams.
|
||||
events: The instances that write to the event, backfill and sticky events streams.
|
||||
typing: The instances that write to the typing stream. Currently
|
||||
can only be a single instance.
|
||||
to_device: The instances that write to the to_device stream. Currently
|
||||
|
||||
@@ -41,7 +41,12 @@ from typing import (
|
||||
import attr
|
||||
from unpaddedbase64 import encode_base64
|
||||
|
||||
from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
|
||||
from synapse.api.constants import (
|
||||
EventContentFields,
|
||||
EventTypes,
|
||||
RelationTypes,
|
||||
StickyEvent,
|
||||
)
|
||||
from synapse.api.room_versions import EventFormatVersions, RoomVersion, RoomVersions
|
||||
from synapse.synapse_rust.events import EventInternalMetadata
|
||||
from synapse.types import (
|
||||
@@ -323,6 +328,20 @@ class EventBase(metaclass=abc.ABCMeta):
|
||||
# this will be a no-op if the event dict is already frozen.
|
||||
self._dict = freeze(self._dict)
|
||||
|
||||
def sticky_duration(self) -> Optional[int]:
|
||||
sticky_obj = self.get_dict().get(StickyEvent.FIELD_NAME, None)
|
||||
if type(sticky_obj) is not dict:
|
||||
return None
|
||||
sticky_duration_ms = sticky_obj.get("duration_ms", None)
|
||||
# MSC: Valid values are the integer range 0-MAX_DURATION_MS
|
||||
if (
|
||||
type(sticky_duration_ms) is int
|
||||
and sticky_duration_ms >= 0
|
||||
and sticky_duration_ms <= StickyEvent.MAX_DURATION_MS
|
||||
):
|
||||
return sticky_duration_ms
|
||||
return None
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.__repr__()
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
import attr
|
||||
from signedjson.types import SigningKey
|
||||
|
||||
from synapse.api.constants import MAX_DEPTH, EventTypes
|
||||
from synapse.api.constants import MAX_DEPTH, EventTypes, StickyEvent, StickyEventField
|
||||
from synapse.api.room_versions import (
|
||||
KNOWN_EVENT_FORMAT_VERSIONS,
|
||||
EventFormatVersions,
|
||||
@@ -89,6 +89,7 @@ class EventBuilder:
|
||||
|
||||
content: JsonDict = attr.Factory(dict)
|
||||
unsigned: JsonDict = attr.Factory(dict)
|
||||
sticky: Optional[StickyEventField] = None
|
||||
|
||||
# These only exist on a subset of events, so they raise AttributeError if
|
||||
# someone tries to get them when they don't exist.
|
||||
@@ -269,6 +270,9 @@ class EventBuilder:
|
||||
if self._origin_server_ts is not None:
|
||||
event_dict["origin_server_ts"] = self._origin_server_ts
|
||||
|
||||
if self.sticky is not None:
|
||||
event_dict[StickyEvent.FIELD_NAME] = self.sticky
|
||||
|
||||
return create_local_event_from_event_dict(
|
||||
clock=self._clock,
|
||||
hostname=self._hostname,
|
||||
@@ -318,6 +322,7 @@ class EventBuilderFactory:
|
||||
unsigned=key_values.get("unsigned", {}),
|
||||
redacts=key_values.get("redacts", None),
|
||||
origin_server_ts=key_values.get("origin_server_ts", None),
|
||||
sticky=key_values.get(StickyEvent.FIELD_NAME, None),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -195,6 +195,8 @@ class FederationBase:
|
||||
# using the event in prev_events).
|
||||
redacted_event = prune_event(pdu)
|
||||
redacted_event.internal_metadata.soft_failed = True
|
||||
# Mark this as spam so we don't re-evaluate soft-failure status.
|
||||
redacted_event.internal_metadata.policy_server_spammy = True
|
||||
return redacted_event
|
||||
|
||||
return pdu
|
||||
|
||||
@@ -216,6 +216,11 @@ class FederationRemoteSendQueue(AbstractFederationSender):
|
||||
# This should never get called.
|
||||
raise NotImplementedError()
|
||||
|
||||
def notify_new_server_joined(self, server: str, room_id: str) -> None:
|
||||
"""As per FederationSender"""
|
||||
# This should never get called.
|
||||
raise NotImplementedError()
|
||||
|
||||
def build_and_send_edu(
|
||||
self,
|
||||
destination: str,
|
||||
|
||||
@@ -180,6 +180,7 @@ from synapse.types import (
|
||||
from synapse.util.clock import Clock
|
||||
from synapse.util.metrics import Measure
|
||||
from synapse.util.retryutils import filter_destinations_by_retry_limiter
|
||||
from synapse.visibility import filter_events_for_server
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.events.presence_router import PresenceRouter
|
||||
@@ -243,6 +244,13 @@ class AbstractFederationSender(metaclass=abc.ABCMeta):
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def notify_new_server_joined(self, server: str, room_id: str) -> None:
|
||||
"""This gets called when we a new server has joined a room. We might
|
||||
want to send out some events to this server.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
async def send_read_receipt(self, receipt: ReadReceipt) -> None:
|
||||
"""Send a RR to any other servers in the room
|
||||
@@ -505,6 +513,64 @@ class FederationSender(AbstractFederationSender):
|
||||
self._per_destination_queues[destination] = queue
|
||||
return queue
|
||||
|
||||
def notify_new_server_joined(self, server: str, room_id: str) -> None:
|
||||
# We currently only use this notification for MSC4354: Sticky Events.
|
||||
if not self.hs.config.experimental.msc4354_enabled:
|
||||
return
|
||||
# fire off a processing loop in the background
|
||||
self.hs.run_as_background_process(
|
||||
"process_new_server_joined_over_federation",
|
||||
self._process_new_server_joined_over_federation,
|
||||
server,
|
||||
room_id,
|
||||
)
|
||||
|
||||
async def _process_new_server_joined_over_federation(
|
||||
self, new_server: str, room_id: str
|
||||
) -> None:
|
||||
sticky_event_ids = await self.store.get_sticky_event_ids_sent_by_self(
|
||||
room_id,
|
||||
0,
|
||||
)
|
||||
sticky_events = await self.store.get_events_as_list(sticky_event_ids)
|
||||
|
||||
# We must not send events that are outliers / lack a stream ordering, else we won't be able to
|
||||
# satisfy /get_missing_events requests
|
||||
sticky_events = [
|
||||
ev
|
||||
for ev in sticky_events
|
||||
if ev.internal_metadata.stream_ordering is not None
|
||||
and not ev.internal_metadata.is_outlier()
|
||||
]
|
||||
# order by stream ordering so we present things in the right timeline order on the receiver
|
||||
sticky_events = sorted(
|
||||
sticky_events,
|
||||
key=lambda ev: ev.internal_metadata.stream_ordering
|
||||
or 0, # not possible to be 0
|
||||
)
|
||||
|
||||
sticky_events = await filter_events_for_server(
|
||||
self._storage_controllers,
|
||||
new_server,
|
||||
self.server_name,
|
||||
sticky_events,
|
||||
redact=False,
|
||||
filter_out_erased_senders=True,
|
||||
filter_out_remote_partial_state_events=True,
|
||||
)
|
||||
if sticky_events:
|
||||
logger.info(
|
||||
"sending %d sticky events to newly joined server %s in room %s",
|
||||
len(sticky_events),
|
||||
new_server,
|
||||
room_id,
|
||||
)
|
||||
# we don't track that we sent up to this stream position since it won't make any difference
|
||||
# since notify_new_server_joined is only called initially.
|
||||
await self._transaction_manager.send_new_transaction(
|
||||
new_server, sticky_events, []
|
||||
)
|
||||
|
||||
def notify_new_events(self, max_token: RoomStreamToken) -> None:
|
||||
"""This gets called when we have some new events we might want to
|
||||
send out to other servers.
|
||||
|
||||
@@ -104,6 +104,7 @@ class PerDestinationQueue:
|
||||
self._instance_name = hs.get_instance_name()
|
||||
self._federation_shard_config = hs.config.worker.federation_shard_config
|
||||
self._state = hs.get_state_handler()
|
||||
self.msc4354_enabled = hs.config.experimental.msc4354_enabled
|
||||
|
||||
self._should_send_on_this_instance = True
|
||||
if not self._federation_shard_config.should_handle(
|
||||
@@ -581,6 +582,33 @@ class PerDestinationQueue:
|
||||
# send.
|
||||
extrem_events = await self._store.get_events_as_list(extrems)
|
||||
|
||||
if self.msc4354_enabled:
|
||||
# we also want to send sticky events that are still active in this room
|
||||
sticky_event_ids = (
|
||||
await self._store.get_sticky_event_ids_sent_by_self(
|
||||
pdu.room_id,
|
||||
last_successful_stream_ordering,
|
||||
)
|
||||
)
|
||||
# skip any that are actually the forward extremities we want to send anyway
|
||||
sticky_events = await self._store.get_events_as_list(
|
||||
[
|
||||
event_id
|
||||
for event_id in sticky_event_ids
|
||||
if event_id not in extrems
|
||||
]
|
||||
)
|
||||
if sticky_events:
|
||||
# *prepend* these to the extrem list, so they are processed first.
|
||||
# This ensures they will show up before the forward extrem in stream order
|
||||
extrem_events = sticky_events + extrem_events
|
||||
logger.info(
|
||||
"Sending %d missed sticky events to %s: %r",
|
||||
len(sticky_events),
|
||||
self._destination,
|
||||
pdu.room_id,
|
||||
)
|
||||
|
||||
new_pdus = []
|
||||
for p in extrem_events:
|
||||
# We pulled this from the DB, so it'll be non-null
|
||||
|
||||
@@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, List, Optional, Set, Tuple
|
||||
|
||||
from twisted.internet.interfaces import IDelayedCall
|
||||
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.api.constants import EventTypes, StickyEvent
|
||||
from synapse.api.errors import ShadowBanError, SynapseError
|
||||
from synapse.api.ratelimiting import Ratelimiter
|
||||
from synapse.config.workers import MAIN_PROCESS_INSTANCE_NAME
|
||||
@@ -331,6 +331,7 @@ class DelayedEventsHandler:
|
||||
origin_server_ts: Optional[int],
|
||||
content: JsonDict,
|
||||
delay: int,
|
||||
sticky_duration_ms: Optional[int],
|
||||
) -> str:
|
||||
"""
|
||||
Creates a new delayed event and schedules its delivery.
|
||||
@@ -344,7 +345,7 @@ class DelayedEventsHandler:
|
||||
If None, the timestamp will be the actual time when the event is sent.
|
||||
content: The content of the event to be sent.
|
||||
delay: How long (in milliseconds) to wait before automatically sending the event.
|
||||
|
||||
sticky_duration_ms: The sticky duration if any, see MSC4354.
|
||||
Returns: The ID of the added delayed event.
|
||||
|
||||
Raises:
|
||||
@@ -380,6 +381,7 @@ class DelayedEventsHandler:
|
||||
origin_server_ts=origin_server_ts,
|
||||
content=content,
|
||||
delay=delay,
|
||||
sticky_duration_ms=sticky_duration_ms,
|
||||
)
|
||||
|
||||
if self._repl_client is not None:
|
||||
@@ -487,6 +489,7 @@ class DelayedEventsHandler:
|
||||
origin_server_ts=event.origin_server_ts,
|
||||
content=event.content,
|
||||
device_id=event.device_id,
|
||||
sticky_duration_ms=event.sticky_duration_ms,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -596,7 +599,10 @@ class DelayedEventsHandler:
|
||||
|
||||
if event.state_key is not None:
|
||||
event_dict["state_key"] = event.state_key
|
||||
|
||||
if event.sticky_duration_ms is not None:
|
||||
event_dict[StickyEvent.FIELD_NAME] = {
|
||||
"duration_ms": event.sticky_duration_ms,
|
||||
}
|
||||
(
|
||||
sent_event,
|
||||
_,
|
||||
|
||||
@@ -67,6 +67,7 @@ from synapse.events import EventBase
|
||||
from synapse.events.snapshot import EventContext, UnpersistedEventContextBase
|
||||
from synapse.events.validator import EventValidator
|
||||
from synapse.federation.federation_client import InvalidResponseError
|
||||
from synapse.federation.federation_server import _INBOUND_EVENT_HANDLING_LOCK_NAME
|
||||
from synapse.handlers.pagination import PURGE_PAGINATION_LOCK_NAME
|
||||
from synapse.http.servlet import assert_params_in_dict
|
||||
from synapse.logging.context import nested_logging_context
|
||||
@@ -74,6 +75,7 @@ from synapse.logging.opentracing import SynapseTags, set_tag, tag_args, trace
|
||||
from synapse.metrics import SERVER_NAME_LABEL
|
||||
from synapse.module_api import NOT_SPAM
|
||||
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
|
||||
from synapse.storage.databases.main.lock import Lock
|
||||
from synapse.storage.invite_rule import InviteRule
|
||||
from synapse.types import JsonDict, StrCollection, get_domain_from_id
|
||||
from synapse.types.state import StateFilter
|
||||
@@ -644,125 +646,158 @@ class FederationHandler:
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
lock: Optional[Lock] = None
|
||||
async with self._is_partial_state_room_linearizer.queue(room_id):
|
||||
already_partial_state_room = await self.store.is_partial_state_room(
|
||||
room_id
|
||||
)
|
||||
|
||||
ret = await self.federation_client.send_join(
|
||||
host_list,
|
||||
event,
|
||||
room_version_obj,
|
||||
# Perform a full join when we are already in the room and it is a
|
||||
# full state room, since we are not allowed to persist a partial
|
||||
# state join event in a full state room. In the future, we could
|
||||
# optimize this by always performing a partial state join and
|
||||
# computing the state ourselves or retrieving it from the remote
|
||||
# homeserver if necessary.
|
||||
#
|
||||
# There's a race where we leave the room, then perform a full join
|
||||
# anyway. This should end up being fast anyway, since we would
|
||||
# already have the full room state and auth chain persisted.
|
||||
partial_state=not is_host_joined or already_partial_state_room,
|
||||
)
|
||||
|
||||
event = ret.event
|
||||
origin = ret.origin
|
||||
state = ret.state
|
||||
auth_chain = ret.auth_chain
|
||||
auth_chain.sort(key=lambda e: e.depth)
|
||||
|
||||
logger.debug("do_invite_join auth_chain: %s", auth_chain)
|
||||
logger.debug("do_invite_join state: %s", state)
|
||||
|
||||
logger.debug("do_invite_join event: %s", event)
|
||||
|
||||
# if this is the first time we've joined this room, it's time to add
|
||||
# a row to `rooms` with the correct room version. If there's already a
|
||||
# row there, we should override it, since it may have been populated
|
||||
# based on an invite request which lied about the room version.
|
||||
#
|
||||
# federation_client.send_join has already checked that the room
|
||||
# version in the received create event is the same as room_version_obj,
|
||||
# so we can rely on it now.
|
||||
#
|
||||
await self.store.upsert_room_on_join(
|
||||
room_id=room_id,
|
||||
room_version=room_version_obj,
|
||||
state_events=state,
|
||||
)
|
||||
|
||||
if ret.partial_state and not already_partial_state_room:
|
||||
# Mark the room as having partial state.
|
||||
# The background process is responsible for unmarking this flag,
|
||||
# even if the join fails.
|
||||
# TODO(faster_joins):
|
||||
# We may want to reset the partial state info if it's from an
|
||||
# old, failed partial state join.
|
||||
# https://github.com/matrix-org/synapse/issues/13000
|
||||
|
||||
# FIXME: Ideally, we would store the full stream token here
|
||||
# not just the minimum stream ID, so that we can compute an
|
||||
# accurate list of device changes when un-partial-ing the
|
||||
# room. The only side effect of this is that we may send
|
||||
# extra unecessary device list outbound pokes through
|
||||
# federation, which is harmless.
|
||||
device_lists_stream_id = self.store.get_device_stream_token().stream
|
||||
|
||||
await self.store.store_partial_state_room(
|
||||
room_id=room_id,
|
||||
servers=ret.servers_in_room,
|
||||
device_lists_stream_id=device_lists_stream_id,
|
||||
joined_via=origin,
|
||||
)
|
||||
|
||||
try:
|
||||
max_stream_id = (
|
||||
await self._federation_event_handler.process_remote_join(
|
||||
origin,
|
||||
room_id,
|
||||
auth_chain,
|
||||
state,
|
||||
event,
|
||||
room_version_obj,
|
||||
partial_state=ret.partial_state,
|
||||
# MSC4354: Sticky Events causes existing servers in the room to send sticky events
|
||||
# to the newly joined server as soon as they realise the new server is in the room.
|
||||
# If they do this before we've persisted the /send_join response we will be unable to
|
||||
# process those PDUs. Therefore, we take a lock out now for this room, and release it
|
||||
# once we have processed the /send_join response, to buffer up these inbound messages.
|
||||
# This may be useful to do even without MSC4354, but it's gated behind an
|
||||
# experimental flag check to reduce the chance of this having unintended side-effects
|
||||
# e.g accidental deadlocks. Once we're confident of this behaviour, we can probably
|
||||
# drop the flag check. We take the lock AFTER we have been queued by the linearizer
|
||||
# else we would just hold the lock for no reason whilst in the queue: we want to hold
|
||||
# the lock for the smallest amount of time possible.
|
||||
if self.config.experimental.msc4354_enabled:
|
||||
lock = await self.store.try_acquire_lock(
|
||||
_INBOUND_EVENT_HANDLING_LOCK_NAME, room_id
|
||||
)
|
||||
)
|
||||
except PartialStateConflictError:
|
||||
# This should be impossible, since we hold the lock on the room's
|
||||
# partial statedness.
|
||||
logger.error(
|
||||
"Room %s was un-partial stated while processing remote join.",
|
||||
room_id,
|
||||
)
|
||||
raise
|
||||
else:
|
||||
# Record the join event id for future use (when we finish the full
|
||||
# join). We have to do this after persisting the event to keep
|
||||
# foreign key constraints intact.
|
||||
if ret.partial_state and not already_partial_state_room:
|
||||
# TODO(faster_joins):
|
||||
# We may want to reset the partial state info if it's from
|
||||
# an old, failed partial state join.
|
||||
# https://github.com/matrix-org/synapse/issues/13000
|
||||
await self.store.write_partial_state_rooms_join_event_id(
|
||||
room_id, event.event_id
|
||||
)
|
||||
finally:
|
||||
# Always kick off the background process that asynchronously fetches
|
||||
# state for the room.
|
||||
# If the join failed, the background process is responsible for
|
||||
# cleaning up — including unmarking the room as a partial state
|
||||
# room.
|
||||
if ret.partial_state:
|
||||
# Kick off the process of asynchronously fetching the state for
|
||||
# this room.
|
||||
self._start_partial_state_room_sync(
|
||||
initial_destination=origin,
|
||||
other_destinations=ret.servers_in_room,
|
||||
# Insert the room into the rooms table now so we can process potential incoming
|
||||
# /send transactions enough to be able to insert into the federation staging
|
||||
# area. We won't process the staging area until we release the lock above.
|
||||
await self.store.upsert_room_on_join(
|
||||
room_id=room_id,
|
||||
room_version=room_version_obj,
|
||||
state_events=None,
|
||||
)
|
||||
|
||||
already_partial_state_room = await self.store.is_partial_state_room(
|
||||
room_id
|
||||
)
|
||||
|
||||
ret = await self.federation_client.send_join(
|
||||
host_list,
|
||||
event,
|
||||
room_version_obj,
|
||||
# Perform a full join when we are already in the room and it is a
|
||||
# full state room, since we are not allowed to persist a partial
|
||||
# state join event in a full state room. In the future, we could
|
||||
# optimize this by always performing a partial state join and
|
||||
# computing the state ourselves or retrieving it from the remote
|
||||
# homeserver if necessary.
|
||||
#
|
||||
# There's a race where we leave the room, then perform a full join
|
||||
# anyway. This should end up being fast anyway, since we would
|
||||
# already have the full room state and auth chain persisted.
|
||||
partial_state=not is_host_joined or already_partial_state_room,
|
||||
)
|
||||
|
||||
event = ret.event
|
||||
origin = ret.origin
|
||||
state = ret.state
|
||||
auth_chain = ret.auth_chain
|
||||
auth_chain.sort(key=lambda e: e.depth)
|
||||
|
||||
logger.debug("do_invite_join auth_chain: %s", auth_chain)
|
||||
logger.debug("do_invite_join state: %s", state)
|
||||
|
||||
logger.debug("do_invite_join event: %s", event)
|
||||
|
||||
# if this is the first time we've joined this room, it's time to add
|
||||
# a row to `rooms` with the correct room version. If there's already a
|
||||
# row there, we should override it, since it may have been populated
|
||||
# based on an invite request which lied about the room version.
|
||||
#
|
||||
# federation_client.send_join has already checked that the room
|
||||
# version in the received create event is the same as room_version_obj,
|
||||
# so we can rely on it now.
|
||||
#
|
||||
await self.store.upsert_room_on_join(
|
||||
room_id=room_id,
|
||||
room_version=room_version_obj,
|
||||
state_events=state,
|
||||
)
|
||||
|
||||
if ret.partial_state and not already_partial_state_room:
|
||||
# Mark the room as having partial state.
|
||||
# The background process is responsible for unmarking this flag,
|
||||
# even if the join fails.
|
||||
# TODO(faster_joins):
|
||||
# We may want to reset the partial state info if it's from an
|
||||
# old, failed partial state join.
|
||||
# https://github.com/matrix-org/synapse/issues/13000
|
||||
|
||||
# FIXME: Ideally, we would store the full stream token here
|
||||
# not just the minimum stream ID, so that we can compute an
|
||||
# accurate list of device changes when un-partial-ing the
|
||||
# room. The only side effect of this is that we may send
|
||||
# extra unecessary device list outbound pokes through
|
||||
# federation, which is harmless.
|
||||
device_lists_stream_id = (
|
||||
self.store.get_device_stream_token().stream
|
||||
)
|
||||
|
||||
await self.store.store_partial_state_room(
|
||||
room_id=room_id,
|
||||
servers=ret.servers_in_room,
|
||||
device_lists_stream_id=device_lists_stream_id,
|
||||
joined_via=origin,
|
||||
)
|
||||
|
||||
try:
|
||||
max_stream_id = (
|
||||
await self._federation_event_handler.process_remote_join(
|
||||
origin,
|
||||
room_id,
|
||||
auth_chain,
|
||||
state,
|
||||
event,
|
||||
room_version_obj,
|
||||
partial_state=ret.partial_state,
|
||||
)
|
||||
)
|
||||
except PartialStateConflictError:
|
||||
# This should be impossible, since we hold the lock on the room's
|
||||
# partial statedness.
|
||||
logger.error(
|
||||
"Room %s was un-partial stated while processing remote join.",
|
||||
room_id,
|
||||
)
|
||||
raise
|
||||
else:
|
||||
# Record the join event id for future use (when we finish the full
|
||||
# join). We have to do this after persisting the event to keep
|
||||
# foreign key constraints intact.
|
||||
if ret.partial_state and not already_partial_state_room:
|
||||
# TODO(faster_joins):
|
||||
# We may want to reset the partial state info if it's from
|
||||
# an old, failed partial state join.
|
||||
# https://github.com/matrix-org/synapse/issues/13000
|
||||
await self.store.write_partial_state_rooms_join_event_id(
|
||||
room_id, event.event_id
|
||||
)
|
||||
finally:
|
||||
# Always kick off the background process that asynchronously fetches
|
||||
# state for the room.
|
||||
# If the join failed, the background process is responsible for
|
||||
# cleaning up — including unmarking the room as a partial state
|
||||
# room.
|
||||
if ret.partial_state:
|
||||
# Kick off the process of asynchronously fetching the state for
|
||||
# this room.
|
||||
self._start_partial_state_room_sync(
|
||||
initial_destination=origin,
|
||||
other_destinations=ret.servers_in_room,
|
||||
room_id=room_id,
|
||||
)
|
||||
finally:
|
||||
# allow inbound events which happened during the join to be processed.
|
||||
# Also ensures we release the lock on unexpected errors e.g db errors from
|
||||
# upsert_room_on_join or network errors from send_join.
|
||||
if lock:
|
||||
await lock.release()
|
||||
# We wait here until this instance has seen the events come down
|
||||
# replication (if we're using replication) as the below uses caches.
|
||||
await self._replication.wait_for_stream_position(
|
||||
|
||||
@@ -54,6 +54,7 @@ from synapse.util.async_helpers import (
|
||||
concurrently_execute,
|
||||
gather_optional_coroutines,
|
||||
)
|
||||
from synapse.visibility import filter_events_for_client
|
||||
|
||||
_ThreadSubscription: TypeAlias = (
|
||||
SlidingSyncResult.Extensions.ThreadSubscriptionsExtension.ThreadSubscription
|
||||
@@ -76,7 +77,10 @@ class SlidingSyncExtensionHandler:
|
||||
self.event_sources = hs.get_event_sources()
|
||||
self.device_handler = hs.get_device_handler()
|
||||
self.push_rules_handler = hs.get_push_rules_handler()
|
||||
self.clock = hs.get_clock()
|
||||
self._storage_controllers = hs.get_storage_controllers()
|
||||
self._enable_thread_subscriptions = hs.config.experimental.msc4306_enabled
|
||||
self._enable_sticky_events = hs.config.experimental.msc4354_enabled
|
||||
|
||||
@trace
|
||||
async def get_extensions_response(
|
||||
@@ -177,6 +181,19 @@ class SlidingSyncExtensionHandler:
|
||||
from_token=from_token,
|
||||
)
|
||||
|
||||
sticky_events_coro = None
|
||||
if (
|
||||
sync_config.extensions.sticky_events is not None
|
||||
and self._enable_sticky_events
|
||||
):
|
||||
sticky_events_coro = self.get_sticky_events_extension_response(
|
||||
sync_config=sync_config,
|
||||
sticky_events_request=sync_config.extensions.sticky_events,
|
||||
actual_room_ids=actual_room_ids,
|
||||
to_token=to_token,
|
||||
from_token=from_token,
|
||||
)
|
||||
|
||||
(
|
||||
to_device_response,
|
||||
e2ee_response,
|
||||
@@ -184,6 +201,7 @@ class SlidingSyncExtensionHandler:
|
||||
receipts_response,
|
||||
typing_response,
|
||||
thread_subs_response,
|
||||
sticky_events_response,
|
||||
) = await gather_optional_coroutines(
|
||||
to_device_coro,
|
||||
e2ee_coro,
|
||||
@@ -191,6 +209,7 @@ class SlidingSyncExtensionHandler:
|
||||
receipts_coro,
|
||||
typing_coro,
|
||||
thread_subs_coro,
|
||||
sticky_events_coro,
|
||||
)
|
||||
|
||||
return SlidingSyncResult.Extensions(
|
||||
@@ -200,6 +219,7 @@ class SlidingSyncExtensionHandler:
|
||||
receipts=receipts_response,
|
||||
typing=typing_response,
|
||||
thread_subscriptions=thread_subs_response,
|
||||
sticky_events=sticky_events_response,
|
||||
)
|
||||
|
||||
def find_relevant_room_ids_for_extension(
|
||||
@@ -970,3 +990,47 @@ class SlidingSyncExtensionHandler:
|
||||
unsubscribed=unsubscribed_threads,
|
||||
prev_batch=prev_batch,
|
||||
)
|
||||
|
||||
async def get_sticky_events_extension_response(
|
||||
self,
|
||||
sync_config: SlidingSyncConfig,
|
||||
sticky_events_request: SlidingSyncConfig.Extensions.StickyEventsExtension,
|
||||
actual_room_ids: Set[str],
|
||||
to_token: StreamToken,
|
||||
from_token: Optional[SlidingSyncStreamToken],
|
||||
) -> Optional[SlidingSyncResult.Extensions.StickyEventsExtension]:
|
||||
if not sticky_events_request.enabled:
|
||||
return None
|
||||
now = self.clock.time_msec()
|
||||
from_id = from_token.stream_token.sticky_events_key if from_token else 0
|
||||
_, room_to_event_ids = await self.store.get_sticky_events_in_rooms(
|
||||
actual_room_ids,
|
||||
from_id,
|
||||
to_token.sticky_events_key,
|
||||
now,
|
||||
# We set no limit here because the client can control when they get sticky events.
|
||||
# Furthermore, it doesn't seem possible to set a limit with the internal API shape
|
||||
# as given, as we cannot manipulate the to_token.sticky_events_key sent to the client...
|
||||
limit=0,
|
||||
)
|
||||
all_sticky_event_ids = {
|
||||
ev_id for evs in room_to_event_ids.values() for ev_id in evs
|
||||
}
|
||||
event_map = await self.store.get_events(all_sticky_event_ids)
|
||||
filtered_events = await filter_events_for_client(
|
||||
self._storage_controllers,
|
||||
sync_config.user.to_string(),
|
||||
list(event_map.values()),
|
||||
always_include_ids=frozenset(all_sticky_event_ids),
|
||||
)
|
||||
event_map = {ev.event_id: ev for ev in filtered_events}
|
||||
return SlidingSyncResult.Extensions.StickyEventsExtension(
|
||||
room_id_to_sticky_events={
|
||||
room_id: {
|
||||
event_map[event_id]
|
||||
for event_id in sticky_event_ids
|
||||
if event_id in event_map
|
||||
}
|
||||
for room_id, sticky_event_ids in room_to_event_ids.items()
|
||||
}
|
||||
)
|
||||
|
||||
@@ -20,6 +20,7 @@
|
||||
#
|
||||
import itertools
|
||||
import logging
|
||||
import time
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
AbstractSet,
|
||||
@@ -44,6 +45,7 @@ from synapse.api.constants import (
|
||||
EventTypes,
|
||||
JoinRules,
|
||||
Membership,
|
||||
StickyEvent,
|
||||
)
|
||||
from synapse.api.filtering import FilterCollection
|
||||
from synapse.api.presence import UserPresenceState
|
||||
@@ -153,6 +155,7 @@ class JoinedSyncResult:
|
||||
state: StateMap[EventBase]
|
||||
ephemeral: List[JsonDict]
|
||||
account_data: List[JsonDict]
|
||||
sticky: List[EventBase]
|
||||
unread_notifications: JsonDict
|
||||
unread_thread_notifications: JsonDict
|
||||
summary: Optional[JsonDict]
|
||||
@@ -163,7 +166,11 @@ class JoinedSyncResult:
|
||||
to tell if room needs to be part of the sync result.
|
||||
"""
|
||||
return bool(
|
||||
self.timeline or self.state or self.ephemeral or self.account_data
|
||||
self.timeline
|
||||
or self.state
|
||||
or self.ephemeral
|
||||
or self.account_data
|
||||
or self.sticky
|
||||
# nb the notification count does not, er, count: if there's nothing
|
||||
# else in the result, we don't need to send it.
|
||||
)
|
||||
@@ -603,6 +610,41 @@ class SyncHandler:
|
||||
|
||||
return now_token, ephemeral_by_room
|
||||
|
||||
async def sticky_events_by_room(
|
||||
self,
|
||||
sync_result_builder: "SyncResultBuilder",
|
||||
now_token: StreamToken,
|
||||
since_token: Optional[StreamToken] = None,
|
||||
) -> Tuple[StreamToken, Dict[str, Set[str]]]:
|
||||
"""Get the sticky events for each room the user is in
|
||||
Args:
|
||||
sync_result_builder
|
||||
now_token: Where the server is currently up to.
|
||||
since_token: Where the server was when the client last synced.
|
||||
Returns:
|
||||
A tuple of the now StreamToken, updated to reflect the which sticky
|
||||
events are included, and a dict mapping from room_id to a list of
|
||||
sticky event IDs for that room.
|
||||
"""
|
||||
now = round(time.time() * 1000)
|
||||
with Measure(
|
||||
self.clock, name="sticky_events_by_room", server_name=self.server_name
|
||||
):
|
||||
from_id = since_token.sticky_events_key if since_token else 0
|
||||
|
||||
room_ids = sync_result_builder.joined_room_ids
|
||||
|
||||
to_id, sticky_by_room = await self.store.get_sticky_events_in_rooms(
|
||||
room_ids,
|
||||
from_id,
|
||||
now_token.sticky_events_key,
|
||||
now,
|
||||
StickyEvent.MAX_EVENTS_IN_SYNC,
|
||||
)
|
||||
now_token = now_token.copy_and_replace(StreamKeyType.STICKY_EVENTS, to_id)
|
||||
|
||||
return now_token, sticky_by_room
|
||||
|
||||
async def _load_filtered_recents(
|
||||
self,
|
||||
room_id: str,
|
||||
@@ -2181,6 +2223,13 @@ class SyncHandler:
|
||||
)
|
||||
sync_result_builder.now_token = now_token
|
||||
|
||||
sticky_by_room: Dict[str, Set[str]] = {}
|
||||
if self.hs_config.experimental.msc4354_enabled:
|
||||
now_token, sticky_by_room = await self.sticky_events_by_room(
|
||||
sync_result_builder, now_token, since_token
|
||||
)
|
||||
sync_result_builder.now_token = now_token
|
||||
|
||||
# 2. We check up front if anything has changed, if it hasn't then there is
|
||||
# no point in going further.
|
||||
if not sync_result_builder.full_state:
|
||||
@@ -2191,7 +2240,7 @@ class SyncHandler:
|
||||
tags_by_room = await self.store.get_updated_tags(
|
||||
user_id, since_token.account_data_key
|
||||
)
|
||||
if not tags_by_room:
|
||||
if not tags_by_room and not sticky_by_room:
|
||||
logger.debug("no-oping sync")
|
||||
return set(), set()
|
||||
|
||||
@@ -2211,7 +2260,6 @@ class SyncHandler:
|
||||
tags_by_room = await self.store.get_tags_for_user(user_id)
|
||||
|
||||
log_kv({"rooms_changed": len(room_changes.room_entries)})
|
||||
|
||||
room_entries = room_changes.room_entries
|
||||
invited = room_changes.invited
|
||||
knocked = room_changes.knocked
|
||||
@@ -2229,6 +2277,7 @@ class SyncHandler:
|
||||
ephemeral=ephemeral_by_room.get(room_entry.room_id, []),
|
||||
tags=tags_by_room.get(room_entry.room_id),
|
||||
account_data=account_data_by_room.get(room_entry.room_id, {}),
|
||||
sticky_event_ids=sticky_by_room.get(room_entry.room_id, set()),
|
||||
always_include=sync_result_builder.full_state,
|
||||
)
|
||||
logger.debug("Generated room entry for %s", room_entry.room_id)
|
||||
@@ -2615,6 +2664,7 @@ class SyncHandler:
|
||||
ephemeral: List[JsonDict],
|
||||
tags: Optional[Mapping[str, JsonMapping]],
|
||||
account_data: Mapping[str, JsonMapping],
|
||||
sticky_event_ids: Set[str],
|
||||
always_include: bool = False,
|
||||
) -> None:
|
||||
"""Populates the `joined` and `archived` section of `sync_result_builder`
|
||||
@@ -2644,6 +2694,7 @@ class SyncHandler:
|
||||
tags: List of *all* tags for room, or None if there has been
|
||||
no change.
|
||||
account_data: List of new account data for room
|
||||
sticky_event_ids: MSC4354 sticky events in the room, if any.
|
||||
always_include: Always include this room in the sync response,
|
||||
even if empty.
|
||||
"""
|
||||
@@ -2654,7 +2705,13 @@ class SyncHandler:
|
||||
events = room_builder.events
|
||||
|
||||
# We want to shortcut out as early as possible.
|
||||
if not (always_include or account_data or ephemeral or full_state):
|
||||
if not (
|
||||
always_include
|
||||
or account_data
|
||||
or ephemeral
|
||||
or full_state
|
||||
or sticky_event_ids
|
||||
):
|
||||
if events == [] and tags is None:
|
||||
return
|
||||
|
||||
@@ -2746,6 +2803,7 @@ class SyncHandler:
|
||||
or account_data_events
|
||||
or ephemeral
|
||||
or full_state
|
||||
or sticky_event_ids
|
||||
):
|
||||
return
|
||||
|
||||
@@ -2792,6 +2850,22 @@ class SyncHandler:
|
||||
|
||||
if room_builder.rtype == "joined":
|
||||
unread_notifications: Dict[str, int] = {}
|
||||
sticky_events: List[EventBase] = []
|
||||
if sticky_event_ids:
|
||||
# remove sticky events that are in the timeline, else we will needlessly duplicate
|
||||
# events. This is particularly important given the risk of sticky events spam since
|
||||
# anyone can send sticky events, so halving the bandwidth on average for each sticky
|
||||
# event is helpful.
|
||||
timeline = {ev.event_id for ev in batch.events}
|
||||
sticky_event_ids = sticky_event_ids.difference(timeline)
|
||||
if sticky_event_ids:
|
||||
sticky_event_map = await self.store.get_events(sticky_event_ids)
|
||||
sticky_events = await filter_events_for_client(
|
||||
self._storage_controllers,
|
||||
sync_result_builder.sync_config.user.to_string(),
|
||||
list(sticky_event_map.values()),
|
||||
always_include_ids=frozenset(sticky_event_ids),
|
||||
)
|
||||
room_sync = JoinedSyncResult(
|
||||
room_id=room_id,
|
||||
timeline=batch,
|
||||
@@ -2802,6 +2876,7 @@ class SyncHandler:
|
||||
unread_thread_notifications={},
|
||||
summary=summary,
|
||||
unread_count=0,
|
||||
sticky=sticky_events,
|
||||
)
|
||||
|
||||
if room_sync or always_include:
|
||||
|
||||
@@ -533,6 +533,7 @@ class Notifier:
|
||||
StreamKeyType.TYPING,
|
||||
StreamKeyType.UN_PARTIAL_STATED_ROOMS,
|
||||
StreamKeyType.THREAD_SUBSCRIPTIONS,
|
||||
StreamKeyType.STICKY_EVENTS,
|
||||
],
|
||||
new_token: int,
|
||||
users: Optional[Collection[Union[str, UserID]]] = None,
|
||||
@@ -939,6 +940,11 @@ class Notifier:
|
||||
# that any in flight requests can be immediately retried.
|
||||
self._federation_client.wake_destination(server)
|
||||
|
||||
def notify_new_server_joined(self, server: str, room_id: str) -> None:
|
||||
# Inform the federation_sender that it may need to send events to the new server.
|
||||
if self.federation_sender:
|
||||
self.federation_sender.notify_new_server_joined(server, room_id)
|
||||
|
||||
def add_lock_released_callback(
|
||||
self, callback: Callable[[str, str, str], None]
|
||||
) -> None:
|
||||
|
||||
@@ -43,7 +43,10 @@ from synapse.replication.tcp.streams import (
|
||||
UnPartialStatedEventStream,
|
||||
UnPartialStatedRoomStream,
|
||||
)
|
||||
from synapse.replication.tcp.streams._base import ThreadSubscriptionsStream
|
||||
from synapse.replication.tcp.streams._base import (
|
||||
StickyEventsStream,
|
||||
ThreadSubscriptionsStream,
|
||||
)
|
||||
from synapse.replication.tcp.streams.events import (
|
||||
EventsStream,
|
||||
EventsStreamEventRow,
|
||||
@@ -261,6 +264,12 @@ class ReplicationDataHandler:
|
||||
token,
|
||||
users=[row.user_id for row in rows],
|
||||
)
|
||||
elif stream_name == StickyEventsStream.NAME:
|
||||
self.notifier.on_new_event(
|
||||
StreamKeyType.STICKY_EVENTS,
|
||||
token,
|
||||
rooms=[row.room_id for row in rows],
|
||||
)
|
||||
|
||||
await self._presence_handler.process_replication_rows(
|
||||
stream_name, instance_name, token, rows
|
||||
|
||||
@@ -462,6 +462,32 @@ class RemoteServerUpCommand(_SimpleCommand):
|
||||
NAME = "REMOTE_SERVER_UP"
|
||||
|
||||
|
||||
class NewServerJoinedCommand(Command):
|
||||
"""Sent when a worker has detected that a new remote server has joined a room.
|
||||
|
||||
Format::
|
||||
|
||||
NEW_SERVER_JOINED <server> <room_id>
|
||||
"""
|
||||
|
||||
NAME = "NEW_SERVER_JOINED"
|
||||
__slots__ = ["server", "room_id"]
|
||||
|
||||
def __init__(self, server: str, room_id: str):
|
||||
self.server = server
|
||||
self.room_id = room_id
|
||||
|
||||
@classmethod
|
||||
def from_line(
|
||||
cls: Type["NewServerJoinedCommand"], line: str
|
||||
) -> "NewServerJoinedCommand":
|
||||
server, room_id = line.split(" ")
|
||||
return cls(server, room_id)
|
||||
|
||||
def to_line(self) -> str:
|
||||
return "%s %s" % (self.server, self.room_id)
|
||||
|
||||
|
||||
class LockReleasedCommand(Command):
|
||||
"""Sent to inform other instances that a given lock has been dropped.
|
||||
|
||||
@@ -517,6 +543,7 @@ _COMMANDS: Tuple[Type[Command], ...] = (
|
||||
FederationAckCommand,
|
||||
UserIpCommand,
|
||||
RemoteServerUpCommand,
|
||||
NewServerJoinedCommand,
|
||||
ClearUserSyncsCommand,
|
||||
LockReleasedCommand,
|
||||
NewActiveTaskCommand,
|
||||
@@ -533,6 +560,7 @@ VALID_SERVER_COMMANDS = (
|
||||
ErrorCommand.NAME,
|
||||
PingCommand.NAME,
|
||||
RemoteServerUpCommand.NAME,
|
||||
NewServerJoinedCommand.NAME,
|
||||
LockReleasedCommand.NAME,
|
||||
)
|
||||
|
||||
@@ -547,6 +575,7 @@ VALID_CLIENT_COMMANDS = (
|
||||
UserIpCommand.NAME,
|
||||
ErrorCommand.NAME,
|
||||
RemoteServerUpCommand.NAME,
|
||||
NewServerJoinedCommand.NAME,
|
||||
LockReleasedCommand.NAME,
|
||||
)
|
||||
|
||||
|
||||
@@ -47,6 +47,7 @@ from synapse.replication.tcp.commands import (
|
||||
FederationAckCommand,
|
||||
LockReleasedCommand,
|
||||
NewActiveTaskCommand,
|
||||
NewServerJoinedCommand,
|
||||
PositionCommand,
|
||||
RdataCommand,
|
||||
RemoteServerUpCommand,
|
||||
@@ -73,6 +74,7 @@ from synapse.replication.tcp.streams import (
|
||||
)
|
||||
from synapse.replication.tcp.streams._base import (
|
||||
DeviceListsStream,
|
||||
StickyEventsStream,
|
||||
ThreadSubscriptionsStream,
|
||||
)
|
||||
|
||||
@@ -224,6 +226,12 @@ class ReplicationCommandHandler:
|
||||
|
||||
continue
|
||||
|
||||
if isinstance(stream, StickyEventsStream):
|
||||
if hs.get_instance_name() in hs.config.worker.writers.events:
|
||||
self._streams_to_replicate.append(stream)
|
||||
|
||||
continue
|
||||
|
||||
if isinstance(stream, DeviceListsStream):
|
||||
if hs.get_instance_name() in hs.config.worker.writers.device_lists:
|
||||
self._streams_to_replicate.append(stream)
|
||||
@@ -756,6 +764,12 @@ class ReplicationCommandHandler:
|
||||
"""Called when get a new REMOTE_SERVER_UP command."""
|
||||
self._notifier.notify_remote_server_up(cmd.data)
|
||||
|
||||
def on_NEW_SERVER_JOINED(
|
||||
self, conn: IReplicationConnection, cmd: NewServerJoinedCommand
|
||||
) -> None:
|
||||
"""Called when get a new NEW_SERVER_JOINED command."""
|
||||
self._notifier.notify_new_server_joined(cmd.server, cmd.room_id)
|
||||
|
||||
def on_LOCK_RELEASED(
|
||||
self, conn: IReplicationConnection, cmd: LockReleasedCommand
|
||||
) -> None:
|
||||
@@ -878,6 +892,9 @@ class ReplicationCommandHandler:
|
||||
def send_remote_server_up(self, server: str) -> None:
|
||||
self.send_command(RemoteServerUpCommand(server))
|
||||
|
||||
def send_new_server_joined(self, server: str, room_id: str) -> None:
|
||||
self.send_command(NewServerJoinedCommand(server, room_id))
|
||||
|
||||
def stream_update(self, stream_name: str, token: Optional[int], data: Any) -> None:
|
||||
"""Called when a new update is available to stream to Redis subscribers.
|
||||
|
||||
|
||||
@@ -40,6 +40,7 @@ from synapse.replication.tcp.streams._base import (
|
||||
PushersStream,
|
||||
PushRulesStream,
|
||||
ReceiptsStream,
|
||||
StickyEventsStream,
|
||||
Stream,
|
||||
ThreadSubscriptionsStream,
|
||||
ToDeviceStream,
|
||||
@@ -68,6 +69,7 @@ STREAMS_MAP = {
|
||||
ToDeviceStream,
|
||||
FederationStream,
|
||||
AccountDataStream,
|
||||
StickyEventsStream,
|
||||
ThreadSubscriptionsStream,
|
||||
UnPartialStatedRoomStream,
|
||||
UnPartialStatedEventStream,
|
||||
@@ -90,6 +92,7 @@ __all__ = [
|
||||
"ToDeviceStream",
|
||||
"FederationStream",
|
||||
"AccountDataStream",
|
||||
"StickyEventsStream",
|
||||
"ThreadSubscriptionsStream",
|
||||
"UnPartialStatedRoomStream",
|
||||
"UnPartialStatedEventStream",
|
||||
|
||||
@@ -766,3 +766,46 @@ class ThreadSubscriptionsStream(_StreamFromIdGen):
|
||||
return [], to_token, False
|
||||
|
||||
return rows, rows[-1][0], len(updates) == limit
|
||||
|
||||
|
||||
@attr.s(slots=True, auto_attribs=True)
|
||||
class StickyEventsStreamRow:
|
||||
"""Stream to inform workers about changes to sticky events."""
|
||||
|
||||
room_id: str
|
||||
event_id: str # The sticky event ID
|
||||
|
||||
|
||||
class StickyEventsStream(_StreamFromIdGen):
|
||||
"""A sticky event was changed."""
|
||||
|
||||
NAME = "sticky_events"
|
||||
ROW_TYPE = StickyEventsStreamRow
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.store = hs.get_datastores().main
|
||||
super().__init__(
|
||||
hs.get_instance_name(),
|
||||
self._update_function,
|
||||
self.store._sticky_events_id_gen,
|
||||
)
|
||||
|
||||
async def _update_function(
|
||||
self, instance_name: str, from_token: int, to_token: int, limit: int
|
||||
) -> StreamUpdateResult:
|
||||
updates = await self.store.get_updated_sticky_events(
|
||||
from_id=from_token, to_id=to_token, limit=limit
|
||||
)
|
||||
rows = [
|
||||
(
|
||||
stream_id,
|
||||
# These are the args to `StickyEventsStreamRow`
|
||||
(room_id, event_id),
|
||||
)
|
||||
for stream_id, room_id, event_id, _ in updates
|
||||
]
|
||||
|
||||
if not rows:
|
||||
return [], to_token, False
|
||||
|
||||
return rows, rows[-1][0], len(updates) == limit
|
||||
|
||||
@@ -33,7 +33,7 @@ from prometheus_client.core import Histogram
|
||||
from twisted.web.server import Request
|
||||
|
||||
from synapse import event_auth
|
||||
from synapse.api.constants import Direction, EventTypes, Membership
|
||||
from synapse.api.constants import Direction, EventTypes, Membership, StickyEvent
|
||||
from synapse.api.errors import (
|
||||
AuthError,
|
||||
Codes,
|
||||
@@ -205,6 +205,7 @@ class RoomStateEventRestServlet(RestServlet):
|
||||
self.clock = hs.get_clock()
|
||||
self._max_event_delay_ms = hs.config.server.max_event_delay_ms
|
||||
self._spam_checker_module_callbacks = hs.get_module_api_callbacks().spam_checker
|
||||
self.msc4354_enabled = hs.config.experimental.msc4354_enabled
|
||||
|
||||
def register(self, http_server: HttpServer) -> None:
|
||||
# /rooms/$roomid/state/$eventtype
|
||||
@@ -326,6 +327,10 @@ class RoomStateEventRestServlet(RestServlet):
|
||||
if requester.app_service:
|
||||
origin_server_ts = parse_integer(request, "ts")
|
||||
|
||||
sticky_duration_ms: Optional[int] = None
|
||||
if self.msc4354_enabled:
|
||||
sticky_duration_ms = parse_integer(request, StickyEvent.QUERY_PARAM_NAME)
|
||||
|
||||
delay = _parse_request_delay(request, self._max_event_delay_ms)
|
||||
if delay is not None:
|
||||
delay_id = await self.delayed_events_handler.add(
|
||||
@@ -336,6 +341,7 @@ class RoomStateEventRestServlet(RestServlet):
|
||||
origin_server_ts=origin_server_ts,
|
||||
content=content,
|
||||
delay=delay,
|
||||
sticky_duration_ms=sticky_duration_ms,
|
||||
)
|
||||
|
||||
set_tag("delay_id", delay_id)
|
||||
@@ -363,6 +369,10 @@ class RoomStateEventRestServlet(RestServlet):
|
||||
"room_id": room_id,
|
||||
"sender": requester.user.to_string(),
|
||||
}
|
||||
if sticky_duration_ms is not None:
|
||||
event_dict[StickyEvent.FIELD_NAME] = {
|
||||
"duration_ms": sticky_duration_ms,
|
||||
}
|
||||
|
||||
if state_key is not None:
|
||||
event_dict["state_key"] = state_key
|
||||
@@ -395,6 +405,7 @@ class RoomSendEventRestServlet(TransactionRestServlet):
|
||||
self.delayed_events_handler = hs.get_delayed_events_handler()
|
||||
self.auth = hs.get_auth()
|
||||
self._max_event_delay_ms = hs.config.server.max_event_delay_ms
|
||||
self.msc4354_enabled = hs.config.experimental.msc4354_enabled
|
||||
|
||||
def register(self, http_server: HttpServer) -> None:
|
||||
# /rooms/$roomid/send/$event_type[/$txn_id]
|
||||
@@ -415,6 +426,10 @@ class RoomSendEventRestServlet(TransactionRestServlet):
|
||||
if requester.app_service:
|
||||
origin_server_ts = parse_integer(request, "ts")
|
||||
|
||||
sticky_duration_ms: Optional[int] = None
|
||||
if self.msc4354_enabled:
|
||||
sticky_duration_ms = parse_integer(request, StickyEvent.QUERY_PARAM_NAME)
|
||||
|
||||
delay = _parse_request_delay(request, self._max_event_delay_ms)
|
||||
if delay is not None:
|
||||
delay_id = await self.delayed_events_handler.add(
|
||||
@@ -425,6 +440,7 @@ class RoomSendEventRestServlet(TransactionRestServlet):
|
||||
origin_server_ts=origin_server_ts,
|
||||
content=content,
|
||||
delay=delay,
|
||||
sticky_duration_ms=sticky_duration_ms,
|
||||
)
|
||||
|
||||
set_tag("delay_id", delay_id)
|
||||
@@ -441,6 +457,11 @@ class RoomSendEventRestServlet(TransactionRestServlet):
|
||||
if origin_server_ts is not None:
|
||||
event_dict["origin_server_ts"] = origin_server_ts
|
||||
|
||||
if sticky_duration_ms is not None:
|
||||
event_dict[StickyEvent.FIELD_NAME] = {
|
||||
"duration_ms": sticky_duration_ms,
|
||||
}
|
||||
|
||||
try:
|
||||
(
|
||||
event,
|
||||
|
||||
@@ -617,6 +617,11 @@ class SyncRestServlet(RestServlet):
|
||||
ephemeral_events = room.ephemeral
|
||||
result["ephemeral"] = {"events": ephemeral_events}
|
||||
result["unread_notifications"] = room.unread_notifications
|
||||
if room.sticky:
|
||||
serialized_sticky = await self._event_serializer.serialize_events(
|
||||
room.sticky, time_now, config=serialize_options
|
||||
)
|
||||
result["msc4354_sticky"] = {"events": serialized_sticky}
|
||||
if room.unread_thread_notifications:
|
||||
result["unread_thread_notifications"] = room.unread_thread_notifications
|
||||
if self._msc3773_enabled:
|
||||
@@ -646,6 +651,7 @@ class SlidingSyncRestServlet(RestServlet):
|
||||
- receipts (MSC3960)
|
||||
- account data (MSC3959)
|
||||
- thread subscriptions (MSC4308)
|
||||
- sticky events (MSC4354)
|
||||
|
||||
Request query parameters:
|
||||
timeout: How long to wait for new events in milliseconds.
|
||||
@@ -1089,8 +1095,35 @@ class SlidingSyncRestServlet(RestServlet):
|
||||
_serialise_thread_subscriptions(extensions.thread_subscriptions)
|
||||
)
|
||||
|
||||
if extensions.sticky_events:
|
||||
serialized_extensions[
|
||||
"org.matrix.msc4354.sticky_events"
|
||||
] = await self._serialise_sticky_events(requester, extensions.sticky_events)
|
||||
|
||||
return serialized_extensions
|
||||
|
||||
async def _serialise_sticky_events(
|
||||
self,
|
||||
requester: Requester,
|
||||
sticky_events: SlidingSyncResult.Extensions.StickyEventsExtension,
|
||||
) -> JsonDict:
|
||||
time_now = self.clock.time_msec()
|
||||
# Same as SSS timelines. TODO: support more options like /sync does.
|
||||
serialize_options = SerializeEventConfig(
|
||||
event_format=format_event_for_client_v2_without_room_id,
|
||||
requester=requester,
|
||||
)
|
||||
return {
|
||||
"rooms": {
|
||||
room_id: {
|
||||
"events": await self.event_serializer.serialize_events(
|
||||
sticky_events, time_now, config=serialize_options
|
||||
)
|
||||
}
|
||||
for room_id, sticky_events in sticky_events.room_id_to_sticky_events.items()
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _serialise_thread_subscriptions(
|
||||
thread_subscriptions: SlidingSyncResult.Extensions.ThreadSubscriptionsExtension,
|
||||
|
||||
@@ -182,6 +182,8 @@ class VersionsRestServlet(RestServlet):
|
||||
"org.matrix.msc4306": self.config.experimental.msc4306_enabled,
|
||||
# MSC4169: Backwards-compatible redaction sending using `/send`
|
||||
"com.beeper.msc4169": self.config.experimental.msc4169_enabled,
|
||||
# MSC4354: Sticky events
|
||||
"org.matrix.msc4354": self.config.experimental.msc4354_enabled,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
@@ -665,6 +665,29 @@ class EventsPersistenceStorageController:
|
||||
async with self._state_deletion_store.persisting_state_group_references(
|
||||
events_and_contexts
|
||||
):
|
||||
new_servers: Optional[Set[str]] = None
|
||||
if self.hs.config.experimental.msc4354_enabled and state_delta_for_room:
|
||||
# We specifically only consider events in `chunk` to reduce the risk of state rollbacks
|
||||
# causing servers to appear to repeatedly rejoin rooms. This works because we only
|
||||
# persist events once, whereas the state delta may unreliably flap between joined members
|
||||
# on unrelated events. This means we may miss cases where the /first/ join event for a server
|
||||
# is as a result of a state rollback and not as a result of a new join event. That is fine
|
||||
# because the chance of that happening is vanishingly rare because the join event would need to be
|
||||
# persisted without it affecting the current state (e.g there's a concurrent ban for that user)
|
||||
# which is then revoked concurrently by a later event (e.g the user is unbanned).
|
||||
# If state resolution were more reliable (in terms of state resets) then we could feasibly only
|
||||
# consider the events in the state_delta_for_room, but we aren't there yet.
|
||||
new_event_ids_in_current_state = set(
|
||||
state_delta_for_room.to_insert.values()
|
||||
)
|
||||
new_servers = await self._check_new_servers_joined(
|
||||
room_id,
|
||||
[
|
||||
ev
|
||||
for (ev, _) in chunk
|
||||
if ev.event_id in new_event_ids_in_current_state
|
||||
],
|
||||
)
|
||||
await self.persist_events_store._persist_events_and_state_updates(
|
||||
room_id,
|
||||
chunk,
|
||||
@@ -674,9 +697,71 @@ class EventsPersistenceStorageController:
|
||||
inhibit_local_membership_updates=backfilled,
|
||||
new_event_links=new_event_links,
|
||||
)
|
||||
if new_servers:
|
||||
# Notify other workers after the server has joined so they can take into account
|
||||
# the latest events that are in `chunk`.
|
||||
for server_name in new_servers:
|
||||
self.hs.get_notifier().notify_new_server_joined(
|
||||
server_name, room_id
|
||||
)
|
||||
self.hs.get_replication_command_handler().send_new_server_joined(
|
||||
server_name, room_id
|
||||
)
|
||||
|
||||
return replaced_events
|
||||
|
||||
async def _check_new_servers_joined(
|
||||
self, room_id: str, new_events_in_current_state: List[EventBase]
|
||||
) -> Optional[Set[str]]:
|
||||
"""Check if new servers have joined the given room.
|
||||
|
||||
Assumes this function is called BEFORE the current_state_events table is updated.
|
||||
|
||||
A new server is "joined" if this is the first join event seen from this domain.
|
||||
|
||||
Args:
|
||||
room_id: The room in question
|
||||
new_events_in_current_state: A list of events that will become part of the current state,
|
||||
but have not yet been persisted.
|
||||
"""
|
||||
# filter to only join events from other servers. We're obviously joined if we are getting full events
|
||||
# so needn't consider ourselves.
|
||||
join_events = [
|
||||
ev
|
||||
for ev in new_events_in_current_state
|
||||
if ev.type == EventTypes.Member
|
||||
and ev.is_state()
|
||||
and not self.is_mine_id(ev.state_key)
|
||||
and ev.membership == Membership.JOIN
|
||||
]
|
||||
if not join_events:
|
||||
return None
|
||||
|
||||
joining_domains = {get_domain_from_id(ev.state_key) for ev in join_events}
|
||||
|
||||
# load all joined members from the current_state_events table as this table is fast and has what we want.
|
||||
# This is the current state prior to applying the update.
|
||||
joined_members: List[
|
||||
Tuple[str]
|
||||
] = await self.main_store.db_pool.simple_select_list(
|
||||
"current_state_events",
|
||||
{
|
||||
"room_id": room_id,
|
||||
"type": EventTypes.Member,
|
||||
"membership": Membership.JOIN,
|
||||
},
|
||||
retcols=["state_key"],
|
||||
desc="_check_new_servers_joined",
|
||||
)
|
||||
joined_domains = {
|
||||
get_domain_from_id(state_key) for (state_key,) in joined_members
|
||||
}
|
||||
|
||||
newly_joined_domains = joining_domains.difference(joined_domains)
|
||||
if not newly_joined_domains:
|
||||
return None
|
||||
return newly_joined_domains
|
||||
|
||||
async def _calculate_new_forward_extremities_and_state_delta(
|
||||
self, room_id: str, ev_ctx_rm: List[EventPersistencePair]
|
||||
) -> Tuple[Optional[Set[str]], Optional[DeltaState]]:
|
||||
|
||||
@@ -34,6 +34,7 @@ from synapse.storage.database import (
|
||||
)
|
||||
from synapse.storage.databases.main.sliding_sync import SlidingSyncStore
|
||||
from synapse.storage.databases.main.stats import UserSortOrder
|
||||
from synapse.storage.databases.main.sticky_events import StickyEventsWorkerStore
|
||||
from synapse.storage.databases.main.thread_subscriptions import (
|
||||
ThreadSubscriptionsWorkerStore,
|
||||
)
|
||||
@@ -144,6 +145,7 @@ class DataStore(
|
||||
TagsStore,
|
||||
AccountDataStore,
|
||||
ThreadSubscriptionsWorkerStore,
|
||||
StickyEventsWorkerStore,
|
||||
PushRulesWorkerStore,
|
||||
StreamWorkerStore,
|
||||
OpenIdStore,
|
||||
|
||||
@@ -46,6 +46,7 @@ class EventDetails:
|
||||
origin_server_ts: Optional[Timestamp]
|
||||
content: JsonDict
|
||||
device_id: Optional[DeviceID]
|
||||
sticky_duration_ms: Optional[int]
|
||||
|
||||
|
||||
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||
@@ -93,6 +94,7 @@ class DelayedEventsStore(SQLBaseStore):
|
||||
origin_server_ts: Optional[int],
|
||||
content: JsonDict,
|
||||
delay: int,
|
||||
sticky_duration_ms: Optional[int],
|
||||
) -> Tuple[DelayID, Timestamp]:
|
||||
"""
|
||||
Inserts a new delayed event in the DB.
|
||||
@@ -119,6 +121,7 @@ class DelayedEventsStore(SQLBaseStore):
|
||||
"state_key": state_key,
|
||||
"origin_server_ts": origin_server_ts,
|
||||
"content": json_encoder.encode(content),
|
||||
"sticky_duration_ms": sticky_duration_ms,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -265,6 +268,7 @@ class DelayedEventsStore(SQLBaseStore):
|
||||
"send_ts",
|
||||
"content",
|
||||
"device_id",
|
||||
"sticky_duration_ms",
|
||||
)
|
||||
)
|
||||
sql_update = "UPDATE delayed_events SET is_processed = TRUE"
|
||||
@@ -305,6 +309,7 @@ class DelayedEventsStore(SQLBaseStore):
|
||||
Timestamp(row[5] if row[5] is not None else row[6]),
|
||||
db_to_json(row[7]),
|
||||
DeviceID(row[8]) if row[8] is not None else None,
|
||||
int(row[9]) if row[9] is not None else None,
|
||||
DelayID(row[0]),
|
||||
UserLocalpart(row[1]),
|
||||
)
|
||||
@@ -355,6 +360,7 @@ class DelayedEventsStore(SQLBaseStore):
|
||||
"origin_server_ts",
|
||||
"content",
|
||||
"device_id",
|
||||
"sticky_duration_ms",
|
||||
)
|
||||
)
|
||||
sql_update = "UPDATE delayed_events SET is_processed = TRUE"
|
||||
@@ -382,6 +388,7 @@ class DelayedEventsStore(SQLBaseStore):
|
||||
Timestamp(row[3]) if row[3] is not None else None,
|
||||
db_to_json(row[4]),
|
||||
DeviceID(row[5]) if row[5] is not None else None,
|
||||
int(row[6]) if row[6] is not None else None,
|
||||
)
|
||||
|
||||
return event, self._get_next_delayed_event_send_ts_txn(txn)
|
||||
|
||||
@@ -251,6 +251,7 @@ class PersistEventsStore:
|
||||
self.database_engine = db.engine
|
||||
self._clock = hs.get_clock()
|
||||
self._instance_name = hs.get_instance_name()
|
||||
self.msc4354_sticky_events = hs.config.experimental.msc4354_enabled
|
||||
|
||||
self._ephemeral_messages_enabled = hs.config.server.enable_ephemeral_messages
|
||||
self.is_mine_id = hs.is_mine_id
|
||||
@@ -370,6 +371,21 @@ class PersistEventsStore:
|
||||
len(events_and_contexts)
|
||||
)
|
||||
|
||||
# TODO: are we guaranteed to call the below code if we were to die now?
|
||||
# On startup we will already think we have persisted the events?
|
||||
|
||||
# This was originally in _persist_events_txn but it relies on non-txn functions like
|
||||
# get_events_as_list and get_partial_filtered_current_state_ids to handle soft-failure
|
||||
# re-evaluation, so it can't do that without leaking out the txn currently, hence it
|
||||
# now just lives outside.
|
||||
if self.msc4354_sticky_events:
|
||||
# re-evaluate soft-failed sticky events.
|
||||
await self.store.reevaluate_soft_failed_sticky_events(
|
||||
room_id,
|
||||
events_and_contexts,
|
||||
state_delta_for_room,
|
||||
)
|
||||
|
||||
if not use_negative_stream_ordering:
|
||||
# we don't want to set the event_persisted_position to a negative
|
||||
# stream_ordering.
|
||||
@@ -1170,6 +1186,11 @@ class PersistEventsStore:
|
||||
sliding_sync_table_changes,
|
||||
)
|
||||
|
||||
if self.msc4354_sticky_events:
|
||||
self.store.insert_sticky_events_txn(
|
||||
txn, [ev for ev, _ in events_and_contexts]
|
||||
)
|
||||
|
||||
# We only update the sliding sync tables for non-backfilled events.
|
||||
self._update_sliding_sync_tables_with_new_persisted_events_txn(
|
||||
txn, room_id, events_and_contexts
|
||||
@@ -2631,6 +2652,11 @@ class PersistEventsStore:
|
||||
# event isn't an outlier any more.
|
||||
self._update_backward_extremeties(txn, [event])
|
||||
|
||||
if self.msc4354_sticky_events and event.sticky_duration():
|
||||
# The de-outliered event is sticky. Update the sticky events table to ensure
|
||||
# we delivery this down /sync.
|
||||
self.store.insert_sticky_events_txn(txn, [event])
|
||||
|
||||
return [ec for ec in events_and_contexts if ec[0] not in to_remove]
|
||||
|
||||
def _store_event_txn(
|
||||
|
||||
@@ -73,6 +73,10 @@ from synapse.metrics.background_process_metrics import (
|
||||
wrap_as_background_process,
|
||||
)
|
||||
from synapse.replication.tcp.streams import BackfillStream, UnPartialStatedEventStream
|
||||
from synapse.replication.tcp.streams._base import (
|
||||
StickyEventsStream,
|
||||
StickyEventsStreamRow,
|
||||
)
|
||||
from synapse.replication.tcp.streams.events import EventsStream
|
||||
from synapse.replication.tcp.streams.partial_state import UnPartialStatedEventStreamRow
|
||||
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
|
||||
@@ -463,6 +467,11 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
# If the partial-stated event became rejected or unrejected
|
||||
# when it wasn't before, we need to invalidate this cache.
|
||||
self._invalidate_local_get_event_cache(row.event_id)
|
||||
elif stream_name == StickyEventsStream.NAME:
|
||||
for row in rows:
|
||||
assert isinstance(row, StickyEventsStreamRow)
|
||||
# In case soft-failure status changed, invalidate the cache.
|
||||
self._invalidate_local_get_event_cache(row.event_id)
|
||||
|
||||
super().process_replication_rows(stream_name, instance_name, token, rows)
|
||||
|
||||
|
||||
@@ -2460,7 +2460,10 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
|
||||
self._instance_name = hs.get_instance_name()
|
||||
|
||||
async def upsert_room_on_join(
|
||||
self, room_id: str, room_version: RoomVersion, state_events: List[EventBase]
|
||||
self,
|
||||
room_id: str,
|
||||
room_version: RoomVersion,
|
||||
state_events: Optional[List[EventBase]],
|
||||
) -> None:
|
||||
"""Ensure that the room is stored in the table
|
||||
|
||||
@@ -2472,36 +2475,46 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
|
||||
# mark the room as having an auth chain cover index.
|
||||
has_auth_chain_index = await self.has_auth_chain_index(room_id)
|
||||
|
||||
create_event = None
|
||||
for e in state_events:
|
||||
if (e.type, e.state_key) == (EventTypes.Create, ""):
|
||||
create_event = e
|
||||
break
|
||||
# We may want to insert a row into the rooms table BEFORE having the state events in the
|
||||
# room, in order to correctly handle the race condition where the /send_join is processed
|
||||
# remotely which causes remote servers to send us events before we've processed the /send_join
|
||||
# response. Therefore, we allow state_events (and thus the creator column) to be optional.
|
||||
# When we get the /send_join response, we'll patch this up.
|
||||
room_creator: Optional[str] = None
|
||||
if state_events:
|
||||
create_event = None
|
||||
for e in state_events:
|
||||
if (e.type, e.state_key) == (EventTypes.Create, ""):
|
||||
create_event = e
|
||||
break
|
||||
|
||||
if create_event is None:
|
||||
# If the state doesn't have a create event then the room is
|
||||
# invalid, and it would fail auth checks anyway.
|
||||
raise StoreError(400, "No create event in state")
|
||||
|
||||
# Before MSC2175, the room creator was a separate field.
|
||||
if not room_version.implicit_room_creator:
|
||||
room_creator = create_event.content.get(EventContentFields.ROOM_CREATOR)
|
||||
|
||||
if not isinstance(room_creator, str):
|
||||
# If the create event does not have a creator then the room is
|
||||
if create_event is None:
|
||||
# If the state doesn't have a create event then the room is
|
||||
# invalid, and it would fail auth checks anyway.
|
||||
raise StoreError(400, "No creator defined on the create event")
|
||||
else:
|
||||
room_creator = create_event.sender
|
||||
raise StoreError(400, "No create event in state")
|
||||
|
||||
# Before MSC2175, the room creator was a separate field.
|
||||
if not room_version.implicit_room_creator:
|
||||
room_creator = create_event.content.get(EventContentFields.ROOM_CREATOR)
|
||||
|
||||
if not isinstance(room_creator, str):
|
||||
# If the create event does not have a creator then the room is
|
||||
# invalid, and it would fail auth checks anyway.
|
||||
raise StoreError(400, "No creator defined on the create event")
|
||||
else:
|
||||
room_creator = create_event.sender
|
||||
|
||||
update_with = {"room_version": room_version.identifier}
|
||||
if room_creator:
|
||||
update_with["creator"] = room_creator
|
||||
|
||||
await self.db_pool.simple_upsert(
|
||||
desc="upsert_room_on_join",
|
||||
table="rooms",
|
||||
keyvalues={"room_id": room_id},
|
||||
values={"room_version": room_version.identifier},
|
||||
values=update_with,
|
||||
insertion_values={
|
||||
"is_public": False,
|
||||
"creator": room_creator,
|
||||
"has_auth_chain_index": has_auth_chain_index,
|
||||
},
|
||||
)
|
||||
|
||||
634
synapse/storage/databases/main/sticky_events.py
Normal file
634
synapse/storage/databases/main/sticky_events.py
Normal file
@@ -0,0 +1,634 @@
|
||||
#
|
||||
# This file is licensed under the Affero General Public License (AGPL) version 3.
|
||||
#
|
||||
# Copyright (C) 2025 New Vector, Ltd
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as
|
||||
# published by the Free Software Foundation, either version 3 of the
|
||||
# License, or (at your option) any later version.
|
||||
#
|
||||
# See the GNU Affero General Public License for more details:
|
||||
# <https://www.gnu.org/licenses/agpl-3.0.html>.
|
||||
import logging
|
||||
import time
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Collection,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
cast,
|
||||
)
|
||||
|
||||
from twisted.internet.defer import Deferred
|
||||
|
||||
from synapse import event_auth
|
||||
from synapse.api.constants import EventTypes, StickyEvent
|
||||
from synapse.api.errors import AuthError
|
||||
from synapse.events import EventBase
|
||||
from synapse.events.snapshot import EventPersistencePair
|
||||
from synapse.replication.tcp.streams._base import StickyEventsStream
|
||||
from synapse.storage.database import (
|
||||
DatabasePool,
|
||||
LoggingDatabaseConnection,
|
||||
LoggingTransaction,
|
||||
make_in_list_sql_clause,
|
||||
)
|
||||
from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
|
||||
from synapse.storage.databases.main.events import DeltaState
|
||||
from synapse.storage.databases.main.state import StateGroupWorkerStore
|
||||
from synapse.storage.engines import PostgresEngine
|
||||
from synapse.storage.util.id_generators import MultiWriterIdGenerator
|
||||
from synapse.types.state import StateFilter
|
||||
from synapse.util.stringutils import shortstr
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Remove entries from the sticky_events table at this frequency.
|
||||
# Note: this does NOT mean we don't honour shorter expiration timeouts.
|
||||
# Consumers call 'get_sticky_events_in_rooms' which has `WHERE expires_at > ?`
|
||||
# to filter out expired sticky events that have yet to be deleted.
|
||||
DELETE_EXPIRED_STICKY_EVENTS_MS = 60 * 1000 * 60 # 1 hour
|
||||
|
||||
|
||||
class StickyEventsWorkerStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
|
||||
def __init__(
|
||||
self,
|
||||
database: DatabasePool,
|
||||
db_conn: LoggingDatabaseConnection,
|
||||
hs: "HomeServer",
|
||||
):
|
||||
super().__init__(database, db_conn, hs)
|
||||
|
||||
self._can_write_to_sticky_events = (
|
||||
self._instance_name in hs.config.worker.writers.events
|
||||
)
|
||||
|
||||
# Technically this means we will cleanup N times, once per event persister, maybe put on master?
|
||||
if self._can_write_to_sticky_events:
|
||||
self.clock.looping_call(
|
||||
self._run_background_cleanup, DELETE_EXPIRED_STICKY_EVENTS_MS
|
||||
)
|
||||
|
||||
self._sticky_events_id_gen: MultiWriterIdGenerator = MultiWriterIdGenerator(
|
||||
db_conn=db_conn,
|
||||
db=database,
|
||||
notifier=hs.get_replication_notifier(),
|
||||
stream_name="sticky_events",
|
||||
server_name=self.server_name,
|
||||
instance_name=self._instance_name,
|
||||
tables=[
|
||||
("sticky_events", "instance_name", "stream_id"),
|
||||
],
|
||||
sequence_name="sticky_events_sequence",
|
||||
writers=hs.config.worker.writers.events,
|
||||
)
|
||||
|
||||
def process_replication_rows(
|
||||
self,
|
||||
stream_name: str,
|
||||
instance_name: str,
|
||||
token: int,
|
||||
rows: Iterable[Any],
|
||||
) -> None:
|
||||
super().process_replication_rows(stream_name, instance_name, token, rows)
|
||||
|
||||
def process_replication_position(
|
||||
self, stream_name: str, instance_name: str, token: int
|
||||
) -> None:
|
||||
if stream_name == StickyEventsStream.NAME:
|
||||
self._sticky_events_id_gen.advance(instance_name, token)
|
||||
super().process_replication_position(stream_name, instance_name, token)
|
||||
|
||||
def get_max_sticky_events_stream_id(self) -> int:
|
||||
"""Get the current maximum stream_id for thread subscriptions.
|
||||
|
||||
Returns:
|
||||
The maximum stream_id
|
||||
"""
|
||||
return self._sticky_events_id_gen.get_current_token()
|
||||
|
||||
def get_sticky_events_stream_id_generator(self) -> MultiWriterIdGenerator:
|
||||
return self._sticky_events_id_gen
|
||||
|
||||
async def get_sticky_events_in_rooms(
|
||||
self,
|
||||
room_ids: Collection[str],
|
||||
from_id: int,
|
||||
to_id: int,
|
||||
now: int,
|
||||
limit: int,
|
||||
) -> Tuple[int, Dict[str, Set[str]]]:
|
||||
"""
|
||||
Fetch all the sticky events in the given rooms, from the given sticky stream ID.
|
||||
|
||||
Args:
|
||||
room_ids: The room IDs to return sticky events in.
|
||||
from_id: The sticky stream ID that sticky events should be returned from (exclusive).
|
||||
to_id: The sticky stream ID that sticky events should end at (inclusive).
|
||||
now: The current time in unix millis, used for skipping expired events.
|
||||
limit: Max sticky events to return. If <= 0, no limit is applied.
|
||||
Returns:
|
||||
A tuple of (to_id, map[room_id, event_ids])
|
||||
"""
|
||||
sticky_events_rows = await self.db_pool.runInteraction(
|
||||
"get_sticky_events_in_rooms",
|
||||
self._get_sticky_events_in_rooms_txn,
|
||||
room_ids,
|
||||
from_id,
|
||||
to_id,
|
||||
now,
|
||||
limit,
|
||||
)
|
||||
new_to_id = from_id
|
||||
room_to_events: Dict[str, Set[str]] = {}
|
||||
for stream_id, room_id, event_id in sticky_events_rows:
|
||||
new_to_id = max(new_to_id, stream_id)
|
||||
events = room_to_events.get(room_id, set())
|
||||
events.add(event_id)
|
||||
room_to_events[room_id] = events
|
||||
return (new_to_id, room_to_events)
|
||||
|
||||
def _get_sticky_events_in_rooms_txn(
|
||||
self,
|
||||
txn: LoggingTransaction,
|
||||
room_ids: Collection[str],
|
||||
from_id: int,
|
||||
to_id: int,
|
||||
now: int,
|
||||
limit: int,
|
||||
) -> List[Tuple[int, str, str]]:
|
||||
if len(room_ids) == 0:
|
||||
return []
|
||||
clause, room_id_values = make_in_list_sql_clause(
|
||||
txn.database_engine, "room_id", room_ids
|
||||
)
|
||||
query = f"""
|
||||
SELECT stream_id, room_id, event_id FROM sticky_events
|
||||
WHERE soft_failed != ? AND expires_at > ? AND stream_id > ? AND stream_id <= ? AND {clause}
|
||||
ORDER BY stream_id ASC
|
||||
"""
|
||||
params = (True, now, from_id, to_id, *room_id_values)
|
||||
if limit > 0:
|
||||
query += "LIMIT ?"
|
||||
params += (limit,)
|
||||
txn.execute(query, params)
|
||||
return cast(List[Tuple[int, str, str]], txn.fetchall())
|
||||
|
||||
async def get_updated_sticky_events(
|
||||
self, from_id: int, to_id: int, limit: int
|
||||
) -> List[Tuple[int, str, str, bool]]:
|
||||
"""Get updates to sticky events between two stream IDs.
|
||||
|
||||
Args:
|
||||
from_id: The starting stream ID (exclusive)
|
||||
to_id: The ending stream ID (inclusive)
|
||||
limit: The maximum number of rows to return
|
||||
|
||||
Returns:
|
||||
list of (stream_id, room_id, event_id) tuples
|
||||
"""
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_updated_sticky_events",
|
||||
self._get_updated_sticky_events_txn,
|
||||
from_id,
|
||||
to_id,
|
||||
limit,
|
||||
)
|
||||
|
||||
def _get_updated_sticky_events_txn(
|
||||
self, txn: LoggingTransaction, from_id: int, to_id: int, limit: int
|
||||
) -> List[Tuple[int, str, str, bool]]:
|
||||
txn.execute(
|
||||
"""
|
||||
SELECT stream_id, room_id, event_id, soft_failed FROM sticky_events WHERE stream_id > ? AND stream_id <= ? LIMIT ?
|
||||
""",
|
||||
(from_id, to_id, limit),
|
||||
)
|
||||
return cast(List[Tuple[int, str, str, bool]], txn.fetchall())
|
||||
|
||||
async def get_sticky_event_ids_sent_by_self(
|
||||
self, room_id: str, from_stream_pos: int
|
||||
) -> List[str]:
|
||||
"""Get unexpired sticky event IDs which have been sent by users on this homeserver.
|
||||
|
||||
Used when sending sticky events eagerly to newly joined servers, or when catching up over federation.
|
||||
|
||||
Args:
|
||||
room_id: The room to fetch sticky events in.
|
||||
from_stream_pos: The stream position to return events from. May be 0 for newly joined servers.
|
||||
Returns:
|
||||
A list of event IDs, which may be empty.
|
||||
"""
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_sticky_event_ids_sent_by_self",
|
||||
self._get_sticky_event_ids_sent_by_self_txn,
|
||||
room_id,
|
||||
from_stream_pos,
|
||||
)
|
||||
|
||||
def _get_sticky_event_ids_sent_by_self_txn(
|
||||
self, txn: LoggingTransaction, room_id: str, from_stream_pos: int
|
||||
) -> List[str]:
|
||||
now_ms = self._now()
|
||||
txn.execute(
|
||||
"""
|
||||
SELECT sticky_events.event_id, sticky_events.sender, events.stream_ordering FROM sticky_events
|
||||
INNER JOIN events ON events.event_id = sticky_events.event_id
|
||||
WHERE soft_failed=? AND expires_at > ? AND sticky_events.room_id = ?
|
||||
""",
|
||||
(False, now_ms, room_id),
|
||||
)
|
||||
rows = cast(List[Tuple[str, str, int]], txn.fetchall())
|
||||
return [
|
||||
row[0]
|
||||
for row in rows
|
||||
if row[2] > from_stream_pos and self.hs.is_mine_id(row[1])
|
||||
]
|
||||
|
||||
async def reevaluate_soft_failed_sticky_events(
|
||||
self,
|
||||
room_id: str,
|
||||
events_and_contexts: List[EventPersistencePair],
|
||||
state_delta_for_room: Optional[DeltaState],
|
||||
) -> None:
|
||||
"""Re-evaluate soft failed events in the room provided.
|
||||
|
||||
Args:
|
||||
room_id: The room that all of the events belong to
|
||||
events_and_contexts: The events just persisted. These are not eligible for re-evaluation.
|
||||
state_delta_for_room: The changes to the current state, used to detect if we need to
|
||||
re-evaluate soft-failed sticky events.
|
||||
"""
|
||||
assert self._can_write_to_sticky_events
|
||||
|
||||
# fetch soft failed sticky events to recheck
|
||||
event_ids_to_check = await self._get_soft_failed_sticky_events_to_recheck(
|
||||
room_id, state_delta_for_room
|
||||
)
|
||||
# filter out soft-failed events in events_and_contexts as we just inserted them, so the
|
||||
# soft failure status won't have changed for them.
|
||||
persisting_event_ids = {ev.event_id for ev, _ in events_and_contexts}
|
||||
event_ids_to_check = [
|
||||
item for item in event_ids_to_check if item not in persisting_event_ids
|
||||
]
|
||||
if event_ids_to_check:
|
||||
logger.info(
|
||||
"_get_soft_failed_sticky_events_to_recheck => %s", event_ids_to_check
|
||||
)
|
||||
# recheck them and update any that now pass soft-fail checks.
|
||||
await self._recheck_soft_failed_events(room_id, event_ids_to_check)
|
||||
|
||||
def insert_sticky_events_txn(
|
||||
self,
|
||||
txn: LoggingTransaction,
|
||||
events: List[EventBase],
|
||||
) -> None:
|
||||
now_ms = self._now()
|
||||
# event, expires_at, stream_id
|
||||
sticky_events: List[Tuple[EventBase, int, int]] = []
|
||||
for ev in events:
|
||||
# MSC: Note: policy servers and other similar antispam techniques still apply to these events.
|
||||
if ev.internal_metadata.policy_server_spammy:
|
||||
continue
|
||||
# We shouldn't be passed rejected events, but if we do, we filter them out too.
|
||||
if ev.rejected_reason is not None:
|
||||
continue
|
||||
# We can't persist outlier sticky events as we don't know the room state at that event
|
||||
if ev.internal_metadata.is_outlier():
|
||||
continue
|
||||
# MSC: The presence of sticky.duration_ms with a valid value makes the event “sticky”
|
||||
sticky_duration = ev.sticky_duration()
|
||||
if sticky_duration:
|
||||
# MSC: The start time is min(now, origin_server_ts).
|
||||
# This ensures that malicious origin timestamps cannot specify start times in the future.
|
||||
# Calculate the end time as start_time + min(sticky.duration_ms, MAX_DURATION_MS).
|
||||
expires_at = min(ev.origin_server_ts, now_ms) + min(
|
||||
ev.get_dict()[StickyEvent.FIELD_NAME]["duration_ms"],
|
||||
StickyEvent.MAX_DURATION_MS,
|
||||
)
|
||||
# filter out already expired sticky events
|
||||
if expires_at > now_ms:
|
||||
sticky_events.append(
|
||||
(ev, expires_at, self._sticky_events_id_gen.get_next_txn(txn))
|
||||
)
|
||||
if len(sticky_events) == 0:
|
||||
return
|
||||
logger.info(
|
||||
"inserting %d sticky events in room %s",
|
||||
len(sticky_events),
|
||||
sticky_events[0][0].room_id,
|
||||
)
|
||||
self.db_pool.simple_insert_many_txn(
|
||||
txn,
|
||||
"sticky_events",
|
||||
keys=(
|
||||
"instance_name",
|
||||
"stream_id",
|
||||
"room_id",
|
||||
"event_id",
|
||||
"sender",
|
||||
"expires_at",
|
||||
"soft_failed",
|
||||
),
|
||||
values=[
|
||||
(
|
||||
self._instance_name,
|
||||
stream_id,
|
||||
ev.room_id,
|
||||
ev.event_id,
|
||||
ev.sender,
|
||||
expires_at,
|
||||
ev.internal_metadata.is_soft_failed(),
|
||||
)
|
||||
for (ev, expires_at, stream_id) in sticky_events
|
||||
],
|
||||
)
|
||||
|
||||
async def _get_soft_failed_sticky_events_to_recheck(
|
||||
self,
|
||||
room_id: str,
|
||||
state_delta_for_room: Optional[DeltaState],
|
||||
) -> List[str]:
|
||||
"""Fetch soft-failed sticky events which should be rechecked against the current state.
|
||||
|
||||
Soft-failed events are not rejected, so they pass auth at the state before
|
||||
the event and at the auth_events in the event. Instead, soft-failed events failed auth at
|
||||
the *current* state of the room. We only need to recheck soft failure if we have a reason to
|
||||
believe the event may pass that check now.
|
||||
|
||||
Note that we don't bother rechecking accepted events that may now be soft-failed, because
|
||||
by that point it's too late as we've already sent the event to clients.
|
||||
|
||||
Returns:
|
||||
A list of event IDs to recheck
|
||||
"""
|
||||
|
||||
if state_delta_for_room is None:
|
||||
# No change to current state => no way soft failure status could be different.
|
||||
return []
|
||||
|
||||
# any change to critical auth events may change soft failure status. This means any changes
|
||||
# to join rules, power levels or member events. If the state has changed but it isn't one
|
||||
# of those events, we don't need to recheck.
|
||||
critical_auth_types = (
|
||||
EventTypes.JoinRules,
|
||||
EventTypes.PowerLevels,
|
||||
EventTypes.Member,
|
||||
)
|
||||
critical_auth_types_changed = set()
|
||||
critical_auth_types_changed.update(
|
||||
[
|
||||
typ
|
||||
for typ, _ in state_delta_for_room.to_delete
|
||||
if typ in critical_auth_types
|
||||
]
|
||||
)
|
||||
critical_auth_types_changed.update(
|
||||
[
|
||||
typ
|
||||
for typ, _ in state_delta_for_room.to_insert
|
||||
if typ in critical_auth_types
|
||||
]
|
||||
)
|
||||
if len(critical_auth_types_changed) == 0:
|
||||
# No change to critical auth events => no way soft failure status could be different.
|
||||
return []
|
||||
|
||||
if critical_auth_types_changed == {EventTypes.Member}:
|
||||
# the final case we want to catch is when unprivileged users join/leave rooms. These users cause
|
||||
# changes in the critical auth types (the member event) but ultimately have no effect on soft
|
||||
# failure status for anyone but that user themselves.
|
||||
#
|
||||
# Grab the set of senders that have been modified and see if any of them sent a soft-failed
|
||||
# sticky event. If they did, then we need to re-evaluate. If they didn't, then we don't need to.
|
||||
new_membership_changes = set(
|
||||
[
|
||||
skey
|
||||
for typ, skey in state_delta_for_room.to_insert
|
||||
if typ == EventTypes.Member
|
||||
]
|
||||
+ [
|
||||
skey
|
||||
for typ, skey in state_delta_for_room.to_delete
|
||||
if typ == EventTypes.Member
|
||||
]
|
||||
)
|
||||
|
||||
# pull out senders of sticky events in this room
|
||||
events_to_recheck: List[
|
||||
Tuple[str]
|
||||
] = await self.db_pool.simple_select_many_batch(
|
||||
table="sticky_events",
|
||||
column="sender",
|
||||
iterable=new_membership_changes,
|
||||
keyvalues={
|
||||
"room_id": room_id,
|
||||
"soft_failed": True,
|
||||
},
|
||||
retcols=("event_id",),
|
||||
desc="_get_soft_failed_sticky_events_to_recheck_members",
|
||||
)
|
||||
return [event_id for (event_id,) in events_to_recheck]
|
||||
|
||||
# otherwise one of the following must be true:
|
||||
# - there was a change in PL or join rules
|
||||
# - there was a change in the membership of a sender of a soft-failed sticky event.
|
||||
# In both of these cases we want to re-evaluate soft failure status.
|
||||
#
|
||||
# NB: event auth checks are NOT recursive. We don't need to specifically handle the case where
|
||||
# an admin user's membership changes which causes a PL event to be allowed, as when the PL event
|
||||
# gets allowed we will re-evaluate anyway. E.g:
|
||||
#
|
||||
# PL(send_event=0, sender=Admin) #1
|
||||
# ^ ^_____________________
|
||||
# | |
|
||||
# . PL(send_event=50, sender=Mod) #2 sticky event (sender=User) #3
|
||||
#
|
||||
# In this scenario, the sticky event is soft-failed due to the Mod updating the PL event to
|
||||
# set send_event=50, which User does not have. If we learn of an event which makes Mod's PL
|
||||
# event invalid (say, Mod was banned by Admin concurrently to Mod setting the PL event), then
|
||||
# the act of seeing the ban event will cause the old PL event to be in the state delta, meaning
|
||||
# we will re-evaluate the sticky event due to the PL changing. We don't need to specially handle
|
||||
# this case.
|
||||
events_to_recheck = await self.db_pool.simple_select_list(
|
||||
table="sticky_events",
|
||||
keyvalues={
|
||||
"room_id": room_id,
|
||||
"soft_failed": True,
|
||||
},
|
||||
retcols=("event_id",),
|
||||
desc="_get_soft_failed_sticky_events_to_recheck",
|
||||
)
|
||||
return [event_id for (event_id,) in events_to_recheck]
|
||||
|
||||
async def _recheck_soft_failed_events(
|
||||
self,
|
||||
room_id: str,
|
||||
soft_failed_event_ids: Collection[str],
|
||||
) -> None:
|
||||
"""
|
||||
Recheck authorised but soft-failed events. The provided event IDs must have already passed
|
||||
all auth checks (so the event isn't rejected) but soft-failure checks.
|
||||
|
||||
Args:
|
||||
txn: The SQL transaction
|
||||
room_id: The room the event IDs are in.
|
||||
soft_failed_event_ids: The soft-failed events to re-evaluate.
|
||||
"""
|
||||
# Load all the soft-failed events to recheck, and pull out the precise state tuples we need
|
||||
soft_failed_event_map = await self.get_events(
|
||||
soft_failed_event_ids, allow_rejected=False
|
||||
)
|
||||
needed_tuples: Set[Tuple[str, str]] = set()
|
||||
for ev in soft_failed_event_map.values():
|
||||
needed_tuples.update(event_auth.auth_types_for_event(ev.room_version, ev))
|
||||
|
||||
# We know the events are otherwise authorised, so we only need to load the needed tuples from
|
||||
# the current state to check if the events pass auth.
|
||||
current_state_map = await self.get_partial_filtered_current_state_ids(
|
||||
room_id, StateFilter.from_types(needed_tuples)
|
||||
)
|
||||
current_state_ids_list = [e for _, e in current_state_map.items()]
|
||||
current_auth_events = await self.get_events_as_list(current_state_ids_list)
|
||||
passing_event_ids: Set[str] = set()
|
||||
for soft_failed_event in soft_failed_event_map.values():
|
||||
if soft_failed_event.internal_metadata.policy_server_spammy:
|
||||
# don't re-evaluate spam.
|
||||
continue
|
||||
try:
|
||||
# We don't need to check_state_independent_auth_rules as that doesn't depend on room state,
|
||||
# so if it passed once it'll pass again.
|
||||
event_auth.check_state_dependent_auth_rules(
|
||||
soft_failed_event, current_auth_events
|
||||
)
|
||||
passing_event_ids.add(soft_failed_event.event_id)
|
||||
except AuthError:
|
||||
pass
|
||||
|
||||
if not passing_event_ids:
|
||||
return
|
||||
|
||||
logger.info(
|
||||
"%s soft-failed events now pass current state checks in room %s : %s",
|
||||
len(passing_event_ids),
|
||||
room_id,
|
||||
shortstr(passing_event_ids),
|
||||
)
|
||||
# Update the DB with the new soft-failure status
|
||||
await self.db_pool.runInteraction(
|
||||
"_recheck_soft_failed_events",
|
||||
self._update_soft_failure_status_txn,
|
||||
passing_event_ids,
|
||||
)
|
||||
|
||||
def _update_soft_failure_status_txn(
|
||||
self, txn: LoggingTransaction, passing_event_ids: Set[str]
|
||||
) -> None:
|
||||
# Update the sticky events table so we notify downstream of the change in soft-failure status
|
||||
new_stream_ids: List[Tuple[str, int]] = [
|
||||
(event_id, self._sticky_events_id_gen.get_next_txn(txn))
|
||||
for event_id in passing_event_ids
|
||||
]
|
||||
# [event_id, stream_pos, event_id, stream_pos, ...]
|
||||
params = [p for pair in new_stream_ids for p in pair]
|
||||
if isinstance(txn.database_engine, PostgresEngine):
|
||||
values_placeholders = ", ".join(["(?, ?)"] * len(new_stream_ids))
|
||||
txn.execute(
|
||||
f"""
|
||||
UPDATE sticky_events AS se
|
||||
SET
|
||||
soft_failed = FALSE,
|
||||
stream_id = v.stream_id
|
||||
FROM (VALUES
|
||||
{values_placeholders}
|
||||
) AS v(event_id, stream_id)
|
||||
WHERE se.event_id = v.event_id;
|
||||
""",
|
||||
params,
|
||||
)
|
||||
# Also update the internal metadata on the event itself, so when we filter_events_for_client
|
||||
# we don't filter them out. It's a bit sad internal_metadata is TEXT and not JSONB...
|
||||
clause, args = make_in_list_sql_clause(
|
||||
txn.database_engine,
|
||||
"event_id",
|
||||
passing_event_ids,
|
||||
)
|
||||
txn.execute(
|
||||
"""
|
||||
UPDATE event_json
|
||||
SET internal_metadata = (
|
||||
jsonb_set(internal_metadata::jsonb, '{soft_failed}', 'false'::jsonb)
|
||||
)::text
|
||||
WHERE %s
|
||||
"""
|
||||
% clause,
|
||||
args,
|
||||
)
|
||||
else:
|
||||
# Use a CASE expression to update in bulk for sqlite
|
||||
case_expr = " ".join(["WHEN ? THEN ? " for _ in new_stream_ids])
|
||||
txn.execute(
|
||||
f"""
|
||||
UPDATE sticky_events
|
||||
SET
|
||||
soft_failed = FALSE,
|
||||
stream_id = CASE event_id
|
||||
{case_expr}
|
||||
ELSE stream_id
|
||||
END
|
||||
WHERE event_id IN ({",".join("?" * len(new_stream_ids))});
|
||||
""",
|
||||
params + [eid for eid, _ in new_stream_ids],
|
||||
)
|
||||
clause, args = make_in_list_sql_clause(
|
||||
txn.database_engine,
|
||||
"event_id",
|
||||
passing_event_ids,
|
||||
)
|
||||
txn.execute(
|
||||
f"""
|
||||
UPDATE event_json
|
||||
SET internal_metadata = json_set(internal_metadata, '$.soft_failed', json('false'))
|
||||
WHERE {clause}
|
||||
""",
|
||||
args,
|
||||
)
|
||||
# finally, invalidate caches
|
||||
for event_id in passing_event_ids:
|
||||
self.invalidate_get_event_cache_after_txn(txn, event_id)
|
||||
|
||||
async def _delete_expired_sticky_events(self) -> None:
|
||||
logger.info("delete_expired_sticky_events")
|
||||
await self.db_pool.runInteraction(
|
||||
"_delete_expired_sticky_events",
|
||||
self._delete_expired_sticky_events_txn,
|
||||
self._now(),
|
||||
)
|
||||
|
||||
def _delete_expired_sticky_events_txn(
|
||||
self, txn: LoggingTransaction, now: int
|
||||
) -> None:
|
||||
txn.execute(
|
||||
"""
|
||||
DELETE FROM sticky_events WHERE expires_at < ?
|
||||
""",
|
||||
(now,),
|
||||
)
|
||||
|
||||
def _now(self) -> int:
|
||||
return round(time.time() * 1000)
|
||||
|
||||
def _run_background_cleanup(self) -> Deferred:
|
||||
return self.hs.run_as_background_process(
|
||||
"delete_expired_sticky_events",
|
||||
self._delete_expired_sticky_events,
|
||||
)
|
||||
@@ -380,7 +380,7 @@ class TransactionWorkerStore(CacheInvalidationWorkerStore):
|
||||
) -> List[str]:
|
||||
"""
|
||||
Returns at most 50 event IDs and their corresponding stream_orderings
|
||||
that correspond to the oldest events that have not yet been sent to
|
||||
that correspond to the newest events that have not yet been sent to
|
||||
the destination.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -19,7 +19,7 @@
|
||||
#
|
||||
#
|
||||
|
||||
SCHEMA_VERSION = 92 # remember to update the list below when updating
|
||||
SCHEMA_VERSION = 93 # remember to update the list below when updating
|
||||
"""Represents the expectations made by the codebase about the database schema
|
||||
|
||||
This should be incremented whenever the codebase changes its requirements on the
|
||||
|
||||
28
synapse/storage/schema/main/delta/93/01_sticky_events.sql
Normal file
28
synapse/storage/schema/main/delta/93/01_sticky_events.sql
Normal file
@@ -0,0 +1,28 @@
|
||||
--
|
||||
-- This file is licensed under the Affero General Public License (AGPL) version 3.
|
||||
--
|
||||
-- Copyright (C) 2025 New Vector, Ltd
|
||||
--
|
||||
-- This program is free software: you can redistribute it and/or modify
|
||||
-- it under the terms of the GNU Affero General Public License as
|
||||
-- published by the Free Software Foundation, either version 3 of the
|
||||
-- License, or (at your option) any later version.
|
||||
--
|
||||
-- See the GNU Affero General Public License for more details:
|
||||
-- <https://www.gnu.org/licenses/agpl-3.0.html>.
|
||||
|
||||
CREATE TABLE sticky_events (
|
||||
stream_id INTEGER NOT NULL PRIMARY KEY,
|
||||
instance_name TEXT NOT NULL,
|
||||
room_id TEXT NOT NULL,
|
||||
event_id TEXT NOT NULL,
|
||||
sender TEXT NOT NULL,
|
||||
expires_at BIGINT NOT NULL,
|
||||
soft_failed BOOLEAN NOT NULL
|
||||
);
|
||||
|
||||
-- for pulling out soft failed events by room
|
||||
CREATE INDEX sticky_events_room_idx ON sticky_events (room_id, soft_failed);
|
||||
|
||||
-- A optional int for combining sticky events with delayed events. Used at send time.
|
||||
ALTER TABLE delayed_events ADD COLUMN sticky_duration_ms BIGINT;
|
||||
@@ -0,0 +1,18 @@
|
||||
--
|
||||
-- This file is licensed under the Affero General Public License (AGPL) version 3.
|
||||
--
|
||||
-- Copyright (C) 2025 New Vector, Ltd
|
||||
--
|
||||
-- This program is free software: you can redistribute it and/or modify
|
||||
-- it under the terms of the GNU Affero General Public License as
|
||||
-- published by the Free Software Foundation, either version 3 of the
|
||||
-- License, or (at your option) any later version.
|
||||
--
|
||||
-- See the GNU Affero General Public License for more details:
|
||||
-- <https://www.gnu.org/licenses/agpl-3.0.html>.
|
||||
|
||||
CREATE SEQUENCE sticky_events_sequence;
|
||||
-- Synapse streams start at 2, because the default position is 1
|
||||
-- so any item inserted at position 1 is ignored.
|
||||
-- We have to use nextval not START WITH 2, see https://github.com/element-hq/synapse/issues/18712
|
||||
SELECT nextval('sticky_events_sequence');
|
||||
@@ -84,6 +84,7 @@ class EventSources:
|
||||
self._instance_name
|
||||
)
|
||||
thread_subscriptions_key = self.store.get_max_thread_subscriptions_stream_id()
|
||||
sticky_events_key = self.store.get_max_sticky_events_stream_id()
|
||||
|
||||
token = StreamToken(
|
||||
room_key=self.sources.room.get_current_key(),
|
||||
@@ -98,6 +99,7 @@ class EventSources:
|
||||
groups_key=0,
|
||||
un_partial_stated_rooms_key=un_partial_stated_rooms_key,
|
||||
thread_subscriptions_key=thread_subscriptions_key,
|
||||
sticky_events_key=sticky_events_key,
|
||||
)
|
||||
return token
|
||||
|
||||
@@ -125,6 +127,7 @@ class EventSources:
|
||||
StreamKeyType.DEVICE_LIST: self.store.get_device_stream_id_generator(),
|
||||
StreamKeyType.UN_PARTIAL_STATED_ROOMS: self.store.get_un_partial_stated_rooms_id_generator(),
|
||||
StreamKeyType.THREAD_SUBSCRIPTIONS: self.store.get_thread_subscriptions_stream_id_generator(),
|
||||
StreamKeyType.STICKY_EVENTS: self.store.get_sticky_events_stream_id_generator(),
|
||||
}
|
||||
|
||||
for _, key in StreamKeyType.__members__.items():
|
||||
|
||||
@@ -1011,6 +1011,7 @@ class StreamKeyType(Enum):
|
||||
DEVICE_LIST = "device_list_key"
|
||||
UN_PARTIAL_STATED_ROOMS = "un_partial_stated_rooms_key"
|
||||
THREAD_SUBSCRIPTIONS = "thread_subscriptions_key"
|
||||
STICKY_EVENTS = "sticky_events_key"
|
||||
|
||||
|
||||
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||
@@ -1032,6 +1033,7 @@ class StreamToken:
|
||||
9. `groups_key`: `1` (note that this key is now unused)
|
||||
10. `un_partial_stated_rooms_key`: `379`
|
||||
11. `thread_subscriptions_key`: 4242
|
||||
12. `sticky_events_key`: 4141
|
||||
|
||||
You can see how many of these keys correspond to the various
|
||||
fields in a "/sync" response:
|
||||
@@ -1091,6 +1093,7 @@ class StreamToken:
|
||||
groups_key: int
|
||||
un_partial_stated_rooms_key: int
|
||||
thread_subscriptions_key: int
|
||||
sticky_events_key: int
|
||||
|
||||
_SEPARATOR = "_"
|
||||
START: ClassVar["StreamToken"]
|
||||
@@ -1119,6 +1122,7 @@ class StreamToken:
|
||||
groups_key,
|
||||
un_partial_stated_rooms_key,
|
||||
thread_subscriptions_key,
|
||||
sticky_events_key,
|
||||
) = keys
|
||||
|
||||
return cls(
|
||||
@@ -1135,6 +1139,7 @@ class StreamToken:
|
||||
groups_key=int(groups_key),
|
||||
un_partial_stated_rooms_key=int(un_partial_stated_rooms_key),
|
||||
thread_subscriptions_key=int(thread_subscriptions_key),
|
||||
sticky_events_key=int(sticky_events_key),
|
||||
)
|
||||
except CancelledError:
|
||||
raise
|
||||
@@ -1158,6 +1163,7 @@ class StreamToken:
|
||||
str(self.groups_key),
|
||||
str(self.un_partial_stated_rooms_key),
|
||||
str(self.thread_subscriptions_key),
|
||||
str(self.sticky_events_key),
|
||||
]
|
||||
)
|
||||
|
||||
@@ -1223,6 +1229,7 @@ class StreamToken:
|
||||
StreamKeyType.TYPING,
|
||||
StreamKeyType.UN_PARTIAL_STATED_ROOMS,
|
||||
StreamKeyType.THREAD_SUBSCRIPTIONS,
|
||||
StreamKeyType.STICKY_EVENTS,
|
||||
],
|
||||
) -> int: ...
|
||||
|
||||
@@ -1279,7 +1286,7 @@ class StreamToken:
|
||||
f"account_data: {self.account_data_key}, push_rules: {self.push_rules_key}, "
|
||||
f"to_device: {self.to_device_key}, device_list: {self.device_list_key}, "
|
||||
f"groups: {self.groups_key}, un_partial_stated_rooms: {self.un_partial_stated_rooms_key},"
|
||||
f"thread_subscriptions: {self.thread_subscriptions_key})"
|
||||
f"thread_subscriptions: {self.thread_subscriptions_key}, sticky_events: {self.sticky_events_key})"
|
||||
)
|
||||
|
||||
|
||||
@@ -1295,6 +1302,7 @@ StreamToken.START = StreamToken(
|
||||
groups_key=0,
|
||||
un_partial_stated_rooms_key=0,
|
||||
thread_subscriptions_key=0,
|
||||
sticky_events_key=0,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -21,6 +21,7 @@ from typing import (
|
||||
AbstractSet,
|
||||
Any,
|
||||
Callable,
|
||||
Collection,
|
||||
Dict,
|
||||
Final,
|
||||
Generic,
|
||||
@@ -396,12 +397,26 @@ class SlidingSyncResult:
|
||||
or bool(self.prev_batch)
|
||||
)
|
||||
|
||||
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||
class StickyEventsExtension:
|
||||
"""The Sticky Events extension (MSC4354)
|
||||
|
||||
Attributes:
|
||||
room_id_to_sticky_events: map (room_id -> [unexpired_sticky_events])
|
||||
"""
|
||||
|
||||
room_id_to_sticky_events: Mapping[str, Collection[EventBase]]
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
return bool(self.room_id_to_sticky_events)
|
||||
|
||||
to_device: Optional[ToDeviceExtension] = None
|
||||
e2ee: Optional[E2eeExtension] = None
|
||||
account_data: Optional[AccountDataExtension] = None
|
||||
receipts: Optional[ReceiptsExtension] = None
|
||||
typing: Optional[TypingExtension] = None
|
||||
thread_subscriptions: Optional[ThreadSubscriptionsExtension] = None
|
||||
sticky_events: Optional[StickyEventsExtension] = None
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
return bool(
|
||||
@@ -411,6 +426,7 @@ class SlidingSyncResult:
|
||||
or self.receipts
|
||||
or self.typing
|
||||
or self.thread_subscriptions
|
||||
or self.sticky_events
|
||||
)
|
||||
|
||||
next_pos: SlidingSyncStreamToken
|
||||
|
||||
@@ -376,6 +376,15 @@ class SlidingSyncBody(RequestBodyModel):
|
||||
enabled: Optional[StrictBool] = False
|
||||
limit: StrictInt = 100
|
||||
|
||||
class StickyEventsExtension(RequestBodyModel):
|
||||
"""The Sticky Events extension (MSC4354)
|
||||
|
||||
Attributes:
|
||||
enabled
|
||||
"""
|
||||
|
||||
enabled: Optional[StrictBool] = False
|
||||
|
||||
to_device: Optional[ToDeviceExtension] = None
|
||||
e2ee: Optional[E2eeExtension] = None
|
||||
account_data: Optional[AccountDataExtension] = None
|
||||
@@ -384,6 +393,9 @@ class SlidingSyncBody(RequestBodyModel):
|
||||
thread_subscriptions: Optional[ThreadSubscriptionsExtension] = Field(
|
||||
alias="io.element.msc4308.thread_subscriptions"
|
||||
)
|
||||
sticky_events: Optional[StickyEventsExtension] = Field(
|
||||
alias="org.matrix.msc4354.sticky_events"
|
||||
)
|
||||
|
||||
conn_id: Optional[StrictStr]
|
||||
|
||||
|
||||
@@ -346,6 +346,7 @@ T3 = TypeVar("T3")
|
||||
T4 = TypeVar("T4")
|
||||
T5 = TypeVar("T5")
|
||||
T6 = TypeVar("T6")
|
||||
T7 = TypeVar("T7")
|
||||
|
||||
|
||||
@overload
|
||||
@@ -477,6 +478,30 @@ async def gather_optional_coroutines(
|
||||
]: ...
|
||||
|
||||
|
||||
@overload
|
||||
async def gather_optional_coroutines(
|
||||
*coroutines: Unpack[
|
||||
Tuple[
|
||||
Optional[Coroutine[Any, Any, T1]],
|
||||
Optional[Coroutine[Any, Any, T2]],
|
||||
Optional[Coroutine[Any, Any, T3]],
|
||||
Optional[Coroutine[Any, Any, T4]],
|
||||
Optional[Coroutine[Any, Any, T5]],
|
||||
Optional[Coroutine[Any, Any, T6]],
|
||||
Optional[Coroutine[Any, Any, T7]],
|
||||
]
|
||||
],
|
||||
) -> Tuple[
|
||||
Optional[T1],
|
||||
Optional[T2],
|
||||
Optional[T3],
|
||||
Optional[T4],
|
||||
Optional[T5],
|
||||
Optional[T6],
|
||||
Optional[T7],
|
||||
]: ...
|
||||
|
||||
|
||||
async def gather_optional_coroutines(
|
||||
*coroutines: Unpack[Tuple[Optional[Coroutine[Any, Any, T1]], ...]],
|
||||
) -> Tuple[Optional[T1], ...]:
|
||||
|
||||
@@ -209,6 +209,15 @@ async def filter_events_for_client(
|
||||
# to the cache!
|
||||
cloned = clone_event(filtered)
|
||||
cloned.unsigned[EventUnsignedContentFields.MEMBERSHIP] = user_membership
|
||||
if storage.main.config.experimental.msc4354_enabled:
|
||||
sticky_duration = cloned.sticky_duration()
|
||||
if sticky_duration:
|
||||
now = storage.main.clock.time_msec()
|
||||
expires_at = min(cloned.origin_server_ts, now) + sticky_duration
|
||||
if sticky_duration and expires_at > now:
|
||||
cloned.unsigned[EventUnsignedContentFields.STICKY_TTL] = (
|
||||
expires_at - now
|
||||
)
|
||||
|
||||
return cloned
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import time
|
||||
from typing import Callable, Collection, List, Optional, Tuple
|
||||
from unittest import mock
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
@@ -19,6 +20,7 @@ from synapse.types import JsonDict
|
||||
from synapse.util.clock import Clock
|
||||
from synapse.util.retryutils import NotRetryingDestination
|
||||
|
||||
from tests import unittest
|
||||
from tests.test_utils import event_injection
|
||||
from tests.unittest import FederatingHomeserverTestCase
|
||||
|
||||
@@ -452,6 +454,60 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
|
||||
# has been successfully sent.
|
||||
self.assertCountEqual(woken, set(server_names[:-1]))
|
||||
|
||||
@unittest.override_config({"experimental_features": {"msc4354_enabled": True}})
|
||||
def test_sends_sticky_events(self) -> None:
|
||||
"""Test that we send sticky events in addition to the latest event in the room when catching up."""
|
||||
# make the clock used when generating origin_server_ts the same as the clock used to check expiry
|
||||
self.reactor.advance(time.time())
|
||||
per_dest_queue, sent_pdus = self.make_fake_destination_queue()
|
||||
|
||||
# Make a room with a local user, and two servers. One will go offline
|
||||
# and one will send some events.
|
||||
self.register_user("u1", "you the one")
|
||||
u1_token = self.login("u1", "you the one")
|
||||
room_1 = self.helper.create_room_as("u1", tok=u1_token)
|
||||
|
||||
self.get_success(
|
||||
event_injection.inject_member_event(self.hs, room_1, "@user:host2", "join")
|
||||
)
|
||||
event_1 = self.get_success(
|
||||
event_injection.inject_member_event(self.hs, room_1, "@user:host3", "join")
|
||||
)
|
||||
|
||||
# now we send a sticky event that we expect to be bundled with the fwd extrem event
|
||||
sticky_event_id = self.helper.send_sticky_event(
|
||||
room_1, "m.room.sticky", 60000, tok=u1_token
|
||||
)["event_id"]
|
||||
# ..and other uninteresting events
|
||||
self.helper.send(room_1, "you hear me!!", tok=u1_token)
|
||||
|
||||
# Now simulate us receiving an event from the still online remote.
|
||||
fwd_extrem_event = self.get_success(
|
||||
event_injection.inject_event(
|
||||
self.hs,
|
||||
type=EventTypes.Message,
|
||||
sender="@user:host3",
|
||||
room_id=room_1,
|
||||
content={"msgtype": "m.text", "body": "Hello"},
|
||||
)
|
||||
)
|
||||
|
||||
assert event_1.internal_metadata.stream_ordering is not None
|
||||
self.get_success(
|
||||
self.hs.get_datastores().main.set_destination_last_successful_stream_ordering(
|
||||
"host2", event_1.internal_metadata.stream_ordering
|
||||
)
|
||||
)
|
||||
|
||||
self.get_success(per_dest_queue._catch_up_transmission_loop())
|
||||
|
||||
# We expect the sticky event and the fwd extrem to be sent
|
||||
self.assertEqual(len(sent_pdus), 2)
|
||||
# We expect the sticky event to appear before the fwd extrem
|
||||
self.assertEqual(sent_pdus[0].event_id, sticky_event_id)
|
||||
self.assertEqual(sent_pdus[1].event_id, fwd_extrem_event.event_id)
|
||||
self.assertFalse(per_dest_queue._catching_up)
|
||||
|
||||
def test_not_latest_event(self) -> None:
|
||||
"""Test that we send the latest event in the room even if its not ours."""
|
||||
|
||||
|
||||
@@ -2244,7 +2244,7 @@ class RoomMessagesTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
def test_topo_token_is_accepted(self) -> None:
|
||||
"""Test Topo Token is accepted."""
|
||||
token = "t1-0_0_0_0_0_0_0_0_0_0_0"
|
||||
token = "t1-0_0_0_0_0_0_0_0_0_0_0_0"
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
"/_synapse/admin/v1/rooms/%s/messages?from=%s" % (self.room_id, token),
|
||||
@@ -2258,7 +2258,7 @@ class RoomMessagesTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
def test_stream_token_is_accepted_for_fwd_pagianation(self) -> None:
|
||||
"""Test that stream token is accepted for forward pagination."""
|
||||
token = "s0_0_0_0_0_0_0_0_0_0_0"
|
||||
token = "s0_0_0_0_0_0_0_0_0_0_0_0"
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
"/_synapse/admin/v1/rooms/%s/messages?from=%s" % (self.room_id, token),
|
||||
|
||||
@@ -2245,7 +2245,7 @@ class RoomMessageListTestCase(RoomBase):
|
||||
self.room_id = self.helper.create_room_as(self.user_id)
|
||||
|
||||
def test_topo_token_is_accepted(self) -> None:
|
||||
token = "t1-0_0_0_0_0_0_0_0_0_0_0"
|
||||
token = "t1-0_0_0_0_0_0_0_0_0_0_0_0"
|
||||
channel = self.make_request(
|
||||
"GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token)
|
||||
)
|
||||
@@ -2256,7 +2256,7 @@ class RoomMessageListTestCase(RoomBase):
|
||||
self.assertTrue("end" in channel.json_body)
|
||||
|
||||
def test_stream_token_is_accepted_for_fwd_pagianation(self) -> None:
|
||||
token = "s0_0_0_0_0_0_0_0_0_0_0"
|
||||
token = "s0_0_0_0_0_0_0_0_0_0_0_0"
|
||||
channel = self.make_request(
|
||||
"GET", "/rooms/%s/messages?access_token=x&from=%s" % (self.room_id, token)
|
||||
)
|
||||
|
||||
@@ -456,6 +456,49 @@ class RestHelper:
|
||||
|
||||
return channel.json_body
|
||||
|
||||
def send_sticky_event(
|
||||
self,
|
||||
room_id: str,
|
||||
type: str,
|
||||
duration_ms: int,
|
||||
content: Optional[dict] = None,
|
||||
txn_id: Optional[str] = None,
|
||||
tok: Optional[str] = None,
|
||||
expect_code: int = HTTPStatus.OK,
|
||||
custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
|
||||
) -> JsonDict:
|
||||
if txn_id is None:
|
||||
txn_id = "m%s" % (str(time.time()))
|
||||
|
||||
path = (
|
||||
"/_matrix/client/r0/rooms/%s/send/%s/%s?org.matrix.msc4354.sticky_duration_ms=%d"
|
||||
% (
|
||||
room_id,
|
||||
type,
|
||||
txn_id,
|
||||
duration_ms,
|
||||
)
|
||||
)
|
||||
if tok:
|
||||
path = path + "&access_token=%s" % tok
|
||||
|
||||
channel = make_request(
|
||||
self.reactor,
|
||||
self.site,
|
||||
"PUT",
|
||||
path,
|
||||
content or {},
|
||||
custom_headers=custom_headers,
|
||||
)
|
||||
|
||||
assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % (
|
||||
expect_code,
|
||||
channel.code,
|
||||
channel.result["body"],
|
||||
)
|
||||
|
||||
return channel.json_body
|
||||
|
||||
def get_event(
|
||||
self,
|
||||
room_id: str,
|
||||
|
||||
Reference in New Issue
Block a user