1
0

Compare commits

...

13 Commits

Author SHA1 Message Date
H. Shay
4809a0cb35 merge in develop 2023-03-16 14:55:49 -07:00
H. Shay
ad325833e2 Merge branch 'develop' into shay/rework_module 2023-03-16 14:49:47 -07:00
H. Shay
d4ed0a48c1 requested changes 2023-03-16 14:31:10 -07:00
H. Shay
13676fb097 more develop merge fix 2023-03-06 12:53:52 -08:00
H. Shay
7fc487421f Merge branch 'develop' into shay/rework_module 2023-03-06 12:48:59 -08:00
H. Shay
2ab6cece75 add clearer return values 2023-03-06 12:21:27 -08:00
H. Shay
a9b0093d3a update changelog 2023-02-21 14:53:26 -08:00
H. Shay
9b702df296 newsfragment 2023-02-21 14:31:52 -08:00
H. Shay
7b610fca1a update docs with information on v2 callback 2023-02-21 14:27:53 -08:00
H. Shay
b564f29fe2 add a new test to check sending an additional event into room 2023-02-21 14:27:38 -08:00
H. Shay
5cebb3767b update tests to reflect new function signature 2023-02-21 14:27:19 -08:00
H. Shay
2e14bc3745 change callsites to reflect new function signature 2023-02-21 14:26:29 -08:00
H. Shay
aab5fb622e add check_event_allowed_v2 2023-02-21 14:25:49 -08:00
22 changed files with 373 additions and 100 deletions

1
changelog.d/15131.misc Normal file
View File

@@ -0,0 +1 @@
Add a new third party callback `check_event_allowed_v2` that is compatible with new batch persisting mechanisms.

View File

@@ -10,6 +10,75 @@ The available third party rules callbacks are:
### `check_event_allowed` ### `check_event_allowed`
_First introduced in Synapse v1.7x.x
```python
async def check_event_allowed_v2(
event: "synapse.events.EventBase",
state_events: "synapse.types.StateMap",
) -> Tuple[bool, Optional[dict], Optional[dict]]
```
**<span style="color:red">
This callback is very experimental and can and will break without notice. Module developers
are encouraged to implement `check_event_for_spam` from the spam checker category instead.
</span>**
Returns:
- A tuple consisting of:
- a boolean representing whether or not the event is allowed
- an optional dict to form the basis of a replacement event for the event
- an optional dict to form the basis of an additional event to be sent into the
room
Called when processing any incoming event, with the event and a `StateMap`
representing the current state of the room the event is being sent into. A `StateMap` is
a dictionary that maps tuples containing an event type and a state key to the
corresponding state event. For example retrieving the room's `m.room.create` event from
the `state_events` argument would look like this: `state_events.get(("m.room.create", ""))`.
The module must return a boolean indicating whether the event can be allowed.
Note that this callback function processes incoming events coming via federation
traffic (on top of client traffic). This means denying an event might cause the local
copy of the room's history to diverge from that of remote servers. This may cause
federation issues in the room. It is strongly recommended to only deny events using this
callback function if the sender is a local user, or in a private federation in which all
servers are using the same module, with the same configuration.
If the boolean returned by the module is `True`, it may tell Synapse to replace the
event with new data by returning the new event's data as a dictionary. In order to do
that, it is recommended the module calls `event.get_dict()` to get the current event as a
dictionary, and modify the returned dictionary accordingly.
Module writers may also wish to use this check to send a second event into the room along
with the event being checked, if this is the case the module writer must provide a dict that
will form the basis of the event that is to be added to the room and it must be returned by `check_event_allowed_v2`.
This dict will then be turned into an event at the appropriate time and it will be persisted after the event
that triggered it, and if the event that triggered it is in a batch of events for persisting, it will be added to the
end of that batch. Note that the event MAY NOT be a membership event.
If `check_event_allowed_v2` raises an exception, the module is assumed to have failed.
The event will not be accepted but is not treated as explicitly rejected, either.
An HTTP request causing the module check will likely result in a 500 Internal
Server Error.
When the boolean returned by the module is `False`, the event is rejected.
(Module developers should not use exceptions for rejection.)
Note that replacing the event or adding an event only works for events sent by local users, not for events
received over federation.
If multiple modules implement this callback, they will be considered in order. If a
callback returns `True`, Synapse falls through to the next one. The value of the first
callback that does not return `True` will be used. If this happens, Synapse will not call
any of the subsequent implementations of this callback. This callback cannot be used in conjunction with `check_event_allowed`,
only one of these callbacks may be operational at a time - if both `check_event_allowed` and `check_event_allowed_v2`
active only `check_event_allowed` will be executed.
### `check_event_allowed`
_First introduced in Synapse v1.39.0_ _First introduced in Synapse v1.39.0_
```python ```python

View File

