Compare commits
15 Commits
release-v1
...
clokep/ran
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f03935dcb7 | ||
|
|
cbbe77f620 | ||
|
|
3ca9a381ab | ||
|
|
a8a45921fb | ||
|
|
0c395fd8b9 | ||
|
|
ad6b7cf5c6 | ||
|
|
50ba4b34ff | ||
|
|
eebc6dfe38 | ||
|
|
580fbb740f | ||
|
|
7c320b79bf | ||
|
|
f43e0b4b1a | ||
|
|
82166cfa51 | ||
|
|
1ec3885aa9 | ||
|
|
48b2d6c9ef | ||
|
|
bc7e8a5e60 |
84
demo.py
Normal file
84
demo.py
Normal file
@@ -0,0 +1,84 @@
|
||||
import json
|
||||
from time import monotonic, sleep
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
HOMESERVER = "http://localhost:8080"
|
||||
|
||||
USER_1_TOK = "syt_dGVzdA_JUXtKQUUMnolcOezckNz_2eyt3H"
|
||||
USER_1_HEADERS = {"Authorization": f"Bearer {USER_1_TOK}"}
|
||||
|
||||
USER_2_TOK = "syt_c2Vjb25k_ElKwbhaNqTgpfgFQcStD_2aiOcs"
|
||||
USER_2_HEADERS = {"Authorization": f"Bearer {USER_2_TOK}"}
|
||||
|
||||
|
||||
def _check_for_status(result):
|
||||
# Similar to raise_for_status, but prints the error.
|
||||
if 400 <= result.status_code:
|
||||
error_msg = result.json()
|
||||
result.raise_for_status()
|
||||
print(error_msg)
|
||||
exit(0)
|
||||
|
||||
|
||||
def _send_event(room_id, content):
|
||||
# Send a msg to the room.
|
||||
result = requests.put(
|
||||
f"{HOMESERVER}/_matrix/client/v3/rooms/{room_id}/send/m.room.message/msg{monotonic()}",
|
||||
json=content,
|
||||
headers=USER_2_HEADERS,
|
||||
)
|
||||
_check_for_status(result)
|
||||
return result.json()["event_id"]
|
||||
|
||||
|
||||
def main():
|
||||
# Create a new room as user 2, add a bunch of messages.
|
||||
result = requests.post(
|
||||
f"{HOMESERVER}/_matrix/client/v3/createRoom",
|
||||
json={"visibility": "public", "name": f"Ranged Read Receipts ({monotonic()})"},
|
||||
headers=USER_2_HEADERS,
|
||||
)
|
||||
_check_for_status(result)
|
||||
room_id = result.json()["room_id"]
|
||||
|
||||
# Second user joins the room.
|
||||
result = requests.post(
|
||||
f"{HOMESERVER}/_matrix/client/v3/rooms/{room_id}/join", headers=USER_1_HEADERS
|
||||
)
|
||||
_check_for_status(result)
|
||||
|
||||
# User 2 sends some messages.
|
||||
thread_event_id = None
|
||||
|
||||
def _send(body, thread_id = None):
|
||||
content = {
|
||||
"msgtype": "m.text",
|
||||
"body": body,
|
||||
}
|
||||
if thread_id:
|
||||
content["m.relates_to"] = {
|
||||
"rel_type": "m.thread",
|
||||
"event_id": thread_id,
|
||||
}
|
||||
return _send_event(room_id, content)
|
||||
|
||||
for msg in range(10):
|
||||
event_id = _send(f"Message {msg}")
|
||||
if msg % 5 == 0:
|
||||
sleep(3)
|
||||
thread_event_id = event_id
|
||||
|
||||
for msg in range(60):
|
||||
if msg % 3 == 0:
|
||||
_send(f"More message {msg}")
|
||||
else:
|
||||
_send(f"Thread message {msg}", thread_event_id)
|
||||
|
||||
if msg % 5 == 0:
|
||||
sleep(3)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
174
rrr_test.py
Normal file
174
rrr_test.py
Normal file
@@ -0,0 +1,174 @@
|
||||
import json
|
||||
from time import monotonic, sleep
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
HOMESERVER = "http://localhost:8080"
|
||||
|
||||
USER_1_TOK = "syt_dGVzdA_JUXtKQUUMnolcOezckNz_2eyt3H"
|
||||
USER_1_HEADERS = {"Authorization": f"Bearer {USER_1_TOK}"}
|
||||
|
||||
USER_2_TOK = "syt_c2Vjb25k_ElKwbhaNqTgpfgFQcStD_2aiOcs"
|
||||
USER_2_HEADERS = {"Authorization": f"Bearer {USER_2_TOK}"}
|
||||
|
||||
|
||||
def _check_for_status(result):
|
||||
# Similar to raise_for_status, but prints the error.
|
||||
if 400 <= result.status_code:
|
||||
error_msg = result.json()
|
||||
result.raise_for_status()
|
||||
print(error_msg)
|
||||
exit(0)
|
||||
|
||||
|
||||
def _sync_and_show(room_id):
|
||||
print("Syncing . . .")
|
||||
result = requests.get(
|
||||
f"{HOMESERVER}/_matrix/client/v3/sync",
|
||||
headers=USER_1_HEADERS,
|
||||
params={"filter": json.dumps({"room": {"timeline": {"limit": 30}}})},
|
||||
)
|
||||
_check_for_status(result)
|
||||
sync_response = result.json()
|
||||
|
||||
room = sync_response["rooms"]["join"][room_id]
|
||||
|
||||
# Find read receipts (this assumes non-overlapping).
|
||||
read_receipt_starts = {} # start event -> users
|
||||
read_receipt_ends = {} # end event -> users
|
||||
for event in room["ephemeral"]["events"]:
|
||||
if event["type"] != "m.receipt":
|
||||
continue
|
||||
|
||||
for event_id, content in event["content"].items():
|
||||
for mxid, receipt in content["m.read"].items():
|
||||
# Just care about the localpart of the MXID.
|
||||
mxid = mxid.split(":", 1)[0]
|
||||
read_receipt_starts.setdefault(
|
||||
receipt.get("start_event_id"), []
|
||||
).append(mxid)
|
||||
read_receipt_ends.setdefault(event_id, []).append(mxid)
|
||||
|
||||
print(room["unread_notifications"])
|
||||
|
||||
if None in read_receipt_starts:
|
||||
user_ids = ", ".join(sorted(read_receipt_starts[None]))
|
||||
print(f"v--------- {user_ids} ---------v")
|
||||
|
||||
for event in room["timeline"]["events"]:
|
||||
event_id = event["event_id"]
|
||||
|
||||
if event_id in read_receipt_starts:
|
||||
user_ids = ", ".join(read_receipt_starts[event_id])
|
||||
print(f"v--------- {user_ids} ---------v")
|
||||
|
||||
if event["type"] == "m.room.message":
|
||||
msg = event["content"]["body"]
|
||||
print(msg)
|
||||
|
||||
if event_id in read_receipt_ends:
|
||||
user_ids = ", ".join(sorted(read_receipt_ends[event_id]))
|
||||
print(f"^--------- {user_ids} ---------^")
|
||||
|
||||
print()
|
||||
print()
|
||||
|
||||
return event_id
|
||||
|
||||
|
||||
def _send_event(room_id, body, prev_event_id = None):
|
||||
args = {"prev_event_id": prev_event_id}
|
||||
|
||||
# Send a msg to the room.
|
||||
result = requests.put(
|
||||
f"{HOMESERVER}/_matrix/client/v3/rooms/{room_id}/send/m.room.message/msg{monotonic()}",
|
||||
json={
|
||||
"msgtype": "m.text",
|
||||
"body": body,
|
||||
},
|
||||
params=args,
|
||||
headers=USER_2_HEADERS,
|
||||
)
|
||||
_check_for_status(result)
|
||||
return result.json()["event_id"]
|
||||
|
||||
|
||||
def main():
|
||||
# Create a new room as user 2, add a bunch of messages.
|
||||
result = requests.post(
|
||||
f"{HOMESERVER}/_matrix/client/v3/createRoom",
|
||||
json={"visibility": "public", "name": f"Ranged Read Receipts ({monotonic()})"},
|
||||
headers=USER_2_HEADERS,
|
||||
)
|
||||
_check_for_status(result)
|
||||
room_id = result.json()["room_id"]
|
||||
|
||||
# Second user joins the room.
|
||||
result = requests.post(
|
||||
f"{HOMESERVER}/_matrix/client/v3/rooms/{room_id}/join", headers=USER_1_HEADERS
|
||||
)
|
||||
_check_for_status(result)
|
||||
|
||||
# Sync user 1.
|
||||
last_event_id = first_event_id = _sync_and_show(room_id)
|
||||
|
||||
# User 2 sends some messages.
|
||||
event_ids = []
|
||||
|
||||
def _send_and_append(body, prev_message_id = None):
|
||||
event_id = _send_event(room_id, body, prev_message_id)
|
||||
event_ids.append(event_id)
|
||||
return event_id
|
||||
|
||||
prev_message_id = first_message_id = _send_and_append("Root")
|
||||
for msg in range(3):
|
||||
prev_message_id = _send_and_append(f"Fork 1 Message {msg}", prev_message_id)
|
||||
sleep(1)
|
||||
|
||||
# User 2 sends a read receipt.
|
||||
print("@second reads to end")
|
||||
result = requests.post(
|
||||
f"{HOMESERVER}/_matrix/client/v3/rooms/{room_id}/receipt/m.read/{event_ids[-1]}",
|
||||
headers=USER_2_HEADERS,
|
||||
json={},
|
||||
)
|
||||
_check_for_status(result)
|
||||
|
||||
_sync_and_show(room_id)
|
||||
|
||||
# User 1 sends a read receipt.
|
||||
print("@test reads from fork 1")
|
||||
result = requests.post(
|
||||
f"{HOMESERVER}/_matrix/client/v3/rooms/{room_id}/receipt/m.read/{event_ids[3]}/{event_ids[1]}",
|
||||
headers=USER_1_HEADERS,
|
||||
json={},
|
||||
)
|
||||
_check_for_status(result)
|
||||
|
||||
_sync_and_show(room_id)
|
||||
|
||||
# Create a fork in the DAG.
|
||||
prev_message_id = first_message_id
|
||||
for msg in range(3):
|
||||
prev_message_id = _send_and_append(f"Fork 2 Message {msg}", prev_message_id)
|
||||
sleep(1)
|
||||
# # Join the forks.
|
||||
_send_and_append("Tail")
|
||||
|
||||
_sync_and_show(room_id)
|
||||
|
||||
# User 1 sends another read receipt.
|
||||
print("@test reads everything")
|
||||
result = requests.post(
|
||||
f"{HOMESERVER}/_matrix/client/v3/rooms/{room_id}/receipt/m.read/{event_ids[-1]}/{event_ids[0]}",
|
||||
headers=USER_1_HEADERS,
|
||||
json={},
|
||||
)
|
||||
_check_for_status(result)
|
||||
|
||||
_sync_and_show(room_id)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -19,7 +19,9 @@ from synapse.appservice import ApplicationService
|
||||
from synapse.streams import EventSource
|
||||
from synapse.types import (
|
||||
JsonDict,
|
||||
RangedReadReceipt,
|
||||
ReadReceipt,
|
||||
Receipt,
|
||||
StreamKeyType,
|
||||
UserID,
|
||||
get_domain_from_id,
|
||||
@@ -65,7 +67,7 @@ class ReceiptsHandler:
|
||||
|
||||
async def _received_remote_receipt(self, origin: str, content: JsonDict) -> None:
|
||||
"""Called when we receive an EDU of type m.receipt from a remote HS."""
|
||||
receipts = []
|
||||
receipts: List[Receipt] = []
|
||||
for room_id, room_values in content.items():
|
||||
# If we're not in the room just ditch the event entirely. This is
|
||||
# probably an old server that has come back and thinks we're still in
|
||||
@@ -103,19 +105,13 @@ class ReceiptsHandler:
|
||||
|
||||
await self._handle_new_receipts(receipts)
|
||||
|
||||
async def _handle_new_receipts(self, receipts: List[ReadReceipt]) -> bool:
|
||||
async def _handle_new_receipts(self, receipts: List[Receipt]) -> bool:
|
||||
"""Takes a list of receipts, stores them and informs the notifier."""
|
||||
min_batch_id: Optional[int] = None
|
||||
max_batch_id: Optional[int] = None
|
||||
|
||||
for receipt in receipts:
|
||||
res = await self.store.insert_receipt(
|
||||
receipt.room_id,
|
||||
receipt.receipt_type,
|
||||
receipt.user_id,
|
||||
receipt.event_ids,
|
||||
receipt.data,
|
||||
)
|
||||
res = await self.store.insert_receipt(receipt)
|
||||
|
||||
if not res:
|
||||
# res will be None if this receipt is 'old'
|
||||
@@ -146,24 +142,45 @@ class ReceiptsHandler:
|
||||
return True
|
||||
|
||||
async def received_client_receipt(
|
||||
self, room_id: str, receipt_type: str, user_id: str, event_id: str
|
||||
self,
|
||||
room_id: str,
|
||||
receipt_type: str,
|
||||
user_id: str,
|
||||
end_event_id: str,
|
||||
start_event_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Called when a client tells us a local user has read up to the given
|
||||
event_id in the room.
|
||||
"""
|
||||
receipt = ReadReceipt(
|
||||
room_id=room_id,
|
||||
receipt_type=receipt_type,
|
||||
user_id=user_id,
|
||||
event_ids=[event_id],
|
||||
data={"ts": int(self.clock.time_msec())},
|
||||
)
|
||||
|
||||
if start_event_id:
|
||||
receipt: Receipt = RangedReadReceipt(
|
||||
room_id=room_id,
|
||||
receipt_type=receipt_type,
|
||||
user_id=user_id,
|
||||
start_event_id=start_event_id,
|
||||
end_event_id=end_event_id,
|
||||
data={"ts": int(self.clock.time_msec())},
|
||||
)
|
||||
else:
|
||||
receipt = ReadReceipt(
|
||||
room_id=room_id,
|
||||
receipt_type=receipt_type,
|
||||
user_id=user_id,
|
||||
event_ids=[end_event_id],
|
||||
data={"ts": int(self.clock.time_msec())},
|
||||
)
|
||||
|
||||
is_new = await self._handle_new_receipts([receipt])
|
||||
if not is_new:
|
||||
return
|
||||
|
||||
if self.federation_sender and receipt_type != ReceiptTypes.READ_PRIVATE:
|
||||
# XXX How to handle this for a ranged read receipt.
|
||||
if (
|
||||
isinstance(receipt, ReadReceipt)
|
||||
and self.federation_sender
|
||||
and receipt_type != ReceiptTypes.READ_PRIVATE
|
||||
):
|
||||
await self.federation_sender.send_read_receipt(receipt)
|
||||
|
||||
|
||||
|
||||
@@ -1052,7 +1052,7 @@ class SyncHandler:
|
||||
|
||||
async def unread_notifs_for_room_id(
|
||||
self, room_id: str, sync_config: SyncConfig
|
||||
) -> NotifCounts:
|
||||
) -> Dict[Optional[str], NotifCounts]:
|
||||
with Measure(self.clock, "unread_notifs_for_room_id"):
|
||||
last_unread_event_id = await self.store.get_last_receipt_event_id_for_user(
|
||||
user_id=sync_config.user.to_string(),
|
||||
@@ -2122,7 +2122,7 @@ class SyncHandler:
|
||||
)
|
||||
|
||||
if room_builder.rtype == "joined":
|
||||
unread_notifications: Dict[str, int] = {}
|
||||
unread_notifications: JsonDict = {}
|
||||
room_sync = JoinedSyncResult(
|
||||
room_id=room_id,
|
||||
timeline=batch,
|
||||
@@ -2137,10 +2137,18 @@ class SyncHandler:
|
||||
if room_sync or always_include:
|
||||
notifs = await self.unread_notifs_for_room_id(room_id, sync_config)
|
||||
|
||||
unread_notifications["notification_count"] = notifs.notify_count
|
||||
unread_notifications["highlight_count"] = notifs.highlight_count
|
||||
# Notifications for the main timeline.
|
||||
main_notifs = notifs[None]
|
||||
unread_notifications.update(main_notifs.to_dict())
|
||||
|
||||
room_sync.unread_count = notifs.unread_count
|
||||
room_sync.unread_count = main_notifs.unread_count
|
||||
|
||||
# And add info for each thread.
|
||||
unread_notifications["unread_thread_notifications"] = {
|
||||
thread_id: thread_notifs.to_dict()
|
||||
for thread_id, thread_notifs in notifs.items()
|
||||
if thread_id is not None
|
||||
}
|
||||
|
||||
sync_result_builder.joined.append(room_sync)
|
||||
|
||||
|
||||
@@ -195,7 +195,7 @@ class BulkPushRuleEvaluator:
|
||||
return pl_event.content if pl_event else {}, sender_level
|
||||
|
||||
async def _get_mutual_relations(
|
||||
self, event: EventBase, rules: Iterable[Dict[str, Any]]
|
||||
self, parent_id: str, rules: Iterable[Dict[str, Any]]
|
||||
) -> Dict[str, Set[Tuple[str, str]]]:
|
||||
"""
|
||||
Fetch event metadata for events which related to the same event as the given event.
|
||||
@@ -203,7 +203,7 @@ class BulkPushRuleEvaluator:
|
||||
If the given event has no relation information, returns an empty dictionary.
|
||||
|
||||
Args:
|
||||
event_id: The event ID which is targeted by relations.
|
||||
parent_id: The event ID which is targeted by relations.
|
||||
rules: The push rules which will be processed for this event.
|
||||
|
||||
Returns:
|
||||
@@ -217,12 +217,6 @@ class BulkPushRuleEvaluator:
|
||||
if not self._relations_match_enabled:
|
||||
return {}
|
||||
|
||||
# If the event does not have a relation, then cannot have any mutual
|
||||
# relations.
|
||||
relation = relation_from_event(event)
|
||||
if not relation:
|
||||
return {}
|
||||
|
||||
# Pre-filter to figure out which relation types are interesting.
|
||||
rel_types = set()
|
||||
for rule in rules:
|
||||
@@ -244,9 +238,7 @@ class BulkPushRuleEvaluator:
|
||||
return {}
|
||||
|
||||
# If any valid rules were found, fetch the mutual relations.
|
||||
return await self.store.get_mutual_event_relations(
|
||||
relation.parent_id, rel_types
|
||||
)
|
||||
return await self.store.get_mutual_event_relations(parent_id, rel_types)
|
||||
|
||||
@measure_func("action_for_event_by_user")
|
||||
async def action_for_event_by_user(
|
||||
@@ -272,9 +264,18 @@ class BulkPushRuleEvaluator:
|
||||
sender_power_level,
|
||||
) = await self._get_power_levels_and_sender_level(event, context)
|
||||
|
||||
relations = await self._get_mutual_relations(
|
||||
event, itertools.chain(*rules_by_user.values())
|
||||
)
|
||||
relation = relation_from_event(event)
|
||||
# If the event does not have a relation, then cannot have any mutual
|
||||
# relations or thread ID.
|
||||
relations = {}
|
||||
thread_id = None
|
||||
if relation:
|
||||
relations = await self._get_mutual_relations(
|
||||
relation.parent_id, itertools.chain(*rules_by_user.values())
|
||||
)
|
||||
# XXX Does this need to point to a valid parent ID or anything?
|
||||
if relation.rel_type == RelationTypes.THREAD:
|
||||
thread_id = relation.parent_id
|
||||
|
||||
evaluator = PushRuleEvaluatorForEvent(
|
||||
event,
|
||||
@@ -339,6 +340,7 @@ class BulkPushRuleEvaluator:
|
||||
event.event_id,
|
||||
actions_by_user,
|
||||
count_as_unread,
|
||||
thread_id,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -39,7 +39,10 @@ async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) -
|
||||
room_id, user_id, last_unread_event_id
|
||||
)
|
||||
)
|
||||
if notifs.notify_count == 0:
|
||||
# Combine the counts from all the threads.
|
||||
notify_count = sum(n.notify_count for n in notifs.values())
|
||||
|
||||
if notify_count == 0:
|
||||
continue
|
||||
|
||||
if group_by_room:
|
||||
@@ -47,7 +50,7 @@ async def get_badge_count(store: DataStore, user_id: str, group_by_room: bool) -
|
||||
badge += 1
|
||||
else:
|
||||
# increment the badge count by the number of unread messages in the room
|
||||
badge += notifs.notify_count
|
||||
badge += notify_count
|
||||
return badge
|
||||
|
||||
|
||||
|
||||
@@ -407,8 +407,8 @@ class FederationSenderHandler:
|
||||
receipt.room_id,
|
||||
receipt.receipt_type,
|
||||
receipt.user_id,
|
||||
[receipt.event_id],
|
||||
receipt.data,
|
||||
event_ids=[receipt.event_id],
|
||||
data=receipt.data,
|
||||
)
|
||||
await self.federation_sender.send_read_receipt(receipt_info)
|
||||
|
||||
|
||||
@@ -71,7 +71,7 @@ class ReadMarkerRestServlet(RestServlet):
|
||||
room_id,
|
||||
ReceiptTypes.READ,
|
||||
user_id=requester.user.to_string(),
|
||||
event_id=read_event_id,
|
||||
end_event_id=read_event_id,
|
||||
)
|
||||
|
||||
read_private_event_id = body.get(ReceiptTypes.READ_PRIVATE, None)
|
||||
@@ -80,7 +80,7 @@ class ReadMarkerRestServlet(RestServlet):
|
||||
room_id,
|
||||
ReceiptTypes.READ_PRIVATE,
|
||||
user_id=requester.user.to_string(),
|
||||
event_id=read_private_event_id,
|
||||
end_event_id=read_private_event_id,
|
||||
)
|
||||
|
||||
read_marker_event_id = body.get(ReceiptTypes.FULLY_READ, None)
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
from typing import TYPE_CHECKING, Optional, Tuple
|
||||
|
||||
from synapse.api.constants import ReceiptTypes
|
||||
from synapse.api.errors import SynapseError
|
||||
@@ -34,7 +34,8 @@ class ReceiptRestServlet(RestServlet):
|
||||
PATTERNS = client_patterns(
|
||||
"/rooms/(?P<room_id>[^/]*)"
|
||||
"/receipt/(?P<receipt_type>[^/]*)"
|
||||
"/(?P<event_id>[^/]*)$"
|
||||
"/(?P<end_event_id>[^/]*)"
|
||||
"(/(?P<start_event_id>[^/]*))?$"
|
||||
)
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
@@ -46,7 +47,12 @@ class ReceiptRestServlet(RestServlet):
|
||||
self.presence_handler = hs.get_presence_handler()
|
||||
|
||||
async def on_POST(
|
||||
self, request: SynapseRequest, room_id: str, receipt_type: str, event_id: str
|
||||
self,
|
||||
request: SynapseRequest,
|
||||
room_id: str,
|
||||
receipt_type: str,
|
||||
end_event_id: str,
|
||||
start_event_id: Optional[str] = None,
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
|
||||
@@ -73,14 +79,15 @@ class ReceiptRestServlet(RestServlet):
|
||||
await self.read_marker_handler.received_client_read_marker(
|
||||
room_id,
|
||||
user_id=requester.user.to_string(),
|
||||
event_id=event_id,
|
||||
event_id=end_event_id,
|
||||
)
|
||||
else:
|
||||
await self.receipts_handler.received_client_receipt(
|
||||
room_id,
|
||||
receipt_type,
|
||||
user_id=requester.user.to_string(),
|
||||
event_id=event_id,
|
||||
end_event_id=end_event_id,
|
||||
start_event_id=start_event_id,
|
||||
)
|
||||
|
||||
return 200, {}
|
||||
|
||||
@@ -268,12 +268,15 @@ class RoomSendEventRestServlet(TransactionRestServlet):
|
||||
if b"ts" in request.args and requester.app_service:
|
||||
event_dict["origin_server_ts"] = parse_integer(request, "ts", 0)
|
||||
|
||||
# XXX Horrible hack.
|
||||
prev_event_ids = parse_strings_from_args(request.args, "prev_event_id")
|
||||
|
||||
try:
|
||||
(
|
||||
event,
|
||||
_,
|
||||
) = await self.event_creation_handler.create_and_send_nonmember_event(
|
||||
requester, event_dict, txn_id=txn_id
|
||||
requester, event_dict, txn_id=txn_id, prev_event_ids=prev_event_ids
|
||||
)
|
||||
event_id = event.event_id
|
||||
except ShadowBanError:
|
||||
|
||||
@@ -24,6 +24,7 @@ from synapse.storage.database import (
|
||||
LoggingDatabaseConnection,
|
||||
LoggingTransaction,
|
||||
)
|
||||
from synapse.types import JsonDict
|
||||
from synapse.util import json_encoder
|
||||
from synapse.util.caches.descriptors import cached
|
||||
|
||||
@@ -79,7 +80,7 @@ class UserPushAction(EmailPushAction):
|
||||
profile_tag: str
|
||||
|
||||
|
||||
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||
@attr.s(slots=True, auto_attribs=True)
|
||||
class NotifCounts:
|
||||
"""
|
||||
The per-user, per-room count of notifications. Used by sync and push.
|
||||
@@ -89,6 +90,12 @@ class NotifCounts:
|
||||
unread_count: int
|
||||
highlight_count: int
|
||||
|
||||
def to_dict(self) -> JsonDict:
|
||||
return {
|
||||
"notification_count": self.notify_count,
|
||||
"highlight_count": self.highlight_count,
|
||||
}
|
||||
|
||||
|
||||
def _serialize_action(actions: List[Union[dict, str]], is_highlight: bool) -> str:
|
||||
"""Custom serializer for actions. This allows us to "compress" common actions.
|
||||
@@ -148,13 +155,13 @@ class EventPushActionsWorkerStore(SQLBaseStore):
|
||||
self._rotate_notifs, 30 * 60 * 1000
|
||||
)
|
||||
|
||||
@cached(num_args=3, tree=True, max_entries=5000)
|
||||
@cached(max_entries=5000, tree=True, iterable=True)
|
||||
async def get_unread_event_push_actions_by_room_for_user(
|
||||
self,
|
||||
room_id: str,
|
||||
user_id: str,
|
||||
last_read_event_id: Optional[str],
|
||||
) -> NotifCounts:
|
||||
) -> Dict[Optional[str], NotifCounts]:
|
||||
"""Get the notification count, the highlight count and the unread message count
|
||||
for a given user in a given room after the given read receipt.
|
||||
|
||||
@@ -187,7 +194,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
|
||||
room_id: str,
|
||||
user_id: str,
|
||||
last_read_event_id: Optional[str],
|
||||
) -> NotifCounts:
|
||||
) -> Dict[Optional[str], NotifCounts]:
|
||||
stream_ordering = None
|
||||
|
||||
if last_read_event_id is not None:
|
||||
@@ -217,49 +224,63 @@ class EventPushActionsWorkerStore(SQLBaseStore):
|
||||
|
||||
def _get_unread_counts_by_pos_txn(
|
||||
self, txn: LoggingTransaction, room_id: str, user_id: str, stream_ordering: int
|
||||
) -> NotifCounts:
|
||||
sql = (
|
||||
"SELECT"
|
||||
" COUNT(CASE WHEN notif = 1 THEN 1 END),"
|
||||
" COUNT(CASE WHEN highlight = 1 THEN 1 END),"
|
||||
" COUNT(CASE WHEN unread = 1 THEN 1 END)"
|
||||
" FROM event_push_actions ea"
|
||||
" WHERE user_id = ?"
|
||||
" AND room_id = ?"
|
||||
" AND stream_ordering > ?"
|
||||
)
|
||||
) -> Dict[Optional[str], NotifCounts]:
|
||||
sql = """
|
||||
SELECT
|
||||
COUNT(CASE WHEN notif = 1 THEN 1 END),
|
||||
COUNT(CASE WHEN highlight = 1 THEN 1 END),
|
||||
COUNT(CASE WHEN unread = 1 THEN 1 END),
|
||||
thread_id
|
||||
FROM event_push_actions ea
|
||||
WHERE user_id = ?
|
||||
AND room_id = ?
|
||||
AND stream_ordering > ?
|
||||
GROUP BY thread_id
|
||||
"""
|
||||
|
||||
txn.execute(sql, (user_id, room_id, stream_ordering))
|
||||
row = txn.fetchone()
|
||||
rows = txn.fetchall()
|
||||
|
||||
(notif_count, highlight_count, unread_count) = (0, 0, 0)
|
||||
|
||||
if row:
|
||||
(notif_count, highlight_count, unread_count) = row
|
||||
notif_counts: Dict[Optional[str], NotifCounts] = {
|
||||
# Ensure the main timeline has notification counts.
|
||||
None: NotifCounts(
|
||||
notify_count=0,
|
||||
unread_count=0,
|
||||
highlight_count=0,
|
||||
)
|
||||
}
|
||||
for notif_count, highlight_count, unread_count, thread_id in rows:
|
||||
notif_counts[thread_id] = NotifCounts(
|
||||
notify_count=notif_count,
|
||||
unread_count=unread_count,
|
||||
highlight_count=highlight_count,
|
||||
)
|
||||
|
||||
txn.execute(
|
||||
"""
|
||||
SELECT notif_count, unread_count FROM event_push_summary
|
||||
SELECT notif_count, unread_count, thread_id FROM event_push_summary
|
||||
WHERE room_id = ? AND user_id = ? AND stream_ordering > ?
|
||||
""",
|
||||
(room_id, user_id, stream_ordering),
|
||||
)
|
||||
row = txn.fetchone()
|
||||
rows = txn.fetchall()
|
||||
|
||||
if row:
|
||||
notif_count += row[0]
|
||||
for notif_count, unread_count, thread_id in rows:
|
||||
if unread_count is None:
|
||||
# The unread_count column of event_push_summary is NULLable.
|
||||
unread_count = 0
|
||||
|
||||
if row[1] is not None:
|
||||
# The unread_count column of event_push_summary is NULLable, so we need
|
||||
# to make sure we don't try increasing the unread counts if it's NULL
|
||||
# for this row.
|
||||
unread_count += row[1]
|
||||
if thread_id in notif_counts:
|
||||
notif_counts[thread_id].notify_count += notif_count
|
||||
notif_counts[thread_id].unread_count += unread_count
|
||||
else:
|
||||
notif_counts[thread_id] = NotifCounts(
|
||||
notify_count=notif_count,
|
||||
unread_count=unread_count,
|
||||
highlight_count=0,
|
||||
)
|
||||
|
||||
return NotifCounts(
|
||||
notify_count=notif_count,
|
||||
unread_count=unread_count,
|
||||
highlight_count=highlight_count,
|
||||
)
|
||||
return notif_counts
|
||||
|
||||
async def get_push_action_users_in_range(
|
||||
self, min_stream_ordering: int, max_stream_ordering: int
|
||||
@@ -528,6 +549,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
|
||||
event_id: str,
|
||||
user_id_actions: Dict[str, List[Union[dict, str]]],
|
||||
count_as_unread: bool,
|
||||
thread_id: Optional[str],
|
||||
) -> None:
|
||||
"""Add the push actions for the event to the push action staging area.
|
||||
|
||||
@@ -536,6 +558,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
|
||||
user_id_actions: A mapping of user_id to list of push actions, where
|
||||
an action can either be a string or dict.
|
||||
count_as_unread: Whether this event should increment unread counts.
|
||||
thread_id: The thread this event is parent of, if applicable.
|
||||
"""
|
||||
if not user_id_actions:
|
||||
return
|
||||
@@ -544,7 +567,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
|
||||
# can be used to insert into the `event_push_actions_staging` table.
|
||||
def _gen_entry(
|
||||
user_id: str, actions: List[Union[dict, str]]
|
||||
) -> Tuple[str, str, str, int, int, int]:
|
||||
) -> Tuple[str, str, str, int, int, int, Optional[str]]:
|
||||
is_highlight = 1 if _action_has_highlight(actions) else 0
|
||||
notif = 1 if "notify" in actions else 0
|
||||
return (
|
||||
@@ -554,6 +577,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
|
||||
notif, # notif column
|
||||
is_highlight, # highlight column
|
||||
int(count_as_unread), # unread column
|
||||
thread_id, # thread_id column
|
||||
)
|
||||
|
||||
def _add_push_actions_to_staging_txn(txn: LoggingTransaction) -> None:
|
||||
@@ -562,8 +586,8 @@ class EventPushActionsWorkerStore(SQLBaseStore):
|
||||
|
||||
sql = """
|
||||
INSERT INTO event_push_actions_staging
|
||||
(event_id, user_id, actions, notif, highlight, unread)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
(event_id, user_id, actions, notif, highlight, unread, thread_id)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||
"""
|
||||
|
||||
txn.execute_batch(
|
||||
@@ -810,20 +834,20 @@ class EventPushActionsWorkerStore(SQLBaseStore):
|
||||
|
||||
# Calculate the new counts that should be upserted into event_push_summary
|
||||
sql = """
|
||||
SELECT user_id, room_id,
|
||||
SELECT user_id, room_id, thread_id,
|
||||
coalesce(old.%s, 0) + upd.cnt,
|
||||
upd.stream_ordering,
|
||||
old.user_id
|
||||
FROM (
|
||||
SELECT user_id, room_id, count(*) as cnt,
|
||||
SELECT user_id, room_id, thread_id, count(*) as cnt,
|
||||
max(stream_ordering) as stream_ordering
|
||||
FROM event_push_actions
|
||||
WHERE ? <= stream_ordering AND stream_ordering < ?
|
||||
AND highlight = 0
|
||||
AND %s = 1
|
||||
GROUP BY user_id, room_id
|
||||
GROUP BY user_id, room_id, thread_id
|
||||
) AS upd
|
||||
LEFT JOIN event_push_summary AS old USING (user_id, room_id)
|
||||
LEFT JOIN event_push_summary AS old USING (user_id, room_id, thread_id)
|
||||
"""
|
||||
|
||||
# First get the count of unread messages.
|
||||
@@ -837,12 +861,12 @@ class EventPushActionsWorkerStore(SQLBaseStore):
|
||||
# object because we might not have the same amount of rows in each of them. To do
|
||||
# this, we use a dict indexed on the user ID and room ID to make it easier to
|
||||
# populate.
|
||||
summaries: Dict[Tuple[str, str], _EventPushSummary] = {}
|
||||
summaries: Dict[Tuple[str, str, Optional[str]], _EventPushSummary] = {}
|
||||
for row in txn:
|
||||
summaries[(row[0], row[1])] = _EventPushSummary(
|
||||
unread_count=row[2],
|
||||
stream_ordering=row[3],
|
||||
old_user_id=row[4],
|
||||
summaries[(row[0], row[1], row[2])] = _EventPushSummary(
|
||||
unread_count=row[3],
|
||||
stream_ordering=row[4],
|
||||
old_user_id=row[5],
|
||||
notif_count=0,
|
||||
)
|
||||
|
||||
@@ -853,18 +877,18 @@ class EventPushActionsWorkerStore(SQLBaseStore):
|
||||
)
|
||||
|
||||
for row in txn:
|
||||
if (row[0], row[1]) in summaries:
|
||||
summaries[(row[0], row[1])].notif_count = row[2]
|
||||
if (row[0], row[1], row[2]) in summaries:
|
||||
summaries[(row[0], row[1], row[2])].notif_count = row[3]
|
||||
else:
|
||||
# Because the rules on notifying are different than the rules on marking
|
||||
# a message unread, we might end up with messages that notify but aren't
|
||||
# marked unread, so we might not have a summary for this (user, room)
|
||||
# tuple to complete.
|
||||
summaries[(row[0], row[1])] = _EventPushSummary(
|
||||
summaries[(row[0], row[1], row[2])] = _EventPushSummary(
|
||||
unread_count=0,
|
||||
stream_ordering=row[3],
|
||||
old_user_id=row[4],
|
||||
notif_count=row[2],
|
||||
stream_ordering=row[4],
|
||||
old_user_id=row[5],
|
||||
notif_count=row[3],
|
||||
)
|
||||
|
||||
logger.info("Rotating notifications, handling %d rows", len(summaries))
|
||||
@@ -881,6 +905,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
|
||||
"notif_count",
|
||||
"unread_count",
|
||||
"stream_ordering",
|
||||
"thread_id",
|
||||
),
|
||||
values=[
|
||||
(
|
||||
@@ -889,8 +914,9 @@ class EventPushActionsWorkerStore(SQLBaseStore):
|
||||
summary.notif_count,
|
||||
summary.unread_count,
|
||||
summary.stream_ordering,
|
||||
thread_id,
|
||||
)
|
||||
for ((user_id, room_id), summary) in summaries.items()
|
||||
for ((user_id, room_id, thread_id), summary) in summaries.items()
|
||||
if summary.old_user_id is None
|
||||
],
|
||||
)
|
||||
@@ -899,7 +925,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
|
||||
"""
|
||||
UPDATE event_push_summary
|
||||
SET notif_count = ?, unread_count = ?, stream_ordering = ?
|
||||
WHERE user_id = ? AND room_id = ?
|
||||
WHERE user_id = ? AND room_id = ? AND thread_id = ?
|
||||
""",
|
||||
(
|
||||
(
|
||||
@@ -908,8 +934,9 @@ class EventPushActionsWorkerStore(SQLBaseStore):
|
||||
summary.stream_ordering,
|
||||
user_id,
|
||||
room_id,
|
||||
thread_id,
|
||||
)
|
||||
for ((user_id, room_id), summary) in summaries.items()
|
||||
for ((user_id, room_id, thread_id), summary) in summaries.items()
|
||||
if summary.old_user_id is not None
|
||||
),
|
||||
)
|
||||
@@ -927,8 +954,13 @@ class EventPushActionsWorkerStore(SQLBaseStore):
|
||||
(rotate_to_stream_ordering,),
|
||||
)
|
||||
|
||||
def _remove_old_push_actions_before_txn(
|
||||
self, txn: LoggingTransaction, room_id: str, user_id: str, stream_ordering: int
|
||||
def _remove_old_push_actions_txn(
|
||||
self,
|
||||
txn: LoggingTransaction,
|
||||
room_id: str,
|
||||
user_id: str,
|
||||
end_stream_ordering: int,
|
||||
start_stream_ordering: Optional[int],
|
||||
) -> None:
|
||||
"""
|
||||
Purges old push actions for a user and room before a given
|
||||
@@ -957,20 +989,33 @@ class EventPushActionsWorkerStore(SQLBaseStore):
|
||||
# Instead, we look up the stream ordering for the last event in that
|
||||
# room received before the threshold time and delete event_push_actions
|
||||
# in the room with a stream_odering before that.
|
||||
txn.execute(
|
||||
"DELETE FROM event_push_actions "
|
||||
" WHERE user_id = ? AND room_id = ? AND "
|
||||
" stream_ordering <= ?"
|
||||
" AND ((stream_ordering < ? AND highlight = 1) or highlight = 0)",
|
||||
(user_id, room_id, stream_ordering, self.stream_ordering_month_ago),
|
||||
)
|
||||
if start_stream_ordering is None:
|
||||
stream_ordering_clause = "stream_ordering <= ?"
|
||||
stream_ordering_args: Tuple[int, ...] = (end_stream_ordering,)
|
||||
else:
|
||||
stream_ordering_clause = "stream_ordering >= ? AND stream_ordering <= ?"
|
||||
stream_ordering_args = (start_stream_ordering, end_stream_ordering)
|
||||
|
||||
txn.execute(
|
||||
"""
|
||||
f"""
|
||||
DELETE FROM event_push_actions
|
||||
WHERE user_id = ? AND room_id = ?
|
||||
AND ((stream_ordering < ? AND highlight = 1) or highlight = 0)
|
||||
AND {stream_ordering_clause}
|
||||
""",
|
||||
(user_id, room_id, self.stream_ordering_month_ago) + stream_ordering_args,
|
||||
)
|
||||
|
||||
# XXX What to do about these summaries? They're currently updated daily.
|
||||
# Deleting a chunk of them if any region overlaps seems suspect.
|
||||
# Maybe we can do a daily update to limit the damage? That would not
|
||||
# give true unread status per event, however.
|
||||
txn.execute(
|
||||
f"""
|
||||
DELETE FROM event_push_summary
|
||||
WHERE room_id = ? AND user_id = ? AND stream_ordering <= ?
|
||||
WHERE room_id = ? AND user_id = ? AND {stream_ordering_clause}
|
||||
""",
|
||||
(room_id, user_id, stream_ordering),
|
||||
(room_id, user_id) + stream_ordering_args,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -28,6 +28,7 @@ from typing import (
|
||||
Sequence,
|
||||
Set,
|
||||
Tuple,
|
||||
cast,
|
||||
)
|
||||
|
||||
import attr
|
||||
@@ -195,6 +196,9 @@ class PersistEventsStore:
|
||||
)
|
||||
persist_event_counter.inc(len(events_and_contexts))
|
||||
|
||||
# Update any receipts for users in the rooms.
|
||||
await self._update_receipts(events_and_contexts)
|
||||
|
||||
if not use_negative_stream_ordering:
|
||||
# we don't want to set the event_persisted_position to a negative
|
||||
# stream_ordering.
|
||||
@@ -2099,6 +2103,174 @@ class PersistEventsStore:
|
||||
),
|
||||
)
|
||||
|
||||
def _get_receipts_to_update(
|
||||
self, txn: LoggingTransaction, event: EventBase
|
||||
) -> List[tuple]:
|
||||
# Find any receipt ranges that would be "broken" by this event.
|
||||
sql = """
|
||||
SELECT
|
||||
stream_id,
|
||||
receipts_ranged.room_id,
|
||||
receipt_type,
|
||||
user_id,
|
||||
start_event_id,
|
||||
end_event_id,
|
||||
data,
|
||||
start_event.topological_ordering,
|
||||
end_event.topological_ordering
|
||||
FROM receipts_ranged
|
||||
LEFT JOIN events AS end_event ON (end_event.event_id = end_event_id)
|
||||
LEFT JOIN events AS start_event ON (start_event.event_id = start_event_id)
|
||||
WHERE
|
||||
receipts_ranged.room_id = ? AND
|
||||
(start_event.topological_ordering <= ? OR start_event_id IS NULL) AND
|
||||
? <= end_event.topological_ordering;
|
||||
"""
|
||||
|
||||
txn.execute(
|
||||
sql,
|
||||
(event.room_id, event.depth, event.depth),
|
||||
)
|
||||
return list(txn.fetchall())
|
||||
|
||||
def _split_receipt(
|
||||
self,
|
||||
txn: LoggingTransaction,
|
||||
event: EventBase,
|
||||
stream_id: int,
|
||||
room_id: str,
|
||||
receipt_type: str,
|
||||
user_id: str,
|
||||
start_event_id: str,
|
||||
end_event_id: str,
|
||||
data: JsonDict,
|
||||
start_topological_ordering: int,
|
||||
end_topological_ordering: int,
|
||||
stream_orderings: Tuple[int, ...],
|
||||
) -> None:
|
||||
# Upsert the current receipt to give it a new endpoint as the
|
||||
# latest event in the range before the new event.
|
||||
sql = """
|
||||
SELECT event_id FROM events
|
||||
WHERE room_id = ? AND topological_ordering <= ? AND stream_ordering < ?
|
||||
ORDER BY topological_ordering DESC, stream_ordering DESC LIMIT 1;
|
||||
"""
|
||||
txn.execute(
|
||||
sql,
|
||||
(
|
||||
event.room_id,
|
||||
event.depth,
|
||||
event.internal_metadata.stream_ordering,
|
||||
),
|
||||
)
|
||||
new_end_event_id = cast(Tuple[str], txn.fetchone())[0] # XXX Can this be None?
|
||||
# TODO Upsert?
|
||||
self.db_pool.simple_delete_one_txn(
|
||||
txn, table="receipts_ranged", keyvalues={"stream_id": stream_id}
|
||||
)
|
||||
self.db_pool.simple_insert_txn(
|
||||
txn,
|
||||
table="receipts_ranged",
|
||||
values={
|
||||
"room_id": room_id,
|
||||
"user_id": user_id,
|
||||
"receipt_type": receipt_type,
|
||||
"start_event_id": start_event_id,
|
||||
"end_event_id": new_end_event_id,
|
||||
"stream_id": stream_orderings[0],
|
||||
"data": data, # XXX Does it make sense to duplicate this?
|
||||
},
|
||||
)
|
||||
|
||||
# Insert a new receipt with a start point as the first event after
|
||||
# the new event and re-using the old endpoint.
|
||||
sql = """
|
||||
SELECT event_id FROM events
|
||||
WHERE room_id = ? AND topological_ordering > ? AND stream_ordering < ?
|
||||
ORDER BY topological_ordering, stream_ordering LIMIT 1;
|
||||
"""
|
||||
txn.execute(
|
||||
sql,
|
||||
(
|
||||
event.room_id,
|
||||
event.depth,
|
||||
event.internal_metadata.stream_ordering,
|
||||
),
|
||||
)
|
||||
row = txn.fetchone()
|
||||
# If there's no events topologically after the end event, the
|
||||
# second range is just for the single event.
|
||||
if row is not None:
|
||||
new_start_event_id = row[0]
|
||||
else:
|
||||
new_start_event_id = end_event_id
|
||||
self.db_pool.simple_insert_txn(
|
||||
txn,
|
||||
table="receipts_ranged",
|
||||
values={
|
||||
"room_id": room_id,
|
||||
"user_id": user_id,
|
||||
"receipt_type": receipt_type,
|
||||
"start_event_id": new_start_event_id,
|
||||
"end_event_id": end_event_id,
|
||||
"stream_id": stream_orderings[1],
|
||||
"data": data, # XXX Does it make sense to duplicate this?
|
||||
},
|
||||
)
|
||||
|
||||
txn.call_after(
|
||||
self.store.invalidate_caches_for_receipt,
|
||||
room_id,
|
||||
receipt_type,
|
||||
user_id,
|
||||
)
|
||||
|
||||
async def _update_receipts(
|
||||
self, events_and_contexts: List[Tuple[EventBase, EventContext]]
|
||||
) -> None:
|
||||
# Only non-outlier events can have a receipt associated with them.
|
||||
# XXX Is this true?
|
||||
non_outlier_events = [
|
||||
event
|
||||
for event, _ in events_and_contexts
|
||||
if not event.internal_metadata.is_outlier()
|
||||
]
|
||||
|
||||
# XXX This is probably slow...
|
||||
for event in non_outlier_events:
|
||||
receipts = await self.db_pool.runInteraction(
|
||||
"update_receipts", self._get_receipts_to_update, event=event
|
||||
)
|
||||
|
||||
# Split each receipt in two by the new event.
|
||||
for (
|
||||
stream_id,
|
||||
room_id,
|
||||
receipt_type,
|
||||
user_id,
|
||||
start_event_id,
|
||||
end_event_id,
|
||||
data,
|
||||
start_topological_ordering,
|
||||
end_topological_ordering,
|
||||
) in receipts:
|
||||
async with self.store._receipts_id_gen.get_next_mult(2) as stream_orderings: # type: ignore[attr-defined]
|
||||
await self.db_pool.runInteraction(
|
||||
"split_receipts",
|
||||
self._split_receipt,
|
||||
event=event,
|
||||
stream_id=stream_id,
|
||||
room_id=room_id,
|
||||
receipt_type=receipt_type,
|
||||
user_id=user_id,
|
||||
start_event_id=start_event_id,
|
||||
end_event_id=end_event_id,
|
||||
data=data,
|
||||
start_topological_ordering=start_topological_ordering,
|
||||
end_topological_ordering=end_topological_ordering,
|
||||
stream_orderings=stream_orderings,
|
||||
)
|
||||
|
||||
def _set_push_actions_for_event_and_users_txn(
|
||||
self,
|
||||
txn: LoggingTransaction,
|
||||
@@ -2130,9 +2302,9 @@ class PersistEventsStore:
|
||||
sql = """
|
||||
INSERT INTO event_push_actions (
|
||||
room_id, event_id, user_id, actions, stream_ordering,
|
||||
topological_ordering, notif, highlight, unread
|
||||
topological_ordering, notif, highlight, unread, thread_id
|
||||
)
|
||||
SELECT ?, event_id, user_id, actions, ?, ?, notif, highlight, unread
|
||||
SELECT ?, event_id, user_id, actions, ?, ?, notif, highlight, unread, thread_id
|
||||
FROM event_push_actions_staging
|
||||
WHERE event_id = ?
|
||||
"""
|
||||
|
||||
@@ -417,6 +417,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
|
||||
# "rooms" happens last, to keep the foreign keys in the other tables
|
||||
# happy
|
||||
"rooms",
|
||||
"receipts_ranged",
|
||||
):
|
||||
logger.info("[purge] removing %s from %s", room_id, table)
|
||||
txn.execute("DELETE FROM %s WHERE room_id=?" % (table,), (room_id,))
|
||||
|
||||
@@ -13,6 +13,38 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Receipts are stored as per-user ranges from a starting event to an ending event.
|
||||
If the starting event is missing than the range is considered to cover all events
|
||||
earlier in the room than the ending events.
|
||||
|
||||
Since events in a room are a DAG we need to linearise it before applying receipts.
|
||||
Synapse linearises the room by sorting events by (topological ordering, stream ordering).
|
||||
To ensure that receipts are non-overlapping and correct the following operations
|
||||
need to occur:
|
||||
|
||||
* When a new receipt is received from a client, we coalesce it with other receipts.
|
||||
* When new events are received, any receipt range which includes the event's
|
||||
topological ordering must be split into two receipts.
|
||||
|
||||
Given a simple linear room:
|
||||
|
||||
A--B--C--D
|
||||
|
||||
This is covered by a single receipt [A, D]
|
||||
|
||||
If a forked in the DAG occurs:
|
||||
|
||||
A--B--C--D which linearises to: A--B--E--C--F--D
|
||||
\ /
|
||||
E---F
|
||||
|
||||
The receipt from above must be split into component parts:
|
||||
[A, B]
|
||||
[C, C]
|
||||
[D, D]
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
@@ -42,7 +74,7 @@ from synapse.storage.util.id_generators import (
|
||||
MultiWriterIdGenerator,
|
||||
StreamIdGenerator,
|
||||
)
|
||||
from synapse.types import JsonDict
|
||||
from synapse.types import JsonDict, RangedReadReceipt, ReadReceipt, Receipt
|
||||
from synapse.util import json_encoder
|
||||
from synapse.util.caches.descriptors import cached, cachedList
|
||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||
@@ -380,7 +412,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
||||
def f(txn: LoggingTransaction) -> List[Dict[str, Any]]:
|
||||
if from_key:
|
||||
sql = """
|
||||
SELECT * FROM receipts_linearized WHERE
|
||||
SELECT * FROM receipts_ranged WHERE
|
||||
stream_id > ? AND stream_id <= ? AND
|
||||
"""
|
||||
clause, args = make_in_list_sql_clause(
|
||||
@@ -390,7 +422,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
||||
txn.execute(sql + clause, [from_key, to_key] + list(args))
|
||||
else:
|
||||
sql = """
|
||||
SELECT * FROM receipts_linearized WHERE
|
||||
SELECT * FROM receipts_ranged WHERE
|
||||
stream_id <= ? AND
|
||||
"""
|
||||
|
||||
@@ -417,10 +449,11 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
||||
|
||||
# The content is of the form:
|
||||
# {"$foo:bar": { "read": { "@user:host": <receipt> }, .. }, .. }
|
||||
event_entry = room_event["content"].setdefault(row["event_id"], {})
|
||||
event_entry = room_event["content"].setdefault(row["end_event_id"], {})
|
||||
receipt_type = event_entry.setdefault(row["receipt_type"], {})
|
||||
|
||||
receipt_type[row["user_id"]] = db_to_json(row["data"])
|
||||
receipt_type[row["user_id"]]["start_event_id"] = row["start_event_id"]
|
||||
|
||||
results = {
|
||||
room_id: [results[room_id]] if room_id in results else []
|
||||
@@ -604,7 +637,8 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
||||
room_id: str,
|
||||
receipt_type: str,
|
||||
user_id: str,
|
||||
event_id: str,
|
||||
start_event_id: Optional[str],
|
||||
end_event_id: str,
|
||||
data: JsonDict,
|
||||
stream_id: int,
|
||||
) -> Optional[int]:
|
||||
@@ -617,37 +651,37 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
||||
"""
|
||||
assert self._can_write_to_receipts
|
||||
|
||||
start_topo_ordering = None
|
||||
start_stream_ordering = None
|
||||
if start_event_id is not None:
|
||||
res = self.db_pool.simple_select_one_txn(
|
||||
txn,
|
||||
table="events",
|
||||
retcols=("topological_ordering", "stream_ordering"),
|
||||
keyvalues={"event_id": start_event_id},
|
||||
allow_none=True,
|
||||
)
|
||||
if res is not None:
|
||||
start_topo_ordering = int(res["topological_ordering"])
|
||||
start_stream_ordering = int(res["stream_ordering"])
|
||||
|
||||
res = self.db_pool.simple_select_one_txn(
|
||||
txn,
|
||||
table="events",
|
||||
retcols=["stream_ordering", "received_ts"],
|
||||
keyvalues={"event_id": event_id},
|
||||
retcols=("topological_ordering", "stream_ordering", "received_ts"),
|
||||
keyvalues={"event_id": end_event_id},
|
||||
allow_none=True,
|
||||
)
|
||||
|
||||
stream_ordering = int(res["stream_ordering"]) if res else None
|
||||
end_topo_ordering = (
|
||||
None # XXX When is it valid to not find this event? Federation?
|
||||
)
|
||||
end_stream_ordering = None
|
||||
if res is not None:
|
||||
end_topo_ordering = int(res["topological_ordering"])
|
||||
end_stream_ordering = int(res["stream_ordering"])
|
||||
# XXX This is just for logging in the caller, can it be removed.
|
||||
rx_ts = res["received_ts"] if res else 0
|
||||
|
||||
# We don't want to clobber receipts for more recent events, so we
|
||||
# have to compare orderings of existing receipts
|
||||
if stream_ordering is not None:
|
||||
sql = (
|
||||
"SELECT stream_ordering, event_id FROM events"
|
||||
" INNER JOIN receipts_linearized AS r USING (event_id, room_id)"
|
||||
" WHERE r.room_id = ? AND r.receipt_type = ? AND r.user_id = ?"
|
||||
)
|
||||
txn.execute(sql, (room_id, receipt_type, user_id))
|
||||
|
||||
for so, eid in txn:
|
||||
if int(so) >= stream_ordering:
|
||||
logger.debug(
|
||||
"Ignoring new receipt for %s in favour of existing "
|
||||
"one for later event %s",
|
||||
event_id,
|
||||
eid,
|
||||
)
|
||||
return None
|
||||
|
||||
txn.call_after(
|
||||
self.invalidate_caches_for_receipt, room_id, receipt_type, user_id
|
||||
)
|
||||
@@ -656,33 +690,113 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
||||
self._receipts_stream_cache.entity_has_changed, room_id, stream_id
|
||||
)
|
||||
|
||||
self.db_pool.simple_upsert_txn(
|
||||
# Find all overlapping or adjacent receipts. These receipts are found by
|
||||
# searching for any receipts which:
|
||||
#
|
||||
# * Have an end topological ordering directly before or after the new
|
||||
# receipt's start topological ordering.
|
||||
# * Have a start topological ordering directly after or before the new
|
||||
# receipt's end topological ordering.
|
||||
#
|
||||
# E.g. the following would be found:
|
||||
#
|
||||
# * [1, 7] and [8, 10] should be combined.
|
||||
# * [1, 7] and [5, 10] should be combined.
|
||||
# * [None, 7] and [5, 10] should be combined.
|
||||
#
|
||||
# XXX Do we care about stream ordering here?
|
||||
#
|
||||
# XXX This doesn't handle a start_topo_ordering of None.
|
||||
sql = """
|
||||
SELECT
|
||||
stream_id,
|
||||
start_event_id,
|
||||
start_event.topological_ordering,
|
||||
end_event_id,
|
||||
end_event.topological_ordering
|
||||
FROM receipts_ranged
|
||||
LEFT JOIN events AS end_event ON (end_event.event_id = end_event_id)
|
||||
LEFT JOIN events AS start_event ON (start_event.event_id = start_event_id)
|
||||
WHERE
|
||||
receipts_ranged.room_id = ? AND
|
||||
user_id = ? AND
|
||||
receipt_type = ? AND
|
||||
end_event.topological_ordering >= ? AND
|
||||
start_event.topological_ordering <= ?;
|
||||
"""
|
||||
txn.execute(
|
||||
sql,
|
||||
(
|
||||
room_id,
|
||||
user_id,
|
||||
receipt_type,
|
||||
start_topo_ordering - 1 if start_topo_ordering is not None else None,
|
||||
end_topo_ordering + 1 if end_topo_ordering is not None else None,
|
||||
),
|
||||
)
|
||||
overlapping_receipts = txn.fetchall()
|
||||
# Delete the overlapping receipts by stream ID.
|
||||
self.db_pool.simple_delete_many_txn(
|
||||
txn,
|
||||
table="receipts_linearized",
|
||||
keyvalues={
|
||||
"room_id": room_id,
|
||||
"receipt_type": receipt_type,
|
||||
"user_id": user_id,
|
||||
},
|
||||
table="receipts_ranged",
|
||||
column="stream_id",
|
||||
values=[receipt[0] for receipt in overlapping_receipts],
|
||||
keyvalues={},
|
||||
)
|
||||
|
||||
# Potentially expand the start/end event based on overlapping receipts.
|
||||
for (
|
||||
_,
|
||||
overlapping_start_event_id,
|
||||
overlapping_start_topo_ordering,
|
||||
overlapping_end_event_id,
|
||||
overlapping_end_topo_ordering,
|
||||
) in overlapping_receipts:
|
||||
if (
|
||||
start_topo_ordering is not None
|
||||
and overlapping_start_topo_ordering < start_topo_ordering
|
||||
):
|
||||
start_topo_ordering = overlapping_start_topo_ordering
|
||||
start_event_id = overlapping_start_event_id
|
||||
|
||||
if end_topo_ordering < overlapping_end_topo_ordering:
|
||||
end_topo_ordering = overlapping_end_topo_ordering
|
||||
end_event_id = overlapping_end_event_id
|
||||
|
||||
# Insert the new receipt into the table.
|
||||
self.db_pool.simple_insert_txn(
|
||||
txn,
|
||||
table="receipts_ranged",
|
||||
values={
|
||||
"room_id": room_id,
|
||||
"user_id": user_id,
|
||||
"receipt_type": receipt_type,
|
||||
"start_event_id": start_event_id,
|
||||
"end_event_id": end_event_id,
|
||||
"stream_id": stream_id,
|
||||
"event_id": event_id,
|
||||
"data": json_encoder.encode(data),
|
||||
},
|
||||
# receipts_linearized has a unique constraint on
|
||||
# (user_id, room_id, receipt_type), so no need to lock
|
||||
lock=False,
|
||||
)
|
||||
|
||||
# XXX How do we migrate receipts_linearized or do we use one of non-ranged receipts?
|
||||
|
||||
# When updating a local users read receipt, remove any push actions
|
||||
# which resulted from the receipt's event and all earlier events.
|
||||
#
|
||||
# XXX Can the stream orderings from local users not be known? Maybe if
|
||||
# events are purged (retention?)
|
||||
#
|
||||
# XXX Do we need to differentiate between an unbounded start
|
||||
# (start_event_id == None) vs. being unable to find the event
|
||||
# (start_stream_ordering == None)?
|
||||
if (
|
||||
self.hs.is_mine_id(user_id)
|
||||
and receipt_type in (ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE)
|
||||
and stream_ordering is not None
|
||||
and (start_stream_ordering is not None or end_stream_ordering is not None)
|
||||
):
|
||||
self._remove_old_push_actions_before_txn( # type: ignore[attr-defined]
|
||||
txn, room_id=room_id, user_id=user_id, stream_ordering=stream_ordering
|
||||
# XXX Topo ordering?
|
||||
self._remove_old_push_actions_txn( # type: ignore[attr-defined]
|
||||
txn, room_id, user_id, end_stream_ordering, start_stream_ordering
|
||||
)
|
||||
|
||||
return rx_ts
|
||||
@@ -725,14 +839,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
||||
else:
|
||||
raise RuntimeError("Unrecognized event_ids: %r" % (event_ids,))
|
||||
|
||||
async def insert_receipt(
|
||||
self,
|
||||
room_id: str,
|
||||
receipt_type: str,
|
||||
user_id: str,
|
||||
event_ids: List[str],
|
||||
data: dict,
|
||||
) -> Optional[Tuple[int, int]]:
|
||||
async def insert_receipt(self, receipt: Receipt) -> Optional[Tuple[int, int]]:
|
||||
"""Insert a receipt, either from local client or remote server.
|
||||
|
||||
Automatically does conversion between linearized and graph
|
||||
@@ -744,26 +851,38 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
||||
"""
|
||||
assert self._can_write_to_receipts
|
||||
|
||||
if not event_ids:
|
||||
return None
|
||||
if isinstance(receipt, ReadReceipt):
|
||||
event_ids = receipt.event_ids
|
||||
if not event_ids:
|
||||
return None
|
||||
|
||||
if len(event_ids) == 1:
|
||||
linearized_event_id = event_ids[0]
|
||||
start_event_id = None
|
||||
if len(event_ids) == 1:
|
||||
end_event_id = event_ids[0]
|
||||
else:
|
||||
# we need to points in graph -> linearized form.
|
||||
end_event_id = await self.db_pool.runInteraction(
|
||||
"insert_receipt_conv",
|
||||
self._graph_to_linear,
|
||||
receipt.room_id,
|
||||
event_ids,
|
||||
)
|
||||
elif isinstance(receipt, RangedReadReceipt):
|
||||
start_event_id = receipt.start_event_id
|
||||
end_event_id = receipt.end_event_id
|
||||
else:
|
||||
# we need to points in graph -> linearized form.
|
||||
linearized_event_id = await self.db_pool.runInteraction(
|
||||
"insert_receipt_conv", self._graph_to_linear, room_id, event_ids
|
||||
)
|
||||
raise ValueError("Unexpected receipt type: %s", type(receipt))
|
||||
|
||||
async with self._receipts_id_gen.get_next() as stream_id: # type: ignore[attr-defined]
|
||||
event_ts = await self.db_pool.runInteraction(
|
||||
"insert_linearized_receipt",
|
||||
self._insert_linearized_receipt_txn,
|
||||
room_id,
|
||||
receipt_type,
|
||||
user_id,
|
||||
linearized_event_id,
|
||||
data,
|
||||
receipt.room_id,
|
||||
receipt.receipt_type,
|
||||
receipt.user_id,
|
||||
start_event_id,
|
||||
end_event_id,
|
||||
receipt.data,
|
||||
stream_id=stream_id,
|
||||
# Read committed is actually beneficial here because we check for a receipt with
|
||||
# greater stream order, and checking the very latest data at select time is better
|
||||
@@ -778,20 +897,21 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
||||
now = self._clock.time_msec()
|
||||
logger.debug(
|
||||
"RR for event %s in %s (%i ms old)",
|
||||
linearized_event_id,
|
||||
room_id,
|
||||
end_event_id, # XXX log start?
|
||||
receipt.room_id,
|
||||
now - event_ts,
|
||||
)
|
||||
|
||||
await self.db_pool.runInteraction(
|
||||
"insert_graph_receipt",
|
||||
self._insert_graph_receipt_txn,
|
||||
room_id,
|
||||
receipt_type,
|
||||
user_id,
|
||||
event_ids,
|
||||
data,
|
||||
)
|
||||
# XXX These aren't really used right now, go away.
|
||||
# await self.db_pool.runInteraction(
|
||||
# "insert_graph_receipt",
|
||||
# self._insert_graph_receipt_txn,
|
||||
# room_id,
|
||||
# receipt_type,
|
||||
# user_id,
|
||||
# event_ids,
|
||||
# data,
|
||||
# )
|
||||
|
||||
max_persisted_id = self._receipts_id_gen.get_current_token()
|
||||
|
||||
|
||||
@@ -0,0 +1,32 @@
|
||||
/* Copyright 2022 The Matrix.org Foundation C.I.C
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
CREATE TABLE receipts_ranged (
|
||||
stream_id bigint NOT NULL,
|
||||
room_id text NOT NULL,
|
||||
receipt_type text NOT NULL,
|
||||
user_id text NOT NULL,
|
||||
-- A null start means "everything before this".
|
||||
start_event_id text,
|
||||
end_event_id text NOT NULL,
|
||||
data text NOT NULL,
|
||||
instance_name text
|
||||
);
|
||||
|
||||
|
||||
CREATE INDEX receipts_ranged_id ON receipts_ranged (stream_id);
|
||||
CREATE INDEX receipts_ranged_room_type_user ON receipts_ranged (room_id, receipt_type, user_id);
|
||||
CREATE INDEX receipts_ranged_room_stream ON receipts_ranged (room_id, stream_id);
|
||||
CREATE INDEX receipts_ranged_user ON receipts_ranged (user_id);
|
||||
@@ -0,0 +1,23 @@
|
||||
/* Copyright 2022 The Matrix.org Foundation C.I.C
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
ALTER TABLE event_push_actions_staging
|
||||
ADD COLUMN thread_id TEXT DEFAULT NULL;
|
||||
|
||||
ALTER TABLE event_push_actions
|
||||
ADD COLUMN thread_id TEXT DEFAULT NULL;
|
||||
|
||||
ALTER TABLE event_push_summary
|
||||
ADD COLUMN thread_id TEXT DEFAULT NULL;
|
||||
@@ -822,16 +822,30 @@ class ThirdPartyInstanceID:
|
||||
|
||||
|
||||
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||
class ReadReceipt:
|
||||
"""Information about a read-receipt"""
|
||||
class Receipt:
|
||||
"""Information about a receipt"""
|
||||
|
||||
room_id: str
|
||||
receipt_type: str
|
||||
user_id: str
|
||||
event_ids: List[str]
|
||||
data: JsonDict
|
||||
|
||||
|
||||
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||
class ReadReceipt(Receipt):
|
||||
"""Information about a read-receipt"""
|
||||
|
||||
event_ids: List[str]
|
||||
|
||||
|
||||
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||
class RangedReadReceipt(Receipt):
|
||||
"""Information about a ranged read-receipt"""
|
||||
|
||||
start_event_id: str
|
||||
end_event_id: str
|
||||
|
||||
|
||||
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||
class DeviceListUpdates:
|
||||
"""
|
||||
|
||||
@@ -393,6 +393,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
|
||||
event.event_id,
|
||||
{user_id: actions for user_id, actions in push_actions},
|
||||
False,
|
||||
None,
|
||||
)
|
||||
)
|
||||
return event, context
|
||||
|
||||
@@ -79,6 +79,7 @@ class EventPushActionsStoreTestCase(HomeserverTestCase):
|
||||
event.event_id,
|
||||
{user_id: action},
|
||||
False,
|
||||
None,
|
||||
)
|
||||
)
|
||||
self.get_success(
|
||||
|
||||
Reference in New Issue
Block a user