diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index e0a84e56d3..573763f270 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -96,12 +96,6 @@ class ExperimentalConfig(Config): # MSC3720 (Account status endpoint) self.msc3720_enabled: bool = experimental.get("msc3720_enabled", False) - # MSC2654: Unread counts - # - # Note that enabling this will result in an incorrect unread count for - # previously calculated push actions. - self.msc2654_enabled: bool = experimental.get("msc2654_enabled", False) - # MSC2815 (allow room moderators to view redacted event content) self.msc2815_enabled: bool = experimental.get("msc2815_enabled", False) diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index 320084f5f5..a2d00645f0 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -27,6 +27,7 @@ from typing import ( Union, ) +import attr from prometheus_client import Counter from synapse.api.constants import ( @@ -107,6 +108,17 @@ def _should_count_as_unread(event: EventBase, context: EventContext) -> bool: return False +@attr.s(slots=True, auto_attribs=True) +class ActionsForUser: + """ + A class to hold the actions for a given event and whether the event should + increment the unread count. + """ + + actions: Collection[Union[Mapping, str]] + count_as_unread: bool + + class BulkPushRuleEvaluator: """Calculates the outcome of push rules for an event for all users in the room at once. @@ -336,15 +348,8 @@ class BulkPushRuleEvaluator: # (historical messages persisted in reverse-chronological order). return - # Disable counting as unread unless the experimental configuration is - # enabled, as it can cause additional (unwanted) rows to be added to the - # event_push_actions table. - count_as_unread = False - if self.hs.config.experimental.msc2654_enabled: - count_as_unread = _should_count_as_unread(event, context) - rules_by_user = await self._get_rules_for_event(event) - actions_by_user: Dict[str, Collection[Union[Mapping, str]]] = {} + actions_by_user: Dict[str, ActionsForUser] = {} room_member_count = await self.store.get_number_joined_users_in_room( event.room_id @@ -429,17 +434,19 @@ class BulkPushRuleEvaluator: if not isinstance(display_name, str): display_name = None - if count_as_unread: - # Add an element for the current user if the event needs to be marked as - # unread, so that add_push_actions_to_staging iterates over it. - # If the event shouldn't be marked as unread but should notify the - # current user, it'll be added to the dict later. - actions_by_user[uid] = [] - actions = evaluator.run(rules, uid, display_name) - if "notify" in actions: - # Push rules say we should notify the user of this event - actions_by_user[uid] = actions + + # check whether unread counts are enabled for this user + unread_enabled = await self.store.get_feature_enabled(uid, "msc2654") + if unread_enabled: + count_as_unread = _should_count_as_unread(event, context) + else: + count_as_unread = False + + if "notify" in actions or count_as_unread: + # Push rules say we should notify the user of this event or the event should + # increment the unread count + actions_by_user[uid] = ActionsForUser(actions, count_as_unread) # If there aren't any actions then we can skip the rest of the # processing. @@ -467,7 +474,6 @@ class BulkPushRuleEvaluator: await self.store.add_push_actions_to_staging( event.event_id, actions_by_user, - count_as_unread, thread_id, ) diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py index 03b0578945..7ae96493cb 100644 --- a/synapse/rest/client/sync.py +++ b/synapse/rest/client/sync.py @@ -100,7 +100,6 @@ class SyncRestServlet(RestServlet): self.presence_handler = hs.get_presence_handler() self._server_notices_sender = hs.get_server_notices_sender() self._event_serializer = hs.get_event_client_serializer() - self._msc2654_enabled = hs.config.experimental.msc2654_enabled self._msc3773_enabled = hs.config.experimental.msc3773_enabled async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: @@ -261,7 +260,7 @@ class SyncRestServlet(RestServlet): ) joined = await self.encode_joined( - sync_result.joined, time_now, serialize_options + sync_result.joined, time_now, serialize_options, requester ) invited = await self.encode_invited( @@ -273,7 +272,7 @@ class SyncRestServlet(RestServlet): ) archived = await self.encode_archived( - sync_result.archived, time_now, serialize_options + sync_result.archived, time_now, serialize_options, requester ) logger.debug("building sync response dict") @@ -344,6 +343,7 @@ class SyncRestServlet(RestServlet): rooms: List[JoinedSyncResult], time_now: int, serialize_options: SerializeEventConfig, + requester: Requester, ) -> JsonDict: """ Encode the joined rooms in a sync result @@ -352,13 +352,18 @@ class SyncRestServlet(RestServlet): rooms: list of sync results for rooms this user is joined to time_now: current time - used as a baseline for age calculations serialize_options: Event serializer options + requester: The requester of the sync Returns: The joined rooms list, in our response format """ joined = {} for room in rooms: joined[room.room_id] = await self.encode_room( - room, time_now, joined=True, serialize_options=serialize_options + room, + time_now, + joined=True, + serialize_options=serialize_options, + requester=requester, ) return joined @@ -449,6 +454,7 @@ class SyncRestServlet(RestServlet): rooms: List[ArchivedSyncResult], time_now: int, serialize_options: SerializeEventConfig, + requester: Requester, ) -> JsonDict: """ Encode the archived rooms in a sync result @@ -457,13 +463,18 @@ class SyncRestServlet(RestServlet): rooms: list of sync results for rooms this user is joined to time_now: current time - used as a baseline for age calculations serialize_options: Event serializer options + requester: the requester of the sync Returns: The archived rooms list, in our response format """ joined = {} for room in rooms: joined[room.room_id] = await self.encode_room( - room, time_now, joined=False, serialize_options=serialize_options + room, + time_now, + joined=False, + serialize_options=serialize_options, + requester=requester, ) return joined @@ -474,6 +485,7 @@ class SyncRestServlet(RestServlet): time_now: int, joined: bool, serialize_options: SerializeEventConfig, + requester: Requester, ) -> JsonDict: """ Args: @@ -486,6 +498,7 @@ class SyncRestServlet(RestServlet): only_fields: Optional. The list of event fields to include. event_formatter: function to convert from federation format to client format + Requester: The requester of the sync Returns: The room, encoded in our response format """ @@ -539,7 +552,9 @@ class SyncRestServlet(RestServlet): "org.matrix.msc3773.unread_thread_notifications" ] = room.unread_thread_notifications result["summary"] = room.summary - if self._msc2654_enabled: + if await self.store.get_feature_enabled( + requester.user.to_string(), "msc2654" + ): result["org.matrix.msc2654.unread_count"] = room.unread_count return result diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py index eeccf5db24..63f5a96809 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py @@ -105,6 +105,7 @@ from synapse.util import json_encoder from synapse.util.caches.descriptors import cached if TYPE_CHECKING: + from synapse.push.bulk_push_rule_evaluator import ActionsForUser from synapse.server import HomeServer logger = logging.getLogger(__name__) @@ -1215,8 +1216,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas async def add_push_actions_to_staging( self, event_id: str, - user_id_actions: Dict[str, Collection[Union[Mapping, str]]], - count_as_unread: bool, + user_id_actions: Dict[str, "ActionsForUser"], thread_id: str, ) -> None: """Add the push actions for the event to the push action staging area. @@ -1234,17 +1234,19 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas # This is a helper function for generating the necessary tuple that # can be used to insert into the `event_push_actions_staging` table. def _gen_entry( - user_id: str, actions: Collection[Union[Mapping, str]] + user_id: str, actions_by_user: "ActionsForUser" ) -> Tuple[str, str, str, int, int, int, str, int]: - is_highlight = 1 if _action_has_highlight(actions) else 0 - notif = 1 if "notify" in actions else 0 + is_highlight = 1 if _action_has_highlight(actions_by_user.actions) else 0 + notif = 1 if "notify" in actions_by_user.actions else 0 return ( event_id, # event_id column user_id, # user_id column - _serialize_action(actions, bool(is_highlight)), # actions column + _serialize_action( + actions_by_user.actions, bool(is_highlight) + ), # actions column notif, # notif column is_highlight, # highlight column - int(count_as_unread), # unread column + int(actions_by_user.count_as_unread), # unread column thread_id, # thread_id column self._clock.time_msec(), # inserted_ts column ) @@ -1262,8 +1264,8 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas "inserted_ts", ), values=[ - _gen_entry(user_id, actions) - for user_id, actions in user_id_actions.items() + _gen_entry(user_id, actions_by_user) + for user_id, actions_by_user in user_id_actions.items() ], desc="add_push_actions_to_staging", ) diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py index b2125b1fea..aea25ed6df 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py @@ -24,6 +24,7 @@ from synapse.api.room_versions import RoomVersions from synapse.events import EventBase, _EventInternalMetadata, make_event_from_dict from synapse.events.snapshot import EventContext from synapse.handlers.room import RoomEventSource +from synapse.push.bulk_push_rule_evaluator import ActionsForUser from synapse.server import HomeServer from synapse.storage.databases.main.event_push_actions import ( NotifCounts, @@ -412,8 +413,10 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase): self.get_success( self.master_store.add_push_actions_to_staging( event.event_id, - dict(push_actions), - False, + { + user_id: ActionsForUser(actions, False) + for user_id, actions in push_actions + }, "main", ) ) diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py index 9c876c7a32..118b894773 100644 --- a/tests/rest/client/test_sync.py +++ b/tests/rest/client/test_sync.py @@ -538,9 +538,6 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase): def default_config(self) -> JsonDict: config = super().default_config() - config["experimental_features"] = { - "msc2654_enabled": True, - } return config def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: @@ -588,6 +585,9 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase): def test_unread_counts(self) -> None: """Tests that /sync returns the right value for the unread count (MSC2654).""" + # add per-user flag to the DB + ex_handler = self.hs.get_experimental_features_manager() + self.get_success(ex_handler.set_feature_for_user(self.user_id, "msc2654", True)) # Check that our own messages don't increase the unread count. self.helper.send(self.room_id, "hello", tok=self.tok)