1
0

Merge commit '1baab2035' into anoa/dinsic_release_1_31_0

This commit is contained in:
Andrew Morgan
2021-04-22 18:30:10 +01:00
38 changed files with 653 additions and 271 deletions

1
changelog.d/9164.bugfix Normal file
View File

@@ -0,0 +1 @@
Fix a long-standing bug where an internal server error was raised when attempting to preview an HTML document in an unknown character encoding.

1
changelog.d/9198.misc Normal file
View File

@@ -0,0 +1 @@
Precompute joined hosts and store in Redis.

1
changelog.d/9199.removal Normal file
View File

@@ -0,0 +1 @@
The `service_url` parameter in `cas_config` is deprecated in favor of `public_baseurl`.

1
changelog.d/9218.bugfix Normal file
View File

@@ -0,0 +1 @@
Fix bug where we sometimes didn't detect that Redis connections had died, causing workers to not see new data.

1
changelog.d/9222.misc Normal file
View File

@@ -0,0 +1 @@
Update `isort` to v5.7.0 to bypass a bug where it would disagree with `black` about formatting.

1
changelog.d/9223.misc Normal file
View File

@@ -0,0 +1 @@
Add type hints to handlers code.

View File

@@ -2057,10 +2057,6 @@ cas_config:
#
#server_url: "https://cas-server.com"
# The public URL of the homeserver.
#
#service_url: "https://homeserver.domain.com:8448"
# The attribute of the CAS response to use as the display name.
#
# If unset, no displayname will be set.

View File

@@ -26,6 +26,8 @@ files =
synapse/handlers/_base.py,
synapse/handlers/account_data.py,
synapse/handlers/account_validity.py,
synapse/handlers/acme.py,
synapse/handlers/acme_issuing_service.py,
synapse/handlers/admin.py,
synapse/handlers/appservice.py,
synapse/handlers/auth.py,
@@ -36,6 +38,7 @@ files =
synapse/handlers/directory.py,
synapse/handlers/events.py,
synapse/handlers/federation.py,
synapse/handlers/groups_local.py,
synapse/handlers/identity.py,
synapse/handlers/initial_sync.py,
synapse/handlers/message.py,
@@ -52,8 +55,13 @@ files =
synapse/handlers/room_member.py,
synapse/handlers/room_member_worker.py,
synapse/handlers/saml_handler.py,
synapse/handlers/search.py,
synapse/handlers/set_password.py,
synapse/handlers/sso.py,
synapse/handlers/state_deltas.py,
synapse/handlers/stats.py,
synapse/handlers/sync.py,
synapse/handlers/typing.py,
synapse/handlers/user_directory.py,
synapse/handlers/ui_auth,
synapse/http/client.py,
@@ -194,3 +202,9 @@ ignore_missing_imports = True
[mypy-hiredis]
ignore_missing_imports = True
[mypy-josepy.*]
ignore_missing_imports = True
[mypy-txacme.*]
ignore_missing_imports = True

View File

@@ -96,7 +96,7 @@ CONDITIONAL_REQUIREMENTS["all"] = list(ALL_OPTIONAL_REQUIREMENTS)
#
# We pin black so that our tests don't start failing on new releases.
CONDITIONAL_REQUIREMENTS["lint"] = [
"isort==5.0.3",
"isort==5.7.0",
"black==19.10b0",
"flake8-comprehensions",
"flake8",

View File

@@ -15,12 +15,23 @@
"""Contains *incomplete* type hints for txredisapi.
"""
from typing import List, Optional, Type, Union
from typing import Any, List, Optional, Type, Union
class RedisProtocol:
def publish(self, channel: str, message: bytes): ...
async def ping(self) -> None: ...
async def set(
self,
key: str,
value: Any,
expire: Optional[int] = None,
pexpire: Optional[int] = None,
only_if_not_exists: bool = False,
only_if_exists: bool = False,
) -> None: ...
async def get(self, key: str) -> Any: ...
class SubscriberProtocol:
class SubscriberProtocol(RedisProtocol):
def __init__(self, *args, **kwargs): ...
password: Optional[str]
def subscribe(self, channels: Union[str, List[str]]): ...
@@ -39,14 +50,13 @@ def lazyConnection(
convertNumbers: bool = ...,
) -> RedisProtocol: ...
class SubscriberFactory:
def buildProtocol(self, addr): ...
class ConnectionHandler: ...
class RedisFactory:
continueTrying: bool
handler: RedisProtocol
pool: List[RedisProtocol]
replyTimeout: Optional[int]
def __init__(
self,
uuid: str,
@@ -59,3 +69,7 @@ class RedisFactory:
replyTimeout: Optional[int] = None,
convertNumbers: Optional[int] = True,
): ...
def buildProtocol(self, addr) -> RedisProtocol: ...
class SubscriberFactory(RedisFactory):
def __init__(self): ...

View File

@@ -20,6 +20,7 @@ from synapse.config import (
password_auth_providers,
push,
ratelimiting,
redis,
registration,
repository,
room_directory,
@@ -83,6 +84,7 @@ class RootConfig:
roomdirectory: room_directory.RoomDirectoryConfig
thirdpartyrules: third_party_event_rules.ThirdPartyRulesConfig
tracer: tracer.TracerConfig
redis: redis.RedisConfig
config_classes: List = ...
def __init__(self) -> None: ...

View File

@@ -30,7 +30,13 @@ class CasConfig(Config):
if self.cas_enabled:
self.cas_server_url = cas_config["server_url"]
self.cas_service_url = cas_config["service_url"]
public_base_url = cas_config.get("service_url") or self.public_baseurl
if public_base_url[-1] != "/":
public_base_url += "/"
# TODO Update this to a _synapse URL.
self.cas_service_url = (
public_base_url + "_matrix/client/r0/login/cas/ticket"
)
self.cas_displayname_attribute = cas_config.get("displayname_attribute")
self.cas_required_attributes = cas_config.get("required_attributes") or {}
else:
@@ -53,10 +59,6 @@ class CasConfig(Config):
#
#server_url: "https://cas-server.com"
# The public URL of the homeserver.
#
#service_url: "https://homeserver.domain.com:8448"
# The attribute of the CAS response to use as the display name.
#
# If unset, no displayname will be set.

View File

@@ -54,8 +54,7 @@ class OIDCConfig(Config):
"Multiple OIDC providers have the idp_id %r." % idp_id
)
public_baseurl = self.public_baseurl
self.oidc_callback_url = public_baseurl + "_synapse/oidc/callback"
self.oidc_callback_url = self.public_baseurl + "_synapse/oidc/callback"
@property
def oidc_enabled(self) -> bool:

View File

@@ -142,6 +142,8 @@ class FederationSender:
self._wake_destinations_needing_catchup,
)
self._external_cache = hs.get_external_cache()
def _get_per_destination_queue(self, destination: str) -> PerDestinationQueue:
"""Get or create a PerDestinationQueue for the given destination
@@ -197,22 +199,40 @@ class FederationSender:
if not event.internal_metadata.should_proactively_send():
return
try:
# Get the state from before the event.
# We need to make sure that this is the state from before
# the event and not from after it.
# Otherwise if the last member on a server in a room is
# banned then it won't receive the event because it won't
# be in the room after the ban.
destinations = await self.state.get_hosts_in_room_at_events(
event.room_id, event_ids=event.prev_event_ids()
destinations = None # type: Optional[Set[str]]
if not event.prev_event_ids():
# If there are no prev event IDs then the state is empty
# and so no remote servers in the room
destinations = set()
else:
# We check the external cache for the destinations, which is
# stored per state group.
sg = await self._external_cache.get(
"event_to_prev_state_group", event.event_id
)
except Exception:
logger.exception(
"Failed to calculate hosts in room for event: %s",
event.event_id,
)
return
if sg:
destinations = await self._external_cache.get(
"get_joined_hosts", str(sg)
)
if destinations is None:
try:
# Get the state from before the event.
# We need to make sure that this is the state from before
# the event and not from after it.
# Otherwise if the last member on a server in a room is
# banned then it won't receive the event because it won't
# be in the room after the ban.
destinations = await self.state.get_hosts_in_room_at_events(
event.room_id, event_ids=event.prev_event_ids()
)
except Exception:
logger.exception(
"Failed to calculate hosts in room for event: %s",
event.event_id,
)
return
destinations = {
d

View File

@@ -14,6 +14,7 @@
# limitations under the License.
import logging
from typing import TYPE_CHECKING
import twisted
import twisted.internet.error
@@ -22,6 +23,9 @@ from twisted.web.resource import Resource
from synapse.app import check_bind_error
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__)
ACME_REGISTER_FAIL_ERROR = """
@@ -35,12 +39,12 @@ solutions, please read https://github.com/matrix-org/synapse/blob/master/docs/AC
class AcmeHandler:
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.reactor = hs.get_reactor()
self._acme_domain = hs.config.acme_domain
async def start_listening(self):
async def start_listening(self) -> None:
from synapse.handlers import acme_issuing_service
# Configure logging for txacme, if you need to debug
@@ -85,7 +89,7 @@ class AcmeHandler:
logger.error(ACME_REGISTER_FAIL_ERROR)
raise
async def provision_certificate(self):
async def provision_certificate(self) -> None:
logger.warning("Reprovisioning %s", self._acme_domain)
@@ -110,5 +114,3 @@ class AcmeHandler:
except Exception:
logger.exception("Failed saving!")
raise
return True

