1
0

Compare commits

..

2 Commits

Author SHA1 Message Date
Erik Johnston
75bf48b905 Update tracer to give more information 2015-03-13 11:38:57 +00:00
Erik Johnston
d1ae594ae5 Add a utility class that can be used to generate a twisted deferred aware call graph 2015-03-12 16:52:02 +00:00
292 changed files with 7081 additions and 20536 deletions

4
.gitignore vendored
View File

@@ -42,7 +42,3 @@ build/
localhost-800*/
static/client/register/register_config.js
.tox
env/
*.config

View File

@@ -1,47 +0,0 @@
Erik Johnston <erik at matrix.org>
* HS core
* Federation API impl
Mark Haines <mark at matrix.org>
* HS core
* Crypto
* Content repository
* CS v2 API impl
Kegan Dougal <kegan at matrix.org>
* HS core
* CS v1 API impl
* AS API impl
Paul "LeoNerd" Evans <paul at matrix.org>
* HS core
* Presence
* Typing Notifications
* Performance metrics and caching layer
Dave Baker <dave at matrix.org>
* Push notifications
* Auth CS v2 impl
Matthew Hodgson <matthew at matrix.org>
* General doc & housekeeping
* Vertobot/vertobridge matrix<->verto PoC
Emmanuel Rohee <manu at matrix.org>
* Supporting iOS clients (testability and fallback registration)
Turned to Dust <dwinslow86 at gmail.com>
* ArchLinux installation instructions
Brabo <brabo at riseup.net>
* Installation instruction fixes
Ivan Shapovalov <intelfx100 at gmail.com>
* contrib/systemd: a sample systemd unit file and a logger configuration
Eric Myhre <hash at exultant.us>
* Fix bug where ``media_store_path`` config option was ignored by v0 content
repository API.
Muthu Subramanian <muthu.subramanian.karunanidhi at ericsson.com>
* Add SAML2 support for registration and logins.

View File

@@ -1,289 +1,3 @@
Changes in synapse v0.10.0-r2 (2015-09-16)
==========================================
* Fix bug where we always fetched remote server signing keys instead of using
ones in our cache.
* Fix adding threepids to an existing account.
* Fix bug with invinting over federation where remote server was already in
the room. (PR #281, SYN-392)
Changes in synapse v0.10.0-r1 (2015-09-08)
==========================================
* Fix bug with python packaging
Changes in synapse v0.10.0 (2015-09-03)
=======================================
No change from release candidate.
Changes in synapse v0.10.0-rc6 (2015-09-02)
===========================================
* Remove some of the old database upgrade scripts.
* Fix database port script to work with newly created sqlite databases.
Changes in synapse v0.10.0-rc5 (2015-08-27)
===========================================
* Fix bug that broke downloading files with ascii filenames across federation.
Changes in synapse v0.10.0-rc4 (2015-08-27)
===========================================
* Allow UTF-8 filenames for upload. (PR #259)
Changes in synapse v0.10.0-rc3 (2015-08-25)
===========================================
* Add ``--keys-directory`` config option to specify where files such as
certs and signing keys should be stored in, when using ``--generate-config``
or ``--generate-keys``. (PR #250)
* Allow ``--config-path`` to specify a directory, causing synapse to use all
\*.yaml files in the directory as config files. (PR #249)
* Add ``web_client_location`` config option to specify static files to be
hosted by synapse under ``/_matrix/client``. (PR #245)
* Add helper utility to synapse to read and parse the config files and extract
the value of a given key. For example::
$ python -m synapse.config read server_name -c homeserver.yaml
localhost
(PR #246)
Changes in synapse v0.10.0-rc2 (2015-08-24)
===========================================
* Fix bug where we incorrectly populated the ``event_forward_extremities``
table, resulting in problems joining large remote rooms (e.g.
``#matrix:matrix.org``)
* Reduce the number of times we wake up pushers by not listening for presence
or typing events, reducing the CPU cost of each pusher.
Changes in synapse v0.10.0-rc1 (2015-08-21)
===========================================
Also see v0.9.4-rc1 changelog, which has been amalgamated into this release.
General:
* Upgrade to Twisted 15 (PR #173)
* Add support for serving and fetching encryption keys over federation.
(PR #208)
* Add support for logging in with email address (PR #234)
* Add support for new ``m.room.canonical_alias`` event. (PR #233)
* Change synapse to treat user IDs case insensitively during registration and
login. (If two users already exist with case insensitive matching user ids,
synapse will continue to require them to specify their user ids exactly.)
* Error if a user tries to register with an email already in use. (PR #211)
* Add extra and improve existing caches (PR #212, #219, #226, #228)
* Batch various storage request (PR #226, #228)
* Fix bug where we didn't correctly log the entity that triggered the request
if the request came in via an application service (PR #230)
* Fix bug where we needlessly regenerated the full list of rooms an AS is
interested in. (PR #232)
* Add support for AS's to use v2_alpha registration API (PR #210)
Configuration:
* Add ``--generate-keys`` that will generate any missing cert and key files in
the configuration files. This is equivalent to running ``--generate-config``
on an existing configuration file. (PR #220)
* ``--generate-config`` now no longer requires a ``--server-name`` parameter
when used on existing configuration files. (PR #220)
* Add ``--print-pidfile`` flag that controls the printing of the pid to stdout
of the demonised process. (PR #213)
Media Repository:
* Fix bug where we picked a lower resolution image than requested. (PR #205)
* Add support for specifying if a the media repository should dynamically
thumbnail images or not. (PR #206)
Metrics:
* Add statistics from the reactor to the metrics API. (PR #224, #225)
Demo Homeservers:
* Fix starting the demo homeservers without rate-limiting enabled. (PR #182)
* Fix enabling registration on demo homeservers (PR #223)
Changes in synapse v0.9.4-rc1 (2015-07-21)
==========================================
General:
* Add basic implementation of receipts. (SPEC-99)
* Add support for configuration presets in room creation API. (PR #203)
* Add auth event that limits the visibility of history for new users.
(SPEC-134)
* Add SAML2 login/registration support. (PR #201. Thanks Muthu Subramanian!)
* Add client side key management APIs for end to end encryption. (PR #198)
* Change power level semantics so that you cannot kick, ban or change power
levels of users that have equal or greater power level than you. (SYN-192)
* Improve performance by bulk inserting events where possible. (PR #193)
* Improve performance by bulk verifying signatures where possible. (PR #194)
Configuration:
* Add support for including TLS certificate chains.
Media Repository:
* Add Content-Disposition headers to content repository responses. (SYN-150)
Changes in synapse v0.9.3 (2015-07-01)
======================================
No changes from v0.9.3 Release Candidate 1.
Changes in synapse v0.9.3-rc1 (2015-06-23)
==========================================
General:
* Fix a memory leak in the notifier. (SYN-412)
* Improve performance of room initial sync. (SYN-418)
* General improvements to logging.
* Remove ``access_token`` query params from ``INFO`` level logging.
Configuration:
* Add support for specifying and configuring multiple listeners. (SYN-389)
Application services:
* Fix bug where synapse failed to send user queries to application services.
Changes in synapse v0.9.2-r2 (2015-06-15)
=========================================
Fix packaging so that schema delta python files get included in the package.
Changes in synapse v0.9.2 (2015-06-12)
======================================
General:
* Use ultrajson for json (de)serialisation when a canonical encoding is not
required. Ultrajson is significantly faster than simplejson in certain
circumstances.
* Use connection pools for outgoing HTTP connections.
* Process thumbnails on separate threads.
Configuration:
* Add option, ``gzip_responses``, to disable HTTP response compression.
Federation:
* Improve resilience of backfill by ensuring we fetch any missing auth events.
* Improve performance of backfill and joining remote rooms by removing
unnecessary computations. This included handling events we'd previously
handled as well as attempting to compute the current state for outliers.
Changes in synapse v0.9.1 (2015-05-26)
======================================
General:
* Add support for backfilling when a client paginates. This allows servers to
request history for a room from remote servers when a client tries to
paginate history the server does not have - SYN-36
* Fix bug where you couldn't disable non-default pushrules - SYN-378
* Fix ``register_new_user`` script - SYN-359
* Improve performance of fetching events from the database, this improves both
initialSync and sending of events.
* Improve performance of event streams, allowing synapse to handle more
simultaneous connected clients.
Federation:
* Fix bug with existing backfill implementation where it returned the wrong
selection of events in some circumstances.
* Improve performance of joining remote rooms.
Configuration:
* Add support for changing the bind host of the metrics listener via the
``metrics_bind_host`` option.
Changes in synapse v0.9.0-r5 (2015-05-21)
=========================================
* Add more database caches to reduce amount of work done for each pusher. This
radically reduces CPU usage when multiple pushers are set up in the same room.
Changes in synapse v0.9.0 (2015-05-07)
======================================
General:
* Add support for using a PostgreSQL database instead of SQLite. See
`docs/postgres.rst`_ for details.
* Add password change and reset APIs. See `Registration`_ in the spec.
* Fix memory leak due to not releasing stale notifiers - SYN-339.
* Fix race in caches that occasionally caused some presence updates to be
dropped - SYN-369.
* Check server name has not changed on restart.
* Add a sample systemd unit file and a logger configuration in
contrib/systemd. Contributed Ivan Shapovalov.
Federation:
* Add key distribution mechanisms for fetching public keys of unavailable
remote home servers. See `Retrieving Server Keys`_ in the spec.
Configuration:
* Add support for multiple config files.
* Add support for dictionaries in config files.
* Remove support for specifying config options on the command line, except
for:
* ``--daemonize`` - Daemonize the home server.
* ``--manhole`` - Turn on the twisted telnet manhole service on the given
port.
* ``--database-path`` - The path to a sqlite database to use.
* ``--verbose`` - The verbosity level.
* ``--log-file`` - File to log to.
* ``--log-config`` - Python logging config file.
* ``--enable-registration`` - Enable registration for new users.
Application services:
* Reliably retry sending of events from Synapse to application services, as per
`Application Services`_ spec.
* Application services can no longer register via the ``/register`` API,
instead their configuration should be saved to a file and listed in the
synapse ``app_service_config_files`` config option. The AS configuration file
has the same format as the old ``/register`` request.
See `docs/application_services.rst`_ for more information.
.. _`docs/postgres.rst`: docs/postgres.rst
.. _`docs/application_services.rst`: docs/application_services.rst
.. _`Registration`: https://github.com/matrix-org/matrix-doc/blob/master/specification/10_client_server_api.rst#registration
.. _`Retrieving Server Keys`: https://github.com/matrix-org/matrix-doc/blob/6f2698/specification/30_server_server_api.rst#retrieving-server-keys
.. _`Application Services`: https://github.com/matrix-org/matrix-doc/blob/0c6bd9/specification/25_application_service_api.rst#home-server---application-service-api
Changes in synapse v0.8.1 (2015-03-18)
======================================
* Disable registration by default. New users can be added using the command
``register_new_matrix_user`` or by enabling registration in the config.
* Add metrics to synapse. To enable metrics use config options
``enable_metrics`` and ``metrics_port``.
* Fix bug where banning only kicked the user.
Changes in synapse v0.8.0 (2015-03-06)
======================================

View File

@@ -1,118 +0,0 @@
Contributing code to Matrix
===========================
Everyone is welcome to contribute code to Matrix
(https://github.com/matrix-org), provided that they are willing to license
their contributions under the same license as the project itself. We follow a
simple 'inbound=outbound' model for contributions: the act of submitting an
'inbound' contribution means that the contributor agrees to license the code
under the same terms as the project's overall 'outbound' license - in our
case, this is almost always Apache Software License v2 (see LICENSE).
How to contribute
~~~~~~~~~~~~~~~~~
The preferred and easiest way to contribute changes to Matrix is to fork the
relevant project on github, and then create a pull request to ask us to pull
your changes into our repo
(https://help.github.com/articles/using-pull-requests/)
**The single biggest thing you need to know is: please base your changes on
the develop branch - /not/ master.**
We use the master branch to track the most recent release, so that folks who
blindly clone the repo and automatically check out master get something that
works. Develop is the unstable branch where all the development actually
happens: the workflow is that contributors should fork the develop branch to
make a 'feature' branch for a particular contribution, and then make a pull
request to merge this back into the matrix.org 'official' develop branch. We
use github's pull request workflow to review the contribution, and either ask
you to make any refinements needed or merge it and make them ourselves. The
changes will then land on master when we next do a release.
We use Jenkins for continuous integration (http://matrix.org/jenkins), and
typically all pull requests get automatically tested Jenkins: if your change breaks the build, Jenkins will yell about it in #matrix-dev:matrix.org so please lurk there and keep an eye open.
Code style
~~~~~~~~~~
All Matrix projects have a well-defined code-style - and sometimes we've even
got as far as documenting it... For instance, synapse's code style doc lives
at https://github.com/matrix-org/synapse/tree/master/docs/code_style.rst.
Please ensure your changes match the cosmetic style of the existing project,
and **never** mix cosmetic and functional changes in the same commit, as it
makes it horribly hard to review otherwise.
Attribution
~~~~~~~~~~~
Everyone who contributes anything to Matrix is welcome to be listed in the
AUTHORS.rst file for the project in question. Please feel free to include a
change to AUTHORS.rst in your pull request to list yourself and a short
description of the area(s) you've worked on. Also, we sometimes have swag to
give away to contributors - if you feel that Matrix-branded apparel is missing
from your life, please mail us your shipping address to matrix at matrix.org and we'll try to fix it :)
Sign off
~~~~~~~~
In order to have a concrete record that your contribution is intentional
and you agree to license it under the same terms as the project's license, we've adopted the
same lightweight approach that the Linux Kernel
(https://www.kernel.org/doc/Documentation/SubmittingPatches), Docker
(https://github.com/docker/docker/blob/master/CONTRIBUTING.md), and many other
projects use: the DCO (Developer Certificate of Origin:
http://developercertificate.org/). This is a simple declaration that you wrote
the contribution or otherwise have the right to contribute it to Matrix::
Developer Certificate of Origin
Version 1.1
Copyright (C) 2004, 2006 The Linux Foundation and its contributors.
660 York Street, Suite 102,
San Francisco, CA 94110 USA
Everyone is permitted to copy and distribute verbatim copies of this
license document, but changing it is not allowed.
Developer's Certificate of Origin 1.1
By making a contribution to this project, I certify that:
(a) The contribution was created in whole or in part by me and I
have the right to submit it under the open source license
indicated in the file; or
(b) The contribution is based upon previous work that, to the best
of my knowledge, is covered under an appropriate open source
license and I have the right under that license to submit that
work with modifications, whether created in whole or in part
by me, under the same open source license (unless I am
permitted to submit under a different license), as indicated
in the file; or
(c) The contribution was provided directly to me by some other
person who certified (a), (b) or (c) and I have not modified
it.
(d) I understand and agree that this project and the contribution
are public and that a record of the contribution (including all
personal information I submit with it, including my sign-off) is
maintained indefinitely and may be redistributed consistent with
this project or the open source license(s) involved.
If you agree to this for your contribution, then all that's needed is to
include the line in your commit or pull request comment::
Signed-off-by: Your Name <your@email.example.org>
...using your real name; unfortunately pseudonyms and anonymous contributions
can't be accepted. Git makes this trivial - just use the -s flag when you do
``git commit``, having first set ``user.name`` and ``user.email`` git configs
(which you should have done anyway :)
Conclusion
~~~~~~~~~~
That's it! Matrix is a very open and collaborative project as you might expect given our obsession with open communication. If we're going to successfully matrix together all the fragmented communication technologies out there we are reliant on contributions and collaboration from the community to do so. So please get involved - and we hope you have as much fun hacking on Matrix as we do!

View File

@@ -3,20 +3,12 @@ include LICENSE
include VERSION
include *.rst
include demo/README
include demo/demo.tls.dh
include demo/*.py
include demo/*.sh
recursive-include synapse/storage/schema *.sql
recursive-include synapse/storage/schema *.py
recursive-include demo *.dh
recursive-include demo *.py
recursive-include demo *.sh
recursive-include docs *
recursive-include scripts *
recursive-include scripts-dev *
recursive-include tests *.py
recursive-include static *.css
recursive-include static *.html
recursive-include static *.js
prune demo/etc

View File

@@ -1,5 +1,3 @@
.. contents::
Introduction
============
@@ -7,7 +5,7 @@ Matrix is an ambitious new ecosystem for open federated Instant Messaging and
VoIP. The basics you need to know to get up and running are:
- Everything in Matrix happens in a room. Rooms are distributed and do not
exist on any single server. Rooms can be located using convenience aliases
exist on any single server. Rooms can be located using convenience aliases
like ``#matrix:matrix.org`` or ``#test:localhost:8448``.
- Matrix user IDs look like ``@matthew:matrix.org`` (although in the future
@@ -20,10 +18,10 @@ The overall architecture is::
https://somewhere.org/_matrix https://elsewhere.net/_matrix
``#matrix:matrix.org`` is the official support room for Matrix, and can be
accessed by the web client at http://matrix.org/beta or via an IRC bridge at
accessed by the web client at http://matrix.org/alpha or via an IRC bridge at
irc://irc.freenode.net/matrix.
Synapse is currently in rapid development, but as of version 0.5 we believe it
Synapse is currently in rapid development, but as of version 0.5 we believe it
is sufficiently stable to be run as an internet-facing service for real usage!
About Matrix
@@ -69,32 +67,25 @@ Synapse ships with two basic demo Matrix clients: webclient (a basic group chat
web client demo implemented in AngularJS) and cmdclient (a basic Python
command line utility which lets you easily see what the JSON APIs are up to).
Meanwhile, iOS and Android SDKs and clients are available from:
Meanwhile, iOS and Android SDKs and clients are currently in development and available from:
- https://github.com/matrix-org/matrix-ios-sdk
- https://github.com/matrix-org/matrix-ios-kit
- https://github.com/matrix-org/matrix-ios-console
- https://github.com/matrix-org/matrix-android-sdk
We'd like to invite you to join #matrix:matrix.org (via
https://matrix.org/beta), run a homeserver, take a look at the Matrix spec at
https://matrix.org/docs/spec and API docs at https://matrix.org/docs/api,
experiment with the APIs and the demo clients, and report any bugs via
https://matrix.org/jira.
We'd like to invite you to join #matrix:matrix.org (via http://matrix.org/alpha), run a homeserver, take a look at the Matrix spec at
http://matrix.org/docs/spec, experiment with the APIs and the demo
clients, and report any bugs via http://matrix.org/jira.
Thanks for using Matrix!
[1] End-to-end encryption is currently in development
Synapse Installation
====================
Synapse is the reference python/twisted Matrix homeserver implementation.
Homeserver Installation
=======================
System requirements:
- POSIX-compliant system (tested on Linux & OS X)
- POSIX-compliant system (tested on Linux & OSX)
- Python 2.7
- At least 512 MB RAM.
Synapse is written in python but some of the libraries is uses are written in
C. So before we can install synapse itself we need a working C compiler and the
@@ -102,163 +93,117 @@ header files for python C extensions.
Installing prerequisites on Ubuntu or Debian::
sudo apt-get install build-essential python2.7-dev libffi-dev \
python-pip python-setuptools sqlite3 \
libssl-dev python-virtualenv libjpeg-dev
$ sudo apt-get install build-essential python2.7-dev libffi-dev \
python-pip python-setuptools sqlite3 \
libssl-dev python-virtualenv libjpeg-dev
Installing prerequisites on ArchLinux::
sudo pacman -S base-devel python2 python-pip \
python-setuptools python-virtualenv sqlite3
$ sudo pacman -S base-devel python2 python-pip \
python-setuptools python-virtualenv sqlite3
Installing prerequisites on Mac OS X::
xcode-select --install
sudo easy_install pip
sudo pip install virtualenv
$ xcode-select --install
$ sudo pip install virtualenv
To install the synapse homeserver run::
virtualenv -p python2.7 ~/.synapse
source ~/.synapse/bin/activate
pip install --upgrade setuptools
pip install --process-dependency-links https://github.com/matrix-org/synapse/tarball/master
$ virtualenv ~/.synapse
$ source ~/.synapse/bin/activate
$ pip install --process-dependency-links https://github.com/matrix-org/synapse/tarball/master
This installs synapse, along with the libraries it uses, into a virtual
environment under ``~/.synapse``. Feel free to pick a different directory
if you prefer.
In case of problems, please see the _Troubleshooting section below.
Alternatively, Silvio Fricke has contributed a Dockerfile to automate the
above in Docker at https://registry.hub.docker.com/u/silviof/docker-matrix/.
environment under ``~/.synapse``.
To set up your homeserver, run (in your virtualenv, as before)::
cd ~/.synapse
python -m synapse.app.homeserver \
$ cd ~/.synapse
$ python -m synapse.app.homeserver \
--server-name machine.my.domain.name \
--config-path homeserver.yaml \
--generate-config
Substituting your host and domain name as appropriate.
This will generate you a config file that you can then customise, but it will
also generate a set of keys for you. These keys will allow your Home Server to
identify itself to other Home Servers, so don't lose or delete them. It would be
wise to back them up somewhere safe. If, for whatever reason, you do need to
change your Home Server's keys, you may find that other Home Servers have the
old key cached. If you update the signing key, you should change the name of the
key in the <server name>.signing.key file (the second word, which by default is
, 'auto') to something different.
By default, registration of new users is disabled. You can either enable
registration in the config by specifying ``enable_registration: true``
(it is then recommended to also set up CAPTCHA), or
you can use the command line to register new users::
$ source ~/.synapse/bin/activate
$ register_new_matrix_user -c homeserver.yaml https://localhost:8448
New user localpart: erikj
Password:
Confirm password:
Success!
For reliable VoIP calls to be routed via this homeserver, you MUST configure
a TURN server. See docs/turn-howto.rst for details.
Using PostgreSQL
================
Troubleshooting Installation
----------------------------
As of Synapse 0.9, `PostgreSQL <http://www.postgresql.org>`_ is supported as an
alternative to the `SQLite <http://sqlite.org/>`_ database that Synapse has
traditionally used for convenience and simplicity.
Synapse requires pip 1.7 or later, so if your OS provides too old a version and
you get errors about ``error: no such option: --process-dependency-links`` you
may need to manually upgrade it::
The advantages of Postgres include:
$ sudo pip install --upgrade pip
* significant performance improvements due to the superior threading and
caching model, smarter query optimiser
* allowing the DB to be run on separate hardware
* allowing basic active/backup high-availability with a "hot spare" synapse
pointing at the same DB master, as well as enabling DB replication in
synapse itself.
If pip crashes mid-installation for reason (e.g. lost terminal), pip may
refuse to run until you remove the temporary installation directory it
created. To reset the installation::
The only disadvantage is that the code is relatively new as of April 2015 and
may have a few regressions relative to SQLite.
$ rm -rf /tmp/pip_install_matrix
For information on how to install and use PostgreSQL, please see
`docs/postgres.rst <docs/postgres.rst>`_.
pip seems to leak *lots* of memory during installation. For instance, a Linux
host with 512MB of RAM may run out of memory whilst installing Twisted. If this
happens, you will have to individually install the dependencies which are
failing, e.g.::
Running Synapse
===============
$ pip install twisted
To actually run your new homeserver, pick a working directory for Synapse to
run (e.g. ``~/.synapse``), and::
cd ~/.synapse
source ./bin/activate
synctl start
Platform Specific Instructions
==============================
On OSX, if you encounter clang: error: unknown argument: '-mno-fused-madd' you
will need to export CFLAGS=-Qunused-arguments.
ArchLinux
---------
The quickest way to get up and running with ArchLinux is probably with Ivan
Shapovalov's AUR package from
https://aur.archlinux.org/packages/matrix-synapse/, which should pull in all
the necessary dependencies.
Alternatively, to install using pip a few changes may be needed as ArchLinux
defaults to python 3, but synapse currently assumes python 2.7 by default:
Installation on ArchLinux may encounter a few hiccups as Arch defaults to
python 3, but synapse currently assumes python 2.7 by default.
pip may be outdated (6.0.7-1 and needs to be upgraded to 6.0.8-1 )::
sudo pip2.7 install --upgrade pip
$ sudo pip2.7 install --upgrade pip
You also may need to explicitly specify python 2.7 again during the install
request::
pip2.7 install --process-dependency-links \
$ pip2.7 install --process-dependency-links \
https://github.com/matrix-org/synapse/tarball/master
If you encounter an error with lib bcrypt causing an Wrong ELF Class:
ELFCLASS32 (x64 Systems), you may need to reinstall py-bcrypt to correctly
compile it under the right architecture. (This should not be needed if
installing under virtualenv)::
sudo pip2.7 uninstall py-bcrypt
sudo pip2.7 install py-bcrypt
$ sudo pip2.7 uninstall py-bcrypt
$ sudo pip2.7 install py-bcrypt
During setup of homeserver you need to call python2.7 directly again::
During setup of Synapse you need to call python2.7 directly again::
cd ~/.synapse
python2.7 -m synapse.app.homeserver \
$ cd ~/.synapse
$ python2.7 -m synapse.app.homeserver \
--server-name machine.my.domain.name \
--config-path homeserver.yaml \
--generate-config
...substituting your host and domain name as appropriate.
Windows Install
---------------
Synapse can be installed on Cygwin. It requires the following Cygwin packages:
- gcc
- git
- libffi-devel
- openssl (and openssl-devel, python-openssl)
- python
- python-setuptools
- gcc
- git
- libffi-devel
- openssl (and openssl-devel, python-openssl)
- python
- python-setuptools
The content repository requires additional packages and will be unable to process
uploads without them:
- libjpeg8
- libjpeg8-devel
- zlib
- libjpeg8
- libjpeg8-devel
- zlib
If you choose to install Synapse without these packages, you will need to reinstall
``pillow`` for changes to be applied, e.g. ``pip uninstall pillow`` ``pip install
pillow --user``
@@ -274,44 +219,21 @@ Troubleshooting:
you do, you may need to create a symlink to ``libsodium.a`` so ``ld`` can find
it: ``ln -s /usr/local/lib/libsodium.a /usr/lib/libsodium.a``
Troubleshooting
===============
Running Your Homeserver
=======================
Troubleshooting Installation
----------------------------
To actually run your new homeserver, pick a working directory for Synapse to run
(e.g. ``~/.synapse``), and::
Synapse requires pip 1.7 or later, so if your OS provides too old a version and
you get errors about ``error: no such option: --process-dependency-links`` you
may need to manually upgrade it::
sudo pip install --upgrade pip
Installing may fail with ``mock requires setuptools>=17.1. Aborting installation``.
You can fix this by upgrading setuptools::
pip install --upgrade setuptools
If pip crashes mid-installation for reason (e.g. lost terminal), pip may
refuse to run until you remove the temporary installation directory it
created. To reset the installation::
rm -rf /tmp/pip_install_matrix
pip seems to leak *lots* of memory during installation. For instance, a Linux
host with 512MB of RAM may run out of memory whilst installing Twisted. If this
happens, you will have to individually install the dependencies which are
failing, e.g.::
pip install twisted
On OS X, if you encounter clang: error: unknown argument: '-mno-fused-madd' you
will need to export CFLAGS=-Qunused-arguments.
$ cd ~/.synapse
$ source ./bin/activate
$ synctl start
Troubleshooting Running
-----------------------
If synapse fails with ``missing "sodium.h"`` crypto errors, you may need
to manually upgrade PyNaCL, as synapse uses NaCl (http://nacl.cr.yp.to/) for
If synapse fails with ``missing "sodium.h"`` crypto errors, you may need
to manually upgrade PyNaCL, as synapse uses NaCl (http://nacl.cr.yp.to/) for
encryption and digital signatures.
Unfortunately PyNACL currently has a few issues
(https://github.com/pyca/pynacl/issues/53) and
@@ -320,46 +242,44 @@ correctly, causing all tests to fail with errors about missing "sodium.h". To
fix try re-installing from PyPI or directly from
(https://github.com/pyca/pynacl)::
# Install from PyPI
pip install --user --upgrade --force pynacl
# Install from github
pip install --user https://github.com/pyca/pynacl/tarball/master
$ # Install from PyPI
$ pip install --user --upgrade --force pynacl
$ # Install from github
$ pip install --user https://github.com/pyca/pynacl/tarball/master
ArchLinux
~~~~~~~~~
---------
If running `$ synctl start` fails with 'returned non-zero exit status 1',
you will need to explicitly call Python2.7 - either running as::
python2.7 -m synapse.app.homeserver --daemonize -c homeserver.yaml
If running `$ synctl start` fails wit 'returned non-zero exit status 1', you will need to explicitly call Python2.7 - either running as::
$ python2.7 -m synapse.app.homeserver --daemonize -c homeserver.yaml --pid-file homeserver.pid
...or by editing synctl with the correct python executable.
Synapse Development
===================
Homeserver Development
======================
To check out a synapse for development, clone the git repo into a working
To check out a homeserver for development, clone the git repo into a working
directory of your choice::
git clone https://github.com/matrix-org/synapse.git
cd synapse
$ git clone https://github.com/matrix-org/synapse.git
$ cd synapse
Synapse has a number of external dependencies, that are easiest
The homeserver has a number of external dependencies, that are easiest
to install using pip and a virtualenv::
virtualenv env
source env/bin/activate
python synapse/python_dependencies.py | xargs -n1 pip install
pip install setuptools_trial mock
$ virtualenv env
$ source env/bin/activate
$ python synapse/python_dependencies.py | xargs -n1 pip install
$ pip install setuptools_trial mock
This will run a process of downloading and installing all the needed
dependencies into a virtual env.
Once this is done, you may wish to run Synapse's unit tests, to
Once this is done, you may wish to run the homeserver's unit tests, to
check that everything is installed as it should be::
python setup.py test
$ python setup.py test
This should end with a 'PASSED' result::
@@ -368,14 +288,17 @@ This should end with a 'PASSED' result::
PASSED (successes=143)
Upgrading an existing Synapse
=============================
Upgrading an existing homeserver
================================
The instructions for upgrading synapse are in `UPGRADE.rst`_.
Please check these instructions as upgrading may require extra steps for some
versions of synapse.
IMPORTANT: Before upgrading an existing homeserver to a new version, please
refer to UPGRADE.rst for any additional instructions.
Otherwise, simply re-install the new codebase over the current one - e.g.
by ``pip install --process-dependency-links
https://github.com/matrix-org/synapse/tarball/master``
if using pip, or by ``git pull`` if running off a git working copy.
.. _UPGRADE.rst: UPGRADE.rst
Setting up Federation
=====================
@@ -397,11 +320,11 @@ IDs:
For the first form, simply pass the required hostname (of the machine) as the
--server-name parameter::
python -m synapse.app.homeserver \
$ python -m synapse.app.homeserver \
--server-name machine.my.domain.name \
--config-path homeserver.yaml \
--generate-config
python -m synapse.app.homeserver --config-path homeserver.yaml
$ python -m synapse.app.homeserver --config-path homeserver.yaml
Alternatively, you can run ``synctl start`` to guide you through the process.
@@ -411,33 +334,38 @@ and port where the server is running. (At the current time synapse does not
support clustering multiple servers into a single logical homeserver). The DNS
record would then look something like::
$ dig -t srv _matrix._tcp.machine.my.domain.name
$ dig -t srv _matrix._tcp.machine.my.domaine.name
_matrix._tcp IN SRV 10 0 8448 machine.my.domain.name.
At this point, you should then run the homeserver with the hostname of this
SRV record, as that is the name other machines will expect it to have::
python -m synapse.app.homeserver \
$ python -m synapse.app.homeserver \
--server-name YOURDOMAIN \
--bind-port 8448 \
--config-path homeserver.yaml \
--generate-config
python -m synapse.app.homeserver --config-path homeserver.yaml
$ python -m synapse.app.homeserver --config-path homeserver.yaml
You may additionally want to pass one or more "-v" options, in order to
increase the verbosity of logging output; at least for initial testing.
Running a Demo Federation of Synapses
-------------------------------------
For the initial alpha release, the homeserver is not speaking TLS for
either client-server or server-server traffic for ease of debugging. We have
also not spent any time yet getting the homeserver to run behind loadbalancers.
Running a Demo Federation of Homeservers
----------------------------------------
If you want to get up and running quickly with a trio of homeservers in a
private federation (``localhost:8080``, ``localhost:8081`` and
``localhost:8082``) which you can then access through the webclient running at
http://localhost:8080. Simply run::
demo/start.sh
$ demo/start.sh
This is mainly useful just for development purposes.
Running The Demo Web Client
@@ -464,10 +392,7 @@ account. Your name will take the form of::
Specify your desired localpart in the topmost box of the "Register for an
account" form, and click the "Register" button. Hostnames can contain ports if
required due to lack of SRV records (e.g. @matthew:localhost:8448 on an
internal synapse sandbox running on localhost).
If registration fails, you may need to enable it in the homeserver (see
`Synapse Installation`_ above)
internal synapse sandbox running on localhost)
Logging In To An Existing Account
@@ -493,14 +418,14 @@ track 3PID logins and publish end-user public keys.
It's currently early days for identity servers as Matrix is not yet using 3PIDs
as the primary means of identity and E2E encryption is not complete. As such,
we are running a single identity server (https://matrix.org) at the current
we are running a single identity server (http://matrix.org:8090) at the current
time.
Where's the spec?!
==================
The source of the matrix spec lives at https://github.com/matrix-org/matrix-doc.
The source of the matrix spec lives at https://github.com/matrix-org/matrix-doc.
A recent HTML snapshot of this lives at http://matrix.org/docs/spec
@@ -510,10 +435,10 @@ Building Internal API Documentation
Before building internal API documentation install sphinx and
sphinxcontrib-napoleon::
pip install sphinx
pip install sphinxcontrib-napoleon
$ pip install sphinx
$ pip install sphinxcontrib-napoleon
Building internal API documentation::
python setup.py build_sphinx
$ python setup.py build_sphinx

View File

@@ -1,70 +1,3 @@
Upgrading Synapse
=================
Before upgrading check if any special steps are required to upgrade from the
what you currently have installed to current version of synapse. The extra
instructions that may be required are listed later in this document.
If synapse was installed in a virtualenv then active that virtualenv before
upgrading. If synapse is installed in a virtualenv in ``~/.synapse/`` then run:
.. code:: bash
source ~/.synapse/bin/activate
If synapse was installed using pip then upgrade to the latest version by
running:
.. code:: bash
pip install --upgrade --process-dependency-links https://github.com/matrix-org/synapse/tarball/master
If synapse was installed using git then upgrade to the latest version by
running:
.. code:: bash
# Pull the latest version of the master branch.
git pull
# Update the versions of synapse's python dependencies.
python synapse/python_dependencies.py | xargs -n1 pip install
Upgrading to v0.9.0
===================
Application services have had a breaking API change in this version.
They can no longer register themselves with a home server using the AS HTTP API. This
decision was made because a compromised application service with free reign to register
any regex in effect grants full read/write access to the home server if a regex of ``.*``
is used. An attack where a compromised AS re-registers itself with ``.*`` was deemed too
big of a security risk to ignore, and so the ability to register with the HS remotely has
been removed.
It has been replaced by specifying a list of application service registrations in
``homeserver.yaml``::
app_service_config_files: ["registration-01.yaml", "registration-02.yaml"]
Where ``registration-01.yaml`` looks like::
url: <String> # e.g. "https://my.application.service.com"
as_token: <String>
hs_token: <String>
sender_localpart: <String> # This is a new field which denotes the user_id localpart when using the AS token
namespaces:
users:
- exclusive: <Boolean>
regex: <String> # e.g. "@prefix_.*"
aliases:
- exclusive: <Boolean>
regex: <String>
rooms:
- exclusive: <Boolean>
regex: <String>
Upgrading to v0.8.0
===================

View File

@@ -1,93 +0,0 @@
#!/usr/bin/env python
from argparse import ArgumentParser
import json
import requests
import sys
import urllib
def _mkurl(template, kws):
for key in kws:
template = template.replace(key, kws[key])
return template
def main(hs, room_id, access_token, user_id_prefix, why):
if not why:
why = "Automated kick."
print "Kicking members on %s in room %s matching %s" % (hs, room_id, user_id_prefix)
room_state_url = _mkurl(
"$HS/_matrix/client/api/v1/rooms/$ROOM/state?access_token=$TOKEN",
{
"$HS": hs,
"$ROOM": room_id,
"$TOKEN": access_token
}
)
print "Getting room state => %s" % room_state_url
res = requests.get(room_state_url)
print "HTTP %s" % res.status_code
state_events = res.json()
if "error" in state_events:
print "FATAL"
print state_events
return
kick_list = []
room_name = room_id
for event in state_events:
if not event["type"] == "m.room.member":
if event["type"] == "m.room.name":
room_name = event["content"].get("name")
continue
if not event["content"].get("membership") == "join":
continue
if event["state_key"].startswith(user_id_prefix):
kick_list.append(event["state_key"])
if len(kick_list) == 0:
print "No user IDs match the prefix '%s'" % user_id_prefix
return
print "The following user IDs will be kicked from %s" % room_name
for uid in kick_list:
print uid
doit = raw_input("Continue? [Y]es\n")
if len(doit) > 0 and doit.lower() == 'y':
print "Kicking members..."
# encode them all
kick_list = [urllib.quote(uid) for uid in kick_list]
for uid in kick_list:
kick_url = _mkurl(
"$HS/_matrix/client/api/v1/rooms/$ROOM/state/m.room.member/$UID?access_token=$TOKEN",
{
"$HS": hs,
"$UID": uid,
"$ROOM": room_id,
"$TOKEN": access_token
}
)
kick_body = {
"membership": "leave",
"reason": why
}
print "Kicking %s" % uid
res = requests.put(kick_url, data=json.dumps(kick_body))
if res.status_code != 200:
print "ERROR: HTTP %s" % res.status_code
if res.json().get("error"):
print "ERROR: JSON %s" % res.json()
if __name__ == "__main__":
parser = ArgumentParser("Kick members in a room matching a certain user ID prefix.")
parser.add_argument("-u","--user-id",help="The user ID prefix e.g. '@irc_'")
parser.add_argument("-t","--token",help="Your access_token")
parser.add_argument("-r","--room",help="The room ID to kick members in")
parser.add_argument("-s","--homeserver",help="The base HS url e.g. http://matrix.org")
parser.add_argument("-w","--why",help="Reason for the kick. Optional.")
args = parser.parse_args()
if not args.room or not args.token or not args.user_id or not args.homeserver:
parser.print_help()
sys.exit(1)
else:
main(args.homeserver, args.room, args.token, args.user_id, args.why)

View File

@@ -1,25 +0,0 @@
version: 1
# In systemd's journal, loglevel is implicitly stored, so let's omit it
# from the message text.
formatters:
journal_fmt:
format: '%(name)s: [%(request)s] %(message)s'
filters:
context:
(): synapse.util.logcontext.LoggingContextFilter
request: ""
handlers:
journal:
class: systemd.journal.JournalHandler
formatter: journal_fmt
filters: [context]
SYSLOG_IDENTIFIER: synapse
root:
level: INFO
handlers: [journal]
disable_existing_loggers: False

View File

@@ -1,16 +0,0 @@
# This assumes that Synapse has been installed as a system package
# (e.g. https://aur.archlinux.org/packages/matrix-synapse/ for ArchLinux)
# rather than in a user home directory or similar under virtualenv.
[Unit]
Description=Synapse Matrix homeserver
[Service]
Type=simple
User=synapse
Group=synapse
WorkingDirectory=/var/lib/synapse
ExecStart=/usr/bin/python2.7 -m synapse.app.homeserver --config-path=/etc/synapse/homeserver.yaml --log-config=/etc/synapse/log_config.yaml
[Install]
WantedBy=multi-user.target

View File

@@ -126,26 +126,12 @@ sub on_unknown_event
if (!$bridgestate->{$room_id}->{gathered_candidates}) {
$bridgestate->{$room_id}->{gathered_candidates} = 1;
my $offer = $bridgestate->{$room_id}->{offer};
my $candidate_block = {
audio => '',
video => '',
};
my $candidate_block = "";
foreach (@{$event->{content}->{candidates}}) {
if ($_->{sdpMid}) {
$candidate_block->{$_->{sdpMid}} .= "a=" . $_->{candidate} . "\r\n";
}
else {
$candidate_block->{audio} .= "a=" . $_->{candidate} . "\r\n";
$candidate_block->{video} .= "a=" . $_->{candidate} . "\r\n";
}
$candidate_block .= "a=" . $_->{candidate} . "\r\n";
}
# XXX: assumes audio comes first
#$offer =~ s/(a=rtcp-mux[\r\n]+)/$1$candidate_block->{audio}/;
#$offer =~ s/(a=rtcp-mux[\r\n]+)/$1$candidate_block->{video}/;
$offer =~ s/(m=video)/$candidate_block->{audio}$1/;
$offer =~ s/(.$)/$1\n$candidate_block->{video}$1/;
# XXX: collate using the right m= line - for now assume audio call
$offer =~ s/(a=rtcp.*[\r\n]+)/$1$candidate_block/;
my $f = send_verto_json_request("verto.invite", {
"sdp" => $offer,
@@ -186,18 +172,23 @@ sub on_room_message
warn "[Matrix] in $room_id: $from: " . $content->{body} . "\n";
}
my $verto_connecting = $loop->new_future;
$bot_verto->connect(
%{ $CONFIG{"verto-bot"} },
on_connected => sub {
warn("[Verto] connected to websocket");
$verto_connecting->done($bot_verto) if not $verto_connecting->is_done;
},
on_connect_error => sub { die "Cannot connect to verto - $_[-1]" },
on_resolve_error => sub { die "Cannot resolve to verto - $_[-1]" },
);
Future->needs_all(
$bot_matrix->login( %{ $CONFIG{"matrix-bot"} } )->then( sub {
$bot_matrix->start;
}),
$bot_verto->connect(
%{ $CONFIG{"verto-bot"} },
on_connect_error => sub { die "Cannot connect to verto - $_[-1]" },
on_resolve_error => sub { die "Cannot resolve to verto - $_[-1]" },
)->on_done( sub {
warn("[Verto] connected to websocket");
}),
$verto_connecting,
)->get;
$loop->attach_signal(

View File

@@ -86,7 +86,7 @@ sub create_virtual_user
"user": "$localpart"
}
EOT
)->get;
)->get;
warn $response->as_string if ($response->code != 200);
}
@@ -266,21 +266,17 @@ my $as_url = $CONFIG{"matrix-bot"}->{as_url};
Future->needs_all(
$http->do_request(
method => "POST",
uri => URI->new( $CONFIG{"matrix"}->{server}."/_matrix/appservice/v1/register" ),
content_type => "application/json",
content => <<EOT
method => "POST",
uri => URI->new( $CONFIG{"matrix"}->{server}."/_matrix/appservice/v1/register" ),
content_type => "application/json",
content => <<EOT
{
"as_token": "$as_token",
"url": "$as_url",
"namespaces": { "users": [ { "regex": "\@\\\\+.*", "exclusive": false } ] }
"namespaces": { "users": ["\@\\\\+.*"] }
}
EOT
)->then( sub{
my ($response) = (@_);
warn $response->as_string if ($response->code != 200);
return Future->done;
}),
),
$verto_connecting,
)->get;

View File

@@ -7,9 +7,6 @@ matrix:
matrix-bot:
user_id: '@vertobot:matrix.org'
password: ''
domain: 'matrix.org"
as_url: 'http://localhost:8009'
as_token: 'vertobot123'
verto-bot:
host: webrtc.freeswitch.org

View File

@@ -11,4 +11,7 @@ requires 'YAML', 0;
requires 'JSON', 0;
requires 'Getopt::Long', 0;
on 'test' => sub {
requires 'Test::More', '>= 0.98';
};

View File

@@ -11,9 +11,7 @@ if [ -f $PID_FILE ]; then
exit 1
fi
for port in 8080 8081 8082; do
rm -rf $DIR/$port
rm -rf $DIR/media_store.$port
done
find "$DIR" -name "*.log" -delete
find "$DIR" -name "*.db" -delete
rm -rf $DIR/etc

View File

@@ -8,42 +8,37 @@ cd "$DIR/.."
mkdir -p demo/etc
export PYTHONPATH=$(readlink -f $(pwd))
echo $PYTHONPATH
# Check the --no-rate-limit param
PARAMS=""
if [ $# -eq 1 ]; then
if [ $1 = "--no-rate-limit" ]; then
PARAMS="--rc-messages-per-second 1000 --rc-message-burst-count 1000"
fi
fi
for port in 8080 8081 8082; do
echo "Starting server on port $port... "
https_port=$((port + 400))
mkdir -p demo/$port
pushd demo/$port
#rm $DIR/etc/$port.config
python -m synapse.app.homeserver \
--generate-config \
--config-path "demo/etc/$port.config" \
-p "$https_port" \
--unsecure-port "$port" \
-H "localhost:$https_port" \
--config-path "$DIR/etc/$port.config" \
--report-stats no
# Check script parameters
if [ $# -eq 1 ]; then
if [ $1 = "--no-rate-limit" ]; then
# Set high limits in config file to disable rate limiting
perl -p -i -e 's/rc_messages_per_second.*/rc_messages_per_second: 1000/g' $DIR/etc/$port.config
perl -p -i -e 's/rc_message_burst_count.*/rc_message_burst_count: 1000/g' $DIR/etc/$port.config
fi
fi
perl -p -i -e 's/^enable_registration:.*/enable_registration: true/g' $DIR/etc/$port.config
-f "$DIR/$port.log" \
-d "$DIR/$port.db" \
-D --pid-file "$DIR/$port.pid" \
--manhole $((port + 1000)) \
--tls-dh-params-path "demo/demo.tls.dh" \
--media-store-path "demo/media_store.$port" \
$PARAMS $SYNAPSE_PARAMS \
python -m synapse.app.homeserver \
--config-path "$DIR/etc/$port.config" \
-D \
--config-path "demo/etc/$port.config" \
-vv \
popd
done
cd "$CWD"

View File

@@ -1,31 +0,0 @@
Captcha can be enabled for this home server. This file explains how to do that.
The captcha mechanism used is Google's ReCaptcha. This requires API keys from Google.
Getting keys
------------
Requires a public/private key pair from:
https://developers.google.com/recaptcha/
Setting ReCaptcha Keys
----------------------
The keys are a config option on the home server config. If they are not
visible, you can generate them via --generate-config. Set the following value:
recaptcha_public_key: YOUR_PUBLIC_KEY
recaptcha_private_key: YOUR_PRIVATE_KEY
In addition, you MUST enable captchas via:
enable_registration_captcha: true
Configuring IP used for auth
----------------------------
The ReCaptcha API requires that the IP address of the user who solved the
captcha is sent. If the client is connecting through a proxy or load balancer,
it may be required to use the X-Forwarded-For (XFF) header instead of the origin
IP address. This can be configured as an option on the home server like so:
captcha_ip_origin_is_x_forwarded: true

View File

@@ -1,36 +0,0 @@
Registering an Application Service
==================================
The registration of new application services depends on the homeserver used.
In synapse, you need to create a new configuration file for your AS and add it
to the list specified under the ``app_service_config_files`` config
option in your synapse config.
For example:
.. code-block:: yaml
app_service_config_files:
- /home/matrix/.synapse/<your-AS>.yaml
The format of the AS configuration file is as follows:
.. code-block:: yaml
url: <base url of AS>
as_token: <token AS will add to requests to HS>
hs_token: <token HS will add to requests to AS>
sender_localpart: <localpart of AS user>
namespaces:
users: # List of users we're interested in
- exclusive: <bool>
regex: <regex>
- ...
aliases: [] # List of aliases we're interested in
rooms: [] # List of room ids we're interested in
See the spec_ for further details on how application services work.
.. _spec: https://github.com/matrix-org/matrix-doc/blob/master/specification/25_application_service_api.rst#application-service-api

View File

@@ -1,50 +0,0 @@
How to monitor Synapse metrics using Prometheus
===============================================
1: Install prometheus:
Follow instructions at http://prometheus.io/docs/introduction/install/
2: Enable synapse metrics:
Simply setting a (local) port number will enable it. Pick a port.
prometheus itself defaults to 9090, so starting just above that for
locally monitored services seems reasonable. E.g. 9092:
Add to homeserver.yaml
metrics_port: 9092
Restart synapse
3: Check out synapse-prometheus-config
https://github.com/matrix-org/synapse-prometheus-config
4: Add ``synapse.html`` and ``synapse.rules``
The ``.html`` file needs to appear in prometheus's ``consoles`` directory,
and the ``.rules`` file needs to be invoked somewhere in the main config
file. A symlink to each from the git checkout into the prometheus directory
might be easiest to ensure ``git pull`` keeps it updated.
5: Add a prometheus target for synapse
This is easiest if prometheus runs on the same machine as synapse, as it can
then just use localhost::
global: {
rule_file: "synapse.rules"
}
job: {
name: "synapse"
target_group: {
target: "http://localhost:9092/"
}
}
6: Start prometheus::
./prometheus -config.file=prometheus.conf
7: Wait a few seconds for it to start and perform the first scrape,
then visit the console:
http://server-where-prometheus-runs:9090/consoles/synapse.html

View File

@@ -1,107 +0,0 @@
Using Postgres
--------------
Set up database
===============
The PostgreSQL database used *must* have the correct encoding set, otherwise
would not be able to store UTF8 strings. To create a database with the correct
encoding use, e.g.::
CREATE DATABASE synapse
ENCODING 'UTF8'
LC_COLLATE='C'
LC_CTYPE='C'
template=template0
OWNER synapse_user;
This would create an appropriate database named ``synapse`` owned by the
``synapse_user`` user (which must already exist).
Set up client
=============
Postgres support depends on the postgres python connector ``psycopg2``. In the
virtual env::
sudo apt-get install libpq-dev
pip install psycopg2
Synapse config
==============
When you are ready to start using PostgreSQL, add the following line to your
config file::
database:
name: psycopg2
args:
user: <user>
password: <pass>
database: <db>
host: <host>
cp_min: 5
cp_max: 10
All key, values in ``args`` are passed to the ``psycopg2.connect(..)``
function, except keys beginning with ``cp_``, which are consumed by the twisted
adbapi connection pool.
Porting from SQLite
===================
Overview
~~~~~~~~
The script ``synapse_port_db`` allows porting an existing synapse server
backed by SQLite to using PostgreSQL. This is done in as a two phase process:
1. Copy the existing SQLite database to a separate location (while the server
is down) and running the port script against that offline database.
2. Shut down the server. Rerun the port script to port any data that has come
in since taking the first snapshot. Restart server against the PostgreSQL
database.
The port script is designed to be run repeatedly against newer snapshots of the
SQLite database file. This makes it safe to repeat step 1 if there was a delay
between taking the previous snapshot and being ready to do step 2.
It is safe to at any time kill the port script and restart it.
Using the port script
~~~~~~~~~~~~~~~~~~~~~
Firstly, shut down the currently running synapse server and copy its database
file (typically ``homeserver.db``) to another location. Once the copy is
complete, restart synapse. For instance::
./synctl stop
cp homeserver.db homeserver.db.snapshot
./synctl start
Assuming your new config file (as described in the section *Synapse config*)
is named ``homeserver-postgres.yaml`` and the SQLite snapshot is at
``homeserver.db.snapshot`` then simply run::
synapse_port_db --sqlite-database homeserver.db.snapshot \
--postgres-config homeserver-postgres.yaml
The flag ``--curses`` displays a coloured curses progress UI.
If the script took a long time to complete, or time has otherwise passed since
the original snapshot was taken, repeat the previous steps with a newer
snapshot.
To complete the conversion shut down the synapse server and run the port
script one last time, e.g. if the SQLite database is at ``homeserver.db``
run::
synapse_port_db --sqlite-database homeserver.db \
--postgres-config database_config.yaml
Once that has completed, change the synapse config to point at the PostgreSQL
database configuration file using the ``database_config`` parameter (see
`Synapse Config`_) and restart synapse. Synapse should now be running against
PostgreSQL.

View File

@@ -1,116 +0,0 @@
import psycopg2
import yaml
import sys
import json
import time
import hashlib
from unpaddedbase64 import encode_base64
from signedjson.key import read_signing_keys
from signedjson.sign import sign_json
from canonicaljson import encode_canonical_json
def select_v1_keys(connection):
cursor = connection.cursor()
cursor.execute("SELECT server_name, key_id, verify_key FROM server_signature_keys")
rows = cursor.fetchall()
cursor.close()
results = {}
for server_name, key_id, verify_key in rows:
results.setdefault(server_name, {})[key_id] = encode_base64(verify_key)
return results
def select_v1_certs(connection):
cursor = connection.cursor()
cursor.execute("SELECT server_name, tls_certificate FROM server_tls_certificates")
rows = cursor.fetchall()
cursor.close()
results = {}
for server_name, tls_certificate in rows:
results[server_name] = tls_certificate
return results
def select_v2_json(connection):
cursor = connection.cursor()
cursor.execute("SELECT server_name, key_id, key_json FROM server_keys_json")
rows = cursor.fetchall()
cursor.close()
results = {}
for server_name, key_id, key_json in rows:
results.setdefault(server_name, {})[key_id] = json.loads(str(key_json).decode("utf-8"))
return results
def convert_v1_to_v2(server_name, valid_until, keys, certificate):
return {
"old_verify_keys": {},
"server_name": server_name,
"verify_keys": {
key_id: {"key": key}
for key_id, key in keys.items()
},
"valid_until_ts": valid_until,
"tls_fingerprints": [fingerprint(certificate)],
}
def fingerprint(certificate):
finger = hashlib.sha256(certificate)
return {"sha256": encode_base64(finger.digest())}
def rows_v2(server, json):
valid_until = json["valid_until_ts"]
key_json = encode_canonical_json(json)
for key_id in json["verify_keys"]:
yield (server, key_id, "-", valid_until, valid_until, buffer(key_json))
def main():
config = yaml.load(open(sys.argv[1]))
valid_until = int(time.time() / (3600 * 24)) * 1000 * 3600 * 24
server_name = config["server_name"]
signing_key = read_signing_keys(open(config["signing_key_path"]))[0]
database = config["database"]
assert database["name"] == "psycopg2", "Can only convert for postgresql"
args = database["args"]
args.pop("cp_max")
args.pop("cp_min")
connection = psycopg2.connect(**args)
keys = select_v1_keys(connection)
certificates = select_v1_certs(connection)
json = select_v2_json(connection)
result = {}
for server in keys:
if not server in json:
v2_json = convert_v1_to_v2(
server, valid_until, keys[server], certificates[server]
)
v2_json = sign_json(v2_json, server_name, signing_key)
result[server] = v2_json
yaml.safe_dump(result, sys.stdout, default_flow_style=False)
rows = list(
row for server, json in result.items()
for row in rows_v2(server, json)
)
cursor = connection.cursor()
cursor.executemany(
"INSERT INTO server_keys_json ("
" server_name, key_id, from_server,"
" ts_added_ms, ts_valid_until_ms, key_json"
") VALUES (%s, %s, %s, %s, %s, %s)",
rows
)
connection.commit()
if __name__ == '__main__':
main()

View File

@@ -1,142 +0,0 @@
#! /usr/bin/python
import ast
import yaml
class DefinitionVisitor(ast.NodeVisitor):
def __init__(self):
super(DefinitionVisitor, self).__init__()
self.functions = {}
self.classes = {}
self.names = {}
self.attrs = set()
self.definitions = {
'def': self.functions,
'class': self.classes,
'names': self.names,
'attrs': self.attrs,
}
def visit_Name(self, node):
self.names.setdefault(type(node.ctx).__name__, set()).add(node.id)
def visit_Attribute(self, node):
self.attrs.add(node.attr)
for child in ast.iter_child_nodes(node):
self.visit(child)
def visit_ClassDef(self, node):
visitor = DefinitionVisitor()
self.classes[node.name] = visitor.definitions
for child in ast.iter_child_nodes(node):
visitor.visit(child)
def visit_FunctionDef(self, node):
visitor = DefinitionVisitor()
self.functions[node.name] = visitor.definitions
for child in ast.iter_child_nodes(node):
visitor.visit(child)
def non_empty(defs):
functions = {name: non_empty(f) for name, f in defs['def'].items()}
classes = {name: non_empty(f) for name, f in defs['class'].items()}
result = {}
if functions: result['def'] = functions
if classes: result['class'] = classes
names = defs['names']
uses = []
for name in names.get('Load', ()):
if name not in names.get('Param', ()) and name not in names.get('Store', ()):
uses.append(name)
uses.extend(defs['attrs'])
if uses: result['uses'] = uses
result['names'] = names
result['attrs'] = defs['attrs']
return result
def definitions_in_code(input_code):
input_ast = ast.parse(input_code)
visitor = DefinitionVisitor()
visitor.visit(input_ast)
definitions = non_empty(visitor.definitions)
return definitions
def definitions_in_file(filepath):
with open(filepath) as f:
return definitions_in_code(f.read())
def defined_names(prefix, defs, names):
for name, funcs in defs.get('def', {}).items():
names.setdefault(name, {'defined': []})['defined'].append(prefix + name)
defined_names(prefix + name + ".", funcs, names)
for name, funcs in defs.get('class', {}).items():
names.setdefault(name, {'defined': []})['defined'].append(prefix + name)
defined_names(prefix + name + ".", funcs, names)
def used_names(prefix, defs, names):
for name, funcs in defs.get('def', {}).items():
used_names(prefix + name + ".", funcs, names)
for name, funcs in defs.get('class', {}).items():
used_names(prefix + name + ".", funcs, names)
for used in defs.get('uses', ()):
if used in names:
names[used].setdefault('used', []).append(prefix.rstrip('.'))
if __name__ == '__main__':
import sys, os, argparse, re
parser = argparse.ArgumentParser(description='Find definitions.')
parser.add_argument(
"--unused", action="store_true", help="Only list unused definitions"
)
parser.add_argument(
"--ignore", action="append", metavar="REGEXP", help="Ignore a pattern"
)
parser.add_argument(
"--pattern", action="append", metavar="REGEXP",
help="Search for a pattern"
)
parser.add_argument(
"directories", nargs='+', metavar="DIR",
help="Directories to search for definitions"
)
args = parser.parse_args()
definitions = {}
for directory in args.directories:
for root, dirs, files in os.walk(directory):
for filename in files:
if filename.endswith(".py"):
filepath = os.path.join(root, filename)
definitions[filepath] = definitions_in_file(filepath)
names = {}
for filepath, defs in definitions.items():
defined_names(filepath + ":", defs, names)
for filepath, defs in definitions.items():
used_names(filepath + ":", defs, names)
patterns = [re.compile(pattern) for pattern in args.pattern or ()]
ignore = [re.compile(pattern) for pattern in args.ignore or ()]
result = {}
for name, definition in names.items():
if patterns and not any(pattern.match(name) for pattern in patterns):
continue
if ignore and any(pattern.match(name) for pattern in ignore):
continue
if args.unused and definition.get('used'):
continue
result[name] = definition
yaml.dump(result, sys.stdout, default_flow_style=False)

View File

@@ -56,9 +56,10 @@ if __name__ == '__main__':
js = json.load(args.json)
auth = Auth(Mock())
check_auth(
auth,
[FrozenEvent(d) for d in js["auth_chain"]],
[FrozenEvent(d) for d in js.get("pdus", [])],
[FrozenEvent(d) for d in js["pdus"]],
)

View File

@@ -1,5 +1,5 @@
from synapse.crypto.event_signing import *
from unpaddedbase64 import encode_base64
from syutil.base64util import encode_base64
import argparse
import hashlib

View File

@@ -1,7 +1,9 @@
from signedjson.sign import verify_signed_json
from signedjson.key import decode_verify_key_bytes, write_signing_keys
from unpaddedbase64 import decode_base64
from syutil.crypto.jsonsign import verify_signed_json
from syutil.crypto.signing_key import (
decode_verify_key_bytes, write_signing_keys
)
from syutil.base64util import decode_base64
import urllib2
import json

View File

@@ -0,0 +1,21 @@
#!/bin/bash
# This is will prepare a synapse database for running with v0.0.1 of synapse.
# It will store all the user information, but will *delete* all messages and
# room data.
set -e
cp "$1" "$1.bak"
DUMP=$(sqlite3 "$1" << 'EOF'
.dump users
.dump access_tokens
.dump presence
.dump profiles
EOF
)
rm "$1"
sqlite3 "$1" <<< "$DUMP"

View File

@@ -0,0 +1,21 @@
#!/bin/bash
# This is will prepare a synapse database for running with v0.5.0 of synapse.
# It will store all the user information, but will *delete* all messages and
# room data.
set -e
cp "$1" "$1.bak"
DUMP=$(sqlite3 "$1" << 'EOF'
.dump users
.dump access_tokens
.dump presence
.dump profiles
EOF
)
rm "$1"
sqlite3 "$1" <<< "$DUMP"

213
scripts/graph_tracer.py Normal file
View File

@@ -0,0 +1,213 @@
import fileinput
import pydot
import sys
import itertools
import json
def pairwise(iterable):
"s -> (s0,s1), (s1,s2), (s2, s3), ..."
a, b = itertools.tee(iterable)
next(b, None)
return itertools.izip(a, b)
nodes = {}
edges = set()
graph = pydot.Dot(graph_name="call_graph", graph_type="digraph")
names = {}
starts = {}
ends = {}
deferreds = set()
deferreds_map = {}
deferred_edges = set()
root_id = None
for line in fileinput.input():
line = line.strip()
try:
if " calls " in line:
start, end = line.split(" calls ")
start, end = start.strip(), end.strip()
edges.add((start, end))
# print start, end
if " named " in line:
node_id, name = line.split(" named ")
names[node_id.strip()] = name.strip()
if name.strip() == "synapse.rest.client.v1.room.RoomSendEventRestServlet.on_PUT":
root_id = node_id
if " in " in line:
node_id, d = line.split(" in ")
deferreds_map[node_id.strip()] = d.strip()
if " is deferred" in line:
node_id, _ = line.split(" is deferred")
deferreds.add(node_id)
if " start " in line:
node_id, ms = line.split(" start ")
starts[node_id.strip()] = int(ms.strip())
if " end " in line:
node_id, ms = line.split(" end ")
ends[node_id.strip()] = int(ms.strip())
if " waits on " in line:
start, end = line.split(" waits on ")
start, end = start.strip(), end.strip()
deferred_edges.add((start, end))
# print start, end
except Exception as e:
sys.stderr.write("failed %s to parse '%s'\n" % (e.message, line))
if not root_id:
sys.stderr.write("Could not find root")
sys.exit(1)
# deferreds_root = set(deferreds.values())
# for parent, child in deferred_edges:
# deferreds_root.discard(child)
#
# deferred_tree = {
# d: {}
# for d in deferreds_root
# }
#
# def populate(root, tree):
# for leaf in deferred_edges.get(root, []):
# populate(leaf, tree.setdefault(leaf, {}))
#
#
# for d in deferreds_root:
# tree = deferred_tree.setdefault(d, {})
# populate(d, tree)
# print deferred_edges
# print root_id
def is_in_deferred(d):
while True:
if d == root_id:
return True
for start, end in deferred_edges:
if d == end:
d = start
break
else:
return False
def walk_graph(d):
res = [d]
while d != root_id:
for start, end in edges:
if d == end:
d = start
res.append(d)
break
else:
return res
return res
def make_tree_el(node_id):
return {
"id": node_id,
"name": names[node_id],
"children": [],
"start": starts[node_id],
"end": ends[node_id],
"size": ends[node_id] - starts[node_id],
}
tree = make_tree_el(root_id)
tree_index = {
root_id: tree,
}
viz_out = {
"nodes": [],
"edges": [],
}
for node_id, name in names.items():
# if times.get(node_id, 100) < 5:
# continue
walk = walk_graph(node_id)
# print walk
if root_id not in walk:
continue
if node_id in deferreds:
if not is_in_deferred(node_id):
continue
elif node_id in deferreds_map:
if not is_in_deferred(deferreds_map[node_id]):
continue
walk_names = [
names[w].split("synapse.", 1)[1] for w in walk
]
for child, parent in reversed(list(pairwise(walk))):
if parent in tree_index and child not in tree_index:
el = make_tree_el(child)
tree_index[parent]["children"].append(el)
tree_index[child] = el
# print "-".join(reversed(["end"] + walk_names)) + ", " + str(ends[node_id] - starts[node_id])
# print "%d,%s,%s,%s" % (len(walk), walk_names[0], starts[node_id], ends[node_id])
viz_out["nodes"].append({
"id": node_id,
"label": names[node_id].split("synapse.", 1)[1],
"value": ends[node_id] - starts[node_id],
"level": len(walk),
})
node = pydot.Node(node_id, label=name)
# if node_id in deferreds:
# clusters[deferreds[node_id]].add_node(node)
# elif node_id in clusters:
# clusters[node_id].add_node(node)
# else:
# graph.add_node(node)
graph.add_node(node)
nodes[node_id] = node
# print node_id
# for el in tree_index.values():
# el["children"].sort(key=lambda e: e["start"])
#
# print json.dumps(tree)
for parent, child in edges:
if child not in nodes:
# sys.stderr.write(child + " not a node\n")
continue
if parent not in nodes:
# sys.stderr.write(parent + " not a node\n")
continue
viz_out["edges"].append({
"from": parent,
"to": child,
"value": ends[child] - starts[child],
})
edge = pydot.Edge(nodes[parent], nodes[child])
graph.add_edge(edge)
print json.dumps(viz_out)
file_prefix = "call_graph_out"
graph.write('%s.dot' % file_prefix, format='raw', prog='dot')
graph.write_svg("%s.svg" % file_prefix, prog='dot')

View File

@@ -6,8 +6,8 @@ from synapse.crypto.event_signing import (
add_event_pdu_content_hash, compute_pdu_event_reference_hash
)
from synapse.api.events.utils import prune_pdu
from unpaddedbase64 import encode_base64, decode_base64
from canonicaljson import encode_canonical_json
from syutil.base64util import encode_base64, decode_base64
from syutil.jsonutil import encode_canonical_json
import sqlite3
import sys

View File

@@ -1,154 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import getpass
import hashlib
import hmac
import json
import sys
import urllib2
import yaml
def request_registration(user, password, server_location, shared_secret):
mac = hmac.new(
key=shared_secret,
msg=user,
digestmod=hashlib.sha1,
).hexdigest()
data = {
"user": user,
"password": password,
"mac": mac,
"type": "org.matrix.login.shared_secret",
}
server_location = server_location.rstrip("/")
print "Sending registration request..."
req = urllib2.Request(
"%s/_matrix/client/api/v1/register" % (server_location,),
data=json.dumps(data),
headers={'Content-Type': 'application/json'}
)
try:
if sys.version_info[:3] >= (2, 7, 9):
# As of version 2.7.9, urllib2 now checks SSL certs
import ssl
f = urllib2.urlopen(req, context=ssl.SSLContext(ssl.PROTOCOL_SSLv23))
else:
f = urllib2.urlopen(req)
f.read()
f.close()
print "Success."
except urllib2.HTTPError as e:
print "ERROR! Received %d %s" % (e.code, e.reason,)
if 400 <= e.code < 500:
if e.info().type == "application/json":
resp = json.load(e)
if "error" in resp:
print resp["error"]
sys.exit(1)
def register_new_user(user, password, server_location, shared_secret):
if not user:
try:
default_user = getpass.getuser()
except:
default_user = None
if default_user:
user = raw_input("New user localpart [%s]: " % (default_user,))
if not user:
user = default_user
else:
user = raw_input("New user localpart: ")
if not user:
print "Invalid user name"
sys.exit(1)
if not password:
password = getpass.getpass("Password: ")
if not password:
print "Password cannot be blank."
sys.exit(1)
confirm_password = getpass.getpass("Confirm password: ")
if password != confirm_password:
print "Passwords do not match"
sys.exit(1)
request_registration(user, password, server_location, shared_secret)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Used to register new users with a given home server when"
" registration has been disabled. The home server must be"
" configured with the 'registration_shared_secret' option"
" set.",
)
parser.add_argument(
"-u", "--user",
default=None,
help="Local part of the new user. Will prompt if omitted.",
)
parser.add_argument(
"-p", "--password",
default=None,
help="New password for user. Will prompt if omitted.",
)
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument(
"-c", "--config",
type=argparse.FileType('r'),
help="Path to server config file. Used to read in shared secret.",
)
group.add_argument(
"-k", "--shared-secret",
help="Shared secret as defined in server config file.",
)
parser.add_argument(
"server_url",
default="https://localhost:8448",
nargs='?',
help="URL to use to talk to the home server. Defaults to "
" 'https://localhost:8448'.",
)
args = parser.parse_args()
if "config" in args and args.config:
config = yaml.safe_load(args.config)
secret = config.get("registration_shared_secret", None)
if not secret:
print "No 'registration_shared_secret' defined in config."
sys.exit(1)
else:
secret = args.shared_secret
register_new_user(args.user, args.password, args.server_url, secret)

View File

@@ -1,761 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer, reactor
from twisted.enterprise import adbapi
from synapse.storage._base import LoggingTransaction, SQLBaseStore
from synapse.storage.engines import create_engine
import argparse
import curses
import logging
import sys
import time
import traceback
import yaml
logger = logging.getLogger("synapse_port_db")
BOOLEAN_COLUMNS = {
"events": ["processed", "outlier"],
"rooms": ["is_public"],
"event_edges": ["is_state"],
"presence_list": ["accepted"],
}
APPEND_ONLY_TABLES = [
"event_content_hashes",
"event_reference_hashes",
"event_signatures",
"event_edge_hashes",
"events",
"event_json",
"state_events",
"room_memberships",
"feedback",
"topics",
"room_names",
"rooms",
"local_media_repository",
"local_media_repository_thumbnails",
"remote_media_cache",
"remote_media_cache_thumbnails",
"redactions",
"event_edges",
"event_auth",
"received_transactions",
"sent_transactions",
"transaction_id_to_pdu",
"users",
"state_groups",
"state_groups_state",
"event_to_state_groups",
"rejections",
]
end_error_exec_info = None
class Store(object):
"""This object is used to pull out some of the convenience API from the
Storage layer.
*All* database interactions should go through this object.
"""
def __init__(self, db_pool, engine):
self.db_pool = db_pool
self.database_engine = engine
_simple_insert_txn = SQLBaseStore.__dict__["_simple_insert_txn"]
_simple_insert = SQLBaseStore.__dict__["_simple_insert"]
_simple_select_onecol_txn = SQLBaseStore.__dict__["_simple_select_onecol_txn"]
_simple_select_onecol = SQLBaseStore.__dict__["_simple_select_onecol"]
_simple_select_one_onecol = SQLBaseStore.__dict__["_simple_select_one_onecol"]
_simple_select_one_onecol_txn = SQLBaseStore.__dict__["_simple_select_one_onecol_txn"]
_simple_update_one = SQLBaseStore.__dict__["_simple_update_one"]
_simple_update_one_txn = SQLBaseStore.__dict__["_simple_update_one_txn"]
def runInteraction(self, desc, func, *args, **kwargs):
def r(conn):
try:
i = 0
N = 5
while True:
try:
txn = conn.cursor()
return func(
LoggingTransaction(txn, desc, self.database_engine, []),
*args, **kwargs
)
except self.database_engine.module.DatabaseError as e:
if self.database_engine.is_deadlock(e):
logger.warn("[TXN DEADLOCK] {%s} %d/%d", desc, i, N)
if i < N:
i += 1
conn.rollback()
continue
raise
except Exception as e:
logger.debug("[TXN FAIL] {%s} %s", desc, e)
raise
return self.db_pool.runWithConnection(r)
def execute(self, f, *args, **kwargs):
return self.runInteraction(f.__name__, f, *args, **kwargs)
def execute_sql(self, sql, *args):
def r(txn):
txn.execute(sql, args)
return txn.fetchall()
return self.runInteraction("execute_sql", r)
def insert_many_txn(self, txn, table, headers, rows):
sql = "INSERT INTO %s (%s) VALUES (%s)" % (
table,
", ".join(k for k in headers),
", ".join("%s" for _ in headers)
)
try:
txn.executemany(sql, rows)
except:
logger.exception(
"Failed to insert: %s",
table,
)
raise
class Porter(object):
def __init__(self, **kwargs):
self.__dict__.update(kwargs)
@defer.inlineCallbacks
def setup_table(self, table):
if table in APPEND_ONLY_TABLES:
# It's safe to just carry on inserting.
next_chunk = yield self.postgres_store._simple_select_one_onecol(
table="port_from_sqlite3",
keyvalues={"table_name": table},
retcol="rowid",
allow_none=True,
)
total_to_port = None
if next_chunk is None:
if table == "sent_transactions":
next_chunk, already_ported, total_to_port = (
yield self._setup_sent_transactions()
)
else:
yield self.postgres_store._simple_insert(
table="port_from_sqlite3",
values={"table_name": table, "rowid": 1}
)
next_chunk = 1
already_ported = 0
if total_to_port is None:
already_ported, total_to_port = yield self._get_total_count_to_port(
table, next_chunk
)
else:
def delete_all(txn):
txn.execute(
"DELETE FROM port_from_sqlite3 WHERE table_name = %s",
(table,)
)
txn.execute("TRUNCATE %s CASCADE" % (table,))
yield self.postgres_store.execute(delete_all)
yield self.postgres_store._simple_insert(
table="port_from_sqlite3",
values={"table_name": table, "rowid": 0}
)
next_chunk = 1
already_ported, total_to_port = yield self._get_total_count_to_port(
table, next_chunk
)
defer.returnValue((table, already_ported, total_to_port, next_chunk))
@defer.inlineCallbacks
def handle_table(self, table, postgres_size, table_size, next_chunk):
if not table_size:
return
self.progress.add_table(table, postgres_size, table_size)
select = (
"SELECT rowid, * FROM %s WHERE rowid >= ? ORDER BY rowid LIMIT ?"
% (table,)
)
while True:
def r(txn):
txn.execute(select, (next_chunk, self.batch_size,))
rows = txn.fetchall()
headers = [column[0] for column in txn.description]
return headers, rows
headers, rows = yield self.sqlite_store.runInteraction("select", r)
if rows:
next_chunk = rows[-1][0] + 1
self._convert_rows(table, headers, rows)
def insert(txn):
self.postgres_store.insert_many_txn(
txn, table, headers[1:], rows
)
self.postgres_store._simple_update_one_txn(
txn,
table="port_from_sqlite3",
keyvalues={"table_name": table},
updatevalues={"rowid": next_chunk},
)
yield self.postgres_store.execute(insert)
postgres_size += len(rows)
self.progress.update(table, postgres_size)
else:
return
def setup_db(self, db_config, database_engine):
db_conn = database_engine.module.connect(
**{
k: v for k, v in db_config.get("args", {}).items()
if not k.startswith("cp_")
}
)
database_engine.prepare_database(db_conn)
db_conn.commit()
@defer.inlineCallbacks
def run(self):
try:
sqlite_db_pool = adbapi.ConnectionPool(
self.sqlite_config["name"],
**self.sqlite_config["args"]
)
postgres_db_pool = adbapi.ConnectionPool(
self.postgres_config["name"],
**self.postgres_config["args"]
)
sqlite_engine = create_engine("sqlite3")
postgres_engine = create_engine("psycopg2")
self.sqlite_store = Store(sqlite_db_pool, sqlite_engine)
self.postgres_store = Store(postgres_db_pool, postgres_engine)
yield self.postgres_store.execute(
postgres_engine.check_database
)
# Step 1. Set up databases.
self.progress.set_state("Preparing SQLite3")
self.setup_db(sqlite_config, sqlite_engine)
self.progress.set_state("Preparing PostgreSQL")
self.setup_db(postgres_config, postgres_engine)
# Step 2. Get tables.
self.progress.set_state("Fetching tables")
sqlite_tables = yield self.sqlite_store._simple_select_onecol(
table="sqlite_master",
keyvalues={
"type": "table",
},
retcol="name",
)
postgres_tables = yield self.postgres_store._simple_select_onecol(
table="information_schema.tables",
keyvalues={
"table_schema": "public",
},
retcol="distinct table_name",
)
tables = set(sqlite_tables) & set(postgres_tables)
self.progress.set_state("Creating tables")
logger.info("Found %d tables", len(tables))
def create_port_table(txn):
txn.execute(
"CREATE TABLE port_from_sqlite3 ("
" table_name varchar(100) NOT NULL UNIQUE,"
" rowid bigint NOT NULL"
")"
)
try:
yield self.postgres_store.runInteraction(
"create_port_table", create_port_table
)
except Exception as e:
logger.info("Failed to create port table: %s", e)
self.progress.set_state("Setting up")
# Set up tables.
setup_res = yield defer.gatherResults(
[
self.setup_table(table)
for table in tables
if table not in ["schema_version", "applied_schema_deltas"]
and not table.startswith("sqlite_")
],
consumeErrors=True,
)
# Process tables.
yield defer.gatherResults(
[
self.handle_table(*res)
for res in setup_res
],
consumeErrors=True,
)
self.progress.done()
except:
global end_error_exec_info
end_error_exec_info = sys.exc_info()
logger.exception("")
finally:
reactor.stop()
def _convert_rows(self, table, headers, rows):
bool_col_names = BOOLEAN_COLUMNS.get(table, [])
bool_cols = [
i for i, h in enumerate(headers) if h in bool_col_names
]
def conv(j, col):
if j in bool_cols:
return bool(col)
return col
for i, row in enumerate(rows):
rows[i] = tuple(
conv(j, col)
for j, col in enumerate(row)
if j > 0
)
@defer.inlineCallbacks
def _setup_sent_transactions(self):
# Only save things from the last day
yesterday = int(time.time()*1000) - 86400000
# And save the max transaction id from each destination
select = (
"SELECT rowid, * FROM sent_transactions WHERE rowid IN ("
"SELECT max(rowid) FROM sent_transactions"
" GROUP BY destination"
")"
)
def r(txn):
txn.execute(select)
rows = txn.fetchall()
headers = [column[0] for column in txn.description]
ts_ind = headers.index('ts')
return headers, [r for r in rows if r[ts_ind] < yesterday]
headers, rows = yield self.sqlite_store.runInteraction(
"select", r,
)
self._convert_rows("sent_transactions", headers, rows)
inserted_rows = len(rows)
if inserted_rows:
max_inserted_rowid = max(r[0] for r in rows)
def insert(txn):
self.postgres_store.insert_many_txn(
txn, "sent_transactions", headers[1:], rows
)
yield self.postgres_store.execute(insert)
else:
max_inserted_rowid = 0
def get_start_id(txn):
txn.execute(
"SELECT rowid FROM sent_transactions WHERE ts >= ?"
" ORDER BY rowid ASC LIMIT 1",
(yesterday,)
)
rows = txn.fetchall()
if rows:
return rows[0][0]
else:
return 1
next_chunk = yield self.sqlite_store.execute(get_start_id)
next_chunk = max(max_inserted_rowid + 1, next_chunk)
yield self.postgres_store._simple_insert(
table="port_from_sqlite3",
values={"table_name": "sent_transactions", "rowid": next_chunk}
)
def get_sent_table_size(txn):
txn.execute(
"SELECT count(*) FROM sent_transactions"
" WHERE ts >= ?",
(yesterday,)
)
size, = txn.fetchone()
return int(size)
remaining_count = yield self.sqlite_store.execute(
get_sent_table_size
)
total_count = remaining_count + inserted_rows
defer.returnValue((next_chunk, inserted_rows, total_count))
@defer.inlineCallbacks
def _get_remaining_count_to_port(self, table, next_chunk):
rows = yield self.sqlite_store.execute_sql(
"SELECT count(*) FROM %s WHERE rowid >= ?" % (table,),
next_chunk,
)
defer.returnValue(rows[0][0])
@defer.inlineCallbacks
def _get_already_ported_count(self, table):
rows = yield self.postgres_store.execute_sql(
"SELECT count(*) FROM %s" % (table,),
)
defer.returnValue(rows[0][0])
@defer.inlineCallbacks
def _get_total_count_to_port(self, table, next_chunk):
remaining, done = yield defer.gatherResults(
[
self._get_remaining_count_to_port(table, next_chunk),
self._get_already_ported_count(table),
],
consumeErrors=True,
)
remaining = int(remaining) if remaining else 0
done = int(done) if done else 0
defer.returnValue((done, remaining + done))
##############################################
###### The following is simply UI stuff ######
##############################################
class Progress(object):
"""Used to report progress of the port
"""
def __init__(self):
self.tables = {}
self.start_time = int(time.time())
def add_table(self, table, cur, size):
self.tables[table] = {
"start": cur,
"num_done": cur,
"total": size,
"perc": int(cur * 100 / size),
}
def update(self, table, num_done):
data = self.tables[table]
data["num_done"] = num_done
data["perc"] = int(num_done * 100 / data["total"])
def done(self):
pass
class CursesProgress(Progress):
"""Reports progress to a curses window
"""
def __init__(self, stdscr):
self.stdscr = stdscr
curses.use_default_colors()
curses.curs_set(0)
curses.init_pair(1, curses.COLOR_RED, -1)
curses.init_pair(2, curses.COLOR_GREEN, -1)
self.last_update = 0
self.finished = False
self.total_processed = 0
self.total_remaining = 0
super(CursesProgress, self).__init__()
def update(self, table, num_done):
super(CursesProgress, self).update(table, num_done)
self.total_processed = 0
self.total_remaining = 0
for table, data in self.tables.items():
self.total_processed += data["num_done"] - data["start"]
self.total_remaining += data["total"] - data["num_done"]
self.render()
def render(self, force=False):
now = time.time()
if not force and now - self.last_update < 0.2:
# reactor.callLater(1, self.render)
return
self.stdscr.clear()
rows, cols = self.stdscr.getmaxyx()
duration = int(now) - int(self.start_time)
minutes, seconds = divmod(duration, 60)
duration_str = '%02dm %02ds' % (minutes, seconds,)
if self.finished:
status = "Time spent: %s (Done!)" % (duration_str,)
else:
if self.total_processed > 0:
left = float(self.total_remaining) / self.total_processed
est_remaining = (int(now) - self.start_time) * left
est_remaining_str = '%02dm %02ds remaining' % divmod(est_remaining, 60)
else:
est_remaining_str = "Unknown"
status = (
"Time spent: %s (est. remaining: %s)"
% (duration_str, est_remaining_str,)
)
self.stdscr.addstr(
0, 0,
status,
curses.A_BOLD,
)
max_len = max([len(t) for t in self.tables.keys()])
left_margin = 5
middle_space = 1
items = self.tables.items()
items.sort(
key=lambda i: (i[1]["perc"], i[0]),
)
for i, (table, data) in enumerate(items):
if i + 2 >= rows:
break
perc = data["perc"]
color = curses.color_pair(2) if perc == 100 else curses.color_pair(1)
self.stdscr.addstr(
i+2, left_margin + max_len - len(table),
table,
curses.A_BOLD | color,
)
size = 20
progress = "[%s%s]" % (
"#" * int(perc*size/100),
" " * (size - int(perc*size/100)),
)
self.stdscr.addstr(
i+2, left_margin + max_len + middle_space,
"%s %3d%% (%d/%d)" % (progress, perc, data["num_done"], data["total"]),
)
if self.finished:
self.stdscr.addstr(
rows-1, 0,
"Press any key to exit...",
)
self.stdscr.refresh()
self.last_update = time.time()
def done(self):
self.finished = True
self.render(True)
self.stdscr.getch()
def set_state(self, state):
self.stdscr.clear()
self.stdscr.addstr(
0, 0,
state + "...",
curses.A_BOLD,
)
self.stdscr.refresh()
class TerminalProgress(Progress):
"""Just prints progress to the terminal
"""
def update(self, table, num_done):
super(TerminalProgress, self).update(table, num_done)
data = self.tables[table]
print "%s: %d%% (%d/%d)" % (
table, data["perc"],
data["num_done"], data["total"],
)
def set_state(self, state):
print state + "..."
##############################################
##############################################
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="A script to port an existing synapse SQLite database to"
" a new PostgreSQL database."
)
parser.add_argument("-v", action='store_true')
parser.add_argument(
"--sqlite-database", required=True,
help="The snapshot of the SQLite database file. This must not be"
" currently used by a running synapse server"
)
parser.add_argument(
"--postgres-config", type=argparse.FileType('r'), required=True,
help="The database config file for the PostgreSQL database"
)
parser.add_argument(
"--curses", action='store_true',
help="display a curses based progress UI"
)
parser.add_argument(
"--batch-size", type=int, default=1000,
help="The number of rows to select from the SQLite table each"
" iteration [default=1000]",
)
args = parser.parse_args()
logging_config = {
"level": logging.DEBUG if args.v else logging.INFO,
"format": "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s"
}
if args.curses:
logging_config["filename"] = "port-synapse.log"
logging.basicConfig(**logging_config)
sqlite_config = {
"name": "sqlite3",
"args": {
"database": args.sqlite_database,
"cp_min": 1,
"cp_max": 1,
"check_same_thread": False,
},
}
postgres_config = yaml.safe_load(args.postgres_config)
if "database" in postgres_config:
postgres_config = postgres_config["database"]
if "name" not in postgres_config:
sys.stderr.write("Malformed database config: no 'name'")
sys.exit(2)
if postgres_config["name"] != "psycopg2":
sys.stderr.write("Database must use 'psycopg2' connector.")
sys.exit(3)
def start(stdscr=None):
if stdscr:
progress = CursesProgress(stdscr)
else:
progress = TerminalProgress()
porter = Porter(
sqlite_config=sqlite_config,
postgres_config=postgres_config,
progress=progress,
batch_size=args.batch_size,
)
reactor.callWhenRunning(porter.run)
reactor.run()
if args.curses:
curses.wrapper(start)
else:
start()
if end_error_exec_info:
exc_type, exc_value, exc_traceback = end_error_exec_info
traceback.print_exception(exc_type, exc_value, exc_traceback)

View File

@@ -0,0 +1,331 @@
from synapse.storage import SCHEMA_VERSION, read_schema
from synapse.storage._base import SQLBaseStore
from synapse.storage.signatures import SignatureStore
from synapse.storage.event_federation import EventFederationStore
from syutil.base64util import encode_base64, decode_base64
from synapse.crypto.event_signing import compute_event_signature
from synapse.events.builder import EventBuilder
from synapse.events.utils import prune_event
from synapse.crypto.event_signing import check_event_content_hash
from syutil.crypto.jsonsign import (
verify_signed_json, SignatureVerifyException,
)
from syutil.crypto.signing_key import decode_verify_key_bytes
from syutil.jsonutil import encode_canonical_json
import argparse
# import dns.resolver
import hashlib
import httplib
import json
import sqlite3
import syutil
import urllib2
delta_sql = """
CREATE TABLE IF NOT EXISTS event_json(
event_id TEXT NOT NULL,
room_id TEXT NOT NULL,
internal_metadata NOT NULL,
json BLOB NOT NULL,
CONSTRAINT ev_j_uniq UNIQUE (event_id)
);
CREATE INDEX IF NOT EXISTS event_json_id ON event_json(event_id);
CREATE INDEX IF NOT EXISTS event_json_room_id ON event_json(room_id);
PRAGMA user_version = 10;
"""
class Store(object):
_get_event_signatures_txn = SignatureStore.__dict__["_get_event_signatures_txn"]
_get_event_content_hashes_txn = SignatureStore.__dict__["_get_event_content_hashes_txn"]
_get_event_reference_hashes_txn = SignatureStore.__dict__["_get_event_reference_hashes_txn"]
_get_prev_event_hashes_txn = SignatureStore.__dict__["_get_prev_event_hashes_txn"]
_get_prev_events_and_state = EventFederationStore.__dict__["_get_prev_events_and_state"]
_get_auth_events = EventFederationStore.__dict__["_get_auth_events"]
cursor_to_dict = SQLBaseStore.__dict__["cursor_to_dict"]
_simple_select_onecol_txn = SQLBaseStore.__dict__["_simple_select_onecol_txn"]
_simple_select_list_txn = SQLBaseStore.__dict__["_simple_select_list_txn"]
_simple_insert_txn = SQLBaseStore.__dict__["_simple_insert_txn"]
def _generate_event_json(self, txn, rows):
events = []
for row in rows:
d = dict(row)
d.pop("stream_ordering", None)
d.pop("topological_ordering", None)
d.pop("processed", None)
if "origin_server_ts" not in d:
d["origin_server_ts"] = d.pop("ts", 0)
else:
d.pop("ts", 0)
d.pop("prev_state", None)
d.update(json.loads(d.pop("unrecognized_keys")))
d["sender"] = d.pop("user_id")
d["content"] = json.loads(d["content"])
if "age_ts" not in d:
# For compatibility
d["age_ts"] = d.get("origin_server_ts", 0)
d.setdefault("unsigned", {})["age_ts"] = d.pop("age_ts")
outlier = d.pop("outlier", False)
# d.pop("membership", None)
d.pop("state_hash", None)
d.pop("replaces_state", None)
b = EventBuilder(d)
b.internal_metadata.outlier = outlier
events.append(b)
for i, ev in enumerate(events):
signatures = self._get_event_signatures_txn(
txn, ev.event_id,
)
ev.signatures = {
n: {
k: encode_base64(v) for k, v in s.items()
}
for n, s in signatures.items()
}
hashes = self._get_event_content_hashes_txn(
txn, ev.event_id,
)
ev.hashes = {
k: encode_base64(v) for k, v in hashes.items()
}
prevs = self._get_prev_events_and_state(txn, ev.event_id)
ev.prev_events = [
(e_id, h)
for e_id, h, is_state in prevs
if is_state == 0
]
# ev.auth_events = self._get_auth_events(txn, ev.event_id)
hashes = dict(ev.auth_events)
for e_id, hash in ev.prev_events:
if e_id in hashes and not hash:
hash.update(hashes[e_id])
#
# if hasattr(ev, "state_key"):
# ev.prev_state = [
# (e_id, h)
# for e_id, h, is_state in prevs
# if is_state == 1
# ]
return [e.build() for e in events]
store = Store()
# def get_key(server_name):
# print "Getting keys for: %s" % (server_name,)
# targets = []
# if ":" in server_name:
# target, port = server_name.split(":")
# targets.append((target, int(port)))
# try:
# answers = dns.resolver.query("_matrix._tcp." + server_name, "SRV")
# for srv in answers:
# targets.append((srv.target, srv.port))
# except dns.resolver.NXDOMAIN:
# targets.append((server_name, 8448))
# except:
# print "Failed to lookup keys for %s" % (server_name,)
# return {}
#
# for target, port in targets:
# url = "https://%s:%i/_matrix/key/v1" % (target, port)
# try:
# keys = json.load(urllib2.urlopen(url, timeout=2))
# verify_keys = {}
# for key_id, key_base64 in keys["verify_keys"].items():
# verify_key = decode_verify_key_bytes(
# key_id, decode_base64(key_base64)
# )
# verify_signed_json(keys, server_name, verify_key)
# verify_keys[key_id] = verify_key
# print "Got keys for: %s" % (server_name,)
# return verify_keys
# except urllib2.URLError:
# pass
# except urllib2.HTTPError:
# pass
# except httplib.HTTPException:
# pass
#
# print "Failed to get keys for %s" % (server_name,)
# return {}
def reinsert_events(cursor, server_name, signing_key):
print "Running delta: v10"
cursor.executescript(delta_sql)
cursor.execute(
"SELECT * FROM events ORDER BY rowid ASC"
)
print "Getting events..."
rows = store.cursor_to_dict(cursor)
events = store._generate_event_json(cursor, rows)
print "Got events from DB."
algorithms = {
"sha256": hashlib.sha256,
}
key_id = "%s:%s" % (signing_key.alg, signing_key.version)
verify_key = signing_key.verify_key
verify_key.alg = signing_key.alg
verify_key.version = signing_key.version
server_keys = {
server_name: {
key_id: verify_key
}
}
i = 0
N = len(events)
for event in events:
if i % 100 == 0:
print "Processed: %d/%d events" % (i,N,)
i += 1
# for alg_name in event.hashes:
# if check_event_content_hash(event, algorithms[alg_name]):
# pass
# else:
# pass
# print "FAIL content hash %s %s" % (alg_name, event.event_id, )
have_own_correctly_signed = False
for host, sigs in event.signatures.items():
pruned = prune_event(event)
for key_id in sigs:
if host not in server_keys:
server_keys[host] = {} # get_key(host)
if key_id in server_keys[host]:
try:
verify_signed_json(
pruned.get_pdu_json(),
host,
server_keys[host][key_id]
)
if host == server_name:
have_own_correctly_signed = True
except SignatureVerifyException:
print "FAIL signature check %s %s" % (
key_id, event.event_id
)
# TODO: Re sign with our own server key
if not have_own_correctly_signed:
sigs = compute_event_signature(event, server_name, signing_key)
event.signatures.update(sigs)
pruned = prune_event(event)
for key_id in event.signatures[server_name]:
verify_signed_json(
pruned.get_pdu_json(),
server_name,
server_keys[server_name][key_id]
)
event_json = encode_canonical_json(
event.get_dict()
).decode("UTF-8")
metadata_json = encode_canonical_json(
event.internal_metadata.get_dict()
).decode("UTF-8")
store._simple_insert_txn(
cursor,
table="event_json",
values={
"event_id": event.event_id,
"room_id": event.room_id,
"internal_metadata": metadata_json,
"json": event_json,
},
or_replace=True,
)
def main(database, server_name, signing_key):
conn = sqlite3.connect(database)
cursor = conn.cursor()
# Do other deltas:
cursor.execute("PRAGMA user_version")
row = cursor.fetchone()
if row and row[0]:
user_version = row[0]
# Run every version since after the current version.
for v in range(user_version + 1, 10):
print "Running delta: %d" % (v,)
sql_script = read_schema("delta/v%d" % (v,))
cursor.executescript(sql_script)
reinsert_events(cursor, server_name, signing_key)
conn.commit()
print "Success!"
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("database")
parser.add_argument("server_name")
parser.add_argument(
"signing_key", type=argparse.FileType('r'),
)
args = parser.parse_args()
signing_key = syutil.crypto.signing_key.read_signing_keys(
args.signing_key
)
main(args.database, args.server_name, signing_key[0])

View File

@@ -3,6 +3,9 @@ source-dir = docs/sphinx
build-dir = docs/build
all_files = 1
[aliases]
test = trial
[trial]
test_suite = tests
@@ -13,6 +16,3 @@ ignore =
docs/*
pylint.cfg
tox.ini
[flake8]
max-line-length = 90

View File

@@ -14,10 +14,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import glob
import os
from setuptools import setup, find_packages, Command
import sys
from setuptools import setup, find_packages
here = os.path.abspath(os.path.dirname(__file__))
@@ -38,39 +36,6 @@ def exec_file(path_segments):
exec(code, result)
return result
class Tox(Command):
user_options = [('tox-args=', 'a', "Arguments to pass to tox")]
def initialize_options(self):
self.tox_args = None
def finalize_options(self):
self.test_args = []
self.test_suite = True
def run(self):
#import here, cause outside the eggs aren't loaded
try:
import tox
except ImportError:
try:
self.distribution.fetch_build_eggs("tox")
import tox
except:
raise RuntimeError(
"The tests need 'tox' to run. Please install 'tox'."
)
import shlex
args = self.tox_args
if args:
args = shlex.split(self.tox_args)
else:
args = []
errno = tox.cmdline(args=args)
sys.exit(errno)
version = exec_file(("synapse", "__init__.py"))["__version__"]
dependencies = exec_file(("synapse", "python_dependencies.py"))
long_description = read_file(("README.rst",))
@@ -80,11 +45,15 @@ setup(
version=version,
packages=find_packages(exclude=["tests", "tests.*"]),
description="Reference Synapse Home Server",
install_requires=dependencies['requirements'](include_conditional=True).keys(),
dependency_links=dependencies["DEPENDENCY_LINKS"].values(),
install_requires=dependencies["REQUIREMENTS"].keys(),
setup_requires=[
"Twisted==14.0.2", # Here to override setuptools_trial's dependency on Twisted>=2.4.0
"setuptools_trial",
"mock"
],
dependency_links=dependencies["DEPENDENCY_LINKS"],
include_package_data=True,
zip_safe=False,
long_description=long_description,
scripts=["synctl"] + glob.glob("scripts/*"),
cmdclass={'test': Tox},
scripts=["synctl"],
)

View File

@@ -37,13 +37,9 @@ textarea, input {
margin: auto
}
.g-recaptcha div {
margin: auto;
}
#registrationForm {
text-align: left;
padding: 5px;
padding: 1em;
margin-bottom: 40px;
display: inline-block;

View File

@@ -16,4 +16,4 @@
""" This is a reference implementation of a Matrix home server.
"""
__version__ = "0.10.0-r2"
__version__ = "0.8.0"

View File

@@ -18,44 +18,26 @@
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership, JoinRules
from synapse.api.errors import AuthError, Codes, SynapseError
from synapse.api.errors import AuthError, StoreError, Codes, SynapseError
from synapse.util.logutils import log_function
from synapse.types import UserID, EventID
from synapse.util.async import run_on_reactor
from synapse.types import UserID, ClientInfo
import logging
import pymacaroons
logger = logging.getLogger(__name__)
AuthEventTypes = (
EventTypes.Create, EventTypes.Member, EventTypes.PowerLevels,
EventTypes.JoinRules, EventTypes.RoomHistoryVisibility,
)
class Auth(object):
def __init__(self, hs):
self.hs = hs
self.store = hs.get_datastore()
self.state = hs.get_state_handler()
self.TOKEN_NOT_FOUND_HTTP_STATUS = 401
self._KNOWN_CAVEAT_PREFIXES = set([
"gen = ",
"type = ",
"time < ",
"user_id = ",
])
def check(self, event, auth_events):
""" Checks if this event is correctly authed.
Args:
event: the event being checked.
auth_events (dict: event-key -> event): the existing room state.
Returns:
True if the auth checks pass.
"""
@@ -72,22 +54,11 @@ class Auth(object):
# FIXME
return True
creation_event = auth_events.get((EventTypes.Create, ""), None)
if not creation_event:
raise SynapseError(
403,
"Room %r does not exist" % (event.room_id,)
)
# FIXME: Temp hack
if event.type == EventTypes.Aliases:
return True
logger.debug(
"Auth events: %s",
[a.event_id for a in auth_events.values()]
)
logger.debug("Auth events: %s", auth_events)
if event.type == EventTypes.Member:
allowed = self.is_membership_change_allowed(
@@ -106,7 +77,7 @@ class Auth(object):
self._check_power_levels(event, auth_events)
if event.type == EventTypes.Redaction:
self.check_redaction(event, auth_events)
self._check_redaction(event, auth_events)
logger.debug("Allowing! %s", event)
except AuthError as e:
@@ -119,20 +90,6 @@ class Auth(object):
@defer.inlineCallbacks
def check_joined_room(self, room_id, user_id, current_state=None):
"""Check if the user is currently joined in the room
Args:
room_id(str): The room to check.
user_id(str): The user to check.
current_state(dict): Optional map of the current state of the room.
If provided then that map is used to check whether they are a
member of the room. Otherwise the current membership is
loaded from the database.
Raises:
AuthError if the user is not in the room.
Returns:
A deferred membership event for the user if the user is in
the room.
"""
if current_state:
member = current_state.get(
(EventTypes.Member, user_id),
@@ -148,43 +105,6 @@ class Auth(object):
self._check_joined_room(member, user_id, room_id)
defer.returnValue(member)
@defer.inlineCallbacks
def check_user_was_in_room(self, room_id, user_id, current_state=None):
"""Check if the user was in the room at some point.
Args:
room_id(str): The room to check.
user_id(str): The user to check.
current_state(dict): Optional map of the current state of the room.
If provided then that map is used to check whether they are a
member of the room. Otherwise the current membership is
loaded from the database.
Raises:
AuthError if the user was never in the room.
Returns:
A deferred membership event for the user if the user was in the
room. This will be the join event if they are currently joined to
the room. This will be the leave event if they have left the room.
"""
if current_state:
member = current_state.get(
(EventTypes.Member, user_id),
None
)
else:
member = yield self.state.get_current_state(
room_id=room_id,
event_type=EventTypes.Member,
state_key=user_id
)
membership = member.membership if member else None
if membership not in (Membership.JOIN, Membership.LEAVE):
raise AuthError(403, "User %s not in room %s" % (
user_id, room_id
))
defer.returnValue(member)
@defer.inlineCallbacks
def check_host_in_room(self, room_id, host):
curr_state = yield self.state.get_current_state(room_id)
@@ -246,7 +166,6 @@ class Auth(object):
target = auth_events.get(key)
target_in_room = target and target.membership == Membership.JOIN
target_banned = target and target.membership == Membership.BAN
key = (EventTypes.JoinRules, "", )
join_rule_event = auth_events.get(key)
@@ -257,20 +176,24 @@ class Auth(object):
else:
join_rule = JoinRules.INVITE
user_level = self._get_user_power_level(event.user_id, auth_events)
target_level = self._get_user_power_level(
target_user_id, auth_events
user_level = self._get_power_level_from_event_state(
event,
event.user_id,
auth_events,
)
# FIXME (erikj): What should we do here as the default?
ban_level = self._get_named_level(auth_events, "ban", 50)
ban_level, kick_level, redact_level = (
self._get_ops_level_from_event_state(
event,
auth_events,
)
)
logger.debug(
"is_membership_change_allowed: %s",
{
"caller_in_room": caller_in_room,
"caller_invited": caller_invited,
"target_banned": target_banned,
"target_in_room": target_in_room,
"membership": membership,
"join_rule": join_rule,
@@ -279,41 +202,25 @@ class Auth(object):
}
)
if Membership.JOIN != membership:
# JOIN is the only action you can perform if you're not in the room
if not caller_in_room: # caller isn't joined
raise AuthError(
403,
"%s not in room %s." % (event.user_id, event.room_id,)
)
if Membership.INVITE == membership:
# TODO (erikj): We should probably handle this more intelligently
# PRIVATE join rules.
# Invites are valid iff caller is in the room and target isn't.
if target_banned:
if not caller_in_room: # caller isn't joined
raise AuthError(
403, "%s is banned from the room" % (target_user_id,)
403,
"%s not in room %s." % (event.user_id, event.room_id,)
)
elif target_in_room: # the target is already in the room.
raise AuthError(403, "%s is already in the room." %
target_user_id)
else:
invite_level = self._get_named_level(auth_events, "invite", 0)
if user_level < invite_level:
raise AuthError(
403, "You cannot invite user %s." % target_user_id
)
elif Membership.JOIN == membership:
# Joins are valid iff caller == target and they were:
# invited: They are accepting the invitation
# joined: It's a NOOP
if event.user_id != target_user_id:
raise AuthError(403, "Cannot force another user to join.")
elif target_banned:
raise AuthError(403, "You are banned from this room")
elif join_rule == JoinRules.PUBLIC:
pass
elif join_rule == JoinRules.INVITE:
@@ -325,61 +232,63 @@ class Auth(object):
raise AuthError(403, "You are not allowed to join this room")
elif Membership.LEAVE == membership:
# TODO (erikj): Implement kicks.
if target_banned and user_level < ban_level:
if not caller_in_room: # trying to leave a room you aren't joined
raise AuthError(
403, "You cannot unban user &s." % (target_user_id,)
403,
"%s not in room %s." % (target_user_id, event.room_id,)
)
elif target_user_id != event.user_id:
kick_level = self._get_named_level(auth_events, "kick", 50)
if kick_level:
kick_level = int(kick_level)
else:
kick_level = 50 # FIXME (erikj): What should we do here?
if user_level < kick_level or user_level <= target_level:
if user_level < kick_level:
raise AuthError(
403, "You cannot kick user %s." % target_user_id
)
elif Membership.BAN == membership:
if user_level < ban_level or user_level <= target_level:
if ban_level:
ban_level = int(ban_level)
else:
ban_level = 50 # FIXME (erikj): What should we do here?
if user_level < ban_level:
raise AuthError(403, "You don't have permission to ban")
else:
raise AuthError(500, "Unknown membership %s" % membership)
return True
def _get_power_level_event(self, auth_events):
def _get_power_level_from_event_state(self, event, user_id, auth_events):
key = (EventTypes.PowerLevels, "", )
return auth_events.get(key)
def _get_user_power_level(self, user_id, auth_events):
power_level_event = self._get_power_level_event(auth_events)
power_level_event = auth_events.get(key)
level = None
if power_level_event:
level = power_level_event.content.get("users", {}).get(user_id)
if not level:
level = power_level_event.content.get("users_default", 0)
if level is None:
return 0
else:
return int(level)
else:
key = (EventTypes.Create, "", )
create_event = auth_events.get(key)
if (create_event is not None and
create_event.content["creator"] == user_id):
return 100
else:
return 0
def _get_named_level(self, auth_events, name, default):
power_level_event = self._get_power_level_event(auth_events)
return level
if not power_level_event:
return default
def _get_ops_level_from_event_state(self, event, auth_events):
key = (EventTypes.PowerLevels, "", )
power_level_event = auth_events.get(key)
level = power_level_event.content.get(name, None)
if level is not None:
return int(level)
else:
return default
if power_level_event:
return (
power_level_event.content.get("ban", 50),
power_level_event.content.get("kick", 50),
power_level_event.content.get("redact", 50),
)
return None, None, None,
@defer.inlineCallbacks
def get_user_by_req(self, request):
@@ -388,9 +297,9 @@ class Auth(object):
Args:
request - An HTTP request with an access_token query parameter.
Returns:
tuple of:
UserID (str)
Access token ID (str)
tuple : of UserID and device string:
User ID object of the user making the request
Client ID object of the client instance the user is using
Raises:
AuthError if no user by that token exists or the token is invalid.
"""
@@ -418,15 +327,16 @@ class Auth(object):
if not user_id:
raise KeyError
request.authenticated_entity = user_id
defer.returnValue((UserID.from_string(user_id), ""))
defer.returnValue(
(UserID.from_string(user_id), ClientInfo("", ""))
)
return
except KeyError:
pass # normal users won't have the user_id query parameter set.
pass # normal users won't have this query parameter set
user_info = yield self._get_user_by_access_token(access_token)
user_info = yield self.get_user_by_token(access_token)
user = user_info["user"]
device_id = user_info["device_id"]
token_id = user_info["token_id"]
ip_addr = self.hs.get_ip_from_request(request)
@@ -435,124 +345,45 @@ class Auth(object):
default=[""]
)[0]
if user and access_token and ip_addr:
self.store.insert_client_ip(
yield self.store.insert_client_ip(
user=user,
access_token=access_token,
device_id=user_info["device_id"],
ip=ip_addr,
user_agent=user_agent
)
request.authenticated_entity = user.to_string()
defer.returnValue((user, token_id,))
defer.returnValue((user, ClientInfo(device_id, token_id)))
except KeyError:
raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.",
errcode=Codes.MISSING_TOKEN
)
raise AuthError(403, "Missing access token.")
@defer.inlineCallbacks
def _get_user_by_access_token(self, token):
def get_user_by_token(self, token):
""" Get a registered user's ID.
Args:
token (str): The access token to get the user by.
Returns:
dict : dict that includes the user and the ID of their access token.
dict : dict that includes the user, device_id, and whether the
user is a server admin.
Raises:
AuthError if no user by that token exists or the token is invalid.
"""
try:
ret = yield self._get_user_from_macaroon(token)
except AuthError:
# TODO(daniel): Remove this fallback when all existing access tokens
# have been re-issued as macaroons.
ret = yield self._look_up_user_by_access_token(token)
defer.returnValue(ret)
ret = yield self.store.get_user_by_token(token=token)
if not ret:
raise StoreError(400, "Unknown token")
user_info = {
"admin": bool(ret.get("admin", False)),
"device_id": ret.get("device_id"),
"user": UserID.from_string(ret.get("name")),
"token_id": ret.get("token_id", None),
}
@defer.inlineCallbacks
def _get_user_from_macaroon(self, macaroon_str):
try:
macaroon = pymacaroons.Macaroon.deserialize(macaroon_str)
self._validate_macaroon(macaroon)
user_prefix = "user_id = "
for caveat in macaroon.caveats:
if caveat.caveat_id.startswith(user_prefix):
user = UserID.from_string(caveat.caveat_id[len(user_prefix):])
# This codepath exists so that we can actually return a
# token ID, because we use token IDs in place of device
# identifiers throughout the codebase.
# TODO(daniel): Remove this fallback when device IDs are
# properly implemented.
ret = yield self._look_up_user_by_access_token(macaroon_str)
if ret["user"] != user:
logger.error(
"Macaroon user (%s) != DB user (%s)",
user,
ret["user"]
)
raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS,
"User mismatch in macaroon",
errcode=Codes.UNKNOWN_TOKEN
)
defer.returnValue(ret)
raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS, "No user caveat in macaroon",
errcode=Codes.UNKNOWN_TOKEN
)
except (pymacaroons.exceptions.MacaroonException, TypeError, ValueError):
raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Invalid macaroon passed.",
errcode=Codes.UNKNOWN_TOKEN
)
def _validate_macaroon(self, macaroon):
v = pymacaroons.Verifier()
v.satisfy_exact("gen = 1")
v.satisfy_exact("type = access")
v.satisfy_general(lambda c: c.startswith("user_id = "))
v.satisfy_general(self._verify_expiry)
v.verify(macaroon, self.hs.config.macaroon_secret_key)
v = pymacaroons.Verifier()
v.satisfy_general(self._verify_recognizes_caveats)
v.verify(macaroon, self.hs.config.macaroon_secret_key)
def _verify_expiry(self, caveat):
prefix = "time < "
if not caveat.startswith(prefix):
return False
# TODO(daniel): Enable expiry check when clients actually know how to
# refresh tokens. (And remember to enable the tests)
return True
expiry = int(caveat[len(prefix):])
now = self.hs.get_clock().time_msec()
return now < expiry
def _verify_recognizes_caveats(self, caveat):
first_space = caveat.find(" ")
if first_space < 0:
return False
second_space = caveat.find(" ", first_space + 1)
if second_space < 0:
return False
return caveat[:second_space + 1] in self._KNOWN_CAVEAT_PREFIXES
@defer.inlineCallbacks
def _look_up_user_by_access_token(self, token):
ret = yield self.store.get_user_by_access_token(token)
if not ret:
raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Unrecognised access token.",
errcode=Codes.UNKNOWN_TOKEN
)
user_info = {
"user": UserID.from_string(ret.get("name")),
"token_id": ret.get("token_id", None),
}
defer.returnValue(user_info)
defer.returnValue(user_info)
except StoreError:
raise AuthError(403, "Unrecognised access token.",
errcode=Codes.UNKNOWN_TOKEN)
@defer.inlineCallbacks
def get_appservice_by_req(self, request):
@@ -560,23 +391,19 @@ class Auth(object):
token = request.args["access_token"][0]
service = yield self.store.get_app_service_by_token(token)
if not service:
raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS,
"Unrecognised access token.",
errcode=Codes.UNKNOWN_TOKEN
)
request.authenticated_entity = service.sender
raise AuthError(403, "Unrecognised access token.",
errcode=Codes.UNKNOWN_TOKEN)
defer.returnValue(service)
except KeyError:
raise AuthError(
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token."
)
raise AuthError(403, "Missing access token.")
def is_server_admin(self, user):
return self.store.is_server_admin(user)
@defer.inlineCallbacks
def add_auth_events(self, builder, context):
yield run_on_reactor()
auth_ids = self.compute_auth_events(builder, context.current_state)
auth_events_entries = yield self.store.add_event_hashes(
@@ -585,6 +412,12 @@ class Auth(object):
builder.auth_events = auth_events_entries
context.auth_events = {
k: v
for k, v in context.current_state.items()
if v.event_id in auth_ids
}
def compute_auth_events(self, event, current_state):
if event.type == EventTypes.Create:
return []
@@ -641,7 +474,7 @@ class Auth(object):
send_level = send_level_event.content.get("events", {}).get(
event.type
)
if send_level is None:
if not send_level:
if hasattr(event, "state_key"):
send_level = send_level_event.content.get(
"state_default", 50
@@ -656,7 +489,16 @@ class Auth(object):
else:
send_level = 0
user_level = self._get_user_power_level(event.user_id, auth_events)
user_level = self._get_power_level_from_event_state(
event,
event.user_id,
auth_events,
)
if user_level:
user_level = int(user_level)
else:
user_level = 0
if user_level < send_level:
raise AuthError(
@@ -667,55 +509,44 @@ class Auth(object):
# Check state_key
if hasattr(event, "state_key"):
if event.state_key.startswith("@"):
if event.state_key != event.user_id:
raise AuthError(
403,
"You are not allowed to set others state"
)
else:
sender_domain = UserID.from_string(
event.user_id
).domain
if sender_domain != event.state_key:
if not event.state_key.startswith("_"):
if event.state_key.startswith("@"):
if event.state_key != event.user_id:
raise AuthError(
403,
"You are not allowed to set others state"
)
else:
sender_domain = UserID.from_string(
event.user_id
).domain
if sender_domain != event.state_key:
raise AuthError(
403,
"You are not allowed to set others state"
)
return True
def check_redaction(self, event, auth_events):
"""Check whether the event sender is allowed to redact the target event.
Returns:
True if the the sender is allowed to redact the target event if the
target event was created by them.
False if the sender is allowed to redact the target event with no
further checks.
Raises:
AuthError if the event sender is definitely not allowed to redact
the target event.
"""
user_level = self._get_user_power_level(event.user_id, auth_events)
redact_level = self._get_named_level(auth_events, "redact", 50)
if user_level > redact_level:
return False
redacter_domain = EventID.from_string(event.event_id).domain
redactee_domain = EventID.from_string(event.redacts).domain
if redacter_domain == redactee_domain:
return True
raise AuthError(
403,
"You don't have permission to redact events"
def _check_redaction(self, event, auth_events):
user_level = self._get_power_level_from_event_state(
event,
event.user_id,
auth_events,
)
_, _, redact_level = self._get_ops_level_from_event_state(
event,
auth_events,
)
if user_level < redact_level:
raise AuthError(
403,
"You don't have permission to redact events"
)
def _check_power_levels(self, event, auth_events):
user_list = event.content.get("users", {})
# Validate users
@@ -736,30 +567,32 @@ class Auth(object):
if not current_state:
return
user_level = self._get_user_power_level(event.user_id, auth_events)
user_level = self._get_power_level_from_event_state(
event,
event.user_id,
auth_events,
)
# Check other levels:
levels_to_check = [
("users_default", None),
("events_default", None),
("state_default", None),
("ban", None),
("redact", None),
("kick", None),
("invite", None),
("users_default", []),
("events_default", []),
("ban", []),
("redact", []),
("kick", []),
]
old_list = current_state.content.get("users")
for user in set(old_list.keys() + user_list.keys()):
levels_to_check.append(
(user, "users")
(user, ["users"])
)
old_list = current_state.content.get("events")
new_list = event.content.get("events")
for ev_id in set(old_list.keys() + new_list.keys()):
levels_to_check.append(
(ev_id, "events")
(ev_id, ["events"])
)
old_state = current_state.content
@@ -767,10 +600,12 @@ class Auth(object):
for level_to_check, dir in levels_to_check:
old_loc = old_state
for d in dir:
old_loc = old_loc.get(d, {})
new_loc = new_state
if dir:
old_loc = old_loc.get(dir, {})
new_loc = new_loc.get(dir, {})
for d in dir:
new_loc = new_loc.get(d, {})
if level_to_check in old_loc:
old_level = int(old_loc[level_to_check])
@@ -786,14 +621,6 @@ class Auth(object):
if new_level == old_level:
continue
if dir == "users" and level_to_check != event.user_id:
if old_level == user_level:
raise AuthError(
403,
"You don't have permission to remove ops level equal "
"to your own"
)
if old_level > user_level or new_level > user_level:
raise AuthError(
403,

View File

@@ -27,6 +27,16 @@ class Membership(object):
LIST = (INVITE, JOIN, KNOCK, LEAVE, BAN)
class Feedback(object):
"""Represents the types of feedback a user can send in response to a
message."""
DELIVERED = u"delivered"
READ = u"read"
LIST = (DELIVERED, READ)
class PresenceState(object):
"""Represents the presence state of a user."""
OFFLINE = u"offline"
@@ -49,11 +59,7 @@ class LoginType(object):
EMAIL_URL = u"m.login.email.url"
EMAIL_IDENTITY = u"m.login.email.identity"
RECAPTCHA = u"m.login.recaptcha"
DUMMY = u"m.login.dummy"
# Only for C/S API v1
APPLICATION_SERVICE = u"m.login.application_service"
SHARED_SECRET = u"org.matrix.login.shared_secret"
class EventTypes(object):
@@ -63,10 +69,7 @@ class EventTypes(object):
PowerLevels = "m.room.power_levels"
Aliases = "m.room.aliases"
Redaction = "m.room.redaction"
RoomHistoryVisibility = "m.room.history_visibility"
CanonicalAlias = "m.room.canonical_alias"
RoomAvatar = "m.room.avatar"
Feedback = "m.room.message.feedback"
# These are used for validation
Message = "m.room.message"
@@ -78,8 +81,3 @@ class RejectedReason(object):
AUTH_ERROR = "auth_error"
REPLACED = "replaced"
NOT_ANCESTOR = "not_ancestor"
class RoomCreationPreset(object):
PRIVATE_CHAT = "private_chat"
PUBLIC_CHAT = "public_chat"

View File

@@ -31,16 +31,13 @@ class Codes(object):
BAD_PAGINATION = "M_BAD_PAGINATION"
UNKNOWN = "M_UNKNOWN"
NOT_FOUND = "M_NOT_FOUND"
MISSING_TOKEN = "M_MISSING_TOKEN"
UNKNOWN_TOKEN = "M_UNKNOWN_TOKEN"
LIMIT_EXCEEDED = "M_LIMIT_EXCEEDED"
CAPTCHA_NEEDED = "M_CAPTCHA_NEEDED"
CAPTCHA_INVALID = "M_CAPTCHA_INVALID"
MISSING_PARAM = "M_MISSING_PARAM"
TOO_LARGE = "M_TOO_LARGE"
MISSING_PARAM = "M_MISSING_PARAM",
TOO_LARGE = "M_TOO_LARGE",
EXCLUSIVE = "M_EXCLUSIVE"
THREEPID_AUTH_FAILED = "M_THREEPID_AUTH_FAILED"
THREEPID_IN_USE = "THREEPID_IN_USE"
class CodeMessageException(RuntimeError):
@@ -77,6 +74,11 @@ class SynapseError(CodeMessageException):
)
class RoomError(SynapseError):
"""An error raised when a room event fails."""
pass
class RegistrationError(SynapseError):
"""An error raised when a registration event fails."""
pass

View File

@@ -22,6 +22,5 @@ STATIC_PREFIX = "/_matrix/static"
WEB_CLIENT_PREFIX = "/_matrix/client"
CONTENT_REPO_PREFIX = "/_matrix/content"
SERVER_KEY_PREFIX = "/_matrix/key/v1"
SERVER_KEY_V2_PREFIX = "/_matrix/key/v2"
MEDIA_PREFIX = "/_matrix/media/v1"
APP_SERVICE_PREFIX = "/_matrix/appservice/v1"

View File

@@ -16,80 +16,55 @@
import sys
sys.dont_write_bytecode = True
from synapse.python_dependencies import (
check_requirements, DEPENDENCY_LINKS, MissingRequirementError
)
if __name__ == '__main__':
try:
check_requirements()
except MissingRequirementError as e:
message = "\n".join([
"Missing Requirement: %s" % (e.message,),
"To install run:",
" pip install --upgrade --force \"%s\"" % (e.dependency,),
"",
])
sys.stderr.writelines(message)
sys.exit(1)
from synapse.storage.engines import create_engine, IncorrectDatabaseSetup
from synapse.storage import (
are_all_users_on_domain, UpgradeDatabaseException,
prepare_database, prepare_sqlite3_database, UpgradeDatabaseException,
)
from synapse.server import HomeServer
from synapse.python_dependencies import check_requirements
from twisted.internet import reactor, task, defer
from twisted.internet import reactor
from twisted.application import service
from twisted.enterprise import adbapi
from twisted.web.resource import Resource, EncodingResourceWrapper
from twisted.web.resource import Resource
from twisted.web.static import File
from twisted.web.server import Site, GzipEncoderFactory, Request
from twisted.web.server import Site
from synapse.http.server import JsonResource, RootRedirect
from synapse.rest.appservice.v1 import AppServiceRestResource
from synapse.rest.media.v0.content_repository import ContentRepoResource
from synapse.rest.media.v1.media_repository import MediaRepositoryResource
from synapse.rest.key.v1.server_key_resource import LocalKey
from synapse.rest.key.v2 import KeyApiV2Resource
from synapse.http.server_key_resource import LocalKey
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
from synapse.api.urls import (
CLIENT_PREFIX, FEDERATION_PREFIX, WEB_CLIENT_PREFIX, CONTENT_REPO_PREFIX,
SERVER_KEY_PREFIX, MEDIA_PREFIX, CLIENT_V2_ALPHA_PREFIX, STATIC_PREFIX,
SERVER_KEY_V2_PREFIX,
SERVER_KEY_PREFIX, MEDIA_PREFIX, CLIENT_V2_ALPHA_PREFIX, APP_SERVICE_PREFIX,
STATIC_PREFIX
)
from synapse.config.homeserver import HomeServerConfig
from synapse.crypto import context_factory
from synapse.util.logcontext import LoggingContext
from synapse.rest.client.v1 import ClientV1RestResource
from synapse.rest.client.v2_alpha import ClientV2AlphaRestResource
from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
from synapse import events
from daemonize import Daemonize
import twisted.manhole.telnet
from multiprocessing import Process
from synapse.util.traceutil import Tracer
import synapse
import contextlib
import logging
import os
import re
import resource
import signal
import subprocess
import time
import sqlite3
import syweb
logger = logging.getLogger("synapse.app.homeserver")
def gz_wrap(r):
return EncodingResourceWrapper(r, [GzipEncoderFactory()])
logger = logging.getLogger(__name__)
class SynapseHomeServer(HomeServer):
@@ -106,42 +81,20 @@ class SynapseHomeServer(HomeServer):
def build_resource_for_federation(self):
return JsonResource(self)
def build_resource_for_app_services(self):
return AppServiceRestResource(self)
def build_resource_for_web_client(self):
webclient_path = self.get_config().web_client_location
if not webclient_path:
try:
import syweb
except ImportError:
quit_with_error(
"Could not find a webclient.\n\n"
"Please either install the matrix-angular-sdk or configure\n"
"the location of the source to serve via the configuration\n"
"option `web_client_location`\n\n"
"To install the `matrix-angular-sdk` via pip, run:\n\n"
" pip install '%(dep)s'\n"
"\n"
"You can also disable hosting of the webclient via the\n"
"configuration option `web_client`\n"
% {"dep": DEPENDENCY_LINKS["matrix-angular-sdk"]}
)
syweb_path = os.path.dirname(syweb.__file__)
webclient_path = os.path.join(syweb_path, "webclient")
# GZip is disabled here due to
# https://twistedmatrix.com/trac/ticket/7678
# (It can stay enabled for the API resources: they call
# write() with the whole body and then finish() straight
# after and so do not trigger the bug.
# GzipFile was removed in commit 184ba09
# return GzipFile(webclient_path) # TODO configurable?
syweb_path = os.path.dirname(syweb.__file__)
webclient_path = os.path.join(syweb_path, "webclient")
return File(webclient_path) # TODO configurable?
def build_resource_for_static_content(self):
# This is old and should go away: not going to bother adding gzip
return File("static")
def build_resource_for_content_repo(self):
return ContentRepoResource(
self, self.config.uploads_path, self.auth, self.content_addr
self, self.upload_dir, self.auth, self.content_addr
)
def build_resource_for_media_repository(self):
@@ -150,148 +103,123 @@ class SynapseHomeServer(HomeServer):
def build_resource_for_server_key(self):
return LocalKey(self)
def build_resource_for_server_key_v2(self):
return KeyApiV2Resource(self)
def build_resource_for_metrics(self):
if self.get_config().enable_metrics:
return MetricsResource(self)
else:
return None
def build_db_pool(self):
name = self.db_config["name"]
return adbapi.ConnectionPool(
name,
**self.db_config.get("args", {})
"sqlite3", self.get_db_name(),
check_same_thread=False,
cp_min=1,
cp_max=1,
cp_openfun=prepare_database, # Prepare the database for each conn
# so that :memory: sqlite works
)
def _listener_http(self, config, listener_config):
port = listener_config["port"]
bind_address = listener_config.get("bind_address", "")
tls = listener_config.get("tls", False)
site_tag = listener_config.get("tag", port)
def create_resource_tree(self, web_client, redirect_root_to_web_client):
"""Create the resource tree for this Home Server.
if tls and config.no_tls:
return
This in unduly complicated because Twisted does not support putting
child resources more than 1 level deep at a time.
metrics_resource = self.get_resource_for_metrics()
Args:
web_client (bool): True to enable the web client.
redirect_root_to_web_client (bool): True to redirect '/' to the
location of the web client. This does nothing if web_client is not
True.
"""
# list containing (path_str, Resource) e.g:
# [ ("/aaa/bbb/cc", Resource1), ("/aaa/dummy", Resource2) ]
desired_tree = [
(CLIENT_PREFIX, self.get_resource_for_client()),
(CLIENT_V2_ALPHA_PREFIX, self.get_resource_for_client_v2_alpha()),
(FEDERATION_PREFIX, self.get_resource_for_federation()),
(CONTENT_REPO_PREFIX, self.get_resource_for_content_repo()),
(SERVER_KEY_PREFIX, self.get_resource_for_server_key()),
(MEDIA_PREFIX, self.get_resource_for_media_repository()),
(APP_SERVICE_PREFIX, self.get_resource_for_app_services()),
(STATIC_PREFIX, self.get_resource_for_static_content()),
]
resources = {}
for res in listener_config["resources"]:
for name in res["names"]:
if name == "client":
if res["compress"]:
client_v1 = gz_wrap(self.get_resource_for_client())
client_v2 = gz_wrap(self.get_resource_for_client_v2_alpha())
else:
client_v1 = self.get_resource_for_client()
client_v2 = self.get_resource_for_client_v2_alpha()
if web_client:
logger.info("Adding the web client.")
desired_tree.append((WEB_CLIENT_PREFIX,
self.get_resource_for_web_client()))
resources.update({
CLIENT_PREFIX: client_v1,
CLIENT_V2_ALPHA_PREFIX: client_v2,
})
if name == "federation":
resources.update({
FEDERATION_PREFIX: self.get_resource_for_federation(),
})
if name in ["static", "client"]:
resources.update({
STATIC_PREFIX: self.get_resource_for_static_content(),
})
if name in ["media", "federation", "client"]:
resources.update({
MEDIA_PREFIX: self.get_resource_for_media_repository(),
CONTENT_REPO_PREFIX: self.get_resource_for_content_repo(),
})
if name in ["keys", "federation"]:
resources.update({
SERVER_KEY_PREFIX: self.get_resource_for_server_key(),
SERVER_KEY_V2_PREFIX: self.get_resource_for_server_key_v2(),
})
if name == "webclient":
resources[WEB_CLIENT_PREFIX] = self.get_resource_for_web_client()
if name == "metrics" and metrics_resource:
resources[METRICS_PREFIX] = metrics_resource
root_resource = create_resource_tree(resources)
if tls:
reactor.listenSSL(
port,
SynapseSite(
"synapse.access.https.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
),
self.tls_server_context_factory,
interface=bind_address
)
if web_client and redirect_root_to_web_client:
self.root_resource = RootRedirect(WEB_CLIENT_PREFIX)
else:
self.root_resource = Resource()
# ideally we'd just use getChild and putChild but getChild doesn't work
# unless you give it a Request object IN ADDITION to the name :/ So
# instead, we'll store a copy of this mapping so we can actually add
# extra resources to existing nodes. See self._resource_id for the key.
resource_mappings = {}
for full_path, res in desired_tree:
logger.info("Attaching %s to path %s", res, full_path)
last_resource = self.root_resource
for path_seg in full_path.split('/')[1:-1]:
if path_seg not in last_resource.listNames():
# resource doesn't exist, so make a "dummy resource"
child_resource = Resource()
last_resource.putChild(path_seg, child_resource)
res_id = self._resource_id(last_resource, path_seg)
resource_mappings[res_id] = child_resource
last_resource = child_resource
else:
# we have an existing Resource, use that instead.
res_id = self._resource_id(last_resource, path_seg)
last_resource = resource_mappings[res_id]
# ===========================
# now attach the actual desired resource
last_path_seg = full_path.split('/')[-1]
# if there is already a resource here, thieve its children and
# replace it
res_id = self._resource_id(last_resource, last_path_seg)
if res_id in resource_mappings:
# there is a dummy resource at this path already, which needs
# to be replaced with the desired resource.
existing_dummy_resource = resource_mappings[res_id]
for child_name in existing_dummy_resource.listNames():
child_res_id = self._resource_id(existing_dummy_resource,
child_name)
child_resource = resource_mappings[child_res_id]
# steal the children
res.putChild(child_name, child_resource)
# finally, insert the desired resource in the right place
last_resource.putChild(last_path_seg, res)
res_id = self._resource_id(last_resource, last_path_seg)
resource_mappings[res_id] = res
return self.root_resource
def _resource_id(self, resource, path_seg):
"""Construct an arbitrary resource ID so you can retrieve the mapping
later.
If you want to represent resource A putChild resource B with path C,
the mapping should looks like _resource_id(A,C) = B.
Args:
resource (Resource): The *parent* Resource
path_seg (str): The name of the child Resource to be attached.
Returns:
str: A unique string which can be a key to the child Resource.
"""
return "%s-%s" % (resource, path_seg)
def start_listening(self, secure_port, unsecure_port):
if secure_port is not None:
reactor.listenSSL(
secure_port, Site(self.root_resource), self.tls_context_factory
)
logger.info("Synapse now listening on port %d", secure_port)
if unsecure_port is not None:
reactor.listenTCP(
port,
SynapseSite(
"synapse.access.http.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
),
interface=bind_address
unsecure_port, Site(self.root_resource)
)
logger.info("Synapse now listening on port %d", port)
def start_listening(self):
config = self.get_config()
for listener in config.listeners:
if listener["type"] == "http":
self._listener_http(config, listener)
elif listener["type"] == "manhole":
f = twisted.manhole.telnet.ShellFactory()
f.username = "matrix"
f.password = "rabbithole"
f.namespace['hs'] = self
reactor.listenTCP(
listener["port"],
f,
interface=listener.get("bind_address", '127.0.0.1')
)
else:
logger.warn("Unrecognized listener type: %s", listener["type"])
def run_startup_checks(self, db_conn, database_engine):
all_users_native = are_all_users_on_domain(
db_conn.cursor(), database_engine, self.hostname
)
if not all_users_native:
quit_with_error(
"Found users in database not native to %s!\n"
"You cannot changed a synapse server_name after it's been configured"
% (self.hostname,)
)
try:
database_engine.check_database(db_conn.cursor())
except IncorrectDatabaseSetup as e:
quit_with_error(e.message)
def quit_with_error(error_string):
message_lines = error_string.split("\n")
line_length = max([len(l) for l in message_lines if len(l) < 80]) + 2
sys.stderr.write("*" * line_length + '\n')
for line in message_lines:
sys.stderr.write(" %s\n" % (line.rstrip(),))
sys.stderr.write("*" * line_length + '\n')
sys.exit(1)
logger.info("Synapse now listening on port %d", unsecure_port)
def get_version_string():
@@ -352,7 +280,7 @@ def get_version_string():
)
).encode("ascii")
except Exception as e:
logger.info("Failed to check for git repository: %s", e)
logger.warn("Failed to check for git repository: %s", e)
return ("Synapse/%s" % (synapse.__version__,)).encode("ascii")
@@ -371,16 +299,16 @@ def change_resource_limit(soft_file_no):
logger.warn("Failed to set file limit: %s", e)
def load_config(config_options):
def setup(config_options):
"""
Args:
config_options_options: The options passed to Synapse. Usually
`sys.argv[1:]`.
should_run (bool): Whether to start the reactor.
Returns:
HomeServerConfig
HomeServer
"""
config = HomeServerConfig.load_config(
"Synapse Homeserver",
config_options,
@@ -389,53 +317,44 @@ def load_config(config_options):
config.setup_logging()
return config
check_requirements()
def setup(config):
"""
Args:
config (Homeserver)
Returns:
HomeServer
"""
version_string = get_version_string()
logger.info("Server hostname: %s", config.server_name)
logger.info("Server version: %s", version_string)
events.USE_FROZEN_DICTS = config.use_frozen_dicts
if re.search(":[0-9]+$", config.server_name):
domain_with_port = config.server_name
else:
domain_with_port = "%s:%s" % (config.server_name, config.bind_port)
tls_server_context_factory = context_factory.ServerContextFactory(config)
database_engine = create_engine(config.database_config["name"])
config.database_config["args"]["cp_openfun"] = database_engine.on_new_connection
tls_context_factory = context_factory.ServerContextFactory(config)
hs = SynapseHomeServer(
config.server_name,
db_config=config.database_config,
tls_server_context_factory=tls_server_context_factory,
domain_with_port=domain_with_port,
upload_dir=os.path.abspath("uploads"),
db_name=config.database_path,
tls_context_factory=tls_context_factory,
config=config,
content_addr=config.content_addr,
version_string=version_string,
database_engine=database_engine,
)
logger.info("Preparing database: %s...", config.database_config['name'])
hs.create_resource_tree(
web_client=config.webclient,
redirect_root_to_web_client=True,
)
db_name = hs.get_db_name()
logger.info("Preparing database: %s...", db_name)
try:
db_conn = database_engine.module.connect(
**{
k: v for k, v in config.database_config.get("args", {}).items()
if not k.startswith("cp_")
}
)
database_engine.prepare_database(db_conn)
hs.run_startup_checks(db_conn, database_engine)
db_conn.commit()
with sqlite3.connect(db_name) as db_conn:
prepare_sqlite3_database(db_conn)
prepare_database(db_conn)
except UpgradeDatabaseException:
sys.stderr.write(
"\nFailed to upgrade database.\n"
@@ -444,51 +363,26 @@ def setup(config):
)
sys.exit(1)
logger.info("Database prepared in %s.", config.database_config['name'])
logger.info("Database prepared in %s.", db_name)
hs.start_listening()
if config.manhole:
f = twisted.manhole.telnet.ShellFactory()
f.username = "matrix"
f.password = "rabbithole"
f.namespace['hs'] = hs
reactor.listenTCP(config.manhole, f, interface='127.0.0.1')
bind_port = config.bind_port
if config.no_tls:
bind_port = None
hs.start_listening(bind_port, config.unsecure_port)
hs.get_pusherpool().start()
hs.get_state_handler().start_caching()
hs.get_datastore().start_profiling()
hs.get_replication_layer().start_get_pdu_cache()
start_time = time.time()
@defer.inlineCallbacks
def phone_stats_home():
now = int(time.time())
uptime = int(now - start_time)
if uptime < 0:
uptime = 0
stats = {}
stats["homeserver"] = config.server_name
stats["timestamp"] = now
stats["uptime_seconds"] = uptime
stats["total_users"] = yield hs.get_datastore().count_all_users()
all_rooms = yield hs.get_datastore().get_rooms(False)
stats["total_room_count"] = len(all_rooms)
stats["daily_active_users"] = yield hs.get_datastore().count_daily_users()
daily_messages = yield hs.get_datastore().count_daily_messages()
if daily_messages is not None:
stats["daily_messages"] = daily_messages
logger.info("Reporting stats to matrix.org: %s" % (stats,))
try:
yield hs.get_simple_http_client().put_json(
"https://matrix.org/report-usage-stats/push",
stats
)
except Exception as e:
logger.warn("Error reporting stats: %s", e)
if hs.config.report_stats:
phone_home_task = task.LoopingCall(phone_stats_home)
phone_home_task.start(60 * 60 * 24, now=False)
return hs
@@ -507,265 +401,27 @@ class SynapseService(service.Service):
return self._port.stopListening()
class SynapseRequest(Request):
def __init__(self, site, *args, **kw):
Request.__init__(self, *args, **kw)
self.site = site
self.authenticated_entity = None
self.start_time = 0
def __repr__(self):
# We overwrite this so that we don't log ``access_token``
return '<%s at 0x%x method=%s uri=%s clientproto=%s site=%s>' % (
self.__class__.__name__,
id(self),
self.method,
self.get_redacted_uri(),
self.clientproto,
self.site.site_tag,
)
def get_redacted_uri(self):
return re.sub(
r'(\?.*access_token=)[^&]*(.*)$',
r'\1<redacted>\2',
self.uri
)
def get_user_agent(self):
return self.requestHeaders.getRawHeaders("User-Agent", [None])[-1]
def started_processing(self):
self.site.access_logger.info(
"%s - %s - Received request: %s %s",
self.getClientIP(),
self.site.site_tag,
self.method,
self.get_redacted_uri()
)
self.start_time = int(time.time() * 1000)
def finished_processing(self):
self.site.access_logger.info(
"%s - %s - {%s}"
" Processed request: %dms %sB %s \"%s %s %s\" \"%s\"",
self.getClientIP(),
self.site.site_tag,
self.authenticated_entity,
int(time.time() * 1000) - self.start_time,
self.sentLength,
self.code,
self.method,
self.get_redacted_uri(),
self.clientproto,
self.get_user_agent(),
)
@contextlib.contextmanager
def processing(self):
self.started_processing()
yield
self.finished_processing()
class XForwardedForRequest(SynapseRequest):
def __init__(self, *args, **kw):
SynapseRequest.__init__(self, *args, **kw)
"""
Add a layer on top of another request that only uses the value of an
X-Forwarded-For header as the result of C{getClientIP}.
"""
def getClientIP(self):
"""
@return: The client address (the first address) in the value of the
I{X-Forwarded-For header}. If the header is not present, return
C{b"-"}.
"""
return self.requestHeaders.getRawHeaders(
b"x-forwarded-for", [b"-"])[0].split(b",")[0].strip()
class SynapseRequestFactory(object):
def __init__(self, site, x_forwarded_for):
self.site = site
self.x_forwarded_for = x_forwarded_for
def __call__(self, *args, **kwargs):
if self.x_forwarded_for:
return XForwardedForRequest(self.site, *args, **kwargs)
else:
return SynapseRequest(self.site, *args, **kwargs)
class SynapseSite(Site):
"""
Subclass of a twisted http Site that does access logging with python's
standard logging
"""
def __init__(self, logger_name, site_tag, config, resource, *args, **kwargs):
Site.__init__(self, resource, *args, **kwargs)
self.site_tag = site_tag
proxied = config.get("x_forwarded", False)
self.requestFactory = SynapseRequestFactory(self, proxied)
self.access_logger = logging.getLogger(logger_name)
def log(self, request):
pass
def create_resource_tree(desired_tree, redirect_root_to_web_client=True):
"""Create the resource tree for this Home Server.
This in unduly complicated because Twisted does not support putting
child resources more than 1 level deep at a time.
Args:
web_client (bool): True to enable the web client.
redirect_root_to_web_client (bool): True to redirect '/' to the
location of the web client. This does nothing if web_client is not
True.
"""
if redirect_root_to_web_client and WEB_CLIENT_PREFIX in desired_tree:
root_resource = RootRedirect(WEB_CLIENT_PREFIX)
else:
root_resource = Resource()
# ideally we'd just use getChild and putChild but getChild doesn't work
# unless you give it a Request object IN ADDITION to the name :/ So
# instead, we'll store a copy of this mapping so we can actually add
# extra resources to existing nodes. See self._resource_id for the key.
resource_mappings = {}
for full_path, res in desired_tree.items():
logger.info("Attaching %s to path %s", res, full_path)
last_resource = root_resource
for path_seg in full_path.split('/')[1:-1]:
if path_seg not in last_resource.listNames():
# resource doesn't exist, so make a "dummy resource"
child_resource = Resource()
last_resource.putChild(path_seg, child_resource)
res_id = _resource_id(last_resource, path_seg)
resource_mappings[res_id] = child_resource
last_resource = child_resource
else:
# we have an existing Resource, use that instead.
res_id = _resource_id(last_resource, path_seg)
last_resource = resource_mappings[res_id]
# ===========================
# now attach the actual desired resource
last_path_seg = full_path.split('/')[-1]
# if there is already a resource here, thieve its children and
# replace it
res_id = _resource_id(last_resource, last_path_seg)
if res_id in resource_mappings:
# there is a dummy resource at this path already, which needs
# to be replaced with the desired resource.
existing_dummy_resource = resource_mappings[res_id]
for child_name in existing_dummy_resource.listNames():
child_res_id = _resource_id(
existing_dummy_resource, child_name
)
child_resource = resource_mappings[child_res_id]
# steal the children
res.putChild(child_name, child_resource)
# finally, insert the desired resource in the right place
last_resource.putChild(last_path_seg, res)
res_id = _resource_id(last_resource, last_path_seg)
resource_mappings[res_id] = res
return root_resource
def _resource_id(resource, path_seg):
"""Construct an arbitrary resource ID so you can retrieve the mapping
later.
If you want to represent resource A putChild resource B with path C,
the mapping should looks like _resource_id(A,C) = B.
Args:
resource (Resource): The *parent* Resource
path_seg (str): The name of the child Resource to be attached.
Returns:
str: A unique string which can be a key to the child Resource.
"""
return "%s-%s" % (resource, path_seg)
def run(config):
PROFILE_SYNAPSE = False
if PROFILE_SYNAPSE:
def profile(func):
from cProfile import Profile
from threading import current_thread
def profiled(*args, **kargs):
profile = Profile()
profile.enable()
func(*args, **kargs)
profile.disable()
ident = current_thread().ident
profile.dump_stats("/tmp/%s.%s.%i.pstat" % (
config.server_name, func.__name__, ident
))
return profiled
from twisted.python.threadpool import ThreadPool
ThreadPool._worker = profile(ThreadPool._worker)
reactor.run = profile(reactor.run)
def run(hs):
def in_thread():
hs = setup(config)
try:
tracer = Tracer()
sys.settrace(tracer.process)
except Exception:
logger.exception("Failed to start tracer")
with LoggingContext("run"):
change_resource_limit(hs.config.soft_file_limit)
reactor.run()
def start_in_process_checker():
p = None
should_restart = [True]
if hs.config.daemonize:
def proxy_signal(signum, stack):
logger.info("Got signal: %r", signum)
if p is not None:
os.kill(p.pid, signum)
if signum == signal.SIGTERM:
should_restart[0] = False
if getattr(signal, "SIGHUP"):
signal.signal(signal.SIGHUP, proxy_signal)
signal.signal(signal.SIGTERM, proxy_signal)
last_start = 0
next_delay = 1
while should_restart[0]:
last_start = time.time()
p = Process(target=in_thread, args=())
p.start()
p.join()
if time.time() - last_start < 120:
next_delay = min(next_delay * 5, 5 * 60)
else:
next_delay = 1
time.sleep(next_delay)
if config.daemonize:
if config.print_pidfile:
print config.pid_file
print hs.config.pid_file
daemon = Daemonize(
app="synapse-homeserver",
pid=config.pid_file,
action=lambda: start_in_process_checker(),
pid=hs.config.pid_file,
action=lambda: in_thread(),
auto_close_fds=False,
verbose=True,
logger=logger,
@@ -778,10 +434,9 @@ def run(config):
def main():
with LoggingContext("main"):
# check base requirements
check_requirements()
config = load_config(sys.argv[1:])
run(config)
hs = setup(sys.argv[1:])
run(hs)
if __name__ == '__main__':

View File

@@ -16,67 +16,53 @@
import sys
import os
import os.path
import subprocess
import signal
import yaml
SYNAPSE = ["python", "-B", "-m", "synapse.app.homeserver"]
CONFIGFILE = "homeserver.yaml"
PIDFILE = "homeserver.pid"
GREEN = "\x1b[1;32m"
RED = "\x1b[1;31m"
NORMAL = "\x1b[m"
def start(configfile):
def start():
if not os.path.exists(CONFIGFILE):
sys.stderr.write(
"No config file found\n"
"To generate a config file, run '%s -c %s --generate-config"
" --server-name=<server name>'\n" % (
" ".join(SYNAPSE), CONFIGFILE
)
)
sys.exit(1)
print "Starting ...",
args = SYNAPSE
args.extend(["--daemonize", "-c", configfile])
cwd = os.path.dirname(os.path.abspath(__file__))
try:
subprocess.check_call(args, cwd=cwd)
print GREEN + "started" + NORMAL
except subprocess.CalledProcessError as e:
print (
RED +
"error starting (exit code: %d); see above for logs" % e.returncode +
NORMAL
)
args.extend(["--daemonize", "-c", CONFIGFILE, "--pid-file", PIDFILE])
subprocess.check_call(args)
print GREEN + "started" + NORMAL
def stop(pidfile):
if os.path.exists(pidfile):
pid = int(open(pidfile).read())
def stop():
if os.path.exists(PIDFILE):
pid = int(open(PIDFILE).read())
os.kill(pid, signal.SIGTERM)
print GREEN + "stopped" + NORMAL
def main():
configfile = sys.argv[2] if len(sys.argv) == 3 else "homeserver.yaml"
if not os.path.exists(configfile):
sys.stderr.write(
"No config file found\n"
"To generate a config file, run '%s -c %s --generate-config"
" --server-name=<server name>'\n" % (
" ".join(SYNAPSE), configfile
)
)
sys.exit(1)
config = yaml.load(open(configfile))
pidfile = config["pid_file"]
action = sys.argv[1] if sys.argv[1:] else "usage"
if action == "start":
start(configfile)
start()
elif action == "stop":
stop(pidfile)
stop()
elif action == "restart":
stop(pidfile)
start(configfile)
stop()
start()
else:
sys.stderr.write("Usage: %s [start|stop|restart] [configfile]\n" % (sys.argv[0],))
sys.stderr.write("Usage: %s [start|stop|restart]\n" % (sys.argv[0],))
sys.exit(1)

View File

@@ -20,50 +20,6 @@ import re
logger = logging.getLogger(__name__)
class ApplicationServiceState(object):
DOWN = "down"
UP = "up"
class AppServiceTransaction(object):
"""Represents an application service transaction."""
def __init__(self, service, id, events):
self.service = service
self.id = id
self.events = events
def send(self, as_api):
"""Sends this transaction using the provided AS API interface.
Args:
as_api(ApplicationServiceApi): The API to use to send.
Returns:
A Deferred which resolves to True if the transaction was sent.
"""
return as_api.push_bulk(
service=self.service,
events=self.events,
txn_id=self.id
)
def complete(self, store):
"""Completes this transaction as successful.
Marks this transaction ID on the application service and removes the
transaction contents from the database.
Args:
store: The database store to operate on.
Returns:
A Deferred which resolves to True if the transaction was completed.
"""
return store.complete_appservice_txn(
service=self.service,
txn_id=self.id
)
class ApplicationService(object):
"""Defines an application service. This definition is mostly what is
provided to the /register AS API.
@@ -79,13 +35,13 @@ class ApplicationService(object):
NS_LIST = [NS_USERS, NS_ALIASES, NS_ROOMS]
def __init__(self, token, url=None, namespaces=None, hs_token=None,
sender=None, id=None):
sender=None, txn_id=None):
self.token = token
self.url = url
self.hs_token = hs_token
self.sender = sender
self.namespaces = self._check_namespaces(namespaces)
self.id = id
self.txn_id = txn_id
def _check_namespaces(self, namespaces):
# Sanity check that it is of the form:
@@ -95,7 +51,7 @@ class ApplicationService(object):
# rooms: [ {regex: "[A-z]+.*", exclusive: true}, ...],
# }
if not namespaces:
namespaces = {}
return None
for ns in ApplicationService.NS_LIST:
if ns not in namespaces:
@@ -148,8 +104,8 @@ class ApplicationService(object):
and self.is_interested_in_user(event.state_key)):
return True
# check joined member events
for user_id in member_list:
if self.is_interested_in_user(user_id):
for member in member_list:
if self.is_interested_in_user(member.state_key):
return True
return False
@@ -173,7 +129,7 @@ class ApplicationService(object):
restrict_to(str): The namespace to restrict regex tests to.
aliases_for_event(list): A list of all the known room aliases for
this event.
member_list(list): A list of all joined user_ids in this room.
member_list(list): A list of all joined room members in this room.
Returns:
bool: True if this service would like to know about this event.
"""
@@ -199,10 +155,7 @@ class ApplicationService(object):
return self._matches_user(event, member_list)
def is_interested_in_user(self, user_id):
return (
self._matches_regex(user_id, ApplicationService.NS_USERS)
or user_id == self.sender
)
return self._matches_regex(user_id, ApplicationService.NS_USERS)
def is_interested_in_alias(self, alias):
return self._matches_regex(alias, ApplicationService.NS_ALIASES)
@@ -211,10 +164,7 @@ class ApplicationService(object):
return self._matches_regex(room_id, ApplicationService.NS_ROOMS)
def is_exclusive_user(self, user_id):
return (
self._is_exclusive(ApplicationService.NS_USERS, user_id)
or user_id == self.sender
)
return self._is_exclusive(ApplicationService.NS_USERS, user_id)
def is_exclusive_alias(self, alias):
return self._is_exclusive(ApplicationService.NS_ALIASES, alias)

View File

@@ -72,19 +72,14 @@ class ApplicationServiceApi(SimpleHttpClient):
defer.returnValue(False)
@defer.inlineCallbacks
def push_bulk(self, service, events, txn_id=None):
def push_bulk(self, service, events):
events = self._serialize(events)
if txn_id is None:
logger.warning("push_bulk: Missing txn ID sending events to %s",
service.url)
txn_id = str(0)
txn_id = str(txn_id)
uri = service.url + ("/transactions/%s" %
urllib.quote(txn_id))
urllib.quote(str(0))) # TODO txn_ids
response = None
try:
yield self.put_json(
response = yield self.put_json(
uri=uri,
json_body={
"events": events
@@ -92,8 +87,9 @@ class ApplicationServiceApi(SimpleHttpClient):
args={
"access_token": service.hs_token
})
defer.returnValue(True)
return
if response: # just an empty json object
# TODO: Mark txn as sent successfully
defer.returnValue(True)
except CodeMessageException as e:
logger.warning("push_bulk to %s received %s", uri, e.code)
except Exception as ex:
@@ -101,8 +97,8 @@ class ApplicationServiceApi(SimpleHttpClient):
defer.returnValue(False)
@defer.inlineCallbacks
def push(self, service, event, txn_id=None):
response = yield self.push_bulk(service, [event], txn_id)
def push(self, service, event):
response = yield self.push_bulk(service, [event])
defer.returnValue(response)
def _serialize(self, events):

View File

@@ -1,254 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This module controls the reliability for application service transactions.
The nominal flow through this module looks like:
__________
1---ASa[e]-->| Service |--> Queue ASa[f]
2----ASb[e]->| Queuer |
3--ASa[f]--->|__________|-----------+ ASa[e], ASb[e]
V
-````````- +------------+
|````````|<--StoreTxn-|Transaction |
|Database| | Controller |---> SEND TO AS
`--------` +------------+
What happens on SEND TO AS depends on the state of the Application Service:
- If the AS is marked as DOWN, do nothing.
- If the AS is marked as UP, send the transaction.
* SUCCESS : Increment where the AS is up to txn-wise and nuke the txn
contents from the db.
* FAILURE : Marked AS as DOWN and start Recoverer.
Recoverer attempts to recover ASes who have died. The flow for this looks like:
,--------------------- backoff++ --------------.
V |
START ---> Wait exp ------> Get oldest txn ID from ----> FAILURE
backoff DB and try to send it
^ |___________
Mark AS as | V
UP & quit +---------- YES SUCCESS
| | |
NO <--- Have more txns? <------ Mark txn success & nuke <-+
from db; incr AS pos.
Reset backoff.
This is all tied together by the AppServiceScheduler which DIs the required
components.
"""
from synapse.appservice import ApplicationServiceState
from twisted.internet import defer
import logging
logger = logging.getLogger(__name__)
class AppServiceScheduler(object):
""" Public facing API for this module. Does the required DI to tie the
components together. This also serves as the "event_pool", which in this
case is a simple array.
"""
def __init__(self, clock, store, as_api):
self.clock = clock
self.store = store
self.as_api = as_api
def create_recoverer(service, callback):
return _Recoverer(clock, store, as_api, service, callback)
self.txn_ctrl = _TransactionController(
clock, store, as_api, create_recoverer
)
self.queuer = _ServiceQueuer(self.txn_ctrl)
@defer.inlineCallbacks
def start(self):
logger.info("Starting appservice scheduler")
# check for any DOWN ASes and start recoverers for them.
recoverers = yield _Recoverer.start(
self.clock, self.store, self.as_api, self.txn_ctrl.on_recovered
)
self.txn_ctrl.add_recoverers(recoverers)
def submit_event_for_as(self, service, event):
self.queuer.enqueue(service, event)
class _ServiceQueuer(object):
"""Queues events for the same application service together, sending
transactions as soon as possible. Once a transaction is sent successfully,
this schedules any other events in the queue to run.
"""
def __init__(self, txn_ctrl):
self.queued_events = {} # dict of {service_id: [events]}
self.pending_requests = {} # dict of {service_id: Deferred}
self.txn_ctrl = txn_ctrl
def enqueue(self, service, event):
# if this service isn't being sent something
if not self.pending_requests.get(service.id):
self._send_request(service, [event])
else:
# add to queue for this service
if service.id not in self.queued_events:
self.queued_events[service.id] = []
self.queued_events[service.id].append(event)
def _send_request(self, service, events):
# send request and add callbacks
d = self.txn_ctrl.send(service, events)
d.addBoth(self._on_request_finish)
d.addErrback(self._on_request_fail)
self.pending_requests[service.id] = d
def _on_request_finish(self, service):
self.pending_requests[service.id] = None
# if there are queued events, then send them.
if (service.id in self.queued_events
and len(self.queued_events[service.id]) > 0):
self._send_request(service, self.queued_events[service.id])
self.queued_events[service.id] = []
def _on_request_fail(self, err):
logger.error("AS request failed: %s", err)
class _TransactionController(object):
def __init__(self, clock, store, as_api, recoverer_fn):
self.clock = clock
self.store = store
self.as_api = as_api
self.recoverer_fn = recoverer_fn
# keep track of how many recoverers there are
self.recoverers = []
@defer.inlineCallbacks
def send(self, service, events):
try:
txn = yield self.store.create_appservice_txn(
service=service,
events=events
)
service_is_up = yield self._is_service_up(service)
if service_is_up:
sent = yield txn.send(self.as_api)
if sent:
txn.complete(self.store)
else:
self._start_recoverer(service)
except Exception as e:
logger.exception(e)
self._start_recoverer(service)
# request has finished
defer.returnValue(service)
@defer.inlineCallbacks
def on_recovered(self, recoverer):
self.recoverers.remove(recoverer)
logger.info("Successfully recovered application service AS ID %s",
recoverer.service.id)
logger.info("Remaining active recoverers: %s", len(self.recoverers))
yield self.store.set_appservice_state(
recoverer.service,
ApplicationServiceState.UP
)
def add_recoverers(self, recoverers):
for r in recoverers:
self.recoverers.append(r)
if len(recoverers) > 0:
logger.info("New active recoverers: %s", len(self.recoverers))
@defer.inlineCallbacks
def _start_recoverer(self, service):
yield self.store.set_appservice_state(
service,
ApplicationServiceState.DOWN
)
logger.info(
"Application service falling behind. Starting recoverer. AS ID %s",
service.id
)
recoverer = self.recoverer_fn(service, self.on_recovered)
self.add_recoverers([recoverer])
recoverer.recover()
@defer.inlineCallbacks
def _is_service_up(self, service):
state = yield self.store.get_appservice_state(service)
defer.returnValue(state == ApplicationServiceState.UP or state is None)
class _Recoverer(object):
@staticmethod
@defer.inlineCallbacks
def start(clock, store, as_api, callback):
services = yield store.get_appservices_by_state(
ApplicationServiceState.DOWN
)
recoverers = [
_Recoverer(clock, store, as_api, s, callback) for s in services
]
for r in recoverers:
logger.info("Starting recoverer for AS ID %s which was marked as "
"DOWN", r.service.id)
r.recover()
defer.returnValue(recoverers)
def __init__(self, clock, store, as_api, service, callback):
self.clock = clock
self.store = store
self.as_api = as_api
self.service = service
self.callback = callback
self.backoff_counter = 1
def recover(self):
self.clock.call_later((2 ** self.backoff_counter), self.retry)
def _backoff(self):
# cap the backoff to be around 18h => (2^16) = 65536 secs
if self.backoff_counter < 16:
self.backoff_counter += 1
self.recover()
@defer.inlineCallbacks
def retry(self):
try:
txn = yield self.store.get_oldest_unsent_txn(self.service)
if txn:
logger.info("Retrying transaction %s for AS ID %s",
txn.id, txn.service.id)
sent = yield txn.send(self.as_api)
if sent:
yield txn.complete(self.store)
# reset the backoff counter and retry immediately
self.backoff_counter = 1
yield self.retry()
else:
self._backoff()
else:
self._set_service_recovered()
except Exception as e:
logger.exception(e)
self._backoff()
def _set_service_recovered(self):
self.callback(self)

View File

@@ -1,30 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
if __name__ == "__main__":
import sys
from homeserver import HomeServerConfig
action = sys.argv[1]
if action == "read":
key = sys.argv[2]
config = HomeServerConfig.load_config("", sys.argv[3:])
print getattr(config, key)
sys.exit(0)
else:
sys.stderr.write("Unknown command %r\n" % (action,))
sys.exit(1)

View File

@@ -14,10 +14,9 @@
# limitations under the License.
import argparse
import sys
import os
import yaml
import sys
from textwrap import dedent
class ConfigError(Exception):
@@ -25,45 +24,18 @@ class ConfigError(Exception):
class Config(object):
stats_reporting_begging_spiel = (
"We would really appreciate it if you could help our project out by"
" reporting anonymized usage statistics from your homeserver. Only very"
" basic aggregate data (e.g. number of users) will be reported, but it"
" helps us to track the growth of the Matrix community, and helps us to"
" make Matrix a success, as well as to convince other networks that they"
" should peer with us."
"\nThank you."
)
def __init__(self, args):
pass
@staticmethod
def parse_size(value):
if isinstance(value, int) or isinstance(value, long):
return value
def parse_size(string):
sizes = {"K": 1024, "M": 1024 * 1024}
size = 1
suffix = value[-1]
suffix = string[-1]
if suffix in sizes:
value = value[:-1]
string = string[:-1]
size = sizes[suffix]
return int(value) * size
@staticmethod
def parse_duration(value):
if isinstance(value, int) or isinstance(value, long):
return value
second = 1000
hour = 60 * 60 * second
day = 24 * hour
week = 7 * day
year = 365 * day
sizes = {"s": second, "h": hour, "d": day, "w": week, "y": year}
size = 1
suffix = value[-1]
if suffix in sizes:
value = value[:-1]
size = sizes[suffix]
return int(value) * size
return int(string) * size
@staticmethod
def abspath(file_path):
@@ -114,200 +86,83 @@ class Config(object):
with open(file_path) as file_stream:
return yaml.load(file_stream)
def invoke_all(self, name, *args, **kargs):
results = []
for cls in type(self).mro():
if name in cls.__dict__:
results.append(getattr(cls, name)(self, *args, **kargs))
return results
@classmethod
def add_arguments(cls, parser):
pass
def generate_config(self, config_dir_path, server_name, report_stats=None):
default_config = "# vim:ft=yaml\n"
default_config += "\n\n".join(dedent(conf) for conf in self.invoke_all(
"default_config",
config_dir_path=config_dir_path,
server_name=server_name,
report_stats=report_stats,
))
config = yaml.load(default_config)
return default_config, config
@classmethod
def generate_config(cls, args, config_dir_path):
pass
@classmethod
def load_config(cls, description, argv, generate_section=None):
obj = cls()
config_parser = argparse.ArgumentParser(add_help=False)
config_parser.add_argument(
"-c", "--config-path",
action="append",
metavar="CONFIG_FILE",
help="Specify config file. Can be given multiple times and"
" may specify directories containing *.yaml files."
help="Specify config file"
)
config_parser.add_argument(
"--generate-config",
action="store_true",
help="Generate a config file for the server name"
)
config_parser.add_argument(
"--report-stats",
action="store",
help="Stuff",
choices=["yes", "no"]
)
config_parser.add_argument(
"--generate-keys",
action="store_true",
help="Generate any missing key files then exit"
)
config_parser.add_argument(
"--keys-directory",
metavar="DIRECTORY",
help="Used with 'generate-*' options to specify where files such as"
" certs and signing keys should be stored in, unless explicitly"
" specified in the config."
)
config_parser.add_argument(
"-H", "--server-name",
help="The server name to generate a config file for"
help="Generate config file"
)
config_args, remaining_args = config_parser.parse_known_args(argv)
generate_keys = config_args.generate_keys
config_files = []
if config_args.config_path:
for config_path in config_args.config_path:
if os.path.isdir(config_path):
# We accept specifying directories as config paths, we search
# inside that directory for all files matching *.yaml, and then
# we apply them in *sorted* order.
files = []
for entry in os.listdir(config_path):
entry_path = os.path.join(config_path, entry)
if not os.path.isfile(entry_path):
print (
"Found subdirectory in config directory: %r. IGNORING."
) % (entry_path, )
continue
if not entry.endswith(".yaml"):
print (
"Found file in config directory that does not"
" end in '.yaml': %r. IGNORING."
) % (entry_path, )
continue
files.append(entry_path)
config_files.extend(sorted(files))
else:
config_files.append(config_path)
if config_args.generate_config:
if config_args.report_stats is None:
if not config_args.config_path:
config_parser.error(
"Please specify either --report-stats=yes or --report-stats=no\n\n" +
cls.stats_reporting_begging_spiel
"Must specify where to generate the config file"
)
if not config_files:
config_parser.error(
"Must supply a config file.\nA config file can be automatically"
" generated using \"--generate-config -H SERVER_NAME"
" -c CONFIG-FILE\""
)
(config_path,) = config_files
if not os.path.exists(config_path):
if config_args.keys_directory:
config_dir_path = config_args.keys_directory
else:
config_dir_path = os.path.dirname(config_path)
config_dir_path = os.path.abspath(config_dir_path)
server_name = config_args.server_name
if not server_name:
print "Must specify a server_name to a generate config for."
sys.exit(1)
if not os.path.exists(config_dir_path):
os.makedirs(config_dir_path)
with open(config_path, "wb") as config_file:
config_bytes, config = obj.generate_config(
config_dir_path=config_dir_path,
server_name=server_name,
report_stats=(config_args.report_stats == "yes"),
)
obj.invoke_all("generate_files", config)
config_file.write(config_bytes)
print (
"A config file has been generated in %r for server name"
" %r with corresponding SSL keys and self-signed"
" certificates. Please review this file and customise it"
" to your needs."
) % (config_path, server_name)
print (
"If this server name is incorrect, you will need to"
" regenerate the SSL certificates"
)
sys.exit(0)
config_dir_path = os.path.dirname(config_args.config_path)
if os.path.exists(config_args.config_path):
defaults = cls.read_config_file(config_args.config_path)
else:
print (
"Config file %r already exists. Generating any missing key"
" files."
) % (config_path,)
generate_keys = True
defaults = {}
else:
if config_args.config_path:
defaults = cls.read_config_file(config_args.config_path)
else:
defaults = {}
parser = argparse.ArgumentParser(
parents=[config_parser],
description=description,
formatter_class=argparse.RawDescriptionHelpFormatter,
)
cls.add_arguments(parser)
parser.set_defaults(**defaults)
obj.invoke_all("add_arguments", parser)
args = parser.parse_args(remaining_args)
if not config_files:
config_parser.error(
"Must supply a config file.\nA config file can be automatically"
" generated using \"--generate-config -H SERVER_NAME"
" -c CONFIG-FILE\""
if config_args.generate_config:
config_dir_path = os.path.dirname(config_args.config_path)
config_dir_path = os.path.abspath(config_dir_path)
if not os.path.exists(config_dir_path):
os.makedirs(config_dir_path)
cls.generate_config(args, config_dir_path)
config = {}
for key, value in vars(args).items():
if (key not in set(["config_path", "generate_config"])
and value is not None):
config[key] = value
with open(config_args.config_path, "w") as config_file:
# TODO(paul) it would be lovely if we wrote out vim- and emacs-
# style mode markers into the file, to hint to people that
# this is a YAML file.
yaml.dump(config, config_file, default_flow_style=False)
print (
"A config file has been generated in %s for server name"
" '%s' with corresponding SSL keys and self-signed"
" certificates. Please review this file and customise it to"
" your needs."
) % (
config_args.config_path, config['server_name']
)
print (
"If this server name is incorrect, you will need to regenerate"
" the SSL certificates"
)
if config_args.keys_directory:
config_dir_path = config_args.keys_directory
else:
config_dir_path = os.path.dirname(config_args.config_path[-1])
config_dir_path = os.path.abspath(config_dir_path)
specified_config = {}
for config_file in config_files:
yaml_config = cls.read_config_file(config_file)
specified_config.update(yaml_config)
server_name = specified_config["server_name"]
_, config = obj.generate_config(
config_dir_path=config_dir_path,
server_name=server_name
)
config.pop("log_config")
config.update(specified_config)
if "report_stats" not in config:
sys.stderr.write(
"Please opt in or out of reporting anonymized homeserver usage "
"statistics, by setting the report_stats key in your config file "
" ( " + config_path + " ) " +
"to either True or False.\n\n" +
Config.stats_reporting_begging_spiel + "\n")
sys.exit(1)
if generate_keys:
obj.invoke_all("generate_files", config)
sys.exit(0)
obj.invoke_all("read_config", config)
obj.invoke_all("read_arguments", args)
return obj
return cls(args)

View File

@@ -1,27 +0,0 @@
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import Config
class AppServiceConfig(Config):
def read_config(self, config):
self.app_service_config_files = config.get("app_service_config_files", [])
def default_config(cls, **kwargs):
return """\
# A list of application service config file to use
app_service_config_files: []
"""

View File

@@ -17,31 +17,35 @@ from ._base import Config
class CaptchaConfig(Config):
def read_config(self, config):
self.recaptcha_private_key = config["recaptcha_private_key"]
self.recaptcha_public_key = config["recaptcha_public_key"]
self.enable_registration_captcha = config["enable_registration_captcha"]
self.captcha_bypass_secret = config.get("captcha_bypass_secret")
self.recaptcha_siteverify_api = config["recaptcha_siteverify_api"]
def __init__(self, args):
super(CaptchaConfig, self).__init__(args)
self.recaptcha_private_key = args.recaptcha_private_key
self.enable_registration_captcha = args.enable_registration_captcha
self.captcha_ip_origin_is_x_forwarded = (
args.captcha_ip_origin_is_x_forwarded
)
self.captcha_bypass_secret = args.captcha_bypass_secret
def default_config(self, **kwargs):
return """\
## Captcha ##
# This Home Server's ReCAPTCHA public key.
recaptcha_private_key: "YOUR_PRIVATE_KEY"
# This Home Server's ReCAPTCHA private key.
recaptcha_public_key: "YOUR_PUBLIC_KEY"
# Enables ReCaptcha checks when registering, preventing signup
# unless a captcha is answered. Requires a valid ReCaptcha
# public/private key.
enable_registration_captcha: False
# A secret key used to bypass the captcha test entirely.
#captcha_bypass_secret: "YOUR_SECRET_HERE"
# The API endpoint to use for verifying m.login.recaptcha responses.
recaptcha_siteverify_api: "https://www.google.com/recaptcha/api/siteverify"
"""
@classmethod
def add_arguments(cls, parser):
super(CaptchaConfig, cls).add_arguments(parser)
group = parser.add_argument_group("recaptcha")
group.add_argument(
"--recaptcha-private-key", type=str, default="YOUR_PRIVATE_KEY",
help="The matching private key for the web client's public key."
)
group.add_argument(
"--enable-registration-captcha", type=bool, default=False,
help="Enables ReCaptcha checks when registering, preventing signup"
+ " unless a captcha is answered. Requires a valid ReCaptcha "
+ "public/private key."
)
group.add_argument(
"--captcha_ip_origin_is_x_forwarded", type=bool, default=False,
help="When checking captchas, use the X-Forwarded-For (XFF) header"
+ " as the client IP and not the actual client IP."
)
group.add_argument(
"--captcha_bypass_secret", type=str,
help="A secret key used to bypass the captcha test entirely."
)

View File

@@ -14,66 +14,32 @@
# limitations under the License.
from ._base import Config
import os
class DatabaseConfig(Config):
def read_config(self, config):
self.event_cache_size = self.parse_size(
config.get("event_cache_size", "10K")
)
self.database_config = config.get("database")
if self.database_config is None:
self.database_config = {
"name": "sqlite3",
"args": {},
}
name = self.database_config.get("name", None)
if name == "psycopg2":
pass
elif name == "sqlite3":
self.database_config.setdefault("args", {}).update({
"cp_min": 1,
"cp_max": 1,
"check_same_thread": False,
})
def __init__(self, args):
super(DatabaseConfig, self).__init__(args)
if args.database_path == ":memory:":
self.database_path = ":memory:"
else:
raise RuntimeError("Unsupported database type '%s'" % (name,))
self.database_path = self.abspath(args.database_path)
self.event_cache_size = self.parse_size(args.event_cache_size)
self.set_databasepath(config.get("database_path"))
def default_config(self, **kwargs):
database_path = self.abspath("homeserver.db")
return """\
# Database configuration
database:
# The database engine name
name: "sqlite3"
# Arguments to pass to the engine
args:
# Path to the database
database: "%(database_path)s"
# Number of events to cache in memory.
event_cache_size: "10K"
""" % locals()
def read_arguments(self, args):
self.set_databasepath(args.database_path)
def set_databasepath(self, database_path):
if database_path != ":memory:":
database_path = self.abspath(database_path)
if self.database_config.get("name", None) == "sqlite3":
if database_path is not None:
self.database_config["args"]["database"] = database_path
def add_arguments(self, parser):
@classmethod
def add_arguments(cls, parser):
super(DatabaseConfig, cls).add_arguments(parser)
db_group = parser.add_argument_group("database")
db_group.add_argument(
"-d", "--database-path", metavar="SQLITE_DATABASE_PATH",
help="The path to a sqlite database to use."
"-d", "--database-path", default="homeserver.db",
help="The database name."
)
db_group.add_argument(
"--event-cache-size", default="100K",
help="Number of events to cache in memory."
)
@classmethod
def generate_config(cls, args, config_dir_path):
super(DatabaseConfig, cls).generate_config(args, config_dir_path)
args.database_path = os.path.abspath(args.database_path)

42
synapse/config/email.py Normal file
View File

@@ -0,0 +1,42 @@
# -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import Config
class EmailConfig(Config):
def __init__(self, args):
super(EmailConfig, self).__init__(args)
self.email_from_address = args.email_from_address
self.email_smtp_server = args.email_smtp_server
@classmethod
def add_arguments(cls, parser):
super(EmailConfig, cls).add_arguments(parser)
email_group = parser.add_argument_group("email")
email_group.add_argument(
"--email-from-address",
default="FROM@EXAMPLE.COM",
help="The address to send emails from (e.g. for password resets)."
)
email_group.add_argument(
"--email-smtp-server",
default="",
help=(
"The SMTP server to send emails from (e.g. for password"
" resets)."
)
)

View File

@@ -20,23 +20,17 @@ from .database import DatabaseConfig
from .ratelimiting import RatelimitConfig
from .repository import ContentRepositoryConfig
from .captcha import CaptchaConfig
from .email import EmailConfig
from .voip import VoipConfig
from .registration import RegistrationConfig
from .metrics import MetricsConfig
from .appservice import AppServiceConfig
from .key import KeyConfig
from .saml2 import SAML2Config
class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
RatelimitConfig, ContentRepositoryConfig, CaptchaConfig,
VoipConfig, RegistrationConfig, MetricsConfig,
AppServiceConfig, KeyConfig, SAML2Config, ):
EmailConfig, VoipConfig, RegistrationConfig,):
pass
if __name__ == '__main__':
import sys
sys.stdout.write(
HomeServerConfig().generate_config(sys.argv[1], sys.argv[2])[0]
)
HomeServerConfig.load_config("Generate config", sys.argv[1:], "HomeServer")

View File

@@ -1,130 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import Config, ConfigError
from synapse.util.stringutils import random_string
from signedjson.key import (
generate_signing_key, is_signing_algorithm_supported,
decode_signing_key_base64, decode_verify_key_bytes,
read_signing_keys, write_signing_keys, NACL_ED25519
)
from unpaddedbase64 import decode_base64
import os
class KeyConfig(Config):
def read_config(self, config):
self.signing_key = self.read_signing_key(config["signing_key_path"])
self.old_signing_keys = self.read_old_signing_keys(
config["old_signing_keys"]
)
self.key_refresh_interval = self.parse_duration(
config["key_refresh_interval"]
)
self.perspectives = self.read_perspectives(
config["perspectives"]
)
def default_config(self, config_dir_path, server_name, **kwargs):
base_key_name = os.path.join(config_dir_path, server_name)
return """\
## Signing Keys ##
# Path to the signing key to sign messages with
signing_key_path: "%(base_key_name)s.signing.key"
# The keys that the server used to sign messages with but won't use
# to sign new messages. E.g. it has lost its private key
old_signing_keys: {}
# "ed25519:auto":
# # Base64 encoded public key
# key: "The public part of your old signing key."
# # Millisecond POSIX timestamp when the key expired.
# expired_ts: 123456789123
# How long key response published by this server is valid for.
# Used to set the valid_until_ts in /key/v2 APIs.
# Determines how quickly servers will query to check which keys
# are still valid.
key_refresh_interval: "1d" # 1 Day.
# The trusted servers to download signing keys from.
perspectives:
servers:
"matrix.org":
verify_keys:
"ed25519:auto":
key: "Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw"
""" % locals()
def read_perspectives(self, perspectives_config):
servers = {}
for server_name, server_config in perspectives_config["servers"].items():
for key_id, key_data in server_config["verify_keys"].items():
if is_signing_algorithm_supported(key_id):
key_base64 = key_data["key"]
key_bytes = decode_base64(key_base64)
verify_key = decode_verify_key_bytes(key_id, key_bytes)
servers.setdefault(server_name, {})[key_id] = verify_key
return servers
def read_signing_key(self, signing_key_path):
signing_keys = self.read_file(signing_key_path, "signing_key")
try:
return read_signing_keys(signing_keys.splitlines(True))
except Exception:
raise ConfigError(
"Error reading signing_key."
" Try running again with --generate-config"
)
def read_old_signing_keys(self, old_signing_keys):
keys = {}
for key_id, key_data in old_signing_keys.items():
if is_signing_algorithm_supported(key_id):
key_base64 = key_data["key"]
key_bytes = decode_base64(key_base64)
verify_key = decode_verify_key_bytes(key_id, key_bytes)
verify_key.expired_ts = key_data["expired_ts"]
keys[key_id] = verify_key
else:
raise ConfigError(
"Unsupported signing algorithm for old key: %r" % (key_id,)
)
return keys
def generate_files(self, config):
signing_key_path = config["signing_key_path"]
if not os.path.exists(signing_key_path):
with open(signing_key_path, "w") as signing_key_file:
key_id = "a_" + random_string(4)
write_signing_keys(
signing_key_file, (generate_signing_key(key_id),),
)
else:
signing_keys = self.read_file(signing_key_path, "signing_key")
if len(signing_keys.split("\n")[0].split()) == 1:
# handle keys in the old format.
key_id = "a_" + random_string(4)
key = decode_signing_key_base64(
NACL_ED25519, key_id, signing_keys.split("\n")[0]
)
with open(signing_key_path, "w") as signing_key_file:
write_signing_keys(
signing_key_file, (key,),
)

View File

@@ -19,89 +19,25 @@ from twisted.python.log import PythonLoggingObserver
import logging
import logging.config
import yaml
from string import Template
import os
import signal
DEFAULT_LOG_CONFIG = Template("""
version: 1
formatters:
precise:
format: '%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s\
- %(message)s'
filters:
context:
(): synapse.util.logcontext.LoggingContextFilter
request: ""
handlers:
file:
class: logging.handlers.RotatingFileHandler
formatter: precise
filename: ${log_file}
maxBytes: 104857600
backupCount: 10
filters: [context]
level: INFO
console:
class: logging.StreamHandler
formatter: precise
loggers:
synapse:
level: INFO
synapse.storage.SQL:
level: INFO
root:
level: INFO
handlers: [file, console]
""")
class LoggingConfig(Config):
def __init__(self, args):
super(LoggingConfig, self).__init__(args)
self.verbosity = int(args.verbose) if args.verbose else None
self.log_config = self.abspath(args.log_config)
self.log_file = self.abspath(args.log_file)
def read_config(self, config):
self.verbosity = config.get("verbose", 0)
self.log_config = self.abspath(config.get("log_config"))
self.log_file = self.abspath(config.get("log_file"))
def default_config(self, config_dir_path, server_name, **kwargs):
log_file = self.abspath("homeserver.log")
log_config = self.abspath(
os.path.join(config_dir_path, server_name + ".log.config")
)
return """
# Logging verbosity level.
verbose: 0
# File to write logging to
log_file: "%(log_file)s"
# A yaml python logging config file
log_config: "%(log_config)s"
""" % locals()
def read_arguments(self, args):
if args.verbose is not None:
self.verbosity = args.verbose
if args.log_config is not None:
self.log_config = args.log_config
if args.log_file is not None:
self.log_file = args.log_file
@classmethod
def add_arguments(cls, parser):
super(LoggingConfig, cls).add_arguments(parser)
logging_group = parser.add_argument_group("logging")
logging_group.add_argument(
'-v', '--verbose', dest="verbose", action='count',
help="The verbosity level."
)
logging_group.add_argument(
'-f', '--log-file', dest="log_file",
'-f', '--log-file', dest="log_file", default="homeserver.log",
help="File to log to."
)
logging_group.add_argument(
@@ -109,14 +45,6 @@ class LoggingConfig(Config):
help="Python logging config file"
)
def generate_files(self, config):
log_config = config.get("log_config")
if log_config and not os.path.exists(log_config):
with open(log_config, "wb") as log_config_file:
log_config_file.write(
DEFAULT_LOG_CONFIG.substitute(log_file=config["log_file"])
)
def setup_logging(self):
log_format = (
"%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(request)s"
@@ -143,19 +71,6 @@ class LoggingConfig(Config):
handler = logging.handlers.RotatingFileHandler(
self.log_file, maxBytes=(1000 * 1000 * 100), backupCount=3
)
def sighup(signum, stack):
logger.info("Closing log file due to SIGHUP")
handler.doRollover()
logger.info("Opened new log file due to SIGHUP")
# TODO(paul): obviously this is a terrible mechanism for
# stealing SIGHUP, because it means no other part of synapse
# can use it instead. If we want to catch SIGHUP anywhere
# else as well, I'd suggest we find a nicer way to broadcast
# it around.
if getattr(signal, "SIGHUP"):
signal.signal(signal.SIGHUP, sighup)
else:
handler = logging.StreamHandler()
handler.setFormatter(formatter)
@@ -163,6 +78,7 @@ class LoggingConfig(Config):
handler.addFilter(LoggingContextFilter(request=""))
logger.addHandler(handler)
logger.info("Test")
else:
with open(self.log_config, 'r') as f:
logging.config.dictConfig(yaml.load(f))

View File

@@ -1,33 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import Config
class MetricsConfig(Config):
def read_config(self, config):
self.enable_metrics = config["enable_metrics"]
self.report_stats = config.get("report_stats", None)
self.metrics_port = config.get("metrics_port")
self.metrics_bind_host = config.get("metrics_bind_host", "127.0.0.1")
def default_config(self, report_stats=None, **kwargs):
suffix = "" if report_stats is None else "report_stats: %(report_stats)s\n"
return ("""\
## Metrics ###
# Enable collection and rendering of performance metrics
enable_metrics: False
""" + suffix) % locals()

View File

@@ -17,42 +17,56 @@ from ._base import Config
class RatelimitConfig(Config):
def read_config(self, config):
self.rc_messages_per_second = config["rc_messages_per_second"]
self.rc_message_burst_count = config["rc_message_burst_count"]
def __init__(self, args):
super(RatelimitConfig, self).__init__(args)
self.rc_messages_per_second = args.rc_messages_per_second
self.rc_message_burst_count = args.rc_message_burst_count
self.federation_rc_window_size = config["federation_rc_window_size"]
self.federation_rc_sleep_limit = config["federation_rc_sleep_limit"]
self.federation_rc_sleep_delay = config["federation_rc_sleep_delay"]
self.federation_rc_reject_limit = config["federation_rc_reject_limit"]
self.federation_rc_concurrent = config["federation_rc_concurrent"]
self.federation_rc_window_size = args.federation_rc_window_size
self.federation_rc_sleep_limit = args.federation_rc_sleep_limit
self.federation_rc_sleep_delay = args.federation_rc_sleep_delay
self.federation_rc_reject_limit = args.federation_rc_reject_limit
self.federation_rc_concurrent = args.federation_rc_concurrent
def default_config(self, **kwargs):
return """\
## Ratelimiting ##
@classmethod
def add_arguments(cls, parser):
super(RatelimitConfig, cls).add_arguments(parser)
rc_group = parser.add_argument_group("ratelimiting")
rc_group.add_argument(
"--rc-messages-per-second", type=float, default=0.2,
help="number of messages a client can send per second"
)
rc_group.add_argument(
"--rc-message-burst-count", type=float, default=10,
help="number of message a client can send before being throttled"
)
# Number of messages a client can send per second
rc_messages_per_second: 0.2
rc_group.add_argument(
"--federation-rc-window-size", type=int, default=10000,
help="The federation window size in milliseconds",
)
# Number of message a client can send before being throttled
rc_message_burst_count: 10.0
rc_group.add_argument(
"--federation-rc-sleep-limit", type=int, default=10,
help="The number of federation requests from a single server"
" in a window before the server will delay processing the"
" request.",
)
# The federation window size in milliseconds
federation_rc_window_size: 1000
rc_group.add_argument(
"--federation-rc-sleep-delay", type=int, default=500,
help="The duration in milliseconds to delay processing events from"
" remote servers by if they go over the sleep limit.",
)
# The number of federation requests from a single server in a window
# before the server will delay processing the request.
federation_rc_sleep_limit: 10
rc_group.add_argument(
"--federation-rc-reject-limit", type=int, default=50,
help="The maximum number of concurrent federation requests allowed"
" from a single server",
)
# The duration in milliseconds to delay processing events from
# remote servers by if they go over the sleep limit.
federation_rc_sleep_delay: 500
# The maximum number of concurrent federation requests allowed
# from a single server
federation_rc_reject_limit: 50
# The number of federation requests to concurrently process from a
# single server
federation_rc_concurrent: 3
"""
rc_group.add_argument(
"--federation-rc-concurrent", type=int, default=3,
help="The number of federation requests to concurrently process"
" from a single server",
)

View File

@@ -15,50 +15,19 @@
from ._base import Config
from synapse.util.stringutils import random_string_with_symbols
from distutils.util import strtobool
class RegistrationConfig(Config):
def read_config(self, config):
self.disable_registration = not bool(
strtobool(str(config["enable_registration"]))
)
if "disable_registration" in config:
self.disable_registration = bool(
strtobool(str(config["disable_registration"]))
)
def __init__(self, args):
super(RegistrationConfig, self).__init__(args)
self.disable_registration = args.disable_registration
self.registration_shared_secret = config.get("registration_shared_secret")
self.macaroon_secret_key = config.get("macaroon_secret_key")
def default_config(self, **kwargs):
registration_shared_secret = random_string_with_symbols(50)
macaroon_secret_key = random_string_with_symbols(50)
return """\
## Registration ##
# Enable registration for new users.
enable_registration: False
# If set, allows registration by anyone who also has the shared
# secret, even if registration is otherwise disabled.
registration_shared_secret: "%(registration_shared_secret)s"
macaroon_secret_key: "%(macaroon_secret_key)s"
""" % locals()
def add_arguments(self, parser):
@classmethod
def add_arguments(cls, parser):
super(RegistrationConfig, cls).add_arguments(parser)
reg_group = parser.add_argument_group("registration")
reg_group.add_argument(
"--enable-registration", action="store_true", default=None,
help="Enable registration for new users."
"--disable-registration",
action='store_true',
help="Disable registration of new users."
)
def read_arguments(self, args):
if args.enable_registration is not None:
self.disable_registration = not bool(
strtobool(str(args.enable_registration))
)

View File

@@ -14,87 +14,35 @@
# limitations under the License.
from ._base import Config
from collections import namedtuple
ThumbnailRequirement = namedtuple(
"ThumbnailRequirement", ["width", "height", "method", "media_type"]
)
def parse_thumbnail_requirements(thumbnail_sizes):
""" Takes a list of dictionaries with "width", "height", and "method" keys
and creates a map from image media types to the thumbnail size, thumnailing
method, and thumbnail media type to precalculate
Args:
thumbnail_sizes(list): List of dicts with "width", "height", and
"method" keys
Returns:
Dictionary mapping from media type string to list of
ThumbnailRequirement tuples.
"""
requirements = {}
for size in thumbnail_sizes:
width = size["width"]
height = size["height"]
method = size["method"]
jpeg_thumbnail = ThumbnailRequirement(width, height, method, "image/jpeg")
png_thumbnail = ThumbnailRequirement(width, height, method, "image/png")
requirements.setdefault("image/jpeg", []).append(jpeg_thumbnail)
requirements.setdefault("image/gif", []).append(png_thumbnail)
requirements.setdefault("image/png", []).append(png_thumbnail)
return {
media_type: tuple(thumbnails)
for media_type, thumbnails in requirements.items()
}
class ContentRepositoryConfig(Config):
def read_config(self, config):
self.max_upload_size = self.parse_size(config["max_upload_size"])
self.max_image_pixels = self.parse_size(config["max_image_pixels"])
self.media_store_path = self.ensure_directory(config["media_store_path"])
self.uploads_path = self.ensure_directory(config["uploads_path"])
self.dynamic_thumbnails = config["dynamic_thumbnails"]
self.thumbnail_requirements = parse_thumbnail_requirements(
config["thumbnail_sizes"]
def __init__(self, args):
super(ContentRepositoryConfig, self).__init__(args)
self.max_upload_size = self.parse_size(args.max_upload_size)
self.max_image_pixels = self.parse_size(args.max_image_pixels)
self.media_store_path = self.ensure_directory(args.media_store_path)
def parse_size(self, string):
sizes = {"K": 1024, "M": 1024 * 1024}
size = 1
suffix = string[-1]
if suffix in sizes:
string = string[:-1]
size = sizes[suffix]
return int(string) * size
@classmethod
def add_arguments(cls, parser):
super(ContentRepositoryConfig, cls).add_arguments(parser)
db_group = parser.add_argument_group("content_repository")
db_group.add_argument(
"--max-upload-size", default="10M"
)
db_group.add_argument(
"--media-store-path", default=cls.default_path("media_store")
)
db_group.add_argument(
"--max-image-pixels", default="32M",
help="Maximum number of pixels that will be thumbnailed"
)
def default_config(self, **kwargs):
media_store = self.default_path("media_store")
uploads_path = self.default_path("uploads")
return """
# Directory where uploaded images and attachments are stored.
media_store_path: "%(media_store)s"
# Directory where in-progress uploads are stored.
uploads_path: "%(uploads_path)s"
# The largest allowed upload size in bytes
max_upload_size: "10M"
# Maximum number of pixels that will be thumbnailed
max_image_pixels: "32M"
# Whether to generate new thumbnails on the fly to precisely match
# the resolution requested by the client. If true then whenever
# a new resolution is requested by the client the server will
# generate a new thumbnail. If false the server will pick a thumbnail
# from a precalcualted list.
dynamic_thumbnails: false
# List of thumbnail to precalculate when an image is uploaded.
thumbnail_sizes:
- width: 32
height: 32
method: crop
- width: 96
height: 96
method: crop
- width: 320
height: 240
method: scale
- width: 640
height: 480
method: scale
""" % locals()

View File

@@ -1,54 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2015 Ericsson
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import Config
class SAML2Config(Config):
"""SAML2 Configuration
Synapse uses pysaml2 libraries for providing SAML2 support
config_path: Path to the sp_conf.py configuration file
idp_redirect_url: Identity provider URL which will redirect
the user back to /login/saml2 with proper info.
sp_conf.py file is something like:
https://github.com/rohe/pysaml2/blob/master/example/sp-repoze/sp_conf.py.example
More information: https://pythonhosted.org/pysaml2/howto/config.html
"""
def read_config(self, config):
saml2_config = config.get("saml2_config", None)
if saml2_config:
self.saml2_enabled = True
self.saml2_config_path = saml2_config["config_path"]
self.saml2_idp_redirect_url = saml2_config["idp_redirect_url"]
else:
self.saml2_enabled = False
self.saml2_config_path = None
self.saml2_idp_redirect_url = None
def default_config(self, config_dir_path, server_name, **kwargs):
return """
# Enable SAML2 for registration and login. Uses pysaml2
# config_path: Path to the sp_conf.py configuration file
# idp_redirect_url: Identity provider URL which will redirect
# the user back to /login/saml2 with proper info.
# See pysaml2 docs for format of config.
#saml2_config:
# config_path: "%s/sp_conf.py"
# idp_redirect_url: "http://%s/idp"
""" % (config_dir_path, server_name)

View File

@@ -13,216 +13,114 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import Config
import os
from ._base import Config, ConfigError
import syutil.crypto.signing_key
class ServerConfig(Config):
def __init__(self, args):
super(ServerConfig, self).__init__(args)
self.server_name = args.server_name
self.signing_key = self.read_signing_key(args.signing_key_path)
self.bind_port = args.bind_port
self.bind_host = args.bind_host
self.unsecure_port = args.unsecure_port
self.daemonize = args.daemonize
self.pid_file = self.abspath(args.pid_file)
self.webclient = True
self.manhole = args.manhole
self.soft_file_limit = args.soft_file_limit
def read_config(self, config):
self.server_name = config["server_name"]
self.pid_file = self.abspath(config.get("pid_file"))
self.web_client = config["web_client"]
self.web_client_location = config.get("web_client_location", None)
self.soft_file_limit = config["soft_file_limit"]
self.daemonize = config.get("daemonize")
self.print_pidfile = config.get("print_pidfile")
self.use_frozen_dicts = config.get("use_frozen_dicts", True)
self.listeners = config.get("listeners", [])
bind_port = config.get("bind_port")
if bind_port:
self.listeners = []
bind_host = config.get("bind_host", "")
gzip_responses = config.get("gzip_responses", True)
names = ["client", "webclient"] if self.web_client else ["client"]
self.listeners.append({
"port": bind_port,
"bind_address": bind_host,
"tls": True,
"type": "http",
"resources": [
{
"names": names,
"compress": gzip_responses,
},
{
"names": ["federation"],
"compress": False,
}
]
})
unsecure_port = config.get("unsecure_port", bind_port - 400)
if unsecure_port:
self.listeners.append({
"port": unsecure_port,
"bind_address": bind_host,
"tls": False,
"type": "http",
"resources": [
{
"names": names,
"compress": gzip_responses,
},
{
"names": ["federation"],
"compress": False,
}
]
})
manhole = config.get("manhole")
if manhole:
self.listeners.append({
"port": manhole,
"bind_address": "127.0.0.1",
"type": "manhole",
})
metrics_port = config.get("metrics_port")
if metrics_port:
self.listeners.append({
"port": metrics_port,
"bind_address": config.get("metrics_bind_host", "127.0.0.1"),
"tls": False,
"type": "http",
"resources": [
{
"names": ["metrics"],
"compress": False,
},
]
})
# Attempt to guess the content_addr for the v0 content repostitory
content_addr = config.get("content_addr")
if not content_addr:
for listener in self.listeners:
if listener["type"] == "http" and not listener.get("tls", False):
unsecure_port = listener["port"]
break
else:
raise RuntimeError("Could not determine 'content_addr'")
host = self.server_name
if not args.content_addr:
host = args.server_name
if ':' not in host:
host = "%s:%d" % (host, unsecure_port)
host = "%s:%d" % (host, args.unsecure_port)
else:
host = host.split(':')[0]
host = "%s:%d" % (host, unsecure_port)
content_addr = "http://%s" % (host,)
host = "%s:%d" % (host, args.unsecure_port)
args.content_addr = "http://%s" % (host,)
self.content_addr = content_addr
self.content_addr = args.content_addr
def default_config(self, server_name, **kwargs):
if ":" in server_name:
bind_port = int(server_name.split(":")[1])
unsecure_port = bind_port - 400
else:
bind_port = 8448
unsecure_port = 8008
pid_file = self.abspath("homeserver.pid")
return """\
## Server ##
# The domain name of the server, with optional explicit port.
# This is used by remote servers to connect to this server,
# e.g. matrix.org, localhost:8080, etc.
server_name: "%(server_name)s"
# When running as a daemon, the file to store the pid in
pid_file: %(pid_file)s
# Whether to serve a web client from the HTTP/HTTPS root resource.
web_client: True
# Set the soft limit on the number of file descriptors synapse can use
# Zero is used to indicate synapse should set the soft limit to the
# hard limit.
soft_file_limit: 0
# List of ports that Synapse should listen on, their purpose and their
# configuration.
listeners:
# Main HTTPS listener
# For when matrix traffic is sent directly to synapse.
-
# The port to listen for HTTPS requests on.
port: %(bind_port)s
# Local interface to listen on.
# The empty string will cause synapse to listen on all interfaces.
bind_address: ''
# This is a 'http' listener, allows us to specify 'resources'.
type: http
tls: true
# Use the X-Forwarded-For (XFF) header as the client IP and not the
# actual client IP.
x_forwarded: false
# List of HTTP resources to serve on this listener.
resources:
-
# List of resources to host on this listener.
names:
- client # The client-server APIs, both v1 and v2
- webclient # The bundled webclient.
# Should synapse compress HTTP responses to clients that support it?
# This should be disabled if running synapse behind a load balancer
# that can do automatic compression.
compress: true
- names: [federation] # Federation APIs
compress: false
# Unsecure HTTP listener,
# For when matrix traffic passes through loadbalancer that unwraps TLS.
- port: %(unsecure_port)s
tls: false
bind_address: ''
type: http
x_forwarded: false
resources:
- names: [client, webclient]
compress: true
- names: [federation]
compress: false
# Turn on the twisted telnet manhole service on localhost on the given
# port.
# - port: 9000
# bind_address: 127.0.0.1
# type: manhole
""" % locals()
def read_arguments(self, args):
if args.manhole is not None:
self.manhole = args.manhole
if args.daemonize is not None:
self.daemonize = args.daemonize
if args.print_pidfile is not None:
self.print_pidfile = args.print_pidfile
def add_arguments(self, parser):
@classmethod
def add_arguments(cls, parser):
super(ServerConfig, cls).add_arguments(parser)
server_group = parser.add_argument_group("server")
server_group.add_argument(
"-H", "--server-name", default="localhost",
help="The domain name of the server, with optional explicit port. "
"This is used by remote servers to connect to this server, "
"e.g. matrix.org, localhost:8080, etc."
)
server_group.add_argument("--signing-key-path",
help="The signing key to sign messages with")
server_group.add_argument("-p", "--bind-port", metavar="PORT",
type=int, help="https port to listen on",
default=8448)
server_group.add_argument("--unsecure-port", metavar="PORT",
type=int, help="http port to listen on",
default=8008)
server_group.add_argument("--bind-host", default="",
help="Local interface to listen on")
server_group.add_argument("-D", "--daemonize", action='store_true',
default=None,
help="Daemonize the home server")
server_group.add_argument("--print-pidfile", action='store_true',
default=None,
help="Print the path to the pidfile just"
" before daemonizing")
server_group.add_argument('--pid-file', default="homeserver.pid",
help="When running as a daemon, the file to"
" store the pid in")
server_group.add_argument("--manhole", metavar="PORT", dest="manhole",
type=int,
help="Turn on the twisted telnet manhole"
" service on the given port.")
server_group.add_argument("--content-addr", default=None,
help="The host and scheme to use for the "
"content repository")
server_group.add_argument("--soft-file-limit", type=int, default=0,
help="Set the soft limit on the number of "
"file descriptors synapse can use. "
"Zero is used to indicate synapse "
"should set the soft limit to the hard"
"limit.")
def read_signing_key(self, signing_key_path):
signing_keys = self.read_file(signing_key_path, "signing_key")
try:
return syutil.crypto.signing_key.read_signing_keys(
signing_keys.splitlines(True)
)
except Exception:
raise ConfigError(
"Error reading signing_key."
" Try running again with --generate-config"
)
@classmethod
def generate_config(cls, args, config_dir_path):
super(ServerConfig, cls).generate_config(args, config_dir_path)
base_key_name = os.path.join(config_dir_path, args.server_name)
args.pid_file = os.path.abspath(args.pid_file)
if not args.signing_key_path:
args.signing_key_path = base_key_name + ".signing.key"
if not os.path.exists(args.signing_key_path):
with open(args.signing_key_path, "w") as signing_key_file:
syutil.crypto.signing_key.write_signing_keys(
signing_key_file,
(syutil.crypto.signing_key.generate_singing_key("auto"),),
)
else:
signing_keys = cls.read_file(args.signing_key_path, "signing_key")
if len(signing_keys.split("\n")[0].split()) == 1:
# handle keys in the old format.
key = syutil.crypto.signing_key.decode_signing_key_base64(
syutil.crypto.signing_key.NACL_ED25519,
"auto",
signing_keys.split("\n")[0]
)
with open(args.signing_key_path, "w") as signing_key_file:
syutil.crypto.signing_key.write_signing_keys(
signing_key_file,
(key,),
)

View File

@@ -23,57 +23,37 @@ GENERATE_DH_PARAMS = False
class TlsConfig(Config):
def read_config(self, config):
def __init__(self, args):
super(TlsConfig, self).__init__(args)
self.tls_certificate = self.read_tls_certificate(
config.get("tls_certificate_path")
args.tls_certificate_path
)
self.tls_certificate_file = config.get("tls_certificate_path")
self.no_tls = config.get("no_tls", False)
self.no_tls = args.no_tls
if self.no_tls:
self.tls_private_key = None
else:
self.tls_private_key = self.read_tls_private_key(
config.get("tls_private_key_path")
args.tls_private_key_path
)
self.tls_dh_params_path = self.check_file(
config.get("tls_dh_params_path"), "tls_dh_params"
args.tls_dh_params_path, "tls_dh_params"
)
# This config option applies to non-federation HTTP clients
# (e.g. for talking to recaptcha, identity servers, and such)
# It should never be used in production, and is intended for
# use only when running tests.
self.use_insecure_ssl_client_just_for_testing_do_not_use = config.get(
"use_insecure_ssl_client_just_for_testing_do_not_use"
)
def default_config(self, config_dir_path, server_name, **kwargs):
base_key_name = os.path.join(config_dir_path, server_name)
tls_certificate_path = base_key_name + ".tls.crt"
tls_private_key_path = base_key_name + ".tls.key"
tls_dh_params_path = base_key_name + ".tls.dh"
return """\
# PEM encoded X509 certificate for TLS.
# You can replace the self-signed certificate that synapse
# autogenerates on launch with your own SSL certificate + key pair
# if you like. Any required intermediary certificates can be
# appended after the primary certificate in hierarchical order.
tls_certificate_path: "%(tls_certificate_path)s"
# PEM encoded private key for TLS
tls_private_key_path: "%(tls_private_key_path)s"
# PEM dh parameters for ephemeral keys
tls_dh_params_path: "%(tls_dh_params_path)s"
# Don't bind to the https port
no_tls: False
""" % locals()
@classmethod
def add_arguments(cls, parser):
super(TlsConfig, cls).add_arguments(parser)
tls_group = parser.add_argument_group("tls")
tls_group.add_argument("--tls-certificate-path",
help="PEM encoded X509 certificate for TLS")
tls_group.add_argument("--tls-private-key-path",
help="PEM encoded private key for TLS")
tls_group.add_argument("--tls-dh-params-path",
help="PEM dh parameters for ephemeral keys")
tls_group.add_argument("--no-tls", action='store_true',
help="Don't bind to the https port.")
def read_tls_certificate(self, cert_path):
cert_pem = self.read_file(cert_path, "tls_certificate")
@@ -83,13 +63,22 @@ class TlsConfig(Config):
private_key_pem = self.read_file(private_key_path, "tls_private_key")
return crypto.load_privatekey(crypto.FILETYPE_PEM, private_key_pem)
def generate_files(self, config):
tls_certificate_path = config["tls_certificate_path"]
tls_private_key_path = config["tls_private_key_path"]
tls_dh_params_path = config["tls_dh_params_path"]
@classmethod
def generate_config(cls, args, config_dir_path):
super(TlsConfig, cls).generate_config(args, config_dir_path)
base_key_name = os.path.join(config_dir_path, args.server_name)
if not os.path.exists(tls_private_key_path):
with open(tls_private_key_path, "w") as private_key_file:
if args.tls_certificate_path is None:
args.tls_certificate_path = base_key_name + ".tls.crt"
if args.tls_private_key_path is None:
args.tls_private_key_path = base_key_name + ".tls.key"
if args.tls_dh_params_path is None:
args.tls_dh_params_path = base_key_name + ".tls.dh"
if not os.path.exists(args.tls_private_key_path):
with open(args.tls_private_key_path, "w") as private_key_file:
tls_private_key = crypto.PKey()
tls_private_key.generate_key(crypto.TYPE_RSA, 2048)
private_key_pem = crypto.dump_privatekey(
@@ -97,17 +86,17 @@ class TlsConfig(Config):
)
private_key_file.write(private_key_pem)
else:
with open(tls_private_key_path) as private_key_file:
with open(args.tls_private_key_path) as private_key_file:
private_key_pem = private_key_file.read()
tls_private_key = crypto.load_privatekey(
crypto.FILETYPE_PEM, private_key_pem
)
if not os.path.exists(tls_certificate_path):
with open(tls_certificate_path, "w") as certificate_file:
if not os.path.exists(args.tls_certificate_path):
with open(args.tls_certificate_path, "w") as certifcate_file:
cert = crypto.X509()
subject = cert.get_subject()
subject.CN = config["server_name"]
subject.CN = args.server_name
cert.set_serial_number(1000)
cert.gmtime_adj_notBefore(0)
@@ -119,18 +108,18 @@ class TlsConfig(Config):
cert_pem = crypto.dump_certificate(crypto.FILETYPE_PEM, cert)
certificate_file.write(cert_pem)
certifcate_file.write(cert_pem)
if not os.path.exists(tls_dh_params_path):
if not os.path.exists(args.tls_dh_params_path):
if GENERATE_DH_PARAMS:
subprocess.check_call([
"openssl", "dhparam",
"-outform", "PEM",
"-out", tls_dh_params_path,
"-out", args.tls_dh_params_path,
"2048"
])
else:
with open(tls_dh_params_path, "w") as dh_params_file:
with open(args.tls_dh_params_path, "w") as dh_params_file:
dh_params_file.write(
"2048-bit DH parameters taken from rfc3526\n"
"-----BEGIN DH PARAMETERS-----\n"

View File

@@ -17,21 +17,28 @@ from ._base import Config
class VoipConfig(Config):
def read_config(self, config):
self.turn_uris = config.get("turn_uris", [])
self.turn_shared_secret = config["turn_shared_secret"]
self.turn_user_lifetime = self.parse_duration(config["turn_user_lifetime"])
def __init__(self, args):
super(VoipConfig, self).__init__(args)
self.turn_uris = args.turn_uris
self.turn_shared_secret = args.turn_shared_secret
self.turn_user_lifetime = args.turn_user_lifetime
def default_config(self, **kwargs):
return """\
## Turn ##
# The public URIs of the TURN server to give to clients
turn_uris: []
# The shared secret used to compute passwords for the TURN server
turn_shared_secret: "YOUR_SHARED_SECRET"
# How long generated TURN credentials last
turn_user_lifetime: "1h"
"""
@classmethod
def add_arguments(cls, parser):
super(VoipConfig, cls).add_arguments(parser)
group = parser.add_argument_group("voip")
group.add_argument(
"--turn-uris", type=str, default=None, action='append',
help="The public URIs of the TURN server to give to clients"
)
group.add_argument(
"--turn-shared-secret", type=str, default=None,
help=(
"The shared secret used to compute passwords for the TURN"
" server"
)
)
group.add_argument(
"--turn-user-lifetime", type=int, default=(1000 * 60 * 60),
help="How long generated TURN credentials last, in ms"
)

View File

@@ -35,9 +35,9 @@ class ServerContextFactory(ssl.ContextFactory):
_ecCurve = _OpenSSLECCurve(_defaultCurveName)
_ecCurve.addECKeyToContext(context)
except:
logger.exception("Failed to enable elliptic curve for TLS")
logger.exception("Failed to enable eliptic curve for TLS")
context.set_options(SSL.OP_NO_SSLv2 | SSL.OP_NO_SSLv3)
context.use_certificate_chain_file(config.tls_certificate_file)
context.use_certificate(config.tls_certificate)
if not config.no_tls:
context.use_privatekey(config.tls_private_key)

View File

@@ -15,12 +15,11 @@
# limitations under the License.
from synapse.api.errors import SynapseError, Codes
from synapse.events.utils import prune_event
from canonicaljson import encode_canonical_json
from unpaddedbase64 import encode_base64, decode_base64
from signedjson.sign import sign_json
from syutil.jsonutil import encode_canonical_json
from syutil.base64util import encode_base64, decode_base64
from syutil.crypto.jsonsign import sign_json
from synapse.api.errors import SynapseError, Codes
import hashlib
import logging

View File

@@ -18,51 +18,37 @@ from twisted.web.http import HTTPClient
from twisted.internet.protocol import Factory
from twisted.internet import defer, reactor
from synapse.http.endpoint import matrix_federation_endpoint
from synapse.util.logcontext import (
preserve_context_over_fn, preserve_context_over_deferred
)
from synapse.util.logcontext import PreserveLoggingContext
import simplejson as json
import logging
logger = logging.getLogger(__name__)
KEY_API_V1 = b"/_matrix/key/v1/"
@defer.inlineCallbacks
def fetch_server_key(server_name, ssl_context_factory, path=KEY_API_V1):
def fetch_server_key(server_name, ssl_context_factory):
"""Fetch the keys for a remote server."""
factory = SynapseKeyClientFactory()
factory.path = path
endpoint = matrix_federation_endpoint(
reactor, server_name, ssl_context_factory, timeout=30
)
for i in range(5):
try:
protocol = yield preserve_context_over_fn(
endpoint.connect, factory
)
server_response, server_certificate = yield preserve_context_over_deferred(
protocol.remote_key
)
defer.returnValue((server_response, server_certificate))
return
except SynapseKeyClientError as e:
logger.exception("Error getting key for %r" % (server_name,))
if e.status.startswith("4"):
# Don't retry for 4xx responses.
raise IOError("Cannot get key for %r" % server_name)
with PreserveLoggingContext():
protocol = yield endpoint.connect(factory)
server_response, server_certificate = yield protocol.remote_key
defer.returnValue((server_response, server_certificate))
return
except Exception as e:
logger.exception(e)
raise IOError("Cannot get key for %r" % server_name)
raise IOError("Cannot get key for %s" % server_name)
class SynapseKeyClientError(Exception):
"""The key wasn't retrieved from the remote server."""
status = None
pass
@@ -80,30 +66,17 @@ class SynapseKeyClientProtocol(HTTPClient):
def connectionMade(self):
self.host = self.transport.getHost()
logger.debug("Connected to %s", self.host)
self.sendCommand(b"GET", self.path)
self.sendCommand(b"GET", b"/_matrix/key/v1/")
self.endHeaders()
self.timer = reactor.callLater(
self.timeout,
self.on_timeout
)
def errback(self, error):
if not self.remote_key.called:
self.remote_key.errback(error)
def callback(self, result):
if not self.remote_key.called:
self.remote_key.callback(result)
def handleStatus(self, version, status, message):
if status != b"200":
# logger.info("Non-200 response from %s: %s %s",
# self.transport.getHost(), status, message)
error = SynapseKeyClientError(
"Non-200 response %r from %r" % (status, self.host)
)
error.status = status
self.errback(error)
self.transport.abortConnection()
def handleResponse(self, response_body_bytes):
@@ -116,18 +89,15 @@ class SynapseKeyClientProtocol(HTTPClient):
return
certificate = self.transport.getPeerCertificate()
self.callback((json_response, certificate))
self.remote_key.callback((json_response, certificate))
self.transport.abortConnection()
self.timer.cancel()
def on_timeout(self):
logger.debug("Timeout waiting for response from %s", self.host)
self.errback(IOError("Timeout waiting for response"))
self.remote_key.errback(IOError("Timeout waiting for response"))
self.transport.abortConnection()
class SynapseKeyClientFactory(Factory):
def protocol(self):
protocol = SynapseKeyClientProtocol()
protocol.path = self.path
return protocol
protocol = SynapseKeyClientProtocol

View File

@@ -14,598 +14,102 @@
# limitations under the License.
from synapse.crypto.keyclient import fetch_server_key
from synapse.api.errors import SynapseError, Codes
from synapse.util.retryutils import get_retry_limiter
from synapse.util import unwrapFirstError
from synapse.util.async import ObservableDeferred
from twisted.internet import defer
from signedjson.sign import (
verify_signed_json, signature_ids, sign_json, encode_canonical_json
)
from signedjson.key import (
from syutil.crypto.jsonsign import verify_signed_json, signature_ids
from syutil.crypto.signing_key import (
is_signing_algorithm_supported, decode_verify_key_bytes
)
from unpaddedbase64 import decode_base64, encode_base64
from syutil.base64util import decode_base64, encode_base64
from synapse.api.errors import SynapseError, Codes
from synapse.util.retryutils import get_retry_limiter
from OpenSSL import crypto
from collections import namedtuple
import urllib
import hashlib
import logging
logger = logging.getLogger(__name__)
KeyGroup = namedtuple("KeyGroup", ("server_name", "group_id", "key_ids"))
class Keyring(object):
def __init__(self, hs):
self.store = hs.get_datastore()
self.clock = hs.get_clock()
self.client = hs.get_http_client()
self.config = hs.get_config()
self.perspective_servers = self.config.perspectives
self.hs = hs
self.key_downloads = {}
@defer.inlineCallbacks
def verify_json_for_server(self, server_name, json_object):
return self.verify_json_objects_for_server(
[(server_name, json_object)]
)[0]
def verify_json_objects_for_server(self, server_and_json):
"""Bulk verfies signatures of json objects, bulk fetching keys as
necessary.
Args:
server_and_json (list): List of pairs of (server_name, json_object)
Returns:
list of deferreds indicating success or failure to verify each
json object's signature for the given server_name.
"""
group_id_to_json = {}
group_id_to_group = {}
group_ids = []
next_group_id = 0
deferreds = {}
for server_name, json_object in server_and_json:
logger.debug("Verifying for %s", server_name)
group_id = next_group_id
next_group_id += 1
group_ids.append(group_id)
key_ids = signature_ids(json_object, server_name)
if not key_ids:
deferreds[group_id] = defer.fail(SynapseError(
400,
"Not signed with a supported algorithm",
Codes.UNAUTHORIZED,
))
else:
deferreds[group_id] = defer.Deferred()
group = KeyGroup(server_name, group_id, key_ids)
group_id_to_group[group_id] = group
group_id_to_json[group_id] = json_object
@defer.inlineCallbacks
def handle_key_deferred(group, deferred):
server_name = group.server_name
try:
_, _, key_id, verify_key = yield deferred
except IOError as e:
logger.warn(
"Got IOError when downloading keys for %s: %s %s",
server_name, type(e).__name__, str(e.message),
)
raise SynapseError(
502,
"Error downloading keys for %s" % (server_name,),
Codes.UNAUTHORIZED,
)
except Exception as e:
logger.exception(
"Got Exception when downloading keys for %s: %s %s",
server_name, type(e).__name__, str(e.message),
)
raise SynapseError(
401,
"No key for %s with id %s" % (server_name, key_ids),
Codes.UNAUTHORIZED,
)
json_object = group_id_to_json[group.group_id]
try:
verify_signed_json(json_object, server_name, verify_key)
except:
raise SynapseError(
401,
"Invalid signature for server %s with key %s:%s" % (
server_name, verify_key.alg, verify_key.version
),
Codes.UNAUTHORIZED,
)
server_to_deferred = {
server_name: defer.Deferred()
for server_name, _ in server_and_json
}
# We want to wait for any previous lookups to complete before
# proceeding.
wait_on_deferred = self.wait_for_previous_lookups(
[server_name for server_name, _ in server_and_json],
server_to_deferred,
)
# Actually start fetching keys.
wait_on_deferred.addBoth(
lambda _: self.get_server_verify_keys(group_id_to_group, deferreds)
)
# When we've finished fetching all the keys for a given server_name,
# resolve the deferred passed to `wait_for_previous_lookups` so that
# any lookups waiting will proceed.
server_to_gids = {}
def remove_deferreds(res, server_name, group_id):
server_to_gids[server_name].discard(group_id)
if not server_to_gids[server_name]:
d = server_to_deferred.pop(server_name, None)
if d:
d.callback(None)
return res
for g_id, deferred in deferreds.items():
server_name = group_id_to_group[g_id].server_name
server_to_gids.setdefault(server_name, set()).add(g_id)
deferred.addBoth(remove_deferreds, server_name, g_id)
# Pass those keys to handle_key_deferred so that the json object
# signatures can be verified
return [
handle_key_deferred(
group_id_to_group[g_id],
deferreds[g_id],
logger.debug("Verifying for %s", server_name)
key_ids = signature_ids(json_object, server_name)
if not key_ids:
raise SynapseError(
400,
"Not signed with a supported algorithm",
Codes.UNAUTHORIZED,
)
for g_id in group_ids
]
@defer.inlineCallbacks
def wait_for_previous_lookups(self, server_names, server_to_deferred):
"""Waits for any previous key lookups for the given servers to finish.
Args:
server_names (list): list of server_names we want to lookup
server_to_deferred (dict): server_name to deferred which gets
resolved once we've finished looking up keys for that server
"""
while True:
wait_on = [
self.key_downloads[server_name]
for server_name in server_names
if server_name in self.key_downloads
]
if wait_on:
yield defer.DeferredList(wait_on)
else:
break
for server_name, deferred in server_to_deferred.items():
d = ObservableDeferred(deferred)
self.key_downloads[server_name] = d
def rm(r, server_name):
self.key_downloads.pop(server_name, None)
return r
d.addBoth(rm, server_name)
def get_server_verify_keys(self, group_id_to_group, group_id_to_deferred):
"""Takes a dict of KeyGroups and tries to find at least one key for
each group.
"""
# These are functions that produce keys given a list of key ids
key_fetch_fns = (
self.get_keys_from_store, # First try the local store
self.get_keys_from_perspectives, # Then try via perspectives
self.get_keys_from_server, # Then try directly
)
@defer.inlineCallbacks
def do_iterations():
merged_results = {}
missing_keys = {}
for group in group_id_to_group.values():
missing_keys.setdefault(group.server_name, set()).union(group.key_ids)
for fn in key_fetch_fns:
results = yield fn(missing_keys.items())
merged_results.update(results)
# We now need to figure out which groups we have keys for
# and which we don't
missing_groups = {}
for group in group_id_to_group.values():
for key_id in group.key_ids:
if key_id in merged_results[group.server_name]:
group_id_to_deferred[group.group_id].callback((
group.group_id,
group.server_name,
key_id,
merged_results[group.server_name][key_id],
))
break
else:
missing_groups.setdefault(
group.server_name, []
).append(group)
if not missing_groups:
break
missing_keys = {
server_name: set(
key_id for group in groups for key_id in group.key_ids
)
for server_name, groups in missing_groups.items()
}
for group in missing_groups.values():
group_id_to_deferred[group.group_id].errback(SynapseError(
401,
"No key for %s with id %s" % (
group.server_name, group.key_ids,
),
Codes.UNAUTHORIZED,
))
def on_err(err):
for deferred in group_id_to_deferred.values():
if not deferred.called:
deferred.errback(err)
do_iterations().addErrback(on_err)
return group_id_to_deferred
@defer.inlineCallbacks
def get_keys_from_store(self, server_name_and_key_ids):
res = yield defer.gatherResults(
[
self.store.get_server_verify_keys(
server_name, key_ids
).addCallback(lambda ks, server: (server, ks), server_name)
for server_name, key_ids in server_name_and_key_ids
],
consumeErrors=True,
).addErrback(unwrapFirstError)
defer.returnValue(dict(res))
@defer.inlineCallbacks
def get_keys_from_perspectives(self, server_name_and_key_ids):
@defer.inlineCallbacks
def get_key(perspective_name, perspective_keys):
try:
result = yield self.get_server_verify_key_v2_indirect(
server_name_and_key_ids, perspective_name, perspective_keys
)
defer.returnValue(result)
except Exception as e:
logger.exception(
"Unable to get key from %r: %s %s",
perspective_name,
type(e).__name__, str(e.message),
)
defer.returnValue({})
results = yield defer.gatherResults(
[
get_key(p_name, p_keys)
for p_name, p_keys in self.perspective_servers.items()
],
consumeErrors=True,
).addErrback(unwrapFirstError)
union_of_keys = {}
for result in results:
for server_name, keys in result.items():
union_of_keys.setdefault(server_name, {}).update(keys)
defer.returnValue(union_of_keys)
@defer.inlineCallbacks
def get_keys_from_server(self, server_name_and_key_ids):
@defer.inlineCallbacks
def get_key(server_name, key_ids):
limiter = yield get_retry_limiter(
server_name,
self.clock,
self.store,
try:
verify_key = yield self.get_server_verify_key(server_name, key_ids)
except IOError as e:
logger.warn(
"Got IOError when downloading keys for %s: %s %s",
server_name, type(e).__name__, str(e.message),
)
with limiter:
keys = None
try:
keys = yield self.get_server_verify_key_v2_direct(
server_name, key_ids
)
except Exception as e:
logger.info(
"Unable to getting key %r for %r directly: %s %s",
key_ids, server_name,
type(e).__name__, str(e.message),
)
if not keys:
keys = yield self.get_server_verify_key_v1_direct(
server_name, key_ids
)
keys = {server_name: keys}
defer.returnValue(keys)
results = yield defer.gatherResults(
[
get_key(server_name, key_ids)
for server_name, key_ids in server_name_and_key_ids
],
consumeErrors=True,
).addErrback(unwrapFirstError)
merged = {}
for result in results:
merged.update(result)
defer.returnValue({
server_name: keys
for server_name, keys in merged.items()
if keys
})
@defer.inlineCallbacks
def get_server_verify_key_v2_indirect(self, server_names_and_key_ids,
perspective_name,
perspective_keys):
limiter = yield get_retry_limiter(
perspective_name, self.clock, self.store
)
with limiter:
# TODO(mark): Set the minimum_valid_until_ts to that needed by
# the events being validated or the current time if validating
# an incoming request.
query_response = yield self.client.post_json(
destination=perspective_name,
path=b"/_matrix/key/v2/query",
data={
u"server_keys": {
server_name: {
key_id: {
u"minimum_valid_until_ts": 0
} for key_id in key_ids
}
for server_name, key_ids in server_names_and_key_ids
}
},
raise SynapseError(
502,
"Error downloading keys for %s" % (server_name,),
Codes.UNAUTHORIZED,
)
except Exception as e:
logger.warn(
"Got Exception when downloading keys for %s: %s %s",
server_name, type(e).__name__, str(e.message),
)
raise SynapseError(
401,
"No key for %s with id %s" % (server_name, key_ids),
Codes.UNAUTHORIZED,
)
keys = {}
responses = query_response["server_keys"]
for response in responses:
if (u"signatures" not in response
or perspective_name not in response[u"signatures"]):
raise ValueError(
"Key response not signed by perspective server"
" %r" % (perspective_name,)
)
verified = False
for key_id in response[u"signatures"][perspective_name]:
if key_id in perspective_keys:
verify_signed_json(
response,
perspective_name,
perspective_keys[key_id]
)
verified = True
if not verified:
logging.info(
"Response from perspective server %r not signed with a"
" known key, signed with: %r, known keys: %r",
perspective_name,
list(response[u"signatures"][perspective_name]),
list(perspective_keys)
)
raise ValueError(
"Response not signed with a known key for perspective"
" server %r" % (perspective_name,)
)
processed_response = yield self.process_v2_response(
perspective_name, response
try:
verify_signed_json(json_object, server_name, verify_key)
except:
raise SynapseError(
401,
"Invalid signature for server %s with key %s:%s" % (
server_name, verify_key.alg, verify_key.version
),
Codes.UNAUTHORIZED,
)
for server_name, response_keys in processed_response.items():
keys.setdefault(server_name, {}).update(response_keys)
yield defer.gatherResults(
[
self.store_keys(
server_name=server_name,
from_server=perspective_name,
verify_keys=response_keys,
)
for server_name, response_keys in keys.items()
],
consumeErrors=True
).addErrback(unwrapFirstError)
defer.returnValue(keys)
@defer.inlineCallbacks
def get_server_verify_key_v2_direct(self, server_name, key_ids):
keys = {}
for requested_key_id in key_ids:
if requested_key_id in keys:
continue
(response, tls_certificate) = yield fetch_server_key(
server_name, self.hs.tls_server_context_factory,
path=(b"/_matrix/key/v2/server/%s" % (
urllib.quote(requested_key_id),
)).encode("ascii"),
)
if (u"signatures" not in response
or server_name not in response[u"signatures"]):
raise ValueError("Key response not signed by remote server")
if "tls_fingerprints" not in response:
raise ValueError("Key response missing TLS fingerprints")
certificate_bytes = crypto.dump_certificate(
crypto.FILETYPE_ASN1, tls_certificate
)
sha256_fingerprint = hashlib.sha256(certificate_bytes).digest()
sha256_fingerprint_b64 = encode_base64(sha256_fingerprint)
response_sha256_fingerprints = set()
for fingerprint in response[u"tls_fingerprints"]:
if u"sha256" in fingerprint:
response_sha256_fingerprints.add(fingerprint[u"sha256"])
if sha256_fingerprint_b64 not in response_sha256_fingerprints:
raise ValueError("TLS certificate not allowed by fingerprints")
response_keys = yield self.process_v2_response(
from_server=server_name,
requested_ids=[requested_key_id],
response_json=response,
)
keys.update(response_keys)
yield defer.gatherResults(
[
self.store_keys(
server_name=key_server_name,
from_server=server_name,
verify_keys=verify_keys,
)
for key_server_name, verify_keys in keys.items()
],
consumeErrors=True
).addErrback(unwrapFirstError)
defer.returnValue(keys)
@defer.inlineCallbacks
def process_v2_response(self, from_server, response_json,
requested_ids=[]):
time_now_ms = self.clock.time_msec()
response_keys = {}
verify_keys = {}
for key_id, key_data in response_json["verify_keys"].items():
if is_signing_algorithm_supported(key_id):
key_base64 = key_data["key"]
key_bytes = decode_base64(key_base64)
verify_key = decode_verify_key_bytes(key_id, key_bytes)
verify_key.time_added = time_now_ms
verify_keys[key_id] = verify_key
old_verify_keys = {}
for key_id, key_data in response_json["old_verify_keys"].items():
if is_signing_algorithm_supported(key_id):
key_base64 = key_data["key"]
key_bytes = decode_base64(key_base64)
verify_key = decode_verify_key_bytes(key_id, key_bytes)
verify_key.expired = key_data["expired_ts"]
verify_key.time_added = time_now_ms
old_verify_keys[key_id] = verify_key
results = {}
server_name = response_json["server_name"]
for key_id in response_json["signatures"].get(server_name, {}):
if key_id not in response_json["verify_keys"]:
raise ValueError(
"Key response must include verification keys for all"
" signatures"
)
if key_id in verify_keys:
verify_signed_json(
response_json,
server_name,
verify_keys[key_id]
)
signed_key_json = sign_json(
response_json,
self.config.server_name,
self.config.signing_key[0],
)
signed_key_json_bytes = encode_canonical_json(signed_key_json)
ts_valid_until_ms = signed_key_json[u"valid_until_ts"]
updated_key_ids = set(requested_ids)
updated_key_ids.update(verify_keys)
updated_key_ids.update(old_verify_keys)
response_keys.update(verify_keys)
response_keys.update(old_verify_keys)
yield defer.gatherResults(
[
self.store.store_server_keys_json(
server_name=server_name,
key_id=key_id,
from_server=server_name,
ts_now_ms=time_now_ms,
ts_expires_ms=ts_valid_until_ms,
key_json_bytes=signed_key_json_bytes,
)
for key_id in updated_key_ids
],
consumeErrors=True,
).addErrback(unwrapFirstError)
results[server_name] = response_keys
defer.returnValue(results)
@defer.inlineCallbacks
def get_server_verify_key_v1_direct(self, server_name, key_ids):
def get_server_verify_key(self, server_name, key_ids):
"""Finds a verification key for the server with one of the key ids.
Args:
server_name (str): The name of the server to fetch a key for.
keys_ids (list of str): The key_ids to check for.
"""
# Check the datastore to see if we have one cached.
cached = yield self.store.get_server_verify_keys(server_name, key_ids)
if cached:
defer.returnValue(cached[0])
return
# Try to fetch the key from the remote server.
(response, tls_certificate) = yield fetch_server_key(
server_name, self.hs.tls_server_context_factory
limiter = yield get_retry_limiter(
server_name,
self.clock,
self.store,
)
with limiter:
(response, tls_certificate) = yield fetch_server_key(
server_name, self.hs.tls_context_factory
)
# Check the response.
x509_certificate_bytes = crypto.dump_certificate(
@@ -624,16 +128,11 @@ class Keyring(object):
if encode_base64(x509_certificate_bytes) != tls_certificate_b64:
raise ValueError("TLS certificate doesn't match")
# Cache the result in the datastore.
time_now_ms = self.clock.time_msec()
verify_keys = {}
for key_id, key_base64 in response["verify_keys"].items():
if is_signing_algorithm_supported(key_id):
key_bytes = decode_base64(key_base64)
verify_key = decode_verify_key_bytes(key_id, key_bytes)
verify_key.time_added = time_now_ms
verify_keys[key_id] = verify_key
for key_id in response["signatures"][server_name]:
@@ -649,6 +148,10 @@ class Keyring(object):
verify_keys[key_id]
)
# Cache the result in the datastore.
time_now_ms = self.clock.time_msec()
yield self.store.store_server_certificate(
server_name,
server_name,
@@ -656,31 +159,14 @@ class Keyring(object):
tls_certificate,
)
yield self.store_keys(
server_name=server_name,
from_server=server_name,
verify_keys=verify_keys,
)
for key_id, key in verify_keys.items():
yield self.store.store_server_verify_key(
server_name, server_name, time_now_ms, key
)
defer.returnValue(verify_keys)
for key_id in key_ids:
if key_id in verify_keys:
defer.returnValue(verify_keys[key_id])
return
@defer.inlineCallbacks
def store_keys(self, server_name, from_server, verify_keys):
"""Store a collection of verify keys for a given server
Args:
server_name(str): The name of the server the keys are for.
from_server(str): The server the keys were downloaded from.
verify_keys(dict): A mapping of key_id to VerifyKey.
Returns:
A deferred that completes when the keys are stored.
"""
# TODO(markjh): Store whether the keys have expired.
yield defer.gatherResults(
[
self.store.store_server_verify_key(
server_name, server_name, key.time_added, key
)
for key_id, key in verify_keys.items()
],
consumeErrors=True,
).addErrback(unwrapFirstError)
raise ValueError("No verification key found for given key ids")

View File

@@ -16,12 +16,6 @@
from synapse.util.frozenutils import freeze
# Whether we should use frozen_dict in FrozenEvent. Using frozen_dicts prevents
# bugs where we accidentally share e.g. signature dicts. However, converting
# a dict to frozen_dicts is expensive.
USE_FROZEN_DICTS = True
class _EventInternalMetadata(object):
def __init__(self, internal_metadata_dict):
self.__dict__ = dict(internal_metadata_dict)
@@ -52,10 +46,9 @@ def _event_dict_property(key):
class EventBase(object):
def __init__(self, event_dict, signatures={}, unsigned={},
internal_metadata_dict={}, rejected_reason=None):
internal_metadata_dict={}):
self.signatures = signatures
self.unsigned = unsigned
self.rejected_reason = rejected_reason
self._event_dict = event_dict
@@ -90,7 +83,7 @@ class EventBase(object):
d = dict(self._event_dict)
d.update({
"signatures": self.signatures,
"unsigned": dict(self.unsigned),
"unsigned": self.unsigned,
})
return d
@@ -109,9 +102,6 @@ class EventBase(object):
pdu_json.setdefault("unsigned", {})["age"] = int(age)
del pdu_json["unsigned"]["age_ts"]
# This may be a frozen event
pdu_json["unsigned"].pop("redacted_because", None)
return pdu_json
def __set__(self, instance, value):
@@ -119,7 +109,7 @@ class EventBase(object):
class FrozenEvent(EventBase):
def __init__(self, event_dict, internal_metadata_dict={}, rejected_reason=None):
def __init__(self, event_dict, internal_metadata_dict={}):
event_dict = dict(event_dict)
# Signatures is a dict of dicts, and this is faster than doing a
@@ -131,17 +121,13 @@ class FrozenEvent(EventBase):
unsigned = dict(event_dict.pop("unsigned", {}))
if USE_FROZEN_DICTS:
frozen_dict = freeze(event_dict)
else:
frozen_dict = event_dict
frozen_dict = freeze(event_dict)
super(FrozenEvent, self).__init__(
frozen_dict,
signatures=signatures,
unsigned=unsigned,
internal_metadata_dict=internal_metadata_dict,
rejected_reason=rejected_reason,
)
@staticmethod

View File

@@ -16,7 +16,8 @@
class EventContext(object):
def __init__(self, current_state=None):
def __init__(self, current_state=None, auth_events=None):
self.current_state = current_state
self.auth_events = auth_events
self.state_group = None
self.rejected = False

View File

@@ -74,8 +74,6 @@ def prune_event(event):
)
elif event_type == EventTypes.Aliases:
add_fields("aliases")
elif event_type == EventTypes.RoomHistoryVisibility:
add_fields("history_visibility")
allowed_fields = {
k: v

View File

@@ -18,12 +18,12 @@ from twisted.internet import defer
from synapse.events.utils import prune_event
from syutil.jsonutil import encode_canonical_json
from synapse.crypto.event_signing import check_event_content_hash
from synapse.api.errors import SynapseError
from synapse.util import unwrapFirstError
import logging
@@ -32,8 +32,7 @@ logger = logging.getLogger(__name__)
class FederationBase(object):
@defer.inlineCallbacks
def _check_sigs_and_hash_and_fetch(self, origin, pdus, outlier=False,
include_none=False):
def _check_sigs_and_hash_and_fetch(self, origin, pdus, outlier=False):
"""Takes a list of PDUs and checks the signatures and hashs of each
one. If a PDU fails its signature check then we check if we have it in
the database and if not then request if from the originating server of
@@ -51,108 +50,84 @@ class FederationBase(object):
Returns:
Deferred : A list of PDUs that have valid signatures and hashes.
"""
deferreds = self._check_sigs_and_hashes(pdus)
def callback(pdu):
return pdu
signed_pdus = []
def errback(failure, pdu):
failure.trap(SynapseError)
return None
@defer.inlineCallbacks
def do(pdu):
try:
new_pdu = yield self._check_sigs_and_hash(pdu)
signed_pdus.append(new_pdu)
except SynapseError:
# FIXME: We should handle signature failures more gracefully.
def try_local_db(res, pdu):
if not res:
# Check local db.
return self.store.get_event(
new_pdu = yield self.store.get_event(
pdu.event_id,
allow_rejected=True,
allow_none=True,
)
return res
if new_pdu:
signed_pdus.append(new_pdu)
return
def try_remote(res, pdu):
if not res and pdu.origin != origin:
return self.get_pdu(
destinations=[pdu.origin],
event_id=pdu.event_id,
outlier=outlier,
timeout=10000,
).addErrback(lambda e: None)
return res
# Check pdu.origin
if pdu.origin != origin:
try:
new_pdu = yield self.get_pdu(
destinations=[pdu.origin],
event_id=pdu.event_id,
outlier=outlier,
)
if new_pdu:
signed_pdus.append(new_pdu)
return
except:
pass
def warn(res, pdu):
if not res:
logger.warn(
"Failed to find copy of %s with valid signature",
pdu.event_id,
)
return res
for pdu, deferred in zip(pdus, deferreds):
deferred.addCallbacks(
callback, errback, errbackArgs=[pdu]
).addCallback(
try_local_db, pdu
).addCallback(
try_remote, pdu
).addCallback(
warn, pdu
)
valid_pdus = yield defer.gatherResults(
deferreds,
yield defer.gatherResults(
[do(pdu) for pdu in pdus],
consumeErrors=True
).addErrback(unwrapFirstError)
)
if include_none:
defer.returnValue(valid_pdus)
else:
defer.returnValue([p for p in valid_pdus if p])
defer.returnValue(signed_pdus)
@defer.inlineCallbacks
def _check_sigs_and_hash(self, pdu):
return self._check_sigs_and_hashes([pdu])[0]
def _check_sigs_and_hashes(self, pdus):
"""Throws a SynapseError if a PDU does not have the correct
"""Throws a SynapseError if the PDU does not have the correct
signatures.
Returns:
FrozenEvent: Either the given event or it redacted if it failed the
content hash check.
"""
# Check signatures are correct.
redacted_event = prune_event(pdu)
redacted_pdu_json = redacted_event.get_pdu_json()
redacted_pdus = [
prune_event(pdu)
for pdu in pdus
]
deferreds = self.keyring.verify_json_objects_for_server([
(p.origin, p.get_pdu_json())
for p in redacted_pdus
])
def callback(_, pdu, redacted):
if not check_event_content_hash(pdu):
logger.warn(
"Event content has been tampered, redacting %s: %s",
pdu.event_id, pdu.get_pdu_json()
)
return redacted
return pdu
def errback(failure, pdu):
failure.trap(SynapseError)
try:
yield self.keyring.verify_json_for_server(
pdu.origin, redacted_pdu_json
)
except SynapseError:
logger.warn(
"Signature check failed for %s",
pdu.event_id,
"Signature check failed for %s redacted to %s",
encode_canonical_json(pdu.get_pdu_json()),
encode_canonical_json(redacted_pdu_json),
)
return failure
raise
for deferred, pdu, redacted in zip(deferreds, pdus, redacted_pdus):
deferred.addCallbacks(
callback, errback,
callbackArgs=[pdu, redacted],
errbackArgs=[pdu],
if not check_event_content_hash(pdu):
logger.warn(
"Event content has been tampered, redacting %s, %s",
pdu.event_id, encode_canonical_json(pdu.get_dict())
)
defer.returnValue(redacted_event)
return deferreds
defer.returnValue(pdu)

View File

@@ -22,15 +22,12 @@ from .units import Edu
from synapse.api.errors import (
CodeMessageException, HttpResponseException, SynapseError,
)
from synapse.util import unwrapFirstError
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.expiringcache import ExpiringCache
from synapse.util.logutils import log_function
from synapse.events import FrozenEvent
import synapse.metrics
from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination
import copy
import itertools
import logging
import random
@@ -39,17 +36,9 @@ import random
logger = logging.getLogger(__name__)
# synapse.federation.federation_client is a silly name
metrics = synapse.metrics.get_metrics_for("synapse.federation.client")
sent_pdus_destination_dist = metrics.register_distribution("sent_pdu_destinations")
sent_edus_counter = metrics.register_counter("sent_edus")
sent_queries_counter = metrics.register_counter("sent_queries", labels=["type"])
class FederationClient(FederationBase):
def __init__(self):
self._get_pdu_cache = None
def start_get_pdu_cache(self):
self._get_pdu_cache = ExpiringCache(
@@ -79,8 +68,6 @@ class FederationClient(FederationBase):
order = self._order
self._order += 1
sent_pdus_destination_dist.inc_by(len(destinations))
logger.debug("[%s] transaction_layer.enqueue_pdu... ", pdu.event_id)
# TODO, add errback, etc.
@@ -100,8 +87,6 @@ class FederationClient(FederationBase):
content=content,
)
sent_edus_counter.inc()
# TODO, add errback, etc.
self._transaction_queue.enqueue_edu(edu)
return defer.succeed(None)
@@ -128,42 +113,10 @@ class FederationClient(FederationBase):
a Deferred which will eventually yield a JSON object from the
response
"""
sent_queries_counter.inc(query_type)
return self.transport_layer.make_query(
destination, query_type, args, retry_on_dns_fail=retry_on_dns_fail
)
@log_function
def query_client_keys(self, destination, content):
"""Query device keys for a device hosted on a remote server.
Args:
destination (str): Domain name of the remote homeserver
content (dict): The query content.
Returns:
a Deferred which will eventually yield a JSON object from the
response
"""
sent_queries_counter.inc("client_device_keys")
return self.transport_layer.query_client_keys(destination, content)
@log_function
def claim_client_keys(self, destination, content):
"""Claims one-time keys for a device hosted on a remote server.
Args:
destination (str): Domain name of the remote homeserver
content (dict): The query content.
Returns:
a Deferred which will eventually yield a JSON object from the
response
"""
sent_queries_counter.inc("client_one_time_keys")
return self.transport_layer.claim_client_keys(destination, content)
@defer.inlineCallbacks
@log_function
def backfill(self, dest, context, limit, extremities):
@@ -196,17 +149,16 @@ class FederationClient(FederationBase):
for p in transaction_data["pdus"]
]
# FIXME: We should handle signature failures more gracefully.
pdus[:] = yield defer.gatherResults(
self._check_sigs_and_hashes(pdus),
consumeErrors=True,
).addErrback(unwrapFirstError)
for i, pdu in enumerate(pdus):
pdus[i] = yield self._check_sigs_and_hash(pdu)
# FIXME: We should handle signature failures more gracefully.
defer.returnValue(pdus)
@defer.inlineCallbacks
@log_function
def get_pdu(self, destinations, event_id, outlier=False, timeout=None):
def get_pdu(self, destinations, event_id, outlier=False):
"""Requests the PDU with given origin and ID from the remote home
servers.
@@ -222,8 +174,6 @@ class FederationClient(FederationBase):
outlier (bool): Indicates whether the PDU is an `outlier`, i.e. if
it's from an arbitary point in the context as opposed to part
of the current block of PDUs. Defaults to `False`
timeout (int): How long to try (in ms) each destination for before
moving to the next destination. None indicates no timeout.
Returns:
Deferred: Results in the requested PDU.
@@ -247,7 +197,7 @@ class FederationClient(FederationBase):
with limiter:
transaction_data = yield self.transport_layer.get_event(
destination, event_id, timeout=timeout,
destination, event_id
)
logger.debug("transaction_data %r", transaction_data)
@@ -257,11 +207,11 @@ class FederationClient(FederationBase):
for p in transaction_data["pdus"]
]
if pdu_list and pdu_list[0]:
if pdu_list:
pdu = pdu_list[0]
# Check signatures are correct.
pdu = yield self._check_sigs_and_hashes([pdu])[0]
pdu = yield self._check_sigs_and_hash(pdu)
break
@@ -290,7 +240,7 @@ class FederationClient(FederationBase):
)
continue
if self._get_pdu_cache is not None and pdu:
if self._get_pdu_cache is not None:
self._get_pdu_cache[event_id] = pdu
defer.returnValue(pdu)
@@ -358,9 +308,6 @@ class FederationClient(FederationBase):
@defer.inlineCallbacks
def make_join(self, destinations, room_id, user_id):
for destination in destinations:
if destination == self.server_name:
continue
try:
ret = yield self.transport_layer.make_join(
destination, room_id, user_id
@@ -387,9 +334,6 @@ class FederationClient(FederationBase):
@defer.inlineCallbacks
def send_join(self, destinations, pdu):
for destination in destinations:
if destination == self.server_name:
continue
try:
time_now = self._clock.time_msec()
_, content = yield self.transport_layer.send_join(
@@ -411,39 +355,13 @@ class FederationClient(FederationBase):
for p in content.get("auth_chain", [])
]
pdus = {
p.event_id: p
for p in itertools.chain(state, auth_chain)
}
valid_pdus = yield self._check_sigs_and_hash_and_fetch(
destination, pdus.values(),
outlier=True,
signed_state = yield self._check_sigs_and_hash_and_fetch(
destination, state, outlier=True
)
valid_pdus_map = {
p.event_id: p
for p in valid_pdus
}
# NB: We *need* to copy to ensure that we don't have multiple
# references being passed on, as that causes... issues.
signed_state = [
copy.copy(valid_pdus_map[p.event_id])
for p in state
if p.event_id in valid_pdus_map
]
signed_auth = [
valid_pdus_map[p.event_id]
for p in auth_chain
if p.event_id in valid_pdus_map
]
# NB: We *need* to copy to ensure that we don't have multiple
# references being passed on, as that causes... issues.
for s in signed_state:
s.internal_metadata = copy.deepcopy(s.internal_metadata)
signed_auth = yield self._check_sigs_and_hash_and_fetch(
destination, auth_chain, outlier=True
)
auth_chain.sort(key=lambda e: e.depth)
@@ -455,7 +373,7 @@ class FederationClient(FederationBase):
except CodeMessageException:
raise
except Exception as e:
logger.exception(
logger.warn(
"Failed to send_join via %s: %s",
destination, e.message
)
@@ -558,7 +476,7 @@ class FederationClient(FederationBase):
]
signed_events = yield self._check_sigs_and_hash_and_fetch(
destination, events, outlier=False
destination, events, outlier=True
)
have_gotten_all_from_destination = True
@@ -585,7 +503,7 @@ class FederationClient(FederationBase):
# Are we missing any?
seen_events = set(earliest_events_ids)
seen_events.update(e.event_id for e in signed_events if e)
seen_events.update(e.event_id for e in signed_events)
missing_events = {}
for e in itertools.chain(latest_events, signed_events):
@@ -628,7 +546,7 @@ class FederationClient(FederationBase):
res = yield defer.DeferredList(deferreds, consumeErrors=True)
for (result, val), (e_id, _) in zip(res, ordered_missing):
if result and val:
if result:
signed_events.append(val)
else:
failed_to_fetch.add(e_id)

View File

@@ -20,28 +20,18 @@ from .federation_base import FederationBase
from .units import Transaction, Edu
from synapse.util.logutils import log_function
from synapse.util.logcontext import PreserveLoggingContext
from synapse.events import FrozenEvent
import synapse.metrics
from synapse.api.errors import FederationError, SynapseError
from synapse.crypto.event_signing import compute_event_signature
import simplejson as json
import logging
logger = logging.getLogger(__name__)
# synapse.federation.federation_server is a silly name
metrics = synapse.metrics.get_metrics_for("synapse.federation.server")
received_pdus_counter = metrics.register_counter("received_pdus")
received_edus_counter = metrics.register_counter("received_edus")
received_queries_counter = metrics.register_counter("received_queries", labels=["type"])
class FederationServer(FederationBase):
def set_handler(self, handler):
@@ -94,8 +84,6 @@ class FederationServer(FederationBase):
def on_incoming_transaction(self, transaction_data):
transaction = Transaction(**transaction_data)
received_pdus_counter.inc_by(len(transaction.pdus))
for p in transaction.pdus:
if "unsigned" in p:
unsigned = p["unsigned"]
@@ -123,28 +111,29 @@ class FederationServer(FederationBase):
logger.debug("[%s] Transaction is new", transaction.transaction_id)
results = []
with PreserveLoggingContext():
results = []
for pdu in pdu_list:
d = self._handle_new_pdu(transaction.origin, pdu)
for pdu in pdu_list:
d = self._handle_new_pdu(transaction.origin, pdu)
try:
yield d
results.append({})
except FederationError as e:
self.send_failure(e, transaction.origin)
results.append({"error": str(e)})
except Exception as e:
results.append({"error": str(e)})
logger.exception("Failed to handle PDU")
try:
yield d
results.append({})
except FederationError as e:
self.send_failure(e, transaction.origin)
results.append({"error": str(e)})
except Exception as e:
results.append({"error": str(e)})
logger.exception("Failed to handle PDU")
if hasattr(transaction, "edus"):
for edu in [Edu(**x) for x in transaction.edus]:
self.received_edu(
transaction.origin,
edu.edu_type,
edu.content
)
if hasattr(transaction, "edus"):
for edu in [Edu(**x) for x in transaction.edus]:
self.received_edu(
transaction.origin,
edu.edu_type,
edu.content
)
for failure in getattr(transaction, "pdu_failures", []):
logger.info("Got failure %r", failure)
@@ -164,8 +153,6 @@ class FederationServer(FederationBase):
defer.returnValue((200, response))
def received_edu(self, origin, edu_type, content):
received_edus_counter.inc()
if edu_type in self.edu_handlers:
self.edu_handlers[edu_type](origin, content)
else:
@@ -217,8 +204,6 @@ class FederationServer(FederationBase):
@defer.inlineCallbacks
def on_query_request(self, query_type, args):
received_queries_counter.inc(query_type)
if query_type in self.query_handlers:
response = yield self.query_handlers[query_type](args)
defer.returnValue((200, response))
@@ -313,48 +298,6 @@ class FederationServer(FederationBase):
(200, send_content)
)
@defer.inlineCallbacks
@log_function
def on_query_client_keys(self, origin, content):
query = []
for user_id, device_ids in content.get("device_keys", {}).items():
if not device_ids:
query.append((user_id, None))
else:
for device_id in device_ids:
query.append((user_id, device_id))
results = yield self.store.get_e2e_device_keys(query)
json_result = {}
for user_id, device_keys in results.items():
for device_id, json_bytes in device_keys.items():
json_result.setdefault(user_id, {})[device_id] = json.loads(
json_bytes
)
defer.returnValue({"device_keys": json_result})
@defer.inlineCallbacks
@log_function
def on_claim_client_keys(self, origin, content):
query = []
for user_id, device_keys in content.get("one_time_keys", {}).items():
for device_id, algorithm in device_keys.items():
query.append((user_id, device_id, algorithm))
results = yield self.store.claim_e2e_one_time_keys(query)
json_result = {}
for user_id, device_keys in results.items():
for device_id, keys in device_keys.items():
for key_id, json_bytes in keys.items():
json_result.setdefault(user_id, {})[device_id] = {
key_id: json.loads(json_bytes)
}
defer.returnValue({"one_time_keys": json_result})
@defer.inlineCallbacks
@log_function
def on_get_missing_events(self, origin, room_id, earliest_events,
@@ -458,13 +401,13 @@ class FederationServer(FederationBase):
pdu.internal_metadata.outlier = True
elif min_depth and pdu.depth > min_depth:
if get_missing and prevs - seen:
latest = yield self.store.get_latest_event_ids_in_room(
latest_tuples = yield self.store.get_latest_events_in_room(
pdu.room_id
)
# We add the prev events that we have seen to the latest
# list to ensure the remote server doesn't give them to us
latest = set(latest)
latest = set(e_id for e_id, _, _ in latest_tuples)
latest |= seen
missing_events = yield self.get_missing_events(

View File

@@ -23,6 +23,8 @@ from twisted.internet import defer
from synapse.util.logutils import log_function
from syutil.jsonutil import encode_canonical_json
import logging
@@ -69,7 +71,7 @@ class TransactionActions(object):
transaction.transaction_id,
transaction.origin,
code,
response,
encode_canonical_json(response)
)
@defer.inlineCallbacks
@@ -99,5 +101,5 @@ class TransactionActions(object):
transaction.transaction_id,
transaction.destination,
response_code,
response_dict,
encode_canonical_json(response_dict)
)

View File

@@ -25,15 +25,12 @@ from synapse.util.logcontext import PreserveLoggingContext
from synapse.util.retryutils import (
get_retry_limiter, NotRetryingDestination,
)
import synapse.metrics
import logging
logger = logging.getLogger(__name__)
metrics = synapse.metrics.get_metrics_for(__name__)
class TransactionQueue(object):
"""This class makes sure we only have one transaction in flight at
@@ -57,25 +54,11 @@ class TransactionQueue(object):
# done
self.pending_transactions = {}
metrics.register_callback(
"pending_destinations",
lambda: len(self.pending_transactions),
)
# Is a mapping from destination -> list of
# tuple(pending pdus, deferred, order)
self.pending_pdus_by_dest = pdus = {}
self.pending_pdus_by_dest = {}
# destination -> list of tuple(edu, deferred)
self.pending_edus_by_dest = edus = {}
metrics.register_callback(
"pending_pdus",
lambda: sum(map(len, pdus.values())),
)
metrics.register_callback(
"pending_edus",
lambda: sum(map(len, edus.values())),
)
self.pending_edus_by_dest = {}
# destination -> list of tuple(failure, deferred)
self.pending_failures_by_dest = {}
@@ -104,6 +87,7 @@ class TransactionQueue(object):
return not destination.startswith("localhost")
@defer.inlineCallbacks
@log_function
def enqueue_pdu(self, pdu, destinations, order):
# We loop through all destinations to see whether we already have
# a transaction in progress. If we do, stick it in the pending_pdus
@@ -207,13 +191,13 @@ class TransactionQueue(object):
# request at which point pending_pdus_by_dest just keeps growing.
# we need application-layer timeouts of some flavour of these
# requests
logger.debug(
logger.info(
"TX [%s] Transaction already in progress",
destination
)
return
logger.debug("TX [%s] _attempt_new_transaction", destination)
logger.info("TX [%s] _attempt_new_transaction", destination)
# list of (pending_pdu, deferred, order)
pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
@@ -221,11 +205,11 @@ class TransactionQueue(object):
pending_failures = self.pending_failures_by_dest.pop(destination, [])
if pending_pdus:
logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d",
destination, len(pending_pdus))
logger.info("TX [%s] len(pending_pdus_by_dest[dest]) = %d",
destination, len(pending_pdus))
if not pending_pdus and not pending_edus and not pending_failures:
logger.debug("TX [%s] Nothing to send", destination)
logger.info("TX [%s] Nothing to send", destination)
return
# Sort based on the order field
@@ -242,8 +226,6 @@ class TransactionQueue(object):
try:
self.pending_transactions[destination] = 1
txn_id = str(self._next_txn_id)
limiter = yield get_retry_limiter(
destination,
self._clock,
@@ -251,9 +233,9 @@ class TransactionQueue(object):
)
logger.debug(
"TX [%s] {%s} Attempting new transaction"
"TX [%s] Attempting new transaction"
" (pdus: %d, edus: %d, failures: %d)",
destination, txn_id,
destination,
len(pending_pdus),
len(pending_edus),
len(pending_failures)
@@ -263,7 +245,7 @@ class TransactionQueue(object):
transaction = Transaction.create_new(
origin_server_ts=int(self._clock.time_msec()),
transaction_id=txn_id,
transaction_id=str(self._next_txn_id),
origin=self.server_name,
destination=destination,
pdus=pdus,
@@ -277,13 +259,9 @@ class TransactionQueue(object):
logger.debug("TX [%s] Persisted transaction", destination)
logger.info(
"TX [%s] {%s} Sending transaction [%s],"
" (PDUs: %d, EDUs: %d, failures: %d)",
destination, txn_id,
"TX [%s] Sending transaction [%s]",
destination,
transaction.transaction_id,
len(pending_pdus),
len(pending_edus),
len(pending_failures),
)
with limiter:
@@ -319,10 +297,7 @@ class TransactionQueue(object):
code = e.code
response = e.response
logger.info(
"TX [%s] {%s} got %d response",
destination, txn_id, code
)
logger.info("TX [%s] got %d response", destination, code)
logger.debug("TX [%s] Sent transaction", destination)
logger.debug("TX [%s] Marking as delivered...", destination)

View File

@@ -50,15 +50,13 @@ class TransportLayerClient(object):
)
@log_function
def get_event(self, destination, event_id, timeout=None):
def get_event(self, destination, event_id):
""" Requests the pdu with give id and origin from the given server.
Args:
destination (str): The host name of the remote home server we want
to get the state from.
event_id (str): The id of the event being requested.
timeout (int): How long to try (in ms) the destination for before
giving up. None indicates no timeout.
Returns:
Deferred: Results in a dict received from the remote homeserver.
@@ -67,7 +65,7 @@ class TransportLayerClient(object):
destination, event_id)
path = PREFIX + "/event/%s/" % (event_id, )
return self.client.get_json(destination, path=path, timeout=timeout)
return self.client.get_json(destination, path=path)
@log_function
def backfill(self, destination, room_id, event_tuples, limit):
@@ -222,76 +220,6 @@ class TransportLayerClient(object):
defer.returnValue(content)
@defer.inlineCallbacks
@log_function
def query_client_keys(self, destination, query_content):
"""Query the device keys for a list of user ids hosted on a remote
server.
Request:
{
"device_keys": {
"<user_id>": ["<device_id>"]
} }
Response:
{
"device_keys": {
"<user_id>": {
"<device_id>": {...}
} } }
Args:
destination(str): The server to query.
query_content(dict): The user ids to query.
Returns:
A dict containg the device keys.
"""
path = PREFIX + "/user/keys/query"
content = yield self.client.post_json(
destination=destination,
path=path,
data=query_content,
)
defer.returnValue(content)
@defer.inlineCallbacks
@log_function
def claim_client_keys(self, destination, query_content):
"""Claim one-time keys for a list of devices hosted on a remote server.
Request:
{
"one_time_keys": {
"<user_id>": {
"<device_id>": "<algorithm>"
} } }
Response:
{
"device_keys": {
"<user_id>": {
"<device_id>": {
"<algorithm>:<key_id>": "<key_base64>"
} } } }
Args:
destination(str): The server to query.
query_content(dict): The user ids to query.
Returns:
A dict containg the one-time keys.
"""
path = PREFIX + "/user/keys/claim"
content = yield self.client.post_json(
destination=destination,
path=path,
data=query_content,
)
defer.returnValue(content)
@defer.inlineCallbacks
@log_function
def get_missing_events(self, destination, room_id, earliest_events,

View File

@@ -93,9 +93,6 @@ class TransportLayerServer(object):
yield self.keyring.verify_json_for_server(origin, json_request)
logger.info("Request from %s", origin)
request.authenticated_entity = origin
defer.returnValue((origin, content))
@log_function
@@ -151,10 +148,6 @@ class BaseFederationServlet(object):
logger.exception("authenticate_request failed")
raise
defer.returnValue(response)
# Extra logic that functools.wraps() doesn't finish
new_code.__self__ = code.__self__
return new_code
def register(self, server):
@@ -199,14 +192,6 @@ class FederationSendServlet(BaseFederationServlet):
transaction_id, str(transaction_data)
)
logger.info(
"Received txn %s from %s. (PDUs: %d, EDUs: %d, failures: %d)",
transaction_id, origin,
len(transaction_data.get("pdus", [])),
len(transaction_data.get("edus", [])),
len(transaction_data.get("failures", [])),
)
# We should ideally be getting this from the security layer.
# origin = body["origin"]
@@ -325,24 +310,6 @@ class FederationInviteServlet(BaseFederationServlet):
defer.returnValue((200, content))
class FederationClientKeysQueryServlet(BaseFederationServlet):
PATH = "/user/keys/query"
@defer.inlineCallbacks
def on_POST(self, origin, content, query):
response = yield self.handler.on_query_client_keys(origin, content)
defer.returnValue((200, response))
class FederationClientKeysClaimServlet(BaseFederationServlet):
PATH = "/user/keys/claim"
@defer.inlineCallbacks
def on_POST(self, origin, content, query):
response = yield self.handler.on_claim_client_keys(origin, content)
defer.returnValue((200, response))
class FederationQueryAuthServlet(BaseFederationServlet):
PATH = "/query_auth/([^/]*)/([^/]*)"
@@ -390,7 +357,4 @@ SERVLET_CLASSES = (
FederationInviteServlet,
FederationQueryAuthServlet,
FederationGetMissingEventsServlet,
FederationEventAuthServlet,
FederationClientKeysQueryServlet,
FederationClientKeysClaimServlet,
)

View File

@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.appservice.scheduler import AppServiceScheduler
from synapse.appservice.api import ApplicationServiceApi
from .register import RegistrationHandler
from .room import (
@@ -22,6 +21,7 @@ from .room import (
from .message import MessageHandler
from .events import EventStreamHandler, EventHandler
from .federation import FederationHandler
from .login import LoginHandler
from .profile import ProfileHandler
from .presence import PresenceHandler
from .directory import DirectoryHandler
@@ -29,9 +29,6 @@ from .typing import TypingNotificationHandler
from .admin import AdminHandler
from .appservice import ApplicationServicesHandler
from .sync import SyncHandler
from .auth import AuthHandler
from .identity import IdentityHandler
from .receipts import ReceiptsHandler
class Handlers(object):
@@ -53,18 +50,11 @@ class Handlers(object):
self.profile_handler = ProfileHandler(hs)
self.presence_handler = PresenceHandler(hs)
self.room_list_handler = RoomListHandler(hs)
self.login_handler = LoginHandler(hs)
self.directory_handler = DirectoryHandler(hs)
self.typing_notification_handler = TypingNotificationHandler(hs)
self.admin_handler = AdminHandler(hs)
self.receipts_handler = ReceiptsHandler(hs)
asapi = ApplicationServiceApi(hs)
self.appservice_handler = ApplicationServicesHandler(
hs, asapi, AppServiceScheduler(
clock=hs.get_clock(),
store=hs.get_datastore(),
as_api=asapi
)
hs, ApplicationServiceApi(hs)
)
self.sync_handler = SyncHandler(hs)
self.auth_handler = AuthHandler(hs)
self.identity_handler = IdentityHandler(hs)

View File

@@ -15,12 +15,11 @@
from twisted.internet import defer
from synapse.api.errors import LimitExceededError, SynapseError, AuthError
from synapse.api.errors import LimitExceededError, SynapseError
from synapse.util.async import run_on_reactor
from synapse.crypto.event_signing import add_hashes_and_signatures
from synapse.api.constants import Membership, EventTypes
from synapse.types import UserID, RoomAlias
from synapse.util.logcontext import PreserveLoggingContext
from synapse.types import UserID
import logging
@@ -59,6 +58,8 @@ class BaseHandler(object):
@defer.inlineCallbacks
def _create_new_client_event(self, builder):
yield run_on_reactor()
latest_ret = yield self.store.get_latest_events_in_room(
builder.room_id,
)
@@ -78,9 +79,7 @@ class BaseHandler(object):
context = yield state_handler.compute_event_context(builder)
if builder.is_state():
builder.prev_state = yield self.store.add_event_hashes(
context.prev_state_events
)
builder.prev_state = context.prev_state_events
yield self.auth.add_auth_events(builder, context)
@@ -91,8 +90,8 @@ class BaseHandler(object):
event = builder.build()
logger.debug(
"Created event %s with current state: %s",
event.event_id, context.current_state,
"Created event %s with auth_events: %s, current state: %s",
event.event_id, context.auth_events, context.current_state,
)
defer.returnValue(
@@ -102,30 +101,14 @@ class BaseHandler(object):
@defer.inlineCallbacks
def handle_new_client_event(self, event, context, extra_destinations=[],
extra_users=[], suppress_auth=False):
yield run_on_reactor()
# We now need to go and hit out to wherever we need to hit out to.
if not suppress_auth:
self.auth.check(event, auth_events=context.current_state)
self.auth.check(event, auth_events=context.auth_events)
if event.type == EventTypes.CanonicalAlias:
# Check the alias is acually valid (at this time at least)
room_alias_str = event.content.get("alias", None)
if room_alias_str:
room_alias = RoomAlias.from_string(room_alias_str)
directory_handler = self.hs.get_handlers().directory_handler
mapping = yield directory_handler.get_association(room_alias)
if mapping["room_id"] != event.room_id:
raise SynapseError(
400,
"Room alias %s does not point to the room" % (
room_alias_str,
)
)
(event_stream_id, max_stream_id) = yield self.store.persist_event(
event, context=context
)
yield self.store.persist_event(event, context=context)
federation_handler = self.hs.get_handlers().federation_handler
@@ -146,21 +129,6 @@ class BaseHandler(object):
returned_invite.signatures
)
if event.type == EventTypes.Redaction:
if self.auth.check_redaction(event, auth_events=context.current_state):
original_event = yield self.store.get_event(
event.redacts,
check_redacted=False,
get_prev_content=False,
allow_rejected=False,
allow_none=False
)
if event.user_id != original_event.user_id:
raise AuthError(
403,
"You don't have permission to redact events"
)
destinations = set(extra_destinations)
for k, s in context.current_state.items():
try:
@@ -174,21 +142,8 @@ class BaseHandler(object):
"Failed to get destination from event %s", s.event_id
)
with PreserveLoggingContext():
# Don't block waiting on waking up all the listeners.
notify_d = self.notifier.on_new_room_event(
event, event_stream_id, max_stream_id,
extra_users=extra_users
)
yield self.notifier.on_new_room_event(event, extra_users=extra_users)
def log_failure(f):
logger.warn(
"Failed to notify about %s: %s",
event.event_id, f.value
)
notify_d.addErrback(log_failure)
federation_handler.handle_new_event(
yield federation_handler.handle_new_event(
event, destinations=destinations,
)

View File

@@ -34,7 +34,6 @@ class AdminHandler(BaseHandler):
d = {}
for r in res:
# Note that device_id is always None
device = d.setdefault(r["device_id"], {})
session = device.setdefault(r["access_token"], [])
session.append({

View File

@@ -15,37 +15,58 @@
from twisted.internet import defer
from synapse.api.constants import EventTypes
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import Codes, StoreError, SynapseError
from synapse.appservice import ApplicationService
from synapse.types import UserID
import synapse.util.stringutils as stringutils
import logging
logger = logging.getLogger(__name__)
def log_failure(failure):
logger.error(
"Application Services Failure",
exc_info=(
failure.type,
failure.value,
failure.getTracebackObject()
)
)
# NB: Purposefully not inheriting BaseHandler since that contains way too much
# setup code which this handler does not need or use. This makes testing a lot
# easier.
class ApplicationServicesHandler(object):
def __init__(self, hs, appservice_api, appservice_scheduler):
def __init__(self, hs, appservice_api):
self.store = hs.get_datastore()
self.hs = hs
self.appservice_api = appservice_api
self.scheduler = appservice_scheduler
self.started_scheduler = False
@defer.inlineCallbacks
def register(self, app_service):
logger.info("Register -> %s", app_service)
# check the token is recognised
try:
stored_service = yield self.store.get_app_service_by_token(
app_service.token
)
if not stored_service:
raise StoreError(404, "Application service not found")
except StoreError:
raise SynapseError(
403, "Unrecognised application services token. "
"Consult the home server admin.",
errcode=Codes.FORBIDDEN
)
app_service.hs_token = self._generate_hs_token()
# create a sender for this application service which is used when
# creating rooms, etc..
account = yield self.hs.get_handlers().registration_handler.register()
app_service.sender = account[0]
yield self.store.update_app_service(app_service)
defer.returnValue(app_service)
@defer.inlineCallbacks
def unregister(self, token):
logger.info("Unregister as_token=%s", token)
yield self.store.unregister_app_service(token)
@defer.inlineCallbacks
def notify_interested_services(self, event):
@@ -69,13 +90,9 @@ class ApplicationServicesHandler(object):
if event.type == EventTypes.Member:
yield self._check_user_exists(event.state_key)
if not self.started_scheduler:
self.scheduler.start().addErrback(log_failure)
self.started_scheduler = True
# Fork off pushes to these services
# Fork off pushes to these services - XXX First cut, best effort
for service in services:
self.scheduler.submit_event_for_as(service, event)
self.appservice_api.push(service, event)
@defer.inlineCallbacks
def query_user_exists(self, user_id):
@@ -147,7 +164,10 @@ class ApplicationServicesHandler(object):
)
# We need to know the members associated with this event.room_id,
# if any.
member_list = yield self.store.get_users_in_room(event.room_id)
member_list = yield self.store.get_room_members(
room_id=event.room_id,
membership=Membership.JOIN
)
services = yield self.store.get_app_services()
interested_list = [
@@ -177,14 +197,7 @@ class ApplicationServicesHandler(object):
return
user_info = yield self.store.get_user_by_id(user_id)
if user_info:
defer.returnValue(False)
return
# user not found; could be the AS though, so check.
services = yield self.store.get_app_services()
service_list = [s for s in services if s.sender == user_id]
defer.returnValue(len(service_list) == 0)
defer.returnValue(len(user_info) == 0)
@defer.inlineCallbacks
def _check_user_exists(self, user_id):
@@ -193,3 +206,6 @@ class ApplicationServicesHandler(object):
exists = yield self.query_user_exists(user_id)
defer.returnValue(exists)
defer.returnValue(True)
def _generate_hs_token(self):
return stringutils.random_string(24)

View File

@@ -1,415 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from ._base import BaseHandler
from synapse.api.constants import LoginType
from synapse.types import UserID
from synapse.api.errors import LoginError, Codes
from synapse.util.async import run_on_reactor
from twisted.web.client import PartialDownloadError
import logging
import bcrypt
import pymacaroons
import simplejson
import synapse.util.stringutils as stringutils
logger = logging.getLogger(__name__)
class AuthHandler(BaseHandler):
def __init__(self, hs):
super(AuthHandler, self).__init__(hs)
self.checkers = {
LoginType.PASSWORD: self._check_password_auth,
LoginType.RECAPTCHA: self._check_recaptcha,
LoginType.EMAIL_IDENTITY: self._check_email_identity,
LoginType.DUMMY: self._check_dummy_auth,
}
self.sessions = {}
@defer.inlineCallbacks
def check_auth(self, flows, clientdict, clientip):
"""
Takes a dictionary sent by the client in the login / registration
protocol and handles the login flow.
As a side effect, this function fills in the 'creds' key on the user's
session with a map, which maps each auth-type (str) to the relevant
identity authenticated by that auth-type (mostly str, but for captcha, bool).
Args:
flows (list): A list of login flows. Each flow is an ordered list of
strings representing auth-types. At least one full
flow must be completed in order for auth to be successful.
clientdict: The dictionary from the client root level, not the
'auth' key: this method prompts for auth if none is sent.
clientip (str): The IP address of the client.
Returns:
A tuple of (authed, dict, dict) where authed is true if the client
has successfully completed an auth flow. If it is true, the first
dict contains the authenticated credentials of each stage.
If authed is false, the first dictionary is the server response to
the login request and should be passed back to the client.
In either case, the second dict contains the parameters for this
request (which may have been given only in a previous call).
"""
authdict = None
sid = None
if clientdict and 'auth' in clientdict:
authdict = clientdict['auth']
del clientdict['auth']
if 'session' in authdict:
sid = authdict['session']
session = self._get_session_info(sid)
if len(clientdict) > 0:
# This was designed to allow the client to omit the parameters
# and just supply the session in subsequent calls so it split
# auth between devices by just sharing the session, (eg. so you
# could continue registration from your phone having clicked the
# email auth link on there). It's probably too open to abuse
# because it lets unauthenticated clients store arbitrary objects
# on a home server.
# Revisit: Assumimg the REST APIs do sensible validation, the data
# isn't arbintrary.
session['clientdict'] = clientdict
self._save_session(session)
elif 'clientdict' in session:
clientdict = session['clientdict']
if not authdict:
defer.returnValue(
(False, self._auth_dict_for_flows(flows, session), clientdict)
)
if 'creds' not in session:
session['creds'] = {}
creds = session['creds']
# check auth type currently being presented
if 'type' in authdict:
if authdict['type'] not in self.checkers:
raise LoginError(400, "", Codes.UNRECOGNIZED)
result = yield self.checkers[authdict['type']](authdict, clientip)
if result:
creds[authdict['type']] = result
self._save_session(session)
for f in flows:
if len(set(f) - set(creds.keys())) == 0:
logger.info("Auth completed with creds: %r", creds)
self._remove_session(session)
defer.returnValue((True, creds, clientdict))
ret = self._auth_dict_for_flows(flows, session)
ret['completed'] = creds.keys()
defer.returnValue((False, ret, clientdict))
@defer.inlineCallbacks
def add_oob_auth(self, stagetype, authdict, clientip):
"""
Adds the result of out-of-band authentication into an existing auth
session. Currently used for adding the result of fallback auth.
"""
if stagetype not in self.checkers:
raise LoginError(400, "", Codes.MISSING_PARAM)
if 'session' not in authdict:
raise LoginError(400, "", Codes.MISSING_PARAM)
sess = self._get_session_info(
authdict['session']
)
if 'creds' not in sess:
sess['creds'] = {}
creds = sess['creds']
result = yield self.checkers[stagetype](authdict, clientip)
if result:
creds[stagetype] = result
self._save_session(sess)
defer.returnValue(True)
defer.returnValue(False)
@defer.inlineCallbacks
def _check_password_auth(self, authdict, _):
if "user" not in authdict or "password" not in authdict:
raise LoginError(400, "", Codes.MISSING_PARAM)
user_id = authdict["user"]
password = authdict["password"]
if not user_id.startswith('@'):
user_id = UserID.create(user_id, self.hs.hostname).to_string()
user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id)
self._check_password(user_id, password, password_hash)
defer.returnValue(user_id)
@defer.inlineCallbacks
def _check_recaptcha(self, authdict, clientip):
try:
user_response = authdict["response"]
except KeyError:
# Client tried to provide captcha but didn't give the parameter:
# bad request.
raise LoginError(
400, "Captcha response is required",
errcode=Codes.CAPTCHA_NEEDED
)
logger.info(
"Submitting recaptcha response %s with remoteip %s",
user_response, clientip
)
# TODO: get this from the homeserver rather than creating a new one for
# each request
try:
client = self.hs.get_simple_http_client()
resp_body = yield client.post_urlencoded_get_json(
self.hs.config.recaptcha_siteverify_api,
args={
'secret': self.hs.config.recaptcha_private_key,
'response': user_response,
'remoteip': clientip,
}
)
except PartialDownloadError as pde:
# Twisted is silly
data = pde.response
resp_body = simplejson.loads(data)
if 'success' in resp_body and resp_body['success']:
defer.returnValue(True)
raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
@defer.inlineCallbacks
def _check_email_identity(self, authdict, _):
yield run_on_reactor()
if 'threepid_creds' not in authdict:
raise LoginError(400, "Missing threepid_creds", Codes.MISSING_PARAM)
threepid_creds = authdict['threepid_creds']
identity_handler = self.hs.get_handlers().identity_handler
logger.info("Getting validated threepid. threepidcreds: %r" % (threepid_creds,))
threepid = yield identity_handler.threepid_from_creds(threepid_creds)
if not threepid:
raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
threepid['threepid_creds'] = authdict['threepid_creds']
defer.returnValue(threepid)
@defer.inlineCallbacks
def _check_dummy_auth(self, authdict, _):
yield run_on_reactor()
defer.returnValue(True)
def _get_params_recaptcha(self):
return {"public_key": self.hs.config.recaptcha_public_key}
def _auth_dict_for_flows(self, flows, session):
public_flows = []
for f in flows:
public_flows.append(f)
get_params = {
LoginType.RECAPTCHA: self._get_params_recaptcha,
}
params = {}
for f in public_flows:
for stage in f:
if stage in get_params and stage not in params:
params[stage] = get_params[stage]()
return {
"session": session['id'],
"flows": [{"stages": f} for f in public_flows],
"params": params
}
def _get_session_info(self, session_id):
if session_id not in self.sessions:
session_id = None
if not session_id:
# create a new session
while session_id is None or session_id in self.sessions:
session_id = stringutils.random_string(24)
self.sessions[session_id] = {
"id": session_id,
}
return self.sessions[session_id]
@defer.inlineCallbacks
def login_with_password(self, user_id, password):
"""
Authenticates the user with their username and password.
Used only by the v1 login API.
Args:
user_id (str): User ID
password (str): Password
Returns:
A tuple of:
The user's ID.
The access token for the user's session.
The refresh token for the user's session.
Raises:
StoreError if there was a problem storing the token.
LoginError if there was an authentication problem.
"""
user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id)
self._check_password(user_id, password, password_hash)
logger.info("Logging in user %s", user_id)
access_token = yield self.issue_access_token(user_id)
refresh_token = yield self.issue_refresh_token(user_id)
defer.returnValue((user_id, access_token, refresh_token))
@defer.inlineCallbacks
def _find_user_id_and_pwd_hash(self, user_id):
"""Checks to see if a user with the given id exists. Will check case
insensitively, but will throw if there are multiple inexact matches.
Returns:
tuple: A 2-tuple of `(canonical_user_id, password_hash)`
"""
user_infos = yield self.store.get_users_by_id_case_insensitive(user_id)
if not user_infos:
logger.warn("Attempted to login as %s but they do not exist", user_id)
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
if len(user_infos) > 1:
if user_id not in user_infos:
logger.warn(
"Attempted to login as %s but it matches more than one user "
"inexactly: %r",
user_id, user_infos.keys()
)
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
defer.returnValue((user_id, user_infos[user_id]))
else:
defer.returnValue(user_infos.popitem())
def _check_password(self, user_id, password, stored_hash):
"""Checks that user_id has passed password, raises LoginError if not."""
if not self.validate_hash(password, stored_hash):
logger.warn("Failed password login for user %s", user_id)
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
@defer.inlineCallbacks
def issue_access_token(self, user_id):
access_token = self.generate_access_token(user_id)
yield self.store.add_access_token_to_user(user_id, access_token)
defer.returnValue(access_token)
@defer.inlineCallbacks
def issue_refresh_token(self, user_id):
refresh_token = self.generate_refresh_token(user_id)
yield self.store.add_refresh_token_to_user(user_id, refresh_token)
defer.returnValue(refresh_token)
def generate_access_token(self, user_id):
macaroon = self._generate_base_macaroon(user_id)
macaroon.add_first_party_caveat("type = access")
now = self.hs.get_clock().time_msec()
expiry = now + (60 * 60 * 1000)
macaroon.add_first_party_caveat("time < %d" % (expiry,))
return macaroon.serialize()
def generate_refresh_token(self, user_id):
m = self._generate_base_macaroon(user_id)
m.add_first_party_caveat("type = refresh")
# Important to add a nonce, because otherwise every refresh token for a
# user will be the same.
m.add_first_party_caveat("nonce = %s" % (
stringutils.random_string_with_symbols(16),
))
return m.serialize()
def _generate_base_macaroon(self, user_id):
macaroon = pymacaroons.Macaroon(
location=self.hs.config.server_name,
identifier="key",
key=self.hs.config.macaroon_secret_key)
macaroon.add_first_party_caveat("gen = 1")
macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
return macaroon
@defer.inlineCallbacks
def set_password(self, user_id, newpassword):
password_hash = self.hash(newpassword)
yield self.store.user_set_password_hash(user_id, password_hash)
yield self.store.user_delete_access_tokens(user_id)
yield self.hs.get_pusherpool().remove_pushers_by_user(user_id)
yield self.store.flush_user(user_id)
@defer.inlineCallbacks
def add_threepid(self, user_id, medium, address, validated_at):
yield self.store.user_add_threepid(
user_id, medium, address, validated_at,
self.hs.get_clock().time_msec()
)
def _save_session(self, session):
# TODO: Persistent storage
logger.debug("Saving session %s", session)
self.sessions[session["id"]] = session
def _remove_session(self, session):
logger.debug("Removing session %s", session)
del self.sessions[session["id"]]
def hash(self, password):
"""Computes a secure hash of password.
Args:
password (str): Password to hash.
Returns:
Hashed password (str).
"""
return bcrypt.hashpw(password, bcrypt.gensalt())
def validate_hash(self, password, stored_hash):
"""Validates that self.hash(password) == stored_hash.
Args:
password (str): Password to hash.
stored_hash (str): Expected hash value.
Returns:
Whether self.hash(password) == stored_hash (bool).
"""
return bcrypt.checkpw(password, stored_hash)

View File

@@ -22,7 +22,6 @@ from synapse.api.constants import EventTypes
from synapse.types import RoomAlias
import logging
import string
logger = logging.getLogger(__name__)
@@ -41,10 +40,6 @@ class DirectoryHandler(BaseHandler):
def _create_association(self, room_alias, room_id, servers=None):
# general association creation for both human users and app services
for wchar in string.whitespace:
if wchar in room_alias.localpart:
raise SynapseError(400, "Invalid characters in room alias")
if not self.hs.is_mine(room_alias):
raise SynapseError(400, "Room alias must be local")
# TODO(erikj): Change this.

View File

@@ -15,6 +15,7 @@
from twisted.internet import defer
from synapse.util.logcontext import PreserveLoggingContext
from synapse.util.logutils import log_function
from synapse.types import UserID
from synapse.events.utils import serialize_event
@@ -49,12 +50,7 @@ class EventStreamHandler(BaseHandler):
@defer.inlineCallbacks
@log_function
def get_stream(self, auth_user_id, pagin_config, timeout=0,
as_client_event=True, affect_presence=True,
only_room_events=False):
"""Fetches the events stream for a given user.
If `only_room_events` is `True` only room events will be returned.
"""
as_client_event=True, affect_presence=True):
auth_user = UserID.from_string(auth_user_id)
try:
@@ -75,15 +71,7 @@ class EventStreamHandler(BaseHandler):
self._streams_per_user[auth_user] += 1
rm_handler = self.hs.get_handlers().room_member_handler
app_service = yield self.store.get_app_service_by_user_id(
auth_user.to_string()
)
if app_service:
rooms = yield self.store.get_app_service_rooms(app_service)
room_ids = set(r.room_id for r in rooms)
else:
room_ids = yield rm_handler.get_joined_rooms_for_user(auth_user)
room_ids = yield rm_handler.get_joined_rooms_for_user(auth_user)
if timeout:
# If they've set a timeout set a minimum limit.
@@ -93,10 +81,10 @@ class EventStreamHandler(BaseHandler):
# thundering herds on restart.
timeout = random.randint(int(timeout*0.9), int(timeout*1.1))
events, tokens = yield self.notifier.get_events_for(
auth_user, room_ids, pagin_config, timeout,
only_room_events=only_room_events
)
with PreserveLoggingContext():
events, tokens = yield self.notifier.get_events_for(
auth_user, room_ids, pagin_config, timeout
)
time_now = self.clock.time_msec()

View File

@@ -18,11 +18,9 @@
from ._base import BaseHandler
from synapse.api.errors import (
AuthError, FederationError, StoreError, CodeMessageException, SynapseError,
AuthError, FederationError, StoreError,
)
from synapse.api.constants import EventTypes, Membership, RejectedReason
from synapse.util import unwrapFirstError
from synapse.util.logcontext import PreserveLoggingContext
from synapse.util.logutils import log_function
from synapse.util.async import run_on_reactor
from synapse.util.frozenutils import unfreeze
@@ -31,10 +29,6 @@ from synapse.crypto.event_signing import (
)
from synapse.types import UserID
from synapse.events.utils import prune_event
from synapse.util.retryutils import NotRetryingDestination
from twisted.internet import defer
import itertools
@@ -79,6 +73,7 @@ class FederationHandler(BaseHandler):
# When joining a room we need to queue any events for that room up
self.room_queues = {}
@defer.inlineCallbacks
def handle_new_event(self, event, destinations):
""" Takes in an event from the client to server side, that has already
been authed and handled by the state module, and sends it to any
@@ -93,7 +88,9 @@ class FederationHandler(BaseHandler):
processing.
"""
return self.replication_layer.send_pdu(event, destinations)
yield run_on_reactor()
self.replication_layer.send_pdu(event, destinations)
@log_function
@defer.inlineCallbacks
@@ -140,32 +137,29 @@ class FederationHandler(BaseHandler):
if state and auth_chain is not None:
# If we have any state or auth_chain given to us by the replication
# layer, then we should handle them (if we haven't before.)
event_infos = []
for e in itertools.chain(auth_chain, state):
if e.event_id in seen_ids:
continue
e.internal_metadata.outlier = True
auth_ids = [e_id for e_id, _ in e.auth_events]
auth = {
(e.type, e.state_key): e for e in auth_chain
if e.event_id in auth_ids
}
event_infos.append({
"event": e,
"auth_events": auth,
})
seen_ids.add(e.event_id)
yield self._handle_new_events(
origin,
event_infos,
outliers=True
)
e.internal_metadata.outlier = True
try:
auth_ids = [e_id for e_id, _ in e.auth_events]
auth = {
(e.type, e.state_key): e for e in auth_chain
if e.event_id in auth_ids
}
yield self._handle_new_event(
origin, e, auth_events=auth
)
seen_ids.add(e.event_id)
except:
logger.exception(
"Failed to handle state event %s",
e.event_id,
)
try:
_, event_stream_id, max_stream_id = yield self._handle_new_event(
yield self._handle_new_event(
origin,
event,
state=state,
@@ -184,7 +178,7 @@ class FederationHandler(BaseHandler):
# it's probably a good idea to mark it as not in retry-state
# for sending (although this is a bit of a leap)
retry_timings = yield self.store.get_destination_retry_timings(origin)
if retry_timings and retry_timings["retry_last_ts"]:
if (retry_timings and retry_timings.retry_last_ts):
self.store.set_destination_retry_timings(origin, 0, 0)
room = yield self.store.get_room(event.room_id)
@@ -206,19 +200,9 @@ class FederationHandler(BaseHandler):
target_user = UserID.from_string(target_user_id)
extra_users.append(target_user)
with PreserveLoggingContext():
d = self.notifier.on_new_room_event(
event, event_stream_id, max_stream_id,
extra_users=extra_users
)
def log_failure(f):
logger.warn(
"Failed to notify about %s: %s",
event.event_id, f.value
)
d.addErrback(log_failure)
yield self.notifier.on_new_room_event(
event, extra_users=extra_users
)
if event.type == EventTypes.Member:
if event.membership == Membership.JOIN:
@@ -227,300 +211,38 @@ class FederationHandler(BaseHandler):
"user_joined_room", user=user, room_id=event.room_id
)
@defer.inlineCallbacks
def _filter_events_for_server(self, server_name, room_id, events):
event_to_state = yield self.store.get_state_for_events(
room_id, frozenset(e.event_id for e in events),
types=(
(EventTypes.RoomHistoryVisibility, ""),
(EventTypes.Member, None),
)
)
def redact_disallowed(event, state):
if not state:
return event
history = state.get((EventTypes.RoomHistoryVisibility, ''), None)
if history:
visibility = history.content.get("history_visibility", "shared")
if visibility in ["invited", "joined"]:
# We now loop through all state events looking for
# membership states for the requesting server to determine
# if the server is either in the room or has been invited
# into the room.
for ev in state.values():
if ev.type != EventTypes.Member:
continue
try:
domain = UserID.from_string(ev.state_key).domain
except:
continue
if domain != server_name:
continue
memtype = ev.membership
if memtype == Membership.JOIN:
return event
elif memtype == Membership.INVITE:
if visibility == "invited":
return event
else:
return prune_event(event)
return event
defer.returnValue([
redact_disallowed(e, event_to_state[e.event_id])
for e in events
])
@log_function
@defer.inlineCallbacks
def backfill(self, dest, room_id, limit, extremities=[]):
def backfill(self, dest, room_id, limit):
""" Trigger a backfill request to `dest` for the given `room_id`
"""
if not extremities:
extremities = yield self.store.get_oldest_events_in_room(room_id)
extremities = yield self.store.get_oldest_events_in_room(room_id)
events = yield self.replication_layer.backfill(
pdus = yield self.replication_layer.backfill(
dest,
room_id,
limit=limit,
limit,
extremities=extremities,
)
event_map = {e.event_id: e for e in events}
events = []
event_ids = set(e.event_id for e in events)
for pdu in pdus:
event = pdu
edges = [
ev.event_id
for ev in events
if set(e_id for e_id, _ in ev.prev_events) - event_ids
]
# FIXME (erikj): Not sure this actually works :/
context = yield self.state_handler.compute_event_context(event)
logger.info(
"backfill: Got %d events with %d edges",
len(events), len(edges),
)
events.append((event, context))
# For each edge get the current state.
auth_events = {}
state_events = {}
events_to_state = {}
for e_id in edges:
state, auth = yield self.replication_layer.get_state_for_room(
destination=dest,
room_id=room_id,
event_id=e_id
yield self.store.persist_event(
event,
context=context,
backfilled=True
)
auth_events.update({a.event_id: a for a in auth})
auth_events.update({s.event_id: s for s in state})
state_events.update({s.event_id: s for s in state})
events_to_state[e_id] = state
seen_events = yield self.store.have_events(
set(auth_events.keys()) | set(state_events.keys())
)
all_events = events + state_events.values() + auth_events.values()
required_auth = set(
a_id for event in all_events for a_id, _ in event.auth_events
)
missing_auth = required_auth - set(auth_events)
results = yield defer.gatherResults(
[
self.replication_layer.get_pdu(
[dest],
event_id,
outlier=True,
timeout=10000,
)
for event_id in missing_auth
],
consumeErrors=True
).addErrback(unwrapFirstError)
auth_events.update({a.event_id: a for a in results})
ev_infos = []
for a in auth_events.values():
if a.event_id in seen_events:
continue
ev_infos.append({
"event": a,
"auth_events": {
(auth_events[a_id].type, auth_events[a_id].state_key):
auth_events[a_id]
for a_id, _ in a.auth_events
}
})
for e_id in events_to_state:
ev_infos.append({
"event": event_map[e_id],
"state": events_to_state[e_id],
"auth_events": {
(auth_events[a_id].type, auth_events[a_id].state_key):
auth_events[a_id]
for a_id, _ in event_map[e_id].auth_events
}
})
events.sort(key=lambda e: e.depth)
for event in events:
if event in events_to_state:
continue
ev_infos.append({
"event": event,
})
yield self._handle_new_events(
dest, ev_infos,
backfilled=True,
)
defer.returnValue(events)
@defer.inlineCallbacks
def maybe_backfill(self, room_id, current_depth):
"""Checks the database to see if we should backfill before paginating,
and if so do.
"""
extremities = yield self.store.get_oldest_events_with_depth_in_room(
room_id
)
if not extremities:
logger.debug("Not backfilling as no extremeties found.")
return
# Check if we reached a point where we should start backfilling.
sorted_extremeties_tuple = sorted(
extremities.items(),
key=lambda e: -int(e[1])
)
max_depth = sorted_extremeties_tuple[0][1]
if current_depth > max_depth:
logger.debug(
"Not backfilling as we don't need to. %d < %d",
max_depth, current_depth,
)
return
# Now we need to decide which hosts to hit first.
# First we try hosts that are already in the room
# TODO: HEURISTIC ALERT.
curr_state = yield self.state_handler.get_current_state(room_id)
def get_domains_from_state(state):
joined_users = [
(state_key, int(event.depth))
for (e_type, state_key), event in state.items()
if e_type == EventTypes.Member
and event.membership == Membership.JOIN
]
joined_domains = {}
for u, d in joined_users:
try:
dom = UserID.from_string(u).domain
old_d = joined_domains.get(dom)
if old_d:
joined_domains[dom] = min(d, old_d)
else:
joined_domains[dom] = d
except:
pass
return sorted(joined_domains.items(), key=lambda d: d[1])
curr_domains = get_domains_from_state(curr_state)
likely_domains = [
domain for domain, depth in curr_domains
if domain is not self.server_name
]
@defer.inlineCallbacks
def try_backfill(domains):
# TODO: Should we try multiple of these at a time?
for dom in domains:
try:
events = yield self.backfill(
dom, room_id,
limit=100,
extremities=[e for e in extremities.keys()]
)
except SynapseError:
logger.info(
"Failed to backfill from %s because %s",
dom, e,
)
continue
except CodeMessageException as e:
if 400 <= e.code < 500:
raise
logger.info(
"Failed to backfill from %s because %s",
dom, e,
)
continue
except NotRetryingDestination as e:
logger.info(e.message)
continue
except Exception as e:
logger.exception(
"Failed to backfill from %s because %s",
dom, e,
)
continue
if events:
defer.returnValue(True)
defer.returnValue(False)
success = yield try_backfill(likely_domains)
if success:
defer.returnValue(True)
# Huh, well *those* domains didn't work out. Lets try some domains
# from the time.
tried_domains = set(likely_domains)
tried_domains.add(self.server_name)
event_ids = list(extremities.keys())
states = yield defer.gatherResults([
self.state_handler.resolve_state_groups(room_id, [e])
for e in event_ids
])
states = dict(zip(event_ids, [s[1] for s in states]))
for e_id, _ in sorted_extremeties_tuple:
likely_domains = get_domains_from_state(states[e_id])
success = yield try_backfill([
dom for dom in likely_domains
if dom not in tried_domains
])
if success:
defer.returnValue(True)
tried_domains.update(likely_domains)
defer.returnValue(False)
@defer.inlineCallbacks
def send_invite(self, target_host, event):
""" Sends the invite to the remote server for signing.
@@ -567,8 +289,6 @@ class FederationHandler(BaseHandler):
"""
logger.debug("Joining %s to %s", joinee, room_id)
yield self.store.clean_room_for_join(room_id)
origin, pdu = yield self.replication_layer.make_join(
target_hosts,
room_id,
@@ -649,22 +369,46 @@ class FederationHandler(BaseHandler):
# FIXME
pass
ev_infos = []
for e in itertools.chain(state, auth_chain):
for e in auth_chain:
e.internal_metadata.outlier = True
if e.event_id == event.event_id:
continue
try:
auth_ids = [e_id for e_id, _ in e.auth_events]
auth = {
(e.type, e.state_key): e for e in auth_chain
if e.event_id in auth_ids
}
yield self._handle_new_event(
origin, e, auth_events=auth
)
except:
logger.exception(
"Failed to handle auth event %s",
e.event_id,
)
for e in state:
if e.event_id == event.event_id:
continue
e.internal_metadata.outlier = True
auth_ids = [e_id for e_id, _ in e.auth_events]
ev_infos.append({
"event": e,
"auth_events": {
try:
auth_ids = [e_id for e_id, _ in e.auth_events]
auth = {
(e.type, e.state_key): e for e in auth_chain
if e.event_id in auth_ids
}
})
yield self._handle_new_events(origin, ev_infos, outliers=True)
yield self._handle_new_event(
origin, e, auth_events=auth
)
except:
logger.exception(
"Failed to handle state event %s",
e.event_id,
)
auth_ids = [e_id for e_id, _ in event.auth_events]
auth_events = {
@@ -672,7 +416,7 @@ class FederationHandler(BaseHandler):
if e.event_id in auth_ids
}
_, event_stream_id, max_stream_id = yield self._handle_new_event(
yield self._handle_new_event(
origin,
new_event,
state=state,
@@ -680,19 +424,9 @@ class FederationHandler(BaseHandler):
auth_events=auth_events,
)
with PreserveLoggingContext():
d = self.notifier.on_new_room_event(
new_event, event_stream_id, max_stream_id,
extra_users=[joinee]
)
def log_failure(f):
logger.warn(
"Failed to notify about %s: %s",
new_event.event_id, f.value
)
d.addErrback(log_failure)
yield self.notifier.on_new_room_event(
new_event, extra_users=[joinee]
)
logger.debug("Finished joining %s to %s", joinee, room_id)
finally:
@@ -729,9 +463,11 @@ class FederationHandler(BaseHandler):
builder=builder,
)
self.auth.check(event, auth_events=context.current_state)
self.auth.check(event, auth_events=context.auth_events)
defer.returnValue(event)
pdu = event
defer.returnValue(pdu)
@defer.inlineCallbacks
@log_function
@@ -749,9 +485,7 @@ class FederationHandler(BaseHandler):
event.internal_metadata.outlier = False
context, event_stream_id, max_stream_id = yield self._handle_new_event(
origin, event
)
context = yield self._handle_new_event(origin, event)
logger.debug(
"on_send_join_request: After _handle_new_event: %s, sigs: %s",
@@ -765,18 +499,9 @@ class FederationHandler(BaseHandler):
target_user = UserID.from_string(target_user_id)
extra_users.append(target_user)
with PreserveLoggingContext():
d = self.notifier.on_new_room_event(
event, event_stream_id, max_stream_id, extra_users=extra_users
)
def log_failure(f):
logger.warn(
"Failed to notify about %s: %s",
event.event_id, f.value
)
d.addErrback(log_failure)
yield self.notifier.on_new_room_event(
event, extra_users=extra_users
)
if event.type == EventTypes.Member:
if event.content["membership"] == Membership.JOIN:
@@ -841,26 +566,16 @@ class FederationHandler(BaseHandler):
context = yield self.state_handler.compute_event_context(event)
event_stream_id, max_stream_id = yield self.store.persist_event(
yield self.store.persist_event(
event,
context=context,
backfilled=False,
)
target_user = UserID.from_string(event.state_key)
with PreserveLoggingContext():
d = self.notifier.on_new_room_event(
event, event_stream_id, max_stream_id,
extra_users=[target_user],
)
def log_failure(f):
logger.warn(
"Failed to notify about %s: %s",
event.event_id, f.value
)
d.addErrback(log_failure)
yield self.notifier.on_new_room_event(
event, extra_users=[target_user],
)
defer.returnValue(event)
@@ -874,7 +589,7 @@ class FederationHandler(BaseHandler):
raise AuthError(403, "Host not in room.")
state_groups = yield self.store.get_state_groups(
room_id, [event_id]
[event_id]
)
if state_groups:
@@ -921,8 +636,6 @@ class FederationHandler(BaseHandler):
limit
)
events = yield self._filter_events_for_server(origin, room_id, events)
defer.returnValue(events)
@defer.inlineCallbacks
@@ -981,72 +694,31 @@ class FederationHandler(BaseHandler):
def _handle_new_event(self, origin, event, state=None, backfilled=False,
current_state=None, auth_events=None):
outlier = event.internal_metadata.is_outlier()
context = yield self._prep_event(
origin, event,
state=state,
backfilled=backfilled,
current_state=current_state,
auth_events=auth_events,
logger.debug(
"_handle_new_event: %s, sigs: %s",
event.event_id, event.signatures,
)
event_stream_id, max_stream_id = yield self.store.persist_event(
event,
context=context,
backfilled=backfilled,
is_new_state=(not outlier and not backfilled),
current_state=current_state,
)
defer.returnValue((context, event_stream_id, max_stream_id))
@defer.inlineCallbacks
def _handle_new_events(self, origin, event_infos, backfilled=False,
outliers=False):
contexts = yield defer.gatherResults(
[
self._prep_event(
origin,
ev_info["event"],
state=ev_info.get("state"),
backfilled=backfilled,
auth_events=ev_info.get("auth_events"),
)
for ev_info in event_infos
]
)
yield self.store.persist_events(
[
(ev_info["event"], context)
for ev_info, context in itertools.izip(event_infos, contexts)
],
backfilled=backfilled,
is_new_state=(not outliers and not backfilled),
)
@defer.inlineCallbacks
def _prep_event(self, origin, event, state=None, backfilled=False,
current_state=None, auth_events=None):
outlier = event.internal_metadata.is_outlier()
context = yield self.state_handler.compute_event_context(
event, old_state=state, outlier=outlier,
event, old_state=state
)
if not auth_events:
auth_events = context.current_state
auth_events = context.auth_events
logger.debug(
"_handle_new_event: %s, auth_events: %s",
event.event_id, auth_events,
)
is_new_state = not event.internal_metadata.is_outlier()
# This is a hack to fix some old rooms where the initial join event
# didn't reference the create event in its auth events.
if event.type == EventTypes.Member and not event.auth_events:
if len(event.prev_events) == 1 and event.depth < 5:
c = yield self.store.get_event(
event.prev_events[0][0],
allow_none=True,
)
if c and c.type == EventTypes.Create:
if len(event.prev_events) == 1:
c = yield self.store.get_event(event.prev_events[0][0])
if c.type == EventTypes.Create:
auth_events[(c.type, c.state_key)] = c
try:
@@ -1061,6 +733,25 @@ class FederationHandler(BaseHandler):
context.rejected = RejectedReason.AUTH_ERROR
# FIXME: Don't store as rejected with AUTH_ERROR if we haven't
# seen all the auth events.
yield self.store.persist_event(
event,
context=context,
backfilled=backfilled,
is_new_state=False,
current_state=current_state,
)
raise
yield self.store.persist_event(
event,
context=context,
backfilled=backfilled,
is_new_state=(is_new_state and not backfilled),
current_state=current_state,
)
defer.returnValue(context)
@defer.inlineCallbacks
@@ -1124,24 +815,14 @@ class FederationHandler(BaseHandler):
@log_function
def do_auth(self, origin, event, context, auth_events):
# Check if we have all the auth events.
current_state = set(e.event_id for e in auth_events.values())
have_events = yield self.store.have_events(
[e_id for e_id, _ in event.auth_events]
)
event_auth_events = set(e_id for e_id, _ in event.auth_events)
if event_auth_events - current_state:
have_events = yield self.store.have_events(
event_auth_events - current_state
)
else:
have_events = {}
have_events.update({
e.event_id: ""
for e in auth_events.values()
})
seen_events = set(have_events.keys())
missing_auth = event_auth_events - seen_events - current_state
missing_auth = event_auth_events - seen_events
if missing_auth:
logger.info("Missing auth: %s", missing_auth)
@@ -1211,7 +892,7 @@ class FederationHandler(BaseHandler):
if d in have_events and not have_events[d]
],
consumeErrors=True
).addErrback(unwrapFirstError)
)
if different_events:
local_view = dict(auth_events)

View File

@@ -1,144 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for interacting with Identity Servers"""
from twisted.internet import defer
from synapse.api.errors import (
CodeMessageException
)
from ._base import BaseHandler
from synapse.http.client import SimpleHttpClient
from synapse.util.async import run_on_reactor
from synapse.api.errors import SynapseError
import json
import logging
logger = logging.getLogger(__name__)
class IdentityHandler(BaseHandler):
def __init__(self, hs):
super(IdentityHandler, self).__init__(hs)
@defer.inlineCallbacks
def threepid_from_creds(self, creds):
yield run_on_reactor()
# TODO: get this from the homeserver rather than creating a new one for
# each request
http_client = SimpleHttpClient(self.hs)
# XXX: make this configurable!
# trustedIdServers = ['matrix.org', 'localhost:8090']
trustedIdServers = ['matrix.org', 'vector.im']
if 'id_server' in creds:
id_server = creds['id_server']
elif 'idServer' in creds:
id_server = creds['idServer']
else:
raise SynapseError(400, "No id_server in creds")
if 'client_secret' in creds:
client_secret = creds['client_secret']
elif 'clientSecret' in creds:
client_secret = creds['clientSecret']
else:
raise SynapseError(400, "No client_secret in creds")
if id_server not in trustedIdServers:
logger.warn('%s is not a trusted ID server: rejecting 3pid ' +
'credentials', id_server)
defer.returnValue(None)
data = {}
try:
data = yield http_client.get_json(
"https://%s%s" % (
id_server,
"/_matrix/identity/api/v1/3pid/getValidated3pid"
),
{'sid': creds['sid'], 'client_secret': client_secret}
)
except CodeMessageException as e:
data = json.loads(e.msg)
if 'medium' in data:
defer.returnValue(data)
defer.returnValue(None)
@defer.inlineCallbacks
def bind_threepid(self, creds, mxid):
yield run_on_reactor()
logger.debug("binding threepid %r to %s", creds, mxid)
http_client = SimpleHttpClient(self.hs)
data = None
if 'id_server' in creds:
id_server = creds['id_server']
elif 'idServer' in creds:
id_server = creds['idServer']
else:
raise SynapseError(400, "No id_server in creds")
if 'client_secret' in creds:
client_secret = creds['client_secret']
elif 'clientSecret' in creds:
client_secret = creds['clientSecret']
else:
raise SynapseError(400, "No client_secret in creds")
try:
data = yield http_client.post_urlencoded_get_json(
"https://%s%s" % (
id_server, "/_matrix/identity/api/v1/3pid/bind"
),
{
'sid': creds['sid'],
'client_secret': client_secret,
'mxid': mxid,
}
)
logger.debug("bound threepid %r to %s", creds, mxid)
except CodeMessageException as e:
data = json.loads(e.msg)
defer.returnValue(data)
@defer.inlineCallbacks
def requestEmailToken(self, id_server, email, client_secret, send_attempt, **kwargs):
yield run_on_reactor()
http_client = SimpleHttpClient(self.hs)
params = {
'email': email,
'client_secret': client_secret,
'send_attempt': send_attempt,
}
params.update(kwargs)
try:
data = yield http_client.post_urlencoded_get_json(
"https://%s%s" % (
id_server,
"/_matrix/identity/api/v1/validate/email/requestToken"
),
params
)
defer.returnValue(data)
except CodeMessageException as e:
logger.info("Proxied requestToken failed: %r", e)
raise e

116
synapse/handlers/login.py Normal file
View File

@@ -0,0 +1,116 @@
# -*- coding: utf-8 -*-
# Copyright 2014, 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from ._base import BaseHandler
from synapse.api.errors import LoginError, Codes, CodeMessageException
from synapse.http.client import SimpleHttpClient
from synapse.util.emailutils import EmailException
import synapse.util.emailutils as emailutils
import bcrypt
import json
import logging
logger = logging.getLogger(__name__)
class LoginHandler(BaseHandler):
def __init__(self, hs):
super(LoginHandler, self).__init__(hs)
self.hs = hs
@defer.inlineCallbacks
def login(self, user, password):
"""Login as the specified user with the specified password.
Args:
user (str): The user ID.
password (str): The password.
Returns:
The newly allocated access token.
Raises:
StoreError if there was a problem storing the token.
LoginError if there was an authentication problem.
"""
# TODO do this better, it can't go in __init__ else it cyclic loops
if not hasattr(self, "reg_handler"):
self.reg_handler = self.hs.get_handlers().registration_handler
# pull out the hash for this user if they exist
user_info = yield self.store.get_user_by_id(user_id=user)
if not user_info:
logger.warn("Attempted to login as %s but they do not exist", user)
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
stored_hash = user_info[0]["password_hash"]
if bcrypt.checkpw(password, stored_hash):
# generate an access token and store it.
token = self.reg_handler._generate_token(user)
logger.info("Adding token %s for user %s", token, user)
yield self.store.add_access_token_to_user(user, token)
defer.returnValue(token)
else:
logger.warn("Failed password login for user %s", user)
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
@defer.inlineCallbacks
def reset_password(self, user_id, email):
is_valid = yield self._check_valid_association(user_id, email)
logger.info("reset_password user=%s email=%s valid=%s", user_id, email,
is_valid)
if is_valid:
try:
# send an email out
emailutils.send_email(
smtp_server=self.hs.config.email_smtp_server,
from_addr=self.hs.config.email_from_address,
to_addr=email,
subject="Password Reset",
body="TODO."
)
except EmailException as e:
logger.exception(e)
@defer.inlineCallbacks
def _check_valid_association(self, user_id, email):
identity = yield self._query_email(email)
if identity and "mxid" in identity:
if identity["mxid"] == user_id:
defer.returnValue(True)
return
defer.returnValue(False)
@defer.inlineCallbacks
def _query_email(self, email):
http_client = SimpleHttpClient(self.hs)
try:
data = yield http_client.get_json(
# TODO FIXME This should be configurable.
# XXX: ID servers need to use HTTPS
"http://%s%s" % (
"matrix.org:8090", "/_matrix/identity/api/v1/lookup"
),
{
'medium': 'email',
'address': email
}
)
defer.returnValue(data)
except CodeMessageException as e:
data = json.loads(e.msg)
defer.returnValue(data)

View File

@@ -16,13 +16,12 @@
from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import SynapseError
from synapse.api.errors import RoomError, SynapseError
from synapse.streams.config import PaginationConfig
from synapse.events.utils import serialize_event
from synapse.events.validator import EventValidator
from synapse.util import unwrapFirstError
from synapse.util.logcontext import PreserveLoggingContext
from synapse.types import UserID, RoomStreamToken, StreamToken
from synapse.types import UserID
from ._base import BaseHandler
@@ -71,7 +70,7 @@ class MessageHandler(BaseHandler):
@defer.inlineCallbacks
def get_messages(self, user_id=None, room_id=None, pagin_config=None,
as_client_event=True):
feedback=False, as_client_event=True):
"""Get messages in a room.
Args:
@@ -79,81 +78,35 @@ class MessageHandler(BaseHandler):
room_id (str): The room they want messages from.
pagin_config (synapse.api.streams.PaginationConfig): The pagination
config rules to apply, if any.
feedback (bool): True to get compressed feedback with the messages
as_client_event (bool): True to get events in client-server format.
Returns:
dict: Pagination API results
"""
member_event = yield self.auth.check_user_was_in_room(room_id, user_id)
yield self.auth.check_joined_room(room_id, user_id)
data_source = self.hs.get_event_sources().sources["room"]
if pagin_config.from_token:
room_token = pagin_config.from_token.room_key
else:
if not pagin_config.from_token:
pagin_config.from_token = (
yield self.hs.get_event_sources().get_current_token(
direction='b'
)
yield self.hs.get_event_sources().get_current_token()
)
room_token = pagin_config.from_token.room_key
room_token = RoomStreamToken.parse(room_token)
if room_token.topological is None:
raise SynapseError(400, "Invalid token")
pagin_config.from_token = pagin_config.from_token.copy_and_replace(
"room_key", str(room_token)
)
source_config = pagin_config.get_source_config("room")
if member_event.membership == Membership.LEAVE:
# If they have left the room then clamp the token to be before
# they left the room
leave_token = yield self.store.get_topological_token_for_event(
member_event.event_id
)
leave_token = RoomStreamToken.parse(leave_token)
if leave_token.topological < room_token.topological:
source_config.from_key = str(leave_token)
if source_config.direction == "f":
if source_config.to_key is None:
source_config.to_key = str(leave_token)
else:
to_token = RoomStreamToken.parse(source_config.to_key)
if leave_token.topological < to_token.topological:
source_config.to_key = str(leave_token)
yield self.hs.get_handlers().federation_handler.maybe_backfill(
room_id, room_token.topological
)
user = UserID.from_string(user_id)
events, next_key = yield data_source.get_pagination_rows(
user, source_config, room_id
user, pagin_config.get_source_config("room"), room_id
)
next_token = pagin_config.from_token.copy_and_replace(
"room_key", next_key
)
if not events:
defer.returnValue({
"chunk": [],
"start": pagin_config.from_token.to_string(),
"end": next_token.to_string(),
})
events = yield self._filter_events_for_client(user_id, room_id, events)
time_now = self.clock.time_msec()
chunk = {
"chunk": [
serialize_event(e, time_now, as_client_event)
for e in events
serialize_event(e, time_now, as_client_event) for e in events
],
"start": pagin_config.from_token.to_string(),
"end": next_token.to_string(),
@@ -161,55 +114,9 @@ class MessageHandler(BaseHandler):
defer.returnValue(chunk)
@defer.inlineCallbacks
def _filter_events_for_client(self, user_id, room_id, events):
event_id_to_state = yield self.store.get_state_for_events(
room_id, frozenset(e.event_id for e in events),
types=(
(EventTypes.RoomHistoryVisibility, ""),
(EventTypes.Member, user_id),
)
)
def allowed(event, state):
if event.type == EventTypes.RoomHistoryVisibility:
return True
membership_ev = state.get((EventTypes.Member, user_id), None)
if membership_ev:
membership = membership_ev.membership
else:
membership = Membership.LEAVE
if membership == Membership.JOIN:
return True
history = state.get((EventTypes.RoomHistoryVisibility, ''), None)
if history:
visibility = history.content.get("history_visibility", "shared")
else:
visibility = "shared"
if visibility == "public":
return True
elif visibility == "shared":
return True
elif visibility == "joined":
return membership == Membership.JOIN
elif visibility == "invited":
return membership == Membership.INVITE
return True
defer.returnValue([
event
for event in events
if allowed(event, event_id_to_state[event.event_id])
])
@defer.inlineCallbacks
def create_and_send_event(self, event_dict, ratelimit=True,
token_id=None, txn_id=None):
client=None, txn_id=None):
""" Given a dict from a client, create and handle a new event.
Creates an FrozenEvent object, filling out auth_events, prev_events,
@@ -243,8 +150,11 @@ class MessageHandler(BaseHandler):
builder.content
)
if token_id is not None:
builder.internal_metadata.token_id = token_id
if client is not None:
if client.token_id is not None:
builder.internal_metadata.token_id = client.token_id
if client.device_id is not None:
builder.internal_metadata.device_id = client.device_id
if txn_id is not None:
builder.internal_metadata.txn_id = txn_id
@@ -281,26 +191,29 @@ class MessageHandler(BaseHandler):
Raises:
SynapseError if something went wrong.
"""
member_event = yield self.auth.check_user_was_in_room(room_id, user_id)
if member_event.membership == Membership.JOIN:
data = yield self.state_handler.get_current_state(
room_id, event_type, state_key
)
elif member_event.membership == Membership.LEAVE:
key = (event_type, state_key)
room_state = yield self.store.get_state_for_events(
room_id, [member_event.event_id], [key]
)
data = room_state[member_event.event_id].get(key)
have_joined = yield self.auth.check_joined_room(room_id, user_id)
if not have_joined:
raise RoomError(403, "User not in room.")
data = yield self.state_handler.get_current_state(
room_id, event_type, state_key
)
defer.returnValue(data)
@defer.inlineCallbacks
def get_feedback(self, event_id):
# yield self.auth.check_joined_room(room_id, user_id)
# Pull out the feedback from the db
fb = yield self.store.get_feedback(event_id)
if fb:
defer.returnValue(fb)
defer.returnValue(None)
@defer.inlineCallbacks
def get_state_events(self, user_id, room_id):
"""Retrieve all state events for a given room. If the user is
joined to the room then return the current state. If the user has
left the room return the state events from when they left.
"""Retrieve all state events for a given room.
Args:
user_id(str): The user requesting state events.
@@ -308,23 +221,18 @@ class MessageHandler(BaseHandler):
Returns:
A list of dicts representing state events. [{}, {}, {}]
"""
member_event = yield self.auth.check_user_was_in_room(room_id, user_id)
if member_event.membership == Membership.JOIN:
room_state = yield self.state_handler.get_current_state(room_id)
elif member_event.membership == Membership.LEAVE:
room_state = yield self.store.get_state_for_events(
room_id, [member_event.event_id], None
)
room_state = room_state[member_event.event_id]
yield self.auth.check_joined_room(room_id, user_id)
# TODO: This is duplicating logic from snapshot_all_rooms
current_state = yield self.state_handler.get_current_state(room_id)
now = self.clock.time_msec()
defer.returnValue(
[serialize_event(c, now) for c in room_state.values()]
[serialize_event(c, now) for c in current_state.values()]
)
@defer.inlineCallbacks
def snapshot_all_rooms(self, user_id=None, pagin_config=None, as_client_event=True):
def snapshot_all_rooms(self, user_id=None, pagin_config=None,
feedback=False, as_client_event=True):
"""Retrieve a snapshot of all rooms the user is invited or has joined.
This snapshot may include messages for all rooms where the user is
@@ -334,6 +242,7 @@ class MessageHandler(BaseHandler):
user_id (str): The ID of the user making the request.
pagin_config (synapse.api.streams.PaginationConfig): The pagination
config used to determine how many messages *PER ROOM* to return.
feedback (bool): True to get feedback along with these messages.
as_client_event (bool): True to get events in client-server format.
Returns:
A list of dicts with "room_id" and "membership" keys for all rooms
@@ -343,9 +252,7 @@ class MessageHandler(BaseHandler):
"""
room_list = yield self.store.get_rooms_for_user_where_membership_is(
user_id=user_id,
membership_list=[
Membership.INVITE, Membership.JOIN, Membership.LEAVE
]
membership_list=[Membership.INVITE, Membership.JOIN]
)
user = UserID.from_string(user_id)
@@ -360,19 +267,14 @@ class MessageHandler(BaseHandler):
user, pagination_config.get_source_config("presence"), None
)
receipt_stream = self.hs.get_event_sources().sources["receipt"]
receipt, _ = yield receipt_stream.get_pagination_rows(
user, pagination_config.get_source_config("receipt"), None
)
public_room_ids = yield self.store.get_public_room_ids()
public_rooms = yield self.store.get_rooms(is_public=True)
public_room_ids = [r["room_id"] for r in public_rooms]
limit = pagin_config.limit
if limit is None:
limit = 10
@defer.inlineCallbacks
def handle_room(event):
for event in room_list:
d = {
"room_id": event.room_id,
"membership": event.membership,
@@ -387,37 +289,13 @@ class MessageHandler(BaseHandler):
rooms_ret.append(d)
if event.membership not in (Membership.JOIN, Membership.LEAVE):
return
if event.membership != Membership.JOIN:
continue
try:
if event.membership == Membership.JOIN:
room_end_token = now_token.room_key
deferred_room_state = self.state_handler.get_current_state(
event.room_id
)
elif event.membership == Membership.LEAVE:
room_end_token = "s%d" % (event.stream_ordering,)
deferred_room_state = self.store.get_state_for_events(
event.room_id, [event.event_id], None
)
deferred_room_state.addCallback(
lambda states: states[event.event_id]
)
(messages, token), current_state = yield defer.gatherResults(
[
self.store.get_recent_events_for_room(
event.room_id,
limit=limit,
end_token=room_end_token,
),
deferred_room_state,
]
).addErrback(unwrapFirstError)
messages = yield self._filter_events_for_client(
user_id, event.room_id, messages
messages, token = yield self.store.get_recent_events_for_room(
event.room_id,
limit=limit,
end_token=now_token.room_key,
)
start_token = now_token.copy_and_replace("room_key", token[0])
@@ -433,6 +311,9 @@ class MessageHandler(BaseHandler):
"end": end_token.to_string(),
}
current_state = yield self.state_handler.get_current_state(
event.room_id
)
d["state"] = [
serialize_event(c, time_now, as_client_event)
for c in current_state.values()
@@ -440,106 +321,26 @@ class MessageHandler(BaseHandler):
except:
logger.exception("Failed to get snapshot")
# Only do N rooms at once
n = 5
d_list = [handle_room(e) for e in room_list]
for i in range(0, len(d_list), n):
yield defer.gatherResults(
d_list[i:i + n],
consumeErrors=True
).addErrback(unwrapFirstError)
ret = {
"rooms": rooms_ret,
"presence": presence,
"receipts": receipt,
"end": now_token.to_string(),
"end": now_token.to_string()
}
defer.returnValue(ret)
@defer.inlineCallbacks
def room_initial_sync(self, user_id, room_id, pagin_config=None):
"""Capture the a snapshot of a room. If user is currently a member of
the room this will be what is currently in the room. If the user left
the room this will be what was in the room when they left.
Args:
user_id(str): The user to get a snapshot for.
room_id(str): The room to get a snapshot of.
pagin_config(synapse.streams.config.PaginationConfig):
The pagination config used to determine how many messages to
return.
Raises:
AuthError if the user wasn't in the room.
Returns:
A JSON serialisable dict with the snapshot of the room.
"""
member_event = yield self.auth.check_user_was_in_room(room_id, user_id)
if member_event.membership == Membership.JOIN:
result = yield self._room_initial_sync_joined(
user_id, room_id, pagin_config, member_event
)
elif member_event.membership == Membership.LEAVE:
result = yield self._room_initial_sync_parted(
user_id, room_id, pagin_config, member_event
)
defer.returnValue(result)
@defer.inlineCallbacks
def _room_initial_sync_parted(self, user_id, room_id, pagin_config,
member_event):
room_state = yield self.store.get_state_for_events(
member_event.room_id, [member_event.event_id], None
)
room_state = room_state[member_event.event_id]
limit = pagin_config.limit if pagin_config else None
if limit is None:
limit = 10
stream_token = yield self.store.get_stream_token_for_event(
member_event.event_id
)
messages, token = yield self.store.get_recent_events_for_room(
room_id,
limit=limit,
end_token=stream_token
)
messages = yield self._filter_events_for_client(
user_id, room_id, messages
)
start_token = StreamToken(token[0], 0, 0, 0)
end_token = StreamToken(token[1], 0, 0, 0)
time_now = self.clock.time_msec()
defer.returnValue({
"membership": member_event.membership,
"room_id": room_id,
"messages": {
"chunk": [serialize_event(m, time_now) for m in messages],
"start": start_token.to_string(),
"end": end_token.to_string(),
},
"state": [serialize_event(s, time_now) for s in room_state.values()],
"presence": [],
"receipts": [],
})
@defer.inlineCallbacks
def _room_initial_sync_joined(self, user_id, room_id, pagin_config,
member_event):
def room_initial_sync(self, user_id, room_id, pagin_config=None,
feedback=False):
current_state = yield self.state.get_current_state(
room_id=room_id,
)
yield self.auth.check_joined_room(
room_id, user_id,
current_state=current_state
)
# TODO(paul): I wish I was called with user objects not user_id
# strings...
auth_user = UserID.from_string(user_id)
@@ -551,12 +352,23 @@ class MessageHandler(BaseHandler):
for x in current_state.values()
]
member_event = current_state.get((EventTypes.Member, user_id,))
now_token = yield self.hs.get_event_sources().get_current_token()
limit = pagin_config.limit if pagin_config else None
if limit is None:
limit = 10
messages, token = yield self.store.get_recent_events_for_room(
room_id,
limit=limit,
end_token=now_token.room_key,
)
start_token = now_token.copy_and_replace("room_key", token[0])
end_token = now_token.copy_and_replace("room_key", token[1])
room_members = [
m for m in current_state.values()
if m.type == EventTypes.Member
@@ -564,39 +376,19 @@ class MessageHandler(BaseHandler):
]
presence_handler = self.hs.get_handlers().presence_handler
@defer.inlineCallbacks
def get_presence():
states = yield presence_handler.get_states(
target_users=[UserID.from_string(m.user_id) for m in room_members],
auth_user=auth_user,
as_event=True,
check_auth=False,
)
defer.returnValue(states.values())
receipts_handler = self.hs.get_handlers().receipts_handler
presence, receipts, (messages, token) = yield defer.gatherResults(
[
get_presence(),
receipts_handler.get_receipts_for_room(room_id, now_token.receipt_key),
self.store.get_recent_events_for_room(
room_id,
limit=limit,
end_token=now_token.room_key,
presence = []
for m in room_members:
try:
member_presence = yield presence_handler.get_state(
target_user=UserID.from_string(m.user_id),
auth_user=auth_user,
as_event=True,
)
presence.append(member_presence)
except SynapseError:
logger.exception(
"Failed to get member presence of %r", m.user_id
)
],
consumeErrors=True,
).addErrback(unwrapFirstError)
messages = yield self._filter_events_for_client(
user_id, room_id, messages
)
start_token = now_token.copy_and_replace("room_key", token[0])
end_token = now_token.copy_and_replace("room_key", token[1])
time_now = self.clock.time_msec()
@@ -609,6 +401,5 @@ class MessageHandler(BaseHandler):
"end": end_token.to_string(),
},
"state": state,
"presence": presence,
"receipts": receipts,
"presence": presence
})

View File

@@ -18,10 +18,9 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError, AuthError
from synapse.api.constants import PresenceState
from synapse.util.logcontext import PreserveLoggingContext
from synapse.util.logutils import log_function
from synapse.util.logcontext import PreserveLoggingContext
from synapse.types import UserID
import synapse.metrics
from ._base import BaseHandler
@@ -30,15 +29,6 @@ import logging
logger = logging.getLogger(__name__)
metrics = synapse.metrics.get_metrics_for(__name__)
# Don't bother bumping "last active" time if it differs by less than 60 seconds
LAST_ACTIVE_GRANULARITY = 60*1000
# Keep no more than this number of offline serial revisions
MAX_OFFLINE_SERIALS = 1000
# TODO(paul): Maybe there's one of these I can steal from somewhere
def partition(l, func):
@@ -138,23 +128,11 @@ class PresenceHandler(BaseHandler):
self._remote_sendmap = {}
# map remote users to sets of local users who're interested in them
self._remote_recvmap = {}
# list of (serial, set of(userids)) tuples, ordered by serial, latest
# first
self._remote_offline_serials = []
# map any user to a UserPresenceCache
self._user_cachemap = {}
self._user_cachemap_latest_serial = 0
# map room_ids to the latest presence serial for a member of that
# room
self._room_serials = {}
metrics.register_callback(
"userCachemap:size",
lambda: len(self._user_cachemap),
)
def _get_or_make_usercache(self, user):
"""If the cache entry doesn't exist, initialise a new one."""
if user not in self._user_cachemap:
@@ -191,38 +169,24 @@ class PresenceHandler(BaseHandler):
defer.returnValue(False)
@defer.inlineCallbacks
def get_state(self, target_user, auth_user, as_event=False, check_auth=True):
"""Get the current presence state of the given user.
Args:
target_user (UserID): The user whose presence we want
auth_user (UserID): The user requesting the presence, used for
checking if said user is allowed to see the persence of the
`target_user`
as_event (bool): Format the return as an event or not?
check_auth (bool): Perform the auth checks or not?
Returns:
dict: The presence state of the `target_user`, whose format depends
on the `as_event` argument.
"""
def get_state(self, target_user, auth_user, as_event=False):
if self.hs.is_mine(target_user):
if check_auth:
visible = yield self.is_presence_visible(
observer_user=auth_user,
observed_user=target_user
)
visible = yield self.is_presence_visible(
observer_user=auth_user,
observed_user=target_user
)
if not visible:
raise SynapseError(404, "Presence information not visible")
if not visible:
raise SynapseError(404, "Presence information not visible")
state = yield self.store.get_presence_state(target_user.localpart)
if "mtime" in state:
del state["mtime"]
state["presence"] = state.pop("state")
if target_user in self._user_cachemap:
state = self._user_cachemap[target_user].get_state()
else:
state = yield self.store.get_presence_state(target_user.localpart)
if "mtime" in state:
del state["mtime"]
state["presence"] = state.pop("state")
cached_state = self._user_cachemap[target_user].get_state()
if "last_active" in cached_state:
state["last_active"] = cached_state["last_active"]
else:
# TODO(paul): Have remote server send us permissions set
state = self._get_or_offline_usercache(target_user).get_state()
@@ -246,81 +210,6 @@ class PresenceHandler(BaseHandler):
else:
defer.returnValue(state)
@defer.inlineCallbacks
def get_states(self, target_users, auth_user, as_event=False, check_auth=True):
"""A batched version of the `get_state` method that accepts a list of
`target_users`
Args:
target_users (list): The list of UserID's whose presence we want
auth_user (UserID): The user requesting the presence, used for
checking if said user is allowed to see the persence of the
`target_users`
as_event (bool): Format the return as an event or not?
check_auth (bool): Perform the auth checks or not?
Returns:
dict: A mapping from user -> presence_state
"""
local_users, remote_users = partitionbool(
target_users,
lambda u: self.hs.is_mine(u)
)
if check_auth:
for user in local_users:
visible = yield self.is_presence_visible(
observer_user=auth_user,
observed_user=user
)
if not visible:
raise SynapseError(404, "Presence information not visible")
results = {}
if local_users:
for user in local_users:
if user in self._user_cachemap:
results[user] = self._user_cachemap[user].get_state()
local_to_user = {u.localpart: u for u in local_users}
states = yield self.store.get_presence_states(
[u.localpart for u in local_users if u not in results]
)
for local_part, state in states.items():
if state is None:
continue
res = {"presence": state["state"]}
if "status_msg" in state and state["status_msg"]:
res["status_msg"] = state["status_msg"]
results[local_to_user[local_part]] = res
for user in remote_users:
# TODO(paul): Have remote server send us permissions set
results[user] = self._get_or_offline_usercache(user).get_state()
for state in results.values():
if "last_active" in state:
state["last_active_ago"] = int(
self.clock.time_msec() - state.pop("last_active")
)
if as_event:
for user, state in results.items():
content = state
content["user_id"] = user.to_string()
if "last_active" in content:
content["last_active_ago"] = int(
self._clock.time_msec() - content.pop("last_active")
)
results[user] = {"type": "m.presence", "content": content}
defer.returnValue(results)
@defer.inlineCallbacks
@log_function
def set_state(self, target_user, auth_user, state):
@@ -371,53 +260,29 @@ class PresenceHandler(BaseHandler):
now_online = state["presence"] != PresenceState.OFFLINE
was_polling = target_user in self._user_cachemap
if now_online and not was_polling:
self.start_polling_presence(target_user, state=state)
elif not now_online and was_polling:
self.stop_polling_presence(target_user)
with PreserveLoggingContext():
if now_online and not was_polling:
self.start_polling_presence(target_user, state=state)
elif not now_online and was_polling:
self.stop_polling_presence(target_user)
# TODO(paul): perform a presence push as part of start/stop poll so
# we don't have to do this all the time
self.changed_presencelike_data(target_user, state)
# TODO(paul): perform a presence push as part of start/stop poll so
# we don't have to do this all the time
self.changed_presencelike_data(target_user, state)
def bump_presence_active_time(self, user, now=None):
if now is None:
now = self.clock.time_msec()
prev_state = self._get_or_make_usercache(user)
if now - prev_state.state.get("last_active", 0) < LAST_ACTIVE_GRANULARITY:
return
self.changed_presencelike_data(user, {"last_active": now})
def get_joined_rooms_for_user(self, user):
"""Get the list of rooms a user is joined to.
Args:
user(UserID): The user.
Returns:
A Deferred of a list of room id strings.
"""
rm_handler = self.homeserver.get_handlers().room_member_handler
return rm_handler.get_joined_rooms_for_user(user)
def get_joined_users_for_room_id(self, room_id):
rm_handler = self.homeserver.get_handlers().room_member_handler
return rm_handler.get_room_members(room_id)
@defer.inlineCallbacks
def changed_presencelike_data(self, user, state):
"""Updates the presence state of a local user.
statuscache = self._get_or_make_usercache(user)
Args:
user(UserID): The user being updated.
state(dict): The new presence state for the user.
Returns:
A Deferred
"""
self._user_cachemap_latest_serial += 1
statuscache = yield self.update_presence_cache(user, state)
yield self.push_presence(user, statuscache=statuscache)
statuscache.update(state, serial=self._user_cachemap_latest_serial)
return self.push_presence(user, statuscache=statuscache)
@log_function
def started_user_eventstream(self, user):
@@ -431,21 +296,14 @@ class PresenceHandler(BaseHandler):
@defer.inlineCallbacks
def user_joined_room(self, user, room_id):
"""Called via the distributor whenever a user joins a room.
Notifies the new member of the presence of the current members.
Notifies the current members of the room of the new member's presence.
Args:
user(UserID): The user who joined the room.
room_id(str): The room id the user joined.
"""
if self.hs.is_mine(user):
statuscache = self._get_or_make_usercache(user)
# No actual update but we need to bump the serial anyway for the
# event source
self._user_cachemap_latest_serial += 1
statuscache = yield self.update_presence_cache(
user, room_ids=[room_id]
)
statuscache.update({}, serial=self._user_cachemap_latest_serial)
self.push_update_to_local_and_remote(
observed_user=user,
room_ids=[room_id],
@@ -453,22 +311,18 @@ class PresenceHandler(BaseHandler):
)
# We also want to tell them about current presence of people.
curr_users = yield self.get_joined_users_for_room_id(room_id)
rm_handler = self.homeserver.get_handlers().room_member_handler
curr_users = yield rm_handler.get_room_members(room_id)
for local_user in [c for c in curr_users if self.hs.is_mine(c)]:
statuscache = yield self.update_presence_cache(
local_user, room_ids=[room_id], add_to_cache=False
)
self.push_update_to_local_and_remote(
observed_user=local_user,
users_to_push=[user],
statuscache=statuscache,
statuscache=self._get_or_offline_usercache(local_user),
)
@defer.inlineCallbacks
def send_invite(self, observer_user, observed_user):
"""Request the presence of a local or remote user for a local user"""
if not self.hs.is_mine(observer_user):
raise SynapseError(400, "User is not hosted on this Home Server")
@@ -503,15 +357,6 @@ class PresenceHandler(BaseHandler):
@defer.inlineCallbacks
def invite_presence(self, observed_user, observer_user):
"""Handles a m.presence_invite EDU. A remote or local user has
requested presence updates for a local user. If the invite is accepted
then allow the local or remote user to see the presence of the local
user.
Args:
observed_user(UserID): The local user whose presence is requested.
observer_user(UserID): The remote or local user requesting presence.
"""
accept = yield self._should_accept_invite(observed_user, observer_user)
if accept:
@@ -538,34 +383,16 @@ class PresenceHandler(BaseHandler):
@defer.inlineCallbacks
def accept_presence(self, observed_user, observer_user):
"""Handles a m.presence_accept EDU. Mark a presence invite from a
local or remote user as accepted in a local user's presence list.
Starts polling for presence updates from the local or remote user.
Args:
observed_user(UserID): The user to update in the presence list.
observer_user(UserID): The owner of the presence list to update.
"""
yield self.store.set_presence_list_accepted(
observer_user.localpart, observed_user.to_string()
)
self.start_polling_presence(
observer_user, target_user=observed_user
)
with PreserveLoggingContext():
self.start_polling_presence(
observer_user, target_user=observed_user
)
@defer.inlineCallbacks
def deny_presence(self, observed_user, observer_user):
"""Handle a m.presence_deny EDU. Removes a local or remote user from a
local user's presence list.
Args:
observed_user(UserID): The local or remote user to remove from the
list.
observer_user(UserID): The local owner of the presence list.
Returns:
A Deferred.
"""
yield self.store.del_presence_list(
observer_user.localpart, observed_user.to_string()
)
@@ -574,16 +401,6 @@ class PresenceHandler(BaseHandler):
@defer.inlineCallbacks
def drop(self, observed_user, observer_user):
"""Remove a local or remote user from a local user's presence list and
unsubscribe the local user from updates that user.
Args:
observed_user(UserId): The local or remote user to remove from the
list.
observer_user(UserId): The local owner of the presence list.
Returns:
A Deferred.
"""
if not self.hs.is_mine(observer_user):
raise SynapseError(400, "User is not hosted on this Home Server")
@@ -591,66 +408,34 @@ class PresenceHandler(BaseHandler):
observer_user.localpart, observed_user.to_string()
)
self.stop_polling_presence(
observer_user, target_user=observed_user
)
with PreserveLoggingContext():
self.stop_polling_presence(
observer_user, target_user=observed_user
)
@defer.inlineCallbacks
def get_presence_list(self, observer_user, accepted=None):
"""Get the presence list for a local user. The retured list includes
the current presence state for each user listed.
Args:
observer_user(UserID): The local user whose presence list to fetch.
accepted(bool or None): If not none then only include users who
have or have not accepted the presence invite request.
Returns:
A Deferred list of presence state events.
"""
if not self.hs.is_mine(observer_user):
raise SynapseError(400, "User is not hosted on this Home Server")
presence_list = yield self.store.get_presence_list(
presence = yield self.store.get_presence_list(
observer_user.localpart, accepted=accepted
)
results = []
for row in presence_list:
observed_user = UserID.from_string(row["observed_user_id"])
result = {
"observed_user": observed_user, "accepted": row["accepted"]
}
result.update(
self._get_or_offline_usercache(observed_user).get_state()
)
if "last_active" in result:
result["last_active_ago"] = int(
self.clock.time_msec() - result.pop("last_active")
for p in presence:
observed_user = UserID.from_string(p.pop("observed_user_id"))
p["observed_user"] = observed_user
p.update(self._get_or_offline_usercache(observed_user).get_state())
if "last_active" in p:
p["last_active_ago"] = int(
self.clock.time_msec() - p.pop("last_active")
)
results.append(result)
defer.returnValue(results)
defer.returnValue(presence)
@defer.inlineCallbacks
@log_function
def start_polling_presence(self, user, target_user=None, state=None):
"""Subscribe a local user to presence updates from a local or remote
user. If no target_user is supplied then subscribe to all users stored
in the presence list for the local user.
Additonally this pushes the current presence state of this user to all
target_users. That state can be provided directly or will be read from
the stored state for the local user.
Also this attempts to notify the local user of the current state of
any local target users.
Args:
user(UserID): The local user that whishes for presence updates.
target_user(UserID): The local or remote user whose updates are
wanted.
state(dict): Optional presence state for the local user.
"""
logger.debug("Start polling for presence from %s", user)
if target_user:
@@ -666,7 +451,8 @@ class PresenceHandler(BaseHandler):
# Also include people in all my rooms
room_ids = yield self.get_joined_rooms_for_user(user)
rm_handler = self.homeserver.get_handlers().room_member_handler
room_ids = yield rm_handler.get_joined_rooms_for_user(user)
if state is None:
state = yield self.store.get_presence_state(user.localpart)
@@ -690,7 +476,9 @@ class PresenceHandler(BaseHandler):
# We want to tell the person that just came online
# presence state of people they are interested in?
self.push_update_to_clients(
observed_user=target_user,
users_to_push=[user],
statuscache=self._get_or_offline_usercache(target_user),
)
deferreds = []
@@ -707,12 +495,6 @@ class PresenceHandler(BaseHandler):
yield defer.DeferredList(deferreds, consumeErrors=True)
def _start_polling_local(self, user, target_user):
"""Subscribe a local user to presence updates for a local user
Args:
user(UserId): The local user that wishes for updates.
target_user(UserId): The local users whose updates are wanted.
"""
target_localpart = target_user.localpart
if target_localpart not in self._local_pushmap:
@@ -721,17 +503,6 @@ class PresenceHandler(BaseHandler):
self._local_pushmap[target_localpart].add(user)
def _start_polling_remote(self, user, domain, remoteusers):
"""Subscribe a local user to presence updates for remote users on a
given remote domain.
Args:
user(UserID): The local user that wishes for updates.
domain(str): The remote server the local user wants updates from.
remoteusers(UserID): The remote users that local user wants to be
told about.
Returns:
A Deferred.
"""
to_poll = set()
for u in remoteusers:
@@ -752,17 +523,6 @@ class PresenceHandler(BaseHandler):
@log_function
def stop_polling_presence(self, user, target_user=None):
"""Unsubscribe a local user from presence updates from a local or
remote user. If no target user is supplied then unsubscribe the user
from all presence updates that the user had subscribed to.
Args:
user(UserID): The local user that no longer wishes for updates.
target_user(UserID or None): The user whose updates are no longer
wanted.
Returns:
A Deferred.
"""
logger.debug("Stop polling for presence from %s", user)
if not target_user or self.hs.is_mine(target_user):
@@ -791,13 +551,6 @@ class PresenceHandler(BaseHandler):
return defer.DeferredList(deferreds, consumeErrors=True)
def _stop_polling_local(self, user, target_user):
"""Unsubscribe a local user from presence updates from a local user on
this server.
Args:
user(UserID): The local user that no longer wishes for updates.
target_user(UserID): The user whose updates are no longer wanted.
"""
for localpart in self._local_pushmap.keys():
if target_user and localpart != target_user.localpart:
continue
@@ -810,17 +563,6 @@ class PresenceHandler(BaseHandler):
@log_function
def _stop_polling_remote(self, user, domain, remoteusers):
"""Unsubscribe a local user from presence updates from remote users on
a given domain.
Args:
user(UserID): The local user that no longer wishes for updates.
domain(str): The remote server to unsubscribe from.
remoteusers([UserID]): The users on that remote server that the
local user no longer wishes to be updated about.
Returns:
A Deferred.
"""
to_unpoll = set()
for u in remoteusers:
@@ -842,19 +584,6 @@ class PresenceHandler(BaseHandler):
@defer.inlineCallbacks
@log_function
def push_presence(self, user, statuscache):
"""
Notify local and remote users of a change in presence of a local user.
Pushes the update to local clients and remote domains that are directly
subscribed to the presence of the local user.
Also pushes that update to any local user or remote domain that shares
a room with the local user.
Args:
user(UserID): The local user whose presence was updated.
statuscache(UserPresenceCache): Cache of the user's presence state
Returns:
A Deferred.
"""
assert(self.hs.is_mine(user))
logger.debug("Pushing presence update from %s", user)
@@ -866,7 +595,8 @@ class PresenceHandler(BaseHandler):
# and also user is informed of server-forced pushes
localusers.add(user)
room_ids = yield self.get_joined_rooms_for_user(user)
rm_handler = self.homeserver.get_handlers().room_member_handler
room_ids = yield rm_handler.get_joined_rooms_for_user(user)
if not localusers and not room_ids:
defer.returnValue(None)
@@ -881,23 +611,44 @@ class PresenceHandler(BaseHandler):
yield self.distributor.fire("user_presence_changed", user, statuscache)
@defer.inlineCallbacks
def incoming_presence(self, origin, content):
"""Handle an incoming m.presence EDU.
For each presence update in the "push" list update our local cache and
notify the appropriate local clients. Only clients that share a room
or are directly subscribed to the presence for a user should be
notified of the update.
For each subscription request in the "poll" list start pushing presence
updates to the remote server.
For unsubscribe request in the "unpoll" list stop pushing presence
updates to the remote server.
def _push_presence_remote(self, user, destination, state=None):
if state is None:
state = yield self.store.get_presence_state(user.localpart)
del state["mtime"]
state["presence"] = state.pop("state")
Args:
orgin(str): The source of this m.presence EDU.
content(dict): The content of this m.presence EDU.
Returns:
A Deferred.
"""
if user in self._user_cachemap:
state["last_active"] = (
self._user_cachemap[user].get_state()["last_active"]
)
yield self.distributor.fire(
"collect_presencelike_data", user, state
)
if "last_active" in state:
state = dict(state)
state["last_active_ago"] = int(
self.clock.time_msec() - state.pop("last_active")
)
user_state = {
"user_id": user.to_string(),
}
user_state.update(**state)
yield self.federation.send_edu(
destination=destination,
edu_type="m.presence",
content={
"push": [
user_state,
],
}
)
@defer.inlineCallbacks
def incoming_presence(self, origin, content):
deferreds = []
for push in content.get("push", []):
@@ -911,7 +662,8 @@ class PresenceHandler(BaseHandler):
" | %d interested local observers %r", len(observers), observers
)
room_ids = yield self.get_joined_rooms_for_user(user)
rm_handler = self.homeserver.get_handlers().room_member_handler
room_ids = yield rm_handler.get_joined_rooms_for_user(user)
if room_ids:
logger.debug(" | %d interested room IDs %r", len(room_ids), room_ids)
@@ -930,35 +682,24 @@ class PresenceHandler(BaseHandler):
self.clock.time_msec() - state.pop("last_active_ago")
)
statuscache = self._get_or_make_usercache(user)
self._user_cachemap_latest_serial += 1
yield self.update_presence_cache(user, state, room_ids=room_ids)
statuscache.update(state, serial=self._user_cachemap_latest_serial)
if not observers and not room_ids:
logger.debug(" | no interested observers or room IDs")
continue
self.push_update_to_clients(
users_to_push=observers, room_ids=room_ids
observed_user=user,
users_to_push=observers,
room_ids=room_ids,
statuscache=statuscache,
)
user_id = user.to_string()
if state["presence"] == PresenceState.OFFLINE:
self._remote_offline_serials.insert(
0,
(self._user_cachemap_latest_serial, set([user_id]))
)
while len(self._remote_offline_serials) > MAX_OFFLINE_SERIALS:
self._remote_offline_serials.pop() # remove the oldest
del self._user_cachemap[user]
else:
# Remove the user from remote_offline_serials now that they're
# no longer offline
for idx, elem in enumerate(self._remote_offline_serials):
(_, user_ids) = elem
user_ids.discard(user_id)
if not user_ids:
self._remote_offline_serials.pop(idx)
for poll in content.get("poll", []):
user = UserID.from_string(poll)
@@ -987,58 +728,13 @@ class PresenceHandler(BaseHandler):
if not self._remote_sendmap[user]:
del self._remote_sendmap[user]
yield defer.DeferredList(deferreds, consumeErrors=True)
@defer.inlineCallbacks
def update_presence_cache(self, user, state={}, room_ids=None,
add_to_cache=True):
"""Update the presence cache for a user with a new state and bump the
serial to the latest value.
Args:
user(UserID): The user being updated
state(dict): The presence state being updated
room_ids(None or list of str): A list of room_ids to update. If
room_ids is None then fetch the list of room_ids the user is
joined to.
add_to_cache: Whether to add an entry to the presence cache if the
user isn't already in the cache.
Returns:
A Deferred UserPresenceCache for the user being updated.
"""
if room_ids is None:
room_ids = yield self.get_joined_rooms_for_user(user)
for room_id in room_ids:
self._room_serials[room_id] = self._user_cachemap_latest_serial
if add_to_cache:
statuscache = self._get_or_make_usercache(user)
else:
statuscache = self._get_or_offline_usercache(user)
statuscache.update(state, serial=self._user_cachemap_latest_serial)
defer.returnValue(statuscache)
with PreserveLoggingContext():
yield defer.DeferredList(deferreds, consumeErrors=True)
@defer.inlineCallbacks
def push_update_to_local_and_remote(self, observed_user, statuscache,
users_to_push=[], room_ids=[],
remote_domains=[]):
"""Notify local clients and remote servers of a change in the presence
of a user.
Args:
observed_user(UserID): The user to push the presence state for.
statuscache(UserPresenceCache): The cache for the presence state to
push.
users_to_push([UserID]): A list of local and remote users to
notify.
room_ids([str]): Notify the local and remote occupants of these
rooms.
remote_domains([str]): A list of remote servers to notify in
addition to those implied by the users_to_push and the
room_ids.
Returns:
A Deferred.
"""
localusers, remoteusers = partitionbool(
users_to_push,
@@ -1048,7 +744,10 @@ class PresenceHandler(BaseHandler):
localusers = set(localusers)
self.push_update_to_clients(
users_to_push=localusers, room_ids=room_ids
observed_user=observed_user,
users_to_push=localusers,
room_ids=room_ids,
statuscache=statuscache,
)
remote_domains = set(remote_domains)
@@ -1073,65 +772,11 @@ class PresenceHandler(BaseHandler):
defer.returnValue((localusers, remote_domains))
def push_update_to_clients(self, users_to_push=[], room_ids=[]):
"""Notify clients of a new presence event.
Args:
users_to_push([UserID]): List of users to notify.
room_ids([str]): List of room_ids to notify.
"""
with PreserveLoggingContext():
self.notifier.on_new_event(
"presence_key",
self._user_cachemap_latest_serial,
users_to_push,
room_ids,
)
@defer.inlineCallbacks
def _push_presence_remote(self, user, destination, state=None):
"""Push a user's presence to a remote server. If a presence state event
that event is sent. Otherwise a new state event is constructed from the
stored presence state.
The last_active is replaced with last_active_ago in case the wallclock
time on the remote server is different to the time on this server.
Sends an EDU to the remote server with the current presence state.
Args:
user(UserID): The user to push the presence state for.
destination(str): The remote server to send state to.
state(dict): The state to push, or None to use the current stored
state.
Returns:
A Deferred.
"""
if state is None:
state = yield self.store.get_presence_state(user.localpart)
del state["mtime"]
state["presence"] = state.pop("state")
if user in self._user_cachemap:
state["last_active"] = (
self._user_cachemap[user].get_state()["last_active"]
)
yield self.distributor.fire(
"collect_presencelike_data", user, state
)
if "last_active" in state:
state = dict(state)
state["last_active_ago"] = int(
self.clock.time_msec() - state.pop("last_active")
)
user_state = {"user_id": user.to_string(), }
user_state.update(state)
yield self.federation.send_edu(
destination=destination,
edu_type="m.presence",
content={"push": [user_state, ], }
def push_update_to_clients(self, observed_user, users_to_push=[],
room_ids=[], statuscache=None):
self.notifier.on_new_user_event(
users_to_push,
room_ids,
)
@@ -1140,65 +785,62 @@ class PresenceEventSource(object):
self.hs = hs
self.clock = hs.get_clock()
@defer.inlineCallbacks
def is_visible(self, observer_user, observed_user):
if observer_user == observed_user:
defer.returnValue(True)
presence = self.hs.get_handlers().presence_handler
if (yield presence.store.user_rooms_intersect(
[u.to_string() for u in observer_user, observed_user])):
defer.returnValue(True)
if self.hs.is_mine(observed_user):
pushmap = presence._local_pushmap
defer.returnValue(
observed_user.localpart in pushmap and
observer_user in pushmap[observed_user.localpart]
)
else:
recvmap = presence._remote_recvmap
defer.returnValue(
observed_user in recvmap and
observer_user in recvmap[observed_user]
)
@defer.inlineCallbacks
@log_function
def get_new_events_for_user(self, user, from_key, limit):
from_key = int(from_key)
observer_user = user
presence = self.hs.get_handlers().presence_handler
cachemap = presence._user_cachemap
max_serial = presence._user_cachemap_latest_serial
clock = self.clock
latest_serial = 0
user_ids_to_check = {user}
presence_list = yield presence.store.get_presence_list(
user.localpart, accepted=True
)
if presence_list is not None:
user_ids_to_check |= set(
UserID.from_string(p["observed_user_id"]) for p in presence_list
)
room_ids = yield presence.get_joined_rooms_for_user(user)
for room_id in set(room_ids) & set(presence._room_serials):
if presence._room_serials[room_id] > from_key:
joined = yield presence.get_joined_users_for_room_id(room_id)
user_ids_to_check |= set(joined)
updates = []
for observed_user in user_ids_to_check & set(cachemap):
# TODO(paul): use a DeferredList ? How to limit concurrency.
for observed_user in cachemap.keys():
cached = cachemap[observed_user]
if cached.serial <= from_key or cached.serial > max_serial:
if cached.serial <= from_key:
continue
latest_serial = max(cached.serial, latest_serial)
updates.append(cached.make_event(user=observed_user, clock=clock))
if (yield self.is_visible(observer_user, observed_user)):
updates.append((observed_user, cached))
# TODO(paul): limit
for serial, user_ids in presence._remote_offline_serials:
if serial <= from_key:
break
if serial > max_serial:
continue
latest_serial = max(latest_serial, serial)
for u in user_ids:
updates.append({
"type": "m.presence",
"content": {"user_id": u, "presence": PresenceState.OFFLINE},
})
# TODO(paul): For the v2 API we want to tell the client their from_key
# is too old if we fell off the end of the _remote_offline_serials
# list, and get them to invalidate+resync. In v1 we have no such
# concept so this is a best-effort result.
if updates:
defer.returnValue((updates, latest_serial))
clock = self.clock
latest_serial = max([x[1].serial for x in updates])
data = [x[1].make_event(user=x[0], clock=clock) for x in updates]
defer.returnValue((data, latest_serial))
else:
defer.returnValue(([], presence._user_cachemap_latest_serial))
@@ -1210,6 +852,8 @@ class PresenceEventSource(object):
def get_pagination_rows(self, user, pagination_config, key):
# TODO (erikj): Does this make sense? Ordering?
observer_user = user
from_key = int(pagination_config.from_key)
if pagination_config.to_key:
@@ -1220,26 +864,14 @@ class PresenceEventSource(object):
presence = self.hs.get_handlers().presence_handler
cachemap = presence._user_cachemap
user_ids_to_check = {user}
presence_list = yield presence.store.get_presence_list(
user.localpart, accepted=True
)
if presence_list is not None:
user_ids_to_check |= set(
UserID.from_string(p["observed_user_id"]) for p in presence_list
)
room_ids = yield presence.get_joined_rooms_for_user(user)
for room_id in set(room_ids) & set(presence._room_serials):
if presence._room_serials[room_id] >= from_key:
joined = yield presence.get_joined_users_for_room_id(room_id)
user_ids_to_check |= set(joined)
updates = []
for observed_user in user_ids_to_check & set(cachemap):
# TODO(paul): use a DeferredList ? How to limit concurrency.
for observed_user in cachemap.keys():
if not (to_key < cachemap[observed_user].serial <= from_key):
continue
updates.append((observed_user, cachemap[observed_user]))
if (yield self.is_visible(observer_user, observed_user)):
updates.append((observed_user, cachemap[observed_user]))
# TODO(paul): limit

View File

@@ -17,8 +17,8 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError, AuthError, CodeMessageException
from synapse.api.constants import EventTypes, Membership
from synapse.util.logcontext import PreserveLoggingContext
from synapse.types import UserID
from synapse.util import unwrapFirstError
from ._base import BaseHandler
@@ -88,9 +88,6 @@ class ProfileHandler(BaseHandler):
if target_user != auth_user:
raise AuthError(400, "Cannot set another user's displayname")
if new_displayname == '':
new_displayname = None
yield self.store.set_profile_displayname(
target_user.localpart, new_displayname
)
@@ -157,13 +154,14 @@ class ProfileHandler(BaseHandler):
if not self.hs.is_mine(user):
defer.returnValue(None)
(displayname, avatar_url) = yield defer.gatherResults(
[
self.store.get_profile_displayname(user.localpart),
self.store.get_profile_avatar_url(user.localpart),
],
consumeErrors=True
).addErrback(unwrapFirstError)
with PreserveLoggingContext():
(displayname, avatar_url) = yield defer.gatherResults(
[
self.store.get_profile_displayname(user.localpart),
self.store.get_profile_avatar_url(user.localpart),
],
consumeErrors=True
)
state["displayname"] = displayname
state["avatar_url"] = avatar_url

View File

@@ -1,210 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import BaseHandler
from twisted.internet import defer
from synapse.util.logcontext import PreserveLoggingContext
import logging
logger = logging.getLogger(__name__)
class ReceiptsHandler(BaseHandler):
def __init__(self, hs):
super(ReceiptsHandler, self).__init__(hs)
self.hs = hs
self.federation = hs.get_replication_layer()
self.federation.register_edu_handler(
"m.receipt", self._received_remote_receipt
)
self.clock = self.hs.get_clock()
self._receipt_cache = None
@defer.inlineCallbacks
def received_client_receipt(self, room_id, receipt_type, user_id,
event_id):
"""Called when a client tells us a local user has read up to the given
event_id in the room.
"""
receipt = {
"room_id": room_id,
"receipt_type": receipt_type,
"user_id": user_id,
"event_ids": [event_id],
"data": {
"ts": int(self.clock.time_msec()),
}
}
is_new = yield self._handle_new_receipts([receipt])
if is_new:
self._push_remotes([receipt])
@defer.inlineCallbacks
def _received_remote_receipt(self, origin, content):
"""Called when we receive an EDU of type m.receipt from a remote HS.
"""
receipts = [
{
"room_id": room_id,
"receipt_type": receipt_type,
"user_id": user_id,
"event_ids": user_values["event_ids"],
"data": user_values.get("data", {}),
}
for room_id, room_values in content.items()
for receipt_type, users in room_values.items()
for user_id, user_values in users.items()
]
yield self._handle_new_receipts(receipts)
@defer.inlineCallbacks
def _handle_new_receipts(self, receipts):
"""Takes a list of receipts, stores them and informs the notifier.
"""
for receipt in receipts:
room_id = receipt["room_id"]
receipt_type = receipt["receipt_type"]
user_id = receipt["user_id"]
event_ids = receipt["event_ids"]
data = receipt["data"]
res = yield self.store.insert_receipt(
room_id, receipt_type, user_id, event_ids, data
)
if not res:
# res will be None if this read receipt is 'old'
defer.returnValue(False)
stream_id, max_persisted_id = res
with PreserveLoggingContext():
self.notifier.on_new_event(
"receipt_key", max_persisted_id, rooms=[room_id]
)
defer.returnValue(True)
@defer.inlineCallbacks
def _push_remotes(self, receipts):
"""Given a list of receipts, works out which remote servers should be
poked and pokes them.
"""
# TODO: Some of this stuff should be coallesced.
for receipt in receipts:
room_id = receipt["room_id"]
receipt_type = receipt["receipt_type"]
user_id = receipt["user_id"]
event_ids = receipt["event_ids"]
data = receipt["data"]
remotedomains = set()
rm_handler = self.hs.get_handlers().room_member_handler
yield rm_handler.fetch_room_distributions_into(
room_id, localusers=None, remotedomains=remotedomains
)
logger.debug("Sending receipt to: %r", remotedomains)
for domain in remotedomains:
self.federation.send_edu(
destination=domain,
edu_type="m.receipt",
content={
room_id: {
receipt_type: {
user_id: {
"event_ids": event_ids,
"data": data,
}
}
},
},
)
@defer.inlineCallbacks
def get_receipts_for_room(self, room_id, to_key):
"""Gets all receipts for a room, upto the given key.
"""
result = yield self.store.get_linearized_receipts_for_room(
room_id,
to_key=to_key,
)
if not result:
defer.returnValue([])
event = {
"type": "m.receipt",
"room_id": room_id,
"content": result,
}
defer.returnValue([event])
class ReceiptEventSource(object):
def __init__(self, hs):
self.store = hs.get_datastore()
@defer.inlineCallbacks
def get_new_events_for_user(self, user, from_key, limit):
from_key = int(from_key)
to_key = yield self.get_current_key()
if from_key == to_key:
defer.returnValue(([], to_key))
rooms = yield self.store.get_rooms_for_user(user.to_string())
rooms = [room.room_id for room in rooms]
events = yield self.store.get_linearized_receipts_for_rooms(
rooms,
from_key=from_key,
to_key=to_key,
)
defer.returnValue((events, to_key))
def get_current_key(self, direction='f'):
return self.store.get_max_receipt_stream_id()
@defer.inlineCallbacks
def get_pagination_rows(self, user, config, key):
to_key = int(config.from_key)
if config.to_key:
from_key = int(config.to_key)
else:
from_key = None
rooms = yield self.store.get_rooms_for_user(user.to_string())
rooms = [room.room_id for room in rooms]
events = yield self.store.get_linearized_receipts_for_rooms(
rooms,
from_key=from_key,
to_key=to_key,
)
defer.returnValue((events, to_key))

View File

@@ -18,15 +18,19 @@ from twisted.internet import defer
from synapse.types import UserID
from synapse.api.errors import (
AuthError, Codes, SynapseError, RegistrationError, InvalidCaptchaError
AuthError, Codes, SynapseError, RegistrationError, InvalidCaptchaError,
CodeMessageException
)
from ._base import BaseHandler
import synapse.util.stringutils as stringutils
from synapse.util.async import run_on_reactor
from synapse.http.client import SimpleHttpClient
from synapse.http.client import CaptchaServerHttpClient
import base64
import bcrypt
import json
import logging
import urllib
logger = logging.getLogger(__name__)
@@ -39,30 +43,6 @@ class RegistrationHandler(BaseHandler):
self.distributor = hs.get_distributor()
self.distributor.declare("registered_user")
@defer.inlineCallbacks
def check_username(self, localpart):
yield run_on_reactor()
if urllib.quote(localpart) != localpart:
raise SynapseError(
400,
"User ID must only contain characters which do not"
" require URL encoding."
)
user = UserID(localpart, self.hs.hostname)
user_id = user.to_string()
yield self.check_user_id_is_valid(user_id)
users = yield self.store.get_users_by_id_case_insensitive(user_id)
if users:
raise SynapseError(
400,
"User ID already taken.",
errcode=Codes.USER_IN_USE,
)
@defer.inlineCallbacks
def register(self, localpart=None, password=None):
"""Registers a new client on the server.
@@ -71,8 +51,7 @@ class RegistrationHandler(BaseHandler):
localpart : The local part of the user ID to register. If None,
one will be randomly generated.
password (str) : The password to assign to this user so they can
login again. This can be None which means they cannot login again
via a password (e.g. the user is an application service user).
login again.
Returns:
A tuple of (user_id, access_token).
Raises:
@@ -81,15 +60,15 @@ class RegistrationHandler(BaseHandler):
yield run_on_reactor()
password_hash = None
if password:
password_hash = self.auth_handler().hash(password)
password_hash = bcrypt.hashpw(password, bcrypt.gensalt())
if localpart:
yield self.check_username(localpart)
user = UserID(localpart, self.hs.hostname)
user_id = user.to_string()
token = self.auth_handler().generate_access_token(user_id)
yield self.check_user_id_is_valid(user_id)
token = self._generate_token(user_id)
yield self.store.register(
user_id=user_id,
token=token,
@@ -109,7 +88,7 @@ class RegistrationHandler(BaseHandler):
user_id = user.to_string()
yield self.check_user_id_is_valid(user_id)
token = self.auth_handler().generate_access_token(user_id)
token = self._generate_token(user_id)
yield self.store.register(
user_id=user_id,
token=token,
@@ -159,7 +138,7 @@ class RegistrationHandler(BaseHandler):
400, "Invalid user localpart for this application service.",
errcode=Codes.EXCLUSIVE
)
token = self.auth_handler().generate_access_token(user_id)
token = self._generate_token(user_id)
yield self.store.register(
user_id=user_id,
token=token,
@@ -170,11 +149,7 @@ class RegistrationHandler(BaseHandler):
@defer.inlineCallbacks
def check_recaptcha(self, ip, private_key, challenge, response):
"""
Checks a recaptcha is correct.
Used only by c/s api v1
"""
"""Checks a recaptcha is correct."""
captcha_response = yield self._validate_captcha(
ip,
@@ -191,49 +166,15 @@ class RegistrationHandler(BaseHandler):
else:
logger.info("Valid captcha entered from %s", ip)
@defer.inlineCallbacks
def register_saml2(self, localpart):
"""
Registers email_id as SAML2 Based Auth.
"""
if urllib.quote(localpart) != localpart:
raise SynapseError(
400,
"User ID must only contain characters which do not"
" require URL encoding."
)
user = UserID(localpart, self.hs.hostname)
user_id = user.to_string()
yield self.check_user_id_is_valid(user_id)
token = self.auth_handler().generate_access_token(user_id)
try:
yield self.store.register(
user_id=user_id,
token=token,
password_hash=None
)
yield self.distributor.fire("registered_user", user)
except Exception, e:
yield self.store.add_access_token_to_user(user_id, token)
# Ignore Registration errors
logger.exception(e)
defer.returnValue((user_id, token))
@defer.inlineCallbacks
def register_email(self, threepidCreds):
"""
Registers emails with an identity server.
Used only by c/s api v1
"""
"""Registers emails with an identity server."""
for c in threepidCreds:
logger.info("validating theeepidcred sid %s on id server %s",
c['sid'], c['idServer'])
try:
identity_handler = self.hs.get_handlers().identity_handler
threepid = yield identity_handler.threepid_from_creds(c)
threepid = yield self._threepid_from_creds(c)
except:
logger.exception("Couldn't validate 3pid")
raise RegistrationError(400, "Couldn't validate 3pid")
@@ -245,16 +186,12 @@ class RegistrationHandler(BaseHandler):
@defer.inlineCallbacks
def bind_emails(self, user_id, threepidCreds):
"""Links emails with a user ID and informs an identity server.
Used only by c/s api v1
"""
"""Links emails with a user ID and informs an identity server."""
# Now we have a matrix ID, bind it to the threepids we were given
for c in threepidCreds:
identity_handler = self.hs.get_handlers().identity_handler
# XXX: This should be a deferred list, shouldn't it?
yield identity_handler.bind_threepid(c, user_id)
yield self._bind_threepid(c, user_id)
@defer.inlineCallbacks
def check_user_id_is_valid(self, user_id):
@@ -271,15 +208,72 @@ class RegistrationHandler(BaseHandler):
errcode=Codes.EXCLUSIVE
)
def _generate_token(self, user_id):
# urlsafe variant uses _ and - so use . as the separator and replace
# all =s with .s so http clients don't quote =s when it is used as
# query params.
return (base64.urlsafe_b64encode(user_id).replace('=', '.') + '.' +
stringutils.random_string(18))
def _generate_user_id(self):
return "-" + stringutils.random_string(18)
@defer.inlineCallbacks
def _threepid_from_creds(self, creds):
# TODO: get this from the homeserver rather than creating a new one for
# each request
http_client = SimpleHttpClient(self.hs)
# XXX: make this configurable!
trustedIdServers = ['matrix.org:8090', 'matrix.org']
if not creds['idServer'] in trustedIdServers:
logger.warn('%s is not a trusted ID server: rejecting 3pid ' +
'credentials', creds['idServer'])
defer.returnValue(None)
data = {}
try:
data = yield http_client.get_json(
# XXX: This should be HTTPS
"http://%s%s" % (
creds['idServer'],
"/_matrix/identity/api/v1/3pid/getValidated3pid"
),
{'sid': creds['sid'], 'clientSecret': creds['clientSecret']}
)
except CodeMessageException as e:
data = json.loads(e.msg)
if 'medium' in data:
defer.returnValue(data)
defer.returnValue(None)
@defer.inlineCallbacks
def _bind_threepid(self, creds, mxid):
yield
logger.debug("binding threepid")
http_client = SimpleHttpClient(self.hs)
data = None
try:
data = yield http_client.post_urlencoded_get_json(
# XXX: Change when ID servers are all HTTPS
"http://%s%s" % (
creds['idServer'], "/_matrix/identity/api/v1/3pid/bind"
),
{
'sid': creds['sid'],
'clientSecret': creds['clientSecret'],
'mxid': mxid,
}
)
logger.debug("bound threepid")
except CodeMessageException as e:
data = json.loads(e.msg)
defer.returnValue(data)
@defer.inlineCallbacks
def _validate_captcha(self, ip_addr, private_key, challenge, response):
"""Validates the captcha provided.
Used only by c/s api v1
Returns:
dict: Containing 'valid'(bool) and 'error_url'(str) if invalid.
@@ -297,9 +291,6 @@ class RegistrationHandler(BaseHandler):
@defer.inlineCallbacks
def _submit_captcha(self, ip_addr, private_key, challenge, response):
"""
Used only by c/s api v1
"""
# TODO: get this from the homeserver rather than creating a new one for
# each request
client = CaptchaServerHttpClient(self.hs)
@@ -313,6 +304,3 @@ class RegistrationHandler(BaseHandler):
}
)
defer.returnValue(data)
def auth_handler(self):
return self.hs.get_handlers().auth_handler

View File

@@ -19,35 +19,19 @@ from twisted.internet import defer
from ._base import BaseHandler
from synapse.types import UserID, RoomAlias, RoomID
from synapse.api.constants import (
EventTypes, Membership, JoinRules, RoomCreationPreset,
)
from synapse.api.constants import EventTypes, Membership, JoinRules
from synapse.api.errors import StoreError, SynapseError
from synapse.util import stringutils, unwrapFirstError
from synapse.util import stringutils
from synapse.util.async import run_on_reactor
from synapse.events.utils import serialize_event
from collections import OrderedDict
import logging
import string
logger = logging.getLogger(__name__)
class RoomCreationHandler(BaseHandler):
PRESETS_DICT = {
RoomCreationPreset.PRIVATE_CHAT: {
"join_rules": JoinRules.INVITE,
"history_visibility": "shared",
"original_invitees_have_ops": False,
},
RoomCreationPreset.PUBLIC_CHAT: {
"join_rules": JoinRules.PUBLIC,
"history_visibility": "shared",
"original_invitees_have_ops": False,
},
}
@defer.inlineCallbacks
def create_room(self, user_id, room_id, config):
""" Creates a new room.
@@ -66,10 +50,6 @@ class RoomCreationHandler(BaseHandler):
self.ratelimit(user_id)
if "room_alias_name" in config:
for wchar in string.whitespace:
if wchar in config["room_alias_name"]:
raise SynapseError(400, "Invalid characters in room alias")
room_alias = RoomAlias.create(
config["room_alias_name"],
self.hs.hostname,
@@ -136,31 +116,15 @@ class RoomCreationHandler(BaseHandler):
servers=[self.hs.hostname],
)
preset_config = config.get(
"preset",
RoomCreationPreset.PUBLIC_CHAT
if is_public
else RoomCreationPreset.PRIVATE_CHAT
)
raw_initial_state = config.get("initial_state", [])
initial_state = OrderedDict()
for val in raw_initial_state:
initial_state[(val["type"], val.get("state_key", ""))] = val["content"]
user = UserID.from_string(user_id)
creation_events = self._create_events_for_new_room(
user, room_id,
preset_config=preset_config,
invite_list=invite_list,
initial_state=initial_state,
user, room_id, is_public=is_public
)
msg_handler = self.hs.get_handlers().message_handler
for event in creation_events:
yield msg_handler.create_and_send_event(event, ratelimit=False)
yield msg_handler.create_and_send_event(event)
if "name" in config:
name = config["name"]
@@ -170,7 +134,7 @@ class RoomCreationHandler(BaseHandler):
"sender": user_id,
"state_key": "",
"content": {"name": name},
}, ratelimit=False)
})
if "topic" in config:
topic = config["topic"]
@@ -180,7 +144,7 @@ class RoomCreationHandler(BaseHandler):
"sender": user_id,
"state_key": "",
"content": {"topic": topic},
}, ratelimit=False)
})
for invitee in invite_list:
yield msg_handler.create_and_send_event({
@@ -189,7 +153,7 @@ class RoomCreationHandler(BaseHandler):
"room_id": room_id,
"sender": user_id,
"content": {"membership": Membership.INVITE},
}, ratelimit=False)
})
result = {"room_id": room_id}
@@ -201,10 +165,7 @@ class RoomCreationHandler(BaseHandler):
defer.returnValue(result)
def _create_events_for_new_room(self, creator, room_id, preset_config,
invite_list, initial_state):
config = RoomCreationHandler.PRESETS_DICT[preset_config]
def _create_events_for_new_room(self, creator, room_id, is_public=False):
creator_id = creator.to_string()
event_keys = {
@@ -237,64 +198,37 @@ class RoomCreationHandler(BaseHandler):
},
)
returned_events = [creation_event, join_event]
if (EventTypes.PowerLevels, '') not in initial_state:
power_level_content = {
power_levels_event = create(
etype=EventTypes.PowerLevels,
content={
"users": {
creator.to_string(): 100,
},
"users_default": 0,
"events": {
EventTypes.Name: 50,
EventTypes.Name: 100,
EventTypes.PowerLevels: 100,
EventTypes.RoomHistoryVisibility: 100,
EventTypes.CanonicalAlias: 50,
EventTypes.RoomAvatar: 50,
},
"events_default": 0,
"state_default": 50,
"ban": 50,
"kick": 50,
"redact": 50,
"invite": 0,
}
"redact": 50
},
)
if config["original_invitees_have_ops"]:
for invitee in invite_list:
power_level_content["users"][invitee] = 100
join_rule = JoinRules.PUBLIC if is_public else JoinRules.INVITE
join_rules_event = create(
etype=EventTypes.JoinRules,
content={"join_rule": join_rule},
)
power_levels_event = create(
etype=EventTypes.PowerLevels,
content=power_level_content,
)
returned_events.append(power_levels_event)
if (EventTypes.JoinRules, '') not in initial_state:
join_rules_event = create(
etype=EventTypes.JoinRules,
content={"join_rule": config["join_rules"]},
)
returned_events.append(join_rules_event)
if (EventTypes.RoomHistoryVisibility, '') not in initial_state:
history_event = create(
etype=EventTypes.RoomHistoryVisibility,
content={"history_visibility": config["history_visibility"]}
)
returned_events.append(history_event)
for (etype, state_key), content in initial_state.items():
returned_events.append(create(
etype=etype,
state_key=state_key,
content=content,
))
return returned_events
return [
creation_event,
join_event,
power_levels_event,
join_rules_event,
]
class RoomMemberHandler(BaseHandler):
@@ -341,6 +275,60 @@ class RoomMemberHandler(BaseHandler):
if remotedomains is not None:
remotedomains.add(member.domain)
@defer.inlineCallbacks
def get_room_members_as_pagination_chunk(self, room_id=None, user_id=None,
limit=0, start_tok=None,
end_tok=None):
"""Retrieve a list of room members in the room.
Args:
room_id (str): The room to get the member list for.
user_id (str): The ID of the user making the request.
limit (int): The max number of members to return.
start_tok (str): Optional. The start token if known.
end_tok (str): Optional. The end token if known.
Returns:
dict: A Pagination streamable dict.
Raises:
SynapseError if something goes wrong.
"""
yield self.auth.check_joined_room(room_id, user_id)
member_list = yield self.store.get_room_members(room_id=room_id)
time_now = self.clock.time_msec()
event_list = [
serialize_event(entry, time_now)
for entry in member_list
]
chunk_data = {
"start": "START", # FIXME (erikj): START is no longer valid
"end": "END",
"chunk": event_list
}
# TODO honor Pagination stream params
# TODO snapshot this list to return on subsequent requests when
# paginating
defer.returnValue(chunk_data)
@defer.inlineCallbacks
def get_room_member(self, room_id, member_user_id, auth_user_id):
"""Retrieve a room member from a room.
Args:
room_id : The room the member is in.
member_user_id : The member's user ID
auth_user_id : The user ID of the user making this request.
Returns:
The room member, or None if this member does not exist.
Raises:
SynapseError if something goes wrong.
"""
yield self.auth.check_joined_room(room_id, auth_user_id)
member = yield self.store.get_room_member(user_id=member_user_id,
room_id=room_id)
defer.returnValue(member)
@defer.inlineCallbacks
def change_membership(self, event, context, do_auth=True):
""" Change the membership status of a user in a room.
@@ -492,14 +480,46 @@ class RoomMemberHandler(BaseHandler):
"user_joined_room", user=user, room_id=room_id
)
@defer.inlineCallbacks
def _should_invite_join(self, room_id, prev_state, do_auth):
logger.debug("_should_invite_join: room_id: %s", room_id)
# XXX: We don't do an auth check if we are doing an invite
# join dance for now, since we're kinda implicitly checking
# that we are allowed to join when we decide whether or not we
# need to do the invite/join dance.
# Only do an invite join dance if a) we were invited,
# b) the person inviting was from a differnt HS and c) we are
# not currently in the room
room_host = None
if prev_state and prev_state.membership == Membership.INVITE:
room = yield self.store.get_room(room_id)
inviter = UserID.from_string(
prev_state.sender
)
is_remote_invite_join = not self.hs.is_mine(inviter) and not room
room_host = inviter.domain
else:
is_remote_invite_join = False
defer.returnValue((is_remote_invite_join, room_host))
@defer.inlineCallbacks
def get_joined_rooms_for_user(self, user):
"""Returns a list of roomids that the user has any of the given
membership states in."""
rooms = yield self.store.get_rooms_for_user(
user.to_string(),
app_service = yield self.store.get_app_service_by_user_id(
user.to_string()
)
if app_service:
rooms = yield self.store.get_app_service_rooms(app_service)
else:
rooms = yield self.store.get_rooms_for_user(
user.to_string(),
)
# For some reason the list of events contains duplicates
# TODO(paul): work out why because I really don't think it should
@@ -527,17 +547,11 @@ class RoomListHandler(BaseHandler):
@defer.inlineCallbacks
def get_public_room_list(self):
chunk = yield self.store.get_rooms(is_public=True)
results = yield defer.gatherResults(
[
self.store.get_users_in_room(room["room_id"])
for room in chunk
],
consumeErrors=True,
).addErrback(unwrapFirstError)
for i, room in enumerate(chunk):
room["num_joined_members"] = len(results[i])
for room in chunk:
joined_users = yield self.store.get_users_in_room(
room_id=room["room_id"],
)
room["num_joined_members"] = len(joined_users)
# FIXME (erikj): START is no longer a valid value
defer.returnValue({"start": "START", "end": "END", "chunk": chunk})
@@ -573,8 +587,8 @@ class RoomEventSource(object):
defer.returnValue((events, end_key))
def get_current_key(self, direction='f'):
return self.store.get_room_events_max_id(direction)
def get_current_key(self):
return self.store.get_room_events_max_id()
@defer.inlineCallbacks
def get_pagination_rows(self, user, config, key):
@@ -584,6 +598,7 @@ class RoomEventSource(object):
to_key=config.to_key,
direction=config.direction,
limit=config.limit,
with_feedback=True
)
defer.returnValue((events, next_key))

View File

@@ -28,6 +28,7 @@ logger = logging.getLogger(__name__)
SyncConfig = collections.namedtuple("SyncConfig", [
"user",
"client_info",
"limit",
"gap",
"sort",
@@ -91,22 +92,13 @@ class SyncHandler(BaseHandler):
result = yield self.current_sync_for_user(sync_config, since_token)
defer.returnValue(result)
else:
def current_sync_callback(before_token, after_token):
def current_sync_callback():
return self.current_sync_for_user(sync_config, since_token)
rm_handler = self.hs.get_handlers().room_member_handler
app_service = yield self.store.get_app_service_by_user_id(
sync_config.user.to_string()
room_ids = yield rm_handler.get_joined_rooms_for_user(
sync_config.user
)
if app_service:
rooms = yield self.store.get_app_service_rooms(app_service)
room_ids = set(r.room_id for r in rooms)
else:
room_ids = yield rm_handler.get_joined_rooms_for_user(
sync_config.user
)
result = yield self.notifier.wait_for_events(
sync_config.user, room_ids,
sync_config.filter, timeout, current_sync_callback
@@ -237,16 +229,7 @@ class SyncHandler(BaseHandler):
logger.debug("Typing %r", typing_by_room)
rm_handler = self.hs.get_handlers().room_member_handler
app_service = yield self.store.get_app_service_by_user_id(
sync_config.user.to_string()
)
if app_service:
rooms = yield self.store.get_app_service_rooms(app_service)
room_ids = set(r.room_id for r in rooms)
else:
room_ids = yield rm_handler.get_joined_rooms_for_user(
sync_config.user
)
room_ids = yield rm_handler.get_joined_rooms_for_user(sync_config.user)
# TODO (mjark): Does public mean "published"?
published_rooms = yield self.store.get_rooms(is_public=True)
@@ -309,52 +292,6 @@ class SyncHandler(BaseHandler):
next_batch=now_token,
))
@defer.inlineCallbacks
def _filter_events_for_client(self, user_id, room_id, events):
event_id_to_state = yield self.store.get_state_for_events(
room_id, frozenset(e.event_id for e in events),
types=(
(EventTypes.RoomHistoryVisibility, ""),
(EventTypes.Member, user_id),
)
)
def allowed(event, state):
if event.type == EventTypes.RoomHistoryVisibility:
return True
membership_ev = state.get((EventTypes.Member, user_id), None)
if membership_ev:
membership = membership_ev.membership
else:
membership = Membership.LEAVE
if membership == Membership.JOIN:
return True
history = state.get((EventTypes.RoomHistoryVisibility, ''), None)
if history:
visibility = history.content.get("history_visibility", "shared")
else:
visibility = "shared"
if visibility == "public":
return True
elif visibility == "shared":
return True
elif visibility == "joined":
return membership == Membership.JOIN
elif visibility == "invited":
return membership == Membership.INVITE
return True
defer.returnValue([
event
for event in events
if allowed(event, event_id_to_state[event.event_id])
])
@defer.inlineCallbacks
def load_filtered_recents(self, room_id, sync_config, now_token,
since_token=None):
@@ -376,9 +313,6 @@ class SyncHandler(BaseHandler):
(room_key, _) = keys
end_key = "s" + room_key.split('-')[-1]
loaded_recents = sync_config.filter.filter_room_events(events)
loaded_recents = yield self._filter_events_for_client(
sync_config.user.to_string(), room_id, loaded_recents,
)
loaded_recents.extend(recents)
recents = loaded_recents
if len(events) <= load_limit:

Some files were not shown because too many files have changed in this diff Show More