Merge remote-tracking branch 'origin/develop' into rei/2528_catchup_fed_outage
This commit is contained in:
@@ -73,7 +73,7 @@ mkdir -p ~/synapse
|
||||
virtualenv -p python3 ~/synapse/env
|
||||
source ~/synapse/env/bin/activate
|
||||
pip install --upgrade pip
|
||||
pip install --upgrade setuptools
|
||||
pip install --upgrade setuptools!=50.0 # setuptools==50.0 fails on some older Python versions
|
||||
pip install matrix-synapse
|
||||
```
|
||||
|
||||
|
||||
1
changelog.d/8199.misc
Normal file
1
changelog.d/8199.misc
Normal file
@@ -0,0 +1 @@
|
||||
Convert various parts of the codebase to async/await.
|
||||
1
changelog.d/8200.misc
Normal file
1
changelog.d/8200.misc
Normal file
@@ -0,0 +1 @@
|
||||
Convert various parts of the codebase to async/await.
|
||||
1
changelog.d/8203.misc
Normal file
1
changelog.d/8203.misc
Normal file
@@ -0,0 +1 @@
|
||||
Make `MultiWriterIDGenerator` work for streams that use negative values.
|
||||
1
changelog.d/8204.misc
Normal file
1
changelog.d/8204.misc
Normal file
@@ -0,0 +1 @@
|
||||
Refactor queries for device keys and cross-signatures.
|
||||
1
changelog.d/8212.bugfix
Normal file
1
changelog.d/8212.bugfix
Normal file
@@ -0,0 +1 @@
|
||||
Do not install setuptools 50.0. It can lead to a broken configuration on some older Python versions.
|
||||
1
changelog.d/8214.misc
Normal file
1
changelog.d/8214.misc
Normal file
@@ -0,0 +1 @@
|
||||
Convert various parts of the codebase to async/await.
|
||||
1
mypy.ini
1
mypy.ini
@@ -28,6 +28,7 @@ files =
|
||||
synapse/handlers/saml_handler.py,
|
||||
synapse/handlers/sync.py,
|
||||
synapse/handlers/ui_auth,
|
||||
synapse/http/federation/well_known_resolver.py,
|
||||
synapse/http/server.py,
|
||||
synapse/http/site.py,
|
||||
synapse/logging/,
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Optional
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import attr
|
||||
from nacl.signing import SigningKey
|
||||
@@ -97,14 +97,14 @@ class EventBuilder(object):
|
||||
def is_state(self):
|
||||
return self._state_key is not None
|
||||
|
||||
async def build(self, prev_event_ids):
|
||||
async def build(self, prev_event_ids: List[str]) -> EventBase:
|
||||
"""Transform into a fully signed and hashed event
|
||||
|
||||
Args:
|
||||
prev_event_ids (list[str]): The event IDs to use as the prev events
|
||||
prev_event_ids: The event IDs to use as the prev events
|
||||
|
||||
Returns:
|
||||
FrozenEvent
|
||||
The signed and hashed event.
|
||||
"""
|
||||
|
||||
state_ids = await self._state.get_current_state_ids(
|
||||
@@ -114,8 +114,13 @@ class EventBuilder(object):
|
||||
|
||||
format_version = self.room_version.event_format
|
||||
if format_version == EventFormatVersions.V1:
|
||||
auth_events = await self._store.add_event_hashes(auth_ids)
|
||||
prev_events = await self._store.add_event_hashes(prev_event_ids)
|
||||
# The types of auth/prev events changes between event versions.
|
||||
auth_events = await self._store.add_event_hashes(
|
||||
auth_ids
|
||||
) # type: Union[List[str], List[Tuple[str, Dict[str, str]]]]
|
||||
prev_events = await self._store.add_event_hashes(
|
||||
prev_event_ids
|
||||
) # type: Union[List[str], List[Tuple[str, Dict[str, str]]]]
|
||||
else:
|
||||
auth_events = auth_ids
|
||||
prev_events = prev_event_ids
|
||||
@@ -138,7 +143,7 @@ class EventBuilder(object):
|
||||
"unsigned": self.unsigned,
|
||||
"depth": depth,
|
||||
"prev_state": [],
|
||||
}
|
||||
} # type: Dict[str, Any]
|
||||
|
||||
if self.is_state():
|
||||
event_dict["state_key"] = self._state_key
|
||||
|
||||
@@ -234,7 +234,9 @@ class DeviceWorkerHandler(BaseHandler):
|
||||
return result
|
||||
|
||||
async def on_federation_query_user_devices(self, user_id):
|
||||
stream_id, devices = await self.store.get_devices_with_keys_by_user(user_id)
|
||||
stream_id, devices = await self.store.get_e2e_device_keys_for_federation_query(
|
||||
user_id
|
||||
)
|
||||
master_key = await self.store.get_e2e_cross_signing_key(user_id, "master")
|
||||
self_signing_key = await self.store.get_e2e_cross_signing_key(
|
||||
user_id, "self_signing"
|
||||
|
||||
@@ -49,14 +49,7 @@ from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.replication.http.send_event import ReplicationSendEventRestServlet
|
||||
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
|
||||
from synapse.storage.state import StateFilter
|
||||
from synapse.types import (
|
||||
Collection,
|
||||
Requester,
|
||||
RoomAlias,
|
||||
StreamToken,
|
||||
UserID,
|
||||
create_requester,
|
||||
)
|
||||
from synapse.types import Requester, RoomAlias, StreamToken, UserID, create_requester
|
||||
from synapse.util import json_decoder
|
||||
from synapse.util.async_helpers import Linearizer
|
||||
from synapse.util.frozenutils import frozendict_json_encoder
|
||||
@@ -446,7 +439,7 @@ class EventCreationHandler(object):
|
||||
event_dict: dict,
|
||||
token_id: Optional[str] = None,
|
||||
txn_id: Optional[str] = None,
|
||||
prev_event_ids: Optional[Collection[str]] = None,
|
||||
prev_event_ids: Optional[List[str]] = None,
|
||||
require_consent: bool = True,
|
||||
) -> Tuple[EventBase, EventContext]:
|
||||
"""
|
||||
@@ -786,7 +779,7 @@ class EventCreationHandler(object):
|
||||
self,
|
||||
builder: EventBuilder,
|
||||
requester: Optional[Requester] = None,
|
||||
prev_event_ids: Optional[Collection[str]] = None,
|
||||
prev_event_ids: Optional[List[str]] = None,
|
||||
) -> Tuple[EventBase, EventContext]:
|
||||
"""Create a new event for a local client
|
||||
|
||||
|
||||
@@ -38,15 +38,7 @@ from synapse.events.builder import create_local_event_from_event_dict
|
||||
from synapse.events.snapshot import EventContext
|
||||
from synapse.events.validator import EventValidator
|
||||
from synapse.storage.roommember import RoomsForUser
|
||||
from synapse.types import (
|
||||
Collection,
|
||||
JsonDict,
|
||||
Requester,
|
||||
RoomAlias,
|
||||
RoomID,
|
||||
StateMap,
|
||||
UserID,
|
||||
)
|
||||
from synapse.types import JsonDict, Requester, RoomAlias, RoomID, StateMap, UserID
|
||||
from synapse.util.async_helpers import Linearizer
|
||||
from synapse.util.distributor import user_joined_room, user_left_room
|
||||
|
||||
@@ -184,7 +176,7 @@ class RoomMemberHandler(object):
|
||||
target: UserID,
|
||||
room_id: str,
|
||||
membership: str,
|
||||
prev_event_ids: Collection[str],
|
||||
prev_event_ids: List[str],
|
||||
txn_id: Optional[str] = None,
|
||||
ratelimit: bool = True,
|
||||
content: Optional[dict] = None,
|
||||
|
||||
@@ -134,8 +134,8 @@ class MatrixFederationAgent(object):
|
||||
and not _is_ip_literal(parsed_uri.hostname)
|
||||
and not parsed_uri.port
|
||||
):
|
||||
well_known_result = yield self._well_known_resolver.get_well_known(
|
||||
parsed_uri.hostname
|
||||
well_known_result = yield defer.ensureDeferred(
|
||||
self._well_known_resolver.get_well_known(parsed_uri.hostname)
|
||||
)
|
||||
delegated_server = well_known_result.delegated_server
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
import logging
|
||||
import random
|
||||
import time
|
||||
from typing import Callable, Dict, Optional, Tuple
|
||||
|
||||
import attr
|
||||
|
||||
@@ -23,6 +24,7 @@ from twisted.internet import defer
|
||||
from twisted.web.client import RedirectAgent, readBody
|
||||
from twisted.web.http import stringToDatetime
|
||||
from twisted.web.http_headers import Headers
|
||||
from twisted.web.iweb import IResponse
|
||||
|
||||
from synapse.logging.context import make_deferred_yieldable
|
||||
from synapse.util import Clock, json_decoder
|
||||
@@ -99,15 +101,14 @@ class WellKnownResolver(object):
|
||||
self._well_known_agent = RedirectAgent(agent)
|
||||
self.user_agent = user_agent
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_well_known(self, server_name):
|
||||
async def get_well_known(self, server_name: bytes) -> WellKnownLookupResult:
|
||||
"""Attempt to fetch and parse a .well-known file for the given server
|
||||
|
||||
Args:
|
||||
server_name (bytes): name of the server, from the requested url
|
||||
server_name: name of the server, from the requested url
|
||||
|
||||
Returns:
|
||||
Deferred[WellKnownLookupResult]: The result of the lookup
|
||||
The result of the lookup
|
||||
"""
|
||||
try:
|
||||
prev_result, expiry, ttl = self._well_known_cache.get_with_expiry(
|
||||
@@ -124,7 +125,9 @@ class WellKnownResolver(object):
|
||||
# requests for the same server in parallel?
|
||||
try:
|
||||
with Measure(self._clock, "get_well_known"):
|
||||
result, cache_period = yield self._fetch_well_known(server_name)
|
||||
result, cache_period = await self._fetch_well_known(
|
||||
server_name
|
||||
) # type: Tuple[Optional[bytes], float]
|
||||
|
||||
except _FetchWellKnownFailure as e:
|
||||
if prev_result and e.temporary:
|
||||
@@ -153,18 +156,17 @@ class WellKnownResolver(object):
|
||||
|
||||
return WellKnownLookupResult(delegated_server=result)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _fetch_well_known(self, server_name):
|
||||
async def _fetch_well_known(self, server_name: bytes) -> Tuple[bytes, float]:
|
||||
"""Actually fetch and parse a .well-known, without checking the cache
|
||||
|
||||
Args:
|
||||
server_name (bytes): name of the server, from the requested url
|
||||
server_name: name of the server, from the requested url
|
||||
|
||||
Raises:
|
||||
_FetchWellKnownFailure if we fail to lookup a result
|
||||
|
||||
Returns:
|
||||
Deferred[Tuple[bytes,int]]: The lookup result and cache period.
|
||||
The lookup result and cache period.
|
||||
"""
|
||||
|
||||
had_valid_well_known = self._had_valid_well_known_cache.get(server_name, False)
|
||||
@@ -172,7 +174,7 @@ class WellKnownResolver(object):
|
||||
# We do this in two steps to differentiate between possibly transient
|
||||
# errors (e.g. can't connect to host, 503 response) and more permenant
|
||||
# errors (such as getting a 404 response).
|
||||
response, body = yield self._make_well_known_request(
|
||||
response, body = await self._make_well_known_request(
|
||||
server_name, retry=had_valid_well_known
|
||||
)
|
||||
|
||||
@@ -215,20 +217,20 @@ class WellKnownResolver(object):
|
||||
|
||||
return result, cache_period
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _make_well_known_request(self, server_name, retry):
|
||||
async def _make_well_known_request(
|
||||
self, server_name: bytes, retry: bool
|
||||
) -> Tuple[IResponse, bytes]:
|
||||
"""Make the well known request.
|
||||
|
||||
This will retry the request if requested and it fails (with unable
|
||||
to connect or receives a 5xx error).
|
||||
|
||||
Args:
|
||||
server_name (bytes)
|
||||
retry (bool): Whether to retry the request if it fails.
|
||||
server_name: name of the server, from the requested url
|
||||
retry: Whether to retry the request if it fails.
|
||||
|
||||
Returns:
|
||||
Deferred[tuple[IResponse, bytes]] Returns the response object and
|
||||
body. Response may be a non-200 response.
|
||||
Returns the response object and body. Response may be a non-200 response.
|
||||
"""
|
||||
uri = b"https://%s/.well-known/matrix/server" % (server_name,)
|
||||
uri_str = uri.decode("ascii")
|
||||
@@ -243,12 +245,12 @@ class WellKnownResolver(object):
|
||||
|
||||
logger.info("Fetching %s", uri_str)
|
||||
try:
|
||||
response = yield make_deferred_yieldable(
|
||||
response = await make_deferred_yieldable(
|
||||
self._well_known_agent.request(
|
||||
b"GET", uri, headers=Headers(headers)
|
||||
)
|
||||
)
|
||||
body = yield make_deferred_yieldable(readBody(response))
|
||||
body = await make_deferred_yieldable(readBody(response))
|
||||
|
||||
if 500 <= response.code < 600:
|
||||
raise Exception("Non-200 response %s" % (response.code,))
|
||||
@@ -265,21 +267,24 @@ class WellKnownResolver(object):
|
||||
logger.info("Error fetching %s: %s. Retrying", uri_str, e)
|
||||
|
||||
# Sleep briefly in the hopes that they come back up
|
||||
yield self._clock.sleep(0.5)
|
||||
await self._clock.sleep(0.5)
|
||||
|
||||
|
||||
def _cache_period_from_headers(headers, time_now=time.time):
|
||||
def _cache_period_from_headers(
|
||||
headers: Headers, time_now: Callable[[], float] = time.time
|
||||
) -> Optional[float]:
|
||||
cache_controls = _parse_cache_control(headers)
|
||||
|
||||
if b"no-store" in cache_controls:
|
||||
return 0
|
||||
|
||||
if b"max-age" in cache_controls:
|
||||
try:
|
||||
max_age = int(cache_controls[b"max-age"])
|
||||
return max_age
|
||||
except ValueError:
|
||||
pass
|
||||
max_age = cache_controls[b"max-age"]
|
||||
if max_age:
|
||||
try:
|
||||
return int(max_age)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
expires = headers.getRawHeaders(b"expires")
|
||||
if expires is not None:
|
||||
@@ -295,7 +300,7 @@ def _cache_period_from_headers(headers, time_now=time.time):
|
||||
return None
|
||||
|
||||
|
||||
def _parse_cache_control(headers):
|
||||
def _parse_cache_control(headers: Headers) -> Dict[bytes, Optional[bytes]]:
|
||||
cache_controls = {}
|
||||
for hdr in headers.getRawHeaders(b"cache-control", []):
|
||||
for directive in hdr.split(b","):
|
||||
|
||||
@@ -74,6 +74,10 @@ REQUIREMENTS = [
|
||||
"Jinja2>=2.9",
|
||||
"bleach>=1.4.3",
|
||||
"typing-extensions>=3.7.4",
|
||||
# setuptools is required by a variety of dependencies, unfortunately version
|
||||
# 50.0 is incompatible with older Python versions, see
|
||||
# https://github.com/pypa/setuptools/issues/2352
|
||||
"setuptools!=50.0",
|
||||
]
|
||||
|
||||
CONDITIONAL_REQUIREMENTS = {
|
||||
|
||||
@@ -48,6 +48,9 @@ class SlavedDeviceStore(EndToEndKeyWorkerStore, DeviceWorkerStore, BaseSlavedSto
|
||||
"DeviceListFederationStreamChangeCache", device_list_max
|
||||
)
|
||||
|
||||
def get_device_stream_token(self) -> int:
|
||||
return self._device_list_id_gen.get_current_token()
|
||||
|
||||
def process_replication_rows(self, stream_name, instance_name, token, rows):
|
||||
if stream_name == DeviceListsStream.NAME:
|
||||
self._device_list_id_gen.advance(instance_name, token)
|
||||
|
||||
@@ -18,7 +18,7 @@
|
||||
import calendar
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from synapse.api.constants import PresenceState
|
||||
from synapse.config.homeserver import HomeServerConfig
|
||||
@@ -264,6 +264,9 @@ class DataStore(
|
||||
# Used in _generate_user_daily_visits to keep track of progress
|
||||
self._last_user_visit_update = self._get_start_of_day()
|
||||
|
||||
def get_device_stream_token(self) -> int:
|
||||
return self._device_list_id_gen.get_current_token()
|
||||
|
||||
def take_presence_startup_info(self):
|
||||
active_on_startup = self._presence_on_startup
|
||||
self._presence_on_startup = None
|
||||
@@ -291,16 +294,16 @@ class DataStore(
|
||||
|
||||
return [UserPresenceState(**row) for row in rows]
|
||||
|
||||
def count_daily_users(self):
|
||||
async def count_daily_users(self) -> int:
|
||||
"""
|
||||
Counts the number of users who used this homeserver in the last 24 hours.
|
||||
"""
|
||||
yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24)
|
||||
return self.db_pool.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
"count_daily_users", self._count_users, yesterday
|
||||
)
|
||||
|
||||
def count_monthly_users(self):
|
||||
async def count_monthly_users(self) -> int:
|
||||
"""
|
||||
Counts the number of users who used this homeserver in the last 30 days.
|
||||
Note this method is intended for phonehome metrics only and is different
|
||||
@@ -308,7 +311,7 @@ class DataStore(
|
||||
amongst other things, includes a 3 day grace period before a user counts.
|
||||
"""
|
||||
thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30)
|
||||
return self.db_pool.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
"count_monthly_users", self._count_users, thirty_days_ago
|
||||
)
|
||||
|
||||
@@ -327,15 +330,15 @@ class DataStore(
|
||||
(count,) = txn.fetchone()
|
||||
return count
|
||||
|
||||
def count_r30_users(self):
|
||||
async def count_r30_users(self) -> Dict[str, int]:
|
||||
"""
|
||||
Counts the number of 30 day retained users, defined as:-
|
||||
* Users who have created their accounts more than 30 days ago
|
||||
* Where last seen at most 30 days ago
|
||||
* Where account creation and last_seen are > 30 days apart
|
||||
|
||||
Returns counts globaly for a given user as well as breaking
|
||||
by platform
|
||||
Returns:
|
||||
A mapping of counts globally as well as broken out by platform.
|
||||
"""
|
||||
|
||||
def _count_r30_users(txn):
|
||||
@@ -408,7 +411,7 @@ class DataStore(
|
||||
|
||||
return results
|
||||
|
||||
return self.db_pool.runInteraction("count_r30_users", _count_r30_users)
|
||||
return await self.db_pool.runInteraction("count_r30_users", _count_r30_users)
|
||||
|
||||
def _get_start_of_day(self):
|
||||
"""
|
||||
@@ -418,7 +421,7 @@ class DataStore(
|
||||
today_start = calendar.timegm((now.tm_year, now.tm_mon, now.tm_mday, 0, 0, 0))
|
||||
return today_start * 1000
|
||||
|
||||
def generate_user_daily_visits(self):
|
||||
async def generate_user_daily_visits(self) -> None:
|
||||
"""
|
||||
Generates daily visit data for use in cohort/ retention analysis
|
||||
"""
|
||||
@@ -473,7 +476,7 @@ class DataStore(
|
||||
# frequently
|
||||
self._last_user_visit_update = now
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
await self.db_pool.runInteraction(
|
||||
"generate_user_daily_visits", _generate_user_daily_visits
|
||||
)
|
||||
|
||||
@@ -497,22 +500,28 @@ class DataStore(
|
||||
desc="get_users",
|
||||
)
|
||||
|
||||
def get_users_paginate(
|
||||
self, start, limit, user_id=None, name=None, guests=True, deactivated=False
|
||||
):
|
||||
async def get_users_paginate(
|
||||
self,
|
||||
start: int,
|
||||
limit: int,
|
||||
user_id: Optional[str] = None,
|
||||
name: Optional[str] = None,
|
||||
guests: bool = True,
|
||||
deactivated: bool = False,
|
||||
) -> Tuple[List[Dict[str, Any]], int]:
|
||||
"""Function to retrieve a paginated list of users from
|
||||
users list. This will return a json list of users and the
|
||||
total number of users matching the filter criteria.
|
||||
|
||||
Args:
|
||||
start (int): start number to begin the query from
|
||||
limit (int): number of rows to retrieve
|
||||
user_id (string): search for user_id. ignored if name is not None
|
||||
name (string): search for local part of user_id or display name
|
||||
guests (bool): whether to in include guest users
|
||||
deactivated (bool): whether to include deactivated users
|
||||
start: start number to begin the query from
|
||||
limit: number of rows to retrieve
|
||||
user_id: search for user_id. ignored if name is not None
|
||||
name: search for local part of user_id or display name
|
||||
guests: whether to in include guest users
|
||||
deactivated: whether to include deactivated users
|
||||
Returns:
|
||||
defer.Deferred: resolves to list[dict[str, Any]], int
|
||||
A tuple of a list of mappings from user to information and a count of total users.
|
||||
"""
|
||||
|
||||
def get_users_paginate_txn(txn):
|
||||
@@ -555,7 +564,7 @@ class DataStore(
|
||||
users = self.db_pool.cursor_to_dict(txn)
|
||||
return users, count
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_users_paginate_txn", get_users_paginate_txn
|
||||
)
|
||||
|
||||
|
||||
@@ -396,7 +396,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
|
||||
self._batch_row_update[key] = (user_agent, device_id, now)
|
||||
|
||||
@wrap_as_background_process("update_client_ips")
|
||||
def _update_client_ips_batch(self):
|
||||
async def _update_client_ips_batch(self) -> None:
|
||||
|
||||
# If the DB pool has already terminated, don't try updating
|
||||
if not self.db_pool.is_running():
|
||||
@@ -405,7 +405,7 @@ class ClientIpStore(ClientIpBackgroundUpdateStore):
|
||||
to_update = self._batch_row_update
|
||||
self._batch_row_update = {}
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
await self.db_pool.runInteraction(
|
||||
"_update_client_ips_batch", self._update_client_ips_batch_txn, to_update
|
||||
)
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import abc
|
||||
import logging
|
||||
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple
|
||||
|
||||
@@ -101,7 +102,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
||||
update included in the response), and the list of updates, where
|
||||
each update is a pair of EDU type and EDU contents.
|
||||
"""
|
||||
now_stream_id = self._device_list_id_gen.get_current_token()
|
||||
now_stream_id = self.get_device_stream_token()
|
||||
|
||||
has_changed = self._device_list_federation_stream_cache.has_entity_changed(
|
||||
destination, int(from_stream_id)
|
||||
@@ -312,9 +313,9 @@ class DeviceWorkerStore(SQLBaseStore):
|
||||
|
||||
return results
|
||||
|
||||
def _get_last_device_update_for_remote_user(
|
||||
async def _get_last_device_update_for_remote_user(
|
||||
self, destination: str, user_id: str, from_stream_id: int
|
||||
):
|
||||
) -> int:
|
||||
def f(txn):
|
||||
prev_sent_id_sql = """
|
||||
SELECT coalesce(max(stream_id), 0) as stream_id
|
||||
@@ -325,12 +326,16 @@ class DeviceWorkerStore(SQLBaseStore):
|
||||
rows = txn.fetchall()
|
||||
return rows[0][0]
|
||||
|
||||
return self.db_pool.runInteraction("get_last_device_update_for_remote_user", f)
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_last_device_update_for_remote_user", f
|
||||
)
|
||||
|
||||
def mark_as_sent_devices_by_remote(self, destination: str, stream_id: int):
|
||||
async def mark_as_sent_devices_by_remote(
|
||||
self, destination: str, stream_id: int
|
||||
) -> None:
|
||||
"""Mark that updates have successfully been sent to the destination.
|
||||
"""
|
||||
return self.db_pool.runInteraction(
|
||||
await self.db_pool.runInteraction(
|
||||
"mark_as_sent_devices_by_remote",
|
||||
self._mark_as_sent_devices_by_remote_txn,
|
||||
destination,
|
||||
@@ -412,8 +417,10 @@ class DeviceWorkerStore(SQLBaseStore):
|
||||
},
|
||||
)
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_device_stream_token(self) -> int:
|
||||
return self._device_list_id_gen.get_current_token()
|
||||
"""Get the current stream id from the _device_list_id_gen"""
|
||||
...
|
||||
|
||||
@trace
|
||||
async def get_user_devices_from_cache(
|
||||
@@ -481,51 +488,6 @@ class DeviceWorkerStore(SQLBaseStore):
|
||||
device["device_id"]: db_to_json(device["content"]) for device in devices
|
||||
}
|
||||
|
||||
def get_devices_with_keys_by_user(self, user_id: str):
|
||||
"""Get all devices (with any device keys) for a user
|
||||
|
||||
Returns:
|
||||
Deferred which resolves to (stream_id, devices)
|
||||
"""
|
||||
return self.db_pool.runInteraction(
|
||||
"get_devices_with_keys_by_user",
|
||||
self._get_devices_with_keys_by_user_txn,
|
||||
user_id,
|
||||
)
|
||||
|
||||
def _get_devices_with_keys_by_user_txn(
|
||||
self, txn: LoggingTransaction, user_id: str
|
||||
) -> Tuple[int, List[JsonDict]]:
|
||||
now_stream_id = self._device_list_id_gen.get_current_token()
|
||||
|
||||
devices = self._get_e2e_device_keys_txn(txn, [(user_id, None)])
|
||||
|
||||
if devices:
|
||||
user_devices = devices[user_id]
|
||||
results = []
|
||||
for device_id, device in user_devices.items():
|
||||
result = {"device_id": device_id}
|
||||
|
||||
key_json = device.get("key_json", None)
|
||||
if key_json:
|
||||
result["keys"] = db_to_json(key_json)
|
||||
|
||||
if "signatures" in device:
|
||||
for sig_user_id, sigs in device["signatures"].items():
|
||||
result["keys"].setdefault("signatures", {}).setdefault(
|
||||
sig_user_id, {}
|
||||
).update(sigs)
|
||||
|
||||
device_display_name = device.get("device_display_name", None)
|
||||
if device_display_name:
|
||||
result["device_display_name"] = device_display_name
|
||||
|
||||
results.append(result)
|
||||
|
||||
return now_stream_id, results
|
||||
|
||||
return now_stream_id, []
|
||||
|
||||
async def get_users_whose_devices_changed(
|
||||
self, from_key: str, user_ids: Iterable[str]
|
||||
) -> Set[str]:
|
||||
@@ -726,7 +688,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
||||
desc="make_remote_user_device_cache_as_stale",
|
||||
)
|
||||
|
||||
def mark_remote_user_device_list_as_unsubscribed(self, user_id: str):
|
||||
async def mark_remote_user_device_list_as_unsubscribed(self, user_id: str) -> None:
|
||||
"""Mark that we no longer track device lists for remote user.
|
||||
"""
|
||||
|
||||
@@ -740,7 +702,7 @@ class DeviceWorkerStore(SQLBaseStore):
|
||||
txn, self.get_device_list_last_stream_id_for_remote, (user_id,)
|
||||
)
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
await self.db_pool.runInteraction(
|
||||
"mark_remote_user_device_list_as_unsubscribed",
|
||||
_mark_remote_user_device_list_as_unsubscribed_txn,
|
||||
)
|
||||
@@ -1001,9 +963,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
||||
desc="update_device",
|
||||
)
|
||||
|
||||
def update_remote_device_list_cache_entry(
|
||||
async def update_remote_device_list_cache_entry(
|
||||
self, user_id: str, device_id: str, content: JsonDict, stream_id: int
|
||||
):
|
||||
) -> None:
|
||||
"""Updates a single device in the cache of a remote user's devicelist.
|
||||
|
||||
Note: assumes that we are the only thread that can be updating this user's
|
||||
@@ -1014,11 +976,8 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
||||
device_id: ID of decivice being updated
|
||||
content: new data on this device
|
||||
stream_id: the version of the device list
|
||||
|
||||
Returns:
|
||||
Deferred[None]
|
||||
"""
|
||||
return self.db_pool.runInteraction(
|
||||
await self.db_pool.runInteraction(
|
||||
"update_remote_device_list_cache_entry",
|
||||
self._update_remote_device_list_cache_entry_txn,
|
||||
user_id,
|
||||
@@ -1070,9 +1029,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
||||
lock=False,
|
||||
)
|
||||
|
||||
def update_remote_device_list_cache(
|
||||
async def update_remote_device_list_cache(
|
||||
self, user_id: str, devices: List[dict], stream_id: int
|
||||
):
|
||||
) -> None:
|
||||
"""Replace the entire cache of the remote user's devices.
|
||||
|
||||
Note: assumes that we are the only thread that can be updating this user's
|
||||
@@ -1082,11 +1041,8 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
||||
user_id: User to update device list for
|
||||
devices: list of device objects supplied over federation
|
||||
stream_id: the version of the device list
|
||||
|
||||
Returns:
|
||||
Deferred[None]
|
||||
"""
|
||||
return self.db_pool.runInteraction(
|
||||
await self.db_pool.runInteraction(
|
||||
"update_remote_device_list_cache",
|
||||
self._update_remote_device_list_cache_txn,
|
||||
user_id,
|
||||
@@ -1096,7 +1052,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
|
||||
|
||||
def _update_remote_device_list_cache_txn(
|
||||
self, txn: LoggingTransaction, user_id: str, devices: List[dict], stream_id: int
|
||||
):
|
||||
) -> None:
|
||||
self.db_pool.simple_delete_txn(
|
||||
txn, table="device_lists_remote_cache", keyvalues={"user_id": user_id}
|
||||
)
|
||||
|
||||
@@ -159,9 +159,9 @@ class DirectoryStore(DirectoryWorkerStore):
|
||||
|
||||
return room_id
|
||||
|
||||
def update_aliases_for_room(
|
||||
async def update_aliases_for_room(
|
||||
self, old_room_id: str, new_room_id: str, creator: Optional[str] = None,
|
||||
):
|
||||
) -> None:
|
||||
"""Repoint all of the aliases for a given room, to a different room.
|
||||
|
||||
Args:
|
||||
@@ -189,6 +189,6 @@ class DirectoryStore(DirectoryWorkerStore):
|
||||
txn, self.get_aliases_for_room, (new_room_id,)
|
||||
)
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
await self.db_pool.runInteraction(
|
||||
"_update_aliases_for_room_txn", _update_aliases_for_room_txn
|
||||
)
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import abc
|
||||
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple
|
||||
|
||||
from canonicaljson import encode_canonical_json
|
||||
@@ -22,7 +23,7 @@ from twisted.enterprise.adbapi import Connection
|
||||
|
||||
from synapse.logging.opentracing import log_kv, set_tag, trace
|
||||
from synapse.storage._base import SQLBaseStore, db_to_json
|
||||
from synapse.storage.database import make_in_list_sql_clause
|
||||
from synapse.storage.database import LoggingTransaction, make_in_list_sql_clause
|
||||
from synapse.types import JsonDict
|
||||
from synapse.util import json_encoder
|
||||
from synapse.util.caches.descriptors import cached, cachedList
|
||||
@@ -33,6 +34,51 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class EndToEndKeyWorkerStore(SQLBaseStore):
|
||||
def get_e2e_device_keys_for_federation_query(self, user_id: str):
|
||||
"""Get all devices (with any device keys) for a user
|
||||
|
||||
Returns:
|
||||
Deferred which resolves to (stream_id, devices)
|
||||
"""
|
||||
return self.db_pool.runInteraction(
|
||||
"get_e2e_device_keys_for_federation_query",
|
||||
self._get_e2e_device_keys_for_federation_query_txn,
|
||||
user_id,
|
||||
)
|
||||
|
||||
def _get_e2e_device_keys_for_federation_query_txn(
|
||||
self, txn: LoggingTransaction, user_id: str
|
||||
) -> Tuple[int, List[JsonDict]]:
|
||||
now_stream_id = self.get_device_stream_token()
|
||||
|
||||
devices = self._get_e2e_device_keys_txn(txn, [(user_id, None)])
|
||||
|
||||
if devices:
|
||||
user_devices = devices[user_id]
|
||||
results = []
|
||||
for device_id, device in user_devices.items():
|
||||
result = {"device_id": device_id}
|
||||
|
||||
key_json = device.get("key_json", None)
|
||||
if key_json:
|
||||
result["keys"] = db_to_json(key_json)
|
||||
|
||||
if "signatures" in device:
|
||||
for sig_user_id, sigs in device["signatures"].items():
|
||||
result["keys"].setdefault("signatures", {}).setdefault(
|
||||
sig_user_id, {}
|
||||
).update(sigs)
|
||||
|
||||
device_display_name = device.get("device_display_name", None)
|
||||
if device_display_name:
|
||||
result["device_display_name"] = device_display_name
|
||||
|
||||
results.append(result)
|
||||
|
||||
return now_stream_id, results
|
||||
|
||||
return now_stream_id, []
|
||||
|
||||
@trace
|
||||
async def get_e2e_device_keys_for_cs_api(
|
||||
self, query_list: List[Tuple[str, Optional[str]]]
|
||||
@@ -533,6 +579,11 @@ class EndToEndKeyWorkerStore(SQLBaseStore):
|
||||
_get_all_user_signature_changes_for_remotes_txn,
|
||||
)
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_device_stream_token(self) -> int:
|
||||
"""Get the current stream id from the _device_list_id_gen"""
|
||||
...
|
||||
|
||||
|
||||
class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
|
||||
def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys):
|
||||
|
||||
@@ -823,20 +823,24 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
|
||||
return event_dict
|
||||
|
||||
def _maybe_redact_event_row(self, original_ev, redactions, event_map):
|
||||
def _maybe_redact_event_row(
|
||||
self,
|
||||
original_ev: EventBase,
|
||||
redactions: Iterable[str],
|
||||
event_map: Dict[str, EventBase],
|
||||
) -> Optional[EventBase]:
|
||||
"""Given an event object and a list of possible redacting event ids,
|
||||
determine whether to honour any of those redactions and if so return a redacted
|
||||
event.
|
||||
|
||||
Args:
|
||||
original_ev (EventBase):
|
||||
redactions (iterable[str]): list of event ids of potential redaction events
|
||||
event_map (dict[str, EventBase]): other events which have been fetched, in
|
||||
which we can look up the redaaction events. Map from event id to event.
|
||||
original_ev: The original event.
|
||||
redactions: list of event ids of potential redaction events
|
||||
event_map: other events which have been fetched, in which we can
|
||||
look up the redaaction events. Map from event id to event.
|
||||
|
||||
Returns:
|
||||
Deferred[EventBase|None]: if the event should be redacted, a pruned
|
||||
event object. Otherwise, None.
|
||||
If the event should be redacted, a pruned event object. Otherwise, None.
|
||||
"""
|
||||
if original_ev.type == "m.room.create":
|
||||
# we choose to ignore redactions of m.room.create events.
|
||||
@@ -946,17 +950,17 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
row = txn.fetchone()
|
||||
return row[0] if row else 0
|
||||
|
||||
def get_current_state_event_counts(self, room_id):
|
||||
async def get_current_state_event_counts(self, room_id: str) -> int:
|
||||
"""
|
||||
Gets the current number of state events in a room.
|
||||
|
||||
Args:
|
||||
room_id (str)
|
||||
room_id: The room ID to query.
|
||||
|
||||
Returns:
|
||||
Deferred[int]
|
||||
The current number of state events.
|
||||
"""
|
||||
return self.db_pool.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_current_state_event_counts",
|
||||
self._get_current_state_event_counts_txn,
|
||||
room_id,
|
||||
@@ -991,7 +995,9 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
"""The current maximum token that events have reached"""
|
||||
return self._stream_id_gen.get_current_token()
|
||||
|
||||
def get_all_new_forward_event_rows(self, last_id, current_id, limit):
|
||||
async def get_all_new_forward_event_rows(
|
||||
self, last_id: int, current_id: int, limit: int
|
||||
) -> List[Tuple]:
|
||||
"""Returns new events, for the Events replication stream
|
||||
|
||||
Args:
|
||||
@@ -999,7 +1005,7 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
current_id: the maximum stream_id to return up to
|
||||
limit: the maximum number of rows to return
|
||||
|
||||
Returns: Deferred[List[Tuple]]
|
||||
Returns:
|
||||
a list of events stream rows. Each tuple consists of a stream id as
|
||||
the first element, followed by fields suitable for casting into an
|
||||
EventsStreamRow.
|
||||
@@ -1020,18 +1026,20 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
txn.execute(sql, (last_id, current_id, limit))
|
||||
return txn.fetchall()
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_all_new_forward_event_rows", get_all_new_forward_event_rows
|
||||
)
|
||||
|
||||
def get_ex_outlier_stream_rows(self, last_id, current_id):
|
||||
async def get_ex_outlier_stream_rows(
|
||||
self, last_id: int, current_id: int
|
||||
) -> List[Tuple]:
|
||||
"""Returns de-outliered events, for the Events replication stream
|
||||
|
||||
Args:
|
||||
last_id: the last stream_id from the previous batch.
|
||||
current_id: the maximum stream_id to return up to
|
||||
|
||||
Returns: Deferred[List[Tuple]]
|
||||
Returns:
|
||||
a list of events stream rows. Each tuple consists of a stream id as
|
||||
the first element, followed by fields suitable for casting into an
|
||||
EventsStreamRow.
|
||||
@@ -1054,7 +1062,7 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
txn.execute(sql, (last_id, current_id))
|
||||
return txn.fetchall()
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_ex_outlier_stream_rows", get_ex_outlier_stream_rows_txn
|
||||
)
|
||||
|
||||
@@ -1226,11 +1234,11 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
|
||||
return (int(res["topological_ordering"]), int(res["stream_ordering"]))
|
||||
|
||||
def get_next_event_to_expire(self):
|
||||
async def get_next_event_to_expire(self) -> Optional[Tuple[str, int]]:
|
||||
"""Retrieve the entry with the lowest expiry timestamp in the event_expiry
|
||||
table, or None if there's no more event to expire.
|
||||
|
||||
Returns: Deferred[Optional[Tuple[str, int]]]
|
||||
Returns:
|
||||
A tuple containing the event ID as its first element and an expiry timestamp
|
||||
as its second one, if there's at least one row in the event_expiry table.
|
||||
None otherwise.
|
||||
@@ -1246,6 +1254,6 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
|
||||
return txn.fetchone()
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
desc="get_next_event_to_expire", func=get_next_event_to_expire_txn
|
||||
)
|
||||
|
||||
@@ -17,6 +17,7 @@ from canonicaljson import encode_canonical_json
|
||||
|
||||
from synapse.api.errors import Codes, SynapseError
|
||||
from synapse.storage._base import SQLBaseStore, db_to_json
|
||||
from synapse.types import JsonDict
|
||||
from synapse.util.caches.descriptors import cached
|
||||
|
||||
|
||||
@@ -40,7 +41,7 @@ class FilteringStore(SQLBaseStore):
|
||||
|
||||
return db_to_json(def_json)
|
||||
|
||||
def add_user_filter(self, user_localpart, user_filter):
|
||||
async def add_user_filter(self, user_localpart: str, user_filter: JsonDict) -> str:
|
||||
def_json = encode_canonical_json(user_filter)
|
||||
|
||||
# Need an atomic transaction to SELECT the maximal ID so far then
|
||||
@@ -71,4 +72,4 @@ class FilteringStore(SQLBaseStore):
|
||||
|
||||
return filter_id
|
||||
|
||||
return self.db_pool.runInteraction("add_user_filter", _do_txn)
|
||||
return await self.db_pool.runInteraction("add_user_filter", _do_txn)
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import Optional
|
||||
|
||||
from synapse.storage._base import SQLBaseStore
|
||||
|
||||
|
||||
@@ -15,7 +17,9 @@ class OpenIdStore(SQLBaseStore):
|
||||
desc="insert_open_id_token",
|
||||
)
|
||||
|
||||
def get_user_id_for_open_id_token(self, token, ts_now_ms):
|
||||
async def get_user_id_for_open_id_token(
|
||||
self, token: str, ts_now_ms: int
|
||||
) -> Optional[str]:
|
||||
def get_user_id_for_token_txn(txn):
|
||||
sql = (
|
||||
"SELECT user_id FROM open_id_tokens"
|
||||
@@ -30,6 +34,6 @@ class OpenIdStore(SQLBaseStore):
|
||||
else:
|
||||
return rows[0][0]
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_user_id_for_token", get_user_id_for_token_txn
|
||||
)
|
||||
|
||||
@@ -138,7 +138,9 @@ class ProfileStore(ProfileWorkerStore):
|
||||
desc="delete_remote_profile_cache",
|
||||
)
|
||||
|
||||
def get_remote_profile_cache_entries_that_expire(self, last_checked):
|
||||
async def get_remote_profile_cache_entries_that_expire(
|
||||
self, last_checked: int
|
||||
) -> Dict[str, str]:
|
||||
"""Get all users who haven't been checked since `last_checked`
|
||||
"""
|
||||
|
||||
@@ -153,7 +155,7 @@ class ProfileStore(ProfileWorkerStore):
|
||||
|
||||
return self.db_pool.cursor_to_dict(txn)
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_remote_profile_cache_entries_that_expire",
|
||||
_get_remote_profile_cache_entries_that_expire_txn,
|
||||
)
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import Any, Tuple
|
||||
from typing import Any, List, Set, Tuple
|
||||
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.storage._base import SQLBaseStore
|
||||
@@ -25,25 +25,24 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
|
||||
def purge_history(self, room_id, token, delete_local_events):
|
||||
async def purge_history(
|
||||
self, room_id: str, token: str, delete_local_events: bool
|
||||
) -> Set[int]:
|
||||
"""Deletes room history before a certain point
|
||||
|
||||
Args:
|
||||
room_id (str):
|
||||
|
||||
token (str): A topological token to delete events before
|
||||
|
||||
delete_local_events (bool):
|
||||
room_id:
|
||||
token: A topological token to delete events before
|
||||
delete_local_events:
|
||||
if True, we will delete local events as well as remote ones
|
||||
(instead of just marking them as outliers and deleting their
|
||||
state groups).
|
||||
|
||||
Returns:
|
||||
Deferred[set[int]]: The set of state groups that are referenced by
|
||||
deleted events.
|
||||
The set of state groups that are referenced by deleted events.
|
||||
"""
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
"purge_history",
|
||||
self._purge_history_txn,
|
||||
room_id,
|
||||
@@ -283,17 +282,18 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
|
||||
|
||||
return referenced_state_groups
|
||||
|
||||
def purge_room(self, room_id):
|
||||
async def purge_room(self, room_id: str) -> List[int]:
|
||||
"""Deletes all record of a room
|
||||
|
||||
Args:
|
||||
room_id (str)
|
||||
room_id
|
||||
|
||||
Returns:
|
||||
Deferred[List[int]]: The list of state groups to delete.
|
||||
The list of state groups to delete.
|
||||
"""
|
||||
|
||||
return self.db_pool.runInteraction("purge_room", self._purge_room_txn, room_id)
|
||||
return await self.db_pool.runInteraction(
|
||||
"purge_room", self._purge_room_txn, room_id
|
||||
)
|
||||
|
||||
def _purge_room_txn(self, txn, room_id):
|
||||
# First we fetch all the state groups that should be deleted, before
|
||||
|
||||
@@ -18,8 +18,6 @@ import abc
|
||||
import logging
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.push.baserules import list_with_base_rules
|
||||
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
|
||||
from synapse.storage._base import SQLBaseStore, db_to_json
|
||||
@@ -149,9 +147,11 @@ class PushRulesWorkerStore(
|
||||
)
|
||||
return {r["rule_id"]: False if r["enabled"] == 0 else True for r in results}
|
||||
|
||||
def have_push_rules_changed_for_user(self, user_id, last_id):
|
||||
async def have_push_rules_changed_for_user(
|
||||
self, user_id: str, last_id: int
|
||||
) -> bool:
|
||||
if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id):
|
||||
return defer.succeed(False)
|
||||
return False
|
||||
else:
|
||||
|
||||
def have_push_rules_changed_txn(txn):
|
||||
@@ -163,7 +163,7 @@ class PushRulesWorkerStore(
|
||||
(count,) = txn.fetchone()
|
||||
return bool(count)
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
"have_push_rules_changed", have_push_rules_changed_txn
|
||||
)
|
||||
|
||||
|
||||
@@ -276,12 +276,14 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
||||
}
|
||||
return results
|
||||
|
||||
def get_users_sent_receipts_between(self, last_id: int, current_id: int):
|
||||
async def get_users_sent_receipts_between(
|
||||
self, last_id: int, current_id: int
|
||||
) -> List[str]:
|
||||
"""Get all users who sent receipts between `last_id` exclusive and
|
||||
`current_id` inclusive.
|
||||
|
||||
Returns:
|
||||
Deferred[List[str]]
|
||||
The list of users.
|
||||
"""
|
||||
|
||||
if last_id == current_id:
|
||||
@@ -296,7 +298,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
|
||||
|
||||
return [r[0] for r in txn]
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_users_sent_receipts_between", _get_users_sent_receipts_between_txn
|
||||
)
|
||||
|
||||
@@ -553,8 +555,10 @@ class ReceiptsStore(ReceiptsWorkerStore):
|
||||
|
||||
return stream_id, max_persisted_id
|
||||
|
||||
def insert_graph_receipt(self, room_id, receipt_type, user_id, event_ids, data):
|
||||
return self.db_pool.runInteraction(
|
||||
async def insert_graph_receipt(
|
||||
self, room_id, receipt_type, user_id, event_ids, data
|
||||
):
|
||||
return await self.db_pool.runInteraction(
|
||||
"insert_graph_receipt",
|
||||
self.insert_graph_receipt_txn,
|
||||
room_id,
|
||||
|
||||
@@ -34,38 +34,33 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class RelationsWorkerStore(SQLBaseStore):
|
||||
@cached(tree=True)
|
||||
def get_relations_for_event(
|
||||
async def get_relations_for_event(
|
||||
self,
|
||||
event_id,
|
||||
relation_type=None,
|
||||
event_type=None,
|
||||
aggregation_key=None,
|
||||
limit=5,
|
||||
direction="b",
|
||||
from_token=None,
|
||||
to_token=None,
|
||||
):
|
||||
event_id: str,
|
||||
relation_type: Optional[str] = None,
|
||||
event_type: Optional[str] = None,
|
||||
aggregation_key: Optional[str] = None,
|
||||
limit: int = 5,
|
||||
direction: str = "b",
|
||||
from_token: Optional[RelationPaginationToken] = None,
|
||||
to_token: Optional[RelationPaginationToken] = None,
|
||||
) -> PaginationChunk:
|
||||
"""Get a list of relations for an event, ordered by topological ordering.
|
||||
|
||||
Args:
|
||||
event_id (str): Fetch events that relate to this event ID.
|
||||
relation_type (str|None): Only fetch events with this relation
|
||||
type, if given.
|
||||
event_type (str|None): Only fetch events with this event type, if
|
||||
given.
|
||||
aggregation_key (str|None): Only fetch events with this aggregation
|
||||
key, if given.
|
||||
limit (int): Only fetch the most recent `limit` events.
|
||||
direction (str): Whether to fetch the most recent first (`"b"`) or
|
||||
the oldest first (`"f"`).
|
||||
from_token (RelationPaginationToken|None): Fetch rows from the given
|
||||
token, or from the start if None.
|
||||
to_token (RelationPaginationToken|None): Fetch rows up to the given
|
||||
token, or up to the end if None.
|
||||
event_id: Fetch events that relate to this event ID.
|
||||
relation_type: Only fetch events with this relation type, if given.
|
||||
event_type: Only fetch events with this event type, if given.
|
||||
aggregation_key: Only fetch events with this aggregation key, if given.
|
||||
limit: Only fetch the most recent `limit` events.
|
||||
direction: Whether to fetch the most recent first (`"b"`) or the
|
||||
oldest first (`"f"`).
|
||||
from_token: Fetch rows from the given token, or from the start if None.
|
||||
to_token: Fetch rows up to the given token, or up to the end if None.
|
||||
|
||||
Returns:
|
||||
Deferred[PaginationChunk]: List of event IDs that match relations
|
||||
requested. The rows are of the form `{"event_id": "..."}`.
|
||||
List of event IDs that match relations requested. The rows are of
|
||||
the form `{"event_id": "..."}`.
|
||||
"""
|
||||
|
||||
where_clause = ["relates_to_id = ?"]
|
||||
@@ -131,20 +126,20 @@ class RelationsWorkerStore(SQLBaseStore):
|
||||
chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token
|
||||
)
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_recent_references_for_event", _get_recent_references_for_event_txn
|
||||
)
|
||||
|
||||
@cached(tree=True)
|
||||
def get_aggregation_groups_for_event(
|
||||
async def get_aggregation_groups_for_event(
|
||||
self,
|
||||
event_id,
|
||||
event_type=None,
|
||||
limit=5,
|
||||
direction="b",
|
||||
from_token=None,
|
||||
to_token=None,
|
||||
):
|
||||
event_id: str,
|
||||
event_type: Optional[str] = None,
|
||||
limit: int = 5,
|
||||
direction: str = "b",
|
||||
from_token: Optional[AggregationPaginationToken] = None,
|
||||
to_token: Optional[AggregationPaginationToken] = None,
|
||||
) -> PaginationChunk:
|
||||
"""Get a list of annotations on the event, grouped by event type and
|
||||
aggregation key, sorted by count.
|
||||
|
||||
@@ -152,21 +147,17 @@ class RelationsWorkerStore(SQLBaseStore):
|
||||
on an event.
|
||||
|
||||
Args:
|
||||
event_id (str): Fetch events that relate to this event ID.
|
||||
event_type (str|None): Only fetch events with this event type, if
|
||||
given.
|
||||
limit (int): Only fetch the `limit` groups.
|
||||
direction (str): Whether to fetch the highest count first (`"b"`) or
|
||||
event_id: Fetch events that relate to this event ID.
|
||||
event_type: Only fetch events with this event type, if given.
|
||||
limit: Only fetch the `limit` groups.
|
||||
direction: Whether to fetch the highest count first (`"b"`) or
|
||||
the lowest count first (`"f"`).
|
||||
from_token (AggregationPaginationToken|None): Fetch rows from the
|
||||
given token, or from the start if None.
|
||||
to_token (AggregationPaginationToken|None): Fetch rows up to the
|
||||
given token, or up to the end if None.
|
||||
|
||||
from_token: Fetch rows from the given token, or from the start if None.
|
||||
to_token: Fetch rows up to the given token, or up to the end if None.
|
||||
|
||||
Returns:
|
||||
Deferred[PaginationChunk]: List of groups of annotations that
|
||||
match. Each row is a dict with `type`, `key` and `count` fields.
|
||||
List of groups of annotations that match. Each row is a dict with
|
||||
`type`, `key` and `count` fields.
|
||||
"""
|
||||
|
||||
where_clause = ["relates_to_id = ?", "relation_type = ?"]
|
||||
@@ -225,7 +216,7 @@ class RelationsWorkerStore(SQLBaseStore):
|
||||
chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token
|
||||
)
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn
|
||||
)
|
||||
|
||||
@@ -279,18 +270,20 @@ class RelationsWorkerStore(SQLBaseStore):
|
||||
|
||||
return await self.get_event(edit_id, allow_none=True)
|
||||
|
||||
def has_user_annotated_event(self, parent_id, event_type, aggregation_key, sender):
|
||||
async def has_user_annotated_event(
|
||||
self, parent_id: str, event_type: str, aggregation_key: str, sender: str
|
||||
) -> bool:
|
||||
"""Check if a user has already annotated an event with the same key
|
||||
(e.g. already liked an event).
|
||||
|
||||
Args:
|
||||
parent_id (str): The event being annotated
|
||||
event_type (str): The event type of the annotation
|
||||
aggregation_key (str): The aggregation key of the annotation
|
||||
sender (str): The sender of the annotation
|
||||
parent_id: The event being annotated
|
||||
event_type: The event type of the annotation
|
||||
aggregation_key: The aggregation key of the annotation
|
||||
sender: The sender of the annotation
|
||||
|
||||
Returns:
|
||||
Deferred[bool]
|
||||
True if the event is already annotated.
|
||||
"""
|
||||
|
||||
sql = """
|
||||
@@ -319,7 +312,7 @@ class RelationsWorkerStore(SQLBaseStore):
|
||||
|
||||
return bool(txn.fetchone())
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_if_user_has_annotated_event", _get_if_user_has_annotated_event
|
||||
)
|
||||
|
||||
|
||||
@@ -89,7 +89,7 @@ class RoomWorkerStore(SQLBaseStore):
|
||||
allow_none=True,
|
||||
)
|
||||
|
||||
def get_room_with_stats(self, room_id: str):
|
||||
async def get_room_with_stats(self, room_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Retrieve room with statistics.
|
||||
|
||||
Args:
|
||||
@@ -121,7 +121,7 @@ class RoomWorkerStore(SQLBaseStore):
|
||||
res["public"] = bool(res["public"])
|
||||
return res
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_room_with_stats", get_room_with_stats_txn, room_id
|
||||
)
|
||||
|
||||
@@ -133,13 +133,17 @@ class RoomWorkerStore(SQLBaseStore):
|
||||
desc="get_public_room_ids",
|
||||
)
|
||||
|
||||
def count_public_rooms(self, network_tuple, ignore_non_federatable):
|
||||
async def count_public_rooms(
|
||||
self,
|
||||
network_tuple: Optional[ThirdPartyInstanceID],
|
||||
ignore_non_federatable: bool,
|
||||
) -> int:
|
||||
"""Counts the number of public rooms as tracked in the room_stats_current
|
||||
and room_stats_state table.
|
||||
|
||||
Args:
|
||||
network_tuple (ThirdPartyInstanceID|None)
|
||||
ignore_non_federatable (bool): If true filters out non-federatable rooms
|
||||
network_tuple
|
||||
ignore_non_federatable: If true filters out non-federatable rooms
|
||||
"""
|
||||
|
||||
def _count_public_rooms_txn(txn):
|
||||
@@ -183,7 +187,7 @@ class RoomWorkerStore(SQLBaseStore):
|
||||
txn.execute(sql, query_args)
|
||||
return txn.fetchone()[0]
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
"count_public_rooms", _count_public_rooms_txn
|
||||
)
|
||||
|
||||
@@ -586,15 +590,14 @@ class RoomWorkerStore(SQLBaseStore):
|
||||
|
||||
return row
|
||||
|
||||
def get_media_mxcs_in_room(self, room_id):
|
||||
async def get_media_mxcs_in_room(self, room_id: str) -> Tuple[List[str], List[str]]:
|
||||
"""Retrieves all the local and remote media MXC URIs in a given room
|
||||
|
||||
Args:
|
||||
room_id (str)
|
||||
room_id
|
||||
|
||||
Returns:
|
||||
The local and remote media as a lists of tuples where the key is
|
||||
the hostname and the value is the media ID.
|
||||
The local and remote media as a lists of the media IDs.
|
||||
"""
|
||||
|
||||
def _get_media_mxcs_in_room_txn(txn):
|
||||
@@ -610,11 +613,13 @@ class RoomWorkerStore(SQLBaseStore):
|
||||
|
||||
return local_media_mxcs, remote_media_mxcs
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_media_ids_in_room", _get_media_mxcs_in_room_txn
|
||||
)
|
||||
|
||||
def quarantine_media_ids_in_room(self, room_id, quarantined_by):
|
||||
async def quarantine_media_ids_in_room(
|
||||
self, room_id: str, quarantined_by: str
|
||||
) -> int:
|
||||
"""For a room loops through all events with media and quarantines
|
||||
the associated media
|
||||
"""
|
||||
@@ -627,7 +632,7 @@ class RoomWorkerStore(SQLBaseStore):
|
||||
txn, local_mxcs, remote_mxcs, quarantined_by
|
||||
)
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
"quarantine_media_in_room", _quarantine_media_in_room_txn
|
||||
)
|
||||
|
||||
@@ -690,9 +695,9 @@ class RoomWorkerStore(SQLBaseStore):
|
||||
|
||||
return local_media_mxcs, remote_media_mxcs
|
||||
|
||||
def quarantine_media_by_id(
|
||||
async def quarantine_media_by_id(
|
||||
self, server_name: str, media_id: str, quarantined_by: str,
|
||||
):
|
||||
) -> int:
|
||||
"""quarantines a single local or remote media id
|
||||
|
||||
Args:
|
||||
@@ -711,11 +716,13 @@ class RoomWorkerStore(SQLBaseStore):
|
||||
txn, local_mxcs, remote_mxcs, quarantined_by
|
||||
)
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
"quarantine_media_by_user", _quarantine_media_by_id_txn
|
||||
)
|
||||
|
||||
def quarantine_media_ids_by_user(self, user_id: str, quarantined_by: str):
|
||||
async def quarantine_media_ids_by_user(
|
||||
self, user_id: str, quarantined_by: str
|
||||
) -> int:
|
||||
"""quarantines all local media associated with a single user
|
||||
|
||||
Args:
|
||||
@@ -727,7 +734,7 @@ class RoomWorkerStore(SQLBaseStore):
|
||||
local_media_ids = self._get_media_ids_by_user_txn(txn, user_id)
|
||||
return self._quarantine_media_txn(txn, local_media_ids, [], quarantined_by)
|
||||
|
||||
return self.db_pool.runInteraction(
|
||||
return await self.db_pool.runInteraction(
|
||||
"quarantine_media_by_user", _quarantine_media_by_user_txn
|
||||
)
|
||||
|
||||
@@ -1284,8 +1291,8 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
|
||||
)
|
||||
self.hs.get_notifier().on_new_replication_data()
|
||||
|
||||
def get_room_count(self):
|
||||
"""Retrieve a list of all rooms
|
||||
async def get_room_count(self) -> int:
|
||||
"""Retrieve the total number of rooms.
|
||||
"""
|
||||
|
||||
def f(txn):
|
||||
@@ -1294,7 +1301,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore, SearchStore):
|
||||
row = txn.fetchone()
|
||||
return row[0] or 0
|
||||
|
||||
return self.db_pool.runInteraction("get_rooms", f)
|
||||
return await self.db_pool.runInteraction("get_rooms", f)
|
||||
|
||||
async def add_event_report(
|
||||
self,
|
||||
|
||||
@@ -13,9 +13,12 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Dict, Iterable, List, Tuple
|
||||
|
||||
from unpaddedbase64 import encode_base64
|
||||
|
||||
from synapse.storage._base import SQLBaseStore
|
||||
from synapse.storage.types import Cursor
|
||||
from synapse.util.caches.descriptors import cached, cachedList
|
||||
|
||||
|
||||
@@ -29,16 +32,37 @@ class SignatureWorkerStore(SQLBaseStore):
|
||||
@cachedList(
|
||||
cached_method_name="get_event_reference_hash", list_name="event_ids", num_args=1
|
||||
)
|
||||
def get_event_reference_hashes(self, event_ids):
|
||||
async def get_event_reference_hashes(
|
||||
self, event_ids: Iterable[str]
|
||||
) -> Dict[str, Dict[str, bytes]]:
|
||||
"""Get all hashes for given events.
|
||||
|
||||
Args:
|
||||
event_ids: The event IDs to get hashes for.
|
||||
|
||||
Returns:
|
||||
A mapping of event ID to a mapping of algorithm to hash.
|
||||
"""
|
||||
|
||||
def f(txn):
|
||||
return {
|
||||
event_id: self._get_event_reference_hashes_txn(txn, event_id)
|
||||
for event_id in event_ids
|
||||
}
|
||||
|
||||
return self.db_pool.runInteraction("get_event_reference_hashes", f)
|
||||
return await self.db_pool.runInteraction("get_event_reference_hashes", f)
|
||||
|
||||
async def add_event_hashes(self, event_ids):
|
||||
async def add_event_hashes(
|
||||
self, event_ids: Iterable[str]
|
||||
) -> List[Tuple[str, Dict[str, str]]]:
|
||||
"""
|
||||
|
||||
Args:
|
||||
event_ids: The event IDs
|
||||
|
||||
Returns:
|
||||
A list of tuples of event ID and a mapping of algorithm to base-64 encoded hash.
|
||||
"""
|
||||
hashes = await self.get_event_reference_hashes(event_ids)
|
||||
hashes = {
|
||||
e_id: {k: encode_base64(v) for k, v in h.items() if k == "sha256"}
|
||||
@@ -47,13 +71,15 @@ class SignatureWorkerStore(SQLBaseStore):
|
||||
|
||||
return list(hashes.items())
|
||||
|
||||
def _get_event_reference_hashes_txn(self, txn, event_id):
|
||||
def _get_event_reference_hashes_txn(
|
||||
self, txn: Cursor, event_id: str
|
||||
) -> Dict[str, bytes]:
|
||||
"""Get all the hashes for a given PDU.
|
||||
Args:
|
||||
txn (cursor):
|
||||
event_id (str): Id for the Event.
|
||||
txn:
|
||||
event_id: Id for the Event.
|
||||
Returns:
|
||||
A dict[unicode, bytes] of algorithm -> hash.
|
||||
A mapping of algorithm -> hash.
|
||||
"""
|
||||
query = (
|
||||
"SELECT algorithm, hash"
|
||||
|
||||
@@ -290,7 +290,7 @@ class UIAuthWorkerStore(SQLBaseStore):
|
||||
|
||||
|
||||
class UIAuthStore(UIAuthWorkerStore):
|
||||
def delete_old_ui_auth_sessions(self, expiration_time: int):
|
||||
async def delete_old_ui_auth_sessions(self, expiration_time: int) -> None:
|
||||
"""
|
||||
Remove sessions which were last used earlier than the expiration time.
|
||||
|
||||
@@ -299,7 +299,7 @@ class UIAuthStore(UIAuthWorkerStore):
|
||||
This is an epoch time in milliseconds.
|
||||
|
||||
"""
|
||||
return self.db_pool.runInteraction(
|
||||
await self.db_pool.runInteraction(
|
||||
"delete_old_ui_auth_sessions",
|
||||
self._delete_old_ui_auth_sessions_txn,
|
||||
expiration_time,
|
||||
|
||||
@@ -66,7 +66,7 @@ class UserErasureWorkerStore(SQLBaseStore):
|
||||
|
||||
|
||||
class UserErasureStore(UserErasureWorkerStore):
|
||||
def mark_user_erased(self, user_id: str) -> None:
|
||||
async def mark_user_erased(self, user_id: str) -> None:
|
||||
"""Indicate that user_id wishes their message history to be erased.
|
||||
|
||||
Args:
|
||||
@@ -84,9 +84,9 @@ class UserErasureStore(UserErasureWorkerStore):
|
||||
|
||||
self._invalidate_cache_and_stream(txn, self.is_user_erased, (user_id,))
|
||||
|
||||
return self.db_pool.runInteraction("mark_user_erased", f)
|
||||
await self.db_pool.runInteraction("mark_user_erased", f)
|
||||
|
||||
def mark_user_not_erased(self, user_id: str) -> None:
|
||||
async def mark_user_not_erased(self, user_id: str) -> None:
|
||||
"""Indicate that user_id is no longer erased.
|
||||
|
||||
Args:
|
||||
@@ -106,4 +106,4 @@ class UserErasureStore(UserErasureWorkerStore):
|
||||
|
||||
self._invalidate_cache_and_stream(txn, self.is_user_erased, (user_id,))
|
||||
|
||||
return self.db_pool.runInteraction("mark_user_not_erased", f)
|
||||
await self.db_pool.runInteraction("mark_user_not_erased", f)
|
||||
|
||||
@@ -185,6 +185,8 @@ class MultiWriterIdGenerator:
|
||||
id_column: Column that stores the stream ID.
|
||||
sequence_name: The name of the postgres sequence used to generate new
|
||||
IDs.
|
||||
positive: Whether the IDs are positive (true) or negative (false).
|
||||
When using negative IDs we go backwards from -1 to -2, -3, etc.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -196,13 +198,19 @@ class MultiWriterIdGenerator:
|
||||
instance_column: str,
|
||||
id_column: str,
|
||||
sequence_name: str,
|
||||
positive: bool = True,
|
||||
):
|
||||
self._db = db
|
||||
self._instance_name = instance_name
|
||||
self._positive = positive
|
||||
self._return_factor = 1 if positive else -1
|
||||
|
||||
# We lock as some functions may be called from DB threads.
|
||||
self._lock = threading.Lock()
|
||||
|
||||
# Note: If we are a negative stream then we still store all the IDs as
|
||||
# positive to make life easier for us, and simply negate the IDs when we
|
||||
# return them.
|
||||
self._current_positions = self._load_current_ids(
|
||||
db_conn, table, instance_column, id_column
|
||||
)
|
||||
@@ -233,13 +241,16 @@ class MultiWriterIdGenerator:
|
||||
def _load_current_ids(
|
||||
self, db_conn, table: str, instance_column: str, id_column: str
|
||||
) -> Dict[str, int]:
|
||||
# If positive stream aggregate via MAX. For negative stream use MIN
|
||||
# *and* negate the result to get a positive number.
|
||||
sql = """
|
||||
SELECT %(instance)s, MAX(%(id)s) FROM %(table)s
|
||||
SELECT %(instance)s, %(agg)s(%(id)s) FROM %(table)s
|
||||
GROUP BY %(instance)s
|
||||
""" % {
|
||||
"instance": instance_column,
|
||||
"id": id_column,
|
||||
"table": table,
|
||||
"agg": "MAX" if self._positive else "-MIN",
|
||||
}
|
||||
|
||||
cur = db_conn.cursor()
|
||||
@@ -269,15 +280,16 @@ class MultiWriterIdGenerator:
|
||||
# Assert the fetched ID is actually greater than what we currently
|
||||
# believe the ID to be. If not, then the sequence and table have got
|
||||
# out of sync somehow.
|
||||
assert self.get_current_token_for_writer(self._instance_name) < next_id
|
||||
|
||||
with self._lock:
|
||||
assert self._current_positions.get(self._instance_name, 0) < next_id
|
||||
|
||||
self._unfinished_ids.add(next_id)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def manager():
|
||||
try:
|
||||
yield next_id
|
||||
# Multiply by the return factor so that the ID has correct sign.
|
||||
yield self._return_factor * next_id
|
||||
finally:
|
||||
self._mark_id_as_finished(next_id)
|
||||
|
||||
@@ -296,15 +308,15 @@ class MultiWriterIdGenerator:
|
||||
# Assert the fetched ID is actually greater than any ID we've already
|
||||
# seen. If not, then the sequence and table have got out of sync
|
||||
# somehow.
|
||||
assert max(self.get_positions().values(), default=0) < min(next_ids)
|
||||
|
||||
with self._lock:
|
||||
assert max(self._current_positions.values(), default=0) < min(next_ids)
|
||||
|
||||
self._unfinished_ids.update(next_ids)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def manager():
|
||||
try:
|
||||
yield next_ids
|
||||
yield [self._return_factor * i for i in next_ids]
|
||||
finally:
|
||||
for i in next_ids:
|
||||
self._mark_id_as_finished(i)
|
||||
@@ -327,7 +339,7 @@ class MultiWriterIdGenerator:
|
||||
txn.call_after(self._mark_id_as_finished, next_id)
|
||||
txn.call_on_exception(self._mark_id_as_finished, next_id)
|
||||
|
||||
return next_id
|
||||
return self._return_factor * next_id
|
||||
|
||||
def _mark_id_as_finished(self, next_id: int):
|
||||
"""The ID has finished being processed so we should advance the
|
||||
@@ -359,20 +371,25 @@ class MultiWriterIdGenerator:
|
||||
"""
|
||||
|
||||
with self._lock:
|
||||
return self._current_positions.get(instance_name, 0)
|
||||
return self._return_factor * self._current_positions.get(instance_name, 0)
|
||||
|
||||
def get_positions(self) -> Dict[str, int]:
|
||||
"""Get a copy of the current positon map.
|
||||
"""
|
||||
|
||||
with self._lock:
|
||||
return dict(self._current_positions)
|
||||
return {
|
||||
name: self._return_factor * i
|
||||
for name, i in self._current_positions.items()
|
||||
}
|
||||
|
||||
def advance(self, instance_name: str, new_id: int):
|
||||
"""Advance the postion of the named writer to the given ID, if greater
|
||||
than existing entry.
|
||||
"""
|
||||
|
||||
new_id *= self._return_factor
|
||||
|
||||
with self._lock:
|
||||
self._current_positions[instance_name] = max(
|
||||
new_id, self._current_positions.get(instance_name, 0)
|
||||
@@ -390,7 +407,7 @@ class MultiWriterIdGenerator:
|
||||
"""
|
||||
|
||||
with self._lock:
|
||||
return self._persisted_upto_position
|
||||
return self._return_factor * self._persisted_upto_position
|
||||
|
||||
def _add_persisted_position(self, new_id: int):
|
||||
"""Record that we have persisted a position.
|
||||
|
||||
@@ -972,7 +972,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
|
||||
def test_well_known_cache(self):
|
||||
self.reactor.lookups["testserv"] = "1.2.3.4"
|
||||
|
||||
fetch_d = self.well_known_resolver.get_well_known(b"testserv")
|
||||
fetch_d = defer.ensureDeferred(
|
||||
self.well_known_resolver.get_well_known(b"testserv")
|
||||
)
|
||||
|
||||
# there should be an attempt to connect on port 443 for the .well-known
|
||||
clients = self.reactor.tcpClients
|
||||
@@ -995,7 +997,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
|
||||
well_known_server.loseConnection()
|
||||
|
||||
# repeat the request: it should hit the cache
|
||||
fetch_d = self.well_known_resolver.get_well_known(b"testserv")
|
||||
fetch_d = defer.ensureDeferred(
|
||||
self.well_known_resolver.get_well_known(b"testserv")
|
||||
)
|
||||
r = self.successResultOf(fetch_d)
|
||||
self.assertEqual(r.delegated_server, b"target-server")
|
||||
|
||||
@@ -1003,7 +1007,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
|
||||
self.reactor.pump((1000.0,))
|
||||
|
||||
# now it should connect again
|
||||
fetch_d = self.well_known_resolver.get_well_known(b"testserv")
|
||||
fetch_d = defer.ensureDeferred(
|
||||
self.well_known_resolver.get_well_known(b"testserv")
|
||||
)
|
||||
|
||||
self.assertEqual(len(clients), 1)
|
||||
(host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
|
||||
@@ -1026,7 +1032,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
|
||||
|
||||
self.reactor.lookups["testserv"] = "1.2.3.4"
|
||||
|
||||
fetch_d = self.well_known_resolver.get_well_known(b"testserv")
|
||||
fetch_d = defer.ensureDeferred(
|
||||
self.well_known_resolver.get_well_known(b"testserv")
|
||||
)
|
||||
|
||||
# there should be an attempt to connect on port 443 for the .well-known
|
||||
clients = self.reactor.tcpClients
|
||||
@@ -1052,7 +1060,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
|
||||
# another lookup.
|
||||
self.reactor.pump((900.0,))
|
||||
|
||||
fetch_d = self.well_known_resolver.get_well_known(b"testserv")
|
||||
fetch_d = defer.ensureDeferred(
|
||||
self.well_known_resolver.get_well_known(b"testserv")
|
||||
)
|
||||
|
||||
# The resolver may retry a few times, so fonx all requests that come along
|
||||
attempts = 0
|
||||
@@ -1082,7 +1092,9 @@ class MatrixFederationAgentTests(unittest.TestCase):
|
||||
self.reactor.pump((10000.0,))
|
||||
|
||||
# Repated the request, this time it should fail if the lookup fails.
|
||||
fetch_d = self.well_known_resolver.get_well_known(b"testserv")
|
||||
fetch_d = defer.ensureDeferred(
|
||||
self.well_known_resolver.get_well_known(b"testserv")
|
||||
)
|
||||
|
||||
clients = self.reactor.tcpClients
|
||||
(host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
|
||||
|
||||
@@ -264,3 +264,108 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
|
||||
# We assume that so long as `get_next` does correctly advance the
|
||||
# `persisted_upto_position` in this case, then it will be correct in the
|
||||
# other cases that are tested above (since they'll hit the same code).
|
||||
|
||||
|
||||
class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
|
||||
"""Tests MultiWriterIdGenerator that produce *negative* stream IDs.
|
||||
"""
|
||||
|
||||
if not USE_POSTGRES_FOR_TESTS:
|
||||
skip = "Requires Postgres"
|
||||
|
||||
def prepare(self, reactor, clock, hs):
|
||||
self.store = hs.get_datastore()
|
||||
self.db_pool = self.store.db_pool # type: DatabasePool
|
||||
|
||||
self.get_success(self.db_pool.runInteraction("_setup_db", self._setup_db))
|
||||
|
||||
def _setup_db(self, txn):
|
||||
txn.execute("CREATE SEQUENCE foobar_seq")
|
||||
txn.execute(
|
||||
"""
|
||||
CREATE TABLE foobar (
|
||||
stream_id BIGINT NOT NULL,
|
||||
instance_name TEXT NOT NULL,
|
||||
data TEXT
|
||||
);
|
||||
"""
|
||||
)
|
||||
|
||||
def _create_id_generator(self, instance_name="master") -> MultiWriterIdGenerator:
|
||||
def _create(conn):
|
||||
return MultiWriterIdGenerator(
|
||||
conn,
|
||||
self.db_pool,
|
||||
instance_name=instance_name,
|
||||
table="foobar",
|
||||
instance_column="instance_name",
|
||||
id_column="stream_id",
|
||||
sequence_name="foobar_seq",
|
||||
positive=False,
|
||||
)
|
||||
|
||||
return self.get_success(self.db_pool.runWithConnection(_create))
|
||||
|
||||
def _insert_row(self, instance_name: str, stream_id: int):
|
||||
"""Insert one row as the given instance with given stream_id.
|
||||
"""
|
||||
|
||||
def _insert(txn):
|
||||
txn.execute(
|
||||
"INSERT INTO foobar VALUES (?, ?)", (stream_id, instance_name,),
|
||||
)
|
||||
|
||||
self.get_success(self.db_pool.runInteraction("_insert_row", _insert))
|
||||
|
||||
def test_single_instance(self):
|
||||
"""Test that reads and writes from a single process are handled
|
||||
correctly.
|
||||
"""
|
||||
id_gen = self._create_id_generator()
|
||||
|
||||
with self.get_success(id_gen.get_next()) as stream_id:
|
||||
self._insert_row("master", stream_id)
|
||||
|
||||
self.assertEqual(id_gen.get_positions(), {"master": -1})
|
||||
self.assertEqual(id_gen.get_current_token_for_writer("master"), -1)
|
||||
self.assertEqual(id_gen.get_persisted_upto_position(), -1)
|
||||
|
||||
with self.get_success(id_gen.get_next_mult(3)) as stream_ids:
|
||||
for stream_id in stream_ids:
|
||||
self._insert_row("master", stream_id)
|
||||
|
||||
self.assertEqual(id_gen.get_positions(), {"master": -4})
|
||||
self.assertEqual(id_gen.get_current_token_for_writer("master"), -4)
|
||||
self.assertEqual(id_gen.get_persisted_upto_position(), -4)
|
||||
|
||||
# Test loading from DB by creating a second ID gen
|
||||
second_id_gen = self._create_id_generator()
|
||||
|
||||
self.assertEqual(second_id_gen.get_positions(), {"master": -4})
|
||||
self.assertEqual(second_id_gen.get_current_token_for_writer("master"), -4)
|
||||
self.assertEqual(second_id_gen.get_persisted_upto_position(), -4)
|
||||
|
||||
def test_multiple_instance(self):
|
||||
"""Tests that having multiple instances that get advanced over
|
||||
federation works corretly.
|
||||
"""
|
||||
id_gen_1 = self._create_id_generator("first")
|
||||
id_gen_2 = self._create_id_generator("second")
|
||||
|
||||
with self.get_success(id_gen_1.get_next()) as stream_id:
|
||||
self._insert_row("first", stream_id)
|
||||
id_gen_2.advance("first", stream_id)
|
||||
|
||||
self.assertEqual(id_gen_1.get_positions(), {"first": -1})
|
||||
self.assertEqual(id_gen_2.get_positions(), {"first": -1})
|
||||
self.assertEqual(id_gen_1.get_persisted_upto_position(), -1)
|
||||
self.assertEqual(id_gen_2.get_persisted_upto_position(), -1)
|
||||
|
||||
with self.get_success(id_gen_2.get_next()) as stream_id:
|
||||
self._insert_row("second", stream_id)
|
||||
id_gen_1.advance("second", stream_id)
|
||||
|
||||
self.assertEqual(id_gen_1.get_positions(), {"first": -1, "second": -2})
|
||||
self.assertEqual(id_gen_2.get_positions(), {"first": -1, "second": -2})
|
||||
self.assertEqual(id_gen_1.get_persisted_upto_position(), -2)
|
||||
self.assertEqual(id_gen_2.get_persisted_upto_position(), -2)
|
||||
|
||||
@@ -13,14 +13,13 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Optional, Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import synapse.server
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
|
||||
from synapse.events import EventBase
|
||||
from synapse.events.snapshot import EventContext
|
||||
from synapse.types import Collection
|
||||
|
||||
"""
|
||||
Utility functions for poking events into the storage of the server under test.
|
||||
@@ -58,7 +57,7 @@ async def inject_member_event(
|
||||
async def inject_event(
|
||||
hs: synapse.server.HomeServer,
|
||||
room_version: Optional[str] = None,
|
||||
prev_event_ids: Optional[Collection[str]] = None,
|
||||
prev_event_ids: Optional[List[str]] = None,
|
||||
**kwargs
|
||||
) -> EventBase:
|
||||
"""Inject a generic event into a room
|
||||
@@ -80,7 +79,7 @@ async def inject_event(
|
||||
async def create_event(
|
||||
hs: synapse.server.HomeServer,
|
||||
room_version: Optional[str] = None,
|
||||
prev_event_ids: Optional[Collection[str]] = None,
|
||||
prev_event_ids: Optional[List[str]] = None,
|
||||
**kwargs
|
||||
) -> Tuple[EventBase, EventContext]:
|
||||
if room_version is None:
|
||||
|
||||
Reference in New Issue
Block a user