Compare commits
437 Commits
v0.8.1
...
erikj/init
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
171829bb94 | ||
|
|
657298cebd | ||
|
|
fabb7acd45 | ||
|
|
23c639ff32 | ||
|
|
8be5284e91 | ||
|
|
503e4d3d52 | ||
|
|
00718ae7a9 | ||
|
|
0465560c1a | ||
|
|
61d05daab1 | ||
|
|
6ead27ddda | ||
|
|
f383d5a801 | ||
|
|
69d4063651 | ||
|
|
5b02f33451 | ||
|
|
054aa0d58c | ||
|
|
3c4c229788 | ||
|
|
29400b45b9 | ||
|
|
2366d28780 | ||
|
|
d89a9f7283 | ||
|
|
0c1b7f843b | ||
|
|
4b46fbec5b | ||
|
|
1d7702833d | ||
|
|
b1ca784aca | ||
|
|
4a9dc5b2f5 | ||
|
|
0ade2712d1 | ||
|
|
50f96f256f | ||
|
|
d2d61a8288 | ||
|
|
3e71d13acf | ||
|
|
e7a6edb0ee | ||
|
|
c27d6ad6b5 | ||
|
|
46daf2d200 | ||
|
|
3864b3a8e6 | ||
|
|
0618978238 | ||
|
|
09177f4f2e | ||
|
|
472be88674 | ||
|
|
a6e62cf6d0 | ||
|
|
12d381bd5d | ||
|
|
f8c30faf25 | ||
|
|
61cd5d9045 | ||
|
|
fb95035a65 | ||
|
|
4669def000 | ||
|
|
0337eaf321 | ||
|
|
884fb88e28 | ||
|
|
d76c058eea | ||
|
|
9927170787 | ||
|
|
109c8aafd2 | ||
|
|
b7788f80a3 | ||
|
|
c8ed9bd278 | ||
|
|
f2d90d5c02 | ||
|
|
845b0b2c97 | ||
|
|
c0036ced54 | ||
|
|
970a9b9d2b | ||
|
|
64991b0c8b | ||
|
|
e26a3d8d9e | ||
|
|
1319905d7a | ||
|
|
a9549fdce3 | ||
|
|
4ad8b45155 | ||
|
|
19167fd21f | ||
|
|
9c4ea42e79 | ||
|
|
74874ffda7 | ||
|
|
cd0864121b | ||
|
|
4932a7e2d9 | ||
|
|
0baf923153 | ||
|
|
9894da6a29 | ||
|
|
46d200a3a1 | ||
|
|
72443572bf | ||
|
|
a08bf11138 | ||
|
|
204132a998 | ||
|
|
f4c9ebbc34 | ||
|
|
45278eaa19 | ||
|
|
478e511db0 | ||
|
|
68c0603946 | ||
|
|
e3005d3ddb | ||
|
|
cc5d68f4c4 | ||
|
|
cc52f02d74 | ||
|
|
3151afee9e | ||
|
|
f41a9a1ffc | ||
|
|
1783c7ca92 | ||
|
|
0126ef7f3c | ||
|
|
d98edb548a | ||
|
|
9fbcf19188 | ||
|
|
073b891ec1 | ||
|
|
327ca883ec | ||
|
|
a1d4813a54 | ||
|
|
18f8247701 | ||
|
|
af27b84ff7 | ||
|
|
4a13ae7201 | ||
|
|
9182f87664 | ||
|
|
55e1bc8920 | ||
|
|
55fcf62e9c | ||
|
|
b96c133034 | ||
|
|
ce8b0b2868 | ||
|
|
252e6f6869 | ||
|
|
1ccaea5b92 | ||
|
|
0bc71103e1 | ||
|
|
f8b865264a | ||
|
|
5b8b1a43bd | ||
|
|
40cbd6b6ee | ||
|
|
4e49f52375 | ||
|
|
38432d8c25 | ||
|
|
42b7139dec | ||
|
|
1ef66cc3bd | ||
|
|
416a3e6c4f | ||
|
|
8558e1ec73 | ||
|
|
56f518d279 | ||
|
|
6f8e2d517e | ||
|
|
1c1d67dfef | ||
|
|
2c70849dc3 | ||
|
|
0a016b0525 | ||
|
|
e701aec2d1 | ||
|
|
412ece18e7 | ||
|
|
1c82fbd2eb | ||
|
|
2732be83d9 | ||
|
|
e4c4664d73 | ||
|
|
03c4f0ed67 | ||
|
|
f1acb9fd40 | ||
|
|
8a5be236e0 | ||
|
|
df75914791 | ||
|
|
b02e1006b9 | ||
|
|
f8152f2708 | ||
|
|
2f475bd5d5 | ||
|
|
a7b51f4539 | ||
|
|
288702170d | ||
|
|
f46eee838a | ||
|
|
a654f3fe49 | ||
|
|
44ccfa6258 | ||
|
|
04d1725752 | ||
|
|
1bac74b9ae | ||
|
|
ed83638668 | ||
|
|
7ac8a60c6f | ||
|
|
c253b14f6e | ||
|
|
bdcb23ca25 | ||
|
|
b2c2dc8940 | ||
|
|
869dc94cbb | ||
|
|
a218619626 | ||
|
|
b1e68add19 | ||
|
|
31e262e6b4 | ||
|
|
eede182df7 | ||
|
|
4e2f8b8722 | ||
|
|
c8c710eca7 | ||
|
|
149ed9f151 | ||
|
|
f7a79a37be | ||
|
|
6532b6e607 | ||
|
|
74270defda | ||
|
|
e1e5e53127 | ||
|
|
b3bda8a75f | ||
|
|
8a785c3006 | ||
|
|
191f7f09ce | ||
|
|
03eb4adc6e | ||
|
|
4bbf7156ef | ||
|
|
6d15401341 | ||
|
|
8c78414284 | ||
|
|
6c99491347 | ||
|
|
0eb61a3d16 | ||
|
|
a2c10d37d7 | ||
|
|
2e0d9219b9 | ||
|
|
f30d47c876 | ||
|
|
a16eaa0c33 | ||
|
|
f43063158a | ||
|
|
7c50e3b816 | ||
|
|
2808c040ef | ||
|
|
48b6ee2b67 | ||
|
|
bc41f0398f | ||
|
|
d3309933f5 | ||
|
|
b568c0231c | ||
|
|
3a7d7a3f22 | ||
|
|
3ba522bb23 | ||
|
|
6080830bef | ||
|
|
8b183781cb | ||
|
|
1f651ba6ec | ||
|
|
812a99100b | ||
|
|
1967650bc4 | ||
|
|
1ebff9736b | ||
|
|
24d21887ed | ||
|
|
db8d4e8dd6 | ||
|
|
2f9157b427 | ||
|
|
91c8f828e1 | ||
|
|
8db6832db8 | ||
|
|
117f35ac4a | ||
|
|
4eea5cf6c2 | ||
|
|
f96ab9d18d | ||
|
|
865398b4a9 | ||
|
|
e3417bbbe0 | ||
|
|
2492efb162 | ||
|
|
4a5990ff8f | ||
|
|
5e7a90316d | ||
|
|
231498ac45 | ||
|
|
fd4fa9097f | ||
|
|
0b1a8500a2 | ||
|
|
cb03fafdf1 | ||
|
|
bf5e54f255 | ||
|
|
94e1e58b4d | ||
|
|
ced39d019f | ||
|
|
16dcdedc8a | ||
|
|
83b554437e | ||
|
|
dfc46c6220 | ||
|
|
6ba2e3df4e | ||
|
|
427bcb7608 | ||
|
|
0ec346d942 | ||
|
|
1352ae2c07 | ||
|
|
4cd5fb13a3 | ||
|
|
ea1776f556 | ||
|
|
e1c0970c11 | ||
|
|
b8092fbc82 | ||
|
|
bc9e69e160 | ||
|
|
f2cf37518b | ||
|
|
04c7f3576e | ||
|
|
0268d40281 | ||
|
|
399b5add58 | ||
|
|
e6e130b9ba | ||
|
|
766bd8e880 | ||
|
|
ffad75bd62 | ||
|
|
a429515bdd | ||
|
|
8d761134c2 | ||
|
|
cf04cedf21 | ||
|
|
5b31afcbd1 | ||
|
|
6e91f14d09 | ||
|
|
ed26e4012b | ||
|
|
806f380a8b | ||
|
|
a19b739909 | ||
|
|
a5c72780e6 | ||
|
|
e19f794fee | ||
|
|
d5ff9effcf | ||
|
|
e845434028 | ||
|
|
4af32a2817 | ||
|
|
bc6cef823f | ||
|
|
1ec6fa98c9 | ||
|
|
25d2914fba | ||
|
|
cce5d057d3 | ||
|
|
6606f7c659 | ||
|
|
a971fa9d58 | ||
|
|
ded4128965 | ||
|
|
f9e12f79ca | ||
|
|
c756dfeb14 | ||
|
|
63677d1f47 | ||
|
|
32e14d8181 | ||
|
|
4847a9534d | ||
|
|
88cb06e996 | ||
|
|
d488463fa3 | ||
|
|
127fad17dd | ||
|
|
5a95cd4442 | ||
|
|
58d8339966 | ||
|
|
be7ead6946 | ||
|
|
3cbc286d06 | ||
|
|
3c741682e5 | ||
|
|
86fc9b617c | ||
|
|
90bcb86957 | ||
|
|
1bede47843 | ||
|
|
93937b2b31 | ||
|
|
c5365dee56 | ||
|
|
4103b1c470 | ||
|
|
4d5b098626 | ||
|
|
7ed2ec3061 | ||
|
|
ce797ad373 | ||
|
|
7e863c51e6 | ||
|
|
0f12772e32 | ||
|
|
d5d4281647 | ||
|
|
cda4a6f93f | ||
|
|
e2722f58ee | ||
|
|
a1665c5094 | ||
|
|
2ded344620 | ||
|
|
9707acfc40 | ||
|
|
8bf285e082 | ||
|
|
cf59d68b17 | ||
|
|
1280a47fc6 | ||
|
|
23d285ad57 | ||
|
|
8ad0f4912e | ||
|
|
6f9dea7483 | ||
|
|
22d7a59306 | ||
|
|
279a547a8b | ||
|
|
83f5125d52 | ||
|
|
a43b40449b | ||
|
|
9cef051ce2 | ||
|
|
0575840866 | ||
|
|
45131b2bca | ||
|
|
ccda401dbf | ||
|
|
5b89052d2f | ||
|
|
3887350e47 | ||
|
|
19234cc6c3 | ||
|
|
e8f1521605 | ||
|
|
605941ee26 | ||
|
|
5bc41fe9f8 | ||
|
|
638be5a6b9 | ||
|
|
830d07db82 | ||
|
|
65f5e4e3e4 | ||
|
|
07d4041709 | ||
|
|
c1b34af441 | ||
|
|
9a05795619 | ||
|
|
24d8134ac1 | ||
|
|
7f911ef4e3 | ||
|
|
d5e7e6b9b6 | ||
|
|
0775c62469 | ||
|
|
38928c6609 | ||
|
|
a2a93a4fa7 | ||
|
|
4fe95094d1 | ||
|
|
ae8ff92e05 | ||
|
|
49d6aa1394 | ||
|
|
0bfa78b39b | ||
|
|
6bc9edd8b2 | ||
|
|
05a35d62b6 | ||
|
|
8574bf62dc | ||
|
|
0af5f5efaf | ||
|
|
c8d3f6486d | ||
|
|
304111afd0 | ||
|
|
d0e444a648 | ||
|
|
65fd446b4d | ||
|
|
364c7f92b4 | ||
|
|
4eb6d66b45 | ||
|
|
6b59650753 | ||
|
|
41cd778d66 | ||
|
|
70a84f17f3 | ||
|
|
779f7b0f44 | ||
|
|
ef1e019840 | ||
|
|
5583e29513 | ||
|
|
c5bf0343e8 | ||
|
|
e24c32e6f3 | ||
|
|
e9c908ebc0 | ||
|
|
9236136f3a | ||
|
|
813e54bd5b | ||
|
|
80a620a83a | ||
|
|
9fa8bda099 | ||
|
|
f129ee1e18 | ||
|
|
09cbff174a | ||
|
|
d18e7779ca | ||
|
|
5e88a09a42 | ||
|
|
cf1fa59f4b | ||
|
|
3470cb36a8 | ||
|
|
c217504949 | ||
|
|
b59aa74556 | ||
|
|
d33ae65efc | ||
|
|
9f642a93ec | ||
|
|
e7887e37a8 | ||
|
|
af853a4cdb | ||
|
|
4891c4ff72 | ||
|
|
46183cc69f | ||
|
|
59bf16eddc | ||
|
|
9a506a191a | ||
|
|
8675ea03de | ||
|
|
8366fde82f | ||
|
|
3e420aebd8 | ||
|
|
ff1fa0fbf8 | ||
|
|
abcd03af02 | ||
|
|
5116946ae9 | ||
|
|
6f4f7e4e22 | ||
|
|
a32e876ef4 | ||
|
|
a198894bf7 | ||
|
|
5b999e206e | ||
|
|
32206dde3f | ||
|
|
4edcbcee3b | ||
|
|
953e40f9dc | ||
|
|
df4c12c762 | ||
|
|
c1a256cc4c | ||
|
|
f173d40a32 | ||
|
|
1b988b051b | ||
|
|
033a517feb | ||
|
|
9ba6487b3f | ||
|
|
d6b3ea75d4 | ||
|
|
7ab9f91a60 | ||
|
|
0e8f5095c7 | ||
|
|
ce2766d19c | ||
|
|
438a21c87b | ||
|
|
9aa0224cdf | ||
|
|
c7023f2155 | ||
|
|
0ba393924a | ||
|
|
f488293d96 | ||
|
|
1aa44939fc | ||
|
|
5a447098dd | ||
|
|
9e98f1022a | ||
|
|
9115421ace | ||
|
|
d19e79ecc9 | ||
|
|
ed008e85a8 | ||
|
|
6e7131f02f | ||
|
|
9a7f496298 | ||
|
|
78adccfaf4 | ||
|
|
d98660a60d | ||
|
|
d5272b1d2c | ||
|
|
278149f533 | ||
|
|
72d8406409 | ||
|
|
a63b4f7101 | ||
|
|
0f86312c4c | ||
|
|
b1022ed8b5 | ||
|
|
f6583796fe | ||
|
|
4848fdbf59 | ||
|
|
80cd08c190 | ||
|
|
9517f4da4d | ||
|
|
dc0c989ef4 | ||
|
|
ceb61daa70 | ||
|
|
fce0114005 | ||
|
|
7e282a53a5 | ||
|
|
91cb46191d | ||
|
|
87db64b839 | ||
|
|
cb8162d3d1 | ||
|
|
d288d273e1 | ||
|
|
d4f50f3ae5 | ||
|
|
455579ca90 | ||
|
|
0d0610870d | ||
|
|
532ebc4a82 | ||
|
|
56f2d31676 | ||
|
|
c178e4e6ca | ||
|
|
d7a0496f3e | ||
|
|
58ed393235 | ||
|
|
fae059cc18 | ||
|
|
b26d85c30f | ||
|
|
0dcb145c7e | ||
|
|
64cf1483e5 | ||
|
|
89036579ed | ||
|
|
f0d6f724a2 | ||
|
|
d04fa1f712 | ||
|
|
6279285b2a | ||
|
|
c9c444f562 | ||
|
|
835e01fc70 | ||
|
|
f9232c7917 | ||
|
|
db1fbc6c6f | ||
|
|
7e0bba555c | ||
|
|
04c9751f24 | ||
|
|
b98cd03193 | ||
|
|
21fd84dcb8 | ||
|
|
0a60bbf4fa | ||
|
|
1ead1caa18 | ||
|
|
1c2dcf762a | ||
|
|
406d32f8b5 | ||
|
|
34ce2ca62f | ||
|
|
4a6afa6abf | ||
|
|
01c099d9ef | ||
|
|
64345b7559 | ||
|
|
10766f1e93 | ||
|
|
2602ddc379 | ||
|
|
0354659f9d | ||
|
|
7d3491c741 | ||
|
|
f260cb72cd | ||
|
|
141ec04d19 | ||
|
|
0fbfe1b08a | ||
|
|
192e228a98 | ||
|
|
d516d68b29 | ||
|
|
0c838f9f5e | ||
|
|
773cb3b688 | ||
|
|
e319071191 | ||
|
|
be09c23ff0 |
37
AUTHORS.rst
Normal file
37
AUTHORS.rst
Normal file
@@ -0,0 +1,37 @@
|
||||
Erik Johnston <erik at matrix.org>
|
||||
* HS core
|
||||
* Federation API impl
|
||||
|
||||
Mark Haines <mark at matrix.org>
|
||||
* HS core
|
||||
* Crypto
|
||||
* Content repository
|
||||
* CS v2 API impl
|
||||
|
||||
Kegan Dougal <kegan at matrix.org>
|
||||
* HS core
|
||||
* CS v1 API impl
|
||||
* AS API impl
|
||||
|
||||
Paul "LeoNerd" Evans <paul at matrix.org>
|
||||
* HS core
|
||||
* Presence
|
||||
* Typing Notifications
|
||||
* Performance metrics and caching layer
|
||||
|
||||
Dave Baker <dave at matrix.org>
|
||||
* Push notifications
|
||||
* Auth CS v2 impl
|
||||
|
||||
Matthew Hodgson <matthew at matrix.org>
|
||||
* General doc & housekeeping
|
||||
* Vertobot/vertobridge matrix<->verto PoC
|
||||
|
||||
Emmanuel Rohee <manu at matrix.org>
|
||||
* Supporting iOS clients (testability and fallback registration)
|
||||
|
||||
Turned to Dust <dwinslow86 at gmail.com>
|
||||
* ArchLinux installation instructions
|
||||
|
||||
Brabo <brabo at riseup.net>
|
||||
* Installation instruction fixes
|
||||
31
CAPTCHA_SETUP
Normal file
31
CAPTCHA_SETUP
Normal file
@@ -0,0 +1,31 @@
|
||||
Captcha can be enabled for this home server. This file explains how to do that.
|
||||
The captcha mechanism used is Google's ReCaptcha. This requires API keys from Google.
|
||||
|
||||
Getting keys
|
||||
------------
|
||||
Requires a public/private key pair from:
|
||||
|
||||
https://developers.google.com/recaptcha/
|
||||
|
||||
|
||||
Setting ReCaptcha Keys
|
||||
----------------------
|
||||
The keys are a config option on the home server config. If they are not
|
||||
visible, you can generate them via --generate-config. Set the following value:
|
||||
|
||||
recaptcha_public_key: YOUR_PUBLIC_KEY
|
||||
recaptcha_private_key: YOUR_PRIVATE_KEY
|
||||
|
||||
In addition, you MUST enable captchas via:
|
||||
|
||||
enable_registration_captcha: true
|
||||
|
||||
Configuring IP used for auth
|
||||
----------------------------
|
||||
The ReCaptcha API requires that the IP address of the user who solved the
|
||||
captcha is sent. If the client is connecting through a proxy or load balancer,
|
||||
it may be required to use the X-Forwarded-For (XFF) header instead of the origin
|
||||
IP address. This can be configured as an option on the home server like so:
|
||||
|
||||
captcha_ip_origin_is_x_forwarded: true
|
||||
|
||||
@@ -1,3 +1,10 @@
|
||||
Changes in synapse vX
|
||||
=====================
|
||||
|
||||
* Changed config option from ``disable_registration`` to
|
||||
``enable_registration``. Old option will be ignored.
|
||||
|
||||
|
||||
Changes in synapse v0.8.1 (2015-03-18)
|
||||
======================================
|
||||
|
||||
|
||||
118
CONTRIBUTING.rst
Normal file
118
CONTRIBUTING.rst
Normal file
@@ -0,0 +1,118 @@
|
||||
Contributing code to Matrix
|
||||
===========================
|
||||
|
||||
Everyone is welcome to contribute code to Matrix
|
||||
(https://github.com/matrix-org), provided that they are willing to license
|
||||
their contributions under the same license as the project itself. We follow a
|
||||
simple 'inbound=outbound' model for contributions: the act of submitting an
|
||||
'inbound' contribution means that the contributor agrees to license the code
|
||||
under the same terms as the project's overall 'outbound' license - in our
|
||||
case, this is almost always Apache Software License v2 (see LICENSE).
|
||||
|
||||
How to contribute
|
||||
~~~~~~~~~~~~~~~~~
|
||||
|
||||
The preferred and easiest way to contribute changes to Matrix is to fork the
|
||||
relevant project on github, and then create a pull request to ask us to pull
|
||||
your changes into our repo
|
||||
(https://help.github.com/articles/using-pull-requests/)
|
||||
|
||||
**The single biggest thing you need to know is: please base your changes on
|
||||
the develop branch - /not/ master.**
|
||||
|
||||
We use the master branch to track the most recent release, so that folks who
|
||||
blindly clone the repo and automatically check out master get something that
|
||||
works. Develop is the unstable branch where all the development actually
|
||||
happens: the workflow is that contributors should fork the develop branch to
|
||||
make a 'feature' branch for a particular contribution, and then make a pull
|
||||
request to merge this back into the matrix.org 'official' develop branch. We
|
||||
use github's pull request workflow to review the contribution, and either ask
|
||||
you to make any refinements needed or merge it and make them ourselves. The
|
||||
changes will then land on master when we next do a release.
|
||||
|
||||
We use Jenkins for continuous integration (http://matrix.org/jenkins), and
|
||||
typically all pull requests get automatically tested Jenkins: if your change breaks the build, Jenkins will yell about it in #matrix-dev:matrix.org so please lurk there and keep an eye open.
|
||||
|
||||
Code style
|
||||
~~~~~~~~~~
|
||||
|
||||
All Matrix projects have a well-defined code-style - and sometimes we've even
|
||||
got as far as documenting it... For instance, synapse's code style doc lives
|
||||
at https://github.com/matrix-org/synapse/tree/master/docs/code_style.rst.
|
||||
|
||||
Please ensure your changes match the cosmetic style of the existing project,
|
||||
and **never** mix cosmetic and functional changes in the same commit, as it
|
||||
makes it horribly hard to review otherwise.
|
||||
|
||||
Attribution
|
||||
~~~~~~~~~~~
|
||||
|
||||
Everyone who contributes anything to Matrix is welcome to be listed in the
|
||||
AUTHORS.rst file for the project in question. Please feel free to include a
|
||||
change to AUTHORS.rst in your pull request to list yourself and a short
|
||||
description of the area(s) you've worked on. Also, we sometimes have swag to
|
||||
give away to contributors - if you feel that Matrix-branded apparel is missing
|
||||
from your life, please mail us your shipping address to matrix at matrix.org and we'll try to fix it :)
|
||||
|
||||
Sign off
|
||||
~~~~~~~~
|
||||
|
||||
In order to have a concrete record that your contribution is intentional
|
||||
and you agree to license it under the same terms as the project's license, we've adopted the
|
||||
same lightweight approach that the Linux Kernel
|
||||
(https://www.kernel.org/doc/Documentation/SubmittingPatches), Docker
|
||||
(https://github.com/docker/docker/blob/master/CONTRIBUTING.md), and many other
|
||||
projects use: the DCO (Developer Certificate of Origin:
|
||||
http://developercertificate.org/). This is a simple declaration that you wrote
|
||||
the contribution or otherwise have the right to contribute it to Matrix::
|
||||
|
||||
Developer Certificate of Origin
|
||||
Version 1.1
|
||||
|
||||
Copyright (C) 2004, 2006 The Linux Foundation and its contributors.
|
||||
660 York Street, Suite 102,
|
||||
San Francisco, CA 94110 USA
|
||||
|
||||
Everyone is permitted to copy and distribute verbatim copies of this
|
||||
license document, but changing it is not allowed.
|
||||
|
||||
Developer's Certificate of Origin 1.1
|
||||
|
||||
By making a contribution to this project, I certify that:
|
||||
|
||||
(a) The contribution was created in whole or in part by me and I
|
||||
have the right to submit it under the open source license
|
||||
indicated in the file; or
|
||||
|
||||
(b) The contribution is based upon previous work that, to the best
|
||||
of my knowledge, is covered under an appropriate open source
|
||||
license and I have the right under that license to submit that
|
||||
work with modifications, whether created in whole or in part
|
||||
by me, under the same open source license (unless I am
|
||||
permitted to submit under a different license), as indicated
|
||||
in the file; or
|
||||
|
||||
(c) The contribution was provided directly to me by some other
|
||||
person who certified (a), (b) or (c) and I have not modified
|
||||
it.
|
||||
|
||||
(d) I understand and agree that this project and the contribution
|
||||
are public and that a record of the contribution (including all
|
||||
personal information I submit with it, including my sign-off) is
|
||||
maintained indefinitely and may be redistributed consistent with
|
||||
this project or the open source license(s) involved.
|
||||
|
||||
If you agree to this for your contribution, then all that's needed is to
|
||||
include the line in your commit or pull request comment::
|
||||
|
||||
Signed-off-by: Your Name <your@email.example.org>
|
||||
|
||||
...using your real name; unfortunately pseudonyms and anonymous contributions
|
||||
can't be accepted. Git makes this trivial - just use the -s flag when you do
|
||||
``git commit``, having first set ``user.name`` and ``user.email`` git configs
|
||||
(which you should have done anyway :)
|
||||
|
||||
Conclusion
|
||||
~~~~~~~~~~
|
||||
|
||||
That's it! Matrix is a very open and collaborative project as you might expect given our obsession with open communication. If we're going to successfully matrix together all the fragmented communication technologies out there we are reliant on contributions and collaboration from the community to do so. So please get involved - and we hope you have as much fun hacking on Matrix as we do!
|
||||
159
README.rst
159
README.rst
@@ -20,7 +20,7 @@ The overall architecture is::
|
||||
https://somewhere.org/_matrix https://elsewhere.net/_matrix
|
||||
|
||||
``#matrix:matrix.org`` is the official support room for Matrix, and can be
|
||||
accessed by the web client at http://matrix.org/alpha or via an IRC bridge at
|
||||
accessed by the web client at http://matrix.org/beta or via an IRC bridge at
|
||||
irc://irc.freenode.net/matrix.
|
||||
|
||||
Synapse is currently in rapid development, but as of version 0.5 we believe it
|
||||
@@ -69,24 +69,30 @@ Synapse ships with two basic demo Matrix clients: webclient (a basic group chat
|
||||
web client demo implemented in AngularJS) and cmdclient (a basic Python
|
||||
command line utility which lets you easily see what the JSON APIs are up to).
|
||||
|
||||
Meanwhile, iOS and Android SDKs and clients are currently in development and available from:
|
||||
Meanwhile, iOS and Android SDKs and clients are available from:
|
||||
|
||||
- https://github.com/matrix-org/matrix-ios-sdk
|
||||
- https://github.com/matrix-org/matrix-ios-kit
|
||||
- https://github.com/matrix-org/matrix-ios-console
|
||||
- https://github.com/matrix-org/matrix-android-sdk
|
||||
|
||||
We'd like to invite you to join #matrix:matrix.org (via http://matrix.org/alpha), run a homeserver, take a look at the Matrix spec at
|
||||
http://matrix.org/docs/spec, experiment with the APIs and the demo
|
||||
clients, and report any bugs via http://matrix.org/jira.
|
||||
We'd like to invite you to join #matrix:matrix.org (via
|
||||
https://matrix.org/beta), run a homeserver, take a look at the Matrix spec at
|
||||
https://matrix.org/docs/spec and API docs at https://matrix.org/docs/api,
|
||||
experiment with the APIs and the demo clients, and report any bugs via
|
||||
https://matrix.org/jira.
|
||||
|
||||
Thanks for using Matrix!
|
||||
|
||||
[1] End-to-end encryption is currently in development
|
||||
|
||||
Homeserver Installation
|
||||
=======================
|
||||
Synapse Installation
|
||||
====================
|
||||
|
||||
Synapse is the reference python/twisted Matrix homeserver implementation.
|
||||
|
||||
System requirements:
|
||||
- POSIX-compliant system (tested on Linux & OSX)
|
||||
- POSIX-compliant system (tested on Linux & OS X)
|
||||
- Python 2.7
|
||||
|
||||
Synapse is written in python but some of the libraries is uses are written in
|
||||
@@ -118,6 +124,9 @@ To install the synapse homeserver run::
|
||||
This installs synapse, along with the libraries it uses, into a virtual
|
||||
environment under ``~/.synapse``.
|
||||
|
||||
Alternatively, Silvio Fricke has contributed a Dockerfile to automate the
|
||||
above in Docker at https://registry.hub.docker.com/u/silviof/docker-matrix/.
|
||||
|
||||
To set up your homeserver, run (in your virtualenv, as before)::
|
||||
|
||||
$ cd ~/.synapse
|
||||
@@ -128,8 +137,18 @@ To set up your homeserver, run (in your virtualenv, as before)::
|
||||
|
||||
Substituting your host and domain name as appropriate.
|
||||
|
||||
This will generate you a config file that you can then customise, but it will
|
||||
also generate a set of keys for you. These keys will allow your Home Server to
|
||||
identify itself to other Home Servers, so don't lose or delete them. It would be
|
||||
wise to back them up somewhere safe. If, for whatever reason, you do need to
|
||||
change your Home Server's keys, you may find that other Home Servers have the
|
||||
old key cached. If you update the signing key, you should change the name of the
|
||||
key in the <server name>.signing.key file (the second word, which by default is
|
||||
, 'auto') to something different.
|
||||
|
||||
By default, registration of new users is disabled. You can either enable
|
||||
registration in the config (it is then recommended to also set up CAPTCHA), or
|
||||
registration in the config by specifying ``enable_registration: true``
|
||||
(it is then recommended to also set up CAPTCHA), or
|
||||
you can use the command line to register new users::
|
||||
|
||||
$ source ~/.synapse/bin/activate
|
||||
@@ -142,36 +161,51 @@ you can use the command line to register new users::
|
||||
For reliable VoIP calls to be routed via this homeserver, you MUST configure
|
||||
a TURN server. See docs/turn-howto.rst for details.
|
||||
|
||||
Troubleshooting Installation
|
||||
----------------------------
|
||||
Using PostgreSQL
|
||||
================
|
||||
|
||||
Synapse requires pip 1.7 or later, so if your OS provides too old a version and
|
||||
you get errors about ``error: no such option: --process-dependency-links`` you
|
||||
may need to manually upgrade it::
|
||||
As of Synapse 0.9, `PostgreSQL <http://www.postgresql.org>`_ is supported as an
|
||||
alternative to the `SQLite <http://sqlite.org/>`_ database that Synapse has
|
||||
traditionally used for convenience and simplicity.
|
||||
|
||||
$ sudo pip install --upgrade pip
|
||||
The advantages of Postgres include:
|
||||
|
||||
If pip crashes mid-installation for reason (e.g. lost terminal), pip may
|
||||
refuse to run until you remove the temporary installation directory it
|
||||
created. To reset the installation::
|
||||
* significant performance improvements due to the superior threading and
|
||||
caching model, smarter query optimiser
|
||||
* allowing the DB to be run on separate hardware
|
||||
* allowing basic active/backup high-availability with a "hot spare" synapse
|
||||
pointing at the same DB master, as well as enabling DB replication in
|
||||
synapse itself.
|
||||
|
||||
The only disadvantage is that the code is relatively new as of April 2015 and
|
||||
may have a few regressions relative to SQLite.
|
||||
|
||||
$ rm -rf /tmp/pip_install_matrix
|
||||
For information on how to install and use PostgreSQL, please see
|
||||
`docs/postgres.rst <docs/postgres.rst>`_.
|
||||
|
||||
pip seems to leak *lots* of memory during installation. For instance, a Linux
|
||||
host with 512MB of RAM may run out of memory whilst installing Twisted. If this
|
||||
happens, you will have to individually install the dependencies which are
|
||||
failing, e.g.::
|
||||
Running Synapse
|
||||
===============
|
||||
|
||||
$ pip install twisted
|
||||
To actually run your new homeserver, pick a working directory for Synapse to run
|
||||
(e.g. ``~/.synapse``), and::
|
||||
|
||||
On OSX, if you encounter clang: error: unknown argument: '-mno-fused-madd' you
|
||||
will need to export CFLAGS=-Qunused-arguments.
|
||||
$ cd ~/.synapse
|
||||
$ source ./bin/activate
|
||||
$ synctl start
|
||||
|
||||
Platform Specific Instructions
|
||||
==============================
|
||||
|
||||
ArchLinux
|
||||
---------
|
||||
|
||||
Installation on ArchLinux may encounter a few hiccups as Arch defaults to
|
||||
python 3, but synapse currently assumes python 2.7 by default.
|
||||
The quickest way to get up and running with ArchLinux is probably with Ivan
|
||||
Shapovalov's AUR package from
|
||||
https://aur.archlinux.org/packages/matrix-synapse/, which should pull in all
|
||||
the necessary dependencies.
|
||||
|
||||
Alternatively, to install using pip a few changes may be needed as ArchLinux
|
||||
defaults to python 3, but synapse currently assumes python 2.7 by default:
|
||||
|
||||
pip may be outdated (6.0.7-1 and needs to be upgraded to 6.0.8-1 )::
|
||||
|
||||
@@ -191,7 +225,7 @@ installing under virtualenv)::
|
||||
$ sudo pip2.7 uninstall py-bcrypt
|
||||
$ sudo pip2.7 install py-bcrypt
|
||||
|
||||
During setup of homeserver you need to call python2.7 directly again::
|
||||
During setup of Synapse you need to call python2.7 directly again::
|
||||
|
||||
$ cd ~/.synapse
|
||||
$ python2.7 -m synapse.app.homeserver \
|
||||
@@ -232,15 +266,33 @@ Troubleshooting:
|
||||
you do, you may need to create a symlink to ``libsodium.a`` so ``ld`` can find
|
||||
it: ``ln -s /usr/local/lib/libsodium.a /usr/lib/libsodium.a``
|
||||
|
||||
Running Your Homeserver
|
||||
=======================
|
||||
Troubleshooting
|
||||
===============
|
||||
|
||||
To actually run your new homeserver, pick a working directory for Synapse to run
|
||||
(e.g. ``~/.synapse``), and::
|
||||
Troubleshooting Installation
|
||||
----------------------------
|
||||
|
||||
$ cd ~/.synapse
|
||||
$ source ./bin/activate
|
||||
$ synctl start
|
||||
Synapse requires pip 1.7 or later, so if your OS provides too old a version and
|
||||
you get errors about ``error: no such option: --process-dependency-links`` you
|
||||
may need to manually upgrade it::
|
||||
|
||||
$ sudo pip install --upgrade pip
|
||||
|
||||
If pip crashes mid-installation for reason (e.g. lost terminal), pip may
|
||||
refuse to run until you remove the temporary installation directory it
|
||||
created. To reset the installation::
|
||||
|
||||
$ rm -rf /tmp/pip_install_matrix
|
||||
|
||||
pip seems to leak *lots* of memory during installation. For instance, a Linux
|
||||
host with 512MB of RAM may run out of memory whilst installing Twisted. If this
|
||||
happens, you will have to individually install the dependencies which are
|
||||
failing, e.g.::
|
||||
|
||||
$ pip install twisted
|
||||
|
||||
On OSX, if you encounter clang: error: unknown argument: '-mno-fused-madd' you
|
||||
will need to export CFLAGS=-Qunused-arguments.
|
||||
|
||||
Troubleshooting Running
|
||||
-----------------------
|
||||
@@ -261,7 +313,7 @@ fix try re-installing from PyPI or directly from
|
||||
$ pip install --user https://github.com/pyca/pynacl/tarball/master
|
||||
|
||||
ArchLinux
|
||||
---------
|
||||
~~~~~~~~~
|
||||
|
||||
If running `$ synctl start` fails with 'returned non-zero exit status 1',
|
||||
you will need to explicitly call Python2.7 - either running as::
|
||||
@@ -270,16 +322,16 @@ you will need to explicitly call Python2.7 - either running as::
|
||||
|
||||
...or by editing synctl with the correct python executable.
|
||||
|
||||
Homeserver Development
|
||||
======================
|
||||
Synapse Development
|
||||
===================
|
||||
|
||||
To check out a homeserver for development, clone the git repo into a working
|
||||
To check out a synapse for development, clone the git repo into a working
|
||||
directory of your choice::
|
||||
|
||||
$ git clone https://github.com/matrix-org/synapse.git
|
||||
$ cd synapse
|
||||
|
||||
The homeserver has a number of external dependencies, that are easiest
|
||||
Synapse has a number of external dependencies, that are easiest
|
||||
to install using pip and a virtualenv::
|
||||
|
||||
$ virtualenv env
|
||||
@@ -290,7 +342,7 @@ to install using pip and a virtualenv::
|
||||
This will run a process of downloading and installing all the needed
|
||||
dependencies into a virtual env.
|
||||
|
||||
Once this is done, you may wish to run the homeserver's unit tests, to
|
||||
Once this is done, you may wish to run Synapse's unit tests, to
|
||||
check that everything is installed as it should be::
|
||||
|
||||
$ python setup.py test
|
||||
@@ -302,10 +354,10 @@ This should end with a 'PASSED' result::
|
||||
PASSED (successes=143)
|
||||
|
||||
|
||||
Upgrading an existing homeserver
|
||||
================================
|
||||
Upgrading an existing Synapse
|
||||
=============================
|
||||
|
||||
IMPORTANT: Before upgrading an existing homeserver to a new version, please
|
||||
IMPORTANT: Before upgrading an existing synapse to a new version, please
|
||||
refer to UPGRADE.rst for any additional instructions.
|
||||
|
||||
Otherwise, simply re-install the new codebase over the current one - e.g.
|
||||
@@ -348,7 +400,7 @@ and port where the server is running. (At the current time synapse does not
|
||||
support clustering multiple servers into a single logical homeserver). The DNS
|
||||
record would then look something like::
|
||||
|
||||
$ dig -t srv _matrix._tcp.machine.my.domaine.name
|
||||
$ dig -t srv _matrix._tcp.machine.my.domain.name
|
||||
_matrix._tcp IN SRV 10 0 8448 machine.my.domain.name.
|
||||
|
||||
|
||||
@@ -366,12 +418,8 @@ SRV record, as that is the name other machines will expect it to have::
|
||||
You may additionally want to pass one or more "-v" options, in order to
|
||||
increase the verbosity of logging output; at least for initial testing.
|
||||
|
||||
For the initial alpha release, the homeserver is not speaking TLS for
|
||||
either client-server or server-server traffic for ease of debugging. We have
|
||||
also not spent any time yet getting the homeserver to run behind loadbalancers.
|
||||
|
||||
Running a Demo Federation of Homeservers
|
||||
----------------------------------------
|
||||
Running a Demo Federation of Synapses
|
||||
-------------------------------------
|
||||
|
||||
If you want to get up and running quickly with a trio of homeservers in a
|
||||
private federation (``localhost:8080``, ``localhost:8081`` and
|
||||
@@ -406,7 +454,10 @@ account. Your name will take the form of::
|
||||
Specify your desired localpart in the topmost box of the "Register for an
|
||||
account" form, and click the "Register" button. Hostnames can contain ports if
|
||||
required due to lack of SRV records (e.g. @matthew:localhost:8448 on an
|
||||
internal synapse sandbox running on localhost)
|
||||
internal synapse sandbox running on localhost).
|
||||
|
||||
If registration fails, you may need to enable it in the homeserver (see
|
||||
`Synapse Installation`_ above)
|
||||
|
||||
|
||||
Logging In To An Existing Account
|
||||
@@ -432,7 +483,7 @@ track 3PID logins and publish end-user public keys.
|
||||
|
||||
It's currently early days for identity servers as Matrix is not yet using 3PIDs
|
||||
as the primary means of identity and E2E encryption is not complete. As such,
|
||||
we are running a single identity server (http://matrix.org:8090) at the current
|
||||
we are running a single identity server (https://matrix.org) at the current
|
||||
time.
|
||||
|
||||
|
||||
|
||||
34
UPGRADE.rst
34
UPGRADE.rst
@@ -1,3 +1,37 @@
|
||||
Upgrading to v0.x.x
|
||||
===================
|
||||
|
||||
Application services have had a breaking API change in this version.
|
||||
|
||||
They can no longer register themselves with a home server using the AS HTTP API. This
|
||||
decision was made because a compromised application service with free reign to register
|
||||
any regex in effect grants full read/write access to the home server if a regex of ``.*``
|
||||
is used. An attack where a compromised AS re-registers itself with ``.*`` was deemed too
|
||||
big of a security risk to ignore, and so the ability to register with the HS remotely has
|
||||
been removed.
|
||||
|
||||
It has been replaced by specifying a list of application service registrations in
|
||||
``homeserver.yaml``::
|
||||
|
||||
app_service_config_files: ["registration-01.yaml", "registration-02.yaml"]
|
||||
|
||||
Where ``registration-01.yaml`` looks like::
|
||||
|
||||
url: <String> # e.g. "https://my.application.service.com"
|
||||
as_token: <String>
|
||||
hs_token: <String>
|
||||
sender_localpart: <String> # This is a new field which denotes the user_id localpart when using the AS token
|
||||
namespaces:
|
||||
users:
|
||||
- exclusive: <Boolean>
|
||||
regex: <String> # e.g. "@prefix_.*"
|
||||
aliases:
|
||||
- exclusive: <Boolean>
|
||||
regex: <String>
|
||||
rooms:
|
||||
- exclusive: <Boolean>
|
||||
regex: <String>
|
||||
|
||||
Upgrading to v0.8.0
|
||||
===================
|
||||
|
||||
|
||||
93
contrib/scripts/kick_users.py
Executable file
93
contrib/scripts/kick_users.py
Executable file
@@ -0,0 +1,93 @@
|
||||
#!/usr/bin/env python
|
||||
from argparse import ArgumentParser
|
||||
import json
|
||||
import requests
|
||||
import sys
|
||||
import urllib
|
||||
|
||||
def _mkurl(template, kws):
|
||||
for key in kws:
|
||||
template = template.replace(key, kws[key])
|
||||
return template
|
||||
|
||||
def main(hs, room_id, access_token, user_id_prefix, why):
|
||||
if not why:
|
||||
why = "Automated kick."
|
||||
print "Kicking members on %s in room %s matching %s" % (hs, room_id, user_id_prefix)
|
||||
room_state_url = _mkurl(
|
||||
"$HS/_matrix/client/api/v1/rooms/$ROOM/state?access_token=$TOKEN",
|
||||
{
|
||||
"$HS": hs,
|
||||
"$ROOM": room_id,
|
||||
"$TOKEN": access_token
|
||||
}
|
||||
)
|
||||
print "Getting room state => %s" % room_state_url
|
||||
res = requests.get(room_state_url)
|
||||
print "HTTP %s" % res.status_code
|
||||
state_events = res.json()
|
||||
if "error" in state_events:
|
||||
print "FATAL"
|
||||
print state_events
|
||||
return
|
||||
|
||||
kick_list = []
|
||||
room_name = room_id
|
||||
for event in state_events:
|
||||
if not event["type"] == "m.room.member":
|
||||
if event["type"] == "m.room.name":
|
||||
room_name = event["content"].get("name")
|
||||
continue
|
||||
if not event["content"].get("membership") == "join":
|
||||
continue
|
||||
if event["state_key"].startswith(user_id_prefix):
|
||||
kick_list.append(event["state_key"])
|
||||
|
||||
if len(kick_list) == 0:
|
||||
print "No user IDs match the prefix '%s'" % user_id_prefix
|
||||
return
|
||||
|
||||
print "The following user IDs will be kicked from %s" % room_name
|
||||
for uid in kick_list:
|
||||
print uid
|
||||
doit = raw_input("Continue? [Y]es\n")
|
||||
if len(doit) > 0 and doit.lower() == 'y':
|
||||
print "Kicking members..."
|
||||
# encode them all
|
||||
kick_list = [urllib.quote(uid) for uid in kick_list]
|
||||
for uid in kick_list:
|
||||
kick_url = _mkurl(
|
||||
"$HS/_matrix/client/api/v1/rooms/$ROOM/state/m.room.member/$UID?access_token=$TOKEN",
|
||||
{
|
||||
"$HS": hs,
|
||||
"$UID": uid,
|
||||
"$ROOM": room_id,
|
||||
"$TOKEN": access_token
|
||||
}
|
||||
)
|
||||
kick_body = {
|
||||
"membership": "leave",
|
||||
"reason": why
|
||||
}
|
||||
print "Kicking %s" % uid
|
||||
res = requests.put(kick_url, data=json.dumps(kick_body))
|
||||
if res.status_code != 200:
|
||||
print "ERROR: HTTP %s" % res.status_code
|
||||
if res.json().get("error"):
|
||||
print "ERROR: JSON %s" % res.json()
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = ArgumentParser("Kick members in a room matching a certain user ID prefix.")
|
||||
parser.add_argument("-u","--user-id",help="The user ID prefix e.g. '@irc_'")
|
||||
parser.add_argument("-t","--token",help="Your access_token")
|
||||
parser.add_argument("-r","--room",help="The room ID to kick members in")
|
||||
parser.add_argument("-s","--homeserver",help="The base HS url e.g. http://matrix.org")
|
||||
parser.add_argument("-w","--why",help="Reason for the kick. Optional.")
|
||||
args = parser.parse_args()
|
||||
if not args.room or not args.token or not args.user_id or not args.homeserver:
|
||||
parser.print_help()
|
||||
sys.exit(1)
|
||||
else:
|
||||
main(args.homeserver, args.room, args.token, args.user_id, args.why)
|
||||
23
contrib/systemd/log_config.yaml
Normal file
23
contrib/systemd/log_config.yaml
Normal file
@@ -0,0 +1,23 @@
|
||||
version: 1
|
||||
|
||||
# In systemd's journal, loglevel is implicitly stored, so let's omit it
|
||||
# from the message text.
|
||||
formatters:
|
||||
journal_fmt:
|
||||
format: '%(name)s: [%(request)s] %(message)s'
|
||||
|
||||
filters:
|
||||
context:
|
||||
(): synapse.util.logcontext.LoggingContextFilter
|
||||
request: ""
|
||||
|
||||
handlers:
|
||||
journal:
|
||||
class: systemd.journal.JournalHandler
|
||||
formatter: journal_fmt
|
||||
filters: [context]
|
||||
SYSLOG_IDENTIFIER: synapse
|
||||
|
||||
root:
|
||||
level: INFO
|
||||
handlers: [journal]
|
||||
16
contrib/systemd/synapse.service
Normal file
16
contrib/systemd/synapse.service
Normal file
@@ -0,0 +1,16 @@
|
||||
# This assumes that Synapse has been installed as a system package
|
||||
# (e.g. https://aur.archlinux.org/packages/matrix-synapse/ for ArchLinux)
|
||||
# rather than in a user home directory or similar under virtualenv.
|
||||
|
||||
[Unit]
|
||||
Description=Synapse Matrix homeserver
|
||||
|
||||
[Service]
|
||||
Type=simple
|
||||
User=synapse
|
||||
Group=synapse
|
||||
WorkingDirectory=/var/lib/synapse
|
||||
ExecStart=/usr/bin/python2.7 -m synapse.app.homeserver --config-path=/etc/synapse/homeserver.yaml --log-config=/etc/synapse/log_config.yaml
|
||||
|
||||
[Install]
|
||||
WantedBy=multi-user.target
|
||||
@@ -7,6 +7,9 @@ matrix:
|
||||
matrix-bot:
|
||||
user_id: '@vertobot:matrix.org'
|
||||
password: ''
|
||||
domain: 'matrix.org"
|
||||
as_url: 'http://localhost:8009'
|
||||
as_token: 'vertobot123'
|
||||
|
||||
verto-bot:
|
||||
host: webrtc.freeswitch.org
|
||||
|
||||
@@ -33,7 +33,8 @@ for port in 8080 8081 8082; do
|
||||
--manhole $((port + 1000)) \
|
||||
--tls-dh-params-path "demo/demo.tls.dh" \
|
||||
--media-store-path "demo/media_store.$port" \
|
||||
$PARAMS $SYNAPSE_PARAMS \
|
||||
$PARAMS $SYNAPSE_PARAMS \
|
||||
--enable-registration
|
||||
|
||||
python -m synapse.app.homeserver \
|
||||
--config-path "demo/etc/$port.config" \
|
||||
|
||||
50
docs/metrics-howto.rst
Normal file
50
docs/metrics-howto.rst
Normal file
@@ -0,0 +1,50 @@
|
||||
How to monitor Synapse metrics using Prometheus
|
||||
===============================================
|
||||
|
||||
1: Install prometheus:
|
||||
Follow instructions at http://prometheus.io/docs/introduction/install/
|
||||
|
||||
2: Enable synapse metrics:
|
||||
Simply setting a (local) port number will enable it. Pick a port.
|
||||
prometheus itself defaults to 9090, so starting just above that for
|
||||
locally monitored services seems reasonable. E.g. 9092:
|
||||
|
||||
Add to homeserver.yaml
|
||||
|
||||
metrics_port: 9092
|
||||
|
||||
Restart synapse
|
||||
|
||||
3: Check out synapse-prometheus-config
|
||||
https://github.com/matrix-org/synapse-prometheus-config
|
||||
|
||||
4: Add ``synapse.html`` and ``synapse.rules``
|
||||
The ``.html`` file needs to appear in prometheus's ``consoles`` directory,
|
||||
and the ``.rules`` file needs to be invoked somewhere in the main config
|
||||
file. A symlink to each from the git checkout into the prometheus directory
|
||||
might be easiest to ensure ``git pull`` keeps it updated.
|
||||
|
||||
5: Add a prometheus target for synapse
|
||||
This is easiest if prometheus runs on the same machine as synapse, as it can
|
||||
then just use localhost::
|
||||
|
||||
global: {
|
||||
rule_file: "synapse.rules"
|
||||
}
|
||||
|
||||
job: {
|
||||
name: "synapse"
|
||||
|
||||
target_group: {
|
||||
target: "http://localhost:9092/"
|
||||
}
|
||||
}
|
||||
|
||||
6: Start prometheus::
|
||||
|
||||
./prometheus -config.file=prometheus.conf
|
||||
|
||||
7: Wait a few seconds for it to start and perform the first scrape,
|
||||
then visit the console:
|
||||
|
||||
http://server-where-prometheus-runs:9090/consoles/synapse.html
|
||||
114
docs/postgres.rst
Normal file
114
docs/postgres.rst
Normal file
@@ -0,0 +1,114 @@
|
||||
Using Postgres
|
||||
--------------
|
||||
|
||||
Set up database
|
||||
===============
|
||||
|
||||
The PostgreSQL database used *must* have the correct encoding set, otherwise
|
||||
would not be able to store UTF8 strings. To create a database with the correct
|
||||
encoding use, e.g.::
|
||||
|
||||
CREATE DATABASE synapse
|
||||
ENCODING 'UTF8'
|
||||
LC_COLLATE='C'
|
||||
LC_CTYPE='C'
|
||||
template=template0
|
||||
OWNER synapse_user;
|
||||
|
||||
This would create an appropriate database named ``synapse`` owned by the
|
||||
``synapse_user`` user (which must already exist).
|
||||
|
||||
Set up client
|
||||
=============
|
||||
|
||||
Postgres support depends on the postgres python connector ``psycopg2``. In the
|
||||
virtual env::
|
||||
|
||||
sudo apt-get install libpq-dev
|
||||
pip install psycopg2
|
||||
|
||||
|
||||
Synapse config
|
||||
==============
|
||||
|
||||
When you are ready to start using PostgreSQL, add the following line to your
|
||||
config file::
|
||||
|
||||
database_config: <db_config_file>
|
||||
|
||||
Where ``<db_config_file>`` is the file name that points to a yaml file of the
|
||||
following form::
|
||||
|
||||
name: psycopg2
|
||||
args:
|
||||
user: <user>
|
||||
password: <pass>
|
||||
database: <db>
|
||||
host: <host>
|
||||
cp_min: 5
|
||||
cp_max: 10
|
||||
|
||||
All key, values in ``args`` are passed to the ``psycopg2.connect(..)``
|
||||
function, except keys beginning with ``cp_``, which are consumed by the twisted
|
||||
adbapi connection pool.
|
||||
|
||||
|
||||
Porting from SQLite
|
||||
===================
|
||||
|
||||
Overview
|
||||
~~~~~~~~
|
||||
|
||||
The script ``port_from_sqlite_to_postgres.py`` allows porting an existing
|
||||
synapse server backed by SQLite to using PostgreSQL. This is done in as a two
|
||||
phase process:
|
||||
|
||||
1. Copy the existing SQLite database to a separate location (while the server
|
||||
is down) and running the port script against that offline database.
|
||||
2. Shut down the server. Rerun the port script to port any data that has come
|
||||
in since taking the first snapshot. Restart server against the PostgreSQL
|
||||
database.
|
||||
|
||||
The port script is designed to be run repeatedly against newer snapshots of the
|
||||
SQLite database file. This makes it safe to repeat step 1 if there was a delay
|
||||
between taking the previous snapshot and being ready to do step 2.
|
||||
|
||||
It is safe to at any time kill the port script and restart it.
|
||||
|
||||
Using the port script
|
||||
~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Firstly, shut down the currently running synapse server and copy its database
|
||||
file (typically ``homeserver.db``) to another location. Once the copy is
|
||||
complete, restart synapse. For instance::
|
||||
|
||||
./synctl stop
|
||||
cp homeserver.db homeserver.db.snapshot
|
||||
./synctl start
|
||||
|
||||
Assuming your database config file (as described in the section *Synapse
|
||||
config*) is named ``database_config.yaml`` and the SQLite snapshot is at
|
||||
``homeserver.db.snapshot`` then simply run::
|
||||
|
||||
python scripts/port_from_sqlite_to_postgres.py \
|
||||
--sqlite-database homeserver.db.snapshot \
|
||||
--postgres-config database_config.yaml
|
||||
|
||||
The flag ``--curses`` displays a coloured curses progress UI.
|
||||
|
||||
If the script took a long time to complete, or time has otherwise passed since
|
||||
the original snapshot was taken, repeat the previous steps with a newer
|
||||
snapshot.
|
||||
|
||||
To complete the conversion shut down the synapse server and run the port
|
||||
script one last time, e.g. if the SQLite database is at ``homeserver.db``
|
||||
run::
|
||||
|
||||
python scripts/port_from_sqlite_to_postgres.py \
|
||||
--sqlite-database homeserver.db \
|
||||
--postgres-config database_config.yaml
|
||||
|
||||
Once that has completed, change the synapse config to point at the PostgreSQL
|
||||
database configuration file using the ``database_config`` parameter (see
|
||||
`Synapse Config`_) and restart synapse. Synapse should now be running against
|
||||
PostgreSQL.
|
||||
@@ -33,10 +33,9 @@ def request_registration(user, password, server_location, shared_secret):
|
||||
).hexdigest()
|
||||
|
||||
data = {
|
||||
"user": user,
|
||||
"username": user,
|
||||
"password": password,
|
||||
"mac": mac,
|
||||
"type": "org.matrix.login.shared_secret",
|
||||
}
|
||||
|
||||
server_location = server_location.rstrip("/")
|
||||
@@ -44,12 +43,17 @@ def request_registration(user, password, server_location, shared_secret):
|
||||
print "Sending registration request..."
|
||||
|
||||
req = urllib2.Request(
|
||||
"%s/_matrix/client/api/v1/register" % (server_location,),
|
||||
"%s/_matrix/client/v2_alpha/register" % (server_location,),
|
||||
data=json.dumps(data),
|
||||
headers={'Content-Type': 'application/json'}
|
||||
)
|
||||
try:
|
||||
f = urllib2.urlopen(req)
|
||||
if sys.version_info[:3] >= (2, 7, 9):
|
||||
# As of version 2.7.9, urllib2 now checks SSL certs
|
||||
import ssl
|
||||
f = urllib2.urlopen(req, context=ssl.SSLContext(ssl.PROTOCOL_SSLv23))
|
||||
else:
|
||||
f = urllib2.urlopen(req)
|
||||
f.read()
|
||||
f.close()
|
||||
print "Success."
|
||||
|
||||
758
scripts/port_from_sqlite_to_postgres.py
Normal file
758
scripts/port_from_sqlite_to_postgres.py
Normal file
@@ -0,0 +1,758 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2015 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from twisted.internet import defer, reactor
|
||||
from twisted.enterprise import adbapi
|
||||
|
||||
from synapse.storage._base import LoggingTransaction, SQLBaseStore
|
||||
from synapse.storage.engines import create_engine
|
||||
|
||||
import argparse
|
||||
import curses
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
import yaml
|
||||
|
||||
|
||||
logger = logging.getLogger("port_from_sqlite_to_postgres")
|
||||
|
||||
|
||||
BOOLEAN_COLUMNS = {
|
||||
"events": ["processed", "outlier"],
|
||||
"rooms": ["is_public"],
|
||||
"event_edges": ["is_state"],
|
||||
"presence_list": ["accepted"],
|
||||
}
|
||||
|
||||
|
||||
APPEND_ONLY_TABLES = [
|
||||
"event_content_hashes",
|
||||
"event_reference_hashes",
|
||||
"event_signatures",
|
||||
"event_edge_hashes",
|
||||
"events",
|
||||
"event_json",
|
||||
"state_events",
|
||||
"room_memberships",
|
||||
"feedback",
|
||||
"topics",
|
||||
"room_names",
|
||||
"rooms",
|
||||
"local_media_repository",
|
||||
"local_media_repository_thumbnails",
|
||||
"remote_media_cache",
|
||||
"remote_media_cache_thumbnails",
|
||||
"redactions",
|
||||
"event_edges",
|
||||
"event_auth",
|
||||
"received_transactions",
|
||||
"sent_transactions",
|
||||
"transaction_id_to_pdu",
|
||||
"users",
|
||||
"state_groups",
|
||||
"state_groups_state",
|
||||
"event_to_state_groups",
|
||||
"rejections",
|
||||
]
|
||||
|
||||
|
||||
end_error_exec_info = None
|
||||
|
||||
|
||||
class Store(object):
|
||||
"""This object is used to pull out some of the convenience API from the
|
||||
Storage layer.
|
||||
|
||||
*All* database interactions should go through this object.
|
||||
"""
|
||||
def __init__(self, db_pool, engine):
|
||||
self.db_pool = db_pool
|
||||
self.database_engine = engine
|
||||
|
||||
_simple_insert_txn = SQLBaseStore.__dict__["_simple_insert_txn"]
|
||||
_simple_insert = SQLBaseStore.__dict__["_simple_insert"]
|
||||
|
||||
_simple_select_onecol_txn = SQLBaseStore.__dict__["_simple_select_onecol_txn"]
|
||||
_simple_select_onecol = SQLBaseStore.__dict__["_simple_select_onecol"]
|
||||
_simple_select_one_onecol = SQLBaseStore.__dict__["_simple_select_one_onecol"]
|
||||
_simple_select_one_onecol_txn = SQLBaseStore.__dict__["_simple_select_one_onecol_txn"]
|
||||
|
||||
_simple_update_one = SQLBaseStore.__dict__["_simple_update_one"]
|
||||
_simple_update_one_txn = SQLBaseStore.__dict__["_simple_update_one_txn"]
|
||||
|
||||
_execute_and_decode = SQLBaseStore.__dict__["_execute_and_decode"]
|
||||
|
||||
def runInteraction(self, desc, func, *args, **kwargs):
|
||||
def r(conn):
|
||||
try:
|
||||
i = 0
|
||||
N = 5
|
||||
while True:
|
||||
try:
|
||||
txn = conn.cursor()
|
||||
return func(
|
||||
LoggingTransaction(txn, desc, self.database_engine),
|
||||
*args, **kwargs
|
||||
)
|
||||
except self.database_engine.module.DatabaseError as e:
|
||||
if self.database_engine.is_deadlock(e):
|
||||
logger.warn("[TXN DEADLOCK] {%s} %d/%d", desc, i, N)
|
||||
if i < N:
|
||||
i += 1
|
||||
conn.rollback()
|
||||
continue
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.debug("[TXN FAIL] {%s} %s", desc, e)
|
||||
raise
|
||||
|
||||
return self.db_pool.runWithConnection(r)
|
||||
|
||||
def execute(self, f, *args, **kwargs):
|
||||
return self.runInteraction(f.__name__, f, *args, **kwargs)
|
||||
|
||||
def execute_sql(self, sql, *args):
|
||||
def r(txn):
|
||||
txn.execute(sql, args)
|
||||
return txn.fetchall()
|
||||
return self.runInteraction("execute_sql", r)
|
||||
|
||||
def insert_many_txn(self, txn, table, headers, rows):
|
||||
sql = "INSERT INTO %s (%s) VALUES (%s)" % (
|
||||
table,
|
||||
", ".join(k for k in headers),
|
||||
", ".join("%s" for _ in headers)
|
||||
)
|
||||
|
||||
try:
|
||||
txn.executemany(sql, rows)
|
||||
except:
|
||||
logger.exception(
|
||||
"Failed to insert: %s",
|
||||
table,
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
class Porter(object):
|
||||
def __init__(self, **kwargs):
|
||||
self.__dict__.update(kwargs)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def setup_table(self, table):
|
||||
if table in APPEND_ONLY_TABLES:
|
||||
# It's safe to just carry on inserting.
|
||||
next_chunk = yield self.postgres_store._simple_select_one_onecol(
|
||||
table="port_from_sqlite3",
|
||||
keyvalues={"table_name": table},
|
||||
retcol="rowid",
|
||||
allow_none=True,
|
||||
)
|
||||
|
||||
total_to_port = None
|
||||
if next_chunk is None:
|
||||
if table == "sent_transactions":
|
||||
next_chunk, already_ported, total_to_port = (
|
||||
yield self._setup_sent_transactions()
|
||||
)
|
||||
else:
|
||||
yield self.postgres_store._simple_insert(
|
||||
table="port_from_sqlite3",
|
||||
values={"table_name": table, "rowid": 1}
|
||||
)
|
||||
|
||||
next_chunk = 1
|
||||
already_ported = 0
|
||||
|
||||
if total_to_port is None:
|
||||
already_ported, total_to_port = yield self._get_total_count_to_port(
|
||||
table, next_chunk
|
||||
)
|
||||
else:
|
||||
def delete_all(txn):
|
||||
txn.execute(
|
||||
"DELETE FROM port_from_sqlite3 WHERE table_name = %s",
|
||||
(table,)
|
||||
)
|
||||
txn.execute("TRUNCATE %s CASCADE" % (table,))
|
||||
|
||||
yield self.postgres_store.execute(delete_all)
|
||||
|
||||
yield self.postgres_store._simple_insert(
|
||||
table="port_from_sqlite3",
|
||||
values={"table_name": table, "rowid": 0}
|
||||
)
|
||||
|
||||
next_chunk = 1
|
||||
|
||||
already_ported, total_to_port = yield self._get_total_count_to_port(
|
||||
table, next_chunk
|
||||
)
|
||||
|
||||
defer.returnValue((table, already_ported, total_to_port, next_chunk))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def handle_table(self, table, postgres_size, table_size, next_chunk):
|
||||
if not table_size:
|
||||
return
|
||||
|
||||
self.progress.add_table(table, postgres_size, table_size)
|
||||
|
||||
select = (
|
||||
"SELECT rowid, * FROM %s WHERE rowid >= ? ORDER BY rowid LIMIT ?"
|
||||
% (table,)
|
||||
)
|
||||
|
||||
while True:
|
||||
def r(txn):
|
||||
txn.execute(select, (next_chunk, self.batch_size,))
|
||||
rows = txn.fetchall()
|
||||
headers = [column[0] for column in txn.description]
|
||||
|
||||
return headers, rows
|
||||
|
||||
headers, rows = yield self.sqlite_store.runInteraction("select", r)
|
||||
|
||||
if rows:
|
||||
next_chunk = rows[-1][0] + 1
|
||||
|
||||
self._convert_rows(table, headers, rows)
|
||||
|
||||
def insert(txn):
|
||||
self.postgres_store.insert_many_txn(
|
||||
txn, table, headers[1:], rows
|
||||
)
|
||||
|
||||
self.postgres_store._simple_update_one_txn(
|
||||
txn,
|
||||
table="port_from_sqlite3",
|
||||
keyvalues={"table_name": table},
|
||||
updatevalues={"rowid": next_chunk},
|
||||
)
|
||||
|
||||
yield self.postgres_store.execute(insert)
|
||||
|
||||
postgres_size += len(rows)
|
||||
|
||||
self.progress.update(table, postgres_size)
|
||||
else:
|
||||
return
|
||||
|
||||
def setup_db(self, db_config, database_engine):
|
||||
db_conn = database_engine.module.connect(
|
||||
**{
|
||||
k: v for k, v in db_config.get("args", {}).items()
|
||||
if not k.startswith("cp_")
|
||||
}
|
||||
)
|
||||
|
||||
database_engine.prepare_database(db_conn)
|
||||
|
||||
db_conn.commit()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def run(self):
|
||||
try:
|
||||
sqlite_db_pool = adbapi.ConnectionPool(
|
||||
self.sqlite_config["name"],
|
||||
**self.sqlite_config["args"]
|
||||
)
|
||||
|
||||
postgres_db_pool = adbapi.ConnectionPool(
|
||||
self.postgres_config["name"],
|
||||
**self.postgres_config["args"]
|
||||
)
|
||||
|
||||
sqlite_engine = create_engine("sqlite3")
|
||||
postgres_engine = create_engine("psycopg2")
|
||||
|
||||
self.sqlite_store = Store(sqlite_db_pool, sqlite_engine)
|
||||
self.postgres_store = Store(postgres_db_pool, postgres_engine)
|
||||
|
||||
yield self.postgres_store.execute(
|
||||
postgres_engine.check_database
|
||||
)
|
||||
|
||||
# Step 1. Set up databases.
|
||||
self.progress.set_state("Preparing SQLite3")
|
||||
self.setup_db(sqlite_config, sqlite_engine)
|
||||
|
||||
self.progress.set_state("Preparing PostgreSQL")
|
||||
self.setup_db(postgres_config, postgres_engine)
|
||||
|
||||
# Step 2. Get tables.
|
||||
self.progress.set_state("Fetching tables")
|
||||
sqlite_tables = yield self.sqlite_store._simple_select_onecol(
|
||||
table="sqlite_master",
|
||||
keyvalues={
|
||||
"type": "table",
|
||||
},
|
||||
retcol="name",
|
||||
)
|
||||
|
||||
postgres_tables = yield self.postgres_store._simple_select_onecol(
|
||||
table="information_schema.tables",
|
||||
keyvalues={
|
||||
"table_schema": "public",
|
||||
},
|
||||
retcol="distinct table_name",
|
||||
)
|
||||
|
||||
tables = set(sqlite_tables) & set(postgres_tables)
|
||||
|
||||
self.progress.set_state("Creating tables")
|
||||
|
||||
logger.info("Found %d tables", len(tables))
|
||||
|
||||
def create_port_table(txn):
|
||||
txn.execute(
|
||||
"CREATE TABLE port_from_sqlite3 ("
|
||||
" table_name varchar(100) NOT NULL UNIQUE,"
|
||||
" rowid bigint NOT NULL"
|
||||
")"
|
||||
)
|
||||
|
||||
try:
|
||||
yield self.postgres_store.runInteraction(
|
||||
"create_port_table", create_port_table
|
||||
)
|
||||
except Exception as e:
|
||||
logger.info("Failed to create port table: %s", e)
|
||||
|
||||
self.progress.set_state("Setting up")
|
||||
|
||||
# Set up tables.
|
||||
setup_res = yield defer.gatherResults(
|
||||
[
|
||||
self.setup_table(table)
|
||||
for table in tables
|
||||
if table not in ["schema_version", "applied_schema_deltas"]
|
||||
and not table.startswith("sqlite_")
|
||||
],
|
||||
consumeErrors=True,
|
||||
)
|
||||
|
||||
# Process tables.
|
||||
yield defer.gatherResults(
|
||||
[
|
||||
self.handle_table(*res)
|
||||
for res in setup_res
|
||||
],
|
||||
consumeErrors=True,
|
||||
)
|
||||
|
||||
self.progress.done()
|
||||
except:
|
||||
global end_error_exec_info
|
||||
end_error_exec_info = sys.exc_info()
|
||||
logger.exception("")
|
||||
finally:
|
||||
reactor.stop()
|
||||
|
||||
def _convert_rows(self, table, headers, rows):
|
||||
bool_col_names = BOOLEAN_COLUMNS.get(table, [])
|
||||
|
||||
bool_cols = [
|
||||
i for i, h in enumerate(headers) if h in bool_col_names
|
||||
]
|
||||
|
||||
def conv(j, col):
|
||||
if j in bool_cols:
|
||||
return bool(col)
|
||||
return col
|
||||
|
||||
for i, row in enumerate(rows):
|
||||
rows[i] = tuple(
|
||||
self.postgres_store.database_engine.encode_parameter(
|
||||
conv(j, col)
|
||||
)
|
||||
for j, col in enumerate(row)
|
||||
if j > 0
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _setup_sent_transactions(self):
|
||||
# Only save things from the last day
|
||||
yesterday = int(time.time()*1000) - 86400000
|
||||
|
||||
# And save the max transaction id from each destination
|
||||
select = (
|
||||
"SELECT rowid, * FROM sent_transactions WHERE rowid IN ("
|
||||
"SELECT max(rowid) FROM sent_transactions"
|
||||
" GROUP BY destination"
|
||||
")"
|
||||
)
|
||||
|
||||
def r(txn):
|
||||
txn.execute(select)
|
||||
rows = txn.fetchall()
|
||||
headers = [column[0] for column in txn.description]
|
||||
|
||||
ts_ind = headers.index('ts')
|
||||
|
||||
return headers, [r for r in rows if r[ts_ind] < yesterday]
|
||||
|
||||
headers, rows = yield self.sqlite_store.runInteraction(
|
||||
"select", r,
|
||||
)
|
||||
|
||||
self._convert_rows("sent_transactions", headers, rows)
|
||||
|
||||
inserted_rows = len(rows)
|
||||
max_inserted_rowid = max(r[0] for r in rows)
|
||||
|
||||
def insert(txn):
|
||||
self.postgres_store.insert_many_txn(
|
||||
txn, "sent_transactions", headers[1:], rows
|
||||
)
|
||||
|
||||
yield self.postgres_store.execute(insert)
|
||||
|
||||
def get_start_id(txn):
|
||||
txn.execute(
|
||||
"SELECT rowid FROM sent_transactions WHERE ts >= ?"
|
||||
" ORDER BY rowid ASC LIMIT 1",
|
||||
(yesterday,)
|
||||
)
|
||||
|
||||
rows = txn.fetchall()
|
||||
if rows:
|
||||
return rows[0][0]
|
||||
else:
|
||||
return 1
|
||||
|
||||
next_chunk = yield self.sqlite_store.execute(get_start_id)
|
||||
next_chunk = max(max_inserted_rowid + 1, next_chunk)
|
||||
|
||||
yield self.postgres_store._simple_insert(
|
||||
table="port_from_sqlite3",
|
||||
values={"table_name": "sent_transactions", "rowid": next_chunk}
|
||||
)
|
||||
|
||||
def get_sent_table_size(txn):
|
||||
txn.execute(
|
||||
"SELECT count(*) FROM sent_transactions"
|
||||
" WHERE ts >= ?",
|
||||
(yesterday,)
|
||||
)
|
||||
size, = txn.fetchone()
|
||||
return int(size)
|
||||
|
||||
remaining_count = yield self.sqlite_store.execute(
|
||||
get_sent_table_size
|
||||
)
|
||||
|
||||
total_count = remaining_count + inserted_rows
|
||||
|
||||
defer.returnValue((next_chunk, inserted_rows, total_count))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _get_remaining_count_to_port(self, table, next_chunk):
|
||||
rows = yield self.sqlite_store.execute_sql(
|
||||
"SELECT count(*) FROM %s WHERE rowid >= ?" % (table,),
|
||||
next_chunk,
|
||||
)
|
||||
|
||||
defer.returnValue(rows[0][0])
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _get_already_ported_count(self, table):
|
||||
rows = yield self.postgres_store.execute_sql(
|
||||
"SELECT count(*) FROM %s" % (table,),
|
||||
)
|
||||
|
||||
defer.returnValue(rows[0][0])
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _get_total_count_to_port(self, table, next_chunk):
|
||||
remaining, done = yield defer.gatherResults(
|
||||
[
|
||||
self._get_remaining_count_to_port(table, next_chunk),
|
||||
self._get_already_ported_count(table),
|
||||
],
|
||||
consumeErrors=True,
|
||||
)
|
||||
|
||||
remaining = int(remaining) if remaining else 0
|
||||
done = int(done) if done else 0
|
||||
|
||||
defer.returnValue((done, remaining + done))
|
||||
|
||||
|
||||
##############################################
|
||||
###### The following is simply UI stuff ######
|
||||
##############################################
|
||||
|
||||
|
||||
class Progress(object):
|
||||
"""Used to report progress of the port
|
||||
"""
|
||||
def __init__(self):
|
||||
self.tables = {}
|
||||
|
||||
self.start_time = int(time.time())
|
||||
|
||||
def add_table(self, table, cur, size):
|
||||
self.tables[table] = {
|
||||
"start": cur,
|
||||
"num_done": cur,
|
||||
"total": size,
|
||||
"perc": int(cur * 100 / size),
|
||||
}
|
||||
|
||||
def update(self, table, num_done):
|
||||
data = self.tables[table]
|
||||
data["num_done"] = num_done
|
||||
data["perc"] = int(num_done * 100 / data["total"])
|
||||
|
||||
def done(self):
|
||||
pass
|
||||
|
||||
|
||||
class CursesProgress(Progress):
|
||||
"""Reports progress to a curses window
|
||||
"""
|
||||
def __init__(self, stdscr):
|
||||
self.stdscr = stdscr
|
||||
|
||||
curses.use_default_colors()
|
||||
curses.curs_set(0)
|
||||
|
||||
curses.init_pair(1, curses.COLOR_RED, -1)
|
||||
curses.init_pair(2, curses.COLOR_GREEN, -1)
|
||||
|
||||
self.last_update = 0
|
||||
|
||||
self.finished = False
|
||||
|
||||
self.total_processed = 0
|
||||
self.total_remaining = 0
|
||||
|
||||
super(CursesProgress, self).__init__()
|
||||
|
||||
def update(self, table, num_done):
|
||||
super(CursesProgress, self).update(table, num_done)
|
||||
|
||||
self.total_processed = 0
|
||||
self.total_remaining = 0
|
||||
for table, data in self.tables.items():
|
||||
self.total_processed += data["num_done"] - data["start"]
|
||||
self.total_remaining += data["total"] - data["num_done"]
|
||||
|
||||
self.render()
|
||||
|
||||
def render(self, force=False):
|
||||
now = time.time()
|
||||
|
||||
if not force and now - self.last_update < 0.2:
|
||||
# reactor.callLater(1, self.render)
|
||||
return
|
||||
|
||||
self.stdscr.clear()
|
||||
|
||||
rows, cols = self.stdscr.getmaxyx()
|
||||
|
||||
duration = int(now) - int(self.start_time)
|
||||
|
||||
minutes, seconds = divmod(duration, 60)
|
||||
duration_str = '%02dm %02ds' % (minutes, seconds,)
|
||||
|
||||
if self.finished:
|
||||
status = "Time spent: %s (Done!)" % (duration_str,)
|
||||
else:
|
||||
|
||||
if self.total_processed > 0:
|
||||
left = float(self.total_remaining) / self.total_processed
|
||||
|
||||
est_remaining = (int(now) - self.start_time) * left
|
||||
est_remaining_str = '%02dm %02ds remaining' % divmod(est_remaining, 60)
|
||||
else:
|
||||
est_remaining_str = "Unknown"
|
||||
status = (
|
||||
"Time spent: %s (est. remaining: %s)"
|
||||
% (duration_str, est_remaining_str,)
|
||||
)
|
||||
|
||||
self.stdscr.addstr(
|
||||
0, 0,
|
||||
status,
|
||||
curses.A_BOLD,
|
||||
)
|
||||
|
||||
max_len = max([len(t) for t in self.tables.keys()])
|
||||
|
||||
left_margin = 5
|
||||
middle_space = 1
|
||||
|
||||
items = self.tables.items()
|
||||
items.sort(
|
||||
key=lambda i: (i[1]["perc"], i[0]),
|
||||
)
|
||||
|
||||
for i, (table, data) in enumerate(items):
|
||||
if i + 2 >= rows:
|
||||
break
|
||||
|
||||
perc = data["perc"]
|
||||
|
||||
color = curses.color_pair(2) if perc == 100 else curses.color_pair(1)
|
||||
|
||||
self.stdscr.addstr(
|
||||
i+2, left_margin + max_len - len(table),
|
||||
table,
|
||||
curses.A_BOLD | color,
|
||||
)
|
||||
|
||||
size = 20
|
||||
|
||||
progress = "[%s%s]" % (
|
||||
"#" * int(perc*size/100),
|
||||
" " * (size - int(perc*size/100)),
|
||||
)
|
||||
|
||||
self.stdscr.addstr(
|
||||
i+2, left_margin + max_len + middle_space,
|
||||
"%s %3d%% (%d/%d)" % (progress, perc, data["num_done"], data["total"]),
|
||||
)
|
||||
|
||||
if self.finished:
|
||||
self.stdscr.addstr(
|
||||
rows-1, 0,
|
||||
"Press any key to exit...",
|
||||
)
|
||||
|
||||
self.stdscr.refresh()
|
||||
self.last_update = time.time()
|
||||
|
||||
def done(self):
|
||||
self.finished = True
|
||||
self.render(True)
|
||||
self.stdscr.getch()
|
||||
|
||||
def set_state(self, state):
|
||||
self.stdscr.clear()
|
||||
self.stdscr.addstr(
|
||||
0, 0,
|
||||
state + "...",
|
||||
curses.A_BOLD,
|
||||
)
|
||||
self.stdscr.refresh()
|
||||
|
||||
|
||||
class TerminalProgress(Progress):
|
||||
"""Just prints progress to the terminal
|
||||
"""
|
||||
def update(self, table, num_done):
|
||||
super(TerminalProgress, self).update(table, num_done)
|
||||
|
||||
data = self.tables[table]
|
||||
|
||||
print "%s: %d%% (%d/%d)" % (
|
||||
table, data["perc"],
|
||||
data["num_done"], data["total"],
|
||||
)
|
||||
|
||||
def set_state(self, state):
|
||||
print state + "..."
|
||||
|
||||
|
||||
##############################################
|
||||
##############################################
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="A script to port an existing synapse SQLite database to"
|
||||
" a new PostgreSQL database."
|
||||
)
|
||||
parser.add_argument("-v", action='store_true')
|
||||
parser.add_argument(
|
||||
"--sqlite-database", required=True,
|
||||
help="The snapshot of the SQLite database file. This must not be"
|
||||
" currently used by a running synapse server"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--postgres-config", type=argparse.FileType('r'), required=True,
|
||||
help="The database config file for the PostgreSQL database"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--curses", action='store_true',
|
||||
help="display a curses based progress UI"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--batch-size", type=int, default=1000,
|
||||
help="The number of rows to select from the SQLite table each"
|
||||
" iteration [default=1000]",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
logging_config = {
|
||||
"level": logging.DEBUG if args.v else logging.INFO,
|
||||
"format": "%(asctime)s - %(name)s - %(lineno)d - %(levelname)s - %(message)s"
|
||||
}
|
||||
|
||||
if args.curses:
|
||||
logging_config["filename"] = "port-synapse.log"
|
||||
|
||||
logging.basicConfig(**logging_config)
|
||||
|
||||
sqlite_config = {
|
||||
"name": "sqlite3",
|
||||
"args": {
|
||||
"database": args.sqlite_database,
|
||||
"cp_min": 1,
|
||||
"cp_max": 1,
|
||||
"check_same_thread": False,
|
||||
},
|
||||
}
|
||||
|
||||
postgres_config = yaml.safe_load(args.postgres_config)
|
||||
|
||||
if "name" not in postgres_config:
|
||||
sys.stderr.write("Malformed database config: no 'name'")
|
||||
sys.exit(2)
|
||||
if postgres_config["name"] != "psycopg2":
|
||||
sys.stderr.write("Database must use 'psycopg2' connector.")
|
||||
sys.exit(3)
|
||||
|
||||
def start(stdscr=None):
|
||||
if stdscr:
|
||||
progress = CursesProgress(stdscr)
|
||||
else:
|
||||
progress = TerminalProgress()
|
||||
|
||||
porter = Porter(
|
||||
sqlite_config=sqlite_config,
|
||||
postgres_config=postgres_config,
|
||||
progress=progress,
|
||||
batch_size=args.batch_size,
|
||||
)
|
||||
|
||||
reactor.callWhenRunning(porter.run)
|
||||
|
||||
reactor.run()
|
||||
|
||||
if args.curses:
|
||||
curses.wrapper(start)
|
||||
else:
|
||||
start()
|
||||
|
||||
if end_error_exec_info:
|
||||
exc_type, exc_value, exc_traceback = end_error_exec_info
|
||||
traceback.print_exception(exc_type, exc_value, exc_traceback)
|
||||
2
setup.py
2
setup.py
@@ -45,7 +45,7 @@ setup(
|
||||
version=version,
|
||||
packages=find_packages(exclude=["tests", "tests.*"]),
|
||||
description="Reference Synapse Home Server",
|
||||
install_requires=dependencies["REQUIREMENTS"].keys(),
|
||||
install_requires=dependencies['requirements'](include_conditional=True).keys(),
|
||||
setup_requires=[
|
||||
"Twisted==14.0.2", # Here to override setuptools_trial's dependency on Twisted>=2.4.0
|
||||
"setuptools_trial",
|
||||
|
||||
@@ -37,9 +37,13 @@ textarea, input {
|
||||
margin: auto
|
||||
}
|
||||
|
||||
.g-recaptcha div {
|
||||
margin: auto;
|
||||
}
|
||||
|
||||
#registrationForm {
|
||||
text-align: left;
|
||||
padding: 1em;
|
||||
padding: 5px;
|
||||
margin-bottom: 40px;
|
||||
display: inline-block;
|
||||
|
||||
|
||||
@@ -16,4 +16,4 @@
|
||||
""" This is a reference implementation of a Matrix home server.
|
||||
"""
|
||||
|
||||
__version__ = "0.8.1"
|
||||
__version__ = "0.8.1-r4"
|
||||
|
||||
@@ -18,7 +18,7 @@
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.constants import EventTypes, Membership, JoinRules
|
||||
from synapse.api.errors import AuthError, StoreError, Codes, SynapseError
|
||||
from synapse.api.errors import AuthError, Codes, SynapseError
|
||||
from synapse.util.logutils import log_function
|
||||
from synapse.util.async import run_on_reactor
|
||||
from synapse.types import UserID, ClientInfo
|
||||
@@ -40,6 +40,7 @@ class Auth(object):
|
||||
self.hs = hs
|
||||
self.store = hs.get_datastore()
|
||||
self.state = hs.get_state_handler()
|
||||
self.TOKEN_NOT_FOUND_HTTP_STATUS = 401
|
||||
|
||||
def check(self, event, auth_events):
|
||||
""" Checks if this event is correctly authed.
|
||||
@@ -183,18 +184,10 @@ class Auth(object):
|
||||
else:
|
||||
join_rule = JoinRules.INVITE
|
||||
|
||||
user_level = self._get_power_level_from_event_state(
|
||||
event,
|
||||
event.user_id,
|
||||
auth_events,
|
||||
)
|
||||
user_level = self._get_user_power_level(event.user_id, auth_events)
|
||||
|
||||
ban_level, kick_level, redact_level = (
|
||||
self._get_ops_level_from_event_state(
|
||||
event,
|
||||
auth_events,
|
||||
)
|
||||
)
|
||||
# FIXME (erikj): What should we do here as the default?
|
||||
ban_level = self._get_named_level(auth_events, "ban", 50)
|
||||
|
||||
logger.debug(
|
||||
"is_membership_change_allowed: %s",
|
||||
@@ -210,28 +203,33 @@ class Auth(object):
|
||||
}
|
||||
)
|
||||
|
||||
if ban_level:
|
||||
ban_level = int(ban_level)
|
||||
else:
|
||||
ban_level = 50 # FIXME (erikj): What should we do here?
|
||||
if Membership.JOIN != membership:
|
||||
# JOIN is the only action you can perform if you're not in the room
|
||||
if not caller_in_room: # caller isn't joined
|
||||
raise AuthError(
|
||||
403,
|
||||
"%s not in room %s." % (event.user_id, event.room_id,)
|
||||
)
|
||||
|
||||
if Membership.INVITE == membership:
|
||||
# TODO (erikj): We should probably handle this more intelligently
|
||||
# PRIVATE join rules.
|
||||
|
||||
# Invites are valid iff caller is in the room and target isn't.
|
||||
if not caller_in_room: # caller isn't joined
|
||||
raise AuthError(
|
||||
403,
|
||||
"%s not in room %s." % (event.user_id, event.room_id,)
|
||||
)
|
||||
elif target_banned:
|
||||
if target_banned:
|
||||
raise AuthError(
|
||||
403, "%s is banned from the room" % (target_user_id,)
|
||||
)
|
||||
elif target_in_room: # the target is already in the room.
|
||||
raise AuthError(403, "%s is already in the room." %
|
||||
target_user_id)
|
||||
else:
|
||||
invite_level = self._get_named_level(auth_events, "invite", 0)
|
||||
|
||||
if user_level < invite_level:
|
||||
raise AuthError(
|
||||
403, "You cannot invite user %s." % target_user_id
|
||||
)
|
||||
elif Membership.JOIN == membership:
|
||||
# Joins are valid iff caller == target and they were:
|
||||
# invited: They are accepting the invitation
|
||||
@@ -251,21 +249,12 @@ class Auth(object):
|
||||
raise AuthError(403, "You are not allowed to join this room")
|
||||
elif Membership.LEAVE == membership:
|
||||
# TODO (erikj): Implement kicks.
|
||||
|
||||
if not caller_in_room: # trying to leave a room you aren't joined
|
||||
raise AuthError(
|
||||
403,
|
||||
"%s not in room %s." % (target_user_id, event.room_id,)
|
||||
)
|
||||
elif target_banned and user_level < ban_level:
|
||||
if target_banned and user_level < ban_level:
|
||||
raise AuthError(
|
||||
403, "You cannot unban user &s." % (target_user_id,)
|
||||
)
|
||||
elif target_user_id != event.user_id:
|
||||
if kick_level:
|
||||
kick_level = int(kick_level)
|
||||
else:
|
||||
kick_level = 50 # FIXME (erikj): What should we do here?
|
||||
kick_level = self._get_named_level(auth_events, "kick", 50)
|
||||
|
||||
if user_level < kick_level:
|
||||
raise AuthError(
|
||||
@@ -279,34 +268,42 @@ class Auth(object):
|
||||
|
||||
return True
|
||||
|
||||
def _get_power_level_from_event_state(self, event, user_id, auth_events):
|
||||
def _get_power_level_event(self, auth_events):
|
||||
key = (EventTypes.PowerLevels, "", )
|
||||
power_level_event = auth_events.get(key)
|
||||
level = None
|
||||
return auth_events.get(key)
|
||||
|
||||
def _get_user_power_level(self, user_id, auth_events):
|
||||
power_level_event = self._get_power_level_event(auth_events)
|
||||
|
||||
if power_level_event:
|
||||
level = power_level_event.content.get("users", {}).get(user_id)
|
||||
if not level:
|
||||
level = power_level_event.content.get("users_default", 0)
|
||||
|
||||
if level is None:
|
||||
return 0
|
||||
else:
|
||||
return int(level)
|
||||
else:
|
||||
key = (EventTypes.Create, "", )
|
||||
create_event = auth_events.get(key)
|
||||
if (create_event is not None and
|
||||
create_event.content["creator"] == user_id):
|
||||
return 100
|
||||
else:
|
||||
return 0
|
||||
|
||||
return level
|
||||
def _get_named_level(self, auth_events, name, default):
|
||||
power_level_event = self._get_power_level_event(auth_events)
|
||||
|
||||
def _get_ops_level_from_event_state(self, event, auth_events):
|
||||
key = (EventTypes.PowerLevels, "", )
|
||||
power_level_event = auth_events.get(key)
|
||||
if not power_level_event:
|
||||
return default
|
||||
|
||||
if power_level_event:
|
||||
return (
|
||||
power_level_event.content.get("ban", 50),
|
||||
power_level_event.content.get("kick", 50),
|
||||
power_level_event.content.get("redact", 50),
|
||||
)
|
||||
return None, None, None,
|
||||
level = power_level_event.content.get(name, None)
|
||||
if level is not None:
|
||||
return int(level)
|
||||
else:
|
||||
return default
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_user_by_req(self, request):
|
||||
@@ -373,7 +370,10 @@ class Auth(object):
|
||||
|
||||
defer.returnValue((user, ClientInfo(device_id, token_id)))
|
||||
except KeyError:
|
||||
raise AuthError(403, "Missing access token.")
|
||||
raise AuthError(
|
||||
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.",
|
||||
errcode=Codes.MISSING_TOKEN
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_user_by_token(self, token):
|
||||
@@ -387,21 +387,20 @@ class Auth(object):
|
||||
Raises:
|
||||
AuthError if no user by that token exists or the token is invalid.
|
||||
"""
|
||||
try:
|
||||
ret = yield self.store.get_user_by_token(token)
|
||||
if not ret:
|
||||
raise StoreError(400, "Unknown token")
|
||||
user_info = {
|
||||
"admin": bool(ret.get("admin", False)),
|
||||
"device_id": ret.get("device_id"),
|
||||
"user": UserID.from_string(ret.get("name")),
|
||||
"token_id": ret.get("token_id", None),
|
||||
}
|
||||
ret = yield self.store.get_user_by_token(token)
|
||||
if not ret:
|
||||
raise AuthError(
|
||||
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Unrecognised access token.",
|
||||
errcode=Codes.UNKNOWN_TOKEN
|
||||
)
|
||||
user_info = {
|
||||
"admin": bool(ret.get("admin", False)),
|
||||
"device_id": ret.get("device_id"),
|
||||
"user": UserID.from_string(ret.get("name")),
|
||||
"token_id": ret.get("token_id", None),
|
||||
}
|
||||
|
||||
defer.returnValue(user_info)
|
||||
except StoreError:
|
||||
raise AuthError(403, "Unrecognised access token.",
|
||||
errcode=Codes.UNKNOWN_TOKEN)
|
||||
defer.returnValue(user_info)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_appservice_by_req(self, request):
|
||||
@@ -409,11 +408,16 @@ class Auth(object):
|
||||
token = request.args["access_token"][0]
|
||||
service = yield self.store.get_app_service_by_token(token)
|
||||
if not service:
|
||||
raise AuthError(403, "Unrecognised access token.",
|
||||
errcode=Codes.UNKNOWN_TOKEN)
|
||||
raise AuthError(
|
||||
self.TOKEN_NOT_FOUND_HTTP_STATUS,
|
||||
"Unrecognised access token.",
|
||||
errcode=Codes.UNKNOWN_TOKEN
|
||||
)
|
||||
defer.returnValue(service)
|
||||
except KeyError:
|
||||
raise AuthError(403, "Missing access token.")
|
||||
raise AuthError(
|
||||
self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token."
|
||||
)
|
||||
|
||||
def is_server_admin(self, user):
|
||||
return self.store.is_server_admin(user)
|
||||
@@ -486,7 +490,7 @@ class Auth(object):
|
||||
send_level = send_level_event.content.get("events", {}).get(
|
||||
event.type
|
||||
)
|
||||
if not send_level:
|
||||
if send_level is None:
|
||||
if hasattr(event, "state_key"):
|
||||
send_level = send_level_event.content.get(
|
||||
"state_default", 50
|
||||
@@ -501,16 +505,7 @@ class Auth(object):
|
||||
else:
|
||||
send_level = 0
|
||||
|
||||
user_level = self._get_power_level_from_event_state(
|
||||
event,
|
||||
event.user_id,
|
||||
auth_events,
|
||||
)
|
||||
|
||||
if user_level:
|
||||
user_level = int(user_level)
|
||||
else:
|
||||
user_level = 0
|
||||
user_level = self._get_user_power_level(event.user_id, auth_events)
|
||||
|
||||
if user_level < send_level:
|
||||
raise AuthError(
|
||||
@@ -542,16 +537,9 @@ class Auth(object):
|
||||
return True
|
||||
|
||||
def _check_redaction(self, event, auth_events):
|
||||
user_level = self._get_power_level_from_event_state(
|
||||
event,
|
||||
event.user_id,
|
||||
auth_events,
|
||||
)
|
||||
user_level = self._get_user_power_level(event.user_id, auth_events)
|
||||
|
||||
_, _, redact_level = self._get_ops_level_from_event_state(
|
||||
event,
|
||||
auth_events,
|
||||
)
|
||||
redact_level = self._get_named_level(auth_events, "redact", 50)
|
||||
|
||||
if user_level < redact_level:
|
||||
raise AuthError(
|
||||
@@ -579,11 +567,7 @@ class Auth(object):
|
||||
if not current_state:
|
||||
return
|
||||
|
||||
user_level = self._get_power_level_from_event_state(
|
||||
event,
|
||||
event.user_id,
|
||||
auth_events,
|
||||
)
|
||||
user_level = self._get_user_power_level(event.user_id, auth_events)
|
||||
|
||||
# Check other levels:
|
||||
levels_to_check = [
|
||||
@@ -592,6 +576,7 @@ class Auth(object):
|
||||
("ban", []),
|
||||
("redact", []),
|
||||
("kick", []),
|
||||
("invite", []),
|
||||
]
|
||||
|
||||
old_list = current_state.content.get("users")
|
||||
|
||||
@@ -59,6 +59,9 @@ class LoginType(object):
|
||||
EMAIL_URL = u"m.login.email.url"
|
||||
EMAIL_IDENTITY = u"m.login.email.identity"
|
||||
RECAPTCHA = u"m.login.recaptcha"
|
||||
DUMMY = u"m.login.dummy"
|
||||
|
||||
# Only for C/S API v1
|
||||
APPLICATION_SERVICE = u"m.login.application_service"
|
||||
SHARED_SECRET = u"org.matrix.login.shared_secret"
|
||||
|
||||
|
||||
@@ -31,13 +31,15 @@ class Codes(object):
|
||||
BAD_PAGINATION = "M_BAD_PAGINATION"
|
||||
UNKNOWN = "M_UNKNOWN"
|
||||
NOT_FOUND = "M_NOT_FOUND"
|
||||
MISSING_TOKEN = "M_MISSING_TOKEN"
|
||||
UNKNOWN_TOKEN = "M_UNKNOWN_TOKEN"
|
||||
LIMIT_EXCEEDED = "M_LIMIT_EXCEEDED"
|
||||
CAPTCHA_NEEDED = "M_CAPTCHA_NEEDED"
|
||||
CAPTCHA_INVALID = "M_CAPTCHA_INVALID"
|
||||
MISSING_PARAM = "M_MISSING_PARAM",
|
||||
TOO_LARGE = "M_TOO_LARGE",
|
||||
MISSING_PARAM = "M_MISSING_PARAM"
|
||||
TOO_LARGE = "M_TOO_LARGE"
|
||||
EXCLUSIVE = "M_EXCLUSIVE"
|
||||
THREEPID_AUTH_FAILED = "M_THREEPID_AUTH_FAILED"
|
||||
|
||||
|
||||
class CodeMessageException(RuntimeError):
|
||||
|
||||
@@ -22,5 +22,6 @@ STATIC_PREFIX = "/_matrix/static"
|
||||
WEB_CLIENT_PREFIX = "/_matrix/client"
|
||||
CONTENT_REPO_PREFIX = "/_matrix/content"
|
||||
SERVER_KEY_PREFIX = "/_matrix/key/v1"
|
||||
SERVER_KEY_V2_PREFIX = "/_matrix/key/v2"
|
||||
MEDIA_PREFIX = "/_matrix/media/v1"
|
||||
APP_SERVICE_PREFIX = "/_matrix/appservice/v1"
|
||||
|
||||
@@ -16,14 +16,18 @@
|
||||
|
||||
import sys
|
||||
sys.dont_write_bytecode = True
|
||||
from synapse.python_dependencies import check_requirements
|
||||
|
||||
if __name__ == '__main__':
|
||||
check_requirements()
|
||||
|
||||
from synapse.storage.engines import create_engine, IncorrectDatabaseSetup
|
||||
from synapse.storage import (
|
||||
prepare_database, prepare_sqlite3_database, UpgradeDatabaseException,
|
||||
are_all_users_on_domain, UpgradeDatabaseException,
|
||||
)
|
||||
|
||||
from synapse.server import HomeServer
|
||||
|
||||
from synapse.python_dependencies import check_requirements
|
||||
|
||||
from twisted.internet import reactor
|
||||
from twisted.application import service
|
||||
@@ -31,16 +35,17 @@ from twisted.enterprise import adbapi
|
||||
from twisted.web.resource import Resource
|
||||
from twisted.web.static import File
|
||||
from twisted.web.server import Site
|
||||
from twisted.web.http import proxiedLogFormatter, combinedLogFormatter
|
||||
from synapse.http.server import JsonResource, RootRedirect
|
||||
from synapse.rest.appservice.v1 import AppServiceRestResource
|
||||
from synapse.rest.media.v0.content_repository import ContentRepoResource
|
||||
from synapse.rest.media.v1.media_repository import MediaRepositoryResource
|
||||
from synapse.http.server_key_resource import LocalKey
|
||||
from synapse.rest.key.v1.server_key_resource import LocalKey
|
||||
from synapse.rest.key.v2 import KeyApiV2Resource
|
||||
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
|
||||
from synapse.api.urls import (
|
||||
CLIENT_PREFIX, FEDERATION_PREFIX, WEB_CLIENT_PREFIX, CONTENT_REPO_PREFIX,
|
||||
SERVER_KEY_PREFIX, MEDIA_PREFIX, CLIENT_V2_ALPHA_PREFIX, APP_SERVICE_PREFIX,
|
||||
STATIC_PREFIX
|
||||
SERVER_KEY_PREFIX, MEDIA_PREFIX, CLIENT_V2_ALPHA_PREFIX, STATIC_PREFIX,
|
||||
SERVER_KEY_V2_PREFIX,
|
||||
)
|
||||
from synapse.config.homeserver import HomeServerConfig
|
||||
from synapse.crypto import context_factory
|
||||
@@ -59,9 +64,9 @@ import os
|
||||
import re
|
||||
import resource
|
||||
import subprocess
|
||||
import sqlite3
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
logger = logging.getLogger("synapse.app.homeserver")
|
||||
|
||||
|
||||
class SynapseHomeServer(HomeServer):
|
||||
@@ -78,9 +83,6 @@ class SynapseHomeServer(HomeServer):
|
||||
def build_resource_for_federation(self):
|
||||
return JsonResource(self)
|
||||
|
||||
def build_resource_for_app_services(self):
|
||||
return AppServiceRestResource(self)
|
||||
|
||||
def build_resource_for_web_client(self):
|
||||
import syweb
|
||||
syweb_path = os.path.dirname(syweb.__file__)
|
||||
@@ -101,6 +103,9 @@ class SynapseHomeServer(HomeServer):
|
||||
def build_resource_for_server_key(self):
|
||||
return LocalKey(self)
|
||||
|
||||
def build_resource_for_server_key_v2(self):
|
||||
return KeyApiV2Resource(self)
|
||||
|
||||
def build_resource_for_metrics(self):
|
||||
if self.get_config().enable_metrics:
|
||||
return MetricsResource(self)
|
||||
@@ -108,13 +113,11 @@ class SynapseHomeServer(HomeServer):
|
||||
return None
|
||||
|
||||
def build_db_pool(self):
|
||||
name = self.db_config["name"]
|
||||
|
||||
return adbapi.ConnectionPool(
|
||||
"sqlite3", self.get_db_name(),
|
||||
check_same_thread=False,
|
||||
cp_min=1,
|
||||
cp_max=1,
|
||||
cp_openfun=prepare_database, # Prepare the database for each conn
|
||||
# so that :memory: sqlite works
|
||||
name,
|
||||
**self.db_config.get("args", {})
|
||||
)
|
||||
|
||||
def create_resource_tree(self, redirect_root_to_web_client):
|
||||
@@ -140,8 +143,8 @@ class SynapseHomeServer(HomeServer):
|
||||
(FEDERATION_PREFIX, self.get_resource_for_federation()),
|
||||
(CONTENT_REPO_PREFIX, self.get_resource_for_content_repo()),
|
||||
(SERVER_KEY_PREFIX, self.get_resource_for_server_key()),
|
||||
(SERVER_KEY_V2_PREFIX, self.get_resource_for_server_key_v2()),
|
||||
(MEDIA_PREFIX, self.get_resource_for_media_repository()),
|
||||
(APP_SERVICE_PREFIX, self.get_resource_for_app_services()),
|
||||
(STATIC_PREFIX, self.get_resource_for_static_content()),
|
||||
]
|
||||
|
||||
@@ -226,7 +229,11 @@ class SynapseHomeServer(HomeServer):
|
||||
if not config.no_tls and config.bind_port is not None:
|
||||
reactor.listenSSL(
|
||||
config.bind_port,
|
||||
Site(self.root_resource),
|
||||
SynapseSite(
|
||||
"synapse.access.https",
|
||||
config,
|
||||
self.root_resource,
|
||||
),
|
||||
self.tls_context_factory,
|
||||
interface=config.bind_host
|
||||
)
|
||||
@@ -235,7 +242,11 @@ class SynapseHomeServer(HomeServer):
|
||||
if config.unsecure_port is not None:
|
||||
reactor.listenTCP(
|
||||
config.unsecure_port,
|
||||
Site(self.root_resource),
|
||||
SynapseSite(
|
||||
"synapse.access.http",
|
||||
config,
|
||||
self.root_resource,
|
||||
),
|
||||
interface=config.bind_host
|
||||
)
|
||||
logger.info("Synapse now listening on port %d", config.unsecure_port)
|
||||
@@ -243,10 +254,43 @@ class SynapseHomeServer(HomeServer):
|
||||
metrics_resource = self.get_resource_for_metrics()
|
||||
if metrics_resource and config.metrics_port is not None:
|
||||
reactor.listenTCP(
|
||||
config.metrics_port, Site(metrics_resource), interface="127.0.0.1",
|
||||
config.metrics_port,
|
||||
SynapseSite(
|
||||
"synapse.access.metrics",
|
||||
config,
|
||||
metrics_resource,
|
||||
),
|
||||
interface="127.0.0.1",
|
||||
)
|
||||
logger.info("Metrics now running on 127.0.0.1 port %d", config.metrics_port)
|
||||
|
||||
def run_startup_checks(self, db_conn, database_engine):
|
||||
all_users_native = are_all_users_on_domain(
|
||||
db_conn.cursor(), database_engine, self.hostname
|
||||
)
|
||||
if not all_users_native:
|
||||
quit_with_error(
|
||||
"Found users in database not native to %s!\n"
|
||||
"You cannot changed a synapse server_name after it's been configured"
|
||||
% (self.hostname,)
|
||||
)
|
||||
|
||||
try:
|
||||
database_engine.check_database(db_conn.cursor())
|
||||
except IncorrectDatabaseSetup as e:
|
||||
quit_with_error(e.message)
|
||||
|
||||
|
||||
def quit_with_error(error_string):
|
||||
message_lines = error_string.split("\n")
|
||||
line_length = max([len(l) for l in message_lines]) + 2
|
||||
sys.stderr.write("*" * line_length + '\n')
|
||||
for line in message_lines:
|
||||
if line.strip():
|
||||
sys.stderr.write(" %s\n" % (line.strip(),))
|
||||
sys.stderr.write("*" * line_length + '\n')
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def get_version_string():
|
||||
try:
|
||||
@@ -358,15 +402,20 @@ def setup(config_options):
|
||||
|
||||
tls_context_factory = context_factory.ServerContextFactory(config)
|
||||
|
||||
database_engine = create_engine(config.database_config["name"])
|
||||
config.database_config["args"]["cp_openfun"] = database_engine.on_new_connection
|
||||
|
||||
hs = SynapseHomeServer(
|
||||
config.server_name,
|
||||
domain_with_port=domain_with_port,
|
||||
upload_dir=os.path.abspath("uploads"),
|
||||
db_name=config.database_path,
|
||||
db_config=config.database_config,
|
||||
tls_context_factory=tls_context_factory,
|
||||
config=config,
|
||||
content_addr=config.content_addr,
|
||||
version_string=version_string,
|
||||
database_engine=database_engine,
|
||||
)
|
||||
|
||||
hs.create_resource_tree(
|
||||
@@ -378,9 +427,17 @@ def setup(config_options):
|
||||
logger.info("Preparing database: %s...", db_name)
|
||||
|
||||
try:
|
||||
with sqlite3.connect(db_name) as db_conn:
|
||||
prepare_sqlite3_database(db_conn)
|
||||
prepare_database(db_conn)
|
||||
db_conn = database_engine.module.connect(
|
||||
**{
|
||||
k: v for k, v in config.database_config.get("args", {}).items()
|
||||
if not k.startswith("cp_")
|
||||
}
|
||||
)
|
||||
|
||||
database_engine.prepare_database(db_conn)
|
||||
hs.run_startup_checks(db_conn, database_engine)
|
||||
|
||||
db_conn.commit()
|
||||
except UpgradeDatabaseException:
|
||||
sys.stderr.write(
|
||||
"\nFailed to upgrade database.\n"
|
||||
@@ -423,6 +480,24 @@ class SynapseService(service.Service):
|
||||
return self._port.stopListening()
|
||||
|
||||
|
||||
class SynapseSite(Site):
|
||||
"""
|
||||
Subclass of a twisted http Site that does access logging with python's
|
||||
standard logging
|
||||
"""
|
||||
def __init__(self, logger_name, config, resource, *args, **kwargs):
|
||||
Site.__init__(self, resource, *args, **kwargs)
|
||||
if config.captcha_ip_origin_is_x_forwarded:
|
||||
self._log_formatter = proxiedLogFormatter
|
||||
else:
|
||||
self._log_formatter = combinedLogFormatter
|
||||
self.access_logger = logging.getLogger(logger_name)
|
||||
|
||||
def log(self, request):
|
||||
line = self._log_formatter(self._logDateTime, request)
|
||||
self.access_logger.info(line)
|
||||
|
||||
|
||||
def run(hs):
|
||||
|
||||
def in_thread():
|
||||
|
||||
@@ -20,6 +20,50 @@ import re
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ApplicationServiceState(object):
|
||||
DOWN = "down"
|
||||
UP = "up"
|
||||
|
||||
|
||||
class AppServiceTransaction(object):
|
||||
"""Represents an application service transaction."""
|
||||
|
||||
def __init__(self, service, id, events):
|
||||
self.service = service
|
||||
self.id = id
|
||||
self.events = events
|
||||
|
||||
def send(self, as_api):
|
||||
"""Sends this transaction using the provided AS API interface.
|
||||
|
||||
Args:
|
||||
as_api(ApplicationServiceApi): The API to use to send.
|
||||
Returns:
|
||||
A Deferred which resolves to True if the transaction was sent.
|
||||
"""
|
||||
return as_api.push_bulk(
|
||||
service=self.service,
|
||||
events=self.events,
|
||||
txn_id=self.id
|
||||
)
|
||||
|
||||
def complete(self, store):
|
||||
"""Completes this transaction as successful.
|
||||
|
||||
Marks this transaction ID on the application service and removes the
|
||||
transaction contents from the database.
|
||||
|
||||
Args:
|
||||
store: The database store to operate on.
|
||||
Returns:
|
||||
A Deferred which resolves to True if the transaction was completed.
|
||||
"""
|
||||
return store.complete_appservice_txn(
|
||||
service=self.service,
|
||||
txn_id=self.id
|
||||
)
|
||||
|
||||
|
||||
class ApplicationService(object):
|
||||
"""Defines an application service. This definition is mostly what is
|
||||
provided to the /register AS API.
|
||||
@@ -35,13 +79,13 @@ class ApplicationService(object):
|
||||
NS_LIST = [NS_USERS, NS_ALIASES, NS_ROOMS]
|
||||
|
||||
def __init__(self, token, url=None, namespaces=None, hs_token=None,
|
||||
sender=None, txn_id=None):
|
||||
sender=None, id=None):
|
||||
self.token = token
|
||||
self.url = url
|
||||
self.hs_token = hs_token
|
||||
self.sender = sender
|
||||
self.namespaces = self._check_namespaces(namespaces)
|
||||
self.txn_id = txn_id
|
||||
self.id = id
|
||||
|
||||
def _check_namespaces(self, namespaces):
|
||||
# Sanity check that it is of the form:
|
||||
@@ -51,7 +95,7 @@ class ApplicationService(object):
|
||||
# rooms: [ {regex: "[A-z]+.*", exclusive: true}, ...],
|
||||
# }
|
||||
if not namespaces:
|
||||
return None
|
||||
namespaces = {}
|
||||
|
||||
for ns in ApplicationService.NS_LIST:
|
||||
if ns not in namespaces:
|
||||
@@ -155,7 +199,10 @@ class ApplicationService(object):
|
||||
return self._matches_user(event, member_list)
|
||||
|
||||
def is_interested_in_user(self, user_id):
|
||||
return self._matches_regex(user_id, ApplicationService.NS_USERS)
|
||||
return (
|
||||
self._matches_regex(user_id, ApplicationService.NS_USERS)
|
||||
or user_id == self.sender
|
||||
)
|
||||
|
||||
def is_interested_in_alias(self, alias):
|
||||
return self._matches_regex(alias, ApplicationService.NS_ALIASES)
|
||||
@@ -164,7 +211,10 @@ class ApplicationService(object):
|
||||
return self._matches_regex(room_id, ApplicationService.NS_ROOMS)
|
||||
|
||||
def is_exclusive_user(self, user_id):
|
||||
return self._is_exclusive(ApplicationService.NS_USERS, user_id)
|
||||
return (
|
||||
self._is_exclusive(ApplicationService.NS_USERS, user_id)
|
||||
or user_id == self.sender
|
||||
)
|
||||
|
||||
def is_exclusive_alias(self, alias):
|
||||
return self._is_exclusive(ApplicationService.NS_ALIASES, alias)
|
||||
|
||||
@@ -72,14 +72,19 @@ class ApplicationServiceApi(SimpleHttpClient):
|
||||
defer.returnValue(False)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def push_bulk(self, service, events):
|
||||
def push_bulk(self, service, events, txn_id=None):
|
||||
events = self._serialize(events)
|
||||
|
||||
if txn_id is None:
|
||||
logger.warning("push_bulk: Missing txn ID sending events to %s",
|
||||
service.url)
|
||||
txn_id = str(0)
|
||||
txn_id = str(txn_id)
|
||||
|
||||
uri = service.url + ("/transactions/%s" %
|
||||
urllib.quote(str(0))) # TODO txn_ids
|
||||
response = None
|
||||
urllib.quote(txn_id))
|
||||
try:
|
||||
response = yield self.put_json(
|
||||
yield self.put_json(
|
||||
uri=uri,
|
||||
json_body={
|
||||
"events": events
|
||||
@@ -87,9 +92,8 @@ class ApplicationServiceApi(SimpleHttpClient):
|
||||
args={
|
||||
"access_token": service.hs_token
|
||||
})
|
||||
if response: # just an empty json object
|
||||
# TODO: Mark txn as sent successfully
|
||||
defer.returnValue(True)
|
||||
defer.returnValue(True)
|
||||
return
|
||||
except CodeMessageException as e:
|
||||
logger.warning("push_bulk to %s received %s", uri, e.code)
|
||||
except Exception as ex:
|
||||
@@ -97,8 +101,8 @@ class ApplicationServiceApi(SimpleHttpClient):
|
||||
defer.returnValue(False)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def push(self, service, event):
|
||||
response = yield self.push_bulk(service, [event])
|
||||
def push(self, service, event, txn_id=None):
|
||||
response = yield self.push_bulk(service, [event], txn_id)
|
||||
defer.returnValue(response)
|
||||
|
||||
def _serialize(self, events):
|
||||
|
||||
254
synapse/appservice/scheduler.py
Normal file
254
synapse/appservice/scheduler.py
Normal file
@@ -0,0 +1,254 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2015 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
This module controls the reliability for application service transactions.
|
||||
|
||||
The nominal flow through this module looks like:
|
||||
__________
|
||||
1---ASa[e]-->| Service |--> Queue ASa[f]
|
||||
2----ASb[e]->| Queuer |
|
||||
3--ASa[f]--->|__________|-----------+ ASa[e], ASb[e]
|
||||
V
|
||||
-````````- +------------+
|
||||
|````````|<--StoreTxn-|Transaction |
|
||||
|Database| | Controller |---> SEND TO AS
|
||||
`--------` +------------+
|
||||
What happens on SEND TO AS depends on the state of the Application Service:
|
||||
- If the AS is marked as DOWN, do nothing.
|
||||
- If the AS is marked as UP, send the transaction.
|
||||
* SUCCESS : Increment where the AS is up to txn-wise and nuke the txn
|
||||
contents from the db.
|
||||
* FAILURE : Marked AS as DOWN and start Recoverer.
|
||||
|
||||
Recoverer attempts to recover ASes who have died. The flow for this looks like:
|
||||
,--------------------- backoff++ --------------.
|
||||
V |
|
||||
START ---> Wait exp ------> Get oldest txn ID from ----> FAILURE
|
||||
backoff DB and try to send it
|
||||
^ |___________
|
||||
Mark AS as | V
|
||||
UP & quit +---------- YES SUCCESS
|
||||
| | |
|
||||
NO <--- Have more txns? <------ Mark txn success & nuke <-+
|
||||
from db; incr AS pos.
|
||||
Reset backoff.
|
||||
|
||||
This is all tied together by the AppServiceScheduler which DIs the required
|
||||
components.
|
||||
"""
|
||||
|
||||
from synapse.appservice import ApplicationServiceState
|
||||
from twisted.internet import defer
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AppServiceScheduler(object):
|
||||
""" Public facing API for this module. Does the required DI to tie the
|
||||
components together. This also serves as the "event_pool", which in this
|
||||
case is a simple array.
|
||||
"""
|
||||
|
||||
def __init__(self, clock, store, as_api):
|
||||
self.clock = clock
|
||||
self.store = store
|
||||
self.as_api = as_api
|
||||
|
||||
def create_recoverer(service, callback):
|
||||
return _Recoverer(clock, store, as_api, service, callback)
|
||||
|
||||
self.txn_ctrl = _TransactionController(
|
||||
clock, store, as_api, create_recoverer
|
||||
)
|
||||
self.queuer = _ServiceQueuer(self.txn_ctrl)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def start(self):
|
||||
logger.info("Starting appservice scheduler")
|
||||
# check for any DOWN ASes and start recoverers for them.
|
||||
recoverers = yield _Recoverer.start(
|
||||
self.clock, self.store, self.as_api, self.txn_ctrl.on_recovered
|
||||
)
|
||||
self.txn_ctrl.add_recoverers(recoverers)
|
||||
|
||||
def submit_event_for_as(self, service, event):
|
||||
self.queuer.enqueue(service, event)
|
||||
|
||||
|
||||
class _ServiceQueuer(object):
|
||||
"""Queues events for the same application service together, sending
|
||||
transactions as soon as possible. Once a transaction is sent successfully,
|
||||
this schedules any other events in the queue to run.
|
||||
"""
|
||||
|
||||
def __init__(self, txn_ctrl):
|
||||
self.queued_events = {} # dict of {service_id: [events]}
|
||||
self.pending_requests = {} # dict of {service_id: Deferred}
|
||||
self.txn_ctrl = txn_ctrl
|
||||
|
||||
def enqueue(self, service, event):
|
||||
# if this service isn't being sent something
|
||||
if not self.pending_requests.get(service.id):
|
||||
self._send_request(service, [event])
|
||||
else:
|
||||
# add to queue for this service
|
||||
if service.id not in self.queued_events:
|
||||
self.queued_events[service.id] = []
|
||||
self.queued_events[service.id].append(event)
|
||||
|
||||
def _send_request(self, service, events):
|
||||
# send request and add callbacks
|
||||
d = self.txn_ctrl.send(service, events)
|
||||
d.addBoth(self._on_request_finish)
|
||||
d.addErrback(self._on_request_fail)
|
||||
self.pending_requests[service.id] = d
|
||||
|
||||
def _on_request_finish(self, service):
|
||||
self.pending_requests[service.id] = None
|
||||
# if there are queued events, then send them.
|
||||
if (service.id in self.queued_events
|
||||
and len(self.queued_events[service.id]) > 0):
|
||||
self._send_request(service, self.queued_events[service.id])
|
||||
self.queued_events[service.id] = []
|
||||
|
||||
def _on_request_fail(self, err):
|
||||
logger.error("AS request failed: %s", err)
|
||||
|
||||
|
||||
class _TransactionController(object):
|
||||
|
||||
def __init__(self, clock, store, as_api, recoverer_fn):
|
||||
self.clock = clock
|
||||
self.store = store
|
||||
self.as_api = as_api
|
||||
self.recoverer_fn = recoverer_fn
|
||||
# keep track of how many recoverers there are
|
||||
self.recoverers = []
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def send(self, service, events):
|
||||
try:
|
||||
txn = yield self.store.create_appservice_txn(
|
||||
service=service,
|
||||
events=events
|
||||
)
|
||||
service_is_up = yield self._is_service_up(service)
|
||||
if service_is_up:
|
||||
sent = yield txn.send(self.as_api)
|
||||
if sent:
|
||||
txn.complete(self.store)
|
||||
else:
|
||||
self._start_recoverer(service)
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
self._start_recoverer(service)
|
||||
# request has finished
|
||||
defer.returnValue(service)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_recovered(self, recoverer):
|
||||
self.recoverers.remove(recoverer)
|
||||
logger.info("Successfully recovered application service AS ID %s",
|
||||
recoverer.service.id)
|
||||
logger.info("Remaining active recoverers: %s", len(self.recoverers))
|
||||
yield self.store.set_appservice_state(
|
||||
recoverer.service,
|
||||
ApplicationServiceState.UP
|
||||
)
|
||||
|
||||
def add_recoverers(self, recoverers):
|
||||
for r in recoverers:
|
||||
self.recoverers.append(r)
|
||||
if len(recoverers) > 0:
|
||||
logger.info("New active recoverers: %s", len(self.recoverers))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _start_recoverer(self, service):
|
||||
yield self.store.set_appservice_state(
|
||||
service,
|
||||
ApplicationServiceState.DOWN
|
||||
)
|
||||
logger.info(
|
||||
"Application service falling behind. Starting recoverer. AS ID %s",
|
||||
service.id
|
||||
)
|
||||
recoverer = self.recoverer_fn(service, self.on_recovered)
|
||||
self.add_recoverers([recoverer])
|
||||
recoverer.recover()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _is_service_up(self, service):
|
||||
state = yield self.store.get_appservice_state(service)
|
||||
defer.returnValue(state == ApplicationServiceState.UP or state is None)
|
||||
|
||||
|
||||
class _Recoverer(object):
|
||||
|
||||
@staticmethod
|
||||
@defer.inlineCallbacks
|
||||
def start(clock, store, as_api, callback):
|
||||
services = yield store.get_appservices_by_state(
|
||||
ApplicationServiceState.DOWN
|
||||
)
|
||||
recoverers = [
|
||||
_Recoverer(clock, store, as_api, s, callback) for s in services
|
||||
]
|
||||
for r in recoverers:
|
||||
logger.info("Starting recoverer for AS ID %s which was marked as "
|
||||
"DOWN", r.service.id)
|
||||
r.recover()
|
||||
defer.returnValue(recoverers)
|
||||
|
||||
def __init__(self, clock, store, as_api, service, callback):
|
||||
self.clock = clock
|
||||
self.store = store
|
||||
self.as_api = as_api
|
||||
self.service = service
|
||||
self.callback = callback
|
||||
self.backoff_counter = 1
|
||||
|
||||
def recover(self):
|
||||
self.clock.call_later((2 ** self.backoff_counter), self.retry)
|
||||
|
||||
def _backoff(self):
|
||||
# cap the backoff to be around 18h => (2^16) = 65536 secs
|
||||
if self.backoff_counter < 16:
|
||||
self.backoff_counter += 1
|
||||
self.recover()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def retry(self):
|
||||
try:
|
||||
txn = yield self.store.get_oldest_unsent_txn(self.service)
|
||||
if txn:
|
||||
logger.info("Retrying transaction %s for AS ID %s",
|
||||
txn.id, txn.service.id)
|
||||
sent = yield txn.send(self.as_api)
|
||||
if sent:
|
||||
yield txn.complete(self.store)
|
||||
# reset the backoff counter and retry immediately
|
||||
self.backoff_counter = 1
|
||||
yield self.retry()
|
||||
else:
|
||||
self._backoff()
|
||||
else:
|
||||
self._set_service_recovered()
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
self._backoff()
|
||||
|
||||
def _set_service_recovered(self):
|
||||
self.callback(self)
|
||||
@@ -77,6 +77,17 @@ class Config(object):
|
||||
with open(file_path) as file_stream:
|
||||
return file_stream.read()
|
||||
|
||||
@classmethod
|
||||
def read_yaml_file(cls, file_path, config_name):
|
||||
cls.check_file(file_path, config_name)
|
||||
with open(file_path) as file_stream:
|
||||
try:
|
||||
return yaml.load(file_stream)
|
||||
except:
|
||||
raise ConfigError(
|
||||
"Error parsing yaml in file %r" % (file_path,)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def default_path(name):
|
||||
return os.path.abspath(os.path.join(os.path.curdir, name))
|
||||
@@ -147,9 +158,10 @@ class Config(object):
|
||||
and value is not None):
|
||||
config[key] = value
|
||||
with open(config_args.config_path, "w") as config_file:
|
||||
# TODO(paul) it would be lovely if we wrote out vim- and emacs-
|
||||
# style mode markers into the file, to hint to people that
|
||||
# this is a YAML file.
|
||||
# TODO(mark/paul) We might want to output emacs-style mode
|
||||
# markers as well as vim-style mode markers into the file,
|
||||
# to further hint to people this is a YAML file.
|
||||
config_file.write("# vim:ft=yaml\n")
|
||||
yaml.dump(config, config_file, default_flow_style=False)
|
||||
print (
|
||||
"A config file has been generated in %s for server name"
|
||||
|
||||
31
synapse/config/appservice.py
Normal file
31
synapse/config/appservice.py
Normal file
@@ -0,0 +1,31 @@
|
||||
# Copyright 2015 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ._base import Config
|
||||
|
||||
|
||||
class AppServiceConfig(Config):
|
||||
|
||||
def __init__(self, args):
|
||||
super(AppServiceConfig, self).__init__(args)
|
||||
self.app_service_config_files = args.app_service_config_files
|
||||
|
||||
@classmethod
|
||||
def add_arguments(cls, parser):
|
||||
super(AppServiceConfig, cls).add_arguments(parser)
|
||||
group = parser.add_argument_group("appservice")
|
||||
group.add_argument(
|
||||
"--app-service-config-files", type=str, nargs='+',
|
||||
help="A list of application service config files to use."
|
||||
)
|
||||
@@ -20,7 +20,10 @@ class CaptchaConfig(Config):
|
||||
def __init__(self, args):
|
||||
super(CaptchaConfig, self).__init__(args)
|
||||
self.recaptcha_private_key = args.recaptcha_private_key
|
||||
self.recaptcha_public_key = args.recaptcha_public_key
|
||||
self.enable_registration_captcha = args.enable_registration_captcha
|
||||
|
||||
# XXX: This is used for more than just captcha
|
||||
self.captcha_ip_origin_is_x_forwarded = (
|
||||
args.captcha_ip_origin_is_x_forwarded
|
||||
)
|
||||
@@ -30,9 +33,13 @@ class CaptchaConfig(Config):
|
||||
def add_arguments(cls, parser):
|
||||
super(CaptchaConfig, cls).add_arguments(parser)
|
||||
group = parser.add_argument_group("recaptcha")
|
||||
group.add_argument(
|
||||
"--recaptcha-public-key", type=str, default="YOUR_PUBLIC_KEY",
|
||||
help="This Home Server's ReCAPTCHA public key."
|
||||
)
|
||||
group.add_argument(
|
||||
"--recaptcha-private-key", type=str, default="YOUR_PRIVATE_KEY",
|
||||
help="The matching private key for the web client's public key."
|
||||
help="This Home Server's ReCAPTCHA private key."
|
||||
)
|
||||
group.add_argument(
|
||||
"--enable-registration-captcha", type=bool, default=False,
|
||||
|
||||
@@ -15,6 +15,7 @@
|
||||
|
||||
from ._base import Config
|
||||
import os
|
||||
import yaml
|
||||
|
||||
|
||||
class DatabaseConfig(Config):
|
||||
@@ -26,18 +27,45 @@ class DatabaseConfig(Config):
|
||||
self.database_path = self.abspath(args.database_path)
|
||||
self.event_cache_size = self.parse_size(args.event_cache_size)
|
||||
|
||||
if args.database_config:
|
||||
with open(args.database_config) as f:
|
||||
self.database_config = yaml.safe_load(f)
|
||||
else:
|
||||
self.database_config = {
|
||||
"name": "sqlite3",
|
||||
"args": {
|
||||
"database": self.database_path,
|
||||
},
|
||||
}
|
||||
|
||||
name = self.database_config.get("name", None)
|
||||
if name == "psycopg2":
|
||||
pass
|
||||
elif name == "sqlite3":
|
||||
self.database_config.setdefault("args", {}).update({
|
||||
"cp_min": 1,
|
||||
"cp_max": 1,
|
||||
"check_same_thread": False,
|
||||
})
|
||||
else:
|
||||
raise RuntimeError("Unsupported database type '%s'" % (name,))
|
||||
|
||||
@classmethod
|
||||
def add_arguments(cls, parser):
|
||||
super(DatabaseConfig, cls).add_arguments(parser)
|
||||
db_group = parser.add_argument_group("database")
|
||||
db_group.add_argument(
|
||||
"-d", "--database-path", default="homeserver.db",
|
||||
help="The database name."
|
||||
metavar="SQLITE_DATABASE_PATH", help="The database name."
|
||||
)
|
||||
db_group.add_argument(
|
||||
"--event-cache-size", default="100K",
|
||||
help="Number of events to cache in memory."
|
||||
)
|
||||
db_group.add_argument(
|
||||
"--database-config", default=None,
|
||||
help="Location of the database configuration file."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def generate_config(cls, args, config_dir_path):
|
||||
|
||||
@@ -1,42 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014, 2015 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ._base import Config
|
||||
|
||||
|
||||
class EmailConfig(Config):
|
||||
|
||||
def __init__(self, args):
|
||||
super(EmailConfig, self).__init__(args)
|
||||
self.email_from_address = args.email_from_address
|
||||
self.email_smtp_server = args.email_smtp_server
|
||||
|
||||
@classmethod
|
||||
def add_arguments(cls, parser):
|
||||
super(EmailConfig, cls).add_arguments(parser)
|
||||
email_group = parser.add_argument_group("email")
|
||||
email_group.add_argument(
|
||||
"--email-from-address",
|
||||
default="FROM@EXAMPLE.COM",
|
||||
help="The address to send emails from (e.g. for password resets)."
|
||||
)
|
||||
email_group.add_argument(
|
||||
"--email-smtp-server",
|
||||
default="",
|
||||
help=(
|
||||
"The SMTP server to send emails from (e.g. for password"
|
||||
" resets)."
|
||||
)
|
||||
)
|
||||
@@ -20,16 +20,17 @@ from .database import DatabaseConfig
|
||||
from .ratelimiting import RatelimitConfig
|
||||
from .repository import ContentRepositoryConfig
|
||||
from .captcha import CaptchaConfig
|
||||
from .email import EmailConfig
|
||||
from .voip import VoipConfig
|
||||
from .registration import RegistrationConfig
|
||||
from .metrics import MetricsConfig
|
||||
from .appservice import AppServiceConfig
|
||||
from .key import KeyConfig
|
||||
|
||||
|
||||
class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
|
||||
RatelimitConfig, ContentRepositoryConfig, CaptchaConfig,
|
||||
EmailConfig, VoipConfig, RegistrationConfig,
|
||||
MetricsConfig,):
|
||||
VoipConfig, RegistrationConfig,
|
||||
MetricsConfig, AppServiceConfig, KeyConfig,):
|
||||
pass
|
||||
|
||||
|
||||
|
||||
147
synapse/config/key.py
Normal file
147
synapse/config/key.py
Normal file
@@ -0,0 +1,147 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2015 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from ._base import Config, ConfigError
|
||||
import syutil.crypto.signing_key
|
||||
from syutil.crypto.signing_key import (
|
||||
is_signing_algorithm_supported, decode_verify_key_bytes
|
||||
)
|
||||
from syutil.base64util import decode_base64
|
||||
|
||||
|
||||
class KeyConfig(Config):
|
||||
|
||||
def __init__(self, args):
|
||||
super(KeyConfig, self).__init__(args)
|
||||
self.signing_key = self.read_signing_key(args.signing_key_path)
|
||||
self.old_signing_keys = self.read_old_signing_keys(
|
||||
args.old_signing_key_path
|
||||
)
|
||||
self.key_refresh_interval = args.key_refresh_interval
|
||||
self.perspectives = self.read_perspectives(
|
||||
args.perspectives_config_path
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def add_arguments(cls, parser):
|
||||
super(KeyConfig, cls).add_arguments(parser)
|
||||
key_group = parser.add_argument_group("keys")
|
||||
key_group.add_argument("--signing-key-path",
|
||||
help="The signing key to sign messages with")
|
||||
key_group.add_argument("--old-signing-key-path",
|
||||
help="The keys that the server used to sign"
|
||||
" sign messages with but won't use"
|
||||
" to sign new messages. E.g. it has"
|
||||
" lost its private key")
|
||||
key_group.add_argument("--key-refresh-interval",
|
||||
default=24 * 60 * 60 * 1000, # 1 Day
|
||||
help="How long a key response is valid for."
|
||||
" Used to set the exipiry in /key/v2/."
|
||||
" Controls how frequently servers will"
|
||||
" query what keys are still valid")
|
||||
key_group.add_argument("--perspectives-config-path",
|
||||
help="The trusted servers to download signing"
|
||||
" keys from")
|
||||
|
||||
def read_perspectives(self, perspectives_config_path):
|
||||
config = self.read_yaml_file(
|
||||
perspectives_config_path, "perspectives_config_path"
|
||||
)
|
||||
servers = {}
|
||||
for server_name, server_config in config["servers"].items():
|
||||
for key_id, key_data in server_config["verify_keys"].items():
|
||||
if is_signing_algorithm_supported(key_id):
|
||||
key_base64 = key_data["key"]
|
||||
key_bytes = decode_base64(key_base64)
|
||||
verify_key = decode_verify_key_bytes(key_id, key_bytes)
|
||||
servers.setdefault(server_name, {})[key_id] = verify_key
|
||||
return servers
|
||||
|
||||
def read_signing_key(self, signing_key_path):
|
||||
signing_keys = self.read_file(signing_key_path, "signing_key")
|
||||
try:
|
||||
return syutil.crypto.signing_key.read_signing_keys(
|
||||
signing_keys.splitlines(True)
|
||||
)
|
||||
except Exception:
|
||||
raise ConfigError(
|
||||
"Error reading signing_key."
|
||||
" Try running again with --generate-config"
|
||||
)
|
||||
|
||||
def read_old_signing_keys(self, old_signing_key_path):
|
||||
old_signing_keys = self.read_file(
|
||||
old_signing_key_path, "old_signing_key"
|
||||
)
|
||||
try:
|
||||
return syutil.crypto.signing_key.read_old_signing_keys(
|
||||
old_signing_keys.splitlines(True)
|
||||
)
|
||||
except Exception:
|
||||
raise ConfigError(
|
||||
"Error reading old signing keys."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def generate_config(cls, args, config_dir_path):
|
||||
super(KeyConfig, cls).generate_config(args, config_dir_path)
|
||||
base_key_name = os.path.join(config_dir_path, args.server_name)
|
||||
|
||||
args.pid_file = os.path.abspath(args.pid_file)
|
||||
|
||||
if not args.signing_key_path:
|
||||
args.signing_key_path = base_key_name + ".signing.key"
|
||||
|
||||
if not os.path.exists(args.signing_key_path):
|
||||
with open(args.signing_key_path, "w") as signing_key_file:
|
||||
syutil.crypto.signing_key.write_signing_keys(
|
||||
signing_key_file,
|
||||
(syutil.crypto.signing_key.generate_signing_key("auto"),),
|
||||
)
|
||||
else:
|
||||
signing_keys = cls.read_file(args.signing_key_path, "signing_key")
|
||||
if len(signing_keys.split("\n")[0].split()) == 1:
|
||||
# handle keys in the old format.
|
||||
key = syutil.crypto.signing_key.decode_signing_key_base64(
|
||||
syutil.crypto.signing_key.NACL_ED25519,
|
||||
"auto",
|
||||
signing_keys.split("\n")[0]
|
||||
)
|
||||
with open(args.signing_key_path, "w") as signing_key_file:
|
||||
syutil.crypto.signing_key.write_signing_keys(
|
||||
signing_key_file,
|
||||
(key,),
|
||||
)
|
||||
|
||||
if not args.old_signing_key_path:
|
||||
args.old_signing_key_path = base_key_name + ".old.signing.keys"
|
||||
|
||||
if not os.path.exists(args.old_signing_key_path):
|
||||
with open(args.old_signing_key_path, "w"):
|
||||
pass
|
||||
|
||||
if not args.perspectives_config_path:
|
||||
args.perspectives_config_path = base_key_name + ".perspectives"
|
||||
|
||||
if not os.path.exists(args.perspectives_config_path):
|
||||
with open(args.perspectives_config_path, "w") as perspectives_file:
|
||||
perspectives_file.write(
|
||||
'servers:\n'
|
||||
' matrix.org:\n'
|
||||
' verify_keys:\n'
|
||||
' "ed25519:auto":\n'
|
||||
' key: "Noi6WqcDj0QmPxCNQqgezwTlBKrfqehY1u2FyWP9uYw"\n'
|
||||
)
|
||||
@@ -78,7 +78,6 @@ class LoggingConfig(Config):
|
||||
handler.addFilter(LoggingContextFilter(request=""))
|
||||
|
||||
logger.addHandler(handler)
|
||||
logger.info("Test")
|
||||
else:
|
||||
with open(self.log_config, 'r') as f:
|
||||
logging.config.dictConfig(yaml.load(f))
|
||||
|
||||
@@ -25,11 +25,11 @@ class RegistrationConfig(Config):
|
||||
def __init__(self, args):
|
||||
super(RegistrationConfig, self).__init__(args)
|
||||
|
||||
# `args.disable_registration` may either be a bool or a string depending
|
||||
# on if the option was given a value (e.g. --disable-registration=false
|
||||
# would set `args.disable_registration` to "false" not False.)
|
||||
self.disable_registration = bool(
|
||||
distutils.util.strtobool(str(args.disable_registration))
|
||||
# `args.enable_registration` may either be a bool or a string depending
|
||||
# on if the option was given a value (e.g. --enable-registration=true
|
||||
# would set `args.enable_registration` to "true" not True.)
|
||||
self.disable_registration = not bool(
|
||||
distutils.util.strtobool(str(args.enable_registration))
|
||||
)
|
||||
self.registration_shared_secret = args.registration_shared_secret
|
||||
|
||||
@@ -39,11 +39,11 @@ class RegistrationConfig(Config):
|
||||
reg_group = parser.add_argument_group("registration")
|
||||
|
||||
reg_group.add_argument(
|
||||
"--disable-registration",
|
||||
"--enable-registration",
|
||||
const=True,
|
||||
default=True,
|
||||
default=False,
|
||||
nargs='?',
|
||||
help="Disable registration of new users.",
|
||||
help="Enable registration for new users.",
|
||||
)
|
||||
reg_group.add_argument(
|
||||
"--registration-shared-secret", type=str,
|
||||
@@ -53,8 +53,9 @@ class RegistrationConfig(Config):
|
||||
|
||||
@classmethod
|
||||
def generate_config(cls, args, config_dir_path):
|
||||
if args.disable_registration is None:
|
||||
args.disable_registration = True
|
||||
super(RegistrationConfig, cls).generate_config(args, config_dir_path)
|
||||
if args.enable_registration is None:
|
||||
args.enable_registration = False
|
||||
|
||||
if args.registration_shared_secret is None:
|
||||
args.registration_shared_secret = random_string_with_symbols(50)
|
||||
|
||||
@@ -13,16 +13,13 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from ._base import Config, ConfigError
|
||||
import syutil.crypto.signing_key
|
||||
from ._base import Config
|
||||
|
||||
|
||||
class ServerConfig(Config):
|
||||
def __init__(self, args):
|
||||
super(ServerConfig, self).__init__(args)
|
||||
self.server_name = args.server_name
|
||||
self.signing_key = self.read_signing_key(args.signing_key_path)
|
||||
self.bind_port = args.bind_port
|
||||
self.bind_host = args.bind_host
|
||||
self.unsecure_port = args.unsecure_port
|
||||
@@ -53,8 +50,6 @@ class ServerConfig(Config):
|
||||
"This is used by remote servers to connect to this server, "
|
||||
"e.g. matrix.org, localhost:8080, etc."
|
||||
)
|
||||
server_group.add_argument("--signing-key-path",
|
||||
help="The signing key to sign messages with")
|
||||
server_group.add_argument("-p", "--bind-port", metavar="PORT",
|
||||
type=int, help="https port to listen on",
|
||||
default=8448)
|
||||
@@ -83,46 +78,3 @@ class ServerConfig(Config):
|
||||
"Zero is used to indicate synapse "
|
||||
"should set the soft limit to the hard"
|
||||
"limit.")
|
||||
|
||||
def read_signing_key(self, signing_key_path):
|
||||
signing_keys = self.read_file(signing_key_path, "signing_key")
|
||||
try:
|
||||
return syutil.crypto.signing_key.read_signing_keys(
|
||||
signing_keys.splitlines(True)
|
||||
)
|
||||
except Exception:
|
||||
raise ConfigError(
|
||||
"Error reading signing_key."
|
||||
" Try running again with --generate-config"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def generate_config(cls, args, config_dir_path):
|
||||
super(ServerConfig, cls).generate_config(args, config_dir_path)
|
||||
base_key_name = os.path.join(config_dir_path, args.server_name)
|
||||
|
||||
args.pid_file = os.path.abspath(args.pid_file)
|
||||
|
||||
if not args.signing_key_path:
|
||||
args.signing_key_path = base_key_name + ".signing.key"
|
||||
|
||||
if not os.path.exists(args.signing_key_path):
|
||||
with open(args.signing_key_path, "w") as signing_key_file:
|
||||
syutil.crypto.signing_key.write_signing_keys(
|
||||
signing_key_file,
|
||||
(syutil.crypto.signing_key.generate_singing_key("auto"),),
|
||||
)
|
||||
else:
|
||||
signing_keys = cls.read_file(args.signing_key_path, "signing_key")
|
||||
if len(signing_keys.split("\n")[0].split()) == 1:
|
||||
# handle keys in the old format.
|
||||
key = syutil.crypto.signing_key.decode_signing_key_base64(
|
||||
syutil.crypto.signing_key.NACL_ED25519,
|
||||
"auto",
|
||||
signing_keys.split("\n")[0]
|
||||
)
|
||||
with open(args.signing_key_path, "w") as signing_key_file:
|
||||
syutil.crypto.signing_key.write_signing_keys(
|
||||
signing_key_file,
|
||||
(key,),
|
||||
)
|
||||
|
||||
@@ -25,12 +25,15 @@ import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
KEY_API_V1 = b"/_matrix/key/v1/"
|
||||
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def fetch_server_key(server_name, ssl_context_factory):
|
||||
def fetch_server_key(server_name, ssl_context_factory, path=KEY_API_V1):
|
||||
"""Fetch the keys for a remote server."""
|
||||
|
||||
factory = SynapseKeyClientFactory()
|
||||
factory.path = path
|
||||
endpoint = matrix_federation_endpoint(
|
||||
reactor, server_name, ssl_context_factory, timeout=30
|
||||
)
|
||||
@@ -42,13 +45,19 @@ def fetch_server_key(server_name, ssl_context_factory):
|
||||
server_response, server_certificate = yield protocol.remote_key
|
||||
defer.returnValue((server_response, server_certificate))
|
||||
return
|
||||
except SynapseKeyClientError as e:
|
||||
logger.exception("Error getting key for %r" % (server_name,))
|
||||
if e.status.startswith("4"):
|
||||
# Don't retry for 4xx responses.
|
||||
raise IOError("Cannot get key for %r" % server_name)
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
raise IOError("Cannot get key for %s" % server_name)
|
||||
raise IOError("Cannot get key for %r" % server_name)
|
||||
|
||||
|
||||
class SynapseKeyClientError(Exception):
|
||||
"""The key wasn't retrieved from the remote server."""
|
||||
status = None
|
||||
pass
|
||||
|
||||
|
||||
@@ -66,17 +75,30 @@ class SynapseKeyClientProtocol(HTTPClient):
|
||||
def connectionMade(self):
|
||||
self.host = self.transport.getHost()
|
||||
logger.debug("Connected to %s", self.host)
|
||||
self.sendCommand(b"GET", b"/_matrix/key/v1/")
|
||||
self.sendCommand(b"GET", self.path)
|
||||
self.endHeaders()
|
||||
self.timer = reactor.callLater(
|
||||
self.timeout,
|
||||
self.on_timeout
|
||||
)
|
||||
|
||||
def errback(self, error):
|
||||
if not self.remote_key.called:
|
||||
self.remote_key.errback(error)
|
||||
|
||||
def callback(self, result):
|
||||
if not self.remote_key.called:
|
||||
self.remote_key.callback(result)
|
||||
|
||||
def handleStatus(self, version, status, message):
|
||||
if status != b"200":
|
||||
# logger.info("Non-200 response from %s: %s %s",
|
||||
# self.transport.getHost(), status, message)
|
||||
error = SynapseKeyClientError(
|
||||
"Non-200 response %r from %r" % (status, self.host)
|
||||
)
|
||||
error.status = status
|
||||
self.errback(error)
|
||||
self.transport.abortConnection()
|
||||
|
||||
def handleResponse(self, response_body_bytes):
|
||||
@@ -89,15 +111,18 @@ class SynapseKeyClientProtocol(HTTPClient):
|
||||
return
|
||||
|
||||
certificate = self.transport.getPeerCertificate()
|
||||
self.remote_key.callback((json_response, certificate))
|
||||
self.callback((json_response, certificate))
|
||||
self.transport.abortConnection()
|
||||
self.timer.cancel()
|
||||
|
||||
def on_timeout(self):
|
||||
logger.debug("Timeout waiting for response from %s", self.host)
|
||||
self.remote_key.errback(IOError("Timeout waiting for response"))
|
||||
self.errback(IOError("Timeout waiting for response"))
|
||||
self.transport.abortConnection()
|
||||
|
||||
|
||||
class SynapseKeyClientFactory(Factory):
|
||||
protocol = SynapseKeyClientProtocol
|
||||
def protocol(self):
|
||||
protocol = SynapseKeyClientProtocol()
|
||||
protocol.path = self.path
|
||||
return protocol
|
||||
|
||||
@@ -15,7 +15,9 @@
|
||||
|
||||
from synapse.crypto.keyclient import fetch_server_key
|
||||
from twisted.internet import defer
|
||||
from syutil.crypto.jsonsign import verify_signed_json, signature_ids
|
||||
from syutil.crypto.jsonsign import (
|
||||
verify_signed_json, signature_ids, sign_json, encode_canonical_json
|
||||
)
|
||||
from syutil.crypto.signing_key import (
|
||||
is_signing_algorithm_supported, decode_verify_key_bytes
|
||||
)
|
||||
@@ -24,8 +26,12 @@ from synapse.api.errors import SynapseError, Codes
|
||||
|
||||
from synapse.util.retryutils import get_retry_limiter
|
||||
|
||||
from synapse.util.async import create_observer
|
||||
|
||||
from OpenSSL import crypto
|
||||
|
||||
import urllib
|
||||
import hashlib
|
||||
import logging
|
||||
|
||||
|
||||
@@ -36,8 +42,13 @@ class Keyring(object):
|
||||
def __init__(self, hs):
|
||||
self.store = hs.get_datastore()
|
||||
self.clock = hs.get_clock()
|
||||
self.client = hs.get_http_client()
|
||||
self.config = hs.get_config()
|
||||
self.perspective_servers = self.config.perspectives
|
||||
self.hs = hs
|
||||
|
||||
self.key_downloads = {}
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def verify_json_for_server(self, server_name, json_object):
|
||||
logger.debug("Verifying for %s", server_name)
|
||||
@@ -85,19 +96,56 @@ class Keyring(object):
|
||||
@defer.inlineCallbacks
|
||||
def get_server_verify_key(self, server_name, key_ids):
|
||||
"""Finds a verification key for the server with one of the key ids.
|
||||
Trys to fetch the key from a trusted perspective server first.
|
||||
Args:
|
||||
server_name (str): The name of the server to fetch a key for.
|
||||
server_name(str): The name of the server to fetch a key for.
|
||||
keys_ids (list of str): The key_ids to check for.
|
||||
"""
|
||||
|
||||
# Check the datastore to see if we have one cached.
|
||||
cached = yield self.store.get_server_verify_keys(server_name, key_ids)
|
||||
|
||||
if cached:
|
||||
defer.returnValue(cached[0])
|
||||
return
|
||||
|
||||
# Try to fetch the key from the remote server.
|
||||
download = self.key_downloads.get(server_name)
|
||||
|
||||
if download is None:
|
||||
download = self._get_server_verify_key_impl(server_name, key_ids)
|
||||
self.key_downloads[server_name] = download
|
||||
|
||||
@download.addBoth
|
||||
def callback(ret):
|
||||
del self.key_downloads[server_name]
|
||||
return ret
|
||||
|
||||
r = yield create_observer(download)
|
||||
defer.returnValue(r)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _get_server_verify_key_impl(self, server_name, key_ids):
|
||||
keys = None
|
||||
|
||||
perspective_results = []
|
||||
for perspective_name, perspective_keys in self.perspective_servers.items():
|
||||
@defer.inlineCallbacks
|
||||
def get_key():
|
||||
try:
|
||||
result = yield self.get_server_verify_key_v2_indirect(
|
||||
server_name, key_ids, perspective_name, perspective_keys
|
||||
)
|
||||
defer.returnValue(result)
|
||||
except:
|
||||
logging.info(
|
||||
"Unable to getting key %r for %r from %r",
|
||||
key_ids, server_name, perspective_name,
|
||||
)
|
||||
perspective_results.append(get_key())
|
||||
|
||||
perspective_results = yield defer.gatherResults(perspective_results)
|
||||
|
||||
for results in perspective_results:
|
||||
if results is not None:
|
||||
keys = results
|
||||
|
||||
limiter = yield get_retry_limiter(
|
||||
server_name,
|
||||
@@ -106,10 +154,234 @@ class Keyring(object):
|
||||
)
|
||||
|
||||
with limiter:
|
||||
(response, tls_certificate) = yield fetch_server_key(
|
||||
server_name, self.hs.tls_context_factory
|
||||
if keys is None:
|
||||
try:
|
||||
keys = yield self.get_server_verify_key_v2_direct(
|
||||
server_name, key_ids
|
||||
)
|
||||
except:
|
||||
pass
|
||||
|
||||
keys = yield self.get_server_verify_key_v1_direct(
|
||||
server_name, key_ids
|
||||
)
|
||||
|
||||
for key_id in key_ids:
|
||||
if key_id in keys:
|
||||
defer.returnValue(keys[key_id])
|
||||
return
|
||||
raise ValueError("No verification key found for given key ids")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_server_verify_key_v2_indirect(self, server_name, key_ids,
|
||||
perspective_name,
|
||||
perspective_keys):
|
||||
limiter = yield get_retry_limiter(
|
||||
perspective_name, self.clock, self.store
|
||||
)
|
||||
|
||||
with limiter:
|
||||
# TODO(mark): Set the minimum_valid_until_ts to that needed by
|
||||
# the events being validated or the current time if validating
|
||||
# an incoming request.
|
||||
responses = yield self.client.post_json(
|
||||
destination=perspective_name,
|
||||
path=b"/_matrix/key/v2/query",
|
||||
data={
|
||||
u"server_keys": {
|
||||
server_name: {
|
||||
key_id: {
|
||||
u"minimum_valid_until_ts": 0
|
||||
} for key_id in key_ids
|
||||
}
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
keys = {}
|
||||
|
||||
for response in responses:
|
||||
if (u"signatures" not in response
|
||||
or perspective_name not in response[u"signatures"]):
|
||||
raise ValueError(
|
||||
"Key response not signed by perspective server"
|
||||
" %r" % (perspective_name,)
|
||||
)
|
||||
|
||||
verified = False
|
||||
for key_id in response[u"signatures"][perspective_name]:
|
||||
if key_id in perspective_keys:
|
||||
verify_signed_json(
|
||||
response,
|
||||
perspective_name,
|
||||
perspective_keys[key_id]
|
||||
)
|
||||
verified = True
|
||||
|
||||
if not verified:
|
||||
logging.info(
|
||||
"Response from perspective server %r not signed with a"
|
||||
" known key, signed with: %r, known keys: %r",
|
||||
perspective_name,
|
||||
list(response[u"signatures"][perspective_name]),
|
||||
list(perspective_keys)
|
||||
)
|
||||
raise ValueError(
|
||||
"Response not signed with a known key for perspective"
|
||||
" server %r" % (perspective_name,)
|
||||
)
|
||||
|
||||
response_keys = yield self.process_v2_response(
|
||||
server_name, perspective_name, response
|
||||
)
|
||||
|
||||
keys.update(response_keys)
|
||||
|
||||
yield self.store_keys(
|
||||
server_name=server_name,
|
||||
from_server=perspective_name,
|
||||
verify_keys=keys,
|
||||
)
|
||||
|
||||
defer.returnValue(keys)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_server_verify_key_v2_direct(self, server_name, key_ids):
|
||||
|
||||
keys = {}
|
||||
|
||||
for requested_key_id in key_ids:
|
||||
if requested_key_id in keys:
|
||||
continue
|
||||
|
||||
(response, tls_certificate) = yield fetch_server_key(
|
||||
server_name, self.hs.tls_context_factory,
|
||||
path=(b"/_matrix/key/v2/server/%s" % (
|
||||
urllib.quote(requested_key_id),
|
||||
)).encode("ascii"),
|
||||
)
|
||||
|
||||
if (u"signatures" not in response
|
||||
or server_name not in response[u"signatures"]):
|
||||
raise ValueError("Key response not signed by remote server")
|
||||
|
||||
if "tls_fingerprints" not in response:
|
||||
raise ValueError("Key response missing TLS fingerprints")
|
||||
|
||||
certificate_bytes = crypto.dump_certificate(
|
||||
crypto.FILETYPE_ASN1, tls_certificate
|
||||
)
|
||||
sha256_fingerprint = hashlib.sha256(certificate_bytes).digest()
|
||||
sha256_fingerprint_b64 = encode_base64(sha256_fingerprint)
|
||||
|
||||
response_sha256_fingerprints = set()
|
||||
for fingerprint in response[u"tls_fingerprints"]:
|
||||
if u"sha256" in fingerprint:
|
||||
response_sha256_fingerprints.add(fingerprint[u"sha256"])
|
||||
|
||||
if sha256_fingerprint_b64 not in response_sha256_fingerprints:
|
||||
raise ValueError("TLS certificate not allowed by fingerprints")
|
||||
|
||||
response_keys = yield self.process_v2_response(
|
||||
server_name=server_name,
|
||||
from_server=server_name,
|
||||
requested_id=requested_key_id,
|
||||
response_json=response,
|
||||
)
|
||||
|
||||
keys.update(response_keys)
|
||||
|
||||
yield self.store_keys(
|
||||
server_name=server_name,
|
||||
from_server=server_name,
|
||||
verify_keys=keys,
|
||||
)
|
||||
|
||||
defer.returnValue(keys)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def process_v2_response(self, server_name, from_server, response_json,
|
||||
requested_id=None):
|
||||
time_now_ms = self.clock.time_msec()
|
||||
response_keys = {}
|
||||
verify_keys = {}
|
||||
for key_id, key_data in response_json["verify_keys"].items():
|
||||
if is_signing_algorithm_supported(key_id):
|
||||
key_base64 = key_data["key"]
|
||||
key_bytes = decode_base64(key_base64)
|
||||
verify_key = decode_verify_key_bytes(key_id, key_bytes)
|
||||
verify_key.time_added = time_now_ms
|
||||
verify_keys[key_id] = verify_key
|
||||
|
||||
old_verify_keys = {}
|
||||
for key_id, key_data in response_json["old_verify_keys"].items():
|
||||
if is_signing_algorithm_supported(key_id):
|
||||
key_base64 = key_data["key"]
|
||||
key_bytes = decode_base64(key_base64)
|
||||
verify_key = decode_verify_key_bytes(key_id, key_bytes)
|
||||
verify_key.expired = key_data["expired_ts"]
|
||||
verify_key.time_added = time_now_ms
|
||||
old_verify_keys[key_id] = verify_key
|
||||
|
||||
for key_id in response_json["signatures"][server_name]:
|
||||
if key_id not in response_json["verify_keys"]:
|
||||
raise ValueError(
|
||||
"Key response must include verification keys for all"
|
||||
" signatures"
|
||||
)
|
||||
if key_id in verify_keys:
|
||||
verify_signed_json(
|
||||
response_json,
|
||||
server_name,
|
||||
verify_keys[key_id]
|
||||
)
|
||||
|
||||
signed_key_json = sign_json(
|
||||
response_json,
|
||||
self.config.server_name,
|
||||
self.config.signing_key[0],
|
||||
)
|
||||
|
||||
signed_key_json_bytes = encode_canonical_json(signed_key_json)
|
||||
ts_valid_until_ms = signed_key_json[u"valid_until_ts"]
|
||||
|
||||
updated_key_ids = set()
|
||||
if requested_id is not None:
|
||||
updated_key_ids.add(requested_id)
|
||||
updated_key_ids.update(verify_keys)
|
||||
updated_key_ids.update(old_verify_keys)
|
||||
|
||||
response_keys.update(verify_keys)
|
||||
response_keys.update(old_verify_keys)
|
||||
|
||||
for key_id in updated_key_ids:
|
||||
yield self.store.store_server_keys_json(
|
||||
server_name=server_name,
|
||||
key_id=key_id,
|
||||
from_server=server_name,
|
||||
ts_now_ms=time_now_ms,
|
||||
ts_expires_ms=ts_valid_until_ms,
|
||||
key_json_bytes=signed_key_json_bytes,
|
||||
)
|
||||
|
||||
defer.returnValue(response_keys)
|
||||
|
||||
raise ValueError("No verification key found for given key ids")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_server_verify_key_v1_direct(self, server_name, key_ids):
|
||||
"""Finds a verification key for the server with one of the key ids.
|
||||
Args:
|
||||
server_name (str): The name of the server to fetch a key for.
|
||||
keys_ids (list of str): The key_ids to check for.
|
||||
"""
|
||||
|
||||
# Try to fetch the key from the remote server.
|
||||
|
||||
(response, tls_certificate) = yield fetch_server_key(
|
||||
server_name, self.hs.tls_context_factory
|
||||
)
|
||||
|
||||
# Check the response.
|
||||
|
||||
x509_certificate_bytes = crypto.dump_certificate(
|
||||
@@ -128,11 +400,16 @@ class Keyring(object):
|
||||
if encode_base64(x509_certificate_bytes) != tls_certificate_b64:
|
||||
raise ValueError("TLS certificate doesn't match")
|
||||
|
||||
# Cache the result in the datastore.
|
||||
|
||||
time_now_ms = self.clock.time_msec()
|
||||
|
||||
verify_keys = {}
|
||||
for key_id, key_base64 in response["verify_keys"].items():
|
||||
if is_signing_algorithm_supported(key_id):
|
||||
key_bytes = decode_base64(key_base64)
|
||||
verify_key = decode_verify_key_bytes(key_id, key_bytes)
|
||||
verify_key.time_added = time_now_ms
|
||||
verify_keys[key_id] = verify_key
|
||||
|
||||
for key_id in response["signatures"][server_name]:
|
||||
@@ -148,10 +425,6 @@ class Keyring(object):
|
||||
verify_keys[key_id]
|
||||
)
|
||||
|
||||
# Cache the result in the datastore.
|
||||
|
||||
time_now_ms = self.clock.time_msec()
|
||||
|
||||
yield self.store.store_server_certificate(
|
||||
server_name,
|
||||
server_name,
|
||||
@@ -159,14 +432,26 @@ class Keyring(object):
|
||||
tls_certificate,
|
||||
)
|
||||
|
||||
yield self.store_keys(
|
||||
server_name=server_name,
|
||||
from_server=server_name,
|
||||
verify_keys=verify_keys,
|
||||
)
|
||||
|
||||
defer.returnValue(verify_keys)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def store_keys(self, server_name, from_server, verify_keys):
|
||||
"""Store a collection of verify keys for a given server
|
||||
Args:
|
||||
server_name(str): The name of the server the keys are for.
|
||||
from_server(str): The server the keys were downloaded from.
|
||||
verify_keys(dict): A mapping of key_id to VerifyKey.
|
||||
Returns:
|
||||
A deferred that completes when the keys are stored.
|
||||
"""
|
||||
for key_id, key in verify_keys.items():
|
||||
# TODO(markjh): Store whether the keys have expired.
|
||||
yield self.store.store_server_verify_key(
|
||||
server_name, server_name, time_now_ms, key
|
||||
server_name, server_name, key.time_added, key
|
||||
)
|
||||
|
||||
for key_id in key_ids:
|
||||
if key_id in verify_keys:
|
||||
defer.returnValue(verify_keys[key_id])
|
||||
return
|
||||
|
||||
raise ValueError("No verification key found for given key ids")
|
||||
|
||||
@@ -46,9 +46,10 @@ def _event_dict_property(key):
|
||||
|
||||
class EventBase(object):
|
||||
def __init__(self, event_dict, signatures={}, unsigned={},
|
||||
internal_metadata_dict={}):
|
||||
internal_metadata_dict={}, rejected_reason=None):
|
||||
self.signatures = signatures
|
||||
self.unsigned = unsigned
|
||||
self.rejected_reason = rejected_reason
|
||||
|
||||
self._event_dict = event_dict
|
||||
|
||||
@@ -109,7 +110,7 @@ class EventBase(object):
|
||||
|
||||
|
||||
class FrozenEvent(EventBase):
|
||||
def __init__(self, event_dict, internal_metadata_dict={}):
|
||||
def __init__(self, event_dict, internal_metadata_dict={}, rejected_reason=None):
|
||||
event_dict = dict(event_dict)
|
||||
|
||||
# Signatures is a dict of dicts, and this is faster than doing a
|
||||
@@ -128,6 +129,7 @@ class FrozenEvent(EventBase):
|
||||
signatures=signatures,
|
||||
unsigned=unsigned,
|
||||
internal_metadata_dict=internal_metadata_dict,
|
||||
rejected_reason=rejected_reason,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -417,13 +417,13 @@ class FederationServer(FederationBase):
|
||||
pdu.internal_metadata.outlier = True
|
||||
elif min_depth and pdu.depth > min_depth:
|
||||
if get_missing and prevs - seen:
|
||||
latest_tuples = yield self.store.get_latest_events_in_room(
|
||||
latest = yield self.store.get_latest_event_ids_in_room(
|
||||
pdu.room_id
|
||||
)
|
||||
|
||||
# We add the prev events that we have seen to the latest
|
||||
# list to ensure the remote server doesn't give them to us
|
||||
latest = set(e_id for e_id, _, _ in latest_tuples)
|
||||
latest = set(latest)
|
||||
latest |= seen
|
||||
|
||||
missing_events = yield self.get_missing_events(
|
||||
|
||||
@@ -361,4 +361,5 @@ SERVLET_CLASSES = (
|
||||
FederationInviteServlet,
|
||||
FederationQueryAuthServlet,
|
||||
FederationGetMissingEventsServlet,
|
||||
FederationEventAuthServlet,
|
||||
)
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from synapse.appservice.scheduler import AppServiceScheduler
|
||||
from synapse.appservice.api import ApplicationServiceApi
|
||||
from .register import RegistrationHandler
|
||||
from .room import (
|
||||
@@ -29,6 +30,8 @@ from .typing import TypingNotificationHandler
|
||||
from .admin import AdminHandler
|
||||
from .appservice import ApplicationServicesHandler
|
||||
from .sync import SyncHandler
|
||||
from .auth import AuthHandler
|
||||
from .identity import IdentityHandler
|
||||
|
||||
|
||||
class Handlers(object):
|
||||
@@ -54,7 +57,14 @@ class Handlers(object):
|
||||
self.directory_handler = DirectoryHandler(hs)
|
||||
self.typing_notification_handler = TypingNotificationHandler(hs)
|
||||
self.admin_handler = AdminHandler(hs)
|
||||
asapi = ApplicationServiceApi(hs)
|
||||
self.appservice_handler = ApplicationServicesHandler(
|
||||
hs, ApplicationServiceApi(hs)
|
||||
hs, asapi, AppServiceScheduler(
|
||||
clock=hs.get_clock(),
|
||||
store=hs.get_datastore(),
|
||||
as_api=asapi
|
||||
)
|
||||
)
|
||||
self.sync_handler = SyncHandler(hs)
|
||||
self.auth_handler = AuthHandler(hs)
|
||||
self.identity_handler = IdentityHandler(hs)
|
||||
|
||||
@@ -16,7 +16,6 @@
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.errors import LimitExceededError, SynapseError
|
||||
from synapse.util.async import run_on_reactor
|
||||
from synapse.crypto.event_signing import add_hashes_and_signatures
|
||||
from synapse.api.constants import Membership, EventTypes
|
||||
from synapse.types import UserID
|
||||
@@ -58,8 +57,6 @@ class BaseHandler(object):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _create_new_client_event(self, builder):
|
||||
yield run_on_reactor()
|
||||
|
||||
latest_ret = yield self.store.get_latest_events_in_room(
|
||||
builder.room_id,
|
||||
)
|
||||
@@ -101,8 +98,6 @@ class BaseHandler(object):
|
||||
@defer.inlineCallbacks
|
||||
def handle_new_client_event(self, event, context, extra_destinations=[],
|
||||
extra_users=[], suppress_auth=False):
|
||||
yield run_on_reactor()
|
||||
|
||||
# We now need to go and hit out to wherever we need to hit out to.
|
||||
|
||||
if not suppress_auth:
|
||||
@@ -143,7 +138,9 @@ class BaseHandler(object):
|
||||
)
|
||||
|
||||
# Don't block waiting on waking up all the listeners.
|
||||
d = self.notifier.on_new_room_event(event, extra_users=extra_users)
|
||||
notify_d = self.notifier.on_new_room_event(
|
||||
event, extra_users=extra_users
|
||||
)
|
||||
|
||||
def log_failure(f):
|
||||
logger.warn(
|
||||
@@ -151,8 +148,10 @@ class BaseHandler(object):
|
||||
event.event_id, f.value
|
||||
)
|
||||
|
||||
d.addErrback(log_failure)
|
||||
notify_d.addErrback(log_failure)
|
||||
|
||||
yield federation_handler.handle_new_event(
|
||||
fed_d = federation_handler.handle_new_event(
|
||||
event, destinations=destinations,
|
||||
)
|
||||
|
||||
fed_d.addErrback(log_failure)
|
||||
|
||||
@@ -16,57 +16,36 @@
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.constants import EventTypes, Membership
|
||||
from synapse.api.errors import Codes, StoreError, SynapseError
|
||||
from synapse.appservice import ApplicationService
|
||||
from synapse.types import UserID
|
||||
import synapse.util.stringutils as stringutils
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def log_failure(failure):
|
||||
logger.error(
|
||||
"Application Services Failure",
|
||||
exc_info=(
|
||||
failure.type,
|
||||
failure.value,
|
||||
failure.getTracebackObject()
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# NB: Purposefully not inheriting BaseHandler since that contains way too much
|
||||
# setup code which this handler does not need or use. This makes testing a lot
|
||||
# easier.
|
||||
class ApplicationServicesHandler(object):
|
||||
|
||||
def __init__(self, hs, appservice_api):
|
||||
def __init__(self, hs, appservice_api, appservice_scheduler):
|
||||
self.store = hs.get_datastore()
|
||||
self.hs = hs
|
||||
self.appservice_api = appservice_api
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def register(self, app_service):
|
||||
logger.info("Register -> %s", app_service)
|
||||
# check the token is recognised
|
||||
try:
|
||||
stored_service = yield self.store.get_app_service_by_token(
|
||||
app_service.token
|
||||
)
|
||||
if not stored_service:
|
||||
raise StoreError(404, "Application service not found")
|
||||
except StoreError:
|
||||
raise SynapseError(
|
||||
403, "Unrecognised application services token. "
|
||||
"Consult the home server admin.",
|
||||
errcode=Codes.FORBIDDEN
|
||||
)
|
||||
|
||||
app_service.hs_token = self._generate_hs_token()
|
||||
|
||||
# create a sender for this application service which is used when
|
||||
# creating rooms, etc..
|
||||
account = yield self.hs.get_handlers().registration_handler.register()
|
||||
app_service.sender = account[0]
|
||||
|
||||
yield self.store.update_app_service(app_service)
|
||||
defer.returnValue(app_service)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def unregister(self, token):
|
||||
logger.info("Unregister as_token=%s", token)
|
||||
yield self.store.unregister_app_service(token)
|
||||
self.scheduler = appservice_scheduler
|
||||
self.started_scheduler = False
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def notify_interested_services(self, event):
|
||||
@@ -90,9 +69,13 @@ class ApplicationServicesHandler(object):
|
||||
if event.type == EventTypes.Member:
|
||||
yield self._check_user_exists(event.state_key)
|
||||
|
||||
# Fork off pushes to these services - XXX First cut, best effort
|
||||
if not self.started_scheduler:
|
||||
self.scheduler.start().addErrback(log_failure)
|
||||
self.started_scheduler = True
|
||||
|
||||
# Fork off pushes to these services
|
||||
for service in services:
|
||||
self.appservice_api.push(service, event)
|
||||
self.scheduler.submit_event_for_as(service, event)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def query_user_exists(self, user_id):
|
||||
@@ -197,7 +180,14 @@ class ApplicationServicesHandler(object):
|
||||
return
|
||||
|
||||
user_info = yield self.store.get_user_by_id(user_id)
|
||||
defer.returnValue(len(user_info) == 0)
|
||||
if len(user_info) > 0:
|
||||
defer.returnValue(False)
|
||||
return
|
||||
|
||||
# user not found; could be the AS though, so check.
|
||||
services = yield self.store.get_app_services()
|
||||
service_list = [s for s in services if s.sender == user_id]
|
||||
defer.returnValue(len(service_list) == 0)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _check_user_exists(self, user_id):
|
||||
@@ -206,6 +196,3 @@ class ApplicationServicesHandler(object):
|
||||
exists = yield self.query_user_exists(user_id)
|
||||
defer.returnValue(exists)
|
||||
defer.returnValue(True)
|
||||
|
||||
def _generate_hs_token(self):
|
||||
return stringutils.random_string(24)
|
||||
|
||||
277
synapse/handlers/auth.py
Normal file
277
synapse/handlers/auth.py
Normal file
@@ -0,0 +1,277 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014, 2015 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from ._base import BaseHandler
|
||||
from synapse.api.constants import LoginType
|
||||
from synapse.types import UserID
|
||||
from synapse.api.errors import LoginError, Codes
|
||||
from synapse.http.client import SimpleHttpClient
|
||||
from synapse.util.async import run_on_reactor
|
||||
|
||||
from twisted.web.client import PartialDownloadError
|
||||
|
||||
import logging
|
||||
import bcrypt
|
||||
import simplejson
|
||||
|
||||
import synapse.util.stringutils as stringutils
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AuthHandler(BaseHandler):
|
||||
|
||||
def __init__(self, hs):
|
||||
super(AuthHandler, self).__init__(hs)
|
||||
self.checkers = {
|
||||
LoginType.PASSWORD: self._check_password_auth,
|
||||
LoginType.RECAPTCHA: self._check_recaptcha,
|
||||
LoginType.EMAIL_IDENTITY: self._check_email_identity,
|
||||
LoginType.DUMMY: self._check_dummy_auth,
|
||||
}
|
||||
self.sessions = {}
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def check_auth(self, flows, clientdict, clientip=None):
|
||||
"""
|
||||
Takes a dictionary sent by the client in the login / registration
|
||||
protocol and handles the login flow.
|
||||
|
||||
Args:
|
||||
flows: list of list of stages
|
||||
authdict: The dictionary from the client root level, not the
|
||||
'auth' key: this method prompts for auth if none is sent.
|
||||
Returns:
|
||||
A tuple of authed, dict, dict where authed is true if the client
|
||||
has successfully completed an auth flow. If it is true, the first
|
||||
dict contains the authenticated credentials of each stage.
|
||||
|
||||
If authed is false, the first dictionary is the server response to
|
||||
the login request and should be passed back to the client.
|
||||
|
||||
In either case, the second dict contains the parameters for this
|
||||
request (which may have been given only in a previous call).
|
||||
"""
|
||||
|
||||
authdict = None
|
||||
sid = None
|
||||
if clientdict and 'auth' in clientdict:
|
||||
authdict = clientdict['auth']
|
||||
del clientdict['auth']
|
||||
if 'session' in authdict:
|
||||
sid = authdict['session']
|
||||
sess = self._get_session_info(sid)
|
||||
|
||||
if len(clientdict) > 0:
|
||||
# This was designed to allow the client to omit the parameters
|
||||
# and just supply the session in subsequent calls so it split
|
||||
# auth between devices by just sharing the session, (eg. so you
|
||||
# could continue registration from your phone having clicked the
|
||||
# email auth link on there). It's probably too open to abuse
|
||||
# because it lets unauthenticated clients store arbitrary objects
|
||||
# on a home server.
|
||||
# sess['clientdict'] = clientdict
|
||||
# self._save_session(sess)
|
||||
pass
|
||||
elif 'clientdict' in sess:
|
||||
clientdict = sess['clientdict']
|
||||
|
||||
if not authdict:
|
||||
defer.returnValue(
|
||||
(False, self._auth_dict_for_flows(flows, sess), clientdict)
|
||||
)
|
||||
|
||||
if 'creds' not in sess:
|
||||
sess['creds'] = {}
|
||||
creds = sess['creds']
|
||||
|
||||
# check auth type currently being presented
|
||||
if 'type' in authdict:
|
||||
if authdict['type'] not in self.checkers:
|
||||
raise LoginError(400, "", Codes.UNRECOGNIZED)
|
||||
result = yield self.checkers[authdict['type']](authdict, clientip)
|
||||
if result:
|
||||
creds[authdict['type']] = result
|
||||
self._save_session(sess)
|
||||
|
||||
for f in flows:
|
||||
if len(set(f) - set(creds.keys())) == 0:
|
||||
logger.info("Auth completed with creds: %r", creds)
|
||||
self._remove_session(sess)
|
||||
defer.returnValue((True, creds, clientdict))
|
||||
|
||||
ret = self._auth_dict_for_flows(flows, sess)
|
||||
ret['completed'] = creds.keys()
|
||||
defer.returnValue((False, ret, clientdict))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def add_oob_auth(self, stagetype, authdict, clientip):
|
||||
"""
|
||||
Adds the result of out-of-band authentication into an existing auth
|
||||
session. Currently used for adding the result of fallback auth.
|
||||
"""
|
||||
if stagetype not in self.checkers:
|
||||
raise LoginError(400, "", Codes.MISSING_PARAM)
|
||||
if 'session' not in authdict:
|
||||
raise LoginError(400, "", Codes.MISSING_PARAM)
|
||||
|
||||
sess = self._get_session_info(
|
||||
authdict['session']
|
||||
)
|
||||
if 'creds' not in sess:
|
||||
sess['creds'] = {}
|
||||
creds = sess['creds']
|
||||
|
||||
result = yield self.checkers[stagetype](authdict, clientip)
|
||||
if result:
|
||||
creds[stagetype] = result
|
||||
self._save_session(sess)
|
||||
defer.returnValue(True)
|
||||
defer.returnValue(False)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _check_password_auth(self, authdict, _):
|
||||
if "user" not in authdict or "password" not in authdict:
|
||||
raise LoginError(400, "", Codes.MISSING_PARAM)
|
||||
|
||||
user = authdict["user"]
|
||||
password = authdict["password"]
|
||||
if not user.startswith('@'):
|
||||
user = UserID.create(user, self.hs.hostname).to_string()
|
||||
|
||||
user_info = yield self.store.get_user_by_id(user_id=user)
|
||||
if not user_info:
|
||||
logger.warn("Attempted to login as %s but they do not exist", user)
|
||||
raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
|
||||
|
||||
stored_hash = user_info[0]["password_hash"]
|
||||
if bcrypt.checkpw(password, stored_hash):
|
||||
defer.returnValue(user)
|
||||
else:
|
||||
logger.warn("Failed password login for user %s", user)
|
||||
raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _check_recaptcha(self, authdict, clientip):
|
||||
try:
|
||||
user_response = authdict["response"]
|
||||
except KeyError:
|
||||
# Client tried to provide captcha but didn't give the parameter:
|
||||
# bad request.
|
||||
raise LoginError(
|
||||
400, "Captcha response is required",
|
||||
errcode=Codes.CAPTCHA_NEEDED
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Submitting recaptcha response %s with remoteip %s",
|
||||
user_response, clientip
|
||||
)
|
||||
|
||||
# TODO: get this from the homeserver rather than creating a new one for
|
||||
# each request
|
||||
try:
|
||||
client = SimpleHttpClient(self.hs)
|
||||
data = yield client.post_urlencoded_get_json(
|
||||
"https://www.google.com/recaptcha/api/siteverify",
|
||||
args={
|
||||
'secret': self.hs.config.recaptcha_private_key,
|
||||
'response': user_response,
|
||||
'remoteip': clientip,
|
||||
}
|
||||
)
|
||||
except PartialDownloadError as pde:
|
||||
# Twisted is silly
|
||||
data = pde.response
|
||||
resp_body = simplejson.loads(data)
|
||||
if 'success' in resp_body and resp_body['success']:
|
||||
defer.returnValue(True)
|
||||
raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _check_email_identity(self, authdict, _):
|
||||
yield run_on_reactor()
|
||||
|
||||
if 'threepid_creds' not in authdict:
|
||||
raise LoginError(400, "Missing threepid_creds", Codes.MISSING_PARAM)
|
||||
|
||||
threepid_creds = authdict['threepid_creds']
|
||||
identity_handler = self.hs.get_handlers().identity_handler
|
||||
|
||||
logger.info("Getting validated threepid. threepidcreds: %r" % (threepid_creds,))
|
||||
threepid = yield identity_handler.threepid_from_creds(threepid_creds)
|
||||
|
||||
if not threepid:
|
||||
raise LoginError(401, "", errcode=Codes.UNAUTHORIZED)
|
||||
|
||||
threepid['threepid_creds'] = authdict['threepid_creds']
|
||||
|
||||
defer.returnValue(threepid)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _check_dummy_auth(self, authdict, _):
|
||||
yield run_on_reactor()
|
||||
defer.returnValue(True)
|
||||
|
||||
def _get_params_recaptcha(self):
|
||||
return {"public_key": self.hs.config.recaptcha_public_key}
|
||||
|
||||
def _auth_dict_for_flows(self, flows, session):
|
||||
public_flows = []
|
||||
for f in flows:
|
||||
public_flows.append(f)
|
||||
|
||||
get_params = {
|
||||
LoginType.RECAPTCHA: self._get_params_recaptcha,
|
||||
}
|
||||
|
||||
params = {}
|
||||
|
||||
for f in public_flows:
|
||||
for stage in f:
|
||||
if stage in get_params and stage not in params:
|
||||
params[stage] = get_params[stage]()
|
||||
|
||||
return {
|
||||
"session": session['id'],
|
||||
"flows": [{"stages": f} for f in public_flows],
|
||||
"params": params
|
||||
}
|
||||
|
||||
def _get_session_info(self, session_id):
|
||||
if session_id not in self.sessions:
|
||||
session_id = None
|
||||
|
||||
if not session_id:
|
||||
# create a new session
|
||||
while session_id is None or session_id in self.sessions:
|
||||
session_id = stringutils.random_string(24)
|
||||
self.sessions[session_id] = {
|
||||
"id": session_id,
|
||||
}
|
||||
|
||||
return self.sessions[session_id]
|
||||
|
||||
def _save_session(self, session):
|
||||
# TODO: Persistent storage
|
||||
logger.debug("Saving session %s", session)
|
||||
self.sessions[session["id"]] = session
|
||||
|
||||
def _remove_session(self, session):
|
||||
logger.debug("Removing session %s", session)
|
||||
del self.sessions[session["id"]]
|
||||
@@ -179,7 +179,7 @@ class FederationHandler(BaseHandler):
|
||||
# it's probably a good idea to mark it as not in retry-state
|
||||
# for sending (although this is a bit of a leap)
|
||||
retry_timings = yield self.store.get_destination_retry_timings(origin)
|
||||
if (retry_timings and retry_timings.retry_last_ts):
|
||||
if retry_timings and retry_timings["retry_last_ts"]:
|
||||
self.store.set_destination_retry_timings(origin, 0, 0)
|
||||
|
||||
room = yield self.store.get_room(event.room_id)
|
||||
@@ -201,10 +201,18 @@ class FederationHandler(BaseHandler):
|
||||
target_user = UserID.from_string(target_user_id)
|
||||
extra_users.append(target_user)
|
||||
|
||||
yield self.notifier.on_new_room_event(
|
||||
d = self.notifier.on_new_room_event(
|
||||
event, extra_users=extra_users
|
||||
)
|
||||
|
||||
def log_failure(f):
|
||||
logger.warn(
|
||||
"Failed to notify about %s: %s",
|
||||
event.event_id, f.value
|
||||
)
|
||||
|
||||
d.addErrback(log_failure)
|
||||
|
||||
if event.type == EventTypes.Member:
|
||||
if event.membership == Membership.JOIN:
|
||||
user = UserID.from_string(event.state_key)
|
||||
@@ -427,10 +435,18 @@ class FederationHandler(BaseHandler):
|
||||
auth_events=auth_events,
|
||||
)
|
||||
|
||||
yield self.notifier.on_new_room_event(
|
||||
d = self.notifier.on_new_room_event(
|
||||
new_event, extra_users=[joinee]
|
||||
)
|
||||
|
||||
def log_failure(f):
|
||||
logger.warn(
|
||||
"Failed to notify about %s: %s",
|
||||
new_event.event_id, f.value
|
||||
)
|
||||
|
||||
d.addErrback(log_failure)
|
||||
|
||||
logger.debug("Finished joining %s to %s", joinee, room_id)
|
||||
finally:
|
||||
room_queue = self.room_queues[room_id]
|
||||
@@ -500,10 +516,18 @@ class FederationHandler(BaseHandler):
|
||||
target_user = UserID.from_string(target_user_id)
|
||||
extra_users.append(target_user)
|
||||
|
||||
yield self.notifier.on_new_room_event(
|
||||
d = self.notifier.on_new_room_event(
|
||||
event, extra_users=extra_users
|
||||
)
|
||||
|
||||
def log_failure(f):
|
||||
logger.warn(
|
||||
"Failed to notify about %s: %s",
|
||||
event.event_id, f.value
|
||||
)
|
||||
|
||||
d.addErrback(log_failure)
|
||||
|
||||
if event.type == EventTypes.Member:
|
||||
if event.content["membership"] == Membership.JOIN:
|
||||
user = UserID.from_string(event.state_key)
|
||||
@@ -574,10 +598,18 @@ class FederationHandler(BaseHandler):
|
||||
)
|
||||
|
||||
target_user = UserID.from_string(event.state_key)
|
||||
yield self.notifier.on_new_room_event(
|
||||
d = self.notifier.on_new_room_event(
|
||||
event, extra_users=[target_user],
|
||||
)
|
||||
|
||||
def log_failure(f):
|
||||
logger.warn(
|
||||
"Failed to notify about %s: %s",
|
||||
event.event_id, f.value
|
||||
)
|
||||
|
||||
d.addErrback(log_failure)
|
||||
|
||||
defer.returnValue(event)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
||||
119
synapse/handlers/identity.py
Normal file
119
synapse/handlers/identity.py
Normal file
@@ -0,0 +1,119 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2015 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Utilities for interacting with Identity Servers"""
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.errors import (
|
||||
CodeMessageException
|
||||
)
|
||||
from ._base import BaseHandler
|
||||
from synapse.http.client import SimpleHttpClient
|
||||
from synapse.util.async import run_on_reactor
|
||||
from synapse.api.errors import SynapseError
|
||||
|
||||
import json
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class IdentityHandler(BaseHandler):
|
||||
|
||||
def __init__(self, hs):
|
||||
super(IdentityHandler, self).__init__(hs)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def threepid_from_creds(self, creds):
|
||||
yield run_on_reactor()
|
||||
|
||||
# TODO: get this from the homeserver rather than creating a new one for
|
||||
# each request
|
||||
http_client = SimpleHttpClient(self.hs)
|
||||
# XXX: make this configurable!
|
||||
# trustedIdServers = ['matrix.org', 'localhost:8090']
|
||||
trustedIdServers = ['matrix.org']
|
||||
|
||||
if 'id_server' in creds:
|
||||
id_server = creds['id_server']
|
||||
elif 'idServer' in creds:
|
||||
id_server = creds['idServer']
|
||||
else:
|
||||
raise SynapseError(400, "No id_server in creds")
|
||||
|
||||
if 'client_secret' in creds:
|
||||
client_secret = creds['client_secret']
|
||||
elif 'clientSecret' in creds:
|
||||
client_secret = creds['clientSecret']
|
||||
else:
|
||||
raise SynapseError(400, "No client_secret in creds")
|
||||
|
||||
if id_server not in trustedIdServers:
|
||||
logger.warn('%s is not a trusted ID server: rejecting 3pid ' +
|
||||
'credentials', id_server)
|
||||
defer.returnValue(None)
|
||||
|
||||
data = {}
|
||||
try:
|
||||
data = yield http_client.get_json(
|
||||
"https://%s%s" % (
|
||||
id_server,
|
||||
"/_matrix/identity/api/v1/3pid/getValidated3pid"
|
||||
),
|
||||
{'sid': creds['sid'], 'client_secret': client_secret}
|
||||
)
|
||||
except CodeMessageException as e:
|
||||
data = json.loads(e.msg)
|
||||
|
||||
if 'medium' in data:
|
||||
defer.returnValue(data)
|
||||
defer.returnValue(None)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def bind_threepid(self, creds, mxid):
|
||||
yield run_on_reactor()
|
||||
logger.debug("binding threepid %r to %s", creds, mxid)
|
||||
http_client = SimpleHttpClient(self.hs)
|
||||
data = None
|
||||
|
||||
if 'id_server' in creds:
|
||||
id_server = creds['id_server']
|
||||
elif 'idServer' in creds:
|
||||
id_server = creds['idServer']
|
||||
else:
|
||||
raise SynapseError(400, "No id_server in creds")
|
||||
|
||||
if 'client_secret' in creds:
|
||||
client_secret = creds['client_secret']
|
||||
elif 'clientSecret' in creds:
|
||||
client_secret = creds['clientSecret']
|
||||
else:
|
||||
raise SynapseError(400, "No client_secret in creds")
|
||||
|
||||
try:
|
||||
data = yield http_client.post_urlencoded_get_json(
|
||||
"https://%s%s" % (
|
||||
id_server, "/_matrix/identity/api/v1/3pid/bind"
|
||||
),
|
||||
{
|
||||
'sid': creds['sid'],
|
||||
'client_secret': client_secret,
|
||||
'mxid': mxid,
|
||||
}
|
||||
)
|
||||
logger.debug("bound threepid %r to %s", creds, mxid)
|
||||
except CodeMessageException as e:
|
||||
data = json.loads(e.msg)
|
||||
defer.returnValue(data)
|
||||
@@ -16,13 +16,9 @@
|
||||
from twisted.internet import defer
|
||||
|
||||
from ._base import BaseHandler
|
||||
from synapse.api.errors import LoginError, Codes, CodeMessageException
|
||||
from synapse.http.client import SimpleHttpClient
|
||||
from synapse.util.emailutils import EmailException
|
||||
import synapse.util.emailutils as emailutils
|
||||
from synapse.api.errors import LoginError, Codes
|
||||
|
||||
import bcrypt
|
||||
import json
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -57,7 +53,7 @@ class LoginHandler(BaseHandler):
|
||||
logger.warn("Attempted to login as %s but they do not exist", user)
|
||||
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
|
||||
|
||||
stored_hash = user_info[0]["password_hash"]
|
||||
stored_hash = user_info["password_hash"]
|
||||
if bcrypt.checkpw(password, stored_hash):
|
||||
# generate an access token and store it.
|
||||
token = self.reg_handler._generate_token(user)
|
||||
@@ -69,48 +65,19 @@ class LoginHandler(BaseHandler):
|
||||
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def reset_password(self, user_id, email):
|
||||
is_valid = yield self._check_valid_association(user_id, email)
|
||||
logger.info("reset_password user=%s email=%s valid=%s", user_id, email,
|
||||
is_valid)
|
||||
if is_valid:
|
||||
try:
|
||||
# send an email out
|
||||
emailutils.send_email(
|
||||
smtp_server=self.hs.config.email_smtp_server,
|
||||
from_addr=self.hs.config.email_from_address,
|
||||
to_addr=email,
|
||||
subject="Password Reset",
|
||||
body="TODO."
|
||||
)
|
||||
except EmailException as e:
|
||||
logger.exception(e)
|
||||
def set_password(self, user_id, newpassword, token_id=None):
|
||||
password_hash = bcrypt.hashpw(newpassword, bcrypt.gensalt())
|
||||
|
||||
yield self.store.user_set_password_hash(user_id, password_hash)
|
||||
yield self.store.user_delete_access_tokens_apart_from(user_id, token_id)
|
||||
yield self.hs.get_pusherpool().remove_pushers_by_user_access_token(
|
||||
user_id, token_id
|
||||
)
|
||||
yield self.store.flush_user(user_id)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _check_valid_association(self, user_id, email):
|
||||
identity = yield self._query_email(email)
|
||||
if identity and "mxid" in identity:
|
||||
if identity["mxid"] == user_id:
|
||||
defer.returnValue(True)
|
||||
return
|
||||
defer.returnValue(False)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _query_email(self, email):
|
||||
http_client = SimpleHttpClient(self.hs)
|
||||
try:
|
||||
data = yield http_client.get_json(
|
||||
# TODO FIXME This should be configurable.
|
||||
# XXX: ID servers need to use HTTPS
|
||||
"http://%s%s" % (
|
||||
"matrix.org:8090", "/_matrix/identity/api/v1/lookup"
|
||||
),
|
||||
{
|
||||
'medium': 'email',
|
||||
'address': email
|
||||
}
|
||||
)
|
||||
defer.returnValue(data)
|
||||
except CodeMessageException as e:
|
||||
data = json.loads(e.msg)
|
||||
defer.returnValue(data)
|
||||
def add_threepid(self, user_id, medium, address, validated_at):
|
||||
yield self.store.user_add_threepid(
|
||||
user_id, medium, address, validated_at,
|
||||
self.hs.get_clock().time_msec()
|
||||
)
|
||||
|
||||
@@ -250,31 +250,47 @@ class MessageHandler(BaseHandler):
|
||||
is joined on, may return a "messages" key with messages, depending
|
||||
on the specified PaginationConfig.
|
||||
"""
|
||||
start_time = self.clock.time_msec()
|
||||
|
||||
def delta():
|
||||
return self.clock.time_msec() - start_time
|
||||
|
||||
logger.info("initial_sync: start")
|
||||
room_list = yield self.store.get_rooms_for_user_where_membership_is(
|
||||
user_id=user_id,
|
||||
membership_list=[Membership.INVITE, Membership.JOIN]
|
||||
)
|
||||
|
||||
logger.info("initial_sync: got_rooms %d", delta())
|
||||
|
||||
user = UserID.from_string(user_id)
|
||||
|
||||
rooms_ret = []
|
||||
|
||||
now_token = yield self.hs.get_event_sources().get_current_token()
|
||||
|
||||
logger.info("initial_sync: now_token %d", delta())
|
||||
|
||||
presence_stream = self.hs.get_event_sources().sources["presence"]
|
||||
pagination_config = PaginationConfig(from_token=now_token)
|
||||
presence, _ = yield presence_stream.get_pagination_rows(
|
||||
user, pagination_config.get_source_config("presence"), None
|
||||
)
|
||||
|
||||
public_rooms = yield self.store.get_rooms(is_public=True)
|
||||
public_room_ids = [r["room_id"] for r in public_rooms]
|
||||
logger.info("initial_sync: presence_done %d", delta())
|
||||
|
||||
public_room_ids = yield self.store.get_public_room_ids()
|
||||
|
||||
logger.info("initial_sync: public_rooms %d", delta())
|
||||
|
||||
limit = pagin_config.limit
|
||||
if limit is None:
|
||||
limit = 10
|
||||
|
||||
for event in room_list:
|
||||
@defer.inlineCallbacks
|
||||
def handle_room(event):
|
||||
logger.info("initial_sync: start: %s %d", event.room_id, delta())
|
||||
|
||||
d = {
|
||||
"room_id": event.room_id,
|
||||
"membership": event.membership,
|
||||
@@ -290,12 +306,19 @@ class MessageHandler(BaseHandler):
|
||||
rooms_ret.append(d)
|
||||
|
||||
if event.membership != Membership.JOIN:
|
||||
continue
|
||||
return
|
||||
try:
|
||||
messages, token = yield self.store.get_recent_events_for_room(
|
||||
event.room_id,
|
||||
limit=limit,
|
||||
end_token=now_token.room_key,
|
||||
(messages, token), current_state = yield defer.gatherResults(
|
||||
[
|
||||
self.store.get_recent_events_for_room(
|
||||
event.room_id,
|
||||
limit=limit,
|
||||
end_token=now_token.room_key,
|
||||
),
|
||||
self.state_handler.get_current_state(
|
||||
event.room_id
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
start_token = now_token.copy_and_replace("room_key", token[0])
|
||||
@@ -311,9 +334,6 @@ class MessageHandler(BaseHandler):
|
||||
"end": end_token.to_string(),
|
||||
}
|
||||
|
||||
current_state = yield self.state_handler.get_current_state(
|
||||
event.room_id
|
||||
)
|
||||
d["state"] = [
|
||||
serialize_event(c, time_now, as_client_event)
|
||||
for c in current_state.values()
|
||||
@@ -321,6 +341,15 @@ class MessageHandler(BaseHandler):
|
||||
except:
|
||||
logger.exception("Failed to get snapshot")
|
||||
|
||||
logger.info("initial_sync: end: %s %d", event.room_id, delta())
|
||||
|
||||
yield defer.gatherResults(
|
||||
[handle_room(e) for e in room_list],
|
||||
consumeErrors=True
|
||||
)
|
||||
|
||||
logger.info("initial_sync: done", delta())
|
||||
|
||||
ret = {
|
||||
"rooms": rooms_ret,
|
||||
"presence": presence,
|
||||
|
||||
@@ -33,6 +33,13 @@ logger = logging.getLogger(__name__)
|
||||
metrics = synapse.metrics.get_metrics_for(__name__)
|
||||
|
||||
|
||||
# Don't bother bumping "last active" time if it differs by less than 60 seconds
|
||||
LAST_ACTIVE_GRANULARITY = 60*1000
|
||||
|
||||
# Keep no more than this number of offline serial revisions
|
||||
MAX_OFFLINE_SERIALS = 1000
|
||||
|
||||
|
||||
# TODO(paul): Maybe there's one of these I can steal from somewhere
|
||||
def partition(l, func):
|
||||
"""Partition the list by the result of func applied to each element."""
|
||||
@@ -131,6 +138,9 @@ class PresenceHandler(BaseHandler):
|
||||
self._remote_sendmap = {}
|
||||
# map remote users to sets of local users who're interested in them
|
||||
self._remote_recvmap = {}
|
||||
# list of (serial, set of(userids)) tuples, ordered by serial, latest
|
||||
# first
|
||||
self._remote_offline_serials = []
|
||||
|
||||
# map any user to a UserPresenceCache
|
||||
self._user_cachemap = {}
|
||||
@@ -282,6 +292,10 @@ class PresenceHandler(BaseHandler):
|
||||
if now is None:
|
||||
now = self.clock.time_msec()
|
||||
|
||||
prev_state = self._get_or_make_usercache(user)
|
||||
if now - prev_state.state.get("last_active", 0) < LAST_ACTIVE_GRANULARITY:
|
||||
return
|
||||
|
||||
self.changed_presencelike_data(user, {"last_active": now})
|
||||
|
||||
def changed_presencelike_data(self, user, state):
|
||||
@@ -706,8 +720,24 @@ class PresenceHandler(BaseHandler):
|
||||
statuscache=statuscache,
|
||||
)
|
||||
|
||||
user_id = user.to_string()
|
||||
|
||||
if state["presence"] == PresenceState.OFFLINE:
|
||||
self._remote_offline_serials.insert(
|
||||
0,
|
||||
(self._user_cachemap_latest_serial, set([user_id]))
|
||||
)
|
||||
while len(self._remote_offline_serials) > MAX_OFFLINE_SERIALS:
|
||||
self._remote_offline_serials.pop() # remove the oldest
|
||||
del self._user_cachemap[user]
|
||||
else:
|
||||
# Remove the user from remote_offline_serials now that they're
|
||||
# no longer offline
|
||||
for idx, elem in enumerate(self._remote_offline_serials):
|
||||
(_, user_ids) = elem
|
||||
user_ids.discard(user_id)
|
||||
if not user_ids:
|
||||
self._remote_offline_serials.pop(idx)
|
||||
|
||||
for poll in content.get("poll", []):
|
||||
user = UserID.from_string(poll)
|
||||
@@ -829,26 +859,47 @@ class PresenceEventSource(object):
|
||||
presence = self.hs.get_handlers().presence_handler
|
||||
cachemap = presence._user_cachemap
|
||||
|
||||
max_serial = presence._user_cachemap_latest_serial
|
||||
|
||||
clock = self.clock
|
||||
latest_serial = 0
|
||||
|
||||
updates = []
|
||||
# TODO(paul): use a DeferredList ? How to limit concurrency.
|
||||
for observed_user in cachemap.keys():
|
||||
cached = cachemap[observed_user]
|
||||
|
||||
if cached.serial <= from_key:
|
||||
if cached.serial <= from_key or cached.serial > max_serial:
|
||||
continue
|
||||
|
||||
if (yield self.is_visible(observer_user, observed_user)):
|
||||
updates.append((observed_user, cached))
|
||||
if not (yield self.is_visible(observer_user, observed_user)):
|
||||
continue
|
||||
|
||||
latest_serial = max(cached.serial, latest_serial)
|
||||
updates.append(cached.make_event(user=observed_user, clock=clock))
|
||||
|
||||
# TODO(paul): limit
|
||||
|
||||
for serial, user_ids in presence._remote_offline_serials:
|
||||
if serial <= from_key:
|
||||
break
|
||||
|
||||
if serial > max_serial:
|
||||
continue
|
||||
|
||||
latest_serial = max(latest_serial, serial)
|
||||
for u in user_ids:
|
||||
updates.append({
|
||||
"type": "m.presence",
|
||||
"content": {"user_id": u, "presence": PresenceState.OFFLINE},
|
||||
})
|
||||
# TODO(paul): For the v2 API we want to tell the client their from_key
|
||||
# is too old if we fell off the end of the _remote_offline_serials
|
||||
# list, and get them to invalidate+resync. In v1 we have no such
|
||||
# concept so this is a best-effort result.
|
||||
|
||||
if updates:
|
||||
clock = self.clock
|
||||
|
||||
latest_serial = max([x[1].serial for x in updates])
|
||||
data = [x[1].make_event(user=x[0], clock=clock) for x in updates]
|
||||
|
||||
defer.returnValue((data, latest_serial))
|
||||
defer.returnValue((updates, latest_serial))
|
||||
else:
|
||||
defer.returnValue(([], presence._user_cachemap_latest_serial))
|
||||
|
||||
|
||||
@@ -18,18 +18,15 @@ from twisted.internet import defer
|
||||
|
||||
from synapse.types import UserID
|
||||
from synapse.api.errors import (
|
||||
AuthError, Codes, SynapseError, RegistrationError, InvalidCaptchaError,
|
||||
CodeMessageException
|
||||
AuthError, Codes, SynapseError, RegistrationError, InvalidCaptchaError
|
||||
)
|
||||
from ._base import BaseHandler
|
||||
import synapse.util.stringutils as stringutils
|
||||
from synapse.util.async import run_on_reactor
|
||||
from synapse.http.client import SimpleHttpClient
|
||||
from synapse.http.client import CaptchaServerHttpClient
|
||||
|
||||
import base64
|
||||
import bcrypt
|
||||
import json
|
||||
import logging
|
||||
import urllib
|
||||
|
||||
@@ -44,6 +41,30 @@ class RegistrationHandler(BaseHandler):
|
||||
self.distributor = hs.get_distributor()
|
||||
self.distributor.declare("registered_user")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def check_username(self, localpart):
|
||||
yield run_on_reactor()
|
||||
|
||||
if urllib.quote(localpart) != localpart:
|
||||
raise SynapseError(
|
||||
400,
|
||||
"User ID must only contain characters which do not"
|
||||
" require URL encoding."
|
||||
)
|
||||
|
||||
user = UserID(localpart, self.hs.hostname)
|
||||
user_id = user.to_string()
|
||||
|
||||
yield self.check_user_id_is_valid(user_id)
|
||||
|
||||
u = yield self.store.get_user_by_id(user_id)
|
||||
if u:
|
||||
raise SynapseError(
|
||||
400,
|
||||
"User ID already taken.",
|
||||
errcode=Codes.USER_IN_USE,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def register(self, localpart=None, password=None):
|
||||
"""Registers a new client on the server.
|
||||
@@ -64,18 +85,11 @@ class RegistrationHandler(BaseHandler):
|
||||
password_hash = bcrypt.hashpw(password, bcrypt.gensalt())
|
||||
|
||||
if localpart:
|
||||
if localpart and urllib.quote(localpart) != localpart:
|
||||
raise SynapseError(
|
||||
400,
|
||||
"User ID must only contain characters which do not"
|
||||
" require URL encoding."
|
||||
)
|
||||
yield self.check_username(localpart)
|
||||
|
||||
user = UserID(localpart, self.hs.hostname)
|
||||
user_id = user.to_string()
|
||||
|
||||
yield self.check_user_id_is_valid(user_id)
|
||||
|
||||
token = self._generate_token(user_id)
|
||||
yield self.store.register(
|
||||
user_id=user_id,
|
||||
@@ -157,7 +171,11 @@ class RegistrationHandler(BaseHandler):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def check_recaptcha(self, ip, private_key, challenge, response):
|
||||
"""Checks a recaptcha is correct."""
|
||||
"""
|
||||
Checks a recaptcha is correct.
|
||||
|
||||
Used only by c/s api v1
|
||||
"""
|
||||
|
||||
captcha_response = yield self._validate_captcha(
|
||||
ip,
|
||||
@@ -176,13 +194,18 @@ class RegistrationHandler(BaseHandler):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def register_email(self, threepidCreds):
|
||||
"""Registers emails with an identity server."""
|
||||
"""
|
||||
Registers emails with an identity server.
|
||||
|
||||
Used only by c/s api v1
|
||||
"""
|
||||
|
||||
for c in threepidCreds:
|
||||
logger.info("validating theeepidcred sid %s on id server %s",
|
||||
c['sid'], c['idServer'])
|
||||
try:
|
||||
threepid = yield self._threepid_from_creds(c)
|
||||
identity_handler = self.hs.get_handlers().identity_handler
|
||||
threepid = yield identity_handler.threepid_from_creds(c)
|
||||
except:
|
||||
logger.exception("Couldn't validate 3pid")
|
||||
raise RegistrationError(400, "Couldn't validate 3pid")
|
||||
@@ -194,12 +217,16 @@ class RegistrationHandler(BaseHandler):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def bind_emails(self, user_id, threepidCreds):
|
||||
"""Links emails with a user ID and informs an identity server."""
|
||||
"""Links emails with a user ID and informs an identity server.
|
||||
|
||||
Used only by c/s api v1
|
||||
"""
|
||||
|
||||
# Now we have a matrix ID, bind it to the threepids we were given
|
||||
for c in threepidCreds:
|
||||
identity_handler = self.hs.get_handlers().identity_handler
|
||||
# XXX: This should be a deferred list, shouldn't it?
|
||||
yield self._bind_threepid(c, user_id)
|
||||
yield identity_handler.bind_threepid(c, user_id)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def check_user_id_is_valid(self, user_id):
|
||||
@@ -226,62 +253,12 @@ class RegistrationHandler(BaseHandler):
|
||||
def _generate_user_id(self):
|
||||
return "-" + stringutils.random_string(18)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _threepid_from_creds(self, creds):
|
||||
# TODO: get this from the homeserver rather than creating a new one for
|
||||
# each request
|
||||
http_client = SimpleHttpClient(self.hs)
|
||||
# XXX: make this configurable!
|
||||
trustedIdServers = ['matrix.org:8090', 'matrix.org']
|
||||
if not creds['idServer'] in trustedIdServers:
|
||||
logger.warn('%s is not a trusted ID server: rejecting 3pid ' +
|
||||
'credentials', creds['idServer'])
|
||||
defer.returnValue(None)
|
||||
|
||||
data = {}
|
||||
try:
|
||||
data = yield http_client.get_json(
|
||||
# XXX: This should be HTTPS
|
||||
"http://%s%s" % (
|
||||
creds['idServer'],
|
||||
"/_matrix/identity/api/v1/3pid/getValidated3pid"
|
||||
),
|
||||
{'sid': creds['sid'], 'clientSecret': creds['clientSecret']}
|
||||
)
|
||||
except CodeMessageException as e:
|
||||
data = json.loads(e.msg)
|
||||
|
||||
if 'medium' in data:
|
||||
defer.returnValue(data)
|
||||
defer.returnValue(None)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _bind_threepid(self, creds, mxid):
|
||||
yield
|
||||
logger.debug("binding threepid")
|
||||
http_client = SimpleHttpClient(self.hs)
|
||||
data = None
|
||||
try:
|
||||
data = yield http_client.post_urlencoded_get_json(
|
||||
# XXX: Change when ID servers are all HTTPS
|
||||
"http://%s%s" % (
|
||||
creds['idServer'], "/_matrix/identity/api/v1/3pid/bind"
|
||||
),
|
||||
{
|
||||
'sid': creds['sid'],
|
||||
'clientSecret': creds['clientSecret'],
|
||||
'mxid': mxid,
|
||||
}
|
||||
)
|
||||
logger.debug("bound threepid")
|
||||
except CodeMessageException as e:
|
||||
data = json.loads(e.msg)
|
||||
defer.returnValue(data)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _validate_captcha(self, ip_addr, private_key, challenge, response):
|
||||
"""Validates the captcha provided.
|
||||
|
||||
Used only by c/s api v1
|
||||
|
||||
Returns:
|
||||
dict: Containing 'valid'(bool) and 'error_url'(str) if invalid.
|
||||
|
||||
@@ -299,6 +276,9 @@ class RegistrationHandler(BaseHandler):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _submit_captcha(self, ip_addr, private_key, challenge, response):
|
||||
"""
|
||||
Used only by c/s api v1
|
||||
"""
|
||||
# TODO: get this from the homeserver rather than creating a new one for
|
||||
# each request
|
||||
client = CaptchaServerHttpClient(self.hs)
|
||||
|
||||
@@ -124,7 +124,7 @@ class RoomCreationHandler(BaseHandler):
|
||||
msg_handler = self.hs.get_handlers().message_handler
|
||||
|
||||
for event in creation_events:
|
||||
yield msg_handler.create_and_send_event(event)
|
||||
yield msg_handler.create_and_send_event(event, ratelimit=False)
|
||||
|
||||
if "name" in config:
|
||||
name = config["name"]
|
||||
@@ -134,7 +134,7 @@ class RoomCreationHandler(BaseHandler):
|
||||
"sender": user_id,
|
||||
"state_key": "",
|
||||
"content": {"name": name},
|
||||
})
|
||||
}, ratelimit=False)
|
||||
|
||||
if "topic" in config:
|
||||
topic = config["topic"]
|
||||
@@ -144,7 +144,7 @@ class RoomCreationHandler(BaseHandler):
|
||||
"sender": user_id,
|
||||
"state_key": "",
|
||||
"content": {"topic": topic},
|
||||
})
|
||||
}, ratelimit=False)
|
||||
|
||||
for invitee in invite_list:
|
||||
yield msg_handler.create_and_send_event({
|
||||
@@ -153,7 +153,7 @@ class RoomCreationHandler(BaseHandler):
|
||||
"room_id": room_id,
|
||||
"sender": user_id,
|
||||
"content": {"membership": Membership.INVITE},
|
||||
})
|
||||
}, ratelimit=False)
|
||||
|
||||
result = {"room_id": room_id}
|
||||
|
||||
@@ -213,7 +213,8 @@ class RoomCreationHandler(BaseHandler):
|
||||
"state_default": 50,
|
||||
"ban": 50,
|
||||
"kick": 50,
|
||||
"redact": 50
|
||||
"redact": 50,
|
||||
"invite": 0,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -310,25 +311,6 @@ class RoomMemberHandler(BaseHandler):
|
||||
# paginating
|
||||
defer.returnValue(chunk_data)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_room_member(self, room_id, member_user_id, auth_user_id):
|
||||
"""Retrieve a room member from a room.
|
||||
|
||||
Args:
|
||||
room_id : The room the member is in.
|
||||
member_user_id : The member's user ID
|
||||
auth_user_id : The user ID of the user making this request.
|
||||
Returns:
|
||||
The room member, or None if this member does not exist.
|
||||
Raises:
|
||||
SynapseError if something goes wrong.
|
||||
"""
|
||||
yield self.auth.check_joined_room(room_id, auth_user_id)
|
||||
|
||||
member = yield self.store.get_room_member(user_id=member_user_id,
|
||||
room_id=room_id)
|
||||
defer.returnValue(member)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def change_membership(self, event, context, do_auth=True):
|
||||
""" Change the membership status of a user in a room.
|
||||
|
||||
@@ -223,6 +223,7 @@ class TypingNotificationEventSource(object):
|
||||
def __init__(self, hs):
|
||||
self.hs = hs
|
||||
self._handler = None
|
||||
self._room_member_handler = None
|
||||
|
||||
def handler(self):
|
||||
# Avoid cyclic dependency in handler setup
|
||||
@@ -230,6 +231,11 @@ class TypingNotificationEventSource(object):
|
||||
self._handler = self.hs.get_handlers().typing_notification_handler
|
||||
return self._handler
|
||||
|
||||
def room_member_handler(self):
|
||||
if not self._room_member_handler:
|
||||
self._room_member_handler = self.hs.get_handlers().room_member_handler
|
||||
return self._room_member_handler
|
||||
|
||||
def _make_event_for(self, room_id):
|
||||
typing = self.handler()._room_typing[room_id]
|
||||
return {
|
||||
@@ -240,19 +246,25 @@ class TypingNotificationEventSource(object):
|
||||
},
|
||||
}
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_new_events_for_user(self, user, from_key, limit):
|
||||
from_key = int(from_key)
|
||||
handler = self.handler()
|
||||
|
||||
joined_room_ids = (
|
||||
yield self.room_member_handler().get_joined_rooms_for_user(user)
|
||||
)
|
||||
|
||||
events = []
|
||||
for room_id in handler._room_serials:
|
||||
if room_id not in joined_room_ids:
|
||||
continue
|
||||
if handler._room_serials[room_id] <= from_key:
|
||||
continue
|
||||
|
||||
# TODO: check if user is in room
|
||||
events.append(self._make_event_for(room_id))
|
||||
|
||||
return (events, handler._latest_room_serial)
|
||||
defer.returnValue((events, handler._latest_room_serial))
|
||||
|
||||
def get_current_key(self):
|
||||
return self.handler()._latest_room_serial
|
||||
|
||||
@@ -200,6 +200,8 @@ class CaptchaServerHttpClient(SimpleHttpClient):
|
||||
"""
|
||||
Separate HTTP client for talking to google's captcha servers
|
||||
Only slightly special because accepts partial download responses
|
||||
|
||||
used only by c/s api v1
|
||||
"""
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
||||
@@ -24,7 +24,7 @@ from syutil.jsonutil import (
|
||||
encode_canonical_json, encode_pretty_printed_json
|
||||
)
|
||||
|
||||
from twisted.internet import defer, reactor
|
||||
from twisted.internet import defer
|
||||
from twisted.web import server, resource
|
||||
from twisted.web.server import NOT_DONE_YET
|
||||
from twisted.web.util import redirectTo
|
||||
@@ -51,16 +51,90 @@ response_timer = metrics.register_distribution(
|
||||
labels=["method", "servlet"]
|
||||
)
|
||||
|
||||
_next_request_id = 0
|
||||
|
||||
|
||||
def request_handler(request_handler):
|
||||
"""Wraps a method that acts as a request handler with the necessary logging
|
||||
and exception handling.
|
||||
|
||||
The method must have a signature of "handle_foo(self, request)". The
|
||||
argument "self" must have "version_string" and "clock" attributes. The
|
||||
argument "request" must be a twisted HTTP request.
|
||||
|
||||
The method must return a deferred. If the deferred succeeds we assume that
|
||||
a response has been sent. If the deferred fails with a SynapseError we use
|
||||
it to send a JSON response with the appropriate HTTP reponse code. If the
|
||||
deferred fails with any other type of error we send a 500 reponse.
|
||||
|
||||
We insert a unique request-id into the logging context for this request and
|
||||
log the response and duration for this request.
|
||||
"""
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def wrapped_request_handler(self, request):
|
||||
global _next_request_id
|
||||
request_id = "%s-%s" % (request.method, _next_request_id)
|
||||
_next_request_id += 1
|
||||
with LoggingContext(request_id) as request_context:
|
||||
request_context.request = request_id
|
||||
code = None
|
||||
start = self.clock.time_msec()
|
||||
try:
|
||||
logger.info(
|
||||
"Received request: %s %s",
|
||||
request.method, request.path
|
||||
)
|
||||
yield request_handler(self, request)
|
||||
code = request.code
|
||||
except CodeMessageException as e:
|
||||
code = e.code
|
||||
if isinstance(e, SynapseError):
|
||||
logger.info(
|
||||
"%s SynapseError: %s - %s", request, code, e.msg
|
||||
)
|
||||
else:
|
||||
logger.exception(e)
|
||||
outgoing_responses_counter.inc(request.method, str(code))
|
||||
respond_with_json(
|
||||
request, code, cs_exception(e), send_cors=True,
|
||||
pretty_print=_request_user_agent_is_curl(request),
|
||||
version_string=self.version_string,
|
||||
)
|
||||
except:
|
||||
code = 500
|
||||
logger.exception(
|
||||
"Failed handle request %s.%s on %r: %r",
|
||||
request_handler.__module__,
|
||||
request_handler.__name__,
|
||||
self,
|
||||
request
|
||||
)
|
||||
respond_with_json(
|
||||
request,
|
||||
500,
|
||||
{"error": "Internal server error"},
|
||||
send_cors=True
|
||||
)
|
||||
finally:
|
||||
code = str(code) if code else "-"
|
||||
end = self.clock.time_msec()
|
||||
logger.info(
|
||||
"Processed request: %dms %s %s %s",
|
||||
end-start, code, request.method, request.path
|
||||
)
|
||||
return wrapped_request_handler
|
||||
|
||||
|
||||
class HttpServer(object):
|
||||
""" Interface for registering callbacks on a HTTP server
|
||||
"""
|
||||
|
||||
def register_path(self, method, path_pattern, callback):
|
||||
""" Register a callback that get's fired if we receive a http request
|
||||
""" Register a callback that gets fired if we receive a http request
|
||||
with the given method for a path that matches the given regex.
|
||||
|
||||
If the regex contains groups these get's passed to the calback via
|
||||
If the regex contains groups these gets passed to the calback via
|
||||
an unpacked tuple.
|
||||
|
||||
Args:
|
||||
@@ -79,6 +153,13 @@ class JsonResource(HttpServer, resource.Resource):
|
||||
Resources.
|
||||
|
||||
Register callbacks via register_path()
|
||||
|
||||
Callbacks can return a tuple of status code and a dict in which case the
|
||||
the dict will automatically be sent to the client as a JSON object.
|
||||
|
||||
The JsonResource is primarily intended for returning JSON, but callbacks
|
||||
may send something other than JSON, they may do so by using the methods
|
||||
on the request object and instead returning None.
|
||||
"""
|
||||
|
||||
isLeaf = True
|
||||
@@ -98,119 +179,61 @@ class JsonResource(HttpServer, resource.Resource):
|
||||
self._PathEntry(path_pattern, callback)
|
||||
)
|
||||
|
||||
def start_listening(self, port):
|
||||
""" Registers the http server with the twisted reactor.
|
||||
|
||||
Args:
|
||||
port (int): The port to listen on.
|
||||
|
||||
"""
|
||||
reactor.listenTCP(
|
||||
port,
|
||||
server.Site(self),
|
||||
interface=self.hs.config.bind_host
|
||||
)
|
||||
|
||||
# Gets called by twisted
|
||||
def render(self, request):
|
||||
""" This get's called by twisted every time someone sends us a request.
|
||||
""" This gets called by twisted every time someone sends us a request.
|
||||
"""
|
||||
self._async_render_with_logging_context(request)
|
||||
self._async_render(request)
|
||||
return server.NOT_DONE_YET
|
||||
|
||||
_request_id = 0
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _async_render_with_logging_context(self, request):
|
||||
request_id = "%s-%s" % (request.method, JsonResource._request_id)
|
||||
JsonResource._request_id += 1
|
||||
with LoggingContext(request_id) as request_context:
|
||||
request_context.request = request_id
|
||||
yield self._async_render(request)
|
||||
|
||||
@request_handler
|
||||
@defer.inlineCallbacks
|
||||
def _async_render(self, request):
|
||||
""" This get's called by twisted every time someone sends us a request.
|
||||
""" This gets called from render() every time someone sends us a request.
|
||||
This checks if anyone has registered a callback for that method and
|
||||
path.
|
||||
"""
|
||||
code = None
|
||||
start = self.clock.time_msec()
|
||||
try:
|
||||
# Just say yes to OPTIONS.
|
||||
if request.method == "OPTIONS":
|
||||
self._send_response(request, 200, {})
|
||||
return
|
||||
if request.method == "OPTIONS":
|
||||
self._send_response(request, 200, {})
|
||||
return
|
||||
# Loop through all the registered callbacks to check if the method
|
||||
# and path regex match
|
||||
for path_entry in self.path_regexs.get(request.method, []):
|
||||
m = path_entry.pattern.match(request.path)
|
||||
if not m:
|
||||
continue
|
||||
|
||||
# Loop through all the registered callbacks to check if the method
|
||||
# and path regex match
|
||||
for path_entry in self.path_regexs.get(request.method, []):
|
||||
m = path_entry.pattern.match(request.path)
|
||||
if not m:
|
||||
continue
|
||||
# We found a match! Trigger callback and then return the
|
||||
# returned response. We pass both the request and any
|
||||
# matched groups from the regex to the callback.
|
||||
|
||||
# We found a match! Trigger callback and then return the
|
||||
# returned response. We pass both the request and any
|
||||
# matched groups from the regex to the callback.
|
||||
callback = path_entry.callback
|
||||
|
||||
callback = path_entry.callback
|
||||
|
||||
servlet_instance = getattr(callback, "__self__", None)
|
||||
if servlet_instance is not None:
|
||||
servlet_classname = servlet_instance.__class__.__name__
|
||||
else:
|
||||
servlet_classname = "%r" % callback
|
||||
incoming_requests_counter.inc(request.method, servlet_classname)
|
||||
|
||||
args = [
|
||||
urllib.unquote(u).decode("UTF-8") for u in m.groups()
|
||||
]
|
||||
|
||||
logger.info(
|
||||
"Received request: %s %s",
|
||||
request.method, request.path
|
||||
)
|
||||
|
||||
code, response = yield callback(request, *args)
|
||||
|
||||
self._send_response(request, code, response)
|
||||
response_timer.inc_by(
|
||||
self.clock.time_msec() - start, request.method, servlet_classname
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
# Huh. No one wanted to handle that? Fiiiiiine. Send 400.
|
||||
raise UnrecognizedRequestError()
|
||||
except CodeMessageException as e:
|
||||
if isinstance(e, SynapseError):
|
||||
logger.info("%s SynapseError: %s - %s", request, e.code, e.msg)
|
||||
servlet_instance = getattr(callback, "__self__", None)
|
||||
if servlet_instance is not None:
|
||||
servlet_classname = servlet_instance.__class__.__name__
|
||||
else:
|
||||
logger.exception(e)
|
||||
servlet_classname = "%r" % callback
|
||||
incoming_requests_counter.inc(request.method, servlet_classname)
|
||||
|
||||
code = e.code
|
||||
self._send_response(
|
||||
request,
|
||||
code,
|
||||
cs_exception(e),
|
||||
response_code_message=e.response_code_message
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
self._send_response(
|
||||
request,
|
||||
500,
|
||||
{"error": "Internal server error"}
|
||||
)
|
||||
finally:
|
||||
code = str(code) if code else "-"
|
||||
args = [
|
||||
urllib.unquote(u).decode("UTF-8") for u in m.groups()
|
||||
]
|
||||
|
||||
end = self.clock.time_msec()
|
||||
logger.info(
|
||||
"Processed request: %dms %s %s %s",
|
||||
end-start, code, request.method, request.path
|
||||
callback_return = yield callback(request, *args)
|
||||
if callback_return is not None:
|
||||
code, response = callback_return
|
||||
self._send_response(request, code, response)
|
||||
|
||||
response_timer.inc_by(
|
||||
self.clock.time_msec() - start, request.method, servlet_classname
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
# Huh. No one wanted to handle that? Fiiiiiine. Send 400.
|
||||
raise UnrecognizedRequestError()
|
||||
|
||||
def _send_response(self, request, code, response_json_object,
|
||||
response_code_message=None):
|
||||
# could alternatively use request.notifyFinish() and flip a flag when
|
||||
@@ -229,20 +252,10 @@ class JsonResource(HttpServer, resource.Resource):
|
||||
request, code, response_json_object,
|
||||
send_cors=True,
|
||||
response_code_message=response_code_message,
|
||||
pretty_print=self._request_user_agent_is_curl,
|
||||
pretty_print=_request_user_agent_is_curl(request),
|
||||
version_string=self.version_string,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _request_user_agent_is_curl(request):
|
||||
user_agents = request.requestHeaders.getRawHeaders(
|
||||
"User-Agent", default=[]
|
||||
)
|
||||
for user_agent in user_agents:
|
||||
if "curl" in user_agent:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class RootRedirect(resource.Resource):
|
||||
"""Redirects the root '/' path to another path."""
|
||||
@@ -263,8 +276,8 @@ class RootRedirect(resource.Resource):
|
||||
def respond_with_json(request, code, json_object, send_cors=False,
|
||||
response_code_message=None, pretty_print=False,
|
||||
version_string=""):
|
||||
if not pretty_print:
|
||||
json_bytes = encode_pretty_printed_json(json_object)
|
||||
if pretty_print:
|
||||
json_bytes = encode_pretty_printed_json(json_object) + "\n"
|
||||
else:
|
||||
json_bytes = encode_canonical_json(json_object)
|
||||
|
||||
@@ -304,3 +317,13 @@ def respond_with_json_bytes(request, code, json_bytes, send_cors=False,
|
||||
request.write(json_bytes)
|
||||
request.finish()
|
||||
return NOT_DONE_YET
|
||||
|
||||
|
||||
def _request_user_agent_is_curl(request):
|
||||
user_agents = request.requestHeaders.getRawHeaders(
|
||||
"User-Agent", default=[]
|
||||
)
|
||||
for user_agent in user_agents:
|
||||
if "curl" in user_agent:
|
||||
return True
|
||||
return False
|
||||
|
||||
@@ -23,6 +23,61 @@ import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def parse_integer(request, name, default=None, required=False):
|
||||
if name in request.args:
|
||||
try:
|
||||
return int(request.args[name][0])
|
||||
except:
|
||||
message = "Query parameter %r must be an integer" % (name,)
|
||||
raise SynapseError(400, message)
|
||||
else:
|
||||
if required:
|
||||
message = "Missing integer query parameter %r" % (name,)
|
||||
raise SynapseError(400, message)
|
||||
else:
|
||||
return default
|
||||
|
||||
|
||||
def parse_boolean(request, name, default=None, required=False):
|
||||
if name in request.args:
|
||||
try:
|
||||
return {
|
||||
"true": True,
|
||||
"false": False,
|
||||
}[request.args[name][0]]
|
||||
except:
|
||||
message = (
|
||||
"Boolean query parameter %r must be one of"
|
||||
" ['true', 'false']"
|
||||
) % (name,)
|
||||
raise SynapseError(400, message)
|
||||
else:
|
||||
if required:
|
||||
message = "Missing boolean query parameter %r" % (name,)
|
||||
raise SynapseError(400, message)
|
||||
else:
|
||||
return default
|
||||
|
||||
|
||||
def parse_string(request, name, default=None, required=False,
|
||||
allowed_values=None, param_type="string"):
|
||||
if name in request.args:
|
||||
value = request.args[name][0]
|
||||
if allowed_values is not None and value not in allowed_values:
|
||||
message = "Query parameter %r must be one of [%s]" % (
|
||||
name, ", ".join(repr(v) for v in allowed_values)
|
||||
)
|
||||
raise SynapseError(message)
|
||||
else:
|
||||
return value
|
||||
else:
|
||||
if required:
|
||||
message = "Missing %s query parameter %r" % (param_type, name)
|
||||
raise SynapseError(400, message)
|
||||
else:
|
||||
return default
|
||||
|
||||
|
||||
class RestServlet(object):
|
||||
|
||||
""" A Synapse REST Servlet.
|
||||
@@ -56,58 +111,3 @@ class RestServlet(object):
|
||||
http_server.register_path(method, pattern, method_handler)
|
||||
else:
|
||||
raise NotImplementedError("RestServlet must register something.")
|
||||
|
||||
@staticmethod
|
||||
def parse_integer(request, name, default=None, required=False):
|
||||
if name in request.args:
|
||||
try:
|
||||
return int(request.args[name][0])
|
||||
except:
|
||||
message = "Query parameter %r must be an integer" % (name,)
|
||||
raise SynapseError(400, message)
|
||||
else:
|
||||
if required:
|
||||
message = "Missing integer query parameter %r" % (name,)
|
||||
raise SynapseError(400, message)
|
||||
else:
|
||||
return default
|
||||
|
||||
@staticmethod
|
||||
def parse_boolean(request, name, default=None, required=False):
|
||||
if name in request.args:
|
||||
try:
|
||||
return {
|
||||
"true": True,
|
||||
"false": False,
|
||||
}[request.args[name][0]]
|
||||
except:
|
||||
message = (
|
||||
"Boolean query parameter %r must be one of"
|
||||
" ['true', 'false']"
|
||||
) % (name,)
|
||||
raise SynapseError(400, message)
|
||||
else:
|
||||
if required:
|
||||
message = "Missing boolean query parameter %r" % (name,)
|
||||
raise SynapseError(400, message)
|
||||
else:
|
||||
return default
|
||||
|
||||
@staticmethod
|
||||
def parse_string(request, name, default=None, required=False,
|
||||
allowed_values=None, param_type="string"):
|
||||
if name in request.args:
|
||||
value = request.args[name][0]
|
||||
if allowed_values is not None and value not in allowed_values:
|
||||
message = "Query parameter %r must be one of [%s]" % (
|
||||
name, ", ".join(repr(v) for v in allowed_values)
|
||||
)
|
||||
raise SynapseError(message)
|
||||
else:
|
||||
return value
|
||||
else:
|
||||
if required:
|
||||
message = "Missing %s query parameter %r" % (param_type, name)
|
||||
raise SynapseError(400, message)
|
||||
else:
|
||||
return default
|
||||
|
||||
@@ -18,6 +18,8 @@ from __future__ import absolute_import
|
||||
|
||||
import logging
|
||||
from resource import getrusage, getpagesize, RUSAGE_SELF
|
||||
import os
|
||||
import stat
|
||||
|
||||
from .metric import (
|
||||
CounterMetric, CallbackMetric, DistributionMetric, CacheMetric
|
||||
@@ -109,3 +111,36 @@ resource_metrics.register_callback("stime", lambda: rusage.ru_stime * 1000)
|
||||
|
||||
# pages
|
||||
resource_metrics.register_callback("maxrss", lambda: rusage.ru_maxrss * PAGE_SIZE)
|
||||
|
||||
TYPES = {
|
||||
stat.S_IFSOCK: "SOCK",
|
||||
stat.S_IFLNK: "LNK",
|
||||
stat.S_IFREG: "REG",
|
||||
stat.S_IFBLK: "BLK",
|
||||
stat.S_IFDIR: "DIR",
|
||||
stat.S_IFCHR: "CHR",
|
||||
stat.S_IFIFO: "FIFO",
|
||||
}
|
||||
|
||||
|
||||
def _process_fds():
|
||||
counts = {(k,): 0 for k in TYPES.values()}
|
||||
counts[("other",)] = 0
|
||||
|
||||
for fd in os.listdir("/proc/self/fd"):
|
||||
try:
|
||||
s = os.stat("/proc/self/fd/%s" % (fd))
|
||||
fmt = stat.S_IFMT(s.st_mode)
|
||||
if fmt in TYPES:
|
||||
t = TYPES[fmt]
|
||||
else:
|
||||
t = "other"
|
||||
|
||||
counts[(t,)] += 1
|
||||
except OSError:
|
||||
# the dirh itself used by listdir() is usually missing by now
|
||||
pass
|
||||
|
||||
return counts
|
||||
|
||||
get_metrics_for("process").register_callback("fds", _process_fds, labels=["type"])
|
||||
|
||||
@@ -59,10 +59,11 @@ class _NotificationListener(object):
|
||||
self.limit = limit
|
||||
self.timeout = timeout
|
||||
self.deferred = deferred
|
||||
|
||||
self.rooms = rooms
|
||||
self.timer = None
|
||||
|
||||
self.pending_notifications = []
|
||||
def notified(self):
|
||||
return self.deferred.called
|
||||
|
||||
def notify(self, notifier, events, start_token, end_token):
|
||||
""" Inform whoever is listening about the new events. This will
|
||||
@@ -78,16 +79,27 @@ class _NotificationListener(object):
|
||||
except defer.AlreadyCalledError:
|
||||
pass
|
||||
|
||||
# Should the following be done be using intrusively linked lists?
|
||||
# -- erikj
|
||||
|
||||
for room in self.rooms:
|
||||
lst = notifier.room_to_listeners.get(room, set())
|
||||
lst.discard(self)
|
||||
|
||||
notifier.user_to_listeners.get(self.user, set()).discard(self)
|
||||
|
||||
if self.appservice:
|
||||
notifier.appservice_to_listeners.get(
|
||||
self.appservice, set()
|
||||
).discard(self)
|
||||
|
||||
# Cancel the timeout for this notifer if one exists.
|
||||
if self.timer is not None:
|
||||
try:
|
||||
notifier.clock.cancel_call_later(self.timer)
|
||||
except:
|
||||
logger.warn("Failed to cancel notifier timer")
|
||||
|
||||
|
||||
class Notifier(object):
|
||||
""" This class is responsible for notifying any listeners when there are
|
||||
@@ -161,10 +173,18 @@ class Notifier(object):
|
||||
|
||||
room_source = self.event_sources.sources["room"]
|
||||
|
||||
listeners = self.room_to_listeners.get(room_id, set()).copy()
|
||||
room_listeners = self.room_to_listeners.get(room_id, set())
|
||||
|
||||
_discard_if_notified(room_listeners)
|
||||
|
||||
listeners = room_listeners.copy()
|
||||
|
||||
for user in extra_users:
|
||||
listeners |= self.user_to_listeners.get(user, set()).copy()
|
||||
user_listeners = self.user_to_listeners.get(user, set())
|
||||
|
||||
_discard_if_notified(user_listeners)
|
||||
|
||||
listeners |= user_listeners
|
||||
|
||||
for appservice in self.appservice_to_listeners:
|
||||
# TODO (kegan): Redundant appservice listener checks?
|
||||
@@ -173,9 +193,13 @@ class Notifier(object):
|
||||
# receive *invites* for users they are interested in. Does this
|
||||
# make the room_to_listeners check somewhat obselete?
|
||||
if appservice.is_interested(event):
|
||||
listeners |= self.appservice_to_listeners.get(
|
||||
app_listeners = self.appservice_to_listeners.get(
|
||||
appservice, set()
|
||||
).copy()
|
||||
)
|
||||
|
||||
_discard_if_notified(app_listeners)
|
||||
|
||||
listeners |= app_listeners
|
||||
|
||||
logger.debug("on_new_room_event listeners %s", listeners)
|
||||
|
||||
@@ -226,10 +250,18 @@ class Notifier(object):
|
||||
listeners = set()
|
||||
|
||||
for user in users:
|
||||
listeners |= self.user_to_listeners.get(user, set()).copy()
|
||||
user_listeners = self.user_to_listeners.get(user, set())
|
||||
|
||||
_discard_if_notified(user_listeners)
|
||||
|
||||
listeners |= user_listeners
|
||||
|
||||
for room in rooms:
|
||||
listeners |= self.room_to_listeners.get(room, set()).copy()
|
||||
room_listeners = self.room_to_listeners.get(room, set())
|
||||
|
||||
_discard_if_notified(room_listeners)
|
||||
|
||||
listeners |= room_listeners
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def notify(listener):
|
||||
@@ -300,14 +332,20 @@ class Notifier(object):
|
||||
self._register_with_keys(listener[0])
|
||||
|
||||
result = yield callback()
|
||||
timer = [None]
|
||||
|
||||
if timeout:
|
||||
timed_out = [False]
|
||||
|
||||
def _timeout_listener():
|
||||
timed_out[0] = True
|
||||
timer[0] = None
|
||||
listener[0].notify(self, [], from_token, from_token)
|
||||
|
||||
self.clock.call_later(timeout/1000., _timeout_listener)
|
||||
# We create multiple notification listeners so we have to manage
|
||||
# canceling the timeout ourselves.
|
||||
timer[0] = self.clock.call_later(timeout/1000., _timeout_listener)
|
||||
|
||||
while not result and not timed_out[0]:
|
||||
yield deferred
|
||||
deferred = defer.Deferred()
|
||||
@@ -322,6 +360,12 @@ class Notifier(object):
|
||||
self._register_with_keys(listener[0])
|
||||
result = yield callback()
|
||||
|
||||
if timer[0] is not None:
|
||||
try:
|
||||
self.clock.cancel_call_later(timer[0])
|
||||
except:
|
||||
logger.exception("Failed to cancel notifer timer")
|
||||
|
||||
defer.returnValue(result)
|
||||
|
||||
def get_events_for(self, user, rooms, pagination_config, timeout):
|
||||
@@ -360,6 +404,8 @@ class Notifier(object):
|
||||
def _timeout_listener():
|
||||
# TODO (erikj): We should probably set to_token to the current
|
||||
# max rather than reusing from_token.
|
||||
# Remove the timer from the listener so we don't try to cancel it.
|
||||
listener.timer = None
|
||||
listener.notify(
|
||||
self,
|
||||
[],
|
||||
@@ -375,8 +421,11 @@ class Notifier(object):
|
||||
if not timeout:
|
||||
_timeout_listener()
|
||||
else:
|
||||
self.clock.call_later(timeout/1000.0, _timeout_listener)
|
||||
|
||||
# Only add the timer if the listener hasn't been notified
|
||||
if not listener.notified():
|
||||
listener.timer = self.clock.call_later(
|
||||
timeout/1000.0, _timeout_listener
|
||||
)
|
||||
return
|
||||
|
||||
@log_function
|
||||
@@ -427,3 +476,17 @@ class Notifier(object):
|
||||
|
||||
listeners = self.room_to_listeners.setdefault(room_id, set())
|
||||
listeners |= new_listeners
|
||||
|
||||
for l in new_listeners:
|
||||
l.rooms.add(room_id)
|
||||
|
||||
|
||||
def _discard_if_notified(listener_set):
|
||||
"""Remove any 'stale' listeners from the given set.
|
||||
"""
|
||||
to_discard = set()
|
||||
for l in listener_set:
|
||||
if l.notified():
|
||||
to_discard.add(l)
|
||||
|
||||
listener_set -= to_discard
|
||||
|
||||
@@ -253,7 +253,8 @@ class Pusher(object):
|
||||
self.user_name, config, timeout=0)
|
||||
self.last_token = chunk['end']
|
||||
self.store.update_pusher_last_token(
|
||||
self.app_id, self.pushkey, self.last_token)
|
||||
self.app_id, self.pushkey, self.user_name, self.last_token
|
||||
)
|
||||
logger.info("Pusher %s for user %s starting from token %s",
|
||||
self.pushkey, self.user_name, self.last_token)
|
||||
|
||||
@@ -314,7 +315,7 @@ class Pusher(object):
|
||||
pk
|
||||
)
|
||||
yield self.hs.get_pusherpool().remove_pusher(
|
||||
self.app_id, pk
|
||||
self.app_id, pk, self.user_name
|
||||
)
|
||||
|
||||
if not self.alive:
|
||||
@@ -326,6 +327,7 @@ class Pusher(object):
|
||||
self.store.update_pusher_last_token_and_success(
|
||||
self.app_id,
|
||||
self.pushkey,
|
||||
self.user_name,
|
||||
self.last_token,
|
||||
self.clock.time_msec()
|
||||
)
|
||||
@@ -334,6 +336,7 @@ class Pusher(object):
|
||||
self.store.update_pusher_failing_since(
|
||||
self.app_id,
|
||||
self.pushkey,
|
||||
self.user_name,
|
||||
self.failing_since)
|
||||
else:
|
||||
if not self.failing_since:
|
||||
@@ -341,6 +344,7 @@ class Pusher(object):
|
||||
self.store.update_pusher_failing_since(
|
||||
self.app_id,
|
||||
self.pushkey,
|
||||
self.user_name,
|
||||
self.failing_since
|
||||
)
|
||||
|
||||
@@ -358,6 +362,7 @@ class Pusher(object):
|
||||
self.store.update_pusher_last_token(
|
||||
self.app_id,
|
||||
self.pushkey,
|
||||
self.user_name,
|
||||
self.last_token
|
||||
)
|
||||
|
||||
@@ -365,6 +370,7 @@ class Pusher(object):
|
||||
self.store.update_pusher_failing_since(
|
||||
self.app_id,
|
||||
self.pushkey,
|
||||
self.user_name,
|
||||
self.failing_since
|
||||
)
|
||||
else:
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2015 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from synapse.push.rulekinds import PRIORITY_CLASS_MAP, PRIORITY_CLASS_INVERSE_MAP
|
||||
|
||||
|
||||
@@ -112,7 +126,25 @@ def make_base_prepend_override_rules():
|
||||
def make_base_append_override_rules():
|
||||
return [
|
||||
{
|
||||
'rule_id': 'global/override/.m.rule.call',
|
||||
'rule_id': 'global/override/.m.rule.suppress_notices',
|
||||
'conditions': [
|
||||
{
|
||||
'kind': 'event_match',
|
||||
'key': 'content.msgtype',
|
||||
'pattern': 'm.notice',
|
||||
}
|
||||
],
|
||||
'actions': [
|
||||
'dont_notify',
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
def make_base_append_underride_rules(user):
|
||||
return [
|
||||
{
|
||||
'rule_id': 'global/underride/.m.rule.call',
|
||||
'conditions': [
|
||||
{
|
||||
'kind': 'event_match',
|
||||
@@ -131,19 +163,6 @@ def make_base_append_override_rules():
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
'rule_id': 'global/override/.m.rule.suppress_notices',
|
||||
'conditions': [
|
||||
{
|
||||
'kind': 'event_match',
|
||||
'key': 'content.msgtype',
|
||||
'pattern': 'm.notice',
|
||||
}
|
||||
],
|
||||
'actions': [
|
||||
'dont_notify',
|
||||
]
|
||||
},
|
||||
{
|
||||
'rule_id': 'global/override/.m.rule.contains_display_name',
|
||||
'conditions': [
|
||||
@@ -162,7 +181,7 @@ def make_base_append_override_rules():
|
||||
]
|
||||
},
|
||||
{
|
||||
'rule_id': 'global/override/.m.rule.room_one_to_one',
|
||||
'rule_id': 'global/underride/.m.rule.room_one_to_one',
|
||||
'conditions': [
|
||||
{
|
||||
'kind': 'room_member_count',
|
||||
@@ -179,12 +198,7 @@ def make_base_append_override_rules():
|
||||
'value': False
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
def make_base_append_underride_rules(user):
|
||||
return [
|
||||
},
|
||||
{
|
||||
'rule_id': 'global/underride/.m.rule.invite_for_me',
|
||||
'conditions': [
|
||||
|
||||
@@ -19,10 +19,7 @@ from twisted.internet import defer
|
||||
from httppusher import HttpPusher
|
||||
from synapse.push import PusherConfigException
|
||||
|
||||
from syutil.jsonutil import encode_canonical_json
|
||||
|
||||
import logging
|
||||
import simplejson as json
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -52,12 +49,10 @@ class PusherPool:
|
||||
@defer.inlineCallbacks
|
||||
def start(self):
|
||||
pushers = yield self.store.get_all_pushers()
|
||||
for p in pushers:
|
||||
p['data'] = json.loads(p['data'])
|
||||
self._start_pushers(pushers)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def add_pusher(self, user_name, profile_tag, kind, app_id,
|
||||
def add_pusher(self, user_name, access_token, profile_tag, kind, app_id,
|
||||
app_display_name, device_display_name, pushkey, lang, data):
|
||||
# we try to create the pusher just to validate the config: it
|
||||
# will then get pulled out of the database,
|
||||
@@ -71,7 +66,7 @@ class PusherPool:
|
||||
"app_display_name": app_display_name,
|
||||
"device_display_name": device_display_name,
|
||||
"pushkey": pushkey,
|
||||
"pushkey_ts": self.hs.get_clock().time_msec(),
|
||||
"ts": self.hs.get_clock().time_msec(),
|
||||
"lang": lang,
|
||||
"data": data,
|
||||
"last_token": None,
|
||||
@@ -79,17 +74,50 @@ class PusherPool:
|
||||
"failing_since": None
|
||||
})
|
||||
yield self._add_pusher_to_store(
|
||||
user_name, profile_tag, kind, app_id,
|
||||
user_name, access_token, profile_tag, kind, app_id,
|
||||
app_display_name, device_display_name,
|
||||
pushkey, lang, data
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _add_pusher_to_store(self, user_name, profile_tag, kind, app_id,
|
||||
app_display_name, device_display_name,
|
||||
def remove_pushers_by_app_id_and_pushkey_not_user(self, app_id, pushkey,
|
||||
not_user_id):
|
||||
to_remove = yield self.store.get_pushers_by_app_id_and_pushkey(
|
||||
app_id, pushkey
|
||||
)
|
||||
for p in to_remove:
|
||||
if p['user_name'] != not_user_id:
|
||||
logger.info(
|
||||
"Removing pusher for app id %s, pushkey %s, user %s",
|
||||
app_id, pushkey, p['user_name']
|
||||
)
|
||||
self.remove_pusher(p['app_id'], p['pushkey'], p['user_name'])
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def remove_pushers_by_user_access_token(self, user_id, not_access_token_id):
|
||||
all = yield self.store.get_all_pushers()
|
||||
logger.info(
|
||||
"Removing all pushers for user %s except access token %s",
|
||||
user_id, not_access_token_id
|
||||
)
|
||||
for p in all:
|
||||
if (
|
||||
p['user_name'] == user_id and
|
||||
p['access_token'] != not_access_token_id
|
||||
):
|
||||
logger.info(
|
||||
"Removing pusher for app id %s, pushkey %s, user %s",
|
||||
p['app_id'], p['pushkey'], p['user_name']
|
||||
)
|
||||
self.remove_pusher(p['app_id'], p['pushkey'], p['user_name'])
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _add_pusher_to_store(self, user_name, access_token, profile_tag, kind,
|
||||
app_id, app_display_name, device_display_name,
|
||||
pushkey, lang, data):
|
||||
yield self.store.add_pusher(
|
||||
user_name=user_name,
|
||||
access_token=access_token,
|
||||
profile_tag=profile_tag,
|
||||
kind=kind,
|
||||
app_id=app_id,
|
||||
@@ -98,9 +126,9 @@ class PusherPool:
|
||||
pushkey=pushkey,
|
||||
pushkey_ts=self.hs.get_clock().time_msec(),
|
||||
lang=lang,
|
||||
data=encode_canonical_json(data).decode("UTF-8"),
|
||||
data=data,
|
||||
)
|
||||
self._refresh_pusher((app_id, pushkey))
|
||||
self._refresh_pusher(app_id, pushkey, user_name)
|
||||
|
||||
def _create_pusher(self, pusherdict):
|
||||
if pusherdict['kind'] == 'http':
|
||||
@@ -112,7 +140,7 @@ class PusherPool:
|
||||
app_display_name=pusherdict['app_display_name'],
|
||||
device_display_name=pusherdict['device_display_name'],
|
||||
pushkey=pusherdict['pushkey'],
|
||||
pushkey_ts=pusherdict['pushkey_ts'],
|
||||
pushkey_ts=pusherdict['ts'],
|
||||
data=pusherdict['data'],
|
||||
last_token=pusherdict['last_token'],
|
||||
last_success=pusherdict['last_success'],
|
||||
@@ -125,30 +153,48 @@ class PusherPool:
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _refresh_pusher(self, app_id_pushkey):
|
||||
p = yield self.store.get_pushers_by_app_id_and_pushkey(
|
||||
app_id_pushkey
|
||||
def _refresh_pusher(self, app_id, pushkey, user_name):
|
||||
resultlist = yield self.store.get_pushers_by_app_id_and_pushkey(
|
||||
app_id, pushkey
|
||||
)
|
||||
p['data'] = json.loads(p['data'])
|
||||
|
||||
self._start_pushers([p])
|
||||
p = None
|
||||
for r in resultlist:
|
||||
if r['user_name'] == user_name:
|
||||
p = r
|
||||
|
||||
if p:
|
||||
|
||||
self._start_pushers([p])
|
||||
|
||||
def _start_pushers(self, pushers):
|
||||
logger.info("Starting %d pushers", len(pushers))
|
||||
for pusherdict in pushers:
|
||||
p = self._create_pusher(pusherdict)
|
||||
try:
|
||||
p = self._create_pusher(pusherdict)
|
||||
except PusherConfigException:
|
||||
logger.exception("Couldn't start a pusher: caught PusherConfigException")
|
||||
continue
|
||||
if p:
|
||||
fullid = "%s:%s" % (pusherdict['app_id'], pusherdict['pushkey'])
|
||||
fullid = "%s:%s:%s" % (
|
||||
pusherdict['app_id'],
|
||||
pusherdict['pushkey'],
|
||||
pusherdict['user_name']
|
||||
)
|
||||
if fullid in self.pushers:
|
||||
self.pushers[fullid].stop()
|
||||
self.pushers[fullid] = p
|
||||
p.start()
|
||||
|
||||
logger.info("Started pushers")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def remove_pusher(self, app_id, pushkey):
|
||||
fullid = "%s:%s" % (app_id, pushkey)
|
||||
def remove_pusher(self, app_id, pushkey, user_name):
|
||||
fullid = "%s:%s:%s" % (app_id, pushkey, user_name)
|
||||
if fullid in self.pushers:
|
||||
logger.info("Stopping pusher %s", fullid)
|
||||
self.pushers[fullid].stop()
|
||||
del self.pushers[fullid]
|
||||
yield self.store.delete_pusher_by_app_id_pushkey(app_id, pushkey)
|
||||
yield self.store.delete_pusher_by_app_id_pushkey_user_name(
|
||||
app_id, pushkey, user_name
|
||||
)
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2015 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
PRIORITY_CLASS_MAP = {
|
||||
'underride': 1,
|
||||
'sender': 2,
|
||||
|
||||
@@ -1,10 +1,24 @@
|
||||
# Copyright 2015 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from distutils.version import LooseVersion
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
REQUIREMENTS = {
|
||||
"syutil>=0.0.3": ["syutil"],
|
||||
"syutil>=0.0.6": ["syutil>=0.0.6"],
|
||||
"Twisted==14.0.2": ["twisted==14.0.2"],
|
||||
"service_identity>=1.0.0": ["service_identity>=1.0.0"],
|
||||
"pyopenssl>=0.14": ["OpenSSL>=0.14"],
|
||||
@@ -43,8 +57,8 @@ DEPENDENCY_LINKS = [
|
||||
),
|
||||
github_link(
|
||||
project="matrix-org/syutil",
|
||||
version="v0.0.3",
|
||||
egg="syutil-0.0.3",
|
||||
version="v0.0.6",
|
||||
egg="syutil-0.0.6",
|
||||
),
|
||||
github_link(
|
||||
project="matrix-org/matrix-angular-sdk",
|
||||
|
||||
@@ -1,48 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2015 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""This module contains base REST classes for constructing client v1 servlets.
|
||||
"""
|
||||
|
||||
from synapse.http.servlet import RestServlet
|
||||
from synapse.api.urls import APP_SERVICE_PREFIX
|
||||
import re
|
||||
|
||||
import logging
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def as_path_pattern(path_regex):
|
||||
"""Creates a regex compiled appservice path with the correct path
|
||||
prefix.
|
||||
|
||||
Args:
|
||||
path_regex (str): The regex string to match. This should NOT have a ^
|
||||
as this will be prefixed.
|
||||
Returns:
|
||||
SRE_Pattern
|
||||
"""
|
||||
return re.compile("^" + APP_SERVICE_PREFIX + path_regex)
|
||||
|
||||
|
||||
class AppServiceRestServlet(RestServlet):
|
||||
"""A base Synapse REST Servlet for the application services version 1 API.
|
||||
"""
|
||||
|
||||
def __init__(self, hs):
|
||||
self.hs = hs
|
||||
self.handler = hs.get_handlers().appservice_handler
|
||||
@@ -1,99 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2015 OpenMarket Ltd
|
||||
#
|
||||
# Licensensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""This module contains REST servlets to do with registration: /register"""
|
||||
from twisted.internet import defer
|
||||
|
||||
from base import AppServiceRestServlet, as_path_pattern
|
||||
from synapse.api.errors import CodeMessageException, SynapseError
|
||||
from synapse.storage.appservice import ApplicationService
|
||||
|
||||
import json
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RegisterRestServlet(AppServiceRestServlet):
|
||||
"""Handles AS registration with the home server.
|
||||
"""
|
||||
|
||||
PATTERN = as_path_pattern("/register$")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request):
|
||||
params = _parse_json(request)
|
||||
|
||||
# sanity check required params
|
||||
try:
|
||||
as_token = params["as_token"]
|
||||
as_url = params["url"]
|
||||
if (not isinstance(as_token, basestring) or
|
||||
not isinstance(as_url, basestring)):
|
||||
raise ValueError
|
||||
except (KeyError, ValueError):
|
||||
raise SynapseError(
|
||||
400, "Missed required keys: as_token(str) / url(str)."
|
||||
)
|
||||
|
||||
try:
|
||||
app_service = ApplicationService(
|
||||
as_token, as_url, params["namespaces"]
|
||||
)
|
||||
except ValueError as e:
|
||||
raise SynapseError(400, e.message)
|
||||
|
||||
app_service = yield self.handler.register(app_service)
|
||||
hs_token = app_service.hs_token
|
||||
|
||||
defer.returnValue((200, {
|
||||
"hs_token": hs_token
|
||||
}))
|
||||
|
||||
|
||||
class UnregisterRestServlet(AppServiceRestServlet):
|
||||
"""Handles AS registration with the home server.
|
||||
"""
|
||||
|
||||
PATTERN = as_path_pattern("/unregister$")
|
||||
|
||||
def on_POST(self, request):
|
||||
params = _parse_json(request)
|
||||
try:
|
||||
as_token = params["as_token"]
|
||||
if not isinstance(as_token, basestring):
|
||||
raise ValueError
|
||||
except (KeyError, ValueError):
|
||||
raise SynapseError(400, "Missing required key: as_token(str)")
|
||||
|
||||
yield self.handler.unregister(as_token)
|
||||
|
||||
raise CodeMessageException(500, "Not implemented")
|
||||
|
||||
|
||||
def _parse_json(request):
|
||||
try:
|
||||
content = json.loads(request.content.read())
|
||||
if type(content) != dict:
|
||||
raise SynapseError(400, "Content must be a JSON object.")
|
||||
return content
|
||||
except ValueError as e:
|
||||
logger.warn(e)
|
||||
raise SynapseError(400, "Content not JSON.")
|
||||
|
||||
|
||||
def register_servlets(hs, http_server):
|
||||
RegisterRestServlet(hs).register(http_server)
|
||||
UnregisterRestServlet(hs).register(http_server)
|
||||
@@ -48,5 +48,5 @@ class ClientV1RestServlet(RestServlet):
|
||||
self.hs = hs
|
||||
self.handlers = hs.get_handlers()
|
||||
self.builder_factory = hs.get_event_builder_factory()
|
||||
self.auth = hs.get_auth()
|
||||
self.auth = hs.get_v1auth()
|
||||
self.txns = HttpTransactionStore()
|
||||
|
||||
@@ -27,7 +27,7 @@ class PusherRestServlet(ClientV1RestServlet):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request):
|
||||
user, _ = yield self.auth.get_user_by_req(request)
|
||||
user, client = yield self.auth.get_user_by_req(request)
|
||||
|
||||
content = _parse_json(request)
|
||||
|
||||
@@ -37,7 +37,7 @@ class PusherRestServlet(ClientV1RestServlet):
|
||||
and 'kind' in content and
|
||||
content['kind'] is None):
|
||||
yield pusher_pool.remove_pusher(
|
||||
content['app_id'], content['pushkey']
|
||||
content['app_id'], content['pushkey'], user_name=user.to_string()
|
||||
)
|
||||
defer.returnValue((200, {}))
|
||||
|
||||
@@ -51,9 +51,21 @@ class PusherRestServlet(ClientV1RestServlet):
|
||||
raise SynapseError(400, "Missing parameters: "+','.join(missing),
|
||||
errcode=Codes.MISSING_PARAM)
|
||||
|
||||
append = False
|
||||
if 'append' in content:
|
||||
append = content['append']
|
||||
|
||||
if not append:
|
||||
yield pusher_pool.remove_pushers_by_app_id_and_pushkey_not_user(
|
||||
app_id=content['app_id'],
|
||||
pushkey=content['pushkey'],
|
||||
not_user_id=user.to_string()
|
||||
)
|
||||
|
||||
try:
|
||||
yield pusher_pool.add_pusher(
|
||||
user_name=user.to_string(),
|
||||
access_token=client.token_id,
|
||||
profile_tag=content['profile_tag'],
|
||||
kind=content['kind'],
|
||||
app_id=content['app_id'],
|
||||
|
||||
@@ -15,7 +15,10 @@
|
||||
|
||||
from . import (
|
||||
sync,
|
||||
filter
|
||||
filter,
|
||||
account,
|
||||
register,
|
||||
auth
|
||||
)
|
||||
|
||||
from synapse.http.server import JsonResource
|
||||
@@ -32,3 +35,6 @@ class ClientV2AlphaRestResource(JsonResource):
|
||||
def register_servlets(client_resource, hs):
|
||||
sync.register_servlets(hs, client_resource)
|
||||
filter.register_servlets(hs, client_resource)
|
||||
account.register_servlets(hs, client_resource)
|
||||
register.register_servlets(hs, client_resource)
|
||||
auth.register_servlets(hs, client_resource)
|
||||
|
||||
@@ -17,9 +17,11 @@
|
||||
"""
|
||||
|
||||
from synapse.api.urls import CLIENT_V2_ALPHA_PREFIX
|
||||
from synapse.api.errors import SynapseError
|
||||
import re
|
||||
|
||||
import logging
|
||||
import simplejson
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -36,3 +38,23 @@ def client_v2_pattern(path_regex):
|
||||
SRE_Pattern
|
||||
"""
|
||||
return re.compile("^" + CLIENT_V2_ALPHA_PREFIX + path_regex)
|
||||
|
||||
|
||||
def parse_request_allow_empty(request):
|
||||
content = request.content.read()
|
||||
if content is None or content == '':
|
||||
return None
|
||||
try:
|
||||
return simplejson.loads(content)
|
||||
except simplejson.JSONDecodeError:
|
||||
raise SynapseError(400, "Content not JSON.")
|
||||
|
||||
|
||||
def parse_json_dict_from_request(request):
|
||||
try:
|
||||
content = simplejson.loads(request.content.read())
|
||||
if type(content) != dict:
|
||||
raise SynapseError(400, "Content must be a JSON object.")
|
||||
return content
|
||||
except simplejson.JSONDecodeError:
|
||||
raise SynapseError(400, "Content not JSON.")
|
||||
|
||||
159
synapse/rest/client/v2_alpha/account.py
Normal file
159
synapse/rest/client/v2_alpha/account.py
Normal file
@@ -0,0 +1,159 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2015 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.constants import LoginType
|
||||
from synapse.api.errors import LoginError, SynapseError, Codes
|
||||
from synapse.http.servlet import RestServlet
|
||||
from synapse.util.async import run_on_reactor
|
||||
|
||||
from ._base import client_v2_pattern, parse_json_dict_from_request
|
||||
|
||||
import logging
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PasswordRestServlet(RestServlet):
|
||||
PATTERN = client_v2_pattern("/account/password")
|
||||
|
||||
def __init__(self, hs):
|
||||
super(PasswordRestServlet, self).__init__()
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.auth_handler = hs.get_handlers().auth_handler
|
||||
self.login_handler = hs.get_handlers().login_handler
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request):
|
||||
yield run_on_reactor()
|
||||
|
||||
body = parse_json_dict_from_request(request)
|
||||
|
||||
authed, result, params = yield self.auth_handler.check_auth([
|
||||
[LoginType.PASSWORD],
|
||||
[LoginType.EMAIL_IDENTITY]
|
||||
], body)
|
||||
|
||||
if not authed:
|
||||
defer.returnValue((401, result))
|
||||
|
||||
user_id = None
|
||||
|
||||
if LoginType.PASSWORD in result:
|
||||
# if using password, they should also be logged in
|
||||
auth_user, client = yield self.auth.get_user_by_req(request)
|
||||
if auth_user.to_string() != result[LoginType.PASSWORD]:
|
||||
raise LoginError(400, "", Codes.UNKNOWN)
|
||||
user_id = auth_user.to_string()
|
||||
elif LoginType.EMAIL_IDENTITY in result:
|
||||
threepid = result[LoginType.EMAIL_IDENTITY]
|
||||
if 'medium' not in threepid or 'address' not in threepid:
|
||||
raise SynapseError(500, "Malformed threepid")
|
||||
# if using email, we must know about the email they're authing with!
|
||||
threepid_user = yield self.hs.get_datastore().get_user_by_threepid(
|
||||
threepid['medium'], threepid['address']
|
||||
)
|
||||
if not threepid_user:
|
||||
raise SynapseError(404, "Email address not found", Codes.NOT_FOUND)
|
||||
user_id = threepid_user
|
||||
else:
|
||||
logger.error("Auth succeeded but no known type!", result.keys())
|
||||
raise SynapseError(500, "", Codes.UNKNOWN)
|
||||
|
||||
if 'new_password' not in params:
|
||||
raise SynapseError(400, "", Codes.MISSING_PARAM)
|
||||
new_password = params['new_password']
|
||||
|
||||
yield self.login_handler.set_password(
|
||||
user_id, new_password, None
|
||||
)
|
||||
|
||||
defer.returnValue((200, {}))
|
||||
|
||||
def on_OPTIONS(self, _):
|
||||
return 200, {}
|
||||
|
||||
|
||||
class ThreepidRestServlet(RestServlet):
|
||||
PATTERN = client_v2_pattern("/account/3pid")
|
||||
|
||||
def __init__(self, hs):
|
||||
super(ThreepidRestServlet, self).__init__()
|
||||
self.hs = hs
|
||||
self.login_handler = hs.get_handlers().login_handler
|
||||
self.identity_handler = hs.get_handlers().identity_handler
|
||||
self.auth = hs.get_auth()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request):
|
||||
yield run_on_reactor()
|
||||
|
||||
auth_user, _ = yield self.auth.get_user_by_req(request)
|
||||
|
||||
threepids = yield self.hs.get_datastore().user_get_threepids(
|
||||
auth_user.to_string()
|
||||
)
|
||||
|
||||
defer.returnValue((200, {'threepids': threepids}))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request):
|
||||
yield run_on_reactor()
|
||||
|
||||
body = parse_json_dict_from_request(request)
|
||||
|
||||
if 'threePidCreds' not in body:
|
||||
raise SynapseError(400, "Missing param", Codes.MISSING_PARAM)
|
||||
threePidCreds = body['threePidCreds']
|
||||
|
||||
auth_user, client = yield self.auth.get_user_by_req(request)
|
||||
|
||||
threepid = yield self.identity_handler.threepid_from_creds(threePidCreds)
|
||||
|
||||
if not threepid:
|
||||
raise SynapseError(
|
||||
400, "Failed to auth 3pid", Codes.THREEPID_AUTH_FAILED
|
||||
)
|
||||
|
||||
for reqd in ['medium', 'address', 'validated_at']:
|
||||
if reqd not in threepid:
|
||||
logger.warn("Couldn't add 3pid: invalid response from ID sevrer")
|
||||
raise SynapseError(500, "Invalid response from ID Server")
|
||||
|
||||
yield self.login_handler.add_threepid(
|
||||
auth_user.to_string(),
|
||||
threepid['medium'],
|
||||
threepid['address'],
|
||||
threepid['validated_at'],
|
||||
)
|
||||
|
||||
if 'bind' in body and body['bind']:
|
||||
logger.debug(
|
||||
"Binding emails %s to %s",
|
||||
threepid, auth_user.to_string()
|
||||
)
|
||||
yield self.identity_handler.bind_threepid(
|
||||
threePidCreds, auth_user.to_string()
|
||||
)
|
||||
|
||||
defer.returnValue((200, {}))
|
||||
|
||||
|
||||
def register_servlets(hs, http_server):
|
||||
PasswordRestServlet(hs).register(http_server)
|
||||
ThreepidRestServlet(hs).register(http_server)
|
||||
190
synapse/rest/client/v2_alpha/auth.py
Normal file
190
synapse/rest/client/v2_alpha/auth.py
Normal file
@@ -0,0 +1,190 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2015 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.constants import LoginType
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.api.urls import CLIENT_V2_ALPHA_PREFIX
|
||||
from synapse.http.servlet import RestServlet
|
||||
|
||||
from ._base import client_v2_pattern
|
||||
|
||||
import logging
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
RECAPTCHA_TEMPLATE = """
|
||||
<html>
|
||||
<head>
|
||||
<title>Authentication</title>
|
||||
<meta name='viewport' content='width=device-width, initial-scale=1,
|
||||
user-scalable=no, minimum-scale=1.0, maximum-scale=1.0'>
|
||||
<script src="https://www.google.com/recaptcha/api.js"
|
||||
async defer></script>
|
||||
<script src="//code.jquery.com/jquery-1.11.2.min.js"></script>
|
||||
<link rel="stylesheet" href="/_matrix/static/client/register/style.css">
|
||||
<script>
|
||||
function captchaDone() {
|
||||
$('#registrationForm').submit();
|
||||
}
|
||||
</script>
|
||||
</head>
|
||||
<body>
|
||||
<form id="registrationForm" method="post" action="%(myurl)s">
|
||||
<div>
|
||||
<p>
|
||||
Hello! We need to prevent computer programs and other automated
|
||||
things from creating accounts on this server.
|
||||
</p>
|
||||
<p>
|
||||
Please verify that you're not a robot.
|
||||
</p>
|
||||
<input type="hidden" name="session" value="%(session)s" />
|
||||
<div class="g-recaptcha"
|
||||
data-sitekey="%(sitekey)s"
|
||||
data-callback="captchaDone">
|
||||
</div>
|
||||
<noscript>
|
||||
<input type="submit" value="All Done" />
|
||||
</noscript>
|
||||
</div>
|
||||
</div>
|
||||
</form>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
SUCCESS_TEMPLATE = """
|
||||
<html>
|
||||
<head>
|
||||
<title>Success!</title>
|
||||
<meta name='viewport' content='width=device-width, initial-scale=1,
|
||||
user-scalable=no, minimum-scale=1.0, maximum-scale=1.0'>
|
||||
<link rel="stylesheet" href="/_matrix/static/client/register/style.css">
|
||||
<script>
|
||||
if (window.onAuthDone != undefined) {
|
||||
window.onAuthDone();
|
||||
}
|
||||
</script>
|
||||
</head>
|
||||
<body>
|
||||
<div>
|
||||
<p>Thank you</p>
|
||||
<p>You may now close this window and return to the application</p>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
|
||||
class AuthRestServlet(RestServlet):
|
||||
"""
|
||||
Handles Client / Server API authentication in any situations where it
|
||||
cannot be handled in the normal flow (with requests to the same endpoint).
|
||||
Current use is for web fallback auth.
|
||||
"""
|
||||
PATTERN = client_v2_pattern("/auth/(?P<stagetype>[\w\.]*)/fallback/web")
|
||||
|
||||
def __init__(self, hs):
|
||||
super(AuthRestServlet, self).__init__()
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.auth_handler = hs.get_handlers().auth_handler
|
||||
self.registration_handler = hs.get_handlers().registration_handler
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, stagetype):
|
||||
yield
|
||||
if stagetype == LoginType.RECAPTCHA:
|
||||
if ('session' not in request.args or
|
||||
len(request.args['session']) == 0):
|
||||
raise SynapseError(400, "No session supplied")
|
||||
|
||||
session = request.args["session"][0]
|
||||
|
||||
html = RECAPTCHA_TEMPLATE % {
|
||||
'session': session,
|
||||
'myurl': "%s/auth/%s/fallback/web" % (
|
||||
CLIENT_V2_ALPHA_PREFIX, LoginType.RECAPTCHA
|
||||
),
|
||||
'sitekey': self.hs.config.recaptcha_public_key,
|
||||
}
|
||||
html_bytes = html.encode("utf8")
|
||||
request.setResponseCode(200)
|
||||
request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
|
||||
request.setHeader(b"Server", self.hs.version_string)
|
||||
request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),))
|
||||
|
||||
request.write(html_bytes)
|
||||
request.finish()
|
||||
defer.returnValue(None)
|
||||
else:
|
||||
raise SynapseError(404, "Unknown auth stage type")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request, stagetype):
|
||||
yield
|
||||
if stagetype == "m.login.recaptcha":
|
||||
if ('g-recaptcha-response' not in request.args or
|
||||
len(request.args['g-recaptcha-response'])) == 0:
|
||||
raise SynapseError(400, "No captcha response supplied")
|
||||
if ('session' not in request.args or
|
||||
len(request.args['session'])) == 0:
|
||||
raise SynapseError(400, "No session supplied")
|
||||
|
||||
session = request.args['session'][0]
|
||||
|
||||
authdict = {
|
||||
'response': request.args['g-recaptcha-response'][0],
|
||||
'session': session,
|
||||
}
|
||||
|
||||
success = yield self.auth_handler.add_oob_auth(
|
||||
LoginType.RECAPTCHA,
|
||||
authdict,
|
||||
self.hs.get_ip_from_request(request)
|
||||
)
|
||||
|
||||
if success:
|
||||
html = SUCCESS_TEMPLATE
|
||||
else:
|
||||
html = RECAPTCHA_TEMPLATE % {
|
||||
'session': session,
|
||||
'myurl': "%s/auth/%s/fallback/web" % (
|
||||
CLIENT_V2_ALPHA_PREFIX, LoginType.RECAPTCHA
|
||||
),
|
||||
'sitekey': self.hs.config.recaptcha_public_key,
|
||||
}
|
||||
html_bytes = html.encode("utf8")
|
||||
request.setResponseCode(200)
|
||||
request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
|
||||
request.setHeader(b"Server", self.hs.version_string)
|
||||
request.setHeader(b"Content-Length", b"%d" % (len(html_bytes),))
|
||||
|
||||
request.write(html_bytes)
|
||||
request.finish()
|
||||
|
||||
defer.returnValue(None)
|
||||
else:
|
||||
raise SynapseError(404, "Unknown auth stage type")
|
||||
|
||||
def on_OPTIONS(self, _):
|
||||
return 200, {}
|
||||
|
||||
|
||||
def register_servlets(hs, http_server):
|
||||
AuthRestServlet(hs).register(http_server)
|
||||
183
synapse/rest/client/v2_alpha/register.py
Normal file
183
synapse/rest/client/v2_alpha/register.py
Normal file
@@ -0,0 +1,183 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2015 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.constants import LoginType
|
||||
from synapse.api.errors import SynapseError, Codes
|
||||
from synapse.http.servlet import RestServlet
|
||||
|
||||
from ._base import client_v2_pattern, parse_request_allow_empty
|
||||
|
||||
import logging
|
||||
import hmac
|
||||
from hashlib import sha1
|
||||
from synapse.util.async import run_on_reactor
|
||||
|
||||
|
||||
# We ought to be using hmac.compare_digest() but on older pythons it doesn't
|
||||
# exist. It's a _really minor_ security flaw to use plain string comparison
|
||||
# because the timing attack is so obscured by all the other code here it's
|
||||
# unlikely to make much difference
|
||||
if hasattr(hmac, "compare_digest"):
|
||||
compare_digest = hmac.compare_digest
|
||||
else:
|
||||
compare_digest = lambda a, b: a == b
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RegisterRestServlet(RestServlet):
|
||||
PATTERN = client_v2_pattern("/register")
|
||||
|
||||
def __init__(self, hs):
|
||||
super(RegisterRestServlet, self).__init__()
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
self.auth_handler = hs.get_handlers().auth_handler
|
||||
self.registration_handler = hs.get_handlers().registration_handler
|
||||
self.identity_handler = hs.get_handlers().identity_handler
|
||||
self.login_handler = hs.get_handlers().login_handler
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request):
|
||||
yield run_on_reactor()
|
||||
|
||||
body = parse_request_allow_empty(request)
|
||||
if 'password' not in body:
|
||||
raise SynapseError(400, "", Codes.MISSING_PARAM)
|
||||
|
||||
if 'username' in body:
|
||||
desired_username = body['username']
|
||||
yield self.registration_handler.check_username(desired_username)
|
||||
|
||||
is_using_shared_secret = False
|
||||
is_application_server = False
|
||||
|
||||
service = None
|
||||
if 'access_token' in request.args:
|
||||
service = yield self.auth.get_appservice_by_req(request)
|
||||
|
||||
if self.hs.config.enable_registration_captcha:
|
||||
flows = [
|
||||
[LoginType.RECAPTCHA],
|
||||
[LoginType.EMAIL_IDENTITY, LoginType.RECAPTCHA]
|
||||
]
|
||||
else:
|
||||
flows = [
|
||||
[LoginType.DUMMY],
|
||||
[LoginType.EMAIL_IDENTITY]
|
||||
]
|
||||
|
||||
if service:
|
||||
is_application_server = True
|
||||
elif 'mac' in body:
|
||||
# Check registration-specific shared secret auth
|
||||
if 'username' not in body:
|
||||
raise SynapseError(400, "", Codes.MISSING_PARAM)
|
||||
self._check_shared_secret_auth(
|
||||
body['username'], body['mac']
|
||||
)
|
||||
is_using_shared_secret = True
|
||||
else:
|
||||
authed, result, params = yield self.auth_handler.check_auth(
|
||||
flows, body, self.hs.get_ip_from_request(request)
|
||||
)
|
||||
|
||||
if not authed:
|
||||
defer.returnValue((401, result))
|
||||
|
||||
can_register = (
|
||||
not self.hs.config.disable_registration
|
||||
or is_application_server
|
||||
or is_using_shared_secret
|
||||
)
|
||||
if not can_register:
|
||||
raise SynapseError(403, "Registration has been disabled")
|
||||
|
||||
if 'password' not in params:
|
||||
raise SynapseError(400, "", Codes.MISSING_PARAM)
|
||||
desired_username = params['username'] if 'username' in params else None
|
||||
new_password = params['password']
|
||||
|
||||
(user_id, token) = yield self.registration_handler.register(
|
||||
localpart=desired_username,
|
||||
password=new_password
|
||||
)
|
||||
|
||||
if LoginType.EMAIL_IDENTITY in result:
|
||||
threepid = result[LoginType.EMAIL_IDENTITY]
|
||||
|
||||
for reqd in ['medium', 'address', 'validated_at']:
|
||||
if reqd not in threepid:
|
||||
logger.info("Can't add incomplete 3pid")
|
||||
else:
|
||||
yield self.login_handler.add_threepid(
|
||||
user_id,
|
||||
threepid['medium'],
|
||||
threepid['address'],
|
||||
threepid['validated_at'],
|
||||
)
|
||||
|
||||
if 'bind_email' in params and params['bind_email']:
|
||||
logger.info("bind_email specified: binding")
|
||||
|
||||
emailThreepid = result[LoginType.EMAIL_IDENTITY]
|
||||
threepid_creds = emailThreepid['threepid_creds']
|
||||
logger.debug("Binding emails %s to %s" % (
|
||||
emailThreepid, user_id
|
||||
))
|
||||
yield self.identity_handler.bind_threepid(threepid_creds, user_id)
|
||||
else:
|
||||
logger.info("bind_email not specified: not binding email")
|
||||
|
||||
result = {
|
||||
"user_id": user_id,
|
||||
"access_token": token,
|
||||
"home_server": self.hs.hostname,
|
||||
}
|
||||
|
||||
defer.returnValue((200, result))
|
||||
|
||||
def on_OPTIONS(self, _):
|
||||
return 200, {}
|
||||
|
||||
def _check_shared_secret_auth(self, username, mac):
|
||||
if not self.hs.config.registration_shared_secret:
|
||||
raise SynapseError(400, "Shared secret registration is not enabled")
|
||||
|
||||
user = username.encode("utf-8")
|
||||
|
||||
# str() because otherwise hmac complains that 'unicode' does not
|
||||
# have the buffer interface
|
||||
got_mac = str(mac)
|
||||
|
||||
want_mac = hmac.new(
|
||||
key=self.hs.config.registration_shared_secret,
|
||||
msg=user,
|
||||
digestmod=sha1,
|
||||
).hexdigest()
|
||||
|
||||
if compare_digest(want_mac, got_mac):
|
||||
return True
|
||||
else:
|
||||
raise SynapseError(
|
||||
403, "HMAC incorrect",
|
||||
)
|
||||
|
||||
|
||||
def register_servlets(hs, http_server):
|
||||
RegisterRestServlet(hs).register(http_server)
|
||||
@@ -15,7 +15,9 @@
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.http.servlet import RestServlet
|
||||
from synapse.http.servlet import (
|
||||
RestServlet, parse_string, parse_integer, parse_boolean
|
||||
)
|
||||
from synapse.handlers.sync import SyncConfig
|
||||
from synapse.types import StreamToken
|
||||
from synapse.events.utils import (
|
||||
@@ -87,20 +89,20 @@ class SyncRestServlet(RestServlet):
|
||||
def on_GET(self, request):
|
||||
user, client = yield self.auth.get_user_by_req(request)
|
||||
|
||||
timeout = self.parse_integer(request, "timeout", default=0)
|
||||
limit = self.parse_integer(request, "limit", required=True)
|
||||
gap = self.parse_boolean(request, "gap", default=True)
|
||||
sort = self.parse_string(
|
||||
timeout = parse_integer(request, "timeout", default=0)
|
||||
limit = parse_integer(request, "limit", required=True)
|
||||
gap = parse_boolean(request, "gap", default=True)
|
||||
sort = parse_string(
|
||||
request, "sort", default="timeline,asc",
|
||||
allowed_values=self.ALLOWED_SORT
|
||||
)
|
||||
since = self.parse_string(request, "since")
|
||||
set_presence = self.parse_string(
|
||||
since = parse_string(request, "since")
|
||||
set_presence = parse_string(
|
||||
request, "set_presence", default="online",
|
||||
allowed_values=self.ALLOWED_PRESENCE
|
||||
)
|
||||
backfill = self.parse_boolean(request, "backfill", default=False)
|
||||
filter_id = self.parse_string(request, "filter", default=None)
|
||||
backfill = parse_boolean(request, "backfill", default=False)
|
||||
filter_id = parse_string(request, "filter", default=None)
|
||||
|
||||
logger.info(
|
||||
"/sync: user=%r, timeout=%r, limit=%r, gap=%r, sort=%r, since=%r,"
|
||||
|
||||
14
synapse/rest/key/v1/__init__.py
Normal file
14
synapse/rest/key/v1/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2015 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
@@ -12,18 +12,14 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from . import register
|
||||
|
||||
from synapse.http.server import JsonResource
|
||||
from twisted.web.resource import Resource
|
||||
from .local_key_resource import LocalKey
|
||||
from .remote_key_resource import RemoteKey
|
||||
|
||||
|
||||
class AppServiceRestResource(JsonResource):
|
||||
"""A resource for version 1 of the matrix application service API."""
|
||||
|
||||
class KeyApiV2Resource(Resource):
|
||||
def __init__(self, hs):
|
||||
JsonResource.__init__(self, hs)
|
||||
self.register_servlets(self, hs)
|
||||
|
||||
@staticmethod
|
||||
def register_servlets(appservice_resource, hs):
|
||||
register.register_servlets(hs, appservice_resource)
|
||||
Resource.__init__(self)
|
||||
self.putChild("server", LocalKey(hs))
|
||||
self.putChild("query", RemoteKey(hs))
|
||||
125
synapse/rest/key/v2/local_key_resource.py
Normal file
125
synapse/rest/key/v2/local_key_resource.py
Normal file
@@ -0,0 +1,125 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014, 2015 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from twisted.web.resource import Resource
|
||||
from synapse.http.server import respond_with_json_bytes
|
||||
from syutil.crypto.jsonsign import sign_json
|
||||
from syutil.base64util import encode_base64
|
||||
from syutil.jsonutil import encode_canonical_json
|
||||
from hashlib import sha256
|
||||
from OpenSSL import crypto
|
||||
import logging
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LocalKey(Resource):
|
||||
"""HTTP resource containing encoding the TLS X.509 certificate and NACL
|
||||
signature verification keys for this server::
|
||||
|
||||
GET /_matrix/key/v2/server/a.key.id HTTP/1.1
|
||||
|
||||
HTTP/1.1 200 OK
|
||||
Content-Type: application/json
|
||||
{
|
||||
"valid_until_ts": # integer posix timestamp when this result expires.
|
||||
"server_name": "this.server.example.com"
|
||||
"verify_keys": {
|
||||
"algorithm:version": {
|
||||
"key": # base64 encoded NACL verification key.
|
||||
}
|
||||
},
|
||||
"old_verify_keys": {
|
||||
"algorithm:version": {
|
||||
"expired_ts": # integer posix timestamp when the key expired.
|
||||
"key": # base64 encoded NACL verification key.
|
||||
}
|
||||
}
|
||||
"tls_certificate": # base64 ASN.1 DER encoded X.509 tls cert.
|
||||
"signatures": {
|
||||
"this.server.example.com": {
|
||||
"algorithm:version": # NACL signature for this server
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
isLeaf = True
|
||||
|
||||
def __init__(self, hs):
|
||||
self.version_string = hs.version_string
|
||||
self.config = hs.config
|
||||
self.clock = hs.clock
|
||||
self.update_response_body(self.clock.time_msec())
|
||||
Resource.__init__(self)
|
||||
|
||||
def update_response_body(self, time_now_msec):
|
||||
refresh_interval = self.config.key_refresh_interval
|
||||
self.valid_until_ts = int(time_now_msec + refresh_interval)
|
||||
self.response_body = encode_canonical_json(self.response_json_object())
|
||||
|
||||
def response_json_object(self):
|
||||
verify_keys = {}
|
||||
for key in self.config.signing_key:
|
||||
verify_key_bytes = key.verify_key.encode()
|
||||
key_id = "%s:%s" % (key.alg, key.version)
|
||||
verify_keys[key_id] = {
|
||||
u"key": encode_base64(verify_key_bytes)
|
||||
}
|
||||
|
||||
old_verify_keys = {}
|
||||
for key in self.config.old_signing_keys:
|
||||
key_id = "%s:%s" % (key.alg, key.version)
|
||||
verify_key_bytes = key.encode()
|
||||
old_verify_keys[key_id] = {
|
||||
u"key": encode_base64(verify_key_bytes),
|
||||
u"expired_ts": key.expired,
|
||||
}
|
||||
|
||||
x509_certificate_bytes = crypto.dump_certificate(
|
||||
crypto.FILETYPE_ASN1,
|
||||
self.config.tls_certificate
|
||||
)
|
||||
|
||||
sha256_fingerprint = sha256(x509_certificate_bytes).digest()
|
||||
|
||||
json_object = {
|
||||
u"valid_until_ts": self.valid_until_ts,
|
||||
u"server_name": self.config.server_name,
|
||||
u"verify_keys": verify_keys,
|
||||
u"old_verify_keys": old_verify_keys,
|
||||
u"tls_fingerprints": [{
|
||||
u"sha256": encode_base64(sha256_fingerprint),
|
||||
}]
|
||||
}
|
||||
for key in self.config.signing_key:
|
||||
json_object = sign_json(
|
||||
json_object,
|
||||
self.config.server_name,
|
||||
key,
|
||||
)
|
||||
return json_object
|
||||
|
||||
def render_GET(self, request):
|
||||
time_now = self.clock.time_msec()
|
||||
# Update the expiry time if less than half the interval remains.
|
||||
if time_now + self.config.key_refresh_interval / 2 > self.valid_until_ts:
|
||||
self.update_response_body(time_now)
|
||||
return respond_with_json_bytes(
|
||||
request, 200, self.response_body,
|
||||
version_string=self.version_string
|
||||
)
|
||||
242
synapse/rest/key/v2/remote_key_resource.py
Normal file
242
synapse/rest/key/v2/remote_key_resource.py
Normal file
@@ -0,0 +1,242 @@
|
||||
# Copyright 2015 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from synapse.http.server import request_handler, respond_with_json_bytes
|
||||
from synapse.http.servlet import parse_integer
|
||||
from synapse.api.errors import SynapseError, Codes
|
||||
|
||||
from twisted.web.resource import Resource
|
||||
from twisted.web.server import NOT_DONE_YET
|
||||
from twisted.internet import defer
|
||||
|
||||
|
||||
from io import BytesIO
|
||||
import json
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RemoteKey(Resource):
|
||||
"""HTTP resource for retreiving the TLS certificate and NACL signature
|
||||
verification keys for a collection of servers. Checks that the reported
|
||||
X.509 TLS certificate matches the one used in the HTTPS connection. Checks
|
||||
that the NACL signature for the remote server is valid. Returns a dict of
|
||||
JSON signed by both the remote server and by this server.
|
||||
|
||||
Supports individual GET APIs and a bulk query POST API.
|
||||
|
||||
Requsts:
|
||||
|
||||
GET /_matrix/key/v2/query/remote.server.example.com HTTP/1.1
|
||||
|
||||
GET /_matrix/key/v2/query/remote.server.example.com/a.key.id HTTP/1.1
|
||||
|
||||
POST /_matrix/v2/query HTTP/1.1
|
||||
Content-Type: application/json
|
||||
{
|
||||
"server_keys": {
|
||||
"remote.server.example.com": {
|
||||
"a.key.id": {
|
||||
"minimum_valid_until_ts": 1234567890123
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Response:
|
||||
|
||||
HTTP/1.1 200 OK
|
||||
Content-Type: application/json
|
||||
{
|
||||
"server_keys": [
|
||||
{
|
||||
"server_name": "remote.server.example.com"
|
||||
"valid_until_ts": # posix timestamp
|
||||
"verify_keys": {
|
||||
"a.key.id": { # The identifier for a key.
|
||||
key: "" # base64 encoded verification key.
|
||||
}
|
||||
}
|
||||
"old_verify_keys": {
|
||||
"an.old.key.id": { # The identifier for an old key.
|
||||
key: "", # base64 encoded key
|
||||
"expired_ts": 0, # when the key stop being used.
|
||||
}
|
||||
}
|
||||
"tls_fingerprints": [
|
||||
{ "sha256": # fingerprint }
|
||||
]
|
||||
"signatures": {
|
||||
"remote.server.example.com": {...}
|
||||
"this.server.example.com": {...}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
"""
|
||||
|
||||
isLeaf = True
|
||||
|
||||
def __init__(self, hs):
|
||||
self.keyring = hs.get_keyring()
|
||||
self.store = hs.get_datastore()
|
||||
self.version_string = hs.version_string
|
||||
self.clock = hs.get_clock()
|
||||
|
||||
def render_GET(self, request):
|
||||
self.async_render_GET(request)
|
||||
return NOT_DONE_YET
|
||||
|
||||
@request_handler
|
||||
@defer.inlineCallbacks
|
||||
def async_render_GET(self, request):
|
||||
if len(request.postpath) == 1:
|
||||
server, = request.postpath
|
||||
query = {server: {}}
|
||||
elif len(request.postpath) == 2:
|
||||
server, key_id = request.postpath
|
||||
minimum_valid_until_ts = parse_integer(
|
||||
request, "minimum_valid_until_ts"
|
||||
)
|
||||
arguments = {}
|
||||
if minimum_valid_until_ts is not None:
|
||||
arguments["minimum_valid_until_ts"] = minimum_valid_until_ts
|
||||
query = {server: {key_id: arguments}}
|
||||
else:
|
||||
raise SynapseError(
|
||||
404, "Not found %r" % request.postpath, Codes.NOT_FOUND
|
||||
)
|
||||
yield self.query_keys(request, query, query_remote_on_cache_miss=True)
|
||||
|
||||
def render_POST(self, request):
|
||||
self.async_render_POST(request)
|
||||
return NOT_DONE_YET
|
||||
|
||||
@request_handler
|
||||
@defer.inlineCallbacks
|
||||
def async_render_POST(self, request):
|
||||
try:
|
||||
content = json.loads(request.content.read())
|
||||
if type(content) != dict:
|
||||
raise ValueError()
|
||||
except ValueError:
|
||||
raise SynapseError(
|
||||
400, "Content must be JSON object.", errcode=Codes.NOT_JSON
|
||||
)
|
||||
|
||||
query = content["server_keys"]
|
||||
|
||||
yield self.query_keys(request, query, query_remote_on_cache_miss=True)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def query_keys(self, request, query, query_remote_on_cache_miss=False):
|
||||
logger.info("Handling query for keys %r", query)
|
||||
store_queries = []
|
||||
for server_name, key_ids in query.items():
|
||||
if not key_ids:
|
||||
key_ids = (None,)
|
||||
for key_id in key_ids:
|
||||
store_queries.append((server_name, key_id, None))
|
||||
|
||||
cached = yield self.store.get_server_keys_json(store_queries)
|
||||
|
||||
json_results = set()
|
||||
|
||||
time_now_ms = self.clock.time_msec()
|
||||
|
||||
cache_misses = dict()
|
||||
for (server_name, key_id, from_server), results in cached.items():
|
||||
results = [
|
||||
(result["ts_added_ms"], result) for result in results
|
||||
]
|
||||
|
||||
if not results and key_id is not None:
|
||||
cache_misses.setdefault(server_name, set()).add(key_id)
|
||||
continue
|
||||
|
||||
if key_id is not None:
|
||||
ts_added_ms, most_recent_result = max(results)
|
||||
ts_valid_until_ms = most_recent_result["ts_valid_until_ms"]
|
||||
req_key = query.get(server_name, {}).get(key_id, {})
|
||||
req_valid_until = req_key.get("minimum_valid_until_ts")
|
||||
miss = False
|
||||
if req_valid_until is not None:
|
||||
if ts_valid_until_ms < req_valid_until:
|
||||
logger.debug(
|
||||
"Cached response for %r/%r is older than requested"
|
||||
": valid_until (%r) < minimum_valid_until (%r)",
|
||||
server_name, key_id,
|
||||
ts_valid_until_ms, req_valid_until
|
||||
)
|
||||
miss = True
|
||||
else:
|
||||
logger.debug(
|
||||
"Cached response for %r/%r is newer than requested"
|
||||
": valid_until (%r) >= minimum_valid_until (%r)",
|
||||
server_name, key_id,
|
||||
ts_valid_until_ms, req_valid_until
|
||||
)
|
||||
elif (ts_added_ms + ts_valid_until_ms) / 2 < time_now_ms:
|
||||
logger.debug(
|
||||
"Cached response for %r/%r is too old"
|
||||
": (added (%r) + valid_until (%r)) / 2 < now (%r)",
|
||||
server_name, key_id,
|
||||
ts_added_ms, ts_valid_until_ms, time_now_ms
|
||||
)
|
||||
# We more than half way through the lifetime of the
|
||||
# response. We should fetch a fresh copy.
|
||||
miss = True
|
||||
else:
|
||||
logger.debug(
|
||||
"Cached response for %r/%r is still valid"
|
||||
": (added (%r) + valid_until (%r)) / 2 < now (%r)",
|
||||
server_name, key_id,
|
||||
ts_added_ms, ts_valid_until_ms, time_now_ms
|
||||
)
|
||||
|
||||
if miss:
|
||||
cache_misses.setdefault(server_name, set()).add(key_id)
|
||||
json_results.add(bytes(most_recent_result["key_json"]))
|
||||
else:
|
||||
for ts_added, result in results:
|
||||
json_results.add(bytes(result["key_json"]))
|
||||
|
||||
if cache_misses and query_remote_on_cache_miss:
|
||||
for server_name, key_ids in cache_misses.items():
|
||||
try:
|
||||
yield self.keyring.get_server_verify_key_v2_direct(
|
||||
server_name, key_ids
|
||||
)
|
||||
except:
|
||||
logger.exception("Failed to get key for %r", server_name)
|
||||
pass
|
||||
yield self.query_keys(
|
||||
request, query, query_remote_on_cache_miss=False
|
||||
)
|
||||
else:
|
||||
result_io = BytesIO()
|
||||
result_io.write(b"{\"server_keys\":")
|
||||
sep = b"["
|
||||
for json_bytes in json_results:
|
||||
result_io.write(sep)
|
||||
result_io.write(json_bytes)
|
||||
sep = b","
|
||||
if sep == b"[":
|
||||
result_io.write(sep)
|
||||
result_io.write(b"]}")
|
||||
|
||||
respond_with_json_bytes(
|
||||
request, 200, result_io.getvalue(),
|
||||
version_string=self.version_string
|
||||
)
|
||||
@@ -18,13 +18,15 @@ from .thumbnailer import Thumbnailer
|
||||
from synapse.http.server import respond_with_json
|
||||
from synapse.util.stringutils import random_string
|
||||
from synapse.api.errors import (
|
||||
cs_exception, CodeMessageException, cs_error, Codes, SynapseError
|
||||
cs_error, Codes, SynapseError
|
||||
)
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.web.resource import Resource
|
||||
from twisted.protocols.basic import FileSender
|
||||
|
||||
from synapse.util.async import create_observer
|
||||
|
||||
import os
|
||||
|
||||
import logging
|
||||
@@ -32,6 +34,18 @@ import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def parse_media_id(request):
|
||||
try:
|
||||
server_name, media_id = request.postpath
|
||||
return (server_name, media_id)
|
||||
except:
|
||||
raise SynapseError(
|
||||
404,
|
||||
"Invalid media id token %r" % (request.postpath,),
|
||||
Codes.UNKNOWN,
|
||||
)
|
||||
|
||||
|
||||
class BaseMediaResource(Resource):
|
||||
isLeaf = True
|
||||
|
||||
@@ -45,74 +59,9 @@ class BaseMediaResource(Resource):
|
||||
self.max_upload_size = hs.config.max_upload_size
|
||||
self.max_image_pixels = hs.config.max_image_pixels
|
||||
self.filepaths = filepaths
|
||||
self.version_string = hs.version_string
|
||||
self.downloads = {}
|
||||
|
||||
@staticmethod
|
||||
def catch_errors(request_handler):
|
||||
@defer.inlineCallbacks
|
||||
def wrapped_request_handler(self, request):
|
||||
try:
|
||||
yield request_handler(self, request)
|
||||
except CodeMessageException as e:
|
||||
logger.info("Responding with error: %r", e)
|
||||
respond_with_json(
|
||||
request, e.code, cs_exception(e), send_cors=True
|
||||
)
|
||||
except:
|
||||
logger.exception(
|
||||
"Failed handle request %s.%s on %r",
|
||||
request_handler.__module__,
|
||||
request_handler.__name__,
|
||||
self,
|
||||
)
|
||||
respond_with_json(
|
||||
request,
|
||||
500,
|
||||
{"error": "Internal server error"},
|
||||
send_cors=True
|
||||
)
|
||||
return wrapped_request_handler
|
||||
|
||||
@staticmethod
|
||||
def _parse_media_id(request):
|
||||
try:
|
||||
server_name, media_id = request.postpath
|
||||
return (server_name, media_id)
|
||||
except:
|
||||
raise SynapseError(
|
||||
404,
|
||||
"Invalid media id token %r" % (request.postpath,),
|
||||
Codes.UNKNOWN,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _parse_integer(request, arg_name, default=None):
|
||||
try:
|
||||
if default is None:
|
||||
return int(request.args[arg_name][0])
|
||||
else:
|
||||
return int(request.args.get(arg_name, [default])[0])
|
||||
except:
|
||||
raise SynapseError(
|
||||
400,
|
||||
"Missing integer argument %r" % (arg_name,),
|
||||
Codes.UNKNOWN,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _parse_string(request, arg_name, default=None):
|
||||
try:
|
||||
if default is None:
|
||||
return request.args[arg_name][0]
|
||||
else:
|
||||
return request.args.get(arg_name, [default])[0]
|
||||
except:
|
||||
raise SynapseError(
|
||||
400,
|
||||
"Missing string argument %r" % (arg_name,),
|
||||
Codes.UNKNOWN,
|
||||
)
|
||||
|
||||
def _respond_404(self, request):
|
||||
respond_with_json(
|
||||
request, 404,
|
||||
@@ -140,7 +89,7 @@ class BaseMediaResource(Resource):
|
||||
def callback(media_info):
|
||||
del self.downloads[key]
|
||||
return media_info
|
||||
return download
|
||||
return create_observer(download)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _get_remote_media_impl(self, server_name, media_id):
|
||||
|
||||
@@ -13,7 +13,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .base_resource import BaseMediaResource
|
||||
from .base_resource import BaseMediaResource, parse_media_id
|
||||
from synapse.http.server import request_handler
|
||||
|
||||
from twisted.web.server import NOT_DONE_YET
|
||||
from twisted.internet import defer
|
||||
@@ -28,15 +29,10 @@ class DownloadResource(BaseMediaResource):
|
||||
self._async_render_GET(request)
|
||||
return NOT_DONE_YET
|
||||
|
||||
@BaseMediaResource.catch_errors
|
||||
@request_handler
|
||||
@defer.inlineCallbacks
|
||||
def _async_render_GET(self, request):
|
||||
try:
|
||||
server_name, media_id = request.postpath
|
||||
except:
|
||||
self._respond_404(request)
|
||||
return
|
||||
|
||||
server_name, media_id = parse_media_id(request)
|
||||
if server_name == self.server_name:
|
||||
yield self._respond_local_file(request, media_id)
|
||||
else:
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2015 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from pydenticon import Generator
|
||||
from twisted.web.resource import Resource
|
||||
|
||||
|
||||
@@ -14,7 +14,9 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from .base_resource import BaseMediaResource
|
||||
from .base_resource import BaseMediaResource, parse_media_id
|
||||
from synapse.http.servlet import parse_string, parse_integer
|
||||
from synapse.http.server import request_handler
|
||||
|
||||
from twisted.web.server import NOT_DONE_YET
|
||||
from twisted.internet import defer
|
||||
@@ -31,14 +33,14 @@ class ThumbnailResource(BaseMediaResource):
|
||||
self._async_render_GET(request)
|
||||
return NOT_DONE_YET
|
||||
|
||||
@BaseMediaResource.catch_errors
|
||||
@request_handler
|
||||
@defer.inlineCallbacks
|
||||
def _async_render_GET(self, request):
|
||||
server_name, media_id = self._parse_media_id(request)
|
||||
width = self._parse_integer(request, "width")
|
||||
height = self._parse_integer(request, "height")
|
||||
method = self._parse_string(request, "method", "scale")
|
||||
m_type = self._parse_string(request, "type", "image/png")
|
||||
server_name, media_id = parse_media_id(request)
|
||||
width = parse_integer(request, "width")
|
||||
height = parse_integer(request, "height")
|
||||
method = parse_string(request, "method", "scale")
|
||||
m_type = parse_string(request, "type", "image/png")
|
||||
|
||||
if server_name == self.server_name:
|
||||
yield self._respond_local_thumbnail(
|
||||
|
||||
@@ -13,12 +13,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from synapse.http.server import respond_with_json
|
||||
from synapse.http.server import respond_with_json, request_handler
|
||||
|
||||
from synapse.util.stringutils import random_string
|
||||
from synapse.api.errors import (
|
||||
cs_exception, SynapseError, CodeMessageException
|
||||
)
|
||||
from synapse.api.errors import SynapseError
|
||||
|
||||
from twisted.web.server import NOT_DONE_YET
|
||||
from twisted.internet import defer
|
||||
@@ -69,53 +67,42 @@ class UploadResource(BaseMediaResource):
|
||||
|
||||
defer.returnValue("mxc://%s/%s" % (self.server_name, media_id))
|
||||
|
||||
@request_handler
|
||||
@defer.inlineCallbacks
|
||||
def _async_render_POST(self, request):
|
||||
try:
|
||||
auth_user, client = yield self.auth.get_user_by_req(request)
|
||||
# TODO: The checks here are a bit late. The content will have
|
||||
# already been uploaded to a tmp file at this point
|
||||
content_length = request.getHeader("Content-Length")
|
||||
if content_length is None:
|
||||
raise SynapseError(
|
||||
msg="Request must specify a Content-Length", code=400
|
||||
)
|
||||
if int(content_length) > self.max_upload_size:
|
||||
raise SynapseError(
|
||||
msg="Upload request body is too large",
|
||||
code=413,
|
||||
)
|
||||
|
||||
headers = request.requestHeaders
|
||||
|
||||
if headers.hasHeader("Content-Type"):
|
||||
media_type = headers.getRawHeaders("Content-Type")[0]
|
||||
else:
|
||||
raise SynapseError(
|
||||
msg="Upload request missing 'Content-Type'",
|
||||
code=400,
|
||||
)
|
||||
|
||||
# if headers.hasHeader("Content-Disposition"):
|
||||
# disposition = headers.getRawHeaders("Content-Disposition")[0]
|
||||
# TODO(markjh): parse content-dispostion
|
||||
|
||||
content_uri = yield self.create_content(
|
||||
media_type, None, request.content.read(),
|
||||
content_length, auth_user
|
||||
auth_user, client = yield self.auth.get_user_by_req(request)
|
||||
# TODO: The checks here are a bit late. The content will have
|
||||
# already been uploaded to a tmp file at this point
|
||||
content_length = request.getHeader("Content-Length")
|
||||
if content_length is None:
|
||||
raise SynapseError(
|
||||
msg="Request must specify a Content-Length", code=400
|
||||
)
|
||||
if int(content_length) > self.max_upload_size:
|
||||
raise SynapseError(
|
||||
msg="Upload request body is too large",
|
||||
code=413,
|
||||
)
|
||||
|
||||
respond_with_json(
|
||||
request, 200, {"content_uri": content_uri}, send_cors=True
|
||||
)
|
||||
except CodeMessageException as e:
|
||||
logger.exception(e)
|
||||
respond_with_json(request, e.code, cs_exception(e), send_cors=True)
|
||||
except:
|
||||
logger.exception("Failed to store file")
|
||||
respond_with_json(
|
||||
request,
|
||||
500,
|
||||
{"error": "Internal server error"},
|
||||
send_cors=True
|
||||
headers = request.requestHeaders
|
||||
|
||||
if headers.hasHeader("Content-Type"):
|
||||
media_type = headers.getRawHeaders("Content-Type")[0]
|
||||
else:
|
||||
raise SynapseError(
|
||||
msg="Upload request missing 'Content-Type'",
|
||||
code=400,
|
||||
)
|
||||
|
||||
# if headers.hasHeader("Content-Disposition"):
|
||||
# disposition = headers.getRawHeaders("Content-Disposition")[0]
|
||||
# TODO(markjh): parse content-dispostion
|
||||
|
||||
content_uri = yield self.create_content(
|
||||
media_type, None, request.content.read(),
|
||||
content_length, auth_user
|
||||
)
|
||||
|
||||
respond_with_json(
|
||||
request, 200, {"content_uri": content_uri}, send_cors=True
|
||||
)
|
||||
|
||||
@@ -65,6 +65,7 @@ class BaseHomeServer(object):
|
||||
'replication_layer',
|
||||
'datastore',
|
||||
'handlers',
|
||||
'v1auth',
|
||||
'auth',
|
||||
'rest_servlet_factory',
|
||||
'state_handler',
|
||||
@@ -78,8 +79,8 @@ class BaseHomeServer(object):
|
||||
'resource_for_web_client',
|
||||
'resource_for_content_repo',
|
||||
'resource_for_server_key',
|
||||
'resource_for_server_key_v2',
|
||||
'resource_for_media_repository',
|
||||
'resource_for_app_services',
|
||||
'resource_for_metrics',
|
||||
'event_sources',
|
||||
'ratelimiter',
|
||||
@@ -182,6 +183,15 @@ class HomeServer(BaseHomeServer):
|
||||
def build_auth(self):
|
||||
return Auth(self)
|
||||
|
||||
def build_v1auth(self):
|
||||
orf = Auth(self)
|
||||
# Matrix spec makes no reference to what HTTP status code is returned,
|
||||
# but the V1 API uses 403 where it means 401, and the webclient
|
||||
# relies on this behaviour, so V1 gets its own copy of the auth
|
||||
# with backwards compat behaviour.
|
||||
orf.TOKEN_NOT_FOUND_HTTP_STATUS = 403
|
||||
return orf
|
||||
|
||||
def build_state_handler(self):
|
||||
return StateHandler(self)
|
||||
|
||||
|
||||
@@ -86,12 +86,7 @@ class StateHandler(object):
|
||||
If `event_type` is specified, then the method returns only the one
|
||||
event (or None) with that `event_type` and `state_key`.
|
||||
"""
|
||||
events = yield self.store.get_latest_events_in_room(room_id)
|
||||
|
||||
event_ids = [
|
||||
e_id
|
||||
for e_id, _, _ in events
|
||||
]
|
||||
event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
|
||||
|
||||
cache = None
|
||||
if self._state_cache is not None:
|
||||
|
||||
@@ -14,13 +14,12 @@
|
||||
# limitations under the License.
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.util.logutils import log_function
|
||||
from synapse.api.constants import EventTypes
|
||||
|
||||
from .appservice import ApplicationServiceStore
|
||||
from .appservice import (
|
||||
ApplicationServiceStore, ApplicationServiceTransactionStore
|
||||
)
|
||||
from ._base import Cache
|
||||
from .directory import DirectoryStore
|
||||
from .feedback import FeedbackStore
|
||||
from .events import EventsStore
|
||||
from .presence import PresenceStore
|
||||
from .profile import ProfileStore
|
||||
from .registration import RegistrationStore
|
||||
@@ -39,11 +38,6 @@ from .state import StateStore
|
||||
from .signatures import SignatureStore
|
||||
from .filtering import FilteringStore
|
||||
|
||||
from syutil.base64util import decode_base64
|
||||
from syutil.jsonutil import encode_canonical_json
|
||||
|
||||
from synapse.crypto.event_signing import compute_event_reference_hash
|
||||
|
||||
|
||||
import fnmatch
|
||||
import imp
|
||||
@@ -57,20 +51,18 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
# Remember to update this number every time a change is made to database
|
||||
# schema files, so the users will be informed on server restarts.
|
||||
SCHEMA_VERSION = 14
|
||||
SCHEMA_VERSION = 17
|
||||
|
||||
dir_path = os.path.abspath(os.path.dirname(__file__))
|
||||
|
||||
|
||||
class _RollbackButIsFineException(Exception):
|
||||
""" This exception is used to rollback a transaction without implying
|
||||
something went wrong.
|
||||
"""
|
||||
pass
|
||||
# Number of msec of granularity to store the user IP 'last seen' time. Smaller
|
||||
# times give more inserts into the database even for readonly API hits
|
||||
# 120 seconds == 2 minutes
|
||||
LAST_SEEN_GRANULARITY = 120*1000
|
||||
|
||||
|
||||
class DataStore(RoomMemberStore, RoomStore,
|
||||
RegistrationStore, StreamStore, ProfileStore, FeedbackStore,
|
||||
RegistrationStore, StreamStore, ProfileStore,
|
||||
PresenceStore, TransactionStore,
|
||||
DirectoryStore, KeyStore, StateStore, SignatureStore,
|
||||
ApplicationServiceStore,
|
||||
@@ -79,7 +71,9 @@ class DataStore(RoomMemberStore, RoomStore,
|
||||
RejectionsStore,
|
||||
FilteringStore,
|
||||
PusherStore,
|
||||
PushRuleStore
|
||||
PushRuleStore,
|
||||
ApplicationServiceTransactionStore,
|
||||
EventsStore,
|
||||
):
|
||||
|
||||
def __init__(self, hs):
|
||||
@@ -89,474 +83,53 @@ class DataStore(RoomMemberStore, RoomStore,
|
||||
self.min_token_deferred = self._get_min_token()
|
||||
self.min_token = None
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def persist_event(self, event, context, backfilled=False,
|
||||
is_new_state=True, current_state=None):
|
||||
stream_ordering = None
|
||||
if backfilled:
|
||||
if not self.min_token_deferred.called:
|
||||
yield self.min_token_deferred
|
||||
self.min_token -= 1
|
||||
stream_ordering = self.min_token
|
||||
|
||||
try:
|
||||
yield self.runInteraction(
|
||||
"persist_event",
|
||||
self._persist_event_txn,
|
||||
event=event,
|
||||
context=context,
|
||||
backfilled=backfilled,
|
||||
stream_ordering=stream_ordering,
|
||||
is_new_state=is_new_state,
|
||||
current_state=current_state,
|
||||
)
|
||||
except _RollbackButIsFineException:
|
||||
pass
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_event(self, event_id, check_redacted=True,
|
||||
get_prev_content=False, allow_rejected=False,
|
||||
allow_none=False):
|
||||
"""Get an event from the database by event_id.
|
||||
|
||||
Args:
|
||||
event_id (str): The event_id of the event to fetch
|
||||
check_redacted (bool): If True, check if event has been redacted
|
||||
and redact it.
|
||||
get_prev_content (bool): If True and event is a state event,
|
||||
include the previous states content in the unsigned field.
|
||||
allow_rejected (bool): If True return rejected events.
|
||||
allow_none (bool): If True, return None if no event found, if
|
||||
False throw an exception.
|
||||
|
||||
Returns:
|
||||
Deferred : A FrozenEvent.
|
||||
"""
|
||||
event = yield self.runInteraction(
|
||||
"get_event", self._get_event_txn,
|
||||
event_id,
|
||||
check_redacted=check_redacted,
|
||||
get_prev_content=get_prev_content,
|
||||
allow_rejected=allow_rejected,
|
||||
)
|
||||
|
||||
if not event and not allow_none:
|
||||
raise RuntimeError("Could not find event %s" % (event_id,))
|
||||
|
||||
defer.returnValue(event)
|
||||
|
||||
@log_function
|
||||
def _persist_event_txn(self, txn, event, context, backfilled,
|
||||
stream_ordering=None, is_new_state=True,
|
||||
current_state=None):
|
||||
|
||||
# Remove the any existing cache entries for the event_id
|
||||
self._get_event_cache.pop(event.event_id)
|
||||
|
||||
# We purposefully do this first since if we include a `current_state`
|
||||
# key, we *want* to update the `current_state_events` table
|
||||
if current_state:
|
||||
txn.execute(
|
||||
"DELETE FROM current_state_events WHERE room_id = ?",
|
||||
(event.room_id,)
|
||||
)
|
||||
|
||||
for s in current_state:
|
||||
self._simple_insert_txn(
|
||||
txn,
|
||||
"current_state_events",
|
||||
{
|
||||
"event_id": s.event_id,
|
||||
"room_id": s.room_id,
|
||||
"type": s.type,
|
||||
"state_key": s.state_key,
|
||||
},
|
||||
or_replace=True,
|
||||
)
|
||||
|
||||
if event.is_state() and is_new_state:
|
||||
if not backfilled and not context.rejected:
|
||||
self._simple_insert_txn(
|
||||
txn,
|
||||
table="state_forward_extremities",
|
||||
values={
|
||||
"event_id": event.event_id,
|
||||
"room_id": event.room_id,
|
||||
"type": event.type,
|
||||
"state_key": event.state_key,
|
||||
},
|
||||
or_replace=True,
|
||||
)
|
||||
|
||||
for prev_state_id, _ in event.prev_state:
|
||||
self._simple_delete_txn(
|
||||
txn,
|
||||
table="state_forward_extremities",
|
||||
keyvalues={
|
||||
"event_id": prev_state_id,
|
||||
}
|
||||
)
|
||||
|
||||
outlier = event.internal_metadata.is_outlier()
|
||||
|
||||
if not outlier:
|
||||
self._store_state_groups_txn(txn, event, context)
|
||||
|
||||
self._update_min_depth_for_room_txn(
|
||||
txn,
|
||||
event.room_id,
|
||||
event.depth
|
||||
)
|
||||
|
||||
self._handle_prev_events(
|
||||
txn,
|
||||
outlier=outlier,
|
||||
event_id=event.event_id,
|
||||
prev_events=event.prev_events,
|
||||
room_id=event.room_id,
|
||||
)
|
||||
|
||||
have_persisted = self._simple_select_one_onecol_txn(
|
||||
txn,
|
||||
table="event_json",
|
||||
keyvalues={"event_id": event.event_id},
|
||||
retcol="event_id",
|
||||
allow_none=True,
|
||||
)
|
||||
|
||||
metadata_json = encode_canonical_json(
|
||||
event.internal_metadata.get_dict()
|
||||
)
|
||||
|
||||
# If we have already persisted this event, we don't need to do any
|
||||
# more processing.
|
||||
# The processing above must be done on every call to persist event,
|
||||
# since they might not have happened on previous calls. For example,
|
||||
# if we are persisting an event that we had persisted as an outlier,
|
||||
# but is no longer one.
|
||||
if have_persisted:
|
||||
if not outlier:
|
||||
sql = (
|
||||
"UPDATE event_json SET internal_metadata = ?"
|
||||
" WHERE event_id = ?"
|
||||
)
|
||||
txn.execute(
|
||||
sql,
|
||||
(metadata_json.decode("UTF-8"), event.event_id,)
|
||||
)
|
||||
|
||||
sql = (
|
||||
"UPDATE events SET outlier = 0"
|
||||
" WHERE event_id = ?"
|
||||
)
|
||||
txn.execute(
|
||||
sql,
|
||||
(event.event_id,)
|
||||
)
|
||||
return
|
||||
|
||||
if event.type == EventTypes.Member:
|
||||
self._store_room_member_txn(txn, event)
|
||||
elif event.type == EventTypes.Feedback:
|
||||
self._store_feedback_txn(txn, event)
|
||||
elif event.type == EventTypes.Name:
|
||||
self._store_room_name_txn(txn, event)
|
||||
elif event.type == EventTypes.Topic:
|
||||
self._store_room_topic_txn(txn, event)
|
||||
elif event.type == EventTypes.Redaction:
|
||||
self._store_redaction(txn, event)
|
||||
|
||||
event_dict = {
|
||||
k: v
|
||||
for k, v in event.get_dict().items()
|
||||
if k not in [
|
||||
"redacted",
|
||||
"redacted_because",
|
||||
]
|
||||
}
|
||||
|
||||
self._simple_insert_txn(
|
||||
txn,
|
||||
table="event_json",
|
||||
values={
|
||||
"event_id": event.event_id,
|
||||
"room_id": event.room_id,
|
||||
"internal_metadata": metadata_json.decode("UTF-8"),
|
||||
"json": encode_canonical_json(event_dict).decode("UTF-8"),
|
||||
},
|
||||
or_replace=True,
|
||||
)
|
||||
|
||||
content = encode_canonical_json(
|
||||
event.content
|
||||
).decode("UTF-8")
|
||||
|
||||
vals = {
|
||||
"topological_ordering": event.depth,
|
||||
"event_id": event.event_id,
|
||||
"type": event.type,
|
||||
"room_id": event.room_id,
|
||||
"content": content,
|
||||
"processed": True,
|
||||
"outlier": outlier,
|
||||
"depth": event.depth,
|
||||
}
|
||||
|
||||
if stream_ordering is not None:
|
||||
vals["stream_ordering"] = stream_ordering
|
||||
|
||||
unrec = {
|
||||
k: v
|
||||
for k, v in event.get_dict().items()
|
||||
if k not in vals.keys() and k not in [
|
||||
"redacted",
|
||||
"redacted_because",
|
||||
"signatures",
|
||||
"hashes",
|
||||
"prev_events",
|
||||
]
|
||||
}
|
||||
|
||||
vals["unrecognized_keys"] = encode_canonical_json(
|
||||
unrec
|
||||
).decode("UTF-8")
|
||||
|
||||
try:
|
||||
self._simple_insert_txn(
|
||||
txn,
|
||||
"events",
|
||||
vals,
|
||||
or_replace=(not outlier),
|
||||
or_ignore=bool(outlier),
|
||||
)
|
||||
except:
|
||||
logger.warn(
|
||||
"Failed to persist, probably duplicate: %s",
|
||||
event.event_id,
|
||||
exc_info=True,
|
||||
)
|
||||
raise _RollbackButIsFineException("_persist_event")
|
||||
|
||||
if context.rejected:
|
||||
self._store_rejections_txn(txn, event.event_id, context.rejected)
|
||||
|
||||
if event.is_state():
|
||||
vals = {
|
||||
"event_id": event.event_id,
|
||||
"room_id": event.room_id,
|
||||
"type": event.type,
|
||||
"state_key": event.state_key,
|
||||
}
|
||||
|
||||
# TODO: How does this work with backfilling?
|
||||
if hasattr(event, "replaces_state"):
|
||||
vals["prev_state"] = event.replaces_state
|
||||
|
||||
self._simple_insert_txn(
|
||||
txn,
|
||||
"state_events",
|
||||
vals,
|
||||
or_replace=True,
|
||||
)
|
||||
|
||||
if is_new_state and not context.rejected:
|
||||
self._simple_insert_txn(
|
||||
txn,
|
||||
"current_state_events",
|
||||
{
|
||||
"event_id": event.event_id,
|
||||
"room_id": event.room_id,
|
||||
"type": event.type,
|
||||
"state_key": event.state_key,
|
||||
},
|
||||
or_replace=True,
|
||||
)
|
||||
|
||||
for e_id, h in event.prev_state:
|
||||
self._simple_insert_txn(
|
||||
txn,
|
||||
table="event_edges",
|
||||
values={
|
||||
"event_id": event.event_id,
|
||||
"prev_event_id": e_id,
|
||||
"room_id": event.room_id,
|
||||
"is_state": 1,
|
||||
},
|
||||
or_ignore=True,
|
||||
)
|
||||
|
||||
for hash_alg, hash_base64 in event.hashes.items():
|
||||
hash_bytes = decode_base64(hash_base64)
|
||||
self._store_event_content_hash_txn(
|
||||
txn, event.event_id, hash_alg, hash_bytes,
|
||||
)
|
||||
|
||||
for prev_event_id, prev_hashes in event.prev_events:
|
||||
for alg, hash_base64 in prev_hashes.items():
|
||||
hash_bytes = decode_base64(hash_base64)
|
||||
self._store_prev_event_hash_txn(
|
||||
txn, event.event_id, prev_event_id, alg, hash_bytes
|
||||
)
|
||||
|
||||
for auth_id, _ in event.auth_events:
|
||||
self._simple_insert_txn(
|
||||
txn,
|
||||
table="event_auth",
|
||||
values={
|
||||
"event_id": event.event_id,
|
||||
"room_id": event.room_id,
|
||||
"auth_id": auth_id,
|
||||
},
|
||||
or_ignore=True,
|
||||
)
|
||||
|
||||
(ref_alg, ref_hash_bytes) = compute_event_reference_hash(event)
|
||||
self._store_event_reference_hash_txn(
|
||||
txn, event.event_id, ref_alg, ref_hash_bytes
|
||||
)
|
||||
|
||||
def _store_redaction(self, txn, event):
|
||||
# invalidate the cache for the redacted event
|
||||
self._get_event_cache.pop(event.redacts)
|
||||
txn.execute(
|
||||
"INSERT OR IGNORE INTO redactions "
|
||||
"(event_id, redacts) VALUES (?,?)",
|
||||
(event.event_id, event.redacts)
|
||||
self.client_ip_last_seen = Cache(
|
||||
name="client_ip_last_seen",
|
||||
keylen=4,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_current_state(self, room_id, event_type=None, state_key=""):
|
||||
del_sql = (
|
||||
"SELECT event_id FROM redactions WHERE redacts = e.event_id "
|
||||
"LIMIT 1"
|
||||
)
|
||||
|
||||
sql = (
|
||||
"SELECT e.*, (%(redacted)s) AS redacted FROM events as e "
|
||||
"INNER JOIN current_state_events as c ON e.event_id = c.event_id "
|
||||
"INNER JOIN state_events as s ON e.event_id = s.event_id "
|
||||
"WHERE c.room_id = ? "
|
||||
) % {
|
||||
"redacted": del_sql,
|
||||
}
|
||||
|
||||
if event_type and state_key is not None:
|
||||
sql += " AND s.type = ? AND s.state_key = ? "
|
||||
args = (room_id, event_type, state_key)
|
||||
elif event_type:
|
||||
sql += " AND s.type = ?"
|
||||
args = (room_id, event_type)
|
||||
else:
|
||||
args = (room_id, )
|
||||
|
||||
results = yield self._execute_and_decode("get_current_state", sql, *args)
|
||||
|
||||
events = yield self._parse_events(results)
|
||||
defer.returnValue(events)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_room_name_and_aliases(self, room_id):
|
||||
del_sql = (
|
||||
"SELECT event_id FROM redactions WHERE redacts = e.event_id "
|
||||
"LIMIT 1"
|
||||
)
|
||||
|
||||
sql = (
|
||||
"SELECT e.*, (%(redacted)s) AS redacted FROM events as e "
|
||||
"INNER JOIN current_state_events as c ON e.event_id = c.event_id "
|
||||
"INNER JOIN state_events as s ON e.event_id = s.event_id "
|
||||
"WHERE c.room_id = ? "
|
||||
) % {
|
||||
"redacted": del_sql,
|
||||
}
|
||||
|
||||
sql += " AND ((s.type = 'm.room.name' AND s.state_key = '')"
|
||||
sql += " OR s.type = 'm.room.aliases')"
|
||||
args = (room_id,)
|
||||
|
||||
results = yield self._execute_and_decode("get_current_state", sql, *args)
|
||||
|
||||
events = yield self._parse_events(results)
|
||||
|
||||
name = None
|
||||
aliases = []
|
||||
|
||||
for e in events:
|
||||
if e.type == 'm.room.name':
|
||||
if 'name' in e.content:
|
||||
name = e.content['name']
|
||||
elif e.type == 'm.room.aliases':
|
||||
if 'aliases' in e.content:
|
||||
aliases.extend(e.content['aliases'])
|
||||
|
||||
defer.returnValue((name, aliases))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _get_min_token(self):
|
||||
row = yield self._execute(
|
||||
"_get_min_token", None, "SELECT MIN(stream_ordering) FROM events"
|
||||
)
|
||||
|
||||
self.min_token = row[0][0] if row and row[0] and row[0][0] else -1
|
||||
self.min_token = min(self.min_token, -1)
|
||||
|
||||
logger.debug("min_token is: %s", self.min_token)
|
||||
|
||||
defer.returnValue(self.min_token)
|
||||
|
||||
def insert_client_ip(self, user, access_token, device_id, ip, user_agent):
|
||||
return self._simple_insert(
|
||||
now = int(self._clock.time_msec())
|
||||
key = (user.to_string(), access_token, device_id, ip)
|
||||
|
||||
try:
|
||||
last_seen = self.client_ip_last_seen.get(*key)
|
||||
except KeyError:
|
||||
last_seen = None
|
||||
|
||||
# Rate-limited inserts
|
||||
if last_seen is not None and (now - last_seen) < LAST_SEEN_GRANULARITY:
|
||||
defer.returnValue(None)
|
||||
|
||||
self.client_ip_last_seen.prefill(*key + (now,))
|
||||
|
||||
# It's safe not to lock here: a) no unique constraint,
|
||||
# b) LAST_SEEN_GRANULARITY makes concurrent updates incredibly unlikely
|
||||
yield self._simple_upsert(
|
||||
"user_ips",
|
||||
{
|
||||
"user": user.to_string(),
|
||||
keyvalues={
|
||||
"user_id": user.to_string(),
|
||||
"access_token": access_token,
|
||||
"device_id": device_id,
|
||||
"ip": ip,
|
||||
"user_agent": user_agent,
|
||||
"last_seen": int(self._clock.time_msec()),
|
||||
}
|
||||
},
|
||||
values={
|
||||
"device_id": device_id,
|
||||
"last_seen": now,
|
||||
},
|
||||
desc="insert_client_ip",
|
||||
lock=False,
|
||||
)
|
||||
|
||||
def get_user_ip_and_agents(self, user):
|
||||
return self._simple_select_list(
|
||||
table="user_ips",
|
||||
keyvalues={"user": user.to_string()},
|
||||
keyvalues={"user_id": user.to_string()},
|
||||
retcols=[
|
||||
"device_id", "access_token", "ip", "user_agent", "last_seen"
|
||||
],
|
||||
)
|
||||
|
||||
def have_events(self, event_ids):
|
||||
"""Given a list of event ids, check if we have already processed them.
|
||||
|
||||
Returns:
|
||||
dict: Has an entry for each event id we already have seen. Maps to
|
||||
the rejected reason string if we rejected the event, else maps to
|
||||
None.
|
||||
"""
|
||||
if not event_ids:
|
||||
return defer.succeed({})
|
||||
|
||||
def f(txn):
|
||||
sql = (
|
||||
"SELECT e.event_id, reason FROM events as e "
|
||||
"LEFT JOIN rejections as r ON e.event_id = r.event_id "
|
||||
"WHERE e.event_id = ?"
|
||||
)
|
||||
|
||||
res = {}
|
||||
for event_id in event_ids:
|
||||
txn.execute(sql, (event_id,))
|
||||
row = txn.fetchone()
|
||||
if row:
|
||||
_, rejected = row
|
||||
res[event_id] = rejected
|
||||
|
||||
return res
|
||||
|
||||
return self.runInteraction(
|
||||
"have_events", f,
|
||||
desc="get_user_ip_and_agents",
|
||||
)
|
||||
|
||||
|
||||
@@ -580,21 +153,23 @@ class UpgradeDatabaseException(PrepareDatabaseException):
|
||||
pass
|
||||
|
||||
|
||||
def prepare_database(db_conn):
|
||||
def prepare_database(db_conn, database_engine):
|
||||
"""Prepares a database for usage. Will either create all necessary tables
|
||||
or upgrade from an older schema version.
|
||||
"""
|
||||
try:
|
||||
cur = db_conn.cursor()
|
||||
version_info = _get_or_create_schema_state(cur)
|
||||
version_info = _get_or_create_schema_state(cur, database_engine)
|
||||
|
||||
if version_info:
|
||||
user_version, delta_files, upgraded = version_info
|
||||
_upgrade_existing_database(cur, user_version, delta_files, upgraded)
|
||||
_upgrade_existing_database(
|
||||
cur, user_version, delta_files, upgraded, database_engine
|
||||
)
|
||||
else:
|
||||
_setup_new_database(cur)
|
||||
_setup_new_database(cur, database_engine)
|
||||
|
||||
cur.execute("PRAGMA user_version = %d" % (SCHEMA_VERSION,))
|
||||
# cur.execute("PRAGMA user_version = %d" % (SCHEMA_VERSION,))
|
||||
|
||||
cur.close()
|
||||
db_conn.commit()
|
||||
@@ -603,7 +178,7 @@ def prepare_database(db_conn):
|
||||
raise
|
||||
|
||||
|
||||
def _setup_new_database(cur):
|
||||
def _setup_new_database(cur, database_engine):
|
||||
"""Sets up the database by finding a base set of "full schemas" and then
|
||||
applying any necessary deltas.
|
||||
|
||||
@@ -657,31 +232,30 @@ def _setup_new_database(cur):
|
||||
|
||||
directory_entries = os.listdir(sql_dir)
|
||||
|
||||
sql_script = "BEGIN TRANSACTION;\n"
|
||||
for filename in fnmatch.filter(directory_entries, "*.sql"):
|
||||
sql_loc = os.path.join(sql_dir, filename)
|
||||
logger.debug("Applying schema %s", sql_loc)
|
||||
sql_script += read_schema(sql_loc)
|
||||
sql_script += "\n"
|
||||
sql_script += "COMMIT TRANSACTION;"
|
||||
cur.executescript(sql_script)
|
||||
executescript(cur, sql_loc)
|
||||
|
||||
cur.execute(
|
||||
"INSERT OR REPLACE INTO schema_version (version, upgraded)"
|
||||
" VALUES (?,?)",
|
||||
(max_current_ver, False)
|
||||
database_engine.convert_param_style(
|
||||
"INSERT INTO schema_version (version, upgraded)"
|
||||
" VALUES (?,?)"
|
||||
),
|
||||
(max_current_ver, False,)
|
||||
)
|
||||
|
||||
_upgrade_existing_database(
|
||||
cur,
|
||||
current_version=max_current_ver,
|
||||
applied_delta_files=[],
|
||||
upgraded=False
|
||||
upgraded=False,
|
||||
database_engine=database_engine,
|
||||
)
|
||||
|
||||
|
||||
def _upgrade_existing_database(cur, current_version, applied_delta_files,
|
||||
upgraded):
|
||||
upgraded, database_engine):
|
||||
"""Upgrades an existing database.
|
||||
|
||||
Delta files can either be SQL stored in *.sql files, or python modules
|
||||
@@ -737,6 +311,8 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files,
|
||||
if not upgraded:
|
||||
start_ver += 1
|
||||
|
||||
logger.debug("applied_delta_files: %s", applied_delta_files)
|
||||
|
||||
for v in range(start_ver, SCHEMA_VERSION + 1):
|
||||
logger.debug("Upgrading schema to v%d", v)
|
||||
|
||||
@@ -753,6 +329,7 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files,
|
||||
directory_entries.sort()
|
||||
for file_name in directory_entries:
|
||||
relative_path = os.path.join(str(v), file_name)
|
||||
logger.debug("Found file: %s", relative_path)
|
||||
if relative_path in applied_delta_files:
|
||||
continue
|
||||
|
||||
@@ -774,9 +351,8 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files,
|
||||
module.run_upgrade(cur)
|
||||
elif ext == ".sql":
|
||||
# A plain old .sql file, just read and execute it
|
||||
delta_schema = read_schema(absolute_path)
|
||||
logger.debug("Applying schema %s", relative_path)
|
||||
cur.executescript(delta_schema)
|
||||
executescript(cur, absolute_path)
|
||||
else:
|
||||
# Not a valid delta file.
|
||||
logger.warn(
|
||||
@@ -788,24 +364,83 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files,
|
||||
|
||||
# Mark as done.
|
||||
cur.execute(
|
||||
"INSERT INTO applied_schema_deltas (version, file)"
|
||||
" VALUES (?,?)",
|
||||
database_engine.convert_param_style(
|
||||
"INSERT INTO applied_schema_deltas (version, file)"
|
||||
" VALUES (?,?)",
|
||||
),
|
||||
(v, relative_path)
|
||||
)
|
||||
|
||||
cur.execute("DELETE FROM schema_version")
|
||||
cur.execute(
|
||||
"INSERT OR REPLACE INTO schema_version (version, upgraded)"
|
||||
" VALUES (?,?)",
|
||||
database_engine.convert_param_style(
|
||||
"INSERT INTO schema_version (version, upgraded)"
|
||||
" VALUES (?,?)",
|
||||
),
|
||||
(v, True)
|
||||
)
|
||||
|
||||
|
||||
def _get_or_create_schema_state(txn):
|
||||
def get_statements(f):
|
||||
statement_buffer = ""
|
||||
in_comment = False # If we're in a /* ... */ style comment
|
||||
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
|
||||
if in_comment:
|
||||
# Check if this line contains an end to the comment
|
||||
comments = line.split("*/", 1)
|
||||
if len(comments) == 1:
|
||||
continue
|
||||
line = comments[1]
|
||||
in_comment = False
|
||||
|
||||
# Remove inline block comments
|
||||
line = re.sub(r"/\*.*\*/", " ", line)
|
||||
|
||||
# Does this line start a comment?
|
||||
comments = line.split("/*", 1)
|
||||
if len(comments) > 1:
|
||||
line = comments[0]
|
||||
in_comment = True
|
||||
|
||||
# Deal with line comments
|
||||
line = line.split("--", 1)[0]
|
||||
line = line.split("//", 1)[0]
|
||||
|
||||
# Find *all* semicolons. We need to treat first and last entry
|
||||
# specially.
|
||||
statements = line.split(";")
|
||||
|
||||
# We must prepend statement_buffer to the first statement
|
||||
first_statement = "%s %s" % (
|
||||
statement_buffer.strip(),
|
||||
statements[0].strip()
|
||||
)
|
||||
statements[0] = first_statement
|
||||
|
||||
# Every entry, except the last, is a full statement
|
||||
for statement in statements[:-1]:
|
||||
yield statement.strip()
|
||||
|
||||
# The last entry did *not* end in a semicolon, so we store it for the
|
||||
# next semicolon we find
|
||||
statement_buffer = statements[-1].strip()
|
||||
|
||||
|
||||
def executescript(txn, schema_path):
|
||||
with open(schema_path, 'r') as f:
|
||||
for statement in get_statements(f):
|
||||
txn.execute(statement)
|
||||
|
||||
|
||||
def _get_or_create_schema_state(txn, database_engine):
|
||||
# Bluntly try creating the schema_version tables.
|
||||
schema_path = os.path.join(
|
||||
dir_path, "schema", "schema_version.sql",
|
||||
)
|
||||
create_schema = read_schema(schema_path)
|
||||
txn.executescript(create_schema)
|
||||
executescript(txn, schema_path)
|
||||
|
||||
txn.execute("SELECT version, upgraded FROM schema_version")
|
||||
row = txn.fetchone()
|
||||
@@ -814,10 +449,13 @@ def _get_or_create_schema_state(txn):
|
||||
|
||||
if current_version:
|
||||
txn.execute(
|
||||
"SELECT file FROM applied_schema_deltas WHERE version >= ?",
|
||||
database_engine.convert_param_style(
|
||||
"SELECT file FROM applied_schema_deltas WHERE version >= ?"
|
||||
),
|
||||
(current_version,)
|
||||
)
|
||||
return current_version, txn.fetchall(), upgraded
|
||||
applied_deltas = [d for d, in txn.fetchall()]
|
||||
return current_version, applied_deltas, upgraded
|
||||
|
||||
return None
|
||||
|
||||
@@ -849,7 +487,19 @@ def prepare_sqlite3_database(db_conn):
|
||||
|
||||
if row and row[0]:
|
||||
db_conn.execute(
|
||||
"INSERT OR REPLACE INTO schema_version (version, upgraded)"
|
||||
"REPLACE INTO schema_version (version, upgraded)"
|
||||
" VALUES (?,?)",
|
||||
(row[0], False)
|
||||
)
|
||||
|
||||
|
||||
def are_all_users_on_domain(txn, database_engine, domain):
|
||||
sql = database_engine.convert_param_style(
|
||||
"SELECT COUNT(*) FROM users WHERE name NOT LIKE ?"
|
||||
)
|
||||
pat = "%:" + domain
|
||||
txn.execute(sql, (pat,))
|
||||
num_not_matching = txn.fetchall()[0][0]
|
||||
if num_not_matching == 0:
|
||||
return True
|
||||
return False
|
||||
|
||||
@@ -22,6 +22,8 @@ from synapse.util.logcontext import PreserveLoggingContext, LoggingContext
|
||||
from synapse.util.lrucache import LruCache
|
||||
import synapse.metrics
|
||||
|
||||
from util.id_generators import IdGenerator, StreamIdGenerator
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from collections import namedtuple, OrderedDict
|
||||
@@ -35,6 +37,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
sql_logger = logging.getLogger("synapse.storage.SQL")
|
||||
transaction_logger = logging.getLogger("synapse.storage.txn")
|
||||
perf_logger = logging.getLogger("synapse.storage.TIME")
|
||||
|
||||
|
||||
metrics = synapse.metrics.get_metrics_for("synapse.storage")
|
||||
@@ -53,14 +56,57 @@ cache_counter = metrics.register_cache(
|
||||
)
|
||||
|
||||
|
||||
# TODO(paul):
|
||||
# * more generic key management
|
||||
# * consider other eviction strategies - LRU?
|
||||
def cached(max_entries=1000):
|
||||
class Cache(object):
|
||||
|
||||
def __init__(self, name, max_entries=1000, keylen=1, lru=False):
|
||||
if lru:
|
||||
self.cache = LruCache(max_size=max_entries)
|
||||
self.max_entries = None
|
||||
else:
|
||||
self.cache = OrderedDict()
|
||||
self.max_entries = max_entries
|
||||
|
||||
self.name = name
|
||||
self.keylen = keylen
|
||||
|
||||
caches_by_name[name] = self.cache
|
||||
|
||||
def get(self, *keyargs):
|
||||
if len(keyargs) != self.keylen:
|
||||
raise ValueError("Expected a key to have %d items", self.keylen)
|
||||
|
||||
if keyargs in self.cache:
|
||||
cache_counter.inc_hits(self.name)
|
||||
return self.cache[keyargs]
|
||||
|
||||
cache_counter.inc_misses(self.name)
|
||||
raise KeyError()
|
||||
|
||||
def prefill(self, *args): # because I can't *keyargs, value
|
||||
keyargs = args[:-1]
|
||||
value = args[-1]
|
||||
|
||||
if len(keyargs) != self.keylen:
|
||||
raise ValueError("Expected a key to have %d items", self.keylen)
|
||||
|
||||
if self.max_entries is not None:
|
||||
while len(self.cache) >= self.max_entries:
|
||||
self.cache.popitem(last=False)
|
||||
|
||||
self.cache[keyargs] = value
|
||||
|
||||
def invalidate(self, *keyargs):
|
||||
if len(keyargs) != self.keylen:
|
||||
raise ValueError("Expected a key to have %d items", self.keylen)
|
||||
|
||||
self.cache.pop(keyargs, None)
|
||||
|
||||
|
||||
def cached(max_entries=1000, num_args=1, lru=False):
|
||||
""" A method decorator that applies a memoizing cache around the function.
|
||||
|
||||
The function is presumed to take one additional argument, which is used as
|
||||
the key for the cache. Cache hits are served directly from the cache;
|
||||
The function is presumed to take zero or more arguments, which are used in
|
||||
a tuple as the key for the cache. Hits are served directly from the cache;
|
||||
misses use the function body to generate the value.
|
||||
|
||||
The wrapped function has an additional member, a callable called
|
||||
@@ -71,34 +117,27 @@ def cached(max_entries=1000):
|
||||
calling the calculation function.
|
||||
"""
|
||||
def wrap(orig):
|
||||
cache = OrderedDict()
|
||||
name = orig.__name__
|
||||
|
||||
caches_by_name[name] = cache
|
||||
|
||||
def prefill(key, value):
|
||||
while len(cache) > max_entries:
|
||||
cache.popitem(last=False)
|
||||
|
||||
cache[key] = value
|
||||
cache = Cache(
|
||||
name=orig.__name__,
|
||||
max_entries=max_entries,
|
||||
keylen=num_args,
|
||||
lru=lru,
|
||||
)
|
||||
|
||||
@functools.wraps(orig)
|
||||
@defer.inlineCallbacks
|
||||
def wrapped(self, key):
|
||||
if key in cache:
|
||||
cache_counter.inc_hits(name)
|
||||
defer.returnValue(cache[key])
|
||||
def wrapped(self, *keyargs):
|
||||
try:
|
||||
defer.returnValue(cache.get(*keyargs))
|
||||
except KeyError:
|
||||
ret = yield orig(self, *keyargs)
|
||||
|
||||
cache_counter.inc_misses(name)
|
||||
ret = yield orig(self, key)
|
||||
prefill(key, ret)
|
||||
defer.returnValue(ret)
|
||||
cache.prefill(*keyargs + (ret,))
|
||||
|
||||
def invalidate(key):
|
||||
cache.pop(key, None)
|
||||
defer.returnValue(ret)
|
||||
|
||||
wrapped.invalidate = invalidate
|
||||
wrapped.prefill = prefill
|
||||
wrapped.invalidate = cache.invalidate
|
||||
wrapped.prefill = cache.prefill
|
||||
return wrapped
|
||||
|
||||
return wrap
|
||||
@@ -108,11 +147,12 @@ class LoggingTransaction(object):
|
||||
"""An object that almost-transparently proxies for the 'txn' object
|
||||
passed to the constructor. Adds logging and metrics to the .execute()
|
||||
method."""
|
||||
__slots__ = ["txn", "name"]
|
||||
__slots__ = ["txn", "name", "database_engine"]
|
||||
|
||||
def __init__(self, txn, name):
|
||||
def __init__(self, txn, name, database_engine):
|
||||
object.__setattr__(self, "txn", txn)
|
||||
object.__setattr__(self, "name", name)
|
||||
object.__setattr__(self, "database_engine", database_engine)
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self.txn, name)
|
||||
@@ -124,26 +164,32 @@ class LoggingTransaction(object):
|
||||
# TODO(paul): Maybe use 'info' and 'debug' for values?
|
||||
sql_logger.debug("[SQL] {%s} %s", self.name, sql)
|
||||
|
||||
try:
|
||||
if args and args[0]:
|
||||
values = args[0]
|
||||
sql = self.database_engine.convert_param_style(sql)
|
||||
|
||||
if args and args[0]:
|
||||
args = list(args)
|
||||
args[0] = [
|
||||
self.database_engine.encode_parameter(a) for a in args[0]
|
||||
]
|
||||
try:
|
||||
sql_logger.debug(
|
||||
"[SQL values] {%s} " + ", ".join(("<%r>",) * len(values)),
|
||||
"[SQL values] {%s} " + ", ".join(("<%r>",) * len(args[0])),
|
||||
self.name,
|
||||
*values
|
||||
*args[0]
|
||||
)
|
||||
except:
|
||||
# Don't let logging failures stop SQL from working
|
||||
pass
|
||||
except:
|
||||
# Don't let logging failures stop SQL from working
|
||||
pass
|
||||
|
||||
start = time.time() * 1000
|
||||
|
||||
try:
|
||||
return self.txn.execute(
|
||||
sql, *args, **kwargs
|
||||
)
|
||||
except:
|
||||
logger.exception("[SQL FAIL] {%s}", self.name)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.debug("[SQL FAIL] {%s} %s", self.name, e)
|
||||
raise
|
||||
finally:
|
||||
msecs = (time.time() * 1000) - start
|
||||
sql_logger.debug("[SQL time] {%s} %f", self.name, msecs)
|
||||
@@ -205,10 +251,16 @@ class SQLBaseStore(object):
|
||||
self._txn_perf_counters = PerformanceCounters()
|
||||
self._get_event_counters = PerformanceCounters()
|
||||
|
||||
self._get_event_cache = LruCache(hs.config.event_cache_size)
|
||||
self._get_event_cache = Cache("*getEvent*", keylen=3, lru=True,
|
||||
max_entries=hs.config.event_cache_size)
|
||||
|
||||
# Pretend the getEventCache is just another named cache
|
||||
caches_by_name["*getEvent*"] = self._get_event_cache
|
||||
self.database_engine = hs.database_engine
|
||||
|
||||
self._stream_id_gen = StreamIdGenerator()
|
||||
self._transaction_id_gen = IdGenerator("sent_transactions", "id", self)
|
||||
self._state_groups_id_gen = IdGenerator("state_groups", "id", self)
|
||||
self._access_tokens_id_gen = IdGenerator("access_tokens", "id", self)
|
||||
self._pushers_id_gen = IdGenerator("pushers", "id", self)
|
||||
|
||||
def start_profiling(self):
|
||||
self._previous_loop_ts = self._clock.time_msec()
|
||||
@@ -232,7 +284,7 @@ class SQLBaseStore(object):
|
||||
time_now - time_then, limit=3
|
||||
)
|
||||
|
||||
logger.info(
|
||||
perf_logger.info(
|
||||
"Total database time: %.3f%% {%s} {%s}",
|
||||
ratio * 100, top_three_counters, top_3_event_counters
|
||||
)
|
||||
@@ -246,8 +298,12 @@ class SQLBaseStore(object):
|
||||
|
||||
start_time = time.time() * 1000
|
||||
|
||||
def inner_func(txn, *args, **kwargs):
|
||||
def inner_func(conn, *args, **kwargs):
|
||||
with LoggingContext("runInteraction") as context:
|
||||
if self.database_engine.is_connection_closed(conn):
|
||||
logger.debug("Reconnecting closed database connection")
|
||||
conn.reconnect()
|
||||
|
||||
current_context.copy_to(context)
|
||||
start = time.time() * 1000
|
||||
txn_id = self._TXN_ID
|
||||
@@ -261,9 +317,48 @@ class SQLBaseStore(object):
|
||||
sql_scheduling_timer.inc_by(time.time() * 1000 - start_time)
|
||||
transaction_logger.debug("[TXN START] {%s}", name)
|
||||
try:
|
||||
return func(LoggingTransaction(txn, name), *args, **kwargs)
|
||||
except:
|
||||
logger.exception("[TXN FAIL] {%s}", name)
|
||||
i = 0
|
||||
N = 5
|
||||
while True:
|
||||
try:
|
||||
txn = conn.cursor()
|
||||
return func(
|
||||
LoggingTransaction(txn, name, self.database_engine),
|
||||
*args, **kwargs
|
||||
)
|
||||
except self.database_engine.module.OperationalError as e:
|
||||
# This can happen if the database disappears mid
|
||||
# transaction.
|
||||
logger.warn(
|
||||
"[TXN OPERROR] {%s} %s %d/%d",
|
||||
name, e, i, N
|
||||
)
|
||||
if i < N:
|
||||
i += 1
|
||||
try:
|
||||
conn.rollback()
|
||||
except self.database_engine.module.Error as e1:
|
||||
logger.warn(
|
||||
"[TXN EROLL] {%s} %s",
|
||||
name, e1,
|
||||
)
|
||||
continue
|
||||
except self.database_engine.module.DatabaseError as e:
|
||||
if self.database_engine.is_deadlock(e):
|
||||
logger.warn("[TXN DEADLOCK] {%s} %d/%d", name, i, N)
|
||||
if i < N:
|
||||
i += 1
|
||||
try:
|
||||
conn.rollback()
|
||||
except self.database_engine.module.Error as e1:
|
||||
logger.warn(
|
||||
"[TXN EROLL] {%s} %s",
|
||||
name, e1,
|
||||
)
|
||||
continue
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.debug("[TXN FAIL] {%s} %s", name, e)
|
||||
raise
|
||||
finally:
|
||||
end = time.time() * 1000
|
||||
@@ -276,7 +371,7 @@ class SQLBaseStore(object):
|
||||
sql_txn_timer.inc_by(duration, desc)
|
||||
|
||||
with PreserveLoggingContext():
|
||||
result = yield self._db_pool.runInteraction(
|
||||
result = yield self._db_pool.runWithConnection(
|
||||
inner_func, *args, **kwargs
|
||||
)
|
||||
defer.returnValue(result)
|
||||
@@ -307,11 +402,11 @@ class SQLBaseStore(object):
|
||||
The result of decoder(results)
|
||||
"""
|
||||
def interaction(txn):
|
||||
cursor = txn.execute(query, args)
|
||||
txn.execute(query, args)
|
||||
if decoder:
|
||||
return decoder(cursor)
|
||||
return decoder(txn)
|
||||
else:
|
||||
return cursor.fetchall()
|
||||
return txn.fetchall()
|
||||
|
||||
return self.runInteraction(desc, interaction)
|
||||
|
||||
@@ -321,26 +416,29 @@ class SQLBaseStore(object):
|
||||
# "Simple" SQL API methods that operate on a single table with no JOINs,
|
||||
# no complex WHERE clauses, just a dict of values for columns.
|
||||
|
||||
def _simple_insert(self, table, values, or_replace=False, or_ignore=False):
|
||||
@defer.inlineCallbacks
|
||||
def _simple_insert(self, table, values, or_ignore=False,
|
||||
desc="_simple_insert"):
|
||||
"""Executes an INSERT query on the named table.
|
||||
|
||||
Args:
|
||||
table : string giving the table name
|
||||
values : dict of new column names and values for them
|
||||
or_replace : bool; if True performs an INSERT OR REPLACE
|
||||
"""
|
||||
return self.runInteraction(
|
||||
"_simple_insert",
|
||||
self._simple_insert_txn, table, values, or_replace=or_replace,
|
||||
or_ignore=or_ignore,
|
||||
)
|
||||
try:
|
||||
yield self.runInteraction(
|
||||
desc,
|
||||
self._simple_insert_txn, table, values,
|
||||
)
|
||||
except self.database_engine.module.IntegrityError:
|
||||
# We have to do or_ignore flag at this layer, since we can't reuse
|
||||
# a cursor after we receive an error from the db.
|
||||
if not or_ignore:
|
||||
raise
|
||||
|
||||
@log_function
|
||||
def _simple_insert_txn(self, txn, table, values, or_replace=False,
|
||||
or_ignore=False):
|
||||
sql = "%s INTO %s (%s) VALUES(%s)" % (
|
||||
("INSERT OR REPLACE" if or_replace else
|
||||
"INSERT OR IGNORE" if or_ignore else "INSERT"),
|
||||
def _simple_insert_txn(self, txn, table, values):
|
||||
sql = "INSERT INTO %s (%s) VALUES(%s)" % (
|
||||
table,
|
||||
", ".join(k for k in values),
|
||||
", ".join("?" for k in values)
|
||||
@@ -352,22 +450,29 @@ class SQLBaseStore(object):
|
||||
)
|
||||
|
||||
txn.execute(sql, values.values())
|
||||
return txn.lastrowid
|
||||
|
||||
def _simple_upsert(self, table, keyvalues, values):
|
||||
def _simple_upsert(self, table, keyvalues, values,
|
||||
insertion_values={}, desc="_simple_upsert", lock=True):
|
||||
"""
|
||||
Args:
|
||||
table (str): The table to upsert into
|
||||
keyvalues (dict): The unique key tables and their new values
|
||||
values (dict): The nonunique columns and their new values
|
||||
insertion_values (dict): key/values to use when inserting
|
||||
Returns: A deferred
|
||||
"""
|
||||
return self.runInteraction(
|
||||
"_simple_upsert",
|
||||
self._simple_upsert_txn, table, keyvalues, values
|
||||
desc,
|
||||
self._simple_upsert_txn, table, keyvalues, values, insertion_values,
|
||||
lock
|
||||
)
|
||||
|
||||
def _simple_upsert_txn(self, txn, table, keyvalues, values):
|
||||
def _simple_upsert_txn(self, txn, table, keyvalues, values, insertion_values={},
|
||||
lock=True):
|
||||
# We need to lock the table :(, unless we're *really* careful
|
||||
if lock:
|
||||
self.database_engine.lock_table(txn, table)
|
||||
|
||||
# Try to update
|
||||
sql = "UPDATE %s SET %s WHERE %s" % (
|
||||
table,
|
||||
@@ -386,6 +491,7 @@ class SQLBaseStore(object):
|
||||
allvalues = {}
|
||||
allvalues.update(keyvalues)
|
||||
allvalues.update(values)
|
||||
allvalues.update(insertion_values)
|
||||
|
||||
sql = "INSERT INTO %s (%s) VALUES (%s)" % (
|
||||
table,
|
||||
@@ -399,7 +505,7 @@ class SQLBaseStore(object):
|
||||
txn.execute(sql, allvalues.values())
|
||||
|
||||
def _simple_select_one(self, table, keyvalues, retcols,
|
||||
allow_none=False):
|
||||
allow_none=False, desc="_simple_select_one"):
|
||||
"""Executes a SELECT query on the named table, which is expected to
|
||||
return a single row, returning a single column from it.
|
||||
|
||||
@@ -411,12 +517,15 @@ class SQLBaseStore(object):
|
||||
allow_none : If true, return None instead of failing if the SELECT
|
||||
statement returns no rows
|
||||
"""
|
||||
return self._simple_selectupdate_one(
|
||||
table, keyvalues, retcols=retcols, allow_none=allow_none
|
||||
return self.runInteraction(
|
||||
desc,
|
||||
self._simple_select_one_txn,
|
||||
table, keyvalues, retcols, allow_none,
|
||||
)
|
||||
|
||||
def _simple_select_one_onecol(self, table, keyvalues, retcol,
|
||||
allow_none=False):
|
||||
allow_none=False,
|
||||
desc="_simple_select_one_onecol"):
|
||||
"""Executes a SELECT query on the named table, which is expected to
|
||||
return a single row, returning a single column from it."
|
||||
|
||||
@@ -426,7 +535,7 @@ class SQLBaseStore(object):
|
||||
retcol : string giving the name of the column to return
|
||||
"""
|
||||
return self.runInteraction(
|
||||
"_simple_select_one_onecol",
|
||||
desc,
|
||||
self._simple_select_one_onecol_txn,
|
||||
table, keyvalues, retcol, allow_none=allow_none,
|
||||
)
|
||||
@@ -450,8 +559,7 @@ class SQLBaseStore(object):
|
||||
|
||||
def _simple_select_onecol_txn(self, txn, table, keyvalues, retcol):
|
||||
sql = (
|
||||
"SELECT %(retcol)s FROM %(table)s WHERE %(where)s "
|
||||
"ORDER BY rowid asc"
|
||||
"SELECT %(retcol)s FROM %(table)s WHERE %(where)s"
|
||||
) % {
|
||||
"retcol": retcol,
|
||||
"table": table,
|
||||
@@ -462,7 +570,8 @@ class SQLBaseStore(object):
|
||||
|
||||
return [r[0] for r in txn.fetchall()]
|
||||
|
||||
def _simple_select_onecol(self, table, keyvalues, retcol):
|
||||
def _simple_select_onecol(self, table, keyvalues, retcol,
|
||||
desc="_simple_select_onecol"):
|
||||
"""Executes a SELECT query on the named table, which returns a list
|
||||
comprising of the values of the named column from the selected rows.
|
||||
|
||||
@@ -475,12 +584,13 @@ class SQLBaseStore(object):
|
||||
Deferred: Results in a list
|
||||
"""
|
||||
return self.runInteraction(
|
||||
"_simple_select_onecol",
|
||||
desc,
|
||||
self._simple_select_onecol_txn,
|
||||
table, keyvalues, retcol
|
||||
)
|
||||
|
||||
def _simple_select_list(self, table, keyvalues, retcols):
|
||||
def _simple_select_list(self, table, keyvalues, retcols,
|
||||
desc="_simple_select_list"):
|
||||
"""Executes a SELECT query on the named table, which may return zero or
|
||||
more rows, returning the result as a list of dicts.
|
||||
|
||||
@@ -491,7 +601,7 @@ class SQLBaseStore(object):
|
||||
retcols : list of strings giving the names of the columns to return
|
||||
"""
|
||||
return self.runInteraction(
|
||||
"_simple_select_list",
|
||||
desc,
|
||||
self._simple_select_list_txn,
|
||||
table, keyvalues, retcols
|
||||
)
|
||||
@@ -507,14 +617,14 @@ class SQLBaseStore(object):
|
||||
retcols : list of strings giving the names of the columns to return
|
||||
"""
|
||||
if keyvalues:
|
||||
sql = "SELECT %s FROM %s WHERE %s ORDER BY rowid asc" % (
|
||||
sql = "SELECT %s FROM %s WHERE %s" % (
|
||||
", ".join(retcols),
|
||||
table,
|
||||
" AND ".join("%s = ?" % (k, ) for k in keyvalues)
|
||||
)
|
||||
txn.execute(sql, keyvalues.values())
|
||||
else:
|
||||
sql = "SELECT %s FROM %s ORDER BY rowid asc" % (
|
||||
sql = "SELECT %s FROM %s" % (
|
||||
", ".join(retcols),
|
||||
table
|
||||
)
|
||||
@@ -523,7 +633,7 @@ class SQLBaseStore(object):
|
||||
return self.cursor_to_dict(txn)
|
||||
|
||||
def _simple_update_one(self, table, keyvalues, updatevalues,
|
||||
retcols=None):
|
||||
desc="_simple_update_one"):
|
||||
"""Executes an UPDATE query on the named table, setting new values for
|
||||
columns in a row matching the key values.
|
||||
|
||||
@@ -541,56 +651,81 @@ class SQLBaseStore(object):
|
||||
get-and-set. This can be used to implement compare-and-set by putting
|
||||
the update column in the 'keyvalues' dict as well.
|
||||
"""
|
||||
return self._simple_selectupdate_one(table, keyvalues, updatevalues,
|
||||
retcols=retcols)
|
||||
return self.runInteraction(
|
||||
desc,
|
||||
self._simple_update_one_txn,
|
||||
table, keyvalues, updatevalues,
|
||||
)
|
||||
|
||||
def _simple_update_one_txn(self, txn, table, keyvalues, updatevalues):
|
||||
update_sql = "UPDATE %s SET %s WHERE %s" % (
|
||||
table,
|
||||
", ".join("%s = ?" % (k,) for k in updatevalues),
|
||||
" AND ".join("%s = ?" % (k,) for k in keyvalues)
|
||||
)
|
||||
|
||||
txn.execute(
|
||||
update_sql,
|
||||
updatevalues.values() + keyvalues.values()
|
||||
)
|
||||
|
||||
if txn.rowcount == 0:
|
||||
raise StoreError(404, "No row found")
|
||||
if txn.rowcount > 1:
|
||||
raise StoreError(500, "More than one row matched")
|
||||
|
||||
def _simple_select_one_txn(self, txn, table, keyvalues, retcols,
|
||||
allow_none=False):
|
||||
select_sql = "SELECT %s FROM %s WHERE %s" % (
|
||||
", ".join(retcols),
|
||||
table,
|
||||
" AND ".join("%s = ?" % (k,) for k in keyvalues)
|
||||
)
|
||||
|
||||
txn.execute(select_sql, keyvalues.values())
|
||||
|
||||
row = txn.fetchone()
|
||||
if not row:
|
||||
if allow_none:
|
||||
return None
|
||||
raise StoreError(404, "No row found")
|
||||
if txn.rowcount > 1:
|
||||
raise StoreError(500, "More than one row matched")
|
||||
|
||||
return dict(zip(retcols, row))
|
||||
|
||||
def _simple_selectupdate_one(self, table, keyvalues, updatevalues=None,
|
||||
retcols=None, allow_none=False):
|
||||
retcols=None, allow_none=False,
|
||||
desc="_simple_selectupdate_one"):
|
||||
""" Combined SELECT then UPDATE."""
|
||||
if retcols:
|
||||
select_sql = "SELECT %s FROM %s WHERE %s ORDER BY rowid asc" % (
|
||||
", ".join(retcols),
|
||||
table,
|
||||
" AND ".join("%s = ?" % (k) for k in keyvalues)
|
||||
)
|
||||
|
||||
if updatevalues:
|
||||
update_sql = "UPDATE %s SET %s WHERE %s" % (
|
||||
table,
|
||||
", ".join("%s = ?" % (k,) for k in updatevalues),
|
||||
" AND ".join("%s = ?" % (k,) for k in keyvalues)
|
||||
)
|
||||
|
||||
def func(txn):
|
||||
ret = None
|
||||
if retcols:
|
||||
txn.execute(select_sql, keyvalues.values())
|
||||
|
||||
row = txn.fetchone()
|
||||
if not row:
|
||||
if allow_none:
|
||||
return None
|
||||
raise StoreError(404, "No row found")
|
||||
if txn.rowcount > 1:
|
||||
raise StoreError(500, "More than one row matched")
|
||||
|
||||
ret = dict(zip(retcols, row))
|
||||
|
||||
if updatevalues:
|
||||
txn.execute(
|
||||
update_sql,
|
||||
updatevalues.values() + keyvalues.values()
|
||||
ret = self._simple_select_one_txn(
|
||||
txn,
|
||||
table=table,
|
||||
keyvalues=keyvalues,
|
||||
retcols=retcols,
|
||||
allow_none=allow_none,
|
||||
)
|
||||
|
||||
if txn.rowcount == 0:
|
||||
raise StoreError(404, "No row found")
|
||||
if updatevalues:
|
||||
self._simple_update_one_txn(
|
||||
txn,
|
||||
table=table,
|
||||
keyvalues=keyvalues,
|
||||
updatevalues=updatevalues,
|
||||
)
|
||||
|
||||
# if txn.rowcount == 0:
|
||||
# raise StoreError(404, "No row found")
|
||||
if txn.rowcount > 1:
|
||||
raise StoreError(500, "More than one row matched")
|
||||
|
||||
return ret
|
||||
return self.runInteraction("_simple_selectupdate_one", func)
|
||||
return self.runInteraction(desc, func)
|
||||
|
||||
def _simple_delete_one(self, table, keyvalues):
|
||||
def _simple_delete_one(self, table, keyvalues, desc="_simple_delete_one"):
|
||||
"""Executes a DELETE query on the named table, expecting to delete a
|
||||
single row.
|
||||
|
||||
@@ -609,9 +744,9 @@ class SQLBaseStore(object):
|
||||
raise StoreError(404, "No row found")
|
||||
if txn.rowcount > 1:
|
||||
raise StoreError(500, "more than one row matched")
|
||||
return self.runInteraction("_simple_delete_one", func)
|
||||
return self.runInteraction(desc, func)
|
||||
|
||||
def _simple_delete(self, table, keyvalues):
|
||||
def _simple_delete(self, table, keyvalues, desc="_simple_delete"):
|
||||
"""Executes a DELETE query on the named table.
|
||||
|
||||
Args:
|
||||
@@ -619,7 +754,7 @@ class SQLBaseStore(object):
|
||||
keyvalues : dict of column names and values to select the row with
|
||||
"""
|
||||
|
||||
return self.runInteraction("_simple_delete", self._simple_delete_txn)
|
||||
return self.runInteraction(desc, self._simple_delete_txn)
|
||||
|
||||
def _simple_delete_txn(self, txn, table, keyvalues):
|
||||
sql = "DELETE FROM %s WHERE %s" % (
|
||||
@@ -670,6 +805,12 @@ class SQLBaseStore(object):
|
||||
|
||||
return [e for e in events if e]
|
||||
|
||||
def _invalidate_get_event_cache(self, event_id):
|
||||
for check_redacted in (False, True):
|
||||
for get_prev_content in (False, True):
|
||||
self._get_event_cache.invalidate(event_id, check_redacted,
|
||||
get_prev_content)
|
||||
|
||||
def _get_event_txn(self, txn, event_id, check_redacted=True,
|
||||
get_prev_content=False, allow_rejected=False):
|
||||
|
||||
@@ -680,16 +821,14 @@ class SQLBaseStore(object):
|
||||
sql_getevents_timer.inc_by(curr_time - last_time, desc)
|
||||
return curr_time
|
||||
|
||||
cache = self._get_event_cache.setdefault(event_id, {})
|
||||
|
||||
try:
|
||||
# Separate cache entries for each way to invoke _get_event_txn
|
||||
ret = cache[(check_redacted, get_prev_content, allow_rejected)]
|
||||
ret = self._get_event_cache.get(event_id, check_redacted, get_prev_content)
|
||||
|
||||
cache_counter.inc_hits("*getEvent*")
|
||||
return ret
|
||||
if allow_rejected or not ret.rejected_reason:
|
||||
return ret
|
||||
else:
|
||||
return None
|
||||
except KeyError:
|
||||
cache_counter.inc_misses("*getEvent*")
|
||||
pass
|
||||
finally:
|
||||
start_time = update_counter("event_cache", start_time)
|
||||
@@ -714,19 +853,22 @@ class SQLBaseStore(object):
|
||||
|
||||
start_time = update_counter("select_event", start_time)
|
||||
|
||||
result = self._get_event_from_row_txn(
|
||||
txn, internal_metadata, js, redacted,
|
||||
check_redacted=check_redacted,
|
||||
get_prev_content=get_prev_content,
|
||||
rejected_reason=rejected_reason,
|
||||
)
|
||||
self._get_event_cache.prefill(event_id, check_redacted, get_prev_content, result)
|
||||
|
||||
if allow_rejected or not rejected_reason:
|
||||
result = self._get_event_from_row_txn(
|
||||
txn, internal_metadata, js, redacted,
|
||||
check_redacted=check_redacted,
|
||||
get_prev_content=get_prev_content,
|
||||
)
|
||||
cache[(check_redacted, get_prev_content, allow_rejected)] = result
|
||||
return result
|
||||
else:
|
||||
return None
|
||||
|
||||
def _get_event_from_row_txn(self, txn, internal_metadata, js, redacted,
|
||||
check_redacted=True, get_prev_content=False):
|
||||
check_redacted=True, get_prev_content=False,
|
||||
rejected_reason=None):
|
||||
|
||||
start_time = time.time() * 1000
|
||||
|
||||
@@ -741,7 +883,11 @@ class SQLBaseStore(object):
|
||||
internal_metadata = json.loads(internal_metadata)
|
||||
start_time = update_counter("decode_internal", start_time)
|
||||
|
||||
ev = FrozenEvent(d, internal_metadata_dict=internal_metadata)
|
||||
ev = FrozenEvent(
|
||||
d,
|
||||
internal_metadata_dict=internal_metadata,
|
||||
rejected_reason=rejected_reason,
|
||||
)
|
||||
start_time = update_counter("build_frozen_event", start_time)
|
||||
|
||||
if check_redacted and redacted:
|
||||
@@ -788,6 +934,19 @@ class SQLBaseStore(object):
|
||||
result = txn.fetchone()
|
||||
return result[0] if result else None
|
||||
|
||||
def get_next_stream_id(self):
|
||||
with self._next_stream_id_lock:
|
||||
i = self._next_stream_id
|
||||
self._next_stream_id += 1
|
||||
return i
|
||||
|
||||
|
||||
class _RollbackButIsFineException(Exception):
|
||||
""" This exception is used to rollback a transaction without implying
|
||||
something went wrong.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class Table(object):
|
||||
""" A base class used to store information about a particular table.
|
||||
@@ -804,7 +963,7 @@ class Table(object):
|
||||
|
||||
_select_where_clause = "SELECT %s FROM %s WHERE %s"
|
||||
_select_clause = "SELECT %s FROM %s"
|
||||
_insert_clause = "INSERT OR REPLACE INTO %s (%s) VALUES (%s)"
|
||||
_insert_clause = "REPLACE INTO %s (%s) VALUES (%s)"
|
||||
|
||||
@classmethod
|
||||
def select_statement(cls, where_clause=None):
|
||||
|
||||
@@ -13,154 +13,35 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import simplejson
|
||||
import urllib
|
||||
import yaml
|
||||
from simplejson import JSONDecodeError
|
||||
import simplejson as json
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.constants import Membership
|
||||
from synapse.api.errors import StoreError
|
||||
from synapse.appservice import ApplicationService
|
||||
from synapse.appservice import ApplicationService, AppServiceTransaction
|
||||
from synapse.storage.roommember import RoomsForUser
|
||||
from synapse.types import UserID
|
||||
from ._base import SQLBaseStore
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def log_failure(failure):
|
||||
logger.error("Failed to detect application services: %s", failure.value)
|
||||
logger.error(failure.getTraceback())
|
||||
|
||||
|
||||
class ApplicationServiceStore(SQLBaseStore):
|
||||
|
||||
def __init__(self, hs):
|
||||
super(ApplicationServiceStore, self).__init__(hs)
|
||||
self.hostname = hs.hostname
|
||||
self.services_cache = []
|
||||
self.cache_defer = self._populate_cache()
|
||||
self.cache_defer.addErrback(log_failure)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def unregister_app_service(self, token):
|
||||
"""Unregisters this service.
|
||||
|
||||
This removes all AS specific regex and the base URL. The token is the
|
||||
only thing preserved for future registration attempts.
|
||||
"""
|
||||
yield self.cache_defer # make sure the cache is ready
|
||||
yield self.runInteraction(
|
||||
"unregister_app_service",
|
||||
self._unregister_app_service_txn,
|
||||
token,
|
||||
)
|
||||
# update cache TODO: Should this be in the txn?
|
||||
for service in self.services_cache:
|
||||
if service.token == token:
|
||||
service.url = None
|
||||
service.namespaces = None
|
||||
service.hs_token = None
|
||||
|
||||
def _unregister_app_service_txn(self, txn, token):
|
||||
# kill the url to prevent pushes
|
||||
txn.execute(
|
||||
"UPDATE application_services SET url=NULL WHERE token=?",
|
||||
(token,)
|
||||
self._populate_appservice_cache(
|
||||
hs.config.app_service_config_files
|
||||
)
|
||||
|
||||
# cleanup regex
|
||||
as_id = self._get_as_id_txn(txn, token)
|
||||
if not as_id:
|
||||
logger.warning(
|
||||
"unregister_app_service_txn: Failed to find as_id for token=",
|
||||
token
|
||||
)
|
||||
return False
|
||||
|
||||
txn.execute(
|
||||
"DELETE FROM application_services_regex WHERE as_id=?",
|
||||
(as_id,)
|
||||
)
|
||||
return True
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def update_app_service(self, service):
|
||||
"""Update an application service, clobbering what was previously there.
|
||||
|
||||
Args:
|
||||
service(ApplicationService): The updated service.
|
||||
"""
|
||||
yield self.cache_defer # make sure the cache is ready
|
||||
|
||||
# NB: There is no "insert" since we provide no public-facing API to
|
||||
# allocate new ASes. It relies on the server admin inserting the AS
|
||||
# token into the database manually.
|
||||
|
||||
if not service.token or not service.url:
|
||||
raise StoreError(400, "Token and url must be specified.")
|
||||
|
||||
if not service.hs_token:
|
||||
raise StoreError(500, "No HS token")
|
||||
|
||||
yield self.runInteraction(
|
||||
"update_app_service",
|
||||
self._update_app_service_txn,
|
||||
service
|
||||
)
|
||||
|
||||
# update cache TODO: Should this be in the txn?
|
||||
for (index, cache_service) in enumerate(self.services_cache):
|
||||
if service.token == cache_service.token:
|
||||
self.services_cache[index] = service
|
||||
logger.info("Updated: %s", service)
|
||||
return
|
||||
# new entry
|
||||
self.services_cache.append(service)
|
||||
logger.info("Updated(new): %s", service)
|
||||
|
||||
def _update_app_service_txn(self, txn, service):
|
||||
as_id = self._get_as_id_txn(txn, service.token)
|
||||
if not as_id:
|
||||
logger.warning(
|
||||
"update_app_service_txn: Failed to find as_id for token=",
|
||||
service.token
|
||||
)
|
||||
return False
|
||||
|
||||
txn.execute(
|
||||
"UPDATE application_services SET url=?, hs_token=?, sender=? "
|
||||
"WHERE id=?",
|
||||
(service.url, service.hs_token, service.sender, as_id,)
|
||||
)
|
||||
# cleanup regex
|
||||
txn.execute(
|
||||
"DELETE FROM application_services_regex WHERE as_id=?",
|
||||
(as_id,)
|
||||
)
|
||||
for (ns_int, ns_str) in enumerate(ApplicationService.NS_LIST):
|
||||
if ns_str in service.namespaces:
|
||||
for regex_obj in service.namespaces[ns_str]:
|
||||
txn.execute(
|
||||
"INSERT INTO application_services_regex("
|
||||
"as_id, namespace, regex) values(?,?,?)",
|
||||
(as_id, ns_int, simplejson.dumps(regex_obj))
|
||||
)
|
||||
return True
|
||||
|
||||
def _get_as_id_txn(self, txn, token):
|
||||
cursor = txn.execute(
|
||||
"SELECT id FROM application_services WHERE token=?",
|
||||
(token,)
|
||||
)
|
||||
res = cursor.fetchone()
|
||||
if res:
|
||||
return res[0]
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_app_services(self):
|
||||
yield self.cache_defer # make sure the cache is ready
|
||||
defer.returnValue(self.services_cache)
|
||||
return defer.succeed(self.services_cache)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_app_service_by_user_id(self, user_id):
|
||||
"""Retrieve an application service from their user ID.
|
||||
|
||||
@@ -174,37 +55,23 @@ class ApplicationServiceStore(SQLBaseStore):
|
||||
Returns:
|
||||
synapse.appservice.ApplicationService or None.
|
||||
"""
|
||||
|
||||
yield self.cache_defer # make sure the cache is ready
|
||||
|
||||
for service in self.services_cache:
|
||||
if service.sender == user_id:
|
||||
defer.returnValue(service)
|
||||
return
|
||||
defer.returnValue(None)
|
||||
return defer.succeed(service)
|
||||
return defer.succeed(None)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_app_service_by_token(self, token, from_cache=True):
|
||||
def get_app_service_by_token(self, token):
|
||||
"""Get the application service with the given appservice token.
|
||||
|
||||
Args:
|
||||
token (str): The application service token.
|
||||
from_cache (bool): True to get this service from the cache, False to
|
||||
check the database.
|
||||
Raises:
|
||||
StoreError if there was a problem retrieving this service.
|
||||
Returns:
|
||||
synapse.appservice.ApplicationService or None.
|
||||
"""
|
||||
yield self.cache_defer # make sure the cache is ready
|
||||
|
||||
if from_cache:
|
||||
for service in self.services_cache:
|
||||
if service.token == token:
|
||||
defer.returnValue(service)
|
||||
return
|
||||
defer.returnValue(None)
|
||||
|
||||
# TODO: The from_cache=False impl
|
||||
# TODO: This should be JOINed with the application_services_regex table.
|
||||
for service in self.services_cache:
|
||||
if service.token == token:
|
||||
return defer.succeed(service)
|
||||
return defer.succeed(None)
|
||||
|
||||
def get_app_service_rooms(self, service):
|
||||
"""Get a list of RoomsForUser for this application service.
|
||||
@@ -277,12 +144,7 @@ class ApplicationServiceStore(SQLBaseStore):
|
||||
|
||||
return rooms_for_user_matching_user_id
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _populate_cache(self):
|
||||
"""Populates the ApplicationServiceCache from the database."""
|
||||
sql = ("SELECT * FROM application_services LEFT JOIN "
|
||||
"application_services_regex ON application_services.id = "
|
||||
"application_services_regex.as_id")
|
||||
def _parse_services_dict(self, results):
|
||||
# SQL results in the form:
|
||||
# [
|
||||
# {
|
||||
@@ -296,12 +158,14 @@ class ApplicationServiceStore(SQLBaseStore):
|
||||
# }
|
||||
# ]
|
||||
services = {}
|
||||
results = yield self._execute_and_decode("_populate_cache", sql)
|
||||
for res in results:
|
||||
as_token = res["token"]
|
||||
if as_token is None:
|
||||
continue
|
||||
if as_token not in services:
|
||||
# add the service
|
||||
services[as_token] = {
|
||||
"id": res["id"],
|
||||
"url": res["url"],
|
||||
"token": as_token,
|
||||
"hs_token": res["hs_token"],
|
||||
@@ -319,20 +183,289 @@ class ApplicationServiceStore(SQLBaseStore):
|
||||
try:
|
||||
services[as_token]["namespaces"][
|
||||
ApplicationService.NS_LIST[ns_int]].append(
|
||||
simplejson.loads(res["regex"])
|
||||
json.loads(res["regex"])
|
||||
)
|
||||
except IndexError:
|
||||
logger.error("Bad namespace enum '%s'. %s", ns_int, res)
|
||||
except JSONDecodeError:
|
||||
logger.error("Bad regex object '%s'", res["regex"])
|
||||
|
||||
# TODO get last successful txn id f.e. service
|
||||
service_list = []
|
||||
for service in services.values():
|
||||
logger.info("Found application service: %s", service)
|
||||
self.services_cache.append(ApplicationService(
|
||||
service_list.append(ApplicationService(
|
||||
token=service["token"],
|
||||
url=service["url"],
|
||||
namespaces=service["namespaces"],
|
||||
hs_token=service["hs_token"],
|
||||
sender=service["sender"]
|
||||
sender=service["sender"],
|
||||
id=service["id"]
|
||||
))
|
||||
return service_list
|
||||
|
||||
def _load_appservice(self, as_info):
|
||||
required_string_fields = [
|
||||
"url", "as_token", "hs_token", "sender_localpart"
|
||||
]
|
||||
for field in required_string_fields:
|
||||
if not isinstance(as_info.get(field), basestring):
|
||||
raise KeyError("Required string field: '%s'", field)
|
||||
|
||||
localpart = as_info["sender_localpart"]
|
||||
if urllib.quote(localpart) != localpart:
|
||||
raise ValueError(
|
||||
"sender_localpart needs characters which are not URL encoded."
|
||||
)
|
||||
user = UserID(localpart, self.hostname)
|
||||
user_id = user.to_string()
|
||||
|
||||
# namespace checks
|
||||
if not isinstance(as_info.get("namespaces"), dict):
|
||||
raise KeyError("Requires 'namespaces' object.")
|
||||
for ns in ApplicationService.NS_LIST:
|
||||
# specific namespaces are optional
|
||||
if ns in as_info["namespaces"]:
|
||||
# expect a list of dicts with exclusive and regex keys
|
||||
for regex_obj in as_info["namespaces"][ns]:
|
||||
if not isinstance(regex_obj, dict):
|
||||
raise ValueError(
|
||||
"Expected namespace entry in %s to be an object,"
|
||||
" but got %s", ns, regex_obj
|
||||
)
|
||||
if not isinstance(regex_obj.get("regex"), basestring):
|
||||
raise ValueError(
|
||||
"Missing/bad type 'regex' key in %s", regex_obj
|
||||
)
|
||||
if not isinstance(regex_obj.get("exclusive"), bool):
|
||||
raise ValueError(
|
||||
"Missing/bad type 'exclusive' key in %s", regex_obj
|
||||
)
|
||||
return ApplicationService(
|
||||
token=as_info["as_token"],
|
||||
url=as_info["url"],
|
||||
namespaces=as_info["namespaces"],
|
||||
hs_token=as_info["hs_token"],
|
||||
sender=user_id,
|
||||
id=as_info["as_token"] # the token is the only unique thing here
|
||||
)
|
||||
|
||||
def _populate_appservice_cache(self, config_files):
|
||||
"""Populates a cache of Application Services from the config files."""
|
||||
if not isinstance(config_files, list):
|
||||
logger.warning(
|
||||
"Expected %s to be a list of AS config files.", config_files
|
||||
)
|
||||
return
|
||||
|
||||
for config_file in config_files:
|
||||
try:
|
||||
with open(config_file, 'r') as f:
|
||||
appservice = self._load_appservice(yaml.load(f))
|
||||
logger.info("Loaded application service: %s", appservice)
|
||||
self.services_cache.append(appservice)
|
||||
except Exception as e:
|
||||
logger.error("Failed to load appservice from '%s'", config_file)
|
||||
logger.exception(e)
|
||||
|
||||
|
||||
class ApplicationServiceTransactionStore(SQLBaseStore):
|
||||
|
||||
def __init__(self, hs):
|
||||
super(ApplicationServiceTransactionStore, self).__init__(hs)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_appservices_by_state(self, state):
|
||||
"""Get a list of application services based on their state.
|
||||
|
||||
Args:
|
||||
state(ApplicationServiceState): The state to filter on.
|
||||
Returns:
|
||||
A Deferred which resolves to a list of ApplicationServices, which
|
||||
may be empty.
|
||||
"""
|
||||
results = yield self._simple_select_list(
|
||||
"application_services_state",
|
||||
dict(state=state),
|
||||
["as_id"]
|
||||
)
|
||||
# NB: This assumes this class is linked with ApplicationServiceStore
|
||||
as_list = yield self.get_app_services()
|
||||
services = []
|
||||
|
||||
for res in results:
|
||||
for service in as_list:
|
||||
if service.id == res["as_id"]:
|
||||
services.append(service)
|
||||
defer.returnValue(services)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_appservice_state(self, service):
|
||||
"""Get the application service state.
|
||||
|
||||
Args:
|
||||
service(ApplicationService): The service whose state to set.
|
||||
Returns:
|
||||
A Deferred which resolves to ApplicationServiceState.
|
||||
"""
|
||||
result = yield self._simple_select_one(
|
||||
"application_services_state",
|
||||
dict(as_id=service.id),
|
||||
["state"],
|
||||
allow_none=True
|
||||
)
|
||||
if result:
|
||||
defer.returnValue(result.get("state"))
|
||||
return
|
||||
defer.returnValue(None)
|
||||
|
||||
def set_appservice_state(self, service, state):
|
||||
"""Set the application service state.
|
||||
|
||||
Args:
|
||||
service(ApplicationService): The service whose state to set.
|
||||
state(ApplicationServiceState): The connectivity state to apply.
|
||||
Returns:
|
||||
A Deferred which resolves when the state was set successfully.
|
||||
"""
|
||||
return self._simple_upsert(
|
||||
"application_services_state",
|
||||
dict(as_id=service.id),
|
||||
dict(state=state)
|
||||
)
|
||||
|
||||
def create_appservice_txn(self, service, events):
|
||||
"""Atomically creates a new transaction for this application service
|
||||
with the given list of events.
|
||||
|
||||
Args:
|
||||
service(ApplicationService): The service who the transaction is for.
|
||||
events(list<Event>): A list of events to put in the transaction.
|
||||
Returns:
|
||||
AppServiceTransaction: A new transaction.
|
||||
"""
|
||||
return self.runInteraction(
|
||||
"create_appservice_txn",
|
||||
self._create_appservice_txn,
|
||||
service, events
|
||||
)
|
||||
|
||||
def _create_appservice_txn(self, txn, service, events):
|
||||
# work out new txn id (highest txn id for this service += 1)
|
||||
# The highest id may be the last one sent (in which case it is last_txn)
|
||||
# or it may be the highest in the txns list (which are waiting to be/are
|
||||
# being sent)
|
||||
last_txn_id = self._get_last_txn(txn, service.id)
|
||||
|
||||
txn.execute(
|
||||
"SELECT MAX(txn_id) FROM application_services_txns WHERE as_id=?",
|
||||
(service.id,)
|
||||
)
|
||||
highest_txn_id = txn.fetchone()[0]
|
||||
if highest_txn_id is None:
|
||||
highest_txn_id = 0
|
||||
|
||||
new_txn_id = max(highest_txn_id, last_txn_id) + 1
|
||||
|
||||
# Insert new txn into txn table
|
||||
event_ids = json.dumps([e.event_id for e in events])
|
||||
txn.execute(
|
||||
"INSERT INTO application_services_txns(as_id, txn_id, event_ids) "
|
||||
"VALUES(?,?,?)",
|
||||
(service.id, new_txn_id, event_ids)
|
||||
)
|
||||
return AppServiceTransaction(
|
||||
service=service, id=new_txn_id, events=events
|
||||
)
|
||||
|
||||
def complete_appservice_txn(self, txn_id, service):
|
||||
"""Completes an application service transaction.
|
||||
|
||||
Args:
|
||||
txn_id(str): The transaction ID being completed.
|
||||
service(ApplicationService): The application service which was sent
|
||||
this transaction.
|
||||
Returns:
|
||||
A Deferred which resolves if this transaction was stored
|
||||
successfully.
|
||||
"""
|
||||
return self.runInteraction(
|
||||
"complete_appservice_txn",
|
||||
self._complete_appservice_txn,
|
||||
txn_id, service
|
||||
)
|
||||
|
||||
def _complete_appservice_txn(self, txn, txn_id, service):
|
||||
txn_id = int(txn_id)
|
||||
|
||||
# Debugging query: Make sure the txn being completed is EXACTLY +1 from
|
||||
# what was there before. If it isn't, we've got problems (e.g. the AS
|
||||
# has probably missed some events), so whine loudly but still continue,
|
||||
# since it shouldn't fail completion of the transaction.
|
||||
last_txn_id = self._get_last_txn(txn, service.id)
|
||||
if (last_txn_id + 1) != txn_id:
|
||||
logger.error(
|
||||
"appservice: Completing a transaction which has an ID > 1 from "
|
||||
"the last ID sent to this AS. We've either dropped events or "
|
||||
"sent it to the AS out of order. FIX ME. last_txn=%s "
|
||||
"completing_txn=%s service_id=%s", last_txn_id, txn_id,
|
||||
service.id
|
||||
)
|
||||
|
||||
# Set current txn_id for AS to 'txn_id'
|
||||
self._simple_upsert_txn(
|
||||
txn, "application_services_state", dict(as_id=service.id),
|
||||
dict(last_txn=txn_id)
|
||||
)
|
||||
|
||||
# Delete txn
|
||||
self._simple_delete_txn(
|
||||
txn, "application_services_txns",
|
||||
dict(txn_id=txn_id, as_id=service.id)
|
||||
)
|
||||
|
||||
def get_oldest_unsent_txn(self, service):
|
||||
"""Get the oldest transaction which has not been sent for this
|
||||
service.
|
||||
|
||||
Args:
|
||||
service(ApplicationService): The app service to get the oldest txn.
|
||||
Returns:
|
||||
A Deferred which resolves to an AppServiceTransaction or
|
||||
None.
|
||||
"""
|
||||
return self.runInteraction(
|
||||
"get_oldest_unsent_appservice_txn",
|
||||
self._get_oldest_unsent_txn,
|
||||
service
|
||||
)
|
||||
|
||||
def _get_oldest_unsent_txn(self, txn, service):
|
||||
# Monotonically increasing txn ids, so just select the smallest
|
||||
# one in the txns table (we delete them when they are sent)
|
||||
txn.execute(
|
||||
"SELECT * FROM application_services_txns WHERE as_id=?"
|
||||
" ORDER BY txn_id ASC LIMIT 1",
|
||||
(service.id,)
|
||||
)
|
||||
rows = self.cursor_to_dict(txn)
|
||||
if not rows:
|
||||
return None
|
||||
|
||||
entry = rows[0]
|
||||
|
||||
event_ids = json.loads(entry["event_ids"])
|
||||
events = self._get_events_txn(txn, event_ids)
|
||||
|
||||
return AppServiceTransaction(
|
||||
service=service, id=entry["txn_id"], events=events
|
||||
)
|
||||
|
||||
def _get_last_txn(self, txn, service_id):
|
||||
txn.execute(
|
||||
"SELECT last_txn FROM application_services_state WHERE as_id=?",
|
||||
(service_id,)
|
||||
)
|
||||
last_txn_id = txn.fetchone()
|
||||
if last_txn_id is None or last_txn_id[0] is None: # no row exists
|
||||
return 0
|
||||
else:
|
||||
return int(last_txn_id[0]) # select 'last_txn' col
|
||||
|
||||
@@ -13,7 +13,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ._base import SQLBaseStore
|
||||
from ._base import SQLBaseStore, cached
|
||||
|
||||
from synapse.api.errors import SynapseError
|
||||
|
||||
@@ -21,8 +21,6 @@ from twisted.internet import defer
|
||||
|
||||
from collections import namedtuple
|
||||
|
||||
import sqlite3
|
||||
|
||||
|
||||
RoomAliasMapping = namedtuple(
|
||||
"RoomAliasMapping",
|
||||
@@ -48,6 +46,7 @@ class DirectoryStore(SQLBaseStore):
|
||||
{"room_alias": room_alias.to_string()},
|
||||
"room_id",
|
||||
allow_none=True,
|
||||
desc="get_association_from_room_alias",
|
||||
)
|
||||
|
||||
if not room_id:
|
||||
@@ -58,6 +57,7 @@ class DirectoryStore(SQLBaseStore):
|
||||
"room_alias_servers",
|
||||
{"room_alias": room_alias.to_string()},
|
||||
"server",
|
||||
desc="get_association_from_room_alias",
|
||||
)
|
||||
|
||||
if not servers:
|
||||
@@ -87,8 +87,9 @@ class DirectoryStore(SQLBaseStore):
|
||||
"room_alias": room_alias.to_string(),
|
||||
"room_id": room_id,
|
||||
},
|
||||
desc="create_room_alias_association",
|
||||
)
|
||||
except sqlite3.IntegrityError:
|
||||
except self.database_engine.module.IntegrityError:
|
||||
raise SynapseError(
|
||||
409, "Room alias %s already exists" % room_alias.to_string()
|
||||
)
|
||||
@@ -100,23 +101,29 @@ class DirectoryStore(SQLBaseStore):
|
||||
{
|
||||
"room_alias": room_alias.to_string(),
|
||||
"server": server,
|
||||
}
|
||||
},
|
||||
desc="create_room_alias_association",
|
||||
)
|
||||
self.get_aliases_for_room.invalidate(room_id)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def delete_room_alias(self, room_alias):
|
||||
return self.runInteraction(
|
||||
room_id = yield self.runInteraction(
|
||||
"delete_room_alias",
|
||||
self._delete_room_alias_txn,
|
||||
room_alias,
|
||||
)
|
||||
|
||||
self.get_aliases_for_room.invalidate(room_id)
|
||||
defer.returnValue(room_id)
|
||||
|
||||
def _delete_room_alias_txn(self, txn, room_alias):
|
||||
cursor = txn.execute(
|
||||
txn.execute(
|
||||
"SELECT room_id FROM room_aliases WHERE room_alias = ?",
|
||||
(room_alias.to_string(),)
|
||||
)
|
||||
|
||||
res = cursor.fetchone()
|
||||
res = txn.fetchone()
|
||||
if res:
|
||||
room_id = res[0]
|
||||
else:
|
||||
@@ -134,9 +141,11 @@ class DirectoryStore(SQLBaseStore):
|
||||
|
||||
return room_id
|
||||
|
||||
@cached()
|
||||
def get_aliases_for_room(self, room_id):
|
||||
return self._simple_select_onecol(
|
||||
"room_aliases",
|
||||
{"room_id": room_id},
|
||||
"room_alias",
|
||||
desc="get_aliases_for_room",
|
||||
)
|
||||
|
||||
41
synapse/storage/engines/__init__.py
Normal file
41
synapse/storage/engines/__init__.py
Normal file
@@ -0,0 +1,41 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2015 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ._base import IncorrectDatabaseSetup
|
||||
from .postgres import PostgresEngine
|
||||
from .sqlite3 import Sqlite3Engine
|
||||
|
||||
import importlib
|
||||
|
||||
|
||||
SUPPORTED_MODULE = {
|
||||
"sqlite3": Sqlite3Engine,
|
||||
"psycopg2": PostgresEngine,
|
||||
}
|
||||
|
||||
|
||||
def create_engine(name):
|
||||
engine_class = SUPPORTED_MODULE.get(name, None)
|
||||
|
||||
if engine_class:
|
||||
module = importlib.import_module(name)
|
||||
return engine_class(module)
|
||||
|
||||
raise RuntimeError(
|
||||
"Unsupported database engine '%s'" % (name,)
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["create_engine", "IncorrectDatabaseSetup"]
|
||||
18
synapse/storage/engines/_base.py
Normal file
18
synapse/storage/engines/_base.py
Normal file
@@ -0,0 +1,18 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2015 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
class IncorrectDatabaseSetup(RuntimeError):
|
||||
pass
|
||||
59
synapse/storage/engines/postgres.py
Normal file
59
synapse/storage/engines/postgres.py
Normal file
@@ -0,0 +1,59 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2015 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from synapse.storage import prepare_database
|
||||
|
||||
from ._base import IncorrectDatabaseSetup
|
||||
|
||||
|
||||
class PostgresEngine(object):
|
||||
def __init__(self, database_module):
|
||||
self.module = database_module
|
||||
self.module.extensions.register_type(self.module.extensions.UNICODE)
|
||||
|
||||
def check_database(self, txn):
|
||||
txn.execute("SHOW SERVER_ENCODING")
|
||||
rows = txn.fetchall()
|
||||
if rows and rows[0][0] != "UTF8":
|
||||
raise IncorrectDatabaseSetup(
|
||||
"Database has incorrect encoding: '%s' instead of 'UTF8'\n"
|
||||
"See docs/postgres.rst for more information."
|
||||
% (rows[0][0],)
|
||||
)
|
||||
|
||||
def convert_param_style(self, sql):
|
||||
return sql.replace("?", "%s")
|
||||
|
||||
def encode_parameter(self, param):
|
||||
return param
|
||||
|
||||
def on_new_connection(self, db_conn):
|
||||
db_conn.set_isolation_level(
|
||||
self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ
|
||||
)
|
||||
|
||||
def prepare_database(self, db_conn):
|
||||
prepare_database(db_conn, self)
|
||||
|
||||
def is_deadlock(self, error):
|
||||
if isinstance(error, self.module.DatabaseError):
|
||||
return error.pgcode in ["40001", "40P01"]
|
||||
return False
|
||||
|
||||
def is_connection_closed(self, conn):
|
||||
return bool(conn.closed)
|
||||
|
||||
def lock_table(self, txn, table):
|
||||
txn.execute("LOCK TABLE %s in EXCLUSIVE MODE" % (table,))
|
||||
46
synapse/storage/engines/sqlite3.py
Normal file
46
synapse/storage/engines/sqlite3.py
Normal file
@@ -0,0 +1,46 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2015 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from synapse.storage import prepare_database, prepare_sqlite3_database
|
||||
|
||||
|
||||
class Sqlite3Engine(object):
|
||||
def __init__(self, database_module):
|
||||
self.module = database_module
|
||||
|
||||
def check_database(self, txn):
|
||||
pass
|
||||
|
||||
def convert_param_style(self, sql):
|
||||
return sql
|
||||
|
||||
def encode_parameter(self, param):
|
||||
return param
|
||||
|
||||
def on_new_connection(self, db_conn):
|
||||
self.prepare_database(db_conn)
|
||||
|
||||
def prepare_database(self, db_conn):
|
||||
prepare_sqlite3_database(db_conn)
|
||||
prepare_database(db_conn, self)
|
||||
|
||||
def is_deadlock(self, error):
|
||||
return False
|
||||
|
||||
def is_connection_closed(self, conn):
|
||||
return False
|
||||
|
||||
def lock_table(self, txn, table):
|
||||
return
|
||||
@@ -96,11 +96,22 @@ class EventFederationStore(SQLBaseStore):
|
||||
room_id,
|
||||
)
|
||||
|
||||
def get_latest_event_ids_in_room(self, room_id):
|
||||
return self._simple_select_onecol(
|
||||
table="event_forward_extremities",
|
||||
keyvalues={
|
||||
"room_id": room_id,
|
||||
},
|
||||
retcol="event_id",
|
||||
desc="get_latest_events_in_room",
|
||||
)
|
||||
|
||||
def _get_latest_events_in_room(self, txn, room_id):
|
||||
sql = (
|
||||
"SELECT e.event_id, e.depth FROM events as e "
|
||||
"INNER JOIN event_forward_extremities as f "
|
||||
"ON e.event_id = f.event_id "
|
||||
"AND e.room_id = f.room_id "
|
||||
"WHERE f.room_id = ?"
|
||||
)
|
||||
|
||||
@@ -153,7 +164,7 @@ class EventFederationStore(SQLBaseStore):
|
||||
results = self._get_prev_events_and_state(
|
||||
txn,
|
||||
event_id,
|
||||
is_state=1,
|
||||
is_state=True,
|
||||
)
|
||||
|
||||
return [(e_id, h, ) for e_id, h, _ in results]
|
||||
@@ -164,7 +175,7 @@ class EventFederationStore(SQLBaseStore):
|
||||
}
|
||||
|
||||
if is_state is not None:
|
||||
keyvalues["is_state"] = is_state
|
||||
keyvalues["is_state"] = bool(is_state)
|
||||
|
||||
res = self._simple_select_list_txn(
|
||||
txn,
|
||||
@@ -242,7 +253,6 @@ class EventFederationStore(SQLBaseStore):
|
||||
"room_id": room_id,
|
||||
"min_depth": depth,
|
||||
},
|
||||
or_replace=True,
|
||||
)
|
||||
|
||||
def _handle_prev_events(self, txn, outlier, event_id, prev_events,
|
||||
@@ -260,9 +270,8 @@ class EventFederationStore(SQLBaseStore):
|
||||
"event_id": event_id,
|
||||
"prev_event_id": e_id,
|
||||
"room_id": room_id,
|
||||
"is_state": 0,
|
||||
"is_state": False,
|
||||
},
|
||||
or_ignore=True,
|
||||
)
|
||||
|
||||
# Update the extremities table if this is not an outlier.
|
||||
@@ -281,19 +290,19 @@ class EventFederationStore(SQLBaseStore):
|
||||
# We only insert as a forward extremity the new event if there are
|
||||
# no other events that reference it as a prev event
|
||||
query = (
|
||||
"INSERT OR IGNORE INTO %(table)s (event_id, room_id) "
|
||||
"SELECT ?, ? WHERE NOT EXISTS ("
|
||||
"SELECT 1 FROM %(event_edges)s WHERE "
|
||||
"prev_event_id = ? "
|
||||
")"
|
||||
) % {
|
||||
"table": "event_forward_extremities",
|
||||
"event_edges": "event_edges",
|
||||
}
|
||||
"SELECT 1 FROM event_edges WHERE prev_event_id = ?"
|
||||
)
|
||||
|
||||
logger.debug("query: %s", query)
|
||||
txn.execute(query, (event_id,))
|
||||
|
||||
txn.execute(query, (event_id, room_id, event_id))
|
||||
if not txn.fetchone():
|
||||
query = (
|
||||
"INSERT INTO event_forward_extremities"
|
||||
" (event_id, room_id)"
|
||||
" VALUES (?, ?)"
|
||||
)
|
||||
|
||||
txn.execute(query, (event_id, room_id))
|
||||
|
||||
# Insert all the prev_events as a backwards thing, they'll get
|
||||
# deleted in a second if they're incorrect anyway.
|
||||
@@ -306,7 +315,6 @@ class EventFederationStore(SQLBaseStore):
|
||||
"event_id": e_id,
|
||||
"room_id": room_id,
|
||||
},
|
||||
or_ignore=True,
|
||||
)
|
||||
|
||||
# Also delete from the backwards extremities table all ones that
|
||||
@@ -400,7 +408,7 @@ class EventFederationStore(SQLBaseStore):
|
||||
|
||||
query = (
|
||||
"SELECT prev_event_id FROM event_edges "
|
||||
"WHERE room_id = ? AND event_id = ? AND is_state = 0 "
|
||||
"WHERE room_id = ? AND event_id = ? AND is_state = ? "
|
||||
"LIMIT ?"
|
||||
)
|
||||
|
||||
@@ -409,7 +417,7 @@ class EventFederationStore(SQLBaseStore):
|
||||
for event_id in front:
|
||||
txn.execute(
|
||||
query,
|
||||
(room_id, event_id, limit - len(event_results))
|
||||
(room_id, event_id, False, limit - len(event_results))
|
||||
)
|
||||
|
||||
for e_id, in txn.fetchall():
|
||||
|
||||
397
synapse/storage/events.py
Normal file
397
synapse/storage/events.py
Normal file
@@ -0,0 +1,397 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014, 2015 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from _base import SQLBaseStore, _RollbackButIsFineException
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.util.logutils import log_function
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.crypto.event_signing import compute_event_reference_hash
|
||||
|
||||
from syutil.base64util import decode_base64
|
||||
from syutil.jsonutil import encode_canonical_json
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EventsStore(SQLBaseStore):
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def persist_event(self, event, context, backfilled=False,
|
||||
is_new_state=True, current_state=None):
|
||||
stream_ordering = None
|
||||
if backfilled:
|
||||
if not self.min_token_deferred.called:
|
||||
yield self.min_token_deferred
|
||||
self.min_token -= 1
|
||||
stream_ordering = self.min_token
|
||||
|
||||
try:
|
||||
yield self.runInteraction(
|
||||
"persist_event",
|
||||
self._persist_event_txn,
|
||||
event=event,
|
||||
context=context,
|
||||
backfilled=backfilled,
|
||||
stream_ordering=stream_ordering,
|
||||
is_new_state=is_new_state,
|
||||
current_state=current_state,
|
||||
)
|
||||
except _RollbackButIsFineException:
|
||||
pass
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_event(self, event_id, check_redacted=True,
|
||||
get_prev_content=False, allow_rejected=False,
|
||||
allow_none=False):
|
||||
"""Get an event from the database by event_id.
|
||||
|
||||
Args:
|
||||
event_id (str): The event_id of the event to fetch
|
||||
check_redacted (bool): If True, check if event has been redacted
|
||||
and redact it.
|
||||
get_prev_content (bool): If True and event is a state event,
|
||||
include the previous states content in the unsigned field.
|
||||
allow_rejected (bool): If True return rejected events.
|
||||
allow_none (bool): If True, return None if no event found, if
|
||||
False throw an exception.
|
||||
|
||||
Returns:
|
||||
Deferred : A FrozenEvent.
|
||||
"""
|
||||
event = yield self.runInteraction(
|
||||
"get_event", self._get_event_txn,
|
||||
event_id,
|
||||
check_redacted=check_redacted,
|
||||
get_prev_content=get_prev_content,
|
||||
allow_rejected=allow_rejected,
|
||||
)
|
||||
|
||||
if not event and not allow_none:
|
||||
raise RuntimeError("Could not find event %s" % (event_id,))
|
||||
|
||||
defer.returnValue(event)
|
||||
|
||||
@log_function
|
||||
def _persist_event_txn(self, txn, event, context, backfilled,
|
||||
stream_ordering=None, is_new_state=True,
|
||||
current_state=None):
|
||||
|
||||
# Remove the any existing cache entries for the event_id
|
||||
self._invalidate_get_event_cache(event.event_id)
|
||||
|
||||
if stream_ordering is None:
|
||||
with self._stream_id_gen.get_next_txn(txn) as stream_ordering:
|
||||
return self._persist_event_txn(
|
||||
txn, event, context, backfilled,
|
||||
stream_ordering=stream_ordering,
|
||||
is_new_state=is_new_state,
|
||||
current_state=current_state,
|
||||
)
|
||||
|
||||
# We purposefully do this first since if we include a `current_state`
|
||||
# key, we *want* to update the `current_state_events` table
|
||||
if current_state:
|
||||
self._simple_delete_txn(
|
||||
txn,
|
||||
table="current_state_events",
|
||||
keyvalues={"room_id": event.room_id},
|
||||
)
|
||||
|
||||
for s in current_state:
|
||||
self._simple_insert_txn(
|
||||
txn,
|
||||
"current_state_events",
|
||||
{
|
||||
"event_id": s.event_id,
|
||||
"room_id": s.room_id,
|
||||
"type": s.type,
|
||||
"state_key": s.state_key,
|
||||
},
|
||||
)
|
||||
|
||||
if event.is_state() and is_new_state:
|
||||
if not backfilled and not context.rejected:
|
||||
self._simple_insert_txn(
|
||||
txn,
|
||||
table="state_forward_extremities",
|
||||
values={
|
||||
"event_id": event.event_id,
|
||||
"room_id": event.room_id,
|
||||
"type": event.type,
|
||||
"state_key": event.state_key,
|
||||
},
|
||||
)
|
||||
|
||||
for prev_state_id, _ in event.prev_state:
|
||||
self._simple_delete_txn(
|
||||
txn,
|
||||
table="state_forward_extremities",
|
||||
keyvalues={
|
||||
"event_id": prev_state_id,
|
||||
}
|
||||
)
|
||||
|
||||
outlier = event.internal_metadata.is_outlier()
|
||||
|
||||
if not outlier:
|
||||
self._store_state_groups_txn(txn, event, context)
|
||||
|
||||
self._update_min_depth_for_room_txn(
|
||||
txn,
|
||||
event.room_id,
|
||||
event.depth
|
||||
)
|
||||
|
||||
have_persisted = self._simple_select_one_onecol_txn(
|
||||
txn,
|
||||
table="event_json",
|
||||
keyvalues={"event_id": event.event_id},
|
||||
retcol="event_id",
|
||||
allow_none=True,
|
||||
)
|
||||
|
||||
metadata_json = encode_canonical_json(
|
||||
event.internal_metadata.get_dict()
|
||||
).decode("UTF-8")
|
||||
|
||||
# If we have already persisted this event, we don't need to do any
|
||||
# more processing.
|
||||
# The processing above must be done on every call to persist event,
|
||||
# since they might not have happened on previous calls. For example,
|
||||
# if we are persisting an event that we had persisted as an outlier,
|
||||
# but is no longer one.
|
||||
if have_persisted:
|
||||
if not outlier:
|
||||
sql = (
|
||||
"UPDATE event_json SET internal_metadata = ?"
|
||||
" WHERE event_id = ?"
|
||||
)
|
||||
txn.execute(
|
||||
sql,
|
||||
(metadata_json, event.event_id,)
|
||||
)
|
||||
|
||||
sql = (
|
||||
"UPDATE events SET outlier = ?"
|
||||
" WHERE event_id = ?"
|
||||
)
|
||||
txn.execute(
|
||||
sql,
|
||||
(False, event.event_id,)
|
||||
)
|
||||
return
|
||||
|
||||
self._handle_prev_events(
|
||||
txn,
|
||||
outlier=outlier,
|
||||
event_id=event.event_id,
|
||||
prev_events=event.prev_events,
|
||||
room_id=event.room_id,
|
||||
)
|
||||
|
||||
if event.type == EventTypes.Member:
|
||||
self._store_room_member_txn(txn, event)
|
||||
elif event.type == EventTypes.Name:
|
||||
self._store_room_name_txn(txn, event)
|
||||
elif event.type == EventTypes.Topic:
|
||||
self._store_room_topic_txn(txn, event)
|
||||
elif event.type == EventTypes.Redaction:
|
||||
self._store_redaction(txn, event)
|
||||
|
||||
event_dict = {
|
||||
k: v
|
||||
for k, v in event.get_dict().items()
|
||||
if k not in [
|
||||
"redacted",
|
||||
"redacted_because",
|
||||
]
|
||||
}
|
||||
|
||||
self._simple_insert_txn(
|
||||
txn,
|
||||
table="event_json",
|
||||
values={
|
||||
"event_id": event.event_id,
|
||||
"room_id": event.room_id,
|
||||
"internal_metadata": metadata_json,
|
||||
"json": encode_canonical_json(event_dict).decode("UTF-8"),
|
||||
},
|
||||
)
|
||||
|
||||
content = encode_canonical_json(
|
||||
event.content
|
||||
).decode("UTF-8")
|
||||
|
||||
vals = {
|
||||
"topological_ordering": event.depth,
|
||||
"event_id": event.event_id,
|
||||
"type": event.type,
|
||||
"room_id": event.room_id,
|
||||
"content": content,
|
||||
"processed": True,
|
||||
"outlier": outlier,
|
||||
"depth": event.depth,
|
||||
}
|
||||
|
||||
unrec = {
|
||||
k: v
|
||||
for k, v in event.get_dict().items()
|
||||
if k not in vals.keys() and k not in [
|
||||
"redacted",
|
||||
"redacted_because",
|
||||
"signatures",
|
||||
"hashes",
|
||||
"prev_events",
|
||||
]
|
||||
}
|
||||
|
||||
vals["unrecognized_keys"] = encode_canonical_json(
|
||||
unrec
|
||||
).decode("UTF-8")
|
||||
|
||||
sql = (
|
||||
"INSERT INTO events"
|
||||
" (stream_ordering, topological_ordering, event_id, type,"
|
||||
" room_id, content, processed, outlier, depth)"
|
||||
" VALUES (?,?,?,?,?,?,?,?,?)"
|
||||
)
|
||||
|
||||
txn.execute(
|
||||
sql,
|
||||
(
|
||||
stream_ordering, event.depth, event.event_id, event.type,
|
||||
event.room_id, content, True, outlier, event.depth
|
||||
)
|
||||
)
|
||||
|
||||
if context.rejected:
|
||||
self._store_rejections_txn(txn, event.event_id, context.rejected)
|
||||
|
||||
for hash_alg, hash_base64 in event.hashes.items():
|
||||
hash_bytes = decode_base64(hash_base64)
|
||||
self._store_event_content_hash_txn(
|
||||
txn, event.event_id, hash_alg, hash_bytes,
|
||||
)
|
||||
|
||||
for prev_event_id, prev_hashes in event.prev_events:
|
||||
for alg, hash_base64 in prev_hashes.items():
|
||||
hash_bytes = decode_base64(hash_base64)
|
||||
self._store_prev_event_hash_txn(
|
||||
txn, event.event_id, prev_event_id, alg, hash_bytes
|
||||
)
|
||||
|
||||
for auth_id, _ in event.auth_events:
|
||||
self._simple_insert_txn(
|
||||
txn,
|
||||
table="event_auth",
|
||||
values={
|
||||
"event_id": event.event_id,
|
||||
"room_id": event.room_id,
|
||||
"auth_id": auth_id,
|
||||
},
|
||||
)
|
||||
|
||||
(ref_alg, ref_hash_bytes) = compute_event_reference_hash(event)
|
||||
self._store_event_reference_hash_txn(
|
||||
txn, event.event_id, ref_alg, ref_hash_bytes
|
||||
)
|
||||
|
||||
if event.is_state():
|
||||
vals = {
|
||||
"event_id": event.event_id,
|
||||
"room_id": event.room_id,
|
||||
"type": event.type,
|
||||
"state_key": event.state_key,
|
||||
}
|
||||
|
||||
# TODO: How does this work with backfilling?
|
||||
if hasattr(event, "replaces_state"):
|
||||
vals["prev_state"] = event.replaces_state
|
||||
|
||||
self._simple_insert_txn(
|
||||
txn,
|
||||
"state_events",
|
||||
vals,
|
||||
)
|
||||
|
||||
for e_id, h in event.prev_state:
|
||||
self._simple_insert_txn(
|
||||
txn,
|
||||
table="event_edges",
|
||||
values={
|
||||
"event_id": event.event_id,
|
||||
"prev_event_id": e_id,
|
||||
"room_id": event.room_id,
|
||||
"is_state": True,
|
||||
},
|
||||
)
|
||||
|
||||
if is_new_state and not context.rejected:
|
||||
self._simple_upsert_txn(
|
||||
txn,
|
||||
"current_state_events",
|
||||
keyvalues={
|
||||
"room_id": event.room_id,
|
||||
"type": event.type,
|
||||
"state_key": event.state_key,
|
||||
},
|
||||
values={
|
||||
"event_id": event.event_id,
|
||||
}
|
||||
)
|
||||
|
||||
def _store_redaction(self, txn, event):
|
||||
# invalidate the cache for the redacted event
|
||||
self._invalidate_get_event_cache(event.redacts)
|
||||
txn.execute(
|
||||
"INSERT INTO redactions (event_id, redacts) VALUES (?,?)",
|
||||
(event.event_id, event.redacts)
|
||||
)
|
||||
|
||||
def have_events(self, event_ids):
|
||||
"""Given a list of event ids, check if we have already processed them.
|
||||
|
||||
Returns:
|
||||
dict: Has an entry for each event id we already have seen. Maps to
|
||||
the rejected reason string if we rejected the event, else maps to
|
||||
None.
|
||||
"""
|
||||
if not event_ids:
|
||||
return defer.succeed({})
|
||||
|
||||
def f(txn):
|
||||
sql = (
|
||||
"SELECT e.event_id, reason FROM events as e "
|
||||
"LEFT JOIN rejections as r ON e.event_id = r.event_id "
|
||||
"WHERE e.event_id = ?"
|
||||
)
|
||||
|
||||
res = {}
|
||||
for event_id in event_ids:
|
||||
txn.execute(sql, (event_id,))
|
||||
row = txn.fetchone()
|
||||
if row:
|
||||
_, rejected = row
|
||||
res[event_id] = rejected
|
||||
|
||||
return res
|
||||
|
||||
return self.runInteraction(
|
||||
"have_events", f,
|
||||
)
|
||||
@@ -1,47 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014, 2015 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from ._base import SQLBaseStore
|
||||
|
||||
|
||||
class FeedbackStore(SQLBaseStore):
|
||||
|
||||
def _store_feedback_txn(self, txn, event):
|
||||
self._simple_insert_txn(txn, "feedback", {
|
||||
"event_id": event.event_id,
|
||||
"feedback_type": event.content["type"],
|
||||
"room_id": event.room_id,
|
||||
"target_event_id": event.content["target_event_id"],
|
||||
"sender": event.user_id,
|
||||
})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_feedback_for_event(self, event_id):
|
||||
sql = (
|
||||
"SELECT events.* FROM events INNER JOIN feedback "
|
||||
"ON events.event_id = feedback.event_id "
|
||||
"WHERE feedback.target_event_id = ? "
|
||||
)
|
||||
|
||||
rows = yield self._execute_and_decode("get_feedback_for_event", sql, event_id)
|
||||
|
||||
defer.returnValue(
|
||||
[
|
||||
(yield self._parse_events(r))
|
||||
for r in rows
|
||||
]
|
||||
)
|
||||
@@ -31,6 +31,7 @@ class FilteringStore(SQLBaseStore):
|
||||
},
|
||||
retcol="filter_json",
|
||||
allow_none=False,
|
||||
desc="get_user_filter",
|
||||
)
|
||||
|
||||
defer.returnValue(json.loads(def_json))
|
||||
|
||||
@@ -57,16 +57,18 @@ class KeyStore(SQLBaseStore):
|
||||
OpenSSL.crypto.FILETYPE_ASN1, tls_certificate
|
||||
)
|
||||
fingerprint = hashlib.sha256(tls_certificate_bytes).hexdigest()
|
||||
return self._simple_insert(
|
||||
return self._simple_upsert(
|
||||
table="server_tls_certificates",
|
||||
values={
|
||||
keyvalues={
|
||||
"server_name": server_name,
|
||||
"fingerprint": fingerprint,
|
||||
},
|
||||
values={
|
||||
"from_server": from_server,
|
||||
"ts_added_ms": time_now_ms,
|
||||
"tls_certificate": buffer(tls_certificate_bytes),
|
||||
},
|
||||
or_ignore=True,
|
||||
desc="store_server_certificate",
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@@ -107,14 +109,85 @@ class KeyStore(SQLBaseStore):
|
||||
ts_now_ms (int): The time now in milliseconds
|
||||
verification_key (VerifyKey): The NACL verify key.
|
||||
"""
|
||||
return self._simple_insert(
|
||||
return self._simple_upsert(
|
||||
table="server_signature_keys",
|
||||
values={
|
||||
keyvalues={
|
||||
"server_name": server_name,
|
||||
"key_id": "%s:%s" % (verify_key.alg, verify_key.version),
|
||||
},
|
||||
values={
|
||||
"from_server": from_server,
|
||||
"ts_added_ms": time_now_ms,
|
||||
"verify_key": buffer(verify_key.encode()),
|
||||
},
|
||||
or_ignore=True,
|
||||
desc="store_server_verify_key",
|
||||
)
|
||||
|
||||
def store_server_keys_json(self, server_name, key_id, from_server,
|
||||
ts_now_ms, ts_expires_ms, key_json_bytes):
|
||||
"""Stores the JSON bytes for a set of keys from a server
|
||||
The JSON should be signed by the originating server, the intermediate
|
||||
server, and by this server. Updates the value for the
|
||||
(server_name, key_id, from_server) triplet if one already existed.
|
||||
Args:
|
||||
server_name (str): The name of the server.
|
||||
key_id (str): The identifer of the key this JSON is for.
|
||||
from_server (str): The server this JSON was fetched from.
|
||||
ts_now_ms (int): The time now in milliseconds.
|
||||
ts_valid_until_ms (int): The time when this json stops being valid.
|
||||
key_json (bytes): The encoded JSON.
|
||||
"""
|
||||
return self._simple_upsert(
|
||||
table="server_keys_json",
|
||||
keyvalues={
|
||||
"server_name": server_name,
|
||||
"key_id": key_id,
|
||||
"from_server": from_server,
|
||||
},
|
||||
values={
|
||||
"server_name": server_name,
|
||||
"key_id": key_id,
|
||||
"from_server": from_server,
|
||||
"ts_added_ms": ts_now_ms,
|
||||
"ts_valid_until_ms": ts_expires_ms,
|
||||
"key_json": buffer(key_json_bytes),
|
||||
},
|
||||
)
|
||||
|
||||
def get_server_keys_json(self, server_keys):
|
||||
"""Retrive the key json for a list of server_keys and key ids.
|
||||
If no keys are found for a given server, key_id and source then
|
||||
that server, key_id, and source triplet entry will be an empty list.
|
||||
The JSON is returned as a byte array so that it can be efficiently
|
||||
used in an HTTP response.
|
||||
Args:
|
||||
server_keys (list): List of (server_name, key_id, source) triplets.
|
||||
Returns:
|
||||
Dict mapping (server_name, key_id, source) triplets to dicts with
|
||||
"ts_valid_until_ms" and "key_json" keys.
|
||||
"""
|
||||
def _get_server_keys_json_txn(txn):
|
||||
results = {}
|
||||
for server_name, key_id, from_server in server_keys:
|
||||
keyvalues = {"server_name": server_name}
|
||||
if key_id is not None:
|
||||
keyvalues["key_id"] = key_id
|
||||
if from_server is not None:
|
||||
keyvalues["from_server"] = from_server
|
||||
rows = self._simple_select_list_txn(
|
||||
txn,
|
||||
"server_keys_json",
|
||||
keyvalues=keyvalues,
|
||||
retcols=(
|
||||
"key_id",
|
||||
"from_server",
|
||||
"ts_added_ms",
|
||||
"ts_valid_until_ms",
|
||||
"key_json",
|
||||
),
|
||||
)
|
||||
results[(server_name, key_id, from_server)] = rows
|
||||
return results
|
||||
return self.runInteraction(
|
||||
"get_server_keys_json", _get_server_keys_json_txn
|
||||
)
|
||||
|
||||
@@ -32,6 +32,7 @@ class MediaRepositoryStore(SQLBaseStore):
|
||||
{"media_id": media_id},
|
||||
("media_type", "media_length", "upload_name", "created_ts"),
|
||||
allow_none=True,
|
||||
desc="get_local_media",
|
||||
)
|
||||
|
||||
def store_local_media(self, media_id, media_type, time_now_ms, upload_name,
|
||||
@@ -45,7 +46,8 @@ class MediaRepositoryStore(SQLBaseStore):
|
||||
"upload_name": upload_name,
|
||||
"media_length": media_length,
|
||||
"user_id": user_id.to_string(),
|
||||
}
|
||||
},
|
||||
desc="store_local_media",
|
||||
)
|
||||
|
||||
def get_local_media_thumbnails(self, media_id):
|
||||
@@ -55,7 +57,8 @@ class MediaRepositoryStore(SQLBaseStore):
|
||||
(
|
||||
"thumbnail_width", "thumbnail_height", "thumbnail_method",
|
||||
"thumbnail_type", "thumbnail_length",
|
||||
)
|
||||
),
|
||||
desc="get_local_media_thumbnails",
|
||||
)
|
||||
|
||||
def store_local_thumbnail(self, media_id, thumbnail_width,
|
||||
@@ -70,7 +73,8 @@ class MediaRepositoryStore(SQLBaseStore):
|
||||
"thumbnail_method": thumbnail_method,
|
||||
"thumbnail_type": thumbnail_type,
|
||||
"thumbnail_length": thumbnail_length,
|
||||
}
|
||||
},
|
||||
desc="store_local_thumbnail",
|
||||
)
|
||||
|
||||
def get_cached_remote_media(self, origin, media_id):
|
||||
@@ -82,6 +86,7 @@ class MediaRepositoryStore(SQLBaseStore):
|
||||
"filesystem_id",
|
||||
),
|
||||
allow_none=True,
|
||||
desc="get_cached_remote_media",
|
||||
)
|
||||
|
||||
def store_cached_remote_media(self, origin, media_id, media_type,
|
||||
@@ -97,7 +102,8 @@ class MediaRepositoryStore(SQLBaseStore):
|
||||
"created_ts": time_now_ms,
|
||||
"upload_name": upload_name,
|
||||
"filesystem_id": filesystem_id,
|
||||
}
|
||||
},
|
||||
desc="store_cached_remote_media",
|
||||
)
|
||||
|
||||
def get_remote_media_thumbnails(self, origin, media_id):
|
||||
@@ -107,7 +113,8 @@ class MediaRepositoryStore(SQLBaseStore):
|
||||
(
|
||||
"thumbnail_width", "thumbnail_height", "thumbnail_method",
|
||||
"thumbnail_type", "thumbnail_length", "filesystem_id",
|
||||
)
|
||||
),
|
||||
desc="get_remote_media_thumbnails",
|
||||
)
|
||||
|
||||
def store_remote_media_thumbnail(self, origin, media_id, filesystem_id,
|
||||
@@ -125,5 +132,6 @@ class MediaRepositoryStore(SQLBaseStore):
|
||||
"thumbnail_type": thumbnail_type,
|
||||
"thumbnail_length": thumbnail_length,
|
||||
"filesystem_id": filesystem_id,
|
||||
}
|
||||
},
|
||||
desc="store_remote_media_thumbnail",
|
||||
)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user