1
0

Compare commits

..

1 Commits

Author SHA1 Message Date
Travis Ralston b024acffea Add rudimentary API for promoting/demoting other people in a group
For https://github.com/matrix-org/synapse/issues/2855 (initial)
2020-08-18 15:21:30 -06:00
58 changed files with 452 additions and 366 deletions
-1
View File
@@ -1 +0,0 @@
Convert various parts of the codebase to async/await.
-1
View File
@@ -1 +0,0 @@
Fix a long-standing bug where invalid JSON would be accepted by Synapse.
-1
View File
@@ -1 +0,0 @@
Separate `get_current_token` into two since there are two different use cases for it.
-1
View File
@@ -1 +0,0 @@
Iteratively encode JSON to avoid blocking the reactor.
-1
View File
@@ -1 +0,0 @@
Convert various parts of the codebase to async/await.
-1
View File
@@ -1 +0,0 @@
Updated documentation to note that Synapse does not follow `HTTP 308` redirects due to an upstream library not supporting them. Contributed by Ryan Cole.
-1
View File
@@ -1 +0,0 @@
Convert various parts of the codebase to async/await.
-1
View File
@@ -1 +0,0 @@
Remove `ChainedIdGenerator`.
-12
View File
@@ -47,18 +47,6 @@ you invite them to. This can be caused by an incorrectly-configured reverse
proxy: see [reverse_proxy.md](<reverse_proxy.md>) for instructions on how to correctly
configure a reverse proxy.
### Known issues
**HTTP `308 Permanent Redirect` redirects are not followed**: Due to missing features
in the HTTP library used by Synapse, 308 redirects are currently not followed by
federating servers, which can cause `M_UNKNOWN` or `401 Unauthorized` errors. This
may affect users who are redirecting apex-to-www (e.g. `example.com` -> `www.example.com`),
and especially users of the Kubernetes *Nginx Ingress* module, which uses 308 redirect
codes by default. For those Kubernetes users, [this Stackoverflow post](https://stackoverflow.com/a/52617528/5096871)
might be helpful. For other users, switching to a `301 Moved Permanently` code may be
an option. 308 redirect codes will be supported properly in a future
release of Synapse.
## Running a demo federation of Synapses
If you want to get up and running quickly with a trio of homeservers in a
+3 -3
View File
@@ -21,9 +21,9 @@ import typing
from http import HTTPStatus
from typing import Dict, List, Optional, Union
from twisted.web import http
from canonicaljson import json
from synapse.util import json_decoder
from twisted.web import http
if typing.TYPE_CHECKING:
from synapse.types import JsonDict
@@ -593,7 +593,7 @@ class HttpResponseException(CodeMessageException):
# try to parse the body as json, to get better errcode/msg, but
# default to M_UNKNOWN with the HTTP status as the error text
try:
j = json_decoder.decode(self.response.decode("utf-8"))
j = json.loads(self.response.decode("utf-8"))
except ValueError:
j = {}
+1 -1
View File
@@ -47,7 +47,7 @@ def check(
Args:
room_version_obj: the version of the room
event: the event being checked.
auth_events: the existing room state.
auth_events (dict: event-key -> event): the existing room state.
Raises:
AuthError if the checks fail
+3 -2
View File
@@ -28,6 +28,7 @@ from typing import (
Union,
)
from canonicaljson import json
from prometheus_client import Counter, Histogram
from twisted.internet import defer
@@ -62,7 +63,7 @@ from synapse.replication.http.federation import (
ReplicationGetQueryRestServlet,
)
from synapse.types import JsonDict, get_domain_from_id
from synapse.util import glob_to_regex, json_decoder, unwrapFirstError
from synapse.util import glob_to_regex, unwrapFirstError
from synapse.util.async_helpers import Linearizer, concurrently_execute
from synapse.util.caches.response_cache import ResponseCache
@@ -550,7 +551,7 @@ class FederationServer(FederationBase):
for device_id, keys in device_keys.items():
for key_id, json_str in keys.items():
json_result.setdefault(user_id, {})[device_id] = {
key_id: json_decoder.decode(json_str)
key_id: json.loads(json_str)
}
logger.info(
@@ -15,6 +15,8 @@
import logging
from typing import TYPE_CHECKING, List, Tuple
from canonicaljson import json
from synapse.api.errors import HttpResponseException
from synapse.events import EventBase
from synapse.federation.persistence import TransactionActions
@@ -26,7 +28,6 @@ from synapse.logging.opentracing import (
tags,
whitelisted_homeserver,
)
from synapse.util import json_decoder
from synapse.util.metrics import measure_func
if TYPE_CHECKING:
@@ -70,7 +71,7 @@ class TransactionManager(object):
for edu in pending_edus:
context = edu.get_context()
if context:
span_contexts.append(extract_text_map(json_decoder.decode(context)))
span_contexts.append(extract_text_map(json.loads(context)))
if keep_destination:
edu.strip_context()
+21
View File
@@ -719,6 +719,27 @@ class GroupsServerHandler(GroupsServerWorkerHandler):
raise NotImplementedError()
async def change_user_admin_in_group(
self, group_id, user_id, want_admin, requester_user_id, content
):
"""Promotes or demotes a user in a group.
"""
await self.check_group_is_ours(group_id, requester_user_id, and_exists=True)
if requester_user_id == user_id:
raise SynapseError(400, "User cannot target themselves")
is_admin = await self.store.is_user_admin_in_group(
group_id, requester_user_id
)
if not is_admin:
raise SynapseError(403, "User is not admin in group")
await self.store.change_user_admin_in_group(group_id, user_id, want_admin)
return {}
async def remove_user_from_group(
self, group_id, user_id, requester_user_id, content
):
+4 -4
View File
@@ -19,7 +19,7 @@ import logging
from typing import Dict, List, Optional, Tuple
import attr
from canonicaljson import encode_canonical_json
from canonicaljson import encode_canonical_json, json
from signedjson.key import VerifyKey, decode_verify_key_bytes
from signedjson.sign import SignatureVerifyException, verify_signed_json
from unpaddedbase64 import decode_base64
@@ -35,7 +35,7 @@ from synapse.types import (
get_domain_from_id,
get_verify_key_from_cross_signing_key,
)
from synapse.util import json_decoder, unwrapFirstError
from synapse.util import unwrapFirstError
from synapse.util.async_helpers import Linearizer
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.retryutils import NotRetryingDestination
@@ -404,7 +404,7 @@ class E2eKeysHandler(object):
for device_id, keys in device_keys.items():
for key_id, json_bytes in keys.items():
json_result.setdefault(user_id, {})[device_id] = {
key_id: json_decoder.decode(json_bytes)
key_id: json.loads(json_bytes)
}
@trace
@@ -1186,7 +1186,7 @@ def _exception_to_failure(e):
def _one_time_keys_match(old_key_json, new_key):
old_key = json_decoder.decode(old_key_json)
old_key = json.loads(old_key_json)
# if either is a string rather than an object, they must match exactly
if not isinstance(old_key, dict) or not isinstance(new_key, dict):
+14 -23
View File
@@ -17,7 +17,6 @@
"""Contains handlers for federation events."""
import collections
import itertools
import logging
from collections.abc import Container
@@ -1369,25 +1368,11 @@ class FederationHandler(BaseHandler):
self.config.worker.writers.events, "events", max_stream_id
)
# Check whether this room is the result of an upgrade of a room we
# already know about. If so, migrate over user information
#
# Note: we do this manually rather than asking the DB to avoid a
# race where the current state hasn't yet updated.
predecessor = None
for s in state:
if s.type == EventTypes.Create:
predecessor = s.content.get("predecessor", None)
# Ensure the key is a dictionary
if not isinstance(predecessor, collections.abc.Mapping):
predecessor = None
break
# Check whether this room is the result of an upgrade of a room we already know
# about. If so, migrate over user information
predecessor = await self.store.get_room_predecessor(room_id)
if not predecessor or not isinstance(predecessor.get("room_id"), str):
return event.event_id, max_stream_id
old_room_id = predecessor["room_id"]
logger.debug(
"Found predecessor for %s during remote join: %s", room_id, old_room_id
@@ -1792,7 +1777,9 @@ class FederationHandler(BaseHandler):
"""Returns the state at the event. i.e. not including said event.
"""
event = await self.store.get_event(event_id, check_room_id=room_id)
event = await self.store.get_event(
event_id, allow_none=False, check_room_id=room_id
)
state_groups = await self.state_store.get_state_groups(room_id, [event_id])
@@ -1818,7 +1805,9 @@ class FederationHandler(BaseHandler):
async def get_state_ids_for_pdu(self, room_id: str, event_id: str) -> List[str]:
"""Returns the state at the event. i.e. not including said event.
"""
event = await self.store.get_event(event_id, check_room_id=room_id)
event = await self.store.get_event(
event_id, allow_none=False, check_room_id=room_id
)
state_groups = await self.state_store.get_state_groups_ids(room_id, [event_id])
@@ -2166,9 +2155,9 @@ class FederationHandler(BaseHandler):
auth_types = auth_types_for_event(event)
current_state_ids = [e for k, e in current_state_ids.items() if k in auth_types]
auth_events_map = await self.store.get_events(current_state_ids)
current_auth_events = await self.store.get_events(current_state_ids)
current_auth_events = {
(e.type, e.state_key): e for e in auth_events_map.values()
(e.type, e.state_key): e for e in current_auth_events.values()
}
try:
@@ -2184,7 +2173,9 @@ class FederationHandler(BaseHandler):
if not in_room:
raise AuthError(403, "Host not in room.")
event = await self.store.get_event(event_id, check_room_id=room_id)
event = await self.store.get_event(
event_id, allow_none=False, check_room_id=room_id
)
# Just go through and process each event in `remote_auth_chain`. We
# don't want to fall into the trap of `missing` being wrong.
+19
View File
@@ -461,6 +461,25 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
return {"state": "invite", "user_profile": user_profile}
async def change_user_admin_in_group(
self, group_id, user_id, want_admin, requester_user_id, content
):
"""Promotes or demotes a user in a group.
"""
if not self.is_mine_id(user_id):
raise SynapseError(400, "User not on this server")
# TODO: We should probably support federation, but this is fine for now
if not self.is_mine_id(group_id):
raise SynapseError(400, "Group not on this server")
res = await self.groups_server_handler.change_user_admin_in_group(
group_id, user_id, want_admin, requester_user_id, content
)
return res
async def remove_user_from_group(
self, group_id, user_id, requester_user_id, content
):
+3 -2
View File
@@ -21,6 +21,8 @@ import logging
import urllib.parse
from typing import Awaitable, Callable, Dict, List, Optional, Tuple
from canonicaljson import json
from twisted.internet.error import TimeoutError
from synapse.api.errors import (
@@ -32,7 +34,6 @@ from synapse.api.errors import (
from synapse.config.emailconfig import ThreepidBehaviour
from synapse.http.client import SimpleHttpClient
from synapse.types import JsonDict, Requester
from synapse.util import json_decoder
from synapse.util.hash import sha256_and_url_safe_base64
from synapse.util.stringutils import assert_valid_client_secret, random_string
@@ -176,7 +177,7 @@ class IdentityHandler(BaseHandler):
except TimeoutError:
raise SynapseError(500, "Timed out contacting identity server")
except CodeMessageException as e:
data = json_decoder.decode(e.msg) # XXX WAT?
data = json.loads(e.msg) # XXX WAT?
return data
logger.info("Got 404 when POSTing JSON %s, falling back to v1 URL", bind_url)
+5 -6
View File
@@ -17,7 +17,7 @@
import logging
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
from canonicaljson import encode_canonical_json
from canonicaljson import encode_canonical_json, json
from twisted.internet.interfaces import IDelayedCall
@@ -55,7 +55,6 @@ from synapse.types import (
UserID,
create_requester,
)
from synapse.util import json_decoder
from synapse.util.async_helpers import Linearizer
from synapse.util.frozenutils import frozendict_json_encoder
from synapse.util.metrics import measure_func
@@ -865,7 +864,7 @@ class EventCreationHandler(object):
# Ensure that we can round trip before trying to persist in db
try:
dump = frozendict_json_encoder.encode(event.content)
json_decoder.decode(dump)
json.loads(dump)
except Exception:
logger.exception("Failed to encode content: %r", event.content)
raise
@@ -961,7 +960,7 @@ class EventCreationHandler(object):
allow_none=True,
)
is_admin_redaction = bool(
is_admin_redaction = (
original_event and event.sender != original_event.sender
)
@@ -1081,8 +1080,8 @@ class EventCreationHandler(object):
auth_events_ids = self.auth.compute_auth_events(
event, prev_state_ids, for_verification=True
)
auth_events_map = await self.store.get_events(auth_events_ids)
auth_events = {(e.type, e.state_key): e for e in auth_events_map.values()}
auth_events = await self.store.get_events(auth_events_ids)
auth_events = {(e.type, e.state_key): e for e in auth_events.values()}
room_version = await self.store.get_room_version_id(event.room_id)
room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
+3 -3
View File
@@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import logging
from typing import TYPE_CHECKING, Dict, Generic, List, Optional, Tuple, TypeVar
from urllib.parse import urlencode
@@ -38,7 +39,6 @@ from synapse.http.server import respond_with_html
from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable
from synapse.types import UserID, map_username_to_mxid_localpart
from synapse.util import json_decoder
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -367,7 +367,7 @@ class OidcHandler:
# and check for an error field. If not, we respond with a generic
# error message.
try:
resp = json_decoder.decode(resp_body.decode("utf-8"))
resp = json.loads(resp_body.decode("utf-8"))
error = resp["error"]
description = resp.get("error_description", error)
except (ValueError, KeyError):
@@ -384,7 +384,7 @@ class OidcHandler:
# Since it is a not a 5xx code, body should be a valid JSON. It will
# raise if not.
resp = json_decoder.decode(resp_body.decode("utf-8"))
resp = json.loads(resp_body.decode("utf-8"))
if "error" in resp:
error = resp["error"]
+1 -1
View File
@@ -716,7 +716,7 @@ class RoomMemberHandler(object):
guest_access = await self.store.get_event(guest_access_id)
return bool(
return (
guest_access
and guest_access.content
and "guest_access" in guest_access.content
+3 -2
View File
@@ -16,12 +16,13 @@
import logging
from typing import Any
from canonicaljson import json
from twisted.web.client import PartialDownloadError
from synapse.api.constants import LoginType
from synapse.api.errors import Codes, LoginError, SynapseError
from synapse.config.emailconfig import ThreepidBehaviour
from synapse.util import json_decoder
logger = logging.getLogger(__name__)
@@ -116,7 +117,7 @@ class RecaptchaAuthChecker(UserInteractiveAuthChecker):
except PartialDownloadError as pde:
# Twisted is silly
data = pde.response
resp_body = json_decoder.decode(data.decode("utf-8"))
resp_body = json.loads(data.decode("utf-8"))
if "success" in resp_body:
# Note that we do NOT check the hostname here: we explicitly
+5 -6
View File
@@ -19,7 +19,7 @@ import urllib
from io import BytesIO
import treq
from canonicaljson import encode_canonical_json
from canonicaljson import encode_canonical_json, json
from netaddr import IPAddress
from prometheus_client import Counter
from zope.interface import implementer, provider
@@ -47,7 +47,6 @@ from synapse.http import (
from synapse.http.proxyagent import ProxyAgent
from synapse.logging.context import make_deferred_yieldable
from synapse.logging.opentracing import set_tag, start_active_span, tags
from synapse.util import json_decoder
from synapse.util.async_helpers import timeout_deferred
logger = logging.getLogger(__name__)
@@ -392,7 +391,7 @@ class SimpleHttpClient(object):
body = await make_deferred_yieldable(readBody(response))
if 200 <= response.code < 300:
return json_decoder.decode(body.decode("utf-8"))
return json.loads(body.decode("utf-8"))
else:
raise HttpResponseException(
response.code, response.phrase.decode("ascii", errors="replace"), body
@@ -434,7 +433,7 @@ class SimpleHttpClient(object):
body = await make_deferred_yieldable(readBody(response))
if 200 <= response.code < 300:
return json_decoder.decode(body.decode("utf-8"))
return json.loads(body.decode("utf-8"))
else:
raise HttpResponseException(
response.code, response.phrase.decode("ascii", errors="replace"), body
@@ -464,7 +463,7 @@ class SimpleHttpClient(object):
actual_headers.update(headers)
body = await self.get_raw(uri, args, headers=headers)
return json_decoder.decode(body.decode("utf-8"))
return json.loads(body.decode("utf-8"))
async def put_json(self, uri, json_body, args={}, headers=None):
""" Puts some json to the given URI.
@@ -507,7 +506,7 @@ class SimpleHttpClient(object):
body = await make_deferred_yieldable(readBody(response))
if 200 <= response.code < 300:
return json_decoder.decode(body.decode("utf-8"))
return json.loads(body.decode("utf-8"))
else:
raise HttpResponseException(
response.code, response.phrase.decode("ascii", errors="replace"), body
@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import logging
import random
import time
@@ -25,7 +26,7 @@ from twisted.web.http import stringToDatetime
from twisted.web.http_headers import Headers
from synapse.logging.context import make_deferred_yieldable
from synapse.util import Clock, json_decoder
from synapse.util import Clock
from synapse.util.caches.ttlcache import TTLCache
from synapse.util.metrics import Measure
@@ -180,7 +181,7 @@ class WellKnownResolver(object):
if response.code != 200:
raise Exception("Non-200 response %s" % (response.code,))
parsed_body = json_decoder.decode(body.decode("utf-8"))
parsed_body = json.loads(body.decode("utf-8"))
logger.info("Response from .well-known: %s", parsed_body)
result = parsed_body["m.server"].encode("ascii")
+31 -42
View File
@@ -500,7 +500,7 @@ class RootOptionsRedirectResource(OptionsResource, RootRedirect):
pass
@implementer(interfaces.IPushProducer)
@implementer(interfaces.IPullProducer)
class _ByteProducer:
"""
Iteratively write bytes to the request.
@@ -515,64 +515,52 @@ class _ByteProducer:
):
self._request = request
self._iterator = iterator
self._paused = False
# Register the producer and start producing data.
self._request.registerProducer(self, True)
self.resumeProducing()
def start(self) -> None:
self._request.registerProducer(self, False)
def _send_data(self, data: List[bytes]) -> None:
"""
Send a list of bytes as a chunk of a response.
Send a list of strings as a response to the request.
"""
if not data:
return
self._request.write(b"".join(data))
def pauseProducing(self) -> None:
self._paused = True
def resumeProducing(self) -> None:
# We've stopped producing in the meantime (note that this might be
# re-entrant after calling write).
if not self._request:
return
self._paused = False
# Get the next chunk and write it to the request.
#
# The output of the JSON encoder is coalesced until min_chunk_size is
# reached. (This is because JSON encoders produce a very small output
# per iteration.)
#
# Note that buffer stores a list of bytes (instead of appending to
# bytes) to hopefully avoid many allocations.
buffer = []
buffered_bytes = 0
while buffered_bytes < self.min_chunk_size:
try:
data = next(self._iterator)
buffer.append(data)
buffered_bytes += len(data)
except StopIteration:
# The entire JSON object has been serialized, write any
# remaining data, finalize the producer and the request, and
# clean-up any references.
self._send_data(buffer)
self._request.unregisterProducer()
self._request.finish()
self.stopProducing()
return
# Write until there's backpressure telling us to stop.
while not self._paused:
# Get the next chunk and write it to the request.
#
# The output of the JSON encoder is buffered and coalesced until
# min_chunk_size is reached. This is because JSON encoders produce
# very small output per iteration and the Request object converts
# each call to write() to a separate chunk. Without this there would
# be an explosion in bytes written (e.g. b"{" becoming "1\r\n{\r\n").
#
# Note that buffer stores a list of bytes (instead of appending to
# bytes) to hopefully avoid many allocations.
buffer = []
buffered_bytes = 0
while buffered_bytes < self.min_chunk_size:
try:
data = next(self._iterator)
buffer.append(data)
buffered_bytes += len(data)
except StopIteration:
# The entire JSON object has been serialized, write any
# remaining data, finalize the producer and the request, and
# clean-up any references.
self._send_data(buffer)
self._request.unregisterProducer()
self._request.finish()
self.stopProducing()
return
self._send_data(buffer)
self._send_data(buffer)
def stopProducing(self) -> None:
# Clear a circular reference.
self._request = None
@@ -632,7 +620,8 @@ def respond_with_json(
if send_cors:
set_cors_headers(request)
_ByteProducer(request, encoder(json_object))
producer = _ByteProducer(request, encoder(json_object))
producer.start()
return NOT_DONE_YET
+3 -2
View File
@@ -17,8 +17,9 @@
import logging
from canonicaljson import json
from synapse.api.errors import Codes, SynapseError
from synapse.util import json_decoder
logger = logging.getLogger(__name__)
@@ -214,7 +215,7 @@ def parse_json_value_from_request(request, allow_empty_body=False):
return None
try:
content = json_decoder.decode(content_bytes.decode("utf-8"))
content = json.loads(content_bytes.decode("utf-8"))
except Exception as e:
logger.warning("Unable to parse JSON: %s", e)
raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
+2 -5
View File
@@ -177,7 +177,6 @@ from canonicaljson import json
from twisted.internet import defer
from synapse.config import ConfigError
from synapse.util import json_decoder
if TYPE_CHECKING:
from synapse.http.site import SynapseRequest
@@ -500,9 +499,7 @@ def start_active_span_from_edu(
if opentracing is None:
return _noop_context_manager()
carrier = json_decoder.decode(edu_content.get("context", "{}")).get(
"opentracing", {}
)
carrier = json.loads(edu_content.get("context", "{}")).get("opentracing", {})
context = opentracing.tracer.extract(opentracing.Format.TEXT_MAP, carrier)
_references = [
opentracing.child_of(span_context_from_string(x))
@@ -702,7 +699,7 @@ def span_context_from_string(carrier):
Returns:
The active span context decoded from a string.
"""
carrier = json_decoder.decode(carrier)
carrier = json.loads(carrier)
return opentracing.tracer.extract(opentracing.Format.TEXT_MAP, carrier)
@@ -175,7 +175,7 @@ def run_as_background_process(desc: str, func, *args, **kwargs):
It returns a Deferred which completes when the function completes, but it doesn't
follow the synapse logcontext rules, which makes it appropriate for passing to
clock.looping_call and friends (or for firing-and-forgetting in the middle of a
normal synapse async function).
normal synapse inlineCallbacks function).
Args:
desc: a description for this background process type
@@ -33,11 +33,3 @@ class SlavedIdTracker(object):
int
"""
return self._current
def get_current_token_for_writer(self, instance_name: str) -> int:
"""Returns the position of the given writer.
For streams with single writers this is equivalent to
`get_current_token`.
"""
return self.get_current_token()
@@ -14,7 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.replication.slave.storage._slaved_id_tracker import SlavedIdTracker
from synapse.replication.tcp.streams import PushRulesStream
from synapse.storage.databases.main.push_rule import PushRulesWorkerStore
@@ -22,13 +21,16 @@ from .events import SlavedEventStore
class SlavedPushRuleStore(SlavedEventStore, PushRulesWorkerStore):
def get_push_rules_stream_token(self):
return (
self._push_rules_stream_id_gen.get_current_token(),
self._stream_id_gen.get_current_token(),
)
def get_max_push_rules_stream_id(self):
return self._push_rules_stream_id_gen.get_current_token()
def process_replication_rows(self, stream_name, instance_name, token, rows):
# We assert this for the benefit of mypy
assert isinstance(self._push_rules_stream_id_gen, SlavedIdTracker)
if stream_name == PushRulesStream.NAME:
self._push_rules_stream_id_gen.advance(token)
for row in rows:
+7 -5
View File
@@ -21,7 +21,9 @@ import abc
import logging
from typing import Tuple, Type
from synapse.util import json_decoder, json_encoder
from canonicaljson import json
from synapse.util import json_encoder as _json_encoder
logger = logging.getLogger(__name__)
@@ -123,7 +125,7 @@ class RdataCommand(Command):
stream_name,
instance_name,
None if token == "batch" else int(token),
json_decoder.decode(row_json),
json.loads(row_json),
)
def to_line(self):
@@ -132,7 +134,7 @@ class RdataCommand(Command):
self.stream_name,
self.instance_name,
str(self.token) if self.token is not None else "batch",
json_encoder.encode(self.row),
_json_encoder.encode(self.row),
)
)
@@ -357,7 +359,7 @@ class UserIpCommand(Command):
def from_line(cls, line):
user_id, jsn = line.split(" ", 1)
access_token, ip, user_agent, device_id, last_seen = json_decoder.decode(jsn)
access_token, ip, user_agent, device_id, last_seen = json.loads(jsn)
return cls(user_id, access_token, ip, user_agent, device_id, last_seen)
@@ -365,7 +367,7 @@ class UserIpCommand(Command):
return (
self.user_id
+ " "
+ json_encoder.encode(
+ _json_encoder.encode(
(
self.access_token,
self.ip,
+2 -2
View File
@@ -352,7 +352,7 @@ class PushRulesStream(Stream):
)
def _current_token(self, instance_name: str) -> int:
push_rules_token = self.store.get_max_push_rules_stream_id()
push_rules_token, _ = self.store.get_push_rules_stream_token()
return push_rules_token
@@ -405,7 +405,7 @@ class CachesStream(Stream):
store = hs.get_datastore()
super().__init__(
hs.get_instance_name(),
store.get_cache_stream_token_for_writer,
store.get_cache_stream_token,
store.get_all_updated_caches,
)
+1 -1
View File
@@ -159,7 +159,7 @@ class PushRuleRestServlet(RestServlet):
return 200, {}
def notify_user(self, user_id):
stream_id = self.store.get_max_push_rules_stream_id()
stream_id, _ = self.store.get_push_rules_stream_token()
self.notifier.on_new_event("push_rules_key", stream_id, users=[user_id])
async def set_rule_attr(self, user_id, spec, val):
+4 -7
View File
@@ -21,6 +21,8 @@ import re
from typing import List, Optional
from urllib import parse as urlparse
from canonicaljson import json
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import (
AuthError,
@@ -44,7 +46,6 @@ from synapse.rest.client.v2_alpha._base import client_patterns
from synapse.storage.state import StateFilter
from synapse.streams.config import PaginationConfig
from synapse.types import RoomAlias, RoomID, StreamToken, ThirdPartyInstanceID, UserID
from synapse.util import json_decoder
MYPY = False
if MYPY:
@@ -518,9 +519,7 @@ class RoomMessageListRestServlet(RestServlet):
filter_str = parse_string(request, b"filter", encoding="utf-8")
if filter_str:
filter_json = urlparse.unquote(filter_str)
event_filter = Filter(
json_decoder.decode(filter_json)
) # type: Optional[Filter]
event_filter = Filter(json.loads(filter_json)) # type: Optional[Filter]
if (
event_filter
and event_filter.filter_json.get("event_format", "client")
@@ -632,9 +631,7 @@ class RoomEventContextServlet(RestServlet):
filter_str = parse_string(request, b"filter", encoding="utf-8")
if filter_str:
filter_json = urlparse.unquote(filter_str)
event_filter = Filter(
json_decoder.decode(filter_json)
) # type: Optional[Filter]
event_filter = Filter(json.loads(filter_json)) # type: Optional[Filter]
else:
event_filter = None
+26
View File
@@ -548,6 +548,31 @@ class GroupAdminUsersKickServlet(RestServlet):
return 200, result
class GroupAdminChangeAdminServlet(RestServlet):
"""Promote or demote a user in the group
"""
PATTERNS = client_patterns(
"/groups/(?P<group_id>[^/]*)/admin/users/admins/(?P<user_id>[^/]*)$"
)
def __init__(self, hs):
super(GroupAdminChangeAdminServlet, self).__init__()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.groups_handler = hs.get_groups_local_handler()
async def on_POST(self, request, group_id, user_id):
requester = await self.auth.get_user_by_req(request)
requester_user_id = requester.user.to_string()
content = parse_json_object_from_request(request)
want_admin = content["is_admin"]
result = await self.groups_handler.change_user_admin_in_group(
group_id, user_id, want_admin, requester_user_id, content
)
return 200, result
class GroupSelfLeaveServlet(RestServlet):
"""Leave a joined group
@@ -722,6 +747,7 @@ def register_servlets(hs, http_server):
GroupAdminRoomsConfigServlet(hs).register(http_server)
GroupAdminUsersInviteServlet(hs).register(http_server)
GroupAdminUsersKickServlet(hs).register(http_server)
GroupAdminChangeAdminServlet(hs).register(http_server)
GroupSelfLeaveServlet(hs).register(http_server)
GroupSelfJoinServlet(hs).register(http_server)
GroupSelfAcceptInviteServlet(hs).register(http_server)
+3 -2
View File
@@ -16,6 +16,8 @@
import itertools
import logging
from canonicaljson import json
from synapse.api.constants import PresenceState
from synapse.api.errors import Codes, StoreError, SynapseError
from synapse.api.filtering import DEFAULT_FILTER_COLLECTION, FilterCollection
@@ -27,7 +29,6 @@ from synapse.handlers.presence import format_user_presence_state
from synapse.handlers.sync import SyncConfig
from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string
from synapse.types import StreamToken
from synapse.util import json_decoder
from ._base import client_patterns, set_timeline_upper_limit
@@ -124,7 +125,7 @@ class SyncRestServlet(RestServlet):
filter_collection = DEFAULT_FILTER_COLLECTION
elif filter_id.startswith("{"):
try:
filter_object = json_decoder.decode(filter_id)
filter_object = json.loads(filter_id)
set_timeline_upper_limit(
filter_object, self.hs.config.filter_timeline_limit
)
+3 -5
View File
@@ -15,19 +15,19 @@
import logging
from typing import Dict, Set
from canonicaljson import json
from signedjson.sign import sign_json
from synapse.api.errors import Codes, SynapseError
from synapse.crypto.keyring import ServerKeyFetcher
from synapse.http.server import DirectServeJsonResource, respond_with_json
from synapse.http.servlet import parse_integer, parse_json_object_from_request
from synapse.util import json_decoder
logger = logging.getLogger(__name__)
class RemoteKey(DirectServeJsonResource):
"""HTTP resource for retrieving the TLS certificate and NACL signature
"""HTTP resource for retreiving the TLS certificate and NACL signature
verification keys for a collection of servers. Checks that the reported
X.509 TLS certificate matches the one used in the HTTPS connection. Checks
that the NACL signature for the remote server is valid. Returns a dict of
@@ -209,15 +209,13 @@ class RemoteKey(DirectServeJsonResource):
# Cast to bytes since postgresql returns a memoryview.
json_results.add(bytes(result["key_json"]))
# If there is a cache miss, request the missing keys, then recurse (and
# ensure the result is sent).
if cache_misses and query_remote_on_cache_miss:
await self.fetcher.get_keys(cache_misses)
await self.query_keys(request, query, query_remote_on_cache_miss=False)
else:
signed_keys = []
for key_json in json_results:
key_json = json_decoder.decode(key_json.decode("utf-8"))
key_json = json.loads(key_json.decode("utf-8"))
for signing_key in self.config.key_server_signing_keys:
key_json = sign_json(key_json, self.config.server_name, signing_key)
+1 -1
View File
@@ -51,5 +51,5 @@ class SpamCheckerApi(object):
state_ids = yield self._store.get_filtered_current_state_ids(
room_id=room_id, state_filter=StateFilter.from_types(types)
)
state = yield defer.ensureDeferred(self._store.get_events(state_ids.values()))
state = yield self._store.get_events(state_ids.values())
return state.values()
+1 -1
View File
@@ -641,7 +641,7 @@ class StateResolutionStore(object):
allow_rejected (bool): If True return rejected events.
Returns:
Awaitable[dict[str, FrozenEvent]]: Dict from event_id to event.
Deferred[dict[str, FrozenEvent]]: Dict from event_id to event.
"""
return self.store.get_events(
+4 -3
View File
@@ -19,11 +19,12 @@ import random
from abc import ABCMeta
from typing import Any, Optional
from canonicaljson import json
from synapse.storage.database import LoggingTransaction # noqa: F401
from synapse.storage.database import make_in_list_sql_clause # noqa: F401
from synapse.storage.database import DatabasePool
from synapse.types import Collection, get_domain_from_id
from synapse.util import json_decoder
logger = logging.getLogger(__name__)
@@ -98,13 +99,13 @@ def db_to_json(db_content):
if isinstance(db_content, memoryview):
db_content = db_content.tobytes()
# Decode it to a Unicode string before feeding it to the JSON decoder, since
# Decode it to a Unicode string before feeding it to json.loads, since
# Python 3.5 does not support deserializing bytes.
if isinstance(db_content, (bytes, bytearray)):
db_content = db_content.decode("utf8")
try:
return json_decoder.decode(db_content)
return json.loads(db_content)
except Exception:
logging.warning("Tried to decode '%r' as JSON and failed", db_content)
raise
+14 -13
View File
@@ -516,16 +516,14 @@ class DatabasePool(object):
logger.warning("Starting db txn '%s' from sentinel context", desc)
try:
result = yield defer.ensureDeferred(
self.runWithConnection(
self.new_transaction,
desc,
after_callbacks,
exception_callbacks,
func,
*args,
**kwargs
)
result = yield self.runWithConnection(
self.new_transaction,
desc,
after_callbacks,
exception_callbacks,
func,
*args,
**kwargs
)
for after_callback, after_args, after_kwargs in after_callbacks:
@@ -537,7 +535,8 @@ class DatabasePool(object):
return result
async def runWithConnection(self, func: Callable, *args: Any, **kwargs: Any) -> Any:
@defer.inlineCallbacks
def runWithConnection(self, func: Callable, *args: Any, **kwargs: Any):
"""Wraps the .runWithConnection() method on the underlying db_pool.
Arguments:
@@ -548,7 +547,7 @@ class DatabasePool(object):
kwargs: named args to pass to `func`
Returns:
The result of func
Deferred: The result of func
"""
parent_context = current_context() # type: Optional[LoggingContextOrSentinel]
if not parent_context:
@@ -571,10 +570,12 @@ class DatabasePool(object):
return func(conn, *args, **kwargs)
return await make_deferred_yieldable(
result = yield make_deferred_yieldable(
self._db_pool.runWithConnection(inner_func, *args, **kwargs)
)
return result
@staticmethod
def cursor_to_dict(cursor):
"""Converts a SQL cursor into an list of dicts.
+2 -2
View File
@@ -299,8 +299,8 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
},
)
def get_cache_stream_token_for_writer(self, instance_name: str) -> int:
def get_cache_stream_token(self, instance_name):
if self._cache_id_gen:
return self._cache_id_gen.get_current_token_for_writer(instance_name)
return self._cache_id_gen.get_current_token(instance_name)
else:
return 0
@@ -30,7 +30,7 @@ logger = logging.getLogger(__name__)
class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBaseStore):
async def get_auth_chain(self, event_ids, include_given=False):
def get_auth_chain(self, event_ids, include_given=False):
"""Get auth events for given event_ids. The events *must* be state events.
Args:
@@ -40,10 +40,9 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
Returns:
list of events
"""
event_ids = await self.get_auth_chain_ids(
return self.get_auth_chain_ids(
event_ids, include_given=include_given
)
return await self.get_events_as_list(event_ids)
).addCallback(self.get_events_as_list)
def get_auth_chain_ids(
self,
@@ -460,7 +459,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
"get_forward_extremeties_for_room", get_forward_extremeties_for_room_txn
)
async def get_backfill_events(self, room_id, event_list, limit):
def get_backfill_events(self, room_id, event_list, limit):
"""Get a list of Events for a given topic that occurred before (and
including) the events in event_list. Return a list of max size `limit`
@@ -470,15 +469,17 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
event_list (list)
limit (int)
"""
event_ids = await self.db_pool.runInteraction(
"get_backfill_events",
self._get_backfill_events,
room_id,
event_list,
limit,
return (
self.db_pool.runInteraction(
"get_backfill_events",
self._get_backfill_events,
room_id,
event_list,
limit,
)
.addCallback(self.get_events_as_list)
.addCallback(lambda l: sorted(l, key=lambda e: -e.depth))
)
events = await self.get_events_as_list(event_ids)
return sorted(events, key=lambda e: -e.depth)
def _get_backfill_events(self, txn, room_id, event_list, limit):
logger.debug("_get_backfill_events: %s, %r, %s", room_id, event_list, limit)
@@ -539,7 +540,8 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
latest_events,
limit,
)
return await self.get_events_as_list(ids)
events = await self.get_events_as_list(ids)
return events
def _get_missing_events(self, txn, room_id, earliest_events, latest_events, limit):
+60 -88
View File
@@ -19,10 +19,9 @@ import itertools
import logging
import threading
from collections import namedtuple
from typing import Dict, Iterable, List, Optional, Tuple, overload
from typing import List, Optional, Tuple
from constantly import NamedConstant, Names
from typing_extensions import Literal
from twisted.internet import defer
@@ -33,7 +32,7 @@ from synapse.api.room_versions import (
EventFormatVersions,
RoomVersions,
)
from synapse.events import EventBase, make_event_from_dict
from synapse.events import make_event_from_dict
from synapse.events.utils import prune_event
from synapse.logging.context import PreserveLoggingContext, current_context
from synapse.metrics.background_process_metrics import run_as_background_process
@@ -43,8 +42,8 @@ from synapse.replication.tcp.streams.events import EventsStream
from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import DatabasePool
from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.types import Collection, get_domain_from_id
from synapse.util.caches.descriptors import Cache, cached
from synapse.types import get_domain_from_id
from synapse.util.caches.descriptors import Cache, cachedInlineCallbacks
from synapse.util.iterutils import batch_iter
from synapse.util.metrics import Measure
@@ -138,33 +137,8 @@ class EventsWorkerStore(SQLBaseStore):
desc="get_received_ts",
)
# Inform mypy that if allow_none is False (the default) then get_event
# always returns an EventBase.
@overload
async def get_event(
self,
event_id: str,
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
get_prev_content: bool = False,
allow_rejected: bool = False,
allow_none: Literal[False] = False,
check_room_id: Optional[str] = None,
) -> EventBase:
...
@overload
async def get_event(
self,
event_id: str,
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
get_prev_content: bool = False,
allow_rejected: bool = False,
allow_none: Literal[True] = False,
check_room_id: Optional[str] = None,
) -> Optional[EventBase]:
...
async def get_event(
@defer.inlineCallbacks
def get_event(
self,
event_id: str,
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
@@ -172,7 +146,7 @@ class EventsWorkerStore(SQLBaseStore):
allow_rejected: bool = False,
allow_none: bool = False,
check_room_id: Optional[str] = None,
) -> Optional[EventBase]:
):
"""Get an event from the database by event_id.
Args:
@@ -197,12 +171,12 @@ class EventsWorkerStore(SQLBaseStore):
If there is a mismatch, behave as per allow_none.
Returns:
The event, or None if the event was not found.
Deferred[EventBase|None]
"""
if not isinstance(event_id, str):
raise TypeError("Invalid event event_id %r" % (event_id,))
events = await self.get_events_as_list(
events = yield self.get_events_as_list(
[event_id],
redact_behaviour=redact_behaviour,
get_prev_content=get_prev_content,
@@ -220,13 +194,14 @@ class EventsWorkerStore(SQLBaseStore):
return event
async def get_events(
@defer.inlineCallbacks
def get_events(
self,
event_ids: Iterable[str],
event_ids: List[str],
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
get_prev_content: bool = False,
allow_rejected: bool = False,
) -> Dict[str, EventBase]:
):
"""Get events from the database
Args:
@@ -245,9 +220,9 @@ class EventsWorkerStore(SQLBaseStore):
omits rejeted events from the response.
Returns:
A mapping from event_id to event.
Deferred : Dict from event_id to event.
"""
events = await self.get_events_as_list(
events = yield self.get_events_as_list(
event_ids,
redact_behaviour=redact_behaviour,
get_prev_content=get_prev_content,
@@ -256,13 +231,14 @@ class EventsWorkerStore(SQLBaseStore):
return {e.event_id: e for e in events}
async def get_events_as_list(
@defer.inlineCallbacks
def get_events_as_list(
self,
event_ids: Collection[str],
event_ids: List[str],
redact_behaviour: EventRedactBehaviour = EventRedactBehaviour.REDACT,
get_prev_content: bool = False,
allow_rejected: bool = False,
) -> List[EventBase]:
):
"""Get events from the database and return in a list in the same order
as given by `event_ids` arg.
@@ -283,8 +259,8 @@ class EventsWorkerStore(SQLBaseStore):
omits rejected events from the response.
Returns:
List of events fetched from the database. The events are in the same
order as `event_ids` arg.
Deferred[list[EventBase]]: List of events fetched from the database. The
events are in the same order as `event_ids` arg.
Note that the returned list may be smaller than the list of event
IDs if not all events could be fetched.
@@ -294,7 +270,7 @@ class EventsWorkerStore(SQLBaseStore):
return []
# there may be duplicates so we cast the list to a set
event_entry_map = await self._get_events_from_cache_or_db(
event_entry_map = yield self._get_events_from_cache_or_db(
set(event_ids), allow_rejected=allow_rejected
)
@@ -329,7 +305,7 @@ class EventsWorkerStore(SQLBaseStore):
continue
redacted_event_id = entry.event.redacts
event_map = await self._get_events_from_cache_or_db([redacted_event_id])
event_map = yield self._get_events_from_cache_or_db([redacted_event_id])
original_event_entry = event_map.get(redacted_event_id)
if not original_event_entry:
# we don't have the redacted event (or it was rejected).
@@ -395,7 +371,7 @@ class EventsWorkerStore(SQLBaseStore):
if get_prev_content:
if "replaces_state" in event.unsigned:
prev = await self.get_event(
prev = yield self.get_event(
event.unsigned["replaces_state"],
get_prev_content=False,
allow_none=True,
@@ -407,7 +383,8 @@ class EventsWorkerStore(SQLBaseStore):
return events
async def _get_events_from_cache_or_db(self, event_ids, allow_rejected=False):
@defer.inlineCallbacks
def _get_events_from_cache_or_db(self, event_ids, allow_rejected=False):
"""Fetch a bunch of events from the cache or the database.
If events are pulled from the database, they will be cached for future lookups.
@@ -422,7 +399,7 @@ class EventsWorkerStore(SQLBaseStore):
rejected events are omitted from the response.
Returns:
Dict[str, _EventCacheEntry]:
Deferred[Dict[str, _EventCacheEntry]]:
map from event id to result
"""
event_entry_map = self._get_events_from_cache(
@@ -440,7 +417,7 @@ class EventsWorkerStore(SQLBaseStore):
# the events have been redacted, and if so pulling the redaction event out
# of the database to check it.
#
missing_events = await self._get_events_from_db(
missing_events = yield self._get_events_from_db(
missing_events_ids, allow_rejected=allow_rejected
)
@@ -548,7 +525,8 @@ class EventsWorkerStore(SQLBaseStore):
with PreserveLoggingContext():
self.hs.get_reactor().callFromThread(fire, event_list, e)
async def _get_events_from_db(self, event_ids, allow_rejected=False):
@defer.inlineCallbacks
def _get_events_from_db(self, event_ids, allow_rejected=False):
"""Fetch a bunch of events from the database.
Returned events will be added to the cache for future lookups.
@@ -562,7 +540,7 @@ class EventsWorkerStore(SQLBaseStore):
rejected events are omitted from the response.
Returns:
Dict[str, _EventCacheEntry]:
Deferred[Dict[str, _EventCacheEntry]]:
map from event id to result. May return extra events which
weren't asked for.
"""
@@ -570,7 +548,7 @@ class EventsWorkerStore(SQLBaseStore):
events_to_fetch = event_ids
while events_to_fetch:
row_map = await self._enqueue_events(events_to_fetch)
row_map = yield self._enqueue_events(events_to_fetch)
# we need to recursively fetch any redactions of those events
redaction_ids = set()
@@ -596,20 +574,8 @@ class EventsWorkerStore(SQLBaseStore):
if not allow_rejected and rejected_reason:
continue
# If the event or metadata cannot be parsed, log the error and act
# as if the event is unknown.
try:
d = db_to_json(row["json"])
except ValueError:
logger.error("Unable to parse json from event: %s", event_id)
continue
try:
internal_metadata = db_to_json(row["internal_metadata"])
except ValueError:
logger.error(
"Unable to parse internal_metadata from event: %s", event_id
)
continue
d = db_to_json(row["json"])
internal_metadata = db_to_json(row["internal_metadata"])
format_version = row["format_version"]
if format_version is None:
@@ -684,7 +650,8 @@ class EventsWorkerStore(SQLBaseStore):
return result_map
async def _enqueue_events(self, events):
@defer.inlineCallbacks
def _enqueue_events(self, events):
"""Fetches events from the database using the _event_fetch_list. This
allows batch and bulk fetching of events - it allows us to fetch events
without having to create a new transaction for each request for events.
@@ -693,7 +660,7 @@ class EventsWorkerStore(SQLBaseStore):
events (Iterable[str]): events to be fetched.
Returns:
Dict[str, Dict]: map from event id to row data from the database.
Deferred[Dict[str, Dict]]: map from event id to row data from the database.
May contain events that weren't requested.
"""
@@ -716,7 +683,7 @@ class EventsWorkerStore(SQLBaseStore):
logger.debug("Loading %d events: %s", len(events), events)
with PreserveLoggingContext():
row_map = await events_d
row_map = yield events_d
logger.debug("Loaded %d events (%d rows)", len(events), len(row_map))
return row_map
@@ -875,29 +842,33 @@ class EventsWorkerStore(SQLBaseStore):
# no valid redaction found for this event
return None
async def have_events_in_timeline(self, event_ids):
@defer.inlineCallbacks
def have_events_in_timeline(self, event_ids):
"""Given a list of event ids, check if we have already processed and
stored them as non outliers.
"""
rows = await self.db_pool.simple_select_many_batch(
table="events",
retcols=("event_id",),
column="event_id",
iterable=list(event_ids),
keyvalues={"outlier": False},
desc="have_events_in_timeline",
rows = yield defer.ensureDeferred(
self.db_pool.simple_select_many_batch(
table="events",
retcols=("event_id",),
column="event_id",
iterable=list(event_ids),
keyvalues={"outlier": False},
desc="have_events_in_timeline",
)
)
return {r["event_id"] for r in rows}
async def have_seen_events(self, event_ids):
@defer.inlineCallbacks
def have_seen_events(self, event_ids):
"""Given a list of event ids, check if we have already processed them.
Args:
event_ids (iterable[str]):
Returns:
set[str]: The events we have already seen.
Deferred[set[str]]: The events we have already seen.
"""
results = set()
@@ -913,7 +884,7 @@ class EventsWorkerStore(SQLBaseStore):
# break the input up into chunks of 100
input_iterator = iter(event_ids)
for chunk in iter(lambda: list(itertools.islice(input_iterator, 100)), []):
await self.db_pool.runInteraction(
yield self.db_pool.runInteraction(
"have_seen_events", have_seen_events_txn, chunk
)
return results
@@ -943,7 +914,8 @@ class EventsWorkerStore(SQLBaseStore):
room_id,
)
async def get_room_complexity(self, room_id):
@defer.inlineCallbacks
def get_room_complexity(self, room_id):
"""
Get a rough approximation of the complexity of the room. This is used by
remote servers to decide whether they wish to join the room or not.
@@ -954,9 +926,9 @@ class EventsWorkerStore(SQLBaseStore):
room_id (str)
Returns:
dict[str:int] of complexity version to complexity.
Deferred[dict[str:int]] of complexity version to complexity.
"""
state_events = await self.get_current_state_event_counts(room_id)
state_events = yield self.get_current_state_event_counts(room_id)
# Call this one "v1", so we can introduce new ones as we want to develop
# it.
@@ -1193,9 +1165,9 @@ class EventsWorkerStore(SQLBaseStore):
to_2, so_2 = await self.get_event_ordering(event_id2)
return (to_1, so_1) > (to_2, so_2)
@cached(max_entries=5000)
async def get_event_ordering(self, event_id):
res = await self.db_pool.simple_select_one(
@cachedInlineCallbacks(max_entries=5000)
def get_event_ordering(self, event_id):
res = yield self.db_pool.simple_select_one(
table="events",
retcols=["topological_ordering", "stream_ordering"],
keyvalues={"event_id": event_id},
@@ -1038,6 +1038,14 @@ class GroupServerStore(GroupServerWorkerStore):
"remove_user_from_group", _remove_user_from_group_txn
)
def change_user_admin_in_group(self, group_id, user_id, is_admin):
return self.db_pool.simple_update(
table="group_users",
keyvalues={"group_id": group_id, "user_id": user_id},
updatevalues={"is_admin": is_admin},
desc="change_user_admin_in_group"
)
def add_room_to_group(self, group_id, room_id, is_public):
return self.db_pool.simple_insert(
table="group_rooms",
+19 -17
View File
@@ -30,7 +30,7 @@ from synapse.storage.databases.main.pusher import PusherWorkerStore
from synapse.storage.databases.main.receipts import ReceiptsWorkerStore
from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
from synapse.storage.push_rule import InconsistentRuleException, RuleNotFoundException
from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.storage.util.id_generators import ChainedIdGenerator
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.stream_change_cache import StreamChangeCache
@@ -82,9 +82,9 @@ class PushRulesWorkerStore(
super(PushRulesWorkerStore, self).__init__(database, db_conn, hs)
if hs.config.worker.worker_app is None:
self._push_rules_stream_id_gen = StreamIdGenerator(
db_conn, "push_rules_stream", "stream_id"
) # type: Union[StreamIdGenerator, SlavedIdTracker]
self._push_rules_stream_id_gen = ChainedIdGenerator(
self._stream_id_gen, db_conn, "push_rules_stream", "stream_id"
) # type: Union[ChainedIdGenerator, SlavedIdTracker]
else:
self._push_rules_stream_id_gen = SlavedIdTracker(
db_conn, "push_rules_stream", "stream_id"
@@ -338,9 +338,8 @@ class PushRuleStore(PushRulesWorkerStore):
) -> None:
conditions_json = json_encoder.encode(conditions)
actions_json = json_encoder.encode(actions)
with self._push_rules_stream_id_gen.get_next() as stream_id:
event_stream_ordering = self._stream_id_gen.get_current_token()
with self._push_rules_stream_id_gen.get_next() as ids:
stream_id, event_stream_ordering = ids
if before or after:
await self.db_pool.runInteraction(
"_add_push_rule_relative_txn",
@@ -560,9 +559,8 @@ class PushRuleStore(PushRulesWorkerStore):
txn, stream_id, event_stream_ordering, user_id, rule_id, op="DELETE"
)
with self._push_rules_stream_id_gen.get_next() as stream_id:
event_stream_ordering = self._stream_id_gen.get_current_token()
with self._push_rules_stream_id_gen.get_next() as ids:
stream_id, event_stream_ordering = ids
await self.db_pool.runInteraction(
"delete_push_rule",
delete_push_rule_txn,
@@ -571,9 +569,8 @@ class PushRuleStore(PushRulesWorkerStore):
)
async def set_push_rule_enabled(self, user_id, rule_id, enabled) -> None:
with self._push_rules_stream_id_gen.get_next() as stream_id:
event_stream_ordering = self._stream_id_gen.get_current_token()
with self._push_rules_stream_id_gen.get_next() as ids:
stream_id, event_stream_ordering = ids
await self.db_pool.runInteraction(
"_set_push_rule_enabled_txn",
self._set_push_rule_enabled_txn,
@@ -646,9 +643,8 @@ class PushRuleStore(PushRulesWorkerStore):
data={"actions": actions_json},
)
with self._push_rules_stream_id_gen.get_next() as stream_id:
event_stream_ordering = self._stream_id_gen.get_current_token()
with self._push_rules_stream_id_gen.get_next() as ids:
stream_id, event_stream_ordering = ids
await self.db_pool.runInteraction(
"set_push_rule_actions",
set_push_rule_actions_txn,
@@ -677,5 +673,11 @@ class PushRuleStore(PushRulesWorkerStore):
self.push_rules_stream_cache.entity_has_changed, user_id, stream_id
)
def get_max_push_rules_stream_id(self):
def get_push_rules_stream_token(self):
"""Get the position of the push rules stream.
Returns a pair of a stream id for the push_rules stream and the
room stream ordering it corresponds to."""
return self._push_rules_stream_id_gen.get_current_token()
def get_max_push_rules_stream_id(self):
return self.get_push_rules_stream_token()[0]
+1
View File
@@ -379,6 +379,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
limit: int = 0,
order: str = "DESC",
) -> Tuple[List[EventBase], str]:
"""Get new room events in stream ordering since `from_key`.
Args:
+65 -17
View File
@@ -16,7 +16,7 @@
import contextlib
import threading
from collections import deque
from typing import Dict, Set
from typing import Dict, Set, Tuple
from typing_extensions import Deque
@@ -158,13 +158,63 @@ class StreamIdGenerator(object):
return self._current
def get_current_token_for_writer(self, instance_name: str) -> int:
"""Returns the position of the given writer.
For streams with single writers this is equivalent to
`get_current_token`.
class ChainedIdGenerator(object):
"""Used to generate new stream ids where the stream must be kept in sync
with another stream. It generates pairs of IDs, the first element is an
integer ID for this stream, the second element is the ID for the stream
that this stream needs to be kept in sync with."""
def __init__(self, chained_generator, db_conn, table, column):
self.chained_generator = chained_generator
self._table = table
self._lock = threading.Lock()
self._current_max = _load_current_id(db_conn, table, column)
self._unfinished_ids = deque() # type: Deque[Tuple[int, int]]
def get_next(self):
"""
return self.get_current_token()
Usage:
with stream_id_gen.get_next() as (stream_id, chained_id):
# ... persist event ...
"""
with self._lock:
self._current_max += 1
next_id = self._current_max
chained_id = self.chained_generator.get_current_token()
self._unfinished_ids.append((next_id, chained_id))
@contextlib.contextmanager
def manager():
try:
yield (next_id, chained_id)
finally:
with self._lock:
self._unfinished_ids.remove((next_id, chained_id))
return manager()
def get_current_token(self):
"""Returns the maximum stream id such that all stream ids less than or
equal to it have been successfully persisted.
"""
with self._lock:
if self._unfinished_ids:
stream_id, chained_id = self._unfinished_ids[0]
return stream_id - 1, chained_id
return self._current_max, self.chained_generator.get_current_token()
def advance(self, token: int):
"""Stub implementation for advancing the token when receiving updates
over replication; raises an exception as this instance should be the
only source of updates.
"""
raise Exception(
"Attempted to advance token on source for table %r", self._table
)
class MultiWriterIdGenerator:
@@ -248,7 +298,7 @@ class MultiWriterIdGenerator:
# Assert the fetched ID is actually greater than what we currently
# believe the ID to be. If not, then the sequence and table have got
# out of sync somehow.
assert self.get_current_token_for_writer(self._instance_name) < next_id
assert self.get_current_token() < next_id
with self._lock:
self._unfinished_ids.add(next_id)
@@ -294,18 +344,16 @@ class MultiWriterIdGenerator:
curr = self._current_positions.get(self._instance_name, 0)
self._current_positions[self._instance_name] = max(curr, next_id)
def get_current_token(self) -> int:
"""Returns the maximum stream id such that all stream ids less than or
equal to it have been successfully persisted.
def get_current_token(self, instance_name: str = None) -> int:
"""Gets the current position of a named writer (defaults to current
instance).
Returns 0 if we don't have a position for the named writer (likely due
to it being a new writer).
"""
# Currently we don't support this operation, as it's not obvious how to
# condense the stream positions of multiple writers into a single int.
raise NotImplementedError()
def get_current_token_for_writer(self, instance_name: str) -> int:
"""Returns the position of the given writer.
"""
if instance_name is None:
instance_name = self._instance_name
with self._lock:
return self._current_positions.get(instance_name, 0)
+1 -1
View File
@@ -39,7 +39,7 @@ class EventSources(object):
self.store = hs.get_datastore()
def get_current_token(self) -> StreamToken:
push_rules_key = self.store.get_max_push_rules_stream_id()
push_rules_key, _ = self.store.get_push_rules_stream_token()
to_device_key = self.store.get_to_device_stream_token()
device_list_key = self.store.get_device_stream_token()
groups_key = self.store.get_group_stream_token()
+2 -12
View File
@@ -25,18 +25,8 @@ from synapse.logging import context
logger = logging.getLogger(__name__)
def _reject_invalid_json(val):
"""Do not allow Infinity, -Infinity, or NaN values in JSON."""
raise json.JSONDecodeError("Invalid JSON value: '%s'" % val)
# Create a custom encoder to reduce the whitespace produced by JSON encoding and
# ensure that valid JSON is produced.
json_encoder = json.JSONEncoder(allow_nan=False, separators=(",", ":"))
# Create a custom decoder to reject Python extensions to JSON.
json_decoder = json.JSONDecoder(parse_constant=_reject_invalid_json)
# Create a custom encoder to reduce the whitespace produced by JSON encoding.
json_encoder = json.JSONEncoder(separators=(",", ":"))
def unwrapFirstError(failure):
+46 -8
View File
@@ -285,9 +285,16 @@ class Cache(object):
class _CacheDescriptorBase(object):
def __init__(self, orig: _CachedFunction, num_args, cache_context=False):
def __init__(
self, orig: _CachedFunction, num_args, inlineCallbacks, cache_context=False
):
self.orig = orig
if inlineCallbacks:
self.function_to_call = defer.inlineCallbacks(orig)
else:
self.function_to_call = orig
arg_spec = inspect.getfullargspec(orig)
all_args = arg_spec.args
@@ -357,7 +364,7 @@ class CacheDescriptor(_CacheDescriptorBase):
invalidated) by adding a special "cache_context" argument to the function
and passing that as a kwarg to all caches called. For example::
@cached(cache_context=True)
@cachedInlineCallbacks(cache_context=True)
def foo(self, key, cache_context):
r1 = yield self.bar1(key, on_invalidate=cache_context.invalidate)
r2 = yield self.bar2(key, on_invalidate=cache_context.invalidate)
@@ -375,11 +382,17 @@ class CacheDescriptor(_CacheDescriptorBase):
max_entries=1000,
num_args=None,
tree=False,
inlineCallbacks=False,
cache_context=False,
iterable=False,
):
super().__init__(orig, num_args=num_args, cache_context=cache_context)
super(CacheDescriptor, self).__init__(
orig,
num_args=num_args,
inlineCallbacks=inlineCallbacks,
cache_context=cache_context,
)
self.max_entries = max_entries
self.tree = tree
@@ -452,7 +465,9 @@ class CacheDescriptor(_CacheDescriptorBase):
observer = defer.succeed(cached_result_d)
except KeyError:
ret = defer.maybeDeferred(preserve_fn(self.orig), obj, *args, **kwargs)
ret = defer.maybeDeferred(
preserve_fn(self.function_to_call), obj, *args, **kwargs
)
def onErr(f):
cache.invalidate(cache_key)
@@ -495,7 +510,9 @@ class CacheListDescriptor(_CacheDescriptorBase):
of results.
"""
def __init__(self, orig, cached_method_name, list_name, num_args=None):
def __init__(
self, orig, cached_method_name, list_name, num_args=None, inlineCallbacks=False
):
"""
Args:
orig (function)
@@ -504,8 +521,12 @@ class CacheListDescriptor(_CacheDescriptorBase):
num_args (int): number of positional arguments (excluding ``self``,
but including list_name) to use as cache keys. Defaults to all
named args of the function.
inlineCallbacks (bool): Whether orig is a generator that should
be wrapped by defer.inlineCallbacks
"""
super().__init__(orig, num_args=num_args)
super(CacheListDescriptor, self).__init__(
orig, num_args=num_args, inlineCallbacks=inlineCallbacks
)
self.list_name = list_name
@@ -610,7 +631,7 @@ class CacheListDescriptor(_CacheDescriptorBase):
cached_defers.append(
defer.maybeDeferred(
preserve_fn(self.orig), **args_to_call
preserve_fn(self.function_to_call), **args_to_call
).addCallbacks(complete_all, errback)
)
@@ -674,7 +695,21 @@ def cached(
)
def cachedList(cached_method_name, list_name, num_args=None):
def cachedInlineCallbacks(
max_entries=1000, num_args=None, tree=False, cache_context=False, iterable=False
):
return lambda orig: CacheDescriptor(
orig,
max_entries=max_entries,
num_args=num_args,
tree=tree,
inlineCallbacks=True,
cache_context=cache_context,
iterable=iterable,
)
def cachedList(cached_method_name, list_name, num_args=None, inlineCallbacks=False):
"""Creates a descriptor that wraps a function in a `CacheListDescriptor`.
Used to do batch lookups for an already created cache. A single argument
@@ -690,6 +725,8 @@ def cachedList(cached_method_name, list_name, num_args=None):
do batch lookups in the cache.
num_args (int): Number of arguments to use as the key in the cache
(including list_name). Defaults to all named parameters.
inlineCallbacks (bool): Should the function be wrapped in an
`defer.inlineCallbacks`?
Example:
@@ -707,4 +744,5 @@ def cachedList(cached_method_name, list_name, num_args=None):
cached_method_name=cached_method_name,
list_name=list_name,
num_args=num_args,
inlineCallbacks=inlineCallbacks,
)
+11 -5
View File
@@ -62,7 +62,8 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
"identifier": {"type": "m.id.user", "user": "kermit" + str(i)},
"password": "monkey",
}
request, channel = self.make_request(b"POST", LOGIN_URL, params)
request_data = json.dumps(params)
request, channel = self.make_request(b"POST", LOGIN_URL, request_data)
self.render(request)
if i == 5:
@@ -75,13 +76,14 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
# than 1min.
self.assertTrue(retry_after_ms < 6000)
self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
self.reactor.advance(retry_after_ms / 1000.0)
params = {
"type": "m.login.password",
"identifier": {"type": "m.id.user", "user": "kermit" + str(i)},
"password": "monkey",
}
request_data = json.dumps(params)
request, channel = self.make_request(b"POST", LOGIN_URL, params)
self.render(request)
@@ -109,7 +111,8 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
"identifier": {"type": "m.id.user", "user": "kermit"},
"password": "monkey",
}
request, channel = self.make_request(b"POST", LOGIN_URL, params)
request_data = json.dumps(params)
request, channel = self.make_request(b"POST", LOGIN_URL, request_data)
self.render(request)
if i == 5:
@@ -129,6 +132,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
"identifier": {"type": "m.id.user", "user": "kermit"},
"password": "monkey",
}
request_data = json.dumps(params)
request, channel = self.make_request(b"POST", LOGIN_URL, params)
self.render(request)
@@ -156,7 +160,8 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
"identifier": {"type": "m.id.user", "user": "kermit"},
"password": "notamonkey",
}
request, channel = self.make_request(b"POST", LOGIN_URL, params)
request_data = json.dumps(params)
request, channel = self.make_request(b"POST", LOGIN_URL, request_data)
self.render(request)
if i == 5:
@@ -169,13 +174,14 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
# than 1min.
self.assertTrue(retry_after_ms < 6000)
self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
self.reactor.advance(retry_after_ms / 1000.0)
params = {
"type": "m.login.password",
"identifier": {"type": "m.id.user", "user": "kermit"},
"password": "notamonkey",
}
request_data = json.dumps(params)
request, channel = self.make_request(b"POST", LOGIN_URL, params)
self.render(request)
+2 -2
View File
@@ -160,7 +160,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
else:
self.assertEquals(channel.result["code"], b"200", channel.result)
self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
self.reactor.advance(retry_after_ms / 1000.0)
request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
self.render(request)
@@ -186,7 +186,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
else:
self.assertEquals(channel.result["code"], b"200", channel.result)
self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
self.reactor.advance(retry_after_ms / 1000.0)
request, channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
self.render(request)
@@ -104,7 +104,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
)
self._rlsn._store.get_events = Mock(
return_value=make_awaitable({"123": mock_event})
return_value=defer.succeed({"123": mock_event})
)
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
# Would be better to check the content, but once == remove blocking event
@@ -122,7 +122,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
)
self._rlsn._store.get_events = Mock(
return_value=make_awaitable({"123": mock_event})
return_value=defer.succeed({"123": mock_event})
)
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
@@ -217,7 +217,7 @@ class TestResourceLimitsServerNotices(unittest.HomeserverTestCase):
type=EventTypes.Message, content={"msgtype": ServerNoticeMsgType}
)
self._rlsn._store.get_events = Mock(
return_value=make_awaitable({"123": mock_event})
return_value=defer.succeed({"123": mock_event})
)
self.get_success(self._rlsn.maybe_send_server_notice_to_user(self.user_id))
+1 -2
View File
@@ -31,7 +31,6 @@ from synapse.storage.databases.main.appservice import (
)
from tests import unittest
from tests.test_utils import make_awaitable
from tests.utils import setup_test_homeserver
@@ -358,7 +357,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
other_events = [Mock(event_id="e5"), Mock(event_id="e6")]
# we aren't testing store._base stuff here, so mock this out
self.store.get_events_as_list = Mock(return_value=make_awaitable(events))
self.store.get_events_as_list = Mock(return_value=defer.succeed(events))
yield self._insert_txn(self.as_list[1]["id"], 9, other_events)
yield self._insert_txn(service.id, 10, events)
+1 -2
View File
@@ -353,7 +353,6 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
self.event_creator_handler._rooms_to_exclude_from_dummy_event_insertion[
"3"
] = 300000
self.event_creator_handler._expire_rooms_to_exclude_from_dummy_event_insertion()
# All entries within time frame
self.assertEqual(
@@ -363,7 +362,7 @@ class CleanupExtremDummyEventsTestCase(HomeserverTestCase):
3,
)
# Oldest room to expire
self.pump(1.01)
self.pump(1)
self.event_creator_handler._expire_rooms_to_exclude_from_dummy_event_insertion()
self.assertEqual(
len(
+8 -8
View File
@@ -88,7 +88,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
id_gen = self._create_id_generator()
self.assertEqual(id_gen.get_positions(), {"master": 7})
self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
self.assertEqual(id_gen.get_current_token("master"), 7)
# Try allocating a new ID gen and check that we only see position
# advanced after we leave the context manager.
@@ -98,12 +98,12 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.assertEqual(stream_id, 8)
self.assertEqual(id_gen.get_positions(), {"master": 7})
self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
self.assertEqual(id_gen.get_current_token("master"), 7)
self.get_success(_get_next_async())
self.assertEqual(id_gen.get_positions(), {"master": 8})
self.assertEqual(id_gen.get_current_token_for_writer("master"), 8)
self.assertEqual(id_gen.get_current_token("master"), 8)
def test_multi_instance(self):
"""Test that reads and writes from multiple processes are handled
@@ -116,8 +116,8 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
second_id_gen = self._create_id_generator("second")
self.assertEqual(first_id_gen.get_positions(), {"first": 3, "second": 7})
self.assertEqual(first_id_gen.get_current_token_for_writer("first"), 3)
self.assertEqual(first_id_gen.get_current_token_for_writer("second"), 7)
self.assertEqual(first_id_gen.get_current_token("first"), 3)
self.assertEqual(first_id_gen.get_current_token("second"), 7)
# Try allocating a new ID gen and check that we only see position
# advanced after we leave the context manager.
@@ -166,7 +166,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
id_gen = self._create_id_generator()
self.assertEqual(id_gen.get_positions(), {"master": 7})
self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
self.assertEqual(id_gen.get_current_token("master"), 7)
# Try allocating a new ID gen and check that we only see position
# advanced after we leave the context manager.
@@ -176,9 +176,9 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
self.assertEqual(stream_id, 8)
self.assertEqual(id_gen.get_positions(), {"master": 7})
self.assertEqual(id_gen.get_current_token_for_writer("master"), 7)
self.assertEqual(id_gen.get_current_token("master"), 7)
self.get_success(self.db_pool.runInteraction("test", _get_next_txn))
self.assertEqual(id_gen.get_positions(), {"master": 8})
self.assertEqual(id_gen.get_current_token_for_writer("master"), 8)
self.assertEqual(id_gen.get_current_token("master"), 8)
+6 -6
View File
@@ -366,11 +366,11 @@ class CachedListDescriptorTestCase(unittest.TestCase):
def fn(self, arg1, arg2):
pass
@descriptors.cachedList("fn", "args1")
async def list_fn(self, args1, arg2):
@descriptors.cachedList("fn", "args1", inlineCallbacks=True)
def list_fn(self, args1, arg2):
assert current_context().request == "c1"
# we want this to behave like an asynchronous function
await run_on_reactor()
yield run_on_reactor()
assert current_context().request == "c1"
return self.mock(args1, arg2)
@@ -416,10 +416,10 @@ class CachedListDescriptorTestCase(unittest.TestCase):
def fn(self, arg1, arg2):
pass
@descriptors.cachedList("fn", "args1")
async def list_fn(self, args1, arg2):
@descriptors.cachedList("fn", "args1", inlineCallbacks=True)
def list_fn(self, args1, arg2):
# we want this to behave like an asynchronous function
await run_on_reactor()
yield run_on_reactor()
return self.mock(args1, arg2)
obj = Cls()