From aa07c37cf0a3b812e6aa1bb2d97d543e6925c8e2 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 1 Sep 2020 12:41:21 +0100 Subject: [PATCH 1/6] Move and rename `get_devices_with_keys_by_user` (#8204) * Move `get_devices_with_keys_by_user` to `EndToEndKeyWorkerStore` this seems a better fit for it. This commit simply moves the existing code: no other changes at all. * Rename `get_devices_with_keys_by_user` to better reflect what it does. * get_device_stream_token abstract method To avoid referencing fields which are declared in the derived classes, make `get_device_stream_token` abstract, and define that in the classes which define `_device_list_id_gen`. --- changelog.d/8204.misc | 1 + synapse/handlers/device.py | 4 +- synapse/replication/slave/storage/devices.py | 3 ++ synapse/storage/databases/main/__init__.py | 3 ++ synapse/storage/databases/main/devices.py | 52 ++---------------- .../storage/databases/main/end_to_end_keys.py | 53 ++++++++++++++++++- 6 files changed, 67 insertions(+), 49 deletions(-) create mode 100644 changelog.d/8204.misc diff --git a/changelog.d/8204.misc b/changelog.d/8204.misc new file mode 100644 index 0000000000..979c8b227b --- /dev/null +++ b/changelog.d/8204.misc @@ -0,0 +1 @@ +Refactor queries for device keys and cross-signatures. diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index db417d60de..ee4666337a 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -234,7 +234,9 @@ class DeviceWorkerHandler(BaseHandler): return result async def on_federation_query_user_devices(self, user_id): - stream_id, devices = await self.store.get_devices_with_keys_by_user(user_id) + stream_id, devices = await self.store.get_e2e_device_keys_for_federation_query( + user_id + ) master_key = await self.store.get_e2e_cross_signing_key(user_id, "master") self_signing_key = await self.store.get_e2e_cross_signing_key( user_id, "self_signing" diff --git a/synapse/replication/slave/storage/devices.py b/synapse/replication/slave/storage/devices.py index 596c72eb92..3b788c9625 100644 --- a/synapse/replication/slave/storage/devices.py +++ b/synapse/replication/slave/storage/devices.py @@ -48,6 +48,9 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto "DeviceListFederationStreamChangeCache", device_list_max ) + def get_device_stream_token(self) -> int: + return self._device_list_id_gen.get_current_token() + def process_replication_rows(self, stream_name, instance_name, token, rows): if stream_name == DeviceListsStream.NAME: self._device_list_id_gen.advance(instance_name, token) diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index 70cf15dd7f..e6536c8456 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -264,6 +264,9 @@ class DataStore( # Used in _generate_user_daily_visits to keep track of progress self._last_user_visit_update = self._get_start_of_day() + def get_device_stream_token(self) -> int: + return self._device_list_id_gen.get_current_token() + def take_presence_startup_info(self): active_on_startup = self._presence_on_startup self._presence_on_startup = None diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index def96637a2..e8379c73c4 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -14,6 +14,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import abc import logging from typing import Any, Dict, Iterable, List, Optional, Set, Tuple @@ -101,7 +102,7 @@ class DeviceWorkerStore(SQLBaseStore): update included in the response), and the list of updates, where each update is a pair of EDU type and EDU contents. """ - now_stream_id = self._device_list_id_gen.get_current_token() + now_stream_id = self.get_device_stream_token() has_changed = self._device_list_federation_stream_cache.has_entity_changed( destination, int(from_stream_id) @@ -412,8 +413,10 @@ class DeviceWorkerStore(SQLBaseStore): }, ) + @abc.abstractmethod def get_device_stream_token(self) -> int: - return self._device_list_id_gen.get_current_token() + """Get the current stream id from the _device_list_id_gen""" + ... @trace async def get_user_devices_from_cache( @@ -481,51 +484,6 @@ class DeviceWorkerStore(SQLBaseStore): device["device_id"]: db_to_json(device["content"]) for device in devices } - def get_devices_with_keys_by_user(self, user_id: str): - """Get all devices (with any device keys) for a user - - Returns: - Deferred which resolves to (stream_id, devices) - """ - return self.db_pool.runInteraction( - "get_devices_with_keys_by_user", - self._get_devices_with_keys_by_user_txn, - user_id, - ) - - def _get_devices_with_keys_by_user_txn( - self, txn: LoggingTransaction, user_id: str - ) -> Tuple[int, List[JsonDict]]: - now_stream_id = self._device_list_id_gen.get_current_token() - - devices = self._get_e2e_device_keys_txn(txn, [(user_id, None)]) - - if devices: - user_devices = devices[user_id] - results = [] - for device_id, device in user_devices.items(): - result = {"device_id": device_id} - - key_json = device.get("key_json", None) - if key_json: - result["keys"] = db_to_json(key_json) - - if "signatures" in device: - for sig_user_id, sigs in device["signatures"].items(): - result["keys"].setdefault("signatures", {}).setdefault( - sig_user_id, {} - ).update(sigs) - - device_display_name = device.get("device_display_name", None) - if device_display_name: - result["device_display_name"] = device_display_name - - results.append(result) - - return now_stream_id, results - - return now_stream_id, [] - async def get_users_whose_devices_changed( self, from_key: str, user_ids: Iterable[str] ) -> Set[str]: diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index 50ecddf7fa..fb3b1f94de 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -14,6 +14,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import abc from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple from canonicaljson import encode_canonical_json @@ -22,7 +23,7 @@ from twisted.enterprise.adbapi import Connection from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.storage._base import SQLBaseStore, db_to_json -from synapse.storage.database import make_in_list_sql_clause +from synapse.storage.database import LoggingTransaction, make_in_list_sql_clause from synapse.types import JsonDict from synapse.util import json_encoder from synapse.util.caches.descriptors import cached, cachedList @@ -33,6 +34,51 @@ if TYPE_CHECKING: class EndToEndKeyWorkerStore(SQLBaseStore): + def get_e2e_device_keys_for_federation_query(self, user_id: str): + """Get all devices (with any device keys) for a user + + Returns: + Deferred which resolves to (stream_id, devices) + """ + return self.db_pool.runInteraction( + "get_e2e_device_keys_for_federation_query", + self._get_e2e_device_keys_for_federation_query_txn, + user_id, + ) + + def _get_e2e_device_keys_for_federation_query_txn( + self, txn: LoggingTransaction, user_id: str + ) -> Tuple[int, List[JsonDict]]: + now_stream_id = self.get_device_stream_token() + + devices = self._get_e2e_device_keys_txn(txn, [(user_id, None)]) + + if devices: + user_devices = devices[user_id] + results = [] + for device_id, device in user_devices.items(): + result = {"device_id": device_id} + + key_json = device.get("key_json", None) + if key_json: + result["keys"] = db_to_json(key_json) + + if "signatures" in device: + for sig_user_id, sigs in device["signatures"].items(): + result["keys"].setdefault("signatures", {}).setdefault( + sig_user_id, {} + ).update(sigs) + + device_display_name = device.get("device_display_name", None) + if device_display_name: + result["device_display_name"] = device_display_name + + results.append(result) + + return now_stream_id, results + + return now_stream_id, [] + @trace async def get_e2e_device_keys_for_cs_api( self, query_list: List[Tuple[str, Optional[str]]] @@ -533,6 +579,11 @@ class EndToEndKeyWorkerStore(SQLBaseStore): _get_all_user_signature_changes_for_remotes_txn, ) + @abc.abstractmethod + def get_device_stream_token(self) -> int: + """Get the current stream id from the _device_list_id_gen""" + ... + class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys): From 318245eaa6d37a27ca72168356198fdd90abfbb7 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 1 Sep 2020 08:16:58 -0400 Subject: [PATCH 2/6] Do not install setuptools 50.0. (#8212) This is due to compatibility issues with old Python versions. --- INSTALL.md | 2 +- changelog.d/8212.bugfix | 1 + synapse/python_dependencies.py | 4 ++++ 3 files changed, 6 insertions(+), 1 deletion(-) create mode 100644 changelog.d/8212.bugfix diff --git a/INSTALL.md b/INSTALL.md index 22f7b7c029..bdb7769fe9 100644 --- a/INSTALL.md +++ b/INSTALL.md @@ -73,7 +73,7 @@ mkdir -p ~/synapse virtualenv -p python3 ~/synapse/env source ~/synapse/env/bin/activate pip install --upgrade pip -pip install --upgrade setuptools +pip install --upgrade setuptools!=50.0 # setuptools==50.0 fails on some older Python versions pip install matrix-synapse ``` diff --git a/changelog.d/8212.bugfix b/changelog.d/8212.bugfix new file mode 100644 index 0000000000..0f8c0aed92 --- /dev/null +++ b/changelog.d/8212.bugfix @@ -0,0 +1 @@ +Do not install setuptools 50.0. It can lead to a broken configuration on some older Python versions. diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py index 2d995ec456..d666f22674 100644 --- a/synapse/python_dependencies.py +++ b/synapse/python_dependencies.py @@ -74,6 +74,10 @@ REQUIREMENTS = [ "Jinja2>=2.9", "bleach>=1.4.3", "typing-extensions>=3.7.4", + # setuptools is required by a variety of dependencies, unfortunately version + # 50.0 is incompatible with older Python versions, see + # https://github.com/pypa/setuptools/issues/2352 + "setuptools!=50.0", ] CONDITIONAL_REQUIREMENTS = { From bbb3c8641ca1214702dbddd88acfe81cc4fc44ae Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 1 Sep 2020 13:36:25 +0100 Subject: [PATCH 3/6] Make MultiWriterIDGenerator work for streams that use negative stream IDs (#8203) This is so that we can use it for the backfill events stream. --- changelog.d/8203.misc | 1 + synapse/storage/util/id_generators.py | 39 +++++++--- tests/storage/test_id_generators.py | 105 ++++++++++++++++++++++++++ 3 files changed, 134 insertions(+), 11 deletions(-) create mode 100644 changelog.d/8203.misc diff --git a/changelog.d/8203.misc b/changelog.d/8203.misc new file mode 100644 index 0000000000..9fe2224aaa --- /dev/null +++ b/changelog.d/8203.misc @@ -0,0 +1 @@ +Make `MultiWriterIDGenerator` work for streams that use negative values. diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py index b27a4843d0..9f3d23f0a5 100644 --- a/synapse/storage/util/id_generators.py +++ b/synapse/storage/util/id_generators.py @@ -185,6 +185,8 @@ class MultiWriterIdGenerator: id_column: Column that stores the stream ID. sequence_name: The name of the postgres sequence used to generate new IDs. + positive: Whether the IDs are positive (true) or negative (false). + When using negative IDs we go backwards from -1 to -2, -3, etc. """ def __init__( @@ -196,13 +198,19 @@ class MultiWriterIdGenerator: instance_column: str, id_column: str, sequence_name: str, + positive: bool = True, ): self._db = db self._instance_name = instance_name + self._positive = positive + self._return_factor = 1 if positive else -1 # We lock as some functions may be called from DB threads. self._lock = threading.Lock() + # Note: If we are a negative stream then we still store all the IDs as + # positive to make life easier for us, and simply negate the IDs when we + # return them. self._current_positions = self._load_current_ids( db_conn, table, instance_column, id_column ) @@ -233,13 +241,16 @@ class MultiWriterIdGenerator: def _load_current_ids( self, db_conn, table: str, instance_column: str, id_column: str ) -> Dict[str, int]: + # If positive stream aggregate via MAX. For negative stream use MIN + # *and* negate the result to get a positive number. sql = """ - SELECT %(instance)s, MAX(%(id)s) FROM %(table)s + SELECT %(instance)s, %(agg)s(%(id)s) FROM %(table)s GROUP BY %(instance)s """ % { "instance": instance_column, "id": id_column, "table": table, + "agg": "MAX" if self._positive else "-MIN", } cur = db_conn.cursor() @@ -269,15 +280,16 @@ class MultiWriterIdGenerator: # Assert the fetched ID is actually greater than what we currently # believe the ID to be. If not, then the sequence and table have got # out of sync somehow. - assert self.get_current_token_for_writer(self._instance_name) < next_id - with self._lock: + assert self._current_positions.get(self._instance_name, 0) < next_id + self._unfinished_ids.add(next_id) @contextlib.contextmanager def manager(): try: - yield next_id + # Multiply by the return factor so that the ID has correct sign. + yield self._return_factor * next_id finally: self._mark_id_as_finished(next_id) @@ -296,15 +308,15 @@ class MultiWriterIdGenerator: # Assert the fetched ID is actually greater than any ID we've already # seen. If not, then the sequence and table have got out of sync # somehow. - assert max(self.get_positions().values(), default=0) < min(next_ids) - with self._lock: + assert max(self._current_positions.values(), default=0) < min(next_ids) + self._unfinished_ids.update(next_ids) @contextlib.contextmanager def manager(): try: - yield next_ids + yield [self._return_factor * i for i in next_ids] finally: for i in next_ids: self._mark_id_as_finished(i) @@ -327,7 +339,7 @@ class MultiWriterIdGenerator: txn.call_after(self._mark_id_as_finished, next_id) txn.call_on_exception(self._mark_id_as_finished, next_id) - return next_id + return self._return_factor * next_id def _mark_id_as_finished(self, next_id: int): """The ID has finished being processed so we should advance the @@ -359,20 +371,25 @@ class MultiWriterIdGenerator: """ with self._lock: - return self._current_positions.get(instance_name, 0) + return self._return_factor * self._current_positions.get(instance_name, 0) def get_positions(self) -> Dict[str, int]: """Get a copy of the current positon map. """ with self._lock: - return dict(self._current_positions) + return { + name: self._return_factor * i + for name, i in self._current_positions.items() + } def advance(self, instance_name: str, new_id: int): """Advance the postion of the named writer to the given ID, if greater than existing entry. """ + new_id *= self._return_factor + with self._lock: self._current_positions[instance_name] = max( new_id, self._current_positions.get(instance_name, 0) @@ -390,7 +407,7 @@ class MultiWriterIdGenerator: """ with self._lock: - return self._persisted_upto_position + return self._return_factor * self._persisted_upto_position def _add_persisted_position(self, new_id: int): """Record that we have persisted a position. diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py index 14ce21c786..f0a8e32f1e 100644 --- a/tests/storage/test_id_generators.py +++ b/tests/storage/test_id_generators.py @@ -264,3 +264,108 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase): # We assume that so long as `get_next` does correctly advance the # `persisted_upto_position` in this case, then it will be correct in the # other cases that are tested above (since they'll hit the same code). + + +class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase): + """Tests MultiWriterIdGenerator that produce *negative* stream IDs. + """ + + if not USE_POSTGRES_FOR_TESTS: + skip = "Requires Postgres" + + def prepare(self, reactor, clock, hs): + self.store = hs.get_datastore() + self.db_pool = self.store.db_pool # type: DatabasePool + + self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db)) + + def _setup_db(self, txn): + txn.execute("CREATE SEQUENCE foobar_seq") + txn.execute( + """ + CREATE TABLE foobar ( + stream_id BIGINT NOT NULL, + instance_name TEXT NOT NULL, + data TEXT + ); + """ + ) + + def _create_id_generator(self, instance_name="master") -> MultiWriterIdGenerator: + def _create(conn): + return MultiWriterIdGenerator( + conn, + self.db_pool, + instance_name=instance_name, + table="foobar", + instance_column="instance_name", + id_column="stream_id", + sequence_name="foobar_seq", + positive=False, + ) + + return self.get_success(self.db_pool.runWithConnection(_create)) + + def _insert_row(self, instance_name: str, stream_id: int): + """Insert one row as the given instance with given stream_id. + """ + + def _insert(txn): + txn.execute( + "INSERT INTO foobar VALUES (?, ?)", (stream_id, instance_name,), + ) + + self.get_success(self.db_pool.runInteraction("_insert_row", _insert)) + + def test_single_instance(self): + """Test that reads and writes from a single process are handled + correctly. + """ + id_gen = self._create_id_generator() + + with self.get_success(id_gen.get_next()) as stream_id: + self._insert_row("master", stream_id) + + self.assertEqual(id_gen.get_positions(), {"master": -1}) + self.assertEqual(id_gen.get_current_token_for_writer("master"), -1) + self.assertEqual(id_gen.get_persisted_upto_position(), -1) + + with self.get_success(id_gen.get_next_mult(3)) as stream_ids: + for stream_id in stream_ids: + self._insert_row("master", stream_id) + + self.assertEqual(id_gen.get_positions(), {"master": -4}) + self.assertEqual(id_gen.get_current_token_for_writer("master"), -4) + self.assertEqual(id_gen.get_persisted_upto_position(), -4) + + # Test loading from DB by creating a second ID gen + second_id_gen = self._create_id_generator() + + self.assertEqual(second_id_gen.get_positions(), {"master": -4}) + self.assertEqual(second_id_gen.get_current_token_for_writer("master"), -4) + self.assertEqual(second_id_gen.get_persisted_upto_position(), -4) + + def test_multiple_instance(self): + """Tests that having multiple instances that get advanced over + federation works corretly. + """ + id_gen_1 = self._create_id_generator("first") + id_gen_2 = self._create_id_generator("second") + + with self.get_success(id_gen_1.get_next()) as stream_id: + self._insert_row("first", stream_id) + id_gen_2.advance("first", stream_id) + + self.assertEqual(id_gen_1.get_positions(), {"first": -1}) + self.assertEqual(id_gen_2.get_positions(), {"first": -1}) + self.assertEqual(id_gen_1.get_persisted_upto_position(), -1) + self.assertEqual(id_gen_2.get_persisted_upto_position(), -1) + + with self.get_success(id_gen_2.get_next()) as stream_id: + self._insert_row("second", stream_id) + id_gen_1.advance("second", stream_id) + + self.assertEqual(id_gen_1.get_positions(), {"first": -1, "second": -2}) + self.assertEqual(id_gen_2.get_positions(), {"first": -1, "second": -2}) + self.assertEqual(id_gen_1.get_persisted_upto_position(), -2) + self.assertEqual(id_gen_2.get_persisted_upto_position(), -2) From da77520cd1c414c9341da287967feb1bab14cbec Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 1 Sep 2020 08:39:04 -0400 Subject: [PATCH 4/6] Convert additional databases to async/await part 2 (#8200) --- changelog.d/8200.misc | 1 + synapse/events/builder.py | 19 ++++--- synapse/handlers/message.py | 13 ++--- synapse/handlers/room_member.py | 12 +---- synapse/storage/databases/main/client_ips.py | 4 +- synapse/storage/databases/main/directory.py | 6 +-- synapse/storage/databases/main/filtering.py | 5 +- synapse/storage/databases/main/openid.py | 8 ++- synapse/storage/databases/main/profile.py | 6 ++- synapse/storage/databases/main/push_rule.py | 10 ++-- synapse/storage/databases/main/room.py | 49 +++++++++++-------- synapse/storage/databases/main/signatures.py | 40 ++++++++++++--- synapse/storage/databases/main/ui_auth.py | 4 +- .../databases/main/user_erasure_store.py | 8 +-- tests/test_utils/event_injection.py | 7 ++- 15 files changed, 111 insertions(+), 81 deletions(-) create mode 100644 changelog.d/8200.misc diff --git a/changelog.d/8200.misc b/changelog.d/8200.misc new file mode 100644 index 0000000000..dfe4c03171 --- /dev/null +++ b/changelog.d/8200.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/synapse/events/builder.py b/synapse/events/builder.py index 9ed24380dd..7878cd7044 100644 --- a/synapse/events/builder.py +++ b/synapse/events/builder.py @@ -12,7 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional +from typing import Any, Dict, List, Optional, Tuple, Union import attr from nacl.signing import SigningKey @@ -97,14 +97,14 @@ class EventBuilder(object): def is_state(self): return self._state_key is not None - async def build(self, prev_event_ids): + async def build(self, prev_event_ids: List[str]) -> EventBase: """Transform into a fully signed and hashed event Args: - prev_event_ids (list[str]): The event IDs to use as the prev events + prev_event_ids: The event IDs to use as the prev events Returns: - FrozenEvent + The signed and hashed event. """ state_ids = await self._state.get_current_state_ids( @@ -114,8 +114,13 @@ class EventBuilder(object): format_version = self.room_version.event_format if format_version == EventFormatVersions.V1: - auth_events = await self._store.add_event_hashes(auth_ids) - prev_events = await self._store.add_event_hashes(prev_event_ids) + # The types of auth/prev events changes between event versions. + auth_events = await self._store.add_event_hashes( + auth_ids + ) # type: Union[List[str], List[Tuple[str, Dict[str, str]]]] + prev_events = await self._store.add_event_hashes( + prev_event_ids + ) # type: Union[List[str], List[Tuple[str, Dict[str, str]]]] else: auth_events = auth_ids prev_events = prev_event_ids @@ -138,7 +143,7 @@ class EventBuilder(object): "unsigned": self.unsigned, "depth": depth, "prev_state": [], - } + } # type: Dict[str, Any] if self.is_state(): event_dict["state_key"] = self._state_key diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 9d0c38f4df..72bb638167 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -49,14 +49,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.http.send_event import ReplicationSendEventRestServlet from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.storage.state import StateFilter -from synapse.types import ( - Collection, - Requester, - RoomAlias, - StreamToken, - UserID, - create_requester, -) +from synapse.types import Requester, RoomAlias, StreamToken, UserID, create_requester from synapse.util import json_decoder from synapse.util.async_helpers import Linearizer from synapse.util.frozenutils import frozendict_json_encoder @@ -446,7 +439,7 @@ class EventCreationHandler(object): event_dict: dict, token_id: Optional[str] = None, txn_id: Optional[str] = None, - prev_event_ids: Optional[Collection[str]] = None, + prev_event_ids: Optional[List[str]] = None, require_consent: bool = True, ) -> Tuple[EventBase, EventContext]: """ @@ -786,7 +779,7 @@ class EventCreationHandler(object): self, builder: EventBuilder, requester: Optional[Requester] = None, - prev_event_ids: Optional[Collection[str]] = None, + prev_event_ids: Optional[List[str]] = None, ) -> Tuple[EventBase, EventContext]: """Create a new event for a local client diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index cae4d013b8..a7962b0ada 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -38,15 +38,7 @@ from synapse.events.builder import create_local_event_from_event_dict from synapse.events.snapshot import EventContext from synapse.events.validator import EventValidator from synapse.storage.roommember import RoomsForUser -from synapse.types import ( - Collection, - JsonDict, - Requester, - RoomAlias, - RoomID, - StateMap, - UserID, -) +from synapse.types import JsonDict, Requester, RoomAlias, RoomID, StateMap, UserID from synapse.util.async_helpers import Linearizer from synapse.util.distributor import user_joined_room, user_left_room @@ -184,7 +176,7 @@ class RoomMemberHandler(object): target: UserID, room_id: str, membership: str, - prev_event_ids: Collection[str], + prev_event_ids: List[str], txn_id: Optional[str] = None, ratelimit: bool = True, content: Optional[dict] = None, diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py index 216a5925fc..c2fc847fbc 100644 --- a/synapse/storage/databases/main/client_ips.py +++ b/synapse/storage/databases/main/client_ips.py @@ -396,7 +396,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): self._batch_row_update[key] = (user_agent, device_id, now) @wrap_as_background_process("update_client_ips") - def _update_client_ips_batch(self): + async def _update_client_ips_batch(self) -> None: # If the DB pool has already terminated, don't try updating if not self.db_pool.is_running(): @@ -405,7 +405,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore): to_update = self._batch_row_update self._batch_row_update = {} - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "_update_client_ips_batch", self._update_client_ips_batch_txn, to_update ) diff --git a/synapse/storage/databases/main/directory.py b/synapse/storage/databases/main/directory.py index 405b5eafa5..e5060d4c46 100644 --- a/synapse/storage/databases/main/directory.py +++ b/synapse/storage/databases/main/directory.py @@ -159,9 +159,9 @@ class DirectoryStore(DirectoryWorkerStore): return room_id - def update_aliases_for_room( + async def update_aliases_for_room( self, old_room_id: str, new_room_id: str, creator: Optional[str] = None, - ): + ) -> None: """Repoint all of the aliases for a given room, to a different room. Args: @@ -189,6 +189,6 @@ class DirectoryStore(DirectoryWorkerStore): txn, self.get_aliases_for_room, (new_room_id,) ) - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "_update_aliases_for_room_txn", _update_aliases_for_room_txn ) diff --git a/synapse/storage/databases/main/filtering.py b/synapse/storage/databases/main/filtering.py index 45a1760170..d2f5b9a502 100644 --- a/synapse/storage/databases/main/filtering.py +++ b/synapse/storage/databases/main/filtering.py @@ -17,6 +17,7 @@ from canonicaljson import encode_canonical_json from synapse.api.errors import Codes, SynapseError from synapse.storage._base import SQLBaseStore, db_to_json +from synapse.types import JsonDict from synapse.util.caches.descriptors import cached @@ -40,7 +41,7 @@ class FilteringStore(SQLBaseStore): return db_to_json(def_json) - def add_user_filter(self, user_localpart, user_filter): + async def add_user_filter(self, user_localpart: str, user_filter: JsonDict) -> str: def_json = encode_canonical_json(user_filter) # Need an atomic transaction to SELECT the maximal ID so far then @@ -71,4 +72,4 @@ class FilteringStore(SQLBaseStore): return filter_id - return self.db_pool.runInteraction("add_user_filter", _do_txn) + return await self.db_pool.runInteraction("add_user_filter", _do_txn) diff --git a/synapse/storage/databases/main/openid.py b/synapse/storage/databases/main/openid.py index 4db8949da7..2aac64901b 100644 --- a/synapse/storage/databases/main/openid.py +++ b/synapse/storage/databases/main/openid.py @@ -1,3 +1,5 @@ +from typing import Optional + from synapse.storage._base import SQLBaseStore @@ -15,7 +17,9 @@ class OpenIdStore(SQLBaseStore): desc="insert_open_id_token", ) - def get_user_id_for_open_id_token(self, token, ts_now_ms): + async def get_user_id_for_open_id_token( + self, token: str, ts_now_ms: int + ) -> Optional[str]: def get_user_id_for_token_txn(txn): sql = ( "SELECT user_id FROM open_id_tokens" @@ -30,6 +34,6 @@ class OpenIdStore(SQLBaseStore): else: return rows[0][0] - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_user_id_for_token", get_user_id_for_token_txn ) diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py index 301875a672..d2e0685e9e 100644 --- a/synapse/storage/databases/main/profile.py +++ b/synapse/storage/databases/main/profile.py @@ -138,7 +138,9 @@ class ProfileStore(ProfileWorkerStore): desc="delete_remote_profile_cache", ) - def get_remote_profile_cache_entries_that_expire(self, last_checked): + async def get_remote_profile_cache_entries_that_expire( + self, last_checked: int + ) -> Dict[str, str]: """Get all users who haven't been checked since `last_checked` """ @@ -153,7 +155,7 @@ class ProfileStore(ProfileWorkerStore): return self.db_pool.cursor_to_dict(txn) - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_remote_profile_cache_entries_that_expire", _get_remote_profile_cache_entries_that_expire_txn, ) diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py index 2fb5b02d7d..0de802a86b 100644 --- a/synapse/storage/databases/main/push_rule.py +++ b/synapse/storage/databases/main/push_rule.py @@ -18,8 +18,6 @@ import abc import logging from typing import List, Tuple, Union -from twisted.internet import defer - from synapse.push.baserules import list_with_base_rules from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker from synapse.storage._base import SQLBaseStore, db_to_json @@ -149,9 +147,11 @@ class PushRulesWorkerStore( ) return {r["rule_id"]: False if r["enabled"] == 0 else True for r in results} - def have_push_rules_changed_for_user(self, user_id, last_id): + async def have_push_rules_changed_for_user( + self, user_id: str, last_id: int + ) -> bool: if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id): - return defer.succeed(False) + return False else: def have_push_rules_changed_txn(txn): @@ -163,7 +163,7 @@ class PushRulesWorkerStore( (count,) = txn.fetchone() return bool(count) - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "have_push_rules_changed", have_push_rules_changed_txn ) diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index a92641c339..717df97301 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -89,7 +89,7 @@ class RoomWorkerStore(SQLBaseStore): allow_none=True, ) - def get_room_with_stats(self, room_id: str): + async def get_room_with_stats(self, room_id: str) -> Optional[Dict[str, Any]]: """Retrieve room with statistics. Args: @@ -121,7 +121,7 @@ class RoomWorkerStore(SQLBaseStore): res["public"] = bool(res["public"]) return res - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_room_with_stats", get_room_with_stats_txn, room_id ) @@ -133,13 +133,17 @@ class RoomWorkerStore(SQLBaseStore): desc="get_public_room_ids", ) - def count_public_rooms(self, network_tuple, ignore_non_federatable): + async def count_public_rooms( + self, + network_tuple: Optional[ThirdPartyInstanceID], + ignore_non_federatable: bool, + ) -> int: """Counts the number of public rooms as tracked in the room_stats_current and room_stats_state table. Args: - network_tuple (ThirdPartyInstanceID|None) - ignore_non_federatable (bool): If true filters out non-federatable rooms + network_tuple + ignore_non_federatable: If true filters out non-federatable rooms """ def _count_public_rooms_txn(txn): @@ -183,7 +187,7 @@ class RoomWorkerStore(SQLBaseStore): txn.execute(sql, query_args) return txn.fetchone()[0] - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "count_public_rooms", _count_public_rooms_txn ) @@ -586,15 +590,14 @@ class RoomWorkerStore(SQLBaseStore): return row - def get_media_mxcs_in_room(self, room_id): + async def get_media_mxcs_in_room(self, room_id: str) -> Tuple[List[str], List[str]]: """Retrieves all the local and remote media MXC URIs in a given room Args: - room_id (str) + room_id Returns: - The local and remote media as a lists of tuples where the key is - the hostname and the value is the media ID. + The local and remote media as a lists of the media IDs. """ def _get_media_mxcs_in_room_txn(txn): @@ -610,11 +613,13 @@ class RoomWorkerStore(SQLBaseStore): return local_media_mxcs, remote_media_mxcs - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_media_ids_in_room", _get_media_mxcs_in_room_txn ) - def quarantine_media_ids_in_room(self, room_id, quarantined_by): + async def quarantine_media_ids_in_room( + self, room_id: str, quarantined_by: str + ) -> int: """For a room loops through all events with media and quarantines the associated media """ @@ -627,7 +632,7 @@ class RoomWorkerStore(SQLBaseStore): txn, local_mxcs, remote_mxcs, quarantined_by ) - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "quarantine_media_in_room", _quarantine_media_in_room_txn ) @@ -690,9 +695,9 @@ class RoomWorkerStore(SQLBaseStore): return local_media_mxcs, remote_media_mxcs - def quarantine_media_by_id( + async def quarantine_media_by_id( self, server_name: str, media_id: str, quarantined_by: str, - ): + ) -> int: """quarantines a single local or remote media id Args: @@ -711,11 +716,13 @@ class RoomWorkerStore(SQLBaseStore): txn, local_mxcs, remote_mxcs, quarantined_by ) - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "quarantine_media_by_user", _quarantine_media_by_id_txn ) - def quarantine_media_ids_by_user(self, user_id: str, quarantined_by: str): + async def quarantine_media_ids_by_user( + self, user_id: str, quarantined_by: str + ) -> int: """quarantines all local media associated with a single user Args: @@ -727,7 +734,7 @@ class RoomWorkerStore(SQLBaseStore): local_media_ids = self._get_media_ids_by_user_txn(txn, user_id) return self._quarantine_media_txn(txn, local_media_ids, [], quarantined_by) - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "quarantine_media_by_user", _quarantine_media_by_user_txn ) @@ -1284,8 +1291,8 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): ) self.hs.get_notifier().on_new_replication_data() - def get_room_count(self): - """Retrieve a list of all rooms + async def get_room_count(self) -> int: + """Retrieve the total number of rooms. """ def f(txn): @@ -1294,7 +1301,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore): row = txn.fetchone() return row[0] or 0 - return self.db_pool.runInteraction("get_rooms", f) + return await self.db_pool.runInteraction("get_rooms", f) async def add_event_report( self, diff --git a/synapse/storage/databases/main/signatures.py b/synapse/storage/databases/main/signatures.py index be191dd870..c8c67953e4 100644 --- a/synapse/storage/databases/main/signatures.py +++ b/synapse/storage/databases/main/signatures.py @@ -13,9 +13,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import Dict, Iterable, List, Tuple + from unpaddedbase64 import encode_base64 from synapse.storage._base import SQLBaseStore +from synapse.storage.types import Cursor from synapse.util.caches.descriptors import cached, cachedList @@ -29,16 +32,37 @@ class SignatureWorkerStore(SQLBaseStore): @cachedList( cached_method_name="get_event_reference_hash", list_name="event_ids", num_args=1 ) - def get_event_reference_hashes(self, event_ids): + async def get_event_reference_hashes( + self, event_ids: Iterable[str] + ) -> Dict[str, Dict[str, bytes]]: + """Get all hashes for given events. + + Args: + event_ids: The event IDs to get hashes for. + + Returns: + A mapping of event ID to a mapping of algorithm to hash. + """ + def f(txn): return { event_id: self._get_event_reference_hashes_txn(txn, event_id) for event_id in event_ids } - return self.db_pool.runInteraction("get_event_reference_hashes", f) + return await self.db_pool.runInteraction("get_event_reference_hashes", f) - async def add_event_hashes(self, event_ids): + async def add_event_hashes( + self, event_ids: Iterable[str] + ) -> List[Tuple[str, Dict[str, str]]]: + """ + + Args: + event_ids: The event IDs + + Returns: + A list of tuples of event ID and a mapping of algorithm to base-64 encoded hash. + """ hashes = await self.get_event_reference_hashes(event_ids) hashes = { e_id: {k: encode_base64(v) for k, v in h.items() if k == "sha256"} @@ -47,13 +71,15 @@ class SignatureWorkerStore(SQLBaseStore): return list(hashes.items()) - def _get_event_reference_hashes_txn(self, txn, event_id): + def _get_event_reference_hashes_txn( + self, txn: Cursor, event_id: str + ) -> Dict[str, bytes]: """Get all the hashes for a given PDU. Args: - txn (cursor): - event_id (str): Id for the Event. + txn: + event_id: Id for the Event. Returns: - A dict[unicode, bytes] of algorithm -> hash. + A mapping of algorithm -> hash. """ query = ( "SELECT algorithm, hash" diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py index 9eef8e57c5..b89668d561 100644 --- a/synapse/storage/databases/main/ui_auth.py +++ b/synapse/storage/databases/main/ui_auth.py @@ -290,7 +290,7 @@ class UIAuthWorkerStore(SQLBaseStore): class UIAuthStore(UIAuthWorkerStore): - def delete_old_ui_auth_sessions(self, expiration_time: int): + async def delete_old_ui_auth_sessions(self, expiration_time: int) -> None: """ Remove sessions which were last used earlier than the expiration time. @@ -299,7 +299,7 @@ class UIAuthStore(UIAuthWorkerStore): This is an epoch time in milliseconds. """ - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "delete_old_ui_auth_sessions", self._delete_old_ui_auth_sessions_txn, expiration_time, diff --git a/synapse/storage/databases/main/user_erasure_store.py b/synapse/storage/databases/main/user_erasure_store.py index e3547e53b3..2f7c95fc74 100644 --- a/synapse/storage/databases/main/user_erasure_store.py +++ b/synapse/storage/databases/main/user_erasure_store.py @@ -66,7 +66,7 @@ class UserErasureWorkerStore(SQLBaseStore): class UserErasureStore(UserErasureWorkerStore): - def mark_user_erased(self, user_id: str) -> None: + async def mark_user_erased(self, user_id: str) -> None: """Indicate that user_id wishes their message history to be erased. Args: @@ -84,9 +84,9 @@ class UserErasureStore(UserErasureWorkerStore): self._invalidate_cache_and_stream(txn, self.is_user_erased, (user_id,)) - return self.db_pool.runInteraction("mark_user_erased", f) + await self.db_pool.runInteraction("mark_user_erased", f) - def mark_user_not_erased(self, user_id: str) -> None: + async def mark_user_not_erased(self, user_id: str) -> None: """Indicate that user_id is no longer erased. Args: @@ -106,4 +106,4 @@ class UserErasureStore(UserErasureWorkerStore): self._invalidate_cache_and_stream(txn, self.is_user_erased, (user_id,)) - return self.db_pool.runInteraction("mark_user_not_erased", f) + await self.db_pool.runInteraction("mark_user_not_erased", f) diff --git a/tests/test_utils/event_injection.py b/tests/test_utils/event_injection.py index 8522c6fc09..fb1ca90336 100644 --- a/tests/test_utils/event_injection.py +++ b/tests/test_utils/event_injection.py @@ -13,14 +13,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple +from typing import List, Optional, Tuple import synapse.server from synapse.api.constants import EventTypes from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.events import EventBase from synapse.events.snapshot import EventContext -from synapse.types import Collection """ Utility functions for poking events into the storage of the server under test. @@ -58,7 +57,7 @@ async def inject_member_event( async def inject_event( hs: synapse.server.HomeServer, room_version: Optional[str] = None, - prev_event_ids: Optional[Collection[str]] = None, + prev_event_ids: Optional[List[str]] = None, **kwargs ) -> EventBase: """Inject a generic event into a room @@ -80,7 +79,7 @@ async def inject_event( async def create_event( hs: synapse.server.HomeServer, room_version: Optional[str] = None, - prev_event_ids: Optional[Collection[str]] = None, + prev_event_ids: Optional[List[str]] = None, **kwargs ) -> Tuple[EventBase, EventContext]: if room_version is None: From 5bf8e5f55b49f9e46a7fe7d7872e6b16d38bffd3 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 1 Sep 2020 09:15:22 -0400 Subject: [PATCH 5/6] Convert the well known resolver to async (#8214) --- changelog.d/8214.misc | 1 + mypy.ini | 1 + .../federation/matrix_federation_agent.py | 4 +- .../http/federation/well_known_resolver.py | 57 ++++++++++--------- .../test_matrix_federation_agent.py | 24 ++++++-- 5 files changed, 53 insertions(+), 34 deletions(-) create mode 100644 changelog.d/8214.misc diff --git a/changelog.d/8214.misc b/changelog.d/8214.misc new file mode 100644 index 0000000000..e26764dea1 --- /dev/null +++ b/changelog.d/8214.misc @@ -0,0 +1 @@ + Convert various parts of the codebase to async/await. diff --git a/mypy.ini b/mypy.ini index 4213e31b03..21c6f523a0 100644 --- a/mypy.ini +++ b/mypy.ini @@ -28,6 +28,7 @@ files = synapse/handlers/saml_handler.py, synapse/handlers/sync.py, synapse/handlers/ui_auth, + synapse/http/federation/well_known_resolver.py, synapse/http/server.py, synapse/http/site.py, synapse/logging/, diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py index 369bf9c2fc..782d39d4ca 100644 --- a/synapse/http/federation/matrix_federation_agent.py +++ b/synapse/http/federation/matrix_federation_agent.py @@ -134,8 +134,8 @@ class MatrixFederationAgent(object): and not _is_ip_literal(parsed_uri.hostname) and not parsed_uri.port ): - well_known_result = yield self._well_known_resolver.get_well_known( - parsed_uri.hostname + well_known_result = yield defer.ensureDeferred( + self._well_known_resolver.get_well_known(parsed_uri.hostname) ) delegated_server = well_known_result.delegated_server diff --git a/synapse/http/federation/well_known_resolver.py b/synapse/http/federation/well_known_resolver.py index f794315deb..cdb6bec56e 100644 --- a/synapse/http/federation/well_known_resolver.py +++ b/synapse/http/federation/well_known_resolver.py @@ -16,6 +16,7 @@ import logging import random import time +from typing import Callable, Dict, Optional, Tuple import attr @@ -23,6 +24,7 @@ from twisted.internet import defer from twisted.web.client import RedirectAgent, readBody from twisted.web.http import stringToDatetime from twisted.web.http_headers import Headers +from twisted.web.iweb import IResponse from synapse.logging.context import make_deferred_yieldable from synapse.util import Clock, json_decoder @@ -99,15 +101,14 @@ class WellKnownResolver(object): self._well_known_agent = RedirectAgent(agent) self.user_agent = user_agent - @defer.inlineCallbacks - def get_well_known(self, server_name): + async def get_well_known(self, server_name: bytes) -> WellKnownLookupResult: """Attempt to fetch and parse a .well-known file for the given server Args: - server_name (bytes): name of the server, from the requested url + server_name: name of the server, from the requested url Returns: - Deferred[WellKnownLookupResult]: The result of the lookup + The result of the lookup """ try: prev_result, expiry, ttl = self._well_known_cache.get_with_expiry( @@ -124,7 +125,9 @@ class WellKnownResolver(object): # requests for the same server in parallel? try: with Measure(self._clock, "get_well_known"): - result, cache_period = yield self._fetch_well_known(server_name) + result, cache_period = await self._fetch_well_known( + server_name + ) # type: Tuple[Optional[bytes], float] except _FetchWellKnownFailure as e: if prev_result and e.temporary: @@ -153,18 +156,17 @@ class WellKnownResolver(object): return WellKnownLookupResult(delegated_server=result) - @defer.inlineCallbacks - def _fetch_well_known(self, server_name): + async def _fetch_well_known(self, server_name: bytes) -> Tuple[bytes, float]: """Actually fetch and parse a .well-known, without checking the cache Args: - server_name (bytes): name of the server, from the requested url + server_name: name of the server, from the requested url Raises: _FetchWellKnownFailure if we fail to lookup a result Returns: - Deferred[Tuple[bytes,int]]: The lookup result and cache period. + The lookup result and cache period. """ had_valid_well_known = self._had_valid_well_known_cache.get(server_name, False) @@ -172,7 +174,7 @@ class WellKnownResolver(object): # We do this in two steps to differentiate between possibly transient # errors (e.g. can't connect to host, 503 response) and more permenant # errors (such as getting a 404 response). - response, body = yield self._make_well_known_request( + response, body = await self._make_well_known_request( server_name, retry=had_valid_well_known ) @@ -215,20 +217,20 @@ class WellKnownResolver(object): return result, cache_period - @defer.inlineCallbacks - def _make_well_known_request(self, server_name, retry): + async def _make_well_known_request( + self, server_name: bytes, retry: bool + ) -> Tuple[IResponse, bytes]: """Make the well known request. This will retry the request if requested and it fails (with unable to connect or receives a 5xx error). Args: - server_name (bytes) - retry (bool): Whether to retry the request if it fails. + server_name: name of the server, from the requested url + retry: Whether to retry the request if it fails. Returns: - Deferred[tuple[IResponse, bytes]] Returns the response object and - body. Response may be a non-200 response. + Returns the response object and body. Response may be a non-200 response. """ uri = b"https://%s/.well-known/matrix/server" % (server_name,) uri_str = uri.decode("ascii") @@ -243,12 +245,12 @@ class WellKnownResolver(object): logger.info("Fetching %s", uri_str) try: - response = yield make_deferred_yieldable( + response = await make_deferred_yieldable( self._well_known_agent.request( b"GET", uri, headers=Headers(headers) ) ) - body = yield make_deferred_yieldable(readBody(response)) + body = await make_deferred_yieldable(readBody(response)) if 500 <= response.code < 600: raise Exception("Non-200 response %s" % (response.code,)) @@ -265,21 +267,24 @@ class WellKnownResolver(object): logger.info("Error fetching %s: %s. Retrying", uri_str, e) # Sleep briefly in the hopes that they come back up - yield self._clock.sleep(0.5) + await self._clock.sleep(0.5) -def _cache_period_from_headers(headers, time_now=time.time): +def _cache_period_from_headers( + headers: Headers, time_now: Callable[[], float] = time.time +) -> Optional[float]: cache_controls = _parse_cache_control(headers) if b"no-store" in cache_controls: return 0 if b"max-age" in cache_controls: - try: - max_age = int(cache_controls[b"max-age"]) - return max_age - except ValueError: - pass + max_age = cache_controls[b"max-age"] + if max_age: + try: + return int(max_age) + except ValueError: + pass expires = headers.getRawHeaders(b"expires") if expires is not None: @@ -295,7 +300,7 @@ def _cache_period_from_headers(headers, time_now=time.time): return None -def _parse_cache_control(headers): +def _parse_cache_control(headers: Headers) -> Dict[bytes, Optional[bytes]]: cache_controls = {} for hdr in headers.getRawHeaders(b"cache-control", []): for directive in hdr.split(b","): diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py index 69945a8f98..eb78ab412a 100644 --- a/tests/http/federation/test_matrix_federation_agent.py +++ b/tests/http/federation/test_matrix_federation_agent.py @@ -972,7 +972,9 @@ class MatrixFederationAgentTests(unittest.TestCase): def test_well_known_cache(self): self.reactor.lookups["testserv"] = "1.2.3.4" - fetch_d = self.well_known_resolver.get_well_known(b"testserv") + fetch_d = defer.ensureDeferred( + self.well_known_resolver.get_well_known(b"testserv") + ) # there should be an attempt to connect on port 443 for the .well-known clients = self.reactor.tcpClients @@ -995,7 +997,9 @@ class MatrixFederationAgentTests(unittest.TestCase): well_known_server.loseConnection() # repeat the request: it should hit the cache - fetch_d = self.well_known_resolver.get_well_known(b"testserv") + fetch_d = defer.ensureDeferred( + self.well_known_resolver.get_well_known(b"testserv") + ) r = self.successResultOf(fetch_d) self.assertEqual(r.delegated_server, b"target-server") @@ -1003,7 +1007,9 @@ class MatrixFederationAgentTests(unittest.TestCase): self.reactor.pump((1000.0,)) # now it should connect again - fetch_d = self.well_known_resolver.get_well_known(b"testserv") + fetch_d = defer.ensureDeferred( + self.well_known_resolver.get_well_known(b"testserv") + ) self.assertEqual(len(clients), 1) (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0) @@ -1026,7 +1032,9 @@ class MatrixFederationAgentTests(unittest.TestCase): self.reactor.lookups["testserv"] = "1.2.3.4" - fetch_d = self.well_known_resolver.get_well_known(b"testserv") + fetch_d = defer.ensureDeferred( + self.well_known_resolver.get_well_known(b"testserv") + ) # there should be an attempt to connect on port 443 for the .well-known clients = self.reactor.tcpClients @@ -1052,7 +1060,9 @@ class MatrixFederationAgentTests(unittest.TestCase): # another lookup. self.reactor.pump((900.0,)) - fetch_d = self.well_known_resolver.get_well_known(b"testserv") + fetch_d = defer.ensureDeferred( + self.well_known_resolver.get_well_known(b"testserv") + ) # The resolver may retry a few times, so fonx all requests that come along attempts = 0 @@ -1082,7 +1092,9 @@ class MatrixFederationAgentTests(unittest.TestCase): self.reactor.pump((10000.0,)) # Repated the request, this time it should fail if the lookup fails. - fetch_d = self.well_known_resolver.get_well_known(b"testserv") + fetch_d = defer.ensureDeferred( + self.well_known_resolver.get_well_known(b"testserv") + ) clients = self.reactor.tcpClients (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0) From 54f8d73c005cf0401d05fc90e857da253f9d1168 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 1 Sep 2020 09:21:48 -0400 Subject: [PATCH 6/6] Convert additional databases to async/await (#8199) --- changelog.d/8199.misc | 1 + synapse/storage/databases/main/__init__.py | 50 +++++---- synapse/storage/databases/main/devices.py | 38 +++---- .../storage/databases/main/events_worker.py | 48 ++++---- .../storage/databases/main/purge_events.py | 30 ++--- synapse/storage/databases/main/receipts.py | 14 ++- synapse/storage/databases/main/relations.py | 103 ++++++++---------- 7 files changed, 147 insertions(+), 137 deletions(-) create mode 100644 changelog.d/8199.misc diff --git a/changelog.d/8199.misc b/changelog.d/8199.misc new file mode 100644 index 0000000000..dfe4c03171 --- /dev/null +++ b/changelog.d/8199.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index e6536c8456..99890ffbf3 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -18,7 +18,7 @@ import calendar import logging import time -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Tuple from synapse.api.constants import PresenceState from synapse.config.homeserver import HomeServerConfig @@ -294,16 +294,16 @@ class DataStore( return [UserPresenceState(**row) for row in rows] - def count_daily_users(self): + async def count_daily_users(self) -> int: """ Counts the number of users who used this homeserver in the last 24 hours. """ yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24) - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "count_daily_users", self._count_users, yesterday ) - def count_monthly_users(self): + async def count_monthly_users(self) -> int: """ Counts the number of users who used this homeserver in the last 30 days. Note this method is intended for phonehome metrics only and is different @@ -311,7 +311,7 @@ class DataStore( amongst other things, includes a 3 day grace period before a user counts. """ thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30) - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "count_monthly_users", self._count_users, thirty_days_ago ) @@ -330,15 +330,15 @@ class DataStore( (count,) = txn.fetchone() return count - def count_r30_users(self): + async def count_r30_users(self) -> Dict[str, int]: """ Counts the number of 30 day retained users, defined as:- * Users who have created their accounts more than 30 days ago * Where last seen at most 30 days ago * Where account creation and last_seen are > 30 days apart - Returns counts globaly for a given user as well as breaking - by platform + Returns: + A mapping of counts globally as well as broken out by platform. """ def _count_r30_users(txn): @@ -411,7 +411,7 @@ class DataStore( return results - return self.db_pool.runInteraction("count_r30_users", _count_r30_users) + return await self.db_pool.runInteraction("count_r30_users", _count_r30_users) def _get_start_of_day(self): """ @@ -421,7 +421,7 @@ class DataStore( today_start = calendar.timegm((now.tm_year, now.tm_mon, now.tm_mday, 0, 0, 0)) return today_start * 1000 - def generate_user_daily_visits(self): + async def generate_user_daily_visits(self) -> None: """ Generates daily visit data for use in cohort/ retention analysis """ @@ -476,7 +476,7 @@ class DataStore( # frequently self._last_user_visit_update = now - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "generate_user_daily_visits", _generate_user_daily_visits ) @@ -500,22 +500,28 @@ class DataStore( desc="get_users", ) - def get_users_paginate( - self, start, limit, user_id=None, name=None, guests=True, deactivated=False - ): + async def get_users_paginate( + self, + start: int, + limit: int, + user_id: Optional[str] = None, + name: Optional[str] = None, + guests: bool = True, + deactivated: bool = False, + ) -> Tuple[List[Dict[str, Any]], int]: """Function to retrieve a paginated list of users from users list. This will return a json list of users and the total number of users matching the filter criteria. Args: - start (int): start number to begin the query from - limit (int): number of rows to retrieve - user_id (string): search for user_id. ignored if name is not None - name (string): search for local part of user_id or display name - guests (bool): whether to in include guest users - deactivated (bool): whether to include deactivated users + start: start number to begin the query from + limit: number of rows to retrieve + user_id: search for user_id. ignored if name is not None + name: search for local part of user_id or display name + guests: whether to in include guest users + deactivated: whether to include deactivated users Returns: - defer.Deferred: resolves to list[dict[str, Any]], int + A tuple of a list of mappings from user to information and a count of total users. """ def get_users_paginate_txn(txn): @@ -558,7 +564,7 @@ class DataStore( users = self.db_pool.cursor_to_dict(txn) return users, count - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_users_paginate_txn", get_users_paginate_txn ) diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index e8379c73c4..a29157d979 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -313,9 +313,9 @@ class DeviceWorkerStore(SQLBaseStore): return results - def _get_last_device_update_for_remote_user( + async def _get_last_device_update_for_remote_user( self, destination: str, user_id: str, from_stream_id: int - ): + ) -> int: def f(txn): prev_sent_id_sql = """ SELECT coalesce(max(stream_id), 0) as stream_id @@ -326,12 +326,16 @@ class DeviceWorkerStore(SQLBaseStore): rows = txn.fetchall() return rows[0][0] - return self.db_pool.runInteraction("get_last_device_update_for_remote_user", f) + return await self.db_pool.runInteraction( + "get_last_device_update_for_remote_user", f + ) - def mark_as_sent_devices_by_remote(self, destination: str, stream_id: int): + async def mark_as_sent_devices_by_remote( + self, destination: str, stream_id: int + ) -> None: """Mark that updates have successfully been sent to the destination. """ - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "mark_as_sent_devices_by_remote", self._mark_as_sent_devices_by_remote_txn, destination, @@ -684,7 +688,7 @@ class DeviceWorkerStore(SQLBaseStore): desc="make_remote_user_device_cache_as_stale", ) - def mark_remote_user_device_list_as_unsubscribed(self, user_id: str): + async def mark_remote_user_device_list_as_unsubscribed(self, user_id: str) -> None: """Mark that we no longer track device lists for remote user. """ @@ -698,7 +702,7 @@ class DeviceWorkerStore(SQLBaseStore): txn, self.get_device_list_last_stream_id_for_remote, (user_id,) ) - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "mark_remote_user_device_list_as_unsubscribed", _mark_remote_user_device_list_as_unsubscribed_txn, ) @@ -959,9 +963,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): desc="update_device", ) - def update_remote_device_list_cache_entry( + async def update_remote_device_list_cache_entry( self, user_id: str, device_id: str, content: JsonDict, stream_id: int - ): + ) -> None: """Updates a single device in the cache of a remote user's devicelist. Note: assumes that we are the only thread that can be updating this user's @@ -972,11 +976,8 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): device_id: ID of decivice being updated content: new data on this device stream_id: the version of the device list - - Returns: - Deferred[None] """ - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "update_remote_device_list_cache_entry", self._update_remote_device_list_cache_entry_txn, user_id, @@ -1028,9 +1029,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): lock=False, ) - def update_remote_device_list_cache( + async def update_remote_device_list_cache( self, user_id: str, devices: List[dict], stream_id: int - ): + ) -> None: """Replace the entire cache of the remote user's devices. Note: assumes that we are the only thread that can be updating this user's @@ -1040,11 +1041,8 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): user_id: User to update device list for devices: list of device objects supplied over federation stream_id: the version of the device list - - Returns: - Deferred[None] """ - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "update_remote_device_list_cache", self._update_remote_device_list_cache_txn, user_id, @@ -1054,7 +1052,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore): def _update_remote_device_list_cache_txn( self, txn: LoggingTransaction, user_id: str, devices: List[dict], stream_id: int - ): + ) -> None: self.db_pool.simple_delete_txn( txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id} ) diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index e6247d682d..a7a73cc3d8 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -823,20 +823,24 @@ class EventsWorkerStore(SQLBaseStore): return event_dict - def _maybe_redact_event_row(self, original_ev, redactions, event_map): + def _maybe_redact_event_row( + self, + original_ev: EventBase, + redactions: Iterable[str], + event_map: Dict[str, EventBase], + ) -> Optional[EventBase]: """Given an event object and a list of possible redacting event ids, determine whether to honour any of those redactions and if so return a redacted event. Args: - original_ev (EventBase): - redactions (iterable[str]): list of event ids of potential redaction events - event_map (dict[str, EventBase]): other events which have been fetched, in - which we can look up the redaaction events. Map from event id to event. + original_ev: The original event. + redactions: list of event ids of potential redaction events + event_map: other events which have been fetched, in which we can + look up the redaaction events. Map from event id to event. Returns: - Deferred[EventBase|None]: if the event should be redacted, a pruned - event object. Otherwise, None. + If the event should be redacted, a pruned event object. Otherwise, None. """ if original_ev.type == "m.room.create": # we choose to ignore redactions of m.room.create events. @@ -946,17 +950,17 @@ class EventsWorkerStore(SQLBaseStore): row = txn.fetchone() return row[0] if row else 0 - def get_current_state_event_counts(self, room_id): + async def get_current_state_event_counts(self, room_id: str) -> int: """ Gets the current number of state events in a room. Args: - room_id (str) + room_id: The room ID to query. Returns: - Deferred[int] + The current number of state events. """ - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_current_state_event_counts", self._get_current_state_event_counts_txn, room_id, @@ -991,7 +995,9 @@ class EventsWorkerStore(SQLBaseStore): """The current maximum token that events have reached""" return self._stream_id_gen.get_current_token() - def get_all_new_forward_event_rows(self, last_id, current_id, limit): + async def get_all_new_forward_event_rows( + self, last_id: int, current_id: int, limit: int + ) -> List[Tuple]: """Returns new events, for the Events replication stream Args: @@ -999,7 +1005,7 @@ class EventsWorkerStore(SQLBaseStore): current_id: the maximum stream_id to return up to limit: the maximum number of rows to return - Returns: Deferred[List[Tuple]] + Returns: a list of events stream rows. Each tuple consists of a stream id as the first element, followed by fields suitable for casting into an EventsStreamRow. @@ -1020,18 +1026,20 @@ class EventsWorkerStore(SQLBaseStore): txn.execute(sql, (last_id, current_id, limit)) return txn.fetchall() - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_all_new_forward_event_rows", get_all_new_forward_event_rows ) - def get_ex_outlier_stream_rows(self, last_id, current_id): + async def get_ex_outlier_stream_rows( + self, last_id: int, current_id: int + ) -> List[Tuple]: """Returns de-outliered events, for the Events replication stream Args: last_id: the last stream_id from the previous batch. current_id: the maximum stream_id to return up to - Returns: Deferred[List[Tuple]] + Returns: a list of events stream rows. Each tuple consists of a stream id as the first element, followed by fields suitable for casting into an EventsStreamRow. @@ -1054,7 +1062,7 @@ class EventsWorkerStore(SQLBaseStore): txn.execute(sql, (last_id, current_id)) return txn.fetchall() - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_ex_outlier_stream_rows", get_ex_outlier_stream_rows_txn ) @@ -1226,11 +1234,11 @@ class EventsWorkerStore(SQLBaseStore): return (int(res["topological_ordering"]), int(res["stream_ordering"])) - def get_next_event_to_expire(self): + async def get_next_event_to_expire(self) -> Optional[Tuple[str, int]]: """Retrieve the entry with the lowest expiry timestamp in the event_expiry table, or None if there's no more event to expire. - Returns: Deferred[Optional[Tuple[str, int]]] + Returns: A tuple containing the event ID as its first element and an expiry timestamp as its second one, if there's at least one row in the event_expiry table. None otherwise. @@ -1246,6 +1254,6 @@ class EventsWorkerStore(SQLBaseStore): return txn.fetchone() - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( desc="get_next_event_to_expire", func=get_next_event_to_expire_txn ) diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py index 3526b6fd66..ea833829ae 100644 --- a/synapse/storage/databases/main/purge_events.py +++ b/synapse/storage/databases/main/purge_events.py @@ -14,7 +14,7 @@ # limitations under the License. import logging -from typing import Any, Tuple +from typing import Any, List, Set, Tuple from synapse.api.errors import SynapseError from synapse.storage._base import SQLBaseStore @@ -25,25 +25,24 @@ logger = logging.getLogger(__name__) class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore): - def purge_history(self, room_id, token, delete_local_events): + async def purge_history( + self, room_id: str, token: str, delete_local_events: bool + ) -> Set[int]: """Deletes room history before a certain point Args: - room_id (str): - - token (str): A topological token to delete events before - - delete_local_events (bool): + room_id: + token: A topological token to delete events before + delete_local_events: if True, we will delete local events as well as remote ones (instead of just marking them as outliers and deleting their state groups). Returns: - Deferred[set[int]]: The set of state groups that are referenced by - deleted events. + The set of state groups that are referenced by deleted events. """ - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "purge_history", self._purge_history_txn, room_id, @@ -283,17 +282,18 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore): return referenced_state_groups - def purge_room(self, room_id): + async def purge_room(self, room_id: str) -> List[int]: """Deletes all record of a room Args: - room_id (str) + room_id Returns: - Deferred[List[int]]: The list of state groups to delete. + The list of state groups to delete. """ - - return self.db_pool.runInteraction("purge_room", self._purge_room_txn, room_id) + return await self.db_pool.runInteraction( + "purge_room", self._purge_room_txn, room_id + ) def _purge_room_txn(self, txn, room_id): # First we fetch all the state groups that should be deleted, before diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index 436f22ad2d..4a0d5a320e 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -276,12 +276,14 @@ class ReceiptsWorkerStore(SQLBaseStore): } return results - def get_users_sent_receipts_between(self, last_id: int, current_id: int): + async def get_users_sent_receipts_between( + self, last_id: int, current_id: int + ) -> List[str]: """Get all users who sent receipts between `last_id` exclusive and `current_id` inclusive. Returns: - Deferred[List[str]] + The list of users. """ if last_id == current_id: @@ -296,7 +298,7 @@ class ReceiptsWorkerStore(SQLBaseStore): return [r[0] for r in txn] - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_users_sent_receipts_between", _get_users_sent_receipts_between_txn ) @@ -553,8 +555,10 @@ class ReceiptsStore(ReceiptsWorkerStore): return stream_id, max_persisted_id - def insert_graph_receipt(self, room_id, receipt_type, user_id, event_ids, data): - return self.db_pool.runInteraction( + async def insert_graph_receipt( + self, room_id, receipt_type, user_id, event_ids, data + ): + return await self.db_pool.runInteraction( "insert_graph_receipt", self.insert_graph_receipt_txn, room_id, diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index a9ceffc20e..5cd61547f7 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -34,38 +34,33 @@ logger = logging.getLogger(__name__) class RelationsWorkerStore(SQLBaseStore): @cached(tree=True) - def get_relations_for_event( + async def get_relations_for_event( self, - event_id, - relation_type=None, - event_type=None, - aggregation_key=None, - limit=5, - direction="b", - from_token=None, - to_token=None, - ): + event_id: str, + relation_type: Optional[str] = None, + event_type: Optional[str] = None, + aggregation_key: Optional[str] = None, + limit: int = 5, + direction: str = "b", + from_token: Optional[RelationPaginationToken] = None, + to_token: Optional[RelationPaginationToken] = None, + ) -> PaginationChunk: """Get a list of relations for an event, ordered by topological ordering. Args: - event_id (str): Fetch events that relate to this event ID. - relation_type (str|None): Only fetch events with this relation - type, if given. - event_type (str|None): Only fetch events with this event type, if - given. - aggregation_key (str|None): Only fetch events with this aggregation - key, if given. - limit (int): Only fetch the most recent `limit` events. - direction (str): Whether to fetch the most recent first (`"b"`) or - the oldest first (`"f"`). - from_token (RelationPaginationToken|None): Fetch rows from the given - token, or from the start if None. - to_token (RelationPaginationToken|None): Fetch rows up to the given - token, or up to the end if None. + event_id: Fetch events that relate to this event ID. + relation_type: Only fetch events with this relation type, if given. + event_type: Only fetch events with this event type, if given. + aggregation_key: Only fetch events with this aggregation key, if given. + limit: Only fetch the most recent `limit` events. + direction: Whether to fetch the most recent first (`"b"`) or the + oldest first (`"f"`). + from_token: Fetch rows from the given token, or from the start if None. + to_token: Fetch rows up to the given token, or up to the end if None. Returns: - Deferred[PaginationChunk]: List of event IDs that match relations - requested. The rows are of the form `{"event_id": "..."}`. + List of event IDs that match relations requested. The rows are of + the form `{"event_id": "..."}`. """ where_clause = ["relates_to_id = ?"] @@ -131,20 +126,20 @@ class RelationsWorkerStore(SQLBaseStore): chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token ) - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_recent_references_for_event", _get_recent_references_for_event_txn ) @cached(tree=True) - def get_aggregation_groups_for_event( + async def get_aggregation_groups_for_event( self, - event_id, - event_type=None, - limit=5, - direction="b", - from_token=None, - to_token=None, - ): + event_id: str, + event_type: Optional[str] = None, + limit: int = 5, + direction: str = "b", + from_token: Optional[AggregationPaginationToken] = None, + to_token: Optional[AggregationPaginationToken] = None, + ) -> PaginationChunk: """Get a list of annotations on the event, grouped by event type and aggregation key, sorted by count. @@ -152,21 +147,17 @@ class RelationsWorkerStore(SQLBaseStore): on an event. Args: - event_id (str): Fetch events that relate to this event ID. - event_type (str|None): Only fetch events with this event type, if - given. - limit (int): Only fetch the `limit` groups. - direction (str): Whether to fetch the highest count first (`"b"`) or + event_id: Fetch events that relate to this event ID. + event_type: Only fetch events with this event type, if given. + limit: Only fetch the `limit` groups. + direction: Whether to fetch the highest count first (`"b"`) or the lowest count first (`"f"`). - from_token (AggregationPaginationToken|None): Fetch rows from the - given token, or from the start if None. - to_token (AggregationPaginationToken|None): Fetch rows up to the - given token, or up to the end if None. - + from_token: Fetch rows from the given token, or from the start if None. + to_token: Fetch rows up to the given token, or up to the end if None. Returns: - Deferred[PaginationChunk]: List of groups of annotations that - match. Each row is a dict with `type`, `key` and `count` fields. + List of groups of annotations that match. Each row is a dict with + `type`, `key` and `count` fields. """ where_clause = ["relates_to_id = ?", "relation_type = ?"] @@ -225,7 +216,7 @@ class RelationsWorkerStore(SQLBaseStore): chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token ) - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn ) @@ -279,18 +270,20 @@ class RelationsWorkerStore(SQLBaseStore): return await self.get_event(edit_id, allow_none=True) - def has_user_annotated_event(self, parent_id, event_type, aggregation_key, sender): + async def has_user_annotated_event( + self, parent_id: str, event_type: str, aggregation_key: str, sender: str + ) -> bool: """Check if a user has already annotated an event with the same key (e.g. already liked an event). Args: - parent_id (str): The event being annotated - event_type (str): The event type of the annotation - aggregation_key (str): The aggregation key of the annotation - sender (str): The sender of the annotation + parent_id: The event being annotated + event_type: The event type of the annotation + aggregation_key: The aggregation key of the annotation + sender: The sender of the annotation Returns: - Deferred[bool] + True if the event is already annotated. """ sql = """ @@ -319,7 +312,7 @@ class RelationsWorkerStore(SQLBaseStore): return bool(txn.fetchone()) - return self.db_pool.runInteraction( + return await self.db_pool.runInteraction( "get_if_user_has_annotated_event", _get_if_user_has_annotated_event )