1
0

Compare commits

...

18 Commits

Author SHA1 Message Date
Patrick Cloke
426144c9bd Disable push summaries. 2022-08-05 09:10:33 -04:00
Patrick Cloke
7fb3b8405e Merge remote-tracking branch 'origin/clokep/thread-api' into clokep/thread-poc 2022-08-05 09:09:00 -04:00
Patrick Cloke
f0ab6a7f4c Add test script. 2022-08-05 08:18:31 -04:00
Patrick Cloke
0cbddc632e Add an experimental config flag. 2022-08-05 08:18:31 -04:00
Patrick Cloke
80dcb911dc Return thread receipts down sync. 2022-08-05 08:18:31 -04:00
Patrick Cloke
fd972df8f9 Mark thread notifications as read. 2022-08-05 08:18:31 -04:00
Patrick Cloke
fbd6727760 Send the thread ID over replication. 2022-08-05 08:18:31 -04:00
Patrick Cloke
08620b3f28 Add a thread ID to receipts. 2022-08-05 08:18:31 -04:00
Patrick Cloke
d56296aa57 Add a sync flag for unread thread notifications 2022-08-05 08:18:08 -04:00
Patrick Cloke
e0ed95a45b Add an experimental config option. 2022-08-05 08:18:08 -04:00
Patrick Cloke
dfd921d421 Return thread notification counts down sync. 2022-08-05 08:18:08 -04:00
Patrick Cloke
2c7a5681b4 Extract the thread ID when processing push rules. 2022-08-05 08:18:08 -04:00
Patrick Cloke
759366e5e6 Merge remote-tracking branch 'origin/develop' into clokep/thread-api 2022-08-01 10:26:15 -04:00
Patrick Cloke
6c2e08ed6f Merge remote-tracking branch 'origin/develop' into clokep/thread-api 2022-08-01 09:05:23 -04:00
Patrick Cloke
f6267b1abe Add an enum. 2022-08-01 09:05:04 -04:00
Patrick Cloke
d510975b2f Test ignored users. 2022-07-27 12:39:43 -04:00
Patrick Cloke
8dcdb4efa9 Allow limiting threads by participation. 2022-07-27 12:39:17 -04:00
Patrick Cloke
e9a649ec31 Add an API for listing threads in a room. 2022-07-27 11:46:22 -04:00
32 changed files with 1304 additions and 182 deletions

View File

@@ -0,0 +1 @@
Experimental support for thread-specific notifications ([MSC3773](https://github.com/matrix-org/matrix-spec-proposals/pull/3773)).

View File

@@ -0,0 +1 @@
Experimental support for thread-specific receipts ([MSC3771](https://github.com/matrix-org/matrix-spec-proposals/pull/3771)).

View File

@@ -0,0 +1 @@
Experimental support for [MSC3856](https://github.com/matrix-org/matrix-spec-proposals/pull/3856): threads list API.

View File

@@ -84,6 +84,7 @@ ROOM_EVENT_FILTER_SCHEMA = {
"contains_url": {"type": "boolean"},
"lazy_load_members": {"type": "boolean"},
"include_redundant_members": {"type": "boolean"},
"unread_thread_notifications": {"type": "boolean"},
# Include or exclude events with the provided labels.
# cf https://github.com/matrix-org/matrix-doc/pull/2326
"org.matrix.labels": {"type": "array", "items": {"type": "string"}},
@@ -240,6 +241,9 @@ class FilterCollection:
def include_redundant_members(self) -> bool:
return self._room_state_filter.include_redundant_members
def unread_thread_notifications(self) -> bool:
return self._room_timeline_filter.unread_thread_notifications
async def filter_presence(
self, events: Iterable[UserPresenceState]
) -> List[UserPresenceState]:
@@ -304,6 +308,9 @@ class Filter:
self.include_redundant_members = filter_json.get(
"include_redundant_members", False
)
self.unread_thread_notifications = filter_json.get(
"unread_thread_notifications", False
)
self.types = filter_json.get("types", None)
self.not_types = filter_json.get("not_types", [])

View File

@@ -82,11 +82,18 @@ class ExperimentalConfig(Config):
# MSC3786 (Add a default push rule to ignore m.room.server_acl events)
self.msc3786_enabled: bool = experimental.get("msc3786_enabled", False)
# MSC3771: Thread read receipts
self.msc3771_enabled: bool = experimental.get("msc3771_enabled", False)
# MSC3772: A push rule for mutual relations.
self.msc3772_enabled: bool = experimental.get("msc3772_enabled", False)
# MSC3773: Thread notifications
self.msc3773_enabled: bool = experimental.get("msc3773_enabled", False)
# MSC3715: dir param on /relations.
self.msc3715_enabled: bool = experimental.get("msc3715_enabled", False)
# MSC3848: Introduce errcodes for specific event sending failures
self.msc3848_enabled: bool = experimental.get("msc3848_enabled", False)
# MSC3856: Threads list API
self.msc3856_enabled: bool = experimental.get("msc3856_enabled", False)

View File

@@ -97,6 +97,7 @@ class ReceiptsHandler:
receipt_type=receipt_type,
user_id=user_id,
event_ids=user_values["event_ids"],
thread_id=None, # TODO
data=user_values.get("data", {}),
)
)
@@ -114,6 +115,7 @@ class ReceiptsHandler:
receipt.receipt_type,
receipt.user_id,
receipt.event_ids,
receipt.thread_id,
receipt.data,
)
@@ -146,7 +148,12 @@ class ReceiptsHandler:
return True
async def received_client_receipt(
self, room_id: str, receipt_type: str, user_id: str, event_id: str
self,
room_id: str,
receipt_type: str,
user_id: str,
event_id: str,
thread_id: Optional[str],
) -> None:
"""Called when a client tells us a local user has read up to the given
event_id in the room.
@@ -156,6 +163,7 @@ class ReceiptsHandler:
receipt_type=receipt_type,
user_id=user_id,
event_ids=[event_id],
thread_id=thread_id,
data={"ts": int(self.clock.time_msec())},
)

View File

@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import enum
import logging
from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Tuple
@@ -31,6 +32,13 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
class ThreadsListInclude(str, enum.Enum):
"""Valid values for the 'include' flag of /threads."""
all = "all"
participated = "participated"
@attr.s(slots=True, frozen=True, auto_attribs=True)
class _ThreadAggregation:
# The latest event in the thread.
@@ -482,3 +490,84 @@ class RelationsHandler:
results.setdefault(event_id, BundledAggregations()).replace = edit
return results
async def get_threads(
self,
requester: Requester,
room_id: str,
include: ThreadsListInclude,
limit: int = 5,
from_token: Optional[StreamToken] = None,
to_token: Optional[StreamToken] = None,
) -> JsonDict:
"""Get related events of a event, ordered by topological ordering.
Args:
requester: The user requesting the relations.
room_id: The room the event belongs to.
include: One of "all" or "participated" to indicate which threads should
be returned.
limit: Only fetch the most recent `limit` events.
from_token: Fetch rows from the given token, or from the start if None.
to_token: Fetch rows up to the given token, or up to the end if None.
Returns:
The pagination chunk.
"""
user_id = requester.user.to_string()
# TODO Properly handle a user leaving a room.
(_, member_event_id) = await self._auth.check_user_in_room_or_world_readable(
room_id, user_id, allow_departed_users=True
)
# Note that ignored users are not passed into get_relations_for_event
# below. Ignored users are handled in filter_events_for_client (and by
# not passing them in here we should get a better cache hit rate).
thread_roots, next_token = await self._main_store.get_threads(
room_id=room_id, limit=limit, from_token=from_token, to_token=to_token
)
events = await self._main_store.get_events_as_list(thread_roots)
if include == ThreadsListInclude.participated:
# Pre-seed thread participation with whether the requester sent the event.
participated = {event.event_id: event.sender == user_id for event in events}
# For events the requester did not send, check the database for whether
# the requester sent a threaded reply.
participated.update(
await self._main_store.get_threads_participated(
[eid for eid, p in participated.items() if not p],
user_id,
)
)
# Limit the returned threads to those the user has participated in.
events = [event for event in events if participated[event.event_id]]
events = await filter_events_for_client(
self._storage_controllers,
user_id,
events,
is_peeking=(member_event_id is None),
)
now = self._clock.time_msec()
aggregations = await self.get_bundled_aggregations(
events, requester.user.to_string()
)
serialized_events = self._event_serializer.serialize_events(
events, now, bundle_aggregations=aggregations
)
return_value: JsonDict = {"chunk": serialized_events}
if next_token:
return_value["next_batch"] = await next_token.to_string(self._main_store)
if from_token:
return_value["prev_batch"] = await from_token.to_string(self._main_store)
return return_value

View File

@@ -115,6 +115,7 @@ class JoinedSyncResult:
ephemeral: List[JsonDict]
account_data: List[JsonDict]
unread_notifications: JsonDict
unread_thread_notifications: JsonDict
summary: Optional[JsonDict]
unread_count: int
@@ -265,6 +266,8 @@ class SyncHandler:
self.rooms_to_exclude = hs.config.server.rooms_to_exclude_from_sync
self._msc3773_enabled = hs.config.experimental.msc3773_enabled
async def wait_for_sync_for_user(
self,
requester: Requester,
@@ -1053,7 +1056,7 @@ class SyncHandler:
async def unread_notifs_for_room_id(
self, room_id: str, sync_config: SyncConfig
) -> NotifCounts:
) -> Tuple[NotifCounts, Dict[str, NotifCounts]]:
with Measure(self.clock, "unread_notifs_for_room_id"):
return await self.store.get_unread_event_push_actions_by_room_for_user(
@@ -2115,17 +2118,46 @@ class SyncHandler:
ephemeral=ephemeral,
account_data=account_data_events,
unread_notifications=unread_notifications,
unread_thread_notifications={},
summary=summary,
unread_count=0,
)
if room_sync or always_include:
notifs = await self.unread_notifs_for_room_id(room_id, sync_config)
notifs, thread_notifs = await self.unread_notifs_for_room_id(
room_id, sync_config
)
unread_notifications["notification_count"] = notifs.notify_count
unread_notifications["highlight_count"] = notifs.highlight_count
# Notifications for the main timeline.
notify_count = notifs.notify_count
highlight_count = notifs.highlight_count
unread_count = notifs.unread_count
room_sync.unread_count = notifs.unread_count
# Check the sync configuration.
if (
self._msc3773_enabled
and sync_config.filter_collection.unread_thread_notifications()
):
# And add info for each thread.
room_sync.unread_thread_notifications = {
thread_id: {
"notification_count": tnotifs.notify_count,
"highlight_count": tnotifs.highlight_count,
}
for thread_id, tnotifs in thread_notifs.items()
if thread_id is not None
}
else:
# Combine the unread counts for all threads and main timeline.
for tnotifs in thread_notifs.values():
notify_count += tnotifs.notify_count
highlight_count += tnotifs.highlight_count
unread_count += tnotifs.unread_count
unread_notifications["notification_count"] = notify_count
unread_notifications["highlight_count"] = highlight_count
room_sync.unread_count = unread_count
sync_result_builder.joined.append(room_sync)

View File

@@ -186,7 +186,7 @@ class BulkPushRuleEvaluator:
return pl_event.content if pl_event else {}, sender_level
async def _get_mutual_relations(
self, event: EventBase, rules: Iterable[Dict[str, Any]]
self, parent_id: str, rules: Iterable[Dict[str, Any]]
) -> Dict[str, Set[Tuple[str, str]]]:
"""
Fetch event metadata for events which related to the same event as the given event.
@@ -194,7 +194,7 @@ class BulkPushRuleEvaluator:
If the given event has no relation information, returns an empty dictionary.
Args:
event_id: The event ID which is targeted by relations.
parent_id: The event ID which is targeted by relations.
rules: The push rules which will be processed for this event.
Returns:
@@ -208,12 +208,6 @@ class BulkPushRuleEvaluator:
if not self._relations_match_enabled:
return {}
# If the event does not have a relation, then cannot have any mutual
# relations.
relation = relation_from_event(event)
if not relation:
return {}
# Pre-filter to figure out which relation types are interesting.
rel_types = set()
for rule in rules:
@@ -235,9 +229,7 @@ class BulkPushRuleEvaluator:
return {}
# If any valid rules were found, fetch the mutual relations.
return await self.store.get_mutual_event_relations(
relation.parent_id, rel_types
)
return await self.store.get_mutual_event_relations(parent_id, rel_types)
@measure_func("action_for_event_by_user")
async def action_for_event_by_user(
@@ -265,9 +257,18 @@ class BulkPushRuleEvaluator:
sender_power_level,
) = await self._get_power_levels_and_sender_level(event, context)
relations = await self._get_mutual_relations(
event, itertools.chain(*rules_by_user.values())
)
relation = relation_from_event(event)
# If the event does not have a relation, then cannot have any mutual
# relations or thread ID.
relations = {}
thread_id = None
if relation:
relations = await self._get_mutual_relations(
relation.parent_id, itertools.chain(*rules_by_user.values())
)
# XXX Does this need to point to a valid parent ID or anything?
if relation.rel_type == RelationTypes.THREAD:
thread_id = relation.parent_id
evaluator = PushRuleEvaluatorForEvent(
event,
@@ -338,6 +339,7 @@ class BulkPushRuleEvaluator:
event.event_id,
actions_by_user,
count_as_unread,
thread_id,
)

View File

@@ -26,13 +26,18 @@ async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) -
badge = len(invites)
for room_id in joins:
notifs = await (
notifs, thread_notifs = await (
store.get_unread_event_push_actions_by_room_for_user(
room_id,
user_id,
)
)
if notifs.notify_count == 0:
# Combine the counts from all the threads.
notify_count = notifs.notify_count + sum(
n.notify_count for n in thread_notifs.values()
)
if notify_count == 0:
continue
if group_by_room:
@@ -40,7 +45,7 @@ async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) -
badge += 1
else:
# increment the badge count by the number of unread messages in the room
badge += notifs.notify_count
badge += notify_count
return badge

View File

@@ -423,7 +423,8 @@ class FederationSenderHandler:
receipt.receipt_type,
receipt.user_id,
[receipt.event_id],
receipt.data,
thread_id=receipt.thread_id,
data=receipt.data,
)
await self.federation_sender.send_read_receipt(receipt_info)

View File

@@ -361,6 +361,7 @@ class ReceiptsStream(Stream):
receipt_type: str
user_id: str
event_id: str
thread_id: Optional[str]
data: dict
NAME = "receipts"

View File

@@ -81,6 +81,7 @@ class ReadMarkerRestServlet(RestServlet):
receipt_type,
user_id=requester.user.to_string(),
event_id=event_id,
thread_id=None, # TODO
)
return 200, {}

