1
0

Compare commits

..

23 Commits

Author SHA1 Message Date
Erik Johnston
4de1c35728 Fixup changelog 2021-03-08 13:59:17 +00:00
Erik Johnston
15c788e22d 1.29.0 2021-03-08 13:52:13 +00:00
Erik Johnston
a6333b8d42 Fix link in UPGRADES 2021-03-04 10:32:44 +00:00
Erik Johnston
ea0a3aaf0a Fix changelog 2021-03-04 10:29:43 +00:00
Erik Johnston
3f49d80dcf 1.29.0rc1 2021-03-04 10:12:53 +00:00
Patrick Cloke
33a02f0f52 Fix additional type hints from Twisted upgrade. (#9518) 2021-03-03 15:47:38 -05:00
Richard van der Hoff
4db07f9aef Set X-Forwarded-Proto header when frontend-proxy proxies a request (#9539)
Should fix some remaining warnings
2021-03-03 18:49:08 +00:00
Erik Johnston
a4fa044c00 Fix 'rejected_events_metadata' background update (#9537)
Turns out matrix.org has an event that has duplicate auth events (which really isn't supposed to happen, but here we are). This caused the background update to fail due to `UniqueViolation`.
2021-03-03 16:04:24 +00:00
Patrick Cloke
922788c604 Purge chain cover tables when purging events. (#9498) 2021-03-03 11:04:08 -05:00
Dirk Klimpel
d790d0d314 Add type hints to user admin API. (#9521) 2021-03-03 08:09:39 -05:00
Patrick Cloke
0c330423bc Bump the mypy and mypy-zope versions. (#9529) 2021-03-03 07:19:19 -05:00
Erik Johnston
16f9f93eb7 Make deleting stale pushers a background update (#9536) 2021-03-03 12:08:16 +00:00
Richard van der Hoff
a5daae2a5f Update nginx reverse-proxy docs (#9512)
Turns out nginx overwrites the Host header by default.
2021-03-03 11:08:11 +00:00
Aaron Raimist
0279e0e086 Prevent presence background jobs from running when presence is disabled (#9530)
Prevent presence background jobs from running when presence is disabled

Signed-off-by: Aaron Raimist <aaron@raim.ist>
2021-03-03 10:21:46 +00:00
Patrick Cloke
aee10768d8 Revert "Fix #8518 (sync requests being cached wrongly on timeout) (#9358)"
This reverts commit f5c93fc993.

This is being backed out due to a regression (#9507) and additional
review feedback being provided.
2021-03-02 09:43:34 -05:00
Erik Johnston
7f5d753d06 Re-run rejected metadata background update. (#9503)
It landed in schema version 58 after 59 had been created, causing some
servers to not run it. The main effect of was that not all rooms had
their chain cover calculated correctly. After the BG updates complete
the chain covers will get fixed when a new state event in the affected
rooms is received.
2021-03-02 14:31:23 +00:00
Erik Johnston
16108c579d Fix SQL delta file taking a long time to run (#9516)
Fixes #9504
2021-03-02 14:05:01 +00:00
Dirk Klimpel
f00c4e7af0 Add type hints to device and event report admin API (#9519) 2021-03-02 09:31:12 +00:00
Patrick Cloke
ad8589d392 Fix a bug when a room alias is given to the admin join endpoint (#9506) 2021-03-01 13:59:01 -05:00
Patrick Cloke
16ec8c3272 (Hopefully) stop leaking file descriptors in media repo. (#9497)
By consuming the response if the headers imply that the
content is too large.
2021-03-01 12:45:00 -05:00
Patrick Cloke
a0bc9d387e Use the proper Request in type hints. (#9515)
This also pins the Twisted version in the mypy job for CI until
proper type hints are fixed throughout Synapse.
2021-03-01 12:23:46 -05:00
Jonathan de Jong
e12077a78a Allow bytecode again (#9502)
In #75, bytecode was disabled (from a bit of FUD back in `python<2.4` days, according to dev chat), I think it's safe enough to enable it again.

Added in `__pycache__/` and `.pyc`/`.pyd` to `.gitignore`, to extra-insure compiled files don't get committed.

`Signed-off-by: Jonathan de Jong <jonathan@automatia.nl>`
2021-02-26 18:30:54 +00:00
Tim Leung
ddb240293a Add support for no_proxy and case insensitive env variables (#9372)
### Changes proposed in this PR

- Add support for the `no_proxy` and `NO_PROXY` environment variables
  - Internally rely on urllib's [`proxy_bypass_environment`](bdb941be42/Lib/urllib/request.py (L2519))
- Extract env variables using urllib's `getproxies`/[`getproxies_environment`](bdb941be42/Lib/urllib/request.py (L2488)) which supports lowercase + uppercase, preferring lowercase, except for `HTTP_PROXY` in a CGI environment

This does contain behaviour changes for consumers so making sure these are called out:
- `no_proxy`/`NO_PROXY` is now respected
- lowercase `https_proxy` is now allowed and taken over `HTTPS_PROXY`

Related to #9306 which also uses `ProxyAgent`

Signed-off-by: Timothy Leung tim95@hotmail.co.uk
2021-02-26 17:37:57 +00:00
95 changed files with 812 additions and 472 deletions

3
.gitignore vendored
View File

@@ -6,13 +6,14 @@
*.egg
*.egg-info
*.lock
*.pyc
*.py[cod]
*.snap
*.tac
_trial_temp/
_trial_temp*/
/out
.DS_Store
__pycache__/
# stuff that is likely to exist when you run a server locally
/*.db

View File

@@ -1,9 +1,66 @@
Synapse 1.xx.0
==============
Synapse 1.29.0 (2021-03-08)
===========================
Note that synapse now expects an `X-Forwarded-Proto` header when used with a reverse proxy. Please see [UPGRADE.rst](UPGRADE.rst#upgrading-to-v1290) for more details on this change.
No significant changes.
Synapse 1.29.0rc1 (2021-03-04)
==============================
Features
--------
- Add rate limiters to cross-user key sharing requests. ([\#8957](https://github.com/matrix-org/synapse/issues/8957))
- Add `order_by` to the admin API `GET /_synapse/admin/v1/users/<user_id>/media`. Contributed by @dklimpel. ([\#8978](https://github.com/matrix-org/synapse/issues/8978))
- Add some configuration settings to make users' profile data more private. ([\#9203](https://github.com/matrix-org/synapse/issues/9203))
- The `no_proxy` and `NO_PROXY` environment variables are now respected in proxied HTTP clients with the lowercase form taking precedence if both are present. Additionally, the lowercase `https_proxy` environment variable is now respected in proxied HTTP clients on top of existing support for the uppercase `HTTPS_PROXY` form and takes precedence if both are present. Contributed by Timothy Leung. ([\#9372](https://github.com/matrix-org/synapse/issues/9372))
- Add a configuration option, `user_directory.prefer_local_users`, which when enabled will make it more likely for users on the same server as you to appear above other users. ([\#9383](https://github.com/matrix-org/synapse/issues/9383), [\#9385](https://github.com/matrix-org/synapse/issues/9385))
- Add support for regenerating thumbnails if they have been deleted but the original image is still stored. ([\#9438](https://github.com/matrix-org/synapse/issues/9438))
- Add support for `X-Forwarded-Proto` header when using a reverse proxy. ([\#9472](https://github.com/matrix-org/synapse/issues/9472), [\#9501](https://github.com/matrix-org/synapse/issues/9501), [\#9512](https://github.com/matrix-org/synapse/issues/9512), [\#9539](https://github.com/matrix-org/synapse/issues/9539))
Bugfixes
--------
- Fix a bug where users' pushers were not all deleted when they deactivated their account. ([\#9285](https://github.com/matrix-org/synapse/issues/9285), [\#9516](https://github.com/matrix-org/synapse/issues/9516))
- Fix a bug where a lot of unnecessary presence updates were sent when joining a room. ([\#9402](https://github.com/matrix-org/synapse/issues/9402))
- Fix a bug that caused multiple calls to the experimental `shared_rooms` endpoint to return stale results. ([\#9416](https://github.com/matrix-org/synapse/issues/9416))
- Fix a bug in single sign-on which could cause a "No session cookie found" error. ([\#9436](https://github.com/matrix-org/synapse/issues/9436))
- Fix bug introduced in v1.27.0 where allowing a user to choose their own username when logging in via single sign-on did not work unless an `idp_icon` was defined. ([\#9440](https://github.com/matrix-org/synapse/issues/9440))
- Fix a bug introduced in v1.26.0 where some sequences were not properly configured when running `synapse_port_db`. ([\#9449](https://github.com/matrix-org/synapse/issues/9449))
- Fix deleting pushers when using sharded pushers. ([\#9465](https://github.com/matrix-org/synapse/issues/9465), [\#9466](https://github.com/matrix-org/synapse/issues/9466), [\#9479](https://github.com/matrix-org/synapse/issues/9479), [\#9536](https://github.com/matrix-org/synapse/issues/9536))
- Fix missing startup checks for the consistency of certain PostgreSQL sequences. ([\#9470](https://github.com/matrix-org/synapse/issues/9470))
- Fix a long-standing bug where the media repository could leak file descriptors while previewing media. ([\#9497](https://github.com/matrix-org/synapse/issues/9497))
- Properly purge the event chain cover index when purging history. ([\#9498](https://github.com/matrix-org/synapse/issues/9498))
- Fix missing chain cover index due to a schema delta not being applied correctly. Only affected servers that ran development versions. ([\#9503](https://github.com/matrix-org/synapse/issues/9503))
- Fix a bug introduced in v1.25.0 where `/_synapse/admin/join/` would fail when given a room alias. ([\#9506](https://github.com/matrix-org/synapse/issues/9506))
- Prevent presence background jobs from running when presence is disabled. ([\#9530](https://github.com/matrix-org/synapse/issues/9530))
- Fix rare edge case that caused a background update to fail if the server had rejected an event that had duplicate auth events. ([\#9537](https://github.com/matrix-org/synapse/issues/9537))
Improved Documentation
----------------------
- Update the example systemd config to propagate reloads to individual units. ([\#9463](https://github.com/matrix-org/synapse/issues/9463))
Internal Changes
----------------
- Add documentation and type hints to `parse_duration`. ([\#9432](https://github.com/matrix-org/synapse/issues/9432))
- Remove vestiges of `uploads_path` configuration setting. ([\#9462](https://github.com/matrix-org/synapse/issues/9462))
- Add a comment about systemd-python. ([\#9464](https://github.com/matrix-org/synapse/issues/9464))
- Test that we require validated email for email pushers. ([\#9496](https://github.com/matrix-org/synapse/issues/9496))
- Allow python to generate bytecode for synapse. ([\#9502](https://github.com/matrix-org/synapse/issues/9502))
- Fix incorrect type hints. ([\#9515](https://github.com/matrix-org/synapse/issues/9515), [\#9518](https://github.com/matrix-org/synapse/issues/9518))
- Add type hints to device and event report admin API. ([\#9519](https://github.com/matrix-org/synapse/issues/9519))
- Add type hints to user admin API. ([\#9521](https://github.com/matrix-org/synapse/issues/9521))
- Bump the versions of mypy and mypy-zope used for static type checking. ([\#9529](https://github.com/matrix-org/synapse/issues/9529))
Synapse 1.28.0 (2021-02-25)
===========================

View File

@@ -98,9 +98,9 @@ will log a warning on each received request.
To avoid the warning, administrators using a reverse proxy should ensure that
the reverse proxy sets `X-Forwarded-Proto` header to `https` or `http` to
indicate the protocol used by the client. See the [reverse proxy
documentation](docs/reverse_proxy.md), where the example configurations have
been updated to show how to set this header.
indicate the protocol used by the client. See the `reverse proxy documentation
<docs/reverse_proxy.md>`_, where the example configurations have been updated to
show how to set this header.
(Users of `Caddy <https://caddyserver.com/>`_ are unaffected, since we believe it
sets `X-Forwarded-Proto` by default.)

View File

@@ -1 +0,0 @@
Temporarily drop cross-user m.room_key_request to_device messages over performance concerns.

View File

@@ -1 +0,0 @@
Add rate limiters to cross-user key sharing requests.

View File

@@ -1 +0,0 @@
Add `order_by` to the admin API `GET /_synapse/admin/v1/users/<user_id>/media`. Contributed by @dklimpel.

View File

@@ -1 +0,0 @@
Add some configuration settings to make users' profile data more private.

View File

@@ -1 +0,0 @@
Fix a bug where users' pushers were not all deleted when they deactivated their account.

View File

@@ -1 +0,0 @@
Added a fix that invalidates cache for empty timed-out sync responses.

View File

@@ -1 +0,0 @@
Add a configuration option, `user_directory.prefer_local_users`, which when enabled will make it more likely for users on the same server as you to appear above other users.

View File

@@ -1 +0,0 @@
Add a configuration option, `user_directory.prefer_local_users`, which when enabled will make it more likely for users on the same server as you to appear above other users.

View File

@@ -1 +0,0 @@
Fix a bug where a lot of unnecessary presence updates were sent when joining a room.

View File

@@ -1 +0,0 @@
Fix a bug that caused multiple calls to the experimental `shared_rooms` endpoint to return stale results.

View File

@@ -1 +0,0 @@
Add documentation and type hints to `parse_duration`.

View File

@@ -1 +0,0 @@
Fix a bug in single sign-on which could cause a "No session cookie found" error.

View File

@@ -1 +0,0 @@
Add support for regenerating thumbnails if they have been deleted but the original image is still stored.

View File

@@ -1 +0,0 @@
Fix bug introduced in v1.27.0 where allowing a user to choose their own username when logging in via single sign-on did not work unless an `idp_icon` was defined.

View File

@@ -1 +0,0 @@
Fix a bug introduced in v1.26.0 where some sequences were not properly configured when running `synapse_port_db`.

View File

@@ -1 +0,0 @@
Remove vestiges of `uploads_path` configuration setting.

View File

@@ -1 +0,0 @@
Update the example systemd config to propagate reloads to individual units.

View File

@@ -1 +0,0 @@
Add a comment about systemd-python.

View File

@@ -1 +0,0 @@
Fix deleting pushers when using sharded pushers.

View File

@@ -1 +0,0 @@
Fix deleting pushers when using sharded pushers.

View File

@@ -1 +0,0 @@
Fix missing startup checks for the consistency of certain PostgreSQL sequences.

View File

@@ -1 +0,0 @@
Add support for `X-Forwarded-Proto` header when using a reverse proxy.

View File

@@ -1 +0,0 @@
Fix deleting pushers when using sharded pushers.

View File

@@ -1 +0,0 @@
Test that we require validated email for email pushers.

View File

@@ -1 +0,0 @@
Add support for `X-Forwarded-Proto` header when using a reverse proxy.

View File

@@ -58,10 +58,10 @@ trap "rm -r $tmpdir" EXIT
cp -r tests "$tmpdir"
PYTHONPATH="$tmpdir" \
"${TARGET_PYTHON}" -B -m twisted.trial --reporter=text -j2 tests
"${TARGET_PYTHON}" -m twisted.trial --reporter=text -j2 tests
# build the config file
"${TARGET_PYTHON}" -B "${VIRTUALENV_DIR}/bin/generate_config" \
"${TARGET_PYTHON}" "${VIRTUALENV_DIR}/bin/generate_config" \
--config-dir="/etc/matrix-synapse" \
--data-dir="/var/lib/matrix-synapse" |
perl -pe '
@@ -87,7 +87,7 @@ PYTHONPATH="$tmpdir" \
' > "${PACKAGE_BUILD_DIR}/etc/matrix-synapse/homeserver.yaml"
# build the log config file
"${TARGET_PYTHON}" -B "${VIRTUALENV_DIR}/bin/generate_log_config" \
"${TARGET_PYTHON}" "${VIRTUALENV_DIR}/bin/generate_log_config" \
--output-file="${PACKAGE_BUILD_DIR}/etc/matrix-synapse/log.yaml"
# add a dependency on the right version of python to substvars.

10
debian/changelog vendored
View File

@@ -1,3 +1,13 @@
matrix-synapse-py3 (1.29.0) stable; urgency=medium
[ Jonathan de Jong ]
* Remove the python -B flag (don't generate bytecode) in scripts and documentation.
[ Synapse Packaging team ]
* New synapse release 1.29.0.
-- Synapse Packaging team <packages@matrix.org> Mon, 08 Mar 2021 13:51:50 +0000
matrix-synapse-py3 (1.28.0) stable; urgency=medium
* New synapse release 1.28.0.

2
debian/synctl.1 vendored
View File

@@ -44,7 +44,7 @@ Configuration file may be generated as follows:
.
.nf
$ python \-B \-m synapse\.app\.homeserver \-c config\.yaml \-\-generate\-config \-\-server\-name=<server name>
$ python \-m synapse\.app\.homeserver \-c config\.yaml \-\-generate\-config \-\-server\-name=<server name>
.
.fi
.

2
debian/synctl.ronn vendored
View File

@@ -41,7 +41,7 @@ process.
Configuration file may be generated as follows:
$ python -B -m synapse.app.homeserver -c config.yaml --generate-config --server-name=<server name>
$ python -m synapse.app.homeserver -c config.yaml --generate-config --server-name=<server name>
## ENVIRONMENT

View File

@@ -53,6 +53,8 @@ server {
proxy_pass http://localhost:8008;
proxy_set_header X-Forwarded-For $remote_addr;
proxy_set_header X-Forwarded-Proto $scheme;
proxy_set_header Host $host;
# Nginx by default only allows file uploads up to 1M in size
# Increase client_max_body_size to match max_upload_size defined in homeserver.yaml
client_max_body_size 50M;

View File

@@ -47,6 +47,7 @@ from synapse.storage.databases.main.events_bg_updates import (
from synapse.storage.databases.main.media_repository import (
MediaRepositoryBackgroundUpdateStore,
)
from synapse.storage.databases.main.pusher import PusherWorkerStore
from synapse.storage.databases.main.registration import (
RegistrationBackgroundUpdateStore,
find_max_generated_user_id_localpart,
@@ -177,6 +178,7 @@ class Store(
UserDirectoryBackgroundUpdateStore,
EndToEndKeyBackgroundStore,
StatsStore,
PusherWorkerStore,
):
def execute(self, f, *args, **kwargs):
return self.db_pool.runInteraction(f.__name__, f, *args, **kwargs)

View File

@@ -102,7 +102,7 @@ CONDITIONAL_REQUIREMENTS["lint"] = [
"flake8",
]
CONDITIONAL_REQUIREMENTS["mypy"] = ["mypy==0.790", "mypy-zope==0.2.8"]
CONDITIONAL_REQUIREMENTS["mypy"] = ["mypy==0.812", "mypy-zope==0.2.11"]
# Dependencies which are exclusively required by unit test code. This is
# NOT a list of all modules that are necessary to run the unit tests.

View File

@@ -48,7 +48,7 @@ try:
except ImportError:
pass
__version__ = "1.28.0"
__version__ = "1.29.0"
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
# We import here so that we don't have to install a bunch of deps when

View File

@@ -17,8 +17,6 @@ import sys
from synapse import python_dependencies # noqa: E402
sys.dont_write_bytecode = True
logger = logging.getLogger(__name__)
try:

View File

@@ -23,6 +23,7 @@ from typing_extensions import ContextManager
from twisted.internet import address
from twisted.web.resource import IResource
from twisted.web.server import Request
import synapse
import synapse.events
@@ -190,7 +191,7 @@ class KeyUploadServlet(RestServlet):
self.http_client = hs.get_simple_http_client()
self.main_uri = hs.config.worker_main_http_uri
async def on_POST(self, request, device_id):
async def on_POST(self, request: Request, device_id: Optional[str]):
requester = await self.auth.get_user_by_req(request, allow_guest=True)
user_id = requester.user.to_string()
body = parse_json_object_from_request(request)
@@ -223,10 +224,12 @@ class KeyUploadServlet(RestServlet):
header: request.requestHeaders.getRawHeaders(header, [])
for header in (b"Authorization", b"User-Agent")
}
# Add the previous hop the the X-Forwarded-For header.
# Add the previous hop to the X-Forwarded-For header.
x_forwarded_for = request.requestHeaders.getRawHeaders(
b"X-Forwarded-For", []
)
# we use request.client here, since we want the previous hop, not the
# original client (as returned by request.getClientAddress()).
if isinstance(request.client, (address.IPv4Address, address.IPv6Address)):
previous_host = request.client.host.encode("ascii")
# If the header exists, add to the comma-separated list of the first
@@ -239,6 +242,14 @@ class KeyUploadServlet(RestServlet):
x_forwarded_for = [previous_host]
headers[b"X-Forwarded-For"] = x_forwarded_for
# Replicate the original X-Forwarded-Proto header. Note that
# XForwardedForRequest overrides isSecure() to give us the original protocol
# used by the client, as opposed to the protocol used by our upstream proxy
# - which is what we want here.
headers[b"X-Forwarded-Proto"] = [
b"https" if request.isSecure() else b"http"
]
try:
result = await self.http_client.post_json_get_json(
self.main_uri + request.uri.decode("ascii"), body, headers=headers

View File

@@ -936,10 +936,6 @@ class FederationHandlerRegistry:
):
return
# Temporary patch to drop cross-user key share requests
if edu_type == "m.room_key_request":
return
# Check if we have a handler on this instance
handler = self.edu_handlers.get(edu_type)
if handler:

View File

@@ -36,7 +36,7 @@ import attr
import bcrypt
import pymacaroons
from twisted.web.http import Request
from twisted.web.server import Request
from synapse.api.constants import LoginType
from synapse.api.errors import (
@@ -481,7 +481,7 @@ class AuthHandler(BaseHandler):
sid = authdict["session"]
# Convert the URI and method to strings.
uri = request.uri.decode("utf-8")
uri = request.uri.decode("utf-8") # type: ignore
method = request.method.decode("utf-8")
# If there's no session ID, create a new session.

View File

@@ -252,7 +252,7 @@ class MessageHandler:
# If this is an AS, double check that they are allowed to see the members.
# This can either be because the AS user is in the room or because there
# is a user in the room that the AS is "interested in"
if False and requester.app_service and user_id not in users_with_profile:
if requester.app_service and user_id not in users_with_profile:
for uid in users_with_profile:
if requester.app_service.is_interested_in_user(uid):
break

View File

@@ -274,22 +274,25 @@ class PresenceHandler(BasePresenceHandler):
self.external_sync_linearizer = Linearizer(name="external_sync_linearizer")
# Start a LoopingCall in 30s that fires every 5s.
# The initial delay is to allow disconnected clients a chance to
# reconnect before we treat them as offline.
def run_timeout_handler():
return run_as_background_process(
"handle_presence_timeouts", self._handle_timeouts
if self._presence_enabled:
# Start a LoopingCall in 30s that fires every 5s.
# The initial delay is to allow disconnected clients a chance to
# reconnect before we treat them as offline.
def run_timeout_handler():
return run_as_background_process(
"handle_presence_timeouts", self._handle_timeouts
)
self.clock.call_later(
30, self.clock.looping_call, run_timeout_handler, 5000
)
self.clock.call_later(30, self.clock.looping_call, run_timeout_handler, 5000)
def run_persister():
return run_as_background_process(
"persist_presence_changes", self._persist_unpersisted_changes
)
def run_persister():
return run_as_background_process(
"persist_presence_changes", self._persist_unpersisted_changes
)
self.clock.call_later(60, self.clock.looping_call, run_persister, 60 * 1000)
self.clock.call_later(60, self.clock.looping_call, run_persister, 60 * 1000)
LaterGauge(
"synapse_handlers_presence_wheel_timer_size",
@@ -299,7 +302,7 @@ class PresenceHandler(BasePresenceHandler):
)
# Used to handle sending of presence to newly joined users/servers
if hs.config.use_presence:
if self._presence_enabled:
self.notifier.add_replication_callback(self.notify_new_event)
# Presence is best effort and quickly heals itself, so lets just always

View File

@@ -43,7 +43,6 @@ class RoomListHandler(BaseHandler):
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.enable_room_list_search = hs.config.enable_room_list_search
self.response_cache = ResponseCache(
hs, "room_list"
) # type: ResponseCache[Tuple[Optional[int], Optional[str], ThirdPartyInstanceID]]

View File

@@ -66,7 +66,6 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
self.account_data_handler = hs.get_account_data_handler()
self.member_linearizer = Linearizer(name="member")
self.member_limiter = Linearizer(max_count=10, name="member_as_limiter")
self.clock = hs.get_clock()
self.spam_checker = hs.get_spam_checker()
@@ -337,38 +336,19 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
key = (room_id,)
as_id = object()
if requester.app_service:
as_id = requester.app_service.id
then = self.clock.time_msec()
with (await self.member_limiter.queue(as_id)):
diff = self.clock.time_msec() - then
if diff > 80 * 1000:
# haproxy would have timed the request out anyway...
raise SynapseError(504, "took to long to process")
with (await self.member_linearizer.queue(key)):
diff = self.clock.time_msec() - then
if diff > 80 * 1000:
# haproxy would have timed the request out anyway...
raise SynapseError(504, "took to long to process")
result = await self.update_membership_locked(
requester,
target,
room_id,
action,
txn_id=txn_id,
remote_room_hosts=remote_room_hosts,
third_party_signed=third_party_signed,
ratelimit=ratelimit,
content=content,
require_consent=require_consent,
)
with (await self.member_linearizer.queue(key)):
result = await self.update_membership_locked(
requester,
target,
room_id,
action,
txn_id=txn_id,
remote_room_hosts=remote_room_hosts,
third_party_signed=third_party_signed,
ratelimit=ratelimit,
content=content,
require_consent=require_consent,
)
return result

View File

@@ -31,8 +31,8 @@ from urllib.parse import urlencode
import attr
from typing_extensions import NoReturn, Protocol
from twisted.web.http import Request
from twisted.web.iweb import IRequest
from twisted.web.server import Request
from synapse.api.constants import LoginType
from synapse.api.errors import Codes, NotFoundError, RedirectException, SynapseError

View File

@@ -52,7 +52,6 @@ logger = logging.getLogger(__name__)
# Debug logger for https://github.com/matrix-org/synapse/issues/4422
issue4422_logger = logging.getLogger("synapse.handler.sync.4422_debug")
SYNC_RESPONSE_CACHE_MS = 2 * 60 * 1000
# Counts the number of times we returned a non-empty sync. `type` is one of
# "initial_sync", "full_state_sync" or "incremental_sync", `lazy_loaded` is
@@ -245,7 +244,7 @@ class SyncHandler:
self.event_sources = hs.get_event_sources()
self.clock = hs.get_clock()
self.response_cache = ResponseCache(
hs, "sync", timeout_ms=SYNC_RESPONSE_CACHE_MS
hs, "sync"
) # type: ResponseCache[Tuple[Any, ...]]
self.state = hs.get_state_handler()
self.auth = hs.get_auth()
@@ -278,9 +277,8 @@ class SyncHandler:
user_id = sync_config.user.to_string()
await self.auth.check_auth_blocking(requester=requester)
res = await self.response_cache.wrap_conditional(
res = await self.response_cache.wrap(
sync_config.request_key,
lambda result: since_token != result.next_batch,
self._wait_for_sync_for_user,
sync_config,
since_token,

View File

@@ -289,8 +289,7 @@ class SimpleHttpClient:
treq_args: Dict[str, Any] = {},
ip_whitelist: Optional[IPSet] = None,
ip_blacklist: Optional[IPSet] = None,
http_proxy: Optional[bytes] = None,
https_proxy: Optional[bytes] = None,
use_proxy: bool = False,
):
"""
Args:
@@ -300,8 +299,8 @@ class SimpleHttpClient:
we may not request.
ip_whitelist: The whitelisted IP addresses, that we can
request if it were otherwise caught in a blacklist.
http_proxy: proxy server to use for http connections. host[:port]
https_proxy: proxy server to use for https connections. host[:port]
use_proxy: Whether proxy settings should be discovered and used
from conventional environment variables.
"""
self.hs = hs
@@ -345,8 +344,7 @@ class SimpleHttpClient:
connectTimeout=15,
contextFactory=self.hs.get_http_client_context_factory(),
pool=pool,
http_proxy=http_proxy,
https_proxy=https_proxy,
use_proxy=use_proxy,
)
if self._ip_blacklist:
@@ -750,7 +748,32 @@ class BodyExceededMaxSize(Exception):
"""The maximum allowed size of the HTTP body was exceeded."""
class _DiscardBodyWithMaxSizeProtocol(protocol.Protocol):
"""A protocol which immediately errors upon receiving data."""
def __init__(self, deferred: defer.Deferred):
self.deferred = deferred
def _maybe_fail(self):
"""
Report a max size exceed error and disconnect the first time this is called.
"""
if not self.deferred.called:
self.deferred.errback(BodyExceededMaxSize())
# Close the connection (forcefully) since all the data will get
# discarded anyway.
self.transport.abortConnection()
def dataReceived(self, data: bytes) -> None:
self._maybe_fail()
def connectionLost(self, reason: Failure) -> None:
self._maybe_fail()
class _ReadBodyWithMaxSizeProtocol(protocol.Protocol):
"""A protocol which reads body to a stream, erroring if the body exceeds a maximum size."""
def __init__(
self, stream: BinaryIO, deferred: defer.Deferred, max_size: Optional[int]
):
@@ -807,13 +830,15 @@ def read_body_with_max_size(
Returns:
A Deferred which resolves to the length of the read body.
"""
d = defer.Deferred()
# If the Content-Length header gives a size larger than the maximum allowed
# size, do not bother downloading the body.
if max_size is not None and response.length != UNKNOWN_LENGTH:
if response.length > max_size:
return defer.fail(BodyExceededMaxSize())
response.deliverBody(_DiscardBodyWithMaxSizeProtocol(d))
return d
d = defer.Deferred()
response.deliverBody(_ReadBodyWithMaxSizeProtocol(stream, d, max_size))
return d

View File

@@ -14,7 +14,7 @@
# limitations under the License.
import logging
import urllib.parse
from typing import List, Optional
from typing import Any, Generator, List, Optional
from netaddr import AddrFormatError, IPAddress, IPSet
from zope.interface import implementer
@@ -116,7 +116,7 @@ class MatrixFederationAgent:
uri: bytes,
headers: Optional[Headers] = None,
bodyProducer: Optional[IBodyProducer] = None,
) -> defer.Deferred:
) -> Generator[defer.Deferred, Any, defer.Deferred]:
"""
Args:
method: HTTP method: GET/POST/etc
@@ -177,17 +177,17 @@ class MatrixFederationAgent:
# We need to make sure the host header is set to the netloc of the
# server and that a user-agent is provided.
if headers is None:
headers = Headers()
request_headers = Headers()
else:
headers = headers.copy()
request_headers = headers.copy()
if not headers.hasHeader(b"host"):
headers.addRawHeader(b"host", parsed_uri.netloc)
if not headers.hasHeader(b"user-agent"):
headers.addRawHeader(b"user-agent", self.user_agent)
if not request_headers.hasHeader(b"host"):
request_headers.addRawHeader(b"host", parsed_uri.netloc)
if not request_headers.hasHeader(b"user-agent"):
request_headers.addRawHeader(b"user-agent", self.user_agent)
res = yield make_deferred_yieldable(
self._agent.request(method, uri, headers, bodyProducer)
self._agent.request(method, uri, request_headers, bodyProducer)
)
return res

View File

@@ -1049,14 +1049,14 @@ def check_content_type_is_json(headers: Headers) -> None:
RequestSendFailed: if the Content-Type header is missing or isn't JSON
"""
c_type = headers.getRawHeaders(b"Content-Type")
if c_type is None:
content_type_headers = headers.getRawHeaders(b"Content-Type")
if content_type_headers is None:
raise RequestSendFailed(
RuntimeError("No Content-Type header received from remote server"),
can_retry=False,
)
c_type = c_type[0].decode("ascii") # only the first header
c_type = content_type_headers[0].decode("ascii") # only the first header
val, options = cgi.parse_header(c_type)
if val != "application/json":
raise RequestSendFailed(

View File

@@ -14,6 +14,7 @@
# limitations under the License.
import logging
import re
from urllib.request import getproxies_environment, proxy_bypass_environment
from zope.interface import implementer
@@ -58,6 +59,9 @@ class ProxyAgent(_AgentBase):
pool (HTTPConnectionPool|None): connection pool to be used. If None, a
non-persistent pool instance will be created.
use_proxy (bool): Whether proxy settings should be discovered and used
from conventional environment variables.
"""
def __init__(
@@ -68,8 +72,7 @@ class ProxyAgent(_AgentBase):
connectTimeout=None,
bindAddress=None,
pool=None,
http_proxy=None,
https_proxy=None,
use_proxy=False,
):
_AgentBase.__init__(self, reactor, pool)
@@ -84,6 +87,15 @@ class ProxyAgent(_AgentBase):
if bindAddress is not None:
self._endpoint_kwargs["bindAddress"] = bindAddress
http_proxy = None
https_proxy = None
no_proxy = None
if use_proxy:
proxies = getproxies_environment()
http_proxy = proxies["http"].encode() if "http" in proxies else None
https_proxy = proxies["https"].encode() if "https" in proxies else None
no_proxy = proxies["no"] if "no" in proxies else None
self.http_proxy_endpoint = _http_proxy_endpoint(
http_proxy, self.proxy_reactor, **self._endpoint_kwargs
)
@@ -92,6 +104,8 @@ class ProxyAgent(_AgentBase):
https_proxy, self.proxy_reactor, **self._endpoint_kwargs
)
self.no_proxy = no_proxy
self._policy_for_https = contextFactory
self._reactor = reactor
@@ -139,13 +153,28 @@ class ProxyAgent(_AgentBase):
pool_key = (parsed_uri.scheme, parsed_uri.host, parsed_uri.port)
request_path = parsed_uri.originForm
if parsed_uri.scheme == b"http" and self.http_proxy_endpoint:
should_skip_proxy = False
if self.no_proxy is not None:
should_skip_proxy = proxy_bypass_environment(
parsed_uri.host.decode(),
proxies={"no": self.no_proxy},
)
if (
parsed_uri.scheme == b"http"
and self.http_proxy_endpoint
and not should_skip_proxy
):
# Cache *all* connections under the same key, since we are only
# connecting to a single destination, the proxy:
pool_key = ("http-proxy", self.http_proxy_endpoint)
endpoint = self.http_proxy_endpoint
request_path = uri
elif parsed_uri.scheme == b"https" and self.https_proxy_endpoint:
elif (
parsed_uri.scheme == b"https"
and self.https_proxy_endpoint
and not should_skip_proxy
):
endpoint = HTTPConnectProxyEndpoint(
self.proxy_reactor,
self.https_proxy_endpoint,

View File

@@ -21,6 +21,7 @@ import logging
import types
import urllib
from http import HTTPStatus
from inspect import isawaitable
from io import BytesIO
from typing import (
Any,
@@ -30,6 +31,7 @@ from typing import (
Iterable,
Iterator,
List,
Optional,
Pattern,
Tuple,
Union,
@@ -79,10 +81,12 @@ def return_json_error(f: failure.Failure, request: SynapseRequest) -> None:
"""Sends a JSON error response to clients."""
if f.check(SynapseError):
error_code = f.value.code
error_dict = f.value.error_dict()
# mypy doesn't understand that f.check asserts the type.
exc = f.value # type: SynapseError # type: ignore
error_code = exc.code
error_dict = exc.error_dict()
logger.info("%s SynapseError: %s - %s", request, error_code, f.value.msg)
logger.info("%s SynapseError: %s - %s", request, error_code, exc.msg)
else:
error_code = 500
error_dict = {"error": "Internal server error", "errcode": Codes.UNKNOWN}
@@ -91,7 +95,7 @@ def return_json_error(f: failure.Failure, request: SynapseRequest) -> None:
"Failed handle request via %r: %r",
request.request_metrics.name,
request,
exc_info=(f.type, f.value, f.getTracebackObject()),
exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore
)
# Only respond with an error response if we haven't already started writing,
@@ -128,7 +132,8 @@ def return_html_error(
`{msg}` placeholders), or a jinja2 template
"""
if f.check(CodeMessageException):
cme = f.value
# mypy doesn't understand that f.check asserts the type.
cme = f.value # type: CodeMessageException # type: ignore
code = cme.code
msg = cme.msg
@@ -142,7 +147,7 @@ def return_html_error(
logger.error(
"Failed handle request %r",
request,
exc_info=(f.type, f.value, f.getTracebackObject()),
exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore
)
else:
code = HTTPStatus.INTERNAL_SERVER_ERROR
@@ -151,7 +156,7 @@ def return_html_error(
logger.error(
"Failed handle request %r",
request,
exc_info=(f.type, f.value, f.getTracebackObject()),
exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore
)
if isinstance(error_template, str):
@@ -278,7 +283,7 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
raw_callback_return = method_handler(request)
# Is it synchronous? We'll allow this for now.
if isinstance(raw_callback_return, (defer.Deferred, types.CoroutineType)):
if isawaitable(raw_callback_return):
callback_return = await raw_callback_return
else:
callback_return = raw_callback_return # type: ignore
@@ -399,8 +404,10 @@ class JsonResource(DirectServeJsonResource):
A tuple of the callback to use, the name of the servlet, and the
key word arguments to pass to the callback
"""
# At this point the path must be bytes.
request_path_bytes = request.path # type: bytes # type: ignore
request_path = request_path_bytes.decode("ascii")
# Treat HEAD requests as GET requests.
request_path = request.path.decode("ascii")
request_method = request.method
if request_method == b"HEAD":
request_method = b"GET"
@@ -551,7 +558,7 @@ class _ByteProducer:
request: Request,
iterator: Iterator[bytes],
):
self._request = request
self._request = request # type: Optional[Request]
self._iterator = iterator
self._paused = False
@@ -563,7 +570,7 @@ class _ByteProducer:
"""
Send a list of bytes as a chunk of a response.
"""
if not data:
if not data or not self._request:
return
self._request.write(b"".join(data))

View File

@@ -14,7 +14,7 @@
import contextlib
import logging
import time
from typing import Optional, Union
from typing import Optional, Type, Union
import attr
from zope.interface import implementer
@@ -57,7 +57,7 @@ class SynapseRequest(Request):
def __init__(self, channel, *args, **kw):
Request.__init__(self, channel, *args, **kw)
self.site = channel.site
self.site = channel.site # type: SynapseSite
self._channel = channel # this is used by the tests
self.start_time = 0.0
@@ -96,25 +96,34 @@ class SynapseRequest(Request):
def get_request_id(self):
return "%s-%i" % (self.get_method(), self.request_seq)
def get_redacted_uri(self):
uri = self.uri
def get_redacted_uri(self) -> str:
"""Gets the redacted URI associated with the request (or placeholder if the URI
has not yet been received).
Note: This is necessary as the placeholder value in twisted is str
rather than bytes, so we need to sanitise `self.uri`.
Returns:
The redacted URI as a string.
"""
uri = self.uri # type: Union[bytes, str]
if isinstance(uri, bytes):
uri = self.uri.decode("ascii", errors="replace")
uri = uri.decode("ascii", errors="replace")
return redact_uri(uri)
def get_method(self):
"""Gets the method associated with the request (or placeholder if not
method has yet been received).
def get_method(self) -> str:
"""Gets the method associated with the request (or placeholder if method
has not yet been received).
Note: This is necessary as the placeholder value in twisted is str
rather than bytes, so we need to sanitise `self.method`.
Returns:
str
The request method as a string.
"""
method = self.method
method = self.method # type: Union[bytes, str]
if isinstance(method, bytes):
method = self.method.decode("ascii")
return self.method.decode("ascii")
return method
def render(self, resrc):
@@ -375,9 +384,9 @@ class XForwardedForRequest(SynapseRequest):
else:
# this is done largely for backwards-compatibility so that people that
# haven't set an x-forwarded-proto header don't get a redirect loop.
#logger.warning(
# "forwarded request lacks an x-forwarded-proto header: assuming https"
#)
logger.warning(
"forwarded request lacks an x-forwarded-proto header: assuming https"
)
self._forwarded_https = True
def isSecure(self):
@@ -432,7 +441,9 @@ class SynapseSite(Site):
assert config.http_options is not None
proxied = config.http_options.x_forwarded
self.requestFactory = XForwardedForRequest if proxied else SynapseRequest
self.requestFactory = (
XForwardedForRequest if proxied else SynapseRequest
) # type: Type[Request]
self.access_logger = logging.getLogger(logger_name)
self.server_version_string = server_version_string.encode("ascii")

View File

@@ -32,7 +32,7 @@ from twisted.internet.endpoints import (
TCP4ClientEndpoint,
TCP6ClientEndpoint,
)
from twisted.internet.interfaces import IPushProducer, ITransport
from twisted.internet.interfaces import IPushProducer, IStreamClientEndpoint, ITransport
from twisted.internet.protocol import Factory, Protocol
from twisted.python.failure import Failure
@@ -121,7 +121,9 @@ class RemoteHandler(logging.Handler):
try:
ip = ip_address(self.host)
if isinstance(ip, IPv4Address):
endpoint = TCP4ClientEndpoint(_reactor, self.host, self.port)
endpoint = TCP4ClientEndpoint(
_reactor, self.host, self.port
) # type: IStreamClientEndpoint
elif isinstance(ip, IPv6Address):
endpoint = TCP6ClientEndpoint(_reactor, self.host, self.port)
else:

View File

@@ -527,7 +527,7 @@ class ReactorLastSeenMetric:
REGISTRY.register(ReactorLastSeenMetric())
def runUntilCurrentTimer(func):
def runUntilCurrentTimer(reactor, func):
@functools.wraps(func)
def f(*args, **kwargs):
now = reactor.seconds()
@@ -590,13 +590,14 @@ def runUntilCurrentTimer(func):
try:
# Ensure the reactor has all the attributes we expect
reactor.runUntilCurrent
reactor._newTimedCalls
reactor.threadCallQueue
reactor.seconds # type: ignore
reactor.runUntilCurrent # type: ignore
reactor._newTimedCalls # type: ignore
reactor.threadCallQueue # type: ignore
# runUntilCurrent is called when we have pending calls. It is called once
# per iteratation after fd polling.
reactor.runUntilCurrent = runUntilCurrentTimer(reactor.runUntilCurrent)
reactor.runUntilCurrent = runUntilCurrentTimer(reactor, reactor.runUntilCurrent) # type: ignore
# We manually run the GC each reactor tick so that we can get some metrics
# about time spent doing GC,

View File

@@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Iterable, Optional, Tuple
from typing import TYPE_CHECKING, Any, Generator, Iterable, Optional, Tuple
from twisted.internet import defer
@@ -307,7 +307,7 @@ class ModuleApi:
@defer.inlineCallbacks
def get_state_events_in_room(
self, room_id: str, types: Iterable[Tuple[str, Optional[str]]]
) -> defer.Deferred:
) -> Generator[defer.Deferred, Any, defer.Deferred]:
"""Gets current state events for the given room.
(This is exposed for compatibility with the old SpamCheckerApi. We should

View File

@@ -15,11 +15,12 @@
# limitations under the License.
import logging
import urllib.parse
from typing import TYPE_CHECKING, Any, Dict, Iterable, Union
from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Union
from prometheus_client import Counter
from twisted.internet.error import AlreadyCalled, AlreadyCancelled
from twisted.internet.interfaces import IDelayedCall
from synapse.api.constants import EventTypes
from synapse.events import EventBase
@@ -71,7 +72,7 @@ class HttpPusher(Pusher):
self.data = pusher_config.data
self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC
self.failing_since = pusher_config.failing_since
self.timed_call = None
self.timed_call = None # type: Optional[IDelayedCall]
self._is_processing = False
self._group_unread_count_by_room = hs.config.push_group_unread_count_by_room
self._pusherpool = hs.get_pusherpool()
@@ -101,11 +102,6 @@ class HttpPusher(Pusher):
"'url' must have a path of '/_matrix/push/v1/notify'"
)
url = url.replace(
"https://matrix.org/_matrix/push/v1/notify",
"http://10.103.0.7/_matrix/push/v1/notify",
)
self.url = url
self.http_client = hs.get_proxied_blacklisted_http_client()
self.data_minus_url = {}

View File

@@ -15,9 +15,10 @@
import logging
from typing import TYPE_CHECKING, List, Optional, Tuple
from twisted.web.http import Request
from twisted.web.server import Request
from synapse.http.servlet import parse_json_object_from_request
from synapse.http.site import SynapseRequest
from synapse.replication.http._base import ReplicationEndpoint
from synapse.types import JsonDict, Requester, UserID
from synapse.util.distributor import user_left_room
@@ -78,7 +79,7 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
}
async def _handle_request( # type: ignore
self, request: Request, room_id: str, user_id: str
self, request: SynapseRequest, room_id: str, user_id: str
) -> Tuple[int, JsonDict]:
content = parse_json_object_from_request(request)
@@ -86,7 +87,6 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
event_content = content["content"]
requester = Requester.deserialize(self.store, content["requester"])
request.requester = requester
logger.info("remote_join: %s into room: %s", user_id, room_id)
@@ -147,7 +147,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
}
async def _handle_request( # type: ignore
self, request: Request, invite_event_id: str
self, request: SynapseRequest, invite_event_id: str
) -> Tuple[int, JsonDict]:
content = parse_json_object_from_request(request)
@@ -155,7 +155,6 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
event_content = content["content"]
requester = Requester.deserialize(self.store, content["requester"])
request.requester = requester
# hopefully we're now on the master, so this won't recurse!

View File

@@ -108,9 +108,7 @@ class ReplicationDataHandler:
# Map from stream to list of deferreds waiting for the stream to
# arrive at a particular position. The lists are sorted by stream position.
self._streams_to_waiters = (
{}
) # type: Dict[str, List[Tuple[int, Deferred[None]]]]
self._streams_to_waiters = {} # type: Dict[str, List[Tuple[int, Deferred]]]
async def on_rdata(
self, stream_name: str, instance_name: str, token: int, rows: list

View File

@@ -502,7 +502,7 @@ class AccountDataStream(Stream):
"""Global or per room account data was changed"""
AccountDataStreamRow = namedtuple(
"AccountDataStream",
"AccountDataStreamRow",
("user_id", "room_id", "data_type"), # str # Optional[str] # str
)

View File

@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Tuple
from synapse.api.errors import NotFoundError, SynapseError
from synapse.http.servlet import (
@@ -20,8 +21,12 @@ from synapse.http.servlet import (
assert_params_in_dict,
parse_json_object_from_request,
)
from synapse.http.site import SynapseRequest
from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin
from synapse.types import UserID
from synapse.types import JsonDict, UserID
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@@ -35,14 +40,16 @@ class DeviceRestServlet(RestServlet):
"/users/(?P<user_id>[^/]*)/devices/(?P<device_id>[^/]*)$", "v2"
)
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__()
self.hs = hs
self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler()
self.store = hs.get_datastore()
async def on_GET(self, request, user_id, device_id):
async def on_GET(
self, request: SynapseRequest, user_id, device_id: str
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
target_user = UserID.from_string(user_id)
@@ -58,7 +65,9 @@ class DeviceRestServlet(RestServlet):
)
return 200, device
async def on_DELETE(self, request, user_id, device_id):
async def on_DELETE(
self, request: SynapseRequest, user_id: str, device_id: str
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
target_user = UserID.from_string(user_id)
@@ -72,7 +81,9 @@ class DeviceRestServlet(RestServlet):
await self.device_handler.delete_device(target_user.to_string(), device_id)
return 200, {}
async def on_PUT(self, request, user_id, device_id):
async def on_PUT(
self, request: SynapseRequest, user_id: str, device_id: str
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
target_user = UserID.from_string(user_id)
@@ -97,7 +108,7 @@ class DevicesRestServlet(RestServlet):
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/devices$", "v2")
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
"""
Args:
hs (synapse.server.HomeServer): server
@@ -107,7 +118,9 @@ class DevicesRestServlet(RestServlet):
self.device_handler = hs.get_device_handler()
self.store = hs.get_datastore()
async def on_GET(self, request, user_id):
async def on_GET(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
target_user = UserID.from_string(user_id)
@@ -130,13 +143,15 @@ class DeleteDevicesRestServlet(RestServlet):
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/delete_devices$", "v2")
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler()
self.store = hs.get_datastore()
async def on_POST(self, request, user_id):
async def on_POST(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
target_user = UserID.from_string(user_id)

View File

@@ -14,10 +14,16 @@
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Tuple
from synapse.api.errors import Codes, NotFoundError, SynapseError
from synapse.http.servlet import RestServlet, parse_integer, parse_string
from synapse.http.site import SynapseRequest
from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin
from synapse.types import JsonDict
if TYPE_CHECKING:
from synapse.server import HomeServer
logger = logging.getLogger(__name__)
@@ -45,12 +51,12 @@ class EventReportsRestServlet(RestServlet):
PATTERNS = admin_patterns("/event_reports$")
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth()
self.store = hs.get_datastore()
async def on_GET(self, request):
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
start = parse_integer(request, "from", default=0)
@@ -106,26 +112,28 @@ class EventReportDetailRestServlet(RestServlet):
PATTERNS = admin_patterns("/event_reports/(?P<report_id>[^/]*)$")
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth()
self.store = hs.get_datastore()
async def on_GET(self, request, report_id):
async def on_GET(
self, request: SynapseRequest, report_id: str
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
message = (
"The report_id parameter must be a string representing a positive integer."
)
try:
report_id = int(report_id)
resolved_report_id = int(report_id)
except ValueError:
raise SynapseError(400, message, errcode=Codes.INVALID_PARAM)
if report_id < 0:
if resolved_report_id < 0:
raise SynapseError(400, message, errcode=Codes.INVALID_PARAM)
ret = await self.store.get_event_report(report_id)
ret = await self.store.get_event_report(resolved_report_id)
if not ret:
raise NotFoundError("Event report not found")

View File

@@ -17,7 +17,7 @@
import logging
from typing import TYPE_CHECKING, Tuple
from twisted.web.http import Request
from twisted.web.server import Request
from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
from synapse.http.servlet import RestServlet, parse_boolean, parse_integer

View File

@@ -44,6 +44,48 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
class ResolveRoomIdMixin:
def __init__(self, hs: "HomeServer"):
self.room_member_handler = hs.get_room_member_handler()
async def resolve_room_id(
self, room_identifier: str, remote_room_hosts: Optional[List[str]] = None
) -> Tuple[str, Optional[List[str]]]:
"""
Resolve a room identifier to a room ID, if necessary.
This also performanes checks to ensure the room ID is of the proper form.
Args:
room_identifier: The room ID or alias.
remote_room_hosts: The potential remote room hosts to use.
Returns:
The resolved room ID.
Raises:
SynapseError if the room ID is of the wrong form.
"""
if RoomID.is_valid(room_identifier):
resolved_room_id = room_identifier
elif RoomAlias.is_valid(room_identifier):
room_alias = RoomAlias.from_string(room_identifier)
(
room_id,
remote_room_hosts,
) = await self.room_member_handler.lookup_room_alias(room_alias)
resolved_room_id = room_id.to_string()
else:
raise SynapseError(
400, "%s was not legal room ID or room alias" % (room_identifier,)
)
if not resolved_room_id:
raise SynapseError(
400, "Unknown room ID or room alias %s" % room_identifier
)
return resolved_room_id, remote_room_hosts
class ShutdownRoomRestServlet(RestServlet):
"""Shuts down a room by removing all local users from the room and blocking
all future invites and joins to the room. Any local aliases will be repointed
@@ -334,14 +376,14 @@ class RoomStateRestServlet(RestServlet):
return 200, ret
class JoinRoomAliasServlet(RestServlet):
class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet):
PATTERNS = admin_patterns("/join/(?P<room_identifier>[^/]*)")
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.hs = hs
self.auth = hs.get_auth()
self.room_member_handler = hs.get_room_member_handler()
self.admin_handler = hs.get_admin_handler()
self.state_handler = hs.get_state_handler()
@@ -362,22 +404,16 @@ class JoinRoomAliasServlet(RestServlet):
if not await self.admin_handler.get_user(target_user):
raise NotFoundError("User not found")
if RoomID.is_valid(room_identifier):
room_id = room_identifier
try:
remote_room_hosts = [
x.decode("ascii") for x in request.args[b"server_name"]
] # type: Optional[List[str]]
except Exception:
remote_room_hosts = None
elif RoomAlias.is_valid(room_identifier):
handler = self.room_member_handler
room_alias = RoomAlias.from_string(room_identifier)
room_id, remote_room_hosts = await handler.lookup_room_alias(room_alias)
else:
raise SynapseError(
400, "%s was not legal room ID or room alias" % (room_identifier,)
)
# Get the room ID from the identifier.
try:
remote_room_hosts = [
x.decode("ascii") for x in request.args[b"server_name"]
] # type: Optional[List[str]]
except Exception:
remote_room_hosts = None
room_id, remote_room_hosts = await self.resolve_room_id(
room_identifier, remote_room_hosts
)
fake_requester = create_requester(
target_user, authenticated_entity=requester.authenticated_entity
@@ -412,7 +448,7 @@ class JoinRoomAliasServlet(RestServlet):
return 200, {"room_id": room_id}
class MakeRoomAdminRestServlet(RestServlet):
class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet):
"""Allows a server admin to get power in a room if a local user has power in
a room. Will also invite the user if they're not in the room and it's a
private room. Can specify another user (rather than the admin user) to be
@@ -427,29 +463,21 @@ class MakeRoomAdminRestServlet(RestServlet):
PATTERNS = admin_patterns("/rooms/(?P<room_identifier>[^/]*)/make_room_admin")
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.hs = hs
self.auth = hs.get_auth()
self.room_member_handler = hs.get_room_member_handler()
self.event_creation_handler = hs.get_event_creation_handler()
self.state_handler = hs.get_state_handler()
self.is_mine_id = hs.is_mine_id
async def on_POST(self, request, room_identifier):
async def on_POST(
self, request: SynapseRequest, room_identifier: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
content = parse_json_object_from_request(request, allow_empty_body=True)
# Resolve to a room ID, if necessary.
if RoomID.is_valid(room_identifier):
room_id = room_identifier
elif RoomAlias.is_valid(room_identifier):
room_alias = RoomAlias.from_string(room_identifier)
room_id, _ = await self.room_member_handler.lookup_room_alias(room_alias)
room_id = room_id.to_string()
else:
raise SynapseError(
400, "%s was not legal room ID or room alias" % (room_identifier,)
)
room_id, _ = await self.resolve_room_id(room_identifier)
# Which user to grant room admin rights to.
user_to_add = content.get("user_id", requester.user.to_string())
@@ -556,7 +584,7 @@ class MakeRoomAdminRestServlet(RestServlet):
return 200, {}
class ForwardExtremitiesRestServlet(RestServlet):
class ForwardExtremitiesRestServlet(ResolveRoomIdMixin, RestServlet):
"""Allows a server admin to get or clear forward extremities.
Clearing does not require restarting the server.
@@ -571,43 +599,29 @@ class ForwardExtremitiesRestServlet(RestServlet):
PATTERNS = admin_patterns("/rooms/(?P<room_identifier>[^/]*)/forward_extremities")
def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.hs = hs
self.auth = hs.get_auth()
self.room_member_handler = hs.get_room_member_handler()
self.store = hs.get_datastore()
async def resolve_room_id(self, room_identifier: str) -> str:
"""Resolve to a room ID, if necessary."""
if RoomID.is_valid(room_identifier):
resolved_room_id = room_identifier
elif RoomAlias.is_valid(room_identifier):
room_alias = RoomAlias.from_string(room_identifier)
room_id, _ = await self.room_member_handler.lookup_room_alias(room_alias)
resolved_room_id = room_id.to_string()
else:
raise SynapseError(
400, "%s was not legal room ID or room alias" % (room_identifier,)
)
if not resolved_room_id:
raise SynapseError(
400, "Unknown room ID or room alias %s" % room_identifier
)
return resolved_room_id
async def on_DELETE(self, request, room_identifier):
async def on_DELETE(
self, request: SynapseRequest, room_identifier: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
room_id = await self.resolve_room_id(room_identifier)
room_id, _ = await self.resolve_room_id(room_identifier)
deleted_count = await self.store.delete_forward_extremities_for_room(room_id)
return 200, {"deleted": deleted_count}
async def on_GET(self, request, room_identifier):
async def on_GET(
self, request: SynapseRequest, room_identifier: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
room_id = await self.resolve_room_id(room_identifier)
room_id, _ = await self.resolve_room_id(room_identifier)
extremities = await self.store.get_forward_extremities_for_room(room_id)
return 200, {"count": len(extremities), "results": extremities}
@@ -623,14 +637,16 @@ class RoomEventContextServlet(RestServlet):
PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]*)/context/(?P<event_id>[^/]*)$")
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
super().__init__()
self.clock = hs.get_clock()
self.room_context_handler = hs.get_room_context_handler()
self._event_serializer = hs.get_event_client_serializer()
self.auth = hs.get_auth()
async def on_GET(self, request, room_id, event_id):
async def on_GET(
self, request: SynapseRequest, room_id: str, event_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request, allow_guest=False)
await assert_user_is_admin(self.auth, requester.user)

View File

@@ -16,7 +16,7 @@ import hashlib
import hmac
import logging
from http import HTTPStatus
from typing import TYPE_CHECKING, Tuple
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
from synapse.api.constants import UserTypes
from synapse.api.errors import Codes, NotFoundError, SynapseError
@@ -47,13 +47,15 @@ logger = logging.getLogger(__name__)
class UsersRestServlet(RestServlet):
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)$")
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.store = hs.get_datastore()
self.auth = hs.get_auth()
self.admin_handler = hs.get_admin_handler()
async def on_GET(self, request, user_id):
async def on_GET(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, List[JsonDict]]:
target_user = UserID.from_string(user_id)
await assert_requester_is_admin(self.auth, request)
@@ -153,7 +155,7 @@ class UserRestServletV2(RestServlet):
otherwise an error.
"""
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth()
self.admin_handler = hs.get_admin_handler()
@@ -165,7 +167,9 @@ class UserRestServletV2(RestServlet):
self.registration_handler = hs.get_registration_handler()
self.pusher_pool = hs.get_pusherpool()
async def on_GET(self, request, user_id):
async def on_GET(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
target_user = UserID.from_string(user_id)
@@ -179,7 +183,9 @@ class UserRestServletV2(RestServlet):
return 200, ret
async def on_PUT(self, request, user_id):
async def on_PUT(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
@@ -273,6 +279,8 @@ class UserRestServletV2(RestServlet):
)
user = await self.admin_handler.get_user(target_user)
assert user is not None
return 200, user
else: # create user
@@ -330,9 +338,10 @@ class UserRestServletV2(RestServlet):
target_user, requester, body["avatar_url"], True
)
ret = await self.admin_handler.get_user(target_user)
user = await self.admin_handler.get_user(target_user)
assert user is not None
return 201, ret
return 201, user
class UserRegisterServlet(RestServlet):
@@ -346,10 +355,10 @@ class UserRegisterServlet(RestServlet):
PATTERNS = admin_patterns("/register")
NONCE_TIMEOUT = 60
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.auth_handler = hs.get_auth_handler()
self.reactor = hs.get_reactor()
self.nonces = {}
self.nonces = {} # type: Dict[str, int]
self.hs = hs
def _clear_old_nonces(self):
@@ -362,7 +371,7 @@ class UserRegisterServlet(RestServlet):
if now - v > self.NONCE_TIMEOUT:
del self.nonces[k]
def on_GET(self, request):
def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
"""
Generate a new nonce.
"""
@@ -372,7 +381,7 @@ class UserRegisterServlet(RestServlet):
self.nonces[nonce] = int(self.reactor.seconds())
return 200, {"nonce": nonce}
async def on_POST(self, request):
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
self._clear_old_nonces()
if not self.hs.config.registration_shared_secret:
@@ -478,12 +487,14 @@ class WhoisRestServlet(RestServlet):
client_patterns("/admin" + path_regex, v1=True)
)
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.auth = hs.get_auth()
self.admin_handler = hs.get_admin_handler()
async def on_GET(self, request, user_id):
async def on_GET(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
target_user = UserID.from_string(user_id)
requester = await self.auth.get_user_by_req(request)
auth_user = requester.user
@@ -508,7 +519,9 @@ class DeactivateAccountRestServlet(RestServlet):
self.is_mine = hs.is_mine
self.store = hs.get_datastore()
async def on_POST(self, request: str, target_user_id: str) -> Tuple[int, JsonDict]:
async def on_POST(
self, request: SynapseRequest, target_user_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
@@ -550,7 +563,7 @@ class AccountValidityRenewServlet(RestServlet):
self.account_activity_handler = hs.get_account_validity_handler()
self.auth = hs.get_auth()
async def on_POST(self, request):
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
body = parse_json_object_from_request(request)
@@ -584,14 +597,16 @@ class ResetPasswordRestServlet(RestServlet):
PATTERNS = admin_patterns("/reset_password/(?P<target_user_id>[^/]*)")
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
self.hs = hs
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
self._set_password_handler = hs.get_set_password_handler()
async def on_POST(self, request, target_user_id):
async def on_POST(
self, request: SynapseRequest, target_user_id: str
) -> Tuple[int, JsonDict]:
"""Post request to allow an administrator reset password for a user.
This needs user to have administrator access in Synapse.
"""
@@ -626,12 +641,14 @@ class SearchUsersRestServlet(RestServlet):
PATTERNS = admin_patterns("/search_users/(?P<target_user_id>[^/]*)")
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.store = hs.get_datastore()
self.auth = hs.get_auth()
async def on_GET(self, request, target_user_id):
async def on_GET(
self, request: SynapseRequest, target_user_id: str
) -> Tuple[int, Optional[List[JsonDict]]]:
"""Get request to search user table for specific users according to
search term.
This needs user to have a administrator access in Synapse.
@@ -682,12 +699,14 @@ class UserAdminServlet(RestServlet):
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/admin$")
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.store = hs.get_datastore()
self.auth = hs.get_auth()
async def on_GET(self, request, user_id):
async def on_GET(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
target_user = UserID.from_string(user_id)
@@ -699,7 +718,9 @@ class UserAdminServlet(RestServlet):
return 200, {"admin": is_admin}
async def on_PUT(self, request, user_id):
async def on_PUT(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
auth_user = requester.user
@@ -730,12 +751,14 @@ class UserMembershipRestServlet(RestServlet):
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]+)/joined_rooms$")
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.is_mine = hs.is_mine
self.auth = hs.get_auth()
self.store = hs.get_datastore()
async def on_GET(self, request, user_id):
async def on_GET(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
room_ids = await self.store.get_rooms_for_user(user_id)
@@ -758,7 +781,7 @@ class PushersRestServlet(RestServlet):
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/pushers$")
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.is_mine = hs.is_mine
self.store = hs.get_datastore()
self.auth = hs.get_auth()
@@ -799,7 +822,7 @@ class UserMediaRestServlet(RestServlet):
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]+)/media$")
def __init__(self, hs):
def __init__(self, hs: "HomeServer"):
self.is_mine = hs.is_mine
self.auth = hs.get_auth()
self.store = hs.get_datastore()
@@ -891,7 +914,9 @@ class UserTokenRestServlet(RestServlet):
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
async def on_POST(self, request, user_id):
async def on_POST(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
requester = await self.auth.get_user_by_req(request)
await assert_user_is_admin(self.auth, requester.user)
auth_user = requester.user
@@ -943,7 +968,9 @@ class ShadowBanRestServlet(RestServlet):
self.store = hs.get_datastore()
self.auth = hs.get_auth()
async def on_POST(self, request, user_id):
async def on_POST(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
await assert_requester_is_admin(self.auth, request)
if not self.hs.is_mine_id(user_id):

View File

@@ -18,7 +18,7 @@ import logging
from functools import wraps
from typing import TYPE_CHECKING, Optional, Tuple
from twisted.web.http import Request
from twisted.web.server import Request
from synapse.api.constants import (
MAX_GROUP_CATEGORYID_LENGTH,

View File

@@ -21,7 +21,7 @@ from typing import Awaitable, Dict, Generator, List, Optional, Tuple
from twisted.internet.interfaces import IConsumer
from twisted.protocols.basic import FileSender
from twisted.web.http import Request
from twisted.web.server import Request
from synapse.api.errors import Codes, SynapseError, cs_error
from synapse.http.server import finish_request, respond_with_json
@@ -49,18 +49,20 @@ TEXT_CONTENT_TYPES = [
def parse_media_id(request: Request) -> Tuple[str, str, Optional[str]]:
try:
# The type on postpath seems incorrect in Twisted 21.2.0.
postpath = request.postpath # type: List[bytes] # type: ignore
assert postpath
# This allows users to append e.g. /test.png to the URL. Useful for
# clients that parse the URL to see content type.
server_name, media_id = request.postpath[:2]
if isinstance(server_name, bytes):
server_name = server_name.decode("utf-8")
media_id = media_id.decode("utf8")
server_name_bytes, media_id_bytes = postpath[:2]
server_name = server_name_bytes.decode("utf-8")
media_id = media_id_bytes.decode("utf8")
file_name = None
if len(request.postpath) > 2:
if len(postpath) > 2:
try:
file_name = urllib.parse.unquote(request.postpath[-1].decode("utf-8"))
file_name = urllib.parse.unquote(postpath[-1].decode("utf-8"))
except UnicodeDecodeError:
pass
return server_name, media_id, file_name

View File

@@ -17,7 +17,7 @@
from typing import TYPE_CHECKING
from twisted.web.http import Request
from twisted.web.server import Request
from synapse.http.server import DirectServeJsonResource, respond_with_json

View File

@@ -16,7 +16,7 @@
import logging
from typing import TYPE_CHECKING
from twisted.web.http import Request
from twisted.web.server import Request
from synapse.http.server import DirectServeJsonResource, set_cors_headers
from synapse.http.servlet import parse_boolean

View File

@@ -22,8 +22,8 @@ from typing import IO, TYPE_CHECKING, Dict, List, Optional, Set, Tuple
import twisted.internet.error
import twisted.web.http
from twisted.web.http import Request
from twisted.web.resource import Resource
from twisted.web.server import Request
from synapse.api.errors import (
FederationDeniedError,

View File

@@ -29,7 +29,7 @@ from urllib import parse as urlparse
import attr
from twisted.internet.error import DNSLookupError
from twisted.web.http import Request
from twisted.web.server import Request
from synapse.api.errors import Codes, SynapseError
from synapse.http.client import SimpleHttpClient
@@ -149,8 +149,7 @@ class PreviewUrlResource(DirectServeJsonResource):
treq_args={"browser_like_redirects": True},
ip_whitelist=hs.config.url_preview_ip_range_whitelist,
ip_blacklist=hs.config.url_preview_ip_range_blacklist,
http_proxy=os.getenvb(b"http_proxy"),
https_proxy=os.getenvb(b"HTTPS_PROXY"),
use_proxy=True,
)
self.media_repo = media_repo
self.primary_base_path = media_repo.primary_base_path

View File

@@ -18,7 +18,7 @@
import logging
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from twisted.web.http import Request
from twisted.web.server import Request
from synapse.api.errors import SynapseError
from synapse.http.server import DirectServeJsonResource, set_cors_headers

View File

@@ -15,9 +15,9 @@
# limitations under the License.
import logging
from typing import TYPE_CHECKING
from typing import IO, TYPE_CHECKING
from twisted.web.http import Request
from twisted.web.server import Request
from synapse.api.errors import Codes, SynapseError
from synapse.http.server import DirectServeJsonResource, respond_with_json
@@ -79,7 +79,9 @@ class UploadResource(DirectServeJsonResource):
headers = request.requestHeaders
if headers.hasHeader(b"Content-Type"):
media_type = headers.getRawHeaders(b"Content-Type")[0].decode("ascii")
content_type_headers = headers.getRawHeaders(b"Content-Type")
assert content_type_headers # for mypy
media_type = content_type_headers[0].decode("ascii")
else:
raise SynapseError(msg="Upload request missing 'Content-Type'", code=400)
@@ -88,8 +90,9 @@ class UploadResource(DirectServeJsonResource):
# TODO(markjh): parse content-dispostion
try:
content = request.content # type: IO # type: ignore
content_uri = await self.media_repo.create_content(
media_type, upload_name, request.content, content_length, requester.user
media_type, upload_name, content, content_length, requester.user
)
except SpamMediaException:
# For uploading of media we want to respond with a 400, instead of

View File

@@ -15,7 +15,7 @@
import logging
from typing import TYPE_CHECKING
from twisted.web.http import Request
from twisted.web.server import Request
from synapse.api.errors import SynapseError
from synapse.handlers.sso import get_username_mapping_session_cookie_from_request

View File

@@ -15,7 +15,7 @@
import logging
from typing import TYPE_CHECKING, Tuple
from twisted.web.http import Request
from twisted.web.server import Request
from synapse.api.errors import ThreepidValidationError
from synapse.config.emailconfig import ThreepidBehaviour

View File

@@ -16,8 +16,8 @@
import logging
from typing import TYPE_CHECKING, List
from twisted.web.http import Request
from twisted.web.resource import Resource
from twisted.web.server import Request
from synapse.api.errors import SynapseError
from synapse.handlers.sso import get_username_mapping_session_cookie_from_request

View File

@@ -16,7 +16,7 @@
import logging
from typing import TYPE_CHECKING
from twisted.web.http import Request
from twisted.web.server import Request
from synapse.api.errors import SynapseError
from synapse.handlers.sso import get_username_mapping_session_cookie_from_request

View File

@@ -24,7 +24,6 @@
import abc
import functools
import logging
import os
from typing import (
TYPE_CHECKING,
Any,
@@ -39,6 +38,7 @@ from typing import (
import twisted.internet.base
import twisted.internet.tcp
from twisted.internet import defer
from twisted.mail.smtp import sendmail
from twisted.web.iweb import IPolicyForHTTPS
@@ -370,11 +370,7 @@ class HomeServer(metaclass=abc.ABCMeta):
"""
An HTTP client that uses configured HTTP(S) proxies.
"""
return SimpleHttpClient(
self,
http_proxy=os.getenvb(b"http_proxy"),
https_proxy=os.getenvb(b"HTTPS_PROXY"),
)
return SimpleHttpClient(self, use_proxy=True)
@cache_in_self
def get_proxied_blacklisted_http_client(self) -> SimpleHttpClient:
@@ -386,8 +382,7 @@ class HomeServer(metaclass=abc.ABCMeta):
self,
ip_whitelist=self.config.ip_range_whitelist,
ip_blacklist=self.config.ip_range_blacklist,
http_proxy=os.getenvb(b"http_proxy"),
https_proxy=os.getenvb(b"HTTPS_PROXY"),
use_proxy=True,
)
@cache_in_self
@@ -409,7 +404,7 @@ class HomeServer(metaclass=abc.ABCMeta):
return RoomShutdownHandler(self)
@cache_in_self
def get_sendmail(self) -> sendmail:
def get_sendmail(self) -> Callable[..., defer.Deferred]:
return sendmail
@cache_in_self

View File

@@ -16,7 +16,7 @@
# limitations under the License.
import logging
from typing import Any, Dict, List, Optional, Tuple
from typing import List, Optional, Tuple
from synapse.api.constants import PresenceState
from synapse.config.homeserver import HomeServerConfig
@@ -27,7 +27,7 @@ from synapse.storage.util.id_generators import (
MultiWriterIdGenerator,
StreamIdGenerator,
)
from synapse.types import get_domain_from_id
from synapse.types import JsonDict, get_domain_from_id
from synapse.util.caches.stream_change_cache import StreamChangeCache
from .account_data import AccountDataStore
@@ -264,7 +264,7 @@ class DataStore(
return [UserPresenceState(**row) for row in rows]
async def get_users(self) -> List[Dict[str, Any]]:
async def get_users(self) -> List[JsonDict]:
"""Function to retrieve a list of users in users table.
Returns:
@@ -292,7 +292,7 @@ class DataStore(
name: Optional[str] = None,
guests: bool = True,
deactivated: bool = False,
) -> Tuple[List[Dict[str, Any]], int]:
) -> Tuple[List[JsonDict], int]:
"""Function to retrieve a paginated list of users from
users list. This will return a json list of users and the
total number of users matching the filter criteria.
@@ -353,7 +353,7 @@ class DataStore(
"get_users_paginate_txn", get_users_paginate_txn
)
async def search_users(self, term: str) -> Optional[List[Dict[str, Any]]]:
async def search_users(self, term: str) -> Optional[List[JsonDict]]:
"""Function to search users list for one or more users with
the matched term.

View File

@@ -27,7 +27,7 @@ logger = logging.getLogger(__name__)
# Number of msec of granularity to store the user IP 'last seen' time. Smaller
# times give more inserts into the database even for readonly API hits
# 120 seconds == 2 minutes
LAST_SEEN_GRANULARITY = 10 * 60 * 1000
LAST_SEEN_GRANULARITY = 120 * 1000
class ClientIpBackgroundUpdateStore(SQLBaseStore):

View File

@@ -696,7 +696,9 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
)
if not has_event_auth:
for auth_id in event.auth_event_ids():
# Old, dodgy, events may have duplicate auth events, which we
# need to deduplicate as we have a unique constraint.
for auth_id in set(event.auth_event_ids()):
auth_events.append(
{
"room_id": event.room_id,

View File

@@ -139,7 +139,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
start: int,
limit: int,
user_id: str,
order_by: MediaSortOrder = MediaSortOrder.CREATED_TS.value,
order_by: str = MediaSortOrder.CREATED_TS.value,
direction: str = "f",
) -> Tuple[List[Dict[str, Any]], int]:
"""Get a paginated list of metadata for a local piece of media

View File

@@ -28,7 +28,10 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
async def purge_history(
self, room_id: str, token: str, delete_local_events: bool
) -> Set[int]:
"""Deletes room history before a certain point
"""Deletes room history before a certain point.
Note that only a single purge can occur at once, this is guaranteed via
a higher level (in the PaginationHandler).
Args:
room_id:
@@ -52,7 +55,9 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
delete_local_events,
)
def _purge_history_txn(self, txn, room_id, token, delete_local_events):
def _purge_history_txn(
self, txn, room_id: str, token: RoomStreamToken, delete_local_events: bool
) -> Set[int]:
# Tables that should be pruned:
# event_auth
# event_backward_extremities
@@ -103,7 +108,7 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
if max_depth < token.topological:
# We need to ensure we don't delete all the events from the database
# otherwise we wouldn't be able to send any events (due to not
# having any backwards extremeties)
# having any backwards extremities)
raise SynapseError(
400, "topological_ordering is greater than forward extremeties"
)
@@ -154,7 +159,7 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
logger.info("[purge] Finding new backward extremities")
# We calculate the new entries for the backward extremeties by finding
# We calculate the new entries for the backward extremities by finding
# events to be purged that are pointed to by events we're not going to
# purge.
txn.execute(
@@ -296,7 +301,7 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
"purge_room", self._purge_room_txn, room_id
)
def _purge_room_txn(self, txn, room_id):
def _purge_room_txn(self, txn, room_id: str) -> List[int]:
# First we fetch all the state groups that should be deleted, before
# we delete that information.
txn.execute(
@@ -310,6 +315,31 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
state_groups = [row[0] for row in txn]
# Get all the auth chains that are referenced by events that are to be
# deleted.
txn.execute(
"""
SELECT chain_id, sequence_number FROM events
LEFT JOIN event_auth_chains USING (event_id)
WHERE room_id = ?
""",
(room_id,),
)
referenced_chain_id_tuples = list(txn)
logger.info("[purge] removing events from event_auth_chain_links")
txn.executemany(
"""
DELETE FROM event_auth_chain_links WHERE
(origin_chain_id = ? AND origin_sequence_number = ?) OR
(target_chain_id = ? AND target_sequence_number = ?)
""",
(
(chain_id, seq_num, chain_id, seq_num)
for (chain_id, seq_num) in referenced_chain_id_tuples
),
)
# Now we delete tables which lack an index on room_id but have one on event_id
for table in (
"event_auth",
@@ -319,6 +349,8 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
"event_reference_hashes",
"event_relations",
"event_to_state_groups",
"event_auth_chains",
"event_auth_chain_to_calculate",
"redactions",
"rejections",
"state_events",

View File

@@ -39,6 +39,16 @@ class PusherWorkerStore(SQLBaseStore):
db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]
)
self.db_pool.updates.register_background_update_handler(
"remove_deactivated_pushers",
self._remove_deactivated_pushers,
)
self.db_pool.updates.register_background_update_handler(
"remove_stale_pushers",
self._remove_stale_pushers,
)
def _decode_pushers_rows(self, rows: Iterable[dict]) -> Iterator[PusherConfig]:
"""JSON-decode the data in the rows returned from the `pushers` table
@@ -284,6 +294,101 @@ class PusherWorkerStore(SQLBaseStore):
lock=False,
)
async def _remove_deactivated_pushers(self, progress: dict, batch_size: int) -> int:
"""A background update that deletes all pushers for deactivated users.
Note that we don't proacively tell the pusherpool that we've deleted
these (just because its a bit off a faff to do from here), but they will
get cleaned up at the next restart
"""
last_user = progress.get("last_user", "")
def _delete_pushers(txn) -> int:
sql = """
SELECT name FROM users
WHERE deactivated = ? and name > ?
ORDER BY name ASC
LIMIT ?
"""
txn.execute(sql, (1, last_user, batch_size))
users = [row[0] for row in txn]
self.db_pool.simple_delete_many_txn(
txn,
table="pushers",
column="user_name",
iterable=users,
keyvalues={},
)
if users:
self.db_pool.updates._background_update_progress_txn(
txn, "remove_deactivated_pushers", {"last_user": users[-1]}
)
return len(users)
number_deleted = await self.db_pool.runInteraction(
"_remove_deactivated_pushers", _delete_pushers
)
if number_deleted < batch_size:
await self.db_pool.updates._end_background_update(
"remove_deactivated_pushers"
)
return number_deleted
async def _remove_stale_pushers(self, progress: dict, batch_size: int) -> int:
"""A background update that deletes all pushers for logged out devices.
Note that we don't proacively tell the pusherpool that we've deleted
these (just because its a bit off a faff to do from here), but they will
get cleaned up at the next restart
"""
last_pusher = progress.get("last_pusher", 0)
def _delete_pushers(txn) -> int:
sql = """
SELECT p.id, access_token FROM pushers AS p
LEFT JOIN access_tokens AS a ON (p.access_token = a.id)
WHERE p.id > ?
ORDER BY p.id ASC
LIMIT ?
"""
txn.execute(sql, (last_pusher, batch_size))
pushers = [(row[0], row[1]) for row in txn]
self.db_pool.simple_delete_many_txn(
txn,
table="pushers",
column="id",
iterable=(pusher_id for pusher_id, token in pushers if token is None),
keyvalues={},
)
if pushers:
self.db_pool.updates._background_update_progress_txn(
txn, "remove_stale_pushers", {"last_pusher": pushers[-1][0]}
)
return len(pushers)
number_deleted = await self.db_pool.runInteraction(
"_remove_stale_pushers", _delete_pushers
)
if number_deleted < batch_size:
await self.db_pool.updates._end_background_update("remove_stale_pushers")
return number_deleted
class PusherStore(PusherWorkerStore):
def get_pushers_stream_token(self) -> int:

View File

@@ -14,8 +14,7 @@
*/
-- We may not have deleted all pushers for deactivated accounts. Do so now.
--
-- Note: We don't bother updating the `deleted_pushers` table as it's just use
-- to stop pushers on workers, and that will happen when they get next restarted.
DELETE FROM pushers WHERE user_name IN (SELECT name FROM users WHERE deactivated = 1);
-- We may not have deleted all pushers for deactivated accounts, so we set up a
-- background job to delete them.
INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
(5908, 'remove_deactivated_pushers', '{}');

View File

@@ -16,4 +16,5 @@
-- Delete all pushers associated with deleted devices. This is to clear up after
-- a bug where they weren't correctly deleted when using workers.
DELETE FROM pushers WHERE access_token NOT IN (SELECT id FROM access_tokens);
INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
(5908, 'remove_stale_pushers', '{}');

View File

@@ -13,5 +13,14 @@
* limitations under the License.
*/
-- This originally was in 58/, but landed after 59/ was created, and so some
-- servers running develop didn't run this delta. Running it again should be
-- safe.
--
-- We first delete any in progress `rejected_events_metadata` background update,
-- to ensure that we don't conflict when trying to insert the new one. (We could
-- alternatively do an ON CONFLICT DO NOTHING, but that syntax isn't supported
-- by older SQLite versions. Plus, this should be a rare case).
DELETE FROM background_updates WHERE update_name = 'rejected_events_metadata';
INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
(5828, 'rejected_events_metadata', '{}');

View File

@@ -707,7 +707,7 @@ def _parse_query(database_engine, search_term):
results = re.findall(r"([\w\-]+)", search_term, re.UNICODE)
if isinstance(database_engine, PostgresEngine):
return " & ".join(result for result in results)
return " & ".join(result + ":*" for result in results)
elif isinstance(database_engine, Sqlite3Engine):
return " & ".join(result + "*" for result in results)
else:

View File

@@ -73,9 +73,6 @@ class PurgeEventsStorage:
Returns:
The set of state groups that can be deleted.
"""
# Graph of state group -> previous group
graph = {}
# Set of events that we have found to be referenced by events
referenced_groups = set()
@@ -111,8 +108,6 @@ class PurgeEventsStorage:
next_to_search |= prevs
state_groups_seen |= prevs
graph.update(edges)
to_delete = state_groups_seen - referenced_groups
return to_delete

View File

@@ -25,7 +25,7 @@ RoomsForUser = namedtuple(
)
GetRoomsForUserWithStreamOrdering = namedtuple(
"_GetRoomsForUserWithStreamOrdering", ("room_id", "event_pos")
"GetRoomsForUserWithStreamOrdering", ("room_id", "event_pos")
)

View File

@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import TYPE_CHECKING, Any, Callable, Dict, Generic, Optional, Set, TypeVar
from typing import TYPE_CHECKING, Any, Callable, Dict, Generic, Optional, TypeVar
from twisted.internet import defer
@@ -40,7 +40,6 @@ class ResponseCache(Generic[T]):
def __init__(self, hs: "HomeServer", name: str, timeout_ms: float = 0):
# Requests that haven't finished yet.
self.pending_result_cache = {} # type: Dict[T, ObservableDeferred]
self.pending_conditionals = {} # type: Dict[T, Set[Callable[[Any], bool]]]
self.clock = hs.get_clock()
self.timeout_sec = timeout_ms / 1000.0
@@ -102,11 +101,7 @@ class ResponseCache(Generic[T]):
self.pending_result_cache[key] = result
def remove(r):
should_cache = all(
func(r) for func in self.pending_conditionals.pop(key, [])
)
if self.timeout_sec and should_cache:
if self.timeout_sec:
self.clock.call_later(
self.timeout_sec, self.pending_result_cache.pop, key, None
)
@@ -117,31 +112,6 @@ class ResponseCache(Generic[T]):
result.addBoth(remove)
return result.observe()
def add_conditional(self, key: T, conditional: Callable[[Any], bool]):
self.pending_conditionals.setdefault(key, set()).add(conditional)
def wrap_conditional(
self,
key: T,
should_cache: Callable[[Any], bool],
callback: "Callable[..., Any]",
*args: Any,
**kwargs: Any
) -> defer.Deferred:
"""The same as wrap(), but adds a conditional to the final execution.
When the final execution completes, *all* conditionals need to return True for it to properly cache,
else it'll not be cached in a timed fashion.
"""
# See if there's already a result on this key that hasn't yet completed. Due to the single-threaded nature of
# python, adding a key immediately in the same execution thread will not cause a race condition.
result = self.get(key)
if not result or isinstance(result, defer.Deferred) and not result.called:
self.add_conditional(key, should_cache)
return self.wrap(key, callback, *args, **kwargs)
def wrap(
self, key: T, callback: "Callable[..., Any]", *args: Any, **kwargs: Any
) -> defer.Deferred:

3
synctl
View File

@@ -30,7 +30,7 @@ import yaml
from synapse.config import find_config_files
SYNAPSE = [sys.executable, "-B", "-m", "synapse.app.homeserver"]
SYNAPSE = [sys.executable, "-m", "synapse.app.homeserver"]
GREEN = "\x1b[1;32m"
YELLOW = "\x1b[1;33m"
@@ -117,7 +117,6 @@ def start_worker(app: str, configfile: str, worker_configfile: str) -> bool:
args = [
sys.executable,
"-B",
"-m",
app,
"-c",

View File

@@ -26,77 +26,96 @@ from tests.unittest import TestCase
class ReadBodyWithMaxSizeTests(TestCase):
def setUp(self):
def _build_response(self, length=UNKNOWN_LENGTH):
"""Start reading the body, returns the response, result and proto"""
response = Mock(length=UNKNOWN_LENGTH)
self.result = BytesIO()
self.deferred = read_body_with_max_size(response, self.result, 6)
response = Mock(length=length)
result = BytesIO()
deferred = read_body_with_max_size(response, result, 6)
# Fish the protocol out of the response.
self.protocol = response.deliverBody.call_args[0][0]
self.protocol.transport = Mock()
protocol = response.deliverBody.call_args[0][0]
protocol.transport = Mock()
def _cleanup_error(self):
return result, deferred, protocol
def _assert_error(self, deferred, protocol):
"""Ensure that the expected error is received."""
self.assertIsInstance(deferred.result, Failure)
self.assertIsInstance(deferred.result.value, BodyExceededMaxSize)
protocol.transport.abortConnection.assert_called_once()
def _cleanup_error(self, deferred):
"""Ensure that the error in the Deferred is handled gracefully."""
called = [False]
def errback(f):
called[0] = True
self.deferred.addErrback(errback)
deferred.addErrback(errback)
self.assertTrue(called[0])
def test_no_error(self):
"""A response that is NOT too large."""
result, deferred, protocol = self._build_response()
# Start sending data.
self.protocol.dataReceived(b"12345")
protocol.dataReceived(b"12345")
# Close the connection.
self.protocol.connectionLost(Failure(ResponseDone()))
protocol.connectionLost(Failure(ResponseDone()))
self.assertEqual(self.result.getvalue(), b"12345")
self.assertEqual(self.deferred.result, 5)
self.assertEqual(result.getvalue(), b"12345")
self.assertEqual(deferred.result, 5)
def test_too_large(self):
"""A response which is too large raises an exception."""
result, deferred, protocol = self._build_response()
# Start sending data.
self.protocol.dataReceived(b"1234567890")
# Close the connection.
self.protocol.connectionLost(Failure(ResponseDone()))
protocol.dataReceived(b"1234567890")
self.assertEqual(self.result.getvalue(), b"1234567890")
self.assertIsInstance(self.deferred.result, Failure)
self.assertIsInstance(self.deferred.result.value, BodyExceededMaxSize)
self._cleanup_error()
self.assertEqual(result.getvalue(), b"1234567890")
self._assert_error(deferred, protocol)
self._cleanup_error(deferred)
def test_multiple_packets(self):
"""Data should be accummulated through mutliple packets."""
"""Data should be accumulated through mutliple packets."""
result, deferred, protocol = self._build_response()
# Start sending data.
self.protocol.dataReceived(b"12")
self.protocol.dataReceived(b"34")
protocol.dataReceived(b"12")
protocol.dataReceived(b"34")
# Close the connection.
self.protocol.connectionLost(Failure(ResponseDone()))
protocol.connectionLost(Failure(ResponseDone()))
self.assertEqual(self.result.getvalue(), b"1234")
self.assertEqual(self.deferred.result, 4)
self.assertEqual(result.getvalue(), b"1234")
self.assertEqual(deferred.result, 4)
def test_additional_data(self):
"""A connection can receive data after being closed."""
result, deferred, protocol = self._build_response()
# Start sending data.
self.protocol.dataReceived(b"1234567890")
self.assertIsInstance(self.deferred.result, Failure)
self.assertIsInstance(self.deferred.result.value, BodyExceededMaxSize)
self.protocol.transport.abortConnection.assert_called_once()
protocol.dataReceived(b"1234567890")
self._assert_error(deferred, protocol)
# More data might have come in.
self.protocol.dataReceived(b"1234567890")
# Close the connection.
self.protocol.connectionLost(Failure(ResponseDone()))
protocol.dataReceived(b"1234567890")
self.assertEqual(self.result.getvalue(), b"1234567890")
self.assertIsInstance(self.deferred.result, Failure)
self.assertIsInstance(self.deferred.result.value, BodyExceededMaxSize)
self._cleanup_error()
self.assertEqual(result.getvalue(), b"1234567890")
self._assert_error(deferred, protocol)
self._cleanup_error(deferred)
def test_content_length(self):
"""The body shouldn't be read (at all) if the Content-Length header is too large."""
result, deferred, protocol = self._build_response(length=10)
# Deferred shouldn't be called yet.
self.assertFalse(deferred.called)
# Start sending data.
protocol.dataReceived(b"12345")
self._assert_error(deferred, protocol)
self._cleanup_error(deferred)
# The data is never consumed.
self.assertEqual(result.getvalue(), b"")

View File

@@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
from unittest.mock import patch
import treq
from netaddr import IPSet
@@ -100,62 +102,36 @@ class MatrixFederationAgentTests(TestCase):
return http_protocol
def test_http_request(self):
agent = ProxyAgent(self.reactor)
def _test_request_direct_connection(self, agent, scheme, hostname, path):
"""Runs a test case for a direct connection not going through a proxy.
self.reactor.lookups["test.com"] = "1.2.3.4"
d = agent.request(b"GET", b"http://test.com")
Args:
agent (ProxyAgent): the proxy agent being tested
scheme (bytes): expected to be either "http" or "https"
hostname (bytes): the hostname to connect to in the test
path (bytes): the path to connect to in the test
"""
is_https = scheme == b"https"
self.reactor.lookups[hostname.decode()] = "1.2.3.4"
d = agent.request(b"GET", scheme + b"://" + hostname + b"/" + path)
# there should be a pending TCP connection
clients = self.reactor.tcpClients
self.assertEqual(len(clients), 1)
(host, port, client_factory, _timeout, _bindAddress) = clients[0]
self.assertEqual(host, "1.2.3.4")
self.assertEqual(port, 80)
# make a test server, and wire up the client
http_server = self._make_connection(
client_factory, _get_test_protocol_factory()
)
# the FakeTransport is async, so we need to pump the reactor
self.reactor.advance(0)
# now there should be a pending request
self.assertEqual(len(http_server.requests), 1)
request = http_server.requests[0]
self.assertEqual(request.method, b"GET")
self.assertEqual(request.path, b"/")
self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"])
request.write(b"result")
request.finish()
self.reactor.advance(0)
resp = self.successResultOf(d)
body = self.successResultOf(treq.content(resp))
self.assertEqual(body, b"result")
def test_https_request(self):
agent = ProxyAgent(self.reactor, contextFactory=get_test_https_policy())
self.reactor.lookups["test.com"] = "1.2.3.4"
d = agent.request(b"GET", b"https://test.com/abc")
# there should be a pending TCP connection
clients = self.reactor.tcpClients
self.assertEqual(len(clients), 1)
(host, port, client_factory, _timeout, _bindAddress) = clients[0]
self.assertEqual(host, "1.2.3.4")
self.assertEqual(port, 443)
self.assertEqual(port, 443 if is_https else 80)
# make a test server, and wire up the client
http_server = self._make_connection(
client_factory,
_get_test_protocol_factory(),
ssl=True,
expected_sni=b"test.com",
ssl=is_https,
expected_sni=hostname if is_https else None,
)
# the FakeTransport is async, so we need to pump the reactor
@@ -166,8 +142,8 @@ class MatrixFederationAgentTests(TestCase):
request = http_server.requests[0]
self.assertEqual(request.method, b"GET")
self.assertEqual(request.path, b"/abc")
self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"])
self.assertEqual(request.path, b"/" + path)
self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [hostname])
request.write(b"result")
request.finish()
@@ -177,8 +153,58 @@ class MatrixFederationAgentTests(TestCase):
body = self.successResultOf(treq.content(resp))
self.assertEqual(body, b"result")
def test_http_request(self):
agent = ProxyAgent(self.reactor)
self._test_request_direct_connection(agent, b"http", b"test.com", b"")
def test_https_request(self):
agent = ProxyAgent(self.reactor, contextFactory=get_test_https_policy())
self._test_request_direct_connection(agent, b"https", b"test.com", b"abc")
def test_http_request_use_proxy_empty_environment(self):
agent = ProxyAgent(self.reactor, use_proxy=True)
self._test_request_direct_connection(agent, b"http", b"test.com", b"")
@patch.dict(os.environ, {"http_proxy": "proxy.com:8888", "NO_PROXY": "test.com"})
def test_http_request_via_uppercase_no_proxy(self):
agent = ProxyAgent(self.reactor, use_proxy=True)
self._test_request_direct_connection(agent, b"http", b"test.com", b"")
@patch.dict(
os.environ, {"http_proxy": "proxy.com:8888", "no_proxy": "test.com,unused.com"}
)
def test_http_request_via_no_proxy(self):
agent = ProxyAgent(self.reactor, use_proxy=True)
self._test_request_direct_connection(agent, b"http", b"test.com", b"")
@patch.dict(
os.environ, {"https_proxy": "proxy.com", "no_proxy": "test.com,unused.com"}
)
def test_https_request_via_no_proxy(self):
agent = ProxyAgent(
self.reactor,
contextFactory=get_test_https_policy(),
use_proxy=True,
)
self._test_request_direct_connection(agent, b"https", b"test.com", b"abc")
@patch.dict(os.environ, {"http_proxy": "proxy.com:8888", "no_proxy": "*"})
def test_http_request_via_no_proxy_star(self):
agent = ProxyAgent(self.reactor, use_proxy=True)
self._test_request_direct_connection(agent, b"http", b"test.com", b"")
@patch.dict(os.environ, {"https_proxy": "proxy.com", "no_proxy": "*"})
def test_https_request_via_no_proxy_star(self):
agent = ProxyAgent(
self.reactor,
contextFactory=get_test_https_policy(),
use_proxy=True,
)
self._test_request_direct_connection(agent, b"https", b"test.com", b"abc")
@patch.dict(os.environ, {"http_proxy": "proxy.com:8888", "no_proxy": "unused.com"})
def test_http_request_via_proxy(self):
agent = ProxyAgent(self.reactor, http_proxy=b"proxy.com:8888")
agent = ProxyAgent(self.reactor, use_proxy=True)
self.reactor.lookups["proxy.com"] = "1.2.3.5"
d = agent.request(b"GET", b"http://test.com")
@@ -214,11 +240,12 @@ class MatrixFederationAgentTests(TestCase):
body = self.successResultOf(treq.content(resp))
self.assertEqual(body, b"result")
@patch.dict(os.environ, {"https_proxy": "proxy.com", "no_proxy": "unused.com"})
def test_https_request_via_proxy(self):
agent = ProxyAgent(
self.reactor,
contextFactory=get_test_https_policy(),
https_proxy=b"proxy.com",
use_proxy=True,
)
self.reactor.lookups["proxy.com"] = "1.2.3.5"
@@ -294,6 +321,7 @@ class MatrixFederationAgentTests(TestCase):
body = self.successResultOf(treq.content(resp))
self.assertEqual(body, b"result")
@patch.dict(os.environ, {"http_proxy": "proxy.com:8888"})
def test_http_request_via_proxy_with_blacklist(self):
# The blacklist includes the configured proxy IP.
agent = ProxyAgent(
@@ -301,7 +329,7 @@ class MatrixFederationAgentTests(TestCase):
self.reactor, ip_whitelist=None, ip_blacklist=IPSet(["1.0.0.0/8"])
),
self.reactor,
http_proxy=b"proxy.com:8888",
use_proxy=True,
)
self.reactor.lookups["proxy.com"] = "1.2.3.5"
@@ -338,7 +366,8 @@ class MatrixFederationAgentTests(TestCase):
body = self.successResultOf(treq.content(resp))
self.assertEqual(body, b"result")
def test_https_request_via_proxy_with_blacklist(self):
@patch.dict(os.environ, {"HTTPS_PROXY": "proxy.com"})
def test_https_request_via_uppercase_proxy_with_blacklist(self):
# The blacklist includes the configured proxy IP.
agent = ProxyAgent(
BlacklistingReactorWrapper(
@@ -346,7 +375,7 @@ class MatrixFederationAgentTests(TestCase):
),
self.reactor,
contextFactory=get_test_https_policy(),
https_proxy=b"proxy.com",
use_proxy=True,
)
self.reactor.lookups["proxy.com"] = "1.2.3.5"

View File

@@ -522,7 +522,9 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
shorthand=False,
)
self.assertEqual(channel.code, 302, channel.result)
cas_uri = channel.headers.getRawHeaders("Location")[0]
location_headers = channel.headers.getRawHeaders("Location")
assert location_headers
cas_uri = location_headers[0]
cas_uri_path, cas_uri_query = cas_uri.split("?", 1)
# it should redirect us to the login page of the cas server
@@ -545,7 +547,9 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
+ "&idp=saml",
)
self.assertEqual(channel.code, 302, channel.result)
saml_uri = channel.headers.getRawHeaders("Location")[0]
location_headers = channel.headers.getRawHeaders("Location")
assert location_headers
saml_uri = location_headers[0]
saml_uri_path, saml_uri_query = saml_uri.split("?", 1)
# it should redirect us to the login page of the SAML server
@@ -567,17 +571,21 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
+ "&idp=oidc",
)
self.assertEqual(channel.code, 302, channel.result)
oidc_uri = channel.headers.getRawHeaders("Location")[0]
location_headers = channel.headers.getRawHeaders("Location")
assert location_headers
oidc_uri = location_headers[0]
oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1)
# it should redirect us to the auth page of the OIDC server
self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT)
# ... and should have set a cookie including the redirect url
cookies = dict(
h.split(";")[0].split("=", maxsplit=1)
for h in channel.headers.getRawHeaders("Set-Cookie")
)
cookie_headers = channel.headers.getRawHeaders("Set-Cookie")
assert cookie_headers
cookies = {} # type: Dict[str, str]
for h in cookie_headers:
key, value = h.split(";")[0].split("=", maxsplit=1)
cookies[key] = value
oidc_session_cookie = cookies["oidc_session"]
macaroon = pymacaroons.Macaroon.deserialize(oidc_session_cookie)
@@ -590,9 +598,9 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
# that should serve a confirmation page
self.assertEqual(channel.code, 200, channel.result)
self.assertTrue(
channel.headers.getRawHeaders("Content-Type")[-1].startswith("text/html")
)
content_type_headers = channel.headers.getRawHeaders("Content-Type")
assert content_type_headers
self.assertTrue(content_type_headers[-1].startswith("text/html"))
p = TestHtmlParser()
p.feed(channel.text_body)
p.close()
@@ -806,6 +814,7 @@ class CASTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, 302)
location_headers = channel.headers.getRawHeaders("Location")
assert location_headers
self.assertEqual(location_headers[0][: len(redirect_url)], redirect_url)
@override_config({"sso": {"client_whitelist": ["https://legit-site.com/"]}})
@@ -1248,7 +1257,9 @@ class UsernamePickerTestCase(HomeserverTestCase):
# that should redirect to the username picker
self.assertEqual(channel.code, 302, channel.result)
picker_url = channel.headers.getRawHeaders("Location")[0]
location_headers = channel.headers.getRawHeaders("Location")
assert location_headers
picker_url = location_headers[0]
self.assertEqual(picker_url, "/_synapse/client/pick_username/account_details")
# ... with a username_mapping_session cookie
@@ -1291,6 +1302,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
)
self.assertEqual(chan.code, 302, chan.result)
location_headers = chan.headers.getRawHeaders("Location")
assert location_headers
# send a request to the completion page, which should 302 to the client redirectUrl
chan = self.make_request(
@@ -1300,6 +1312,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
)
self.assertEqual(chan.code, 302, chan.result)
location_headers = chan.headers.getRawHeaders("Location")
assert location_headers
# ensure that the returned location matches the requested redirect URL
path, query = location_headers[0].split("?", 1)

View File

@@ -189,5 +189,7 @@ commands=
[testenv:mypy]
deps =
{[base]deps}
# Type hints are broken with Twisted > 20.3.0, see https://github.com/matrix-org/synapse/issues/9513
twisted==20.3.0
extras = all,mypy
commands = mypy