Add ephemeral messages support (MSC2228) (#6409)
* commit '54dd5dc12': Add ephemeral messages support (MSC2228) (#6409)
This commit is contained in:
1
changelog.d/6409.feature
Normal file
1
changelog.d/6409.feature
Normal file
@@ -0,0 +1 @@
|
||||
Add ephemeral messages support by partially implementing [MSC2228](https://github.com/matrix-org/matrix-doc/pull/2228).
|
||||
@@ -148,3 +148,7 @@ class EventContentFields(object):
|
||||
|
||||
# Labels for the event, cf https://github.com/matrix-org/matrix-doc/pull/2326
|
||||
LABELS = "org.matrix.labels"
|
||||
|
||||
# Timestamp to delete the event after
|
||||
# cf https://github.com/matrix-org/matrix-doc/pull/2228
|
||||
SELF_DESTRUCT_AFTER = "org.matrix.self_destruct_after"
|
||||
|
||||
@@ -502,6 +502,8 @@ class ServerConfig(Config):
|
||||
"cleanup_extremities_with_dummy_events", True
|
||||
)
|
||||
|
||||
self.enable_ephemeral_messages = config.get("enable_ephemeral_messages", False)
|
||||
|
||||
def has_tls_listener(self) -> bool:
|
||||
return any(l["tls"] for l in self.listeners)
|
||||
|
||||
|
||||
@@ -121,6 +121,7 @@ class FederationHandler(BaseHandler):
|
||||
self.pusher_pool = hs.get_pusherpool()
|
||||
self.spam_checker = hs.get_spam_checker()
|
||||
self.event_creation_handler = hs.get_event_creation_handler()
|
||||
self._message_handler = hs.get_message_handler()
|
||||
self._server_notices_mxid = hs.config.server_notices_mxid
|
||||
self.config = hs.config
|
||||
self.http_client = hs.get_simple_http_client()
|
||||
@@ -141,6 +142,8 @@ class FederationHandler(BaseHandler):
|
||||
|
||||
self.third_party_event_rules = hs.get_third_party_event_rules()
|
||||
|
||||
self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False):
|
||||
""" Process a PDU received via a federation /send/ transaction, or
|
||||
@@ -2722,6 +2725,11 @@ class FederationHandler(BaseHandler):
|
||||
event_and_contexts, backfilled=backfilled
|
||||
)
|
||||
|
||||
if self._ephemeral_messages_enabled:
|
||||
for (event, context) in event_and_contexts:
|
||||
# If there's an expiry timestamp on the event, schedule its expiry.
|
||||
self._message_handler.maybe_schedule_expiry(event)
|
||||
|
||||
if not backfilled: # Never notify for backfilled events
|
||||
for event, _ in event_and_contexts:
|
||||
yield self._notify_persisted_event(event, max_stream_id)
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from six import iteritems, itervalues, string_types
|
||||
|
||||
@@ -22,9 +23,16 @@ from canonicaljson import encode_canonical_json, json
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.internet.defer import succeed
|
||||
from twisted.internet.interfaces import IDelayedCall
|
||||
|
||||
from synapse import event_auth
|
||||
from synapse.api.constants import EventTypes, Membership, RelationTypes, UserTypes
|
||||
from synapse.api.constants import (
|
||||
EventContentFields,
|
||||
EventTypes,
|
||||
Membership,
|
||||
RelationTypes,
|
||||
UserTypes,
|
||||
)
|
||||
from synapse.api.errors import (
|
||||
AuthError,
|
||||
Codes,
|
||||
@@ -62,6 +70,17 @@ class MessageHandler(object):
|
||||
self.storage = hs.get_storage()
|
||||
self.state_store = self.storage.state
|
||||
self._event_serializer = hs.get_event_client_serializer()
|
||||
self._ephemeral_events_enabled = hs.config.enable_ephemeral_messages
|
||||
self._is_worker_app = bool(hs.config.worker_app)
|
||||
|
||||
# The scheduled call to self._expire_event. None if no call is currently
|
||||
# scheduled.
|
||||
self._scheduled_expiry = None # type: Optional[IDelayedCall]
|
||||
|
||||
if not hs.config.worker_app:
|
||||
run_as_background_process(
|
||||
"_schedule_next_expiry", self._schedule_next_expiry
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_room_data(
|
||||
@@ -225,6 +244,100 @@ class MessageHandler(object):
|
||||
for user_id, profile in iteritems(users_with_profile)
|
||||
}
|
||||
|
||||
def maybe_schedule_expiry(self, event):
|
||||
"""Schedule the expiry of an event if there's not already one scheduled,
|
||||
or if the one running is for an event that will expire after the provided
|
||||
timestamp.
|
||||
|
||||
This function needs to invalidate the event cache, which is only possible on
|
||||
the master process, and therefore needs to be run on there.
|
||||
|
||||
Args:
|
||||
event (EventBase): The event to schedule the expiry of.
|
||||
"""
|
||||
assert not self._is_worker_app
|
||||
|
||||
expiry_ts = event.content.get(EventContentFields.SELF_DESTRUCT_AFTER)
|
||||
if not isinstance(expiry_ts, int) or event.is_state():
|
||||
return
|
||||
|
||||
# _schedule_expiry_for_event won't actually schedule anything if there's already
|
||||
# a task scheduled for a timestamp that's sooner than the provided one.
|
||||
self._schedule_expiry_for_event(event.event_id, expiry_ts)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _schedule_next_expiry(self):
|
||||
"""Retrieve the ID and the expiry timestamp of the next event to be expired,
|
||||
and schedule an expiry task for it.
|
||||
|
||||
If there's no event left to expire, set _expiry_scheduled to None so that a
|
||||
future call to save_expiry_ts can schedule a new expiry task.
|
||||
"""
|
||||
# Try to get the expiry timestamp of the next event to expire.
|
||||
res = yield self.store.get_next_event_to_expire()
|
||||
if res:
|
||||
event_id, expiry_ts = res
|
||||
self._schedule_expiry_for_event(event_id, expiry_ts)
|
||||
|
||||
def _schedule_expiry_for_event(self, event_id, expiry_ts):
|
||||
"""Schedule an expiry task for the provided event if there's not already one
|
||||
scheduled at a timestamp that's sooner than the provided one.
|
||||
|
||||
Args:
|
||||
event_id (str): The ID of the event to expire.
|
||||
expiry_ts (int): The timestamp at which to expire the event.
|
||||
"""
|
||||
if self._scheduled_expiry:
|
||||
# If the provided timestamp refers to a time before the scheduled time of the
|
||||
# next expiry task, cancel that task and reschedule it for this timestamp.
|
||||
next_scheduled_expiry_ts = self._scheduled_expiry.getTime() * 1000
|
||||
if expiry_ts < next_scheduled_expiry_ts:
|
||||
self._scheduled_expiry.cancel()
|
||||
else:
|
||||
return
|
||||
|
||||
# Figure out how many seconds we need to wait before expiring the event.
|
||||
now_ms = self.clock.time_msec()
|
||||
delay = (expiry_ts - now_ms) / 1000
|
||||
|
||||
# callLater doesn't support negative delays, so trim the delay to 0 if we're
|
||||
# in that case.
|
||||
if delay < 0:
|
||||
delay = 0
|
||||
|
||||
logger.info("Scheduling expiry for event %s in %.3fs", event_id, delay)
|
||||
|
||||
self._scheduled_expiry = self.clock.call_later(
|
||||
delay,
|
||||
run_as_background_process,
|
||||
"_expire_event",
|
||||
self._expire_event,
|
||||
event_id,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _expire_event(self, event_id):
|
||||
"""Retrieve and expire an event that needs to be expired from the database.
|
||||
|
||||
If the event doesn't exist in the database, log it and delete the expiry date
|
||||
from the database (so that we don't try to expire it again).
|
||||
"""
|
||||
assert self._ephemeral_events_enabled
|
||||
|
||||
self._scheduled_expiry = None
|
||||
|
||||
logger.info("Expiring event %s", event_id)
|
||||
|
||||
try:
|
||||
# Expire the event if we know about it. This function also deletes the expiry
|
||||
# date from the database in the same database transaction.
|
||||
yield self.store.expire_event(event_id)
|
||||
except Exception as e:
|
||||
logger.error("Could not expire event %s: %r", event_id, e)
|
||||
|
||||
# Schedule the expiry of the next event to expire.
|
||||
yield self._schedule_next_expiry()
|
||||
|
||||
|
||||
# The duration (in ms) after which rooms should be removed
|
||||
# `_rooms_to_exclude_from_dummy_event_insertion` (with the effect that we will try
|
||||
@@ -295,6 +408,10 @@ class EventCreationHandler(object):
|
||||
5 * 60 * 1000,
|
||||
)
|
||||
|
||||
self._message_handler = hs.get_message_handler()
|
||||
|
||||
self._ephemeral_events_enabled = hs.config.enable_ephemeral_messages
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def create_event(
|
||||
self,
|
||||
@@ -877,6 +994,10 @@ class EventCreationHandler(object):
|
||||
event, context=context
|
||||
)
|
||||
|
||||
if self._ephemeral_events_enabled:
|
||||
# If there's an expiry timestamp on the event, schedule its expiry.
|
||||
self._message_handler.maybe_schedule_expiry(event)
|
||||
|
||||
yield self.pusher_pool.on_new_notifications(event_stream_id, max_stream_id)
|
||||
|
||||
def _notify():
|
||||
|
||||
@@ -130,6 +130,8 @@ class EventsStore(
|
||||
if self.hs.config.redaction_retention_period is not None:
|
||||
hs.get_clock().looping_call(_censor_redactions, 5 * 60 * 1000)
|
||||
|
||||
self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _read_forward_extremities(self):
|
||||
def fetch(txn):
|
||||
@@ -940,6 +942,12 @@ class EventsStore(
|
||||
txn, event.event_id, labels, event.room_id, event.depth
|
||||
)
|
||||
|
||||
if self._ephemeral_messages_enabled:
|
||||
# If there's an expiry timestamp on the event, store it.
|
||||
expiry_ts = event.content.get(EventContentFields.SELF_DESTRUCT_AFTER)
|
||||
if isinstance(expiry_ts, int) and not event.is_state():
|
||||
self._insert_event_expiry_txn(txn, event.event_id, expiry_ts)
|
||||
|
||||
# Insert into the room_memberships table.
|
||||
self._store_room_members_txn(
|
||||
txn,
|
||||
@@ -1101,12 +1109,7 @@ class EventsStore(
|
||||
def _update_censor_txn(txn):
|
||||
for redaction_id, event_id, pruned_json in updates:
|
||||
if pruned_json:
|
||||
self._simple_update_one_txn(
|
||||
txn,
|
||||
table="event_json",
|
||||
keyvalues={"event_id": event_id},
|
||||
updatevalues={"json": pruned_json},
|
||||
)
|
||||
self._censor_event_txn(txn, event_id, pruned_json)
|
||||
|
||||
self._simple_update_one_txn(
|
||||
txn,
|
||||
@@ -1117,6 +1120,22 @@ class EventsStore(
|
||||
|
||||
yield self.runInteraction("_update_censor_txn", _update_censor_txn)
|
||||
|
||||
def _censor_event_txn(self, txn, event_id, pruned_json):
|
||||
"""Censor an event by replacing its JSON in the event_json table with the
|
||||
provided pruned JSON.
|
||||
|
||||
Args:
|
||||
txn (LoggingTransaction): The database transaction.
|
||||
event_id (str): The ID of the event to censor.
|
||||
pruned_json (str): The pruned JSON
|
||||
"""
|
||||
self._simple_update_one_txn(
|
||||
txn,
|
||||
table="event_json",
|
||||
keyvalues={"event_id": event_id},
|
||||
updatevalues={"json": pruned_json},
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def count_daily_messages(self):
|
||||
"""
|
||||
@@ -1957,6 +1976,101 @@ class EventsStore(
|
||||
],
|
||||
)
|
||||
|
||||
def _insert_event_expiry_txn(self, txn, event_id, expiry_ts):
|
||||
"""Save the expiry timestamp associated with a given event ID.
|
||||
|
||||
Args:
|
||||
txn (LoggingTransaction): The database transaction to use.
|
||||
event_id (str): The event ID the expiry timestamp is associated with.
|
||||
expiry_ts (int): The timestamp at which to expire (delete) the event.
|
||||
"""
|
||||
return self._simple_insert_txn(
|
||||
txn=txn,
|
||||
table="event_expiry",
|
||||
values={"event_id": event_id, "expiry_ts": expiry_ts},
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def expire_event(self, event_id):
|
||||
"""Retrieve and expire an event that has expired, and delete its associated
|
||||
expiry timestamp. If the event can't be retrieved, delete its associated
|
||||
timestamp so we don't try to expire it again in the future.
|
||||
|
||||
Args:
|
||||
event_id (str): The ID of the event to delete.
|
||||
"""
|
||||
# Try to retrieve the event's content from the database or the event cache.
|
||||
event = yield self.get_event(event_id)
|
||||
|
||||
def delete_expired_event_txn(txn):
|
||||
# Delete the expiry timestamp associated with this event from the database.
|
||||
self._delete_event_expiry_txn(txn, event_id)
|
||||
|
||||
if not event:
|
||||
# If we can't find the event, log a warning and delete the expiry date
|
||||
# from the database so that we don't try to expire it again in the
|
||||
# future.
|
||||
logger.warning(
|
||||
"Can't expire event %s because we don't have it.", event_id
|
||||
)
|
||||
return
|
||||
|
||||
# Prune the event's dict then convert it to JSON.
|
||||
pruned_json = encode_json(prune_event_dict(event.get_dict()))
|
||||
|
||||
# Update the event_json table to replace the event's JSON with the pruned
|
||||
# JSON.
|
||||
self._censor_event_txn(txn, event.event_id, pruned_json)
|
||||
|
||||
# We need to invalidate the event cache entry for this event because we
|
||||
# changed its content in the database. We can't call
|
||||
# self._invalidate_cache_and_stream because self.get_event_cache isn't of the
|
||||
# right type.
|
||||
txn.call_after(self._get_event_cache.invalidate, (event.event_id,))
|
||||
# Send that invalidation to replication so that other workers also invalidate
|
||||
# the event cache.
|
||||
self._send_invalidation_to_replication(
|
||||
txn, "_get_event_cache", (event.event_id,)
|
||||
)
|
||||
|
||||
yield self.runInteraction("delete_expired_event", delete_expired_event_txn)
|
||||
|
||||
def _delete_event_expiry_txn(self, txn, event_id):
|
||||
"""Delete the expiry timestamp associated with an event ID without deleting the
|
||||
actual event.
|
||||
|
||||
Args:
|
||||
txn (LoggingTransaction): The transaction to use to perform the deletion.
|
||||
event_id (str): The event ID to delete the associated expiry timestamp of.
|
||||
"""
|
||||
return self._simple_delete_txn(
|
||||
txn=txn, table="event_expiry", keyvalues={"event_id": event_id}
|
||||
)
|
||||
|
||||
def get_next_event_to_expire(self):
|
||||
"""Retrieve the entry with the lowest expiry timestamp in the event_expiry
|
||||
table, or None if there's no more event to expire.
|
||||
|
||||
Returns: Deferred[Optional[Tuple[str, int]]]
|
||||
A tuple containing the event ID as its first element and an expiry timestamp
|
||||
as its second one, if there's at least one row in the event_expiry table.
|
||||
None otherwise.
|
||||
"""
|
||||
|
||||
def get_next_event_to_expire_txn(txn):
|
||||
txn.execute(
|
||||
"""
|
||||
SELECT event_id, expiry_ts FROM event_expiry
|
||||
ORDER BY expiry_ts ASC LIMIT 1
|
||||
"""
|
||||
)
|
||||
|
||||
return txn.fetchone()
|
||||
|
||||
return self.runInteraction(
|
||||
desc="get_next_event_to_expire", func=get_next_event_to_expire_txn
|
||||
)
|
||||
|
||||
|
||||
AllNewEventsResult = namedtuple(
|
||||
"AllNewEventsResult",
|
||||
|
||||
@@ -0,0 +1,21 @@
|
||||
/* Copyright 2019 The Matrix.org Foundation C.I.C.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
CREATE TABLE IF NOT EXISTS event_expiry (
|
||||
event_id TEXT PRIMARY KEY,
|
||||
expiry_ts BIGINT NOT NULL
|
||||
);
|
||||
|
||||
CREATE INDEX event_expiry_expiry_ts_idx ON event_expiry(expiry_ts);
|
||||
101
tests/rest/client/test_ephemeral_message.py
Normal file
101
tests/rest/client/test_ephemeral_message.py
Normal file
@@ -0,0 +1,101 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2019 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from synapse.api.constants import EventContentFields, EventTypes
|
||||
from synapse.rest import admin
|
||||
from synapse.rest.client.v1 import room
|
||||
|
||||
from tests import unittest
|
||||
|
||||
|
||||
class EphemeralMessageTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
user_id = "@user:test"
|
||||
|
||||
servlets = [
|
||||
admin.register_servlets,
|
||||
room.register_servlets,
|
||||
]
|
||||
|
||||
def make_homeserver(self, reactor, clock):
|
||||
config = self.default_config()
|
||||
|
||||
config["enable_ephemeral_messages"] = True
|
||||
|
||||
self.hs = self.setup_test_homeserver(config=config)
|
||||
return self.hs
|
||||
|
||||
def prepare(self, reactor, clock, homeserver):
|
||||
self.room_id = self.helper.create_room_as(self.user_id)
|
||||
|
||||
def test_message_expiry_no_delay(self):
|
||||
"""Tests that sending a message sent with a m.self_destruct_after field set to the
|
||||
past results in that event being deleted right away.
|
||||
"""
|
||||
# Send a message in the room that has expired. From here, the reactor clock is
|
||||
# at 200ms, so 0 is in the past, and even if that wasn't the case and the clock
|
||||
# is at 0ms the code path is the same if the event's expiry timestamp is the
|
||||
# current timestamp.
|
||||
res = self.helper.send_event(
|
||||
room_id=self.room_id,
|
||||
type=EventTypes.Message,
|
||||
content={
|
||||
"msgtype": "m.text",
|
||||
"body": "hello",
|
||||
EventContentFields.SELF_DESTRUCT_AFTER: 0,
|
||||
},
|
||||
)
|
||||
event_id = res["event_id"]
|
||||
|
||||
# Check that we can't retrieve the content of the event.
|
||||
event_content = self.get_event(self.room_id, event_id)["content"]
|
||||
self.assertFalse(bool(event_content), event_content)
|
||||
|
||||
def test_message_expiry_delay(self):
|
||||
"""Tests that sending a message with a m.self_destruct_after field set to the
|
||||
future results in that event not being deleted right away, but advancing the
|
||||
clock to after that expiry timestamp causes the event to be deleted.
|
||||
"""
|
||||
# Send a message in the room that'll expire in 1s.
|
||||
res = self.helper.send_event(
|
||||
room_id=self.room_id,
|
||||
type=EventTypes.Message,
|
||||
content={
|
||||
"msgtype": "m.text",
|
||||
"body": "hello",
|
||||
EventContentFields.SELF_DESTRUCT_AFTER: self.clock.time_msec() + 1000,
|
||||
},
|
||||
)
|
||||
event_id = res["event_id"]
|
||||
|
||||
# Check that we can retrieve the content of the event before it has expired.
|
||||
event_content = self.get_event(self.room_id, event_id)["content"]
|
||||
self.assertTrue(bool(event_content), event_content)
|
||||
|
||||
# Advance the clock to after the deletion.
|
||||
self.reactor.advance(1)
|
||||
|
||||
# Check that we can't retrieve the content of the event anymore.
|
||||
event_content = self.get_event(self.room_id, event_id)["content"]
|
||||
self.assertFalse(bool(event_content), event_content)
|
||||
|
||||
def get_event(self, room_id, event_id, expected_code=200):
|
||||
url = "/_matrix/client/r0/rooms/%s/event/%s" % (room_id, event_id)
|
||||
|
||||
request, channel = self.make_request("GET", url)
|
||||
self.render(request)
|
||||
|
||||
self.assertEqual(channel.code, expected_code, channel.result)
|
||||
|
||||
return channel.json_body
|
||||
Reference in New Issue
Block a user