View File

@@ -13,10 +13,10 @@
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Tuple
from typing import TYPE_CHECKING, Optional, Tuple
from synapse.api.constants import ReceiptTypes
from synapse.api.errors import SynapseError
from synapse.api.errors import SynapseError, UnrecognizedRequestError
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.http.site import SynapseRequest
@@ -34,7 +34,8 @@ class ReceiptRestServlet(RestServlet):
PATTERNS = client_patterns(
"/rooms/(?P<room_id>[^/]*)"
"/receipt/(?P<receipt_type>[^/]*)"
"/(?P<event_id>[^/]*)$"
"/(?P<event_id>[^/]*)"
"(/(?P<thread_id>[^/]*))?$"
)
def __init__(self, hs: "HomeServer"):
@@ -50,8 +51,15 @@ class ReceiptRestServlet(RestServlet):
(ReceiptTypes.READ_PRIVATE, ReceiptTypes.FULLY_READ)
)
self._msc3771_enabled = hs.config.experimental.msc3771_enabled
async def on_POST(
self, request: SynapseRequest, room_id: str, receipt_type: str, event_id: str
self,
request: SynapseRequest,
room_id: str,
receipt_type: str,
event_id: str,
thread_id: Optional[str],
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
@@ -61,6 +69,9 @@ class ReceiptRestServlet(RestServlet):
f"Receipt type must be {', '.join(self._known_receipt_types)}",
)
if thread_id and not self._msc3771_enabled:
raise UnrecognizedRequestError()
parse_json_object_from_request(request, allow_empty_body=False)
await self.presence_handler.bump_presence_active_time(requester.user)
@@ -77,6 +88,7 @@ class ReceiptRestServlet(RestServlet):
receipt_type,
user_id=requester.user.to_string(),
event_id=event_id,
thread_id=thread_id,
)
return 200, {}

View File

