1
0

Merge commit '5c03134d0' into anoa/dinsic_release_1_21_x

* commit '5c03134d0':
  Convert additional database code to async/await. (#8195)
  Define StateMap as immutable and add a MutableStateMap type. (#8183)
  Move and refactor LoginRestServlet helper methods (#8182)
This commit is contained in:
Andrew Morgan
2020-10-20 17:42:11 +01:00
21 changed files with 392 additions and 262 deletions

1
changelog.d/8182.misc Normal file
View File

@@ -0,0 +1 @@
Refactor some of `LoginRestServlet`'s helper methods, and move them to `AuthHandler` for easier reuse.

1
changelog.d/8183.misc Normal file
View File

@@ -0,0 +1 @@
Add type hints to `synapse.state`.

1
changelog.d/8195.misc Normal file
View File

@@ -0,0 +1 @@
Convert various parts of the codebase to async/await.

View File

@@ -14,11 +14,16 @@
# limitations under the License.
import logging
import re
from typing import TYPE_CHECKING
from synapse.api.constants import EventTypes
from synapse.appservice.api import ApplicationServiceApi
from synapse.types import GroupID, get_domain_from_id
from synapse.util.caches.descriptors import cached
if TYPE_CHECKING:
from synapse.storage.databases.main import DataStore
logger = logging.getLogger(__name__)
@@ -35,19 +40,19 @@ class AppServiceTransaction(object):
self.id = id
self.events = events
def send(self, as_api):
async def send(self, as_api: ApplicationServiceApi) -> bool:
"""Sends this transaction using the provided AS API interface.
Args:
as_api(ApplicationServiceApi): The API to use to send.
as_api: The API to use to send.
Returns:
An Awaitable which resolves to True if the transaction was sent.
True if the transaction was sent.
"""
return as_api.push_bulk(
return await as_api.push_bulk(
service=self.service, events=self.events, txn_id=self.id
)
def complete(self, store):
async def complete(self, store: "DataStore") -> None:
"""Completes this transaction as successful.
Marks this transaction ID on the application service and removes the
@@ -55,10 +60,8 @@ class AppServiceTransaction(object):
Args:
store: The database store to operate on.
Returns:
A Deferred which resolves to True if the transaction was completed.
"""
return store.complete_appservice_txn(service=self.service, txn_id=self.id)
await store.complete_appservice_txn(service=self.service, txn_id=self.id)
class ApplicationService(object):

View File

@@ -20,6 +20,7 @@ These actions are mostly only used by the :py:mod:`.replication` module.
"""
import logging
from typing import Optional, Tuple
from synapse.federation.units import Transaction
from synapse.logging.utils import log_function
@@ -36,25 +37,27 @@ class TransactionActions(object):
self.store = datastore
@log_function
def have_responded(self, origin, transaction):
""" Have we already responded to a transaction with the same id and
async def have_responded(
self, origin: str, transaction: Transaction
) -> Optional[Tuple[int, JsonDict]]:
"""Have we already responded to a transaction with the same id and
origin?
Returns:
Deferred: Results in `None` if we have not previously responded to
this transaction or a 2-tuple of `(int, dict)` representing the
response code and response body.
`None` if we have not previously responded to this transaction or a
2-tuple of `(int, dict)` representing the response code and response body.
"""
if not transaction.transaction_id:
transaction_id = transaction.transaction_id # type: ignore
if not transaction_id:
raise RuntimeError("Cannot persist a transaction with no transaction_id")
return self.store.get_received_txn_response(transaction.transaction_id, origin)
return await self.store.get_received_txn_response(transaction_id, origin)
@log_function
async def set_response(
self, origin: str, transaction: Transaction, code: int, response: JsonDict
) -> None:
""" Persist how we responded to a transaction.
"""Persist how we responded to a transaction.
"""
transaction_id = transaction.transaction_id # type: ignore
if not transaction_id:

View File

@@ -42,8 +42,9 @@ from synapse.http.site import SynapseRequest
from synapse.logging.context import defer_to_thread
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.module_api import ModuleApi
from synapse.types import Requester, UserID
from synapse.types import JsonDict, Requester, UserID
from synapse.util import stringutils as stringutils
from synapse.util.msisdn import phone_number_to_msisdn
from synapse.util.threepids import canonicalise_email
from ._base import BaseHandler
@@ -51,6 +52,91 @@ from ._base import BaseHandler
logger = logging.getLogger(__name__)
def convert_client_dict_legacy_fields_to_identifier(
submission: JsonDict,
) -> Dict[str, str]:
"""
Convert a legacy-formatted login submission to an identifier dict.
Legacy login submissions (used in both login and user-interactive authentication)
provide user-identifying information at the top-level instead.
These are now deprecated and replaced with identifiers:
https://matrix.org/docs/spec/client_server/r0.6.1#identifier-types
Args:
submission: The client dict to convert
Returns:
The matching identifier dict
Raises:
SynapseError: If the format of the client dict is invalid
"""
identifier = submission.get("identifier", {})
# Generate an m.id.user identifier if "user" parameter is present
user = submission.get("user")
if user:
identifier = {"type": "m.id.user", "user": user}
# Generate an m.id.thirdparty identifier if "medium" and "address" parameters are present
medium = submission.get("medium")
address = submission.get("address")
if medium and address:
identifier = {
"type": "m.id.thirdparty",
"medium": medium,
"address": address,
}
# We've converted valid, legacy login submissions to an identifier. If the
# submission still doesn't have an identifier, it's invalid
if not identifier:
raise SynapseError(400, "Invalid login submission", Codes.INVALID_PARAM)
# Ensure the identifier has a type
if "type" not in identifier:
raise SynapseError(
400, "'identifier' dict has no key 'type'", errcode=Codes.MISSING_PARAM,
)
return identifier
def login_id_phone_to_thirdparty(identifier: JsonDict) -> Dict[str, str]:
"""
Convert a phone login identifier type to a generic threepid identifier.
Args:
identifier: Login identifier dict of type 'm.id.phone'
Returns:
An equivalent m.id.thirdparty identifier dict
"""
if "country" not in identifier or (
# The specification requires a "phone" field, while Synapse used to require a "number"
# field. Accept both for backwards compatibility.
"phone" not in identifier
and "number" not in identifier
):
raise SynapseError(
400, "Invalid phone-type identifier", errcode=Codes.INVALID_PARAM
)
# Accept both "phone" and "number" as valid keys in m.id.phone
phone_number = identifier.get("phone", identifier["number"])
# Convert user-provided phone number to a consistent representation
msisdn = phone_number_to_msisdn(identifier["country"], phone_number)
return {
"type": "m.id.thirdparty",
"medium": "msisdn",
"address": msisdn,
}
class AuthHandler(BaseHandler):
SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000

View File

@@ -72,7 +72,13 @@ from synapse.replication.http.federation import (
from synapse.replication.http.membership import ReplicationUserJoinedLeftRoomRestServlet
from synapse.state import StateResolutionStore, resolve_events_with_store
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.types import JsonDict, StateMap, UserID, get_domain_from_id
from synapse.types import (
JsonDict,
MutableStateMap,
StateMap,
UserID,
get_domain_from_id,
)
from synapse.util.async_helpers import Linearizer, concurrently_execute
from synapse.util.distributor import user_joined_room
from synapse.util.retryutils import NotRetryingDestination
@@ -96,7 +102,7 @@ class _NewEventInfo:
event = attr.ib(type=EventBase)
state = attr.ib(type=Optional[Sequence[EventBase]], default=None)
auth_events = attr.ib(type=Optional[StateMap[EventBase]], default=None)
auth_events = attr.ib(type=Optional[MutableStateMap[EventBase]], default=None)
class FederationHandler(BaseHandler):
@@ -1883,8 +1889,8 @@ class FederationHandler(BaseHandler):
else:
return None
def get_min_depth_for_context(self, context):
return self.store.get_min_depth(context)
async def get_min_depth_for_context(self, context):
return await self.store.get_min_depth(context)
async def _handle_new_event(
self, origin, event, state=None, auth_events=None, backfilled=False
@@ -2063,7 +2069,7 @@ class FederationHandler(BaseHandler):
origin: str,
event: EventBase,
state: Optional[Iterable[EventBase]],
auth_events: Optional[StateMap[EventBase]],
auth_events: Optional[MutableStateMap[EventBase]],
backfilled: bool,
) -> EventContext:
context = await self.state_handler.compute_event_context(event, old_state=state)
@@ -2147,7 +2153,9 @@ class FederationHandler(BaseHandler):
current_states = await self.state_handler.resolve_events(
room_version, state_sets, event
)
current_state_ids = {k: e.event_id for k, e in current_states.items()}
current_state_ids = {
k: e.event_id for k, e in current_states.items()
} # type: StateMap[str]
else:
current_state_ids = await self.state_handler.get_current_state_ids(
event.room_id, latest_event_ids=extrem_ids
@@ -2233,7 +2241,7 @@ class FederationHandler(BaseHandler):
origin: str,
event: EventBase,
context: EventContext,
auth_events: StateMap[EventBase],
auth_events: MutableStateMap[EventBase],
) -> EventContext:
"""
@@ -2284,7 +2292,7 @@ class FederationHandler(BaseHandler):
origin: str,
event: EventBase,
context: EventContext,
auth_events: StateMap[EventBase],
auth_events: MutableStateMap[EventBase],
) -> EventContext:
"""Helper for do_auth. See there for docs.

View File

@@ -41,6 +41,7 @@ from synapse.http.endpoint import parse_and_validate_server_name
from synapse.storage.state import StateFilter
from synapse.types import (
JsonDict,
MutableStateMap,
Requester,
RoomAlias,
RoomID,
@@ -843,7 +844,7 @@ class RoomCreationHandler(BaseHandler):
room_id: str,
preset_config: str,
invite_list: List[str],
initial_state: StateMap,
initial_state: MutableStateMap,
creation_content: JsonDict,
room_alias: Optional[RoomAlias] = None,
power_level_content_override: Optional[JsonDict] = None,

View File

@@ -31,6 +31,7 @@ from synapse.storage.state import StateFilter
from synapse.types import (
Collection,
JsonDict,
MutableStateMap,
RoomStreamToken,
StateMap,
StreamToken,
@@ -588,7 +589,7 @@ class SyncHandler(object):
room_id: str,
sync_config: SyncConfig,
batch: TimelineBatch,
state: StateMap[EventBase],
state: MutableStateMap[EventBase],
now_token: StreamToken,
) -> Optional[JsonDict]:
""" Works out a room summary block for this room, summarising the number
@@ -736,7 +737,7 @@ class SyncHandler(object):
since_token: Optional[StreamToken],
now_token: StreamToken,
full_state: bool,
) -> StateMap[EventBase]:
) -> MutableStateMap[EventBase]:
""" Works out the difference in state between the start of the timeline
and the previous sync.

View File

@@ -18,6 +18,10 @@ from typing import Awaitable, Callable, Dict, Optional
from synapse.api.errors import Codes, LoginError, SynapseError
from synapse.api.ratelimiting import Ratelimiter
from synapse.handlers.auth import (
convert_client_dict_legacy_fields_to_identifier,
login_id_phone_to_thirdparty,
)
from synapse.http.server import finish_request
from synapse.http.servlet import (
RestServlet,
@@ -28,56 +32,11 @@ from synapse.http.site import SynapseRequest
from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.rest.well_known import WellKnownBuilder
from synapse.types import JsonDict, UserID
from synapse.util.msisdn import phone_number_to_msisdn
from synapse.util.threepids import canonicalise_email
logger = logging.getLogger(__name__)
def login_submission_legacy_convert(submission):
"""
If the input login submission is an old style object
(ie. with top-level user / medium / address) convert it
to a typed object.
"""
if "user" in submission:
submission["identifier"] = {"type": "m.id.user", "user": submission["user"]}
del submission["user"]
if "medium" in submission and "address" in submission:
submission["identifier"] = {
"type": "m.id.thirdparty",
"medium": submission["medium"],
"address": submission["address"],
}
del submission["medium"]
del submission["address"]
def login_id_thirdparty_from_phone(identifier):
"""
Convert a phone login identifier type to a generic threepid identifier
Args:
identifier(dict): Login identifier dict of type 'm.id.phone'
Returns: Login identifier dict of type 'm.id.threepid'
"""
if "country" not in identifier or (
# The specification requires a "phone" field, while Synapse used to require a "number"
# field. Accept both for backwards compatibility.
"phone" not in identifier
and "number" not in identifier
):
raise SynapseError(400, "Invalid phone-type identifier")
# Accept both "phone" and "number" as valid keys in m.id.phone
phone_number = identifier.get("phone", identifier["number"])
msisdn = phone_number_to_msisdn(identifier["country"], phone_number)
return {"type": "m.id.thirdparty", "medium": "msisdn", "address": msisdn}
class LoginRestServlet(RestServlet):
PATTERNS = client_patterns("/login$", v1=True)
CAS_TYPE = "m.login.cas"
@@ -194,18 +153,11 @@ class LoginRestServlet(RestServlet):
login_submission.get("address"),
login_submission.get("user"),
)
login_submission_legacy_convert(login_submission)
if "identifier" not in login_submission:
raise SynapseError(400, "Missing param: identifier")
identifier = login_submission["identifier"]
if "type" not in identifier:
raise SynapseError(400, "Login identifier has no type")
identifier = convert_client_dict_legacy_fields_to_identifier(login_submission)
# convert phone type identifiers to generic threepids
if identifier["type"] == "m.id.phone":
identifier = login_id_thirdparty_from_phone(identifier)
identifier = login_id_phone_to_thirdparty(identifier)
# convert threepid identifiers to user IDs
if identifier["type"] == "m.id.thirdparty":

View File

@@ -25,6 +25,7 @@ from typing import (
Sequence,
Set,
Union,
cast,
overload,
)
@@ -41,7 +42,7 @@ from synapse.logging.utils import log_function
from synapse.state import v1, v2
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.storage.roommember import ProfileInfo
from synapse.types import Collection, StateMap
from synapse.types import Collection, MutableStateMap, StateMap
from synapse.util import Clock
from synapse.util.async_helpers import Linearizer
from synapse.util.caches.expiringcache import ExpiringCache
@@ -205,7 +206,7 @@ class StateHandler(object):
logger.debug("calling resolve_state_groups from get_current_state_ids")
ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
return dict(ret.state)
return ret.state
async def get_current_users_in_room(
self, room_id: str, latest_event_ids: Optional[List[str]] = None
@@ -302,7 +303,7 @@ class StateHandler(object):
# if we're given the state before the event, then we use that
state_ids_before_event = {
(s.type, s.state_key): s.event_id for s in old_state
}
} # type: StateMap[str]
state_group_before_event = None
state_group_before_event_prev_group = None
deltas_to_state_group_before_event = None
@@ -315,7 +316,7 @@ class StateHandler(object):
event.room_id, event.prev_event_ids()
)
state_ids_before_event = dict(entry.state)
state_ids_before_event = entry.state
state_group_before_event = entry.state_group
state_group_before_event_prev_group = entry.prev_group
deltas_to_state_group_before_event = entry.delta_ids
@@ -540,7 +541,7 @@ class StateResolutionHandler(object):
#
# XXX: is this actually worthwhile, or should we just let
# resolve_events_with_store do it?
new_state = {}
new_state = {} # type: MutableStateMap[str]
conflicted_state = False
for st in state_groups_ids.values():
for key, e_id in st.items():
@@ -554,13 +555,20 @@ class StateResolutionHandler(object):
if conflicted_state:
logger.info("Resolving conflicted state for %r", room_id)
with Measure(self.clock, "state._resolve_events"):
new_state = await resolve_events_with_store(
self.clock,
room_id,
room_version,
list(state_groups_ids.values()),
event_map=event_map,
state_res_store=state_res_store,
# resolve_events_with_store returns a StateMap, but we can
# treat it as a MutableStateMap as it is above. It isn't
# actually mutated anymore (and is frozen in
# _make_state_cache_entry below).
new_state = cast(
MutableStateMap,
await resolve_events_with_store(
self.clock,
room_id,
room_version,
list(state_groups_ids.values()),
event_map=event_map,
state_res_store=state_res_store,
),
)
# if the new state matches any of the input state groups, we can

View File

@@ -32,7 +32,7 @@ from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError
from synapse.api.room_versions import RoomVersions
from synapse.events import EventBase
from synapse.types import StateMap
from synapse.types import MutableStateMap, StateMap
logger = logging.getLogger(__name__)
@@ -131,7 +131,7 @@ async def resolve_events_with_store(
def _seperate(
state_sets: Iterable[StateMap[str]],
) -> Tuple[StateMap[str], StateMap[Set[str]]]:
) -> Tuple[MutableStateMap[str], MutableStateMap[Set[str]]]:
"""Takes the state_sets and figures out which keys are conflicted and
which aren't. i.e., which have multiple different event_ids associated
with them in different state sets.
@@ -152,7 +152,7 @@ def _seperate(
"""
state_set_iterator = iter(state_sets)
unconflicted_state = dict(next(state_set_iterator))
conflicted_state = {} # type: StateMap[Set[str]]
conflicted_state = {} # type: MutableStateMap[Set[str]]
for state_set in state_set_iterator:
for key, value in state_set.items():
@@ -208,7 +208,7 @@ def _create_auth_events_from_maps(
def _resolve_with_state(
unconflicted_state_ids: StateMap[str],
unconflicted_state_ids: MutableStateMap[str],
conflicted_state_ids: StateMap[Set[str]],
auth_event_ids: StateMap[str],
state_map: Dict[str, EventBase],
@@ -241,7 +241,7 @@ def _resolve_with_state(
def _resolve_state_events(
conflicted_state: StateMap[List[EventBase]], auth_events: StateMap[EventBase]
conflicted_state: StateMap[List[EventBase]], auth_events: MutableStateMap[EventBase]
) -> StateMap[EventBase]:
""" This is where we actually decide which of the conflicted state to
use.

View File

@@ -38,7 +38,7 @@ from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import EventBase
from synapse.types import StateMap
from synapse.types import MutableStateMap, StateMap
from synapse.util import Clock
logger = logging.getLogger(__name__)
@@ -414,7 +414,7 @@ async def _iterative_auth_checks(
base_state: StateMap[str],
event_map: Dict[str, EventBase],
state_res_store: "synapse.state.StateResolutionStore",
) -> StateMap[str]:
) -> MutableStateMap[str]:
"""Sequentially apply auth checks to each event in given list, updating the
state as it goes along.
@@ -430,7 +430,7 @@ async def _iterative_auth_checks(
Returns:
Returns the final updated state
"""
resolved_state = base_state.copy()
resolved_state = dict(base_state)
room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
for idx, event_id in enumerate(event_ids, start=1):

View File

@@ -172,7 +172,7 @@ class ApplicationServiceTransactionWorkerStore(
"application_services_state", {"as_id": service.id}, {"state": state}
)
def create_appservice_txn(self, service, events):
async def create_appservice_txn(self, service, events):
"""Atomically creates a new transaction for this application service
with the given list of events.
@@ -209,20 +209,17 @@ class ApplicationServiceTransactionWorkerStore(
)
return AppServiceTransaction(service=service, id=new_txn_id, events=events)
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"create_appservice_txn", _create_appservice_txn
)
def complete_appservice_txn(self, txn_id, service):
async def complete_appservice_txn(self, txn_id, service) -> None:
"""Completes an application service transaction.
Args:
txn_id(str): The transaction ID being completed.
service(ApplicationService): The application service which was sent
this transaction.
Returns:
A Deferred which resolves if this transaction was stored
successfully.
"""
txn_id = int(txn_id)
@@ -258,7 +255,7 @@ class ApplicationServiceTransactionWorkerStore(
{"txn_id": txn_id, "as_id": service.id},
)
return self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"complete_appservice_txn", _complete_appservice_txn
)
@@ -312,13 +309,13 @@ class ApplicationServiceTransactionWorkerStore(
else:
return int(last_txn_id[0]) # select 'last_txn' col
def set_appservice_last_pos(self, pos):
async def set_appservice_last_pos(self, pos) -> None:
def set_appservice_last_pos_txn(txn):
txn.execute(
"UPDATE appservice_stream_position SET stream_ordering = ?", (pos,)
)
return self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"set_appservice_last_pos", set_appservice_last_pos_txn
)

View File

@@ -190,15 +190,15 @@ class DeviceInboxWorkerStore(SQLBaseStore):
)
@trace
def delete_device_msgs_for_remote(self, destination, up_to_stream_id):
async def delete_device_msgs_for_remote(
self, destination: str, up_to_stream_id: int
) -> None:
"""Used to delete messages when the remote destination acknowledges
their receipt.
Args:
destination(str): The destination server_name
up_to_stream_id(int): Where to delete messages up to.
Returns:
A deferred that resolves when the messages have been deleted.
destination: The destination server_name
up_to_stream_id: Where to delete messages up to.
"""
def delete_messages_for_remote_destination_txn(txn):
@@ -209,7 +209,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
)
txn.execute(sql, (destination, up_to_stream_id))
return self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"delete_device_msgs_for_remote", delete_messages_for_remote_destination_txn
)

View File

@@ -151,7 +151,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
return sessions
def get_e2e_room_keys_multi(self, user_id, version, room_keys):
async def get_e2e_room_keys_multi(self, user_id, version, room_keys):
"""Get multiple room keys at a time. The difference between this function and
get_e2e_room_keys is that this function can be used to retrieve
multiple specific keys at a time, whereas get_e2e_room_keys is used for
@@ -166,10 +166,10 @@ class EndToEndRoomKeyStore(SQLBaseStore):
that we want to query
Returns:
Deferred[dict[str, dict[str, dict]]]: a map of room IDs to session IDs to room key
dict[str, dict[str, dict]]: a map of room IDs to session IDs to room key
"""
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_e2e_room_keys_multi",
self._get_e2e_room_keys_multi_txn,
user_id,
@@ -283,7 +283,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
raise StoreError(404, "No current backup version")
return row[0]
def get_e2e_room_keys_version_info(self, user_id, version=None):
async def get_e2e_room_keys_version_info(self, user_id, version=None):
"""Get info metadata about a version of our room_keys backup.
Args:
@@ -293,7 +293,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
Raises:
StoreError: with code 404 if there are no e2e_room_keys_versions present
Returns:
A deferred dict giving the info metadata for this backup version, with
A dict giving the info metadata for this backup version, with
fields including:
version(str)
algorithm(str)
@@ -324,12 +324,12 @@ class EndToEndRoomKeyStore(SQLBaseStore):
result["etag"] = 0
return result
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_e2e_room_keys_version_info", _get_e2e_room_keys_version_info_txn
)
@trace
def create_e2e_room_keys_version(self, user_id, info):
async def create_e2e_room_keys_version(self, user_id: str, info: dict) -> str:
"""Atomically creates a new version of this user's e2e_room_keys store
with the given version info.
@@ -338,7 +338,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
info(dict): the info about the backup version to be created
Returns:
A deferred string for the newly created version ID
The newly created version ID
"""
def _create_e2e_room_keys_version_txn(txn):
@@ -365,7 +365,7 @@ class EndToEndRoomKeyStore(SQLBaseStore):
return new_version
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"create_e2e_room_keys_version_txn", _create_e2e_room_keys_version_txn
)
@@ -403,13 +403,15 @@ class EndToEndRoomKeyStore(SQLBaseStore):
)
@trace
def delete_e2e_room_keys_version(self, user_id, version=None):
async def delete_e2e_room_keys_version(
self, user_id: str, version: Optional[str] = None
) -> None:
"""Delete a given backup version of the user's room keys.
Doesn't delete their actual key data.
Args:
user_id(str): the user whose backup version we're deleting
version(str): Optional. the version ID of the backup version we're deleting
user_id: the user whose backup version we're deleting
version: Optional. the version ID of the backup version we're deleting
If missing, we delete the current backup version info.
Raises:
StoreError: with code 404 if there are no e2e_room_keys_versions present,
@@ -430,13 +432,13 @@ class EndToEndRoomKeyStore(SQLBaseStore):
keyvalues={"user_id": user_id, "version": this_version},
)
return self.db_pool.simple_update_one_txn(
self.db_pool.simple_update_one_txn(
txn,
table="e2e_room_keys_versions",
keyvalues={"user_id": user_id, "version": this_version},
updatevalues={"deleted": 1},
)
return self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"delete_e2e_room_keys_version", _delete_e2e_room_keys_version_txn
)

View File

@@ -59,7 +59,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
include_given: include the given events in result
Returns:
list of event_ids
An awaitable which resolve to a list of event_ids
"""
return await self.db_pool.runInteraction(
"get_auth_chain_ids",
@@ -95,7 +95,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
return list(results)
def get_auth_chain_difference(self, state_sets: List[Set[str]]):
async def get_auth_chain_difference(self, state_sets: List[Set[str]]) -> Set[str]:
"""Given sets of state events figure out the auth chain difference (as
per state res v2 algorithm).
@@ -104,10 +104,10 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
chain.
Returns:
Deferred[Set[str]]
The set of the difference in auth chains.
"""
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_auth_chain_difference",
self._get_auth_chain_difference_txn,
state_sets,
@@ -252,8 +252,8 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
# Return all events where not all sets can reach them.
return {eid for eid, n in event_to_missing_sets.items() if n}
def get_oldest_events_with_depth_in_room(self, room_id):
return self.db_pool.runInteraction(
async def get_oldest_events_with_depth_in_room(self, room_id):
return await self.db_pool.runInteraction(
"get_oldest_events_with_depth_in_room",
self.get_oldest_events_with_depth_in_room_txn,
room_id,
@@ -293,7 +293,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
else:
return max(row["depth"] for row in rows)
def get_prev_events_for_room(self, room_id: str):
async def get_prev_events_for_room(self, room_id: str) -> List[str]:
"""
Gets a subset of the current forward extremities in the given room.
@@ -301,14 +301,14 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
events which refer to hundreds of prev_events.
Args:
room_id (str): room_id
room_id: room_id
Returns:
Deferred[List[str]]: the event ids of the forward extremites
The event ids of the forward extremities.
"""
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_prev_events_for_room", self._get_prev_events_for_room_txn, room_id
)
@@ -328,17 +328,19 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
return [row[0] for row in txn]
def get_rooms_with_many_extremities(self, min_count, limit, room_id_filter):
async def get_rooms_with_many_extremities(
self, min_count: int, limit: int, room_id_filter: Iterable[str]
) -> List[str]:
"""Get the top rooms with at least N extremities.
Args:
min_count (int): The minimum number of extremities
limit (int): The maximum number of rooms to return.
room_id_filter (iterable[str]): room_ids to exclude from the results
min_count: The minimum number of extremities
limit: The maximum number of rooms to return.
room_id_filter: room_ids to exclude from the results
Returns:
Deferred[list]: At most `limit` room IDs that have at least
`min_count` extremities, sorted by extremity count.
At most `limit` room IDs that have at least `min_count` extremities,
sorted by extremity count.
"""
def _get_rooms_with_many_extremities_txn(txn):
@@ -363,7 +365,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
txn.execute(sql, query_args)
return [room_id for room_id, in txn]
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_rooms_with_many_extremities", _get_rooms_with_many_extremities_txn
)
@@ -376,10 +378,10 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
desc="get_latest_event_ids_in_room",
)
def get_min_depth(self, room_id):
""" For hte given room, get the minimum depth we have seen for it.
async def get_min_depth(self, room_id: str) -> int:
"""For the given room, get the minimum depth we have seen for it.
"""
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_min_depth", self._get_min_depth_interaction, room_id
)
@@ -394,7 +396,9 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
return int(min_depth) if min_depth is not None else None
def get_forward_extremeties_for_room(self, room_id, stream_ordering):
async def get_forward_extremeties_for_room(
self, room_id: str, stream_ordering: int
) -> List[str]:
"""For a given room_id and stream_ordering, return the forward
extremeties of the room at that point in "time".
@@ -402,11 +406,11 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
stream_orderings from that point.
Args:
room_id (str):
stream_ordering (int):
room_id:
stream_ordering:
Returns:
deferred, which resolves to a list of event_ids
A list of event_ids
"""
# We want to make the cache more effective, so we clamp to the last
# change before the given ordering.
@@ -422,10 +426,10 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
if last_change > self.stream_ordering_month_ago:
stream_ordering = min(last_change, stream_ordering)
return self._get_forward_extremeties_for_room(room_id, stream_ordering)
return await self._get_forward_extremeties_for_room(room_id, stream_ordering)
@cached(max_entries=5000, num_args=2)
def _get_forward_extremeties_for_room(self, room_id, stream_ordering):
async def _get_forward_extremeties_for_room(self, room_id, stream_ordering):
"""For a given room_id and stream_ordering, return the forward
extremeties of the room at that point in "time".
@@ -450,19 +454,18 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
txn.execute(sql, (stream_ordering, room_id))
return [event_id for event_id, in txn]
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_forward_extremeties_for_room", get_forward_extremeties_for_room_txn
)
async def get_backfill_events(self, room_id, event_list, limit):
async def get_backfill_events(self, room_id: str, event_list: list, limit: int):
"""Get a list of Events for a given topic that occurred before (and
including) the events in event_list. Return a list of max size `limit`
Args:
txn
room_id (str)
event_list (list)
limit (int)
room_id
event_list
limit
"""
event_ids = await self.db_pool.runInteraction(
"get_backfill_events",
@@ -631,8 +634,8 @@ class EventFederationStore(EventFederationWorkerStore):
_delete_old_forward_extrem_cache_txn,
)
def clean_room_for_join(self, room_id):
return self.db_pool.runInteraction(
async def clean_room_for_join(self, room_id):
return await self.db_pool.runInteraction(
"clean_room_for_join", self._clean_room_for_join_txn, room_id
)

View File

@@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple, Union
from synapse.api.errors import SynapseError
from synapse.storage._base import SQLBaseStore, db_to_json
@@ -70,7 +70,9 @@ class GroupServerWorkerStore(SQLBaseStore):
desc="get_invited_users_in_group",
)
def get_rooms_in_group(self, group_id: str, include_private: bool = False):
async def get_rooms_in_group(
self, group_id: str, include_private: bool = False
) -> List[Dict[str, Union[str, bool]]]:
"""Retrieve the rooms that belong to a given group. Does not return rooms that
lack members.
@@ -79,8 +81,7 @@ class GroupServerWorkerStore(SQLBaseStore):
include_private: Whether to return private rooms in results
Returns:
Deferred[List[Dict[str, str|bool]]]: A list of dictionaries, each in the
form of:
A list of dictionaries, each in the form of:
{
"room_id": "!a_room_id:example.com", # The ID of the room
@@ -117,13 +118,13 @@ class GroupServerWorkerStore(SQLBaseStore):
for room_id, is_public in txn
]
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_rooms_in_group", _get_rooms_in_group_txn
)
def get_rooms_for_summary_by_category(
async def get_rooms_for_summary_by_category(
self, group_id: str, include_private: bool = False,
):
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
"""Get the rooms and categories that should be included in a summary request
Args:
@@ -131,7 +132,7 @@ class GroupServerWorkerStore(SQLBaseStore):
include_private: Whether to return private rooms in results
Returns:
Deferred[Tuple[List, Dict]]: A tuple containing:
A tuple containing:
* A list of dictionaries with the keys:
* "room_id": str, the room ID
@@ -207,7 +208,7 @@ class GroupServerWorkerStore(SQLBaseStore):
return rooms, categories
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_rooms_for_summary", _get_rooms_for_summary_txn
)
@@ -281,10 +282,11 @@ class GroupServerWorkerStore(SQLBaseStore):
desc="get_local_groups_for_room",
)
def get_users_for_summary_by_role(self, group_id, include_private=False):
async def get_users_for_summary_by_role(self, group_id, include_private=False):
"""Get the users and roles that should be included in a summary request
Returns ([users], [roles])
Returns:
([users], [roles])
"""
def _get_users_for_summary_txn(txn):
@@ -338,7 +340,7 @@ class GroupServerWorkerStore(SQLBaseStore):
return users, roles
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_users_for_summary_by_role", _get_users_for_summary_txn
)
@@ -376,7 +378,7 @@ class GroupServerWorkerStore(SQLBaseStore):
allow_none=True,
)
def get_users_membership_info_in_group(self, group_id, user_id):
async def get_users_membership_info_in_group(self, group_id, user_id):
"""Get a dict describing the membership of a user in a group.
Example if joined:
@@ -387,7 +389,8 @@ class GroupServerWorkerStore(SQLBaseStore):
"is_privileged": False,
}
Returns an empty dict if the user is not join/invite/etc
Returns:
An empty dict if the user is not join/invite/etc
"""
def _get_users_membership_in_group_txn(txn):
@@ -419,7 +422,7 @@ class GroupServerWorkerStore(SQLBaseStore):
return {}
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_users_membership_info_in_group", _get_users_membership_in_group_txn
)
@@ -433,7 +436,7 @@ class GroupServerWorkerStore(SQLBaseStore):
desc="get_publicised_groups_for_user",
)
def get_attestations_need_renewals(self, valid_until_ms):
async def get_attestations_need_renewals(self, valid_until_ms):
"""Get all attestations that need to be renewed until givent time
"""
@@ -445,7 +448,7 @@ class GroupServerWorkerStore(SQLBaseStore):
txn.execute(sql, (valid_until_ms,))
return self.db_pool.cursor_to_dict(txn)
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_attestations_need_renewals", _get_attestations_need_renewals_txn
)
@@ -475,7 +478,7 @@ class GroupServerWorkerStore(SQLBaseStore):
desc="get_joined_groups",
)
def get_all_groups_for_user(self, user_id, now_token):
async def get_all_groups_for_user(self, user_id, now_token):
def _get_all_groups_for_user_txn(txn):
sql = """
SELECT group_id, type, membership, u.content
@@ -495,7 +498,7 @@ class GroupServerWorkerStore(SQLBaseStore):
for row in txn
]
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_all_groups_for_user", _get_all_groups_for_user_txn
)
@@ -600,8 +603,27 @@ class GroupServerStore(GroupServerWorkerStore):
desc="set_group_join_policy",
)
def add_room_to_summary(self, group_id, room_id, category_id, order, is_public):
return self.db_pool.runInteraction(
async def add_room_to_summary(
self,
group_id: str,
room_id: str,
category_id: str,
order: int,
is_public: Optional[bool],
) -> None:
"""Add (or update) room's entry in summary.
Args:
group_id
room_id
category_id: If not None then adds the category to the end of
the summary if its not already there.
order: If not None inserts the room at that position, e.g. an order
of 1 will put the room first. Otherwise, the room gets added to
the end.
is_public
"""
await self.db_pool.runInteraction(
"add_room_to_summary",
self._add_room_to_summary_txn,
group_id,
@@ -612,18 +634,26 @@ class GroupServerStore(GroupServerWorkerStore):
)
def _add_room_to_summary_txn(
self, txn, group_id, room_id, category_id, order, is_public
):
self,
txn,
group_id: str,
room_id: str,
category_id: str,
order: int,
is_public: Optional[bool],
) -> None:
"""Add (or update) room's entry in summary.
Args:
group_id (str)
room_id (str)
category_id (str): If not None then adds the category to the end of
the summary if its not already there. [Optional]
order (int): If not None inserts the room at that position, e.g.
an order of 1 will put the room first. Otherwise, the room gets
added to the end.
txn
group_id
room_id
category_id: If not None then adds the category to the end of
the summary if its not already there.
order: If not None inserts the room at that position, e.g. an order
of 1 will put the room first. Otherwise, the room gets added to
the end.
is_public
"""
room_in_group = self.db_pool.simple_select_one_onecol_txn(
txn,
@@ -818,8 +848,27 @@ class GroupServerStore(GroupServerWorkerStore):
desc="remove_group_role",
)
def add_user_to_summary(self, group_id, user_id, role_id, order, is_public):
return self.db_pool.runInteraction(
async def add_user_to_summary(
self,
group_id: str,
user_id: str,
role_id: str,
order: int,
is_public: Optional[bool],
) -> None:
"""Add (or update) user's entry in summary.
Args:
group_id
user_id
role_id: If not None then adds the role to the end of the summary if
its not already there.
order: If not None inserts the user at that position, e.g. an order
of 1 will put the user first. Otherwise, the user gets added to
the end.
is_public
"""
await self.db_pool.runInteraction(
"add_user_to_summary",
self._add_user_to_summary_txn,
group_id,
@@ -830,18 +879,26 @@ class GroupServerStore(GroupServerWorkerStore):
)
def _add_user_to_summary_txn(
self, txn, group_id, user_id, role_id, order, is_public
self,
txn,
group_id: str,
user_id: str,
role_id: str,
order: int,
is_public: Optional[bool],
):
"""Add (or update) user's entry in summary.
Args:
group_id (str)
user_id (str)
role_id (str): If not None then adds the role to the end of
the summary if its not already there. [Optional]
order (int): If not None inserts the user at that position, e.g.
an order of 1 will put the user first. Otherwise, the user gets
added to the end.
txn
group_id
user_id
role_id: If not None then adds the role to the end of the summary if
its not already there.
order: If not None inserts the user at that position, e.g. an order
of 1 will put the user first. Otherwise, the user gets added to
the end.
is_public
"""
user_in_group = self.db_pool.simple_select_one_onecol_txn(
txn,
@@ -963,27 +1020,26 @@ class GroupServerStore(GroupServerWorkerStore):
desc="add_group_invite",
)
def add_user_to_group(
async def add_user_to_group(
self,
group_id,
user_id,
is_admin=False,
is_public=True,
local_attestation=None,
remote_attestation=None,
):
group_id: str,
user_id: str,
is_admin: bool = False,
is_public: bool = True,
local_attestation: dict = None,
remote_attestation: dict = None,
) -> None:
"""Add a user to the group server.
Args:
group_id (str)
user_id (str)
is_admin (bool)
is_public (bool)
local_attestation (dict): The attestation the GS created to give
to the remote server. Optional if the user and group are on the
same server
remote_attestation (dict): The attestation given to GS by remote
group_id
user_id
is_admin
is_public
local_attestation: The attestation the GS created to give to the remote
server. Optional if the user and group are on the same server
remote_attestation: The attestation given to GS by remote server.
Optional if the user and group are on the same server
"""
def _add_user_to_group_txn(txn):
@@ -1026,9 +1082,9 @@ class GroupServerStore(GroupServerWorkerStore):
},
)
return self.db_pool.runInteraction("add_user_to_group", _add_user_to_group_txn)
await self.db_pool.runInteraction("add_user_to_group", _add_user_to_group_txn)
def remove_user_from_group(self, group_id, user_id):
async def remove_user_from_group(self, group_id: str, user_id: str) -> None:
def _remove_user_from_group_txn(txn):
self.db_pool.simple_delete_txn(
txn,
@@ -1056,7 +1112,7 @@ class GroupServerStore(GroupServerWorkerStore):
keyvalues={"group_id": group_id, "user_id": user_id},
)
return self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"remove_user_from_group", _remove_user_from_group_txn
)
@@ -1079,7 +1135,7 @@ class GroupServerStore(GroupServerWorkerStore):
desc="update_room_in_group_visibility",
)
def remove_room_from_group(self, group_id, room_id):
async def remove_room_from_group(self, group_id: str, room_id: str) -> None:
def _remove_room_from_group_txn(txn):
self.db_pool.simple_delete_txn(
txn,
@@ -1093,7 +1149,7 @@ class GroupServerStore(GroupServerWorkerStore):
keyvalues={"group_id": group_id, "room_id": room_id},
)
return self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"remove_room_from_group", _remove_room_from_group_txn
)
@@ -1286,14 +1342,11 @@ class GroupServerStore(GroupServerWorkerStore):
def get_group_stream_token(self):
return self._group_updates_id_gen.get_current_token()
def delete_group(self, group_id):
async def delete_group(self, group_id: str) -> None:
"""Deletes a group fully from the database.
Args:
group_id (str)
Returns:
Deferred
group_id: The group ID to delete.
"""
def _delete_group_txn(txn):
@@ -1317,4 +1370,4 @@ class GroupServerStore(GroupServerWorkerStore):
txn, table=table, keyvalues={"group_id": group_id}
)
return self.db_pool.runInteraction("delete_group", _delete_group_txn)
await self.db_pool.runInteraction("delete_group", _delete_group_txn)

View File

@@ -16,7 +16,7 @@
import itertools
import logging
from typing import Iterable, Tuple
from typing import Dict, Iterable, List, Optional, Tuple
from signedjson.key import decode_verify_key_bytes
@@ -42,16 +42,17 @@ class KeyStore(SQLBaseStore):
@cachedList(
cached_method_name="_get_server_verify_key", list_name="server_name_and_key_ids"
)
def get_server_verify_keys(self, server_name_and_key_ids):
async def get_server_verify_keys(
self, server_name_and_key_ids: Iterable[Tuple[str, str]]
) -> Dict[Tuple[str, str], Optional[FetchKeyResult]]:
"""
Args:
server_name_and_key_ids (iterable[Tuple[str, str]]):
server_name_and_key_ids:
iterable of (server_name, key-id) tuples to fetch keys for
Returns:
Deferred: resolves to dict[Tuple[str, str], FetchKeyResult|None]:
map from (server_name, key_id) -> FetchKeyResult, or None if the key is
unknown
A map from (server_name, key_id) -> FetchKeyResult, or None if the
key is unknown
"""
keys = {}
@@ -87,7 +88,7 @@ class KeyStore(SQLBaseStore):
_get_keys(txn, batch)
return keys
return self.db_pool.runInteraction("get_server_verify_keys", _txn)
return await self.db_pool.runInteraction("get_server_verify_keys", _txn)
async def store_server_verify_keys(
self,
@@ -179,7 +180,9 @@ class KeyStore(SQLBaseStore):
desc="store_server_keys_json",
)
def get_server_keys_json(self, server_keys):
async def get_server_keys_json(
self, server_keys: Iterable[Tuple[str, Optional[str], Optional[str]]]
) -> Dict[Tuple[str, Optional[str], Optional[str]], List[dict]]:
"""Retrive the key json for a list of server_keys and key ids.
If no keys are found for a given server, key_id and source then
that server, key_id, and source triplet entry will be an empty list.
@@ -188,8 +191,7 @@ class KeyStore(SQLBaseStore):
Args:
server_keys (list): List of (server_name, key_id, source) triplets.
Returns:
Deferred[dict[Tuple[str, str, str|None], list[dict]]]:
Dict mapping (server_name, key_id, source) triplets to lists of dicts
A mapping from (server_name, key_id, source) triplets to a list of dicts
"""
def _get_server_keys_json_txn(txn):
@@ -215,6 +217,6 @@ class KeyStore(SQLBaseStore):
results[(server_name, key_id, from_server)] = rows
return results
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_server_keys_json", _get_server_keys_json_txn
)

View File

@@ -15,6 +15,7 @@
import logging
from collections import namedtuple
from typing import Optional, Tuple
from canonicaljson import encode_canonical_json
@@ -56,21 +57,23 @@ class TransactionStore(SQLBaseStore):
expiry_ms=5 * 60 * 1000,
)
def get_received_txn_response(self, transaction_id, origin):
async def get_received_txn_response(
self, transaction_id: str, origin: str
) -> Optional[Tuple[int, JsonDict]]:
"""For an incoming transaction from a given origin, check if we have
already responded to it. If so, return the response code and response
body (as a dict).
Args:
transaction_id (str)
origin(str)
transaction_id
origin
Returns:
tuple: None if we have not previously responded to
this transaction or a 2-tuple of (int, dict)
None if we have not previously responded to this transaction or a
2-tuple of (int, dict)
"""
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_received_txn_response",
self._get_received_txn_response,
transaction_id,
@@ -166,21 +169,25 @@ class TransactionStore(SQLBaseStore):
else:
return None
def set_destination_retry_timings(
self, destination, failure_ts, retry_last_ts, retry_interval
):
async def set_destination_retry_timings(
self,
destination: str,
failure_ts: Optional[int],
retry_last_ts: int,
retry_interval: int,
) -> None:
"""Sets the current retry timings for a given destination.
Both timings should be zero if retrying is no longer occuring.
Args:
destination (str)
failure_ts (int|None) - when the server started failing (ms since epoch)
retry_last_ts (int) - time of last retry attempt in unix epoch ms
retry_interval (int) - how long until next retry in ms
destination
failure_ts: when the server started failing (ms since epoch)
retry_last_ts: time of last retry attempt in unix epoch ms
retry_interval: how long until next retry in ms
"""
self._destination_retry_cache.pop(destination, None)
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"set_destination_retry_timings",
self._set_destination_retry_timings,
destination,
@@ -256,13 +263,13 @@ class TransactionStore(SQLBaseStore):
"cleanup_transactions", self._cleanup_transactions
)
def _cleanup_transactions(self):
async def _cleanup_transactions(self) -> None:
now = self._clock.time_msec()
month_ago = now - 30 * 24 * 60 * 60 * 1000
def _cleanup_transactions_txn(txn):
txn.execute("DELETE FROM received_transactions WHERE ts < ?", (month_ago,))
return self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"_cleanup_transactions", _cleanup_transactions_txn
)

View File

@@ -18,7 +18,7 @@ import re
import string
import sys
from collections import namedtuple
from typing import Any, Dict, Tuple, Type, TypeVar
from typing import Any, Dict, Mapping, MutableMapping, Tuple, Type, TypeVar
import attr
from signedjson.key import decode_verify_key_bytes
@@ -42,8 +42,9 @@ else:
# Define a state map type from type/state_key to T (usually an event ID or
# event)
T = TypeVar("T")
StateMap = Dict[Tuple[str, str], T]
StateKey = Tuple[str, str]
StateMap = Mapping[StateKey, T]
MutableStateMap = MutableMapping[StateKey, T]
# the type of a JSON-serialisable dict. This could be made stronger, but it will
# do for now.