1
0

Merge commit '208e1d3eb' into anoa/dinsic_release_1_21_x

* commit '208e1d3eb':
  Fix typing for `@cached` wrapped functions (#8240)
  Remove useless changelog about reverting a #8239.
  Revert pinning of setuptools (#8239)
  Fix typing for SyncHandler (#8237)
  wrap `_get_e2e_device_keys_and_signatures_txn` in a non-txn method (#8231)
  Add an overload for simple_select_one_onecol_txn. (#8235)
This commit is contained in:
Andrew Morgan
2020-10-20 17:52:08 +01:00
17 changed files with 200 additions and 53 deletions

View File

@@ -73,7 +73,7 @@ mkdir -p ~/synapse
virtualenv -p python3 ~/synapse/env
source ~/synapse/env/bin/activate
pip install --upgrade pip
pip install --upgrade setuptools!=50.0 # setuptools==50.0 fails on some older Python versions
pip install --upgrade setuptools
pip install matrix-synapse
```

View File

@@ -1 +0,0 @@
Do not install setuptools 50.0. It can lead to a broken configuration on some older Python versions.

1
changelog.d/8231.misc Normal file
View File

@@ -0,0 +1 @@
Refactor queries for device keys and cross-signatures.

1
changelog.d/8235.misc Normal file
View File

@@ -0,0 +1 @@
Add type hints to `StreamStore`.

1
changelog.d/8237.misc Normal file
View File

@@ -0,0 +1 @@
Fix type hints in `SyncHandler`.

1
changelog.d/8240.misc Normal file
View File

@@ -0,0 +1 @@
Fix type hints for functions decorated with `@cached`.

View File

@@ -1,6 +1,6 @@
[mypy]
namespace_packages = True
plugins = mypy_zope:plugin
plugins = mypy_zope:plugin, scripts-dev/mypy_synapse_plugin.py
follow_imports = silent
check_untyped_defs = True
show_error_codes = True
@@ -51,6 +51,7 @@ files =
synapse/storage/util,
synapse/streams,
synapse/types.py,
synapse/util/caches/descriptors.py,
synapse/util/caches/stream_change_cache.py,
synapse/util/metrics.py,
tests/replication,

View File

@@ -0,0 +1,85 @@
# -*- coding: utf-8 -*-
# Copyright 2020 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""This is a mypy plugin for Synpase to deal with some of the funky typing that
can crop up, e.g the cache descriptors.
"""
from typing import Callable, Optional
from mypy.plugin import MethodSigContext, Plugin
from mypy.typeops import bind_self
from mypy.types import CallableType
class SynapsePlugin(Plugin):
def get_method_signature_hook(
self, fullname: str
) -> Optional[Callable[[MethodSigContext], CallableType]]:
if fullname.startswith(
"synapse.util.caches.descriptors._CachedFunction.__call__"
):
return cached_function_method_signature
return None
def cached_function_method_signature(ctx: MethodSigContext) -> CallableType:
"""Fixes the `_CachedFunction.__call__` signature to be correct.
It already has *almost* the correct signature, except:
1. the `self` argument needs to be marked as "bound"; and
2. any `cache_context` argument should be removed.
"""
# First we mark this as a bound function signature.
signature = bind_self(ctx.default_signature)
# Secondly, we remove any "cache_context" args.
#
# Note: We should be only doing this if `cache_context=True` is set, but if
# it isn't then the code will raise an exception when its called anyway, so
# its not the end of the world.
context_arg_index = None
for idx, name in enumerate(signature.arg_names):
if name == "cache_context":
context_arg_index = idx
break
if context_arg_index:
arg_types = list(signature.arg_types)
arg_types.pop(context_arg_index)
arg_names = list(signature.arg_names)
arg_names.pop(context_arg_index)
arg_kinds = list(signature.arg_kinds)
arg_kinds.pop(context_arg_index)
signature = signature.copy_modified(
arg_types=arg_types, arg_names=arg_names, arg_kinds=arg_kinds,
)
return signature
def plugin(version: str):
# This is the entry point of the plugin, and let's us deal with the fact
# that the mypy plugin interface is *not* stable by looking at the version
# string.
#
# However, since we pin the version of mypy Synapse uses in CI, we don't
# really care.
return SynapsePlugin

View File

@@ -443,11 +443,11 @@ class FederationHandler(BaseHandler):
if not prevs - seen:
return
latest = await self.store.get_latest_event_ids_in_room(room_id)
latest_list = await self.store.get_latest_event_ids_in_room(room_id)
# We add the prev events that we have seen to the latest
# list to ensure the remote server doesn't give them to us
latest = set(latest)
latest = set(latest_list)
latest |= seen
logger.info(
@@ -784,7 +784,7 @@ class FederationHandler(BaseHandler):
# keys across all devices.
current_keys = [
key
for device in cached_devices
for device in cached_devices.values()
for key in device.get("keys", {}).get("keys", {}).values()
]
@@ -2129,8 +2129,8 @@ class FederationHandler(BaseHandler):
if backfilled or event.internal_metadata.is_outlier():
return
extrem_ids = await self.store.get_latest_event_ids_in_room(event.room_id)
extrem_ids = set(extrem_ids)
extrem_ids_list = await self.store.get_latest_event_ids_in_room(event.room_id)
extrem_ids = set(extrem_ids_list)
prev_event_ids = set(event.prev_event_ids())
if extrem_ids == prev_event_ids:

View File

@@ -16,7 +16,7 @@
import itertools
import logging
from typing import Any, Dict, FrozenSet, List, Optional, Set, Tuple
from typing import TYPE_CHECKING, Any, Dict, FrozenSet, List, Optional, Set, Tuple
import attr
from prometheus_client import Counter
@@ -44,6 +44,9 @@ from synapse.util.caches.response_cache import ResponseCache
from synapse.util.metrics import Measure, measure_func
from synapse.visibility import filter_events_for_client
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
# Debug logger for https://github.com/matrix-org/synapse/issues/4422
@@ -244,7 +247,7 @@ class SyncResult:
class SyncHandler(object):
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.hs_config = hs.config
self.store = hs.get_datastore()
self.notifier = hs.get_notifier()
@@ -717,9 +720,8 @@ class SyncHandler(object):
]
missing_hero_state = await self.store.get_events(missing_hero_event_ids)
missing_hero_state = missing_hero_state.values()
for s in missing_hero_state:
for s in missing_hero_state.values():
cache.set(s.state_key, s.event_id)
state[(EventTypes.Member, s.state_key)] = s
@@ -1771,7 +1773,7 @@ class SyncHandler(object):
ignored_users: Set[str],
room_builder: "RoomSyncResultBuilder",
ephemeral: List[JsonDict],
tags: Optional[List[JsonDict]],
tags: Optional[Dict[str, Dict[str, Any]]],
account_data: Dict[str, JsonDict],
always_include: bool = False,
):

View File

@@ -74,10 +74,6 @@ REQUIREMENTS = [
"Jinja2>=2.9",
"bleach>=1.4.3",
"typing-extensions>=3.7.4",
# setuptools is required by a variety of dependencies, unfortunately version
# 50.0 is incompatible with older Python versions, see
# https://github.com/pypa/setuptools/issues/2352
"setuptools!=50.0",
]
CONDITIONAL_REQUIREMENTS = {

View File

@@ -1149,6 +1149,30 @@ class DatabasePool(object):
allow_none=allow_none,
)
@overload
@classmethod
def simple_select_one_onecol_txn(
cls,
txn: LoggingTransaction,
table: str,
keyvalues: Dict[str, Any],
retcol: Iterable[str],
allow_none: Literal[False] = False,
) -> Any:
...
@overload
@classmethod
def simple_select_one_onecol_txn(
cls,
txn: LoggingTransaction,
table: str,
keyvalues: Dict[str, Any],
retcol: Iterable[str],
allow_none: Literal[True] = True,
) -> Optional[Any]:
...
@classmethod
def simple_select_one_onecol_txn(
cls,

View File

@@ -255,9 +255,7 @@ class DeviceWorkerStore(SQLBaseStore):
List of objects representing an device update EDU
"""
devices = (
await self.db_pool.runInteraction(
"get_e2e_device_keys_and_signatures_txn",
self._get_e2e_device_keys_and_signatures_txn,
await self.get_e2e_device_keys_and_signatures(
query_map.keys(),
include_all_devices=True,
include_deleted_devices=True,

View File

@@ -36,7 +36,7 @@ if TYPE_CHECKING:
@attr.s
class DeviceKeyLookupResult:
"""The type returned by _get_e2e_device_keys_and_signatures_txn"""
"""The type returned by get_e2e_device_keys_and_signatures"""
display_name = attr.ib(type=Optional[str])
@@ -60,11 +60,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
"""
now_stream_id = self.get_device_stream_token()
devices = await self.db_pool.runInteraction(
"get_e2e_device_keys_and_signatures_txn",
self._get_e2e_device_keys_and_signatures_txn,
[(user_id, None)],
)
devices = await self.get_e2e_device_keys_and_signatures([(user_id, None)])
if devices:
user_devices = devices[user_id]
@@ -108,11 +104,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
if not query_list:
return {}
results = await self.db_pool.runInteraction(
"get_e2e_device_keys_and_signatures_txn",
self._get_e2e_device_keys_and_signatures_txn,
query_list,
)
results = await self.get_e2e_device_keys_and_signatures(query_list)
# Build the result structure, un-jsonify the results, and add the
# "unsigned" section
@@ -135,12 +127,45 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
return rv
@trace
def _get_e2e_device_keys_and_signatures_txn(
self, txn, query_list, include_all_devices=False, include_deleted_devices=False
async def get_e2e_device_keys_and_signatures(
self,
query_list: List[Tuple[str, Optional[str]]],
include_all_devices: bool = False,
include_deleted_devices: bool = False,
) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]:
"""Fetch a list of device keys, together with their cross-signatures.
Args:
query_list: List of pairs of user_ids and device_ids. Device id can be None
to indicate "all devices for this user"
include_all_devices: whether to return devices without device keys
include_deleted_devices: whether to include null entries for
devices which no longer exist (but were in the query_list).
This option only takes effect if include_all_devices is true.
Returns:
Dict mapping from user-id to dict mapping from device_id to
key data.
"""
set_tag("include_all_devices", include_all_devices)
set_tag("include_deleted_devices", include_deleted_devices)
result = await self.db_pool.runInteraction(
"get_e2e_device_keys",
self._get_e2e_device_keys_and_signatures_txn,
query_list,
include_all_devices,
include_deleted_devices,
)
log_kv(result)
return result
def _get_e2e_device_keys_and_signatures_txn(
self, txn, query_list, include_all_devices=False, include_deleted_devices=False
) -> Dict[str, Dict[str, Optional[DeviceKeyLookupResult]]]:
query_clauses = []
query_params = []
signature_query_clauses = []
@@ -230,7 +255,6 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
)
signing_user_signatures[signing_key_id] = signature
log_kv(result)
return result
async def get_e2e_one_time_keys(

View File

@@ -298,8 +298,8 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return None
async def get_rooms_for_local_user_where_membership_is(
self, user_id: str, membership_list: List[str]
) -> Optional[List[RoomsForUser]]:
self, user_id: str, membership_list: Collection[str]
) -> List[RoomsForUser]:
"""Get all the rooms for this *local* user where the membership for this user
matches one in the membership list.
@@ -314,7 +314,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
The RoomsForUser that the user matches the membership types.
"""
if not membership_list:
return None
return []
rooms = await self.db_pool.runInteraction(
"get_rooms_for_local_user_where_membership_is",

View File

@@ -43,7 +43,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
"room_tags", {"user_id": user_id}, ["room_id", "tag", "content"]
)
tags_by_room = {}
tags_by_room = {} # type: Dict[str, Dict[str, JsonDict]]
for row in rows:
room_tags = tags_by_room.setdefault(row["room_id"], {})
room_tags[row["tag"]] = db_to_json(row["content"])
@@ -123,7 +123,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
async def get_updated_tags(
self, user_id: str, stream_id: int
) -> Dict[str, List[str]]:
) -> Dict[str, Dict[str, JsonDict]]:
"""Get all the tags for the rooms where the tags have changed since the
given version

View File

@@ -18,11 +18,10 @@ import functools
import inspect
import logging
import threading
from typing import Any, Tuple, Union, cast
from typing import Any, Callable, Generic, Optional, Tuple, TypeVar, Union, cast
from weakref import WeakValueDictionary
from prometheus_client import Gauge
from typing_extensions import Protocol
from twisted.internet import defer
@@ -38,8 +37,10 @@ logger = logging.getLogger(__name__)
CacheKey = Union[Tuple, Any]
F = TypeVar("F", bound=Callable[..., Any])
class _CachedFunction(Protocol):
class _CachedFunction(Generic[F]):
invalidate = None # type: Any
invalidate_all = None # type: Any
invalidate_many = None # type: Any
@@ -47,8 +48,11 @@ class _CachedFunction(Protocol):
cache = None # type: Any
num_args = None # type: Any
def __name__(self):
...
__name__ = None # type: str
# Note: This function signature is actually fiddled with by the synapse mypy
# plugin to a) make it a bound method, and b) remove any `cache_context` arg.
__call__ = None # type: F
cache_pending_metric = Gauge(
@@ -123,7 +127,7 @@ class Cache(object):
self.name = name
self.keylen = keylen
self.thread = None
self.thread = None # type: Optional[threading.Thread]
self.metrics = register_cache(
"cache",
name,
@@ -662,9 +666,13 @@ class _CacheContext:
def cached(
max_entries=1000, num_args=None, tree=False, cache_context=False, iterable=False
):
return lambda orig: CacheDescriptor(
max_entries: int = 1000,
num_args: Optional[int] = None,
tree: bool = False,
cache_context: bool = False,
iterable: bool = False,
) -> Callable[[F], _CachedFunction[F]]:
func = lambda orig: CacheDescriptor(
orig,
max_entries=max_entries,
num_args=num_args,
@@ -673,8 +681,12 @@ def cached(
iterable=iterable,
)
return cast(Callable[[F], _CachedFunction[F]], func)
def cachedList(cached_method_name, list_name, num_args=None):
def cachedList(
cached_method_name: str, list_name: str, num_args: Optional[int] = None
) -> Callable[[F], _CachedFunction[F]]:
"""Creates a descriptor that wraps a function in a `CacheListDescriptor`.
Used to do batch lookups for an already created cache. A single argument
@@ -684,11 +696,11 @@ def cachedList(cached_method_name, list_name, num_args=None):
cache.
Args:
cached_method_name (str): The name of the single-item lookup method.
cached_method_name: The name of the single-item lookup method.
This is only used to find the cache to use.
list_name (str): The name of the argument that is the list to use to
list_name: The name of the argument that is the list to use to
do batch lookups in the cache.
num_args (int): Number of arguments to use as the key in the cache
num_args: Number of arguments to use as the key in the cache
(including list_name). Defaults to all named parameters.
Example:
@@ -702,9 +714,11 @@ def cachedList(cached_method_name, list_name, num_args=None):
def batch_do_something(self, first_arg, second_args):
...
"""
return lambda orig: CacheListDescriptor(
func = lambda orig: CacheListDescriptor(
orig,
cached_method_name=cached_method_name,
list_name=list_name,
num_args=num_args,
)
return cast(Callable[[F], _CachedFunction[F]], func)