Compare commits
2 Commits
clokep/tes
...
anoa/docs_
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3360be1829 | ||
|
|
19ca533bcc |
@@ -33,7 +33,7 @@ site-url = "/synapse/"
|
||||
additional-css = [
|
||||
"docs/website_files/table-of-contents.css",
|
||||
"docs/website_files/remove-nav-buttons.css",
|
||||
"docs/website_files/indent-section-headers.css",
|
||||
"docs/website_files/section-headers.css",
|
||||
]
|
||||
additional-js = ["docs/website_files/table-of-contents.js"]
|
||||
theme = "docs/website_files/theme"
|
||||
@@ -1 +1 @@
|
||||
Add type hints to tests files.
|
||||
Add type hints to `tests/rest/client`.
|
||||
|
||||
@@ -1 +1 @@
|
||||
Add type hints to tests files.
|
||||
Add type hints to `tests/rest`.
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
Document that the `typing`, `to_device`, `account_data`, `receipts`, and `presence` stream writer can only be used on a single worker.
|
||||
@@ -1 +0,0 @@
|
||||
Avoid trying to calculate the state at outlier events.
|
||||
@@ -1 +0,0 @@
|
||||
Fix a misleading comment in the function `check_event_for_spam`.
|
||||
@@ -1 +0,0 @@
|
||||
Document that contributors can sign off privately by email.
|
||||
@@ -1 +0,0 @@
|
||||
Remove unnecessary `pass` statements.
|
||||
@@ -1 +0,0 @@
|
||||
Add type hints to tests files.
|
||||
@@ -1 +0,0 @@
|
||||
Add type hints to tests files.
|
||||
@@ -1 +0,0 @@
|
||||
Update the SSO username picker template to comply with SIWA guidelines.
|
||||
@@ -458,17 +458,6 @@ Git allows you to add this signoff automatically when using the `-s`
|
||||
flag to `git commit`, which uses the name and email set in your
|
||||
`user.name` and `user.email` git configs.
|
||||
|
||||
### Private Sign off
|
||||
|
||||
If you would like to provide your legal name privately to the Matrix.org
|
||||
Foundation (instead of in a public commit or comment), you can do so
|
||||
by emailing your legal name and a link to the pull request to
|
||||
[dco@matrix.org](mailto:dco@matrix.org?subject=Private%20sign%20off).
|
||||
It helps to include "sign off" or similar in the subject line. You will then
|
||||
be instructed further.
|
||||
|
||||
Once private sign off is complete, doing so for future contributions will not
|
||||
be required.
|
||||
|
||||
# 10. Turn feedback into better code.
|
||||
|
||||
|
||||
@@ -1947,13 +1947,8 @@ saml2_config:
|
||||
#
|
||||
# localpart_template: Jinja2 template for the localpart of the MXID.
|
||||
# If this is not set, the user will be prompted to choose their
|
||||
# own username (see the documentation for the
|
||||
# 'sso_auth_account_details.html' template).
|
||||
#
|
||||
# confirm_localpart: Whether to prompt the user to validate (or
|
||||
# change) the generated localpart (see the documentation for the
|
||||
# 'sso_auth_account_details.html' template), instead of
|
||||
# registering the account right away.
|
||||
# own username (see 'sso_auth_account_details.html' in the 'sso'
|
||||
# section of this file).
|
||||
#
|
||||
# display_name_template: Jinja2 template for the display name to set
|
||||
# on first login. If unset, no displayname will be set.
|
||||
|
||||
@@ -176,11 +176,8 @@ Below are the templates Synapse will look for when generating pages related to S
|
||||
for the brand of the IdP
|
||||
* `user_attributes`: an object containing details about the user that
|
||||
we received from the IdP. May have the following attributes:
|
||||
* `display_name`: the user's display name
|
||||
* `emails`: a list of email addresses
|
||||
* `localpart`: the local part of the Matrix user ID to register,
|
||||
if `localpart_template` is set in the mapping provider configuration (empty
|
||||
string if not)
|
||||
* display_name: the user's display_name
|
||||
* emails: a list of email addresses
|
||||
The template should render a form which submits the following fields:
|
||||
* `username`: the localpart of the user's chosen user id
|
||||
* `sso_new_user_consent.html`: HTML page allowing the user to consent to the
|
||||
|
||||
@@ -1,7 +0,0 @@
|
||||
/*
|
||||
* Indents each chapter title in the left sidebar so that they aren't
|
||||
* at the same level as the section headers.
|
||||
*/
|
||||
.chapter-item {
|
||||
margin-left: 1em;
|
||||
}
|
||||
20
docs/website_files/section-headers.css
Normal file
20
docs/website_files/section-headers.css
Normal file
@@ -0,0 +1,20 @@
|
||||
/*
|
||||
* Indents each chapter title in the left sidebar so that they aren't
|
||||
* at the same level as the section headers.
|
||||
*/
|
||||
.chapter-item {
|
||||
margin-left: 1em;
|
||||
}
|
||||
|
||||
/*
|
||||
* Prevents a large gap between successive section headers.
|
||||
*
|
||||
* mdbook sets 'margin-top: 2.5em' on h2 and h3 headers. This makes sense when separating
|
||||
* a header from the paragraph beforehand, but has the downside of introducing a large
|
||||
* gap between headers that are next to each other with no text in between.
|
||||
*
|
||||
* This rule reduces the margin in this case.
|
||||
*/
|
||||
h1 + h2, h2 + h3 {
|
||||
margin-top: 1.0em;
|
||||
}
|
||||
@@ -351,11 +351,8 @@ is only supported with Redis-based replication.)
|
||||
|
||||
To enable this, the worker must have a HTTP replication listener configured,
|
||||
have a `worker_name` and be listed in the `instance_map` config. The same worker
|
||||
can handle multiple streams, but unless otherwise documented, each stream can only
|
||||
have a single writer.
|
||||
|
||||
For example, to move event persistence off to a dedicated worker, the shared
|
||||
configuration would include:
|
||||
can handle multiple streams. For example, to move event persistence off to a
|
||||
dedicated worker, the shared configuration would include:
|
||||
|
||||
```yaml
|
||||
instance_map:
|
||||
@@ -373,8 +370,8 @@ streams and the endpoints associated with them:
|
||||
|
||||
##### The `events` stream
|
||||
|
||||
The `events` stream experimentally supports having multiple writers, where work
|
||||
is sharded between them by room ID. Note that you *must* restart all worker
|
||||
The `events` stream also experimentally supports having multiple writers, where
|
||||
work is sharded between them by room ID. Note that you *must* restart all worker
|
||||
instances when adding or removing event persisters. An example `stream_writers`
|
||||
configuration with multiple writers:
|
||||
|
||||
@@ -387,38 +384,38 @@ stream_writers:
|
||||
|
||||
##### The `typing` stream
|
||||
|
||||
The following endpoints should be routed directly to the worker configured as
|
||||
the stream writer for the `typing` stream:
|
||||
The following endpoints should be routed directly to the workers configured as
|
||||
stream writers for the `typing` stream:
|
||||
|
||||
^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/typing
|
||||
|
||||
##### The `to_device` stream
|
||||
|
||||
The following endpoints should be routed directly to the worker configured as
|
||||
the stream writer for the `to_device` stream:
|
||||
The following endpoints should be routed directly to the workers configured as
|
||||
stream writers for the `to_device` stream:
|
||||
|
||||
^/_matrix/client/(api/v1|r0|v3|unstable)/sendToDevice/
|
||||
|
||||
##### The `account_data` stream
|
||||
|
||||
The following endpoints should be routed directly to the worker configured as
|
||||
the stream writer for the `account_data` stream:
|
||||
The following endpoints should be routed directly to the workers configured as
|
||||
stream writers for the `account_data` stream:
|
||||
|
||||
^/_matrix/client/(api/v1|r0|v3|unstable)/.*/tags
|
||||
^/_matrix/client/(api/v1|r0|v3|unstable)/.*/account_data
|
||||
|
||||
##### The `receipts` stream
|
||||
|
||||
The following endpoints should be routed directly to the worker configured as
|
||||
the stream writer for the `receipts` stream:
|
||||
The following endpoints should be routed directly to the workers configured as
|
||||
stream writers for the `receipts` stream:
|
||||
|
||||
^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/receipt
|
||||
^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/read_markers
|
||||
|
||||
##### The `presence` stream
|
||||
|
||||
The following endpoints should be routed directly to the worker configured as
|
||||
the stream writer for the `presence` stream:
|
||||
The following endpoints should be routed directly to the workers configured as
|
||||
stream writers for the `presence` stream:
|
||||
|
||||
^/_matrix/client/(api/v1|r0|v3|unstable)/presence/
|
||||
|
||||
|
||||
1
mypy.ini
1
mypy.ini
@@ -90,6 +90,7 @@ exclude = (?x)
|
||||
|tests/push/test_push_rule_evaluator.py
|
||||
|tests/rest/client/test_transactions.py
|
||||
|tests/rest/media/v1/test_media_storage.py
|
||||
|tests/rest/media/v1/test_url_preview.py
|
||||
|tests/scripts/test_new_matrix_user.py
|
||||
|tests/server.py
|
||||
|tests/server_notices/test_resource_limits_server_notices.py
|
||||
|
||||
@@ -182,13 +182,8 @@ class OIDCConfig(Config):
|
||||
#
|
||||
# localpart_template: Jinja2 template for the localpart of the MXID.
|
||||
# If this is not set, the user will be prompted to choose their
|
||||
# own username (see the documentation for the
|
||||
# 'sso_auth_account_details.html' template).
|
||||
#
|
||||
# confirm_localpart: Whether to prompt the user to validate (or
|
||||
# change) the generated localpart (see the documentation for the
|
||||
# 'sso_auth_account_details.html' template), instead of
|
||||
# registering the account right away.
|
||||
# own username (see 'sso_auth_account_details.html' in the 'sso'
|
||||
# section of this file).
|
||||
#
|
||||
# display_name_template: Jinja2 template for the display name to set
|
||||
# on first login. If unset, no displayname will be set.
|
||||
|
||||
@@ -245,8 +245,8 @@ class SpamChecker:
|
||||
"""Checks if a given event is considered "spammy" by this server.
|
||||
|
||||
If the server considers an event spammy, then it will be rejected if
|
||||
sent by a local user. If it is sent by a user on another server, the
|
||||
event is soft-failed.
|
||||
sent by a local user. If it is sent by a user on another server, then
|
||||
users receive a blank event.
|
||||
|
||||
Args:
|
||||
event: the event to be checked
|
||||
|
||||
@@ -371,6 +371,7 @@ class DeviceHandler(DeviceWorkerHandler):
|
||||
log_kv(
|
||||
{"reason": "User doesn't have device id.", "device_id": device_id}
|
||||
)
|
||||
pass
|
||||
else:
|
||||
raise
|
||||
|
||||
@@ -413,6 +414,7 @@ class DeviceHandler(DeviceWorkerHandler):
|
||||
# no match
|
||||
set_tag("error", True)
|
||||
set_tag("reason", "User doesn't have that device id.")
|
||||
pass
|
||||
else:
|
||||
raise
|
||||
|
||||
|
||||
@@ -1228,7 +1228,6 @@ class OidcSessionData:
|
||||
|
||||
class UserAttributeDict(TypedDict):
|
||||
localpart: Optional[str]
|
||||
confirm_localpart: bool
|
||||
display_name: Optional[str]
|
||||
emails: List[str]
|
||||
|
||||
@@ -1317,7 +1316,6 @@ class JinjaOidcMappingConfig:
|
||||
display_name_template: Optional[Template]
|
||||
email_template: Optional[Template]
|
||||
extra_attributes: Dict[str, Template]
|
||||
confirm_localpart: bool = False
|
||||
|
||||
|
||||
class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
|
||||
@@ -1359,17 +1357,12 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
|
||||
"invalid jinja template", path=["extra_attributes", key]
|
||||
) from e
|
||||
|
||||
confirm_localpart = config.get("confirm_localpart") or False
|
||||
if not isinstance(confirm_localpart, bool):
|
||||
raise ConfigError("must be a bool", path=["confirm_localpart"])
|
||||
|
||||
return JinjaOidcMappingConfig(
|
||||
subject_claim=subject_claim,
|
||||
localpart_template=localpart_template,
|
||||
display_name_template=display_name_template,
|
||||
email_template=email_template,
|
||||
extra_attributes=extra_attributes,
|
||||
confirm_localpart=confirm_localpart,
|
||||
)
|
||||
|
||||
def get_remote_user_id(self, userinfo: UserInfo) -> str:
|
||||
@@ -1405,10 +1398,7 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
|
||||
emails.append(email)
|
||||
|
||||
return UserAttributeDict(
|
||||
localpart=localpart,
|
||||
display_name=display_name,
|
||||
emails=emails,
|
||||
confirm_localpart=self._config.confirm_localpart,
|
||||
localpart=localpart, display_name=display_name, emails=emails
|
||||
)
|
||||
|
||||
async def get_extra_attributes(self, userinfo: UserInfo, token: Token) -> JsonDict:
|
||||
|
||||
@@ -267,6 +267,7 @@ class BasePresenceHandler(abc.ABC):
|
||||
is_syncing: Whether or not the user is now syncing
|
||||
sync_time_msec: Time in ms when the user was last syncing
|
||||
"""
|
||||
pass
|
||||
|
||||
async def update_external_syncs_clear(self, process_id: str) -> None:
|
||||
"""Marks all users that had been marked as syncing by a given process
|
||||
@@ -276,6 +277,7 @@ class BasePresenceHandler(abc.ABC):
|
||||
|
||||
This is a no-op when presence is handled by a different worker.
|
||||
"""
|
||||
pass
|
||||
|
||||
async def process_replication_rows(
|
||||
self, stream_name: str, instance_name: str, token: int, rows: list
|
||||
|
||||
@@ -132,7 +132,6 @@ class UserAttributes:
|
||||
# if `None`, the mapper has not picked a userid, and the user should be prompted to
|
||||
# enter one.
|
||||
localpart: Optional[str]
|
||||
confirm_localpart: bool = False
|
||||
display_name: Optional[str] = None
|
||||
emails: Collection[str] = attr.Factory(list)
|
||||
|
||||
@@ -562,10 +561,9 @@ class SsoHandler:
|
||||
# Must provide either attributes or session, not both
|
||||
assert (attributes is not None) != (session is not None)
|
||||
|
||||
if (
|
||||
attributes
|
||||
and (attributes.localpart is None or attributes.confirm_localpart is True)
|
||||
) or (session and session.chosen_localpart is None):
|
||||
if (attributes and attributes.localpart is None) or (
|
||||
session and session.chosen_localpart is None
|
||||
):
|
||||
return b"/_synapse/client/pick_username/account_details"
|
||||
elif self._consent_at_registration and not (
|
||||
session and session.terms_accepted_version
|
||||
|
||||
@@ -120,6 +120,7 @@ class ByteParser(ByteWriteable, Generic[T], abc.ABC):
|
||||
"""Called when response has finished streaming and the parser should
|
||||
return the final result (or error).
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@attr.s(slots=True, frozen=True, auto_attribs=True)
|
||||
@@ -600,6 +601,7 @@ class MatrixFederationHttpClient:
|
||||
response.code,
|
||||
response_phrase,
|
||||
)
|
||||
pass
|
||||
else:
|
||||
logger.info(
|
||||
"{%s} [%s] Got response headers: %d %s",
|
||||
|
||||
@@ -233,6 +233,7 @@ class HttpServer(Protocol):
|
||||
servlet_classname (str): The name of the handler to be used in prometheus
|
||||
and opentracing logs.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
|
||||
|
||||
@@ -118,21 +118,13 @@ class RedisSubscriber(txredisapi.SubscriberProtocol):
|
||||
# have successfully subscribed to the stream - otherwise we might miss the
|
||||
# POSITION response sent back by the other end.
|
||||
logger.info("Sending redis SUBSCRIBE for %s", self.synapse_stream_name)
|
||||
try:
|
||||
await make_deferred_yieldable(self.subscribe(self.synapse_stream_name))
|
||||
except txredisapi.ConnectionError:
|
||||
# The connection died, the factory will attempt to reconnect.
|
||||
return
|
||||
await make_deferred_yieldable(self.subscribe(self.synapse_stream_name))
|
||||
logger.info(
|
||||
"Successfully subscribed to redis stream, sending REPLICATE command"
|
||||
)
|
||||
|
||||
# If the connection has been severed for some reason, bail.
|
||||
if not self.connected:
|
||||
return
|
||||
|
||||
self.synapse_handler.new_connection(self)
|
||||
await self._async_send_command(ReplicateCommand())
|
||||
logger.info("REPLICATE successfully sent")
|
||||
|
||||
# We send out our positions when there is a new connection in case the
|
||||
# other side missed updates. We do this for Redis connections as the
|
||||
@@ -263,15 +255,7 @@ class SynapseRedisFactory(txredisapi.RedisFactory):
|
||||
replyTimeout=replyTimeout,
|
||||
convertNumbers=convertNumbers,
|
||||
)
|
||||
self.hs = hs
|
||||
|
||||
# Set the homeserver reactor as the clock, if this is not done than
|
||||
# twisted.internet.protocol.ReconnectingClientFactory.retry will default
|
||||
# to the reactor.
|
||||
self.clock = hs.get_reactor()
|
||||
|
||||
# Send pings every 30 seconds (not that get_clock() returns a Clock, not
|
||||
# a reactor).
|
||||
hs.get_clock().looping_call(self._send_ping, 30 * 1000)
|
||||
|
||||
@wrap_as_background_process("redis_ping")
|
||||
@@ -369,7 +353,6 @@ def lazyConnection(
|
||||
reconnect: bool = True,
|
||||
password: Optional[str] = None,
|
||||
replyTimeout: int = 30,
|
||||
handler: Optional[txredisapi.ConnectionHandler] = None,
|
||||
) -> txredisapi.ConnectionHandler:
|
||||
"""Creates a connection to Redis that is lazily set up and reconnects if the
|
||||
connections is lost.
|
||||
|
||||
@@ -130,15 +130,15 @@
|
||||
</head>
|
||||
<body>
|
||||
<header>
|
||||
<h1>Choose your user name</h1>
|
||||
<p>This is required to create your account on {{ server_name }}, and you can't change this later.</p>
|
||||
<h1>Your account is nearly ready</h1>
|
||||
<p>Check your details before creating an account on {{ server_name }}</p>
|
||||
</header>
|
||||
<main>
|
||||
<form method="post" class="form__input" id="form">
|
||||
<div class="username_input" id="username_input">
|
||||
<label for="field-username">Username</label>
|
||||
<div class="prefix">@</div>
|
||||
<input type="text" name="username" id="field-username" value="{{ user_attributes.localpart }}" autofocus>
|
||||
<input type="text" name="username" id="field-username" autofocus>
|
||||
<div class="postfix">:{{ server_name }}</div>
|
||||
</div>
|
||||
<output for="username_input" id="field-username-output"></output>
|
||||
|
||||
@@ -298,6 +298,7 @@ class Responder:
|
||||
Returns:
|
||||
Resolves once the response has finished being written
|
||||
"""
|
||||
pass
|
||||
|
||||
def __enter__(self) -> None:
|
||||
pass
|
||||
|
||||
@@ -92,20 +92,12 @@ class AccountDetailsResource(DirectServeHtmlResource):
|
||||
self._sso_handler.render_error(request, "bad_session", e.msg, code=e.code)
|
||||
return
|
||||
|
||||
# The configuration might mandate going through this step to validate an
|
||||
# automatically generated localpart, so session.chosen_localpart might already
|
||||
# be set.
|
||||
localpart = ""
|
||||
if session.chosen_localpart is not None:
|
||||
localpart = session.chosen_localpart
|
||||
|
||||
idp_id = session.auth_provider_id
|
||||
template_params = {
|
||||
"idp": self._sso_handler.get_identity_providers()[idp_id],
|
||||
"user_attributes": {
|
||||
"display_name": session.display_name,
|
||||
"emails": session.emails,
|
||||
"localpart": localpart,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -328,6 +328,7 @@ class HomeServer(metaclass=abc.ABCMeta):
|
||||
Does nothing in this base class; overridden in derived classes to start the
|
||||
appropriate listeners.
|
||||
"""
|
||||
pass
|
||||
|
||||
def setup_background_tasks(self) -> None:
|
||||
"""
|
||||
|
||||
@@ -48,6 +48,8 @@ class ExternalIDReuseException(Exception):
|
||||
"""Exception if writing an external id for a user fails,
|
||||
because this external id is given to an other user."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@attr.s(frozen=True, slots=True, auto_attribs=True)
|
||||
class TokenLookupResult:
|
||||
|
||||
@@ -36,6 +36,7 @@ def run_upgrade(cur, database_engine, config, *args, **kwargs):
|
||||
config_files = config.appservice.app_service_config_files
|
||||
except AttributeError:
|
||||
logger.warning("Could not get app_service_config_files from config")
|
||||
pass
|
||||
|
||||
appservices = load_appservices(config.server.server_name, config_files)
|
||||
|
||||
|
||||
@@ -22,6 +22,8 @@ class TreeCacheNode(dict):
|
||||
leaves.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class TreeCache:
|
||||
"""
|
||||
|
||||
@@ -15,15 +15,11 @@
|
||||
from collections import Counter
|
||||
from unittest.mock import Mock
|
||||
|
||||
from twisted.test.proto_helpers import MemoryReactor
|
||||
|
||||
import synapse.rest.admin
|
||||
import synapse.storage
|
||||
from synapse.api.constants import EventTypes, JoinRules
|
||||
from synapse.api.room_versions import RoomVersions
|
||||
from synapse.rest.client import knock, login, room
|
||||
from synapse.server import HomeServer
|
||||
from synapse.util import Clock
|
||||
|
||||
from tests import unittest
|
||||
|
||||
@@ -36,7 +32,7 @@ class ExfiltrateData(unittest.HomeserverTestCase):
|
||||
knock.register_servlets,
|
||||
]
|
||||
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
def prepare(self, reactor, clock, hs):
|
||||
self.admin_handler = hs.get_admin_handler()
|
||||
|
||||
self.user1 = self.register_user("user1", "password")
|
||||
@@ -45,7 +41,7 @@ class ExfiltrateData(unittest.HomeserverTestCase):
|
||||
self.user2 = self.register_user("user2", "password")
|
||||
self.token2 = self.login("user2", "password")
|
||||
|
||||
def test_single_public_joined_room(self) -> None:
|
||||
def test_single_public_joined_room(self):
|
||||
"""Test that we write *all* events for a public room"""
|
||||
room_id = self.helper.create_room_as(
|
||||
self.user1, tok=self.token1, is_public=True
|
||||
@@ -78,7 +74,7 @@ class ExfiltrateData(unittest.HomeserverTestCase):
|
||||
self.assertEqual(counter[(EventTypes.Member, self.user1)], 1)
|
||||
self.assertEqual(counter[(EventTypes.Member, self.user2)], 1)
|
||||
|
||||
def test_single_private_joined_room(self) -> None:
|
||||
def test_single_private_joined_room(self):
|
||||
"""Tests that we correctly write state when we can't see all events in
|
||||
a room.
|
||||
"""
|
||||
@@ -116,7 +112,7 @@ class ExfiltrateData(unittest.HomeserverTestCase):
|
||||
self.assertEqual(counter[(EventTypes.Member, self.user1)], 1)
|
||||
self.assertEqual(counter[(EventTypes.Member, self.user2)], 1)
|
||||
|
||||
def test_single_left_room(self) -> None:
|
||||
def test_single_left_room(self):
|
||||
"""Tests that we don't see events in the room after we leave."""
|
||||
room_id = self.helper.create_room_as(self.user1, tok=self.token1)
|
||||
self.helper.send(room_id, body="Hello!", tok=self.token1)
|
||||
@@ -148,7 +144,7 @@ class ExfiltrateData(unittest.HomeserverTestCase):
|
||||
self.assertEqual(counter[(EventTypes.Member, self.user1)], 1)
|
||||
self.assertEqual(counter[(EventTypes.Member, self.user2)], 2)
|
||||
|
||||
def test_single_left_rejoined_private_room(self) -> None:
|
||||
def test_single_left_rejoined_private_room(self):
|
||||
"""Tests that see the correct events in private rooms when we
|
||||
repeatedly join and leave.
|
||||
"""
|
||||
@@ -189,7 +185,7 @@ class ExfiltrateData(unittest.HomeserverTestCase):
|
||||
self.assertEqual(counter[(EventTypes.Member, self.user1)], 1)
|
||||
self.assertEqual(counter[(EventTypes.Member, self.user2)], 3)
|
||||
|
||||
def test_invite(self) -> None:
|
||||
def test_invite(self):
|
||||
"""Tests that pending invites get handled correctly."""
|
||||
room_id = self.helper.create_room_as(self.user1, tok=self.token1)
|
||||
self.helper.send(room_id, body="Hello!", tok=self.token1)
|
||||
@@ -208,7 +204,7 @@ class ExfiltrateData(unittest.HomeserverTestCase):
|
||||
self.assertEqual(args[1].content["membership"], "invite")
|
||||
self.assertTrue(args[2]) # Assert there is at least one bit of state
|
||||
|
||||
def test_knock(self) -> None:
|
||||
def test_knock(self):
|
||||
"""Tests that knock get handled correctly."""
|
||||
# create a knockable v7 room
|
||||
room_id = self.helper.create_room_as(
|
||||
|
||||
@@ -15,12 +15,8 @@ from unittest.mock import Mock
|
||||
|
||||
import pymacaroons
|
||||
|
||||
from twisted.test.proto_helpers import MemoryReactor
|
||||
|
||||
from synapse.api.errors import AuthError, ResourceLimitError
|
||||
from synapse.rest import admin
|
||||
from synapse.server import HomeServer
|
||||
from synapse.util import Clock
|
||||
|
||||
from tests import unittest
|
||||
from tests.test_utils import make_awaitable
|
||||
@@ -31,7 +27,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||
admin.register_servlets,
|
||||
]
|
||||
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
def prepare(self, reactor, clock, hs):
|
||||
self.auth_handler = hs.get_auth_handler()
|
||||
self.macaroon_generator = hs.get_macaroon_generator()
|
||||
|
||||
@@ -46,23 +42,23 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
self.user1 = self.register_user("a_user", "pass")
|
||||
|
||||
def test_macaroon_caveats(self) -> None:
|
||||
def test_macaroon_caveats(self):
|
||||
token = self.macaroon_generator.generate_guest_access_token("a_user")
|
||||
macaroon = pymacaroons.Macaroon.deserialize(token)
|
||||
|
||||
def verify_gen(caveat: str) -> bool:
|
||||
def verify_gen(caveat):
|
||||
return caveat == "gen = 1"
|
||||
|
||||
def verify_user(caveat: str) -> bool:
|
||||
def verify_user(caveat):
|
||||
return caveat == "user_id = a_user"
|
||||
|
||||
def verify_type(caveat: str) -> bool:
|
||||
def verify_type(caveat):
|
||||
return caveat == "type = access"
|
||||
|
||||
def verify_nonce(caveat: str) -> bool:
|
||||
def verify_nonce(caveat):
|
||||
return caveat.startswith("nonce =")
|
||||
|
||||
def verify_guest(caveat: str) -> bool:
|
||||
def verify_guest(caveat):
|
||||
return caveat == "guest = true"
|
||||
|
||||
v = pymacaroons.Verifier()
|
||||
@@ -73,7 +69,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||
v.satisfy_general(verify_guest)
|
||||
v.verify(macaroon, self.hs.config.key.macaroon_secret_key)
|
||||
|
||||
def test_short_term_login_token_gives_user_id(self) -> None:
|
||||
def test_short_term_login_token_gives_user_id(self):
|
||||
token = self.macaroon_generator.generate_short_term_login_token(
|
||||
self.user1, "", duration_in_ms=5000
|
||||
)
|
||||
@@ -88,7 +84,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||
AuthError,
|
||||
)
|
||||
|
||||
def test_short_term_login_token_gives_auth_provider(self) -> None:
|
||||
def test_short_term_login_token_gives_auth_provider(self):
|
||||
token = self.macaroon_generator.generate_short_term_login_token(
|
||||
self.user1, auth_provider_id="my_idp"
|
||||
)
|
||||
@@ -96,7 +92,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||
self.assertEqual(self.user1, res.user_id)
|
||||
self.assertEqual("my_idp", res.auth_provider_id)
|
||||
|
||||
def test_short_term_login_token_cannot_replace_user_id(self) -> None:
|
||||
def test_short_term_login_token_cannot_replace_user_id(self):
|
||||
token = self.macaroon_generator.generate_short_term_login_token(
|
||||
self.user1, "", duration_in_ms=5000
|
||||
)
|
||||
@@ -116,7 +112,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||
AuthError,
|
||||
)
|
||||
|
||||
def test_mau_limits_disabled(self) -> None:
|
||||
def test_mau_limits_disabled(self):
|
||||
self.auth_blocking._limit_usage_by_mau = False
|
||||
# Ensure does not throw exception
|
||||
self.get_success(
|
||||
@@ -131,7 +127,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
)
|
||||
|
||||
def test_mau_limits_exceeded_large(self) -> None:
|
||||
def test_mau_limits_exceeded_large(self):
|
||||
self.auth_blocking._limit_usage_by_mau = True
|
||||
self.hs.get_datastores().main.get_monthly_active_count = Mock(
|
||||
return_value=make_awaitable(self.large_number_of_users)
|
||||
@@ -154,7 +150,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||
ResourceLimitError,
|
||||
)
|
||||
|
||||
def test_mau_limits_parity(self) -> None:
|
||||
def test_mau_limits_parity(self):
|
||||
# Ensure we're not at the unix epoch.
|
||||
self.reactor.advance(1)
|
||||
self.auth_blocking._limit_usage_by_mau = True
|
||||
@@ -193,7 +189,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
)
|
||||
|
||||
def test_mau_limits_not_exceeded(self) -> None:
|
||||
def test_mau_limits_not_exceeded(self):
|
||||
self.auth_blocking._limit_usage_by_mau = True
|
||||
|
||||
self.hs.get_datastores().main.get_monthly_active_count = Mock(
|
||||
@@ -215,7 +211,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
)
|
||||
|
||||
def _get_macaroon(self) -> pymacaroons.Macaroon:
|
||||
def _get_macaroon(self):
|
||||
token = self.macaroon_generator.generate_short_term_login_token(
|
||||
self.user1, "", duration_in_ms=5000
|
||||
)
|
||||
|
||||
@@ -39,7 +39,7 @@ class DeactivateAccountTestCase(HomeserverTestCase):
|
||||
self.user = self.register_user("user", "pass")
|
||||
self.token = self.login("user", "pass")
|
||||
|
||||
def _deactivate_my_account(self) -> None:
|
||||
def _deactivate_my_account(self):
|
||||
"""
|
||||
Deactivates the account `self.user` using `self.token` and asserts
|
||||
that it returns a 200 success code.
|
||||
|
||||
@@ -14,14 +14,9 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from twisted.test.proto_helpers import MemoryReactor
|
||||
|
||||
from synapse.api.errors import NotFoundError, SynapseError
|
||||
from synapse.handlers.device import MAX_DEVICE_DISPLAY_NAME_LEN
|
||||
from synapse.server import HomeServer
|
||||
from synapse.util import Clock
|
||||
import synapse.api.errors
|
||||
import synapse.handlers.device
|
||||
import synapse.storage
|
||||
|
||||
from tests import unittest
|
||||
|
||||
@@ -30,27 +25,28 @@ user2 = "@theresa:bbb"
|
||||
|
||||
|
||||
class DeviceTestCase(unittest.HomeserverTestCase):
|
||||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||
def make_homeserver(self, reactor, clock):
|
||||
hs = self.setup_test_homeserver("server", federation_http_client=None)
|
||||
self.handler = hs.get_device_handler()
|
||||
self.store = hs.get_datastores().main
|
||||
return hs
|
||||
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
def prepare(self, reactor, clock, hs):
|
||||
# These tests assume that it starts 1000 seconds in.
|
||||
self.reactor.advance(1000)
|
||||
|
||||
def test_device_is_created_with_invalid_name(self) -> None:
|
||||
def test_device_is_created_with_invalid_name(self):
|
||||
self.get_failure(
|
||||
self.handler.check_device_registered(
|
||||
user_id="@boris:foo",
|
||||
device_id="foo",
|
||||
initial_device_display_name="a" * (MAX_DEVICE_DISPLAY_NAME_LEN + 1),
|
||||
initial_device_display_name="a"
|
||||
* (synapse.handlers.device.MAX_DEVICE_DISPLAY_NAME_LEN + 1),
|
||||
),
|
||||
SynapseError,
|
||||
synapse.api.errors.SynapseError,
|
||||
)
|
||||
|
||||
def test_device_is_created_if_doesnt_exist(self) -> None:
|
||||
def test_device_is_created_if_doesnt_exist(self):
|
||||
res = self.get_success(
|
||||
self.handler.check_device_registered(
|
||||
user_id="@boris:foo",
|
||||
@@ -63,7 +59,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
|
||||
dev = self.get_success(self.handler.store.get_device("@boris:foo", "fco"))
|
||||
self.assertEqual(dev["display_name"], "display name")
|
||||
|
||||
def test_device_is_preserved_if_exists(self) -> None:
|
||||
def test_device_is_preserved_if_exists(self):
|
||||
res1 = self.get_success(
|
||||
self.handler.check_device_registered(
|
||||
user_id="@boris:foo",
|
||||
@@ -85,7 +81,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
|
||||
dev = self.get_success(self.handler.store.get_device("@boris:foo", "fco"))
|
||||
self.assertEqual(dev["display_name"], "display name")
|
||||
|
||||
def test_device_id_is_made_up_if_unspecified(self) -> None:
|
||||
def test_device_id_is_made_up_if_unspecified(self):
|
||||
device_id = self.get_success(
|
||||
self.handler.check_device_registered(
|
||||
user_id="@theresa:foo",
|
||||
@@ -97,7 +93,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
|
||||
dev = self.get_success(self.handler.store.get_device("@theresa:foo", device_id))
|
||||
self.assertEqual(dev["display_name"], "display")
|
||||
|
||||
def test_get_devices_by_user(self) -> None:
|
||||
def test_get_devices_by_user(self):
|
||||
self._record_users()
|
||||
|
||||
res = self.get_success(self.handler.get_devices_by_user(user1))
|
||||
@@ -135,7 +131,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
|
||||
device_map["abc"],
|
||||
)
|
||||
|
||||
def test_get_device(self) -> None:
|
||||
def test_get_device(self):
|
||||
self._record_users()
|
||||
|
||||
res = self.get_success(self.handler.get_device(user1, "abc"))
|
||||
@@ -150,19 +146,21 @@ class DeviceTestCase(unittest.HomeserverTestCase):
|
||||
res,
|
||||
)
|
||||
|
||||
def test_delete_device(self) -> None:
|
||||
def test_delete_device(self):
|
||||
self._record_users()
|
||||
|
||||
# delete the device
|
||||
self.get_success(self.handler.delete_device(user1, "abc"))
|
||||
|
||||
# check the device was deleted
|
||||
self.get_failure(self.handler.get_device(user1, "abc"), NotFoundError)
|
||||
self.get_failure(
|
||||
self.handler.get_device(user1, "abc"), synapse.api.errors.NotFoundError
|
||||
)
|
||||
|
||||
# we'd like to check the access token was invalidated, but that's a
|
||||
# bit of a PITA.
|
||||
|
||||
def test_delete_device_and_device_inbox(self) -> None:
|
||||
def test_delete_device_and_device_inbox(self):
|
||||
self._record_users()
|
||||
|
||||
# add an device_inbox
|
||||
@@ -193,7 +191,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
self.assertIsNone(res)
|
||||
|
||||
def test_update_device(self) -> None:
|
||||
def test_update_device(self):
|
||||
self._record_users()
|
||||
|
||||
update = {"display_name": "new display"}
|
||||
@@ -202,29 +200,32 @@ class DeviceTestCase(unittest.HomeserverTestCase):
|
||||
res = self.get_success(self.handler.get_device(user1, "abc"))
|
||||
self.assertEqual(res["display_name"], "new display")
|
||||
|
||||
def test_update_device_too_long_display_name(self) -> None:
|
||||
def test_update_device_too_long_display_name(self):
|
||||
"""Update a device with a display name that is invalid (too long)."""
|
||||
self._record_users()
|
||||
|
||||
# Request to update a device display name with a new value that is longer than allowed.
|
||||
update = {"display_name": "a" * (MAX_DEVICE_DISPLAY_NAME_LEN + 1)}
|
||||
update = {
|
||||
"display_name": "a"
|
||||
* (synapse.handlers.device.MAX_DEVICE_DISPLAY_NAME_LEN + 1)
|
||||
}
|
||||
self.get_failure(
|
||||
self.handler.update_device(user1, "abc", update),
|
||||
SynapseError,
|
||||
synapse.api.errors.SynapseError,
|
||||
)
|
||||
|
||||
# Ensure the display name was not updated.
|
||||
res = self.get_success(self.handler.get_device(user1, "abc"))
|
||||
self.assertEqual(res["display_name"], "display 2")
|
||||
|
||||
def test_update_unknown_device(self) -> None:
|
||||
def test_update_unknown_device(self):
|
||||
update = {"display_name": "new_display"}
|
||||
self.get_failure(
|
||||
self.handler.update_device("user_id", "unknown_device_id", update),
|
||||
NotFoundError,
|
||||
synapse.api.errors.NotFoundError,
|
||||
)
|
||||
|
||||
def _record_users(self) -> None:
|
||||
def _record_users(self):
|
||||
# check this works for both devices which have a recorded client_ip,
|
||||
# and those which don't.
|
||||
self._record_user(user1, "xyz", "display 0")
|
||||
@@ -237,13 +238,8 @@ class DeviceTestCase(unittest.HomeserverTestCase):
|
||||
self.reactor.advance(10000)
|
||||
|
||||
def _record_user(
|
||||
self,
|
||||
user_id: str,
|
||||
device_id: str,
|
||||
display_name: str,
|
||||
access_token: Optional[str] = None,
|
||||
ip: Optional[str] = None,
|
||||
) -> None:
|
||||
self, user_id, device_id, display_name, access_token=None, ip=None
|
||||
):
|
||||
device_id = self.get_success(
|
||||
self.handler.check_device_registered(
|
||||
user_id=user_id,
|
||||
@@ -252,7 +248,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
)
|
||||
|
||||
if access_token is not None and ip is not None:
|
||||
if ip is not None:
|
||||
self.get_success(
|
||||
self.store.insert_client_ip(
|
||||
user_id, access_token, ip, "user_agent", device_id
|
||||
@@ -262,7 +258,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
|
||||
class DehydrationTestCase(unittest.HomeserverTestCase):
|
||||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||
def make_homeserver(self, reactor, clock):
|
||||
hs = self.setup_test_homeserver("server", federation_http_client=None)
|
||||
self.handler = hs.get_device_handler()
|
||||
self.registration = hs.get_registration_handler()
|
||||
@@ -270,7 +266,7 @@ class DehydrationTestCase(unittest.HomeserverTestCase):
|
||||
self.store = hs.get_datastores().main
|
||||
return hs
|
||||
|
||||
def test_dehydrate_and_rehydrate_device(self) -> None:
|
||||
def test_dehydrate_and_rehydrate_device(self):
|
||||
user_id = "@boris:dehydration"
|
||||
|
||||
self.get_success(self.store.register_user(user_id, "foobar"))
|
||||
@@ -307,7 +303,7 @@ class DehydrationTestCase(unittest.HomeserverTestCase):
|
||||
access_token=access_token,
|
||||
device_id="not the right device ID",
|
||||
),
|
||||
NotFoundError,
|
||||
synapse.api.errors.NotFoundError,
|
||||
)
|
||||
|
||||
# dehydrating the right devices should succeed and change our device ID
|
||||
@@ -335,7 +331,7 @@ class DehydrationTestCase(unittest.HomeserverTestCase):
|
||||
# make sure that the device ID that we were initially assigned no longer exists
|
||||
self.get_failure(
|
||||
self.handler.get_device(user_id, device_id),
|
||||
NotFoundError,
|
||||
synapse.api.errors.NotFoundError,
|
||||
)
|
||||
|
||||
# make sure that there's no device available for dehydrating now
|
||||
|
||||
@@ -124,6 +124,7 @@ class PasswordCustomAuthProvider:
|
||||
("m.login.password", ("password",)): self.check_auth,
|
||||
}
|
||||
)
|
||||
pass
|
||||
|
||||
def check_auth(self, *args):
|
||||
return mock_password_provider.check_auth(*args)
|
||||
|
||||
@@ -549,6 +549,7 @@ class ModuleApiWorkerTestCase(BaseMultiWorkerStreamTestCase):
|
||||
|
||||
def default_config(self):
|
||||
conf = super().default_config()
|
||||
conf["redis"] = {"enabled": "true"}
|
||||
conf["stream_writers"] = {"presence": ["presence_writer"]}
|
||||
conf["instance_map"] = {
|
||||
"presence_writer": {"host": "testserv", "port": 1001},
|
||||
|
||||
@@ -14,14 +14,20 @@
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from twisted.internet.address import IPv4Address
|
||||
from twisted.internet.protocol import Protocol
|
||||
from twisted.python.failure import Failure
|
||||
from twisted.web.resource import Resource
|
||||
|
||||
from synapse.app.generic_worker import GenericWorkerServer
|
||||
from synapse.http.site import SynapseRequest, SynapseSite
|
||||
from synapse.replication.http import ReplicationRestResource
|
||||
from synapse.replication.tcp.client import ReplicationDataHandler
|
||||
from synapse.replication.tcp.handler import ReplicationCommandHandler
|
||||
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
|
||||
from synapse.replication.tcp.resource import (
|
||||
ReplicationStreamProtocolFactory,
|
||||
ServerReplicationStreamProtocol,
|
||||
)
|
||||
from synapse.server import HomeServer
|
||||
|
||||
from tests import unittest
|
||||
@@ -35,55 +41,6 @@ except ImportError:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FakeOutboundConnector:
|
||||
"""
|
||||
A fake connector class, reconnects.
|
||||
"""
|
||||
|
||||
def __init__(self, hs: HomeServer):
|
||||
self._hs = hs
|
||||
|
||||
def stopConnecting(self):
|
||||
pass
|
||||
|
||||
def connect(self):
|
||||
# Restart replication.
|
||||
from synapse.replication.tcp.redis import lazyConnection
|
||||
|
||||
handler = self._hs.get_outbound_redis_connection()
|
||||
|
||||
reactor = self._hs.get_reactor()
|
||||
reactor.connectTCP(
|
||||
self._hs.config.redis.redis_host,
|
||||
self._hs.config.redis.redis_port,
|
||||
handler._factory,
|
||||
timeout=30,
|
||||
bindAddress=None,
|
||||
)
|
||||
|
||||
def getDestination(self):
|
||||
return "blah"
|
||||
|
||||
|
||||
class FakeReplicationHandlerConnector:
|
||||
"""
|
||||
A fake connector class, reconnects.
|
||||
"""
|
||||
|
||||
def __init__(self, hs: HomeServer):
|
||||
self._hs = hs
|
||||
|
||||
def stopConnecting(self):
|
||||
pass
|
||||
|
||||
def connect(self):
|
||||
# Restart replication.
|
||||
self._hs.get_replication_command_handler().start_replication(self._hs)
|
||||
|
||||
def getDestination(self):
|
||||
return "blah"
|
||||
|
||||
|
||||
class BaseStreamTestCase(unittest.HomeserverTestCase):
|
||||
"""Base class for tests of the replication streams"""
|
||||
|
||||
@@ -92,33 +49,16 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
|
||||
if not hiredis:
|
||||
skip = "Requires hiredis"
|
||||
|
||||
def default_config(self):
|
||||
config = super().default_config()
|
||||
config["redis"] = {"enabled": True}
|
||||
return config
|
||||
|
||||
def prepare(self, reactor, clock, hs):
|
||||
# build a replication server
|
||||
server_factory = ReplicationStreamProtocolFactory(hs)
|
||||
self.streamer = hs.get_replication_streamer()
|
||||
|
||||
# Fake in memory Redis server that servers can connect to.
|
||||
self._redis_transports = []
|
||||
self._redis_server = FakeRedisPubSubServer()
|
||||
|
||||
# We may have an attempt to connect to redis for the external cache already.
|
||||
self.connect_any_redis_attempts()
|
||||
self.server: ServerReplicationStreamProtocol = server_factory.buildProtocol(
|
||||
IPv4Address("TCP", "127.0.0.1", 0)
|
||||
)
|
||||
|
||||
# Make a new HomeServer object for the worker
|
||||
self.reactor.lookups["testserv"] = "1.2.3.4"
|
||||
self.reactor.lookups["localhost"] = "127.0.0.1"
|
||||
|
||||
# Handle attempts to connect to fake redis server.
|
||||
self.reactor.add_tcp_client_callback(
|
||||
"localhost",
|
||||
6379,
|
||||
self.connect_any_redis_attempts,
|
||||
)
|
||||
|
||||
self.worker_hs = self.setup_test_homeserver(
|
||||
federation_http_client=None,
|
||||
homeserver_to_use=GenericWorkerServer,
|
||||
@@ -141,11 +81,18 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
|
||||
self.test_handler = self._build_replication_data_handler()
|
||||
self.worker_hs._replication_data_handler = self.test_handler # type: ignore[attr-defined]
|
||||
|
||||
self.hs.get_replication_command_handler().start_replication(self.hs)
|
||||
self.worker_hs.get_replication_command_handler().start_replication(
|
||||
self.worker_hs
|
||||
repl_handler = ReplicationCommandHandler(self.worker_hs)
|
||||
self.client = ClientReplicationStreamProtocol(
|
||||
self.worker_hs,
|
||||
"client",
|
||||
"test",
|
||||
clock,
|
||||
repl_handler,
|
||||
)
|
||||
|
||||
self._client_transport = None
|
||||
self._server_transport = None
|
||||
|
||||
def create_resource_dict(self) -> Dict[str, Resource]:
|
||||
d = super().create_resource_dict()
|
||||
d["/_synapse/replication"] = ReplicationRestResource(self.hs)
|
||||
@@ -162,46 +109,26 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
|
||||
return TestReplicationDataHandler(self.worker_hs)
|
||||
|
||||
def reconnect(self):
|
||||
self.disconnect()
|
||||
print("RECONNECTING")
|
||||
if self._client_transport:
|
||||
self.client.close()
|
||||
|
||||
# Make a `FakeConnector` to emulate the behavior of `connectTCP. That
|
||||
# creates an `IConnector`, which is responsible for calling the factory
|
||||
# `clientConnectionLost`. The reconnecting factory then calls
|
||||
# `IConnector.connect` to attempt a reconnection. The transport is meant
|
||||
# to call `connectionLost` on the `IConnector`.
|
||||
#
|
||||
# Most of that is bypassed by directly calling `retry` on the factory,
|
||||
# which schedules a `connect()` call on the connector.
|
||||
timeouts = []
|
||||
for hs in (self.hs, self.worker_hs):
|
||||
hs_factory_outbound = hs.get_outbound_redis_connection()._factory
|
||||
hs_factory_outbound.clientConnectionLost(
|
||||
FakeOutboundConnector(hs), Failure(RuntimeError(""))
|
||||
)
|
||||
timeouts.append(hs_factory_outbound.delay)
|
||||
if self._server_transport:
|
||||
self.server.close()
|
||||
|
||||
hs_factory = hs.get_replication_command_handler()._factory
|
||||
hs_factory.clientConnectionLost(
|
||||
FakeReplicationHandlerConnector(hs),
|
||||
Failure(RuntimeError("")),
|
||||
)
|
||||
timeouts.append(hs_factory.delay)
|
||||
self._client_transport = FakeTransport(self.server, self.reactor)
|
||||
self.client.makeConnection(self._client_transport)
|
||||
|
||||
# Wait for the reconnects to happen.
|
||||
self.pump(max(timeouts) + 1)
|
||||
|
||||
self.connect_any_redis_attempts()
|
||||
self._server_transport = FakeTransport(self.client, self.reactor)
|
||||
self.server.makeConnection(self._server_transport)
|
||||
|
||||
def disconnect(self):
|
||||
print("DISCONNECTING")
|
||||
for (
|
||||
client_to_server_transport,
|
||||
server_to_client_transport,
|
||||
) in self._redis_transports:
|
||||
client_to_server_transport.abortConnection()
|
||||
server_to_client_transport.abortConnection()
|
||||
self._redis_transports = []
|
||||
if self._client_transport:
|
||||
self._client_transport = None
|
||||
self.client.close()
|
||||
|
||||
if self._server_transport:
|
||||
self._server_transport = None
|
||||
self.server.close()
|
||||
|
||||
def replicate(self):
|
||||
"""Tell the master side of replication that something has happened, and then
|
||||
@@ -285,40 +212,6 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
self.assertEqual(request.method, b"GET")
|
||||
|
||||
def connect_any_redis_attempts(self):
|
||||
"""If redis is enabled we need to deal with workers connecting to a
|
||||
redis server. We don't want to use a real Redis server so we use a
|
||||
fake one.
|
||||
"""
|
||||
clients = self.reactor.tcpClients
|
||||
while clients:
|
||||
(host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
|
||||
self.assertEqual(host, "localhost")
|
||||
self.assertEqual(port, 6379)
|
||||
|
||||
client_protocol = client_factory.buildProtocol(None)
|
||||
server_protocol = self._redis_server.buildProtocol(None)
|
||||
if client_protocol.__class__.__name__ == "RedisSubscriber":
|
||||
print(client_protocol, client_protocol.synapse_handler._presence_handler.hs, client_protocol.synapse_outbound_redis_connection)
|
||||
else:
|
||||
print(client_protocol, client_protocol.factory.hs)
|
||||
print()
|
||||
|
||||
client_to_server_transport = FakeTransport(
|
||||
server_protocol, self.reactor, client_protocol
|
||||
)
|
||||
client_protocol.makeConnection(client_to_server_transport)
|
||||
|
||||
server_to_client_transport = FakeTransport(
|
||||
client_protocol, self.reactor, server_protocol
|
||||
)
|
||||
server_protocol.makeConnection(server_to_client_transport)
|
||||
|
||||
# Store for potentially disconnecting.
|
||||
self._redis_transports.append(
|
||||
(client_to_server_transport, server_to_client_transport)
|
||||
)
|
||||
|
||||
|
||||
class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
|
||||
"""Base class for tests running multiple workers.
|
||||
@@ -327,14 +220,11 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
|
||||
unlike `BaseStreamTestCase`.
|
||||
"""
|
||||
|
||||
def default_config(self):
|
||||
config = super().default_config()
|
||||
config["redis"] = {"enabled": True}
|
||||
return config
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
|
||||
# build a replication server
|
||||
self.server_factory = ReplicationStreamProtocolFactory(self.hs)
|
||||
self.streamer = self.hs.get_replication_streamer()
|
||||
|
||||
# Fake in memory Redis server that servers can connect to.
|
||||
@@ -353,14 +243,15 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
|
||||
# handling inbound HTTP requests to that instance.
|
||||
self._hs_to_site = {self.hs: self.site}
|
||||
|
||||
# Handle attempts to connect to fake redis server.
|
||||
self.reactor.add_tcp_client_callback(
|
||||
"localhost",
|
||||
6379,
|
||||
self.connect_any_redis_attempts,
|
||||
)
|
||||
if self.hs.config.redis.redis_enabled:
|
||||
# Handle attempts to connect to fake redis server.
|
||||
self.reactor.add_tcp_client_callback(
|
||||
"localhost",
|
||||
6379,
|
||||
self.connect_any_redis_attempts,
|
||||
)
|
||||
|
||||
self.hs.get_replication_command_handler().start_replication(self.hs)
|
||||
self.hs.get_replication_command_handler().start_replication(self.hs)
|
||||
|
||||
# When we see a connection attempt to the master replication listener we
|
||||
# automatically set up the connection. This is so that tests don't
|
||||
@@ -444,6 +335,27 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
|
||||
store = worker_hs.get_datastores().main
|
||||
store.db_pool._db_pool = self.database_pool._db_pool
|
||||
|
||||
# Set up TCP replication between master and the new worker if we don't
|
||||
# have Redis support enabled.
|
||||
if not worker_hs.config.redis.redis_enabled:
|
||||
repl_handler = ReplicationCommandHandler(worker_hs)
|
||||
client = ClientReplicationStreamProtocol(
|
||||
worker_hs,
|
||||
"client",
|
||||
"test",
|
||||
self.clock,
|
||||
repl_handler,
|
||||
)
|
||||
server = self.server_factory.buildProtocol(
|
||||
IPv4Address("TCP", "127.0.0.1", 0)
|
||||
)
|
||||
|
||||
client_transport = FakeTransport(server, self.reactor)
|
||||
client.makeConnection(client_transport)
|
||||
|
||||
server_transport = FakeTransport(client, self.reactor)
|
||||
server.makeConnection(server_transport)
|
||||
|
||||
# Set up a resource for the worker
|
||||
resource = ReplicationRestResource(worker_hs)
|
||||
|
||||
@@ -462,7 +374,8 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
|
||||
reactor=self.reactor,
|
||||
)
|
||||
|
||||
worker_hs.get_replication_command_handler().start_replication(worker_hs)
|
||||
if worker_hs.config.redis.redis_enabled:
|
||||
worker_hs.get_replication_command_handler().start_replication(worker_hs)
|
||||
|
||||
return worker_hs
|
||||
|
||||
@@ -511,7 +424,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
# Note: at this point we've wired everything up, but we need to return
|
||||
# before the data starts flowing over the connections as this is called
|
||||
# inside `connectTCP` before the connection has been passed back to the
|
||||
# inside `connecTCP` before the connection has been passed back to the
|
||||
# code that requested the TCP connection.
|
||||
|
||||
def connect_any_redis_attempts(self):
|
||||
@@ -623,13 +536,8 @@ class FakeRedisPubSubProtocol(Protocol):
|
||||
self.send("OK")
|
||||
elif command == b"GET":
|
||||
self.send(None)
|
||||
|
||||
# Connection keep-alives.
|
||||
elif command == b"PING":
|
||||
self.send("PONG")
|
||||
|
||||
else:
|
||||
raise Exception(f"Unknown command: {command}")
|
||||
raise Exception("Unknown command")
|
||||
|
||||
def send(self, msg):
|
||||
"""Send a message back to the client."""
|
||||
|
||||
@@ -250,14 +250,10 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
|
||||
self.replicate()
|
||||
self.check("get_rooms_for_user_with_stream_ordering", (USER_ID_2,), set())
|
||||
|
||||
# limit the replication rate from server -> client.
|
||||
print(len(self._redis_transports))
|
||||
for x in self._redis_transports:
|
||||
print(f"\t{x}")
|
||||
assert len(self._redis_transports) == 1
|
||||
for _, repl_transport in self._redis_transports:
|
||||
assert isinstance(repl_transport, FakeTransport)
|
||||
repl_transport.autoflush = False
|
||||
# limit the replication rate
|
||||
repl_transport = self._server_transport
|
||||
assert isinstance(repl_transport, FakeTransport)
|
||||
repl_transport.autoflush = False
|
||||
|
||||
# build the join and message events and persist them in the same batch.
|
||||
logger.info("----- build test events ------")
|
||||
|
||||
@@ -28,7 +28,7 @@ class ReceiptsStreamTestCase(BaseStreamTestCase):
|
||||
return Mock(wraps=super()._build_replication_data_handler())
|
||||
|
||||
def test_receipt(self):
|
||||
# self.reconnect()
|
||||
self.reconnect()
|
||||
|
||||
# tell the master to send a new receipt
|
||||
self.get_success(
|
||||
|
||||
@@ -27,8 +27,10 @@ class ClientReaderTestCase(BaseMultiWorkerStreamTestCase):
|
||||
servlets = [register.register_servlets]
|
||||
|
||||
def _get_worker_hs_config(self) -> dict:
|
||||
config = super()._get_worker_hs_config()
|
||||
config = self.default_config()
|
||||
config["worker_app"] = "synapse.app.client_reader"
|
||||
config["worker_replication_host"] = "testserv"
|
||||
config["worker_replication_http_port"] = "8765"
|
||||
return config
|
||||
|
||||
def test_register_single_worker(self):
|
||||
|
||||
@@ -51,6 +51,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
|
||||
|
||||
def default_config(self):
|
||||
conf = super().default_config()
|
||||
conf["redis"] = {"enabled": "true"}
|
||||
conf["stream_writers"] = {"events": ["worker1", "worker2"]}
|
||||
conf["instance_map"] = {
|
||||
"worker1": {"host": "testserv", "port": 1001},
|
||||
|
||||
@@ -24,7 +24,6 @@ from synapse.util import Clock
|
||||
from synapse.visibility import filter_events_for_client
|
||||
|
||||
from tests import unittest
|
||||
from tests.unittest import override_config
|
||||
|
||||
one_hour_ms = 3600000
|
||||
one_day_ms = one_hour_ms * 24
|
||||
@@ -39,10 +38,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||
config = self.default_config()
|
||||
|
||||
# merge this default retention config with anything that was specified in
|
||||
# @override_config
|
||||
retention_config = {
|
||||
config["retention"] = {
|
||||
"enabled": True,
|
||||
"default_policy": {
|
||||
"min_lifetime": one_day_ms,
|
||||
@@ -51,8 +47,6 @@ class RetentionTestCase(unittest.HomeserverTestCase):
|
||||
"allowed_lifetime_min": one_day_ms,
|
||||
"allowed_lifetime_max": one_day_ms * 3,
|
||||
}
|
||||
retention_config.update(config.get("retention", {}))
|
||||
config["retention"] = retention_config
|
||||
|
||||
self.hs = self.setup_test_homeserver(config=config)
|
||||
|
||||
@@ -121,20 +115,22 @@ class RetentionTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
self._test_retention_event_purged(room_id, one_day_ms * 2)
|
||||
|
||||
@override_config({"retention": {"purge_jobs": [{"interval": "5d"}]}})
|
||||
def test_visibility(self) -> None:
|
||||
"""Tests that synapse.visibility.filter_events_for_client correctly filters out
|
||||
outdated events, even if the purge job hasn't got to them yet.
|
||||
|
||||
We do this by setting a very long time between purge jobs.
|
||||
outdated events
|
||||
"""
|
||||
store = self.hs.get_datastores().main
|
||||
storage = self.hs.get_storage()
|
||||
room_id = self.helper.create_room_as(self.user_id, tok=self.token)
|
||||
events = []
|
||||
|
||||
# Send a first event, which should be filtered out at the end of the test.
|
||||
resp = self.helper.send(room_id=room_id, body="1", tok=self.token)
|
||||
first_event_id = resp.get("event_id")
|
||||
|
||||
# Get the event from the store so that we end up with a FrozenEvent that we can
|
||||
# give to filter_events_for_client. We need to do this now because the event won't
|
||||
# be in the database anymore after it has expired.
|
||||
events.append(self.get_success(store.get_event(resp.get("event_id"))))
|
||||
|
||||
# Advance the time by 2 days. We're using the default retention policy, therefore
|
||||
# after this the first event will still be valid.
|
||||
@@ -142,17 +138,16 @@ class RetentionTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
# Send another event, which shouldn't get filtered out.
|
||||
resp = self.helper.send(room_id=room_id, body="2", tok=self.token)
|
||||
|
||||
valid_event_id = resp.get("event_id")
|
||||
|
||||
events.append(self.get_success(store.get_event(valid_event_id)))
|
||||
|
||||
# Advance the time by another 2 days. After this, the first event should be
|
||||
# outdated but not the second one.
|
||||
self.reactor.advance(one_day_ms * 2 / 1000)
|
||||
|
||||
# Fetch the events, and run filter_events_for_client on them
|
||||
events = self.get_success(
|
||||
store.get_events_as_list([first_event_id, valid_event_id])
|
||||
)
|
||||
self.assertEqual(2, len(events), "events retrieved from database")
|
||||
# Run filter_events_for_client with our list of FrozenEvents.
|
||||
filtered_events = self.get_success(
|
||||
filter_events_for_client(storage, self.user_id, events)
|
||||
)
|
||||
|
||||
@@ -1,18 +1,3 @@
|
||||
# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from http import HTTPStatus
|
||||
from unittest.mock import Mock, call
|
||||
|
||||
from twisted.internet import defer, reactor
|
||||
@@ -26,14 +11,14 @@ from tests.utils import MockClock
|
||||
|
||||
|
||||
class HttpTransactionCacheTestCase(unittest.TestCase):
|
||||
def setUp(self) -> None:
|
||||
def setUp(self):
|
||||
self.clock = MockClock()
|
||||
self.hs = Mock()
|
||||
self.hs.get_clock = Mock(return_value=self.clock)
|
||||
self.hs.get_auth = Mock()
|
||||
self.cache = HttpTransactionCache(self.hs)
|
||||
|
||||
self.mock_http_response = (HTTPStatus.OK, "GOOD JOB!")
|
||||
self.mock_http_response = (200, "GOOD JOB!")
|
||||
self.mock_key = "foo"
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
||||
@@ -16,7 +16,7 @@ import shutil
|
||||
import tempfile
|
||||
from binascii import unhexlify
|
||||
from io import BytesIO
|
||||
from typing import Any, BinaryIO, Dict, List, Optional, Union
|
||||
from typing import Optional
|
||||
from unittest.mock import Mock
|
||||
from urllib import parse
|
||||
|
||||
@@ -26,24 +26,18 @@ from PIL import Image as Image
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.internet.defer import Deferred
|
||||
from twisted.test.proto_helpers import MemoryReactor
|
||||
|
||||
from synapse.events import EventBase
|
||||
from synapse.events.spamcheck import load_legacy_spam_checkers
|
||||
from synapse.logging.context import make_deferred_yieldable
|
||||
from synapse.module_api import ModuleApi
|
||||
from synapse.rest import admin
|
||||
from synapse.rest.client import login
|
||||
from synapse.rest.media.v1._base import FileInfo
|
||||
from synapse.rest.media.v1.filepath import MediaFilePaths
|
||||
from synapse.rest.media.v1.media_storage import MediaStorage, ReadableFileWrapper
|
||||
from synapse.rest.media.v1.media_storage import MediaStorage
|
||||
from synapse.rest.media.v1.storage_provider import FileStorageProviderBackend
|
||||
from synapse.server import HomeServer
|
||||
from synapse.types import RoomAlias
|
||||
from synapse.util import Clock
|
||||
|
||||
from tests import unittest
|
||||
from tests.server import FakeChannel, FakeSite, make_request
|
||||
from tests.server import FakeSite, make_request
|
||||
from tests.test_utils import SMALL_PNG
|
||||
from tests.utils import default_config
|
||||
|
||||
@@ -52,7 +46,7 @@ class MediaStorageTests(unittest.HomeserverTestCase):
|
||||
|
||||
needs_threadpool = True
|
||||
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
def prepare(self, reactor, clock, hs):
|
||||
self.test_dir = tempfile.mkdtemp(prefix="synapse-tests-")
|
||||
self.addCleanup(shutil.rmtree, self.test_dir)
|
||||
|
||||
@@ -68,7 +62,7 @@ class MediaStorageTests(unittest.HomeserverTestCase):
|
||||
hs, self.primary_base_path, self.filepaths, storage_providers
|
||||
)
|
||||
|
||||
def test_ensure_media_is_in_local_cache(self) -> None:
|
||||
def test_ensure_media_is_in_local_cache(self):
|
||||
media_id = "some_media_id"
|
||||
test_body = "Test\n"
|
||||
|
||||
@@ -111,7 +105,7 @@ class MediaStorageTests(unittest.HomeserverTestCase):
|
||||
self.assertEqual(test_body, body)
|
||||
|
||||
|
||||
@attr.s(auto_attribs=True, slots=True, frozen=True)
|
||||
@attr.s(slots=True, frozen=True)
|
||||
class _TestImage:
|
||||
"""An image for testing thumbnailing with the expected results
|
||||
|
||||
@@ -127,18 +121,18 @@ class _TestImage:
|
||||
a 404 is expected.
|
||||
"""
|
||||
|
||||
data: bytes
|
||||
content_type: bytes
|
||||
extension: bytes
|
||||
expected_cropped: Optional[bytes] = None
|
||||
expected_scaled: Optional[bytes] = None
|
||||
expected_found: bool = True
|
||||
data = attr.ib(type=bytes)
|
||||
content_type = attr.ib(type=bytes)
|
||||
extension = attr.ib(type=bytes)
|
||||
expected_cropped = attr.ib(type=Optional[bytes], default=None)
|
||||
expected_scaled = attr.ib(type=Optional[bytes], default=None)
|
||||
expected_found = attr.ib(default=True, type=bool)
|
||||
|
||||
|
||||
@parameterized_class(
|
||||
("test_image",),
|
||||
[
|
||||
# small png
|
||||
# smoll png
|
||||
(
|
||||
_TestImage(
|
||||
SMALL_PNG,
|
||||
@@ -199,17 +193,11 @@ class MediaRepoTests(unittest.HomeserverTestCase):
|
||||
hijack_auth = True
|
||||
user_id = "@test:user"
|
||||
|
||||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||
def make_homeserver(self, reactor, clock):
|
||||
|
||||
self.fetches = []
|
||||
|
||||
def get_file(
|
||||
destination: str,
|
||||
path: str,
|
||||
output_stream: BinaryIO,
|
||||
args: Optional[Dict[str, Union[str, List[str]]]] = None,
|
||||
max_size: Optional[int] = None,
|
||||
) -> Deferred:
|
||||
def get_file(destination, path, output_stream, args=None, max_size=None):
|
||||
"""
|
||||
Returns tuple[int,dict,str,int] of file length, response headers,
|
||||
absolute URI, and response code.
|
||||
@@ -250,7 +238,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
|
||||
|
||||
return hs
|
||||
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
def prepare(self, reactor, clock, hs):
|
||||
|
||||
media_resource = hs.get_media_repository_resource()
|
||||
self.download_resource = media_resource.children[b"download"]
|
||||
@@ -260,9 +248,8 @@ class MediaRepoTests(unittest.HomeserverTestCase):
|
||||
|
||||
self.media_id = "example.com/12345"
|
||||
|
||||
def _req(
|
||||
self, content_disposition: Optional[bytes], include_content_type: bool = True
|
||||
) -> FakeChannel:
|
||||
def _req(self, content_disposition, include_content_type=True):
|
||||
|
||||
channel = make_request(
|
||||
self.reactor,
|
||||
FakeSite(self.download_resource, self.reactor),
|
||||
@@ -301,7 +288,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
|
||||
|
||||
return channel
|
||||
|
||||
def test_handle_missing_content_type(self) -> None:
|
||||
def test_handle_missing_content_type(self):
|
||||
channel = self._req(
|
||||
b"inline; filename=out" + self.test_image.extension,
|
||||
include_content_type=False,
|
||||
@@ -312,7 +299,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
|
||||
headers.getRawHeaders(b"Content-Type"), [b"application/octet-stream"]
|
||||
)
|
||||
|
||||
def test_disposition_filename_ascii(self) -> None:
|
||||
def test_disposition_filename_ascii(self):
|
||||
"""
|
||||
If the filename is filename=<ascii> then Synapse will decode it as an
|
||||
ASCII string, and use filename= in the response.
|
||||
@@ -328,7 +315,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
|
||||
[b"inline; filename=out" + self.test_image.extension],
|
||||
)
|
||||
|
||||
def test_disposition_filenamestar_utf8escaped(self) -> None:
|
||||
def test_disposition_filenamestar_utf8escaped(self):
|
||||
"""
|
||||
If the filename is filename=*utf8''<utf8 escaped> then Synapse will
|
||||
correctly decode it as the UTF-8 string, and use filename* in the
|
||||
@@ -348,7 +335,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
|
||||
[b"inline; filename*=utf-8''" + filename + self.test_image.extension],
|
||||
)
|
||||
|
||||
def test_disposition_none(self) -> None:
|
||||
def test_disposition_none(self):
|
||||
"""
|
||||
If there is no filename, one isn't passed on in the Content-Disposition
|
||||
of the request.
|
||||
@@ -361,26 +348,26 @@ class MediaRepoTests(unittest.HomeserverTestCase):
|
||||
)
|
||||
self.assertEqual(headers.getRawHeaders(b"Content-Disposition"), None)
|
||||
|
||||
def test_thumbnail_crop(self) -> None:
|
||||
def test_thumbnail_crop(self):
|
||||
"""Test that a cropped remote thumbnail is available."""
|
||||
self._test_thumbnail(
|
||||
"crop", self.test_image.expected_cropped, self.test_image.expected_found
|
||||
)
|
||||
|
||||
def test_thumbnail_scale(self) -> None:
|
||||
def test_thumbnail_scale(self):
|
||||
"""Test that a scaled remote thumbnail is available."""
|
||||
self._test_thumbnail(
|
||||
"scale", self.test_image.expected_scaled, self.test_image.expected_found
|
||||
)
|
||||
|
||||
def test_invalid_type(self) -> None:
|
||||
def test_invalid_type(self):
|
||||
"""An invalid thumbnail type is never available."""
|
||||
self._test_thumbnail("invalid", None, False)
|
||||
|
||||
@unittest.override_config(
|
||||
{"thumbnail_sizes": [{"width": 32, "height": 32, "method": "scale"}]}
|
||||
)
|
||||
def test_no_thumbnail_crop(self) -> None:
|
||||
def test_no_thumbnail_crop(self):
|
||||
"""
|
||||
Override the config to generate only scaled thumbnails, but request a cropped one.
|
||||
"""
|
||||
@@ -389,13 +376,13 @@ class MediaRepoTests(unittest.HomeserverTestCase):
|
||||
@unittest.override_config(
|
||||
{"thumbnail_sizes": [{"width": 32, "height": 32, "method": "crop"}]}
|
||||
)
|
||||
def test_no_thumbnail_scale(self) -> None:
|
||||
def test_no_thumbnail_scale(self):
|
||||
"""
|
||||
Override the config to generate only cropped thumbnails, but request a scaled one.
|
||||
"""
|
||||
self._test_thumbnail("scale", None, False)
|
||||
|
||||
def test_thumbnail_repeated_thumbnail(self) -> None:
|
||||
def test_thumbnail_repeated_thumbnail(self):
|
||||
"""Test that fetching the same thumbnail works, and deleting the on disk
|
||||
thumbnail regenerates it.
|
||||
"""
|
||||
@@ -456,9 +443,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
|
||||
channel.result["body"],
|
||||
)
|
||||
|
||||
def _test_thumbnail(
|
||||
self, method: str, expected_body: Optional[bytes], expected_found: bool
|
||||
) -> None:
|
||||
def _test_thumbnail(self, method, expected_body, expected_found):
|
||||
params = "?width=32&height=32&method=" + method
|
||||
channel = make_request(
|
||||
self.reactor,
|
||||
@@ -500,7 +485,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
|
||||
)
|
||||
|
||||
@parameterized.expand([("crop", 16), ("crop", 64), ("scale", 16), ("scale", 64)])
|
||||
def test_same_quality(self, method: str, desired_size: int) -> None:
|
||||
def test_same_quality(self, method, desired_size):
|
||||
"""Test that choosing between thumbnails with the same quality rating succeeds.
|
||||
|
||||
We are not particular about which thumbnail is chosen."""
|
||||
@@ -536,7 +521,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
|
||||
)
|
||||
)
|
||||
|
||||
def test_x_robots_tag_header(self) -> None:
|
||||
def test_x_robots_tag_header(self):
|
||||
"""
|
||||
Tests that the `X-Robots-Tag` header is present, which informs web crawlers
|
||||
to not index, archive, or follow links in media.
|
||||
@@ -555,38 +540,29 @@ class TestSpamChecker:
|
||||
`evil`.
|
||||
"""
|
||||
|
||||
def __init__(self, config: Dict[str, Any], api: ModuleApi) -> None:
|
||||
def __init__(self, config, api):
|
||||
self.config = config
|
||||
self.api = api
|
||||
|
||||
def parse_config(config: Dict[str, Any]) -> Dict[str, Any]:
|
||||
def parse_config(config):
|
||||
return config
|
||||
|
||||
async def check_event_for_spam(self, event: EventBase) -> Union[bool, str]:
|
||||
async def check_event_for_spam(self, foo):
|
||||
return False # allow all events
|
||||
|
||||
async def user_may_invite(
|
||||
self,
|
||||
inviter_userid: str,
|
||||
invitee_userid: str,
|
||||
room_id: str,
|
||||
) -> bool:
|
||||
async def user_may_invite(self, inviter_userid, invitee_userid, room_id):
|
||||
return True # allow all invites
|
||||
|
||||
async def user_may_create_room(self, userid: str) -> bool:
|
||||
async def user_may_create_room(self, userid):
|
||||
return True # allow all room creations
|
||||
|
||||
async def user_may_create_room_alias(
|
||||
self, userid: str, room_alias: RoomAlias
|
||||
) -> bool:
|
||||
async def user_may_create_room_alias(self, userid, room_alias):
|
||||
return True # allow all room aliases
|
||||
|
||||
async def user_may_publish_room(self, userid: str, room_id: str) -> bool:
|
||||
async def user_may_publish_room(self, userid, room_id):
|
||||
return True # allow publishing of all rooms
|
||||
|
||||
async def check_media_file_for_spam(
|
||||
self, file_wrapper: ReadableFileWrapper, file_info: FileInfo
|
||||
) -> bool:
|
||||
async def check_media_file_for_spam(self, file_wrapper, file_info) -> bool:
|
||||
buf = BytesIO()
|
||||
await file_wrapper.write_chunks_to(buf.write)
|
||||
|
||||
@@ -599,7 +575,7 @@ class SpamCheckerTestCase(unittest.HomeserverTestCase):
|
||||
admin.register_servlets,
|
||||
]
|
||||
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
def prepare(self, reactor, clock, hs):
|
||||
self.user = self.register_user("user", "pass")
|
||||
self.tok = self.login("user", "pass")
|
||||
|
||||
@@ -610,7 +586,7 @@ class SpamCheckerTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
load_legacy_spam_checkers(hs)
|
||||
|
||||
def default_config(self) -> Dict[str, Any]:
|
||||
def default_config(self):
|
||||
config = default_config("test")
|
||||
|
||||
config.update(
|
||||
@@ -626,13 +602,13 @@ class SpamCheckerTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
return config
|
||||
|
||||
def test_upload_innocent(self) -> None:
|
||||
def test_upload_innocent(self):
|
||||
"""Attempt to upload some innocent data that should be allowed."""
|
||||
self.helper.upload_media(
|
||||
self.upload_resource, SMALL_PNG, tok=self.tok, expect_code=200
|
||||
)
|
||||
|
||||
def test_upload_ban(self) -> None:
|
||||
def test_upload_ban(self):
|
||||
"""Attempt to upload some data that includes bytes "evil", which should
|
||||
get rejected by the spam checker.
|
||||
"""
|
||||
|
||||
@@ -16,21 +16,16 @@ import base64
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from typing import Any, Dict, Optional, Sequence, Tuple, Type
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from twisted.internet._resolver import HostResolution
|
||||
from twisted.internet.address import IPv4Address, IPv6Address
|
||||
from twisted.internet.error import DNSLookupError
|
||||
from twisted.internet.interfaces import IAddress, IResolutionReceiver
|
||||
from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactor
|
||||
from twisted.test.proto_helpers import AccumulatingProtocol
|
||||
|
||||
from synapse.config.oembed import OEmbedEndpointConfig
|
||||
from synapse.rest.media.v1.media_repository import MediaRepositoryResource
|
||||
from synapse.rest.media.v1.preview_url_resource import IMAGE_CACHE_EXPIRY_MS
|
||||
from synapse.server import HomeServer
|
||||
from synapse.types import JsonDict
|
||||
from synapse.util import Clock
|
||||
from synapse.util.stringutils import parse_and_validate_mxc_uri
|
||||
|
||||
from tests import unittest
|
||||
@@ -57,7 +52,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
|
||||
b"</head></html>"
|
||||
)
|
||||
|
||||
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|
||||
def make_homeserver(self, reactor, clock):
|
||||
|
||||
config = self.default_config()
|
||||
config["url_preview_enabled"] = True
|
||||
@@ -118,22 +113,22 @@ class URLPreviewTests(unittest.HomeserverTestCase):
|
||||
|
||||
return hs
|
||||
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
|
||||
def prepare(self, reactor, clock, hs):
|
||||
|
||||
self.media_repo = hs.get_media_repository_resource()
|
||||
self.preview_url = self.media_repo.children[b"preview_url"]
|
||||
|
||||
self.lookups: Dict[str, Any] = {}
|
||||
self.lookups = {}
|
||||
|
||||
class Resolver:
|
||||
def resolveHostName(
|
||||
_self,
|
||||
resolutionReceiver: IResolutionReceiver,
|
||||
hostName: str,
|
||||
portNumber: int = 0,
|
||||
addressTypes: Optional[Sequence[Type[IAddress]]] = None,
|
||||
transportSemantics: str = "TCP",
|
||||
) -> IResolutionReceiver:
|
||||
resolutionReceiver,
|
||||
hostName,
|
||||
portNumber=0,
|
||||
addressTypes=None,
|
||||
transportSemantics="TCP",
|
||||
):
|
||||
|
||||
resolution = HostResolution(hostName)
|
||||
resolutionReceiver.resolutionBegan(resolution)
|
||||
@@ -145,9 +140,9 @@ class URLPreviewTests(unittest.HomeserverTestCase):
|
||||
resolutionReceiver.resolutionComplete()
|
||||
return resolutionReceiver
|
||||
|
||||
self.reactor.nameResolver = Resolver() # type: ignore[assignment]
|
||||
self.reactor.nameResolver = Resolver()
|
||||
|
||||
def create_test_resource(self) -> MediaRepositoryResource:
|
||||
def create_test_resource(self):
|
||||
return self.hs.get_media_repository_resource()
|
||||
|
||||
def _assert_small_png(self, json_body: JsonDict) -> None:
|
||||
@@ -158,7 +153,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
|
||||
self.assertEqual(json_body["og:image:type"], "image/png")
|
||||
self.assertEqual(json_body["matrix:image:size"], 67)
|
||||
|
||||
def test_cache_returns_correct_type(self) -> None:
|
||||
def test_cache_returns_correct_type(self):
|
||||
self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
|
||||
|
||||
channel = self.make_request(
|
||||
@@ -212,7 +207,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
|
||||
channel.json_body, {"og:title": "~matrix~", "og:description": "hi"}
|
||||
)
|
||||
|
||||
def test_non_ascii_preview_httpequiv(self) -> None:
|
||||
def test_non_ascii_preview_httpequiv(self):
|
||||
self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
|
||||
|
||||
end_content = (
|
||||
@@ -248,7 +243,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
|
||||
self.assertEqual(channel.code, 200)
|
||||
self.assertEqual(channel.json_body["og:title"], "\u0434\u043a\u0430")
|
||||
|
||||
def test_video_rejected(self) -> None:
|
||||
def test_video_rejected(self):
|
||||
self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
|
||||
|
||||
end_content = b"anything"
|
||||
@@ -284,7 +279,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
|
||||
},
|
||||
)
|
||||
|
||||
def test_audio_rejected(self) -> None:
|
||||
def test_audio_rejected(self):
|
||||
self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
|
||||
|
||||
end_content = b"anything"
|
||||
@@ -320,7 +315,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
|
||||
},
|
||||
)
|
||||
|
||||
def test_non_ascii_preview_content_type(self) -> None:
|
||||
def test_non_ascii_preview_content_type(self):
|
||||
self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
|
||||
|
||||
end_content = (
|
||||
@@ -355,7 +350,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
|
||||
self.assertEqual(channel.code, 200)
|
||||
self.assertEqual(channel.json_body["og:title"], "\u0434\u043a\u0430")
|
||||
|
||||
def test_overlong_title(self) -> None:
|
||||
def test_overlong_title(self):
|
||||
self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
|
||||
|
||||
end_content = (
|
||||
@@ -392,7 +387,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
|
||||
# We should only see the `og:description` field, as `title` is too long and should be stripped out
|
||||
self.assertCountEqual(["og:description"], res.keys())
|
||||
|
||||
def test_ipaddr(self) -> None:
|
||||
def test_ipaddr(self):
|
||||
"""
|
||||
IP addresses can be previewed directly.
|
||||
"""
|
||||
@@ -422,7 +417,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
|
||||
channel.json_body, {"og:title": "~matrix~", "og:description": "hi"}
|
||||
)
|
||||
|
||||
def test_blacklisted_ip_specific(self) -> None:
|
||||
def test_blacklisted_ip_specific(self):
|
||||
"""
|
||||
Blacklisted IP addresses, found via DNS, are not spidered.
|
||||
"""
|
||||
@@ -443,7 +438,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
|
||||
},
|
||||
)
|
||||
|
||||
def test_blacklisted_ip_range(self) -> None:
|
||||
def test_blacklisted_ip_range(self):
|
||||
"""
|
||||
Blacklisted IP ranges, IPs found over DNS, are not spidered.
|
||||
"""
|
||||
@@ -462,7 +457,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
|
||||
},
|
||||
)
|
||||
|
||||
def test_blacklisted_ip_specific_direct(self) -> None:
|
||||
def test_blacklisted_ip_specific_direct(self):
|
||||
"""
|
||||
Blacklisted IP addresses, accessed directly, are not spidered.
|
||||
"""
|
||||
@@ -481,7 +476,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
|
||||
)
|
||||
self.assertEqual(channel.code, 403)
|
||||
|
||||
def test_blacklisted_ip_range_direct(self) -> None:
|
||||
def test_blacklisted_ip_range_direct(self):
|
||||
"""
|
||||
Blacklisted IP ranges, accessed directly, are not spidered.
|
||||
"""
|
||||
@@ -498,7 +493,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
|
||||
},
|
||||
)
|
||||
|
||||
def test_blacklisted_ip_range_whitelisted_ip(self) -> None:
|
||||
def test_blacklisted_ip_range_whitelisted_ip(self):
|
||||
"""
|
||||
Blacklisted but then subsequently whitelisted IP addresses can be
|
||||
spidered.
|
||||
@@ -531,7 +526,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
|
||||
channel.json_body, {"og:title": "~matrix~", "og:description": "hi"}
|
||||
)
|
||||
|
||||
def test_blacklisted_ip_with_external_ip(self) -> None:
|
||||
def test_blacklisted_ip_with_external_ip(self):
|
||||
"""
|
||||
If a hostname resolves a blacklisted IP, even if there's a
|
||||
non-blacklisted one, it will be rejected.
|
||||
@@ -554,7 +549,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
|
||||
},
|
||||
)
|
||||
|
||||
def test_blacklisted_ipv6_specific(self) -> None:
|
||||
def test_blacklisted_ipv6_specific(self):
|
||||
"""
|
||||
Blacklisted IP addresses, found via DNS, are not spidered.
|
||||
"""
|
||||
@@ -577,7 +572,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
|
||||
},
|
||||
)
|
||||
|
||||
def test_blacklisted_ipv6_range(self) -> None:
|
||||
def test_blacklisted_ipv6_range(self):
|
||||
"""
|
||||
Blacklisted IP ranges, IPs found over DNS, are not spidered.
|
||||
"""
|
||||
@@ -596,7 +591,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
|
||||
},
|
||||
)
|
||||
|
||||
def test_OPTIONS(self) -> None:
|
||||
def test_OPTIONS(self):
|
||||
"""
|
||||
OPTIONS returns the OPTIONS.
|
||||
"""
|
||||
@@ -606,7 +601,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
|
||||
self.assertEqual(channel.code, 200)
|
||||
self.assertEqual(channel.json_body, {})
|
||||
|
||||
def test_accept_language_config_option(self) -> None:
|
||||
def test_accept_language_config_option(self):
|
||||
"""
|
||||
Accept-Language header is sent to the remote server
|
||||
"""
|
||||
@@ -657,7 +652,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
|
||||
server.data,
|
||||
)
|
||||
|
||||
def test_data_url(self) -> None:
|
||||
def test_data_url(self):
|
||||
"""
|
||||
Requesting to preview a data URL is not supported.
|
||||
"""
|
||||
@@ -680,7 +675,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
|
||||
|
||||
self.assertEqual(channel.code, 500)
|
||||
|
||||
def test_inline_data_url(self) -> None:
|
||||
def test_inline_data_url(self):
|
||||
"""
|
||||
An inline image (as a data URL) should be parsed properly.
|
||||
"""
|
||||
@@ -717,7 +712,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
|
||||
self.assertEqual(channel.code, 200)
|
||||
self._assert_small_png(channel.json_body)
|
||||
|
||||
def test_oembed_photo(self) -> None:
|
||||
def test_oembed_photo(self):
|
||||
"""Test an oEmbed endpoint which returns a 'photo' type which redirects the preview to a new URL."""
|
||||
self.lookups["publish.twitter.com"] = [(IPv4Address, "10.1.2.3")]
|
||||
self.lookups["cdn.twitter.com"] = [(IPv4Address, "10.1.2.3")]
|
||||
@@ -776,7 +771,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
|
||||
self.assertEqual(body["og:url"], "http://twitter.com/matrixdotorg/status/12345")
|
||||
self._assert_small_png(body)
|
||||
|
||||
def test_oembed_rich(self) -> None:
|
||||
def test_oembed_rich(self):
|
||||
"""Test an oEmbed endpoint which returns HTML content via the 'rich' type."""
|
||||
self.lookups["publish.twitter.com"] = [(IPv4Address, "10.1.2.3")]
|
||||
|
||||
@@ -822,7 +817,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
|
||||
},
|
||||
)
|
||||
|
||||
def test_oembed_format(self) -> None:
|
||||
def test_oembed_format(self):
|
||||
"""Test an oEmbed endpoint which requires the format in the URL."""
|
||||
self.lookups["www.hulu.com"] = [(IPv4Address, "10.1.2.3")]
|
||||
|
||||
@@ -871,7 +866,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
|
||||
},
|
||||
)
|
||||
|
||||
def test_oembed_autodiscovery(self) -> None:
|
||||
def test_oembed_autodiscovery(self):
|
||||
"""
|
||||
Autodiscovery works by finding the link in the HTML response and then requesting an oEmbed URL.
|
||||
1. Request a preview of a URL which is not known to the oEmbed code.
|
||||
@@ -967,7 +962,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
|
||||
)
|
||||
self._assert_small_png(body)
|
||||
|
||||
def _download_image(self) -> Tuple[str, str]:
|
||||
def _download_image(self):
|
||||
"""Downloads an image into the URL cache.
|
||||
Returns:
|
||||
A (host, media_id) tuple representing the MXC URI of the image.
|
||||
@@ -1000,7 +995,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
|
||||
self.assertIsNone(_port)
|
||||
return host, media_id
|
||||
|
||||
def test_storage_providers_exclude_files(self) -> None:
|
||||
def test_storage_providers_exclude_files(self):
|
||||
"""Test that files are not stored in or fetched from storage providers."""
|
||||
host, media_id = self._download_image()
|
||||
|
||||
@@ -1042,7 +1037,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
|
||||
"URL cache file was unexpectedly retrieved from a storage provider",
|
||||
)
|
||||
|
||||
def test_storage_providers_exclude_thumbnails(self) -> None:
|
||||
def test_storage_providers_exclude_thumbnails(self):
|
||||
"""Test that thumbnails are not stored in or fetched from storage providers."""
|
||||
host, media_id = self._download_image()
|
||||
|
||||
@@ -1095,7 +1090,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
|
||||
"URL cache thumbnail was unexpectedly retrieved from a storage provider",
|
||||
)
|
||||
|
||||
def test_cache_expiry(self) -> None:
|
||||
def test_cache_expiry(self):
|
||||
"""Test that URL cache files and thumbnails are cleaned up properly on expiry."""
|
||||
self.preview_url.clock = MockClock()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user