@@ -32,6 +32,10 @@ logger = logging.getLogger(__name__)
CHECK_EVENT_ALLOWED_CALLBACK = Callable[ CHECK_EVENT_ALLOWED_CALLBACK = Callable[
[EventBase, StateMap[EventBase]], Awaitable[Tuple[bool, Optional[dict]]] [EventBase, StateMap[EventBase]], Awaitable[Tuple[bool, Optional[dict]]]
] ]
CHECK_EVENT_ALLOWED_V2_CALLBACK = Callable[
[EventBase, StateMap[EventBase]],
Awaitable[Tuple[bool, Optional[dict], Optional[dict]]],
]
ON_CREATE_ROOM_CALLBACK = Callable[[Requester, dict, bool], Awaitable] ON_CREATE_ROOM_CALLBACK = Callable[[Requester, dict, bool], Awaitable]
CHECK_THREEPID_CAN_BE_INVITED_CALLBACK = Callable[ CHECK_THREEPID_CAN_BE_INVITED_CALLBACK = Callable[
[str, str, StateMap[EventBase]], Awaitable[bool] [str, str, StateMap[EventBase]], Awaitable[bool]
@@ -155,6 +159,9 @@ class ThirdPartyEventRules:
self._storage_controllers = hs.get_storage_controllers() self._storage_controllers = hs.get_storage_controllers()
self._check_event_allowed_callbacks: List[CHECK_EVENT_ALLOWED_CALLBACK] = [] self._check_event_allowed_callbacks: List[CHECK_EVENT_ALLOWED_CALLBACK] = []
self._check_event_allowed_v2_callbacks: List[
CHECK_EVENT_ALLOWED_V2_CALLBACK
] = []
self._on_create_room_callbacks: List[ON_CREATE_ROOM_CALLBACK] = [] self._on_create_room_callbacks: List[ON_CREATE_ROOM_CALLBACK] = []
self._check_threepid_can_be_invited_callbacks: List[ self._check_threepid_can_be_invited_callbacks: List[
CHECK_THREEPID_CAN_BE_INVITED_CALLBACK CHECK_THREEPID_CAN_BE_INVITED_CALLBACK
@@ -184,6 +191,7 @@ class ThirdPartyEventRules:
def register_third_party_rules_callbacks( def register_third_party_rules_callbacks(
self, self,
check_event_allowed: Optional[CHECK_EVENT_ALLOWED_CALLBACK] = None, check_event_allowed: Optional[CHECK_EVENT_ALLOWED_CALLBACK] = None,
check_event_allowed_v2: Optional[CHECK_EVENT_ALLOWED_V2_CALLBACK] = None,
on_create_room: Optional[ON_CREATE_ROOM_CALLBACK] = None, on_create_room: Optional[ON_CREATE_ROOM_CALLBACK] = None,
check_threepid_can_be_invited: Optional[ check_threepid_can_be_invited: Optional[
CHECK_THREEPID_CAN_BE_INVITED_CALLBACK CHECK_THREEPID_CAN_BE_INVITED_CALLBACK
@@ -210,6 +218,9 @@ class ThirdPartyEventRules:
if check_event_allowed is not None: if check_event_allowed is not None:
self._check_event_allowed_callbacks.append(check_event_allowed) self._check_event_allowed_callbacks.append(check_event_allowed)
if check_event_allowed_v2 is not None:
self._check_event_allowed_v2_callbacks.append(check_event_allowed_v2)
if on_create_room is not None: if on_create_room is not None:
self._on_create_room_callbacks.append(on_create_room) self._on_create_room_callbacks.append(on_create_room)
@@ -256,7 +267,7 @@ class ThirdPartyEventRules:
self, self,
event: EventBase, event: EventBase,
context: UnpersistedEventContextBase, context: UnpersistedEventContextBase,
) -> Tuple[bool, Optional[dict]]: ) -> Tuple[bool, Optional[dict], Optional[dict]]:
"""Check if a provided event should be allowed in the given context. """Check if a provided event should be allowed in the given context.
The module can return: The module can return:
@@ -264,7 +275,8 @@ class ThirdPartyEventRules:
* False: the event is not allowed, and should be rejected with M_FORBIDDEN. * False: the event is not allowed, and should be rejected with M_FORBIDDEN.
If the event is allowed, the module can also return a dictionary to use as a If the event is allowed, the module can also return a dictionary to use as a
replacement for the event. replacement for the event, and/or return a dictionary to use as the basis for
another event to be sent into the room.
Args: Args:
event: The event to be checked. event: The event to be checked.
@@ -274,8 +286,11 @@ class ThirdPartyEventRules:
The result from the ThirdPartyRules module, as above. The result from the ThirdPartyRules module, as above.
""" """
# Bail out early without hitting the store if we don't have any callbacks to run. # Bail out early without hitting the store if we don't have any callbacks to run.
if len(self._check_event_allowed_callbacks) == 0: if (
return True, None len(self._check_event_allowed_callbacks) == 0
and len(self._check_event_allowed_v2_callbacks) == 0
):
return True, None, None
prev_state_ids = await context.get_prev_state_ids() prev_state_ids = await context.get_prev_state_ids()
@@ -288,35 +303,63 @@ class ThirdPartyEventRules:
# the hashes and signatures. # the hashes and signatures.
event.freeze() event.freeze()
for callback in self._check_event_allowed_callbacks: if len(self._check_event_allowed_callbacks) != 0:
try: for callback in self._check_event_allowed_callbacks:
res, replacement_data = await delay_cancellation( try:
callback(event, state_events) res, replacement_data = await delay_cancellation(
) callback(event, state_events)
except CancelledError: )
raise except CancelledError:
except SynapseError as e: raise
# FIXME: Being able to throw SynapseErrors is relied upon by except SynapseError as e:
# some modules. PR #10386 accidentally broke this ability. # FIXME: Being able to throw SynapseErrors is relied upon by
# That said, we aren't keen on exposing this implementation detail # some modules. PR #10386 accidentally broke this ability.
# to modules and we should one day have a proper way to do what # That said, we aren't keen on exposing this implementation detail
# is wanted. # to modules and we should one day have a proper way to do what
# This module callback needs a rework so that hacks such as # is wanted.
# this one are not necessary. # This module callback needs a rework so that hacks such as
raise e # this one are not necessary.
except Exception: raise e
raise ModuleFailedException( except Exception:
"Failed to run `check_event_allowed` module API callback" raise ModuleFailedException(
) "Failed to run `check_event_allowed` module API callback"
)
# Return if the event shouldn't be allowed or if the module came up with a # Return if the event shouldn't be allowed or if the module came up with a
# replacement dict for the event. # replacement dict for the event.
if res is False: if res is False:
return res, None return res, None, None
elif isinstance(replacement_data, dict): elif isinstance(replacement_data, dict):
return True, replacement_data return True, replacement_data, None
else:
for v2_callback in self._check_event_allowed_v2_callbacks:
try:
res, replacement_data, new_event = await delay_cancellation(
v2_callback(event, state_events)
)
except CancelledError:
raise
except SynapseError as e:
# FIXME: Being able to throw SynapseErrors is relied upon by
# some modules. PR #10386 accidentally broke this ability.
# That said, we aren't keen on exposing this implementation detail
# to modules and we should one day have a proper way to do what
# is wanted.
# This module callback needs a rework so that hacks such as
# this one are not necessary.
raise e
except Exception:
raise ModuleFailedException(
"Failed to run `check_event_allowed_v2` module API callback"
)
return True, None # Return if the event shouldn't be allowed, if the module came up with a
# replacement dict for the event, or if the module wants to send a new event
if res is False:
return res, None, None
else:
return True, replacement_data, new_event
return True, None, None
async def on_create_room( async def on_create_room(
self, requester: Requester, config: dict, is_requester_admin: bool self, requester: Requester, config: dict, is_requester_admin: bool

View File

@@ -1007,6 +1007,7 @@ class FederationHandler:
( (
event, event,
unpersisted_context, unpersisted_context,
_,
) = await self.event_creation_handler.create_new_client_event( ) = await self.event_creation_handler.create_new_client_event(
builder=builder, builder=builder,
prev_event_ids=prev_event_ids, prev_event_ids=prev_event_ids,
@@ -1198,7 +1199,7 @@ class FederationHandler:
}, },
) )
event, _ = await self.event_creation_handler.create_new_client_event( event, _, _ = await self.event_creation_handler.create_new_client_event(
builder=builder builder=builder
) )
@@ -1251,9 +1252,10 @@ class FederationHandler:
( (
event, event,
unpersisted_context, unpersisted_context,
_,
) = await self.event_creation_handler.create_new_client_event(builder=builder) ) = await self.event_creation_handler.create_new_client_event(builder=builder)
event_allowed, _ = await self.third_party_event_rules.check_event_allowed( event_allowed, _, _ = await self.third_party_event_rules.check_event_allowed(
event, unpersisted_context event, unpersisted_context
) )
if not event_allowed: if not event_allowed:
@@ -1446,6 +1448,7 @@ class FederationHandler:
( (
event, event,
unpersisted_context, unpersisted_context,
_,
) = await self.event_creation_handler.create_new_client_event( ) = await self.event_creation_handler.create_new_client_event(
builder=builder builder=builder
) )
@@ -1528,6 +1531,7 @@ class FederationHandler:
( (
event, event,
unpersisted_context, unpersisted_context,
_,
) = await self.event_creation_handler.create_new_client_event( ) = await self.event_creation_handler.create_new_client_event(
builder=builder builder=builder
) )
@@ -1610,6 +1614,7 @@ class FederationHandler:
( (
event, event,
unpersisted_context, unpersisted_context,
_,
) = await self.event_creation_handler.create_new_client_event(builder=builder) ) = await self.event_creation_handler.create_new_client_event(builder=builder)
EventValidator().validate_new(event, self.config) EventValidator().validate_new(event, self.config)

View File

@@ -404,9 +404,11 @@ class FederationEventHandler:
# for knock events, we run the third-party event rules. It's not entirely clear # for knock events, we run the third-party event rules. It's not entirely clear
# why we don't do this for other sorts of membership events. # why we don't do this for other sorts of membership events.
if event.membership == Membership.KNOCK: if event.membership == Membership.KNOCK:
event_allowed, _ = await self._third_party_event_rules.check_event_allowed( (
event, context event_allowed,
) _,
_,
) = await self._third_party_event_rules.check_event_allowed(event, context)
if not event_allowed: if not event_allowed:
logger.info("Sending of knock %s forbidden by third-party rules", event) logger.info("Sending of knock %s forbidden by third-party rules", event)
raise SynapseError( raise SynapseError(

View File

@@ -16,6 +16,7 @@
# limitations under the License. # limitations under the License.
import logging import logging
import random import random
from builtins import dict
from http import HTTPStatus from http import HTTPStatus
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple
@@ -577,7 +578,7 @@ class EventCreationHandler:
state_map: Optional[StateMap[str]] = None, state_map: Optional[StateMap[str]] = None,
for_batch: bool = False, for_batch: bool = False,
current_state_group: Optional[int] = None, current_state_group: Optional[int] = None,
) -> Tuple[EventBase, UnpersistedEventContextBase]: ) -> Tuple[EventBase, UnpersistedEventContextBase, Optional[dict]]:
""" """
Given a dict from a client, create a new event. If bool for_batch is true, will Given a dict from a client, create a new event. If bool for_batch is true, will
create an event using the prev_event_ids, and will create an event context for create an event using the prev_event_ids, and will create an event context for
@@ -649,7 +650,9 @@ class EventCreationHandler:
exceeded exceeded
Returns: Returns:
Tuple of created event, Context Tuple of created event, Context, and an optional event dict to form the basis
of a new event if third_party_rules would like to send an additional event as a
consequence of this event.
""" """
await self.auth_blocking.check_auth_blocking(requester=requester) await self.auth_blocking.check_auth_blocking(requester=requester)
@@ -711,7 +714,7 @@ class EventCreationHandler:
builder.internal_metadata.historical = historical builder.internal_metadata.historical = historical
event, unpersisted_context = await self.create_new_client_event( event, unpersisted_context, new_event = await self.create_new_client_event(
builder=builder, builder=builder,
requester=requester, requester=requester,
allow_no_prev_events=allow_no_prev_events, allow_no_prev_events=allow_no_prev_events,
@@ -765,7 +768,7 @@ class EventCreationHandler:
) )
self.validator.validate_new(event, self.config) self.validator.validate_new(event, self.config)
return event, unpersisted_context return event, unpersisted_context, new_event
async def _is_exempt_from_privacy_policy( async def _is_exempt_from_privacy_policy(
self, builder: EventBuilder, requester: Requester self, builder: EventBuilder, requester: Requester
@@ -1005,7 +1008,11 @@ class EventCreationHandler:
max_retries = 5 max_retries = 5
for i in range(max_retries): for i in range(max_retries):
try: try:
event, unpersisted_context = await self.create_event( (
event,
unpersisted_context,
third_party_event_dict,
) = await self.create_event(
requester, requester,
event_dict, event_dict,
txn_id=txn_id, txn_id=txn_id,
@@ -1054,9 +1061,24 @@ class EventCreationHandler:
Codes.FORBIDDEN, Codes.FORBIDDEN,
) )
events_and_context = [(event, context)]
if third_party_event_dict:
(
third_party_event,
unpersisted_third_party_context,
_,
) = await self.create_event(
requester,
third_party_event_dict,
)
third_party_context = await unpersisted_third_party_context.persist(
third_party_event
)
events_and_context.append((third_party_event, third_party_context))
ev = await self.handle_new_client_event( ev = await self.handle_new_client_event(
requester=requester, requester=requester,
events_and_context=[(event, context)], events_and_context=events_and_context,
ratelimit=ratelimit, ratelimit=ratelimit,
ignore_shadow_ban=ignore_shadow_ban, ignore_shadow_ban=ignore_shadow_ban,
) )
@@ -1086,7 +1108,7 @@ class EventCreationHandler:
state_map: Optional[StateMap[str]] = None, state_map: Optional[StateMap[str]] = None,
for_batch: bool = False, for_batch: bool = False,
current_state_group: Optional[int] = None, current_state_group: Optional[int] = None,
) -> Tuple[EventBase, UnpersistedEventContextBase]: ) -> Tuple[EventBase, UnpersistedEventContextBase, Optional[dict]]:
"""Create a new event for a local client. If bool for_batch is true, will """Create a new event for a local client. If bool for_batch is true, will
create an event using the prev_event_ids, and will create an event context for create an event using the prev_event_ids, and will create an event context for
the event using the parameters state_map and current_state_group, thus these parameters the event using the parameters state_map and current_state_group, thus these parameters
@@ -1135,7 +1157,9 @@ class EventCreationHandler:
batch persisting batch persisting
Returns: Returns:
Tuple of created event, UnpersistedEventContext Tuple of created event, UnpersistedEventContext, and an optional event dict
to form the basis of a new event if third_party_rules would like to send an
additional event as a consequence of this event.
""" """
# Strip down the state_event_ids to only what we need to auth the event. # Strip down the state_event_ids to only what we need to auth the event.
# For example, we don't need extra m.room.member that don't match event.sender # For example, we don't need extra m.room.member that don't match event.sender
@@ -1269,9 +1293,11 @@ class EventCreationHandler:
if requester: if requester:
context.app_service = requester.app_service context.app_service = requester.app_service
res, new_content = await self.third_party_event_rules.check_event_allowed( (
event, context res,
) new_content,
new_event,
) = await self.third_party_event_rules.check_event_allowed(event, context)
if res is False: if res is False:
logger.info( logger.info(
"Event %s forbidden by third-party rules", "Event %s forbidden by third-party rules",
@@ -1291,7 +1317,7 @@ class EventCreationHandler:
await self._validate_event_relation(event) await self._validate_event_relation(event)
logger.debug("Created event %s", event.event_id) logger.debug("Created event %s", event.event_id)
return event, context return event, context, new_event
async def _validate_event_relation(self, event: EventBase) -> None: async def _validate_event_relation(self, event: EventBase) -> None:
""" """
@@ -2046,7 +2072,7 @@ class EventCreationHandler:
max_retries = 5 max_retries = 5
for i in range(max_retries): for i in range(max_retries):
try: try:
event, unpersisted_context = await self.create_event( event, unpersisted_context, _ = await self.create_event(
requester, requester,
{ {
"type": EventTypes.Dummy, "type": EventTypes.Dummy,

View File

@@ -213,6 +213,7 @@ class RoomCreationHandler:
( (
tombstone_event, tombstone_event,
tombstone_unpersisted_context, tombstone_unpersisted_context,
_,
) = await self.event_creation_handler.create_event( ) = await self.event_creation_handler.create_event(
requester, requester,
{ {
@@ -1066,7 +1067,11 @@ class RoomCreationHandler:
content: JsonDict, content: JsonDict,
for_batch: bool, for_batch: bool,
**kwargs: Any, **kwargs: Any,
) -> Tuple[EventBase, synapse.events.snapshot.UnpersistedEventContextBase]: ) -> Tuple[
EventBase,
synapse.events.snapshot.UnpersistedEventContextBase,
Optional[dict],
]:
""" """
Creates an event and associated event context. Creates an event and associated event context.
Args: Args:
@@ -1088,6 +1093,7 @@ class RoomCreationHandler:
( (
new_event, new_event,
new_unpersisted_context, new_unpersisted_context,
third_party_event,
) = await self.event_creation_handler.create_event( ) = await self.event_creation_handler.create_event(
creator, creator,
event_dict, event_dict,
@@ -1103,7 +1109,7 @@ class RoomCreationHandler:
prev_event = [new_event.event_id] prev_event = [new_event.event_id]
state_map[(new_event.type, new_event.state_key)] = new_event.event_id state_map[(new_event.type, new_event.state_key)] = new_event.event_id
return new_event, new_unpersisted_context return new_event, new_unpersisted_context, third_party_event
visibility = room_config.get("visibility", "private") visibility = room_config.get("visibility", "private")
preset_config = room_config.get( preset_config = room_config.get(
@@ -1121,7 +1127,7 @@ class RoomCreationHandler:
) )
creation_content.update({"creator": creator_id}) creation_content.update({"creator": creator_id})
creation_event, unpersisted_creation_context = await create_event( creation_event, unpersisted_creation_context, _ = await create_event(
EventTypes.Create, creation_content, False EventTypes.Create, creation_content, False
) )
creation_context = await unpersisted_creation_context.persist(creation_event) creation_context = await unpersisted_creation_context.persist(creation_event)
@@ -1161,14 +1167,17 @@ class RoomCreationHandler:
current_state_group = event_to_state[member_event_id] current_state_group = event_to_state[member_event_id]
events_to_send = [] events_to_send = []
third_party_events_to_append = []
# We treat the power levels override specially as this needs to be one # We treat the power levels override specially as this needs to be one
# of the first events that get sent into a room. # of the first events that get sent into a room.
pl_content = initial_state.pop((EventTypes.PowerLevels, ""), None) pl_content = initial_state.pop((EventTypes.PowerLevels, ""), None)
if pl_content is not None: if pl_content is not None:
power_event, power_context = await create_event( power_event, power_context, power_tp_event = await create_event(
EventTypes.PowerLevels, pl_content, True EventTypes.PowerLevels, pl_content, True
) )
events_to_send.append((power_event, power_context)) events_to_send.append((power_event, power_context))
if power_tp_event:
third_party_events_to_append.append(power_tp_event)
else: else:
power_level_content: JsonDict = { power_level_content: JsonDict = {
"users": {creator_id: 100}, "users": {creator_id: 100},
@@ -1211,76 +1220,114 @@ class RoomCreationHandler:
# apply those. # apply those.
if power_level_content_override: if power_level_content_override:
power_level_content.update(power_level_content_override) power_level_content.update(power_level_content_override)
pl_event, pl_context = await create_event( pl_event, pl_context, pl_tp_event = await create_event(
EventTypes.PowerLevels, EventTypes.PowerLevels,
power_level_content, power_level_content,
True, True,
) )
events_to_send.append((pl_event, pl_context)) events_to_send.append((pl_event, pl_context))
if pl_tp_event:
third_party_events_to_append.append(pl_tp_event)
if room_alias and (EventTypes.CanonicalAlias, "") not in initial_state: if room_alias and (EventTypes.CanonicalAlias, "") not in initial_state:
room_alias_event, room_alias_context = await create_event( room_alias_event, room_alias_context, ra_tp_event = await create_event(
EventTypes.CanonicalAlias, {"alias": room_alias.to_string()}, True EventTypes.CanonicalAlias, {"alias": room_alias.to_string()}, True
) )
events_to_send.append((room_alias_event, room_alias_context)) events_to_send.append((room_alias_event, room_alias_context))
if ra_tp_event:
third_party_events_to_append.append(ra_tp_event)
if (EventTypes.JoinRules, "") not in initial_state: if (EventTypes.JoinRules, "") not in initial_state:
join_rules_event, join_rules_context = await create_event( join_rules_event, join_rules_context, jr_tp_event = await create_event(
EventTypes.JoinRules, EventTypes.JoinRules,
{"join_rule": config["join_rules"]}, {"join_rule": config["join_rules"]},
True, True,
) )
events_to_send.append((join_rules_event, join_rules_context)) events_to_send.append((join_rules_event, join_rules_context))
if jr_tp_event:
third_party_events_to_append.append(jr_tp_event)
if (EventTypes.RoomHistoryVisibility, "") not in initial_state: if (EventTypes.RoomHistoryVisibility, "") not in initial_state:
visibility_event, visibility_context = await create_event( visibility_event, visibility_context, vis_tp_event = await create_event(
EventTypes.RoomHistoryVisibility, EventTypes.RoomHistoryVisibility,
{"history_visibility": config["history_visibility"]}, {"history_visibility": config["history_visibility"]},
True, True,
) )
events_to_send.append((visibility_event, visibility_context)) events_to_send.append((visibility_event, visibility_context))
if vis_tp_event:
third_party_events_to_append.append(vis_tp_event)
if config["guest_can_join"]: if config["guest_can_join"]:
if (EventTypes.GuestAccess, "") not in initial_state: if (EventTypes.GuestAccess, "") not in initial_state:
guest_access_event, guest_access_context = await create_event( (
guest_access_event,
guest_access_context,
ga_tp_event,
) = await create_event(
EventTypes.GuestAccess, EventTypes.GuestAccess,
{EventContentFields.GUEST_ACCESS: GuestAccess.CAN_JOIN}, {EventContentFields.GUEST_ACCESS: GuestAccess.CAN_JOIN},
True, True,
) )
events_to_send.append((guest_access_event, guest_access_context)) events_to_send.append((guest_access_event, guest_access_context))
if ga_tp_event:
third_party_events_to_append.append(ga_tp_event)
for (etype, state_key), content in initial_state.items(): for (etype, state_key), content in initial_state.items():
event, context = await create_event( event, context, tp_event = await create_event(
etype, content, True, state_key=state_key etype, content, True, state_key=state_key
) )
events_to_send.append((event, context)) events_to_send.append((event, context))
if tp_event:
third_party_events_to_append.append(tp_event)
if config["encrypted"]: if config["encrypted"]:
encryption_event, encryption_context = await create_event( encryption_event, encryption_context, encrypt_tp_event = await create_event(
EventTypes.RoomEncryption, EventTypes.RoomEncryption,
{"algorithm": RoomEncryptionAlgorithms.DEFAULT}, {"algorithm": RoomEncryptionAlgorithms.DEFAULT},
True, True,
state_key="", state_key="",
) )
events_to_send.append((encryption_event, encryption_context)) events_to_send.append((encryption_event, encryption_context))
if encrypt_tp_event:
third_party_events_to_append.append(encrypt_tp_event)
if "name" in room_config: if "name" in room_config:
name = room_config["name"] name = room_config["name"]
name_event, name_context = await create_event( name_event, name_context, name_tp_event = await create_event(
EventTypes.Name, EventTypes.Name,
{"name": name}, {"name": name},
True, True,
) )
events_to_send.append((name_event, name_context)) events_to_send.append((name_event, name_context))
if name_tp_event:
third_party_events_to_append.append(name_tp_event)
if "topic" in room_config: if "topic" in room_config:
topic = room_config["topic"] topic = room_config["topic"]
topic_event, topic_context = await create_event( topic_event, topic_context, topic_tp_event = await create_event(
EventTypes.Topic, EventTypes.Topic,
{"topic": topic}, {"topic": topic},
True, True,
) )
events_to_send.append((topic_event, topic_context)) events_to_send.append((topic_event, topic_context))
if topic_tp_event:
third_party_events_to_append.append(topic_tp_event)
for event_dict in third_party_events_to_append:
(
event,
unpersisted_context,
_,
) = await self.event_creation_handler.create_event(
creator,
event_dict,
prev_event_ids=prev_event,
state_map=state_map,
for_batch=True,
current_state_group=current_state_group,
)
context = await unpersisted_context.persist(event)
events_to_send.append((event, context))
datastore = self.hs.get_datastores().state datastore = self.hs.get_datastores().state
events_and_context = ( events_and_context = (

View File

@@ -327,7 +327,11 @@ class RoomBatchHandler:
# Mark all events as historical # Mark all events as historical
event_dict["content"][EventContentFields.MSC2716_HISTORICAL] = True event_dict["content"][EventContentFields.MSC2716_HISTORICAL] = True
event, unpersisted_context = await self.event_creation_handler.create_event( (
event,
unpersisted_context,
_,
) = await self.event_creation_handler.create_event(
await self.create_requester_for_user_id_from_app_service( await self.create_requester_for_user_id_from_app_service(
ev["sender"], app_service_requester.app_service ev["sender"], app_service_requester.app_service
), ),

View File

@@ -418,6 +418,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
( (
event, event,
unpersisted_context, unpersisted_context,
third_party_event,
) = await self.event_creation_handler.create_event( ) = await self.event_creation_handler.create_event(
requester, requester,
{ {
@@ -472,6 +473,20 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
ratelimit=ratelimit, ratelimit=ratelimit,
) )
) )
if third_party_event:
(
tp_event,
tp_unpersisted_context,
_,
) = await self.event_creation_handler.create_event(
requester,
third_party_event,
prev_event_ids=[result_event.event_id],
)
tp_context = await tp_unpersisted_context.persist(tp_event)
await self.event_creation_handler.handle_new_client_event(
requester, events_and_context=[(tp_event, tp_context)]
)
if event.membership == Membership.LEAVE: if event.membership == Membership.LEAVE:
if prev_member_event_id: if prev_member_event_id:
@@ -1951,6 +1966,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
( (
event, event,
unpersisted_context, unpersisted_context,
third_party_event_dict,
) = await self.event_creation_handler.create_event( ) = await self.event_creation_handler.create_event(
requester, requester,
event_dict, event_dict,
@@ -1962,10 +1978,24 @@ class RoomMemberMasterHandler(RoomMemberHandler):
context = await unpersisted_context.persist(event) context = await unpersisted_context.persist(event)
event.internal_metadata.out_of_band_membership = True event.internal_metadata.out_of_band_membership = True
events_and_context = [(event, context)]
if third_party_event_dict:
(
third_party_event,
third_party_unpersisted_context,
_,
) = await self.event_creation_handler.create_event(
requester, third_party_event_dict
)
third_party_context = await third_party_unpersisted_context.persist(
event
)
events_and_context.append((third_party_event, third_party_context))
result_event = ( result_event = (
await self.event_creation_handler.handle_new_client_event( await self.event_creation_handler.handle_new_client_event(
requester, requester,
events_and_context=[(event, context)], events_and_context=events_and_context,
extra_users=[UserID.from_string(target_user)], extra_users=[UserID.from_string(target_user)],
) )
) )

View File

@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging import logging
from typing import Tuple from typing import Optional, Tuple
from twisted.test.proto_helpers import MemoryReactor from twisted.test.proto_helpers import MemoryReactor
@@ -81,7 +81,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
def _create_duplicate_event( def _create_duplicate_event(
self, txn_id: str self, txn_id: str
) -> Tuple[EventBase, UnpersistedEventContextBase]: ) -> Tuple[EventBase, UnpersistedEventContextBase, Optional[dict]]:
"""Create a new event with the given transaction ID. All events produced """Create a new event with the given transaction ID. All events produced
by this method will be considered duplicates. by this method will be considered duplicates.
""" """
@@ -109,7 +109,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
txn_id = "something_suitably_random" txn_id = "something_suitably_random"
event1, unpersisted_context = self._create_duplicate_event(txn_id) event1, unpersisted_context, _ = self._create_duplicate_event(txn_id)
context = self.get_success(unpersisted_context.persist(event1)) context = self.get_success(unpersisted_context.persist(event1))
ret_event1 = self.get_success( ret_event1 = self.get_success(
@@ -122,7 +122,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
self.assertEqual(event1.event_id, ret_event1.event_id) self.assertEqual(event1.event_id, ret_event1.event_id)
event2, unpersisted_context = self._create_duplicate_event(txn_id) event2, unpersisted_context, _ = self._create_duplicate_event(txn_id)
context = self.get_success(unpersisted_context.persist(event2)) context = self.get_success(unpersisted_context.persist(event2))
# We want to test that the deduplication at the persit event end works, # We want to test that the deduplication at the persit event end works,
@@ -144,7 +144,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
# Let's test that calling `persist_event` directly also does the right # Let's test that calling `persist_event` directly also does the right
# thing. # thing.
event3, unpersisted_context = self._create_duplicate_event(txn_id) event3, unpersisted_context, _ = self._create_duplicate_event(txn_id)
context = self.get_success(unpersisted_context.persist(event3)) context = self.get_success(unpersisted_context.persist(event3))
self.assertNotEqual(event1.event_id, event3.event_id) self.assertNotEqual(event1.event_id, event3.event_id)
@@ -160,8 +160,9 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
# Let's test that calling `persist_events` directly also does the right # Let's test that calling `persist_events` directly also does the right
# thing. # thing.
event4, unpersisted_context = self._create_duplicate_event(txn_id) event4, unpersisted_context, _ = self._create_duplicate_event(txn_id)
context = self.get_success(unpersisted_context.persist(event4)) context = self.get_success(unpersisted_context.persist(event4))
self.assertNotEqual(event1.event_id, event3.event_id) self.assertNotEqual(event1.event_id, event3.event_id)
events, _ = self.get_success( events, _ = self.get_success(
@@ -181,9 +182,9 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
txn_id = "something_else_suitably_random" txn_id = "something_else_suitably_random"
# Create two duplicate events to persist at the same time # Create two duplicate events to persist at the same time
event1, unpersisted_context1 = self._create_duplicate_event(txn_id) event1, unpersisted_context1, _ = self._create_duplicate_event(txn_id)
context1 = self.get_success(unpersisted_context1.persist(event1)) context1 = self.get_success(unpersisted_context1.persist(event1))
event2, unpersisted_context2 = self._create_duplicate_event(txn_id) event2, unpersisted_context2, _ = self._create_duplicate_event(txn_id)
context2 = self.get_success(unpersisted_context2.persist(event2)) context2 = self.get_success(unpersisted_context2.persist(event2))
# Ensure their event IDs are different to start with # Ensure their event IDs are different to start with
@@ -209,7 +210,7 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
memberEvent, _ = self._create_and_persist_member_event() memberEvent, _ = self._create_and_persist_member_event()
# Try to create the event with empty prev_events bit with some auth_events # Try to create the event with empty prev_events bit with some auth_events
event, _ = self.get_success( event, _, _ = self.get_success(
self.handler.create_event( self.handler.create_event(
self.requester, self.requester,
{ {

View File

@@ -507,7 +507,8 @@ class RegistrationTestCase(unittest.HomeserverTestCase):
# Lower the permissions of the inviter. # Lower the permissions of the inviter.
event_creation_handler = self.hs.get_event_creation_handler() event_creation_handler = self.hs.get_event_creation_handler()
requester = create_requester(inviter) requester = create_requester(inviter)
event, unpersisted_context = self.get_success(
event, unpersisted_context, _ = self.get_success(
event_creation_handler.create_event( event_creation_handler.create_event(
requester, requester,
{ {

View File

@@ -965,7 +965,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
}, },
) )
event, unpersisted_context = self.get_success( event, unpersisted_context, _ = self.get_success(
self.event_creation_handler.create_new_client_event(builder) self.event_creation_handler.create_new_client_event(builder)
) )

View File

@@ -130,7 +130,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
# Create a new message event, and try to evaluate it under the dodgy # Create a new message event, and try to evaluate it under the dodgy
# power level event. # power level event.
event, unpersisted_context = self.get_success( event, unpersisted_context, _ = self.get_success(
self.event_creation_handler.create_event( self.event_creation_handler.create_event(
self.requester, self.requester,
{ {
@@ -171,7 +171,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
"""Ensure that push rules are not calculated when disabled in the config""" """Ensure that push rules are not calculated when disabled in the config"""
# Create a new message event which should cause a notification. # Create a new message event which should cause a notification.
event, unpersisted_context = self.get_success( event, unpersisted_context, _ = self.get_success(
self.event_creation_handler.create_event( self.event_creation_handler.create_event(
self.requester, self.requester,
{ {
@@ -202,7 +202,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
) -> bool: ) -> bool:
"""Returns true iff the `mentions` trigger an event push action.""" """Returns true iff the `mentions` trigger an event push action."""
# Create a new message event which should cause a notification. # Create a new message event which should cause a notification.
event, unpersisted_context = self.get_success( event, unpersisted_context, _ = self.get_success(
self.event_creation_handler.create_event( self.event_creation_handler.create_event(
self.requester, self.requester,
{ {
@@ -378,7 +378,7 @@ class TestBulkPushRuleEvaluator(HomeserverTestCase):
bulk_evaluator = BulkPushRuleEvaluator(self.hs) bulk_evaluator = BulkPushRuleEvaluator(self.hs)
# Create & persist an event to use as the parent of the relation. # Create & persist an event to use as the parent of the relation.
event, unpersisted_context = self.get_success( event, unpersisted_context, _ = self.get_success(
self.event_creation_handler.create_event( self.event_creation_handler.create_event(
self.requester, self.requester,
{ {

View File

@@ -2935,7 +2935,7 @@ class UserMembershipRestTestCase(unittest.HomeserverTestCase):
}, },
) )
event, unpersisted_context = self.get_success( event, unpersisted_context, _ = self.get_success(
event_creation_handler.create_new_client_event(builder) event_creation_handler.create_new_client_event(builder)
) )

View File

@@ -275,6 +275,46 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
ev = channel.json_body ev = channel.json_body
self.assertEqual(ev["content"]["x"], "y") self.assertEqual(ev["content"]["x"], "y")
def test_add_event(self) -> None:
# needs checking of combo of return conditions, ie replace event and send event
async def check(
ev: EventBase, state: StateMap[EventBase]
) -> Tuple[bool, Optional[JsonDict], Optional[dict]]:
event_dict = {
"type": "m.room.test",
"room_id": self.room_id,
"sender": self.user_id,
"content": {
"creator": "test_user",
"body": "message",
"msgtype": "message",
},
}
if ev.type == "message":
return True, None, event_dict
else:
return True, None, None
self.hs.get_third_party_event_rules()._check_event_allowed_v2_callbacks = [
check
]
channel = self.make_request(
"PUT",
"/_matrix/client/r0/rooms/%s/send/message/1" % self.room_id,
{"x": "x"},
access_token=self.tok,
)
self.assertEqual(channel.code, 200, channel.result)
events = self.get_success(
self.hs.get_datastores().main.get_forward_extremities_for_room(self.room_id)
)
event = events[1]
e = self.get_success(self.hs.get_datastores().main.get_event(event["event_id"]))
self.assertEqual("m.room.test", e.type)
def test_message_edit(self) -> None: def test_message_edit(self) -> None:
"""Ensure that the module doesn't cause issues with edited messages.""" """Ensure that the module doesn't cause issues with edited messages."""

View File

@@ -522,7 +522,8 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
latest_event_ids = self.get_success( latest_event_ids = self.get_success(
self.store.get_prev_events_for_room(room_id) self.store.get_prev_events_for_room(room_id)
) )
event, unpersisted_context = self.get_success(
event, unpersisted_context, _ = self.get_success(
event_handler.create_event( event_handler.create_event(
self.requester, self.requester,
{ {
@@ -545,7 +546,7 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
assert state_ids1 is not None assert state_ids1 is not None
state1 = set(state_ids1.values()) state1 = set(state_ids1.values())
event, unpersisted_context = self.get_success( event, unpersisted_context, _ = self.get_success(
event_handler.create_event( event_handler.create_event(
self.requester, self.requester,
{ {

View File

@@ -74,7 +74,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
}, },
) )
event, unpersisted_context = self.get_success( event, unpersisted_context, _ = self.get_success(
self.event_creation_handler.create_new_client_event(builder) self.event_creation_handler.create_new_client_event(builder)
) )
@@ -98,7 +98,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
}, },
) )
event, unpersisted_context = self.get_success( event, unpersisted_context, _ = self.get_success(
self.event_creation_handler.create_new_client_event(builder) self.event_creation_handler.create_new_client_event(builder)
) )
@@ -123,7 +123,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
}, },
) )
event, unpersisted_context = self.get_success( event, unpersisted_context, _ = self.get_success(
self.event_creation_handler.create_new_client_event(builder) self.event_creation_handler.create_new_client_event(builder)
) )
@@ -265,7 +265,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
def internal_metadata(self) -> _EventInternalMetadata: def internal_metadata(self) -> _EventInternalMetadata:
return self._base_builder.internal_metadata return self._base_builder.internal_metadata
event_1, unpersisted_context_1 = self.get_success( event_1, unpersisted_context_1, _ = self.get_success(
self.event_creation_handler.create_new_client_event( self.event_creation_handler.create_new_client_event(
cast( cast(
EventBuilder, EventBuilder,
@@ -290,7 +290,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
self.get_success(self._persistence.persist_event(event_1, context_1)) self.get_success(self._persistence.persist_event(event_1, context_1))
event_2, unpersisted_context_2 = self.get_success( event_2, unpersisted_context_2, _ = self.get_success(
self.event_creation_handler.create_new_client_event( self.event_creation_handler.create_new_client_event(
cast( cast(
EventBuilder, EventBuilder,
@@ -431,7 +431,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
}, },
) )
redaction_event, unpersisted_context = self.get_success( redaction_event, unpersisted_context, _ = self.get_success(
self.event_creation_handler.create_new_client_event(builder) self.event_creation_handler.create_new_client_event(builder)
) )

View File

@@ -67,7 +67,7 @@ class StateStoreTestCase(HomeserverTestCase):
}, },
) )
event, unpersisted_context = self.get_success( event, unpersisted_context, _ = self.get_success(
self.event_creation_handler.create_new_client_event(builder) self.event_creation_handler.create_new_client_event(builder)
) )
@@ -521,7 +521,7 @@ class StateStoreTestCase(HomeserverTestCase):
}, },
) )
event1, unpersisted_context1 = self.get_success( event1, unpersisted_context1, _ = self.get_success(
self.event_creation_handler.create_new_client_event(builder) self.event_creation_handler.create_new_client_event(builder)
) )
events_and_context.append((event1, unpersisted_context1)) events_and_context.append((event1, unpersisted_context1))
@@ -537,7 +537,7 @@ class StateStoreTestCase(HomeserverTestCase):
}, },
) )
event2, unpersisted_context2 = self.get_success( event2, unpersisted_context2, _ = self.get_success(
self.event_creation_handler.create_new_client_event(builder2) self.event_creation_handler.create_new_client_event(builder2)
) )
events_and_context.append((event2, unpersisted_context2)) events_and_context.append((event2, unpersisted_context2))
@@ -552,7 +552,7 @@ class StateStoreTestCase(HomeserverTestCase):
}, },
) )
event3, unpersisted_context3 = self.get_success( event3, unpersisted_context3, _ = self.get_success(
self.event_creation_handler.create_new_client_event(builder3) self.event_creation_handler.create_new_client_event(builder3)
) )
events_and_context.append((event3, unpersisted_context3)) events_and_context.append((event3, unpersisted_context3))
@@ -568,7 +568,7 @@ class StateStoreTestCase(HomeserverTestCase):
}, },
) )
event4, unpersisted_context4 = self.get_success( event4, unpersisted_context4, _ = self.get_success(
self.event_creation_handler.create_new_client_event(builder4) self.event_creation_handler.create_new_client_event(builder4)
) )
events_and_context.append((event4, unpersisted_context4)) events_and_context.append((event4, unpersisted_context4))

View File

@@ -95,6 +95,7 @@ async def create_event(
( (
event, event,
unpersisted_context, unpersisted_context,
_,
) = await hs.get_event_creation_handler().create_new_client_event( ) = await hs.get_event_creation_handler().create_new_client_event(
builder, prev_event_ids=prev_event_ids builder, prev_event_ids=prev_event_ids
) )

View File

@@ -207,7 +207,7 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
}, },
) )
event, unpersisted_context = self.get_success( event, unpersisted_context, _ = self.get_success(
self.event_creation_handler.create_new_client_event(builder) self.event_creation_handler.create_new_client_event(builder)
) )
context = self.get_success(unpersisted_context.persist(event)) context = self.get_success(unpersisted_context.persist(event))
@@ -233,7 +233,7 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
}, },
) )
event, unpersisted_context = self.get_success( event, unpersisted_context, _ = self.get_success(
self.event_creation_handler.create_new_client_event(builder) self.event_creation_handler.create_new_client_event(builder)
) )
context = self.get_success(unpersisted_context.persist(event)) context = self.get_success(unpersisted_context.persist(event))
@@ -256,7 +256,7 @@ class FilterEventsForServerTestCase(unittest.HomeserverTestCase):
}, },
) )
event, unpersisted_context = self.get_success( event, unpersisted_context, _ = self.get_success(
self.event_creation_handler.create_new_client_event(builder) self.event_creation_handler.create_new_client_event(builder)
) )
context = self.get_success(unpersisted_context.persist(event)) context = self.get_success(unpersisted_context.persist(event))

View File

@@ -723,7 +723,7 @@ class HomeserverTestCase(TestCase):
event_creator = self.hs.get_event_creation_handler() event_creator = self.hs.get_event_creation_handler()
requester = create_requester(user) requester = create_requester(user)
event, unpersisted_context = self.get_success( event, unpersisted_context, _ = self.get_success(
event_creator.create_event( event_creator.create_event(
requester, requester,
{ {

View File

@@ -335,9 +335,11 @@ async def create_room(hs: HomeServer, room_id: str, creator_id: str) -> None:
}, },
) )
event, unpersisted_context = await event_creation_handler.create_new_client_event( (
builder event,
) unpersisted_context,
_,
) = await event_creation_handler.create_new_client_event(builder)
context = await unpersisted_context.persist(event) context = await unpersisted_context.persist(event)
await persistence_store.persist_event(event, context) await persistence_store.persist_event(event, context)