Merge pull request #6301 from matrix-org/babolivier/msc2326
* commit 'f496d2587': Incorporate review Factor out an _AsyncEventContextImpl (#6298) Update synapse/storage/data_stores/main/schema/delta/56/event_labels.sql Add more data to the event_labels table and fix the indexes Add unstable feature flag Lint Incorporate review Lint Changelog Add integration tests for /messages Add more integration testing Add integration tests for sync Add unit tests Add index on label Implement filtering Store labels for new events Add database table for keeping track of labels on events
This commit is contained in:
1
changelog.d/6298.misc
Normal file
1
changelog.d/6298.misc
Normal file
@@ -0,0 +1 @@
|
||||
Refactor EventContext for clarity.
|
||||
1
changelog.d/6301.feature
Normal file
1
changelog.d/6301.feature
Normal file
@@ -0,0 +1 @@
|
||||
Implement label-based filtering on `/sync` and `/messages` ([MSC2326](https://github.com/matrix-org/matrix-doc/pull/2326)).
|
||||
@@ -141,3 +141,10 @@ class LimitBlockingTypes(object):
|
||||
|
||||
MONTHLY_ACTIVE_USER = "monthly_active_user"
|
||||
HS_DISABLED = "hs_disabled"
|
||||
|
||||
|
||||
class EventContentFields(object):
|
||||
"""Fields found in events' content, regardless of type."""
|
||||
|
||||
# Labels for the event, cf https://github.com/matrix-org/matrix-doc/pull/2326
|
||||
LABELS = "org.matrix.labels"
|
||||
|
||||
@@ -20,6 +20,7 @@ from jsonschema import FormatChecker
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.constants import EventContentFields
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.storage.presence import UserPresenceState
|
||||
from synapse.types import RoomID, UserID
|
||||
@@ -66,6 +67,10 @@ ROOM_EVENT_FILTER_SCHEMA = {
|
||||
"contains_url": {"type": "boolean"},
|
||||
"lazy_load_members": {"type": "boolean"},
|
||||
"include_redundant_members": {"type": "boolean"},
|
||||
# Include or exclude events with the provided labels.
|
||||
# cf https://github.com/matrix-org/matrix-doc/pull/2326
|
||||
"org.matrix.labels": {"type": "array", "items": {"type": "string"}},
|
||||
"org.matrix.not_labels": {"type": "array", "items": {"type": "string"}},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -259,6 +264,9 @@ class Filter(object):
|
||||
|
||||
self.contains_url = self.filter_json.get("contains_url", None)
|
||||
|
||||
self.labels = self.filter_json.get("org.matrix.labels", None)
|
||||
self.not_labels = self.filter_json.get("org.matrix.not_labels", [])
|
||||
|
||||
def filters_all_types(self):
|
||||
return "*" in self.not_types
|
||||
|
||||
@@ -282,6 +290,7 @@ class Filter(object):
|
||||
room_id = None
|
||||
ev_type = "m.presence"
|
||||
contains_url = False
|
||||
labels = []
|
||||
else:
|
||||
sender = event.get("sender", None)
|
||||
if not sender:
|
||||
@@ -300,10 +309,11 @@ class Filter(object):
|
||||
content = event.get("content", {})
|
||||
# check if there is a string url field in the content for filtering purposes
|
||||
contains_url = isinstance(content.get("url"), text_type)
|
||||
labels = content.get(EventContentFields.LABELS, [])
|
||||
|
||||
return self.check_fields(room_id, sender, ev_type, contains_url)
|
||||
return self.check_fields(room_id, sender, ev_type, labels, contains_url)
|
||||
|
||||
def check_fields(self, room_id, sender, event_type, contains_url):
|
||||
def check_fields(self, room_id, sender, event_type, labels, contains_url):
|
||||
"""Checks whether the filter matches the given event fields.
|
||||
|
||||
Returns:
|
||||
@@ -313,6 +323,7 @@ class Filter(object):
|
||||
"rooms": lambda v: room_id == v,
|
||||
"senders": lambda v: sender == v,
|
||||
"types": lambda v: _matches_wildcard(event_type, v),
|
||||
"labels": lambda v: v in labels,
|
||||
}
|
||||
|
||||
for name, match_func in literal_keys.items():
|
||||
|
||||
@@ -37,9 +37,6 @@ class EventContext:
|
||||
delta_ids (dict[(str, str), str]): Delta from ``prev_group``.
|
||||
(type, state_key) -> event_id. ``None`` for an outlier.
|
||||
|
||||
prev_state_events (?): XXX: is this ever set to anything other than
|
||||
the empty list?
|
||||
|
||||
app_service: FIXME
|
||||
|
||||
_current_state_ids (dict[(str, str), str]|None):
|
||||
@@ -51,36 +48,16 @@ class EventContext:
|
||||
The current state map excluding the current event. None if outlier
|
||||
or we haven't fetched the state from DB yet.
|
||||
(type, state_key) -> event_id
|
||||
|
||||
_fetching_state_deferred (Deferred|None): Resolves when *_state_ids have
|
||||
been calculated. None if we haven't started calculating yet
|
||||
|
||||
_event_type (str): The type of the event the context is associated with.
|
||||
Only set when state has not been fetched yet.
|
||||
|
||||
_event_state_key (str|None): The state_key of the event the context is
|
||||
associated with. Only set when state has not been fetched yet.
|
||||
|
||||
_prev_state_id (str|None): If the event associated with the context is
|
||||
a state event, then `_prev_state_id` is the event_id of the state
|
||||
that was replaced.
|
||||
Only set when state has not been fetched yet.
|
||||
"""
|
||||
|
||||
state_group = attr.ib(default=None)
|
||||
rejected = attr.ib(default=False)
|
||||
prev_group = attr.ib(default=None)
|
||||
delta_ids = attr.ib(default=None)
|
||||
prev_state_events = attr.ib(default=attr.Factory(list))
|
||||
app_service = attr.ib(default=None)
|
||||
|
||||
_current_state_ids = attr.ib(default=None)
|
||||
_prev_state_ids = attr.ib(default=None)
|
||||
_prev_state_id = attr.ib(default=None)
|
||||
|
||||
_event_type = attr.ib(default=None)
|
||||
_event_state_key = attr.ib(default=None)
|
||||
_fetching_state_deferred = attr.ib(default=None)
|
||||
_current_state_ids = attr.ib(default=None)
|
||||
|
||||
@staticmethod
|
||||
def with_state(
|
||||
@@ -90,7 +67,6 @@ class EventContext:
|
||||
current_state_ids=current_state_ids,
|
||||
prev_state_ids=prev_state_ids,
|
||||
state_group=state_group,
|
||||
fetching_state_deferred=defer.succeed(None),
|
||||
prev_group=prev_group,
|
||||
delta_ids=delta_ids,
|
||||
)
|
||||
@@ -125,7 +101,6 @@ class EventContext:
|
||||
"rejected": self.rejected,
|
||||
"prev_group": self.prev_group,
|
||||
"delta_ids": _encode_state_dict(self.delta_ids),
|
||||
"prev_state_events": self.prev_state_events,
|
||||
"app_service_id": self.app_service.id if self.app_service else None,
|
||||
}
|
||||
|
||||
@@ -141,7 +116,7 @@ class EventContext:
|
||||
Returns:
|
||||
EventContext
|
||||
"""
|
||||
context = EventContext(
|
||||
context = _AsyncEventContextImpl(
|
||||
# We use the state_group and prev_state_id stuff to pull the
|
||||
# current_state_ids out of the DB and construct prev_state_ids.
|
||||
prev_state_id=input["prev_state_id"],
|
||||
@@ -151,7 +126,6 @@ class EventContext:
|
||||
prev_group=input["prev_group"],
|
||||
delta_ids=_decode_state_dict(input["delta_ids"]),
|
||||
rejected=input["rejected"],
|
||||
prev_state_events=input["prev_state_events"],
|
||||
)
|
||||
|
||||
app_service_id = input["app_service_id"]
|
||||
@@ -170,14 +144,7 @@ class EventContext:
|
||||
Maps a (type, state_key) to the event ID of the state event matching
|
||||
this tuple.
|
||||
"""
|
||||
|
||||
if not self._fetching_state_deferred:
|
||||
self._fetching_state_deferred = run_in_background(
|
||||
self._fill_out_state, store
|
||||
)
|
||||
|
||||
yield make_deferred_yieldable(self._fetching_state_deferred)
|
||||
|
||||
yield self._ensure_fetched(store)
|
||||
return self._current_state_ids
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@@ -190,14 +157,7 @@ class EventContext:
|
||||
Maps a (type, state_key) to the event ID of the state event matching
|
||||
this tuple.
|
||||
"""
|
||||
|
||||
if not self._fetching_state_deferred:
|
||||
self._fetching_state_deferred = run_in_background(
|
||||
self._fill_out_state, store
|
||||
)
|
||||
|
||||
yield make_deferred_yieldable(self._fetching_state_deferred)
|
||||
|
||||
yield self._ensure_fetched(store)
|
||||
return self._prev_state_ids
|
||||
|
||||
def get_cached_current_state_ids(self):
|
||||
@@ -211,6 +171,44 @@ class EventContext:
|
||||
|
||||
return self._current_state_ids
|
||||
|
||||
def _ensure_fetched(self, store):
|
||||
return defer.succeed(None)
|
||||
|
||||
|
||||
@attr.s(slots=True)
|
||||
class _AsyncEventContextImpl(EventContext):
|
||||
"""
|
||||
An implementation of EventContext which fetches _current_state_ids and
|
||||
_prev_state_ids from the database on demand.
|
||||
|
||||
Attributes:
|
||||
|
||||
_fetching_state_deferred (Deferred|None): Resolves when *_state_ids have
|
||||
been calculated. None if we haven't started calculating yet
|
||||
|
||||
_event_type (str): The type of the event the context is associated with.
|
||||
|
||||
_event_state_key (str): The state_key of the event the context is
|
||||
associated with.
|
||||
|
||||
_prev_state_id (str|None): If the event associated with the context is
|
||||
a state event, then `_prev_state_id` is the event_id of the state
|
||||
that was replaced.
|
||||
"""
|
||||
|
||||
_prev_state_id = attr.ib(default=None)
|
||||
_event_type = attr.ib(default=None)
|
||||
_event_state_key = attr.ib(default=None)
|
||||
_fetching_state_deferred = attr.ib(default=None)
|
||||
|
||||
def _ensure_fetched(self, store):
|
||||
if not self._fetching_state_deferred:
|
||||
self._fetching_state_deferred = run_in_background(
|
||||
self._fill_out_state, store
|
||||
)
|
||||
|
||||
return make_deferred_yieldable(self._fetching_state_deferred)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _fill_out_state(self, store):
|
||||
"""Called to populate the _current_state_ids and _prev_state_ids
|
||||
@@ -228,27 +226,6 @@ class EventContext:
|
||||
else:
|
||||
self._prev_state_ids = self._current_state_ids
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def update_state(
|
||||
self, state_group, prev_state_ids, current_state_ids, prev_group, delta_ids
|
||||
):
|
||||
"""Replace the state in the context
|
||||
"""
|
||||
|
||||
# We need to make sure we wait for any ongoing fetching of state
|
||||
# to complete so that the updated state doesn't get clobbered
|
||||
if self._fetching_state_deferred:
|
||||
yield make_deferred_yieldable(self._fetching_state_deferred)
|
||||
|
||||
self.state_group = state_group
|
||||
self._prev_state_ids = prev_state_ids
|
||||
self.prev_group = prev_group
|
||||
self._current_state_ids = current_state_ids
|
||||
self.delta_ids = delta_ids
|
||||
|
||||
# We need to ensure that that we've marked as having fetched the state
|
||||
self._fetching_state_deferred = defer.succeed(None)
|
||||
|
||||
|
||||
def _encode_state_dict(state_dict):
|
||||
"""Since dicts of (type, state_key) -> event_id cannot be serialized in
|
||||
|
||||
@@ -45,6 +45,7 @@ from synapse.api.errors import (
|
||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions
|
||||
from synapse.crypto.event_signing import compute_event_signature
|
||||
from synapse.event_auth import auth_types_for_event
|
||||
from synapse.events.snapshot import EventContext
|
||||
from synapse.events.validator import EventValidator
|
||||
from synapse.logging.context import (
|
||||
make_deferred_yieldable,
|
||||
@@ -1878,14 +1879,7 @@ class FederationHandler(BaseHandler):
|
||||
if c and c.type == EventTypes.Create:
|
||||
auth_events[(c.type, c.state_key)] = c
|
||||
|
||||
try:
|
||||
yield self.do_auth(origin, event, context, auth_events=auth_events)
|
||||
except AuthError as e:
|
||||
logger.warning(
|
||||
"[%s %s] Rejecting: %s", event.room_id, event.event_id, e.msg
|
||||
)
|
||||
|
||||
context.rejected = RejectedReason.AUTH_ERROR
|
||||
context = yield self.do_auth(origin, event, context, auth_events=auth_events)
|
||||
|
||||
if not context.rejected:
|
||||
yield self._check_for_soft_fail(event, state, backfilled)
|
||||
@@ -2054,12 +2048,12 @@ class FederationHandler(BaseHandler):
|
||||
|
||||
Also NB that this function adds entries to it.
|
||||
Returns:
|
||||
defer.Deferred[None]
|
||||
defer.Deferred[EventContext]: updated context object
|
||||
"""
|
||||
room_version = yield self.store.get_room_version(event.room_id)
|
||||
|
||||
try:
|
||||
yield self._update_auth_events_and_context_for_auth(
|
||||
context = yield self._update_auth_events_and_context_for_auth(
|
||||
origin, event, context, auth_events
|
||||
)
|
||||
except Exception:
|
||||
@@ -2077,7 +2071,9 @@ class FederationHandler(BaseHandler):
|
||||
event_auth.check(room_version, event, auth_events=auth_events)
|
||||
except AuthError as e:
|
||||
logger.warning("Failed auth resolution for %r because %s", event, e)
|
||||
raise e
|
||||
context.rejected = RejectedReason.AUTH_ERROR
|
||||
|
||||
return context
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _update_auth_events_and_context_for_auth(
|
||||
@@ -2101,7 +2097,7 @@ class FederationHandler(BaseHandler):
|
||||
auth_events (dict[(str, str)->synapse.events.EventBase]):
|
||||
|
||||
Returns:
|
||||
defer.Deferred[None]
|
||||
defer.Deferred[EventContext]: updated context
|
||||
"""
|
||||
event_auth_events = set(event.auth_event_ids())
|
||||
|
||||
@@ -2140,7 +2136,7 @@ class FederationHandler(BaseHandler):
|
||||
# The other side isn't around or doesn't implement the
|
||||
# endpoint, so lets just bail out.
|
||||
logger.info("Failed to get event auth from remote: %s", e)
|
||||
return
|
||||
return context
|
||||
|
||||
seen_remotes = yield self.store.have_seen_events(
|
||||
[e.event_id for e in remote_auth_chain]
|
||||
@@ -2181,7 +2177,7 @@ class FederationHandler(BaseHandler):
|
||||
|
||||
if event.internal_metadata.is_outlier():
|
||||
logger.info("Skipping auth_event fetch for outlier")
|
||||
return
|
||||
return context
|
||||
|
||||
# FIXME: Assumes we have and stored all the state for all the
|
||||
# prev_events
|
||||
@@ -2190,7 +2186,7 @@ class FederationHandler(BaseHandler):
|
||||
)
|
||||
|
||||
if not different_auth:
|
||||
return
|
||||
return context
|
||||
|
||||
logger.info(
|
||||
"auth_events refers to events which are not in our calculated auth "
|
||||
@@ -2237,10 +2233,12 @@ class FederationHandler(BaseHandler):
|
||||
|
||||
auth_events.update(new_state)
|
||||
|
||||
yield self._update_context_for_auth_events(
|
||||
context = yield self._update_context_for_auth_events(
|
||||
event, context, auth_events, event_key
|
||||
)
|
||||
|
||||
return context
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _update_context_for_auth_events(self, event, context, auth_events, event_key):
|
||||
"""Update the state_ids in an event context after auth event resolution,
|
||||
@@ -2249,14 +2247,16 @@ class FederationHandler(BaseHandler):
|
||||
Args:
|
||||
event (Event): The event we're handling the context for
|
||||
|
||||
context (synapse.events.snapshot.EventContext): event context
|
||||
to be updated
|
||||
context (synapse.events.snapshot.EventContext): initial event context
|
||||
|
||||
auth_events (dict[(str, str)->str]): Events to update in the event
|
||||
context.
|
||||
|
||||
event_key ((str, str)): (type, state_key) for the current event.
|
||||
this will not be included in the current_state in the context.
|
||||
|
||||
Returns:
|
||||
Deferred[EventContext]: new event context
|
||||
"""
|
||||
state_updates = {
|
||||
k: a.event_id for k, a in iteritems(auth_events) if k != event_key
|
||||
@@ -2281,7 +2281,7 @@ class FederationHandler(BaseHandler):
|
||||
current_state_ids=current_state_ids,
|
||||
)
|
||||
|
||||
yield context.update_state(
|
||||
return EventContext.with_state(
|
||||
state_group=state_group,
|
||||
current_state_ids=current_state_ids,
|
||||
prev_state_ids=prev_state_ids,
|
||||
|
||||
@@ -65,6 +65,9 @@ class VersionsRestServlet(RestServlet):
|
||||
"m.require_identity_server": False,
|
||||
# as per MSC2290
|
||||
"m.separate_add_and_bind": True,
|
||||
# Implements support for label-based filtering as described in
|
||||
# MSC2326.
|
||||
"org.matrix.label_based_filtering": True,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
@@ -29,7 +29,7 @@ from prometheus_client import Counter
|
||||
from twisted.internet import defer
|
||||
|
||||
import synapse.metrics
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.api.constants import EventContentFields, EventTypes
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.events import EventBase # noqa: F401
|
||||
from synapse.events.snapshot import EventContext # noqa: F401
|
||||
@@ -935,6 +935,13 @@ class EventsStore(
|
||||
|
||||
self._handle_event_relations(txn, event)
|
||||
|
||||
# Store the labels for this event.
|
||||
labels = event.content.get(EventContentFields.LABELS)
|
||||
if labels:
|
||||
self.insert_labels_for_event_txn(
|
||||
txn, event.event_id, labels, event.room_id, event.depth
|
||||
)
|
||||
|
||||
# Insert into the room_memberships table.
|
||||
self._store_room_members_txn(
|
||||
txn,
|
||||
@@ -1920,6 +1927,33 @@ class EventsStore(
|
||||
get_all_updated_current_state_deltas_txn,
|
||||
)
|
||||
|
||||
def insert_labels_for_event_txn(
|
||||
self, txn, event_id, labels, room_id, topological_ordering
|
||||
):
|
||||
"""Store the mapping between an event's ID and its labels, with one row per
|
||||
(event_id, label) tuple.
|
||||
|
||||
Args:
|
||||
txn (LoggingTransaction): The transaction to execute.
|
||||
event_id (str): The event's ID.
|
||||
labels (list[str]): A list of text labels.
|
||||
room_id (str): The ID of the room the event was sent to.
|
||||
topological_ordering (int): The position of the event in the room's topology.
|
||||
"""
|
||||
return self._simple_insert_many_txn(
|
||||
txn=txn,
|
||||
table="event_labels",
|
||||
values=[
|
||||
{
|
||||
"event_id": event_id,
|
||||
"label": label,
|
||||
"room_id": room_id,
|
||||
"topological_ordering": topological_ordering,
|
||||
}
|
||||
for label in labels
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
AllNewEventsResult = namedtuple(
|
||||
"AllNewEventsResult",
|
||||
|
||||
@@ -0,0 +1,30 @@
|
||||
/* Copyright 2019 The Matrix.org Foundation C.I.C.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
-- room_id and topoligical_ordering are denormalised from the events table in order to
|
||||
-- make the index work.
|
||||
CREATE TABLE IF NOT EXISTS event_labels (
|
||||
event_id TEXT,
|
||||
label TEXT,
|
||||
room_id TEXT NOT NULL,
|
||||
topological_ordering BIGINT NOT NULL,
|
||||
PRIMARY KEY(event_id, label)
|
||||
);
|
||||
|
||||
|
||||
-- This index enables an event pagination looking for a particular label to index the
|
||||
-- event_labels table first, which is much quicker than scanning the events table and then
|
||||
-- filtering by label, if the label is rarely used relative to the size of the room.
|
||||
CREATE INDEX event_labels_room_id_label_idx ON event_labels(room_id, label, topological_ordering);
|
||||
@@ -229,6 +229,14 @@ def filter_to_clause(event_filter):
|
||||
clauses.append("contains_url = ?")
|
||||
args.append(event_filter.contains_url)
|
||||
|
||||
# We're only applying the "labels" filter on the database query, because applying the
|
||||
# "not_labels" filter via a SQL query is non-trivial. Instead, we let
|
||||
# event_filter.check_fields apply it, which is not as efficient but makes the
|
||||
# implementation simpler.
|
||||
if event_filter.labels:
|
||||
clauses.append("(%s)" % " OR ".join("label = ?" for _ in event_filter.labels))
|
||||
args.extend(event_filter.labels)
|
||||
|
||||
return " AND ".join(clauses), args
|
||||
|
||||
|
||||
@@ -864,8 +872,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
args.append(int(limit))
|
||||
|
||||
sql = (
|
||||
"SELECT event_id, topological_ordering, stream_ordering"
|
||||
"SELECT DISTINCT event_id, topological_ordering, stream_ordering"
|
||||
" FROM events"
|
||||
" LEFT JOIN event_labels USING (event_id, room_id, topological_ordering)"
|
||||
" WHERE outlier = ? AND room_id = ? AND %(bounds)s"
|
||||
" ORDER BY topological_ordering %(order)s,"
|
||||
" stream_ordering %(order)s LIMIT ?"
|
||||
|
||||
@@ -19,6 +19,7 @@ import jsonschema
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.constants import EventContentFields
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.api.filtering import Filter
|
||||
from synapse.events import FrozenEvent
|
||||
@@ -95,6 +96,8 @@ class FilteringTestCase(unittest.TestCase):
|
||||
"types": ["m.room.message"],
|
||||
"not_rooms": ["!726s6s6q:example.com"],
|
||||
"not_senders": ["@spam:example.com"],
|
||||
"org.matrix.labels": ["#fun"],
|
||||
"org.matrix.not_labels": ["#work"],
|
||||
},
|
||||
"ephemeral": {
|
||||
"types": ["m.receipt", "m.typing"],
|
||||
@@ -320,6 +323,46 @@ class FilteringTestCase(unittest.TestCase):
|
||||
)
|
||||
self.assertFalse(Filter(definition).check(event))
|
||||
|
||||
def test_filter_labels(self):
|
||||
definition = {"org.matrix.labels": ["#fun"]}
|
||||
event = MockEvent(
|
||||
sender="@foo:bar",
|
||||
type="m.room.message",
|
||||
room_id="!secretbase:unknown",
|
||||
content={EventContentFields.LABELS: ["#fun"]},
|
||||
)
|
||||
|
||||
self.assertTrue(Filter(definition).check(event))
|
||||
|
||||
event = MockEvent(
|
||||
sender="@foo:bar",
|
||||
type="m.room.message",
|
||||
room_id="!secretbase:unknown",
|
||||
content={EventContentFields.LABELS: ["#notfun"]},
|
||||
)
|
||||
|
||||
self.assertFalse(Filter(definition).check(event))
|
||||
|
||||
def test_filter_not_labels(self):
|
||||
definition = {"org.matrix.not_labels": ["#fun"]}
|
||||
event = MockEvent(
|
||||
sender="@foo:bar",
|
||||
type="m.room.message",
|
||||
room_id="!secretbase:unknown",
|
||||
content={EventContentFields.LABELS: ["#fun"]},
|
||||
)
|
||||
|
||||
self.assertFalse(Filter(definition).check(event))
|
||||
|
||||
event = MockEvent(
|
||||
sender="@foo:bar",
|
||||
type="m.room.message",
|
||||
room_id="!secretbase:unknown",
|
||||
content={EventContentFields.LABELS: ["#notfun"]},
|
||||
)
|
||||
|
||||
self.assertTrue(Filter(definition).check(event))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_filter_presence_match(self):
|
||||
user_filter_json = {"presence": {"types": ["m.*"]}}
|
||||
|
||||
@@ -24,7 +24,7 @@ from six.moves.urllib import parse as urlparse
|
||||
from twisted.internet import defer
|
||||
|
||||
import synapse.rest.admin
|
||||
from synapse.api.constants import Membership
|
||||
from synapse.api.constants import EventContentFields, EventTypes, Membership
|
||||
from synapse.rest.client.v1 import login, profile, room
|
||||
|
||||
from tests import unittest
|
||||
@@ -811,6 +811,105 @@ class RoomMessageListTestCase(RoomBase):
|
||||
self.assertTrue("chunk" in channel.json_body)
|
||||
self.assertTrue("end" in channel.json_body)
|
||||
|
||||
def test_filter_labels(self):
|
||||
"""Test that we can filter by a label."""
|
||||
message_filter = json.dumps(
|
||||
{"types": [EventTypes.Message], "org.matrix.labels": ["#fun"]}
|
||||
)
|
||||
|
||||
events = self._test_filter_labels(message_filter)
|
||||
|
||||
self.assertEqual(len(events), 2, [event["content"] for event in events])
|
||||
self.assertEqual(events[0]["content"]["body"], "with right label", events[0])
|
||||
self.assertEqual(events[1]["content"]["body"], "with right label", events[1])
|
||||
|
||||
def test_filter_not_labels(self):
|
||||
"""Test that we can filter by the absence of a label."""
|
||||
message_filter = json.dumps(
|
||||
{"types": [EventTypes.Message], "org.matrix.not_labels": ["#fun"]}
|
||||
)
|
||||
|
||||
events = self._test_filter_labels(message_filter)
|
||||
|
||||
self.assertEqual(len(events), 3, [event["content"] for event in events])
|
||||
self.assertEqual(events[0]["content"]["body"], "without label", events[0])
|
||||
self.assertEqual(events[1]["content"]["body"], "with wrong label", events[1])
|
||||
self.assertEqual(
|
||||
events[2]["content"]["body"], "with two wrong labels", events[2]
|
||||
)
|
||||
|
||||
def test_filter_labels_not_labels(self):
|
||||
"""Test that we can filter by both a label and the absence of another label."""
|
||||
sync_filter = json.dumps(
|
||||
{
|
||||
"types": [EventTypes.Message],
|
||||
"org.matrix.labels": ["#work"],
|
||||
"org.matrix.not_labels": ["#notfun"],
|
||||
}
|
||||
)
|
||||
|
||||
events = self._test_filter_labels(sync_filter)
|
||||
|
||||
self.assertEqual(len(events), 1, [event["content"] for event in events])
|
||||
self.assertEqual(events[0]["content"]["body"], "with wrong label", events[0])
|
||||
|
||||
def _test_filter_labels(self, message_filter):
|
||||
self.helper.send_event(
|
||||
room_id=self.room_id,
|
||||
type=EventTypes.Message,
|
||||
content={
|
||||
"msgtype": "m.text",
|
||||
"body": "with right label",
|
||||
EventContentFields.LABELS: ["#fun"],
|
||||
},
|
||||
)
|
||||
|
||||
self.helper.send_event(
|
||||
room_id=self.room_id,
|
||||
type=EventTypes.Message,
|
||||
content={"msgtype": "m.text", "body": "without label"},
|
||||
)
|
||||
|
||||
self.helper.send_event(
|
||||
room_id=self.room_id,
|
||||
type=EventTypes.Message,
|
||||
content={
|
||||
"msgtype": "m.text",
|
||||
"body": "with wrong label",
|
||||
EventContentFields.LABELS: ["#work"],
|
||||
},
|
||||
)
|
||||
|
||||
self.helper.send_event(
|
||||
room_id=self.room_id,
|
||||
type=EventTypes.Message,
|
||||
content={
|
||||
"msgtype": "m.text",
|
||||
"body": "with two wrong labels",
|
||||
EventContentFields.LABELS: ["#work", "#notfun"],
|
||||
},
|
||||
)
|
||||
|
||||
self.helper.send_event(
|
||||
room_id=self.room_id,
|
||||
type=EventTypes.Message,
|
||||
content={
|
||||
"msgtype": "m.text",
|
||||
"body": "with right label",
|
||||
EventContentFields.LABELS: ["#fun"],
|
||||
},
|
||||
)
|
||||
|
||||
token = "s0_0_0_0_0_0_0_0_0"
|
||||
request, channel = self.make_request(
|
||||
"GET",
|
||||
"/rooms/%s/messages?access_token=x&from=%s&filter=%s"
|
||||
% (self.room_id, token, message_filter),
|
||||
)
|
||||
self.render(request)
|
||||
|
||||
return channel.json_body["chunk"]
|
||||
|
||||
|
||||
class RoomSearchTestCase(unittest.HomeserverTestCase):
|
||||
servlets = [
|
||||
|
||||
@@ -106,13 +106,22 @@ class RestHelper(object):
|
||||
self.auth_user_id = temp_id
|
||||
|
||||
def send(self, room_id, body=None, txn_id=None, tok=None, expect_code=200):
|
||||
if txn_id is None:
|
||||
txn_id = "m%s" % (str(time.time()))
|
||||
if body is None:
|
||||
body = "body_text_here"
|
||||
|
||||
path = "/_matrix/client/r0/rooms/%s/send/m.room.message/%s" % (room_id, txn_id)
|
||||
content = {"msgtype": "m.text", "body": body}
|
||||
|
||||
return self.send_event(
|
||||
room_id, "m.room.message", content, txn_id, tok, expect_code
|
||||
)
|
||||
|
||||
def send_event(
|
||||
self, room_id, type, content={}, txn_id=None, tok=None, expect_code=200
|
||||
):
|
||||
if txn_id is None:
|
||||
txn_id = "m%s" % (str(time.time()))
|
||||
|
||||
path = "/_matrix/client/r0/rooms/%s/send/%s/%s" % (room_id, type, txn_id)
|
||||
if tok:
|
||||
path = path + "?access_token=%s" % tok
|
||||
|
||||
|
||||
@@ -12,10 +12,12 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import json
|
||||
|
||||
from mock import Mock
|
||||
|
||||
import synapse.rest.admin
|
||||
from synapse.api.constants import EventContentFields, EventTypes
|
||||
from synapse.rest.client.v1 import login, room
|
||||
from synapse.rest.client.v2_alpha import sync
|
||||
|
||||
@@ -26,7 +28,12 @@ from tests.server import TimedOutException
|
||||
class FilterTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
user_id = "@apple:test"
|
||||
servlets = [sync.register_servlets]
|
||||
servlets = [
|
||||
synapse.rest.admin.register_servlets_for_client_rest_resource,
|
||||
room.register_servlets,
|
||||
login.register_servlets,
|
||||
sync.register_servlets,
|
||||
]
|
||||
|
||||
def make_homeserver(self, reactor, clock):
|
||||
|
||||
@@ -70,6 +77,140 @@ class FilterTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
|
||||
|
||||
class SyncFilterTestCase(unittest.HomeserverTestCase):
|
||||
servlets = [
|
||||
synapse.rest.admin.register_servlets_for_client_rest_resource,
|
||||
room.register_servlets,
|
||||
login.register_servlets,
|
||||
sync.register_servlets,
|
||||
]
|
||||
|
||||
def test_sync_filter_labels(self):
|
||||
"""Test that we can filter by a label."""
|
||||
sync_filter = json.dumps(
|
||||
{
|
||||
"room": {
|
||||
"timeline": {
|
||||
"types": [EventTypes.Message],
|
||||
"org.matrix.labels": ["#fun"],
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
events = self._test_sync_filter_labels(sync_filter)
|
||||
|
||||
self.assertEqual(len(events), 2, [event["content"] for event in events])
|
||||
self.assertEqual(events[0]["content"]["body"], "with right label", events[0])
|
||||
self.assertEqual(events[1]["content"]["body"], "with right label", events[1])
|
||||
|
||||
def test_sync_filter_not_labels(self):
|
||||
"""Test that we can filter by the absence of a label."""
|
||||
sync_filter = json.dumps(
|
||||
{
|
||||
"room": {
|
||||
"timeline": {
|
||||
"types": [EventTypes.Message],
|
||||
"org.matrix.not_labels": ["#fun"],
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
events = self._test_sync_filter_labels(sync_filter)
|
||||
|
||||
self.assertEqual(len(events), 3, [event["content"] for event in events])
|
||||
self.assertEqual(events[0]["content"]["body"], "without label", events[0])
|
||||
self.assertEqual(events[1]["content"]["body"], "with wrong label", events[1])
|
||||
self.assertEqual(
|
||||
events[2]["content"]["body"], "with two wrong labels", events[2]
|
||||
)
|
||||
|
||||
def test_sync_filter_labels_not_labels(self):
|
||||
"""Test that we can filter by both a label and the absence of another label."""
|
||||
sync_filter = json.dumps(
|
||||
{
|
||||
"room": {
|
||||
"timeline": {
|
||||
"types": [EventTypes.Message],
|
||||
"org.matrix.labels": ["#work"],
|
||||
"org.matrix.not_labels": ["#notfun"],
|
||||
}
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
events = self._test_sync_filter_labels(sync_filter)
|
||||
|
||||
self.assertEqual(len(events), 1, [event["content"] for event in events])
|
||||
self.assertEqual(events[0]["content"]["body"], "with wrong label", events[0])
|
||||
|
||||
def _test_sync_filter_labels(self, sync_filter):
|
||||
user_id = self.register_user("kermit", "test")
|
||||
tok = self.login("kermit", "test")
|
||||
|
||||
room_id = self.helper.create_room_as(user_id, tok=tok)
|
||||
|
||||
self.helper.send_event(
|
||||
room_id=room_id,
|
||||
type=EventTypes.Message,
|
||||
content={
|
||||
"msgtype": "m.text",
|
||||
"body": "with right label",
|
||||
EventContentFields.LABELS: ["#fun"],
|
||||
},
|
||||
tok=tok,
|
||||
)
|
||||
|
||||
self.helper.send_event(
|
||||
room_id=room_id,
|
||||
type=EventTypes.Message,
|
||||
content={"msgtype": "m.text", "body": "without label"},
|
||||
tok=tok,
|
||||
)
|
||||
|
||||
self.helper.send_event(
|
||||
room_id=room_id,
|
||||
type=EventTypes.Message,
|
||||
content={
|
||||
"msgtype": "m.text",
|
||||
"body": "with wrong label",
|
||||
EventContentFields.LABELS: ["#work"],
|
||||
},
|
||||
tok=tok,
|
||||
)
|
||||
|
||||
self.helper.send_event(
|
||||
room_id=room_id,
|
||||
type=EventTypes.Message,
|
||||
content={
|
||||
"msgtype": "m.text",
|
||||
"body": "with two wrong labels",
|
||||
EventContentFields.LABELS: ["#work", "#notfun"],
|
||||
},
|
||||
tok=tok,
|
||||
)
|
||||
|
||||
self.helper.send_event(
|
||||
room_id=room_id,
|
||||
type=EventTypes.Message,
|
||||
content={
|
||||
"msgtype": "m.text",
|
||||
"body": "with right label",
|
||||
EventContentFields.LABELS: ["#fun"],
|
||||
},
|
||||
tok=tok,
|
||||
)
|
||||
|
||||
request, channel = self.make_request(
|
||||
"GET", "/sync?filter=%s" % sync_filter, access_token=tok
|
||||
)
|
||||
self.render(request)
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
|
||||
return channel.json_body["rooms"]["join"][room_id]["timeline"]["events"]
|
||||
|
||||
|
||||
class SyncTypingTests(unittest.HomeserverTestCase):
|
||||
|
||||
servlets = [
|
||||
|
||||
@@ -59,7 +59,9 @@ class MessageAcceptTests(unittest.TestCase):
|
||||
)
|
||||
|
||||
self.handler = self.homeserver.get_handlers().federation_handler
|
||||
self.handler.do_auth = lambda *a, **b: succeed(True)
|
||||
self.handler.do_auth = lambda origin, event, context, auth_events: succeed(
|
||||
context
|
||||
)
|
||||
self.client = self.homeserver.get_federation_client()
|
||||
self.client._check_sigs_and_hash_and_fetch = lambda dest, pdus, **k: succeed(
|
||||
pdus
|
||||
|
||||
Reference in New Issue
Block a user