1
0

Merge commit '112266eaf' into anoa/dinsic_release_1_21_x

* commit '112266eaf':
  Add StreamStore to mypy (#8232)
  Re-implement unread counts (again) (#8059)
This commit is contained in:
Andrew Morgan
2020-10-20 17:51:52 +01:00
16 changed files with 523 additions and 142 deletions

1
changelog.d/8059.feature Normal file
View File

@@ -0,0 +1 @@
Add unread messages count to sync responses, as specified in [MSC2654](https://github.com/matrix-org/matrix-doc/pull/2654).

1
changelog.d/8232.misc Normal file
View File

@@ -0,0 +1 @@
Add type hints to `StreamStore`.

View File

@@ -43,6 +43,7 @@ files =
synapse/server_notices,
synapse/spam_checker_api,
synapse/state,
synapse/storage/databases/main/stream.py,
synapse/storage/databases/main/ui_auth.py,
synapse/storage/database.py,
synapse/storage/engines,

View File

@@ -18,7 +18,7 @@
import abc
import os
from distutils.util import strtobool
from typing import Dict, Optional, Type
from typing import Dict, Optional, Tuple, Type
from unpaddedbase64 import encode_base64
@@ -120,7 +120,7 @@ class _EventInternalMetadata(object):
# be here
before = DictProperty("before") # type: str
after = DictProperty("after") # type: str
order = DictProperty("order") # type: int
order = DictProperty("order") # type: Tuple[int, int]
def get_dict(self) -> JsonDict:
return dict(self._dict)

View File

@@ -95,7 +95,12 @@ class TimelineBatch:
__bool__ = __nonzero__ # python3
@attr.s(slots=True, frozen=True)
# We can't freeze this class, because we need to update it after it's instantiated to
# update its unread count. This is because we calculate the unread count for a room only
# if there are updates for it, which we check after the instance has been created.
# This should not be a big deal because we update the notification counts afterwards as
# well anyway.
@attr.s(slots=True)
class JoinedSyncResult:
room_id = attr.ib(type=str)
timeline = attr.ib(type=TimelineBatch)
@@ -104,6 +109,7 @@ class JoinedSyncResult:
account_data = attr.ib(type=List[JsonDict])
unread_notifications = attr.ib(type=JsonDict)
summary = attr.ib(type=Optional[JsonDict])
unread_count = attr.ib(type=int)
def __nonzero__(self) -> bool:
"""Make the result appear empty if there are no updates. This is used
@@ -931,7 +937,7 @@ class SyncHandler(object):
async def unread_notifs_for_room_id(
self, room_id: str, sync_config: SyncConfig
) -> Optional[Dict[str, str]]:
) -> Dict[str, int]:
with Measure(self.clock, "unread_notifs_for_room_id"):
last_unread_event_id = await self.store.get_last_receipt_event_id_for_user(
user_id=sync_config.user.to_string(),
@@ -939,15 +945,10 @@ class SyncHandler(object):
receipt_type="m.read",
)
if last_unread_event_id:
notifs = await self.store.get_unread_event_push_actions_by_room_for_user(
room_id, sync_config.user.to_string(), last_unread_event_id
)
return notifs
# There is no new information in this period, so your notification
# count is whatever it was last time.
return None
notifs = await self.store.get_unread_event_push_actions_by_room_for_user(
room_id, sync_config.user.to_string(), last_unread_event_id
)
return notifs
async def generate_sync_result(
self,
@@ -1886,7 +1887,7 @@ class SyncHandler(object):
)
if room_builder.rtype == "joined":
unread_notifications = {} # type: Dict[str, str]
unread_notifications = {} # type: Dict[str, int]
room_sync = JoinedSyncResult(
room_id=room_id,
timeline=batch,
@@ -1895,14 +1896,16 @@ class SyncHandler(object):
account_data=account_data_events,
unread_notifications=unread_notifications,
summary=summary,
unread_count=0,
)
if room_sync or always_include:
notifs = await self.unread_notifs_for_room_id(room_id, sync_config)
if notifs is not None:
unread_notifications["notification_count"] = notifs["notify_count"]
unread_notifications["highlight_count"] = notifs["highlight_count"]
unread_notifications["notification_count"] = notifs["notify_count"]
unread_notifications["highlight_count"] = notifs["highlight_count"]
room_sync.unread_count = notifs["unread_count"]
sync_result_builder.joined.append(room_sync)

View File

@@ -19,8 +19,10 @@ from collections import namedtuple
from prometheus_client import Counter
from synapse.api.constants import EventTypes, Membership
from synapse.api.constants import EventTypes, Membership, RelationTypes
from synapse.event_auth import get_user_power_level
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.state import POWER_KEY
from synapse.util.async_helpers import Linearizer
from synapse.util.caches import register_cache
@@ -51,6 +53,48 @@ push_rules_delta_state_cache_metric = register_cache(
)
STATE_EVENT_TYPES_TO_MARK_UNREAD = {
EventTypes.Topic,
EventTypes.Name,
EventTypes.RoomAvatar,
EventTypes.Tombstone,
}
def _should_count_as_unread(event: EventBase, context: EventContext) -> bool:
# Exclude rejected and soft-failed events.
if context.rejected or event.internal_metadata.is_soft_failed():
return False
# Exclude notices.
if (
not event.is_state()
and event.type == EventTypes.Message
and event.content.get("msgtype") == "m.notice"
):
return False
# Exclude edits.
relates_to = event.content.get("m.relates_to", {})
if relates_to.get("rel_type") == RelationTypes.REPLACE:
return False
# Mark events that have a non-empty string body as unread.
body = event.content.get("body")
if isinstance(body, str) and body:
return True
# Mark some state events as unread.
if event.is_state() and event.type in STATE_EVENT_TYPES_TO_MARK_UNREAD:
return True
# Mark encrypted events as unread.
if not event.is_state() and event.type == EventTypes.Encrypted:
return True
return False
class BulkPushRuleEvaluator(object):
"""Calculates the outcome of push rules for an event for all users in the
room at once.
@@ -133,9 +177,12 @@ class BulkPushRuleEvaluator(object):
return pl_event.content if pl_event else {}, sender_level
async def action_for_event_by_user(self, event, context) -> None:
"""Given an event and context, evaluate the push rules and insert the
results into the event_push_actions_staging table.
"""Given an event and context, evaluate the push rules, check if the message
should increment the unread count, and insert the results into the
event_push_actions_staging table.
"""
count_as_unread = _should_count_as_unread(event, context)
rules_by_user = await self._get_rules_for_event(event, context)
actions_by_user = {}
@@ -172,6 +219,8 @@ class BulkPushRuleEvaluator(object):
if event.type == EventTypes.Member and event.state_key == uid:
display_name = event.content.get("displayname", None)
actions_by_user[uid] = []
for rule in rules:
if "enabled" in rule and not rule["enabled"]:
continue
@@ -189,7 +238,9 @@ class BulkPushRuleEvaluator(object):
# Mark in the DB staging area the push actions for users who should be
# notified for this event. (This will then get handled when we persist
# the event)
await self.store.add_push_actions_to_staging(event.event_id, actions_by_user)
await self.store.add_push_actions_to_staging(
event.event_id, actions_by_user, count_as_unread,
)
def _condition_checker(evaluator, conditions, uid, display_name, cache):
@@ -369,8 +420,8 @@ class RulesForRoom(object):
Args:
ret_rules_by_user (dict): Partiallly filled dict of push rules. Gets
updated with any new rules.
member_event_ids (list): List of event ids for membership events that
have happened since the last time we filled rules_by_user
member_event_ids (dict): Dict of user id to event id for membership events
that have happened since the last time we filled rules_by_user
state_group: The state group we are currently computing push rules
for. Used when updating the cache.
"""
@@ -390,34 +441,19 @@ class RulesForRoom(object):
if logger.isEnabledFor(logging.DEBUG):
logger.debug("Found members %r: %r", self.room_id, members.values())
interested_in_user_ids = {
user_ids = {
user_id
for user_id, membership in members.values()
if membership == Membership.JOIN
}
logger.debug("Joined: %r", interested_in_user_ids)
logger.debug("Joined: %r", user_ids)
if_users_with_pushers = await self.store.get_if_users_have_pushers(
interested_in_user_ids, on_invalidate=self.invalidate_all_cb
)
user_ids = {
uid for uid, have_pusher in if_users_with_pushers.items() if have_pusher
}
logger.debug("With pushers: %r", user_ids)
users_with_receipts = await self.store.get_users_with_read_receipts_in_room(
self.room_id, on_invalidate=self.invalidate_all_cb
)
logger.debug("With receipts: %r", users_with_receipts)
# any users with pushers must be ours: they have pushers
for uid in users_with_receipts:
if uid in interested_in_user_ids:
user_ids.add(uid)
# Previously we only considered users with pushers or read receipts in that
# room. We can't do this anymore because we use push actions to calculate unread
# counts, which don't rely on the user having pushers or sent a read receipt into
# the room. Therefore we just need to filter for local users here.
user_ids = list(filter(self.is_mine_id, user_ids))
rules_by_user = await self.store.bulk_get_push_rules(
user_ids, on_invalidate=self.invalidate_all_cb

View File

@@ -36,7 +36,7 @@ async def get_badge_count(store, user_id):
)
# return one badge count per conversation, as count per
# message is so noisy as to be almost useless
badge += 1 if notifs["notify_count"] else 0
badge += 1 if notifs["unread_count"] else 0
return badge

View File

@@ -425,6 +425,7 @@ class SyncRestServlet(RestServlet):
result["ephemeral"] = {"events": ephemeral_events}
result["unread_notifications"] = room.unread_notifications
result["summary"] = room.summary
result["org.matrix.msc2654.unread_count"] = room.unread_count
return result

View File

@@ -604,6 +604,18 @@ class DatabasePool(object):
results = [dict(zip(col_headers, row)) for row in cursor]
return results
@overload
async def execute(
self, desc: str, decoder: Literal[None], query: str, *args: Any
) -> List[Tuple[Any, ...]]:
...
@overload
async def execute(
self, desc: str, decoder: Callable[[Cursor], R], query: str, *args: Any
) -> R:
...
async def execute(
self,
desc: str,
@@ -1088,6 +1100,28 @@ class DatabasePool(object):
desc, self.simple_select_one_txn, table, keyvalues, retcols, allow_none
)
@overload
async def simple_select_one_onecol(
self,
table: str,
keyvalues: Dict[str, Any],
retcol: Iterable[str],
allow_none: Literal[False] = False,
desc: str = "simple_select_one_onecol",
) -> Any:
...
@overload
async def simple_select_one_onecol(
self,
table: str,
keyvalues: Dict[str, Any],
retcol: Iterable[str],
allow_none: Literal[True] = True,
desc: str = "simple_select_one_onecol",
) -> Optional[Any]:
...
async def simple_select_one_onecol(
self,
table: str,

View File

@@ -15,7 +15,9 @@
# limitations under the License.
import logging
from typing import Dict, List, Union
from typing import Dict, List, Optional, Tuple, Union
import attr
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import LoggingTransaction, SQLBaseStore, db_to_json
@@ -88,8 +90,26 @@ class EventPushActionsWorkerStore(SQLBaseStore):
@cached(num_args=3, tree=True, max_entries=5000)
async def get_unread_event_push_actions_by_room_for_user(
self, room_id, user_id, last_read_event_id
):
self, room_id: str, user_id: str, last_read_event_id: Optional[str],
) -> Dict[str, int]:
"""Get the notification count, the highlight count and the unread message count
for a given user in a given room after the given read receipt.
Note that this function assumes the user to be a current member of the room,
since it's either called by the sync handler to handle joined room entries, or by
the HTTP pusher to calculate the badge of unread joined rooms.
Args:
room_id: The room to retrieve the counts in.
user_id: The user to retrieve the counts for.
last_read_event_id: The event associated with the latest read receipt for
this user in this room. None if no receipt for this user in this room.
Returns
A dict containing the counts mentioned earlier in this docstring,
respectively under the keys "notify_count", "highlight_count" and
"unread_count".
"""
return await self.db_pool.runInteraction(
"get_unread_event_push_actions_by_room",
self._get_unread_counts_by_receipt_txn,
@@ -99,69 +119,71 @@ class EventPushActionsWorkerStore(SQLBaseStore):
)
def _get_unread_counts_by_receipt_txn(
self, txn, room_id, user_id, last_read_event_id
self, txn, room_id, user_id, last_read_event_id,
):
sql = (
"SELECT stream_ordering"
" FROM events"
" WHERE room_id = ? AND event_id = ?"
)
txn.execute(sql, (room_id, last_read_event_id))
results = txn.fetchall()
if len(results) == 0:
return {"notify_count": 0, "highlight_count": 0}
stream_ordering = None
stream_ordering = results[0][0]
if last_read_event_id is not None:
stream_ordering = self.get_stream_id_for_event_txn(
txn, last_read_event_id, allow_none=True,
)
if stream_ordering is None:
# Either last_read_event_id is None, or it's an event we don't have (e.g.
# because it's been purged), in which case retrieve the stream ordering for
# the latest membership event from this user in this room (which we assume is
# a join).
event_id = self.db_pool.simple_select_one_onecol_txn(
txn=txn,
table="local_current_membership",
keyvalues={"room_id": room_id, "user_id": user_id},
retcol="event_id",
)
stream_ordering = self.get_stream_id_for_event_txn(txn, event_id)
return self._get_unread_counts_by_pos_txn(
txn, room_id, user_id, stream_ordering
)
def _get_unread_counts_by_pos_txn(self, txn, room_id, user_id, stream_ordering):
# First get number of notifications.
# We don't need to put a notif=1 clause as all rows always have
# notif=1
sql = (
"SELECT count(*)"
"SELECT"
" COUNT(CASE WHEN notif = 1 THEN 1 END),"
" COUNT(CASE WHEN highlight = 1 THEN 1 END),"
" COUNT(CASE WHEN unread = 1 THEN 1 END)"
" FROM event_push_actions ea"
" WHERE"
" user_id = ?"
" AND room_id = ?"
" AND stream_ordering > ?"
" WHERE user_id = ?"
" AND room_id = ?"
" AND stream_ordering > ?"
)
txn.execute(sql, (user_id, room_id, stream_ordering))
row = txn.fetchone()
notify_count = row[0] if row else 0
(notif_count, highlight_count, unread_count) = (0, 0, 0)
if row:
(notif_count, highlight_count, unread_count) = row
txn.execute(
"""
SELECT notif_count FROM event_push_summary
WHERE room_id = ? AND user_id = ? AND stream_ordering > ?
""",
SELECT notif_count, unread_count FROM event_push_summary
WHERE room_id = ? AND user_id = ? AND stream_ordering > ?
""",
(room_id, user_id, stream_ordering),
)
rows = txn.fetchall()
if rows:
notify_count += rows[0][0]
# Now get the number of highlights
sql = (
"SELECT count(*)"
" FROM event_push_actions ea"
" WHERE"
" highlight = 1"
" AND user_id = ?"
" AND room_id = ?"
" AND stream_ordering > ?"
)
txn.execute(sql, (user_id, room_id, stream_ordering))
row = txn.fetchone()
highlight_count = row[0] if row else 0
return {"notify_count": notify_count, "highlight_count": highlight_count}
if row:
notif_count += row[0]
unread_count += row[1]
return {
"notify_count": notif_count,
"unread_count": unread_count,
"highlight_count": highlight_count,
}
async def get_push_action_users_in_range(
self, min_stream_ordering, max_stream_ordering
@@ -222,6 +244,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
" AND ep.user_id = ?"
" AND ep.stream_ordering > ?"
" AND ep.stream_ordering <= ?"
" AND ep.notif = 1"
" ORDER BY ep.stream_ordering ASC LIMIT ?"
)
args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
@@ -250,6 +273,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
" AND ep.user_id = ?"
" AND ep.stream_ordering > ?"
" AND ep.stream_ordering <= ?"
" AND ep.notif = 1"
" ORDER BY ep.stream_ordering ASC LIMIT ?"
)
args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
@@ -324,6 +348,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
" AND ep.user_id = ?"
" AND ep.stream_ordering > ?"
" AND ep.stream_ordering <= ?"
" AND ep.notif = 1"
" ORDER BY ep.stream_ordering DESC LIMIT ?"
)
args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
@@ -352,6 +377,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
" AND ep.user_id = ?"
" AND ep.stream_ordering > ?"
" AND ep.stream_ordering <= ?"
" AND ep.notif = 1"
" ORDER BY ep.stream_ordering DESC LIMIT ?"
)
args = [user_id, user_id, min_stream_ordering, max_stream_ordering, limit]
@@ -402,7 +428,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
def _get_if_maybe_push_in_range_for_user_txn(txn):
sql = """
SELECT 1 FROM event_push_actions
WHERE user_id = ? AND stream_ordering > ?
WHERE user_id = ? AND stream_ordering > ? AND notif = 1
LIMIT 1
"""
@@ -415,7 +441,10 @@ class EventPushActionsWorkerStore(SQLBaseStore):
)
async def add_push_actions_to_staging(
self, event_id: str, user_id_actions: Dict[str, List[Union[dict, str]]]
self,
event_id: str,
user_id_actions: Dict[str, List[Union[dict, str]]],
count_as_unread: bool,
) -> None:
"""Add the push actions for the event to the push action staging area.
@@ -423,21 +452,23 @@ class EventPushActionsWorkerStore(SQLBaseStore):
event_id
user_id_actions: A mapping of user_id to list of push actions, where
an action can either be a string or dict.
count_as_unread: Whether this event should increment unread counts.
"""
if not user_id_actions:
return
# This is a helper function for generating the necessary tuple that
# can be used to inert into the `event_push_actions_staging` table.
# can be used to insert into the `event_push_actions_staging` table.
def _gen_entry(user_id, actions):
is_highlight = 1 if _action_has_highlight(actions) else 0
notif = 1 if "notify" in actions else 0
return (
event_id, # event_id column
user_id, # user_id column
_serialize_action(actions, is_highlight), # actions column
1, # notif column
notif, # notif column
is_highlight, # highlight column
int(count_as_unread), # unread column
)
def _add_push_actions_to_staging_txn(txn):
@@ -446,8 +477,8 @@ class EventPushActionsWorkerStore(SQLBaseStore):
sql = """
INSERT INTO event_push_actions_staging
(event_id, user_id, actions, notif, highlight)
VALUES (?, ?, ?, ?, ?)
(event_id, user_id, actions, notif, highlight, unread)
VALUES (?, ?, ?, ?, ?, ?)
"""
txn.executemany(
@@ -811,24 +842,63 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
# Calculate the new counts that should be upserted into event_push_summary
sql = """
SELECT user_id, room_id,
coalesce(old.notif_count, 0) + upd.notif_count,
coalesce(old.%s, 0) + upd.cnt,
upd.stream_ordering,
old.user_id
FROM (
SELECT user_id, room_id, count(*) as notif_count,
SELECT user_id, room_id, count(*) as cnt,
max(stream_ordering) as stream_ordering
FROM event_push_actions
WHERE ? <= stream_ordering AND stream_ordering < ?
AND highlight = 0
AND %s = 1
GROUP BY user_id, room_id
) AS upd
LEFT JOIN event_push_summary AS old USING (user_id, room_id)
"""
txn.execute(sql, (old_rotate_stream_ordering, rotate_to_stream_ordering))
rows = txn.fetchall()
# First get the count of unread messages.
txn.execute(
sql % ("unread_count", "unread"),
(old_rotate_stream_ordering, rotate_to_stream_ordering),
)
logger.info("Rotating notifications, handling %d rows", len(rows))
# We need to merge results from the two requests (the one that retrieves the
# unread count and the one that retrieves the notifications count) into a single
# object because we might not have the same amount of rows in each of them. To do
# this, we use a dict indexed on the user ID and room ID to make it easier to
# populate.
summaries = {} # type: Dict[Tuple[str, str], _EventPushSummary]
for row in txn:
summaries[(row[0], row[1])] = _EventPushSummary(
unread_count=row[2],
stream_ordering=row[3],
old_user_id=row[4],
notif_count=0,
)
# Then get the count of notifications.
txn.execute(
sql % ("notif_count", "notif"),
(old_rotate_stream_ordering, rotate_to_stream_ordering),
)
for row in txn:
if (row[0], row[1]) in summaries:
summaries[(row[0], row[1])].notif_count = row[2]
else:
# Because the rules on notifying are different than the rules on marking
# a message unread, we might end up with messages that notify but aren't
# marked unread, so we might not have a summary for this (user, room)
# tuple to complete.
summaries[(row[0], row[1])] = _EventPushSummary(
unread_count=0,
stream_ordering=row[3],
old_user_id=row[4],
notif_count=row[2],
)
logger.info("Rotating notifications, handling %d rows", len(summaries))
# If the `old.user_id` above is NULL then we know there isn't already an
# entry in the table, so we simply insert it. Otherwise we update the
@@ -838,22 +908,34 @@ class EventPushActionsStore(EventPushActionsWorkerStore):
table="event_push_summary",
values=[
{
"user_id": row[0],
"room_id": row[1],
"notif_count": row[2],
"stream_ordering": row[3],
"user_id": user_id,
"room_id": room_id,
"notif_count": summary.notif_count,
"unread_count": summary.unread_count,
"stream_ordering": summary.stream_ordering,
}
for row in rows
if row[4] is None
for ((user_id, room_id), summary) in summaries.items()
if summary.old_user_id is None
],
)
txn.executemany(
"""
UPDATE event_push_summary SET notif_count = ?, stream_ordering = ?
UPDATE event_push_summary
SET notif_count = ?, unread_count = ?, stream_ordering = ?
WHERE user_id = ? AND room_id = ?
""",
((row[2], row[3], row[0], row[1]) for row in rows if row[4] is not None),
(
(
summary.notif_count,
summary.unread_count,
summary.stream_ordering,
user_id,
room_id,
)
for ((user_id, room_id), summary) in summaries.items()
if summary.old_user_id is not None
),
)
txn.execute(
@@ -879,3 +961,15 @@ def _action_has_highlight(actions):
pass
return False
@attr.s
class _EventPushSummary:
"""Summary of pending event push actions for a given user in a given room.
Used in _rotate_notifs_before_txn to manipulate results from event_push_actions.
"""
unread_count = attr.ib(type=int)
stream_ordering = attr.ib(type=int)
old_user_id = attr.ib(type=str)
notif_count = attr.ib(type=int)

View File

@@ -1298,9 +1298,9 @@ class PersistEventsStore:
sql = """
INSERT INTO event_push_actions (
room_id, event_id, user_id, actions, stream_ordering,
topological_ordering, notif, highlight
topological_ordering, notif, highlight, unread
)
SELECT ?, event_id, user_id, actions, ?, ?, notif, highlight
SELECT ?, event_id, user_id, actions, ?, ?, notif, highlight, unread
FROM event_push_actions_staging
WHERE event_id = ?
"""

View File

@@ -0,0 +1,26 @@
/* Copyright 2020 The Matrix.org Foundation C.I.C.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- We're hijacking the push actions to store unread messages and unread counts (specified
-- in MSC2654) because doing otherwise would result in either performance issues or
-- reimplementing a consequent bit of the push actions.
-- Add columns to event_push_actions and event_push_actions_staging to track unread
-- messages and calculate unread counts.
ALTER TABLE event_push_actions_staging ADD COLUMN unread SMALLINT NOT NULL DEFAULT 0;
ALTER TABLE event_push_actions ADD COLUMN unread SMALLINT NOT NULL DEFAULT 0;
-- Add column to event_push_summary
ALTER TABLE event_push_summary ADD COLUMN unread_count BIGINT NOT NULL DEFAULT 0;

View File

@@ -39,7 +39,7 @@ what sort order was used:
import abc
import logging
from collections import namedtuple
from typing import Dict, Iterable, List, Optional, Tuple
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple
from twisted.internet import defer
@@ -47,12 +47,19 @@ from synapse.api.filtering import Filter
from synapse.events import EventBase
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import DatabasePool, make_in_list_sql_clause
from synapse.storage.database import (
DatabasePool,
LoggingTransaction,
make_in_list_sql_clause,
)
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
from synapse.types import RoomStreamToken
from synapse.types import Collection, RoomStreamToken
from synapse.util.caches.stream_change_cache import StreamChangeCache
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@@ -202,7 +209,7 @@ def _make_generic_sql_bound(
)
def filter_to_clause(event_filter: Filter) -> Tuple[str, List[str]]:
def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]:
# NB: This may create SQL clauses that don't optimise well (and we don't
# have indices on all possible clauses). E.g. it may create
# "room_id == X AND room_id != X", which postgres doesn't optimise.
@@ -260,7 +267,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
__metaclass__ = abc.ABCMeta
def __init__(self, database: DatabasePool, db_conn, hs):
def __init__(self, database: DatabasePool, db_conn, hs: "HomeServer"):
super(StreamWorkerStore, self).__init__(database, db_conn, hs)
self._instance_name = hs.get_instance_name()
@@ -293,16 +300,16 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
self._stream_order_on_start = self.get_room_max_stream_ordering()
@abc.abstractmethod
def get_room_max_stream_ordering(self):
def get_room_max_stream_ordering(self) -> int:
raise NotImplementedError()
@abc.abstractmethod
def get_room_min_stream_ordering(self):
def get_room_min_stream_ordering(self) -> int:
raise NotImplementedError()
async def get_room_events_stream_for_rooms(
self,
room_ids: Iterable[str],
room_ids: Collection[str],
from_key: str,
to_key: str,
limit: int = 0,
@@ -356,19 +363,21 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return results
def get_rooms_that_changed(self, room_ids, from_key):
def get_rooms_that_changed(
self, room_ids: Collection[str], from_key: str
) -> Set[str]:
"""Given a list of rooms and a token, return rooms where there may have
been changes.
Args:
room_ids (list)
from_key (str): The room_key portion of a StreamToken
room_ids
from_key: The room_key portion of a StreamToken
"""
from_key = RoomStreamToken.parse_stream_token(from_key).stream
from_id = RoomStreamToken.parse_stream_token(from_key).stream
return {
room_id
for room_id in room_ids
if self._events_stream_cache.has_entity_changed(room_id, from_key)
if self._events_stream_cache.has_entity_changed(room_id, from_id)
}
async def get_room_events_stream_for_room(
@@ -440,7 +449,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
return ret, key
async def get_membership_changes_for_user(self, user_id, from_key, to_key):
async def get_membership_changes_for_user(
self, user_id: str, from_key: str, to_key: str
) -> List[EventBase]:
from_id = RoomStreamToken.parse_stream_token(from_key).stream
to_id = RoomStreamToken.parse_stream_token(to_key).stream
@@ -593,8 +604,19 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
Returns:
A stream ID.
"""
return await self.db_pool.simple_select_one_onecol(
table="events", keyvalues={"event_id": event_id}, retcol="stream_ordering"
return await self.db_pool.runInteraction(
"get_stream_id_for_event", self.get_stream_id_for_event_txn, event_id,
)
def get_stream_id_for_event_txn(
self, txn: LoggingTransaction, event_id: str, allow_none=False,
) -> int:
return self.db_pool.simple_select_one_onecol_txn(
txn=txn,
table="events",
keyvalues={"event_id": event_id},
retcol="stream_ordering",
allow_none=allow_none,
)
async def get_stream_token_for_event(self, event_id: str) -> str:
@@ -646,7 +668,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
)
return row[0][0] if row else 0
def _get_max_topological_txn(self, txn, room_id):
def _get_max_topological_txn(self, txn: LoggingTransaction, room_id: str) -> int:
txn.execute(
"SELECT MAX(topological_ordering) FROM events WHERE room_id = ?",
(room_id,),
@@ -719,7 +741,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
def _get_events_around_txn(
self,
txn,
txn: LoggingTransaction,
room_id: str,
event_id: str,
before_limit: int,
@@ -747,6 +769,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
retcols=["stream_ordering", "topological_ordering"],
)
# This cannot happen as `allow_none=False`.
assert results is not None
# Paginating backwards includes the event at the token, but paginating
# forward doesn't.
before_token = RoomStreamToken(
@@ -856,7 +881,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
desc="update_federation_out_pos",
)
def _reset_federation_positions_txn(self, txn) -> None:
def _reset_federation_positions_txn(self, txn: LoggingTransaction) -> None:
"""Fiddles with the `federation_stream_position` table to make it match
the configured federation sender instances during start up.
"""
@@ -895,7 +920,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
GROUP BY type
"""
txn.execute(sql)
min_positions = dict(txn) # Map from type -> min position
min_positions = {typ: pos for typ, pos in txn} # Map from type -> min position
# Ensure we do actually have some values here
assert set(min_positions) == {"federation", "events"}
@@ -922,7 +947,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
def _paginate_room_events_txn(
self,
txn,
txn: LoggingTransaction,
room_id: str,
from_token: RoomStreamToken,
to_token: Optional[RoomStreamToken] = None,

View File

@@ -160,7 +160,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
self.check(
"get_unread_event_push_actions_by_room_for_user",
[ROOM_ID, USER_ID_2, event1.event_id],
{"highlight_count": 0, "notify_count": 0},
{"highlight_count": 0, "unread_count": 0, "notify_count": 0},
)
self.persist(
@@ -173,7 +173,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
self.check(
"get_unread_event_push_actions_by_room_for_user",
[ROOM_ID, USER_ID_2, event1.event_id],
{"highlight_count": 0, "notify_count": 1},
{"highlight_count": 0, "unread_count": 0, "notify_count": 1},
)
self.persist(
@@ -188,7 +188,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
self.check(
"get_unread_event_push_actions_by_room_for_user",
[ROOM_ID, USER_ID_2, event1.event_id],
{"highlight_count": 1, "notify_count": 2},
{"highlight_count": 1, "unread_count": 0, "notify_count": 2},
)
def test_get_rooms_for_user_with_stream_ordering(self):
@@ -368,7 +368,9 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
self.get_success(
self.master_store.add_push_actions_to_staging(
event.event_id, {user_id: actions for user_id, actions in push_actions}
event.event_id,
{user_id: actions for user_id, actions in push_actions},
False,
)
)
return event, context

View File

@@ -16,9 +16,9 @@
import json
import synapse.rest.admin
from synapse.api.constants import EventContentFields, EventTypes
from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
from synapse.rest.client.v1 import login, room
from synapse.rest.client.v2_alpha import sync
from synapse.rest.client.v2_alpha import read_marker, sync
from tests import unittest
from tests.server import TimedOutException
@@ -324,3 +324,156 @@ class SyncTypingTests(unittest.HomeserverTestCase):
"GET", sync_url % (access_token, next_batch)
)
self.assertRaises(TimedOutException, self.render, request)
class UnreadMessagesTestCase(unittest.HomeserverTestCase):
servlets = [
synapse.rest.admin.register_servlets,
login.register_servlets,
read_marker.register_servlets,
room.register_servlets,
sync.register_servlets,
]
def prepare(self, reactor, clock, hs):
self.url = "/sync?since=%s"
self.next_batch = "s0"
# Register the first user (used to check the unread counts).
self.user_id = self.register_user("kermit", "monkey")
self.tok = self.login("kermit", "monkey")
# Create the room we'll check unread counts for.
self.room_id = self.helper.create_room_as(self.user_id, tok=self.tok)
# Register the second user (used to send events to the room).
self.user2 = self.register_user("kermit2", "monkey")
self.tok2 = self.login("kermit2", "monkey")
# Change the power levels of the room so that the second user can send state
# events.
self.helper.send_state(
self.room_id,
EventTypes.PowerLevels,
{
"users": {self.user_id: 100, self.user2: 100},
"users_default": 0,
"events": {
"m.room.name": 50,
"m.room.power_levels": 100,
"m.room.history_visibility": 100,
"m.room.canonical_alias": 50,
"m.room.avatar": 50,
"m.room.tombstone": 100,
"m.room.server_acl": 100,
"m.room.encryption": 100,
},
"events_default": 0,
"state_default": 50,
"ban": 50,
"kick": 50,
"redact": 50,
"invite": 0,
},
tok=self.tok,
)
def test_unread_counts(self):
"""Tests that /sync returns the right value for the unread count (MSC2654)."""
# Check that our own messages don't increase the unread count.
self.helper.send(self.room_id, "hello", tok=self.tok)
self._check_unread_count(0)
# Join the new user and check that this doesn't increase the unread count.
self.helper.join(room=self.room_id, user=self.user2, tok=self.tok2)
self._check_unread_count(0)
# Check that the new user sending a message increases our unread count.
res = self.helper.send(self.room_id, "hello", tok=self.tok2)
self._check_unread_count(1)
# Send a read receipt to tell the server we've read the latest event.
body = json.dumps({"m.read": res["event_id"]}).encode("utf8")
request, channel = self.make_request(
"POST",
"/rooms/%s/read_markers" % self.room_id,
body,
access_token=self.tok,
)
self.render(request)
self.assertEqual(channel.code, 200, channel.json_body)
# Check that the unread counter is back to 0.
self._check_unread_count(0)
# Check that room name changes increase the unread counter.
self.helper.send_state(
self.room_id, "m.room.name", {"name": "my super room"}, tok=self.tok2,
)
self._check_unread_count(1)
# Check that room topic changes increase the unread counter.
self.helper.send_state(
self.room_id, "m.room.topic", {"topic": "welcome!!!"}, tok=self.tok2,
)
self._check_unread_count(2)
# Check that encrypted messages increase the unread counter.
self.helper.send_event(self.room_id, EventTypes.Encrypted, {}, tok=self.tok2)
self._check_unread_count(3)
# Check that custom events with a body increase the unread counter.
self.helper.send_event(
self.room_id, "org.matrix.custom_type", {"body": "hello"}, tok=self.tok2,
)
self._check_unread_count(4)
# Check that edits don't increase the unread counter.
self.helper.send_event(
room_id=self.room_id,
type=EventTypes.Message,
content={
"body": "hello",
"msgtype": "m.text",
"m.relates_to": {"rel_type": RelationTypes.REPLACE},
},
tok=self.tok2,
)
self._check_unread_count(4)
# Check that notices don't increase the unread counter.
self.helper.send_event(
room_id=self.room_id,
type=EventTypes.Message,
content={"body": "hello", "msgtype": "m.notice"},
tok=self.tok2,
)
self._check_unread_count(4)
# Check that tombstone events changes increase the unread counter.
self.helper.send_state(
self.room_id,
EventTypes.Tombstone,
{"replacement_room": "!someroom:test"},
tok=self.tok2,
)
self._check_unread_count(5)
def _check_unread_count(self, expected_count: True):
"""Syncs and compares the unread count with the expected value."""
request, channel = self.make_request(
"GET", self.url % self.next_batch, access_token=self.tok,
)
self.render(request)
self.assertEqual(channel.code, 200, channel.json_body)
room_entry = channel.json_body["rooms"]["join"][self.room_id]
self.assertEqual(
room_entry["org.matrix.msc2654.unread_count"], expected_count, room_entry,
)
# Store the next batch for the next request.
self.next_batch = channel.json_body["next_batch"]

View File

@@ -67,7 +67,11 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
)
self.assertEquals(
counts,
{"notify_count": noitf_count, "highlight_count": highlight_count},
{
"notify_count": noitf_count,
"unread_count": 0, # Unread counts are tested in the sync tests.
"highlight_count": highlight_count,
},
)
@defer.inlineCallbacks
@@ -80,7 +84,7 @@ class EventPushActionsStoreTestCase(tests.unittest.TestCase):
yield defer.ensureDeferred(
self.store.add_push_actions_to_staging(
event.event_id, {user_id: action}
event.event_id, {user_id: action}, False,
)
)
yield defer.ensureDeferred(