Merge commit 'dc22090a6' into anoa/dinsic_release_1_21_x
* commit 'dc22090a6': Add type hints to synapse.handlers.room (#8090) Remove some unused database functions. (#8085) Convert misc database code to async (#8087) Remove a space at the start of a changelog entry.
This commit is contained in:
@@ -1 +1 @@
|
||||
Convert various parts of the codebase to async/await.
|
||||
Convert various parts of the codebase to async/await.
|
||||
|
||||
1
changelog.d/8085.misc
Normal file
1
changelog.d/8085.misc
Normal file
@@ -0,0 +1 @@
|
||||
Remove some unused database functions.
|
||||
1
changelog.d/8087.misc
Normal file
1
changelog.d/8087.misc
Normal file
@@ -0,0 +1 @@
|
||||
Convert various parts of the codebase to async/await.
|
||||
1
changelog.d/8090.misc
Normal file
1
changelog.d/8090.misc
Normal file
@@ -0,0 +1 @@
|
||||
Add type hints to `synapse.handlers.room`.
|
||||
@@ -22,7 +22,7 @@ import logging
|
||||
import math
|
||||
import string
|
||||
from collections import OrderedDict
|
||||
from typing import Awaitable, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Any, Awaitable, Dict, List, Optional, Tuple
|
||||
|
||||
from synapse.api.constants import (
|
||||
EventTypes,
|
||||
@@ -32,11 +32,14 @@ from synapse.api.constants import (
|
||||
RoomEncryptionAlgorithms,
|
||||
)
|
||||
from synapse.api.errors import AuthError, Codes, NotFoundError, StoreError, SynapseError
|
||||
from synapse.api.filtering import Filter
|
||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
|
||||
from synapse.events import EventBase
|
||||
from synapse.events.utils import copy_power_levels_contents
|
||||
from synapse.http.endpoint import parse_and_validate_server_name
|
||||
from synapse.storage.state import StateFilter
|
||||
from synapse.types import (
|
||||
JsonDict,
|
||||
Requester,
|
||||
RoomAlias,
|
||||
RoomID,
|
||||
@@ -53,6 +56,9 @@ from synapse.visibility import filter_events_for_client
|
||||
|
||||
from ._base import BaseHandler
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
id_server_scheme = "https://"
|
||||
@@ -61,7 +67,7 @@ FIVE_MINUTES_IN_MS = 5 * 60 * 1000
|
||||
|
||||
|
||||
class RoomCreationHandler(BaseHandler):
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super(RoomCreationHandler, self).__init__(hs)
|
||||
|
||||
self.spam_checker = hs.get_spam_checker()
|
||||
@@ -92,7 +98,7 @@ class RoomCreationHandler(BaseHandler):
|
||||
"guest_can_join": False,
|
||||
"power_level_content_override": {},
|
||||
},
|
||||
}
|
||||
} # type: Dict[str, Dict[str, Any]]
|
||||
|
||||
# Modify presets to selectively enable encryption by default per homeserver config
|
||||
for preset_name, preset_config in self._presets_dict.items():
|
||||
@@ -215,6 +221,9 @@ class RoomCreationHandler(BaseHandler):
|
||||
|
||||
old_room_state = await tombstone_context.get_current_state_ids()
|
||||
|
||||
# We know the tombstone event isn't an outlier so it has current state.
|
||||
assert old_room_state is not None
|
||||
|
||||
# update any aliases
|
||||
await self._move_aliases_to_new_room(
|
||||
requester, old_room_id, new_room_id, old_room_state
|
||||
@@ -540,17 +549,21 @@ class RoomCreationHandler(BaseHandler):
|
||||
logger.error("Unable to send updated alias events in new room: %s", e)
|
||||
|
||||
async def create_room(
|
||||
self, requester, config, ratelimit=True, creator_join_profile=None
|
||||
self,
|
||||
requester: Requester,
|
||||
config: JsonDict,
|
||||
ratelimit: bool = True,
|
||||
creator_join_profile: Optional[JsonDict] = None,
|
||||
) -> Tuple[dict, int]:
|
||||
""" Creates a new room.
|
||||
|
||||
Args:
|
||||
requester (synapse.types.Requester):
|
||||
requester:
|
||||
The user who requested the room creation.
|
||||
config (dict) : A dict of configuration options.
|
||||
ratelimit (bool): set to False to disable the rate limiter
|
||||
config : A dict of configuration options.
|
||||
ratelimit: set to False to disable the rate limiter
|
||||
|
||||
creator_join_profile (dict|None):
|
||||
creator_join_profile:
|
||||
Set to override the displayname and avatar for the creating
|
||||
user in this room. If unset, displayname and avatar will be
|
||||
derived from the user's profile. If set, should contain the
|
||||
@@ -619,6 +632,7 @@ class RoomCreationHandler(BaseHandler):
|
||||
Codes.UNSUPPORTED_ROOM_VERSION,
|
||||
)
|
||||
|
||||
room_alias = None
|
||||
if "room_alias_name" in config:
|
||||
for wchar in string.whitespace:
|
||||
if wchar in config["room_alias_name"]:
|
||||
@@ -629,8 +643,6 @@ class RoomCreationHandler(BaseHandler):
|
||||
|
||||
if mapping:
|
||||
raise SynapseError(400, "Room alias already taken", Codes.ROOM_IN_USE)
|
||||
else:
|
||||
room_alias = None
|
||||
|
||||
for i in invite_list:
|
||||
try:
|
||||
@@ -797,23 +809,30 @@ class RoomCreationHandler(BaseHandler):
|
||||
|
||||
async def _send_events_for_new_room(
|
||||
self,
|
||||
creator, # A Requester object.
|
||||
room_id,
|
||||
preset_config,
|
||||
invite_list,
|
||||
initial_state,
|
||||
creation_content,
|
||||
room_alias=None,
|
||||
power_level_content_override=None, # Doesn't apply when initial state has power level state event content
|
||||
creator_join_profile=None,
|
||||
creator: Requester,
|
||||
room_id: str,
|
||||
preset_config: str,
|
||||
invite_list: List[str],
|
||||
initial_state: StateMap,
|
||||
creation_content: JsonDict,
|
||||
room_alias: Optional[RoomAlias] = None,
|
||||
power_level_content_override: Optional[JsonDict] = None,
|
||||
creator_join_profile: Optional[JsonDict] = None,
|
||||
) -> int:
|
||||
"""Sends the initial events into a new room.
|
||||
|
||||
`power_level_content_override` doesn't apply when initial state has
|
||||
power level state event content.
|
||||
|
||||
Returns:
|
||||
The stream_id of the last event persisted.
|
||||
"""
|
||||
|
||||
def create(etype, content, **kwargs):
|
||||
creator_id = creator.user.to_string()
|
||||
|
||||
event_keys = {"room_id": room_id, "sender": creator_id, "state_key": ""}
|
||||
|
||||
def create(etype: str, content: JsonDict, **kwargs) -> JsonDict:
|
||||
e = {"type": etype, "content": content}
|
||||
|
||||
e.update(event_keys)
|
||||
@@ -821,7 +840,7 @@ class RoomCreationHandler(BaseHandler):
|
||||
|
||||
return e
|
||||
|
||||
async def send(etype, content, **kwargs) -> int:
|
||||
async def send(etype: str, content: JsonDict, **kwargs) -> int:
|
||||
event = create(etype, content, **kwargs)
|
||||
logger.debug("Sending %s in new room", etype)
|
||||
(
|
||||
@@ -834,10 +853,6 @@ class RoomCreationHandler(BaseHandler):
|
||||
|
||||
config = self._presets_dict[preset_config]
|
||||
|
||||
creator_id = creator.user.to_string()
|
||||
|
||||
event_keys = {"room_id": room_id, "sender": creator_id, "state_key": ""}
|
||||
|
||||
creation_content.update({"creator": creator_id})
|
||||
await send(etype=EventTypes.Create, content=creation_content)
|
||||
|
||||
@@ -879,7 +894,7 @@ class RoomCreationHandler(BaseHandler):
|
||||
"kick": 50,
|
||||
"redact": 50,
|
||||
"invite": 50,
|
||||
}
|
||||
} # type: JsonDict
|
||||
|
||||
if config["original_invitees_have_ops"]:
|
||||
for invitee in invite_list:
|
||||
@@ -933,7 +948,7 @@ class RoomCreationHandler(BaseHandler):
|
||||
return last_sent_stream_id
|
||||
|
||||
async def _generate_room_id(
|
||||
self, creator_id: str, is_public: str, room_version: RoomVersion,
|
||||
self, creator_id: str, is_public: bool, room_version: RoomVersion,
|
||||
):
|
||||
# autogen room IDs and try to create it. We may clash, so just
|
||||
# try a few times till one goes through, giving up eventually.
|
||||
@@ -957,23 +972,30 @@ class RoomCreationHandler(BaseHandler):
|
||||
|
||||
|
||||
class RoomContextHandler(object):
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.hs = hs
|
||||
self.store = hs.get_datastore()
|
||||
self.storage = hs.get_storage()
|
||||
self.state_store = self.storage.state
|
||||
|
||||
async def get_event_context(self, user, room_id, event_id, limit, event_filter):
|
||||
async def get_event_context(
|
||||
self,
|
||||
user: UserID,
|
||||
room_id: str,
|
||||
event_id: str,
|
||||
limit: int,
|
||||
event_filter: Optional[Filter],
|
||||
) -> Optional[JsonDict]:
|
||||
"""Retrieves events, pagination tokens and state around a given event
|
||||
in a room.
|
||||
|
||||
Args:
|
||||
user (UserID)
|
||||
room_id (str)
|
||||
event_id (str)
|
||||
limit (int): The maximum number of events to return in total
|
||||
user
|
||||
room_id
|
||||
event_id
|
||||
limit: The maximum number of events to return in total
|
||||
(excluding state).
|
||||
event_filter (Filter|None): the filter to apply to the events returned
|
||||
event_filter: the filter to apply to the events returned
|
||||
(excluding the target event_id)
|
||||
|
||||
Returns:
|
||||
@@ -1060,12 +1082,18 @@ class RoomContextHandler(object):
|
||||
|
||||
|
||||
class RoomEventSource(object):
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
async def get_new_events(
|
||||
self, user, from_key, limit, room_ids, is_guest, explicit_room_id=None
|
||||
):
|
||||
self,
|
||||
user: UserID,
|
||||
from_key: str,
|
||||
limit: int,
|
||||
room_ids: List[str],
|
||||
is_guest: bool,
|
||||
explicit_room_id: Optional[str] = None,
|
||||
) -> Tuple[List[EventBase], str]:
|
||||
# We just ignore the key for now.
|
||||
|
||||
to_key = self.get_current_key()
|
||||
@@ -1123,7 +1151,7 @@ class RoomShutdownHandler(object):
|
||||
)
|
||||
DEFAULT_ROOM_NAME = "Content Violation Notification"
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.hs = hs
|
||||
self.room_member_handler = hs.get_room_member_handler()
|
||||
self._room_creation_handler = hs.get_room_creation_handler()
|
||||
|
||||
@@ -18,8 +18,6 @@ from typing import Optional
|
||||
|
||||
from canonicaljson import json
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
|
||||
from . import engines
|
||||
@@ -308,9 +306,8 @@ class BackgroundUpdater(object):
|
||||
update_name (str): Name of update
|
||||
"""
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def noop_update(progress, batch_size):
|
||||
yield self._end_background_update(update_name)
|
||||
async def noop_update(progress, batch_size):
|
||||
await self._end_background_update(update_name)
|
||||
return 1
|
||||
|
||||
self.register_background_update_handler(update_name, noop_update)
|
||||
@@ -409,12 +406,11 @@ class BackgroundUpdater(object):
|
||||
else:
|
||||
runner = create_index_sqlite
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def updater(progress, batch_size):
|
||||
async def updater(progress, batch_size):
|
||||
if runner is not None:
|
||||
logger.info("Adding index %s to %s", index_name, table)
|
||||
yield self.db_pool.runWithConnection(runner)
|
||||
yield self._end_background_update(update_name)
|
||||
await self.db_pool.runWithConnection(runner)
|
||||
await self._end_background_update(update_name)
|
||||
return 1
|
||||
|
||||
self.register_background_update_handler(update_name, updater)
|
||||
|
||||
@@ -671,10 +671,9 @@ class DeviceWorkerStore(SQLBaseStore):
|
||||
@cachedList(
|
||||
cached_method_name="get_device_list_last_stream_id_for_remote",
|
||||
list_name="user_ids",
|
||||
inlineCallbacks=True,
|
||||
)
|
||||
def get_device_list_last_stream_id_for_remotes(self, user_ids: str):
|
||||
rows = yield self.db_pool.simple_select_many_batch(
|
||||
async def get_device_list_last_stream_id_for_remotes(self, user_ids: str):
|
||||
rows = await self.db_pool.simple_select_many_batch(
|
||||
table="device_lists_remote_extremeties",
|
||||
column="user_id",
|
||||
iterable=user_ids,
|
||||
|
||||
@@ -257,11 +257,6 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
|
||||
# Return all events where not all sets can reach them.
|
||||
return {eid for eid, n in event_to_missing_sets.items() if n}
|
||||
|
||||
def get_oldest_events_in_room(self, room_id):
|
||||
return self.db_pool.runInteraction(
|
||||
"get_oldest_events_in_room", self._get_oldest_events_in_room_txn, room_id
|
||||
)
|
||||
|
||||
def get_oldest_events_with_depth_in_room(self, room_id):
|
||||
return self.db_pool.runInteraction(
|
||||
"get_oldest_events_with_depth_in_room",
|
||||
@@ -303,14 +298,6 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
|
||||
else:
|
||||
return max(row["depth"] for row in rows)
|
||||
|
||||
def _get_oldest_events_in_room_txn(self, txn, room_id):
|
||||
return self.db_pool.simple_select_onecol_txn(
|
||||
txn,
|
||||
table="event_backward_extremities",
|
||||
keyvalues={"room_id": room_id},
|
||||
retcol="event_id",
|
||||
)
|
||||
|
||||
def get_prev_events_for_room(self, room_id: str):
|
||||
"""
|
||||
Gets a subset of the current forward extremities in the given room.
|
||||
|
||||
@@ -21,7 +21,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.storage._base import LoggingTransaction, SQLBaseStore, db_to_json
|
||||
from synapse.storage.database import DatabasePool
|
||||
from synapse.util import json_encoder
|
||||
from synapse.util.caches.descriptors import cachedInlineCallbacks
|
||||
from synapse.util.caches.descriptors import cached
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -86,18 +86,17 @@ class EventPushActionsWorkerStore(SQLBaseStore):
|
||||
self._rotate_delay = 3
|
||||
self._rotate_count = 10000
|
||||
|
||||
@cachedInlineCallbacks(num_args=3, tree=True, max_entries=5000)
|
||||
def get_unread_event_push_actions_by_room_for_user(
|
||||
@cached(num_args=3, tree=True, max_entries=5000)
|
||||
async def get_unread_event_push_actions_by_room_for_user(
|
||||
self, room_id, user_id, last_read_event_id
|
||||
):
|
||||
ret = yield self.db_pool.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_unread_event_push_actions_by_room",
|
||||
self._get_unread_counts_by_receipt_txn,
|
||||
room_id,
|
||||
user_id,
|
||||
last_read_event_id,
|
||||
)
|
||||
return ret
|
||||
|
||||
def _get_unread_counts_by_receipt_txn(
|
||||
self, txn, room_id, user_id, last_read_event_id
|
||||
|
||||
@@ -43,7 +43,7 @@ from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_cla
|
||||
from synapse.storage.database import DatabasePool
|
||||
from synapse.storage.util.id_generators import StreamIdGenerator
|
||||
from synapse.types import get_domain_from_id
|
||||
from synapse.util.caches.descriptors import Cache, cached, cachedInlineCallbacks
|
||||
from synapse.util.caches.descriptors import Cache, cachedInlineCallbacks
|
||||
from synapse.util.iterutils import batch_iter
|
||||
from synapse.util.metrics import Measure
|
||||
|
||||
@@ -137,42 +137,6 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
desc="get_received_ts",
|
||||
)
|
||||
|
||||
def get_received_ts_by_stream_pos(self, stream_ordering):
|
||||
"""Given a stream ordering get an approximate timestamp of when it
|
||||
happened.
|
||||
|
||||
This is done by simply taking the received ts of the first event that
|
||||
has a stream ordering greater than or equal to the given stream pos.
|
||||
If none exists returns the current time, on the assumption that it must
|
||||
have happened recently.
|
||||
|
||||
Args:
|
||||
stream_ordering (int)
|
||||
|
||||
Returns:
|
||||
Deferred[int]
|
||||
"""
|
||||
|
||||
def _get_approximate_received_ts_txn(txn):
|
||||
sql = """
|
||||
SELECT received_ts FROM events
|
||||
WHERE stream_ordering >= ?
|
||||
LIMIT 1
|
||||
"""
|
||||
|
||||
txn.execute(sql, (stream_ordering,))
|
||||
row = txn.fetchone()
|
||||
if row and row[0]:
|
||||
ts = row[0]
|
||||
else:
|
||||
ts = self.clock.time_msec()
|
||||
|
||||
return ts
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
"get_approximate_received_ts", _get_approximate_received_ts_txn
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_event(
|
||||
self,
|
||||
@@ -923,36 +887,6 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
)
|
||||
return results
|
||||
|
||||
def _get_total_state_event_counts_txn(self, txn, room_id):
|
||||
"""
|
||||
See get_total_state_event_counts.
|
||||
"""
|
||||
# We join against the events table as that has an index on room_id
|
||||
sql = """
|
||||
SELECT COUNT(*) FROM state_events
|
||||
INNER JOIN events USING (room_id, event_id)
|
||||
WHERE room_id=?
|
||||
"""
|
||||
txn.execute(sql, (room_id,))
|
||||
row = txn.fetchone()
|
||||
return row[0] if row else 0
|
||||
|
||||
def get_total_state_event_counts(self, room_id):
|
||||
"""
|
||||
Gets the total number of state events in a room.
|
||||
|
||||
Args:
|
||||
room_id (str)
|
||||
|
||||
Returns:
|
||||
Deferred[int]
|
||||
"""
|
||||
return self.db_pool.runInteraction(
|
||||
"get_total_state_event_counts",
|
||||
self._get_total_state_event_counts_txn,
|
||||
room_id,
|
||||
)
|
||||
|
||||
def _get_current_state_event_counts_txn(self, txn, room_id):
|
||||
"""
|
||||
See get_current_state_event_counts.
|
||||
@@ -1222,97 +1156,6 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
|
||||
return rows, to_token, True
|
||||
|
||||
@cached(num_args=5, max_entries=10)
|
||||
def get_all_new_events(
|
||||
self,
|
||||
last_backfill_id,
|
||||
last_forward_id,
|
||||
current_backfill_id,
|
||||
current_forward_id,
|
||||
limit,
|
||||
):
|
||||
"""Get all the new events that have arrived at the server either as
|
||||
new events or as backfilled events"""
|
||||
have_backfill_events = last_backfill_id != current_backfill_id
|
||||
have_forward_events = last_forward_id != current_forward_id
|
||||
|
||||
if not have_backfill_events and not have_forward_events:
|
||||
return defer.succeed(AllNewEventsResult([], [], [], [], []))
|
||||
|
||||
def get_all_new_events_txn(txn):
|
||||
sql = (
|
||||
"SELECT e.stream_ordering, e.event_id, e.room_id, e.type,"
|
||||
" state_key, redacts"
|
||||
" FROM events AS e"
|
||||
" LEFT JOIN redactions USING (event_id)"
|
||||
" LEFT JOIN state_events USING (event_id)"
|
||||
" WHERE ? < stream_ordering AND stream_ordering <= ?"
|
||||
" ORDER BY stream_ordering ASC"
|
||||
" LIMIT ?"
|
||||
)
|
||||
if have_forward_events:
|
||||
txn.execute(sql, (last_forward_id, current_forward_id, limit))
|
||||
new_forward_events = txn.fetchall()
|
||||
|
||||
if len(new_forward_events) == limit:
|
||||
upper_bound = new_forward_events[-1][0]
|
||||
else:
|
||||
upper_bound = current_forward_id
|
||||
|
||||
sql = (
|
||||
"SELECT event_stream_ordering, event_id, state_group"
|
||||
" FROM ex_outlier_stream"
|
||||
" WHERE ? > event_stream_ordering"
|
||||
" AND event_stream_ordering >= ?"
|
||||
" ORDER BY event_stream_ordering DESC"
|
||||
)
|
||||
txn.execute(sql, (last_forward_id, upper_bound))
|
||||
forward_ex_outliers = txn.fetchall()
|
||||
else:
|
||||
new_forward_events = []
|
||||
forward_ex_outliers = []
|
||||
|
||||
sql = (
|
||||
"SELECT -e.stream_ordering, e.event_id, e.room_id, e.type,"
|
||||
" state_key, redacts"
|
||||
" FROM events AS e"
|
||||
" LEFT JOIN redactions USING (event_id)"
|
||||
" LEFT JOIN state_events USING (event_id)"
|
||||
" WHERE ? > stream_ordering AND stream_ordering >= ?"
|
||||
" ORDER BY stream_ordering DESC"
|
||||
" LIMIT ?"
|
||||
)
|
||||
if have_backfill_events:
|
||||
txn.execute(sql, (-last_backfill_id, -current_backfill_id, limit))
|
||||
new_backfill_events = txn.fetchall()
|
||||
|
||||
if len(new_backfill_events) == limit:
|
||||
upper_bound = new_backfill_events[-1][0]
|
||||
else:
|
||||
upper_bound = current_backfill_id
|
||||
|
||||
sql = (
|
||||
"SELECT -event_stream_ordering, event_id, state_group"
|
||||
" FROM ex_outlier_stream"
|
||||
" WHERE ? > event_stream_ordering"
|
||||
" AND event_stream_ordering >= ?"
|
||||
" ORDER BY event_stream_ordering DESC"
|
||||
)
|
||||
txn.execute(sql, (-last_backfill_id, -upper_bound))
|
||||
backward_ex_outliers = txn.fetchall()
|
||||
else:
|
||||
new_backfill_events = []
|
||||
backward_ex_outliers = []
|
||||
|
||||
return AllNewEventsResult(
|
||||
new_forward_events,
|
||||
new_backfill_events,
|
||||
forward_ex_outliers,
|
||||
backward_ex_outliers,
|
||||
)
|
||||
|
||||
return self.db_pool.runInteraction("get_all_new_events", get_all_new_events_txn)
|
||||
|
||||
async def is_event_after(self, event_id1, event_id2):
|
||||
"""Returns True if event_id1 is after event_id2 in the stream
|
||||
"""
|
||||
@@ -1357,14 +1200,3 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
return self.db_pool.runInteraction(
|
||||
desc="get_next_event_to_expire", func=get_next_event_to_expire_txn
|
||||
)
|
||||
|
||||
|
||||
AllNewEventsResult = namedtuple(
|
||||
"AllNewEventsResult",
|
||||
[
|
||||
"new_forward_events",
|
||||
"new_backfill_events",
|
||||
"forward_ex_outliers",
|
||||
"backward_ex_outliers",
|
||||
],
|
||||
)
|
||||
|
||||
@@ -130,13 +130,10 @@ class PresenceStore(SQLBaseStore):
|
||||
raise NotImplementedError()
|
||||
|
||||
@cachedList(
|
||||
cached_method_name="_get_presence_for_user",
|
||||
list_name="user_ids",
|
||||
num_args=1,
|
||||
inlineCallbacks=True,
|
||||
cached_method_name="_get_presence_for_user", list_name="user_ids", num_args=1,
|
||||
)
|
||||
def get_presence_for_users(self, user_ids):
|
||||
rows = yield self.db_pool.simple_select_many_batch(
|
||||
async def get_presence_for_users(self, user_ids):
|
||||
rows = await self.db_pool.simple_select_many_batch(
|
||||
table="presence_stream",
|
||||
column="user_id",
|
||||
iterable=user_ids,
|
||||
@@ -160,24 +157,3 @@ class PresenceStore(SQLBaseStore):
|
||||
|
||||
def get_current_presence_token(self):
|
||||
return self._presence_id_gen.get_current_token()
|
||||
|
||||
def allow_presence_visible(self, observed_localpart, observer_userid):
|
||||
return self.db_pool.simple_insert(
|
||||
table="presence_allow_inbound",
|
||||
values={
|
||||
"observed_user_id": observed_localpart,
|
||||
"observer_user_id": observer_userid,
|
||||
},
|
||||
desc="allow_presence_visible",
|
||||
or_ignore=True,
|
||||
)
|
||||
|
||||
def disallow_presence_visible(self, observed_localpart, observer_userid):
|
||||
return self.db_pool.simple_delete_one(
|
||||
table="presence_allow_inbound",
|
||||
keyvalues={
|
||||
"observed_user_id": observed_localpart,
|
||||
"observer_user_id": observer_userid,
|
||||
},
|
||||
desc="disallow_presence_visible",
|
||||
)
|
||||
|
||||
@@ -170,18 +170,15 @@ class PushRulesWorkerStore(
|
||||
)
|
||||
|
||||
@cachedList(
|
||||
cached_method_name="get_push_rules_for_user",
|
||||
list_name="user_ids",
|
||||
num_args=1,
|
||||
inlineCallbacks=True,
|
||||
cached_method_name="get_push_rules_for_user", list_name="user_ids", num_args=1,
|
||||
)
|
||||
def bulk_get_push_rules(self, user_ids):
|
||||
async def bulk_get_push_rules(self, user_ids):
|
||||
if not user_ids:
|
||||
return {}
|
||||
|
||||
results = {user_id: [] for user_id in user_ids}
|
||||
|
||||
rows = yield self.db_pool.simple_select_many_batch(
|
||||
rows = await self.db_pool.simple_select_many_batch(
|
||||
table="push_rules",
|
||||
column="user_name",
|
||||
iterable=user_ids,
|
||||
@@ -194,7 +191,7 @@ class PushRulesWorkerStore(
|
||||
for row in rows:
|
||||
results.setdefault(row["user_name"], []).append(row)
|
||||
|
||||
enabled_map_by_user = yield self.bulk_get_push_rules_enabled(user_ids)
|
||||
enabled_map_by_user = await self.bulk_get_push_rules_enabled(user_ids)
|
||||
|
||||
for user_id, rules in results.items():
|
||||
use_new_defaults = user_id in self._users_new_default_push_rules
|
||||
@@ -260,15 +257,14 @@ class PushRulesWorkerStore(
|
||||
cached_method_name="get_push_rules_enabled_for_user",
|
||||
list_name="user_ids",
|
||||
num_args=1,
|
||||
inlineCallbacks=True,
|
||||
)
|
||||
def bulk_get_push_rules_enabled(self, user_ids):
|
||||
async def bulk_get_push_rules_enabled(self, user_ids):
|
||||
if not user_ids:
|
||||
return {}
|
||||
|
||||
results = {user_id: {} for user_id in user_ids}
|
||||
|
||||
rows = yield self.db_pool.simple_select_many_batch(
|
||||
rows = await self.db_pool.simple_select_many_batch(
|
||||
table="push_rules_enable",
|
||||
column="user_name",
|
||||
iterable=user_ids,
|
||||
|
||||
@@ -170,13 +170,10 @@ class PusherWorkerStore(SQLBaseStore):
|
||||
raise NotImplementedError()
|
||||
|
||||
@cachedList(
|
||||
cached_method_name="get_if_user_has_pusher",
|
||||
list_name="user_ids",
|
||||
num_args=1,
|
||||
inlineCallbacks=True,
|
||||
cached_method_name="get_if_user_has_pusher", list_name="user_ids", num_args=1,
|
||||
)
|
||||
def get_if_users_have_pushers(self, user_ids):
|
||||
rows = yield self.db_pool.simple_select_many_batch(
|
||||
async def get_if_users_have_pushers(self, user_ids):
|
||||
rows = await self.db_pool.simple_select_many_batch(
|
||||
table="pushers",
|
||||
column="user_name",
|
||||
iterable=user_ids,
|
||||
|
||||
@@ -212,9 +212,8 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
||||
cached_method_name="_get_linearized_receipts_for_room",
|
||||
list_name="room_ids",
|
||||
num_args=3,
|
||||
inlineCallbacks=True,
|
||||
)
|
||||
def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
|
||||
async def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
|
||||
if not room_ids:
|
||||
return {}
|
||||
|
||||
@@ -243,7 +242,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
||||
|
||||
return self.db_pool.cursor_to_dict(txn)
|
||||
|
||||
txn_results = yield self.db_pool.runInteraction(
|
||||
txn_results = await self.db_pool.runInteraction(
|
||||
"_get_linearized_receipts_for_rooms", f
|
||||
)
|
||||
|
||||
|
||||
@@ -1424,43 +1424,6 @@ class RegistrationStore(RegistrationBackgroundUpdateStore):
|
||||
"validate_threepid_session_txn", validate_threepid_session_txn
|
||||
)
|
||||
|
||||
def upsert_threepid_validation_session(
|
||||
self,
|
||||
medium,
|
||||
address,
|
||||
client_secret,
|
||||
send_attempt,
|
||||
session_id,
|
||||
validated_at=None,
|
||||
):
|
||||
"""Upsert a threepid validation session
|
||||
Args:
|
||||
medium (str): The medium of the 3PID
|
||||
address (str): The address of the 3PID
|
||||
client_secret (str): A unique string provided by the client to
|
||||
help identify this validation attempt
|
||||
send_attempt (int): The latest send_attempt on this session
|
||||
session_id (str): The id of this validation session
|
||||
validated_at (int|None): The unix timestamp in milliseconds of
|
||||
when the session was marked as valid
|
||||
"""
|
||||
insertion_values = {
|
||||
"medium": medium,
|
||||
"address": address,
|
||||
"client_secret": client_secret,
|
||||
}
|
||||
|
||||
if validated_at:
|
||||
insertion_values["validated_at"] = validated_at
|
||||
|
||||
return self.db_pool.simple_upsert(
|
||||
table="threepid_validation_session",
|
||||
keyvalues={"session_id": session_id},
|
||||
values={"last_send_attempt": send_attempt},
|
||||
insertion_values=insertion_values,
|
||||
desc="upsert_threepid_validation_session",
|
||||
)
|
||||
|
||||
def start_or_continue_validation_session(
|
||||
self,
|
||||
medium,
|
||||
|
||||
@@ -37,10 +37,6 @@ from synapse.util.caches.descriptors import cached
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
OpsLevel = collections.namedtuple(
|
||||
"OpsLevel", ("ban_level", "kick_level", "redact_level")
|
||||
)
|
||||
|
||||
RatelimitOverride = collections.namedtuple(
|
||||
"RatelimitOverride", ("messages_per_second", "burst_count")
|
||||
)
|
||||
|
||||
@@ -17,8 +17,6 @@
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Awaitable, Iterable, List, Optional, Set
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.constants import EventTypes, Membership
|
||||
from synapse.events import EventBase
|
||||
from synapse.events.snapshot import EventContext
|
||||
@@ -92,8 +90,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||
lambda: self._known_servers_count,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _count_known_servers(self):
|
||||
async def _count_known_servers(self):
|
||||
"""
|
||||
Count the servers that this server knows about.
|
||||
|
||||
@@ -121,7 +118,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||
txn.execute(query)
|
||||
return list(txn)[0][0]
|
||||
|
||||
count = yield self.db_pool.runInteraction("get_known_servers", _transact)
|
||||
count = await self.db_pool.runInteraction("get_known_servers", _transact)
|
||||
|
||||
# We always know about ourselves, even if we have nothing in
|
||||
# room_memberships (for example, the server is new).
|
||||
@@ -589,11 +586,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||
raise NotImplementedError()
|
||||
|
||||
@cachedList(
|
||||
cached_method_name="_get_joined_profile_from_event_id",
|
||||
list_name="event_ids",
|
||||
inlineCallbacks=True,
|
||||
cached_method_name="_get_joined_profile_from_event_id", list_name="event_ids",
|
||||
)
|
||||
def _get_joined_profiles_from_event_ids(self, event_ids: Iterable[str]):
|
||||
async def _get_joined_profiles_from_event_ids(self, event_ids: Iterable[str]):
|
||||
"""For given set of member event_ids check if they point to a join
|
||||
event and if so return the associated user and profile info.
|
||||
|
||||
@@ -601,11 +596,11 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||
event_ids: The member event IDs to lookup
|
||||
|
||||
Returns:
|
||||
Deferred[dict[str, Tuple[str, ProfileInfo]|None]]: Map from event ID
|
||||
dict[str, Tuple[str, ProfileInfo]|None]: Map from event ID
|
||||
to `user_id` and ProfileInfo (or None if not join event).
|
||||
"""
|
||||
|
||||
rows = yield self.db_pool.simple_select_many_batch(
|
||||
rows = await self.db_pool.simple_select_many_batch(
|
||||
table="room_memberships",
|
||||
column="event_id",
|
||||
iterable=event_ids,
|
||||
|
||||
@@ -0,0 +1,17 @@
|
||||
/* Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
-- This table is no longer used.
|
||||
DROP TABLE IF EXISTS presence_allow_inbound;
|
||||
@@ -273,12 +273,11 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
cached_method_name="_get_state_group_for_event",
|
||||
list_name="event_ids",
|
||||
num_args=1,
|
||||
inlineCallbacks=True,
|
||||
)
|
||||
def _get_state_group_for_events(self, event_ids):
|
||||
async def _get_state_group_for_events(self, event_ids):
|
||||
"""Returns mapping event_id -> state_group
|
||||
"""
|
||||
rows = yield self.db_pool.simple_select_many_batch(
|
||||
rows = await self.db_pool.simple_select_many_batch(
|
||||
table="event_to_state_groups",
|
||||
column="event_id",
|
||||
iterable=event_ids,
|
||||
|
||||
@@ -38,10 +38,8 @@ class UserErasureWorkerStore(SQLBaseStore):
|
||||
desc="is_user_erased",
|
||||
).addCallback(operator.truth)
|
||||
|
||||
@cachedList(
|
||||
cached_method_name="is_user_erased", list_name="user_ids", inlineCallbacks=True
|
||||
)
|
||||
def are_users_erased(self, user_ids):
|
||||
@cachedList(cached_method_name="is_user_erased", list_name="user_ids")
|
||||
async def are_users_erased(self, user_ids):
|
||||
"""
|
||||
Checks which users in a list have requested erasure
|
||||
|
||||
@@ -49,14 +47,14 @@ class UserErasureWorkerStore(SQLBaseStore):
|
||||
user_ids (iterable[str]): full user id to check
|
||||
|
||||
Returns:
|
||||
Deferred[dict[str, bool]]:
|
||||
dict[str, bool]:
|
||||
for each user, whether the user has requested erasure.
|
||||
"""
|
||||
# this serves the dual purpose of (a) making sure we can do len and
|
||||
# iterate it multiple times, and (b) avoiding duplicates.
|
||||
user_ids = tuple(set(user_ids))
|
||||
|
||||
rows = yield self.db_pool.simple_select_many_batch(
|
||||
rows = await self.db_pool.simple_select_many_batch(
|
||||
table="erased_users",
|
||||
column="user_id",
|
||||
iterable=user_ids,
|
||||
@@ -65,8 +63,7 @@ class UserErasureWorkerStore(SQLBaseStore):
|
||||
)
|
||||
erased_users = {row["user_id"] for row in rows}
|
||||
|
||||
res = {u: u in erased_users for u in user_ids}
|
||||
return res
|
||||
return {u: u in erased_users for u in user_ids}
|
||||
|
||||
|
||||
class UserErasureStore(UserErasureWorkerStore):
|
||||
|
||||
3
tox.ini
3
tox.ini
@@ -170,7 +170,7 @@ commands=
|
||||
skip_install = True
|
||||
deps =
|
||||
{[base]deps}
|
||||
mypy==0.750
|
||||
mypy==0.782
|
||||
mypy-zope
|
||||
env =
|
||||
MYPYPATH = stubs/
|
||||
@@ -191,6 +191,7 @@ commands = mypy \
|
||||
synapse/handlers/message.py \
|
||||
synapse/handlers/oidc_handler.py \
|
||||
synapse/handlers/presence.py \
|
||||
synapse/handlers/room.py \
|
||||
synapse/handlers/room_member.py \
|
||||
synapse/handlers/room_member_worker.py \
|
||||
synapse/handlers/saml_handler.py \
|
||||
|
||||
Reference in New Issue
Block a user