1
0

Compare commits

...

3 Commits

Author SHA1 Message Date
Erik Johnston
1e05b033af Persited up to command 2020-09-29 14:45:42 +01:00
Erik Johnston
4499d81adf Wire up token 2020-09-29 14:43:28 +01:00
Erik Johnston
a4dde1f23c Reduce usages of RoomStreamToken constructor 2020-09-29 14:43:28 +01:00
15 changed files with 389 additions and 60 deletions

View File

@@ -29,7 +29,6 @@ from synapse.api.errors import (
from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import ( from synapse.types import (
RoomStreamToken,
StreamToken, StreamToken,
get_domain_from_id, get_domain_from_id,
get_verify_key_from_cross_signing_key, get_verify_key_from_cross_signing_key,
@@ -113,8 +112,7 @@ class DeviceWorkerHandler(BaseHandler):
set_tag("user_id", user_id) set_tag("user_id", user_id)
set_tag("from_token", from_token) set_tag("from_token", from_token)
now_room_id = self.store.get_room_max_stream_ordering() now_room_key = self.store.get_room_max_token()
now_room_key = RoomStreamToken(None, now_room_id)
room_ids = await self.store.get_rooms_for_user(user_id) room_ids = await self.store.get_rooms_for_user(user_id)

View File

@@ -1141,7 +1141,7 @@ class RoomEventSource:
return (events, end_key) return (events, end_key)
def get_current_key(self) -> RoomStreamToken: def get_current_key(self) -> RoomStreamToken:
return RoomStreamToken(None, self.store.get_room_max_stream_ordering()) return self.store.get_room_max_token()
def get_current_key_for_room(self, room_id: str) -> Awaitable[str]: def get_current_key_for_room(self, room_id: str) -> Awaitable[str]:
return self.store.get_room_events_max_id(room_id) return self.store.get_room_events_max_id(room_id)

View File

@@ -163,7 +163,7 @@ class _NotifierUserStream:
""" """
# Immediately wake up stream if something has already since happened # Immediately wake up stream if something has already since happened
# since their last token. # since their last token.
if self.last_notified_token.is_after(token): if self.last_notified_token != token:
return _NotificationListener(defer.succeed(self.current_token)) return _NotificationListener(defer.succeed(self.current_token))
else: else:
return _NotificationListener(self.notify_deferred.observe()) return _NotificationListener(self.notify_deferred.observe())
@@ -470,7 +470,7 @@ class Notifier:
async def check_for_updates( async def check_for_updates(
before_token: StreamToken, after_token: StreamToken before_token: StreamToken, after_token: StreamToken
) -> EventStreamResult: ) -> EventStreamResult:
if not after_token.is_after(before_token): if after_token == before_token:
return EventStreamResult([], (from_token, from_token)) return EventStreamResult([], (from_token, from_token))
events = [] # type: List[EventBase] events = [] # type: List[EventBase]

View File

@@ -77,6 +77,7 @@ REQUIREMENTS = [
"Jinja2>=2.9", "Jinja2>=2.9",
"bleach>=1.4.3", "bleach>=1.4.3",
"typing-extensions>=3.7.4", "typing-extensions>=3.7.4",
"cbor2",
] ]
CONDITIONAL_REQUIREMENTS = { CONDITIONAL_REQUIREMENTS = {

View File

@@ -29,7 +29,7 @@ from synapse.replication.tcp.streams.events import (
EventsStreamEventRow, EventsStreamEventRow,
EventsStreamRow, EventsStreamRow,
) )
from synapse.types import PersistedEventPosition, RoomStreamToken, UserID from synapse.types import PersistedEventPosition, UserID
from synapse.util.async_helpers import timeout_deferred from synapse.util.async_helpers import timeout_deferred
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
@@ -152,9 +152,7 @@ class ReplicationDataHandler:
if event.type == EventTypes.Member: if event.type == EventTypes.Member:
extra_users = (UserID.from_string(event.state_key),) extra_users = (UserID.from_string(event.state_key),)
max_token = RoomStreamToken( max_token = self.store.get_room_max_token()
None, self.store.get_room_max_stream_ordering()
)
event_pos = PersistedEventPosition(instance_name, token) event_pos = PersistedEventPosition(instance_name, token)
self.notifier.on_new_room_event( self.notifier.on_new_room_event(
event, event_pos, max_token, extra_users event, event_pos, max_token, extra_users

View File

@@ -171,6 +171,37 @@ class PositionCommand(Command):
return " ".join((self.stream_name, self.instance_name, str(self.token))) return " ".join((self.stream_name, self.instance_name, str(self.token)))
class PersistedToCommand(Command):
"""Sent by writers to inform others that it has persisted up to the included
token.
The included `token` will *not* have been persisted by the instance.
Format::
PERSISTED_TO <stream_name> <instance_name> <token>
On receipt the client should mark that the given instances has persisted
everything up to the given token. Note: this does *not* mean that other
instances have also persisted all their rows up to that point.
"""
NAME = "PERSISTED_TO"
def __init__(self, stream_name, instance_name, token):
self.stream_name = stream_name
self.instance_name = instance_name
self.token = token
@classmethod
def from_line(cls, line):
stream_name, instance_name, token = line.split(" ", 2)
return cls(stream_name, instance_name, int(token))
def to_line(self):
return " ".join((self.stream_name, self.instance_name, str(self.token)))
class ErrorCommand(_SimpleCommand): class ErrorCommand(_SimpleCommand):
"""Sent by either side if there was an ERROR. The data is a string describing """Sent by either side if there was an ERROR. The data is a string describing
the error. the error.
@@ -405,6 +436,7 @@ _COMMANDS = (
UserIpCommand, UserIpCommand,
RemoteServerUpCommand, RemoteServerUpCommand,
ClearUserSyncsCommand, ClearUserSyncsCommand,
PersistedToCommand,
) # type: Tuple[Type[Command], ...] ) # type: Tuple[Type[Command], ...]
# Map of command name to command type. # Map of command name to command type.

View File

@@ -47,6 +47,7 @@ from synapse.replication.tcp.commands import (
ReplicateCommand, ReplicateCommand,
UserIpCommand, UserIpCommand,
UserSyncCommand, UserSyncCommand,
PersistedToCommand,
) )
from synapse.replication.tcp.protocol import AbstractConnection from synapse.replication.tcp.protocol import AbstractConnection
from synapse.replication.tcp.streams import ( from synapse.replication.tcp.streams import (
@@ -387,6 +388,9 @@ class ReplicationCommandHandler:
assert self._server_notices_sender is not None assert self._server_notices_sender is not None
await self._server_notices_sender.on_user_ip(cmd.user_id) await self._server_notices_sender.on_user_ip(cmd.user_id)
def on_PERSISTED_TO(self, conn: AbstractConnection, cmd: PersistedToCommand):
pass
def on_RDATA(self, conn: AbstractConnection, cmd: RdataCommand): def on_RDATA(self, conn: AbstractConnection, cmd: RdataCommand):
if cmd.instance_name == self._instance_name: if cmd.instance_name == self._instance_name:
# Ignore RDATA that are just our own echoes # Ignore RDATA that are just our own echoes

View File

@@ -24,6 +24,7 @@ from twisted.internet.protocol import Factory
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.tcp.protocol import ServerReplicationStreamProtocol from synapse.replication.tcp.protocol import ServerReplicationStreamProtocol
from synapse.replication.tcp.streams import EventsStream
from synapse.util.metrics import Measure from synapse.util.metrics import Measure
stream_updates_counter = Counter( stream_updates_counter = Counter(
@@ -84,6 +85,9 @@ class ReplicationStreamer:
# Set of streams to replicate. # Set of streams to replicate.
self.streams = self.command_handler.get_streams_to_replicate() self.streams = self.command_handler.get_streams_to_replicate()
if self.streams:
self.clock.looping_call(self.on_notifier_poke, 1000.0)
def on_notifier_poke(self): def on_notifier_poke(self):
"""Checks if there is actually any new data and sends it to the """Checks if there is actually any new data and sends it to the
connections if there are. connections if there are.
@@ -126,9 +130,7 @@ class ReplicationStreamer:
random.shuffle(all_streams) random.shuffle(all_streams)
for stream in all_streams: for stream in all_streams:
if stream.last_token == stream.current_token( if not stream.has_updates():
self._instance_name
):
continue continue
if self._replication_torture_level: if self._replication_torture_level:
@@ -174,6 +176,11 @@ class ReplicationStreamer:
except Exception: except Exception:
logger.exception("Failed to replicate") logger.exception("Failed to replicate")
# for command in stream.extra_commands(
# sent_updates=bool(updates)
# ):
# self.command_handler.send_command(command)
logger.debug("No more pending updates, breaking poke loop") logger.debug("No more pending updates, breaking poke loop")
finally: finally:
self.pending_updates = False self.pending_updates = False

View File

@@ -31,6 +31,7 @@ from typing import (
import attr import attr
from synapse.replication.http.streams import ReplicationGetStreamUpdates from synapse.replication.http.streams import ReplicationGetStreamUpdates
from synapse.replication.tcp.commands import Command
if TYPE_CHECKING: if TYPE_CHECKING:
import synapse.server import synapse.server
@@ -187,6 +188,12 @@ class Stream:
) )
return updates, upto_token, limited return updates, upto_token, limited
def has_updates(self) -> bool:
return self.current_token(self.local_instance_name) != self.last_token
def extra_commands(self, sent_updates: bool) -> List[Command]:
return []
def current_token_without_instance( def current_token_without_instance(
current_token: Callable[[], int] current_token: Callable[[], int]

View File

@@ -19,7 +19,8 @@ from typing import List, Tuple, Type
import attr import attr
from ._base import Stream, StreamUpdateResult, Token from synapse.replication.tcp.streams._base import Stream, StreamUpdateResult, Token
from synapse.replication.tcp.commands import Command, PersistedToCommand
"""Handling of the 'events' replication stream """Handling of the 'events' replication stream
@@ -222,3 +223,18 @@ class EventsStream(Stream):
(typ, data) = row (typ, data) = row
data = TypeToRow[typ].from_data(data) data = TypeToRow[typ].from_data(data)
return EventsStreamRow(typ, data) return EventsStreamRow(typ, data)
def has_updates(self) -> bool:
return True
def extra_commands(self, sent_updates: bool) -> List[Command]:
if sent_updates:
return []
return [
PersistedToCommand(
self.NAME,
self.local_instance_name,
self._store._stream_id_gen.get_max_persisted_position_for_self(),
)
]

View File

@@ -178,6 +178,8 @@ class PersistEventsStore:
) )
persist_event_counter.inc(len(events_and_contexts)) persist_event_counter.inc(len(events_and_contexts))
logger.debug("Finished persisting 1")
if not backfilled: if not backfilled:
# backfilled events have negative stream orderings, so we don't # backfilled events have negative stream orderings, so we don't
# want to set the event_persisted_position to that. # want to set the event_persisted_position to that.
@@ -185,6 +187,8 @@ class PersistEventsStore:
events_and_contexts[-1][0].internal_metadata.stream_ordering events_and_contexts[-1][0].internal_metadata.stream_ordering
) )
logger.debug("Finished persisting 2")
for event, context in events_and_contexts: for event, context in events_and_contexts:
if context.app_service: if context.app_service:
origin_type = "local" origin_type = "local"
@@ -198,6 +202,8 @@ class PersistEventsStore:
event_counter.labels(event.type, origin_type, origin_entity).inc() event_counter.labels(event.type, origin_type, origin_entity).inc()
logger.debug("Finished persisting 3")
for room_id, new_state in current_state_for_room.items(): for room_id, new_state in current_state_for_room.items():
self.store.get_current_state_ids.prefill((room_id,), new_state) self.store.get_current_state_ids.prefill((room_id,), new_state)
@@ -206,6 +212,9 @@ class PersistEventsStore:
(room_id,), list(latest_event_ids) (room_id,), list(latest_event_ids)
) )
logger.debug("Finished persisting 4")
logger.debug("Finished persisting 5")
async def _get_events_which_are_prevs(self, event_ids: Iterable[str]) -> List[str]: async def _get_events_which_are_prevs(self, event_ids: Iterable[str]) -> List[str]:
"""Filter the supplied list of event_ids to get those which are prev_events of """Filter the supplied list of event_ids to get those which are prev_events of
existing (non-outlier/rejected) events. existing (non-outlier/rejected) events.

View File

@@ -35,11 +35,10 @@ what sort order was used:
- topological tokems: "t%d-%d", where the integers map to the topological - topological tokems: "t%d-%d", where the integers map to the topological
and stream ordering columns respectively. and stream ordering columns respectively.
""" """
import abc import abc
import logging import logging
from collections import namedtuple from collections import namedtuple
from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple
from twisted.internet import defer from twisted.internet import defer
@@ -54,6 +53,7 @@ from synapse.storage.database import (
) )
from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
from synapse.storage.util.id_generators import MultiWriterIdGenerator
from synapse.types import Collection, RoomStreamToken from synapse.types import Collection, RoomStreamToken
from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -76,6 +76,18 @@ _EventDictReturn = namedtuple(
) )
def _filter_result(
instance_name: str,
stream_id: int,
from_token: RoomStreamToken,
to_token: RoomStreamToken,
) -> bool:
from_id = from_token.instance_map.get(instance_name, from_token.stream)
to_id = to_token.instance_map.get(instance_name, to_token.stream)
return from_id < stream_id <= to_id
def generate_pagination_where_clause( def generate_pagination_where_clause(
direction: str, direction: str,
column_names: Tuple[str, str], column_names: Tuple[str, str],
@@ -209,6 +221,71 @@ def _make_generic_sql_bound(
) )
def _make_instance_filter_clause(
direction: str,
from_token: Optional[RoomStreamToken],
to_token: Optional[RoomStreamToken],
) -> Tuple[str, List[Any]]:
if from_token and from_token.topological:
from_token = None
if to_token and to_token.topological:
to_token = None
if not from_token and not to_token:
return "", []
from_bound = ">=" if direction == "b" else "<"
to_bound = "<" if direction == "b" else ">="
filter_clauses = []
filter_args = [] # type: List[Any]
from_map = from_token.instance_map if from_token else {}
to_map = to_token.instance_map if to_token else {}
default_from = from_token.stream if from_token else None
default_to = to_token.stream if to_token else None
if default_from and default_to:
filter_clauses.append(
"(? %s stream_ordering AND ? %s stream_ordering)" % (from_bound, to_bound)
)
filter_args.extend((default_from, default_to,))
elif default_from:
filter_clauses.append("(? %s stream_ordering)" % (from_bound,))
filter_args.extend((default_from,))
elif default_to:
filter_clauses.append("(? %s stream_ordering)" % (to_bound,))
filter_args.extend((default_to,))
for instance in set(from_map).union(to_map):
from_id = from_map.get(instance, default_from)
to_id = to_map.get(instance, default_to)
if from_id and to_id:
filter_clauses.append(
"(instance_name = ? AND ? %s stream_ordering AND ? %s stream_ordering)"
% (from_bound, to_bound)
)
filter_args.extend((instance, from_id, to_id,))
elif from_id:
filter_clauses.append(
"(instance_name = ? AND ? %s stream_ordering)" % (from_bound,)
)
filter_args.extend((instance, from_id,))
elif to_id:
filter_clauses.append(
"(instance_name = ? AND ? %s stream_ordering)" % (to_bound,)
)
filter_args.extend((instance, to_id,))
filter_clause = ""
if filter_clauses:
filter_clause = "(%s)" % (" OR ".join(filter_clauses),)
return filter_clause, filter_args
def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]: def filter_to_clause(event_filter: Optional[Filter]) -> Tuple[str, List[str]]:
# NB: This may create SQL clauses that don't optimise well (and we don't # NB: This may create SQL clauses that don't optimise well (and we don't
# have indices on all possible clauses). E.g. it may create # have indices on all possible clauses). E.g. it may create
@@ -305,6 +382,22 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
def get_room_min_stream_ordering(self) -> int: def get_room_min_stream_ordering(self) -> int:
raise NotImplementedError() raise NotImplementedError()
def get_room_max_token(self) -> RoomStreamToken:
min_pos = self._stream_id_gen.get_current_token()
positions = {}
if isinstance(self._stream_id_gen, MultiWriterIdGenerator):
positions = {
i: p
for i, p in self._stream_id_gen.get_positions().items()
if p >= min_pos
}
if set(positions.values()) == {min_pos}:
positions = {}
return RoomStreamToken(None, min_pos, positions)
async def get_room_events_stream_for_rooms( async def get_room_events_stream_for_rooms(
self, self,
room_ids: Collection[str], room_ids: Collection[str],
@@ -402,25 +495,50 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
if from_key == to_key: if from_key == to_key:
return [], from_key return [], from_key
from_id = from_key.stream has_changed = self._events_stream_cache.has_entity_changed(
to_id = to_key.stream room_id, from_key.stream
)
has_changed = self._events_stream_cache.has_entity_changed(room_id, from_id)
if not has_changed: if not has_changed:
return [], from_key return [], from_key
def f(txn): def f(txn):
sql = ( filter_clause, filter_args = _make_instance_filter_clause(
"SELECT event_id, stream_ordering FROM events WHERE" "f", from_key, to_key
" room_id = ?" )
" AND not outlier" if filter_clause:
" AND stream_ordering > ? AND stream_ordering <= ?" filter_clause = " AND " + filter_clause
" ORDER BY stream_ordering %s LIMIT ?"
) % (order,)
txn.execute(sql, (room_id, from_id, to_id, limit))
rows = [_EventDictReturn(row[0], None, row[1]) for row in txn] min_from_id = min(from_key.instance_map.values(), default=from_key.stream)
max_to_id = max(to_key.instance_map.values(), default=to_key.stream)
sql = """
SELECT event_id, instance_name, stream_ordering
FROM events
WHERE
room_id = ?
AND not outlier
AND stream_ordering > ? AND stream_ordering <= ?
%s
ORDER BY stream_ordering %s LIMIT ?
""" % (
filter_clause,
order,
)
args = [room_id, min_from_id, max_to_id]
args.extend(filter_args)
args.append(limit)
txn.execute(sql, args)
# rows = [
# _EventDictReturn(event_id, None, stream_ordering)
# for event_id, instance_name, stream_ordering in txn
# if _filter_result(instance_name, stream_ordering, from_key, to_key)
# ]
rows = [
_EventDictReturn(event_id, None, stream_ordering)
for event_id, instance_name, stream_ordering in txn
]
return rows return rows
rows = await self.db_pool.runInteraction("get_room_events_stream_for_room", f) rows = await self.db_pool.runInteraction("get_room_events_stream_for_room", f)
@@ -429,7 +547,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
[r.event_id for r in rows], get_prev_content=True [r.event_id for r in rows], get_prev_content=True
) )
self._set_before_and_after(ret, rows, topo_order=from_id is None) self._set_before_and_after(ret, rows, topo_order=from_key.stream is None)
if order.lower() == "desc": if order.lower() == "desc":
ret.reverse() ret.reverse()
@@ -446,29 +564,40 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
async def get_membership_changes_for_user( async def get_membership_changes_for_user(
self, user_id: str, from_key: RoomStreamToken, to_key: RoomStreamToken self, user_id: str, from_key: RoomStreamToken, to_key: RoomStreamToken
) -> List[EventBase]: ) -> List[EventBase]:
from_id = from_key.stream
to_id = to_key.stream
if from_key == to_key: if from_key == to_key:
return [] return []
if from_id: if from_key:
has_changed = self._membership_stream_cache.has_entity_changed( has_changed = self._membership_stream_cache.has_entity_changed(
user_id, int(from_id) user_id, int(from_key.stream)
) )
if not has_changed: if not has_changed:
return [] return []
def f(txn): def f(txn):
sql = ( filter_clause, filter_args = _make_instance_filter_clause(
"SELECT m.event_id, stream_ordering FROM events AS e," "f", from_key, to_key
" room_memberships AS m"
" WHERE e.event_id = m.event_id"
" AND m.user_id = ?"
" AND e.stream_ordering > ? AND e.stream_ordering <= ?"
" ORDER BY e.stream_ordering ASC"
) )
txn.execute(sql, (user_id, from_id, to_id)) if filter_clause:
filter_clause = " AND " + filter_clause
min_from_id = min(from_key.instance_map.values(), default=from_key.stream)
max_to_id = max(to_key.instance_map.values(), default=to_key.stream)
sql = """
SELECT m.event_id, stream_ordering
FROM events AS e, room_memberships AS m
WHERE e.event_id = m.event_id
AND m.user_id = ?
AND e.stream_ordering > ? AND e.stream_ordering <= ?
%s
ORDER BY e.stream_ordering ASC
""" % (
filter_clause,
)
args = [user_id, min_from_id, max_to_id]
args.extend(filter_args)
txn.execute(sql, args)
rows = [_EventDictReturn(row[0], None, row[1]) for row in txn] rows = [_EventDictReturn(row[0], None, row[1]) for row in txn]
@@ -975,11 +1104,39 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
else: else:
order = "ASC" order = "ASC"
if from_token.topological is not None:
from_bound = from_token.as_tuple()
elif direction == "b":
from_bound = (
None,
max(from_token.instance_map.values(), default=from_token.stream),
)
else:
from_bound = (
None,
min(from_token.instance_map.values(), default=from_token.stream),
)
to_bound = None
if to_token:
if to_token.topological is not None:
to_bound = to_token.as_tuple()
elif direction == "b":
to_bound = (
None,
min(to_token.instance_map.values(), default=to_token.stream),
)
else:
to_bound = (
None,
max(to_token.instance_map.values(), default=to_token.stream),
)
bounds = generate_pagination_where_clause( bounds = generate_pagination_where_clause(
direction=direction, direction=direction,
column_names=("topological_ordering", "stream_ordering"), column_names=("topological_ordering", "stream_ordering"),
from_token=from_token.as_tuple(), from_token=from_bound,
to_token=to_token.as_tuple() if to_token else None, to_token=to_bound,
engine=self.database_engine, engine=self.database_engine,
) )
@@ -989,6 +1146,13 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore, metaclass=abc.ABCMeta):
bounds += " AND " + filter_clause bounds += " AND " + filter_clause
args.extend(filter_args) args.extend(filter_args)
stream_filter_clause, stream_filter_args = _make_instance_filter_clause(
direction, from_token, to_token
)
if stream_filter_clause:
bounds += " AND " + stream_filter_clause
args.extend(stream_filter_args)
args.append(int(limit)) args.append(int(limit))
select_keywords = "SELECT" select_keywords = "SELECT"