@@ -13,8 +13,10 @@
# limitations under the License.
import logging
import re
from typing import TYPE_CHECKING, Optional, Tuple
from synapse.handlers.relations import ThreadsListInclude
from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_integer, parse_string
from synapse.http.site import SynapseRequest
@@ -91,5 +93,55 @@ class RelationPaginationServlet(RestServlet):
return 200, result
class ThreadsServlet(RestServlet):
PATTERNS = (
re.compile(
"^/_matrix/client/unstable/org.matrix.msc3856/rooms/(?P<room_id>[^/]*)/threads"
),
)
def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastores().main
self._relations_handler = hs.get_relations_handler()
async def on_GET(
self, request: SynapseRequest, room_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
limit = parse_integer(request, "limit", default=5)
from_token_str = parse_string(request, "from")
to_token_str = parse_string(request, "to")
include = parse_string(
request,
"include",
default=ThreadsListInclude.all.value,
allowed_values=[v.value for v in ThreadsListInclude],
)
# Return the relations
from_token = None
if from_token_str:
from_token = await StreamToken.from_string(self.store, from_token_str)
to_token = None
if to_token_str:
to_token = await StreamToken.from_string(self.store, to_token_str)
result = await self._relations_handler.get_threads(
requester=requester,
room_id=room_id,
include=ThreadsListInclude(include),
limit=limit,
from_token=from_token,
to_token=to_token,
)
return 200, result
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
RelationPaginationServlet(hs).register(http_server)
if hs.config.experimental.msc3856_enabled:
ThreadsServlet(hs).register(http_server)

View File

@@ -509,6 +509,8 @@ class SyncRestServlet(RestServlet):
ephemeral_events = room.ephemeral
result["ephemeral"] = {"events": ephemeral_events}
result["unread_notifications"] = room.unread_notifications
if room.unread_thread_notifications:
result["unread_thread_notifications"] = room.unread_thread_notifications
result["summary"] = room.summary
if self._msc2654_enabled:
result["org.matrix.msc2654.unread_count"] = room.unread_count

View File

@@ -78,7 +78,6 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union, cast
import attr
from synapse.api.constants import ReceiptTypes
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import (
@@ -206,10 +205,11 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
self._rotate_count = 10000
self._doing_notif_rotation = False
if hs.config.worker.run_background_tasks:
self._rotate_notif_loop = self._clock.looping_call(
self._rotate_notifs, 30 * 1000
)
# XXX Do not rotate summaries for now, they're broken.
# if hs.config.worker.run_background_tasks:
# self._rotate_notif_loop = self._clock.looping_call(
# self._rotate_notifs, 30 * 1000
# )
self.db_pool.updates.register_background_index_update(
"event_push_summary_unique_index",
@@ -220,12 +220,21 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
replaces_index="event_push_summary_user_rm",
)
@cached(tree=True, max_entries=5000)
self.db_pool.updates.register_background_index_update(
"event_push_summary_unique_index2",
index_name="event_push_summary_unique_index2",
table="event_push_summary",
columns=["user_id", "room_id", "thread_id"],
unique=True,
replaces_index="event_push_summary_unique_index",
)
@cached(tree=True, max_entries=5000, iterable=True)
async def get_unread_event_push_actions_by_room_for_user(
self,
room_id: str,
user_id: str,
) -> NotifCounts:
) -> Tuple[NotifCounts, Dict[str, NotifCounts]]:
"""Get the notification count, the highlight count and the unread message count
for a given user in a given room after the given read receipt.
@@ -254,31 +263,19 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
txn: LoggingTransaction,
room_id: str,
user_id: str,
) -> NotifCounts:
result = self.get_last_receipt_for_user_txn(
txn,
user_id,
room_id,
receipt_types=(ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE),
) -> Tuple[NotifCounts, Dict[str, NotifCounts]]:
# Either last_read_event_id is None, or it's an event we don't have (e.g.
# because it's been purged), in which case retrieve the stream ordering for
# the latest membership event from this user in this room (which we assume is
# a join).
event_id = self.db_pool.simple_select_one_onecol_txn(
txn=txn,
table="local_current_membership",
keyvalues={"room_id": room_id, "user_id": user_id},
retcol="event_id",
)
stream_ordering = None
if result:
_, stream_ordering = result
if stream_ordering is None:
# Either last_read_event_id is None, or it's an event we don't have (e.g.
# because it's been purged), in which case retrieve the stream ordering for
# the latest membership event from this user in this room (which we assume is
# a join).
event_id = self.db_pool.simple_select_one_onecol_txn(
txn=txn,
table="local_current_membership",
keyvalues={"room_id": room_id, "user_id": user_id},
retcol="event_id",
)
stream_ordering = self.get_stream_id_for_event_txn(txn, event_id)
stream_ordering = self.get_stream_id_for_event_txn(txn, event_id)
return self._get_unread_counts_by_pos_txn(
txn, room_id, user_id, stream_ordering
@@ -286,12 +283,20 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
def _get_unread_counts_by_pos_txn(
self, txn: LoggingTransaction, room_id: str, user_id: str, stream_ordering: int
) -> NotifCounts:
) -> Tuple[NotifCounts, Dict[str, NotifCounts]]:
"""Get the number of unread messages for a user/room that have happened
since the given stream ordering.
Returns:
A tuple of:
The unread messages for the main timeline
A dictionary of thread ID to unread messages for that thread.
Only contains threads with unread messages.
"""
counts = NotifCounts()
thread_counts = {}
# First we pull the counts from the summary table.
#
@@ -306,51 +311,92 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
# date (as the row was written by an older version of Synapse that
# updated `event_push_summary` synchronously when persisting a new read
# receipt).
txn.execute(
"""
SELECT stream_ordering, notif_count, COALESCE(unread_count, 0)
FROM event_push_summary
WHERE room_id = ? AND user_id = ?
AND (
(last_receipt_stream_ordering IS NULL AND stream_ordering > ?)
OR last_receipt_stream_ordering = ?
)
""",
(room_id, user_id, stream_ordering, stream_ordering),
)
row = txn.fetchone()
summary_stream_ordering = 0
if row:
summary_stream_ordering = row[0]
counts.notify_count += row[1]
counts.unread_count += row[2]
# XXX event_push_summary is not currently filled in. broken.
# txn.execute(
# """
# SELECT notif_count, COALESCE(unread_count, 0), thread_id, MAX(events.stream_ordering)
# FROM event_push_summary
# LEFT JOIN receipts_linearized USING (room_id, user_id, thread_id)
# LEFT JOIN events ON (
# events.room_id = receipts_linearized.room_id AND
# events.event_id = receipts_linearized.event_id
# )
# WHERE event_push_summary.room_id = ? AND user_id = ?
# AND (
# (
# last_receipt_stream_ordering IS NULL
# AND event_push_summary.stream_ordering > COALESCE(events.stream_ordering, ?)
# )
# OR last_receipt_stream_ordering = COALESCE(events.stream_ordering, ?)
# )
# AND (receipt_type = 'm.read' OR receipt_type = 'org.matrix.msc2285.read.private')
# """,
# (room_id, user_id, stream_ordering, stream_ordering),
# )
# for notif_count, unread_count, thread_id, _ in txn:
# # XXX Why are these returned? Related to MAX(...) aggregation.
# if notif_count is None:
# continue
#
# if not thread_id:
# counts = NotifCounts(
# notify_count=notif_count, unread_count=unread_count
# )
# # TODO Delete zeroed out threads completely from the database.
# elif notif_count or unread_count:
# thread_counts[thread_id] = NotifCounts(
# notify_count=notif_count, unread_count=unread_count
# )
# Next we need to count highlights, which aren't summarised
sql = """
SELECT COUNT(*) FROM event_push_actions
SELECT COUNT(*), thread_id, MAX(events.stream_ordering) FROM event_push_actions
LEFT JOIN receipts_linearized USING (room_id, user_id, thread_id)
LEFT JOIN events ON (
events.room_id = receipts_linearized.room_id AND
events.event_id = receipts_linearized.event_id
)
WHERE user_id = ?
AND room_id = ?
AND stream_ordering > ?
AND event_push_actions.room_id = ?
AND event_push_actions.stream_ordering > COALESCE(events.stream_ordering, ?)
AND highlight = 1
GROUP BY thread_id
"""
txn.execute(sql, (user_id, room_id, stream_ordering))
row = txn.fetchone()
if row:
counts.highlight_count += row[0]
for highlight_count, thread_id, _ in txn:
if not thread_id:
counts.highlight_count += highlight_count
elif highlight_count:
if thread_id in thread_counts:
thread_counts[thread_id].highlight_count += highlight_count
else:
thread_counts[thread_id] = NotifCounts(
notify_count=0, unread_count=0, highlight_count=highlight_count
)
# Finally we need to count push actions that aren't included in the
# summary returned above, e.g. recent events that haven't been
# summarised yet, or the summary is empty due to a recent read receipt.
stream_ordering = max(stream_ordering, summary_stream_ordering)
notify_count, unread_count = self._get_notif_unread_count_for_user_room(
unread_counts = self._get_notif_unread_count_for_user_room(
txn, room_id, user_id, stream_ordering
)
counts.notify_count += notify_count
counts.unread_count += unread_count
for notif_count, unread_count, thread_id in unread_counts:
if not thread_id:
counts.notify_count += notif_count
counts.unread_count += unread_count
elif thread_id in thread_counts:
thread_counts[thread_id].notify_count += notif_count
thread_counts[thread_id].unread_count += unread_count
else:
thread_counts[thread_id] = NotifCounts(
notify_count=notif_count,
unread_count=unread_count,
highlight_count=0,
)
return counts
return counts, thread_counts
def _get_notif_unread_count_for_user_room(
self,
@@ -359,7 +405,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
user_id: str,
stream_ordering: int,
max_stream_ordering: Optional[int] = None,
) -> Tuple[int, int]:
) -> List[Tuple[int, int, str]]:
"""Returns the notify and unread counts from `event_push_actions` for
the given user/room in the given range.
@@ -380,8 +426,10 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
# If there have been no events in the room since the stream ordering,
# there can't be any push actions either.
if not self._events_stream_cache.has_entity_changed(room_id, stream_ordering):
return 0, 0
#
# XXX
# if not self._events_stream_cache.has_entity_changed(room_id, stream_ordering):
# return []
clause = ""
args = [user_id, room_id, stream_ordering]
@@ -389,29 +437,29 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
clause = "AND ea.stream_ordering <= ?"
args.append(max_stream_ordering)
# If the max stream ordering is less than the min stream ordering,
# then obviously there are zero push actions in that range.
if max_stream_ordering <= stream_ordering:
return 0, 0
sql = f"""
SELECT
COUNT(CASE WHEN notif = 1 THEN 1 END),
COUNT(CASE WHEN unread = 1 THEN 1 END)
FROM event_push_actions ea
WHERE user_id = ?
AND room_id = ?
AND ea.stream_ordering > ?
COUNT(CASE WHEN unread = 1 THEN 1 END),
thread_id,
MAX(events.stream_ordering)
FROM event_push_actions ea
LEFT JOIN receipts_linearized USING (room_id, user_id, thread_id)
LEFT JOIN events ON (
events.room_id = receipts_linearized.room_id AND
events.event_id = receipts_linearized.event_id
)
WHERE user_id = ?
AND ea.room_id = ?
AND ea.stream_ordering > COALESCE(events.stream_ordering, ?)
{clause}
GROUP BY thread_id
"""
txn.execute(sql, args)
row = txn.fetchone()
if row:
return cast(Tuple[int, int], row)
return 0, 0
# The max stream ordering is simply there to select the latest receipt,
# it doesn't need to be returned.
return [cast(Tuple[int, int, str], row[:3]) for row in txn.fetchall()]
async def get_push_action_users_in_range(
self, min_stream_ordering: int, max_stream_ordering: int
@@ -680,6 +728,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
event_id: str,
user_id_actions: Dict[str, List[Union[dict, str]]],
count_as_unread: bool,
thread_id: Optional[str],
) -> None:
"""Add the push actions for the event to the push action staging area.
@@ -688,6 +737,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
user_id_actions: A mapping of user_id to list of push actions, where
an action can either be a string or dict.
count_as_unread: Whether this event should increment unread counts.
thread_id: The thread this event is parent of, if applicable.
"""
if not user_id_actions:
return
@@ -696,7 +746,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
# can be used to insert into the `event_push_actions_staging` table.
def _gen_entry(
user_id: str, actions: List[Union[dict, str]]
) -> Tuple[str, str, str, int, int, int]:
) -> Tuple[str, str, str, int, int, int, Optional[str]]:
is_highlight = 1 if _action_has_highlight(actions) else 0
notif = 1 if "notify" in actions else 0
return (
@@ -706,6 +756,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
notif, # notif column
is_highlight, # highlight column
int(count_as_unread), # unread column
thread_id or "", # thread_id column
)
def _add_push_actions_to_staging_txn(txn: LoggingTransaction) -> None:
@@ -714,8 +765,8 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
sql = """
INSERT INTO event_push_actions_staging
(event_id, user_id, actions, notif, highlight, unread)
VALUES (?, ?, ?, ?, ?, ?)
(event_id, user_id, actions, notif, highlight, unread, thread_id)
VALUES (?, ?, ?, ?, ?, ?, ?)
"""
txn.execute_batch(
@@ -955,7 +1006,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
)
sql = """
SELECT r.stream_id, r.room_id, r.user_id, e.stream_ordering
SELECT r.stream_id, r.room_id, r.user_id, r.thread_id, e.stream_ordering
FROM receipts_linearized AS r
INNER JOIN events AS e USING (event_id)
WHERE ? < r.stream_id AND r.stream_id <= ? AND user_id LIKE ?
@@ -978,41 +1029,83 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
)
rows = txn.fetchall()
# For each new read receipt we delete push actions from before it and
# recalculate the summary.
for _, room_id, user_id, stream_ordering in rows:
# Group the rows by room ID / user ID.
rows_by_room_user: Dict[Tuple[str, str], List[Tuple[str, str, int]]] = {}
for stream_id, room_id, user_id, thread_id, stream_ordering in rows:
# Only handle our own read receipts.
if not self.hs.is_mine_id(user_id):
continue
txn.execute(
"""
DELETE FROM event_push_actions
WHERE room_id = ?
AND user_id = ?
AND stream_ordering <= ?
AND highlight = 0
""",
(room_id, user_id, stream_ordering),
rows_by_room_user.setdefault((room_id, user_id), []).append(
(stream_id, thread_id, stream_ordering)
)
# For each new read receipt we delete push actions from before it and
# recalculate the summary.
for (room_id, user_id), room_rows in rows_by_room_user.items():
for _, thread_id, stream_ordering in room_rows:
txn.execute(
"""
DELETE FROM event_push_actions
WHERE room_id = ?
AND user_id = ?
AND thread_id = ?
AND stream_ordering <= ?
AND highlight = 0
""",
(room_id, user_id, thread_id, stream_ordering),
)
# Fetch the notification counts between the stream ordering of the
# latest receipt and what was previously summarised.
notif_count, unread_count = self._get_notif_unread_count_for_user_room(
txn, room_id, user_id, stream_ordering, old_rotate_stream_ordering
earliest_stream_ordering = min(r[2] for r in room_rows)
unread_counts = self._get_notif_unread_count_for_user_room(
txn,
room_id,
user_id,
earliest_stream_ordering,
old_rotate_stream_ordering,
)
# Replace the previous summary with the new counts.
self.db_pool.simple_upsert_txn(
# Updated threads get their notification count and unread count updated.
self.db_pool.simple_upsert_many_txn(
txn,
table="event_push_summary",
keyvalues={"room_id": room_id, "user_id": user_id},
values={
"notif_count": notif_count,
"unread_count": unread_count,
"stream_ordering": old_rotate_stream_ordering,
"last_receipt_stream_ordering": stream_ordering,
},
key_names=("room_id", "user_id", "thread_id"),
key_values=[(room_id, user_id, row[2]) for row in unread_counts],
value_names=(
"notif_count",
"unread_count",
"stream_ordering",
"last_receipt_stream_ordering",
),
value_values=[
# XXX Stream ordering.
(
row[0],
row[1],
old_rotate_stream_ordering,
earliest_stream_ordering,
)
for row in unread_counts
],
)
# XXX WTF?
# Other threads should be marked as reset at the old stream ordering.
txn.execute(
"""
UPDATE event_push_summary SET notif_count = 0, unread_count = 0, stream_ordering = ?, last_receipt_stream_ordering = ?
WHERE user_id = ? AND room_id = ? AND
stream_ordering <= ?
""",
(
old_rotate_stream_ordering,
min_receipts_stream_id,
user_id,
room_id,
old_rotate_stream_ordering,
),
)
# We always update `event_push_summary_last_receipt_stream_id` to
@@ -1102,23 +1195,23 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
# Calculate the new counts that should be upserted into event_push_summary
sql = """
SELECT user_id, room_id,
SELECT user_id, room_id, thread_id,
coalesce(old.%s, 0) + upd.cnt,
upd.stream_ordering
FROM (
SELECT user_id, room_id, count(*) as cnt,
SELECT user_id, room_id, thread_id, count(*) as cnt,
max(ea.stream_ordering) as stream_ordering
FROM event_push_actions AS ea
LEFT JOIN event_push_summary AS old USING (user_id, room_id)
LEFT JOIN event_push_summary AS old USING (user_id, room_id, thread_id)
WHERE ? < ea.stream_ordering AND ea.stream_ordering <= ?
AND (
old.last_receipt_stream_ordering IS NULL
OR old.last_receipt_stream_ordering < ea.stream_ordering
)
AND %s = 1
GROUP BY user_id, room_id
GROUP BY user_id, room_id, thread_id
) AS upd
LEFT JOIN event_push_summary AS old USING (user_id, room_id)
LEFT JOIN event_push_summary AS old USING (user_id, room_id, thread_id)
"""
# First get the count of unread messages.
@@ -1132,11 +1225,11 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
# object because we might not have the same amount of rows in each of them. To do
# this, we use a dict indexed on the user ID and room ID to make it easier to
# populate.
summaries: Dict[Tuple[str, str], _EventPushSummary] = {}
summaries: Dict[Tuple[str, str, str], _EventPushSummary] = {}
for row in txn:
summaries[(row[0], row[1])] = _EventPushSummary(
unread_count=row[2],
stream_ordering=row[3],
summaries[(row[0], row[1], row[2])] = _EventPushSummary(
unread_count=row[3],
stream_ordering=row[4],
notif_count=0,
)
@@ -1147,17 +1240,17 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
)
for row in txn:
if (row[0], row[1]) in summaries:
summaries[(row[0], row[1])].notif_count = row[2]
if (row[0], row[1], row[2]) in summaries:
summaries[(row[0], row[1], row[2])].notif_count = row[3]
else:
# Because the rules on notifying are different than the rules on marking
# a message unread, we might end up with messages that notify but aren't
# marked unread, so we might not have a summary for this (user, room)
# tuple to complete.
summaries[(row[0], row[1])] = _EventPushSummary(
summaries[(row[0], row[1], row[2])] = _EventPushSummary(
unread_count=0,
stream_ordering=row[3],
notif_count=row[2],
stream_ordering=row[4],
notif_count=row[3],
)
logger.info("Rotating notifications, handling %d rows", len(summaries))
@@ -1165,8 +1258,11 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
self.db_pool.simple_upsert_many_txn(
txn,
table="event_push_summary",
key_names=("user_id", "room_id"),
key_values=[(user_id, room_id) for user_id, room_id in summaries],
key_names=("user_id", "room_id", "thread_id"),
key_values=[
(user_id, room_id, thread_id)
for user_id, room_id, thread_id in summaries
],
value_names=("notif_count", "unread_count", "stream_ordering"),
value_values=[
(

View File

@@ -1594,7 +1594,7 @@ class PersistEventsStore:
)
# Remove from relations table.
self._handle_redact_relations(txn, event.redacts)
self._handle_redact_relations(txn, event.room_id, event.redacts)
# Update the event_forward_extremities, event_backward_extremities and
# event_edges tables.
@@ -1909,6 +1909,7 @@ class PersistEventsStore:
self.store.get_thread_participated.invalidate,
(relation.parent_id, event.sender),
)
txn.call_after(self.store.get_threads.invalidate, (event.room_id,))
def _handle_insertion_event(
self, txn: LoggingTransaction, event: EventBase
@@ -2033,13 +2034,14 @@ class PersistEventsStore:
txn.execute(sql, (batch_id,))
def _handle_redact_relations(
self, txn: LoggingTransaction, redacted_event_id: str
self, txn: LoggingTransaction, room_id: str, redacted_event_id: str
) -> None:
"""Handles receiving a redaction and checking whether the redacted event
has any relations which must be removed from the database.
Args:
txn
room_id: The room ID of the event that was redacted.
redacted_event_id: The event that was redacted.
"""
@@ -2068,6 +2070,7 @@ class PersistEventsStore:
self.store._invalidate_cache_and_stream(
txn, self.store.get_thread_participated, (redacted_relates_to,)
)
txn.call_after(self.store.get_threads.invalidate, (room_id,))
self.store._invalidate_cache_and_stream(
txn,
self.store.get_mutual_event_relations_for_rel_type,
@@ -2190,9 +2193,9 @@ class PersistEventsStore:
sql = """
INSERT INTO event_push_actions (
room_id, event_id, user_id, actions, stream_ordering,
topological_ordering, notif, highlight, unread
topological_ordering, notif, highlight, unread, thread_id
)
SELECT ?, event_id, user_id, actions, ?, ?, notif, highlight, unread
SELECT ?, event_id, user_id, actions, ?, ?, notif, highlight, unread, thread_id
FROM event_push_actions_staging
WHERE event_id = ?
"""

View File

@@ -117,6 +117,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
"""Get the current max stream ID for receipts stream"""
return self._receipts_id_gen.get_current_token()
# XXX MOVE TO TESTS
async def get_last_receipt_event_id_for_user(
self, user_id: str, room_id: str, receipt_types: Collection[str]
) -> Optional[str]:
@@ -411,6 +412,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
"_get_linearized_receipts_for_rooms", f
)
# Map of room ID to a dictionary in the form that sync wants it.
results: JsonDict = {}
for row in txn_results:
# We want a single event per room, since we want to batch the
@@ -426,6 +428,8 @@ class ReceiptsWorkerStore(SQLBaseStore):
receipt_type = event_entry.setdefault(row["receipt_type"], {})
receipt_type[row["user_id"]] = db_to_json(row["data"])
if row["thread_id"]:
receipt_type[row["user_id"]]["thread_id"] = row["thread_id"]
results = {
room_id: [results[room_id]] if room_id in results else []
@@ -522,7 +526,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
async def get_all_updated_receipts(
self, instance_name: str, last_id: int, current_id: int, limit: int
) -> Tuple[List[Tuple[int, list]], int, bool]:
) -> Tuple[List[Tuple[int, Tuple[str, str, str, str, str, JsonDict]]], int, bool]:
"""Get updates for receipts replication stream.
Args:
@@ -549,9 +553,11 @@ class ReceiptsWorkerStore(SQLBaseStore):
def get_all_updated_receipts_txn(
txn: LoggingTransaction,
) -> Tuple[List[Tuple[int, list]], int, bool]:
) -> Tuple[
List[Tuple[int, Tuple[str, str, str, str, str, JsonDict]]], int, bool
]:
sql = """
SELECT stream_id, room_id, receipt_type, user_id, event_id, data
SELECT stream_id, room_id, receipt_type, user_id, event_id, thread_id, data
FROM receipts_linearized
WHERE ? < stream_id AND stream_id <= ?
ORDER BY stream_id ASC
@@ -560,8 +566,8 @@ class ReceiptsWorkerStore(SQLBaseStore):
txn.execute(sql, (last_id, current_id, limit))
updates = cast(
List[Tuple[int, list]],
[(r[0], r[1:5] + (db_to_json(r[5]),)) for r in txn],
List[Tuple[int, Tuple[str, str, str, str, str, JsonDict]]],
[(r[0], r[1:6] + (db_to_json(r[6]),)) for r in txn],
)
limited = False
@@ -613,6 +619,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
receipt_type: str,
user_id: str,
event_id: str,
thread_id: Optional[str],
data: JsonDict,
stream_id: int,
) -> Optional[int]:
@@ -636,15 +643,18 @@ class ReceiptsWorkerStore(SQLBaseStore):
stream_ordering = int(res["stream_ordering"]) if res else None
rx_ts = res["received_ts"] if res else 0
# Convert None to a blank string.
thread_id = thread_id or ""
# We don't want to clobber receipts for more recent events, so we
# have to compare orderings of existing receipts
if stream_ordering is not None:
sql = (
"SELECT stream_ordering, event_id FROM events"
" INNER JOIN receipts_linearized AS r USING (event_id, room_id)"
" WHERE r.room_id = ? AND r.receipt_type = ? AND r.user_id = ?"
" INNER JOIN receipts_linearized as r USING (event_id, room_id)"
" WHERE r.room_id = ? AND r.receipt_type = ? AND r.user_id = ? AND r.thread_id = ?"
)
txn.execute(sql, (room_id, receipt_type, user_id))
txn.execute(sql, (room_id, receipt_type, user_id, thread_id))
for so, eid in txn:
if int(so) >= stream_ordering:
@@ -656,6 +666,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
)
return None
# TODO
txn.call_after(
self.invalidate_caches_for_receipt, room_id, receipt_type, user_id
)
@@ -671,6 +682,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
"room_id": room_id,
"receipt_type": receipt_type,
"user_id": user_id,
"thread_id": thread_id,
},
values={
"stream_id": stream_id,
@@ -678,7 +690,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
"data": json_encoder.encode(data),
},
# receipts_linearized has a unique constraint on
# (user_id, room_id, receipt_type), so no need to lock
# (user_id, room_id, receipt_type, key), so no need to lock
lock=False,
)
@@ -728,6 +740,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
receipt_type: str,
user_id: str,
event_ids: List[str],
thread_id: Optional[str],
data: dict,
) -> Optional[Tuple[int, int]]:
"""Insert a receipt, either from local client or remote server.
@@ -760,6 +773,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
receipt_type,
user_id,
linearized_event_id,
thread_id,
data,
stream_id=stream_id,
# Read committed is actually beneficial here because we check for a receipt with
@@ -774,7 +788,8 @@ class ReceiptsWorkerStore(SQLBaseStore):
now = self._clock.time_msec()
logger.debug(
"RR for event %s in %s (%i ms old)",
"Receipt %s for event %s in %s (%i ms old)",
receipt_type,
linearized_event_id,
room_id,
now - event_ts,
@@ -787,6 +802,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
receipt_type,
user_id,
event_ids,
thread_id,
data,
)
@@ -801,6 +817,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
receipt_type: str,
user_id: str,
event_ids: List[str],
thread_id: Optional[str],
data: JsonDict,
) -> None:
assert self._can_write_to_receipts
@@ -812,6 +829,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
# FIXME: This shouldn't invalidate the whole cache
txn.call_after(self._get_linearized_receipts_for_room.invalidate, (room_id,))
# Convert None to a blank string.
thread_id = thread_id or ""
self.db_pool.simple_delete_txn(
txn,
table="receipts_graph",
@@ -819,6 +839,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
"room_id": room_id,
"receipt_type": receipt_type,
"user_id": user_id,
"thread_id": thread_id,
},
)
self.db_pool.simple_insert_txn(
@@ -829,6 +850,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
"receipt_type": receipt_type,
"user_id": user_id,
"event_ids": json_encoder.encode(event_ids),
"thread_id": thread_id,
"data": json_encoder.encode(data),
},
)

View File

@@ -814,6 +814,93 @@ class RelationsWorkerStore(SQLBaseStore):
"get_event_relations", _get_event_relations
)
@cached(tree=True)
async def get_threads(
self,
room_id: str,
limit: int = 5,
from_token: Optional[StreamToken] = None,
to_token: Optional[StreamToken] = None,
) -> Tuple[List[str], Optional[StreamToken]]:
"""Get a list of thread IDs, ordered by topological ordering of their
latest reply.
Args:
room_id: The room the event belongs to.
limit: Only fetch the most recent `limit` threads.
from_token: Fetch rows from the given token, or from the start if None.
to_token: Fetch rows up to the given token, or up to the end if None.
Returns:
A tuple of:
A list of thread root event IDs.
The next stream token, if one exists.
"""
pagination_clause = generate_pagination_where_clause(
direction="b",
column_names=("topological_ordering", "stream_ordering"),
from_token=from_token.room_key.as_historical_tuple()
if from_token
else None,
to_token=to_token.room_key.as_historical_tuple() if to_token else None,
engine=self.database_engine,
)
if pagination_clause:
pagination_clause = "AND " + pagination_clause
sql = f"""
SELECT relates_to_id, MAX(topological_ordering), MAX(stream_ordering)
FROM event_relations
INNER JOIN events USING (event_id)
WHERE
room_id = ? AND
relation_type = '{RelationTypes.THREAD}'
{pagination_clause}
GROUP BY relates_to_id
ORDER BY MAX(topological_ordering) DESC, MAX(stream_ordering) DESC
LIMIT ?
"""
def _get_threads_txn(
txn: LoggingTransaction,
) -> Tuple[List[str], Optional[StreamToken]]:
txn.execute(sql, [room_id, limit + 1])
last_topo_id = None
last_stream_id = None
thread_ids = []
for thread_id, topo_id, stream_id in txn:
thread_ids.append(thread_id)
last_topo_id = topo_id
last_stream_id = stream_id
# If there are more events, generate the next pagination key.
next_token = None
if len(thread_ids) > limit and last_topo_id and last_stream_id:
next_key = RoomStreamToken(last_topo_id, last_stream_id)
if from_token:
next_token = from_token.copy_and_replace(
StreamKeyType.ROOM, next_key
)
else:
next_token = StreamToken(
room_key=next_key,
presence_key=0,
typing_key=0,
receipt_key=0,
account_data_key=0,
push_rules_key=0,
to_device_key=0,
device_list_key=0,
groups_key=0,
)
return thread_ids[:limit], next_token
return await self.db_pool.runInteraction("get_threads", _get_threads_txn)
class RelationsStore(RelationsWorkerStore):
pass

View File

@@ -0,0 +1,25 @@
/* Copyright 2022 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- Allow multiple receipts per user per room via a nullable thread_id column.
ALTER TABLE receipts_linearized ADD COLUMN thread_id TEXT NOT NULL DEFAULT '';
ALTER TABLE receipts_graph ADD COLUMN thread_id TEXT NOT NULL DEFAULT '';
-- Rebuild the unique constraint with the thread_id.
ALTER TABLE receipts_linearized DROP CONSTRAINT IF EXISTS receipts_linearized_uniqueness;
ALTER TABLE receipts_linearized ADD CONSTRAINT receipts_linearized_uniqueness UNIQUE (room_id, receipt_type, user_id, thread_id);
ALTER TABLE receipts_graph DROP CONSTRAINT IF EXISTS receipts_graph_uniqueness;
ALTER TABLE receipts_graph ADD CONSTRAINT receipts_graph_uniqueness UNIQUE (room_id, receipt_type, user_id, thread_id);

View File

@@ -0,0 +1,67 @@
/* Copyright 2022 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- Allow multiple receipts per user per room via a nullable thread_id column.
--
-- SQLite doesn't support modifying constraints to an existing table, so it must
-- be recreated.
-- Create the new tables.
CREATE TABLE receipts_graph_new (
room_id TEXT NOT NULL,
receipt_type TEXT NOT NULL,
user_id TEXT NOT NULL,
event_ids TEXT NOT NULL,
thread_id TEXT NOT NULL DEFAULT '',
data TEXT NOT NULL,
CONSTRAINT receipts_graph_uniqueness UNIQUE (room_id, receipt_type, user_id, thread_id)
);
CREATE TABLE receipts_linearized_new (
stream_id BIGINT NOT NULL,
room_id TEXT NOT NULL,
receipt_type TEXT NOT NULL,
user_id TEXT NOT NULL,
event_id TEXT NOT NULL,
thread_id TEXT NOT NULL DEFAULT '',
data TEXT NOT NULL,
CONSTRAINT receipts_linearized_uniqueness UNIQUE (room_id, receipt_type, user_id, thread_id)
);
-- Drop the old indexes.
DROP INDEX IF EXISTS receipts_linearized_id;
DROP INDEX IF EXISTS receipts_linearized_room_stream;
DROP INDEX IF EXISTS receipts_linearized_user;
-- Copy the data.
INSERT INTO receipts_graph_new (room_id, receipt_type, user_id, event_ids, data)
SELECT room_id, receipt_type, user_id, event_ids, data
FROM receipts_graph;
INSERT INTO receipts_linearized_new (stream_id, room_id, receipt_type, user_id, event_id, data)
SELECT stream_id, room_id, receipt_type, user_id, event_id, data
FROM receipts_linearized;
-- Drop the old tables.
DROP TABLE receipts_graph;
DROP TABLE receipts_linearized;
-- Rename the tables.
ALTER TABLE receipts_graph_new RENAME TO receipts_graph;
ALTER TABLE receipts_linearized_new RENAME TO receipts_linearized;
-- Create the indices.
CREATE INDEX receipts_linearized_id ON receipts_linearized( stream_id );
CREATE INDEX receipts_linearized_room_stream ON receipts_linearized( room_id, stream_id );
CREATE INDEX receipts_linearized_user ON receipts_linearized( user_id );

View File

@@ -0,0 +1,27 @@
/* Copyright 2022 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
ALTER TABLE event_push_actions_staging
ADD COLUMN thread_id TEXT NOT NULL DEFAULT '';
ALTER TABLE event_push_actions
ADD COLUMN thread_id TEXT NOT NULL DEFAULT '';
ALTER TABLE event_push_summary
ADD COLUMN thread_id TEXT NOT NULL DEFAULT '';
-- Update the unique index for `event_push_summary`
INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
(7003, 'event_push_summary_unique_index2', '{}');

View File

@@ -830,6 +830,7 @@ class ReadReceipt:
receipt_type: str
user_id: str
event_ids: List[str]
thread_id: Optional[str]
data: JsonDict

View File

@@ -49,7 +49,12 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
sender = self.hs.get_federation_sender()
receipt = ReadReceipt(
"room_id", "m.read", "user_id", ["event_id"], {"ts": 1234}
"room_id",
"m.read",
"user_id",
["event_id"],
thread_id=None,
data={"ts": 1234},
)
self.successResultOf(defer.ensureDeferred(sender.send_read_receipt(receipt)))
@@ -89,7 +94,12 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
sender = self.hs.get_federation_sender()
receipt = ReadReceipt(
"room_id", "m.read", "user_id", ["event_id"], {"ts": 1234}
"room_id",
"m.read",
"user_id",
["event_id"],
thread_id=None,
data={"ts": 1234},
)
self.successResultOf(defer.ensureDeferred(sender.send_read_receipt(receipt)))
@@ -121,7 +131,12 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
# send the second RR
receipt = ReadReceipt(
"room_id", "m.read", "user_id", ["other_id"], {"ts": 1234}
"room_id",
"m.read",
"user_id",
["other_id"],
thread_id=None,
data={"ts": 1234},
)
self.successResultOf(defer.ensureDeferred(sender.send_read_receipt(receipt)))
self.pump()

View File

@@ -447,6 +447,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
receipt_type="m.read",
user_id=self.local_user,
event_ids=[f"$eventid_{i}"],
thread_id=None,
data={},
)
)

View File

@@ -171,14 +171,14 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
if send_receipt:
self.get_success(
self.master_store.insert_receipt(
ROOM_ID, ReceiptTypes.READ, USER_ID_2, [event1.event_id], {}
ROOM_ID, ReceiptTypes.READ, USER_ID_2, [event1.event_id], None, {}
)
)
self.check(
"get_unread_event_push_actions_by_room_for_user",
[ROOM_ID, USER_ID_2],
NotifCounts(highlight_count=0, unread_count=0, notify_count=0),
(NotifCounts(highlight_count=0, unread_count=0, notify_count=0), {}),
)
self.persist(
@@ -191,7 +191,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
self.check(
"get_unread_event_push_actions_by_room_for_user",
[ROOM_ID, USER_ID_2],
NotifCounts(highlight_count=0, unread_count=0, notify_count=1),
(NotifCounts(highlight_count=0, unread_count=0, notify_count=1), {}),
)
self.persist(
@@ -206,7 +206,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
self.check(
"get_unread_event_push_actions_by_room_for_user",
[ROOM_ID, USER_ID_2],
NotifCounts(highlight_count=1, unread_count=0, notify_count=2),
(NotifCounts(highlight_count=1, unread_count=0, notify_count=2), {}),
)
def test_get_rooms_for_user_with_stream_ordering(self):
@@ -404,6 +404,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
event.event_id,
{user_id: actions for user_id, actions in push_actions},
False,
None,
)
)
return event, context

View File

@@ -33,7 +33,12 @@ class ReceiptsStreamTestCase(BaseStreamTestCase):
# tell the master to send a new receipt
self.get_success(
self.hs.get_datastores().main.insert_receipt(
"!room:blue", "m.read", USER_ID, ["$event:blue"], {"a": 1}
"!room:blue",
"m.read",
USER_ID,
["$event:blue"],
thread_id=None,
data={"a": 1},
)
)
self.replicate()
@@ -57,7 +62,12 @@ class ReceiptsStreamTestCase(BaseStreamTestCase):
self.get_success(
self.hs.get_datastores().main.insert_receipt(
"!room2:blue", "m.read", USER_ID, ["$event2:foo"], {"a": 2}
"!room2:blue",
"m.read",
USER_ID,
["$event2:foo"],
thread_id=None,
data={"a": 2},
)
)
self.replicate()

View File

@@ -1679,3 +1679,154 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
relations[RelationTypes.THREAD]["latest_event"]["event_id"],
related_event_id,
)
class ThreadsTestCase(BaseRelationsTestCase):
@unittest.override_config({"experimental_features": {"msc3856_enabled": True}})
def test_threads(self) -> None:
"""Create threads and ensure the ordering is due to their latest event."""
# Create 2 threads.
thread_1 = self.parent_id
res = self.helper.send(self.room, body="Thread Root!", tok=self.user_token)
thread_2 = res["event_id"]
self._send_relation(RelationTypes.THREAD, "m.room.test")
self._send_relation(RelationTypes.THREAD, "m.room.test", parent_id=thread_2)
# Request the threads in the room.
channel = self.make_request(
"GET",
f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads",
access_token=self.user_token,
)
self.assertEquals(200, channel.code, channel.json_body)
thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]]
self.assertEqual(thread_roots, [thread_2, thread_1])
# Update the first thread, the ordering should swap.
self._send_relation(RelationTypes.THREAD, "m.room.test")
channel = self.make_request(
"GET",
f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads",
access_token=self.user_token,
)
self.assertEquals(200, channel.code, channel.json_body)
thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]]
self.assertEqual(thread_roots, [thread_1, thread_2])
@unittest.override_config({"experimental_features": {"msc3856_enabled": True}})
def test_pagination(self) -> None:
"""Create threads and paginate through them."""
# Create 2 threads.
thread_1 = self.parent_id
res = self.helper.send(self.room, body="Thread Root!", tok=self.user_token)
thread_2 = res["event_id"]
self._send_relation(RelationTypes.THREAD, "m.room.test")
self._send_relation(RelationTypes.THREAD, "m.room.test", parent_id=thread_2)
# Request the threads in the room.
channel = self.make_request(
"GET",
f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads?limit=1",
access_token=self.user_token,
)
self.assertEquals(200, channel.code, channel.json_body)
thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]]
self.assertEqual(thread_roots, [thread_2])
# Make sure next_batch has something in it that looks like it could be a
# valid token.
next_batch = channel.json_body.get("next_batch")
self.assertIsInstance(next_batch, str, channel.json_body)
channel = self.make_request(
"GET",
f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads?limit=1&from={next_batch}",
access_token=self.user_token,
)
self.assertEquals(200, channel.code, channel.json_body)
thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]]
self.assertEqual(thread_roots, [thread_1], channel.json_body)
self.assertNotIn("next_batch", channel.json_body, channel.json_body)
@unittest.override_config({"experimental_features": {"msc3856_enabled": True}})
def test_include(self) -> None:
"""Filtering threads to all or participated in should work."""
# Thread 1 has the user as the root event.
thread_1 = self.parent_id
self._send_relation(
RelationTypes.THREAD, "m.room.test", access_token=self.user2_token
)
# Thread 2 has the user replying.
res = self.helper.send(self.room, body="Thread Root!", tok=self.user2_token)
thread_2 = res["event_id"]
self._send_relation(RelationTypes.THREAD, "m.room.test", parent_id=thread_2)
# Thread 3 has the user not participating in.
res = self.helper.send(self.room, body="Another thread!", tok=self.user2_token)
thread_3 = res["event_id"]
self._send_relation(
RelationTypes.THREAD,
"m.room.test",
access_token=self.user2_token,
parent_id=thread_3,
)
# All threads in the room.
channel = self.make_request(
"GET",
f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads",
access_token=self.user_token,
)
self.assertEquals(200, channel.code, channel.json_body)
thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]]
self.assertEqual(
thread_roots, [thread_3, thread_2, thread_1], channel.json_body
)
# Only participated threads.
channel = self.make_request(
"GET",
f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads?include=participated",
access_token=self.user_token,
)
self.assertEquals(200, channel.code, channel.json_body)
thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]]
self.assertEqual(thread_roots, [thread_2, thread_1], channel.json_body)
@unittest.override_config({"experimental_features": {"msc3856_enabled": True}})
def test_ignored_user(self) -> None:
"""Events from ignored users should be ignored."""
# Thread 1 has a reply from an ignored user.
thread_1 = self.parent_id
self._send_relation(
RelationTypes.THREAD, "m.room.test", access_token=self.user2_token
)
# Thread 2 is created by an ignored user.
res = self.helper.send(self.room, body="Thread Root!", tok=self.user2_token)
thread_2 = res["event_id"]
self._send_relation(RelationTypes.THREAD, "m.room.test", parent_id=thread_2)
# Ignore user2.
self.get_success(
self.store.add_account_data_for_user(
self.user_id,
AccountDataTypes.IGNORED_USER_LIST,
{"ignored_users": {self.user2_id: {}}},
)
)
# Only thread 1 is returned.
channel = self.make_request(
"GET",
f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads",
access_token=self.user_token,
)
self.assertEquals(200, channel.code, channel.json_body)
thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]]
self.assertEqual(thread_roots, [thread_1], channel.json_body)

View File

@@ -12,12 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
from twisted.test.proto_helpers import MemoryReactor
from synapse.rest import admin
from synapse.rest.client import login, room
from synapse.server import HomeServer
from synapse.storage.databases.main.event_push_actions import NotifCounts
from synapse.types import JsonDict
from synapse.util import Clock
from tests.unittest import HomeserverTestCase
@@ -70,7 +73,7 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
def _assert_counts(
noitf_count: int, unread_count: int, highlight_count: int
) -> None:
counts = self.get_success(
counts, thread_counts = self.get_success(
self.store.db_pool.runInteraction(
"get-unread-counts",
self.store._get_unread_counts_by_receipt_txn,
@@ -86,6 +89,7 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
highlight_count=highlight_count,
),
)
self.assertEqual(thread_counts, {})
def _create_event(highlight: bool = False) -> str:
result = self.helper.send_event(
@@ -108,6 +112,7 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
"m.read",
user_id=user_id,
event_ids=[event_id],
thread_id=None,
data={},
)
)
@@ -131,6 +136,7 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
_assert_counts(0, 0, 0)
_create_event()
_assert_counts(1, 1, 0)
_rotate()
_assert_counts(1, 1, 0)
@@ -166,6 +172,184 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
_rotate()
_assert_counts(0, 0, 0)
def test_count_aggregation_threads(self) -> None:
# Create a user to receive notifications and send receipts.
user_id = self.register_user("user1235", "pass")
token = self.login("user1235", "pass")
# And another users to send events.
other_id = self.register_user("other", "pass")
other_token = self.login("other", "pass")
# Create a room and put both users in it.
room_id = self.helper.create_room_as(user_id, tok=token)
self.helper.join(room_id, other_id, tok=other_token)
thread_id: str
last_event_id: str
def _assert_counts(
noitf_count: int,
unread_count: int,
highlight_count: int,
thread_notif_count: int,
thread_unread_count: int,
thread_highlight_count: int,
) -> None:
counts, thread_counts = self.get_success(
self.store.db_pool.runInteraction(
"get-unread-counts",
self.store._get_unread_counts_by_receipt_txn,
room_id,
user_id,
)
)
self.assertEqual(
counts,
NotifCounts(
notify_count=noitf_count,
unread_count=unread_count,
highlight_count=highlight_count,
),
)
if thread_notif_count or thread_unread_count or thread_highlight_count:
self.assertEqual(
thread_counts,
{
thread_id: NotifCounts(
notify_count=thread_notif_count,
unread_count=thread_unread_count,
highlight_count=thread_highlight_count,
),
},
)
else:
self.assertEqual(thread_counts, {})
def _create_event(
highlight: bool = False, thread_id: Optional[str] = None
) -> str:
content: JsonDict = {
"msgtype": "m.text",
"body": user_id if highlight else "",
}
if thread_id:
content["m.relates_to"] = {
"rel_type": "m.thread",
"event_id": thread_id,
}
result = self.helper.send_event(
room_id,
type="m.room.message",
content=content,
tok=other_token,
)
nonlocal last_event_id
last_event_id = result["event_id"]
return last_event_id
def _rotate() -> None:
self.get_success(self.store._rotate_notifs())
def _mark_read(event_id: str, thread_id: Optional[str] = None) -> None:
self.get_success(
self.store.insert_receipt(
room_id,
"m.read",
user_id=user_id,
event_ids=[event_id],
thread_id=thread_id,
data={},
)
)
_assert_counts(0, 0, 0, 0, 0, 0)
thread_id = _create_event()
_assert_counts(1, 0, 0, 0, 0, 0)
_rotate()
_assert_counts(1, 0, 0, 0, 0, 0)
_create_event(thread_id=thread_id)
_assert_counts(1, 0, 0, 1, 0, 0)
_rotate()
_assert_counts(1, 0, 0, 1, 0, 0)
_create_event()
_assert_counts(2, 0, 0, 1, 0, 0)
_rotate()
_assert_counts(2, 0, 0, 1, 0, 0)
event_id = _create_event(thread_id=thread_id)
_assert_counts(2, 0, 0, 2, 0, 0)
_rotate()
_assert_counts(2, 0, 0, 2, 0, 0)
_create_event()
_create_event(thread_id=thread_id)
_mark_read(event_id)
_assert_counts(1, 0, 0, 3, 0, 0)
_mark_read(event_id, thread_id)
_assert_counts(1, 0, 0, 1, 0, 0)
_mark_read(last_event_id)
_mark_read(last_event_id, thread_id)
_assert_counts(0, 0, 0, 0, 0, 0)
_create_event()
_create_event(thread_id=thread_id)
_assert_counts(1, 0, 0, 1, 0, 0)
_rotate()
_assert_counts(1, 0, 0, 1, 0, 0)
# Delete old event push actions, this should not affect the (summarised) count.
self.get_success(self.store._remove_old_push_actions_that_have_rotated())
_assert_counts(1, 0, 0, 1, 0, 0)
_mark_read(last_event_id)
_mark_read(last_event_id, thread_id)
_assert_counts(0, 0, 0, 0, 0, 0)
_create_event(True)
_assert_counts(1, 1, 1, 0, 0, 0)
_rotate()
_assert_counts(1, 1, 1, 0, 0, 0)
event_id = _create_event(True, thread_id)
_assert_counts(1, 1, 1, 1, 1, 1)
_rotate()
_assert_counts(1, 1, 1, 1, 1, 1)
# Check that adding another notification and rotating after highlight
# works.
_create_event()
_rotate()
_assert_counts(2, 1, 1, 1, 1, 1)
_create_event(thread_id=thread_id)
_rotate()
_assert_counts(2, 1, 1, 2, 1, 1)
# Check that sending read receipts at different points results in the
# right counts.
_mark_read(event_id)
_assert_counts(1, 0, 0, 2, 1, 1)
_mark_read(event_id, thread_id)
_assert_counts(1, 0, 0, 1, 0, 0)
_mark_read(last_event_id)
_assert_counts(0, 0, 0, 1, 0, 0)
_mark_read(last_event_id, thread_id)
_assert_counts(0, 0, 0, 0, 0, 0)
_create_event(True)
_create_event(True, thread_id)
_assert_counts(1, 1, 1, 1, 1, 1)
_mark_read(last_event_id)
_mark_read(last_event_id, thread_id)
_assert_counts(0, 0, 0, 0, 0, 0)
_rotate()
_assert_counts(0, 0, 0, 0, 0, 0)
def test_find_first_stream_ordering_after_ts(self) -> None:
def add_event(so: int, ts: int) -> None:
self.get_success(

View File

@@ -120,13 +120,18 @@ class ReceiptTestCase(HomeserverTestCase):
# Send public read receipt for the first event
self.get_success(
self.store.insert_receipt(
self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_1_id], {}
self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_1_id], None, {}
)
)
# Send private read receipt for the second event
self.get_success(
self.store.insert_receipt(
self.room_id1, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event1_2_id], {}
self.room_id1,
ReceiptTypes.READ_PRIVATE,
OUR_USER_ID,
[event1_2_id],
None,
{},
)
)
@@ -153,7 +158,7 @@ class ReceiptTestCase(HomeserverTestCase):
# Test receipt updating
self.get_success(
self.store.insert_receipt(
self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_2_id], {}
self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_2_id], None, {}
)
)
res = self.get_success(
@@ -169,7 +174,12 @@ class ReceiptTestCase(HomeserverTestCase):
# Test new room is reflected in what the method returns
self.get_success(
self.store.insert_receipt(
self.room_id2, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event2_1_id], {}
self.room_id2,
ReceiptTypes.READ_PRIVATE,
OUR_USER_ID,
[event2_1_id],
None,
{},
)
)
res = self.get_success(
@@ -191,13 +201,18 @@ class ReceiptTestCase(HomeserverTestCase):
# Send public read receipt for the first event
self.get_success(
self.store.insert_receipt(
self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_1_id], {}
self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_1_id], None, {}
)
)
# Send private read receipt for the second event
self.get_success(
self.store.insert_receipt(
self.room_id1, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event1_2_id], {}
self.room_id1,
ReceiptTypes.READ_PRIVATE,
OUR_USER_ID,
[event1_2_id],
None,
{},
)
)
@@ -230,7 +245,7 @@ class ReceiptTestCase(HomeserverTestCase):
# Test receipt updating
self.get_success(
self.store.insert_receipt(
self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_2_id], {}
self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_2_id], None, {}
)
)
res = self.get_success(
@@ -248,7 +263,12 @@ class ReceiptTestCase(HomeserverTestCase):
# Test new room is reflected in what the method returns
self.get_success(
self.store.insert_receipt(
self.room_id2, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event2_1_id], {}
self.room_id2,
ReceiptTypes.READ_PRIVATE,
OUR_USER_ID,
[event2_1_id],
None,
{},
)
)
res = self.get_success(

190
thread_test.py Normal file
View File

@@ -0,0 +1,190 @@
import json
from time import monotonic
import requests
HOMESERVER = "http://localhost:8080"
USER_1_TOK = "syt_dGVzdGVy_AywuFarQjsYrHuPkOUvg_25XLNK"
USER_1_HEADERS = {"Authorization": f"Bearer {USER_1_TOK}"}
USER_2_TOK = "syt_b3RoZXI_jtiTnwtlBjMGMixlHIBM_4cxesB"
USER_2_HEADERS = {"Authorization": f"Bearer {USER_2_TOK}"}
def _check_for_status(result):
# Similar to raise_for_status, but prints the error.
if 400 <= result.status_code:
error_msg = result.json()
result.raise_for_status()
print(error_msg)
exit(0)
def _sync_and_show(room_id):
print("Syncing . . .")
result = requests.get(
f"{HOMESERVER}/_matrix/client/v3/sync",
headers=USER_1_HEADERS,
params={
"filter": json.dumps(
{
"room": {
"timeline": {"limit": 30, "unread_thread_notifications": True}
}
}
)
},
)
_check_for_status(result)
sync_response = result.json()
room = sync_response["rooms"]["join"][room_id]
# Find read receipts (this assumes non-overlapping).
read_receipts = {} # thread -> event ID -> users
for event in room["ephemeral"]["events"]:
if event["type"] != "m.receipt":
continue
for event_id, content in event["content"].items():
for mxid, receipt in content["m.read"].items():
print(mxid, receipt)
# Just care about the localpart of the MXID.
mxid = mxid.split(":", 1)[0]
read_receipts.setdefault(receipt.get("thread_id"), {}).setdefault(
event_id, []
).append(mxid)
print(room["unread_notifications"])
print(room.get("unread_thread_notifications"))
print()
# Convert events to their threads.
threads = {}
for event in room["timeline"]["events"]:
if event["type"] != "m.room.message":
continue
event_id = event["event_id"]
parent_id = event["content"].get("m.relates_to", {}).get("event_id")
if parent_id:
threads[parent_id][1].append(event)
else:
threads[event_id] = (event, [])
for root_event_id, (root, thread) in threads.items():
msg = root["content"]["body"]
print(f"{root_event_id}: {msg}")
for event in thread:
thread_event_id = event["event_id"]
msg = event["content"]["body"]
print(f"\t{thread_event_id}: {msg}")
if thread_event_id in read_receipts.get(root_event_id, {}):
user_ids = ", ".join(read_receipts[root_event_id][thread_event_id])
print(f"\t^--------- {user_ids} ---------^")
if root_event_id in read_receipts[None]:
user_ids = ", ".join(read_receipts[None][root_event_id])
print(f"^--------- {user_ids} ---------^")
print()
print()
def _send_event(room_id, body, thread_id=None):
content = {
"msgtype": "m.text",
"body": body,
}
if thread_id:
content["m.relates_to"] = {
"rel_type": "m.thread",
"event_id": thread_id,
}
# Send a msg to the room.
result = requests.put(
f"{HOMESERVER}/_matrix/client/v3/rooms/{room_id}/send/m.room.message/msg{monotonic()}",
json=content,
headers=USER_2_HEADERS,
)
_check_for_status(result)
return result.json()["event_id"]
def main():
# Create a new room as user 2, add a bunch of messages.
result = requests.post(
f"{HOMESERVER}/_matrix/client/v3/createRoom",
json={"visibility": "public", "name": f"Thread Read Receipts ({monotonic()})"},
headers=USER_2_HEADERS,
)
_check_for_status(result)
room_id = result.json()["room_id"]
# Second user joins the room.
result = requests.post(
f"{HOMESERVER}/_matrix/client/v3/rooms/{room_id}/join", headers=USER_1_HEADERS
)
_check_for_status(result)
# Sync user 1.
_sync_and_show(room_id)
# User 2 sends some messages.
event_ids = []
def _send_and_append(body, thread_id=None):
event_id = _send_event(room_id, body, thread_id)
event_ids.append(event_id)
return event_id
for msg in range(5):
root_message_id = _send_and_append(f"Message {msg}")
for msg in range(10):
if msg % 2:
_send_and_append(f"More message {msg}")
else:
_send_and_append(f"Thread Message {msg}", root_message_id)
# User 2 sends a read receipt.
print("@second reads main timeline")
result = requests.post(
f"{HOMESERVER}/_matrix/client/v3/rooms/{room_id}/receipt/m.read/{event_ids[3]}",
headers=USER_2_HEADERS,
json={},
)
_check_for_status(result)
_sync_and_show(room_id)
# User 1 sends a read receipt.
print("@test reads main timeline")
result = requests.post(
f"{HOMESERVER}/_matrix/client/v3/rooms/{room_id}/receipt/m.read/{event_ids[-5]}",
headers=USER_1_HEADERS,
json={},
)
_check_for_status(result)
_sync_and_show(room_id)
# User 1 sends another read receipt.
print("@test reads thread")
result = requests.post(
f"{HOMESERVER}/_matrix/client/v3/rooms/{room_id}/receipt/m.read/{event_ids[-4]}/{root_message_id}",
headers=USER_1_HEADERS,
json={},
)
_check_for_status(result)
_sync_and_show(room_id)
if __name__ == "__main__":
main()