move experimental feature msc2654 (unread counts) to per-user flag
This commit is contained in:
@@ -96,12 +96,6 @@ class ExperimentalConfig(Config):
|
||||
# MSC3720 (Account status endpoint)
|
||||
self.msc3720_enabled: bool = experimental.get("msc3720_enabled", False)
|
||||
|
||||
# MSC2654: Unread counts
|
||||
#
|
||||
# Note that enabling this will result in an incorrect unread count for
|
||||
# previously calculated push actions.
|
||||
self.msc2654_enabled: bool = experimental.get("msc2654_enabled", False)
|
||||
|
||||
# MSC2815 (allow room moderators to view redacted event content)
|
||||
self.msc2815_enabled: bool = experimental.get("msc2815_enabled", False)
|
||||
|
||||
|
||||
@@ -27,6 +27,7 @@ from typing import (
|
||||
Union,
|
||||
)
|
||||
|
||||
import attr
|
||||
from prometheus_client import Counter
|
||||
|
||||
from synapse.api.constants import (
|
||||
@@ -107,6 +108,17 @@ def _should_count_as_unread(event: EventBase, context: EventContext) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
@attr.s(slots=True, auto_attribs=True)
|
||||
class ActionsForUser:
|
||||
"""
|
||||
A class to hold the actions for a given event and whether the event should
|
||||
increment the unread count.
|
||||
"""
|
||||
|
||||
actions: Collection[Union[Mapping, str]]
|
||||
count_as_unread: bool
|
||||
|
||||
|
||||
class BulkPushRuleEvaluator:
|
||||
"""Calculates the outcome of push rules for an event for all users in the
|
||||
room at once.
|
||||
@@ -336,15 +348,8 @@ class BulkPushRuleEvaluator:
|
||||
# (historical messages persisted in reverse-chronological order).
|
||||
return
|
||||
|
||||
# Disable counting as unread unless the experimental configuration is
|
||||
# enabled, as it can cause additional (unwanted) rows to be added to the
|
||||
# event_push_actions table.
|
||||
count_as_unread = False
|
||||
if self.hs.config.experimental.msc2654_enabled:
|
||||
count_as_unread = _should_count_as_unread(event, context)
|
||||
|
||||
rules_by_user = await self._get_rules_for_event(event)
|
||||
actions_by_user: Dict[str, Collection[Union[Mapping, str]]] = {}
|
||||
actions_by_user: Dict[str, ActionsForUser] = {}
|
||||
|
||||
room_member_count = await self.store.get_number_joined_users_in_room(
|
||||
event.room_id
|
||||
@@ -429,17 +434,19 @@ class BulkPushRuleEvaluator:
|
||||
if not isinstance(display_name, str):
|
||||
display_name = None
|
||||
|
||||
if count_as_unread:
|
||||
# Add an element for the current user if the event needs to be marked as
|
||||
# unread, so that add_push_actions_to_staging iterates over it.
|
||||
# If the event shouldn't be marked as unread but should notify the
|
||||
# current user, it'll be added to the dict later.
|
||||
actions_by_user[uid] = []
|
||||
|
||||
actions = evaluator.run(rules, uid, display_name)
|
||||
if "notify" in actions:
|
||||
# Push rules say we should notify the user of this event
|
||||
actions_by_user[uid] = actions
|
||||
|
||||
# check whether unread counts are enabled for this user
|
||||
unread_enabled = await self.store.get_feature_enabled(uid, "msc2654")
|
||||
if unread_enabled:
|
||||
count_as_unread = _should_count_as_unread(event, context)
|
||||
else:
|
||||
count_as_unread = False
|
||||
|
||||
if "notify" in actions or count_as_unread:
|
||||
# Push rules say we should notify the user of this event or the event should
|
||||
# increment the unread count
|
||||
actions_by_user[uid] = ActionsForUser(actions, count_as_unread)
|
||||
|
||||
# If there aren't any actions then we can skip the rest of the
|
||||
# processing.
|
||||
@@ -467,7 +474,6 @@ class BulkPushRuleEvaluator:
|
||||
await self.store.add_push_actions_to_staging(
|
||||
event.event_id,
|
||||
actions_by_user,
|
||||
count_as_unread,
|
||||
thread_id,
|
||||
)
|
||||
|
||||
|
||||
@@ -100,7 +100,6 @@ class SyncRestServlet(RestServlet):
|
||||
self.presence_handler = hs.get_presence_handler()
|
||||
self._server_notices_sender = hs.get_server_notices_sender()
|
||||
self._event_serializer = hs.get_event_client_serializer()
|
||||
self._msc2654_enabled = hs.config.experimental.msc2654_enabled
|
||||
self._msc3773_enabled = hs.config.experimental.msc3773_enabled
|
||||
|
||||
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||
@@ -261,7 +260,7 @@ class SyncRestServlet(RestServlet):
|
||||
)
|
||||
|
||||
joined = await self.encode_joined(
|
||||
sync_result.joined, time_now, serialize_options
|
||||
sync_result.joined, time_now, serialize_options, requester
|
||||
)
|
||||
|
||||
invited = await self.encode_invited(
|
||||
@@ -273,7 +272,7 @@ class SyncRestServlet(RestServlet):
|
||||
)
|
||||
|
||||
archived = await self.encode_archived(
|
||||
sync_result.archived, time_now, serialize_options
|
||||
sync_result.archived, time_now, serialize_options, requester
|
||||
)
|
||||
|
||||
logger.debug("building sync response dict")
|
||||
@@ -344,6 +343,7 @@ class SyncRestServlet(RestServlet):
|
||||
rooms: List[JoinedSyncResult],
|
||||
time_now: int,
|
||||
serialize_options: SerializeEventConfig,
|
||||
requester: Requester,
|
||||
) -> JsonDict:
|
||||
"""
|
||||
Encode the joined rooms in a sync result
|
||||
@@ -352,13 +352,18 @@ class SyncRestServlet(RestServlet):
|
||||
rooms: list of sync results for rooms this user is joined to
|
||||
time_now: current time - used as a baseline for age calculations
|
||||
serialize_options: Event serializer options
|
||||
requester: The requester of the sync
|
||||
Returns:
|
||||
The joined rooms list, in our response format
|
||||
"""
|
||||
joined = {}
|
||||
for room in rooms:
|
||||
joined[room.room_id] = await self.encode_room(
|
||||
room, time_now, joined=True, serialize_options=serialize_options
|
||||
room,
|
||||
time_now,
|
||||
joined=True,
|
||||
serialize_options=serialize_options,
|
||||
requester=requester,
|
||||
)
|
||||
|
||||
return joined
|
||||
@@ -449,6 +454,7 @@ class SyncRestServlet(RestServlet):
|
||||
rooms: List[ArchivedSyncResult],
|
||||
time_now: int,
|
||||
serialize_options: SerializeEventConfig,
|
||||
requester: Requester,
|
||||
) -> JsonDict:
|
||||
"""
|
||||
Encode the archived rooms in a sync result
|
||||
@@ -457,13 +463,18 @@ class SyncRestServlet(RestServlet):
|
||||
rooms: list of sync results for rooms this user is joined to
|
||||
time_now: current time - used as a baseline for age calculations
|
||||
serialize_options: Event serializer options
|
||||
requester: the requester of the sync
|
||||
Returns:
|
||||
The archived rooms list, in our response format
|
||||
"""
|
||||
joined = {}
|
||||
for room in rooms:
|
||||
joined[room.room_id] = await self.encode_room(
|
||||
room, time_now, joined=False, serialize_options=serialize_options
|
||||
room,
|
||||
time_now,
|
||||
joined=False,
|
||||
serialize_options=serialize_options,
|
||||
requester=requester,
|
||||
)
|
||||
|
||||
return joined
|
||||
@@ -474,6 +485,7 @@ class SyncRestServlet(RestServlet):
|
||||
time_now: int,
|
||||
joined: bool,
|
||||
serialize_options: SerializeEventConfig,
|
||||
requester: Requester,
|
||||
) -> JsonDict:
|
||||
"""
|
||||
Args:
|
||||
@@ -486,6 +498,7 @@ class SyncRestServlet(RestServlet):
|
||||
only_fields: Optional. The list of event fields to include.
|
||||
event_formatter: function to convert from federation format
|
||||
to client format
|
||||
Requester: The requester of the sync
|
||||
Returns:
|
||||
The room, encoded in our response format
|
||||
"""
|
||||
@@ -539,7 +552,9 @@ class SyncRestServlet(RestServlet):
|
||||
"org.matrix.msc3773.unread_thread_notifications"
|
||||
] = room.unread_thread_notifications
|
||||
result["summary"] = room.summary
|
||||
if self._msc2654_enabled:
|
||||
if await self.store.get_feature_enabled(
|
||||
requester.user.to_string(), "msc2654"
|
||||
):
|
||||
result["org.matrix.msc2654.unread_count"] = room.unread_count
|
||||
|
||||
return result
|
||||
|
||||
@@ -105,6 +105,7 @@ from synapse.util import json_encoder
|
||||
from synapse.util.caches.descriptors import cached
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.push.bulk_push_rule_evaluator import ActionsForUser
|
||||
from synapse.server import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -1215,8 +1216,7 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
|
||||
async def add_push_actions_to_staging(
|
||||
self,
|
||||
event_id: str,
|
||||
user_id_actions: Dict[str, Collection[Union[Mapping, str]]],
|
||||
count_as_unread: bool,
|
||||
user_id_actions: Dict[str, "ActionsForUser"],
|
||||
thread_id: str,
|
||||
) -> None:
|
||||
"""Add the push actions for the event to the push action staging area.
|
||||
@@ -1234,17 +1234,19 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
|
||||
# This is a helper function for generating the necessary tuple that
|
||||
# can be used to insert into the `event_push_actions_staging` table.
|
||||
def _gen_entry(
|
||||
user_id: str, actions: Collection[Union[Mapping, str]]
|
||||
user_id: str, actions_by_user: "ActionsForUser"
|
||||
) -> Tuple[str, str, str, int, int, int, str, int]:
|
||||
is_highlight = 1 if _action_has_highlight(actions) else 0
|
||||
notif = 1 if "notify" in actions else 0
|
||||
is_highlight = 1 if _action_has_highlight(actions_by_user.actions) else 0
|
||||
notif = 1 if "notify" in actions_by_user.actions else 0
|
||||
return (
|
||||
event_id, # event_id column
|
||||
user_id, # user_id column
|
||||
_serialize_action(actions, bool(is_highlight)), # actions column
|
||||
_serialize_action(
|
||||
actions_by_user.actions, bool(is_highlight)
|
||||
), # actions column
|
||||
notif, # notif column
|
||||
is_highlight, # highlight column
|
||||
int(count_as_unread), # unread column
|
||||
int(actions_by_user.count_as_unread), # unread column
|
||||
thread_id, # thread_id column
|
||||
self._clock.time_msec(), # inserted_ts column
|
||||
)
|
||||
@@ -1262,8 +1264,8 @@ class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBas
|
||||
"inserted_ts",
|
||||
),
|
||||
values=[
|
||||
_gen_entry(user_id, actions)
|
||||
for user_id, actions in user_id_actions.items()
|
||||
_gen_entry(user_id, actions_by_user)
|
||||
for user_id, actions_by_user in user_id_actions.items()
|
||||
],
|
||||
desc="add_push_actions_to_staging",
|
||||
)
|
||||
|
||||
@@ -24,6 +24,7 @@ from synapse.api.room_versions import RoomVersions
|
||||
from synapse.events import EventBase, _EventInternalMetadata, make_event_from_dict
|
||||
from synapse.events.snapshot import EventContext
|
||||
from synapse.handlers.room import RoomEventSource
|
||||
from synapse.push.bulk_push_rule_evaluator import ActionsForUser
|
||||
from synapse.server import HomeServer
|
||||
from synapse.storage.databases.main.event_push_actions import (
|
||||
NotifCounts,
|
||||
@@ -412,8 +413,10 @@ class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase):
|
||||
self.get_success(
|
||||
self.master_store.add_push_actions_to_staging(
|
||||
event.event_id,
|
||||
dict(push_actions),
|
||||
False,
|
||||
{
|
||||
user_id: ActionsForUser(actions, False)
|
||||
for user_id, actions in push_actions
|
||||
},
|
||||
"main",
|
||||
)
|
||||
)
|
||||
|
||||
@@ -538,9 +538,6 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
def default_config(self) -> JsonDict:
|
||||
config = super().default_config()
|
||||
config["experimental_features"] = {
|
||||
"msc2654_enabled": True,
|
||||
}
|
||||
return config
|
||||
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
@@ -588,6 +585,9 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
def test_unread_counts(self) -> None:
|
||||
"""Tests that /sync returns the right value for the unread count (MSC2654)."""
|
||||
# add per-user flag to the DB
|
||||
ex_handler = self.hs.get_experimental_features_manager()
|
||||
self.get_success(ex_handler.set_feature_for_user(self.user_id, "msc2654", True))
|
||||
|
||||
# Check that our own messages don't increase the unread count.
|
||||
self.helper.send(self.room_id, "hello", tok=self.tok)
|
||||
|
||||
Reference in New Issue
Block a user