Compare commits
5 Commits
v1.99.0
...
erikj/shar
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2ecc2200f3 | ||
|
|
e894f67509 | ||
|
|
ab73bf1c85 | ||
|
|
98b7415fcd | ||
|
|
c4ca1de6a0 |
@@ -832,11 +832,28 @@ class ShardedWorkerHandlingConfig:
|
||||
def should_handle(self, instance_name: str, key: str) -> bool:
|
||||
"""Whether this instance is responsible for handling the given key.
|
||||
"""
|
||||
|
||||
# If multiple instances are not defined we always return true.
|
||||
# If multiple instances are not defined we always return true
|
||||
if not self.instances or len(self.instances) == 1:
|
||||
return True
|
||||
|
||||
return self.get_instance(key) == instance_name
|
||||
|
||||
def get_instance(self, key: str) -> str:
|
||||
"""Get the instance responsible for handling the given key.
|
||||
|
||||
Note: For things like federation sending the config for which instance
|
||||
is sending is known only to the sender instance if there is only one.
|
||||
Therefore `should_handle` should be used where possible.
|
||||
"""
|
||||
|
||||
# Note: For things like federation sending the config for which instance
|
||||
# is sending is known only to the sender instance if there is only one.
|
||||
if not self.instances:
|
||||
return "master"
|
||||
|
||||
if len(self.instances) == 1:
|
||||
return self.instances[0]
|
||||
|
||||
# We shard by taking the hash, modulo it by the number of instances and
|
||||
# then checking whether this instance matches the instance at that
|
||||
# index.
|
||||
@@ -846,7 +863,7 @@ class ShardedWorkerHandlingConfig:
|
||||
dest_hash = sha256(key.encode("utf8")).digest()
|
||||
dest_int = int.from_bytes(dest_hash, byteorder="little")
|
||||
remainder = dest_int % (len(self.instances))
|
||||
return self.instances[remainder] == instance_name
|
||||
return self.instances[remainder]
|
||||
|
||||
|
||||
__all__ = ["Config", "RootConfig", "ShardedWorkerHandlingConfig"]
|
||||
|
||||
@@ -142,3 +142,4 @@ class ShardedWorkerHandlingConfig:
|
||||
instances: List[str]
|
||||
def __init__(self, instances: List[str]) -> None: ...
|
||||
def should_handle(self, instance_name: str, key: str) -> bool: ...
|
||||
def get_instance(self, key: str) -> str: ...
|
||||
|
||||
@@ -13,12 +13,24 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Union
|
||||
|
||||
import attr
|
||||
|
||||
from ._base import Config, ConfigError, ShardedWorkerHandlingConfig
|
||||
from .server import ListenerConfig, parse_listener_def
|
||||
|
||||
|
||||
def _instance_to_list_converter(obj: Union[str, List[str]]) -> List[str]:
|
||||
"""Helper for allowing parsing a string or list of strings to a config
|
||||
option expecting a list of strings.
|
||||
"""
|
||||
|
||||
if isinstance(obj, str):
|
||||
return [obj]
|
||||
return obj
|
||||
|
||||
|
||||
@attr.s
|
||||
class InstanceLocationConfig:
|
||||
"""The host and port to talk to an instance via HTTP replication.
|
||||
@@ -33,11 +45,13 @@ class WriterLocations:
|
||||
"""Specifies the instances that write various streams.
|
||||
|
||||
Attributes:
|
||||
events: The instance that writes to the event and backfill streams.
|
||||
events: The instances that write to the event and backfill streams.
|
||||
events: The instance that writes to the typing stream.
|
||||
"""
|
||||
|
||||
events = attr.ib(default="master", type=str)
|
||||
events = attr.ib(
|
||||
default=["master"], type=List[str], converter=_instance_to_list_converter
|
||||
)
|
||||
typing = attr.ib(default="master", type=str)
|
||||
|
||||
|
||||
@@ -108,12 +122,15 @@ class WorkerConfig(Config):
|
||||
# Check that the configured writer for events and typing also appears in
|
||||
# `instance_map`.
|
||||
for stream in ("events", "typing"):
|
||||
instance = getattr(self.writers, stream)
|
||||
if instance != "master" and instance not in self.instance_map:
|
||||
raise ConfigError(
|
||||
"Instance %r is configured to write %s but does not appear in `instance_map` config."
|
||||
% (instance, stream)
|
||||
)
|
||||
instances = _instance_to_list_converter(getattr(self.writers, stream))
|
||||
for instance in instances:
|
||||
if instance != "master" and instance not in self.instance_map:
|
||||
raise ConfigError(
|
||||
"Instance %r is configured to write %s but does not appear in `instance_map` config."
|
||||
% (instance, stream)
|
||||
)
|
||||
|
||||
self.events_shard_config = ShardedWorkerHandlingConfig(self.writers.events)
|
||||
|
||||
def generate_config_section(self, config_dir_path, server_name, **kwargs):
|
||||
return """\
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
|
||||
"""Contains handlers for federation events."""
|
||||
|
||||
import collections
|
||||
import itertools
|
||||
import logging
|
||||
from collections.abc import Container
|
||||
@@ -917,7 +918,8 @@ class FederationHandler(BaseHandler):
|
||||
)
|
||||
)
|
||||
|
||||
await self._handle_new_events(dest, ev_infos, backfilled=True)
|
||||
if ev_infos:
|
||||
await self._handle_new_events(dest, ev_infos, backfilled=True)
|
||||
|
||||
# Step 2: Persist the rest of the events in the chunk one by one
|
||||
events.sort(key=lambda e: e.depth)
|
||||
@@ -1365,12 +1367,25 @@ class FederationHandler(BaseHandler):
|
||||
#
|
||||
# TODO: Currently the events stream is written to from master
|
||||
await self._replication.wait_for_stream_position(
|
||||
self.config.worker.writers.events, "events", max_stream_id
|
||||
self.config.worker.events_shard_config.get_instance(room_id),
|
||||
"events",
|
||||
max_stream_id,
|
||||
)
|
||||
|
||||
predecessor = None
|
||||
for s in state:
|
||||
if s.type == EventTypes.Create:
|
||||
predecessor = s.content.get("predecessor", None)
|
||||
|
||||
# Ensure the key is a dictionary
|
||||
if not isinstance(predecessor, collections.abc.Mapping):
|
||||
predecessor = None
|
||||
|
||||
break
|
||||
|
||||
# Check whether this room is the result of an upgrade of a room we already know
|
||||
# about. If so, migrate over user information
|
||||
predecessor = await self.store.get_room_predecessor(room_id)
|
||||
# predecessor = await self.store.get_room_predecessor(room_id)
|
||||
if not predecessor or not isinstance(predecessor.get("room_id"), str):
|
||||
return event.event_id, max_stream_id
|
||||
old_room_id = predecessor["room_id"]
|
||||
@@ -2904,9 +2919,13 @@ class FederationHandler(BaseHandler):
|
||||
backfilled: Whether these events are a result of
|
||||
backfilling or not
|
||||
"""
|
||||
if self.config.worker.writers.events != self._instance_name:
|
||||
# FIXME:
|
||||
instance = self.config.worker.events_shard_config.get_instance(
|
||||
event_and_contexts[0][0].room_id
|
||||
)
|
||||
if instance != self._instance_name:
|
||||
result = await self._send_events(
|
||||
instance_name=self.config.worker.writers.events,
|
||||
instance_name=instance,
|
||||
store=self.store,
|
||||
event_and_contexts=event_and_contexts,
|
||||
backfilled=backfilled,
|
||||
|
||||
@@ -377,9 +377,8 @@ class EventCreationHandler(object):
|
||||
self.notifier = hs.get_notifier()
|
||||
self.config = hs.config
|
||||
self.require_membership_for_aliases = hs.config.require_membership_for_aliases
|
||||
self._is_event_writer = (
|
||||
self.config.worker.writers.events == hs.get_instance_name()
|
||||
)
|
||||
self._events_shard_config = self.config.worker.events_shard_config
|
||||
self._instance_name = hs.get_instance_name()
|
||||
|
||||
self.room_invite_state_types = self.hs.config.room_invite_state_types
|
||||
|
||||
@@ -874,9 +873,10 @@ class EventCreationHandler(object):
|
||||
|
||||
try:
|
||||
# If we're a worker we need to hit out to the master.
|
||||
if not self._is_event_writer:
|
||||
writer_instance = self._events_shard_config.get_instance(event.room_id)
|
||||
if writer_instance != self._instance_name:
|
||||
result = await self.send_event(
|
||||
instance_name=self.config.worker.writers.events,
|
||||
instance_name=writer_instance,
|
||||
event_id=event.event_id,
|
||||
store=self.store,
|
||||
requester=requester,
|
||||
@@ -944,7 +944,9 @@ class EventCreationHandler(object):
|
||||
|
||||
This should only be run on the instance in charge of persisting events.
|
||||
"""
|
||||
assert self._is_event_writer
|
||||
assert self._events_shard_config.should_handle(
|
||||
self._instance_name, event.room_id
|
||||
)
|
||||
|
||||
if ratelimit:
|
||||
# We check if this is a room admin redacting an event so that we
|
||||
|
||||
@@ -776,7 +776,9 @@ class RoomCreationHandler(BaseHandler):
|
||||
|
||||
# Always wait for room creation to progate before returning
|
||||
await self._replication.wait_for_stream_position(
|
||||
self.hs.config.worker.writers.events, "events", last_stream_id
|
||||
self.hs.config.worker.events_shard_config.get_instance(room_id),
|
||||
"events",
|
||||
last_stream_id,
|
||||
)
|
||||
|
||||
return result, last_stream_id
|
||||
@@ -1233,7 +1235,9 @@ class RoomShutdownHandler(object):
|
||||
#
|
||||
# TODO: Currently the events stream is written to from master
|
||||
await self._replication.wait_for_stream_position(
|
||||
self.hs.config.worker.writers.events, "events", stream_id
|
||||
self.hs.config.worker.events_shard_config.get_instance(new_room_id),
|
||||
"events",
|
||||
stream_id,
|
||||
)
|
||||
else:
|
||||
new_room_id = None
|
||||
@@ -1263,7 +1267,9 @@ class RoomShutdownHandler(object):
|
||||
|
||||
# Wait for leave to come in over replication before trying to forget.
|
||||
await self._replication.wait_for_stream_position(
|
||||
self.hs.config.worker.writers.events, "events", stream_id
|
||||
self.hs.config.worker.events_shard_config.get_instance(room_id),
|
||||
"events",
|
||||
stream_id,
|
||||
)
|
||||
|
||||
await self.room_member_handler.forget(target_requester.user, room_id)
|
||||
|
||||
@@ -75,13 +75,6 @@ class RoomMemberHandler(object):
|
||||
self._enable_lookup = hs.config.enable_3pid_lookup
|
||||
self.allow_per_room_profiles = self.config.allow_per_room_profiles
|
||||
|
||||
self._event_stream_writer_instance = hs.config.worker.writers.events
|
||||
self._is_on_event_persistence_instance = (
|
||||
self._event_stream_writer_instance == hs.get_instance_name()
|
||||
)
|
||||
if self._is_on_event_persistence_instance:
|
||||
self.persist_event_storage = hs.get_storage().persistence
|
||||
|
||||
self._join_rate_limiter_local = Ratelimiter(
|
||||
clock=self.clock,
|
||||
rate_hz=hs.config.ratelimiting.rc_joins_local.per_second,
|
||||
|
||||
@@ -80,4 +80,4 @@ class SlavedEventStore(
|
||||
return self._stream_id_gen.get_current_token()
|
||||
|
||||
def get_room_min_stream_ordering(self):
|
||||
return self._backfill_id_gen.get_current_token()
|
||||
return -self._backfill_id_gen.get_current_token()
|
||||
|
||||
@@ -109,7 +109,7 @@ class ReplicationCommandHandler:
|
||||
if isinstance(stream, (EventsStream, BackfillStream)):
|
||||
# Only add EventStream and BackfillStream as a source on the
|
||||
# instance in charge of event persistence.
|
||||
if hs.config.worker.writers.events == hs.get_instance_name():
|
||||
if hs.get_instance_name() in hs.config.worker.writers.events:
|
||||
self._streams_to_replicate.append(stream)
|
||||
|
||||
continue
|
||||
|
||||
@@ -19,7 +19,7 @@ from typing import List, Tuple, Type
|
||||
|
||||
import attr
|
||||
|
||||
from ._base import Stream, StreamUpdateResult, Token, current_token_without_instance
|
||||
from ._base import Stream, StreamUpdateResult, Token
|
||||
|
||||
"""Handling of the 'events' replication stream
|
||||
|
||||
@@ -117,7 +117,7 @@ class EventsStream(Stream):
|
||||
self._store = hs.get_datastore()
|
||||
super().__init__(
|
||||
hs.get_instance_name(),
|
||||
current_token_without_instance(self._store.get_current_events_token),
|
||||
self._store._stream_id_gen.get_current_token_for_writer,
|
||||
self._update_function,
|
||||
)
|
||||
|
||||
|
||||
@@ -68,7 +68,7 @@ class Databases(object):
|
||||
|
||||
# If we're on a process that can persist events also
|
||||
# instantiate a `PersistEventsStore`
|
||||
if hs.config.worker.writers.events == hs.get_instance_name():
|
||||
if hs.get_instance_name() in hs.config.worker.writers.events:
|
||||
persist_events = PersistEventsStore(hs, database, main)
|
||||
|
||||
if "state" in database_config.databases:
|
||||
|
||||
@@ -440,7 +440,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
|
||||
"""
|
||||
|
||||
if stream_ordering <= self.stream_ordering_month_ago:
|
||||
raise StoreError(400, "stream_ordering too old")
|
||||
raise StoreError(400, "stream_ordering too old %s" % (stream_ordering,))
|
||||
|
||||
sql = """
|
||||
SELECT event_id FROM stream_ordering_to_exterm
|
||||
|
||||
@@ -34,6 +34,7 @@ from synapse.storage.database import DatabasePool, LoggingTransaction
|
||||
from synapse.storage.databases.main.search import SearchEntry
|
||||
from synapse.storage.util.id_generators import StreamIdGenerator
|
||||
from synapse.types import StateMap, get_domain_from_id
|
||||
from synapse.util.async_helpers import maybe_awaitable
|
||||
from synapse.util.frozenutils import frozendict_json_encoder
|
||||
from synapse.util.iterutils import batch_iter
|
||||
|
||||
@@ -97,6 +98,7 @@ class PersistEventsStore:
|
||||
self.store = main_data_store
|
||||
self.database_engine = db.engine
|
||||
self._clock = hs.get_clock()
|
||||
self._instance_name = hs.get_instance_name()
|
||||
|
||||
self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
|
||||
self.is_mine_id = hs.is_mine_id
|
||||
@@ -108,7 +110,7 @@ class PersistEventsStore:
|
||||
|
||||
# This should only exist on instances that are configured to write
|
||||
assert (
|
||||
hs.config.worker.writers.events == hs.get_instance_name()
|
||||
hs.get_instance_name() in hs.config.worker.writers.events
|
||||
), "Can only instantiate EventsStore on master"
|
||||
|
||||
async def _persist_events_and_state_updates(
|
||||
@@ -153,15 +155,18 @@ class PersistEventsStore:
|
||||
# Note: Multiple instances of this function cannot be in flight at
|
||||
# the same time for the same room.
|
||||
if backfilled:
|
||||
stream_ordering_manager = self._backfill_id_gen.get_next_mult(
|
||||
len(events_and_contexts)
|
||||
stream_ordering_manager = await maybe_awaitable(
|
||||
self._backfill_id_gen.get_next_mult(len(events_and_contexts))
|
||||
)
|
||||
else:
|
||||
stream_ordering_manager = self._stream_id_gen.get_next_mult(
|
||||
len(events_and_contexts)
|
||||
stream_ordering_manager = await maybe_awaitable(
|
||||
self._stream_id_gen.get_next_mult(len(events_and_contexts))
|
||||
)
|
||||
|
||||
with stream_ordering_manager as stream_orderings:
|
||||
if backfilled:
|
||||
stream_orderings = [-s for s in stream_orderings]
|
||||
|
||||
for (event, context), stream in zip(events_and_contexts, stream_orderings):
|
||||
event.internal_metadata.stream_ordering = stream
|
||||
|
||||
@@ -800,6 +805,7 @@ class PersistEventsStore:
|
||||
table="events",
|
||||
values=[
|
||||
{
|
||||
"instance_name": self._instance_name,
|
||||
"stream_ordering": event.internal_metadata.stream_ordering,
|
||||
"topological_ordering": event.depth,
|
||||
"depth": event.depth,
|
||||
|
||||
@@ -37,12 +37,12 @@ from synapse.events import EventBase, make_event_from_dict
|
||||
from synapse.events.utils import prune_event
|
||||
from synapse.logging.context import PreserveLoggingContext, current_context
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
|
||||
from synapse.replication.tcp.streams import BackfillStream
|
||||
from synapse.replication.tcp.streams.events import EventsStream
|
||||
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
|
||||
from synapse.storage.database import DatabasePool
|
||||
from synapse.storage.util.id_generators import StreamIdGenerator
|
||||
from synapse.storage.engines import PostgresEngine
|
||||
from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
|
||||
from synapse.types import Collection, get_domain_from_id
|
||||
from synapse.util.caches.descriptors import Cache, cached
|
||||
from synapse.util.iterutils import batch_iter
|
||||
@@ -78,26 +78,60 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
def __init__(self, database: DatabasePool, db_conn, hs):
|
||||
super(EventsWorkerStore, self).__init__(database, db_conn, hs)
|
||||
|
||||
if hs.config.worker.writers.events == hs.get_instance_name():
|
||||
if hs.get_instance_name() in hs.config.worker.writers.events:
|
||||
# We are the process in charge of generating stream ids for events,
|
||||
# so instantiate ID generators based on the database
|
||||
self._stream_id_gen = StreamIdGenerator(
|
||||
db_conn, "events", "stream_ordering",
|
||||
)
|
||||
self._backfill_id_gen = StreamIdGenerator(
|
||||
db_conn,
|
||||
"events",
|
||||
"stream_ordering",
|
||||
step=-1,
|
||||
extra_tables=[("ex_outlier_stream", "event_stream_ordering")],
|
||||
)
|
||||
if isinstance(database.engine, PostgresEngine):
|
||||
self._stream_id_gen = MultiWriterIdGenerator(
|
||||
db_conn=db_conn,
|
||||
db=database,
|
||||
instance_name=hs.get_instance_name(),
|
||||
table="events",
|
||||
instance_column="instance_name",
|
||||
id_column="stream_ordering",
|
||||
sequence_name="events_stream_seq",
|
||||
)
|
||||
self._backfill_id_gen = MultiWriterIdGenerator(
|
||||
db_conn=db_conn,
|
||||
db=database,
|
||||
instance_name=hs.get_instance_name(),
|
||||
table="events",
|
||||
instance_column="instance_name",
|
||||
id_column="-stream_ordering",
|
||||
sequence_name="events_backfill_stream_seq",
|
||||
)
|
||||
else:
|
||||
self._stream_id_gen = StreamIdGenerator(
|
||||
db_conn, "events", "stream_ordering",
|
||||
)
|
||||
|
||||
self._backfill_id_gen = StreamIdGenerator(
|
||||
db_conn,
|
||||
"events",
|
||||
"-stream_ordering",
|
||||
# extra_tables=[("ex_outlier_stream", "event_stream_ordering")],
|
||||
)
|
||||
else:
|
||||
# Another process is in charge of persisting events and generating
|
||||
# stream IDs: rely on the replication streams to let us know which
|
||||
# IDs we can process.
|
||||
self._stream_id_gen = SlavedIdTracker(db_conn, "events", "stream_ordering")
|
||||
self._backfill_id_gen = SlavedIdTracker(
|
||||
db_conn, "events", "stream_ordering", step=-1
|
||||
self._stream_id_gen = MultiWriterIdGenerator(
|
||||
db_conn=db_conn,
|
||||
db=database,
|
||||
instance_name=hs.get_instance_name(),
|
||||
table="events",
|
||||
instance_column="instance_name",
|
||||
id_column="stream_ordering",
|
||||
sequence_name="events_stream_seq",
|
||||
)
|
||||
self._backfill_id_gen = MultiWriterIdGenerator(
|
||||
db_conn=db_conn,
|
||||
db=database,
|
||||
instance_name=hs.get_instance_name(),
|
||||
table="events",
|
||||
instance_column="instance_name",
|
||||
id_column="-stream_ordering",
|
||||
sequence_name="events_backfill_stream_seq",
|
||||
)
|
||||
|
||||
self._get_event_cache = Cache(
|
||||
@@ -113,9 +147,9 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
|
||||
def process_replication_rows(self, stream_name, instance_name, token, rows):
|
||||
if stream_name == EventsStream.NAME:
|
||||
self._stream_id_gen.advance(token)
|
||||
self._stream_id_gen.advance(instance_name, token)
|
||||
elif stream_name == BackfillStream.NAME:
|
||||
self._backfill_id_gen.advance(-token)
|
||||
self._backfill_id_gen.advance(instance_name, token)
|
||||
|
||||
super().process_replication_rows(stream_name, instance_name, token, rows)
|
||||
|
||||
@@ -966,7 +1000,7 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
|
||||
def get_current_backfill_token(self):
|
||||
"""The current minimum token that backfilled events have reached"""
|
||||
return -self._backfill_id_gen.get_current_token()
|
||||
return self._backfill_id_gen.get_current_token()
|
||||
|
||||
def get_current_events_token(self):
|
||||
"""The current maximum token that events have reached"""
|
||||
|
||||
@@ -0,0 +1,16 @@
|
||||
/* Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
ALTER TABLE events ADD COLUMN instance_name TEXT;
|
||||
@@ -0,0 +1,26 @@
|
||||
/* Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
CREATE SEQUENCE IF NOT EXISTS events_stream_seq;
|
||||
|
||||
SELECT setval('events_stream_seq', (
|
||||
SELECT COALESCE(MAX(stream_ordering), 1) FROM events
|
||||
));
|
||||
|
||||
CREATE SEQUENCE IF NOT EXISTS events_backfill_stream_seq;
|
||||
|
||||
SELECT setval('events_backfill_stream_seq', (
|
||||
SELECT COALESCE(-MIN(stream_ordering), 1) FROM events
|
||||
));
|
||||
@@ -1084,4 +1084,4 @@ class StreamStore(StreamWorkerStore):
|
||||
return self._stream_id_gen.get_current_token()
|
||||
|
||||
def get_room_min_stream_ordering(self) -> int:
|
||||
return self._backfill_id_gen.get_current_token()
|
||||
return -self._backfill_id_gen.get_current_token()
|
||||
|
||||
@@ -14,15 +14,19 @@
|
||||
# limitations under the License.
|
||||
|
||||
import contextlib
|
||||
import heapq
|
||||
import logging
|
||||
import threading
|
||||
from collections import deque
|
||||
from typing import Dict, Set
|
||||
from typing import Dict, List, Set
|
||||
|
||||
from typing_extensions import Deque
|
||||
|
||||
from synapse.storage.database import DatabasePool, LoggingTransaction
|
||||
from synapse.storage.util.sequence import PostgresSequenceGenerator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class IdGenerator(object):
|
||||
def __init__(self, db_conn, table, column):
|
||||
@@ -145,7 +149,7 @@ class StreamIdGenerator(object):
|
||||
|
||||
return manager()
|
||||
|
||||
def get_current_token(self):
|
||||
def get_current_token(self, instance_name=None):
|
||||
"""Returns the maximum stream id such that all stream ids less than or
|
||||
equal to it have been successfully persisted.
|
||||
|
||||
@@ -198,6 +202,7 @@ class MultiWriterIdGenerator:
|
||||
):
|
||||
self._db = db
|
||||
self._instance_name = instance_name
|
||||
self._sequence_name = sequence_name
|
||||
|
||||
# We lock as some functions may be called from DB threads.
|
||||
self._lock = threading.Lock()
|
||||
@@ -210,6 +215,15 @@ class MultiWriterIdGenerator:
|
||||
# should be less than the minimum of this set (if not empty).
|
||||
self._unfinished_ids = set() # type: Set[int]
|
||||
|
||||
# We track the max position where we know everything before has been
|
||||
# persisted. This is done by a) looking at the min across all instances
|
||||
# and b) noting that if we have seen a run of persisted positions
|
||||
# without gaps (e.g. 5, 6, 7) then we can skip forward (e.g. to 7).
|
||||
self._persisted_upto_position = (
|
||||
min(self._current_positions.values()) if self._current_positions else 1
|
||||
)
|
||||
self._known_persisted_positions = [] # type: List[int]
|
||||
|
||||
self._sequence_gen = PostgresSequenceGenerator(sequence_name)
|
||||
|
||||
def _load_current_ids(
|
||||
@@ -234,9 +248,12 @@ class MultiWriterIdGenerator:
|
||||
|
||||
return current_positions
|
||||
|
||||
def _load_next_id_txn(self, txn):
|
||||
def _load_next_id_txn(self, txn) -> int:
|
||||
return self._sequence_gen.get_next_id_txn(txn)
|
||||
|
||||
def _load_next_mult_id_txn(self, txn, n: int) -> List[int]:
|
||||
return self._sequence_gen.get_next_mult_txn(txn, n)
|
||||
|
||||
async def get_next(self):
|
||||
"""
|
||||
Usage:
|
||||
@@ -262,6 +279,34 @@ class MultiWriterIdGenerator:
|
||||
|
||||
return manager()
|
||||
|
||||
async def get_next_mult(self, n: int):
|
||||
"""
|
||||
Usage:
|
||||
with await stream_id_gen.get_next_mult(5) as stream_ids:
|
||||
# ... persist event ...
|
||||
"""
|
||||
next_ids = await self._db.runInteraction(
|
||||
"_load_next_mult_id", self._load_next_mult_id_txn, n
|
||||
)
|
||||
|
||||
# Assert the fetched ID is actually greater than what we currently
|
||||
# believe the ID to be. If not, then the sequence and table have got
|
||||
# out of sync somehow.
|
||||
assert self.get_current_token() < min(next_ids)
|
||||
|
||||
with self._lock:
|
||||
self._unfinished_ids.update(next_ids)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def manager():
|
||||
try:
|
||||
yield next_ids
|
||||
finally:
|
||||
for i in next_ids:
|
||||
self._mark_id_as_finished(i)
|
||||
|
||||
return manager()
|
||||
|
||||
def get_next_txn(self, txn: LoggingTransaction):
|
||||
"""
|
||||
Usage:
|
||||
@@ -301,7 +346,7 @@ class MultiWriterIdGenerator:
|
||||
|
||||
# Currently we don't support this operation, as it's not obvious how to
|
||||
# condense the stream positions of multiple writers into a single int.
|
||||
raise NotImplementedError()
|
||||
return self.get_persisted_upto_position()
|
||||
|
||||
def get_current_token_for_writer(self, instance_name: str) -> int:
|
||||
"""Returns the position of the given writer.
|
||||
@@ -326,3 +371,53 @@ class MultiWriterIdGenerator:
|
||||
self._current_positions[instance_name] = max(
|
||||
new_id, self._current_positions.get(instance_name, 0)
|
||||
)
|
||||
|
||||
self._add_persisted_position(new_id)
|
||||
|
||||
def get_persisted_upto_position(self) -> int:
|
||||
"""Get the max position where all previous positions have been
|
||||
persisted.
|
||||
"""
|
||||
|
||||
with self._lock:
|
||||
return self._persisted_upto_position
|
||||
|
||||
def _add_persisted_position(self, new_id: int):
|
||||
"""Record that we have persisted a position.
|
||||
|
||||
This is used to keep the `_current_positions` up to date.
|
||||
"""
|
||||
|
||||
# We require that the lock is locked by caller
|
||||
assert self._lock.locked()
|
||||
|
||||
heapq.heappush(self._known_persisted_positions, new_id)
|
||||
|
||||
# We move the current min position up if the minimum current positions
|
||||
# of all instances is higher (since by definition all positions less
|
||||
# that that have been persisted).
|
||||
min_curr = min(self._current_positions.values())
|
||||
self._persisted_upto_position = max(min_curr, self._persisted_upto_position)
|
||||
|
||||
# We now iterate through the seen positions, discarding those that are
|
||||
# less than the current min positions, and incrementing the min position
|
||||
# if its exactly one greater.
|
||||
while self._known_persisted_positions:
|
||||
if self._known_persisted_positions[0] <= self._persisted_upto_position:
|
||||
heapq.heappop(self._known_persisted_positions)
|
||||
elif (
|
||||
self._known_persisted_positions[0] == self._persisted_upto_position + 1
|
||||
):
|
||||
heapq.heappop(self._known_persisted_positions)
|
||||
self._persisted_upto_position += 1
|
||||
else:
|
||||
# There was a gap in seen positions, so there is nothing more to
|
||||
# do.
|
||||
break
|
||||
|
||||
logger.info(
|
||||
"Got new_id %s: %s, setting persited pos to %s",
|
||||
self._sequence_name,
|
||||
new_id,
|
||||
self._persisted_upto_position,
|
||||
)
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
# limitations under the License.
|
||||
import abc
|
||||
import threading
|
||||
from typing import Callable, Optional
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
|
||||
from synapse.storage.types import Cursor
|
||||
@@ -39,6 +39,12 @@ class PostgresSequenceGenerator(SequenceGenerator):
|
||||
txn.execute("SELECT nextval(?)", (self._sequence_name,))
|
||||
return txn.fetchone()[0]
|
||||
|
||||
def get_next_mult_txn(self, txn: Cursor, n: int) -> List[int]:
|
||||
txn.execute(
|
||||
"SELECT nextval(?) FROM generate_series(1, ?)", (self._sequence_name, n)
|
||||
)
|
||||
return [i for i, in txn]
|
||||
|
||||
|
||||
GetFirstCallbackType = Callable[[Cursor], int]
|
||||
|
||||
|
||||
@@ -182,3 +182,39 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
|
||||
|
||||
self.assertEqual(id_gen.get_positions(), {"master": 8})
|
||||
self.assertEqual(id_gen.get_current_token_for_writer("master"), 8)
|
||||
|
||||
def test_get_persisted_upto_position(self):
|
||||
"""Test that `get_persisted_upto_position` correctly tracks updates to
|
||||
positions.
|
||||
"""
|
||||
|
||||
self._insert_rows("first", 3)
|
||||
self._insert_rows("second", 5)
|
||||
|
||||
id_gen = self._create_id_generator("first")
|
||||
|
||||
# Min is 3 and there is a gap between 5, so we expect it to be 3.
|
||||
self.assertEqual(id_gen.get_persisted_upto_position(), 3)
|
||||
|
||||
# We advance "first" straight to 6. Min is now 5 but there is no gap so
|
||||
# we expect it to be 6
|
||||
id_gen.advance("first", 6)
|
||||
self.assertEqual(id_gen.get_persisted_upto_position(), 6)
|
||||
|
||||
# No gap, so we expect 7.
|
||||
id_gen.advance("second", 7)
|
||||
self.assertEqual(id_gen.get_persisted_upto_position(), 7)
|
||||
|
||||
# We haven't seen 8 yet, so we expect 7 still.
|
||||
id_gen.advance("second", 9)
|
||||
self.assertEqual(id_gen.get_persisted_upto_position(), 7)
|
||||
|
||||
# Now that we've seen 7, 8 and 9 we can got straight to 9.
|
||||
id_gen.advance("first", 8)
|
||||
self.assertEqual(id_gen.get_persisted_upto_position(), 9)
|
||||
|
||||
# Jmp forward with gaps. The minimum is 11, even though we haven't seen
|
||||
# 10 we know that everything before 11 must be persisted.
|
||||
id_gen.advance("first", 11)
|
||||
id_gen.advance("second", 15)
|
||||
self.assertEqual(id_gen.get_persisted_upto_position(), 11)
|
||||
|
||||
Reference in New Issue
Block a user