View File

@@ -229,7 +229,7 @@ class EventsPersistenceStorage:
defer.gatherResults(deferreds, consumeErrors=True) defer.gatherResults(deferreds, consumeErrors=True)
) )
return RoomStreamToken(None, self.main_store.get_current_events_token()) return self.main_store.get_room_max_token()
async def persist_event( async def persist_event(
self, event: EventBase, context: EventContext, backfilled: bool = False self, event: EventBase, context: EventContext, backfilled: bool = False
@@ -247,11 +247,10 @@ class EventsPersistenceStorage:
await make_deferred_yieldable(deferred) await make_deferred_yieldable(deferred)
max_persisted_id = self.main_store.get_current_events_token()
event_stream_id = event.internal_metadata.stream_ordering event_stream_id = event.internal_metadata.stream_ordering
pos = PersistedEventPosition(self._instance_name, event_stream_id) pos = PersistedEventPosition(self._instance_name, event_stream_id)
return pos, RoomStreamToken(None, max_persisted_id) return pos, self.main_store.get_room_max_token()
def _maybe_start_persisting(self, room_id: str): def _maybe_start_persisting(self, room_id: str):
async def persisting_queue(item): async def persisting_queue(item):

View File

@@ -217,6 +217,7 @@ class MultiWriterIdGenerator:
self._instance_name = instance_name self._instance_name = instance_name
self._positive = positive self._positive = positive
self._writers = writers self._writers = writers
self._sequence_name = sequence_name
self._return_factor = 1 if positive else -1 self._return_factor = 1 if positive else -1
# We lock as some functions may be called from DB threads. # We lock as some functions may be called from DB threads.
@@ -227,6 +228,8 @@ class MultiWriterIdGenerator:
# return them. # return them.
self._current_positions = {} # type: Dict[str, int] self._current_positions = {} # type: Dict[str, int]
self._max_persisted_positions = dict(self._current_positions)
# Set of local IDs that we're still processing. The current position # Set of local IDs that we're still processing. The current position
# should be less than the minimum of this set (if not empty). # should be less than the minimum of this set (if not empty).
self._unfinished_ids = set() # type: Set[int] self._unfinished_ids = set() # type: Set[int]
@@ -404,6 +407,12 @@ class MultiWriterIdGenerator:
current position if possible. current position if possible.
""" """
logger.debug(
"Mark as finished 1 _current_positions %s: %s",
self._sequence_name,
self._current_positions,
)
with self._lock: with self._lock:
self._unfinished_ids.discard(next_id) self._unfinished_ids.discard(next_id)
self._finished_ids.add(next_id) self._finished_ids.add(next_id)
@@ -439,6 +448,16 @@ class MultiWriterIdGenerator:
if new_cur: if new_cur:
curr = self._current_positions.get(self._instance_name, 0) curr = self._current_positions.get(self._instance_name, 0)
self._current_positions[self._instance_name] = max(curr, new_cur) self._current_positions[self._instance_name] = max(curr, new_cur)
self._max_persisted_positions[self._instance_name] = max(
self._current_positions[self._instance_name],
self._max_persisted_positions.get(self._instance_name, 0),
)
logger.debug(
"Mark as finished _current_positions %s: %s",
self._sequence_name,
self._current_positions,
)
self._add_persisted_position(next_id) self._add_persisted_position(next_id)
@@ -454,6 +473,11 @@ class MultiWriterIdGenerator:
""" """
with self._lock: with self._lock:
logger.debug(
"get_current_token_for_writer %s: %s",
self._sequence_name,
self._current_positions,
)
return self._return_factor * self._current_positions.get(instance_name, 0) return self._return_factor * self._current_positions.get(instance_name, 0)
def get_positions(self) -> Dict[str, int]: def get_positions(self) -> Dict[str, int]:
@@ -478,6 +502,12 @@ class MultiWriterIdGenerator:
new_id, self._current_positions.get(instance_name, 0) new_id, self._current_positions.get(instance_name, 0)
) )
self._max_persisted_positions[instance_name] = max(
new_id,
self._current_positions.get(instance_name, 0),
self._max_persisted_positions.get(instance_name, 0),
)
self._add_persisted_position(new_id) self._add_persisted_position(new_id)
def get_persisted_upto_position(self) -> int: def get_persisted_upto_position(self) -> int:
@@ -492,10 +522,29 @@ class MultiWriterIdGenerator:
with self._lock: with self._lock:
return self._return_factor * self._persisted_upto_position return self._return_factor * self._persisted_upto_position
def get_max_persisted_position_for_self(self) -> int:
with self._lock:
if self._unfinished_ids:
return self.get_current_token_for_writer(self._instance_name)
return self._return_factor * max(
self._current_positions.values(), default=1
)
def advance_persisted_to(self, instance_name: str, new_id: int):
new_id *= self._return_factor
with self._lock:
self._max_persisted_positions[instance_name] = max(
new_id,
self._current_positions.get(instance_name, 0),
self._max_persisted_positions.get(instance_name, 0),
)
def _add_persisted_position(self, new_id: int): def _add_persisted_position(self, new_id: int):
"""Record that we have persisted a position. """Record that we have persisted a position.
This is used to keep the `_current_positions` up to date. This is used to keep the `_persisted_upto_position` up to date.
""" """
# We require that the lock is locked by caller # We require that the lock is locked by caller
@@ -506,7 +555,7 @@ class MultiWriterIdGenerator:
# We move the current min position up if the minimum current positions # We move the current min position up if the minimum current positions
# of all instances is higher (since by definition all positions less # of all instances is higher (since by definition all positions less
# that that have been persisted). # that that have been persisted).
min_curr = min(self._current_positions.values(), default=0) min_curr = min(self._max_persisted_positions.values(), default=0)
self._persisted_upto_position = max(min_curr, self._persisted_upto_position) self._persisted_upto_position = max(min_curr, self._persisted_upto_position)
# We now iterate through the seen positions, discarding those that are # We now iterate through the seen positions, discarding those that are

