Merge commit '208e1d3eb' into anoa/dinsic_release_1_21_x
* commit '208e1d3eb': Fix typing for `@cached` wrapped functions (#8240) Remove useless changelog about reverting a #8239. Revert pinning of setuptools (#8239) Fix typing for SyncHandler (#8237) wrap `_get_e2e_device_keys_and_signatures_txn` in a non-txn method (#8231) Add an overload for simple_select_one_onecol_txn. (#8235)
This commit is contained in:
@@ -73,7 +73,7 @@ mkdir -p ~/synapse
|
||||
virtualenv -p python3 ~/synapse/env
|
||||
source ~/synapse/env/bin/activate
|
||||
pip install --upgrade pip
|
||||
pip install --upgrade setuptools!=50.0 # setuptools==50.0 fails on some older Python versions
|
||||
pip install --upgrade setuptools
|
||||
pip install matrix-synapse
|
||||
```
|
||||
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
Do not install setuptools 50.0. It can lead to a broken configuration on some older Python versions.
|
||||
1
changelog.d/8231.misc
Normal file
1
changelog.d/8231.misc
Normal file
@@ -0,0 +1 @@
|
||||
Refactor queries for device keys and cross-signatures.
|
||||
1
changelog.d/8235.misc
Normal file
1
changelog.d/8235.misc
Normal file
@@ -0,0 +1 @@
|
||||
Add type hints to `StreamStore`.
|
||||
1
changelog.d/8237.misc
Normal file
1
changelog.d/8237.misc
Normal file
@@ -0,0 +1 @@
|
||||
Fix type hints in `SyncHandler`.
|
||||
1
changelog.d/8240.misc
Normal file
1
changelog.d/8240.misc
Normal file
@@ -0,0 +1 @@
|
||||
Fix type hints for functions decorated with `@cached`.
|
||||
3
mypy.ini
3
mypy.ini
@@ -1,6 +1,6 @@
|
||||
[mypy]
|
||||
namespace_packages = True
|
||||
plugins = mypy_zope:plugin
|
||||
plugins = mypy_zope:plugin, scripts-dev/mypy_synapse_plugin.py
|
||||
follow_imports = silent
|
||||
check_untyped_defs = True
|
||||
show_error_codes = True
|
||||
@@ -51,6 +51,7 @@ files =
|
||||
synapse/storage/util,
|
||||
synapse/streams,
|
||||
synapse/types.py,
|
||||
synapse/util/caches/descriptors.py,
|
||||
synapse/util/caches/stream_change_cache.py,
|
||||
synapse/util/metrics.py,
|
||||
tests/replication,
|
||||
|
||||
85
scripts-dev/mypy_synapse_plugin.py
Normal file
85
scripts-dev/mypy_synapse_plugin.py
Normal file
@@ -0,0 +1,85 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2020 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""This is a mypy plugin for Synpase to deal with some of the funky typing that
|
||||
can crop up, e.g the cache descriptors.
|
||||
"""
|
||||
|
||||
from typing import Callable, Optional
|
||||
|
||||
from mypy.plugin import MethodSigContext, Plugin
|
||||
from mypy.typeops import bind_self
|
||||
from mypy.types import CallableType
|
||||
|
||||
|
||||
class SynapsePlugin(Plugin):
|
||||
def get_method_signature_hook(
|
||||
self, fullname: str
|
||||
) -> Optional[Callable[[MethodSigContext], CallableType]]:
|
||||
if fullname.startswith(
|
||||
"synapse.util.caches.descriptors._CachedFunction.__call__"
|
||||
):
|
||||
return cached_function_method_signature
|
||||
return None
|
||||
|
||||
|
||||
def cached_function_method_signature(ctx: MethodSigContext) -> CallableType:
|
||||
"""Fixes the `_CachedFunction.__call__` signature to be correct.
|
||||
|
||||
It already has *almost* the correct signature, except:
|
||||
|
||||
1. the `self` argument needs to be marked as "bound"; and
|
||||
2. any `cache_context` argument should be removed.
|
||||
"""
|
||||
|
||||
# First we mark this as a bound function signature.
|
||||
signature = bind_self(ctx.default_signature)
|
||||
|
||||
# Secondly, we remove any "cache_context" args.
|
||||
#
|
||||
# Note: We should be only doing this if `cache_context=True` is set, but if
|
||||
# it isn't then the code will raise an exception when its called anyway, so
|
||||
# its not the end of the world.
|
||||
context_arg_index = None
|
||||
for idx, name in enumerate(signature.arg_names):
|
||||
if name == "cache_context":
|
||||
context_arg_index = idx
|
||||
break
|
||||
|
||||
if context_arg_index:
|
||||
arg_types = list(signature.arg_types)
|
||||
arg_types.pop(context_arg_index)
|
||||
|
||||
arg_names = list(signature.arg_names)
|
||||
arg_names.pop(context_arg_index)
|
||||
|
||||
arg_kinds = list(signature.arg_kinds)
|
||||
arg_kinds.pop(context_arg_index)
|
||||
|
||||
signature = signature.copy_modified(
|
||||
arg_types=arg_types, arg_names=arg_names, arg_kinds=arg_kinds,
|
||||
)
|
||||
|
||||
return signature
|
||||
|
||||
|
||||
def plugin(version: str):
|
||||
# This is the entry point of the plugin, and let's us deal with the fact
|
||||
# that the mypy plugin interface is *not* stable by looking at the version
|
||||
# string.
|
||||
#
|
||||
# However, since we pin the version of mypy Synapse uses in CI, we don't
|
||||
# really care.
|
||||
return SynapsePlugin
|
||||
@@ -443,11 +443,11 @@ class FederationHandler(BaseHandler):
|
||||
if not prevs - seen:
|
||||
return
|
||||
|
||||
latest = await self.store.get_latest_event_ids_in_room(room_id)
|
||||
latest_list = await self.store.get_latest_event_ids_in_room(room_id)
|
||||
|
||||
# We add the prev events that we have seen to the latest
|
||||
# list to ensure the remote server doesn't give them to us
|
||||
latest = set(latest)
|
||||
latest = set(latest_list)
|
||||
latest |= seen
|
||||
|
||||
logger.info(
|
||||
@@ -784,7 +784,7 @@ class FederationHandler(BaseHandler):
|
||||
# keys across all devices.
|
||||
current_keys = [
|
||||
key
|
||||
for device in cached_devices
|
||||
for device in cached_devices.values()
|
||||
for key in device.get("keys", {}).get("keys", {}).values()
|
||||
]
|
||||
|
||||
@@ -2129,8 +2129,8 @@ class FederationHandler(BaseHandler):
|
||||
if backfilled or event.internal_metadata.is_outlier():
|
||||
return
|
||||
|
||||
extrem_ids = await self.store.get_latest_event_ids_in_room(event.room_id)
|
||||
extrem_ids = set(extrem_ids)
|
||||
extrem_ids_list = await self.store.get_latest_event_ids_in_room(event.room_id)
|
||||
extrem_ids = set(extrem_ids_list)
|
||||
prev_event_ids = set(event.prev_event_ids())
|
||||
|
||||
if extrem_ids == prev_event_ids:
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
|
||||
import itertools
|
||||
import logging
|
||||
from typing import Any, Dict, FrozenSet, List, Optional, Set, Tuple
|
||||
from typing import TYPE_CHECKING, Any, Dict, FrozenSet, List, Optional, Set, Tuple
|
||||
|
||||
import attr
|
||||
from prometheus_client import Counter
|
||||
@@ -44,6 +44,9 @@ from synapse.util.caches.response_cache import ResponseCache
|
||||
from synapse.util.metrics import Measure, measure_func
|
||||
from synapse.visibility import filter_events_for_client
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Debug logger for https://github.com/matrix-org/synapse/issues/4422
|
||||
@@ -244,7 +247,7 @@ class SyncResult:
|
||||
|
||||
|
||||
class SyncHandler(object):
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.hs_config = hs.config
|
||||
self.store = hs.get_datastore()
|
||||
self.notifier = hs.get_notifier()
|
||||
@@ -717,9 +720,8 @@ class SyncHandler(object):
|
||||
]
|
||||
|
||||
missing_hero_state = await self.store.get_events(missing_hero_event_ids)
|
||||
missing_hero_state = missing_hero_state.values()
|
||||
|
||||
for s in missing_hero_state:
|
||||
for s in missing_hero_state.values():
|
||||
cache.set(s.state_key, s.event_id)
|
||||
state[(EventTypes.Member, s.state_key)] = s
|
||||
|
||||
@@ -1771,7 +1773,7 @@ class SyncHandler(object):
|
||||
ignored_users: Set[str],
|
||||
room_builder: "RoomSyncResultBuilder",
|
||||
ephemeral: List[JsonDict],
|
||||
tags: Optional[List[JsonDict]],
|
||||
tags: Optional[Dict[str, Dict[str, Any]]],
|
||||
account_data: Dict[str, JsonDict],
|
||||
always_include: bool = False,
|
||||
):
|
||||
|
||||
@@ -74,10 +74,6 @@ REQUIREMENTS = [
|
||||
"Jinja2>=2.9",
|
||||
"bleach>=1.4.3",
|
||||
"typing-extensions>=3.7.4",
|
||||
# setuptools is required by a variety of dependencies, unfortunately version
|
||||
# 50.0 is incompatible with older Python versions, see
|
||||
# https://github.com/pypa/setuptools/issues/2352
|
||||
"setuptools!=50.0",
|
||||
]
|
||||
|
||||
CONDITIONAL_REQUIREMENTS = {
|
||||
|
||||
@@ -1149,6 +1149,30 @@ class DatabasePool(object):
|
||||
allow_none=allow_none,
|
||||
)
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
def simple_select_one_onecol_txn(
|
||||
cls,
|
||||
txn: LoggingTransaction,
|
||||
table: str,
|
||||
keyvalues: Dict[str, Any],
|
||||
retcol: Iterable[str],
|
||||
allow_none: Literal[False] = False,
|
||||
) -> Any:
|
||||
...
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
def simple_select_one_onecol_txn(
|
||||
cls,
|
||||
txn: LoggingTransaction,
|
||||
table: str,
|
||||
keyvalues: Dict[str, Any],
|
||||
retcol: Iterable[str],
|
||||
allow_none: Literal[True] = True,
|
||||
) -> Optional[Any]:
|
||||
...
|
||||
|
||||
@classmethod
|
||||
def simple_select_one_onecol_txn(
|
||||
cls,
|
||||
|
||||
@@ -255,9 +255,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
||||
List of objects representing an device update EDU
|
||||
"""
|
||||
devices = (
|
||||
await self.db_pool.runInteraction(
|
||||
"get_e2e_device_keys_and_signatures_txn",
|
||||
self._get_e2e_device_keys_and_signatures_txn,
|
||||
await self.get_e2e_device_keys_and_signatures(
|
||||
query_map.keys(),
|
||||
include_all_devices=True,
|
||||
include_deleted_devices=True,
|
||||
|
||||
@@ -36,7 +36,7 @@ if TYPE_CHECKING:
|
||||
|
||||
@attr.s
|
||||
class DeviceKeyLookupResult:
|
||||
"""The type returned by _get_e2e_device_keys_and_signatures_txn"""
|
||||
"""The type returned by get_e2e_device_keys_and_signatures"""
|
||||
|
||||
display_name = attr.ib(type=Optional[str])
|
||||
|
||||
@@ -60,11 +60,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
|
||||
"""
|
||||
now_stream_id = self.get_device_stream_token()
|
||||
|
||||
devices = await self.db_pool.runInteraction(
|
||||
"get_e2e_device_keys_and_signatures_txn",
|
||||
self._get_e2e_device_keys_and_signatures_txn,
|
||||
[(user_id, None)],
|
||||
)
|
||||
devices = await self.get_e2e_device_keys_and_signatures([(user_id, None)])
|
||||
|
||||
if devices:
|
||||
user_devices = devices[user_id]
|
||||
@@ -108,11 +104,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
|
||||
if not query_list:
|
||||
return {}
|
||||
|
||||
results = await self.db_pool.runInteraction(
|
||||
"get_e2e_device_keys_and_signatures_txn",
|
||||
self._get_e2e_device_keys_and_signatures_txn,
|
||||
query_list,
|
||||
)
|
||||
results = await self.get_e2e_device_keys_and_signatures(query_list)
|
||||
|
||||
# Build the result structure, un-jsonify the results, and add the
|
||||
# "unsigned" section
|
||||
@@ -135,12 +127,45 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
|
||||
return rv
|
||||
|
||||
@trace
|
||||
def _get_e2e_device_keys_and_signatures_txn(
|
||||
self, txn, query_list, include_all_devices=False, include_deleted_devices=False
|
||||
async def get_e2e_device_keys_and_signatures(
|
||||
self,
|
||||
query_list: List[Tuple[str, Optional[str]]],
|
||||
include_all_devices: bool = False,
|
||||
include_deleted_devices: bool = False,
|
||||
) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]:
|
||||
"""Fetch a list of device keys, together with their cross-signatures.
|
||||
|
||||
Args:
|
||||
query_list: List of pairs of user_ids and device_ids. Device id can be None
|
||||
to indicate "all devices for this user"
|
||||
|
||||
include_all_devices: whether to return devices without device keys
|
||||
|
||||
include_deleted_devices: whether to include null entries for
|
||||
devices which no longer exist (but were in the query_list).
|
||||
This option only takes effect if include_all_devices is true.
|
||||
|
||||
Returns:
|
||||
Dict mapping from user-id to dict mapping from device_id to
|
||||
key data.
|
||||
"""
|
||||
set_tag("include_all_devices", include_all_devices)
|
||||
set_tag("include_deleted_devices", include_deleted_devices)
|
||||
|
||||
result = await self.db_pool.runInteraction(
|
||||
"get_e2e_device_keys",
|
||||
self._get_e2e_device_keys_and_signatures_txn,
|
||||
query_list,
|
||||
include_all_devices,
|
||||
include_deleted_devices,
|
||||
)
|
||||
|
||||
log_kv(result)
|
||||
return result
|
||||
|
||||
def _get_e2e_device_keys_and_signatures_txn(
|
||||
self, txn, query_list, include_all_devices=False, include_deleted_devices=False
|
||||
) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]:
|
||||
query_clauses = []
|
||||
query_params = []
|
||||
signature_query_clauses = []
|
||||
@@ -230,7 +255,6 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
|
||||
)
|
||||
signing_user_signatures[signing_key_id] = signature
|
||||
|
||||
log_kv(result)
|
||||
return result
|
||||
|
||||
async def get_e2e_one_time_keys(
|
||||
|
||||
@@ -298,8 +298,8 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||
return None
|
||||
|
||||
async def get_rooms_for_local_user_where_membership_is(
|
||||
self, user_id: str, membership_list: List[str]
|
||||
) -> Optional[List[RoomsForUser]]:
|
||||
self, user_id: str, membership_list: Collection[str]
|
||||
) -> List[RoomsForUser]:
|
||||
"""Get all the rooms for this *local* user where the membership for this user
|
||||
matches one in the membership list.
|
||||
|
||||
@@ -314,7 +314,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||
The RoomsForUser that the user matches the membership types.
|
||||
"""
|
||||
if not membership_list:
|
||||
return None
|
||||
return []
|
||||
|
||||
rooms = await self.db_pool.runInteraction(
|
||||
"get_rooms_for_local_user_where_membership_is",
|
||||
|
||||
@@ -43,7 +43,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
|
||||
"room_tags", {"user_id": user_id}, ["room_id", "tag", "content"]
|
||||
)
|
||||
|
||||
tags_by_room = {}
|
||||
tags_by_room = {} # type: Dict[str, Dict[str, JsonDict]]
|
||||
for row in rows:
|
||||
room_tags = tags_by_room.setdefault(row["room_id"], {})
|
||||
room_tags[row["tag"]] = db_to_json(row["content"])
|
||||
@@ -123,7 +123,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
|
||||
|
||||
async def get_updated_tags(
|
||||
self, user_id: str, stream_id: int
|
||||
) -> Dict[str, List[str]]:
|
||||
) -> Dict[str, Dict[str, JsonDict]]:
|
||||
"""Get all the tags for the rooms where the tags have changed since the
|
||||
given version
|
||||
|
||||
|
||||
@@ -18,11 +18,10 @@ import functools
|
||||
import inspect
|
||||
import logging
|
||||
import threading
|
||||
from typing import Any, Tuple, Union, cast
|
||||
from typing import Any, Callable, Generic, Optional, Tuple, TypeVar, Union, cast
|
||||
from weakref import WeakValueDictionary
|
||||
|
||||
from prometheus_client import Gauge
|
||||
from typing_extensions import Protocol
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
@@ -38,8 +37,10 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
CacheKey = Union[Tuple, Any]
|
||||
|
||||
F = TypeVar("F", bound=Callable[..., Any])
|
||||
|
||||
class _CachedFunction(Protocol):
|
||||
|
||||
class _CachedFunction(Generic[F]):
|
||||
invalidate = None # type: Any
|
||||
invalidate_all = None # type: Any
|
||||
invalidate_many = None # type: Any
|
||||
@@ -47,8 +48,11 @@ class _CachedFunction(Protocol):
|
||||
cache = None # type: Any
|
||||
num_args = None # type: Any
|
||||
|
||||
def __name__(self):
|
||||
...
|
||||
__name__ = None # type: str
|
||||
|
||||
# Note: This function signature is actually fiddled with by the synapse mypy
|
||||
# plugin to a) make it a bound method, and b) remove any `cache_context` arg.
|
||||
__call__ = None # type: F
|
||||
|
||||
|
||||
cache_pending_metric = Gauge(
|
||||
@@ -123,7 +127,7 @@ class Cache(object):
|
||||
|
||||
self.name = name
|
||||
self.keylen = keylen
|
||||
self.thread = None
|
||||
self.thread = None # type: Optional[threading.Thread]
|
||||
self.metrics = register_cache(
|
||||
"cache",
|
||||
name,
|
||||
@@ -662,9 +666,13 @@ class _CacheContext:
|
||||
|
||||
|
||||
def cached(
|
||||
max_entries=1000, num_args=None, tree=False, cache_context=False, iterable=False
|
||||
):
|
||||
return lambda orig: CacheDescriptor(
|
||||
max_entries: int = 1000,
|
||||
num_args: Optional[int] = None,
|
||||
tree: bool = False,
|
||||
cache_context: bool = False,
|
||||
iterable: bool = False,
|
||||
) -> Callable[[F], _CachedFunction[F]]:
|
||||
func = lambda orig: CacheDescriptor(
|
||||
orig,
|
||||
max_entries=max_entries,
|
||||
num_args=num_args,
|
||||
@@ -673,8 +681,12 @@ def cached(
|
||||
iterable=iterable,
|
||||
)
|
||||
|
||||
return cast(Callable[[F], _CachedFunction[F]], func)
|
||||
|
||||
def cachedList(cached_method_name, list_name, num_args=None):
|
||||
|
||||
def cachedList(
|
||||
cached_method_name: str, list_name: str, num_args: Optional[int] = None
|
||||
) -> Callable[[F], _CachedFunction[F]]:
|
||||
"""Creates a descriptor that wraps a function in a `CacheListDescriptor`.
|
||||
|
||||
Used to do batch lookups for an already created cache. A single argument
|
||||
@@ -684,11 +696,11 @@ def cachedList(cached_method_name, list_name, num_args=None):
|
||||
cache.
|
||||
|
||||
Args:
|
||||
cached_method_name (str): The name of the single-item lookup method.
|
||||
cached_method_name: The name of the single-item lookup method.
|
||||
This is only used to find the cache to use.
|
||||
list_name (str): The name of the argument that is the list to use to
|
||||
list_name: The name of the argument that is the list to use to
|
||||
do batch lookups in the cache.
|
||||
num_args (int): Number of arguments to use as the key in the cache
|
||||
num_args: Number of arguments to use as the key in the cache
|
||||
(including list_name). Defaults to all named parameters.
|
||||
|
||||
Example:
|
||||
@@ -702,9 +714,11 @@ def cachedList(cached_method_name, list_name, num_args=None):
|
||||
def batch_do_something(self, first_arg, second_args):
|
||||
...
|
||||
"""
|
||||
return lambda orig: CacheListDescriptor(
|
||||
func = lambda orig: CacheListDescriptor(
|
||||
orig,
|
||||
cached_method_name=cached_method_name,
|
||||
list_name=list_name,
|
||||
num_args=num_args,
|
||||
)
|
||||
|
||||
return cast(Callable[[F], _CachedFunction[F]], func)
|
||||
|
||||
Reference in New Issue
Block a user