View File

@@ -22,8 +22,10 @@ only need (and may only have available) if we are doing ACME, so is designed to
imported conditionally.
"""
import logging
from typing import Dict, Iterable, List
import attr
import pem
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization
from josepy import JWKRSA
@@ -36,20 +38,27 @@ from txacme.util import generate_private_key
from zope.interface import implementer
from twisted.internet import defer
from twisted.internet.interfaces import IReactorTCP
from twisted.python.filepath import FilePath
from twisted.python.url import URL
from twisted.web.resource import IResource
logger = logging.getLogger(__name__)
def create_issuing_service(reactor, acme_url, account_key_file, well_known_resource):
def create_issuing_service(
reactor: IReactorTCP,
acme_url: str,
account_key_file: str,
well_known_resource: IResource,
) -> AcmeIssuingService:
"""Create an ACME issuing service, and attach it to a web Resource
Args:
reactor: twisted reactor
acme_url (str): URL to use to request certificates
account_key_file (str): where to store the account key
well_known_resource (twisted.web.IResource): web resource for .well-known.
acme_url: URL to use to request certificates
account_key_file: where to store the account key
well_known_resource: web resource for .well-known.
we will attach a child resource for "acme-challenge".
Returns:
@@ -83,18 +92,20 @@ class ErsatzStore:
A store that only stores in memory.
"""
certs = attr.ib(default=attr.Factory(dict))
certs = attr.ib(type=Dict[bytes, List[bytes]], default=attr.Factory(dict))
def store(self, server_name, pem_objects):
def store(
self, server_name: bytes, pem_objects: Iterable[pem.AbstractPEMObject]
) -> defer.Deferred:
self.certs[server_name] = [o.as_bytes() for o in pem_objects]
return defer.succeed(None)
def load_or_create_client_key(key_file):
def load_or_create_client_key(key_file: str) -> JWKRSA:
"""Load the ACME account key from a file, creating it if it does not exist.
Args:
key_file (str): name of the file to use as the account key
key_file: name of the file to use as the account key
"""
# this is based on txacme.endpoint.load_or_create_client_key, but doesn't
# hardcode the 'client.key' filename

View File

@@ -99,11 +99,7 @@ class CasHandler:
Returns:
The URL to use as a "service" parameter.
"""
return "%s%s?%s" % (
self._cas_service_url,
"/_matrix/client/r0/login/cas/ticket",
urllib.parse.urlencode(args),
)
return "%s?%s" % (self._cas_service_url, urllib.parse.urlencode(args),)
async def _validate_ticket(
self, ticket: str, service_args: Dict[str, str]

View File

@@ -2285,6 +2285,11 @@ class FederationHandler(BaseHandler):
if event.type == EventTypes.GuestAccess and not context.rejected:
await self.maybe_kick_guest_users(event)
# If we are going to send this event over federation we precaclculate
# the joined hosts.
if event.internal_metadata.get_send_on_behalf_of():
await self.event_creation_handler.cache_joined_hosts_for_event(event)
return context
async def _check_for_soft_fail(

View File

@@ -15,9 +15,13 @@
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Dict, Iterable, List, Set
from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError
from synapse.types import GroupID, get_domain_from_id
from synapse.types import GroupID, JsonDict, get_domain_from_id
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__)
@@ -56,7 +60,7 @@ def _create_rerouter(func_name):
class GroupsLocalWorkerHandler:
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.store = hs.get_datastore()
self.room_list_handler = hs.get_room_list_handler()
@@ -84,7 +88,9 @@ class GroupsLocalWorkerHandler:
get_group_role = _create_rerouter("get_group_role")
get_group_roles = _create_rerouter("get_group_roles")
async def get_group_summary(self, group_id, requester_user_id):
async def get_group_summary(
self, group_id: str, requester_user_id: str
) -> JsonDict:
"""Get the group summary for a group.
If the group is remote we check that the users have valid attestations.
@@ -137,14 +143,15 @@ class GroupsLocalWorkerHandler:
return res
async def get_users_in_group(self, group_id, requester_user_id):
async def get_users_in_group(
self, group_id: str, requester_user_id: str
) -> JsonDict:
"""Get users in a group
"""
if self.is_mine_id(group_id):
res = await self.groups_server_handler.get_users_in_group(
return await self.groups_server_handler.get_users_in_group(
group_id, requester_user_id
)
return res
group_server_name = get_domain_from_id(group_id)
@@ -178,11 +185,11 @@ class GroupsLocalWorkerHandler:
return res
async def get_joined_groups(self, user_id):
async def get_joined_groups(self, user_id: str) -> JsonDict:
group_ids = await self.store.get_joined_groups(user_id)
return {"groups": group_ids}
async def get_publicised_groups_for_user(self, user_id):
async def get_publicised_groups_for_user(self, user_id: str) -> JsonDict:
if self.hs.is_mine_id(user_id):
result = await self.store.get_publicised_groups_for_user(user_id)
@@ -206,8 +213,10 @@ class GroupsLocalWorkerHandler:
# TODO: Verify attestations
return {"groups": result}
async def bulk_get_publicised_groups(self, user_ids, proxy=True):
destinations = {}
async def bulk_get_publicised_groups(
self, user_ids: Iterable[str], proxy: bool = True
) -> JsonDict:
destinations = {} # type: Dict[str, Set[str]]
local_users = set()
for user_id in user_ids:
@@ -220,7 +229,7 @@ class GroupsLocalWorkerHandler:
raise SynapseError(400, "Some user_ids are not local")
results = {}
failed_results = []
failed_results = [] # type: List[str]
for destination, dest_user_ids in destinations.items():
try:
r = await self.transport_client.bulk_get_publicised_groups(
@@ -242,7 +251,7 @@ class GroupsLocalWorkerHandler:
class GroupsLocalHandler(GroupsLocalWorkerHandler):
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
# Ensure attestations get renewed
@@ -271,7 +280,9 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
set_group_join_policy = _create_rerouter("set_group_join_policy")
async def create_group(self, group_id, user_id, content):
async def create_group(
self, group_id: str, user_id: str, content: JsonDict
) -> JsonDict:
"""Create a group
"""
@@ -284,27 +295,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
local_attestation = None
remote_attestation = None
else:
local_attestation = self.attestations.create_attestation(group_id, user_id)
content["attestation"] = local_attestation
content["user_profile"] = await self.profile_handler.get_profile(user_id)
try:
res = await self.transport_client.create_group(
get_domain_from_id(group_id), group_id, user_id, content
)
except HttpResponseException as e:
raise e.to_synapse_error()
except RequestSendFailed:
raise SynapseError(502, "Failed to contact group server")
remote_attestation = res["attestation"]
await self.attestations.verify_attestation(
remote_attestation,
group_id=group_id,
user_id=user_id,
server_name=get_domain_from_id(group_id),
)
raise SynapseError(400, "Unable to create remote groups")
is_publicised = content.get("publicise", False)
token = await self.store.register_user_group_membership(
@@ -320,7 +311,9 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
return res
async def join_group(self, group_id, user_id, content):
async def join_group(
self, group_id: str, user_id: str, content: JsonDict
) -> JsonDict:
"""Request to join a group
"""
if self.is_mine_id(group_id):
@@ -365,7 +358,9 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
return {}
async def accept_invite(self, group_id, user_id, content):
async def accept_invite(
self, group_id: str, user_id: str, content: JsonDict
) -> JsonDict:
"""Accept an invite to a group
"""
if self.is_mine_id(group_id):
@@ -410,7 +405,9 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
return {}
async def invite(self, group_id, user_id, requester_user_id, config):
async def invite(
self, group_id: str, user_id: str, requester_user_id: str, config: JsonDict
) -> JsonDict:
"""Invite a user to a group
"""
content = {"requester_user_id": requester_user_id, "config": config}
@@ -434,7 +431,9 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
return res
async def on_invite(self, group_id, user_id, content):
async def on_invite(
self, group_id: str, user_id: str, content: JsonDict
) -> JsonDict:
"""One of our users were invited to a group
"""
# TODO: Support auto join and rejection
@@ -465,8 +464,8 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
return {"state": "invite", "user_profile": user_profile}
async def remove_user_from_group(
self, group_id, user_id, requester_user_id, content
):
self, group_id: str, user_id: str, requester_user_id: str, content: JsonDict
) -> JsonDict:
"""Remove a user from a group
"""
if user_id == requester_user_id:
@@ -499,7 +498,9 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
return res
async def user_removed_from_group(self, group_id, user_id, content):
async def user_removed_from_group(
self, group_id: str, user_id: str, content: JsonDict
) -> None:
"""One of our users was removed/kicked from a group
"""
# TODO: Check if user in group

View File

@@ -434,6 +434,8 @@ class EventCreationHandler:
self._ephemeral_events_enabled = hs.config.enable_ephemeral_messages
self._external_cache = hs.get_external_cache()
async def create_event(
self,
requester: Requester,
@@ -941,6 +943,8 @@ class EventCreationHandler:
await self.action_generator.handle_push_actions_for_event(event, context)
await self.cache_joined_hosts_for_event(event)
try:
# If we're a worker we need to hit out to the master.
writer_instance = self._events_shard_config.get_instance(event.room_id)
@@ -980,6 +984,44 @@ class EventCreationHandler:
await self.store.remove_push_actions_from_staging(event.event_id)
raise
async def cache_joined_hosts_for_event(self, event: EventBase) -> None:
"""Precalculate the joined hosts at the event, when using Redis, so that
external federation senders don't have to recalculate it themselves.
"""
if not self._external_cache.is_enabled():
return
# We actually store two mappings, event ID -> prev state group,
# state group -> joined hosts, which is much more space efficient
# than event ID -> joined hosts.
#
# Note: We have to cache event ID -> prev state group, as we don't
# store that in the DB.
#
# Note: We always set the state group -> joined hosts cache, even if
# we already set it, so that the expiry time is reset.
state_entry = await self.state.resolve_state_groups_for_events(
event.room_id, event_ids=event.prev_event_ids()
)
if state_entry.state_group:
joined_hosts = await self.store.get_joined_hosts(event.room_id, state_entry)
await self._external_cache.set(
"event_to_prev_state_group",
event.event_id,
state_entry.state_group,
expiry_ms=60 * 60 * 1000,
)
await self._external_cache.set(
"get_joined_hosts",
str(state_entry.state_group),
list(joined_hosts),
expiry_ms=60 * 60 * 1000,
)
async def _validate_canonical_alias(
self, directory_handler, room_alias_str: str, expected_room_id: str
) -> None:

View File

@@ -15,23 +15,28 @@
import itertools
import logging
from typing import Iterable
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional
from unpaddedbase64 import decode_base64, encode_base64
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import NotFoundError, SynapseError
from synapse.api.filtering import Filter
from synapse.events import EventBase
from synapse.storage.state import StateFilter
from synapse.types import JsonDict, UserID
from synapse.visibility import filter_events_for_client
from ._base import BaseHandler
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__)
class SearchHandler(BaseHandler):
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self._event_serializer = hs.get_event_client_serializer()
self.storage = hs.get_storage()
@@ -87,13 +92,15 @@ class SearchHandler(BaseHandler):
return historical_room_ids
async def search(self, user, content, batch=None):
async def search(
self, user: UserID, content: JsonDict, batch: Optional[str] = None
) -> JsonDict:
"""Performs a full text search for a user.
Args:
user (UserID)
content (dict): Search parameters
batch (str): The next_batch parameter. Used for pagination.
user
content: Search parameters
batch: The next_batch parameter. Used for pagination.
Returns:
dict to be returned to the client with results of search
@@ -186,7 +193,7 @@ class SearchHandler(BaseHandler):
# If doing a subset of all rooms seearch, check if any of the rooms
# are from an upgraded room, and search their contents as well
if search_filter.rooms:
historical_room_ids = []
historical_room_ids = [] # type: List[str]
for room_id in search_filter.rooms:
# Add any previous rooms to the search if they exist
ids = await self.get_old_rooms_from_upgraded_room(room_id)
@@ -209,8 +216,10 @@ class SearchHandler(BaseHandler):
rank_map = {} # event_id -> rank of event
allowed_events = []
room_groups = {} # Holds result of grouping by room, if applicable
sender_group = {} # Holds result of grouping by sender, if applicable
# Holds result of grouping by room, if applicable
room_groups = {} # type: Dict[str, JsonDict]
# Holds result of grouping by sender, if applicable
sender_group = {} # type: Dict[str, JsonDict]
# Holds the next_batch for the entire result set if one of those exists
global_next_batch = None
@@ -254,7 +263,7 @@ class SearchHandler(BaseHandler):
s["results"].append(e.event_id)
elif order_by == "recent":
room_events = []
room_events = [] # type: List[EventBase]
i = 0
pagination_token = batch_token
@@ -418,13 +427,10 @@ class SearchHandler(BaseHandler):
state_results = {}
if include_state:
rooms = {e.room_id for e in allowed_events}
for room_id in rooms:
for room_id in {e.room_id for e in allowed_events}:
state = await self.state_handler.get_current_state(room_id)
state_results[room_id] = list(state.values())
state_results.values()
# We're now about to serialize the events. We should not make any
# blocking calls after this. Otherwise the 'age' will be wrong
@@ -448,9 +454,9 @@ class SearchHandler(BaseHandler):
if state_results:
s = {}
for room_id, state in state_results.items():
for room_id, state_events in state_results.items():
s[room_id] = await self._event_serializer.serialize_events(
state, time_now
state_events, time_now
)
rooms_cat_res["state"] = s

View File

@@ -14,24 +14,26 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import Optional
from typing import TYPE_CHECKING, Optional
from synapse.api.errors import Codes, StoreError, SynapseError
from synapse.types import Requester
from ._base import BaseHandler
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__)
class SetPasswordHandler(BaseHandler):
"""Handler which deals with changing user account passwords"""
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self._auth_handler = hs.get_auth_handler()
self._device_handler = hs.get_device_handler()
self._password_policy_handler = hs.get_password_policy_handler()
async def set_password(
self,
@@ -39,7 +41,7 @@ class SetPasswordHandler(BaseHandler):
password_hash: str,
logout_devices: bool,
requester: Optional[Requester] = None,
):
) -> None:
if not self.hs.config.password_localdb_enabled:
raise SynapseError(403, "Password change disabled", errcode=Codes.FORBIDDEN)

