1
0

Merge commit 'dc22090a6' into anoa/dinsic_release_1_21_x

* commit 'dc22090a6':
  Add type hints to synapse.handlers.room (#8090)
  Remove some unused database functions. (#8085)
  Convert misc database code to async (#8087)
  Remove a space at the start of a changelog entry.
This commit is contained in:
Andrew Morgan
2020-10-19 18:25:03 +01:00
21 changed files with 128 additions and 348 deletions

View File

@@ -1 +1 @@
Convert various parts of the codebase to async/await.
Convert various parts of the codebase to async/await.

1
changelog.d/8085.misc Normal file
View File

@@ -0,0 +1 @@
Remove some unused database functions.

1
changelog.d/8087.misc Normal file
View File

@@ -0,0 +1 @@
Convert various parts of the codebase to async/await.

1
changelog.d/8090.misc Normal file
View File

@@ -0,0 +1 @@
Add type hints to `synapse.handlers.room`.

View File

@@ -22,7 +22,7 @@ import logging
import math
import string
from collections import OrderedDict
from typing import Awaitable, Optional, Tuple
from typing import TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Tuple
from synapse.api.constants import (
EventTypes,
@@ -32,11 +32,14 @@ from synapse.api.constants import (
RoomEncryptionAlgorithms,
)
from synapse.api.errors import AuthError, Codes, NotFoundError, StoreError, SynapseError
from synapse.api.filtering import Filter
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
from synapse.events import EventBase
from synapse.events.utils import copy_power_levels_contents
from synapse.http.endpoint import parse_and_validate_server_name
from synapse.storage.state import StateFilter
from synapse.types import (
JsonDict,
Requester,
RoomAlias,
RoomID,
@@ -53,6 +56,9 @@ from synapse.visibility import filter_events_for_client
from ._base import BaseHandler
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
id_server_scheme = "https://"
@@ -61,7 +67,7 @@ FIVE_MINUTES_IN_MS = 5 * 60 * 1000
class RoomCreationHandler(BaseHandler):
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super(RoomCreationHandler, self).__init__(hs)
self.spam_checker = hs.get_spam_checker()
@@ -92,7 +98,7 @@ class RoomCreationHandler(BaseHandler):
"guest_can_join": False,
"power_level_content_override": {},
},
}
} # type: Dict[str, Dict[str, Any]]
# Modify presets to selectively enable encryption by default per homeserver config
for preset_name, preset_config in self._presets_dict.items():
@@ -215,6 +221,9 @@ class RoomCreationHandler(BaseHandler):
old_room_state = await tombstone_context.get_current_state_ids()
# We know the tombstone event isn't an outlier so it has current state.
assert old_room_state is not None
# update any aliases
await self._move_aliases_to_new_room(
requester, old_room_id, new_room_id, old_room_state
@@ -540,17 +549,21 @@ class RoomCreationHandler(BaseHandler):
logger.error("Unable to send updated alias events in new room: %s", e)
async def create_room(
self, requester, config, ratelimit=True, creator_join_profile=None
self,
requester: Requester,
config: JsonDict,
ratelimit: bool = True,
creator_join_profile: Optional[JsonDict] = None,
) -> Tuple[dict, int]:
""" Creates a new room.
Args:
requester (synapse.types.Requester):
requester:
The user who requested the room creation.
config (dict) : A dict of configuration options.
ratelimit (bool): set to False to disable the rate limiter
config : A dict of configuration options.
ratelimit: set to False to disable the rate limiter
creator_join_profile (dict|None):
creator_join_profile:
Set to override the displayname and avatar for the creating
user in this room. If unset, displayname and avatar will be
derived from the user's profile. If set, should contain the
@@ -619,6 +632,7 @@ class RoomCreationHandler(BaseHandler):
Codes.UNSUPPORTED_ROOM_VERSION,
)
room_alias = None
if "room_alias_name" in config:
for wchar in string.whitespace:
if wchar in config["room_alias_name"]:
@@ -629,8 +643,6 @@ class RoomCreationHandler(BaseHandler):
if mapping:
raise SynapseError(400, "Room alias already taken", Codes.ROOM_IN_USE)
else:
room_alias = None
for i in invite_list:
try:
@@ -797,23 +809,30 @@ class RoomCreationHandler(BaseHandler):
async def _send_events_for_new_room(
self,
creator, # A Requester object.
room_id,
preset_config,
invite_list,
initial_state,
creation_content,
room_alias=None,
power_level_content_override=None, # Doesn't apply when initial state has power level state event content
creator_join_profile=None,
creator: Requester,
room_id: str,
preset_config: str,
invite_list: List[str],
initial_state: StateMap,
creation_content: JsonDict,
room_alias: Optional[RoomAlias] = None,
power_level_content_override: Optional[JsonDict] = None,
creator_join_profile: Optional[JsonDict] = None,
) -> int:
"""Sends the initial events into a new room.
`power_level_content_override` doesn't apply when initial state has
power level state event content.
Returns:
The stream_id of the last event persisted.
"""
def create(etype, content, **kwargs):
creator_id = creator.user.to_string()
event_keys = {"room_id": room_id, "sender": creator_id, "state_key": ""}
def create(etype: str, content: JsonDict, **kwargs) -> JsonDict:
e = {"type": etype, "content": content}
e.update(event_keys)
@@ -821,7 +840,7 @@ class RoomCreationHandler(BaseHandler):
return e
async def send(etype, content, **kwargs) -> int:
async def send(etype: str, content: JsonDict, **kwargs) -> int:
event = create(etype, content, **kwargs)
logger.debug("Sending %s in new room", etype)
(
@@ -834,10 +853,6 @@ class RoomCreationHandler(BaseHandler):
config = self._presets_dict[preset_config]
creator_id = creator.user.to_string()
event_keys = {"room_id": room_id, "sender": creator_id, "state_key": ""}
creation_content.update({"creator": creator_id})
await send(etype=EventTypes.Create, content=creation_content)
@@ -879,7 +894,7 @@ class RoomCreationHandler(BaseHandler):
"kick": 50,
"redact": 50,
"invite": 50,
}
} # type: JsonDict
if config["original_invitees_have_ops"]:
for invitee in invite_list:
@@ -933,7 +948,7 @@ class RoomCreationHandler(BaseHandler):
return last_sent_stream_id
async def _generate_room_id(
self, creator_id: str, is_public: str, room_version: RoomVersion,
self, creator_id: str, is_public: bool, room_version: RoomVersion,
):
# autogen room IDs and try to create it. We may clash, so just
# try a few times till one goes through, giving up eventually.
@@ -957,23 +972,30 @@ class RoomCreationHandler(BaseHandler):
class RoomContextHandler(object):
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.store = hs.get_datastore()
self.storage = hs.get_storage()
self.state_store = self.storage.state
async def get_event_context(self, user, room_id, event_id, limit, event_filter):
async def get_event_context(
self,
user: UserID,
room_id: str,
event_id: str,
limit: int,
event_filter: Optional[Filter],
) -> Optional[JsonDict]:
"""Retrieves events, pagination tokens and state around a given event
in a room.
Args:
user (UserID)
room_id (str)
event_id (str)
limit (int): The maximum number of events to return in total
user
room_id
event_id
limit: The maximum number of events to return in total
(excluding state).
event_filter (Filter|None): the filter to apply to the events returned
event_filter: the filter to apply to the events returned
(excluding the target event_id)
Returns:
@@ -1060,12 +1082,18 @@ class RoomContextHandler(object):
class RoomEventSource(object):
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
async def get_new_events(
self, user, from_key, limit, room_ids, is_guest, explicit_room_id=None
):
self,
user: UserID,
from_key: str,
limit: int,
room_ids: List[str],
is_guest: bool,
explicit_room_id: Optional[str] = None,
) -> Tuple[List[EventBase], str]:
# We just ignore the key for now.
to_key = self.get_current_key()
@@ -1123,7 +1151,7 @@ class RoomShutdownHandler(object):
)
DEFAULT_ROOM_NAME = "Content Violation Notification"
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.room_member_handler = hs.get_room_member_handler()
self._room_creation_handler = hs.get_room_creation_handler()

View File

@@ -18,8 +18,6 @@ from typing import Optional
from canonicaljson import json
from twisted.internet import defer
from synapse.metrics.background_process_metrics import run_as_background_process
from . import engines
@@ -308,9 +306,8 @@ class BackgroundUpdater(object):
update_name (str): Name of update
"""
@defer.inlineCallbacks
def noop_update(progress, batch_size):
yield self._end_background_update(update_name)
async def noop_update(progress, batch_size):
await self._end_background_update(update_name)
return 1
self.register_background_update_handler(update_name, noop_update)
@@ -409,12 +406,11 @@ class BackgroundUpdater(object):
else:
runner = create_index_sqlite
@defer.inlineCallbacks
def updater(progress, batch_size):
async def updater(progress, batch_size):
if runner is not None:
logger.info("Adding index %s to %s", index_name, table)
yield self.db_pool.runWithConnection(runner)
yield self._end_background_update(update_name)
await self.db_pool.runWithConnection(runner)
await self._end_background_update(update_name)
return 1
self.register_background_update_handler(update_name, updater)

View File

@@ -671,10 +671,9 @@ class DeviceWorkerStore(SQLBaseStore):
@cachedList(
cached_method_name="get_device_list_last_stream_id_for_remote",
list_name="user_ids",
inlineCallbacks=True,
)
def get_device_list_last_stream_id_for_remotes(self, user_ids: str):
rows = yield self.db_pool.simple_select_many_batch(
async def get_device_list_last_stream_id_for_remotes(self, user_ids: str):
rows = await self.db_pool.simple_select_many_batch(
table="device_lists_remote_extremeties",
column="user_id",
iterable=user_ids,

View File

@@ -257,11 +257,6 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
# Return all events where not all sets can reach them.
return {eid for eid, n in event_to_missing_sets.items() if n}
def get_oldest_events_in_room(self, room_id):
return self.db_pool.runInteraction(
"get_oldest_events_in_room", self._get_oldest_events_in_room_txn, room_id
)
def get_oldest_events_with_depth_in_room(self, room_id):
return self.db_pool.runInteraction(
"get_oldest_events_with_depth_in_room",
@@ -303,14 +298,6 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
else:
return max(row["depth"] for row in rows)
def _get_oldest_events_in_room_txn(self, txn, room_id):
return self.db_pool.simple_select_onecol_txn(
txn,
table="event_backward_extremities",
keyvalues={"room_id": room_id},
retcol="event_id",
)
def get_prev_events_for_room(self, room_id: str):
"""
Gets a subset of the current forward extremities in the given room.

View File

@@ -21,7 +21,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import LoggingTransaction, SQLBaseStore, db_to_json
from synapse.storage.database import DatabasePool
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cachedInlineCallbacks
from synapse.util.caches.descriptors import cached
logger = logging.getLogger(__name__)
@@ -86,18 +86,17 @@ class EventPushActionsWorkerStore(SQLBaseStore):
self._rotate_delay = 3
self._rotate_count = 10000
@cachedInlineCallbacks(num_args=3, tree=True, max_entries=5000)
def get_unread_event_push_actions_by_room_for_user(
@cached(num_args=3, tree=True, max_entries=5000)
async def get_unread_event_push_actions_by_room_for_user(
self, room_id, user_id, last_read_event_id
):
ret = yield self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_unread_event_push_actions_by_room",
self._get_unread_counts_by_receipt_txn,
room_id,
user_id,
last_read_event_id,
)
return ret
def _get_unread_counts_by_receipt_txn(
self, txn, room_id, user_id, last_read_event_id

View File

@@ -43,7 +43,7 @@ from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_cla
from synapse.storage.database import DatabasePool
from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.types import get_domain_from_id
from synapse.util.caches.descriptors import Cache, cached, cachedInlineCallbacks
from synapse.util.caches.descriptors import Cache, cachedInlineCallbacks
from synapse.util.iterutils import batch_iter
from synapse.util.metrics import Measure
@@ -137,42 +137,6 @@ class EventsWorkerStore(SQLBaseStore):
desc="get_received_ts",
)
def get_received_ts_by_stream_pos(self, stream_ordering):
"""Given a stream ordering get an approximate timestamp of when it
happened.
This is done by simply taking the received ts of the first event that
has a stream ordering greater than or equal to the given stream pos.
If none exists returns the current time, on the assumption that it must
have happened recently.
Args:
stream_ordering (int)
Returns:
Deferred[int]
"""
def _get_approximate_received_ts_txn(txn):
sql = """
SELECT received_ts FROM events
WHERE stream_ordering >= ?
LIMIT 1
"""
txn.execute(sql, (stream_ordering,))
row = txn.fetchone()
if row and row[0]:
ts = row[0]
else:
ts = self.clock.time_msec()
return ts
return self.db_pool.runInteraction(
"get_approximate_received_ts", _get_approximate_received_ts_txn
)
@defer.inlineCallbacks
def get_event(
self,
@@ -923,36 +887,6 @@ class EventsWorkerStore(SQLBaseStore):
)
return results
def _get_total_state_event_counts_txn(self, txn, room_id):
"""
See get_total_state_event_counts.
"""
# We join against the events table as that has an index on room_id
sql = """
SELECT COUNT(*) FROM state_events
INNER JOIN events USING (room_id, event_id)
WHERE room_id=?
"""
txn.execute(sql, (room_id,))
row = txn.fetchone()
return row[0] if row else 0
def get_total_state_event_counts(self, room_id):
"""
Gets the total number of state events in a room.
Args:
room_id (str)
Returns:
Deferred[int]
"""
return self.db_pool.runInteraction(
"get_total_state_event_counts",
self._get_total_state_event_counts_txn,
room_id,
)
def _get_current_state_event_counts_txn(self, txn, room_id):
"""
See get_current_state_event_counts.
@@ -1222,97 +1156,6 @@ class EventsWorkerStore(SQLBaseStore):
return rows, to_token, True
@cached(num_args=5, max_entries=10)
def get_all_new_events(
self,
last_backfill_id,
last_forward_id,
current_backfill_id,
current_forward_id,
limit,
):
"""Get all the new events that have arrived at the server either as
new events or as backfilled events"""
have_backfill_events = last_backfill_id != current_backfill_id
have_forward_events = last_forward_id != current_forward_id
if not have_backfill_events and not have_forward_events:
return defer.succeed(AllNewEventsResult([], [], [], [], []))
def get_all_new_events_txn(txn):
sql = (
"SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
" state_key, redacts"
" FROM events AS e"
" LEFT JOIN redactions USING (event_id)"
" LEFT JOIN state_events USING (event_id)"
" WHERE ? < stream_ordering AND stream_ordering <= ?"
" ORDER BY stream_ordering ASC"
" LIMIT ?"
)
if have_forward_events:
txn.execute(sql, (last_forward_id, current_forward_id, limit))
new_forward_events = txn.fetchall()
if len(new_forward_events) == limit:
upper_bound = new_forward_events[-1][0]
else:
upper_bound = current_forward_id
sql = (
"SELECT event_stream_ordering, event_id, state_group"
" FROM ex_outlier_stream"
" WHERE ? > event_stream_ordering"
" AND event_stream_ordering >= ?"
" ORDER BY event_stream_ordering DESC"
)
txn.execute(sql, (last_forward_id, upper_bound))
forward_ex_outliers = txn.fetchall()
else:
new_forward_events = []
forward_ex_outliers = []
sql = (
"SELECT -e.stream_ordering, e.event_id, e.room_id, e.type,"
" state_key, redacts"
" FROM events AS e"
" LEFT JOIN redactions USING (event_id)"
" LEFT JOIN state_events USING (event_id)"
" WHERE ? > stream_ordering AND stream_ordering >= ?"
" ORDER BY stream_ordering DESC"
" LIMIT ?"
)
if have_backfill_events:
txn.execute(sql, (-last_backfill_id, -current_backfill_id, limit))
new_backfill_events = txn.fetchall()
if len(new_backfill_events) == limit:
upper_bound = new_backfill_events[-1][0]
else:
upper_bound = current_backfill_id
sql = (
"SELECT -event_stream_ordering, event_id, state_group"
" FROM ex_outlier_stream"
" WHERE ? > event_stream_ordering"
" AND event_stream_ordering >= ?"
" ORDER BY event_stream_ordering DESC"
)
txn.execute(sql, (-last_backfill_id, -upper_bound))
backward_ex_outliers = txn.fetchall()
else:
new_backfill_events = []
backward_ex_outliers = []
return AllNewEventsResult(
new_forward_events,
new_backfill_events,
forward_ex_outliers,
backward_ex_outliers,
)
return self.db_pool.runInteraction("get_all_new_events", get_all_new_events_txn)
async def is_event_after(self, event_id1, event_id2):
"""Returns True if event_id1 is after event_id2 in the stream
"""
@@ -1357,14 +1200,3 @@ class EventsWorkerStore(SQLBaseStore):
return self.db_pool.runInteraction(
desc="get_next_event_to_expire", func=get_next_event_to_expire_txn
)
AllNewEventsResult = namedtuple(
"AllNewEventsResult",
[
"new_forward_events",
"new_backfill_events",
"forward_ex_outliers",
"backward_ex_outliers",
],
)

View File

@@ -130,13 +130,10 @@ class PresenceStore(SQLBaseStore):
raise NotImplementedError()
@cachedList(
cached_method_name="_get_presence_for_user",
list_name="user_ids",
num_args=1,
inlineCallbacks=True,
cached_method_name="_get_presence_for_user", list_name="user_ids", num_args=1,
)
def get_presence_for_users(self, user_ids):
rows = yield self.db_pool.simple_select_many_batch(
async def get_presence_for_users(self, user_ids):
rows = await self.db_pool.simple_select_many_batch(
table="presence_stream",
column="user_id",
iterable=user_ids,
@@ -160,24 +157,3 @@ class PresenceStore(SQLBaseStore):
def get_current_presence_token(self):
return self._presence_id_gen.get_current_token()
def allow_presence_visible(self, observed_localpart, observer_userid):
return self.db_pool.simple_insert(
table="presence_allow_inbound",
values={
"observed_user_id": observed_localpart,
"observer_user_id": observer_userid,
},
desc="allow_presence_visible",
or_ignore=True,
)
def disallow_presence_visible(self, observed_localpart, observer_userid):
return self.db_pool.simple_delete_one(
table="presence_allow_inbound",
keyvalues={
"observed_user_id": observed_localpart,
"observer_user_id": observer_userid,
},
desc="disallow_presence_visible",
)

View File

@@ -170,18 +170,15 @@ class PushRulesWorkerStore(
)
@cachedList(
cached_method_name="get_push_rules_for_user",
list_name="user_ids",
num_args=1,
inlineCallbacks=True,
cached_method_name="get_push_rules_for_user", list_name="user_ids", num_args=1,
)
def bulk_get_push_rules(self, user_ids):
async def bulk_get_push_rules(self, user_ids):
if not user_ids:
return {}
results = {user_id: [] for user_id in user_ids}
rows = yield self.db_pool.simple_select_many_batch(
rows = await self.db_pool.simple_select_many_batch(
table="push_rules",
column="user_name",
iterable=user_ids,
@@ -194,7 +191,7 @@ class PushRulesWorkerStore(
for row in rows:
results.setdefault(row["user_name"], []).append(row)
enabled_map_by_user = yield self.bulk_get_push_rules_enabled(user_ids)
enabled_map_by_user = await self.bulk_get_push_rules_enabled(user_ids)
for user_id, rules in results.items():
use_new_defaults = user_id in self._users_new_default_push_rules
@@ -260,15 +257,14 @@ class PushRulesWorkerStore(
cached_method_name="get_push_rules_enabled_for_user",
list_name="user_ids",
num_args=1,
inlineCallbacks=True,
)
def bulk_get_push_rules_enabled(self, user_ids):
async def bulk_get_push_rules_enabled(self, user_ids):
if not user_ids:
return {}
results = {user_id: {} for user_id in user_ids}
rows = yield self.db_pool.simple_select_many_batch(
rows = await self.db_pool.simple_select_many_batch(
table="push_rules_enable",
column="user_name",
iterable=user_ids,

View File

@@ -170,13 +170,10 @@ class PusherWorkerStore(SQLBaseStore):
raise NotImplementedError()
@cachedList(
cached_method_name="get_if_user_has_pusher",
list_name="user_ids",
num_args=1,
inlineCallbacks=True,
cached_method_name="get_if_user_has_pusher", list_name="user_ids", num_args=1,
)
def get_if_users_have_pushers(self, user_ids):
rows = yield self.db_pool.simple_select_many_batch(
async def get_if_users_have_pushers(self, user_ids):
rows = await self.db_pool.simple_select_many_batch(
table="pushers",
column="user_name",
iterable=user_ids,

View File

@@ -212,9 +212,8 @@ class ReceiptsWorkerStore(SQLBaseStore):
cached_method_name="_get_linearized_receipts_for_room",
list_name="room_ids",
num_args=3,
inlineCallbacks=True,
)
def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
async def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
if not room_ids:
return {}
@@ -243,7 +242,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
return self.db_pool.cursor_to_dict(txn)
txn_results = yield self.db_pool.runInteraction(
txn_results = await self.db_pool.runInteraction(
"_get_linearized_receipts_for_rooms", f
)

View File

@@ -1424,43 +1424,6 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
"validate_threepid_session_txn", validate_threepid_session_txn
)
def upsert_threepid_validation_session(
self,
medium,
address,
client_secret,
send_attempt,
session_id,
validated_at=None,
):
"""Upsert a threepid validation session
Args:
medium (str): The medium of the 3PID
address (str): The address of the 3PID
client_secret (str): A unique string provided by the client to
help identify this validation attempt
send_attempt (int): The latest send_attempt on this session
session_id (str): The id of this validation session
validated_at (int|None): The unix timestamp in milliseconds of
when the session was marked as valid
"""
insertion_values = {
"medium": medium,
"address": address,
"client_secret": client_secret,
}
if validated_at:
insertion_values["validated_at"] = validated_at
return self.db_pool.simple_upsert(
table="threepid_validation_session",
keyvalues={"session_id": session_id},
values={"last_send_attempt": send_attempt},
insertion_values=insertion_values,
desc="upsert_threepid_validation_session",
)
def start_or_continue_validation_session(
self,
medium,

View File

@@ -37,10 +37,6 @@ from synapse.util.caches.descriptors import cached
logger = logging.getLogger(__name__)
OpsLevel = collections.namedtuple(
"OpsLevel", ("ban_level", "kick_level", "redact_level")
)
RatelimitOverride = collections.namedtuple(
"RatelimitOverride", ("messages_per_second", "burst_count")
)

View File

@@ -17,8 +17,6 @@
import logging
from typing import TYPE_CHECKING, Awaitable, Iterable, List, Optional, Set
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
@@ -92,8 +90,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
lambda: self._known_servers_count,
)
@defer.inlineCallbacks
def _count_known_servers(self):
async def _count_known_servers(self):
"""
Count the servers that this server knows about.
@@ -121,7 +118,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
txn.execute(query)
return list(txn)[0][0]
count = yield self.db_pool.runInteraction("get_known_servers", _transact)
count = await self.db_pool.runInteraction("get_known_servers", _transact)
# We always know about ourselves, even if we have nothing in
# room_memberships (for example, the server is new).
@@ -589,11 +586,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
raise NotImplementedError()
@cachedList(
cached_method_name="_get_joined_profile_from_event_id",
list_name="event_ids",
inlineCallbacks=True,
cached_method_name="_get_joined_profile_from_event_id", list_name="event_ids",
)
def _get_joined_profiles_from_event_ids(self, event_ids: Iterable[str]):
async def _get_joined_profiles_from_event_ids(self, event_ids: Iterable[str]):
"""For given set of member event_ids check if they point to a join
event and if so return the associated user and profile info.
@@ -601,11 +596,11 @@ class RoomMemberWorkerStore(EventsWorkerStore):
event_ids: The member event IDs to lookup
Returns:
Deferred[dict[str, Tuple[str, ProfileInfo]|None]]: Map from event ID
dict[str, Tuple[str, ProfileInfo]|None]: Map from event ID
to `user_id` and ProfileInfo (or None if not join event).
"""
rows = yield self.db_pool.simple_select_many_batch(
rows = await self.db_pool.simple_select_many_batch(
table="room_memberships",
column="event_id",
iterable=event_ids,

View File

@@ -0,0 +1,17 @@
/* Copyright 2020 The Matrix.org Foundation C.I.C.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- This table is no longer used.
DROP TABLE IF EXISTS presence_allow_inbound;

View File

@@ -273,12 +273,11 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
cached_method_name="_get_state_group_for_event",
list_name="event_ids",
num_args=1,
inlineCallbacks=True,
)
def _get_state_group_for_events(self, event_ids):
async def _get_state_group_for_events(self, event_ids):
"""Returns mapping event_id -> state_group
"""
rows = yield self.db_pool.simple_select_many_batch(
rows = await self.db_pool.simple_select_many_batch(
table="event_to_state_groups",
column="event_id",
iterable=event_ids,

View File

@@ -38,10 +38,8 @@ class UserErasureWorkerStore(SQLBaseStore):
desc="is_user_erased",
).addCallback(operator.truth)
@cachedList(
cached_method_name="is_user_erased", list_name="user_ids", inlineCallbacks=True
)
def are_users_erased(self, user_ids):
@cachedList(cached_method_name="is_user_erased", list_name="user_ids")
async def are_users_erased(self, user_ids):
"""
Checks which users in a list have requested erasure
@@ -49,14 +47,14 @@ class UserErasureWorkerStore(SQLBaseStore):
user_ids (iterable[str]): full user id to check
Returns:
Deferred[dict[str, bool]]:
dict[str, bool]:
for each user, whether the user has requested erasure.
"""
# this serves the dual purpose of (a) making sure we can do len and
# iterate it multiple times, and (b) avoiding duplicates.
user_ids = tuple(set(user_ids))
rows = yield self.db_pool.simple_select_many_batch(
rows = await self.db_pool.simple_select_many_batch(
table="erased_users",
column="user_id",
iterable=user_ids,
@@ -65,8 +63,7 @@ class UserErasureWorkerStore(SQLBaseStore):
)
erased_users = {row["user_id"] for row in rows}
res = {u: u in erased_users for u in user_ids}
return res
return {u: u in erased_users for u in user_ids}
class UserErasureStore(UserErasureWorkerStore):

View File

@@ -170,7 +170,7 @@ commands=
skip_install = True
deps =
{[base]deps}
mypy==0.750
mypy==0.782
mypy-zope
env =
MYPYPATH = stubs/
@@ -191,6 +191,7 @@ commands = mypy \
synapse/handlers/message.py \
synapse/handlers/oidc_handler.py \
synapse/handlers/presence.py \
synapse/handlers/room.py \
synapse/handlers/room_member.py \
synapse/handlers/room_member_worker.py \
synapse/handlers/saml_handler.py \