1
0

Compare commits

..

2 Commits

Author SHA1 Message Date
Andrew Morgan
3360be1829 Add header margin change 2022-03-10 18:35:08 +00:00
Andrew Morgan
19ca533bcc Rename indent-section-headers -> section-headers
to be more generic
2022-03-10 18:34:58 +00:00
49 changed files with 306 additions and 502 deletions

View File

@@ -33,7 +33,7 @@ site-url = "/synapse/"
additional-css = [
"docs/website_files/table-of-contents.css",
"docs/website_files/remove-nav-buttons.css",
"docs/website_files/indent-section-headers.css",
"docs/website_files/section-headers.css",
]
additional-js = ["docs/website_files/table-of-contents.js"]
theme = "docs/website_files/theme"

View File

@@ -1 +1 @@
Add type hints to tests files.
Add type hints to `tests/rest/client`.

View File

@@ -1 +1 @@
Add type hints to tests files.
Add type hints to `tests/rest`.

View File

@@ -1 +0,0 @@
Document that the `typing`, `to_device`, `account_data`, `receipts`, and `presence` stream writer can only be used on a single worker.

View File

@@ -1 +0,0 @@
Avoid trying to calculate the state at outlier events.

View File

@@ -1 +0,0 @@
Fix a misleading comment in the function `check_event_for_spam`.

View File

@@ -1 +0,0 @@
Document that contributors can sign off privately by email.

View File

@@ -1 +0,0 @@
Remove unnecessary `pass` statements.

View File

@@ -1 +0,0 @@
Add type hints to tests files.

View File

@@ -1 +0,0 @@
Add type hints to tests files.

View File

@@ -1 +0,0 @@
Update the SSO username picker template to comply with SIWA guidelines.

View File

@@ -458,17 +458,6 @@ Git allows you to add this signoff automatically when using the `-s`
flag to `git commit`, which uses the name and email set in your
`user.name` and `user.email` git configs.
### Private Sign off
If you would like to provide your legal name privately to the Matrix.org
Foundation (instead of in a public commit or comment), you can do so
by emailing your legal name and a link to the pull request to
[dco@matrix.org](mailto:dco@matrix.org?subject=Private%20sign%20off).
It helps to include "sign off" or similar in the subject line. You will then
be instructed further.
Once private sign off is complete, doing so for future contributions will not
be required.
# 10. Turn feedback into better code.

View File

@@ -1947,13 +1947,8 @@ saml2_config:
#
# localpart_template: Jinja2 template for the localpart of the MXID.
# If this is not set, the user will be prompted to choose their
# own username (see the documentation for the
# 'sso_auth_account_details.html' template).
#
# confirm_localpart: Whether to prompt the user to validate (or
# change) the generated localpart (see the documentation for the
# 'sso_auth_account_details.html' template), instead of
# registering the account right away.
# own username (see 'sso_auth_account_details.html' in the 'sso'
# section of this file).
#
# display_name_template: Jinja2 template for the display name to set
# on first login. If unset, no displayname will be set.

View File

@@ -176,11 +176,8 @@ Below are the templates Synapse will look for when generating pages related to S
for the brand of the IdP
* `user_attributes`: an object containing details about the user that
we received from the IdP. May have the following attributes:
* `display_name`: the user's display name
* `emails`: a list of email addresses
* `localpart`: the local part of the Matrix user ID to register,
if `localpart_template` is set in the mapping provider configuration (empty
string if not)
* display_name: the user's display_name
* emails: a list of email addresses
The template should render a form which submits the following fields:
* `username`: the localpart of the user's chosen user id
* `sso_new_user_consent.html`: HTML page allowing the user to consent to the

View File

@@ -1,7 +0,0 @@
/*
* Indents each chapter title in the left sidebar so that they aren't
* at the same level as the section headers.
*/
.chapter-item {
margin-left: 1em;
}

View File

@@ -0,0 +1,20 @@
/*
* Indents each chapter title in the left sidebar so that they aren't
* at the same level as the section headers.
*/
.chapter-item {
margin-left: 1em;
}
/*
* Prevents a large gap between successive section headers.
*
* mdbook sets 'margin-top: 2.5em' on h2 and h3 headers. This makes sense when separating
* a header from the paragraph beforehand, but has the downside of introducing a large
* gap between headers that are next to each other with no text in between.
*
* This rule reduces the margin in this case.
*/
h1 + h2, h2 + h3 {
margin-top: 1.0em;
}

View File

