Merge commit '1baab2035' into anoa/dinsic_release_1_31_0
This commit is contained in:
1
changelog.d/9164.bugfix
Normal file
1
changelog.d/9164.bugfix
Normal file
@@ -0,0 +1 @@
|
||||
Fix a long-standing bug where an internal server error was raised when attempting to preview an HTML document in an unknown character encoding.
|
||||
1
changelog.d/9198.misc
Normal file
1
changelog.d/9198.misc
Normal file
@@ -0,0 +1 @@
|
||||
Precompute joined hosts and store in Redis.
|
||||
1
changelog.d/9199.removal
Normal file
1
changelog.d/9199.removal
Normal file
@@ -0,0 +1 @@
|
||||
The `service_url` parameter in `cas_config` is deprecated in favor of `public_baseurl`.
|
||||
1
changelog.d/9218.bugfix
Normal file
1
changelog.d/9218.bugfix
Normal file
@@ -0,0 +1 @@
|
||||
Fix bug where we sometimes didn't detect that Redis connections had died, causing workers to not see new data.
|
||||
1
changelog.d/9222.misc
Normal file
1
changelog.d/9222.misc
Normal file
@@ -0,0 +1 @@
|
||||
Update `isort` to v5.7.0 to bypass a bug where it would disagree with `black` about formatting.
|
||||
1
changelog.d/9223.misc
Normal file
1
changelog.d/9223.misc
Normal file
@@ -0,0 +1 @@
|
||||
Add type hints to handlers code.
|
||||
@@ -2057,10 +2057,6 @@ cas_config:
|
||||
#
|
||||
#server_url: "https://cas-server.com"
|
||||
|
||||
# The public URL of the homeserver.
|
||||
#
|
||||
#service_url: "https://homeserver.domain.com:8448"
|
||||
|
||||
# The attribute of the CAS response to use as the display name.
|
||||
#
|
||||
# If unset, no displayname will be set.
|
||||
|
||||
14
mypy.ini
14
mypy.ini
@@ -26,6 +26,8 @@ files =
|
||||
synapse/handlers/_base.py,
|
||||
synapse/handlers/account_data.py,
|
||||
synapse/handlers/account_validity.py,
|
||||
synapse/handlers/acme.py,
|
||||
synapse/handlers/acme_issuing_service.py,
|
||||
synapse/handlers/admin.py,
|
||||
synapse/handlers/appservice.py,
|
||||
synapse/handlers/auth.py,
|
||||
@@ -36,6 +38,7 @@ files =
|
||||
synapse/handlers/directory.py,
|
||||
synapse/handlers/events.py,
|
||||
synapse/handlers/federation.py,
|
||||
synapse/handlers/groups_local.py,
|
||||
synapse/handlers/identity.py,
|
||||
synapse/handlers/initial_sync.py,
|
||||
synapse/handlers/message.py,
|
||||
@@ -52,8 +55,13 @@ files =
|
||||
synapse/handlers/room_member.py,
|
||||
synapse/handlers/room_member_worker.py,
|
||||
synapse/handlers/saml_handler.py,
|
||||
synapse/handlers/search.py,
|
||||
synapse/handlers/set_password.py,
|
||||
synapse/handlers/sso.py,
|
||||
synapse/handlers/state_deltas.py,
|
||||
synapse/handlers/stats.py,
|
||||
synapse/handlers/sync.py,
|
||||
synapse/handlers/typing.py,
|
||||
synapse/handlers/user_directory.py,
|
||||
synapse/handlers/ui_auth,
|
||||
synapse/http/client.py,
|
||||
@@ -194,3 +202,9 @@ ignore_missing_imports = True
|
||||
|
||||
[mypy-hiredis]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-josepy.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-txacme.*]
|
||||
ignore_missing_imports = True
|
||||
|
||||
2
setup.py
2
setup.py
@@ -96,7 +96,7 @@ CONDITIONAL_REQUIREMENTS["all"] = list(ALL_OPTIONAL_REQUIREMENTS)
|
||||
#
|
||||
# We pin black so that our tests don't start failing on new releases.
|
||||
CONDITIONAL_REQUIREMENTS["lint"] = [
|
||||
"isort==5.0.3",
|
||||
"isort==5.7.0",
|
||||
"black==19.10b0",
|
||||
"flake8-comprehensions",
|
||||
"flake8",
|
||||
|
||||
@@ -15,12 +15,23 @@
|
||||
|
||||
"""Contains *incomplete* type hints for txredisapi.
|
||||
"""
|
||||
from typing import List, Optional, Type, Union
|
||||
from typing import Any, List, Optional, Type, Union
|
||||
|
||||
class RedisProtocol:
|
||||
def publish(self, channel: str, message: bytes): ...
|
||||
async def ping(self) -> None: ...
|
||||
async def set(
|
||||
self,
|
||||
key: str,
|
||||
value: Any,
|
||||
expire: Optional[int] = None,
|
||||
pexpire: Optional[int] = None,
|
||||
only_if_not_exists: bool = False,
|
||||
only_if_exists: bool = False,
|
||||
) -> None: ...
|
||||
async def get(self, key: str) -> Any: ...
|
||||
|
||||
class SubscriberProtocol:
|
||||
class SubscriberProtocol(RedisProtocol):
|
||||
def __init__(self, *args, **kwargs): ...
|
||||
password: Optional[str]
|
||||
def subscribe(self, channels: Union[str, List[str]]): ...
|
||||
@@ -39,14 +50,13 @@ def lazyConnection(
|
||||
convertNumbers: bool = ...,
|
||||
) -> RedisProtocol: ...
|
||||
|
||||
class SubscriberFactory:
|
||||
def buildProtocol(self, addr): ...
|
||||
|
||||
class ConnectionHandler: ...
|
||||
|
||||
class RedisFactory:
|
||||
continueTrying: bool
|
||||
handler: RedisProtocol
|
||||
pool: List[RedisProtocol]
|
||||
replyTimeout: Optional[int]
|
||||
def __init__(
|
||||
self,
|
||||
uuid: str,
|
||||
@@ -59,3 +69,7 @@ class RedisFactory:
|
||||
replyTimeout: Optional[int] = None,
|
||||
convertNumbers: Optional[int] = True,
|
||||
): ...
|
||||
def buildProtocol(self, addr) -> RedisProtocol: ...
|
||||
|
||||
class SubscriberFactory(RedisFactory):
|
||||
def __init__(self): ...
|
||||
|
||||
@@ -20,6 +20,7 @@ from synapse.config import (
|
||||
password_auth_providers,
|
||||
push,
|
||||
ratelimiting,
|
||||
redis,
|
||||
registration,
|
||||
repository,
|
||||
room_directory,
|
||||
@@ -83,6 +84,7 @@ class RootConfig:
|
||||
roomdirectory: room_directory.RoomDirectoryConfig
|
||||
thirdpartyrules: third_party_event_rules.ThirdPartyRulesConfig
|
||||
tracer: tracer.TracerConfig
|
||||
redis: redis.RedisConfig
|
||||
|
||||
config_classes: List = ...
|
||||
def __init__(self) -> None: ...
|
||||
|
||||
@@ -30,7 +30,13 @@ class CasConfig(Config):
|
||||
|
||||
if self.cas_enabled:
|
||||
self.cas_server_url = cas_config["server_url"]
|
||||
self.cas_service_url = cas_config["service_url"]
|
||||
public_base_url = cas_config.get("service_url") or self.public_baseurl
|
||||
if public_base_url[-1] != "/":
|
||||
public_base_url += "/"
|
||||
# TODO Update this to a _synapse URL.
|
||||
self.cas_service_url = (
|
||||
public_base_url + "_matrix/client/r0/login/cas/ticket"
|
||||
)
|
||||
self.cas_displayname_attribute = cas_config.get("displayname_attribute")
|
||||
self.cas_required_attributes = cas_config.get("required_attributes") or {}
|
||||
else:
|
||||
@@ -53,10 +59,6 @@ class CasConfig(Config):
|
||||
#
|
||||
#server_url: "https://cas-server.com"
|
||||
|
||||
# The public URL of the homeserver.
|
||||
#
|
||||
#service_url: "https://homeserver.domain.com:8448"
|
||||
|
||||
# The attribute of the CAS response to use as the display name.
|
||||
#
|
||||
# If unset, no displayname will be set.
|
||||
|
||||
@@ -54,8 +54,7 @@ class OIDCConfig(Config):
|
||||
"Multiple OIDC providers have the idp_id %r." % idp_id
|
||||
)
|
||||
|
||||
public_baseurl = self.public_baseurl
|
||||
self.oidc_callback_url = public_baseurl + "_synapse/oidc/callback"
|
||||
self.oidc_callback_url = self.public_baseurl + "_synapse/oidc/callback"
|
||||
|
||||
@property
|
||||
def oidc_enabled(self) -> bool:
|
||||
|
||||
@@ -142,6 +142,8 @@ class FederationSender:
|
||||
self._wake_destinations_needing_catchup,
|
||||
)
|
||||
|
||||
self._external_cache = hs.get_external_cache()
|
||||
|
||||
def _get_per_destination_queue(self, destination: str) -> PerDestinationQueue:
|
||||
"""Get or create a PerDestinationQueue for the given destination
|
||||
|
||||
@@ -197,22 +199,40 @@ class FederationSender:
|
||||
if not event.internal_metadata.should_proactively_send():
|
||||
return
|
||||
|
||||
try:
|
||||
# Get the state from before the event.
|
||||
# We need to make sure that this is the state from before
|
||||
# the event and not from after it.
|
||||
# Otherwise if the last member on a server in a room is
|
||||
# banned then it won't receive the event because it won't
|
||||
# be in the room after the ban.
|
||||
destinations = await self.state.get_hosts_in_room_at_events(
|
||||
event.room_id, event_ids=event.prev_event_ids()
|
||||
destinations = None # type: Optional[Set[str]]
|
||||
if not event.prev_event_ids():
|
||||
# If there are no prev event IDs then the state is empty
|
||||
# and so no remote servers in the room
|
||||
destinations = set()
|
||||
else:
|
||||
# We check the external cache for the destinations, which is
|
||||
# stored per state group.
|
||||
|
||||
sg = await self._external_cache.get(
|
||||
"event_to_prev_state_group", event.event_id
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to calculate hosts in room for event: %s",
|
||||
event.event_id,
|
||||
)
|
||||
return
|
||||
if sg:
|
||||
destinations = await self._external_cache.get(
|
||||
"get_joined_hosts", str(sg)
|
||||
)
|
||||
|
||||
if destinations is None:
|
||||
try:
|
||||
# Get the state from before the event.
|
||||
# We need to make sure that this is the state from before
|
||||
# the event and not from after it.
|
||||
# Otherwise if the last member on a server in a room is
|
||||
# banned then it won't receive the event because it won't
|
||||
# be in the room after the ban.
|
||||
destinations = await self.state.get_hosts_in_room_at_events(
|
||||
event.room_id, event_ids=event.prev_event_ids()
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to calculate hosts in room for event: %s",
|
||||
event.event_id,
|
||||
)
|
||||
return
|
||||
|
||||
destinations = {
|
||||
d
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import twisted
|
||||
import twisted.internet.error
|
||||
@@ -22,6 +23,9 @@ from twisted.web.resource import Resource
|
||||
|
||||
from synapse.app import check_bind_error
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.app.homeserver import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ACME_REGISTER_FAIL_ERROR = """
|
||||
@@ -35,12 +39,12 @@ solutions, please read https://github.com/matrix-org/synapse/blob/master/docs/AC
|
||||
|
||||
|
||||
class AcmeHandler:
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.hs = hs
|
||||
self.reactor = hs.get_reactor()
|
||||
self._acme_domain = hs.config.acme_domain
|
||||
|
||||
async def start_listening(self):
|
||||
async def start_listening(self) -> None:
|
||||
from synapse.handlers import acme_issuing_service
|
||||
|
||||
# Configure logging for txacme, if you need to debug
|
||||
@@ -85,7 +89,7 @@ class AcmeHandler:
|
||||
logger.error(ACME_REGISTER_FAIL_ERROR)
|
||||
raise
|
||||
|
||||
async def provision_certificate(self):
|
||||
async def provision_certificate(self) -> None:
|
||||
|
||||
logger.warning("Reprovisioning %s", self._acme_domain)
|
||||
|
||||
@@ -110,5 +114,3 @@ class AcmeHandler:
|
||||
except Exception:
|
||||
logger.exception("Failed saving!")
|
||||
raise
|
||||
|
||||
return True
|
||||
|
||||
@@ -22,8 +22,10 @@ only need (and may only have available) if we are doing ACME, so is designed to
|
||||
imported conditionally.
|
||||
"""
|
||||
import logging
|
||||
from typing import Dict, Iterable, List
|
||||
|
||||
import attr
|
||||
import pem
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from josepy import JWKRSA
|
||||
@@ -36,20 +38,27 @@ from txacme.util import generate_private_key
|
||||
from zope.interface import implementer
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.internet.interfaces import IReactorTCP
|
||||
from twisted.python.filepath import FilePath
|
||||
from twisted.python.url import URL
|
||||
from twisted.web.resource import IResource
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_issuing_service(reactor, acme_url, account_key_file, well_known_resource):
|
||||
def create_issuing_service(
|
||||
reactor: IReactorTCP,
|
||||
acme_url: str,
|
||||
account_key_file: str,
|
||||
well_known_resource: IResource,
|
||||
) -> AcmeIssuingService:
|
||||
"""Create an ACME issuing service, and attach it to a web Resource
|
||||
|
||||
Args:
|
||||
reactor: twisted reactor
|
||||
acme_url (str): URL to use to request certificates
|
||||
account_key_file (str): where to store the account key
|
||||
well_known_resource (twisted.web.IResource): web resource for .well-known.
|
||||
acme_url: URL to use to request certificates
|
||||
account_key_file: where to store the account key
|
||||
well_known_resource: web resource for .well-known.
|
||||
we will attach a child resource for "acme-challenge".
|
||||
|
||||
Returns:
|
||||
@@ -83,18 +92,20 @@ class ErsatzStore:
|
||||
A store that only stores in memory.
|
||||
"""
|
||||
|
||||
certs = attr.ib(default=attr.Factory(dict))
|
||||
certs = attr.ib(type=Dict[bytes, List[bytes]], default=attr.Factory(dict))
|
||||
|
||||
def store(self, server_name, pem_objects):
|
||||
def store(
|
||||
self, server_name: bytes, pem_objects: Iterable[pem.AbstractPEMObject]
|
||||
) -> defer.Deferred:
|
||||
self.certs[server_name] = [o.as_bytes() for o in pem_objects]
|
||||
return defer.succeed(None)
|
||||
|
||||
|
||||
def load_or_create_client_key(key_file):
|
||||
def load_or_create_client_key(key_file: str) -> JWKRSA:
|
||||
"""Load the ACME account key from a file, creating it if it does not exist.
|
||||
|
||||
Args:
|
||||
key_file (str): name of the file to use as the account key
|
||||
key_file: name of the file to use as the account key
|
||||
"""
|
||||
# this is based on txacme.endpoint.load_or_create_client_key, but doesn't
|
||||
# hardcode the 'client.key' filename
|
||||
|
||||
@@ -99,11 +99,7 @@ class CasHandler:
|
||||
Returns:
|
||||
The URL to use as a "service" parameter.
|
||||
"""
|
||||
return "%s%s?%s" % (
|
||||
self._cas_service_url,
|
||||
"/_matrix/client/r0/login/cas/ticket",
|
||||
urllib.parse.urlencode(args),
|
||||
)
|
||||
return "%s?%s" % (self._cas_service_url, urllib.parse.urlencode(args),)
|
||||
|
||||
async def _validate_ticket(
|
||||
self, ticket: str, service_args: Dict[str, str]
|
||||
|
||||
@@ -2285,6 +2285,11 @@ class FederationHandler(BaseHandler):
|
||||
if event.type == EventTypes.GuestAccess and not context.rejected:
|
||||
await self.maybe_kick_guest_users(event)
|
||||
|
||||
# If we are going to send this event over federation we precaclculate
|
||||
# the joined hosts.
|
||||
if event.internal_metadata.get_send_on_behalf_of():
|
||||
await self.event_creation_handler.cache_joined_hosts_for_event(event)
|
||||
|
||||
return context
|
||||
|
||||
async def _check_for_soft_fail(
|
||||
|
||||
@@ -15,9 +15,13 @@
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Dict, Iterable, List, Set
|
||||
|
||||
from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError
|
||||
from synapse.types import GroupID, get_domain_from_id
|
||||
from synapse.types import GroupID, JsonDict, get_domain_from_id
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.app.homeserver import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -56,7 +60,7 @@ def _create_rerouter(func_name):
|
||||
|
||||
|
||||
class GroupsLocalWorkerHandler:
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.hs = hs
|
||||
self.store = hs.get_datastore()
|
||||
self.room_list_handler = hs.get_room_list_handler()
|
||||
@@ -84,7 +88,9 @@ class GroupsLocalWorkerHandler:
|
||||
get_group_role = _create_rerouter("get_group_role")
|
||||
get_group_roles = _create_rerouter("get_group_roles")
|
||||
|
||||
async def get_group_summary(self, group_id, requester_user_id):
|
||||
async def get_group_summary(
|
||||
self, group_id: str, requester_user_id: str
|
||||
) -> JsonDict:
|
||||
"""Get the group summary for a group.
|
||||
|
||||
If the group is remote we check that the users have valid attestations.
|
||||
@@ -137,14 +143,15 @@ class GroupsLocalWorkerHandler:
|
||||
|
||||
return res
|
||||
|
||||
async def get_users_in_group(self, group_id, requester_user_id):
|
||||
async def get_users_in_group(
|
||||
self, group_id: str, requester_user_id: str
|
||||
) -> JsonDict:
|
||||
"""Get users in a group
|
||||
"""
|
||||
if self.is_mine_id(group_id):
|
||||
res = await self.groups_server_handler.get_users_in_group(
|
||||
return await self.groups_server_handler.get_users_in_group(
|
||||
group_id, requester_user_id
|
||||
)
|
||||
return res
|
||||
|
||||
group_server_name = get_domain_from_id(group_id)
|
||||
|
||||
@@ -178,11 +185,11 @@ class GroupsLocalWorkerHandler:
|
||||
|
||||
return res
|
||||
|
||||
async def get_joined_groups(self, user_id):
|
||||
async def get_joined_groups(self, user_id: str) -> JsonDict:
|
||||
group_ids = await self.store.get_joined_groups(user_id)
|
||||
return {"groups": group_ids}
|
||||
|
||||
async def get_publicised_groups_for_user(self, user_id):
|
||||
async def get_publicised_groups_for_user(self, user_id: str) -> JsonDict:
|
||||
if self.hs.is_mine_id(user_id):
|
||||
result = await self.store.get_publicised_groups_for_user(user_id)
|
||||
|
||||
@@ -206,8 +213,10 @@ class GroupsLocalWorkerHandler:
|
||||
# TODO: Verify attestations
|
||||
return {"groups": result}
|
||||
|
||||
async def bulk_get_publicised_groups(self, user_ids, proxy=True):
|
||||
destinations = {}
|
||||
async def bulk_get_publicised_groups(
|
||||
self, user_ids: Iterable[str], proxy: bool = True
|
||||
) -> JsonDict:
|
||||
destinations = {} # type: Dict[str, Set[str]]
|
||||
local_users = set()
|
||||
|
||||
for user_id in user_ids:
|
||||
@@ -220,7 +229,7 @@ class GroupsLocalWorkerHandler:
|
||||
raise SynapseError(400, "Some user_ids are not local")
|
||||
|
||||
results = {}
|
||||
failed_results = []
|
||||
failed_results = [] # type: List[str]
|
||||
for destination, dest_user_ids in destinations.items():
|
||||
try:
|
||||
r = await self.transport_client.bulk_get_publicised_groups(
|
||||
@@ -242,7 +251,7 @@ class GroupsLocalWorkerHandler:
|
||||
|
||||
|
||||
class GroupsLocalHandler(GroupsLocalWorkerHandler):
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__(hs)
|
||||
|
||||
# Ensure attestations get renewed
|
||||
@@ -271,7 +280,9 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
|
||||
|
||||
set_group_join_policy = _create_rerouter("set_group_join_policy")
|
||||
|
||||
async def create_group(self, group_id, user_id, content):
|
||||
async def create_group(
|
||||
self, group_id: str, user_id: str, content: JsonDict
|
||||
) -> JsonDict:
|
||||
"""Create a group
|
||||
"""
|
||||
|
||||
@@ -284,27 +295,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
|
||||
local_attestation = None
|
||||
remote_attestation = None
|
||||
else:
|
||||
local_attestation = self.attestations.create_attestation(group_id, user_id)
|
||||
content["attestation"] = local_attestation
|
||||
|
||||
content["user_profile"] = await self.profile_handler.get_profile(user_id)
|
||||
|
||||
try:
|
||||
res = await self.transport_client.create_group(
|
||||
get_domain_from_id(group_id), group_id, user_id, content
|
||||
)
|
||||
except HttpResponseException as e:
|
||||
raise e.to_synapse_error()
|
||||
except RequestSendFailed:
|
||||
raise SynapseError(502, "Failed to contact group server")
|
||||
|
||||
remote_attestation = res["attestation"]
|
||||
await self.attestations.verify_attestation(
|
||||
remote_attestation,
|
||||
group_id=group_id,
|
||||
user_id=user_id,
|
||||
server_name=get_domain_from_id(group_id),
|
||||
)
|
||||
raise SynapseError(400, "Unable to create remote groups")
|
||||
|
||||
is_publicised = content.get("publicise", False)
|
||||
token = await self.store.register_user_group_membership(
|
||||
@@ -320,7 +311,9 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
|
||||
|
||||
return res
|
||||
|
||||
async def join_group(self, group_id, user_id, content):
|
||||
async def join_group(
|
||||
self, group_id: str, user_id: str, content: JsonDict
|
||||
) -> JsonDict:
|
||||
"""Request to join a group
|
||||
"""
|
||||
if self.is_mine_id(group_id):
|
||||
@@ -365,7 +358,9 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
|
||||
|
||||
return {}
|
||||
|
||||
async def accept_invite(self, group_id, user_id, content):
|
||||
async def accept_invite(
|
||||
self, group_id: str, user_id: str, content: JsonDict
|
||||
) -> JsonDict:
|
||||
"""Accept an invite to a group
|
||||
"""
|
||||
if self.is_mine_id(group_id):
|
||||
@@ -410,7 +405,9 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
|
||||
|
||||
return {}
|
||||
|
||||
async def invite(self, group_id, user_id, requester_user_id, config):
|
||||
async def invite(
|
||||
self, group_id: str, user_id: str, requester_user_id: str, config: JsonDict
|
||||
) -> JsonDict:
|
||||
"""Invite a user to a group
|
||||
"""
|
||||
content = {"requester_user_id": requester_user_id, "config": config}
|
||||
@@ -434,7 +431,9 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
|
||||
|
||||
return res
|
||||
|
||||
async def on_invite(self, group_id, user_id, content):
|
||||
async def on_invite(
|
||||
self, group_id: str, user_id: str, content: JsonDict
|
||||
) -> JsonDict:
|
||||
"""One of our users were invited to a group
|
||||
"""
|
||||
# TODO: Support auto join and rejection
|
||||
@@ -465,8 +464,8 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
|
||||
return {"state": "invite", "user_profile": user_profile}
|
||||
|
||||
async def remove_user_from_group(
|
||||
self, group_id, user_id, requester_user_id, content
|
||||
):
|
||||
self, group_id: str, user_id: str, requester_user_id: str, content: JsonDict
|
||||
) -> JsonDict:
|
||||
"""Remove a user from a group
|
||||
"""
|
||||
if user_id == requester_user_id:
|
||||
@@ -499,7 +498,9 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
|
||||
|
||||
return res
|
||||
|
||||
async def user_removed_from_group(self, group_id, user_id, content):
|
||||
async def user_removed_from_group(
|
||||
self, group_id: str, user_id: str, content: JsonDict
|
||||
) -> None:
|
||||
"""One of our users was removed/kicked from a group
|
||||
"""
|
||||
# TODO: Check if user in group
|
||||
|
||||
@@ -434,6 +434,8 @@ class EventCreationHandler:
|
||||
|
||||
self._ephemeral_events_enabled = hs.config.enable_ephemeral_messages
|
||||
|
||||
self._external_cache = hs.get_external_cache()
|
||||
|
||||
async def create_event(
|
||||
self,
|
||||
requester: Requester,
|
||||
@@ -941,6 +943,8 @@ class EventCreationHandler:
|
||||
|
||||
await self.action_generator.handle_push_actions_for_event(event, context)
|
||||
|
||||
await self.cache_joined_hosts_for_event(event)
|
||||
|
||||
try:
|
||||
# If we're a worker we need to hit out to the master.
|
||||
writer_instance = self._events_shard_config.get_instance(event.room_id)
|
||||
@@ -980,6 +984,44 @@ class EventCreationHandler:
|
||||
await self.store.remove_push_actions_from_staging(event.event_id)
|
||||
raise
|
||||
|
||||
async def cache_joined_hosts_for_event(self, event: EventBase) -> None:
|
||||
"""Precalculate the joined hosts at the event, when using Redis, so that
|
||||
external federation senders don't have to recalculate it themselves.
|
||||
"""
|
||||
|
||||
if not self._external_cache.is_enabled():
|
||||
return
|
||||
|
||||
# We actually store two mappings, event ID -> prev state group,
|
||||
# state group -> joined hosts, which is much more space efficient
|
||||
# than event ID -> joined hosts.
|
||||
#
|
||||
# Note: We have to cache event ID -> prev state group, as we don't
|
||||
# store that in the DB.
|
||||
#
|
||||
# Note: We always set the state group -> joined hosts cache, even if
|
||||
# we already set it, so that the expiry time is reset.
|
||||
|
||||
state_entry = await self.state.resolve_state_groups_for_events(
|
||||
event.room_id, event_ids=event.prev_event_ids()
|
||||
)
|
||||
|
||||
if state_entry.state_group:
|
||||
joined_hosts = await self.store.get_joined_hosts(event.room_id, state_entry)
|
||||
|
||||
await self._external_cache.set(
|
||||
"event_to_prev_state_group",
|
||||
event.event_id,
|
||||
state_entry.state_group,
|
||||
expiry_ms=60 * 60 * 1000,
|
||||
)
|
||||
await self._external_cache.set(
|
||||
"get_joined_hosts",
|
||||
str(state_entry.state_group),
|
||||
list(joined_hosts),
|
||||
expiry_ms=60 * 60 * 1000,
|
||||
)
|
||||
|
||||
async def _validate_canonical_alias(
|
||||
self, directory_handler, room_alias_str: str, expected_room_id: str
|
||||
) -> None:
|
||||
|
||||
@@ -15,23 +15,28 @@
|
||||
|
||||
import itertools
|
||||
import logging
|
||||
from typing import Iterable
|
||||
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional
|
||||
|
||||
from unpaddedbase64 import decode_base64, encode_base64
|
||||
|
||||
from synapse.api.constants import EventTypes, Membership
|
||||
from synapse.api.errors import NotFoundError, SynapseError
|
||||
from synapse.api.filtering import Filter
|
||||
from synapse.events import EventBase
|
||||
from synapse.storage.state import StateFilter
|
||||
from synapse.types import JsonDict, UserID
|
||||
from synapse.visibility import filter_events_for_client
|
||||
|
||||
from ._base import BaseHandler
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.app.homeserver import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SearchHandler(BaseHandler):
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__(hs)
|
||||
self._event_serializer = hs.get_event_client_serializer()
|
||||
self.storage = hs.get_storage()
|
||||
@@ -87,13 +92,15 @@ class SearchHandler(BaseHandler):
|
||||
|
||||
return historical_room_ids
|
||||
|
||||
async def search(self, user, content, batch=None):
|
||||
async def search(
|
||||
self, user: UserID, content: JsonDict, batch: Optional[str] = None
|
||||
) -> JsonDict:
|
||||
"""Performs a full text search for a user.
|
||||
|
||||
Args:
|
||||
user (UserID)
|
||||
content (dict): Search parameters
|
||||
batch (str): The next_batch parameter. Used for pagination.
|
||||
user
|
||||
content: Search parameters
|
||||
batch: The next_batch parameter. Used for pagination.
|
||||
|
||||
Returns:
|
||||
dict to be returned to the client with results of search
|
||||
@@ -186,7 +193,7 @@ class SearchHandler(BaseHandler):
|
||||
# If doing a subset of all rooms seearch, check if any of the rooms
|
||||
# are from an upgraded room, and search their contents as well
|
||||
if search_filter.rooms:
|
||||
historical_room_ids = []
|
||||
historical_room_ids = [] # type: List[str]
|
||||
for room_id in search_filter.rooms:
|
||||
# Add any previous rooms to the search if they exist
|
||||
ids = await self.get_old_rooms_from_upgraded_room(room_id)
|
||||
@@ -209,8 +216,10 @@ class SearchHandler(BaseHandler):
|
||||
|
||||
rank_map = {} # event_id -> rank of event
|
||||
allowed_events = []
|
||||
room_groups = {} # Holds result of grouping by room, if applicable
|
||||
sender_group = {} # Holds result of grouping by sender, if applicable
|
||||
# Holds result of grouping by room, if applicable
|
||||
room_groups = {} # type: Dict[str, JsonDict]
|
||||
# Holds result of grouping by sender, if applicable
|
||||
sender_group = {} # type: Dict[str, JsonDict]
|
||||
|
||||
# Holds the next_batch for the entire result set if one of those exists
|
||||
global_next_batch = None
|
||||
@@ -254,7 +263,7 @@ class SearchHandler(BaseHandler):
|
||||
s["results"].append(e.event_id)
|
||||
|
||||
elif order_by == "recent":
|
||||
room_events = []
|
||||
room_events = [] # type: List[EventBase]
|
||||
i = 0
|
||||
|
||||
pagination_token = batch_token
|
||||
@@ -418,13 +427,10 @@ class SearchHandler(BaseHandler):
|
||||
|
||||
state_results = {}
|
||||
if include_state:
|
||||
rooms = {e.room_id for e in allowed_events}
|
||||
for room_id in rooms:
|
||||
for room_id in {e.room_id for e in allowed_events}:
|
||||
state = await self.state_handler.get_current_state(room_id)
|
||||
state_results[room_id] = list(state.values())
|
||||
|
||||
state_results.values()
|
||||
|
||||
# We're now about to serialize the events. We should not make any
|
||||
# blocking calls after this. Otherwise the 'age' will be wrong
|
||||
|
||||
@@ -448,9 +454,9 @@ class SearchHandler(BaseHandler):
|
||||
|
||||
if state_results:
|
||||
s = {}
|
||||
for room_id, state in state_results.items():
|
||||
for room_id, state_events in state_results.items():
|
||||
s[room_id] = await self._event_serializer.serialize_events(
|
||||
state, time_now
|
||||
state_events, time_now
|
||||
)
|
||||
|
||||
rooms_cat_res["state"] = s
|
||||
|
||||
@@ -14,24 +14,26 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from synapse.api.errors import Codes, StoreError, SynapseError
|
||||
from synapse.types import Requester
|
||||
|
||||
from ._base import BaseHandler
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.app.homeserver import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SetPasswordHandler(BaseHandler):
|
||||
"""Handler which deals with changing user account passwords"""
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__(hs)
|
||||
self._auth_handler = hs.get_auth_handler()
|
||||
self._device_handler = hs.get_device_handler()
|
||||
self._password_policy_handler = hs.get_password_policy_handler()
|
||||
|
||||
async def set_password(
|
||||
self,
|
||||
@@ -39,7 +41,7 @@ class SetPasswordHandler(BaseHandler):
|
||||
password_hash: str,
|
||||
logout_devices: bool,
|
||||
requester: Optional[Requester] = None,
|
||||
):
|
||||
) -> None:
|
||||
if not self.hs.config.password_localdb_enabled:
|
||||
raise SynapseError(403, "Password change disabled", errcode=Codes.FORBIDDEN)
|
||||
|
||||
|
||||
@@ -14,15 +14,25 @@
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.app.homeserver import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class StateDeltasHandler:
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
async def _get_key_change(self, prev_event_id, event_id, key_name, public_value):
|
||||
async def _get_key_change(
|
||||
self,
|
||||
prev_event_id: Optional[str],
|
||||
event_id: Optional[str],
|
||||
key_name: str,
|
||||
public_value: str,
|
||||
) -> Optional[bool]:
|
||||
"""Given two events check if the `key_name` field in content changed
|
||||
from not matching `public_value` to doing so.
|
||||
|
||||
|
||||
@@ -14,13 +14,19 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from collections import Counter
|
||||
from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Tuple
|
||||
|
||||
from typing_extensions import Counter as CounterType
|
||||
|
||||
from synapse.api.constants import EventTypes, Membership
|
||||
from synapse.metrics import event_processing_positions
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.types import JsonDict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.app.homeserver import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -33,7 +39,7 @@ class StatsHandler:
|
||||
Heavily derived from UserDirectoryHandler
|
||||
"""
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.hs = hs
|
||||
self.store = hs.get_datastore()
|
||||
self.state = hs.get_state_handler()
|
||||
@@ -46,7 +52,7 @@ class StatsHandler:
|
||||
self.stats_enabled = hs.config.stats_enabled
|
||||
|
||||
# The current position in the current_state_delta stream
|
||||
self.pos = None
|
||||
self.pos = None # type: Optional[int]
|
||||
|
||||
# Guard to ensure we only process deltas one at a time
|
||||
self._is_processing = False
|
||||
@@ -58,7 +64,7 @@ class StatsHandler:
|
||||
# we start populating stats
|
||||
self.clock.call_later(0, self.notify_new_event)
|
||||
|
||||
def notify_new_event(self):
|
||||
def notify_new_event(self) -> None:
|
||||
"""Called when there may be more deltas to process
|
||||
"""
|
||||
if not self.stats_enabled or self._is_processing:
|
||||
@@ -74,7 +80,7 @@ class StatsHandler:
|
||||
|
||||
run_as_background_process("stats.notify_new_event", process)
|
||||
|
||||
async def _unsafe_process(self):
|
||||
async def _unsafe_process(self) -> None:
|
||||
# If self.pos is None then means we haven't fetched it from DB
|
||||
if self.pos is None:
|
||||
self.pos = await self.store.get_stats_positions()
|
||||
@@ -112,10 +118,10 @@ class StatsHandler:
|
||||
)
|
||||
|
||||
for room_id, fields in room_count.items():
|
||||
room_deltas.setdefault(room_id, {}).update(fields)
|
||||
room_deltas.setdefault(room_id, Counter()).update(fields)
|
||||
|
||||
for user_id, fields in user_count.items():
|
||||
user_deltas.setdefault(user_id, {}).update(fields)
|
||||
user_deltas.setdefault(user_id, Counter()).update(fields)
|
||||
|
||||
logger.debug("room_deltas: %s", room_deltas)
|
||||
logger.debug("user_deltas: %s", user_deltas)
|
||||
@@ -133,19 +139,20 @@ class StatsHandler:
|
||||
|
||||
self.pos = max_pos
|
||||
|
||||
async def _handle_deltas(self, deltas):
|
||||
async def _handle_deltas(
|
||||
self, deltas: Iterable[JsonDict]
|
||||
) -> Tuple[Dict[str, CounterType[str]], Dict[str, CounterType[str]]]:
|
||||
"""Called with the state deltas to process
|
||||
|
||||
Returns:
|
||||
tuple[dict[str, Counter], dict[str, counter]]
|
||||
Two dicts: the room deltas and the user deltas,
|
||||
mapping from room/user ID to changes in the various fields.
|
||||
"""
|
||||
|
||||
room_to_stats_deltas = {}
|
||||
user_to_stats_deltas = {}
|
||||
room_to_stats_deltas = {} # type: Dict[str, CounterType[str]]
|
||||
user_to_stats_deltas = {} # type: Dict[str, CounterType[str]]
|
||||
|
||||
room_to_state_updates = {}
|
||||
room_to_state_updates = {} # type: Dict[str, Dict[str, Any]]
|
||||
|
||||
for delta in deltas:
|
||||
typ = delta["type"]
|
||||
@@ -175,7 +182,7 @@ class StatsHandler:
|
||||
)
|
||||
continue
|
||||
|
||||
event_content = {}
|
||||
event_content = {} # type: JsonDict
|
||||
|
||||
sender = None
|
||||
if event_id is not None:
|
||||
@@ -263,13 +270,13 @@ class StatsHandler:
|
||||
)
|
||||
|
||||
if has_changed_joinedness:
|
||||
delta = +1 if membership == Membership.JOIN else -1
|
||||
membership_delta = +1 if membership == Membership.JOIN else -1
|
||||
|
||||
user_to_stats_deltas.setdefault(user_id, Counter())[
|
||||
"joined_rooms"
|
||||
] += delta
|
||||
] += membership_delta
|
||||
|
||||
room_stats_delta["local_users_in_room"] += delta
|
||||
room_stats_delta["local_users_in_room"] += membership_delta
|
||||
|
||||
elif typ == EventTypes.Create:
|
||||
room_state["is_federatable"] = (
|
||||
|
||||
@@ -15,13 +15,13 @@
|
||||
import logging
|
||||
import random
|
||||
from collections import namedtuple
|
||||
from typing import TYPE_CHECKING, List, Set, Tuple
|
||||
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple
|
||||
|
||||
from synapse.api.errors import AuthError, ShadowBanError, SynapseError
|
||||
from synapse.appservice import ApplicationService
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.replication.tcp.streams import TypingStream
|
||||
from synapse.types import JsonDict, UserID, get_domain_from_id
|
||||
from synapse.types import JsonDict, Requester, UserID, get_domain_from_id
|
||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||
from synapse.util.metrics import Measure
|
||||
from synapse.util.wheel_timer import WheelTimer
|
||||
@@ -65,17 +65,17 @@ class FollowerTypingHandler:
|
||||
)
|
||||
|
||||
# map room IDs to serial numbers
|
||||
self._room_serials = {}
|
||||
self._room_serials = {} # type: Dict[str, int]
|
||||
# map room IDs to sets of users currently typing
|
||||
self._room_typing = {}
|
||||
self._room_typing = {} # type: Dict[str, Set[str]]
|
||||
|
||||
self._member_last_federation_poke = {}
|
||||
self._member_last_federation_poke = {} # type: Dict[RoomMember, int]
|
||||
self.wheel_timer = WheelTimer(bucket_size=5000)
|
||||
self._latest_room_serial = 0
|
||||
|
||||
self.clock.looping_call(self._handle_timeouts, 5000)
|
||||
|
||||
def _reset(self):
|
||||
def _reset(self) -> None:
|
||||
"""Reset the typing handler's data caches.
|
||||
"""
|
||||
# map room IDs to serial numbers
|
||||
@@ -86,7 +86,7 @@ class FollowerTypingHandler:
|
||||
self._member_last_federation_poke = {}
|
||||
self.wheel_timer = WheelTimer(bucket_size=5000)
|
||||
|
||||
def _handle_timeouts(self):
|
||||
def _handle_timeouts(self) -> None:
|
||||
logger.debug("Checking for typing timeouts")
|
||||
|
||||
now = self.clock.time_msec()
|
||||
@@ -96,7 +96,7 @@ class FollowerTypingHandler:
|
||||
for member in members:
|
||||
self._handle_timeout_for_member(now, member)
|
||||
|
||||
def _handle_timeout_for_member(self, now: int, member: RoomMember):
|
||||
def _handle_timeout_for_member(self, now: int, member: RoomMember) -> None:
|
||||
if not self.is_typing(member):
|
||||
# Nothing to do if they're no longer typing
|
||||
return
|
||||
@@ -114,10 +114,10 @@ class FollowerTypingHandler:
|
||||
# each person typing.
|
||||
self.wheel_timer.insert(now=now, obj=member, then=now + 60 * 1000)
|
||||
|
||||
def is_typing(self, member):
|
||||
def is_typing(self, member: RoomMember) -> bool:
|
||||
return member.user_id in self._room_typing.get(member.room_id, [])
|
||||
|
||||
async def _push_remote(self, member, typing):
|
||||
async def _push_remote(self, member: RoomMember, typing: bool) -> None:
|
||||
if not self.federation:
|
||||
return
|
||||
|
||||
@@ -148,7 +148,7 @@ class FollowerTypingHandler:
|
||||
|
||||
def process_replication_rows(
|
||||
self, token: int, rows: List[TypingStream.TypingStreamRow]
|
||||
):
|
||||
) -> None:
|
||||
"""Should be called whenever we receive updates for typing stream.
|
||||
"""
|
||||
|
||||
@@ -178,7 +178,7 @@ class FollowerTypingHandler:
|
||||
|
||||
async def _send_changes_in_typing_to_remotes(
|
||||
self, room_id: str, prev_typing: Set[str], now_typing: Set[str]
|
||||
):
|
||||
) -> None:
|
||||
"""Process a change in typing of a room from replication, sending EDUs
|
||||
for any local users.
|
||||
"""
|
||||
@@ -194,12 +194,12 @@ class FollowerTypingHandler:
|
||||
if self.is_mine_id(user_id):
|
||||
await self._push_remote(RoomMember(room_id, user_id), False)
|
||||
|
||||
def get_current_token(self):
|
||||
def get_current_token(self) -> int:
|
||||
return self._latest_room_serial
|
||||
|
||||
|
||||
class TypingWriterHandler(FollowerTypingHandler):
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__(hs)
|
||||
|
||||
assert hs.config.worker.writers.typing == hs.get_instance_name()
|
||||
@@ -213,14 +213,15 @@ class TypingWriterHandler(FollowerTypingHandler):
|
||||
|
||||
hs.get_distributor().observe("user_left_room", self.user_left_room)
|
||||
|
||||
self._member_typing_until = {} # clock time we expect to stop
|
||||
# clock time we expect to stop
|
||||
self._member_typing_until = {} # type: Dict[RoomMember, int]
|
||||
|
||||
# caches which room_ids changed at which serials
|
||||
self._typing_stream_change_cache = StreamChangeCache(
|
||||
"TypingStreamChangeCache", self._latest_room_serial
|
||||
)
|
||||
|
||||
def _handle_timeout_for_member(self, now: int, member: RoomMember):
|
||||
def _handle_timeout_for_member(self, now: int, member: RoomMember) -> None:
|
||||
super()._handle_timeout_for_member(now, member)
|
||||
|
||||
if not self.is_typing(member):
|
||||
@@ -233,7 +234,9 @@ class TypingWriterHandler(FollowerTypingHandler):
|
||||
self._stopped_typing(member)
|
||||
return
|
||||
|
||||
async def started_typing(self, target_user, requester, room_id, timeout):
|
||||
async def started_typing(
|
||||
self, target_user: UserID, requester: Requester, room_id: str, timeout: int
|
||||
) -> None:
|
||||
target_user_id = target_user.to_string()
|
||||
auth_user_id = requester.user.to_string()
|
||||
|
||||
@@ -263,11 +266,13 @@ class TypingWriterHandler(FollowerTypingHandler):
|
||||
|
||||
if was_present:
|
||||
# No point sending another notification
|
||||
return None
|
||||
return
|
||||
|
||||
self._push_update(member=member, typing=True)
|
||||
|
||||
async def stopped_typing(self, target_user, requester, room_id):
|
||||
async def stopped_typing(
|
||||
self, target_user: UserID, requester: Requester, room_id: str
|
||||
) -> None:
|
||||
target_user_id = target_user.to_string()
|
||||
auth_user_id = requester.user.to_string()
|
||||
|
||||
@@ -290,23 +295,23 @@ class TypingWriterHandler(FollowerTypingHandler):
|
||||
|
||||
self._stopped_typing(member)
|
||||
|
||||
def user_left_room(self, user, room_id):
|
||||
def user_left_room(self, user: UserID, room_id: str) -> None:
|
||||
user_id = user.to_string()
|
||||
if self.is_mine_id(user_id):
|
||||
member = RoomMember(room_id=room_id, user_id=user_id)
|
||||
self._stopped_typing(member)
|
||||
|
||||
def _stopped_typing(self, member):
|
||||
def _stopped_typing(self, member: RoomMember) -> None:
|
||||
if member.user_id not in self._room_typing.get(member.room_id, set()):
|
||||
# No point
|
||||
return None
|
||||
return
|
||||
|
||||
self._member_typing_until.pop(member, None)
|
||||
self._member_last_federation_poke.pop(member, None)
|
||||
|
||||
self._push_update(member=member, typing=False)
|
||||
|
||||
def _push_update(self, member, typing):
|
||||
def _push_update(self, member: RoomMember, typing: bool) -> None:
|
||||
if self.hs.is_mine_id(member.user_id):
|
||||
# Only send updates for changes to our own users.
|
||||
run_as_background_process(
|
||||
@@ -315,7 +320,7 @@ class TypingWriterHandler(FollowerTypingHandler):
|
||||
|
||||
self._push_update_local(member=member, typing=typing)
|
||||
|
||||
async def _recv_edu(self, origin, content):
|
||||
async def _recv_edu(self, origin: str, content: JsonDict) -> None:
|
||||
room_id = content["room_id"]
|
||||
user_id = content["user_id"]
|
||||
|
||||
@@ -340,7 +345,7 @@ class TypingWriterHandler(FollowerTypingHandler):
|
||||
self.wheel_timer.insert(now=now, obj=member, then=now + FEDERATION_TIMEOUT)
|
||||
self._push_update_local(member=member, typing=content["typing"])
|
||||
|
||||
def _push_update_local(self, member, typing):
|
||||
def _push_update_local(self, member: RoomMember, typing: bool) -> None:
|
||||
room_set = self._room_typing.setdefault(member.room_id, set())
|
||||
if typing:
|
||||
room_set.add(member.user_id)
|
||||
@@ -386,7 +391,7 @@ class TypingWriterHandler(FollowerTypingHandler):
|
||||
|
||||
changed_rooms = self._typing_stream_change_cache.get_all_entities_changed(
|
||||
last_id
|
||||
)
|
||||
) # type: Optional[Iterable[str]]
|
||||
|
||||
if changed_rooms is None:
|
||||
changed_rooms = self._room_serials
|
||||
@@ -412,13 +417,13 @@ class TypingWriterHandler(FollowerTypingHandler):
|
||||
|
||||
def process_replication_rows(
|
||||
self, token: int, rows: List[TypingStream.TypingStreamRow]
|
||||
):
|
||||
) -> None:
|
||||
# The writing process should never get updates from replication.
|
||||
raise Exception("Typing writer instance got typing info over replication")
|
||||
|
||||
|
||||
class TypingNotificationEventSource:
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.hs = hs
|
||||
self.clock = hs.get_clock()
|
||||
# We can't call get_typing_handler here because there's a cycle:
|
||||
@@ -427,7 +432,7 @@ class TypingNotificationEventSource:
|
||||
#
|
||||
self.get_typing_handler = hs.get_typing_handler
|
||||
|
||||
def _make_event_for(self, room_id):
|
||||
def _make_event_for(self, room_id: str) -> JsonDict:
|
||||
typing = self.get_typing_handler()._room_typing[room_id]
|
||||
return {
|
||||
"type": "m.typing",
|
||||
@@ -462,7 +467,9 @@ class TypingNotificationEventSource:
|
||||
|
||||
return (events, handler._latest_room_serial)
|
||||
|
||||
async def get_new_events(self, from_key, room_ids, **kwargs):
|
||||
async def get_new_events(
|
||||
self, from_key: int, room_ids: Iterable[str], **kwargs
|
||||
) -> Tuple[List[JsonDict], int]:
|
||||
with Measure(self.clock, "typing.get_new_events"):
|
||||
from_key = int(from_key)
|
||||
handler = self.get_typing_handler()
|
||||
@@ -478,5 +485,5 @@ class TypingNotificationEventSource:
|
||||
|
||||
return (events, handler._latest_room_serial)
|
||||
|
||||
def get_current_key(self):
|
||||
def get_current_key(self) -> int:
|
||||
return self.get_typing_handler()._latest_room_serial
|
||||
|
||||
@@ -145,10 +145,6 @@ class UserDirectoryHandler(StateDeltasHandler):
|
||||
if self.pos is None:
|
||||
self.pos = await self.store.get_user_directory_stream_pos()
|
||||
|
||||
# If still None then the initial background update hasn't happened yet
|
||||
if self.pos is None:
|
||||
return None
|
||||
|
||||
# Loop round handling deltas until we're up to date
|
||||
while True:
|
||||
with Measure(self.clock, "user_dir_delta"):
|
||||
@@ -233,6 +229,11 @@ class UserDirectoryHandler(StateDeltasHandler):
|
||||
|
||||
if change: # The user joined
|
||||
event = await self.store.get_event(event_id, allow_none=True)
|
||||
# It isn't expected for this event to not exist, but we
|
||||
# don't want the entire background process to break.
|
||||
if event is None:
|
||||
continue
|
||||
|
||||
profile = ProfileInfo(
|
||||
avatar_url=event.content.get("avatar_url"),
|
||||
display_name=event.content.get("displayname"),
|
||||
|
||||
105
synapse/replication/tcp/external_cache.py
Normal file
105
synapse/replication/tcp/external_cache.py
Normal file
@@ -0,0 +1,105 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
from prometheus_client import Counter
|
||||
|
||||
from synapse.logging.context import make_deferred_yieldable
|
||||
from synapse.util import json_decoder, json_encoder
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
set_counter = Counter(
|
||||
"synapse_external_cache_set",
|
||||
"Number of times we set a cache",
|
||||
labelnames=["cache_name"],
|
||||
)
|
||||
|
||||
get_counter = Counter(
|
||||
"synapse_external_cache_get",
|
||||
"Number of times we get a cache",
|
||||
labelnames=["cache_name", "hit"],
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ExternalCache:
|
||||
"""A cache backed by an external Redis. Does nothing if no Redis is
|
||||
configured.
|
||||
"""
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self._redis_connection = hs.get_outbound_redis_connection()
|
||||
|
||||
def _get_redis_key(self, cache_name: str, key: str) -> str:
|
||||
return "cache_v1:%s:%s" % (cache_name, key)
|
||||
|
||||
def is_enabled(self) -> bool:
|
||||
"""Whether the external cache is used or not.
|
||||
|
||||
It's safe to use the cache when this returns false, the methods will
|
||||
just no-op, but the function is useful to avoid doing unnecessary work.
|
||||
"""
|
||||
return self._redis_connection is not None
|
||||
|
||||
async def set(self, cache_name: str, key: str, value: Any, expiry_ms: int) -> None:
|
||||
"""Add the key/value to the named cache, with the expiry time given.
|
||||
"""
|
||||
|
||||
if self._redis_connection is None:
|
||||
return
|
||||
|
||||
set_counter.labels(cache_name).inc()
|
||||
|
||||
# txredisapi requires the value to be string, bytes or numbers, so we
|
||||
# encode stuff in JSON.
|
||||
encoded_value = json_encoder.encode(value)
|
||||
|
||||
logger.debug("Caching %s %s: %r", cache_name, key, encoded_value)
|
||||
|
||||
return await make_deferred_yieldable(
|
||||
self._redis_connection.set(
|
||||
self._get_redis_key(cache_name, key), encoded_value, pexpire=expiry_ms,
|
||||
)
|
||||
)
|
||||
|
||||
async def get(self, cache_name: str, key: str) -> Optional[Any]:
|
||||
"""Look up a key/value in the named cache.
|
||||
"""
|
||||
|
||||
if self._redis_connection is None:
|
||||
return None
|
||||
|
||||
result = await make_deferred_yieldable(
|
||||
self._redis_connection.get(self._get_redis_key(cache_name, key))
|
||||
)
|
||||
|
||||
logger.debug("Got cache result %s %s: %r", cache_name, key, result)
|
||||
|
||||
get_counter.labels(cache_name, result is not None).inc()
|
||||
|
||||
if not result:
|
||||
return None
|
||||
|
||||
# For some reason the integers get magically converted back to integers
|
||||
if isinstance(result, int):
|
||||
return result
|
||||
|
||||
return json_decoder.decode(result)
|
||||
@@ -15,6 +15,7 @@
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Awaitable,
|
||||
Dict,
|
||||
@@ -63,6 +64,9 @@ from synapse.replication.tcp.streams import (
|
||||
TypingStream,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -88,7 +92,7 @@ class ReplicationCommandHandler:
|
||||
back out to connections.
|
||||
"""
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self._replication_data_handler = hs.get_replication_data_handler()
|
||||
self._presence_handler = hs.get_presence_handler()
|
||||
self._store = hs.get_datastore()
|
||||
@@ -282,13 +286,6 @@ class ReplicationCommandHandler:
|
||||
if hs.config.redis.redis_enabled:
|
||||
from synapse.replication.tcp.redis import (
|
||||
RedisDirectTcpReplicationClientFactory,
|
||||
lazyConnection,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Connecting to redis (host=%r port=%r)",
|
||||
hs.config.redis_host,
|
||||
hs.config.redis_port,
|
||||
)
|
||||
|
||||
# First let's ensure that we have a ReplicationStreamer started.
|
||||
@@ -299,13 +296,7 @@ class ReplicationCommandHandler:
|
||||
# connection after SUBSCRIBE is called).
|
||||
|
||||
# First create the connection for sending commands.
|
||||
outbound_redis_connection = lazyConnection(
|
||||
reactor=hs.get_reactor(),
|
||||
host=hs.config.redis_host,
|
||||
port=hs.config.redis_port,
|
||||
password=hs.config.redis.redis_password,
|
||||
reconnect=True,
|
||||
)
|
||||
outbound_redis_connection = hs.get_outbound_redis_connection()
|
||||
|
||||
# Now create the factory/connection for the subscription stream.
|
||||
self._factory = RedisDirectTcpReplicationClientFactory(
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
|
||||
import logging
|
||||
from inspect import isawaitable
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
from typing import TYPE_CHECKING, Optional, Type, cast
|
||||
|
||||
import txredisapi
|
||||
|
||||
@@ -23,6 +23,7 @@ from synapse.logging.context import PreserveLoggingContext, make_deferred_yielda
|
||||
from synapse.metrics.background_process_metrics import (
|
||||
BackgroundProcessLoggingContext,
|
||||
run_as_background_process,
|
||||
wrap_as_background_process,
|
||||
)
|
||||
from synapse.replication.tcp.commands import (
|
||||
Command,
|
||||
@@ -59,16 +60,16 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
|
||||
immediately after initialisation.
|
||||
|
||||
Attributes:
|
||||
handler: The command handler to handle incoming commands.
|
||||
stream_name: The *redis* stream name to subscribe to and publish from
|
||||
(not anything to do with Synapse replication streams).
|
||||
outbound_redis_connection: The connection to redis to use to send
|
||||
synapse_handler: The command handler to handle incoming commands.
|
||||
synapse_stream_name: The *redis* stream name to subscribe to and publish
|
||||
from (not anything to do with Synapse replication streams).
|
||||
synapse_outbound_redis_connection: The connection to redis to use to send
|
||||
commands.
|
||||
"""
|
||||
|
||||
handler = None # type: ReplicationCommandHandler
|
||||
stream_name = None # type: str
|
||||
outbound_redis_connection = None # type: txredisapi.RedisProtocol
|
||||
synapse_handler = None # type: ReplicationCommandHandler
|
||||
synapse_stream_name = None # type: str
|
||||
synapse_outbound_redis_connection = None # type: txredisapi.RedisProtocol
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
@@ -88,19 +89,19 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
|
||||
# it's important to make sure that we only send the REPLICATE command once we
|
||||
# have successfully subscribed to the stream - otherwise we might miss the
|
||||
# POSITION response sent back by the other end.
|
||||
logger.info("Sending redis SUBSCRIBE for %s", self.stream_name)
|
||||
await make_deferred_yieldable(self.subscribe(self.stream_name))
|
||||
logger.info("Sending redis SUBSCRIBE for %s", self.synapse_stream_name)
|
||||
await make_deferred_yieldable(self.subscribe(self.synapse_stream_name))
|
||||
logger.info(
|
||||
"Successfully subscribed to redis stream, sending REPLICATE command"
|
||||
)
|
||||
self.handler.new_connection(self)
|
||||
self.synapse_handler.new_connection(self)
|
||||
await self._async_send_command(ReplicateCommand())
|
||||
logger.info("REPLICATE successfully sent")
|
||||
|
||||
# We send out our positions when there is a new connection in case the
|
||||
# other side missed updates. We do this for Redis connections as the
|
||||
# otherside won't know we've connected and so won't issue a REPLICATE.
|
||||
self.handler.send_positions_to_connection(self)
|
||||
self.synapse_handler.send_positions_to_connection(self)
|
||||
|
||||
def messageReceived(self, pattern: str, channel: str, message: str):
|
||||
"""Received a message from redis.
|
||||
@@ -137,7 +138,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
|
||||
cmd: received command
|
||||
"""
|
||||
|
||||
cmd_func = getattr(self.handler, "on_%s" % (cmd.NAME,), None)
|
||||
cmd_func = getattr(self.synapse_handler, "on_%s" % (cmd.NAME,), None)
|
||||
if not cmd_func:
|
||||
logger.warning("Unhandled command: %r", cmd)
|
||||
return
|
||||
@@ -155,7 +156,7 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
|
||||
def connectionLost(self, reason):
|
||||
logger.info("Lost connection to redis")
|
||||
super().connectionLost(reason)
|
||||
self.handler.lost_connection(self)
|
||||
self.synapse_handler.lost_connection(self)
|
||||
|
||||
# mark the logging context as finished
|
||||
self._logging_context.__exit__(None, None, None)
|
||||
@@ -183,11 +184,54 @@ class RedisSubscriber(txredisapi.SubscriberProtocol, AbstractConnection):
|
||||
tcp_outbound_commands_counter.labels(cmd.NAME, "redis").inc()
|
||||
|
||||
await make_deferred_yieldable(
|
||||
self.outbound_redis_connection.publish(self.stream_name, encoded_string)
|
||||
self.synapse_outbound_redis_connection.publish(
|
||||
self.synapse_stream_name, encoded_string
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class RedisDirectTcpReplicationClientFactory(txredisapi.SubscriberFactory):
|
||||
class SynapseRedisFactory(txredisapi.RedisFactory):
|
||||
"""A subclass of RedisFactory that periodically sends pings to ensure that
|
||||
we detect dead connections.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hs: "HomeServer",
|
||||
uuid: str,
|
||||
dbid: Optional[int],
|
||||
poolsize: int,
|
||||
isLazy: bool = False,
|
||||
handler: Type = txredisapi.ConnectionHandler,
|
||||
charset: str = "utf-8",
|
||||
password: Optional[str] = None,
|
||||
replyTimeout: int = 30,
|
||||
convertNumbers: Optional[int] = True,
|
||||
):
|
||||
super().__init__(
|
||||
uuid=uuid,
|
||||
dbid=dbid,
|
||||
poolsize=poolsize,
|
||||
isLazy=isLazy,
|
||||
handler=handler,
|
||||
charset=charset,
|
||||
password=password,
|
||||
replyTimeout=replyTimeout,
|
||||
convertNumbers=convertNumbers,
|
||||
)
|
||||
|
||||
hs.get_clock().looping_call(self._send_ping, 30 * 1000)
|
||||
|
||||
@wrap_as_background_process("redis_ping")
|
||||
async def _send_ping(self):
|
||||
for connection in self.pool:
|
||||
try:
|
||||
await make_deferred_yieldable(connection.ping())
|
||||
except Exception:
|
||||
logger.warning("Failed to send ping to a redis connection")
|
||||
|
||||
|
||||
class RedisDirectTcpReplicationClientFactory(SynapseRedisFactory):
|
||||
"""This is a reconnecting factory that connects to redis and immediately
|
||||
subscribes to a stream.
|
||||
|
||||
@@ -206,65 +250,62 @@ class RedisDirectTcpReplicationClientFactory(txredisapi.SubscriberFactory):
|
||||
self, hs: "HomeServer", outbound_redis_connection: txredisapi.RedisProtocol
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
super().__init__(
|
||||
hs,
|
||||
uuid="subscriber",
|
||||
dbid=None,
|
||||
poolsize=1,
|
||||
replyTimeout=30,
|
||||
password=hs.config.redis.redis_password,
|
||||
)
|
||||
|
||||
# This sets the password on the RedisFactory base class (as
|
||||
# SubscriberFactory constructor doesn't pass it through).
|
||||
self.password = hs.config.redis.redis_password
|
||||
self.synapse_handler = hs.get_tcp_replication()
|
||||
self.synapse_stream_name = hs.hostname
|
||||
|
||||
self.handler = hs.get_tcp_replication()
|
||||
self.stream_name = hs.hostname
|
||||
|
||||
self.outbound_redis_connection = outbound_redis_connection
|
||||
self.synapse_outbound_redis_connection = outbound_redis_connection
|
||||
|
||||
def buildProtocol(self, addr):
|
||||
p = super().buildProtocol(addr) # type: RedisSubscriber
|
||||
p = super().buildProtocol(addr)
|
||||
p = cast(RedisSubscriber, p)
|
||||
|
||||
# We do this here rather than add to the constructor of `RedisSubcriber`
|
||||
# as to do so would involve overriding `buildProtocol` entirely, however
|
||||
# the base method does some other things than just instantiating the
|
||||
# protocol.
|
||||
p.handler = self.handler
|
||||
p.outbound_redis_connection = self.outbound_redis_connection
|
||||
p.stream_name = self.stream_name
|
||||
p.password = self.password
|
||||
p.synapse_handler = self.synapse_handler
|
||||
p.synapse_outbound_redis_connection = self.synapse_outbound_redis_connection
|
||||
p.synapse_stream_name = self.synapse_stream_name
|
||||
|
||||
return p
|
||||
|
||||
|
||||
def lazyConnection(
|
||||
reactor,
|
||||
hs: "HomeServer",
|
||||
host: str = "localhost",
|
||||
port: int = 6379,
|
||||
dbid: Optional[int] = None,
|
||||
reconnect: bool = True,
|
||||
charset: str = "utf-8",
|
||||
password: Optional[str] = None,
|
||||
connectTimeout: Optional[int] = None,
|
||||
replyTimeout: Optional[int] = None,
|
||||
convertNumbers: bool = True,
|
||||
replyTimeout: int = 30,
|
||||
) -> txredisapi.RedisProtocol:
|
||||
"""Equivalent to `txredisapi.lazyConnection`, except allows specifying a
|
||||
reactor.
|
||||
"""Creates a connection to Redis that is lazily set up and reconnects if the
|
||||
connections is lost.
|
||||
"""
|
||||
|
||||
isLazy = True
|
||||
poolsize = 1
|
||||
|
||||
uuid = "%s:%d" % (host, port)
|
||||
factory = txredisapi.RedisFactory(
|
||||
uuid,
|
||||
dbid,
|
||||
poolsize,
|
||||
isLazy,
|
||||
txredisapi.ConnectionHandler,
|
||||
charset,
|
||||
password,
|
||||
replyTimeout,
|
||||
convertNumbers,
|
||||
factory = SynapseRedisFactory(
|
||||
hs,
|
||||
uuid=uuid,
|
||||
dbid=dbid,
|
||||
poolsize=1,
|
||||
isLazy=True,
|
||||
handler=txredisapi.ConnectionHandler,
|
||||
password=password,
|
||||
replyTimeout=replyTimeout,
|
||||
)
|
||||
factory.continueTrying = reconnect
|
||||
for x in range(poolsize):
|
||||
reactor.connectTCP(host, port, factory, connectTimeout)
|
||||
|
||||
reactor = hs.get_reactor()
|
||||
reactor.connectTCP(host, port, factory, 30)
|
||||
|
||||
return factory.handler
|
||||
|
||||
@@ -385,7 +385,7 @@ class PreviewUrlResource(DirectServeJsonResource):
|
||||
"""
|
||||
Check whether the URL should be downloaded as oEmbed content instead.
|
||||
|
||||
Params:
|
||||
Args:
|
||||
url: The URL to check.
|
||||
|
||||
Returns:
|
||||
@@ -402,7 +402,7 @@ class PreviewUrlResource(DirectServeJsonResource):
|
||||
"""
|
||||
Request content from an oEmbed endpoint.
|
||||
|
||||
Params:
|
||||
Args:
|
||||
endpoint: The oEmbed API endpoint.
|
||||
url: The URL to pass to the API.
|
||||
|
||||
@@ -691,27 +691,51 @@ class PreviewUrlResource(DirectServeJsonResource):
|
||||
def decode_and_calc_og(
|
||||
body: bytes, media_uri: str, request_encoding: Optional[str] = None
|
||||
) -> Dict[str, Optional[str]]:
|
||||
"""
|
||||
Calculate metadata for an HTML document.
|
||||
|
||||
This uses lxml to parse the HTML document into the OG response. If errors
|
||||
occur during processing of the document, an empty response is returned.
|
||||
|
||||
Args:
|
||||
body: The HTML document, as bytes.
|
||||
media_url: The URI used to download the body.
|
||||
request_encoding: The character encoding of the body, as a string.
|
||||
|
||||
Returns:
|
||||
The OG response as a dictionary.
|
||||
"""
|
||||
# If there's no body, nothing useful is going to be found.
|
||||
if not body:
|
||||
return {}
|
||||
|
||||
from lxml import etree
|
||||
|
||||
# Create an HTML parser. If this fails, log and return no metadata.
|
||||
try:
|
||||
parser = etree.HTMLParser(recover=True, encoding=request_encoding)
|
||||
tree = etree.fromstring(body, parser)
|
||||
og = _calc_og(tree, media_uri)
|
||||
except LookupError:
|
||||
# blindly consider the encoding as utf-8.
|
||||
parser = etree.HTMLParser(recover=True, encoding="utf-8")
|
||||
except Exception as e:
|
||||
logger.warning("Unable to create HTML parser: %s" % (e,))
|
||||
return {}
|
||||
|
||||
def _attempt_calc_og(body_attempt: Union[bytes, str]) -> Dict[str, Optional[str]]:
|
||||
# Attempt to parse the body. If this fails, log and return no metadata.
|
||||
tree = etree.fromstring(body_attempt, parser)
|
||||
return _calc_og(tree, media_uri)
|
||||
|
||||
# Attempt to parse the body. If this fails, log and return no metadata.
|
||||
try:
|
||||
return _attempt_calc_og(body)
|
||||
except UnicodeDecodeError:
|
||||
# blindly try decoding the body as utf-8, which seems to fix
|
||||
# the charset mismatches on https://google.com
|
||||
parser = etree.HTMLParser(recover=True, encoding=request_encoding)
|
||||
tree = etree.fromstring(body.decode("utf-8", "ignore"), parser)
|
||||
og = _calc_og(tree, media_uri)
|
||||
|
||||
return og
|
||||
return _attempt_calc_og(body.decode("utf-8", "ignore"))
|
||||
|
||||
|
||||
def _calc_og(tree, media_uri: str) -> Dict[str, Optional[str]]:
|
||||
def _calc_og(tree: "etree.Element", media_uri: str) -> Dict[str, Optional[str]]:
|
||||
# suck our tree into lxml and define our OG response.
|
||||
|
||||
# if we see any image URLs in the OG response, then spider them
|
||||
|
||||
@@ -102,6 +102,7 @@ from synapse.notifier import Notifier
|
||||
from synapse.push.action_generator import ActionGenerator
|
||||
from synapse.push.pusherpool import PusherPool
|
||||
from synapse.replication.tcp.client import ReplicationDataHandler
|
||||
from synapse.replication.tcp.external_cache import ExternalCache
|
||||
from synapse.replication.tcp.handler import ReplicationCommandHandler
|
||||
from synapse.replication.tcp.resource import ReplicationStreamer
|
||||
from synapse.replication.tcp.streams import STREAMS_MAP, Stream
|
||||
@@ -127,6 +128,8 @@ from synapse.util.stringutils import random_string
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from txredisapi import RedisProtocol
|
||||
|
||||
from synapse.handlers.oidc_handler import OidcHandler
|
||||
from synapse.handlers.saml_handler import SamlHandler
|
||||
|
||||
@@ -710,6 +713,33 @@ class HomeServer(metaclass=abc.ABCMeta):
|
||||
def get_account_data_handler(self) -> AccountDataHandler:
|
||||
return AccountDataHandler(self)
|
||||
|
||||
@cache_in_self
|
||||
def get_external_cache(self) -> ExternalCache:
|
||||
return ExternalCache(self)
|
||||
|
||||
@cache_in_self
|
||||
def get_outbound_redis_connection(self) -> Optional["RedisProtocol"]:
|
||||
if not self.config.redis.redis_enabled:
|
||||
return None
|
||||
|
||||
# We only want to import redis module if we're using it, as we have
|
||||
# `txredisapi` as an optional dependency.
|
||||
from synapse.replication.tcp.redis import lazyConnection
|
||||
|
||||
logger.info(
|
||||
"Connecting to redis (host=%r port=%r) for external cache",
|
||||
self.config.redis_host,
|
||||
self.config.redis_port,
|
||||
)
|
||||
|
||||
return lazyConnection(
|
||||
hs=self,
|
||||
host=self.config.redis_host,
|
||||
port=self.config.redis_port,
|
||||
password=self.config.redis.redis_password,
|
||||
reconnect=True,
|
||||
)
|
||||
|
||||
async def remove_pusher(self, app_id: str, push_key: str, user_id: str):
|
||||
return await self.get_pusherpool().remove_pusher(app_id, push_key, user_id)
|
||||
|
||||
|
||||
@@ -310,6 +310,7 @@ class StateHandler:
|
||||
state_group_before_event = None
|
||||
state_group_before_event_prev_group = None
|
||||
deltas_to_state_group_before_event = None
|
||||
entry = None
|
||||
|
||||
else:
|
||||
# otherwise, we'll need to resolve the state across the prev_events.
|
||||
@@ -340,9 +341,13 @@ class StateHandler:
|
||||
current_state_ids=state_ids_before_event,
|
||||
)
|
||||
|
||||
# XXX: can we update the state cache entry for the new state group? or
|
||||
# could we set a flag on resolve_state_groups_for_events to tell it to
|
||||
# always make a state group?
|
||||
# Assign the new state group to the cached state entry.
|
||||
#
|
||||
# Note that this can race in that we could generate multiple state
|
||||
# groups for the same state entry, but that is just inefficient
|
||||
# rather than dangerous.
|
||||
if entry and entry.state_group is None:
|
||||
entry.state_group = state_group_before_event
|
||||
|
||||
#
|
||||
# now if it's not a state event, we're done
|
||||
|
||||
@@ -89,7 +89,7 @@ class EventForwardExtremitiesStore(SQLBaseStore):
|
||||
SELECT event_id, state_group, depth, received_ts
|
||||
FROM event_forward_extremities
|
||||
INNER JOIN event_to_state_groups USING (event_id)
|
||||
INNER JOIN events INNER JOIN USING (event_id)
|
||||
INNER JOIN events USING (room_id, event_id)
|
||||
WHERE room_id = ?
|
||||
"""
|
||||
|
||||
|
||||
@@ -24,6 +24,7 @@ from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_cla
|
||||
from synapse.storage.database import DatabasePool
|
||||
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
|
||||
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
|
||||
from synapse.types import Collection
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -460,7 +461,7 @@ class SearchStore(SearchBackgroundUpdateStore):
|
||||
|
||||
async def search_rooms(
|
||||
self,
|
||||
room_ids: List[str],
|
||||
room_ids: Collection[str],
|
||||
search_term: str,
|
||||
keys: List[str],
|
||||
limit,
|
||||
|
||||
@@ -15,11 +15,12 @@
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from collections import Counter
|
||||
from enum import Enum
|
||||
from itertools import chain
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from typing_extensions import Counter
|
||||
|
||||
from twisted.internet.defer import DeferredLock
|
||||
|
||||
from synapse.api.constants import EventTypes, Membership
|
||||
@@ -320,7 +321,9 @@ class StatsStore(StateDeltasStore):
|
||||
return slice_list
|
||||
|
||||
@cached()
|
||||
async def get_earliest_token_for_stats(self, stats_type: str, id: str) -> int:
|
||||
async def get_earliest_token_for_stats(
|
||||
self, stats_type: str, id: str
|
||||
) -> Optional[int]:
|
||||
"""
|
||||
Fetch the "earliest token". This is used by the room stats delta
|
||||
processor to ignore deltas that have been processed between the
|
||||
@@ -340,7 +343,7 @@ class StatsStore(StateDeltasStore):
|
||||
)
|
||||
|
||||
async def bulk_update_stats_delta(
|
||||
self, ts: int, updates: Dict[str, Dict[str, Dict[str, Counter]]], stream_id: int
|
||||
self, ts: int, updates: Dict[str, Dict[str, Counter[str]]], stream_id: int
|
||||
) -> None:
|
||||
"""Bulk update stats tables for a given stream_id and updates the stats
|
||||
incremental position.
|
||||
@@ -666,7 +669,7 @@ class StatsStore(StateDeltasStore):
|
||||
|
||||
async def get_changes_room_total_events_and_bytes(
|
||||
self, min_pos: int, max_pos: int
|
||||
) -> Dict[str, Dict[str, int]]:
|
||||
) -> Tuple[Dict[str, Dict[str, int]], Dict[str, Dict[str, int]]]:
|
||||
"""Fetches the counts of events in the given range of stream IDs.
|
||||
|
||||
Args:
|
||||
@@ -684,18 +687,19 @@ class StatsStore(StateDeltasStore):
|
||||
max_pos,
|
||||
)
|
||||
|
||||
def get_changes_room_total_events_and_bytes_txn(self, txn, low_pos, high_pos):
|
||||
def get_changes_room_total_events_and_bytes_txn(
|
||||
self, txn, low_pos: int, high_pos: int
|
||||
) -> Tuple[Dict[str, Dict[str, int]], Dict[str, Dict[str, int]]]:
|
||||
"""Gets the total_events and total_event_bytes counts for rooms and
|
||||
senders, in a range of stream_orderings (including backfilled events).
|
||||
|
||||
Args:
|
||||
txn
|
||||
low_pos (int): Low stream ordering
|
||||
high_pos (int): High stream ordering
|
||||
low_pos: Low stream ordering
|
||||
high_pos: High stream ordering
|
||||
|
||||
Returns:
|
||||
tuple[dict[str, dict[str, int]], dict[str, dict[str, int]]]: The
|
||||
room and user deltas for total_events/total_event_bytes in the
|
||||
The room and user deltas for total_events/total_event_bytes in the
|
||||
format of `stats_id` -> fields
|
||||
"""
|
||||
|
||||
|
||||
@@ -540,7 +540,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
|
||||
desc="get_user_in_directory",
|
||||
)
|
||||
|
||||
async def update_user_directory_stream_pos(self, stream_id: str) -> None:
|
||||
async def update_user_directory_stream_pos(self, stream_id: int) -> None:
|
||||
await self.db_pool.simple_update_one(
|
||||
table="user_directory_stream_pos",
|
||||
keyvalues={},
|
||||
|
||||
@@ -212,6 +212,9 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
|
||||
# Fake in memory Redis server that servers can connect to.
|
||||
self._redis_server = FakeRedisPubSubServer()
|
||||
|
||||
# We may have an attempt to connect to redis for the external cache already.
|
||||
self.connect_any_redis_attempts()
|
||||
|
||||
store = self.hs.get_datastore()
|
||||
self.database_pool = store.db_pool
|
||||
|
||||
@@ -401,25 +404,23 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
|
||||
fake one.
|
||||
"""
|
||||
clients = self.reactor.tcpClients
|
||||
self.assertEqual(len(clients), 1)
|
||||
(host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
|
||||
self.assertEqual(host, "localhost")
|
||||
self.assertEqual(port, 6379)
|
||||
while clients:
|
||||
(host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
|
||||
self.assertEqual(host, "localhost")
|
||||
self.assertEqual(port, 6379)
|
||||
|
||||
client_protocol = client_factory.buildProtocol(None)
|
||||
server_protocol = self._redis_server.buildProtocol(None)
|
||||
client_protocol = client_factory.buildProtocol(None)
|
||||
server_protocol = self._redis_server.buildProtocol(None)
|
||||
|
||||
client_to_server_transport = FakeTransport(
|
||||
server_protocol, self.reactor, client_protocol
|
||||
)
|
||||
client_protocol.makeConnection(client_to_server_transport)
|
||||
client_to_server_transport = FakeTransport(
|
||||
server_protocol, self.reactor, client_protocol
|
||||
)
|
||||
client_protocol.makeConnection(client_to_server_transport)
|
||||
|
||||
server_to_client_transport = FakeTransport(
|
||||
client_protocol, self.reactor, server_protocol
|
||||
)
|
||||
server_protocol.makeConnection(server_to_client_transport)
|
||||
|
||||
return client_to_server_transport, server_to_client_transport
|
||||
server_to_client_transport = FakeTransport(
|
||||
client_protocol, self.reactor, server_protocol
|
||||
)
|
||||
server_protocol.makeConnection(server_to_client_transport)
|
||||
|
||||
|
||||
class TestReplicationDataHandler(GenericWorkerReplicationHandler):
|
||||
@@ -624,6 +625,12 @@ class FakeRedisPubSubProtocol(Protocol):
|
||||
(channel,) = args
|
||||
self._server.add_subscriber(self)
|
||||
self.send(["subscribe", channel, 1])
|
||||
|
||||
# Since we use SET/GET to cache things we can safely no-op them.
|
||||
elif command == b"SET":
|
||||
self.send("OK")
|
||||
elif command == b"GET":
|
||||
self.send(None)
|
||||
else:
|
||||
raise Exception("Unknown command")
|
||||
|
||||
@@ -645,6 +652,8 @@ class FakeRedisPubSubProtocol(Protocol):
|
||||
# We assume bytes are just unicode strings.
|
||||
obj = obj.decode("utf-8")
|
||||
|
||||
if obj is None:
|
||||
return "$-1\r\n"
|
||||
if isinstance(obj, str):
|
||||
return "${len}\r\n{str}\r\n".format(len=len(obj), str=obj)
|
||||
if isinstance(obj, int):
|
||||
|
||||
@@ -261,3 +261,32 @@ class PreviewUrlTestCase(unittest.TestCase):
|
||||
html = ""
|
||||
og = decode_and_calc_og(html, "http://example.com/test.html")
|
||||
self.assertEqual(og, {})
|
||||
|
||||
def test_invalid_encoding(self):
|
||||
"""An invalid character encoding should be ignored and treated as UTF-8, if possible."""
|
||||
html = """
|
||||
<html>
|
||||
<head><title>Foo</title></head>
|
||||
<body>
|
||||
Some text.
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
og = decode_and_calc_og(
|
||||
html, "http://example.com/test.html", "invalid-encoding"
|
||||
)
|
||||
self.assertEqual(og, {"og:title": "Foo", "og:description": "Some text."})
|
||||
|
||||
def test_invalid_encoding2(self):
|
||||
"""A body which doesn't match the sent character encoding."""
|
||||
# Note that this contains an invalid UTF-8 sequence in the title.
|
||||
html = b"""
|
||||
<html>
|
||||
<head><title>\xff\xff Foo</title></head>
|
||||
<body>
|
||||
Some text.
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
og = decode_and_calc_og(html, "http://example.com/test.html")
|
||||
self.assertEqual(og, {"og:title": "ÿÿ Foo", "og:description": "Some text."})
|
||||
|
||||
Reference in New Issue
Block a user