Compare commits
14 Commits
v1.59.0rc1
...
babolivier
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
df3f489f5b | ||
|
|
32e64b4bcd | ||
|
|
84facf769e | ||
|
|
c72d26c1e1 | ||
|
|
c997bfb926 | ||
|
|
29f06704b8 | ||
|
|
989fa33096 | ||
|
|
147f098fb4 | ||
|
|
dbb12a0b54 | ||
|
|
5cfb004595 | ||
|
|
5c00151c28 | ||
|
|
2aad0ae57f | ||
|
|
b44fbdffa4 | ||
|
|
02cdace707 |
1
changelog.d/12477.misc
Normal file
1
changelog.d/12477.misc
Normal file
@@ -0,0 +1 @@
|
||||
Add some type hints to datastore.
|
||||
1
changelog.d/12586.misc
Normal file
1
changelog.d/12586.misc
Normal file
@@ -0,0 +1 @@
|
||||
Add `@cancellable` decorator, for use on endpoint methods that can be cancelled when clients disconnect.
|
||||
1
changelog.d/12588.misc
Normal file
1
changelog.d/12588.misc
Normal file
@@ -0,0 +1 @@
|
||||
Add ability to cancel disconnected requests to `SynapseRequest`.
|
||||
1
changelog.d/12630.misc
Normal file
1
changelog.d/12630.misc
Normal file
@@ -0,0 +1 @@
|
||||
Add a helper class for testing request cancellation.
|
||||
1
changelog.d/12676.misc
Normal file
1
changelog.d/12676.misc
Normal file
@@ -0,0 +1 @@
|
||||
Improve documentation of the `synapse.push` module.
|
||||
1
changelog.d/12677.misc
Normal file
1
changelog.d/12677.misc
Normal file
@@ -0,0 +1 @@
|
||||
Refactor functions to on `PushRuleEvaluatorForEvent`.
|
||||
1
changelog.d/12679.misc
Normal file
1
changelog.d/12679.misc
Normal file
@@ -0,0 +1 @@
|
||||
Preparation for database schema simplifications: stop writing to `event_reference_hashes`.
|
||||
1
changelog.d/12683.bugfix
Normal file
1
changelog.d/12683.bugfix
Normal file
@@ -0,0 +1 @@
|
||||
Fix a bug introduced in Synapse 1.57.0 where `/messages` would throw a 500 error when querying for a non-existent room.
|
||||
1
changelog.d/12689.misc
Normal file
1
changelog.d/12689.misc
Normal file
@@ -0,0 +1 @@
|
||||
Refactor `EventContext` class.
|
||||
1
changelog.d/12694.misc
Normal file
1
changelog.d/12694.misc
Normal file
@@ -0,0 +1 @@
|
||||
Capture the `Deferred` for request cancellation in `_AsyncResource`.
|
||||
1
changelog.d/12695.misc
Normal file
1
changelog.d/12695.misc
Normal file
@@ -0,0 +1 @@
|
||||
Fixes an incorrect type hint for `Filter._check_event_relations`.
|
||||
1
changelog.d/12712.bugfix
Normal file
1
changelog.d/12712.bugfix
Normal file
@@ -0,0 +1 @@
|
||||
Fix a bug introduced in Synapse 1.38.0 where empty rooms would be created automatically if an MAU limit was set and time-based cache expiry was turned on.
|
||||
@@ -19,6 +19,7 @@ from typing import (
|
||||
TYPE_CHECKING,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Collection,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
@@ -444,9 +445,9 @@ class Filter:
|
||||
return room_ids
|
||||
|
||||
async def _check_event_relations(
|
||||
self, events: Iterable[FilterEvent]
|
||||
self, events: Collection[FilterEvent]
|
||||
) -> List[FilterEvent]:
|
||||
# The event IDs to check, mypy doesn't understand the ifinstance check.
|
||||
# The event IDs to check, mypy doesn't understand the isinstance check.
|
||||
event_ids = [event.event_id for event in events if isinstance(event, EventBase)] # type: ignore[attr-defined]
|
||||
event_ids_to_keep = set(
|
||||
await self._store.events_have_relations(
|
||||
|
||||
@@ -15,12 +15,10 @@ from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
||||
|
||||
import attr
|
||||
from frozendict import frozendict
|
||||
|
||||
from twisted.internet.defer import Deferred
|
||||
from typing_extensions import Literal
|
||||
|
||||
from synapse.appservice import ApplicationService
|
||||
from synapse.events import EventBase
|
||||
from synapse.logging.context import make_deferred_yieldable, run_in_background
|
||||
from synapse.types import JsonDict, StateMap
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -60,6 +58,9 @@ class EventContext:
|
||||
If ``state_group`` is None (ie, the event is an outlier),
|
||||
``state_group_before_event`` will always also be ``None``.
|
||||
|
||||
state_delta_due_to_event: If `state_group` and `state_group_before_event` are not None
|
||||
then this is the delta of the state between the two groups.
|
||||
|
||||
prev_group: If it is known, ``state_group``'s prev_group. Note that this being
|
||||
None does not necessarily mean that ``state_group`` does not have
|
||||
a prev_group!
|
||||
@@ -78,73 +79,47 @@ class EventContext:
|
||||
app_service: If this event is being sent by a (local) application service, that
|
||||
app service.
|
||||
|
||||
_current_state_ids: The room state map, including this event - ie, the state
|
||||
in ``state_group``.
|
||||
|
||||
(type, state_key) -> event_id
|
||||
|
||||
For an outlier, this is {}
|
||||
|
||||
Note that this is a private attribute: it should be accessed via
|
||||
``get_current_state_ids``. _AsyncEventContext impl calculates this
|
||||
on-demand: it will be None until that happens.
|
||||
|
||||
_prev_state_ids: The room state map, excluding this event - ie, the state
|
||||
in ``state_group_before_event``. For a non-state
|
||||
event, this will be the same as _current_state_events.
|
||||
|
||||
Note that it is a completely different thing to prev_group!
|
||||
|
||||
(type, state_key) -> event_id
|
||||
|
||||
For an outlier, this is {}
|
||||
|
||||
As with _current_state_ids, this is a private attribute. It should be
|
||||
accessed via get_prev_state_ids.
|
||||
|
||||
partial_state: if True, we may be storing this event with a temporary,
|
||||
incomplete state.
|
||||
"""
|
||||
|
||||
rejected: Union[bool, str] = False
|
||||
_storage: "Storage"
|
||||
rejected: Union[Literal[False], str] = False
|
||||
_state_group: Optional[int] = None
|
||||
state_group_before_event: Optional[int] = None
|
||||
_state_delta_due_to_event: Optional[StateMap[str]] = None
|
||||
prev_group: Optional[int] = None
|
||||
delta_ids: Optional[StateMap[str]] = None
|
||||
app_service: Optional[ApplicationService] = None
|
||||
|
||||
_current_state_ids: Optional[StateMap[str]] = None
|
||||
_prev_state_ids: Optional[StateMap[str]] = None
|
||||
|
||||
partial_state: bool = False
|
||||
|
||||
@staticmethod
|
||||
def with_state(
|
||||
storage: "Storage",
|
||||
state_group: Optional[int],
|
||||
state_group_before_event: Optional[int],
|
||||
current_state_ids: Optional[StateMap[str]],
|
||||
prev_state_ids: Optional[StateMap[str]],
|
||||
state_delta_due_to_event: Optional[StateMap[str]],
|
||||
partial_state: bool,
|
||||
prev_group: Optional[int] = None,
|
||||
delta_ids: Optional[StateMap[str]] = None,
|
||||
) -> "EventContext":
|
||||
return EventContext(
|
||||
current_state_ids=current_state_ids,
|
||||
prev_state_ids=prev_state_ids,
|
||||
storage=storage,
|
||||
state_group=state_group,
|
||||
state_group_before_event=state_group_before_event,
|
||||
state_delta_due_to_event=state_delta_due_to_event,
|
||||
prev_group=prev_group,
|
||||
delta_ids=delta_ids,
|
||||
partial_state=partial_state,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def for_outlier() -> "EventContext":
|
||||
def for_outlier(
|
||||
storage: "Storage",
|
||||
) -> "EventContext":
|
||||
"""Return an EventContext instance suitable for persisting an outlier event"""
|
||||
return EventContext(
|
||||
current_state_ids={},
|
||||
prev_state_ids={},
|
||||
)
|
||||
return EventContext(storage=storage)
|
||||
|
||||
async def serialize(self, event: EventBase, store: "DataStore") -> JsonDict:
|
||||
"""Converts self to a type that can be serialized as JSON, and then
|
||||
@@ -157,24 +132,14 @@ class EventContext:
|
||||
The serialized event.
|
||||
"""
|
||||
|
||||
# We don't serialize the full state dicts, instead they get pulled out
|
||||
# of the DB on the other side. However, the other side can't figure out
|
||||
# the prev_state_ids, so if we're a state event we include the event
|
||||
# id that we replaced in the state.
|
||||
if event.is_state():
|
||||
prev_state_ids = await self.get_prev_state_ids()
|
||||
prev_state_id = prev_state_ids.get((event.type, event.state_key))
|
||||
else:
|
||||
prev_state_id = None
|
||||
|
||||
return {
|
||||
"prev_state_id": prev_state_id,
|
||||
"event_type": event.type,
|
||||
"event_state_key": event.get_state_key(),
|
||||
"state_group": self._state_group,
|
||||
"state_group_before_event": self.state_group_before_event,
|
||||
"rejected": self.rejected,
|
||||
"prev_group": self.prev_group,
|
||||
"state_delta_due_to_event": _encode_state_dict(
|
||||
self._state_delta_due_to_event
|
||||
),
|
||||
"delta_ids": _encode_state_dict(self.delta_ids),
|
||||
"app_service_id": self.app_service.id if self.app_service else None,
|
||||
"partial_state": self.partial_state,
|
||||
@@ -192,16 +157,16 @@ class EventContext:
|
||||
Returns:
|
||||
The event context.
|
||||
"""
|
||||
context = _AsyncEventContextImpl(
|
||||
context = EventContext(
|
||||
# We use the state_group and prev_state_id stuff to pull the
|
||||
# current_state_ids out of the DB and construct prev_state_ids.
|
||||
storage=storage,
|
||||
prev_state_id=input["prev_state_id"],
|
||||
event_type=input["event_type"],
|
||||
event_state_key=input["event_state_key"],
|
||||
state_group=input["state_group"],
|
||||
state_group_before_event=input["state_group_before_event"],
|
||||
prev_group=input["prev_group"],
|
||||
state_delta_due_to_event=_decode_state_dict(
|
||||
input["state_delta_due_to_event"]
|
||||
),
|
||||
delta_ids=_decode_state_dict(input["delta_ids"]),
|
||||
rejected=input["rejected"],
|
||||
partial_state=input.get("partial_state", False),
|
||||
@@ -249,8 +214,15 @@ class EventContext:
|
||||
if self.rejected:
|
||||
raise RuntimeError("Attempt to access state_ids of rejected event")
|
||||
|
||||
await self._ensure_fetched()
|
||||
return self._current_state_ids
|
||||
assert self._state_delta_due_to_event is not None
|
||||
|
||||
prev_state_ids = await self.get_prev_state_ids()
|
||||
|
||||
if self._state_delta_due_to_event:
|
||||
prev_state_ids = dict(prev_state_ids)
|
||||
prev_state_ids.update(self._state_delta_due_to_event)
|
||||
|
||||
return prev_state_ids
|
||||
|
||||
async def get_prev_state_ids(self) -> StateMap[str]:
|
||||
"""
|
||||
@@ -265,94 +237,10 @@ class EventContext:
|
||||
Maps a (type, state_key) to the event ID of the state event matching
|
||||
this tuple.
|
||||
"""
|
||||
await self._ensure_fetched()
|
||||
# There *should* be previous state IDs now.
|
||||
assert self._prev_state_ids is not None
|
||||
return self._prev_state_ids
|
||||
|
||||
def get_cached_current_state_ids(self) -> Optional[StateMap[str]]:
|
||||
"""Gets the current state IDs if we have them already cached.
|
||||
|
||||
It is an error to access this for a rejected event, since rejected state should
|
||||
not make it into the room state. This method will raise an exception if
|
||||
``rejected`` is set.
|
||||
|
||||
Returns:
|
||||
Returns None if we haven't cached the state or if state_group is None
|
||||
(which happens when the associated event is an outlier).
|
||||
|
||||
Otherwise, returns the the current state IDs.
|
||||
"""
|
||||
if self.rejected:
|
||||
raise RuntimeError("Attempt to access state_ids of rejected event")
|
||||
|
||||
return self._current_state_ids
|
||||
|
||||
async def _ensure_fetched(self) -> None:
|
||||
return None
|
||||
|
||||
|
||||
@attr.s(slots=True)
|
||||
class _AsyncEventContextImpl(EventContext):
|
||||
"""
|
||||
An implementation of EventContext which fetches _current_state_ids and
|
||||
_prev_state_ids from the database on demand.
|
||||
|
||||
Attributes:
|
||||
|
||||
_storage
|
||||
|
||||
_fetching_state_deferred: Resolves when *_state_ids have been calculated.
|
||||
None if we haven't started calculating yet
|
||||
|
||||
_event_type: The type of the event the context is associated with.
|
||||
|
||||
_event_state_key: The state_key of the event the context is associated with.
|
||||
|
||||
_prev_state_id: If the event associated with the context is a state event,
|
||||
then `_prev_state_id` is the event_id of the state that was replaced.
|
||||
"""
|
||||
|
||||
# This needs to have a default as we're inheriting
|
||||
_storage: "Storage" = attr.ib(default=None)
|
||||
_prev_state_id: Optional[str] = attr.ib(default=None)
|
||||
_event_type: str = attr.ib(default=None)
|
||||
_event_state_key: Optional[str] = attr.ib(default=None)
|
||||
_fetching_state_deferred: Optional["Deferred[None]"] = attr.ib(default=None)
|
||||
|
||||
async def _ensure_fetched(self) -> None:
|
||||
if not self._fetching_state_deferred:
|
||||
self._fetching_state_deferred = run_in_background(self._fill_out_state)
|
||||
|
||||
await make_deferred_yieldable(self._fetching_state_deferred)
|
||||
|
||||
async def _fill_out_state(self) -> None:
|
||||
"""Called to populate the _current_state_ids and _prev_state_ids
|
||||
attributes by loading from the database.
|
||||
"""
|
||||
if self.state_group is None:
|
||||
# No state group means the event is an outlier. Usually the state_ids dicts are also
|
||||
# pre-set to empty dicts, but they get reset when the context is serialized, so set
|
||||
# them to empty dicts again here.
|
||||
self._current_state_ids = {}
|
||||
self._prev_state_ids = {}
|
||||
return
|
||||
|
||||
current_state_ids = await self._storage.state.get_state_ids_for_group(
|
||||
self.state_group
|
||||
assert self.state_group_before_event is not None
|
||||
return await self._storage.state.get_state_ids_for_group(
|
||||
self.state_group_before_event
|
||||
)
|
||||
# Set this separately so mypy knows current_state_ids is not None.
|
||||
self._current_state_ids = current_state_ids
|
||||
if self._event_state_key is not None:
|
||||
self._prev_state_ids = dict(current_state_ids)
|
||||
|
||||
key = (self._event_type, self._event_state_key)
|
||||
if self._prev_state_id:
|
||||
self._prev_state_ids[key] = self._prev_state_id
|
||||
else:
|
||||
self._prev_state_ids.pop(key, None)
|
||||
else:
|
||||
self._prev_state_ids = current_state_ids
|
||||
|
||||
|
||||
def _encode_state_dict(
|
||||
|
||||
@@ -659,7 +659,7 @@ class FederationHandler:
|
||||
# in the invitee's sync stream. It is stripped out for all other local users.
|
||||
event.unsigned["knock_room_state"] = stripped_room_state["knock_state_events"]
|
||||
|
||||
context = EventContext.for_outlier()
|
||||
context = EventContext.for_outlier(self.storage)
|
||||
stream_id = await self._federation_event_handler.persist_events_and_notify(
|
||||
event.room_id, [(event, context)]
|
||||
)
|
||||
@@ -848,7 +848,7 @@ class FederationHandler:
|
||||
)
|
||||
)
|
||||
|
||||
context = EventContext.for_outlier()
|
||||
context = EventContext.for_outlier(self.storage)
|
||||
await self._federation_event_handler.persist_events_and_notify(
|
||||
event.room_id, [(event, context)]
|
||||
)
|
||||
@@ -877,7 +877,7 @@ class FederationHandler:
|
||||
|
||||
await self.federation_client.send_leave(host_list, event)
|
||||
|
||||
context = EventContext.for_outlier()
|
||||
context = EventContext.for_outlier(self.storage)
|
||||
stream_id = await self._federation_event_handler.persist_events_and_notify(
|
||||
event.room_id, [(event, context)]
|
||||
)
|
||||
|
||||
@@ -1423,7 +1423,7 @@ class FederationEventHandler:
|
||||
# we're not bothering about room state, so flag the event as an outlier.
|
||||
event.internal_metadata.outlier = True
|
||||
|
||||
context = EventContext.for_outlier()
|
||||
context = EventContext.for_outlier(self._storage)
|
||||
try:
|
||||
validate_event_for_room_version(room_version_obj, event)
|
||||
check_auth_rules_for_event(room_version_obj, event, auth)
|
||||
@@ -1874,10 +1874,10 @@ class FederationEventHandler:
|
||||
)
|
||||
|
||||
return EventContext.with_state(
|
||||
storage=self._storage,
|
||||
state_group=state_group,
|
||||
state_group_before_event=context.state_group_before_event,
|
||||
current_state_ids=current_state_ids,
|
||||
prev_state_ids=prev_state_ids,
|
||||
state_delta_due_to_event=state_updates,
|
||||
prev_group=prev_group,
|
||||
delta_ids=state_updates,
|
||||
partial_state=context.partial_state,
|
||||
|
||||
@@ -757,6 +757,10 @@ class EventCreationHandler:
|
||||
The previous version of the event is returned, if it is found in the
|
||||
event context. Otherwise, None is returned.
|
||||
"""
|
||||
if event.internal_metadata.is_outlier():
|
||||
# This can happen due to out of band memberships
|
||||
return None
|
||||
|
||||
prev_state_ids = await context.get_prev_state_ids()
|
||||
prev_event_id = prev_state_ids.get((event.type, event.state_key))
|
||||
if not prev_event_id:
|
||||
@@ -1001,7 +1005,7 @@ class EventCreationHandler:
|
||||
# after it is created
|
||||
if builder.internal_metadata.outlier:
|
||||
event.internal_metadata.outlier = True
|
||||
context = EventContext.for_outlier()
|
||||
context = EventContext.for_outlier(self.storage)
|
||||
elif (
|
||||
event.type == EventTypes.MSC2716_INSERTION
|
||||
and state_event_ids
|
||||
|
||||
@@ -448,7 +448,7 @@ class PaginationHandler:
|
||||
)
|
||||
# We expect `/messages` to use historic pagination tokens by default but
|
||||
# `/messages` should still works with live tokens when manually provided.
|
||||
assert from_token.room_key.topological
|
||||
assert from_token.room_key.topological is not None
|
||||
|
||||
if pagin_config.limit is None:
|
||||
# This shouldn't happen as we've set a default limit before this
|
||||
|
||||
@@ -33,6 +33,7 @@ from typing import (
|
||||
Optional,
|
||||
Pattern,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
@@ -92,6 +93,66 @@ HTML_ERROR_TEMPLATE = """<!DOCTYPE html>
|
||||
HTTP_STATUS_REQUEST_CANCELLED = 499
|
||||
|
||||
|
||||
F = TypeVar("F", bound=Callable[..., Any])
|
||||
|
||||
|
||||
_cancellable_method_names = frozenset(
|
||||
{
|
||||
# `RestServlet`, `BaseFederationServlet` and `BaseFederationServerServlet`
|
||||
# methods
|
||||
"on_GET",
|
||||
"on_PUT",
|
||||
"on_POST",
|
||||
"on_DELETE",
|
||||
# `_AsyncResource`, `DirectServeHtmlResource` and `DirectServeJsonResource`
|
||||
# methods
|
||||
"_async_render_GET",
|
||||
"_async_render_PUT",
|
||||
"_async_render_POST",
|
||||
"_async_render_DELETE",
|
||||
"_async_render_OPTIONS",
|
||||
# `ReplicationEndpoint` methods
|
||||
"_handle_request",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def cancellable(method: F) -> F:
|
||||
"""Marks a servlet method as cancellable.
|
||||
|
||||
Methods with this decorator will be cancelled if the client disconnects before we
|
||||
finish processing the request.
|
||||
|
||||
During cancellation, `Deferred.cancel()` will be invoked on the `Deferred` wrapping
|
||||
the method. The `cancel()` call will propagate down to the `Deferred` that is
|
||||
currently being waited on. That `Deferred` will raise a `CancelledError`, which will
|
||||
propagate up, as per normal exception handling.
|
||||
|
||||
Before applying this decorator to a new endpoint, you MUST recursively check
|
||||
that all `await`s in the function are on `async` functions or `Deferred`s that
|
||||
handle cancellation cleanly, otherwise a variety of bugs may occur, ranging from
|
||||
premature logging context closure, to stuck requests, to database corruption.
|
||||
|
||||
Usage:
|
||||
class SomeServlet(RestServlet):
|
||||
@cancellable
|
||||
async def on_GET(self, request: SynapseRequest) -> ...:
|
||||
...
|
||||
"""
|
||||
if method.__name__ not in _cancellable_method_names:
|
||||
raise ValueError(
|
||||
"@cancellable decorator can only be applied to servlet methods."
|
||||
)
|
||||
|
||||
method.cancellable = True # type: ignore[attr-defined]
|
||||
return method
|
||||
|
||||
|
||||
def is_method_cancellable(method: Callable[..., Any]) -> bool:
|
||||
"""Checks whether a servlet method has the `@cancellable` flag."""
|
||||
return getattr(method, "cancellable", False)
|
||||
|
||||
|
||||
def return_json_error(f: failure.Failure, request: SynapseRequest) -> None:
|
||||
"""Sends a JSON error response to clients."""
|
||||
|
||||
@@ -283,7 +344,9 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
|
||||
|
||||
def render(self, request: SynapseRequest) -> int:
|
||||
"""This gets called by twisted every time someone sends us a request."""
|
||||
defer.ensureDeferred(self._async_render_wrapper(request))
|
||||
request.render_deferred = defer.ensureDeferred(
|
||||
self._async_render_wrapper(request)
|
||||
)
|
||||
return NOT_DONE_YET
|
||||
|
||||
@wrap_async_request_handler
|
||||
|
||||
@@ -19,6 +19,7 @@ from typing import TYPE_CHECKING, Any, Generator, Optional, Tuple, Union
|
||||
import attr
|
||||
from zope.interface import implementer
|
||||
|
||||
from twisted.internet.defer import Deferred
|
||||
from twisted.internet.interfaces import IAddress, IReactorTime
|
||||
from twisted.python.failure import Failure
|
||||
from twisted.web.http import HTTPChannel
|
||||
@@ -91,6 +92,14 @@ class SynapseRequest(Request):
|
||||
# we can't yet create the logcontext, as we don't know the method.
|
||||
self.logcontext: Optional[LoggingContext] = None
|
||||
|
||||
# The `Deferred` to cancel if the client disconnects early and
|
||||
# `is_render_cancellable` is set. Expected to be set by `Resource.render`.
|
||||
self.render_deferred: Optional["Deferred[None]"] = None
|
||||
# A boolean indicating whether `render_deferred` should be cancelled if the
|
||||
# client disconnects early. Expected to be set by the coroutine started by
|
||||
# `Resource.render`, if rendering is asynchronous.
|
||||
self.is_render_cancellable = False
|
||||
|
||||
global _next_request_seq
|
||||
self.request_seq = _next_request_seq
|
||||
_next_request_seq += 1
|
||||
@@ -357,7 +366,21 @@ class SynapseRequest(Request):
|
||||
{"event": "client connection lost", "reason": str(reason.value)}
|
||||
)
|
||||
|
||||
if not self._is_processing:
|
||||
if self._is_processing:
|
||||
if self.is_render_cancellable:
|
||||
if self.render_deferred is not None:
|
||||
# Throw a cancellation into the request processing, in the hope
|
||||
# that it will finish up sooner than it normally would.
|
||||
# The `self.processing()` context manager will call
|
||||
# `_finished_processing()` when done.
|
||||
with PreserveLoggingContext():
|
||||
self.render_deferred.cancel()
|
||||
else:
|
||||
logger.error(
|
||||
"Connection from client lost, but have no Deferred to "
|
||||
"cancel even though the request is marked as cancellable."
|
||||
)
|
||||
else:
|
||||
self._finished_processing()
|
||||
|
||||
def _started_processing(self, servlet_name: str) -> None:
|
||||
|
||||
@@ -12,6 +12,85 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
This module implements the push rules & notifications portion of the Matrix
|
||||
specification.
|
||||
|
||||
There's a few related features:
|
||||
|
||||
* Push notifications (i.e. email or outgoing requests to a Push Gateway).
|
||||
* Calculation of unread notifications (for /sync and /notifications).
|
||||
|
||||
When Synapse receives a new event (locally, via the Client-Server API, or via
|
||||
federation), the following occurs:
|
||||
|
||||
1. The push rules get evaluated to generate a set of per-user actions.
|
||||
2. The event is persisted into the database.
|
||||
3. (In the background) The notifier is notified about the new event.
|
||||
|
||||
The per-user actions are initially stored in the event_push_actions_staging table,
|
||||
before getting moved into the event_push_actions table when the event is persisted.
|
||||
The event_push_actions table is periodically summarised into the event_push_summary
|
||||
and event_push_summary_stream_ordering tables.
|
||||
|
||||
Since push actions block an event from being persisted the generation of push
|
||||
actions is performance sensitive.
|
||||
|
||||
The general interaction of the classes are:
|
||||
|
||||
+---------------------------------------------+
|
||||
| FederationEventHandler/EventCreationHandler |
|
||||
+---------------------------------------------+
|
||||
|
|
||||
v
|
||||
+-----------------+
|
||||
| ActionGenerator |
|
||||
+-----------------+
|
||||
|
|
||||
v
|
||||
+-----------------------+ +---------------------------+
|
||||
| BulkPushRuleEvaluator |---->| PushRuleEvaluatorForEvent |
|
||||
+-----------------------+ +---------------------------+
|
||||
|
|
||||
v
|
||||
+-----------------------------+
|
||||
| EventPushActionsWorkerStore |
|
||||
+-----------------------------+
|
||||
|
||||
The notifier notifies the pusher pool of the new event, which checks for affected
|
||||
users. Each user-configured pusher of the affected users then performs the
|
||||
previously calculated action.
|
||||
|
||||
The general interaction of the classes are:
|
||||
|
||||
+----------+
|
||||
| Notifier |
|
||||
+----------+
|
||||
|
|
||||
v
|
||||
+------------+ +--------------+
|
||||
| PusherPool |---->| PusherConfig |
|
||||
+------------+ +--------------+
|
||||
|
|
||||
| +---------------+
|
||||
+<--->| PusherFactory |
|
||||
| +---------------+
|
||||
v
|
||||
+------------------------+ +-----------------------------------------------+
|
||||
| EmailPusher/HttpPusher |---->| EventPushActionsWorkerStore/PusherWorkerStore |
|
||||
+------------------------+ +-----------------------------------------------+
|
||||
|
|
||||
v
|
||||
+-------------------------+
|
||||
| Mailer/SimpleHttpClient |
|
||||
+-------------------------+
|
||||
|
||||
The Pusher instance also calls out to various utilities for generating payloads
|
||||
(or email templates), but those interactions are not detailed in this diagram
|
||||
(and are specific to the type of pusher).
|
||||
|
||||
"""
|
||||
|
||||
import abc
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||
|
||||
|
||||
@@ -40,5 +40,9 @@ class ActionGenerator:
|
||||
async def handle_push_actions_for_event(
|
||||
self, event: EventBase, context: EventContext
|
||||
) -> None:
|
||||
if event.internal_metadata.is_outlier():
|
||||
# This can happen due to out of band memberships
|
||||
return
|
||||
|
||||
with Measure(self.clock, "action_for_event_by_user"):
|
||||
await self.bulk_evaluator.action_for_event_by_user(event, context)
|
||||
|
||||
@@ -208,8 +208,6 @@ class BulkPushRuleEvaluator:
|
||||
event, len(room_members), sender_power_level, power_levels
|
||||
)
|
||||
|
||||
condition_cache: Dict[str, bool] = {}
|
||||
|
||||
# If the event is not a state event check if any users ignore the sender.
|
||||
if not event.is_state():
|
||||
ignorers = await self.store.ignored_by(event.sender)
|
||||
@@ -247,8 +245,8 @@ class BulkPushRuleEvaluator:
|
||||
if "enabled" in rule and not rule["enabled"]:
|
||||
continue
|
||||
|
||||
matches = _condition_checker(
|
||||
evaluator, rule["conditions"], uid, display_name, condition_cache
|
||||
matches = evaluator.check_conditions(
|
||||
rule["conditions"], uid, display_name
|
||||
)
|
||||
if matches:
|
||||
actions = [x for x in rule["actions"] if x != "dont_notify"]
|
||||
@@ -267,32 +265,6 @@ class BulkPushRuleEvaluator:
|
||||
)
|
||||
|
||||
|
||||
def _condition_checker(
|
||||
evaluator: PushRuleEvaluatorForEvent,
|
||||
conditions: List[dict],
|
||||
uid: str,
|
||||
display_name: Optional[str],
|
||||
cache: Dict[str, bool],
|
||||
) -> bool:
|
||||
for cond in conditions:
|
||||
_cache_key = cond.get("_cache_key", None)
|
||||
if _cache_key:
|
||||
res = cache.get(_cache_key, None)
|
||||
if res is False:
|
||||
return False
|
||||
elif res is True:
|
||||
continue
|
||||
|
||||
res = evaluator.matches(cond, uid, display_name)
|
||||
if _cache_key:
|
||||
cache[_cache_key] = bool(res)
|
||||
|
||||
if not res:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
MemberMap = Dict[str, Optional[EventIdMembership]]
|
||||
Rule = Dict[str, dict]
|
||||
RulesByUser = Dict[str, List[Rule]]
|
||||
|
||||
@@ -129,9 +129,55 @@ class PushRuleEvaluatorForEvent:
|
||||
# Maps strings of e.g. 'content.body' -> event["content"]["body"]
|
||||
self._value_cache = _flatten_dict(event)
|
||||
|
||||
# Maps cache keys to final values.
|
||||
self._condition_cache: Dict[str, bool] = {}
|
||||
|
||||
def check_conditions(
|
||||
self, conditions: List[dict], uid: str, display_name: Optional[str]
|
||||
) -> bool:
|
||||
"""
|
||||
Returns true if a user's conditions/user ID/display name match the event.
|
||||
|
||||
Args:
|
||||
conditions: The user's conditions to match.
|
||||
uid: The user's MXID.
|
||||
display_name: The display name.
|
||||
|
||||
Returns:
|
||||
True if all conditions match the event, False otherwise.
|
||||
"""
|
||||
for cond in conditions:
|
||||
_cache_key = cond.get("_cache_key", None)
|
||||
if _cache_key:
|
||||
res = self._condition_cache.get(_cache_key, None)
|
||||
if res is False:
|
||||
return False
|
||||
elif res is True:
|
||||
continue
|
||||
|
||||
res = self.matches(cond, uid, display_name)
|
||||
if _cache_key:
|
||||
self._condition_cache[_cache_key] = bool(res)
|
||||
|
||||
if not res:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def matches(
|
||||
self, condition: Dict[str, Any], user_id: str, display_name: Optional[str]
|
||||
) -> bool:
|
||||
"""
|
||||
Returns true if a user's condition/user ID/display name match the event.
|
||||
|
||||
Args:
|
||||
condition: The user's condition to match.
|
||||
uid: The user's MXID.
|
||||
display_name: The display name, or None if there is not one.
|
||||
|
||||
Returns:
|
||||
True if the condition matches the event, False otherwise.
|
||||
"""
|
||||
if condition["kind"] == "event_match":
|
||||
return self._event_match(condition, user_id)
|
||||
elif condition["kind"] == "contains_display_name":
|
||||
@@ -146,6 +192,16 @@ class PushRuleEvaluatorForEvent:
|
||||
return True
|
||||
|
||||
def _event_match(self, condition: dict, user_id: str) -> bool:
|
||||
"""
|
||||
Check an "event_match" push rule condition.
|
||||
|
||||
Args:
|
||||
condition: The "event_match" push rule condition to match.
|
||||
user_id: The user's MXID.
|
||||
|
||||
Returns:
|
||||
True if the condition matches the event, False otherwise.
|
||||
"""
|
||||
pattern = condition.get("pattern", None)
|
||||
|
||||
if not pattern:
|
||||
@@ -167,13 +223,22 @@ class PushRuleEvaluatorForEvent:
|
||||
|
||||
return _glob_matches(pattern, body, word_boundary=True)
|
||||
else:
|
||||
haystack = self._get_value(condition["key"])
|
||||
haystack = self._value_cache.get(condition["key"], None)
|
||||
if haystack is None:
|
||||
return False
|
||||
|
||||
return _glob_matches(pattern, haystack)
|
||||
|
||||
def _contains_display_name(self, display_name: Optional[str]) -> bool:
|
||||
"""
|
||||
Check an "event_match" push rule condition.
|
||||
|
||||
Args:
|
||||
display_name: The display name, or None if there is not one.
|
||||
|
||||
Returns:
|
||||
True if the display name is found in the event body, False otherwise.
|
||||
"""
|
||||
if not display_name:
|
||||
return False
|
||||
|
||||
@@ -191,9 +256,6 @@ class PushRuleEvaluatorForEvent:
|
||||
|
||||
return bool(r.search(body))
|
||||
|
||||
def _get_value(self, dotted_key: str) -> Optional[str]:
|
||||
return self._value_cache.get(dotted_key, None)
|
||||
|
||||
|
||||
# Caches (string, is_glob, word_boundary) -> regex for push. See _glob_matches
|
||||
regex_cache: LruCache[Tuple[str, bool, bool], Pattern] = LruCache(
|
||||
|
||||
@@ -90,7 +90,7 @@ class ServerNoticesManager:
|
||||
)
|
||||
return event
|
||||
|
||||
@cached()
|
||||
@cached(prune_unread_entries=False)
|
||||
async def get_or_create_notice_room_for_user(self, user_id: str) -> str:
|
||||
"""Get the room for notices for a given user
|
||||
|
||||
|
||||
@@ -130,6 +130,7 @@ class StateHandler:
|
||||
self.state_store = hs.get_storage().state
|
||||
self.hs = hs
|
||||
self._state_resolution_handler = hs.get_state_resolution_handler()
|
||||
self._storage = hs.get_storage()
|
||||
|
||||
@overload
|
||||
async def get_current_state(
|
||||
@@ -361,10 +362,10 @@ class StateHandler:
|
||||
|
||||
if not event.is_state():
|
||||
return EventContext.with_state(
|
||||
storage=self._storage,
|
||||
state_group_before_event=state_group_before_event,
|
||||
state_group=state_group_before_event,
|
||||
current_state_ids=state_ids_before_event,
|
||||
prev_state_ids=state_ids_before_event,
|
||||
state_delta_due_to_event={},
|
||||
prev_group=state_group_before_event_prev_group,
|
||||
delta_ids=deltas_to_state_group_before_event,
|
||||
partial_state=partial_state,
|
||||
@@ -393,10 +394,10 @@ class StateHandler:
|
||||
)
|
||||
|
||||
return EventContext.with_state(
|
||||
storage=self._storage,
|
||||
state_group=state_group_after_event,
|
||||
state_group_before_event=state_group_before_event,
|
||||
current_state_ids=state_ids_after_event,
|
||||
prev_state_ids=state_ids_before_event,
|
||||
state_delta_due_to_event=delta_ids,
|
||||
prev_group=state_group_before_event,
|
||||
delta_ids=delta_ids,
|
||||
partial_state=partial_state,
|
||||
|
||||
@@ -36,7 +36,6 @@ from prometheus_client import Counter
|
||||
import synapse.metrics
|
||||
from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
|
||||
from synapse.api.room_versions import RoomVersions
|
||||
from synapse.crypto.event_signing import compute_event_reference_hash
|
||||
from synapse.events import EventBase # noqa: F401
|
||||
from synapse.events.snapshot import EventContext # noqa: F401
|
||||
from synapse.storage._base import db_to_json, make_in_list_sql_clause
|
||||
@@ -50,7 +49,7 @@ from synapse.storage.databases.main.search import SearchEntry
|
||||
from synapse.storage.engines.postgres import PostgresEngine
|
||||
from synapse.storage.util.id_generators import AbstractStreamIdGenerator
|
||||
from synapse.storage.util.sequence import SequenceGenerator
|
||||
from synapse.types import StateMap, get_domain_from_id
|
||||
from synapse.types import JsonDict, StateMap, get_domain_from_id
|
||||
from synapse.util import json_encoder
|
||||
from synapse.util.iterutils import batch_iter, sorted_topologically
|
||||
|
||||
@@ -129,7 +128,6 @@ class PersistEventsStore:
|
||||
self,
|
||||
events_and_contexts: List[Tuple[EventBase, EventContext]],
|
||||
*,
|
||||
current_state_for_room: Dict[str, StateMap[str]],
|
||||
state_delta_for_room: Dict[str, DeltaState],
|
||||
new_forward_extremities: Dict[str, Set[str]],
|
||||
use_negative_stream_ordering: bool = False,
|
||||
@@ -140,8 +138,6 @@ class PersistEventsStore:
|
||||
|
||||
Args:
|
||||
events_and_contexts:
|
||||
current_state_for_room: Map from room_id to the current state of
|
||||
the room based on forward extremities
|
||||
state_delta_for_room: Map from room_id to the delta to apply to
|
||||
room state
|
||||
new_forward_extremities: Map from room_id to set of event IDs
|
||||
@@ -216,9 +212,6 @@ class PersistEventsStore:
|
||||
|
||||
event_counter.labels(event.type, origin_type, origin_entity).inc()
|
||||
|
||||
for room_id, new_state in current_state_for_room.items():
|
||||
self.store.get_current_state_ids.prefill((room_id,), new_state)
|
||||
|
||||
for room_id, latest_event_ids in new_forward_extremities.items():
|
||||
self.store.get_latest_event_ids_in_room.prefill(
|
||||
(room_id,), list(latest_event_ids)
|
||||
@@ -236,7 +229,9 @@ class PersistEventsStore:
|
||||
"""
|
||||
results: List[str] = []
|
||||
|
||||
def _get_events_which_are_prevs_txn(txn, batch):
|
||||
def _get_events_which_are_prevs_txn(
|
||||
txn: LoggingTransaction, batch: Collection[str]
|
||||
) -> None:
|
||||
sql = """
|
||||
SELECT prev_event_id, internal_metadata
|
||||
FROM event_edges
|
||||
@@ -286,7 +281,9 @@ class PersistEventsStore:
|
||||
# and their prev events.
|
||||
existing_prevs = set()
|
||||
|
||||
def _get_prevs_before_rejected_txn(txn, batch):
|
||||
def _get_prevs_before_rejected_txn(
|
||||
txn: LoggingTransaction, batch: Collection[str]
|
||||
) -> None:
|
||||
to_recursively_check = batch
|
||||
|
||||
while to_recursively_check:
|
||||
@@ -516,7 +513,7 @@ class PersistEventsStore:
|
||||
@classmethod
|
||||
def _add_chain_cover_index(
|
||||
cls,
|
||||
txn,
|
||||
txn: LoggingTransaction,
|
||||
db_pool: DatabasePool,
|
||||
event_chain_id_gen: SequenceGenerator,
|
||||
event_to_room_id: Dict[str, str],
|
||||
@@ -810,7 +807,7 @@ class PersistEventsStore:
|
||||
|
||||
@staticmethod
|
||||
def _allocate_chain_ids(
|
||||
txn,
|
||||
txn: LoggingTransaction,
|
||||
db_pool: DatabasePool,
|
||||
event_chain_id_gen: SequenceGenerator,
|
||||
event_to_room_id: Dict[str, str],
|
||||
@@ -944,7 +941,7 @@ class PersistEventsStore:
|
||||
self,
|
||||
txn: LoggingTransaction,
|
||||
events_and_contexts: List[Tuple[EventBase, EventContext]],
|
||||
):
|
||||
) -> None:
|
||||
"""Persist the mapping from transaction IDs to event IDs (if defined)."""
|
||||
|
||||
to_insert = []
|
||||
@@ -998,7 +995,7 @@ class PersistEventsStore:
|
||||
txn: LoggingTransaction,
|
||||
state_delta_by_room: Dict[str, DeltaState],
|
||||
stream_id: int,
|
||||
):
|
||||
) -> None:
|
||||
for room_id, delta_state in state_delta_by_room.items():
|
||||
to_delete = delta_state.to_delete
|
||||
to_insert = delta_state.to_insert
|
||||
@@ -1156,7 +1153,7 @@ class PersistEventsStore:
|
||||
txn, room_id, members_changed
|
||||
)
|
||||
|
||||
def _upsert_room_version_txn(self, txn: LoggingTransaction, room_id: str):
|
||||
def _upsert_room_version_txn(self, txn: LoggingTransaction, room_id: str) -> None:
|
||||
"""Update the room version in the database based off current state
|
||||
events.
|
||||
|
||||
@@ -1190,7 +1187,7 @@ class PersistEventsStore:
|
||||
txn: LoggingTransaction,
|
||||
new_forward_extremities: Dict[str, Set[str]],
|
||||
max_stream_order: int,
|
||||
):
|
||||
) -> None:
|
||||
for room_id in new_forward_extremities.keys():
|
||||
self.db_pool.simple_delete_txn(
|
||||
txn, table="event_forward_extremities", keyvalues={"room_id": room_id}
|
||||
@@ -1255,9 +1252,9 @@ class PersistEventsStore:
|
||||
|
||||
def _update_room_depths_txn(
|
||||
self,
|
||||
txn,
|
||||
txn: LoggingTransaction,
|
||||
events_and_contexts: List[Tuple[EventBase, EventContext]],
|
||||
):
|
||||
) -> None:
|
||||
"""Update min_depth for each room
|
||||
|
||||
Args:
|
||||
@@ -1386,7 +1383,7 @@ class PersistEventsStore:
|
||||
# nothing to do here
|
||||
return
|
||||
|
||||
def event_dict(event):
|
||||
def event_dict(event: EventBase) -> JsonDict:
|
||||
d = event.get_dict()
|
||||
d.pop("redacted", None)
|
||||
d.pop("redacted_because", None)
|
||||
@@ -1477,18 +1474,20 @@ class PersistEventsStore:
|
||||
),
|
||||
)
|
||||
|
||||
def _store_rejected_events_txn(self, txn, events_and_contexts):
|
||||
def _store_rejected_events_txn(
|
||||
self,
|
||||
txn: LoggingTransaction,
|
||||
events_and_contexts: List[Tuple[EventBase, EventContext]],
|
||||
) -> List[Tuple[EventBase, EventContext]]:
|
||||
"""Add rows to the 'rejections' table for received events which were
|
||||
rejected
|
||||
|
||||
Args:
|
||||
txn (twisted.enterprise.adbapi.Connection): db connection
|
||||
events_and_contexts (list[(EventBase, EventContext)]): events
|
||||
we are persisting
|
||||
txn: db connection
|
||||
events_and_contexts: events we are persisting
|
||||
|
||||
Returns:
|
||||
list[(EventBase, EventContext)] new list, without the rejected
|
||||
events.
|
||||
new list, without the rejected events.
|
||||
"""
|
||||
# Remove the rejected events from the list now that we've added them
|
||||
# to the events table and the events_json table.
|
||||
@@ -1509,7 +1508,7 @@ class PersistEventsStore:
|
||||
events_and_contexts: List[Tuple[EventBase, EventContext]],
|
||||
all_events_and_contexts: List[Tuple[EventBase, EventContext]],
|
||||
inhibit_local_membership_updates: bool = False,
|
||||
):
|
||||
) -> None:
|
||||
"""Update all the miscellaneous tables for new events
|
||||
|
||||
Args:
|
||||
@@ -1600,15 +1599,14 @@ class PersistEventsStore:
|
||||
inhibit_local_membership_updates=inhibit_local_membership_updates,
|
||||
)
|
||||
|
||||
# Insert event_reference_hashes table.
|
||||
self._store_event_reference_hashes_txn(
|
||||
txn, [event for event, _ in events_and_contexts]
|
||||
)
|
||||
|
||||
# Prefill the event cache
|
||||
self._add_to_cache(txn, events_and_contexts)
|
||||
|
||||
def _add_to_cache(self, txn, events_and_contexts):
|
||||
def _add_to_cache(
|
||||
self,
|
||||
txn: LoggingTransaction,
|
||||
events_and_contexts: List[Tuple[EventBase, EventContext]],
|
||||
) -> None:
|
||||
to_prefill = []
|
||||
|
||||
rows = []
|
||||
@@ -1639,7 +1637,7 @@ class PersistEventsStore:
|
||||
if not row["rejects"] and not row["redacts"]:
|
||||
to_prefill.append(EventCacheEntry(event=event, redacted_event=None))
|
||||
|
||||
def prefill():
|
||||
def prefill() -> None:
|
||||
for cache_entry in to_prefill:
|
||||
self.store._get_event_cache.set(
|
||||
(cache_entry.event.event_id,), cache_entry
|
||||
@@ -1669,19 +1667,24 @@ class PersistEventsStore:
|
||||
)
|
||||
|
||||
def insert_labels_for_event_txn(
|
||||
self, txn, event_id, labels, room_id, topological_ordering
|
||||
):
|
||||
self,
|
||||
txn: LoggingTransaction,
|
||||
event_id: str,
|
||||
labels: List[str],
|
||||
room_id: str,
|
||||
topological_ordering: int,
|
||||
) -> None:
|
||||
"""Store the mapping between an event's ID and its labels, with one row per
|
||||
(event_id, label) tuple.
|
||||
|
||||
Args:
|
||||
txn (LoggingTransaction): The transaction to execute.
|
||||
event_id (str): The event's ID.
|
||||
labels (list[str]): A list of text labels.
|
||||
room_id (str): The ID of the room the event was sent to.
|
||||
topological_ordering (int): The position of the event in the room's topology.
|
||||
txn: The transaction to execute.
|
||||
event_id: The event's ID.
|
||||
labels: A list of text labels.
|
||||
room_id: The ID of the room the event was sent to.
|
||||
topological_ordering: The position of the event in the room's topology.
|
||||
"""
|
||||
return self.db_pool.simple_insert_many_txn(
|
||||
self.db_pool.simple_insert_many_txn(
|
||||
txn=txn,
|
||||
table="event_labels",
|
||||
keys=("event_id", "label", "room_id", "topological_ordering"),
|
||||
@@ -1690,44 +1693,32 @@ class PersistEventsStore:
|
||||
],
|
||||
)
|
||||
|
||||
def _insert_event_expiry_txn(self, txn, event_id, expiry_ts):
|
||||
def _insert_event_expiry_txn(
|
||||
self, txn: LoggingTransaction, event_id: str, expiry_ts: int
|
||||
) -> None:
|
||||
"""Save the expiry timestamp associated with a given event ID.
|
||||
|
||||
Args:
|
||||
txn (LoggingTransaction): The database transaction to use.
|
||||
event_id (str): The event ID the expiry timestamp is associated with.
|
||||
expiry_ts (int): The timestamp at which to expire (delete) the event.
|
||||
txn: The database transaction to use.
|
||||
event_id: The event ID the expiry timestamp is associated with.
|
||||
expiry_ts: The timestamp at which to expire (delete) the event.
|
||||
"""
|
||||
return self.db_pool.simple_insert_txn(
|
||||
self.db_pool.simple_insert_txn(
|
||||
txn=txn,
|
||||
table="event_expiry",
|
||||
values={"event_id": event_id, "expiry_ts": expiry_ts},
|
||||
)
|
||||
|
||||
def _store_event_reference_hashes_txn(self, txn, events):
|
||||
"""Store a hash for a PDU
|
||||
Args:
|
||||
txn (cursor):
|
||||
events (list): list of Events.
|
||||
"""
|
||||
|
||||
vals = []
|
||||
for event in events:
|
||||
ref_alg, ref_hash_bytes = compute_event_reference_hash(event)
|
||||
vals.append((event.event_id, ref_alg, memoryview(ref_hash_bytes)))
|
||||
|
||||
self.db_pool.simple_insert_many_txn(
|
||||
txn,
|
||||
table="event_reference_hashes",
|
||||
keys=("event_id", "algorithm", "hash"),
|
||||
values=vals,
|
||||
)
|
||||
|
||||
def _store_room_members_txn(
|
||||
self, txn, events, *, inhibit_local_membership_updates: bool = False
|
||||
):
|
||||
self,
|
||||
txn: LoggingTransaction,
|
||||
events: List[EventBase],
|
||||
*,
|
||||
inhibit_local_membership_updates: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Store a room member in the database.
|
||||
|
||||
Args:
|
||||
txn: The transaction to use.
|
||||
events: List of events to store.
|
||||
@@ -1767,6 +1758,7 @@ class PersistEventsStore:
|
||||
)
|
||||
|
||||
for event in events:
|
||||
assert event.internal_metadata.stream_ordering is not None
|
||||
txn.call_after(
|
||||
self.store._membership_stream_cache.entity_has_changed,
|
||||
event.state_key,
|
||||
@@ -1863,7 +1855,9 @@ class PersistEventsStore:
|
||||
(parent_id, event.sender),
|
||||
)
|
||||
|
||||
def _handle_insertion_event(self, txn: LoggingTransaction, event: EventBase):
|
||||
def _handle_insertion_event(
|
||||
self, txn: LoggingTransaction, event: EventBase
|
||||
) -> None:
|
||||
"""Handles keeping track of insertion events and edges/connections.
|
||||
Part of MSC2716.
|
||||
|
||||
@@ -1924,7 +1918,7 @@ class PersistEventsStore:
|
||||
},
|
||||
)
|
||||
|
||||
def _handle_batch_event(self, txn: LoggingTransaction, event: EventBase):
|
||||
def _handle_batch_event(self, txn: LoggingTransaction, event: EventBase) -> None:
|
||||
"""Handles inserting the batch edges/connections between the batch event
|
||||
and an insertion event. Part of MSC2716.
|
||||
|
||||
@@ -2024,25 +2018,29 @@ class PersistEventsStore:
|
||||
txn, table="event_relations", keyvalues={"event_id": redacted_event_id}
|
||||
)
|
||||
|
||||
def _store_room_topic_txn(self, txn: LoggingTransaction, event: EventBase):
|
||||
def _store_room_topic_txn(self, txn: LoggingTransaction, event: EventBase) -> None:
|
||||
if isinstance(event.content.get("topic"), str):
|
||||
self.store_event_search_txn(
|
||||
txn, event, "content.topic", event.content["topic"]
|
||||
)
|
||||
|
||||
def _store_room_name_txn(self, txn: LoggingTransaction, event: EventBase):
|
||||
def _store_room_name_txn(self, txn: LoggingTransaction, event: EventBase) -> None:
|
||||
if isinstance(event.content.get("name"), str):
|
||||
self.store_event_search_txn(
|
||||
txn, event, "content.name", event.content["name"]
|
||||
)
|
||||
|
||||
def _store_room_message_txn(self, txn: LoggingTransaction, event: EventBase):
|
||||
def _store_room_message_txn(
|
||||
self, txn: LoggingTransaction, event: EventBase
|
||||
) -> None:
|
||||
if isinstance(event.content.get("body"), str):
|
||||
self.store_event_search_txn(
|
||||
txn, event, "content.body", event.content["body"]
|
||||
)
|
||||
|
||||
def _store_retention_policy_for_room_txn(self, txn, event):
|
||||
def _store_retention_policy_for_room_txn(
|
||||
self, txn: LoggingTransaction, event: EventBase
|
||||
) -> None:
|
||||
if not event.is_state():
|
||||
logger.debug("Ignoring non-state m.room.retention event")
|
||||
return
|
||||
@@ -2102,8 +2100,11 @@ class PersistEventsStore:
|
||||
)
|
||||
|
||||
def _set_push_actions_for_event_and_users_txn(
|
||||
self, txn, events_and_contexts, all_events_and_contexts
|
||||
):
|
||||
self,
|
||||
txn: LoggingTransaction,
|
||||
events_and_contexts: List[Tuple[EventBase, EventContext]],
|
||||
all_events_and_contexts: List[Tuple[EventBase, EventContext]],
|
||||
) -> None:
|
||||
"""Handles moving push actions from staging table to main
|
||||
event_push_actions table for all events in `events_and_contexts`.
|
||||
|
||||
@@ -2111,12 +2112,10 @@ class PersistEventsStore:
|
||||
from the push action staging area.
|
||||
|
||||
Args:
|
||||
events_and_contexts (list[(EventBase, EventContext)]): events
|
||||
we are persisting
|
||||
all_events_and_contexts (list[(EventBase, EventContext)]): all
|
||||
events that we were going to persist. This includes events
|
||||
we've already persisted, etc, that wouldn't appear in
|
||||
events_and_context.
|
||||
events_and_contexts: events we are persisting
|
||||
all_events_and_contexts: all events that we were going to persist.
|
||||
This includes events we've already persisted, etc, that wouldn't
|
||||
appear in events_and_context.
|
||||
"""
|
||||
|
||||
# Only non outlier events will have push actions associated with them,
|
||||
@@ -2185,7 +2184,9 @@ class PersistEventsStore:
|
||||
),
|
||||
)
|
||||
|
||||
def _remove_push_actions_for_event_id_txn(self, txn, room_id, event_id):
|
||||
def _remove_push_actions_for_event_id_txn(
|
||||
self, txn: LoggingTransaction, room_id: str, event_id: str
|
||||
) -> None:
|
||||
# Sad that we have to blow away the cache for the whole room here
|
||||
txn.call_after(
|
||||
self.store.get_unread_event_push_actions_by_room_for_user.invalidate,
|
||||
@@ -2196,7 +2197,9 @@ class PersistEventsStore:
|
||||
(room_id, event_id),
|
||||
)
|
||||
|
||||
def _store_rejections_txn(self, txn, event_id, reason):
|
||||
def _store_rejections_txn(
|
||||
self, txn: LoggingTransaction, event_id: str, reason: str
|
||||
) -> None:
|
||||
self.db_pool.simple_insert_txn(
|
||||
txn,
|
||||
table="rejections",
|
||||
@@ -2208,8 +2211,10 @@ class PersistEventsStore:
|
||||
)
|
||||
|
||||
def _store_event_state_mappings_txn(
|
||||
self, txn, events_and_contexts: Iterable[Tuple[EventBase, EventContext]]
|
||||
):
|
||||
self,
|
||||
txn: LoggingTransaction,
|
||||
events_and_contexts: Collection[Tuple[EventBase, EventContext]],
|
||||
) -> None:
|
||||
state_groups = {}
|
||||
for event, context in events_and_contexts:
|
||||
if event.internal_metadata.is_outlier():
|
||||
@@ -2266,7 +2271,9 @@ class PersistEventsStore:
|
||||
state_group_id,
|
||||
)
|
||||
|
||||
def _update_min_depth_for_room_txn(self, txn, room_id, depth):
|
||||
def _update_min_depth_for_room_txn(
|
||||
self, txn: LoggingTransaction, room_id: str, depth: int
|
||||
) -> None:
|
||||
min_depth = self.store._get_min_depth_interaction(txn, room_id)
|
||||
|
||||
if min_depth is not None and depth >= min_depth:
|
||||
@@ -2279,7 +2286,9 @@ class PersistEventsStore:
|
||||
values={"min_depth": depth},
|
||||
)
|
||||
|
||||
def _handle_mult_prev_events(self, txn, events):
|
||||
def _handle_mult_prev_events(
|
||||
self, txn: LoggingTransaction, events: List[EventBase]
|
||||
) -> None:
|
||||
"""
|
||||
For the given event, update the event edges table and forward and
|
||||
backward extremities tables.
|
||||
@@ -2297,7 +2306,9 @@ class PersistEventsStore:
|
||||
|
||||
self._update_backward_extremeties(txn, events)
|
||||
|
||||
def _update_backward_extremeties(self, txn, events):
|
||||
def _update_backward_extremeties(
|
||||
self, txn: LoggingTransaction, events: List[EventBase]
|
||||
) -> None:
|
||||
"""Updates the event_backward_extremities tables based on the new/updated
|
||||
events being persisted.
|
||||
|
||||
|
||||
@@ -69,7 +69,6 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
|
||||
# event_forward_extremities
|
||||
# event_json
|
||||
# event_push_actions
|
||||
# event_reference_hashes
|
||||
# event_relations
|
||||
# event_search
|
||||
# event_to_state_groups
|
||||
@@ -220,7 +219,6 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
|
||||
"event_auth",
|
||||
"event_edges",
|
||||
"event_forward_extremities",
|
||||
"event_reference_hashes",
|
||||
"event_relations",
|
||||
"event_search",
|
||||
"rejections",
|
||||
@@ -369,7 +367,6 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
|
||||
"event_edges",
|
||||
"event_json",
|
||||
"event_push_actions_staging",
|
||||
"event_reference_hashes",
|
||||
"event_relations",
|
||||
"event_to_state_groups",
|
||||
"event_auth_chains",
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import TYPE_CHECKING, Any, Collection, Iterable, List, Optional, Set
|
||||
from typing import TYPE_CHECKING, Any, Collection, Iterable, List, Optional, Set, Tuple
|
||||
|
||||
import attr
|
||||
|
||||
@@ -27,7 +27,7 @@ from synapse.storage.database import (
|
||||
LoggingTransaction,
|
||||
)
|
||||
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
|
||||
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
|
||||
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
|
||||
from synapse.types import JsonDict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -149,7 +149,9 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
|
||||
self.EVENT_SEARCH_DELETE_NON_STRINGS, self._background_delete_non_strings
|
||||
)
|
||||
|
||||
async def _background_reindex_search(self, progress, batch_size):
|
||||
async def _background_reindex_search(
|
||||
self, progress: JsonDict, batch_size: int
|
||||
) -> int:
|
||||
# we work through the events table from highest stream id to lowest
|
||||
target_min_stream_id = progress["target_min_stream_id_inclusive"]
|
||||
max_stream_id = progress["max_stream_id_exclusive"]
|
||||
@@ -157,7 +159,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
|
||||
|
||||
TYPES = ["m.room.name", "m.room.message", "m.room.topic"]
|
||||
|
||||
def reindex_search_txn(txn):
|
||||
def reindex_search_txn(txn: LoggingTransaction) -> int:
|
||||
sql = (
|
||||
"SELECT stream_ordering, event_id, room_id, type, json, "
|
||||
" origin_server_ts FROM events"
|
||||
@@ -255,12 +257,14 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
|
||||
|
||||
return result
|
||||
|
||||
async def _background_reindex_gin_search(self, progress, batch_size):
|
||||
async def _background_reindex_gin_search(
|
||||
self, progress: JsonDict, batch_size: int
|
||||
) -> int:
|
||||
"""This handles old synapses which used GIST indexes, if any;
|
||||
converting them back to be GIN as per the actual schema.
|
||||
"""
|
||||
|
||||
def create_index(conn):
|
||||
def create_index(conn: LoggingDatabaseConnection) -> None:
|
||||
conn.rollback()
|
||||
|
||||
# we have to set autocommit, because postgres refuses to
|
||||
@@ -299,7 +303,9 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
|
||||
)
|
||||
return 1
|
||||
|
||||
async def _background_reindex_search_order(self, progress, batch_size):
|
||||
async def _background_reindex_search_order(
|
||||
self, progress: JsonDict, batch_size: int
|
||||
) -> int:
|
||||
target_min_stream_id = progress["target_min_stream_id_inclusive"]
|
||||
max_stream_id = progress["max_stream_id_exclusive"]
|
||||
rows_inserted = progress.get("rows_inserted", 0)
|
||||
@@ -307,7 +313,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
|
||||
|
||||
if not have_added_index:
|
||||
|
||||
def create_index(conn):
|
||||
def create_index(conn: LoggingDatabaseConnection) -> None:
|
||||
conn.rollback()
|
||||
conn.set_session(autocommit=True)
|
||||
c = conn.cursor()
|
||||
@@ -336,7 +342,7 @@ class SearchBackgroundUpdateStore(SearchWorkerStore):
|
||||
pg,
|
||||
)
|
||||
|
||||
def reindex_search_txn(txn):
|
||||
def reindex_search_txn(txn: LoggingTransaction) -> Tuple[int, bool]:
|
||||
sql = (
|
||||
"UPDATE event_search AS es SET stream_ordering = e.stream_ordering,"
|
||||
" origin_server_ts = e.origin_server_ts"
|
||||
@@ -644,7 +650,8 @@ class SearchStore(SearchBackgroundUpdateStore):
|
||||
else:
|
||||
raise Exception("Unrecognized database engine")
|
||||
|
||||
args.append(limit)
|
||||
# mypy expects to append only a `str`, not an `int`
|
||||
args.append(limit) # type: ignore[arg-type]
|
||||
|
||||
results = await self.db_pool.execute(
|
||||
"search_rooms", self.db_pool.cursor_to_dict, sql, *args
|
||||
@@ -705,7 +712,7 @@ class SearchStore(SearchBackgroundUpdateStore):
|
||||
A set of strings.
|
||||
"""
|
||||
|
||||
def f(txn):
|
||||
def f(txn: LoggingTransaction) -> Set[str]:
|
||||
highlight_words = set()
|
||||
for event in events:
|
||||
# As a hack we simply join values of all possible keys. This is
|
||||
@@ -759,11 +766,11 @@ class SearchStore(SearchBackgroundUpdateStore):
|
||||
return await self.db_pool.runInteraction("_find_highlights", f)
|
||||
|
||||
|
||||
def _to_postgres_options(options_dict):
|
||||
def _to_postgres_options(options_dict: JsonDict) -> str:
|
||||
return "'%s'" % (",".join("%s=%s" % (k, v) for k, v in options_dict.items()),)
|
||||
|
||||
|
||||
def _parse_query(database_engine, search_term):
|
||||
def _parse_query(database_engine: BaseDatabaseEngine, search_term: str) -> str:
|
||||
"""Takes a plain unicode string from the user and converts it into a form
|
||||
that can be passed to database.
|
||||
We use this so that we can add prefix matching, which isn't something
|
||||
|
||||
@@ -785,22 +785,14 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
return None
|
||||
|
||||
async def get_current_room_stream_token_for_room_id(
|
||||
self, room_id: Optional[str] = None
|
||||
self, room_id: str
|
||||
) -> RoomStreamToken:
|
||||
"""Returns the current position of the rooms stream.
|
||||
|
||||
By default, it returns a live token with the current global stream
|
||||
token. Specifying a `room_id` causes it to return a historic token with
|
||||
the room specific topological token.
|
||||
"""
|
||||
"""Returns the current position of the rooms stream (historic token)."""
|
||||
stream_ordering = self.get_room_max_stream_ordering()
|
||||
if room_id is None:
|
||||
return RoomStreamToken(None, stream_ordering)
|
||||
else:
|
||||
topo = await self.db_pool.runInteraction(
|
||||
"_get_max_topological_txn", self._get_max_topological_txn, room_id
|
||||
)
|
||||
return RoomStreamToken(topo, stream_ordering)
|
||||
topo = await self.db_pool.runInteraction(
|
||||
"_get_max_topological_txn", self._get_max_topological_txn, room_id
|
||||
)
|
||||
return RoomStreamToken(topo, stream_ordering)
|
||||
|
||||
def get_stream_id_for_event_txn(
|
||||
self,
|
||||
@@ -870,7 +862,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
)
|
||||
|
||||
rows = txn.fetchall()
|
||||
return rows[0][0] if rows else 0
|
||||
# An aggregate function like MAX() will always return one row per group
|
||||
# so we can safely rely on the lookup here. For example, when a we
|
||||
# lookup a `room_id` which does not exist, `rows` will look like
|
||||
# `[(None,)]`
|
||||
return rows[0][0] if rows[0][0] is not None else 0
|
||||
|
||||
@staticmethod
|
||||
def _set_before_and_after(
|
||||
|
||||
@@ -487,12 +487,6 @@ class EventsPersistenceStorage:
|
||||
# extremities in each room
|
||||
new_forward_extremities: Dict[str, Set[str]] = {}
|
||||
|
||||
# map room_id->(type,state_key)->event_id tracking the full
|
||||
# state in each room after adding these events.
|
||||
# This is simply used to prefill the get_current_state_ids
|
||||
# cache
|
||||
current_state_for_room: Dict[str, StateMap[str]] = {}
|
||||
|
||||
# map room_id->(to_delete, to_insert) where to_delete is a list
|
||||
# of type/state keys to remove from current state, and to_insert
|
||||
# is a map (type,key)->event_id giving the state delta in each
|
||||
@@ -628,14 +622,8 @@ class EventsPersistenceStorage:
|
||||
|
||||
state_delta_for_room[room_id] = delta
|
||||
|
||||
# If we have the current_state then lets prefill
|
||||
# the cache with it.
|
||||
if current_state is not None:
|
||||
current_state_for_room[room_id] = current_state
|
||||
|
||||
await self.persist_events_store._persist_events_and_state_updates(
|
||||
chunk,
|
||||
current_state_for_room=current_state_for_room,
|
||||
state_delta_for_room=state_delta_for_room,
|
||||
new_forward_extremities=new_forward_extremities,
|
||||
use_negative_stream_ordering=backfilled,
|
||||
@@ -733,7 +721,8 @@ class EventsPersistenceStorage:
|
||||
|
||||
The first state map is the full new current state and the second
|
||||
is the delta to the existing current state. If both are None then
|
||||
there has been no change.
|
||||
there has been no change. Either or neither can be None if there
|
||||
has been a change.
|
||||
|
||||
The function may prune some old entries from the set of new
|
||||
forward extremities if it's safe to do so.
|
||||
@@ -743,9 +732,6 @@ class EventsPersistenceStorage:
|
||||
the new current state is only returned if we've already calculated
|
||||
it.
|
||||
"""
|
||||
# map from state_group to ((type, key) -> event_id) state map
|
||||
state_groups_map = {}
|
||||
|
||||
# Map from (prev state group, new state group) -> delta state dict
|
||||
state_group_deltas = {}
|
||||
|
||||
@@ -759,16 +745,6 @@ class EventsPersistenceStorage:
|
||||
)
|
||||
continue
|
||||
|
||||
if ctx.state_group in state_groups_map:
|
||||
continue
|
||||
|
||||
# We're only interested in pulling out state that has already
|
||||
# been cached in the context. We'll pull stuff out of the DB later
|
||||
# if necessary.
|
||||
current_state_ids = ctx.get_cached_current_state_ids()
|
||||
if current_state_ids is not None:
|
||||
state_groups_map[ctx.state_group] = current_state_ids
|
||||
|
||||
if ctx.prev_group:
|
||||
state_group_deltas[(ctx.prev_group, ctx.state_group)] = ctx.delta_ids
|
||||
|
||||
@@ -826,18 +802,14 @@ class EventsPersistenceStorage:
|
||||
delta_ids = state_group_deltas.get((old_state_group, new_state_group), None)
|
||||
if delta_ids is not None:
|
||||
# We have a delta from the existing to new current state,
|
||||
# so lets just return that. If we happen to already have
|
||||
# the current state in memory then lets also return that,
|
||||
# but it doesn't matter if we don't.
|
||||
new_state = state_groups_map.get(new_state_group)
|
||||
return new_state, delta_ids, new_latest_event_ids
|
||||
# so lets just return that.
|
||||
return None, delta_ids, new_latest_event_ids
|
||||
|
||||
# Now that we have calculated new_state_groups we need to get
|
||||
# their state IDs so we can resolve to a single state set.
|
||||
missing_state = new_state_groups - set(state_groups_map)
|
||||
if missing_state:
|
||||
group_to_state = await self.state_store._get_state_for_groups(missing_state)
|
||||
state_groups_map.update(group_to_state)
|
||||
state_groups_map = await self.state_store._get_state_for_groups(
|
||||
new_state_groups
|
||||
)
|
||||
|
||||
if len(new_state_groups) == 1:
|
||||
# If there is only one state group, then we know what the current
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
SCHEMA_VERSION = 69 # remember to update the list below when updating
|
||||
SCHEMA_VERSION = 70 # remember to update the list below when updating
|
||||
"""Represents the expectations made by the codebase about the database schema
|
||||
|
||||
This should be incremented whenever the codebase changes its requirements on the
|
||||
@@ -62,6 +62,9 @@ Changes in SCHEMA_VERSION = 68:
|
||||
Changes in SCHEMA_VERSION = 69:
|
||||
- We now write to `device_lists_changes_in_room` table.
|
||||
- Use sequence to generate future `application_services_txns.txn_id`s
|
||||
|
||||
Changes in SCHEMA_VERSION = 70:
|
||||
- event_reference_hashes is no longer written to.
|
||||
"""
|
||||
|
||||
|
||||
|
||||
@@ -148,7 +148,9 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
|
||||
prev_event.internal_metadata.outlier = True
|
||||
persistence = self.hs.get_storage().persistence
|
||||
self.get_success(
|
||||
persistence.persist_event(prev_event, EventContext.for_outlier())
|
||||
persistence.persist_event(
|
||||
prev_event, EventContext.for_outlier(self.hs.get_storage())
|
||||
)
|
||||
)
|
||||
else:
|
||||
|
||||
|
||||
13
tests/http/server/__init__.py
Normal file
13
tests/http/server/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
# Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
100
tests/http/server/_base.py
Normal file
100
tests/http/server/_base.py
Normal file
@@ -0,0 +1,100 @@
|
||||
# Copyright 2022 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unles4s required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from http import HTTPStatus
|
||||
from typing import Any, Callable, Optional, Union
|
||||
from unittest import mock
|
||||
|
||||
from twisted.internet.error import ConnectionDone
|
||||
|
||||
from synapse.http.server import (
|
||||
HTTP_STATUS_REQUEST_CANCELLED,
|
||||
respond_with_html_bytes,
|
||||
respond_with_json,
|
||||
)
|
||||
from synapse.types import JsonDict
|
||||
|
||||
from tests import unittest
|
||||
from tests.server import FakeChannel, ThreadedMemoryReactorClock
|
||||
|
||||
|
||||
class EndpointCancellationTestHelperMixin(unittest.TestCase):
|
||||
"""Provides helper methods for testing cancellation of endpoints."""
|
||||
|
||||
def _test_disconnect(
|
||||
self,
|
||||
reactor: ThreadedMemoryReactorClock,
|
||||
channel: FakeChannel,
|
||||
expect_cancellation: bool,
|
||||
expected_body: Union[bytes, JsonDict],
|
||||
expected_code: Optional[int] = None,
|
||||
) -> None:
|
||||
"""Disconnects an in-flight request and checks the response.
|
||||
|
||||
Args:
|
||||
reactor: The twisted reactor running the request handler.
|
||||
channel: The `FakeChannel` for the request.
|
||||
expect_cancellation: `True` if request processing is expected to be
|
||||
cancelled, `False` if the request should run to completion.
|
||||
expected_body: The expected response for the request.
|
||||
expected_code: The expected status code for the request. Defaults to `200`
|
||||
or `499` depending on `expect_cancellation`.
|
||||
"""
|
||||
# Determine the expected status code.
|
||||
if expected_code is None:
|
||||
if expect_cancellation:
|
||||
expected_code = HTTP_STATUS_REQUEST_CANCELLED
|
||||
else:
|
||||
expected_code = HTTPStatus.OK
|
||||
|
||||
request = channel.request
|
||||
self.assertFalse(
|
||||
channel.is_finished(),
|
||||
"Request finished before we could disconnect - "
|
||||
"was `await_result=False` passed to `make_request`?",
|
||||
)
|
||||
|
||||
# We're about to disconnect the request. This also disconnects the channel, so
|
||||
# we have to rely on mocks to extract the response.
|
||||
respond_method: Callable[..., Any]
|
||||
if isinstance(expected_body, bytes):
|
||||
respond_method = respond_with_html_bytes
|
||||
else:
|
||||
respond_method = respond_with_json
|
||||
|
||||
with mock.patch(
|
||||
f"synapse.http.server.{respond_method.__name__}", wraps=respond_method
|
||||
) as respond_mock:
|
||||
# Disconnect the request.
|
||||
request.connectionLost(reason=ConnectionDone())
|
||||
|
||||
if expect_cancellation:
|
||||
# An immediate cancellation is expected.
|
||||
respond_mock.assert_called_once()
|
||||
args, _kwargs = respond_mock.call_args
|
||||
code, body = args[1], args[2]
|
||||
self.assertEqual(code, expected_code)
|
||||
self.assertEqual(request.code, expected_code)
|
||||
self.assertEqual(body, expected_body)
|
||||
else:
|
||||
respond_mock.assert_not_called()
|
||||
|
||||
# The handler is expected to run to completion.
|
||||
reactor.pump([1.0])
|
||||
respond_mock.assert_called_once()
|
||||
args, _kwargs = respond_mock.call_args
|
||||
code, body = args[1], args[2]
|
||||
self.assertEqual(code, expected_code)
|
||||
self.assertEqual(request.code, expected_code)
|
||||
self.assertEqual(body, expected_body)
|
||||
@@ -109,6 +109,17 @@ class FakeChannel:
|
||||
_ip: str = "127.0.0.1"
|
||||
_producer: Optional[Union[IPullProducer, IPushProducer]] = None
|
||||
resource_usage: Optional[ContextResourceUsage] = None
|
||||
_request: Optional[Request] = None
|
||||
|
||||
@property
|
||||
def request(self) -> Request:
|
||||
assert self._request is not None
|
||||
return self._request
|
||||
|
||||
@request.setter
|
||||
def request(self, request: Request) -> None:
|
||||
assert self._request is None
|
||||
self._request = request
|
||||
|
||||
@property
|
||||
def json_body(self):
|
||||
@@ -322,6 +333,8 @@ def make_request(
|
||||
channel = FakeChannel(site, reactor, ip=client_ip)
|
||||
|
||||
req = request(channel, site)
|
||||
channel.request = req
|
||||
|
||||
req.content = BytesIO(content)
|
||||
# Twisted expects to be at the end of the content when parsing the request.
|
||||
req.content.seek(0, SEEK_END)
|
||||
|
||||
@@ -393,7 +393,7 @@ class EventChainStoreTestCase(HomeserverTestCase):
|
||||
# We need to persist the events to the events and state_events
|
||||
# tables.
|
||||
persist_events_store._store_event_txn(
|
||||
txn, [(e, EventContext()) for e in events]
|
||||
txn, [(e, EventContext(self.hs.get_storage())) for e in events]
|
||||
)
|
||||
|
||||
# Actually call the function that calculates the auth chain stuff.
|
||||
|
||||
@@ -58,15 +58,6 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
|
||||
(room_id, event_id),
|
||||
)
|
||||
|
||||
txn.execute(
|
||||
(
|
||||
"INSERT INTO event_reference_hashes "
|
||||
"(event_id, algorithm, hash) "
|
||||
"VALUES (?, 'sha256', ?)"
|
||||
),
|
||||
(event_id, bytearray(b"ffff")),
|
||||
)
|
||||
|
||||
for i in range(0, 20):
|
||||
self.get_success(
|
||||
self.store.db_pool.runInteraction("insert", insert_event, i)
|
||||
|
||||
@@ -88,6 +88,9 @@ class _DummyStore:
|
||||
|
||||
return groups
|
||||
|
||||
async def get_state_ids_for_group(self, state_group):
|
||||
return self._group_to_state[state_group]
|
||||
|
||||
async def store_state_group(
|
||||
self, event_id, room_id, prev_group, delta_ids, current_state_ids
|
||||
):
|
||||
|
||||
@@ -234,7 +234,9 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
|
||||
event = self.get_success(builder.build(prev_event_ids=[], auth_event_ids=[]))
|
||||
event.internal_metadata.outlier = True
|
||||
self.get_success(
|
||||
self.storage.persistence.persist_event(event, EventContext.for_outlier())
|
||||
self.storage.persistence.persist_event(
|
||||
event, EventContext.for_outlier(self.storage)
|
||||
)
|
||||
)
|
||||
return event
|
||||
|
||||
|
||||
Reference in New Issue
Block a user