View File

@@ -21,8 +21,9 @@ from collections import namedtuple
from typing import Any, Dict, Mapping, MutableMapping, Optional, Tuple, Type, TypeVar from typing import Any, Dict, Mapping, MutableMapping, Optional, Tuple, Type, TypeVar
import attr import attr
import cbor2
from signedjson.key import decode_verify_key_bytes from signedjson.key import decode_verify_key_bytes
from unpaddedbase64 import decode_base64 from unpaddedbase64 import decode_base64, encode_base64
from synapse.api.errors import Codes, SynapseError from synapse.api.errors import Codes, SynapseError
@@ -362,7 +363,7 @@ def map_username_to_mxid_localpart(username, case_sensitive=False):
return username.decode("ascii") return username.decode("ascii")
@attr.s(frozen=True, slots=True) @attr.s(frozen=True, slots=True, cmp=False)
class RoomStreamToken: class RoomStreamToken:
"""Tokens are positions between events. The token "s1" comes after event 1. """Tokens are positions between events. The token "s1" comes after event 1.
@@ -392,6 +393,8 @@ class RoomStreamToken:
) )
stream = attr.ib(type=int, validator=attr.validators.instance_of(int)) stream = attr.ib(type=int, validator=attr.validators.instance_of(int))
instance_map = attr.ib(type=Dict[str, int], factory=dict)
@classmethod @classmethod
def parse(cls, string: str) -> "RoomStreamToken": def parse(cls, string: str) -> "RoomStreamToken":
try: try:
@@ -400,6 +403,11 @@ class RoomStreamToken:
if string[0] == "t": if string[0] == "t":
parts = string[1:].split("-", 1) parts = string[1:].split("-", 1)
return cls(topological=int(parts[0]), stream=int(parts[1])) return cls(topological=int(parts[0]), stream=int(parts[1]))
if string[0] == "m":
payload = cbor2.loads(decode_base64(string[1:]))
return cls(
topological=None, stream=payload["s"], instance_map=payload["p"],
)
except Exception: except Exception:
pass pass
raise SynapseError(400, "Invalid token %r" % (string,)) raise SynapseError(400, "Invalid token %r" % (string,))
@@ -413,15 +421,49 @@ class RoomStreamToken:
pass pass
raise SynapseError(400, "Invalid token %r" % (string,)) raise SynapseError(400, "Invalid token %r" % (string,))
def copy_and_advance(self, other: "RoomStreamToken") -> "RoomStreamToken":
if self.topological or other.topological:
raise Exception("Can't advance topological tokens")
max_stream = max(self.stream, other.stream)
instance_map = {
instance: max(
self.instance_map.get(instance, self.stream),
other.instance_map.get(instance, other.stream),
)
for instance in set(self.instance_map).union(other.instance_map)
}
return RoomStreamToken(None, max_stream, instance_map)
def as_tuple(self) -> Tuple[Optional[int], int]: def as_tuple(self) -> Tuple[Optional[int], int]:
return (self.topological, self.stream) return (self.topological, self.stream)
def __str__(self) -> str: def __str__(self) -> str:
if self.topological is not None: if self.topological is not None:
return "t%d-%d" % (self.topological, self.stream) return "t%d-%d" % (self.topological, self.stream)
elif self.instance_map:
return "m" + encode_base64(
cbor2.dumps({"s": self.stream, "p": self.instance_map}),
)
else: else:
return "s%d" % (self.stream,) return "s%d" % (self.stream,)
def __lt__(self, other: "RoomStreamToken"):
if self.stream != other.stream:
return self.stream < other.stream
for instance in set(self.instance_map).union(other.instance_map):
if self.instance_map.get(instance, self.stream) != other.instance_map.get(
instance, other.stream
):
return self.instance_map.get(
instance, self.stream
) < other.instance_map.get(instance, other.stream)
return False
@attr.s(slots=True, frozen=True) @attr.s(slots=True, frozen=True)
class StreamToken: class StreamToken:
@@ -461,7 +503,7 @@ class StreamToken:
def is_after(self, other): def is_after(self, other):
"""Does this token contain events that the other doesn't?""" """Does this token contain events that the other doesn't?"""
return ( return (
(other.room_stream_id < self.room_stream_id) (other.room_key < self.room_key)
or (int(other.presence_key) < int(self.presence_key)) or (int(other.presence_key) < int(self.presence_key))
or (int(other.typing_key) < int(self.typing_key)) or (int(other.typing_key) < int(self.typing_key))
or (int(other.receipt_key) < int(self.receipt_key)) or (int(other.receipt_key) < int(self.receipt_key))
@@ -476,13 +518,16 @@ class StreamToken:
"""Advance the given key in the token to a new value if and only if the """Advance the given key in the token to a new value if and only if the
new value is after the old value. new value is after the old value.
""" """
new_token = self.copy_and_replace(key, new_value)
if key == "room_key": if key == "room_key":
new_id = new_token.room_stream_id new_token = self.copy_and_replace(
old_id = self.room_stream_id "room_key", self.room_key.copy_and_advance(new_value)
else: )
return new_token
new_token = self.copy_and_replace(key, new_value)
new_id = int(getattr(new_token, key)) new_id = int(getattr(new_token, key))
old_id = int(getattr(self, key)) old_id = int(getattr(self, key))
if old_id < new_id: if old_id < new_id:
return new_token return new_token
else: else:
@@ -507,7 +552,7 @@ class PersistedEventPosition:
stream = attr.ib(type=int) stream = attr.ib(type=int)
def persisted_after(self, token: RoomStreamToken) -> bool: def persisted_after(self, token: RoomStreamToken) -> bool:
return token.stream < self.stream return token.instance_map.get(self.instance_name, token.stream) < self.stream
class ThirdPartyInstanceID( class ThirdPartyInstanceID(