1
0

Merge commit '5bf8e5f55' into anoa/dinsic_release_1_21_x

* commit '5bf8e5f55':
  Convert the well known resolver to async (#8214)
  Convert additional databases to async/await part 2 (#8200)
  Make MultiWriterIDGenerator work for streams that use negative stream IDs (#8203)
  Do not install setuptools 50.0. (#8212)
  Move and rename `get_devices_with_keys_by_user` (#8204)
  Rename `get_e2e_device_keys` to better reflect its purpose (#8205)
  Add a comment about _LimitedHostnameResolver
This commit is contained in:
Andrew Morgan
2020-10-20 17:42:29 +01:00
36 changed files with 392 additions and 195 deletions

View File

@@ -73,7 +73,7 @@ mkdir -p ~/synapse
virtualenv -p python3 ~/synapse/env
source ~/synapse/env/bin/activate
pip install --upgrade pip
pip install --upgrade setuptools
pip install --upgrade setuptools!=50.0 # setuptools==50.0 fails on some older Python versions
pip install matrix-synapse
```

1
changelog.d/8200.misc Normal file
View File

@@ -0,0 +1 @@
Convert various parts of the codebase to async/await.

1
changelog.d/8203.misc Normal file
View File

@@ -0,0 +1 @@
Make `MultiWriterIDGenerator` work for streams that use negative values.

1
changelog.d/8204.misc Normal file
View File

@@ -0,0 +1 @@
Refactor queries for device keys and cross-signatures.

1
changelog.d/8205.misc Normal file
View File

@@ -0,0 +1 @@
Refactor queries for device keys and cross-signatures.

1
changelog.d/8212.bugfix Normal file
View File

@@ -0,0 +1 @@
Do not install setuptools 50.0. It can lead to a broken configuration on some older Python versions.

1
changelog.d/8214.misc Normal file
View File

@@ -0,0 +1 @@
Convert various parts of the codebase to async/await.

View File

@@ -28,6 +28,7 @@ files =
synapse/handlers/saml_handler.py,
synapse/handlers/sync.py,
synapse/handlers/ui_auth,
synapse/http/federation/well_known_resolver.py,
synapse/http/server.py,
synapse/http/site.py,
synapse/logging/,

View File

@@ -334,6 +334,13 @@ def install_dns_limiter(reactor, max_dns_requests_in_flight=100):
This is to workaround https://twistedmatrix.com/trac/ticket/9620, where we
can run out of file descriptors and infinite loop if we attempt to do too
many DNS queries at once
XXX: I'm confused by this. reactor.nameResolver does not use twisted.names unless
you explicitly install twisted.names as the resolver; rather it uses a GAIResolver
backed by the reactor's default threadpool (which is limited to 10 threads). So
(a) I don't understand why twisted ticket 9620 is relevant, and (b) I don't
understand why we would run out of FDs if we did too many lookups at once.
-- richvdh 2020/08/29
"""
new_resolver = _LimitedHostnameResolver(
reactor.nameResolver, max_dns_requests_in_flight

View File

@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
from typing import Any, Dict, List, Optional, Tuple, Union
import attr
from nacl.signing import SigningKey
@@ -97,14 +97,14 @@ class EventBuilder(object):
def is_state(self):
return self._state_key is not None
async def build(self, prev_event_ids):
async def build(self, prev_event_ids: List[str]) -> EventBase:
"""Transform into a fully signed and hashed event
Args:
prev_event_ids (list[str]): The event IDs to use as the prev events
prev_event_ids: The event IDs to use as the prev events
Returns:
FrozenEvent
The signed and hashed event.
"""
state_ids = await self._state.get_current_state_ids(
@@ -114,8 +114,13 @@ class EventBuilder(object):
format_version = self.room_version.event_format
if format_version == EventFormatVersions.V1:
auth_events = await self._store.add_event_hashes(auth_ids)
prev_events = await self._store.add_event_hashes(prev_event_ids)
# The types of auth/prev events changes between event versions.
auth_events = await self._store.add_event_hashes(
auth_ids
) # type: Union[List[str], List[Tuple[str, Dict[str, str]]]]
prev_events = await self._store.add_event_hashes(
prev_event_ids
) # type: Union[List[str], List[Tuple[str, Dict[str, str]]]]
else:
auth_events = auth_ids
prev_events = prev_event_ids
@@ -138,7 +143,7 @@ class EventBuilder(object):
"unsigned": self.unsigned,
"depth": depth,
"prev_state": [],
}
} # type: Dict[str, Any]
if self.is_state():
event_dict["state_key"] = self._state_key

View File

@@ -234,7 +234,9 @@ class DeviceWorkerHandler(BaseHandler):
return result
async def on_federation_query_user_devices(self, user_id):
stream_id, devices = await self.store.get_devices_with_keys_by_user(user_id)
stream_id, devices = await self.store.get_e2e_device_keys_for_federation_query(
user_id
)
master_key = await self.store.get_e2e_cross_signing_key(user_id, "master")
self_signing_key = await self.store.get_e2e_cross_signing_key(
user_id, "self_signing"

View File

@@ -353,7 +353,7 @@ class E2eKeysHandler(object):
# make sure that each queried user appears in the result dict
result_dict[user_id] = {}
results = await self.store.get_e2e_device_keys(local_query)
results = await self.store.get_e2e_device_keys_for_cs_api(local_query)
# Build the result structure
for user_id, device_keys in results.items():
@@ -734,7 +734,7 @@ class E2eKeysHandler(object):
# fetch our stored devices. This is used to 1. verify
# signatures on the master key, and 2. to compare with what
# was sent if the device was signed
devices = await self.store.get_e2e_device_keys([(user_id, None)])
devices = await self.store.get_e2e_device_keys_for_cs_api([(user_id, None)])
if user_id not in devices:
raise NotFoundError("No device keys found")

View File

@@ -49,14 +49,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.http.send_event import ReplicationSendEventRestServlet
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.storage.state import StateFilter
from synapse.types import (
Collection,
Requester,
RoomAlias,
StreamToken,
UserID,
create_requester,
)
from synapse.types import Requester, RoomAlias, StreamToken, UserID, create_requester
from synapse.util import json_decoder
from synapse.util.async_helpers import Linearizer
from synapse.util.frozenutils import frozendict_json_encoder
@@ -449,7 +442,7 @@ class EventCreationHandler(object):
event_dict: dict,
token_id: Optional[str] = None,
txn_id: Optional[str] = None,
prev_event_ids: Optional[Collection[str]] = None,
prev_event_ids: Optional[List[str]] = None,
require_consent: bool = True,
) -> Tuple[EventBase, EventContext]:
"""
@@ -789,7 +782,7 @@ class EventCreationHandler(object):
self,
builder: EventBuilder,
requester: Optional[Requester] = None,
prev_event_ids: Optional[Collection[str]] = None,
prev_event_ids: Optional[List[str]] = None,
) -> Tuple[EventBase, EventContext]:
"""Create a new event for a local client

View File

@@ -38,15 +38,7 @@ from synapse.events.builder import create_local_event_from_event_dict
from synapse.events.snapshot import EventContext
from synapse.events.validator import EventValidator
from synapse.storage.roommember import RoomsForUser
from synapse.types import (
Collection,
JsonDict,
Requester,
RoomAlias,
RoomID,
StateMap,
UserID,
)
from synapse.types import JsonDict, Requester, RoomAlias, RoomID, StateMap, UserID
from synapse.util.async_helpers import Linearizer
from synapse.util.distributor import user_joined_room, user_left_room
@@ -185,7 +177,7 @@ class RoomMemberHandler(object):
target: UserID,
room_id: str,
membership: str,
prev_event_ids: Collection[str],
prev_event_ids: List[str],
txn_id: Optional[str] = None,
ratelimit: bool = True,
content: Optional[dict] = None,

View File

@@ -134,8 +134,8 @@ class MatrixFederationAgent(object):
and not _is_ip_literal(parsed_uri.hostname)
and not parsed_uri.port
):
well_known_result = yield self._well_known_resolver.get_well_known(
parsed_uri.hostname
well_known_result = yield defer.ensureDeferred(
self._well_known_resolver.get_well_known(parsed_uri.hostname)
)
delegated_server = well_known_result.delegated_server

View File

@@ -16,6 +16,7 @@
import logging
import random
import time
from typing import Callable, Dict, Optional, Tuple
import attr
@@ -23,6 +24,7 @@ from twisted.internet import defer
from twisted.web.client import RedirectAgent, readBody
from twisted.web.http import stringToDatetime
from twisted.web.http_headers import Headers
from twisted.web.iweb import IResponse
from synapse.logging.context import make_deferred_yieldable
from synapse.util import Clock, json_decoder
@@ -99,15 +101,14 @@ class WellKnownResolver(object):
self._well_known_agent = RedirectAgent(agent)
self.user_agent = user_agent
@defer.inlineCallbacks
def get_well_known(self, server_name):
async def get_well_known(self, server_name: bytes) -> WellKnownLookupResult:
"""Attempt to fetch and parse a .well-known file for the given server
Args:
server_name (bytes): name of the server, from the requested url
server_name: name of the server, from the requested url
Returns:
Deferred[WellKnownLookupResult]: The result of the lookup
The result of the lookup
"""
try:
prev_result, expiry, ttl = self._well_known_cache.get_with_expiry(
@@ -124,7 +125,9 @@ class WellKnownResolver(object):
# requests for the same server in parallel?
try:
with Measure(self._clock, "get_well_known"):
result, cache_period = yield self._fetch_well_known(server_name)
result, cache_period = await self._fetch_well_known(
server_name
) # type: Tuple[Optional[bytes], float]
except _FetchWellKnownFailure as e:
if prev_result and e.temporary:
@@ -153,18 +156,17 @@ class WellKnownResolver(object):
return WellKnownLookupResult(delegated_server=result)
@defer.inlineCallbacks
def _fetch_well_known(self, server_name):
async def _fetch_well_known(self, server_name: bytes) -> Tuple[bytes, float]:
"""Actually fetch and parse a .well-known, without checking the cache
Args:
server_name (bytes): name of the server, from the requested url
server_name: name of the server, from the requested url
Raises:
_FetchWellKnownFailure if we fail to lookup a result
Returns:
Deferred[Tuple[bytes,int]]: The lookup result and cache period.
The lookup result and cache period.
"""
had_valid_well_known = self._had_valid_well_known_cache.get(server_name, False)
@@ -172,7 +174,7 @@ class WellKnownResolver(object):
# We do this in two steps to differentiate between possibly transient
# errors (e.g. can't connect to host, 503 response) and more permenant
# errors (such as getting a 404 response).
response, body = yield self._make_well_known_request(
response, body = await self._make_well_known_request(
server_name, retry=had_valid_well_known
)
@@ -215,20 +217,20 @@ class WellKnownResolver(object):
return result, cache_period
@defer.inlineCallbacks
def _make_well_known_request(self, server_name, retry):
async def _make_well_known_request(
self, server_name: bytes, retry: bool
) -> Tuple[IResponse, bytes]:
"""Make the well known request.
This will retry the request if requested and it fails (with unable
to connect or receives a 5xx error).
Args:
server_name (bytes)
retry (bool): Whether to retry the request if it fails.
server_name: name of the server, from the requested url
retry: Whether to retry the request if it fails.
Returns:
Deferred[tuple[IResponse, bytes]] Returns the response object and
body. Response may be a non-200 response.
Returns the response object and body. Response may be a non-200 response.
"""
uri = b"https://%s/.well-known/matrix/server" % (server_name,)
uri_str = uri.decode("ascii")
@@ -243,12 +245,12 @@ class WellKnownResolver(object):
logger.info("Fetching %s", uri_str)
try:
response = yield make_deferred_yieldable(
response = await make_deferred_yieldable(
self._well_known_agent.request(
b"GET", uri, headers=Headers(headers)
)
)
body = yield make_deferred_yieldable(readBody(response))
body = await make_deferred_yieldable(readBody(response))
if 500 <= response.code < 600:
raise Exception("Non-200 response %s" % (response.code,))
@@ -265,21 +267,24 @@ class WellKnownResolver(object):
logger.info("Error fetching %s: %s. Retrying", uri_str, e)
# Sleep briefly in the hopes that they come back up
yield self._clock.sleep(0.5)
await self._clock.sleep(0.5)
def _cache_period_from_headers(headers, time_now=time.time):
def _cache_period_from_headers(
headers: Headers, time_now: Callable[[], float] = time.time
) -> Optional[float]:
cache_controls = _parse_cache_control(headers)
if b"no-store" in cache_controls:
return 0
if b"max-age" in cache_controls:
try:
max_age = int(cache_controls[b"max-age"])
return max_age
except ValueError:
pass
max_age = cache_controls[b"max-age"]
if max_age:
try:
return int(max_age)
except ValueError:
pass
expires = headers.getRawHeaders(b"expires")
if expires is not None:
@@ -295,7 +300,7 @@ def _cache_period_from_headers(headers, time_now=time.time):
return None
def _parse_cache_control(headers):
def _parse_cache_control(headers: Headers) -> Dict[bytes, Optional[bytes]]:
cache_controls = {}
for hdr in headers.getRawHeaders(b"cache-control", []):
for directive in hdr.split(b","):

View File

@@ -74,6 +74,10 @@ REQUIREMENTS = [
"Jinja2>=2.9",
"bleach>=1.4.3",
"typing-extensions>=3.7.4",
# setuptools is required by a variety of dependencies, unfortunately version
# 50.0 is incompatible with older Python versions, see
# https://github.com/pypa/setuptools/issues/2352
"setuptools!=50.0",
]
CONDITIONAL_REQUIREMENTS = {

View File

@@ -48,6 +48,9 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
"DeviceListFederationStreamChangeCache", device_list_max
)
def get_device_stream_token(self) -> int:
return self._device_list_id_gen.get_current_token()
def process_replication_rows(self, stream_name, instance_name, token, rows):
if stream_name == DeviceListsStream.NAME:
self._device_list_id_gen.advance(instance_name, token)

View File

@@ -264,6 +264,9 @@ class DataStore(
# Used in _generate_user_daily_visits to keep track of progress
self._last_user_visit_update = self._get_start_of_day()
def get_device_stream_token(self) -> int:
return self._device_list_id_gen.get_current_token()
def take_presence_startup_info(self):
active_on_startup = self._presence_on_startup
self._presence_on_startup = None

View File

@@ -396,7 +396,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
self._batch_row_update[key] = (user_agent, device_id, now)
@wrap_as_background_process("update_client_ips")
def _update_client_ips_batch(self):
async def _update_client_ips_batch(self) -> None:
# If the DB pool has already terminated, don't try updating
if not self.db_pool.is_running():
@@ -405,7 +405,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
to_update = self._batch_row_update
self._batch_row_update = {}
return self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"_update_client_ips_batch", self._update_client_ips_batch_txn, to_update
)

View File

@@ -14,6 +14,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
import logging
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple
@@ -101,7 +102,7 @@ class DeviceWorkerStore(SQLBaseStore):
update included in the response), and the list of updates, where
each update is a pair of EDU type and EDU contents.
"""
now_stream_id = self._device_list_id_gen.get_current_token()
now_stream_id = self.get_device_stream_token()
has_changed = self._device_list_federation_stream_cache.has_entity_changed(
destination, int(from_stream_id)
@@ -412,8 +413,10 @@ class DeviceWorkerStore(SQLBaseStore):
},
)
@abc.abstractmethod
def get_device_stream_token(self) -> int:
return self._device_list_id_gen.get_current_token()
"""Get the current stream id from the _device_list_id_gen"""
...
@trace
async def get_user_devices_from_cache(
@@ -481,51 +484,6 @@ class DeviceWorkerStore(SQLBaseStore):
device["device_id"]: db_to_json(device["content"]) for device in devices
}
def get_devices_with_keys_by_user(self, user_id: str):
"""Get all devices (with any device keys) for a user
Returns:
Deferred which resolves to (stream_id, devices)
"""
return self.db_pool.runInteraction(
"get_devices_with_keys_by_user",
self._get_devices_with_keys_by_user_txn,
user_id,
)
def _get_devices_with_keys_by_user_txn(
self, txn: LoggingTransaction, user_id: str
) -> Tuple[int, List[JsonDict]]:
now_stream_id = self._device_list_id_gen.get_current_token()
devices = self._get_e2e_device_keys_txn(txn, [(user_id, None)])
if devices:
user_devices = devices[user_id]
results = []
for device_id, device in user_devices.items():
result = {"device_id": device_id}
key_json = device.get("key_json", None)
if key_json:
result["keys"] = db_to_json(key_json)
if "signatures" in device:
for sig_user_id, sigs in device["signatures"].items():
result["keys"].setdefault("signatures", {}).setdefault(
sig_user_id, {}
).update(sigs)
device_display_name = device.get("device_display_name", None)
if device_display_name:
result["device_display_name"] = device_display_name
results.append(result)
return now_stream_id, results
return now_stream_id, []
async def get_users_whose_devices_changed(
self, from_key: str, user_ids: Iterable[str]
) -> Set[str]:

View File

@@ -159,9 +159,9 @@ class DirectoryStore(DirectoryWorkerStore):
return room_id
def update_aliases_for_room(
async def update_aliases_for_room(
self, old_room_id: str, new_room_id: str, creator: Optional[str] = None,
):
) -> None:
"""Repoint all of the aliases for a given room, to a different room.
Args:
@@ -189,6 +189,6 @@ class DirectoryStore(DirectoryWorkerStore):
txn, self.get_aliases_for_room, (new_room_id,)
)
return self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"_update_aliases_for_room_txn", _update_aliases_for_room_txn
)

View File

@@ -14,6 +14,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import abc
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple
from canonicaljson import encode_canonical_json
@@ -22,7 +23,8 @@ from twisted.enterprise.adbapi import Connection
from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.storage.database import make_in_list_sql_clause
from synapse.storage.database import LoggingTransaction, make_in_list_sql_clause
from synapse.types import JsonDict
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.iterutils import batch_iter
@@ -32,18 +34,58 @@ if TYPE_CHECKING:
class EndToEndKeyWorkerStore(SQLBaseStore):
def get_e2e_device_keys_for_federation_query(self, user_id: str):
"""Get all devices (with any device keys) for a user
Returns:
Deferred which resolves to (stream_id, devices)
"""
return self.db_pool.runInteraction(
"get_e2e_device_keys_for_federation_query",
self._get_e2e_device_keys_for_federation_query_txn,
user_id,
)
def _get_e2e_device_keys_for_federation_query_txn(
self, txn: LoggingTransaction, user_id: str
) -> Tuple[int, List[JsonDict]]:
now_stream_id = self.get_device_stream_token()
devices = self._get_e2e_device_keys_txn(txn, [(user_id, None)])
if devices:
user_devices = devices[user_id]
results = []
for device_id, device in user_devices.items():
result = {"device_id": device_id}
key_json = device.get("key_json", None)
if key_json:
result["keys"] = db_to_json(key_json)
if "signatures" in device:
for sig_user_id, sigs in device["signatures"].items():
result["keys"].setdefault("signatures", {}).setdefault(
sig_user_id, {}
).update(sigs)
device_display_name = device.get("device_display_name", None)
if device_display_name:
result["device_display_name"] = device_display_name
results.append(result)
return now_stream_id, results
return now_stream_id, []
@trace
async def get_e2e_device_keys(
self, query_list, include_all_devices=False, include_deleted_devices=False
):
"""Fetch a list of device keys.
async def get_e2e_device_keys_for_cs_api(
self, query_list: List[Tuple[str, Optional[str]]]
) -> Dict[str, Dict[str, JsonDict]]:
"""Fetch a list of device keys, formatted suitably for the C/S API.
Args:
query_list(list): List of pairs of user_ids and device_ids.
include_all_devices (bool): whether to include entries for devices
that don't have device keys
include_deleted_devices (bool): whether to include null entries for
devices which no longer exist (but were in the query_list).
This option only takes effect if include_all_devices is true.
Returns:
Dict mapping from user-id to dict mapping from device_id to
key data. The key data will be a dict in the same format as the
@@ -54,11 +96,7 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
return {}
results = await self.db_pool.runInteraction(
"get_e2e_device_keys",
self._get_e2e_device_keys_txn,
query_list,
include_all_devices,
include_deleted_devices,
"get_e2e_device_keys", self._get_e2e_device_keys_txn, query_list,
)
# Build the result structure, un-jsonify the results, and add the
@@ -541,6 +579,11 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
_get_all_user_signature_changes_for_remotes_txn,
)
@abc.abstractmethod
def get_device_stream_token(self) -> int:
"""Get the current stream id from the _device_list_id_gen"""
...
class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys):

View File

@@ -17,6 +17,7 @@ from canonicaljson import encode_canonical_json
from synapse.api.errors import Codes, SynapseError
from synapse.storage._base import SQLBaseStore, db_to_json
from synapse.types import JsonDict
from synapse.util.caches.descriptors import cached
@@ -40,7 +41,7 @@ class FilteringStore(SQLBaseStore):
return db_to_json(def_json)
def add_user_filter(self, user_localpart, user_filter):
async def add_user_filter(self, user_localpart: str, user_filter: JsonDict) -> str:
def_json = encode_canonical_json(user_filter)
# Need an atomic transaction to SELECT the maximal ID so far then
@@ -71,4 +72,4 @@ class FilteringStore(SQLBaseStore):
return filter_id
return self.db_pool.runInteraction("add_user_filter", _do_txn)
return await self.db_pool.runInteraction("add_user_filter", _do_txn)

View File

@@ -1,3 +1,5 @@
from typing import Optional
from synapse.storage._base import SQLBaseStore
@@ -15,7 +17,9 @@ class OpenIdStore(SQLBaseStore):
desc="insert_open_id_token",
)
def get_user_id_for_open_id_token(self, token, ts_now_ms):
async def get_user_id_for_open_id_token(
self, token: str, ts_now_ms: int
) -> Optional[str]:
def get_user_id_for_token_txn(txn):
sql = (
"SELECT user_id FROM open_id_tokens"
@@ -30,6 +34,6 @@ class OpenIdStore(SQLBaseStore):
else:
return rows[0][0]
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_user_id_for_token", get_user_id_for_token_txn
)

View File

@@ -252,7 +252,9 @@ class ProfileStore(ProfileWorkerStore):
desc="delete_remote_profile_cache",
)
def get_remote_profile_cache_entries_that_expire(self, last_checked):
async def get_remote_profile_cache_entries_that_expire(
self, last_checked: int
) -> Dict[str, str]:
"""Get all users who haven't been checked since `last_checked`
"""
@@ -267,7 +269,7 @@ class ProfileStore(ProfileWorkerStore):
return self.db_pool.cursor_to_dict(txn)
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_remote_profile_cache_entries_that_expire",
_get_remote_profile_cache_entries_that_expire_txn,
)

View File

@@ -18,8 +18,6 @@ import abc
import logging
from typing import List, Tuple, Union
from twisted.internet import defer
from synapse.push.baserules import list_with_base_rules
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.storage._base import SQLBaseStore, db_to_json
@@ -149,9 +147,11 @@ class PushRulesWorkerStore(
)
return {r["rule_id"]: False if r["enabled"] == 0 else True for r in results}
def have_push_rules_changed_for_user(self, user_id, last_id):
async def have_push_rules_changed_for_user(
self, user_id: str, last_id: int
) -> bool:
if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id):
return defer.succeed(False)
return False
else:
def have_push_rules_changed_txn(txn):
@@ -163,7 +163,7 @@ class PushRulesWorkerStore(
(count,) = txn.fetchone()
return bool(count)
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"have_push_rules_changed", have_push_rules_changed_txn
)

View File

@@ -89,7 +89,7 @@ class RoomWorkerStore(SQLBaseStore):
allow_none=True,
)
def get_room_with_stats(self, room_id: str):
async def get_room_with_stats(self, room_id: str) -> Optional[Dict[str, Any]]:
"""Retrieve room with statistics.
Args:
@@ -121,7 +121,7 @@ class RoomWorkerStore(SQLBaseStore):
res["public"] = bool(res["public"])
return res
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_room_with_stats", get_room_with_stats_txn, room_id
)
@@ -133,13 +133,17 @@ class RoomWorkerStore(SQLBaseStore):
desc="get_public_room_ids",
)
def count_public_rooms(self, network_tuple, ignore_non_federatable):
async def count_public_rooms(
self,
network_tuple: Optional[ThirdPartyInstanceID],
ignore_non_federatable: bool,
) -> int:
"""Counts the number of public rooms as tracked in the room_stats_current
and room_stats_state table.
Args:
network_tuple (ThirdPartyInstanceID|None)
ignore_non_federatable (bool): If true filters out non-federatable rooms
network_tuple
ignore_non_federatable: If true filters out non-federatable rooms
"""
def _count_public_rooms_txn(txn):
@@ -183,7 +187,7 @@ class RoomWorkerStore(SQLBaseStore):
txn.execute(sql, query_args)
return txn.fetchone()[0]
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"count_public_rooms", _count_public_rooms_txn
)
@@ -608,15 +612,14 @@ class RoomWorkerStore(SQLBaseStore):
return row
def get_media_mxcs_in_room(self, room_id):
async def get_media_mxcs_in_room(self, room_id: str) -> Tuple[List[str], List[str]]:
"""Retrieves all the local and remote media MXC URIs in a given room
Args:
room_id (str)
room_id
Returns:
The local and remote media as a lists of tuples where the key is
the hostname and the value is the media ID.
The local and remote media as a lists of the media IDs.
"""
def _get_media_mxcs_in_room_txn(txn):
@@ -632,11 +635,13 @@ class RoomWorkerStore(SQLBaseStore):
return local_media_mxcs, remote_media_mxcs
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"get_media_ids_in_room", _get_media_mxcs_in_room_txn
)
def quarantine_media_ids_in_room(self, room_id, quarantined_by):
async def quarantine_media_ids_in_room(
self, room_id: str, quarantined_by: str
) -> int:
"""For a room loops through all events with media and quarantines
the associated media
"""
@@ -649,7 +654,7 @@ class RoomWorkerStore(SQLBaseStore):
txn, local_mxcs, remote_mxcs, quarantined_by
)
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"quarantine_media_in_room", _quarantine_media_in_room_txn
)
@@ -712,9 +717,9 @@ class RoomWorkerStore(SQLBaseStore):
return local_media_mxcs, remote_media_mxcs
def quarantine_media_by_id(
async def quarantine_media_by_id(
self, server_name: str, media_id: str, quarantined_by: str,
):
) -> int:
"""quarantines a single local or remote media id
Args:
@@ -733,11 +738,13 @@ class RoomWorkerStore(SQLBaseStore):
txn, local_mxcs, remote_mxcs, quarantined_by
)
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"quarantine_media_by_user", _quarantine_media_by_id_txn
)
def quarantine_media_ids_by_user(self, user_id: str, quarantined_by: str):
async def quarantine_media_ids_by_user(
self, user_id: str, quarantined_by: str
) -> int:
"""quarantines all local media associated with a single user
Args:
@@ -749,7 +756,7 @@ class RoomWorkerStore(SQLBaseStore):
local_media_ids = self._get_media_ids_by_user_txn(txn, user_id)
return self._quarantine_media_txn(txn, local_media_ids, [], quarantined_by)
return self.db_pool.runInteraction(
return await self.db_pool.runInteraction(
"quarantine_media_by_user", _quarantine_media_by_user_txn
)
@@ -1306,8 +1313,8 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
)
self.hs.get_notifier().on_new_replication_data()
def get_room_count(self):
"""Retrieve a list of all rooms
async def get_room_count(self) -> int:
"""Retrieve the total number of rooms.
"""
def f(txn):
@@ -1316,7 +1323,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
row = txn.fetchone()
return row[0] or 0
return self.db_pool.runInteraction("get_rooms", f)
return await self.db_pool.runInteraction("get_rooms", f)
async def add_event_report(
self,

View File

@@ -13,9 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, Iterable, List, Tuple
from unpaddedbase64 import encode_base64
from synapse.storage._base import SQLBaseStore
from synapse.storage.types import Cursor
from synapse.util.caches.descriptors import cached, cachedList
@@ -29,16 +32,37 @@ class SignatureWorkerStore(SQLBaseStore):
@cachedList(
cached_method_name="get_event_reference_hash", list_name="event_ids", num_args=1
)
def get_event_reference_hashes(self, event_ids):
async def get_event_reference_hashes(
self, event_ids: Iterable[str]
) -> Dict[str, Dict[str, bytes]]:
"""Get all hashes for given events.
Args:
event_ids: The event IDs to get hashes for.
Returns:
A mapping of event ID to a mapping of algorithm to hash.
"""
def f(txn):
return {
event_id: self._get_event_reference_hashes_txn(txn, event_id)
for event_id in event_ids
}
return self.db_pool.runInteraction("get_event_reference_hashes", f)
return await self.db_pool.runInteraction("get_event_reference_hashes", f)
async def add_event_hashes(self, event_ids):
async def add_event_hashes(
self, event_ids: Iterable[str]
) -> List[Tuple[str, Dict[str, str]]]:
"""
Args:
event_ids: The event IDs
Returns:
A list of tuples of event ID and a mapping of algorithm to base-64 encoded hash.
"""
hashes = await self.get_event_reference_hashes(event_ids)
hashes = {
e_id: {k: encode_base64(v) for k, v in h.items() if k == "sha256"}
@@ -47,13 +71,15 @@ class SignatureWorkerStore(SQLBaseStore):
return list(hashes.items())
def _get_event_reference_hashes_txn(self, txn, event_id):
def _get_event_reference_hashes_txn(
self, txn: Cursor, event_id: str
) -> Dict[str, bytes]:
"""Get all the hashes for a given PDU.
Args:
txn (cursor):
event_id (str): Id for the Event.
txn:
event_id: Id for the Event.
Returns:
A dict[unicode, bytes] of algorithm -> hash.
A mapping of algorithm -> hash.
"""
query = (
"SELECT algorithm, hash"

View File

@@ -290,7 +290,7 @@ class UIAuthWorkerStore(SQLBaseStore):
class UIAuthStore(UIAuthWorkerStore):
def delete_old_ui_auth_sessions(self, expiration_time: int):
async def delete_old_ui_auth_sessions(self, expiration_time: int) -> None:
"""
Remove sessions which were last used earlier than the expiration time.
@@ -299,7 +299,7 @@ class UIAuthStore(UIAuthWorkerStore):
This is an epoch time in milliseconds.
"""
return self.db_pool.runInteraction(
await self.db_pool.runInteraction(
"delete_old_ui_auth_sessions",
self._delete_old_ui_auth_sessions_txn,
expiration_time,

View File

@@ -66,7 +66,7 @@ class UserErasureWorkerStore(SQLBaseStore):
class UserErasureStore(UserErasureWorkerStore):
def mark_user_erased(self, user_id: str) -> None:
async def mark_user_erased(self, user_id: str) -> None:
"""Indicate that user_id wishes their message history to be erased.
Args:
@@ -84,9 +84,9 @@ class UserErasureStore(UserErasureWorkerStore):
self._invalidate_cache_and_stream(txn, self.is_user_erased, (user_id,))
return self.db_pool.runInteraction("mark_user_erased", f)
await self.db_pool.runInteraction("mark_user_erased", f)
def mark_user_not_erased(self, user_id: str) -> None:
async def mark_user_not_erased(self, user_id: str) -> None:
"""Indicate that user_id is no longer erased.
Args:
@@ -106,4 +106,4 @@ class UserErasureStore(UserErasureWorkerStore):
self._invalidate_cache_and_stream(txn, self.is_user_erased, (user_id,))
return self.db_pool.runInteraction("mark_user_not_erased", f)
await self.db_pool.runInteraction("mark_user_not_erased", f)

View File

@@ -185,6 +185,8 @@ class MultiWriterIdGenerator:
id_column: Column that stores the stream ID.
sequence_name: The name of the postgres sequence used to generate new
IDs.
positive: Whether the IDs are positive (true) or negative (false).
When using negative IDs we go backwards from -1 to -2, -3, etc.
"""
def __init__(
@@ -196,13 +198,19 @@ class MultiWriterIdGenerator:
instance_column: str,
id_column: str,
sequence_name: str,
positive: bool = True,
):
self._db = db
self._instance_name = instance_name
self._positive = positive
self._return_factor = 1 if positive else -1
# We lock as some functions may be called from DB threads.
self._lock = threading.Lock()
# Note: If we are a negative stream then we still store all the IDs as
# positive to make life easier for us, and simply negate the IDs when we
# return them.
self._current_positions = self._load_current_ids(
db_conn, table, instance_column, id_column
)
@@ -233,13 +241,16 @@ class MultiWriterIdGenerator:
def _load_current_ids(
self, db_conn, table: str, instance_column: str, id_column: str
) -> Dict[str, int]:
# If positive stream aggregate via MAX. For negative stream use MIN
# *and* negate the result to get a positive number.
sql = """
SELECT %(instance)s, MAX(%(id)s) FROM %(table)s
SELECT %(instance)s, %(agg)s(%(id)s) FROM %(table)s
GROUP BY %(instance)s
""" % {
"instance": instance_column,
"id": id_column,
"table": table,
"agg": "MAX" if self._positive else "-MIN",
}
cur = db_conn.cursor()
@@ -269,15 +280,16 @@ class MultiWriterIdGenerator:
# Assert the fetched ID is actually greater than what we currently
# believe the ID to be. If not, then the sequence and table have got
# out of sync somehow.
assert self.get_current_token_for_writer(self._instance_name) < next_id
with self._lock:
assert self._current_positions.get(self._instance_name, 0) < next_id
self._unfinished_ids.add(next_id)
@contextlib.contextmanager
def manager():
try:
yield next_id
# Multiply by the return factor so that the ID has correct sign.
yield self._return_factor * next_id
finally:
self._mark_id_as_finished(next_id)
@@ -296,15 +308,15 @@ class MultiWriterIdGenerator:
# Assert the fetched ID is actually greater than any ID we've already
# seen. If not, then the sequence and table have got out of sync
# somehow.
assert max(self.get_positions().values(), default=0) < min(next_ids)
with self._lock:
assert max(self._current_positions.values(), default=0) < min(next_ids)
self._unfinished_ids.update(next_ids)
@contextlib.contextmanager
def manager():
try:
yield next_ids
yield [self._return_factor * i for i in next_ids]
finally:
for i in next_ids:
self._mark_id_as_finished(i)
@@ -327,7 +339,7 @@ class MultiWriterIdGenerator:
txn.call_after(self._mark_id_as_finished, next_id)
txn.call_on_exception(self._mark_id_as_finished, next_id)
return next_id
return self._return_factor * next_id
def _mark_id_as_finished(self, next_id: int):
"""The ID has finished being processed so we should advance the
@@ -359,20 +371,25 @@ class MultiWriterIdGenerator:
"""
with self._lock:
return self._current_positions.get(instance_name, 0)
return self._return_factor * self._current_positions.get(instance_name, 0)
def get_positions(self) -> Dict[str, int]:
"""Get a copy of the current positon map.
"""
with self._lock:
return dict(self._current_positions)
return {
name: self._return_factor * i
for name, i in self._current_positions.items()
}
def advance(self, instance_name: str, new_id: int):
"""Advance the postion of the named writer to the given ID, if greater
than existing entry.
"""
new_id *= self._return_factor
with self._lock:
self._current_positions[instance_name] = max(
new_id, self._current_positions.get(instance_name, 0)
@@ -390,7 +407,7 @@ class MultiWriterIdGenerator:
"""
with self._lock:
return self._persisted_upto_position
return self._return_factor * self._persisted_upto_position
def _add_persisted_position(self, new_id: int):
"""Record that we have persisted a position.

View File

@@ -972,7 +972,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
def test_well_known_cache(self):
self.reactor.lookups["testserv"] = "1.2.3.4"
fetch_d = self.well_known_resolver.get_well_known(b"testserv")
fetch_d = defer.ensureDeferred(
self.well_known_resolver.get_well_known(b"testserv")
)
# there should be an attempt to connect on port 443 for the .well-known
clients = self.reactor.tcpClients
@@ -995,7 +997,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
well_known_server.loseConnection()
# repeat the request: it should hit the cache
fetch_d = self.well_known_resolver.get_well_known(b"testserv")
fetch_d = defer.ensureDeferred(
self.well_known_resolver.get_well_known(b"testserv")
)
r = self.successResultOf(fetch_d)
self.assertEqual(r.delegated_server, b"target-server")
@@ -1003,7 +1007,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
self.reactor.pump((1000.0,))
# now it should connect again
fetch_d = self.well_known_resolver.get_well_known(b"testserv")
fetch_d = defer.ensureDeferred(
self.well_known_resolver.get_well_known(b"testserv")
)
self.assertEqual(len(clients), 1)
(host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
@@ -1026,7 +1032,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
self.reactor.lookups["testserv"] = "1.2.3.4"
fetch_d = self.well_known_resolver.get_well_known(b"testserv")
fetch_d = defer.ensureDeferred(
self.well_known_resolver.get_well_known(b"testserv")
)
# there should be an attempt to connect on port 443 for the .well-known
clients = self.reactor.tcpClients
@@ -1052,7 +1060,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
# another lookup.
self.reactor.pump((900.0,))
fetch_d = self.well_known_resolver.get_well_known(b"testserv")
fetch_d = defer.ensureDeferred(
self.well_known_resolver.get_well_known(b"testserv")
)
# The resolver may retry a few times, so fonx all requests that come along
attempts = 0
@@ -1082,7 +1092,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
self.reactor.pump((10000.0,))
# Repated the request, this time it should fail if the lookup fails.
fetch_d = self.well_known_resolver.get_well_known(b"testserv")
fetch_d = defer.ensureDeferred(
self.well_known_resolver.get_well_known(b"testserv")
)
clients = self.reactor.tcpClients
(host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)

View File

@@ -37,7 +37,7 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
)
res = yield defer.ensureDeferred(
self.store.get_e2e_device_keys((("user", "device"),))
self.store.get_e2e_device_keys_for_cs_api((("user", "device"),))
)
self.assertIn("user", res)
self.assertIn("device", res["user"])
@@ -76,7 +76,7 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
)
res = yield defer.ensureDeferred(
self.store.get_e2e_device_keys((("user", "device"),))
self.store.get_e2e_device_keys_for_cs_api((("user", "device"),))
)
self.assertIn("user", res)
self.assertIn("device", res["user"])
@@ -108,7 +108,9 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
)
res = yield defer.ensureDeferred(
self.store.get_e2e_device_keys((("user1", "device1"), ("user2", "device2")))
self.store.get_e2e_device_keys_for_cs_api(
(("user1", "device1"), ("user2", "device2"))
)
)
self.assertIn("user1", res)
self.assertIn("device1", res["user1"])

View File

@@ -264,3 +264,108 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
# We assume that so long as `get_next` does correctly advance the
# `persisted_upto_position` in this case, then it will be correct in the
# other cases that are tested above (since they'll hit the same code).
class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
"""Tests MultiWriterIdGenerator that produce *negative* stream IDs.
"""
if not USE_POSTGRES_FOR_TESTS:
skip = "Requires Postgres"
def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
self.db_pool = self.store.db_pool # type: DatabasePool
self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
def _setup_db(self, txn):
txn.execute("CREATE SEQUENCE foobar_seq")
txn.execute(
"""
CREATE TABLE foobar (
stream_id BIGINT NOT NULL,
instance_name TEXT NOT NULL,
data TEXT
);
"""
)
def _create_id_generator(self, instance_name="master") -> MultiWriterIdGenerator:
def _create(conn):
return MultiWriterIdGenerator(
conn,
self.db_pool,
instance_name=instance_name,
table="foobar",
instance_column="instance_name",
id_column="stream_id",
sequence_name="foobar_seq",
positive=False,
)
return self.get_success(self.db_pool.runWithConnection(_create))
def _insert_row(self, instance_name: str, stream_id: int):
"""Insert one row as the given instance with given stream_id.
"""
def _insert(txn):
txn.execute(
"INSERT INTO foobar VALUES (?, ?)", (stream_id, instance_name,),
)
self.get_success(self.db_pool.runInteraction("_insert_row", _insert))
def test_single_instance(self):
"""Test that reads and writes from a single process are handled
correctly.
"""
id_gen = self._create_id_generator()
with self.get_success(id_gen.get_next()) as stream_id:
self._insert_row("master", stream_id)
self.assertEqual(id_gen.get_positions(), {"master": -1})
self.assertEqual(id_gen.get_current_token_for_writer("master"), -1)
self.assertEqual(id_gen.get_persisted_upto_position(), -1)
with self.get_success(id_gen.get_next_mult(3)) as stream_ids:
for stream_id in stream_ids:
self._insert_row("master", stream_id)
self.assertEqual(id_gen.get_positions(), {"master": -4})
self.assertEqual(id_gen.get_current_token_for_writer("master"), -4)
self.assertEqual(id_gen.get_persisted_upto_position(), -4)
# Test loading from DB by creating a second ID gen
second_id_gen = self._create_id_generator()
self.assertEqual(second_id_gen.get_positions(), {"master": -4})
self.assertEqual(second_id_gen.get_current_token_for_writer("master"), -4)
self.assertEqual(second_id_gen.get_persisted_upto_position(), -4)
def test_multiple_instance(self):
"""Tests that having multiple instances that get advanced over
federation works corretly.
"""
id_gen_1 = self._create_id_generator("first")
id_gen_2 = self._create_id_generator("second")
with self.get_success(id_gen_1.get_next()) as stream_id:
self._insert_row("first", stream_id)
id_gen_2.advance("first", stream_id)
self.assertEqual(id_gen_1.get_positions(), {"first": -1})
self.assertEqual(id_gen_2.get_positions(), {"first": -1})
self.assertEqual(id_gen_1.get_persisted_upto_position(), -1)
self.assertEqual(id_gen_2.get_persisted_upto_position(), -1)
with self.get_success(id_gen_2.get_next()) as stream_id:
self._insert_row("second", stream_id)
id_gen_1.advance("second", stream_id)
self.assertEqual(id_gen_1.get_positions(), {"first": -1, "second": -2})
self.assertEqual(id_gen_2.get_positions(), {"first": -1, "second": -2})
self.assertEqual(id_gen_1.get_persisted_upto_position(), -2)
self.assertEqual(id_gen_2.get_persisted_upto_position(), -2)

View File

@@ -13,14 +13,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple
from typing import List, Optional, Tuple
import synapse.server
from synapse.api.constants import EventTypes
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.types import Collection
"""
Utility functions for poking events into the storage of the server under test.
@@ -58,7 +57,7 @@ async def inject_member_event(
async def inject_event(
hs: synapse.server.HomeServer,
room_version: Optional[str] = None,
prev_event_ids: Optional[Collection[str]] = None,
prev_event_ids: Optional[List[str]] = None,
**kwargs
) -> EventBase:
"""Inject a generic event into a room
@@ -80,7 +79,7 @@ async def inject_event(
async def create_event(
hs: synapse.server.HomeServer,
room_version: Optional[str] = None,
prev_event_ids: Optional[Collection[str]] = None,
prev_event_ids: Optional[List[str]] = None,
**kwargs
) -> Tuple[EventBase, EventContext]:
if room_version is None: