1
0

Compare commits

...

18 Commits

Author SHA1 Message Date
Patrick Cloke 426144c9bd Disable push summaries. 2022-08-05 09:10:33 -04:00
Patrick Cloke 7fb3b8405e Merge remote-tracking branch 'origin/clokep/thread-api' into clokep/thread-poc 2022-08-05 09:09:00 -04:00
Patrick Cloke f0ab6a7f4c Add test script. 2022-08-05 08:18:31 -04:00
Patrick Cloke 0cbddc632e Add an experimental config flag. 2022-08-05 08:18:31 -04:00
Patrick Cloke 80dcb911dc Return thread receipts down sync. 2022-08-05 08:18:31 -04:00
Patrick Cloke fd972df8f9 Mark thread notifications as read. 2022-08-05 08:18:31 -04:00
Patrick Cloke fbd6727760 Send the thread ID over replication. 2022-08-05 08:18:31 -04:00
Patrick Cloke 08620b3f28 Add a thread ID to receipts. 2022-08-05 08:18:31 -04:00
Patrick Cloke d56296aa57 Add a sync flag for unread thread notifications 2022-08-05 08:18:08 -04:00
Patrick Cloke e0ed95a45b Add an experimental config option. 2022-08-05 08:18:08 -04:00
Patrick Cloke dfd921d421 Return thread notification counts down sync. 2022-08-05 08:18:08 -04:00
Patrick Cloke 2c7a5681b4 Extract the thread ID when processing push rules. 2022-08-05 08:18:08 -04:00
Patrick Cloke 759366e5e6 Merge remote-tracking branch 'origin/develop' into clokep/thread-api 2022-08-01 10:26:15 -04:00
Patrick Cloke 6c2e08ed6f Merge remote-tracking branch 'origin/develop' into clokep/thread-api 2022-08-01 09:05:23 -04:00
Patrick Cloke f6267b1abe Add an enum. 2022-08-01 09:05:04 -04:00
Patrick Cloke d510975b2f Test ignored users. 2022-07-27 12:39:43 -04:00
Patrick Cloke 8dcdb4efa9 Allow limiting threads by participation. 2022-07-27 12:39:17 -04:00
Patrick Cloke e9a649ec31 Add an API for listing threads in a room. 2022-07-27 11:46:22 -04:00
32 changed files with 1304 additions and 182 deletions
+1
View File
@@ -0,0 +1 @@
Experimental support for thread-specific notifications ([MSC3773](https://github.com/matrix-org/matrix-spec-proposals/pull/3773)).
+1
View File
@@ -0,0 +1 @@
Experimental support for thread-specific receipts ([MSC3771](https://github.com/matrix-org/matrix-spec-proposals/pull/3771)).
+1
View File
@@ -0,0 +1 @@
Experimental support for [MSC3856](https://github.com/matrix-org/matrix-spec-proposals/pull/3856): threads list API.
+7
View File
@@ -84,6 +84,7 @@ ROOM_EVENT_FILTER_SCHEMA = {
"contains_url": {"type": "boolean"}, "contains_url": {"type": "boolean"},
"lazy_load_members": {"type": "boolean"}, "lazy_load_members": {"type": "boolean"},
"include_redundant_members": {"type": "boolean"}, "include_redundant_members": {"type": "boolean"},
"unread_thread_notifications": {"type": "boolean"},
# Include or exclude events with the provided labels. # Include or exclude events with the provided labels.
# cf https://github.com/matrix-org/matrix-doc/pull/2326 # cf https://github.com/matrix-org/matrix-doc/pull/2326
"org.matrix.labels": {"type": "array", "items": {"type": "string"}}, "org.matrix.labels": {"type": "array", "items": {"type": "string"}},
@@ -240,6 +241,9 @@ class FilterCollection:
def include_redundant_members(self) -> bool: def include_redundant_members(self) -> bool:
return self._room_state_filter.include_redundant_members return self._room_state_filter.include_redundant_members
def unread_thread_notifications(self) -> bool:
return self._room_timeline_filter.unread_thread_notifications
async def filter_presence( async def filter_presence(
self, events: Iterable[UserPresenceState] self, events: Iterable[UserPresenceState]
) -> List[UserPresenceState]: ) -> List[UserPresenceState]:
@@ -304,6 +308,9 @@ class Filter:
self.include_redundant_members = filter_json.get( self.include_redundant_members = filter_json.get(
"include_redundant_members", False "include_redundant_members", False
) )
self.unread_thread_notifications = filter_json.get(
"unread_thread_notifications", False
)
self.types = filter_json.get("types", None) self.types = filter_json.get("types", None)
self.not_types = filter_json.get("not_types", []) self.not_types = filter_json.get("not_types", [])
+7
View File
@@ -82,11 +82,18 @@ class ExperimentalConfig(Config):
# MSC3786 (Add a default push rule to ignore m.room.server_acl events) # MSC3786 (Add a default push rule to ignore m.room.server_acl events)
self.msc3786_enabled: bool = experimental.get("msc3786_enabled", False) self.msc3786_enabled: bool = experimental.get("msc3786_enabled", False)
# MSC3771: Thread read receipts
self.msc3771_enabled: bool = experimental.get("msc3771_enabled", False)
# MSC3772: A push rule for mutual relations. # MSC3772: A push rule for mutual relations.
self.msc3772_enabled: bool = experimental.get("msc3772_enabled", False) self.msc3772_enabled: bool = experimental.get("msc3772_enabled", False)
# MSC3773: Thread notifications
self.msc3773_enabled: bool = experimental.get("msc3773_enabled", False)
# MSC3715: dir param on /relations. # MSC3715: dir param on /relations.
self.msc3715_enabled: bool = experimental.get("msc3715_enabled", False) self.msc3715_enabled: bool = experimental.get("msc3715_enabled", False)
# MSC3848: Introduce errcodes for specific event sending failures # MSC3848: Introduce errcodes for specific event sending failures
self.msc3848_enabled: bool = experimental.get("msc3848_enabled", False) self.msc3848_enabled: bool = experimental.get("msc3848_enabled", False)
# MSC3856: Threads list API
self.msc3856_enabled: bool = experimental.get("msc3856_enabled", False)
+9 -1
View File
@@ -97,6 +97,7 @@ class ReceiptsHandler:
receipt_type=receipt_type, receipt_type=receipt_type,
user_id=user_id, user_id=user_id,
event_ids=user_values["event_ids"], event_ids=user_values["event_ids"],
thread_id=None, # TODO
data=user_values.get("data", {}), data=user_values.get("data", {}),
) )
) )
@@ -114,6 +115,7 @@ class ReceiptsHandler:
receipt.receipt_type, receipt.receipt_type,
receipt.user_id, receipt.user_id,
receipt.event_ids, receipt.event_ids,
receipt.thread_id,
receipt.data, receipt.data,
) )
@@ -146,7 +148,12 @@ class ReceiptsHandler:
return True return True
async def received_client_receipt( async def received_client_receipt(
self, room_id: str, receipt_type: str, user_id: str, event_id: str self,
room_id: str,
receipt_type: str,
user_id: str,
event_id: str,
thread_id: Optional[str],
) -> None: ) -> None:
"""Called when a client tells us a local user has read up to the given """Called when a client tells us a local user has read up to the given
event_id in the room. event_id in the room.
@@ -156,6 +163,7 @@ class ReceiptsHandler:
receipt_type=receipt_type, receipt_type=receipt_type,
user_id=user_id, user_id=user_id,
event_ids=[event_id], event_ids=[event_id],
thread_id=thread_id,
data={"ts": int(self.clock.time_msec())}, data={"ts": int(self.clock.time_msec())},
) )
+89
View File
@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import enum
import logging import logging
from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Tuple from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Tuple
@@ -31,6 +32,13 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class ThreadsListInclude(str, enum.Enum):
"""Valid values for the 'include' flag of /threads."""
all = "all"
participated = "participated"
@attr.s(slots=True, frozen=True, auto_attribs=True) @attr.s(slots=True, frozen=True, auto_attribs=True)
class _ThreadAggregation: class _ThreadAggregation:
# The latest event in the thread. # The latest event in the thread.
@@ -482,3 +490,84 @@ class RelationsHandler:
results.setdefault(event_id, BundledAggregations()).replace = edit results.setdefault(event_id, BundledAggregations()).replace = edit
return results return results
async def get_threads(
self,
requester: Requester,
room_id: str,
include: ThreadsListInclude,
limit: int = 5,
from_token: Optional[StreamToken] = None,
to_token: Optional[StreamToken] = None,
) -> JsonDict:
"""Get related events of a event, ordered by topological ordering.
Args:
requester: The user requesting the relations.
room_id: The room the event belongs to.
include: One of "all" or "participated" to indicate which threads should
be returned.
limit: Only fetch the most recent `limit` events.
from_token: Fetch rows from the given token, or from the start if None.
to_token: Fetch rows up to the given token, or up to the end if None.
Returns:
The pagination chunk.
"""
user_id = requester.user.to_string()
# TODO Properly handle a user leaving a room.
(_, member_event_id) = await self._auth.check_user_in_room_or_world_readable(
room_id, user_id, allow_departed_users=True
)
# Note that ignored users are not passed into get_relations_for_event
# below. Ignored users are handled in filter_events_for_client (and by
# not passing them in here we should get a better cache hit rate).
thread_roots, next_token = await self._main_store.get_threads(
room_id=room_id, limit=limit, from_token=from_token, to_token=to_token
)
events = await self._main_store.get_events_as_list(thread_roots)
if include == ThreadsListInclude.participated:
# Pre-seed thread participation with whether the requester sent the event.
participated = {event.event_id: event.sender == user_id for event in events}
# For events the requester did not send, check the database for whether
# the requester sent a threaded reply.
participated.update(
await self._main_store.get_threads_participated(
[eid for eid, p in participated.items() if not p],
user_id,
)
)
# Limit the returned threads to those the user has participated in.
events = [event for event in events if participated[event.event_id]]
events = await filter_events_for_client(
self._storage_controllers,
user_id,
events,
is_peeking=(member_event_id is None),
)
now = self._clock.time_msec()
aggregations = await self.get_bundled_aggregations(
events, requester.user.to_string()
)
serialized_events = self._event_serializer.serialize_events(
events, now, bundle_aggregations=aggregations
)
return_value: JsonDict = {"chunk": serialized_events}
if next_token:
return_value["next_batch"] = await next_token.to_string(self._main_store)
if from_token:
return_value["prev_batch"] = await from_token.to_string(self._main_store)
return return_value
+37 -5
View File
@@ -115,6 +115,7 @@ class JoinedSyncResult:
ephemeral: List[JsonDict] ephemeral: List[JsonDict]
account_data: List[JsonDict] account_data: List[JsonDict]
unread_notifications: JsonDict unread_notifications: JsonDict
unread_thread_notifications: JsonDict
summary: Optional[JsonDict] summary: Optional[JsonDict]
unread_count: int unread_count: int
@@ -265,6 +266,8 @@ class SyncHandler:
self.rooms_to_exclude = hs.config.server.rooms_to_exclude_from_sync self.rooms_to_exclude = hs.config.server.rooms_to_exclude_from_sync
self._msc3773_enabled = hs.config.experimental.msc3773_enabled
async def wait_for_sync_for_user( async def wait_for_sync_for_user(
self, self,
requester: Requester, requester: Requester,
@@ -1053,7 +1056,7 @@ class SyncHandler:
async def unread_notifs_for_room_id( async def unread_notifs_for_room_id(
self, room_id: str, sync_config: SyncConfig self, room_id: str, sync_config: SyncConfig
) -> NotifCounts: ) -> Tuple[NotifCounts, Dict[str, NotifCounts]]:
with Measure(self.clock, "unread_notifs_for_room_id"): with Measure(self.clock, "unread_notifs_for_room_id"):
return await self.store.get_unread_event_push_actions_by_room_for_user( return await self.store.get_unread_event_push_actions_by_room_for_user(
@@ -2115,17 +2118,46 @@ class SyncHandler:
ephemeral=ephemeral, ephemeral=ephemeral,
account_data=account_data_events, account_data=account_data_events,
unread_notifications=unread_notifications, unread_notifications=unread_notifications,
unread_thread_notifications={},
summary=summary, summary=summary,
unread_count=0, unread_count=0,
) )
if room_sync or always_include: if room_sync or always_include:
notifs = await self.unread_notifs_for_room_id(room_id, sync_config) notifs, thread_notifs = await self.unread_notifs_for_room_id(
room_id, sync_config
)
unread_notifications["notification_count"] = notifs.notify_count # Notifications for the main timeline.
unread_notifications["highlight_count"] = notifs.highlight_count notify_count = notifs.notify_count
highlight_count = notifs.highlight_count
unread_count = notifs.unread_count
room_sync.unread_count = notifs.unread_count # Check the sync configuration.
if (
self._msc3773_enabled
and sync_config.filter_collection.unread_thread_notifications()
):
# And add info for each thread.
room_sync.unread_thread_notifications = {
thread_id: {
"notification_count": tnotifs.notify_count,
"highlight_count": tnotifs.highlight_count,
}
for thread_id, tnotifs in thread_notifs.items()
if thread_id is not None
}
else:
# Combine the unread counts for all threads and main timeline.
for tnotifs in thread_notifs.values():
notify_count += tnotifs.notify_count
highlight_count += tnotifs.highlight_count
unread_count += tnotifs.unread_count
unread_notifications["notification_count"] = notify_count
unread_notifications["highlight_count"] = highlight_count
room_sync.unread_count = unread_count
sync_result_builder.joined.append(room_sync) sync_result_builder.joined.append(room_sync)
+16 -14
View File
@@ -186,7 +186,7 @@ class BulkPushRuleEvaluator:
return pl_event.content if pl_event else {}, sender_level return pl_event.content if pl_event else {}, sender_level
async def _get_mutual_relations( async def _get_mutual_relations(
self, event: EventBase, rules: Iterable[Dict[str, Any]] self, parent_id: str, rules: Iterable[Dict[str, Any]]
) -> Dict[str, Set[Tuple[str, str]]]: ) -> Dict[str, Set[Tuple[str, str]]]:
""" """
Fetch event metadata for events which related to the same event as the given event. Fetch event metadata for events which related to the same event as the given event.
@@ -194,7 +194,7 @@ class BulkPushRuleEvaluator:
If the given event has no relation information, returns an empty dictionary. If the given event has no relation information, returns an empty dictionary.
Args: Args:
event_id: The event ID which is targeted by relations. parent_id: The event ID which is targeted by relations.
rules: The push rules which will be processed for this event. rules: The push rules which will be processed for this event.
Returns: Returns:
@@ -208,12 +208,6 @@ class BulkPushRuleEvaluator:
if not self._relations_match_enabled: if not self._relations_match_enabled:
return {} return {}
# If the event does not have a relation, then cannot have any mutual
# relations.
relation = relation_from_event(event)
if not relation:
return {}
# Pre-filter to figure out which relation types are interesting. # Pre-filter to figure out which relation types are interesting.
rel_types = set() rel_types = set()
for rule in rules: for rule in rules:
@@ -235,9 +229,7 @@ class BulkPushRuleEvaluator:
return {} return {}
# If any valid rules were found, fetch the mutual relations. # If any valid rules were found, fetch the mutual relations.
return await self.store.get_mutual_event_relations( return await self.store.get_mutual_event_relations(parent_id, rel_types)
relation.parent_id, rel_types
)
@measure_func("action_for_event_by_user") @measure_func("action_for_event_by_user")
async def action_for_event_by_user( async def action_for_event_by_user(
@@ -265,9 +257,18 @@ class BulkPushRuleEvaluator:
sender_power_level, sender_power_level,
) = await self._get_power_levels_and_sender_level(event, context) ) = await self._get_power_levels_and_sender_level(event, context)
relations = await self._get_mutual_relations( relation = relation_from_event(event)
event, itertools.chain(*rules_by_user.values()) # If the event does not have a relation, then cannot have any mutual
) # relations or thread ID.
relations = {}
thread_id = None
if relation:
relations = await self._get_mutual_relations(
relation.parent_id, itertools.chain(*rules_by_user.values())
)
# XXX Does this need to point to a valid parent ID or anything?
if relation.rel_type == RelationTypes.THREAD:
thread_id = relation.parent_id
evaluator = PushRuleEvaluatorForEvent( evaluator = PushRuleEvaluatorForEvent(
event, event,
@@ -338,6 +339,7 @@ class BulkPushRuleEvaluator:
event.event_id, event.event_id,
actions_by_user, actions_by_user,
count_as_unread, count_as_unread,
thread_id,
) )
+8 -3
View File
@@ -26,13 +26,18 @@ async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) -
badge = len(invites) badge = len(invites)
for room_id in joins: for room_id in joins:
notifs = await ( notifs, thread_notifs = await (
store.get_unread_event_push_actions_by_room_for_user( store.get_unread_event_push_actions_by_room_for_user(
room_id, room_id,
user_id, user_id,
) )
) )
if notifs.notify_count == 0: # Combine the counts from all the threads.
notify_count = notifs.notify_count + sum(
n.notify_count for n in thread_notifs.values()
)
if notify_count == 0:
continue continue
if group_by_room: if group_by_room:
@@ -40,7 +45,7 @@ async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) -
badge += 1 badge += 1
else: else:
# increment the badge count by the number of unread messages in the room # increment the badge count by the number of unread messages in the room
badge += notifs.notify_count badge += notify_count
return badge return badge
+2 -1
View File
@@ -423,7 +423,8 @@ class FederationSenderHandler:
receipt.receipt_type, receipt.receipt_type,
receipt.user_id, receipt.user_id,
[receipt.event_id], [receipt.event_id],
receipt.data, thread_id=receipt.thread_id,
data=receipt.data,
) )
await self.federation_sender.send_read_receipt(receipt_info) await self.federation_sender.send_read_receipt(receipt_info)
+1
View File
@@ -361,6 +361,7 @@ class ReceiptsStream(Stream):
receipt_type: str receipt_type: str
user_id: str user_id: str
event_id: str event_id: str
thread_id: Optional[str]
data: dict data: dict
NAME = "receipts" NAME = "receipts"
+1
View File
@@ -81,6 +81,7 @@ class ReadMarkerRestServlet(RestServlet):
receipt_type, receipt_type,
user_id=requester.user.to_string(), user_id=requester.user.to_string(),
event_id=event_id, event_id=event_id,
thread_id=None, # TODO
) )
return 200, {} return 200, {}
+16 -4
View File
@@ -13,10 +13,10 @@
# limitations under the License. # limitations under the License.
import logging import logging
from typing import TYPE_CHECKING, Tuple from typing import TYPE_CHECKING, Optional, Tuple
from synapse.api.constants import ReceiptTypes from synapse.api.constants import ReceiptTypes
from synapse.api.errors import SynapseError from synapse.api.errors import SynapseError, UnrecognizedRequestError
from synapse.http.server import HttpServer from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_json_object_from_request from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
@@ -34,7 +34,8 @@ class ReceiptRestServlet(RestServlet):
PATTERNS = client_patterns( PATTERNS = client_patterns(
"/rooms/(?P<room_id>[^/]*)" "/rooms/(?P<room_id>[^/]*)"
"/receipt/(?P<receipt_type>[^/]*)" "/receipt/(?P<receipt_type>[^/]*)"
"/(?P<event_id>[^/]*)$" "/(?P<event_id>[^/]*)"
"(/(?P<thread_id>[^/]*))?$"
) )
def __init__(self, hs: "HomeServer"): def __init__(self, hs: "HomeServer"):
@@ -50,8 +51,15 @@ class ReceiptRestServlet(RestServlet):
(ReceiptTypes.READ_PRIVATE, ReceiptTypes.FULLY_READ) (ReceiptTypes.READ_PRIVATE, ReceiptTypes.FULLY_READ)
) )
self._msc3771_enabled = hs.config.experimental.msc3771_enabled
async def on_POST( async def on_POST(
self, request: SynapseRequest, room_id: str, receipt_type: str, event_id: str self,
request: SynapseRequest,
room_id: str,
receipt_type: str,
event_id: str,
thread_id: Optional[str],
) -> Tuple[int, JsonDict]: ) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request) requester = await self.auth.get_user_by_req(request)
@@ -61,6 +69,9 @@ class ReceiptRestServlet(RestServlet):
f"Receipt type must be {', '.join(self._known_receipt_types)}", f"Receipt type must be {', '.join(self._known_receipt_types)}",
) )
if thread_id and not self._msc3771_enabled:
raise UnrecognizedRequestError()
parse_json_object_from_request(request, allow_empty_body=False) parse_json_object_from_request(request, allow_empty_body=False)
await self.presence_handler.bump_presence_active_time(requester.user) await self.presence_handler.bump_presence_active_time(requester.user)
@@ -77,6 +88,7 @@ class ReceiptRestServlet(RestServlet):
receipt_type, receipt_type,
user_id=requester.user.to_string(), user_id=requester.user.to_string(),
event_id=event_id, event_id=event_id,
thread_id=thread_id,
) )
return 200, {} return 200, {}
+52
View File
@@ -13,8 +13,10 @@
# limitations under the License. # limitations under the License.
import logging import logging
import re
from typing import TYPE_CHECKING, Optional, Tuple from typing import TYPE_CHECKING, Optional, Tuple
from synapse.handlers.relations import ThreadsListInclude
from synapse.http.server import HttpServer from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_integer, parse_string from synapse.http.servlet import RestServlet, parse_integer, parse_string
from synapse.http.site import SynapseRequest from synapse.http.site import SynapseRequest
@@ -91,5 +93,55 @@ class RelationPaginationServlet(RestServlet):
return 200, result return 200, result
class ThreadsServlet(RestServlet):
PATTERNS = (
re.compile(
"^/_matrix/client/unstable/org.matrix.msc3856/rooms/(?P<room_id>[^/]*)/threads"
),
)
def __init__(self, hs: "HomeServer"):
super().__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastores().main
self._relations_handler = hs.get_relations_handler()
async def on_GET(
self, request: SynapseRequest, room_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
limit = parse_integer(request, "limit", default=5)
from_token_str = parse_string(request, "from")
to_token_str = parse_string(request, "to")
include = parse_string(
request,
"include",
default=ThreadsListInclude.all.value,
allowed_values=[v.value for v in ThreadsListInclude],
)
# Return the relations
from_token = None
if from_token_str:
from_token = await StreamToken.from_string(self.store, from_token_str)
to_token = None
if to_token_str:
to_token = await StreamToken.from_string(self.store, to_token_str)
result = await self._relations_handler.get_threads(
requester=requester,
room_id=room_id,
include=ThreadsListInclude(include),
limit=limit,
from_token=from_token,
to_token=to_token,
)
return 200, result
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
RelationPaginationServlet(hs).register(http_server) RelationPaginationServlet(hs).register(http_server)
if hs.config.experimental.msc3856_enabled:
ThreadsServlet(hs).register(http_server)
+2
View File
@@ -509,6 +509,8 @@ class SyncRestServlet(RestServlet):
ephemeral_events = room.ephemeral ephemeral_events = room.ephemeral
result["ephemeral"] = {"events": ephemeral_events} result["ephemeral"] = {"events": ephemeral_events}
result["unread_notifications"] = room.unread_notifications result["unread_notifications"] = room.unread_notifications
if room.unread_thread_notifications:
result["unread_thread_notifications"] = room.unread_thread_notifications
result["summary"] = room.summary result["summary"] = room.summary
if self._msc2654_enabled: if self._msc2654_enabled:
result["org.matrix.msc2654.unread_count"] = room.unread_count result["org.matrix.msc2654.unread_count"] = room.unread_count
@@ -78,7 +78,6 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union, cast
import attr import attr
from synapse.api.constants import ReceiptTypes
from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.storage._base import SQLBaseStore, db_to_json from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import ( from synapse.storage.database import (
@@ -206,10 +205,11 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
self._rotate_count = 10000 self._rotate_count = 10000
self._doing_notif_rotation = False self._doing_notif_rotation = False
if hs.config.worker.run_background_tasks: # XXX Do not rotate summaries for now, they're broken.
self._rotate_notif_loop = self._clock.looping_call( # if hs.config.worker.run_background_tasks:
self._rotate_notifs, 30 * 1000 # self._rotate_notif_loop = self._clock.looping_call(
) # self._rotate_notifs, 30 * 1000
# )
self.db_pool.updates.register_background_index_update( self.db_pool.updates.register_background_index_update(
"event_push_summary_unique_index", "event_push_summary_unique_index",
@@ -220,12 +220,21 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
replaces_index="event_push_summary_user_rm", replaces_index="event_push_summary_user_rm",
) )
@cached(tree=True, max_entries=5000) self.db_pool.updates.register_background_index_update(
"event_push_summary_unique_index2",
index_name="event_push_summary_unique_index2",
table="event_push_summary",
columns=["user_id", "room_id", "thread_id"],
unique=True,
replaces_index="event_push_summary_unique_index",
)
@cached(tree=True, max_entries=5000, iterable=True)
async def get_unread_event_push_actions_by_room_for_user( async def get_unread_event_push_actions_by_room_for_user(
self, self,
room_id: str, room_id: str,
user_id: str, user_id: str,
) -> NotifCounts: ) -> Tuple[NotifCounts, Dict[str, NotifCounts]]:
"""Get the notification count, the highlight count and the unread message count """Get the notification count, the highlight count and the unread message count
for a given user in a given room after the given read receipt. for a given user in a given room after the given read receipt.
@@ -254,31 +263,19 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
txn: LoggingTransaction, txn: LoggingTransaction,
room_id: str, room_id: str,
user_id: str, user_id: str,
) -> NotifCounts: ) -> Tuple[NotifCounts, Dict[str, NotifCounts]]:
result = self.get_last_receipt_for_user_txn( # Either last_read_event_id is None, or it's an event we don't have (e.g.
txn, # because it's been purged), in which case retrieve the stream ordering for
user_id, # the latest membership event from this user in this room (which we assume is
room_id, # a join).
receipt_types=(ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE), event_id = self.db_pool.simple_select_one_onecol_txn(
txn=txn,
table="local_current_membership",
keyvalues={"room_id": room_id, "user_id": user_id},
retcol="event_id",
) )
stream_ordering = None stream_ordering = self.get_stream_id_for_event_txn(txn, event_id)
if result:
_, stream_ordering = result
if stream_ordering is None:
# Either last_read_event_id is None, or it's an event we don't have (e.g.
# because it's been purged), in which case retrieve the stream ordering for
# the latest membership event from this user in this room (which we assume is
# a join).
event_id = self.db_pool.simple_select_one_onecol_txn(
txn=txn,
table="local_current_membership",
keyvalues={"room_id": room_id, "user_id": user_id},
retcol="event_id",
)
stream_ordering = self.get_stream_id_for_event_txn(txn, event_id)
return self._get_unread_counts_by_pos_txn( return self._get_unread_counts_by_pos_txn(
txn, room_id, user_id, stream_ordering txn, room_id, user_id, stream_ordering
@@ -286,12 +283,20 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
def _get_unread_counts_by_pos_txn( def _get_unread_counts_by_pos_txn(
self, txn: LoggingTransaction, room_id: str, user_id: str, stream_ordering: int self, txn: LoggingTransaction, room_id: str, user_id: str, stream_ordering: int
) -> NotifCounts: ) -> Tuple[NotifCounts, Dict[str, NotifCounts]]:
"""Get the number of unread messages for a user/room that have happened """Get the number of unread messages for a user/room that have happened
since the given stream ordering. since the given stream ordering.
Returns:
A tuple of:
The unread messages for the main timeline
A dictionary of thread ID to unread messages for that thread.
Only contains threads with unread messages.
""" """
counts = NotifCounts() counts = NotifCounts()
thread_counts = {}
# First we pull the counts from the summary table. # First we pull the counts from the summary table.
# #
@@ -306,51 +311,92 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
# date (as the row was written by an older version of Synapse that # date (as the row was written by an older version of Synapse that
# updated `event_push_summary` synchronously when persisting a new read # updated `event_push_summary` synchronously when persisting a new read
# receipt). # receipt).
txn.execute(
"""
SELECT stream_ordering, notif_count, COALESCE(unread_count, 0)
FROM event_push_summary
WHERE room_id = ? AND user_id = ?
AND (
(last_receipt_stream_ordering IS NULL AND stream_ordering > ?)
OR last_receipt_stream_ordering = ?
)
""",
(room_id, user_id, stream_ordering, stream_ordering),
)
row = txn.fetchone()
summary_stream_ordering = 0 # XXX event_push_summary is not currently filled in. broken.
if row: # txn.execute(
summary_stream_ordering = row[0] # """
counts.notify_count += row[1] # SELECT notif_count, COALESCE(unread_count, 0), thread_id, MAX(events.stream_ordering)
counts.unread_count += row[2] # FROM event_push_summary
# LEFT JOIN receipts_linearized USING (room_id, user_id, thread_id)
# LEFT JOIN events ON (
# events.room_id = receipts_linearized.room_id AND
# events.event_id = receipts_linearized.event_id
# )
# WHERE event_push_summary.room_id = ? AND user_id = ?
# AND (
# (
# last_receipt_stream_ordering IS NULL
# AND event_push_summary.stream_ordering > COALESCE(events.stream_ordering, ?)
# )
# OR last_receipt_stream_ordering = COALESCE(events.stream_ordering, ?)
# )
# AND (receipt_type = 'm.read' OR receipt_type = 'org.matrix.msc2285.read.private')
# """,
# (room_id, user_id, stream_ordering, stream_ordering),
# )
# for notif_count, unread_count, thread_id, _ in txn:
# # XXX Why are these returned? Related to MAX(...) aggregation.
# if notif_count is None:
# continue
#
# if not thread_id:
# counts = NotifCounts(
# notify_count=notif_count, unread_count=unread_count
# )
# # TODO Delete zeroed out threads completely from the database.
# elif notif_count or unread_count:
# thread_counts[thread_id] = NotifCounts(
# notify_count=notif_count, unread_count=unread_count
# )
# Next we need to count highlights, which aren't summarised # Next we need to count highlights, which aren't summarised
sql = """ sql = """
SELECT COUNT(*) FROM event_push_actions SELECT COUNT(*), thread_id, MAX(events.stream_ordering) FROM event_push_actions
LEFT JOIN receipts_linearized USING (room_id, user_id, thread_id)
LEFT JOIN events ON (
events.room_id = receipts_linearized.room_id AND
events.event_id = receipts_linearized.event_id
)
WHERE user_id = ? WHERE user_id = ?
AND room_id = ? AND event_push_actions.room_id = ?
AND stream_ordering > ? AND event_push_actions.stream_ordering > COALESCE(events.stream_ordering, ?)
AND highlight = 1 AND highlight = 1
GROUP BY thread_id
""" """
txn.execute(sql, (user_id, room_id, stream_ordering)) txn.execute(sql, (user_id, room_id, stream_ordering))
row = txn.fetchone() for highlight_count, thread_id, _ in txn:
if row: if not thread_id:
counts.highlight_count += row[0] counts.highlight_count += highlight_count
elif highlight_count:
if thread_id in thread_counts:
thread_counts[thread_id].highlight_count += highlight_count
else:
thread_counts[thread_id] = NotifCounts(
notify_count=0, unread_count=0, highlight_count=highlight_count
)
# Finally we need to count push actions that aren't included in the # Finally we need to count push actions that aren't included in the
# summary returned above, e.g. recent events that haven't been # summary returned above, e.g. recent events that haven't been
# summarised yet, or the summary is empty due to a recent read receipt. # summarised yet, or the summary is empty due to a recent read receipt.
stream_ordering = max(stream_ordering, summary_stream_ordering) unread_counts = self._get_notif_unread_count_for_user_room(
notify_count, unread_count = self._get_notif_unread_count_for_user_room(
txn, room_id, user_id, stream_ordering txn, room_id, user_id, stream_ordering
) )
counts.notify_count += notify_count for notif_count, unread_count, thread_id in unread_counts:
counts.unread_count += unread_count if not thread_id:
counts.notify_count += notif_count
counts.unread_count += unread_count
elif thread_id in thread_counts:
thread_counts[thread_id].notify_count += notif_count
thread_counts[thread_id].unread_count += unread_count
else:
thread_counts[thread_id] = NotifCounts(
notify_count=notif_count,
unread_count=unread_count,
highlight_count=0,
)
return counts return counts, thread_counts
def _get_notif_unread_count_for_user_room( def _get_notif_unread_count_for_user_room(
self, self,
@@ -359,7 +405,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
user_id: str, user_id: str,
stream_ordering: int, stream_ordering: int,
max_stream_ordering: Optional[int] = None, max_stream_ordering: Optional[int] = None,
) -> Tuple[int, int]: ) -> List[Tuple[int, int, str]]:
"""Returns the notify and unread counts from `event_push_actions` for """Returns the notify and unread counts from `event_push_actions` for
the given user/room in the given range. the given user/room in the given range.
@@ -380,8 +426,10 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
# If there have been no events in the room since the stream ordering, # If there have been no events in the room since the stream ordering,
# there can't be any push actions either. # there can't be any push actions either.
if not self._events_stream_cache.has_entity_changed(room_id, stream_ordering): #
return 0, 0 # XXX
# if not self._events_stream_cache.has_entity_changed(room_id, stream_ordering):
# return []
clause = "" clause = ""
args = [user_id, room_id, stream_ordering] args = [user_id, room_id, stream_ordering]
@@ -389,29 +437,29 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
clause = "AND ea.stream_ordering <= ?" clause = "AND ea.stream_ordering <= ?"
args.append(max_stream_ordering) args.append(max_stream_ordering)
# If the max stream ordering is less than the min stream ordering,
# then obviously there are zero push actions in that range.
if max_stream_ordering <= stream_ordering:
return 0, 0
sql = f""" sql = f"""
SELECT SELECT
COUNT(CASE WHEN notif = 1 THEN 1 END), COUNT(CASE WHEN notif = 1 THEN 1 END),
COUNT(CASE WHEN unread = 1 THEN 1 END) COUNT(CASE WHEN unread = 1 THEN 1 END),
FROM event_push_actions ea thread_id,
WHERE user_id = ? MAX(events.stream_ordering)
AND room_id = ? FROM event_push_actions ea
AND ea.stream_ordering > ? LEFT JOIN receipts_linearized USING (room_id, user_id, thread_id)
LEFT JOIN events ON (
events.room_id = receipts_linearized.room_id AND
events.event_id = receipts_linearized.event_id
)
WHERE user_id = ?
AND ea.room_id = ?
AND ea.stream_ordering > COALESCE(events.stream_ordering, ?)
{clause} {clause}
GROUP BY thread_id
""" """
txn.execute(sql, args) txn.execute(sql, args)
row = txn.fetchone() # The max stream ordering is simply there to select the latest receipt,
# it doesn't need to be returned.
if row: return [cast(Tuple[int, int, str], row[:3]) for row in txn.fetchall()]
return cast(Tuple[int, int], row)
return 0, 0
async def get_push_action_users_in_range( async def get_push_action_users_in_range(
self, min_stream_ordering: int, max_stream_ordering: int self, min_stream_ordering: int, max_stream_ordering: int
@@ -680,6 +728,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
event_id: str, event_id: str,
user_id_actions: Dict[str, List[Union[dict, str]]], user_id_actions: Dict[str, List[Union[dict, str]]],
count_as_unread: bool, count_as_unread: bool,
thread_id: Optional[str],
) -> None: ) -> None:
"""Add the push actions for the event to the push action staging area. """Add the push actions for the event to the push action staging area.
@@ -688,6 +737,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
user_id_actions: A mapping of user_id to list of push actions, where user_id_actions: A mapping of user_id to list of push actions, where
an action can either be a string or dict. an action can either be a string or dict.
count_as_unread: Whether this event should increment unread counts. count_as_unread: Whether this event should increment unread counts.
thread_id: The thread this event is parent of, if applicable.
""" """
if not user_id_actions: if not user_id_actions:
return return
@@ -696,7 +746,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
# can be used to insert into the `event_push_actions_staging` table. # can be used to insert into the `event_push_actions_staging` table.
def _gen_entry( def _gen_entry(
user_id: str, actions: List[Union[dict, str]] user_id: str, actions: List[Union[dict, str]]
) -> Tuple[str, str, str, int, int, int]: ) -> Tuple[str, str, str, int, int, int, Optional[str]]:
is_highlight = 1 if _action_has_highlight(actions) else 0 is_highlight = 1 if _action_has_highlight(actions) else 0
notif = 1 if "notify" in actions else 0 notif = 1 if "notify" in actions else 0
return ( return (
@@ -706,6 +756,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
notif, # notif column notif, # notif column
is_highlight, # highlight column is_highlight, # highlight column
int(count_as_unread), # unread column int(count_as_unread), # unread column
thread_id or "", # thread_id column
) )
def _add_push_actions_to_staging_txn(txn: LoggingTransaction) -> None: def _add_push_actions_to_staging_txn(txn: LoggingTransaction) -> None:
@@ -714,8 +765,8 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
sql = """ sql = """
INSERT INTO event_push_actions_staging INSERT INTO event_push_actions_staging
(event_id, user_id, actions, notif, highlight, unread) (event_id, user_id, actions, notif, highlight, unread, thread_id)
VALUES (?, ?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?, ?, ?)
""" """
txn.execute_batch( txn.execute_batch(
@@ -955,7 +1006,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
) )
sql = """ sql = """
SELECT r.stream_id, r.room_id, r.user_id, e.stream_ordering SELECT r.stream_id, r.room_id, r.user_id, r.thread_id, e.stream_ordering
FROM receipts_linearized AS r FROM receipts_linearized AS r
INNER JOIN events AS e USING (event_id) INNER JOIN events AS e USING (event_id)
WHERE ? < r.stream_id AND r.stream_id <= ? AND user_id LIKE ? WHERE ? < r.stream_id AND r.stream_id <= ? AND user_id LIKE ?
@@ -978,41 +1029,83 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
) )
rows = txn.fetchall() rows = txn.fetchall()
# For each new read receipt we delete push actions from before it and # Group the rows by room ID / user ID.
# recalculate the summary. rows_by_room_user: Dict[Tuple[str, str], List[Tuple[str, str, int]]] = {}
for _, room_id, user_id, stream_ordering in rows: for stream_id, room_id, user_id, thread_id, stream_ordering in rows:
# Only handle our own read receipts. # Only handle our own read receipts.
if not self.hs.is_mine_id(user_id): if not self.hs.is_mine_id(user_id):
continue continue
txn.execute( rows_by_room_user.setdefault((room_id, user_id), []).append(
""" (stream_id, thread_id, stream_ordering)
DELETE FROM event_push_actions
WHERE room_id = ?
AND user_id = ?
AND stream_ordering <= ?
AND highlight = 0
""",
(room_id, user_id, stream_ordering),
) )
# For each new read receipt we delete push actions from before it and
# recalculate the summary.
for (room_id, user_id), room_rows in rows_by_room_user.items():
for _, thread_id, stream_ordering in room_rows:
txn.execute(
"""
DELETE FROM event_push_actions
WHERE room_id = ?
AND user_id = ?
AND thread_id = ?
AND stream_ordering <= ?
AND highlight = 0
""",
(room_id, user_id, thread_id, stream_ordering),
)
# Fetch the notification counts between the stream ordering of the # Fetch the notification counts between the stream ordering of the
# latest receipt and what was previously summarised. # latest receipt and what was previously summarised.
notif_count, unread_count = self._get_notif_unread_count_for_user_room( earliest_stream_ordering = min(r[2] for r in room_rows)
txn, room_id, user_id, stream_ordering, old_rotate_stream_ordering unread_counts = self._get_notif_unread_count_for_user_room(
txn,
room_id,
user_id,
earliest_stream_ordering,
old_rotate_stream_ordering,
) )
# Replace the previous summary with the new counts. # Updated threads get their notification count and unread count updated.
self.db_pool.simple_upsert_txn( self.db_pool.simple_upsert_many_txn(
txn, txn,
table="event_push_summary", table="event_push_summary",
keyvalues={"room_id": room_id, "user_id": user_id}, key_names=("room_id", "user_id", "thread_id"),
values={ key_values=[(room_id, user_id, row[2]) for row in unread_counts],
"notif_count": notif_count, value_names=(
"unread_count": unread_count, "notif_count",
"stream_ordering": old_rotate_stream_ordering, "unread_count",
"last_receipt_stream_ordering": stream_ordering, "stream_ordering",
}, "last_receipt_stream_ordering",
),
value_values=[
# XXX Stream ordering.
(
row[0],
row[1],
old_rotate_stream_ordering,
earliest_stream_ordering,
)
for row in unread_counts
],
)
# XXX WTF?
# Other threads should be marked as reset at the old stream ordering.
txn.execute(
"""
UPDATE event_push_summary SET notif_count = 0, unread_count = 0, stream_ordering = ?, last_receipt_stream_ordering = ?
WHERE user_id = ? AND room_id = ? AND
stream_ordering <= ?
""",
(
old_rotate_stream_ordering,
min_receipts_stream_id,
user_id,
room_id,
old_rotate_stream_ordering,
),
) )
# We always update `event_push_summary_last_receipt_stream_id` to # We always update `event_push_summary_last_receipt_stream_id` to
@@ -1102,23 +1195,23 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
# Calculate the new counts that should be upserted into event_push_summary # Calculate the new counts that should be upserted into event_push_summary
sql = """ sql = """
SELECT user_id, room_id, SELECT user_id, room_id, thread_id,
coalesce(old.%s, 0) + upd.cnt, coalesce(old.%s, 0) + upd.cnt,
upd.stream_ordering upd.stream_ordering
FROM ( FROM (
SELECT user_id, room_id, count(*) as cnt, SELECT user_id, room_id, thread_id, count(*) as cnt,
max(ea.stream_ordering) as stream_ordering max(ea.stream_ordering) as stream_ordering
FROM event_push_actions AS ea FROM event_push_actions AS ea
LEFT JOIN event_push_summary AS old USING (user_id, room_id) LEFT JOIN event_push_summary AS old USING (user_id, room_id, thread_id)
WHERE ? < ea.stream_ordering AND ea.stream_ordering <= ? WHERE ? < ea.stream_ordering AND ea.stream_ordering <= ?
AND ( AND (
old.last_receipt_stream_ordering IS NULL old.last_receipt_stream_ordering IS NULL
OR old.last_receipt_stream_ordering < ea.stream_ordering OR old.last_receipt_stream_ordering < ea.stream_ordering
) )
AND %s = 1 AND %s = 1
GROUP BY user_id, room_id GROUP BY user_id, room_id, thread_id
) AS upd ) AS upd
LEFT JOIN event_push_summary AS old USING (user_id, room_id) LEFT JOIN event_push_summary AS old USING (user_id, room_id, thread_id)
""" """
# First get the count of unread messages. # First get the count of unread messages.
@@ -1132,11 +1225,11 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
# object because we might not have the same amount of rows in each of them. To do # object because we might not have the same amount of rows in each of them. To do
# this, we use a dict indexed on the user ID and room ID to make it easier to # this, we use a dict indexed on the user ID and room ID to make it easier to
# populate. # populate.
summaries: Dict[Tuple[str, str], _EventPushSummary] = {} summaries: Dict[Tuple[str, str, str], _EventPushSummary] = {}
for row in txn: for row in txn:
summaries[(row[0], row[1])] = _EventPushSummary( summaries[(row[0], row[1], row[2])] = _EventPushSummary(
unread_count=row[2], unread_count=row[3],
stream_ordering=row[3], stream_ordering=row[4],
notif_count=0, notif_count=0,
) )
@@ -1147,17 +1240,17 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
) )
for row in txn: for row in txn:
if (row[0], row[1]) in summaries: if (row[0], row[1], row[2]) in summaries:
summaries[(row[0], row[1])].notif_count = row[2] summaries[(row[0], row[1], row[2])].notif_count = row[3]
else: else:
# Because the rules on notifying are different than the rules on marking # Because the rules on notifying are different than the rules on marking
# a message unread, we might end up with messages that notify but aren't # a message unread, we might end up with messages that notify but aren't
# marked unread, so we might not have a summary for this (user, room) # marked unread, so we might not have a summary for this (user, room)
# tuple to complete. # tuple to complete.
summaries[(row[0], row[1])] = _EventPushSummary( summaries[(row[0], row[1], row[2])] = _EventPushSummary(
unread_count=0, unread_count=0,
stream_ordering=row[3], stream_ordering=row[4],
notif_count=row[2], notif_count=row[3],
) )
logger.info("Rotating notifications, handling %d rows", len(summaries)) logger.info("Rotating notifications, handling %d rows", len(summaries))
@@ -1165,8 +1258,11 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
self.db_pool.simple_upsert_many_txn( self.db_pool.simple_upsert_many_txn(
txn, txn,
table="event_push_summary", table="event_push_summary",
key_names=("user_id", "room_id"), key_names=("user_id", "room_id", "thread_id"),
key_values=[(user_id, room_id) for user_id, room_id in summaries], key_values=[
(user_id, room_id, thread_id)
for user_id, room_id, thread_id in summaries
],
value_names=("notif_count", "unread_count", "stream_ordering"), value_names=("notif_count", "unread_count", "stream_ordering"),
value_values=[ value_values=[
( (
+7 -4
View File
@@ -1594,7 +1594,7 @@ class PersistEventsStore:
) )
# Remove from relations table. # Remove from relations table.
self._handle_redact_relations(txn, event.redacts) self._handle_redact_relations(txn, event.room_id, event.redacts)
# Update the event_forward_extremities, event_backward_extremities and # Update the event_forward_extremities, event_backward_extremities and
# event_edges tables. # event_edges tables.
@@ -1909,6 +1909,7 @@ class PersistEventsStore:
self.store.get_thread_participated.invalidate, self.store.get_thread_participated.invalidate,
(relation.parent_id, event.sender), (relation.parent_id, event.sender),
) )
txn.call_after(self.store.get_threads.invalidate, (event.room_id,))
def _handle_insertion_event( def _handle_insertion_event(
self, txn: LoggingTransaction, event: EventBase self, txn: LoggingTransaction, event: EventBase
@@ -2033,13 +2034,14 @@ class PersistEventsStore:
txn.execute(sql, (batch_id,)) txn.execute(sql, (batch_id,))
def _handle_redact_relations( def _handle_redact_relations(
self, txn: LoggingTransaction, redacted_event_id: str self, txn: LoggingTransaction, room_id: str, redacted_event_id: str
) -> None: ) -> None:
"""Handles receiving a redaction and checking whether the redacted event """Handles receiving a redaction and checking whether the redacted event
has any relations which must be removed from the database. has any relations which must be removed from the database.
Args: Args:
txn txn
room_id: The room ID of the event that was redacted.
redacted_event_id: The event that was redacted. redacted_event_id: The event that was redacted.
""" """
@@ -2068,6 +2070,7 @@ class PersistEventsStore:
self.store._invalidate_cache_and_stream( self.store._invalidate_cache_and_stream(
txn, self.store.get_thread_participated, (redacted_relates_to,) txn, self.store.get_thread_participated, (redacted_relates_to,)
) )
txn.call_after(self.store.get_threads.invalidate, (room_id,))
self.store._invalidate_cache_and_stream( self.store._invalidate_cache_and_stream(
txn, txn,
self.store.get_mutual_event_relations_for_rel_type, self.store.get_mutual_event_relations_for_rel_type,
@@ -2190,9 +2193,9 @@ class PersistEventsStore:
sql = """ sql = """
INSERT INTO event_push_actions ( INSERT INTO event_push_actions (
room_id, event_id, user_id, actions, stream_ordering, room_id, event_id, user_id, actions, stream_ordering,
topological_ordering, notif, highlight, unread topological_ordering, notif, highlight, unread, thread_id
) )
SELECT ?, event_id, user_id, actions, ?, ?, notif, highlight, unread SELECT ?, event_id, user_id, actions, ?, ?, notif, highlight, unread, thread_id
FROM event_push_actions_staging FROM event_push_actions_staging
WHERE event_id = ? WHERE event_id = ?
""" """
+32 -10
View File
@@ -117,6 +117,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
"""Get the current max stream ID for receipts stream""" """Get the current max stream ID for receipts stream"""
return self._receipts_id_gen.get_current_token() return self._receipts_id_gen.get_current_token()
# XXX MOVE TO TESTS
async def get_last_receipt_event_id_for_user( async def get_last_receipt_event_id_for_user(
self, user_id: str, room_id: str, receipt_types: Collection[str] self, user_id: str, room_id: str, receipt_types: Collection[str]
) -> Optional[str]: ) -> Optional[str]:
@@ -411,6 +412,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
"_get_linearized_receipts_for_rooms", f "_get_linearized_receipts_for_rooms", f
) )
# Map of room ID to a dictionary in the form that sync wants it.
results: JsonDict = {} results: JsonDict = {}
for row in txn_results: for row in txn_results:
# We want a single event per room, since we want to batch the # We want a single event per room, since we want to batch the
@@ -426,6 +428,8 @@ class ReceiptsWorkerStore(SQLBaseStore):
receipt_type = event_entry.setdefault(row["receipt_type"], {}) receipt_type = event_entry.setdefault(row["receipt_type"], {})
receipt_type[row["user_id"]] = db_to_json(row["data"]) receipt_type[row["user_id"]] = db_to_json(row["data"])
if row["thread_id"]:
receipt_type[row["user_id"]]["thread_id"] = row["thread_id"]
results = { results = {
room_id: [results[room_id]] if room_id in results else [] room_id: [results[room_id]] if room_id in results else []
@@ -522,7 +526,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
async def get_all_updated_receipts( async def get_all_updated_receipts(
self, instance_name: str, last_id: int, current_id: int, limit: int self, instance_name: str, last_id: int, current_id: int, limit: int
) -> Tuple[List[Tuple[int, list]], int, bool]: ) -> Tuple[List[Tuple[int, Tuple[str, str, str, str, str, JsonDict]]], int, bool]:
"""Get updates for receipts replication stream. """Get updates for receipts replication stream.
Args: Args:
@@ -549,9 +553,11 @@ class ReceiptsWorkerStore(SQLBaseStore):
def get_all_updated_receipts_txn( def get_all_updated_receipts_txn(
txn: LoggingTransaction, txn: LoggingTransaction,
) -> Tuple[List[Tuple[int, list]], int, bool]: ) -> Tuple[
List[Tuple[int, Tuple[str, str, str, str, str, JsonDict]]], int, bool
]:
sql = """ sql = """
SELECT stream_id, room_id, receipt_type, user_id, event_id, data SELECT stream_id, room_id, receipt_type, user_id, event_id, thread_id, data
FROM receipts_linearized FROM receipts_linearized
WHERE ? < stream_id AND stream_id <= ? WHERE ? < stream_id AND stream_id <= ?
ORDER BY stream_id ASC ORDER BY stream_id ASC
@@ -560,8 +566,8 @@ class ReceiptsWorkerStore(SQLBaseStore):
txn.execute(sql, (last_id, current_id, limit)) txn.execute(sql, (last_id, current_id, limit))
updates = cast( updates = cast(
List[Tuple[int, list]], List[Tuple[int, Tuple[str, str, str, str, str, JsonDict]]],
[(r[0], r[1:5] + (db_to_json(r[5]),)) for r in txn], [(r[0], r[1:6] + (db_to_json(r[6]),)) for r in txn],
) )
limited = False limited = False
@@ -613,6 +619,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
receipt_type: str, receipt_type: str,
user_id: str, user_id: str,
event_id: str, event_id: str,
thread_id: Optional[str],
data: JsonDict, data: JsonDict,
stream_id: int, stream_id: int,
) -> Optional[int]: ) -> Optional[int]:
@@ -636,15 +643,18 @@ class ReceiptsWorkerStore(SQLBaseStore):
stream_ordering = int(res["stream_ordering"]) if res else None stream_ordering = int(res["stream_ordering"]) if res else None
rx_ts = res["received_ts"] if res else 0 rx_ts = res["received_ts"] if res else 0
# Convert None to a blank string.
thread_id = thread_id or ""
# We don't want to clobber receipts for more recent events, so we # We don't want to clobber receipts for more recent events, so we
# have to compare orderings of existing receipts # have to compare orderings of existing receipts
if stream_ordering is not None: if stream_ordering is not None:
sql = ( sql = (
"SELECT stream_ordering, event_id FROM events" "SELECT stream_ordering, event_id FROM events"
" INNER JOIN receipts_linearized AS r USING (event_id, room_id)" " INNER JOIN receipts_linearized as r USING (event_id, room_id)"
" WHERE r.room_id = ? AND r.receipt_type = ? AND r.user_id = ?" " WHERE r.room_id = ? AND r.receipt_type = ? AND r.user_id = ? AND r.thread_id = ?"
) )
txn.execute(sql, (room_id, receipt_type, user_id)) txn.execute(sql, (room_id, receipt_type, user_id, thread_id))
for so, eid in txn: for so, eid in txn:
if int(so) >= stream_ordering: if int(so) >= stream_ordering:
@@ -656,6 +666,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
) )
return None return None
# TODO
txn.call_after( txn.call_after(
self.invalidate_caches_for_receipt, room_id, receipt_type, user_id self.invalidate_caches_for_receipt, room_id, receipt_type, user_id
) )
@@ -671,6 +682,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
"room_id": room_id, "room_id": room_id,
"receipt_type": receipt_type, "receipt_type": receipt_type,
"user_id": user_id, "user_id": user_id,
"thread_id": thread_id,
}, },
values={ values={
"stream_id": stream_id, "stream_id": stream_id,
@@ -678,7 +690,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
"data": json_encoder.encode(data), "data": json_encoder.encode(data),
}, },
# receipts_linearized has a unique constraint on # receipts_linearized has a unique constraint on
# (user_id, room_id, receipt_type), so no need to lock # (user_id, room_id, receipt_type, key), so no need to lock
lock=False, lock=False,
) )
@@ -728,6 +740,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
receipt_type: str, receipt_type: str,
user_id: str, user_id: str,
event_ids: List[str], event_ids: List[str],
thread_id: Optional[str],
data: dict, data: dict,
) -> Optional[Tuple[int, int]]: ) -> Optional[Tuple[int, int]]:
"""Insert a receipt, either from local client or remote server. """Insert a receipt, either from local client or remote server.
@@ -760,6 +773,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
receipt_type, receipt_type,
user_id, user_id,
linearized_event_id, linearized_event_id,
thread_id,
data, data,
stream_id=stream_id, stream_id=stream_id,
# Read committed is actually beneficial here because we check for a receipt with # Read committed is actually beneficial here because we check for a receipt with
@@ -774,7 +788,8 @@ class ReceiptsWorkerStore(SQLBaseStore):
now = self._clock.time_msec() now = self._clock.time_msec()
logger.debug( logger.debug(
"RR for event %s in %s (%i ms old)", "Receipt %s for event %s in %s (%i ms old)",
receipt_type,
linearized_event_id, linearized_event_id,
room_id, room_id,
now - event_ts, now - event_ts,
@@ -787,6 +802,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
receipt_type, receipt_type,
user_id, user_id,
event_ids, event_ids,
thread_id,
data, data,
) )
@@ -801,6 +817,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
receipt_type: str, receipt_type: str,
user_id: str, user_id: str,
event_ids: List[str], event_ids: List[str],
thread_id: Optional[str],
data: JsonDict, data: JsonDict,
) -> None: ) -> None:
assert self._can_write_to_receipts assert self._can_write_to_receipts
@@ -812,6 +829,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
# FIXME: This shouldn't invalidate the whole cache # FIXME: This shouldn't invalidate the whole cache
txn.call_after(self._get_linearized_receipts_for_room.invalidate, (room_id,)) txn.call_after(self._get_linearized_receipts_for_room.invalidate, (room_id,))
# Convert None to a blank string.
thread_id = thread_id or ""
self.db_pool.simple_delete_txn( self.db_pool.simple_delete_txn(
txn, txn,
table="receipts_graph", table="receipts_graph",
@@ -819,6 +839,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
"room_id": room_id, "room_id": room_id,
"receipt_type": receipt_type, "receipt_type": receipt_type,
"user_id": user_id, "user_id": user_id,
"thread_id": thread_id,
}, },
) )
self.db_pool.simple_insert_txn( self.db_pool.simple_insert_txn(
@@ -829,6 +850,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
"receipt_type": receipt_type, "receipt_type": receipt_type,
"user_id": user_id, "user_id": user_id,
"event_ids": json_encoder.encode(event_ids), "event_ids": json_encoder.encode(event_ids),
"thread_id": thread_id,
"data": json_encoder.encode(data), "data": json_encoder.encode(data),
}, },
) )
@@ -814,6 +814,93 @@ class RelationsWorkerStore(SQLBaseStore):
"get_event_relations", _get_event_relations "get_event_relations", _get_event_relations
) )
@cached(tree=True)
async def get_threads(
self,
room_id: str,
limit: int = 5,
from_token: Optional[StreamToken] = None,
to_token: Optional[StreamToken] = None,
) -> Tuple[List[str], Optional[StreamToken]]:
"""Get a list of thread IDs, ordered by topological ordering of their
latest reply.
Args:
room_id: The room the event belongs to.
limit: Only fetch the most recent `limit` threads.
from_token: Fetch rows from the given token, or from the start if None.
to_token: Fetch rows up to the given token, or up to the end if None.
Returns:
A tuple of:
A list of thread root event IDs.
The next stream token, if one exists.
"""
pagination_clause = generate_pagination_where_clause(
direction="b",
column_names=("topological_ordering", "stream_ordering"),
from_token=from_token.room_key.as_historical_tuple()
if from_token
else None,
to_token=to_token.room_key.as_historical_tuple() if to_token else None,
engine=self.database_engine,
)
if pagination_clause:
pagination_clause = "AND " + pagination_clause
sql = f"""
SELECT relates_to_id, MAX(topological_ordering), MAX(stream_ordering)
FROM event_relations
INNER JOIN events USING (event_id)
WHERE
room_id = ? AND
relation_type = '{RelationTypes.THREAD}'
{pagination_clause}
GROUP BY relates_to_id
ORDER BY MAX(topological_ordering) DESC, MAX(stream_ordering) DESC
LIMIT ?
"""
def _get_threads_txn(
txn: LoggingTransaction,
) -> Tuple[List[str], Optional[StreamToken]]:
txn.execute(sql, [room_id, limit + 1])
last_topo_id = None
last_stream_id = None
thread_ids = []
for thread_id, topo_id, stream_id in txn:
thread_ids.append(thread_id)
last_topo_id = topo_id
last_stream_id = stream_id
# If there are more events, generate the next pagination key.
next_token = None
if len(thread_ids) > limit and last_topo_id and last_stream_id:
next_key = RoomStreamToken(last_topo_id, last_stream_id)
if from_token:
next_token = from_token.copy_and_replace(
StreamKeyType.ROOM, next_key
)
else:
next_token = StreamToken(
room_key=next_key,
presence_key=0,
typing_key=0,
receipt_key=0,
account_data_key=0,
push_rules_key=0,
to_device_key=0,
device_list_key=0,
groups_key=0,
)
return thread_ids[:limit], next_token
return await self.db_pool.runInteraction("get_threads", _get_threads_txn)
class RelationsStore(RelationsWorkerStore): class RelationsStore(RelationsWorkerStore):
pass pass
@@ -0,0 +1,25 @@
/* Copyright 2022 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- Allow multiple receipts per user per room via a nullable thread_id column.
ALTER TABLE receipts_linearized ADD COLUMN thread_id TEXT NOT NULL DEFAULT '';
ALTER TABLE receipts_graph ADD COLUMN thread_id TEXT NOT NULL DEFAULT '';
-- Rebuild the unique constraint with the thread_id.
ALTER TABLE receipts_linearized DROP CONSTRAINT IF EXISTS receipts_linearized_uniqueness;
ALTER TABLE receipts_linearized ADD CONSTRAINT receipts_linearized_uniqueness UNIQUE (room_id, receipt_type, user_id, thread_id);
ALTER TABLE receipts_graph DROP CONSTRAINT IF EXISTS receipts_graph_uniqueness;
ALTER TABLE receipts_graph ADD CONSTRAINT receipts_graph_uniqueness UNIQUE (room_id, receipt_type, user_id, thread_id);
@@ -0,0 +1,67 @@
/* Copyright 2022 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- Allow multiple receipts per user per room via a nullable thread_id column.
--
-- SQLite doesn't support modifying constraints to an existing table, so it must
-- be recreated.
-- Create the new tables.
CREATE TABLE receipts_graph_new (
room_id TEXT NOT NULL,
receipt_type TEXT NOT NULL,
user_id TEXT NOT NULL,
event_ids TEXT NOT NULL,
thread_id TEXT NOT NULL DEFAULT '',
data TEXT NOT NULL,
CONSTRAINT receipts_graph_uniqueness UNIQUE (room_id, receipt_type, user_id, thread_id)
);
CREATE TABLE receipts_linearized_new (
stream_id BIGINT NOT NULL,
room_id TEXT NOT NULL,
receipt_type TEXT NOT NULL,
user_id TEXT NOT NULL,
event_id TEXT NOT NULL,
thread_id TEXT NOT NULL DEFAULT '',
data TEXT NOT NULL,
CONSTRAINT receipts_linearized_uniqueness UNIQUE (room_id, receipt_type, user_id, thread_id)
);
-- Drop the old indexes.
DROP INDEX IF EXISTS receipts_linearized_id;
DROP INDEX IF EXISTS receipts_linearized_room_stream;
DROP INDEX IF EXISTS receipts_linearized_user;
-- Copy the data.
INSERT INTO receipts_graph_new (room_id, receipt_type, user_id, event_ids, data)
SELECT room_id, receipt_type, user_id, event_ids, data
FROM receipts_graph;
INSERT INTO receipts_linearized_new (stream_id, room_id, receipt_type, user_id, event_id, data)
SELECT stream_id, room_id, receipt_type, user_id, event_id, data
FROM receipts_linearized;
-- Drop the old tables.
DROP TABLE receipts_graph;
DROP TABLE receipts_linearized;
-- Rename the tables.
ALTER TABLE receipts_graph_new RENAME TO receipts_graph;
ALTER TABLE receipts_linearized_new RENAME TO receipts_linearized;
-- Create the indices.
CREATE INDEX receipts_linearized_id ON receipts_linearized( stream_id );
CREATE INDEX receipts_linearized_room_stream ON receipts_linearized( room_id, stream_id );
CREATE INDEX receipts_linearized_user ON receipts_linearized( user_id );
@@ -0,0 +1,27 @@
/* Copyright 2022 The Matrix.org Foundation C.I.C
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
ALTER TABLE event_push_actions_staging
ADD COLUMN thread_id TEXT NOT NULL DEFAULT '';
ALTER TABLE event_push_actions
ADD COLUMN thread_id TEXT NOT NULL DEFAULT '';
ALTER TABLE event_push_summary
ADD COLUMN thread_id TEXT NOT NULL DEFAULT '';
-- Update the unique index for `event_push_summary`
INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
(7003, 'event_push_summary_unique_index2', '{}');
+1
View File
@@ -830,6 +830,7 @@ class ReadReceipt:
receipt_type: str receipt_type: str
user_id: str user_id: str
event_ids: List[str] event_ids: List[str]
thread_id: Optional[str]
data: JsonDict data: JsonDict
+18 -3
View File
@@ -49,7 +49,12 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
sender = self.hs.get_federation_sender() sender = self.hs.get_federation_sender()
receipt = ReadReceipt( receipt = ReadReceipt(
"room_id", "m.read", "user_id", ["event_id"], {"ts": 1234} "room_id",
"m.read",
"user_id",
["event_id"],
thread_id=None,
data={"ts": 1234},
) )
self.successResultOf(defer.ensureDeferred(sender.send_read_receipt(receipt))) self.successResultOf(defer.ensureDeferred(sender.send_read_receipt(receipt)))
@@ -89,7 +94,12 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
sender = self.hs.get_federation_sender() sender = self.hs.get_federation_sender()
receipt = ReadReceipt( receipt = ReadReceipt(
"room_id", "m.read", "user_id", ["event_id"], {"ts": 1234} "room_id",
"m.read",
"user_id",
["event_id"],
thread_id=None,
data={"ts": 1234},
) )
self.successResultOf(defer.ensureDeferred(sender.send_read_receipt(receipt))) self.successResultOf(defer.ensureDeferred(sender.send_read_receipt(receipt)))
@@ -121,7 +131,12 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
# send the second RR # send the second RR
receipt = ReadReceipt( receipt = ReadReceipt(
"room_id", "m.read", "user_id", ["other_id"], {"ts": 1234} "room_id",
"m.read",
"user_id",
["other_id"],
thread_id=None,
data={"ts": 1234},
) )
self.successResultOf(defer.ensureDeferred(sender.send_read_receipt(receipt))) self.successResultOf(defer.ensureDeferred(sender.send_read_receipt(receipt)))
self.pump() self.pump()
+1
View File
@@ -447,6 +447,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
receipt_type="m.read", receipt_type="m.read",
user_id=self.local_user, user_id=self.local_user,
event_ids=[f"$eventid_{i}"], event_ids=[f"$eventid_{i}"],
thread_id=None,
data={}, data={},
) )
) )
@@ -171,14 +171,14 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
if send_receipt: if send_receipt:
self.get_success( self.get_success(
self.master_store.insert_receipt( self.master_store.insert_receipt(
ROOM_ID, ReceiptTypes.READ, USER_ID_2, [event1.event_id], {} ROOM_ID, ReceiptTypes.READ, USER_ID_2, [event1.event_id], None, {}
) )
) )
self.check( self.check(
"get_unread_event_push_actions_by_room_for_user", "get_unread_event_push_actions_by_room_for_user",
[ROOM_ID, USER_ID_2], [ROOM_ID, USER_ID_2],
NotifCounts(highlight_count=0, unread_count=0, notify_count=0), (NotifCounts(highlight_count=0, unread_count=0, notify_count=0), {}),
) )
self.persist( self.persist(
@@ -191,7 +191,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
self.check( self.check(
"get_unread_event_push_actions_by_room_for_user", "get_unread_event_push_actions_by_room_for_user",
[ROOM_ID, USER_ID_2], [ROOM_ID, USER_ID_2],
NotifCounts(highlight_count=0, unread_count=0, notify_count=1), (NotifCounts(highlight_count=0, unread_count=0, notify_count=1), {}),
) )
self.persist( self.persist(
@@ -206,7 +206,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
self.check( self.check(
"get_unread_event_push_actions_by_room_for_user", "get_unread_event_push_actions_by_room_for_user",
[ROOM_ID, USER_ID_2], [ROOM_ID, USER_ID_2],
NotifCounts(highlight_count=1, unread_count=0, notify_count=2), (NotifCounts(highlight_count=1, unread_count=0, notify_count=2), {}),
) )
def test_get_rooms_for_user_with_stream_ordering(self): def test_get_rooms_for_user_with_stream_ordering(self):
@@ -404,6 +404,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
event.event_id, event.event_id,
{user_id: actions for user_id, actions in push_actions}, {user_id: actions for user_id, actions in push_actions},
False, False,
None,
) )
) )
return event, context return event, context
+12 -2
View File
@@ -33,7 +33,12 @@ class ReceiptsStreamTestCase(BaseStreamTestCase):
# tell the master to send a new receipt # tell the master to send a new receipt
self.get_success( self.get_success(
self.hs.get_datastores().main.insert_receipt( self.hs.get_datastores().main.insert_receipt(
"!room:blue", "m.read", USER_ID, ["$event:blue"], {"a": 1} "!room:blue",
"m.read",
USER_ID,
["$event:blue"],
thread_id=None,
data={"a": 1},
) )
) )
self.replicate() self.replicate()
@@ -57,7 +62,12 @@ class ReceiptsStreamTestCase(BaseStreamTestCase):
self.get_success( self.get_success(
self.hs.get_datastores().main.insert_receipt( self.hs.get_datastores().main.insert_receipt(
"!room2:blue", "m.read", USER_ID, ["$event2:foo"], {"a": 2} "!room2:blue",
"m.read",
USER_ID,
["$event2:foo"],
thread_id=None,
data={"a": 2},
) )
) )
self.replicate() self.replicate()
+151
View File
@@ -1679,3 +1679,154 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
relations[RelationTypes.THREAD]["latest_event"]["event_id"], relations[RelationTypes.THREAD]["latest_event"]["event_id"],
related_event_id, related_event_id,
) )
class ThreadsTestCase(BaseRelationsTestCase):
@unittest.override_config({"experimental_features": {"msc3856_enabled": True}})
def test_threads(self) -> None:
"""Create threads and ensure the ordering is due to their latest event."""
# Create 2 threads.
thread_1 = self.parent_id
res = self.helper.send(self.room, body="Thread Root!", tok=self.user_token)
thread_2 = res["event_id"]
self._send_relation(RelationTypes.THREAD, "m.room.test")
self._send_relation(RelationTypes.THREAD, "m.room.test", parent_id=thread_2)
# Request the threads in the room.
channel = self.make_request(
"GET",
f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads",
access_token=self.user_token,
)
self.assertEquals(200, channel.code, channel.json_body)
thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]]
self.assertEqual(thread_roots, [thread_2, thread_1])
# Update the first thread, the ordering should swap.
self._send_relation(RelationTypes.THREAD, "m.room.test")
channel = self.make_request(
"GET",
f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads",
access_token=self.user_token,
)
self.assertEquals(200, channel.code, channel.json_body)
thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]]
self.assertEqual(thread_roots, [thread_1, thread_2])
@unittest.override_config({"experimental_features": {"msc3856_enabled": True}})
def test_pagination(self) -> None:
"""Create threads and paginate through them."""
# Create 2 threads.
thread_1 = self.parent_id
res = self.helper.send(self.room, body="Thread Root!", tok=self.user_token)
thread_2 = res["event_id"]
self._send_relation(RelationTypes.THREAD, "m.room.test")
self._send_relation(RelationTypes.THREAD, "m.room.test", parent_id=thread_2)
# Request the threads in the room.
channel = self.make_request(
"GET",
f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads?limit=1",
access_token=self.user_token,
)
self.assertEquals(200, channel.code, channel.json_body)
thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]]
self.assertEqual(thread_roots, [thread_2])
# Make sure next_batch has something in it that looks like it could be a
# valid token.
next_batch = channel.json_body.get("next_batch")
self.assertIsInstance(next_batch, str, channel.json_body)
channel = self.make_request(
"GET",
f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads?limit=1&from={next_batch}",
access_token=self.user_token,
)
self.assertEquals(200, channel.code, channel.json_body)
thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]]
self.assertEqual(thread_roots, [thread_1], channel.json_body)
self.assertNotIn("next_batch", channel.json_body, channel.json_body)
@unittest.override_config({"experimental_features": {"msc3856_enabled": True}})
def test_include(self) -> None:
"""Filtering threads to all or participated in should work."""
# Thread 1 has the user as the root event.
thread_1 = self.parent_id
self._send_relation(
RelationTypes.THREAD, "m.room.test", access_token=self.user2_token
)
# Thread 2 has the user replying.
res = self.helper.send(self.room, body="Thread Root!", tok=self.user2_token)
thread_2 = res["event_id"]
self._send_relation(RelationTypes.THREAD, "m.room.test", parent_id=thread_2)
# Thread 3 has the user not participating in.
res = self.helper.send(self.room, body="Another thread!", tok=self.user2_token)
thread_3 = res["event_id"]
self._send_relation(
RelationTypes.THREAD,
"m.room.test",
access_token=self.user2_token,
parent_id=thread_3,
)
# All threads in the room.
channel = self.make_request(
"GET",
f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads",
access_token=self.user_token,
)
self.assertEquals(200, channel.code, channel.json_body)
thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]]
self.assertEqual(
thread_roots, [thread_3, thread_2, thread_1], channel.json_body
)
# Only participated threads.
channel = self.make_request(
"GET",
f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads?include=participated",
access_token=self.user_token,
)
self.assertEquals(200, channel.code, channel.json_body)
thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]]
self.assertEqual(thread_roots, [thread_2, thread_1], channel.json_body)
@unittest.override_config({"experimental_features": {"msc3856_enabled": True}})
def test_ignored_user(self) -> None:
"""Events from ignored users should be ignored."""
# Thread 1 has a reply from an ignored user.
thread_1 = self.parent_id
self._send_relation(
RelationTypes.THREAD, "m.room.test", access_token=self.user2_token
)
# Thread 2 is created by an ignored user.
res = self.helper.send(self.room, body="Thread Root!", tok=self.user2_token)
thread_2 = res["event_id"]
self._send_relation(RelationTypes.THREAD, "m.room.test", parent_id=thread_2)
# Ignore user2.
self.get_success(
self.store.add_account_data_for_user(
self.user_id,
AccountDataTypes.IGNORED_USER_LIST,
{"ignored_users": {self.user2_id: {}}},
)
)
# Only thread 1 is returned.
channel = self.make_request(
"GET",
f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads",
access_token=self.user_token,
)
self.assertEquals(200, channel.code, channel.json_body)
thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]]
self.assertEqual(thread_roots, [thread_1], channel.json_body)
+185 -1
View File
@@ -12,12 +12,15 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Optional
from twisted.test.proto_helpers import MemoryReactor from twisted.test.proto_helpers import MemoryReactor
from synapse.rest import admin from synapse.rest import admin
from synapse.rest.client import login, room from synapse.rest.client import login, room
from synapse.server import HomeServer from synapse.server import HomeServer
from synapse.storage.databases.main.event_push_actions import NotifCounts from synapse.storage.databases.main.event_push_actions import NotifCounts
from synapse.types import JsonDict
from synapse.util import Clock from synapse.util import Clock
from tests.unittest import HomeserverTestCase from tests.unittest import HomeserverTestCase
@@ -70,7 +73,7 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
def _assert_counts( def _assert_counts(
noitf_count: int, unread_count: int, highlight_count: int noitf_count: int, unread_count: int, highlight_count: int
) -> None: ) -> None:
counts = self.get_success( counts, thread_counts = self.get_success(
self.store.db_pool.runInteraction( self.store.db_pool.runInteraction(
"get-unread-counts", "get-unread-counts",
self.store._get_unread_counts_by_receipt_txn, self.store._get_unread_counts_by_receipt_txn,
@@ -86,6 +89,7 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
highlight_count=highlight_count, highlight_count=highlight_count,
), ),
) )
self.assertEqual(thread_counts, {})
def _create_event(highlight: bool = False) -> str: def _create_event(highlight: bool = False) -> str:
result = self.helper.send_event( result = self.helper.send_event(
@@ -108,6 +112,7 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
"m.read", "m.read",
user_id=user_id, user_id=user_id,
event_ids=[event_id], event_ids=[event_id],
thread_id=None,
data={}, data={},
) )
) )
@@ -131,6 +136,7 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
_assert_counts(0, 0, 0) _assert_counts(0, 0, 0)
_create_event() _create_event()
_assert_counts(1, 1, 0)
_rotate() _rotate()
_assert_counts(1, 1, 0) _assert_counts(1, 1, 0)
@@ -166,6 +172,184 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
_rotate() _rotate()
_assert_counts(0, 0, 0) _assert_counts(0, 0, 0)
def test_count_aggregation_threads(self) -> None:
# Create a user to receive notifications and send receipts.
user_id = self.register_user("user1235", "pass")
token = self.login("user1235", "pass")
# And another users to send events.
other_id = self.register_user("other", "pass")
other_token = self.login("other", "pass")
# Create a room and put both users in it.
room_id = self.helper.create_room_as(user_id, tok=token)
self.helper.join(room_id, other_id, tok=other_token)
thread_id: str
last_event_id: str
def _assert_counts(
noitf_count: int,
unread_count: int,
highlight_count: int,
thread_notif_count: int,
thread_unread_count: int,
thread_highlight_count: int,
) -> None:
counts, thread_counts = self.get_success(
self.store.db_pool.runInteraction(
"get-unread-counts",
self.store._get_unread_counts_by_receipt_txn,
room_id,
user_id,
)
)
self.assertEqual(
counts,
NotifCounts(
notify_count=noitf_count,
unread_count=unread_count,
highlight_count=highlight_count,
),
)
if thread_notif_count or thread_unread_count or thread_highlight_count:
self.assertEqual(
thread_counts,
{
thread_id: NotifCounts(
notify_count=thread_notif_count,
unread_count=thread_unread_count,
highlight_count=thread_highlight_count,
),
},
)
else:
self.assertEqual(thread_counts, {})
def _create_event(
highlight: bool = False, thread_id: Optional[str] = None
) -> str:
content: JsonDict = {
"msgtype": "m.text",
"body": user_id if highlight else "",
}
if thread_id:
content["m.relates_to"] = {
"rel_type": "m.thread",
"event_id": thread_id,
}
result = self.helper.send_event(
room_id,
type="m.room.message",
content=content,
tok=other_token,
)
nonlocal last_event_id
last_event_id = result["event_id"]
return last_event_id
def _rotate() -> None:
self.get_success(self.store._rotate_notifs())
def _mark_read(event_id: str, thread_id: Optional[str] = None) -> None:
self.get_success(
self.store.insert_receipt(
room_id,
"m.read",
user_id=user_id,
event_ids=[event_id],
thread_id=thread_id,
data={},
)
)
_assert_counts(0, 0, 0, 0, 0, 0)
thread_id = _create_event()
_assert_counts(1, 0, 0, 0, 0, 0)
_rotate()
_assert_counts(1, 0, 0, 0, 0, 0)
_create_event(thread_id=thread_id)
_assert_counts(1, 0, 0, 1, 0, 0)
_rotate()
_assert_counts(1, 0, 0, 1, 0, 0)
_create_event()
_assert_counts(2, 0, 0, 1, 0, 0)
_rotate()
_assert_counts(2, 0, 0, 1, 0, 0)
event_id = _create_event(thread_id=thread_id)
_assert_counts(2, 0, 0, 2, 0, 0)
_rotate()
_assert_counts(2, 0, 0, 2, 0, 0)
_create_event()
_create_event(thread_id=thread_id)
_mark_read(event_id)
_assert_counts(1, 0, 0, 3, 0, 0)
_mark_read(event_id, thread_id)
_assert_counts(1, 0, 0, 1, 0, 0)
_mark_read(last_event_id)
_mark_read(last_event_id, thread_id)
_assert_counts(0, 0, 0, 0, 0, 0)
_create_event()
_create_event(thread_id=thread_id)
_assert_counts(1, 0, 0, 1, 0, 0)
_rotate()
_assert_counts(1, 0, 0, 1, 0, 0)
# Delete old event push actions, this should not affect the (summarised) count.
self.get_success(self.store._remove_old_push_actions_that_have_rotated())
_assert_counts(1, 0, 0, 1, 0, 0)
_mark_read(last_event_id)
_mark_read(last_event_id, thread_id)
_assert_counts(0, 0, 0, 0, 0, 0)
_create_event(True)
_assert_counts(1, 1, 1, 0, 0, 0)
_rotate()
_assert_counts(1, 1, 1, 0, 0, 0)
event_id = _create_event(True, thread_id)
_assert_counts(1, 1, 1, 1, 1, 1)
_rotate()
_assert_counts(1, 1, 1, 1, 1, 1)
# Check that adding another notification and rotating after highlight
# works.
_create_event()
_rotate()
_assert_counts(2, 1, 1, 1, 1, 1)
_create_event(thread_id=thread_id)
_rotate()
_assert_counts(2, 1, 1, 2, 1, 1)
# Check that sending read receipts at different points results in the
# right counts.
_mark_read(event_id)
_assert_counts(1, 0, 0, 2, 1, 1)
_mark_read(event_id, thread_id)
_assert_counts(1, 0, 0, 1, 0, 0)
_mark_read(last_event_id)
_assert_counts(0, 0, 0, 1, 0, 0)
_mark_read(last_event_id, thread_id)
_assert_counts(0, 0, 0, 0, 0, 0)
_create_event(True)
_create_event(True, thread_id)
_assert_counts(1, 1, 1, 1, 1, 1)
_mark_read(last_event_id)
_mark_read(last_event_id, thread_id)
_assert_counts(0, 0, 0, 0, 0, 0)
_rotate()
_assert_counts(0, 0, 0, 0, 0, 0)
def test_find_first_stream_ordering_after_ts(self) -> None: def test_find_first_stream_ordering_after_ts(self) -> None:
def add_event(so: int, ts: int) -> None: def add_event(so: int, ts: int) -> None:
self.get_success( self.get_success(
+28 -8
View File
@@ -120,13 +120,18 @@ class ReceiptTestCase(HomeserverTestCase):
# Send public read receipt for the first event # Send public read receipt for the first event
self.get_success( self.get_success(
self.store.insert_receipt( self.store.insert_receipt(
self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_1_id], {} self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_1_id], None, {}
) )
) )
# Send private read receipt for the second event # Send private read receipt for the second event
self.get_success( self.get_success(
self.store.insert_receipt( self.store.insert_receipt(
self.room_id1, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event1_2_id], {} self.room_id1,
ReceiptTypes.READ_PRIVATE,
OUR_USER_ID,
[event1_2_id],
None,
{},
) )
) )
@@ -153,7 +158,7 @@ class ReceiptTestCase(HomeserverTestCase):
# Test receipt updating # Test receipt updating
self.get_success( self.get_success(
self.store.insert_receipt( self.store.insert_receipt(
self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_2_id], {} self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_2_id], None, {}
) )
) )
res = self.get_success( res = self.get_success(
@@ -169,7 +174,12 @@ class ReceiptTestCase(HomeserverTestCase):
# Test new room is reflected in what the method returns # Test new room is reflected in what the method returns
self.get_success( self.get_success(
self.store.insert_receipt( self.store.insert_receipt(
self.room_id2, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event2_1_id], {} self.room_id2,
ReceiptTypes.READ_PRIVATE,
OUR_USER_ID,
[event2_1_id],
None,
{},
) )
) )
res = self.get_success( res = self.get_success(
@@ -191,13 +201,18 @@ class ReceiptTestCase(HomeserverTestCase):
# Send public read receipt for the first event # Send public read receipt for the first event
self.get_success( self.get_success(
self.store.insert_receipt( self.store.insert_receipt(
self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_1_id], {} self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_1_id], None, {}
) )
) )
# Send private read receipt for the second event # Send private read receipt for the second event
self.get_success( self.get_success(
self.store.insert_receipt( self.store.insert_receipt(
self.room_id1, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event1_2_id], {} self.room_id1,
ReceiptTypes.READ_PRIVATE,
OUR_USER_ID,
[event1_2_id],
None,
{},
) )
) )
@@ -230,7 +245,7 @@ class ReceiptTestCase(HomeserverTestCase):
# Test receipt updating # Test receipt updating
self.get_success( self.get_success(
self.store.insert_receipt( self.store.insert_receipt(
self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_2_id], {} self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_2_id], None, {}
) )
) )
res = self.get_success( res = self.get_success(
@@ -248,7 +263,12 @@ class ReceiptTestCase(HomeserverTestCase):
# Test new room is reflected in what the method returns # Test new room is reflected in what the method returns
self.get_success( self.get_success(
self.store.insert_receipt( self.store.insert_receipt(
self.room_id2, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event2_1_id], {} self.room_id2,
ReceiptTypes.READ_PRIVATE,
OUR_USER_ID,
[event2_1_id],
None,
{},
) )
) )
res = self.get_success( res = self.get_success(
+190
View File
@@ -0,0 +1,190 @@
import json
from time import monotonic
import requests
HOMESERVER = "http://localhost:8080"
USER_1_TOK = "syt_dGVzdGVy_AywuFarQjsYrHuPkOUvg_25XLNK"
USER_1_HEADERS = {"Authorization": f"Bearer {USER_1_TOK}"}
USER_2_TOK = "syt_b3RoZXI_jtiTnwtlBjMGMixlHIBM_4cxesB"
USER_2_HEADERS = {"Authorization": f"Bearer {USER_2_TOK}"}
def _check_for_status(result):
# Similar to raise_for_status, but prints the error.
if 400 <= result.status_code:
error_msg = result.json()
result.raise_for_status()
print(error_msg)
exit(0)
def _sync_and_show(room_id):
print("Syncing . . .")
result = requests.get(
f"{HOMESERVER}/_matrix/client/v3/sync",
headers=USER_1_HEADERS,
params={
"filter": json.dumps(
{
"room": {
"timeline": {"limit": 30, "unread_thread_notifications": True}
}
}
)
},
)
_check_for_status(result)
sync_response = result.json()
room = sync_response["rooms"]["join"][room_id]
# Find read receipts (this assumes non-overlapping).
read_receipts = {} # thread -> event ID -> users
for event in room["ephemeral"]["events"]:
if event["type"] != "m.receipt":
continue
for event_id, content in event["content"].items():
for mxid, receipt in content["m.read"].items():
print(mxid, receipt)
# Just care about the localpart of the MXID.
mxid = mxid.split(":", 1)[0]
read_receipts.setdefault(receipt.get("thread_id"), {}).setdefault(
event_id, []
).append(mxid)
print(room["unread_notifications"])
print(room.get("unread_thread_notifications"))
print()
# Convert events to their threads.
threads = {}
for event in room["timeline"]["events"]:
if event["type"] != "m.room.message":
continue
event_id = event["event_id"]
parent_id = event["content"].get("m.relates_to", {}).get("event_id")
if parent_id:
threads[parent_id][1].append(event)
else:
threads[event_id] = (event, [])
for root_event_id, (root, thread) in threads.items():
msg = root["content"]["body"]
print(f"{root_event_id}: {msg}")
for event in thread:
thread_event_id = event["event_id"]
msg = event["content"]["body"]
print(f"\t{thread_event_id}: {msg}")
if thread_event_id in read_receipts.get(root_event_id, {}):
user_ids = ", ".join(read_receipts[root_event_id][thread_event_id])
print(f"\t^--------- {user_ids} ---------^")
if root_event_id in read_receipts[None]:
user_ids = ", ".join(read_receipts[None][root_event_id])
print(f"^--------- {user_ids} ---------^")
print()
print()
def _send_event(room_id, body, thread_id=None):
content = {
"msgtype": "m.text",
"body": body,
}
if thread_id:
content["m.relates_to"] = {
"rel_type": "m.thread",
"event_id": thread_id,
}
# Send a msg to the room.
result = requests.put(
f"{HOMESERVER}/_matrix/client/v3/rooms/{room_id}/send/m.room.message/msg{monotonic()}",
json=content,
headers=USER_2_HEADERS,
)
_check_for_status(result)
return result.json()["event_id"]
def main():
# Create a new room as user 2, add a bunch of messages.
result = requests.post(
f"{HOMESERVER}/_matrix/client/v3/createRoom",
json={"visibility": "public", "name": f"Thread Read Receipts ({monotonic()})"},
headers=USER_2_HEADERS,
)
_check_for_status(result)
room_id = result.json()["room_id"]
# Second user joins the room.
result = requests.post(
f"{HOMESERVER}/_matrix/client/v3/rooms/{room_id}/join", headers=USER_1_HEADERS
)
_check_for_status(result)
# Sync user 1.
_sync_and_show(room_id)
# User 2 sends some messages.
event_ids = []
def _send_and_append(body, thread_id=None):
event_id = _send_event(room_id, body, thread_id)
event_ids.append(event_id)
return event_id
for msg in range(5):
root_message_id = _send_and_append(f"Message {msg}")
for msg in range(10):
if msg % 2:
_send_and_append(f"More message {msg}")
else:
_send_and_append(f"Thread Message {msg}", root_message_id)
# User 2 sends a read receipt.
print("@second reads main timeline")
result = requests.post(
f"{HOMESERVER}/_matrix/client/v3/rooms/{room_id}/receipt/m.read/{event_ids[3]}",
headers=USER_2_HEADERS,
json={},
)
_check_for_status(result)
_sync_and_show(room_id)
# User 1 sends a read receipt.
print("@test reads main timeline")
result = requests.post(
f"{HOMESERVER}/_matrix/client/v3/rooms/{room_id}/receipt/m.read/{event_ids[-5]}",
headers=USER_1_HEADERS,
json={},
)
_check_for_status(result)
_sync_and_show(room_id)
# User 1 sends another read receipt.
print("@test reads thread")
result = requests.post(
f"{HOMESERVER}/_matrix/client/v3/rooms/{room_id}/receipt/m.read/{event_ids[-4]}/{root_message_id}",
headers=USER_1_HEADERS,
json={},
)
_check_for_status(result)
_sync_and_show(room_id)
if __name__ == "__main__":
main()