Compare commits
18 Commits
clokep/sta
...
clokep/thr
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
426144c9bd | ||
|
|
7fb3b8405e | ||
|
|
f0ab6a7f4c | ||
|
|
0cbddc632e | ||
|
|
80dcb911dc | ||
|
|
fd972df8f9 | ||
|
|
fbd6727760 | ||
|
|
08620b3f28 | ||
|
|
d56296aa57 | ||
|
|
e0ed95a45b | ||
|
|
dfd921d421 | ||
|
|
2c7a5681b4 | ||
|
|
759366e5e6 | ||
|
|
6c2e08ed6f | ||
|
|
f6267b1abe | ||
|
|
d510975b2f | ||
|
|
8dcdb4efa9 | ||
|
|
e9a649ec31 |
1
changelog.d/13181.feature
Normal file
1
changelog.d/13181.feature
Normal file
@@ -0,0 +1 @@
|
||||
Experimental support for thread-specific notifications ([MSC3773](https://github.com/matrix-org/matrix-spec-proposals/pull/3773)).
|
||||
1
changelog.d/13202.feature
Normal file
1
changelog.d/13202.feature
Normal file
@@ -0,0 +1 @@
|
||||
Experimental support for thread-specific receipts ([MSC3771](https://github.com/matrix-org/matrix-spec-proposals/pull/3771)).
|
||||
1
changelog.d/13394.feature
Normal file
1
changelog.d/13394.feature
Normal file
@@ -0,0 +1 @@
|
||||
Experimental support for [MSC3856](https://github.com/matrix-org/matrix-spec-proposals/pull/3856): threads list API.
|
||||
@@ -84,6 +84,7 @@ ROOM_EVENT_FILTER_SCHEMA = {
|
||||
"contains_url": {"type": "boolean"},
|
||||
"lazy_load_members": {"type": "boolean"},
|
||||
"include_redundant_members": {"type": "boolean"},
|
||||
"unread_thread_notifications": {"type": "boolean"},
|
||||
# Include or exclude events with the provided labels.
|
||||
# cf https://github.com/matrix-org/matrix-doc/pull/2326
|
||||
"org.matrix.labels": {"type": "array", "items": {"type": "string"}},
|
||||
@@ -240,6 +241,9 @@ class FilterCollection:
|
||||
def include_redundant_members(self) -> bool:
|
||||
return self._room_state_filter.include_redundant_members
|
||||
|
||||
def unread_thread_notifications(self) -> bool:
|
||||
return self._room_timeline_filter.unread_thread_notifications
|
||||
|
||||
async def filter_presence(
|
||||
self, events: Iterable[UserPresenceState]
|
||||
) -> List[UserPresenceState]:
|
||||
@@ -304,6 +308,9 @@ class Filter:
|
||||
self.include_redundant_members = filter_json.get(
|
||||
"include_redundant_members", False
|
||||
)
|
||||
self.unread_thread_notifications = filter_json.get(
|
||||
"unread_thread_notifications", False
|
||||
)
|
||||
|
||||
self.types = filter_json.get("types", None)
|
||||
self.not_types = filter_json.get("not_types", [])
|
||||
|
||||
@@ -82,11 +82,18 @@ class ExperimentalConfig(Config):
|
||||
# MSC3786 (Add a default push rule to ignore m.room.server_acl events)
|
||||
self.msc3786_enabled: bool = experimental.get("msc3786_enabled", False)
|
||||
|
||||
# MSC3771: Thread read receipts
|
||||
self.msc3771_enabled: bool = experimental.get("msc3771_enabled", False)
|
||||
# MSC3772: A push rule for mutual relations.
|
||||
self.msc3772_enabled: bool = experimental.get("msc3772_enabled", False)
|
||||
# MSC3773: Thread notifications
|
||||
self.msc3773_enabled: bool = experimental.get("msc3773_enabled", False)
|
||||
|
||||
# MSC3715: dir param on /relations.
|
||||
self.msc3715_enabled: bool = experimental.get("msc3715_enabled", False)
|
||||
|
||||
# MSC3848: Introduce errcodes for specific event sending failures
|
||||
self.msc3848_enabled: bool = experimental.get("msc3848_enabled", False)
|
||||
|
||||
# MSC3856: Threads list API
|
||||
self.msc3856_enabled: bool = experimental.get("msc3856_enabled", False)
|
||||
|
||||
@@ -97,6 +97,7 @@ class ReceiptsHandler:
|
||||
receipt_type=receipt_type,
|
||||
user_id=user_id,
|
||||
event_ids=user_values["event_ids"],
|
||||
thread_id=None, # TODO
|
||||
data=user_values.get("data", {}),
|
||||
)
|
||||
)
|
||||
@@ -114,6 +115,7 @@ class ReceiptsHandler:
|
||||
receipt.receipt_type,
|
||||
receipt.user_id,
|
||||
receipt.event_ids,
|
||||
receipt.thread_id,
|
||||
receipt.data,
|
||||
)
|
||||
|
||||
@@ -146,7 +148,12 @@ class ReceiptsHandler:
|
||||
return True
|
||||
|
||||
async def received_client_receipt(
|
||||
self, room_id: str, receipt_type: str, user_id: str, event_id: str
|
||||
self,
|
||||
room_id: str,
|
||||
receipt_type: str,
|
||||
user_id: str,
|
||||
event_id: str,
|
||||
thread_id: Optional[str],
|
||||
) -> None:
|
||||
"""Called when a client tells us a local user has read up to the given
|
||||
event_id in the room.
|
||||
@@ -156,6 +163,7 @@ class ReceiptsHandler:
|
||||
receipt_type=receipt_type,
|
||||
user_id=user_id,
|
||||
event_ids=[event_id],
|
||||
thread_id=thread_id,
|
||||
data={"ts": int(self.clock.time_msec())},
|
||||
)
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import enum
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Tuple
|
||||
|
||||
@@ -31,6 +32,13 @@ if TYPE_CHECKING:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ThreadsListInclude(str, enum.Enum):
|
||||
"""Valid values for the 'include' flag of /threads."""
|
||||
|
||||
all = "all"
|
||||
participated = "participated"
|
||||
|
||||
|
||||
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||
class _ThreadAggregation:
|
||||
# The latest event in the thread.
|
||||
@@ -482,3 +490,84 @@ class RelationsHandler:
|
||||
results.setdefault(event_id, BundledAggregations()).replace = edit
|
||||
|
||||
return results
|
||||
|
||||
async def get_threads(
|
||||
self,
|
||||
requester: Requester,
|
||||
room_id: str,
|
||||
include: ThreadsListInclude,
|
||||
limit: int = 5,
|
||||
from_token: Optional[StreamToken] = None,
|
||||
to_token: Optional[StreamToken] = None,
|
||||
) -> JsonDict:
|
||||
"""Get related events of a event, ordered by topological ordering.
|
||||
|
||||
Args:
|
||||
requester: The user requesting the relations.
|
||||
room_id: The room the event belongs to.
|
||||
include: One of "all" or "participated" to indicate which threads should
|
||||
be returned.
|
||||
limit: Only fetch the most recent `limit` events.
|
||||
from_token: Fetch rows from the given token, or from the start if None.
|
||||
to_token: Fetch rows up to the given token, or up to the end if None.
|
||||
|
||||
Returns:
|
||||
The pagination chunk.
|
||||
"""
|
||||
|
||||
user_id = requester.user.to_string()
|
||||
|
||||
# TODO Properly handle a user leaving a room.
|
||||
(_, member_event_id) = await self._auth.check_user_in_room_or_world_readable(
|
||||
room_id, user_id, allow_departed_users=True
|
||||
)
|
||||
|
||||
# Note that ignored users are not passed into get_relations_for_event
|
||||
# below. Ignored users are handled in filter_events_for_client (and by
|
||||
# not passing them in here we should get a better cache hit rate).
|
||||
thread_roots, next_token = await self._main_store.get_threads(
|
||||
room_id=room_id, limit=limit, from_token=from_token, to_token=to_token
|
||||
)
|
||||
|
||||
events = await self._main_store.get_events_as_list(thread_roots)
|
||||
|
||||
if include == ThreadsListInclude.participated:
|
||||
# Pre-seed thread participation with whether the requester sent the event.
|
||||
participated = {event.event_id: event.sender == user_id for event in events}
|
||||
# For events the requester did not send, check the database for whether
|
||||
# the requester sent a threaded reply.
|
||||
participated.update(
|
||||
await self._main_store.get_threads_participated(
|
||||
[eid for eid, p in participated.items() if not p],
|
||||
user_id,
|
||||
)
|
||||
)
|
||||
|
||||
# Limit the returned threads to those the user has participated in.
|
||||
events = [event for event in events if participated[event.event_id]]
|
||||
|
||||
events = await filter_events_for_client(
|
||||
self._storage_controllers,
|
||||
user_id,
|
||||
events,
|
||||
is_peeking=(member_event_id is None),
|
||||
)
|
||||
|
||||
now = self._clock.time_msec()
|
||||
|
||||
aggregations = await self.get_bundled_aggregations(
|
||||
events, requester.user.to_string()
|
||||
)
|
||||
serialized_events = self._event_serializer.serialize_events(
|
||||
events, now, bundle_aggregations=aggregations
|
||||
)
|
||||
|
||||
return_value: JsonDict = {"chunk": serialized_events}
|
||||
|
||||
if next_token:
|
||||
return_value["next_batch"] = await next_token.to_string(self._main_store)
|
||||
|
||||
if from_token:
|
||||
return_value["prev_batch"] = await from_token.to_string(self._main_store)
|
||||
|
||||
return return_value
|
||||
|
||||
@@ -115,6 +115,7 @@ class JoinedSyncResult:
|
||||
ephemeral: List[JsonDict]
|
||||
account_data: List[JsonDict]
|
||||
unread_notifications: JsonDict
|
||||
unread_thread_notifications: JsonDict
|
||||
summary: Optional[JsonDict]
|
||||
unread_count: int
|
||||
|
||||
@@ -265,6 +266,8 @@ class SyncHandler:
|
||||
|
||||
self.rooms_to_exclude = hs.config.server.rooms_to_exclude_from_sync
|
||||
|
||||
self._msc3773_enabled = hs.config.experimental.msc3773_enabled
|
||||
|
||||
async def wait_for_sync_for_user(
|
||||
self,
|
||||
requester: Requester,
|
||||
@@ -1053,7 +1056,7 @@ class SyncHandler:
|
||||
|
||||
async def unread_notifs_for_room_id(
|
||||
self, room_id: str, sync_config: SyncConfig
|
||||
) -> NotifCounts:
|
||||
) -> Tuple[NotifCounts, Dict[str, NotifCounts]]:
|
||||
with Measure(self.clock, "unread_notifs_for_room_id"):
|
||||
|
||||
return await self.store.get_unread_event_push_actions_by_room_for_user(
|
||||
@@ -2115,17 +2118,46 @@ class SyncHandler:
|
||||
ephemeral=ephemeral,
|
||||
account_data=account_data_events,
|
||||
unread_notifications=unread_notifications,
|
||||
unread_thread_notifications={},
|
||||
summary=summary,
|
||||
unread_count=0,
|
||||
)
|
||||
|
||||
if room_sync or always_include:
|
||||
notifs = await self.unread_notifs_for_room_id(room_id, sync_config)
|
||||
notifs, thread_notifs = await self.unread_notifs_for_room_id(
|
||||
room_id, sync_config
|
||||
)
|
||||
|
||||
unread_notifications["notification_count"] = notifs.notify_count
|
||||
unread_notifications["highlight_count"] = notifs.highlight_count
|
||||
# Notifications for the main timeline.
|
||||
notify_count = notifs.notify_count
|
||||
highlight_count = notifs.highlight_count
|
||||
unread_count = notifs.unread_count
|
||||
|
||||
room_sync.unread_count = notifs.unread_count
|
||||
# Check the sync configuration.
|
||||
if (
|
||||
self._msc3773_enabled
|
||||
and sync_config.filter_collection.unread_thread_notifications()
|
||||
):
|
||||
# And add info for each thread.
|
||||
room_sync.unread_thread_notifications = {
|
||||
thread_id: {
|
||||
"notification_count": tnotifs.notify_count,
|
||||
"highlight_count": tnotifs.highlight_count,
|
||||
}
|
||||
for thread_id, tnotifs in thread_notifs.items()
|
||||
if thread_id is not None
|
||||
}
|
||||
|
||||
else:
|
||||
# Combine the unread counts for all threads and main timeline.
|
||||
for tnotifs in thread_notifs.values():
|
||||
notify_count += tnotifs.notify_count
|
||||
highlight_count += tnotifs.highlight_count
|
||||
unread_count += tnotifs.unread_count
|
||||
|
||||
unread_notifications["notification_count"] = notify_count
|
||||
unread_notifications["highlight_count"] = highlight_count
|
||||
room_sync.unread_count = unread_count
|
||||
|
||||
sync_result_builder.joined.append(room_sync)
|
||||
|
||||
|
||||
@@ -186,7 +186,7 @@ class BulkPushRuleEvaluator:
|
||||
return pl_event.content if pl_event else {}, sender_level
|
||||
|
||||
async def _get_mutual_relations(
|
||||
self, event: EventBase, rules: Iterable[Dict[str, Any]]
|
||||
self, parent_id: str, rules: Iterable[Dict[str, Any]]
|
||||
) -> Dict[str, Set[Tuple[str, str]]]:
|
||||
"""
|
||||
Fetch event metadata for events which related to the same event as the given event.
|
||||
@@ -194,7 +194,7 @@ class BulkPushRuleEvaluator:
|
||||
If the given event has no relation information, returns an empty dictionary.
|
||||
|
||||
Args:
|
||||
event_id: The event ID which is targeted by relations.
|
||||
parent_id: The event ID which is targeted by relations.
|
||||
rules: The push rules which will be processed for this event.
|
||||
|
||||
Returns:
|
||||
@@ -208,12 +208,6 @@ class BulkPushRuleEvaluator:
|
||||
if not self._relations_match_enabled:
|
||||
return {}
|
||||
|
||||
# If the event does not have a relation, then cannot have any mutual
|
||||
# relations.
|
||||
relation = relation_from_event(event)
|
||||
if not relation:
|
||||
return {}
|
||||
|
||||
# Pre-filter to figure out which relation types are interesting.
|
||||
rel_types = set()
|
||||
for rule in rules:
|
||||
@@ -235,9 +229,7 @@ class BulkPushRuleEvaluator:
|
||||
return {}
|
||||
|
||||
# If any valid rules were found, fetch the mutual relations.
|
||||
return await self.store.get_mutual_event_relations(
|
||||
relation.parent_id, rel_types
|
||||
)
|
||||
return await self.store.get_mutual_event_relations(parent_id, rel_types)
|
||||
|
||||
@measure_func("action_for_event_by_user")
|
||||
async def action_for_event_by_user(
|
||||
@@ -265,9 +257,18 @@ class BulkPushRuleEvaluator:
|
||||
sender_power_level,
|
||||
) = await self._get_power_levels_and_sender_level(event, context)
|
||||
|
||||
relations = await self._get_mutual_relations(
|
||||
event, itertools.chain(*rules_by_user.values())
|
||||
)
|
||||
relation = relation_from_event(event)
|
||||
# If the event does not have a relation, then cannot have any mutual
|
||||
# relations or thread ID.
|
||||
relations = {}
|
||||
thread_id = None
|
||||
if relation:
|
||||
relations = await self._get_mutual_relations(
|
||||
relation.parent_id, itertools.chain(*rules_by_user.values())
|
||||
)
|
||||
# XXX Does this need to point to a valid parent ID or anything?
|
||||
if relation.rel_type == RelationTypes.THREAD:
|
||||
thread_id = relation.parent_id
|
||||
|
||||
evaluator = PushRuleEvaluatorForEvent(
|
||||
event,
|
||||
@@ -338,6 +339,7 @@ class BulkPushRuleEvaluator:
|
||||
event.event_id,
|
||||
actions_by_user,
|
||||
count_as_unread,
|
||||
thread_id,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -26,13 +26,18 @@ async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) -
|
||||
badge = len(invites)
|
||||
|
||||
for room_id in joins:
|
||||
notifs = await (
|
||||
notifs, thread_notifs = await (
|
||||
store.get_unread_event_push_actions_by_room_for_user(
|
||||
room_id,
|
||||
user_id,
|
||||
)
|
||||
)
|
||||
if notifs.notify_count == 0:
|
||||
# Combine the counts from all the threads.
|
||||
notify_count = notifs.notify_count + sum(
|
||||
n.notify_count for n in thread_notifs.values()
|
||||
)
|
||||
|
||||
if notify_count == 0:
|
||||
continue
|
||||
|
||||
if group_by_room:
|
||||
@@ -40,7 +45,7 @@ async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) -
|
||||
badge += 1
|
||||
else:
|
||||
# increment the badge count by the number of unread messages in the room
|
||||
badge += notifs.notify_count
|
||||
badge += notify_count
|
||||
return badge
|
||||
|
||||
|
||||
|
||||
@@ -423,7 +423,8 @@ class FederationSenderHandler:
|
||||
receipt.receipt_type,
|
||||
receipt.user_id,
|
||||
[receipt.event_id],
|
||||
receipt.data,
|
||||
thread_id=receipt.thread_id,
|
||||
data=receipt.data,
|
||||
)
|
||||
await self.federation_sender.send_read_receipt(receipt_info)
|
||||
|
||||
|
||||
@@ -361,6 +361,7 @@ class ReceiptsStream(Stream):
|
||||
receipt_type: str
|
||||
user_id: str
|
||||
event_id: str
|
||||
thread_id: Optional[str]
|
||||
data: dict
|
||||
|
||||
NAME = "receipts"
|
||||
|
||||
@@ -81,6 +81,7 @@ class ReadMarkerRestServlet(RestServlet):
|
||||
receipt_type,
|
||||
user_id=requester.user.to_string(),
|
||||
event_id=event_id,
|
||||
thread_id=None, # TODO
|
||||
)
|
||||
|
||||
return 200, {}
|
||||
|
||||
@@ -13,10 +13,10 @@
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
from typing import TYPE_CHECKING, Optional, Tuple
|
||||
|
||||
from synapse.api.constants import ReceiptTypes
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.api.errors import SynapseError, UnrecognizedRequestError
|
||||
from synapse.http.server import HttpServer
|
||||
from synapse.http.servlet import RestServlet, parse_json_object_from_request
|
||||
from synapse.http.site import SynapseRequest
|
||||
@@ -34,7 +34,8 @@ class ReceiptRestServlet(RestServlet):
|
||||
PATTERNS = client_patterns(
|
||||
"/rooms/(?P<room_id>[^/]*)"
|
||||
"/receipt/(?P<receipt_type>[^/]*)"
|
||||
"/(?P<event_id>[^/]*)$"
|
||||
"/(?P<event_id>[^/]*)"
|
||||
"(/(?P<thread_id>[^/]*))?$"
|
||||
)
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
@@ -50,8 +51,15 @@ class ReceiptRestServlet(RestServlet):
|
||||
(ReceiptTypes.READ_PRIVATE, ReceiptTypes.FULLY_READ)
|
||||
)
|
||||
|
||||
self._msc3771_enabled = hs.config.experimental.msc3771_enabled
|
||||
|
||||
async def on_POST(
|
||||
self, request: SynapseRequest, room_id: str, receipt_type: str, event_id: str
|
||||
self,
|
||||
request: SynapseRequest,
|
||||
room_id: str,
|
||||
receipt_type: str,
|
||||
event_id: str,
|
||||
thread_id: Optional[str],
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
|
||||
@@ -61,6 +69,9 @@ class ReceiptRestServlet(RestServlet):
|
||||
f"Receipt type must be {', '.join(self._known_receipt_types)}",
|
||||
)
|
||||
|
||||
if thread_id and not self._msc3771_enabled:
|
||||
raise UnrecognizedRequestError()
|
||||
|
||||
parse_json_object_from_request(request, allow_empty_body=False)
|
||||
|
||||
await self.presence_handler.bump_presence_active_time(requester.user)
|
||||
@@ -77,6 +88,7 @@ class ReceiptRestServlet(RestServlet):
|
||||
receipt_type,
|
||||
user_id=requester.user.to_string(),
|
||||
event_id=event_id,
|
||||
thread_id=thread_id,
|
||||
)
|
||||
|
||||
return 200, {}
|
||||
|
||||
@@ -13,8 +13,10 @@
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import TYPE_CHECKING, Optional, Tuple
|
||||
|
||||
from synapse.handlers.relations import ThreadsListInclude
|
||||
from synapse.http.server import HttpServer
|
||||
from synapse.http.servlet import RestServlet, parse_integer, parse_string
|
||||
from synapse.http.site import SynapseRequest
|
||||
@@ -91,5 +93,55 @@ class RelationPaginationServlet(RestServlet):
|
||||
return 200, result
|
||||
|
||||
|
||||
class ThreadsServlet(RestServlet):
|
||||
PATTERNS = (
|
||||
re.compile(
|
||||
"^/_matrix/client/unstable/org.matrix.msc3856/rooms/(?P<room_id>[^/]*)/threads"
|
||||
),
|
||||
)
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self.store = hs.get_datastores().main
|
||||
self._relations_handler = hs.get_relations_handler()
|
||||
|
||||
async def on_GET(
|
||||
self, request: SynapseRequest, room_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
|
||||
limit = parse_integer(request, "limit", default=5)
|
||||
from_token_str = parse_string(request, "from")
|
||||
to_token_str = parse_string(request, "to")
|
||||
include = parse_string(
|
||||
request,
|
||||
"include",
|
||||
default=ThreadsListInclude.all.value,
|
||||
allowed_values=[v.value for v in ThreadsListInclude],
|
||||
)
|
||||
|
||||
# Return the relations
|
||||
from_token = None
|
||||
if from_token_str:
|
||||
from_token = await StreamToken.from_string(self.store, from_token_str)
|
||||
to_token = None
|
||||
if to_token_str:
|
||||
to_token = await StreamToken.from_string(self.store, to_token_str)
|
||||
|
||||
result = await self._relations_handler.get_threads(
|
||||
requester=requester,
|
||||
room_id=room_id,
|
||||
include=ThreadsListInclude(include),
|
||||
limit=limit,
|
||||
from_token=from_token,
|
||||
to_token=to_token,
|
||||
)
|
||||
|
||||
return 200, result
|
||||
|
||||
|
||||
def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
|
||||
RelationPaginationServlet(hs).register(http_server)
|
||||
if hs.config.experimental.msc3856_enabled:
|
||||
ThreadsServlet(hs).register(http_server)
|
||||
|
||||
@@ -509,6 +509,8 @@ class SyncRestServlet(RestServlet):
|
||||
ephemeral_events = room.ephemeral
|
||||
result["ephemeral"] = {"events": ephemeral_events}
|
||||
result["unread_notifications"] = room.unread_notifications
|
||||
if room.unread_thread_notifications:
|
||||
result["unread_thread_notifications"] = room.unread_thread_notifications
|
||||
result["summary"] = room.summary
|
||||
if self._msc2654_enabled:
|
||||
result["org.matrix.msc2654.unread_count"] = room.unread_count
|
||||
|
||||
@@ -78,7 +78,6 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union, cast
|
||||
|
||||
import attr
|
||||
|
||||
from synapse.api.constants import ReceiptTypes
|
||||
from synapse.metrics.background_process_metrics import wrap_as_background_process
|
||||
from synapse.storage._base import SQLBaseStore, db_to_json
|
||||
from synapse.storage.database import (
|
||||
@@ -206,10 +205,11 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
|
||||
|
||||
self._rotate_count = 10000
|
||||
self._doing_notif_rotation = False
|
||||
if hs.config.worker.run_background_tasks:
|
||||
self._rotate_notif_loop = self._clock.looping_call(
|
||||
self._rotate_notifs, 30 * 1000
|
||||
)
|
||||
# XXX Do not rotate summaries for now, they're broken.
|
||||
# if hs.config.worker.run_background_tasks:
|
||||
# self._rotate_notif_loop = self._clock.looping_call(
|
||||
# self._rotate_notifs, 30 * 1000
|
||||
# )
|
||||
|
||||
self.db_pool.updates.register_background_index_update(
|
||||
"event_push_summary_unique_index",
|
||||
@@ -220,12 +220,21 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
|
||||
replaces_index="event_push_summary_user_rm",
|
||||
)
|
||||
|
||||
@cached(tree=True, max_entries=5000)
|
||||
self.db_pool.updates.register_background_index_update(
|
||||
"event_push_summary_unique_index2",
|
||||
index_name="event_push_summary_unique_index2",
|
||||
table="event_push_summary",
|
||||
columns=["user_id", "room_id", "thread_id"],
|
||||
unique=True,
|
||||
replaces_index="event_push_summary_unique_index",
|
||||
)
|
||||
|
||||
@cached(tree=True, max_entries=5000, iterable=True)
|
||||
async def get_unread_event_push_actions_by_room_for_user(
|
||||
self,
|
||||
room_id: str,
|
||||
user_id: str,
|
||||
) -> NotifCounts:
|
||||
) -> Tuple[NotifCounts, Dict[str, NotifCounts]]:
|
||||
"""Get the notification count, the highlight count and the unread message count
|
||||
for a given user in a given room after the given read receipt.
|
||||
|
||||
@@ -254,31 +263,19 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
|
||||
txn: LoggingTransaction,
|
||||
room_id: str,
|
||||
user_id: str,
|
||||
) -> NotifCounts:
|
||||
result = self.get_last_receipt_for_user_txn(
|
||||
txn,
|
||||
user_id,
|
||||
room_id,
|
||||
receipt_types=(ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE),
|
||||
) -> Tuple[NotifCounts, Dict[str, NotifCounts]]:
|
||||
# Either last_read_event_id is None, or it's an event we don't have (e.g.
|
||||
# because it's been purged), in which case retrieve the stream ordering for
|
||||
# the latest membership event from this user in this room (which we assume is
|
||||
# a join).
|
||||
event_id = self.db_pool.simple_select_one_onecol_txn(
|
||||
txn=txn,
|
||||
table="local_current_membership",
|
||||
keyvalues={"room_id": room_id, "user_id": user_id},
|
||||
retcol="event_id",
|
||||
)
|
||||
|
||||
stream_ordering = None
|
||||
if result:
|
||||
_, stream_ordering = result
|
||||
|
||||
if stream_ordering is None:
|
||||
# Either last_read_event_id is None, or it's an event we don't have (e.g.
|
||||
# because it's been purged), in which case retrieve the stream ordering for
|
||||
# the latest membership event from this user in this room (which we assume is
|
||||
# a join).
|
||||
event_id = self.db_pool.simple_select_one_onecol_txn(
|
||||
txn=txn,
|
||||
table="local_current_membership",
|
||||
keyvalues={"room_id": room_id, "user_id": user_id},
|
||||
retcol="event_id",
|
||||
)
|
||||
|
||||
stream_ordering = self.get_stream_id_for_event_txn(txn, event_id)
|
||||
stream_ordering = self.get_stream_id_for_event_txn(txn, event_id)
|
||||
|
||||
return self._get_unread_counts_by_pos_txn(
|
||||
txn, room_id, user_id, stream_ordering
|
||||
@@ -286,12 +283,20 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
|
||||
|
||||
def _get_unread_counts_by_pos_txn(
|
||||
self, txn: LoggingTransaction, room_id: str, user_id: str, stream_ordering: int
|
||||
) -> NotifCounts:
|
||||
) -> Tuple[NotifCounts, Dict[str, NotifCounts]]:
|
||||
"""Get the number of unread messages for a user/room that have happened
|
||||
since the given stream ordering.
|
||||
|
||||
Returns:
|
||||
A tuple of:
|
||||
The unread messages for the main timeline
|
||||
|
||||
A dictionary of thread ID to unread messages for that thread.
|
||||
Only contains threads with unread messages.
|
||||
"""
|
||||
|
||||
counts = NotifCounts()
|
||||
thread_counts = {}
|
||||
|
||||
# First we pull the counts from the summary table.
|
||||
#
|
||||
@@ -306,51 +311,92 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
|
||||
# date (as the row was written by an older version of Synapse that
|
||||
# updated `event_push_summary` synchronously when persisting a new read
|
||||
# receipt).
|
||||
txn.execute(
|
||||
"""
|
||||
SELECT stream_ordering, notif_count, COALESCE(unread_count, 0)
|
||||
FROM event_push_summary
|
||||
WHERE room_id = ? AND user_id = ?
|
||||
AND (
|
||||
(last_receipt_stream_ordering IS NULL AND stream_ordering > ?)
|
||||
OR last_receipt_stream_ordering = ?
|
||||
)
|
||||
""",
|
||||
(room_id, user_id, stream_ordering, stream_ordering),
|
||||
)
|
||||
row = txn.fetchone()
|
||||
|
||||
summary_stream_ordering = 0
|
||||
if row:
|
||||
summary_stream_ordering = row[0]
|
||||
counts.notify_count += row[1]
|
||||
counts.unread_count += row[2]
|
||||
# XXX event_push_summary is not currently filled in. broken.
|
||||
# txn.execute(
|
||||
# """
|
||||
# SELECT notif_count, COALESCE(unread_count, 0), thread_id, MAX(events.stream_ordering)
|
||||
# FROM event_push_summary
|
||||
# LEFT JOIN receipts_linearized USING (room_id, user_id, thread_id)
|
||||
# LEFT JOIN events ON (
|
||||
# events.room_id = receipts_linearized.room_id AND
|
||||
# events.event_id = receipts_linearized.event_id
|
||||
# )
|
||||
# WHERE event_push_summary.room_id = ? AND user_id = ?
|
||||
# AND (
|
||||
# (
|
||||
# last_receipt_stream_ordering IS NULL
|
||||
# AND event_push_summary.stream_ordering > COALESCE(events.stream_ordering, ?)
|
||||
# )
|
||||
# OR last_receipt_stream_ordering = COALESCE(events.stream_ordering, ?)
|
||||
# )
|
||||
# AND (receipt_type = 'm.read' OR receipt_type = 'org.matrix.msc2285.read.private')
|
||||
# """,
|
||||
# (room_id, user_id, stream_ordering, stream_ordering),
|
||||
# )
|
||||
# for notif_count, unread_count, thread_id, _ in txn:
|
||||
# # XXX Why are these returned? Related to MAX(...) aggregation.
|
||||
# if notif_count is None:
|
||||
# continue
|
||||
#
|
||||
# if not thread_id:
|
||||
# counts = NotifCounts(
|
||||
# notify_count=notif_count, unread_count=unread_count
|
||||
# )
|
||||
# # TODO Delete zeroed out threads completely from the database.
|
||||
# elif notif_count or unread_count:
|
||||
# thread_counts[thread_id] = NotifCounts(
|
||||
# notify_count=notif_count, unread_count=unread_count
|
||||
# )
|
||||
|
||||
# Next we need to count highlights, which aren't summarised
|
||||
sql = """
|
||||
SELECT COUNT(*) FROM event_push_actions
|
||||
SELECT COUNT(*), thread_id, MAX(events.stream_ordering) FROM event_push_actions
|
||||
LEFT JOIN receipts_linearized USING (room_id, user_id, thread_id)
|
||||
LEFT JOIN events ON (
|
||||
events.room_id = receipts_linearized.room_id AND
|
||||
events.event_id = receipts_linearized.event_id
|
||||
)
|
||||
WHERE user_id = ?
|
||||
AND room_id = ?
|
||||
AND stream_ordering > ?
|
||||
AND event_push_actions.room_id = ?
|
||||
AND event_push_actions.stream_ordering > COALESCE(events.stream_ordering, ?)
|
||||
AND highlight = 1
|
||||
GROUP BY thread_id
|
||||
"""
|
||||
txn.execute(sql, (user_id, room_id, stream_ordering))
|
||||
row = txn.fetchone()
|
||||
if row:
|
||||
counts.highlight_count += row[0]
|
||||
for highlight_count, thread_id, _ in txn:
|
||||
if not thread_id:
|
||||
counts.highlight_count += highlight_count
|
||||
elif highlight_count:
|
||||
if thread_id in thread_counts:
|
||||
thread_counts[thread_id].highlight_count += highlight_count
|
||||
else:
|
||||
thread_counts[thread_id] = NotifCounts(
|
||||
notify_count=0, unread_count=0, highlight_count=highlight_count
|
||||
)
|
||||
|
||||
# Finally we need to count push actions that aren't included in the
|
||||
# summary returned above, e.g. recent events that haven't been
|
||||
# summarised yet, or the summary is empty due to a recent read receipt.
|
||||
stream_ordering = max(stream_ordering, summary_stream_ordering)
|
||||
notify_count, unread_count = self._get_notif_unread_count_for_user_room(
|
||||
unread_counts = self._get_notif_unread_count_for_user_room(
|
||||
txn, room_id, user_id, stream_ordering
|
||||
)
|
||||
|
||||
counts.notify_count += notify_count
|
||||
counts.unread_count += unread_count
|
||||
for notif_count, unread_count, thread_id in unread_counts:
|
||||
if not thread_id:
|
||||
counts.notify_count += notif_count
|
||||
counts.unread_count += unread_count
|
||||
elif thread_id in thread_counts:
|
||||
thread_counts[thread_id].notify_count += notif_count
|
||||
thread_counts[thread_id].unread_count += unread_count
|
||||
else:
|
||||
thread_counts[thread_id] = NotifCounts(
|
||||
notify_count=notif_count,
|
||||
unread_count=unread_count,
|
||||
highlight_count=0,
|
||||
)
|
||||
|
||||
return counts
|
||||
return counts, thread_counts
|
||||
|
||||
def _get_notif_unread_count_for_user_room(
|
||||
self,
|
||||
@@ -359,7 +405,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
|
||||
user_id: str,
|
||||
stream_ordering: int,
|
||||
max_stream_ordering: Optional[int] = None,
|
||||
) -> Tuple[int, int]:
|
||||
) -> List[Tuple[int, int, str]]:
|
||||
"""Returns the notify and unread counts from `event_push_actions` for
|
||||
the given user/room in the given range.
|
||||
|
||||
@@ -380,8 +426,10 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
|
||||
|
||||
# If there have been no events in the room since the stream ordering,
|
||||
# there can't be any push actions either.
|
||||
if not self._events_stream_cache.has_entity_changed(room_id, stream_ordering):
|
||||
return 0, 0
|
||||
#
|
||||
# XXX
|
||||
# if not self._events_stream_cache.has_entity_changed(room_id, stream_ordering):
|
||||
# return []
|
||||
|
||||
clause = ""
|
||||
args = [user_id, room_id, stream_ordering]
|
||||
@@ -389,29 +437,29 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
|
||||
clause = "AND ea.stream_ordering <= ?"
|
||||
args.append(max_stream_ordering)
|
||||
|
||||
# If the max stream ordering is less than the min stream ordering,
|
||||
# then obviously there are zero push actions in that range.
|
||||
if max_stream_ordering <= stream_ordering:
|
||||
return 0, 0
|
||||
|
||||
sql = f"""
|
||||
SELECT
|
||||
COUNT(CASE WHEN notif = 1 THEN 1 END),
|
||||
COUNT(CASE WHEN unread = 1 THEN 1 END)
|
||||
FROM event_push_actions ea
|
||||
WHERE user_id = ?
|
||||
AND room_id = ?
|
||||
AND ea.stream_ordering > ?
|
||||
COUNT(CASE WHEN unread = 1 THEN 1 END),
|
||||
thread_id,
|
||||
MAX(events.stream_ordering)
|
||||
FROM event_push_actions ea
|
||||
LEFT JOIN receipts_linearized USING (room_id, user_id, thread_id)
|
||||
LEFT JOIN events ON (
|
||||
events.room_id = receipts_linearized.room_id AND
|
||||
events.event_id = receipts_linearized.event_id
|
||||
)
|
||||
WHERE user_id = ?
|
||||
AND ea.room_id = ?
|
||||
AND ea.stream_ordering > COALESCE(events.stream_ordering, ?)
|
||||
{clause}
|
||||
GROUP BY thread_id
|
||||
"""
|
||||
|
||||
txn.execute(sql, args)
|
||||
row = txn.fetchone()
|
||||
|
||||
if row:
|
||||
return cast(Tuple[int, int], row)
|
||||
|
||||
return 0, 0
|
||||
# The max stream ordering is simply there to select the latest receipt,
|
||||
# it doesn't need to be returned.
|
||||
return [cast(Tuple[int, int, str], row[:3]) for row in txn.fetchall()]
|
||||
|
||||
async def get_push_action_users_in_range(
|
||||
self, min_stream_ordering: int, max_stream_ordering: int
|
||||
@@ -680,6 +728,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
|
||||
event_id: str,
|
||||
user_id_actions: Dict[str, List[Union[dict, str]]],
|
||||
count_as_unread: bool,
|
||||
thread_id: Optional[str],
|
||||
) -> None:
|
||||
"""Add the push actions for the event to the push action staging area.
|
||||
|
||||
@@ -688,6 +737,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
|
||||
user_id_actions: A mapping of user_id to list of push actions, where
|
||||
an action can either be a string or dict.
|
||||
count_as_unread: Whether this event should increment unread counts.
|
||||
thread_id: The thread this event is parent of, if applicable.
|
||||
"""
|
||||
if not user_id_actions:
|
||||
return
|
||||
@@ -696,7 +746,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
|
||||
# can be used to insert into the `event_push_actions_staging` table.
|
||||
def _gen_entry(
|
||||
user_id: str, actions: List[Union[dict, str]]
|
||||
) -> Tuple[str, str, str, int, int, int]:
|
||||
) -> Tuple[str, str, str, int, int, int, Optional[str]]:
|
||||
is_highlight = 1 if _action_has_highlight(actions) else 0
|
||||
notif = 1 if "notify" in actions else 0
|
||||
return (
|
||||
@@ -706,6 +756,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
|
||||
notif, # notif column
|
||||
is_highlight, # highlight column
|
||||
int(count_as_unread), # unread column
|
||||
thread_id or "", # thread_id column
|
||||
)
|
||||
|
||||
def _add_push_actions_to_staging_txn(txn: LoggingTransaction) -> None:
|
||||
@@ -714,8 +765,8 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
|
||||
|
||||
sql = """
|
||||
INSERT INTO event_push_actions_staging
|
||||
(event_id, user_id, actions, notif, highlight, unread)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
(event_id, user_id, actions, notif, highlight, unread, thread_id)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||
"""
|
||||
|
||||
txn.execute_batch(
|
||||
@@ -955,7 +1006,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
|
||||
)
|
||||
|
||||
sql = """
|
||||
SELECT r.stream_id, r.room_id, r.user_id, e.stream_ordering
|
||||
SELECT r.stream_id, r.room_id, r.user_id, r.thread_id, e.stream_ordering
|
||||
FROM receipts_linearized AS r
|
||||
INNER JOIN events AS e USING (event_id)
|
||||
WHERE ? < r.stream_id AND r.stream_id <= ? AND user_id LIKE ?
|
||||
@@ -978,41 +1029,83 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
|
||||
)
|
||||
rows = txn.fetchall()
|
||||
|
||||
# For each new read receipt we delete push actions from before it and
|
||||
# recalculate the summary.
|
||||
for _, room_id, user_id, stream_ordering in rows:
|
||||
# Group the rows by room ID / user ID.
|
||||
rows_by_room_user: Dict[Tuple[str, str], List[Tuple[str, str, int]]] = {}
|
||||
for stream_id, room_id, user_id, thread_id, stream_ordering in rows:
|
||||
# Only handle our own read receipts.
|
||||
if not self.hs.is_mine_id(user_id):
|
||||
continue
|
||||
|
||||
txn.execute(
|
||||
"""
|
||||
DELETE FROM event_push_actions
|
||||
WHERE room_id = ?
|
||||
AND user_id = ?
|
||||
AND stream_ordering <= ?
|
||||
AND highlight = 0
|
||||
""",
|
||||
(room_id, user_id, stream_ordering),
|
||||
rows_by_room_user.setdefault((room_id, user_id), []).append(
|
||||
(stream_id, thread_id, stream_ordering)
|
||||
)
|
||||
|
||||
# For each new read receipt we delete push actions from before it and
|
||||
# recalculate the summary.
|
||||
for (room_id, user_id), room_rows in rows_by_room_user.items():
|
||||
for _, thread_id, stream_ordering in room_rows:
|
||||
txn.execute(
|
||||
"""
|
||||
DELETE FROM event_push_actions
|
||||
WHERE room_id = ?
|
||||
AND user_id = ?
|
||||
AND thread_id = ?
|
||||
AND stream_ordering <= ?
|
||||
AND highlight = 0
|
||||
""",
|
||||
(room_id, user_id, thread_id, stream_ordering),
|
||||
)
|
||||
|
||||
# Fetch the notification counts between the stream ordering of the
|
||||
# latest receipt and what was previously summarised.
|
||||
notif_count, unread_count = self._get_notif_unread_count_for_user_room(
|
||||
txn, room_id, user_id, stream_ordering, old_rotate_stream_ordering
|
||||
earliest_stream_ordering = min(r[2] for r in room_rows)
|
||||
unread_counts = self._get_notif_unread_count_for_user_room(
|
||||
txn,
|
||||
room_id,
|
||||
user_id,
|
||||
earliest_stream_ordering,
|
||||
old_rotate_stream_ordering,
|
||||
)
|
||||
|
||||
# Replace the previous summary with the new counts.
|
||||
self.db_pool.simple_upsert_txn(
|
||||
# Updated threads get their notification count and unread count updated.
|
||||
self.db_pool.simple_upsert_many_txn(
|
||||
txn,
|
||||
table="event_push_summary",
|
||||
keyvalues={"room_id": room_id, "user_id": user_id},
|
||||
values={
|
||||
"notif_count": notif_count,
|
||||
"unread_count": unread_count,
|
||||
"stream_ordering": old_rotate_stream_ordering,
|
||||
"last_receipt_stream_ordering": stream_ordering,
|
||||
},
|
||||
key_names=("room_id", "user_id", "thread_id"),
|
||||
key_values=[(room_id, user_id, row[2]) for row in unread_counts],
|
||||
value_names=(
|
||||
"notif_count",
|
||||
"unread_count",
|
||||
"stream_ordering",
|
||||
"last_receipt_stream_ordering",
|
||||
),
|
||||
value_values=[
|
||||
# XXX Stream ordering.
|
||||
(
|
||||
row[0],
|
||||
row[1],
|
||||
old_rotate_stream_ordering,
|
||||
earliest_stream_ordering,
|
||||
)
|
||||
for row in unread_counts
|
||||
],
|
||||
)
|
||||
|
||||
# XXX WTF?
|
||||
# Other threads should be marked as reset at the old stream ordering.
|
||||
txn.execute(
|
||||
"""
|
||||
UPDATE event_push_summary SET notif_count = 0, unread_count = 0, stream_ordering = ?, last_receipt_stream_ordering = ?
|
||||
WHERE user_id = ? AND room_id = ? AND
|
||||
stream_ordering <= ?
|
||||
""",
|
||||
(
|
||||
old_rotate_stream_ordering,
|
||||
min_receipts_stream_id,
|
||||
user_id,
|
||||
room_id,
|
||||
old_rotate_stream_ordering,
|
||||
),
|
||||
)
|
||||
|
||||
# We always update `event_push_summary_last_receipt_stream_id` to
|
||||
@@ -1102,23 +1195,23 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
|
||||
|
||||
# Calculate the new counts that should be upserted into event_push_summary
|
||||
sql = """
|
||||
SELECT user_id, room_id,
|
||||
SELECT user_id, room_id, thread_id,
|
||||
coalesce(old.%s, 0) + upd.cnt,
|
||||
upd.stream_ordering
|
||||
FROM (
|
||||
SELECT user_id, room_id, count(*) as cnt,
|
||||
SELECT user_id, room_id, thread_id, count(*) as cnt,
|
||||
max(ea.stream_ordering) as stream_ordering
|
||||
FROM event_push_actions AS ea
|
||||
LEFT JOIN event_push_summary AS old USING (user_id, room_id)
|
||||
LEFT JOIN event_push_summary AS old USING (user_id, room_id, thread_id)
|
||||
WHERE ? < ea.stream_ordering AND ea.stream_ordering <= ?
|
||||
AND (
|
||||
old.last_receipt_stream_ordering IS NULL
|
||||
OR old.last_receipt_stream_ordering < ea.stream_ordering
|
||||
)
|
||||
AND %s = 1
|
||||
GROUP BY user_id, room_id
|
||||
GROUP BY user_id, room_id, thread_id
|
||||
) AS upd
|
||||
LEFT JOIN event_push_summary AS old USING (user_id, room_id)
|
||||
LEFT JOIN event_push_summary AS old USING (user_id, room_id, thread_id)
|
||||
"""
|
||||
|
||||
# First get the count of unread messages.
|
||||
@@ -1132,11 +1225,11 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
|
||||
# object because we might not have the same amount of rows in each of them. To do
|
||||
# this, we use a dict indexed on the user ID and room ID to make it easier to
|
||||
# populate.
|
||||
summaries: Dict[Tuple[str, str], _EventPushSummary] = {}
|
||||
summaries: Dict[Tuple[str, str, str], _EventPushSummary] = {}
|
||||
for row in txn:
|
||||
summaries[(row[0], row[1])] = _EventPushSummary(
|
||||
unread_count=row[2],
|
||||
stream_ordering=row[3],
|
||||
summaries[(row[0], row[1], row[2])] = _EventPushSummary(
|
||||
unread_count=row[3],
|
||||
stream_ordering=row[4],
|
||||
notif_count=0,
|
||||
)
|
||||
|
||||
@@ -1147,17 +1240,17 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
|
||||
)
|
||||
|
||||
for row in txn:
|
||||
if (row[0], row[1]) in summaries:
|
||||
summaries[(row[0], row[1])].notif_count = row[2]
|
||||
if (row[0], row[1], row[2]) in summaries:
|
||||
summaries[(row[0], row[1], row[2])].notif_count = row[3]
|
||||
else:
|
||||
# Because the rules on notifying are different than the rules on marking
|
||||
# a message unread, we might end up with messages that notify but aren't
|
||||
# marked unread, so we might not have a summary for this (user, room)
|
||||
# tuple to complete.
|
||||
summaries[(row[0], row[1])] = _EventPushSummary(
|
||||
summaries[(row[0], row[1], row[2])] = _EventPushSummary(
|
||||
unread_count=0,
|
||||
stream_ordering=row[3],
|
||||
notif_count=row[2],
|
||||
stream_ordering=row[4],
|
||||
notif_count=row[3],
|
||||
)
|
||||
|
||||
logger.info("Rotating notifications, handling %d rows", len(summaries))
|
||||
@@ -1165,8 +1258,11 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
|
||||
self.db_pool.simple_upsert_many_txn(
|
||||
txn,
|
||||
table="event_push_summary",
|
||||
key_names=("user_id", "room_id"),
|
||||
key_values=[(user_id, room_id) for user_id, room_id in summaries],
|
||||
key_names=("user_id", "room_id", "thread_id"),
|
||||
key_values=[
|
||||
(user_id, room_id, thread_id)
|
||||
for user_id, room_id, thread_id in summaries
|
||||
],
|
||||
value_names=("notif_count", "unread_count", "stream_ordering"),
|
||||
value_values=[
|
||||
(
|
||||
|
||||
@@ -1594,7 +1594,7 @@ class PersistEventsStore:
|
||||
)
|
||||
|
||||
# Remove from relations table.
|
||||
self._handle_redact_relations(txn, event.redacts)
|
||||
self._handle_redact_relations(txn, event.room_id, event.redacts)
|
||||
|
||||
# Update the event_forward_extremities, event_backward_extremities and
|
||||
# event_edges tables.
|
||||
@@ -1909,6 +1909,7 @@ class PersistEventsStore:
|
||||
self.store.get_thread_participated.invalidate,
|
||||
(relation.parent_id, event.sender),
|
||||
)
|
||||
txn.call_after(self.store.get_threads.invalidate, (event.room_id,))
|
||||
|
||||
def _handle_insertion_event(
|
||||
self, txn: LoggingTransaction, event: EventBase
|
||||
@@ -2033,13 +2034,14 @@ class PersistEventsStore:
|
||||
txn.execute(sql, (batch_id,))
|
||||
|
||||
def _handle_redact_relations(
|
||||
self, txn: LoggingTransaction, redacted_event_id: str
|
||||
self, txn: LoggingTransaction, room_id: str, redacted_event_id: str
|
||||
) -> None:
|
||||
"""Handles receiving a redaction and checking whether the redacted event
|
||||
has any relations which must be removed from the database.
|
||||
|
||||
Args:
|
||||
txn
|
||||
room_id: The room ID of the event that was redacted.
|
||||
redacted_event_id: The event that was redacted.
|
||||
"""
|
||||
|
||||
@@ -2068,6 +2070,7 @@ class PersistEventsStore:
|
||||
self.store._invalidate_cache_and_stream(
|
||||
txn, self.store.get_thread_participated, (redacted_relates_to,)
|
||||
)
|
||||
txn.call_after(self.store.get_threads.invalidate, (room_id,))
|
||||
self.store._invalidate_cache_and_stream(
|
||||
txn,
|
||||
self.store.get_mutual_event_relations_for_rel_type,
|
||||
@@ -2190,9 +2193,9 @@ class PersistEventsStore:
|
||||
sql = """
|
||||
INSERT INTO event_push_actions (
|
||||
room_id, event_id, user_id, actions, stream_ordering,
|
||||
topological_ordering, notif, highlight, unread
|
||||
topological_ordering, notif, highlight, unread, thread_id
|
||||
)
|
||||
SELECT ?, event_id, user_id, actions, ?, ?, notif, highlight, unread
|
||||
SELECT ?, event_id, user_id, actions, ?, ?, notif, highlight, unread, thread_id
|
||||
FROM event_push_actions_staging
|
||||
WHERE event_id = ?
|
||||
"""
|
||||
|
||||
@@ -117,6 +117,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
||||
"""Get the current max stream ID for receipts stream"""
|
||||
return self._receipts_id_gen.get_current_token()
|
||||
|
||||
# XXX MOVE TO TESTS
|
||||
async def get_last_receipt_event_id_for_user(
|
||||
self, user_id: str, room_id: str, receipt_types: Collection[str]
|
||||
) -> Optional[str]:
|
||||
@@ -411,6 +412,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
||||
"_get_linearized_receipts_for_rooms", f
|
||||
)
|
||||
|
||||
# Map of room ID to a dictionary in the form that sync wants it.
|
||||
results: JsonDict = {}
|
||||
for row in txn_results:
|
||||
# We want a single event per room, since we want to batch the
|
||||
@@ -426,6 +428,8 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
||||
receipt_type = event_entry.setdefault(row["receipt_type"], {})
|
||||
|
||||
receipt_type[row["user_id"]] = db_to_json(row["data"])
|
||||
if row["thread_id"]:
|
||||
receipt_type[row["user_id"]]["thread_id"] = row["thread_id"]
|
||||
|
||||
results = {
|
||||
room_id: [results[room_id]] if room_id in results else []
|
||||
@@ -522,7 +526,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
||||
|
||||
async def get_all_updated_receipts(
|
||||
self, instance_name: str, last_id: int, current_id: int, limit: int
|
||||
) -> Tuple[List[Tuple[int, list]], int, bool]:
|
||||
) -> Tuple[List[Tuple[int, Tuple[str, str, str, str, str, JsonDict]]], int, bool]:
|
||||
"""Get updates for receipts replication stream.
|
||||
|
||||
Args:
|
||||
@@ -549,9 +553,11 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
||||
|
||||
def get_all_updated_receipts_txn(
|
||||
txn: LoggingTransaction,
|
||||
) -> Tuple[List[Tuple[int, list]], int, bool]:
|
||||
) -> Tuple[
|
||||
List[Tuple[int, Tuple[str, str, str, str, str, JsonDict]]], int, bool
|
||||
]:
|
||||
sql = """
|
||||
SELECT stream_id, room_id, receipt_type, user_id, event_id, data
|
||||
SELECT stream_id, room_id, receipt_type, user_id, event_id, thread_id, data
|
||||
FROM receipts_linearized
|
||||
WHERE ? < stream_id AND stream_id <= ?
|
||||
ORDER BY stream_id ASC
|
||||
@@ -560,8 +566,8 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
||||
txn.execute(sql, (last_id, current_id, limit))
|
||||
|
||||
updates = cast(
|
||||
List[Tuple[int, list]],
|
||||
[(r[0], r[1:5] + (db_to_json(r[5]),)) for r in txn],
|
||||
List[Tuple[int, Tuple[str, str, str, str, str, JsonDict]]],
|
||||
[(r[0], r[1:6] + (db_to_json(r[6]),)) for r in txn],
|
||||
)
|
||||
|
||||
limited = False
|
||||
@@ -613,6 +619,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
||||
receipt_type: str,
|
||||
user_id: str,
|
||||
event_id: str,
|
||||
thread_id: Optional[str],
|
||||
data: JsonDict,
|
||||
stream_id: int,
|
||||
) -> Optional[int]:
|
||||
@@ -636,15 +643,18 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
||||
stream_ordering = int(res["stream_ordering"]) if res else None
|
||||
rx_ts = res["received_ts"] if res else 0
|
||||
|
||||
# Convert None to a blank string.
|
||||
thread_id = thread_id or ""
|
||||
|
||||
# We don't want to clobber receipts for more recent events, so we
|
||||
# have to compare orderings of existing receipts
|
||||
if stream_ordering is not None:
|
||||
sql = (
|
||||
"SELECT stream_ordering, event_id FROM events"
|
||||
" INNER JOIN receipts_linearized AS r USING (event_id, room_id)"
|
||||
" WHERE r.room_id = ? AND r.receipt_type = ? AND r.user_id = ?"
|
||||
" INNER JOIN receipts_linearized as r USING (event_id, room_id)"
|
||||
" WHERE r.room_id = ? AND r.receipt_type = ? AND r.user_id = ? AND r.thread_id = ?"
|
||||
)
|
||||
txn.execute(sql, (room_id, receipt_type, user_id))
|
||||
txn.execute(sql, (room_id, receipt_type, user_id, thread_id))
|
||||
|
||||
for so, eid in txn:
|
||||
if int(so) >= stream_ordering:
|
||||
@@ -656,6 +666,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
||||
)
|
||||
return None
|
||||
|
||||
# TODO
|
||||
txn.call_after(
|
||||
self.invalidate_caches_for_receipt, room_id, receipt_type, user_id
|
||||
)
|
||||
@@ -671,6 +682,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
||||
"room_id": room_id,
|
||||
"receipt_type": receipt_type,
|
||||
"user_id": user_id,
|
||||
"thread_id": thread_id,
|
||||
},
|
||||
values={
|
||||
"stream_id": stream_id,
|
||||
@@ -678,7 +690,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
||||
"data": json_encoder.encode(data),
|
||||
},
|
||||
# receipts_linearized has a unique constraint on
|
||||
# (user_id, room_id, receipt_type), so no need to lock
|
||||
# (user_id, room_id, receipt_type, key), so no need to lock
|
||||
lock=False,
|
||||
)
|
||||
|
||||
@@ -728,6 +740,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
||||
receipt_type: str,
|
||||
user_id: str,
|
||||
event_ids: List[str],
|
||||
thread_id: Optional[str],
|
||||
data: dict,
|
||||
) -> Optional[Tuple[int, int]]:
|
||||
"""Insert a receipt, either from local client or remote server.
|
||||
@@ -760,6 +773,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
||||
receipt_type,
|
||||
user_id,
|
||||
linearized_event_id,
|
||||
thread_id,
|
||||
data,
|
||||
stream_id=stream_id,
|
||||
# Read committed is actually beneficial here because we check for a receipt with
|
||||
@@ -774,7 +788,8 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
||||
|
||||
now = self._clock.time_msec()
|
||||
logger.debug(
|
||||
"RR for event %s in %s (%i ms old)",
|
||||
"Receipt %s for event %s in %s (%i ms old)",
|
||||
receipt_type,
|
||||
linearized_event_id,
|
||||
room_id,
|
||||
now - event_ts,
|
||||
@@ -787,6 +802,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
||||
receipt_type,
|
||||
user_id,
|
||||
event_ids,
|
||||
thread_id,
|
||||
data,
|
||||
)
|
||||
|
||||
@@ -801,6 +817,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
||||
receipt_type: str,
|
||||
user_id: str,
|
||||
event_ids: List[str],
|
||||
thread_id: Optional[str],
|
||||
data: JsonDict,
|
||||
) -> None:
|
||||
assert self._can_write_to_receipts
|
||||
@@ -812,6 +829,9 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
||||
# FIXME: This shouldn't invalidate the whole cache
|
||||
txn.call_after(self._get_linearized_receipts_for_room.invalidate, (room_id,))
|
||||
|
||||
# Convert None to a blank string.
|
||||
thread_id = thread_id or ""
|
||||
|
||||
self.db_pool.simple_delete_txn(
|
||||
txn,
|
||||
table="receipts_graph",
|
||||
@@ -819,6 +839,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
||||
"room_id": room_id,
|
||||
"receipt_type": receipt_type,
|
||||
"user_id": user_id,
|
||||
"thread_id": thread_id,
|
||||
},
|
||||
)
|
||||
self.db_pool.simple_insert_txn(
|
||||
@@ -829,6 +850,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
||||
"receipt_type": receipt_type,
|
||||
"user_id": user_id,
|
||||
"event_ids": json_encoder.encode(event_ids),
|
||||
"thread_id": thread_id,
|
||||
"data": json_encoder.encode(data),
|
||||
},
|
||||
)
|
||||
|
||||
@@ -814,6 +814,93 @@ class RelationsWorkerStore(SQLBaseStore):
|
||||
"get_event_relations", _get_event_relations
|
||||
)
|
||||
|
||||
@cached(tree=True)
|
||||
async def get_threads(
|
||||
self,
|
||||
room_id: str,
|
||||
limit: int = 5,
|
||||
from_token: Optional[StreamToken] = None,
|
||||
to_token: Optional[StreamToken] = None,
|
||||
) -> Tuple[List[str], Optional[StreamToken]]:
|
||||
"""Get a list of thread IDs, ordered by topological ordering of their
|
||||
latest reply.
|
||||
|
||||
Args:
|
||||
room_id: The room the event belongs to.
|
||||
limit: Only fetch the most recent `limit` threads.
|
||||
from_token: Fetch rows from the given token, or from the start if None.
|
||||
to_token: Fetch rows up to the given token, or up to the end if None.
|
||||
|
||||
Returns:
|
||||
A tuple of:
|
||||
A list of thread root event IDs.
|
||||
|
||||
The next stream token, if one exists.
|
||||
"""
|
||||
pagination_clause = generate_pagination_where_clause(
|
||||
direction="b",
|
||||
column_names=("topological_ordering", "stream_ordering"),
|
||||
from_token=from_token.room_key.as_historical_tuple()
|
||||
if from_token
|
||||
else None,
|
||||
to_token=to_token.room_key.as_historical_tuple() if to_token else None,
|
||||
engine=self.database_engine,
|
||||
)
|
||||
|
||||
if pagination_clause:
|
||||
pagination_clause = "AND " + pagination_clause
|
||||
|
||||
sql = f"""
|
||||
SELECT relates_to_id, MAX(topological_ordering), MAX(stream_ordering)
|
||||
FROM event_relations
|
||||
INNER JOIN events USING (event_id)
|
||||
WHERE
|
||||
room_id = ? AND
|
||||
relation_type = '{RelationTypes.THREAD}'
|
||||
{pagination_clause}
|
||||
GROUP BY relates_to_id
|
||||
ORDER BY MAX(topological_ordering) DESC, MAX(stream_ordering) DESC
|
||||
LIMIT ?
|
||||
"""
|
||||
|
||||
def _get_threads_txn(
|
||||
txn: LoggingTransaction,
|
||||
) -> Tuple[List[str], Optional[StreamToken]]:
|
||||
txn.execute(sql, [room_id, limit + 1])
|
||||
|
||||
last_topo_id = None
|
||||
last_stream_id = None
|
||||
thread_ids = []
|
||||
for thread_id, topo_id, stream_id in txn:
|
||||
thread_ids.append(thread_id)
|
||||
last_topo_id = topo_id
|
||||
last_stream_id = stream_id
|
||||
|
||||
# If there are more events, generate the next pagination key.
|
||||
next_token = None
|
||||
if len(thread_ids) > limit and last_topo_id and last_stream_id:
|
||||
next_key = RoomStreamToken(last_topo_id, last_stream_id)
|
||||
if from_token:
|
||||
next_token = from_token.copy_and_replace(
|
||||
StreamKeyType.ROOM, next_key
|
||||
)
|
||||
else:
|
||||
next_token = StreamToken(
|
||||
room_key=next_key,
|
||||
presence_key=0,
|
||||
typing_key=0,
|
||||
receipt_key=0,
|
||||
account_data_key=0,
|
||||
push_rules_key=0,
|
||||
to_device_key=0,
|
||||
device_list_key=0,
|
||||
groups_key=0,
|
||||
)
|
||||
|
||||
return thread_ids[:limit], next_token
|
||||
|
||||
return await self.db_pool.runInteraction("get_threads", _get_threads_txn)
|
||||
|
||||
|
||||
class RelationsStore(RelationsWorkerStore):
|
||||
pass
|
||||
|
||||
@@ -0,0 +1,25 @@
|
||||
/* Copyright 2022 The Matrix.org Foundation C.I.C
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
-- Allow multiple receipts per user per room via a nullable thread_id column.
|
||||
ALTER TABLE receipts_linearized ADD COLUMN thread_id TEXT NOT NULL DEFAULT '';
|
||||
ALTER TABLE receipts_graph ADD COLUMN thread_id TEXT NOT NULL DEFAULT '';
|
||||
|
||||
-- Rebuild the unique constraint with the thread_id.
|
||||
ALTER TABLE receipts_linearized DROP CONSTRAINT IF EXISTS receipts_linearized_uniqueness;
|
||||
ALTER TABLE receipts_linearized ADD CONSTRAINT receipts_linearized_uniqueness UNIQUE (room_id, receipt_type, user_id, thread_id);
|
||||
|
||||
ALTER TABLE receipts_graph DROP CONSTRAINT IF EXISTS receipts_graph_uniqueness;
|
||||
ALTER TABLE receipts_graph ADD CONSTRAINT receipts_graph_uniqueness UNIQUE (room_id, receipt_type, user_id, thread_id);
|
||||
@@ -0,0 +1,67 @@
|
||||
/* Copyright 2022 The Matrix.org Foundation C.I.C
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
-- Allow multiple receipts per user per room via a nullable thread_id column.
|
||||
--
|
||||
-- SQLite doesn't support modifying constraints to an existing table, so it must
|
||||
-- be recreated.
|
||||
|
||||
-- Create the new tables.
|
||||
CREATE TABLE receipts_graph_new (
|
||||
room_id TEXT NOT NULL,
|
||||
receipt_type TEXT NOT NULL,
|
||||
user_id TEXT NOT NULL,
|
||||
event_ids TEXT NOT NULL,
|
||||
thread_id TEXT NOT NULL DEFAULT '',
|
||||
data TEXT NOT NULL,
|
||||
CONSTRAINT receipts_graph_uniqueness UNIQUE (room_id, receipt_type, user_id, thread_id)
|
||||
);
|
||||
|
||||
CREATE TABLE receipts_linearized_new (
|
||||
stream_id BIGINT NOT NULL,
|
||||
room_id TEXT NOT NULL,
|
||||
receipt_type TEXT NOT NULL,
|
||||
user_id TEXT NOT NULL,
|
||||
event_id TEXT NOT NULL,
|
||||
thread_id TEXT NOT NULL DEFAULT '',
|
||||
data TEXT NOT NULL,
|
||||
CONSTRAINT receipts_linearized_uniqueness UNIQUE (room_id, receipt_type, user_id, thread_id)
|
||||
);
|
||||
|
||||
-- Drop the old indexes.
|
||||
DROP INDEX IF EXISTS receipts_linearized_id;
|
||||
DROP INDEX IF EXISTS receipts_linearized_room_stream;
|
||||
DROP INDEX IF EXISTS receipts_linearized_user;
|
||||
|
||||
-- Copy the data.
|
||||
INSERT INTO receipts_graph_new (room_id, receipt_type, user_id, event_ids, data)
|
||||
SELECT room_id, receipt_type, user_id, event_ids, data
|
||||
FROM receipts_graph;
|
||||
INSERT INTO receipts_linearized_new (stream_id, room_id, receipt_type, user_id, event_id, data)
|
||||
SELECT stream_id, room_id, receipt_type, user_id, event_id, data
|
||||
FROM receipts_linearized;
|
||||
|
||||
-- Drop the old tables.
|
||||
DROP TABLE receipts_graph;
|
||||
DROP TABLE receipts_linearized;
|
||||
|
||||
-- Rename the tables.
|
||||
ALTER TABLE receipts_graph_new RENAME TO receipts_graph;
|
||||
ALTER TABLE receipts_linearized_new RENAME TO receipts_linearized;
|
||||
|
||||
-- Create the indices.
|
||||
CREATE INDEX receipts_linearized_id ON receipts_linearized( stream_id );
|
||||
CREATE INDEX receipts_linearized_room_stream ON receipts_linearized( room_id, stream_id );
|
||||
CREATE INDEX receipts_linearized_user ON receipts_linearized( user_id );
|
||||
@@ -0,0 +1,27 @@
|
||||
/* Copyright 2022 The Matrix.org Foundation C.I.C
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
ALTER TABLE event_push_actions_staging
|
||||
ADD COLUMN thread_id TEXT NOT NULL DEFAULT '';
|
||||
|
||||
ALTER TABLE event_push_actions
|
||||
ADD COLUMN thread_id TEXT NOT NULL DEFAULT '';
|
||||
|
||||
ALTER TABLE event_push_summary
|
||||
ADD COLUMN thread_id TEXT NOT NULL DEFAULT '';
|
||||
|
||||
-- Update the unique index for `event_push_summary`
|
||||
INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
|
||||
(7003, 'event_push_summary_unique_index2', '{}');
|
||||
@@ -830,6 +830,7 @@ class ReadReceipt:
|
||||
receipt_type: str
|
||||
user_id: str
|
||||
event_ids: List[str]
|
||||
thread_id: Optional[str]
|
||||
data: JsonDict
|
||||
|
||||
|
||||
|
||||
@@ -49,7 +49,12 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
|
||||
|
||||
sender = self.hs.get_federation_sender()
|
||||
receipt = ReadReceipt(
|
||||
"room_id", "m.read", "user_id", ["event_id"], {"ts": 1234}
|
||||
"room_id",
|
||||
"m.read",
|
||||
"user_id",
|
||||
["event_id"],
|
||||
thread_id=None,
|
||||
data={"ts": 1234},
|
||||
)
|
||||
self.successResultOf(defer.ensureDeferred(sender.send_read_receipt(receipt)))
|
||||
|
||||
@@ -89,7 +94,12 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
|
||||
|
||||
sender = self.hs.get_federation_sender()
|
||||
receipt = ReadReceipt(
|
||||
"room_id", "m.read", "user_id", ["event_id"], {"ts": 1234}
|
||||
"room_id",
|
||||
"m.read",
|
||||
"user_id",
|
||||
["event_id"],
|
||||
thread_id=None,
|
||||
data={"ts": 1234},
|
||||
)
|
||||
self.successResultOf(defer.ensureDeferred(sender.send_read_receipt(receipt)))
|
||||
|
||||
@@ -121,7 +131,12 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
|
||||
|
||||
# send the second RR
|
||||
receipt = ReadReceipt(
|
||||
"room_id", "m.read", "user_id", ["other_id"], {"ts": 1234}
|
||||
"room_id",
|
||||
"m.read",
|
||||
"user_id",
|
||||
["other_id"],
|
||||
thread_id=None,
|
||||
data={"ts": 1234},
|
||||
)
|
||||
self.successResultOf(defer.ensureDeferred(sender.send_read_receipt(receipt)))
|
||||
self.pump()
|
||||
|
||||
@@ -447,6 +447,7 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
|
||||
receipt_type="m.read",
|
||||
user_id=self.local_user,
|
||||
event_ids=[f"$eventid_{i}"],
|
||||
thread_id=None,
|
||||
data={},
|
||||
)
|
||||
)
|
||||
|
||||
@@ -171,14 +171,14 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
|
||||
if send_receipt:
|
||||
self.get_success(
|
||||
self.master_store.insert_receipt(
|
||||
ROOM_ID, ReceiptTypes.READ, USER_ID_2, [event1.event_id], {}
|
||||
ROOM_ID, ReceiptTypes.READ, USER_ID_2, [event1.event_id], None, {}
|
||||
)
|
||||
)
|
||||
|
||||
self.check(
|
||||
"get_unread_event_push_actions_by_room_for_user",
|
||||
[ROOM_ID, USER_ID_2],
|
||||
NotifCounts(highlight_count=0, unread_count=0, notify_count=0),
|
||||
(NotifCounts(highlight_count=0, unread_count=0, notify_count=0), {}),
|
||||
)
|
||||
|
||||
self.persist(
|
||||
@@ -191,7 +191,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
|
||||
self.check(
|
||||
"get_unread_event_push_actions_by_room_for_user",
|
||||
[ROOM_ID, USER_ID_2],
|
||||
NotifCounts(highlight_count=0, unread_count=0, notify_count=1),
|
||||
(NotifCounts(highlight_count=0, unread_count=0, notify_count=1), {}),
|
||||
)
|
||||
|
||||
self.persist(
|
||||
@@ -206,7 +206,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
|
||||
self.check(
|
||||
"get_unread_event_push_actions_by_room_for_user",
|
||||
[ROOM_ID, USER_ID_2],
|
||||
NotifCounts(highlight_count=1, unread_count=0, notify_count=2),
|
||||
(NotifCounts(highlight_count=1, unread_count=0, notify_count=2), {}),
|
||||
)
|
||||
|
||||
def test_get_rooms_for_user_with_stream_ordering(self):
|
||||
@@ -404,6 +404,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
|
||||
event.event_id,
|
||||
{user_id: actions for user_id, actions in push_actions},
|
||||
False,
|
||||
None,
|
||||
)
|
||||
)
|
||||
return event, context
|
||||
|
||||
@@ -33,7 +33,12 @@ class ReceiptsStreamTestCase(BaseStreamTestCase):
|
||||
# tell the master to send a new receipt
|
||||
self.get_success(
|
||||
self.hs.get_datastores().main.insert_receipt(
|
||||
"!room:blue", "m.read", USER_ID, ["$event:blue"], {"a": 1}
|
||||
"!room:blue",
|
||||
"m.read",
|
||||
USER_ID,
|
||||
["$event:blue"],
|
||||
thread_id=None,
|
||||
data={"a": 1},
|
||||
)
|
||||
)
|
||||
self.replicate()
|
||||
@@ -57,7 +62,12 @@ class ReceiptsStreamTestCase(BaseStreamTestCase):
|
||||
|
||||
self.get_success(
|
||||
self.hs.get_datastores().main.insert_receipt(
|
||||
"!room2:blue", "m.read", USER_ID, ["$event2:foo"], {"a": 2}
|
||||
"!room2:blue",
|
||||
"m.read",
|
||||
USER_ID,
|
||||
["$event2:foo"],
|
||||
thread_id=None,
|
||||
data={"a": 2},
|
||||
)
|
||||
)
|
||||
self.replicate()
|
||||
|
||||
@@ -1679,3 +1679,154 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
|
||||
relations[RelationTypes.THREAD]["latest_event"]["event_id"],
|
||||
related_event_id,
|
||||
)
|
||||
|
||||
|
||||
class ThreadsTestCase(BaseRelationsTestCase):
|
||||
@unittest.override_config({"experimental_features": {"msc3856_enabled": True}})
|
||||
def test_threads(self) -> None:
|
||||
"""Create threads and ensure the ordering is due to their latest event."""
|
||||
# Create 2 threads.
|
||||
thread_1 = self.parent_id
|
||||
res = self.helper.send(self.room, body="Thread Root!", tok=self.user_token)
|
||||
thread_2 = res["event_id"]
|
||||
|
||||
self._send_relation(RelationTypes.THREAD, "m.room.test")
|
||||
self._send_relation(RelationTypes.THREAD, "m.room.test", parent_id=thread_2)
|
||||
|
||||
# Request the threads in the room.
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads",
|
||||
access_token=self.user_token,
|
||||
)
|
||||
self.assertEquals(200, channel.code, channel.json_body)
|
||||
thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]]
|
||||
self.assertEqual(thread_roots, [thread_2, thread_1])
|
||||
|
||||
# Update the first thread, the ordering should swap.
|
||||
self._send_relation(RelationTypes.THREAD, "m.room.test")
|
||||
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads",
|
||||
access_token=self.user_token,
|
||||
)
|
||||
self.assertEquals(200, channel.code, channel.json_body)
|
||||
thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]]
|
||||
self.assertEqual(thread_roots, [thread_1, thread_2])
|
||||
|
||||
@unittest.override_config({"experimental_features": {"msc3856_enabled": True}})
|
||||
def test_pagination(self) -> None:
|
||||
"""Create threads and paginate through them."""
|
||||
# Create 2 threads.
|
||||
thread_1 = self.parent_id
|
||||
res = self.helper.send(self.room, body="Thread Root!", tok=self.user_token)
|
||||
thread_2 = res["event_id"]
|
||||
|
||||
self._send_relation(RelationTypes.THREAD, "m.room.test")
|
||||
self._send_relation(RelationTypes.THREAD, "m.room.test", parent_id=thread_2)
|
||||
|
||||
# Request the threads in the room.
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads?limit=1",
|
||||
access_token=self.user_token,
|
||||
)
|
||||
self.assertEquals(200, channel.code, channel.json_body)
|
||||
thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]]
|
||||
self.assertEqual(thread_roots, [thread_2])
|
||||
|
||||
# Make sure next_batch has something in it that looks like it could be a
|
||||
# valid token.
|
||||
next_batch = channel.json_body.get("next_batch")
|
||||
self.assertIsInstance(next_batch, str, channel.json_body)
|
||||
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads?limit=1&from={next_batch}",
|
||||
access_token=self.user_token,
|
||||
)
|
||||
self.assertEquals(200, channel.code, channel.json_body)
|
||||
thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]]
|
||||
self.assertEqual(thread_roots, [thread_1], channel.json_body)
|
||||
|
||||
self.assertNotIn("next_batch", channel.json_body, channel.json_body)
|
||||
|
||||
@unittest.override_config({"experimental_features": {"msc3856_enabled": True}})
|
||||
def test_include(self) -> None:
|
||||
"""Filtering threads to all or participated in should work."""
|
||||
# Thread 1 has the user as the root event.
|
||||
thread_1 = self.parent_id
|
||||
self._send_relation(
|
||||
RelationTypes.THREAD, "m.room.test", access_token=self.user2_token
|
||||
)
|
||||
|
||||
# Thread 2 has the user replying.
|
||||
res = self.helper.send(self.room, body="Thread Root!", tok=self.user2_token)
|
||||
thread_2 = res["event_id"]
|
||||
self._send_relation(RelationTypes.THREAD, "m.room.test", parent_id=thread_2)
|
||||
|
||||
# Thread 3 has the user not participating in.
|
||||
res = self.helper.send(self.room, body="Another thread!", tok=self.user2_token)
|
||||
thread_3 = res["event_id"]
|
||||
self._send_relation(
|
||||
RelationTypes.THREAD,
|
||||
"m.room.test",
|
||||
access_token=self.user2_token,
|
||||
parent_id=thread_3,
|
||||
)
|
||||
|
||||
# All threads in the room.
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads",
|
||||
access_token=self.user_token,
|
||||
)
|
||||
self.assertEquals(200, channel.code, channel.json_body)
|
||||
thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]]
|
||||
self.assertEqual(
|
||||
thread_roots, [thread_3, thread_2, thread_1], channel.json_body
|
||||
)
|
||||
|
||||
# Only participated threads.
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads?include=participated",
|
||||
access_token=self.user_token,
|
||||
)
|
||||
self.assertEquals(200, channel.code, channel.json_body)
|
||||
thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]]
|
||||
self.assertEqual(thread_roots, [thread_2, thread_1], channel.json_body)
|
||||
|
||||
@unittest.override_config({"experimental_features": {"msc3856_enabled": True}})
|
||||
def test_ignored_user(self) -> None:
|
||||
"""Events from ignored users should be ignored."""
|
||||
# Thread 1 has a reply from an ignored user.
|
||||
thread_1 = self.parent_id
|
||||
self._send_relation(
|
||||
RelationTypes.THREAD, "m.room.test", access_token=self.user2_token
|
||||
)
|
||||
|
||||
# Thread 2 is created by an ignored user.
|
||||
res = self.helper.send(self.room, body="Thread Root!", tok=self.user2_token)
|
||||
thread_2 = res["event_id"]
|
||||
self._send_relation(RelationTypes.THREAD, "m.room.test", parent_id=thread_2)
|
||||
|
||||
# Ignore user2.
|
||||
self.get_success(
|
||||
self.store.add_account_data_for_user(
|
||||
self.user_id,
|
||||
AccountDataTypes.IGNORED_USER_LIST,
|
||||
{"ignored_users": {self.user2_id: {}}},
|
||||
)
|
||||
)
|
||||
|
||||
# Only thread 1 is returned.
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
f"/_matrix/client/unstable/org.matrix.msc3856/rooms/{self.room}/threads",
|
||||
access_token=self.user_token,
|
||||
)
|
||||
self.assertEquals(200, channel.code, channel.json_body)
|
||||
thread_roots = [ev["event_id"] for ev in channel.json_body["chunk"]]
|
||||
self.assertEqual(thread_roots, [thread_1], channel.json_body)
|
||||
|
||||
@@ -12,12 +12,15 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from twisted.test.proto_helpers import MemoryReactor
|
||||
|
||||
from synapse.rest import admin
|
||||
from synapse.rest.client import login, room
|
||||
from synapse.server import HomeServer
|
||||
from synapse.storage.databases.main.event_push_actions import NotifCounts
|
||||
from synapse.types import JsonDict
|
||||
from synapse.util import Clock
|
||||
|
||||
from tests.unittest import HomeserverTestCase
|
||||
@@ -70,7 +73,7 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
|
||||
def _assert_counts(
|
||||
noitf_count: int, unread_count: int, highlight_count: int
|
||||
) -> None:
|
||||
counts = self.get_success(
|
||||
counts, thread_counts = self.get_success(
|
||||
self.store.db_pool.runInteraction(
|
||||
"get-unread-counts",
|
||||
self.store._get_unread_counts_by_receipt_txn,
|
||||
@@ -86,6 +89,7 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
|
||||
highlight_count=highlight_count,
|
||||
),
|
||||
)
|
||||
self.assertEqual(thread_counts, {})
|
||||
|
||||
def _create_event(highlight: bool = False) -> str:
|
||||
result = self.helper.send_event(
|
||||
@@ -108,6 +112,7 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
|
||||
"m.read",
|
||||
user_id=user_id,
|
||||
event_ids=[event_id],
|
||||
thread_id=None,
|
||||
data={},
|
||||
)
|
||||
)
|
||||
@@ -131,6 +136,7 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
|
||||
_assert_counts(0, 0, 0)
|
||||
|
||||
_create_event()
|
||||
_assert_counts(1, 1, 0)
|
||||
_rotate()
|
||||
_assert_counts(1, 1, 0)
|
||||
|
||||
@@ -166,6 +172,184 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
|
||||
_rotate()
|
||||
_assert_counts(0, 0, 0)
|
||||
|
||||
def test_count_aggregation_threads(self) -> None:
|
||||
# Create a user to receive notifications and send receipts.
|
||||
user_id = self.register_user("user1235", "pass")
|
||||
token = self.login("user1235", "pass")
|
||||
|
||||
# And another users to send events.
|
||||
other_id = self.register_user("other", "pass")
|
||||
other_token = self.login("other", "pass")
|
||||
|
||||
# Create a room and put both users in it.
|
||||
room_id = self.helper.create_room_as(user_id, tok=token)
|
||||
self.helper.join(room_id, other_id, tok=other_token)
|
||||
thread_id: str
|
||||
|
||||
last_event_id: str
|
||||
|
||||
def _assert_counts(
|
||||
noitf_count: int,
|
||||
unread_count: int,
|
||||
highlight_count: int,
|
||||
thread_notif_count: int,
|
||||
thread_unread_count: int,
|
||||
thread_highlight_count: int,
|
||||
) -> None:
|
||||
counts, thread_counts = self.get_success(
|
||||
self.store.db_pool.runInteraction(
|
||||
"get-unread-counts",
|
||||
self.store._get_unread_counts_by_receipt_txn,
|
||||
room_id,
|
||||
user_id,
|
||||
)
|
||||
)
|
||||
self.assertEqual(
|
||||
counts,
|
||||
NotifCounts(
|
||||
notify_count=noitf_count,
|
||||
unread_count=unread_count,
|
||||
highlight_count=highlight_count,
|
||||
),
|
||||
)
|
||||
if thread_notif_count or thread_unread_count or thread_highlight_count:
|
||||
self.assertEqual(
|
||||
thread_counts,
|
||||
{
|
||||
thread_id: NotifCounts(
|
||||
notify_count=thread_notif_count,
|
||||
unread_count=thread_unread_count,
|
||||
highlight_count=thread_highlight_count,
|
||||
),
|
||||
},
|
||||
)
|
||||
else:
|
||||
self.assertEqual(thread_counts, {})
|
||||
|
||||
def _create_event(
|
||||
highlight: bool = False, thread_id: Optional[str] = None
|
||||
) -> str:
|
||||
content: JsonDict = {
|
||||
"msgtype": "m.text",
|
||||
"body": user_id if highlight else "",
|
||||
}
|
||||
if thread_id:
|
||||
content["m.relates_to"] = {
|
||||
"rel_type": "m.thread",
|
||||
"event_id": thread_id,
|
||||
}
|
||||
|
||||
result = self.helper.send_event(
|
||||
room_id,
|
||||
type="m.room.message",
|
||||
content=content,
|
||||
tok=other_token,
|
||||
)
|
||||
nonlocal last_event_id
|
||||
last_event_id = result["event_id"]
|
||||
return last_event_id
|
||||
|
||||
def _rotate() -> None:
|
||||
self.get_success(self.store._rotate_notifs())
|
||||
|
||||
def _mark_read(event_id: str, thread_id: Optional[str] = None) -> None:
|
||||
self.get_success(
|
||||
self.store.insert_receipt(
|
||||
room_id,
|
||||
"m.read",
|
||||
user_id=user_id,
|
||||
event_ids=[event_id],
|
||||
thread_id=thread_id,
|
||||
data={},
|
||||
)
|
||||
)
|
||||
|
||||
_assert_counts(0, 0, 0, 0, 0, 0)
|
||||
thread_id = _create_event()
|
||||
_assert_counts(1, 0, 0, 0, 0, 0)
|
||||
_rotate()
|
||||
_assert_counts(1, 0, 0, 0, 0, 0)
|
||||
|
||||
_create_event(thread_id=thread_id)
|
||||
_assert_counts(1, 0, 0, 1, 0, 0)
|
||||
_rotate()
|
||||
_assert_counts(1, 0, 0, 1, 0, 0)
|
||||
|
||||
_create_event()
|
||||
_assert_counts(2, 0, 0, 1, 0, 0)
|
||||
_rotate()
|
||||
_assert_counts(2, 0, 0, 1, 0, 0)
|
||||
|
||||
event_id = _create_event(thread_id=thread_id)
|
||||
_assert_counts(2, 0, 0, 2, 0, 0)
|
||||
_rotate()
|
||||
_assert_counts(2, 0, 0, 2, 0, 0)
|
||||
|
||||
_create_event()
|
||||
_create_event(thread_id=thread_id)
|
||||
_mark_read(event_id)
|
||||
_assert_counts(1, 0, 0, 3, 0, 0)
|
||||
_mark_read(event_id, thread_id)
|
||||
_assert_counts(1, 0, 0, 1, 0, 0)
|
||||
|
||||
_mark_read(last_event_id)
|
||||
_mark_read(last_event_id, thread_id)
|
||||
_assert_counts(0, 0, 0, 0, 0, 0)
|
||||
|
||||
_create_event()
|
||||
_create_event(thread_id=thread_id)
|
||||
_assert_counts(1, 0, 0, 1, 0, 0)
|
||||
_rotate()
|
||||
_assert_counts(1, 0, 0, 1, 0, 0)
|
||||
|
||||
# Delete old event push actions, this should not affect the (summarised) count.
|
||||
self.get_success(self.store._remove_old_push_actions_that_have_rotated())
|
||||
_assert_counts(1, 0, 0, 1, 0, 0)
|
||||
|
||||
_mark_read(last_event_id)
|
||||
_mark_read(last_event_id, thread_id)
|
||||
_assert_counts(0, 0, 0, 0, 0, 0)
|
||||
|
||||
_create_event(True)
|
||||
_assert_counts(1, 1, 1, 0, 0, 0)
|
||||
_rotate()
|
||||
_assert_counts(1, 1, 1, 0, 0, 0)
|
||||
|
||||
event_id = _create_event(True, thread_id)
|
||||
_assert_counts(1, 1, 1, 1, 1, 1)
|
||||
_rotate()
|
||||
_assert_counts(1, 1, 1, 1, 1, 1)
|
||||
|
||||
# Check that adding another notification and rotating after highlight
|
||||
# works.
|
||||
_create_event()
|
||||
_rotate()
|
||||
_assert_counts(2, 1, 1, 1, 1, 1)
|
||||
|
||||
_create_event(thread_id=thread_id)
|
||||
_rotate()
|
||||
_assert_counts(2, 1, 1, 2, 1, 1)
|
||||
|
||||
# Check that sending read receipts at different points results in the
|
||||
# right counts.
|
||||
_mark_read(event_id)
|
||||
_assert_counts(1, 0, 0, 2, 1, 1)
|
||||
_mark_read(event_id, thread_id)
|
||||
_assert_counts(1, 0, 0, 1, 0, 0)
|
||||
_mark_read(last_event_id)
|
||||
_assert_counts(0, 0, 0, 1, 0, 0)
|
||||
_mark_read(last_event_id, thread_id)
|
||||
_assert_counts(0, 0, 0, 0, 0, 0)
|
||||
|
||||
_create_event(True)
|
||||
_create_event(True, thread_id)
|
||||
_assert_counts(1, 1, 1, 1, 1, 1)
|
||||
_mark_read(last_event_id)
|
||||
_mark_read(last_event_id, thread_id)
|
||||
_assert_counts(0, 0, 0, 0, 0, 0)
|
||||
_rotate()
|
||||
_assert_counts(0, 0, 0, 0, 0, 0)
|
||||
|
||||
def test_find_first_stream_ordering_after_ts(self) -> None:
|
||||
def add_event(so: int, ts: int) -> None:
|
||||
self.get_success(
|
||||
|
||||
@@ -120,13 +120,18 @@ class ReceiptTestCase(HomeserverTestCase):
|
||||
# Send public read receipt for the first event
|
||||
self.get_success(
|
||||
self.store.insert_receipt(
|
||||
self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_1_id], {}
|
||||
self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_1_id], None, {}
|
||||
)
|
||||
)
|
||||
# Send private read receipt for the second event
|
||||
self.get_success(
|
||||
self.store.insert_receipt(
|
||||
self.room_id1, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event1_2_id], {}
|
||||
self.room_id1,
|
||||
ReceiptTypes.READ_PRIVATE,
|
||||
OUR_USER_ID,
|
||||
[event1_2_id],
|
||||
None,
|
||||
{},
|
||||
)
|
||||
)
|
||||
|
||||
@@ -153,7 +158,7 @@ class ReceiptTestCase(HomeserverTestCase):
|
||||
# Test receipt updating
|
||||
self.get_success(
|
||||
self.store.insert_receipt(
|
||||
self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_2_id], {}
|
||||
self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_2_id], None, {}
|
||||
)
|
||||
)
|
||||
res = self.get_success(
|
||||
@@ -169,7 +174,12 @@ class ReceiptTestCase(HomeserverTestCase):
|
||||
# Test new room is reflected in what the method returns
|
||||
self.get_success(
|
||||
self.store.insert_receipt(
|
||||
self.room_id2, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event2_1_id], {}
|
||||
self.room_id2,
|
||||
ReceiptTypes.READ_PRIVATE,
|
||||
OUR_USER_ID,
|
||||
[event2_1_id],
|
||||
None,
|
||||
{},
|
||||
)
|
||||
)
|
||||
res = self.get_success(
|
||||
@@ -191,13 +201,18 @@ class ReceiptTestCase(HomeserverTestCase):
|
||||
# Send public read receipt for the first event
|
||||
self.get_success(
|
||||
self.store.insert_receipt(
|
||||
self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_1_id], {}
|
||||
self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_1_id], None, {}
|
||||
)
|
||||
)
|
||||
# Send private read receipt for the second event
|
||||
self.get_success(
|
||||
self.store.insert_receipt(
|
||||
self.room_id1, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event1_2_id], {}
|
||||
self.room_id1,
|
||||
ReceiptTypes.READ_PRIVATE,
|
||||
OUR_USER_ID,
|
||||
[event1_2_id],
|
||||
None,
|
||||
{},
|
||||
)
|
||||
)
|
||||
|
||||
@@ -230,7 +245,7 @@ class ReceiptTestCase(HomeserverTestCase):
|
||||
# Test receipt updating
|
||||
self.get_success(
|
||||
self.store.insert_receipt(
|
||||
self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_2_id], {}
|
||||
self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_2_id], None, {}
|
||||
)
|
||||
)
|
||||
res = self.get_success(
|
||||
@@ -248,7 +263,12 @@ class ReceiptTestCase(HomeserverTestCase):
|
||||
# Test new room is reflected in what the method returns
|
||||
self.get_success(
|
||||
self.store.insert_receipt(
|
||||
self.room_id2, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event2_1_id], {}
|
||||
self.room_id2,
|
||||
ReceiptTypes.READ_PRIVATE,
|
||||
OUR_USER_ID,
|
||||
[event2_1_id],
|
||||
None,
|
||||
{},
|
||||
)
|
||||
)
|
||||
res = self.get_success(
|
||||
|
||||
190
thread_test.py
Normal file
190
thread_test.py
Normal file
@@ -0,0 +1,190 @@
|
||||
import json
|
||||
from time import monotonic
|
||||
|
||||
import requests
|
||||
|
||||
HOMESERVER = "http://localhost:8080"
|
||||
|
||||
USER_1_TOK = "syt_dGVzdGVy_AywuFarQjsYrHuPkOUvg_25XLNK"
|
||||
USER_1_HEADERS = {"Authorization": f"Bearer {USER_1_TOK}"}
|
||||
|
||||
USER_2_TOK = "syt_b3RoZXI_jtiTnwtlBjMGMixlHIBM_4cxesB"
|
||||
USER_2_HEADERS = {"Authorization": f"Bearer {USER_2_TOK}"}
|
||||
|
||||
|
||||
def _check_for_status(result):
|
||||
# Similar to raise_for_status, but prints the error.
|
||||
if 400 <= result.status_code:
|
||||
error_msg = result.json()
|
||||
result.raise_for_status()
|
||||
print(error_msg)
|
||||
exit(0)
|
||||
|
||||
|
||||
def _sync_and_show(room_id):
|
||||
print("Syncing . . .")
|
||||
result = requests.get(
|
||||
f"{HOMESERVER}/_matrix/client/v3/sync",
|
||||
headers=USER_1_HEADERS,
|
||||
params={
|
||||
"filter": json.dumps(
|
||||
{
|
||||
"room": {
|
||||
"timeline": {"limit": 30, "unread_thread_notifications": True}
|
||||
}
|
||||
}
|
||||
)
|
||||
},
|
||||
)
|
||||
_check_for_status(result)
|
||||
sync_response = result.json()
|
||||
|
||||
room = sync_response["rooms"]["join"][room_id]
|
||||
|
||||
# Find read receipts (this assumes non-overlapping).
|
||||
read_receipts = {} # thread -> event ID -> users
|
||||
for event in room["ephemeral"]["events"]:
|
||||
if event["type"] != "m.receipt":
|
||||
continue
|
||||
|
||||
for event_id, content in event["content"].items():
|
||||
for mxid, receipt in content["m.read"].items():
|
||||
print(mxid, receipt)
|
||||
# Just care about the localpart of the MXID.
|
||||
mxid = mxid.split(":", 1)[0]
|
||||
read_receipts.setdefault(receipt.get("thread_id"), {}).setdefault(
|
||||
event_id, []
|
||||
).append(mxid)
|
||||
|
||||
print(room["unread_notifications"])
|
||||
print(room.get("unread_thread_notifications"))
|
||||
print()
|
||||
|
||||
# Convert events to their threads.
|
||||
threads = {}
|
||||
for event in room["timeline"]["events"]:
|
||||
if event["type"] != "m.room.message":
|
||||
continue
|
||||
|
||||
event_id = event["event_id"]
|
||||
|
||||
parent_id = event["content"].get("m.relates_to", {}).get("event_id")
|
||||
if parent_id:
|
||||
threads[parent_id][1].append(event)
|
||||
else:
|
||||
threads[event_id] = (event, [])
|
||||
|
||||
for root_event_id, (root, thread) in threads.items():
|
||||
msg = root["content"]["body"]
|
||||
print(f"{root_event_id}: {msg}")
|
||||
|
||||
for event in thread:
|
||||
thread_event_id = event["event_id"]
|
||||
|
||||
msg = event["content"]["body"]
|
||||
print(f"\t{thread_event_id}: {msg}")
|
||||
|
||||
if thread_event_id in read_receipts.get(root_event_id, {}):
|
||||
user_ids = ", ".join(read_receipts[root_event_id][thread_event_id])
|
||||
print(f"\t^--------- {user_ids} ---------^")
|
||||
|
||||
if root_event_id in read_receipts[None]:
|
||||
user_ids = ", ".join(read_receipts[None][root_event_id])
|
||||
print(f"^--------- {user_ids} ---------^")
|
||||
|
||||
print()
|
||||
print()
|
||||
|
||||
|
||||
def _send_event(room_id, body, thread_id=None):
|
||||
content = {
|
||||
"msgtype": "m.text",
|
||||
"body": body,
|
||||
}
|
||||
if thread_id:
|
||||
content["m.relates_to"] = {
|
||||
"rel_type": "m.thread",
|
||||
"event_id": thread_id,
|
||||
}
|
||||
|
||||
# Send a msg to the room.
|
||||
result = requests.put(
|
||||
f"{HOMESERVER}/_matrix/client/v3/rooms/{room_id}/send/m.room.message/msg{monotonic()}",
|
||||
json=content,
|
||||
headers=USER_2_HEADERS,
|
||||
)
|
||||
_check_for_status(result)
|
||||
return result.json()["event_id"]
|
||||
|
||||
|
||||
def main():
|
||||
# Create a new room as user 2, add a bunch of messages.
|
||||
result = requests.post(
|
||||
f"{HOMESERVER}/_matrix/client/v3/createRoom",
|
||||
json={"visibility": "public", "name": f"Thread Read Receipts ({monotonic()})"},
|
||||
headers=USER_2_HEADERS,
|
||||
)
|
||||
_check_for_status(result)
|
||||
room_id = result.json()["room_id"]
|
||||
|
||||
# Second user joins the room.
|
||||
result = requests.post(
|
||||
f"{HOMESERVER}/_matrix/client/v3/rooms/{room_id}/join", headers=USER_1_HEADERS
|
||||
)
|
||||
_check_for_status(result)
|
||||
|
||||
# Sync user 1.
|
||||
_sync_and_show(room_id)
|
||||
|
||||
# User 2 sends some messages.
|
||||
event_ids = []
|
||||
|
||||
def _send_and_append(body, thread_id=None):
|
||||
event_id = _send_event(room_id, body, thread_id)
|
||||
event_ids.append(event_id)
|
||||
return event_id
|
||||
|
||||
for msg in range(5):
|
||||
root_message_id = _send_and_append(f"Message {msg}")
|
||||
for msg in range(10):
|
||||
if msg % 2:
|
||||
_send_and_append(f"More message {msg}")
|
||||
else:
|
||||
_send_and_append(f"Thread Message {msg}", root_message_id)
|
||||
|
||||
# User 2 sends a read receipt.
|
||||
print("@second reads main timeline")
|
||||
result = requests.post(
|
||||
f"{HOMESERVER}/_matrix/client/v3/rooms/{room_id}/receipt/m.read/{event_ids[3]}",
|
||||
headers=USER_2_HEADERS,
|
||||
json={},
|
||||
)
|
||||
_check_for_status(result)
|
||||
|
||||
_sync_and_show(room_id)
|
||||
|
||||
# User 1 sends a read receipt.
|
||||
print("@test reads main timeline")
|
||||
result = requests.post(
|
||||
f"{HOMESERVER}/_matrix/client/v3/rooms/{room_id}/receipt/m.read/{event_ids[-5]}",
|
||||
headers=USER_1_HEADERS,
|
||||
json={},
|
||||
)
|
||||
_check_for_status(result)
|
||||
|
||||
_sync_and_show(room_id)
|
||||
|
||||
# User 1 sends another read receipt.
|
||||
print("@test reads thread")
|
||||
result = requests.post(
|
||||
f"{HOMESERVER}/_matrix/client/v3/rooms/{room_id}/receipt/m.read/{event_ids[-4]}/{root_message_id}",
|
||||
headers=USER_1_HEADERS,
|
||||
json={},
|
||||
)
|
||||
_check_for_status(result)
|
||||
|
||||
_sync_and_show(room_id)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user