Merge commit 'c9c544cda' into anoa/dinsic_release_1_21_x
* commit 'c9c544cda': Remove `ChainedIdGenerator`. (#8123) Switch the JSON byte producer from a pull to a push producer. (#8116) Updated docs: Added note about missing 308 redirect support. (#8120) Be stricter about JSON that is accepted by Synapse (#8106) Convert runWithConnection to async. (#8121) Remove the unused inlineCallbacks code-paths in the caching code (#8119) Separate `get_current_token` into two. (#8113) Convert events worker database to async/await. (#8071) Add a link to the matrix-synapse-rest-password-provider. (#8111)
This commit is contained in:
1
changelog.d/8071.misc
Normal file
1
changelog.d/8071.misc
Normal file
@@ -0,0 +1 @@
|
||||
Convert various parts of the codebase to async/await.
|
||||
1
changelog.d/8106.bugfix
Normal file
1
changelog.d/8106.bugfix
Normal file
@@ -0,0 +1 @@
|
||||
Fix a long-standing bug where invalid JSON would be accepted by Synapse.
|
||||
1
changelog.d/8111.doc
Normal file
1
changelog.d/8111.doc
Normal file
@@ -0,0 +1 @@
|
||||
Link to matrix-synapse-rest-password-provider in the password provider documentation.
|
||||
1
changelog.d/8113.misc
Normal file
1
changelog.d/8113.misc
Normal file
@@ -0,0 +1 @@
|
||||
Separate `get_current_token` into two since there are two different use cases for it.
|
||||
1
changelog.d/8116.feature
Normal file
1
changelog.d/8116.feature
Normal file
@@ -0,0 +1 @@
|
||||
Iteratively encode JSON to avoid blocking the reactor.
|
||||
1
changelog.d/8119.misc
Normal file
1
changelog.d/8119.misc
Normal file
@@ -0,0 +1 @@
|
||||
Convert various parts of the codebase to async/await.
|
||||
1
changelog.d/8120.doc
Normal file
1
changelog.d/8120.doc
Normal file
@@ -0,0 +1 @@
|
||||
Updated documentation to note that Synapse does not follow `HTTP 308` redirects due to an upstream library not supporting them. Contributed by Ryan Cole.
|
||||
1
changelog.d/8121.misc
Normal file
1
changelog.d/8121.misc
Normal file
@@ -0,0 +1 @@
|
||||
Convert various parts of the codebase to async/await.
|
||||
1
changelog.d/8123.misc
Normal file
1
changelog.d/8123.misc
Normal file
@@ -0,0 +1 @@
|
||||
Remove `ChainedIdGenerator`.
|
||||
@@ -47,6 +47,18 @@ you invite them to. This can be caused by an incorrectly-configured reverse
|
||||
proxy: see [reverse_proxy.md](<reverse_proxy.md>) for instructions on how to correctly
|
||||
configure a reverse proxy.
|
||||
|
||||
### Known issues
|
||||
|
||||
**HTTP `308 Permanent Redirect` redirects are not followed**: Due to missing features
|
||||
in the HTTP library used by Synapse, 308 redirects are currently not followed by
|
||||
federating servers, which can cause `M_UNKNOWN` or `401 Unauthorized` errors. This
|
||||
may affect users who are redirecting apex-to-www (e.g. `example.com` -> `www.example.com`),
|
||||
and especially users of the Kubernetes *Nginx Ingress* module, which uses 308 redirect
|
||||
codes by default. For those Kubernetes users, [this Stackoverflow post](https://stackoverflow.com/a/52617528/5096871)
|
||||
might be helpful. For other users, switching to a `301 Moved Permanently` code may be
|
||||
an option. 308 redirect codes will be supported properly in a future
|
||||
release of Synapse.
|
||||
|
||||
## Running a demo federation of Synapses
|
||||
|
||||
If you want to get up and running quickly with a trio of homeservers in a
|
||||
|
||||
@@ -14,6 +14,7 @@ password auth provider module implementations:
|
||||
|
||||
* [matrix-synapse-ldap3](https://github.com/matrix-org/matrix-synapse-ldap3/)
|
||||
* [matrix-synapse-shared-secret-auth](https://github.com/devture/matrix-synapse-shared-secret-auth)
|
||||
* [matrix-synapse-rest-password-provider](https://github.com/ma1uta/matrix-synapse-rest-password-provider)
|
||||
|
||||
## Required methods
|
||||
|
||||
|
||||
@@ -22,10 +22,10 @@ import typing
|
||||
from http import HTTPStatus
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from canonicaljson import json
|
||||
|
||||
from twisted.web import http
|
||||
|
||||
from synapse.util import json_decoder
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from synapse.types import JsonDict
|
||||
|
||||
@@ -594,7 +594,7 @@ class HttpResponseException(CodeMessageException):
|
||||
# try to parse the body as json, to get better errcode/msg, but
|
||||
# default to M_UNKNOWN with the HTTP status as the error text
|
||||
try:
|
||||
j = json.loads(self.response.decode("utf-8"))
|
||||
j = json_decoder.decode(self.response.decode("utf-8"))
|
||||
except ValueError:
|
||||
j = {}
|
||||
|
||||
|
||||
@@ -47,7 +47,7 @@ def check(
|
||||
Args:
|
||||
room_version_obj: the version of the room
|
||||
event: the event being checked.
|
||||
auth_events (dict: event-key -> event): the existing room state.
|
||||
auth_events: the existing room state.
|
||||
|
||||
Raises:
|
||||
AuthError if the checks fail
|
||||
|
||||
@@ -28,7 +28,6 @@ from typing import (
|
||||
Union,
|
||||
)
|
||||
|
||||
from canonicaljson import json
|
||||
from prometheus_client import Counter, Histogram
|
||||
|
||||
from twisted.internet import defer
|
||||
@@ -63,7 +62,7 @@ from synapse.replication.http.federation import (
|
||||
ReplicationGetQueryRestServlet,
|
||||
)
|
||||
from synapse.types import JsonDict, get_domain_from_id
|
||||
from synapse.util import glob_to_regex, unwrapFirstError
|
||||
from synapse.util import glob_to_regex, json_decoder, unwrapFirstError
|
||||
from synapse.util.async_helpers import Linearizer, concurrently_execute
|
||||
from synapse.util.caches.response_cache import ResponseCache
|
||||
|
||||
@@ -551,7 +550,7 @@ class FederationServer(FederationBase):
|
||||
for device_id, keys in device_keys.items():
|
||||
for key_id, json_str in keys.items():
|
||||
json_result.setdefault(user_id, {})[device_id] = {
|
||||
key_id: json.loads(json_str)
|
||||
key_id: json_decoder.decode(json_str)
|
||||
}
|
||||
|
||||
logger.info(
|
||||
|
||||
@@ -15,8 +15,6 @@
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, List, Tuple
|
||||
|
||||
from canonicaljson import json
|
||||
|
||||
from synapse.api.errors import HttpResponseException
|
||||
from synapse.events import EventBase
|
||||
from synapse.federation.persistence import TransactionActions
|
||||
@@ -28,6 +26,7 @@ from synapse.logging.opentracing import (
|
||||
tags,
|
||||
whitelisted_homeserver,
|
||||
)
|
||||
from synapse.util import json_decoder
|
||||
from synapse.util.metrics import measure_func
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -71,7 +70,7 @@ class TransactionManager(object):
|
||||
for edu in pending_edus:
|
||||
context = edu.get_context()
|
||||
if context:
|
||||
span_contexts.append(extract_text_map(json.loads(context)))
|
||||
span_contexts.append(extract_text_map(json_decoder.decode(context)))
|
||||
if keep_destination:
|
||||
edu.strip_context()
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ import logging
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import attr
|
||||
from canonicaljson import encode_canonical_json, json
|
||||
from canonicaljson import encode_canonical_json
|
||||
from signedjson.key import VerifyKey, decode_verify_key_bytes
|
||||
from signedjson.sign import SignatureVerifyException, verify_signed_json
|
||||
from unpaddedbase64 import decode_base64
|
||||
@@ -35,7 +35,7 @@ from synapse.types import (
|
||||
get_domain_from_id,
|
||||
get_verify_key_from_cross_signing_key,
|
||||
)
|
||||
from synapse.util import unwrapFirstError
|
||||
from synapse.util import json_decoder, unwrapFirstError
|
||||
from synapse.util.async_helpers import Linearizer
|
||||
from synapse.util.caches.expiringcache import ExpiringCache
|
||||
from synapse.util.retryutils import NotRetryingDestination
|
||||
@@ -404,7 +404,7 @@ class E2eKeysHandler(object):
|
||||
for device_id, keys in device_keys.items():
|
||||
for key_id, json_bytes in keys.items():
|
||||
json_result.setdefault(user_id, {})[device_id] = {
|
||||
key_id: json.loads(json_bytes)
|
||||
key_id: json_decoder.decode(json_bytes)
|
||||
}
|
||||
|
||||
@trace
|
||||
@@ -1186,7 +1186,7 @@ def _exception_to_failure(e):
|
||||
|
||||
|
||||
def _one_time_keys_match(old_key_json, new_key):
|
||||
old_key = json.loads(old_key_json)
|
||||
old_key = json_decoder.decode(old_key_json)
|
||||
|
||||
# if either is a string rather than an object, they must match exactly
|
||||
if not isinstance(old_key, dict) or not isinstance(new_key, dict):
|
||||
|
||||
@@ -1787,9 +1787,7 @@ class FederationHandler(BaseHandler):
|
||||
"""Returns the state at the event. i.e. not including said event.
|
||||
"""
|
||||
|
||||
event = await self.store.get_event(
|
||||
event_id, allow_none=False, check_room_id=room_id
|
||||
)
|
||||
event = await self.store.get_event(event_id, check_room_id=room_id)
|
||||
|
||||
state_groups = await self.state_store.get_state_groups(room_id, [event_id])
|
||||
|
||||
@@ -1815,9 +1813,7 @@ class FederationHandler(BaseHandler):
|
||||
async def get_state_ids_for_pdu(self, room_id: str, event_id: str) -> List[str]:
|
||||
"""Returns the state at the event. i.e. not including said event.
|
||||
"""
|
||||
event = await self.store.get_event(
|
||||
event_id, allow_none=False, check_room_id=room_id
|
||||
)
|
||||
event = await self.store.get_event(event_id, check_room_id=room_id)
|
||||
|
||||
state_groups = await self.state_store.get_state_groups_ids(room_id, [event_id])
|
||||
|
||||
@@ -2165,9 +2161,9 @@ class FederationHandler(BaseHandler):
|
||||
auth_types = auth_types_for_event(event)
|
||||
current_state_ids = [e for k, e in current_state_ids.items() if k in auth_types]
|
||||
|
||||
current_auth_events = await self.store.get_events(current_state_ids)
|
||||
auth_events_map = await self.store.get_events(current_state_ids)
|
||||
current_auth_events = {
|
||||
(e.type, e.state_key): e for e in current_auth_events.values()
|
||||
(e.type, e.state_key): e for e in auth_events_map.values()
|
||||
}
|
||||
|
||||
try:
|
||||
@@ -2183,9 +2179,7 @@ class FederationHandler(BaseHandler):
|
||||
if not in_room:
|
||||
raise AuthError(403, "Host not in room.")
|
||||
|
||||
event = await self.store.get_event(
|
||||
event_id, allow_none=False, check_room_id=room_id
|
||||
)
|
||||
event = await self.store.get_event(event_id, check_room_id=room_id)
|
||||
|
||||
# Just go through and process each event in `remote_auth_chain`. We
|
||||
# don't want to fall into the trap of `missing` being wrong.
|
||||
|
||||
@@ -21,9 +21,6 @@ import logging
|
||||
import urllib.parse
|
||||
from typing import Awaitable, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
from canonicaljson import json
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.internet.error import TimeoutError
|
||||
|
||||
from synapse.api.errors import (
|
||||
@@ -37,6 +34,7 @@ from synapse.api.errors import (
|
||||
from synapse.config.emailconfig import ThreepidBehaviour
|
||||
from synapse.http.client import SimpleHttpClient
|
||||
from synapse.types import JsonDict, Requester
|
||||
from synapse.util import json_decoder
|
||||
from synapse.util.hash import sha256_and_url_safe_base64
|
||||
from synapse.util.stringutils import assert_valid_client_secret, random_string
|
||||
|
||||
@@ -197,7 +195,7 @@ class IdentityHandler(BaseHandler):
|
||||
except TimeoutError:
|
||||
raise SynapseError(500, "Timed out contacting identity server")
|
||||
except CodeMessageException as e:
|
||||
data = json.loads(e.msg) # XXX WAT?
|
||||
data = json_decoder.decode(e.msg) # XXX WAT?
|
||||
return data
|
||||
|
||||
logger.info("Got 404 when POSTing JSON %s, falling back to v1 URL", bind_url)
|
||||
@@ -620,18 +618,19 @@ class IdentityHandler(BaseHandler):
|
||||
# the CS API. They should be consolidated with those in RoomMemberHandler
|
||||
# https://github.com/matrix-org/synapse-dinsic/issues/25
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def proxy_lookup_3pid(self, id_server, medium, address):
|
||||
async def proxy_lookup_3pid(
|
||||
self, id_server: str, medium: str, address: str
|
||||
) -> JsonDict:
|
||||
"""Looks up a 3pid in the passed identity server.
|
||||
|
||||
Args:
|
||||
id_server (str): The server name (including port, if required)
|
||||
id_server: The server name (including port, if required)
|
||||
of the identity server to use.
|
||||
medium (str): The type of the third party identifier (e.g. "email").
|
||||
address (str): The third party identifier (e.g. "foo@example.com").
|
||||
medium: The type of the third party identifier (e.g. "email").
|
||||
address: The third party identifier (e.g. "foo@example.com").
|
||||
|
||||
Returns:
|
||||
Deferred[dict]: The result of the lookup. See
|
||||
The result of the lookup. See
|
||||
https://matrix.org/docs/spec/identity_service/r0.1.0.html#association-lookup
|
||||
for details
|
||||
"""
|
||||
@@ -643,16 +642,11 @@ class IdentityHandler(BaseHandler):
|
||||
id_server_url = self.rewrite_id_server_url(id_server, add_https=True)
|
||||
|
||||
try:
|
||||
data = yield self.http_client.get_json(
|
||||
data = await self.http_client.get_json(
|
||||
"%s/_matrix/identity/api/v1/lookup" % (id_server_url,),
|
||||
{"medium": medium, "address": address},
|
||||
)
|
||||
|
||||
if "mxid" in data:
|
||||
if "signatures" not in data:
|
||||
raise AuthError(401, "No signatures on 3pid binding")
|
||||
yield self._verify_any_signature(data, id_server)
|
||||
|
||||
except HttpResponseException as e:
|
||||
logger.info("Proxied lookup failed: %r", e)
|
||||
raise e.to_synapse_error()
|
||||
@@ -662,18 +656,19 @@ class IdentityHandler(BaseHandler):
|
||||
|
||||
return data
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def proxy_bulk_lookup_3pid(self, id_server, threepids):
|
||||
async def proxy_bulk_lookup_3pid(
|
||||
self, id_server: str, threepids: List[List[str]]
|
||||
) -> JsonDict:
|
||||
"""Looks up given 3pids in the passed identity server.
|
||||
|
||||
Args:
|
||||
id_server (str): The server name (including port, if required)
|
||||
id_server: The server name (including port, if required)
|
||||
of the identity server to use.
|
||||
threepids ([[str, str]]): The third party identifiers to lookup, as
|
||||
threepids: The third party identifiers to lookup, as
|
||||
a list of 2-string sized lists ([medium, address]).
|
||||
|
||||
Returns:
|
||||
Deferred[dict]: The result of the lookup. See
|
||||
The result of the lookup. See
|
||||
https://matrix.org/docs/spec/identity_service/r0.1.0.html#association-lookup
|
||||
for details
|
||||
"""
|
||||
@@ -685,7 +680,7 @@ class IdentityHandler(BaseHandler):
|
||||
id_server_url = self.rewrite_id_server_url(id_server, add_https=True)
|
||||
|
||||
try:
|
||||
data = yield self.http_client.post_json_get_json(
|
||||
data = await self.http_client.post_json_get_json(
|
||||
"%s/_matrix/identity/api/v1/bulk_lookup" % (id_server_url,),
|
||||
{"threepids": threepids},
|
||||
)
|
||||
@@ -697,7 +692,7 @@ class IdentityHandler(BaseHandler):
|
||||
logger.info("Failed to contact %s: %s", id_server, e)
|
||||
raise ProxiedRequestError(503, "Failed to contact identity server")
|
||||
|
||||
defer.returnValue(data)
|
||||
return data
|
||||
|
||||
async def lookup_3pid(
|
||||
self,
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
|
||||
|
||||
from canonicaljson import encode_canonical_json, json
|
||||
from canonicaljson import encode_canonical_json
|
||||
|
||||
from twisted.internet.interfaces import IDelayedCall
|
||||
|
||||
@@ -55,6 +55,7 @@ from synapse.types import (
|
||||
UserID,
|
||||
create_requester,
|
||||
)
|
||||
from synapse.util import json_decoder
|
||||
from synapse.util.async_helpers import Linearizer
|
||||
from synapse.util.frozenutils import frozendict_json_encoder
|
||||
from synapse.util.metrics import measure_func
|
||||
@@ -867,7 +868,7 @@ class EventCreationHandler(object):
|
||||
# Ensure that we can round trip before trying to persist in db
|
||||
try:
|
||||
dump = frozendict_json_encoder.encode(event.content)
|
||||
json.loads(dump)
|
||||
json_decoder.decode(dump)
|
||||
except Exception:
|
||||
logger.exception("Failed to encode content: %r", event.content)
|
||||
raise
|
||||
@@ -963,7 +964,7 @@ class EventCreationHandler(object):
|
||||
allow_none=True,
|
||||
)
|
||||
|
||||
is_admin_redaction = (
|
||||
is_admin_redaction = bool(
|
||||
original_event and event.sender != original_event.sender
|
||||
)
|
||||
|
||||
@@ -1083,8 +1084,8 @@ class EventCreationHandler(object):
|
||||
auth_events_ids = self.auth.compute_auth_events(
|
||||
event, prev_state_ids, for_verification=True
|
||||
)
|
||||
auth_events = await self.store.get_events(auth_events_ids)
|
||||
auth_events = {(e.type, e.state_key): e for e in auth_events.values()}
|
||||
auth_events_map = await self.store.get_events(auth_events_ids)
|
||||
auth_events = {(e.type, e.state_key): e for e in auth_events_map.values()}
|
||||
|
||||
room_version = await self.store.get_room_version_id(event.room_id)
|
||||
room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
|
||||
|
||||
@@ -12,7 +12,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import json
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Dict, Generic, List, Optional, Tuple, TypeVar
|
||||
from urllib.parse import urlencode
|
||||
@@ -39,6 +38,7 @@ from synapse.http.server import respond_with_html
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.logging.context import make_deferred_yieldable
|
||||
from synapse.types import UserID, map_username_to_mxid_localpart
|
||||
from synapse.util import json_decoder
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
@@ -367,7 +367,7 @@ class OidcHandler:
|
||||
# and check for an error field. If not, we respond with a generic
|
||||
# error message.
|
||||
try:
|
||||
resp = json.loads(resp_body.decode("utf-8"))
|
||||
resp = json_decoder.decode(resp_body.decode("utf-8"))
|
||||
error = resp["error"]
|
||||
description = resp.get("error_description", error)
|
||||
except (ValueError, KeyError):
|
||||
@@ -384,7 +384,7 @@ class OidcHandler:
|
||||
|
||||
# Since it is a not a 5xx code, body should be a valid JSON. It will
|
||||
# raise if not.
|
||||
resp = json.loads(resp_body.decode("utf-8"))
|
||||
resp = json_decoder.decode(resp_body.decode("utf-8"))
|
||||
|
||||
if "error" in resp:
|
||||
error = resp["error"]
|
||||
|
||||
@@ -133,8 +133,12 @@ class BaseProfileHandler(BaseHandler):
|
||||
body = {"batchnum": batchnum, "batch": batch, "origin_server": self.hs.hostname}
|
||||
signed_body = sign_json(body, self.hs.hostname, self.hs.config.signing_key[0])
|
||||
try:
|
||||
yield self.http_client.post_json_get_json(url, signed_body)
|
||||
yield self.store.update_replication_batch_for_host(host, batchnum)
|
||||
yield defer.ensureDeferred(
|
||||
self.http_client.post_json_get_json(url, signed_body)
|
||||
)
|
||||
yield defer.ensureDeferred(
|
||||
self.store.update_replication_batch_for_host(host, batchnum)
|
||||
)
|
||||
logger.info("Sucessfully replicated profile batch %d to %s", batchnum, host)
|
||||
except Exception:
|
||||
# This will get retried when the looping call next comes around
|
||||
|
||||
@@ -747,7 +747,7 @@ class RoomMemberHandler(object):
|
||||
|
||||
guest_access = await self.store.get_event(guest_access_id)
|
||||
|
||||
return (
|
||||
return bool(
|
||||
guest_access
|
||||
and guest_access.content
|
||||
and "guest_access" in guest_access.content
|
||||
|
||||
@@ -16,13 +16,12 @@
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from canonicaljson import json
|
||||
|
||||
from twisted.web.client import PartialDownloadError
|
||||
|
||||
from synapse.api.constants import LoginType
|
||||
from synapse.api.errors import Codes, LoginError, SynapseError
|
||||
from synapse.config.emailconfig import ThreepidBehaviour
|
||||
from synapse.util import json_decoder
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -117,7 +116,7 @@ class RecaptchaAuthChecker(UserInteractiveAuthChecker):
|
||||
except PartialDownloadError as pde:
|
||||
# Twisted is silly
|
||||
data = pde.response
|
||||
resp_body = json.loads(data.decode("utf-8"))
|
||||
resp_body = json_decoder.decode(data.decode("utf-8"))
|
||||
|
||||
if "success" in resp_body:
|
||||
# Note that we do NOT check the hostname here: we explicitly
|
||||
|
||||
@@ -19,7 +19,7 @@ import urllib
|
||||
from io import BytesIO
|
||||
|
||||
import treq
|
||||
from canonicaljson import encode_canonical_json, json
|
||||
from canonicaljson import encode_canonical_json
|
||||
from netaddr import IPAddress
|
||||
from prometheus_client import Counter
|
||||
from zope.interface import implementer, provider
|
||||
@@ -47,6 +47,7 @@ from synapse.http import (
|
||||
from synapse.http.proxyagent import ProxyAgent
|
||||
from synapse.logging.context import make_deferred_yieldable
|
||||
from synapse.logging.opentracing import set_tag, start_active_span, tags
|
||||
from synapse.util import json_decoder
|
||||
from synapse.util.async_helpers import timeout_deferred
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -391,7 +392,7 @@ class SimpleHttpClient(object):
|
||||
body = await make_deferred_yieldable(readBody(response))
|
||||
|
||||
if 200 <= response.code < 300:
|
||||
return json.loads(body.decode("utf-8"))
|
||||
return json_decoder.decode(body.decode("utf-8"))
|
||||
else:
|
||||
raise HttpResponseException(
|
||||
response.code, response.phrase.decode("ascii", errors="replace"), body
|
||||
@@ -433,7 +434,7 @@ class SimpleHttpClient(object):
|
||||
body = await make_deferred_yieldable(readBody(response))
|
||||
|
||||
if 200 <= response.code < 300:
|
||||
return json.loads(body.decode("utf-8"))
|
||||
return json_decoder.decode(body.decode("utf-8"))
|
||||
else:
|
||||
raise HttpResponseException(
|
||||
response.code, response.phrase.decode("ascii", errors="replace"), body
|
||||
@@ -463,7 +464,7 @@ class SimpleHttpClient(object):
|
||||
actual_headers.update(headers)
|
||||
|
||||
body = await self.get_raw(uri, args, headers=headers)
|
||||
return json.loads(body.decode("utf-8"))
|
||||
return json_decoder.decode(body.decode("utf-8"))
|
||||
|
||||
async def put_json(self, uri, json_body, args={}, headers=None):
|
||||
""" Puts some json to the given URI.
|
||||
@@ -506,7 +507,7 @@ class SimpleHttpClient(object):
|
||||
body = await make_deferred_yieldable(readBody(response))
|
||||
|
||||
if 200 <= response.code < 300:
|
||||
return json.loads(body.decode("utf-8"))
|
||||
return json_decoder.decode(body.decode("utf-8"))
|
||||
else:
|
||||
raise HttpResponseException(
|
||||
response.code, response.phrase.decode("ascii", errors="replace"), body
|
||||
|
||||
@@ -13,7 +13,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import logging
|
||||
import random
|
||||
import time
|
||||
@@ -26,7 +25,7 @@ from twisted.web.http import stringToDatetime
|
||||
from twisted.web.http_headers import Headers
|
||||
|
||||
from synapse.logging.context import make_deferred_yieldable
|
||||
from synapse.util import Clock
|
||||
from synapse.util import Clock, json_decoder
|
||||
from synapse.util.caches.ttlcache import TTLCache
|
||||
from synapse.util.metrics import Measure
|
||||
|
||||
@@ -181,7 +180,7 @@ class WellKnownResolver(object):
|
||||
if response.code != 200:
|
||||
raise Exception("Non-200 response %s" % (response.code,))
|
||||
|
||||
parsed_body = json.loads(body.decode("utf-8"))
|
||||
parsed_body = json_decoder.decode(body.decode("utf-8"))
|
||||
logger.info("Response from .well-known: %s", parsed_body)
|
||||
|
||||
result = parsed_body["m.server"].encode("ascii")
|
||||
|
||||
@@ -500,7 +500,7 @@ class RootOptionsRedirectResource(OptionsResource, RootRedirect):
|
||||
pass
|
||||
|
||||
|
||||
@implementer(interfaces.IPullProducer)
|
||||
@implementer(interfaces.IPushProducer)
|
||||
class _ByteProducer:
|
||||
"""
|
||||
Iteratively write bytes to the request.
|
||||
@@ -515,52 +515,64 @@ class _ByteProducer:
|
||||
):
|
||||
self._request = request
|
||||
self._iterator = iterator
|
||||
self._paused = False
|
||||
|
||||
def start(self) -> None:
|
||||
self._request.registerProducer(self, False)
|
||||
# Register the producer and start producing data.
|
||||
self._request.registerProducer(self, True)
|
||||
self.resumeProducing()
|
||||
|
||||
def _send_data(self, data: List[bytes]) -> None:
|
||||
"""
|
||||
Send a list of strings as a response to the request.
|
||||
Send a list of bytes as a chunk of a response.
|
||||
"""
|
||||
if not data:
|
||||
return
|
||||
self._request.write(b"".join(data))
|
||||
|
||||
def pauseProducing(self) -> None:
|
||||
self._paused = True
|
||||
|
||||
def resumeProducing(self) -> None:
|
||||
# We've stopped producing in the meantime (note that this might be
|
||||
# re-entrant after calling write).
|
||||
if not self._request:
|
||||
return
|
||||
|
||||
# Get the next chunk and write it to the request.
|
||||
#
|
||||
# The output of the JSON encoder is coalesced until min_chunk_size is
|
||||
# reached. (This is because JSON encoders produce a very small output
|
||||
# per iteration.)
|
||||
#
|
||||
# Note that buffer stores a list of bytes (instead of appending to
|
||||
# bytes) to hopefully avoid many allocations.
|
||||
buffer = []
|
||||
buffered_bytes = 0
|
||||
while buffered_bytes < self.min_chunk_size:
|
||||
try:
|
||||
data = next(self._iterator)
|
||||
buffer.append(data)
|
||||
buffered_bytes += len(data)
|
||||
except StopIteration:
|
||||
# The entire JSON object has been serialized, write any
|
||||
# remaining data, finalize the producer and the request, and
|
||||
# clean-up any references.
|
||||
self._send_data(buffer)
|
||||
self._request.unregisterProducer()
|
||||
self._request.finish()
|
||||
self.stopProducing()
|
||||
return
|
||||
self._paused = False
|
||||
|
||||
self._send_data(buffer)
|
||||
# Write until there's backpressure telling us to stop.
|
||||
while not self._paused:
|
||||
# Get the next chunk and write it to the request.
|
||||
#
|
||||
# The output of the JSON encoder is buffered and coalesced until
|
||||
# min_chunk_size is reached. This is because JSON encoders produce
|
||||
# very small output per iteration and the Request object converts
|
||||
# each call to write() to a separate chunk. Without this there would
|
||||
# be an explosion in bytes written (e.g. b"{" becoming "1\r\n{\r\n").
|
||||
#
|
||||
# Note that buffer stores a list of bytes (instead of appending to
|
||||
# bytes) to hopefully avoid many allocations.
|
||||
buffer = []
|
||||
buffered_bytes = 0
|
||||
while buffered_bytes < self.min_chunk_size:
|
||||
try:
|
||||
data = next(self._iterator)
|
||||
buffer.append(data)
|
||||
buffered_bytes += len(data)
|
||||
except StopIteration:
|
||||
# The entire JSON object has been serialized, write any
|
||||
# remaining data, finalize the producer and the request, and
|
||||
# clean-up any references.
|
||||
self._send_data(buffer)
|
||||
self._request.unregisterProducer()
|
||||
self._request.finish()
|
||||
self.stopProducing()
|
||||
return
|
||||
|
||||
self._send_data(buffer)
|
||||
|
||||
def stopProducing(self) -> None:
|
||||
# Clear a circular reference.
|
||||
self._request = None
|
||||
|
||||
|
||||
@@ -620,8 +632,7 @@ def respond_with_json(
|
||||
if send_cors:
|
||||
set_cors_headers(request)
|
||||
|
||||
producer = _ByteProducer(request, encoder(json_object))
|
||||
producer.start()
|
||||
_ByteProducer(request, encoder(json_object))
|
||||
return NOT_DONE_YET
|
||||
|
||||
|
||||
|
||||
@@ -17,9 +17,8 @@
|
||||
|
||||
import logging
|
||||
|
||||
from canonicaljson import json
|
||||
|
||||
from synapse.api.errors import Codes, SynapseError
|
||||
from synapse.util import json_decoder
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -215,7 +214,7 @@ def parse_json_value_from_request(request, allow_empty_body=False):
|
||||
return None
|
||||
|
||||
try:
|
||||
content = json.loads(content_bytes.decode("utf-8"))
|
||||
content = json_decoder.decode(content_bytes.decode("utf-8"))
|
||||
except Exception as e:
|
||||
logger.warning("Unable to parse JSON: %s", e)
|
||||
raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
|
||||
|
||||
@@ -177,6 +177,7 @@ from canonicaljson import json
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.config import ConfigError
|
||||
from synapse.util import json_decoder
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.http.site import SynapseRequest
|
||||
@@ -499,7 +500,9 @@ def start_active_span_from_edu(
|
||||
if opentracing is None:
|
||||
return _noop_context_manager()
|
||||
|
||||
carrier = json.loads(edu_content.get("context", "{}")).get("opentracing", {})
|
||||
carrier = json_decoder.decode(edu_content.get("context", "{}")).get(
|
||||
"opentracing", {}
|
||||
)
|
||||
context = opentracing.tracer.extract(opentracing.Format.TEXT_MAP, carrier)
|
||||
_references = [
|
||||
opentracing.child_of(span_context_from_string(x))
|
||||
@@ -699,7 +702,7 @@ def span_context_from_string(carrier):
|
||||
Returns:
|
||||
The active span context decoded from a string.
|
||||
"""
|
||||
carrier = json.loads(carrier)
|
||||
carrier = json_decoder.decode(carrier)
|
||||
return opentracing.tracer.extract(opentracing.Format.TEXT_MAP, carrier)
|
||||
|
||||
|
||||
|
||||
@@ -175,7 +175,7 @@ def run_as_background_process(desc: str, func, *args, **kwargs):
|
||||
It returns a Deferred which completes when the function completes, but it doesn't
|
||||
follow the synapse logcontext rules, which makes it appropriate for passing to
|
||||
clock.looping_call and friends (or for firing-and-forgetting in the middle of a
|
||||
normal synapse inlineCallbacks function).
|
||||
normal synapse async function).
|
||||
|
||||
Args:
|
||||
desc: a description for this background process type
|
||||
|
||||
@@ -33,3 +33,11 @@ class SlavedIdTracker(object):
|
||||
int
|
||||
"""
|
||||
return self._current
|
||||
|
||||
def get_current_token_for_writer(self, instance_name: str) -> int:
|
||||
"""Returns the position of the given writer.
|
||||
|
||||
For streams with single writers this is equivalent to
|
||||
`get_current_token`.
|
||||
"""
|
||||
return self.get_current_token()
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
|
||||
from synapse.replication.tcp.streams import PushRulesStream
|
||||
from synapse.storage.databases.main.push_rule import PushRulesWorkerStore
|
||||
|
||||
@@ -21,16 +22,13 @@ from .events import SlavedEventStore
|
||||
|
||||
|
||||
class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
|
||||
def get_push_rules_stream_token(self):
|
||||
return (
|
||||
self._push_rules_stream_id_gen.get_current_token(),
|
||||
self._stream_id_gen.get_current_token(),
|
||||
)
|
||||
|
||||
def get_max_push_rules_stream_id(self):
|
||||
return self._push_rules_stream_id_gen.get_current_token()
|
||||
|
||||
def process_replication_rows(self, stream_name, instance_name, token, rows):
|
||||
# We assert this for the benefit of mypy
|
||||
assert isinstance(self._push_rules_stream_id_gen, SlavedIdTracker)
|
||||
|
||||
if stream_name == PushRulesStream.NAME:
|
||||
self._push_rules_stream_id_gen.advance(token)
|
||||
for row in rows:
|
||||
|
||||
@@ -21,9 +21,7 @@ import abc
|
||||
import logging
|
||||
from typing import Tuple, Type
|
||||
|
||||
from canonicaljson import json
|
||||
|
||||
from synapse.util import json_encoder as _json_encoder
|
||||
from synapse.util import json_decoder, json_encoder
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -125,7 +123,7 @@ class RdataCommand(Command):
|
||||
stream_name,
|
||||
instance_name,
|
||||
None if token == "batch" else int(token),
|
||||
json.loads(row_json),
|
||||
json_decoder.decode(row_json),
|
||||
)
|
||||
|
||||
def to_line(self):
|
||||
@@ -134,7 +132,7 @@ class RdataCommand(Command):
|
||||
self.stream_name,
|
||||
self.instance_name,
|
||||
str(self.token) if self.token is not None else "batch",
|
||||
_json_encoder.encode(self.row),
|
||||
json_encoder.encode(self.row),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -359,7 +357,7 @@ class UserIpCommand(Command):
|
||||
def from_line(cls, line):
|
||||
user_id, jsn = line.split(" ", 1)
|
||||
|
||||
access_token, ip, user_agent, device_id, last_seen = json.loads(jsn)
|
||||
access_token, ip, user_agent, device_id, last_seen = json_decoder.decode(jsn)
|
||||
|
||||
return cls(user_id, access_token, ip, user_agent, device_id, last_seen)
|
||||
|
||||
@@ -367,7 +365,7 @@ class UserIpCommand(Command):
|
||||
return (
|
||||
self.user_id
|
||||
+ " "
|
||||
+ _json_encoder.encode(
|
||||
+ json_encoder.encode(
|
||||
(
|
||||
self.access_token,
|
||||
self.ip,
|
||||
|
||||
@@ -352,7 +352,7 @@ class PushRulesStream(Stream):
|
||||
)
|
||||
|
||||
def _current_token(self, instance_name: str) -> int:
|
||||
push_rules_token, _ = self.store.get_push_rules_stream_token()
|
||||
push_rules_token = self.store.get_max_push_rules_stream_id()
|
||||
return push_rules_token
|
||||
|
||||
|
||||
@@ -405,7 +405,7 @@ class CachesStream(Stream):
|
||||
store = hs.get_datastore()
|
||||
super().__init__(
|
||||
hs.get_instance_name(),
|
||||
store.get_cache_stream_token,
|
||||
store.get_cache_stream_token_for_writer,
|
||||
store.get_all_updated_caches,
|
||||
)
|
||||
|
||||
|
||||
@@ -159,7 +159,7 @@ class PushRuleRestServlet(RestServlet):
|
||||
return 200, {}
|
||||
|
||||
def notify_user(self, user_id):
|
||||
stream_id, _ = self.store.get_push_rules_stream_token()
|
||||
stream_id = self.store.get_max_push_rules_stream_id()
|
||||
self.notifier.on_new_event("push_rules_key", stream_id, users=[user_id])
|
||||
|
||||
async def set_rule_attr(self, user_id, spec, val):
|
||||
|
||||
@@ -21,8 +21,6 @@ import re
|
||||
from typing import List, Optional
|
||||
from urllib import parse as urlparse
|
||||
|
||||
from canonicaljson import json
|
||||
|
||||
from synapse.api.constants import EventTypes, Membership
|
||||
from synapse.api.errors import (
|
||||
AuthError,
|
||||
@@ -46,6 +44,7 @@ from synapse.rest.client.v2_alpha._base import client_patterns
|
||||
from synapse.storage.state import StateFilter
|
||||
from synapse.streams.config import PaginationConfig
|
||||
from synapse.types import RoomAlias, RoomID, StreamToken, ThirdPartyInstanceID, UserID
|
||||
from synapse.util import json_decoder
|
||||
|
||||
MYPY = False
|
||||
if MYPY:
|
||||
@@ -519,7 +518,9 @@ class RoomMessageListRestServlet(RestServlet):
|
||||
filter_str = parse_string(request, b"filter", encoding="utf-8")
|
||||
if filter_str:
|
||||
filter_json = urlparse.unquote(filter_str)
|
||||
event_filter = Filter(json.loads(filter_json)) # type: Optional[Filter]
|
||||
event_filter = Filter(
|
||||
json_decoder.decode(filter_json)
|
||||
) # type: Optional[Filter]
|
||||
if (
|
||||
event_filter
|
||||
and event_filter.filter_json.get("event_format", "client")
|
||||
@@ -631,7 +632,9 @@ class RoomEventContextServlet(RestServlet):
|
||||
filter_str = parse_string(request, b"filter", encoding="utf-8")
|
||||
if filter_str:
|
||||
filter_json = urlparse.unquote(filter_str)
|
||||
event_filter = Filter(json.loads(filter_json)) # type: Optional[Filter]
|
||||
event_filter = Filter(
|
||||
json_decoder.decode(filter_json)
|
||||
) # type: Optional[Filter]
|
||||
else:
|
||||
event_filter = None
|
||||
|
||||
|
||||
@@ -1005,12 +1005,11 @@ class ThreepidLookupRestServlet(RestServlet):
|
||||
self.auth = hs.get_auth()
|
||||
self.identity_handler = hs.get_handlers().identity_handler
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request):
|
||||
async def on_GET(self, request):
|
||||
"""Proxy a /_matrix/identity/api/v1/lookup request to an identity
|
||||
server
|
||||
"""
|
||||
yield self.auth.get_user_by_req(request)
|
||||
await self.auth.get_user_by_req(request)
|
||||
|
||||
# Verify query parameters
|
||||
query_params = request.args
|
||||
@@ -1023,9 +1022,9 @@ class ThreepidLookupRestServlet(RestServlet):
|
||||
|
||||
# Proxy the request to the identity server. lookup_3pid handles checking
|
||||
# if the lookup is allowed so we don't need to do it here.
|
||||
ret = yield self.identity_handler.proxy_lookup_3pid(id_server, medium, address)
|
||||
ret = await self.identity_handler.proxy_lookup_3pid(id_server, medium, address)
|
||||
|
||||
defer.returnValue((200, ret))
|
||||
return 200, ret
|
||||
|
||||
|
||||
class ThreepidBulkLookupRestServlet(RestServlet):
|
||||
@@ -1036,12 +1035,11 @@ class ThreepidBulkLookupRestServlet(RestServlet):
|
||||
self.auth = hs.get_auth()
|
||||
self.identity_handler = hs.get_handlers().identity_handler
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request):
|
||||
async def on_POST(self, request):
|
||||
"""Proxy a /_matrix/identity/api/v1/bulk_lookup request to an identity
|
||||
server
|
||||
"""
|
||||
yield self.auth.get_user_by_req(request)
|
||||
await self.auth.get_user_by_req(request)
|
||||
|
||||
body = parse_json_object_from_request(request)
|
||||
|
||||
@@ -1049,11 +1047,11 @@ class ThreepidBulkLookupRestServlet(RestServlet):
|
||||
|
||||
# Proxy the request to the identity server. lookup_3pid handles checking
|
||||
# if the lookup is allowed so we don't need to do it here.
|
||||
ret = yield self.identity_handler.proxy_bulk_lookup_3pid(
|
||||
ret = await self.identity_handler.proxy_bulk_lookup_3pid(
|
||||
body["id_server"], body["threepids"]
|
||||
)
|
||||
|
||||
defer.returnValue((200, ret))
|
||||
return 200, ret
|
||||
|
||||
|
||||
def assert_valid_next_link(hs: "HomeServer", next_link: str):
|
||||
|
||||
@@ -16,8 +16,6 @@
|
||||
import itertools
|
||||
import logging
|
||||
|
||||
from canonicaljson import json
|
||||
|
||||
from synapse.api.constants import PresenceState
|
||||
from synapse.api.errors import Codes, StoreError, SynapseError
|
||||
from synapse.api.filtering import DEFAULT_FILTER_COLLECTION, FilterCollection
|
||||
@@ -29,6 +27,7 @@ from synapse.handlers.presence import format_user_presence_state
|
||||
from synapse.handlers.sync import SyncConfig
|
||||
from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string
|
||||
from synapse.types import StreamToken
|
||||
from synapse.util import json_decoder
|
||||
|
||||
from ._base import client_patterns, set_timeline_upper_limit
|
||||
|
||||
@@ -125,7 +124,7 @@ class SyncRestServlet(RestServlet):
|
||||
filter_collection = DEFAULT_FILTER_COLLECTION
|
||||
elif filter_id.startswith("{"):
|
||||
try:
|
||||
filter_object = json.loads(filter_id)
|
||||
filter_object = json_decoder.decode(filter_id)
|
||||
set_timeline_upper_limit(
|
||||
filter_object, self.hs.config.filter_timeline_limit
|
||||
)
|
||||
|
||||
@@ -15,19 +15,19 @@
|
||||
import logging
|
||||
from typing import Dict, Set
|
||||
|
||||
from canonicaljson import json
|
||||
from signedjson.sign import sign_json
|
||||
|
||||
from synapse.api.errors import Codes, SynapseError
|
||||
from synapse.crypto.keyring import ServerKeyFetcher
|
||||
from synapse.http.server import DirectServeJsonResource, respond_with_json
|
||||
from synapse.http.servlet import parse_integer, parse_json_object_from_request
|
||||
from synapse.util import json_decoder
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RemoteKey(DirectServeJsonResource):
|
||||
"""HTTP resource for retreiving the TLS certificate and NACL signature
|
||||
"""HTTP resource for retrieving the TLS certificate and NACL signature
|
||||
verification keys for a collection of servers. Checks that the reported
|
||||
X.509 TLS certificate matches the one used in the HTTPS connection. Checks
|
||||
that the NACL signature for the remote server is valid. Returns a dict of
|
||||
@@ -209,13 +209,15 @@ class RemoteKey(DirectServeJsonResource):
|
||||
# Cast to bytes since postgresql returns a memoryview.
|
||||
json_results.add(bytes(result["key_json"]))
|
||||
|
||||
# If there is a cache miss, request the missing keys, then recurse (and
|
||||
# ensure the result is sent).
|
||||
if cache_misses and query_remote_on_cache_miss:
|
||||
await self.fetcher.get_keys(cache_misses)
|
||||
await self.query_keys(request, query, query_remote_on_cache_miss=False)
|
||||
else:
|
||||
signed_keys = []
|
||||
for key_json in json_results:
|
||||
key_json = json.loads(key_json.decode("utf-8"))
|
||||
key_json = json_decoder.decode(key_json.decode("utf-8"))
|
||||
for signing_key in self.config.key_server_signing_keys:
|
||||
key_json = sign_json(key_json, self.config.server_name, signing_key)
|
||||
|
||||
|
||||
@@ -51,5 +51,5 @@ class SpamCheckerApi(object):
|
||||
state_ids = yield self._store.get_filtered_current_state_ids(
|
||||
room_id=room_id, state_filter=StateFilter.from_types(types)
|
||||
)
|
||||
state = yield self._store.get_events(state_ids.values())
|
||||
state = yield defer.ensureDeferred(self._store.get_events(state_ids.values()))
|
||||
return state.values()
|
||||
|
||||
@@ -641,7 +641,7 @@ class StateResolutionStore(object):
|
||||
allow_rejected (bool): If True return rejected events.
|
||||
|
||||
Returns:
|
||||
Deferred[dict[str, FrozenEvent]]: Dict from event_id to event.
|
||||
Awaitable[dict[str, FrozenEvent]]: Dict from event_id to event.
|
||||
"""
|
||||
|
||||
return self.store.get_events(
|
||||
|
||||
@@ -19,12 +19,11 @@ import random
|
||||
from abc import ABCMeta
|
||||
from typing import Any, Optional
|
||||
|
||||
from canonicaljson import json
|
||||
|
||||
from synapse.storage.database import LoggingTransaction # noqa: F401
|
||||
from synapse.storage.database import make_in_list_sql_clause # noqa: F401
|
||||
from synapse.storage.database import DatabasePool
|
||||
from synapse.types import Collection, get_domain_from_id
|
||||
from synapse.util import json_decoder
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -99,13 +98,13 @@ def db_to_json(db_content):
|
||||
if isinstance(db_content, memoryview):
|
||||
db_content = db_content.tobytes()
|
||||
|
||||
# Decode it to a Unicode string before feeding it to json.loads, since
|
||||
# Decode it to a Unicode string before feeding it to the JSON decoder, since
|
||||
# Python 3.5 does not support deserializing bytes.
|
||||
if isinstance(db_content, (bytes, bytearray)):
|
||||
db_content = db_content.decode("utf8")
|
||||
|
||||
try:
|
||||
return json.loads(db_content)
|
||||
return json_decoder.decode(db_content)
|
||||
except Exception:
|
||||
logging.warning("Tried to decode '%r' as JSON and failed", db_content)
|
||||
raise
|
||||
|
||||
@@ -516,14 +516,16 @@ class DatabasePool(object):
|
||||
logger.warning("Starting db txn '%s' from sentinel context", desc)
|
||||
|
||||
try:
|
||||
result = yield self.runWithConnection(
|
||||
self.new_transaction,
|
||||
desc,
|
||||
after_callbacks,
|
||||
exception_callbacks,
|
||||
func,
|
||||
*args,
|
||||
**kwargs
|
||||
result = yield defer.ensureDeferred(
|
||||
self.runWithConnection(
|
||||
self.new_transaction,
|
||||
desc,
|
||||
after_callbacks,
|
||||
exception_callbacks,
|
||||
func,
|
||||
*args,
|
||||
**kwargs
|
||||
)
|
||||
)
|
||||
|
||||
for after_callback, after_args, after_kwargs in after_callbacks:
|
||||
@@ -535,8 +537,7 @@ class DatabasePool(object):
|
||||
|
||||
return result
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def runWithConnection(self, func: Callable, *args: Any, **kwargs: Any):
|
||||
async def runWithConnection(self, func: Callable, *args: Any, **kwargs: Any) -> Any:
|
||||
"""Wraps the .runWithConnection() method on the underlying db_pool.
|
||||
|
||||
Arguments:
|
||||
@@ -547,7 +548,7 @@ class DatabasePool(object):
|
||||
kwargs: named args to pass to `func`
|
||||
|
||||
Returns:
|
||||
Deferred: The result of func
|
||||
The result of func
|
||||
"""
|
||||
parent_context = current_context() # type: Optional[LoggingContextOrSentinel]
|
||||
if not parent_context:
|
||||
@@ -570,12 +571,10 @@ class DatabasePool(object):
|
||||
|
||||
return func(conn, *args, **kwargs)
|
||||
|
||||
result = yield make_deferred_yieldable(
|
||||
return await make_deferred_yieldable(
|
||||
self._db_pool.runWithConnection(inner_func, *args, **kwargs)
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def cursor_to_dict(cursor):
|
||||
"""Converts a SQL cursor into an list of dicts.
|
||||
|
||||
@@ -299,8 +299,8 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
|
||||
},
|
||||
)
|
||||
|
||||
def get_cache_stream_token(self, instance_name):
|
||||
def get_cache_stream_token_for_writer(self, instance_name: str) -> int:
|
||||
if self._cache_id_gen:
|
||||
return self._cache_id_gen.get_current_token(instance_name)
|
||||
return self._cache_id_gen.get_current_token_for_writer(instance_name)
|
||||
else:
|
||||
return 0
|
||||
|
||||
@@ -30,7 +30,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBaseStore):
|
||||
def get_auth_chain(self, event_ids, include_given=False):
|
||||
async def get_auth_chain(self, event_ids, include_given=False):
|
||||
"""Get auth events for given event_ids. The events *must* be state events.
|
||||
|
||||
Args:
|
||||
@@ -40,9 +40,10 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
|
||||
Returns:
|
||||
list of events
|
||||
"""
|
||||
return self.get_auth_chain_ids(
|
||||
event_ids = await self.get_auth_chain_ids(
|
||||
event_ids, include_given=include_given
|
||||
).addCallback(self.get_events_as_list)
|
||||
)
|
||||
return await self.get_events_as_list(event_ids)
|
||||
|
||||
def get_auth_chain_ids(
|
||||
self,
|
||||
@@ -459,7 +460,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
|
||||
"get_forward_extremeties_for_room", get_forward_extremeties_for_room_txn
|
||||
)
|
||||
|
||||
def get_backfill_events(self, room_id, event_list, limit):
|
||||
async def get_backfill_events(self, room_id, event_list, limit):
|
||||
"""Get a list of Events for a given topic that occurred before (and
|
||||
including) the events in event_list. Return a list of max size `limit`
|
||||
|
||||
@@ -469,17 +470,15 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
|
||||
event_list (list)
|
||||
limit (int)
|
||||
"""
|
||||
return (
|
||||
self.db_pool.runInteraction(
|
||||
"get_backfill_events",
|
||||
self._get_backfill_events,
|
||||
room_id,
|
||||
event_list,
|
||||
limit,
|
||||
)
|
||||
.addCallback(self.get_events_as_list)
|
||||
.addCallback(lambda l: sorted(l, key=lambda e: -e.depth))
|
||||
event_ids = await self.db_pool.runInteraction(
|
||||
"get_backfill_events",
|
||||
self._get_backfill_events,
|
||||
room_id,
|
||||
event_list,
|
||||
limit,
|
||||
)
|
||||
events = await self.get_events_as_list(event_ids)
|
||||
return sorted(events, key=lambda e: -e.depth)
|
||||
|
||||
def _get_backfill_events(self, txn, room_id, event_list, limit):
|
||||
logger.debug("_get_backfill_events: %s, %r, %s", room_id, event_list, limit)
|
||||
@@ -540,8 +539,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
|
||||
latest_events,
|
||||
limit,
|
||||
)
|
||||
events = await self.get_events_as_list(ids)
|
||||
return events
|
||||
return await self.get_events_as_list(ids)
|
||||
|
||||
def _get_missing_events(self, txn, room_id, earliest_events, latest_events, limit):
|
||||
|
||||
|
||||
@@ -19,9 +19,10 @@ import itertools
|
||||
import logging
|
||||
import threading
|
||||
from collections import namedtuple
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import Dict, Iterable, List, Optional, Tuple, overload
|
||||
|
||||
from constantly import NamedConstant, Names
|
||||
from typing_extensions import Literal
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
@@ -32,7 +33,7 @@ from synapse.api.room_versions import (
|
||||
EventFormatVersions,
|
||||
RoomVersions,
|
||||
)
|
||||
from synapse.events import make_event_from_dict
|
||||
from synapse.events import EventBase, make_event_from_dict
|
||||
from synapse.events.utils import prune_event
|
||||
from synapse.logging.context import PreserveLoggingContext, current_context
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
@@ -42,8 +43,8 @@ from synapse.replication.tcp.streams.events import EventsStream
|
||||
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
|
||||
from synapse.storage.database import DatabasePool
|
||||
from synapse.storage.util.id_generators import StreamIdGenerator
|
||||
from synapse.types import get_domain_from_id
|
||||
from synapse.util.caches.descriptors import Cache, cachedInlineCallbacks
|
||||
from synapse.types import Collection, get_domain_from_id
|
||||
from synapse.util.caches.descriptors import Cache, cached
|
||||
from synapse.util.iterutils import batch_iter
|
||||
from synapse.util.metrics import Measure
|
||||
|
||||
@@ -137,8 +138,33 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
desc="get_received_ts",
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_event(
|
||||
# Inform mypy that if allow_none is False (the default) then get_event
|
||||
# always returns an EventBase.
|
||||
@overload
|
||||
async def get_event(
|
||||
self,
|
||||
event_id: str,
|
||||
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
|
||||
get_prev_content: bool = False,
|
||||
allow_rejected: bool = False,
|
||||
allow_none: Literal[False] = False,
|
||||
check_room_id: Optional[str] = None,
|
||||
) -> EventBase:
|
||||
...
|
||||
|
||||
@overload
|
||||
async def get_event(
|
||||
self,
|
||||
event_id: str,
|
||||
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
|
||||
get_prev_content: bool = False,
|
||||
allow_rejected: bool = False,
|
||||
allow_none: Literal[True] = False,
|
||||
check_room_id: Optional[str] = None,
|
||||
) -> Optional[EventBase]:
|
||||
...
|
||||
|
||||
async def get_event(
|
||||
self,
|
||||
event_id: str,
|
||||
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
|
||||
@@ -146,7 +172,7 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
allow_rejected: bool = False,
|
||||
allow_none: bool = False,
|
||||
check_room_id: Optional[str] = None,
|
||||
):
|
||||
) -> Optional[EventBase]:
|
||||
"""Get an event from the database by event_id.
|
||||
|
||||
Args:
|
||||
@@ -171,12 +197,12 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
If there is a mismatch, behave as per allow_none.
|
||||
|
||||
Returns:
|
||||
Deferred[EventBase|None]
|
||||
The event, or None if the event was not found.
|
||||
"""
|
||||
if not isinstance(event_id, str):
|
||||
raise TypeError("Invalid event event_id %r" % (event_id,))
|
||||
|
||||
events = yield self.get_events_as_list(
|
||||
events = await self.get_events_as_list(
|
||||
[event_id],
|
||||
redact_behaviour=redact_behaviour,
|
||||
get_prev_content=get_prev_content,
|
||||
@@ -194,14 +220,13 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
|
||||
return event
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_events(
|
||||
async def get_events(
|
||||
self,
|
||||
event_ids: List[str],
|
||||
event_ids: Iterable[str],
|
||||
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
|
||||
get_prev_content: bool = False,
|
||||
allow_rejected: bool = False,
|
||||
):
|
||||
) -> Dict[str, EventBase]:
|
||||
"""Get events from the database
|
||||
|
||||
Args:
|
||||
@@ -220,9 +245,9 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
omits rejeted events from the response.
|
||||
|
||||
Returns:
|
||||
Deferred : Dict from event_id to event.
|
||||
A mapping from event_id to event.
|
||||
"""
|
||||
events = yield self.get_events_as_list(
|
||||
events = await self.get_events_as_list(
|
||||
event_ids,
|
||||
redact_behaviour=redact_behaviour,
|
||||
get_prev_content=get_prev_content,
|
||||
@@ -231,14 +256,13 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
|
||||
return {e.event_id: e for e in events}
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_events_as_list(
|
||||
async def get_events_as_list(
|
||||
self,
|
||||
event_ids: List[str],
|
||||
event_ids: Collection[str],
|
||||
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
|
||||
get_prev_content: bool = False,
|
||||
allow_rejected: bool = False,
|
||||
):
|
||||
) -> List[EventBase]:
|
||||
"""Get events from the database and return in a list in the same order
|
||||
as given by `event_ids` arg.
|
||||
|
||||
@@ -259,8 +283,8 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
omits rejected events from the response.
|
||||
|
||||
Returns:
|
||||
Deferred[list[EventBase]]: List of events fetched from the database. The
|
||||
events are in the same order as `event_ids` arg.
|
||||
List of events fetched from the database. The events are in the same
|
||||
order as `event_ids` arg.
|
||||
|
||||
Note that the returned list may be smaller than the list of event
|
||||
IDs if not all events could be fetched.
|
||||
@@ -270,7 +294,7 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
return []
|
||||
|
||||
# there may be duplicates so we cast the list to a set
|
||||
event_entry_map = yield self._get_events_from_cache_or_db(
|
||||
event_entry_map = await self._get_events_from_cache_or_db(
|
||||
set(event_ids), allow_rejected=allow_rejected
|
||||
)
|
||||
|
||||
@@ -305,7 +329,7 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
continue
|
||||
|
||||
redacted_event_id = entry.event.redacts
|
||||
event_map = yield self._get_events_from_cache_or_db([redacted_event_id])
|
||||
event_map = await self._get_events_from_cache_or_db([redacted_event_id])
|
||||
original_event_entry = event_map.get(redacted_event_id)
|
||||
if not original_event_entry:
|
||||
# we don't have the redacted event (or it was rejected).
|
||||
@@ -371,7 +395,7 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
|
||||
if get_prev_content:
|
||||
if "replaces_state" in event.unsigned:
|
||||
prev = yield self.get_event(
|
||||
prev = await self.get_event(
|
||||
event.unsigned["replaces_state"],
|
||||
get_prev_content=False,
|
||||
allow_none=True,
|
||||
@@ -383,8 +407,7 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
|
||||
return events
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _get_events_from_cache_or_db(self, event_ids, allow_rejected=False):
|
||||
async def _get_events_from_cache_or_db(self, event_ids, allow_rejected=False):
|
||||
"""Fetch a bunch of events from the cache or the database.
|
||||
|
||||
If events are pulled from the database, they will be cached for future lookups.
|
||||
@@ -399,7 +422,7 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
rejected events are omitted from the response.
|
||||
|
||||
Returns:
|
||||
Deferred[Dict[str, _EventCacheEntry]]:
|
||||
Dict[str, _EventCacheEntry]:
|
||||
map from event id to result
|
||||
"""
|
||||
event_entry_map = self._get_events_from_cache(
|
||||
@@ -417,7 +440,7 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
# the events have been redacted, and if so pulling the redaction event out
|
||||
# of the database to check it.
|
||||
#
|
||||
missing_events = yield self._get_events_from_db(
|
||||
missing_events = await self._get_events_from_db(
|
||||
missing_events_ids, allow_rejected=allow_rejected
|
||||
)
|
||||
|
||||
@@ -525,8 +548,7 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
with PreserveLoggingContext():
|
||||
self.hs.get_reactor().callFromThread(fire, event_list, e)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _get_events_from_db(self, event_ids, allow_rejected=False):
|
||||
async def _get_events_from_db(self, event_ids, allow_rejected=False):
|
||||
"""Fetch a bunch of events from the database.
|
||||
|
||||
Returned events will be added to the cache for future lookups.
|
||||
@@ -540,7 +562,7 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
rejected events are omitted from the response.
|
||||
|
||||
Returns:
|
||||
Deferred[Dict[str, _EventCacheEntry]]:
|
||||
Dict[str, _EventCacheEntry]:
|
||||
map from event id to result. May return extra events which
|
||||
weren't asked for.
|
||||
"""
|
||||
@@ -548,7 +570,7 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
events_to_fetch = event_ids
|
||||
|
||||
while events_to_fetch:
|
||||
row_map = yield self._enqueue_events(events_to_fetch)
|
||||
row_map = await self._enqueue_events(events_to_fetch)
|
||||
|
||||
# we need to recursively fetch any redactions of those events
|
||||
redaction_ids = set()
|
||||
@@ -574,8 +596,20 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
if not allow_rejected and rejected_reason:
|
||||
continue
|
||||
|
||||
d = db_to_json(row["json"])
|
||||
internal_metadata = db_to_json(row["internal_metadata"])
|
||||
# If the event or metadata cannot be parsed, log the error and act
|
||||
# as if the event is unknown.
|
||||
try:
|
||||
d = db_to_json(row["json"])
|
||||
except ValueError:
|
||||
logger.error("Unable to parse json from event: %s", event_id)
|
||||
continue
|
||||
try:
|
||||
internal_metadata = db_to_json(row["internal_metadata"])
|
||||
except ValueError:
|
||||
logger.error(
|
||||
"Unable to parse internal_metadata from event: %s", event_id
|
||||
)
|
||||
continue
|
||||
|
||||
format_version = row["format_version"]
|
||||
if format_version is None:
|
||||
@@ -650,8 +684,7 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
|
||||
return result_map
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _enqueue_events(self, events):
|
||||
async def _enqueue_events(self, events):
|
||||
"""Fetches events from the database using the _event_fetch_list. This
|
||||
allows batch and bulk fetching of events - it allows us to fetch events
|
||||
without having to create a new transaction for each request for events.
|
||||
@@ -660,7 +693,7 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
events (Iterable[str]): events to be fetched.
|
||||
|
||||
Returns:
|
||||
Deferred[Dict[str, Dict]]: map from event id to row data from the database.
|
||||
Dict[str, Dict]: map from event id to row data from the database.
|
||||
May contain events that weren't requested.
|
||||
"""
|
||||
|
||||
@@ -683,7 +716,7 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
|
||||
logger.debug("Loading %d events: %s", len(events), events)
|
||||
with PreserveLoggingContext():
|
||||
row_map = yield events_d
|
||||
row_map = await events_d
|
||||
logger.debug("Loaded %d events (%d rows)", len(events), len(row_map))
|
||||
|
||||
return row_map
|
||||
@@ -842,33 +875,29 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
# no valid redaction found for this event
|
||||
return None
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def have_events_in_timeline(self, event_ids):
|
||||
async def have_events_in_timeline(self, event_ids):
|
||||
"""Given a list of event ids, check if we have already processed and
|
||||
stored them as non outliers.
|
||||
"""
|
||||
rows = yield defer.ensureDeferred(
|
||||
self.db_pool.simple_select_many_batch(
|
||||
table="events",
|
||||
retcols=("event_id",),
|
||||
column="event_id",
|
||||
iterable=list(event_ids),
|
||||
keyvalues={"outlier": False},
|
||||
desc="have_events_in_timeline",
|
||||
)
|
||||
rows = await self.db_pool.simple_select_many_batch(
|
||||
table="events",
|
||||
retcols=("event_id",),
|
||||
column="event_id",
|
||||
iterable=list(event_ids),
|
||||
keyvalues={"outlier": False},
|
||||
desc="have_events_in_timeline",
|
||||
)
|
||||
|
||||
return {r["event_id"] for r in rows}
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def have_seen_events(self, event_ids):
|
||||
async def have_seen_events(self, event_ids):
|
||||
"""Given a list of event ids, check if we have already processed them.
|
||||
|
||||
Args:
|
||||
event_ids (iterable[str]):
|
||||
|
||||
Returns:
|
||||
Deferred[set[str]]: The events we have already seen.
|
||||
set[str]: The events we have already seen.
|
||||
"""
|
||||
results = set()
|
||||
|
||||
@@ -884,7 +913,7 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
# break the input up into chunks of 100
|
||||
input_iterator = iter(event_ids)
|
||||
for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)), []):
|
||||
yield self.db_pool.runInteraction(
|
||||
await self.db_pool.runInteraction(
|
||||
"have_seen_events", have_seen_events_txn, chunk
|
||||
)
|
||||
return results
|
||||
@@ -914,8 +943,7 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
room_id,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_room_complexity(self, room_id):
|
||||
async def get_room_complexity(self, room_id):
|
||||
"""
|
||||
Get a rough approximation of the complexity of the room. This is used by
|
||||
remote servers to decide whether they wish to join the room or not.
|
||||
@@ -926,9 +954,9 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
room_id (str)
|
||||
|
||||
Returns:
|
||||
Deferred[dict[str:int]] of complexity version to complexity.
|
||||
dict[str:int] of complexity version to complexity.
|
||||
"""
|
||||
state_events = yield self.get_current_state_event_counts(room_id)
|
||||
state_events = await self.get_current_state_event_counts(room_id)
|
||||
|
||||
# Call this one "v1", so we can introduce new ones as we want to develop
|
||||
# it.
|
||||
@@ -1165,9 +1193,9 @@ class EventsWorkerStore(SQLBaseStore):
|
||||
to_2, so_2 = await self.get_event_ordering(event_id2)
|
||||
return (to_1, so_1) > (to_2, so_2)
|
||||
|
||||
@cachedInlineCallbacks(max_entries=5000)
|
||||
def get_event_ordering(self, event_id):
|
||||
res = yield self.db_pool.simple_select_one(
|
||||
@cached(max_entries=5000)
|
||||
async def get_event_ordering(self, event_id):
|
||||
res = await self.db_pool.simple_select_one(
|
||||
table="events",
|
||||
retcols=["topological_ordering", "stream_ordering"],
|
||||
keyvalues={"event_id": event_id},
|
||||
|
||||
@@ -30,7 +30,7 @@ from synapse.storage.databases.main.pusher import PusherWorkerStore
|
||||
from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
|
||||
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
|
||||
from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
|
||||
from synapse.storage.util.id_generators import ChainedIdGenerator
|
||||
from synapse.storage.util.id_generators import StreamIdGenerator
|
||||
from synapse.util import json_encoder
|
||||
from synapse.util.caches.descriptors import cached, cachedList
|
||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||
@@ -82,9 +82,9 @@ class PushRulesWorkerStore(
|
||||
super(PushRulesWorkerStore, self).__init__(database, db_conn, hs)
|
||||
|
||||
if hs.config.worker.worker_app is None:
|
||||
self._push_rules_stream_id_gen = ChainedIdGenerator(
|
||||
self._stream_id_gen, db_conn, "push_rules_stream", "stream_id"
|
||||
) # type: Union[ChainedIdGenerator, SlavedIdTracker]
|
||||
self._push_rules_stream_id_gen = StreamIdGenerator(
|
||||
db_conn, "push_rules_stream", "stream_id"
|
||||
) # type: Union[StreamIdGenerator, SlavedIdTracker]
|
||||
else:
|
||||
self._push_rules_stream_id_gen = SlavedIdTracker(
|
||||
db_conn, "push_rules_stream", "stream_id"
|
||||
@@ -338,8 +338,9 @@ class PushRuleStore(PushRulesWorkerStore):
|
||||
) -> None:
|
||||
conditions_json = json_encoder.encode(conditions)
|
||||
actions_json = json_encoder.encode(actions)
|
||||
with self._push_rules_stream_id_gen.get_next() as ids:
|
||||
stream_id, event_stream_ordering = ids
|
||||
with self._push_rules_stream_id_gen.get_next() as stream_id:
|
||||
event_stream_ordering = self._stream_id_gen.get_current_token()
|
||||
|
||||
if before or after:
|
||||
await self.db_pool.runInteraction(
|
||||
"_add_push_rule_relative_txn",
|
||||
@@ -559,8 +560,9 @@ class PushRuleStore(PushRulesWorkerStore):
|
||||
txn, stream_id, event_stream_ordering, user_id, rule_id, op="DELETE"
|
||||
)
|
||||
|
||||
with self._push_rules_stream_id_gen.get_next() as ids:
|
||||
stream_id, event_stream_ordering = ids
|
||||
with self._push_rules_stream_id_gen.get_next() as stream_id:
|
||||
event_stream_ordering = self._stream_id_gen.get_current_token()
|
||||
|
||||
await self.db_pool.runInteraction(
|
||||
"delete_push_rule",
|
||||
delete_push_rule_txn,
|
||||
@@ -569,8 +571,9 @@ class PushRuleStore(PushRulesWorkerStore):
|
||||
)
|
||||
|
||||
async def set_push_rule_enabled(self, user_id, rule_id, enabled) -> None:
|
||||
with self._push_rules_stream_id_gen.get_next() as ids:
|
||||
stream_id, event_stream_ordering = ids
|
||||
with self._push_rules_stream_id_gen.get_next() as stream_id:
|
||||
event_stream_ordering = self._stream_id_gen.get_current_token()
|
||||
|
||||
await self.db_pool.runInteraction(
|
||||
"_set_push_rule_enabled_txn",
|
||||
self._set_push_rule_enabled_txn,
|
||||
@@ -643,8 +646,9 @@ class PushRuleStore(PushRulesWorkerStore):
|
||||
data={"actions": actions_json},
|
||||
)
|
||||
|
||||
with self._push_rules_stream_id_gen.get_next() as ids:
|
||||
stream_id, event_stream_ordering = ids
|
||||
with self._push_rules_stream_id_gen.get_next() as stream_id:
|
||||
event_stream_ordering = self._stream_id_gen.get_current_token()
|
||||
|
||||
await self.db_pool.runInteraction(
|
||||
"set_push_rule_actions",
|
||||
set_push_rule_actions_txn,
|
||||
@@ -673,11 +677,5 @@ class PushRuleStore(PushRulesWorkerStore):
|
||||
self.push_rules_stream_cache.entity_has_changed, user_id, stream_id
|
||||
)
|
||||
|
||||
def get_push_rules_stream_token(self):
|
||||
"""Get the position of the push rules stream.
|
||||
Returns a pair of a stream id for the push_rules stream and the
|
||||
room stream ordering it corresponds to."""
|
||||
return self._push_rules_stream_id_gen.get_current_token()
|
||||
|
||||
def get_max_push_rules_stream_id(self):
|
||||
return self.get_push_rules_stream_token()[0]
|
||||
return self._push_rules_stream_id_gen.get_current_token()
|
||||
|
||||
@@ -379,7 +379,6 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
limit: int = 0,
|
||||
order: str = "DESC",
|
||||
) -> Tuple[List[EventBase], str]:
|
||||
|
||||
"""Get new room events in stream ordering since `from_key`.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
import contextlib
|
||||
import threading
|
||||
from collections import deque
|
||||
from typing import Dict, Set, Tuple
|
||||
from typing import Dict, Set
|
||||
|
||||
from typing_extensions import Deque
|
||||
|
||||
@@ -158,63 +158,13 @@ class StreamIdGenerator(object):
|
||||
|
||||
return self._current
|
||||
|
||||
def get_current_token_for_writer(self, instance_name: str) -> int:
|
||||
"""Returns the position of the given writer.
|
||||
|
||||
class ChainedIdGenerator(object):
|
||||
"""Used to generate new stream ids where the stream must be kept in sync
|
||||
with another stream. It generates pairs of IDs, the first element is an
|
||||
integer ID for this stream, the second element is the ID for the stream
|
||||
that this stream needs to be kept in sync with."""
|
||||
|
||||
def __init__(self, chained_generator, db_conn, table, column):
|
||||
self.chained_generator = chained_generator
|
||||
self._table = table
|
||||
self._lock = threading.Lock()
|
||||
self._current_max = _load_current_id(db_conn, table, column)
|
||||
self._unfinished_ids = deque() # type: Deque[Tuple[int, int]]
|
||||
|
||||
def get_next(self):
|
||||
For streams with single writers this is equivalent to
|
||||
`get_current_token`.
|
||||
"""
|
||||
Usage:
|
||||
with stream_id_gen.get_next() as (stream_id, chained_id):
|
||||
# ... persist event ...
|
||||
"""
|
||||
with self._lock:
|
||||
self._current_max += 1
|
||||
next_id = self._current_max
|
||||
chained_id = self.chained_generator.get_current_token()
|
||||
|
||||
self._unfinished_ids.append((next_id, chained_id))
|
||||
|
||||
@contextlib.contextmanager
|
||||
def manager():
|
||||
try:
|
||||
yield (next_id, chained_id)
|
||||
finally:
|
||||
with self._lock:
|
||||
self._unfinished_ids.remove((next_id, chained_id))
|
||||
|
||||
return manager()
|
||||
|
||||
def get_current_token(self):
|
||||
"""Returns the maximum stream id such that all stream ids less than or
|
||||
equal to it have been successfully persisted.
|
||||
"""
|
||||
with self._lock:
|
||||
if self._unfinished_ids:
|
||||
stream_id, chained_id = self._unfinished_ids[0]
|
||||
return stream_id - 1, chained_id
|
||||
|
||||
return self._current_max, self.chained_generator.get_current_token()
|
||||
|
||||
def advance(self, token: int):
|
||||
"""Stub implementation for advancing the token when receiving updates
|
||||
over replication; raises an exception as this instance should be the
|
||||
only source of updates.
|
||||
"""
|
||||
|
||||
raise Exception(
|
||||
"Attempted to advance token on source for table %r", self._table
|
||||
)
|
||||
return self.get_current_token()
|
||||
|
||||
|
||||
class MultiWriterIdGenerator:
|
||||
@@ -298,7 +248,7 @@ class MultiWriterIdGenerator:
|
||||
# Assert the fetched ID is actually greater than what we currently
|
||||
# believe the ID to be. If not, then the sequence and table have got
|
||||
# out of sync somehow.
|
||||
assert self.get_current_token() < next_id
|
||||
assert self.get_current_token_for_writer(self._instance_name) < next_id
|
||||
|
||||
with self._lock:
|
||||
self._unfinished_ids.add(next_id)
|
||||
@@ -344,16 +294,18 @@ class MultiWriterIdGenerator:
|
||||
curr = self._current_positions.get(self._instance_name, 0)
|
||||
self._current_positions[self._instance_name] = max(curr, next_id)
|
||||
|
||||
def get_current_token(self, instance_name: str = None) -> int:
|
||||
"""Gets the current position of a named writer (defaults to current
|
||||
instance).
|
||||
|
||||
Returns 0 if we don't have a position for the named writer (likely due
|
||||
to it being a new writer).
|
||||
def get_current_token(self) -> int:
|
||||
"""Returns the maximum stream id such that all stream ids less than or
|
||||
equal to it have been successfully persisted.
|
||||
"""
|
||||
|
||||
if instance_name is None:
|
||||
instance_name = self._instance_name
|
||||
# Currently we don't support this operation, as it's not obvious how to
|
||||
# condense the stream positions of multiple writers into a single int.
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_current_token_for_writer(self, instance_name: str) -> int:
|
||||
"""Returns the position of the given writer.
|
||||
"""
|
||||
|
||||
with self._lock:
|
||||
return self._current_positions.get(instance_name, 0)
|
||||
|
||||
@@ -39,7 +39,7 @@ class EventSources(object):
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
def get_current_token(self) -> StreamToken:
|
||||
push_rules_key, _ = self.store.get_push_rules_stream_token()
|
||||
push_rules_key = self.store.get_max_push_rules_stream_id()
|
||||
to_device_key = self.store.get_to_device_stream_token()
|
||||
device_list_key = self.store.get_device_stream_token()
|
||||
groups_key = self.store.get_group_stream_token()
|
||||
|
||||
@@ -25,8 +25,18 @@ from synapse.logging import context
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Create a custom encoder to reduce the whitespace produced by JSON encoding.
|
||||
json_encoder = json.JSONEncoder(separators=(",", ":"))
|
||||
|
||||
def _reject_invalid_json(val):
|
||||
"""Do not allow Infinity, -Infinity, or NaN values in JSON."""
|
||||
raise json.JSONDecodeError("Invalid JSON value: '%s'" % val)
|
||||
|
||||
|
||||
# Create a custom encoder to reduce the whitespace produced by JSON encoding and
|
||||
# ensure that valid JSON is produced.
|
||||
json_encoder = json.JSONEncoder(allow_nan=False, separators=(",", ":"))
|
||||
|
||||
# Create a custom decoder to reject Python extensions to JSON.
|
||||
json_decoder = json.JSONDecoder(parse_constant=_reject_invalid_json)
|
||||
|
||||
|
||||
def unwrapFirstError(failure):
|
||||
|
||||
@@ -285,16 +285,9 @@ class Cache(object):
|
||||
|
||||
|
||||
class _CacheDescriptorBase(object):
|
||||
def __init__(
|
||||
self, orig: _CachedFunction, num_args, inlineCallbacks, cache_context=False
|
||||
):
|
||||
def __init__(self, orig: _CachedFunction, num_args, cache_context=False):
|
||||
self.orig = orig
|
||||
|
||||
if inlineCallbacks:
|
||||
self.function_to_call = defer.inlineCallbacks(orig)
|
||||
else:
|
||||
self.function_to_call = orig
|
||||
|
||||
arg_spec = inspect.getfullargspec(orig)
|
||||
all_args = arg_spec.args
|
||||
|
||||
@@ -364,7 +357,7 @@ class CacheDescriptor(_CacheDescriptorBase):
|
||||
invalidated) by adding a special "cache_context" argument to the function
|
||||
and passing that as a kwarg to all caches called. For example::
|
||||
|
||||
@cachedInlineCallbacks(cache_context=True)
|
||||
@cached(cache_context=True)
|
||||
def foo(self, key, cache_context):
|
||||
r1 = yield self.bar1(key, on_invalidate=cache_context.invalidate)
|
||||
r2 = yield self.bar2(key, on_invalidate=cache_context.invalidate)
|
||||
@@ -382,17 +375,11 @@ class CacheDescriptor(_CacheDescriptorBase):
|
||||
max_entries=1000,
|
||||
num_args=None,
|
||||
tree=False,
|
||||
inlineCallbacks=False,
|
||||
cache_context=False,
|
||||
iterable=False,
|
||||
):
|
||||
|
||||
super(CacheDescriptor, self).__init__(
|
||||
orig,
|
||||
num_args=num_args,
|
||||
inlineCallbacks=inlineCallbacks,
|
||||
cache_context=cache_context,
|
||||
)
|
||||
super().__init__(orig, num_args=num_args, cache_context=cache_context)
|
||||
|
||||
self.max_entries = max_entries
|
||||
self.tree = tree
|
||||
@@ -465,9 +452,7 @@ class CacheDescriptor(_CacheDescriptorBase):
|
||||
observer = defer.succeed(cached_result_d)
|
||||
|
||||
except KeyError:
|
||||
ret = defer.maybeDeferred(
|
||||
preserve_fn(self.function_to_call), obj, *args, **kwargs
|
||||
)
|
||||
ret = defer.maybeDeferred(preserve_fn(self.orig), obj, *args, **kwargs)
|
||||
|
||||
def onErr(f):
|
||||
cache.invalidate(cache_key)
|
||||
@@ -510,9 +495,7 @@ class CacheListDescriptor(_CacheDescriptorBase):
|
||||
of results.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, orig, cached_method_name, list_name, num_args=None, inlineCallbacks=False
|
||||
):
|
||||
def __init__(self, orig, cached_method_name, list_name, num_args=None):
|
||||
"""
|
||||
Args:
|
||||
orig (function)
|
||||
@@ -521,12 +504,8 @@ class CacheListDescriptor(_CacheDescriptorBase):
|
||||
num_args (int): number of positional arguments (excluding ``self``,
|
||||
but including list_name) to use as cache keys. Defaults to all
|
||||
named args of the function.
|
||||
inlineCallbacks (bool): Whether orig is a generator that should
|
||||
be wrapped by defer.inlineCallbacks
|
||||
"""
|
||||
super(CacheListDescriptor, self).__init__(
|
||||
orig, num_args=num_args, inlineCallbacks=inlineCallbacks
|
||||
)
|
||||
super().__init__(orig, num_args=num_args)
|
||||
|
||||
self.list_name = list_name
|
||||
|
||||
@@ -631,7 +610,7 @@ class CacheListDescriptor(_CacheDescriptorBase):
|
||||
|
||||
cached_defers.append(
|
||||
defer.maybeDeferred(
|
||||
preserve_fn(self.function_to_call), **args_to_call
|
||||
preserve_fn(self.orig), **args_to_call
|
||||
).addCallbacks(complete_all, errback)
|
||||
)
|
||||
|
||||
@@ -695,21 +674,7 @@ def cached(
|
||||
)
|
||||
|
||||
|
||||
def cachedInlineCallbacks(
|
||||
max_entries=1000, num_args=None, tree=False, cache_context=False, iterable=False
|
||||
):
|
||||
return lambda orig: CacheDescriptor(
|
||||
orig,
|
||||
max_entries=max_entries,
|
||||
num_args=num_args,
|
||||
tree=tree,
|
||||
inlineCallbacks=True,
|
||||
cache_context=cache_context,
|
||||
iterable=iterable,
|
||||
)
|
||||
|
||||
|
||||
def cachedList(cached_method_name, list_name, num_args=None, inlineCallbacks=False):
|
||||
def cachedList(cached_method_name, list_name, num_args=None):
|
||||
"""Creates a descriptor that wraps a function in a `CacheListDescriptor`.
|
||||
|
||||
Used to do batch lookups for an already created cache. A single argument
|
||||
@@ -725,8 +690,6 @@ def cachedList(cached_method_name, list_name, num_args=None, inlineCallbacks=Fal
|
||||
do batch lookups in the cache.
|
||||
num_args (int): Number of arguments to use as the key in the cache
|
||||
(including list_name). Defaults to all named parameters.
|
||||
inlineCallbacks (bool): Should the function be wrapped in an
|
||||
`defer.inlineCallbacks`?
|
||||
|
||||
Example:
|
||||
|
||||
@@ -744,5 +707,4 @@ def cachedList(cached_method_name, list_name, num_args=None, inlineCallbacks=Fal
|
||||
cached_method_name=cached_method_name,
|
||||
list_name=list_name,
|
||||
num_args=num_args,
|
||||
inlineCallbacks=inlineCallbacks,
|
||||
)
|
||||
|
||||
@@ -71,7 +71,9 @@ class ProfileTestCase(unittest.TestCase):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_get_my_name(self):
|
||||
yield self.store.set_profile_displayname(self.frank.localpart, "Frank", 1)
|
||||
yield defer.ensureDeferred(
|
||||
self.store.set_profile_displayname(self.frank.localpart, "Frank", 1)
|
||||
)
|
||||
|
||||
displayname = yield defer.ensureDeferred(
|
||||
self.handler.get_displayname(self.frank)
|
||||
@@ -112,10 +114,17 @@ class ProfileTestCase(unittest.TestCase):
|
||||
self.hs.config.enable_set_displayname = False
|
||||
|
||||
# Setting displayname for the first time is allowed
|
||||
yield self.store.set_profile_displayname(self.frank.localpart, "Frank", 1)
|
||||
yield defer.ensureDeferred(
|
||||
self.store.set_profile_displayname(self.frank.localpart, "Frank", 1)
|
||||
)
|
||||
|
||||
self.assertEquals(
|
||||
(yield self.store.get_profile_displayname(self.frank.localpart)), "Frank",
|
||||
(
|
||||
yield defer.ensureDeferred(
|
||||
self.store.get_profile_displayname(self.frank.localpart)
|
||||
)
|
||||
),
|
||||
"Frank",
|
||||
)
|
||||
|
||||
# Setting displayname a second time is forbidden
|
||||
@@ -158,7 +167,9 @@ class ProfileTestCase(unittest.TestCase):
|
||||
@defer.inlineCallbacks
|
||||
def test_incoming_fed_query(self):
|
||||
yield defer.ensureDeferred(self.store.create_profile("caroline"))
|
||||
yield self.store.set_profile_displayname("caroline", "Caroline", 1)
|
||||
yield defer.ensureDeferred(
|
||||
self.store.set_profile_displayname("caroline", "Caroline", 1)
|
||||
)
|
||||
|
||||
response = yield defer.ensureDeferred(
|
||||
self.query_handlers["profile"](
|
||||
@@ -170,8 +181,10 @@ class ProfileTestCase(unittest.TestCase):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_get_my_avatar(self):
|
||||
yield self.store.set_profile_avatar_url(
|
||||
self.frank.localpart, "http://my.server/me.png", 1
|
||||
yield defer.ensureDeferred(
|
||||
self.store.set_profile_avatar_url(
|
||||
self.frank.localpart, "http://my.server/me.png", 1
|
||||
)
|
||||
)
|
||||
avatar_url = yield defer.ensureDeferred(self.handler.get_avatar_url(self.frank))
|
||||
|
||||
@@ -211,8 +224,10 @@ class ProfileTestCase(unittest.TestCase):
|
||||
self.hs.config.enable_set_avatar_url = False
|
||||
|
||||
# Setting displayname for the first time is allowed
|
||||
yield self.store.set_profile_avatar_url(
|
||||
self.frank.localpart, "http://my.server/me.png", 1
|
||||
yield defer.ensureDeferred(
|
||||
self.store.set_profile_avatar_url(
|
||||
self.frank.localpart, "http://my.server/me.png", 1
|
||||
)
|
||||
)
|
||||
|
||||
self.assertEquals(
|
||||
|
||||
@@ -62,8 +62,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
|
||||
"identifier": {"type": "m.id.user", "user": "kermit" + str(i)},
|
||||
"password": "monkey",
|
||||
}
|
||||
request_data = json.dumps(params)
|
||||
request, channel = self.make_request(b"POST", LOGIN_URL, request_data)
|
||||
request, channel = self.make_request(b"POST", LOGIN_URL, params)
|
||||
self.render(request)
|
||||
|
||||
if i == 5:
|
||||
@@ -76,14 +75,13 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
|
||||
# than 1min.
|
||||
self.assertTrue(retry_after_ms < 6000)
|
||||
|
||||
self.reactor.advance(retry_after_ms / 1000.0)
|
||||
self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
|
||||
|
||||
params = {
|
||||
"type": "m.login.password",
|
||||
"identifier": {"type": "m.id.user", "user": "kermit" + str(i)},
|
||||
"password": "monkey",
|
||||
}
|
||||
request_data = json.dumps(params)
|
||||
request, channel = self.make_request(b"POST", LOGIN_URL, params)
|
||||
self.render(request)
|
||||
|
||||
@@ -111,8 +109,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
|
||||
"identifier": {"type": "m.id.user", "user": "kermit"},
|
||||
"password": "monkey",
|
||||
}
|
||||
request_data = json.dumps(params)
|
||||
request, channel = self.make_request(b"POST", LOGIN_URL, request_data)
|
||||
request, channel = self.make_request(b"POST", LOGIN_URL, params)
|
||||
self.render(request)
|
||||
|
||||
if i == 5:
|
||||
@@ -132,7 +129,6 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
|
||||
"identifier": {"type": "m.id.user", "user": "kermit"},
|
||||
"password": "monkey",
|
||||
}
|
||||
request_data = json.dumps(params)
|
||||
request, channel = self.make_request(b"POST", LOGIN_URL, params)
|
||||
self.render(request)
|
||||
|
||||
@@ -160,8 +156,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
|
||||
"identifier": {"type": "m.id.user", "user": "kermit"},
|
||||
"password": "notamonkey",
|
||||
}
|
||||
request_data = json.dumps(params)
|
||||
request, channel = self.make_request(b"POST", LOGIN_URL, request_data)
|
||||
request, channel = self.make_request(b"POST", LOGIN_URL, params)
|
||||
self.render(request)
|
||||
|
||||
if i == 5:
|
||||
@@ -174,14 +169,13 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
|
||||
# than 1min.
|
||||
self.assertTrue(retry_after_ms < 6000)
|
||||
|
||||
self.reactor.advance(retry_after_ms / 1000.0)
|
||||
self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
|
||||
|
||||
params = {
|
||||
"type": "m.login.password",
|
||||
"identifier": {"type": "m.id.user", "user": "kermit"},
|
||||
"password": "notamonkey",
|
||||
}
|
||||
request_data = json.dumps(params)
|
||||
request, channel = self.make_request(b"POST", LOGIN_URL, params)
|
||||
self.render(request)
|
||||
|
||||
|
||||
@@ -156,7 +156,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
|
||||
else:
|
||||
self.assertEquals(channel.result["code"], b"200", channel.result)
|
||||
|
||||
self.reactor.advance(retry_after_ms / 1000.0)
|
||||
self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
|
||||
|
||||
request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
|
||||
self.render(request)
|
||||
@@ -182,7 +182,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
|
||||
else:
|
||||
self.assertEquals(channel.result["code"], b"200", channel.result)
|
||||
|
||||
self.reactor.advance(retry_after_ms / 1000.0)
|
||||
self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
|
||||
|
||||
request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
|
||||
self.render(request)
|
||||
|
||||
@@ -104,7 +104,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
|
||||
type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
|
||||
)
|
||||
self._rlsn._store.get_events = Mock(
|
||||
return_value=defer.succeed({"123": mock_event})
|
||||
return_value=make_awaitable({"123": mock_event})
|
||||
)
|
||||
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
|
||||
# Would be better to check the content, but once == remove blocking event
|
||||
@@ -122,7 +122,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
|
||||
type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
|
||||
)
|
||||
self._rlsn._store.get_events = Mock(
|
||||
return_value=defer.succeed({"123": mock_event})
|
||||
return_value=make_awaitable({"123": mock_event})
|
||||
)
|
||||
|
||||
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
|
||||
@@ -217,7 +217,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
|
||||
type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
|
||||
)
|
||||
self._rlsn._store.get_events = Mock(
|
||||
return_value=defer.succeed({"123": mock_event})
|
||||
return_value=make_awaitable({"123": mock_event})
|
||||
)
|
||||
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
|
||||
|
||||
|
||||
@@ -31,6 +31,7 @@ from synapse.storage.databases.main.appservice import (
|
||||
)
|
||||
|
||||
from tests import unittest
|
||||
from tests.test_utils import make_awaitable
|
||||
from tests.utils import setup_test_homeserver
|
||||
|
||||
|
||||
@@ -357,7 +358,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
|
||||
other_events = [Mock(event_id="e5"), Mock(event_id="e6")]
|
||||
|
||||
# we aren't testing store._base stuff here, so mock this out
|
||||
self.store.get_events_as_list = Mock(return_value=defer.succeed(events))
|
||||
self.store.get_events_as_list = Mock(return_value=make_awaitable(events))
|
||||
|
||||
yield self._insert_txn(self.as_list[1]["id"], 9, other_events)
|
||||
yield self._insert_txn(service.id, 10, events)
|
||||
|
||||
@@ -353,6 +353,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
|
||||
self.event_creator_handler._rooms_to_exclude_from_dummy_event_insertion[
|
||||
"3"
|
||||
] = 300000
|
||||
|
||||
self.event_creator_handler._expire_rooms_to_exclude_from_dummy_event_insertion()
|
||||
# All entries within time frame
|
||||
self.assertEqual(
|
||||
@@ -362,7 +363,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
|
||||
3,
|
||||
)
|
||||
# Oldest room to expire
|
||||
self.pump(1)
|
||||
self.pump(1.01)
|
||||
self.event_creator_handler._expire_rooms_to_exclude_from_dummy_event_insertion()
|
||||
self.assertEqual(
|
||||
len(
|
||||
|
||||
@@ -88,7 +88,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
|
||||
id_gen = self._create_id_generator()
|
||||
|
||||
self.assertEqual(id_gen.get_positions(), {"master": 7})
|
||||
self.assertEqual(id_gen.get_current_token("master"), 7)
|
||||
self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
|
||||
|
||||
# Try allocating a new ID gen and check that we only see position
|
||||
# advanced after we leave the context manager.
|
||||
@@ -98,12 +98,12 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
|
||||
self.assertEqual(stream_id, 8)
|
||||
|
||||
self.assertEqual(id_gen.get_positions(), {"master": 7})
|
||||
self.assertEqual(id_gen.get_current_token("master"), 7)
|
||||
self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
|
||||
|
||||
self.get_success(_get_next_async())
|
||||
|
||||
self.assertEqual(id_gen.get_positions(), {"master": 8})
|
||||
self.assertEqual(id_gen.get_current_token("master"), 8)
|
||||
self.assertEqual(id_gen.get_current_token_for_writer("master"), 8)
|
||||
|
||||
def test_multi_instance(self):
|
||||
"""Test that reads and writes from multiple processes are handled
|
||||
@@ -116,8 +116,8 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
|
||||
second_id_gen = self._create_id_generator("second")
|
||||
|
||||
self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7})
|
||||
self.assertEqual(first_id_gen.get_current_token("first"), 3)
|
||||
self.assertEqual(first_id_gen.get_current_token("second"), 7)
|
||||
self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 3)
|
||||
self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 7)
|
||||
|
||||
# Try allocating a new ID gen and check that we only see position
|
||||
# advanced after we leave the context manager.
|
||||
@@ -166,7 +166,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
|
||||
id_gen = self._create_id_generator()
|
||||
|
||||
self.assertEqual(id_gen.get_positions(), {"master": 7})
|
||||
self.assertEqual(id_gen.get_current_token("master"), 7)
|
||||
self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
|
||||
|
||||
# Try allocating a new ID gen and check that we only see position
|
||||
# advanced after we leave the context manager.
|
||||
@@ -176,9 +176,9 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
|
||||
self.assertEqual(stream_id, 8)
|
||||
|
||||
self.assertEqual(id_gen.get_positions(), {"master": 7})
|
||||
self.assertEqual(id_gen.get_current_token("master"), 7)
|
||||
self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
|
||||
|
||||
self.get_success(self.db_pool.runInteraction("test", _get_next_txn))
|
||||
|
||||
self.assertEqual(id_gen.get_positions(), {"master": 8})
|
||||
self.assertEqual(id_gen.get_current_token("master"), 8)
|
||||
self.assertEqual(id_gen.get_current_token_for_writer("master"), 8)
|
||||
|
||||
@@ -34,10 +34,12 @@ class DataStoreTestCase(unittest.TestCase):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_get_users_paginate(self):
|
||||
yield self.store.register_user(self.user.to_string(), "pass")
|
||||
yield defer.ensureDeferred(
|
||||
self.store.register_user(self.user.to_string(), "pass")
|
||||
)
|
||||
yield defer.ensureDeferred(self.store.create_profile(self.user.localpart))
|
||||
yield self.store.set_profile_displayname(
|
||||
self.user.localpart, self.displayname, 1
|
||||
yield defer.ensureDeferred(
|
||||
self.store.set_profile_displayname(self.user.localpart, self.displayname, 1)
|
||||
)
|
||||
|
||||
users, total = yield self.store.get_users_paginate(
|
||||
|
||||
@@ -366,11 +366,11 @@ class CachedListDescriptorTestCase(unittest.TestCase):
|
||||
def fn(self, arg1, arg2):
|
||||
pass
|
||||
|
||||
@descriptors.cachedList("fn", "args1", inlineCallbacks=True)
|
||||
def list_fn(self, args1, arg2):
|
||||
@descriptors.cachedList("fn", "args1")
|
||||
async def list_fn(self, args1, arg2):
|
||||
assert current_context().request == "c1"
|
||||
# we want this to behave like an asynchronous function
|
||||
yield run_on_reactor()
|
||||
await run_on_reactor()
|
||||
assert current_context().request == "c1"
|
||||
return self.mock(args1, arg2)
|
||||
|
||||
@@ -416,10 +416,10 @@ class CachedListDescriptorTestCase(unittest.TestCase):
|
||||
def fn(self, arg1, arg2):
|
||||
pass
|
||||
|
||||
@descriptors.cachedList("fn", "args1", inlineCallbacks=True)
|
||||
def list_fn(self, args1, arg2):
|
||||
@descriptors.cachedList("fn", "args1")
|
||||
async def list_fn(self, args1, arg2):
|
||||
# we want this to behave like an asynchronous function
|
||||
yield run_on_reactor()
|
||||
await run_on_reactor()
|
||||
return self.mock(args1, arg2)
|
||||
|
||||
obj = Cls()
|
||||
|
||||
Reference in New Issue
Block a user