1
0

Compare commits

..

1 Commits

Author SHA1 Message Date
Erik Johnston
171829bb94 Add logging to initialSync 2015-05-01 11:24:03 +01:00
198 changed files with 4407 additions and 10619 deletions

4
.gitignore vendored
View File

@@ -42,7 +42,3 @@ build/
localhost-800*/
static/client/register/register_config.js
.tox
env/
*.config

View File

@@ -35,13 +35,3 @@ Turned to Dust <dwinslow86 at gmail.com>
Brabo <brabo at riseup.net>
* Installation instruction fixes
Ivan Shapovalov <intelfx100 at gmail.com>
* contrib/systemd: a sample systemd unit file and a logger configuration
Eric Myhre <hash at exultant.us>
* Fix bug where ``media_store_path`` config option was ignored by v0 content
repository API.
Muthu Subramanian <muthu.subramanian.karunanidhi at ericsson.com>
* Add SAML2 support for registration and logins.

View File

@@ -1,265 +1,9 @@
Changes in synapse v0.10.0 (2015-09-03)
=======================================
Changes in synapse vX
=====================
No change from release candidate.
* Changed config option from ``disable_registration`` to
``enable_registration``. Old option will be ignored.
Changes in synapse v0.10.0-rc6 (2015-09-02)
===========================================
* Remove some of the old database upgrade scripts.
* Fix database port script to work with newly created sqlite databases.
Changes in synapse v0.10.0-rc5 (2015-08-27)
===========================================
* Fix bug that broke downloading files with ascii filenames across federation.
Changes in synapse v0.10.0-rc4 (2015-08-27)
===========================================
* Allow UTF-8 filenames for upload. (PR #259)
Changes in synapse v0.10.0-rc3 (2015-08-25)
===========================================
* Add ``--keys-directory`` config option to specify where files such as
certs and signing keys should be stored in, when using ``--generate-config``
or ``--generate-keys``. (PR #250)
* Allow ``--config-path`` to specify a directory, causing synapse to use all
\*.yaml files in the directory as config files. (PR #249)
* Add ``web_client_location`` config option to specify static files to be
hosted by synapse under ``/_matrix/client``. (PR #245)
* Add helper utility to synapse to read and parse the config files and extract
the value of a given key. For example::
$ python -m synapse.config read server_name -c homeserver.yaml
localhost
(PR #246)
Changes in synapse v0.10.0-rc2 (2015-08-24)
===========================================
* Fix bug where we incorrectly populated the ``event_forward_extremities``
table, resulting in problems joining large remote rooms (e.g.
``#matrix:matrix.org``)
* Reduce the number of times we wake up pushers by not listening for presence
or typing events, reducing the CPU cost of each pusher.
Changes in synapse v0.10.0-rc1 (2015-08-21)
===========================================
Also see v0.9.4-rc1 changelog, which has been amalgamated into this release.
General:
* Upgrade to Twisted 15 (PR #173)
* Add support for serving and fetching encryption keys over federation.
(PR #208)
* Add support for logging in with email address (PR #234)
* Add support for new ``m.room.canonical_alias`` event. (PR #233)
* Change synapse to treat user IDs case insensitively during registration and
login. (If two users already exist with case insensitive matching user ids,
synapse will continue to require them to specify their user ids exactly.)
* Error if a user tries to register with an email already in use. (PR #211)
* Add extra and improve existing caches (PR #212, #219, #226, #228)
* Batch various storage request (PR #226, #228)
* Fix bug where we didn't correctly log the entity that triggered the request
if the request came in via an application service (PR #230)
* Fix bug where we needlessly regenerated the full list of rooms an AS is
interested in. (PR #232)
* Add support for AS's to use v2_alpha registration API (PR #210)
Configuration:
* Add ``--generate-keys`` that will generate any missing cert and key files in
the configuration files. This is equivalent to running ``--generate-config``
on an existing configuration file. (PR #220)
* ``--generate-config`` now no longer requires a ``--server-name`` parameter
when used on existing configuration files. (PR #220)
* Add ``--print-pidfile`` flag that controls the printing of the pid to stdout
of the demonised process. (PR #213)
Media Repository:
* Fix bug where we picked a lower resolution image than requested. (PR #205)
* Add support for specifying if a the media repository should dynamically
thumbnail images or not. (PR #206)
Metrics:
* Add statistics from the reactor to the metrics API. (PR #224, #225)
Demo Homeservers:
* Fix starting the demo homeservers without rate-limiting enabled. (PR #182)
* Fix enabling registration on demo homeservers (PR #223)
Changes in synapse v0.9.4-rc1 (2015-07-21)
==========================================
General:
* Add basic implementation of receipts. (SPEC-99)
* Add support for configuration presets in room creation API. (PR #203)
* Add auth event that limits the visibility of history for new users.
(SPEC-134)
* Add SAML2 login/registration support. (PR #201. Thanks Muthu Subramanian!)
* Add client side key management APIs for end to end encryption. (PR #198)
* Change power level semantics so that you cannot kick, ban or change power
levels of users that have equal or greater power level than you. (SYN-192)
* Improve performance by bulk inserting events where possible. (PR #193)
* Improve performance by bulk verifying signatures where possible. (PR #194)
Configuration:
* Add support for including TLS certificate chains.
Media Repository:
* Add Content-Disposition headers to content repository responses. (SYN-150)
Changes in synapse v0.9.3 (2015-07-01)
======================================
No changes from v0.9.3 Release Candidate 1.
Changes in synapse v0.9.3-rc1 (2015-06-23)
==========================================
General:
* Fix a memory leak in the notifier. (SYN-412)
* Improve performance of room initial sync. (SYN-418)
* General improvements to logging.
* Remove ``access_token`` query params from ``INFO`` level logging.
Configuration:
* Add support for specifying and configuring multiple listeners. (SYN-389)
Application services:
* Fix bug where synapse failed to send user queries to application services.
Changes in synapse v0.9.2-r2 (2015-06-15)
=========================================
Fix packaging so that schema delta python files get included in the package.
Changes in synapse v0.9.2 (2015-06-12)
======================================
General:
* Use ultrajson for json (de)serialisation when a canonical encoding is not
required. Ultrajson is significantly faster than simplejson in certain
circumstances.
* Use connection pools for outgoing HTTP connections.
* Process thumbnails on separate threads.
Configuration:
* Add option, ``gzip_responses``, to disable HTTP response compression.
Federation:
* Improve resilience of backfill by ensuring we fetch any missing auth events.
* Improve performance of backfill and joining remote rooms by removing
unnecessary computations. This included handling events we'd previously
handled as well as attempting to compute the current state for outliers.
Changes in synapse v0.9.1 (2015-05-26)
======================================
General:
* Add support for backfilling when a client paginates. This allows servers to
request history for a room from remote servers when a client tries to
paginate history the server does not have - SYN-36
* Fix bug where you couldn't disable non-default pushrules - SYN-378
* Fix ``register_new_user`` script - SYN-359
* Improve performance of fetching events from the database, this improves both
initialSync and sending of events.
* Improve performance of event streams, allowing synapse to handle more
simultaneous connected clients.
Federation:
* Fix bug with existing backfill implementation where it returned the wrong
selection of events in some circumstances.
* Improve performance of joining remote rooms.
Configuration:
* Add support for changing the bind host of the metrics listener via the
``metrics_bind_host`` option.
Changes in synapse v0.9.0-r5 (2015-05-21)
=========================================
* Add more database caches to reduce amount of work done for each pusher. This
radically reduces CPU usage when multiple pushers are set up in the same room.
Changes in synapse v0.9.0 (2015-05-07)
======================================
General:
* Add support for using a PostgreSQL database instead of SQLite. See
`docs/postgres.rst`_ for details.
* Add password change and reset APIs. See `Registration`_ in the spec.
* Fix memory leak due to not releasing stale notifiers - SYN-339.
* Fix race in caches that occasionally caused some presence updates to be
dropped - SYN-369.
* Check server name has not changed on restart.
* Add a sample systemd unit file and a logger configuration in
contrib/systemd. Contributed Ivan Shapovalov.
Federation:
* Add key distribution mechanisms for fetching public keys of unavailable
remote home servers. See `Retrieving Server Keys`_ in the spec.
Configuration:
* Add support for multiple config files.
* Add support for dictionaries in config files.
* Remove support for specifying config options on the command line, except
for:
* ``--daemonize`` - Daemonize the home server.
* ``--manhole`` - Turn on the twisted telnet manhole service on the given
port.
* ``--database-path`` - The path to a sqlite database to use.
* ``--verbose`` - The verbosity level.
* ``--log-file`` - File to log to.
* ``--log-config`` - Python logging config file.
* ``--enable-registration`` - Enable registration for new users.
Application services:
* Reliably retry sending of events from Synapse to application services, as per
`Application Services`_ spec.
* Application services can no longer register via the ``/register`` API,
instead their configuration should be saved to a file and listed in the
synapse ``app_service_config_files`` config option. The AS configuration file
has the same format as the old ``/register`` request.
See `docs/application_services.rst`_ for more information.
.. _`docs/postgres.rst`: docs/postgres.rst
.. _`docs/application_services.rst`: docs/application_services.rst
.. _`Registration`: https://github.com/matrix-org/matrix-doc/blob/master/specification/10_client_server_api.rst#registration
.. _`Retrieving Server Keys`: https://github.com/matrix-org/matrix-doc/blob/6f2698/specification/30_server_server_api.rst#retrieving-server-keys
.. _`Application Services`: https://github.com/matrix-org/matrix-doc/blob/0c6bd9/specification/25_application_service_api.rst#home-server---application-service-api
Changes in synapse v0.8.1 (2015-03-18)
======================================

View File

@@ -3,20 +3,12 @@ include LICENSE
include VERSION
include *.rst
include demo/README
include demo/demo.tls.dh
include demo/*.py
include demo/*.sh
recursive-include synapse/storage/schema *.sql
recursive-include synapse/storage/schema *.py
recursive-include demo *.dh
recursive-include demo *.py
recursive-include demo *.sh
recursive-include docs *
recursive-include scripts *
recursive-include scripts-dev *
recursive-include tests *.py
recursive-include static *.css
recursive-include static *.html
recursive-include static *.js
prune demo/etc

View File

@@ -7,7 +7,7 @@ Matrix is an ambitious new ecosystem for open federated Instant Messaging and
VoIP. The basics you need to know to get up and running are:
- Everything in Matrix happens in a room. Rooms are distributed and do not
exist on any single server. Rooms can be located using convenience aliases
exist on any single server. Rooms can be located using convenience aliases
like ``#matrix:matrix.org`` or ``#test:localhost:8448``.
- Matrix user IDs look like ``@matthew:matrix.org`` (although in the future
@@ -23,7 +23,7 @@ The overall architecture is::
accessed by the web client at http://matrix.org/beta or via an IRC bridge at
irc://irc.freenode.net/matrix.
Synapse is currently in rapid development, but as of version 0.5 we believe it
Synapse is currently in rapid development, but as of version 0.5 we believe it
is sufficiently stable to be run as an internet-facing service for real usage!
About Matrix
@@ -94,7 +94,6 @@ Synapse is the reference python/twisted Matrix homeserver implementation.
System requirements:
- POSIX-compliant system (tested on Linux & OS X)
- Python 2.7
- At least 512 MB RAM.
Synapse is written in python but some of the libraries is uses are written in
C. So before we can install synapse itself we need a working C compiler and the
@@ -102,41 +101,36 @@ header files for python C extensions.
Installing prerequisites on Ubuntu or Debian::
sudo apt-get install build-essential python2.7-dev libffi-dev \
python-pip python-setuptools sqlite3 \
libssl-dev python-virtualenv libjpeg-dev
$ sudo apt-get install build-essential python2.7-dev libffi-dev \
python-pip python-setuptools sqlite3 \
libssl-dev python-virtualenv libjpeg-dev
Installing prerequisites on ArchLinux::
sudo pacman -S base-devel python2 python-pip \
python-setuptools python-virtualenv sqlite3
$ sudo pacman -S base-devel python2 python-pip \
python-setuptools python-virtualenv sqlite3
Installing prerequisites on Mac OS X::
xcode-select --install
sudo easy_install pip
sudo pip install virtualenv
$ xcode-select --install
$ sudo pip install virtualenv
To install the synapse homeserver run::
virtualenv -p python2.7 ~/.synapse
source ~/.synapse/bin/activate
pip install --upgrade setuptools
pip install --process-dependency-links https://github.com/matrix-org/synapse/tarball/master
$ virtualenv ~/.synapse
$ source ~/.synapse/bin/activate
$ pip install --process-dependency-links https://github.com/matrix-org/synapse/tarball/master
This installs synapse, along with the libraries it uses, into a virtual
environment under ``~/.synapse``. Feel free to pick a different directory
if you prefer.
In case of problems, please see the _Troubleshooting section below.
environment under ``~/.synapse``.
Alternatively, Silvio Fricke has contributed a Dockerfile to automate the
above in Docker at https://registry.hub.docker.com/u/silviof/docker-matrix/.
To set up your homeserver, run (in your virtualenv, as before)::
cd ~/.synapse
python -m synapse.app.homeserver \
$ cd ~/.synapse
$ python -m synapse.app.homeserver \
--server-name machine.my.domain.name \
--config-path homeserver.yaml \
--generate-config
@@ -176,13 +170,13 @@ traditionally used for convenience and simplicity.
The advantages of Postgres include:
* significant performance improvements due to the superior threading and
caching model, smarter query optimiser
* allowing the DB to be run on separate hardware
* allowing basic active/backup high-availability with a "hot spare" synapse
pointing at the same DB master, as well as enabling DB replication in
synapse itself.
* significant performance improvements due to the superior threading and
caching model, smarter query optimiser
* allowing the DB to be run on separate hardware
* allowing basic active/backup high-availability with a "hot spare" synapse
pointing at the same DB master, as well as enabling DB replication in
synapse itself.
The only disadvantage is that the code is relatively new as of April 2015 and
may have a few regressions relative to SQLite.
@@ -192,12 +186,12 @@ For information on how to install and use PostgreSQL, please see
Running Synapse
===============
To actually run your new homeserver, pick a working directory for Synapse to
run (e.g. ``~/.synapse``), and::
To actually run your new homeserver, pick a working directory for Synapse to run
(e.g. ``~/.synapse``), and::
cd ~/.synapse
source ./bin/activate
synctl start
$ cd ~/.synapse
$ source ./bin/activate
$ synctl start
Platform Specific Instructions
==============================
@@ -215,50 +209,48 @@ defaults to python 3, but synapse currently assumes python 2.7 by default:
pip may be outdated (6.0.7-1 and needs to be upgraded to 6.0.8-1 )::
sudo pip2.7 install --upgrade pip
$ sudo pip2.7 install --upgrade pip
You also may need to explicitly specify python 2.7 again during the install
request::
pip2.7 install --process-dependency-links \
$ pip2.7 install --process-dependency-links \
https://github.com/matrix-org/synapse/tarball/master
If you encounter an error with lib bcrypt causing an Wrong ELF Class:
ELFCLASS32 (x64 Systems), you may need to reinstall py-bcrypt to correctly
compile it under the right architecture. (This should not be needed if
installing under virtualenv)::
sudo pip2.7 uninstall py-bcrypt
sudo pip2.7 install py-bcrypt
$ sudo pip2.7 uninstall py-bcrypt
$ sudo pip2.7 install py-bcrypt
During setup of Synapse you need to call python2.7 directly again::
cd ~/.synapse
python2.7 -m synapse.app.homeserver \
$ cd ~/.synapse
$ python2.7 -m synapse.app.homeserver \
--server-name machine.my.domain.name \
--config-path homeserver.yaml \
--generate-config
...substituting your host and domain name as appropriate.
Windows Install
---------------
Synapse can be installed on Cygwin. It requires the following Cygwin packages:
- gcc
- git
- libffi-devel
- openssl (and openssl-devel, python-openssl)
- python
- python-setuptools
- gcc
- git
- libffi-devel
- openssl (and openssl-devel, python-openssl)
- python
- python-setuptools
The content repository requires additional packages and will be unable to process
uploads without them:
- libjpeg8
- libjpeg8-devel
- zlib
- libjpeg8
- libjpeg8-devel
- zlib
If you choose to install Synapse without these packages, you will need to reinstall
``pillow`` for changes to be applied, e.g. ``pip uninstall pillow`` ``pip install
pillow --user``
@@ -280,38 +272,33 @@ Troubleshooting
Troubleshooting Installation
----------------------------
Synapse requires pip 1.7 or later, so if your OS provides too old a version and
you get errors about ``error: no such option: --process-dependency-links`` you
Synapse requires pip 1.7 or later, so if your OS provides too old a version and
you get errors about ``error: no such option: --process-dependency-links`` you
may need to manually upgrade it::
sudo pip install --upgrade pip
Installing may fail with ``mock requires setuptools>=17.1. Aborting installation``.
You can fix this by upgrading setuptools::
pip install --upgrade setuptools
$ sudo pip install --upgrade pip
If pip crashes mid-installation for reason (e.g. lost terminal), pip may
refuse to run until you remove the temporary installation directory it
created. To reset the installation::
rm -rf /tmp/pip_install_matrix
$ rm -rf /tmp/pip_install_matrix
pip seems to leak *lots* of memory during installation. For instance, a Linux
host with 512MB of RAM may run out of memory whilst installing Twisted. If this
happens, you will have to individually install the dependencies which are
pip seems to leak *lots* of memory during installation. For instance, a Linux
host with 512MB of RAM may run out of memory whilst installing Twisted. If this
happens, you will have to individually install the dependencies which are
failing, e.g.::
pip install twisted
$ pip install twisted
On OS X, if you encounter clang: error: unknown argument: '-mno-fused-madd' you
On OSX, if you encounter clang: error: unknown argument: '-mno-fused-madd' you
will need to export CFLAGS=-Qunused-arguments.
Troubleshooting Running
-----------------------
If synapse fails with ``missing "sodium.h"`` crypto errors, you may need
to manually upgrade PyNaCL, as synapse uses NaCl (http://nacl.cr.yp.to/) for
If synapse fails with ``missing "sodium.h"`` crypto errors, you may need
to manually upgrade PyNaCL, as synapse uses NaCl (http://nacl.cr.yp.to/) for
encryption and digital signatures.
Unfortunately PyNACL currently has a few issues
(https://github.com/pyca/pynacl/issues/53) and
@@ -320,11 +307,10 @@ correctly, causing all tests to fail with errors about missing "sodium.h". To
fix try re-installing from PyPI or directly from
(https://github.com/pyca/pynacl)::
# Install from PyPI
pip install --user --upgrade --force pynacl
# Install from github
pip install --user https://github.com/pyca/pynacl/tarball/master
$ # Install from PyPI
$ pip install --user --upgrade --force pynacl
$ # Install from github
$ pip install --user https://github.com/pyca/pynacl/tarball/master
ArchLinux
~~~~~~~~~
@@ -332,8 +318,8 @@ ArchLinux
If running `$ synctl start` fails with 'returned non-zero exit status 1',
you will need to explicitly call Python2.7 - either running as::
python2.7 -m synapse.app.homeserver --daemonize -c homeserver.yaml
$ python2.7 -m synapse.app.homeserver --daemonize -c homeserver.yaml --pid-file homeserver.pid
...or by editing synctl with the correct python executable.
Synapse Development
@@ -342,16 +328,16 @@ Synapse Development
To check out a synapse for development, clone the git repo into a working
directory of your choice::
git clone https://github.com/matrix-org/synapse.git
cd synapse
$ git clone https://github.com/matrix-org/synapse.git
$ cd synapse
Synapse has a number of external dependencies, that are easiest
to install using pip and a virtualenv::
virtualenv env
source env/bin/activate
python synapse/python_dependencies.py | xargs -n1 pip install
pip install setuptools_trial mock
$ virtualenv env
$ source env/bin/activate
$ python synapse/python_dependencies.py | xargs -n1 pip install
$ pip install setuptools_trial mock
This will run a process of downloading and installing all the needed
dependencies into a virtual env.
@@ -359,7 +345,7 @@ dependencies into a virtual env.
Once this is done, you may wish to run Synapse's unit tests, to
check that everything is installed as it should be::
python setup.py test
$ python setup.py test
This should end with a 'PASSED' result::
@@ -371,11 +357,14 @@ This should end with a 'PASSED' result::
Upgrading an existing Synapse
=============================
The instructions for upgrading synapse are in `UPGRADE.rst`_.
Please check these instructions as upgrading may require extra steps for some
versions of synapse.
IMPORTANT: Before upgrading an existing synapse to a new version, please
refer to UPGRADE.rst for any additional instructions.
Otherwise, simply re-install the new codebase over the current one - e.g.
by ``pip install --process-dependency-links
https://github.com/matrix-org/synapse/tarball/master``
if using pip, or by ``git pull`` if running off a git working copy.
.. _UPGRADE.rst: UPGRADE.rst
Setting up Federation
=====================
@@ -397,11 +386,11 @@ IDs:
For the first form, simply pass the required hostname (of the machine) as the
--server-name parameter::
python -m synapse.app.homeserver \
$ python -m synapse.app.homeserver \
--server-name machine.my.domain.name \
--config-path homeserver.yaml \
--generate-config
python -m synapse.app.homeserver --config-path homeserver.yaml
$ python -m synapse.app.homeserver --config-path homeserver.yaml
Alternatively, you can run ``synctl start`` to guide you through the process.
@@ -418,11 +407,12 @@ record would then look something like::
At this point, you should then run the homeserver with the hostname of this
SRV record, as that is the name other machines will expect it to have::
python -m synapse.app.homeserver \
$ python -m synapse.app.homeserver \
--server-name YOURDOMAIN \
--bind-port 8448 \
--config-path homeserver.yaml \
--generate-config
python -m synapse.app.homeserver --config-path homeserver.yaml
$ python -m synapse.app.homeserver --config-path homeserver.yaml
You may additionally want to pass one or more "-v" options, in order to
@@ -436,8 +426,8 @@ private federation (``localhost:8080``, ``localhost:8081`` and
``localhost:8082``) which you can then access through the webclient running at
http://localhost:8080. Simply run::
demo/start.sh
$ demo/start.sh
This is mainly useful just for development purposes.
Running The Demo Web Client
@@ -500,7 +490,7 @@ time.
Where's the spec?!
==================
The source of the matrix spec lives at https://github.com/matrix-org/matrix-doc.
The source of the matrix spec lives at https://github.com/matrix-org/matrix-doc.
A recent HTML snapshot of this lives at http://matrix.org/docs/spec
@@ -510,10 +500,10 @@ Building Internal API Documentation
Before building internal API documentation install sphinx and
sphinxcontrib-napoleon::
pip install sphinx
pip install sphinxcontrib-napoleon
$ pip install sphinx
$ pip install sphinxcontrib-napoleon
Building internal API documentation::
python setup.py build_sphinx
$ python setup.py build_sphinx

View File

@@ -1,37 +1,4 @@
Upgrading Synapse
=================
Before upgrading check if any special steps are required to upgrade from the
what you currently have installed to current version of synapse. The extra
instructions that may be required are listed later in this document.
If synapse was installed in a virtualenv then active that virtualenv before
upgrading. If synapse is installed in a virtualenv in ``~/.synapse/`` then run:
.. code:: bash
source ~/.synapse/bin/activate
If synapse was installed using pip then upgrade to the latest version by
running:
.. code:: bash
pip install --upgrade --process-dependency-links https://github.com/matrix-org/synapse/tarball/master
If synapse was installed using git then upgrade to the latest version by
running:
.. code:: bash
# Pull the latest version of the master branch.
git pull
# Update the versions of synapse's python dependencies.
python synapse/python_dependencies.py | xargs -n1 pip install
Upgrading to v0.9.0
Upgrading to v0.x.x
===================
Application services have had a breaking API change in this version.

View File

@@ -21,5 +21,3 @@ handlers:
root:
level: INFO
handlers: [journal]
disable_existing_loggers: False

View File

@@ -126,26 +126,12 @@ sub on_unknown_event
if (!$bridgestate->{$room_id}->{gathered_candidates}) {
$bridgestate->{$room_id}->{gathered_candidates} = 1;
my $offer = $bridgestate->{$room_id}->{offer};
my $candidate_block = {
audio => '',
video => '',
};
my $candidate_block = "";
foreach (@{$event->{content}->{candidates}}) {
if ($_->{sdpMid}) {
$candidate_block->{$_->{sdpMid}} .= "a=" . $_->{candidate} . "\r\n";
}
else {
$candidate_block->{audio} .= "a=" . $_->{candidate} . "\r\n";
$candidate_block->{video} .= "a=" . $_->{candidate} . "\r\n";
}
$candidate_block .= "a=" . $_->{candidate} . "\r\n";
}
# XXX: assumes audio comes first
#$offer =~ s/(a=rtcp-mux[\r\n]+)/$1$candidate_block->{audio}/;
#$offer =~ s/(a=rtcp-mux[\r\n]+)/$1$candidate_block->{video}/;
$offer =~ s/(m=video)/$candidate_block->{audio}$1/;
$offer =~ s/(.$)/$1\n$candidate_block->{video}$1/;
# XXX: collate using the right m= line - for now assume audio call
$offer =~ s/(a=rtcp.*[\r\n]+)/$1$candidate_block/;
my $f = send_verto_json_request("verto.invite", {
"sdp" => $offer,
@@ -186,18 +172,22 @@ sub on_room_message
warn "[Matrix] in $room_id: $from: " . $content->{body} . "\n";
}
my $verto_connecting = $loop->new_future;
$bot_verto->connect(
%{ $CONFIG{"verto-bot"} },
on_connect_error => sub { die "Cannot connect to verto - $_[-1]" },
on_resolve_error => sub { die "Cannot resolve to verto - $_[-1]" },
)->then( sub {
warn("[Verto] connected to websocket");
$verto_connecting->done($bot_verto) if not $verto_connecting->is_done;
});
Future->needs_all(
$bot_matrix->login( %{ $CONFIG{"matrix-bot"} } )->then( sub {
$bot_matrix->start;
}),
$bot_verto->connect(
%{ $CONFIG{"verto-bot"} },
on_connect_error => sub { die "Cannot connect to verto - $_[-1]" },
on_resolve_error => sub { die "Cannot resolve to verto - $_[-1]" },
)->on_done( sub {
warn("[Verto] connected to websocket");
}),
$verto_connecting,
)->get;
$loop->attach_signal(

View File

@@ -11,4 +11,7 @@ requires 'YAML', 0;
requires 'JSON', 0;
requires 'Getopt::Long', 0;
on 'test' => sub {
requires 'Test::More', '>= 0.98';
};

View File

@@ -11,9 +11,7 @@ if [ -f $PID_FILE ]; then
exit 1
fi
for port in 8080 8081 8082; do
rm -rf $DIR/$port
rm -rf $DIR/media_store.$port
done
find "$DIR" -name "*.log" -delete
find "$DIR" -name "*.db" -delete
rm -rf $DIR/etc

View File

@@ -8,41 +8,38 @@ cd "$DIR/.."
mkdir -p demo/etc
export PYTHONPATH=$(readlink -f $(pwd))
echo $PYTHONPATH
# Check the --no-rate-limit param
PARAMS=""
if [ $# -eq 1 ]; then
if [ $1 = "--no-rate-limit" ]; then
PARAMS="--rc-messages-per-second 1000 --rc-message-burst-count 1000"
fi
fi
for port in 8080 8081 8082; do
echo "Starting server on port $port... "
https_port=$((port + 400))
mkdir -p demo/$port
pushd demo/$port
#rm $DIR/etc/$port.config
python -m synapse.app.homeserver \
--generate-config \
--config-path "demo/etc/$port.config" \
-p "$https_port" \
--unsecure-port "$port" \
-H "localhost:$https_port" \
--config-path "$DIR/etc/$port.config" \
# Check script parameters
if [ $# -eq 1 ]; then
if [ $1 = "--no-rate-limit" ]; then
# Set high limits in config file to disable rate limiting
perl -p -i -e 's/rc_messages_per_second.*/rc_messages_per_second: 1000/g' $DIR/etc/$port.config
perl -p -i -e 's/rc_message_burst_count.*/rc_message_burst_count: 1000/g' $DIR/etc/$port.config
fi
fi
perl -p -i -e 's/^enable_registration:.*/enable_registration: true/g' $DIR/etc/$port.config
-f "$DIR/$port.log" \
-d "$DIR/$port.db" \
-D --pid-file "$DIR/$port.pid" \
--manhole $((port + 1000)) \
--tls-dh-params-path "demo/demo.tls.dh" \
--media-store-path "demo/media_store.$port" \
$PARAMS $SYNAPSE_PARAMS \
--enable-registration
python -m synapse.app.homeserver \
--config-path "$DIR/etc/$port.config" \
-D \
--config-path "demo/etc/$port.config" \
-vv \
popd
done
cd "$CWD"

View File

@@ -1,36 +0,0 @@
Registering an Application Service
==================================
The registration of new application services depends on the homeserver used.
In synapse, you need to create a new configuration file for your AS and add it
to the list specified under the ``app_service_config_files`` config
option in your synapse config.
For example:
.. code-block:: yaml
app_service_config_files:
- /home/matrix/.synapse/<your-AS>.yaml
The format of the AS configuration file is as follows:
.. code-block:: yaml
url: <base url of AS>
as_token: <token AS will add to requests to HS>
hs_token: <token HS will add to requests to AS>
sender_localpart: <localpart of AS user>
namespaces:
users: # List of users we're interested in
- exclusive: <bool>
regex: <regex>
- ...
aliases: [] # List of aliases we're interested in
rooms: [] # List of room ids we're interested in
See the spec_ for further details on how application services work.
.. _spec: https://github.com/matrix-org/matrix-doc/blob/master/specification/25_application_service_api.rst#application-service-api

View File

@@ -34,15 +34,19 @@ Synapse config
When you are ready to start using PostgreSQL, add the following line to your
config file::
database:
name: psycopg2
args:
user: <user>
password: <pass>
database: <db>
host: <host>
cp_min: 5
cp_max: 10
database_config: <db_config_file>
Where ``<db_config_file>`` is the file name that points to a yaml file of the
following form::
name: psycopg2
args:
user: <user>
password: <pass>
database: <db>
host: <host>
cp_min: 5
cp_max: 10
All key, values in ``args`` are passed to the ``psycopg2.connect(..)``
function, except keys beginning with ``cp_``, which are consumed by the twisted
@@ -55,8 +59,9 @@ Porting from SQLite
Overview
~~~~~~~~
The script ``synapse_port_db`` allows porting an existing synapse server
backed by SQLite to using PostgreSQL. This is done in as a two phase process:
The script ``port_from_sqlite_to_postgres.py`` allows porting an existing
synapse server backed by SQLite to using PostgreSQL. This is done in as a two
phase process:
1. Copy the existing SQLite database to a separate location (while the server
is down) and running the port script against that offline database.
@@ -81,12 +86,13 @@ complete, restart synapse. For instance::
cp homeserver.db homeserver.db.snapshot
./synctl start
Assuming your new config file (as described in the section *Synapse config*)
is named ``homeserver-postgres.yaml`` and the SQLite snapshot is at
Assuming your database config file (as described in the section *Synapse
config*) is named ``database_config.yaml`` and the SQLite snapshot is at
``homeserver.db.snapshot`` then simply run::
synapse_port_db --sqlite-database homeserver.db.snapshot \
--postgres-config homeserver-postgres.yaml
python scripts/port_from_sqlite_to_postgres.py \
--sqlite-database homeserver.db.snapshot \
--postgres-config database_config.yaml
The flag ``--curses`` displays a coloured curses progress UI.
@@ -98,7 +104,8 @@ To complete the conversion shut down the synapse server and run the port
script one last time, e.g. if the SQLite database is at ``homeserver.db``
run::
synapse_port_db --sqlite-database homeserver.db \
python scripts/port_from_sqlite_to_postgres.py \
--sqlite-database homeserver.db \
--postgres-config database_config.yaml
Once that has completed, change the synapse config to point at the PostgreSQL

View File

@@ -33,10 +33,9 @@ def request_registration(user, password, server_location, shared_secret):
).hexdigest()
data = {
"user": user,
"username": user,
"password": password,
"mac": mac,
"type": "org.matrix.login.shared_secret",
}
server_location = server_location.rstrip("/")
@@ -44,7 +43,7 @@ def request_registration(user, password, server_location, shared_secret):
print "Sending registration request..."
req = urllib2.Request(
"%s/_matrix/client/api/v1/register" % (server_location,),
"%s/_matrix/client/v2_alpha/register" % (server_location,),
data=json.dumps(data),
headers={'Content-Type': 'application/json'}
)

View File

@@ -1,116 +0,0 @@
import psycopg2
import yaml
import sys
import json
import time
import hashlib
from unpaddedbase64 import encode_base64
from signedjson.key import read_signing_keys
from signedjson.sign import sign_json
from canonicaljson import encode_canonical_json
def select_v1_keys(connection):
cursor = connection.cursor()
cursor.execute("SELECT server_name, key_id, verify_key FROM server_signature_keys")
rows = cursor.fetchall()
cursor.close()
results = {}
for server_name, key_id, verify_key in rows:
results.setdefault(server_name, {})[key_id] = encode_base64(verify_key)
return results
def select_v1_certs(connection):
cursor = connection.cursor()
cursor.execute("SELECT server_name, tls_certificate FROM server_tls_certificates")
rows = cursor.fetchall()
cursor.close()
results = {}
for server_name, tls_certificate in rows:
results[server_name] = tls_certificate
return results
def select_v2_json(connection):
cursor = connection.cursor()
cursor.execute("SELECT server_name, key_id, key_json FROM server_keys_json")
rows = cursor.fetchall()
cursor.close()
results = {}
for server_name, key_id, key_json in rows:
results.setdefault(server_name, {})[key_id] = json.loads(str(key_json).decode("utf-8"))
return results
def convert_v1_to_v2(server_name, valid_until, keys, certificate):
return {
"old_verify_keys": {},
"server_name": server_name,
"verify_keys": {
key_id: {"key": key}
for key_id, key in keys.items()
},
"valid_until_ts": valid_until,
"tls_fingerprints": [fingerprint(certificate)],
}
def fingerprint(certificate):
finger = hashlib.sha256(certificate)
return {"sha256": encode_base64(finger.digest())}
def rows_v2(server, json):
valid_until = json["valid_until_ts"]
key_json = encode_canonical_json(json)
for key_id in json["verify_keys"]:
yield (server, key_id, "-", valid_until, valid_until, buffer(key_json))
def main():
config = yaml.load(open(sys.argv[1]))
valid_until = int(time.time() / (3600 * 24)) * 1000 * 3600 * 24
server_name = config["server_name"]
signing_key = read_signing_keys(open(config["signing_key_path"]))[0]
database = config["database"]
assert database["name"] == "psycopg2", "Can only convert for postgresql"
args = database["args"]
args.pop("cp_max")
args.pop("cp_min")
connection = psycopg2.connect(**args)
keys = select_v1_keys(connection)
certificates = select_v1_certs(connection)
json = select_v2_json(connection)
result = {}
for server in keys:
if not server in json:
v2_json = convert_v1_to_v2(
server, valid_until, keys[server], certificates[server]
)
v2_json = sign_json(v2_json, server_name, signing_key)
result[server] = v2_json
yaml.safe_dump(result, sys.stdout, default_flow_style=False)
rows = list(
row for server, json in result.items()
for row in rows_v2(server, json)
)
cursor = connection.cursor()
cursor.executemany(
"INSERT INTO server_keys_json ("
" server_name, key_id, from_server,"
" ts_added_ms, ts_valid_until_ms, key_json"
") VALUES (%s, %s, %s, %s, %s, %s)",
rows
)
connection.commit()
if __name__ == '__main__':
main()

View File

@@ -56,9 +56,10 @@ if __name__ == '__main__':
js = json.load(args.json)
auth = Auth(Mock())
check_auth(
auth,
[FrozenEvent(d) for d in js["auth_chain"]],
[FrozenEvent(d) for d in js.get("pdus", [])],
[FrozenEvent(d) for d in js["pdus"]],
)

View File

@@ -1,5 +1,5 @@
from synapse.crypto.event_signing import *
from unpaddedbase64 import encode_base64
from syutil.base64util import encode_base64
import argparse
import hashlib

View File

@@ -1,7 +1,9 @@
from signedjson.sign import verify_signed_json
from signedjson.key import decode_verify_key_bytes, write_signing_keys
from unpaddedbase64 import decode_base64
from syutil.crypto.jsonsign import verify_signed_json
from syutil.crypto.signing_key import (
decode_verify_key_bytes, write_signing_keys
)
from syutil.base64util import decode_base64
import urllib2
import json

View File

@@ -0,0 +1,21 @@
#!/bin/bash
# This is will prepare a synapse database for running with v0.0.1 of synapse.
# It will store all the user information, but will *delete* all messages and
# room data.
set -e
cp "$1" "$1.bak"
DUMP=$(sqlite3 "$1" << 'EOF'
.dump users
.dump access_tokens
.dump presence
.dump profiles
EOF
)
rm "$1"
sqlite3 "$1" <<< "$DUMP"

View File

@@ -0,0 +1,21 @@
#!/bin/bash
# This is will prepare a synapse database for running with v0.5.0 of synapse.
# It will store all the user information, but will *delete* all messages and
# room data.
set -e
cp "$1" "$1.bak"
DUMP=$(sqlite3 "$1" << 'EOF'
.dump users
.dump access_tokens
.dump presence
.dump profiles
EOF
)
rm "$1"
sqlite3 "$1" <<< "$DUMP"

View File

@@ -6,8 +6,8 @@ from synapse.crypto.event_signing import (
add_event_pdu_content_hash, compute_pdu_event_reference_hash
)
from synapse.api.events.utils import prune_pdu
from unpaddedbase64 import encode_base64, decode_base64
from canonicaljson import encode_canonical_json
from syutil.base64util import encode_base64, decode_base64
from syutil.jsonutil import encode_canonical_json
import sqlite3
import sys

View File

@@ -1,4 +1,3 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
@@ -29,7 +28,7 @@ import traceback
import yaml
logger = logging.getLogger("synapse_port_db")
logger = logging.getLogger("port_from_sqlite_to_postgres")
BOOLEAN_COLUMNS = {
@@ -106,7 +105,7 @@ class Store(object):
try:
txn = conn.cursor()
return func(
LoggingTransaction(txn, desc, self.database_engine, []),
LoggingTransaction(txn, desc, self.database_engine),
*args, **kwargs
)
except self.database_engine.module.DatabaseError as e:
@@ -378,7 +377,9 @@ class Porter(object):
for i, row in enumerate(rows):
rows[i] = tuple(
conv(j, col)
self.postgres_store.database_engine.encode_parameter(
conv(j, col)
)
for j, col in enumerate(row)
if j > 0
)
@@ -412,17 +413,14 @@ class Porter(object):
self._convert_rows("sent_transactions", headers, rows)
inserted_rows = len(rows)
if inserted_rows:
max_inserted_rowid = max(r[0] for r in rows)
max_inserted_rowid = max(r[0] for r in rows)
def insert(txn):
self.postgres_store.insert_many_txn(
txn, "sent_transactions", headers[1:], rows
)
def insert(txn):
self.postgres_store.insert_many_txn(
txn, "sent_transactions", headers[1:], rows
)
yield self.postgres_store.execute(insert)
else:
max_inserted_rowid = 0
yield self.postgres_store.execute(insert)
def get_start_id(txn):
txn.execute(
@@ -726,9 +724,6 @@ if __name__ == "__main__":
postgres_config = yaml.safe_load(args.postgres_config)
if "database" in postgres_config:
postgres_config = postgres_config["database"]
if "name" not in postgres_config:
sys.stderr.write("Malformed database config: no 'name'")
sys.exit(2)

View File

@@ -0,0 +1,331 @@
from synapse.storage import SCHEMA_VERSION, read_schema
from synapse.storage._base import SQLBaseStore
from synapse.storage.signatures import SignatureStore
from synapse.storage.event_federation import EventFederationStore
from syutil.base64util import encode_base64, decode_base64
from synapse.crypto.event_signing import compute_event_signature
from synapse.events.builder import EventBuilder
from synapse.events.utils import prune_event
from synapse.crypto.event_signing import check_event_content_hash
from syutil.crypto.jsonsign import (
verify_signed_json, SignatureVerifyException,
)
from syutil.crypto.signing_key import decode_verify_key_bytes
from syutil.jsonutil import encode_canonical_json
import argparse
# import dns.resolver
import hashlib
import httplib
import json
import sqlite3
import syutil
import urllib2
delta_sql = """
CREATE TABLE IF NOT EXISTS event_json(
event_id TEXT NOT NULL,
room_id TEXT NOT NULL,
internal_metadata NOT NULL,
json BLOB NOT NULL,
CONSTRAINT ev_j_uniq UNIQUE (event_id)
);
CREATE INDEX IF NOT EXISTS event_json_id ON event_json(event_id);
CREATE INDEX IF NOT EXISTS event_json_room_id ON event_json(room_id);
PRAGMA user_version = 10;
"""
class Store(object):
_get_event_signatures_txn = SignatureStore.__dict__["_get_event_signatures_txn"]
_get_event_content_hashes_txn = SignatureStore.__dict__["_get_event_content_hashes_txn"]
_get_event_reference_hashes_txn = SignatureStore.__dict__["_get_event_reference_hashes_txn"]
_get_prev_event_hashes_txn = SignatureStore.__dict__["_get_prev_event_hashes_txn"]
_get_prev_events_and_state = EventFederationStore.__dict__["_get_prev_events_and_state"]
_get_auth_events = EventFederationStore.__dict__["_get_auth_events"]
cursor_to_dict = SQLBaseStore.__dict__["cursor_to_dict"]
_simple_select_onecol_txn = SQLBaseStore.__dict__["_simple_select_onecol_txn"]
_simple_select_list_txn = SQLBaseStore.__dict__["_simple_select_list_txn"]
_simple_insert_txn = SQLBaseStore.__dict__["_simple_insert_txn"]
def _generate_event_json(self, txn, rows):
events = []
for row in rows:
d = dict(row)
d.pop("stream_ordering", None)
d.pop("topological_ordering", None)
d.pop("processed", None)
if "origin_server_ts" not in d:
d["origin_server_ts"] = d.pop("ts", 0)
else:
d.pop("ts", 0)
d.pop("prev_state", None)
d.update(json.loads(d.pop("unrecognized_keys")))
d["sender"] = d.pop("user_id")
d["content"] = json.loads(d["content"])
if "age_ts" not in d:
# For compatibility
d["age_ts"] = d.get("origin_server_ts", 0)
d.setdefault("unsigned", {})["age_ts"] = d.pop("age_ts")
outlier = d.pop("outlier", False)
# d.pop("membership", None)
d.pop("state_hash", None)
d.pop("replaces_state", None)
b = EventBuilder(d)
b.internal_metadata.outlier = outlier
events.append(b)
for i, ev in enumerate(events):
signatures = self._get_event_signatures_txn(
txn, ev.event_id,
)
ev.signatures = {
n: {
k: encode_base64(v) for k, v in s.items()
}
for n, s in signatures.items()
}
hashes = self._get_event_content_hashes_txn(
txn, ev.event_id,
)
ev.hashes = {
k: encode_base64(v) for k, v in hashes.items()
}
prevs = self._get_prev_events_and_state(txn, ev.event_id)
ev.prev_events = [
(e_id, h)
for e_id, h, is_state in prevs
if is_state == 0
]
# ev.auth_events = self._get_auth_events(txn, ev.event_id)
hashes = dict(ev.auth_events)
for e_id, hash in ev.prev_events:
if e_id in hashes and not hash:
hash.update(hashes[e_id])
#
# if hasattr(ev, "state_key"):
# ev.prev_state = [
# (e_id, h)
# for e_id, h, is_state in prevs
# if is_state == 1
# ]
return [e.build() for e in events]
store = Store()
# def get_key(server_name):
# print "Getting keys for: %s" % (server_name,)
# targets = []
# if ":" in server_name:
# target, port = server_name.split(":")
# targets.append((target, int(port)))
# try:
# answers = dns.resolver.query("_matrix._tcp." + server_name, "SRV")
# for srv in answers:
# targets.append((srv.target, srv.port))
# except dns.resolver.NXDOMAIN:
# targets.append((server_name, 8448))
# except:
# print "Failed to lookup keys for %s" % (server_name,)
# return {}
#
# for target, port in targets:
# url = "https://%s:%i/_matrix/key/v1" % (target, port)
# try:
# keys = json.load(urllib2.urlopen(url, timeout=2))
# verify_keys = {}
# for key_id, key_base64 in keys["verify_keys"].items():
# verify_key = decode_verify_key_bytes(
# key_id, decode_base64(key_base64)
# )
# verify_signed_json(keys, server_name, verify_key)
# verify_keys[key_id] = verify_key
# print "Got keys for: %s" % (server_name,)
# return verify_keys
# except urllib2.URLError:
# pass
# except urllib2.HTTPError:
# pass
# except httplib.HTTPException:
# pass
#
# print "Failed to get keys for %s" % (server_name,)
# return {}
def reinsert_events(cursor, server_name, signing_key):
print "Running delta: v10"
cursor.executescript(delta_sql)
cursor.execute(
"SELECT * FROM events ORDER BY rowid ASC"
)
print "Getting events..."
rows = store.cursor_to_dict(cursor)
events = store._generate_event_json(cursor, rows)
print "Got events from DB."
algorithms = {
"sha256": hashlib.sha256,
}
key_id = "%s:%s" % (signing_key.alg, signing_key.version)
verify_key = signing_key.verify_key
verify_key.alg = signing_key.alg
verify_key.version = signing_key.version
server_keys = {
server_name: {
key_id: verify_key
}
}
i = 0
N = len(events)
for event in events:
if i % 100 == 0:
print "Processed: %d/%d events" % (i,N,)
i += 1
# for alg_name in event.hashes:
# if check_event_content_hash(event, algorithms[alg_name]):
# pass
# else:
# pass
# print "FAIL content hash %s %s" % (alg_name, event.event_id, )
have_own_correctly_signed = False
for host, sigs in event.signatures.items():
pruned = prune_event(event)
for key_id in sigs:
if host not in server_keys:
server_keys[host] = {} # get_key(host)
if key_id in server_keys[host]:
try:
verify_signed_json(
pruned.get_pdu_json(),
host,
server_keys[host][key_id]
)
if host == server_name:
have_own_correctly_signed = True
except SignatureVerifyException:
print "FAIL signature check %s %s" % (
key_id, event.event_id
)
# TODO: Re sign with our own server key
if not have_own_correctly_signed:
sigs = compute_event_signature(event, server_name, signing_key)
event.signatures.update(sigs)
pruned = prune_event(event)
for key_id in event.signatures[server_name]:
verify_signed_json(
pruned.get_pdu_json(),
server_name,
server_keys[server_name][key_id]
)
event_json = encode_canonical_json(
event.get_dict()
).decode("UTF-8")
metadata_json = encode_canonical_json(
event.internal_metadata.get_dict()
).decode("UTF-8")
store._simple_insert_txn(
cursor,
table="event_json",
values={
"event_id": event.event_id,
"room_id": event.room_id,
"internal_metadata": metadata_json,
"json": event_json,
},
or_replace=True,
)
def main(database, server_name, signing_key):
conn = sqlite3.connect(database)
cursor = conn.cursor()
# Do other deltas:
cursor.execute("PRAGMA user_version")
row = cursor.fetchone()
if row and row[0]:
user_version = row[0]
# Run every version since after the current version.
for v in range(user_version + 1, 10):
print "Running delta: %d" % (v,)
sql_script = read_schema("delta/v%d" % (v,))
cursor.executescript(sql_script)
reinsert_events(cursor, server_name, signing_key)
conn.commit()
print "Success!"
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("database")
parser.add_argument("server_name")
parser.add_argument(
"signing_key", type=argparse.FileType('r'),
)
args = parser.parse_args()
signing_key = syutil.crypto.signing_key.read_signing_keys(
args.signing_key
)
main(args.database, args.server_name, signing_key[0])

View File

@@ -3,6 +3,9 @@ source-dir = docs/sphinx
build-dir = docs/build
all_files = 1
[aliases]
test = trial
[trial]
test_suite = tests
@@ -13,6 +16,3 @@ ignore =
docs/*
pylint.cfg
tox.ini
[flake8]
max-line-length = 90

View File

@@ -14,10 +14,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import glob
import os
from setuptools import setup, find_packages, Command
import sys
from setuptools import setup, find_packages
here = os.path.abspath(os.path.dirname(__file__))
@@ -38,39 +36,6 @@ def exec_file(path_segments):
exec(code, result)
return result
class Tox(Command):
user_options = [('tox-args=', 'a', "Arguments to pass to tox")]
def initialize_options(self):
self.tox_args = None
def finalize_options(self):
self.test_args = []
self.test_suite = True
def run(self):
#import here, cause outside the eggs aren't loaded
try:
import tox
except ImportError:
try:
self.distribution.fetch_build_eggs("tox")
import tox
except:
raise RuntimeError(
"The tests need 'tox' to run. Please install 'tox'."
)
import shlex
args = self.tox_args
if args:
args = shlex.split(self.tox_args)
else:
args = []
errno = tox.cmdline(args=args)
sys.exit(errno)
version = exec_file(("synapse", "__init__.py"))["__version__"]
dependencies = exec_file(("synapse", "python_dependencies.py"))
long_description = read_file(("README.rst",))
@@ -81,10 +46,14 @@ setup(
packages=find_packages(exclude=["tests", "tests.*"]),
description="Reference Synapse Home Server",
install_requires=dependencies['requirements'](include_conditional=True).keys(),
dependency_links=dependencies["DEPENDENCY_LINKS"].values(),
setup_requires=[
"Twisted==14.0.2", # Here to override setuptools_trial's dependency on Twisted>=2.4.0
"setuptools_trial",
"mock"
],
dependency_links=dependencies["DEPENDENCY_LINKS"],
include_package_data=True,
zip_safe=False,
long_description=long_description,
scripts=["synctl"] + glob.glob("scripts/*"),
cmdclass={'test': Tox},
scripts=["synctl", "register_new_matrix_user"],
)

View File

@@ -16,4 +16,4 @@
""" This is a reference implementation of a Matrix home server.
"""
__version__ = "0.10.0"
__version__ = "0.8.1-r4"

View File

@@ -20,7 +20,8 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership, JoinRules
from synapse.api.errors import AuthError, Codes, SynapseError
from synapse.util.logutils import log_function
from synapse.types import EventID, RoomID, UserID
from synapse.util.async import run_on_reactor
from synapse.types import UserID, ClientInfo
import logging
@@ -29,7 +30,7 @@ logger = logging.getLogger(__name__)
AuthEventTypes = (
EventTypes.Create, EventTypes.Member, EventTypes.PowerLevels,
EventTypes.JoinRules, EventTypes.RoomHistoryVisibility,
EventTypes.JoinRules,
)
@@ -44,11 +45,6 @@ class Auth(object):
def check(self, event, auth_events):
""" Checks if this event is correctly authed.
Args:
event: the event being checked.
auth_events (dict: event-key -> event): the existing room state.
Returns:
True if the auth checks pass.
"""
@@ -65,35 +61,11 @@ class Auth(object):
# FIXME
return True
creation_event = auth_events.get((EventTypes.Create, ""), None)
if not creation_event:
raise SynapseError(
403,
"Room %r does not exist" % (event.room_id,)
)
creating_domain = RoomID.from_string(event.room_id).domain
originating_domain = EventID.from_string(event.event_id).domain
if creating_domain != originating_domain:
if not self.can_federate(event, auth_events):
raise SynapseError(
403,
"This room has been marked as unfederatable."
)
# FIXME: Temp hack
if event.type == EventTypes.Aliases:
alias_domain = UserID.from_string(event.state_key).domain
if alias_domain != originating_domain:
raise AuthError(
403,
"Can only set aliases for own domain"
)
return True
logger.debug(
"Auth events: %s",
[a.event_id for a in auth_events.values()]
)
logger.debug("Auth events: %s", auth_events)
if event.type == EventTypes.Member:
allowed = self.is_membership_change_allowed(
@@ -112,7 +84,7 @@ class Auth(object):
self._check_power_levels(event, auth_events)
if event.type == EventTypes.Redaction:
self.check_redaction(event, auth_events)
self._check_redaction(event, auth_events)
logger.debug("Allowing! %s", event)
except AuthError as e:
@@ -174,11 +146,6 @@ class Auth(object):
user_id, room_id, repr(member)
))
def can_federate(self, event, auth_events):
creation_event = auth_events.get((EventTypes.Create, ""))
return creation_event.content.get("m.federate", True) is True
@log_function
def is_membership_change_allowed(self, event, auth_events):
membership = event.content["membership"]
@@ -218,9 +185,6 @@ class Auth(object):
join_rule = JoinRules.INVITE
user_level = self._get_user_power_level(event.user_id, auth_events)
target_level = self._get_user_power_level(
target_user_id, auth_events
)
# FIXME (erikj): What should we do here as the default?
ban_level = self._get_named_level(auth_events, "ban", 50)
@@ -292,12 +256,12 @@ class Auth(object):
elif target_user_id != event.user_id:
kick_level = self._get_named_level(auth_events, "kick", 50)
if user_level < kick_level or user_level <= target_level:
if user_level < kick_level:
raise AuthError(
403, "You cannot kick user %s." % target_user_id
)
elif Membership.BAN == membership:
if user_level < ban_level or user_level <= target_level:
if user_level < ban_level:
raise AuthError(403, "You don't have permission to ban")
else:
raise AuthError(500, "Unknown membership %s" % membership)
@@ -348,9 +312,9 @@ class Auth(object):
Args:
request - An HTTP request with an access_token query parameter.
Returns:
tuple of:
UserID (str)
Access token ID (str)
tuple : of UserID and device string:
User ID object of the user making the request
Client ID object of the client instance the user is using
Raises:
AuthError if no user by that token exists or the token is invalid.
"""
@@ -378,15 +342,16 @@ class Auth(object):
if not user_id:
raise KeyError
request.authenticated_entity = user_id
defer.returnValue((UserID.from_string(user_id), ""))
defer.returnValue(
(UserID.from_string(user_id), ClientInfo("", ""))
)
return
except KeyError:
pass # normal users won't have the user_id query parameter set.
pass # normal users won't have this query parameter set
user_info = yield self.get_user_by_access_token(access_token)
user_info = yield self.get_user_by_token(access_token)
user = user_info["user"]
device_id = user_info["device_id"]
token_id = user_info["token_id"]
ip_addr = self.hs.get_ip_from_request(request)
@@ -395,16 +360,15 @@ class Auth(object):
default=[""]
)[0]
if user and access_token and ip_addr:
self.store.insert_client_ip(
yield self.store.insert_client_ip(
user=user,
access_token=access_token,
device_id=user_info["device_id"],
ip=ip_addr,
user_agent=user_agent
)
request.authenticated_entity = user.to_string()
defer.returnValue((user, token_id,))
defer.returnValue((user, ClientInfo(device_id, token_id)))
except KeyError:
raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.",
@@ -412,23 +376,26 @@ class Auth(object):
)
@defer.inlineCallbacks
def get_user_by_access_token(self, token):
def get_user_by_token(self, token):
""" Get a registered user's ID.
Args:
token (str): The access token to get the user by.
Returns:
dict : dict that includes the user and the ID of their access token.
dict : dict that includes the user, device_id, and whether the
user is a server admin.
Raises:
AuthError if no user by that token exists or the token is invalid.
"""
ret = yield self.store.get_user_by_access_token(token)
ret = yield self.store.get_user_by_token(token)
if not ret:
raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Unrecognised access token.",
errcode=Codes.UNKNOWN_TOKEN
)
user_info = {
"admin": bool(ret.get("admin", False)),
"device_id": ret.get("device_id"),
"user": UserID.from_string(ret.get("name")),
"token_id": ret.get("token_id", None),
}
@@ -446,7 +413,6 @@ class Auth(object):
"Unrecognised access token.",
errcode=Codes.UNKNOWN_TOKEN
)
request.authenticated_entity = service.sender
defer.returnValue(service)
except KeyError:
raise AuthError(
@@ -458,6 +424,8 @@ class Auth(object):
@defer.inlineCallbacks
def add_auth_events(self, builder, context):
yield run_on_reactor()
auth_ids = self.compute_auth_events(builder, context.current_state)
auth_events_entries = yield self.store.add_event_hashes(
@@ -548,54 +516,36 @@ class Auth(object):
# Check state_key
if hasattr(event, "state_key"):
if event.state_key.startswith("@"):
if event.state_key != event.user_id:
raise AuthError(
403,
"You are not allowed to set others state"
)
else:
sender_domain = UserID.from_string(
event.user_id
).domain
if sender_domain != event.state_key:
if not event.state_key.startswith("_"):
if event.state_key.startswith("@"):
if event.state_key != event.user_id:
raise AuthError(
403,
"You are not allowed to set others state"
)
else:
sender_domain = UserID.from_string(
event.user_id
).domain
if sender_domain != event.state_key:
raise AuthError(
403,
"You are not allowed to set others state"
)
return True
def check_redaction(self, event, auth_events):
"""Check whether the event sender is allowed to redact the target event.
Returns:
True if the the sender is allowed to redact the target event if the
target event was created by them.
False if the sender is allowed to redact the target event with no
further checks.
Raises:
AuthError if the event sender is definitely not allowed to redact
the target event.
"""
def _check_redaction(self, event, auth_events):
user_level = self._get_user_power_level(event.user_id, auth_events)
redact_level = self._get_named_level(auth_events, "redact", 50)
if user_level > redact_level:
return False
redacter_domain = EventID.from_string(event.event_id).domain
redactee_domain = EventID.from_string(event.redacts).domain
if redacter_domain == redactee_domain:
return True
raise AuthError(
403,
"You don't have permission to redact events"
)
if user_level < redact_level:
raise AuthError(
403,
"You don't have permission to redact events"
)
def _check_power_levels(self, event, auth_events):
user_list = event.content.get("users", {})
@@ -621,26 +571,25 @@ class Auth(object):
# Check other levels:
levels_to_check = [
("users_default", None),
("events_default", None),
("state_default", None),
("ban", None),
("redact", None),
("kick", None),
("invite", None),
("users_default", []),
("events_default", []),
("ban", []),
("redact", []),
("kick", []),
("invite", []),
]
old_list = current_state.content.get("users")
for user in set(old_list.keys() + user_list.keys()):
levels_to_check.append(
(user, "users")
(user, ["users"])
)
old_list = current_state.content.get("events")
new_list = event.content.get("events")
for ev_id in set(old_list.keys() + new_list.keys()):
levels_to_check.append(
(ev_id, "events")
(ev_id, ["events"])
)
old_state = current_state.content
@@ -648,10 +597,12 @@ class Auth(object):
for level_to_check, dir in levels_to_check:
old_loc = old_state
for d in dir:
old_loc = old_loc.get(d, {})
new_loc = new_state
if dir:
old_loc = old_loc.get(dir, {})
new_loc = new_loc.get(dir, {})
for d in dir:
new_loc = new_loc.get(d, {})
if level_to_check in old_loc:
old_level = int(old_loc[level_to_check])
@@ -667,14 +618,6 @@ class Auth(object):
if new_level == old_level:
continue
if dir == "users" and level_to_check != event.user_id:
if old_level == user_level:
raise AuthError(
403,
"You don't have permission to remove ops level equal "
"to your own"
)
if old_level > user_level or new_level > user_level:
raise AuthError(
403,

View File

@@ -75,10 +75,6 @@ class EventTypes(object):
Redaction = "m.room.redaction"
Feedback = "m.room.message.feedback"
RoomHistoryVisibility = "m.room.history_visibility"
CanonicalAlias = "m.room.canonical_alias"
RoomAvatar = "m.room.avatar"
# These are used for validation
Message = "m.room.message"
Topic = "m.room.topic"
@@ -89,8 +85,3 @@ class RejectedReason(object):
AUTH_ERROR = "auth_error"
REPLACED = "replaced"
NOT_ANCESTOR = "not_ancestor"
class RoomCreationPreset(object):
PRIVATE_CHAT = "private_chat"
PUBLIC_CHAT = "public_chat"

View File

@@ -40,7 +40,6 @@ class Codes(object):
TOO_LARGE = "M_TOO_LARGE"
EXCLUSIVE = "M_EXCLUSIVE"
THREEPID_AUTH_FAILED = "M_THREEPID_AUTH_FAILED"
THREEPID_IN_USE = "THREEPID_IN_USE"
class CodeMessageException(RuntimeError):

View File

@@ -16,7 +16,7 @@
import sys
sys.dont_write_bytecode = True
from synapse.python_dependencies import check_requirements, DEPENDENCY_LINKS
from synapse.python_dependencies import check_requirements
if __name__ == '__main__':
check_requirements()
@@ -32,9 +32,10 @@ from synapse.server import HomeServer
from twisted.internet import reactor
from twisted.application import service
from twisted.enterprise import adbapi
from twisted.web.resource import Resource, EncodingResourceWrapper
from twisted.web.resource import Resource
from twisted.web.static import File
from twisted.web.server import Site, GzipEncoderFactory, Request
from twisted.web.server import Site
from twisted.web.http import proxiedLogFormatter, combinedLogFormatter
from synapse.http.server import JsonResource, RootRedirect
from synapse.rest.media.v0.content_repository import ContentRepoResource
from synapse.rest.media.v1.media_repository import MediaRepositoryResource
@@ -53,35 +54,21 @@ from synapse.rest.client.v1 import ClientV1RestResource
from synapse.rest.client.v2_alpha import ClientV2AlphaRestResource
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
from synapse import events
from daemonize import Daemonize
import twisted.manhole.telnet
import synapse
import contextlib
import logging
import os
import re
import resource
import subprocess
import time
logger = logging.getLogger("synapse.app.homeserver")
class GzipFile(File):
def getChild(self, path, request):
child = File.getChild(self, path, request)
return EncodingResourceWrapper(child, [GzipEncoderFactory()])
def gz_wrap(r):
return EncodingResourceWrapper(r, [GzipEncoderFactory()])
class SynapseHomeServer(HomeServer):
def build_http_client(self):
@@ -97,40 +84,17 @@ class SynapseHomeServer(HomeServer):
return JsonResource(self)
def build_resource_for_web_client(self):
webclient_path = self.get_config().web_client_location
if not webclient_path:
try:
import syweb
except ImportError:
quit_with_error(
"Could not find a webclient.\n\n"
"Please either install the matrix-angular-sdk or configure\n"
"the location of the source to serve via the configuration\n"
"option `web_client_location`\n\n"
"To install the `matrix-angular-sdk` via pip, run:\n\n"
" pip install '%(dep)s'\n"
"\n"
"You can also disable hosting of the webclient via the\n"
"configuration option `web_client`\n"
% {"dep": DEPENDENCY_LINKS["matrix-angular-sdk"]}
)
syweb_path = os.path.dirname(syweb.__file__)
webclient_path = os.path.join(syweb_path, "webclient")
# GZip is disabled here due to
# https://twistedmatrix.com/trac/ticket/7678
# (It can stay enabled for the API resources: they call
# write() with the whole body and then finish() straight
# after and so do not trigger the bug.
# return GzipFile(webclient_path) # TODO configurable?
import syweb
syweb_path = os.path.dirname(syweb.__file__)
webclient_path = os.path.join(syweb_path, "webclient")
return File(webclient_path) # TODO configurable?
def build_resource_for_static_content(self):
# This is old and should go away: not going to bother adding gzip
return File("static")
def build_resource_for_content_repo(self):
return ContentRepoResource(
self, self.config.uploads_path, self.auth, self.content_addr
self, self.upload_dir, self.auth, self.content_addr
)
def build_resource_for_media_repository(self):
@@ -156,105 +120,149 @@ class SynapseHomeServer(HomeServer):
**self.db_config.get("args", {})
)
def _listener_http(self, config, listener_config):
port = listener_config["port"]
bind_address = listener_config.get("bind_address", "")
tls = listener_config.get("tls", False)
site_tag = listener_config.get("tag", port)
def create_resource_tree(self, redirect_root_to_web_client):
"""Create the resource tree for this Home Server.
if tls and config.no_tls:
return
This in unduly complicated because Twisted does not support putting
child resources more than 1 level deep at a time.
Args:
web_client (bool): True to enable the web client.
redirect_root_to_web_client (bool): True to redirect '/' to the
location of the web client. This does nothing if web_client is not
True.
"""
config = self.get_config()
web_client = config.web_client
# list containing (path_str, Resource) e.g:
# [ ("/aaa/bbb/cc", Resource1), ("/aaa/dummy", Resource2) ]
desired_tree = [
(CLIENT_PREFIX, self.get_resource_for_client()),
(CLIENT_V2_ALPHA_PREFIX, self.get_resource_for_client_v2_alpha()),
(FEDERATION_PREFIX, self.get_resource_for_federation()),
(CONTENT_REPO_PREFIX, self.get_resource_for_content_repo()),
(SERVER_KEY_PREFIX, self.get_resource_for_server_key()),
(SERVER_KEY_V2_PREFIX, self.get_resource_for_server_key_v2()),
(MEDIA_PREFIX, self.get_resource_for_media_repository()),
(STATIC_PREFIX, self.get_resource_for_static_content()),
]
if web_client:
logger.info("Adding the web client.")
desired_tree.append((WEB_CLIENT_PREFIX,
self.get_resource_for_web_client()))
if web_client and redirect_root_to_web_client:
self.root_resource = RootRedirect(WEB_CLIENT_PREFIX)
else:
self.root_resource = Resource()
metrics_resource = self.get_resource_for_metrics()
if config.metrics_port is None and metrics_resource is not None:
desired_tree.append((METRICS_PREFIX, metrics_resource))
resources = {}
for res in listener_config["resources"]:
for name in res["names"]:
if name == "client":
if res["compress"]:
client_v1 = gz_wrap(self.get_resource_for_client())
client_v2 = gz_wrap(self.get_resource_for_client_v2_alpha())
else:
client_v1 = self.get_resource_for_client()
client_v2 = self.get_resource_for_client_v2_alpha()
# ideally we'd just use getChild and putChild but getChild doesn't work
# unless you give it a Request object IN ADDITION to the name :/ So
# instead, we'll store a copy of this mapping so we can actually add
# extra resources to existing nodes. See self._resource_id for the key.
resource_mappings = {}
for full_path, res in desired_tree:
logger.info("Attaching %s to path %s", res, full_path)
last_resource = self.root_resource
for path_seg in full_path.split('/')[1:-1]:
if path_seg not in last_resource.listNames():
# resource doesn't exist, so make a "dummy resource"
child_resource = Resource()
last_resource.putChild(path_seg, child_resource)
res_id = self._resource_id(last_resource, path_seg)
resource_mappings[res_id] = child_resource
last_resource = child_resource
else:
# we have an existing Resource, use that instead.
res_id = self._resource_id(last_resource, path_seg)
last_resource = resource_mappings[res_id]
resources.update({
CLIENT_PREFIX: client_v1,
CLIENT_V2_ALPHA_PREFIX: client_v2,
})
# ===========================
# now attach the actual desired resource
last_path_seg = full_path.split('/')[-1]
if name == "federation":
resources.update({
FEDERATION_PREFIX: self.get_resource_for_federation(),
})
# if there is already a resource here, thieve its children and
# replace it
res_id = self._resource_id(last_resource, last_path_seg)
if res_id in resource_mappings:
# there is a dummy resource at this path already, which needs
# to be replaced with the desired resource.
existing_dummy_resource = resource_mappings[res_id]
for child_name in existing_dummy_resource.listNames():
child_res_id = self._resource_id(existing_dummy_resource,
child_name)
child_resource = resource_mappings[child_res_id]
# steal the children
res.putChild(child_name, child_resource)
if name in ["static", "client"]:
resources.update({
STATIC_PREFIX: self.get_resource_for_static_content(),
})
# finally, insert the desired resource in the right place
last_resource.putChild(last_path_seg, res)
res_id = self._resource_id(last_resource, last_path_seg)
resource_mappings[res_id] = res
if name in ["media", "federation", "client"]:
resources.update({
MEDIA_PREFIX: self.get_resource_for_media_repository(),
CONTENT_REPO_PREFIX: self.get_resource_for_content_repo(),
})
return self.root_resource
if name in ["keys", "federation"]:
resources.update({
SERVER_KEY_PREFIX: self.get_resource_for_server_key(),
SERVER_KEY_V2_PREFIX: self.get_resource_for_server_key_v2(),
})
def _resource_id(self, resource, path_seg):
"""Construct an arbitrary resource ID so you can retrieve the mapping
later.
if name == "webclient":
resources[WEB_CLIENT_PREFIX] = self.get_resource_for_web_client()
If you want to represent resource A putChild resource B with path C,
the mapping should looks like _resource_id(A,C) = B.
if name == "metrics" and metrics_resource:
resources[METRICS_PREFIX] = metrics_resource
root_resource = create_resource_tree(resources)
if tls:
reactor.listenSSL(
port,
SynapseSite(
"synapse.access.https.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
),
self.tls_context_factory,
interface=bind_address
)
else:
reactor.listenTCP(
port,
SynapseSite(
"synapse.access.http.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
),
interface=bind_address
)
logger.info("Synapse now listening on port %d", port)
Args:
resource (Resource): The *parent* Resource
path_seg (str): The name of the child Resource to be attached.
Returns:
str: A unique string which can be a key to the child Resource.
"""
return "%s-%s" % (resource, path_seg)
def start_listening(self):
config = self.get_config()
for listener in config.listeners:
if listener["type"] == "http":
self._listener_http(config, listener)
elif listener["type"] == "manhole":
f = twisted.manhole.telnet.ShellFactory()
f.username = "matrix"
f.password = "rabbithole"
f.namespace['hs'] = self
reactor.listenTCP(
listener["port"],
f,
interface=listener.get("bind_address", '127.0.0.1')
)
else:
logger.warn("Unrecognized listener type: %s", listener["type"])
if not config.no_tls and config.bind_port is not None:
reactor.listenSSL(
config.bind_port,
SynapseSite(
"synapse.access.https",
config,
self.root_resource,
),
self.tls_context_factory,
interface=config.bind_host
)
logger.info("Synapse now listening on port %d", config.bind_port)
if config.unsecure_port is not None:
reactor.listenTCP(
config.unsecure_port,
SynapseSite(
"synapse.access.http",
config,
self.root_resource,
),
interface=config.bind_host
)
logger.info("Synapse now listening on port %d", config.unsecure_port)
metrics_resource = self.get_resource_for_metrics()
if metrics_resource and config.metrics_port is not None:
reactor.listenTCP(
config.metrics_port,
SynapseSite(
"synapse.access.metrics",
config,
metrics_resource,
),
interface="127.0.0.1",
)
logger.info("Metrics now running on 127.0.0.1 port %d", config.metrics_port)
def run_startup_checks(self, db_conn, database_engine):
all_users_native = are_all_users_on_domain(
@@ -275,10 +283,11 @@ class SynapseHomeServer(HomeServer):
def quit_with_error(error_string):
message_lines = error_string.split("\n")
line_length = max([len(l) for l in message_lines if len(l) < 80]) + 2
line_length = max([len(l) for l in message_lines]) + 2
sys.stderr.write("*" * line_length + '\n')
for line in message_lines:
sys.stderr.write(" %s\n" % (line.rstrip(),))
if line.strip():
sys.stderr.write(" %s\n" % (line.strip(),))
sys.stderr.write("*" * line_length + '\n')
sys.exit(1)
@@ -341,7 +350,7 @@ def get_version_string():
)
).encode("ascii")
except Exception as e:
logger.info("Failed to check for git repository: %s", e)
logger.warn("Failed to check for git repository: %s", e)
return ("Synapse/%s" % (synapse.__version__,)).encode("ascii")
@@ -386,7 +395,10 @@ def setup(config_options):
logger.info("Server hostname: %s", config.server_name)
logger.info("Server version: %s", version_string)
events.USE_FROZEN_DICTS = config.use_frozen_dicts
if re.search(":[0-9]+$", config.server_name):
domain_with_port = config.server_name
else:
domain_with_port = "%s:%s" % (config.server_name, config.bind_port)
tls_context_factory = context_factory.ServerContextFactory(config)
@@ -395,6 +407,9 @@ def setup(config_options):
hs = SynapseHomeServer(
config.server_name,
domain_with_port=domain_with_port,
upload_dir=os.path.abspath("uploads"),
db_name=config.database_path,
db_config=config.database_config,
tls_context_factory=tls_context_factory,
config=config,
@@ -403,7 +418,13 @@ def setup(config_options):
database_engine=database_engine,
)
logger.info("Preparing database: %s...", config.database_config['name'])
hs.create_resource_tree(
redirect_root_to_web_client=True,
)
db_name = hs.get_db_name()
logger.info("Preparing database: %s...", db_name)
try:
db_conn = database_engine.module.connect(
@@ -425,7 +446,14 @@ def setup(config_options):
)
sys.exit(1)
logger.info("Database prepared in %s.", config.database_config['name'])
logger.info("Database prepared in %s.", db_name)
if config.manhole:
f = twisted.manhole.telnet.ShellFactory()
f.username = "matrix"
f.password = "rabbithole"
f.namespace['hs'] = hs
reactor.listenTCP(config.manhole, f, interface='127.0.0.1')
hs.start_listening()
@@ -452,228 +480,35 @@ class SynapseService(service.Service):
return self._port.stopListening()
class SynapseRequest(Request):
def __init__(self, site, *args, **kw):
Request.__init__(self, *args, **kw)
self.site = site
self.authenticated_entity = None
self.start_time = 0
def __repr__(self):
# We overwrite this so that we don't log ``access_token``
return '<%s at 0x%x method=%s uri=%s clientproto=%s site=%s>' % (
self.__class__.__name__,
id(self),
self.method,
self.get_redacted_uri(),
self.clientproto,
self.site.site_tag,
)
def get_redacted_uri(self):
return re.sub(
r'(\?.*access_token=)[^&]*(.*)$',
r'\1<redacted>\2',
self.uri
)
def get_user_agent(self):
return self.requestHeaders.getRawHeaders("User-Agent", [None])[-1]
def started_processing(self):
self.site.access_logger.info(
"%s - %s - Received request: %s %s",
self.getClientIP(),
self.site.site_tag,
self.method,
self.get_redacted_uri()
)
self.start_time = int(time.time() * 1000)
def finished_processing(self):
self.site.access_logger.info(
"%s - %s - {%s}"
" Processed request: %dms %sB %s \"%s %s %s\" \"%s\"",
self.getClientIP(),
self.site.site_tag,
self.authenticated_entity,
int(time.time() * 1000) - self.start_time,
self.sentLength,
self.code,
self.method,
self.get_redacted_uri(),
self.clientproto,
self.get_user_agent(),
)
@contextlib.contextmanager
def processing(self):
self.started_processing()
yield
self.finished_processing()
class XForwardedForRequest(SynapseRequest):
def __init__(self, *args, **kw):
SynapseRequest.__init__(self, *args, **kw)
"""
Add a layer on top of another request that only uses the value of an
X-Forwarded-For header as the result of C{getClientIP}.
"""
def getClientIP(self):
"""
@return: The client address (the first address) in the value of the
I{X-Forwarded-For header}. If the header is not present, return
C{b"-"}.
"""
return self.requestHeaders.getRawHeaders(
b"x-forwarded-for", [b"-"])[0].split(b",")[0].strip()
class SynapseRequestFactory(object):
def __init__(self, site, x_forwarded_for):
self.site = site
self.x_forwarded_for = x_forwarded_for
def __call__(self, *args, **kwargs):
if self.x_forwarded_for:
return XForwardedForRequest(self.site, *args, **kwargs)
else:
return SynapseRequest(self.site, *args, **kwargs)
class SynapseSite(Site):
"""
Subclass of a twisted http Site that does access logging with python's
standard logging
"""
def __init__(self, logger_name, site_tag, config, resource, *args, **kwargs):
def __init__(self, logger_name, config, resource, *args, **kwargs):
Site.__init__(self, resource, *args, **kwargs)
self.site_tag = site_tag
proxied = config.get("x_forwarded", False)
self.requestFactory = SynapseRequestFactory(self, proxied)
if config.captcha_ip_origin_is_x_forwarded:
self._log_formatter = proxiedLogFormatter
else:
self._log_formatter = combinedLogFormatter
self.access_logger = logging.getLogger(logger_name)
def log(self, request):
pass
def create_resource_tree(desired_tree, redirect_root_to_web_client=True):
"""Create the resource tree for this Home Server.
This in unduly complicated because Twisted does not support putting
child resources more than 1 level deep at a time.
Args:
web_client (bool): True to enable the web client.
redirect_root_to_web_client (bool): True to redirect '/' to the
location of the web client. This does nothing if web_client is not
True.
"""
if redirect_root_to_web_client and WEB_CLIENT_PREFIX in desired_tree:
root_resource = RootRedirect(WEB_CLIENT_PREFIX)
else:
root_resource = Resource()
# ideally we'd just use getChild and putChild but getChild doesn't work
# unless you give it a Request object IN ADDITION to the name :/ So
# instead, we'll store a copy of this mapping so we can actually add
# extra resources to existing nodes. See self._resource_id for the key.
resource_mappings = {}
for full_path, res in desired_tree.items():
logger.info("Attaching %s to path %s", res, full_path)
last_resource = root_resource
for path_seg in full_path.split('/')[1:-1]:
if path_seg not in last_resource.listNames():
# resource doesn't exist, so make a "dummy resource"
child_resource = Resource()
last_resource.putChild(path_seg, child_resource)
res_id = _resource_id(last_resource, path_seg)
resource_mappings[res_id] = child_resource
last_resource = child_resource
else:
# we have an existing Resource, use that instead.
res_id = _resource_id(last_resource, path_seg)
last_resource = resource_mappings[res_id]
# ===========================
# now attach the actual desired resource
last_path_seg = full_path.split('/')[-1]
# if there is already a resource here, thieve its children and
# replace it
res_id = _resource_id(last_resource, last_path_seg)
if res_id in resource_mappings:
# there is a dummy resource at this path already, which needs
# to be replaced with the desired resource.
existing_dummy_resource = resource_mappings[res_id]
for child_name in existing_dummy_resource.listNames():
child_res_id = _resource_id(
existing_dummy_resource, child_name
)
child_resource = resource_mappings[child_res_id]
# steal the children
res.putChild(child_name, child_resource)
# finally, insert the desired resource in the right place
last_resource.putChild(last_path_seg, res)
res_id = _resource_id(last_resource, last_path_seg)
resource_mappings[res_id] = res
return root_resource
def _resource_id(resource, path_seg):
"""Construct an arbitrary resource ID so you can retrieve the mapping
later.
If you want to represent resource A putChild resource B with path C,
the mapping should looks like _resource_id(A,C) = B.
Args:
resource (Resource): The *parent* Resource
path_seg (str): The name of the child Resource to be attached.
Returns:
str: A unique string which can be a key to the child Resource.
"""
return "%s-%s" % (resource, path_seg)
line = self._log_formatter(self._logDateTime, request)
self.access_logger.info(line)
def run(hs):
PROFILE_SYNAPSE = False
if PROFILE_SYNAPSE:
def profile(func):
from cProfile import Profile
from threading import current_thread
def profiled(*args, **kargs):
profile = Profile()
profile.enable()
func(*args, **kargs)
profile.disable()
ident = current_thread().ident
profile.dump_stats("/tmp/%s.%s.%i.pstat" % (
hs.hostname, func.__name__, ident
))
return profiled
from twisted.python.threadpool import ThreadPool
ThreadPool._worker = profile(ThreadPool._worker)
reactor.run = profile(reactor.run)
def in_thread():
with LoggingContext("run"):
change_resource_limit(hs.config.soft_file_limit)
reactor.run()
if hs.config.daemonize:
if hs.config.print_pidfile:
print hs.config.pid_file
print hs.config.pid_file
daemon = Daemonize(
app="synapse-homeserver",

View File

@@ -18,33 +18,29 @@ import sys
import os
import subprocess
import signal
import yaml
SYNAPSE = ["python", "-B", "-m", "synapse.app.homeserver"]
CONFIGFILE = "homeserver.yaml"
PIDFILE = "homeserver.pid"
GREEN = "\x1b[1;32m"
NORMAL = "\x1b[m"
if not os.path.exists(CONFIGFILE):
sys.stderr.write(
"No config file found\n"
"To generate a config file, run '%s -c %s --generate-config"
" --server-name=<server name>'\n" % (
" ".join(SYNAPSE), CONFIGFILE
)
)
sys.exit(1)
CONFIG = yaml.load(open(CONFIGFILE))
PIDFILE = CONFIG["pid_file"]
def start():
if not os.path.exists(CONFIGFILE):
sys.stderr.write(
"No config file found\n"
"To generate a config file, run '%s -c %s --generate-config"
" --server-name=<server name>'\n" % (
" ".join(SYNAPSE), CONFIGFILE
)
)
sys.exit(1)
print "Starting ...",
args = SYNAPSE
args.extend(["--daemonize", "-c", CONFIGFILE])
args.extend(["--daemonize", "-c", CONFIGFILE, "--pid-file", PIDFILE])
subprocess.check_call(args)
print GREEN + "started" + NORMAL

View File

@@ -148,8 +148,8 @@ class ApplicationService(object):
and self.is_interested_in_user(event.state_key)):
return True
# check joined member events
for user_id in member_list:
if self.is_interested_in_user(user_id):
for member in member_list:
if self.is_interested_in_user(member.state_key):
return True
return False
@@ -173,7 +173,7 @@ class ApplicationService(object):
restrict_to(str): The namespace to restrict regex tests to.
aliases_for_event(list): A list of all the known room aliases for
this event.
member_list(list): A list of all joined user_ids in this room.
member_list(list): A list of all joined room members in this room.
Returns:
bool: True if this service would like to know about this event.
"""

View File

@@ -1,30 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
if __name__ == "__main__":
import sys
from homeserver import HomeServerConfig
action = sys.argv[1]
if action == "read":
key = sys.argv[2]
config = HomeServerConfig.load_config("", sys.argv[3:])
print getattr(config, key)
sys.exit(0)
else:
sys.stderr.write("Unknown command %r\n" % (action,))
sys.exit(1)

View File

@@ -14,10 +14,9 @@
# limitations under the License.
import argparse
import sys
import os
import yaml
import sys
from textwrap import dedent
class ConfigError(Exception):
@@ -25,35 +24,18 @@ class ConfigError(Exception):
class Config(object):
def __init__(self, args):
pass
@staticmethod
def parse_size(value):
if isinstance(value, int) or isinstance(value, long):
return value
def parse_size(string):
sizes = {"K": 1024, "M": 1024 * 1024}
size = 1
suffix = value[-1]
suffix = string[-1]
if suffix in sizes:
value = value[:-1]
string = string[:-1]
size = sizes[suffix]
return int(value) * size
@staticmethod
def parse_duration(value):
if isinstance(value, int) or isinstance(value, long):
return value
second = 1000
hour = 60 * 60 * second
day = 24 * hour
week = 7 * day
year = 365 * day
sizes = {"s": second, "h": hour, "d": day, "w": week, "y": year}
size = 1
suffix = value[-1]
if suffix in sizes:
value = value[:-1]
size = sizes[suffix]
return int(value) * size
return int(string) * size
@staticmethod
def abspath(file_path):
@@ -95,6 +77,17 @@ class Config(object):
with open(file_path) as file_stream:
return file_stream.read()
@classmethod
def read_yaml_file(cls, file_path, config_name):
cls.check_file(file_path, config_name)
with open(file_path) as file_stream:
try:
return yaml.load(file_stream)
except:
raise ConfigError(
"Error parsing yaml in file %r" % (file_path,)
)
@staticmethod
def default_path(name):
return os.path.abspath(os.path.join(os.path.curdir, name))
@@ -104,173 +97,84 @@ class Config(object):
with open(file_path) as file_stream:
return yaml.load(file_stream)
def invoke_all(self, name, *args, **kargs):
results = []
for cls in type(self).mro():
if name in cls.__dict__:
results.append(getattr(cls, name)(self, *args, **kargs))
return results
@classmethod
def add_arguments(cls, parser):
pass
def generate_config(self, config_dir_path, server_name):
default_config = "# vim:ft=yaml\n"
default_config += "\n\n".join(dedent(conf) for conf in self.invoke_all(
"default_config", config_dir_path, server_name
))
config = yaml.load(default_config)
return default_config, config
@classmethod
def generate_config(cls, args, config_dir_path):
pass
@classmethod
def load_config(cls, description, argv, generate_section=None):
obj = cls()
config_parser = argparse.ArgumentParser(add_help=False)
config_parser.add_argument(
"-c", "--config-path",
action="append",
metavar="CONFIG_FILE",
help="Specify config file. Can be given multiple times and"
" may specify directories containing *.yaml files."
help="Specify config file"
)
config_parser.add_argument(
"--generate-config",
action="store_true",
help="Generate a config file for the server name"
)
config_parser.add_argument(
"--generate-keys",
action="store_true",
help="Generate any missing key files then exit"
)
config_parser.add_argument(
"--keys-directory",
metavar="DIRECTORY",
help="Used with 'generate-*' options to specify where files such as"
" certs and signing keys should be stored in, unless explicitly"
" specified in the config."
)
config_parser.add_argument(
"-H", "--server-name",
help="The server name to generate a config file for"
help="Generate config file"
)
config_args, remaining_args = config_parser.parse_known_args(argv)
generate_keys = config_args.generate_keys
config_files = []
if config_args.config_path:
for config_path in config_args.config_path:
if os.path.isdir(config_path):
# We accept specifying directories as config paths, we search
# inside that directory for all files matching *.yaml, and then
# we apply them in *sorted* order.
files = []
for entry in os.listdir(config_path):
entry_path = os.path.join(config_path, entry)
if not os.path.isfile(entry_path):
print (
"Found subdirectory in config directory: %r. IGNORING."
) % (entry_path, )
continue
if not entry.endswith(".yaml"):
print (
"Found file in config directory that does not"
" end in '.yaml': %r. IGNORING."
) % (entry_path, )
continue
files.append(entry_path)
config_files.extend(sorted(files))
else:
config_files.append(config_path)
if config_args.generate_config:
if not config_files:
if not config_args.config_path:
config_parser.error(
"Must supply a config file.\nA config file can be automatically"
" generated using \"--generate-config -H SERVER_NAME"
" -c CONFIG-FILE\""
"Must specify where to generate the config file"
)
(config_path,) = config_files
if not os.path.exists(config_path):
if config_args.keys_directory:
config_dir_path = config_args.keys_directory
else:
config_dir_path = os.path.dirname(config_path)
config_dir_path = os.path.abspath(config_dir_path)
server_name = config_args.server_name
if not server_name:
print "Must specify a server_name to a generate config for."
sys.exit(1)
if not os.path.exists(config_dir_path):
os.makedirs(config_dir_path)
with open(config_path, "wb") as config_file:
config_bytes, config = obj.generate_config(
config_dir_path, server_name
)
obj.invoke_all("generate_files", config)
config_file.write(config_bytes)
print (
"A config file has been generated in %r for server name"
" %r with corresponding SSL keys and self-signed"
" certificates. Please review this file and customise it"
" to your needs."
) % (config_path, server_name)
print (
"If this server name is incorrect, you will need to"
" regenerate the SSL certificates"
)
sys.exit(0)
config_dir_path = os.path.dirname(config_args.config_path)
if os.path.exists(config_args.config_path):
defaults = cls.read_config_file(config_args.config_path)
else:
print (
"Config file %r already exists. Generating any missing key"
" files."
) % (config_path,)
generate_keys = True
defaults = {}
else:
if config_args.config_path:
defaults = cls.read_config_file(config_args.config_path)
else:
defaults = {}
parser = argparse.ArgumentParser(
parents=[config_parser],
description=description,
formatter_class=argparse.RawDescriptionHelpFormatter,
)
cls.add_arguments(parser)
parser.set_defaults(**defaults)
obj.invoke_all("add_arguments", parser)
args = parser.parse_args(remaining_args)
if not config_files:
config_parser.error(
"Must supply a config file.\nA config file can be automatically"
" generated using \"--generate-config -H SERVER_NAME"
" -c CONFIG-FILE\""
if config_args.generate_config:
config_dir_path = os.path.dirname(config_args.config_path)
config_dir_path = os.path.abspath(config_dir_path)
if not os.path.exists(config_dir_path):
os.makedirs(config_dir_path)
cls.generate_config(args, config_dir_path)
config = {}
for key, value in vars(args).items():
if (key not in set(["config_path", "generate_config"])
and value is not None):
config[key] = value
with open(config_args.config_path, "w") as config_file:
# TODO(mark/paul) We might want to output emacs-style mode
# markers as well as vim-style mode markers into the file,
# to further hint to people this is a YAML file.
config_file.write("# vim:ft=yaml\n")
yaml.dump(config, config_file, default_flow_style=False)
print (
"A config file has been generated in %s for server name"
" '%s' with corresponding SSL keys and self-signed"
" certificates. Please review this file and customise it to"
" your needs."
) % (
config_args.config_path, config['server_name']
)
print (
"If this server name is incorrect, you will need to regenerate"
" the SSL certificates"
)
if config_args.keys_directory:
config_dir_path = config_args.keys_directory
else:
config_dir_path = os.path.dirname(config_args.config_path[-1])
config_dir_path = os.path.abspath(config_dir_path)
specified_config = {}
for config_file in config_files:
yaml_config = cls.read_config_file(config_file)
specified_config.update(yaml_config)
server_name = specified_config["server_name"]
_, config = obj.generate_config(config_dir_path, server_name)
config.pop("log_config")
config.update(specified_config)
if generate_keys:
obj.invoke_all("generate_files", config)
sys.exit(0)
obj.invoke_all("read_config", config)
obj.invoke_all("read_arguments", args)
return obj
return cls(args)

View File

@@ -17,11 +17,15 @@ from ._base import Config
class AppServiceConfig(Config):
def read_config(self, config):
self.app_service_config_files = config.get("app_service_config_files", [])
def __init__(self, args):
super(AppServiceConfig, self).__init__(args)
self.app_service_config_files = args.app_service_config_files
def default_config(cls, config_dir_path, server_name):
return """\
# A list of application service config file to use
app_service_config_files: []
"""
@classmethod
def add_arguments(cls, parser):
super(AppServiceConfig, cls).add_arguments(parser)
group = parser.add_argument_group("appservice")
group.add_argument(
"--app-service-config-files", type=str, nargs='+',
help="A list of application service config files to use."
)

View File

@@ -17,31 +17,42 @@ from ._base import Config
class CaptchaConfig(Config):
def read_config(self, config):
self.recaptcha_private_key = config["recaptcha_private_key"]
self.recaptcha_public_key = config["recaptcha_public_key"]
self.enable_registration_captcha = config["enable_registration_captcha"]
self.captcha_bypass_secret = config.get("captcha_bypass_secret")
self.recaptcha_siteverify_api = config["recaptcha_siteverify_api"]
def __init__(self, args):
super(CaptchaConfig, self).__init__(args)
self.recaptcha_private_key = args.recaptcha_private_key
self.recaptcha_public_key = args.recaptcha_public_key
self.enable_registration_captcha = args.enable_registration_captcha
def default_config(self, config_dir_path, server_name):
return """\
## Captcha ##
# XXX: This is used for more than just captcha
self.captcha_ip_origin_is_x_forwarded = (
args.captcha_ip_origin_is_x_forwarded
)
self.captcha_bypass_secret = args.captcha_bypass_secret
# This Home Server's ReCAPTCHA public key.
recaptcha_private_key: "YOUR_PRIVATE_KEY"
# This Home Server's ReCAPTCHA private key.
recaptcha_public_key: "YOUR_PUBLIC_KEY"
# Enables ReCaptcha checks when registering, preventing signup
# unless a captcha is answered. Requires a valid ReCaptcha
# public/private key.
enable_registration_captcha: False
# A secret key used to bypass the captcha test entirely.
#captcha_bypass_secret: "YOUR_SECRET_HERE"
# The API endpoint to use for verifying m.login.recaptcha responses.
recaptcha_siteverify_api: "https://www.google.com/recaptcha/api/siteverify"
"""
@classmethod
def add_arguments(cls, parser):
super(CaptchaConfig, cls).add_arguments(parser)
group = parser.add_argument_group("recaptcha")
group.add_argument(
"--recaptcha-public-key", type=str, default="YOUR_PUBLIC_KEY",
help="This Home Server's ReCAPTCHA public key."
)
group.add_argument(
"--recaptcha-private-key", type=str, default="YOUR_PRIVATE_KEY",
help="This Home Server's ReCAPTCHA private key."
)
group.add_argument(
"--enable-registration-captcha", type=bool, default=False,
help="Enables ReCaptcha checks when registering, preventing signup"
+ " unless a captcha is answered. Requires a valid ReCaptcha "
+ "public/private key."
)
group.add_argument(
"--captcha_ip_origin_is_x_forwarded", type=bool, default=False,
help="When checking captchas, use the X-Forwarded-For (XFF) header"
+ " as the client IP and not the actual client IP."
)
group.add_argument(
"--captcha_bypass_secret", type=str,
help="A secret key used to bypass the captcha test entirely."
)

View File

@@ -14,21 +14,28 @@
# limitations under the License.
from ._base import Config
import os
import yaml
class DatabaseConfig(Config):
def __init__(self, args):
super(DatabaseConfig, self).__init__(args)
if args.database_path == ":memory:":
self.database_path = ":memory:"
else:
self.database_path = self.abspath(args.database_path)
self.event_cache_size = self.parse_size(args.event_cache_size)
def read_config(self, config):
self.event_cache_size = self.parse_size(
config.get("event_cache_size", "10K")
)
self.database_config = config.get("database")
if self.database_config is None:
if args.database_config:
with open(args.database_config) as f:
self.database_config = yaml.safe_load(f)
else:
self.database_config = {
"name": "sqlite3",
"args": {},
"args": {
"database": self.database_path,
},
}
name = self.database_config.get("name", None)
@@ -43,37 +50,24 @@ class DatabaseConfig(Config):
else:
raise RuntimeError("Unsupported database type '%s'" % (name,))
self.set_databasepath(config.get("database_path"))
def default_config(self, config, config_dir_path):
database_path = self.abspath("homeserver.db")
return """\
# Database configuration
database:
# The database engine name
name: "sqlite3"
# Arguments to pass to the engine
args:
# Path to the database
database: "%(database_path)s"
# Number of events to cache in memory.
event_cache_size: "10K"
""" % locals()
def read_arguments(self, args):
self.set_databasepath(args.database_path)
def set_databasepath(self, database_path):
if database_path != ":memory:":
database_path = self.abspath(database_path)
if self.database_config.get("name", None) == "sqlite3":
if database_path is not None:
self.database_config["args"]["database"] = database_path
def add_arguments(self, parser):
@classmethod
def add_arguments(cls, parser):
super(DatabaseConfig, cls).add_arguments(parser)
db_group = parser.add_argument_group("database")
db_group.add_argument(
"-d", "--database-path", metavar="SQLITE_DATABASE_PATH",
help="The path to a sqlite database to use."
"-d", "--database-path", default="homeserver.db",
metavar="SQLITE_DATABASE_PATH", help="The database name."
)
db_group.add_argument(
"--event-cache-size", default="100K",
help="Number of events to cache in memory."
)
db_group.add_argument(
"--database-config", default=None,
help="Location of the database configuration file."
)
@classmethod
def generate_config(cls, args, config_dir_path):
super(DatabaseConfig, cls).generate_config(args, config_dir_path)
args.database_path = os.path.abspath(args.database_path)

View File

@@ -25,18 +25,15 @@ from .registration import RegistrationConfig
from .metrics import MetricsConfig
from .appservice import AppServiceConfig
from .key import KeyConfig
from .saml2 import SAML2Config
class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
RatelimitConfig, ContentRepositoryConfig, CaptchaConfig,
VoipConfig, RegistrationConfig, MetricsConfig,
AppServiceConfig, KeyConfig, SAML2Config, ):
VoipConfig, RegistrationConfig,
MetricsConfig, AppServiceConfig, KeyConfig,):
pass
if __name__ == '__main__':
import sys
sys.stdout.write(
HomeServerConfig().generate_config(sys.argv[1], sys.argv[2])[0]
)
HomeServerConfig.load_config("Generate config", sys.argv[1:], "HomeServer")

View File

@@ -13,68 +13,55 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import Config, ConfigError
from synapse.util.stringutils import random_string
from signedjson.key import (
generate_signing_key, is_signing_algorithm_supported,
decode_signing_key_base64, decode_verify_key_bytes,
read_signing_keys, write_signing_keys, NACL_ED25519
)
from unpaddedbase64 import decode_base64
import os
from ._base import Config, ConfigError
import syutil.crypto.signing_key
from syutil.crypto.signing_key import (
is_signing_algorithm_supported, decode_verify_key_bytes
)
from syutil.base64util import decode_base64
class KeyConfig(Config):
def read_config(self, config):
self.signing_key = self.read_signing_key(config["signing_key_path"])
def __init__(self, args):
super(KeyConfig, self).__init__(args)
self.signing_key = self.read_signing_key(args.signing_key_path)
self.old_signing_keys = self.read_old_signing_keys(
config["old_signing_keys"]
)
self.key_refresh_interval = self.parse_duration(
config["key_refresh_interval"]
args.old_signing_key_path
)
self.key_refresh_interval = args.key_refresh_interval
self.perspectives = self.read_perspectives(
config["perspectives"]
args.perspectives_config_path
)
def default_config(self, config_dir_path, server_name):
base_key_name = os.path.join(config_dir_path, server_name)
return """\
## Signing Keys ##
@classmethod
def add_arguments(cls, parser):
super(KeyConfig, cls).add_arguments(parser)
key_group = parser.add_argument_group("keys")
key_group.add_argument("--signing-key-path",
help="The signing key to sign messages with")
key_group.add_argument("--old-signing-key-path",
help="The keys that the server used to sign"
" sign messages with but won't use"
" to sign new messages. E.g. it has"
" lost its private key")
key_group.add_argument("--key-refresh-interval",
default=24 * 60 * 60 * 1000, # 1 Day
help="How long a key response is valid for."
" Used to set the exipiry in /key/v2/."
" Controls how frequently servers will"
" query what keys are still valid")
key_group.add_argument("--perspectives-config-path",
help="The trusted servers to download signing"
" keys from")
# Path to the signing key to sign messages with
signing_key_path: "%(base_key_name)s.signing.key"
# The keys that the server used to sign messages with but won't use
# to sign new messages. E.g. it has lost its private key
old_signing_keys: {}
# "ed25519:auto":
# # Base64 encoded public key
# key: "The public part of your old signing key."
# # Millisecond POSIX timestamp when the key expired.
# expired_ts: 123456789123
# How long key response published by this server is valid for.
# Used to set the valid_until_ts in /key/v2 APIs.
# Determines how quickly servers will query to check which keys
# are still valid.
key_refresh_interval: "1d" # 1 Day.
# The trusted servers to download signing keys from.
perspectives:
servers:
"matrix.org":
verify_keys:
"ed25519:auto":
key: "Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw"
""" % locals()
def read_perspectives(self, perspectives_config):
def read_perspectives(self, perspectives_config_path):
config = self.read_yaml_file(
perspectives_config_path, "perspectives_config_path"
)
servers = {}
for server_name, server_config in perspectives_config["servers"].items():
for server_name, server_config in config["servers"].items():
for key_id, key_data in server_config["verify_keys"].items():
if is_signing_algorithm_supported(key_id):
key_base64 = key_data["key"]
@@ -86,45 +73,75 @@ class KeyConfig(Config):
def read_signing_key(self, signing_key_path):
signing_keys = self.read_file(signing_key_path, "signing_key")
try:
return read_signing_keys(signing_keys.splitlines(True))
return syutil.crypto.signing_key.read_signing_keys(
signing_keys.splitlines(True)
)
except Exception:
raise ConfigError(
"Error reading signing_key."
" Try running again with --generate-config"
)
def read_old_signing_keys(self, old_signing_keys):
keys = {}
for key_id, key_data in old_signing_keys.items():
if is_signing_algorithm_supported(key_id):
key_base64 = key_data["key"]
key_bytes = decode_base64(key_base64)
verify_key = decode_verify_key_bytes(key_id, key_bytes)
verify_key.expired_ts = key_data["expired_ts"]
keys[key_id] = verify_key
else:
raise ConfigError(
"Unsupported signing algorithm for old key: %r" % (key_id,)
)
return keys
def read_old_signing_keys(self, old_signing_key_path):
old_signing_keys = self.read_file(
old_signing_key_path, "old_signing_key"
)
try:
return syutil.crypto.signing_key.read_old_signing_keys(
old_signing_keys.splitlines(True)
)
except Exception:
raise ConfigError(
"Error reading old signing keys."
)
def generate_files(self, config):
signing_key_path = config["signing_key_path"]
if not os.path.exists(signing_key_path):
with open(signing_key_path, "w") as signing_key_file:
key_id = "a_" + random_string(4)
write_signing_keys(
signing_key_file, (generate_signing_key(key_id),),
@classmethod
def generate_config(cls, args, config_dir_path):
super(KeyConfig, cls).generate_config(args, config_dir_path)
base_key_name = os.path.join(config_dir_path, args.server_name)
args.pid_file = os.path.abspath(args.pid_file)
if not args.signing_key_path:
args.signing_key_path = base_key_name + ".signing.key"
if not os.path.exists(args.signing_key_path):
with open(args.signing_key_path, "w") as signing_key_file:
syutil.crypto.signing_key.write_signing_keys(
signing_key_file,
(syutil.crypto.signing_key.generate_signing_key("auto"),),
)
else:
signing_keys = self.read_file(signing_key_path, "signing_key")
signing_keys = cls.read_file(args.signing_key_path, "signing_key")
if len(signing_keys.split("\n")[0].split()) == 1:
# handle keys in the old format.
key_id = "a_" + random_string(4)
key = decode_signing_key_base64(
NACL_ED25519, key_id, signing_keys.split("\n")[0]
key = syutil.crypto.signing_key.decode_signing_key_base64(
syutil.crypto.signing_key.NACL_ED25519,
"auto",
signing_keys.split("\n")[0]
)
with open(signing_key_path, "w") as signing_key_file:
write_signing_keys(
signing_key_file, (key,),
with open(args.signing_key_path, "w") as signing_key_file:
syutil.crypto.signing_key.write_signing_keys(
signing_key_file,
(key,),
)
if not args.old_signing_key_path:
args.old_signing_key_path = base_key_name + ".old.signing.keys"
if not os.path.exists(args.old_signing_key_path):
with open(args.old_signing_key_path, "w"):
pass
if not args.perspectives_config_path:
args.perspectives_config_path = base_key_name + ".perspectives"
if not os.path.exists(args.perspectives_config_path):
with open(args.perspectives_config_path, "w") as perspectives_file:
perspectives_file.write(
'servers:\n'
' matrix.org:\n'
' verify_keys:\n'
' "ed25519:auto":\n'
' key: "Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw"\n'
)

View File

@@ -19,88 +19,25 @@ from twisted.python.log import PythonLoggingObserver
import logging
import logging.config
import yaml
from string import Template
import os
DEFAULT_LOG_CONFIG = Template("""
version: 1
formatters:
precise:
format: '%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s\
- %(message)s'
filters:
context:
(): synapse.util.logcontext.LoggingContextFilter
request: ""
handlers:
file:
class: logging.handlers.RotatingFileHandler
formatter: precise
filename: ${log_file}
maxBytes: 104857600
backupCount: 10
filters: [context]
level: INFO
console:
class: logging.StreamHandler
formatter: precise
loggers:
synapse:
level: INFO
synapse.storage.SQL:
level: INFO
root:
level: INFO
handlers: [file, console]
""")
class LoggingConfig(Config):
def __init__(self, args):
super(LoggingConfig, self).__init__(args)
self.verbosity = int(args.verbose) if args.verbose else None
self.log_config = self.abspath(args.log_config)
self.log_file = self.abspath(args.log_file)
def read_config(self, config):
self.verbosity = config.get("verbose", 0)
self.log_config = self.abspath(config.get("log_config"))
self.log_file = self.abspath(config.get("log_file"))
def default_config(self, config_dir_path, server_name):
log_file = self.abspath("homeserver.log")
log_config = self.abspath(
os.path.join(config_dir_path, server_name + ".log.config")
)
return """
# Logging verbosity level.
verbose: 0
# File to write logging to
log_file: "%(log_file)s"
# A yaml python logging config file
log_config: "%(log_config)s"
""" % locals()
def read_arguments(self, args):
if args.verbose is not None:
self.verbosity = args.verbose
if args.log_config is not None:
self.log_config = args.log_config
if args.log_file is not None:
self.log_file = args.log_file
@classmethod
def add_arguments(cls, parser):
super(LoggingConfig, cls).add_arguments(parser)
logging_group = parser.add_argument_group("logging")
logging_group.add_argument(
'-v', '--verbose', dest="verbose", action='count',
help="The verbosity level."
)
logging_group.add_argument(
'-f', '--log-file', dest="log_file",
'-f', '--log-file', dest="log_file", default="homeserver.log",
help="File to log to."
)
logging_group.add_argument(
@@ -108,14 +45,6 @@ class LoggingConfig(Config):
help="Python logging config file"
)
def generate_files(self, config):
log_config = config.get("log_config")
if log_config and not os.path.exists(log_config):
with open(log_config, "wb") as log_config_file:
log_config_file.write(
DEFAULT_LOG_CONFIG.substitute(log_file=config["log_file"])
)
def setup_logging(self):
log_format = (
"%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s"

View File

@@ -17,15 +17,20 @@ from ._base import Config
class MetricsConfig(Config):
def read_config(self, config):
self.enable_metrics = config["enable_metrics"]
self.metrics_port = config.get("metrics_port")
self.metrics_bind_host = config.get("metrics_bind_host", "127.0.0.1")
def __init__(self, args):
super(MetricsConfig, self).__init__(args)
self.enable_metrics = args.enable_metrics
self.metrics_port = args.metrics_port
def default_config(self, config_dir_path, server_name):
return """\
## Metrics ###
# Enable collection and rendering of performance metrics
enable_metrics: False
"""
@classmethod
def add_arguments(cls, parser):
super(MetricsConfig, cls).add_arguments(parser)
metrics_group = parser.add_argument_group("metrics")
metrics_group.add_argument(
'--enable-metrics', dest="enable_metrics", action="store_true",
help="Enable collection and rendering of performance metrics"
)
metrics_group.add_argument(
'--metrics-port', metavar="PORT", type=int,
help="Separate port to accept metrics requests on (on localhost)"
)

View File

@@ -17,42 +17,56 @@ from ._base import Config
class RatelimitConfig(Config):
def read_config(self, config):
self.rc_messages_per_second = config["rc_messages_per_second"]
self.rc_message_burst_count = config["rc_message_burst_count"]
def __init__(self, args):
super(RatelimitConfig, self).__init__(args)
self.rc_messages_per_second = args.rc_messages_per_second
self.rc_message_burst_count = args.rc_message_burst_count
self.federation_rc_window_size = config["federation_rc_window_size"]
self.federation_rc_sleep_limit = config["federation_rc_sleep_limit"]
self.federation_rc_sleep_delay = config["federation_rc_sleep_delay"]
self.federation_rc_reject_limit = config["federation_rc_reject_limit"]
self.federation_rc_concurrent = config["federation_rc_concurrent"]
self.federation_rc_window_size = args.federation_rc_window_size
self.federation_rc_sleep_limit = args.federation_rc_sleep_limit
self.federation_rc_sleep_delay = args.federation_rc_sleep_delay
self.federation_rc_reject_limit = args.federation_rc_reject_limit
self.federation_rc_concurrent = args.federation_rc_concurrent
def default_config(self, config_dir_path, server_name):
return """\
## Ratelimiting ##
@classmethod
def add_arguments(cls, parser):
super(RatelimitConfig, cls).add_arguments(parser)
rc_group = parser.add_argument_group("ratelimiting")
rc_group.add_argument(
"--rc-messages-per-second", type=float, default=0.2,
help="number of messages a client can send per second"
)
rc_group.add_argument(
"--rc-message-burst-count", type=float, default=10,
help="number of message a client can send before being throttled"
)
# Number of messages a client can send per second
rc_messages_per_second: 0.2
rc_group.add_argument(
"--federation-rc-window-size", type=int, default=10000,
help="The federation window size in milliseconds",
)
# Number of message a client can send before being throttled
rc_message_burst_count: 10.0
rc_group.add_argument(
"--federation-rc-sleep-limit", type=int, default=10,
help="The number of federation requests from a single server"
" in a window before the server will delay processing the"
" request.",
)
# The federation window size in milliseconds
federation_rc_window_size: 1000
rc_group.add_argument(
"--federation-rc-sleep-delay", type=int, default=500,
help="The duration in milliseconds to delay processing events from"
" remote servers by if they go over the sleep limit.",
)
# The number of federation requests from a single server in a window
# before the server will delay processing the request.
federation_rc_sleep_limit: 10
rc_group.add_argument(
"--federation-rc-reject-limit", type=int, default=50,
help="The maximum number of concurrent federation requests allowed"
" from a single server",
)
# The duration in milliseconds to delay processing events from
# remote servers by if they go over the sleep limit.
federation_rc_sleep_delay: 500
# The maximum number of concurrent federation requests allowed
# from a single server
federation_rc_reject_limit: 50
# The number of federation requests to concurrently process from a
# single server
federation_rc_concurrent: 3
"""
rc_group.add_argument(
"--federation-rc-concurrent", type=int, default=3,
help="The number of federation requests to concurrently process"
" from a single server",
)

View File

@@ -17,48 +17,45 @@ from ._base import Config
from synapse.util.stringutils import random_string_with_symbols
from distutils.util import strtobool
import distutils.util
class RegistrationConfig(Config):
def read_config(self, config):
def __init__(self, args):
super(RegistrationConfig, self).__init__(args)
# `args.enable_registration` may either be a bool or a string depending
# on if the option was given a value (e.g. --enable-registration=true
# would set `args.enable_registration` to "true" not True.)
self.disable_registration = not bool(
strtobool(str(config["enable_registration"]))
distutils.util.strtobool(str(args.enable_registration))
)
if "disable_registration" in config:
self.disable_registration = bool(
strtobool(str(config["disable_registration"]))
)
self.registration_shared_secret = args.registration_shared_secret
self.registration_shared_secret = config.get("registration_shared_secret")
self.macaroon_secret_key = config.get("macaroon_secret_key")
def default_config(self, config_dir, server_name):
registration_shared_secret = random_string_with_symbols(50)
macaroon_secret_key = random_string_with_symbols(50)
return """\
## Registration ##
# Enable registration for new users.
enable_registration: False
# If set, allows registration by anyone who also has the shared
# secret, even if registration is otherwise disabled.
registration_shared_secret: "%(registration_shared_secret)s"
macaroon_secret_key: "%(macaroon_secret_key)s"
""" % locals()
def add_arguments(self, parser):
@classmethod
def add_arguments(cls, parser):
super(RegistrationConfig, cls).add_arguments(parser)
reg_group = parser.add_argument_group("registration")
reg_group.add_argument(
"--enable-registration", action="store_true", default=None,
help="Enable registration for new users."
"--enable-registration",
const=True,
default=False,
nargs='?',
help="Enable registration for new users.",
)
reg_group.add_argument(
"--registration-shared-secret", type=str,
help="If set, allows registration by anyone who also has the shared"
" secret, even if registration is otherwise disabled.",
)
def read_arguments(self, args):
if args.enable_registration is not None:
self.disable_registration = not bool(
strtobool(str(args.enable_registration))
)
@classmethod
def generate_config(cls, args, config_dir_path):
super(RegistrationConfig, cls).generate_config(args, config_dir_path)
if args.enable_registration is None:
args.enable_registration = False
if args.registration_shared_secret is None:
args.registration_shared_secret = random_string_with_symbols(50)

View File

@@ -14,87 +14,35 @@
# limitations under the License.
from ._base import Config
from collections import namedtuple
ThumbnailRequirement = namedtuple(
"ThumbnailRequirement", ["width", "height", "method", "media_type"]
)
def parse_thumbnail_requirements(thumbnail_sizes):
""" Takes a list of dictionaries with "width", "height", and "method" keys
and creates a map from image media types to the thumbnail size, thumnailing
method, and thumbnail media type to precalculate
Args:
thumbnail_sizes(list): List of dicts with "width", "height", and
"method" keys
Returns:
Dictionary mapping from media type string to list of
ThumbnailRequirement tuples.
"""
requirements = {}
for size in thumbnail_sizes:
width = size["width"]
height = size["height"]
method = size["method"]
jpeg_thumbnail = ThumbnailRequirement(width, height, method, "image/jpeg")
png_thumbnail = ThumbnailRequirement(width, height, method, "image/png")
requirements.setdefault("image/jpeg", []).append(jpeg_thumbnail)
requirements.setdefault("image/gif", []).append(png_thumbnail)
requirements.setdefault("image/png", []).append(png_thumbnail)
return {
media_type: tuple(thumbnails)
for media_type, thumbnails in requirements.items()
}
class ContentRepositoryConfig(Config):
def read_config(self, config):
self.max_upload_size = self.parse_size(config["max_upload_size"])
self.max_image_pixels = self.parse_size(config["max_image_pixels"])
self.media_store_path = self.ensure_directory(config["media_store_path"])
self.uploads_path = self.ensure_directory(config["uploads_path"])
self.dynamic_thumbnails = config["dynamic_thumbnails"]
self.thumbnail_requirements = parse_thumbnail_requirements(
config["thumbnail_sizes"]
def __init__(self, args):
super(ContentRepositoryConfig, self).__init__(args)
self.max_upload_size = self.parse_size(args.max_upload_size)
self.max_image_pixels = self.parse_size(args.max_image_pixels)
self.media_store_path = self.ensure_directory(args.media_store_path)
def parse_size(self, string):
sizes = {"K": 1024, "M": 1024 * 1024}
size = 1
suffix = string[-1]
if suffix in sizes:
string = string[:-1]
size = sizes[suffix]
return int(string) * size
@classmethod
def add_arguments(cls, parser):
super(ContentRepositoryConfig, cls).add_arguments(parser)
db_group = parser.add_argument_group("content_repository")
db_group.add_argument(
"--max-upload-size", default="10M"
)
db_group.add_argument(
"--media-store-path", default=cls.default_path("media_store")
)
db_group.add_argument(
"--max-image-pixels", default="32M",
help="Maximum number of pixels that will be thumbnailed"
)
def default_config(self, config_dir_path, server_name):
media_store = self.default_path("media_store")
uploads_path = self.default_path("uploads")
return """
# Directory where uploaded images and attachments are stored.
media_store_path: "%(media_store)s"
# Directory where in-progress uploads are stored.
uploads_path: "%(uploads_path)s"
# The largest allowed upload size in bytes
max_upload_size: "10M"
# Maximum number of pixels that will be thumbnailed
max_image_pixels: "32M"
# Whether to generate new thumbnails on the fly to precisely match
# the resolution requested by the client. If true then whenever
# a new resolution is requested by the client the server will
# generate a new thumbnail. If false the server will pick a thumbnail
# from a precalcualted list.
dynamic_thumbnails: false
# List of thumbnail to precalculate when an image is uploaded.
thumbnail_sizes:
- width: 32
height: 32
method: crop
- width: 96
height: 96
method: crop
- width: 320
height: 240
method: scale
- width: 640
height: 480
method: scale
""" % locals()

View File

@@ -1,54 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2015 Ericsson
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import Config
class SAML2Config(Config):
"""SAML2 Configuration
Synapse uses pysaml2 libraries for providing SAML2 support
config_path: Path to the sp_conf.py configuration file
idp_redirect_url: Identity provider URL which will redirect
the user back to /login/saml2 with proper info.
sp_conf.py file is something like:
https://github.com/rohe/pysaml2/blob/master/example/sp-repoze/sp_conf.py.example
More information: https://pythonhosted.org/pysaml2/howto/config.html
"""
def read_config(self, config):
saml2_config = config.get("saml2_config", None)
if saml2_config:
self.saml2_enabled = True
self.saml2_config_path = saml2_config["config_path"]
self.saml2_idp_redirect_url = saml2_config["idp_redirect_url"]
else:
self.saml2_enabled = False
self.saml2_config_path = None
self.saml2_idp_redirect_url = None
def default_config(self, config_dir_path, server_name):
return """
# Enable SAML2 for registration and login. Uses pysaml2
# config_path: Path to the sp_conf.py configuration file
# idp_redirect_url: Identity provider URL which will redirect
# the user back to /login/saml2 with proper info.
# See pysaml2 docs for format of config.
#saml2_config:
# config_path: "%s/sp_conf.py"
# idp_redirect_url: "http://%s/idp"
""" % (config_dir_path, server_name)

View File

@@ -17,212 +17,64 @@ from ._base import Config
class ServerConfig(Config):
def __init__(self, args):
super(ServerConfig, self).__init__(args)
self.server_name = args.server_name
self.bind_port = args.bind_port
self.bind_host = args.bind_host
self.unsecure_port = args.unsecure_port
self.daemonize = args.daemonize
self.pid_file = self.abspath(args.pid_file)
self.web_client = args.web_client
self.manhole = args.manhole
self.soft_file_limit = args.soft_file_limit
def read_config(self, config):
self.server_name = config["server_name"]
self.pid_file = self.abspath(config.get("pid_file"))
self.web_client = config["web_client"]
self.web_client_location = config.get("web_client_location", None)
self.soft_file_limit = config["soft_file_limit"]
self.daemonize = config.get("daemonize")
self.print_pidfile = config.get("print_pidfile")
self.use_frozen_dicts = config.get("use_frozen_dicts", True)
self.listeners = config.get("listeners", [])
bind_port = config.get("bind_port")
if bind_port:
self.listeners = []
bind_host = config.get("bind_host", "")
gzip_responses = config.get("gzip_responses", True)
names = ["client", "webclient"] if self.web_client else ["client"]
self.listeners.append({
"port": bind_port,
"bind_address": bind_host,
"tls": True,
"type": "http",
"resources": [
{
"names": names,
"compress": gzip_responses,
},
{
"names": ["federation"],
"compress": False,
}
]
})
unsecure_port = config.get("unsecure_port", bind_port - 400)
if unsecure_port:
self.listeners.append({
"port": unsecure_port,
"bind_address": bind_host,
"tls": False,
"type": "http",
"resources": [
{
"names": names,
"compress": gzip_responses,
},
{
"names": ["federation"],
"compress": False,
}
]
})
manhole = config.get("manhole")
if manhole:
self.listeners.append({
"port": manhole,
"bind_address": "127.0.0.1",
"type": "manhole",
})
metrics_port = config.get("metrics_port")
if metrics_port:
self.listeners.append({
"port": metrics_port,
"bind_address": config.get("metrics_bind_host", "127.0.0.1"),
"tls": False,
"type": "http",
"resources": [
{
"names": ["metrics"],
"compress": False,
},
]
})
# Attempt to guess the content_addr for the v0 content repostitory
content_addr = config.get("content_addr")
if not content_addr:
for listener in self.listeners:
if listener["type"] == "http" and not listener.get("tls", False):
unsecure_port = listener["port"]
break
else:
raise RuntimeError("Could not determine 'content_addr'")
host = self.server_name
if not args.content_addr:
host = args.server_name
if ':' not in host:
host = "%s:%d" % (host, unsecure_port)
host = "%s:%d" % (host, args.unsecure_port)
else:
host = host.split(':')[0]
host = "%s:%d" % (host, unsecure_port)
content_addr = "http://%s" % (host,)
host = "%s:%d" % (host, args.unsecure_port)
args.content_addr = "http://%s" % (host,)
self.content_addr = content_addr
self.content_addr = args.content_addr
def default_config(self, config_dir_path, server_name):
if ":" in server_name:
bind_port = int(server_name.split(":")[1])
unsecure_port = bind_port - 400
else:
bind_port = 8448
unsecure_port = 8008
pid_file = self.abspath("homeserver.pid")
return """\
## Server ##
# The domain name of the server, with optional explicit port.
# This is used by remote servers to connect to this server,
# e.g. matrix.org, localhost:8080, etc.
server_name: "%(server_name)s"
# When running as a daemon, the file to store the pid in
pid_file: %(pid_file)s
# Whether to serve a web client from the HTTP/HTTPS root resource.
web_client: True
# Set the soft limit on the number of file descriptors synapse can use
# Zero is used to indicate synapse should set the soft limit to the
# hard limit.
soft_file_limit: 0
# List of ports that Synapse should listen on, their purpose and their
# configuration.
listeners:
# Main HTTPS listener
# For when matrix traffic is sent directly to synapse.
-
# The port to listen for HTTPS requests on.
port: %(bind_port)s
# Local interface to listen on.
# The empty string will cause synapse to listen on all interfaces.
bind_address: ''
# This is a 'http' listener, allows us to specify 'resources'.
type: http
tls: true
# Use the X-Forwarded-For (XFF) header as the client IP and not the
# actual client IP.
x_forwarded: false
# List of HTTP resources to serve on this listener.
resources:
-
# List of resources to host on this listener.
names:
- client # The client-server APIs, both v1 and v2
- webclient # The bundled webclient.
# Should synapse compress HTTP responses to clients that support it?
# This should be disabled if running synapse behind a load balancer
# that can do automatic compression.
compress: true
- names: [federation] # Federation APIs
compress: false
# Unsecure HTTP listener,
# For when matrix traffic passes through loadbalancer that unwraps TLS.
- port: %(unsecure_port)s
tls: false
bind_address: ''
type: http
x_forwarded: false
resources:
- names: [client, webclient]
compress: true
- names: [federation]
compress: false
# Turn on the twisted telnet manhole service on localhost on the given
# port.
# - port: 9000
# bind_address: 127.0.0.1
# type: manhole
""" % locals()
def read_arguments(self, args):
if args.manhole is not None:
self.manhole = args.manhole
if args.daemonize is not None:
self.daemonize = args.daemonize
if args.print_pidfile is not None:
self.print_pidfile = args.print_pidfile
def add_arguments(self, parser):
@classmethod
def add_arguments(cls, parser):
super(ServerConfig, cls).add_arguments(parser)
server_group = parser.add_argument_group("server")
server_group.add_argument(
"-H", "--server-name", default="localhost",
help="The domain name of the server, with optional explicit port. "
"This is used by remote servers to connect to this server, "
"e.g. matrix.org, localhost:8080, etc."
)
server_group.add_argument("-p", "--bind-port", metavar="PORT",
type=int, help="https port to listen on",
default=8448)
server_group.add_argument("--unsecure-port", metavar="PORT",
type=int, help="http port to listen on",
default=8008)
server_group.add_argument("--bind-host", default="",
help="Local interface to listen on")
server_group.add_argument("-D", "--daemonize", action='store_true',
default=None,
help="Daemonize the home server")
server_group.add_argument("--print-pidfile", action='store_true',
default=None,
help="Print the path to the pidfile just"
" before daemonizing")
server_group.add_argument('--pid-file', default="homeserver.pid",
help="When running as a daemon, the file to"
" store the pid in")
server_group.add_argument('--web_client', default=True, type=bool,
help="Whether or not to serve a web client")
server_group.add_argument("--manhole", metavar="PORT", dest="manhole",
type=int,
help="Turn on the twisted telnet manhole"
" service on the given port.")
server_group.add_argument("--content-addr", default=None,
help="The host and scheme to use for the "
"content repository")
server_group.add_argument("--soft-file-limit", type=int, default=0,
help="Set the soft limit on the number of "
"file descriptors synapse can use. "
"Zero is used to indicate synapse "
"should set the soft limit to the hard"
"limit.")

View File

@@ -23,49 +23,37 @@ GENERATE_DH_PARAMS = False
class TlsConfig(Config):
def read_config(self, config):
def __init__(self, args):
super(TlsConfig, self).__init__(args)
self.tls_certificate = self.read_tls_certificate(
config.get("tls_certificate_path")
args.tls_certificate_path
)
self.tls_certificate_file = config.get("tls_certificate_path")
self.no_tls = config.get("no_tls", False)
self.no_tls = args.no_tls
if self.no_tls:
self.tls_private_key = None
else:
self.tls_private_key = self.read_tls_private_key(
config.get("tls_private_key_path")
args.tls_private_key_path
)
self.tls_dh_params_path = self.check_file(
config.get("tls_dh_params_path"), "tls_dh_params"
args.tls_dh_params_path, "tls_dh_params"
)
def default_config(self, config_dir_path, server_name):
base_key_name = os.path.join(config_dir_path, server_name)
tls_certificate_path = base_key_name + ".tls.crt"
tls_private_key_path = base_key_name + ".tls.key"
tls_dh_params_path = base_key_name + ".tls.dh"
return """\
# PEM encoded X509 certificate for TLS.
# You can replace the self-signed certificate that synapse
# autogenerates on launch with your own SSL certificate + key pair
# if you like. Any required intermediary certificates can be
# appended after the primary certificate in hierarchical order.
tls_certificate_path: "%(tls_certificate_path)s"
# PEM encoded private key for TLS
tls_private_key_path: "%(tls_private_key_path)s"
# PEM dh parameters for ephemeral keys
tls_dh_params_path: "%(tls_dh_params_path)s"
# Don't bind to the https port
no_tls: False
""" % locals()
@classmethod
def add_arguments(cls, parser):
super(TlsConfig, cls).add_arguments(parser)
tls_group = parser.add_argument_group("tls")
tls_group.add_argument("--tls-certificate-path",
help="PEM encoded X509 certificate for TLS")
tls_group.add_argument("--tls-private-key-path",
help="PEM encoded private key for TLS")
tls_group.add_argument("--tls-dh-params-path",
help="PEM dh parameters for ephemeral keys")
tls_group.add_argument("--no-tls", action='store_true',
help="Don't bind to the https port.")
def read_tls_certificate(self, cert_path):
cert_pem = self.read_file(cert_path, "tls_certificate")
@@ -75,13 +63,22 @@ class TlsConfig(Config):
private_key_pem = self.read_file(private_key_path, "tls_private_key")
return crypto.load_privatekey(crypto.FILETYPE_PEM, private_key_pem)
def generate_files(self, config):
tls_certificate_path = config["tls_certificate_path"]
tls_private_key_path = config["tls_private_key_path"]
tls_dh_params_path = config["tls_dh_params_path"]
@classmethod
def generate_config(cls, args, config_dir_path):
super(TlsConfig, cls).generate_config(args, config_dir_path)
base_key_name = os.path.join(config_dir_path, args.server_name)
if not os.path.exists(tls_private_key_path):
with open(tls_private_key_path, "w") as private_key_file:
if args.tls_certificate_path is None:
args.tls_certificate_path = base_key_name + ".tls.crt"
if args.tls_private_key_path is None:
args.tls_private_key_path = base_key_name + ".tls.key"
if args.tls_dh_params_path is None:
args.tls_dh_params_path = base_key_name + ".tls.dh"
if not os.path.exists(args.tls_private_key_path):
with open(args.tls_private_key_path, "w") as private_key_file:
tls_private_key = crypto.PKey()
tls_private_key.generate_key(crypto.TYPE_RSA, 2048)
private_key_pem = crypto.dump_privatekey(
@@ -89,17 +86,17 @@ class TlsConfig(Config):
)
private_key_file.write(private_key_pem)
else:
with open(tls_private_key_path) as private_key_file:
with open(args.tls_private_key_path) as private_key_file:
private_key_pem = private_key_file.read()
tls_private_key = crypto.load_privatekey(
crypto.FILETYPE_PEM, private_key_pem
)
if not os.path.exists(tls_certificate_path):
with open(tls_certificate_path, "w") as certificate_file:
if not os.path.exists(args.tls_certificate_path):
with open(args.tls_certificate_path, "w") as certifcate_file:
cert = crypto.X509()
subject = cert.get_subject()
subject.CN = config["server_name"]
subject.CN = args.server_name
cert.set_serial_number(1000)
cert.gmtime_adj_notBefore(0)
@@ -111,18 +108,18 @@ class TlsConfig(Config):
cert_pem = crypto.dump_certificate(crypto.FILETYPE_PEM, cert)
certificate_file.write(cert_pem)
certifcate_file.write(cert_pem)
if not os.path.exists(tls_dh_params_path):
if not os.path.exists(args.tls_dh_params_path):
if GENERATE_DH_PARAMS:
subprocess.check_call([
"openssl", "dhparam",
"-outform", "PEM",
"-out", tls_dh_params_path,
"-out", args.tls_dh_params_path,
"2048"
])
else:
with open(tls_dh_params_path, "w") as dh_params_file:
with open(args.tls_dh_params_path, "w") as dh_params_file:
dh_params_file.write(
"2048-bit DH parameters taken from rfc3526\n"
"-----BEGIN DH PARAMETERS-----\n"

View File

@@ -17,21 +17,28 @@ from ._base import Config
class VoipConfig(Config):
def read_config(self, config):
self.turn_uris = config.get("turn_uris", [])
self.turn_shared_secret = config["turn_shared_secret"]
self.turn_user_lifetime = self.parse_duration(config["turn_user_lifetime"])
def __init__(self, args):
super(VoipConfig, self).__init__(args)
self.turn_uris = args.turn_uris
self.turn_shared_secret = args.turn_shared_secret
self.turn_user_lifetime = args.turn_user_lifetime
def default_config(self, config_dir_path, server_name):
return """\
## Turn ##
# The public URIs of the TURN server to give to clients
turn_uris: []
# The shared secret used to compute passwords for the TURN server
turn_shared_secret: "YOUR_SHARED_SECRET"
# How long generated TURN credentials last
turn_user_lifetime: "1h"
"""
@classmethod
def add_arguments(cls, parser):
super(VoipConfig, cls).add_arguments(parser)
group = parser.add_argument_group("voip")
group.add_argument(
"--turn-uris", type=str, default=None, action='append',
help="The public URIs of the TURN server to give to clients"
)
group.add_argument(
"--turn-shared-secret", type=str, default=None,
help=(
"The shared secret used to compute passwords for the TURN"
" server"
)
)
group.add_argument(
"--turn-user-lifetime", type=int, default=(1000 * 60 * 60),
help="How long generated TURN credentials last, in ms"
)

View File

@@ -35,9 +35,9 @@ class ServerContextFactory(ssl.ContextFactory):
_ecCurve = _OpenSSLECCurve(_defaultCurveName)
_ecCurve.addECKeyToContext(context)
except:
logger.exception("Failed to enable elliptic curve for TLS")
logger.exception("Failed to enable eliptic curve for TLS")
context.set_options(SSL.OP_NO_SSLv2 | SSL.OP_NO_SSLv3)
context.use_certificate_chain_file(config.tls_certificate_file)
context.use_certificate(config.tls_certificate)
if not config.no_tls:
context.use_privatekey(config.tls_private_key)

View File

@@ -15,12 +15,11 @@
# limitations under the License.
from synapse.api.errors import SynapseError, Codes
from synapse.events.utils import prune_event
from canonicaljson import encode_canonical_json
from unpaddedbase64 import encode_base64, decode_base64
from signedjson.sign import sign_json
from syutil.jsonutil import encode_canonical_json
from syutil.base64util import encode_base64, decode_base64
from syutil.crypto.jsonsign import sign_json
from synapse.api.errors import SynapseError, Codes
import hashlib
import logging

View File

@@ -18,9 +18,7 @@ from twisted.web.http import HTTPClient
from twisted.internet.protocol import Factory
from twisted.internet import defer, reactor
from synapse.http.endpoint import matrix_federation_endpoint
from synapse.util.logcontext import (
preserve_context_over_fn, preserve_context_over_deferred
)
from synapse.util.logcontext import PreserveLoggingContext
import simplejson as json
import logging
@@ -42,14 +40,11 @@ def fetch_server_key(server_name, ssl_context_factory, path=KEY_API_V1):
for i in range(5):
try:
protocol = yield preserve_context_over_fn(
endpoint.connect, factory
)
server_response, server_certificate = yield preserve_context_over_deferred(
protocol.remote_key
)
defer.returnValue((server_response, server_certificate))
return
with PreserveLoggingContext():
protocol = yield endpoint.connect(factory)
server_response, server_certificate = yield protocol.remote_key
defer.returnValue((server_response, server_certificate))
return
except SynapseKeyClientError as e:
logger.exception("Error getting key for %r" % (server_name,))
if e.status.startswith("4"):

View File

@@ -14,24 +14,22 @@
# limitations under the License.
from synapse.crypto.keyclient import fetch_server_key
from synapse.api.errors import SynapseError, Codes
from synapse.util.retryutils import get_retry_limiter
from synapse.util import unwrapFirstError
from synapse.util.async import ObservableDeferred
from twisted.internet import defer
from signedjson.sign import (
from syutil.crypto.jsonsign import (
verify_signed_json, signature_ids, sign_json, encode_canonical_json
)
from signedjson.key import (
from syutil.crypto.signing_key import (
is_signing_algorithm_supported, decode_verify_key_bytes
)
from unpaddedbase64 import decode_base64, encode_base64
from syutil.base64util import decode_base64, encode_base64
from synapse.api.errors import SynapseError, Codes
from synapse.util.retryutils import get_retry_limiter
from synapse.util.async import create_observer
from OpenSSL import crypto
from collections import namedtuple
import urllib
import hashlib
import logging
@@ -40,9 +38,6 @@ import logging
logger = logging.getLogger(__name__)
KeyGroup = namedtuple("KeyGroup", ("server_name", "group_id", "key_ids"))
class Keyring(object):
def __init__(self, hs):
self.store = hs.get_datastore()
@@ -54,332 +49,131 @@ class Keyring(object):
self.key_downloads = {}
@defer.inlineCallbacks
def verify_json_for_server(self, server_name, json_object):
return self.verify_json_objects_for_server(
[(server_name, json_object)]
)[0]
def verify_json_objects_for_server(self, server_and_json):
"""Bulk verfies signatures of json objects, bulk fetching keys as
necessary.
Args:
server_and_json (list): List of pairs of (server_name, json_object)
Returns:
list of deferreds indicating success or failure to verify each
json object's signature for the given server_name.
"""
group_id_to_json = {}
group_id_to_group = {}
group_ids = []
next_group_id = 0
deferreds = {}
for server_name, json_object in server_and_json:
logger.debug("Verifying for %s", server_name)
group_id = next_group_id
next_group_id += 1
group_ids.append(group_id)
key_ids = signature_ids(json_object, server_name)
if not key_ids:
deferreds[group_id] = defer.fail(SynapseError(
400,
"Not signed with a supported algorithm",
Codes.UNAUTHORIZED,
))
else:
deferreds[group_id] = defer.Deferred()
group = KeyGroup(server_name, group_id, key_ids)
group_id_to_group[group_id] = group
group_id_to_json[group_id] = json_object
@defer.inlineCallbacks
def handle_key_deferred(group, deferred):
server_name = group.server_name
try:
_, _, key_id, verify_key = yield deferred
except IOError as e:
logger.warn(
"Got IOError when downloading keys for %s: %s %s",
server_name, type(e).__name__, str(e.message),
)
raise SynapseError(
502,
"Error downloading keys for %s" % (server_name,),
Codes.UNAUTHORIZED,
)
except Exception as e:
logger.exception(
"Got Exception when downloading keys for %s: %s %s",
server_name, type(e).__name__, str(e.message),
)
raise SynapseError(
401,
"No key for %s with id %s" % (server_name, key_ids),
Codes.UNAUTHORIZED,
)
json_object = group_id_to_json[group.group_id]
try:
verify_signed_json(json_object, server_name, verify_key)
except:
raise SynapseError(
401,
"Invalid signature for server %s with key %s:%s" % (
server_name, verify_key.alg, verify_key.version
),
Codes.UNAUTHORIZED,
)
server_to_deferred = {
server_name: defer.Deferred()
for server_name, _ in server_and_json
}
# We want to wait for any previous lookups to complete before
# proceeding.
wait_on_deferred = self.wait_for_previous_lookups(
[server_name for server_name, _ in server_and_json],
server_to_deferred,
)
# Actually start fetching keys.
wait_on_deferred.addBoth(
lambda _: self.get_server_verify_keys(group_id_to_group, deferreds)
)
# When we've finished fetching all the keys for a given server_name,
# resolve the deferred passed to `wait_for_previous_lookups` so that
# any lookups waiting will proceed.
server_to_gids = {}
def remove_deferreds(res, server_name, group_id):
server_to_gids[server_name].discard(group_id)
if not server_to_gids[server_name]:
d = server_to_deferred.pop(server_name, None)
if d:
d.callback(None)
return res
for g_id, deferred in deferreds.items():
server_name = group_id_to_group[g_id].server_name
server_to_gids.setdefault(server_name, set()).add(g_id)
deferred.addBoth(remove_deferreds, server_name, g_id)
# Pass those keys to handle_key_deferred so that the json object
# signatures can be verified
return [
handle_key_deferred(
group_id_to_group[g_id],
deferreds[g_id],
logger.debug("Verifying for %s", server_name)
key_ids = signature_ids(json_object, server_name)
if not key_ids:
raise SynapseError(
400,
"Not signed with a supported algorithm",
Codes.UNAUTHORIZED,
)
try:
verify_key = yield self.get_server_verify_key(server_name, key_ids)
except IOError as e:
logger.warn(
"Got IOError when downloading keys for %s: %s %s",
server_name, type(e).__name__, str(e.message),
)
raise SynapseError(
502,
"Error downloading keys for %s" % (server_name,),
Codes.UNAUTHORIZED,
)
except Exception as e:
logger.warn(
"Got Exception when downloading keys for %s: %s %s",
server_name, type(e).__name__, str(e.message),
)
raise SynapseError(
401,
"No key for %s with id %s" % (server_name, key_ids),
Codes.UNAUTHORIZED,
)
try:
verify_signed_json(json_object, server_name, verify_key)
except:
raise SynapseError(
401,
"Invalid signature for server %s with key %s:%s" % (
server_name, verify_key.alg, verify_key.version
),
Codes.UNAUTHORIZED,
)
for g_id in group_ids
]
@defer.inlineCallbacks
def wait_for_previous_lookups(self, server_names, server_to_deferred):
"""Waits for any previous key lookups for the given servers to finish.
def get_server_verify_key(self, server_name, key_ids):
"""Finds a verification key for the server with one of the key ids.
Trys to fetch the key from a trusted perspective server first.
Args:
server_names (list): list of server_names we want to lookup
server_to_deferred (dict): server_name to deferred which gets
resolved once we've finished looking up keys for that server
server_name(str): The name of the server to fetch a key for.
keys_ids (list of str): The key_ids to check for.
"""
while True:
wait_on = [
self.key_downloads[server_name]
for server_name in server_names
if server_name in self.key_downloads
]
if wait_on:
yield defer.DeferredList(wait_on)
else:
break
cached = yield self.store.get_server_verify_keys(server_name, key_ids)
for server_name, deferred in server_to_deferred.items():
d = ObservableDeferred(deferred)
self.key_downloads[server_name] = d
if cached:
defer.returnValue(cached[0])
return
def rm(r, server_name):
self.key_downloads.pop(server_name, None)
return r
download = self.key_downloads.get(server_name)
d.addBoth(rm, server_name)
if download is None:
download = self._get_server_verify_key_impl(server_name, key_ids)
self.key_downloads[server_name] = download
def get_server_verify_keys(self, group_id_to_group, group_id_to_deferred):
"""Takes a dict of KeyGroups and tries to find at least one key for
each group.
"""
@download.addBoth
def callback(ret):
del self.key_downloads[server_name]
return ret
# These are functions that produce keys given a list of key ids
key_fetch_fns = (
self.get_keys_from_store, # First try the local store
self.get_keys_from_perspectives, # Then try via perspectives
self.get_keys_from_server, # Then try directly
)
r = yield create_observer(download)
defer.returnValue(r)
@defer.inlineCallbacks
def do_iterations():
merged_results = {}
@defer.inlineCallbacks
def _get_server_verify_key_impl(self, server_name, key_ids):
keys = None
missing_keys = {
group.server_name: set(group.key_ids)
for group in group_id_to_group.values()
}
for fn in key_fetch_fns:
results = yield fn(missing_keys.items())
merged_results.update(results)
# We now need to figure out which groups we have keys for
# and which we don't
missing_groups = {}
for group in group_id_to_group.values():
for key_id in group.key_ids:
if key_id in merged_results[group.server_name]:
group_id_to_deferred[group.group_id].callback((
group.group_id,
group.server_name,
key_id,
merged_results[group.server_name][key_id],
))
break
else:
missing_groups.setdefault(
group.server_name, []
).append(group)
if not missing_groups:
break
missing_keys = {
server_name: set(
key_id for group in groups for key_id in group.key_ids
perspective_results = []
for perspective_name, perspective_keys in self.perspective_servers.items():
@defer.inlineCallbacks
def get_key():
try:
result = yield self.get_server_verify_key_v2_indirect(
server_name, key_ids, perspective_name, perspective_keys
)
for server_name, groups in missing_groups.items()
}
defer.returnValue(result)
except:
logging.info(
"Unable to getting key %r for %r from %r",
key_ids, server_name, perspective_name,
)
perspective_results.append(get_key())
for group in missing_groups.values():
group_id_to_deferred[group.group_id].errback(SynapseError(
401,
"No key for %s with id %s" % (
group.server_name, group.key_ids,
),
Codes.UNAUTHORIZED,
))
perspective_results = yield defer.gatherResults(perspective_results)
def on_err(err):
for deferred in group_id_to_deferred.values():
if not deferred.called:
deferred.errback(err)
for results in perspective_results:
if results is not None:
keys = results
do_iterations().addErrback(on_err)
limiter = yield get_retry_limiter(
server_name,
self.clock,
self.store,
)
return group_id_to_deferred
@defer.inlineCallbacks
def get_keys_from_store(self, server_name_and_key_ids):
res = yield defer.gatherResults(
[
self.store.get_server_verify_keys(
server_name, key_ids
).addCallback(lambda ks, server: (server, ks), server_name)
for server_name, key_ids in server_name_and_key_ids
],
consumeErrors=True,
).addErrback(unwrapFirstError)
defer.returnValue(dict(res))
@defer.inlineCallbacks
def get_keys_from_perspectives(self, server_name_and_key_ids):
@defer.inlineCallbacks
def get_key(perspective_name, perspective_keys):
try:
result = yield self.get_server_verify_key_v2_indirect(
server_name_and_key_ids, perspective_name, perspective_keys
)
defer.returnValue(result)
except Exception as e:
logger.exception(
"Unable to get key from %r: %s %s",
perspective_name,
type(e).__name__, str(e.message),
)
defer.returnValue({})
results = yield defer.gatherResults(
[
get_key(p_name, p_keys)
for p_name, p_keys in self.perspective_servers.items()
],
consumeErrors=True,
).addErrback(unwrapFirstError)
union_of_keys = {}
for result in results:
for server_name, keys in result.items():
union_of_keys.setdefault(server_name, {}).update(keys)
defer.returnValue(union_of_keys)
@defer.inlineCallbacks
def get_keys_from_server(self, server_name_and_key_ids):
@defer.inlineCallbacks
def get_key(server_name, key_ids):
limiter = yield get_retry_limiter(
server_name,
self.clock,
self.store,
)
with limiter:
keys = None
with limiter:
if keys is None:
try:
keys = yield self.get_server_verify_key_v2_direct(
server_name, key_ids
)
except Exception as e:
logger.info(
"Unable to getting key %r for %r directly: %s %s",
key_ids, server_name,
type(e).__name__, str(e.message),
)
except:
pass
if not keys:
keys = yield self.get_server_verify_key_v1_direct(
server_name, key_ids
)
keys = yield self.get_server_verify_key_v1_direct(
server_name, key_ids
)
keys = {server_name: keys}
defer.returnValue(keys)
results = yield defer.gatherResults(
[
get_key(server_name, key_ids)
for server_name, key_ids in server_name_and_key_ids
],
consumeErrors=True,
).addErrback(unwrapFirstError)
merged = {}
for result in results:
merged.update(result)
defer.returnValue({
server_name: keys
for server_name, keys in merged.items()
if keys
})
for key_id in key_ids:
if key_id in keys:
defer.returnValue(keys[key_id])
return
raise ValueError("No verification key found for given key ids")
@defer.inlineCallbacks
def get_server_verify_key_v2_indirect(self, server_names_and_key_ids,
def get_server_verify_key_v2_indirect(self, server_name, key_ids,
perspective_name,
perspective_keys):
limiter = yield get_retry_limiter(
@@ -390,7 +184,7 @@ class Keyring(object):
# TODO(mark): Set the minimum_valid_until_ts to that needed by
# the events being validated or the current time if validating
# an incoming request.
query_response = yield self.client.post_json(
responses = yield self.client.post_json(
destination=perspective_name,
path=b"/_matrix/key/v2/query",
data={
@@ -400,15 +194,12 @@ class Keyring(object):
u"minimum_valid_until_ts": 0
} for key_id in key_ids
}
for server_name, key_ids in server_names_and_key_ids
}
},
)
keys = {}
responses = query_response["server_keys"]
for response in responses:
if (u"signatures" not in response
or perspective_name not in response[u"signatures"]):
@@ -440,29 +231,23 @@ class Keyring(object):
" server %r" % (perspective_name,)
)
processed_response = yield self.process_v2_response(
perspective_name, response
response_keys = yield self.process_v2_response(
server_name, perspective_name, response
)
for server_name, response_keys in processed_response.items():
keys.setdefault(server_name, {}).update(response_keys)
keys.update(response_keys)
yield defer.gatherResults(
[
self.store_keys(
server_name=server_name,
from_server=perspective_name,
verify_keys=response_keys,
)
for server_name, response_keys in keys.items()
],
consumeErrors=True
).addErrback(unwrapFirstError)
yield self.store_keys(
server_name=server_name,
from_server=perspective_name,
verify_keys=keys,
)
defer.returnValue(keys)
@defer.inlineCallbacks
def get_server_verify_key_v2_direct(self, server_name, key_ids):
keys = {}
for requested_key_id in key_ids:
@@ -498,30 +283,25 @@ class Keyring(object):
raise ValueError("TLS certificate not allowed by fingerprints")
response_keys = yield self.process_v2_response(
server_name=server_name,
from_server=server_name,
requested_ids=[requested_key_id],
requested_id=requested_key_id,
response_json=response,
)
keys.update(response_keys)
yield defer.gatherResults(
[
self.store_keys(
server_name=key_server_name,
from_server=server_name,
verify_keys=verify_keys,
)
for key_server_name, verify_keys in keys.items()
],
consumeErrors=True
).addErrback(unwrapFirstError)
yield self.store_keys(
server_name=server_name,
from_server=server_name,
verify_keys=keys,
)
defer.returnValue(keys)
@defer.inlineCallbacks
def process_v2_response(self, from_server, response_json,
requested_ids=[]):
def process_v2_response(self, server_name, from_server, response_json,
requested_id=None):
time_now_ms = self.clock.time_msec()
response_keys = {}
verify_keys = {}
@@ -543,9 +323,7 @@ class Keyring(object):
verify_key.time_added = time_now_ms
old_verify_keys[key_id] = verify_key
results = {}
server_name = response_json["server_name"]
for key_id in response_json["signatures"].get(server_name, {}):
for key_id in response_json["signatures"][server_name]:
if key_id not in response_json["verify_keys"]:
raise ValueError(
"Key response must include verification keys for all"
@@ -567,31 +345,28 @@ class Keyring(object):
signed_key_json_bytes = encode_canonical_json(signed_key_json)
ts_valid_until_ms = signed_key_json[u"valid_until_ts"]
updated_key_ids = set(requested_ids)
updated_key_ids = set()
if requested_id is not None:
updated_key_ids.add(requested_id)
updated_key_ids.update(verify_keys)
updated_key_ids.update(old_verify_keys)
response_keys.update(verify_keys)
response_keys.update(old_verify_keys)
yield defer.gatherResults(
[
self.store.store_server_keys_json(
server_name=server_name,
key_id=key_id,
from_server=server_name,
ts_now_ms=time_now_ms,
ts_expires_ms=ts_valid_until_ms,
key_json_bytes=signed_key_json_bytes,
)
for key_id in updated_key_ids
],
consumeErrors=True,
).addErrback(unwrapFirstError)
for key_id in updated_key_ids:
yield self.store.store_server_keys_json(
server_name=server_name,
key_id=key_id,
from_server=server_name,
ts_now_ms=time_now_ms,
ts_expires_ms=ts_valid_until_ms,
key_json_bytes=signed_key_json_bytes,
)
results[server_name] = response_keys
defer.returnValue(response_keys)
defer.returnValue(results)
raise ValueError("No verification key found for given key ids")
@defer.inlineCallbacks
def get_server_verify_key_v1_direct(self, server_name, key_ids):
@@ -675,13 +450,8 @@ class Keyring(object):
Returns:
A deferred that completes when the keys are stored.
"""
# TODO(markjh): Store whether the keys have expired.
yield defer.gatherResults(
[
self.store.store_server_verify_key(
server_name, server_name, key.time_added, key
)
for key_id, key in verify_keys.items()
],
consumeErrors=True,
).addErrback(unwrapFirstError)
for key_id, key in verify_keys.items():
# TODO(markjh): Store whether the keys have expired.
yield self.store.store_server_verify_key(
server_name, server_name, key.time_added, key
)

View File

@@ -16,12 +16,6 @@
from synapse.util.frozenutils import freeze
# Whether we should use frozen_dict in FrozenEvent. Using frozen_dicts prevents
# bugs where we accidentally share e.g. signature dicts. However, converting
# a dict to frozen_dicts is expensive.
USE_FROZEN_DICTS = True
class _EventInternalMetadata(object):
def __init__(self, internal_metadata_dict):
self.__dict__ = dict(internal_metadata_dict)
@@ -90,7 +84,7 @@ class EventBase(object):
d = dict(self._event_dict)
d.update({
"signatures": self.signatures,
"unsigned": dict(self.unsigned),
"unsigned": self.unsigned,
})
return d
@@ -109,9 +103,6 @@ class EventBase(object):
pdu_json.setdefault("unsigned", {})["age"] = int(age)
del pdu_json["unsigned"]["age_ts"]
# This may be a frozen event
pdu_json["unsigned"].pop("redacted_because", None)
return pdu_json
def __set__(self, instance, value):
@@ -131,10 +122,7 @@ class FrozenEvent(EventBase):
unsigned = dict(event_dict.pop("unsigned", {}))
if USE_FROZEN_DICTS:
frozen_dict = freeze(event_dict)
else:
frozen_dict = event_dict
frozen_dict = freeze(event_dict)
super(FrozenEvent, self).__init__(
frozen_dict,

View File

@@ -74,8 +74,6 @@ def prune_event(event):
)
elif event_type == EventTypes.Aliases:
add_fields("aliases")
elif event_type == EventTypes.RoomHistoryVisibility:
add_fields("history_visibility")
allowed_fields = {
k: v

View File

@@ -18,12 +18,12 @@ from twisted.internet import defer
from synapse.events.utils import prune_event
from syutil.jsonutil import encode_canonical_json
from synapse.crypto.event_signing import check_event_content_hash
from synapse.api.errors import SynapseError
from synapse.util import unwrapFirstError
import logging
@@ -32,8 +32,7 @@ logger = logging.getLogger(__name__)
class FederationBase(object):
@defer.inlineCallbacks
def _check_sigs_and_hash_and_fetch(self, origin, pdus, outlier=False,
include_none=False):
def _check_sigs_and_hash_and_fetch(self, origin, pdus, outlier=False):
"""Takes a list of PDUs and checks the signatures and hashs of each
one. If a PDU fails its signature check then we check if we have it in
the database and if not then request if from the originating server of
@@ -51,108 +50,84 @@ class FederationBase(object):
Returns:
Deferred : A list of PDUs that have valid signatures and hashes.
"""
deferreds = self._check_sigs_and_hashes(pdus)
def callback(pdu):
return pdu
signed_pdus = []
def errback(failure, pdu):
failure.trap(SynapseError)
return None
@defer.inlineCallbacks
def do(pdu):
try:
new_pdu = yield self._check_sigs_and_hash(pdu)
signed_pdus.append(new_pdu)
except SynapseError:
# FIXME: We should handle signature failures more gracefully.
def try_local_db(res, pdu):
if not res:
# Check local db.
return self.store.get_event(
new_pdu = yield self.store.get_event(
pdu.event_id,
allow_rejected=True,
allow_none=True,
)
return res
if new_pdu:
signed_pdus.append(new_pdu)
return
def try_remote(res, pdu):
if not res and pdu.origin != origin:
return self.get_pdu(
destinations=[pdu.origin],
event_id=pdu.event_id,
outlier=outlier,
timeout=10000,
).addErrback(lambda e: None)
return res
# Check pdu.origin
if pdu.origin != origin:
try:
new_pdu = yield self.get_pdu(
destinations=[pdu.origin],
event_id=pdu.event_id,
outlier=outlier,
)
if new_pdu:
signed_pdus.append(new_pdu)
return
except:
pass
def warn(res, pdu):
if not res:
logger.warn(
"Failed to find copy of %s with valid signature",
pdu.event_id,
)
return res
for pdu, deferred in zip(pdus, deferreds):
deferred.addCallbacks(
callback, errback, errbackArgs=[pdu]
).addCallback(
try_local_db, pdu
).addCallback(
try_remote, pdu
).addCallback(
warn, pdu
)
valid_pdus = yield defer.gatherResults(
deferreds,
yield defer.gatherResults(
[do(pdu) for pdu in pdus],
consumeErrors=True
).addErrback(unwrapFirstError)
)
if include_none:
defer.returnValue(valid_pdus)
else:
defer.returnValue([p for p in valid_pdus if p])
defer.returnValue(signed_pdus)
@defer.inlineCallbacks
def _check_sigs_and_hash(self, pdu):
return self._check_sigs_and_hashes([pdu])[0]
def _check_sigs_and_hashes(self, pdus):
"""Throws a SynapseError if a PDU does not have the correct
"""Throws a SynapseError if the PDU does not have the correct
signatures.
Returns:
FrozenEvent: Either the given event or it redacted if it failed the
content hash check.
"""
# Check signatures are correct.
redacted_event = prune_event(pdu)
redacted_pdu_json = redacted_event.get_pdu_json()
redacted_pdus = [
prune_event(pdu)
for pdu in pdus
]
deferreds = self.keyring.verify_json_objects_for_server([
(p.origin, p.get_pdu_json())
for p in redacted_pdus
])
def callback(_, pdu, redacted):
if not check_event_content_hash(pdu):
logger.warn(
"Event content has been tampered, redacting %s: %s",
pdu.event_id, pdu.get_pdu_json()
)
return redacted
return pdu
def errback(failure, pdu):
failure.trap(SynapseError)
try:
yield self.keyring.verify_json_for_server(
pdu.origin, redacted_pdu_json
)
except SynapseError:
logger.warn(
"Signature check failed for %s",
pdu.event_id,
"Signature check failed for %s redacted to %s",
encode_canonical_json(pdu.get_pdu_json()),
encode_canonical_json(redacted_pdu_json),
)
return failure
raise
for deferred, pdu, redacted in zip(deferreds, pdus, redacted_pdus):
deferred.addCallbacks(
callback, errback,
callbackArgs=[pdu, redacted],
errbackArgs=[pdu],
if not check_event_content_hash(pdu):
logger.warn(
"Event content has been tampered, redacting %s, %s",
pdu.event_id, encode_canonical_json(pdu.get_dict())
)
defer.returnValue(redacted_event)
return deferreds
defer.returnValue(pdu)

View File

@@ -22,15 +22,13 @@ from .units import Edu
from synapse.api.errors import (
CodeMessageException, HttpResponseException, SynapseError,
)
from synapse.util import unwrapFirstError
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.expiringcache import ExpiringCache
from synapse.util.logutils import log_function
from synapse.events import FrozenEvent
import synapse.metrics
from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination
import copy
import itertools
import logging
import random
@@ -134,36 +132,6 @@ class FederationClient(FederationBase):
destination, query_type, args, retry_on_dns_fail=retry_on_dns_fail
)
@log_function
def query_client_keys(self, destination, content):
"""Query device keys for a device hosted on a remote server.
Args:
destination (str): Domain name of the remote homeserver
content (dict): The query content.
Returns:
a Deferred which will eventually yield a JSON object from the
response
"""
sent_queries_counter.inc("client_device_keys")
return self.transport_layer.query_client_keys(destination, content)
@log_function
def claim_client_keys(self, destination, content):
"""Claims one-time keys for a device hosted on a remote server.
Args:
destination (str): Domain name of the remote homeserver
content (dict): The query content.
Returns:
a Deferred which will eventually yield a JSON object from the
response
"""
sent_queries_counter.inc("client_one_time_keys")
return self.transport_layer.claim_client_keys(destination, content)
@defer.inlineCallbacks
@log_function
def backfill(self, dest, context, limit, extremities):
@@ -196,17 +164,16 @@ class FederationClient(FederationBase):
for p in transaction_data["pdus"]
]
# FIXME: We should handle signature failures more gracefully.
pdus[:] = yield defer.gatherResults(
self._check_sigs_and_hashes(pdus),
consumeErrors=True,
).addErrback(unwrapFirstError)
for i, pdu in enumerate(pdus):
pdus[i] = yield self._check_sigs_and_hash(pdu)
# FIXME: We should handle signature failures more gracefully.
defer.returnValue(pdus)
@defer.inlineCallbacks
@log_function
def get_pdu(self, destinations, event_id, outlier=False, timeout=None):
def get_pdu(self, destinations, event_id, outlier=False):
"""Requests the PDU with given origin and ID from the remote home
servers.
@@ -222,8 +189,6 @@ class FederationClient(FederationBase):
outlier (bool): Indicates whether the PDU is an `outlier`, i.e. if
it's from an arbitary point in the context as opposed to part
of the current block of PDUs. Defaults to `False`
timeout (int): How long to try (in ms) each destination for before
moving to the next destination. None indicates no timeout.
Returns:
Deferred: Results in the requested PDU.
@@ -247,7 +212,7 @@ class FederationClient(FederationBase):
with limiter:
transaction_data = yield self.transport_layer.get_event(
destination, event_id, timeout=timeout,
destination, event_id
)
logger.debug("transaction_data %r", transaction_data)
@@ -257,11 +222,11 @@ class FederationClient(FederationBase):
for p in transaction_data["pdus"]
]
if pdu_list and pdu_list[0]:
if pdu_list:
pdu = pdu_list[0]
# Check signatures are correct.
pdu = yield self._check_sigs_and_hashes([pdu])[0]
pdu = yield self._check_sigs_and_hash(pdu)
break
@@ -290,7 +255,7 @@ class FederationClient(FederationBase):
)
continue
if self._get_pdu_cache is not None and pdu:
if self._get_pdu_cache is not None:
self._get_pdu_cache[event_id] = pdu
defer.returnValue(pdu)
@@ -358,9 +323,6 @@ class FederationClient(FederationBase):
@defer.inlineCallbacks
def make_join(self, destinations, room_id, user_id):
for destination in destinations:
if destination == self.server_name:
continue
try:
ret = yield self.transport_layer.make_join(
destination, room_id, user_id
@@ -387,9 +349,6 @@ class FederationClient(FederationBase):
@defer.inlineCallbacks
def send_join(self, destinations, pdu):
for destination in destinations:
if destination == self.server_name:
continue
try:
time_now = self._clock.time_msec()
_, content = yield self.transport_layer.send_join(
@@ -411,39 +370,13 @@ class FederationClient(FederationBase):
for p in content.get("auth_chain", [])
]
pdus = {
p.event_id: p
for p in itertools.chain(state, auth_chain)
}
valid_pdus = yield self._check_sigs_and_hash_and_fetch(
destination, pdus.values(),
outlier=True,
signed_state = yield self._check_sigs_and_hash_and_fetch(
destination, state, outlier=True
)
valid_pdus_map = {
p.event_id: p
for p in valid_pdus
}
# NB: We *need* to copy to ensure that we don't have multiple
# references being passed on, as that causes... issues.
signed_state = [
copy.copy(valid_pdus_map[p.event_id])
for p in state
if p.event_id in valid_pdus_map
]
signed_auth = [
valid_pdus_map[p.event_id]
for p in auth_chain
if p.event_id in valid_pdus_map
]
# NB: We *need* to copy to ensure that we don't have multiple
# references being passed on, as that causes... issues.
for s in signed_state:
s.internal_metadata = copy.deepcopy(s.internal_metadata)
signed_auth = yield self._check_sigs_and_hash_and_fetch(
destination, auth_chain, outlier=True
)
auth_chain.sort(key=lambda e: e.depth)
@@ -455,7 +388,7 @@ class FederationClient(FederationBase):
except CodeMessageException:
raise
except Exception as e:
logger.exception(
logger.warn(
"Failed to send_join via %s: %s",
destination, e.message
)
@@ -558,7 +491,7 @@ class FederationClient(FederationBase):
]
signed_events = yield self._check_sigs_and_hash_and_fetch(
destination, events, outlier=False
destination, events, outlier=True
)
have_gotten_all_from_destination = True
@@ -585,7 +518,7 @@ class FederationClient(FederationBase):
# Are we missing any?
seen_events = set(earliest_events_ids)
seen_events.update(e.event_id for e in signed_events if e)
seen_events.update(e.event_id for e in signed_events)
missing_events = {}
for e in itertools.chain(latest_events, signed_events):
@@ -628,7 +561,7 @@ class FederationClient(FederationBase):
res = yield defer.DeferredList(deferreds, consumeErrors=True)
for (result, val), (e_id, _) in zip(res, ordered_missing):
if result and val:
if result:
signed_events.append(val)
else:
failed_to_fetch.add(e_id)

View File

@@ -20,6 +20,7 @@ from .federation_base import FederationBase
from .units import Transaction, Edu
from synapse.util.logutils import log_function
from synapse.util.logcontext import PreserveLoggingContext
from synapse.events import FrozenEvent
import synapse.metrics
@@ -27,7 +28,6 @@ from synapse.api.errors import FederationError, SynapseError
from synapse.crypto.event_signing import compute_event_signature
import simplejson as json
import logging
@@ -123,28 +123,29 @@ class FederationServer(FederationBase):
logger.debug("[%s] Transaction is new", transaction.transaction_id)
results = []
with PreserveLoggingContext():
results = []
for pdu in pdu_list:
d = self._handle_new_pdu(transaction.origin, pdu)
for pdu in pdu_list:
d = self._handle_new_pdu(transaction.origin, pdu)
try:
yield d
results.append({})
except FederationError as e:
self.send_failure(e, transaction.origin)
results.append({"error": str(e)})
except Exception as e:
results.append({"error": str(e)})
logger.exception("Failed to handle PDU")
try:
yield d
results.append({})
except FederationError as e:
self.send_failure(e, transaction.origin)
results.append({"error": str(e)})
except Exception as e:
results.append({"error": str(e)})
logger.exception("Failed to handle PDU")
if hasattr(transaction, "edus"):
for edu in [Edu(**x) for x in transaction.edus]:
self.received_edu(
transaction.origin,
edu.edu_type,
edu.content
)
if hasattr(transaction, "edus"):
for edu in [Edu(**x) for x in transaction.edus]:
self.received_edu(
transaction.origin,
edu.edu_type,
edu.content
)
for failure in getattr(transaction, "pdu_failures", []):
logger.info("Got failure %r", failure)
@@ -313,48 +314,6 @@ class FederationServer(FederationBase):
(200, send_content)
)
@defer.inlineCallbacks
@log_function
def on_query_client_keys(self, origin, content):
query = []
for user_id, device_ids in content.get("device_keys", {}).items():
if not device_ids:
query.append((user_id, None))
else:
for device_id in device_ids:
query.append((user_id, device_id))
results = yield self.store.get_e2e_device_keys(query)
json_result = {}
for user_id, device_keys in results.items():
for device_id, json_bytes in device_keys.items():
json_result.setdefault(user_id, {})[device_id] = json.loads(
json_bytes
)
defer.returnValue({"device_keys": json_result})
@defer.inlineCallbacks
@log_function
def on_claim_client_keys(self, origin, content):
query = []
for user_id, device_keys in content.get("one_time_keys", {}).items():
for device_id, algorithm in device_keys.items():
query.append((user_id, device_id, algorithm))
results = yield self.store.claim_e2e_one_time_keys(query)
json_result = {}
for user_id, device_keys in results.items():
for device_id, keys in device_keys.items():
for key_id, json_bytes in keys.items():
json_result.setdefault(user_id, {})[device_id] = {
key_id: json.loads(json_bytes)
}
defer.returnValue({"one_time_keys": json_result})
@defer.inlineCallbacks
@log_function
def on_get_missing_events(self, origin, room_id, earliest_events,

View File

@@ -23,6 +23,8 @@ from twisted.internet import defer
from synapse.util.logutils import log_function
from syutil.jsonutil import encode_canonical_json
import logging
@@ -69,7 +71,7 @@ class TransactionActions(object):
transaction.transaction_id,
transaction.origin,
code,
response,
encode_canonical_json(response)
)
@defer.inlineCallbacks
@@ -99,5 +101,5 @@ class TransactionActions(object):
transaction.transaction_id,
transaction.destination,
response_code,
response_dict,
encode_canonical_json(response_dict)
)

View File

@@ -104,6 +104,7 @@ class TransactionQueue(object):
return not destination.startswith("localhost")
@defer.inlineCallbacks
@log_function
def enqueue_pdu(self, pdu, destinations, order):
# We loop through all destinations to see whether we already have
# a transaction in progress. If we do, stick it in the pending_pdus
@@ -207,13 +208,13 @@ class TransactionQueue(object):
# request at which point pending_pdus_by_dest just keeps growing.
# we need application-layer timeouts of some flavour of these
# requests
logger.debug(
logger.info(
"TX [%s] Transaction already in progress",
destination
)
return
logger.debug("TX [%s] _attempt_new_transaction", destination)
logger.info("TX [%s] _attempt_new_transaction", destination)
# list of (pending_pdu, deferred, order)
pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
@@ -221,11 +222,11 @@ class TransactionQueue(object):
pending_failures = self.pending_failures_by_dest.pop(destination, [])
if pending_pdus:
logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d",
destination, len(pending_pdus))
logger.info("TX [%s] len(pending_pdus_by_dest[dest]) = %d",
destination, len(pending_pdus))
if not pending_pdus and not pending_edus and not pending_failures:
logger.debug("TX [%s] Nothing to send", destination)
logger.info("TX [%s] Nothing to send", destination)
return
# Sort based on the order field
@@ -242,8 +243,6 @@ class TransactionQueue(object):
try:
self.pending_transactions[destination] = 1
txn_id = str(self._next_txn_id)
limiter = yield get_retry_limiter(
destination,
self._clock,
@@ -251,9 +250,9 @@ class TransactionQueue(object):
)
logger.debug(
"TX [%s] {%s} Attempting new transaction"
"TX [%s] Attempting new transaction"
" (pdus: %d, edus: %d, failures: %d)",
destination, txn_id,
destination,
len(pending_pdus),
len(pending_edus),
len(pending_failures)
@@ -263,7 +262,7 @@ class TransactionQueue(object):
transaction = Transaction.create_new(
origin_server_ts=int(self._clock.time_msec()),
transaction_id=txn_id,
transaction_id=str(self._next_txn_id),
origin=self.server_name,
destination=destination,
pdus=pdus,
@@ -277,13 +276,9 @@ class TransactionQueue(object):
logger.debug("TX [%s] Persisted transaction", destination)
logger.info(
"TX [%s] {%s} Sending transaction [%s],"
" (PDUs: %d, EDUs: %d, failures: %d)",
destination, txn_id,
"TX [%s] Sending transaction [%s]",
destination,
transaction.transaction_id,
len(pending_pdus),
len(pending_edus),
len(pending_failures),
)
with limiter:
@@ -319,10 +314,7 @@ class TransactionQueue(object):
code = e.code
response = e.response
logger.info(
"TX [%s] {%s} got %d response",
destination, txn_id, code
)
logger.info("TX [%s] got %d response", destination, code)
logger.debug("TX [%s] Sent transaction", destination)
logger.debug("TX [%s] Marking as delivered...", destination)

View File

@@ -50,15 +50,13 @@ class TransportLayerClient(object):
)
@log_function
def get_event(self, destination, event_id, timeout=None):
def get_event(self, destination, event_id):
""" Requests the pdu with give id and origin from the given server.
Args:
destination (str): The host name of the remote home server we want
to get the state from.
event_id (str): The id of the event being requested.
timeout (int): How long to try (in ms) the destination for before
giving up. None indicates no timeout.
Returns:
Deferred: Results in a dict received from the remote homeserver.
@@ -67,7 +65,7 @@ class TransportLayerClient(object):
destination, event_id)
path = PREFIX + "/event/%s/" % (event_id, )
return self.client.get_json(destination, path=path, timeout=timeout)
return self.client.get_json(destination, path=path)
@log_function
def backfill(self, destination, room_id, event_tuples, limit):
@@ -222,76 +220,6 @@ class TransportLayerClient(object):
defer.returnValue(content)
@defer.inlineCallbacks
@log_function
def query_client_keys(self, destination, query_content):
"""Query the device keys for a list of user ids hosted on a remote
server.
Request:
{
"device_keys": {
"<user_id>": ["<device_id>"]
} }
Response:
{
"device_keys": {
"<user_id>": {
"<device_id>": {...}
} } }
Args:
destination(str): The server to query.
query_content(dict): The user ids to query.
Returns:
A dict containg the device keys.
"""
path = PREFIX + "/user/keys/query"
content = yield self.client.post_json(
destination=destination,
path=path,
data=query_content,
)
defer.returnValue(content)
@defer.inlineCallbacks
@log_function
def claim_client_keys(self, destination, query_content):
"""Claim one-time keys for a list of devices hosted on a remote server.
Request:
{
"one_time_keys": {
"<user_id>": {
"<device_id>": "<algorithm>"
} } }
Response:
{
"device_keys": {
"<user_id>": {
"<device_id>": {
"<algorithm>:<key_id>": "<key_base64>"
} } } }
Args:
destination(str): The server to query.
query_content(dict): The user ids to query.
Returns:
A dict containg the one-time keys.
"""
path = PREFIX + "/user/keys/claim"
content = yield self.client.post_json(
destination=destination,
path=path,
data=query_content,
)
defer.returnValue(content)
@defer.inlineCallbacks
@log_function
def get_missing_events(self, destination, room_id, earliest_events,

View File

@@ -93,9 +93,6 @@ class TransportLayerServer(object):
yield self.keyring.verify_json_for_server(origin, json_request)
logger.info("Request from %s", origin)
request.authenticated_entity = origin
defer.returnValue((origin, content))
@log_function
@@ -199,14 +196,6 @@ class FederationSendServlet(BaseFederationServlet):
transaction_id, str(transaction_data)
)
logger.info(
"Received txn %s from %s. (PDUs: %d, EDUs: %d, failures: %d)",
transaction_id, origin,
len(transaction_data.get("pdus", [])),
len(transaction_data.get("edus", [])),
len(transaction_data.get("failures", [])),
)
# We should ideally be getting this from the security layer.
# origin = body["origin"]
@@ -325,24 +314,6 @@ class FederationInviteServlet(BaseFederationServlet):
defer.returnValue((200, content))
class FederationClientKeysQueryServlet(BaseFederationServlet):
PATH = "/user/keys/query"
@defer.inlineCallbacks
def on_POST(self, origin, content, query):
response = yield self.handler.on_query_client_keys(origin, content)
defer.returnValue((200, response))
class FederationClientKeysClaimServlet(BaseFederationServlet):
PATH = "/user/keys/claim"
@defer.inlineCallbacks
def on_POST(self, origin, content, query):
response = yield self.handler.on_claim_client_keys(origin, content)
defer.returnValue((200, response))
class FederationQueryAuthServlet(BaseFederationServlet):
PATH = "/query_auth/([^/]*)/([^/]*)"
@@ -391,6 +362,4 @@ SERVLET_CLASSES = (
FederationQueryAuthServlet,
FederationGetMissingEventsServlet,
FederationEventAuthServlet,
FederationClientKeysQueryServlet,
FederationClientKeysClaimServlet,
)

View File

@@ -22,6 +22,7 @@ from .room import (
from .message import MessageHandler
from .events import EventStreamHandler, EventHandler
from .federation import FederationHandler
from .login import LoginHandler
from .profile import ProfileHandler
from .presence import PresenceHandler
from .directory import DirectoryHandler
@@ -31,7 +32,6 @@ from .appservice import ApplicationServicesHandler
from .sync import SyncHandler
from .auth import AuthHandler
from .identity import IdentityHandler
from .receipts import ReceiptsHandler
class Handlers(object):
@@ -53,10 +53,10 @@ class Handlers(object):
self.profile_handler = ProfileHandler(hs)
self.presence_handler = PresenceHandler(hs)
self.room_list_handler = RoomListHandler(hs)
self.login_handler = LoginHandler(hs)
self.directory_handler = DirectoryHandler(hs)
self.typing_notification_handler = TypingNotificationHandler(hs)
self.admin_handler = AdminHandler(hs)
self.receipts_handler = ReceiptsHandler(hs)
asapi = ApplicationServiceApi(hs)
self.appservice_handler = ApplicationServicesHandler(
hs, asapi, AppServiceScheduler(

View File

@@ -15,12 +15,10 @@
from twisted.internet import defer
from synapse.api.errors import LimitExceededError, SynapseError, AuthError
from synapse.api.errors import LimitExceededError, SynapseError
from synapse.crypto.event_signing import add_hashes_and_signatures
from synapse.api.constants import Membership, EventTypes
from synapse.types import UserID, RoomAlias
from synapse.util.logcontext import PreserveLoggingContext
from synapse.types import UserID
import logging
@@ -78,9 +76,7 @@ class BaseHandler(object):
context = yield state_handler.compute_event_context(builder)
if builder.is_state():
builder.prev_state = yield self.store.add_event_hashes(
context.prev_state_events
)
builder.prev_state = context.prev_state_events
yield self.auth.add_auth_events(builder, context)
@@ -107,25 +103,7 @@ class BaseHandler(object):
if not suppress_auth:
self.auth.check(event, auth_events=context.current_state)
if event.type == EventTypes.CanonicalAlias:
# Check the alias is acually valid (at this time at least)
room_alias_str = event.content.get("alias", None)
if room_alias_str:
room_alias = RoomAlias.from_string(room_alias_str)
directory_handler = self.hs.get_handlers().directory_handler
mapping = yield directory_handler.get_association(room_alias)
if mapping["room_id"] != event.room_id:
raise SynapseError(
400,
"Room alias %s does not point to the room" % (
room_alias_str,
)
)
(event_stream_id, max_stream_id) = yield self.store.persist_event(
event, context=context
)
yield self.store.persist_event(event, context=context)
federation_handler = self.hs.get_handlers().federation_handler
@@ -146,21 +124,6 @@ class BaseHandler(object):
returned_invite.signatures
)
if event.type == EventTypes.Redaction:
if self.auth.check_redaction(event, auth_events=context.current_state):
original_event = yield self.store.get_event(
event.redacts,
check_redacted=False,
get_prev_content=False,
allow_rejected=False,
allow_none=False
)
if event.user_id != original_event.user_id:
raise AuthError(
403,
"You don't have permission to redact events"
)
destinations = set(extra_destinations)
for k, s in context.current_state.items():
try:
@@ -174,12 +137,10 @@ class BaseHandler(object):
"Failed to get destination from event %s", s.event_id
)
with PreserveLoggingContext():
# Don't block waiting on waking up all the listeners.
notify_d = self.notifier.on_new_room_event(
event, event_stream_id, max_stream_id,
extra_users=extra_users
)
# Don't block waiting on waking up all the listeners.
notify_d = self.notifier.on_new_room_event(
event, extra_users=extra_users
)
def log_failure(f):
logger.warn(
@@ -189,6 +150,8 @@ class BaseHandler(object):
notify_d.addErrback(log_failure)
federation_handler.handle_new_event(
fed_d = federation_handler.handle_new_event(
event, destinations=destinations,
)
fed_d.addErrback(log_failure)

View File

@@ -34,7 +34,6 @@ class AdminHandler(BaseHandler):
d = {}
for r in res:
# Note that device_id is always None
device = d.setdefault(r["device_id"], {})
session = device.setdefault(r["access_token"], [])
session.append({

View File

@@ -15,7 +15,7 @@
from twisted.internet import defer
from synapse.api.constants import EventTypes
from synapse.api.constants import EventTypes, Membership
from synapse.appservice import ApplicationService
from synapse.types import UserID
@@ -147,7 +147,10 @@ class ApplicationServicesHandler(object):
)
# We need to know the members associated with this event.room_id,
# if any.
member_list = yield self.store.get_users_in_room(event.room_id)
member_list = yield self.store.get_room_members(
room_id=event.room_id,
membership=Membership.JOIN
)
services = yield self.store.get_app_services()
interested_list = [
@@ -177,7 +180,7 @@ class ApplicationServicesHandler(object):
return
user_info = yield self.store.get_user_by_id(user_id)
if user_info:
if len(user_info) > 0:
defer.returnValue(False)
return

View File

@@ -26,7 +26,6 @@ from twisted.web.client import PartialDownloadError
import logging
import bcrypt
import pymacaroons
import simplejson
import synapse.util.stringutils as stringutils
@@ -48,24 +47,17 @@ class AuthHandler(BaseHandler):
self.sessions = {}
@defer.inlineCallbacks
def check_auth(self, flows, clientdict, clientip):
def check_auth(self, flows, clientdict, clientip=None):
"""
Takes a dictionary sent by the client in the login / registration
protocol and handles the login flow.
As a side effect, this function fills in the 'creds' key on the user's
session with a map, which maps each auth-type (str) to the relevant
identity authenticated by that auth-type (mostly str, but for captcha, bool).
Args:
flows (list): A list of login flows. Each flow is an ordered list of
strings representing auth-types. At least one full
flow must be completed in order for auth to be successful.
clientdict: The dictionary from the client root level, not the
'auth' key: this method prompts for auth if none is sent.
clientip (str): The IP address of the client.
flows: list of list of stages
authdict: The dictionary from the client root level, not the
'auth' key: this method prompts for auth if none is sent.
Returns:
A tuple of (authed, dict, dict) where authed is true if the client
A tuple of authed, dict, dict where authed is true if the client
has successfully completed an auth flow. If it is true, the first
dict contains the authenticated credentials of each stage.
@@ -83,7 +75,7 @@ class AuthHandler(BaseHandler):
del clientdict['auth']
if 'session' in authdict:
sid = authdict['session']
session = self._get_session_info(sid)
sess = self._get_session_info(sid)
if len(clientdict) > 0:
# This was designed to allow the client to omit the parameters
@@ -93,21 +85,20 @@ class AuthHandler(BaseHandler):
# email auth link on there). It's probably too open to abuse
# because it lets unauthenticated clients store arbitrary objects
# on a home server.
# Revisit: Assumimg the REST APIs do sensible validation, the data
# isn't arbintrary.
session['clientdict'] = clientdict
self._save_session(session)
elif 'clientdict' in session:
clientdict = session['clientdict']
# sess['clientdict'] = clientdict
# self._save_session(sess)
pass
elif 'clientdict' in sess:
clientdict = sess['clientdict']
if not authdict:
defer.returnValue(
(False, self._auth_dict_for_flows(flows, session), clientdict)
(False, self._auth_dict_for_flows(flows, sess), clientdict)
)
if 'creds' not in session:
session['creds'] = {}
creds = session['creds']
if 'creds' not in sess:
sess['creds'] = {}
creds = sess['creds']
# check auth type currently being presented
if 'type' in authdict:
@@ -116,15 +107,15 @@ class AuthHandler(BaseHandler):
result = yield self.checkers[authdict['type']](authdict, clientip)
if result:
creds[authdict['type']] = result
self._save_session(session)
self._save_session(sess)
for f in flows:
if len(set(f) - set(creds.keys())) == 0:
logger.info("Auth completed with creds: %r", creds)
self._remove_session(session)
self._remove_session(sess)
defer.returnValue((True, creds, clientdict))
ret = self._auth_dict_for_flows(flows, session)
ret = self._auth_dict_for_flows(flows, sess)
ret['completed'] = creds.keys()
defer.returnValue((False, ret, clientdict))
@@ -158,14 +149,22 @@ class AuthHandler(BaseHandler):
if "user" not in authdict or "password" not in authdict:
raise LoginError(400, "", Codes.MISSING_PARAM)
user_id = authdict["user"]
user = authdict["user"]
password = authdict["password"]
if not user_id.startswith('@'):
user_id = UserID.create(user_id, self.hs.hostname).to_string()
if not user.startswith('@'):
user = UserID.create(user, self.hs.hostname).to_string()
user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id)
self._check_password(user_id, password, password_hash)
defer.returnValue(user_id)
user_info = yield self.store.get_user_by_id(user_id=user)
if not user_info:
logger.warn("Attempted to login as %s but they do not exist", user)
raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
stored_hash = user_info[0]["password_hash"]
if bcrypt.checkpw(password, stored_hash):
defer.returnValue(user)
else:
logger.warn("Failed password login for user %s", user)
raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
@defer.inlineCallbacks
def _check_recaptcha(self, authdict, clientip):
@@ -188,8 +187,8 @@ class AuthHandler(BaseHandler):
# each request
try:
client = SimpleHttpClient(self.hs)
resp_body = yield client.post_urlencoded_get_json(
self.hs.config.recaptcha_siteverify_api,
data = yield client.post_urlencoded_get_json(
"https://www.google.com/recaptcha/api/siteverify",
args={
'secret': self.hs.config.recaptcha_private_key,
'response': user_response,
@@ -199,8 +198,7 @@ class AuthHandler(BaseHandler):
except PartialDownloadError as pde:
# Twisted is silly
data = pde.response
resp_body = simplejson.loads(data)
resp_body = simplejson.loads(data)
if 'success' in resp_body and resp_body['success']:
defer.returnValue(True)
raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
@@ -269,120 +267,6 @@ class AuthHandler(BaseHandler):
return self.sessions[session_id]
@defer.inlineCallbacks
def login_with_password(self, user_id, password):
"""
Authenticates the user with their username and password.
Used only by the v1 login API.
Args:
user_id (str): User ID
password (str): Password
Returns:
A tuple of:
The user's ID.
The access token for the user's session.
The refresh token for the user's session.
Raises:
StoreError if there was a problem storing the token.
LoginError if there was an authentication problem.
"""
user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id)
self._check_password(user_id, password, password_hash)
logger.info("Logging in user %s", user_id)
access_token = yield self.issue_access_token(user_id)
refresh_token = yield self.issue_refresh_token(user_id)
defer.returnValue((user_id, access_token, refresh_token))
@defer.inlineCallbacks
def _find_user_id_and_pwd_hash(self, user_id):
"""Checks to see if a user with the given id exists. Will check case
insensitively, but will throw if there are multiple inexact matches.
Returns:
tuple: A 2-tuple of `(canonical_user_id, password_hash)`
"""
user_infos = yield self.store.get_users_by_id_case_insensitive(user_id)
if not user_infos:
logger.warn("Attempted to login as %s but they do not exist", user_id)
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
if len(user_infos) > 1:
if user_id not in user_infos:
logger.warn(
"Attempted to login as %s but it matches more than one user "
"inexactly: %r",
user_id, user_infos.keys()
)
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
defer.returnValue((user_id, user_infos[user_id]))
else:
defer.returnValue(user_infos.popitem())
def _check_password(self, user_id, password, stored_hash):
"""Checks that user_id has passed password, raises LoginError if not."""
if not self.validate_hash(password, stored_hash):
logger.warn("Failed password login for user %s", user_id)
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
@defer.inlineCallbacks
def issue_access_token(self, user_id):
access_token = self.generate_access_token(user_id)
yield self.store.add_access_token_to_user(user_id, access_token)
defer.returnValue(access_token)
@defer.inlineCallbacks
def issue_refresh_token(self, user_id):
refresh_token = self.generate_refresh_token(user_id)
yield self.store.add_refresh_token_to_user(user_id, refresh_token)
defer.returnValue(refresh_token)
def generate_access_token(self, user_id):
macaroon = self._generate_base_macaroon(user_id)
macaroon.add_first_party_caveat("type = access")
now = self.hs.get_clock().time_msec()
expiry = now + (60 * 60 * 1000)
macaroon.add_first_party_caveat("time < %d" % (expiry,))
return macaroon.serialize()
def generate_refresh_token(self, user_id):
m = self._generate_base_macaroon(user_id)
m.add_first_party_caveat("type = refresh")
# Important to add a nonce, because otherwise every refresh token for a
# user will be the same.
m.add_first_party_caveat("nonce = %s" % (
stringutils.random_string_with_symbols(16),
))
return m.serialize()
def _generate_base_macaroon(self, user_id):
macaroon = pymacaroons.Macaroon(
location=self.hs.config.server_name,
identifier="key",
key=self.hs.config.macaroon_secret_key)
macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
return macaroon
@defer.inlineCallbacks
def set_password(self, user_id, newpassword):
password_hash = self.hash(newpassword)
yield self.store.user_set_password_hash(user_id, password_hash)
yield self.store.user_delete_access_tokens(user_id)
yield self.hs.get_pusherpool().remove_pushers_by_user(user_id)
yield self.store.flush_user(user_id)
@defer.inlineCallbacks
def add_threepid(self, user_id, medium, address, validated_at):
yield self.store.user_add_threepid(
user_id, medium, address, validated_at,
self.hs.get_clock().time_msec()
)
def _save_session(self, session):
# TODO: Persistent storage
logger.debug("Saving session %s", session)
@@ -391,26 +275,3 @@ class AuthHandler(BaseHandler):
def _remove_session(self, session):
logger.debug("Removing session %s", session)
del self.sessions[session["id"]]
def hash(self, password):
"""Computes a secure hash of password.
Args:
password (str): Password to hash.
Returns:
Hashed password (str).
"""
return bcrypt.hashpw(password, bcrypt.gensalt())
def validate_hash(self, password, stored_hash):
"""Validates that self.hash(password) == stored_hash.
Args:
password (str): Password to hash.
stored_hash (str): Expected hash value.
Returns:
Whether self.hash(password) == stored_hash (bool).
"""
return bcrypt.checkpw(password, stored_hash)

View File

@@ -22,7 +22,6 @@ from synapse.api.constants import EventTypes
from synapse.types import RoomAlias
import logging
import string
logger = logging.getLogger(__name__)
@@ -41,10 +40,6 @@ class DirectoryHandler(BaseHandler):
def _create_association(self, room_alias, room_id, servers=None):
# general association creation for both human users and app services
for wchar in string.whitespace:
if wchar in room_alias.localpart:
raise SynapseError(400, "Invalid characters in room alias")
if not self.hs.is_mine(room_alias):
raise SynapseError(400, "Room alias must be local")
# TODO(erikj): Change this.

View File

@@ -15,6 +15,7 @@
from twisted.internet import defer
from synapse.util.logcontext import PreserveLoggingContext
from synapse.util.logutils import log_function
from synapse.types import UserID
from synapse.events.utils import serialize_event
@@ -49,12 +50,7 @@ class EventStreamHandler(BaseHandler):
@defer.inlineCallbacks
@log_function
def get_stream(self, auth_user_id, pagin_config, timeout=0,
as_client_event=True, affect_presence=True,
only_room_events=False):
"""Fetches the events stream for a given user.
If `only_room_events` is `True` only room events will be returned.
"""
as_client_event=True, affect_presence=True):
auth_user = UserID.from_string(auth_user_id)
try:
@@ -75,15 +71,7 @@ class EventStreamHandler(BaseHandler):
self._streams_per_user[auth_user] += 1
rm_handler = self.hs.get_handlers().room_member_handler
app_service = yield self.store.get_app_service_by_user_id(
auth_user.to_string()
)
if app_service:
rooms = yield self.store.get_app_service_rooms(app_service)
room_ids = set(r.room_id for r in rooms)
else:
room_ids = yield rm_handler.get_joined_rooms_for_user(auth_user)
room_ids = yield rm_handler.get_joined_rooms_for_user(auth_user)
if timeout:
# If they've set a timeout set a minimum limit.
@@ -93,10 +81,10 @@ class EventStreamHandler(BaseHandler):
# thundering herds on restart.
timeout = random.randint(int(timeout*0.9), int(timeout*1.1))
events, tokens = yield self.notifier.get_events_for(
auth_user, room_ids, pagin_config, timeout,
only_room_events=only_room_events
)
with PreserveLoggingContext():
events, tokens = yield self.notifier.get_events_for(
auth_user, room_ids, pagin_config, timeout
)
time_now = self.clock.time_msec()

View File

@@ -18,11 +18,9 @@
from ._base import BaseHandler
from synapse.api.errors import (
AuthError, FederationError, StoreError, CodeMessageException, SynapseError,
AuthError, FederationError, StoreError,
)
from synapse.api.constants import EventTypes, Membership, RejectedReason
from synapse.util import unwrapFirstError
from synapse.util.logcontext import PreserveLoggingContext
from synapse.util.logutils import log_function
from synapse.util.async import run_on_reactor
from synapse.util.frozenutils import unfreeze
@@ -31,10 +29,6 @@ from synapse.crypto.event_signing import (
)
from synapse.types import UserID
from synapse.events.utils import prune_event
from synapse.util.retryutils import NotRetryingDestination
from twisted.internet import defer
import itertools
@@ -79,6 +73,8 @@ class FederationHandler(BaseHandler):
# When joining a room we need to queue any events for that room up
self.room_queues = {}
@log_function
@defer.inlineCallbacks
def handle_new_event(self, event, destinations):
""" Takes in an event from the client to server side, that has already
been authed and handled by the state module, and sends it to any
@@ -93,7 +89,9 @@ class FederationHandler(BaseHandler):
processing.
"""
return self.replication_layer.send_pdu(event, destinations)
yield run_on_reactor()
self.replication_layer.send_pdu(event, destinations)
@log_function
@defer.inlineCallbacks
@@ -140,32 +138,29 @@ class FederationHandler(BaseHandler):
if state and auth_chain is not None:
# If we have any state or auth_chain given to us by the replication
# layer, then we should handle them (if we haven't before.)
event_infos = []
for e in itertools.chain(auth_chain, state):
if e.event_id in seen_ids:
continue
e.internal_metadata.outlier = True
auth_ids = [e_id for e_id, _ in e.auth_events]
auth = {
(e.type, e.state_key): e for e in auth_chain
if e.event_id in auth_ids
}
event_infos.append({
"event": e,
"auth_events": auth,
})
seen_ids.add(e.event_id)
yield self._handle_new_events(
origin,
event_infos,
outliers=True
)
e.internal_metadata.outlier = True
try:
auth_ids = [e_id for e_id, _ in e.auth_events]
auth = {
(e.type, e.state_key): e for e in auth_chain
if e.event_id in auth_ids
}
yield self._handle_new_event(
origin, e, auth_events=auth
)
seen_ids.add(e.event_id)
except:
logger.exception(
"Failed to handle state event %s",
e.event_id,
)
try:
_, event_stream_id, max_stream_id = yield self._handle_new_event(
yield self._handle_new_event(
origin,
event,
state=state,
@@ -206,11 +201,9 @@ class FederationHandler(BaseHandler):
target_user = UserID.from_string(target_user_id)
extra_users.append(target_user)
with PreserveLoggingContext():
d = self.notifier.on_new_room_event(
event, event_stream_id, max_stream_id,
extra_users=extra_users
)
d = self.notifier.on_new_room_event(
event, extra_users=extra_users
)
def log_failure(f):
logger.warn(
@@ -227,300 +220,38 @@ class FederationHandler(BaseHandler):
"user_joined_room", user=user, room_id=event.room_id
)
@defer.inlineCallbacks
def _filter_events_for_server(self, server_name, room_id, events):
event_to_state = yield self.store.get_state_for_events(
room_id, frozenset(e.event_id for e in events),
types=(
(EventTypes.RoomHistoryVisibility, ""),
(EventTypes.Member, None),
)
)
def redact_disallowed(event, state):
if not state:
return event
history = state.get((EventTypes.RoomHistoryVisibility, ''), None)
if history:
visibility = history.content.get("history_visibility", "shared")
if visibility in ["invited", "joined"]:
# We now loop through all state events looking for
# membership states for the requesting server to determine
# if the server is either in the room or has been invited
# into the room.
for ev in state.values():
if ev.type != EventTypes.Member:
continue
try:
domain = UserID.from_string(ev.state_key).domain
except:
continue
if domain != server_name:
continue
memtype = ev.membership
if memtype == Membership.JOIN:
return event
elif memtype == Membership.INVITE:
if visibility == "invited":
return event
else:
return prune_event(event)
return event
defer.returnValue([
redact_disallowed(e, event_to_state[e.event_id])
for e in events
])
@log_function
@defer.inlineCallbacks
def backfill(self, dest, room_id, limit, extremities=[]):
def backfill(self, dest, room_id, limit):
""" Trigger a backfill request to `dest` for the given `room_id`
"""
if not extremities:
extremities = yield self.store.get_oldest_events_in_room(room_id)
extremities = yield self.store.get_oldest_events_in_room(room_id)
events = yield self.replication_layer.backfill(
pdus = yield self.replication_layer.backfill(
dest,
room_id,
limit=limit,
limit,
extremities=extremities,
)
event_map = {e.event_id: e for e in events}
events = []
event_ids = set(e.event_id for e in events)
for pdu in pdus:
event = pdu
edges = [
ev.event_id
for ev in events
if set(e_id for e_id, _ in ev.prev_events) - event_ids
]
# FIXME (erikj): Not sure this actually works :/
context = yield self.state_handler.compute_event_context(event)
logger.info(
"backfill: Got %d events with %d edges",
len(events), len(edges),
)
events.append((event, context))
# For each edge get the current state.
auth_events = {}
state_events = {}
events_to_state = {}
for e_id in edges:
state, auth = yield self.replication_layer.get_state_for_room(
destination=dest,
room_id=room_id,
event_id=e_id
yield self.store.persist_event(
event,
context=context,
backfilled=True
)
auth_events.update({a.event_id: a for a in auth})
auth_events.update({s.event_id: s for s in state})
state_events.update({s.event_id: s for s in state})
events_to_state[e_id] = state
seen_events = yield self.store.have_events(
set(auth_events.keys()) | set(state_events.keys())
)
all_events = events + state_events.values() + auth_events.values()
required_auth = set(
a_id for event in all_events for a_id, _ in event.auth_events
)
missing_auth = required_auth - set(auth_events)
results = yield defer.gatherResults(
[
self.replication_layer.get_pdu(
[dest],
event_id,
outlier=True,
timeout=10000,
)
for event_id in missing_auth
],
consumeErrors=True
).addErrback(unwrapFirstError)
auth_events.update({a.event_id: a for a in results})
ev_infos = []
for a in auth_events.values():
if a.event_id in seen_events:
continue
ev_infos.append({
"event": a,
"auth_events": {
(auth_events[a_id].type, auth_events[a_id].state_key):
auth_events[a_id]
for a_id, _ in a.auth_events
}
})
for e_id in events_to_state:
ev_infos.append({
"event": event_map[e_id],
"state": events_to_state[e_id],
"auth_events": {
(auth_events[a_id].type, auth_events[a_id].state_key):
auth_events[a_id]
for a_id, _ in event_map[e_id].auth_events
}
})
events.sort(key=lambda e: e.depth)
for event in events:
if event in events_to_state:
continue
ev_infos.append({
"event": event,
})
yield self._handle_new_events(
dest, ev_infos,
backfilled=True,
)
defer.returnValue(events)
@defer.inlineCallbacks
def maybe_backfill(self, room_id, current_depth):
"""Checks the database to see if we should backfill before paginating,
and if so do.
"""
extremities = yield self.store.get_oldest_events_with_depth_in_room(
room_id
)
if not extremities:
logger.debug("Not backfilling as no extremeties found.")
return
# Check if we reached a point where we should start backfilling.
sorted_extremeties_tuple = sorted(
extremities.items(),
key=lambda e: -int(e[1])
)
max_depth = sorted_extremeties_tuple[0][1]
if current_depth > max_depth:
logger.debug(
"Not backfilling as we don't need to. %d < %d",
max_depth, current_depth,
)
return
# Now we need to decide which hosts to hit first.
# First we try hosts that are already in the room
# TODO: HEURISTIC ALERT.
curr_state = yield self.state_handler.get_current_state(room_id)
def get_domains_from_state(state):
joined_users = [
(state_key, int(event.depth))
for (e_type, state_key), event in state.items()
if e_type == EventTypes.Member
and event.membership == Membership.JOIN
]
joined_domains = {}
for u, d in joined_users:
try:
dom = UserID.from_string(u).domain
old_d = joined_domains.get(dom)
if old_d:
joined_domains[dom] = min(d, old_d)
else:
joined_domains[dom] = d
except:
pass
return sorted(joined_domains.items(), key=lambda d: d[1])
curr_domains = get_domains_from_state(curr_state)
likely_domains = [
domain for domain, depth in curr_domains
if domain is not self.server_name
]
@defer.inlineCallbacks
def try_backfill(domains):
# TODO: Should we try multiple of these at a time?
for dom in domains:
try:
events = yield self.backfill(
dom, room_id,
limit=100,
extremities=[e for e in extremities.keys()]
)
except SynapseError:
logger.info(
"Failed to backfill from %s because %s",
dom, e,
)
continue
except CodeMessageException as e:
if 400 <= e.code < 500:
raise
logger.info(
"Failed to backfill from %s because %s",
dom, e,
)
continue
except NotRetryingDestination as e:
logger.info(e.message)
continue
except Exception as e:
logger.exception(
"Failed to backfill from %s because %s",
dom, e,
)
continue
if events:
defer.returnValue(True)
defer.returnValue(False)
success = yield try_backfill(likely_domains)
if success:
defer.returnValue(True)
# Huh, well *those* domains didn't work out. Lets try some domains
# from the time.
tried_domains = set(likely_domains)
tried_domains.add(self.server_name)
event_ids = list(extremities.keys())
states = yield defer.gatherResults([
self.state_handler.resolve_state_groups(room_id, [e])
for e in event_ids
])
states = dict(zip(event_ids, [s[1] for s in states]))
for e_id, _ in sorted_extremeties_tuple:
likely_domains = get_domains_from_state(states[e_id])
success = yield try_backfill([
dom for dom in likely_domains
if dom not in tried_domains
])
if success:
defer.returnValue(True)
tried_domains.update(likely_domains)
defer.returnValue(False)
@defer.inlineCallbacks
def send_invite(self, target_host, event):
""" Sends the invite to the remote server for signing.
@@ -649,22 +380,46 @@ class FederationHandler(BaseHandler):
# FIXME
pass
ev_infos = []
for e in itertools.chain(state, auth_chain):
for e in auth_chain:
e.internal_metadata.outlier = True
if e.event_id == event.event_id:
continue
try:
auth_ids = [e_id for e_id, _ in e.auth_events]
auth = {
(e.type, e.state_key): e for e in auth_chain
if e.event_id in auth_ids
}
yield self._handle_new_event(
origin, e, auth_events=auth
)
except:
logger.exception(
"Failed to handle auth event %s",
e.event_id,
)
for e in state:
if e.event_id == event.event_id:
continue
e.internal_metadata.outlier = True
auth_ids = [e_id for e_id, _ in e.auth_events]
ev_infos.append({
"event": e,
"auth_events": {
try:
auth_ids = [e_id for e_id, _ in e.auth_events]
auth = {
(e.type, e.state_key): e for e in auth_chain
if e.event_id in auth_ids
}
})
yield self._handle_new_events(origin, ev_infos, outliers=True)
yield self._handle_new_event(
origin, e, auth_events=auth
)
except:
logger.exception(
"Failed to handle state event %s",
e.event_id,
)
auth_ids = [e_id for e_id, _ in event.auth_events]
auth_events = {
@@ -672,7 +427,7 @@ class FederationHandler(BaseHandler):
if e.event_id in auth_ids
}
_, event_stream_id, max_stream_id = yield self._handle_new_event(
yield self._handle_new_event(
origin,
new_event,
state=state,
@@ -680,11 +435,9 @@ class FederationHandler(BaseHandler):
auth_events=auth_events,
)
with PreserveLoggingContext():
d = self.notifier.on_new_room_event(
new_event, event_stream_id, max_stream_id,
extra_users=[joinee]
)
d = self.notifier.on_new_room_event(
new_event, extra_users=[joinee]
)
def log_failure(f):
logger.warn(
@@ -749,9 +502,7 @@ class FederationHandler(BaseHandler):
event.internal_metadata.outlier = False
context, event_stream_id, max_stream_id = yield self._handle_new_event(
origin, event
)
context = yield self._handle_new_event(origin, event)
logger.debug(
"on_send_join_request: After _handle_new_event: %s, sigs: %s",
@@ -765,10 +516,9 @@ class FederationHandler(BaseHandler):
target_user = UserID.from_string(target_user_id)
extra_users.append(target_user)
with PreserveLoggingContext():
d = self.notifier.on_new_room_event(
event, event_stream_id, max_stream_id, extra_users=extra_users
)
d = self.notifier.on_new_room_event(
event, extra_users=extra_users
)
def log_failure(f):
logger.warn(
@@ -841,18 +591,16 @@ class FederationHandler(BaseHandler):
context = yield self.state_handler.compute_event_context(event)
event_stream_id, max_stream_id = yield self.store.persist_event(
yield self.store.persist_event(
event,
context=context,
backfilled=False,
)
target_user = UserID.from_string(event.state_key)
with PreserveLoggingContext():
d = self.notifier.on_new_room_event(
event, event_stream_id, max_stream_id,
extra_users=[target_user],
)
d = self.notifier.on_new_room_event(
event, extra_users=[target_user],
)
def log_failure(f):
logger.warn(
@@ -874,7 +622,7 @@ class FederationHandler(BaseHandler):
raise AuthError(403, "Host not in room.")
state_groups = yield self.store.get_state_groups(
room_id, [event_id]
[event_id]
)
if state_groups:
@@ -921,8 +669,6 @@ class FederationHandler(BaseHandler):
limit
)
events = yield self._filter_events_for_server(origin, room_id, events)
defer.returnValue(events)
@defer.inlineCallbacks
@@ -981,72 +727,31 @@ class FederationHandler(BaseHandler):
def _handle_new_event(self, origin, event, state=None, backfilled=False,
current_state=None, auth_events=None):
outlier = event.internal_metadata.is_outlier()
context = yield self._prep_event(
origin, event,
state=state,
backfilled=backfilled,
current_state=current_state,
auth_events=auth_events,
logger.debug(
"_handle_new_event: %s, sigs: %s",
event.event_id, event.signatures,
)
event_stream_id, max_stream_id = yield self.store.persist_event(
event,
context=context,
backfilled=backfilled,
is_new_state=(not outlier and not backfilled),
current_state=current_state,
)
defer.returnValue((context, event_stream_id, max_stream_id))
@defer.inlineCallbacks
def _handle_new_events(self, origin, event_infos, backfilled=False,
outliers=False):
contexts = yield defer.gatherResults(
[
self._prep_event(
origin,
ev_info["event"],
state=ev_info.get("state"),
backfilled=backfilled,
auth_events=ev_info.get("auth_events"),
)
for ev_info in event_infos
]
)
yield self.store.persist_events(
[
(ev_info["event"], context)
for ev_info, context in itertools.izip(event_infos, contexts)
],
backfilled=backfilled,
is_new_state=(not outliers and not backfilled),
)
@defer.inlineCallbacks
def _prep_event(self, origin, event, state=None, backfilled=False,
current_state=None, auth_events=None):
outlier = event.internal_metadata.is_outlier()
context = yield self.state_handler.compute_event_context(
event, old_state=state, outlier=outlier,
event, old_state=state
)
if not auth_events:
auth_events = context.current_state
logger.debug(
"_handle_new_event: %s, auth_events: %s",
event.event_id, auth_events,
)
is_new_state = not event.internal_metadata.is_outlier()
# This is a hack to fix some old rooms where the initial join event
# didn't reference the create event in its auth events.
if event.type == EventTypes.Member and not event.auth_events:
if len(event.prev_events) == 1 and event.depth < 5:
c = yield self.store.get_event(
event.prev_events[0][0],
allow_none=True,
)
if c and c.type == EventTypes.Create:
if len(event.prev_events) == 1:
c = yield self.store.get_event(event.prev_events[0][0])
if c.type == EventTypes.Create:
auth_events[(c.type, c.state_key)] = c
try:
@@ -1061,6 +766,25 @@ class FederationHandler(BaseHandler):
context.rejected = RejectedReason.AUTH_ERROR
# FIXME: Don't store as rejected with AUTH_ERROR if we haven't
# seen all the auth events.
yield self.store.persist_event(
event,
context=context,
backfilled=backfilled,
is_new_state=False,
current_state=current_state,
)
raise
yield self.store.persist_event(
event,
context=context,
backfilled=backfilled,
is_new_state=(is_new_state and not backfilled),
current_state=current_state,
)
defer.returnValue(context)
@defer.inlineCallbacks
@@ -1124,24 +848,14 @@ class FederationHandler(BaseHandler):
@log_function
def do_auth(self, origin, event, context, auth_events):
# Check if we have all the auth events.
current_state = set(e.event_id for e in auth_events.values())
have_events = yield self.store.have_events(
[e_id for e_id, _ in event.auth_events]
)
event_auth_events = set(e_id for e_id, _ in event.auth_events)
if event_auth_events - current_state:
have_events = yield self.store.have_events(
event_auth_events - current_state
)
else:
have_events = {}
have_events.update({
e.event_id: ""
for e in auth_events.values()
})
seen_events = set(have_events.keys())
missing_auth = event_auth_events - seen_events - current_state
missing_auth = event_auth_events - seen_events
if missing_auth:
logger.info("Missing auth: %s", missing_auth)
@@ -1211,7 +925,7 @@ class FederationHandler(BaseHandler):
if d in have_events and not have_events[d]
],
consumeErrors=True
).addErrback(unwrapFirstError)
)
if different_events:
local_view = dict(auth_events)
@@ -1456,52 +1170,3 @@ class FederationHandler(BaseHandler):
},
"missing": [e.event_id for e in missing_locals],
})
@defer.inlineCallbacks
def _handle_auth_events(self, origin, auth_events):
auth_ids_to_deferred = {}
def process_auth_ev(ev):
auth_ids = [e_id for e_id, _ in ev.auth_events]
prev_ds = [
auth_ids_to_deferred[i]
for i in auth_ids
if i in auth_ids_to_deferred
]
d = defer.Deferred()
auth_ids_to_deferred[ev.event_id] = d
@defer.inlineCallbacks
def f(*_):
ev.internal_metadata.outlier = True
try:
auth = {
(e.type, e.state_key): e for e in auth_events
if e.event_id in auth_ids
}
yield self._handle_new_event(
origin, ev, auth_events=auth
)
except:
logger.exception(
"Failed to handle auth event %s",
ev.event_id,
)
d.callback(None)
if prev_ds:
dx = defer.DeferredList(prev_ds)
dx.addBoth(f)
else:
f()
for e in auth_events:
process_auth_ev(e)
yield defer.DeferredList(auth_ids_to_deferred.values())

View File

@@ -44,7 +44,7 @@ class IdentityHandler(BaseHandler):
http_client = SimpleHttpClient(self.hs)
# XXX: make this configurable!
# trustedIdServers = ['matrix.org', 'localhost:8090']
trustedIdServers = ['matrix.org', 'vector.im']
trustedIdServers = ['matrix.org']
if 'id_server' in creds:
id_server = creds['id_server']
@@ -117,28 +117,3 @@ class IdentityHandler(BaseHandler):
except CodeMessageException as e:
data = json.loads(e.msg)
defer.returnValue(data)
@defer.inlineCallbacks
def requestEmailToken(self, id_server, email, client_secret, send_attempt, **kwargs):
yield run_on_reactor()
http_client = SimpleHttpClient(self.hs)
params = {
'email': email,
'client_secret': client_secret,
'send_attempt': send_attempt,
}
params.update(kwargs)
try:
data = yield http_client.post_urlencoded_get_json(
"https://%s%s" % (
id_server,
"/_matrix/identity/api/v1/validate/email/requestToken"
),
params
)
defer.returnValue(data)
except CodeMessageException as e:
logger.info("Proxied requestToken failed: %r", e)
raise e

83
synapse/handlers/login.py Normal file
View File

@@ -0,0 +1,83 @@
# -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from ._base import BaseHandler
from synapse.api.errors import LoginError, Codes
import bcrypt
import logging
logger = logging.getLogger(__name__)
class LoginHandler(BaseHandler):
def __init__(self, hs):
super(LoginHandler, self).__init__(hs)
self.hs = hs
@defer.inlineCallbacks
def login(self, user, password):
"""Login as the specified user with the specified password.
Args:
user (str): The user ID.
password (str): The password.
Returns:
The newly allocated access token.
Raises:
StoreError if there was a problem storing the token.
LoginError if there was an authentication problem.
"""
# TODO do this better, it can't go in __init__ else it cyclic loops
if not hasattr(self, "reg_handler"):
self.reg_handler = self.hs.get_handlers().registration_handler
# pull out the hash for this user if they exist
user_info = yield self.store.get_user_by_id(user_id=user)
if not user_info:
logger.warn("Attempted to login as %s but they do not exist", user)
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
stored_hash = user_info["password_hash"]
if bcrypt.checkpw(password, stored_hash):
# generate an access token and store it.
token = self.reg_handler._generate_token(user)
logger.info("Adding token %s for user %s", token, user)
yield self.store.add_access_token_to_user(user, token)
defer.returnValue(token)
else:
logger.warn("Failed password login for user %s", user)
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
@defer.inlineCallbacks
def set_password(self, user_id, newpassword, token_id=None):
password_hash = bcrypt.hashpw(newpassword, bcrypt.gensalt())
yield self.store.user_set_password_hash(user_id, password_hash)
yield self.store.user_delete_access_tokens_apart_from(user_id, token_id)
yield self.hs.get_pusherpool().remove_pushers_by_user_access_token(
user_id, token_id
)
yield self.store.flush_user(user_id)
@defer.inlineCallbacks
def add_threepid(self, user_id, medium, address, validated_at):
yield self.store.user_add_threepid(
user_id, medium, address, validated_at,
self.hs.get_clock().time_msec()
)

View File

@@ -20,9 +20,8 @@ from synapse.api.errors import RoomError, SynapseError
from synapse.streams.config import PaginationConfig
from synapse.events.utils import serialize_event
from synapse.events.validator import EventValidator
from synapse.util import unwrapFirstError
from synapse.util.logcontext import PreserveLoggingContext
from synapse.types import UserID, RoomStreamToken
from synapse.types import UserID
from ._base import BaseHandler
@@ -90,19 +89,9 @@ class MessageHandler(BaseHandler):
if not pagin_config.from_token:
pagin_config.from_token = (
yield self.hs.get_event_sources().get_current_token(
direction='b'
)
yield self.hs.get_event_sources().get_current_token()
)
room_token = RoomStreamToken.parse(pagin_config.from_token.room_key)
if room_token.topological is None:
raise SynapseError(400, "Invalid token")
yield self.hs.get_handlers().federation_handler.maybe_backfill(
room_id, room_token.topological
)
user = UserID.from_string(user_id)
events, next_key = yield data_source.get_pagination_rows(
@@ -113,21 +102,11 @@ class MessageHandler(BaseHandler):
"room_key", next_key
)
if not events:
defer.returnValue({
"chunk": [],
"start": pagin_config.from_token.to_string(),
"end": next_token.to_string(),
})
events = yield self._filter_events_for_client(user_id, room_id, events)
time_now = self.clock.time_msec()
chunk = {
"chunk": [
serialize_event(e, time_now, as_client_event)
for e in events
serialize_event(e, time_now, as_client_event) for e in events
],
"start": pagin_config.from_token.to_string(),
"end": next_token.to_string(),
@@ -135,55 +114,9 @@ class MessageHandler(BaseHandler):
defer.returnValue(chunk)
@defer.inlineCallbacks
def _filter_events_for_client(self, user_id, room_id, events):
event_id_to_state = yield self.store.get_state_for_events(
room_id, frozenset(e.event_id for e in events),
types=(
(EventTypes.RoomHistoryVisibility, ""),
(EventTypes.Member, user_id),
)
)
def allowed(event, state):
if event.type == EventTypes.RoomHistoryVisibility:
return True
membership_ev = state.get((EventTypes.Member, user_id), None)
if membership_ev:
membership = membership_ev.membership
else:
membership = Membership.LEAVE
if membership == Membership.JOIN:
return True
history = state.get((EventTypes.RoomHistoryVisibility, ''), None)
if history:
visibility = history.content.get("history_visibility", "shared")
else:
visibility = "shared"
if visibility == "public":
return True
elif visibility == "shared":
return True
elif visibility == "joined":
return membership == Membership.JOIN
elif visibility == "invited":
return membership == Membership.INVITE
return True
defer.returnValue([
event
for event in events
if allowed(event, event_id_to_state[event.event_id])
])
@defer.inlineCallbacks
def create_and_send_event(self, event_dict, ratelimit=True,
token_id=None, txn_id=None):
client=None, txn_id=None):
""" Given a dict from a client, create and handle a new event.
Creates an FrozenEvent object, filling out auth_events, prev_events,
@@ -217,8 +150,11 @@ class MessageHandler(BaseHandler):
builder.content
)
if token_id is not None:
builder.internal_metadata.token_id = token_id
if client is not None:
if client.token_id is not None:
builder.internal_metadata.token_id = client.token_id
if client.device_id is not None:
builder.internal_metadata.device_id = client.device_id
if txn_id is not None:
builder.internal_metadata.txn_id = txn_id
@@ -314,36 +250,47 @@ class MessageHandler(BaseHandler):
is joined on, may return a "messages" key with messages, depending
on the specified PaginationConfig.
"""
start_time = self.clock.time_msec()
def delta():
return self.clock.time_msec() - start_time
logger.info("initial_sync: start")
room_list = yield self.store.get_rooms_for_user_where_membership_is(
user_id=user_id,
membership_list=[Membership.INVITE, Membership.JOIN]
)
logger.info("initial_sync: got_rooms %d", delta())
user = UserID.from_string(user_id)
rooms_ret = []
now_token = yield self.hs.get_event_sources().get_current_token()
logger.info("initial_sync: now_token %d", delta())
presence_stream = self.hs.get_event_sources().sources["presence"]
pagination_config = PaginationConfig(from_token=now_token)
presence, _ = yield presence_stream.get_pagination_rows(
user, pagination_config.get_source_config("presence"), None
)
receipt_stream = self.hs.get_event_sources().sources["receipt"]
receipt, _ = yield receipt_stream.get_pagination_rows(
user, pagination_config.get_source_config("receipt"), None
)
logger.info("initial_sync: presence_done %d", delta())
public_room_ids = yield self.store.get_public_room_ids()
logger.info("initial_sync: public_rooms %d", delta())
limit = pagin_config.limit
if limit is None:
limit = 10
@defer.inlineCallbacks
def handle_room(event):
logger.info("initial_sync: start: %s %d", event.room_id, delta())
d = {
"room_id": event.room_id,
"membership": event.membership,
@@ -372,10 +319,6 @@ class MessageHandler(BaseHandler):
event.room_id
),
]
).addErrback(unwrapFirstError)
messages = yield self._filter_events_for_client(
user_id, event.room_id, messages
)
start_token = now_token.copy_and_replace("room_key", token[0])
@@ -398,20 +341,19 @@ class MessageHandler(BaseHandler):
except:
logger.exception("Failed to get snapshot")
# Only do N rooms at once
n = 5
d_list = [handle_room(e) for e in room_list]
for i in range(0, len(d_list), n):
yield defer.gatherResults(
d_list[i:i + n],
consumeErrors=True
).addErrback(unwrapFirstError)
logger.info("initial_sync: end: %s %d", event.room_id, delta())
yield defer.gatherResults(
[handle_room(e) for e in room_list],
consumeErrors=True
)
logger.info("initial_sync: done", delta())
ret = {
"rooms": rooms_ret,
"presence": presence,
"receipts": receipt,
"end": now_token.to_string(),
"end": now_token.to_string()
}
defer.returnValue(ret)
@@ -447,6 +389,15 @@ class MessageHandler(BaseHandler):
if limit is None:
limit = 10
messages, token = yield self.store.get_recent_events_for_room(
room_id,
limit=limit,
end_token=now_token.room_key,
)
start_token = now_token.copy_and_replace("room_key", token[0])
end_token = now_token.copy_and_replace("room_key", token[1])
room_members = [
m for m in current_state.values()
if m.type == EventTypes.Member
@@ -454,39 +405,19 @@ class MessageHandler(BaseHandler):
]
presence_handler = self.hs.get_handlers().presence_handler
@defer.inlineCallbacks
def get_presence():
states = yield presence_handler.get_states(
target_users=[UserID.from_string(m.user_id) for m in room_members],
auth_user=auth_user,
as_event=True,
check_auth=False,
)
defer.returnValue(states.values())
receipts_handler = self.hs.get_handlers().receipts_handler
presence, receipts, (messages, token) = yield defer.gatherResults(
[
get_presence(),
receipts_handler.get_receipts_for_room(room_id, now_token.receipt_key),
self.store.get_recent_events_for_room(
room_id,
limit=limit,
end_token=now_token.room_key,
presence = []
for m in room_members:
try:
member_presence = yield presence_handler.get_state(
target_user=UserID.from_string(m.user_id),
auth_user=auth_user,
as_event=True,
)
presence.append(member_presence)
except SynapseError:
logger.exception(
"Failed to get member presence of %r", m.user_id
)
],
consumeErrors=True,
).addErrback(unwrapFirstError)
messages = yield self._filter_events_for_client(
user_id, room_id, messages
)
start_token = now_token.copy_and_replace("room_key", token[0])
end_token = now_token.copy_and_replace("room_key", token[1])
time_now = self.clock.time_msec()
@@ -499,6 +430,5 @@ class MessageHandler(BaseHandler):
"end": end_token.to_string(),
},
"state": state,
"presence": presence,
"receipts": receipts,
"presence": presence
})

View File

@@ -18,8 +18,8 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError, AuthError
from synapse.api.constants import PresenceState
from synapse.util.logcontext import PreserveLoggingContext
from synapse.util.logutils import log_function
from synapse.util.logcontext import PreserveLoggingContext
from synapse.types import UserID
import synapse.metrics
@@ -146,10 +146,6 @@ class PresenceHandler(BaseHandler):
self._user_cachemap = {}
self._user_cachemap_latest_serial = 0
# map room_ids to the latest presence serial for a member of that
# room
self._room_serials = {}
metrics.register_callback(
"userCachemap:size",
lambda: len(self._user_cachemap),
@@ -191,38 +187,24 @@ class PresenceHandler(BaseHandler):
defer.returnValue(False)
@defer.inlineCallbacks
def get_state(self, target_user, auth_user, as_event=False, check_auth=True):
"""Get the current presence state of the given user.
Args:
target_user (UserID): The user whose presence we want
auth_user (UserID): The user requesting the presence, used for
checking if said user is allowed to see the persence of the
`target_user`
as_event (bool): Format the return as an event or not?
check_auth (bool): Perform the auth checks or not?
Returns:
dict: The presence state of the `target_user`, whose format depends
on the `as_event` argument.
"""
def get_state(self, target_user, auth_user, as_event=False):
if self.hs.is_mine(target_user):
if check_auth:
visible = yield self.is_presence_visible(
observer_user=auth_user,
observed_user=target_user
)
visible = yield self.is_presence_visible(
observer_user=auth_user,
observed_user=target_user
)
if not visible:
raise SynapseError(404, "Presence information not visible")
if not visible:
raise SynapseError(404, "Presence information not visible")
state = yield self.store.get_presence_state(target_user.localpart)
if "mtime" in state:
del state["mtime"]
state["presence"] = state.pop("state")
if target_user in self._user_cachemap:
state = self._user_cachemap[target_user].get_state()
else:
state = yield self.store.get_presence_state(target_user.localpart)
if "mtime" in state:
del state["mtime"]
state["presence"] = state.pop("state")
cached_state = self._user_cachemap[target_user].get_state()
if "last_active" in cached_state:
state["last_active"] = cached_state["last_active"]
else:
# TODO(paul): Have remote server send us permissions set
state = self._get_or_offline_usercache(target_user).get_state()
@@ -246,81 +228,6 @@ class PresenceHandler(BaseHandler):
else:
defer.returnValue(state)
@defer.inlineCallbacks
def get_states(self, target_users, auth_user, as_event=False, check_auth=True):
"""A batched version of the `get_state` method that accepts a list of
`target_users`
Args:
target_users (list): The list of UserID's whose presence we want
auth_user (UserID): The user requesting the presence, used for
checking if said user is allowed to see the persence of the
`target_users`
as_event (bool): Format the return as an event or not?
check_auth (bool): Perform the auth checks or not?
Returns:
dict: A mapping from user -> presence_state
"""
local_users, remote_users = partitionbool(
target_users,
lambda u: self.hs.is_mine(u)
)
if check_auth:
for user in local_users:
visible = yield self.is_presence_visible(
observer_user=auth_user,
observed_user=user
)
if not visible:
raise SynapseError(404, "Presence information not visible")
results = {}
if local_users:
for user in local_users:
if user in self._user_cachemap:
results[user] = self._user_cachemap[user].get_state()
local_to_user = {u.localpart: u for u in local_users}
states = yield self.store.get_presence_states(
[u.localpart for u in local_users if u not in results]
)
for local_part, state in states.items():
if state is None:
continue
res = {"presence": state["state"]}
if "status_msg" in state and state["status_msg"]:
res["status_msg"] = state["status_msg"]
results[local_to_user[local_part]] = res
for user in remote_users:
# TODO(paul): Have remote server send us permissions set
results[user] = self._get_or_offline_usercache(user).get_state()
for state in results.values():
if "last_active" in state:
state["last_active_ago"] = int(
self.clock.time_msec() - state.pop("last_active")
)
if as_event:
for user, state in results.items():
content = state
content["user_id"] = user.to_string()
if "last_active" in content:
content["last_active_ago"] = int(
self._clock.time_msec() - content.pop("last_active")
)
results[user] = {"type": "m.presence", "content": content}
defer.returnValue(results)
@defer.inlineCallbacks
@log_function
def set_state(self, target_user, auth_user, state):
@@ -371,14 +278,15 @@ class PresenceHandler(BaseHandler):
now_online = state["presence"] != PresenceState.OFFLINE
was_polling = target_user in self._user_cachemap
if now_online and not was_polling:
self.start_polling_presence(target_user, state=state)
elif not now_online and was_polling:
self.stop_polling_presence(target_user)
with PreserveLoggingContext():
if now_online and not was_polling:
self.start_polling_presence(target_user, state=state)
elif not now_online and was_polling:
self.stop_polling_presence(target_user)
# TODO(paul): perform a presence push as part of start/stop poll so
# we don't have to do this all the time
self.changed_presencelike_data(target_user, state)
# TODO(paul): perform a presence push as part of start/stop poll so
# we don't have to do this all the time
self.changed_presencelike_data(target_user, state)
def bump_presence_active_time(self, user, now=None):
if now is None:
@@ -390,34 +298,13 @@ class PresenceHandler(BaseHandler):
self.changed_presencelike_data(user, {"last_active": now})
def get_joined_rooms_for_user(self, user):
"""Get the list of rooms a user is joined to.
Args:
user(UserID): The user.
Returns:
A Deferred of a list of room id strings.
"""
rm_handler = self.homeserver.get_handlers().room_member_handler
return rm_handler.get_joined_rooms_for_user(user)
def get_joined_users_for_room_id(self, room_id):
rm_handler = self.homeserver.get_handlers().room_member_handler
return rm_handler.get_room_members(room_id)
@defer.inlineCallbacks
def changed_presencelike_data(self, user, state):
"""Updates the presence state of a local user.
statuscache = self._get_or_make_usercache(user)
Args:
user(UserID): The user being updated.
state(dict): The new presence state for the user.
Returns:
A Deferred
"""
self._user_cachemap_latest_serial += 1
statuscache = yield self.update_presence_cache(user, state)
yield self.push_presence(user, statuscache=statuscache)
statuscache.update(state, serial=self._user_cachemap_latest_serial)
return self.push_presence(user, statuscache=statuscache)
@log_function
def started_user_eventstream(self, user):
@@ -431,21 +318,14 @@ class PresenceHandler(BaseHandler):
@defer.inlineCallbacks
def user_joined_room(self, user, room_id):
"""Called via the distributor whenever a user joins a room.
Notifies the new member of the presence of the current members.
Notifies the current members of the room of the new member's presence.
Args:
user(UserID): The user who joined the room.
room_id(str): The room id the user joined.
"""
if self.hs.is_mine(user):
statuscache = self._get_or_make_usercache(user)
# No actual update but we need to bump the serial anyway for the
# event source
self._user_cachemap_latest_serial += 1
statuscache = yield self.update_presence_cache(
user, room_ids=[room_id]
)
statuscache.update({}, serial=self._user_cachemap_latest_serial)
self.push_update_to_local_and_remote(
observed_user=user,
room_ids=[room_id],
@@ -453,22 +333,18 @@ class PresenceHandler(BaseHandler):
)
# We also want to tell them about current presence of people.
curr_users = yield self.get_joined_users_for_room_id(room_id)
rm_handler = self.homeserver.get_handlers().room_member_handler
curr_users = yield rm_handler.get_room_members(room_id)
for local_user in [c for c in curr_users if self.hs.is_mine(c)]:
statuscache = yield self.update_presence_cache(
local_user, room_ids=[room_id], add_to_cache=False
)
self.push_update_to_local_and_remote(
observed_user=local_user,
users_to_push=[user],
statuscache=statuscache,
statuscache=self._get_or_offline_usercache(local_user),
)
@defer.inlineCallbacks
def send_invite(self, observer_user, observed_user):
"""Request the presence of a local or remote user for a local user"""
if not self.hs.is_mine(observer_user):
raise SynapseError(400, "User is not hosted on this Home Server")
@@ -503,15 +379,6 @@ class PresenceHandler(BaseHandler):
@defer.inlineCallbacks
def invite_presence(self, observed_user, observer_user):
"""Handles a m.presence_invite EDU. A remote or local user has
requested presence updates for a local user. If the invite is accepted
then allow the local or remote user to see the presence of the local
user.
Args:
observed_user(UserID): The local user whose presence is requested.
observer_user(UserID): The remote or local user requesting presence.
"""
accept = yield self._should_accept_invite(observed_user, observer_user)
if accept:
@@ -538,34 +405,16 @@ class PresenceHandler(BaseHandler):
@defer.inlineCallbacks
def accept_presence(self, observed_user, observer_user):
"""Handles a m.presence_accept EDU. Mark a presence invite from a
local or remote user as accepted in a local user's presence list.
Starts polling for presence updates from the local or remote user.
Args:
observed_user(UserID): The user to update in the presence list.
observer_user(UserID): The owner of the presence list to update.
"""
yield self.store.set_presence_list_accepted(
observer_user.localpart, observed_user.to_string()
)
self.start_polling_presence(
observer_user, target_user=observed_user
)
with PreserveLoggingContext():
self.start_polling_presence(
observer_user, target_user=observed_user
)
@defer.inlineCallbacks
def deny_presence(self, observed_user, observer_user):
"""Handle a m.presence_deny EDU. Removes a local or remote user from a
local user's presence list.
Args:
observed_user(UserID): The local or remote user to remove from the
list.
observer_user(UserID): The local owner of the presence list.
Returns:
A Deferred.
"""
yield self.store.del_presence_list(
observer_user.localpart, observed_user.to_string()
)
@@ -574,16 +423,6 @@ class PresenceHandler(BaseHandler):
@defer.inlineCallbacks
def drop(self, observed_user, observer_user):
"""Remove a local or remote user from a local user's presence list and
unsubscribe the local user from updates that user.
Args:
observed_user(UserId): The local or remote user to remove from the
list.
observer_user(UserId): The local owner of the presence list.
Returns:
A Deferred.
"""
if not self.hs.is_mine(observer_user):
raise SynapseError(400, "User is not hosted on this Home Server")
@@ -591,66 +430,34 @@ class PresenceHandler(BaseHandler):
observer_user.localpart, observed_user.to_string()
)
self.stop_polling_presence(
observer_user, target_user=observed_user
)
with PreserveLoggingContext():
self.stop_polling_presence(
observer_user, target_user=observed_user
)
@defer.inlineCallbacks
def get_presence_list(self, observer_user, accepted=None):
"""Get the presence list for a local user. The retured list includes
the current presence state for each user listed.
Args:
observer_user(UserID): The local user whose presence list to fetch.
accepted(bool or None): If not none then only include users who
have or have not accepted the presence invite request.
Returns:
A Deferred list of presence state events.
"""
if not self.hs.is_mine(observer_user):
raise SynapseError(400, "User is not hosted on this Home Server")
presence_list = yield self.store.get_presence_list(
presence = yield self.store.get_presence_list(
observer_user.localpart, accepted=accepted
)
results = []
for row in presence_list:
observed_user = UserID.from_string(row["observed_user_id"])
result = {
"observed_user": observed_user, "accepted": row["accepted"]
}
result.update(
self._get_or_offline_usercache(observed_user).get_state()
)
if "last_active" in result:
result["last_active_ago"] = int(
self.clock.time_msec() - result.pop("last_active")
for p in presence:
observed_user = UserID.from_string(p.pop("observed_user_id"))
p["observed_user"] = observed_user
p.update(self._get_or_offline_usercache(observed_user).get_state())
if "last_active" in p:
p["last_active_ago"] = int(
self.clock.time_msec() - p.pop("last_active")
)
results.append(result)
defer.returnValue(results)
defer.returnValue(presence)
@defer.inlineCallbacks
@log_function
def start_polling_presence(self, user, target_user=None, state=None):
"""Subscribe a local user to presence updates from a local or remote
user. If no target_user is supplied then subscribe to all users stored
in the presence list for the local user.
Additonally this pushes the current presence state of this user to all
target_users. That state can be provided directly or will be read from
the stored state for the local user.
Also this attempts to notify the local user of the current state of
any local target users.
Args:
user(UserID): The local user that whishes for presence updates.
target_user(UserID): The local or remote user whose updates are
wanted.
state(dict): Optional presence state for the local user.
"""
logger.debug("Start polling for presence from %s", user)
if target_user:
@@ -666,7 +473,8 @@ class PresenceHandler(BaseHandler):
# Also include people in all my rooms
room_ids = yield self.get_joined_rooms_for_user(user)
rm_handler = self.homeserver.get_handlers().room_member_handler
room_ids = yield rm_handler.get_joined_rooms_for_user(user)
if state is None:
state = yield self.store.get_presence_state(user.localpart)
@@ -690,7 +498,9 @@ class PresenceHandler(BaseHandler):
# We want to tell the person that just came online
# presence state of people they are interested in?
self.push_update_to_clients(
observed_user=target_user,
users_to_push=[user],
statuscache=self._get_or_offline_usercache(target_user),
)
deferreds = []
@@ -707,12 +517,6 @@ class PresenceHandler(BaseHandler):
yield defer.DeferredList(deferreds, consumeErrors=True)
def _start_polling_local(self, user, target_user):
"""Subscribe a local user to presence updates for a local user
Args:
user(UserId): The local user that wishes for updates.
target_user(UserId): The local users whose updates are wanted.
"""
target_localpart = target_user.localpart
if target_localpart not in self._local_pushmap:
@@ -721,17 +525,6 @@ class PresenceHandler(BaseHandler):
self._local_pushmap[target_localpart].add(user)
def _start_polling_remote(self, user, domain, remoteusers):
"""Subscribe a local user to presence updates for remote users on a
given remote domain.
Args:
user(UserID): The local user that wishes for updates.
domain(str): The remote server the local user wants updates from.
remoteusers(UserID): The remote users that local user wants to be
told about.
Returns:
A Deferred.
"""
to_poll = set()
for u in remoteusers:
@@ -752,17 +545,6 @@ class PresenceHandler(BaseHandler):
@log_function
def stop_polling_presence(self, user, target_user=None):
"""Unsubscribe a local user from presence updates from a local or
remote user. If no target user is supplied then unsubscribe the user
from all presence updates that the user had subscribed to.
Args:
user(UserID): The local user that no longer wishes for updates.
target_user(UserID or None): The user whose updates are no longer
wanted.
Returns:
A Deferred.
"""
logger.debug("Stop polling for presence from %s", user)
if not target_user or self.hs.is_mine(target_user):
@@ -791,13 +573,6 @@ class PresenceHandler(BaseHandler):
return defer.DeferredList(deferreds, consumeErrors=True)
def _stop_polling_local(self, user, target_user):
"""Unsubscribe a local user from presence updates from a local user on
this server.
Args:
user(UserID): The local user that no longer wishes for updates.
target_user(UserID): The user whose updates are no longer wanted.
"""
for localpart in self._local_pushmap.keys():
if target_user and localpart != target_user.localpart:
continue
@@ -810,17 +585,6 @@ class PresenceHandler(BaseHandler):
@log_function
def _stop_polling_remote(self, user, domain, remoteusers):
"""Unsubscribe a local user from presence updates from remote users on
a given domain.
Args:
user(UserID): The local user that no longer wishes for updates.
domain(str): The remote server to unsubscribe from.
remoteusers([UserID]): The users on that remote server that the
local user no longer wishes to be updated about.
Returns:
A Deferred.
"""
to_unpoll = set()
for u in remoteusers:
@@ -842,19 +606,6 @@ class PresenceHandler(BaseHandler):
@defer.inlineCallbacks
@log_function
def push_presence(self, user, statuscache):
"""
Notify local and remote users of a change in presence of a local user.
Pushes the update to local clients and remote domains that are directly
subscribed to the presence of the local user.
Also pushes that update to any local user or remote domain that shares
a room with the local user.
Args:
user(UserID): The local user whose presence was updated.
statuscache(UserPresenceCache): Cache of the user's presence state
Returns:
A Deferred.
"""
assert(self.hs.is_mine(user))
logger.debug("Pushing presence update from %s", user)
@@ -866,7 +617,8 @@ class PresenceHandler(BaseHandler):
# and also user is informed of server-forced pushes
localusers.add(user)
room_ids = yield self.get_joined_rooms_for_user(user)
rm_handler = self.homeserver.get_handlers().room_member_handler
room_ids = yield rm_handler.get_joined_rooms_for_user(user)
if not localusers and not room_ids:
defer.returnValue(None)
@@ -881,23 +633,44 @@ class PresenceHandler(BaseHandler):
yield self.distributor.fire("user_presence_changed", user, statuscache)
@defer.inlineCallbacks
def incoming_presence(self, origin, content):
"""Handle an incoming m.presence EDU.
For each presence update in the "push" list update our local cache and
notify the appropriate local clients. Only clients that share a room
or are directly subscribed to the presence for a user should be
notified of the update.
For each subscription request in the "poll" list start pushing presence
updates to the remote server.
For unsubscribe request in the "unpoll" list stop pushing presence
updates to the remote server.
def _push_presence_remote(self, user, destination, state=None):
if state is None:
state = yield self.store.get_presence_state(user.localpart)
del state["mtime"]
state["presence"] = state.pop("state")
Args:
orgin(str): The source of this m.presence EDU.
content(dict): The content of this m.presence EDU.
Returns:
A Deferred.
"""
if user in self._user_cachemap:
state["last_active"] = (
self._user_cachemap[user].get_state()["last_active"]
)
yield self.distributor.fire(
"collect_presencelike_data", user, state
)
if "last_active" in state:
state = dict(state)
state["last_active_ago"] = int(
self.clock.time_msec() - state.pop("last_active")
)
user_state = {
"user_id": user.to_string(),
}
user_state.update(**state)
yield self.federation.send_edu(
destination=destination,
edu_type="m.presence",
content={
"push": [
user_state,
],
}
)
@defer.inlineCallbacks
def incoming_presence(self, origin, content):
deferreds = []
for push in content.get("push", []):
@@ -911,7 +684,8 @@ class PresenceHandler(BaseHandler):
" | %d interested local observers %r", len(observers), observers
)
room_ids = yield self.get_joined_rooms_for_user(user)
rm_handler = self.homeserver.get_handlers().room_member_handler
room_ids = yield rm_handler.get_joined_rooms_for_user(user)
if room_ids:
logger.debug(" | %d interested room IDs %r", len(room_ids), room_ids)
@@ -930,15 +704,20 @@ class PresenceHandler(BaseHandler):
self.clock.time_msec() - state.pop("last_active_ago")
)
statuscache = self._get_or_make_usercache(user)
self._user_cachemap_latest_serial += 1
yield self.update_presence_cache(user, state, room_ids=room_ids)
statuscache.update(state, serial=self._user_cachemap_latest_serial)
if not observers and not room_ids:
logger.debug(" | no interested observers or room IDs")
continue
self.push_update_to_clients(
users_to_push=observers, room_ids=room_ids
observed_user=user,
users_to_push=observers,
room_ids=room_ids,
statuscache=statuscache,
)
user_id = user.to_string()
@@ -987,58 +766,13 @@ class PresenceHandler(BaseHandler):
if not self._remote_sendmap[user]:
del self._remote_sendmap[user]
yield defer.DeferredList(deferreds, consumeErrors=True)
@defer.inlineCallbacks
def update_presence_cache(self, user, state={}, room_ids=None,
add_to_cache=True):
"""Update the presence cache for a user with a new state and bump the
serial to the latest value.
Args:
user(UserID): The user being updated
state(dict): The presence state being updated
room_ids(None or list of str): A list of room_ids to update. If
room_ids is None then fetch the list of room_ids the user is
joined to.
add_to_cache: Whether to add an entry to the presence cache if the
user isn't already in the cache.
Returns:
A Deferred UserPresenceCache for the user being updated.
"""
if room_ids is None:
room_ids = yield self.get_joined_rooms_for_user(user)
for room_id in room_ids:
self._room_serials[room_id] = self._user_cachemap_latest_serial
if add_to_cache:
statuscache = self._get_or_make_usercache(user)
else:
statuscache = self._get_or_offline_usercache(user)
statuscache.update(state, serial=self._user_cachemap_latest_serial)
defer.returnValue(statuscache)
with PreserveLoggingContext():
yield defer.DeferredList(deferreds, consumeErrors=True)
@defer.inlineCallbacks
def push_update_to_local_and_remote(self, observed_user, statuscache,
users_to_push=[], room_ids=[],
remote_domains=[]):
"""Notify local clients and remote servers of a change in the presence
of a user.
Args:
observed_user(UserID): The user to push the presence state for.
statuscache(UserPresenceCache): The cache for the presence state to
push.
users_to_push([UserID]): A list of local and remote users to
notify.
room_ids([str]): Notify the local and remote occupants of these
rooms.
remote_domains([str]): A list of remote servers to notify in
addition to those implied by the users_to_push and the
room_ids.
Returns:
A Deferred.
"""
localusers, remoteusers = partitionbool(
users_to_push,
@@ -1048,7 +782,10 @@ class PresenceHandler(BaseHandler):
localusers = set(localusers)
self.push_update_to_clients(
users_to_push=localusers, room_ids=room_ids
observed_user=observed_user,
users_to_push=localusers,
room_ids=room_ids,
statuscache=statuscache,
)
remote_domains = set(remote_domains)
@@ -1073,65 +810,11 @@ class PresenceHandler(BaseHandler):
defer.returnValue((localusers, remote_domains))
def push_update_to_clients(self, users_to_push=[], room_ids=[]):
"""Notify clients of a new presence event.
Args:
users_to_push([UserID]): List of users to notify.
room_ids([str]): List of room_ids to notify.
"""
with PreserveLoggingContext():
self.notifier.on_new_event(
"presence_key",
self._user_cachemap_latest_serial,
users_to_push,
room_ids,
)
@defer.inlineCallbacks
def _push_presence_remote(self, user, destination, state=None):
"""Push a user's presence to a remote server. If a presence state event
that event is sent. Otherwise a new state event is constructed from the
stored presence state.
The last_active is replaced with last_active_ago in case the wallclock
time on the remote server is different to the time on this server.
Sends an EDU to the remote server with the current presence state.
Args:
user(UserID): The user to push the presence state for.
destination(str): The remote server to send state to.
state(dict): The state to push, or None to use the current stored
state.
Returns:
A Deferred.
"""
if state is None:
state = yield self.store.get_presence_state(user.localpart)
del state["mtime"]
state["presence"] = state.pop("state")
if user in self._user_cachemap:
state["last_active"] = (
self._user_cachemap[user].get_state()["last_active"]
)
yield self.distributor.fire(
"collect_presencelike_data", user, state
)
if "last_active" in state:
state = dict(state)
state["last_active_ago"] = int(
self.clock.time_msec() - state.pop("last_active")
)
user_state = {"user_id": user.to_string(), }
user_state.update(state)
yield self.federation.send_edu(
destination=destination,
edu_type="m.presence",
content={"push": [user_state, ], }
def push_update_to_clients(self, observed_user, users_to_push=[],
room_ids=[], statuscache=None):
self.notifier.on_new_user_event(
users_to_push,
room_ids,
)
@@ -1140,11 +823,39 @@ class PresenceEventSource(object):
self.hs = hs
self.clock = hs.get_clock()
@defer.inlineCallbacks
def is_visible(self, observer_user, observed_user):
if observer_user == observed_user:
defer.returnValue(True)
presence = self.hs.get_handlers().presence_handler
if (yield presence.store.user_rooms_intersect(
[u.to_string() for u in observer_user, observed_user])):
defer.returnValue(True)
if self.hs.is_mine(observed_user):
pushmap = presence._local_pushmap
defer.returnValue(
observed_user.localpart in pushmap and
observer_user in pushmap[observed_user.localpart]
)
else:
recvmap = presence._remote_recvmap
defer.returnValue(
observed_user in recvmap and
observer_user in recvmap[observed_user]
)
@defer.inlineCallbacks
@log_function
def get_new_events_for_user(self, user, from_key, limit):
from_key = int(from_key)
observer_user = user
presence = self.hs.get_handlers().presence_handler
cachemap = presence._user_cachemap
@@ -1153,27 +864,17 @@ class PresenceEventSource(object):
clock = self.clock
latest_serial = 0
user_ids_to_check = {user}
presence_list = yield presence.store.get_presence_list(
user.localpart, accepted=True
)
if presence_list is not None:
user_ids_to_check |= set(
UserID.from_string(p["observed_user_id"]) for p in presence_list
)
room_ids = yield presence.get_joined_rooms_for_user(user)
for room_id in set(room_ids) & set(presence._room_serials):
if presence._room_serials[room_id] > from_key:
joined = yield presence.get_joined_users_for_room_id(room_id)
user_ids_to_check |= set(joined)
updates = []
for observed_user in user_ids_to_check & set(cachemap):
# TODO(paul): use a DeferredList ? How to limit concurrency.
for observed_user in cachemap.keys():
cached = cachemap[observed_user]
if cached.serial <= from_key or cached.serial > max_serial:
continue
if not (yield self.is_visible(observer_user, observed_user)):
continue
latest_serial = max(cached.serial, latest_serial)
updates.append(cached.make_event(user=observed_user, clock=clock))
@@ -1210,6 +911,8 @@ class PresenceEventSource(object):
def get_pagination_rows(self, user, pagination_config, key):
# TODO (erikj): Does this make sense? Ordering?
observer_user = user
from_key = int(pagination_config.from_key)
if pagination_config.to_key:
@@ -1220,26 +923,14 @@ class PresenceEventSource(object):
presence = self.hs.get_handlers().presence_handler
cachemap = presence._user_cachemap
user_ids_to_check = {user}
presence_list = yield presence.store.get_presence_list(
user.localpart, accepted=True
)
if presence_list is not None:
user_ids_to_check |= set(
UserID.from_string(p["observed_user_id"]) for p in presence_list
)
room_ids = yield presence.get_joined_rooms_for_user(user)
for room_id in set(room_ids) & set(presence._room_serials):
if presence._room_serials[room_id] >= from_key:
joined = yield presence.get_joined_users_for_room_id(room_id)
user_ids_to_check |= set(joined)
updates = []
for observed_user in user_ids_to_check & set(cachemap):
# TODO(paul): use a DeferredList ? How to limit concurrency.
for observed_user in cachemap.keys():
if not (to_key < cachemap[observed_user].serial <= from_key):
continue
updates.append((observed_user, cachemap[observed_user]))
if (yield self.is_visible(observer_user, observed_user)):
updates.append((observed_user, cachemap[observed_user]))
# TODO(paul): limit

View File

@@ -17,8 +17,8 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError, AuthError, CodeMessageException
from synapse.api.constants import EventTypes, Membership
from synapse.util.logcontext import PreserveLoggingContext
from synapse.types import UserID
from synapse.util import unwrapFirstError
from ._base import BaseHandler
@@ -88,9 +88,6 @@ class ProfileHandler(BaseHandler):
if target_user != auth_user:
raise AuthError(400, "Cannot set another user's displayname")
if new_displayname == '':
new_displayname = None
yield self.store.set_profile_displayname(
target_user.localpart, new_displayname
)
@@ -157,13 +154,14 @@ class ProfileHandler(BaseHandler):
if not self.hs.is_mine(user):
defer.returnValue(None)
(displayname, avatar_url) = yield defer.gatherResults(
[
self.store.get_profile_displayname(user.localpart),
self.store.get_profile_avatar_url(user.localpart),
],
consumeErrors=True
).addErrback(unwrapFirstError)
with PreserveLoggingContext():
(displayname, avatar_url) = yield defer.gatherResults(
[
self.store.get_profile_displayname(user.localpart),
self.store.get_profile_avatar_url(user.localpart),
],
consumeErrors=True
)
state["displayname"] = displayname
state["avatar_url"] = avatar_url

View File

@@ -1,210 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import BaseHandler
from twisted.internet import defer
from synapse.util.logcontext import PreserveLoggingContext
import logging
logger = logging.getLogger(__name__)
class ReceiptsHandler(BaseHandler):
def __init__(self, hs):
super(ReceiptsHandler, self).__init__(hs)
self.hs = hs
self.federation = hs.get_replication_layer()
self.federation.register_edu_handler(
"m.receipt", self._received_remote_receipt
)
self.clock = self.hs.get_clock()
self._receipt_cache = None
@defer.inlineCallbacks
def received_client_receipt(self, room_id, receipt_type, user_id,
event_id):
"""Called when a client tells us a local user has read up to the given
event_id in the room.
"""
receipt = {
"room_id": room_id,
"receipt_type": receipt_type,
"user_id": user_id,
"event_ids": [event_id],
"data": {
"ts": int(self.clock.time_msec()),
}
}
is_new = yield self._handle_new_receipts([receipt])
if is_new:
self._push_remotes([receipt])
@defer.inlineCallbacks
def _received_remote_receipt(self, origin, content):
"""Called when we receive an EDU of type m.receipt from a remote HS.
"""
receipts = [
{
"room_id": room_id,
"receipt_type": receipt_type,
"user_id": user_id,
"event_ids": user_values["event_ids"],
"data": user_values.get("data", {}),
}
for room_id, room_values in content.items()
for receipt_type, users in room_values.items()
for user_id, user_values in users.items()
]
yield self._handle_new_receipts(receipts)
@defer.inlineCallbacks
def _handle_new_receipts(self, receipts):
"""Takes a list of receipts, stores them and informs the notifier.
"""
for receipt in receipts:
room_id = receipt["room_id"]
receipt_type = receipt["receipt_type"]
user_id = receipt["user_id"]
event_ids = receipt["event_ids"]
data = receipt["data"]
res = yield self.store.insert_receipt(
room_id, receipt_type, user_id, event_ids, data
)
if not res:
# res will be None if this read receipt is 'old'
defer.returnValue(False)
stream_id, max_persisted_id = res
with PreserveLoggingContext():
self.notifier.on_new_event(
"receipt_key", max_persisted_id, rooms=[room_id]
)
defer.returnValue(True)
@defer.inlineCallbacks
def _push_remotes(self, receipts):
"""Given a list of receipts, works out which remote servers should be
poked and pokes them.
"""
# TODO: Some of this stuff should be coallesced.
for receipt in receipts:
room_id = receipt["room_id"]
receipt_type = receipt["receipt_type"]
user_id = receipt["user_id"]
event_ids = receipt["event_ids"]
data = receipt["data"]
remotedomains = set()
rm_handler = self.hs.get_handlers().room_member_handler
yield rm_handler.fetch_room_distributions_into(
room_id, localusers=None, remotedomains=remotedomains
)
logger.debug("Sending receipt to: %r", remotedomains)
for domain in remotedomains:
self.federation.send_edu(
destination=domain,
edu_type="m.receipt",
content={
room_id: {
receipt_type: {
user_id: {
"event_ids": event_ids,
"data": data,
}
}
},
},
)
@defer.inlineCallbacks
def get_receipts_for_room(self, room_id, to_key):
"""Gets all receipts for a room, upto the given key.
"""
result = yield self.store.get_linearized_receipts_for_room(
room_id,
to_key=to_key,
)
if not result:
defer.returnValue([])
event = {
"type": "m.receipt",
"room_id": room_id,
"content": result,
}
defer.returnValue([event])
class ReceiptEventSource(object):
def __init__(self, hs):
self.store = hs.get_datastore()
@defer.inlineCallbacks
def get_new_events_for_user(self, user, from_key, limit):
from_key = int(from_key)
to_key = yield self.get_current_key()
if from_key == to_key:
defer.returnValue(([], to_key))
rooms = yield self.store.get_rooms_for_user(user.to_string())
rooms = [room.room_id for room in rooms]
events = yield self.store.get_linearized_receipts_for_rooms(
rooms,
from_key=from_key,
to_key=to_key,
)
defer.returnValue((events, to_key))
def get_current_key(self, direction='f'):
return self.store.get_max_receipt_stream_id()
@defer.inlineCallbacks
def get_pagination_rows(self, user, config, key):
to_key = int(config.from_key)
if config.to_key:
from_key = int(config.to_key)
else:
from_key = None
rooms = yield self.store.get_rooms_for_user(user.to_string())
rooms = [room.room_id for room in rooms]
events = yield self.store.get_linearized_receipts_for_rooms(
rooms,
from_key=from_key,
to_key=to_key,
)
defer.returnValue((events, to_key))

View File

@@ -25,6 +25,8 @@ import synapse.util.stringutils as stringutils
from synapse.util.async import run_on_reactor
from synapse.http.client import CaptchaServerHttpClient
import base64
import bcrypt
import logging
import urllib
@@ -55,8 +57,8 @@ class RegistrationHandler(BaseHandler):
yield self.check_user_id_is_valid(user_id)
users = yield self.store.get_users_by_id_case_insensitive(user_id)
if users:
u = yield self.store.get_user_by_id(user_id)
if u:
raise SynapseError(
400,
"User ID already taken.",
@@ -71,8 +73,7 @@ class RegistrationHandler(BaseHandler):
localpart : The local part of the user ID to register. If None,
one will be randomly generated.
password (str) : The password to assign to this user so they can
login again. This can be None which means they cannot login again
via a password (e.g. the user is an application service user).
login again.
Returns:
A tuple of (user_id, access_token).
Raises:
@@ -81,7 +82,7 @@ class RegistrationHandler(BaseHandler):
yield run_on_reactor()
password_hash = None
if password:
password_hash = self.auth_handler().hash(password)
password_hash = bcrypt.hashpw(password, bcrypt.gensalt())
if localpart:
yield self.check_username(localpart)
@@ -89,7 +90,7 @@ class RegistrationHandler(BaseHandler):
user = UserID(localpart, self.hs.hostname)
user_id = user.to_string()
token = self.auth_handler().generate_access_token(user_id)
token = self._generate_token(user_id)
yield self.store.register(
user_id=user_id,
token=token,
@@ -109,7 +110,7 @@ class RegistrationHandler(BaseHandler):
user_id = user.to_string()
yield self.check_user_id_is_valid(user_id)
token = self.auth_handler().generate_access_token(user_id)
token = self._generate_token(user_id)
yield self.store.register(
user_id=user_id,
token=token,
@@ -159,7 +160,7 @@ class RegistrationHandler(BaseHandler):
400, "Invalid user localpart for this application service.",
errcode=Codes.EXCLUSIVE
)
token = self.auth_handler().generate_access_token(user_id)
token = self._generate_token(user_id)
yield self.store.register(
user_id=user_id,
token=token,
@@ -191,35 +192,6 @@ class RegistrationHandler(BaseHandler):
else:
logger.info("Valid captcha entered from %s", ip)
@defer.inlineCallbacks
def register_saml2(self, localpart):
"""
Registers email_id as SAML2 Based Auth.
"""
if urllib.quote(localpart) != localpart:
raise SynapseError(
400,
"User ID must only contain characters which do not"
" require URL encoding."
)
user = UserID(localpart, self.hs.hostname)
user_id = user.to_string()
yield self.check_user_id_is_valid(user_id)
token = self.auth_handler().generate_access_token(user_id)
try:
yield self.store.register(
user_id=user_id,
token=token,
password_hash=None
)
yield self.distributor.fire("registered_user", user)
except Exception, e:
yield self.store.add_access_token_to_user(user_id, token)
# Ignore Registration errors
logger.exception(e)
defer.returnValue((user_id, token))
@defer.inlineCallbacks
def register_email(self, threepidCreds):
"""
@@ -271,6 +243,13 @@ class RegistrationHandler(BaseHandler):
errcode=Codes.EXCLUSIVE
)
def _generate_token(self, user_id):
# urlsafe variant uses _ and - so use . as the separator and replace
# all =s with .s so http clients don't quote =s when it is used as
# query params.
return (base64.urlsafe_b64encode(user_id).replace('=', '.') + '.' +
stringutils.random_string(18))
def _generate_user_id(self):
return "-" + stringutils.random_string(18)
@@ -313,6 +292,3 @@ class RegistrationHandler(BaseHandler):
}
)
defer.returnValue(data)
def auth_handler(self):
return self.hs.get_handlers().auth_handler

View File

@@ -19,36 +19,19 @@ from twisted.internet import defer
from ._base import BaseHandler
from synapse.types import UserID, RoomAlias, RoomID
from synapse.api.constants import (
EventTypes, Membership, JoinRules, RoomCreationPreset,
)
from synapse.api.constants import EventTypes, Membership, JoinRules
from synapse.api.errors import StoreError, SynapseError
from synapse.util import stringutils, unwrapFirstError
from synapse.util import stringutils
from synapse.util.async import run_on_reactor
from synapse.events.utils import serialize_event
from collections import OrderedDict
import logging
import string
logger = logging.getLogger(__name__)
class RoomCreationHandler(BaseHandler):
PRESETS_DICT = {
RoomCreationPreset.PRIVATE_CHAT: {
"join_rules": JoinRules.INVITE,
"history_visibility": "invited",
"original_invitees_have_ops": False,
},
RoomCreationPreset.PUBLIC_CHAT: {
"join_rules": JoinRules.PUBLIC,
"history_visibility": "shared",
"original_invitees_have_ops": False,
},
}
@defer.inlineCallbacks
def create_room(self, user_id, room_id, config):
""" Creates a new room.
@@ -67,10 +50,6 @@ class RoomCreationHandler(BaseHandler):
self.ratelimit(user_id)
if "room_alias_name" in config:
for wchar in string.whitespace:
if wchar in config["room_alias_name"]:
raise SynapseError(400, "Invalid characters in room alias")
room_alias = RoomAlias.create(
config["room_alias_name"],
self.hs.hostname,
@@ -137,28 +116,9 @@ class RoomCreationHandler(BaseHandler):
servers=[self.hs.hostname],
)
preset_config = config.get(
"preset",
RoomCreationPreset.PUBLIC_CHAT
if is_public
else RoomCreationPreset.PRIVATE_CHAT
)
raw_initial_state = config.get("initial_state", [])
initial_state = OrderedDict()
for val in raw_initial_state:
initial_state[(val["type"], val.get("state_key", ""))] = val["content"]
creation_content = config.get("creation_content", {})
user = UserID.from_string(user_id)
creation_events = self._create_events_for_new_room(
user, room_id,
preset_config=preset_config,
invite_list=invite_list,
initial_state=initial_state,
creation_content=creation_content,
user, room_id, is_public=is_public
)
msg_handler = self.hs.get_handlers().message_handler
@@ -205,10 +165,7 @@ class RoomCreationHandler(BaseHandler):
defer.returnValue(result)
def _create_events_for_new_room(self, creator, room_id, preset_config,
invite_list, initial_state, creation_content):
config = RoomCreationHandler.PRESETS_DICT[preset_config]
def _create_events_for_new_room(self, creator, room_id, is_public=False):
creator_id = creator.to_string()
event_keys = {
@@ -228,10 +185,9 @@ class RoomCreationHandler(BaseHandler):
return e
creation_content.update({"creator": creator.to_string()})
creation_event = create(
etype=EventTypes.Create,
content=creation_content,
content={"creator": creator.to_string()},
)
join_event = create(
@@ -242,20 +198,16 @@ class RoomCreationHandler(BaseHandler):
},
)
returned_events = [creation_event, join_event]
if (EventTypes.PowerLevels, '') not in initial_state:
power_level_content = {
power_levels_event = create(
etype=EventTypes.PowerLevels,
content={
"users": {
creator.to_string(): 100,
},
"users_default": 0,
"events": {
EventTypes.Name: 50,
EventTypes.Name: 100,
EventTypes.PowerLevels: 100,
EventTypes.RoomHistoryVisibility: 100,
EventTypes.CanonicalAlias: 50,
EventTypes.RoomAvatar: 50,
},
"events_default": 0,
"state_default": 50,
@@ -263,43 +215,21 @@ class RoomCreationHandler(BaseHandler):
"kick": 50,
"redact": 50,
"invite": 0,
}
},
)
if config["original_invitees_have_ops"]:
for invitee in invite_list:
power_level_content["users"][invitee] = 100
join_rule = JoinRules.PUBLIC if is_public else JoinRules.INVITE
join_rules_event = create(
etype=EventTypes.JoinRules,
content={"join_rule": join_rule},
)
power_levels_event = create(
etype=EventTypes.PowerLevels,
content=power_level_content,
)
returned_events.append(power_levels_event)
if (EventTypes.JoinRules, '') not in initial_state:
join_rules_event = create(
etype=EventTypes.JoinRules,
content={"join_rule": config["join_rules"]},
)
returned_events.append(join_rules_event)
if (EventTypes.RoomHistoryVisibility, '') not in initial_state:
history_event = create(
etype=EventTypes.RoomHistoryVisibility,
content={"history_visibility": config["history_visibility"]}
)
returned_events.append(history_event)
for (etype, state_key), content in initial_state.items():
returned_events.append(create(
etype=etype,
state_key=state_key,
content=content,
))
return returned_events
return [
creation_event,
join_event,
power_levels_event,
join_rules_event,
]
class RoomMemberHandler(BaseHandler):
@@ -563,9 +493,15 @@ class RoomMemberHandler(BaseHandler):
"""Returns a list of roomids that the user has any of the given
membership states in."""
rooms = yield self.store.get_rooms_for_user(
user.to_string(),
app_service = yield self.store.get_app_service_by_user_id(
user.to_string()
)
if app_service:
rooms = yield self.store.get_app_service_rooms(app_service)
else:
rooms = yield self.store.get_rooms_for_user(
user.to_string(),
)
# For some reason the list of events contains duplicates
# TODO(paul): work out why because I really don't think it should
@@ -593,17 +529,11 @@ class RoomListHandler(BaseHandler):
@defer.inlineCallbacks
def get_public_room_list(self):
chunk = yield self.store.get_rooms(is_public=True)
results = yield defer.gatherResults(
[
self.store.get_users_in_room(room["room_id"])
for room in chunk
],
consumeErrors=True,
).addErrback(unwrapFirstError)
for i, room in enumerate(chunk):
room["num_joined_members"] = len(results[i])
for room in chunk:
joined_users = yield self.store.get_users_in_room(
room_id=room["room_id"],
)
room["num_joined_members"] = len(joined_users)
# FIXME (erikj): START is no longer a valid value
defer.returnValue({"start": "START", "end": "END", "chunk": chunk})
@@ -639,8 +569,8 @@ class RoomEventSource(object):
defer.returnValue((events, end_key))
def get_current_key(self, direction='f'):
return self.store.get_room_events_max_id(direction)
def get_current_key(self):
return self.store.get_room_events_max_id()
@defer.inlineCallbacks
def get_pagination_rows(self, user, config, key):

View File

@@ -28,6 +28,7 @@ logger = logging.getLogger(__name__)
SyncConfig = collections.namedtuple("SyncConfig", [
"user",
"client_info",
"limit",
"gap",
"sort",
@@ -91,22 +92,13 @@ class SyncHandler(BaseHandler):
result = yield self.current_sync_for_user(sync_config, since_token)
defer.returnValue(result)
else:
def current_sync_callback(before_token, after_token):
def current_sync_callback():
return self.current_sync_for_user(sync_config, since_token)
rm_handler = self.hs.get_handlers().room_member_handler
app_service = yield self.store.get_app_service_by_user_id(
sync_config.user.to_string()
room_ids = yield rm_handler.get_joined_rooms_for_user(
sync_config.user
)
if app_service:
rooms = yield self.store.get_app_service_rooms(app_service)
room_ids = set(r.room_id for r in rooms)
else:
room_ids = yield rm_handler.get_joined_rooms_for_user(
sync_config.user
)
result = yield self.notifier.wait_for_events(
sync_config.user, room_ids,
sync_config.filter, timeout, current_sync_callback
@@ -237,16 +229,7 @@ class SyncHandler(BaseHandler):
logger.debug("Typing %r", typing_by_room)
rm_handler = self.hs.get_handlers().room_member_handler
app_service = yield self.store.get_app_service_by_user_id(
sync_config.user.to_string()
)
if app_service:
rooms = yield self.store.get_app_service_rooms(app_service)
room_ids = set(r.room_id for r in rooms)
else:
room_ids = yield rm_handler.get_joined_rooms_for_user(
sync_config.user
)
room_ids = yield rm_handler.get_joined_rooms_for_user(sync_config.user)
# TODO (mjark): Does public mean "published"?
published_rooms = yield self.store.get_rooms(is_public=True)
@@ -309,52 +292,6 @@ class SyncHandler(BaseHandler):
next_batch=now_token,
))
@defer.inlineCallbacks
def _filter_events_for_client(self, user_id, room_id, events):
event_id_to_state = yield self.store.get_state_for_events(
room_id, frozenset(e.event_id for e in events),
types=(
(EventTypes.RoomHistoryVisibility, ""),
(EventTypes.Member, user_id),
)
)
def allowed(event, state):
if event.type == EventTypes.RoomHistoryVisibility:
return True
membership_ev = state.get((EventTypes.Member, user_id), None)
if membership_ev:
membership = membership_ev.membership
else:
membership = Membership.LEAVE
if membership == Membership.JOIN:
return True
history = state.get((EventTypes.RoomHistoryVisibility, ''), None)
if history:
visibility = history.content.get("history_visibility", "shared")
else:
visibility = "shared"
if visibility == "public":
return True
elif visibility == "shared":
return True
elif visibility == "joined":
return membership == Membership.JOIN
elif visibility == "invited":
return membership == Membership.INVITE
return True
defer.returnValue([
event
for event in events
if allowed(event, event_id_to_state[event.event_id])
])
@defer.inlineCallbacks
def load_filtered_recents(self, room_id, sync_config, now_token,
since_token=None):
@@ -376,9 +313,6 @@ class SyncHandler(BaseHandler):
(room_key, _) = keys
end_key = "s" + room_key.split('-')[-1]
loaded_recents = sync_config.filter.filter_room_events(events)
loaded_recents = yield self._filter_events_for_client(
sync_config.user.to_string(), room_id, loaded_recents,
)
loaded_recents.extend(recents)
recents = loaded_recents
if len(events) <= load_limit:

View File

@@ -18,7 +18,6 @@ from twisted.internet import defer
from ._base import BaseHandler
from synapse.api.errors import SynapseError, AuthError
from synapse.util.logcontext import PreserveLoggingContext
from synapse.types import UserID
import logging
@@ -204,19 +203,20 @@ class TypingNotificationHandler(BaseHandler):
)
def _push_update_local(self, room_id, user, typing):
room_set = self._room_typing.setdefault(room_id, set())
if room_id not in self._room_serials:
self._room_serials[room_id] = 0
self._room_typing[room_id] = set()
room_set = self._room_typing[room_id]
if typing:
room_set.add(user)
else:
room_set.discard(user)
elif user in room_set:
room_set.remove(user)
self._latest_room_serial += 1
self._room_serials[room_id] = self._latest_room_serial
with PreserveLoggingContext():
self.notifier.on_new_event(
"typing_key", self._latest_room_serial, rooms=[room_id]
)
self.notifier.on_new_user_event(rooms=[room_id])
class TypingNotificationEventSource(object):
@@ -256,8 +256,8 @@ class TypingNotificationEventSource(object):
)
events = []
for room_id in joined_room_ids:
if room_id not in handler._room_serials:
for room_id in handler._room_serials:
if room_id not in joined_room_ids:
continue
if handler._room_serials[room_id] <= from_key:
continue

View File

@@ -14,15 +14,12 @@
# limitations under the License.
from synapse.api.errors import CodeMessageException
from synapse.util.logcontext import preserve_context_over_fn
from syutil.jsonutil import encode_canonical_json
import synapse.metrics
from canonicaljson import encode_canonical_json
from twisted.internet import defer, reactor
from twisted.web.client import (
Agent, readBody, FileBodyProducer, PartialDownloadError,
HTTPConnectionPool,
Agent, readBody, FileBodyProducer, PartialDownloadError
)
from twisted.web.http_headers import Headers
@@ -57,36 +54,21 @@ class SimpleHttpClient(object):
# The default context factory in Twisted 14.0.0 (which we require) is
# BrowserLikePolicyForHTTPS which will do regular cert validation
# 'like a browser'
pool = HTTPConnectionPool(reactor)
pool.maxPersistentPerHost = 10
self.agent = Agent(reactor, pool=pool)
self.agent = Agent(reactor)
self.version_string = hs.version_string
def request(self, method, uri, *args, **kwargs):
def request(self, method, *args, **kwargs):
# A small wrapper around self.agent.request() so we can easily attach
# counters to it
outgoing_requests_counter.inc(method)
d = preserve_context_over_fn(
self.agent.request,
method, uri, *args, **kwargs
)
logger.info("Sending request %s %s", method, uri)
d = self.agent.request(method, *args, **kwargs)
def _cb(response):
incoming_responses_counter.inc(method, response.code)
logger.info(
"Received response to %s %s: %s",
method, uri, response.code
)
return response
def _eb(failure):
incoming_responses_counter.inc(method, "ERR")
logger.info(
"Error sending request to %s %s: %s %s",
method, uri, failure.type, failure.getErrorMessage()
)
return failure
d.addCallbacks(_cb, _eb)
@@ -95,9 +77,7 @@ class SimpleHttpClient(object):
@defer.inlineCallbacks
def post_urlencoded_get_json(self, uri, args={}):
# TODO: Do we ever want to log message contents?
logger.debug("post_urlencoded_get_json args: %s", args)
query_bytes = urllib.urlencode(args, True)
response = yield self.request(
@@ -110,7 +90,7 @@ class SimpleHttpClient(object):
bodyProducer=FileBodyProducer(StringIO(query_bytes))
)
body = yield preserve_context_over_fn(readBody, response)
body = yield readBody(response)
defer.returnValue(json.loads(body))
@@ -118,7 +98,7 @@ class SimpleHttpClient(object):
def post_json_get_json(self, uri, post_json):
json_str = encode_canonical_json(post_json)
logger.debug("HTTP POST %s -> %s", json_str, uri)
logger.info("HTTP POST %s -> %s", json_str, uri)
response = yield self.request(
"POST",
@@ -129,7 +109,7 @@ class SimpleHttpClient(object):
bodyProducer=FileBodyProducer(StringIO(json_str))
)
body = yield preserve_context_over_fn(readBody, response)
body = yield readBody(response)
defer.returnValue(json.loads(body))
@@ -162,7 +142,7 @@ class SimpleHttpClient(object):
})
)
body = yield preserve_context_over_fn(readBody, response)
body = yield readBody(response)
if 200 <= response.code < 300:
defer.returnValue(json.loads(body))
@@ -205,7 +185,7 @@ class SimpleHttpClient(object):
bodyProducer=FileBodyProducer(StringIO(json_str))
)
body = yield preserve_context_over_fn(readBody, response)
body = yield readBody(response)
if 200 <= response.code < 300:
defer.returnValue(json.loads(body))
@@ -239,7 +219,7 @@ class CaptchaServerHttpClient(SimpleHttpClient):
)
try:
body = yield preserve_context_over_fn(readBody, response)
body = yield readBody(response)
defer.returnValue(body)
except PartialDownloadError as e:
# twisted dislikes google's response, no content length.

View File

@@ -16,32 +16,30 @@
from twisted.internet import defer, reactor, protocol
from twisted.internet.error import DNSLookupError
from twisted.web.client import readBody, HTTPConnectionPool, Agent
from twisted.web.client import readBody, _AgentBase, _URI
from twisted.web.http_headers import Headers
from twisted.web._newclient import ResponseDone
from synapse.http.endpoint import matrix_federation_endpoint
from synapse.util.async import sleep
from synapse.util.logcontext import preserve_context_over_fn
from synapse.util.logcontext import PreserveLoggingContext
import synapse.metrics
from canonicaljson import encode_canonical_json
from syutil.jsonutil import encode_canonical_json
from synapse.api.errors import (
SynapseError, Codes, HttpResponseException,
)
from signedjson.sign import sign_json
from syutil.crypto.jsonsign import sign_json
import simplejson as json
import logging
import sys
import urllib
import urlparse
logger = logging.getLogger(__name__)
outbound_logger = logging.getLogger("synapse.http.outbound")
metrics = synapse.metrics.get_metrics_for(__name__)
@@ -55,17 +53,41 @@ incoming_responses_counter = metrics.register_counter(
)
class MatrixFederationEndpointFactory(object):
def __init__(self, hs):
self.tls_context_factory = hs.tls_context_factory
class MatrixFederationHttpAgent(_AgentBase):
def endpointForURI(self, uri):
destination = uri.netloc
def __init__(self, reactor, pool=None):
_AgentBase.__init__(self, reactor, pool)
return matrix_federation_endpoint(
reactor, destination, timeout=10,
ssl_context_factory=self.tls_context_factory
)
def request(self, destination, endpoint, method, path, params, query,
headers, body_producer):
outgoing_requests_counter.inc(method)
host = b""
port = 0
fragment = b""
parsed_URI = _URI(b"http", destination, host, port, path, params,
query, fragment)
# Set the connection pool key to be the destination.
key = destination
d = self._requestWithEndpoint(key, endpoint, method, parsed_URI,
headers, body_producer,
parsed_URI.originForm)
def _cb(response):
incoming_responses_counter.inc(method, response.code)
return response
def _eb(failure):
incoming_responses_counter.inc(method, "ERR")
return failure
d.addCallbacks(_cb, _eb)
return d
class MatrixFederationHttpClient(object):
@@ -81,123 +103,105 @@ class MatrixFederationHttpClient(object):
self.hs = hs
self.signing_key = hs.config.signing_key[0]
self.server_name = hs.hostname
pool = HTTPConnectionPool(reactor)
pool.maxPersistentPerHost = 10
self.agent = Agent.usingEndpointFactory(
reactor, MatrixFederationEndpointFactory(hs), pool=pool
)
self.agent = MatrixFederationHttpAgent(reactor)
self.clock = hs.get_clock()
self.version_string = hs.version_string
self._next_id = 1
def _create_url(self, destination, path_bytes, param_bytes, query_bytes):
return urlparse.urlunparse(
("matrix", destination, path_bytes, param_bytes, query_bytes, "")
)
@defer.inlineCallbacks
def _create_request(self, destination, method, path_bytes,
body_callback, headers_dict={}, param_bytes=b"",
query_bytes=b"", retry_on_dns_fail=True,
timeout=None):
query_bytes=b"", retry_on_dns_fail=True):
""" Creates and sends a request to the given url
"""
headers_dict[b"User-Agent"] = [self.version_string]
headers_dict[b"Host"] = [destination]
url_bytes = self._create_url(
destination, path_bytes, param_bytes, query_bytes
url_bytes = urlparse.urlunparse(
("", "", path_bytes, param_bytes, query_bytes, "",)
)
txn_id = "%s-O-%s" % (method, self._next_id)
self._next_id = (self._next_id + 1) % (sys.maxint - 1)
logger.info("Sending request to %s: %s %s",
destination, method, url_bytes)
outbound_logger.info(
"{%s} [%s] Sending request: %s %s",
txn_id, destination, method, url_bytes
logger.debug(
"Types: %s",
[
type(destination), type(method), type(path_bytes),
type(param_bytes),
type(query_bytes)
]
)
# XXX: Would be much nicer to retry only at the transaction-layer
# (once we have reliable transactions in place)
retries_left = 5
http_url_bytes = urlparse.urlunparse(
("", "", path_bytes, param_bytes, query_bytes, "")
)
endpoint = self._getEndpoint(reactor, destination)
log_result = None
try:
while True:
producer = None
if body_callback:
producer = body_callback(method, http_url_bytes, headers_dict)
while True:
producer = None
if body_callback:
producer = body_callback(method, url_bytes, headers_dict)
try:
def send_request():
request_deferred = preserve_context_over_fn(
self.agent.request,
method,
url_bytes,
Headers(headers_dict),
producer
)
return self.clock.time_bound_deferred(
request_deferred,
time_out=timeout/1000. if timeout else 60,
)
response = yield preserve_context_over_fn(
send_request,
)
log_result = "%d %s" % (response.code, response.phrase,)
break
except Exception as e:
if not retry_on_dns_fail and isinstance(e, DNSLookupError):
logger.warn(
"DNS Lookup failed to %s with %s",
destination,
e
)
log_result = "DNS Lookup failed to %s with %s" % (
destination, e
)
raise
logger.warn(
"{%s} Sending request failed to %s: %s %s: %s - %s",
txn_id,
try:
with PreserveLoggingContext():
request_deferred = self.agent.request(
destination,
endpoint,
method,
url_bytes,
type(e).__name__,
_flatten_response_never_received(e),
path_bytes,
param_bytes,
query_bytes,
Headers(headers_dict),
producer
)
log_result = "%s - %s" % (
type(e).__name__, _flatten_response_never_received(e),
response = yield self.clock.time_bound_deferred(
request_deferred,
time_out=60,
)
if retries_left and not timeout:
yield sleep(2 ** (5 - retries_left))
retries_left -= 1
else:
raise
finally:
outbound_logger.info(
"{%s} [%s] Result: %s",
txn_id,
destination,
log_result,
)
logger.debug("Got response to %s", method)
break
except Exception as e:
if not retry_on_dns_fail and isinstance(e, DNSLookupError):
logger.warn(
"DNS Lookup failed to %s with %s",
destination,
e
)
raise
logger.warn(
"Sending request failed to %s: %s %s: %s - %s",
destination,
method,
url_bytes,
type(e).__name__,
_flatten_response_never_received(e),
)
if retries_left:
yield sleep(2 ** (5 - retries_left))
retries_left -= 1
else:
raise
logger.info(
"Received response %d %s for %s: %s %s",
response.code,
response.phrase,
destination,
method,
url_bytes
)
if 200 <= response.code < 300:
pass
else:
# :'(
# Update transactions table?
body = yield preserve_context_over_fn(readBody, response)
body = yield readBody(response)
raise HttpResponseException(
response.code, response.phrase, body
)
@@ -277,7 +281,10 @@ class MatrixFederationHttpClient(object):
"Content-Type not application/json"
)
body = yield preserve_context_over_fn(readBody, response)
logger.debug("Getting resp body")
body = yield readBody(response)
logger.debug("Got resp body")
defer.returnValue(json.loads(body))
@defer.inlineCallbacks
@@ -320,13 +327,14 @@ class MatrixFederationHttpClient(object):
"Content-Type not application/json"
)
body = yield preserve_context_over_fn(readBody, response)
logger.debug("Getting resp body")
body = yield readBody(response)
logger.debug("Got resp body")
defer.returnValue(json.loads(body))
@defer.inlineCallbacks
def get_json(self, destination, path, args={}, retry_on_dns_fail=True,
timeout=None):
def get_json(self, destination, path, args={}, retry_on_dns_fail=True):
""" GETs some json from the given host homeserver and path
Args:
@@ -335,9 +343,6 @@ class MatrixFederationHttpClient(object):
path (str): The HTTP path.
args (dict): A dictionary used to create query strings, defaults to
None.
timeout (int): How long to try (in ms) the destination for before
giving up. None indicates no timeout and that the request will
be retried.
Returns:
Deferred: Succeeds when we get *any* HTTP response.
@@ -365,8 +370,7 @@ class MatrixFederationHttpClient(object):
path.encode("ascii"),
query_bytes=query_bytes,
body_callback=body_callback,
retry_on_dns_fail=retry_on_dns_fail,
timeout=timeout,
retry_on_dns_fail=retry_on_dns_fail
)
if 200 <= response.code < 300:
@@ -378,7 +382,9 @@ class MatrixFederationHttpClient(object):
"Content-Type not application/json"
)
body = yield preserve_context_over_fn(readBody, response)
logger.debug("Getting resp body")
body = yield readBody(response)
logger.debug("Got resp body")
defer.returnValue(json.loads(body))
@@ -421,16 +427,19 @@ class MatrixFederationHttpClient(object):
headers = dict(response.headers.getAllRawHeaders())
try:
length = yield preserve_context_over_fn(
_readBodyToFile,
response, output_stream, max_size
)
length = yield _readBodyToFile(response, output_stream, max_size)
except:
logger.exception("Failed to download body")
raise
defer.returnValue((length, headers))
def _getEndpoint(self, reactor, destination):
return matrix_federation_endpoint(
reactor, destination, timeout=10,
ssl_context_factory=self.hs.tls_context_factory
)
class _ReadBodyToFileProtocol(protocol.Protocol):
def __init__(self, stream, deferred, max_size):

View File

@@ -17,11 +17,10 @@
from synapse.api.errors import (
cs_exception, SynapseError, CodeMessageException, UnrecognizedRequestError
)
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
from synapse.util.logcontext import LoggingContext
import synapse.metrics
import synapse.events
from canonicaljson import (
from syutil.jsonutil import (
encode_canonical_json, encode_pretty_printed_json
)
@@ -33,7 +32,6 @@ from twisted.web.util import redirectTo
import collections
import logging
import urllib
import ujson
logger = logging.getLogger(__name__)
@@ -80,39 +78,51 @@ def request_handler(request_handler):
_next_request_id += 1
with LoggingContext(request_id) as request_context:
request_context.request = request_id
with request.processing():
try:
d = request_handler(self, request)
with PreserveLoggingContext():
yield d
except CodeMessageException as e:
code = e.code
if isinstance(e, SynapseError):
logger.info(
"%s SynapseError: %s - %s", request, code, e.msg
)
else:
logger.exception(e)
outgoing_responses_counter.inc(request.method, str(code))
respond_with_json(
request, code, cs_exception(e), send_cors=True,
pretty_print=_request_user_agent_is_curl(request),
version_string=self.version_string,
)
except:
logger.exception(
"Failed handle request %s.%s on %r: %r",
request_handler.__module__,
request_handler.__name__,
self,
request
)
respond_with_json(
request,
500,
{"error": "Internal server error"},
send_cors=True
code = None
start = self.clock.time_msec()
try:
logger.info(
"Received request: %s %s",
request.method, request.path
)
yield request_handler(self, request)
code = request.code
except CodeMessageException as e:
code = e.code
if isinstance(e, SynapseError):
logger.info(
"%s SynapseError: %s - %s", request, code, e.msg
)
else:
logger.exception(e)
outgoing_responses_counter.inc(request.method, str(code))
respond_with_json(
request, code, cs_exception(e), send_cors=True,
pretty_print=_request_user_agent_is_curl(request),
version_string=self.version_string,
)
except:
code = 500
logger.exception(
"Failed handle request %s.%s on %r: %r",
request_handler.__module__,
request_handler.__name__,
self,
request
)
respond_with_json(
request,
500,
{"error": "Internal server error"},
send_cors=True
)
finally:
code = str(code) if code else "-"
end = self.clock.time_msec()
logger.info(
"Processed request: %dms %s %s %s",
end-start, code, request.method, request.path
)
return wrapped_request_handler
@@ -156,10 +166,9 @@ class JsonResource(HttpServer, resource.Resource):
_PathEntry = collections.namedtuple("_PathEntry", ["pattern", "callback"])
def __init__(self, hs, canonical_json=True):
def __init__(self, hs):
resource.Resource.__init__(self)
self.canonical_json = canonical_json
self.clock = hs.get_clock()
self.path_regexs = {}
self.version_string = hs.version_string
@@ -208,7 +217,7 @@ class JsonResource(HttpServer, resource.Resource):
incoming_requests_counter.inc(request.method, servlet_classname)
args = [
urllib.unquote(u).decode("UTF-8") if u else u for u in m.groups()
urllib.unquote(u).decode("UTF-8") for u in m.groups()
]
callback_return = yield callback(request, *args)
@@ -245,7 +254,6 @@ class JsonResource(HttpServer, resource.Resource):
response_code_message=response_code_message,
pretty_print=_request_user_agent_is_curl(request),
version_string=self.version_string,
canonical_json=self.canonical_json,
)
@@ -267,15 +275,11 @@ class RootRedirect(resource.Resource):
def respond_with_json(request, code, json_object, send_cors=False,
response_code_message=None, pretty_print=False,
version_string="", canonical_json=True):
version_string=""):
if pretty_print:
json_bytes = encode_pretty_printed_json(json_object) + "\n"
else:
if canonical_json or synapse.events.USE_FROZEN_DICTS:
json_bytes = encode_canonical_json(json_object)
else:
# ujson doesn't like frozen_dicts.
json_bytes = ujson.dumps(json_object, ensure_ascii=False)
json_bytes = encode_canonical_json(json_object)
return respond_with_json_bytes(
request, code, json_bytes,

View File

@@ -17,13 +17,9 @@
from __future__ import absolute_import
import logging
from resource import getrusage, RUSAGE_SELF
import functools
from resource import getrusage, getpagesize, RUSAGE_SELF
import os
import stat
import time
from twisted.internet import reactor
from .metric import (
CounterMetric, CallbackMetric, DistributionMetric, CacheMetric
@@ -100,6 +96,7 @@ def render_all():
# process resource usage
rusage = None
PAGE_SIZE = getpagesize()
def update_resource_metrics():
@@ -112,8 +109,8 @@ resource_metrics = get_metrics_for("process.resource")
resource_metrics.register_callback("utime", lambda: rusage.ru_utime * 1000)
resource_metrics.register_callback("stime", lambda: rusage.ru_stime * 1000)
# kilobytes
resource_metrics.register_callback("maxrss", lambda: rusage.ru_maxrss * 1024)
# pages
resource_metrics.register_callback("maxrss", lambda: rusage.ru_maxrss * PAGE_SIZE)
TYPES = {
stat.S_IFSOCK: "SOCK",
@@ -130,10 +127,6 @@ def _process_fds():
counts = {(k,): 0 for k in TYPES.values()}
counts[("other",)] = 0
# Not every OS will have a /proc/self/fd directory
if not os.path.exists("/proc/self/fd"):
return counts
for fd in os.listdir("/proc/self/fd"):
try:
s = os.stat("/proc/self/fd/%s" % (fd))
@@ -151,50 +144,3 @@ def _process_fds():
return counts
get_metrics_for("process").register_callback("fds", _process_fds, labels=["type"])
reactor_metrics = get_metrics_for("reactor")
tick_time = reactor_metrics.register_distribution("tick_time")
pending_calls_metric = reactor_metrics.register_distribution("pending_calls")
def runUntilCurrentTimer(func):
@functools.wraps(func)
def f(*args, **kwargs):
now = reactor.seconds()
num_pending = 0
# _newTimedCalls is one long list of *all* pending calls. Below loop
# is based off of impl of reactor.runUntilCurrent
for delayed_call in reactor._newTimedCalls:
if delayed_call.time > now:
break
if delayed_call.delayed_time > 0:
continue
num_pending += 1
num_pending += len(reactor.threadCallQueue)
start = time.time() * 1000
ret = func(*args, **kwargs)
end = time.time() * 1000
tick_time.inc_by(end - start)
pending_calls_metric.inc_by(num_pending)
return ret
return f
try:
# Ensure the reactor has all the attributes we expect
reactor.runUntilCurrent
reactor._newTimedCalls
reactor.threadCallQueue
# runUntilCurrent is called when we have pending calls. It is called once
# per iteratation after fd polling.
reactor.runUntilCurrent = runUntilCurrentTimer(reactor.runUntilCurrent)
except AttributeError:
pass

View File

@@ -16,7 +16,8 @@
from twisted.internet import defer
from synapse.util.logutils import log_function
from synapse.util.async import run_on_reactor, ObservableDeferred
from synapse.util.logcontext import PreserveLoggingContext
from synapse.util.async import run_on_reactor
from synapse.types import StreamToken
import synapse.metrics
@@ -42,78 +43,62 @@ def count(func, l):
class _NotificationListener(object):
""" This represents a single client connection to the events stream.
The events stream handler will have yielded to the deferred, so to
notify the handler it is sufficient to resolve the deferred.
"""
__slots__ = ["deferred"]
def __init__(self, deferred):
self.deferred = deferred
class _NotifierUserStream(object):
"""This represents a user connected to the event stream.
It tracks the most recent stream token for that user.
At a given point a user may have a number of streams listening for
events.
This listener will also keep track of which rooms it is listening in
so that it can remove itself from the indexes in the Notifier class.
"""
def __init__(self, user, rooms, current_token, time_now_ms,
def __init__(self, user, rooms, from_token, limit, timeout, deferred,
appservice=None):
self.user = str(user)
self.user = user
self.appservice = appservice
self.rooms = set(rooms)
self.current_token = current_token
self.last_notified_ms = time_now_ms
self.from_token = from_token
self.limit = limit
self.timeout = timeout
self.deferred = deferred
self.rooms = rooms
self.timer = None
self.notify_deferred = ObservableDeferred(defer.Deferred())
def notified(self):
return self.deferred.called
def notify(self, stream_key, stream_id, time_now_ms):
"""Notify any listeners for this user of a new event from an
event source.
Args:
stream_key(str): The stream the event came from.
stream_id(str): The new id for the stream the event came from.
time_now_ms(int): The current time in milliseconds.
"""
self.current_token = self.current_token.copy_and_advance(
stream_key, stream_id
)
self.last_notified_ms = time_now_ms
noify_deferred = self.notify_deferred
self.notify_deferred = ObservableDeferred(defer.Deferred())
noify_deferred.callback(self.current_token)
def remove(self, notifier):
""" Remove this listener from all the indexes in the Notifier
def notify(self, notifier, events, start_token, end_token):
""" Inform whoever is listening about the new events. This will
also remove this listener from all the indexes in the Notifier
it knows about.
"""
result = (events, (start_token, end_token))
try:
self.deferred.callback(result)
notified_events_counter.inc_by(len(events))
except defer.AlreadyCalledError:
pass
# Should the following be done be using intrusively linked lists?
# -- erikj
for room in self.rooms:
lst = notifier.room_to_user_streams.get(room, set())
lst = notifier.room_to_listeners.get(room, set())
lst.discard(self)
notifier.user_to_user_stream.pop(self.user)
notifier.user_to_listeners.get(self.user, set()).discard(self)
if self.appservice:
notifier.appservice_to_user_streams.get(
notifier.appservice_to_listeners.get(
self.appservice, set()
).discard(self)
def count_listeners(self):
return len(self.notify_deferred.observers())
def new_listener(self, token):
"""Returns a deferred that is resolved when there is a new token
greater than the given token.
"""
if self.current_token.is_after(token):
return _NotificationListener(defer.succeed(self.current_token))
else:
return _NotificationListener(self.notify_deferred.observe())
# Cancel the timeout for this notifer if one exists.
if self.timer is not None:
try:
notifier.clock.cancel_call_later(self.timer)
except:
logger.warn("Failed to cancel notifier timer")
class Notifier(object):
@@ -123,18 +108,14 @@ class Notifier(object):
Primarily used from the /events stream.
"""
UNUSED_STREAM_EXPIRY_MS = 10 * 60 * 1000
def __init__(self, hs):
self.hs = hs
self.user_to_user_stream = {}
self.room_to_user_streams = {}
self.appservice_to_user_streams = {}
self.room_to_listeners = {}
self.user_to_listeners = {}
self.appservice_to_listeners = {}
self.event_sources = hs.get_event_sources()
self.store = hs.get_datastore()
self.pending_new_room_events = []
self.clock = hs.get_clock()
@@ -142,272 +123,370 @@ class Notifier(object):
"user_joined_room", self._user_joined_room
)
self.clock.looping_call(
self.remove_expired_streams, self.UNUSED_STREAM_EXPIRY_MS
)
# This is not a very cheap test to perform, but it's only executed
# when rendering the metrics page, which is likely once per minute at
# most when scraping it.
def count_listeners():
all_user_streams = set()
all_listeners = set()
for x in self.room_to_user_streams.values():
all_user_streams |= x
for x in self.user_to_user_stream.values():
all_user_streams.add(x)
for x in self.appservice_to_user_streams.values():
all_user_streams |= x
for x in self.room_to_listeners.values():
all_listeners |= x
for x in self.user_to_listeners.values():
all_listeners |= x
for x in self.appservice_to_listeners.values():
all_listeners |= x
return sum(stream.count_listeners() for stream in all_user_streams)
return len(all_listeners)
metrics.register_callback("listeners", count_listeners)
metrics.register_callback(
"rooms",
lambda: count(bool, self.room_to_user_streams.values()),
lambda: count(bool, self.room_to_listeners.values()),
)
metrics.register_callback(
"users",
lambda: len(self.user_to_user_stream),
lambda: count(bool, self.user_to_listeners.values()),
)
metrics.register_callback(
"appservices",
lambda: count(bool, self.appservice_to_user_streams.values()),
lambda: count(bool, self.appservice_to_listeners.values()),
)
@log_function
@defer.inlineCallbacks
def on_new_room_event(self, event, room_stream_id, max_room_stream_id,
extra_users=[]):
def on_new_room_event(self, event, extra_users=[]):
""" Used by handlers to inform the notifier something has happened
in the room, room event wise.
This triggers the notifier to wake up any listeners that are
listening to the room, and any listeners for the users in the
`extra_users` param.
The events can be peristed out of order. The notifier will wait
until all previous events have been persisted before notifying
the client streams.
"""
yield run_on_reactor()
self.pending_new_room_events.append((
room_stream_id, event, extra_users
))
self._notify_pending_new_room_events(max_room_stream_id)
def _notify_pending_new_room_events(self, max_room_stream_id):
"""Notify for the room events that were queued waiting for a previous
event to be persisted.
Args:
max_room_stream_id(int): The highest stream_id below which all
events have been persisted.
"""
pending = self.pending_new_room_events
self.pending_new_room_events = []
for room_stream_id, event, extra_users in pending:
if room_stream_id > max_room_stream_id:
self.pending_new_room_events.append((
room_stream_id, event, extra_users
))
else:
self._on_new_room_event(event, room_stream_id, extra_users)
def _on_new_room_event(self, event, room_stream_id, extra_users=[]):
"""Notify any user streams that are interested in this room event"""
# poke any interested application service.
self.hs.get_handlers().appservice_handler.notify_interested_services(
event
)
app_streams = set()
room_id = event.room_id
for appservice in self.appservice_to_user_streams:
room_source = self.event_sources.sources["room"]
room_listeners = self.room_to_listeners.get(room_id, set())
_discard_if_notified(room_listeners)
listeners = room_listeners.copy()
for user in extra_users:
user_listeners = self.user_to_listeners.get(user, set())
_discard_if_notified(user_listeners)
listeners |= user_listeners
for appservice in self.appservice_to_listeners:
# TODO (kegan): Redundant appservice listener checks?
# App services will already be in the room_to_user_streams set, but
# App services will already be in the room_to_listeners set, but
# that isn't enough. They need to be checked here in order to
# receive *invites* for users they are interested in. Does this
# make the room_to_user_streams check somewhat obselete?
# make the room_to_listeners check somewhat obselete?
if appservice.is_interested(event):
app_user_streams = self.appservice_to_user_streams.get(
app_listeners = self.appservice_to_listeners.get(
appservice, set()
)
app_streams |= app_user_streams
self.on_new_event(
"room_key", room_stream_id,
users=extra_users,
rooms=[event.room_id],
extra_streams=app_streams,
)
_discard_if_notified(app_listeners)
listeners |= app_listeners
logger.debug("on_new_room_event listeners %s", listeners)
# TODO (erikj): Can we make this more efficient by hitting the
# db once?
@defer.inlineCallbacks
def notify(listener):
events, end_key = yield room_source.get_new_events_for_user(
listener.user,
listener.from_token.room_key,
listener.limit,
)
if events:
end_token = listener.from_token.copy_and_replace(
"room_key", end_key
)
listener.notify(
self, events, listener.from_token, end_token
)
def eb(failure):
logger.exception("Failed to notify listener", failure)
with PreserveLoggingContext():
yield defer.DeferredList(
[notify(l).addErrback(eb) for l in listeners],
consumeErrors=True,
)
@defer.inlineCallbacks
@log_function
def on_new_event(self, stream_key, new_token, users=[], rooms=[],
extra_streams=set()):
""" Used to inform listeners that something has happend event wise.
def on_new_user_event(self, users=[], rooms=[]):
""" Used to inform listeners that something has happend
presence/user event wise.
Will wake up all listeners for the given users and rooms.
"""
yield run_on_reactor()
user_streams = set()
# TODO(paul): This is horrible, having to manually list every event
# source here individually
presence_source = self.event_sources.sources["presence"]
typing_source = self.event_sources.sources["typing"]
listeners = set()
for user in users:
user_stream = self.user_to_user_stream.get(str(user))
if user_stream is not None:
user_streams.add(user_stream)
user_listeners = self.user_to_listeners.get(user, set())
_discard_if_notified(user_listeners)
listeners |= user_listeners
for room in rooms:
user_streams |= self.room_to_user_streams.get(room, set())
room_listeners = self.room_to_listeners.get(room, set())
time_now_ms = self.clock.time_msec()
for user_stream in user_streams:
try:
user_stream.notify(stream_key, new_token, time_now_ms)
except:
logger.exception("Failed to notify listener")
_discard_if_notified(room_listeners)
listeners |= room_listeners
@defer.inlineCallbacks
def notify(listener):
presence_events, presence_end_key = (
yield presence_source.get_new_events_for_user(
listener.user,
listener.from_token.presence_key,
listener.limit,
)
)
typing_events, typing_end_key = (
yield typing_source.get_new_events_for_user(
listener.user,
listener.from_token.typing_key,
listener.limit,
)
)
if presence_events or typing_events:
end_token = listener.from_token.copy_and_replace(
"presence_key", presence_end_key
).copy_and_replace(
"typing_key", typing_end_key
)
listener.notify(
self,
presence_events + typing_events,
listener.from_token,
end_token
)
def eb(failure):
logger.error(
"Failed to notify listener",
exc_info=(
failure.type,
failure.value,
failure.getTracebackObject())
)
with PreserveLoggingContext():
yield defer.DeferredList(
[notify(l).addErrback(eb) for l in listeners],
consumeErrors=True,
)
@defer.inlineCallbacks
def wait_for_events(self, user, rooms, timeout, callback,
from_token=StreamToken("s0", "0", "0", "0")):
def wait_for_events(self, user, rooms, filter, timeout, callback):
"""Wait until the callback returns a non empty response or the
timeout fires.
"""
user = str(user)
user_stream = self.user_to_user_stream.get(user)
if user_stream is None:
appservice = yield self.store.get_app_service_by_user_id(user)
current_token = yield self.event_sources.get_current_token()
rooms = yield self.store.get_rooms_for_user(user)
rooms = [room.room_id for room in rooms]
user_stream = _NotifierUserStream(
user=user,
rooms=rooms,
appservice=appservice,
current_token=current_token,
time_now_ms=self.clock.time_msec(),
)
self._register_with_keys(user_stream)
result = None
deferred = defer.Deferred()
from_token = StreamToken("s0", "0", "0")
listener = [_NotificationListener(
user=user,
rooms=rooms,
from_token=from_token,
limit=1,
timeout=timeout,
deferred=deferred,
)]
if timeout:
# Will be set to a _NotificationListener that we'll be waiting on.
# Allows us to cancel it.
listener = None
self._register_with_keys(listener[0])
def timed_out():
if listener:
listener.deferred.cancel()
timer = self.clock.call_later(timeout/1000., timed_out)
result = yield callback()
timer = [None]
prev_token = from_token
while not result:
try:
current_token = user_stream.current_token
if timeout:
timed_out = [False]
result = yield callback(prev_token, current_token)
if result:
break
def _timeout_listener():
timed_out[0] = True
timer[0] = None
listener[0].notify(self, [], from_token, from_token)
# Now we wait for the _NotifierUserStream to be told there
# is a new token.
# We need to supply the token we supplied to callback so
# that we don't miss any current_token updates.
prev_token = current_token
listener = user_stream.new_listener(prev_token)
yield listener.deferred
except defer.CancelledError:
break
# We create multiple notification listeners so we have to manage
# canceling the timeout ourselves.
timer[0] = self.clock.call_later(timeout/1000., _timeout_listener)
self.clock.cancel_call_later(timer, ignore_errs=True)
else:
current_token = user_stream.current_token
result = yield callback(from_token, current_token)
while not result and not timed_out[0]:
yield deferred
deferred = defer.Deferred()
listener[0] = _NotificationListener(
user=user,
rooms=rooms,
from_token=from_token,
limit=1,
timeout=timeout,
deferred=deferred,
)
self._register_with_keys(listener[0])
result = yield callback()
if timer[0] is not None:
try:
self.clock.cancel_call_later(timer[0])
except:
logger.exception("Failed to cancel notifer timer")
defer.returnValue(result)
@defer.inlineCallbacks
def get_events_for(self, user, rooms, pagination_config, timeout,
only_room_events=False):
def get_events_for(self, user, rooms, pagination_config, timeout):
""" For the given user and rooms, return any new events for them. If
there are no new events wait for up to `timeout` milliseconds for any
new events to happen before returning.
If `only_room_events` is `True` only room events will be returned.
"""
from_token = pagination_config.from_token
deferred = defer.Deferred()
self._get_events(
deferred, user, rooms, pagination_config.from_token,
pagination_config.limit, timeout
).addErrback(deferred.errback)
return deferred
@defer.inlineCallbacks
def _get_events(self, deferred, user, rooms, from_token, limit, timeout):
if not from_token:
from_token = yield self.event_sources.get_current_token()
limit = pagination_config.limit
@defer.inlineCallbacks
def check_for_updates(before_token, after_token):
if not after_token.is_after(before_token):
defer.returnValue(None)
events = []
end_token = from_token
for name, source in self.event_sources.sources.items():
keyname = "%s_key" % name
before_id = getattr(before_token, keyname)
after_id = getattr(after_token, keyname)
if before_id == after_id:
continue
if only_room_events and name != "room":
continue
new_events, new_key = yield source.get_new_events_for_user(
user, getattr(from_token, keyname), limit,
)
events.extend(new_events)
end_token = end_token.copy_and_replace(keyname, new_key)
if events:
defer.returnValue((events, (from_token, end_token)))
else:
defer.returnValue(None)
result = yield self.wait_for_events(
user, rooms, timeout, check_for_updates, from_token=from_token
appservice = yield self.hs.get_datastore().get_app_service_by_user_id(
user.to_string()
)
if result is None:
result = ([], (from_token, from_token))
listener = _NotificationListener(
user,
rooms,
from_token,
limit,
timeout,
deferred,
appservice=appservice
)
defer.returnValue(result)
def _timeout_listener():
# TODO (erikj): We should probably set to_token to the current
# max rather than reusing from_token.
# Remove the timer from the listener so we don't try to cancel it.
listener.timer = None
listener.notify(
self,
[],
listener.from_token,
listener.from_token,
)
if timeout:
self._register_with_keys(listener)
yield self._check_for_updates(listener)
if not timeout:
_timeout_listener()
else:
# Only add the timer if the listener hasn't been notified
if not listener.notified():
listener.timer = self.clock.call_later(
timeout/1000.0, _timeout_listener
)
return
@log_function
def remove_expired_streams(self):
time_now_ms = self.clock.time_msec()
expired_streams = []
expire_before_ts = time_now_ms - self.UNUSED_STREAM_EXPIRY_MS
for stream in self.user_to_user_stream.values():
if stream.count_listeners():
continue
if stream.last_notified_ms < expire_before_ts:
expired_streams.append(stream)
def _register_with_keys(self, listener):
for room in listener.rooms:
s = self.room_to_listeners.setdefault(room, set())
s.add(listener)
for expired_stream in expired_streams:
expired_stream.remove(self)
self.user_to_listeners.setdefault(listener.user, set()).add(listener)
if listener.appservice:
self.appservice_to_listeners.setdefault(
listener.appservice, set()
).add(listener)
@defer.inlineCallbacks
@log_function
def _register_with_keys(self, user_stream):
self.user_to_user_stream[user_stream.user] = user_stream
def _check_for_updates(self, listener):
# TODO (erikj): We need to think about limits across multiple sources
events = []
for room in user_stream.rooms:
s = self.room_to_user_streams.setdefault(room, set())
s.add(user_stream)
from_token = listener.from_token
limit = listener.limit
if user_stream.appservice:
self.appservice_to_user_stream.setdefault(
user_stream.appservice, set()
).add(user_stream)
# TODO (erikj): DeferredList?
for name, source in self.event_sources.sources.items():
keyname = "%s_key" % name
stuff, new_key = yield source.get_new_events_for_user(
listener.user,
getattr(from_token, keyname),
limit,
)
events.extend(stuff)
from_token = from_token.copy_and_replace(keyname, new_key)
end_token = from_token
if events:
listener.notify(self, events, listener.from_token, end_token)
defer.returnValue(listener)
def _user_joined_room(self, user, room_id):
user = str(user)
new_user_stream = self.user_to_user_stream.get(user)
if new_user_stream is not None:
room_streams = self.room_to_user_streams.setdefault(room_id, set())
room_streams.add(new_user_stream)
new_user_stream.rooms.add(room_id)
new_listeners = self.user_to_listeners.get(user, set())
listeners = self.room_to_listeners.setdefault(room_id, set())
listeners |= new_listeners
for l in new_listeners:
l.rooms.add(room_id)
def _discard_if_notified(listener_set):
"""Remove any 'stale' listeners from the given set.
"""
to_discard = set()
for l in listener_set:
if l.notified():
to_discard.add(l)
listener_set -= to_discard

View File

@@ -24,7 +24,6 @@ import baserules
import logging
import simplejson as json
import re
import random
logger = logging.getLogger(__name__)
@@ -75,33 +74,35 @@ class Pusher(object):
rawrules = yield self.store.get_push_rules_for_user(self.user_name)
rules = []
for rawrule in rawrules:
rule = dict(rawrule)
rule['conditions'] = json.loads(rawrule['conditions'])
rule['actions'] = json.loads(rawrule['actions'])
rules.append(rule)
for r in rawrules:
r['conditions'] = json.loads(r['conditions'])
r['actions'] = json.loads(r['actions'])
enabled_map = yield self.store.get_push_rules_enabled_for_user(self.user_name)
user = UserID.from_string(self.user_name)
rules = baserules.list_with_base_rules(rules, user)
room_id = ev['room_id']
rules = baserules.list_with_base_rules(rawrules, user)
# get *our* member event for display name matching
my_display_name = None
our_member_event = yield self.store.get_current_state(
room_id=room_id,
member_events_for_room = yield self.store.get_current_state(
room_id=ev['room_id'],
event_type='m.room.member',
state_key=self.user_name,
state_key=None
)
if our_member_event:
my_display_name = our_member_event[0].content.get("displayname")
my_display_name = None
room_member_count = 0
for mev in member_events_for_room:
if mev.content['membership'] != 'join':
continue
room_members = yield self.store.get_users_in_room(room_id)
room_member_count = len(room_members)
# This loop does two things:
# 1) Find our current display name
if mev.state_key == self.user_name and 'displayname' in mev.content:
my_display_name = mev.content['displayname']
# and 2) Get the number of people in that room
room_member_count += 1
for r in rules:
if r['rule_id'] in enabled_map:
@@ -249,9 +250,7 @@ class Pusher(object):
# we fail to dispatch the push)
config = PaginationConfig(from_token=None, limit='1')
chunk = yield self.evStreamHandler.get_stream(
self.user_name, config, timeout=0, affect_presence=False,
only_room_events=True
)
self.user_name, config, timeout=0)
self.last_token = chunk['end']
self.store.update_pusher_last_token(
self.app_id, self.pushkey, self.user_name, self.last_token
@@ -259,160 +258,132 @@ class Pusher(object):
logger.info("Pusher %s for user %s starting from token %s",
self.pushkey, self.user_name, self.last_token)
wait = 0
while self.alive:
try:
if wait > 0:
yield synapse.util.async.sleep(wait)
yield self.get_and_dispatch()
wait = 0
except:
if wait == 0:
wait = 1
else:
wait = min(wait * 2, 1800)
logger.exception(
"Exception in pusher loop for pushkey %s. Pausing for %ds",
self.pushkey, wait
from_tok = StreamToken.from_string(self.last_token)
config = PaginationConfig(from_token=from_tok, limit='1')
chunk = yield self.evStreamHandler.get_stream(
self.user_name, config,
timeout=100*365*24*60*60*1000, affect_presence=False
)
# limiting to 1 may get 1 event plus 1 presence event, so
# pick out the actual event
single_event = None
for c in chunk['chunk']:
if 'event_id' in c: # Hmmm...
single_event = c
break
if not single_event:
self.last_token = chunk['end']
continue
if not self.alive:
continue
processed = False
actions = yield self._actions_for_event(single_event)
tweaks = _tweaks_for_actions(actions)
if len(actions) == 0:
logger.warn("Empty actions! Using default action.")
actions = Pusher.DEFAULT_ACTIONS
if 'notify' not in actions and 'dont_notify' not in actions:
logger.warn("Neither notify nor dont_notify in actions: adding default")
actions.extend(Pusher.DEFAULT_ACTIONS)
if 'dont_notify' in actions:
logger.debug(
"%s for %s: dont_notify",
single_event['event_id'], self.user_name
)
@defer.inlineCallbacks
def get_and_dispatch(self):
from_tok = StreamToken.from_string(self.last_token)
config = PaginationConfig(from_token=from_tok, limit='1')
timeout = (300 + random.randint(-60, 60)) * 1000
chunk = yield self.evStreamHandler.get_stream(
self.user_name, config, timeout=timeout, affect_presence=False,
only_room_events=True
)
# limiting to 1 may get 1 event plus 1 presence event, so
# pick out the actual event
single_event = None
for c in chunk['chunk']:
if 'event_id' in c: # Hmmm...
single_event = c
break
if not single_event:
self.last_token = chunk['end']
logger.debug("Event stream timeout for pushkey %s", self.pushkey)
yield self.store.update_pusher_last_token(
self.app_id,
self.pushkey,
self.user_name,
self.last_token
)
return
if not self.alive:
return
processed = False
actions = yield self._actions_for_event(single_event)
tweaks = _tweaks_for_actions(actions)
if len(actions) == 0:
logger.warn("Empty actions! Using default action.")
actions = Pusher.DEFAULT_ACTIONS
if 'notify' not in actions and 'dont_notify' not in actions:
logger.warn("Neither notify nor dont_notify in actions: adding default")
actions.extend(Pusher.DEFAULT_ACTIONS)
if 'dont_notify' in actions:
logger.debug(
"%s for %s: dont_notify",
single_event['event_id'], self.user_name
)
processed = True
else:
rejected = yield self.dispatch_push(single_event, tweaks)
self.has_unread = True
if isinstance(rejected, list) or isinstance(rejected, tuple):
processed = True
for pk in rejected:
if pk != self.pushkey:
# for sanity, we only remove the pushkey if it
# was the one we actually sent...
logger.warn(
("Ignoring rejected pushkey %s because we"
" didn't send it"), pk
)
else:
logger.info(
"Pushkey %s was rejected: removing",
pk
)
yield self.hs.get_pusherpool().remove_pusher(
self.app_id, pk, self.user_name
)
else:
rejected = yield self.dispatch_push(single_event, tweaks)
self.has_unread = True
if isinstance(rejected, list) or isinstance(rejected, tuple):
processed = True
for pk in rejected:
if pk != self.pushkey:
# for sanity, we only remove the pushkey if it
# was the one we actually sent...
logger.warn(
("Ignoring rejected pushkey %s because we"
" didn't send it"), pk
)
else:
logger.info(
"Pushkey %s was rejected: removing",
pk
)
yield self.hs.get_pusherpool().remove_pusher(
self.app_id, pk, self.user_name
)
if not self.alive:
return
if not self.alive:
continue
if processed:
self.backoff_delay = Pusher.INITIAL_BACKOFF
self.last_token = chunk['end']
yield self.store.update_pusher_last_token_and_success(
self.app_id,
self.pushkey,
self.user_name,
self.last_token,
self.clock.time_msec()
)
if self.failing_since:
self.failing_since = None
yield self.store.update_pusher_failing_since(
self.app_id,
self.pushkey,
self.user_name,
self.failing_since)
else:
if not self.failing_since:
self.failing_since = self.clock.time_msec()
yield self.store.update_pusher_failing_since(
self.app_id,
self.pushkey,
self.user_name,
self.failing_since
)
if (self.failing_since and
self.failing_since <
self.clock.time_msec() - Pusher.GIVE_UP_AFTER):
# we really only give up so that if the URL gets
# fixed, we don't suddenly deliver a load
# of old notifications.
logger.warn("Giving up on a notification to user %s, "
"pushkey %s",
self.user_name, self.pushkey)
if processed:
self.backoff_delay = Pusher.INITIAL_BACKOFF
self.last_token = chunk['end']
yield self.store.update_pusher_last_token(
self.store.update_pusher_last_token_and_success(
self.app_id,
self.pushkey,
self.user_name,
self.last_token
)
self.failing_since = None
yield self.store.update_pusher_failing_since(
self.app_id,
self.pushkey,
self.user_name,
self.failing_since
self.last_token,
self.clock.time_msec()
)
if self.failing_since:
self.failing_since = None
self.store.update_pusher_failing_since(
self.app_id,
self.pushkey,
self.user_name,
self.failing_since)
else:
logger.warn("Failed to dispatch push for user %s "
"(failing for %dms)."
"Trying again in %dms",
self.user_name,
self.clock.time_msec() - self.failing_since,
self.backoff_delay)
yield synapse.util.async.sleep(self.backoff_delay / 1000.0)
self.backoff_delay *= 2
if self.backoff_delay > Pusher.MAX_BACKOFF:
self.backoff_delay = Pusher.MAX_BACKOFF
if not self.failing_since:
self.failing_since = self.clock.time_msec()
self.store.update_pusher_failing_since(
self.app_id,
self.pushkey,
self.user_name,
self.failing_since
)
if (self.failing_since and
self.failing_since <
self.clock.time_msec() - Pusher.GIVE_UP_AFTER):
# we really only give up so that if the URL gets
# fixed, we don't suddenly deliver a load
# of old notifications.
logger.warn("Giving up on a notification to user %s, "
"pushkey %s",
self.user_name, self.pushkey)
self.backoff_delay = Pusher.INITIAL_BACKOFF
self.last_token = chunk['end']
self.store.update_pusher_last_token(
self.app_id,
self.pushkey,
self.user_name,
self.last_token
)
self.failing_since = None
self.store.update_pusher_failing_since(
self.app_id,
self.pushkey,
self.user_name,
self.failing_since
)
else:
logger.warn("Failed to dispatch push for user %s "
"(failing for %dms)."
"Trying again in %dms",
self.user_name,
self.clock.time_msec() - self.failing_since,
self.backoff_delay)
yield synapse.util.async.sleep(self.backoff_delay / 1000.0)
self.backoff_delay *= 2
if self.backoff_delay > Pusher.MAX_BACKOFF:
self.backoff_delay = Pusher.MAX_BACKOFF
def stop(self):
self.alive = False

View File

@@ -164,7 +164,7 @@ def make_base_append_underride_rules(user):
]
},
{
'rule_id': 'global/underride/.m.rule.contains_display_name',
'rule_id': 'global/override/.m.rule.contains_display_name',
'conditions': [
{
'kind': 'contains_display_name'

View File

@@ -94,14 +94,17 @@ class PusherPool:
self.remove_pusher(p['app_id'], p['pushkey'], p['user_name'])
@defer.inlineCallbacks
def remove_pushers_by_user(self, user_id):
def remove_pushers_by_user_access_token(self, user_id, not_access_token_id):
all = yield self.store.get_all_pushers()
logger.info(
"Removing all pushers for user %s",
user_id,
"Removing all pushers for user %s except access token %s",
user_id, not_access_token_id
)
for p in all:
if p['user_name'] == user_id:
if (
p['user_name'] == user_id and
p['access_token'] != not_access_token_id
):
logger.info(
"Removing pusher for app id %s, pushkey %s, user %s",
p['app_id'], p['pushkey'], p['user_name']

View File

@@ -18,36 +18,30 @@ from distutils.version import LooseVersion
logger = logging.getLogger(__name__)
REQUIREMENTS = {
"unpaddedbase64>=1.0.1": ["unpaddedbase64>=1.0.1"],
"canonicaljson>=1.0.0": ["canonicaljson>=1.0.0"],
"signedjson>=1.0.0": ["signedjson>=1.0.0"],
"Twisted>=15.1.0": ["twisted>=15.1.0"],
"syutil>=0.0.6": ["syutil>=0.0.6"],
"Twisted==14.0.2": ["twisted==14.0.2"],
"service_identity>=1.0.0": ["service_identity>=1.0.0"],
"pyopenssl>=0.14": ["OpenSSL>=0.14"],
"pyyaml": ["yaml"],
"pyasn1": ["pyasn1"],
"pynacl>=0.3.0": ["nacl>=0.3.0"],
"pynacl": ["nacl"],
"daemonize": ["daemonize"],
"py-bcrypt": ["bcrypt"],
"frozendict>=0.4": ["frozendict"],
"pillow": ["PIL"],
"pydenticon": ["pydenticon"],
"ujson": ["ujson"],
"blist": ["blist"],
"pysaml2": ["saml2"],
"pymacaroons-pynacl": ["pymacaroons"],
}
CONDITIONAL_REQUIREMENTS = {
"web_client": {
"matrix_angular_sdk>=0.6.6": ["syweb>=0.6.6"],
"matrix_angular_sdk>=0.6.5": ["syweb>=0.6.5"],
}
}
def requirements(config=None, include_conditional=False):
reqs = REQUIREMENTS.copy()
if include_conditional:
for _, req in CONDITIONAL_REQUIREMENTS.items():
for key, req in CONDITIONAL_REQUIREMENTS.items():
if (config and getattr(config, key)) or include_conditional:
reqs.update(req)
return reqs
@@ -55,8 +49,23 @@ def requirements(config=None, include_conditional=False):
def github_link(project, version, egg):
return "https://github.com/%s/tarball/%s/#egg=%s" % (project, version, egg)
DEPENDENCY_LINKS = {
}
DEPENDENCY_LINKS = [
github_link(
project="pyca/pynacl",
version="d4d3175589b892f6ea7c22f466e0e223853516fa",
egg="pynacl-0.3.0",
),
github_link(
project="matrix-org/syutil",
version="v0.0.6",
egg="syutil-0.0.6",
),
github_link(
project="matrix-org/matrix-angular-sdk",
version="v0.6.5",
egg="matrix_angular_sdk-0.6.5",
),
]
class MissingRequirementError(Exception):
@@ -124,7 +133,7 @@ def check_requirements(config=None):
def list_requirements():
result = []
linked = []
for link in DEPENDENCY_LINKS.values():
for link in DEPENDENCY_LINKS:
egg = link.split("#egg=")[1]
linked.append(egg.split('-')[0])
result.append(link)

View File

@@ -25,7 +25,7 @@ class ClientV1RestResource(JsonResource):
"""A resource for version 1 of the matrix client API."""
def __init__(self, hs):
JsonResource.__init__(self, hs, canonical_json=False)
JsonResource.__init__(self, hs)
self.register_servlets(self, hs)
@staticmethod

View File

@@ -31,7 +31,7 @@ class WhoisRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_GET(self, request, user_id):
target_user = UserID.from_string(user_id)
auth_user, _ = yield self.auth.get_user_by_req(request)
auth_user, client = yield self.auth.get_user_by_req(request)
is_admin = yield self.auth.is_server_admin(auth_user)
if not is_admin and target_user != auth_user:

View File

@@ -69,7 +69,7 @@ class ClientDirectoryServer(ClientV1RestServlet):
try:
# try to auth as a user
user, _ = yield self.auth.get_user_by_req(request)
user, client = yield self.auth.get_user_by_req(request)
try:
user_id = user.to_string()
yield dir_handler.create_association(
@@ -116,7 +116,7 @@ class ClientDirectoryServer(ClientV1RestServlet):
# fallback to default user behaviour if they aren't an AS
pass
user, _ = yield self.auth.get_user_by_req(request)
user, client = yield self.auth.get_user_by_req(request)
is_admin = yield self.auth.is_server_admin(user)
if not is_admin:

View File

@@ -34,7 +34,7 @@ class EventStreamRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_GET(self, request):
auth_user, _ = yield self.auth.get_user_by_req(request)
auth_user, client = yield self.auth.get_user_by_req(request)
try:
handler = self.handlers.event_stream_handler
pagin_config = PaginationConfig.from_request(request)
@@ -71,7 +71,7 @@ class EventRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def on_GET(self, request, event_id):
auth_user, _ = yield self.auth.get_user_by_req(request)
auth_user, client = yield self.auth.get_user_by_req(request)
handler = self.handlers.event_handler
event = yield handler.get_event(auth_user, event_id)

Some files were not shown because too many files have changed in this diff Show More