@@ -351,11 +351,8 @@ is only supported with Redis-based replication.)
To enable this, the worker must have a HTTP replication listener configured,
have a `worker_name` and be listed in the `instance_map` config. The same worker
can handle multiple streams, but unless otherwise documented, each stream can only
have a single writer.
For example, to move event persistence off to a dedicated worker, the shared
configuration would include:
can handle multiple streams. For example, to move event persistence off to a
dedicated worker, the shared configuration would include:
```yaml
instance_map:
@@ -373,8 +370,8 @@ streams and the endpoints associated with them:
##### The `events` stream
The `events` stream experimentally supports having multiple writers, where work
is sharded between them by room ID. Note that you *must* restart all worker
The `events` stream also experimentally supports having multiple writers, where
work is sharded between them by room ID. Note that you *must* restart all worker
instances when adding or removing event persisters. An example `stream_writers`
configuration with multiple writers:
@@ -387,38 +384,38 @@ stream_writers:
##### The `typing` stream
The following endpoints should be routed directly to the worker configured as
the stream writer for the `typing` stream:
The following endpoints should be routed directly to the workers configured as
stream writers for the `typing` stream:
^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/typing
##### The `to_device` stream
The following endpoints should be routed directly to the worker configured as
the stream writer for the `to_device` stream:
The following endpoints should be routed directly to the workers configured as
stream writers for the `to_device` stream:
^/_matrix/client/(api/v1|r0|v3|unstable)/sendToDevice/
##### The `account_data` stream
The following endpoints should be routed directly to the worker configured as
the stream writer for the `account_data` stream:
The following endpoints should be routed directly to the workers configured as
stream writers for the `account_data` stream:
^/_matrix/client/(api/v1|r0|v3|unstable)/.*/tags
^/_matrix/client/(api/v1|r0|v3|unstable)/.*/account_data
##### The `receipts` stream
The following endpoints should be routed directly to the worker configured as
the stream writer for the `receipts` stream:
The following endpoints should be routed directly to the workers configured as
stream writers for the `receipts` stream:
^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/receipt
^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/read_markers
##### The `presence` stream
The following endpoints should be routed directly to the worker configured as
the stream writer for the `presence` stream:
The following endpoints should be routed directly to the workers configured as
stream writers for the `presence` stream:
^/_matrix/client/(api/v1|r0|v3|unstable)/presence/

View File

@@ -90,6 +90,7 @@ exclude = (?x)
|tests/push/test_push_rule_evaluator.py
|tests/rest/client/test_transactions.py
|tests/rest/media/v1/test_media_storage.py
|tests/rest/media/v1/test_url_preview.py
|tests/scripts/test_new_matrix_user.py
|tests/server.py
|tests/server_notices/test_resource_limits_server_notices.py

View File

@@ -182,13 +182,8 @@ class OIDCConfig(Config):
#
# localpart_template: Jinja2 template for the localpart of the MXID.
# If this is not set, the user will be prompted to choose their
# own username (see the documentation for the
# 'sso_auth_account_details.html' template).
#
# confirm_localpart: Whether to prompt the user to validate (or
# change) the generated localpart (see the documentation for the
# 'sso_auth_account_details.html' template), instead of
# registering the account right away.
# own username (see 'sso_auth_account_details.html' in the 'sso'
# section of this file).
#
# display_name_template: Jinja2 template for the display name to set
# on first login. If unset, no displayname will be set.

View File

@@ -245,8 +245,8 @@ class SpamChecker:
"""Checks if a given event is considered "spammy" by this server.
If the server considers an event spammy, then it will be rejected if
sent by a local user. If it is sent by a user on another server, the
event is soft-failed.
sent by a local user. If it is sent by a user on another server, then
users receive a blank event.
Args:
event: the event to be checked

View File

@@ -371,6 +371,7 @@ class DeviceHandler(DeviceWorkerHandler):
log_kv(
{"reason": "User doesn't have device id.", "device_id": device_id}
)
pass
else:
raise
@@ -413,6 +414,7 @@ class DeviceHandler(DeviceWorkerHandler):
# no match
set_tag("error", True)
set_tag("reason", "User doesn't have that device id.")
pass
else:
raise

View File

@@ -1228,7 +1228,6 @@ class OidcSessionData:
class UserAttributeDict(TypedDict):
localpart: Optional[str]
confirm_localpart: bool
display_name: Optional[str]
emails: List[str]
@@ -1317,7 +1316,6 @@ class JinjaOidcMappingConfig:
display_name_template: Optional[Template]
email_template: Optional[Template]
extra_attributes: Dict[str, Template]
confirm_localpart: bool = False
class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
@@ -1359,17 +1357,12 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
"invalid jinja template", path=["extra_attributes", key]
) from e
confirm_localpart = config.get("confirm_localpart") or False
if not isinstance(confirm_localpart, bool):
raise ConfigError("must be a bool", path=["confirm_localpart"])
return JinjaOidcMappingConfig(
subject_claim=subject_claim,
localpart_template=localpart_template,
display_name_template=display_name_template,
email_template=email_template,
extra_attributes=extra_attributes,
confirm_localpart=confirm_localpart,
)
def get_remote_user_id(self, userinfo: UserInfo) -> str:
@@ -1405,10 +1398,7 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
emails.append(email)
return UserAttributeDict(
localpart=localpart,
display_name=display_name,
emails=emails,
confirm_localpart=self._config.confirm_localpart,
localpart=localpart, display_name=display_name, emails=emails
)
async def get_extra_attributes(self, userinfo: UserInfo, token: Token) -> JsonDict:

View File

@@ -267,6 +267,7 @@ class BasePresenceHandler(abc.ABC):
is_syncing: Whether or not the user is now syncing
sync_time_msec: Time in ms when the user was last syncing
"""
pass
async def update_external_syncs_clear(self, process_id: str) -> None:
"""Marks all users that had been marked as syncing by a given process
@@ -276,6 +277,7 @@ class BasePresenceHandler(abc.ABC):
This is a no-op when presence is handled by a different worker.
"""
pass
async def process_replication_rows(
self, stream_name: str, instance_name: str, token: int, rows: list

View File

@@ -132,7 +132,6 @@ class UserAttributes:
# if `None`, the mapper has not picked a userid, and the user should be prompted to
# enter one.
localpart: Optional[str]
confirm_localpart: bool = False
display_name: Optional[str] = None
emails: Collection[str] = attr.Factory(list)
@@ -562,10 +561,9 @@ class SsoHandler:
# Must provide either attributes or session, not both
assert (attributes is not None) != (session is not None)
if (
attributes
and (attributes.localpart is None or attributes.confirm_localpart is True)
) or (session and session.chosen_localpart is None):
if (attributes and attributes.localpart is None) or (
session and session.chosen_localpart is None
):
return b"/_synapse/client/pick_username/account_details"
elif self._consent_at_registration and not (
session and session.terms_accepted_version

View File

@@ -120,6 +120,7 @@ class ByteParser(ByteWriteable, Generic[T], abc.ABC):
"""Called when response has finished streaming and the parser should
return the final result (or error).
"""
pass
@attr.s(slots=True, frozen=True, auto_attribs=True)
@@ -600,6 +601,7 @@ class MatrixFederationHttpClient:
response.code,
response_phrase,
)
pass
else:
logger.info(
"{%s} [%s] Got response headers: %d %s",

View File

@@ -233,6 +233,7 @@ class HttpServer(Protocol):
servlet_classname (str): The name of the handler to be used in prometheus
and opentracing logs.
"""
pass
class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):

View File

@@ -118,21 +118,13 @@ class RedisSubscriber(txredisapi.SubscriberProtocol):
# have successfully subscribed to the stream - otherwise we might miss the
# POSITION response sent back by the other end.
logger.info("Sending redis SUBSCRIBE for %s", self.synapse_stream_name)
try:
await make_deferred_yieldable(self.subscribe(self.synapse_stream_name))
except txredisapi.ConnectionError:
# The connection died, the factory will attempt to reconnect.
return
await make_deferred_yieldable(self.subscribe(self.synapse_stream_name))
logger.info(
"Successfully subscribed to redis stream, sending REPLICATE command"
)
# If the connection has been severed for some reason, bail.
if not self.connected:
return
self.synapse_handler.new_connection(self)
await self._async_send_command(ReplicateCommand())
logger.info("REPLICATE successfully sent")
# We send out our positions when there is a new connection in case the
# other side missed updates. We do this for Redis connections as the
@@ -263,15 +255,7 @@ class SynapseRedisFactory(txredisapi.RedisFactory):
replyTimeout=replyTimeout,
convertNumbers=convertNumbers,
)
self.hs = hs
# Set the homeserver reactor as the clock, if this is not done than
# twisted.internet.protocol.ReconnectingClientFactory.retry will default
# to the reactor.
self.clock = hs.get_reactor()
# Send pings every 30 seconds (not that get_clock() returns a Clock, not
# a reactor).
hs.get_clock().looping_call(self._send_ping, 30 * 1000)
@wrap_as_background_process("redis_ping")
@@ -369,7 +353,6 @@ def lazyConnection(
reconnect: bool = True,
password: Optional[str] = None,
replyTimeout: int = 30,
handler: Optional[txredisapi.ConnectionHandler] = None,
) -> txredisapi.ConnectionHandler:
"""Creates a connection to Redis that is lazily set up and reconnects if the
connections is lost.

View File

@@ -130,15 +130,15 @@
</head>
<body>
<header>
<h1>Choose your user name</h1>
<p>This is required to create your account on {{ server_name }}, and you can't change this later.</p>
<h1>Your account is nearly ready</h1>
<p>Check your details before creating an account on {{ server_name }}</p>
</header>
<main>
<form method="post" class="form__input" id="form">
<div class="username_input" id="username_input">
<label for="field-username">Username</label>
<div class="prefix">@</div>
<input type="text" name="username" id="field-username" value="{{ user_attributes.localpart }}" autofocus>
<input type="text" name="username" id="field-username" autofocus>
<div class="postfix">:{{ server_name }}</div>
</div>
<output for="username_input" id="field-username-output"></output>

View File

@@ -298,6 +298,7 @@ class Responder:
Returns:
Resolves once the response has finished being written
"""
pass
def __enter__(self) -> None:
pass

View File

@@ -92,20 +92,12 @@ class AccountDetailsResource(DirectServeHtmlResource):
self._sso_handler.render_error(request, "bad_session", e.msg, code=e.code)
return
# The configuration might mandate going through this step to validate an
# automatically generated localpart, so session.chosen_localpart might already
# be set.
localpart = ""
if session.chosen_localpart is not None:
localpart = session.chosen_localpart
idp_id = session.auth_provider_id
template_params = {
"idp": self._sso_handler.get_identity_providers()[idp_id],
"user_attributes": {
"display_name": session.display_name,
"emails": session.emails,
"localpart": localpart,
},
}

View File

@@ -328,6 +328,7 @@ class HomeServer(metaclass=abc.ABCMeta):
Does nothing in this base class; overridden in derived classes to start the
appropriate listeners.
"""
pass
def setup_background_tasks(self) -> None:
"""

View File

@@ -48,6 +48,8 @@ class ExternalIDReuseException(Exception):
"""Exception if writing an external id for a user fails,
because this external id is given to an other user."""
pass
@attr.s(frozen=True, slots=True, auto_attribs=True)
class TokenLookupResult:

View File

@@ -36,6 +36,7 @@ def run_upgrade(cur, database_engine, config, *args, **kwargs):
config_files = config.appservice.app_service_config_files
except AttributeError:
logger.warning("Could not get app_service_config_files from config")
pass
appservices = load_appservices(config.server.server_name, config_files)

View File

@@ -22,6 +22,8 @@ class TreeCacheNode(dict):
leaves.
"""
pass
class TreeCache:
"""

View File

@@ -15,15 +15,11 @@
from collections import Counter
from unittest.mock import Mock
from twisted.test.proto_helpers import MemoryReactor
import synapse.rest.admin
import synapse.storage
from synapse.api.constants import EventTypes, JoinRules
from synapse.api.room_versions import RoomVersions
from synapse.rest.client import knock, login, room
from synapse.server import HomeServer
from synapse.util import Clock
from tests import unittest
@@ -36,7 +32,7 @@ class ExfiltrateData(unittest.HomeserverTestCase):
knock.register_servlets,
]
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
def prepare(self, reactor, clock, hs):
self.admin_handler = hs.get_admin_handler()
self.user1 = self.register_user("user1", "password")
@@ -45,7 +41,7 @@ class ExfiltrateData(unittest.HomeserverTestCase):
self.user2 = self.register_user("user2", "password")
self.token2 = self.login("user2", "password")
def test_single_public_joined_room(self) -> None:
def test_single_public_joined_room(self):
"""Test that we write *all* events for a public room"""
room_id = self.helper.create_room_as(
self.user1, tok=self.token1, is_public=True
@@ -78,7 +74,7 @@ class ExfiltrateData(unittest.HomeserverTestCase):
self.assertEqual(counter[(EventTypes.Member, self.user1)], 1)
self.assertEqual(counter[(EventTypes.Member, self.user2)], 1)
def test_single_private_joined_room(self) -> None:
def test_single_private_joined_room(self):
"""Tests that we correctly write state when we can't see all events in
a room.
"""
@@ -116,7 +112,7 @@ class ExfiltrateData(unittest.HomeserverTestCase):
self.assertEqual(counter[(EventTypes.Member, self.user1)], 1)
self.assertEqual(counter[(EventTypes.Member, self.user2)], 1)
def test_single_left_room(self) -> None:
def test_single_left_room(self):
"""Tests that we don't see events in the room after we leave."""
room_id = self.helper.create_room_as(self.user1, tok=self.token1)
self.helper.send(room_id, body="Hello!", tok=self.token1)
@@ -148,7 +144,7 @@ class ExfiltrateData(unittest.HomeserverTestCase):
self.assertEqual(counter[(EventTypes.Member, self.user1)], 1)
self.assertEqual(counter[(EventTypes.Member, self.user2)], 2)
def test_single_left_rejoined_private_room(self) -> None:
def test_single_left_rejoined_private_room(self):
"""Tests that see the correct events in private rooms when we
repeatedly join and leave.
"""
@@ -189,7 +185,7 @@ class ExfiltrateData(unittest.HomeserverTestCase):
self.assertEqual(counter[(EventTypes.Member, self.user1)], 1)
self.assertEqual(counter[(EventTypes.Member, self.user2)], 3)
def test_invite(self) -> None:
def test_invite(self):
"""Tests that pending invites get handled correctly."""
room_id = self.helper.create_room_as(self.user1, tok=self.token1)
self.helper.send(room_id, body="Hello!", tok=self.token1)
@@ -208,7 +204,7 @@ class ExfiltrateData(unittest.HomeserverTestCase):
self.assertEqual(args[1].content["membership"], "invite")
self.assertTrue(args[2]) # Assert there is at least one bit of state
def test_knock(self) -> None:
def test_knock(self):
"""Tests that knock get handled correctly."""
# create a knockable v7 room
room_id = self.helper.create_room_as(

View File

@@ -15,12 +15,8 @@ from unittest.mock import Mock
import pymacaroons
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.errors import AuthError, ResourceLimitError
from synapse.rest import admin
from synapse.server import HomeServer
from synapse.util import Clock
from tests import unittest
from tests.test_utils import make_awaitable
@@ -31,7 +27,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
admin.register_servlets,
]
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
def prepare(self, reactor, clock, hs):
self.auth_handler = hs.get_auth_handler()
self.macaroon_generator = hs.get_macaroon_generator()
@@ -46,23 +42,23 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.user1 = self.register_user("a_user", "pass")
def test_macaroon_caveats(self) -> None:
def test_macaroon_caveats(self):
token = self.macaroon_generator.generate_guest_access_token("a_user")
macaroon = pymacaroons.Macaroon.deserialize(token)
def verify_gen(caveat: str) -> bool:
def verify_gen(caveat):
return caveat == "gen = 1"
def verify_user(caveat: str) -> bool:
def verify_user(caveat):
return caveat == "user_id = a_user"
def verify_type(caveat: str) -> bool:
def verify_type(caveat):
return caveat == "type = access"
def verify_nonce(caveat: str) -> bool:
def verify_nonce(caveat):
return caveat.startswith("nonce =")
def verify_guest(caveat: str) -> bool:
def verify_guest(caveat):
return caveat == "guest = true"
v = pymacaroons.Verifier()
@@ -73,7 +69,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
v.satisfy_general(verify_guest)
v.verify(macaroon, self.hs.config.key.macaroon_secret_key)
def test_short_term_login_token_gives_user_id(self) -> None:
def test_short_term_login_token_gives_user_id(self):
token = self.macaroon_generator.generate_short_term_login_token(
self.user1, "", duration_in_ms=5000
)
@@ -88,7 +84,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
AuthError,
)
def test_short_term_login_token_gives_auth_provider(self) -> None:
def test_short_term_login_token_gives_auth_provider(self):
token = self.macaroon_generator.generate_short_term_login_token(
self.user1, auth_provider_id="my_idp"
)
@@ -96,7 +92,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
self.assertEqual(self.user1, res.user_id)
self.assertEqual("my_idp", res.auth_provider_id)
def test_short_term_login_token_cannot_replace_user_id(self) -> None:
def test_short_term_login_token_cannot_replace_user_id(self):
token = self.macaroon_generator.generate_short_term_login_token(
self.user1, "", duration_in_ms=5000
)
@@ -116,7 +112,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
AuthError,
)
def test_mau_limits_disabled(self) -> None:
def test_mau_limits_disabled(self):
self.auth_blocking._limit_usage_by_mau = False
# Ensure does not throw exception
self.get_success(
@@ -131,7 +127,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
)
)
def test_mau_limits_exceeded_large(self) -> None:
def test_mau_limits_exceeded_large(self):
self.auth_blocking._limit_usage_by_mau = True
self.hs.get_datastores().main.get_monthly_active_count = Mock(
return_value=make_awaitable(self.large_number_of_users)
@@ -154,7 +150,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
ResourceLimitError,
)
def test_mau_limits_parity(self) -> None:
def test_mau_limits_parity(self):
# Ensure we're not at the unix epoch.
self.reactor.advance(1)
self.auth_blocking._limit_usage_by_mau = True
@@ -193,7 +189,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
)
)
def test_mau_limits_not_exceeded(self) -> None:
def test_mau_limits_not_exceeded(self):
self.auth_blocking._limit_usage_by_mau = True
self.hs.get_datastores().main.get_monthly_active_count = Mock(
@@ -215,7 +211,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
)
)
def _get_macaroon(self) -> pymacaroons.Macaroon:
def _get_macaroon(self):
token = self.macaroon_generator.generate_short_term_login_token(
self.user1, "", duration_in_ms=5000
)

View File

@@ -39,7 +39,7 @@ class DeactivateAccountTestCase(HomeserverTestCase):
self.user = self.register_user("user", "pass")
self.token = self.login("user", "pass")
def _deactivate_my_account(self) -> None:
def _deactivate_my_account(self):
"""
Deactivates the account `self.user` using `self.token` and asserts
that it returns a 200 success code.

View File

@@ -14,14 +14,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.errors import NotFoundError, SynapseError
from synapse.handlers.device import MAX_DEVICE_DISPLAY_NAME_LEN
from synapse.server import HomeServer
from synapse.util import Clock
import synapse.api.errors
import synapse.handlers.device
import synapse.storage
from tests import unittest
@@ -30,27 +25,28 @@ user2 = "@theresa:bbb"
class DeviceTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
def make_homeserver(self, reactor, clock):
hs = self.setup_test_homeserver("server", federation_http_client=None)
self.handler = hs.get_device_handler()
self.store = hs.get_datastores().main
return hs
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
def prepare(self, reactor, clock, hs):
# These tests assume that it starts 1000 seconds in.
self.reactor.advance(1000)
def test_device_is_created_with_invalid_name(self) -> None:
def test_device_is_created_with_invalid_name(self):
self.get_failure(
self.handler.check_device_registered(
user_id="@boris:foo",
device_id="foo",
initial_device_display_name="a" * (MAX_DEVICE_DISPLAY_NAME_LEN + 1),
initial_device_display_name="a"
* (synapse.handlers.device.MAX_DEVICE_DISPLAY_NAME_LEN + 1),
),
SynapseError,
synapse.api.errors.SynapseError,
)
def test_device_is_created_if_doesnt_exist(self) -> None:
def test_device_is_created_if_doesnt_exist(self):
res = self.get_success(
self.handler.check_device_registered(
user_id="@boris:foo",
@@ -63,7 +59,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
dev = self.get_success(self.handler.store.get_device("@boris:foo", "fco"))
self.assertEqual(dev["display_name"], "display name")
def test_device_is_preserved_if_exists(self) -> None:
def test_device_is_preserved_if_exists(self):
res1 = self.get_success(
self.handler.check_device_registered(
user_id="@boris:foo",
@@ -85,7 +81,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
dev = self.get_success(self.handler.store.get_device("@boris:foo", "fco"))
self.assertEqual(dev["display_name"], "display name")
def test_device_id_is_made_up_if_unspecified(self) -> None:
def test_device_id_is_made_up_if_unspecified(self):
device_id = self.get_success(
self.handler.check_device_registered(
user_id="@theresa:foo",
@@ -97,7 +93,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
dev = self.get_success(self.handler.store.get_device("@theresa:foo", device_id))
self.assertEqual(dev["display_name"], "display")
def test_get_devices_by_user(self) -> None:
def test_get_devices_by_user(self):
self._record_users()
res = self.get_success(self.handler.get_devices_by_user(user1))
@@ -135,7 +131,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
device_map["abc"],
)
def test_get_device(self) -> None:
def test_get_device(self):
self._record_users()
res = self.get_success(self.handler.get_device(user1, "abc"))
@@ -150,19 +146,21 @@ class DeviceTestCase(unittest.HomeserverTestCase):
res,
)
def test_delete_device(self) -> None:
def test_delete_device(self):
self._record_users()
# delete the device
self.get_success(self.handler.delete_device(user1, "abc"))
# check the device was deleted
self.get_failure(self.handler.get_device(user1, "abc"), NotFoundError)
self.get_failure(
self.handler.get_device(user1, "abc"), synapse.api.errors.NotFoundError
)
# we'd like to check the access token was invalidated, but that's a
# bit of a PITA.
def test_delete_device_and_device_inbox(self) -> None:
def test_delete_device_and_device_inbox(self):
self._record_users()
# add an device_inbox
@@ -193,7 +191,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
)
self.assertIsNone(res)
def test_update_device(self) -> None:
def test_update_device(self):
self._record_users()
update = {"display_name": "new display"}
@@ -202,29 +200,32 @@ class DeviceTestCase(unittest.HomeserverTestCase):
res = self.get_success(self.handler.get_device(user1, "abc"))
self.assertEqual(res["display_name"], "new display")
def test_update_device_too_long_display_name(self) -> None:
def test_update_device_too_long_display_name(self):
"""Update a device with a display name that is invalid (too long)."""
self._record_users()
# Request to update a device display name with a new value that is longer than allowed.
update = {"display_name": "a" * (MAX_DEVICE_DISPLAY_NAME_LEN + 1)}
update = {
"display_name": "a"
* (synapse.handlers.device.MAX_DEVICE_DISPLAY_NAME_LEN + 1)
}
self.get_failure(
self.handler.update_device(user1, "abc", update),
SynapseError,
synapse.api.errors.SynapseError,
)
# Ensure the display name was not updated.
res = self.get_success(self.handler.get_device(user1, "abc"))
self.assertEqual(res["display_name"], "display 2")
def test_update_unknown_device(self) -> None:
def test_update_unknown_device(self):
update = {"display_name": "new_display"}
self.get_failure(
self.handler.update_device("user_id", "unknown_device_id", update),
NotFoundError,
synapse.api.errors.NotFoundError,
)
def _record_users(self) -> None:
def _record_users(self):
# check this works for both devices which have a recorded client_ip,
# and those which don't.
self._record_user(user1, "xyz", "display 0")
@@ -237,13 +238,8 @@ class DeviceTestCase(unittest.HomeserverTestCase):
self.reactor.advance(10000)
def _record_user(
self,
user_id: str,
device_id: str,
display_name: str,
access_token: Optional[str] = None,
ip: Optional[str] = None,
) -> None:
self, user_id, device_id, display_name, access_token=None, ip=None
):
device_id = self.get_success(
self.handler.check_device_registered(
user_id=user_id,
@@ -252,7 +248,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
)
)
if access_token is not None and ip is not None:
if ip is not None:
self.get_success(
self.store.insert_client_ip(
user_id, access_token, ip, "user_agent", device_id
@@ -262,7 +258,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
class DehydrationTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
def make_homeserver(self, reactor, clock):
hs = self.setup_test_homeserver("server", federation_http_client=None)
self.handler = hs.get_device_handler()
self.registration = hs.get_registration_handler()
@@ -270,7 +266,7 @@ class DehydrationTestCase(unittest.HomeserverTestCase):
self.store = hs.get_datastores().main
return hs
def test_dehydrate_and_rehydrate_device(self) -> None:
def test_dehydrate_and_rehydrate_device(self):
user_id = "@boris:dehydration"
self.get_success(self.store.register_user(user_id, "foobar"))
@@ -307,7 +303,7 @@ class DehydrationTestCase(unittest.HomeserverTestCase):
access_token=access_token,
device_id="not the right device ID",
),
NotFoundError,
synapse.api.errors.NotFoundError,
)
# dehydrating the right devices should succeed and change our device ID
@@ -335,7 +331,7 @@ class DehydrationTestCase(unittest.HomeserverTestCase):
# make sure that the device ID that we were initially assigned no longer exists
self.get_failure(
self.handler.get_device(user_id, device_id),
NotFoundError,
synapse.api.errors.NotFoundError,
)
# make sure that there's no device available for dehydrating now

View File

@@ -124,6 +124,7 @@ class PasswordCustomAuthProvider:
("m.login.password", ("password",)): self.check_auth,
}
)
pass
def check_auth(self, *args):
return mock_password_provider.check_auth(*args)

View File

@@ -549,6 +549,7 @@ class ModuleApiWorkerTestCase(BaseMultiWorkerStreamTestCase):
def default_config(self):
conf = super().default_config()
conf["redis"] = {"enabled": "true"}
conf["stream_writers"] = {"presence": ["presence_writer"]}
conf["instance_map"] = {
"presence_writer": {"host": "testserv", "port": 1001},

View File

@@ -14,14 +14,20 @@
import logging
from typing import Any, Dict, List, Optional, Tuple
from twisted.internet.address import IPv4Address
from twisted.internet.protocol import Protocol
from twisted.python.failure import Failure
from twisted.web.resource import Resource
from synapse.app.generic_worker import GenericWorkerServer
from synapse.http.site import SynapseRequest, SynapseSite
from synapse.replication.http import ReplicationRestResource
from synapse.replication.tcp.client import ReplicationDataHandler
from synapse.replication.tcp.handler import ReplicationCommandHandler
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
from synapse.replication.tcp.resource import (
ReplicationStreamProtocolFactory,
ServerReplicationStreamProtocol,
)
from synapse.server import HomeServer
from tests import unittest
@@ -35,55 +41,6 @@ except ImportError:
logger = logging.getLogger(__name__)
class FakeOutboundConnector:
"""
A fake connector class, reconnects.
"""
def __init__(self, hs: HomeServer):
self._hs = hs
def stopConnecting(self):
pass
def connect(self):
# Restart replication.
from synapse.replication.tcp.redis import lazyConnection
handler = self._hs.get_outbound_redis_connection()
reactor = self._hs.get_reactor()
reactor.connectTCP(
self._hs.config.redis.redis_host,
self._hs.config.redis.redis_port,
handler._factory,
timeout=30,
bindAddress=None,
)
def getDestination(self):
return "blah"
class FakeReplicationHandlerConnector:
"""
A fake connector class, reconnects.
"""
def __init__(self, hs: HomeServer):
self._hs = hs
def stopConnecting(self):
pass
def connect(self):
# Restart replication.
self._hs.get_replication_command_handler().start_replication(self._hs)
def getDestination(self):
return "blah"
class BaseStreamTestCase(unittest.HomeserverTestCase):
"""Base class for tests of the replication streams"""
@@ -92,33 +49,16 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
if not hiredis:
skip = "Requires hiredis"
def default_config(self):
config = super().default_config()
config["redis"] = {"enabled": True}
return config
def prepare(self, reactor, clock, hs):
# build a replication server
server_factory = ReplicationStreamProtocolFactory(hs)
self.streamer = hs.get_replication_streamer()
# Fake in memory Redis server that servers can connect to.
self._redis_transports = []
self._redis_server = FakeRedisPubSubServer()
# We may have an attempt to connect to redis for the external cache already.
self.connect_any_redis_attempts()
self.server: ServerReplicationStreamProtocol = server_factory.buildProtocol(
IPv4Address("TCP", "127.0.0.1", 0)
)
# Make a new HomeServer object for the worker
self.reactor.lookups["testserv"] = "1.2.3.4"
self.reactor.lookups["localhost"] = "127.0.0.1"
# Handle attempts to connect to fake redis server.
self.reactor.add_tcp_client_callback(
"localhost",
6379,
self.connect_any_redis_attempts,
)
self.worker_hs = self.setup_test_homeserver(
federation_http_client=None,
homeserver_to_use=GenericWorkerServer,
@@ -141,11 +81,18 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
self.test_handler = self._build_replication_data_handler()
self.worker_hs._replication_data_handler = self.test_handler # type: ignore[attr-defined]
self.hs.get_replication_command_handler().start_replication(self.hs)
self.worker_hs.get_replication_command_handler().start_replication(
self.worker_hs
repl_handler = ReplicationCommandHandler(self.worker_hs)
self.client = ClientReplicationStreamProtocol(
self.worker_hs,
"client",
"test",
clock,
repl_handler,
)
self._client_transport = None
self._server_transport = None
def create_resource_dict(self) -> Dict[str, Resource]:
d = super().create_resource_dict()
d["/_synapse/replication"] = ReplicationRestResource(self.hs)
@@ -162,46 +109,26 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
return TestReplicationDataHandler(self.worker_hs)
def reconnect(self):
self.disconnect()
print("RECONNECTING")
if self._client_transport:
self.client.close()
# Make a `FakeConnector` to emulate the behavior of `connectTCP. That
# creates an `IConnector`, which is responsible for calling the factory
# `clientConnectionLost`. The reconnecting factory then calls
# `IConnector.connect` to attempt a reconnection. The transport is meant
# to call `connectionLost` on the `IConnector`.
#
# Most of that is bypassed by directly calling `retry` on the factory,
# which schedules a `connect()` call on the connector.
timeouts = []
for hs in (self.hs, self.worker_hs):
hs_factory_outbound = hs.get_outbound_redis_connection()._factory
hs_factory_outbound.clientConnectionLost(
FakeOutboundConnector(hs), Failure(RuntimeError(""))
)
timeouts.append(hs_factory_outbound.delay)
if self._server_transport:
self.server.close()
hs_factory = hs.get_replication_command_handler()._factory
hs_factory.clientConnectionLost(
FakeReplicationHandlerConnector(hs),
Failure(RuntimeError("")),
)
timeouts.append(hs_factory.delay)
self._client_transport = FakeTransport(self.server, self.reactor)
self.client.makeConnection(self._client_transport)
# Wait for the reconnects to happen.
self.pump(max(timeouts) + 1)
self.connect_any_redis_attempts()
self._server_transport = FakeTransport(self.client, self.reactor)
self.server.makeConnection(self._server_transport)
def disconnect(self):
print("DISCONNECTING")
for (
client_to_server_transport,
server_to_client_transport,
) in self._redis_transports:
client_to_server_transport.abortConnection()
server_to_client_transport.abortConnection()
self._redis_transports = []
if self._client_transport:
self._client_transport = None
self.client.close()
if self._server_transport:
self._server_transport = None
self.server.close()
def replicate(self):
"""Tell the master side of replication that something has happened, and then
@@ -285,40 +212,6 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
self.assertEqual(request.method, b"GET")
def connect_any_redis_attempts(self):
"""If redis is enabled we need to deal with workers connecting to a
redis server. We don't want to use a real Redis server so we use a
fake one.
"""
clients = self.reactor.tcpClients
while clients:
(host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
self.assertEqual(host, "localhost")
self.assertEqual(port, 6379)
client_protocol = client_factory.buildProtocol(None)
server_protocol = self._redis_server.buildProtocol(None)
if client_protocol.__class__.__name__ == "RedisSubscriber":
print(client_protocol, client_protocol.synapse_handler._presence_handler.hs, client_protocol.synapse_outbound_redis_connection)
else:
print(client_protocol, client_protocol.factory.hs)
print()
client_to_server_transport = FakeTransport(
server_protocol, self.reactor, client_protocol
)
client_protocol.makeConnection(client_to_server_transport)
server_to_client_transport = FakeTransport(
client_protocol, self.reactor, server_protocol
)
server_protocol.makeConnection(server_to_client_transport)
# Store for potentially disconnecting.
self._redis_transports.append(
(client_to_server_transport, server_to_client_transport)
)
class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
"""Base class for tests running multiple workers.
@@ -327,14 +220,11 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
unlike `BaseStreamTestCase`.
"""
def default_config(self):
config = super().default_config()
config["redis"] = {"enabled": True}
return config
def setUp(self):
super().setUp()
# build a replication server
self.server_factory = ReplicationStreamProtocolFactory(self.hs)
self.streamer = self.hs.get_replication_streamer()
# Fake in memory Redis server that servers can connect to.
@@ -353,14 +243,15 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
# handling inbound HTTP requests to that instance.
self._hs_to_site = {self.hs: self.site}
# Handle attempts to connect to fake redis server.
self.reactor.add_tcp_client_callback(
"localhost",
6379,
self.connect_any_redis_attempts,
)
if self.hs.config.redis.redis_enabled:
# Handle attempts to connect to fake redis server.
self.reactor.add_tcp_client_callback(
"localhost",
6379,
self.connect_any_redis_attempts,
)
self.hs.get_replication_command_handler().start_replication(self.hs)
self.hs.get_replication_command_handler().start_replication(self.hs)
# When we see a connection attempt to the master replication listener we
# automatically set up the connection. This is so that tests don't
@@ -444,6 +335,27 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
store = worker_hs.get_datastores().main
store.db_pool._db_pool = self.database_pool._db_pool
# Set up TCP replication between master and the new worker if we don't
# have Redis support enabled.
if not worker_hs.config.redis.redis_enabled:
repl_handler = ReplicationCommandHandler(worker_hs)
client = ClientReplicationStreamProtocol(
worker_hs,
"client",
"test",
self.clock,
repl_handler,
)
server = self.server_factory.buildProtocol(
IPv4Address("TCP", "127.0.0.1", 0)
)
client_transport = FakeTransport(server, self.reactor)
client.makeConnection(client_transport)
server_transport = FakeTransport(client, self.reactor)
server.makeConnection(server_transport)
# Set up a resource for the worker
resource = ReplicationRestResource(worker_hs)
@@ -462,7 +374,8 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
reactor=self.reactor,
)
worker_hs.get_replication_command_handler().start_replication(worker_hs)
if worker_hs.config.redis.redis_enabled:
worker_hs.get_replication_command_handler().start_replication(worker_hs)
return worker_hs
@@ -511,7 +424,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
# Note: at this point we've wired everything up, but we need to return
# before the data starts flowing over the connections as this is called
# inside `connectTCP` before the connection has been passed back to the
# inside `connecTCP` before the connection has been passed back to the
# code that requested the TCP connection.
def connect_any_redis_attempts(self):
@@ -623,13 +536,8 @@ class FakeRedisPubSubProtocol(Protocol):
self.send("OK")
elif command == b"GET":
self.send(None)
# Connection keep-alives.
elif command == b"PING":
self.send("PONG")
else:
raise Exception(f"Unknown command: {command}")
raise Exception("Unknown command")
def send(self, msg):
"""Send a message back to the client."""

View File

@@ -250,14 +250,10 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
self.replicate()
self.check("get_rooms_for_user_with_stream_ordering", (USER_ID_2,), set())
# limit the replication rate from server -> client.
print(len(self._redis_transports))
for x in self._redis_transports:
print(f"\t{x}")
assert len(self._redis_transports) == 1
for _, repl_transport in self._redis_transports:
assert isinstance(repl_transport, FakeTransport)
repl_transport.autoflush = False
# limit the replication rate
repl_transport = self._server_transport
assert isinstance(repl_transport, FakeTransport)
repl_transport.autoflush = False
# build the join and message events and persist them in the same batch.
logger.info("----- build test events ------")

View File

@@ -28,7 +28,7 @@ class ReceiptsStreamTestCase(BaseStreamTestCase):
return Mock(wraps=super()._build_replication_data_handler())
def test_receipt(self):
# self.reconnect()
self.reconnect()
# tell the master to send a new receipt
self.get_success(

View File

@@ -27,8 +27,10 @@ class ClientReaderTestCase(BaseMultiWorkerStreamTestCase):
servlets = [register.register_servlets]
def _get_worker_hs_config(self) -> dict:
config = super()._get_worker_hs_config()
config = self.default_config()
config["worker_app"] = "synapse.app.client_reader"
config["worker_replication_host"] = "testserv"
config["worker_replication_http_port"] = "8765"
return config
def test_register_single_worker(self):

View File

@@ -51,6 +51,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
def default_config(self):
conf = super().default_config()
conf["redis"] = {"enabled": "true"}
conf["stream_writers"] = {"events": ["worker1", "worker2"]}
conf["instance_map"] = {
"worker1": {"host": "testserv", "port": 1001},

View File

@@ -24,7 +24,6 @@ from synapse.util import Clock
from synapse.visibility import filter_events_for_client
from tests import unittest
from tests.unittest import override_config
one_hour_ms = 3600000
one_day_ms = one_hour_ms * 24
@@ -39,10 +38,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
config = self.default_config()
# merge this default retention config with anything that was specified in
# @override_config
retention_config = {
config["retention"] = {
"enabled": True,
"default_policy": {
"min_lifetime": one_day_ms,
@@ -51,8 +47,6 @@ class RetentionTestCase(unittest.HomeserverTestCase):
"allowed_lifetime_min": one_day_ms,
"allowed_lifetime_max": one_day_ms * 3,
}
retention_config.update(config.get("retention", {}))
config["retention"] = retention_config
self.hs = self.setup_test_homeserver(config=config)
@@ -121,20 +115,22 @@ class RetentionTestCase(unittest.HomeserverTestCase):
self._test_retention_event_purged(room_id, one_day_ms * 2)
@override_config({"retention": {"purge_jobs": [{"interval": "5d"}]}})
def test_visibility(self) -> None:
"""Tests that synapse.visibility.filter_events_for_client correctly filters out
outdated events, even if the purge job hasn't got to them yet.
We do this by setting a very long time between purge jobs.
outdated events
"""
store = self.hs.get_datastores().main
storage = self.hs.get_storage()
room_id = self.helper.create_room_as(self.user_id, tok=self.token)
events = []
# Send a first event, which should be filtered out at the end of the test.
resp = self.helper.send(room_id=room_id, body="1", tok=self.token)
first_event_id = resp.get("event_id")
# Get the event from the store so that we end up with a FrozenEvent that we can
# give to filter_events_for_client. We need to do this now because the event won't
# be in the database anymore after it has expired.
events.append(self.get_success(store.get_event(resp.get("event_id"))))
# Advance the time by 2 days. We're using the default retention policy, therefore
# after this the first event will still be valid.
@@ -142,17 +138,16 @@ class RetentionTestCase(unittest.HomeserverTestCase):
# Send another event, which shouldn't get filtered out.
resp = self.helper.send(room_id=room_id, body="2", tok=self.token)
valid_event_id = resp.get("event_id")
events.append(self.get_success(store.get_event(valid_event_id)))
# Advance the time by another 2 days. After this, the first event should be
# outdated but not the second one.
self.reactor.advance(one_day_ms * 2 / 1000)
# Fetch the events, and run filter_events_for_client on them
events = self.get_success(
store.get_events_as_list([first_event_id, valid_event_id])
)
self.assertEqual(2, len(events), "events retrieved from database")
# Run filter_events_for_client with our list of FrozenEvents.
filtered_events = self.get_success(
filter_events_for_client(storage, self.user_id, events)
)

View File

@@ -1,18 +1,3 @@
# Copyright 2018-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from http import HTTPStatus
from unittest.mock import Mock, call
from twisted.internet import defer, reactor
@@ -26,14 +11,14 @@ from tests.utils import MockClock
class HttpTransactionCacheTestCase(unittest.TestCase):
def setUp(self) -> None:
def setUp(self):
self.clock = MockClock()
self.hs = Mock()
self.hs.get_clock = Mock(return_value=self.clock)
self.hs.get_auth = Mock()
self.cache = HttpTransactionCache(self.hs)
self.mock_http_response = (HTTPStatus.OK, "GOOD JOB!")
self.mock_http_response = (200, "GOOD JOB!")
self.mock_key = "foo"
@defer.inlineCallbacks

View File

@@ -16,7 +16,7 @@ import shutil
import tempfile
from binascii import unhexlify
from io import BytesIO
from typing import Any, BinaryIO, Dict, List, Optional, Union
from typing import Optional
from unittest.mock import Mock
from urllib import parse
@@ -26,24 +26,18 @@ from PIL import Image as Image
from twisted.internet import defer
from twisted.internet.defer import Deferred
from twisted.test.proto_helpers import MemoryReactor
from synapse.events import EventBase
from synapse.events.spamcheck import load_legacy_spam_checkers
from synapse.logging.context import make_deferred_yieldable
from synapse.module_api import ModuleApi
from synapse.rest import admin
from synapse.rest.client import login
from synapse.rest.media.v1._base import FileInfo
from synapse.rest.media.v1.filepath import MediaFilePaths
from synapse.rest.media.v1.media_storage import MediaStorage, ReadableFileWrapper
from synapse.rest.media.v1.media_storage import MediaStorage
from synapse.rest.media.v1.storage_provider import FileStorageProviderBackend
from synapse.server import HomeServer
from synapse.types import RoomAlias
from synapse.util import Clock
from tests import unittest
from tests.server import FakeChannel, FakeSite, make_request
from tests.server import FakeSite, make_request
from tests.test_utils import SMALL_PNG
from tests.utils import default_config
@@ -52,7 +46,7 @@ class MediaStorageTests(unittest.HomeserverTestCase):
needs_threadpool = True
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
def prepare(self, reactor, clock, hs):
self.test_dir = tempfile.mkdtemp(prefix="synapse-tests-")
self.addCleanup(shutil.rmtree, self.test_dir)
@@ -68,7 +62,7 @@ class MediaStorageTests(unittest.HomeserverTestCase):
hs, self.primary_base_path, self.filepaths, storage_providers
)
def test_ensure_media_is_in_local_cache(self) -> None:
def test_ensure_media_is_in_local_cache(self):
media_id = "some_media_id"
test_body = "Test\n"
@@ -111,7 +105,7 @@ class MediaStorageTests(unittest.HomeserverTestCase):
self.assertEqual(test_body, body)
@attr.s(auto_attribs=True, slots=True, frozen=True)
@attr.s(slots=True, frozen=True)
class _TestImage:
"""An image for testing thumbnailing with the expected results
@@ -127,18 +121,18 @@ class _TestImage:
a 404 is expected.
"""
data: bytes
content_type: bytes
extension: bytes
expected_cropped: Optional[bytes] = None
expected_scaled: Optional[bytes] = None
expected_found: bool = True
data = attr.ib(type=bytes)
content_type = attr.ib(type=bytes)
extension = attr.ib(type=bytes)
expected_cropped = attr.ib(type=Optional[bytes], default=None)
expected_scaled = attr.ib(type=Optional[bytes], default=None)
expected_found = attr.ib(default=True, type=bool)
@parameterized_class(
("test_image",),
[
# small png
# smoll png
(
_TestImage(
SMALL_PNG,
@@ -199,17 +193,11 @@ class MediaRepoTests(unittest.HomeserverTestCase):
hijack_auth = True
user_id = "@test:user"
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
def make_homeserver(self, reactor, clock):
self.fetches = []
def get_file(
destination: str,
path: str,
output_stream: BinaryIO,
args: Optional[Dict[str, Union[str, List[str]]]] = None,
max_size: Optional[int] = None,
) -> Deferred:
def get_file(destination, path, output_stream, args=None, max_size=None):
"""
Returns tuple[int,dict,str,int] of file length, response headers,
absolute URI, and response code.
@@ -250,7 +238,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
return hs
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
def prepare(self, reactor, clock, hs):
media_resource = hs.get_media_repository_resource()
self.download_resource = media_resource.children[b"download"]
@@ -260,9 +248,8 @@ class MediaRepoTests(unittest.HomeserverTestCase):
self.media_id = "example.com/12345"
def _req(
self, content_disposition: Optional[bytes], include_content_type: bool = True
) -> FakeChannel:
def _req(self, content_disposition, include_content_type=True):
channel = make_request(
self.reactor,
FakeSite(self.download_resource, self.reactor),
@@ -301,7 +288,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
return channel
def test_handle_missing_content_type(self) -> None:
def test_handle_missing_content_type(self):
channel = self._req(
b"inline; filename=out" + self.test_image.extension,
include_content_type=False,
@@ -312,7 +299,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
headers.getRawHeaders(b"Content-Type"), [b"application/octet-stream"]
)
def test_disposition_filename_ascii(self) -> None:
def test_disposition_filename_ascii(self):
"""
If the filename is filename=<ascii> then Synapse will decode it as an
ASCII string, and use filename= in the response.
@@ -328,7 +315,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
[b"inline; filename=out" + self.test_image.extension],
)
def test_disposition_filenamestar_utf8escaped(self) -> None:
def test_disposition_filenamestar_utf8escaped(self):
"""
If the filename is filename=*utf8''<utf8 escaped> then Synapse will
correctly decode it as the UTF-8 string, and use filename* in the
@@ -348,7 +335,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
[b"inline; filename*=utf-8''" + filename + self.test_image.extension],
)
def test_disposition_none(self) -> None:
def test_disposition_none(self):
"""
If there is no filename, one isn't passed on in the Content-Disposition
of the request.
@@ -361,26 +348,26 @@ class MediaRepoTests(unittest.HomeserverTestCase):
)
self.assertEqual(headers.getRawHeaders(b"Content-Disposition"), None)
def test_thumbnail_crop(self) -> None:
def test_thumbnail_crop(self):
"""Test that a cropped remote thumbnail is available."""
self._test_thumbnail(
"crop", self.test_image.expected_cropped, self.test_image.expected_found
)
def test_thumbnail_scale(self) -> None:
def test_thumbnail_scale(self):
"""Test that a scaled remote thumbnail is available."""
self._test_thumbnail(
"scale", self.test_image.expected_scaled, self.test_image.expected_found
)
def test_invalid_type(self) -> None:
def test_invalid_type(self):
"""An invalid thumbnail type is never available."""
self._test_thumbnail("invalid", None, False)
@unittest.override_config(
{"thumbnail_sizes": [{"width": 32, "height": 32, "method": "scale"}]}
)
def test_no_thumbnail_crop(self) -> None:
def test_no_thumbnail_crop(self):
"""
Override the config to generate only scaled thumbnails, but request a cropped one.
"""
@@ -389,13 +376,13 @@ class MediaRepoTests(unittest.HomeserverTestCase):
@unittest.override_config(
{"thumbnail_sizes": [{"width": 32, "height": 32, "method": "crop"}]}
)
def test_no_thumbnail_scale(self) -> None:
def test_no_thumbnail_scale(self):
"""
Override the config to generate only cropped thumbnails, but request a scaled one.
"""
self._test_thumbnail("scale", None, False)
def test_thumbnail_repeated_thumbnail(self) -> None:
def test_thumbnail_repeated_thumbnail(self):
"""Test that fetching the same thumbnail works, and deleting the on disk
thumbnail regenerates it.
"""
@@ -456,9 +443,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
channel.result["body"],
)
def _test_thumbnail(
self, method: str, expected_body: Optional[bytes], expected_found: bool
) -> None:
def _test_thumbnail(self, method, expected_body, expected_found):
params = "?width=32&height=32&method=" + method
channel = make_request(
self.reactor,
@@ -500,7 +485,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
)
@parameterized.expand([("crop", 16), ("crop", 64), ("scale", 16), ("scale", 64)])
def test_same_quality(self, method: str, desired_size: int) -> None:
def test_same_quality(self, method, desired_size):
"""Test that choosing between thumbnails with the same quality rating succeeds.
We are not particular about which thumbnail is chosen."""
@@ -536,7 +521,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
)
)
def test_x_robots_tag_header(self) -> None:
def test_x_robots_tag_header(self):
"""
Tests that the `X-Robots-Tag` header is present, which informs web crawlers
to not index, archive, or follow links in media.
@@ -555,38 +540,29 @@ class TestSpamChecker:
`evil`.
"""
def __init__(self, config: Dict[str, Any], api: ModuleApi) -> None:
def __init__(self, config, api):
self.config = config
self.api = api
def parse_config(config: Dict[str, Any]) -> Dict[str, Any]:
def parse_config(config):
return config
async def check_event_for_spam(self, event: EventBase) -> Union[bool, str]:
async def check_event_for_spam(self, foo):
return False # allow all events
async def user_may_invite(
self,
inviter_userid: str,
invitee_userid: str,
room_id: str,
) -> bool:
async def user_may_invite(self, inviter_userid, invitee_userid, room_id):
return True # allow all invites
async def user_may_create_room(self, userid: str) -> bool:
async def user_may_create_room(self, userid):
return True # allow all room creations
async def user_may_create_room_alias(
self, userid: str, room_alias: RoomAlias
) -> bool:
async def user_may_create_room_alias(self, userid, room_alias):
return True # allow all room aliases
async def user_may_publish_room(self, userid: str, room_id: str) -> bool:
async def user_may_publish_room(self, userid, room_id):
return True # allow publishing of all rooms
async def check_media_file_for_spam(
self, file_wrapper: ReadableFileWrapper, file_info: FileInfo
) -> bool:
async def check_media_file_for_spam(self, file_wrapper, file_info) -> bool:
buf = BytesIO()
await file_wrapper.write_chunks_to(buf.write)
@@ -599,7 +575,7 @@ class SpamCheckerTestCase(unittest.HomeserverTestCase):
admin.register_servlets,
]
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
def prepare(self, reactor, clock, hs):
self.user = self.register_user("user", "pass")
self.tok = self.login("user", "pass")
@@ -610,7 +586,7 @@ class SpamCheckerTestCase(unittest.HomeserverTestCase):
load_legacy_spam_checkers(hs)
def default_config(self) -> Dict[str, Any]:
def default_config(self):
config = default_config("test")
config.update(
@@ -626,13 +602,13 @@ class SpamCheckerTestCase(unittest.HomeserverTestCase):
return config
def test_upload_innocent(self) -> None:
def test_upload_innocent(self):
"""Attempt to upload some innocent data that should be allowed."""
self.helper.upload_media(
self.upload_resource, SMALL_PNG, tok=self.tok, expect_code=200
)
def test_upload_ban(self) -> None:
def test_upload_ban(self):
"""Attempt to upload some data that includes bytes "evil", which should
get rejected by the spam checker.
"""

View File

@@ -16,21 +16,16 @@ import base64
import json
import os
import re
from typing import Any, Dict, Optional, Sequence, Tuple, Type
from urllib.parse import urlencode
from twisted.internet._resolver import HostResolution
from twisted.internet.address import IPv4Address, IPv6Address
from twisted.internet.error import DNSLookupError
from twisted.internet.interfaces import IAddress, IResolutionReceiver
from twisted.test.proto_helpers import AccumulatingProtocol, MemoryReactor
from twisted.test.proto_helpers import AccumulatingProtocol
from synapse.config.oembed import OEmbedEndpointConfig
from synapse.rest.media.v1.media_repository import MediaRepositoryResource
from synapse.rest.media.v1.preview_url_resource import IMAGE_CACHE_EXPIRY_MS
from synapse.server import HomeServer
from synapse.types import JsonDict
from synapse.util import Clock
from synapse.util.stringutils import parse_and_validate_mxc_uri
from tests import unittest
@@ -57,7 +52,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
b"</head></html>"
)
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
def make_homeserver(self, reactor, clock):
config = self.default_config()
config["url_preview_enabled"] = True
@@ -118,22 +113,22 @@ class URLPreviewTests(unittest.HomeserverTestCase):
return hs
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
def prepare(self, reactor, clock, hs):
self.media_repo = hs.get_media_repository_resource()
self.preview_url = self.media_repo.children[b"preview_url"]
self.lookups: Dict[str, Any] = {}
self.lookups = {}
class Resolver:
def resolveHostName(
_self,
resolutionReceiver: IResolutionReceiver,
hostName: str,
portNumber: int = 0,
addressTypes: Optional[Sequence[Type[IAddress]]] = None,
transportSemantics: str = "TCP",
) -> IResolutionReceiver:
resolutionReceiver,
hostName,
portNumber=0,
addressTypes=None,
transportSemantics="TCP",
):
resolution = HostResolution(hostName)
resolutionReceiver.resolutionBegan(resolution)
@@ -145,9 +140,9 @@ class URLPreviewTests(unittest.HomeserverTestCase):
resolutionReceiver.resolutionComplete()
return resolutionReceiver
self.reactor.nameResolver = Resolver() # type: ignore[assignment]
self.reactor.nameResolver = Resolver()
def create_test_resource(self) -> MediaRepositoryResource:
def create_test_resource(self):
return self.hs.get_media_repository_resource()
def _assert_small_png(self, json_body: JsonDict) -> None:
@@ -158,7 +153,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.assertEqual(json_body["og:image:type"], "image/png")
self.assertEqual(json_body["matrix:image:size"], 67)
def test_cache_returns_correct_type(self) -> None:
def test_cache_returns_correct_type(self):
self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
channel = self.make_request(
@@ -212,7 +207,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
channel.json_body, {"og:title": "~matrix~", "og:description": "hi"}
)
def test_non_ascii_preview_httpequiv(self) -> None:
def test_non_ascii_preview_httpequiv(self):
self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
end_content = (
@@ -248,7 +243,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200)
self.assertEqual(channel.json_body["og:title"], "\u0434\u043a\u0430")
def test_video_rejected(self) -> None:
def test_video_rejected(self):
self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
end_content = b"anything"
@@ -284,7 +279,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
},
)
def test_audio_rejected(self) -> None:
def test_audio_rejected(self):
self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
end_content = b"anything"
@@ -320,7 +315,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
},
)
def test_non_ascii_preview_content_type(self) -> None:
def test_non_ascii_preview_content_type(self):
self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
end_content = (
@@ -355,7 +350,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200)
self.assertEqual(channel.json_body["og:title"], "\u0434\u043a\u0430")
def test_overlong_title(self) -> None:
def test_overlong_title(self):
self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")]
end_content = (
@@ -392,7 +387,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
# We should only see the `og:description` field, as `title` is too long and should be stripped out
self.assertCountEqual(["og:description"], res.keys())
def test_ipaddr(self) -> None:
def test_ipaddr(self):
"""
IP addresses can be previewed directly.
"""
@@ -422,7 +417,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
channel.json_body, {"og:title": "~matrix~", "og:description": "hi"}
)
def test_blacklisted_ip_specific(self) -> None:
def test_blacklisted_ip_specific(self):
"""
Blacklisted IP addresses, found via DNS, are not spidered.
"""
@@ -443,7 +438,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
},
)
def test_blacklisted_ip_range(self) -> None:
def test_blacklisted_ip_range(self):
"""
Blacklisted IP ranges, IPs found over DNS, are not spidered.
"""
@@ -462,7 +457,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
},
)
def test_blacklisted_ip_specific_direct(self) -> None:
def test_blacklisted_ip_specific_direct(self):
"""
Blacklisted IP addresses, accessed directly, are not spidered.
"""
@@ -481,7 +476,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
)
self.assertEqual(channel.code, 403)
def test_blacklisted_ip_range_direct(self) -> None:
def test_blacklisted_ip_range_direct(self):
"""
Blacklisted IP ranges, accessed directly, are not spidered.
"""
@@ -498,7 +493,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
},
)
def test_blacklisted_ip_range_whitelisted_ip(self) -> None:
def test_blacklisted_ip_range_whitelisted_ip(self):
"""
Blacklisted but then subsequently whitelisted IP addresses can be
spidered.
@@ -531,7 +526,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
channel.json_body, {"og:title": "~matrix~", "og:description": "hi"}
)
def test_blacklisted_ip_with_external_ip(self) -> None:
def test_blacklisted_ip_with_external_ip(self):
"""
If a hostname resolves a blacklisted IP, even if there's a
non-blacklisted one, it will be rejected.
@@ -554,7 +549,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
},
)
def test_blacklisted_ipv6_specific(self) -> None:
def test_blacklisted_ipv6_specific(self):
"""
Blacklisted IP addresses, found via DNS, are not spidered.
"""
@@ -577,7 +572,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
},
)
def test_blacklisted_ipv6_range(self) -> None:
def test_blacklisted_ipv6_range(self):
"""
Blacklisted IP ranges, IPs found over DNS, are not spidered.
"""
@@ -596,7 +591,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
},
)
def test_OPTIONS(self) -> None:
def test_OPTIONS(self):
"""
OPTIONS returns the OPTIONS.
"""
@@ -606,7 +601,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200)
self.assertEqual(channel.json_body, {})
def test_accept_language_config_option(self) -> None:
def test_accept_language_config_option(self):
"""
Accept-Language header is sent to the remote server
"""
@@ -657,7 +652,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
server.data,
)
def test_data_url(self) -> None:
def test_data_url(self):
"""
Requesting to preview a data URL is not supported.
"""
@@ -680,7 +675,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 500)
def test_inline_data_url(self) -> None:
def test_inline_data_url(self):
"""
An inline image (as a data URL) should be parsed properly.
"""
@@ -717,7 +712,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 200)
self._assert_small_png(channel.json_body)
def test_oembed_photo(self) -> None:
def test_oembed_photo(self):
"""Test an oEmbed endpoint which returns a 'photo' type which redirects the preview to a new URL."""
self.lookups["publish.twitter.com"] = [(IPv4Address, "10.1.2.3")]
self.lookups["cdn.twitter.com"] = [(IPv4Address, "10.1.2.3")]
@@ -776,7 +771,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.assertEqual(body["og:url"], "http://twitter.com/matrixdotorg/status/12345")
self._assert_small_png(body)
def test_oembed_rich(self) -> None:
def test_oembed_rich(self):
"""Test an oEmbed endpoint which returns HTML content via the 'rich' type."""
self.lookups["publish.twitter.com"] = [(IPv4Address, "10.1.2.3")]
@@ -822,7 +817,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
},
)
def test_oembed_format(self) -> None:
def test_oembed_format(self):
"""Test an oEmbed endpoint which requires the format in the URL."""
self.lookups["www.hulu.com"] = [(IPv4Address, "10.1.2.3")]
@@ -871,7 +866,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
},
)
def test_oembed_autodiscovery(self) -> None:
def test_oembed_autodiscovery(self):
"""
Autodiscovery works by finding the link in the HTML response and then requesting an oEmbed URL.
1. Request a preview of a URL which is not known to the oEmbed code.
@@ -967,7 +962,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
)
self._assert_small_png(body)
def _download_image(self) -> Tuple[str, str]:
def _download_image(self):
"""Downloads an image into the URL cache.
Returns:
A (host, media_id) tuple representing the MXC URI of the image.
@@ -1000,7 +995,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
self.assertIsNone(_port)
return host, media_id
def test_storage_providers_exclude_files(self) -> None:
def test_storage_providers_exclude_files(self):
"""Test that files are not stored in or fetched from storage providers."""
host, media_id = self._download_image()
@@ -1042,7 +1037,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
"URL cache file was unexpectedly retrieved from a storage provider",
)
def test_storage_providers_exclude_thumbnails(self) -> None:
def test_storage_providers_exclude_thumbnails(self):
"""Test that thumbnails are not stored in or fetched from storage providers."""
host, media_id = self._download_image()
@@ -1095,7 +1090,7 @@ class URLPreviewTests(unittest.HomeserverTestCase):
"URL cache thumbnail was unexpectedly retrieved from a storage provider",
)
def test_cache_expiry(self) -> None:
def test_cache_expiry(self):
"""Test that URL cache files and thumbnails are cleaned up properly on expiry."""
self.preview_url.clock = MockClock()