1
0

Compare commits

...

15 Commits

Author SHA1 Message Date
Erik Johnston
b143e6c688 Make it simpler 2020-01-22 15:39:21 +00:00
Erik Johnston
83ae89a7bc Refactor HomeServer object to work with type hints 2020-01-22 10:37:35 +00:00
Erik Johnston
0e68760078 Add a DeltaState to track changes to be made to current state (#6716) 2020-01-20 18:07:20 +00:00
Erik Johnston
b0a66ab83c Fixup synapse.rest to pass mypy (#6732) 2020-01-20 17:38:21 +00:00
Erik Johnston
74b74462f1 Fix /events/:event_id deprecated API. (#6731) 2020-01-20 17:38:09 +00:00
Erik Johnston
0f6e525be3 Fixup synapse.api to pass mypy (#6733) 2020-01-20 17:34:13 +00:00
Erik Johnston
ceecedc68b Fix changing password via user admin API. (#6730) 2020-01-20 17:23:59 +00:00
Andrew Morgan
e9e066055f Fix empty account_validity config block (#6747) 2020-01-20 16:21:59 +00:00
Andrew Morgan
351fdfede6 Update changelog.d/6747.bugfix
Co-Authored-By: Erik Johnston <erik@matrix.org>
2020-01-20 15:58:44 +00:00
Erik Johnston
2f23eb27b3 Revert "Newsfile"
This reverts commit 11c23af465.
2020-01-20 15:12:58 +00:00
Erik Johnston
11c23af465 Newsfile 2020-01-20 15:11:38 +00:00
Andrew Morgan
026f4bdf3c Add changelog 2020-01-20 14:12:21 +00:00
Andrew Morgan
198d52da3a Fix empty account_validity config block 2020-01-20 14:05:29 +00:00
Brendan Abolivier
a17f64361c Add more logging around message retention policies support (#6717)
So we can debug issues like #6683 more easily
2020-01-17 20:51:44 +00:00
Erik Johnston
5909751936 Fix up changelog 2020-01-17 15:13:27 +00:00
34 changed files with 459 additions and 446 deletions

1
changelog.d/6716.misc Normal file
View File

@@ -0,0 +1 @@
Add a `DeltaState` to track changes to be made to current state during event persistence.

1
changelog.d/6717.misc Normal file
View File

@@ -0,0 +1 @@
Add more logging around message retention policies support.

View File

@@ -1 +0,0 @@
Fix a bug causing `ValueError: unsupported format character ''' (0x27) at index 312` error when running the schema 57 upgrade script.

1
changelog.d/6728.misc Normal file
View File

@@ -0,0 +1 @@
Add `local_current_membership` table for tracking local user membership state in rooms.

1
changelog.d/6730.bugfix Normal file
View File

@@ -0,0 +1 @@
Fix changing password via user admin API.

1
changelog.d/6731.bugfix Normal file
View File

@@ -0,0 +1 @@
Fix `/events/:event_id` deprecated API.

1
changelog.d/6732.misc Normal file
View File

@@ -0,0 +1 @@
Fixup `synapse.rest` to pass mypy.

1
changelog.d/6733.misc Normal file
View File

@@ -0,0 +1 @@
Fixup synapse.api to pass mypy.

1
changelog.d/6747.bugfix Normal file
View File

@@ -0,0 +1 @@
Fix bug when setting `account_validity` to an empty block in the config. Thanks to @Sorunome for reporting.

View File

@@ -7,6 +7,9 @@ show_error_codes = True
show_traceback = True
mypy_path = stubs
[mypy-pymacaroons.*]
ignore_missing_imports = True
[mypy-zope]
ignore_missing_imports = True
@@ -63,3 +66,12 @@ ignore_missing_imports = True
[mypy-sentry_sdk]
ignore_missing_imports = True
[mypy-PIL.*]
ignore_missing_imports = True
[mypy-lxml]
ignore_missing_imports = True
[mypy-jwt.*]
ignore_missing_imports = True

View File

@@ -15,6 +15,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
from six import text_type
import jsonschema
@@ -293,7 +295,7 @@ class Filter(object):
room_id = None
ev_type = "m.presence"
contains_url = False
labels = []
labels = [] # type: List[str]
else:
sender = event.get("sender", None)
if not sender:

View File

@@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
from collections import OrderedDict
from typing import Any, Optional, Tuple
from synapse.api.errors import LimitExceededError
@@ -23,7 +24,9 @@ class Ratelimiter(object):
"""
def __init__(self):
self.message_counts = collections.OrderedDict()
self.message_counts = (
OrderedDict()
) # type: OrderedDict[Any, Tuple[float, int, Optional[float]]]
def can_do_action(self, key, time_now_s, rate_hz, burst_count, update=True):
"""Can the entity (e.g. user or IP address) perform the action?

View File

@@ -29,6 +29,7 @@ class AccountValidityConfig(Config):
def __init__(self, config, synapse_config):
if config is None:
return
super(AccountValidityConfig, self).__init__()
self.enabled = config.get("enabled", False)
self.renew_by_email_enabled = "renew_at" in config
@@ -93,7 +94,7 @@ class RegistrationConfig(Config):
)
self.account_validity = AccountValidityConfig(
config.get("account_validity", {}), config
config.get("account_validity") or {}, config
)
self.registrations_require_3pid = config.get("registrations_require_3pid", [])

View File

@@ -294,6 +294,14 @@ class ServerConfig(Config):
self.retention_default_min_lifetime = None
self.retention_default_max_lifetime = None
if self.retention_enabled:
logger.info(
"Message retention policies support enabled with the following default"
" policy: min_lifetime = %s ; max_lifetime = %s",
self.retention_default_min_lifetime,
self.retention_default_max_lifetime,
)
self.retention_allowed_lifetime_min = retention_config.get(
"allowed_lifetime_min"
)

View File

@@ -634,7 +634,7 @@ def get_public_keys(invite_event):
return public_keys
def auth_types_for_event(event) -> Set[Tuple[str]]:
def auth_types_for_event(event) -> Set[Tuple[str, str]]:
"""Given an event, return a list of (EventType, StateKey) that may be
needed to auth the event. The returned list may be a superset of what
would actually be required depending on the full state of the room.

View File

@@ -88,6 +88,8 @@ class PaginationHandler(object):
if hs.config.retention_enabled:
# Run the purge jobs described in the configuration file.
for job in hs.config.retention_purge_jobs:
logger.info("Setting up purge job with config: %s", job)
self.clock.looping_call(
run_as_background_process,
job["interval"],
@@ -130,11 +132,22 @@ class PaginationHandler(object):
else:
include_null = False
logger.info(
"[purge] Running purge job for %d < max_lifetime <= %d (include NULLs = %s)",
min_ms,
max_ms,
include_null,
)
rooms = yield self.store.get_rooms_for_retention_period_in_range(
min_ms, max_ms, include_null
)
logger.debug("[purge] Rooms to purge: %s", rooms)
for room_id, retention_policy in iteritems(rooms):
logger.info("[purge] Attempting to purge messages in room %s", room_id)
if room_id in self._purges_in_progress_by_room:
logger.warning(
"[purge] not purging room %s as there's an ongoing purge running"

View File

@@ -193,8 +193,8 @@ class UserRestServletV2(RestServlet):
raise SynapseError(400, "Invalid password")
else:
new_password = body["password"]
await self._set_password_handler.set_password(
target_user, new_password, requester
await self.set_password_handler.set_password(
target_user.to_string(), new_password, requester
)
if "deactivated" in body:
@@ -338,21 +338,22 @@ class UserRegisterServlet(RestServlet):
got_mac = body["mac"]
want_mac = hmac.new(
want_mac_builder = hmac.new(
key=self.hs.config.registration_shared_secret.encode(),
digestmod=hashlib.sha1,
)
want_mac.update(nonce.encode("utf8"))
want_mac.update(b"\x00")
want_mac.update(username)
want_mac.update(b"\x00")
want_mac.update(password)
want_mac.update(b"\x00")
want_mac.update(b"admin" if admin else b"notadmin")
want_mac_builder.update(nonce.encode("utf8"))
want_mac_builder.update(b"\x00")
want_mac_builder.update(username)
want_mac_builder.update(b"\x00")
want_mac_builder.update(password)
want_mac_builder.update(b"\x00")
want_mac_builder.update(b"admin" if admin else b"notadmin")
if user_type:
want_mac.update(b"\x00")
want_mac.update(user_type.encode("utf8"))
want_mac = want_mac.hexdigest()
want_mac_builder.update(b"\x00")
want_mac_builder.update(user_type.encode("utf8"))
want_mac = want_mac_builder.hexdigest()
if not hmac.compare_digest(want_mac.encode("ascii"), got_mac.encode("ascii")):
raise SynapseError(403, "HMAC incorrect")

View File

@@ -70,7 +70,6 @@ class EventStreamRestServlet(RestServlet):
return 200, {}
# TODO: Unit test gets, with and without auth, with different kinds of events.
class EventRestServlet(RestServlet):
PATTERNS = client_patterns("/events/(?P<event_id>[^/]*)$", v1=True)
@@ -78,6 +77,7 @@ class EventRestServlet(RestServlet):
super(EventRestServlet, self).__init__()
self.clock = hs.get_clock()
self.event_handler = hs.get_event_handler()
self.auth = hs.get_auth()
self._event_serializer = hs.get_event_client_serializer()
async def on_GET(self, request, event_id):

View File

@@ -514,7 +514,7 @@ class CasTicketServlet(RestServlet):
if user is None:
raise Exception("CAS response does not contain user")
except Exception:
logger.error("Error parsing CAS response", exc_info=1)
logger.exception("Error parsing CAS response")
raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED)
if not success:
raise LoginError(

View File

@@ -16,6 +16,7 @@
""" This module contains REST servlets to do with rooms: /rooms/<paths> """
import logging
from typing import List, Optional
from six.moves.urllib import parse as urlparse
@@ -207,7 +208,7 @@ class RoomStateEventRestServlet(TransactionRestServlet):
requester, event_dict, txn_id=txn_id
)
ret = {}
ret = {} # type: dict
if event:
set_tag("event_id", event.event_id)
ret = {"event_id": event.event_id}
@@ -285,7 +286,7 @@ class JoinRoomAliasServlet(TransactionRestServlet):
try:
remote_room_hosts = [
x.decode("ascii") for x in request.args[b"server_name"]
]
] # type: Optional[List[str]]
except Exception:
remote_room_hosts = None
elif RoomAlias.is_valid(room_identifier):
@@ -375,7 +376,7 @@ class PublicRoomListRestServlet(TransactionRestServlet):
server = parse_string(request, "server", default=None)
content = parse_json_object_from_request(request)
limit = int(content.get("limit", 100))
limit = int(content.get("limit", 100)) # type: Optional[int]
since_token = content.get("since", None)
search_filter = content.get("filter", None)
@@ -504,11 +505,16 @@ class RoomMessageListRestServlet(RestServlet):
filter_bytes = parse_string(request, b"filter", encoding=None)
if filter_bytes:
filter_json = urlparse.unquote(filter_bytes.decode("UTF-8"))
event_filter = Filter(json.loads(filter_json))
if event_filter.filter_json.get("event_format", "client") == "federation":
event_filter = Filter(json.loads(filter_json)) # type: Optional[Filter]
if (
event_filter
and event_filter.filter_json.get("event_format", "client")
== "federation"
):
as_client_event = False
else:
event_filter = None
msgs = await self.pagination_handler.get_messages(
room_id=room_id,
requester=requester,
@@ -611,7 +617,7 @@ class RoomEventContextServlet(RestServlet):
filter_bytes = parse_string(request, "filter")
if filter_bytes:
filter_json = urlparse.unquote(filter_bytes)
event_filter = Filter(json.loads(filter_json))
event_filter = Filter(json.loads(filter_json)) # type: Optional[Filter]
else:
event_filter = None

View File

@@ -21,6 +21,7 @@ from typing import List, Union
from six import string_types
import synapse
import synapse.api.auth
import synapse.types
from synapse.api.constants import LoginType
from synapse.api.errors import (
@@ -405,7 +406,7 @@ class RegisterRestServlet(RestServlet):
return ret
elif kind != b"user":
raise UnrecognizedRequestError(
"Do not understand membership kind: %s" % (kind,)
"Do not understand membership kind: %s" % (kind.decode("utf8"),)
)
# we do basic sanity checks here because the auth layer will store these

View File

@@ -14,6 +14,7 @@
# limitations under the License.
import logging
from typing import Tuple
from synapse.http import servlet
from synapse.http.servlet import parse_json_object_from_request
@@ -60,7 +61,7 @@ class SendToDeviceRestServlet(servlet.RestServlet):
sender_user_id, message_type, content["messages"]
)
response = (200, {})
response = (200, {}) # type: Tuple[int, dict]
return response

View File

@@ -13,6 +13,7 @@
# limitations under the License.
import logging
from typing import Dict, Set
from canonicaljson import encode_canonical_json, json
from signedjson.sign import sign_json
@@ -103,7 +104,7 @@ class RemoteKey(DirectServeResource):
async def _async_render_GET(self, request):
if len(request.postpath) == 1:
(server,) = request.postpath
query = {server.decode("ascii"): {}}
query = {server.decode("ascii"): {}} # type: dict
elif len(request.postpath) == 2:
server, key_id = request.postpath
minimum_valid_until_ts = parse_integer(request, "minimum_valid_until_ts")
@@ -148,7 +149,7 @@ class RemoteKey(DirectServeResource):
time_now_ms = self.clock.time_msec()
cache_misses = dict()
cache_misses = dict() # type: Dict[str, Set[str]]
for (server_name, key_id, from_server), results in cached.items():
results = [(result["ts_added_ms"], result) for result in results]

View File

@@ -18,6 +18,7 @@ import errno
import logging
import os
import shutil
from typing import Dict, Tuple
from six import iteritems
@@ -605,7 +606,7 @@ class MediaRepository(object):
# We deduplicate the thumbnail sizes by ignoring the cropped versions if
# they have the same dimensions of a scaled one.
thumbnails = {}
thumbnails = {} # type: Dict[Tuple[int, int, str], str]
for r_width, r_height, r_method, r_type in requirements:
if r_method == "crop":
thumbnails.setdefault((r_width, r_height, r_type), r_method)

View File

@@ -23,6 +23,7 @@ import re
import shutil
import sys
import traceback
from typing import Dict, Optional
import six
from six import string_types
@@ -237,8 +238,8 @@ class PreviewUrlResource(DirectServeResource):
# If we don't find a match, we'll look at the HTTP Content-Type, and
# if that doesn't exist, we'll fall back to UTF-8.
if not encoding:
match = _content_type_match.match(media_info["media_type"])
encoding = match.group(1) if match else "utf-8"
content_match = _content_type_match.match(media_info["media_type"])
encoding = content_match.group(1) if content_match else "utf-8"
og = decode_and_calc_og(body, media_info["uri"], encoding)
@@ -518,7 +519,7 @@ def _calc_og(tree, media_uri):
# "og:video:height" : "720",
# "og:video:secure_url": "https://www.youtube.com/v/LXDBoHyjmtw?version=3",
og = {}
og = {} # type: Dict[str, Optional[str]]
for tag in tree.xpath("//*/meta[starts-with(@property, 'og:')]"):
if "content" in tag.attrib:
# if we've got more than 50 tags, someone is taking the piss

View File

@@ -296,8 +296,8 @@ class ThumbnailResource(DirectServeResource):
d_h = desired_height
if desired_method.lower() == "crop":
info_list = []
info_list2 = []
crop_info_list = []
crop_info_list2 = []
for info in thumbnail_infos:
t_w = info["thumbnail_width"]
t_h = info["thumbnail_height"]
@@ -309,7 +309,7 @@ class ThumbnailResource(DirectServeResource):
type_quality = desired_type != info["thumbnail_type"]
length_quality = info["thumbnail_length"]
if t_w >= d_w or t_h >= d_h:
info_list.append(
crop_info_list.append(
(
aspect_quality,
min_quality,
@@ -320,7 +320,7 @@ class ThumbnailResource(DirectServeResource):
)
)
else:
info_list2.append(
crop_info_list2.append(
(
aspect_quality,
min_quality,
@@ -330,10 +330,10 @@ class ThumbnailResource(DirectServeResource):
info,
)
)
if info_list:
return min(info_list)[-1]
if crop_info_list:
return min(crop_info_list2)[-1]
else:
return min(info_list2)[-1]
return min(crop_info_list2)[-1]
else:
info_list = []
info_list2 = []

View File

@@ -24,10 +24,15 @@
import abc
import logging
import os
from functools import wraps
from typing import Any, Callable, Dict, List, Optional, TypeVar, cast
from twisted.internet import tcp
from twisted.mail.smtp import sendmail
from twisted.web.client import BrowserLikePolicyForHTTPS
from twisted.web.iweb import IPolicyForHTTPS
import synapse
from synapse.api.auth import Auth
from synapse.api.filtering import Filtering
from synapse.api.ratelimiting import Ratelimiter
@@ -104,6 +109,43 @@ from synapse.util.distributor import Distributor
logger = logging.getLogger(__name__)
FuncType = Callable[..., Any]
F = TypeVar("F", bound=FuncType)
def builder(f: F) -> F:
"""Decorator to wrap a HomeServer method to cache result and detect
cyclical dependencies.
"""
if not f.__name__.startswith("get_"):
raise Exception("Function must be named `get_*`")
depname = f.__name__[len("get_") :] # type: str
@wraps(f)
def _get(self):
try:
return getattr(self, depname)
except AttributeError:
pass
# Prevent cyclic dependencies from deadlocking
if depname in self._building:
raise ValueError("Cyclic dependency while building %s" % (depname,))
try:
self._building[depname] = True
dep = f(self)
finally:
self._building.pop(depname, None)
setattr(self, self.depname, dep)
return dep
return cast(F, _get)
class HomeServer(object):
"""A basic homeserver object without lazy component builders.
@@ -111,17 +153,6 @@ class HomeServer(object):
constructor arguments, or the relevant methods overriding to create them.
Typically this would only be used for unit tests.
For every dependency in the DEPENDENCIES list below, this class creates one
method,
def get_DEPENDENCY(self)
which returns the value of that dependency. If no value has yet been set
nor was provided to the constructor, it will attempt to call a lazy builder
method called
def build_DEPENDENCY(self)
which must be implemented by the subclass. This code may call any of the
required "get" methods on the instance to obtain the sub-dependencies that
one requires.
Attributes:
config (synapse.config.homeserver.HomeserverConfig):
_listening_services (list[twisted.internet.tcp.Port]): TCP ports that
@@ -130,77 +161,6 @@ class HomeServer(object):
__metaclass__ = abc.ABCMeta
DEPENDENCIES = [
"http_client",
"federation_client",
"federation_server",
"handlers",
"auth",
"room_creation_handler",
"state_handler",
"state_resolution_handler",
"presence_handler",
"sync_handler",
"typing_handler",
"room_list_handler",
"acme_handler",
"auth_handler",
"device_handler",
"stats_handler",
"e2e_keys_handler",
"e2e_room_keys_handler",
"event_handler",
"event_stream_handler",
"initial_sync_handler",
"application_service_api",
"application_service_scheduler",
"application_service_handler",
"device_message_handler",
"profile_handler",
"event_creation_handler",
"deactivate_account_handler",
"set_password_handler",
"notifier",
"event_sources",
"keyring",
"pusherpool",
"event_builder_factory",
"filtering",
"http_client_context_factory",
"simple_http_client",
"proxied_http_client",
"media_repository",
"media_repository_resource",
"federation_transport_client",
"federation_sender",
"receipts_handler",
"macaroon_generator",
"tcp_replication",
"read_marker_handler",
"action_generator",
"user_directory_handler",
"groups_local_handler",
"groups_server_handler",
"groups_attestation_signing",
"groups_attestation_renewer",
"secrets",
"spam_checker",
"third_party_event_rules",
"room_member_handler",
"federation_registry",
"server_notices_manager",
"server_notices_sender",
"message_handler",
"pagination_handler",
"room_context_handler",
"sendmail",
"registration_handler",
"account_validity_handler",
"saml_handler",
"event_client_serializer",
"storage",
]
REQUIRED_ON_MASTER_STARTUP = ["user_directory_handler", "stats_handler"]
# This is overridden in derived application classes
@@ -215,14 +175,16 @@ class HomeServer(object):
config: The full config for the homeserver.
"""
if not reactor:
from twisted.internet import reactor
from twisted import internet
reactor = internet.reactor
self._reactor = reactor
self.hostname = hostname
self.config = config
self._building = {}
self._listening_services = []
self.start_time = None
self._building = {} # type: Dict[str, bool]
self._listening_services = [] # type: List[tcp.Port]
self.start_time = None # type: Optional[int]
self.clock = Clock(reactor)
self.distributor = Distributor()
@@ -230,7 +192,7 @@ class HomeServer(object):
self.admin_redaction_ratelimiter = Ratelimiter()
self.registration_ratelimiter = Ratelimiter()
self.datastores = None
self.datastores = None # type: Optional[DataStores]
# Other kwargs are explicit dependencies
for depname in kwargs:
@@ -261,182 +223,231 @@ class HomeServer(object):
# X-Forwarded-For is handled by our custom request type.
return request.getClientIP()
def is_mine(self, domain_specific_string):
def is_mine(self, domain_specific_string) -> bool:
return domain_specific_string.domain == self.hostname
def is_mine_id(self, string):
def is_mine_id(self, string: str) -> bool:
return string.split(":", 1)[1] == self.hostname
def get_clock(self):
def get_clock(self) -> Clock:
return self.clock
def get_datastore(self):
if not self.datastores:
raise Exception("HomeServer has not been set up yet")
return self.datastores.main
def get_datastores(self):
def get_datastores(self) -> DataStores:
if not self.datastores:
raise Exception("HomeServer has not been set up yet")
return self.datastores
def get_config(self):
def get_config(self) -> HomeServerConfig:
return self.config
def get_distributor(self):
def get_distributor(self) -> Distributor:
return self.distributor
def get_ratelimiter(self):
def get_ratelimiter(self) -> Ratelimiter:
return self.ratelimiter
def get_registration_ratelimiter(self):
def get_registration_ratelimiter(self) -> Ratelimiter:
return self.registration_ratelimiter
def get_admin_redaction_ratelimiter(self):
def get_admin_redaction_ratelimiter(self) -> Ratelimiter:
return self.admin_redaction_ratelimiter
def build_federation_client(self):
@builder
def get_federation_client(self) -> FederationClient:
return FederationClient(self)
def build_federation_server(self):
@builder
def get_federation_server(self) -> FederationServer:
return FederationServer(self)
def build_handlers(self):
@builder
def get_handlers(self) -> Handlers:
return Handlers(self)
def build_notifier(self):
@builder
def get_notifier(self) -> Notifier:
return Notifier(self)
def build_auth(self):
@builder
def get_auth(self) -> Auth:
return Auth(self)
def build_http_client_context_factory(self):
@builder
def get_http_client_context_factory(self) -> IPolicyForHTTPS:
return (
InsecureInterceptableContextFactory()
if self.config.use_insecure_ssl_client_just_for_testing_do_not_use
else BrowserLikePolicyForHTTPS()
)
def build_simple_http_client(self):
@builder
def get_simple_http_client(self) -> SimpleHttpClient:
return SimpleHttpClient(self)
def build_proxied_http_client(self):
@builder
def get_proxied_http_client(self) -> SimpleHttpClient:
return SimpleHttpClient(
self,
http_proxy=os.getenvb(b"http_proxy"),
https_proxy=os.getenvb(b"HTTPS_PROXY"),
)
def build_room_creation_handler(self):
@builder
def get_room_creation_handler(self) -> RoomCreationHandler:
return RoomCreationHandler(self)
def build_sendmail(self):
@builder
def get_sendmail(self) -> sendmail:
return sendmail
def build_state_handler(self):
@builder
def get_state_handler(self) -> StateHandler:
return StateHandler(self)
def build_state_resolution_handler(self):
@builder
def get_state_resolution_handler(self) -> StateResolutionHandler:
return StateResolutionHandler(self)
def build_presence_handler(self):
@builder
def get_presence_handler(self) -> PresenceHandler:
return PresenceHandler(self)
def build_typing_handler(self):
@builder
def get_typing_handler(self) -> TypingHandler:
return TypingHandler(self)
def build_sync_handler(self):
@builder
def get_sync_handler(self) -> SyncHandler:
return SyncHandler(self)
def build_room_list_handler(self):
@builder
def get_room_list_handler(self) -> RoomListHandler:
return RoomListHandler(self)
def build_auth_handler(self):
@builder
def get_auth_handler(self) -> AuthHandler:
return AuthHandler(self)
def build_macaroon_generator(self):
@builder
def get_macaroon_generator(self) -> MacaroonGenerator:
return MacaroonGenerator(self)
def build_device_handler(self):
@builder
def get_device_handler(self) -> DeviceWorkerHandler:
if self.config.worker_app:
return DeviceWorkerHandler(self)
else:
return DeviceHandler(self)
def build_device_message_handler(self):
@builder
def get_device_message_handler(self) -> DeviceMessageHandler:
return DeviceMessageHandler(self)
def build_e2e_keys_handler(self):
@builder
def get_e2e_keys_handler(self) -> E2eKeysHandler:
return E2eKeysHandler(self)
def build_e2e_room_keys_handler(self):
@builder
def get_e2e_room_keys_handler(self) -> E2eRoomKeysHandler:
return E2eRoomKeysHandler(self)
def build_acme_handler(self):
@builder
def get_acme_handler(self) -> AcmeHandler:
return AcmeHandler(self)
def build_application_service_api(self):
@builder
def get_application_service_api(self) -> ApplicationServiceApi:
return ApplicationServiceApi(self)
def build_application_service_scheduler(self):
@builder
def get_application_service_scheduler(self) -> ApplicationServiceScheduler:
return ApplicationServiceScheduler(self)
def build_application_service_handler(self):
@builder
def get_application_service_handler(self) -> ApplicationServicesHandler:
return ApplicationServicesHandler(self)
def build_event_handler(self):
@builder
def get_event_handler(self) -> EventHandler:
return EventHandler(self)
def build_event_stream_handler(self):
@builder
def get_event_stream_handler(self) -> EventStreamHandler:
return EventStreamHandler(self)
def build_initial_sync_handler(self):
@builder
def get_initial_sync_handler(self) -> InitialSyncHandler:
return InitialSyncHandler(self)
def build_profile_handler(self):
@builder
def get_profile_handler(self):
if self.config.worker_app:
return BaseProfileHandler(self)
else:
return MasterProfileHandler(self)
def build_event_creation_handler(self):
@builder
def get_event_creation_handler(self) -> EventCreationHandler:
return EventCreationHandler(self)
def build_deactivate_account_handler(self):
@builder
def get_deactivate_account_handler(self) -> DeactivateAccountHandler:
return DeactivateAccountHandler(self)
def build_set_password_handler(self):
@builder
def get_set_password_handler(self) -> SetPasswordHandler:
return SetPasswordHandler(self)
def build_event_sources(self):
@builder
def get_event_sources(self) -> EventSources:
return EventSources(self)
def build_keyring(self):
@builder
def get_keyring(self) -> Keyring:
return Keyring(self)
def build_event_builder_factory(self):
@builder
def get_event_builder_factory(self) -> EventBuilderFactory:
return EventBuilderFactory(self)
def build_filtering(self):
@builder
def get_filtering(self) -> Filtering:
return Filtering(self)
def build_pusherpool(self):
@builder
def get_pusherpool(self) -> PusherPool:
return PusherPool(self)
def build_http_client(self):
@builder
def get_http_client(self) -> MatrixFederationHttpClient:
tls_client_options_factory = context_factory.ClientTLSOptionsFactory(
self.config
)
return MatrixFederationHttpClient(self, tls_client_options_factory)
def build_media_repository_resource(self):
@builder
def get_media_repository_resource(self) -> MediaRepositoryResource:
# build the media repo resource. This indirects through the HomeServer
# to ensure that we only have a single instance of
return MediaRepositoryResource(self)
def build_media_repository(self):
@builder
def get_media_repository(self) -> MediaRepository:
return MediaRepository(self)
def build_federation_transport_client(self):
@builder
def get_federation_transport_client(self) -> TransportLayerClient:
return TransportLayerClient(self)
def build_federation_sender(self):
@builder
def get_federation_sender(self):
if self.should_send_federation():
return FederationSender(self)
elif not self.config.worker_app:
@@ -444,135 +455,126 @@ class HomeServer(object):
else:
raise Exception("Workers cannot send federation traffic")
def build_receipts_handler(self):
@builder
def get_receipts_handler(self) -> ReceiptsHandler:
return ReceiptsHandler(self)
def build_read_marker_handler(self):
@builder
def get_read_marker_handler(self) -> ReadMarkerHandler:
return ReadMarkerHandler(self)
def build_tcp_replication(self):
@builder
def get_tcp_replication(self):
raise NotImplementedError()
def build_action_generator(self):
@builder
def get_action_generator(self) -> ActionGenerator:
return ActionGenerator(self)
def build_user_directory_handler(self):
@builder
def get_user_directory_handler(self) -> UserDirectoryHandler:
return UserDirectoryHandler(self)
def build_groups_local_handler(self):
@builder
def get_groups_local_handler(self) -> GroupsLocalHandler:
return GroupsLocalHandler(self)
def build_groups_server_handler(self):
@builder
def get_groups_server_handler(self) -> GroupsServerHandler:
return GroupsServerHandler(self)
def build_groups_attestation_signing(self):
@builder
def get_groups_attestation_signing(self) -> GroupAttestationSigning:
return GroupAttestationSigning(self)
def build_groups_attestation_renewer(self):
@builder
def get_groups_attestation_renewer(self) -> GroupAttestionRenewer:
return GroupAttestionRenewer(self)
def build_secrets(self):
@builder
def get_secrets(self):
return Secrets()
def build_stats_handler(self):
@builder
def get_stats_handler(self) -> StatsHandler:
return StatsHandler(self)
def build_spam_checker(self):
@builder
def get_spam_checker(self) -> SpamChecker:
return SpamChecker(self)
def build_third_party_event_rules(self):
@builder
def get_third_party_event_rules(self) -> ThirdPartyEventRules:
return ThirdPartyEventRules(self)
def build_room_member_handler(self):
@builder
def get_room_member_handler(self):
if self.config.worker_app:
return RoomMemberWorkerHandler(self)
return RoomMemberMasterHandler(self)
def build_federation_registry(self):
@builder
def get_federation_registry(self):
if self.config.worker_app:
return ReplicationFederationHandlerRegistry(self)
else:
return FederationHandlerRegistry()
def build_server_notices_manager(self):
@builder
def get_server_notices_manager(self) -> ServerNoticesManager:
if self.config.worker_app:
raise Exception("Workers cannot send server notices")
return ServerNoticesManager(self)
def build_server_notices_sender(self):
@builder
def get_server_notices_sender(self):
if self.config.worker_app:
return WorkerServerNoticesSender(self)
return ServerNoticesSender(self)
def build_message_handler(self):
@builder
def get_message_handler(self) -> MessageHandler:
return MessageHandler(self)
def build_pagination_handler(self):
@builder
def get_pagination_handler(self) -> PaginationHandler:
return PaginationHandler(self)
def build_room_context_handler(self):
@builder
def get_room_context_handler(self) -> RoomContextHandler:
return RoomContextHandler(self)
def build_registration_handler(self):
@builder
def get_registration_handler(self) -> RegistrationHandler:
return RegistrationHandler(self)
def build_account_validity_handler(self):
@builder
def get_account_validity_handler(self) -> AccountValidityHandler:
return AccountValidityHandler(self)
def build_saml_handler(self):
@builder
def get_saml_handler(self) -> "synapse.handlers.saml_handler.SamlHandler":
from synapse.handlers.saml_handler import SamlHandler
return SamlHandler(self)
def build_event_client_serializer(self):
@builder
def get_event_client_serializer(self) -> EventClientSerializer:
return EventClientSerializer(self)
def build_storage(self) -> Storage:
@builder
def get_storage(self) -> Storage:
if self.datastores is None:
raise Exception("HomeServer has not been set up yet")
return Storage(self, self.datastores)
def remove_pusher(self, app_id, push_key, user_id):
def remove_pusher(self, app_id: str, push_key: str, user_id: str):
return self.get_pusherpool().remove_pusher(app_id, push_key, user_id)
def should_send_federation(self):
def should_send_federation(self) -> bool:
"Should this server be sending federation traffic directly?"
return self.config.send_federation and (
not self.config.worker_app
or self.config.worker_app == "synapse.app.federation_sender"
)
def _make_dependency_method(depname):
def _get(hs):
try:
return getattr(hs, depname)
except AttributeError:
pass
try:
builder = getattr(hs, "build_%s" % (depname))
except AttributeError:
builder = None
if builder:
# Prevent cyclic dependencies from deadlocking
if depname in hs._building:
raise ValueError("Cyclic dependency while building %s" % (depname,))
hs._building[depname] = 1
dep = builder()
setattr(hs, depname, dep)
del hs._building[depname]
return dep
raise NotImplementedError(
"%s has no %s nor a builder for it" % (type(hs).__name__, depname)
)
setattr(HomeServer, "get_%s" % (depname), _get)
# Build magic accessors for every dependency
for depname in HomeServer.DEPENDENCIES:
_make_dependency_method(depname)

View File

@@ -1,99 +0,0 @@
import twisted.internet
import synapse.api.auth
import synapse.config.homeserver
import synapse.federation.sender
import synapse.federation.transaction_queue
import synapse.federation.transport.client
import synapse.handlers
import synapse.handlers.auth
import synapse.handlers.deactivate_account
import synapse.handlers.device
import synapse.handlers.e2e_keys
import synapse.handlers.message
import synapse.handlers.presence
import synapse.handlers.room
import synapse.handlers.room_member
import synapse.handlers.set_password
import synapse.http.client
import synapse.notifier
import synapse.rest.media.v1.media_repository
import synapse.server_notices.server_notices_manager
import synapse.server_notices.server_notices_sender
import synapse.state
import synapse.storage
class HomeServer(object):
@property
def config(self) -> synapse.config.homeserver.HomeServerConfig:
pass
def get_auth(self) -> synapse.api.auth.Auth:
pass
def get_auth_handler(self) -> synapse.handlers.auth.AuthHandler:
pass
def get_datastore(self) -> synapse.storage.DataStore:
pass
def get_device_handler(self) -> synapse.handlers.device.DeviceHandler:
pass
def get_e2e_keys_handler(self) -> synapse.handlers.e2e_keys.E2eKeysHandler:
pass
def get_handlers(self) -> synapse.handlers.Handlers:
pass
def get_state_handler(self) -> synapse.state.StateHandler:
pass
def get_state_resolution_handler(self) -> synapse.state.StateResolutionHandler:
pass
def get_simple_http_client(self) -> synapse.http.client.SimpleHttpClient:
"""Fetch an HTTP client implementation which doesn't do any blacklisting
or support any HTTP_PROXY settings"""
pass
def get_proxied_http_client(self) -> synapse.http.client.SimpleHttpClient:
"""Fetch an HTTP client implementation which doesn't do any blacklisting
but does support HTTP_PROXY settings"""
pass
def get_deactivate_account_handler(
self,
) -> synapse.handlers.deactivate_account.DeactivateAccountHandler:
pass
def get_room_creation_handler(self) -> synapse.handlers.room.RoomCreationHandler:
pass
def get_room_member_handler(self) -> synapse.handlers.room_member.RoomMemberHandler:
pass
def get_event_creation_handler(
self,
) -> synapse.handlers.message.EventCreationHandler:
pass
def get_set_password_handler(
self,
) -> synapse.handlers.set_password.SetPasswordHandler:
pass
def get_federation_sender(self) -> synapse.federation.sender.FederationSender:
pass
def get_federation_transport_client(
self,
) -> synapse.federation.transport.client.TransportLayerClient:
pass
def get_media_repository_resource(
self,
) -> synapse.rest.media.v1.media_repository.MediaRepositoryResource:
pass
def get_media_repository(
self,
) -> synapse.rest.media.v1.media_repository.MediaRepository:
pass
def get_server_notices_manager(
self,
) -> synapse.server_notices.server_notices_manager.ServerNoticesManager:
pass
def get_server_notices_sender(
self,
) -> synapse.server_notices.server_notices_sender.ServerNoticesSender:
pass
def get_notifier(self) -> synapse.notifier.Notifier:
pass
def get_presence_handler(self) -> synapse.handlers.presence.PresenceHandler:
pass
def get_clock(self) -> synapse.util.Clock:
pass
def get_reactor(self) -> twisted.internet.base.ReactorBase:
pass

View File

@@ -19,6 +19,7 @@ import itertools
import logging
from collections import Counter as c_counter, OrderedDict, namedtuple
from functools import wraps
from typing import Dict, List, Tuple
from six import iteritems, text_type
from six.moves import range
@@ -41,8 +42,9 @@ from synapse.storage._base import make_in_list_sql_clause
from synapse.storage.data_stores.main.event_federation import EventFederationStore
from synapse.storage.data_stores.main.events_worker import EventsWorkerStore
from synapse.storage.data_stores.main.state import StateGroupWorkerStore
from synapse.storage.database import Database
from synapse.types import RoomStreamToken, get_domain_from_id
from synapse.storage.database import Database, LoggingTransaction
from synapse.storage.persist_events import DeltaState
from synapse.types import RoomStreamToken, StateMap, get_domain_from_id
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
from synapse.util.frozenutils import frozendict_json_encoder
from synapse.util.iterutils import batch_iter
@@ -148,30 +150,26 @@ class EventsStore(
@defer.inlineCallbacks
def _persist_events_and_state_updates(
self,
events_and_contexts,
current_state_for_room,
state_delta_for_room,
new_forward_extremeties,
backfilled=False,
delete_existing=False,
events_and_contexts: List[Tuple[EventBase, EventContext]],
current_state_for_room: Dict[str, StateMap[str]],
state_delta_for_room: Dict[str, DeltaState],
new_forward_extremeties: Dict[str, List[str]],
backfilled: bool = False,
delete_existing: bool = False,
):
"""Persist a set of events alongside updates to the current state and
forward extremities tables.
Args:
events_and_contexts (list[(EventBase, EventContext)]):
current_state_for_room (dict[str, dict]): Map from room_id to the
current state of the room based on forward extremities
state_delta_for_room (dict[str, tuple]): Map from room_id to tuple
of `(to_delete, to_insert)` where to_delete is a list
of type/state keys to remove from current state, and to_insert
is a map (type,key)->event_id giving the state delta in each
room.
new_forward_extremities (dict[str, list[str]]): Map from room_id
to list of event IDs that are the new forward extremities of
the room.
backfilled (bool)
delete_existing (bool):
events_and_contexts:
current_state_for_room: Map from room_id to the current state of
the room based on forward extremities
state_delta_for_room: Map from room_id to the delta to apply to
room state
new_forward_extremities: Map from room_id to list of event IDs
that are the new forward extremities of the room.
backfilled
delete_existing
Returns:
Deferred: resolves when the events have been persisted
@@ -352,12 +350,12 @@ class EventsStore(
@log_function
def _persist_events_txn(
self,
txn,
events_and_contexts,
backfilled,
delete_existing=False,
state_delta_for_room={},
new_forward_extremeties={},
txn: LoggingTransaction,
events_and_contexts: List[Tuple[EventBase, EventContext]],
backfilled: bool,
delete_existing: bool = False,
state_delta_for_room: Dict[str, DeltaState] = {},
new_forward_extremeties: Dict[str, List[str]] = {},
):
"""Insert some number of room events into the necessary database tables.
@@ -366,21 +364,16 @@ class EventsStore(
whether the event was rejected.
Args:
txn (twisted.enterprise.adbapi.Connection): db connection
events_and_contexts (list[(EventBase, EventContext)]):
events to persist
backfilled (bool): True if the events were backfilled
delete_existing (bool): True to purge existing table rows for the
events from the database. This is useful when retrying due to
txn
events_and_contexts: events to persist
backfilled: True if the events were backfilled
delete_existing True to purge existing table rows for the events
from the database. This is useful when retrying due to
IntegrityError.
state_delta_for_room (dict[str, (list, dict)]):
The current-state delta for each room. For each room, a tuple
(to_delete, to_insert), being a list of type/state keys to be
removed from the current state, and a state set to be added to
the current state.
new_forward_extremeties (dict[str, list[str]]):
The new forward extremities for each room. For each room, a
list of the event ids which are the forward extremities.
state_delta_for_room: The current-state delta for each room.
new_forward_extremetie: The new forward extremities for each room.
For each room, a list of the event ids which are the forward
extremities.
"""
all_events_and_contexts = events_and_contexts
@@ -465,9 +458,15 @@ class EventsStore(
# room_memberships, where applicable.
self._update_current_state_txn(txn, state_delta_for_room, min_stream_order)
def _update_current_state_txn(self, txn, state_delta_by_room, stream_id):
for room_id, current_state_tuple in iteritems(state_delta_by_room):
to_delete, to_insert = current_state_tuple
def _update_current_state_txn(
self,
txn: LoggingTransaction,
state_delta_by_room: Dict[str, DeltaState],
stream_id: int,
):
for room_id, delta_state in iteritems(state_delta_by_room):
to_delete = delta_state.to_delete
to_insert = delta_state.to_insert
# First we add entries to the current_state_delta_stream. We
# do this before updating the current_state_events table so

View File

@@ -17,19 +17,24 @@
import logging
from collections import deque, namedtuple
from typing import Iterable, List, Optional, Tuple
from six import iteritems
from six.moves import range
import attr
from prometheus_client import Counter, Histogram
from twisted.internet import defer
from synapse.api.constants import EventTypes
from synapse.events import FrozenEvent
from synapse.events.snapshot import EventContext
from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.state import StateResolutionStore
from synapse.storage.data_stores import DataStores
from synapse.types import StateMap
from synapse.util.async_helpers import ObservableDeferred
from synapse.util.metrics import Measure
@@ -67,6 +72,19 @@ stale_forward_extremities_counter = Histogram(
)
@attr.s(slots=True, frozen=True)
class DeltaState:
"""Deltas to use to update the `current_state_events` table.
Attributes:
to_delete: List of type/state_keys to delete from current state
to_insert: Map of state to upsert into current state
"""
to_delete = attr.ib(type=List[Tuple[str, str]])
to_insert = attr.ib(type=StateMap[str])
class _EventPeristenceQueue(object):
"""Queues up events so that they can be persisted in bulk with only one
concurrent transaction per room.
@@ -138,13 +156,12 @@ class _EventPeristenceQueue(object):
self._currently_persisting_rooms.add(room_id)
@defer.inlineCallbacks
def handle_queue_loop():
async def handle_queue_loop():
try:
queue = self._get_drainining_queue(room_id)
for item in queue:
try:
ret = yield per_item_callback(item)
ret = await per_item_callback(item)
except Exception:
with PreserveLoggingContext():
item.deferred.errback()
@@ -191,12 +208,16 @@ class EventsPersistenceStorage(object):
self._state_resolution_handler = hs.get_state_resolution_handler()
@defer.inlineCallbacks
def persist_events(self, events_and_contexts, backfilled=False):
def persist_events(
self,
events_and_contexts: List[Tuple[FrozenEvent, EventContext]],
backfilled: bool = False,
):
"""
Write events to the database
Args:
events_and_contexts: list of tuples of (event, context)
backfilled (bool): Whether the results are retrieved from federation
backfilled: Whether the results are retrieved from federation
via backfill or not. Used to determine if they're "new" events
which might update the current state etc.
@@ -226,16 +247,12 @@ class EventsPersistenceStorage(object):
return max_persisted_id
@defer.inlineCallbacks
def persist_event(self, event, context, backfilled=False):
def persist_event(
self, event: FrozenEvent, context: EventContext, backfilled: bool = False
):
"""
Args:
event (EventBase):
context (EventContext):
backfilled (bool):
Returns:
Deferred: resolves to (int, int): the stream ordering of ``event``,
Deferred[Tuple[int, int]]: the stream ordering of ``event``,
and the stream ordering of the latest persisted event
"""
deferred = self._event_persist_queue.add_to_queue(
@@ -249,28 +266,22 @@ class EventsPersistenceStorage(object):
max_persisted_id = yield self.main_store.get_current_events_token()
return (event.internal_metadata.stream_ordering, max_persisted_id)
def _maybe_start_persisting(self, room_id):
@defer.inlineCallbacks
def persisting_queue(item):
def _maybe_start_persisting(self, room_id: str):
async def persisting_queue(item):
with Measure(self._clock, "persist_events"):
yield self._persist_events(
await self._persist_events(
item.events_and_contexts, backfilled=item.backfilled
)
self._event_persist_queue.handle_queue(room_id, persisting_queue)
@defer.inlineCallbacks
def _persist_events(self, events_and_contexts, backfilled=False):
async def _persist_events(
self,
events_and_contexts: List[Tuple[FrozenEvent, EventContext]],
backfilled: bool = False,
):
"""Calculates the change to current state and forward extremities, and
persists the given events and with those updates.
Args:
events_and_contexts (list[(EventBase, EventContext)]):
backfilled (bool):
delete_existing (bool):
Returns:
Deferred: resolves when the events have been persisted
"""
if not events_and_contexts:
return
@@ -315,10 +326,10 @@ class EventsPersistenceStorage(object):
)
for room_id, ev_ctx_rm in iteritems(events_by_room):
latest_event_ids = yield self.main_store.get_latest_event_ids_in_room(
latest_event_ids = await self.main_store.get_latest_event_ids_in_room(
room_id
)
new_latest_event_ids = yield self._calculate_new_extremities(
new_latest_event_ids = await self._calculate_new_extremities(
room_id, ev_ctx_rm, latest_event_ids
)
@@ -374,7 +385,7 @@ class EventsPersistenceStorage(object):
with Measure(
self._clock, "persist_events.get_new_state_after_events"
):
res = yield self._get_new_state_after_events(
res = await self._get_new_state_after_events(
room_id,
ev_ctx_rm,
latest_event_ids,
@@ -389,12 +400,12 @@ class EventsPersistenceStorage(object):
# If there is a delta we know that we've
# only added or replaced state, never
# removed keys entirely.
state_delta_for_room[room_id] = ([], delta_ids)
state_delta_for_room[room_id] = DeltaState([], delta_ids)
elif current_state is not None:
with Measure(
self._clock, "persist_events.calculate_state_delta"
):
delta = yield self._calculate_state_delta(
delta = await self._calculate_state_delta(
room_id, current_state
)
state_delta_for_room[room_id] = delta
@@ -404,7 +415,7 @@ class EventsPersistenceStorage(object):
if current_state is not None:
current_state_for_room[room_id] = current_state
yield self.main_store._persist_events_and_state_updates(
await self.main_store._persist_events_and_state_updates(
chunk,
current_state_for_room=current_state_for_room,
state_delta_for_room=state_delta_for_room,
@@ -412,8 +423,12 @@ class EventsPersistenceStorage(object):
backfilled=backfilled,
)
@defer.inlineCallbacks
def _calculate_new_extremities(self, room_id, event_contexts, latest_event_ids):
async def _calculate_new_extremities(
self,
room_id: str,
event_contexts: List[Tuple[FrozenEvent, EventContext]],
latest_event_ids: List[str],
):
"""Calculates the new forward extremities for a room given events to
persist.
@@ -444,13 +459,13 @@ class EventsPersistenceStorage(object):
)
# Remove any events which are prev_events of any existing events.
existing_prevs = yield self.main_store._get_events_which_are_prevs(result)
existing_prevs = await self.main_store._get_events_which_are_prevs(result)
result.difference_update(existing_prevs)
# Finally handle the case where the new events have soft-failed prev
# events. If they do we need to remove them and their prev events,
# otherwise we end up with dangling extremities.
existing_prevs = yield self.main_store._get_prevs_before_rejected(
existing_prevs = await self.main_store._get_prevs_before_rejected(
e_id for event in new_events for e_id in event.prev_event_ids()
)
result.difference_update(existing_prevs)
@@ -464,10 +479,13 @@ class EventsPersistenceStorage(object):
return result
@defer.inlineCallbacks
def _get_new_state_after_events(
self, room_id, events_context, old_latest_event_ids, new_latest_event_ids
):
async def _get_new_state_after_events(
self,
room_id: str,
events_context: List[Tuple[FrozenEvent, EventContext]],
old_latest_event_ids: Iterable[str],
new_latest_event_ids: Iterable[str],
) -> Tuple[Optional[StateMap[str]], Optional[StateMap[str]]]:
"""Calculate the current state dict after adding some new events to
a room
@@ -485,7 +503,6 @@ class EventsPersistenceStorage(object):
the new forward extremities for the room.
Returns:
Deferred[tuple[dict[(str,str), str]|None, dict[(str,str), str]|None]]:
Returns a tuple of two state maps, the first being the full new current
state and the second being the delta to the existing current state.
If both are None then there has been no change.
@@ -547,7 +564,7 @@ class EventsPersistenceStorage(object):
if missing_event_ids:
# Now pull out the state groups for any missing events from DB
event_to_groups = yield self.main_store._get_state_group_for_events(
event_to_groups = await self.main_store._get_state_group_for_events(
missing_event_ids
)
event_id_to_state_group.update(event_to_groups)
@@ -588,7 +605,7 @@ class EventsPersistenceStorage(object):
# their state IDs so we can resolve to a single state set.
missing_state = new_state_groups - set(state_groups_map)
if missing_state:
group_to_state = yield self.state_store._get_state_for_groups(missing_state)
group_to_state = await self.state_store._get_state_for_groups(missing_state)
state_groups_map.update(group_to_state)
if len(new_state_groups) == 1:
@@ -612,10 +629,10 @@ class EventsPersistenceStorage(object):
break
if not room_version:
room_version = yield self.main_store.get_room_version(room_id)
room_version = await self.main_store.get_room_version(room_id)
logger.debug("calling resolve_state_groups from preserve_events")
res = yield self._state_resolution_handler.resolve_state_groups(
res = await self._state_resolution_handler.resolve_state_groups(
room_id,
room_version,
state_groups,
@@ -625,18 +642,14 @@ class EventsPersistenceStorage(object):
return res.state, None
@defer.inlineCallbacks
def _calculate_state_delta(self, room_id, current_state):
async def _calculate_state_delta(
self, room_id: str, current_state: StateMap[str]
) -> DeltaState:
"""Calculate the new state deltas for a room.
Assumes that we are only persisting events for one room at a time.
Returns:
tuple[list, dict] (to_delete, to_insert): where to_delete are the
type/state_keys to remove from current_state_events and `to_insert`
are the updates to current_state_events.
"""
existing_state = yield self.main_store.get_current_state_ids(room_id)
existing_state = await self.main_store.get_current_state_ids(room_id)
to_delete = [key for key in existing_state if key not in current_state]
@@ -646,4 +659,4 @@ class EventsPersistenceStorage(object):
if ev_id != existing_state.get(key)
}
return to_delete, to_insert
return DeltaState(to_delete=to_delete, to_insert=to_insert)

View File

@@ -435,6 +435,19 @@ class UserRestTestCase(unittest.HomeserverTestCase):
self.assertEqual(0, channel.json_body["is_guest"])
self.assertEqual(0, channel.json_body["deactivated"])
# Change password
body = json.dumps({"password": "hahaha"})
request, channel = self.make_request(
"PUT",
self.url,
access_token=self.admin_user_tok,
content=body.encode(encoding="utf_8"),
)
self.render(request)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
# Modify user
body = json.dumps({"displayname": "foobar", "deactivated": True})

View File

@@ -134,3 +134,30 @@ class EventStreamPermissionsTestCase(unittest.HomeserverTestCase):
# someone else set topic, expect 6 (join,send,topic,join,send,topic)
pass
class GetEventsTestCase(unittest.HomeserverTestCase):
servlets = [
events.register_servlets,
room.register_servlets,
synapse.rest.admin.register_servlets_for_client_rest_resource,
login.register_servlets,
]
def prepare(self, hs, reactor, clock):
# register an account
self.user_id = self.register_user("sid1", "pass")
self.token = self.login(self.user_id, "pass")
self.room_id = self.helper.create_room_as(self.user_id, tok=self.token)
def test_get_event_via_events(self):
resp = self.helper.send(self.room_id, tok=self.token)
event_id = resp["event_id"]
request, channel = self.make_request(
"GET", "/events/" + event_id, access_token=self.token,
)
self.render(request)
self.assertEquals(channel.code, 200, msg=channel.result)

View File

@@ -463,7 +463,7 @@ class HomeserverTestCase(TestCase):
# Create the user
request, channel = self.make_request("GET", "/_matrix/client/r0/admin/register")
self.render(request)
self.assertEqual(channel.code, 200)
self.assertEqual(channel.code, 200, msg=channel.result)
nonce = channel.json_body["nonce"]
want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)

View File

@@ -177,13 +177,13 @@ env =
MYPYPATH = stubs/
extras = all
commands = mypy \
synapse/api \
synapse/config/ \
synapse/handlers/ui_auth \
synapse/logging/ \
synapse/module_api \
synapse/replication \
synapse/rest/consent \
synapse/rest/saml2 \
synapse/rest \
synapse/spam_checker_api \
synapse/storage/engines \
synapse/streams