Compare commits
23 Commits
michaelkay
...
release-v1
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4de1c35728 | ||
|
|
15c788e22d | ||
|
|
a6333b8d42 | ||
|
|
ea0a3aaf0a | ||
|
|
3f49d80dcf | ||
|
|
33a02f0f52 | ||
|
|
4db07f9aef | ||
|
|
a4fa044c00 | ||
|
|
922788c604 | ||
|
|
d790d0d314 | ||
|
|
0c330423bc | ||
|
|
16f9f93eb7 | ||
|
|
a5daae2a5f | ||
|
|
0279e0e086 | ||
|
|
aee10768d8 | ||
|
|
7f5d753d06 | ||
|
|
16108c579d | ||
|
|
f00c4e7af0 | ||
|
|
ad8589d392 | ||
|
|
16ec8c3272 | ||
|
|
a0bc9d387e | ||
|
|
e12077a78a | ||
|
|
ddb240293a |
3
.gitignore
vendored
3
.gitignore
vendored
@@ -6,13 +6,14 @@
|
||||
*.egg
|
||||
*.egg-info
|
||||
*.lock
|
||||
*.pyc
|
||||
*.py[cod]
|
||||
*.snap
|
||||
*.tac
|
||||
_trial_temp/
|
||||
_trial_temp*/
|
||||
/out
|
||||
.DS_Store
|
||||
__pycache__/
|
||||
|
||||
# stuff that is likely to exist when you run a server locally
|
||||
/*.db
|
||||
|
||||
61
CHANGES.md
61
CHANGES.md
@@ -1,9 +1,66 @@
|
||||
Synapse 1.xx.0
|
||||
==============
|
||||
Synapse 1.29.0 (2021-03-08)
|
||||
===========================
|
||||
|
||||
Note that synapse now expects an `X-Forwarded-Proto` header when used with a reverse proxy. Please see [UPGRADE.rst](UPGRADE.rst#upgrading-to-v1290) for more details on this change.
|
||||
|
||||
|
||||
No significant changes.
|
||||
|
||||
|
||||
Synapse 1.29.0rc1 (2021-03-04)
|
||||
==============================
|
||||
|
||||
Features
|
||||
--------
|
||||
|
||||
- Add rate limiters to cross-user key sharing requests. ([\#8957](https://github.com/matrix-org/synapse/issues/8957))
|
||||
- Add `order_by` to the admin API `GET /_synapse/admin/v1/users/<user_id>/media`. Contributed by @dklimpel. ([\#8978](https://github.com/matrix-org/synapse/issues/8978))
|
||||
- Add some configuration settings to make users' profile data more private. ([\#9203](https://github.com/matrix-org/synapse/issues/9203))
|
||||
- The `no_proxy` and `NO_PROXY` environment variables are now respected in proxied HTTP clients with the lowercase form taking precedence if both are present. Additionally, the lowercase `https_proxy` environment variable is now respected in proxied HTTP clients on top of existing support for the uppercase `HTTPS_PROXY` form and takes precedence if both are present. Contributed by Timothy Leung. ([\#9372](https://github.com/matrix-org/synapse/issues/9372))
|
||||
- Add a configuration option, `user_directory.prefer_local_users`, which when enabled will make it more likely for users on the same server as you to appear above other users. ([\#9383](https://github.com/matrix-org/synapse/issues/9383), [\#9385](https://github.com/matrix-org/synapse/issues/9385))
|
||||
- Add support for regenerating thumbnails if they have been deleted but the original image is still stored. ([\#9438](https://github.com/matrix-org/synapse/issues/9438))
|
||||
- Add support for `X-Forwarded-Proto` header when using a reverse proxy. ([\#9472](https://github.com/matrix-org/synapse/issues/9472), [\#9501](https://github.com/matrix-org/synapse/issues/9501), [\#9512](https://github.com/matrix-org/synapse/issues/9512), [\#9539](https://github.com/matrix-org/synapse/issues/9539))
|
||||
|
||||
|
||||
Bugfixes
|
||||
--------
|
||||
|
||||
- Fix a bug where users' pushers were not all deleted when they deactivated their account. ([\#9285](https://github.com/matrix-org/synapse/issues/9285), [\#9516](https://github.com/matrix-org/synapse/issues/9516))
|
||||
- Fix a bug where a lot of unnecessary presence updates were sent when joining a room. ([\#9402](https://github.com/matrix-org/synapse/issues/9402))
|
||||
- Fix a bug that caused multiple calls to the experimental `shared_rooms` endpoint to return stale results. ([\#9416](https://github.com/matrix-org/synapse/issues/9416))
|
||||
- Fix a bug in single sign-on which could cause a "No session cookie found" error. ([\#9436](https://github.com/matrix-org/synapse/issues/9436))
|
||||
- Fix bug introduced in v1.27.0 where allowing a user to choose their own username when logging in via single sign-on did not work unless an `idp_icon` was defined. ([\#9440](https://github.com/matrix-org/synapse/issues/9440))
|
||||
- Fix a bug introduced in v1.26.0 where some sequences were not properly configured when running `synapse_port_db`. ([\#9449](https://github.com/matrix-org/synapse/issues/9449))
|
||||
- Fix deleting pushers when using sharded pushers. ([\#9465](https://github.com/matrix-org/synapse/issues/9465), [\#9466](https://github.com/matrix-org/synapse/issues/9466), [\#9479](https://github.com/matrix-org/synapse/issues/9479), [\#9536](https://github.com/matrix-org/synapse/issues/9536))
|
||||
- Fix missing startup checks for the consistency of certain PostgreSQL sequences. ([\#9470](https://github.com/matrix-org/synapse/issues/9470))
|
||||
- Fix a long-standing bug where the media repository could leak file descriptors while previewing media. ([\#9497](https://github.com/matrix-org/synapse/issues/9497))
|
||||
- Properly purge the event chain cover index when purging history. ([\#9498](https://github.com/matrix-org/synapse/issues/9498))
|
||||
- Fix missing chain cover index due to a schema delta not being applied correctly. Only affected servers that ran development versions. ([\#9503](https://github.com/matrix-org/synapse/issues/9503))
|
||||
- Fix a bug introduced in v1.25.0 where `/_synapse/admin/join/` would fail when given a room alias. ([\#9506](https://github.com/matrix-org/synapse/issues/9506))
|
||||
- Prevent presence background jobs from running when presence is disabled. ([\#9530](https://github.com/matrix-org/synapse/issues/9530))
|
||||
- Fix rare edge case that caused a background update to fail if the server had rejected an event that had duplicate auth events. ([\#9537](https://github.com/matrix-org/synapse/issues/9537))
|
||||
|
||||
|
||||
Improved Documentation
|
||||
----------------------
|
||||
|
||||
- Update the example systemd config to propagate reloads to individual units. ([\#9463](https://github.com/matrix-org/synapse/issues/9463))
|
||||
|
||||
|
||||
Internal Changes
|
||||
----------------
|
||||
|
||||
- Add documentation and type hints to `parse_duration`. ([\#9432](https://github.com/matrix-org/synapse/issues/9432))
|
||||
- Remove vestiges of `uploads_path` configuration setting. ([\#9462](https://github.com/matrix-org/synapse/issues/9462))
|
||||
- Add a comment about systemd-python. ([\#9464](https://github.com/matrix-org/synapse/issues/9464))
|
||||
- Test that we require validated email for email pushers. ([\#9496](https://github.com/matrix-org/synapse/issues/9496))
|
||||
- Allow python to generate bytecode for synapse. ([\#9502](https://github.com/matrix-org/synapse/issues/9502))
|
||||
- Fix incorrect type hints. ([\#9515](https://github.com/matrix-org/synapse/issues/9515), [\#9518](https://github.com/matrix-org/synapse/issues/9518))
|
||||
- Add type hints to device and event report admin API. ([\#9519](https://github.com/matrix-org/synapse/issues/9519))
|
||||
- Add type hints to user admin API. ([\#9521](https://github.com/matrix-org/synapse/issues/9521))
|
||||
- Bump the versions of mypy and mypy-zope used for static type checking. ([\#9529](https://github.com/matrix-org/synapse/issues/9529))
|
||||
|
||||
|
||||
Synapse 1.28.0 (2021-02-25)
|
||||
===========================
|
||||
|
||||
|
||||
@@ -98,9 +98,9 @@ will log a warning on each received request.
|
||||
|
||||
To avoid the warning, administrators using a reverse proxy should ensure that
|
||||
the reverse proxy sets `X-Forwarded-Proto` header to `https` or `http` to
|
||||
indicate the protocol used by the client. See the [reverse proxy
|
||||
documentation](docs/reverse_proxy.md), where the example configurations have
|
||||
been updated to show how to set this header.
|
||||
indicate the protocol used by the client. See the `reverse proxy documentation
|
||||
<docs/reverse_proxy.md>`_, where the example configurations have been updated to
|
||||
show how to set this header.
|
||||
|
||||
(Users of `Caddy <https://caddyserver.com/>`_ are unaffected, since we believe it
|
||||
sets `X-Forwarded-Proto` by default.)
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
Temporarily drop cross-user m.room_key_request to_device messages over performance concerns.
|
||||
@@ -1 +0,0 @@
|
||||
Add rate limiters to cross-user key sharing requests.
|
||||
@@ -1 +0,0 @@
|
||||
Add `order_by` to the admin API `GET /_synapse/admin/v1/users/<user_id>/media`. Contributed by @dklimpel.
|
||||
@@ -1 +0,0 @@
|
||||
Add some configuration settings to make users' profile data more private.
|
||||
@@ -1 +0,0 @@
|
||||
Fix a bug where users' pushers were not all deleted when they deactivated their account.
|
||||
@@ -1 +0,0 @@
|
||||
Added a fix that invalidates cache for empty timed-out sync responses.
|
||||
@@ -1 +0,0 @@
|
||||
Add a configuration option, `user_directory.prefer_local_users`, which when enabled will make it more likely for users on the same server as you to appear above other users.
|
||||
@@ -1 +0,0 @@
|
||||
Add a configuration option, `user_directory.prefer_local_users`, which when enabled will make it more likely for users on the same server as you to appear above other users.
|
||||
@@ -1 +0,0 @@
|
||||
Fix a bug where a lot of unnecessary presence updates were sent when joining a room.
|
||||
@@ -1 +0,0 @@
|
||||
Fix a bug that caused multiple calls to the experimental `shared_rooms` endpoint to return stale results.
|
||||
@@ -1 +0,0 @@
|
||||
Add documentation and type hints to `parse_duration`.
|
||||
@@ -1 +0,0 @@
|
||||
Fix a bug in single sign-on which could cause a "No session cookie found" error.
|
||||
@@ -1 +0,0 @@
|
||||
Add support for regenerating thumbnails if they have been deleted but the original image is still stored.
|
||||
@@ -1 +0,0 @@
|
||||
Fix bug introduced in v1.27.0 where allowing a user to choose their own username when logging in via single sign-on did not work unless an `idp_icon` was defined.
|
||||
@@ -1 +0,0 @@
|
||||
Fix a bug introduced in v1.26.0 where some sequences were not properly configured when running `synapse_port_db`.
|
||||
@@ -1 +0,0 @@
|
||||
Remove vestiges of `uploads_path` configuration setting.
|
||||
@@ -1 +0,0 @@
|
||||
Update the example systemd config to propagate reloads to individual units.
|
||||
@@ -1 +0,0 @@
|
||||
Add a comment about systemd-python.
|
||||
@@ -1 +0,0 @@
|
||||
Fix deleting pushers when using sharded pushers.
|
||||
@@ -1 +0,0 @@
|
||||
Fix deleting pushers when using sharded pushers.
|
||||
@@ -1 +0,0 @@
|
||||
Fix missing startup checks for the consistency of certain PostgreSQL sequences.
|
||||
@@ -1 +0,0 @@
|
||||
Add support for `X-Forwarded-Proto` header when using a reverse proxy.
|
||||
@@ -1 +0,0 @@
|
||||
Fix deleting pushers when using sharded pushers.
|
||||
@@ -1 +0,0 @@
|
||||
Test that we require validated email for email pushers.
|
||||
@@ -1 +0,0 @@
|
||||
Add support for `X-Forwarded-Proto` header when using a reverse proxy.
|
||||
6
debian/build_virtualenv
vendored
6
debian/build_virtualenv
vendored
@@ -58,10 +58,10 @@ trap "rm -r $tmpdir" EXIT
|
||||
cp -r tests "$tmpdir"
|
||||
|
||||
PYTHONPATH="$tmpdir" \
|
||||
"${TARGET_PYTHON}" -B -m twisted.trial --reporter=text -j2 tests
|
||||
"${TARGET_PYTHON}" -m twisted.trial --reporter=text -j2 tests
|
||||
|
||||
# build the config file
|
||||
"${TARGET_PYTHON}" -B "${VIRTUALENV_DIR}/bin/generate_config" \
|
||||
"${TARGET_PYTHON}" "${VIRTUALENV_DIR}/bin/generate_config" \
|
||||
--config-dir="/etc/matrix-synapse" \
|
||||
--data-dir="/var/lib/matrix-synapse" |
|
||||
perl -pe '
|
||||
@@ -87,7 +87,7 @@ PYTHONPATH="$tmpdir" \
|
||||
' > "${PACKAGE_BUILD_DIR}/etc/matrix-synapse/homeserver.yaml"
|
||||
|
||||
# build the log config file
|
||||
"${TARGET_PYTHON}" -B "${VIRTUALENV_DIR}/bin/generate_log_config" \
|
||||
"${TARGET_PYTHON}" "${VIRTUALENV_DIR}/bin/generate_log_config" \
|
||||
--output-file="${PACKAGE_BUILD_DIR}/etc/matrix-synapse/log.yaml"
|
||||
|
||||
# add a dependency on the right version of python to substvars.
|
||||
|
||||
10
debian/changelog
vendored
10
debian/changelog
vendored
@@ -1,3 +1,13 @@
|
||||
matrix-synapse-py3 (1.29.0) stable; urgency=medium
|
||||
|
||||
[ Jonathan de Jong ]
|
||||
* Remove the python -B flag (don't generate bytecode) in scripts and documentation.
|
||||
|
||||
[ Synapse Packaging team ]
|
||||
* New synapse release 1.29.0.
|
||||
|
||||
-- Synapse Packaging team <packages@matrix.org> Mon, 08 Mar 2021 13:51:50 +0000
|
||||
|
||||
matrix-synapse-py3 (1.28.0) stable; urgency=medium
|
||||
|
||||
* New synapse release 1.28.0.
|
||||
|
||||
2
debian/synctl.1
vendored
2
debian/synctl.1
vendored
@@ -44,7 +44,7 @@ Configuration file may be generated as follows:
|
||||
.
|
||||
.nf
|
||||
|
||||
$ python \-B \-m synapse\.app\.homeserver \-c config\.yaml \-\-generate\-config \-\-server\-name=<server name>
|
||||
$ python \-m synapse\.app\.homeserver \-c config\.yaml \-\-generate\-config \-\-server\-name=<server name>
|
||||
.
|
||||
.fi
|
||||
.
|
||||
|
||||
2
debian/synctl.ronn
vendored
2
debian/synctl.ronn
vendored
@@ -41,7 +41,7 @@ process.
|
||||
|
||||
Configuration file may be generated as follows:
|
||||
|
||||
$ python -B -m synapse.app.homeserver -c config.yaml --generate-config --server-name=<server name>
|
||||
$ python -m synapse.app.homeserver -c config.yaml --generate-config --server-name=<server name>
|
||||
|
||||
## ENVIRONMENT
|
||||
|
||||
|
||||
@@ -53,6 +53,8 @@ server {
|
||||
proxy_pass http://localhost:8008;
|
||||
proxy_set_header X-Forwarded-For $remote_addr;
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
proxy_set_header Host $host;
|
||||
|
||||
# Nginx by default only allows file uploads up to 1M in size
|
||||
# Increase client_max_body_size to match max_upload_size defined in homeserver.yaml
|
||||
client_max_body_size 50M;
|
||||
|
||||
@@ -47,6 +47,7 @@ from synapse.storage.databases.main.events_bg_updates import (
|
||||
from synapse.storage.databases.main.media_repository import (
|
||||
MediaRepositoryBackgroundUpdateStore,
|
||||
)
|
||||
from synapse.storage.databases.main.pusher import PusherWorkerStore
|
||||
from synapse.storage.databases.main.registration import (
|
||||
RegistrationBackgroundUpdateStore,
|
||||
find_max_generated_user_id_localpart,
|
||||
@@ -177,6 +178,7 @@ class Store(
|
||||
UserDirectoryBackgroundUpdateStore,
|
||||
EndToEndKeyBackgroundStore,
|
||||
StatsStore,
|
||||
PusherWorkerStore,
|
||||
):
|
||||
def execute(self, f, *args, **kwargs):
|
||||
return self.db_pool.runInteraction(f.__name__, f, *args, **kwargs)
|
||||
|
||||
2
setup.py
2
setup.py
@@ -102,7 +102,7 @@ CONDITIONAL_REQUIREMENTS["lint"] = [
|
||||
"flake8",
|
||||
]
|
||||
|
||||
CONDITIONAL_REQUIREMENTS["mypy"] = ["mypy==0.790", "mypy-zope==0.2.8"]
|
||||
CONDITIONAL_REQUIREMENTS["mypy"] = ["mypy==0.812", "mypy-zope==0.2.11"]
|
||||
|
||||
# Dependencies which are exclusively required by unit test code. This is
|
||||
# NOT a list of all modules that are necessary to run the unit tests.
|
||||
|
||||
@@ -48,7 +48,7 @@ try:
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
__version__ = "1.28.0"
|
||||
__version__ = "1.29.0"
|
||||
|
||||
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
|
||||
# We import here so that we don't have to install a bunch of deps when
|
||||
|
||||
@@ -17,8 +17,6 @@ import sys
|
||||
|
||||
from synapse import python_dependencies # noqa: E402
|
||||
|
||||
sys.dont_write_bytecode = True
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
|
||||
@@ -23,6 +23,7 @@ from typing_extensions import ContextManager
|
||||
|
||||
from twisted.internet import address
|
||||
from twisted.web.resource import IResource
|
||||
from twisted.web.server import Request
|
||||
|
||||
import synapse
|
||||
import synapse.events
|
||||
@@ -190,7 +191,7 @@ class KeyUploadServlet(RestServlet):
|
||||
self.http_client = hs.get_simple_http_client()
|
||||
self.main_uri = hs.config.worker_main_http_uri
|
||||
|
||||
async def on_POST(self, request, device_id):
|
||||
async def on_POST(self, request: Request, device_id: Optional[str]):
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=True)
|
||||
user_id = requester.user.to_string()
|
||||
body = parse_json_object_from_request(request)
|
||||
@@ -223,10 +224,12 @@ class KeyUploadServlet(RestServlet):
|
||||
header: request.requestHeaders.getRawHeaders(header, [])
|
||||
for header in (b"Authorization", b"User-Agent")
|
||||
}
|
||||
# Add the previous hop the the X-Forwarded-For header.
|
||||
# Add the previous hop to the X-Forwarded-For header.
|
||||
x_forwarded_for = request.requestHeaders.getRawHeaders(
|
||||
b"X-Forwarded-For", []
|
||||
)
|
||||
# we use request.client here, since we want the previous hop, not the
|
||||
# original client (as returned by request.getClientAddress()).
|
||||
if isinstance(request.client, (address.IPv4Address, address.IPv6Address)):
|
||||
previous_host = request.client.host.encode("ascii")
|
||||
# If the header exists, add to the comma-separated list of the first
|
||||
@@ -239,6 +242,14 @@ class KeyUploadServlet(RestServlet):
|
||||
x_forwarded_for = [previous_host]
|
||||
headers[b"X-Forwarded-For"] = x_forwarded_for
|
||||
|
||||
# Replicate the original X-Forwarded-Proto header. Note that
|
||||
# XForwardedForRequest overrides isSecure() to give us the original protocol
|
||||
# used by the client, as opposed to the protocol used by our upstream proxy
|
||||
# - which is what we want here.
|
||||
headers[b"X-Forwarded-Proto"] = [
|
||||
b"https" if request.isSecure() else b"http"
|
||||
]
|
||||
|
||||
try:
|
||||
result = await self.http_client.post_json_get_json(
|
||||
self.main_uri + request.uri.decode("ascii"), body, headers=headers
|
||||
|
||||
@@ -936,10 +936,6 @@ class FederationHandlerRegistry:
|
||||
):
|
||||
return
|
||||
|
||||
# Temporary patch to drop cross-user key share requests
|
||||
if edu_type == "m.room_key_request":
|
||||
return
|
||||
|
||||
# Check if we have a handler on this instance
|
||||
handler = self.edu_handlers.get(edu_type)
|
||||
if handler:
|
||||
|
||||
@@ -36,7 +36,7 @@ import attr
|
||||
import bcrypt
|
||||
import pymacaroons
|
||||
|
||||
from twisted.web.http import Request
|
||||
from twisted.web.server import Request
|
||||
|
||||
from synapse.api.constants import LoginType
|
||||
from synapse.api.errors import (
|
||||
@@ -481,7 +481,7 @@ class AuthHandler(BaseHandler):
|
||||
sid = authdict["session"]
|
||||
|
||||
# Convert the URI and method to strings.
|
||||
uri = request.uri.decode("utf-8")
|
||||
uri = request.uri.decode("utf-8") # type: ignore
|
||||
method = request.method.decode("utf-8")
|
||||
|
||||
# If there's no session ID, create a new session.
|
||||
|
||||
@@ -252,7 +252,7 @@ class MessageHandler:
|
||||
# If this is an AS, double check that they are allowed to see the members.
|
||||
# This can either be because the AS user is in the room or because there
|
||||
# is a user in the room that the AS is "interested in"
|
||||
if False and requester.app_service and user_id not in users_with_profile:
|
||||
if requester.app_service and user_id not in users_with_profile:
|
||||
for uid in users_with_profile:
|
||||
if requester.app_service.is_interested_in_user(uid):
|
||||
break
|
||||
|
||||
@@ -274,22 +274,25 @@ class PresenceHandler(BasePresenceHandler):
|
||||
|
||||
self.external_sync_linearizer = Linearizer(name="external_sync_linearizer")
|
||||
|
||||
# Start a LoopingCall in 30s that fires every 5s.
|
||||
# The initial delay is to allow disconnected clients a chance to
|
||||
# reconnect before we treat them as offline.
|
||||
def run_timeout_handler():
|
||||
return run_as_background_process(
|
||||
"handle_presence_timeouts", self._handle_timeouts
|
||||
if self._presence_enabled:
|
||||
# Start a LoopingCall in 30s that fires every 5s.
|
||||
# The initial delay is to allow disconnected clients a chance to
|
||||
# reconnect before we treat them as offline.
|
||||
def run_timeout_handler():
|
||||
return run_as_background_process(
|
||||
"handle_presence_timeouts", self._handle_timeouts
|
||||
)
|
||||
|
||||
self.clock.call_later(
|
||||
30, self.clock.looping_call, run_timeout_handler, 5000
|
||||
)
|
||||
|
||||
self.clock.call_later(30, self.clock.looping_call, run_timeout_handler, 5000)
|
||||
def run_persister():
|
||||
return run_as_background_process(
|
||||
"persist_presence_changes", self._persist_unpersisted_changes
|
||||
)
|
||||
|
||||
def run_persister():
|
||||
return run_as_background_process(
|
||||
"persist_presence_changes", self._persist_unpersisted_changes
|
||||
)
|
||||
|
||||
self.clock.call_later(60, self.clock.looping_call, run_persister, 60 * 1000)
|
||||
self.clock.call_later(60, self.clock.looping_call, run_persister, 60 * 1000)
|
||||
|
||||
LaterGauge(
|
||||
"synapse_handlers_presence_wheel_timer_size",
|
||||
@@ -299,7 +302,7 @@ class PresenceHandler(BasePresenceHandler):
|
||||
)
|
||||
|
||||
# Used to handle sending of presence to newly joined users/servers
|
||||
if hs.config.use_presence:
|
||||
if self._presence_enabled:
|
||||
self.notifier.add_replication_callback(self.notify_new_event)
|
||||
|
||||
# Presence is best effort and quickly heals itself, so lets just always
|
||||
|
||||
@@ -43,7 +43,6 @@ class RoomListHandler(BaseHandler):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__(hs)
|
||||
self.enable_room_list_search = hs.config.enable_room_list_search
|
||||
|
||||
self.response_cache = ResponseCache(
|
||||
hs, "room_list"
|
||||
) # type: ResponseCache[Tuple[Optional[int], Optional[str], ThirdPartyInstanceID]]
|
||||
|
||||
@@ -66,7 +66,6 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
||||
self.account_data_handler = hs.get_account_data_handler()
|
||||
|
||||
self.member_linearizer = Linearizer(name="member")
|
||||
self.member_limiter = Linearizer(max_count=10, name="member_as_limiter")
|
||||
|
||||
self.clock = hs.get_clock()
|
||||
self.spam_checker = hs.get_spam_checker()
|
||||
@@ -337,38 +336,19 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
||||
|
||||
key = (room_id,)
|
||||
|
||||
as_id = object()
|
||||
if requester.app_service:
|
||||
as_id = requester.app_service.id
|
||||
|
||||
then = self.clock.time_msec()
|
||||
|
||||
with (await self.member_limiter.queue(as_id)):
|
||||
diff = self.clock.time_msec() - then
|
||||
|
||||
if diff > 80 * 1000:
|
||||
# haproxy would have timed the request out anyway...
|
||||
raise SynapseError(504, "took to long to process")
|
||||
|
||||
with (await self.member_linearizer.queue(key)):
|
||||
diff = self.clock.time_msec() - then
|
||||
|
||||
if diff > 80 * 1000:
|
||||
# haproxy would have timed the request out anyway...
|
||||
raise SynapseError(504, "took to long to process")
|
||||
|
||||
result = await self.update_membership_locked(
|
||||
requester,
|
||||
target,
|
||||
room_id,
|
||||
action,
|
||||
txn_id=txn_id,
|
||||
remote_room_hosts=remote_room_hosts,
|
||||
third_party_signed=third_party_signed,
|
||||
ratelimit=ratelimit,
|
||||
content=content,
|
||||
require_consent=require_consent,
|
||||
)
|
||||
with (await self.member_linearizer.queue(key)):
|
||||
result = await self.update_membership_locked(
|
||||
requester,
|
||||
target,
|
||||
room_id,
|
||||
action,
|
||||
txn_id=txn_id,
|
||||
remote_room_hosts=remote_room_hosts,
|
||||
third_party_signed=third_party_signed,
|
||||
ratelimit=ratelimit,
|
||||
content=content,
|
||||
require_consent=require_consent,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@@ -31,8 +31,8 @@ from urllib.parse import urlencode
|
||||
import attr
|
||||
from typing_extensions import NoReturn, Protocol
|
||||
|
||||
from twisted.web.http import Request
|
||||
from twisted.web.iweb import IRequest
|
||||
from twisted.web.server import Request
|
||||
|
||||
from synapse.api.constants import LoginType
|
||||
from synapse.api.errors import Codes, NotFoundError, RedirectException, SynapseError
|
||||
|
||||
@@ -52,7 +52,6 @@ logger = logging.getLogger(__name__)
|
||||
# Debug logger for https://github.com/matrix-org/synapse/issues/4422
|
||||
issue4422_logger = logging.getLogger("synapse.handler.sync.4422_debug")
|
||||
|
||||
SYNC_RESPONSE_CACHE_MS = 2 * 60 * 1000
|
||||
|
||||
# Counts the number of times we returned a non-empty sync. `type` is one of
|
||||
# "initial_sync", "full_state_sync" or "incremental_sync", `lazy_loaded` is
|
||||
@@ -245,7 +244,7 @@ class SyncHandler:
|
||||
self.event_sources = hs.get_event_sources()
|
||||
self.clock = hs.get_clock()
|
||||
self.response_cache = ResponseCache(
|
||||
hs, "sync", timeout_ms=SYNC_RESPONSE_CACHE_MS
|
||||
hs, "sync"
|
||||
) # type: ResponseCache[Tuple[Any, ...]]
|
||||
self.state = hs.get_state_handler()
|
||||
self.auth = hs.get_auth()
|
||||
@@ -278,9 +277,8 @@ class SyncHandler:
|
||||
user_id = sync_config.user.to_string()
|
||||
await self.auth.check_auth_blocking(requester=requester)
|
||||
|
||||
res = await self.response_cache.wrap_conditional(
|
||||
res = await self.response_cache.wrap(
|
||||
sync_config.request_key,
|
||||
lambda result: since_token != result.next_batch,
|
||||
self._wait_for_sync_for_user,
|
||||
sync_config,
|
||||
since_token,
|
||||
|
||||
@@ -289,8 +289,7 @@ class SimpleHttpClient:
|
||||
treq_args: Dict[str, Any] = {},
|
||||
ip_whitelist: Optional[IPSet] = None,
|
||||
ip_blacklist: Optional[IPSet] = None,
|
||||
http_proxy: Optional[bytes] = None,
|
||||
https_proxy: Optional[bytes] = None,
|
||||
use_proxy: bool = False,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@@ -300,8 +299,8 @@ class SimpleHttpClient:
|
||||
we may not request.
|
||||
ip_whitelist: The whitelisted IP addresses, that we can
|
||||
request if it were otherwise caught in a blacklist.
|
||||
http_proxy: proxy server to use for http connections. host[:port]
|
||||
https_proxy: proxy server to use for https connections. host[:port]
|
||||
use_proxy: Whether proxy settings should be discovered and used
|
||||
from conventional environment variables.
|
||||
"""
|
||||
self.hs = hs
|
||||
|
||||
@@ -345,8 +344,7 @@ class SimpleHttpClient:
|
||||
connectTimeout=15,
|
||||
contextFactory=self.hs.get_http_client_context_factory(),
|
||||
pool=pool,
|
||||
http_proxy=http_proxy,
|
||||
https_proxy=https_proxy,
|
||||
use_proxy=use_proxy,
|
||||
)
|
||||
|
||||
if self._ip_blacklist:
|
||||
@@ -750,7 +748,32 @@ class BodyExceededMaxSize(Exception):
|
||||
"""The maximum allowed size of the HTTP body was exceeded."""
|
||||
|
||||
|
||||
class _DiscardBodyWithMaxSizeProtocol(protocol.Protocol):
|
||||
"""A protocol which immediately errors upon receiving data."""
|
||||
|
||||
def __init__(self, deferred: defer.Deferred):
|
||||
self.deferred = deferred
|
||||
|
||||
def _maybe_fail(self):
|
||||
"""
|
||||
Report a max size exceed error and disconnect the first time this is called.
|
||||
"""
|
||||
if not self.deferred.called:
|
||||
self.deferred.errback(BodyExceededMaxSize())
|
||||
# Close the connection (forcefully) since all the data will get
|
||||
# discarded anyway.
|
||||
self.transport.abortConnection()
|
||||
|
||||
def dataReceived(self, data: bytes) -> None:
|
||||
self._maybe_fail()
|
||||
|
||||
def connectionLost(self, reason: Failure) -> None:
|
||||
self._maybe_fail()
|
||||
|
||||
|
||||
class _ReadBodyWithMaxSizeProtocol(protocol.Protocol):
|
||||
"""A protocol which reads body to a stream, erroring if the body exceeds a maximum size."""
|
||||
|
||||
def __init__(
|
||||
self, stream: BinaryIO, deferred: defer.Deferred, max_size: Optional[int]
|
||||
):
|
||||
@@ -807,13 +830,15 @@ def read_body_with_max_size(
|
||||
Returns:
|
||||
A Deferred which resolves to the length of the read body.
|
||||
"""
|
||||
d = defer.Deferred()
|
||||
|
||||
# If the Content-Length header gives a size larger than the maximum allowed
|
||||
# size, do not bother downloading the body.
|
||||
if max_size is not None and response.length != UNKNOWN_LENGTH:
|
||||
if response.length > max_size:
|
||||
return defer.fail(BodyExceededMaxSize())
|
||||
response.deliverBody(_DiscardBodyWithMaxSizeProtocol(d))
|
||||
return d
|
||||
|
||||
d = defer.Deferred()
|
||||
response.deliverBody(_ReadBodyWithMaxSizeProtocol(stream, d, max_size))
|
||||
return d
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import urllib.parse
|
||||
from typing import List, Optional
|
||||
from typing import Any, Generator, List, Optional
|
||||
|
||||
from netaddr import AddrFormatError, IPAddress, IPSet
|
||||
from zope.interface import implementer
|
||||
@@ -116,7 +116,7 @@ class MatrixFederationAgent:
|
||||
uri: bytes,
|
||||
headers: Optional[Headers] = None,
|
||||
bodyProducer: Optional[IBodyProducer] = None,
|
||||
) -> defer.Deferred:
|
||||
) -> Generator[defer.Deferred, Any, defer.Deferred]:
|
||||
"""
|
||||
Args:
|
||||
method: HTTP method: GET/POST/etc
|
||||
@@ -177,17 +177,17 @@ class MatrixFederationAgent:
|
||||
# We need to make sure the host header is set to the netloc of the
|
||||
# server and that a user-agent is provided.
|
||||
if headers is None:
|
||||
headers = Headers()
|
||||
request_headers = Headers()
|
||||
else:
|
||||
headers = headers.copy()
|
||||
request_headers = headers.copy()
|
||||
|
||||
if not headers.hasHeader(b"host"):
|
||||
headers.addRawHeader(b"host", parsed_uri.netloc)
|
||||
if not headers.hasHeader(b"user-agent"):
|
||||
headers.addRawHeader(b"user-agent", self.user_agent)
|
||||
if not request_headers.hasHeader(b"host"):
|
||||
request_headers.addRawHeader(b"host", parsed_uri.netloc)
|
||||
if not request_headers.hasHeader(b"user-agent"):
|
||||
request_headers.addRawHeader(b"user-agent", self.user_agent)
|
||||
|
||||
res = yield make_deferred_yieldable(
|
||||
self._agent.request(method, uri, headers, bodyProducer)
|
||||
self._agent.request(method, uri, request_headers, bodyProducer)
|
||||
)
|
||||
|
||||
return res
|
||||
|
||||
@@ -1049,14 +1049,14 @@ def check_content_type_is_json(headers: Headers) -> None:
|
||||
RequestSendFailed: if the Content-Type header is missing or isn't JSON
|
||||
|
||||
"""
|
||||
c_type = headers.getRawHeaders(b"Content-Type")
|
||||
if c_type is None:
|
||||
content_type_headers = headers.getRawHeaders(b"Content-Type")
|
||||
if content_type_headers is None:
|
||||
raise RequestSendFailed(
|
||||
RuntimeError("No Content-Type header received from remote server"),
|
||||
can_retry=False,
|
||||
)
|
||||
|
||||
c_type = c_type[0].decode("ascii") # only the first header
|
||||
c_type = content_type_headers[0].decode("ascii") # only the first header
|
||||
val, options = cgi.parse_header(c_type)
|
||||
if val != "application/json":
|
||||
raise RequestSendFailed(
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import re
|
||||
from urllib.request import getproxies_environment, proxy_bypass_environment
|
||||
|
||||
from zope.interface import implementer
|
||||
|
||||
@@ -58,6 +59,9 @@ class ProxyAgent(_AgentBase):
|
||||
|
||||
pool (HTTPConnectionPool|None): connection pool to be used. If None, a
|
||||
non-persistent pool instance will be created.
|
||||
|
||||
use_proxy (bool): Whether proxy settings should be discovered and used
|
||||
from conventional environment variables.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -68,8 +72,7 @@ class ProxyAgent(_AgentBase):
|
||||
connectTimeout=None,
|
||||
bindAddress=None,
|
||||
pool=None,
|
||||
http_proxy=None,
|
||||
https_proxy=None,
|
||||
use_proxy=False,
|
||||
):
|
||||
_AgentBase.__init__(self, reactor, pool)
|
||||
|
||||
@@ -84,6 +87,15 @@ class ProxyAgent(_AgentBase):
|
||||
if bindAddress is not None:
|
||||
self._endpoint_kwargs["bindAddress"] = bindAddress
|
||||
|
||||
http_proxy = None
|
||||
https_proxy = None
|
||||
no_proxy = None
|
||||
if use_proxy:
|
||||
proxies = getproxies_environment()
|
||||
http_proxy = proxies["http"].encode() if "http" in proxies else None
|
||||
https_proxy = proxies["https"].encode() if "https" in proxies else None
|
||||
no_proxy = proxies["no"] if "no" in proxies else None
|
||||
|
||||
self.http_proxy_endpoint = _http_proxy_endpoint(
|
||||
http_proxy, self.proxy_reactor, **self._endpoint_kwargs
|
||||
)
|
||||
@@ -92,6 +104,8 @@ class ProxyAgent(_AgentBase):
|
||||
https_proxy, self.proxy_reactor, **self._endpoint_kwargs
|
||||
)
|
||||
|
||||
self.no_proxy = no_proxy
|
||||
|
||||
self._policy_for_https = contextFactory
|
||||
self._reactor = reactor
|
||||
|
||||
@@ -139,13 +153,28 @@ class ProxyAgent(_AgentBase):
|
||||
pool_key = (parsed_uri.scheme, parsed_uri.host, parsed_uri.port)
|
||||
request_path = parsed_uri.originForm
|
||||
|
||||
if parsed_uri.scheme == b"http" and self.http_proxy_endpoint:
|
||||
should_skip_proxy = False
|
||||
if self.no_proxy is not None:
|
||||
should_skip_proxy = proxy_bypass_environment(
|
||||
parsed_uri.host.decode(),
|
||||
proxies={"no": self.no_proxy},
|
||||
)
|
||||
|
||||
if (
|
||||
parsed_uri.scheme == b"http"
|
||||
and self.http_proxy_endpoint
|
||||
and not should_skip_proxy
|
||||
):
|
||||
# Cache *all* connections under the same key, since we are only
|
||||
# connecting to a single destination, the proxy:
|
||||
pool_key = ("http-proxy", self.http_proxy_endpoint)
|
||||
endpoint = self.http_proxy_endpoint
|
||||
request_path = uri
|
||||
elif parsed_uri.scheme == b"https" and self.https_proxy_endpoint:
|
||||
elif (
|
||||
parsed_uri.scheme == b"https"
|
||||
and self.https_proxy_endpoint
|
||||
and not should_skip_proxy
|
||||
):
|
||||
endpoint = HTTPConnectProxyEndpoint(
|
||||
self.proxy_reactor,
|
||||
self.https_proxy_endpoint,
|
||||
|
||||
@@ -21,6 +21,7 @@ import logging
|
||||
import types
|
||||
import urllib
|
||||
from http import HTTPStatus
|
||||
from inspect import isawaitable
|
||||
from io import BytesIO
|
||||
from typing import (
|
||||
Any,
|
||||
@@ -30,6 +31,7 @@ from typing import (
|
||||
Iterable,
|
||||
Iterator,
|
||||
List,
|
||||
Optional,
|
||||
Pattern,
|
||||
Tuple,
|
||||
Union,
|
||||
@@ -79,10 +81,12 @@ def return_json_error(f: failure.Failure, request: SynapseRequest) -> None:
|
||||
"""Sends a JSON error response to clients."""
|
||||
|
||||
if f.check(SynapseError):
|
||||
error_code = f.value.code
|
||||
error_dict = f.value.error_dict()
|
||||
# mypy doesn't understand that f.check asserts the type.
|
||||
exc = f.value # type: SynapseError # type: ignore
|
||||
error_code = exc.code
|
||||
error_dict = exc.error_dict()
|
||||
|
||||
logger.info("%s SynapseError: %s - %s", request, error_code, f.value.msg)
|
||||
logger.info("%s SynapseError: %s - %s", request, error_code, exc.msg)
|
||||
else:
|
||||
error_code = 500
|
||||
error_dict = {"error": "Internal server error", "errcode": Codes.UNKNOWN}
|
||||
@@ -91,7 +95,7 @@ def return_json_error(f: failure.Failure, request: SynapseRequest) -> None:
|
||||
"Failed handle request via %r: %r",
|
||||
request.request_metrics.name,
|
||||
request,
|
||||
exc_info=(f.type, f.value, f.getTracebackObject()),
|
||||
exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore
|
||||
)
|
||||
|
||||
# Only respond with an error response if we haven't already started writing,
|
||||
@@ -128,7 +132,8 @@ def return_html_error(
|
||||
`{msg}` placeholders), or a jinja2 template
|
||||
"""
|
||||
if f.check(CodeMessageException):
|
||||
cme = f.value
|
||||
# mypy doesn't understand that f.check asserts the type.
|
||||
cme = f.value # type: CodeMessageException # type: ignore
|
||||
code = cme.code
|
||||
msg = cme.msg
|
||||
|
||||
@@ -142,7 +147,7 @@ def return_html_error(
|
||||
logger.error(
|
||||
"Failed handle request %r",
|
||||
request,
|
||||
exc_info=(f.type, f.value, f.getTracebackObject()),
|
||||
exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore
|
||||
)
|
||||
else:
|
||||
code = HTTPStatus.INTERNAL_SERVER_ERROR
|
||||
@@ -151,7 +156,7 @@ def return_html_error(
|
||||
logger.error(
|
||||
"Failed handle request %r",
|
||||
request,
|
||||
exc_info=(f.type, f.value, f.getTracebackObject()),
|
||||
exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore
|
||||
)
|
||||
|
||||
if isinstance(error_template, str):
|
||||
@@ -278,7 +283,7 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
|
||||
raw_callback_return = method_handler(request)
|
||||
|
||||
# Is it synchronous? We'll allow this for now.
|
||||
if isinstance(raw_callback_return, (defer.Deferred, types.CoroutineType)):
|
||||
if isawaitable(raw_callback_return):
|
||||
callback_return = await raw_callback_return
|
||||
else:
|
||||
callback_return = raw_callback_return # type: ignore
|
||||
@@ -399,8 +404,10 @@ class JsonResource(DirectServeJsonResource):
|
||||
A tuple of the callback to use, the name of the servlet, and the
|
||||
key word arguments to pass to the callback
|
||||
"""
|
||||
# At this point the path must be bytes.
|
||||
request_path_bytes = request.path # type: bytes # type: ignore
|
||||
request_path = request_path_bytes.decode("ascii")
|
||||
# Treat HEAD requests as GET requests.
|
||||
request_path = request.path.decode("ascii")
|
||||
request_method = request.method
|
||||
if request_method == b"HEAD":
|
||||
request_method = b"GET"
|
||||
@@ -551,7 +558,7 @@ class _ByteProducer:
|
||||
request: Request,
|
||||
iterator: Iterator[bytes],
|
||||
):
|
||||
self._request = request
|
||||
self._request = request # type: Optional[Request]
|
||||
self._iterator = iterator
|
||||
self._paused = False
|
||||
|
||||
@@ -563,7 +570,7 @@ class _ByteProducer:
|
||||
"""
|
||||
Send a list of bytes as a chunk of a response.
|
||||
"""
|
||||
if not data:
|
||||
if not data or not self._request:
|
||||
return
|
||||
self._request.write(b"".join(data))
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
import contextlib
|
||||
import logging
|
||||
import time
|
||||
from typing import Optional, Union
|
||||
from typing import Optional, Type, Union
|
||||
|
||||
import attr
|
||||
from zope.interface import implementer
|
||||
@@ -57,7 +57,7 @@ class SynapseRequest(Request):
|
||||
|
||||
def __init__(self, channel, *args, **kw):
|
||||
Request.__init__(self, channel, *args, **kw)
|
||||
self.site = channel.site
|
||||
self.site = channel.site # type: SynapseSite
|
||||
self._channel = channel # this is used by the tests
|
||||
self.start_time = 0.0
|
||||
|
||||
@@ -96,25 +96,34 @@ class SynapseRequest(Request):
|
||||
def get_request_id(self):
|
||||
return "%s-%i" % (self.get_method(), self.request_seq)
|
||||
|
||||
def get_redacted_uri(self):
|
||||
uri = self.uri
|
||||
def get_redacted_uri(self) -> str:
|
||||
"""Gets the redacted URI associated with the request (or placeholder if the URI
|
||||
has not yet been received).
|
||||
|
||||
Note: This is necessary as the placeholder value in twisted is str
|
||||
rather than bytes, so we need to sanitise `self.uri`.
|
||||
|
||||
Returns:
|
||||
The redacted URI as a string.
|
||||
"""
|
||||
uri = self.uri # type: Union[bytes, str]
|
||||
if isinstance(uri, bytes):
|
||||
uri = self.uri.decode("ascii", errors="replace")
|
||||
uri = uri.decode("ascii", errors="replace")
|
||||
return redact_uri(uri)
|
||||
|
||||
def get_method(self):
|
||||
"""Gets the method associated with the request (or placeholder if not
|
||||
method has yet been received).
|
||||
def get_method(self) -> str:
|
||||
"""Gets the method associated with the request (or placeholder if method
|
||||
has not yet been received).
|
||||
|
||||
Note: This is necessary as the placeholder value in twisted is str
|
||||
rather than bytes, so we need to sanitise `self.method`.
|
||||
|
||||
Returns:
|
||||
str
|
||||
The request method as a string.
|
||||
"""
|
||||
method = self.method
|
||||
method = self.method # type: Union[bytes, str]
|
||||
if isinstance(method, bytes):
|
||||
method = self.method.decode("ascii")
|
||||
return self.method.decode("ascii")
|
||||
return method
|
||||
|
||||
def render(self, resrc):
|
||||
@@ -375,9 +384,9 @@ class XForwardedForRequest(SynapseRequest):
|
||||
else:
|
||||
# this is done largely for backwards-compatibility so that people that
|
||||
# haven't set an x-forwarded-proto header don't get a redirect loop.
|
||||
#logger.warning(
|
||||
# "forwarded request lacks an x-forwarded-proto header: assuming https"
|
||||
#)
|
||||
logger.warning(
|
||||
"forwarded request lacks an x-forwarded-proto header: assuming https"
|
||||
)
|
||||
self._forwarded_https = True
|
||||
|
||||
def isSecure(self):
|
||||
@@ -432,7 +441,9 @@ class SynapseSite(Site):
|
||||
|
||||
assert config.http_options is not None
|
||||
proxied = config.http_options.x_forwarded
|
||||
self.requestFactory = XForwardedForRequest if proxied else SynapseRequest
|
||||
self.requestFactory = (
|
||||
XForwardedForRequest if proxied else SynapseRequest
|
||||
) # type: Type[Request]
|
||||
self.access_logger = logging.getLogger(logger_name)
|
||||
self.server_version_string = server_version_string.encode("ascii")
|
||||
|
||||
|
||||
@@ -32,7 +32,7 @@ from twisted.internet.endpoints import (
|
||||
TCP4ClientEndpoint,
|
||||
TCP6ClientEndpoint,
|
||||
)
|
||||
from twisted.internet.interfaces import IPushProducer, ITransport
|
||||
from twisted.internet.interfaces import IPushProducer, IStreamClientEndpoint, ITransport
|
||||
from twisted.internet.protocol import Factory, Protocol
|
||||
from twisted.python.failure import Failure
|
||||
|
||||
@@ -121,7 +121,9 @@ class RemoteHandler(logging.Handler):
|
||||
try:
|
||||
ip = ip_address(self.host)
|
||||
if isinstance(ip, IPv4Address):
|
||||
endpoint = TCP4ClientEndpoint(_reactor, self.host, self.port)
|
||||
endpoint = TCP4ClientEndpoint(
|
||||
_reactor, self.host, self.port
|
||||
) # type: IStreamClientEndpoint
|
||||
elif isinstance(ip, IPv6Address):
|
||||
endpoint = TCP6ClientEndpoint(_reactor, self.host, self.port)
|
||||
else:
|
||||
|
||||
@@ -527,7 +527,7 @@ class ReactorLastSeenMetric:
|
||||
REGISTRY.register(ReactorLastSeenMetric())
|
||||
|
||||
|
||||
def runUntilCurrentTimer(func):
|
||||
def runUntilCurrentTimer(reactor, func):
|
||||
@functools.wraps(func)
|
||||
def f(*args, **kwargs):
|
||||
now = reactor.seconds()
|
||||
@@ -590,13 +590,14 @@ def runUntilCurrentTimer(func):
|
||||
|
||||
try:
|
||||
# Ensure the reactor has all the attributes we expect
|
||||
reactor.runUntilCurrent
|
||||
reactor._newTimedCalls
|
||||
reactor.threadCallQueue
|
||||
reactor.seconds # type: ignore
|
||||
reactor.runUntilCurrent # type: ignore
|
||||
reactor._newTimedCalls # type: ignore
|
||||
reactor.threadCallQueue # type: ignore
|
||||
|
||||
# runUntilCurrent is called when we have pending calls. It is called once
|
||||
# per iteratation after fd polling.
|
||||
reactor.runUntilCurrent = runUntilCurrentTimer(reactor.runUntilCurrent)
|
||||
reactor.runUntilCurrent = runUntilCurrentTimer(reactor, reactor.runUntilCurrent) # type: ignore
|
||||
|
||||
# We manually run the GC each reactor tick so that we can get some metrics
|
||||
# about time spent doing GC,
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Iterable, Optional, Tuple
|
||||
from typing import TYPE_CHECKING, Any, Generator, Iterable, Optional, Tuple
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
@@ -307,7 +307,7 @@ class ModuleApi:
|
||||
@defer.inlineCallbacks
|
||||
def get_state_events_in_room(
|
||||
self, room_id: str, types: Iterable[Tuple[str, Optional[str]]]
|
||||
) -> defer.Deferred:
|
||||
) -> Generator[defer.Deferred, Any, defer.Deferred]:
|
||||
"""Gets current state events for the given room.
|
||||
|
||||
(This is exposed for compatibility with the old SpamCheckerApi. We should
|
||||
|
||||
@@ -15,11 +15,12 @@
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import urllib.parse
|
||||
from typing import TYPE_CHECKING, Any, Dict, Iterable, Union
|
||||
from typing import TYPE_CHECKING, Any, Dict, Iterable, Optional, Union
|
||||
|
||||
from prometheus_client import Counter
|
||||
|
||||
from twisted.internet.error import AlreadyCalled, AlreadyCancelled
|
||||
from twisted.internet.interfaces import IDelayedCall
|
||||
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.events import EventBase
|
||||
@@ -71,7 +72,7 @@ class HttpPusher(Pusher):
|
||||
self.data = pusher_config.data
|
||||
self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC
|
||||
self.failing_since = pusher_config.failing_since
|
||||
self.timed_call = None
|
||||
self.timed_call = None # type: Optional[IDelayedCall]
|
||||
self._is_processing = False
|
||||
self._group_unread_count_by_room = hs.config.push_group_unread_count_by_room
|
||||
self._pusherpool = hs.get_pusherpool()
|
||||
@@ -101,11 +102,6 @@ class HttpPusher(Pusher):
|
||||
"'url' must have a path of '/_matrix/push/v1/notify'"
|
||||
)
|
||||
|
||||
url = url.replace(
|
||||
"https://matrix.org/_matrix/push/v1/notify",
|
||||
"http://10.103.0.7/_matrix/push/v1/notify",
|
||||
)
|
||||
|
||||
self.url = url
|
||||
self.http_client = hs.get_proxied_blacklisted_http_client()
|
||||
self.data_minus_url = {}
|
||||
|
||||
@@ -15,9 +15,10 @@
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple
|
||||
|
||||
from twisted.web.http import Request
|
||||
from twisted.web.server import Request
|
||||
|
||||
from synapse.http.servlet import parse_json_object_from_request
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.replication.http._base import ReplicationEndpoint
|
||||
from synapse.types import JsonDict, Requester, UserID
|
||||
from synapse.util.distributor import user_left_room
|
||||
@@ -78,7 +79,7 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
|
||||
}
|
||||
|
||||
async def _handle_request( # type: ignore
|
||||
self, request: Request, room_id: str, user_id: str
|
||||
self, request: SynapseRequest, room_id: str, user_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
content = parse_json_object_from_request(request)
|
||||
|
||||
@@ -86,7 +87,6 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
|
||||
event_content = content["content"]
|
||||
|
||||
requester = Requester.deserialize(self.store, content["requester"])
|
||||
|
||||
request.requester = requester
|
||||
|
||||
logger.info("remote_join: %s into room: %s", user_id, room_id)
|
||||
@@ -147,7 +147,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
|
||||
}
|
||||
|
||||
async def _handle_request( # type: ignore
|
||||
self, request: Request, invite_event_id: str
|
||||
self, request: SynapseRequest, invite_event_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
content = parse_json_object_from_request(request)
|
||||
|
||||
@@ -155,7 +155,6 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
|
||||
event_content = content["content"]
|
||||
|
||||
requester = Requester.deserialize(self.store, content["requester"])
|
||||
|
||||
request.requester = requester
|
||||
|
||||
# hopefully we're now on the master, so this won't recurse!
|
||||
|
||||
@@ -108,9 +108,7 @@ class ReplicationDataHandler:
|
||||
|
||||
# Map from stream to list of deferreds waiting for the stream to
|
||||
# arrive at a particular position. The lists are sorted by stream position.
|
||||
self._streams_to_waiters = (
|
||||
{}
|
||||
) # type: Dict[str, List[Tuple[int, Deferred[None]]]]
|
||||
self._streams_to_waiters = {} # type: Dict[str, List[Tuple[int, Deferred]]]
|
||||
|
||||
async def on_rdata(
|
||||
self, stream_name: str, instance_name: str, token: int, rows: list
|
||||
|
||||
@@ -502,7 +502,7 @@ class AccountDataStream(Stream):
|
||||
"""Global or per room account data was changed"""
|
||||
|
||||
AccountDataStreamRow = namedtuple(
|
||||
"AccountDataStream",
|
||||
"AccountDataStreamRow",
|
||||
("user_id", "room_id", "data_type"), # str # Optional[str] # str
|
||||
)
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
|
||||
from synapse.api.errors import NotFoundError, SynapseError
|
||||
from synapse.http.servlet import (
|
||||
@@ -20,8 +21,12 @@ from synapse.http.servlet import (
|
||||
assert_params_in_dict,
|
||||
parse_json_object_from_request,
|
||||
)
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin
|
||||
from synapse.types import UserID
|
||||
from synapse.types import JsonDict, UserID
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -35,14 +40,16 @@ class DeviceRestServlet(RestServlet):
|
||||
"/users/(?P<user_id>[^/]*)/devices/(?P<device_id>[^/]*)$", "v2"
|
||||
)
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.device_handler = hs.get_device_handler()
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
async def on_GET(self, request, user_id, device_id):
|
||||
async def on_GET(
|
||||
self, request: SynapseRequest, user_id, device_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
await assert_requester_is_admin(self.auth, request)
|
||||
|
||||
target_user = UserID.from_string(user_id)
|
||||
@@ -58,7 +65,9 @@ class DeviceRestServlet(RestServlet):
|
||||
)
|
||||
return 200, device
|
||||
|
||||
async def on_DELETE(self, request, user_id, device_id):
|
||||
async def on_DELETE(
|
||||
self, request: SynapseRequest, user_id: str, device_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
await assert_requester_is_admin(self.auth, request)
|
||||
|
||||
target_user = UserID.from_string(user_id)
|
||||
@@ -72,7 +81,9 @@ class DeviceRestServlet(RestServlet):
|
||||
await self.device_handler.delete_device(target_user.to_string(), device_id)
|
||||
return 200, {}
|
||||
|
||||
async def on_PUT(self, request, user_id, device_id):
|
||||
async def on_PUT(
|
||||
self, request: SynapseRequest, user_id: str, device_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
await assert_requester_is_admin(self.auth, request)
|
||||
|
||||
target_user = UserID.from_string(user_id)
|
||||
@@ -97,7 +108,7 @@ class DevicesRestServlet(RestServlet):
|
||||
|
||||
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/devices$", "v2")
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
"""
|
||||
Args:
|
||||
hs (synapse.server.HomeServer): server
|
||||
@@ -107,7 +118,9 @@ class DevicesRestServlet(RestServlet):
|
||||
self.device_handler = hs.get_device_handler()
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
async def on_GET(self, request, user_id):
|
||||
async def on_GET(
|
||||
self, request: SynapseRequest, user_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
await assert_requester_is_admin(self.auth, request)
|
||||
|
||||
target_user = UserID.from_string(user_id)
|
||||
@@ -130,13 +143,15 @@ class DeleteDevicesRestServlet(RestServlet):
|
||||
|
||||
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/delete_devices$", "v2")
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.device_handler = hs.get_device_handler()
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
async def on_POST(self, request, user_id):
|
||||
async def on_POST(
|
||||
self, request: SynapseRequest, user_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
await assert_requester_is_admin(self.auth, request)
|
||||
|
||||
target_user = UserID.from_string(user_id)
|
||||
|
||||
@@ -14,10 +14,16 @@
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
|
||||
from synapse.api.errors import Codes, NotFoundError, SynapseError
|
||||
from synapse.http.servlet import RestServlet, parse_integer, parse_string
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.rest.admin._base import admin_patterns, assert_requester_is_admin
|
||||
from synapse.types import JsonDict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.server import HomeServer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -45,12 +51,12 @@ class EventReportsRestServlet(RestServlet):
|
||||
|
||||
PATTERNS = admin_patterns("/event_reports$")
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
async def on_GET(self, request):
|
||||
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||
await assert_requester_is_admin(self.auth, request)
|
||||
|
||||
start = parse_integer(request, "from", default=0)
|
||||
@@ -106,26 +112,28 @@ class EventReportDetailRestServlet(RestServlet):
|
||||
|
||||
PATTERNS = admin_patterns("/event_reports/(?P<report_id>[^/]*)$")
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
async def on_GET(self, request, report_id):
|
||||
async def on_GET(
|
||||
self, request: SynapseRequest, report_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
await assert_requester_is_admin(self.auth, request)
|
||||
|
||||
message = (
|
||||
"The report_id parameter must be a string representing a positive integer."
|
||||
)
|
||||
try:
|
||||
report_id = int(report_id)
|
||||
resolved_report_id = int(report_id)
|
||||
except ValueError:
|
||||
raise SynapseError(400, message, errcode=Codes.INVALID_PARAM)
|
||||
|
||||
if report_id < 0:
|
||||
if resolved_report_id < 0:
|
||||
raise SynapseError(400, message, errcode=Codes.INVALID_PARAM)
|
||||
|
||||
ret = await self.store.get_event_report(report_id)
|
||||
ret = await self.store.get_event_report(resolved_report_id)
|
||||
if not ret:
|
||||
raise NotFoundError("Event report not found")
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
|
||||
from twisted.web.http import Request
|
||||
from twisted.web.server import Request
|
||||
|
||||
from synapse.api.errors import AuthError, Codes, NotFoundError, SynapseError
|
||||
from synapse.http.servlet import RestServlet, parse_boolean, parse_integer
|
||||
|
||||
@@ -44,6 +44,48 @@ if TYPE_CHECKING:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ResolveRoomIdMixin:
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.room_member_handler = hs.get_room_member_handler()
|
||||
|
||||
async def resolve_room_id(
|
||||
self, room_identifier: str, remote_room_hosts: Optional[List[str]] = None
|
||||
) -> Tuple[str, Optional[List[str]]]:
|
||||
"""
|
||||
Resolve a room identifier to a room ID, if necessary.
|
||||
|
||||
This also performanes checks to ensure the room ID is of the proper form.
|
||||
|
||||
Args:
|
||||
room_identifier: The room ID or alias.
|
||||
remote_room_hosts: The potential remote room hosts to use.
|
||||
|
||||
Returns:
|
||||
The resolved room ID.
|
||||
|
||||
Raises:
|
||||
SynapseError if the room ID is of the wrong form.
|
||||
"""
|
||||
if RoomID.is_valid(room_identifier):
|
||||
resolved_room_id = room_identifier
|
||||
elif RoomAlias.is_valid(room_identifier):
|
||||
room_alias = RoomAlias.from_string(room_identifier)
|
||||
(
|
||||
room_id,
|
||||
remote_room_hosts,
|
||||
) = await self.room_member_handler.lookup_room_alias(room_alias)
|
||||
resolved_room_id = room_id.to_string()
|
||||
else:
|
||||
raise SynapseError(
|
||||
400, "%s was not legal room ID or room alias" % (room_identifier,)
|
||||
)
|
||||
if not resolved_room_id:
|
||||
raise SynapseError(
|
||||
400, "Unknown room ID or room alias %s" % room_identifier
|
||||
)
|
||||
return resolved_room_id, remote_room_hosts
|
||||
|
||||
|
||||
class ShutdownRoomRestServlet(RestServlet):
|
||||
"""Shuts down a room by removing all local users from the room and blocking
|
||||
all future invites and joins to the room. Any local aliases will be repointed
|
||||
@@ -334,14 +376,14 @@ class RoomStateRestServlet(RestServlet):
|
||||
return 200, ret
|
||||
|
||||
|
||||
class JoinRoomAliasServlet(RestServlet):
|
||||
class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet):
|
||||
|
||||
PATTERNS = admin_patterns("/join/(?P<room_identifier>[^/]*)")
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__(hs)
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.room_member_handler = hs.get_room_member_handler()
|
||||
self.admin_handler = hs.get_admin_handler()
|
||||
self.state_handler = hs.get_state_handler()
|
||||
|
||||
@@ -362,22 +404,16 @@ class JoinRoomAliasServlet(RestServlet):
|
||||
if not await self.admin_handler.get_user(target_user):
|
||||
raise NotFoundError("User not found")
|
||||
|
||||
if RoomID.is_valid(room_identifier):
|
||||
room_id = room_identifier
|
||||
try:
|
||||
remote_room_hosts = [
|
||||
x.decode("ascii") for x in request.args[b"server_name"]
|
||||
] # type: Optional[List[str]]
|
||||
except Exception:
|
||||
remote_room_hosts = None
|
||||
elif RoomAlias.is_valid(room_identifier):
|
||||
handler = self.room_member_handler
|
||||
room_alias = RoomAlias.from_string(room_identifier)
|
||||
room_id, remote_room_hosts = await handler.lookup_room_alias(room_alias)
|
||||
else:
|
||||
raise SynapseError(
|
||||
400, "%s was not legal room ID or room alias" % (room_identifier,)
|
||||
)
|
||||
# Get the room ID from the identifier.
|
||||
try:
|
||||
remote_room_hosts = [
|
||||
x.decode("ascii") for x in request.args[b"server_name"]
|
||||
] # type: Optional[List[str]]
|
||||
except Exception:
|
||||
remote_room_hosts = None
|
||||
room_id, remote_room_hosts = await self.resolve_room_id(
|
||||
room_identifier, remote_room_hosts
|
||||
)
|
||||
|
||||
fake_requester = create_requester(
|
||||
target_user, authenticated_entity=requester.authenticated_entity
|
||||
@@ -412,7 +448,7 @@ class JoinRoomAliasServlet(RestServlet):
|
||||
return 200, {"room_id": room_id}
|
||||
|
||||
|
||||
class MakeRoomAdminRestServlet(RestServlet):
|
||||
class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet):
|
||||
"""Allows a server admin to get power in a room if a local user has power in
|
||||
a room. Will also invite the user if they're not in the room and it's a
|
||||
private room. Can specify another user (rather than the admin user) to be
|
||||
@@ -427,29 +463,21 @@ class MakeRoomAdminRestServlet(RestServlet):
|
||||
PATTERNS = admin_patterns("/rooms/(?P<room_identifier>[^/]*)/make_room_admin")
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__(hs)
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.room_member_handler = hs.get_room_member_handler()
|
||||
self.event_creation_handler = hs.get_event_creation_handler()
|
||||
self.state_handler = hs.get_state_handler()
|
||||
self.is_mine_id = hs.is_mine_id
|
||||
|
||||
async def on_POST(self, request, room_identifier):
|
||||
async def on_POST(
|
||||
self, request: SynapseRequest, room_identifier: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
await assert_user_is_admin(self.auth, requester.user)
|
||||
content = parse_json_object_from_request(request, allow_empty_body=True)
|
||||
|
||||
# Resolve to a room ID, if necessary.
|
||||
if RoomID.is_valid(room_identifier):
|
||||
room_id = room_identifier
|
||||
elif RoomAlias.is_valid(room_identifier):
|
||||
room_alias = RoomAlias.from_string(room_identifier)
|
||||
room_id, _ = await self.room_member_handler.lookup_room_alias(room_alias)
|
||||
room_id = room_id.to_string()
|
||||
else:
|
||||
raise SynapseError(
|
||||
400, "%s was not legal room ID or room alias" % (room_identifier,)
|
||||
)
|
||||
room_id, _ = await self.resolve_room_id(room_identifier)
|
||||
|
||||
# Which user to grant room admin rights to.
|
||||
user_to_add = content.get("user_id", requester.user.to_string())
|
||||
@@ -556,7 +584,7 @@ class MakeRoomAdminRestServlet(RestServlet):
|
||||
return 200, {}
|
||||
|
||||
|
||||
class ForwardExtremitiesRestServlet(RestServlet):
|
||||
class ForwardExtremitiesRestServlet(ResolveRoomIdMixin, RestServlet):
|
||||
"""Allows a server admin to get or clear forward extremities.
|
||||
|
||||
Clearing does not require restarting the server.
|
||||
@@ -571,43 +599,29 @@ class ForwardExtremitiesRestServlet(RestServlet):
|
||||
PATTERNS = admin_patterns("/rooms/(?P<room_identifier>[^/]*)/forward_extremities")
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__(hs)
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.room_member_handler = hs.get_room_member_handler()
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
async def resolve_room_id(self, room_identifier: str) -> str:
|
||||
"""Resolve to a room ID, if necessary."""
|
||||
if RoomID.is_valid(room_identifier):
|
||||
resolved_room_id = room_identifier
|
||||
elif RoomAlias.is_valid(room_identifier):
|
||||
room_alias = RoomAlias.from_string(room_identifier)
|
||||
room_id, _ = await self.room_member_handler.lookup_room_alias(room_alias)
|
||||
resolved_room_id = room_id.to_string()
|
||||
else:
|
||||
raise SynapseError(
|
||||
400, "%s was not legal room ID or room alias" % (room_identifier,)
|
||||
)
|
||||
if not resolved_room_id:
|
||||
raise SynapseError(
|
||||
400, "Unknown room ID or room alias %s" % room_identifier
|
||||
)
|
||||
return resolved_room_id
|
||||
|
||||
async def on_DELETE(self, request, room_identifier):
|
||||
async def on_DELETE(
|
||||
self, request: SynapseRequest, room_identifier: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
await assert_user_is_admin(self.auth, requester.user)
|
||||
|
||||
room_id = await self.resolve_room_id(room_identifier)
|
||||
room_id, _ = await self.resolve_room_id(room_identifier)
|
||||
|
||||
deleted_count = await self.store.delete_forward_extremities_for_room(room_id)
|
||||
return 200, {"deleted": deleted_count}
|
||||
|
||||
async def on_GET(self, request, room_identifier):
|
||||
async def on_GET(
|
||||
self, request: SynapseRequest, room_identifier: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
await assert_user_is_admin(self.auth, requester.user)
|
||||
|
||||
room_id = await self.resolve_room_id(room_identifier)
|
||||
room_id, _ = await self.resolve_room_id(room_identifier)
|
||||
|
||||
extremities = await self.store.get_forward_extremities_for_room(room_id)
|
||||
return 200, {"count": len(extremities), "results": extremities}
|
||||
@@ -623,14 +637,16 @@ class RoomEventContextServlet(RestServlet):
|
||||
|
||||
PATTERNS = admin_patterns("/rooms/(?P<room_id>[^/]*)/context/(?P<event_id>[^/]*)$")
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
super().__init__()
|
||||
self.clock = hs.get_clock()
|
||||
self.room_context_handler = hs.get_room_context_handler()
|
||||
self._event_serializer = hs.get_event_client_serializer()
|
||||
self.auth = hs.get_auth()
|
||||
|
||||
async def on_GET(self, request, room_id, event_id):
|
||||
async def on_GET(
|
||||
self, request: SynapseRequest, room_id: str, event_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request, allow_guest=False)
|
||||
await assert_user_is_admin(self.auth, requester.user)
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ import hashlib
|
||||
import hmac
|
||||
import logging
|
||||
from http import HTTPStatus
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
|
||||
|
||||
from synapse.api.constants import UserTypes
|
||||
from synapse.api.errors import Codes, NotFoundError, SynapseError
|
||||
@@ -47,13 +47,15 @@ logger = logging.getLogger(__name__)
|
||||
class UsersRestServlet(RestServlet):
|
||||
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)$")
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.hs = hs
|
||||
self.store = hs.get_datastore()
|
||||
self.auth = hs.get_auth()
|
||||
self.admin_handler = hs.get_admin_handler()
|
||||
|
||||
async def on_GET(self, request, user_id):
|
||||
async def on_GET(
|
||||
self, request: SynapseRequest, user_id: str
|
||||
) -> Tuple[int, List[JsonDict]]:
|
||||
target_user = UserID.from_string(user_id)
|
||||
await assert_requester_is_admin(self.auth, request)
|
||||
|
||||
@@ -153,7 +155,7 @@ class UserRestServletV2(RestServlet):
|
||||
otherwise an error.
|
||||
"""
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.admin_handler = hs.get_admin_handler()
|
||||
@@ -165,7 +167,9 @@ class UserRestServletV2(RestServlet):
|
||||
self.registration_handler = hs.get_registration_handler()
|
||||
self.pusher_pool = hs.get_pusherpool()
|
||||
|
||||
async def on_GET(self, request, user_id):
|
||||
async def on_GET(
|
||||
self, request: SynapseRequest, user_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
await assert_requester_is_admin(self.auth, request)
|
||||
|
||||
target_user = UserID.from_string(user_id)
|
||||
@@ -179,7 +183,9 @@ class UserRestServletV2(RestServlet):
|
||||
|
||||
return 200, ret
|
||||
|
||||
async def on_PUT(self, request, user_id):
|
||||
async def on_PUT(
|
||||
self, request: SynapseRequest, user_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
await assert_user_is_admin(self.auth, requester.user)
|
||||
|
||||
@@ -273,6 +279,8 @@ class UserRestServletV2(RestServlet):
|
||||
)
|
||||
|
||||
user = await self.admin_handler.get_user(target_user)
|
||||
assert user is not None
|
||||
|
||||
return 200, user
|
||||
|
||||
else: # create user
|
||||
@@ -330,9 +338,10 @@ class UserRestServletV2(RestServlet):
|
||||
target_user, requester, body["avatar_url"], True
|
||||
)
|
||||
|
||||
ret = await self.admin_handler.get_user(target_user)
|
||||
user = await self.admin_handler.get_user(target_user)
|
||||
assert user is not None
|
||||
|
||||
return 201, ret
|
||||
return 201, user
|
||||
|
||||
|
||||
class UserRegisterServlet(RestServlet):
|
||||
@@ -346,10 +355,10 @@ class UserRegisterServlet(RestServlet):
|
||||
PATTERNS = admin_patterns("/register")
|
||||
NONCE_TIMEOUT = 60
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.auth_handler = hs.get_auth_handler()
|
||||
self.reactor = hs.get_reactor()
|
||||
self.nonces = {}
|
||||
self.nonces = {} # type: Dict[str, int]
|
||||
self.hs = hs
|
||||
|
||||
def _clear_old_nonces(self):
|
||||
@@ -362,7 +371,7 @@ class UserRegisterServlet(RestServlet):
|
||||
if now - v > self.NONCE_TIMEOUT:
|
||||
del self.nonces[k]
|
||||
|
||||
def on_GET(self, request):
|
||||
def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||
"""
|
||||
Generate a new nonce.
|
||||
"""
|
||||
@@ -372,7 +381,7 @@ class UserRegisterServlet(RestServlet):
|
||||
self.nonces[nonce] = int(self.reactor.seconds())
|
||||
return 200, {"nonce": nonce}
|
||||
|
||||
async def on_POST(self, request):
|
||||
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||
self._clear_old_nonces()
|
||||
|
||||
if not self.hs.config.registration_shared_secret:
|
||||
@@ -478,12 +487,14 @@ class WhoisRestServlet(RestServlet):
|
||||
client_patterns("/admin" + path_regex, v1=True)
|
||||
)
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.admin_handler = hs.get_admin_handler()
|
||||
|
||||
async def on_GET(self, request, user_id):
|
||||
async def on_GET(
|
||||
self, request: SynapseRequest, user_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
target_user = UserID.from_string(user_id)
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
auth_user = requester.user
|
||||
@@ -508,7 +519,9 @@ class DeactivateAccountRestServlet(RestServlet):
|
||||
self.is_mine = hs.is_mine
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
async def on_POST(self, request: str, target_user_id: str) -> Tuple[int, JsonDict]:
|
||||
async def on_POST(
|
||||
self, request: SynapseRequest, target_user_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
await assert_user_is_admin(self.auth, requester.user)
|
||||
|
||||
@@ -550,7 +563,7 @@ class AccountValidityRenewServlet(RestServlet):
|
||||
self.account_activity_handler = hs.get_account_validity_handler()
|
||||
self.auth = hs.get_auth()
|
||||
|
||||
async def on_POST(self, request):
|
||||
async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
|
||||
await assert_requester_is_admin(self.auth, request)
|
||||
|
||||
body = parse_json_object_from_request(request)
|
||||
@@ -584,14 +597,16 @@ class ResetPasswordRestServlet(RestServlet):
|
||||
|
||||
PATTERNS = admin_patterns("/reset_password/(?P<target_user_id>[^/]*)")
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.store = hs.get_datastore()
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.auth_handler = hs.get_auth_handler()
|
||||
self._set_password_handler = hs.get_set_password_handler()
|
||||
|
||||
async def on_POST(self, request, target_user_id):
|
||||
async def on_POST(
|
||||
self, request: SynapseRequest, target_user_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
"""Post request to allow an administrator reset password for a user.
|
||||
This needs user to have administrator access in Synapse.
|
||||
"""
|
||||
@@ -626,12 +641,14 @@ class SearchUsersRestServlet(RestServlet):
|
||||
|
||||
PATTERNS = admin_patterns("/search_users/(?P<target_user_id>[^/]*)")
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.hs = hs
|
||||
self.store = hs.get_datastore()
|
||||
self.auth = hs.get_auth()
|
||||
|
||||
async def on_GET(self, request, target_user_id):
|
||||
async def on_GET(
|
||||
self, request: SynapseRequest, target_user_id: str
|
||||
) -> Tuple[int, Optional[List[JsonDict]]]:
|
||||
"""Get request to search user table for specific users according to
|
||||
search term.
|
||||
This needs user to have a administrator access in Synapse.
|
||||
@@ -682,12 +699,14 @@ class UserAdminServlet(RestServlet):
|
||||
|
||||
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/admin$")
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.hs = hs
|
||||
self.store = hs.get_datastore()
|
||||
self.auth = hs.get_auth()
|
||||
|
||||
async def on_GET(self, request, user_id):
|
||||
async def on_GET(
|
||||
self, request: SynapseRequest, user_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
await assert_requester_is_admin(self.auth, request)
|
||||
|
||||
target_user = UserID.from_string(user_id)
|
||||
@@ -699,7 +718,9 @@ class UserAdminServlet(RestServlet):
|
||||
|
||||
return 200, {"admin": is_admin}
|
||||
|
||||
async def on_PUT(self, request, user_id):
|
||||
async def on_PUT(
|
||||
self, request: SynapseRequest, user_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
await assert_user_is_admin(self.auth, requester.user)
|
||||
auth_user = requester.user
|
||||
@@ -730,12 +751,14 @@ class UserMembershipRestServlet(RestServlet):
|
||||
|
||||
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]+)/joined_rooms$")
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.is_mine = hs.is_mine
|
||||
self.auth = hs.get_auth()
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
async def on_GET(self, request, user_id):
|
||||
async def on_GET(
|
||||
self, request: SynapseRequest, user_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
await assert_requester_is_admin(self.auth, request)
|
||||
|
||||
room_ids = await self.store.get_rooms_for_user(user_id)
|
||||
@@ -758,7 +781,7 @@ class PushersRestServlet(RestServlet):
|
||||
|
||||
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]*)/pushers$")
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.is_mine = hs.is_mine
|
||||
self.store = hs.get_datastore()
|
||||
self.auth = hs.get_auth()
|
||||
@@ -799,7 +822,7 @@ class UserMediaRestServlet(RestServlet):
|
||||
|
||||
PATTERNS = admin_patterns("/users/(?P<user_id>[^/]+)/media$")
|
||||
|
||||
def __init__(self, hs):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.is_mine = hs.is_mine
|
||||
self.auth = hs.get_auth()
|
||||
self.store = hs.get_datastore()
|
||||
@@ -891,7 +914,9 @@ class UserTokenRestServlet(RestServlet):
|
||||
self.auth = hs.get_auth()
|
||||
self.auth_handler = hs.get_auth_handler()
|
||||
|
||||
async def on_POST(self, request, user_id):
|
||||
async def on_POST(
|
||||
self, request: SynapseRequest, user_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
requester = await self.auth.get_user_by_req(request)
|
||||
await assert_user_is_admin(self.auth, requester.user)
|
||||
auth_user = requester.user
|
||||
@@ -943,7 +968,9 @@ class ShadowBanRestServlet(RestServlet):
|
||||
self.store = hs.get_datastore()
|
||||
self.auth = hs.get_auth()
|
||||
|
||||
async def on_POST(self, request, user_id):
|
||||
async def on_POST(
|
||||
self, request: SynapseRequest, user_id: str
|
||||
) -> Tuple[int, JsonDict]:
|
||||
await assert_requester_is_admin(self.auth, request)
|
||||
|
||||
if not self.hs.is_mine_id(user_id):
|
||||
|
||||
@@ -18,7 +18,7 @@ import logging
|
||||
from functools import wraps
|
||||
from typing import TYPE_CHECKING, Optional, Tuple
|
||||
|
||||
from twisted.web.http import Request
|
||||
from twisted.web.server import Request
|
||||
|
||||
from synapse.api.constants import (
|
||||
MAX_GROUP_CATEGORYID_LENGTH,
|
||||
|
||||
@@ -21,7 +21,7 @@ from typing import Awaitable, Dict, Generator, List, Optional, Tuple
|
||||
|
||||
from twisted.internet.interfaces import IConsumer
|
||||
from twisted.protocols.basic import FileSender
|
||||
from twisted.web.http import Request
|
||||
from twisted.web.server import Request
|
||||
|
||||
from synapse.api.errors import Codes, SynapseError, cs_error
|
||||
from synapse.http.server import finish_request, respond_with_json
|
||||
@@ -49,18 +49,20 @@ TEXT_CONTENT_TYPES = [
|
||||
|
||||
def parse_media_id(request: Request) -> Tuple[str, str, Optional[str]]:
|
||||
try:
|
||||
# The type on postpath seems incorrect in Twisted 21.2.0.
|
||||
postpath = request.postpath # type: List[bytes] # type: ignore
|
||||
assert postpath
|
||||
|
||||
# This allows users to append e.g. /test.png to the URL. Useful for
|
||||
# clients that parse the URL to see content type.
|
||||
server_name, media_id = request.postpath[:2]
|
||||
|
||||
if isinstance(server_name, bytes):
|
||||
server_name = server_name.decode("utf-8")
|
||||
media_id = media_id.decode("utf8")
|
||||
server_name_bytes, media_id_bytes = postpath[:2]
|
||||
server_name = server_name_bytes.decode("utf-8")
|
||||
media_id = media_id_bytes.decode("utf8")
|
||||
|
||||
file_name = None
|
||||
if len(request.postpath) > 2:
|
||||
if len(postpath) > 2:
|
||||
try:
|
||||
file_name = urllib.parse.unquote(request.postpath[-1].decode("utf-8"))
|
||||
file_name = urllib.parse.unquote(postpath[-1].decode("utf-8"))
|
||||
except UnicodeDecodeError:
|
||||
pass
|
||||
return server_name, media_id, file_name
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from twisted.web.http import Request
|
||||
from twisted.web.server import Request
|
||||
|
||||
from synapse.http.server import DirectServeJsonResource, respond_with_json
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from twisted.web.http import Request
|
||||
from twisted.web.server import Request
|
||||
|
||||
from synapse.http.server import DirectServeJsonResource, set_cors_headers
|
||||
from synapse.http.servlet import parse_boolean
|
||||
|
||||
@@ -22,8 +22,8 @@ from typing import IO, TYPE_CHECKING, Dict, List, Optional, Set, Tuple
|
||||
|
||||
import twisted.internet.error
|
||||
import twisted.web.http
|
||||
from twisted.web.http import Request
|
||||
from twisted.web.resource import Resource
|
||||
from twisted.web.server import Request
|
||||
|
||||
from synapse.api.errors import (
|
||||
FederationDeniedError,
|
||||
|
||||
@@ -29,7 +29,7 @@ from urllib import parse as urlparse
|
||||
import attr
|
||||
|
||||
from twisted.internet.error import DNSLookupError
|
||||
from twisted.web.http import Request
|
||||
from twisted.web.server import Request
|
||||
|
||||
from synapse.api.errors import Codes, SynapseError
|
||||
from synapse.http.client import SimpleHttpClient
|
||||
@@ -149,8 +149,7 @@ class PreviewUrlResource(DirectServeJsonResource):
|
||||
treq_args={"browser_like_redirects": True},
|
||||
ip_whitelist=hs.config.url_preview_ip_range_whitelist,
|
||||
ip_blacklist=hs.config.url_preview_ip_range_blacklist,
|
||||
http_proxy=os.getenvb(b"http_proxy"),
|
||||
https_proxy=os.getenvb(b"HTTPS_PROXY"),
|
||||
use_proxy=True,
|
||||
)
|
||||
self.media_repo = media_repo
|
||||
self.primary_base_path = media_repo.primary_base_path
|
||||
|
||||
@@ -18,7 +18,7 @@
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
|
||||
from twisted.web.http import Request
|
||||
from twisted.web.server import Request
|
||||
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.http.server import DirectServeJsonResource, set_cors_headers
|
||||
|
||||
@@ -15,9 +15,9 @@
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import IO, TYPE_CHECKING
|
||||
|
||||
from twisted.web.http import Request
|
||||
from twisted.web.server import Request
|
||||
|
||||
from synapse.api.errors import Codes, SynapseError
|
||||
from synapse.http.server import DirectServeJsonResource, respond_with_json
|
||||
@@ -79,7 +79,9 @@ class UploadResource(DirectServeJsonResource):
|
||||
headers = request.requestHeaders
|
||||
|
||||
if headers.hasHeader(b"Content-Type"):
|
||||
media_type = headers.getRawHeaders(b"Content-Type")[0].decode("ascii")
|
||||
content_type_headers = headers.getRawHeaders(b"Content-Type")
|
||||
assert content_type_headers # for mypy
|
||||
media_type = content_type_headers[0].decode("ascii")
|
||||
else:
|
||||
raise SynapseError(msg="Upload request missing 'Content-Type'", code=400)
|
||||
|
||||
@@ -88,8 +90,9 @@ class UploadResource(DirectServeJsonResource):
|
||||
# TODO(markjh): parse content-dispostion
|
||||
|
||||
try:
|
||||
content = request.content # type: IO # type: ignore
|
||||
content_uri = await self.media_repo.create_content(
|
||||
media_type, upload_name, request.content, content_length, requester.user
|
||||
media_type, upload_name, content, content_length, requester.user
|
||||
)
|
||||
except SpamMediaException:
|
||||
# For uploading of media we want to respond with a 400, instead of
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from twisted.web.http import Request
|
||||
from twisted.web.server import Request
|
||||
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.handlers.sso import get_username_mapping_session_cookie_from_request
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Tuple
|
||||
|
||||
from twisted.web.http import Request
|
||||
from twisted.web.server import Request
|
||||
|
||||
from synapse.api.errors import ThreepidValidationError
|
||||
from synapse.config.emailconfig import ThreepidBehaviour
|
||||
|
||||
@@ -16,8 +16,8 @@
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, List
|
||||
|
||||
from twisted.web.http import Request
|
||||
from twisted.web.resource import Resource
|
||||
from twisted.web.server import Request
|
||||
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.handlers.sso import get_username_mapping_session_cookie_from_request
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from twisted.web.http import Request
|
||||
from twisted.web.server import Request
|
||||
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.handlers.sso import get_username_mapping_session_cookie_from_request
|
||||
|
||||
@@ -24,7 +24,6 @@
|
||||
import abc
|
||||
import functools
|
||||
import logging
|
||||
import os
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
@@ -39,6 +38,7 @@ from typing import (
|
||||
|
||||
import twisted.internet.base
|
||||
import twisted.internet.tcp
|
||||
from twisted.internet import defer
|
||||
from twisted.mail.smtp import sendmail
|
||||
from twisted.web.iweb import IPolicyForHTTPS
|
||||
|
||||
@@ -370,11 +370,7 @@ class HomeServer(metaclass=abc.ABCMeta):
|
||||
"""
|
||||
An HTTP client that uses configured HTTP(S) proxies.
|
||||
"""
|
||||
return SimpleHttpClient(
|
||||
self,
|
||||
http_proxy=os.getenvb(b"http_proxy"),
|
||||
https_proxy=os.getenvb(b"HTTPS_PROXY"),
|
||||
)
|
||||
return SimpleHttpClient(self, use_proxy=True)
|
||||
|
||||
@cache_in_self
|
||||
def get_proxied_blacklisted_http_client(self) -> SimpleHttpClient:
|
||||
@@ -386,8 +382,7 @@ class HomeServer(metaclass=abc.ABCMeta):
|
||||
self,
|
||||
ip_whitelist=self.config.ip_range_whitelist,
|
||||
ip_blacklist=self.config.ip_range_blacklist,
|
||||
http_proxy=os.getenvb(b"http_proxy"),
|
||||
https_proxy=os.getenvb(b"HTTPS_PROXY"),
|
||||
use_proxy=True,
|
||||
)
|
||||
|
||||
@cache_in_self
|
||||
@@ -409,7 +404,7 @@ class HomeServer(metaclass=abc.ABCMeta):
|
||||
return RoomShutdownHandler(self)
|
||||
|
||||
@cache_in_self
|
||||
def get_sendmail(self) -> sendmail:
|
||||
def get_sendmail(self) -> Callable[..., defer.Deferred]:
|
||||
return sendmail
|
||||
|
||||
@cache_in_self
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from synapse.api.constants import PresenceState
|
||||
from synapse.config.homeserver import HomeServerConfig
|
||||
@@ -27,7 +27,7 @@ from synapse.storage.util.id_generators import (
|
||||
MultiWriterIdGenerator,
|
||||
StreamIdGenerator,
|
||||
)
|
||||
from synapse.types import get_domain_from_id
|
||||
from synapse.types import JsonDict, get_domain_from_id
|
||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||
|
||||
from .account_data import AccountDataStore
|
||||
@@ -264,7 +264,7 @@ class DataStore(
|
||||
|
||||
return [UserPresenceState(**row) for row in rows]
|
||||
|
||||
async def get_users(self) -> List[Dict[str, Any]]:
|
||||
async def get_users(self) -> List[JsonDict]:
|
||||
"""Function to retrieve a list of users in users table.
|
||||
|
||||
Returns:
|
||||
@@ -292,7 +292,7 @@ class DataStore(
|
||||
name: Optional[str] = None,
|
||||
guests: bool = True,
|
||||
deactivated: bool = False,
|
||||
) -> Tuple[List[Dict[str, Any]], int]:
|
||||
) -> Tuple[List[JsonDict], int]:
|
||||
"""Function to retrieve a paginated list of users from
|
||||
users list. This will return a json list of users and the
|
||||
total number of users matching the filter criteria.
|
||||
@@ -353,7 +353,7 @@ class DataStore(
|
||||
"get_users_paginate_txn", get_users_paginate_txn
|
||||
)
|
||||
|
||||
async def search_users(self, term: str) -> Optional[List[Dict[str, Any]]]:
|
||||
async def search_users(self, term: str) -> Optional[List[JsonDict]]:
|
||||
"""Function to search users list for one or more users with
|
||||
the matched term.
|
||||
|
||||
|
||||
@@ -27,7 +27,7 @@ logger = logging.getLogger(__name__)
|
||||
# Number of msec of granularity to store the user IP 'last seen' time. Smaller
|
||||
# times give more inserts into the database even for readonly API hits
|
||||
# 120 seconds == 2 minutes
|
||||
LAST_SEEN_GRANULARITY = 10 * 60 * 1000
|
||||
LAST_SEEN_GRANULARITY = 120 * 1000
|
||||
|
||||
|
||||
class ClientIpBackgroundUpdateStore(SQLBaseStore):
|
||||
|
||||
@@ -696,7 +696,9 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
|
||||
)
|
||||
|
||||
if not has_event_auth:
|
||||
for auth_id in event.auth_event_ids():
|
||||
# Old, dodgy, events may have duplicate auth events, which we
|
||||
# need to deduplicate as we have a unique constraint.
|
||||
for auth_id in set(event.auth_event_ids()):
|
||||
auth_events.append(
|
||||
{
|
||||
"room_id": event.room_id,
|
||||
|
||||
@@ -139,7 +139,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
||||
start: int,
|
||||
limit: int,
|
||||
user_id: str,
|
||||
order_by: MediaSortOrder = MediaSortOrder.CREATED_TS.value,
|
||||
order_by: str = MediaSortOrder.CREATED_TS.value,
|
||||
direction: str = "f",
|
||||
) -> Tuple[List[Dict[str, Any]], int]:
|
||||
"""Get a paginated list of metadata for a local piece of media
|
||||
|
||||
@@ -28,7 +28,10 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
|
||||
async def purge_history(
|
||||
self, room_id: str, token: str, delete_local_events: bool
|
||||
) -> Set[int]:
|
||||
"""Deletes room history before a certain point
|
||||
"""Deletes room history before a certain point.
|
||||
|
||||
Note that only a single purge can occur at once, this is guaranteed via
|
||||
a higher level (in the PaginationHandler).
|
||||
|
||||
Args:
|
||||
room_id:
|
||||
@@ -52,7 +55,9 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
|
||||
delete_local_events,
|
||||
)
|
||||
|
||||
def _purge_history_txn(self, txn, room_id, token, delete_local_events):
|
||||
def _purge_history_txn(
|
||||
self, txn, room_id: str, token: RoomStreamToken, delete_local_events: bool
|
||||
) -> Set[int]:
|
||||
# Tables that should be pruned:
|
||||
# event_auth
|
||||
# event_backward_extremities
|
||||
@@ -103,7 +108,7 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
|
||||
if max_depth < token.topological:
|
||||
# We need to ensure we don't delete all the events from the database
|
||||
# otherwise we wouldn't be able to send any events (due to not
|
||||
# having any backwards extremeties)
|
||||
# having any backwards extremities)
|
||||
raise SynapseError(
|
||||
400, "topological_ordering is greater than forward extremeties"
|
||||
)
|
||||
@@ -154,7 +159,7 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
|
||||
|
||||
logger.info("[purge] Finding new backward extremities")
|
||||
|
||||
# We calculate the new entries for the backward extremeties by finding
|
||||
# We calculate the new entries for the backward extremities by finding
|
||||
# events to be purged that are pointed to by events we're not going to
|
||||
# purge.
|
||||
txn.execute(
|
||||
@@ -296,7 +301,7 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
|
||||
"purge_room", self._purge_room_txn, room_id
|
||||
)
|
||||
|
||||
def _purge_room_txn(self, txn, room_id):
|
||||
def _purge_room_txn(self, txn, room_id: str) -> List[int]:
|
||||
# First we fetch all the state groups that should be deleted, before
|
||||
# we delete that information.
|
||||
txn.execute(
|
||||
@@ -310,6 +315,31 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
|
||||
|
||||
state_groups = [row[0] for row in txn]
|
||||
|
||||
# Get all the auth chains that are referenced by events that are to be
|
||||
# deleted.
|
||||
txn.execute(
|
||||
"""
|
||||
SELECT chain_id, sequence_number FROM events
|
||||
LEFT JOIN event_auth_chains USING (event_id)
|
||||
WHERE room_id = ?
|
||||
""",
|
||||
(room_id,),
|
||||
)
|
||||
referenced_chain_id_tuples = list(txn)
|
||||
|
||||
logger.info("[purge] removing events from event_auth_chain_links")
|
||||
txn.executemany(
|
||||
"""
|
||||
DELETE FROM event_auth_chain_links WHERE
|
||||
(origin_chain_id = ? AND origin_sequence_number = ?) OR
|
||||
(target_chain_id = ? AND target_sequence_number = ?)
|
||||
""",
|
||||
(
|
||||
(chain_id, seq_num, chain_id, seq_num)
|
||||
for (chain_id, seq_num) in referenced_chain_id_tuples
|
||||
),
|
||||
)
|
||||
|
||||
# Now we delete tables which lack an index on room_id but have one on event_id
|
||||
for table in (
|
||||
"event_auth",
|
||||
@@ -319,6 +349,8 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
|
||||
"event_reference_hashes",
|
||||
"event_relations",
|
||||
"event_to_state_groups",
|
||||
"event_auth_chains",
|
||||
"event_auth_chain_to_calculate",
|
||||
"redactions",
|
||||
"rejections",
|
||||
"state_events",
|
||||
|
||||
@@ -39,6 +39,16 @@ class PusherWorkerStore(SQLBaseStore):
|
||||
db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]
|
||||
)
|
||||
|
||||
self.db_pool.updates.register_background_update_handler(
|
||||
"remove_deactivated_pushers",
|
||||
self._remove_deactivated_pushers,
|
||||
)
|
||||
|
||||
self.db_pool.updates.register_background_update_handler(
|
||||
"remove_stale_pushers",
|
||||
self._remove_stale_pushers,
|
||||
)
|
||||
|
||||
def _decode_pushers_rows(self, rows: Iterable[dict]) -> Iterator[PusherConfig]:
|
||||
"""JSON-decode the data in the rows returned from the `pushers` table
|
||||
|
||||
@@ -284,6 +294,101 @@ class PusherWorkerStore(SQLBaseStore):
|
||||
lock=False,
|
||||
)
|
||||
|
||||
async def _remove_deactivated_pushers(self, progress: dict, batch_size: int) -> int:
|
||||
"""A background update that deletes all pushers for deactivated users.
|
||||
|
||||
Note that we don't proacively tell the pusherpool that we've deleted
|
||||
these (just because its a bit off a faff to do from here), but they will
|
||||
get cleaned up at the next restart
|
||||
"""
|
||||
|
||||
last_user = progress.get("last_user", "")
|
||||
|
||||
def _delete_pushers(txn) -> int:
|
||||
|
||||
sql = """
|
||||
SELECT name FROM users
|
||||
WHERE deactivated = ? and name > ?
|
||||
ORDER BY name ASC
|
||||
LIMIT ?
|
||||
"""
|
||||
|
||||
txn.execute(sql, (1, last_user, batch_size))
|
||||
users = [row[0] for row in txn]
|
||||
|
||||
self.db_pool.simple_delete_many_txn(
|
||||
txn,
|
||||
table="pushers",
|
||||
column="user_name",
|
||||
iterable=users,
|
||||
keyvalues={},
|
||||
)
|
||||
|
||||
if users:
|
||||
self.db_pool.updates._background_update_progress_txn(
|
||||
txn, "remove_deactivated_pushers", {"last_user": users[-1]}
|
||||
)
|
||||
|
||||
return len(users)
|
||||
|
||||
number_deleted = await self.db_pool.runInteraction(
|
||||
"_remove_deactivated_pushers", _delete_pushers
|
||||
)
|
||||
|
||||
if number_deleted < batch_size:
|
||||
await self.db_pool.updates._end_background_update(
|
||||
"remove_deactivated_pushers"
|
||||
)
|
||||
|
||||
return number_deleted
|
||||
|
||||
async def _remove_stale_pushers(self, progress: dict, batch_size: int) -> int:
|
||||
"""A background update that deletes all pushers for logged out devices.
|
||||
|
||||
Note that we don't proacively tell the pusherpool that we've deleted
|
||||
these (just because its a bit off a faff to do from here), but they will
|
||||
get cleaned up at the next restart
|
||||
"""
|
||||
|
||||
last_pusher = progress.get("last_pusher", 0)
|
||||
|
||||
def _delete_pushers(txn) -> int:
|
||||
|
||||
sql = """
|
||||
SELECT p.id, access_token FROM pushers AS p
|
||||
LEFT JOIN access_tokens AS a ON (p.access_token = a.id)
|
||||
WHERE p.id > ?
|
||||
ORDER BY p.id ASC
|
||||
LIMIT ?
|
||||
"""
|
||||
|
||||
txn.execute(sql, (last_pusher, batch_size))
|
||||
pushers = [(row[0], row[1]) for row in txn]
|
||||
|
||||
self.db_pool.simple_delete_many_txn(
|
||||
txn,
|
||||
table="pushers",
|
||||
column="id",
|
||||
iterable=(pusher_id for pusher_id, token in pushers if token is None),
|
||||
keyvalues={},
|
||||
)
|
||||
|
||||
if pushers:
|
||||
self.db_pool.updates._background_update_progress_txn(
|
||||
txn, "remove_stale_pushers", {"last_pusher": pushers[-1][0]}
|
||||
)
|
||||
|
||||
return len(pushers)
|
||||
|
||||
number_deleted = await self.db_pool.runInteraction(
|
||||
"_remove_stale_pushers", _delete_pushers
|
||||
)
|
||||
|
||||
if number_deleted < batch_size:
|
||||
await self.db_pool.updates._end_background_update("remove_stale_pushers")
|
||||
|
||||
return number_deleted
|
||||
|
||||
|
||||
class PusherStore(PusherWorkerStore):
|
||||
def get_pushers_stream_token(self) -> int:
|
||||
|
||||
@@ -14,8 +14,7 @@
|
||||
*/
|
||||
|
||||
|
||||
-- We may not have deleted all pushers for deactivated accounts. Do so now.
|
||||
--
|
||||
-- Note: We don't bother updating the `deleted_pushers` table as it's just use
|
||||
-- to stop pushers on workers, and that will happen when they get next restarted.
|
||||
DELETE FROM pushers WHERE user_name IN (SELECT name FROM users WHERE deactivated = 1);
|
||||
-- We may not have deleted all pushers for deactivated accounts, so we set up a
|
||||
-- background job to delete them.
|
||||
INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
|
||||
(5908, 'remove_deactivated_pushers', '{}');
|
||||
|
||||
@@ -16,4 +16,5 @@
|
||||
|
||||
-- Delete all pushers associated with deleted devices. This is to clear up after
|
||||
-- a bug where they weren't correctly deleted when using workers.
|
||||
DELETE FROM pushers WHERE access_token NOT IN (SELECT id FROM access_tokens);
|
||||
INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
|
||||
(5908, 'remove_stale_pushers', '{}');
|
||||
|
||||
@@ -13,5 +13,14 @@
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
-- This originally was in 58/, but landed after 59/ was created, and so some
|
||||
-- servers running develop didn't run this delta. Running it again should be
|
||||
-- safe.
|
||||
--
|
||||
-- We first delete any in progress `rejected_events_metadata` background update,
|
||||
-- to ensure that we don't conflict when trying to insert the new one. (We could
|
||||
-- alternatively do an ON CONFLICT DO NOTHING, but that syntax isn't supported
|
||||
-- by older SQLite versions. Plus, this should be a rare case).
|
||||
DELETE FROM background_updates WHERE update_name = 'rejected_events_metadata';
|
||||
INSERT INTO background_updates (ordering, update_name, progress_json) VALUES
|
||||
(5828, 'rejected_events_metadata', '{}');
|
||||
@@ -707,7 +707,7 @@ def _parse_query(database_engine, search_term):
|
||||
results = re.findall(r"([\w\-]+)", search_term, re.UNICODE)
|
||||
|
||||
if isinstance(database_engine, PostgresEngine):
|
||||
return " & ".join(result for result in results)
|
||||
return " & ".join(result + ":*" for result in results)
|
||||
elif isinstance(database_engine, Sqlite3Engine):
|
||||
return " & ".join(result + "*" for result in results)
|
||||
else:
|
||||
|
||||
@@ -73,9 +73,6 @@ class PurgeEventsStorage:
|
||||
Returns:
|
||||
The set of state groups that can be deleted.
|
||||
"""
|
||||
# Graph of state group -> previous group
|
||||
graph = {}
|
||||
|
||||
# Set of events that we have found to be referenced by events
|
||||
referenced_groups = set()
|
||||
|
||||
@@ -111,8 +108,6 @@ class PurgeEventsStorage:
|
||||
next_to_search |= prevs
|
||||
state_groups_seen |= prevs
|
||||
|
||||
graph.update(edges)
|
||||
|
||||
to_delete = state_groups_seen - referenced_groups
|
||||
|
||||
return to_delete
|
||||
|
||||
@@ -25,7 +25,7 @@ RoomsForUser = namedtuple(
|
||||
)
|
||||
|
||||
GetRoomsForUserWithStreamOrdering = namedtuple(
|
||||
"_GetRoomsForUserWithStreamOrdering", ("room_id", "event_pos")
|
||||
"GetRoomsForUserWithStreamOrdering", ("room_id", "event_pos")
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, Generic, Optional, Set, TypeVar
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, Generic, Optional, TypeVar
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
@@ -40,7 +40,6 @@ class ResponseCache(Generic[T]):
|
||||
def __init__(self, hs: "HomeServer", name: str, timeout_ms: float = 0):
|
||||
# Requests that haven't finished yet.
|
||||
self.pending_result_cache = {} # type: Dict[T, ObservableDeferred]
|
||||
self.pending_conditionals = {} # type: Dict[T, Set[Callable[[Any], bool]]]
|
||||
|
||||
self.clock = hs.get_clock()
|
||||
self.timeout_sec = timeout_ms / 1000.0
|
||||
@@ -102,11 +101,7 @@ class ResponseCache(Generic[T]):
|
||||
self.pending_result_cache[key] = result
|
||||
|
||||
def remove(r):
|
||||
should_cache = all(
|
||||
func(r) for func in self.pending_conditionals.pop(key, [])
|
||||
)
|
||||
|
||||
if self.timeout_sec and should_cache:
|
||||
if self.timeout_sec:
|
||||
self.clock.call_later(
|
||||
self.timeout_sec, self.pending_result_cache.pop, key, None
|
||||
)
|
||||
@@ -117,31 +112,6 @@ class ResponseCache(Generic[T]):
|
||||
result.addBoth(remove)
|
||||
return result.observe()
|
||||
|
||||
def add_conditional(self, key: T, conditional: Callable[[Any], bool]):
|
||||
self.pending_conditionals.setdefault(key, set()).add(conditional)
|
||||
|
||||
def wrap_conditional(
|
||||
self,
|
||||
key: T,
|
||||
should_cache: Callable[[Any], bool],
|
||||
callback: "Callable[..., Any]",
|
||||
*args: Any,
|
||||
**kwargs: Any
|
||||
) -> defer.Deferred:
|
||||
"""The same as wrap(), but adds a conditional to the final execution.
|
||||
|
||||
When the final execution completes, *all* conditionals need to return True for it to properly cache,
|
||||
else it'll not be cached in a timed fashion.
|
||||
"""
|
||||
|
||||
# See if there's already a result on this key that hasn't yet completed. Due to the single-threaded nature of
|
||||
# python, adding a key immediately in the same execution thread will not cause a race condition.
|
||||
result = self.get(key)
|
||||
if not result or isinstance(result, defer.Deferred) and not result.called:
|
||||
self.add_conditional(key, should_cache)
|
||||
|
||||
return self.wrap(key, callback, *args, **kwargs)
|
||||
|
||||
def wrap(
|
||||
self, key: T, callback: "Callable[..., Any]", *args: Any, **kwargs: Any
|
||||
) -> defer.Deferred:
|
||||
|
||||
3
synctl
3
synctl
@@ -30,7 +30,7 @@ import yaml
|
||||
|
||||
from synapse.config import find_config_files
|
||||
|
||||
SYNAPSE = [sys.executable, "-B", "-m", "synapse.app.homeserver"]
|
||||
SYNAPSE = [sys.executable, "-m", "synapse.app.homeserver"]
|
||||
|
||||
GREEN = "\x1b[1;32m"
|
||||
YELLOW = "\x1b[1;33m"
|
||||
@@ -117,7 +117,6 @@ def start_worker(app: str, configfile: str, worker_configfile: str) -> bool:
|
||||
|
||||
args = [
|
||||
sys.executable,
|
||||
"-B",
|
||||
"-m",
|
||||
app,
|
||||
"-c",
|
||||
|
||||
@@ -26,77 +26,96 @@ from tests.unittest import TestCase
|
||||
|
||||
|
||||
class ReadBodyWithMaxSizeTests(TestCase):
|
||||
def setUp(self):
|
||||
def _build_response(self, length=UNKNOWN_LENGTH):
|
||||
"""Start reading the body, returns the response, result and proto"""
|
||||
response = Mock(length=UNKNOWN_LENGTH)
|
||||
self.result = BytesIO()
|
||||
self.deferred = read_body_with_max_size(response, self.result, 6)
|
||||
response = Mock(length=length)
|
||||
result = BytesIO()
|
||||
deferred = read_body_with_max_size(response, result, 6)
|
||||
|
||||
# Fish the protocol out of the response.
|
||||
self.protocol = response.deliverBody.call_args[0][0]
|
||||
self.protocol.transport = Mock()
|
||||
protocol = response.deliverBody.call_args[0][0]
|
||||
protocol.transport = Mock()
|
||||
|
||||
def _cleanup_error(self):
|
||||
return result, deferred, protocol
|
||||
|
||||
def _assert_error(self, deferred, protocol):
|
||||
"""Ensure that the expected error is received."""
|
||||
self.assertIsInstance(deferred.result, Failure)
|
||||
self.assertIsInstance(deferred.result.value, BodyExceededMaxSize)
|
||||
protocol.transport.abortConnection.assert_called_once()
|
||||
|
||||
def _cleanup_error(self, deferred):
|
||||
"""Ensure that the error in the Deferred is handled gracefully."""
|
||||
called = [False]
|
||||
|
||||
def errback(f):
|
||||
called[0] = True
|
||||
|
||||
self.deferred.addErrback(errback)
|
||||
deferred.addErrback(errback)
|
||||
self.assertTrue(called[0])
|
||||
|
||||
def test_no_error(self):
|
||||
"""A response that is NOT too large."""
|
||||
result, deferred, protocol = self._build_response()
|
||||
|
||||
# Start sending data.
|
||||
self.protocol.dataReceived(b"12345")
|
||||
protocol.dataReceived(b"12345")
|
||||
# Close the connection.
|
||||
self.protocol.connectionLost(Failure(ResponseDone()))
|
||||
protocol.connectionLost(Failure(ResponseDone()))
|
||||
|
||||
self.assertEqual(self.result.getvalue(), b"12345")
|
||||
self.assertEqual(self.deferred.result, 5)
|
||||
self.assertEqual(result.getvalue(), b"12345")
|
||||
self.assertEqual(deferred.result, 5)
|
||||
|
||||
def test_too_large(self):
|
||||
"""A response which is too large raises an exception."""
|
||||
result, deferred, protocol = self._build_response()
|
||||
|
||||
# Start sending data.
|
||||
self.protocol.dataReceived(b"1234567890")
|
||||
# Close the connection.
|
||||
self.protocol.connectionLost(Failure(ResponseDone()))
|
||||
protocol.dataReceived(b"1234567890")
|
||||
|
||||
self.assertEqual(self.result.getvalue(), b"1234567890")
|
||||
self.assertIsInstance(self.deferred.result, Failure)
|
||||
self.assertIsInstance(self.deferred.result.value, BodyExceededMaxSize)
|
||||
self._cleanup_error()
|
||||
self.assertEqual(result.getvalue(), b"1234567890")
|
||||
self._assert_error(deferred, protocol)
|
||||
self._cleanup_error(deferred)
|
||||
|
||||
def test_multiple_packets(self):
|
||||
"""Data should be accummulated through mutliple packets."""
|
||||
"""Data should be accumulated through mutliple packets."""
|
||||
result, deferred, protocol = self._build_response()
|
||||
|
||||
# Start sending data.
|
||||
self.protocol.dataReceived(b"12")
|
||||
self.protocol.dataReceived(b"34")
|
||||
protocol.dataReceived(b"12")
|
||||
protocol.dataReceived(b"34")
|
||||
# Close the connection.
|
||||
self.protocol.connectionLost(Failure(ResponseDone()))
|
||||
protocol.connectionLost(Failure(ResponseDone()))
|
||||
|
||||
self.assertEqual(self.result.getvalue(), b"1234")
|
||||
self.assertEqual(self.deferred.result, 4)
|
||||
self.assertEqual(result.getvalue(), b"1234")
|
||||
self.assertEqual(deferred.result, 4)
|
||||
|
||||
def test_additional_data(self):
|
||||
"""A connection can receive data after being closed."""
|
||||
result, deferred, protocol = self._build_response()
|
||||
|
||||
# Start sending data.
|
||||
self.protocol.dataReceived(b"1234567890")
|
||||
self.assertIsInstance(self.deferred.result, Failure)
|
||||
self.assertIsInstance(self.deferred.result.value, BodyExceededMaxSize)
|
||||
self.protocol.transport.abortConnection.assert_called_once()
|
||||
protocol.dataReceived(b"1234567890")
|
||||
self._assert_error(deferred, protocol)
|
||||
|
||||
# More data might have come in.
|
||||
self.protocol.dataReceived(b"1234567890")
|
||||
# Close the connection.
|
||||
self.protocol.connectionLost(Failure(ResponseDone()))
|
||||
protocol.dataReceived(b"1234567890")
|
||||
|
||||
self.assertEqual(self.result.getvalue(), b"1234567890")
|
||||
self.assertIsInstance(self.deferred.result, Failure)
|
||||
self.assertIsInstance(self.deferred.result.value, BodyExceededMaxSize)
|
||||
self._cleanup_error()
|
||||
self.assertEqual(result.getvalue(), b"1234567890")
|
||||
self._assert_error(deferred, protocol)
|
||||
self._cleanup_error(deferred)
|
||||
|
||||
def test_content_length(self):
|
||||
"""The body shouldn't be read (at all) if the Content-Length header is too large."""
|
||||
result, deferred, protocol = self._build_response(length=10)
|
||||
|
||||
# Deferred shouldn't be called yet.
|
||||
self.assertFalse(deferred.called)
|
||||
|
||||
# Start sending data.
|
||||
protocol.dataReceived(b"12345")
|
||||
self._assert_error(deferred, protocol)
|
||||
self._cleanup_error(deferred)
|
||||
|
||||
# The data is never consumed.
|
||||
self.assertEqual(result.getvalue(), b"")
|
||||
|
||||
@@ -13,6 +13,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
|
||||
import treq
|
||||
from netaddr import IPSet
|
||||
@@ -100,62 +102,36 @@ class MatrixFederationAgentTests(TestCase):
|
||||
|
||||
return http_protocol
|
||||
|
||||
def test_http_request(self):
|
||||
agent = ProxyAgent(self.reactor)
|
||||
def _test_request_direct_connection(self, agent, scheme, hostname, path):
|
||||
"""Runs a test case for a direct connection not going through a proxy.
|
||||
|
||||
self.reactor.lookups["test.com"] = "1.2.3.4"
|
||||
d = agent.request(b"GET", b"http://test.com")
|
||||
Args:
|
||||
agent (ProxyAgent): the proxy agent being tested
|
||||
|
||||
scheme (bytes): expected to be either "http" or "https"
|
||||
|
||||
hostname (bytes): the hostname to connect to in the test
|
||||
|
||||
path (bytes): the path to connect to in the test
|
||||
"""
|
||||
is_https = scheme == b"https"
|
||||
|
||||
self.reactor.lookups[hostname.decode()] = "1.2.3.4"
|
||||
d = agent.request(b"GET", scheme + b"://" + hostname + b"/" + path)
|
||||
|
||||
# there should be a pending TCP connection
|
||||
clients = self.reactor.tcpClients
|
||||
self.assertEqual(len(clients), 1)
|
||||
(host, port, client_factory, _timeout, _bindAddress) = clients[0]
|
||||
self.assertEqual(host, "1.2.3.4")
|
||||
self.assertEqual(port, 80)
|
||||
|
||||
# make a test server, and wire up the client
|
||||
http_server = self._make_connection(
|
||||
client_factory, _get_test_protocol_factory()
|
||||
)
|
||||
|
||||
# the FakeTransport is async, so we need to pump the reactor
|
||||
self.reactor.advance(0)
|
||||
|
||||
# now there should be a pending request
|
||||
self.assertEqual(len(http_server.requests), 1)
|
||||
|
||||
request = http_server.requests[0]
|
||||
self.assertEqual(request.method, b"GET")
|
||||
self.assertEqual(request.path, b"/")
|
||||
self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"])
|
||||
request.write(b"result")
|
||||
request.finish()
|
||||
|
||||
self.reactor.advance(0)
|
||||
|
||||
resp = self.successResultOf(d)
|
||||
body = self.successResultOf(treq.content(resp))
|
||||
self.assertEqual(body, b"result")
|
||||
|
||||
def test_https_request(self):
|
||||
agent = ProxyAgent(self.reactor, contextFactory=get_test_https_policy())
|
||||
|
||||
self.reactor.lookups["test.com"] = "1.2.3.4"
|
||||
d = agent.request(b"GET", b"https://test.com/abc")
|
||||
|
||||
# there should be a pending TCP connection
|
||||
clients = self.reactor.tcpClients
|
||||
self.assertEqual(len(clients), 1)
|
||||
(host, port, client_factory, _timeout, _bindAddress) = clients[0]
|
||||
self.assertEqual(host, "1.2.3.4")
|
||||
self.assertEqual(port, 443)
|
||||
self.assertEqual(port, 443 if is_https else 80)
|
||||
|
||||
# make a test server, and wire up the client
|
||||
http_server = self._make_connection(
|
||||
client_factory,
|
||||
_get_test_protocol_factory(),
|
||||
ssl=True,
|
||||
expected_sni=b"test.com",
|
||||
ssl=is_https,
|
||||
expected_sni=hostname if is_https else None,
|
||||
)
|
||||
|
||||
# the FakeTransport is async, so we need to pump the reactor
|
||||
@@ -166,8 +142,8 @@ class MatrixFederationAgentTests(TestCase):
|
||||
|
||||
request = http_server.requests[0]
|
||||
self.assertEqual(request.method, b"GET")
|
||||
self.assertEqual(request.path, b"/abc")
|
||||
self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [b"test.com"])
|
||||
self.assertEqual(request.path, b"/" + path)
|
||||
self.assertEqual(request.requestHeaders.getRawHeaders(b"host"), [hostname])
|
||||
request.write(b"result")
|
||||
request.finish()
|
||||
|
||||
@@ -177,8 +153,58 @@ class MatrixFederationAgentTests(TestCase):
|
||||
body = self.successResultOf(treq.content(resp))
|
||||
self.assertEqual(body, b"result")
|
||||
|
||||
def test_http_request(self):
|
||||
agent = ProxyAgent(self.reactor)
|
||||
self._test_request_direct_connection(agent, b"http", b"test.com", b"")
|
||||
|
||||
def test_https_request(self):
|
||||
agent = ProxyAgent(self.reactor, contextFactory=get_test_https_policy())
|
||||
self._test_request_direct_connection(agent, b"https", b"test.com", b"abc")
|
||||
|
||||
def test_http_request_use_proxy_empty_environment(self):
|
||||
agent = ProxyAgent(self.reactor, use_proxy=True)
|
||||
self._test_request_direct_connection(agent, b"http", b"test.com", b"")
|
||||
|
||||
@patch.dict(os.environ, {"http_proxy": "proxy.com:8888", "NO_PROXY": "test.com"})
|
||||
def test_http_request_via_uppercase_no_proxy(self):
|
||||
agent = ProxyAgent(self.reactor, use_proxy=True)
|
||||
self._test_request_direct_connection(agent, b"http", b"test.com", b"")
|
||||
|
||||
@patch.dict(
|
||||
os.environ, {"http_proxy": "proxy.com:8888", "no_proxy": "test.com,unused.com"}
|
||||
)
|
||||
def test_http_request_via_no_proxy(self):
|
||||
agent = ProxyAgent(self.reactor, use_proxy=True)
|
||||
self._test_request_direct_connection(agent, b"http", b"test.com", b"")
|
||||
|
||||
@patch.dict(
|
||||
os.environ, {"https_proxy": "proxy.com", "no_proxy": "test.com,unused.com"}
|
||||
)
|
||||
def test_https_request_via_no_proxy(self):
|
||||
agent = ProxyAgent(
|
||||
self.reactor,
|
||||
contextFactory=get_test_https_policy(),
|
||||
use_proxy=True,
|
||||
)
|
||||
self._test_request_direct_connection(agent, b"https", b"test.com", b"abc")
|
||||
|
||||
@patch.dict(os.environ, {"http_proxy": "proxy.com:8888", "no_proxy": "*"})
|
||||
def test_http_request_via_no_proxy_star(self):
|
||||
agent = ProxyAgent(self.reactor, use_proxy=True)
|
||||
self._test_request_direct_connection(agent, b"http", b"test.com", b"")
|
||||
|
||||
@patch.dict(os.environ, {"https_proxy": "proxy.com", "no_proxy": "*"})
|
||||
def test_https_request_via_no_proxy_star(self):
|
||||
agent = ProxyAgent(
|
||||
self.reactor,
|
||||
contextFactory=get_test_https_policy(),
|
||||
use_proxy=True,
|
||||
)
|
||||
self._test_request_direct_connection(agent, b"https", b"test.com", b"abc")
|
||||
|
||||
@patch.dict(os.environ, {"http_proxy": "proxy.com:8888", "no_proxy": "unused.com"})
|
||||
def test_http_request_via_proxy(self):
|
||||
agent = ProxyAgent(self.reactor, http_proxy=b"proxy.com:8888")
|
||||
agent = ProxyAgent(self.reactor, use_proxy=True)
|
||||
|
||||
self.reactor.lookups["proxy.com"] = "1.2.3.5"
|
||||
d = agent.request(b"GET", b"http://test.com")
|
||||
@@ -214,11 +240,12 @@ class MatrixFederationAgentTests(TestCase):
|
||||
body = self.successResultOf(treq.content(resp))
|
||||
self.assertEqual(body, b"result")
|
||||
|
||||
@patch.dict(os.environ, {"https_proxy": "proxy.com", "no_proxy": "unused.com"})
|
||||
def test_https_request_via_proxy(self):
|
||||
agent = ProxyAgent(
|
||||
self.reactor,
|
||||
contextFactory=get_test_https_policy(),
|
||||
https_proxy=b"proxy.com",
|
||||
use_proxy=True,
|
||||
)
|
||||
|
||||
self.reactor.lookups["proxy.com"] = "1.2.3.5"
|
||||
@@ -294,6 +321,7 @@ class MatrixFederationAgentTests(TestCase):
|
||||
body = self.successResultOf(treq.content(resp))
|
||||
self.assertEqual(body, b"result")
|
||||
|
||||
@patch.dict(os.environ, {"http_proxy": "proxy.com:8888"})
|
||||
def test_http_request_via_proxy_with_blacklist(self):
|
||||
# The blacklist includes the configured proxy IP.
|
||||
agent = ProxyAgent(
|
||||
@@ -301,7 +329,7 @@ class MatrixFederationAgentTests(TestCase):
|
||||
self.reactor, ip_whitelist=None, ip_blacklist=IPSet(["1.0.0.0/8"])
|
||||
),
|
||||
self.reactor,
|
||||
http_proxy=b"proxy.com:8888",
|
||||
use_proxy=True,
|
||||
)
|
||||
|
||||
self.reactor.lookups["proxy.com"] = "1.2.3.5"
|
||||
@@ -338,7 +366,8 @@ class MatrixFederationAgentTests(TestCase):
|
||||
body = self.successResultOf(treq.content(resp))
|
||||
self.assertEqual(body, b"result")
|
||||
|
||||
def test_https_request_via_proxy_with_blacklist(self):
|
||||
@patch.dict(os.environ, {"HTTPS_PROXY": "proxy.com"})
|
||||
def test_https_request_via_uppercase_proxy_with_blacklist(self):
|
||||
# The blacklist includes the configured proxy IP.
|
||||
agent = ProxyAgent(
|
||||
BlacklistingReactorWrapper(
|
||||
@@ -346,7 +375,7 @@ class MatrixFederationAgentTests(TestCase):
|
||||
),
|
||||
self.reactor,
|
||||
contextFactory=get_test_https_policy(),
|
||||
https_proxy=b"proxy.com",
|
||||
use_proxy=True,
|
||||
)
|
||||
|
||||
self.reactor.lookups["proxy.com"] = "1.2.3.5"
|
||||
|
||||
@@ -522,7 +522,9 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
|
||||
shorthand=False,
|
||||
)
|
||||
self.assertEqual(channel.code, 302, channel.result)
|
||||
cas_uri = channel.headers.getRawHeaders("Location")[0]
|
||||
location_headers = channel.headers.getRawHeaders("Location")
|
||||
assert location_headers
|
||||
cas_uri = location_headers[0]
|
||||
cas_uri_path, cas_uri_query = cas_uri.split("?", 1)
|
||||
|
||||
# it should redirect us to the login page of the cas server
|
||||
@@ -545,7 +547,9 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
|
||||
+ "&idp=saml",
|
||||
)
|
||||
self.assertEqual(channel.code, 302, channel.result)
|
||||
saml_uri = channel.headers.getRawHeaders("Location")[0]
|
||||
location_headers = channel.headers.getRawHeaders("Location")
|
||||
assert location_headers
|
||||
saml_uri = location_headers[0]
|
||||
saml_uri_path, saml_uri_query = saml_uri.split("?", 1)
|
||||
|
||||
# it should redirect us to the login page of the SAML server
|
||||
@@ -567,17 +571,21 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
|
||||
+ "&idp=oidc",
|
||||
)
|
||||
self.assertEqual(channel.code, 302, channel.result)
|
||||
oidc_uri = channel.headers.getRawHeaders("Location")[0]
|
||||
location_headers = channel.headers.getRawHeaders("Location")
|
||||
assert location_headers
|
||||
oidc_uri = location_headers[0]
|
||||
oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1)
|
||||
|
||||
# it should redirect us to the auth page of the OIDC server
|
||||
self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT)
|
||||
|
||||
# ... and should have set a cookie including the redirect url
|
||||
cookies = dict(
|
||||
h.split(";")[0].split("=", maxsplit=1)
|
||||
for h in channel.headers.getRawHeaders("Set-Cookie")
|
||||
)
|
||||
cookie_headers = channel.headers.getRawHeaders("Set-Cookie")
|
||||
assert cookie_headers
|
||||
cookies = {} # type: Dict[str, str]
|
||||
for h in cookie_headers:
|
||||
key, value = h.split(";")[0].split("=", maxsplit=1)
|
||||
cookies[key] = value
|
||||
|
||||
oidc_session_cookie = cookies["oidc_session"]
|
||||
macaroon = pymacaroons.Macaroon.deserialize(oidc_session_cookie)
|
||||
@@ -590,9 +598,9 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
# that should serve a confirmation page
|
||||
self.assertEqual(channel.code, 200, channel.result)
|
||||
self.assertTrue(
|
||||
channel.headers.getRawHeaders("Content-Type")[-1].startswith("text/html")
|
||||
)
|
||||
content_type_headers = channel.headers.getRawHeaders("Content-Type")
|
||||
assert content_type_headers
|
||||
self.assertTrue(content_type_headers[-1].startswith("text/html"))
|
||||
p = TestHtmlParser()
|
||||
p.feed(channel.text_body)
|
||||
p.close()
|
||||
@@ -806,6 +814,7 @@ class CASTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
self.assertEqual(channel.code, 302)
|
||||
location_headers = channel.headers.getRawHeaders("Location")
|
||||
assert location_headers
|
||||
self.assertEqual(location_headers[0][: len(redirect_url)], redirect_url)
|
||||
|
||||
@override_config({"sso": {"client_whitelist": ["https://legit-site.com/"]}})
|
||||
@@ -1248,7 +1257,9 @@ class UsernamePickerTestCase(HomeserverTestCase):
|
||||
|
||||
# that should redirect to the username picker
|
||||
self.assertEqual(channel.code, 302, channel.result)
|
||||
picker_url = channel.headers.getRawHeaders("Location")[0]
|
||||
location_headers = channel.headers.getRawHeaders("Location")
|
||||
assert location_headers
|
||||
picker_url = location_headers[0]
|
||||
self.assertEqual(picker_url, "/_synapse/client/pick_username/account_details")
|
||||
|
||||
# ... with a username_mapping_session cookie
|
||||
@@ -1291,6 +1302,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
|
||||
)
|
||||
self.assertEqual(chan.code, 302, chan.result)
|
||||
location_headers = chan.headers.getRawHeaders("Location")
|
||||
assert location_headers
|
||||
|
||||
# send a request to the completion page, which should 302 to the client redirectUrl
|
||||
chan = self.make_request(
|
||||
@@ -1300,6 +1312,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
|
||||
)
|
||||
self.assertEqual(chan.code, 302, chan.result)
|
||||
location_headers = chan.headers.getRawHeaders("Location")
|
||||
assert location_headers
|
||||
|
||||
# ensure that the returned location matches the requested redirect URL
|
||||
path, query = location_headers[0].split("?", 1)
|
||||
|
||||
Reference in New Issue
Block a user