View File

@@ -14,15 +14,25 @@
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Optional
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__)
class StateDeltasHandler:
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
async def _get_key_change(self, prev_event_id, event_id, key_name, public_value):
async def _get_key_change(
self,
prev_event_id: Optional[str],
event_id: Optional[str],
key_name: str,
public_value: str,
) -> Optional[bool]:
"""Given two events check if the `key_name` field in content changed
from not matching `public_value` to doing so.

View File

@@ -14,13 +14,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from collections import Counter
from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Tuple
from typing_extensions import Counter as CounterType
from synapse.api.constants import EventTypes, Membership
from synapse.metrics import event_processing_positions
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.types import JsonDict
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__)
@@ -33,7 +39,7 @@ class StatsHandler:
Heavily derived from UserDirectoryHandler
"""
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.store = hs.get_datastore()
self.state = hs.get_state_handler()
@@ -46,7 +52,7 @@ class StatsHandler:
self.stats_enabled = hs.config.stats_enabled
# The current position in the current_state_delta stream
self.pos = None
self.pos = None # type: Optional[int]
# Guard to ensure we only process deltas one at a time
self._is_processing = False
@@ -58,7 +64,7 @@ class StatsHandler:
# we start populating stats
self.clock.call_later(0, self.notify_new_event)
def notify_new_event(self):
def notify_new_event(self) -> None:
"""Called when there may be more deltas to process
"""
if not self.stats_enabled or self._is_processing:
@@ -74,7 +80,7 @@ class StatsHandler:
run_as_background_process("stats.notify_new_event", process)
async def _unsafe_process(self):
async def _unsafe_process(self) -> None:
# If self.pos is None then means we haven't fetched it from DB
if self.pos is None:
self.pos = await self.store.get_stats_positions()
@@ -112,10 +118,10 @@ class StatsHandler:
)
for room_id, fields in room_count.items():
room_deltas.setdefault(room_id, {}).update(fields)
room_deltas.setdefault(room_id, Counter()).update(fields)
for user_id, fields in user_count.items():
user_deltas.setdefault(user_id, {}).update(fields)
user_deltas.setdefault(user_id, Counter()).update(fields)
logger.debug("room_deltas: %s", room_deltas)
logger.debug("user_deltas: %s", user_deltas)
@@ -133,19 +139,20 @@ class StatsHandler:
self.pos = max_pos
async def _handle_deltas(self, deltas):
async def _handle_deltas(
self, deltas: Iterable[JsonDict]
) -> Tuple[Dict[str, CounterType[str]], Dict[str, CounterType[str]]]:
"""Called with the state deltas to process
Returns:
tuple[dict[str, Counter], dict[str, counter]]
Two dicts: the room deltas and the user deltas,
mapping from room/user ID to changes in the various fields.
"""
room_to_stats_deltas = {}
user_to_stats_deltas = {}
room_to_stats_deltas = {} # type: Dict[str, CounterType[str]]
user_to_stats_deltas = {} # type: Dict[str, CounterType[str]]
room_to_state_updates = {}
room_to_state_updates = {} # type: Dict[str, Dict[str, Any]]
for delta in deltas:
typ = delta["type"]
@@ -175,7 +182,7 @@ class StatsHandler:
)
continue
event_content = {}
event_content = {} # type: JsonDict
sender = None
if event_id is not None:
@@ -263,13 +270,13 @@ class StatsHandler:
)
if has_changed_joinedness:
delta = +1 if membership == Membership.JOIN else -1
membership_delta = +1 if membership == Membership.JOIN else -1
user_to_stats_deltas.setdefault(user_id, Counter())[
"joined_rooms"
] += delta
] += membership_delta
room_stats_delta["local_users_in_room"] += delta
room_stats_delta["local_users_in_room"] += membership_delta
elif typ == EventTypes.Create:
room_state["is_federatable"] = (

View File

@@ -15,13 +15,13 @@
import logging
import random
from collections import namedtuple
from typing import TYPE_CHECKING, List, Set, Tuple
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple
from synapse.api.errors import AuthError, ShadowBanError, SynapseError
from synapse.appservice import ApplicationService
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.tcp.streams import TypingStream
from synapse.types import JsonDict, UserID, get_domain_from_id
from synapse.types import JsonDict, Requester, UserID, get_domain_from_id
from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.util.metrics import Measure
from synapse.util.wheel_timer import WheelTimer
@@ -65,17 +65,17 @@ class FollowerTypingHandler:
)
# map room IDs to serial numbers
self._room_serials = {}
self._room_serials = {} # type: Dict[str, int]
# map room IDs to sets of users currently typing
self._room_typing = {}
self._room_typing = {} # type: Dict[str, Set[str]]
self._member_last_federation_poke = {}
self._member_last_federation_poke = {} # type: Dict[RoomMember, int]
self.wheel_timer = WheelTimer(bucket_size=5000)
self._latest_room_serial = 0
self.clock.looping_call(self._handle_timeouts, 5000)
def _reset(self):
def _reset(self) -> None:
"""Reset the typing handler's data caches.
"""
# map room IDs to serial numbers
@@ -86,7 +86,7 @@ class FollowerTypingHandler:
self._member_last_federation_poke = {}
self.wheel_timer = WheelTimer(bucket_size=5000)
def _handle_timeouts(self):
def _handle_timeouts(self) -> None:
logger.debug("Checking for typing timeouts")
now = self.clock.time_msec()
@@ -96,7 +96,7 @@ class FollowerTypingHandler:
for member in members:
self._handle_timeout_for_member(now, member)
def _handle_timeout_for_member(self, now: int, member: RoomMember):
def _handle_timeout_for_member(self, now: int, member: RoomMember) -> None:
if not self.is_typing(member):
# Nothing to do if they're no longer typing
return
@@ -114,10 +114,10 @@ class FollowerTypingHandler:
# each person typing.
self.wheel_timer.insert(now=now, obj=member, then=now + 60 * 1000)
def is_typing(self, member):
def is_typing(self, member: RoomMember) -> bool:
return member.user_id in self._room_typing.get(member.room_id, [])
async def _push_remote(self, member, typing):
async def _push_remote(self, member: RoomMember, typing: bool) -> None:
if not self.federation:
return
@@ -148,7 +148,7 @@ class FollowerTypingHandler:
def process_replication_rows(
self, token: int, rows: List[TypingStream.TypingStreamRow]
):
) -> None:
"""Should be called whenever we receive updates for typing stream.
"""
@@ -178,7 +178,7 @@ class FollowerTypingHandler:
async def _send_changes_in_typing_to_remotes(
self, room_id: str, prev_typing: Set[str], now_typing: Set[str]
):
) -> None:
"""Process a change in typing of a room from replication, sending EDUs
for any local users.
"""
@@ -194,12 +194,12 @@ class FollowerTypingHandler:
if self.is_mine_id(user_id):
await self._push_remote(RoomMember(room_id, user_id), False)
def get_current_token(self):
def get_current_token(self) -> int:
return self._latest_room_serial
class TypingWriterHandler(FollowerTypingHandler):
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
assert hs.config.worker.writers.typing == hs.get_instance_name()
@@ -213,14 +213,15 @@ class TypingWriterHandler(FollowerTypingHandler):
hs.get_distributor().observe("user_left_room", self.user_left_room)
self._member_typing_until = {} # clock time we expect to stop
# clock time we expect to stop
self._member_typing_until = {} # type: Dict[RoomMember, int]
# caches which room_ids changed at which serials
self._typing_stream_change_cache = StreamChangeCache(
"TypingStreamChangeCache", self._latest_room_serial
)
def _handle_timeout_for_member(self, now: int, member: RoomMember):
def _handle_timeout_for_member(self, now: int, member: RoomMember) -> None:
super()._handle_timeout_for_member(now, member)
if not self.is_typing(member):
@@ -233,7 +234,9 @@ class TypingWriterHandler(FollowerTypingHandler):
self._stopped_typing(member)
return
async def started_typing(self, target_user, requester, room_id, timeout):
async def started_typing(
self, target_user: UserID, requester: Requester, room_id: str, timeout: int
) -> None:
target_user_id = target_user.to_string()
auth_user_id = requester.user.to_string()
@@ -263,11 +266,13 @@ class TypingWriterHandler(FollowerTypingHandler):
if was_present:
# No point sending another notification
return None
return
self._push_update(member=member, typing=True)
async def stopped_typing(self, target_user, requester, room_id):
async def stopped_typing(
self, target_user: UserID, requester: Requester, room_id: str
) -> None:
target_user_id = target_user.to_string()
auth_user_id = requester.user.to_string()
@@ -290,23 +295,23 @@ class TypingWriterHandler(FollowerTypingHandler):
self._stopped_typing(member)
def user_left_room(self, user, room_id):
def user_left_room(self, user: UserID, room_id: str) -> None:
user_id = user.to_string()
if self.is_mine_id(user_id):
member = RoomMember(room_id=room_id, user_id=user_id)
self._stopped_typing(member)
def _stopped_typing(self, member):
def _stopped_typing(self, member: RoomMember) -> None:
if member.user_id not in self._room_typing.get(member.room_id, set()):
# No point
return None
return
self._member_typing_until.pop(member, None)
self._member_last_federation_poke.pop(member, None)
self._push_update(member=member, typing=False)
def _push_update(self, member, typing):
def _push_update(self, member: RoomMember, typing: bool) -> None:
if self.hs.is_mine_id(member.user_id):
# Only send updates for changes to our own users.
run_as_background_process(
@@ -315,7 +320,7 @@ class TypingWriterHandler(FollowerTypingHandler):
self._push_update_local(member=member, typing=typing)
async def _recv_edu(self, origin, content):
async def _recv_edu(self, origin: str, content: JsonDict) -> None:
room_id = content["room_id"]
user_id = content["user_id"]
@@ -340,7 +345,7 @@ class TypingWriterHandler(FollowerTypingHandler):
self.wheel_timer.insert(now=now, obj=member, then=now + FEDERATION_TIMEOUT)
self._push_update_local(member=member, typing=content["typing"])
def _push_update_local(self, member, typing):
def _push_update_local(self, member: RoomMember, typing: bool) -> None:
room_set = self._room_typing.setdefault(member.room_id, set())
if typing:
room_set.add(member.user_id)
@@ -386,7 +391,7 @@ class TypingWriterHandler(FollowerTypingHandler):
changed_rooms = self._typing_stream_change_cache.get_all_entities_changed(
last_id
)
) # type: Optional[Iterable[str]]
if changed_rooms is None:
changed_rooms = self._room_serials
@@ -412,13 +417,13 @@ class TypingWriterHandler(FollowerTypingHandler):
def process_replication_rows(
self, token: int, rows: List[TypingStream.TypingStreamRow]
):
) -> None:
# The writing process should never get updates from replication.
raise Exception("Typing writer instance got typing info over replication")
class TypingNotificationEventSource:
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.clock = hs.get_clock()
# We can't call get_typing_handler here because there's a cycle:
@@ -427,7 +432,7 @@ class TypingNotificationEventSource:
#
self.get_typing_handler = hs.get_typing_handler
def _make_event_for(self, room_id):
def _make_event_for(self, room_id: str) -> JsonDict:
typing = self.get_typing_handler()._room_typing[room_id]
return {
"type": "m.typing",
@@ -462,7 +467,9 @@ class TypingNotificationEventSource:
return (events, handler._latest_room_serial)
async def get_new_events(self, from_key, room_ids, **kwargs):
async def get_new_events(
self, from_key: int, room_ids: Iterable[str], **kwargs
) -> Tuple[List[JsonDict], int]:
with Measure(self.clock, "typing.get_new_events"):
from_key = int(from_key)
handler = self.get_typing_handler()
@@ -478,5 +485,5 @@ class TypingNotificationEventSource:
return (events, handler._latest_room_serial)
def get_current_key(self):
def get_current_key(self) -> int:
return self.get_typing_handler()._latest_room_serial

View File

@@ -145,10 +145,6 @@ class UserDirectoryHandler(StateDeltasHandler):
if self.pos is None:
self.pos = await self.store.get_user_directory_stream_pos()
# If still None then the initial background update hasn't happened yet
if self.pos is None:
return None
# Loop round handling deltas until we're up to date
while True:
with Measure(self.clock, "user_dir_delta"):
@@ -233,6 +229,11 @@ class UserDirectoryHandler(StateDeltasHandler):
if change: # The user joined
event = await self.store.get_event(event_id, allow_none=True)
# It isn't expected for this event to not exist, but we
# don't want the entire background process to break.
if event is None:
continue
profile = ProfileInfo(
avatar_url=event.content.get("avatar_url"),
display_name=event.content.get("displayname"),

View File

@@ -0,0 +1,105 @@
# -*- coding: utf-8 -*-
# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Any, Optional
from prometheus_client import Counter
from synapse.logging.context import make_deferred_yieldable
from synapse.util import json_decoder, json_encoder
if TYPE_CHECKING:
from synapse.server import HomeServer
set_counter = Counter(
"synapse_external_cache_set",
"Number of times we set a cache",
labelnames=["cache_name"],
)
get_counter = Counter(
"synapse_external_cache_get",
"Number of times we get a cache",
labelnames=["cache_name", "hit"],
)
logger = logging.getLogger(__name__)
class ExternalCache:
"""A cache backed by an external Redis. Does nothing if no Redis is
configured.
"""
def __init__(self, hs: "HomeServer"):
self._redis_connection = hs.get_outbound_redis_connection()
def _get_redis_key(self, cache_name: str, key: str) -> str:
return "cache_v1:%s:%s" % (cache_name, key)
def is_enabled(self) -> bool:
"""Whether the external cache is used or not.
It's safe to use the cache when this returns false, the methods will
just no-op, but the function is useful to avoid doing unnecessary work.
"""
return self._redis_connection is not None
async def set(self, cache_name: str, key: str, value: Any, expiry_ms: int) -> None:
"""Add the key/value to the named cache, with the expiry time given.
"""
if self._redis_connection is None:
return
set_counter.labels(cache_name).inc()
# txredisapi requires the value to be string, bytes or numbers, so we
# encode stuff in JSON.
encoded_value = json_encoder.encode(value)
logger.debug("Caching %s %s: %r", cache_name, key, encoded_value)
return await make_deferred_yieldable(
self._redis_connection.set(
self._get_redis_key(cache_name, key), encoded_value, pexpire=expiry_ms,
)
)
async def get(self, cache_name: str, key: str) -> Optional[Any]:
"""Look up a key/value in the named cache.
"""
if self._redis_connection is None:
return None
result = await make_deferred_yieldable(
self._redis_connection.get(self._get_redis_key(cache_name, key))
)
logger.debug("Got cache result %s %s: %r", cache_name, key, result)
get_counter.labels(cache_name, result is not None).inc()
if not result:
return None
# For some reason the integers get magically converted back to integers
if isinstance(result, int):
return result
return json_decoder.decode(result)

View File

@@ -15,6 +15,7 @@
# limitations under the License.
import logging
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Dict,
@@ -63,6 +64,9 @@ from synapse.replication.tcp.streams import (
TypingStream,
)
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@@ -88,7 +92,7 @@ class ReplicationCommandHandler:
back out to connections.
"""
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self._replication_data_handler = hs.get_replication_data_handler()
self._presence_handler = hs.get_presence_handler()
self._store = hs.get_datastore()
@@ -282,13 +286,6 @@ class ReplicationCommandHandler:
if hs.config.redis.redis_enabled:
from synapse.replication.tcp.redis import (
RedisDirectTcpReplicationClientFactory,
lazyConnection,
)
logger.info(
"Connecting to redis (host=%r port=%r)",
hs.config.redis_host,
hs.config.redis_port,
)
# First let's ensure that we have a ReplicationStreamer started.
@@ -299,13 +296,7 @@ class ReplicationCommandHandler:
# connection after SUBSCRIBE is called).
# First create the connection for sending commands.
outbound_redis_connection = lazyConnection(
reactor=hs.get_reactor(),
host=hs.config.redis_host,
port=hs.config.redis_port,
password=hs.config.redis.redis_password,
reconnect=True,
)
outbound_redis_connection = hs.get_outbound_redis_connection()
# Now create the factory/connection for the subscription stream.
self._factory = RedisDirectTcpReplicationClientFactory(

View File

@@ -15,7 +15,7 @@
import logging
from inspect import isawaitable
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Optional, Type, cast
import txredisapi
@@ -23,6 +23,7 @@ from synapse.logging.context import PreserveLoggingContext, make_deferred_yielda
from synapse.metrics.background_process_metrics import (
BackgroundProcessLoggingContext,
run_as_background_process,
wrap_as_background_process,
)
from synapse.replication.tcp.commands import (
Command,
@@ -59,16 +60,16 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
immediately after initialisation.
Attributes:
handler: The command handler to handle incoming commands.
stream_name: The *redis* stream name to subscribe to and publish from
(not anything to do with Synapse replication streams).
outbound_redis_connection: The connection to redis to use to send
synapse_handler: The command handler to handle incoming commands.
synapse_stream_name: The *redis* stream name to subscribe to and publish
from (not anything to do with Synapse replication streams).
synapse_outbound_redis_connection: The connection to redis to use to send
commands.
"""
handler = None # type: ReplicationCommandHandler
stream_name = None # type: str
outbound_redis_connection = None # type: txredisapi.RedisProtocol
synapse_handler = None # type: ReplicationCommandHandler
synapse_stream_name = None # type: str
synapse_outbound_redis_connection = None # type: txredisapi.RedisProtocol
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -88,19 +89,19 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
# it's important to make sure that we only send the REPLICATE command once we
# have successfully subscribed to the stream - otherwise we might miss the
# POSITION response sent back by the other end.
logger.info("Sending redis SUBSCRIBE for %s", self.stream_name)
await make_deferred_yieldable(self.subscribe(self.stream_name))
logger.info("Sending redis SUBSCRIBE for %s", self.synapse_stream_name)
await make_deferred_yieldable(self.subscribe(self.synapse_stream_name))
logger.info(
"Successfully subscribed to redis stream, sending REPLICATE command"
)
self.handler.new_connection(self)
self.synapse_handler.new_connection(self)
await self._async_send_command(ReplicateCommand())
logger.info("REPLICATE successfully sent")
# We send out our positions when there is a new connection in case the
# other side missed updates. We do this for Redis connections as the
# otherside won't know we've connected and so won't issue a REPLICATE.
self.handler.send_positions_to_connection(self)
self.synapse_handler.send_positions_to_connection(self)
def messageReceived(self, pattern: str, channel: str, message: str):
"""Received a message from redis.
@@ -137,7 +138,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
cmd: received command
"""
cmd_func = getattr(self.handler, "on_%s" % (cmd.NAME,), None)
cmd_func = getattr(self.synapse_handler, "on_%s" % (cmd.NAME,), None)
if not cmd_func:
logger.warning("Unhandled command: %r", cmd)
return
@@ -155,7 +156,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
def connectionLost(self, reason):
logger.info("Lost connection to redis")
super().connectionLost(reason)
self.handler.lost_connection(self)
self.synapse_handler.lost_connection(self)
# mark the logging context as finished
self._logging_context.__exit__(None, None, None)
@@ -183,11 +184,54 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
tcp_outbound_commands_counter.labels(cmd.NAME, "redis").inc()
await make_deferred_yieldable(
self.outbound_redis_connection.publish(self.stream_name, encoded_string)
self.synapse_outbound_redis_connection.publish(
self.synapse_stream_name, encoded_string
)
)
class RedisDirectTcpReplicationClientFactory(txredisapi.SubscriberFactory):
class SynapseRedisFactory(txredisapi.RedisFactory):
"""A subclass of RedisFactory that periodically sends pings to ensure that
we detect dead connections.
"""
def __init__(
self,
hs: "HomeServer",
uuid: str,
dbid: Optional[int],
poolsize: int,
isLazy: bool = False,
handler: Type = txredisapi.ConnectionHandler,
charset: str = "utf-8",
password: Optional[str] = None,
replyTimeout: int = 30,
convertNumbers: Optional[int] = True,
):
super().__init__(
uuid=uuid,
dbid=dbid,
poolsize=poolsize,
isLazy=isLazy,
handler=handler,
charset=charset,
password=password,
replyTimeout=replyTimeout,
convertNumbers=convertNumbers,
)
hs.get_clock().looping_call(self._send_ping, 30 * 1000)
@wrap_as_background_process("redis_ping")
async def _send_ping(self):
for connection in self.pool:
try:
await make_deferred_yieldable(connection.ping())
except Exception:
logger.warning("Failed to send ping to a redis connection")
class RedisDirectTcpReplicationClientFactory(SynapseRedisFactory):
"""This is a reconnecting factory that connects to redis and immediately
subscribes to a stream.
@@ -206,65 +250,62 @@ class RedisDirectTcpReplicationClientFactory(txredisapi.SubscriberFactory):
self, hs: "HomeServer", outbound_redis_connection: txredisapi.RedisProtocol
):
super().__init__()
super().__init__(
hs,
uuid="subscriber",
dbid=None,
poolsize=1,
replyTimeout=30,
password=hs.config.redis.redis_password,
)
# This sets the password on the RedisFactory base class (as
# SubscriberFactory constructor doesn't pass it through).
self.password = hs.config.redis.redis_password
self.synapse_handler = hs.get_tcp_replication()
self.synapse_stream_name = hs.hostname
self.handler = hs.get_tcp_replication()
self.stream_name = hs.hostname
self.outbound_redis_connection = outbound_redis_connection
self.synapse_outbound_redis_connection = outbound_redis_connection
def buildProtocol(self, addr):
p = super().buildProtocol(addr) # type: RedisSubscriber
p = super().buildProtocol(addr)
p = cast(RedisSubscriber, p)
# We do this here rather than add to the constructor of `RedisSubcriber`
# as to do so would involve overriding `buildProtocol` entirely, however
# the base method does some other things than just instantiating the
# protocol.
p.handler = self.handler
p.outbound_redis_connection = self.outbound_redis_connection
p.stream_name = self.stream_name
p.password = self.password
p.synapse_handler = self.synapse_handler
p.synapse_outbound_redis_connection = self.synapse_outbound_redis_connection
p.synapse_stream_name = self.synapse_stream_name
return p
def lazyConnection(
reactor,
hs: "HomeServer",
host: str = "localhost",
port: int = 6379,
dbid: Optional[int] = None,
reconnect: bool = True,
charset: str = "utf-8",
password: Optional[str] = None,
connectTimeout: Optional[int] = None,
replyTimeout: Optional[int] = None,
convertNumbers: bool = True,
replyTimeout: int = 30,
) -> txredisapi.RedisProtocol:
"""Equivalent to `txredisapi.lazyConnection`, except allows specifying a
reactor.
"""Creates a connection to Redis that is lazily set up and reconnects if the
connections is lost.
"""
isLazy = True
poolsize = 1
uuid = "%s:%d" % (host, port)
factory = txredisapi.RedisFactory(
uuid,
dbid,
poolsize,
isLazy,
txredisapi.ConnectionHandler,
charset,
password,
replyTimeout,
convertNumbers,
factory = SynapseRedisFactory(
hs,
uuid=uuid,
dbid=dbid,
poolsize=1,
isLazy=True,
handler=txredisapi.ConnectionHandler,
password=password,
replyTimeout=replyTimeout,
)
factory.continueTrying = reconnect
for x in range(poolsize):
reactor.connectTCP(host, port, factory, connectTimeout)
reactor = hs.get_reactor()
reactor.connectTCP(host, port, factory, 30)
return factory.handler

View File

@@ -385,7 +385,7 @@ class PreviewUrlResource(DirectServeJsonResource):
"""
Check whether the URL should be downloaded as oEmbed content instead.
Params:
Args:
url: The URL to check.
Returns:
@@ -402,7 +402,7 @@ class PreviewUrlResource(DirectServeJsonResource):
"""
Request content from an oEmbed endpoint.
Params:
Args:
endpoint: The oEmbed API endpoint.
url: The URL to pass to the API.
@@ -691,27 +691,51 @@ class PreviewUrlResource(DirectServeJsonResource):
def decode_and_calc_og(
body: bytes, media_uri: str, request_encoding: Optional[str] = None
) -> Dict[str, Optional[str]]:
"""
Calculate metadata for an HTML document.
This uses lxml to parse the HTML document into the OG response. If errors
occur during processing of the document, an empty response is returned.
Args:
body: The HTML document, as bytes.
media_url: The URI used to download the body.
request_encoding: The character encoding of the body, as a string.
Returns:
The OG response as a dictionary.
"""
# If there's no body, nothing useful is going to be found.
if not body:
return {}
from lxml import etree
# Create an HTML parser. If this fails, log and return no metadata.
try:
parser = etree.HTMLParser(recover=True, encoding=request_encoding)
tree = etree.fromstring(body, parser)
og = _calc_og(tree, media_uri)
except LookupError:
# blindly consider the encoding as utf-8.
parser = etree.HTMLParser(recover=True, encoding="utf-8")
except Exception as e:
logger.warning("Unable to create HTML parser: %s" % (e,))
return {}
def _attempt_calc_og(body_attempt: Union[bytes, str]) -> Dict[str, Optional[str]]:
# Attempt to parse the body. If this fails, log and return no metadata.
tree = etree.fromstring(body_attempt, parser)
return _calc_og(tree, media_uri)
# Attempt to parse the body. If this fails, log and return no metadata.
try:
return _attempt_calc_og(body)
except UnicodeDecodeError:
# blindly try decoding the body as utf-8, which seems to fix
# the charset mismatches on https://google.com
parser = etree.HTMLParser(recover=True, encoding=request_encoding)
tree = etree.fromstring(body.decode("utf-8", "ignore"), parser)
og = _calc_og(tree, media_uri)
return og
return _attempt_calc_og(body.decode("utf-8", "ignore"))
def _calc_og(tree, media_uri: str) -> Dict[str, Optional[str]]:
def _calc_og(tree: "etree.Element", media_uri: str) -> Dict[str, Optional[str]]:
# suck our tree into lxml and define our OG response.
# if we see any image URLs in the OG response, then spider them

View File

@@ -102,6 +102,7 @@ from synapse.notifier import Notifier
from synapse.push.action_generator import ActionGenerator
from synapse.push.pusherpool import PusherPool
from synapse.replication.tcp.client import ReplicationDataHandler
from synapse.replication.tcp.external_cache import ExternalCache
from synapse.replication.tcp.handler import ReplicationCommandHandler
from synapse.replication.tcp.resource import ReplicationStreamer
from synapse.replication.tcp.streams import STREAMS_MAP, Stream
@@ -127,6 +128,8 @@ from synapse.util.stringutils import random_string
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from txredisapi import RedisProtocol
from synapse.handlers.oidc_handler import OidcHandler
from synapse.handlers.saml_handler import SamlHandler
@@ -710,6 +713,33 @@ class HomeServer(metaclass=abc.ABCMeta):
def get_account_data_handler(self) -> AccountDataHandler:
return AccountDataHandler(self)
@cache_in_self
def get_external_cache(self) -> ExternalCache:
return ExternalCache(self)
@cache_in_self
def get_outbound_redis_connection(self) -> Optional["RedisProtocol"]:
if not self.config.redis.redis_enabled:
return None
# We only want to import redis module if we're using it, as we have
# `txredisapi` as an optional dependency.
from synapse.replication.tcp.redis import lazyConnection
logger.info(
"Connecting to redis (host=%r port=%r) for external cache",
self.config.redis_host,
self.config.redis_port,
)
return lazyConnection(
hs=self,
host=self.config.redis_host,
port=self.config.redis_port,
password=self.config.redis.redis_password,
reconnect=True,
)
async def remove_pusher(self, app_id: str, push_key: str, user_id: str):
return await self.get_pusherpool().remove_pusher(app_id, push_key, user_id)

View File

@@ -310,6 +310,7 @@ class StateHandler:
state_group_before_event = None
state_group_before_event_prev_group = None
deltas_to_state_group_before_event = None
entry = None
else:
# otherwise, we'll need to resolve the state across the prev_events.
@@ -340,9 +341,13 @@ class StateHandler:
current_state_ids=state_ids_before_event,
)
# XXX: can we update the state cache entry for the new state group? or
# could we set a flag on resolve_state_groups_for_events to tell it to
# always make a state group?
# Assign the new state group to the cached state entry.
#
# Note that this can race in that we could generate multiple state
# groups for the same state entry, but that is just inefficient
# rather than dangerous.
if entry and entry.state_group is None:
entry.state_group = state_group_before_event
#
# now if it's not a state event, we're done

View File

@@ -89,7 +89,7 @@ class EventForwardExtremitiesStore(SQLBaseStore):
SELECT event_id, state_group, depth, received_ts
FROM event_forward_extremities
INNER JOIN event_to_state_groups USING (event_id)
INNER JOIN events INNER JOIN USING (event_id)
INNER JOIN events USING (room_id, event_id)
WHERE room_id = ?
"""

View File

@@ -24,6 +24,7 @@ from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_cla
from synapse.storage.database import DatabasePool
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
from synapse.types import Collection
logger = logging.getLogger(__name__)
@@ -460,7 +461,7 @@ class SearchStore(SearchBackgroundUpdateStore):
async def search_rooms(
self,
room_ids: List[str],
room_ids: Collection[str],
search_term: str,
keys: List[str],
limit,

View File

@@ -15,11 +15,12 @@
# limitations under the License.
import logging
from collections import Counter
from enum import Enum
from itertools import chain
from typing import Any, Dict, List, Optional, Tuple
from typing_extensions import Counter
from twisted.internet.defer import DeferredLock
from synapse.api.constants import EventTypes, Membership
@@ -320,7 +321,9 @@ class StatsStore(StateDeltasStore):
return slice_list
@cached()
async def get_earliest_token_for_stats(self, stats_type: str, id: str) -> int:
async def get_earliest_token_for_stats(
self, stats_type: str, id: str
) -> Optional[int]:
"""
Fetch the "earliest token". This is used by the room stats delta
processor to ignore deltas that have been processed between the
@@ -340,7 +343,7 @@ class StatsStore(StateDeltasStore):
)
async def bulk_update_stats_delta(
self, ts: int, updates: Dict[str, Dict[str, Dict[str, Counter]]], stream_id: int
self, ts: int, updates: Dict[str, Dict[str, Counter[str]]], stream_id: int
) -> None:
"""Bulk update stats tables for a given stream_id and updates the stats
incremental position.
@@ -666,7 +669,7 @@ class StatsStore(StateDeltasStore):
async def get_changes_room_total_events_and_bytes(
self, min_pos: int, max_pos: int
) -> Dict[str, Dict[str, int]]:
) -> Tuple[Dict[str, Dict[str, int]], Dict[str, Dict[str, int]]]:
"""Fetches the counts of events in the given range of stream IDs.
Args:
@@ -684,18 +687,19 @@ class StatsStore(StateDeltasStore):
max_pos,
)
def get_changes_room_total_events_and_bytes_txn(self, txn, low_pos, high_pos):
def get_changes_room_total_events_and_bytes_txn(
self, txn, low_pos: int, high_pos: int
) -> Tuple[Dict[str, Dict[str, int]], Dict[str, Dict[str, int]]]:
"""Gets the total_events and total_event_bytes counts for rooms and
senders, in a range of stream_orderings (including backfilled events).
Args:
txn
low_pos (int): Low stream ordering
high_pos (int): High stream ordering
low_pos: Low stream ordering
high_pos: High stream ordering
Returns:
tuple[dict[str, dict[str, int]], dict[str, dict[str, int]]]: The
room and user deltas for total_events/total_event_bytes in the
The room and user deltas for total_events/total_event_bytes in the
format of `stats_id` -> fields
"""

