Merge commit '5c03134d0' into anoa/dinsic_release_1_21_x
* commit '5c03134d0': Convert additional database code to async/await. (#8195) Define StateMap as immutable and add a MutableStateMap type. (#8183) Move and refactor LoginRestServlet helper methods (#8182)
This commit is contained in:
1
changelog.d/8182.misc
Normal file
1
changelog.d/8182.misc
Normal file
@@ -0,0 +1 @@
|
||||
Refactor some of `LoginRestServlet`'s helper methods, and move them to `AuthHandler` for easier reuse.
|
||||
1
changelog.d/8183.misc
Normal file
1
changelog.d/8183.misc
Normal file
@@ -0,0 +1 @@
|
||||
Add type hints to `synapse.state`.
|
||||
1
changelog.d/8195.misc
Normal file
1
changelog.d/8195.misc
Normal file
@@ -0,0 +1 @@
|
||||
Convert various parts of the codebase to async/await.
|
||||
@@ -14,11 +14,16 @@
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import re
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.appservice.api import ApplicationServiceApi
|
||||
from synapse.types import GroupID, get_domain_from_id
|
||||
from synapse.util.caches.descriptors import cached
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.storage.databases.main import DataStore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -35,19 +40,19 @@ class AppServiceTransaction(object):
|
||||
self.id = id
|
||||
self.events = events
|
||||
|
||||
def send(self, as_api):
|
||||
async def send(self, as_api: ApplicationServiceApi) -> bool:
|
||||
"""Sends this transaction using the provided AS API interface.
|
||||
|
||||
Args:
|
||||
as_api(ApplicationServiceApi): The API to use to send.
|
||||
as_api: The API to use to send.
|
||||
Returns:
|
||||
An Awaitable which resolves to True if the transaction was sent.
|
||||
True if the transaction was sent.
|
||||
"""
|
||||
return as_api.push_bulk(
|
||||
return await as_api.push_bulk(
|
||||
service=self.service, events=self.events, txn_id=self.id
|
||||
)
|
||||
|
||||
def complete(self, store):
|
||||
async def complete(self, store: "DataStore") -> None:
|
||||
"""Completes this transaction as successful.
|
||||
|
||||
Marks this transaction ID on the application service and removes the
|
||||
@@ -55,10 +60,8 @@ class AppServiceTransaction(object):
|
||||
|
||||
Args:
|
||||
store: The database store to operate on.
|
||||
Returns:
|
||||
A Deferred which resolves to True if the transaction was completed.
|
||||
"""
|
||||
return store.complete_appservice_txn(service=self.service, txn_id=self.id)
|
||||
await store.complete_appservice_txn(service=self.service, txn_id=self.id)
|
||||
|
||||
|
||||
class ApplicationService(object):
|
||||
|
||||
@@ -20,6 +20,7 @@ These actions are mostly only used by the :py:mod:`.replication` module.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from synapse.federation.units import Transaction
|
||||
from synapse.logging.utils import log_function
|
||||
@@ -36,25 +37,27 @@ class TransactionActions(object):
|
||||
self.store = datastore
|
||||
|
||||
@log_function
|
||||
def have_responded(self, origin, transaction):
|
||||
""" Have we already responded to a transaction with the same id and
|
||||
async def have_responded(
|
||||
self, origin: str, transaction: Transaction
|
||||
) -> Optional[Tuple[int, JsonDict]]:
|
||||
"""Have we already responded to a transaction with the same id and
|
||||
origin?
|
||||
|
||||
Returns:
|
||||
Deferred: Results in `None` if we have not previously responded to
|
||||
this transaction or a 2-tuple of `(int, dict)` representing the
|
||||
response code and response body.
|
||||
`None` if we have not previously responded to this transaction or a
|
||||
2-tuple of `(int, dict)` representing the response code and response body.
|
||||
"""
|
||||
if not transaction.transaction_id:
|
||||
transaction_id = transaction.transaction_id # type: ignore
|
||||
if not transaction_id:
|
||||
raise RuntimeError("Cannot persist a transaction with no transaction_id")
|
||||
|
||||
return self.store.get_received_txn_response(transaction.transaction_id, origin)
|
||||
return await self.store.get_received_txn_response(transaction_id, origin)
|
||||
|
||||
@log_function
|
||||
async def set_response(
|
||||
self, origin: str, transaction: Transaction, code: int, response: JsonDict
|
||||
) -> None:
|
||||
""" Persist how we responded to a transaction.
|
||||
"""Persist how we responded to a transaction.
|
||||
"""
|
||||
transaction_id = transaction.transaction_id # type: ignore
|
||||
if not transaction_id:
|
||||
|
||||
@@ -42,8 +42,9 @@ from synapse.http.site import SynapseRequest
|
||||
from synapse.logging.context import defer_to_thread
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.module_api import ModuleApi
|
||||
from synapse.types import Requester, UserID
|
||||
from synapse.types import JsonDict, Requester, UserID
|
||||
from synapse.util import stringutils as stringutils
|
||||
from synapse.util.msisdn import phone_number_to_msisdn
|
||||
from synapse.util.threepids import canonicalise_email
|
||||
|
||||
from ._base import BaseHandler
|
||||
@@ -51,6 +52,91 @@ from ._base import BaseHandler
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def convert_client_dict_legacy_fields_to_identifier(
|
||||
submission: JsonDict,
|
||||
) -> Dict[str, str]:
|
||||
"""
|
||||
Convert a legacy-formatted login submission to an identifier dict.
|
||||
|
||||
Legacy login submissions (used in both login and user-interactive authentication)
|
||||
provide user-identifying information at the top-level instead.
|
||||
|
||||
These are now deprecated and replaced with identifiers:
|
||||
https://matrix.org/docs/spec/client_server/r0.6.1#identifier-types
|
||||
|
||||
Args:
|
||||
submission: The client dict to convert
|
||||
|
||||
Returns:
|
||||
The matching identifier dict
|
||||
|
||||
Raises:
|
||||
SynapseError: If the format of the client dict is invalid
|
||||
"""
|
||||
identifier = submission.get("identifier", {})
|
||||
|
||||
# Generate an m.id.user identifier if "user" parameter is present
|
||||
user = submission.get("user")
|
||||
if user:
|
||||
identifier = {"type": "m.id.user", "user": user}
|
||||
|
||||
# Generate an m.id.thirdparty identifier if "medium" and "address" parameters are present
|
||||
medium = submission.get("medium")
|
||||
address = submission.get("address")
|
||||
if medium and address:
|
||||
identifier = {
|
||||
"type": "m.id.thirdparty",
|
||||
"medium": medium,
|
||||
"address": address,
|
||||
}
|
||||
|
||||
# We've converted valid, legacy login submissions to an identifier. If the
|
||||
# submission still doesn't have an identifier, it's invalid
|
||||
if not identifier:
|
||||
raise SynapseError(400, "Invalid login submission", Codes.INVALID_PARAM)
|
||||
|
||||
# Ensure the identifier has a type
|
||||
if "type" not in identifier:
|
||||
raise SynapseError(
|
||||
400, "'identifier' dict has no key 'type'", errcode=Codes.MISSING_PARAM,
|
||||
)
|
||||
|
||||
return identifier
|
||||
|
||||
|
||||
def login_id_phone_to_thirdparty(identifier: JsonDict) -> Dict[str, str]:
|
||||
"""
|
||||
Convert a phone login identifier type to a generic threepid identifier.
|
||||
|
||||
Args:
|
||||
identifier: Login identifier dict of type 'm.id.phone'
|
||||
|
||||
Returns:
|
||||
An equivalent m.id.thirdparty identifier dict
|
||||
"""
|
||||
if "country" not in identifier or (
|
||||
# The specification requires a "phone" field, while Synapse used to require a "number"
|
||||
# field. Accept both for backwards compatibility.
|
||||
"phone" not in identifier
|
||||
and "number" not in identifier
|
||||
):
|
||||
raise SynapseError(
|
||||
400, "Invalid phone-type identifier", errcode=Codes.INVALID_PARAM
|
||||
)
|
||||
|
||||
# Accept both "phone" and "number" as valid keys in m.id.phone
|
||||
phone_number = identifier.get("phone", identifier["number"])
|
||||
|
||||
# Convert user-provided phone number to a consistent representation
|
||||
msisdn = phone_number_to_msisdn(identifier["country"], phone_number)
|
||||
|
||||
return {
|
||||
"type": "m.id.thirdparty",
|
||||
"medium": "msisdn",
|
||||
"address": msisdn,
|
||||
}
|
||||
|
||||
|
||||
class AuthHandler(BaseHandler):
|
||||
SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000
|
||||
|
||||
|
||||
@@ -72,7 +72,13 @@ from synapse.replication.http.federation import (
|
||||
from synapse.replication.http.membership import ReplicationUserJoinedLeftRoomRestServlet
|
||||
from synapse.state import StateResolutionStore, resolve_events_with_store
|
||||
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
|
||||
from synapse.types import JsonDict, StateMap, UserID, get_domain_from_id
|
||||
from synapse.types import (
|
||||
JsonDict,
|
||||
MutableStateMap,
|
||||
StateMap,
|
||||
UserID,
|
||||
get_domain_from_id,
|
||||
)
|
||||
from synapse.util.async_helpers import Linearizer, concurrently_execute
|
||||
from synapse.util.distributor import user_joined_room
|
||||
from synapse.util.retryutils import NotRetryingDestination
|
||||
@@ -96,7 +102,7 @@ class _NewEventInfo:
|
||||
|
||||
event = attr.ib(type=EventBase)
|
||||
state = attr.ib(type=Optional[Sequence[EventBase]], default=None)
|
||||
auth_events = attr.ib(type=Optional[StateMap[EventBase]], default=None)
|
||||
auth_events = attr.ib(type=Optional[MutableStateMap[EventBase]], default=None)
|
||||
|
||||
|
||||
class FederationHandler(BaseHandler):
|
||||
@@ -1883,8 +1889,8 @@ class FederationHandler(BaseHandler):
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_min_depth_for_context(self, context):
|
||||
return self.store.get_min_depth(context)
|
||||
async def get_min_depth_for_context(self, context):
|
||||
return await self.store.get_min_depth(context)
|
||||
|
||||
async def _handle_new_event(
|
||||
self, origin, event, state=None, auth_events=None, backfilled=False
|
||||
@@ -2063,7 +2069,7 @@ class FederationHandler(BaseHandler):
|
||||
origin: str,
|
||||
event: EventBase,
|
||||
state: Optional[Iterable[EventBase]],
|
||||
auth_events: Optional[StateMap[EventBase]],
|
||||
auth_events: Optional[MutableStateMap[EventBase]],
|
||||
backfilled: bool,
|
||||
) -> EventContext:
|
||||
context = await self.state_handler.compute_event_context(event, old_state=state)
|
||||
@@ -2147,7 +2153,9 @@ class FederationHandler(BaseHandler):
|
||||
current_states = await self.state_handler.resolve_events(
|
||||
room_version, state_sets, event
|
||||
)
|
||||
current_state_ids = {k: e.event_id for k, e in current_states.items()}
|
||||
current_state_ids = {
|
||||
k: e.event_id for k, e in current_states.items()
|
||||
} # type: StateMap[str]
|
||||
else:
|
||||
current_state_ids = await self.state_handler.get_current_state_ids(
|
||||
event.room_id, latest_event_ids=extrem_ids
|
||||
@@ -2233,7 +2241,7 @@ class FederationHandler(BaseHandler):
|
||||
origin: str,
|
||||
event: EventBase,
|
||||
context: EventContext,
|
||||
auth_events: StateMap[EventBase],
|
||||
auth_events: MutableStateMap[EventBase],
|
||||
) -> EventContext:
|
||||
"""
|
||||
|
||||
@@ -2284,7 +2292,7 @@ class FederationHandler(BaseHandler):
|
||||
origin: str,
|
||||
event: EventBase,
|
||||
context: EventContext,
|
||||
auth_events: StateMap[EventBase],
|
||||
auth_events: MutableStateMap[EventBase],
|
||||
) -> EventContext:
|
||||
"""Helper for do_auth. See there for docs.
|
||||
|
||||
|
||||
@@ -41,6 +41,7 @@ from synapse.http.endpoint import parse_and_validate_server_name
|
||||
from synapse.storage.state import StateFilter
|
||||
from synapse.types import (
|
||||
JsonDict,
|
||||
MutableStateMap,
|
||||
Requester,
|
||||
RoomAlias,
|
||||
RoomID,
|
||||
@@ -843,7 +844,7 @@ class RoomCreationHandler(BaseHandler):
|
||||
room_id: str,
|
||||
preset_config: str,
|
||||
invite_list: List[str],
|
||||
initial_state: StateMap,
|
||||
initial_state: MutableStateMap,
|
||||
creation_content: JsonDict,
|
||||
room_alias: Optional[RoomAlias] = None,
|
||||
power_level_content_override: Optional[JsonDict] = None,
|
||||
|
||||
@@ -31,6 +31,7 @@ from synapse.storage.state import StateFilter
|
||||
from synapse.types import (
|
||||
Collection,
|
||||
JsonDict,
|
||||
MutableStateMap,
|
||||
RoomStreamToken,
|
||||
StateMap,
|
||||
StreamToken,
|
||||
@@ -588,7 +589,7 @@ class SyncHandler(object):
|
||||
room_id: str,
|
||||
sync_config: SyncConfig,
|
||||
batch: TimelineBatch,
|
||||
state: StateMap[EventBase],
|
||||
state: MutableStateMap[EventBase],
|
||||
now_token: StreamToken,
|
||||
) -> Optional[JsonDict]:
|
||||
""" Works out a room summary block for this room, summarising the number
|
||||
@@ -736,7 +737,7 @@ class SyncHandler(object):
|
||||
since_token: Optional[StreamToken],
|
||||
now_token: StreamToken,
|
||||
full_state: bool,
|
||||
) -> StateMap[EventBase]:
|
||||
) -> MutableStateMap[EventBase]:
|
||||
""" Works out the difference in state between the start of the timeline
|
||||
and the previous sync.
|
||||
|
||||
|
||||
@@ -18,6 +18,10 @@ from typing import Awaitable, Callable, Dict, Optional
|
||||
|
||||
from synapse.api.errors import Codes, LoginError, SynapseError
|
||||
from synapse.api.ratelimiting import Ratelimiter
|
||||
from synapse.handlers.auth import (
|
||||
convert_client_dict_legacy_fields_to_identifier,
|
||||
login_id_phone_to_thirdparty,
|
||||
)
|
||||
from synapse.http.server import finish_request
|
||||
from synapse.http.servlet import (
|
||||
RestServlet,
|
||||
@@ -28,56 +32,11 @@ from synapse.http.site import SynapseRequest
|
||||
from synapse.rest.client.v2_alpha._base import client_patterns
|
||||
from synapse.rest.well_known import WellKnownBuilder
|
||||
from synapse.types import JsonDict, UserID
|
||||
from synapse.util.msisdn import phone_number_to_msisdn
|
||||
from synapse.util.threepids import canonicalise_email
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def login_submission_legacy_convert(submission):
|
||||
"""
|
||||
If the input login submission is an old style object
|
||||
(ie. with top-level user / medium / address) convert it
|
||||
to a typed object.
|
||||
"""
|
||||
if "user" in submission:
|
||||
submission["identifier"] = {"type": "m.id.user", "user": submission["user"]}
|
||||
del submission["user"]
|
||||
|
||||
if "medium" in submission and "address" in submission:
|
||||
submission["identifier"] = {
|
||||
"type": "m.id.thirdparty",
|
||||
"medium": submission["medium"],
|
||||
"address": submission["address"],
|
||||
}
|
||||
del submission["medium"]
|
||||
del submission["address"]
|
||||
|
||||
|
||||
def login_id_thirdparty_from_phone(identifier):
|
||||
"""
|
||||
Convert a phone login identifier type to a generic threepid identifier
|
||||
Args:
|
||||
identifier(dict): Login identifier dict of type 'm.id.phone'
|
||||
|
||||
Returns: Login identifier dict of type 'm.id.threepid'
|
||||
"""
|
||||
if "country" not in identifier or (
|
||||
# The specification requires a "phone" field, while Synapse used to require a "number"
|
||||
# field. Accept both for backwards compatibility.
|
||||
"phone" not in identifier
|
||||
and "number" not in identifier
|
||||
):
|
||||
raise SynapseError(400, "Invalid phone-type identifier")
|
||||
|
||||
# Accept both "phone" and "number" as valid keys in m.id.phone
|
||||
phone_number = identifier.get("phone", identifier["number"])
|
||||
|
||||
msisdn = phone_number_to_msisdn(identifier["country"], phone_number)
|
||||
|
||||
return {"type": "m.id.thirdparty", "medium": "msisdn", "address": msisdn}
|
||||
|
||||
|
||||
class LoginRestServlet(RestServlet):
|
||||
PATTERNS = client_patterns("/login$", v1=True)
|
||||
CAS_TYPE = "m.login.cas"
|
||||
@@ -194,18 +153,11 @@ class LoginRestServlet(RestServlet):
|
||||
login_submission.get("address"),
|
||||
login_submission.get("user"),
|
||||
)
|
||||
login_submission_legacy_convert(login_submission)
|
||||
|
||||
if "identifier" not in login_submission:
|
||||
raise SynapseError(400, "Missing param: identifier")
|
||||
|
||||
identifier = login_submission["identifier"]
|
||||
if "type" not in identifier:
|
||||
raise SynapseError(400, "Login identifier has no type")
|
||||
identifier = convert_client_dict_legacy_fields_to_identifier(login_submission)
|
||||
|
||||
# convert phone type identifiers to generic threepids
|
||||
if identifier["type"] == "m.id.phone":
|
||||
identifier = login_id_thirdparty_from_phone(identifier)
|
||||
identifier = login_id_phone_to_thirdparty(identifier)
|
||||
|
||||
# convert threepid identifiers to user IDs
|
||||
if identifier["type"] == "m.id.thirdparty":
|
||||
|
||||
@@ -25,6 +25,7 @@ from typing import (
|
||||
Sequence,
|
||||
Set,
|
||||
Union,
|
||||
cast,
|
||||
overload,
|
||||
)
|
||||
|
||||
@@ -41,7 +42,7 @@ from synapse.logging.utils import log_function
|
||||
from synapse.state import v1, v2
|
||||
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
|
||||
from synapse.storage.roommember import ProfileInfo
|
||||
from synapse.types import Collection, StateMap
|
||||
from synapse.types import Collection, MutableStateMap, StateMap
|
||||
from synapse.util import Clock
|
||||
from synapse.util.async_helpers import Linearizer
|
||||
from synapse.util.caches.expiringcache import ExpiringCache
|
||||
@@ -205,7 +206,7 @@ class StateHandler(object):
|
||||
|
||||
logger.debug("calling resolve_state_groups from get_current_state_ids")
|
||||
ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
|
||||
return dict(ret.state)
|
||||
return ret.state
|
||||
|
||||
async def get_current_users_in_room(
|
||||
self, room_id: str, latest_event_ids: Optional[List[str]] = None
|
||||
@@ -302,7 +303,7 @@ class StateHandler(object):
|
||||
# if we're given the state before the event, then we use that
|
||||
state_ids_before_event = {
|
||||
(s.type, s.state_key): s.event_id for s in old_state
|
||||
}
|
||||
} # type: StateMap[str]
|
||||
state_group_before_event = None
|
||||
state_group_before_event_prev_group = None
|
||||
deltas_to_state_group_before_event = None
|
||||
@@ -315,7 +316,7 @@ class StateHandler(object):
|
||||
event.room_id, event.prev_event_ids()
|
||||
)
|
||||
|
||||
state_ids_before_event = dict(entry.state)
|
||||
state_ids_before_event = entry.state
|
||||
state_group_before_event = entry.state_group
|
||||
state_group_before_event_prev_group = entry.prev_group
|
||||
deltas_to_state_group_before_event = entry.delta_ids
|
||||
@@ -540,7 +541,7 @@ class StateResolutionHandler(object):
|
||||
#
|
||||
# XXX: is this actually worthwhile, or should we just let
|
||||
# resolve_events_with_store do it?
|
||||
new_state = {}
|
||||
new_state = {} # type: MutableStateMap[str]
|
||||
conflicted_state = False
|
||||
for st in state_groups_ids.values():
|
||||
for key, e_id in st.items():
|
||||
@@ -554,13 +555,20 @@ class StateResolutionHandler(object):
|
||||
if conflicted_state:
|
||||
logger.info("Resolving conflicted state for %r", room_id)
|
||||
with Measure(self.clock, "state._resolve_events"):
|
||||
new_state = await resolve_events_with_store(
|
||||
self.clock,
|
||||
room_id,
|
||||
room_version,
|
||||
list(state_groups_ids.values()),
|
||||
event_map=event_map,
|
||||
state_res_store=state_res_store,
|
||||
# resolve_events_with_store returns a StateMap, but we can
|
||||
# treat it as a MutableStateMap as it is above. It isn't
|
||||
# actually mutated anymore (and is frozen in
|
||||
# _make_state_cache_entry below).
|
||||
new_state = cast(
|
||||
MutableStateMap,
|
||||
await resolve_events_with_store(
|
||||
self.clock,
|
||||
room_id,
|
||||
room_version,
|
||||
list(state_groups_ids.values()),
|
||||
event_map=event_map,
|
||||
state_res_store=state_res_store,
|
||||
),
|
||||
)
|
||||
|
||||
# if the new state matches any of the input state groups, we can
|
||||
|
||||
@@ -32,7 +32,7 @@ from synapse.api.constants import EventTypes
|
||||
from synapse.api.errors import AuthError
|
||||
from synapse.api.room_versions import RoomVersions
|
||||
from synapse.events import EventBase
|
||||
from synapse.types import StateMap
|
||||
from synapse.types import MutableStateMap, StateMap
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -131,7 +131,7 @@ async def resolve_events_with_store(
|
||||
|
||||
def _seperate(
|
||||
state_sets: Iterable[StateMap[str]],
|
||||
) -> Tuple[StateMap[str], StateMap[Set[str]]]:
|
||||
) -> Tuple[MutableStateMap[str], MutableStateMap[Set[str]]]:
|
||||
"""Takes the state_sets and figures out which keys are conflicted and
|
||||
which aren't. i.e., which have multiple different event_ids associated
|
||||
with them in different state sets.
|
||||
@@ -152,7 +152,7 @@ def _seperate(
|
||||
"""
|
||||
state_set_iterator = iter(state_sets)
|
||||
unconflicted_state = dict(next(state_set_iterator))
|
||||
conflicted_state = {} # type: StateMap[Set[str]]
|
||||
conflicted_state = {} # type: MutableStateMap[Set[str]]
|
||||
|
||||
for state_set in state_set_iterator:
|
||||
for key, value in state_set.items():
|
||||
@@ -208,7 +208,7 @@ def _create_auth_events_from_maps(
|
||||
|
||||
|
||||
def _resolve_with_state(
|
||||
unconflicted_state_ids: StateMap[str],
|
||||
unconflicted_state_ids: MutableStateMap[str],
|
||||
conflicted_state_ids: StateMap[Set[str]],
|
||||
auth_event_ids: StateMap[str],
|
||||
state_map: Dict[str, EventBase],
|
||||
@@ -241,7 +241,7 @@ def _resolve_with_state(
|
||||
|
||||
|
||||
def _resolve_state_events(
|
||||
conflicted_state: StateMap[List[EventBase]], auth_events: StateMap[EventBase]
|
||||
conflicted_state: StateMap[List[EventBase]], auth_events: MutableStateMap[EventBase]
|
||||
) -> StateMap[EventBase]:
|
||||
""" This is where we actually decide which of the conflicted state to
|
||||
use.
|
||||
|
||||
@@ -38,7 +38,7 @@ from synapse.api.constants import EventTypes
|
||||
from synapse.api.errors import AuthError
|
||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
|
||||
from synapse.events import EventBase
|
||||
from synapse.types import StateMap
|
||||
from synapse.types import MutableStateMap, StateMap
|
||||
from synapse.util import Clock
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -414,7 +414,7 @@ async def _iterative_auth_checks(
|
||||
base_state: StateMap[str],
|
||||
event_map: Dict[str, EventBase],
|
||||
state_res_store: "synapse.state.StateResolutionStore",
|
||||
) -> StateMap[str]:
|
||||
) -> MutableStateMap[str]:
|
||||
"""Sequentially apply auth checks to each event in given list, updating the
|
||||
state as it goes along.
|
||||
|
||||
@@ -430,7 +430,7 @@ async def _iterative_auth_checks(
|
||||
Returns:
|
||||
Returns the final updated state
|
||||
"""
|
||||
resolved_state = base_state.copy()
|
||||
resolved_state = dict(base_state)
|
||||
room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
|
||||
|
||||
for idx, event_id in enumerate(event_ids, start=1):
|
||||
|
||||
@@ -172,7 +172,7 @@ class ApplicationServiceTransactionWorkerStore(
|
||||
"application_services_state", {"as_id": service.id}, {"state": state}
|
||||
)
|
||||
|
||||
def create_appservice_txn(self, service, events):
|
||||
async def create_appservice_txn(self, service, events):
|
||||
"""Atomically creates a new transaction for this application service
|
||||
with the given list of events.
|
||||
|
||||
@@ -209,20 +209,17 @@ class ApplicationServiceTransactionWorkerStore(
|
||||
)
|
||||
return AppServiceTransaction(service=service, id=new_txn_id, events=events)
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
"create_appservice_txn", _create_appservice_txn
|
||||
)
|
||||
|
||||
def complete_appservice_txn(self, txn_id, service):
|
||||
async def complete_appservice_txn(self, txn_id, service) -> None:
|
||||
"""Completes an application service transaction.
|
||||
|
||||
Args:
|
||||
txn_id(str): The transaction ID being completed.
|
||||
service(ApplicationService): The application service which was sent
|
||||
this transaction.
|
||||
Returns:
|
||||
A Deferred which resolves if this transaction was stored
|
||||
successfully.
|
||||
"""
|
||||
txn_id = int(txn_id)
|
||||
|
||||
@@ -258,7 +255,7 @@ class ApplicationServiceTransactionWorkerStore(
|
||||
{"txn_id": txn_id, "as_id": service.id},
|
||||
)
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
await self.db_pool.runInteraction(
|
||||
"complete_appservice_txn", _complete_appservice_txn
|
||||
)
|
||||
|
||||
@@ -312,13 +309,13 @@ class ApplicationServiceTransactionWorkerStore(
|
||||
else:
|
||||
return int(last_txn_id[0]) # select 'last_txn' col
|
||||
|
||||
def set_appservice_last_pos(self, pos):
|
||||
async def set_appservice_last_pos(self, pos) -> None:
|
||||
def set_appservice_last_pos_txn(txn):
|
||||
txn.execute(
|
||||
"UPDATE appservice_stream_position SET stream_ordering = ?", (pos,)
|
||||
)
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
await self.db_pool.runInteraction(
|
||||
"set_appservice_last_pos", set_appservice_last_pos_txn
|
||||
)
|
||||
|
||||
|
||||
@@ -190,15 +190,15 @@ class DeviceInboxWorkerStore(SQLBaseStore):
|
||||
)
|
||||
|
||||
@trace
|
||||
def delete_device_msgs_for_remote(self, destination, up_to_stream_id):
|
||||
async def delete_device_msgs_for_remote(
|
||||
self, destination: str, up_to_stream_id: int
|
||||
) -> None:
|
||||
"""Used to delete messages when the remote destination acknowledges
|
||||
their receipt.
|
||||
|
||||
Args:
|
||||
destination(str): The destination server_name
|
||||
up_to_stream_id(int): Where to delete messages up to.
|
||||
Returns:
|
||||
A deferred that resolves when the messages have been deleted.
|
||||
destination: The destination server_name
|
||||
up_to_stream_id: Where to delete messages up to.
|
||||
"""
|
||||
|
||||
def delete_messages_for_remote_destination_txn(txn):
|
||||
@@ -209,7 +209,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
|
||||
)
|
||||
txn.execute(sql, (destination, up_to_stream_id))
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
await self.db_pool.runInteraction(
|
||||
"delete_device_msgs_for_remote", delete_messages_for_remote_destination_txn
|
||||
)
|
||||
|
||||
|
||||
@@ -151,7 +151,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
|
||||
|
||||
return sessions
|
||||
|
||||
def get_e2e_room_keys_multi(self, user_id, version, room_keys):
|
||||
async def get_e2e_room_keys_multi(self, user_id, version, room_keys):
|
||||
"""Get multiple room keys at a time. The difference between this function and
|
||||
get_e2e_room_keys is that this function can be used to retrieve
|
||||
multiple specific keys at a time, whereas get_e2e_room_keys is used for
|
||||
@@ -166,10 +166,10 @@ class EndToEndRoomKeyStore(SQLBaseStore):
|
||||
that we want to query
|
||||
|
||||
Returns:
|
||||
Deferred[dict[str, dict[str, dict]]]: a map of room IDs to session IDs to room key
|
||||
dict[str, dict[str, dict]]: a map of room IDs to session IDs to room key
|
||||
"""
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_e2e_room_keys_multi",
|
||||
self._get_e2e_room_keys_multi_txn,
|
||||
user_id,
|
||||
@@ -283,7 +283,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
|
||||
raise StoreError(404, "No current backup version")
|
||||
return row[0]
|
||||
|
||||
def get_e2e_room_keys_version_info(self, user_id, version=None):
|
||||
async def get_e2e_room_keys_version_info(self, user_id, version=None):
|
||||
"""Get info metadata about a version of our room_keys backup.
|
||||
|
||||
Args:
|
||||
@@ -293,7 +293,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
|
||||
Raises:
|
||||
StoreError: with code 404 if there are no e2e_room_keys_versions present
|
||||
Returns:
|
||||
A deferred dict giving the info metadata for this backup version, with
|
||||
A dict giving the info metadata for this backup version, with
|
||||
fields including:
|
||||
version(str)
|
||||
algorithm(str)
|
||||
@@ -324,12 +324,12 @@ class EndToEndRoomKeyStore(SQLBaseStore):
|
||||
result["etag"] = 0
|
||||
return result
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_e2e_room_keys_version_info", _get_e2e_room_keys_version_info_txn
|
||||
)
|
||||
|
||||
@trace
|
||||
def create_e2e_room_keys_version(self, user_id, info):
|
||||
async def create_e2e_room_keys_version(self, user_id: str, info: dict) -> str:
|
||||
"""Atomically creates a new version of this user's e2e_room_keys store
|
||||
with the given version info.
|
||||
|
||||
@@ -338,7 +338,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
|
||||
info(dict): the info about the backup version to be created
|
||||
|
||||
Returns:
|
||||
A deferred string for the newly created version ID
|
||||
The newly created version ID
|
||||
"""
|
||||
|
||||
def _create_e2e_room_keys_version_txn(txn):
|
||||
@@ -365,7 +365,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
|
||||
|
||||
return new_version
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
"create_e2e_room_keys_version_txn", _create_e2e_room_keys_version_txn
|
||||
)
|
||||
|
||||
@@ -403,13 +403,15 @@ class EndToEndRoomKeyStore(SQLBaseStore):
|
||||
)
|
||||
|
||||
@trace
|
||||
def delete_e2e_room_keys_version(self, user_id, version=None):
|
||||
async def delete_e2e_room_keys_version(
|
||||
self, user_id: str, version: Optional[str] = None
|
||||
) -> None:
|
||||
"""Delete a given backup version of the user's room keys.
|
||||
Doesn't delete their actual key data.
|
||||
|
||||
Args:
|
||||
user_id(str): the user whose backup version we're deleting
|
||||
version(str): Optional. the version ID of the backup version we're deleting
|
||||
user_id: the user whose backup version we're deleting
|
||||
version: Optional. the version ID of the backup version we're deleting
|
||||
If missing, we delete the current backup version info.
|
||||
Raises:
|
||||
StoreError: with code 404 if there are no e2e_room_keys_versions present,
|
||||
@@ -430,13 +432,13 @@ class EndToEndRoomKeyStore(SQLBaseStore):
|
||||
keyvalues={"user_id": user_id, "version": this_version},
|
||||
)
|
||||
|
||||
return self.db_pool.simple_update_one_txn(
|
||||
self.db_pool.simple_update_one_txn(
|
||||
txn,
|
||||
table="e2e_room_keys_versions",
|
||||
keyvalues={"user_id": user_id, "version": this_version},
|
||||
updatevalues={"deleted": 1},
|
||||
)
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
await self.db_pool.runInteraction(
|
||||
"delete_e2e_room_keys_version", _delete_e2e_room_keys_version_txn
|
||||
)
|
||||
|
||||
@@ -59,7 +59,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
|
||||
include_given: include the given events in result
|
||||
|
||||
Returns:
|
||||
list of event_ids
|
||||
An awaitable which resolve to a list of event_ids
|
||||
"""
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_auth_chain_ids",
|
||||
@@ -95,7 +95,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
|
||||
|
||||
return list(results)
|
||||
|
||||
def get_auth_chain_difference(self, state_sets: List[Set[str]]):
|
||||
async def get_auth_chain_difference(self, state_sets: List[Set[str]]) -> Set[str]:
|
||||
"""Given sets of state events figure out the auth chain difference (as
|
||||
per state res v2 algorithm).
|
||||
|
||||
@@ -104,10 +104,10 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
|
||||
chain.
|
||||
|
||||
Returns:
|
||||
Deferred[Set[str]]
|
||||
The set of the difference in auth chains.
|
||||
"""
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_auth_chain_difference",
|
||||
self._get_auth_chain_difference_txn,
|
||||
state_sets,
|
||||
@@ -252,8 +252,8 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
|
||||
# Return all events where not all sets can reach them.
|
||||
return {eid for eid, n in event_to_missing_sets.items() if n}
|
||||
|
||||
def get_oldest_events_with_depth_in_room(self, room_id):
|
||||
return self.db_pool.runInteraction(
|
||||
async def get_oldest_events_with_depth_in_room(self, room_id):
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_oldest_events_with_depth_in_room",
|
||||
self.get_oldest_events_with_depth_in_room_txn,
|
||||
room_id,
|
||||
@@ -293,7 +293,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
|
||||
else:
|
||||
return max(row["depth"] for row in rows)
|
||||
|
||||
def get_prev_events_for_room(self, room_id: str):
|
||||
async def get_prev_events_for_room(self, room_id: str) -> List[str]:
|
||||
"""
|
||||
Gets a subset of the current forward extremities in the given room.
|
||||
|
||||
@@ -301,14 +301,14 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
|
||||
events which refer to hundreds of prev_events.
|
||||
|
||||
Args:
|
||||
room_id (str): room_id
|
||||
room_id: room_id
|
||||
|
||||
Returns:
|
||||
Deferred[List[str]]: the event ids of the forward extremites
|
||||
The event ids of the forward extremities.
|
||||
|
||||
"""
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_prev_events_for_room", self._get_prev_events_for_room_txn, room_id
|
||||
)
|
||||
|
||||
@@ -328,17 +328,19 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
|
||||
|
||||
return [row[0] for row in txn]
|
||||
|
||||
def get_rooms_with_many_extremities(self, min_count, limit, room_id_filter):
|
||||
async def get_rooms_with_many_extremities(
|
||||
self, min_count: int, limit: int, room_id_filter: Iterable[str]
|
||||
) -> List[str]:
|
||||
"""Get the top rooms with at least N extremities.
|
||||
|
||||
Args:
|
||||
min_count (int): The minimum number of extremities
|
||||
limit (int): The maximum number of rooms to return.
|
||||
room_id_filter (iterable[str]): room_ids to exclude from the results
|
||||
min_count: The minimum number of extremities
|
||||
limit: The maximum number of rooms to return.
|
||||
room_id_filter: room_ids to exclude from the results
|
||||
|
||||
Returns:
|
||||
Deferred[list]: At most `limit` room IDs that have at least
|
||||
`min_count` extremities, sorted by extremity count.
|
||||
At most `limit` room IDs that have at least `min_count` extremities,
|
||||
sorted by extremity count.
|
||||
"""
|
||||
|
||||
def _get_rooms_with_many_extremities_txn(txn):
|
||||
@@ -363,7 +365,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
|
||||
txn.execute(sql, query_args)
|
||||
return [room_id for room_id, in txn]
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_rooms_with_many_extremities", _get_rooms_with_many_extremities_txn
|
||||
)
|
||||
|
||||
@@ -376,10 +378,10 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
|
||||
desc="get_latest_event_ids_in_room",
|
||||
)
|
||||
|
||||
def get_min_depth(self, room_id):
|
||||
""" For hte given room, get the minimum depth we have seen for it.
|
||||
async def get_min_depth(self, room_id: str) -> int:
|
||||
"""For the given room, get the minimum depth we have seen for it.
|
||||
"""
|
||||
return self.db_pool.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_min_depth", self._get_min_depth_interaction, room_id
|
||||
)
|
||||
|
||||
@@ -394,7 +396,9 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
|
||||
|
||||
return int(min_depth) if min_depth is not None else None
|
||||
|
||||
def get_forward_extremeties_for_room(self, room_id, stream_ordering):
|
||||
async def get_forward_extremeties_for_room(
|
||||
self, room_id: str, stream_ordering: int
|
||||
) -> List[str]:
|
||||
"""For a given room_id and stream_ordering, return the forward
|
||||
extremeties of the room at that point in "time".
|
||||
|
||||
@@ -402,11 +406,11 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
|
||||
stream_orderings from that point.
|
||||
|
||||
Args:
|
||||
room_id (str):
|
||||
stream_ordering (int):
|
||||
room_id:
|
||||
stream_ordering:
|
||||
|
||||
Returns:
|
||||
deferred, which resolves to a list of event_ids
|
||||
A list of event_ids
|
||||
"""
|
||||
# We want to make the cache more effective, so we clamp to the last
|
||||
# change before the given ordering.
|
||||
@@ -422,10 +426,10 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
|
||||
if last_change > self.stream_ordering_month_ago:
|
||||
stream_ordering = min(last_change, stream_ordering)
|
||||
|
||||
return self._get_forward_extremeties_for_room(room_id, stream_ordering)
|
||||
return await self._get_forward_extremeties_for_room(room_id, stream_ordering)
|
||||
|
||||
@cached(max_entries=5000, num_args=2)
|
||||
def _get_forward_extremeties_for_room(self, room_id, stream_ordering):
|
||||
async def _get_forward_extremeties_for_room(self, room_id, stream_ordering):
|
||||
"""For a given room_id and stream_ordering, return the forward
|
||||
extremeties of the room at that point in "time".
|
||||
|
||||
@@ -450,19 +454,18 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
|
||||
txn.execute(sql, (stream_ordering, room_id))
|
||||
return [event_id for event_id, in txn]
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_forward_extremeties_for_room", get_forward_extremeties_for_room_txn
|
||||
)
|
||||
|
||||
async def get_backfill_events(self, room_id, event_list, limit):
|
||||
async def get_backfill_events(self, room_id: str, event_list: list, limit: int):
|
||||
"""Get a list of Events for a given topic that occurred before (and
|
||||
including) the events in event_list. Return a list of max size `limit`
|
||||
|
||||
Args:
|
||||
txn
|
||||
room_id (str)
|
||||
event_list (list)
|
||||
limit (int)
|
||||
room_id
|
||||
event_list
|
||||
limit
|
||||
"""
|
||||
event_ids = await self.db_pool.runInteraction(
|
||||
"get_backfill_events",
|
||||
@@ -631,8 +634,8 @@ class EventFederationStore(EventFederationWorkerStore):
|
||||
_delete_old_forward_extrem_cache_txn,
|
||||
)
|
||||
|
||||
def clean_room_for_join(self, room_id):
|
||||
return self.db_pool.runInteraction(
|
||||
async def clean_room_for_join(self, room_id):
|
||||
return await self.db_pool.runInteraction(
|
||||
"clean_room_for_join", self._clean_room_for_join_txn, room_id
|
||||
)
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.storage._base import SQLBaseStore, db_to_json
|
||||
@@ -70,7 +70,9 @@ class GroupServerWorkerStore(SQLBaseStore):
|
||||
desc="get_invited_users_in_group",
|
||||
)
|
||||
|
||||
def get_rooms_in_group(self, group_id: str, include_private: bool = False):
|
||||
async def get_rooms_in_group(
|
||||
self, group_id: str, include_private: bool = False
|
||||
) -> List[Dict[str, Union[str, bool]]]:
|
||||
"""Retrieve the rooms that belong to a given group. Does not return rooms that
|
||||
lack members.
|
||||
|
||||
@@ -79,8 +81,7 @@ class GroupServerWorkerStore(SQLBaseStore):
|
||||
include_private: Whether to return private rooms in results
|
||||
|
||||
Returns:
|
||||
Deferred[List[Dict[str, str|bool]]]: A list of dictionaries, each in the
|
||||
form of:
|
||||
A list of dictionaries, each in the form of:
|
||||
|
||||
{
|
||||
"room_id": "!a_room_id:example.com", # The ID of the room
|
||||
@@ -117,13 +118,13 @@ class GroupServerWorkerStore(SQLBaseStore):
|
||||
for room_id, is_public in txn
|
||||
]
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_rooms_in_group", _get_rooms_in_group_txn
|
||||
)
|
||||
|
||||
def get_rooms_for_summary_by_category(
|
||||
async def get_rooms_for_summary_by_category(
|
||||
self, group_id: str, include_private: bool = False,
|
||||
):
|
||||
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
|
||||
"""Get the rooms and categories that should be included in a summary request
|
||||
|
||||
Args:
|
||||
@@ -131,7 +132,7 @@ class GroupServerWorkerStore(SQLBaseStore):
|
||||
include_private: Whether to return private rooms in results
|
||||
|
||||
Returns:
|
||||
Deferred[Tuple[List, Dict]]: A tuple containing:
|
||||
A tuple containing:
|
||||
|
||||
* A list of dictionaries with the keys:
|
||||
* "room_id": str, the room ID
|
||||
@@ -207,7 +208,7 @@ class GroupServerWorkerStore(SQLBaseStore):
|
||||
|
||||
return rooms, categories
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_rooms_for_summary", _get_rooms_for_summary_txn
|
||||
)
|
||||
|
||||
@@ -281,10 +282,11 @@ class GroupServerWorkerStore(SQLBaseStore):
|
||||
desc="get_local_groups_for_room",
|
||||
)
|
||||
|
||||
def get_users_for_summary_by_role(self, group_id, include_private=False):
|
||||
async def get_users_for_summary_by_role(self, group_id, include_private=False):
|
||||
"""Get the users and roles that should be included in a summary request
|
||||
|
||||
Returns ([users], [roles])
|
||||
Returns:
|
||||
([users], [roles])
|
||||
"""
|
||||
|
||||
def _get_users_for_summary_txn(txn):
|
||||
@@ -338,7 +340,7 @@ class GroupServerWorkerStore(SQLBaseStore):
|
||||
|
||||
return users, roles
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_users_for_summary_by_role", _get_users_for_summary_txn
|
||||
)
|
||||
|
||||
@@ -376,7 +378,7 @@ class GroupServerWorkerStore(SQLBaseStore):
|
||||
allow_none=True,
|
||||
)
|
||||
|
||||
def get_users_membership_info_in_group(self, group_id, user_id):
|
||||
async def get_users_membership_info_in_group(self, group_id, user_id):
|
||||
"""Get a dict describing the membership of a user in a group.
|
||||
|
||||
Example if joined:
|
||||
@@ -387,7 +389,8 @@ class GroupServerWorkerStore(SQLBaseStore):
|
||||
"is_privileged": False,
|
||||
}
|
||||
|
||||
Returns an empty dict if the user is not join/invite/etc
|
||||
Returns:
|
||||
An empty dict if the user is not join/invite/etc
|
||||
"""
|
||||
|
||||
def _get_users_membership_in_group_txn(txn):
|
||||
@@ -419,7 +422,7 @@ class GroupServerWorkerStore(SQLBaseStore):
|
||||
|
||||
return {}
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_users_membership_info_in_group", _get_users_membership_in_group_txn
|
||||
)
|
||||
|
||||
@@ -433,7 +436,7 @@ class GroupServerWorkerStore(SQLBaseStore):
|
||||
desc="get_publicised_groups_for_user",
|
||||
)
|
||||
|
||||
def get_attestations_need_renewals(self, valid_until_ms):
|
||||
async def get_attestations_need_renewals(self, valid_until_ms):
|
||||
"""Get all attestations that need to be renewed until givent time
|
||||
"""
|
||||
|
||||
@@ -445,7 +448,7 @@ class GroupServerWorkerStore(SQLBaseStore):
|
||||
txn.execute(sql, (valid_until_ms,))
|
||||
return self.db_pool.cursor_to_dict(txn)
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_attestations_need_renewals", _get_attestations_need_renewals_txn
|
||||
)
|
||||
|
||||
@@ -475,7 +478,7 @@ class GroupServerWorkerStore(SQLBaseStore):
|
||||
desc="get_joined_groups",
|
||||
)
|
||||
|
||||
def get_all_groups_for_user(self, user_id, now_token):
|
||||
async def get_all_groups_for_user(self, user_id, now_token):
|
||||
def _get_all_groups_for_user_txn(txn):
|
||||
sql = """
|
||||
SELECT group_id, type, membership, u.content
|
||||
@@ -495,7 +498,7 @@ class GroupServerWorkerStore(SQLBaseStore):
|
||||
for row in txn
|
||||
]
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_all_groups_for_user", _get_all_groups_for_user_txn
|
||||
)
|
||||
|
||||
@@ -600,8 +603,27 @@ class GroupServerStore(GroupServerWorkerStore):
|
||||
desc="set_group_join_policy",
|
||||
)
|
||||
|
||||
def add_room_to_summary(self, group_id, room_id, category_id, order, is_public):
|
||||
return self.db_pool.runInteraction(
|
||||
async def add_room_to_summary(
|
||||
self,
|
||||
group_id: str,
|
||||
room_id: str,
|
||||
category_id: str,
|
||||
order: int,
|
||||
is_public: Optional[bool],
|
||||
) -> None:
|
||||
"""Add (or update) room's entry in summary.
|
||||
|
||||
Args:
|
||||
group_id
|
||||
room_id
|
||||
category_id: If not None then adds the category to the end of
|
||||
the summary if its not already there.
|
||||
order: If not None inserts the room at that position, e.g. an order
|
||||
of 1 will put the room first. Otherwise, the room gets added to
|
||||
the end.
|
||||
is_public
|
||||
"""
|
||||
await self.db_pool.runInteraction(
|
||||
"add_room_to_summary",
|
||||
self._add_room_to_summary_txn,
|
||||
group_id,
|
||||
@@ -612,18 +634,26 @@ class GroupServerStore(GroupServerWorkerStore):
|
||||
)
|
||||
|
||||
def _add_room_to_summary_txn(
|
||||
self, txn, group_id, room_id, category_id, order, is_public
|
||||
):
|
||||
self,
|
||||
txn,
|
||||
group_id: str,
|
||||
room_id: str,
|
||||
category_id: str,
|
||||
order: int,
|
||||
is_public: Optional[bool],
|
||||
) -> None:
|
||||
"""Add (or update) room's entry in summary.
|
||||
|
||||
Args:
|
||||
group_id (str)
|
||||
room_id (str)
|
||||
category_id (str): If not None then adds the category to the end of
|
||||
the summary if its not already there. [Optional]
|
||||
order (int): If not None inserts the room at that position, e.g.
|
||||
an order of 1 will put the room first. Otherwise, the room gets
|
||||
added to the end.
|
||||
txn
|
||||
group_id
|
||||
room_id
|
||||
category_id: If not None then adds the category to the end of
|
||||
the summary if its not already there.
|
||||
order: If not None inserts the room at that position, e.g. an order
|
||||
of 1 will put the room first. Otherwise, the room gets added to
|
||||
the end.
|
||||
is_public
|
||||
"""
|
||||
room_in_group = self.db_pool.simple_select_one_onecol_txn(
|
||||
txn,
|
||||
@@ -818,8 +848,27 @@ class GroupServerStore(GroupServerWorkerStore):
|
||||
desc="remove_group_role",
|
||||
)
|
||||
|
||||
def add_user_to_summary(self, group_id, user_id, role_id, order, is_public):
|
||||
return self.db_pool.runInteraction(
|
||||
async def add_user_to_summary(
|
||||
self,
|
||||
group_id: str,
|
||||
user_id: str,
|
||||
role_id: str,
|
||||
order: int,
|
||||
is_public: Optional[bool],
|
||||
) -> None:
|
||||
"""Add (or update) user's entry in summary.
|
||||
|
||||
Args:
|
||||
group_id
|
||||
user_id
|
||||
role_id: If not None then adds the role to the end of the summary if
|
||||
its not already there.
|
||||
order: If not None inserts the user at that position, e.g. an order
|
||||
of 1 will put the user first. Otherwise, the user gets added to
|
||||
the end.
|
||||
is_public
|
||||
"""
|
||||
await self.db_pool.runInteraction(
|
||||
"add_user_to_summary",
|
||||
self._add_user_to_summary_txn,
|
||||
group_id,
|
||||
@@ -830,18 +879,26 @@ class GroupServerStore(GroupServerWorkerStore):
|
||||
)
|
||||
|
||||
def _add_user_to_summary_txn(
|
||||
self, txn, group_id, user_id, role_id, order, is_public
|
||||
self,
|
||||
txn,
|
||||
group_id: str,
|
||||
user_id: str,
|
||||
role_id: str,
|
||||
order: int,
|
||||
is_public: Optional[bool],
|
||||
):
|
||||
"""Add (or update) user's entry in summary.
|
||||
|
||||
Args:
|
||||
group_id (str)
|
||||
user_id (str)
|
||||
role_id (str): If not None then adds the role to the end of
|
||||
the summary if its not already there. [Optional]
|
||||
order (int): If not None inserts the user at that position, e.g.
|
||||
an order of 1 will put the user first. Otherwise, the user gets
|
||||
added to the end.
|
||||
txn
|
||||
group_id
|
||||
user_id
|
||||
role_id: If not None then adds the role to the end of the summary if
|
||||
its not already there.
|
||||
order: If not None inserts the user at that position, e.g. an order
|
||||
of 1 will put the user first. Otherwise, the user gets added to
|
||||
the end.
|
||||
is_public
|
||||
"""
|
||||
user_in_group = self.db_pool.simple_select_one_onecol_txn(
|
||||
txn,
|
||||
@@ -963,27 +1020,26 @@ class GroupServerStore(GroupServerWorkerStore):
|
||||
desc="add_group_invite",
|
||||
)
|
||||
|
||||
def add_user_to_group(
|
||||
async def add_user_to_group(
|
||||
self,
|
||||
group_id,
|
||||
user_id,
|
||||
is_admin=False,
|
||||
is_public=True,
|
||||
local_attestation=None,
|
||||
remote_attestation=None,
|
||||
):
|
||||
group_id: str,
|
||||
user_id: str,
|
||||
is_admin: bool = False,
|
||||
is_public: bool = True,
|
||||
local_attestation: dict = None,
|
||||
remote_attestation: dict = None,
|
||||
) -> None:
|
||||
"""Add a user to the group server.
|
||||
|
||||
Args:
|
||||
group_id (str)
|
||||
user_id (str)
|
||||
is_admin (bool)
|
||||
is_public (bool)
|
||||
local_attestation (dict): The attestation the GS created to give
|
||||
to the remote server. Optional if the user and group are on the
|
||||
same server
|
||||
remote_attestation (dict): The attestation given to GS by remote
|
||||
group_id
|
||||
user_id
|
||||
is_admin
|
||||
is_public
|
||||
local_attestation: The attestation the GS created to give to the remote
|
||||
server. Optional if the user and group are on the same server
|
||||
remote_attestation: The attestation given to GS by remote server.
|
||||
Optional if the user and group are on the same server
|
||||
"""
|
||||
|
||||
def _add_user_to_group_txn(txn):
|
||||
@@ -1026,9 +1082,9 @@ class GroupServerStore(GroupServerWorkerStore):
|
||||
},
|
||||
)
|
||||
|
||||
return self.db_pool.runInteraction("add_user_to_group", _add_user_to_group_txn)
|
||||
await self.db_pool.runInteraction("add_user_to_group", _add_user_to_group_txn)
|
||||
|
||||
def remove_user_from_group(self, group_id, user_id):
|
||||
async def remove_user_from_group(self, group_id: str, user_id: str) -> None:
|
||||
def _remove_user_from_group_txn(txn):
|
||||
self.db_pool.simple_delete_txn(
|
||||
txn,
|
||||
@@ -1056,7 +1112,7 @@ class GroupServerStore(GroupServerWorkerStore):
|
||||
keyvalues={"group_id": group_id, "user_id": user_id},
|
||||
)
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
await self.db_pool.runInteraction(
|
||||
"remove_user_from_group", _remove_user_from_group_txn
|
||||
)
|
||||
|
||||
@@ -1079,7 +1135,7 @@ class GroupServerStore(GroupServerWorkerStore):
|
||||
desc="update_room_in_group_visibility",
|
||||
)
|
||||
|
||||
def remove_room_from_group(self, group_id, room_id):
|
||||
async def remove_room_from_group(self, group_id: str, room_id: str) -> None:
|
||||
def _remove_room_from_group_txn(txn):
|
||||
self.db_pool.simple_delete_txn(
|
||||
txn,
|
||||
@@ -1093,7 +1149,7 @@ class GroupServerStore(GroupServerWorkerStore):
|
||||
keyvalues={"group_id": group_id, "room_id": room_id},
|
||||
)
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
await self.db_pool.runInteraction(
|
||||
"remove_room_from_group", _remove_room_from_group_txn
|
||||
)
|
||||
|
||||
@@ -1286,14 +1342,11 @@ class GroupServerStore(GroupServerWorkerStore):
|
||||
def get_group_stream_token(self):
|
||||
return self._group_updates_id_gen.get_current_token()
|
||||
|
||||
def delete_group(self, group_id):
|
||||
async def delete_group(self, group_id: str) -> None:
|
||||
"""Deletes a group fully from the database.
|
||||
|
||||
Args:
|
||||
group_id (str)
|
||||
|
||||
Returns:
|
||||
Deferred
|
||||
group_id: The group ID to delete.
|
||||
"""
|
||||
|
||||
def _delete_group_txn(txn):
|
||||
@@ -1317,4 +1370,4 @@ class GroupServerStore(GroupServerWorkerStore):
|
||||
txn, table=table, keyvalues={"group_id": group_id}
|
||||
)
|
||||
|
||||
return self.db_pool.runInteraction("delete_group", _delete_group_txn)
|
||||
await self.db_pool.runInteraction("delete_group", _delete_group_txn)
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
|
||||
import itertools
|
||||
import logging
|
||||
from typing import Iterable, Tuple
|
||||
from typing import Dict, Iterable, List, Optional, Tuple
|
||||
|
||||
from signedjson.key import decode_verify_key_bytes
|
||||
|
||||
@@ -42,16 +42,17 @@ class KeyStore(SQLBaseStore):
|
||||
@cachedList(
|
||||
cached_method_name="_get_server_verify_key", list_name="server_name_and_key_ids"
|
||||
)
|
||||
def get_server_verify_keys(self, server_name_and_key_ids):
|
||||
async def get_server_verify_keys(
|
||||
self, server_name_and_key_ids: Iterable[Tuple[str, str]]
|
||||
) -> Dict[Tuple[str, str], Optional[FetchKeyResult]]:
|
||||
"""
|
||||
Args:
|
||||
server_name_and_key_ids (iterable[Tuple[str, str]]):
|
||||
server_name_and_key_ids:
|
||||
iterable of (server_name, key-id) tuples to fetch keys for
|
||||
|
||||
Returns:
|
||||
Deferred: resolves to dict[Tuple[str, str], FetchKeyResult|None]:
|
||||
map from (server_name, key_id) -> FetchKeyResult, or None if the key is
|
||||
unknown
|
||||
A map from (server_name, key_id) -> FetchKeyResult, or None if the
|
||||
key is unknown
|
||||
"""
|
||||
keys = {}
|
||||
|
||||
@@ -87,7 +88,7 @@ class KeyStore(SQLBaseStore):
|
||||
_get_keys(txn, batch)
|
||||
return keys
|
||||
|
||||
return self.db_pool.runInteraction("get_server_verify_keys", _txn)
|
||||
return await self.db_pool.runInteraction("get_server_verify_keys", _txn)
|
||||
|
||||
async def store_server_verify_keys(
|
||||
self,
|
||||
@@ -179,7 +180,9 @@ class KeyStore(SQLBaseStore):
|
||||
desc="store_server_keys_json",
|
||||
)
|
||||
|
||||
def get_server_keys_json(self, server_keys):
|
||||
async def get_server_keys_json(
|
||||
self, server_keys: Iterable[Tuple[str, Optional[str], Optional[str]]]
|
||||
) -> Dict[Tuple[str, Optional[str], Optional[str]], List[dict]]:
|
||||
"""Retrive the key json for a list of server_keys and key ids.
|
||||
If no keys are found for a given server, key_id and source then
|
||||
that server, key_id, and source triplet entry will be an empty list.
|
||||
@@ -188,8 +191,7 @@ class KeyStore(SQLBaseStore):
|
||||
Args:
|
||||
server_keys (list): List of (server_name, key_id, source) triplets.
|
||||
Returns:
|
||||
Deferred[dict[Tuple[str, str, str|None], list[dict]]]:
|
||||
Dict mapping (server_name, key_id, source) triplets to lists of dicts
|
||||
A mapping from (server_name, key_id, source) triplets to a list of dicts
|
||||
"""
|
||||
|
||||
def _get_server_keys_json_txn(txn):
|
||||
@@ -215,6 +217,6 @@ class KeyStore(SQLBaseStore):
|
||||
results[(server_name, key_id, from_server)] = rows
|
||||
return results
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_server_keys_json", _get_server_keys_json_txn
|
||||
)
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
|
||||
import logging
|
||||
from collections import namedtuple
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from canonicaljson import encode_canonical_json
|
||||
|
||||
@@ -56,21 +57,23 @@ class TransactionStore(SQLBaseStore):
|
||||
expiry_ms=5 * 60 * 1000,
|
||||
)
|
||||
|
||||
def get_received_txn_response(self, transaction_id, origin):
|
||||
async def get_received_txn_response(
|
||||
self, transaction_id: str, origin: str
|
||||
) -> Optional[Tuple[int, JsonDict]]:
|
||||
"""For an incoming transaction from a given origin, check if we have
|
||||
already responded to it. If so, return the response code and response
|
||||
body (as a dict).
|
||||
|
||||
Args:
|
||||
transaction_id (str)
|
||||
origin(str)
|
||||
transaction_id
|
||||
origin
|
||||
|
||||
Returns:
|
||||
tuple: None if we have not previously responded to
|
||||
this transaction or a 2-tuple of (int, dict)
|
||||
None if we have not previously responded to this transaction or a
|
||||
2-tuple of (int, dict)
|
||||
"""
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_received_txn_response",
|
||||
self._get_received_txn_response,
|
||||
transaction_id,
|
||||
@@ -166,21 +169,25 @@ class TransactionStore(SQLBaseStore):
|
||||
else:
|
||||
return None
|
||||
|
||||
def set_destination_retry_timings(
|
||||
self, destination, failure_ts, retry_last_ts, retry_interval
|
||||
):
|
||||
async def set_destination_retry_timings(
|
||||
self,
|
||||
destination: str,
|
||||
failure_ts: Optional[int],
|
||||
retry_last_ts: int,
|
||||
retry_interval: int,
|
||||
) -> None:
|
||||
"""Sets the current retry timings for a given destination.
|
||||
Both timings should be zero if retrying is no longer occuring.
|
||||
|
||||
Args:
|
||||
destination (str)
|
||||
failure_ts (int|None) - when the server started failing (ms since epoch)
|
||||
retry_last_ts (int) - time of last retry attempt in unix epoch ms
|
||||
retry_interval (int) - how long until next retry in ms
|
||||
destination
|
||||
failure_ts: when the server started failing (ms since epoch)
|
||||
retry_last_ts: time of last retry attempt in unix epoch ms
|
||||
retry_interval: how long until next retry in ms
|
||||
"""
|
||||
|
||||
self._destination_retry_cache.pop(destination, None)
|
||||
return self.db_pool.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
"set_destination_retry_timings",
|
||||
self._set_destination_retry_timings,
|
||||
destination,
|
||||
@@ -256,13 +263,13 @@ class TransactionStore(SQLBaseStore):
|
||||
"cleanup_transactions", self._cleanup_transactions
|
||||
)
|
||||
|
||||
def _cleanup_transactions(self):
|
||||
async def _cleanup_transactions(self) -> None:
|
||||
now = self._clock.time_msec()
|
||||
month_ago = now - 30 * 24 * 60 * 60 * 1000
|
||||
|
||||
def _cleanup_transactions_txn(txn):
|
||||
txn.execute("DELETE FROM received_transactions WHERE ts < ?", (month_ago,))
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
await self.db_pool.runInteraction(
|
||||
"_cleanup_transactions", _cleanup_transactions_txn
|
||||
)
|
||||
|
||||
@@ -18,7 +18,7 @@ import re
|
||||
import string
|
||||
import sys
|
||||
from collections import namedtuple
|
||||
from typing import Any, Dict, Tuple, Type, TypeVar
|
||||
from typing import Any, Dict, Mapping, MutableMapping, Tuple, Type, TypeVar
|
||||
|
||||
import attr
|
||||
from signedjson.key import decode_verify_key_bytes
|
||||
@@ -42,8 +42,9 @@ else:
|
||||
# Define a state map type from type/state_key to T (usually an event ID or
|
||||
# event)
|
||||
T = TypeVar("T")
|
||||
StateMap = Dict[Tuple[str, str], T]
|
||||
|
||||
StateKey = Tuple[str, str]
|
||||
StateMap = Mapping[StateKey, T]
|
||||
MutableStateMap = MutableMapping[StateKey, T]
|
||||
|
||||
# the type of a JSON-serialisable dict. This could be made stronger, but it will
|
||||
# do for now.
|
||||
|
||||
Reference in New Issue
Block a user