Compare commits
30 Commits
matrix-org
...
matthew/e2
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f27afe1183 | ||
|
|
c199970627 | ||
|
|
e3bc9f2904 | ||
|
|
4cd381d96d | ||
|
|
1976540db7 | ||
|
|
ebfaeac6c1 | ||
|
|
10cec6f452 | ||
|
|
4fdd02fa24 | ||
|
|
843ce19127 | ||
|
|
7ab77f6502 | ||
|
|
3a8d9ab16f | ||
|
|
ef75daf286 | ||
|
|
eb1a429b74 | ||
|
|
68e9cfedf3 | ||
|
|
1aa87fcca5 | ||
|
|
fecfa638a6 | ||
|
|
98fcffdac4 | ||
|
|
a22968d642 | ||
|
|
dee6849f50 | ||
|
|
a658f4e238 | ||
|
|
ace298f629 | ||
|
|
59aa35b2c1 | ||
|
|
356829ca16 | ||
|
|
3e98a00658 | ||
|
|
124efee9ea | ||
|
|
552934ebac | ||
|
|
602b647503 | ||
|
|
7983840682 | ||
|
|
e558f213e9 | ||
|
|
824b3393cf |
2
.gitignore
vendored
2
.gitignore
vendored
@@ -46,5 +46,3 @@ static/client/register/register_config.js
|
||||
|
||||
env/
|
||||
*.config
|
||||
|
||||
.vscode/
|
||||
|
||||
53
CHANGES.rst
53
CHANGES.rst
@@ -1,56 +1,3 @@
|
||||
Unreleased
|
||||
==========
|
||||
|
||||
synctl no longer starts the main synapse when using ``-a`` option with workers.
|
||||
A new worker file should be added with ``worker_app: synapse.app.homeserver``.
|
||||
|
||||
This release also begins the process of renaming a number of the metrics
|
||||
reported to prometheus. See `docs/metrics-howto.rst <docs/metrics-howto.rst#block-and-response-metrics-renamed-for-0-27-0>`_.
|
||||
|
||||
|
||||
Changes in synapse v0.26.0 (2018-01-05)
|
||||
=======================================
|
||||
|
||||
No changes since v0.26.0-rc1
|
||||
|
||||
|
||||
Changes in synapse v0.26.0-rc1 (2017-12-13)
|
||||
===========================================
|
||||
|
||||
Features:
|
||||
|
||||
* Add ability for ASes to publicise groups for their users (PR #2686)
|
||||
* Add all local users to the user_directory and optionally search them (PR
|
||||
#2723)
|
||||
* Add support for custom login types for validating users (PR #2729)
|
||||
|
||||
|
||||
Changes:
|
||||
|
||||
* Update example Prometheus config to new format (PR #2648) Thanks to
|
||||
@krombel!
|
||||
* Rename redact_content option to include_content in Push API (PR #2650)
|
||||
* Declare support for r0.3.0 (PR #2677)
|
||||
* Improve upserts (PR #2684, #2688, #2689, #2713)
|
||||
* Improve documentation of workers (PR #2700)
|
||||
* Improve tracebacks on exceptions (PR #2705)
|
||||
* Allow guest access to group APIs for reading (PR #2715)
|
||||
* Support for posting content in federation_client script (PR #2716)
|
||||
* Delete devices and pushers on logouts etc (PR #2722)
|
||||
|
||||
|
||||
Bug fixes:
|
||||
|
||||
* Fix database port script (PR #2673)
|
||||
* Fix internal server error on login with ldap_auth_provider (PR #2678) Thanks
|
||||
to @jkolo!
|
||||
* Fix error on sqlite 3.7 (PR #2697)
|
||||
* Fix OPTIONS on preview_url (PR #2707)
|
||||
* Fix error handling on dns lookup (PR #2711)
|
||||
* Fix wrong avatars when inviting multiple users when creating room (PR #2717)
|
||||
* Fix 500 when joining matrix-dev (PR #2719)
|
||||
|
||||
|
||||
Changes in synapse v0.25.1 (2017-11-17)
|
||||
=======================================
|
||||
|
||||
|
||||
@@ -632,11 +632,6 @@ largest boxes pause for thought.)
|
||||
|
||||
Troubleshooting
|
||||
---------------
|
||||
|
||||
You can use the federation tester to check if your homeserver is all set:
|
||||
``https://matrix.org/federationtester/api/report?server_name=<your_server_name>``
|
||||
If any of the attributes under "checks" is false, federation won't work.
|
||||
|
||||
The typical failure mode with federation is that when you try to join a room,
|
||||
it is rejected with "401: Unauthorized". Generally this means that other
|
||||
servers in the room couldn't access yours. (Joining a room over federation is a
|
||||
|
||||
@@ -1,23 +0,0 @@
|
||||
# List all media in a room
|
||||
|
||||
This API gets a list of known media in a room.
|
||||
|
||||
The API is:
|
||||
```
|
||||
GET /_matrix/client/r0/admin/room/<room_id>/media
|
||||
```
|
||||
including an `access_token` of a server admin.
|
||||
|
||||
It returns a JSON body like the following:
|
||||
```
|
||||
{
|
||||
"local": [
|
||||
"mxc://localhost/xwvutsrqponmlkjihgfedcba",
|
||||
"mxc://localhost/abcdefghijklmnopqrstuvwx"
|
||||
],
|
||||
"remote": [
|
||||
"mxc://matrix.org/xwvutsrqponmlkjihgfedcba",
|
||||
"mxc://matrix.org/abcdefghijklmnopqrstuvwx"
|
||||
]
|
||||
}
|
||||
```
|
||||
@@ -4,6 +4,8 @@ Purge History API
|
||||
The purge history API allows server admins to purge historic events from their
|
||||
database, reclaiming disk space.
|
||||
|
||||
**NB!** This will not delete local events (locally sent messages content etc) from the database, but will remove lots of the metadata about them and does dramatically reduce the on disk space usage
|
||||
|
||||
Depending on the amount of history being purged a call to the API may take
|
||||
several minutes or longer. During this period users will not be able to
|
||||
paginate further back in the room from the point being purged from.
|
||||
@@ -13,15 +15,3 @@ The API is simply:
|
||||
``POST /_matrix/client/r0/admin/purge_history/<room_id>/<event_id>``
|
||||
|
||||
including an ``access_token`` of a server admin.
|
||||
|
||||
By default, events sent by local users are not deleted, as they may represent
|
||||
the only copies of this content in existence. (Events sent by remote users are
|
||||
deleted, and room state data before the cutoff is always removed).
|
||||
|
||||
To delete local events as well, set ``delete_local_events`` in the body:
|
||||
|
||||
.. code:: json
|
||||
|
||||
{
|
||||
"delete_local_events": true
|
||||
}
|
||||
|
||||
@@ -16,7 +16,7 @@ How to monitor Synapse metrics using Prometheus
|
||||
metrics_port: 9092
|
||||
|
||||
Also ensure that ``enable_metrics`` is set to ``True``.
|
||||
|
||||
|
||||
Restart synapse.
|
||||
|
||||
3. Add a prometheus target for synapse.
|
||||
@@ -28,58 +28,11 @@ How to monitor Synapse metrics using Prometheus
|
||||
static_configs:
|
||||
- targets: ["my.server.here:9092"]
|
||||
|
||||
If your prometheus is older than 1.5.2, you will need to replace
|
||||
If your prometheus is older than 1.5.2, you will need to replace
|
||||
``static_configs`` in the above with ``target_groups``.
|
||||
|
||||
|
||||
Restart prometheus.
|
||||
|
||||
|
||||
Block and response metrics renamed for 0.27.0
|
||||
---------------------------------------------
|
||||
|
||||
Synapse 0.27.0 begins the process of rationalising the duplicate ``*:count``
|
||||
metrics reported for the resource tracking for code blocks and HTTP requests.
|
||||
|
||||
At the same time, the corresponding ``*:total`` metrics are being renamed, as
|
||||
the ``:total`` suffix no longer makes sense in the absence of a corresponding
|
||||
``:count`` metric.
|
||||
|
||||
To enable a graceful migration path, this release just adds new names for the
|
||||
metrics being renamed. A future release will remove the old ones.
|
||||
|
||||
The following table shows the new metrics, and the old metrics which they are
|
||||
replacing.
|
||||
|
||||
==================================================== ===================================================
|
||||
New name Old name
|
||||
==================================================== ===================================================
|
||||
synapse_util_metrics_block_count synapse_util_metrics_block_timer:count
|
||||
synapse_util_metrics_block_count synapse_util_metrics_block_ru_utime:count
|
||||
synapse_util_metrics_block_count synapse_util_metrics_block_ru_stime:count
|
||||
synapse_util_metrics_block_count synapse_util_metrics_block_db_txn_count:count
|
||||
synapse_util_metrics_block_count synapse_util_metrics_block_db_txn_duration:count
|
||||
|
||||
synapse_util_metrics_block_time_seconds synapse_util_metrics_block_timer:total
|
||||
synapse_util_metrics_block_ru_utime_seconds synapse_util_metrics_block_ru_utime:total
|
||||
synapse_util_metrics_block_ru_stime_seconds synapse_util_metrics_block_ru_stime:total
|
||||
synapse_util_metrics_block_db_txn_count synapse_util_metrics_block_db_txn_count:total
|
||||
synapse_util_metrics_block_db_txn_duration_seconds synapse_util_metrics_block_db_txn_duration:total
|
||||
|
||||
synapse_http_server_response_count synapse_http_server_requests
|
||||
synapse_http_server_response_count synapse_http_server_response_time:count
|
||||
synapse_http_server_response_count synapse_http_server_response_ru_utime:count
|
||||
synapse_http_server_response_count synapse_http_server_response_ru_stime:count
|
||||
synapse_http_server_response_count synapse_http_server_response_db_txn_count:count
|
||||
synapse_http_server_response_count synapse_http_server_response_db_txn_duration:count
|
||||
|
||||
synapse_http_server_response_time_seconds synapse_http_server_response_time:total
|
||||
synapse_http_server_response_ru_utime_seconds synapse_http_server_response_ru_utime:total
|
||||
synapse_http_server_response_ru_stime_seconds synapse_http_server_response_ru_stime:total
|
||||
synapse_http_server_response_db_txn_count synapse_http_server_response_db_txn_count:total
|
||||
synapse_http_server_response_db_txn_duration_seconds synapse_http_server_response_db_txn_duration:total
|
||||
==================================================== ===================================================
|
||||
|
||||
|
||||
Standard Metric Names
|
||||
---------------------
|
||||
|
||||
@@ -89,7 +42,7 @@ have been changed to seconds, from miliseconds.
|
||||
|
||||
================================== =============================
|
||||
New name Old name
|
||||
================================== =============================
|
||||
---------------------------------- -----------------------------
|
||||
process_cpu_user_seconds_total process_resource_utime / 1000
|
||||
process_cpu_system_seconds_total process_resource_stime / 1000
|
||||
process_open_fds (no 'type' label) process_fds
|
||||
@@ -99,8 +52,8 @@ The python-specific counts of garbage collector performance have been renamed.
|
||||
|
||||
=========================== ======================
|
||||
New name Old name
|
||||
=========================== ======================
|
||||
python_gc_time reactor_gc_time
|
||||
--------------------------- ----------------------
|
||||
python_gc_time reactor_gc_time
|
||||
python_gc_unreachable_total reactor_gc_unreachable
|
||||
python_gc_counts reactor_gc_counts
|
||||
=========================== ======================
|
||||
@@ -109,7 +62,7 @@ The twisted-specific reactor metrics have been renamed.
|
||||
|
||||
==================================== =====================
|
||||
New name Old name
|
||||
==================================== =====================
|
||||
------------------------------------ ---------------------
|
||||
python_twisted_reactor_pending_calls reactor_pending_calls
|
||||
python_twisted_reactor_tick_time reactor_tick_time
|
||||
==================================== =====================
|
||||
|
||||
@@ -30,29 +30,17 @@ requests made to the federation port. The caveats regarding running a
|
||||
reverse-proxy on the federation port still apply (see
|
||||
https://github.com/matrix-org/synapse/blob/master/README.rst#reverse-proxying-the-federation-port).
|
||||
|
||||
To enable workers, you need to add two replication listeners to the master
|
||||
synapse, e.g.::
|
||||
To enable workers, you need to add a replication listener to the master synapse, e.g.::
|
||||
|
||||
listeners:
|
||||
# The TCP replication port
|
||||
- port: 9092
|
||||
bind_address: '127.0.0.1'
|
||||
type: replication
|
||||
# The HTTP replication port
|
||||
- port: 9093
|
||||
bind_address: '127.0.0.1'
|
||||
type: http
|
||||
resources:
|
||||
- names: [replication]
|
||||
|
||||
Under **no circumstances** should these replication API listeners be exposed to
|
||||
the public internet; it currently implements no authentication whatsoever and is
|
||||
Under **no circumstances** should this replication API listener be exposed to the
|
||||
public internet; it currently implements no authentication whatsoever and is
|
||||
unencrypted.
|
||||
|
||||
(Roughly, the TCP port is used for streaming data from the master to the
|
||||
workers, and the HTTP port for the workers to send data to the main
|
||||
synapse process.)
|
||||
|
||||
You then create a set of configs for the various worker processes. These
|
||||
should be worker configuration files, and should be stored in a dedicated
|
||||
subdirectory, to allow synctl to manipulate them.
|
||||
@@ -64,13 +52,8 @@ You should minimise the number of overrides though to maintain a usable config.
|
||||
|
||||
You must specify the type of worker application (``worker_app``). The currently
|
||||
available worker applications are listed below. You must also specify the
|
||||
replication endpoints that it's talking to on the main synapse process.
|
||||
``worker_replication_host`` should specify the host of the main synapse,
|
||||
``worker_replication_port`` should point to the TCP replication listener port and
|
||||
``worker_replication_http_port`` should point to the HTTP replication port.
|
||||
|
||||
Currently, only the ``event_creator`` worker requires specifying
|
||||
``worker_replication_http_port``.
|
||||
replication endpoint that it's talking to on the main synapse process
|
||||
(``worker_replication_host`` and ``worker_replication_port``).
|
||||
|
||||
For instance::
|
||||
|
||||
@@ -79,7 +62,6 @@ For instance::
|
||||
# The replication listener on the synapse to talk to.
|
||||
worker_replication_host: 127.0.0.1
|
||||
worker_replication_port: 9092
|
||||
worker_replication_http_port: 9093
|
||||
|
||||
worker_listeners:
|
||||
- type: http
|
||||
@@ -225,14 +207,3 @@ the ``worker_main_http_uri`` setting in the frontend_proxy worker configuration
|
||||
file. For example::
|
||||
|
||||
worker_main_http_uri: http://127.0.0.1:8008
|
||||
|
||||
|
||||
``synapse.app.event_creator``
|
||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Handles non-state event creation. It can handle REST endpoints matching:
|
||||
|
||||
^/_matrix/client/(api/v1|r0|unstable)/rooms/.*/send
|
||||
|
||||
It will create events locally and then send them on to the main synapse
|
||||
instance to be persisted and handled.
|
||||
|
||||
@@ -1,133 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2017 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Moves a list of remote media from one media store to another.
|
||||
|
||||
The input should be a list of media files to be moved, one per line. Each line
|
||||
should be formatted::
|
||||
|
||||
<origin server>|<file id>
|
||||
|
||||
This can be extracted from postgres with::
|
||||
|
||||
psql --tuples-only -A -c "select media_origin, filesystem_id from
|
||||
matrix.remote_media_cache where ..."
|
||||
|
||||
To use, pipe the above into::
|
||||
|
||||
PYTHON_PATH=. ./scripts/move_remote_media_to_new_store.py <source repo> <dest repo>
|
||||
"""
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
|
||||
import sys
|
||||
|
||||
import os
|
||||
|
||||
import shutil
|
||||
|
||||
from synapse.rest.media.v1.filepath import MediaFilePaths
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
def main(src_repo, dest_repo):
|
||||
src_paths = MediaFilePaths(src_repo)
|
||||
dest_paths = MediaFilePaths(dest_repo)
|
||||
for line in sys.stdin:
|
||||
line = line.strip()
|
||||
parts = line.split('|')
|
||||
if len(parts) != 2:
|
||||
print("Unable to parse input line %s" % line, file=sys.stderr)
|
||||
exit(1)
|
||||
|
||||
move_media(parts[0], parts[1], src_paths, dest_paths)
|
||||
|
||||
|
||||
def move_media(origin_server, file_id, src_paths, dest_paths):
|
||||
"""Move the given file, and any thumbnails, to the dest repo
|
||||
|
||||
Args:
|
||||
origin_server (str):
|
||||
file_id (str):
|
||||
src_paths (MediaFilePaths):
|
||||
dest_paths (MediaFilePaths):
|
||||
"""
|
||||
logger.info("%s/%s", origin_server, file_id)
|
||||
|
||||
# check that the original exists
|
||||
original_file = src_paths.remote_media_filepath(origin_server, file_id)
|
||||
if not os.path.exists(original_file):
|
||||
logger.warn(
|
||||
"Original for %s/%s (%s) does not exist",
|
||||
origin_server, file_id, original_file,
|
||||
)
|
||||
else:
|
||||
mkdir_and_move(
|
||||
original_file,
|
||||
dest_paths.remote_media_filepath(origin_server, file_id),
|
||||
)
|
||||
|
||||
# now look for thumbnails
|
||||
original_thumb_dir = src_paths.remote_media_thumbnail_dir(
|
||||
origin_server, file_id,
|
||||
)
|
||||
if not os.path.exists(original_thumb_dir):
|
||||
return
|
||||
|
||||
mkdir_and_move(
|
||||
original_thumb_dir,
|
||||
dest_paths.remote_media_thumbnail_dir(origin_server, file_id)
|
||||
)
|
||||
|
||||
|
||||
def mkdir_and_move(original_file, dest_file):
|
||||
dirname = os.path.dirname(dest_file)
|
||||
if not os.path.exists(dirname):
|
||||
logger.debug("mkdir %s", dirname)
|
||||
os.makedirs(dirname)
|
||||
logger.debug("mv %s %s", original_file, dest_file)
|
||||
shutil.move(original_file, dest_file)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description=__doc__,
|
||||
formatter_class = argparse.RawDescriptionHelpFormatter,
|
||||
)
|
||||
parser.add_argument(
|
||||
"-v", action='store_true', help='enable debug logging')
|
||||
parser.add_argument(
|
||||
"src_repo",
|
||||
help="Path to source content repo",
|
||||
)
|
||||
parser.add_argument(
|
||||
"dest_repo",
|
||||
help="Path to source content repo",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
logging_config = {
|
||||
"level": logging.DEBUG if args.v else logging.INFO,
|
||||
"format": "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s"
|
||||
}
|
||||
logging.basicConfig(**logging_config)
|
||||
|
||||
main(args.src_repo, args.dest_repo)
|
||||
@@ -16,4 +16,4 @@
|
||||
""" This is a reference implementation of a Matrix home server.
|
||||
"""
|
||||
|
||||
__version__ = "0.26.0"
|
||||
__version__ = "0.25.1"
|
||||
|
||||
@@ -46,9 +46,9 @@ class Codes(object):
|
||||
THREEPID_AUTH_FAILED = "M_THREEPID_AUTH_FAILED"
|
||||
THREEPID_IN_USE = "M_THREEPID_IN_USE"
|
||||
THREEPID_NOT_FOUND = "M_THREEPID_NOT_FOUND"
|
||||
THREEPID_DENIED = "M_THREEPID_DENIED"
|
||||
INVALID_USERNAME = "M_INVALID_USERNAME"
|
||||
SERVER_NOT_TRUSTED = "M_SERVER_NOT_TRUSTED"
|
||||
WRONG_ROOM_KEYS_VERSION = "M_WRONG_ROOM_KEYS_VERSION"
|
||||
|
||||
|
||||
class CodeMessageException(RuntimeError):
|
||||
@@ -141,32 +141,6 @@ class RegistrationError(SynapseError):
|
||||
pass
|
||||
|
||||
|
||||
class FederationDeniedError(SynapseError):
|
||||
"""An error raised when the server tries to federate with a server which
|
||||
is not on its federation whitelist.
|
||||
|
||||
Attributes:
|
||||
destination (str): The destination which has been denied
|
||||
"""
|
||||
|
||||
def __init__(self, destination):
|
||||
"""Raised by federation client or server to indicate that we are
|
||||
are deliberately not attempting to contact a given server because it is
|
||||
not on our federation whitelist.
|
||||
|
||||
Args:
|
||||
destination (str): the domain in question
|
||||
"""
|
||||
|
||||
self.destination = destination
|
||||
|
||||
super(FederationDeniedError, self).__init__(
|
||||
code=403,
|
||||
msg="Federation denied with %s." % (self.destination,),
|
||||
errcode=Codes.FORBIDDEN,
|
||||
)
|
||||
|
||||
|
||||
class InteractiveAuthIncompleteError(Exception):
|
||||
"""An error raised when UI auth is not yet complete
|
||||
|
||||
@@ -276,6 +250,27 @@ class LimitExceededError(SynapseError):
|
||||
)
|
||||
|
||||
|
||||
class RoomKeysVersionError(SynapseError):
|
||||
"""A client has tried to upload to a non-current version of the room_keys store
|
||||
"""
|
||||
def __init__(self, current_version):
|
||||
"""
|
||||
Args:
|
||||
current_version (str): the current version of the store they should have used
|
||||
"""
|
||||
super(RoomKeysVersionError, self).__init__(
|
||||
403, "Wrong room_keys version", Codes.WRONG_ROOM_KEYS_VERSION
|
||||
)
|
||||
self.current_version = current_version
|
||||
|
||||
def error_dict(self):
|
||||
return cs_error(
|
||||
self.msg,
|
||||
self.errcode,
|
||||
current_version=self.current_version,
|
||||
)
|
||||
|
||||
|
||||
def cs_exception(exception):
|
||||
if isinstance(exception, CodeMessageException):
|
||||
return exception.error_dict()
|
||||
|
||||
@@ -49,6 +49,19 @@ class AppserviceSlaveStore(
|
||||
|
||||
|
||||
class AppserviceServer(HomeServer):
|
||||
def get_db_conn(self, run_new_connection=True):
|
||||
# Any param beginning with cp_ is a parameter for adbapi, and should
|
||||
# not be passed to the database engine.
|
||||
db_params = {
|
||||
k: v for k, v in self.db_config.get("args", {}).items()
|
||||
if not k.startswith("cp_")
|
||||
}
|
||||
db_conn = self.database_engine.module.connect(**db_params)
|
||||
|
||||
if run_new_connection:
|
||||
self.database_engine.on_new_connection(db_conn)
|
||||
return db_conn
|
||||
|
||||
def setup(self):
|
||||
logger.info("Setting up.")
|
||||
self.datastore = AppserviceSlaveStore(self.get_db_conn(), self)
|
||||
|
||||
@@ -64,6 +64,19 @@ class ClientReaderSlavedStore(
|
||||
|
||||
|
||||
class ClientReaderServer(HomeServer):
|
||||
def get_db_conn(self, run_new_connection=True):
|
||||
# Any param beginning with cp_ is a parameter for adbapi, and should
|
||||
# not be passed to the database engine.
|
||||
db_params = {
|
||||
k: v for k, v in self.db_config.get("args", {}).items()
|
||||
if not k.startswith("cp_")
|
||||
}
|
||||
db_conn = self.database_engine.module.connect(**db_params)
|
||||
|
||||
if run_new_connection:
|
||||
self.database_engine.on_new_connection(db_conn)
|
||||
return db_conn
|
||||
|
||||
def setup(self):
|
||||
logger.info("Setting up.")
|
||||
self.datastore = ClientReaderSlavedStore(self.get_db_conn(), self)
|
||||
|
||||
@@ -1,170 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2018 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import sys
|
||||
|
||||
import synapse
|
||||
from synapse import events
|
||||
from synapse.app import _base
|
||||
from synapse.config._base import ConfigError
|
||||
from synapse.config.homeserver import HomeServerConfig
|
||||
from synapse.config.logger import setup_logging
|
||||
from synapse.crypto import context_factory
|
||||
from synapse.http.server import JsonResource
|
||||
from synapse.http.site import SynapseSite
|
||||
from synapse.metrics.resource import METRICS_PREFIX, MetricsResource
|
||||
from synapse.replication.slave.storage._base import BaseSlavedStore
|
||||
from synapse.replication.slave.storage.appservice import SlavedApplicationServiceStore
|
||||
from synapse.replication.slave.storage.client_ips import SlavedClientIpStore
|
||||
from synapse.replication.slave.storage.devices import SlavedDeviceStore
|
||||
from synapse.replication.slave.storage.events import SlavedEventStore
|
||||
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
|
||||
from synapse.replication.slave.storage.room import RoomStore
|
||||
from synapse.replication.tcp.client import ReplicationClientHandler
|
||||
from synapse.rest.client.v1.room import RoomSendEventRestServlet
|
||||
from synapse.server import HomeServer
|
||||
from synapse.storage.engines import create_engine
|
||||
from synapse.util.httpresourcetree import create_resource_tree
|
||||
from synapse.util.logcontext import LoggingContext
|
||||
from synapse.util.manhole import manhole
|
||||
from synapse.util.versionstring import get_version_string
|
||||
from twisted.internet import reactor
|
||||
from twisted.web.resource import Resource
|
||||
|
||||
logger = logging.getLogger("synapse.app.event_creator")
|
||||
|
||||
|
||||
class EventCreatorSlavedStore(
|
||||
SlavedDeviceStore,
|
||||
SlavedClientIpStore,
|
||||
SlavedApplicationServiceStore,
|
||||
SlavedEventStore,
|
||||
SlavedRegistrationStore,
|
||||
RoomStore,
|
||||
BaseSlavedStore,
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
class EventCreatorServer(HomeServer):
|
||||
def setup(self):
|
||||
logger.info("Setting up.")
|
||||
self.datastore = EventCreatorSlavedStore(self.get_db_conn(), self)
|
||||
logger.info("Finished setting up.")
|
||||
|
||||
def _listen_http(self, listener_config):
|
||||
port = listener_config["port"]
|
||||
bind_addresses = listener_config["bind_addresses"]
|
||||
site_tag = listener_config.get("tag", port)
|
||||
resources = {}
|
||||
for res in listener_config["resources"]:
|
||||
for name in res["names"]:
|
||||
if name == "metrics":
|
||||
resources[METRICS_PREFIX] = MetricsResource(self)
|
||||
elif name == "client":
|
||||
resource = JsonResource(self, canonical_json=False)
|
||||
RoomSendEventRestServlet(self).register(resource)
|
||||
resources.update({
|
||||
"/_matrix/client/r0": resource,
|
||||
"/_matrix/client/unstable": resource,
|
||||
"/_matrix/client/v2_alpha": resource,
|
||||
"/_matrix/client/api/v1": resource,
|
||||
})
|
||||
|
||||
root_resource = create_resource_tree(resources, Resource())
|
||||
|
||||
_base.listen_tcp(
|
||||
bind_addresses,
|
||||
port,
|
||||
SynapseSite(
|
||||
"synapse.access.http.%s" % (site_tag,),
|
||||
site_tag,
|
||||
listener_config,
|
||||
root_resource,
|
||||
)
|
||||
)
|
||||
|
||||
logger.info("Synapse event creator now listening on port %d", port)
|
||||
|
||||
def start_listening(self, listeners):
|
||||
for listener in listeners:
|
||||
if listener["type"] == "http":
|
||||
self._listen_http(listener)
|
||||
elif listener["type"] == "manhole":
|
||||
_base.listen_tcp(
|
||||
listener["bind_addresses"],
|
||||
listener["port"],
|
||||
manhole(
|
||||
username="matrix",
|
||||
password="rabbithole",
|
||||
globals={"hs": self},
|
||||
)
|
||||
)
|
||||
else:
|
||||
logger.warn("Unrecognized listener type: %s", listener["type"])
|
||||
|
||||
self.get_tcp_replication().start_replication(self)
|
||||
|
||||
def build_tcp_replication(self):
|
||||
return ReplicationClientHandler(self.get_datastore())
|
||||
|
||||
|
||||
def start(config_options):
|
||||
try:
|
||||
config = HomeServerConfig.load_config(
|
||||
"Synapse event creator", config_options
|
||||
)
|
||||
except ConfigError as e:
|
||||
sys.stderr.write("\n" + e.message + "\n")
|
||||
sys.exit(1)
|
||||
|
||||
assert config.worker_app == "synapse.app.event_creator"
|
||||
|
||||
assert config.worker_replication_http_port is not None
|
||||
|
||||
setup_logging(config, use_worker_options=True)
|
||||
|
||||
events.USE_FROZEN_DICTS = config.use_frozen_dicts
|
||||
|
||||
database_engine = create_engine(config.database_config)
|
||||
|
||||
tls_server_context_factory = context_factory.ServerContextFactory(config)
|
||||
|
||||
ss = EventCreatorServer(
|
||||
config.server_name,
|
||||
db_config=config.database_config,
|
||||
tls_server_context_factory=tls_server_context_factory,
|
||||
config=config,
|
||||
version_string="Synapse/" + get_version_string(synapse),
|
||||
database_engine=database_engine,
|
||||
)
|
||||
|
||||
ss.setup()
|
||||
ss.get_handlers()
|
||||
ss.start_listening(config.worker_listeners)
|
||||
|
||||
def start():
|
||||
ss.get_state_handler().start_caching()
|
||||
ss.get_datastore().start_profiling()
|
||||
|
||||
reactor.callWhenRunning(start)
|
||||
|
||||
_base.start_worker_reactor("synapse-event-creator", config)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
with LoggingContext("main"):
|
||||
start(sys.argv[1:])
|
||||
@@ -58,6 +58,19 @@ class FederationReaderSlavedStore(
|
||||
|
||||
|
||||
class FederationReaderServer(HomeServer):
|
||||
def get_db_conn(self, run_new_connection=True):
|
||||
# Any param beginning with cp_ is a parameter for adbapi, and should
|
||||
# not be passed to the database engine.
|
||||
db_params = {
|
||||
k: v for k, v in self.db_config.get("args", {}).items()
|
||||
if not k.startswith("cp_")
|
||||
}
|
||||
db_conn = self.database_engine.module.connect(**db_params)
|
||||
|
||||
if run_new_connection:
|
||||
self.database_engine.on_new_connection(db_conn)
|
||||
return db_conn
|
||||
|
||||
def setup(self):
|
||||
logger.info("Setting up.")
|
||||
self.datastore = FederationReaderSlavedStore(self.get_db_conn(), self)
|
||||
|
||||
@@ -76,6 +76,19 @@ class FederationSenderSlaveStore(
|
||||
|
||||
|
||||
class FederationSenderServer(HomeServer):
|
||||
def get_db_conn(self, run_new_connection=True):
|
||||
# Any param beginning with cp_ is a parameter for adbapi, and should
|
||||
# not be passed to the database engine.
|
||||
db_params = {
|
||||
k: v for k, v in self.db_config.get("args", {}).items()
|
||||
if not k.startswith("cp_")
|
||||
}
|
||||
db_conn = self.database_engine.module.connect(**db_params)
|
||||
|
||||
if run_new_connection:
|
||||
self.database_engine.on_new_connection(db_conn)
|
||||
return db_conn
|
||||
|
||||
def setup(self):
|
||||
logger.info("Setting up.")
|
||||
self.datastore = FederationSenderSlaveStore(self.get_db_conn(), self)
|
||||
|
||||
@@ -36,7 +36,6 @@ from synapse.replication.slave.storage.client_ips import SlavedClientIpStore
|
||||
from synapse.replication.slave.storage.devices import SlavedDeviceStore
|
||||
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
|
||||
from synapse.replication.tcp.client import ReplicationClientHandler
|
||||
from synapse.rest.client.v1.base import ClientV1RestServlet, client_path_patterns
|
||||
from synapse.rest.client.v2_alpha._base import client_v2_patterns
|
||||
from synapse.server import HomeServer
|
||||
from synapse.storage.engines import create_engine
|
||||
@@ -50,35 +49,6 @@ from twisted.web.resource import Resource
|
||||
logger = logging.getLogger("synapse.app.frontend_proxy")
|
||||
|
||||
|
||||
class PresenceStatusStubServlet(ClientV1RestServlet):
|
||||
PATTERNS = client_path_patterns("/presence/(?P<user_id>[^/]*)/status")
|
||||
|
||||
def __init__(self, hs):
|
||||
super(PresenceStatusStubServlet, self).__init__(hs)
|
||||
self.http_client = hs.get_simple_http_client()
|
||||
self.auth = hs.get_auth()
|
||||
self.main_uri = hs.config.worker_main_http_uri
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, user_id):
|
||||
# Pass through the auth headers, if any, in case the access token
|
||||
# is there.
|
||||
auth_headers = request.requestHeaders.getRawHeaders("Authorization", [])
|
||||
headers = {
|
||||
"Authorization": auth_headers,
|
||||
}
|
||||
result = yield self.http_client.get_json(
|
||||
self.main_uri + request.uri,
|
||||
headers=headers,
|
||||
)
|
||||
defer.returnValue((200, result))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_PUT(self, request, user_id):
|
||||
yield self.auth.get_user_by_req(request)
|
||||
defer.returnValue((200, {}))
|
||||
|
||||
|
||||
class KeyUploadServlet(RestServlet):
|
||||
PATTERNS = client_v2_patterns("/keys/upload(/(?P<device_id>[^/]+))?$")
|
||||
|
||||
@@ -148,6 +118,19 @@ class FrontendProxySlavedStore(
|
||||
|
||||
|
||||
class FrontendProxyServer(HomeServer):
|
||||
def get_db_conn(self, run_new_connection=True):
|
||||
# Any param beginning with cp_ is a parameter for adbapi, and should
|
||||
# not be passed to the database engine.
|
||||
db_params = {
|
||||
k: v for k, v in self.db_config.get("args", {}).items()
|
||||
if not k.startswith("cp_")
|
||||
}
|
||||
db_conn = self.database_engine.module.connect(**db_params)
|
||||
|
||||
if run_new_connection:
|
||||
self.database_engine.on_new_connection(db_conn)
|
||||
return db_conn
|
||||
|
||||
def setup(self):
|
||||
logger.info("Setting up.")
|
||||
self.datastore = FrontendProxySlavedStore(self.get_db_conn(), self)
|
||||
@@ -165,7 +148,6 @@ class FrontendProxyServer(HomeServer):
|
||||
elif name == "client":
|
||||
resource = JsonResource(self, canonical_json=False)
|
||||
KeyUploadServlet(self).register(resource)
|
||||
PresenceStatusStubServlet(self).register(resource)
|
||||
resources.update({
|
||||
"/_matrix/client/r0": resource,
|
||||
"/_matrix/client/unstable": resource,
|
||||
|
||||
@@ -38,7 +38,6 @@ from synapse.metrics import register_memory_metrics
|
||||
from synapse.metrics.resource import METRICS_PREFIX, MetricsResource
|
||||
from synapse.python_dependencies import CONDITIONAL_REQUIREMENTS, \
|
||||
check_requirements
|
||||
from synapse.replication.http import ReplicationRestResource, REPLICATION_PREFIX
|
||||
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
|
||||
from synapse.rest import ClientRestResource
|
||||
from synapse.rest.key.v1.server_key_resource import LocalKey
|
||||
@@ -220,9 +219,6 @@ class SynapseHomeServer(HomeServer):
|
||||
if name == "metrics" and self.get_config().enable_metrics:
|
||||
resources[METRICS_PREFIX] = MetricsResource(self)
|
||||
|
||||
if name == "replication":
|
||||
resources[REPLICATION_PREFIX] = ReplicationRestResource(self)
|
||||
|
||||
return resources
|
||||
|
||||
def start_listening(self):
|
||||
@@ -270,6 +266,19 @@ class SynapseHomeServer(HomeServer):
|
||||
except IncorrectDatabaseSetup as e:
|
||||
quit_with_error(e.message)
|
||||
|
||||
def get_db_conn(self, run_new_connection=True):
|
||||
# Any param beginning with cp_ is a parameter for adbapi, and should
|
||||
# not be passed to the database engine.
|
||||
db_params = {
|
||||
k: v for k, v in self.db_config.get("args", {}).items()
|
||||
if not k.startswith("cp_")
|
||||
}
|
||||
db_conn = self.database_engine.module.connect(**db_params)
|
||||
|
||||
if run_new_connection:
|
||||
self.database_engine.on_new_connection(db_conn)
|
||||
return db_conn
|
||||
|
||||
|
||||
def setup(config_options):
|
||||
"""
|
||||
|
||||
@@ -60,6 +60,19 @@ class MediaRepositorySlavedStore(
|
||||
|
||||
|
||||
class MediaRepositoryServer(HomeServer):
|
||||
def get_db_conn(self, run_new_connection=True):
|
||||
# Any param beginning with cp_ is a parameter for adbapi, and should
|
||||
# not be passed to the database engine.
|
||||
db_params = {
|
||||
k: v for k, v in self.db_config.get("args", {}).items()
|
||||
if not k.startswith("cp_")
|
||||
}
|
||||
db_conn = self.database_engine.module.connect(**db_params)
|
||||
|
||||
if run_new_connection:
|
||||
self.database_engine.on_new_connection(db_conn)
|
||||
return db_conn
|
||||
|
||||
def setup(self):
|
||||
logger.info("Setting up.")
|
||||
self.datastore = MediaRepositorySlavedStore(self.get_db_conn(), self)
|
||||
|
||||
@@ -81,6 +81,19 @@ class PusherSlaveStore(
|
||||
|
||||
|
||||
class PusherServer(HomeServer):
|
||||
def get_db_conn(self, run_new_connection=True):
|
||||
# Any param beginning with cp_ is a parameter for adbapi, and should
|
||||
# not be passed to the database engine.
|
||||
db_params = {
|
||||
k: v for k, v in self.db_config.get("args", {}).items()
|
||||
if not k.startswith("cp_")
|
||||
}
|
||||
db_conn = self.database_engine.module.connect(**db_params)
|
||||
|
||||
if run_new_connection:
|
||||
self.database_engine.on_new_connection(db_conn)
|
||||
return db_conn
|
||||
|
||||
def setup(self):
|
||||
logger.info("Setting up.")
|
||||
self.datastore = PusherSlaveStore(self.get_db_conn(), self)
|
||||
|
||||
@@ -117,7 +117,6 @@ class SynchrotronPresence(object):
|
||||
logger.info("Presence process_id is %r", self.process_id)
|
||||
|
||||
def send_user_sync(self, user_id, is_syncing, last_sync_ms):
|
||||
return
|
||||
self.hs.get_tcp_replication().send_user_sync(user_id, is_syncing, last_sync_ms)
|
||||
|
||||
def mark_as_coming_online(self, user_id):
|
||||
@@ -215,8 +214,6 @@ class SynchrotronPresence(object):
|
||||
yield self.notify_from_replication(states, stream_id)
|
||||
|
||||
def get_currently_syncing_users(self):
|
||||
# presence is disabled on matrix.org, so we return the empty set
|
||||
return set()
|
||||
return [
|
||||
user_id for user_id, count in self.user_to_num_current_syncs.iteritems()
|
||||
if count > 0
|
||||
@@ -249,6 +246,19 @@ class SynchrotronApplicationService(object):
|
||||
|
||||
|
||||
class SynchrotronServer(HomeServer):
|
||||
def get_db_conn(self, run_new_connection=True):
|
||||
# Any param beginning with cp_ is a parameter for adbapi, and should
|
||||
# not be passed to the database engine.
|
||||
db_params = {
|
||||
k: v for k, v in self.db_config.get("args", {}).items()
|
||||
if not k.startswith("cp_")
|
||||
}
|
||||
db_conn = self.database_engine.module.connect(**db_params)
|
||||
|
||||
if run_new_connection:
|
||||
self.database_engine.on_new_connection(db_conn)
|
||||
return db_conn
|
||||
|
||||
def setup(self):
|
||||
logger.info("Setting up.")
|
||||
self.datastore = SynchrotronSlavedStore(self.get_db_conn(), self)
|
||||
|
||||
@@ -184,9 +184,6 @@ def main():
|
||||
worker_configfiles.append(worker_configfile)
|
||||
|
||||
if options.all_processes:
|
||||
# To start the main synapse with -a you need to add a worker file
|
||||
# with worker_app == "synapse.app.homeserver"
|
||||
start_stop_synapse = False
|
||||
worker_configdir = options.all_processes
|
||||
if not os.path.isdir(worker_configdir):
|
||||
write(
|
||||
@@ -203,29 +200,11 @@ def main():
|
||||
with open(worker_configfile) as stream:
|
||||
worker_config = yaml.load(stream)
|
||||
worker_app = worker_config["worker_app"]
|
||||
if worker_app == "synapse.app.homeserver":
|
||||
# We need to special case all of this to pick up options that may
|
||||
# be set in the main config file or in this worker config file.
|
||||
worker_pidfile = (
|
||||
worker_config.get("pid_file")
|
||||
or pidfile
|
||||
)
|
||||
worker_cache_factor = worker_config.get("synctl_cache_factor") or cache_factor
|
||||
daemonize = worker_config.get("daemonize") or config.get("daemonize")
|
||||
assert daemonize, "Main process must have daemonize set to true"
|
||||
|
||||
# The master process doesn't support using worker_* config.
|
||||
for key in worker_config:
|
||||
if key == "worker_app": # But we allow worker_app
|
||||
continue
|
||||
assert not key.startswith("worker_"), \
|
||||
"Main process cannot use worker_* config"
|
||||
else:
|
||||
worker_pidfile = worker_config["worker_pid_file"]
|
||||
worker_daemonize = worker_config["worker_daemonize"]
|
||||
assert worker_daemonize, "In config %r: expected '%s' to be True" % (
|
||||
worker_configfile, "worker_daemonize")
|
||||
worker_cache_factor = worker_config.get("synctl_cache_factor")
|
||||
worker_pidfile = worker_config["worker_pid_file"]
|
||||
worker_daemonize = worker_config["worker_daemonize"]
|
||||
assert worker_daemonize, "In config %r: expected '%s' to be True" % (
|
||||
worker_configfile, "worker_daemonize")
|
||||
worker_cache_factor = worker_config.get("synctl_cache_factor")
|
||||
workers.append(Worker(
|
||||
worker_app, worker_configfile, worker_pidfile, worker_cache_factor,
|
||||
))
|
||||
|
||||
@@ -92,6 +92,19 @@ class UserDirectorySlaveStore(
|
||||
|
||||
|
||||
class UserDirectoryServer(HomeServer):
|
||||
def get_db_conn(self, run_new_connection=True):
|
||||
# Any param beginning with cp_ is a parameter for adbapi, and should
|
||||
# not be passed to the database engine.
|
||||
db_params = {
|
||||
k: v for k, v in self.db_config.get("args", {}).items()
|
||||
if not k.startswith("cp_")
|
||||
}
|
||||
db_conn = self.database_engine.module.connect(**db_params)
|
||||
|
||||
if run_new_connection:
|
||||
self.database_engine.on_new_connection(db_conn)
|
||||
return db_conn
|
||||
|
||||
def setup(self):
|
||||
logger.info("Setting up.")
|
||||
self.datastore = UserDirectorySlaveStore(self.get_db_conn(), self)
|
||||
|
||||
@@ -28,27 +28,27 @@ DEFAULT_LOG_CONFIG = Template("""
|
||||
version: 1
|
||||
|
||||
formatters:
|
||||
precise:
|
||||
format: '%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - \
|
||||
%(request)s - %(message)s'
|
||||
precise:
|
||||
format: '%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s\
|
||||
- %(message)s'
|
||||
|
||||
filters:
|
||||
context:
|
||||
(): synapse.util.logcontext.LoggingContextFilter
|
||||
request: ""
|
||||
context:
|
||||
(): synapse.util.logcontext.LoggingContextFilter
|
||||
request: ""
|
||||
|
||||
handlers:
|
||||
file:
|
||||
class: logging.handlers.RotatingFileHandler
|
||||
formatter: precise
|
||||
filename: ${log_file}
|
||||
maxBytes: 104857600
|
||||
backupCount: 10
|
||||
filters: [context]
|
||||
console:
|
||||
class: logging.StreamHandler
|
||||
formatter: precise
|
||||
filters: [context]
|
||||
file:
|
||||
class: logging.handlers.RotatingFileHandler
|
||||
formatter: precise
|
||||
filename: ${log_file}
|
||||
maxBytes: 104857600
|
||||
backupCount: 10
|
||||
filters: [context]
|
||||
console:
|
||||
class: logging.StreamHandler
|
||||
formatter: precise
|
||||
filters: [context]
|
||||
|
||||
loggers:
|
||||
synapse:
|
||||
@@ -74,10 +74,17 @@ class LoggingConfig(Config):
|
||||
self.log_file = self.abspath(config.get("log_file"))
|
||||
|
||||
def default_config(self, config_dir_path, server_name, **kwargs):
|
||||
log_file = self.abspath("homeserver.log")
|
||||
log_config = self.abspath(
|
||||
os.path.join(config_dir_path, server_name + ".log.config")
|
||||
)
|
||||
return """
|
||||
# Logging verbosity level. Ignored if log_config is specified.
|
||||
verbose: 0
|
||||
|
||||
# File to write logging to. Ignored if log_config is specified.
|
||||
log_file: "%(log_file)s"
|
||||
|
||||
# A yaml python logging config file
|
||||
log_config: "%(log_config)s"
|
||||
""" % locals()
|
||||
@@ -116,10 +123,9 @@ class LoggingConfig(Config):
|
||||
def generate_files(self, config):
|
||||
log_config = config.get("log_config")
|
||||
if log_config and not os.path.exists(log_config):
|
||||
log_file = self.abspath("homeserver.log")
|
||||
with open(log_config, "wb") as log_config_file:
|
||||
log_config_file.write(
|
||||
DEFAULT_LOG_CONFIG.substitute(log_file=log_file)
|
||||
DEFAULT_LOG_CONFIG.substitute(log_file=config["log_file"])
|
||||
)
|
||||
|
||||
|
||||
@@ -144,9 +150,6 @@ def setup_logging(config, use_worker_options=False):
|
||||
)
|
||||
|
||||
if log_config is None:
|
||||
# We don't have a logfile, so fall back to the 'verbosity' param from
|
||||
# the config or cmdline. (Note that we generate a log config for new
|
||||
# installs, so this will be an unusual case)
|
||||
level = logging.INFO
|
||||
level_for_storage = logging.INFO
|
||||
if config.verbosity:
|
||||
@@ -154,10 +157,11 @@ def setup_logging(config, use_worker_options=False):
|
||||
if config.verbosity > 1:
|
||||
level_for_storage = logging.DEBUG
|
||||
|
||||
# FIXME: we need a logging.WARN for a -q quiet option
|
||||
logger = logging.getLogger('')
|
||||
logger.setLevel(level)
|
||||
|
||||
logging.getLogger('synapse.storage.SQL').setLevel(level_for_storage)
|
||||
logging.getLogger('synapse.storage').setLevel(level_for_storage)
|
||||
|
||||
formatter = logging.Formatter(log_format)
|
||||
if log_file:
|
||||
|
||||
@@ -31,8 +31,6 @@ class RegistrationConfig(Config):
|
||||
strtobool(str(config["disable_registration"]))
|
||||
)
|
||||
|
||||
self.registrations_require_3pid = config.get("registrations_require_3pid", [])
|
||||
self.allowed_local_3pids = config.get("allowed_local_3pids", [])
|
||||
self.registration_shared_secret = config.get("registration_shared_secret")
|
||||
|
||||
self.bcrypt_rounds = config.get("bcrypt_rounds", 12)
|
||||
@@ -54,23 +52,6 @@ class RegistrationConfig(Config):
|
||||
# Enable registration for new users.
|
||||
enable_registration: False
|
||||
|
||||
# The user must provide all of the below types of 3PID when registering.
|
||||
#
|
||||
# registrations_require_3pid:
|
||||
# - email
|
||||
# - msisdn
|
||||
|
||||
# Mandate that users are only allowed to associate certain formats of
|
||||
# 3PIDs with accounts on this server.
|
||||
#
|
||||
# allowed_local_3pids:
|
||||
# - medium: email
|
||||
# pattern: ".*@matrix\\.org"
|
||||
# - medium: email
|
||||
# pattern: ".*@vector\\.im"
|
||||
# - medium: msisdn
|
||||
# pattern: "\\+44"
|
||||
|
||||
# If set, allows registration by anyone who also has the shared
|
||||
# secret, even if registration is otherwise disabled.
|
||||
registration_shared_secret: "%(registration_shared_secret)s"
|
||||
|
||||
@@ -16,8 +16,6 @@
|
||||
from ._base import Config, ConfigError
|
||||
from collections import namedtuple
|
||||
|
||||
from synapse.util.module_loader import load_module
|
||||
|
||||
|
||||
MISSING_NETADDR = (
|
||||
"Missing netaddr library. This is required for URL preview API."
|
||||
@@ -38,14 +36,6 @@ ThumbnailRequirement = namedtuple(
|
||||
"ThumbnailRequirement", ["width", "height", "method", "media_type"]
|
||||
)
|
||||
|
||||
MediaStorageProviderConfig = namedtuple(
|
||||
"MediaStorageProviderConfig", (
|
||||
"store_local", # Whether to store newly uploaded local files
|
||||
"store_remote", # Whether to store newly downloaded remote files
|
||||
"store_synchronous", # Whether to wait for successful storage for local uploads
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def parse_thumbnail_requirements(thumbnail_sizes):
|
||||
""" Takes a list of dictionaries with "width", "height", and "method" keys
|
||||
@@ -83,61 +73,16 @@ class ContentRepositoryConfig(Config):
|
||||
|
||||
self.media_store_path = self.ensure_directory(config["media_store_path"])
|
||||
|
||||
backup_media_store_path = config.get("backup_media_store_path")
|
||||
self.backup_media_store_path = config.get("backup_media_store_path")
|
||||
if self.backup_media_store_path:
|
||||
self.backup_media_store_path = self.ensure_directory(
|
||||
self.backup_media_store_path
|
||||
)
|
||||
|
||||
synchronous_backup_media_store = config.get(
|
||||
self.synchronous_backup_media_store = config.get(
|
||||
"synchronous_backup_media_store", False
|
||||
)
|
||||
|
||||
storage_providers = config.get("media_storage_providers", [])
|
||||
|
||||
if backup_media_store_path:
|
||||
if storage_providers:
|
||||
raise ConfigError(
|
||||
"Cannot use both 'backup_media_store_path' and 'storage_providers'"
|
||||
)
|
||||
|
||||
storage_providers = [{
|
||||
"module": "file_system",
|
||||
"store_local": True,
|
||||
"store_synchronous": synchronous_backup_media_store,
|
||||
"store_remote": True,
|
||||
"config": {
|
||||
"directory": backup_media_store_path,
|
||||
}
|
||||
}]
|
||||
|
||||
# This is a list of config that can be used to create the storage
|
||||
# providers. The entries are tuples of (Class, class_config,
|
||||
# MediaStorageProviderConfig), where Class is the class of the provider,
|
||||
# the class_config the config to pass to it, and
|
||||
# MediaStorageProviderConfig are options for StorageProviderWrapper.
|
||||
#
|
||||
# We don't create the storage providers here as not all workers need
|
||||
# them to be started.
|
||||
self.media_storage_providers = []
|
||||
|
||||
for provider_config in storage_providers:
|
||||
# We special case the module "file_system" so as not to need to
|
||||
# expose FileStorageProviderBackend
|
||||
if provider_config["module"] == "file_system":
|
||||
provider_config["module"] = (
|
||||
"synapse.rest.media.v1.storage_provider"
|
||||
".FileStorageProviderBackend"
|
||||
)
|
||||
|
||||
provider_class, parsed_config = load_module(provider_config)
|
||||
|
||||
wrapper_config = MediaStorageProviderConfig(
|
||||
provider_config.get("store_local", False),
|
||||
provider_config.get("store_remote", False),
|
||||
provider_config.get("store_synchronous", False),
|
||||
)
|
||||
|
||||
self.media_storage_providers.append(
|
||||
(provider_class, parsed_config, wrapper_config,)
|
||||
)
|
||||
|
||||
self.uploads_path = self.ensure_directory(config["uploads_path"])
|
||||
self.dynamic_thumbnails = config["dynamic_thumbnails"]
|
||||
self.thumbnail_requirements = parse_thumbnail_requirements(
|
||||
@@ -182,19 +127,13 @@ class ContentRepositoryConfig(Config):
|
||||
# Directory where uploaded images and attachments are stored.
|
||||
media_store_path: "%(media_store)s"
|
||||
|
||||
# Media storage providers allow media to be stored in different
|
||||
# locations.
|
||||
# media_storage_providers:
|
||||
# - module: file_system
|
||||
# # Whether to write new local files.
|
||||
# store_local: false
|
||||
# # Whether to write new remote media
|
||||
# store_remote: false
|
||||
# # Whether to block upload requests waiting for write to this
|
||||
# # provider to complete
|
||||
# store_synchronous: false
|
||||
# config:
|
||||
# directory: /mnt/some/other/directory
|
||||
# A secondary directory where uploaded images and attachments are
|
||||
# stored as a backup.
|
||||
# backup_media_store_path: "%(media_store)s"
|
||||
|
||||
# Whether to wait for successful write to backup media store before
|
||||
# returning successfully.
|
||||
# synchronous_backup_media_store: false
|
||||
|
||||
# Directory where in-progress uploads are stored.
|
||||
uploads_path: "%(uploads_path)s"
|
||||
|
||||
@@ -55,17 +55,6 @@ class ServerConfig(Config):
|
||||
"block_non_admin_invites", False,
|
||||
)
|
||||
|
||||
# FIXME: federation_domain_whitelist needs sytests
|
||||
self.federation_domain_whitelist = None
|
||||
federation_domain_whitelist = config.get(
|
||||
"federation_domain_whitelist", None
|
||||
)
|
||||
# turn the whitelist into a hash for speed of lookup
|
||||
if federation_domain_whitelist is not None:
|
||||
self.federation_domain_whitelist = {}
|
||||
for domain in federation_domain_whitelist:
|
||||
self.federation_domain_whitelist[domain] = True
|
||||
|
||||
if self.public_baseurl is not None:
|
||||
if self.public_baseurl[-1] != '/':
|
||||
self.public_baseurl += '/'
|
||||
@@ -221,17 +210,6 @@ class ServerConfig(Config):
|
||||
# (except those sent by local server admins). The default is False.
|
||||
# block_non_admin_invites: True
|
||||
|
||||
# Restrict federation to the following whitelist of domains.
|
||||
# N.B. we recommend also firewalling your federation listener to limit
|
||||
# inbound federation traffic as early as possible, rather than relying
|
||||
# purely on this application-layer restriction. If not specified, the
|
||||
# default is to whitelist everything.
|
||||
#
|
||||
# federation_domain_whitelist:
|
||||
# - lon.example.com
|
||||
# - nyc.example.com
|
||||
# - syd.example.com
|
||||
|
||||
# List of ports that Synapse should listen on, their purpose and their
|
||||
# configuration.
|
||||
listeners:
|
||||
|
||||
@@ -96,7 +96,7 @@ class TlsConfig(Config):
|
||||
# certificates returned by this server match one of the fingerprints.
|
||||
#
|
||||
# Synapse automatically adds the fingerprint of its own certificate
|
||||
# to the list. So if federation traffic is handled directly by synapse
|
||||
# to the list. So if federation traffic is handle directly by synapse
|
||||
# then no modification to the list is required.
|
||||
#
|
||||
# If synapse is run behind a load balancer that handles the TLS then it
|
||||
|
||||
@@ -23,26 +23,13 @@ class WorkerConfig(Config):
|
||||
|
||||
def read_config(self, config):
|
||||
self.worker_app = config.get("worker_app")
|
||||
|
||||
# Canonicalise worker_app so that master always has None
|
||||
if self.worker_app == "synapse.app.homeserver":
|
||||
self.worker_app = None
|
||||
|
||||
self.worker_listeners = config.get("worker_listeners")
|
||||
self.worker_daemonize = config.get("worker_daemonize")
|
||||
self.worker_pid_file = config.get("worker_pid_file")
|
||||
self.worker_log_file = config.get("worker_log_file")
|
||||
self.worker_log_config = config.get("worker_log_config")
|
||||
|
||||
# The host used to connect to the main synapse
|
||||
self.worker_replication_host = config.get("worker_replication_host", None)
|
||||
|
||||
# The port on the main synapse for TCP replication
|
||||
self.worker_replication_port = config.get("worker_replication_port", None)
|
||||
|
||||
# The port on the main synapse for HTTP replication endpoint
|
||||
self.worker_replication_http_port = config.get("worker_replication_http_port")
|
||||
|
||||
self.worker_name = config.get("worker_name", self.worker_app)
|
||||
|
||||
self.worker_main_http_uri = config.get("worker_main_http_uri", None)
|
||||
|
||||
@@ -319,7 +319,7 @@ def _is_membership_change_allowed(event, auth_events):
|
||||
# TODO (erikj): Implement kicks.
|
||||
if target_banned and user_level < ban_level:
|
||||
raise AuthError(
|
||||
403, "You cannot unban user %s." % (target_user_id,)
|
||||
403, "You cannot unban user &s." % (target_user_id,)
|
||||
)
|
||||
elif target_user_id != event.user_id:
|
||||
kick_level = _get_named_level(auth_events, "kick", 50)
|
||||
|
||||
@@ -14,9 +14,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from frozendict import frozendict
|
||||
|
||||
|
||||
class EventContext(object):
|
||||
"""
|
||||
Attributes:
|
||||
@@ -28,9 +25,7 @@ class EventContext(object):
|
||||
The current state map excluding the current event.
|
||||
(type, state_key) -> event_id
|
||||
|
||||
state_group (int|None): state group id, if the state has been stored
|
||||
as a state group. This is usually only None if e.g. the event is
|
||||
an outlier.
|
||||
state_group (int): state group id
|
||||
rejected (bool|str): A rejection reason if the event was rejected, else
|
||||
False
|
||||
|
||||
@@ -76,72 +71,3 @@ class EventContext(object):
|
||||
self.prev_state_events = None
|
||||
|
||||
self.app_service = None
|
||||
|
||||
def serialize(self):
|
||||
"""Converts self to a type that can be serialized as JSON, and then
|
||||
deserialized by `deserialize`
|
||||
|
||||
Returns:
|
||||
dict
|
||||
"""
|
||||
return {
|
||||
"current_state_ids": _encode_state_dict(self.current_state_ids),
|
||||
"prev_state_ids": _encode_state_dict(self.prev_state_ids),
|
||||
"state_group": self.state_group,
|
||||
"rejected": self.rejected,
|
||||
"push_actions": self.push_actions,
|
||||
"prev_group": self.prev_group,
|
||||
"delta_ids": _encode_state_dict(self.delta_ids),
|
||||
"prev_state_events": self.prev_state_events,
|
||||
"app_service_id": self.app_service.id if self.app_service else None
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def deserialize(store, input):
|
||||
"""Converts a dict that was produced by `serialize` back into a
|
||||
EventContext.
|
||||
|
||||
Args:
|
||||
store (DataStore): Used to convert AS ID to AS object
|
||||
input (dict): A dict produced by `serialize`
|
||||
|
||||
Returns:
|
||||
EventContext
|
||||
"""
|
||||
context = EventContext()
|
||||
context.current_state_ids = _decode_state_dict(input["current_state_ids"])
|
||||
context.prev_state_ids = _decode_state_dict(input["prev_state_ids"])
|
||||
context.state_group = input["state_group"]
|
||||
context.rejected = input["rejected"]
|
||||
context.push_actions = input["push_actions"]
|
||||
context.prev_group = input["prev_group"]
|
||||
context.delta_ids = _decode_state_dict(input["delta_ids"])
|
||||
context.prev_state_events = input["prev_state_events"]
|
||||
|
||||
app_service_id = input["app_service_id"]
|
||||
if app_service_id:
|
||||
context.app_service = store.get_app_service_by_id(app_service_id)
|
||||
|
||||
return context
|
||||
|
||||
|
||||
def _encode_state_dict(state_dict):
|
||||
"""Since dicts of (type, state_key) -> event_id cannot be serialized in
|
||||
JSON we need to convert them to a form that can.
|
||||
"""
|
||||
if state_dict is None:
|
||||
return None
|
||||
|
||||
return [
|
||||
(etype, state_key, v)
|
||||
for (etype, state_key), v in state_dict.iteritems()
|
||||
]
|
||||
|
||||
|
||||
def _decode_state_dict(input):
|
||||
"""Decodes a state dict encoded using `_encode_state_dict` above
|
||||
"""
|
||||
if input is None:
|
||||
return None
|
||||
|
||||
return frozendict({(etype, state_key,): v for etype, state_key, v in input})
|
||||
|
||||
@@ -16,9 +16,7 @@ import logging
|
||||
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.crypto.event_signing import check_event_content_hash
|
||||
from synapse.events import FrozenEvent
|
||||
from synapse.events.utils import prune_event
|
||||
from synapse.http.servlet import assert_params_in_request
|
||||
from synapse.util import unwrapFirstError, logcontext
|
||||
from twisted.internet import defer
|
||||
|
||||
@@ -171,28 +169,3 @@ class FederationBase(object):
|
||||
)
|
||||
|
||||
return deferreds
|
||||
|
||||
|
||||
def event_from_pdu_json(pdu_json, outlier=False):
|
||||
"""Construct a FrozenEvent from an event json received over federation
|
||||
|
||||
Args:
|
||||
pdu_json (object): pdu as received over federation
|
||||
outlier (bool): True to mark this event as an outlier
|
||||
|
||||
Returns:
|
||||
FrozenEvent
|
||||
|
||||
Raises:
|
||||
SynapseError: if the pdu is missing required fields
|
||||
"""
|
||||
# we could probably enforce a bunch of other fields here (room_id, sender,
|
||||
# origin, etc etc)
|
||||
assert_params_in_request(pdu_json, ('event_id', 'type'))
|
||||
event = FrozenEvent(
|
||||
pdu_json
|
||||
)
|
||||
|
||||
event.internal_metadata.outlier = outlier
|
||||
|
||||
return event
|
||||
|
||||
@@ -14,28 +14,28 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from .federation_base import FederationBase
|
||||
from synapse.api.constants import Membership
|
||||
|
||||
from synapse.api.errors import (
|
||||
CodeMessageException, HttpResponseException, SynapseError,
|
||||
)
|
||||
from synapse.util import unwrapFirstError, logcontext
|
||||
from synapse.util.caches.expiringcache import ExpiringCache
|
||||
from synapse.util.logutils import log_function
|
||||
from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
|
||||
from synapse.events import FrozenEvent, builder
|
||||
import synapse.metrics
|
||||
|
||||
from synapse.util.retryutils import NotRetryingDestination
|
||||
|
||||
import copy
|
||||
import itertools
|
||||
import logging
|
||||
import random
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.constants import Membership
|
||||
from synapse.api.errors import (
|
||||
CodeMessageException, HttpResponseException, SynapseError, FederationDeniedError
|
||||
)
|
||||
from synapse.events import builder
|
||||
from synapse.federation.federation_base import (
|
||||
FederationBase,
|
||||
event_from_pdu_json,
|
||||
)
|
||||
import synapse.metrics
|
||||
from synapse.util import logcontext, unwrapFirstError
|
||||
from synapse.util.caches.expiringcache import ExpiringCache
|
||||
from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
|
||||
from synapse.util.logutils import log_function
|
||||
from synapse.util.retryutils import NotRetryingDestination
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -184,7 +184,7 @@ class FederationClient(FederationBase):
|
||||
logger.debug("backfill transaction_data=%s", repr(transaction_data))
|
||||
|
||||
pdus = [
|
||||
event_from_pdu_json(p, outlier=False)
|
||||
self.event_from_pdu_json(p, outlier=False)
|
||||
for p in transaction_data["pdus"]
|
||||
]
|
||||
|
||||
@@ -244,7 +244,7 @@ class FederationClient(FederationBase):
|
||||
logger.debug("transaction_data %r", transaction_data)
|
||||
|
||||
pdu_list = [
|
||||
event_from_pdu_json(p, outlier=outlier)
|
||||
self.event_from_pdu_json(p, outlier=outlier)
|
||||
for p in transaction_data["pdus"]
|
||||
]
|
||||
|
||||
@@ -266,9 +266,6 @@ class FederationClient(FederationBase):
|
||||
except NotRetryingDestination as e:
|
||||
logger.info(e.message)
|
||||
continue
|
||||
except FederationDeniedError as e:
|
||||
logger.info(e.message)
|
||||
continue
|
||||
except Exception as e:
|
||||
pdu_attempts[destination] = now
|
||||
|
||||
@@ -339,11 +336,11 @@ class FederationClient(FederationBase):
|
||||
)
|
||||
|
||||
pdus = [
|
||||
event_from_pdu_json(p, outlier=True) for p in result["pdus"]
|
||||
self.event_from_pdu_json(p, outlier=True) for p in result["pdus"]
|
||||
]
|
||||
|
||||
auth_chain = [
|
||||
event_from_pdu_json(p, outlier=True)
|
||||
self.event_from_pdu_json(p, outlier=True)
|
||||
for p in result.get("auth_chain", [])
|
||||
]
|
||||
|
||||
@@ -444,7 +441,7 @@ class FederationClient(FederationBase):
|
||||
)
|
||||
|
||||
auth_chain = [
|
||||
event_from_pdu_json(p, outlier=True)
|
||||
self.event_from_pdu_json(p, outlier=True)
|
||||
for p in res["auth_chain"]
|
||||
]
|
||||
|
||||
@@ -573,12 +570,12 @@ class FederationClient(FederationBase):
|
||||
logger.debug("Got content: %s", content)
|
||||
|
||||
state = [
|
||||
event_from_pdu_json(p, outlier=True)
|
||||
self.event_from_pdu_json(p, outlier=True)
|
||||
for p in content.get("state", [])
|
||||
]
|
||||
|
||||
auth_chain = [
|
||||
event_from_pdu_json(p, outlier=True)
|
||||
self.event_from_pdu_json(p, outlier=True)
|
||||
for p in content.get("auth_chain", [])
|
||||
]
|
||||
|
||||
@@ -653,7 +650,7 @@ class FederationClient(FederationBase):
|
||||
|
||||
logger.debug("Got response to send_invite: %s", pdu_dict)
|
||||
|
||||
pdu = event_from_pdu_json(pdu_dict)
|
||||
pdu = self.event_from_pdu_json(pdu_dict)
|
||||
|
||||
# Check signatures are correct.
|
||||
pdu = yield self._check_sigs_and_hash(pdu)
|
||||
@@ -743,7 +740,7 @@ class FederationClient(FederationBase):
|
||||
)
|
||||
|
||||
auth_chain = [
|
||||
event_from_pdu_json(e)
|
||||
self.event_from_pdu_json(e)
|
||||
for e in content["auth_chain"]
|
||||
]
|
||||
|
||||
@@ -791,7 +788,7 @@ class FederationClient(FederationBase):
|
||||
)
|
||||
|
||||
events = [
|
||||
event_from_pdu_json(e)
|
||||
self.event_from_pdu_json(e)
|
||||
for e in content.get("events", [])
|
||||
]
|
||||
|
||||
@@ -808,6 +805,15 @@ class FederationClient(FederationBase):
|
||||
|
||||
defer.returnValue(signed_events)
|
||||
|
||||
def event_from_pdu_json(self, pdu_json, outlier=False):
|
||||
event = FrozenEvent(
|
||||
pdu_json
|
||||
)
|
||||
|
||||
event.internal_metadata.outlier = outlier
|
||||
|
||||
return event
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def forward_third_party_invite(self, destinations, room_id, event_dict):
|
||||
for destination in destinations:
|
||||
|
||||
@@ -12,24 +12,25 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
|
||||
import simplejson as json
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.errors import AuthError, FederationError, SynapseError
|
||||
from synapse.crypto.event_signing import compute_event_signature
|
||||
from synapse.federation.federation_base import (
|
||||
FederationBase,
|
||||
event_from_pdu_json,
|
||||
)
|
||||
from synapse.federation.units import Edu, Transaction
|
||||
import synapse.metrics
|
||||
from synapse.types import get_domain_from_id
|
||||
from .federation_base import FederationBase
|
||||
from .units import Transaction, Edu
|
||||
|
||||
from synapse.util import async
|
||||
from synapse.util.caches.response_cache import ResponseCache
|
||||
from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
|
||||
from synapse.util.logutils import log_function
|
||||
from synapse.util.caches.response_cache import ResponseCache
|
||||
from synapse.events import FrozenEvent
|
||||
from synapse.types import get_domain_from_id
|
||||
import synapse.metrics
|
||||
|
||||
from synapse.api.errors import AuthError, FederationError, SynapseError
|
||||
|
||||
from synapse.crypto.event_signing import compute_event_signature
|
||||
|
||||
import simplejson as json
|
||||
import logging
|
||||
|
||||
# when processing incoming transactions, we try to handle multiple rooms in
|
||||
# parallel, up to this limit.
|
||||
@@ -171,7 +172,7 @@ class FederationServer(FederationBase):
|
||||
p["age_ts"] = request_time - int(p["age"])
|
||||
del p["age"]
|
||||
|
||||
event = event_from_pdu_json(p)
|
||||
event = self.event_from_pdu_json(p)
|
||||
room_id = event.room_id
|
||||
pdus_by_room.setdefault(room_id, []).append(event)
|
||||
|
||||
@@ -345,7 +346,7 @@ class FederationServer(FederationBase):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_invite_request(self, origin, content):
|
||||
pdu = event_from_pdu_json(content)
|
||||
pdu = self.event_from_pdu_json(content)
|
||||
ret_pdu = yield self.handler.on_invite_request(origin, pdu)
|
||||
time_now = self._clock.time_msec()
|
||||
defer.returnValue((200, {"event": ret_pdu.get_pdu_json(time_now)}))
|
||||
@@ -353,7 +354,7 @@ class FederationServer(FederationBase):
|
||||
@defer.inlineCallbacks
|
||||
def on_send_join_request(self, origin, content):
|
||||
logger.debug("on_send_join_request: content: %s", content)
|
||||
pdu = event_from_pdu_json(content)
|
||||
pdu = self.event_from_pdu_json(content)
|
||||
logger.debug("on_send_join_request: pdu sigs: %s", pdu.signatures)
|
||||
res_pdus = yield self.handler.on_send_join_request(origin, pdu)
|
||||
time_now = self._clock.time_msec()
|
||||
@@ -373,7 +374,7 @@ class FederationServer(FederationBase):
|
||||
@defer.inlineCallbacks
|
||||
def on_send_leave_request(self, origin, content):
|
||||
logger.debug("on_send_leave_request: content: %s", content)
|
||||
pdu = event_from_pdu_json(content)
|
||||
pdu = self.event_from_pdu_json(content)
|
||||
logger.debug("on_send_leave_request: pdu sigs: %s", pdu.signatures)
|
||||
yield self.handler.on_send_leave_request(origin, pdu)
|
||||
defer.returnValue((200, {}))
|
||||
@@ -410,7 +411,7 @@ class FederationServer(FederationBase):
|
||||
"""
|
||||
with (yield self._server_linearizer.queue((origin, room_id))):
|
||||
auth_chain = [
|
||||
event_from_pdu_json(e)
|
||||
self.event_from_pdu_json(e)
|
||||
for e in content["auth_chain"]
|
||||
]
|
||||
|
||||
@@ -585,6 +586,15 @@ class FederationServer(FederationBase):
|
||||
def __str__(self):
|
||||
return "<ReplicationLayer(%s)>" % self.server_name
|
||||
|
||||
def event_from_pdu_json(self, pdu_json, outlier=False):
|
||||
event = FrozenEvent(
|
||||
pdu_json
|
||||
)
|
||||
|
||||
event.internal_metadata.outlier = outlier
|
||||
|
||||
return event
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def exchange_third_party_invite(
|
||||
self,
|
||||
|
||||
@@ -19,7 +19,7 @@ from twisted.internet import defer
|
||||
from .persistence import TransactionActions
|
||||
from .units import Transaction, Edu
|
||||
|
||||
from synapse.api.errors import HttpResponseException, FederationDeniedError
|
||||
from synapse.api.errors import HttpResponseException
|
||||
from synapse.util import logcontext, PreserveLoggingContext
|
||||
from synapse.util.async import run_on_reactor
|
||||
from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter
|
||||
@@ -42,8 +42,6 @@ sent_edus_counter = client_metrics.register_counter("sent_edus")
|
||||
|
||||
sent_transactions_counter = client_metrics.register_counter("sent_transactions")
|
||||
|
||||
events_processed_counter = client_metrics.register_counter("events_processed")
|
||||
|
||||
|
||||
class TransactionQueue(object):
|
||||
"""This class makes sure we only have one transaction in flight at
|
||||
@@ -184,22 +182,17 @@ class TransactionQueue(object):
|
||||
if not is_mine and send_on_behalf_of is None:
|
||||
continue
|
||||
|
||||
try:
|
||||
# Get the state from before the event.
|
||||
# We need to make sure that this is the state from before
|
||||
# the event and not from after it.
|
||||
# Otherwise if the last member on a server in a room is
|
||||
# banned then it won't receive the event because it won't
|
||||
# be in the room after the ban.
|
||||
destinations = yield self.state.get_current_hosts_in_room(
|
||||
event.room_id, latest_event_ids=[
|
||||
prev_id for prev_id, _ in event.prev_events
|
||||
],
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Failed to calculate hosts in room")
|
||||
continue
|
||||
|
||||
# Get the state from before the event.
|
||||
# We need to make sure that this is the state from before
|
||||
# the event and not from after it.
|
||||
# Otherwise if the last member on a server in a room is
|
||||
# banned then it won't receive the event because it won't
|
||||
# be in the room after the ban.
|
||||
destinations = yield self.state.get_current_hosts_in_room(
|
||||
event.room_id, latest_event_ids=[
|
||||
prev_id for prev_id, _ in event.prev_events
|
||||
],
|
||||
)
|
||||
destinations = set(destinations)
|
||||
|
||||
if send_on_behalf_of is not None:
|
||||
@@ -212,8 +205,6 @@ class TransactionQueue(object):
|
||||
|
||||
self._send_pdu(event, destinations)
|
||||
|
||||
events_processed_counter.inc_by(len(events))
|
||||
|
||||
yield self.store.update_federation_out_pos(
|
||||
"events", next_token
|
||||
)
|
||||
@@ -259,7 +250,6 @@ class TransactionQueue(object):
|
||||
Args:
|
||||
states (list(UserPresenceState))
|
||||
"""
|
||||
return
|
||||
|
||||
# First we queue up the new presence by user ID, so multiple presence
|
||||
# updates in quick successtion are correctly handled
|
||||
@@ -496,8 +486,6 @@ class TransactionQueue(object):
|
||||
(e.retry_last_ts + e.retry_interval) / 1000.0
|
||||
),
|
||||
)
|
||||
except FederationDeniedError as e:
|
||||
logger.info(e)
|
||||
except Exception as e:
|
||||
logger.warn(
|
||||
"TX [%s] Failed to send transaction: %s",
|
||||
|
||||
@@ -212,9 +212,6 @@ class TransportLayerClient(object):
|
||||
|
||||
Fails with ``NotRetryingDestination`` if we are not yet ready
|
||||
to retry this server.
|
||||
|
||||
Fails with ``FederationDeniedError`` if the remote destination
|
||||
is not in our federation whitelist
|
||||
"""
|
||||
valid_memberships = {Membership.JOIN, Membership.LEAVE}
|
||||
if membership not in valid_memberships:
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.urls import FEDERATION_PREFIX as PREFIX
|
||||
from synapse.api.errors import Codes, SynapseError, FederationDeniedError
|
||||
from synapse.api.errors import Codes, SynapseError
|
||||
from synapse.http.server import JsonResource
|
||||
from synapse.http.servlet import (
|
||||
parse_json_object_from_request, parse_integer_from_args, parse_string_from_args,
|
||||
@@ -81,7 +81,6 @@ class Authenticator(object):
|
||||
self.keyring = hs.get_keyring()
|
||||
self.server_name = hs.hostname
|
||||
self.store = hs.get_datastore()
|
||||
self.federation_domain_whitelist = hs.config.federation_domain_whitelist
|
||||
|
||||
# A method just so we can pass 'self' as the authenticator to the Servlets
|
||||
@defer.inlineCallbacks
|
||||
@@ -93,12 +92,6 @@ class Authenticator(object):
|
||||
"signatures": {},
|
||||
}
|
||||
|
||||
if (
|
||||
self.federation_domain_whitelist is not None and
|
||||
self.server_name not in self.federation_domain_whitelist
|
||||
):
|
||||
raise FederationDeniedError(self.server_name)
|
||||
|
||||
if content is not None:
|
||||
json_request["content"] = content
|
||||
|
||||
|
||||
@@ -15,7 +15,6 @@
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
import synapse
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.util.metrics import Measure
|
||||
from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
|
||||
@@ -24,10 +23,6 @@ import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
metrics = synapse.metrics.get_metrics_for(__name__)
|
||||
|
||||
events_processed_counter = metrics.register_counter("events_processed")
|
||||
|
||||
|
||||
def log_failure(failure):
|
||||
logger.error(
|
||||
@@ -108,8 +103,6 @@ class ApplicationServicesHandler(object):
|
||||
service, event
|
||||
)
|
||||
|
||||
events_processed_counter.inc_by(len(events))
|
||||
|
||||
yield self.store.set_appservice_last_pos(upper_bound)
|
||||
finally:
|
||||
self.is_processing = False
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from twisted.internet import defer, threads
|
||||
from twisted.internet import defer
|
||||
|
||||
from ._base import BaseHandler
|
||||
from synapse.api.constants import LoginType
|
||||
@@ -25,7 +25,6 @@ from synapse.module_api import ModuleApi
|
||||
from synapse.types import UserID
|
||||
from synapse.util.async import run_on_reactor
|
||||
from synapse.util.caches.expiringcache import ExpiringCache
|
||||
from synapse.util.logcontext import make_deferred_yieldable
|
||||
|
||||
from twisted.web.client import PartialDownloadError
|
||||
|
||||
@@ -715,7 +714,7 @@ class AuthHandler(BaseHandler):
|
||||
if not lookupres:
|
||||
defer.returnValue(None)
|
||||
(user_id, password_hash) = lookupres
|
||||
result = yield self.validate_hash(password, password_hash)
|
||||
result = self.validate_hash(password, password_hash)
|
||||
if not result:
|
||||
logger.warn("Failed password login for user %s", user_id)
|
||||
defer.returnValue(None)
|
||||
@@ -843,13 +842,10 @@ class AuthHandler(BaseHandler):
|
||||
password (str): Password to hash.
|
||||
|
||||
Returns:
|
||||
Deferred(str): Hashed password.
|
||||
Hashed password (str).
|
||||
"""
|
||||
def _do_hash():
|
||||
return bcrypt.hashpw(password.encode('utf8') + self.hs.config.password_pepper,
|
||||
bcrypt.gensalt(self.bcrypt_rounds))
|
||||
|
||||
return make_deferred_yieldable(threads.deferToThread(_do_hash))
|
||||
return bcrypt.hashpw(password.encode('utf8') + self.hs.config.password_pepper,
|
||||
bcrypt.gensalt(self.bcrypt_rounds))
|
||||
|
||||
def validate_hash(self, password, stored_hash):
|
||||
"""Validates that self.hash(password) == stored_hash.
|
||||
@@ -859,17 +855,13 @@ class AuthHandler(BaseHandler):
|
||||
stored_hash (str): Expected hash value.
|
||||
|
||||
Returns:
|
||||
Deferred(bool): Whether self.hash(password) == stored_hash.
|
||||
Whether self.hash(password) == stored_hash (bool).
|
||||
"""
|
||||
|
||||
def _do_validate_hash():
|
||||
if stored_hash:
|
||||
return bcrypt.hashpw(password.encode('utf8') + self.hs.config.password_pepper,
|
||||
stored_hash.encode('utf8')) == stored_hash
|
||||
|
||||
if stored_hash:
|
||||
return make_deferred_yieldable(threads.deferToThread(_do_validate_hash))
|
||||
else:
|
||||
return defer.succeed(False)
|
||||
return False
|
||||
|
||||
|
||||
class MacaroonGeneartor(object):
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
# limitations under the License.
|
||||
from synapse.api import errors
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.api.errors import FederationDeniedError
|
||||
from synapse.util import stringutils
|
||||
from synapse.util.async import Linearizer
|
||||
from synapse.util.caches.expiringcache import ExpiringCache
|
||||
@@ -514,9 +513,6 @@ class DeviceListEduUpdater(object):
|
||||
# This makes it more likely that the device lists will
|
||||
# eventually become consistent.
|
||||
return
|
||||
except FederationDeniedError as e:
|
||||
logger.info(e)
|
||||
return
|
||||
except Exception:
|
||||
# TODO: Remember that we are now out of sync and try again
|
||||
# later
|
||||
|
||||
@@ -17,8 +17,7 @@ import logging
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.types import get_domain_from_id, UserID
|
||||
from synapse.types import get_domain_from_id
|
||||
from synapse.util.stringutils import random_string
|
||||
|
||||
|
||||
@@ -34,7 +33,7 @@ class DeviceMessageHandler(object):
|
||||
"""
|
||||
self.store = hs.get_datastore()
|
||||
self.notifier = hs.get_notifier()
|
||||
self.is_mine = hs.is_mine
|
||||
self.is_mine_id = hs.is_mine_id
|
||||
self.federation = hs.get_federation_sender()
|
||||
|
||||
hs.get_replication_layer().register_edu_handler(
|
||||
@@ -53,12 +52,6 @@ class DeviceMessageHandler(object):
|
||||
message_type = content["type"]
|
||||
message_id = content["message_id"]
|
||||
for user_id, by_device in content["messages"].items():
|
||||
# we use UserID.from_string to catch invalid user ids
|
||||
if not self.is_mine(UserID.from_string(user_id)):
|
||||
logger.warning("Request for keys for non-local user %s",
|
||||
user_id)
|
||||
raise SynapseError(400, "Not a user here")
|
||||
|
||||
messages_by_device = {
|
||||
device_id: {
|
||||
"content": message_content,
|
||||
@@ -84,8 +77,7 @@ class DeviceMessageHandler(object):
|
||||
local_messages = {}
|
||||
remote_messages = {}
|
||||
for user_id, by_device in messages.items():
|
||||
# we use UserID.from_string to catch invalid user ids
|
||||
if self.is_mine(UserID.from_string(user_id)):
|
||||
if self.is_mine_id(user_id):
|
||||
messages_by_device = {
|
||||
device_id: {
|
||||
"content": message_content,
|
||||
|
||||
@@ -34,7 +34,6 @@ class DirectoryHandler(BaseHandler):
|
||||
|
||||
self.state = hs.get_state_handler()
|
||||
self.appservice_handler = hs.get_application_service_handler()
|
||||
self.event_creation_handler = hs.get_event_creation_handler()
|
||||
|
||||
self.federation = hs.get_replication_layer()
|
||||
self.federation.register_query_handler(
|
||||
@@ -250,7 +249,8 @@ class DirectoryHandler(BaseHandler):
|
||||
def send_room_alias_update_event(self, requester, user_id, room_id):
|
||||
aliases = yield self.store.get_aliases_for_room(room_id)
|
||||
|
||||
yield self.event_creation_handler.create_and_send_nonmember_event(
|
||||
msg_handler = self.hs.get_handlers().message_handler
|
||||
yield msg_handler.create_and_send_nonmember_event(
|
||||
requester,
|
||||
{
|
||||
"type": EventTypes.Aliases,
|
||||
@@ -272,7 +272,8 @@ class DirectoryHandler(BaseHandler):
|
||||
if not alias_event or alias_event.content.get("alias", "") != alias_str:
|
||||
return
|
||||
|
||||
yield self.event_creation_handler.create_and_send_nonmember_event(
|
||||
msg_handler = self.hs.get_handlers().message_handler
|
||||
yield msg_handler.create_and_send_nonmember_event(
|
||||
requester,
|
||||
{
|
||||
"type": EventTypes.CanonicalAlias,
|
||||
|
||||
@@ -19,10 +19,8 @@ import logging
|
||||
from canonicaljson import encode_canonical_json
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.errors import (
|
||||
SynapseError, CodeMessageException, FederationDeniedError,
|
||||
)
|
||||
from synapse.types import get_domain_from_id, UserID
|
||||
from synapse.api.errors import SynapseError, CodeMessageException
|
||||
from synapse.types import get_domain_from_id
|
||||
from synapse.util.logcontext import preserve_fn, make_deferred_yieldable
|
||||
from synapse.util.retryutils import NotRetryingDestination
|
||||
|
||||
@@ -34,7 +32,7 @@ class E2eKeysHandler(object):
|
||||
self.store = hs.get_datastore()
|
||||
self.federation = hs.get_replication_layer()
|
||||
self.device_handler = hs.get_device_handler()
|
||||
self.is_mine = hs.is_mine
|
||||
self.is_mine_id = hs.is_mine_id
|
||||
self.clock = hs.get_clock()
|
||||
|
||||
# doesn't really work as part of the generic query API, because the
|
||||
@@ -72,8 +70,7 @@ class E2eKeysHandler(object):
|
||||
remote_queries = {}
|
||||
|
||||
for user_id, device_ids in device_keys_query.items():
|
||||
# we use UserID.from_string to catch invalid user ids
|
||||
if self.is_mine(UserID.from_string(user_id)):
|
||||
if self.is_mine_id(user_id):
|
||||
local_query[user_id] = device_ids
|
||||
else:
|
||||
remote_queries[user_id] = device_ids
|
||||
@@ -142,10 +139,6 @@ class E2eKeysHandler(object):
|
||||
failures[destination] = {
|
||||
"status": 503, "message": "Not ready for retry",
|
||||
}
|
||||
except FederationDeniedError as e:
|
||||
failures[destination] = {
|
||||
"status": 403, "message": "Federation Denied",
|
||||
}
|
||||
except Exception as e:
|
||||
# include ConnectionRefused and other errors
|
||||
failures[destination] = {
|
||||
@@ -177,8 +170,7 @@ class E2eKeysHandler(object):
|
||||
|
||||
result_dict = {}
|
||||
for user_id, device_ids in query.items():
|
||||
# we use UserID.from_string to catch invalid user ids
|
||||
if not self.is_mine(UserID.from_string(user_id)):
|
||||
if not self.is_mine_id(user_id):
|
||||
logger.warning("Request for keys for non-local user %s",
|
||||
user_id)
|
||||
raise SynapseError(400, "Not a user here")
|
||||
@@ -221,8 +213,7 @@ class E2eKeysHandler(object):
|
||||
remote_queries = {}
|
||||
|
||||
for user_id, device_keys in query.get("one_time_keys", {}).items():
|
||||
# we use UserID.from_string to catch invalid user ids
|
||||
if self.is_mine(UserID.from_string(user_id)):
|
||||
if self.is_mine_id(user_id):
|
||||
for device_id, algorithm in device_keys.items():
|
||||
local_query.append((user_id, device_id, algorithm))
|
||||
else:
|
||||
|
||||
281
synapse/handlers/e2e_room_keys.py
Normal file
281
synapse/handlers/e2e_room_keys.py
Normal file
@@ -0,0 +1,281 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2017 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.errors import StoreError, SynapseError, RoomKeysVersionError
|
||||
from synapse.util.async import Linearizer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class E2eRoomKeysHandler(object):
|
||||
"""
|
||||
Implements an optional realtime backup mechanism for encrypted E2E megolm room keys.
|
||||
This gives a way for users to store and recover their megolm keys if they lose all
|
||||
their clients. It should also extend easily to future room key mechanisms.
|
||||
The actual payload of the encrypted keys is completely opaque to the handler.
|
||||
"""
|
||||
|
||||
def __init__(self, hs):
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
# Used to lock whenever a client is uploading key data. This prevents collisions
|
||||
# between clients trying to upload the details of a new session, given all
|
||||
# clients belonging to a user will receive and try to upload a new session at
|
||||
# roughly the same time. Also used to lock out uploads when the key is being
|
||||
# changed.
|
||||
self._upload_linearizer = Linearizer("upload_room_keys_lock")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_room_keys(self, user_id, version, room_id=None, session_id=None):
|
||||
"""Bulk get the E2E room keys for a given backup, optionally filtered to a given
|
||||
room, or a given session.
|
||||
See EndToEndRoomKeyStore.get_e2e_room_keys for full details.
|
||||
|
||||
Returns:
|
||||
A deferred list of dicts giving the session_data and message metadata for
|
||||
these room keys.
|
||||
"""
|
||||
|
||||
# we deliberately take the lock to get keys so that changing the version
|
||||
# works atomically
|
||||
with (yield self._upload_linearizer.queue(user_id)):
|
||||
results = yield self.store.get_e2e_room_keys(
|
||||
user_id, version, room_id, session_id
|
||||
)
|
||||
|
||||
if results['rooms'] == {}:
|
||||
raise SynapseError(404, "No room_keys found")
|
||||
|
||||
defer.returnValue(results)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def delete_room_keys(self, user_id, version, room_id=None, session_id=None):
|
||||
"""Bulk delete the E2E room keys for a given backup, optionally filtered to a given
|
||||
room or a given session.
|
||||
See EndToEndRoomKeyStore.delete_e2e_room_keys for full details.
|
||||
|
||||
Returns:
|
||||
A deferred of the deletion transaction
|
||||
"""
|
||||
|
||||
# lock for consistency with uploading
|
||||
with (yield self._upload_linearizer.queue(user_id)):
|
||||
yield self.store.delete_e2e_room_keys(user_id, version, room_id, session_id)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def upload_room_keys(self, user_id, version, room_keys):
|
||||
"""Bulk upload a list of room keys into a given backup version, asserting
|
||||
that the given version is the current backup version. room_keys are merged
|
||||
into the current backup as described in RoomKeysServlet.on_PUT().
|
||||
|
||||
Args:
|
||||
user_id(str): the user whose backup we're setting
|
||||
version(str): the version ID of the backup we're updating
|
||||
room_keys(dict): a nested dict describing the room_keys we're setting:
|
||||
|
||||
{
|
||||
"rooms": {
|
||||
"!abc:matrix.org": {
|
||||
"sessions": {
|
||||
"c0ff33": {
|
||||
"first_message_index": 1,
|
||||
"forwarded_count": 1,
|
||||
"is_verified": false,
|
||||
"session_data": "SSBBTSBBIEZJU0gK"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Raises:
|
||||
SynapseError: with code 404 if there are no versions defined
|
||||
RoomKeysVersionError: if the uploaded version is not the current version
|
||||
"""
|
||||
|
||||
# TODO: Validate the JSON to make sure it has the right keys.
|
||||
|
||||
# XXX: perhaps we should use a finer grained lock here?
|
||||
with (yield self._upload_linearizer.queue(user_id)):
|
||||
|
||||
# Check that the version we're trying to upload is the current version
|
||||
try:
|
||||
version_info = yield self._get_version_info_unlocked(user_id)
|
||||
except StoreError as e:
|
||||
if e.code == 404:
|
||||
raise SynapseError(404, "Version '%s' not found" % (version,))
|
||||
else:
|
||||
raise e
|
||||
|
||||
if version_info['version'] != version:
|
||||
# Check that the version we're trying to upload actually exists
|
||||
try:
|
||||
version_info = yield self._get_version_info_unlocked(user_id, version)
|
||||
# if we get this far, the version must exist
|
||||
raise RoomKeysVersionError(current_version=version_info['version'])
|
||||
except StoreError as e:
|
||||
if e.code == 404:
|
||||
raise SynapseError(404, "Version '%s' not found" % (version,))
|
||||
else:
|
||||
raise e
|
||||
|
||||
# go through the room_keys.
|
||||
# XXX: this should/could be done concurrently, given we're in a lock.
|
||||
for room_id, room in room_keys['rooms'].iteritems():
|
||||
for session_id, session in room['sessions'].iteritems():
|
||||
yield self._upload_room_key(
|
||||
user_id, version, room_id, session_id, session
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _upload_room_key(self, user_id, version, room_id, session_id, room_key):
|
||||
"""Upload a given room_key for a given room and session into a given
|
||||
version of the backup. Merges the key with any which might already exist.
|
||||
|
||||
Args:
|
||||
user_id(str): the user whose backup we're setting
|
||||
version(str): the version ID of the backup we're updating
|
||||
room_id(str): the ID of the room whose keys we're setting
|
||||
session_id(str): the session whose room_key we're setting
|
||||
room_key(dict): the room_key being set
|
||||
"""
|
||||
|
||||
# get the room_key for this particular row
|
||||
current_room_key = None
|
||||
try:
|
||||
current_room_key = yield self.store.get_e2e_room_key(
|
||||
user_id, version, room_id, session_id
|
||||
)
|
||||
except StoreError as e:
|
||||
if e.code == 404:
|
||||
pass
|
||||
else:
|
||||
raise e
|
||||
|
||||
if E2eRoomKeysHandler._should_replace_room_key(current_room_key, room_key):
|
||||
yield self.store.set_e2e_room_key(
|
||||
user_id, version, room_id, session_id, room_key
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _should_replace_room_key(current_room_key, room_key):
|
||||
"""
|
||||
Determine whether to replace a given current_room_key (if any)
|
||||
with a newly uploaded room_key backup
|
||||
|
||||
Args:
|
||||
current_room_key (dict): Optional, the current room_key dict if any
|
||||
room_key (dict): The new room_key dict which may or may not be fit to
|
||||
replace the current_room_key
|
||||
|
||||
Returns:
|
||||
True if current_room_key should be replaced by room_key in the backup
|
||||
"""
|
||||
|
||||
if current_room_key:
|
||||
# spelt out with if/elifs rather than nested boolean expressions
|
||||
# purely for legibility.
|
||||
|
||||
if room_key['is_verified'] and not current_room_key['is_verified']:
|
||||
pass
|
||||
elif (
|
||||
room_key['first_message_index'] <
|
||||
current_room_key['first_message_index']
|
||||
):
|
||||
pass
|
||||
elif room_key['forwarded_count'] < current_room_key['forwarded_count']:
|
||||
pass
|
||||
else:
|
||||
return False
|
||||
return True
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def create_version(self, user_id, version_info):
|
||||
"""Create a new backup version. This automatically becomes the new
|
||||
backup version for the user's keys; previous backups will no longer be
|
||||
writeable to.
|
||||
|
||||
Args:
|
||||
user_id(str): the user whose backup version we're creating
|
||||
version_info(dict): metadata about the new version being created
|
||||
|
||||
{
|
||||
"algorithm": "m.megolm_backup.v1",
|
||||
"auth_data": "dGhpcyBzaG91bGQgYWN0dWFsbHkgYmUgZW5jcnlwdGVkIGpzb24K"
|
||||
}
|
||||
|
||||
Returns:
|
||||
A deferred of a string that gives the new version number.
|
||||
"""
|
||||
|
||||
# TODO: Validate the JSON to make sure it has the right keys.
|
||||
|
||||
# lock everyone out until we've switched version
|
||||
with (yield self._upload_linearizer.queue(user_id)):
|
||||
new_version = yield self.store.create_e2e_room_keys_version(
|
||||
user_id, version_info
|
||||
)
|
||||
defer.returnValue(new_version)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_version_info(self, user_id, version=None):
|
||||
"""Get the info about a given version of the user's backup
|
||||
|
||||
Args:
|
||||
user_id(str): the user whose current backup version we're querying
|
||||
version(str): Optional; if None gives the most recent version
|
||||
otherwise a historical one.
|
||||
Raises:
|
||||
StoreError: code 404 if the requested backup version doesn't exist
|
||||
Returns:
|
||||
A deferred of a info dict that gives the info about the new version.
|
||||
|
||||
{
|
||||
"version": "1234",
|
||||
"algorithm": "m.megolm_backup.v1",
|
||||
"auth_data": "dGhpcyBzaG91bGQgYWN0dWFsbHkgYmUgZW5jcnlwdGVkIGpzb24K"
|
||||
}
|
||||
"""
|
||||
|
||||
with (yield self._upload_linearizer.queue(user_id)):
|
||||
res = yield self._get_version_info_unlocked(user_id, version)
|
||||
defer.returnValue(res)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _get_version_info_unlocked(self, user_id, version=None):
|
||||
"""Get the info about a given version of the user's backup
|
||||
without obtaining the upload_linearizer lock. For params see get_version_info
|
||||
"""
|
||||
|
||||
results = yield self.store.get_e2e_room_keys_version_info(user_id, version)
|
||||
defer.returnValue(results)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def delete_version(self, user_id, version=None):
|
||||
"""Deletes a given version of the user's e2e_room_keys backup
|
||||
|
||||
Args:
|
||||
user_id(str): the user whose current backup version we're deleting
|
||||
version(str): the version id of the backup being deleted
|
||||
Raises:
|
||||
StoreError: code 404 if this backup version doesn't exist
|
||||
"""
|
||||
|
||||
with (yield self._upload_linearizer.queue(user_id)):
|
||||
yield self.store.delete_e2e_room_keys_version(user_id, version)
|
||||
@@ -1,6 +1,5 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014-2016 OpenMarket Ltd
|
||||
# Copyright 2018 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -23,7 +22,6 @@ from ._base import BaseHandler
|
||||
|
||||
from synapse.api.errors import (
|
||||
AuthError, FederationError, StoreError, CodeMessageException, SynapseError,
|
||||
FederationDeniedError,
|
||||
)
|
||||
from synapse.api.constants import EventTypes, Membership, RejectedReason
|
||||
from synapse.events.validator import EventValidator
|
||||
@@ -76,7 +74,6 @@ class FederationHandler(BaseHandler):
|
||||
self.is_mine_id = hs.is_mine_id
|
||||
self.pusher_pool = hs.get_pusherpool()
|
||||
self.spam_checker = hs.get_spam_checker()
|
||||
self.event_creation_handler = hs.get_event_creation_handler()
|
||||
|
||||
self.replication_layer.set_handler(self)
|
||||
|
||||
@@ -785,9 +782,6 @@ class FederationHandler(BaseHandler):
|
||||
except NotRetryingDestination as e:
|
||||
logger.info(e.message)
|
||||
continue
|
||||
except FederationDeniedError as e:
|
||||
logger.info(e)
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"Failed to backfill from %s because %s",
|
||||
@@ -810,12 +804,13 @@ class FederationHandler(BaseHandler):
|
||||
event_ids = list(extremities.keys())
|
||||
|
||||
logger.debug("calling resolve_state_groups in _maybe_backfill")
|
||||
resolve = logcontext.preserve_fn(
|
||||
self.state_handler.resolve_state_groups_for_events
|
||||
)
|
||||
states = yield logcontext.make_deferred_yieldable(defer.gatherResults(
|
||||
[resolve(room_id, [e]) for e in event_ids],
|
||||
consumeErrors=True,
|
||||
[
|
||||
logcontext.preserve_fn(self.state_handler.resolve_state_groups)(
|
||||
room_id, [e]
|
||||
)
|
||||
for e in event_ids
|
||||
], consumeErrors=True,
|
||||
))
|
||||
states = dict(zip(event_ids, [s.state for s in states]))
|
||||
|
||||
@@ -1009,7 +1004,8 @@ class FederationHandler(BaseHandler):
|
||||
})
|
||||
|
||||
try:
|
||||
event, context = yield self.event_creation_handler.create_new_client_event(
|
||||
message_handler = self.hs.get_handlers().message_handler
|
||||
event, context = yield message_handler._create_new_client_event(
|
||||
builder=builder,
|
||||
)
|
||||
except AuthError as e:
|
||||
@@ -1249,7 +1245,8 @@ class FederationHandler(BaseHandler):
|
||||
"state_key": user_id,
|
||||
})
|
||||
|
||||
event, context = yield self.event_creation_handler.create_new_client_event(
|
||||
message_handler = self.hs.get_handlers().message_handler
|
||||
event, context = yield message_handler._create_new_client_event(
|
||||
builder=builder,
|
||||
)
|
||||
|
||||
@@ -1831,8 +1828,8 @@ class FederationHandler(BaseHandler):
|
||||
current_state = set(e.event_id for e in auth_events.values())
|
||||
different_auth = event_auth_events - current_state
|
||||
|
||||
yield self._update_context_for_auth_events(
|
||||
event, context, auth_events, event_key,
|
||||
self._update_context_for_auth_events(
|
||||
context, auth_events, event_key,
|
||||
)
|
||||
|
||||
if different_auth and not event.internal_metadata.is_outlier():
|
||||
@@ -1913,8 +1910,8 @@ class FederationHandler(BaseHandler):
|
||||
# 4. Look at rejects and their proofs.
|
||||
# TODO.
|
||||
|
||||
yield self._update_context_for_auth_events(
|
||||
event, context, auth_events, event_key,
|
||||
self._update_context_for_auth_events(
|
||||
context, auth_events, event_key,
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -1923,15 +1920,11 @@ class FederationHandler(BaseHandler):
|
||||
logger.warn("Failed auth resolution for %r because %s", event, e)
|
||||
raise e
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _update_context_for_auth_events(self, event, context, auth_events,
|
||||
def _update_context_for_auth_events(self, context, auth_events,
|
||||
event_key):
|
||||
"""Update the state_ids in an event context after auth event resolution,
|
||||
storing the changes as a new state group.
|
||||
"""Update the state_ids in an event context after auth event resolution
|
||||
|
||||
Args:
|
||||
event (Event): The event we're handling the context for
|
||||
|
||||
context (synapse.events.snapshot.EventContext): event context
|
||||
to be updated
|
||||
|
||||
@@ -1954,13 +1947,7 @@ class FederationHandler(BaseHandler):
|
||||
context.prev_state_ids.update({
|
||||
k: a.event_id for k, a in auth_events.iteritems()
|
||||
})
|
||||
context.state_group = yield self.store.store_state_group(
|
||||
event.event_id,
|
||||
event.room_id,
|
||||
prev_group=context.prev_group,
|
||||
delta_ids=context.delta_ids,
|
||||
current_state_ids=context.current_state_ids,
|
||||
)
|
||||
context.state_group = self.store.get_next_state_group()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def construct_auth_difference(self, local_auth, remote_auth):
|
||||
@@ -2130,7 +2117,8 @@ class FederationHandler(BaseHandler):
|
||||
if (yield self.auth.check_host_in_room(room_id, self.hs.hostname)):
|
||||
builder = self.event_builder_factory.new(event_dict)
|
||||
EventValidator().validate_new(builder)
|
||||
event, context = yield self.event_creation_handler.create_new_client_event(
|
||||
message_handler = self.hs.get_handlers().message_handler
|
||||
event, context = yield message_handler._create_new_client_event(
|
||||
builder=builder
|
||||
)
|
||||
|
||||
@@ -2168,7 +2156,8 @@ class FederationHandler(BaseHandler):
|
||||
"""
|
||||
builder = self.event_builder_factory.new(event_dict)
|
||||
|
||||
event, context = yield self.event_creation_handler.create_new_client_event(
|
||||
message_handler = self.hs.get_handlers().message_handler
|
||||
event, context = yield message_handler._create_new_client_event(
|
||||
builder=builder,
|
||||
)
|
||||
|
||||
@@ -2218,9 +2207,8 @@ class FederationHandler(BaseHandler):
|
||||
|
||||
builder = self.event_builder_factory.new(event_dict)
|
||||
EventValidator().validate_new(builder)
|
||||
event, context = yield self.event_creation_handler.create_new_client_event(
|
||||
builder=builder,
|
||||
)
|
||||
message_handler = self.hs.get_handlers().message_handler
|
||||
event, context = yield message_handler._create_new_client_event(builder=builder)
|
||||
defer.returnValue((event, context))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
||||
@@ -383,12 +383,11 @@ class GroupsLocalHandler(object):
|
||||
|
||||
defer.returnValue({"groups": result})
|
||||
else:
|
||||
bulk_result = yield self.transport_client.bulk_get_publicised_groups(
|
||||
get_domain_from_id(user_id), [user_id],
|
||||
result = yield self.transport_client.get_publicised_groups_for_user(
|
||||
get_domain_from_id(user_id), user_id
|
||||
)
|
||||
result = bulk_result.get("users", {}).get(user_id)
|
||||
# TODO: Verify attestations
|
||||
defer.returnValue({"groups": result})
|
||||
defer.returnValue(result)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def bulk_get_publicised_groups(self, user_ids, proxy=True):
|
||||
|
||||
@@ -372,7 +372,6 @@ class InitialSyncHandler(BaseHandler):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_presence():
|
||||
defer.returnValue([])
|
||||
states = yield presence_handler.get_states(
|
||||
[m.user_id for m in room_members],
|
||||
as_event=True,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 - 2016 OpenMarket Ltd
|
||||
# Copyright 2017 - 2018 New Vector Ltd
|
||||
# Copyright 2017 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -28,7 +28,6 @@ from synapse.util.logcontext import preserve_fn
|
||||
from synapse.util.metrics import measure_func
|
||||
from synapse.util.frozenutils import unfreeze
|
||||
from synapse.visibility import filter_events_for_client
|
||||
from synapse.replication.http.send_event import send_event_to_master
|
||||
|
||||
from ._base import BaseHandler
|
||||
|
||||
@@ -48,11 +47,23 @@ class MessageHandler(BaseHandler):
|
||||
self.hs = hs
|
||||
self.state = hs.get_state_handler()
|
||||
self.clock = hs.get_clock()
|
||||
self.validator = EventValidator()
|
||||
self.profile_handler = hs.get_profile_handler()
|
||||
|
||||
self.pagination_lock = ReadWriteLock()
|
||||
|
||||
self.pusher_pool = hs.get_pusherpool()
|
||||
|
||||
# We arbitrarily limit concurrent event creation for a room to 5.
|
||||
# This is to stop us from diverging history *too* much.
|
||||
self.limiter = Limiter(max_count=5)
|
||||
|
||||
self.action_generator = hs.get_action_generator()
|
||||
|
||||
self.spam_checker = hs.get_spam_checker()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def purge_history(self, room_id, event_id, delete_local_events=False):
|
||||
def purge_history(self, room_id, event_id):
|
||||
event = yield self.store.get_event(event_id)
|
||||
|
||||
if event.room_id != room_id:
|
||||
@@ -61,7 +72,7 @@ class MessageHandler(BaseHandler):
|
||||
depth = event.depth
|
||||
|
||||
with (yield self.pagination_lock.write(room_id)):
|
||||
yield self.store.purge_history(room_id, depth, delete_local_events)
|
||||
yield self.store.delete_old_state(room_id, depth)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_messages(self, requester, room_id=None, pagin_config=None,
|
||||
@@ -171,6 +182,166 @@ class MessageHandler(BaseHandler):
|
||||
|
||||
defer.returnValue(chunk)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def create_event(self, requester, event_dict, token_id=None, txn_id=None,
|
||||
prev_event_ids=None):
|
||||
"""
|
||||
Given a dict from a client, create a new event.
|
||||
|
||||
Creates an FrozenEvent object, filling out auth_events, prev_events,
|
||||
etc.
|
||||
|
||||
Adds display names to Join membership events.
|
||||
|
||||
Args:
|
||||
requester
|
||||
event_dict (dict): An entire event
|
||||
token_id (str)
|
||||
txn_id (str)
|
||||
prev_event_ids (list): The prev event ids to use when creating the event
|
||||
|
||||
Returns:
|
||||
Tuple of created event (FrozenEvent), Context
|
||||
"""
|
||||
builder = self.event_builder_factory.new(event_dict)
|
||||
|
||||
with (yield self.limiter.queue(builder.room_id)):
|
||||
self.validator.validate_new(builder)
|
||||
|
||||
if builder.type == EventTypes.Member:
|
||||
membership = builder.content.get("membership", None)
|
||||
target = UserID.from_string(builder.state_key)
|
||||
|
||||
if membership in {Membership.JOIN, Membership.INVITE}:
|
||||
# If event doesn't include a display name, add one.
|
||||
profile = self.profile_handler
|
||||
content = builder.content
|
||||
|
||||
try:
|
||||
if "displayname" not in content:
|
||||
content["displayname"] = yield profile.get_displayname(target)
|
||||
if "avatar_url" not in content:
|
||||
content["avatar_url"] = yield profile.get_avatar_url(target)
|
||||
except Exception as e:
|
||||
logger.info(
|
||||
"Failed to get profile information for %r: %s",
|
||||
target, e
|
||||
)
|
||||
|
||||
if token_id is not None:
|
||||
builder.internal_metadata.token_id = token_id
|
||||
|
||||
if txn_id is not None:
|
||||
builder.internal_metadata.txn_id = txn_id
|
||||
|
||||
event, context = yield self._create_new_client_event(
|
||||
builder=builder,
|
||||
requester=requester,
|
||||
prev_event_ids=prev_event_ids,
|
||||
)
|
||||
|
||||
defer.returnValue((event, context))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def send_nonmember_event(self, requester, event, context, ratelimit=True):
|
||||
"""
|
||||
Persists and notifies local clients and federation of an event.
|
||||
|
||||
Args:
|
||||
event (FrozenEvent) the event to send.
|
||||
context (Context) the context of the event.
|
||||
ratelimit (bool): Whether to rate limit this send.
|
||||
is_guest (bool): Whether the sender is a guest.
|
||||
"""
|
||||
if event.type == EventTypes.Member:
|
||||
raise SynapseError(
|
||||
500,
|
||||
"Tried to send member event through non-member codepath"
|
||||
)
|
||||
|
||||
# We check here if we are currently being rate limited, so that we
|
||||
# don't do unnecessary work. We check again just before we actually
|
||||
# send the event.
|
||||
yield self.ratelimit(requester, update=False)
|
||||
|
||||
user = UserID.from_string(event.sender)
|
||||
|
||||
assert self.hs.is_mine(user), "User must be our own: %s" % (user,)
|
||||
|
||||
if event.is_state():
|
||||
prev_state = yield self.deduplicate_state_event(event, context)
|
||||
if prev_state is not None:
|
||||
defer.returnValue(prev_state)
|
||||
|
||||
yield self.handle_new_client_event(
|
||||
requester=requester,
|
||||
event=event,
|
||||
context=context,
|
||||
ratelimit=ratelimit,
|
||||
)
|
||||
|
||||
if event.type == EventTypes.Message:
|
||||
presence = self.hs.get_presence_handler()
|
||||
# We don't want to block sending messages on any presence code. This
|
||||
# matters as sometimes presence code can take a while.
|
||||
preserve_fn(presence.bump_presence_active_time)(user)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def deduplicate_state_event(self, event, context):
|
||||
"""
|
||||
Checks whether event is in the latest resolved state in context.
|
||||
|
||||
If so, returns the version of the event in context.
|
||||
Otherwise, returns None.
|
||||
"""
|
||||
prev_event_id = context.prev_state_ids.get((event.type, event.state_key))
|
||||
prev_event = yield self.store.get_event(prev_event_id, allow_none=True)
|
||||
if not prev_event:
|
||||
return
|
||||
|
||||
if prev_event and event.user_id == prev_event.user_id:
|
||||
prev_content = encode_canonical_json(prev_event.content)
|
||||
next_content = encode_canonical_json(event.content)
|
||||
if prev_content == next_content:
|
||||
defer.returnValue(prev_event)
|
||||
return
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def create_and_send_nonmember_event(
|
||||
self,
|
||||
requester,
|
||||
event_dict,
|
||||
ratelimit=True,
|
||||
txn_id=None
|
||||
):
|
||||
"""
|
||||
Creates an event, then sends it.
|
||||
|
||||
See self.create_event and self.send_nonmember_event.
|
||||
"""
|
||||
event, context = yield self.create_event(
|
||||
requester,
|
||||
event_dict,
|
||||
token_id=requester.access_token_id,
|
||||
txn_id=txn_id
|
||||
)
|
||||
|
||||
spam_error = self.spam_checker.check_event_for_spam(event)
|
||||
if spam_error:
|
||||
if not isinstance(spam_error, basestring):
|
||||
spam_error = "Spam is not permitted here"
|
||||
raise SynapseError(
|
||||
403, spam_error, Codes.FORBIDDEN
|
||||
)
|
||||
|
||||
yield self.send_nonmember_event(
|
||||
requester,
|
||||
event,
|
||||
context,
|
||||
ratelimit=ratelimit,
|
||||
)
|
||||
defer.returnValue(event)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_room_data(self, user_id=None, room_id=None,
|
||||
event_type=None, state_key="", is_guest=False):
|
||||
@@ -283,7 +454,7 @@ class MessageHandler(BaseHandler):
|
||||
# If this is an AS, double check that they are allowed to see the members.
|
||||
# This can either be because the AS user is in the room or becuase there
|
||||
# is a user in the room that the AS is "interested in"
|
||||
if False and requester.app_service and user_id not in users_with_profile:
|
||||
if requester.app_service and user_id not in users_with_profile:
|
||||
for uid in users_with_profile:
|
||||
if requester.app_service.is_interested_in_user(uid):
|
||||
break
|
||||
@@ -299,189 +470,9 @@ class MessageHandler(BaseHandler):
|
||||
for user_id, profile in users_with_profile.iteritems()
|
||||
})
|
||||
|
||||
|
||||
class EventCreationHandler(object):
|
||||
def __init__(self, hs):
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.store = hs.get_datastore()
|
||||
self.state = hs.get_state_handler()
|
||||
self.clock = hs.get_clock()
|
||||
self.validator = EventValidator()
|
||||
self.profile_handler = hs.get_profile_handler()
|
||||
self.event_builder_factory = hs.get_event_builder_factory()
|
||||
self.server_name = hs.hostname
|
||||
self.ratelimiter = hs.get_ratelimiter()
|
||||
self.notifier = hs.get_notifier()
|
||||
self.config = hs.config
|
||||
|
||||
self.http_client = hs.get_simple_http_client()
|
||||
|
||||
# This is only used to get at ratelimit function, and maybe_kick_guest_users
|
||||
self.base_handler = BaseHandler(hs)
|
||||
|
||||
self.pusher_pool = hs.get_pusherpool()
|
||||
|
||||
# We arbitrarily limit concurrent event creation for a room to 5.
|
||||
# This is to stop us from diverging history *too* much.
|
||||
self.limiter = Limiter(max_count=5)
|
||||
|
||||
self.action_generator = hs.get_action_generator()
|
||||
|
||||
self.spam_checker = hs.get_spam_checker()
|
||||
|
||||
@measure_func("_create_new_client_event")
|
||||
@defer.inlineCallbacks
|
||||
def create_event(self, requester, event_dict, token_id=None, txn_id=None,
|
||||
prev_event_ids=None):
|
||||
"""
|
||||
Given a dict from a client, create a new event.
|
||||
|
||||
Creates an FrozenEvent object, filling out auth_events, prev_events,
|
||||
etc.
|
||||
|
||||
Adds display names to Join membership events.
|
||||
|
||||
Args:
|
||||
requester
|
||||
event_dict (dict): An entire event
|
||||
token_id (str)
|
||||
txn_id (str)
|
||||
prev_event_ids (list): The prev event ids to use when creating the event
|
||||
|
||||
Returns:
|
||||
Tuple of created event (FrozenEvent), Context
|
||||
"""
|
||||
builder = self.event_builder_factory.new(event_dict)
|
||||
|
||||
with (yield self.limiter.queue(builder.room_id)):
|
||||
self.validator.validate_new(builder)
|
||||
|
||||
if builder.type == EventTypes.Member:
|
||||
membership = builder.content.get("membership", None)
|
||||
target = UserID.from_string(builder.state_key)
|
||||
|
||||
if membership in {Membership.JOIN, Membership.INVITE}:
|
||||
# If event doesn't include a display name, add one.
|
||||
profile = self.profile_handler
|
||||
content = builder.content
|
||||
|
||||
try:
|
||||
if "displayname" not in content:
|
||||
content["displayname"] = yield profile.get_displayname(target)
|
||||
if "avatar_url" not in content:
|
||||
content["avatar_url"] = yield profile.get_avatar_url(target)
|
||||
except Exception as e:
|
||||
logger.info(
|
||||
"Failed to get profile information for %r: %s",
|
||||
target, e
|
||||
)
|
||||
|
||||
if token_id is not None:
|
||||
builder.internal_metadata.token_id = token_id
|
||||
|
||||
if txn_id is not None:
|
||||
builder.internal_metadata.txn_id = txn_id
|
||||
|
||||
event, context = yield self.create_new_client_event(
|
||||
builder=builder,
|
||||
requester=requester,
|
||||
prev_event_ids=prev_event_ids,
|
||||
)
|
||||
|
||||
defer.returnValue((event, context))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def send_nonmember_event(self, requester, event, context, ratelimit=True):
|
||||
"""
|
||||
Persists and notifies local clients and federation of an event.
|
||||
|
||||
Args:
|
||||
event (FrozenEvent) the event to send.
|
||||
context (Context) the context of the event.
|
||||
ratelimit (bool): Whether to rate limit this send.
|
||||
is_guest (bool): Whether the sender is a guest.
|
||||
"""
|
||||
if event.type == EventTypes.Member:
|
||||
raise SynapseError(
|
||||
500,
|
||||
"Tried to send member event through non-member codepath"
|
||||
)
|
||||
|
||||
user = UserID.from_string(event.sender)
|
||||
|
||||
assert self.hs.is_mine(user), "User must be our own: %s" % (user,)
|
||||
|
||||
if event.is_state():
|
||||
prev_state = yield self.deduplicate_state_event(event, context)
|
||||
if prev_state is not None:
|
||||
defer.returnValue(prev_state)
|
||||
|
||||
yield self.handle_new_client_event(
|
||||
requester=requester,
|
||||
event=event,
|
||||
context=context,
|
||||
ratelimit=ratelimit,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def deduplicate_state_event(self, event, context):
|
||||
"""
|
||||
Checks whether event is in the latest resolved state in context.
|
||||
|
||||
If so, returns the version of the event in context.
|
||||
Otherwise, returns None.
|
||||
"""
|
||||
prev_event_id = context.prev_state_ids.get((event.type, event.state_key))
|
||||
prev_event = yield self.store.get_event(prev_event_id, allow_none=True)
|
||||
if not prev_event:
|
||||
return
|
||||
|
||||
if prev_event and event.user_id == prev_event.user_id:
|
||||
prev_content = encode_canonical_json(prev_event.content)
|
||||
next_content = encode_canonical_json(event.content)
|
||||
if prev_content == next_content:
|
||||
defer.returnValue(prev_event)
|
||||
return
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def create_and_send_nonmember_event(
|
||||
self,
|
||||
requester,
|
||||
event_dict,
|
||||
ratelimit=True,
|
||||
txn_id=None
|
||||
):
|
||||
"""
|
||||
Creates an event, then sends it.
|
||||
|
||||
See self.create_event and self.send_nonmember_event.
|
||||
"""
|
||||
event, context = yield self.create_event(
|
||||
requester,
|
||||
event_dict,
|
||||
token_id=requester.access_token_id,
|
||||
txn_id=txn_id
|
||||
)
|
||||
|
||||
spam_error = self.spam_checker.check_event_for_spam(event)
|
||||
if spam_error:
|
||||
if not isinstance(spam_error, basestring):
|
||||
spam_error = "Spam is not permitted here"
|
||||
raise SynapseError(
|
||||
403, spam_error, Codes.FORBIDDEN
|
||||
)
|
||||
|
||||
yield self.send_nonmember_event(
|
||||
requester,
|
||||
event,
|
||||
context,
|
||||
ratelimit=ratelimit,
|
||||
)
|
||||
defer.returnValue(event)
|
||||
|
||||
@measure_func("create_new_client_event")
|
||||
@defer.inlineCallbacks
|
||||
def create_new_client_event(self, builder, requester=None, prev_event_ids=None):
|
||||
def _create_new_client_event(self, builder, requester=None, prev_event_ids=None):
|
||||
if prev_event_ids:
|
||||
prev_events = yield self.store.add_event_hashes(prev_event_ids)
|
||||
prev_max_depth = yield self.store.get_max_depth_of_events(prev_event_ids)
|
||||
@@ -518,7 +509,9 @@ class EventCreationHandler(object):
|
||||
builder.prev_events = prev_events
|
||||
builder.depth = depth
|
||||
|
||||
context = yield self.state.compute_event_context(builder)
|
||||
state_handler = self.state_handler
|
||||
|
||||
context = yield state_handler.compute_event_context(builder)
|
||||
if requester:
|
||||
context.app_service = requester.app_service
|
||||
|
||||
@@ -557,20 +550,8 @@ class EventCreationHandler(object):
|
||||
):
|
||||
# We now need to go and hit out to wherever we need to hit out to.
|
||||
|
||||
# If we're a worker we need to hit out to the master.
|
||||
if self.config.worker_app:
|
||||
yield send_event_to_master(
|
||||
self.http_client,
|
||||
host=self.config.worker_replication_host,
|
||||
port=self.config.worker_replication_http_port,
|
||||
requester=requester,
|
||||
event=event,
|
||||
context=context,
|
||||
)
|
||||
return
|
||||
|
||||
if ratelimit:
|
||||
yield self.base_handler.ratelimit(requester)
|
||||
yield self.ratelimit(requester)
|
||||
|
||||
try:
|
||||
yield self.auth.check_from_context(event, context)
|
||||
@@ -586,7 +567,7 @@ class EventCreationHandler(object):
|
||||
logger.exception("Failed to encode content: %r", event.content)
|
||||
raise
|
||||
|
||||
yield self.base_handler.maybe_kick_guest_users(event, context)
|
||||
yield self.maybe_kick_guest_users(event, context)
|
||||
|
||||
if event.type == EventTypes.CanonicalAlias:
|
||||
# Check the alias is acually valid (at this time at least)
|
||||
@@ -702,9 +683,3 @@ class EventCreationHandler(object):
|
||||
)
|
||||
|
||||
preserve_fn(_notify)()
|
||||
|
||||
if event.type == EventTypes.Message:
|
||||
presence = self.hs.get_presence_handler()
|
||||
# We don't want to block sending messages on any presence code. This
|
||||
# matters as sometimes presence code can take a while.
|
||||
preserve_fn(presence.bump_presence_active_time)(requester.user)
|
||||
|
||||
@@ -372,7 +372,6 @@ class PresenceHandler(object):
|
||||
"""We've seen the user do something that indicates they're interacting
|
||||
with the app.
|
||||
"""
|
||||
return
|
||||
user_id = user.to_string()
|
||||
|
||||
bump_active_time_counter.inc()
|
||||
@@ -402,7 +401,6 @@ class PresenceHandler(object):
|
||||
Useful for streams that are not associated with an actual
|
||||
client that is being used by a user.
|
||||
"""
|
||||
affect_presence = False
|
||||
if affect_presence:
|
||||
curr_sync = self.user_to_num_current_syncs.get(user_id, 0)
|
||||
self.user_to_num_current_syncs[user_id] = curr_sync + 1
|
||||
@@ -445,8 +443,6 @@ class PresenceHandler(object):
|
||||
Returns:
|
||||
set(str): A set of user_id strings.
|
||||
"""
|
||||
# presence is disabled on matrix.org, so we return the empty set
|
||||
return set()
|
||||
syncing_user_ids = {
|
||||
user_id for user_id, count in self.user_to_num_current_syncs.items()
|
||||
if count
|
||||
@@ -466,7 +462,6 @@ class PresenceHandler(object):
|
||||
syncing_user_ids(set(str)): The set of user_ids that are
|
||||
currently syncing on that server.
|
||||
"""
|
||||
return
|
||||
|
||||
# Grab the previous list of user_ids that were syncing on that process
|
||||
prev_syncing_user_ids = (
|
||||
|
||||
@@ -25,7 +25,6 @@ from synapse.http.client import CaptchaServerHttpClient
|
||||
from synapse import types
|
||||
from synapse.types import UserID
|
||||
from synapse.util.async import run_on_reactor
|
||||
from synapse.util.threepids import check_3pid_allowed
|
||||
from ._base import BaseHandler
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -132,7 +131,7 @@ class RegistrationHandler(BaseHandler):
|
||||
yield run_on_reactor()
|
||||
password_hash = None
|
||||
if password:
|
||||
password_hash = yield self.auth_handler().hash(password)
|
||||
password_hash = self.auth_handler().hash(password)
|
||||
|
||||
if localpart:
|
||||
yield self.check_username(localpart, guest_access_token=guest_access_token)
|
||||
@@ -294,7 +293,7 @@ class RegistrationHandler(BaseHandler):
|
||||
"""
|
||||
|
||||
for c in threepidCreds:
|
||||
logger.info("validating threepidcred sid %s on id server %s",
|
||||
logger.info("validating theeepidcred sid %s on id server %s",
|
||||
c['sid'], c['idServer'])
|
||||
try:
|
||||
identity_handler = self.hs.get_handlers().identity_handler
|
||||
@@ -308,11 +307,6 @@ class RegistrationHandler(BaseHandler):
|
||||
logger.info("got threepid with medium '%s' and address '%s'",
|
||||
threepid['medium'], threepid['address'])
|
||||
|
||||
if not check_3pid_allowed(self.hs, threepid['medium'], threepid['address']):
|
||||
raise RegistrationError(
|
||||
403, "Third party identifier is not allowed"
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def bind_emails(self, user_id, threepidCreds):
|
||||
"""Links emails with a user ID and informs an identity server.
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 - 2016 OpenMarket Ltd
|
||||
# Copyright 2018 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -65,7 +64,6 @@ class RoomCreationHandler(BaseHandler):
|
||||
super(RoomCreationHandler, self).__init__(hs)
|
||||
|
||||
self.spam_checker = hs.get_spam_checker()
|
||||
self.event_creation_handler = hs.get_event_creation_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def create_room(self, requester, config, ratelimit=True):
|
||||
@@ -165,11 +163,13 @@ class RoomCreationHandler(BaseHandler):
|
||||
|
||||
creation_content = config.get("creation_content", {})
|
||||
|
||||
msg_handler = self.hs.get_handlers().message_handler
|
||||
room_member_handler = self.hs.get_handlers().room_member_handler
|
||||
|
||||
yield self._send_events_for_new_room(
|
||||
requester,
|
||||
room_id,
|
||||
msg_handler,
|
||||
room_member_handler,
|
||||
preset_config=preset_config,
|
||||
invite_list=invite_list,
|
||||
@@ -181,7 +181,7 @@ class RoomCreationHandler(BaseHandler):
|
||||
|
||||
if "name" in config:
|
||||
name = config["name"]
|
||||
yield self.event_creation_handler.create_and_send_nonmember_event(
|
||||
yield msg_handler.create_and_send_nonmember_event(
|
||||
requester,
|
||||
{
|
||||
"type": EventTypes.Name,
|
||||
@@ -194,7 +194,7 @@ class RoomCreationHandler(BaseHandler):
|
||||
|
||||
if "topic" in config:
|
||||
topic = config["topic"]
|
||||
yield self.event_creation_handler.create_and_send_nonmember_event(
|
||||
yield msg_handler.create_and_send_nonmember_event(
|
||||
requester,
|
||||
{
|
||||
"type": EventTypes.Topic,
|
||||
@@ -249,6 +249,7 @@ class RoomCreationHandler(BaseHandler):
|
||||
self,
|
||||
creator, # A Requester object.
|
||||
room_id,
|
||||
msg_handler,
|
||||
room_member_handler,
|
||||
preset_config,
|
||||
invite_list,
|
||||
@@ -271,7 +272,7 @@ class RoomCreationHandler(BaseHandler):
|
||||
@defer.inlineCallbacks
|
||||
def send(etype, content, **kwargs):
|
||||
event = create(etype, content, **kwargs)
|
||||
yield self.event_creation_handler.create_and_send_nonmember_event(
|
||||
yield msg_handler.create_and_send_nonmember_event(
|
||||
creator,
|
||||
event,
|
||||
ratelimit=False
|
||||
|
||||
@@ -44,7 +44,7 @@ EMTPY_THIRD_PARTY_ID = ThirdPartyInstanceID(None, None)
|
||||
class RoomListHandler(BaseHandler):
|
||||
def __init__(self, hs):
|
||||
super(RoomListHandler, self).__init__(hs)
|
||||
self.response_cache = ResponseCache(hs, timeout_ms=10 * 60 * 1000)
|
||||
self.response_cache = ResponseCache(hs)
|
||||
self.remote_response_cache = ResponseCache(hs, timeout_ms=30 * 1000)
|
||||
|
||||
def get_local_public_room_list(self, limit=None, since_token=None,
|
||||
@@ -203,8 +203,7 @@ class RoomListHandler(BaseHandler):
|
||||
if limit:
|
||||
step = limit + 1
|
||||
else:
|
||||
# step cannot be zero
|
||||
step = len(rooms_to_scan) if len(rooms_to_scan) != 0 else 1
|
||||
step = len(rooms_to_scan)
|
||||
|
||||
chunk = []
|
||||
for i in xrange(0, len(rooms_to_scan), step):
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2016 OpenMarket Ltd
|
||||
# Copyright 2018 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -28,7 +27,7 @@ from synapse.api.constants import (
|
||||
)
|
||||
from synapse.api.errors import AuthError, SynapseError, Codes
|
||||
from synapse.types import UserID, RoomID
|
||||
from synapse.util.async import Linearizer, Limiter
|
||||
from synapse.util.async import Linearizer
|
||||
from synapse.util.distributor import user_left_room, user_joined_room
|
||||
from ._base import BaseHandler
|
||||
|
||||
@@ -47,10 +46,8 @@ class RoomMemberHandler(BaseHandler):
|
||||
super(RoomMemberHandler, self).__init__(hs)
|
||||
|
||||
self.profile_handler = hs.get_profile_handler()
|
||||
self.event_creation_hander = hs.get_event_creation_handler()
|
||||
|
||||
self.member_linearizer = Linearizer(name="member")
|
||||
self.member_limiter = Limiter(3)
|
||||
|
||||
self.clock = hs.get_clock()
|
||||
self.spam_checker = hs.get_spam_checker()
|
||||
@@ -69,12 +66,13 @@ class RoomMemberHandler(BaseHandler):
|
||||
):
|
||||
if content is None:
|
||||
content = {}
|
||||
msg_handler = self.hs.get_handlers().message_handler
|
||||
|
||||
content["membership"] = membership
|
||||
if requester.is_guest:
|
||||
content["kind"] = "guest"
|
||||
|
||||
event, context = yield self.event_creation_hander.create_event(
|
||||
event, context = yield msg_handler.create_event(
|
||||
requester,
|
||||
{
|
||||
"type": EventTypes.Member,
|
||||
@@ -92,14 +90,12 @@ class RoomMemberHandler(BaseHandler):
|
||||
)
|
||||
|
||||
# Check if this event matches the previous membership event for the user.
|
||||
duplicate = yield self.event_creation_hander.deduplicate_state_event(
|
||||
event, context,
|
||||
)
|
||||
duplicate = yield msg_handler.deduplicate_state_event(event, context)
|
||||
if duplicate is not None:
|
||||
# Discard the new event since this membership change is a no-op.
|
||||
defer.returnValue(duplicate)
|
||||
|
||||
yield self.event_creation_hander.handle_new_client_event(
|
||||
yield msg_handler.handle_new_client_event(
|
||||
requester,
|
||||
event,
|
||||
context,
|
||||
@@ -162,23 +158,18 @@ class RoomMemberHandler(BaseHandler):
|
||||
):
|
||||
key = (room_id,)
|
||||
|
||||
as_id = object()
|
||||
if requester.app_service:
|
||||
as_id = requester.app_service.id
|
||||
|
||||
with (yield self.member_limiter.queue(as_id)):
|
||||
with (yield self.member_linearizer.queue(key)):
|
||||
result = yield self._update_membership(
|
||||
requester,
|
||||
target,
|
||||
room_id,
|
||||
action,
|
||||
txn_id=txn_id,
|
||||
remote_room_hosts=remote_room_hosts,
|
||||
third_party_signed=third_party_signed,
|
||||
ratelimit=ratelimit,
|
||||
content=content,
|
||||
)
|
||||
with (yield self.member_linearizer.queue(key)):
|
||||
result = yield self._update_membership(
|
||||
requester,
|
||||
target,
|
||||
room_id,
|
||||
action,
|
||||
txn_id=txn_id,
|
||||
remote_room_hosts=remote_room_hosts,
|
||||
third_party_signed=third_party_signed,
|
||||
ratelimit=ratelimit,
|
||||
content=content,
|
||||
)
|
||||
|
||||
defer.returnValue(result)
|
||||
|
||||
@@ -403,9 +394,8 @@ class RoomMemberHandler(BaseHandler):
|
||||
else:
|
||||
requester = synapse.types.create_requester(target_user)
|
||||
|
||||
prev_event = yield self.event_creation_hander.deduplicate_state_event(
|
||||
event, context,
|
||||
)
|
||||
message_handler = self.hs.get_handlers().message_handler
|
||||
prev_event = yield message_handler.deduplicate_state_event(event, context)
|
||||
if prev_event is not None:
|
||||
return
|
||||
|
||||
@@ -422,7 +412,7 @@ class RoomMemberHandler(BaseHandler):
|
||||
if is_blocked:
|
||||
raise SynapseError(403, "This room has been blocked on this server")
|
||||
|
||||
yield self.event_creation_hander.handle_new_client_event(
|
||||
yield message_handler.handle_new_client_event(
|
||||
requester,
|
||||
event,
|
||||
context,
|
||||
@@ -654,7 +644,8 @@ class RoomMemberHandler(BaseHandler):
|
||||
)
|
||||
)
|
||||
|
||||
yield self.event_creation_hander.create_and_send_nonmember_event(
|
||||
msg_handler = self.hs.get_handlers().message_handler
|
||||
yield msg_handler.create_and_send_nonmember_event(
|
||||
requester,
|
||||
{
|
||||
"type": EventTypes.ThirdPartyInvite,
|
||||
|
||||
@@ -31,7 +31,7 @@ class SetPasswordHandler(BaseHandler):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def set_password(self, user_id, newpassword, requester=None):
|
||||
password_hash = yield self._auth_handler.hash(newpassword)
|
||||
password_hash = self._auth_handler.hash(newpassword)
|
||||
|
||||
except_device_id = requester.device_id if requester else None
|
||||
except_access_token_id = requester.access_token_id if requester else None
|
||||
|
||||
@@ -585,7 +585,7 @@ class SyncHandler(object):
|
||||
since_token is None and
|
||||
sync_config.filter_collection.blocks_all_presence()
|
||||
)
|
||||
if False and not block_all_presence_data:
|
||||
if not block_all_presence_data:
|
||||
yield self._generate_sync_entry_for_presence(
|
||||
sync_result_builder, newly_joined_rooms, newly_joined_users
|
||||
)
|
||||
|
||||
@@ -18,7 +18,6 @@ from OpenSSL.SSL import VERIFY_NONE
|
||||
from synapse.api.errors import (
|
||||
CodeMessageException, MatrixCodeMessageException, SynapseError, Codes,
|
||||
)
|
||||
from synapse.util.caches import CACHE_SIZE_FACTOR
|
||||
from synapse.util.logcontext import make_deferred_yieldable
|
||||
from synapse.util import logcontext
|
||||
import synapse.metrics
|
||||
@@ -31,7 +30,6 @@ from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
|
||||
from twisted.web.client import (
|
||||
BrowserLikeRedirectAgent, ContentDecoderAgent, GzipDecoder, Agent,
|
||||
readBody, PartialDownloadError,
|
||||
HTTPConnectionPool,
|
||||
)
|
||||
from twisted.web.client import FileBodyProducer as TwistedFileBodyProducer
|
||||
from twisted.web.http import PotentialDataLoss
|
||||
@@ -66,23 +64,13 @@ class SimpleHttpClient(object):
|
||||
"""
|
||||
def __init__(self, hs):
|
||||
self.hs = hs
|
||||
|
||||
pool = HTTPConnectionPool(reactor)
|
||||
|
||||
# the pusher makes lots of concurrent SSL connections to sygnal, and
|
||||
# tends to do so in batches, so we need to allow the pool to keep lots
|
||||
# of idle connections around.
|
||||
pool.maxPersistentPerHost = max((100 * CACHE_SIZE_FACTOR, 5))
|
||||
pool.cachedConnectionTimeout = 2 * 60
|
||||
|
||||
# The default context factory in Twisted 14.0.0 (which we require) is
|
||||
# BrowserLikePolicyForHTTPS which will do regular cert validation
|
||||
# 'like a browser'
|
||||
self.agent = Agent(
|
||||
reactor,
|
||||
connectTimeout=15,
|
||||
contextFactory=hs.get_http_client_context_factory(),
|
||||
pool=pool,
|
||||
contextFactory=hs.get_http_client_context_factory()
|
||||
)
|
||||
self.user_agent = hs.version_string
|
||||
self.clock = hs.get_clock()
|
||||
|
||||
@@ -357,7 +357,8 @@ def _get_hosts_for_srv_record(dns_client, host):
|
||||
def eb(res, record_type):
|
||||
if res.check(DNSNameError):
|
||||
return []
|
||||
logger.warn("Error looking up %s for %s: %s", record_type, host, res)
|
||||
logger.warn("Error looking up %s for %s: %s",
|
||||
record_type, host, res, res.value)
|
||||
return res
|
||||
|
||||
# no logcontexts here, so we can safely fire these off and gatherResults
|
||||
|
||||
@@ -27,7 +27,7 @@ import synapse.metrics
|
||||
from canonicaljson import encode_canonical_json
|
||||
|
||||
from synapse.api.errors import (
|
||||
SynapseError, Codes, HttpResponseException, FederationDeniedError,
|
||||
SynapseError, Codes, HttpResponseException,
|
||||
)
|
||||
|
||||
from signedjson.sign import sign_json
|
||||
@@ -123,22 +123,11 @@ class MatrixFederationHttpClient(object):
|
||||
|
||||
Fails with ``HTTPRequestException``: if we get an HTTP response
|
||||
code >= 300.
|
||||
|
||||
Fails with ``NotRetryingDestination`` if we are not yet ready
|
||||
to retry this server.
|
||||
|
||||
Fails with ``FederationDeniedError`` if this destination
|
||||
is not on our federation whitelist
|
||||
|
||||
(May also fail with plenty of other Exceptions for things like DNS
|
||||
failures, connection failures, SSL failures.)
|
||||
"""
|
||||
if (
|
||||
self.hs.config.federation_domain_whitelist and
|
||||
destination not in self.hs.config.federation_domain_whitelist
|
||||
):
|
||||
raise FederationDeniedError(destination)
|
||||
|
||||
limiter = yield synapse.util.retryutils.get_retry_limiter(
|
||||
destination,
|
||||
self.clock,
|
||||
@@ -319,9 +308,6 @@ class MatrixFederationHttpClient(object):
|
||||
|
||||
Fails with ``NotRetryingDestination`` if we are not yet ready
|
||||
to retry this server.
|
||||
|
||||
Fails with ``FederationDeniedError`` if this destination
|
||||
is not on our federation whitelist
|
||||
"""
|
||||
|
||||
if not json_data_callback:
|
||||
@@ -382,9 +368,6 @@ class MatrixFederationHttpClient(object):
|
||||
|
||||
Fails with ``NotRetryingDestination`` if we are not yet ready
|
||||
to retry this server.
|
||||
|
||||
Fails with ``FederationDeniedError`` if this destination
|
||||
is not on our federation whitelist
|
||||
"""
|
||||
|
||||
def body_callback(method, url_bytes, headers_dict):
|
||||
@@ -439,9 +422,6 @@ class MatrixFederationHttpClient(object):
|
||||
|
||||
Fails with ``NotRetryingDestination`` if we are not yet ready
|
||||
to retry this server.
|
||||
|
||||
Fails with ``FederationDeniedError`` if this destination
|
||||
is not on our federation whitelist
|
||||
"""
|
||||
logger.debug("get_json args: %s", args)
|
||||
|
||||
@@ -495,9 +475,6 @@ class MatrixFederationHttpClient(object):
|
||||
|
||||
Fails with ``NotRetryingDestination`` if we are not yet ready
|
||||
to retry this server.
|
||||
|
||||
Fails with ``FederationDeniedError`` if this destination
|
||||
is not on our federation whitelist
|
||||
"""
|
||||
|
||||
response = yield self._request(
|
||||
@@ -541,9 +518,6 @@ class MatrixFederationHttpClient(object):
|
||||
|
||||
Fails with ``NotRetryingDestination`` if we are not yet ready
|
||||
to retry this server.
|
||||
|
||||
Fails with ``FederationDeniedError`` if this destination
|
||||
is not on our federation whitelist
|
||||
"""
|
||||
|
||||
encoded_args = {}
|
||||
|
||||
@@ -42,70 +42,36 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
metrics = synapse.metrics.get_metrics_for(__name__)
|
||||
|
||||
# total number of responses served, split by method/servlet/tag
|
||||
response_count = metrics.register_counter(
|
||||
"response_count",
|
||||
incoming_requests_counter = metrics.register_counter(
|
||||
"requests",
|
||||
labels=["method", "servlet", "tag"],
|
||||
alternative_names=(
|
||||
# the following are all deprecated aliases for the same metric
|
||||
metrics.name_prefix + x for x in (
|
||||
"_requests",
|
||||
"_response_time:count",
|
||||
"_response_ru_utime:count",
|
||||
"_response_ru_stime:count",
|
||||
"_response_db_txn_count:count",
|
||||
"_response_db_txn_duration:count",
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
outgoing_responses_counter = metrics.register_counter(
|
||||
"responses",
|
||||
labels=["method", "code"],
|
||||
)
|
||||
|
||||
response_timer = metrics.register_counter(
|
||||
"response_time_seconds",
|
||||
labels=["method", "servlet", "tag"],
|
||||
alternative_names=(
|
||||
metrics.name_prefix + "_response_time:total",
|
||||
),
|
||||
response_timer = metrics.register_distribution(
|
||||
"response_time",
|
||||
labels=["method", "servlet", "tag"]
|
||||
)
|
||||
|
||||
response_ru_utime = metrics.register_counter(
|
||||
"response_ru_utime_seconds", labels=["method", "servlet", "tag"],
|
||||
alternative_names=(
|
||||
metrics.name_prefix + "_response_ru_utime:total",
|
||||
),
|
||||
response_ru_utime = metrics.register_distribution(
|
||||
"response_ru_utime", labels=["method", "servlet", "tag"]
|
||||
)
|
||||
|
||||
response_ru_stime = metrics.register_counter(
|
||||
"response_ru_stime_seconds", labels=["method", "servlet", "tag"],
|
||||
alternative_names=(
|
||||
metrics.name_prefix + "_response_ru_stime:total",
|
||||
),
|
||||
response_ru_stime = metrics.register_distribution(
|
||||
"response_ru_stime", labels=["method", "servlet", "tag"]
|
||||
)
|
||||
|
||||
response_db_txn_count = metrics.register_counter(
|
||||
"response_db_txn_count", labels=["method", "servlet", "tag"],
|
||||
alternative_names=(
|
||||
metrics.name_prefix + "_response_db_txn_count:total",
|
||||
),
|
||||
response_db_txn_count = metrics.register_distribution(
|
||||
"response_db_txn_count", labels=["method", "servlet", "tag"]
|
||||
)
|
||||
|
||||
# seconds spent waiting for db txns, excluding scheduling time, when processing
|
||||
# this request
|
||||
response_db_txn_duration = metrics.register_counter(
|
||||
"response_db_txn_duration_seconds", labels=["method", "servlet", "tag"],
|
||||
alternative_names=(
|
||||
metrics.name_prefix + "_response_db_txn_duration:total",
|
||||
),
|
||||
response_db_txn_duration = metrics.register_distribution(
|
||||
"response_db_txn_duration", labels=["method", "servlet", "tag"]
|
||||
)
|
||||
|
||||
# seconds spent waiting for a db connection, when processing this request
|
||||
response_db_sched_duration = metrics.register_counter(
|
||||
"response_db_sched_duration_seconds", labels=["method", "servlet", "tag"]
|
||||
)
|
||||
|
||||
_next_request_id = 0
|
||||
|
||||
@@ -141,10 +107,6 @@ def wrap_request_handler(request_handler, include_metrics=False):
|
||||
with LoggingContext(request_id) as request_context:
|
||||
with Measure(self.clock, "wrapped_request_handler"):
|
||||
request_metrics = RequestMetrics()
|
||||
# we start the request metrics timer here with an initial stab
|
||||
# at the servlet name. For most requests that name will be
|
||||
# JsonResource (or a subclass), and JsonResource._async_render
|
||||
# will update it once it picks a servlet.
|
||||
request_metrics.start(self.clock, name=self.__class__.__name__)
|
||||
|
||||
request_context.request = request_id
|
||||
@@ -287,23 +249,12 @@ class JsonResource(HttpServer, resource.Resource):
|
||||
if not m:
|
||||
continue
|
||||
|
||||
# We found a match! First update the metrics object to indicate
|
||||
# which servlet is handling the request.
|
||||
# We found a match! Trigger callback and then return the
|
||||
# returned response. We pass both the request and any
|
||||
# matched groups from the regex to the callback.
|
||||
|
||||
callback = path_entry.callback
|
||||
|
||||
servlet_instance = getattr(callback, "__self__", None)
|
||||
if servlet_instance is not None:
|
||||
servlet_classname = servlet_instance.__class__.__name__
|
||||
else:
|
||||
servlet_classname = "%r" % callback
|
||||
|
||||
request_metrics.name = servlet_classname
|
||||
|
||||
# Now trigger the callback. If it returns a response, we send it
|
||||
# here. If it throws an exception, that is handled by the wrapper
|
||||
# installed by @request_handler.
|
||||
|
||||
kwargs = intern_dict({
|
||||
name: urllib.unquote(value).decode("UTF-8") if value else value
|
||||
for name, value in m.groupdict().items()
|
||||
@@ -314,14 +265,30 @@ class JsonResource(HttpServer, resource.Resource):
|
||||
code, response = callback_return
|
||||
self._send_response(request, code, response)
|
||||
|
||||
servlet_instance = getattr(callback, "__self__", None)
|
||||
if servlet_instance is not None:
|
||||
servlet_classname = servlet_instance.__class__.__name__
|
||||
else:
|
||||
servlet_classname = "%r" % callback
|
||||
|
||||
request_metrics.name = servlet_classname
|
||||
|
||||
return
|
||||
|
||||
# Huh. No one wanted to handle that? Fiiiiiine. Send 400.
|
||||
request_metrics.name = self.__class__.__name__ + ".UnrecognizedRequest"
|
||||
raise UnrecognizedRequestError()
|
||||
|
||||
def _send_response(self, request, code, response_json_object,
|
||||
response_code_message=None):
|
||||
# could alternatively use request.notifyFinish() and flip a flag when
|
||||
# the Deferred fires, but since the flag is RIGHT THERE it seems like
|
||||
# a waste.
|
||||
if request._disconnected:
|
||||
logger.warn(
|
||||
"Not sending response to request %s, already disconnected.",
|
||||
request)
|
||||
return
|
||||
|
||||
outgoing_responses_counter.inc(request.method, str(code))
|
||||
|
||||
# TODO: Only enable CORS for the requests that need it.
|
||||
@@ -355,7 +322,7 @@ class RequestMetrics(object):
|
||||
)
|
||||
return
|
||||
|
||||
response_count.inc(request.method, self.name, tag)
|
||||
incoming_requests_counter.inc(request.method, self.name, tag)
|
||||
|
||||
response_timer.inc_by(
|
||||
clock.time_msec() - self.start, request.method,
|
||||
@@ -374,10 +341,7 @@ class RequestMetrics(object):
|
||||
context.db_txn_count, request.method, self.name, tag
|
||||
)
|
||||
response_db_txn_duration.inc_by(
|
||||
context.db_txn_duration_ms / 1000., request.method, self.name, tag
|
||||
)
|
||||
response_db_sched_duration.inc_by(
|
||||
context.db_sched_duration_ms / 1000., request.method, self.name, tag
|
||||
context.db_txn_duration, request.method, self.name, tag
|
||||
)
|
||||
|
||||
|
||||
@@ -400,15 +364,6 @@ class RootRedirect(resource.Resource):
|
||||
def respond_with_json(request, code, json_object, send_cors=False,
|
||||
response_code_message=None, pretty_print=False,
|
||||
version_string="", canonical_json=True):
|
||||
# could alternatively use request.notifyFinish() and flip a flag when
|
||||
# the Deferred fires, but since the flag is RIGHT THERE it seems like
|
||||
# a waste.
|
||||
if request._disconnected:
|
||||
logger.warn(
|
||||
"Not sending response to request %s, already disconnected.",
|
||||
request)
|
||||
return
|
||||
|
||||
if pretty_print:
|
||||
json_bytes = encode_pretty_printed_json(json_object) + "\n"
|
||||
else:
|
||||
|
||||
@@ -148,13 +148,11 @@ def parse_string_from_args(args, name, default=None, required=False,
|
||||
return default
|
||||
|
||||
|
||||
def parse_json_value_from_request(request, allow_empty_body=False):
|
||||
def parse_json_value_from_request(request):
|
||||
"""Parse a JSON value from the body of a twisted HTTP request.
|
||||
|
||||
Args:
|
||||
request: the twisted HTTP request.
|
||||
allow_empty_body (bool): if True, an empty body will be accepted and
|
||||
turned into None
|
||||
|
||||
Returns:
|
||||
The JSON value.
|
||||
@@ -167,9 +165,6 @@ def parse_json_value_from_request(request, allow_empty_body=False):
|
||||
except Exception:
|
||||
raise SynapseError(400, "Error reading JSON content.")
|
||||
|
||||
if not content_bytes and allow_empty_body:
|
||||
return None
|
||||
|
||||
try:
|
||||
content = simplejson.loads(content_bytes)
|
||||
except Exception as e:
|
||||
@@ -179,24 +174,17 @@ def parse_json_value_from_request(request, allow_empty_body=False):
|
||||
return content
|
||||
|
||||
|
||||
def parse_json_object_from_request(request, allow_empty_body=False):
|
||||
def parse_json_object_from_request(request):
|
||||
"""Parse a JSON object from the body of a twisted HTTP request.
|
||||
|
||||
Args:
|
||||
request: the twisted HTTP request.
|
||||
allow_empty_body (bool): if True, an empty body will be accepted and
|
||||
turned into an empty dict.
|
||||
|
||||
Raises:
|
||||
SynapseError if the request body couldn't be decoded as JSON or
|
||||
if it wasn't a JSON object.
|
||||
"""
|
||||
content = parse_json_value_from_request(
|
||||
request, allow_empty_body=allow_empty_body,
|
||||
)
|
||||
|
||||
if allow_empty_body and content is None:
|
||||
return {}
|
||||
content = parse_json_value_from_request(request)
|
||||
|
||||
if type(content) != dict:
|
||||
message = "Content must be a JSON object."
|
||||
|
||||
@@ -66,15 +66,14 @@ class SynapseRequest(Request):
|
||||
context = LoggingContext.current_context()
|
||||
ru_utime, ru_stime = context.get_resource_usage()
|
||||
db_txn_count = context.db_txn_count
|
||||
db_txn_duration_ms = context.db_txn_duration_ms
|
||||
db_sched_duration_ms = context.db_sched_duration_ms
|
||||
db_txn_duration = context.db_txn_duration
|
||||
except Exception:
|
||||
ru_utime, ru_stime = (0, 0)
|
||||
db_txn_count, db_txn_duration_ms = (0, 0)
|
||||
db_txn_count, db_txn_duration = (0, 0)
|
||||
|
||||
self.site.access_logger.info(
|
||||
"%s - %s - {%s}"
|
||||
" Processed request: %dms (%dms, %dms) (%dms/%dms/%d)"
|
||||
" Processed request: %dms (%dms, %dms) (%dms/%d)"
|
||||
" %sB %s \"%s %s %s\" \"%s\"",
|
||||
self.getClientIP(),
|
||||
self.site.site_tag,
|
||||
@@ -82,8 +81,7 @@ class SynapseRequest(Request):
|
||||
int(time.time() * 1000) - self.start_time,
|
||||
int(ru_utime * 1000),
|
||||
int(ru_stime * 1000),
|
||||
db_sched_duration_ms,
|
||||
db_txn_duration_ms,
|
||||
int(db_txn_duration * 1000),
|
||||
int(db_txn_count),
|
||||
self.sentLength,
|
||||
self.code,
|
||||
|
||||
@@ -146,15 +146,10 @@ def runUntilCurrentTimer(func):
|
||||
num_pending += 1
|
||||
|
||||
num_pending += len(reactor.threadCallQueue)
|
||||
|
||||
start = time.time() * 1000
|
||||
ret = func(*args, **kwargs)
|
||||
end = time.time() * 1000
|
||||
|
||||
# record the amount of wallclock time spent running pending calls.
|
||||
# This is a proxy for the actual amount of time between reactor polls,
|
||||
# since about 25% of time is actually spent running things triggered by
|
||||
# I/O events, but that is harder to capture without rewriting half the
|
||||
# reactor.
|
||||
tick_time.inc_by(end - start)
|
||||
pending_calls_metric.inc_by(num_pending)
|
||||
|
||||
|
||||
@@ -15,38 +15,18 @@
|
||||
|
||||
|
||||
from itertools import chain
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def flatten(items):
|
||||
"""Flatten a list of lists
|
||||
|
||||
Args:
|
||||
items: iterable[iterable[X]]
|
||||
|
||||
Returns:
|
||||
list[X]: flattened list
|
||||
"""
|
||||
return list(chain.from_iterable(items))
|
||||
# TODO(paul): I can't believe Python doesn't have one of these
|
||||
def map_concat(func, items):
|
||||
# flatten a list-of-lists
|
||||
return list(chain.from_iterable(map(func, items)))
|
||||
|
||||
|
||||
class BaseMetric(object):
|
||||
"""Base class for metrics which report a single value per label set
|
||||
"""
|
||||
|
||||
def __init__(self, name, labels=[], alternative_names=[]):
|
||||
"""
|
||||
Args:
|
||||
name (str): principal name for this metric
|
||||
labels (list(str)): names of the labels which will be reported
|
||||
for this metric
|
||||
alternative_names (iterable(str)): list of alternative names for
|
||||
this metric. This can be useful to provide a migration path
|
||||
when renaming metrics.
|
||||
"""
|
||||
self._names = [name] + list(alternative_names)
|
||||
def __init__(self, name, labels=[]):
|
||||
self.name = name
|
||||
self.labels = labels # OK not to clone as we never write it
|
||||
|
||||
def dimension(self):
|
||||
@@ -56,7 +36,7 @@ class BaseMetric(object):
|
||||
return not len(self.labels)
|
||||
|
||||
def _render_labelvalue(self, value):
|
||||
# TODO: escape backslashes, quotes and newlines
|
||||
# TODO: some kind of value escape
|
||||
return '"%s"' % (value)
|
||||
|
||||
def _render_key(self, values):
|
||||
@@ -67,60 +47,19 @@ class BaseMetric(object):
|
||||
for k, v in zip(self.labels, values)])
|
||||
)
|
||||
|
||||
def _render_for_labels(self, label_values, value):
|
||||
"""Render this metric for a single set of labels
|
||||
|
||||
Args:
|
||||
label_values (list[str]): values for each of the labels
|
||||
value: value of the metric at with these labels
|
||||
|
||||
Returns:
|
||||
iterable[str]: rendered metric
|
||||
"""
|
||||
rendered_labels = self._render_key(label_values)
|
||||
return (
|
||||
"%s%s %.12g" % (name, rendered_labels, value)
|
||||
for name in self._names
|
||||
)
|
||||
|
||||
def render(self):
|
||||
"""Render this metric
|
||||
|
||||
Each metric is rendered as:
|
||||
|
||||
name{label1="val1",label2="val2"} value
|
||||
|
||||
https://prometheus.io/docs/instrumenting/exposition_formats/#text-format-details
|
||||
|
||||
Returns:
|
||||
iterable[str]: rendered metrics
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class CounterMetric(BaseMetric):
|
||||
"""The simplest kind of metric; one that stores a monotonically-increasing
|
||||
value that counts events or running totals.
|
||||
|
||||
Example use cases for Counters:
|
||||
- Number of requests processed
|
||||
- Number of items that were inserted into a queue
|
||||
- Total amount of data that a system has processed
|
||||
Counters can only go up (and be reset when the process restarts).
|
||||
"""
|
||||
integer that counts events."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(CounterMetric, self).__init__(*args, **kwargs)
|
||||
|
||||
# dict[list[str]]: value for each set of label values. the keys are the
|
||||
# label values, in the same order as the labels in self.labels.
|
||||
#
|
||||
# (if the metric is a scalar, the (single) key is the empty list).
|
||||
self.counts = {}
|
||||
|
||||
# Scalar metrics are never empty
|
||||
if self.is_scalar():
|
||||
self.counts[()] = 0.
|
||||
self.counts[()] = 0
|
||||
|
||||
def inc_by(self, incr, *values):
|
||||
if len(values) != self.dimension():
|
||||
@@ -138,11 +77,11 @@ class CounterMetric(BaseMetric):
|
||||
def inc(self, *values):
|
||||
self.inc_by(1, *values)
|
||||
|
||||
def render_item(self, k):
|
||||
return ["%s%s %d" % (self.name, self._render_key(k), self.counts[k])]
|
||||
|
||||
def render(self):
|
||||
return flatten(
|
||||
self._render_for_labels(k, self.counts[k])
|
||||
for k in sorted(self.counts.keys())
|
||||
)
|
||||
return map_concat(self.render_item, sorted(self.counts.keys()))
|
||||
|
||||
|
||||
class CallbackMetric(BaseMetric):
|
||||
@@ -156,19 +95,13 @@ class CallbackMetric(BaseMetric):
|
||||
self.callback = callback
|
||||
|
||||
def render(self):
|
||||
try:
|
||||
value = self.callback()
|
||||
except Exception:
|
||||
logger.exception("Failed to render %s", self.name)
|
||||
return ["# FAILED to render " + self.name]
|
||||
value = self.callback()
|
||||
|
||||
if self.is_scalar():
|
||||
return list(self._render_for_labels([], value))
|
||||
return ["%s %.12g" % (self.name, value)]
|
||||
|
||||
return flatten(
|
||||
self._render_for_labels(k, value[k])
|
||||
for k in sorted(value.keys())
|
||||
)
|
||||
return ["%s%s %.12g" % (self.name, self._render_key(k), value[k])
|
||||
for k in sorted(value.keys())]
|
||||
|
||||
|
||||
class DistributionMetric(object):
|
||||
@@ -193,9 +126,7 @@ class DistributionMetric(object):
|
||||
|
||||
|
||||
class CacheMetric(object):
|
||||
__slots__ = (
|
||||
"name", "cache_name", "hits", "misses", "evicted_size", "size_callback",
|
||||
)
|
||||
__slots__ = ("name", "cache_name", "hits", "misses", "size_callback")
|
||||
|
||||
def __init__(self, name, size_callback, cache_name):
|
||||
self.name = name
|
||||
@@ -203,7 +134,6 @@ class CacheMetric(object):
|
||||
|
||||
self.hits = 0
|
||||
self.misses = 0
|
||||
self.evicted_size = 0
|
||||
|
||||
self.size_callback = size_callback
|
||||
|
||||
@@ -213,9 +143,6 @@ class CacheMetric(object):
|
||||
def inc_misses(self):
|
||||
self.misses += 1
|
||||
|
||||
def inc_evictions(self, size=1):
|
||||
self.evicted_size += size
|
||||
|
||||
def render(self):
|
||||
size = self.size_callback()
|
||||
hits = self.hits
|
||||
@@ -225,9 +152,6 @@ class CacheMetric(object):
|
||||
"""%s:hits{name="%s"} %d""" % (self.name, self.cache_name, hits),
|
||||
"""%s:total{name="%s"} %d""" % (self.name, self.cache_name, total),
|
||||
"""%s:size{name="%s"} %d""" % (self.name, self.cache_name, size),
|
||||
"""%s:evicted_size{name="%s"} %d""" % (
|
||||
self.name, self.cache_name, self.evicted_size
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -13,30 +13,21 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
|
||||
from synapse.push import PusherConfigException
|
||||
|
||||
from twisted.internet import defer, reactor
|
||||
from twisted.internet.error import AlreadyCalled, AlreadyCancelled
|
||||
|
||||
import logging
|
||||
import push_rule_evaluator
|
||||
import push_tools
|
||||
import synapse
|
||||
from synapse.push import PusherConfigException
|
||||
|
||||
from synapse.util.logcontext import LoggingContext
|
||||
from synapse.util.metrics import Measure
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
metrics = synapse.metrics.get_metrics_for(__name__)
|
||||
|
||||
http_push_processed_counter = metrics.register_counter(
|
||||
"http_pushes_processed",
|
||||
)
|
||||
|
||||
http_push_failed_counter = metrics.register_counter(
|
||||
"http_pushes_failed",
|
||||
)
|
||||
|
||||
|
||||
class HttpPusher(object):
|
||||
INITIAL_BACKOFF_SEC = 1 # in seconds because that's what Twisted takes
|
||||
@@ -161,16 +152,9 @@ class HttpPusher(object):
|
||||
self.user_id, self.last_stream_ordering, self.max_stream_ordering
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Processing %i unprocessed push actions for %s starting at "
|
||||
"stream_ordering %s",
|
||||
len(unprocessed), self.name, self.last_stream_ordering,
|
||||
)
|
||||
|
||||
for push_action in unprocessed:
|
||||
processed = yield self._process_one(push_action)
|
||||
if processed:
|
||||
http_push_processed_counter.inc()
|
||||
self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC
|
||||
self.last_stream_ordering = push_action['stream_ordering']
|
||||
yield self.store.update_pusher_last_stream_ordering_and_success(
|
||||
@@ -185,7 +169,6 @@ class HttpPusher(object):
|
||||
self.failing_since
|
||||
)
|
||||
else:
|
||||
http_push_failed_counter.inc()
|
||||
if not self.failing_since:
|
||||
self.failing_since = self.clock.time_msec()
|
||||
yield self.store.update_pusher_failing_since(
|
||||
@@ -333,10 +316,7 @@ class HttpPusher(object):
|
||||
try:
|
||||
resp = yield self.http_client.post_json_get_json(self.url, notification_dict)
|
||||
except Exception:
|
||||
logger.warn(
|
||||
"Failed to push event %s to %s",
|
||||
event.event_id, self.name, exc_info=True,
|
||||
)
|
||||
logger.warn("Failed to push %s ", self.url)
|
||||
defer.returnValue(False)
|
||||
rejected = []
|
||||
if 'rejected' in resp:
|
||||
@@ -345,7 +325,7 @@ class HttpPusher(object):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _send_badge(self, badge):
|
||||
logger.info("Sending updated badge count %d to %s", badge, self.name)
|
||||
logger.info("Sending updated badge count %d to %r", badge, self.user_id)
|
||||
d = {
|
||||
'notification': {
|
||||
'id': '',
|
||||
@@ -367,10 +347,7 @@ class HttpPusher(object):
|
||||
try:
|
||||
resp = yield self.http_client.post_json_get_json(self.url, d)
|
||||
except Exception:
|
||||
logger.warn(
|
||||
"Failed to send badge count to %s",
|
||||
self.name, exc_info=True,
|
||||
)
|
||||
logger.exception("Failed to push %s ", self.url)
|
||||
defer.returnValue(False)
|
||||
rejected = []
|
||||
if 'rejected' in resp:
|
||||
|
||||
@@ -36,7 +36,7 @@ REQUIREMENTS = {
|
||||
"pydenticon": ["pydenticon"],
|
||||
"ujson": ["ujson"],
|
||||
"blist": ["blist"],
|
||||
"pysaml2>=3.0.0": ["saml2>=3.0.0"],
|
||||
"pysaml2>=3.0.0,<4.0.0": ["saml2>=3.0.0,<4.0.0"],
|
||||
"pymacaroons-pynacl": ["pymacaroons"],
|
||||
"msgpack-python>=0.3.0": ["msgpack"],
|
||||
"phonenumbers>=8.2.0": ["phonenumbers"],
|
||||
|
||||
@@ -1,31 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2018 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import send_event
|
||||
|
||||
from synapse.http.server import JsonResource
|
||||
|
||||
|
||||
REPLICATION_PREFIX = "/_synapse/replication"
|
||||
|
||||
|
||||
class ReplicationRestResource(JsonResource):
|
||||
def __init__(self, hs):
|
||||
JsonResource.__init__(self, hs, canonical_json=False)
|
||||
self.register_servlets(hs)
|
||||
|
||||
def register_servlets(self, hs):
|
||||
send_event.register_servlets(hs, self)
|
||||
@@ -1,108 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2018 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.events import FrozenEvent
|
||||
from synapse.events.snapshot import EventContext
|
||||
from synapse.http.servlet import RestServlet, parse_json_object_from_request
|
||||
from synapse.util.metrics import Measure
|
||||
from synapse.types import Requester
|
||||
|
||||
import logging
|
||||
import re
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def send_event_to_master(client, host, port, requester, event, context):
|
||||
"""Send event to be handled on the master
|
||||
|
||||
Args:
|
||||
client (SimpleHttpClient)
|
||||
host (str): host of master
|
||||
port (int): port on master listening for HTTP replication
|
||||
requester (Requester)
|
||||
event (FrozenEvent)
|
||||
context (EventContext)
|
||||
"""
|
||||
uri = "http://%s:%s/_synapse/replication/send_event" % (host, port,)
|
||||
|
||||
payload = {
|
||||
"event": event.get_pdu_json(),
|
||||
"internal_metadata": event.internal_metadata.get_dict(),
|
||||
"rejected_reason": event.rejected_reason,
|
||||
"context": context.serialize(),
|
||||
"requester": requester.serialize(),
|
||||
}
|
||||
|
||||
return client.post_json_get_json(uri, payload)
|
||||
|
||||
|
||||
class ReplicationSendEventRestServlet(RestServlet):
|
||||
"""Handles events newly created on workers, including persisting and
|
||||
notifying.
|
||||
|
||||
The API looks like:
|
||||
|
||||
POST /_synapse/replication/send_event
|
||||
|
||||
{
|
||||
"event": { .. serialized event .. },
|
||||
"internal_metadata": { .. serialized internal_metadata .. },
|
||||
"rejected_reason": .., // The event.rejected_reason field
|
||||
"context": { .. serialized event context .. },
|
||||
"requester": { .. serialized requester .. },
|
||||
}
|
||||
"""
|
||||
PATTERNS = [re.compile("^/_synapse/replication/send_event$")]
|
||||
|
||||
def __init__(self, hs):
|
||||
super(ReplicationSendEventRestServlet, self).__init__()
|
||||
|
||||
self.event_creation_handler = hs.get_event_creation_handler()
|
||||
self.store = hs.get_datastore()
|
||||
self.clock = hs.get_clock()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request):
|
||||
with Measure(self.clock, "repl_send_event_parse"):
|
||||
content = parse_json_object_from_request(request)
|
||||
|
||||
event_dict = content["event"]
|
||||
internal_metadata = content["internal_metadata"]
|
||||
rejected_reason = content["rejected_reason"]
|
||||
event = FrozenEvent(event_dict, internal_metadata, rejected_reason)
|
||||
|
||||
requester = Requester.deserialize(self.store, content["requester"])
|
||||
context = EventContext.deserialize(self.store, content["context"])
|
||||
|
||||
if requester.user:
|
||||
request.authenticated_entity = requester.user.to_string()
|
||||
|
||||
logger.info(
|
||||
"Got event to send with ID: %s into room: %s",
|
||||
event.event_id, event.room_id,
|
||||
)
|
||||
|
||||
yield self.event_creation_handler.handle_new_client_event(
|
||||
requester, event, context,
|
||||
)
|
||||
|
||||
defer.returnValue((200, {}))
|
||||
|
||||
|
||||
def register_servlets(hs, http_server):
|
||||
ReplicationSendEventRestServlet(hs).register(http_server)
|
||||
@@ -42,8 +42,6 @@ class SlavedClientIpStore(BaseSlavedStore):
|
||||
if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY:
|
||||
return
|
||||
|
||||
self.client_ip_last_seen.prefill(key, now)
|
||||
|
||||
self.hs.get_tcp_replication().send_user_ip(
|
||||
user_id, access_token, ip, user_agent, device_id, now
|
||||
)
|
||||
|
||||
@@ -19,9 +19,8 @@ from synapse.storage import DataStore
|
||||
from synapse.storage.event_federation import EventFederationStore
|
||||
from synapse.storage.event_push_actions import EventPushActionsStore
|
||||
from synapse.storage.roommember import RoomMemberStore
|
||||
from synapse.storage.state import StateGroupWorkerStore
|
||||
from synapse.storage.state import StateGroupReadStore
|
||||
from synapse.storage.stream import StreamStore
|
||||
from synapse.storage.signatures import SignatureStore
|
||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||
from ._base import BaseSlavedStore
|
||||
from ._slaved_id_tracker import SlavedIdTracker
|
||||
@@ -38,7 +37,7 @@ logger = logging.getLogger(__name__)
|
||||
# the method descriptor on the DataStore and chuck them into our class.
|
||||
|
||||
|
||||
class SlavedEventStore(StateGroupWorkerStore, BaseSlavedStore):
|
||||
class SlavedEventStore(StateGroupReadStore, BaseSlavedStore):
|
||||
|
||||
def __init__(self, db_conn, hs):
|
||||
super(SlavedEventStore, self).__init__(db_conn, hs)
|
||||
@@ -171,25 +170,6 @@ class SlavedEventStore(StateGroupWorkerStore, BaseSlavedStore):
|
||||
get_federation_out_pos = DataStore.get_federation_out_pos.__func__
|
||||
update_federation_out_pos = DataStore.update_federation_out_pos.__func__
|
||||
|
||||
get_latest_event_ids_and_hashes_in_room = (
|
||||
DataStore.get_latest_event_ids_and_hashes_in_room.__func__
|
||||
)
|
||||
_get_latest_event_ids_and_hashes_in_room = (
|
||||
DataStore._get_latest_event_ids_and_hashes_in_room.__func__
|
||||
)
|
||||
_get_event_reference_hashes_txn = (
|
||||
DataStore._get_event_reference_hashes_txn.__func__
|
||||
)
|
||||
add_event_hashes = (
|
||||
DataStore.add_event_hashes.__func__
|
||||
)
|
||||
get_event_reference_hashes = (
|
||||
SignatureStore.__dict__["get_event_reference_hashes"]
|
||||
)
|
||||
get_event_reference_hash = (
|
||||
SignatureStore.__dict__["get_event_reference_hash"]
|
||||
)
|
||||
|
||||
def stream_positions(self):
|
||||
result = super(SlavedEventStore, self).stream_positions()
|
||||
result["events"] = self._stream_id_gen.get_current_token()
|
||||
|
||||
@@ -517,28 +517,25 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
|
||||
self.send_error("Wrong remote")
|
||||
|
||||
def on_RDATA(self, cmd):
|
||||
stream_name = cmd.stream_name
|
||||
inbound_rdata_count.inc(stream_name)
|
||||
|
||||
try:
|
||||
row = STREAMS_MAP[stream_name].ROW_TYPE(*cmd.row)
|
||||
row = STREAMS_MAP[cmd.stream_name].ROW_TYPE(*cmd.row)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"[%s] Failed to parse RDATA: %r %r",
|
||||
self.id(), stream_name, cmd.row
|
||||
self.id(), cmd.stream_name, cmd.row
|
||||
)
|
||||
raise
|
||||
|
||||
if cmd.token is None:
|
||||
# I.e. this is part of a batch of updates for this stream. Batch
|
||||
# until we get an update for the stream with a non None token
|
||||
self.pending_batches.setdefault(stream_name, []).append(row)
|
||||
self.pending_batches.setdefault(cmd.stream_name, []).append(row)
|
||||
else:
|
||||
# Check if this is the last of a batch of updates
|
||||
rows = self.pending_batches.pop(stream_name, [])
|
||||
rows = self.pending_batches.pop(cmd.stream_name, [])
|
||||
rows.append(row)
|
||||
|
||||
self.handler.on_rdata(stream_name, cmd.token, rows)
|
||||
self.handler.on_rdata(cmd.stream_name, cmd.token, rows)
|
||||
|
||||
def on_POSITION(self, cmd):
|
||||
self.handler.on_position(cmd.stream_name, cmd.token)
|
||||
@@ -647,9 +644,3 @@ metrics.register_callback(
|
||||
},
|
||||
labels=["command", "name", "conn_id"],
|
||||
)
|
||||
|
||||
# number of updates received for each RDATA stream
|
||||
inbound_rdata_count = metrics.register_counter(
|
||||
"inbound_rdata_count",
|
||||
labels=["stream_name"],
|
||||
)
|
||||
|
||||
@@ -33,7 +33,7 @@ import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
MAX_EVENTS_BEHIND = 500000
|
||||
MAX_EVENTS_BEHIND = 10000
|
||||
|
||||
|
||||
EventStreamRow = namedtuple("EventStreamRow", (
|
||||
|
||||
@@ -41,6 +41,7 @@ from synapse.rest.client.v2_alpha import (
|
||||
auth,
|
||||
receipts,
|
||||
read_marker,
|
||||
room_keys,
|
||||
keys,
|
||||
tokenrefresh,
|
||||
tags,
|
||||
@@ -92,6 +93,7 @@ class ClientRestResource(JsonResource):
|
||||
auth.register_servlets(hs, client_resource)
|
||||
receipts.register_servlets(hs, client_resource)
|
||||
read_marker.register_servlets(hs, client_resource)
|
||||
room_keys.register_servlets(hs, client_resource)
|
||||
keys.register_servlets(hs, client_resource)
|
||||
tokenrefresh.register_servlets(hs, client_resource)
|
||||
tags.register_servlets(hs, client_resource)
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014-2016 OpenMarket Ltd
|
||||
# Copyright 2018 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -129,14 +128,7 @@ class PurgeHistoryRestServlet(ClientV1RestServlet):
|
||||
if not is_admin:
|
||||
raise AuthError(403, "You are not a server admin")
|
||||
|
||||
body = parse_json_object_from_request(request, allow_empty_body=True)
|
||||
|
||||
delete_local_events = bool(body.get("delete_local_events", False))
|
||||
|
||||
yield self.handlers.message_handler.purge_history(
|
||||
room_id, event_id,
|
||||
delete_local_events=delete_local_events,
|
||||
)
|
||||
yield self.handlers.message_handler.purge_history(room_id, event_id)
|
||||
|
||||
defer.returnValue((200, {}))
|
||||
|
||||
@@ -179,7 +171,6 @@ class ShutdownRoomRestServlet(ClientV1RestServlet):
|
||||
self.store = hs.get_datastore()
|
||||
self.handlers = hs.get_handlers()
|
||||
self.state = hs.get_state_handler()
|
||||
self.event_creation_handler = hs.get_event_creation_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request, room_id):
|
||||
@@ -212,6 +203,18 @@ class ShutdownRoomRestServlet(ClientV1RestServlet):
|
||||
)
|
||||
new_room_id = info["room_id"]
|
||||
|
||||
msg_handler = self.handlers.message_handler
|
||||
yield msg_handler.create_and_send_nonmember_event(
|
||||
room_creator_requester,
|
||||
{
|
||||
"type": "m.room.message",
|
||||
"content": {"body": message, "msgtype": "m.text"},
|
||||
"room_id": new_room_id,
|
||||
"sender": new_room_user_id,
|
||||
},
|
||||
ratelimit=False,
|
||||
)
|
||||
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
logger.info("Shutting down room %r", room_id)
|
||||
@@ -249,17 +252,6 @@ class ShutdownRoomRestServlet(ClientV1RestServlet):
|
||||
|
||||
kicked_users.append(user_id)
|
||||
|
||||
yield self.event_creation_handler.create_and_send_nonmember_event(
|
||||
room_creator_requester,
|
||||
{
|
||||
"type": "m.room.message",
|
||||
"content": {"body": message, "msgtype": "m.text"},
|
||||
"room_id": new_room_id,
|
||||
"sender": new_room_user_id,
|
||||
},
|
||||
ratelimit=False,
|
||||
)
|
||||
|
||||
aliases_for_room = yield self.store.get_aliases_for_room(room_id)
|
||||
|
||||
yield self.store.update_aliases_for_room(
|
||||
@@ -297,27 +289,6 @@ class QuarantineMediaInRoom(ClientV1RestServlet):
|
||||
defer.returnValue((200, {"num_quarantined": num_quarantined}))
|
||||
|
||||
|
||||
class ListMediaInRoom(ClientV1RestServlet):
|
||||
"""Lists all of the media in a given room.
|
||||
"""
|
||||
PATTERNS = client_path_patterns("/admin/room/(?P<room_id>[^/]+)/media")
|
||||
|
||||
def __init__(self, hs):
|
||||
super(ListMediaInRoom, self).__init__(hs)
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, room_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
is_admin = yield self.auth.is_server_admin(requester.user)
|
||||
if not is_admin:
|
||||
raise AuthError(403, "You are not a server admin")
|
||||
|
||||
local_mxcs, remote_mxcs = yield self.store.get_media_mxcs_in_room(room_id)
|
||||
|
||||
defer.returnValue((200, {"local": local_mxcs, "remote": remote_mxcs}))
|
||||
|
||||
|
||||
class ResetPasswordRestServlet(ClientV1RestServlet):
|
||||
"""Post request to allow an administrator reset password for a user.
|
||||
This needs user to have administrator access in Synapse.
|
||||
@@ -516,4 +487,3 @@ def register_servlets(hs, http_server):
|
||||
SearchUsersRestServlet(hs).register(http_server)
|
||||
ShutdownRoomRestServlet(hs).register(http_server)
|
||||
QuarantineMediaInRoom(hs).register(http_server)
|
||||
ListMediaInRoom(hs).register(http_server)
|
||||
|
||||
@@ -191,25 +191,19 @@ class LoginRestServlet(ClientV1RestServlet):
|
||||
|
||||
# convert threepid identifiers to user IDs
|
||||
if identifier["type"] == "m.id.thirdparty":
|
||||
address = identifier.get('address')
|
||||
medium = identifier.get('medium')
|
||||
|
||||
if medium is None or address is None:
|
||||
if 'medium' not in identifier or 'address' not in identifier:
|
||||
raise SynapseError(400, "Invalid thirdparty identifier")
|
||||
|
||||
if medium == 'email':
|
||||
address = identifier['address']
|
||||
if identifier['medium'] == 'email':
|
||||
# For emails, transform the address to lowercase.
|
||||
# We store all email addreses as lowercase in the DB.
|
||||
# (See add_threepid in synapse/handlers/auth.py)
|
||||
address = address.lower()
|
||||
user_id = yield self.hs.get_datastore().get_user_id_by_threepid(
|
||||
medium, address,
|
||||
identifier['medium'], address
|
||||
)
|
||||
if not user_id:
|
||||
logger.warn(
|
||||
"unknown 3pid identifier medium %s, address %r",
|
||||
medium, address,
|
||||
)
|
||||
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
|
||||
|
||||
identifier = {
|
||||
|
||||
@@ -81,7 +81,7 @@ class PresenceStatusRestServlet(ClientV1RestServlet):
|
||||
except Exception:
|
||||
raise SynapseError(400, "Unable to parse state")
|
||||
|
||||
# yield self.presence_handler.set_state(user, state)
|
||||
yield self.presence_handler.set_state(user, state)
|
||||
|
||||
defer.returnValue((200, {}))
|
||||
|
||||
|
||||
@@ -70,15 +70,10 @@ class RegisterRestServlet(ClientV1RestServlet):
|
||||
self.handlers = hs.get_handlers()
|
||||
|
||||
def on_GET(self, request):
|
||||
|
||||
require_email = 'email' in self.hs.config.registrations_require_3pid
|
||||
require_msisdn = 'msisdn' in self.hs.config.registrations_require_3pid
|
||||
|
||||
flows = []
|
||||
if self.hs.config.enable_registration_captcha:
|
||||
# only support the email-only flow if we don't require MSISDN 3PIDs
|
||||
if not require_msisdn:
|
||||
flows.extend([
|
||||
return (
|
||||
200,
|
||||
{"flows": [
|
||||
{
|
||||
"type": LoginType.RECAPTCHA,
|
||||
"stages": [
|
||||
@@ -87,34 +82,27 @@ class RegisterRestServlet(ClientV1RestServlet):
|
||||
LoginType.PASSWORD
|
||||
]
|
||||
},
|
||||
])
|
||||
# only support 3PIDless registration if no 3PIDs are required
|
||||
if not require_email and not require_msisdn:
|
||||
flows.extend([
|
||||
{
|
||||
"type": LoginType.RECAPTCHA,
|
||||
"stages": [LoginType.RECAPTCHA, LoginType.PASSWORD]
|
||||
}
|
||||
])
|
||||
]}
|
||||
)
|
||||
else:
|
||||
# only support the email-only flow if we don't require MSISDN 3PIDs
|
||||
if require_email or not require_msisdn:
|
||||
flows.extend([
|
||||
return (
|
||||
200,
|
||||
{"flows": [
|
||||
{
|
||||
"type": LoginType.EMAIL_IDENTITY,
|
||||
"stages": [
|
||||
LoginType.EMAIL_IDENTITY, LoginType.PASSWORD
|
||||
]
|
||||
}
|
||||
])
|
||||
# only support 3PIDless registration if no 3PIDs are required
|
||||
if not require_email and not require_msisdn:
|
||||
flows.extend([
|
||||
},
|
||||
{
|
||||
"type": LoginType.PASSWORD
|
||||
}
|
||||
])
|
||||
return (200, {"flows": flows})
|
||||
]}
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request):
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014-2016 OpenMarket Ltd
|
||||
# Copyright 2018 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -83,7 +82,6 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
|
||||
def __init__(self, hs):
|
||||
super(RoomStateEventRestServlet, self).__init__(hs)
|
||||
self.handlers = hs.get_handlers()
|
||||
self.event_creation_hander = hs.get_event_creation_handler()
|
||||
|
||||
def register(self, http_server):
|
||||
# /room/$roomid/state/$eventtype
|
||||
@@ -164,16 +162,15 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
|
||||
content=content,
|
||||
)
|
||||
else:
|
||||
event, context = yield self.event_creation_hander.create_event(
|
||||
msg_handler = self.handlers.message_handler
|
||||
event, context = yield msg_handler.create_event(
|
||||
requester,
|
||||
event_dict,
|
||||
token_id=requester.access_token_id,
|
||||
txn_id=txn_id,
|
||||
)
|
||||
|
||||
yield self.event_creation_hander.send_nonmember_event(
|
||||
requester, event, context,
|
||||
)
|
||||
yield msg_handler.send_nonmember_event(requester, event, context)
|
||||
|
||||
ret = {}
|
||||
if event:
|
||||
@@ -186,7 +183,7 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
|
||||
|
||||
def __init__(self, hs):
|
||||
super(RoomSendEventRestServlet, self).__init__(hs)
|
||||
self.event_creation_hander = hs.get_event_creation_handler()
|
||||
self.handlers = hs.get_handlers()
|
||||
|
||||
def register(self, http_server):
|
||||
# /rooms/$roomid/send/$event_type[/$txn_id]
|
||||
@@ -198,19 +195,15 @@ class RoomSendEventRestServlet(ClientV1RestServlet):
|
||||
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
|
||||
content = parse_json_object_from_request(request)
|
||||
|
||||
event_dict = {
|
||||
"type": event_type,
|
||||
"content": content,
|
||||
"room_id": room_id,
|
||||
"sender": requester.user.to_string(),
|
||||
}
|
||||
|
||||
if 'ts' in request.args and requester.app_service:
|
||||
event_dict['origin_server_ts'] = parse_integer(request, "ts", 0)
|
||||
|
||||
event = yield self.event_creation_hander.create_and_send_nonmember_event(
|
||||
msg_handler = self.handlers.message_handler
|
||||
event = yield msg_handler.create_and_send_nonmember_event(
|
||||
requester,
|
||||
event_dict,
|
||||
{
|
||||
"type": event_type,
|
||||
"content": content,
|
||||
"room_id": room_id,
|
||||
"sender": requester.user.to_string(),
|
||||
},
|
||||
txn_id=txn_id,
|
||||
)
|
||||
|
||||
@@ -494,35 +487,13 @@ class RoomInitialSyncRestServlet(ClientV1RestServlet):
|
||||
defer.returnValue((200, content))
|
||||
|
||||
|
||||
class RoomEventServlet(ClientV1RestServlet):
|
||||
PATTERNS = client_path_patterns(
|
||||
"/rooms/(?P<room_id>[^/]*)/event/(?P<event_id>[^/]*)$"
|
||||
)
|
||||
|
||||
def __init__(self, hs):
|
||||
super(RoomEventServlet, self).__init__(hs)
|
||||
self.clock = hs.get_clock()
|
||||
self.event_handler = hs.get_event_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, room_id, event_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
event = yield self.event_handler.get_event(requester.user, event_id)
|
||||
|
||||
time_now = self.clock.time_msec()
|
||||
if event:
|
||||
defer.returnValue((200, serialize_event(event, time_now)))
|
||||
else:
|
||||
defer.returnValue((404, "Event not found."))
|
||||
|
||||
|
||||
class RoomEventContextServlet(ClientV1RestServlet):
|
||||
class RoomEventContext(ClientV1RestServlet):
|
||||
PATTERNS = client_path_patterns(
|
||||
"/rooms/(?P<room_id>[^/]*)/context/(?P<event_id>[^/]*)$"
|
||||
)
|
||||
|
||||
def __init__(self, hs):
|
||||
super(RoomEventContextServlet, self).__init__(hs)
|
||||
super(RoomEventContext, self).__init__(hs)
|
||||
self.clock = hs.get_clock()
|
||||
self.handlers = hs.get_handlers()
|
||||
|
||||
@@ -672,7 +643,6 @@ class RoomRedactEventRestServlet(ClientV1RestServlet):
|
||||
def __init__(self, hs):
|
||||
super(RoomRedactEventRestServlet, self).__init__(hs)
|
||||
self.handlers = hs.get_handlers()
|
||||
self.event_creation_handler = hs.get_event_creation_handler()
|
||||
|
||||
def register(self, http_server):
|
||||
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)")
|
||||
@@ -683,7 +653,8 @@ class RoomRedactEventRestServlet(ClientV1RestServlet):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
content = parse_json_object_from_request(request)
|
||||
|
||||
event = yield self.event_creation_handler.create_and_send_nonmember_event(
|
||||
msg_handler = self.handlers.message_handler
|
||||
event = yield msg_handler.create_and_send_nonmember_event(
|
||||
requester,
|
||||
{
|
||||
"type": EventTypes.Redaction,
|
||||
@@ -832,5 +803,4 @@ def register_servlets(hs, http_server):
|
||||
RoomTypingRestServlet(hs).register(http_server)
|
||||
SearchRestServlet(hs).register(http_server)
|
||||
JoinedRoomsRestServlet(hs).register(http_server)
|
||||
RoomEventServlet(hs).register(http_server)
|
||||
RoomEventContextServlet(hs).register(http_server)
|
||||
RoomEventContext(hs).register(http_server)
|
||||
|
||||
@@ -26,7 +26,6 @@ from synapse.http.servlet import (
|
||||
)
|
||||
from synapse.util.async import run_on_reactor
|
||||
from synapse.util.msisdn import phone_number_to_msisdn
|
||||
from synapse.util.threepids import check_3pid_allowed
|
||||
from ._base import client_v2_patterns, interactive_auth_handler
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -48,11 +47,6 @@ class EmailPasswordRequestTokenRestServlet(RestServlet):
|
||||
'id_server', 'client_secret', 'email', 'send_attempt'
|
||||
])
|
||||
|
||||
if not check_3pid_allowed(self.hs, "email", body['email']):
|
||||
raise SynapseError(
|
||||
403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
|
||||
)
|
||||
|
||||
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
|
||||
'email', body['email']
|
||||
)
|
||||
@@ -84,11 +78,6 @@ class MsisdnPasswordRequestTokenRestServlet(RestServlet):
|
||||
|
||||
msisdn = phone_number_to_msisdn(body['country'], body['phone_number'])
|
||||
|
||||
if not check_3pid_allowed(self.hs, "msisdn", msisdn):
|
||||
raise SynapseError(
|
||||
403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
|
||||
)
|
||||
|
||||
existingUid = yield self.datastore.get_user_id_by_threepid(
|
||||
'msisdn', msisdn
|
||||
)
|
||||
@@ -228,11 +217,6 @@ class EmailThreepidRequestTokenRestServlet(RestServlet):
|
||||
if absent:
|
||||
raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
|
||||
|
||||
if not check_3pid_allowed(self.hs, "email", body['email']):
|
||||
raise SynapseError(
|
||||
403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
|
||||
)
|
||||
|
||||
existingUid = yield self.datastore.get_user_id_by_threepid(
|
||||
'email', body['email']
|
||||
)
|
||||
@@ -271,11 +255,6 @@ class MsisdnThreepidRequestTokenRestServlet(RestServlet):
|
||||
|
||||
msisdn = phone_number_to_msisdn(body['country'], body['phone_number'])
|
||||
|
||||
if not check_3pid_allowed(self.hs, "msisdn", msisdn):
|
||||
raise SynapseError(
|
||||
403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
|
||||
)
|
||||
|
||||
existingUid = yield self.datastore.get_user_id_by_threepid(
|
||||
'msisdn', msisdn
|
||||
)
|
||||
|
||||
@@ -26,7 +26,6 @@ from synapse.http.servlet import (
|
||||
RestServlet, parse_json_object_from_request, assert_params_in_request, parse_string
|
||||
)
|
||||
from synapse.util.msisdn import phone_number_to_msisdn
|
||||
from synapse.util.threepids import check_3pid_allowed
|
||||
|
||||
from ._base import client_v2_patterns, interactive_auth_handler
|
||||
|
||||
@@ -71,11 +70,6 @@ class EmailRegisterRequestTokenRestServlet(RestServlet):
|
||||
'id_server', 'client_secret', 'email', 'send_attempt'
|
||||
])
|
||||
|
||||
if not check_3pid_allowed(self.hs, "email", body['email']):
|
||||
raise SynapseError(
|
||||
403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
|
||||
)
|
||||
|
||||
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
|
||||
'email', body['email']
|
||||
)
|
||||
@@ -111,11 +105,6 @@ class MsisdnRegisterRequestTokenRestServlet(RestServlet):
|
||||
|
||||
msisdn = phone_number_to_msisdn(body['country'], body['phone_number'])
|
||||
|
||||
if not check_3pid_allowed(self.hs, "msisdn", msisdn):
|
||||
raise SynapseError(
|
||||
403, "Third party identifier is not allowed", Codes.THREEPID_DENIED,
|
||||
)
|
||||
|
||||
existingUid = yield self.hs.get_datastore().get_user_id_by_threepid(
|
||||
'msisdn', msisdn
|
||||
)
|
||||
@@ -316,67 +305,31 @@ class RegisterRestServlet(RestServlet):
|
||||
if 'x_show_msisdn' in body and body['x_show_msisdn']:
|
||||
show_msisdn = True
|
||||
|
||||
# FIXME: need a better error than "no auth flow found" for scenarios
|
||||
# where we required 3PID for registration but the user didn't give one
|
||||
require_email = 'email' in self.hs.config.registrations_require_3pid
|
||||
require_msisdn = 'msisdn' in self.hs.config.registrations_require_3pid
|
||||
|
||||
flows = []
|
||||
if self.hs.config.enable_registration_captcha:
|
||||
# only support 3PIDless registration if no 3PIDs are required
|
||||
if not require_email and not require_msisdn:
|
||||
flows.extend([[LoginType.RECAPTCHA]])
|
||||
# only support the email-only flow if we don't require MSISDN 3PIDs
|
||||
if not require_msisdn:
|
||||
flows.extend([[LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA]])
|
||||
|
||||
flows = [
|
||||
[LoginType.RECAPTCHA],
|
||||
[LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA],
|
||||
]
|
||||
if show_msisdn:
|
||||
# only support the MSISDN-only flow if we don't require email 3PIDs
|
||||
if not require_email:
|
||||
flows.extend([[LoginType.MSISDN, LoginType.RECAPTCHA]])
|
||||
# always let users provide both MSISDN & email
|
||||
flows.extend([
|
||||
[LoginType.MSISDN, LoginType.RECAPTCHA],
|
||||
[LoginType.MSISDN, LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA],
|
||||
])
|
||||
else:
|
||||
# only support 3PIDless registration if no 3PIDs are required
|
||||
if not require_email and not require_msisdn:
|
||||
flows.extend([[LoginType.DUMMY]])
|
||||
# only support the email-only flow if we don't require MSISDN 3PIDs
|
||||
if not require_msisdn:
|
||||
flows.extend([[LoginType.EMAIL_IDENTITY]])
|
||||
|
||||
flows = [
|
||||
[LoginType.DUMMY],
|
||||
[LoginType.EMAIL_IDENTITY],
|
||||
]
|
||||
if show_msisdn:
|
||||
# only support the MSISDN-only flow if we don't require email 3PIDs
|
||||
if not require_email or require_msisdn:
|
||||
flows.extend([[LoginType.MSISDN]])
|
||||
# always let users provide both MSISDN & email
|
||||
flows.extend([
|
||||
[LoginType.MSISDN, LoginType.EMAIL_IDENTITY]
|
||||
[LoginType.MSISDN],
|
||||
[LoginType.MSISDN, LoginType.EMAIL_IDENTITY],
|
||||
])
|
||||
|
||||
auth_result, params, session_id = yield self.auth_handler.check_auth(
|
||||
flows, body, self.hs.get_ip_from_request(request)
|
||||
)
|
||||
|
||||
# Check that we're not trying to register a denied 3pid.
|
||||
#
|
||||
# the user-facing checks will probably already have happened in
|
||||
# /register/email/requestToken when we requested a 3pid, but that's not
|
||||
# guaranteed.
|
||||
|
||||
if auth_result:
|
||||
for login_type in [LoginType.EMAIL_IDENTITY, LoginType.MSISDN]:
|
||||
if login_type in auth_result:
|
||||
medium = auth_result[login_type]['medium']
|
||||
address = auth_result[login_type]['address']
|
||||
|
||||
if not check_3pid_allowed(self.hs, medium, address):
|
||||
raise SynapseError(
|
||||
403, "Third party identifier is not allowed",
|
||||
Codes.THREEPID_DENIED,
|
||||
)
|
||||
|
||||
if registered_user_id is not None:
|
||||
logger.info(
|
||||
"Already registered user ID %r for this session",
|
||||
|
||||
352
synapse/rest/client/v2_alpha/room_keys.py
Normal file
352
synapse/rest/client/v2_alpha/room_keys.py
Normal file
@@ -0,0 +1,352 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2017 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.http.servlet import (
|
||||
RestServlet, parse_json_object_from_request, parse_string
|
||||
)
|
||||
from ._base import client_v2_patterns
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RoomKeysServlet(RestServlet):
|
||||
PATTERNS = client_v2_patterns(
|
||||
"/room_keys/keys(/(?P<room_id>[^/]+))?(/(?P<session_id>[^/]+))?$"
|
||||
)
|
||||
|
||||
def __init__(self, hs):
|
||||
"""
|
||||
Args:
|
||||
hs (synapse.server.HomeServer): server
|
||||
"""
|
||||
super(RoomKeysServlet, self).__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_PUT(self, request, room_id, session_id):
|
||||
"""
|
||||
Uploads one or more encrypted E2E room keys for backup purposes.
|
||||
room_id: the ID of the room the keys are for (optional)
|
||||
session_id: the ID for the E2E room keys for the room (optional)
|
||||
version: the version of the user's backup which this data is for.
|
||||
the version must already have been created via the /room_keys/version API.
|
||||
|
||||
Each session has:
|
||||
* first_message_index: a numeric index indicating the oldest message
|
||||
encrypted by this session.
|
||||
* forwarded_count: how many times the uploading client claims this key
|
||||
has been shared (forwarded)
|
||||
* is_verified: whether the client that uploaded the keys claims they
|
||||
were sent by a device which they've verified
|
||||
* session_data: base64-encrypted data describing the session.
|
||||
|
||||
Returns 200 OK on success with body {}
|
||||
Returns 403 Forbidden if the version in question is not the most recently
|
||||
created version (i.e. if this is an old client trying to write to a stale backup)
|
||||
Returns 404 Not Found if the version in question doesn't exist
|
||||
|
||||
The API is designed to be otherwise agnostic to the room_key encryption
|
||||
algorithm being used. Sessions are merged with existing ones in the
|
||||
backup using the heuristics:
|
||||
* is_verified sessions always win over unverified sessions
|
||||
* older first_message_index always win over newer sessions
|
||||
* lower forwarded_count always wins over higher forwarded_count
|
||||
|
||||
We trust the clients not to lie and corrupt their own backups.
|
||||
It also means that if your access_token is stolen, the attacker could
|
||||
delete your backup.
|
||||
|
||||
POST /room_keys/keys/!abc:matrix.org/c0ff33?version=1 HTTP/1.1
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"first_message_index": 1,
|
||||
"forwarded_count": 1,
|
||||
"is_verified": false,
|
||||
"session_data": "SSBBTSBBIEZJU0gK"
|
||||
}
|
||||
|
||||
Or...
|
||||
|
||||
POST /room_keys/keys/!abc:matrix.org?version=1 HTTP/1.1
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"sessions": {
|
||||
"c0ff33": {
|
||||
"first_message_index": 1,
|
||||
"forwarded_count": 1,
|
||||
"is_verified": false,
|
||||
"session_data": "SSBBTSBBIEZJU0gK"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Or...
|
||||
|
||||
POST /room_keys/keys?version=1 HTTP/1.1
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"rooms": {
|
||||
"!abc:matrix.org": {
|
||||
"sessions": {
|
||||
"c0ff33": {
|
||||
"first_message_index": 1,
|
||||
"forwarded_count": 1,
|
||||
"is_verified": false,
|
||||
"session_data": "SSBBTSBBIEZJU0gK"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
requester = yield self.auth.get_user_by_req(request, allow_guest=False)
|
||||
user_id = requester.user.to_string()
|
||||
body = parse_json_object_from_request(request)
|
||||
version = parse_string(request, "version")
|
||||
|
||||
if session_id:
|
||||
body = {
|
||||
"sessions": {
|
||||
session_id: body
|
||||
}
|
||||
}
|
||||
|
||||
if room_id:
|
||||
body = {
|
||||
"rooms": {
|
||||
room_id: body
|
||||
}
|
||||
}
|
||||
|
||||
yield self.e2e_room_keys_handler.upload_room_keys(
|
||||
user_id, version, body
|
||||
)
|
||||
defer.returnValue((200, {}))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, room_id, session_id):
|
||||
"""
|
||||
Retrieves one or more encrypted E2E room keys for backup purposes.
|
||||
Symmetric with the PUT version of the API.
|
||||
|
||||
room_id: the ID of the room to retrieve the keys for (optional)
|
||||
session_id: the ID for the E2E room keys to retrieve the keys for (optional)
|
||||
version: the version of the user's backup which this data is for.
|
||||
the version must already have been created via the /change_secret API.
|
||||
|
||||
Returns as follows:
|
||||
|
||||
GET /room_keys/keys/!abc:matrix.org/c0ff33?version=1 HTTP/1.1
|
||||
{
|
||||
"first_message_index": 1,
|
||||
"forwarded_count": 1,
|
||||
"is_verified": false,
|
||||
"session_data": "SSBBTSBBIEZJU0gK"
|
||||
}
|
||||
|
||||
Or...
|
||||
|
||||
GET /room_keys/keys/!abc:matrix.org?version=1 HTTP/1.1
|
||||
{
|
||||
"sessions": {
|
||||
"c0ff33": {
|
||||
"first_message_index": 1,
|
||||
"forwarded_count": 1,
|
||||
"is_verified": false,
|
||||
"session_data": "SSBBTSBBIEZJU0gK"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Or...
|
||||
|
||||
GET /room_keys/keys?version=1 HTTP/1.1
|
||||
{
|
||||
"rooms": {
|
||||
"!abc:matrix.org": {
|
||||
"sessions": {
|
||||
"c0ff33": {
|
||||
"first_message_index": 1,
|
||||
"forwarded_count": 1,
|
||||
"is_verified": false,
|
||||
"session_data": "SSBBTSBBIEZJU0gK"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
requester = yield self.auth.get_user_by_req(request, allow_guest=False)
|
||||
user_id = requester.user.to_string()
|
||||
version = parse_string(request, "version")
|
||||
|
||||
room_keys = yield self.e2e_room_keys_handler.get_room_keys(
|
||||
user_id, version, room_id, session_id
|
||||
)
|
||||
|
||||
if session_id:
|
||||
room_keys = room_keys['rooms'][room_id]['sessions'][session_id]
|
||||
elif room_id:
|
||||
room_keys = room_keys['rooms'][room_id]
|
||||
|
||||
defer.returnValue((200, room_keys))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_DELETE(self, request, room_id, session_id):
|
||||
"""
|
||||
Deletes one or more encrypted E2E room keys for a user for backup purposes.
|
||||
|
||||
DELETE /room_keys/keys/!abc:matrix.org/c0ff33?version=1
|
||||
HTTP/1.1 200 OK
|
||||
{}
|
||||
|
||||
room_id: the ID of the room whose keys to delete (optional)
|
||||
session_id: the ID for the E2E session to delete (optional)
|
||||
version: the version of the user's backup which this data is for.
|
||||
the version must already have been created via the /change_secret API.
|
||||
"""
|
||||
|
||||
requester = yield self.auth.get_user_by_req(request, allow_guest=False)
|
||||
user_id = requester.user.to_string()
|
||||
version = parse_string(request, "version")
|
||||
|
||||
yield self.e2e_room_keys_handler.delete_room_keys(
|
||||
user_id, version, room_id, session_id
|
||||
)
|
||||
defer.returnValue((200, {}))
|
||||
|
||||
|
||||
class RoomKeysVersionServlet(RestServlet):
|
||||
PATTERNS = client_v2_patterns(
|
||||
"/room_keys/version(/(?P<version>[^/]+))?$"
|
||||
)
|
||||
|
||||
def __init__(self, hs):
|
||||
"""
|
||||
Args:
|
||||
hs (synapse.server.HomeServer): server
|
||||
"""
|
||||
super(RoomKeysVersionServlet, self).__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self.e2e_room_keys_handler = hs.get_e2e_room_keys_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request, version):
|
||||
"""
|
||||
Create a new backup version for this user's room_keys with the given
|
||||
info. The version is allocated by the server and returned to the user
|
||||
in the response. This API is intended to be used whenever the user
|
||||
changes the encryption key for their backups, ensuring that backups
|
||||
encrypted with different keys don't collide.
|
||||
|
||||
It takes out an exclusive lock on this user's room_key backups, to ensure
|
||||
clients only upload to the current backup.
|
||||
|
||||
The algorithm passed in the version info is a reverse-DNS namespaced
|
||||
identifier to describe the format of the encrypted backupped keys.
|
||||
|
||||
The auth_data is { user_id: "user_id", nonce: <random string> }
|
||||
encrypted using the algorithm and current encryption key described above.
|
||||
|
||||
POST /room_keys/version
|
||||
Content-Type: application/json
|
||||
{
|
||||
"algorithm": "m.megolm_backup.v1",
|
||||
"auth_data": "dGhpcyBzaG91bGQgYWN0dWFsbHkgYmUgZW5jcnlwdGVkIGpzb24K"
|
||||
}
|
||||
|
||||
HTTP/1.1 200 OK
|
||||
Content-Type: application/json
|
||||
{
|
||||
"version": 12345
|
||||
}
|
||||
"""
|
||||
|
||||
if version:
|
||||
raise SynapseError(405, "Cannot POST to a specific version")
|
||||
|
||||
requester = yield self.auth.get_user_by_req(request, allow_guest=False)
|
||||
user_id = requester.user.to_string()
|
||||
info = parse_json_object_from_request(request)
|
||||
|
||||
new_version = yield self.e2e_room_keys_handler.create_version(
|
||||
user_id, info
|
||||
)
|
||||
defer.returnValue((200, {"version": new_version}))
|
||||
|
||||
# we deliberately don't have a PUT /version, as these things really should
|
||||
# be immutable to avoid people footgunning
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, version):
|
||||
"""
|
||||
Retrieve the version information about a given version of the user's
|
||||
room_keys backup. If the version part is missing, returns info about the
|
||||
most current backup version (if any)
|
||||
|
||||
It takes out an exclusive lock on this user's room_key backups, to ensure
|
||||
clients only upload to the current backup.
|
||||
|
||||
Returns 404 if the given version does not exist.
|
||||
|
||||
GET /room_keys/version/12345 HTTP/1.1
|
||||
{
|
||||
"version": "12345",
|
||||
"algorithm": "m.megolm_backup.v1",
|
||||
"auth_data": "dGhpcyBzaG91bGQgYWN0dWFsbHkgYmUgZW5jcnlwdGVkIGpzb24K"
|
||||
}
|
||||
"""
|
||||
|
||||
requester = yield self.auth.get_user_by_req(request, allow_guest=False)
|
||||
user_id = requester.user.to_string()
|
||||
|
||||
info = yield self.e2e_room_keys_handler.get_version_info(
|
||||
user_id, version
|
||||
)
|
||||
defer.returnValue((200, info))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_DELETE(self, request, version):
|
||||
"""
|
||||
Delete the information about a given version of the user's
|
||||
room_keys backup. If the version part is missing, deletes the most
|
||||
current backup version (if any). Doesn't delete the actual room data.
|
||||
|
||||
DELETE /room_keys/version/12345 HTTP/1.1
|
||||
HTTP/1.1 200 OK
|
||||
{}
|
||||
"""
|
||||
|
||||
requester = yield self.auth.get_user_by_req(request, allow_guest=False)
|
||||
user_id = requester.user.to_string()
|
||||
|
||||
yield self.e2e_room_keys_handler.delete_version(
|
||||
user_id, version
|
||||
)
|
||||
defer.returnValue((200, {}))
|
||||
|
||||
|
||||
def register_servlets(hs, http_server):
|
||||
RoomKeysServlet(hs).register(http_server)
|
||||
RoomKeysVersionServlet(hs).register(http_server)
|
||||
@@ -93,7 +93,6 @@ class RemoteKey(Resource):
|
||||
self.store = hs.get_datastore()
|
||||
self.version_string = hs.version_string
|
||||
self.clock = hs.get_clock()
|
||||
self.federation_domain_whitelist = hs.config.federation_domain_whitelist
|
||||
|
||||
def render_GET(self, request):
|
||||
self.async_render_GET(request)
|
||||
@@ -138,13 +137,6 @@ class RemoteKey(Resource):
|
||||
logger.info("Handling query for keys %r", query)
|
||||
store_queries = []
|
||||
for server_name, key_ids in query.items():
|
||||
if (
|
||||
self.federation_domain_whitelist is not None and
|
||||
server_name not in self.federation_domain_whitelist
|
||||
):
|
||||
logger.debug("Federation denied with %s", server_name)
|
||||
continue
|
||||
|
||||
if not key_ids:
|
||||
key_ids = (None,)
|
||||
for key_id in key_ids:
|
||||
|
||||
@@ -70,11 +70,38 @@ def respond_with_file(request, media_type, file_path,
|
||||
logger.debug("Responding with %r", file_path)
|
||||
|
||||
if os.path.isfile(file_path):
|
||||
request.setHeader(b"Content-Type", media_type.encode("UTF-8"))
|
||||
if upload_name:
|
||||
if is_ascii(upload_name):
|
||||
request.setHeader(
|
||||
b"Content-Disposition",
|
||||
b"inline; filename=%s" % (
|
||||
urllib.quote(upload_name.encode("utf-8")),
|
||||
),
|
||||
)
|
||||
else:
|
||||
request.setHeader(
|
||||
b"Content-Disposition",
|
||||
b"inline; filename*=utf-8''%s" % (
|
||||
urllib.quote(upload_name.encode("utf-8")),
|
||||
),
|
||||
)
|
||||
|
||||
# cache for at least a day.
|
||||
# XXX: we might want to turn this off for data we don't want to
|
||||
# recommend caching as it's sensitive or private - or at least
|
||||
# select private. don't bother setting Expires as all our
|
||||
# clients are smart enough to be happy with Cache-Control
|
||||
request.setHeader(
|
||||
b"Cache-Control", b"public,max-age=86400,s-maxage=86400"
|
||||
)
|
||||
if file_size is None:
|
||||
stat = os.stat(file_path)
|
||||
file_size = stat.st_size
|
||||
|
||||
add_file_headers(request, media_type, file_size, upload_name)
|
||||
request.setHeader(
|
||||
b"Content-Length", b"%d" % (file_size,)
|
||||
)
|
||||
|
||||
with open(file_path, "rb") as f:
|
||||
yield logcontext.make_deferred_yieldable(
|
||||
@@ -84,118 +111,3 @@ def respond_with_file(request, media_type, file_path,
|
||||
finish_request(request)
|
||||
else:
|
||||
respond_404(request)
|
||||
|
||||
|
||||
def add_file_headers(request, media_type, file_size, upload_name):
|
||||
"""Adds the correct response headers in preparation for responding with the
|
||||
media.
|
||||
|
||||
Args:
|
||||
request (twisted.web.http.Request)
|
||||
media_type (str): The media/content type.
|
||||
file_size (int): Size in bytes of the media, if known.
|
||||
upload_name (str): The name of the requested file, if any.
|
||||
"""
|
||||
request.setHeader(b"Content-Type", media_type.encode("UTF-8"))
|
||||
if upload_name:
|
||||
if is_ascii(upload_name):
|
||||
request.setHeader(
|
||||
b"Content-Disposition",
|
||||
b"inline; filename=%s" % (
|
||||
urllib.quote(upload_name.encode("utf-8")),
|
||||
),
|
||||
)
|
||||
else:
|
||||
request.setHeader(
|
||||
b"Content-Disposition",
|
||||
b"inline; filename*=utf-8''%s" % (
|
||||
urllib.quote(upload_name.encode("utf-8")),
|
||||
),
|
||||
)
|
||||
|
||||
# cache for at least a day.
|
||||
# XXX: we might want to turn this off for data we don't want to
|
||||
# recommend caching as it's sensitive or private - or at least
|
||||
# select private. don't bother setting Expires as all our
|
||||
# clients are smart enough to be happy with Cache-Control
|
||||
request.setHeader(
|
||||
b"Cache-Control", b"public,max-age=86400,s-maxage=86400"
|
||||
)
|
||||
|
||||
request.setHeader(
|
||||
b"Content-Length", b"%d" % (file_size,)
|
||||
)
|
||||
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def respond_with_responder(request, responder, media_type, file_size, upload_name=None):
|
||||
"""Responds to the request with given responder. If responder is None then
|
||||
returns 404.
|
||||
|
||||
Args:
|
||||
request (twisted.web.http.Request)
|
||||
responder (Responder|None)
|
||||
media_type (str): The media/content type.
|
||||
file_size (int|None): Size in bytes of the media. If not known it should be None
|
||||
upload_name (str|None): The name of the requested file, if any.
|
||||
"""
|
||||
if not responder:
|
||||
respond_404(request)
|
||||
return
|
||||
|
||||
add_file_headers(request, media_type, file_size, upload_name)
|
||||
with responder:
|
||||
yield responder.write_to_consumer(request)
|
||||
finish_request(request)
|
||||
|
||||
|
||||
class Responder(object):
|
||||
"""Represents a response that can be streamed to the requester.
|
||||
|
||||
Responder is a context manager which *must* be used, so that any resources
|
||||
held can be cleaned up.
|
||||
"""
|
||||
def write_to_consumer(self, consumer):
|
||||
"""Stream response into consumer
|
||||
|
||||
Args:
|
||||
consumer (IConsumer)
|
||||
|
||||
Returns:
|
||||
Deferred: Resolves once the response has finished being written
|
||||
"""
|
||||
pass
|
||||
|
||||
def __enter__(self):
|
||||
pass
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
pass
|
||||
|
||||
|
||||
class FileInfo(object):
|
||||
"""Details about a requested/uploaded file.
|
||||
|
||||
Attributes:
|
||||
server_name (str): The server name where the media originated from,
|
||||
or None if local.
|
||||
file_id (str): The local ID of the file. For local files this is the
|
||||
same as the media_id
|
||||
url_cache (bool): If the file is for the url preview cache
|
||||
thumbnail (bool): Whether the file is a thumbnail or not.
|
||||
thumbnail_width (int)
|
||||
thumbnail_height (int)
|
||||
thumbnail_method (str)
|
||||
thumbnail_type (str): Content type of thumbnail, e.g. image/png
|
||||
"""
|
||||
def __init__(self, server_name, file_id, url_cache=False,
|
||||
thumbnail=False, thumbnail_width=None, thumbnail_height=None,
|
||||
thumbnail_method=None, thumbnail_type=None):
|
||||
self.server_name = server_name
|
||||
self.file_id = file_id
|
||||
self.url_cache = url_cache
|
||||
self.thumbnail = thumbnail
|
||||
self.thumbnail_width = thumbnail_width
|
||||
self.thumbnail_height = thumbnail_height
|
||||
self.thumbnail_method = thumbnail_method
|
||||
self.thumbnail_type = thumbnail_type
|
||||
|
||||
@@ -14,7 +14,7 @@
|
||||
# limitations under the License.
|
||||
import synapse.http.servlet
|
||||
|
||||
from ._base import parse_media_id, respond_404
|
||||
from ._base import parse_media_id, respond_with_file, respond_404
|
||||
from twisted.web.resource import Resource
|
||||
from synapse.http.server import request_handler, set_cors_headers
|
||||
|
||||
@@ -32,12 +32,12 @@ class DownloadResource(Resource):
|
||||
def __init__(self, hs, media_repo):
|
||||
Resource.__init__(self)
|
||||
|
||||
self.filepaths = media_repo.filepaths
|
||||
self.media_repo = media_repo
|
||||
self.server_name = hs.hostname
|
||||
|
||||
# Both of these are expected by @request_handler()
|
||||
self.clock = hs.get_clock()
|
||||
self.store = hs.get_datastore()
|
||||
self.version_string = hs.version_string
|
||||
self.clock = hs.get_clock()
|
||||
|
||||
def render_GET(self, request):
|
||||
self._async_render_GET(request)
|
||||
@@ -57,16 +57,59 @@ class DownloadResource(Resource):
|
||||
)
|
||||
server_name, media_id, name = parse_media_id(request)
|
||||
if server_name == self.server_name:
|
||||
yield self.media_repo.get_local_media(request, media_id, name)
|
||||
yield self._respond_local_file(request, media_id, name)
|
||||
else:
|
||||
allow_remote = synapse.http.servlet.parse_boolean(
|
||||
request, "allow_remote", default=True)
|
||||
if not allow_remote:
|
||||
logger.info(
|
||||
"Rejecting request for remote media %s/%s due to allow_remote",
|
||||
server_name, media_id,
|
||||
)
|
||||
respond_404(request)
|
||||
return
|
||||
yield self._respond_remote_file(
|
||||
request, server_name, media_id, name
|
||||
)
|
||||
|
||||
yield self.media_repo.get_remote_media(request, server_name, media_id, name)
|
||||
@defer.inlineCallbacks
|
||||
def _respond_local_file(self, request, media_id, name):
|
||||
media_info = yield self.store.get_local_media(media_id)
|
||||
if not media_info or media_info["quarantined_by"]:
|
||||
respond_404(request)
|
||||
return
|
||||
|
||||
media_type = media_info["media_type"]
|
||||
media_length = media_info["media_length"]
|
||||
upload_name = name if name else media_info["upload_name"]
|
||||
if media_info["url_cache"]:
|
||||
# TODO: Check the file still exists, if it doesn't we can redownload
|
||||
# it from the url `media_info["url_cache"]`
|
||||
file_path = self.filepaths.url_cache_filepath(media_id)
|
||||
else:
|
||||
file_path = self.filepaths.local_media_filepath(media_id)
|
||||
|
||||
yield respond_with_file(
|
||||
request, media_type, file_path, media_length,
|
||||
upload_name=upload_name,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _respond_remote_file(self, request, server_name, media_id, name):
|
||||
# don't forward requests for remote media if allow_remote is false
|
||||
allow_remote = synapse.http.servlet.parse_boolean(
|
||||
request, "allow_remote", default=True)
|
||||
if not allow_remote:
|
||||
logger.info(
|
||||
"Rejecting request for remote media %s/%s due to allow_remote",
|
||||
server_name, media_id,
|
||||
)
|
||||
respond_404(request)
|
||||
return
|
||||
|
||||
media_info = yield self.media_repo.get_remote_media(server_name, media_id)
|
||||
|
||||
media_type = media_info["media_type"]
|
||||
media_length = media_info["media_length"]
|
||||
filesystem_id = media_info["filesystem_id"]
|
||||
upload_name = name if name else media_info["upload_name"]
|
||||
|
||||
file_path = self.filepaths.remote_media_filepath(
|
||||
server_name, filesystem_id
|
||||
)
|
||||
|
||||
yield respond_with_file(
|
||||
request, media_type, file_path, media_length,
|
||||
upload_name=upload_name,
|
||||
)
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014-2016 OpenMarket Ltd
|
||||
# Copyright 2018 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -19,7 +18,6 @@ import twisted.internet.error
|
||||
import twisted.web.http
|
||||
from twisted.web.resource import Resource
|
||||
|
||||
from ._base import respond_404, FileInfo, respond_with_responder
|
||||
from .upload_resource import UploadResource
|
||||
from .download_resource import DownloadResource
|
||||
from .thumbnail_resource import ThumbnailResource
|
||||
@@ -27,18 +25,15 @@ from .identicon_resource import IdenticonResource
|
||||
from .preview_url_resource import PreviewUrlResource
|
||||
from .filepath import MediaFilePaths
|
||||
from .thumbnailer import Thumbnailer
|
||||
from .storage_provider import StorageProviderWrapper
|
||||
from .media_storage import MediaStorage
|
||||
|
||||
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
|
||||
from synapse.util.stringutils import random_string
|
||||
from synapse.api.errors import (
|
||||
SynapseError, HttpResponseException, NotFoundError, FederationDeniedError,
|
||||
)
|
||||
from synapse.api.errors import SynapseError, HttpResponseException, \
|
||||
NotFoundError
|
||||
|
||||
from synapse.util.async import Linearizer
|
||||
from synapse.util.stringutils import is_ascii
|
||||
from synapse.util.logcontext import make_deferred_yieldable
|
||||
from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
|
||||
from synapse.util.retryutils import NotRetryingDestination
|
||||
|
||||
import os
|
||||
@@ -52,7 +47,7 @@ import urlparse
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
UPDATE_RECENTLY_ACCESSED_TS = 60 * 1000
|
||||
UPDATE_RECENTLY_ACCESSED_REMOTES_TS = 60 * 1000
|
||||
|
||||
|
||||
class MediaRepository(object):
|
||||
@@ -68,62 +63,96 @@ class MediaRepository(object):
|
||||
self.primary_base_path = hs.config.media_store_path
|
||||
self.filepaths = MediaFilePaths(self.primary_base_path)
|
||||
|
||||
self.backup_base_path = hs.config.backup_media_store_path
|
||||
|
||||
self.synchronous_backup_media_store = hs.config.synchronous_backup_media_store
|
||||
|
||||
self.dynamic_thumbnails = hs.config.dynamic_thumbnails
|
||||
self.thumbnail_requirements = hs.config.thumbnail_requirements
|
||||
|
||||
self.remote_media_linearizer = Linearizer(name="media_remote")
|
||||
|
||||
self.recently_accessed_remotes = set()
|
||||
self.recently_accessed_locals = set()
|
||||
|
||||
self.federation_domain_whitelist = hs.config.federation_domain_whitelist
|
||||
|
||||
# List of StorageProviders where we should search for media and
|
||||
# potentially upload to.
|
||||
storage_providers = []
|
||||
|
||||
for clz, provider_config, wrapper_config in hs.config.media_storage_providers:
|
||||
backend = clz(hs, provider_config)
|
||||
provider = StorageProviderWrapper(
|
||||
backend,
|
||||
store_local=wrapper_config.store_local,
|
||||
store_remote=wrapper_config.store_remote,
|
||||
store_synchronous=wrapper_config.store_synchronous,
|
||||
)
|
||||
storage_providers.append(provider)
|
||||
|
||||
self.media_storage = MediaStorage(
|
||||
self.primary_base_path, self.filepaths, storage_providers,
|
||||
)
|
||||
|
||||
self.clock.looping_call(
|
||||
self._update_recently_accessed,
|
||||
UPDATE_RECENTLY_ACCESSED_TS,
|
||||
self._update_recently_accessed_remotes,
|
||||
UPDATE_RECENTLY_ACCESSED_REMOTES_TS
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _update_recently_accessed(self):
|
||||
remote_media = self.recently_accessed_remotes
|
||||
def _update_recently_accessed_remotes(self):
|
||||
media = self.recently_accessed_remotes
|
||||
self.recently_accessed_remotes = set()
|
||||
|
||||
local_media = self.recently_accessed_locals
|
||||
self.recently_accessed_locals = set()
|
||||
|
||||
yield self.store.update_cached_last_access_time(
|
||||
local_media, remote_media, self.clock.time_msec()
|
||||
media, self.clock.time_msec()
|
||||
)
|
||||
|
||||
def mark_recently_accessed(self, server_name, media_id):
|
||||
"""Mark the given media as recently accessed.
|
||||
@staticmethod
|
||||
def _makedirs(filepath):
|
||||
dirname = os.path.dirname(filepath)
|
||||
if not os.path.exists(dirname):
|
||||
os.makedirs(dirname)
|
||||
|
||||
@staticmethod
|
||||
def _write_file_synchronously(source, fname):
|
||||
"""Write `source` to the path `fname` synchronously. Should be called
|
||||
from a thread.
|
||||
|
||||
Args:
|
||||
server_name (str|None): Origin server of media, or None if local
|
||||
media_id (str): The media ID of the content
|
||||
source: A file like object to be written
|
||||
fname (str): Path to write to
|
||||
"""
|
||||
if server_name:
|
||||
self.recently_accessed_remotes.add((server_name, media_id))
|
||||
else:
|
||||
self.recently_accessed_locals.add(media_id)
|
||||
MediaRepository._makedirs(fname)
|
||||
source.seek(0) # Ensure we read from the start of the file
|
||||
with open(fname, "wb") as f:
|
||||
shutil.copyfileobj(source, f)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def write_to_file_and_backup(self, source, path):
|
||||
"""Write `source` to the on disk media store, and also the backup store
|
||||
if configured.
|
||||
|
||||
Args:
|
||||
source: A file like object that should be written
|
||||
path (str): Relative path to write file to
|
||||
|
||||
Returns:
|
||||
Deferred[str]: the file path written to in the primary media store
|
||||
"""
|
||||
fname = os.path.join(self.primary_base_path, path)
|
||||
|
||||
# Write to the main repository
|
||||
yield make_deferred_yieldable(threads.deferToThread(
|
||||
self._write_file_synchronously, source, fname,
|
||||
))
|
||||
|
||||
# Write to backup repository
|
||||
yield self.copy_to_backup(path)
|
||||
|
||||
defer.returnValue(fname)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def copy_to_backup(self, path):
|
||||
"""Copy a file from the primary to backup media store, if configured.
|
||||
|
||||
Args:
|
||||
path(str): Relative path to write file to
|
||||
"""
|
||||
if self.backup_base_path:
|
||||
primary_fname = os.path.join(self.primary_base_path, path)
|
||||
backup_fname = os.path.join(self.backup_base_path, path)
|
||||
|
||||
# We can either wait for successful writing to the backup repository
|
||||
# or write in the background and immediately return
|
||||
if self.synchronous_backup_media_store:
|
||||
yield make_deferred_yieldable(threads.deferToThread(
|
||||
shutil.copyfile, primary_fname, backup_fname,
|
||||
))
|
||||
else:
|
||||
preserve_fn(threads.deferToThread)(
|
||||
shutil.copyfile, primary_fname, backup_fname,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def create_content(self, media_type, upload_name, content, content_length,
|
||||
@@ -142,13 +171,10 @@ class MediaRepository(object):
|
||||
"""
|
||||
media_id = random_string(24)
|
||||
|
||||
file_info = FileInfo(
|
||||
server_name=None,
|
||||
file_id=media_id,
|
||||
fname = yield self.write_to_file_and_backup(
|
||||
content, self.filepaths.local_media_filepath_rel(media_id)
|
||||
)
|
||||
|
||||
fname = yield self.media_storage.store_file(content, file_info)
|
||||
|
||||
logger.info("Stored local media in file %r", fname)
|
||||
|
||||
yield self.store.store_local_media(
|
||||
@@ -159,275 +185,134 @@ class MediaRepository(object):
|
||||
media_length=content_length,
|
||||
user_id=auth_user,
|
||||
)
|
||||
media_info = {
|
||||
"media_type": media_type,
|
||||
"media_length": content_length,
|
||||
}
|
||||
|
||||
yield self._generate_thumbnails(
|
||||
None, media_id, media_id, media_type,
|
||||
)
|
||||
yield self._generate_thumbnails(None, media_id, media_info)
|
||||
|
||||
defer.returnValue("mxc://%s/%s" % (self.server_name, media_id))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_local_media(self, request, media_id, name):
|
||||
"""Responds to reqests for local media, if exists, or returns 404.
|
||||
|
||||
Args:
|
||||
request(twisted.web.http.Request)
|
||||
media_id (str): The media ID of the content. (This is the same as
|
||||
the file_id for local content.)
|
||||
name (str|None): Optional name that, if specified, will be used as
|
||||
the filename in the Content-Disposition header of the response.
|
||||
|
||||
Returns:
|
||||
Deferred: Resolves once a response has successfully been written
|
||||
to request
|
||||
"""
|
||||
media_info = yield self.store.get_local_media(media_id)
|
||||
if not media_info or media_info["quarantined_by"]:
|
||||
respond_404(request)
|
||||
return
|
||||
|
||||
self.mark_recently_accessed(None, media_id)
|
||||
|
||||
media_type = media_info["media_type"]
|
||||
media_length = media_info["media_length"]
|
||||
upload_name = name if name else media_info["upload_name"]
|
||||
url_cache = media_info["url_cache"]
|
||||
|
||||
file_info = FileInfo(
|
||||
None, media_id,
|
||||
url_cache=url_cache,
|
||||
)
|
||||
|
||||
responder = yield self.media_storage.fetch_media(file_info)
|
||||
yield respond_with_responder(
|
||||
request, responder, media_type, media_length, upload_name,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_remote_media(self, request, server_name, media_id, name):
|
||||
"""Respond to requests for remote media.
|
||||
|
||||
Args:
|
||||
request(twisted.web.http.Request)
|
||||
server_name (str): Remote server_name where the media originated.
|
||||
media_id (str): The media ID of the content (as defined by the
|
||||
remote server).
|
||||
name (str|None): Optional name that, if specified, will be used as
|
||||
the filename in the Content-Disposition header of the response.
|
||||
|
||||
Returns:
|
||||
Deferred: Resolves once a response has successfully been written
|
||||
to request
|
||||
"""
|
||||
if (
|
||||
self.federation_domain_whitelist is not None and
|
||||
server_name not in self.federation_domain_whitelist
|
||||
):
|
||||
raise FederationDeniedError(server_name)
|
||||
|
||||
self.mark_recently_accessed(server_name, media_id)
|
||||
|
||||
# We linearize here to ensure that we don't try and download remote
|
||||
# media multiple times concurrently
|
||||
def get_remote_media(self, server_name, media_id):
|
||||
key = (server_name, media_id)
|
||||
with (yield self.remote_media_linearizer.queue(key)):
|
||||
responder, media_info = yield self._get_remote_media_impl(
|
||||
server_name, media_id,
|
||||
)
|
||||
|
||||
# We deliberately stream the file outside the lock
|
||||
if responder:
|
||||
media_type = media_info["media_type"]
|
||||
media_length = media_info["media_length"]
|
||||
upload_name = name if name else media_info["upload_name"]
|
||||
yield respond_with_responder(
|
||||
request, responder, media_type, media_length, upload_name,
|
||||
)
|
||||
else:
|
||||
respond_404(request)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_remote_media_info(self, server_name, media_id):
|
||||
"""Gets the media info associated with the remote file, downloading
|
||||
if necessary.
|
||||
|
||||
Args:
|
||||
server_name (str): Remote server_name where the media originated.
|
||||
media_id (str): The media ID of the content (as defined by the
|
||||
remote server).
|
||||
|
||||
Returns:
|
||||
Deferred[dict]: The media_info of the file
|
||||
"""
|
||||
if (
|
||||
self.federation_domain_whitelist is not None and
|
||||
server_name not in self.federation_domain_whitelist
|
||||
):
|
||||
raise FederationDeniedError(server_name)
|
||||
|
||||
# We linearize here to ensure that we don't try and download remote
|
||||
# media multiple times concurrently
|
||||
key = (server_name, media_id)
|
||||
with (yield self.remote_media_linearizer.queue(key)):
|
||||
responder, media_info = yield self._get_remote_media_impl(
|
||||
server_name, media_id,
|
||||
)
|
||||
|
||||
# Ensure we actually use the responder so that it releases resources
|
||||
if responder:
|
||||
with responder:
|
||||
pass
|
||||
|
||||
media_info = yield self._get_remote_media_impl(server_name, media_id)
|
||||
defer.returnValue(media_info)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _get_remote_media_impl(self, server_name, media_id):
|
||||
"""Looks for media in local cache, if not there then attempt to
|
||||
download from remote server.
|
||||
|
||||
Args:
|
||||
server_name (str): Remote server_name where the media originated.
|
||||
media_id (str): The media ID of the content (as defined by the
|
||||
remote server).
|
||||
|
||||
Returns:
|
||||
Deferred[(Responder, media_info)]
|
||||
"""
|
||||
media_info = yield self.store.get_cached_remote_media(
|
||||
server_name, media_id
|
||||
)
|
||||
|
||||
# file_id is the ID we use to track the file locally. If we've already
|
||||
# seen the file then reuse the existing ID, otherwise genereate a new
|
||||
# one.
|
||||
if media_info:
|
||||
file_id = media_info["filesystem_id"]
|
||||
if not media_info:
|
||||
media_info = yield self._download_remote_file(
|
||||
server_name, media_id
|
||||
)
|
||||
elif media_info["quarantined_by"]:
|
||||
raise NotFoundError()
|
||||
else:
|
||||
file_id = random_string(24)
|
||||
|
||||
file_info = FileInfo(server_name, file_id)
|
||||
|
||||
# If we have an entry in the DB, try and look for it
|
||||
if media_info:
|
||||
if media_info["quarantined_by"]:
|
||||
logger.info("Media is quarantined")
|
||||
raise NotFoundError()
|
||||
|
||||
responder = yield self.media_storage.fetch_media(file_info)
|
||||
if responder:
|
||||
defer.returnValue((responder, media_info))
|
||||
|
||||
# Failed to find the file anywhere, lets download it.
|
||||
|
||||
media_info = yield self._download_remote_file(
|
||||
server_name, media_id, file_id
|
||||
)
|
||||
|
||||
responder = yield self.media_storage.fetch_media(file_info)
|
||||
defer.returnValue((responder, media_info))
|
||||
self.recently_accessed_remotes.add((server_name, media_id))
|
||||
yield self.store.update_cached_last_access_time(
|
||||
[(server_name, media_id)], self.clock.time_msec()
|
||||
)
|
||||
defer.returnValue(media_info)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _download_remote_file(self, server_name, media_id, file_id):
|
||||
"""Attempt to download the remote file from the given server name,
|
||||
using the given file_id as the local id.
|
||||
def _download_remote_file(self, server_name, media_id):
|
||||
file_id = random_string(24)
|
||||
|
||||
Args:
|
||||
server_name (str): Originating server
|
||||
media_id (str): The media ID of the content (as defined by the
|
||||
remote server). This is different than the file_id, which is
|
||||
locally generated.
|
||||
file_id (str): Local file ID
|
||||
|
||||
Returns:
|
||||
Deferred[MediaInfo]
|
||||
"""
|
||||
|
||||
file_info = FileInfo(
|
||||
server_name=server_name,
|
||||
file_id=file_id,
|
||||
fpath = self.filepaths.remote_media_filepath_rel(
|
||||
server_name, file_id
|
||||
)
|
||||
fname = os.path.join(self.primary_base_path, fpath)
|
||||
self._makedirs(fname)
|
||||
|
||||
with self.media_storage.store_into_file(file_info) as (f, fname, finish):
|
||||
request_path = "/".join((
|
||||
"/_matrix/media/v1/download", server_name, media_id,
|
||||
))
|
||||
try:
|
||||
length, headers = yield self.client.get_file(
|
||||
server_name, request_path, output_stream=f,
|
||||
max_size=self.max_upload_size, args={
|
||||
# tell the remote server to 404 if it doesn't
|
||||
# recognise the server_name, to make sure we don't
|
||||
# end up with a routing loop.
|
||||
"allow_remote": "false",
|
||||
}
|
||||
)
|
||||
except twisted.internet.error.DNSLookupError as e:
|
||||
logger.warn("HTTP error fetching remote media %s/%s: %r",
|
||||
server_name, media_id, e)
|
||||
raise NotFoundError()
|
||||
|
||||
except HttpResponseException as e:
|
||||
logger.warn("HTTP error fetching remote media %s/%s: %s",
|
||||
server_name, media_id, e.response)
|
||||
if e.code == twisted.web.http.NOT_FOUND:
|
||||
raise SynapseError.from_http_response_exception(e)
|
||||
raise SynapseError(502, "Failed to fetch remote media")
|
||||
|
||||
except SynapseError:
|
||||
logger.exception("Failed to fetch remote media %s/%s",
|
||||
server_name, media_id)
|
||||
raise
|
||||
except NotRetryingDestination:
|
||||
logger.warn("Not retrying destination %r", server_name)
|
||||
raise SynapseError(502, "Failed to fetch remote media")
|
||||
except Exception:
|
||||
logger.exception("Failed to fetch remote media %s/%s",
|
||||
server_name, media_id)
|
||||
raise SynapseError(502, "Failed to fetch remote media")
|
||||
|
||||
yield finish()
|
||||
|
||||
media_type = headers["Content-Type"][0]
|
||||
|
||||
time_now_ms = self.clock.time_msec()
|
||||
|
||||
content_disposition = headers.get("Content-Disposition", None)
|
||||
if content_disposition:
|
||||
_, params = cgi.parse_header(content_disposition[0],)
|
||||
upload_name = None
|
||||
|
||||
# First check if there is a valid UTF-8 filename
|
||||
upload_name_utf8 = params.get("filename*", None)
|
||||
if upload_name_utf8:
|
||||
if upload_name_utf8.lower().startswith("utf-8''"):
|
||||
upload_name = upload_name_utf8[7:]
|
||||
|
||||
# If there isn't check for an ascii name.
|
||||
if not upload_name:
|
||||
upload_name_ascii = params.get("filename", None)
|
||||
if upload_name_ascii and is_ascii(upload_name_ascii):
|
||||
upload_name = upload_name_ascii
|
||||
|
||||
if upload_name:
|
||||
upload_name = urlparse.unquote(upload_name)
|
||||
try:
|
||||
with open(fname, "wb") as f:
|
||||
request_path = "/".join((
|
||||
"/_matrix/media/v1/download", server_name, media_id,
|
||||
))
|
||||
try:
|
||||
upload_name = upload_name.decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
upload_name = None
|
||||
else:
|
||||
upload_name = None
|
||||
length, headers = yield self.client.get_file(
|
||||
server_name, request_path, output_stream=f,
|
||||
max_size=self.max_upload_size, args={
|
||||
# tell the remote server to 404 if it doesn't
|
||||
# recognise the server_name, to make sure we don't
|
||||
# end up with a routing loop.
|
||||
"allow_remote": "false",
|
||||
}
|
||||
)
|
||||
except twisted.internet.error.DNSLookupError as e:
|
||||
logger.warn("HTTP error fetching remote media %s/%s: %r",
|
||||
server_name, media_id, e)
|
||||
raise NotFoundError()
|
||||
|
||||
logger.info("Stored remote media in file %r", fname)
|
||||
except HttpResponseException as e:
|
||||
logger.warn("HTTP error fetching remote media %s/%s: %s",
|
||||
server_name, media_id, e.response)
|
||||
if e.code == twisted.web.http.NOT_FOUND:
|
||||
raise SynapseError.from_http_response_exception(e)
|
||||
raise SynapseError(502, "Failed to fetch remote media")
|
||||
|
||||
yield self.store.store_cached_remote_media(
|
||||
origin=server_name,
|
||||
media_id=media_id,
|
||||
media_type=media_type,
|
||||
time_now_ms=self.clock.time_msec(),
|
||||
upload_name=upload_name,
|
||||
media_length=length,
|
||||
filesystem_id=file_id,
|
||||
)
|
||||
except SynapseError:
|
||||
logger.exception("Failed to fetch remote media %s/%s",
|
||||
server_name, media_id)
|
||||
raise
|
||||
except NotRetryingDestination:
|
||||
logger.warn("Not retrying destination %r", server_name)
|
||||
raise SynapseError(502, "Failed to fetch remote media")
|
||||
except Exception:
|
||||
logger.exception("Failed to fetch remote media %s/%s",
|
||||
server_name, media_id)
|
||||
raise SynapseError(502, "Failed to fetch remote media")
|
||||
|
||||
yield self.copy_to_backup(fpath)
|
||||
|
||||
media_type = headers["Content-Type"][0]
|
||||
time_now_ms = self.clock.time_msec()
|
||||
|
||||
content_disposition = headers.get("Content-Disposition", None)
|
||||
if content_disposition:
|
||||
_, params = cgi.parse_header(content_disposition[0],)
|
||||
upload_name = None
|
||||
|
||||
# First check if there is a valid UTF-8 filename
|
||||
upload_name_utf8 = params.get("filename*", None)
|
||||
if upload_name_utf8:
|
||||
if upload_name_utf8.lower().startswith("utf-8''"):
|
||||
upload_name = upload_name_utf8[7:]
|
||||
|
||||
# If there isn't check for an ascii name.
|
||||
if not upload_name:
|
||||
upload_name_ascii = params.get("filename", None)
|
||||
if upload_name_ascii and is_ascii(upload_name_ascii):
|
||||
upload_name = upload_name_ascii
|
||||
|
||||
if upload_name:
|
||||
upload_name = urlparse.unquote(upload_name)
|
||||
try:
|
||||
upload_name = upload_name.decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
upload_name = None
|
||||
else:
|
||||
upload_name = None
|
||||
|
||||
logger.info("Stored remote media in file %r", fname)
|
||||
|
||||
yield self.store.store_cached_remote_media(
|
||||
origin=server_name,
|
||||
media_id=media_id,
|
||||
media_type=media_type,
|
||||
time_now_ms=self.clock.time_msec(),
|
||||
upload_name=upload_name,
|
||||
media_length=length,
|
||||
filesystem_id=file_id,
|
||||
)
|
||||
except Exception:
|
||||
os.remove(fname)
|
||||
raise
|
||||
|
||||
media_info = {
|
||||
"media_type": media_type,
|
||||
@@ -438,7 +323,7 @@ class MediaRepository(object):
|
||||
}
|
||||
|
||||
yield self._generate_thumbnails(
|
||||
server_name, media_id, file_id, media_type,
|
||||
server_name, media_id, media_info
|
||||
)
|
||||
|
||||
defer.returnValue(media_info)
|
||||
@@ -472,10 +357,8 @@ class MediaRepository(object):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def generate_local_exact_thumbnail(self, media_id, t_width, t_height,
|
||||
t_method, t_type, url_cache):
|
||||
input_path = yield self.media_storage.ensure_media_is_in_local_cache(FileInfo(
|
||||
None, media_id, url_cache=url_cache,
|
||||
))
|
||||
t_method, t_type):
|
||||
input_path = self.filepaths.local_media_filepath(media_id)
|
||||
|
||||
thumbnailer = Thumbnailer(input_path)
|
||||
t_byte_source = yield make_deferred_yieldable(threads.deferToThread(
|
||||
@@ -485,19 +368,11 @@ class MediaRepository(object):
|
||||
|
||||
if t_byte_source:
|
||||
try:
|
||||
file_info = FileInfo(
|
||||
server_name=None,
|
||||
file_id=media_id,
|
||||
url_cache=url_cache,
|
||||
thumbnail=True,
|
||||
thumbnail_width=t_width,
|
||||
thumbnail_height=t_height,
|
||||
thumbnail_method=t_method,
|
||||
thumbnail_type=t_type,
|
||||
)
|
||||
|
||||
output_path = yield self.media_storage.store_file(
|
||||
t_byte_source, file_info,
|
||||
output_path = yield self.write_to_file_and_backup(
|
||||
t_byte_source,
|
||||
self.filepaths.local_media_thumbnail_rel(
|
||||
media_id, t_width, t_height, t_type, t_method
|
||||
)
|
||||
)
|
||||
finally:
|
||||
t_byte_source.close()
|
||||
@@ -515,9 +390,7 @@ class MediaRepository(object):
|
||||
@defer.inlineCallbacks
|
||||
def generate_remote_exact_thumbnail(self, server_name, file_id, media_id,
|
||||
t_width, t_height, t_method, t_type):
|
||||
input_path = yield self.media_storage.ensure_media_is_in_local_cache(FileInfo(
|
||||
server_name, file_id, url_cache=False,
|
||||
))
|
||||
input_path = self.filepaths.remote_media_filepath(server_name, file_id)
|
||||
|
||||
thumbnailer = Thumbnailer(input_path)
|
||||
t_byte_source = yield make_deferred_yieldable(threads.deferToThread(
|
||||
@@ -527,18 +400,11 @@ class MediaRepository(object):
|
||||
|
||||
if t_byte_source:
|
||||
try:
|
||||
file_info = FileInfo(
|
||||
server_name=server_name,
|
||||
file_id=media_id,
|
||||
thumbnail=True,
|
||||
thumbnail_width=t_width,
|
||||
thumbnail_height=t_height,
|
||||
thumbnail_method=t_method,
|
||||
thumbnail_type=t_type,
|
||||
)
|
||||
|
||||
output_path = yield self.media_storage.store_file(
|
||||
t_byte_source, file_info,
|
||||
output_path = yield self.write_to_file_and_backup(
|
||||
t_byte_source,
|
||||
self.filepaths.remote_media_thumbnail_rel(
|
||||
server_name, file_id, t_width, t_height, t_type, t_method
|
||||
)
|
||||
)
|
||||
finally:
|
||||
t_byte_source.close()
|
||||
@@ -555,29 +421,31 @@ class MediaRepository(object):
|
||||
defer.returnValue(output_path)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _generate_thumbnails(self, server_name, media_id, file_id, media_type,
|
||||
url_cache=False):
|
||||
def _generate_thumbnails(self, server_name, media_id, media_info, url_cache=False):
|
||||
"""Generate and store thumbnails for an image.
|
||||
|
||||
Args:
|
||||
server_name (str|None): The server name if remote media, else None if local
|
||||
media_id (str): The media ID of the content. (This is the same as
|
||||
the file_id for local content)
|
||||
file_id (str): Local file ID
|
||||
media_type (str): The content type of the file
|
||||
url_cache (bool): If we are thumbnailing images downloaded for the URL cache,
|
||||
server_name(str|None): The server name if remote media, else None if local
|
||||
media_id(str)
|
||||
media_info(dict)
|
||||
url_cache(bool): If we are thumbnailing images downloaded for the URL cache,
|
||||
used exclusively by the url previewer
|
||||
|
||||
Returns:
|
||||
Deferred[dict]: Dict with "width" and "height" keys of original image
|
||||
"""
|
||||
media_type = media_info["media_type"]
|
||||
file_id = media_info.get("filesystem_id")
|
||||
requirements = self._get_thumbnail_requirements(media_type)
|
||||
if not requirements:
|
||||
return
|
||||
|
||||
input_path = yield self.media_storage.ensure_media_is_in_local_cache(FileInfo(
|
||||
server_name, file_id, url_cache=url_cache,
|
||||
))
|
||||
if server_name:
|
||||
input_path = self.filepaths.remote_media_filepath(server_name, file_id)
|
||||
elif url_cache:
|
||||
input_path = self.filepaths.url_cache_filepath(media_id)
|
||||
else:
|
||||
input_path = self.filepaths.local_media_filepath(media_id)
|
||||
|
||||
thumbnailer = Thumbnailer(input_path)
|
||||
m_width = thumbnailer.width
|
||||
@@ -604,6 +472,20 @@ class MediaRepository(object):
|
||||
|
||||
# Now we generate the thumbnails for each dimension, store it
|
||||
for (t_width, t_height, t_type), t_method in thumbnails.iteritems():
|
||||
# Work out the correct file name for thumbnail
|
||||
if server_name:
|
||||
file_path = self.filepaths.remote_media_thumbnail_rel(
|
||||
server_name, file_id, t_width, t_height, t_type, t_method
|
||||
)
|
||||
elif url_cache:
|
||||
file_path = self.filepaths.url_cache_thumbnail_rel(
|
||||
media_id, t_width, t_height, t_type, t_method
|
||||
)
|
||||
else:
|
||||
file_path = self.filepaths.local_media_thumbnail_rel(
|
||||
media_id, t_width, t_height, t_type, t_method
|
||||
)
|
||||
|
||||
# Generate the thumbnail
|
||||
if t_method == "crop":
|
||||
t_byte_source = yield make_deferred_yieldable(threads.deferToThread(
|
||||
@@ -623,19 +505,9 @@ class MediaRepository(object):
|
||||
continue
|
||||
|
||||
try:
|
||||
file_info = FileInfo(
|
||||
server_name=server_name,
|
||||
file_id=file_id,
|
||||
thumbnail=True,
|
||||
thumbnail_width=t_width,
|
||||
thumbnail_height=t_height,
|
||||
thumbnail_method=t_method,
|
||||
thumbnail_type=t_type,
|
||||
url_cache=url_cache,
|
||||
)
|
||||
|
||||
output_path = yield self.media_storage.store_file(
|
||||
t_byte_source, file_info,
|
||||
# Write to disk
|
||||
output_path = yield self.write_to_file_and_backup(
|
||||
t_byte_source, file_path,
|
||||
)
|
||||
finally:
|
||||
t_byte_source.close()
|
||||
@@ -748,11 +620,7 @@ class MediaRepositoryResource(Resource):
|
||||
|
||||
self.putChild("upload", UploadResource(hs, media_repo))
|
||||
self.putChild("download", DownloadResource(hs, media_repo))
|
||||
self.putChild("thumbnail", ThumbnailResource(
|
||||
hs, media_repo, media_repo.media_storage,
|
||||
))
|
||||
self.putChild("thumbnail", ThumbnailResource(hs, media_repo))
|
||||
self.putChild("identicon", IdenticonResource())
|
||||
if hs.config.url_preview_enabled:
|
||||
self.putChild("preview_url", PreviewUrlResource(
|
||||
hs, media_repo, media_repo.media_storage,
|
||||
))
|
||||
self.putChild("preview_url", PreviewUrlResource(hs, media_repo))
|
||||
|
||||
@@ -1,274 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2018 New Vecotr Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from twisted.internet import defer, threads
|
||||
from twisted.protocols.basic import FileSender
|
||||
|
||||
from ._base import Responder
|
||||
|
||||
from synapse.util.file_consumer import BackgroundFileConsumer
|
||||
from synapse.util.logcontext import make_deferred_yieldable
|
||||
|
||||
import contextlib
|
||||
import os
|
||||
import logging
|
||||
import shutil
|
||||
import sys
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MediaStorage(object):
|
||||
"""Responsible for storing/fetching files from local sources.
|
||||
|
||||
Args:
|
||||
local_media_directory (str): Base path where we store media on disk
|
||||
filepaths (MediaFilePaths)
|
||||
storage_providers ([StorageProvider]): List of StorageProvider that are
|
||||
used to fetch and store files.
|
||||
"""
|
||||
|
||||
def __init__(self, local_media_directory, filepaths, storage_providers):
|
||||
self.local_media_directory = local_media_directory
|
||||
self.filepaths = filepaths
|
||||
self.storage_providers = storage_providers
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def store_file(self, source, file_info):
|
||||
"""Write `source` to the on disk media store, and also any other
|
||||
configured storage providers
|
||||
|
||||
Args:
|
||||
source: A file like object that should be written
|
||||
file_info (FileInfo): Info about the file to store
|
||||
|
||||
Returns:
|
||||
Deferred[str]: the file path written to in the primary media store
|
||||
"""
|
||||
path = self._file_info_to_path(file_info)
|
||||
fname = os.path.join(self.local_media_directory, path)
|
||||
|
||||
dirname = os.path.dirname(fname)
|
||||
if not os.path.exists(dirname):
|
||||
os.makedirs(dirname)
|
||||
|
||||
# Write to the main repository
|
||||
yield make_deferred_yieldable(threads.deferToThread(
|
||||
_write_file_synchronously, source, fname,
|
||||
))
|
||||
|
||||
# Tell the storage providers about the new file. They'll decide
|
||||
# if they should upload it and whether to do so synchronously
|
||||
# or not.
|
||||
for provider in self.storage_providers:
|
||||
yield provider.store_file(path, file_info)
|
||||
|
||||
defer.returnValue(fname)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def store_into_file(self, file_info):
|
||||
"""Context manager used to get a file like object to write into, as
|
||||
described by file_info.
|
||||
|
||||
Actually yields a 3-tuple (file, fname, finish_cb), where file is a file
|
||||
like object that can be written to, fname is the absolute path of file
|
||||
on disk, and finish_cb is a function that returns a Deferred.
|
||||
|
||||
fname can be used to read the contents from after upload, e.g. to
|
||||
generate thumbnails.
|
||||
|
||||
finish_cb must be called and waited on after the file has been
|
||||
successfully been written to. Should not be called if there was an
|
||||
error.
|
||||
|
||||
Args:
|
||||
file_info (FileInfo): Info about the file to store
|
||||
|
||||
Example:
|
||||
|
||||
with media_storage.store_into_file(info) as (f, fname, finish_cb):
|
||||
# .. write into f ...
|
||||
yield finish_cb()
|
||||
"""
|
||||
|
||||
path = self._file_info_to_path(file_info)
|
||||
fname = os.path.join(self.local_media_directory, path)
|
||||
|
||||
dirname = os.path.dirname(fname)
|
||||
if not os.path.exists(dirname):
|
||||
os.makedirs(dirname)
|
||||
|
||||
finished_called = [False]
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def finish():
|
||||
for provider in self.storage_providers:
|
||||
yield provider.store_file(path, file_info)
|
||||
|
||||
finished_called[0] = True
|
||||
|
||||
try:
|
||||
with open(fname, "wb") as f:
|
||||
yield f, fname, finish
|
||||
except Exception:
|
||||
t, v, tb = sys.exc_info()
|
||||
try:
|
||||
os.remove(fname)
|
||||
except Exception:
|
||||
pass
|
||||
raise t, v, tb
|
||||
|
||||
if not finished_called:
|
||||
raise Exception("Finished callback not called")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def fetch_media(self, file_info):
|
||||
"""Attempts to fetch media described by file_info from the local cache
|
||||
and configured storage providers.
|
||||
|
||||
Args:
|
||||
file_info (FileInfo)
|
||||
|
||||
Returns:
|
||||
Deferred[Responder|None]: Returns a Responder if the file was found,
|
||||
otherwise None.
|
||||
"""
|
||||
|
||||
path = self._file_info_to_path(file_info)
|
||||
local_path = os.path.join(self.local_media_directory, path)
|
||||
if os.path.exists(local_path):
|
||||
defer.returnValue(FileResponder(open(local_path, "rb")))
|
||||
|
||||
for provider in self.storage_providers:
|
||||
res = yield provider.fetch(path, file_info)
|
||||
if res:
|
||||
defer.returnValue(res)
|
||||
|
||||
defer.returnValue(None)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def ensure_media_is_in_local_cache(self, file_info):
|
||||
"""Ensures that the given file is in the local cache. Attempts to
|
||||
download it from storage providers if it isn't.
|
||||
|
||||
Args:
|
||||
file_info (FileInfo)
|
||||
|
||||
Returns:
|
||||
Deferred[str]: Full path to local file
|
||||
"""
|
||||
path = self._file_info_to_path(file_info)
|
||||
local_path = os.path.join(self.local_media_directory, path)
|
||||
if os.path.exists(local_path):
|
||||
defer.returnValue(local_path)
|
||||
|
||||
dirname = os.path.dirname(local_path)
|
||||
if not os.path.exists(dirname):
|
||||
os.makedirs(dirname)
|
||||
|
||||
for provider in self.storage_providers:
|
||||
res = yield provider.fetch(path, file_info)
|
||||
if res:
|
||||
with res:
|
||||
consumer = BackgroundFileConsumer(open(local_path, "w"))
|
||||
yield res.write_to_consumer(consumer)
|
||||
yield consumer.wait()
|
||||
defer.returnValue(local_path)
|
||||
|
||||
raise Exception("file could not be found")
|
||||
|
||||
def _file_info_to_path(self, file_info):
|
||||
"""Converts file_info into a relative path.
|
||||
|
||||
The path is suitable for storing files under a directory, e.g. used to
|
||||
store files on local FS under the base media repository directory.
|
||||
|
||||
Args:
|
||||
file_info (FileInfo)
|
||||
|
||||
Returns:
|
||||
str
|
||||
"""
|
||||
if file_info.url_cache:
|
||||
if file_info.thumbnail:
|
||||
return self.filepaths.url_cache_thumbnail_rel(
|
||||
media_id=file_info.file_id,
|
||||
width=file_info.thumbnail_width,
|
||||
height=file_info.thumbnail_height,
|
||||
content_type=file_info.thumbnail_type,
|
||||
method=file_info.thumbnail_method,
|
||||
)
|
||||
return self.filepaths.url_cache_filepath_rel(file_info.file_id)
|
||||
|
||||
if file_info.server_name:
|
||||
if file_info.thumbnail:
|
||||
return self.filepaths.remote_media_thumbnail_rel(
|
||||
server_name=file_info.server_name,
|
||||
file_id=file_info.file_id,
|
||||
width=file_info.thumbnail_width,
|
||||
height=file_info.thumbnail_height,
|
||||
content_type=file_info.thumbnail_type,
|
||||
method=file_info.thumbnail_method
|
||||
)
|
||||
return self.filepaths.remote_media_filepath_rel(
|
||||
file_info.server_name, file_info.file_id,
|
||||
)
|
||||
|
||||
if file_info.thumbnail:
|
||||
return self.filepaths.local_media_thumbnail_rel(
|
||||
media_id=file_info.file_id,
|
||||
width=file_info.thumbnail_width,
|
||||
height=file_info.thumbnail_height,
|
||||
content_type=file_info.thumbnail_type,
|
||||
method=file_info.thumbnail_method
|
||||
)
|
||||
return self.filepaths.local_media_filepath_rel(
|
||||
file_info.file_id,
|
||||
)
|
||||
|
||||
|
||||
def _write_file_synchronously(source, fname):
|
||||
"""Write `source` to the path `fname` synchronously. Should be called
|
||||
from a thread.
|
||||
|
||||
Args:
|
||||
source: A file like object to be written
|
||||
fname (str): Path to write to
|
||||
"""
|
||||
dirname = os.path.dirname(fname)
|
||||
if not os.path.exists(dirname):
|
||||
os.makedirs(dirname)
|
||||
|
||||
source.seek(0) # Ensure we read from the start of the file
|
||||
with open(fname, "wb") as f:
|
||||
shutil.copyfileobj(source, f)
|
||||
|
||||
|
||||
class FileResponder(Responder):
|
||||
"""Wraps an open file that can be sent to a request.
|
||||
|
||||
Args:
|
||||
open_file (file): A file like object to be streamed ot the client,
|
||||
is closed when finished streaming.
|
||||
"""
|
||||
def __init__(self, open_file):
|
||||
self.open_file = open_file
|
||||
|
||||
def write_to_consumer(self, consumer):
|
||||
return FileSender().beginFileTransfer(self.open_file, consumer)
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
self.open_file.close()
|
||||
@@ -12,26 +12,11 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import cgi
|
||||
import datetime
|
||||
import errno
|
||||
import fnmatch
|
||||
import itertools
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import sys
|
||||
import traceback
|
||||
import ujson as json
|
||||
import urlparse
|
||||
|
||||
from twisted.web.server import NOT_DONE_YET
|
||||
from twisted.internet import defer
|
||||
from twisted.web.resource import Resource
|
||||
|
||||
from ._base import FileInfo
|
||||
|
||||
from synapse.api.errors import (
|
||||
SynapseError, Codes,
|
||||
)
|
||||
@@ -46,13 +31,25 @@ from synapse.http.server import (
|
||||
from synapse.util.async import ObservableDeferred
|
||||
from synapse.util.stringutils import is_ascii
|
||||
|
||||
import os
|
||||
import re
|
||||
import fnmatch
|
||||
import cgi
|
||||
import ujson as json
|
||||
import urlparse
|
||||
import itertools
|
||||
import datetime
|
||||
import errno
|
||||
import shutil
|
||||
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PreviewUrlResource(Resource):
|
||||
isLeaf = True
|
||||
|
||||
def __init__(self, hs, media_repo, media_storage):
|
||||
def __init__(self, hs, media_repo):
|
||||
Resource.__init__(self)
|
||||
|
||||
self.auth = hs.get_auth()
|
||||
@@ -65,7 +62,6 @@ class PreviewUrlResource(Resource):
|
||||
self.client = SpiderHttpClient(hs)
|
||||
self.media_repo = media_repo
|
||||
self.primary_base_path = media_repo.primary_base_path
|
||||
self.media_storage = media_storage
|
||||
|
||||
self.url_preview_url_blacklist = hs.config.url_preview_url_blacklist
|
||||
|
||||
@@ -186,10 +182,8 @@ class PreviewUrlResource(Resource):
|
||||
logger.debug("got media_info of '%s'" % media_info)
|
||||
|
||||
if _is_media(media_info['media_type']):
|
||||
file_id = media_info['filesystem_id']
|
||||
dims = yield self.media_repo._generate_thumbnails(
|
||||
None, file_id, file_id, media_info["media_type"],
|
||||
url_cache=True,
|
||||
None, media_info['filesystem_id'], media_info, url_cache=True,
|
||||
)
|
||||
|
||||
og = {
|
||||
@@ -234,10 +228,8 @@ class PreviewUrlResource(Resource):
|
||||
|
||||
if _is_media(image_info['media_type']):
|
||||
# TODO: make sure we don't choke on white-on-transparent images
|
||||
file_id = image_info['filesystem_id']
|
||||
dims = yield self.media_repo._generate_thumbnails(
|
||||
None, file_id, file_id, image_info["media_type"],
|
||||
url_cache=True,
|
||||
None, image_info['filesystem_id'], image_info, url_cache=True,
|
||||
)
|
||||
if dims:
|
||||
og["og:image:width"] = dims['width']
|
||||
@@ -281,34 +273,21 @@ class PreviewUrlResource(Resource):
|
||||
|
||||
file_id = datetime.date.today().isoformat() + '_' + random_string(16)
|
||||
|
||||
file_info = FileInfo(
|
||||
server_name=None,
|
||||
file_id=file_id,
|
||||
url_cache=True,
|
||||
)
|
||||
fpath = self.filepaths.url_cache_filepath_rel(file_id)
|
||||
fname = os.path.join(self.primary_base_path, fpath)
|
||||
self.media_repo._makedirs(fname)
|
||||
|
||||
with self.media_storage.store_into_file(file_info) as (f, fname, finish):
|
||||
try:
|
||||
try:
|
||||
with open(fname, "wb") as f:
|
||||
logger.debug("Trying to get url '%s'" % url)
|
||||
length, headers, uri, code = yield self.client.get_file(
|
||||
url, output_stream=f, max_size=self.max_spider_size,
|
||||
)
|
||||
except Exception as e:
|
||||
# FIXME: pass through 404s and other error messages nicely
|
||||
logger.warn("Error downloading %s: %r", url, e)
|
||||
raise SynapseError(
|
||||
500, "Failed to download content: %s" % (
|
||||
traceback.format_exception_only(sys.exc_type, e),
|
||||
),
|
||||
Codes.UNKNOWN,
|
||||
)
|
||||
yield finish()
|
||||
|
||||
try:
|
||||
if "Content-Type" in headers:
|
||||
media_type = headers["Content-Type"][0]
|
||||
else:
|
||||
media_type = "application/octet-stream"
|
||||
yield self.media_repo.copy_to_backup(fpath)
|
||||
|
||||
media_type = headers["Content-Type"][0]
|
||||
time_now_ms = self.clock.time_msec()
|
||||
|
||||
content_disposition = headers.get("Content-Disposition", None)
|
||||
@@ -348,11 +327,11 @@ class PreviewUrlResource(Resource):
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error handling downloaded %s: %r", url, e)
|
||||
# TODO: we really ought to delete the downloaded file in this
|
||||
# case, since we won't have recorded it in the db, and will
|
||||
# therefore not expire it.
|
||||
raise
|
||||
os.remove(fname)
|
||||
raise SynapseError(
|
||||
500, ("Failed to download content: %s" % e),
|
||||
Codes.UNKNOWN
|
||||
)
|
||||
|
||||
defer.returnValue({
|
||||
"media_type": media_type,
|
||||
|
||||
@@ -1,140 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2018 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from twisted.internet import defer, threads
|
||||
|
||||
from .media_storage import FileResponder
|
||||
|
||||
from synapse.config._base import Config
|
||||
from synapse.util.logcontext import preserve_fn
|
||||
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class StorageProvider(object):
|
||||
"""A storage provider is a service that can store uploaded media and
|
||||
retrieve them.
|
||||
"""
|
||||
def store_file(self, path, file_info):
|
||||
"""Store the file described by file_info. The actual contents can be
|
||||
retrieved by reading the file in file_info.upload_path.
|
||||
|
||||
Args:
|
||||
path (str): Relative path of file in local cache
|
||||
file_info (FileInfo)
|
||||
|
||||
Returns:
|
||||
Deferred
|
||||
"""
|
||||
pass
|
||||
|
||||
def fetch(self, path, file_info):
|
||||
"""Attempt to fetch the file described by file_info and stream it
|
||||
into writer.
|
||||
|
||||
Args:
|
||||
path (str): Relative path of file in local cache
|
||||
file_info (FileInfo)
|
||||
|
||||
Returns:
|
||||
Deferred(Responder): Returns a Responder if the provider has the file,
|
||||
otherwise returns None.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class StorageProviderWrapper(StorageProvider):
|
||||
"""Wraps a storage provider and provides various config options
|
||||
|
||||
Args:
|
||||
backend (StorageProvider)
|
||||
store_local (bool): Whether to store new local files or not.
|
||||
store_synchronous (bool): Whether to wait for file to be successfully
|
||||
uploaded, or todo the upload in the backgroud.
|
||||
store_remote (bool): Whether remote media should be uploaded
|
||||
"""
|
||||
def __init__(self, backend, store_local, store_synchronous, store_remote):
|
||||
self.backend = backend
|
||||
self.store_local = store_local
|
||||
self.store_synchronous = store_synchronous
|
||||
self.store_remote = store_remote
|
||||
|
||||
def store_file(self, path, file_info):
|
||||
if not file_info.server_name and not self.store_local:
|
||||
return defer.succeed(None)
|
||||
|
||||
if file_info.server_name and not self.store_remote:
|
||||
return defer.succeed(None)
|
||||
|
||||
if self.store_synchronous:
|
||||
return self.backend.store_file(path, file_info)
|
||||
else:
|
||||
# TODO: Handle errors.
|
||||
preserve_fn(self.backend.store_file)(path, file_info)
|
||||
return defer.succeed(None)
|
||||
|
||||
def fetch(self, path, file_info):
|
||||
return self.backend.fetch(path, file_info)
|
||||
|
||||
|
||||
class FileStorageProviderBackend(StorageProvider):
|
||||
"""A storage provider that stores files in a directory on a filesystem.
|
||||
|
||||
Args:
|
||||
hs (HomeServer)
|
||||
config: The config returned by `parse_config`.
|
||||
"""
|
||||
|
||||
def __init__(self, hs, config):
|
||||
self.cache_directory = hs.config.media_store_path
|
||||
self.base_directory = config
|
||||
|
||||
def store_file(self, path, file_info):
|
||||
"""See StorageProvider.store_file"""
|
||||
|
||||
primary_fname = os.path.join(self.cache_directory, path)
|
||||
backup_fname = os.path.join(self.base_directory, path)
|
||||
|
||||
dirname = os.path.dirname(backup_fname)
|
||||
if not os.path.exists(dirname):
|
||||
os.makedirs(dirname)
|
||||
|
||||
return threads.deferToThread(
|
||||
shutil.copyfile, primary_fname, backup_fname,
|
||||
)
|
||||
|
||||
def fetch(self, path, file_info):
|
||||
"""See StorageProvider.fetch"""
|
||||
|
||||
backup_fname = os.path.join(self.base_directory, path)
|
||||
if os.path.isfile(backup_fname):
|
||||
return FileResponder(open(backup_fname, "rb"))
|
||||
|
||||
@staticmethod
|
||||
def parse_config(config):
|
||||
"""Called on startup to parse config supplied. This should parse
|
||||
the config and raise if there is a problem.
|
||||
|
||||
The returned value is passed into the constructor.
|
||||
|
||||
In this case we only care about a single param, the directory, so let's
|
||||
just pull that out.
|
||||
"""
|
||||
return Config.ensure_directory(config["directory"])
|
||||
@@ -14,10 +14,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from ._base import (
|
||||
parse_media_id, respond_404, respond_with_file, FileInfo,
|
||||
respond_with_responder,
|
||||
)
|
||||
from ._base import parse_media_id, respond_404, respond_with_file
|
||||
from twisted.web.resource import Resource
|
||||
from synapse.http.servlet import parse_string, parse_integer
|
||||
from synapse.http.server import request_handler, set_cors_headers
|
||||
@@ -33,12 +30,12 @@ logger = logging.getLogger(__name__)
|
||||
class ThumbnailResource(Resource):
|
||||
isLeaf = True
|
||||
|
||||
def __init__(self, hs, media_repo, media_storage):
|
||||
def __init__(self, hs, media_repo):
|
||||
Resource.__init__(self)
|
||||
|
||||
self.store = hs.get_datastore()
|
||||
self.filepaths = media_repo.filepaths
|
||||
self.media_repo = media_repo
|
||||
self.media_storage = media_storage
|
||||
self.dynamic_thumbnails = hs.config.dynamic_thumbnails
|
||||
self.server_name = hs.hostname
|
||||
self.version_string = hs.version_string
|
||||
@@ -67,7 +64,6 @@ class ThumbnailResource(Resource):
|
||||
yield self._respond_local_thumbnail(
|
||||
request, media_id, width, height, method, m_type
|
||||
)
|
||||
self.media_repo.mark_recently_accessed(None, media_id)
|
||||
else:
|
||||
if self.dynamic_thumbnails:
|
||||
yield self._select_or_generate_remote_thumbnail(
|
||||
@@ -79,46 +75,48 @@ class ThumbnailResource(Resource):
|
||||
request, server_name, media_id,
|
||||
width, height, method, m_type
|
||||
)
|
||||
self.media_repo.mark_recently_accessed(server_name, media_id)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _respond_local_thumbnail(self, request, media_id, width, height,
|
||||
method, m_type):
|
||||
media_info = yield self.store.get_local_media(media_id)
|
||||
|
||||
if not media_info:
|
||||
respond_404(request)
|
||||
return
|
||||
if media_info["quarantined_by"]:
|
||||
logger.info("Media is quarantined")
|
||||
if not media_info or media_info["quarantined_by"]:
|
||||
respond_404(request)
|
||||
return
|
||||
|
||||
# if media_info["media_type"] == "image/svg+xml":
|
||||
# file_path = self.filepaths.local_media_filepath(media_id)
|
||||
# yield respond_with_file(request, media_info["media_type"], file_path)
|
||||
# return
|
||||
|
||||
thumbnail_infos = yield self.store.get_local_media_thumbnails(media_id)
|
||||
|
||||
if thumbnail_infos:
|
||||
thumbnail_info = self._select_thumbnail(
|
||||
width, height, method, m_type, thumbnail_infos
|
||||
)
|
||||
t_width = thumbnail_info["thumbnail_width"]
|
||||
t_height = thumbnail_info["thumbnail_height"]
|
||||
t_type = thumbnail_info["thumbnail_type"]
|
||||
t_method = thumbnail_info["thumbnail_method"]
|
||||
|
||||
file_info = FileInfo(
|
||||
server_name=None, file_id=media_id,
|
||||
url_cache=media_info["url_cache"],
|
||||
thumbnail=True,
|
||||
thumbnail_width=thumbnail_info["thumbnail_width"],
|
||||
thumbnail_height=thumbnail_info["thumbnail_height"],
|
||||
thumbnail_type=thumbnail_info["thumbnail_type"],
|
||||
thumbnail_method=thumbnail_info["thumbnail_method"],
|
||||
)
|
||||
if media_info["url_cache"]:
|
||||
# TODO: Check the file still exists, if it doesn't we can redownload
|
||||
# it from the url `media_info["url_cache"]`
|
||||
file_path = self.filepaths.url_cache_thumbnail(
|
||||
media_id, t_width, t_height, t_type, t_method,
|
||||
)
|
||||
else:
|
||||
file_path = self.filepaths.local_media_thumbnail(
|
||||
media_id, t_width, t_height, t_type, t_method,
|
||||
)
|
||||
yield respond_with_file(request, t_type, file_path)
|
||||
|
||||
t_type = file_info.thumbnail_type
|
||||
t_length = thumbnail_info["thumbnail_length"]
|
||||
|
||||
responder = yield self.media_storage.fetch_media(file_info)
|
||||
yield respond_with_responder(request, responder, t_type, t_length)
|
||||
else:
|
||||
logger.info("Couldn't find any generated thumbnails")
|
||||
respond_404(request)
|
||||
yield self._respond_default_thumbnail(
|
||||
request, media_info, width, height, method, m_type,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _select_or_generate_local_thumbnail(self, request, media_id, desired_width,
|
||||
@@ -126,14 +124,15 @@ class ThumbnailResource(Resource):
|
||||
desired_type):
|
||||
media_info = yield self.store.get_local_media(media_id)
|
||||
|
||||
if not media_info:
|
||||
respond_404(request)
|
||||
return
|
||||
if media_info["quarantined_by"]:
|
||||
logger.info("Media is quarantined")
|
||||
if not media_info or media_info["quarantined_by"]:
|
||||
respond_404(request)
|
||||
return
|
||||
|
||||
# if media_info["media_type"] == "image/svg+xml":
|
||||
# file_path = self.filepaths.local_media_filepath(media_id)
|
||||
# yield respond_with_file(request, media_info["media_type"], file_path)
|
||||
# return
|
||||
|
||||
thumbnail_infos = yield self.store.get_local_media_thumbnails(media_id)
|
||||
for info in thumbnail_infos:
|
||||
t_w = info["thumbnail_width"] == desired_width
|
||||
@@ -142,43 +141,46 @@ class ThumbnailResource(Resource):
|
||||
t_type = info["thumbnail_type"] == desired_type
|
||||
|
||||
if t_w and t_h and t_method and t_type:
|
||||
file_info = FileInfo(
|
||||
server_name=None, file_id=media_id,
|
||||
url_cache=media_info["url_cache"],
|
||||
thumbnail=True,
|
||||
thumbnail_width=info["thumbnail_width"],
|
||||
thumbnail_height=info["thumbnail_height"],
|
||||
thumbnail_type=info["thumbnail_type"],
|
||||
thumbnail_method=info["thumbnail_method"],
|
||||
)
|
||||
if media_info["url_cache"]:
|
||||
# TODO: Check the file still exists, if it doesn't we can redownload
|
||||
# it from the url `media_info["url_cache"]`
|
||||
file_path = self.filepaths.url_cache_thumbnail(
|
||||
media_id, desired_width, desired_height, desired_type,
|
||||
desired_method,
|
||||
)
|
||||
else:
|
||||
file_path = self.filepaths.local_media_thumbnail(
|
||||
media_id, desired_width, desired_height, desired_type,
|
||||
desired_method,
|
||||
)
|
||||
yield respond_with_file(request, desired_type, file_path)
|
||||
return
|
||||
|
||||
t_type = file_info.thumbnail_type
|
||||
t_length = info["thumbnail_length"]
|
||||
|
||||
responder = yield self.media_storage.fetch_media(file_info)
|
||||
if responder:
|
||||
yield respond_with_responder(request, responder, t_type, t_length)
|
||||
return
|
||||
|
||||
logger.debug("We don't have a thumbnail of that size. Generating")
|
||||
logger.debug("We don't have a local thumbnail of that size. Generating")
|
||||
|
||||
# Okay, so we generate one.
|
||||
file_path = yield self.media_repo.generate_local_exact_thumbnail(
|
||||
media_id, desired_width, desired_height, desired_method, desired_type,
|
||||
url_cache=media_info["url_cache"],
|
||||
media_id, desired_width, desired_height, desired_method, desired_type
|
||||
)
|
||||
|
||||
if file_path:
|
||||
yield respond_with_file(request, desired_type, file_path)
|
||||
else:
|
||||
logger.warn("Failed to generate thumbnail")
|
||||
respond_404(request)
|
||||
yield self._respond_default_thumbnail(
|
||||
request, media_info, desired_width, desired_height,
|
||||
desired_method, desired_type,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _select_or_generate_remote_thumbnail(self, request, server_name, media_id,
|
||||
desired_width, desired_height,
|
||||
desired_method, desired_type):
|
||||
media_info = yield self.media_repo.get_remote_media_info(server_name, media_id)
|
||||
media_info = yield self.media_repo.get_remote_media(server_name, media_id)
|
||||
|
||||
# if media_info["media_type"] == "image/svg+xml":
|
||||
# file_path = self.filepaths.remote_media_filepath(server_name, media_id)
|
||||
# yield respond_with_file(request, media_info["media_type"], file_path)
|
||||
# return
|
||||
|
||||
thumbnail_infos = yield self.store.get_remote_media_thumbnails(
|
||||
server_name, media_id,
|
||||
@@ -193,24 +195,14 @@ class ThumbnailResource(Resource):
|
||||
t_type = info["thumbnail_type"] == desired_type
|
||||
|
||||
if t_w and t_h and t_method and t_type:
|
||||
file_info = FileInfo(
|
||||
server_name=server_name, file_id=media_info["filesystem_id"],
|
||||
thumbnail=True,
|
||||
thumbnail_width=info["thumbnail_width"],
|
||||
thumbnail_height=info["thumbnail_height"],
|
||||
thumbnail_type=info["thumbnail_type"],
|
||||
thumbnail_method=info["thumbnail_method"],
|
||||
file_path = self.filepaths.remote_media_thumbnail(
|
||||
server_name, file_id, desired_width, desired_height,
|
||||
desired_type, desired_method,
|
||||
)
|
||||
yield respond_with_file(request, desired_type, file_path)
|
||||
return
|
||||
|
||||
t_type = file_info.thumbnail_type
|
||||
t_length = info["thumbnail_length"]
|
||||
|
||||
responder = yield self.media_storage.fetch_media(file_info)
|
||||
if responder:
|
||||
yield respond_with_responder(request, responder, t_type, t_length)
|
||||
return
|
||||
|
||||
logger.debug("We don't have a thumbnail of that size. Generating")
|
||||
logger.debug("We don't have a local thumbnail of that size. Generating")
|
||||
|
||||
# Okay, so we generate one.
|
||||
file_path = yield self.media_repo.generate_remote_exact_thumbnail(
|
||||
@@ -221,16 +213,22 @@ class ThumbnailResource(Resource):
|
||||
if file_path:
|
||||
yield respond_with_file(request, desired_type, file_path)
|
||||
else:
|
||||
logger.warn("Failed to generate thumbnail")
|
||||
respond_404(request)
|
||||
yield self._respond_default_thumbnail(
|
||||
request, media_info, desired_width, desired_height,
|
||||
desired_method, desired_type,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _respond_remote_thumbnail(self, request, server_name, media_id, width,
|
||||
height, method, m_type):
|
||||
# TODO: Don't download the whole remote file
|
||||
# We should proxy the thumbnail from the remote server instead of
|
||||
# downloading the remote file and generating our own thumbnails.
|
||||
media_info = yield self.media_repo.get_remote_media_info(server_name, media_id)
|
||||
# We should proxy the thumbnail from the remote server instead.
|
||||
media_info = yield self.media_repo.get_remote_media(server_name, media_id)
|
||||
|
||||
# if media_info["media_type"] == "image/svg+xml":
|
||||
# file_path = self.filepaths.remote_media_filepath(server_name, media_id)
|
||||
# yield respond_with_file(request, media_info["media_type"], file_path)
|
||||
# return
|
||||
|
||||
thumbnail_infos = yield self.store.get_remote_media_thumbnails(
|
||||
server_name, media_id,
|
||||
@@ -240,23 +238,59 @@ class ThumbnailResource(Resource):
|
||||
thumbnail_info = self._select_thumbnail(
|
||||
width, height, method, m_type, thumbnail_infos
|
||||
)
|
||||
file_info = FileInfo(
|
||||
server_name=server_name, file_id=media_info["filesystem_id"],
|
||||
thumbnail=True,
|
||||
thumbnail_width=thumbnail_info["thumbnail_width"],
|
||||
thumbnail_height=thumbnail_info["thumbnail_height"],
|
||||
thumbnail_type=thumbnail_info["thumbnail_type"],
|
||||
thumbnail_method=thumbnail_info["thumbnail_method"],
|
||||
)
|
||||
|
||||
t_type = file_info.thumbnail_type
|
||||
t_width = thumbnail_info["thumbnail_width"]
|
||||
t_height = thumbnail_info["thumbnail_height"]
|
||||
t_type = thumbnail_info["thumbnail_type"]
|
||||
t_method = thumbnail_info["thumbnail_method"]
|
||||
file_id = thumbnail_info["filesystem_id"]
|
||||
t_length = thumbnail_info["thumbnail_length"]
|
||||
|
||||
responder = yield self.media_storage.fetch_media(file_info)
|
||||
yield respond_with_responder(request, responder, t_type, t_length)
|
||||
file_path = self.filepaths.remote_media_thumbnail(
|
||||
server_name, file_id, t_width, t_height, t_type, t_method,
|
||||
)
|
||||
yield respond_with_file(request, t_type, file_path, t_length)
|
||||
else:
|
||||
logger.info("Failed to find any generated thumbnails")
|
||||
yield self._respond_default_thumbnail(
|
||||
request, media_info, width, height, method, m_type,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _respond_default_thumbnail(self, request, media_info, width, height,
|
||||
method, m_type):
|
||||
# XXX: how is this meant to work? store.get_default_thumbnails
|
||||
# appears to always return [] so won't this always 404?
|
||||
media_type = media_info["media_type"]
|
||||
top_level_type = media_type.split("/")[0]
|
||||
sub_type = media_type.split("/")[-1].split(";")[0]
|
||||
thumbnail_infos = yield self.store.get_default_thumbnails(
|
||||
top_level_type, sub_type,
|
||||
)
|
||||
if not thumbnail_infos:
|
||||
thumbnail_infos = yield self.store.get_default_thumbnails(
|
||||
top_level_type, "_default",
|
||||
)
|
||||
if not thumbnail_infos:
|
||||
thumbnail_infos = yield self.store.get_default_thumbnails(
|
||||
"_default", "_default",
|
||||
)
|
||||
if not thumbnail_infos:
|
||||
respond_404(request)
|
||||
return
|
||||
|
||||
thumbnail_info = self._select_thumbnail(
|
||||
width, height, "crop", m_type, thumbnail_infos
|
||||
)
|
||||
|
||||
t_width = thumbnail_info["thumbnail_width"]
|
||||
t_height = thumbnail_info["thumbnail_height"]
|
||||
t_type = thumbnail_info["thumbnail_type"]
|
||||
t_method = thumbnail_info["thumbnail_method"]
|
||||
t_length = thumbnail_info["thumbnail_length"]
|
||||
|
||||
file_path = self.filepaths.default_thumbnail(
|
||||
top_level_type, sub_type, t_width, t_height, t_type, t_method,
|
||||
)
|
||||
yield respond_with_file(request, t_type, file_path, t_length)
|
||||
|
||||
def _select_thumbnail(self, desired_width, desired_height, desired_method,
|
||||
desired_type, thumbnail_infos):
|
||||
|
||||
@@ -43,6 +43,7 @@ from synapse.handlers.deactivate_account import DeactivateAccountHandler
|
||||
from synapse.handlers.devicemessage import DeviceMessageHandler
|
||||
from synapse.handlers.device import DeviceHandler
|
||||
from synapse.handlers.e2e_keys import E2eKeysHandler
|
||||
from synapse.handlers.e2e_room_keys import E2eRoomKeysHandler
|
||||
from synapse.handlers.presence import PresenceHandler
|
||||
from synapse.handlers.room_list import RoomListHandler
|
||||
from synapse.handlers.set_password import SetPasswordHandler
|
||||
@@ -55,7 +56,6 @@ from synapse.handlers.read_marker import ReadMarkerHandler
|
||||
from synapse.handlers.user_directory import UserDirectoryHandler
|
||||
from synapse.handlers.groups_local import GroupsLocalHandler
|
||||
from synapse.handlers.profile import ProfileHandler
|
||||
from synapse.handlers.message import EventCreationHandler
|
||||
from synapse.groups.groups_server import GroupsServerHandler
|
||||
from synapse.groups.attestations import GroupAttestionRenewer, GroupAttestationSigning
|
||||
from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory
|
||||
@@ -67,7 +67,7 @@ from synapse.rest.media.v1.media_repository import (
|
||||
MediaRepository,
|
||||
MediaRepositoryResource,
|
||||
)
|
||||
from synapse.state import StateHandler, StateResolutionHandler
|
||||
from synapse.state import StateHandler
|
||||
from synapse.storage import DataStore
|
||||
from synapse.streams.events import EventSources
|
||||
from synapse.util import Clock
|
||||
@@ -103,7 +103,6 @@ class HomeServer(object):
|
||||
'v1auth',
|
||||
'auth',
|
||||
'state_handler',
|
||||
'state_resolution_handler',
|
||||
'presence_handler',
|
||||
'sync_handler',
|
||||
'typing_handler',
|
||||
@@ -111,6 +110,7 @@ class HomeServer(object):
|
||||
'auth_handler',
|
||||
'device_handler',
|
||||
'e2e_keys_handler',
|
||||
'e2e_room_keys_handler',
|
||||
'event_handler',
|
||||
'event_stream_handler',
|
||||
'initial_sync_handler',
|
||||
@@ -119,7 +119,6 @@ class HomeServer(object):
|
||||
'application_service_handler',
|
||||
'device_message_handler',
|
||||
'profile_handler',
|
||||
'event_creation_handler',
|
||||
'deactivate_account_handler',
|
||||
'set_password_handler',
|
||||
'notifier',
|
||||
@@ -227,9 +226,6 @@ class HomeServer(object):
|
||||
def build_state_handler(self):
|
||||
return StateHandler(self)
|
||||
|
||||
def build_state_resolution_handler(self):
|
||||
return StateResolutionHandler(self)
|
||||
|
||||
def build_presence_handler(self):
|
||||
return PresenceHandler(self)
|
||||
|
||||
@@ -257,6 +253,9 @@ class HomeServer(object):
|
||||
def build_e2e_keys_handler(self):
|
||||
return E2eKeysHandler(self)
|
||||
|
||||
def build_e2e_room_keys_handler(self):
|
||||
return E2eRoomKeysHandler(self)
|
||||
|
||||
def build_application_service_api(self):
|
||||
return ApplicationServiceApi(self)
|
||||
|
||||
@@ -278,9 +277,6 @@ class HomeServer(object):
|
||||
def build_profile_handler(self):
|
||||
return ProfileHandler(self)
|
||||
|
||||
def build_event_creation_handler(self):
|
||||
return EventCreationHandler(self)
|
||||
|
||||
def build_deactivate_account_handler(self):
|
||||
return DeactivateAccountHandler(self)
|
||||
|
||||
@@ -316,23 +312,6 @@ class HomeServer(object):
|
||||
**self.db_config.get("args", {})
|
||||
)
|
||||
|
||||
def get_db_conn(self, run_new_connection=True):
|
||||
"""Makes a new connection to the database, skipping the db pool
|
||||
|
||||
Returns:
|
||||
Connection: a connection object implementing the PEP-249 spec
|
||||
"""
|
||||
# Any param beginning with cp_ is a parameter for adbapi, and should
|
||||
# not be passed to the database engine.
|
||||
db_params = {
|
||||
k: v for k, v in self.db_config.get("args", {}).items()
|
||||
if not k.startswith("cp_")
|
||||
}
|
||||
db_conn = self.database_engine.module.connect(**db_params)
|
||||
if run_new_connection:
|
||||
self.database_engine.on_new_connection(db_conn)
|
||||
return db_conn
|
||||
|
||||
def build_media_repository_resource(self):
|
||||
# build the media repo resource. This indirects through the HomeServer
|
||||
# to ensure that we only have a single instance of
|
||||
|
||||
@@ -34,9 +34,6 @@ class HomeServer(object):
|
||||
def get_state_handler(self) -> synapse.state.StateHandler:
|
||||
pass
|
||||
|
||||
def get_state_resolution_handler(self) -> synapse.state.StateResolutionHandler:
|
||||
pass
|
||||
|
||||
def get_deactivate_account_handler(self) -> synapse.handlers.deactivate_account.DeactivateAccountHandler:
|
||||
pass
|
||||
|
||||
|
||||
336
synapse/state.py
336
synapse/state.py
@@ -58,11 +58,7 @@ class _StateCacheEntry(object):
|
||||
__slots__ = ["state", "state_group", "state_id", "prev_group", "delta_ids"]
|
||||
|
||||
def __init__(self, state, state_group, prev_group=None, delta_ids=None):
|
||||
# dict[(str, str), str] map from (type, state_key) to event_id
|
||||
self.state = frozendict(state)
|
||||
|
||||
# the ID of a state group if one and only one is involved.
|
||||
# otherwise, None otherwise?
|
||||
self.state_group = state_group
|
||||
|
||||
self.prev_group = prev_group
|
||||
@@ -85,19 +81,31 @@ class _StateCacheEntry(object):
|
||||
|
||||
|
||||
class StateHandler(object):
|
||||
"""Fetches bits of state from the stores, and does state resolution
|
||||
where necessary
|
||||
""" Responsible for doing state conflict resolution.
|
||||
"""
|
||||
|
||||
def __init__(self, hs):
|
||||
self.clock = hs.get_clock()
|
||||
self.store = hs.get_datastore()
|
||||
self.hs = hs
|
||||
self._state_resolution_handler = hs.get_state_resolution_handler()
|
||||
|
||||
# dict of set of event_ids -> _StateCacheEntry.
|
||||
self._state_cache = None
|
||||
self.resolve_linearizer = Linearizer(name="state_resolve_lock")
|
||||
|
||||
def start_caching(self):
|
||||
# TODO: remove this shim
|
||||
self._state_resolution_handler.start_caching()
|
||||
logger.debug("start_caching")
|
||||
|
||||
self._state_cache = ExpiringCache(
|
||||
cache_name="state_cache",
|
||||
clock=self.clock,
|
||||
max_len=SIZE_OF_CACHE,
|
||||
expiry_ms=EVICTION_TIMEOUT_SECONDS * 1000,
|
||||
iterable=True,
|
||||
reset_expiry_on_get=True,
|
||||
)
|
||||
|
||||
self._state_cache.start()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_current_state(self, room_id, event_type=None, state_key="",
|
||||
@@ -119,7 +127,7 @@ class StateHandler(object):
|
||||
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
|
||||
|
||||
logger.debug("calling resolve_state_groups from get_current_state")
|
||||
ret = yield self.resolve_state_groups_for_events(room_id, latest_event_ids)
|
||||
ret = yield self.resolve_state_groups(room_id, latest_event_ids)
|
||||
state = ret.state
|
||||
|
||||
if event_type:
|
||||
@@ -138,27 +146,19 @@ class StateHandler(object):
|
||||
defer.returnValue(state)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_current_state_ids(self, room_id, latest_event_ids=None):
|
||||
"""Get the current state, or the state at a set of events, for a room
|
||||
|
||||
Args:
|
||||
room_id (str):
|
||||
|
||||
latest_event_ids (iterable[str]|None): if given, the forward
|
||||
extremities to resolve. If None, we look them up from the
|
||||
database (via a cache)
|
||||
|
||||
Returns:
|
||||
Deferred[dict[(str, str), str)]]: the state dict, mapping from
|
||||
(event_type, state_key) -> event_id
|
||||
"""
|
||||
def get_current_state_ids(self, room_id, event_type=None, state_key="",
|
||||
latest_event_ids=None):
|
||||
if not latest_event_ids:
|
||||
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
|
||||
|
||||
logger.debug("calling resolve_state_groups from get_current_state_ids")
|
||||
ret = yield self.resolve_state_groups_for_events(room_id, latest_event_ids)
|
||||
ret = yield self.resolve_state_groups(room_id, latest_event_ids)
|
||||
state = ret.state
|
||||
|
||||
if event_type:
|
||||
defer.returnValue(state.get((event_type, state_key)))
|
||||
return
|
||||
|
||||
defer.returnValue(state)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@@ -166,7 +166,7 @@ class StateHandler(object):
|
||||
if not latest_event_ids:
|
||||
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
|
||||
logger.debug("calling resolve_state_groups from get_current_user_in_room")
|
||||
entry = yield self.resolve_state_groups_for_events(room_id, latest_event_ids)
|
||||
entry = yield self.resolve_state_groups(room_id, latest_event_ids)
|
||||
joined_users = yield self.store.get_joined_users_from_state(room_id, entry)
|
||||
defer.returnValue(joined_users)
|
||||
|
||||
@@ -175,7 +175,7 @@ class StateHandler(object):
|
||||
if not latest_event_ids:
|
||||
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
|
||||
logger.debug("calling resolve_state_groups from get_current_hosts_in_room")
|
||||
entry = yield self.resolve_state_groups_for_events(room_id, latest_event_ids)
|
||||
entry = yield self.resolve_state_groups(room_id, latest_event_ids)
|
||||
joined_hosts = yield self.store.get_joined_hosts(room_id, entry)
|
||||
defer.returnValue(joined_hosts)
|
||||
|
||||
@@ -183,15 +183,8 @@ class StateHandler(object):
|
||||
def compute_event_context(self, event, old_state=None):
|
||||
"""Build an EventContext structure for the event.
|
||||
|
||||
This works out what the current state should be for the event, and
|
||||
generates a new state group if necessary.
|
||||
|
||||
Args:
|
||||
event (synapse.events.EventBase):
|
||||
old_state (dict|None): The state at the event if it can't be
|
||||
calculated from existing events. This is normally only specified
|
||||
when receiving an event from federation where we don't have the
|
||||
prev events for, e.g. when backfilling.
|
||||
Returns:
|
||||
synapse.events.snapshot.EventContext:
|
||||
"""
|
||||
@@ -215,22 +208,15 @@ class StateHandler(object):
|
||||
context.current_state_ids = {}
|
||||
context.prev_state_ids = {}
|
||||
context.prev_state_events = []
|
||||
|
||||
# We don't store state for outliers, so we don't generate a state
|
||||
# froup for it.
|
||||
context.state_group = None
|
||||
|
||||
context.state_group = self.store.get_next_state_group()
|
||||
defer.returnValue(context)
|
||||
|
||||
if old_state:
|
||||
# We already have the state, so we don't need to calculate it.
|
||||
# Let's just correctly fill out the context and create a
|
||||
# new state group for it.
|
||||
|
||||
context = EventContext()
|
||||
context.prev_state_ids = {
|
||||
(s.type, s.state_key): s.event_id for s in old_state
|
||||
}
|
||||
context.state_group = self.store.get_next_state_group()
|
||||
|
||||
if event.is_state():
|
||||
key = (event.type, event.state_key)
|
||||
@@ -243,19 +229,11 @@ class StateHandler(object):
|
||||
else:
|
||||
context.current_state_ids = context.prev_state_ids
|
||||
|
||||
context.state_group = yield self.store.store_state_group(
|
||||
event.event_id,
|
||||
event.room_id,
|
||||
prev_group=None,
|
||||
delta_ids=None,
|
||||
current_state_ids=context.current_state_ids,
|
||||
)
|
||||
|
||||
context.prev_state_events = []
|
||||
defer.returnValue(context)
|
||||
|
||||
logger.debug("calling resolve_state_groups from compute_event_context")
|
||||
entry = yield self.resolve_state_groups_for_events(
|
||||
entry = yield self.resolve_state_groups(
|
||||
event.room_id, [e for e, _ in event.prev_events],
|
||||
)
|
||||
|
||||
@@ -264,8 +242,7 @@ class StateHandler(object):
|
||||
context = EventContext()
|
||||
context.prev_state_ids = curr_state
|
||||
if event.is_state():
|
||||
# If this is a state event then we need to create a new state
|
||||
# group for the state after this event.
|
||||
context.state_group = self.store.get_next_state_group()
|
||||
|
||||
key = (event.type, event.state_key)
|
||||
if key in context.prev_state_ids:
|
||||
@@ -276,57 +253,38 @@ class StateHandler(object):
|
||||
context.current_state_ids[key] = event.event_id
|
||||
|
||||
if entry.state_group:
|
||||
# If the state at the event has a state group assigned then
|
||||
# we can use that as the prev group
|
||||
context.prev_group = entry.state_group
|
||||
context.delta_ids = {
|
||||
key: event.event_id
|
||||
}
|
||||
elif entry.prev_group:
|
||||
# If the state at the event only has a prev group, then we can
|
||||
# use that as a prev group too.
|
||||
context.prev_group = entry.prev_group
|
||||
context.delta_ids = dict(entry.delta_ids)
|
||||
context.delta_ids[key] = event.event_id
|
||||
|
||||
context.state_group = yield self.store.store_state_group(
|
||||
event.event_id,
|
||||
event.room_id,
|
||||
prev_group=context.prev_group,
|
||||
delta_ids=context.delta_ids,
|
||||
current_state_ids=context.current_state_ids,
|
||||
)
|
||||
else:
|
||||
context.current_state_ids = context.prev_state_ids
|
||||
context.prev_group = entry.prev_group
|
||||
context.delta_ids = entry.delta_ids
|
||||
|
||||
if entry.state_group is None:
|
||||
entry.state_group = yield self.store.store_state_group(
|
||||
event.event_id,
|
||||
event.room_id,
|
||||
prev_group=entry.prev_group,
|
||||
delta_ids=entry.delta_ids,
|
||||
current_state_ids=context.current_state_ids,
|
||||
)
|
||||
entry.state_group = self.store.get_next_state_group()
|
||||
entry.state_id = entry.state_group
|
||||
|
||||
context.state_group = entry.state_group
|
||||
context.current_state_ids = context.prev_state_ids
|
||||
context.prev_group = entry.prev_group
|
||||
context.delta_ids = entry.delta_ids
|
||||
|
||||
context.prev_state_events = []
|
||||
defer.returnValue(context)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def resolve_state_groups_for_events(self, room_id, event_ids):
|
||||
@log_function
|
||||
def resolve_state_groups(self, room_id, event_ids):
|
||||
""" Given a list of event_ids this method fetches the state at each
|
||||
event, resolves conflicts between them and returns them.
|
||||
|
||||
Args:
|
||||
room_id (str):
|
||||
event_ids (list[str]):
|
||||
|
||||
Returns:
|
||||
Deferred[_StateCacheEntry]: resolved state
|
||||
a Deferred tuple of (`state_group`, `state`, `prev_state`).
|
||||
`state_group` is the name of a state group if one and only one is
|
||||
involved. `state` is a map from (type, state_key) to event, and
|
||||
`prev_state` is a list of event ids.
|
||||
"""
|
||||
logger.debug("resolve_state_groups event_ids %s", event_ids)
|
||||
|
||||
@@ -337,7 +295,13 @@ class StateHandler(object):
|
||||
room_id, event_ids
|
||||
)
|
||||
|
||||
if len(state_groups_ids) == 1:
|
||||
logger.debug(
|
||||
"resolve_state_groups state_groups %s",
|
||||
state_groups_ids.keys()
|
||||
)
|
||||
|
||||
group_names = frozenset(state_groups_ids.keys())
|
||||
if len(group_names) == 1:
|
||||
name, state_list = state_groups_ids.items().pop()
|
||||
|
||||
prev_group, delta_ids = yield self.store.get_state_group_delta(name)
|
||||
@@ -349,102 +313,6 @@ class StateHandler(object):
|
||||
delta_ids=delta_ids,
|
||||
))
|
||||
|
||||
result = yield self._state_resolution_handler.resolve_state_groups(
|
||||
room_id, state_groups_ids, None, self._state_map_factory,
|
||||
)
|
||||
defer.returnValue(result)
|
||||
|
||||
def _state_map_factory(self, ev_ids):
|
||||
return self.store.get_events(
|
||||
ev_ids, get_prev_content=False, check_redacted=False,
|
||||
)
|
||||
|
||||
def resolve_events(self, state_sets, event):
|
||||
logger.info(
|
||||
"Resolving state for %s with %d groups", event.room_id, len(state_sets)
|
||||
)
|
||||
state_set_ids = [{
|
||||
(ev.type, ev.state_key): ev.event_id
|
||||
for ev in st
|
||||
} for st in state_sets]
|
||||
|
||||
state_map = {
|
||||
ev.event_id: ev
|
||||
for st in state_sets
|
||||
for ev in st
|
||||
}
|
||||
|
||||
with Measure(self.clock, "state._resolve_events"):
|
||||
new_state = resolve_events_with_state_map(state_set_ids, state_map)
|
||||
|
||||
new_state = {
|
||||
key: state_map[ev_id] for key, ev_id in new_state.items()
|
||||
}
|
||||
|
||||
return new_state
|
||||
|
||||
|
||||
class StateResolutionHandler(object):
|
||||
"""Responsible for doing state conflict resolution.
|
||||
|
||||
Note that the storage layer depends on this handler, so all functions must
|
||||
be storage-independent.
|
||||
"""
|
||||
def __init__(self, hs):
|
||||
self.clock = hs.get_clock()
|
||||
|
||||
# dict of set of event_ids -> _StateCacheEntry.
|
||||
self._state_cache = None
|
||||
self.resolve_linearizer = Linearizer(name="state_resolve_lock")
|
||||
|
||||
def start_caching(self):
|
||||
logger.debug("start_caching")
|
||||
|
||||
self._state_cache = ExpiringCache(
|
||||
cache_name="state_cache",
|
||||
clock=self.clock,
|
||||
max_len=SIZE_OF_CACHE,
|
||||
expiry_ms=EVICTION_TIMEOUT_SECONDS * 1000,
|
||||
iterable=True,
|
||||
reset_expiry_on_get=True,
|
||||
)
|
||||
|
||||
self._state_cache.start()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def resolve_state_groups(
|
||||
self, room_id, state_groups_ids, event_map, state_map_factory,
|
||||
):
|
||||
"""Resolves conflicts between a set of state groups
|
||||
|
||||
Always generates a new state group (unless we hit the cache), so should
|
||||
not be called for a single state group
|
||||
|
||||
Args:
|
||||
room_id (str): room we are resolving for (used for logging)
|
||||
state_groups_ids (dict[int, dict[(str, str), str]]):
|
||||
map from state group id to the state in that state group
|
||||
(where 'state' is a map from state key to event id)
|
||||
|
||||
event_map(dict[str,FrozenEvent]|None):
|
||||
a dict from event_id to event, for any events that we happen to
|
||||
have in flight (eg, those currently being persisted). This will be
|
||||
used as a starting point fof finding the state we need; any missing
|
||||
events will be requested via state_map_factory.
|
||||
|
||||
If None, all events will be fetched via state_map_factory.
|
||||
|
||||
Returns:
|
||||
Deferred[_StateCacheEntry]: resolved state
|
||||
"""
|
||||
logger.debug(
|
||||
"resolve_state_groups state_groups %s",
|
||||
state_groups_ids.keys()
|
||||
)
|
||||
|
||||
group_names = frozenset(state_groups_ids.keys())
|
||||
|
||||
with (yield self.resolve_linearizer.queue(group_names)):
|
||||
if self._state_cache is not None:
|
||||
cache = self._state_cache.get(group_names, None)
|
||||
@@ -473,20 +341,17 @@ class StateResolutionHandler(object):
|
||||
if conflicted_state:
|
||||
logger.info("Resolving conflicted state for %r", room_id)
|
||||
with Measure(self.clock, "state._resolve_events"):
|
||||
new_state = yield resolve_events_with_factory(
|
||||
new_state = yield resolve_events(
|
||||
state_groups_ids.values(),
|
||||
event_map=event_map,
|
||||
state_map_factory=state_map_factory,
|
||||
state_map_factory=lambda ev_ids: self.store.get_events(
|
||||
ev_ids, get_prev_content=False, check_redacted=False,
|
||||
),
|
||||
)
|
||||
else:
|
||||
new_state = {
|
||||
key: e_ids.pop() for key, e_ids in state.items()
|
||||
}
|
||||
|
||||
# if the new state matches any of the input state groups, we can
|
||||
# use that state group again. Otherwise we will generate a state_id
|
||||
# which will be used as a cache key for future resolutions, but
|
||||
# not get persisted.
|
||||
state_group = None
|
||||
new_state_event_ids = frozenset(new_state.values())
|
||||
for sg, events in state_groups_ids.items():
|
||||
@@ -523,6 +388,30 @@ class StateResolutionHandler(object):
|
||||
|
||||
defer.returnValue(cache)
|
||||
|
||||
def resolve_events(self, state_sets, event):
|
||||
logger.info(
|
||||
"Resolving state for %s with %d groups", event.room_id, len(state_sets)
|
||||
)
|
||||
state_set_ids = [{
|
||||
(ev.type, ev.state_key): ev.event_id
|
||||
for ev in st
|
||||
} for st in state_sets]
|
||||
|
||||
state_map = {
|
||||
ev.event_id: ev
|
||||
for st in state_sets
|
||||
for ev in st
|
||||
}
|
||||
|
||||
with Measure(self.clock, "state._resolve_events"):
|
||||
new_state = resolve_events(state_set_ids, state_map)
|
||||
|
||||
new_state = {
|
||||
key: state_map[ev_id] for key, ev_id in new_state.items()
|
||||
}
|
||||
|
||||
return new_state
|
||||
|
||||
|
||||
def _ordered_events(events):
|
||||
def key_func(e):
|
||||
@@ -531,17 +420,19 @@ def _ordered_events(events):
|
||||
return sorted(events, key=key_func)
|
||||
|
||||
|
||||
def resolve_events_with_state_map(state_sets, state_map):
|
||||
def resolve_events(state_sets, state_map_factory):
|
||||
"""
|
||||
Args:
|
||||
state_sets(list): List of dicts of (type, state_key) -> event_id,
|
||||
which are the different state groups to resolve.
|
||||
state_map(dict): a dict from event_id to event, for all events in
|
||||
state_sets.
|
||||
state_map_factory(dict|callable): If callable, then will be called
|
||||
with a list of event_ids that are needed, and should return with
|
||||
a Deferred of dict of event_id to event. Otherwise, should be
|
||||
a dict from event_id to event of all events in state_sets.
|
||||
|
||||
Returns
|
||||
dict[(str, str), str]:
|
||||
a map from (type, state_key) to event_id.
|
||||
dict[(str, str), synapse.events.FrozenEvent] is a map from
|
||||
(type, state_key) to event.
|
||||
"""
|
||||
if len(state_sets) == 1:
|
||||
return state_sets[0]
|
||||
@@ -550,6 +441,13 @@ def resolve_events_with_state_map(state_sets, state_map):
|
||||
state_sets,
|
||||
)
|
||||
|
||||
if callable(state_map_factory):
|
||||
return _resolve_with_state_fac(
|
||||
unconflicted_state, conflicted_state, state_map_factory
|
||||
)
|
||||
|
||||
state_map = state_map_factory
|
||||
|
||||
auth_events = _create_auth_events_from_maps(
|
||||
unconflicted_state, conflicted_state, state_map
|
||||
)
|
||||
@@ -563,21 +461,6 @@ def _seperate(state_sets):
|
||||
"""Takes the state_sets and figures out which keys are conflicted and
|
||||
which aren't. i.e., which have multiple different event_ids associated
|
||||
with them in different state sets.
|
||||
|
||||
Args:
|
||||
state_sets(list[dict[(str, str), str]]):
|
||||
List of dicts of (type, state_key) -> event_id, which are the
|
||||
different state groups to resolve.
|
||||
|
||||
Returns:
|
||||
(dict[(str, str), str], dict[(str, str), set[str]]):
|
||||
A tuple of (unconflicted_state, conflicted_state), where:
|
||||
|
||||
unconflicted_state is a dict mapping (type, state_key)->event_id
|
||||
for unconflicted state keys.
|
||||
|
||||
conflicted_state is a dict mapping (type, state_key) to a set of
|
||||
event ids for conflicted state keys.
|
||||
"""
|
||||
unconflicted_state = dict(state_sets[0])
|
||||
conflicted_state = {}
|
||||
@@ -608,50 +491,19 @@ def _seperate(state_sets):
|
||||
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def resolve_events_with_factory(state_sets, event_map, state_map_factory):
|
||||
"""
|
||||
Args:
|
||||
state_sets(list): List of dicts of (type, state_key) -> event_id,
|
||||
which are the different state groups to resolve.
|
||||
|
||||
event_map(dict[str,FrozenEvent]|None):
|
||||
a dict from event_id to event, for any events that we happen to
|
||||
have in flight (eg, those currently being persisted). This will be
|
||||
used as a starting point fof finding the state we need; any missing
|
||||
events will be requested via state_map_factory.
|
||||
|
||||
If None, all events will be fetched via state_map_factory.
|
||||
|
||||
state_map_factory(func): will be called
|
||||
with a list of event_ids that are needed, and should return with
|
||||
a Deferred of dict of event_id to event.
|
||||
|
||||
Returns
|
||||
Deferred[dict[(str, str), str]]:
|
||||
a map from (type, state_key) to event_id.
|
||||
"""
|
||||
if len(state_sets) == 1:
|
||||
defer.returnValue(state_sets[0])
|
||||
|
||||
unconflicted_state, conflicted_state = _seperate(
|
||||
state_sets,
|
||||
)
|
||||
|
||||
def _resolve_with_state_fac(unconflicted_state, conflicted_state,
|
||||
state_map_factory):
|
||||
needed_events = set(
|
||||
event_id
|
||||
for event_ids in conflicted_state.itervalues()
|
||||
for event_id in event_ids
|
||||
)
|
||||
if event_map is not None:
|
||||
needed_events -= set(event_map.iterkeys())
|
||||
|
||||
logger.info("Asking for %d conflicted events", len(needed_events))
|
||||
|
||||
# dict[str, FrozenEvent]: a map from state event id to event. Only includes
|
||||
# the state events which are in conflict (and those in event_map)
|
||||
# the state events which are in conflict.
|
||||
state_map = yield state_map_factory(needed_events)
|
||||
if event_map is not None:
|
||||
state_map.update(event_map)
|
||||
|
||||
# get the ids of the auth events which allow us to authenticate the
|
||||
# conflicted state, picking only from the unconflicting state.
|
||||
@@ -663,8 +515,6 @@ def resolve_events_with_factory(state_sets, event_map, state_map_factory):
|
||||
|
||||
new_needed_events = set(auth_events.itervalues())
|
||||
new_needed_events -= needed_events
|
||||
if event_map is not None:
|
||||
new_needed_events -= set(event_map.iterkeys())
|
||||
|
||||
logger.info("Asking for %d auth events", len(new_needed_events))
|
||||
|
||||
|
||||
@@ -42,6 +42,7 @@ from .state import StateStore
|
||||
from .signatures import SignatureStore
|
||||
from .filtering import FilteringStore
|
||||
from .end_to_end_keys import EndToEndKeyStore
|
||||
from .e2e_room_keys import EndToEndRoomKeyStore
|
||||
|
||||
from .receipts import ReceiptsStore
|
||||
from .search import SearchStore
|
||||
@@ -79,6 +80,7 @@ class DataStore(RoomMemberStore, RoomStore,
|
||||
EventsStore,
|
||||
ReceiptsStore,
|
||||
EndToEndKeyStore,
|
||||
EndToEndRoomKeyStore,
|
||||
SearchStore,
|
||||
TagsStore,
|
||||
AccountDataStore,
|
||||
@@ -124,6 +126,7 @@ class DataStore(RoomMemberStore, RoomStore,
|
||||
)
|
||||
|
||||
self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id")
|
||||
self._state_groups_id_gen = IdGenerator(db_conn, "state_groups", "id")
|
||||
self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
|
||||
self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
|
||||
self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
|
||||
|
||||
@@ -291,33 +291,33 @@ class SQLBaseStore(object):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def runInteraction(self, desc, func, *args, **kwargs):
|
||||
"""Starts a transaction on the database and runs a given function
|
||||
|
||||
Arguments:
|
||||
desc (str): description of the transaction, for logging and metrics
|
||||
func (func): callback function, which will be called with a
|
||||
database transaction (twisted.enterprise.adbapi.Transaction) as
|
||||
its first argument, followed by `args` and `kwargs`.
|
||||
|
||||
args (list): positional args to pass to `func`
|
||||
kwargs (dict): named args to pass to `func`
|
||||
|
||||
Returns:
|
||||
Deferred: The result of func
|
||||
"""
|
||||
"""Wraps the .runInteraction() method on the underlying db_pool."""
|
||||
current_context = LoggingContext.current_context()
|
||||
|
||||
start_time = time.time() * 1000
|
||||
|
||||
after_callbacks = []
|
||||
final_callbacks = []
|
||||
|
||||
def inner_func(conn, *args, **kwargs):
|
||||
return self._new_transaction(
|
||||
conn, desc, after_callbacks, final_callbacks, current_context,
|
||||
func, *args, **kwargs
|
||||
)
|
||||
with LoggingContext("runInteraction") as context:
|
||||
sql_scheduling_timer.inc_by(time.time() * 1000 - start_time)
|
||||
|
||||
if self.database_engine.is_connection_closed(conn):
|
||||
logger.debug("Reconnecting closed database connection")
|
||||
conn.reconnect()
|
||||
|
||||
current_context.copy_to(context)
|
||||
return self._new_transaction(
|
||||
conn, desc, after_callbacks, final_callbacks, current_context,
|
||||
func, *args, **kwargs
|
||||
)
|
||||
|
||||
try:
|
||||
result = yield self.runWithConnection(inner_func, *args, **kwargs)
|
||||
with PreserveLoggingContext():
|
||||
result = yield self._db_pool.runWithConnection(
|
||||
inner_func, *args, **kwargs
|
||||
)
|
||||
|
||||
for after_callback, after_args, after_kwargs in after_callbacks:
|
||||
after_callback(*after_args, **after_kwargs)
|
||||
@@ -329,27 +329,14 @@ class SQLBaseStore(object):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def runWithConnection(self, func, *args, **kwargs):
|
||||
"""Wraps the .runWithConnection() method on the underlying db_pool.
|
||||
|
||||
Arguments:
|
||||
func (func): callback function, which will be called with a
|
||||
database connection (twisted.enterprise.adbapi.Connection) as
|
||||
its first argument, followed by `args` and `kwargs`.
|
||||
args (list): positional args to pass to `func`
|
||||
kwargs (dict): named args to pass to `func`
|
||||
|
||||
Returns:
|
||||
Deferred: The result of func
|
||||
"""
|
||||
"""Wraps the .runInteraction() method on the underlying db_pool."""
|
||||
current_context = LoggingContext.current_context()
|
||||
|
||||
start_time = time.time() * 1000
|
||||
|
||||
def inner_func(conn, *args, **kwargs):
|
||||
with LoggingContext("runWithConnection") as context:
|
||||
sched_duration_ms = time.time() * 1000 - start_time
|
||||
sql_scheduling_timer.inc_by(sched_duration_ms)
|
||||
current_context.add_database_scheduled(sched_duration_ms)
|
||||
sql_scheduling_timer.inc_by(time.time() * 1000 - start_time)
|
||||
|
||||
if self.database_engine.is_connection_closed(conn):
|
||||
logger.debug("Reconnecting closed database connection")
|
||||
|
||||
@@ -99,19 +99,6 @@ class ApplicationServiceStore(SQLBaseStore):
|
||||
return service
|
||||
return None
|
||||
|
||||
def get_app_service_by_id(self, as_id):
|
||||
"""Get the application service with the given appservice ID.
|
||||
|
||||
Args:
|
||||
as_id (str): The application service ID.
|
||||
Returns:
|
||||
synapse.appservice.ApplicationService or None.
|
||||
"""
|
||||
for service in self.services_cache:
|
||||
if service.id == as_id:
|
||||
return service
|
||||
return None
|
||||
|
||||
def get_app_service_rooms(self, service):
|
||||
"""Get a list of RoomsForUser for this application service.
|
||||
|
||||
|
||||
@@ -242,25 +242,6 @@ class BackgroundUpdateStore(SQLBaseStore):
|
||||
"""
|
||||
self._background_update_handlers[update_name] = update_handler
|
||||
|
||||
def register_noop_background_update(self, update_name):
|
||||
"""Register a noop handler for a background update.
|
||||
|
||||
This is useful when we previously did a background update, but no
|
||||
longer wish to do the update. In this case the background update should
|
||||
be removed from the schema delta files, but there may still be some
|
||||
users who have the background update queued, so this method should
|
||||
also be called to clear the update.
|
||||
|
||||
Args:
|
||||
update_name (str): Name of update
|
||||
"""
|
||||
@defer.inlineCallbacks
|
||||
def noop_update(progress, batch_size):
|
||||
yield self._end_background_update(update_name)
|
||||
defer.returnValue(1)
|
||||
|
||||
self.register_background_update_handler(update_name, noop_update)
|
||||
|
||||
def register_background_index_update(self, update_name, index_name,
|
||||
table, columns, where_clause=None,
|
||||
unique=False,
|
||||
|
||||
@@ -28,7 +28,7 @@ logger = logging.getLogger(__name__)
|
||||
# Number of msec of granularity to store the user IP 'last seen' time. Smaller
|
||||
# times give more inserts into the database even for readonly API hits
|
||||
# 120 seconds == 2 minutes
|
||||
LAST_SEEN_GRANULARITY = 10 * 60 * 1000
|
||||
LAST_SEEN_GRANULARITY = 120 * 1000
|
||||
|
||||
|
||||
class ClientIpStore(background_updates.BackgroundUpdateStore):
|
||||
|
||||
308
synapse/storage/e2e_room_keys.py
Normal file
308
synapse/storage/e2e_room_keys.py
Normal file
@@ -0,0 +1,308 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2017 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from twisted.internet import defer
|
||||
from synapse.api.errors import StoreError
|
||||
|
||||
from ._base import SQLBaseStore
|
||||
|
||||
|
||||
class EndToEndRoomKeyStore(SQLBaseStore):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_e2e_room_key(self, user_id, version, room_id, session_id):
|
||||
"""Get the encrypted E2E room key for a given session from a given
|
||||
backup version of room_keys. We only store the 'best' room key for a given
|
||||
session at a given time, as determined by the handler.
|
||||
|
||||
Args:
|
||||
user_id(str): the user whose backup we're querying
|
||||
version(str): the version ID of the backup for the set of keys we're querying
|
||||
room_id(str): the ID of the room whose keys we're querying.
|
||||
This is a bit redundant as it's implied by the session_id, but
|
||||
we include for consistency with the rest of the API.
|
||||
session_id(str): the session whose room_key we're querying.
|
||||
|
||||
Returns:
|
||||
A deferred dict giving the session_data and message metadata for
|
||||
this room key.
|
||||
"""
|
||||
|
||||
row = yield self._simple_select_one(
|
||||
table="e2e_room_keys",
|
||||
keyvalues={
|
||||
"user_id": user_id,
|
||||
"version": version,
|
||||
"room_id": room_id,
|
||||
"session_id": session_id,
|
||||
},
|
||||
retcols=(
|
||||
"first_message_index",
|
||||
"forwarded_count",
|
||||
"is_verified",
|
||||
"session_data",
|
||||
),
|
||||
desc="get_e2e_room_key",
|
||||
)
|
||||
|
||||
defer.returnValue(row)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def set_e2e_room_key(self, user_id, version, room_id, session_id, room_key):
|
||||
"""Replaces or inserts the encrypted E2E room key for a given session in
|
||||
a given backup
|
||||
|
||||
Args:
|
||||
user_id(str): the user whose backup we're setting
|
||||
version(str): the version ID of the backup we're updating
|
||||
room_id(str): the ID of the room whose keys we're setting
|
||||
session_id(str): the session whose room_key we're setting
|
||||
room_key(dict): the room_key being set
|
||||
Raises:
|
||||
StoreError if stuff goes wrong, probably
|
||||
"""
|
||||
|
||||
yield self._simple_upsert(
|
||||
table="e2e_room_keys",
|
||||
keyvalues={
|
||||
"user_id": user_id,
|
||||
"room_id": room_id,
|
||||
"session_id": session_id,
|
||||
},
|
||||
values={
|
||||
"version": version,
|
||||
"first_message_index": room_key['first_message_index'],
|
||||
"forwarded_count": room_key['forwarded_count'],
|
||||
"is_verified": room_key['is_verified'],
|
||||
"session_data": room_key['session_data'],
|
||||
},
|
||||
lock=False,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_e2e_room_keys(
|
||||
self, user_id, version, room_id=None, session_id=None
|
||||
):
|
||||
"""Bulk get the E2E room keys for a given backup, optionally filtered to a given
|
||||
room, or a given session.
|
||||
|
||||
Args:
|
||||
user_id(str): the user whose backup we're querying
|
||||
version(str): the version ID of the backup for the set of keys we're querying
|
||||
room_id(str): Optional. the ID of the room whose keys we're querying, if any.
|
||||
If not specified, we return the keys for all the rooms in the backup.
|
||||
session_id(str): Optional. the session whose room_key we're querying, if any.
|
||||
If specified, we also require the room_id to be specified.
|
||||
If not specified, we return all the keys in this version of
|
||||
the backup (or for the specified room)
|
||||
|
||||
Returns:
|
||||
A deferred list of dicts giving the session_data and message metadata for
|
||||
these room keys.
|
||||
"""
|
||||
|
||||
keyvalues = {
|
||||
"user_id": user_id,
|
||||
"version": version,
|
||||
}
|
||||
if room_id:
|
||||
keyvalues['room_id'] = room_id
|
||||
if session_id:
|
||||
keyvalues['session_id'] = session_id
|
||||
|
||||
rows = yield self._simple_select_list(
|
||||
table="e2e_room_keys",
|
||||
keyvalues=keyvalues,
|
||||
retcols=(
|
||||
"user_id",
|
||||
"room_id",
|
||||
"session_id",
|
||||
"first_message_index",
|
||||
"forwarded_count",
|
||||
"is_verified",
|
||||
"session_data",
|
||||
),
|
||||
desc="get_e2e_room_keys",
|
||||
)
|
||||
|
||||
sessions = {'rooms': {}}
|
||||
for row in rows:
|
||||
room_entry = sessions['rooms'].setdefault(row['room_id'], {"sessions": {}})
|
||||
room_entry['sessions'][row['session_id']] = {
|
||||
"first_message_index": row["first_message_index"],
|
||||
"forwarded_count": row["forwarded_count"],
|
||||
"is_verified": row["is_verified"],
|
||||
"session_data": row["session_data"],
|
||||
}
|
||||
|
||||
defer.returnValue(sessions)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def delete_e2e_room_keys(
|
||||
self, user_id, version, room_id=None, session_id=None
|
||||
):
|
||||
"""Bulk delete the E2E room keys for a given backup, optionally filtered to a given
|
||||
room or a given session.
|
||||
|
||||
Args:
|
||||
user_id(str): the user whose backup we're deleting from
|
||||
version(str): the version ID of the backup for the set of keys we're deleting
|
||||
room_id(str): Optional. the ID of the room whose keys we're deleting, if any.
|
||||
If not specified, we delete the keys for all the rooms in the backup.
|
||||
session_id(str): Optional. the session whose room_key we're querying, if any.
|
||||
If specified, we also require the room_id to be specified.
|
||||
If not specified, we delete all the keys in this version of
|
||||
the backup (or for the specified room)
|
||||
|
||||
Returns:
|
||||
A deferred of the deletion transaction
|
||||
"""
|
||||
|
||||
keyvalues = {
|
||||
"user_id": user_id,
|
||||
"version": version,
|
||||
}
|
||||
if room_id:
|
||||
keyvalues['room_id'] = room_id
|
||||
if session_id:
|
||||
keyvalues['session_id'] = session_id
|
||||
|
||||
yield self._simple_delete(
|
||||
table="e2e_room_keys",
|
||||
keyvalues=keyvalues,
|
||||
desc="delete_e2e_room_keys",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_current_version(txn, user_id):
|
||||
txn.execute(
|
||||
"SELECT MAX(version) FROM e2e_room_keys_versions WHERE user_id=?",
|
||||
(user_id,)
|
||||
)
|
||||
row = txn.fetchone()
|
||||
if not row:
|
||||
raise StoreError(404, 'No current backup version')
|
||||
return row[0]
|
||||
|
||||
def get_e2e_room_keys_version_info(self, user_id, version=None):
|
||||
"""Get info metadata about a version of our room_keys backup.
|
||||
|
||||
Args:
|
||||
user_id(str): the user whose backup we're querying
|
||||
version(str): Optional. the version ID of the backup we're querying about
|
||||
If missing, we return the information about the current version.
|
||||
Raises:
|
||||
StoreError: with code 404 if there are no e2e_room_keys_versions present
|
||||
Returns:
|
||||
A deferred dict giving the info metadata for this backup version
|
||||
"""
|
||||
|
||||
def _get_e2e_room_keys_version_info_txn(txn):
|
||||
if version is None:
|
||||
this_version = self._get_current_version(txn, user_id)
|
||||
else:
|
||||
this_version = version
|
||||
|
||||
return self._simple_select_one_txn(
|
||||
txn,
|
||||
table="e2e_room_keys_versions",
|
||||
keyvalues={
|
||||
"user_id": user_id,
|
||||
"version": this_version,
|
||||
},
|
||||
retcols=(
|
||||
"version",
|
||||
"algorithm",
|
||||
"auth_data",
|
||||
),
|
||||
)
|
||||
|
||||
return self.runInteraction(
|
||||
"get_e2e_room_keys_version_info",
|
||||
_get_e2e_room_keys_version_info_txn
|
||||
)
|
||||
|
||||
def create_e2e_room_keys_version(self, user_id, info):
|
||||
"""Atomically creates a new version of this user's e2e_room_keys store
|
||||
with the given version info.
|
||||
|
||||
Args:
|
||||
user_id(str): the user whose backup we're creating a version
|
||||
info(dict): the info about the backup version to be created
|
||||
|
||||
Returns:
|
||||
A deferred string for the newly created version ID
|
||||
"""
|
||||
|
||||
def _create_e2e_room_keys_version_txn(txn):
|
||||
txn.execute(
|
||||
"SELECT MAX(version) FROM e2e_room_keys_versions WHERE user_id=?",
|
||||
(user_id,)
|
||||
)
|
||||
current_version = txn.fetchone()[0]
|
||||
if current_version is None:
|
||||
current_version = '0'
|
||||
|
||||
new_version = str(int(current_version) + 1)
|
||||
|
||||
self._simple_insert_txn(
|
||||
txn,
|
||||
table="e2e_room_keys_versions",
|
||||
values={
|
||||
"user_id": user_id,
|
||||
"version": new_version,
|
||||
"algorithm": info["algorithm"],
|
||||
"auth_data": info["auth_data"],
|
||||
},
|
||||
)
|
||||
|
||||
return new_version
|
||||
|
||||
return self.runInteraction(
|
||||
"create_e2e_room_keys_version_txn", _create_e2e_room_keys_version_txn
|
||||
)
|
||||
|
||||
def delete_e2e_room_keys_version(self, user_id, version=None):
|
||||
"""Delete a given backup version of the user's room keys.
|
||||
Doesn't delete their actual key data.
|
||||
|
||||
Args:
|
||||
user_id(str): the user whose backup version we're deleting
|
||||
version(str): Optional. the version ID of the backup version we're deleting
|
||||
If missing, we delete the current backup version info.
|
||||
Raises:
|
||||
StoreError: with code 404 if there are no e2e_room_keys_versions present,
|
||||
or if the version requested doesn't exist.
|
||||
"""
|
||||
|
||||
def _delete_e2e_room_keys_version_txn(txn):
|
||||
if version is None:
|
||||
this_version = self._get_current_version(txn, user_id)
|
||||
else:
|
||||
this_version = version
|
||||
|
||||
return self._simple_delete_one_txn(
|
||||
txn,
|
||||
table="e2e_room_keys_versions",
|
||||
keyvalues={
|
||||
"user_id": user_id,
|
||||
"version": this_version,
|
||||
},
|
||||
)
|
||||
|
||||
return self.runInteraction(
|
||||
"delete_e2e_room_keys_version",
|
||||
_delete_e2e_room_keys_version_txn
|
||||
)
|
||||
@@ -62,9 +62,3 @@ class PostgresEngine(object):
|
||||
|
||||
def lock_table(self, txn, table):
|
||||
txn.execute("LOCK TABLE %s in EXCLUSIVE MODE" % (table,))
|
||||
|
||||
def get_next_state_group_id(self, txn):
|
||||
"""Returns an int that can be used as a new state_group ID
|
||||
"""
|
||||
txn.execute("SELECT nextval('state_group_id_seq')")
|
||||
return txn.fetchone()[0]
|
||||
|
||||
@@ -16,7 +16,6 @@
|
||||
from synapse.storage.prepare_database import prepare_database
|
||||
|
||||
import struct
|
||||
import threading
|
||||
|
||||
|
||||
class Sqlite3Engine(object):
|
||||
@@ -25,11 +24,6 @@ class Sqlite3Engine(object):
|
||||
def __init__(self, database_module, database_config):
|
||||
self.module = database_module
|
||||
|
||||
# The current max state_group, or None if we haven't looked
|
||||
# in the DB yet.
|
||||
self._current_state_group_id = None
|
||||
self._current_state_group_id_lock = threading.Lock()
|
||||
|
||||
def check_database(self, txn):
|
||||
pass
|
||||
|
||||
@@ -49,19 +43,6 @@ class Sqlite3Engine(object):
|
||||
def lock_table(self, txn, table):
|
||||
return
|
||||
|
||||
def get_next_state_group_id(self, txn):
|
||||
"""Returns an int that can be used as a new state_group ID
|
||||
"""
|
||||
# We do application locking here since if we're using sqlite then
|
||||
# we are a single process synapse.
|
||||
with self._current_state_group_id_lock:
|
||||
if self._current_state_group_id is None:
|
||||
txn.execute("SELECT COALESCE(max(id), 0) FROM state_groups")
|
||||
self._current_state_group_id = txn.fetchone()[0]
|
||||
|
||||
self._current_state_group_id += 1
|
||||
return self._current_state_group_id
|
||||
|
||||
|
||||
# Following functions taken from: https://github.com/coleifer/peewee
|
||||
|
||||
|
||||
@@ -87,8 +87,6 @@ class EventPushActionsStore(SQLBaseStore):
|
||||
self._rotate_notif_loop = self._clock.looping_call(
|
||||
self._rotate_notifs, 30 * 60 * 1000
|
||||
)
|
||||
self._rotate_delay = 3
|
||||
self._rotate_count = 10000
|
||||
|
||||
def _set_push_actions_for_event_and_users_txn(self, txn, event, tuples):
|
||||
"""
|
||||
@@ -631,7 +629,7 @@ class EventPushActionsStore(SQLBaseStore):
|
||||
)
|
||||
if caught_up:
|
||||
break
|
||||
yield sleep(self._rotate_delay)
|
||||
yield sleep(5)
|
||||
finally:
|
||||
self._doing_notif_rotation = False
|
||||
|
||||
@@ -652,8 +650,8 @@ class EventPushActionsStore(SQLBaseStore):
|
||||
txn.execute("""
|
||||
SELECT stream_ordering FROM event_push_actions
|
||||
WHERE stream_ordering > ?
|
||||
ORDER BY stream_ordering ASC LIMIT 1 OFFSET ?
|
||||
""", (old_rotate_stream_ordering, self._rotate_count))
|
||||
ORDER BY stream_ordering ASC LIMIT 1 OFFSET 50000
|
||||
""", (old_rotate_stream_ordering,))
|
||||
stream_row = txn.fetchone()
|
||||
if stream_row:
|
||||
offset_stream_ordering, = stream_row
|
||||
|
||||
@@ -27,6 +27,7 @@ from synapse.util.logutils import log_function
|
||||
from synapse.util.metrics import Measure
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.state import resolve_events
|
||||
from synapse.util.caches.descriptors import cached
|
||||
from synapse.types import get_domain_from_id
|
||||
|
||||
@@ -109,7 +110,7 @@ class _EventPeristenceQueue(object):
|
||||
end_item.events_and_contexts.extend(events_and_contexts)
|
||||
return end_item.deferred.observe()
|
||||
|
||||
deferred = ObservableDeferred(defer.Deferred(), consumeErrors=True)
|
||||
deferred = ObservableDeferred(defer.Deferred())
|
||||
|
||||
queue.append(self._EventPersistQueueItem(
|
||||
events_and_contexts=events_and_contexts,
|
||||
@@ -145,25 +146,18 @@ class _EventPeristenceQueue(object):
|
||||
try:
|
||||
queue = self._get_drainining_queue(room_id)
|
||||
for item in queue:
|
||||
# handle_queue_loop runs in the sentinel logcontext, so
|
||||
# there is no need to preserve_fn when running the
|
||||
# callbacks on the deferred.
|
||||
try:
|
||||
ret = yield per_item_callback(item)
|
||||
item.deferred.callback(ret)
|
||||
except Exception:
|
||||
item.deferred.errback()
|
||||
except Exception as e:
|
||||
item.deferred.errback(e)
|
||||
finally:
|
||||
queue = self._event_persist_queues.pop(room_id, None)
|
||||
if queue:
|
||||
self._event_persist_queues[room_id] = queue
|
||||
self._currently_persisting_rooms.discard(room_id)
|
||||
|
||||
# set handle_queue_loop off on the background. We don't want to
|
||||
# attribute work done in it to the current request, so we drop the
|
||||
# logcontext altogether.
|
||||
with PreserveLoggingContext():
|
||||
handle_queue_loop()
|
||||
preserve_fn(handle_queue_loop)()
|
||||
|
||||
def _get_drainining_queue(self, room_id):
|
||||
queue = self._event_persist_queues.setdefault(room_id, deque())
|
||||
@@ -236,8 +230,6 @@ class EventsStore(SQLBaseStore):
|
||||
|
||||
self._event_persist_queue = _EventPeristenceQueue()
|
||||
|
||||
self._state_resolution_handler = hs.get_state_resolution_handler()
|
||||
|
||||
def persist_events(self, events_and_contexts, backfilled=False):
|
||||
"""
|
||||
Write events to the database
|
||||
@@ -343,20 +335,8 @@ class EventsStore(SQLBaseStore):
|
||||
|
||||
# NB: Assumes that we are only persisting events for one room
|
||||
# at a time.
|
||||
|
||||
# map room_id->list[event_ids] giving the new forward
|
||||
# extremities in each room
|
||||
new_forward_extremeties = {}
|
||||
|
||||
# map room_id->(type,state_key)->event_id tracking the full
|
||||
# state in each room after adding these events
|
||||
current_state_for_room = {}
|
||||
|
||||
# map room_id->(to_delete, to_insert) where each entry is
|
||||
# a map (type,key)->event_id giving the state delta in each
|
||||
# room
|
||||
state_delta_for_room = {}
|
||||
|
||||
if not backfilled:
|
||||
with Measure(self._clock, "_calculate_state_and_extrem"):
|
||||
# Work out the new "current state" for each room.
|
||||
@@ -399,20 +379,11 @@ class EventsStore(SQLBaseStore):
|
||||
if all_single_prev_not_state:
|
||||
continue
|
||||
|
||||
logger.info(
|
||||
"Calculating state delta for room %s", room_id,
|
||||
state = yield self._calculate_state_delta(
|
||||
room_id, ev_ctx_rm, new_latest_event_ids
|
||||
)
|
||||
current_state = yield self._get_new_state_after_events(
|
||||
room_id,
|
||||
ev_ctx_rm, new_latest_event_ids,
|
||||
)
|
||||
if current_state is not None:
|
||||
current_state_for_room[room_id] = current_state
|
||||
delta = yield self._calculate_state_delta(
|
||||
room_id, current_state,
|
||||
)
|
||||
if delta is not None:
|
||||
state_delta_for_room[room_id] = delta
|
||||
if state:
|
||||
current_state_for_room[room_id] = state
|
||||
|
||||
yield self.runInteraction(
|
||||
"persist_events",
|
||||
@@ -420,7 +391,7 @@ class EventsStore(SQLBaseStore):
|
||||
events_and_contexts=chunk,
|
||||
backfilled=backfilled,
|
||||
delete_existing=delete_existing,
|
||||
state_delta_for_room=state_delta_for_room,
|
||||
current_state_for_room=current_state_for_room,
|
||||
new_forward_extremeties=new_forward_extremeties,
|
||||
)
|
||||
persist_event_counter.inc_by(len(chunk))
|
||||
@@ -437,7 +408,7 @@ class EventsStore(SQLBaseStore):
|
||||
|
||||
event_counter.inc(event.type, origin_type, origin_entity)
|
||||
|
||||
for room_id, new_state in current_state_for_room.iteritems():
|
||||
for room_id, (_, _, new_state) in current_state_for_room.iteritems():
|
||||
self.get_current_state_ids.prefill(
|
||||
(room_id, ), new_state
|
||||
)
|
||||
@@ -489,31 +460,22 @@ class EventsStore(SQLBaseStore):
|
||||
defer.returnValue(new_latest_event_ids)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _get_new_state_after_events(self, room_id, events_context, new_latest_event_ids):
|
||||
"""Calculate the current state dict after adding some new events to
|
||||
a room
|
||||
def _calculate_state_delta(self, room_id, events_context, new_latest_event_ids):
|
||||
"""Calculate the new state deltas for a room.
|
||||
|
||||
Args:
|
||||
room_id (str):
|
||||
room to which the events are being added. Used for logging etc
|
||||
|
||||
events_context (list[(EventBase, EventContext)]):
|
||||
events and contexts which are being added to the room
|
||||
|
||||
new_latest_event_ids (iterable[str]):
|
||||
the new forward extremities for the room.
|
||||
Assumes that we are only persisting events for one room at a time.
|
||||
|
||||
Returns:
|
||||
Deferred[dict[(str,str), str]|None]:
|
||||
None if there are no changes to the room state, or
|
||||
a dict of (type, state_key) -> event_id].
|
||||
3-tuple (to_delete, to_insert, new_state) where both are state dicts,
|
||||
i.e. (type, state_key) -> event_id. `to_delete` are the entries to
|
||||
first be deleted from current_state_events, `to_insert` are entries
|
||||
to insert. `new_state` is the full set of state.
|
||||
May return None if there are no changes to be applied.
|
||||
"""
|
||||
|
||||
if not new_latest_event_ids:
|
||||
defer.returnValue({})
|
||||
|
||||
# map from state_group to ((type, key) -> event_id) state map
|
||||
state_groups = {}
|
||||
# Now we need to work out the different state sets for
|
||||
# each state extremities
|
||||
state_sets = []
|
||||
state_groups = set()
|
||||
missing_event_ids = []
|
||||
was_updated = False
|
||||
for event_id in new_latest_event_ids:
|
||||
@@ -524,19 +486,16 @@ class EventsStore(SQLBaseStore):
|
||||
if ctx.current_state_ids is None:
|
||||
raise Exception("Unknown current state")
|
||||
|
||||
if ctx.state_group is None:
|
||||
# I don't think this can happen, but let's double-check
|
||||
raise Exception(
|
||||
"Context for new extremity event %s has no state "
|
||||
"group" % (event_id, ),
|
||||
)
|
||||
|
||||
# If we've already seen the state group don't bother adding
|
||||
# it to the state sets again
|
||||
if ctx.state_group not in state_groups:
|
||||
state_groups[ctx.state_group] = ctx.current_state_ids
|
||||
state_sets.append(ctx.current_state_ids)
|
||||
if ctx.delta_ids or hasattr(ev, "state_key"):
|
||||
was_updated = True
|
||||
if ctx.state_group:
|
||||
# Add this as a seen state group (if it has a state
|
||||
# group)
|
||||
state_groups.add(ctx.state_group)
|
||||
break
|
||||
else:
|
||||
# If we couldn't find it, then we'll need to pull
|
||||
@@ -544,50 +503,60 @@ class EventsStore(SQLBaseStore):
|
||||
was_updated = True
|
||||
missing_event_ids.append(event_id)
|
||||
|
||||
if not was_updated:
|
||||
return
|
||||
|
||||
if missing_event_ids:
|
||||
# Now pull out the state for any missing events from DB
|
||||
event_to_groups = yield self._get_state_group_for_events(
|
||||
missing_event_ids,
|
||||
)
|
||||
|
||||
groups = set(event_to_groups.itervalues()) - set(state_groups.iterkeys())
|
||||
groups = set(event_to_groups.itervalues()) - state_groups
|
||||
|
||||
if groups:
|
||||
group_to_state = yield self._get_state_for_groups(groups)
|
||||
state_groups.update(group_to_state)
|
||||
state_sets.extend(group_to_state.itervalues())
|
||||
|
||||
if len(state_groups) == 1:
|
||||
# If there is only one state group, then we know what the current
|
||||
# state is.
|
||||
defer.returnValue(state_groups.values()[0])
|
||||
if not new_latest_event_ids:
|
||||
current_state = {}
|
||||
elif was_updated:
|
||||
if len(state_sets) == 1:
|
||||
# If there is only one state set, then we know what the current
|
||||
# state is.
|
||||
current_state = state_sets[0]
|
||||
else:
|
||||
# We work out the current state by passing the state sets to the
|
||||
# state resolution algorithm. It may ask for some events, including
|
||||
# the events we have yet to persist, so we need a slightly more
|
||||
# complicated event lookup function than simply looking the events
|
||||
# up in the db.
|
||||
events_map = {ev.event_id: ev for ev, _ in events_context}
|
||||
|
||||
def get_events(ev_ids):
|
||||
return self.get_events(
|
||||
ev_ids, get_prev_content=False, check_redacted=False,
|
||||
)
|
||||
events_map = {ev.event_id: ev for ev, _ in events_context}
|
||||
logger.debug("calling resolve_state_groups from preserve_events")
|
||||
res = yield self._state_resolution_handler.resolve_state_groups(
|
||||
room_id, state_groups, events_map, get_events
|
||||
)
|
||||
@defer.inlineCallbacks
|
||||
def get_events(ev_ids):
|
||||
# We get the events by first looking at the list of events we
|
||||
# are trying to persist, and then fetching the rest from the DB.
|
||||
db = []
|
||||
to_return = {}
|
||||
for ev_id in ev_ids:
|
||||
ev = events_map.get(ev_id, None)
|
||||
if ev:
|
||||
to_return[ev_id] = ev
|
||||
else:
|
||||
db.append(ev_id)
|
||||
|
||||
defer.returnValue(res.state)
|
||||
if db:
|
||||
evs = yield self.get_events(
|
||||
ev_ids, get_prev_content=False, check_redacted=False,
|
||||
)
|
||||
to_return.update(evs)
|
||||
defer.returnValue(to_return)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _calculate_state_delta(self, room_id, current_state):
|
||||
"""Calculate the new state deltas for a room.
|
||||
current_state = yield resolve_events(
|
||||
state_sets,
|
||||
state_map_factory=get_events,
|
||||
)
|
||||
else:
|
||||
return
|
||||
|
||||
Assumes that we are only persisting events for one room at a time.
|
||||
|
||||
Returns:
|
||||
2-tuple (to_delete, to_insert) where both are state dicts,
|
||||
i.e. (type, state_key) -> event_id. `to_delete` are the entries to
|
||||
first be deleted from current_state_events, `to_insert` are entries
|
||||
to insert.
|
||||
"""
|
||||
existing_state = yield self.get_current_state_ids(room_id)
|
||||
|
||||
existing_events = set(existing_state.itervalues())
|
||||
@@ -607,7 +576,7 @@ class EventsStore(SQLBaseStore):
|
||||
if ev_id in events_to_insert
|
||||
}
|
||||
|
||||
defer.returnValue((to_delete, to_insert))
|
||||
defer.returnValue((to_delete, to_insert, current_state))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_event(self, event_id, check_redacted=True,
|
||||
@@ -667,7 +636,7 @@ class EventsStore(SQLBaseStore):
|
||||
|
||||
@log_function
|
||||
def _persist_events_txn(self, txn, events_and_contexts, backfilled,
|
||||
delete_existing=False, state_delta_for_room={},
|
||||
delete_existing=False, current_state_for_room={},
|
||||
new_forward_extremeties={}):
|
||||
"""Insert some number of room events into the necessary database tables.
|
||||
|
||||
@@ -683,7 +652,7 @@ class EventsStore(SQLBaseStore):
|
||||
delete_existing (bool): True to purge existing table rows for the
|
||||
events from the database. This is useful when retrying due to
|
||||
IntegrityError.
|
||||
state_delta_for_room (dict[str, (list[str], list[str])]):
|
||||
current_state_for_room (dict[str, (list[str], list[str])]):
|
||||
The current-state delta for each room. For each room, a tuple
|
||||
(to_delete, to_insert), being a list of event ids to be removed
|
||||
from the current state, and a list of event ids to be added to
|
||||
@@ -695,7 +664,7 @@ class EventsStore(SQLBaseStore):
|
||||
"""
|
||||
max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering
|
||||
|
||||
self._update_current_state_txn(txn, state_delta_for_room, max_stream_order)
|
||||
self._update_current_state_txn(txn, current_state_for_room, max_stream_order)
|
||||
|
||||
self._update_forward_extremities_txn(
|
||||
txn,
|
||||
@@ -739,8 +708,9 @@ class EventsStore(SQLBaseStore):
|
||||
events_and_contexts=events_and_contexts,
|
||||
)
|
||||
|
||||
# Insert into event_to_state_groups.
|
||||
self._store_event_state_mappings_txn(txn, events_and_contexts)
|
||||
# Insert into the state_groups, state_groups_state, and
|
||||
# event_to_state_groups tables.
|
||||
self._store_mult_state_groups_txn(txn, events_and_contexts)
|
||||
|
||||
# _store_rejected_events_txn filters out any events which were
|
||||
# rejected, and returns the filtered list.
|
||||
@@ -760,7 +730,7 @@ class EventsStore(SQLBaseStore):
|
||||
|
||||
def _update_current_state_txn(self, txn, state_delta_by_room, max_stream_order):
|
||||
for room_id, current_state_tuple in state_delta_by_room.iteritems():
|
||||
to_delete, to_insert = current_state_tuple
|
||||
to_delete, to_insert, _ = current_state_tuple
|
||||
txn.executemany(
|
||||
"DELETE FROM current_state_events WHERE event_id = ?",
|
||||
[(ev_id,) for ev_id in to_delete.itervalues()],
|
||||
@@ -975,9 +945,10 @@ class EventsStore(SQLBaseStore):
|
||||
# an outlier in the database. We now have some state at that
|
||||
# so we need to update the state_groups table with that state.
|
||||
|
||||
# insert into event_to_state_groups.
|
||||
# insert into the state_group, state_groups_state and
|
||||
# event_to_state_groups tables.
|
||||
try:
|
||||
self._store_event_state_mappings_txn(txn, ((event, context),))
|
||||
self._store_mult_state_groups_txn(txn, ((event, context),))
|
||||
except Exception:
|
||||
logger.exception("")
|
||||
raise
|
||||
@@ -1043,6 +1014,7 @@ class EventsStore(SQLBaseStore):
|
||||
"event_edge_hashes",
|
||||
"event_edges",
|
||||
"event_forward_extremities",
|
||||
"event_push_actions",
|
||||
"event_reference_hashes",
|
||||
"event_search",
|
||||
"event_signatures",
|
||||
@@ -1062,14 +1034,6 @@ class EventsStore(SQLBaseStore):
|
||||
[(ev.event_id,) for ev, _ in events_and_contexts]
|
||||
)
|
||||
|
||||
for table in (
|
||||
"event_push_actions",
|
||||
):
|
||||
txn.executemany(
|
||||
"DELETE FROM %s WHERE room_id = ? AND event_id = ?" % (table,),
|
||||
[(ev.event_id,) for ev, _ in events_and_contexts]
|
||||
)
|
||||
|
||||
def _store_event_txn(self, txn, events_and_contexts):
|
||||
"""Insert new events into the event and event_json tables
|
||||
|
||||
@@ -2054,32 +2018,16 @@ class EventsStore(SQLBaseStore):
|
||||
)
|
||||
return self.runInteraction("get_all_new_events", get_all_new_events_txn)
|
||||
|
||||
def purge_history(
|
||||
self, room_id, topological_ordering, delete_local_events,
|
||||
):
|
||||
"""Deletes room history before a certain point
|
||||
|
||||
Args:
|
||||
room_id (str):
|
||||
|
||||
topological_ordering (int):
|
||||
minimum topo ordering to preserve
|
||||
|
||||
delete_local_events (bool):
|
||||
if True, we will delete local events as well as remote ones
|
||||
(instead of just marking them as outliers and deleting their
|
||||
state groups).
|
||||
"""
|
||||
|
||||
def delete_old_state(self, room_id, topological_ordering):
|
||||
return self.runInteraction(
|
||||
"purge_history",
|
||||
self._purge_history_txn, room_id, topological_ordering,
|
||||
delete_local_events,
|
||||
"delete_old_state",
|
||||
self._delete_old_state_txn, room_id, topological_ordering
|
||||
)
|
||||
|
||||
def _purge_history_txn(
|
||||
self, txn, room_id, topological_ordering, delete_local_events,
|
||||
):
|
||||
def _delete_old_state_txn(self, txn, room_id, topological_ordering):
|
||||
"""Deletes old room state
|
||||
"""
|
||||
|
||||
# Tables that should be pruned:
|
||||
# event_auth
|
||||
# event_backward_extremities
|
||||
@@ -2100,30 +2048,6 @@ class EventsStore(SQLBaseStore):
|
||||
# state_groups
|
||||
# state_groups_state
|
||||
|
||||
# we will build a temporary table listing the events so that we don't
|
||||
# have to keep shovelling the list back and forth across the
|
||||
# connection. Annoyingly the python sqlite driver commits the
|
||||
# transaction on CREATE, so let's do this first.
|
||||
#
|
||||
# furthermore, we might already have the table from a previous (failed)
|
||||
# purge attempt, so let's drop the table first.
|
||||
|
||||
txn.execute("DROP TABLE IF EXISTS events_to_purge")
|
||||
|
||||
txn.execute(
|
||||
"CREATE TEMPORARY TABLE events_to_purge ("
|
||||
" event_id TEXT NOT NULL,"
|
||||
" should_delete BOOLEAN NOT NULL"
|
||||
")"
|
||||
)
|
||||
|
||||
# create an index on should_delete because later we'll be looking for
|
||||
# the should_delete / shouldn't_delete subsets
|
||||
txn.execute(
|
||||
"CREATE INDEX events_to_purge_should_delete"
|
||||
" ON events_to_purge(should_delete)",
|
||||
)
|
||||
|
||||
# First ensure that we're not about to delete all the forward extremeties
|
||||
txn.execute(
|
||||
"SELECT e.event_id, e.depth FROM events as e "
|
||||
@@ -2144,48 +2068,42 @@ class EventsStore(SQLBaseStore):
|
||||
400, "topological_ordering is greater than forward extremeties"
|
||||
)
|
||||
|
||||
logger.info("[purge] looking for events to delete")
|
||||
|
||||
should_delete_expr = "state_key IS NULL"
|
||||
should_delete_params = ()
|
||||
if not delete_local_events:
|
||||
should_delete_expr += " AND event_id NOT LIKE ?"
|
||||
should_delete_params += ("%:" + self.hs.hostname, )
|
||||
|
||||
should_delete_params += (room_id, topological_ordering)
|
||||
logger.debug("[purge] looking for events to delete")
|
||||
|
||||
txn.execute(
|
||||
"INSERT INTO events_to_purge"
|
||||
" SELECT event_id, %s"
|
||||
" FROM events AS e LEFT JOIN state_events USING (event_id)"
|
||||
" WHERE e.room_id = ? AND topological_ordering < ?" % (
|
||||
should_delete_expr,
|
||||
),
|
||||
should_delete_params,
|
||||
)
|
||||
txn.execute(
|
||||
"SELECT event_id, should_delete FROM events_to_purge"
|
||||
"SELECT event_id, state_key FROM events"
|
||||
" LEFT JOIN state_events USING (room_id, event_id)"
|
||||
" WHERE room_id = ? AND topological_ordering < ?",
|
||||
(room_id, topological_ordering,)
|
||||
)
|
||||
event_rows = txn.fetchall()
|
||||
logger.info(
|
||||
"[purge] found %i events before cutoff, of which %i can be deleted",
|
||||
len(event_rows), sum(1 for e in event_rows if e[1]),
|
||||
)
|
||||
|
||||
logger.info("[purge] Finding new backward extremities")
|
||||
to_delete = [
|
||||
(event_id,) for event_id, state_key in event_rows
|
||||
if state_key is None and not self.hs.is_mine_id(event_id)
|
||||
]
|
||||
logger.info(
|
||||
"[purge] found %i events before cutoff, of which %i are remote"
|
||||
" non-state events to delete", len(event_rows), len(to_delete))
|
||||
|
||||
for event_id, state_key in event_rows:
|
||||
txn.call_after(self._get_state_group_for_event.invalidate, (event_id,))
|
||||
|
||||
logger.debug("[purge] Finding new backward extremities")
|
||||
|
||||
# We calculate the new entries for the backward extremeties by finding
|
||||
# all events that point to events that are to be purged
|
||||
txn.execute(
|
||||
"SELECT DISTINCT e.event_id FROM events_to_purge AS e"
|
||||
" INNER JOIN event_edges AS ed ON e.event_id = ed.prev_event_id"
|
||||
" INNER JOIN events AS e2 ON e2.event_id = ed.event_id"
|
||||
" WHERE e2.topological_ordering >= ?",
|
||||
(topological_ordering, )
|
||||
"SELECT DISTINCT e.event_id FROM events as e"
|
||||
" INNER JOIN event_edges as ed ON e.event_id = ed.prev_event_id"
|
||||
" INNER JOIN events as e2 ON e2.event_id = ed.event_id"
|
||||
" WHERE e.room_id = ? AND e.topological_ordering < ?"
|
||||
" AND e2.topological_ordering >= ?",
|
||||
(room_id, topological_ordering, topological_ordering)
|
||||
)
|
||||
new_backwards_extrems = txn.fetchall()
|
||||
|
||||
logger.info("[purge] replacing backward extremities: %r", new_backwards_extrems)
|
||||
logger.debug("[purge] replacing backward extremities: %r", new_backwards_extrems)
|
||||
|
||||
txn.execute(
|
||||
"DELETE FROM event_backward_extremities WHERE room_id = ?",
|
||||
@@ -2201,7 +2119,7 @@ class EventsStore(SQLBaseStore):
|
||||
]
|
||||
)
|
||||
|
||||
logger.info("[purge] finding redundant state groups")
|
||||
logger.debug("[purge] finding redundant state groups")
|
||||
|
||||
# Get all state groups that are only referenced by events that are
|
||||
# to be deleted.
|
||||
@@ -2209,23 +2127,24 @@ class EventsStore(SQLBaseStore):
|
||||
"SELECT state_group FROM event_to_state_groups"
|
||||
" INNER JOIN events USING (event_id)"
|
||||
" WHERE state_group IN ("
|
||||
" SELECT DISTINCT state_group FROM events_to_purge"
|
||||
" SELECT DISTINCT state_group FROM events"
|
||||
" INNER JOIN event_to_state_groups USING (event_id)"
|
||||
" WHERE room_id = ? AND topological_ordering < ?"
|
||||
" )"
|
||||
" GROUP BY state_group HAVING MAX(topological_ordering) < ?",
|
||||
(topological_ordering, )
|
||||
(room_id, topological_ordering, topological_ordering)
|
||||
)
|
||||
|
||||
state_rows = txn.fetchall()
|
||||
logger.info("[purge] found %i redundant state groups", len(state_rows))
|
||||
logger.debug("[purge] found %i redundant state groups", len(state_rows))
|
||||
|
||||
# make a set of the redundant state groups, so that we can look them up
|
||||
# efficiently
|
||||
state_groups_to_delete = set([sg for sg, in state_rows])
|
||||
|
||||
# Now we get all the state groups that rely on these state groups
|
||||
logger.info("[purge] finding state groups which depend on redundant"
|
||||
" state groups")
|
||||
logger.debug("[purge] finding state groups which depend on redundant"
|
||||
" state groups")
|
||||
remaining_state_groups = []
|
||||
for i in xrange(0, len(state_rows), 100):
|
||||
chunk = [sg for sg, in state_rows[i:i + 100]]
|
||||
@@ -2250,7 +2169,7 @@ class EventsStore(SQLBaseStore):
|
||||
# Now we turn the state groups that reference to-be-deleted state
|
||||
# groups to non delta versions.
|
||||
for sg in remaining_state_groups:
|
||||
logger.info("[purge] de-delta-ing remaining state group %s", sg)
|
||||
logger.debug("[purge] de-delta-ing remaining state group %s", sg)
|
||||
curr_state = self._get_state_groups_from_groups_txn(
|
||||
txn, [sg], types=None
|
||||
)
|
||||
@@ -2287,7 +2206,7 @@ class EventsStore(SQLBaseStore):
|
||||
],
|
||||
)
|
||||
|
||||
logger.info("[purge] removing redundant state groups")
|
||||
logger.debug("[purge] removing redundant state groups")
|
||||
txn.executemany(
|
||||
"DELETE FROM state_groups_state WHERE state_group = ?",
|
||||
state_rows
|
||||
@@ -2297,15 +2216,18 @@ class EventsStore(SQLBaseStore):
|
||||
state_rows
|
||||
)
|
||||
|
||||
logger.info("[purge] removing events from event_to_state_groups")
|
||||
txn.execute(
|
||||
"DELETE FROM event_to_state_groups "
|
||||
"WHERE event_id IN (SELECT event_id from events_to_purge)"
|
||||
# Delete all non-state
|
||||
logger.debug("[purge] removing events from event_to_state_groups")
|
||||
txn.executemany(
|
||||
"DELETE FROM event_to_state_groups WHERE event_id = ?",
|
||||
[(event_id,) for event_id, _ in event_rows]
|
||||
)
|
||||
|
||||
logger.debug("[purge] updating room_depth")
|
||||
txn.execute(
|
||||
"UPDATE room_depth SET min_depth = ? WHERE room_id = ?",
|
||||
(topological_ordering, room_id,)
|
||||
)
|
||||
for event_id, _ in event_rows:
|
||||
txn.call_after(self._get_state_group_for_event.invalidate, (
|
||||
event_id,
|
||||
))
|
||||
|
||||
# Delete all remote non-state events
|
||||
for table in (
|
||||
@@ -2323,55 +2245,22 @@ class EventsStore(SQLBaseStore):
|
||||
"event_signatures",
|
||||
"rejections",
|
||||
):
|
||||
logger.info("[purge] removing events from %s", table)
|
||||
logger.debug("[purge] removing remote non-state events from %s", table)
|
||||
|
||||
txn.execute(
|
||||
"DELETE FROM %s WHERE event_id IN ("
|
||||
" SELECT event_id FROM events_to_purge WHERE should_delete"
|
||||
")" % (table,),
|
||||
)
|
||||
|
||||
# event_push_actions lacks an index on event_id, and has one on
|
||||
# (room_id, event_id) instead.
|
||||
for table in (
|
||||
"event_push_actions",
|
||||
):
|
||||
logger.info("[purge] removing events from %s", table)
|
||||
|
||||
txn.execute(
|
||||
"DELETE FROM %s WHERE room_id = ? AND event_id IN ("
|
||||
" SELECT event_id FROM events_to_purge WHERE should_delete"
|
||||
")" % (table,),
|
||||
(room_id, )
|
||||
txn.executemany(
|
||||
"DELETE FROM %s WHERE event_id = ?" % (table,),
|
||||
to_delete
|
||||
)
|
||||
|
||||
# Mark all state and own events as outliers
|
||||
logger.info("[purge] marking remaining events as outliers")
|
||||
txn.execute(
|
||||
logger.debug("[purge] marking remaining events as outliers")
|
||||
txn.executemany(
|
||||
"UPDATE events SET outlier = ?"
|
||||
" WHERE event_id IN ("
|
||||
" SELECT event_id FROM events_to_purge "
|
||||
" WHERE NOT should_delete"
|
||||
")",
|
||||
(True,),
|
||||
)
|
||||
|
||||
# synapse tries to take out an exclusive lock on room_depth whenever it
|
||||
# persists events (because upsert), and once we run this update, we
|
||||
# will block that for the rest of our transaction.
|
||||
#
|
||||
# So, let's stick it at the end so that we don't block event
|
||||
# persistence.
|
||||
logger.info("[purge] updating room_depth")
|
||||
txn.execute(
|
||||
"UPDATE room_depth SET min_depth = ? WHERE room_id = ?",
|
||||
(topological_ordering, room_id,)
|
||||
)
|
||||
|
||||
# finally, drop the temp table. this will commit the txn in sqlite,
|
||||
# so make sure to keep this actually last.
|
||||
txn.execute(
|
||||
"DROP TABLE events_to_purge"
|
||||
" WHERE event_id = ?",
|
||||
[
|
||||
(True, event_id,) for event_id, state_key in event_rows
|
||||
if state_key is not None or self.hs.is_mine_id(event_id)
|
||||
]
|
||||
)
|
||||
|
||||
logger.info("[purge] done")
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user