View File

@@ -540,7 +540,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
desc="get_user_in_directory",
)
async def update_user_directory_stream_pos(self, stream_id: str) -> None:
async def update_user_directory_stream_pos(self, stream_id: int) -> None:
await self.db_pool.simple_update_one(
table="user_directory_stream_pos",
keyvalues={},

View File

@@ -212,6 +212,9 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
# Fake in memory Redis server that servers can connect to.
self._redis_server = FakeRedisPubSubServer()
# We may have an attempt to connect to redis for the external cache already.
self.connect_any_redis_attempts()
store = self.hs.get_datastore()
self.database_pool = store.db_pool
@@ -401,25 +404,23 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
fake one.
"""
clients = self.reactor.tcpClients
self.assertEqual(len(clients), 1)
(host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
self.assertEqual(host, "localhost")
self.assertEqual(port, 6379)
while clients:
(host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
self.assertEqual(host, "localhost")
self.assertEqual(port, 6379)
client_protocol = client_factory.buildProtocol(None)
server_protocol = self._redis_server.buildProtocol(None)
client_protocol = client_factory.buildProtocol(None)
server_protocol = self._redis_server.buildProtocol(None)
client_to_server_transport = FakeTransport(
server_protocol, self.reactor, client_protocol
)
client_protocol.makeConnection(client_to_server_transport)
client_to_server_transport = FakeTransport(
server_protocol, self.reactor, client_protocol
)
client_protocol.makeConnection(client_to_server_transport)
server_to_client_transport = FakeTransport(
client_protocol, self.reactor, server_protocol
)
server_protocol.makeConnection(server_to_client_transport)
return client_to_server_transport, server_to_client_transport
server_to_client_transport = FakeTransport(
client_protocol, self.reactor, server_protocol
)
server_protocol.makeConnection(server_to_client_transport)
class TestReplicationDataHandler(GenericWorkerReplicationHandler):
@@ -624,6 +625,12 @@ class FakeRedisPubSubProtocol(Protocol):
(channel,) = args
self._server.add_subscriber(self)
self.send(["subscribe", channel, 1])
# Since we use SET/GET to cache things we can safely no-op them.
elif command == b"SET":
self.send("OK")
elif command == b"GET":
self.send(None)
else:
raise Exception("Unknown command")
@@ -645,6 +652,8 @@ class FakeRedisPubSubProtocol(Protocol):
# We assume bytes are just unicode strings.
obj = obj.decode("utf-8")
if obj is None:
return "$-1\r\n"
if isinstance(obj, str):
return "${len}\r\n{str}\r\n".format(len=len(obj), str=obj)
if isinstance(obj, int):

View File

@@ -261,3 +261,32 @@ class PreviewUrlTestCase(unittest.TestCase):
html = ""
og = decode_and_calc_og(html, "http://example.com/test.html")
self.assertEqual(og, {})
def test_invalid_encoding(self):
"""An invalid character encoding should be ignored and treated as UTF-8, if possible."""
html = """
<html>
<head><title>Foo</title></head>
<body>
Some text.
</body>
</html>
"""
og = decode_and_calc_og(
html, "http://example.com/test.html", "invalid-encoding"
)
self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
def test_invalid_encoding2(self):
"""A body which doesn't match the sent character encoding."""
# Note that this contains an invalid UTF-8 sequence in the title.
html = b"""
<html>
<head><title>\xff\xff Foo</title></head>
<body>
Some text.
</body>
</html>
"""
og = decode_and_calc_og(html, "http://example.com/test.html")
self.assertEqual(og, {"og:title": "ÿÿ Foo", "og:description": "Some text."})