1
0

Compare commits

..

275 Commits

Author SHA1 Message Date
Mark Haines
3a676b8ee3 More merging 2016-04-21 16:28:05 +01:00
Mark Haines
0d5622b088 Merge branch 'markjh/slave_event_push_actions' into markjh/split_pusher 2016-04-21 16:24:38 +01:00
Mark Haines
712030aeef Merge branch 'develop' into markjh/split_pusher 2016-04-21 16:21:49 +01:00
Mark Haines
9f53491cab Merge branch 'develop' into markjh/slave_event_push_actions 2016-04-21 16:00:43 +01:00
Mark Haines
02a27a6c4f pip install new python dependencies in jenkins.sh 2016-04-21 16:00:28 +01:00
Mark Haines
a611c968cc Merge branch 'develop' into markjh/slave_event_push_actions 2016-04-21 15:37:01 +01:00
Erik Johnston
59698906eb Make jenkins install lxml 2016-04-21 15:36:13 +01:00
Mark Haines
c0d8e0eb63 Replicate push actions 2016-04-21 15:25:58 +01:00
Erik Johnston
68ebb81e86 Merge pull request #740 from matrix-org/erikj/state_cache
Always use state cache entry if it exists
2016-04-20 13:52:59 +01:00
Erik Johnston
5bbc321588 Always use state cache entry if it exists
Also check if the resolved state matches an existing state group.
2016-04-20 11:49:10 +01:00
Erik Johnston
4cf4320593 Add some logging to state resolve_events 2016-04-20 11:06:02 +01:00
Erik Johnston
eab47ea1e5 Merge pull request #739 from matrix-org/erikj/cache_get_state_groups_for_groups
Add cache to _get_state_groups_from_groups
2016-04-19 17:37:19 +01:00
Mark Haines
f52dd35ac3 Merge pull request #738 from matrix-org/markjh/slaved_receipts
Add a slaved receipts store
2016-04-19 17:31:59 +01:00
Erik Johnston
61c7edfd34 Add cache to _get_state_groups_from_groups 2016-04-19 17:22:03 +01:00
Mark Haines
5bbd424ee0 Add a slaved receipts store 2016-04-19 17:14:08 +01:00
Erik Johnston
6ac40f7b65 Merge pull request #737 from matrix-org/erikj/spider_ssl_factory
Use tls_server_context_factory for SpiderEndpoint
2016-04-19 16:22:05 +01:00
Erik Johnston
f505575f69 Make InsecureInterceptableContextFactory work with SpiderEndpoint 2016-04-19 16:08:14 +01:00
Mark Haines
4084c58aa1 Merge pull request #736 from matrix-org/markjh/slave_invited_rooms_for_user
Replicate get_invited_rooms_for_user
2016-04-19 15:46:00 +01:00
Mark Haines
e99365f601 Replicate get_invited_rooms_for_user 2016-04-19 15:22:14 +01:00
Erik Johnston
e8884e5e9c Add self.media_repo to PreviewUrlResource 2016-04-19 14:51:34 +01:00
Erik Johnston
a7001c311b _make_dirs was moved to MediaRepository 2016-04-19 14:49:31 +01:00
Erik Johnston
9181e2f4c7 Add store to PreviewUrlResource 2016-04-19 14:48:24 +01:00
Erik Johnston
fb76a81ff7 Reorder imports 2016-04-19 14:45:05 +01:00
Mark Haines
0b282d33af Add an HTTP API for removing rejected pushers.
When a push is rejected by the push gateway then synapse needs to
remove the pusher from the database. However we probably don't want
to do that directly from the slave, so we add an HTTP API to synapse
to remove the pusher from the database.
2016-04-19 14:43:47 +01:00
Erik Johnston
48af68ba8e Merge pull request #735 from matrix-org/erikj/media_resource_cleanup
Split out BaseMediaResource into MediaRepository
2016-04-19 13:56:59 +01:00
Erik Johnston
0c93df89b6 Move MediaRepository to media_repository module 2016-04-19 11:31:43 +01:00
Erik Johnston
43f0941e8f Split out BaseMediaResource into MediaRepository
This is so that a single MediaRepository can be shared across all
resources, rather than having a "copy" per resource.

In particular this allows us to guard against both the thumbnail and
download resource triggering a download of remote content at the same
time.
2016-04-19 11:24:59 +01:00
Erik Johnston
481119f7d6 Merge pull request #734 from matrix-org/erikj/measure
Create log context in Measure if one doesn't exist
2016-04-18 19:03:57 +01:00
Erik Johnston
eb8619e256 Create log context in Measure if one doesn't exist 2016-04-18 16:08:32 +01:00
Erik Johnston
4ef7a25c10 Merge pull request #733 from matrix-org/erikj/make_member_timeout
Lower timeout for make_membership_event
2016-04-18 15:08:05 +01:00
Erik Johnston
3727a15764 Merge pull request #732 from matrix-org/erikj/login
Simplify _check_password
2016-04-18 15:07:57 +01:00
Matthew Hodgson
aaabbd3e9e explicitly pass in the charset from Content-Type to lxml to fix cyrillic woes better 2016-04-15 14:32:25 +01:00
Matthew Hodgson
84f9cac4d0 fix cyrillic URL previews by hardcoding all page decoding to UTF-8 for now, rather than relying on lxml's heuristics which seem to get it wrong 2016-04-15 13:20:08 +01:00
Erik Johnston
914f1eafac Lower timeout for make_membership_event
Calls to make_membership_event are done in response to client requests,
and so should not be retried over long timeframes.
2016-04-15 11:22:23 +01:00
Erik Johnston
6fd2f685fe Simplify _check_password 2016-04-15 11:17:18 +01:00
Erik Johnston
737aee9295 Merge pull request #731 from matrix-org/erikj/timed_otu
Use SynapseError 504 for Timeout errors
2016-04-15 10:31:17 +01:00
Erik Johnston
cb9c465707 Use SynapseError 504 for Timeout errors 2016-04-15 10:21:32 +01:00
Mark Haines
3c79bdd7a0 Fix check_password rather than inverting the meaning of _check_local_password (#730) 2016-04-14 19:00:21 +01:00
David Baker
a4c56bf67b Merge pull request #729 from matrix-org/dbkr/fix_login_nonexistent_user
Fix login to error for nonexistent users
2016-04-14 18:46:45 +01:00
David Baker
4c1b32d7e2 Fix login to error for nonexistent users
Fixes SYN-680
2016-04-14 18:28:42 +01:00
Mark Haines
03c8df54f0 Invalidate the receipt cache correctly 2016-04-14 17:25:27 +01:00
Mark Haines
c214d3e36e Merge branch 'develop' into markjh/split_pusher 2016-04-14 17:00:40 +01:00
Mark Haines
1c1b2de975 Poke the slaved pushers on new receipts 2016-04-14 16:59:56 +01:00
Matthew Hodgson
f78b479118 fix urlparse import thinko breaking tiny URLs 2016-04-14 15:23:55 +01:00
Kegsay
4802f9cdb6 Merge pull request #727 from matrix-org/kegan/fix-asapi-reg
Make v2_alpha reg follow the AS API specification
2016-04-14 15:08:54 +01:00
Kegan Dougal
83776d6219 Make v2_alpha reg follow the AS API specification
The spec is clear the key should be 'user' not 'username' and this is indeed
the case for v1. This is not true for v2_alpha though, which is what this
commit is fixing.
2016-04-14 14:52:26 +01:00
Matthew Hodgson
bd77216d06 comment out 2c838f6459 due to risk of https://en.wikipedia.org/wiki/Billion_laughs attacks - thanks @torhve 2016-04-14 14:39:24 +01:00
Erik Johnston
5a578ea4c7 Merge pull request #726 from matrix-org/erikj/push_metric
Measure push action generator
2016-04-14 13:49:09 +01:00
Erik Johnston
9ae64c9910 Measure push action generator 2016-04-14 13:42:22 +01:00
Mark Haines
f41b1a8723 Make push sort of work 2016-04-14 13:30:57 +01:00
Erik Johnston
b42ad359e9 Merge pull request #725 from matrix-org/dbkr/push_only_joined
Don't push for everyone who ever sent an RR to the room
2016-04-14 12:05:13 +01:00
David Baker
757e2c79b4 Don't push for everyone who ever sent an RR to the room 2016-04-14 12:02:50 +01:00
Erik Johnston
86e9bbc74e Add missing yield 2016-04-14 11:56:52 +01:00
Erik Johnston
e40f25ebe1 Rename log context 2016-04-14 11:54:14 +01:00
Erik Johnston
ff1d333a02 Merge pull request #724 from matrix-org/erikj/push_measure
Add push index. Add extra Measure
2016-04-14 11:46:46 +01:00
Erik Johnston
2ae91a9e2f Make send_badge private 2016-04-14 11:37:50 +01:00
Erik Johnston
d213d69fe3 Add desc arg 2016-04-14 11:36:23 +01:00
Erik Johnston
56da835eaf Add necessary logging contexts 2016-04-14 11:33:50 +01:00
Erik Johnston
96bcfb29c7 Add index 2016-04-14 11:26:33 +01:00
Erik Johnston
7be1065b8f Add extra Measure 2016-04-14 11:26:15 +01:00
Mark Haines
1209d3174e Optionally split out the pusher into a separate process 2016-04-14 11:20:48 +01:00
Erik Johnston
a2546b9082 Fix query for get_unread_push_actions_for_user_in_range 2016-04-14 11:08:31 +01:00
Erik Johnston
ceeb5b909f Merge pull request #721 from matrix-org/erikj/spider
Sanitize the optional dependencies for spider API
2016-04-14 09:59:29 +01:00
David Baker
43a89cca8e Merge pull request #722 from matrix-org/dbkr/only_unread_event_actions
Only return unread notifications
2016-04-13 14:54:26 +01:00
Erik Johnston
f338bf9257 Give install requirements 2016-04-13 14:33:48 +01:00
David Baker
767fc0b739 pep8 2016-04-13 14:23:27 +01:00
David Baker
54d08c8868 Only return unread notifications
Make get_unread_push_actions_for_user_in_range only return unread event actions, being more true to its name. Done in two separate sql queries to get actions after a read receipt and those in a room wiht no receipt at all. SQL queries by Erik.
2016-04-13 14:16:45 +01:00
Erik Johnston
5880bc5417 Merge pull request #718 from matrix-org/erikj/public_room_list
Don't return empty public rooms
2016-04-13 14:07:26 +01:00
Erik Johnston
f613a3e332 Merge pull request #720 from matrix-org/erikj/auth_chec
Don't auto log failed auth checks
2016-04-13 14:07:23 +01:00
Erik Johnston
bfe586843f Add back in helpful description for missing url_preview_ip_range_blacklist 2016-04-13 13:52:57 +01:00
Erik Johnston
d0633e6dbe Sanitize the optional dependencies for spider API 2016-04-13 13:38:09 +01:00
Erik Johnston
0f2ca8cde1 Measure Auth.check 2016-04-13 11:15:59 +01:00
Erik Johnston
c53f9d561e Don't auto log failed auth checks 2016-04-13 11:11:46 +01:00
David Baker
65141161f6 Unused member variable 2016-04-12 16:25:26 +01:00
Erik Johnston
72f454b752 Don't return empty public rooms 2016-04-12 16:06:18 +01:00
Mark Haines
10ebbaea2e Update replication.rst 2016-04-12 15:53:45 +01:00
Mark Haines
aa5ce4d450 Add some design documentation for replication 2016-04-12 15:14:10 +01:00
David Baker
d33d623f0d Merge pull request #716 from matrix-org/dbkr/get_pushers
Add get endpoint for pushers
2016-04-12 14:40:37 +01:00
David Baker
7984ffdc6a Unneccessarywhitespaceisunnecessary 2016-04-12 13:55:57 +01:00
David Baker
c1267d04c5 Oops, forgot the desc. 2016-04-12 13:55:32 +01:00
David Baker
a04c076b7f Make the /set part mandatory 2016-04-12 13:54:41 +01:00
David Baker
44891b4a0a Tidy up get_pusher functions
Decodes pushers rows on the main thread rather than the db thread and uses _simple_select_list. Also do the same to the function I copied and factor out the duplication into a helper function.
2016-04-12 13:47:17 +01:00
David Baker
7b39bcdaae Mis-named function 2016-04-12 13:35:08 +01:00
David Baker
d937f342bb Split into separate servlet classes 2016-04-12 13:33:30 +01:00
Erik Johnston
318cb1f207 Merge pull request #717 from matrix-org/erikj/backfill_state
Check if we've already backfilled events
2016-04-12 13:30:30 +01:00
Erik Johnston
c48465dbaa More comments 2016-04-12 12:48:30 +01:00
Erik Johnston
8be1a37909 More comments 2016-04-12 12:04:19 +01:00
Erik Johnston
d3d0be4167 Don't append to unused list 2016-04-12 11:59:00 +01:00
Erik Johnston
762ada1e07 Add back backfilled parameter that was removed 2016-04-12 11:58:04 +01:00
Erik Johnston
0d3da210f0 Add comment 2016-04-12 11:54:41 +01:00
Erik Johnston
cccf86dd05 Check if we've already backfilled events 2016-04-12 11:19:32 +01:00
David Baker
8a76094965 Add get endpoint for pushers
As per https://github.com/matrix-org/matrix-doc/pull/308
2016-04-11 18:00:03 +01:00
Mark Haines
790f5848b2 Fix the rule_id for .m.rule.invite_for_me (#715) 2016-04-11 16:10:39 +01:00
Mark Haines
82d7eea7e3 Move the versionstring code out of app.homeserver into util 2016-04-11 14:57:09 +01:00
David Baker
2547dffccc Merge pull request #705 from matrix-org/dbkr/pushers_use_event_actions
Change pushers to use the event_actions table
2016-04-11 12:58:55 +01:00
David Baker
9bb041791c Run unsafe proces in a loop until we've caught up
and wrap unsafe process in a try block
2016-04-11 12:48:30 +01:00
Erik Johnston
17515bae14 PEP8 2016-04-11 11:02:50 +01:00
Matthew Hodgson
4bd3d25218 Merge pull request #688 from matrix-org/matthew/preview_urls
URL previewing support
2016-04-11 10:40:29 +01:00
Matthew Hodgson
5ffacc5e84 fix typos and needless try/except from PR review 2016-04-11 10:39:16 +01:00
Matthew Hodgson
83b2f83da0 actually throw meaningful errors 2016-04-08 21:36:59 +01:00
Mark Haines
b36270b5e1 Fix pep8 warning 2016-04-08 19:52:23 +01:00
Matthew Hodgson
6ff7a79308 move local_media_repository_url_cache.sql to schema v31 2016-04-08 19:09:02 +01:00
Matthew Hodgson
af582b66bb fix typo 2016-04-08 19:08:47 +01:00
Matthew Hodgson
2460d904bd fix error checking for new SQL 2016-04-08 19:04:29 +01:00
Matthew Hodgson
1ccabe2965 more PR feedback 2016-04-08 18:58:08 +01:00
Matthew Hodgson
fb83f6a1fc fix SQL based on PR feedback 2016-04-08 18:55:38 +01:00
Matthew Hodgson
b04f81284a Add more doc 2016-04-08 18:55:27 +01:00
Matthew Hodgson
ec9331f851 Add doc 2016-04-08 18:54:18 +01:00
Matthew Hodgson
dafef5a688 Add url_preview_enabled config option to turn on/off preview_url endpoint. defaults to off.
Add url_preview_ip_range_blacklist to let admins specify internal IP ranges that must not be spidered.
Add url_preview_url_blacklist to let admins specify URL patterns that must not be spidered.
Implement a custom SpiderEndpoint and associated support classes to implement url_preview_ip_range_blacklist
Add commentary and generally address PR feedback
2016-04-08 18:37:15 +01:00
David Baker
d96a070a3a Actually check if we;re processing 2016-04-08 16:49:39 +01:00
David Baker
ed3979df5f Fix invite pushes
* If the event is an invite event, add the invitee to list of user we run push rules for (if they have a pusher etc)
 * Move invite_for_me to be higher prio than member events otherwise member events matches them
 * Spell override right
2016-04-08 15:29:59 +01:00
Erik Johnston
79fc4ff6f9 Merge pull request #677 from matrix-org/erikj/dns_cache
Read from DNS cache if within TTL
2016-04-08 14:09:56 +01:00
David Baker
7b6d519482 Make sure max stream ordering only increases 2016-04-08 14:08:16 +01:00
David Baker
52d1008661 Unsafe process should call itself if the max has changed 2016-04-08 14:06:54 +01:00
Erik Johnston
96bd8ff57c Merge pull request #707 from matrix-org/markjh/remove_changed_presencelike_data
changed_presencelike_data isn't observed anywhere so can be removed
2016-04-08 14:04:54 +01:00
David Baker
ce3fe52498 Comment why unsafe process is unsafe 2016-04-08 14:02:38 +01:00
Mark Haines
7e2f971c08 Remove some unused functions (#711)
* Remove some unused functions

* get_room_events_stream is only used in tests

* is_exclusive_room might actually be something we want
2016-04-08 14:01:56 +01:00
Mark Haines
d63b49137a Merge pull request #710 from matrix-org/markjh/move_fire
Move all the wrapper functions for distributor.fire
2016-04-08 11:39:34 +01:00
Mark Haines
b9ee5650b0 Move all the wrapper functions for distributor.fire
Move the functions inside the distributor and import them
where needed. This reduces duplication and makes it possible
for flake8 to detect when the functions aren't used in a
given file.
2016-04-08 11:01:38 +01:00
Mark Haines
caef337587 changed_presencelike_data isn't observed anywhere in synapse so can be removed 2016-04-08 10:37:19 +01:00
Mark Haines
b4a5002a6e Merge pull request #708 from matrix-org/markjh/remove_collect_presencelike_data
Call profile handler get_displayname directly
2016-04-08 09:51:36 +01:00
Mark Haines
86be915cce Call profile handler get_displayname directly rather than using collect_presencelike_data 2016-04-07 18:11:49 +01:00
David Baker
d9f38561c8 Literally a dictionary 2016-04-07 17:45:01 +01:00
David Baker
4836864f56 generate id in the main thread 2016-04-07 17:38:48 +01:00
David Baker
a4a31fa8dc Only pass in what we need 2016-04-07 17:37:19 +01:00
Erik Johnston
f942980c0b Merge pull request #701 from DoubleMalt/ldap-auth
Add LDAP authentication
2016-04-07 17:35:28 +01:00
David Baker
3fb35cbd6f Oops, inequality fail 2016-04-07 17:33:37 +01:00
David Baker
15e0f1696f Wrap process in a flag so we don't process whist already processing. 2016-04-07 17:31:08 +01:00
Mark Haines
da84fa3d74 Merge pull request #706 from matrix-org/markjh/slaveIV
Add tests for redactions
2016-04-07 17:30:37 +01:00
Matthew Hodgson
d6e7333ae4 Merge branch 'develop' into matthew/preview_urls 2016-04-07 17:26:44 +01:00
David Baker
6ec02e9ecf indenting 2016-04-07 17:24:05 +01:00
David Baker
25cd5bb697 defer.gatherResults rather than doing all the pokes in series 2016-04-07 17:22:14 +01:00
David Baker
fa129ce5b5 Add measure blocks 2016-04-07 17:12:29 +01:00
David Baker
e1e042f2a1 Add comments on min_stream_id
saying that the min stream id won't be completely accurate all the time
2016-04-07 17:09:36 +01:00
Mark Haines
ceb599e789 Add tests for redactions 2016-04-07 16:52:07 +01:00
Mark Haines
8c82b06904 Merge pull request #704 from matrix-org/markh/slaveIII
Add tests for get_latest_event_ids_in_room and get_current_state
2016-04-07 16:49:34 +01:00
David Baker
05d044aac3 pep8 2016-04-07 16:45:38 +01:00
David Baker
2d5c693fd3 Fix port script for changes merged from develop 2016-04-07 16:43:54 +01:00
Mark Haines
57fa1801c3 Add sensible __eq__ operators inside the tests.
Rather than adding them globally. This limits the changes to only
affect the tests.
2016-04-07 16:41:37 +01:00
Erik Johnston
a294b04bf0 Merge pull request #700 from matrix-org/erikj/deduplicate_joins
Deduplicate membership changes
2016-04-07 16:35:40 +01:00
David Baker
9c99ab4572 Merge remote-tracking branch 'origin/develop' into dbkr/pushers_use_event_actions 2016-04-07 16:35:22 +01:00
David Baker
d549fdfa22 Remove code that's now been obsoleted or moved elsewhere 2016-04-07 16:31:38 +01:00
Erik Johnston
95ac3078da Rename things 2016-04-07 16:07:16 +01:00
David Baker
92e3071623 Send badge count pushes.
Also fix bugs with retrying.
2016-04-07 15:39:53 +01:00
Erik Johnston
ee5aef6c72 Log contexts and squash things together 2016-04-07 15:34:21 +01:00
Erik Johnston
639cd07d6d Add comment 2016-04-07 14:24:12 +01:00
Erik Johnston
af03ecf352 Deduplicate joins 2016-04-07 14:19:02 +01:00
Mark Haines
60ec9793fb Add tests for get_latest_event_ids_in_room and get_current_state 2016-04-07 13:17:56 +01:00
Christoph Witzany
674379e673 Add myself to AUTHORS.rst
Signed-off-by: Christoph Witzany <christoph@web.crofting.com>
2016-04-07 13:01:09 +02:00
Erik Johnston
a28d066732 Merge branch 'develop' of github.com:matrix-org/synapse into erikj/dns_cache 2016-04-07 11:11:17 +01:00
Erik Johnston
8495b6d365 Merge pull request #703 from matrix-org/erikj/member
Set profile information when joining rooms remotely
2016-04-07 11:08:15 +01:00
Erik Johnston
1ef0365670 Set profile information when joining rooms remotely 2016-04-07 09:42:52 +01:00
Richard van der Hoff
87a30890a3 Merge pull request #699 from matrix-org/rav/show_own_leave_event
Let users see their own leave events
2016-04-06 17:57:06 +01:00
Christoph Witzany
ed4d18f516 fix check for failed authentication 2016-04-06 18:30:11 +02:00
Christoph Witzany
9c62fcdb68 remove line 2016-04-06 18:23:46 +02:00
Christoph Witzany
27a0c21c38 make tests for ldap more specific to not be fooled by Mocks 2016-04-06 18:23:46 +02:00
Christoph Witzany
3555a659ec output ldap version for info and to pacify pep8 2016-04-06 18:23:46 +02:00
Christoph Witzany
4c5e8adf8b conditionally import ldap 2016-04-06 18:23:46 +02:00
Christoph Witzany
875ed05bdc fix pep8 2016-04-06 18:23:46 +02:00
Christoph Witzany
67f3a50e9a fix exception handling 2016-04-06 18:23:46 +02:00
Christoph Witzany
afff321e9a code style 2016-04-06 18:23:46 +02:00
Christoph Witzany
8f0e47fae8 cleanup 2016-04-06 18:23:45 +02:00
Christoph Witzany
823b8be4b7 add tls property and twist my head around twisted 2016-04-06 18:23:45 +02:00
Christoph Witzany
92767dd703 add tls property 2016-04-06 18:23:45 +02:00
Christoph Witzany
7b9319b1c8 move LDAP authentication to AuthenticationHandler 2016-04-06 18:23:45 +02:00
Christoph Witzany
3d95405e5f Introduce LDAP authentication 2016-04-06 18:23:45 +02:00
Mark Haines
8d2bca1a90 Merge pull request #702 from matrix-org/markjh/slaveII
Test that room membership is replicated
2016-04-06 16:52:39 +01:00
David Baker
0fd1cd2400 pep8 2016-04-06 16:50:47 +01:00
Mark Haines
6bfec56796 Test that room membership is replicated 2016-04-06 16:20:13 +01:00
Mark Haines
e815763b7f Merge pull request #697 from matrix-org/markjh/slaveI
Add a slaved events store class
2016-04-06 16:19:25 +01:00
David Baker
7e2c89a37f Make pushers use the event_push_actions table instead of listening on an event stream & running the rules again. Sytest passes, but remaining to do:
* Make badges work again
 * Remove old, unused code
2016-04-06 15:42:15 +01:00
Richard van der Hoff
1e05637e37 Let users see their own leave events
... otherwise clients get confused.

Fixes https://matrix.org/jira/browse/SYN-662,
https://github.com/vector-im/vector-web/issues/368
2016-04-06 15:36:19 +01:00
Erik Johnston
b713934b2e Merge pull request #698 from matrix-org/erikj/port_script_fix
Don't require config to create database
2016-04-06 14:32:45 +01:00
Mark Haines
75fb9ac1be Add a slaved events store class
Add a test to check that get_room_names_and_aliases does the same
thing on both the master and on the slave data store.
2016-04-06 14:18:35 +01:00
Erik Johnston
8aab9d87fa Don't require config to create database 2016-04-06 14:15:45 +01:00
Mark Haines
7d11f825aa Merge pull request #694 from matrix-org/markjh/caches
Move _get_cache_dict into the SQLBaseStore
2016-04-06 13:21:25 +01:00
Mark Haines
196ebaf662 Merge pull request #695 from matrix-org/markjh/cachesII
Make the cache objects be per instance rather than being global
2016-04-06 13:21:19 +01:00
Mark Haines
87f2dec8d4 Make the cache objects be per instance rather than being global 2016-04-06 13:08:05 +01:00
Mark Haines
a1e0d316ea Move _get_cache_dict into the SQLBaseStore 2016-04-06 13:05:19 +01:00
Erik Johnston
11860637e1 Tests 2016-04-06 10:12:30 +01:00
Mark Haines
2e308a3a38 Merge pull request #692 from matrix-org/markjh/replicate_reshuffle
Separate generating the replication response...
2016-04-05 13:23:36 +01:00
Erik Johnston
c2b429ab24 Merge pull request #693 from matrix-org/erikj/backfill_self
Don't backfill from self
2016-04-05 13:04:36 +01:00
Erik Johnston
6222ae51ce Don't backfill from self 2016-04-05 12:56:29 +01:00
Erik Johnston
b29f98377d Merge pull request #691 from matrix-org/erikj/member
Fix stuck invites
2016-04-05 12:44:39 +01:00
Mark Haines
1d4deff25a Separate generating the replication response...
from doing the http request parsing to make it easier
to write unit tests for replication.
2016-04-05 11:23:57 +01:00
Erik Johnston
df727f2126 Fix stuck invites
If rejecting a remote invite fails with an error response don't fail
the entire request; instead mark the invite as locally rejected.

This fixes the bug where users can get stuck invites which they can
neither accept nor reject.
2016-04-05 11:13:24 +01:00
Erik Johnston
7a77f8b6d5 Merge pull request #690 from matrix-org/erikj/member
Store invites in a separate table.
2016-04-05 09:12:27 +01:00
Erik Johnston
0c53d750e7 Docs and indents 2016-04-04 18:02:48 +01:00
Erik Johnston
92ab45a330 Add upgrade path, rename table 2016-04-04 17:07:43 +01:00
Erik Johnston
3d76b7cb2b Store invites in a separate table. 2016-04-04 16:30:15 +01:00
Erik Johnston
bf14883a04 Merge pull request #689 from matrix-org/erikj/member
Do checks for memberships before creating events
2016-04-04 11:56:40 +01:00
Matthew Hodgson
9f7dc2bef7 Merge branch 'develop' into matthew/preview_urls 2016-04-04 00:38:21 +01:00
Matthew Hodgson
cf51c4120e report image size (bytewise) in OG meta 2016-04-03 23:57:05 +01:00
Matthew Hodgson
0834b152fb char encoding 2016-04-03 12:59:27 +01:00
Matthew Hodgson
8b98a7e8c3 pep8 2016-04-03 12:56:29 +01:00
Matthew Hodgson
eab4d462f8 fix etag typing error. fix timestamp typing error 2016-04-03 02:02:46 +01:00
Matthew Hodgson
c3916462f6 rebase all image URLs 2016-04-03 01:33:12 +01:00
Matthew Hodgson
110780b18b remove stale todo 2016-04-03 00:48:31 +01:00
Matthew Hodgson
b09e29a03c Ensure only one download for a given URL is active at a time 2016-04-03 00:47:40 +01:00
Matthew Hodgson
7426c86eb8 add a persistent cache of URL lookups, and fix up the in-memory one to work 2016-04-03 00:31:57 +01:00
Matthew Hodgson
d1b154a10f support gzip compression, and don't pass through error msgs 2016-04-02 03:06:39 +01:00
Matthew Hodgson
9377157961 how was _respond_default_thumbnail ever meant to work? 2016-04-02 02:31:45 +01:00
Matthew Hodgson
2c838f6459 pass back SVGs as their own thumbnails 2016-04-02 02:30:07 +01:00
Matthew Hodgson
5037ee0d37 handle missing dimensions without crashing 2016-04-02 02:29:57 +01:00
Matthew Hodgson
b26e8604f1 make meta comparisons case insensitive 2016-04-02 01:35:44 +01:00
Matthew Hodgson
5fd07da764 refactor calc_og; spider image URLs; fix xpath; add a (broken) expiringcache; loads of other fixes 2016-04-02 00:35:49 +01:00
Erik Johnston
d76d89323c Use computed prev event ids 2016-04-01 17:39:32 +01:00
Erik Johnston
aa82cb38e9 Remove state hack from _create_new_client_event 2016-04-01 16:36:54 +01:00
Mark Haines
89e6839a48 Merge pull request #686 from matrix-org/markjh/doc_strings
Use google style doc strings.
2016-04-01 16:20:09 +01:00
Erik Johnston
c906f30661 Do checks for memberships before creating events 2016-04-01 16:17:32 +01:00
Mark Haines
2a37467fa1 Use google style doc strings.
pycharm supports them so there is no need to use the other format.

Might as well convert the existing strings to reduce the risk of
people accidentally cargo culting the wrong doc string format.
2016-04-01 16:12:07 +01:00
Mark Haines
f2b916534b Merge pull request #684 from matrix-org/markjh/backfill_id_gen
Use a stream id generator for backfilled ids
2016-04-01 15:13:14 +01:00
Mark Haines
9bc5b4c663 Assert that the step != 0 2016-04-01 15:08:20 +01:00
Mark Haines
35b5c4ba1b use google style doc strings 2016-04-01 15:07:01 +01:00
Erik Johnston
a853cdec5b Merge pull request #685 from matrix-org/erikj/sync_leave
Add concurrently_execute function
2016-04-01 15:02:59 +01:00
Erik Johnston
3f4eb4c924 Comment 2016-04-01 14:15:27 +01:00
Erik Johnston
8d73cd502b Add concurrently_execute function 2016-04-01 14:06:00 +01:00
Mark Haines
a2866e2e6a Rename direction to step, apply checks consistently 2016-04-01 13:50:54 +01:00
Mark Haines
e36bfbab38 Use a stream id generator for backfilled ids 2016-04-01 13:29:05 +01:00
Erik Johnston
35bb465b86 Filter rooms list before chunking 2016-04-01 13:14:53 +01:00
Mark Haines
c42f46ab7d Merge pull request #682 from matrix-org/markjh/fix_invalidate
Fix the invalidation of the names and aliases cache
2016-04-01 10:52:29 +01:00
Mark Haines
7753fc6570 Fix the invalidation of the names and aliases cache 2016-04-01 10:34:51 +01:00
Matthew Hodgson
c60b751694 fix assorted redirect, unicode and screenscraping bugs 2016-04-01 02:17:48 +01:00
Matthew Hodgson
683e564815 handle spidered relative images correctly 2016-03-31 23:52:58 +01:00
Mark Haines
431aa8ada9 Merge pull request #681 from matrix-org/markjh/remove_outlier
Remove outlier parameter from compute_event_context
2016-03-31 15:44:37 +01:00
Mark Haines
dc4c1579d4 Remove outlier parameter from compute_event_context
Use event.internal_metadata.is_outlier instead.
2016-03-31 15:32:24 +01:00
Mark Haines
03e406eefc Merge pull request #680 from matrix-org/markjh/remove_is_new_state
Remove the is_new_state argument to persist event.
2016-03-31 15:14:48 +01:00
Matthew Hodgson
72550c3803 prevent choking on invalid utf-8, and handle image thumbnailing smarter 2016-03-31 15:14:14 +01:00
Mark Haines
5d06929169 Move the check for backfilled outside the for loop 2016-03-31 15:09:09 +01:00
Mark Haines
76503f95ed Remove the is_new_state argument to persist event.
Move the checks for whether an event is new state inside persist
event itself.

This was harder than expected because there wasn't enough information
passed to persist event to correctly handle invites from remote servers
for new rooms.
2016-03-31 15:00:42 +01:00
Erik Johnston
fe95943305 Merge pull request #679 from matrix-org/erikj/member
Split out RoomMemberHandler
2016-03-31 14:45:57 +01:00
Matthew Hodgson
bb9a2ca87c synthesise basig OG metadata from pages lacking it 2016-03-31 14:15:09 +01:00
Erik Johnston
d35780eda0 Split out RoomMemberHandler 2016-03-31 13:08:45 +01:00
Matthew Hodgson
0d3d7de6fc sync in changes from matrixfederationclient 2016-03-31 12:42:27 +01:00
Mark Haines
62e395f0e3 Merge pull request #676 from matrix-org/markjh/replicate_stateIII
Add replication streams for ex outliers and current state resets
2016-03-31 11:20:57 +01:00
Erik Johnston
5260db7663 Line length 2016-03-31 10:49:27 +01:00
Mark Haines
2ec5426035 Use a namedtuple rather than tuple unpacking 2016-03-31 10:33:02 +01:00
David Baker
c9500a9c1d Merge pull request #678 from matrix-org/dbkr/push_obey_enable
Don't ignore the obey overlay if the rule has an enabled attribute of False
2016-03-31 10:26:22 +01:00
Erik Johnston
f9d3665c88 Allow clock to be passed in to func 2016-03-31 10:23:48 +01:00
David Baker
c27c51484a Don't ignore the obey overlay if the rule has an enabled attribute of False
Fixes https://github.com/vector-im/vector-web/issues/1244
2016-03-31 10:12:31 +01:00
Erik Johnston
f699b8f997 Read from DNS cache if within TTL 2016-03-31 10:04:28 +01:00
Matthew Hodgson
a8a5dd3b44 handle requests with missing content-length headers (e.g. YouTube) 2016-03-31 01:55:21 +01:00
Matthew Hodgson
a68c1b15aa spell out more packages 2016-03-30 17:29:42 +01:00
Matthew Hodgson
9113316b0e typo 2016-03-30 17:29:42 +01:00
Matthew Hodgson
7178ab7da0 spell out more packages 2016-03-30 17:29:22 +01:00
Mark Haines
1fbb094c6f Add replication streams for ex outliers and current state resets 2016-03-30 17:19:56 +01:00
Mark Haines
98c460cecd Merge pull request #675 from matrix-org/markjh/replicate_stateII
Add a replication stream for state groups
2016-03-30 16:40:17 +01:00
Mark Haines
8b8052909f return the state_group for backfill 2016-03-30 16:20:07 +01:00
Mark Haines
61407986b4 Add a entry to current_state_resets table when the current state is reset 2016-03-30 16:18:46 +01:00
Mark Haines
31a9eceda5 Add a replication stream for state groups 2016-03-30 16:01:58 +01:00
Mark Haines
fc66df1e60 Merge pull request #674 from matrix-org/markjh/replicate_state
Use a stream id generator to assign state group ids
2016-03-30 15:58:49 +01:00
Erik Johnston
178c9fb200 Merge pull request #673 from matrix-org/erikj/forget
Require user to have left room to forget room
2016-03-30 15:55:24 +01:00
Erik Johnston
73b6bf4629 Only forget room if you were in the room 2016-03-30 15:09:18 +01:00
Erik Johnston
08a8514b7a Remove spurious comment 2016-03-30 15:05:33 +01:00
Erik Johnston
d24662b88a Merge branch 'master' of github.com:matrix-org/synapse into develop 2016-03-30 14:41:31 +01:00
Mark Haines
1e25f62ee6 Use a stream id generator to assign state group ids 2016-03-30 12:55:02 +01:00
Erik Johnston
fddb6fddc1 Require user to have left room to forget room
This dramatically simplifies the forget API code - in particular it no
longer generates a leave event.
2016-03-30 11:03:00 +01:00
Erik Johnston
f5bf45a2e5 Merge pull request #671 from nikriek/jwt-support
Support login using Javascript Web Tokens (JWT)
2016-03-29 16:31:42 +01:00
Niklas Riekenbrauck
3f9948a069 Add JWT support 2016-03-29 14:36:36 +02:00
Matthew Hodgson
ae5831d303 fix bugs 2016-03-29 03:32:55 +01:00
Matthew Hodgson
721b2bfa85 implement redirects 2016-03-29 03:32:52 +01:00
Matthew Hodgson
19038582d3 debug 2016-03-29 03:14:16 +01:00
Matthew Hodgson
64b4aead15 make it work 2016-03-29 03:13:25 +01:00
Matthew Hodgson
dd4287ca5d make it build 2016-03-29 02:07:57 +01:00
Matthew Hodgson
e0c2490a14 Merge branch 'develop' into matthew/preview_urls 2016-03-29 01:20:25 +01:00
Matthew Hodgson
ec0cf996c9 typo 2016-03-29 01:20:14 +01:00
Matthew Hodgson
d9d48aad2d Merge branch 'develop' into matthew/preview_urls 2016-03-27 22:54:42 +01:00
Matthew Hodgson
adafa24b0a typo 2016-03-25 23:38:19 +00:00
Mark Haines
3e8bb99a2b Merge pull request #668 from matrix-org/markjh/deduplicate
Deduplicate identical /sync requests
2016-03-24 18:07:30 +00:00
Mark Haines
77cba688ed Fix typo 2016-03-24 18:02:37 +00:00
Mark Haines
54a546091a Add a response cache for getting the public room list 2016-03-24 18:02:10 +00:00
Mark Haines
191c7bef6b Deduplicate identical /sync requests 2016-03-24 17:47:31 +00:00
David Baker
31e6f8636f Merge pull request #667 from matrix-org/dbkr/never_notify_member_events
Never notify for member events.
2016-03-24 13:48:02 +00:00
David Baker
3b554bda26 Never notify for member events. This fixes https://github.com/vector-im/vector-web/issues/828 2016-03-24 13:19:39 +00:00
Matthew Hodgson
7dd0c1730a initial WIP of a tentative preview_url endpoint - incomplete, untested, experimental, etc. just putting it here for safekeeping for now 2016-01-24 18:47:27 -05:00
126 changed files with 6341 additions and 3281 deletions

View File

@@ -57,3 +57,6 @@ Florent Violleau <floviolleau at gmail dot com>
Niklas Riekenbrauck <nikriek at gmail dot.com>
* Add JWT support for registration and login
Christoph Witzany <christoph at web.crofting.com>
* Add LDAP support for authentication

View File

@@ -104,7 +104,7 @@ Installing prerequisites on Ubuntu or Debian::
sudo apt-get install build-essential python2.7-dev libffi-dev \
python-pip python-setuptools sqlite3 \
libssl-dev python-virtualenv libjpeg-dev
libssl-dev python-virtualenv libjpeg-dev libxslt1-dev
Installing prerequisites on ArchLinux::
@@ -118,7 +118,6 @@ Installing prerequisites on CentOS 7::
python-virtualenv libffi-devel openssl-devel
sudo yum groupinstall "Development Tools"
Installing prerequisites on Mac OS X::
xcode-select --install
@@ -150,12 +149,7 @@ In case of problems, please see the _Troubleshooting section below.
Alternatively, Silvio Fricke has contributed a Dockerfile to automate the
above in Docker at https://registry.hub.docker.com/u/silviof/docker-matrix/.
Another alternative is to install via apt from http://matrix.org/packages/debian/.
Note that these packages do not include a client - choose one from
https://matrix.org/blog/try-matrix-now/ (or build your own with
https://github.com/matrix-org/matrix-js-sdk/).
Finally, Martin Giess has created an auto-deployment process with vagrant/ansible,
Also, Martin Giess has created an auto-deployment process with vagrant/ansible,
tested with VirtualBox/AWS/DigitalOcean - see https://github.com/EMnify/matrix-synapse-auto-deploy
for details.
@@ -229,6 +223,19 @@ For information on how to install and use PostgreSQL, please see
Platform Specific Instructions
==============================
Debian
------
Matrix provides official Debian packages via apt from http://matrix.org/packages/debian/.
Note that these packages do not include a client - choose one from
https://matrix.org/blog/try-matrix-now/ (or build your own with one of our SDKs :)
Fedora
------
Oleg Girko provides Fedora RPMs at
https://obs.infoserver.lv/project/monitor/matrix-synapse
ArchLinux
---------
@@ -270,11 +277,17 @@ During setup of Synapse you need to call python2.7 directly again::
FreeBSD
-------
Synapse can be installed via FreeBSD Ports or Packages:
Synapse can be installed via FreeBSD Ports or Packages contributed by Brendan Molloy from:
- Ports: ``cd /usr/ports/net/py-matrix-synapse && make install clean``
- Packages: ``pkg install py27-matrix-synapse``
NixOS
-----
Robin Lambertz has packaged Synapse for NixOS at:
https://github.com/NixOS/nixpkgs/blob/master/nixos/modules/services/misc/matrix-synapse.nix
Windows Install
---------------
Synapse can be installed on Cygwin. It requires the following Cygwin packages:
@@ -544,6 +557,23 @@ as the primary means of identity and E2E encryption is not complete. As such,
we are running a single identity server (https://matrix.org) at the current
time.
URL Previews
============
Synapse 0.15.0 introduces an experimental new API for previewing URLs at
/_matrix/media/r0/preview_url. This is disabled by default. To turn it on
you must enable the `url_preview_enabled: True` config parameter and explicitly
specify the IP ranges that Synapse is not allowed to spider for previewing in
the `url_preview_ip_range_blacklist` configuration parameter. This is critical
from a security perspective to stop arbitrary Matrix users spidering 'internal'
URLs on your network. At the very least we recommend that your loopback and
RFC1918 IP addresses are blacklisted.
This also requires the optional lxml and netaddr python dependencies to be
installed.
Password reset
==============

View File

@@ -30,6 +30,14 @@ running:
python synapse/python_dependencies.py | xargs -n1 pip install
Upgrading to v0.15.0
====================
If you want to use the new URL previewing API (/_matrix/media/r0/preview_url)
then you have to explicitly enable it in the config and update your dependencies
dependencies. See README.rst for details.
Upgrading to v0.11.0
====================

58
docs/replication.rst Normal file
View File

@@ -0,0 +1,58 @@
Replication Architecture
========================
Motivation
----------
We'd like to be able to split some of the work that synapse does into multiple
python processes. In theory multiple synapse processes could share a single
postgresql database and we'd scale up by running more synapse processes.
However much of synapse assumes that only one process is interacting with the
database, both for assigning unique identifiers when inserting into tables,
notifying components about new updates, and for invalidating its caches.
So running multiple copies of the current code isn't an option. One way to
run multiple processes would be to have a single writer process and multiple
reader processes connected to the same database. In order to do this we'd need
a way for the reader process to invalidate its in-memory caches when an update
happens on the writer. One way to do this is for the writer to present an
append-only log of updates which the readers can consume to invalidate their
caches and to push updates to listening clients or pushers.
Synapse already stores much of its data as an append-only log so that it can
correctly respond to /sync requests so the amount of code changes needed to
expose the append-only log to the readers should be fairly minimal.
Architecture
------------
The Replication API
~~~~~~~~~~~~~~~~~~~
Synapse will optionally expose a long poll HTTP API for extracting updates. The
API will have a similar shape to /sync in that clients provide tokens
indicating where in the log they have reached and a timeout. The synapse server
then either responds with updates immediately if it already has updates or it
waits until the timeout for more updates. If the timeout expires and nothing
happened then the server returns an empty response.
However unlike the /sync API this replication API is returning synapse specific
data rather than trying to implement a matrix specification. The replication
results are returned as arrays of rows where the rows are mostly lifted
directly from the database. This avoids unnecessary JSON parsing on the server
and hopefully avoids an impedance mismatch between the data returned and the
required updates to the datastore.
This does not replicate all the database tables as many of the database tables
are indexes that can be recovered from the contents of other tables.
The format and parameters for the api are documented in
``synapse/replication/resource.py``.
The Slaved DataStore
~~~~~~~~~~~~~~~~~~~~
There are read-only version of the synapse storage layer in
``synapse/replication/slave/storage`` that use the response of the replication
API to invalidate their caches.

74
docs/url_previews.rst Normal file
View File

@@ -0,0 +1,74 @@
URL Previews
============
Design notes on a URL previewing service for Matrix:
Options are:
1. Have an AS which listens for URLs, downloads them, and inserts an event that describes their metadata.
* Pros:
* Decouples the implementation entirely from Synapse.
* Uses existing Matrix events & content repo to store the metadata.
* Cons:
* Which AS should provide this service for a room, and why should you trust it?
* Doesn't work well with E2E; you'd have to cut the AS into every room
* the AS would end up subscribing to every room anyway.
2. Have a generic preview API (nothing to do with Matrix) that provides a previewing service:
* Pros:
* Simple and flexible; can be used by any clients at any point
* Cons:
* If each HS provides one of these independently, all the HSes in a room may needlessly DoS the target URI
* We need somewhere to store the URL metadata rather than just using Matrix itself
* We can't piggyback on matrix to distribute the metadata between HSes.
3. Make the synapse of the sending user responsible for spidering the URL and inserting an event asynchronously which describes the metadata.
* Pros:
* Works transparently for all clients
* Piggy-backs nicely on using Matrix for distributing the metadata.
* No confusion as to which AS
* Cons:
* Doesn't work with E2E
* We might want to decouple the implementation of the spider from the HS, given spider behaviour can be quite complicated and evolve much more rapidly than the HS. It's more like a bot than a core part of the server.
4. Make the sending client use the preview API and insert the event itself when successful.
* Pros:
* Works well with E2E
* No custom server functionality
* Lets the client customise the preview that they send (like on FB)
* Cons:
* Entirely specific to the sending client, whereas it'd be nice if /any/ URL was correctly previewed if clients support it.
5. Have the option of specifying a shared (centralised) previewing service used by a room, to avoid all the different HSes in the room DoSing the target.
Best solution is probably a combination of both 2 and 4.
* Sending clients do their best to create and send a preview at the point of sending the message, perhaps delaying the message until the preview is computed? (This also lets the user validate the preview before sending)
* Receiving clients have the option of going and creating their own preview if one doesn't arrive soon enough (or if the original sender didn't create one)
This is a bit magical though in that the preview could come from two entirely different sources - the sending HS or your local one. However, this can always be exposed to users: "Generate your own URL previews if none are available?"
This is tantamount also to senders calculating their own thumbnails for sending in advance of the main content - we are trusting the sender not to lie about the content in the thumbnail. Whereas currently thumbnails are calculated by the receiving homeserver to avoid this attack.
However, this kind of phishing attack does exist whether we let senders pick their thumbnails or not, in that a malicious sender can send normal text messages around the attachment claiming it to be legitimate. We could rely on (future) reputation/abuse management to punish users who phish (be it with bogus metadata or bogus descriptions). Bogus metadata is particularly bad though, especially if it's avoidable.
As a first cut, let's do #2 and have the receiver hit the API to calculate its own previews (as it does currently for image thumbnails). We can then extend/optimise this to option 4 as a special extra if needed.
API
---
GET /_matrix/media/r0/preview_url?url=http://wherever.com
200 OK
{
"og:type" : "article"
"og:url" : "https://twitter.com/matrixdotorg/status/684074366691356672"
"og:title" : "Matrix on Twitter"
"og:image" : "https://pbs.twimg.com/profile_images/500400952029888512/yI0qtFi7_400x400.png"
"og:description" : "“Synapse 0.12 is out! Lots of polishing, performance &amp;amp; bugfixes: /sync API, /r0 prefix, fulltext search, 3PID invites https://t.co/5alhXLLEGP”"
"og:site_name" : "Twitter"
}
* Downloads the URL
* If HTML, just stores it in RAM and parses it for OG meta tags
* Download any media OG meta tags to the media repo, and refer to them in the OG via mxc:// URIs.
* If a media filetype we know we can thumbnail: store it on disk, and hand it to the thumbnailer. Generate OG meta tags from the thumbnailer contents.
* Otherwise, don't bother downloading further.

View File

@@ -25,7 +25,9 @@ rm .coverage* || echo "No coverage files to remove"
tox --notest -e py27
TOX_BIN=$WORKSPACE/.tox/py27/bin
python synapse/python_dependencies.py | xargs -n1 $TOX_BIN/pip install
$TOX_BIN/pip install psycopg2
$TOX_BIN/pip install lxml
: ${GIT_BRANCH:="origin/$(git rev-parse --abbrev-ref HEAD)"}

View File

@@ -24,6 +24,8 @@ rm .coverage* || echo "No coverage files to remove"
tox --notest -e py27
TOX_BIN=$WORKSPACE/.tox/py27/bin
python synapse/python_dependencies.py | xargs -n1 $TOX_BIN/pip install
$TOX_BIN/pip install lxml
: ${GIT_BRANCH:="origin/$(git rev-parse --abbrev-ref HEAD)"}

View File

@@ -1,86 +0,0 @@
#!/bin/bash
set -eux
: ${WORKSPACE:="$(pwd)"}
export PYTHONDONTWRITEBYTECODE=yep
export SYNAPSE_CACHE_FACTOR=1
# Output test results as junit xml
export TRIAL_FLAGS="--reporter=subunit"
export TOXSUFFIX="| subunit-1to2 | subunit2junitxml --no-passthrough --output-to=results.xml"
# Write coverage reports to a separate file for each process
export COVERAGE_OPTS="-p"
export DUMP_COVERAGE_COMMAND="coverage help"
# Output flake8 violations to violations.flake8.log
# Don't exit with non-0 status code on Jenkins,
# so that the build steps continue and a later step can decided whether to
# UNSTABLE or FAILURE this build.
export PEP8SUFFIX="--output-file=violations.flake8.log || echo flake8 finished with status code \$?"
rm .coverage* || echo "No coverage files to remove"
tox
: ${GIT_BRANCH:="origin/$(git rev-parse --abbrev-ref HEAD)"}
TOX_BIN=$WORKSPACE/.tox/py27/bin
if [[ ! -e .sytest-base ]]; then
git clone https://github.com/matrix-org/sytest.git .sytest-base --mirror
else
(cd .sytest-base; git fetch -p)
fi
rm -rf sytest
git clone .sytest-base sytest --shared
cd sytest
git checkout "${GIT_BRANCH}" || (echo >&2 "No ref ${GIT_BRANCH} found, falling back to develop" ; git checkout develop)
: ${PERL5LIB:=$WORKSPACE/perl5/lib/perl5}
: ${PERL_MB_OPT:=--install_base=$WORKSPACE/perl5}
: ${PERL_MM_OPT:=INSTALL_BASE=$WORKSPACE/perl5}
export PERL5LIB PERL_MB_OPT PERL_MM_OPT
./install-deps.pl
: ${PORT_BASE:=8000}
echo >&2 "Running sytest with SQLite3";
./run-tests.pl --coverage -O tap --synapse-directory $WORKSPACE \
--python $TOX_BIN/python --all --port-base $PORT_BASE > results-sqlite3.tap
RUN_POSTGRES=""
for port in $(($PORT_BASE + 1)) $(($PORT_BASE + 2)); do
if psql synapse_jenkins_$port <<< ""; then
RUN_POSTGRES="$RUN_POSTGRES:$port"
cat > localhost-$port/database.yaml << EOF
name: psycopg2
args:
database: synapse_jenkins_$port
EOF
fi
done
# Run if both postgresql databases exist
if test "$RUN_POSTGRES" = ":$(($PORT_BASE + 1)):$(($PORT_BASE + 2))"; then
echo >&2 "Running sytest with PostgreSQL";
$TOX_BIN/pip install psycopg2
./run-tests.pl --coverage -O tap --synapse-directory $WORKSPACE \
--python $TOX_BIN/python --all --port-base $PORT_BASE > results-postgresql.tap
else
echo >&2 "Skipping running sytest with PostgreSQL, $RUN_POSTGRES"
fi
cd ..
cp sytest/.coverage.* .
# Combine the coverage reports
echo "Combining:" .coverage.*
$TOX_BIN/python -m coverage combine
# Output coverage to coverage.xml
$TOX_BIN/coverage xml -o coverage.xml

View File

@@ -19,6 +19,7 @@ from twisted.enterprise import adbapi
from synapse.storage._base import LoggingTransaction, SQLBaseStore
from synapse.storage.engines import create_engine
from synapse.storage.prepare_database import prepare_database
import argparse
import curses
@@ -37,6 +38,7 @@ BOOLEAN_COLUMNS = {
"rooms": ["is_public"],
"event_edges": ["is_state"],
"presence_list": ["accepted"],
"presence_stream": ["currently_active"],
}
@@ -292,7 +294,7 @@ class Porter(object):
}
)
database_engine.prepare_database(db_conn)
prepare_database(db_conn, database_engine, config=None)
db_conn.commit()
@@ -309,8 +311,8 @@ class Porter(object):
**self.postgres_config["args"]
)
sqlite_engine = create_engine(FakeConfig(sqlite_config))
postgres_engine = create_engine(FakeConfig(postgres_config))
sqlite_engine = create_engine(sqlite_config)
postgres_engine = create_engine(postgres_config)
self.sqlite_store = Store(sqlite_db_pool, sqlite_engine)
self.postgres_store = Store(postgres_db_pool, postgres_engine)
@@ -792,8 +794,3 @@ if __name__ == "__main__":
if end_error_exec_info:
exc_type, exc_value, exc_traceback = end_error_exec_info
traceback.print_exception(exc_type, exc_value, exc_traceback)
class FakeConfig:
def __init__(self, database_config):
self.database_config = database_config

View File

@@ -17,3 +17,6 @@ ignore =
[flake8]
max-line-length = 90
ignore = W503 ; W503 requires that binary operators be at the end, not start, of lines. Erik doesn't like it.
[pep8]
max-line-length = 90

View File

@@ -25,6 +25,7 @@ from synapse.api.errors import AuthError, Codes, SynapseError, EventSizeError
from synapse.types import Requester, RoomID, UserID, EventID
from synapse.util.logutils import log_function
from synapse.util.logcontext import preserve_context_over_fn
from synapse.util.metrics import Measure
from unpaddedbase64 import decode_base64
import logging
@@ -44,6 +45,7 @@ class Auth(object):
def __init__(self, hs):
self.hs = hs
self.clock = hs.get_clock()
self.store = hs.get_datastore()
self.state = hs.get_state_handler()
self.TOKEN_NOT_FOUND_HTTP_STATUS = 401
@@ -66,9 +68,9 @@ class Auth(object):
Returns:
True if the auth checks pass.
"""
self.check_size_limits(event)
with Measure(self.clock, "auth.check"):
self.check_size_limits(event)
try:
if not hasattr(event, "room_id"):
raise AuthError(500, "Event has no room_id: %s" % event)
if auth_events is None:
@@ -127,13 +129,6 @@ class Auth(object):
self.check_redaction(event, auth_events)
logger.debug("Allowing! %s", event)
except AuthError as e:
logger.info(
"Event auth check failed on event %s with msg: %s",
event, e.msg
)
logger.info("Denying! %s", event)
raise
def check_size_limits(self, event):
def too_big(field):

View File

@@ -20,8 +20,6 @@ import contextlib
import logging
import os
import re
import resource
import subprocess
import sys
import time
from synapse.config._base import ConfigError
@@ -33,7 +31,7 @@ from synapse.python_dependencies import (
from synapse.rest import ClientRestResource
from synapse.storage.engines import create_engine, IncorrectDatabaseSetup
from synapse.storage import are_all_users_on_domain
from synapse.storage.prepare_database import UpgradeDatabaseException
from synapse.storage.prepare_database import UpgradeDatabaseException, prepare_database
from synapse.server import HomeServer
@@ -66,6 +64,9 @@ from synapse.metrics.resource import MetricsResource, METRICS_PREFIX
from synapse.replication.resource import ReplicationResource, REPLICATION_PREFIX
from synapse.federation.transport.server import TransportLayerServer
from synapse.util.rlimit import change_resource_limit
from synapse.util.versionstring import get_version_string
from synapse import events
from daemonize import Daemonize
@@ -245,7 +246,7 @@ class SynapseHomeServer(HomeServer):
except IncorrectDatabaseSetup as e:
quit_with_error(e.message)
def get_db_conn(self):
def get_db_conn(self, run_new_connection=True):
# Any param beginning with cp_ is a parameter for adbapi, and should
# not be passed to the database engine.
db_params = {
@@ -254,7 +255,8 @@ class SynapseHomeServer(HomeServer):
}
db_conn = self.database_engine.module.connect(**db_params)
self.database_engine.on_new_connection(db_conn)
if run_new_connection:
self.database_engine.on_new_connection(db_conn)
return db_conn
@@ -268,86 +270,6 @@ def quit_with_error(error_string):
sys.exit(1)
def get_version_string():
try:
null = open(os.devnull, 'w')
cwd = os.path.dirname(os.path.abspath(__file__))
try:
git_branch = subprocess.check_output(
['git', 'rev-parse', '--abbrev-ref', 'HEAD'],
stderr=null,
cwd=cwd,
).strip()
git_branch = "b=" + git_branch
except subprocess.CalledProcessError:
git_branch = ""
try:
git_tag = subprocess.check_output(
['git', 'describe', '--exact-match'],
stderr=null,
cwd=cwd,
).strip()
git_tag = "t=" + git_tag
except subprocess.CalledProcessError:
git_tag = ""
try:
git_commit = subprocess.check_output(
['git', 'rev-parse', '--short', 'HEAD'],
stderr=null,
cwd=cwd,
).strip()
except subprocess.CalledProcessError:
git_commit = ""
try:
dirty_string = "-this_is_a_dirty_checkout"
is_dirty = subprocess.check_output(
['git', 'describe', '--dirty=' + dirty_string],
stderr=null,
cwd=cwd,
).strip().endswith(dirty_string)
git_dirty = "dirty" if is_dirty else ""
except subprocess.CalledProcessError:
git_dirty = ""
if git_branch or git_tag or git_commit or git_dirty:
git_version = ",".join(
s for s in
(git_branch, git_tag, git_commit, git_dirty,)
if s
)
return (
"Synapse/%s (%s)" % (
synapse.__version__, git_version,
)
).encode("ascii")
except Exception as e:
logger.info("Failed to check for git repository: %s", e)
return ("Synapse/%s" % (synapse.__version__,)).encode("ascii")
def change_resource_limit(soft_file_no):
try:
soft, hard = resource.getrlimit(resource.RLIMIT_NOFILE)
if not soft_file_no:
soft_file_no = hard
resource.setrlimit(resource.RLIMIT_NOFILE, (soft_file_no, hard))
logger.info("Set file limit to: %d", soft_file_no)
resource.setrlimit(
resource.RLIMIT_CORE, (resource.RLIM_INFINITY, resource.RLIM_INFINITY)
)
except (ValueError, resource.error) as e:
logger.warn("Failed to set file or core limit: %s", e)
def setup(config_options):
"""
Args:
@@ -377,7 +299,7 @@ def setup(config_options):
# check any extra requirements we have now we have a config
check_requirements(config)
version_string = get_version_string()
version_string = get_version_string("Synapse", synapse)
logger.info("Server hostname: %s", config.server_name)
logger.info("Server version: %s", version_string)
@@ -386,7 +308,7 @@ def setup(config_options):
tls_server_context_factory = context_factory.ServerContextFactory(config)
database_engine = create_engine(config)
database_engine = create_engine(config.database_config)
config.database_config["args"]["cp_openfun"] = database_engine.on_new_connection
hs = SynapseHomeServer(
@@ -402,8 +324,10 @@ def setup(config_options):
logger.info("Preparing database: %s...", config.database_config['name'])
try:
db_conn = hs.get_db_conn()
database_engine.prepare_database(db_conn)
db_conn = hs.get_db_conn(run_new_connection=False)
prepare_database(db_conn, database_engine, config=config)
database_engine.on_new_connection(db_conn)
hs.run_startup_checks(db_conn, database_engine)
db_conn.commit()

206
synapse/app/pusher.py Normal file
View File

@@ -0,0 +1,206 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import synapse
from synapse.server import HomeServer
from synapse.util.versionstring import get_version_string
from synapse.config._base import ConfigError
from synapse.config.database import DatabaseConfig
from synapse.config.logger import LoggingConfig
from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.replication.slave.storage.pushers import SlavedPusherStore
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
from synapse.storage.engines import create_engine
from synapse.storage import DataStore
from synapse.util.async import sleep
from synapse.util.logcontext import (LoggingContext, preserve_fn)
from twisted.internet import reactor, defer
import sys
import logging
logger = logging.getLogger("synapse.app.pusher")
class SlaveConfig(DatabaseConfig):
def read_config(self, config):
self.replication_url = config["replication_url"]
self.server_name = config["server_name"]
self.use_insecure_ssl_client_just_for_testing_do_not_use = True
self.user_agent_suffix = None
self.start_pushers = True
def default_config(self, **kwargs):
return """\
## Slave ##
#replication_url: https://localhost:{replication_port}/_synapse/replication
report_stats: False
"""
class PusherSlaveConfig(SlaveConfig, LoggingConfig):
pass
class PusherSlaveStore(
SlavedEventStore, SlavedPusherStore, SlavedReceiptsStore
):
update_pusher_last_stream_ordering_and_success = (
DataStore.update_pusher_last_stream_ordering_and_success.__func__
)
class PusherServer(HomeServer):
def get_db_conn(self, run_new_connection=True):
# Any param beginning with cp_ is a parameter for adbapi, and should
# not be passed to the database engine.
db_params = {
k: v for k, v in self.db_config.get("args", {}).items()
if not k.startswith("cp_")
}
db_conn = self.database_engine.module.connect(**db_params)
if run_new_connection:
self.database_engine.on_new_connection(db_conn)
return db_conn
def setup(self):
logger.info("Setting up.")
self.datastore = PusherSlaveStore(self.get_db_conn(), self)
logger.info("Finished setting up.")
def remove_pusher(self, app_id, push_key, user_id):
http_client = self.get_simple_http_client()
replication_url = self.config.replication_url
url = replication_url + "/remove_pushers"
return http_client.post_json_get_json(url, {
"remove": [{
"app_id": app_id,
"push_key": push_key,
"user_id": user_id,
}]
})
@defer.inlineCallbacks
def replicate(self):
http_client = self.get_simple_http_client()
store = self.get_datastore()
replication_url = self.config.replication_url
pusher_pool = self.get_pusherpool()
def stop_pusher(user_id, app_id, pushkey):
key = "%s:%s" % (app_id, pushkey)
pushers_for_user = pusher_pool.pushers.get(user_id, {})
pusher = pushers_for_user.pop(key, None)
if pusher is None:
return
logger.info("Stopping pusher %r / %r", user_id, key)
pusher.on_stop()
def start_pusher(user_id, app_id, pushkey):
key = "%s:%s" % (app_id, pushkey)
logger.info("Starting pusher %r / %r", user_id, key)
return pusher_pool._refresh_pusher(app_id, pushkey, user_id)
@defer.inlineCallbacks
def poke_pushers(results):
pushers_rows = set(
map(tuple, results.get("pushers", {}).get("rows", []))
)
deleted_pushers_rows = set(
map(tuple, results.get("deleted_pushers", {}).get("rows", []))
)
for row in sorted(pushers_rows | deleted_pushers_rows):
if row in deleted_pushers_rows:
user_id, app_id, pushkey = row[1:4]
stop_pusher(user_id, app_id, pushkey)
elif row in pushers_rows:
user_id = row[1]
app_id = row[5]
pushkey = row[8]
yield start_pusher(user_id, app_id, pushkey)
stream = results.get("events")
if stream:
min_stream_id = stream["rows"][0][0]
max_stream_id = stream["position"]
preserve_fn(pusher_pool.on_new_notifications)(
min_stream_id, max_stream_id
)
stream = results.get("receipts")
if stream:
rows = stream["rows"]
affected_room_ids = set(row[1] for row in rows)
min_stream_id = rows[0][0]
max_stream_id = stream["position"]
preserve_fn(pusher_pool.on_new_receipts)(
min_stream_id, max_stream_id, affected_room_ids
)
while True:
try:
args = store.stream_positions()
args["timeout"] = 30000
result = yield http_client.get_json(replication_url, args=args)
yield store.process_replication(result)
poke_pushers(result)
except:
logger.exception("Error replicating from %r", replication_url)
sleep(30)
def setup(config_options):
try:
config = PusherSlaveConfig.load_config(
"Synapse pusher", config_options
)
except ConfigError as e:
sys.stderr.write("\n" + e.message + "\n")
sys.exit(1)
config.setup_logging()
database_engine = create_engine(config.database_config)
ps = PusherServer(
config.server_name,
db_config=config.database_config,
config=config,
version_string=get_version_string("Synapse", synapse),
database_engine=database_engine,
)
ps.setup()
def start():
ps.replicate()
ps.get_pusherpool().start()
ps.get_datastore().start_profiling()
reactor.callWhenRunning(start)
return ps
if __name__ == '__main__':
with LoggingContext("main"):
ps = setup(sys.argv[1:])
reactor.run()

View File

@@ -100,11 +100,6 @@ class ApplicationServiceApi(SimpleHttpClient):
logger.warning("push_bulk to %s threw exception %s", uri, ex)
defer.returnValue(False)
@defer.inlineCallbacks
def push(self, service, event, txn_id=None):
response = yield self.push_bulk(service, [event], txn_id)
defer.returnValue(response)
def _serialize(self, events):
time_now = self.clock.time_msec()
return [

View File

@@ -29,13 +29,15 @@ from .key import KeyConfig
from .saml2 import SAML2Config
from .cas import CasConfig
from .password import PasswordConfig
from .jwt import JWTConfig
from .ldap import LDAPConfig
class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
RatelimitConfig, ContentRepositoryConfig, CaptchaConfig,
VoipConfig, RegistrationConfig, MetricsConfig, ApiConfig,
AppServiceConfig, KeyConfig, SAML2Config, CasConfig,
PasswordConfig,):
JWTConfig, LDAPConfig, PasswordConfig,):
pass

37
synapse/config/jwt.py Normal file
View File

@@ -0,0 +1,37 @@
# -*- coding: utf-8 -*-
# Copyright 2015 Niklas Riekenbrauck
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import Config
class JWTConfig(Config):
def read_config(self, config):
jwt_config = config.get("jwt_config", None)
if jwt_config:
self.jwt_enabled = jwt_config.get("enabled", False)
self.jwt_secret = jwt_config["secret"]
self.jwt_algorithm = jwt_config["algorithm"]
else:
self.jwt_enabled = False
self.jwt_secret = None
self.jwt_algorithm = None
def default_config(self, **kwargs):
return """\
# jwt_config:
# enabled: true
# secret: "a secret"
# algorithm: "HS256"
"""

52
synapse/config/ldap.py Normal file
View File

@@ -0,0 +1,52 @@
# -*- coding: utf-8 -*-
# Copyright 2015 Niklas Riekenbrauck
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import Config
class LDAPConfig(Config):
def read_config(self, config):
ldap_config = config.get("ldap_config", None)
if ldap_config:
self.ldap_enabled = ldap_config.get("enabled", False)
self.ldap_server = ldap_config["server"]
self.ldap_port = ldap_config["port"]
self.ldap_tls = ldap_config.get("tls", False)
self.ldap_search_base = ldap_config["search_base"]
self.ldap_search_property = ldap_config["search_property"]
self.ldap_email_property = ldap_config["email_property"]
self.ldap_full_name_property = ldap_config["full_name_property"]
else:
self.ldap_enabled = False
self.ldap_server = None
self.ldap_port = None
self.ldap_tls = False
self.ldap_search_base = None
self.ldap_search_property = None
self.ldap_email_property = None
self.ldap_full_name_property = None
def default_config(self, **kwargs):
return """\
# ldap_config:
# enabled: true
# server: "ldap://localhost"
# port: 389
# tls: false
# search_base: "ou=Users,dc=example,dc=com"
# search_property: "cn"
# email_property: "email"
# full_name_property: "givenName"
"""

View File

@@ -13,9 +13,25 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import Config
from ._base import Config, ConfigError
from collections import namedtuple
MISSING_NETADDR = (
"Missing netaddr library. This is required for URL preview API."
)
MISSING_LXML = (
"""Missing lxml library. This is required for URL preview API.
Install by running:
pip install lxml
Requires libxslt1-dev system package.
"""
)
ThumbnailRequirement = namedtuple(
"ThumbnailRequirement", ["width", "height", "method", "media_type"]
)
@@ -23,7 +39,7 @@ ThumbnailRequirement = namedtuple(
def parse_thumbnail_requirements(thumbnail_sizes):
""" Takes a list of dictionaries with "width", "height", and "method" keys
and creates a map from image media types to the thumbnail size, thumnailing
and creates a map from image media types to the thumbnail size, thumbnailing
method, and thumbnail media type to precalculate
Args:
@@ -53,12 +69,39 @@ class ContentRepositoryConfig(Config):
def read_config(self, config):
self.max_upload_size = self.parse_size(config["max_upload_size"])
self.max_image_pixels = self.parse_size(config["max_image_pixels"])
self.max_spider_size = self.parse_size(config["max_spider_size"])
self.media_store_path = self.ensure_directory(config["media_store_path"])
self.uploads_path = self.ensure_directory(config["uploads_path"])
self.dynamic_thumbnails = config["dynamic_thumbnails"]
self.thumbnail_requirements = parse_thumbnail_requirements(
config["thumbnail_sizes"]
)
self.url_preview_enabled = config.get("url_preview_enabled", False)
if self.url_preview_enabled:
try:
import lxml
lxml # To stop unused lint.
except ImportError:
raise ConfigError(MISSING_LXML)
try:
from netaddr import IPSet
except ImportError:
raise ConfigError(MISSING_NETADDR)
if "url_preview_ip_range_blacklist" in config:
self.url_preview_ip_range_blacklist = IPSet(
config["url_preview_ip_range_blacklist"]
)
else:
raise ConfigError(
"For security, you must specify an explicit target IP address "
"blacklist in url_preview_ip_range_blacklist for url previewing "
"to work"
)
if "url_preview_url_blacklist" in config:
self.url_preview_url_blacklist = config["url_preview_url_blacklist"]
def default_config(self, **kwargs):
media_store = self.default_path("media_store")
@@ -80,7 +123,7 @@ class ContentRepositoryConfig(Config):
# the resolution requested by the client. If true then whenever
# a new resolution is requested by the client the server will
# generate a new thumbnail. If false the server will pick a thumbnail
# from a precalcualted list.
# from a precalculated list.
dynamic_thumbnails: false
# List of thumbnail to precalculate when an image is uploaded.
@@ -100,4 +143,62 @@ class ContentRepositoryConfig(Config):
- width: 800
height: 600
method: scale
# Is the preview URL API enabled? If enabled, you *must* specify
# an explicit url_preview_ip_range_blacklist of IPs that the spider is
# denied from accessing.
url_preview_enabled: False
# List of IP address CIDR ranges that the URL preview spider is denied
# from accessing. There are no defaults: you must explicitly
# specify a list for URL previewing to work. You should specify any
# internal services in your network that you do not want synapse to try
# to connect to, otherwise anyone in any Matrix room could cause your
# synapse to issue arbitrary GET requests to your internal services,
# causing serious security issues.
#
# url_preview_ip_range_blacklist:
# - '127.0.0.0/8'
# - '10.0.0.0/8'
# - '172.16.0.0/12'
# - '192.168.0.0/16'
# Optional list of URL matches that the URL preview spider is
# denied from accessing. You should use url_preview_ip_range_blacklist
# in preference to this, otherwise someone could define a public DNS
# entry that points to a private IP address and circumvent the blacklist.
# This is more useful if you know there is an entire shape of URL that
# you know that will never want synapse to try to spider.
#
# Each list entry is a dictionary of url component attributes as returned
# by urlparse.urlsplit as applied to the absolute form of the URL. See
# https://docs.python.org/2/library/urlparse.html#urlparse.urlsplit
# The values of the dictionary are treated as an filename match pattern
# applied to that component of URLs, unless they start with a ^ in which
# case they are treated as a regular expression match. If all the
# specified component matches for a given list item succeed, the URL is
# blacklisted.
#
# url_preview_url_blacklist:
# # blacklist any URL with a username in its URI
# - username: '*'
#
# # blacklist all *.google.com URLs
# - netloc: 'google.com'
# - netloc: '*.google.com'
#
# # blacklist all plain HTTP URLs
# - scheme: 'http'
#
# # blacklist http(s)://www.acme.com/foo
# - netloc: 'www.acme.com'
# path: '/foo'
#
# # blacklist any URL with a literal IPv4 address
# - netloc: '^[0-9]+\.[0-9]+\.[0-9]+\.[0-9]+$'
# The largest allowed URL preview spidering size in bytes
max_spider_size: "10M"
""" % locals()

View File

@@ -28,6 +28,7 @@ class ServerConfig(Config):
self.print_pidfile = config.get("print_pidfile")
self.user_agent_suffix = config.get("user_agent_suffix")
self.use_frozen_dicts = config.get("use_frozen_dicts", True)
self.start_pushers = config.get("start_pushers", True)
self.listeners = config.get("listeners", [])

View File

@@ -31,7 +31,10 @@ class _EventInternalMetadata(object):
return dict(self.__dict__)
def is_outlier(self):
return hasattr(self, "outlier") and self.outlier
return getattr(self, "outlier", False)
def is_invite_from_remote(self):
return getattr(self, "invite_from_remote", False)
def _event_dict_property(key):

View File

@@ -179,7 +179,8 @@ class TransportLayerClient(object):
content = yield self.client.get_json(
destination=destination,
path=path,
retry_on_dns_fail=True,
retry_on_dns_fail=False,
timeout=20000,
)
defer.returnValue(content)

View File

@@ -17,8 +17,9 @@ from synapse.appservice.scheduler import AppServiceScheduler
from synapse.appservice.api import ApplicationServiceApi
from .register import RegistrationHandler
from .room import (
RoomCreationHandler, RoomMemberHandler, RoomListHandler, RoomContextHandler,
RoomCreationHandler, RoomListHandler, RoomContextHandler,
)
from .room_member import RoomMemberHandler
from .message import MessageHandler
from .events import EventStreamHandler, EventHandler
from .federation import FederationHandler

View File

@@ -21,7 +21,7 @@ from synapse.api.constants import Membership, EventTypes
from synapse.types import UserID, RoomAlias, Requester
from synapse.push.action_generator import ActionGenerator
from synapse.util.logcontext import PreserveLoggingContext
from synapse.util.logcontext import PreserveLoggingContext, preserve_fn
import logging
@@ -37,12 +37,22 @@ VISIBILITY_PRIORITY = (
)
MEMBERSHIP_PRIORITY = (
Membership.JOIN,
Membership.INVITE,
Membership.KNOCK,
Membership.LEAVE,
Membership.BAN,
)
class BaseHandler(object):
"""
Common base class for the event handlers.
:type store: synapse.storage.events.StateStore
:type state_handler: synapse.state.StateHandler
Attributes:
store (synapse.storage.events.StateStore):
state_handler (synapse.state.StateHandler):
"""
def __init__(self, hs):
@@ -65,11 +75,13 @@ class BaseHandler(object):
""" Returns dict of user_id -> list of events that user is allowed to
see.
:param (str, bool) user_tuples: (user id, is_peeking) for each
user to be checked. is_peeking should be true if:
* the user is not currently a member of the room, and:
* the user has not been a member of the room since the given
events
Args:
user_tuples (str, bool): (user id, is_peeking) for each user to be
checked. is_peeking should be true if:
* the user is not currently a member of the room, and:
* the user has not been a member of the room since the
given events
events ([synapse.events.EventBase]): list of events to filter
"""
forgotten = yield defer.gatherResults([
self.store.who_forgot_in_room(
@@ -84,6 +96,12 @@ class BaseHandler(object):
)
def allowed(event, user_id, is_peeking):
"""
Args:
event (synapse.events.EventBase): event to check
user_id (str)
is_peeking (bool)
"""
state = event_id_to_state[event.event_id]
# get the room_visibility at the time of the event.
@@ -115,17 +133,30 @@ class BaseHandler(object):
if old_priority < new_priority:
visibility = prev_visibility
# get the user's membership at the time of the event. (or rather,
# just *after* the event. Which means that people can see their
# own join events, but not (currently) their own leave events.)
membership_event = state.get((EventTypes.Member, user_id), None)
if membership_event:
if membership_event.event_id in event_id_forgotten:
membership = None
else:
membership = membership_event.membership
else:
membership = None
# likewise, if the event is the user's own membership event, use
# the 'most joined' membership
membership = None
if event.type == EventTypes.Member and event.state_key == user_id:
membership = event.content.get("membership", None)
if membership not in MEMBERSHIP_PRIORITY:
membership = "leave"
prev_content = event.unsigned.get("prev_content", {})
prev_membership = prev_content.get("membership", None)
if prev_membership not in MEMBERSHIP_PRIORITY:
prev_membership = "leave"
new_priority = MEMBERSHIP_PRIORITY.index(membership)
old_priority = MEMBERSHIP_PRIORITY.index(prev_membership)
if old_priority < new_priority:
membership = prev_membership
# otherwise, get the user's membership at the time of the event.
if membership is None:
membership_event = state.get((EventTypes.Member, user_id), None)
if membership_event:
if membership_event.event_id not in event_id_forgotten:
membership = membership_event.membership
# if the user was a member of the room at the time of the event,
# they can see it.
@@ -165,13 +196,16 @@ class BaseHandler(object):
"""
Check which events a user is allowed to see
:param str user_id: user id to be checked
:param [synapse.events.EventBase] events: list of events to be checked
:param bool is_peeking should be True if:
Args:
user_id(str): user id to be checked
events([synapse.events.EventBase]): list of events to be checked
is_peeking(bool): should be True if:
* the user is not currently a member of the room, and:
* the user has not been a member of the room since the given
events
:rtype [synapse.events.EventBase]
Returns:
[synapse.events.EventBase]
"""
types = (
(EventTypes.RoomHistoryVisibility, ""),
@@ -199,20 +233,25 @@ class BaseHandler(object):
)
@defer.inlineCallbacks
def _create_new_client_event(self, builder):
latest_ret = yield self.store.get_latest_event_ids_and_hashes_in_room(
builder.room_id,
)
if latest_ret:
depth = max([d for _, _, d in latest_ret]) + 1
def _create_new_client_event(self, builder, prev_event_ids=None):
if prev_event_ids:
prev_events = yield self.store.add_event_hashes(prev_event_ids)
prev_max_depth = yield self.store.get_max_depth_of_events(prev_event_ids)
depth = prev_max_depth + 1
else:
depth = 1
latest_ret = yield self.store.get_latest_event_ids_and_hashes_in_room(
builder.room_id,
)
prev_events = [
(event_id, prev_hashes)
for event_id, prev_hashes, _ in latest_ret
]
if latest_ret:
depth = max([d for _, _, d in latest_ret]) + 1
else:
depth = 1
prev_events = [
(event_id, prev_hashes)
for event_id, prev_hashes, _ in latest_ret
]
builder.prev_events = prev_events
builder.depth = depth
@@ -221,50 +260,6 @@ class BaseHandler(object):
context = yield state_handler.compute_event_context(builder)
# If we've received an invite over federation, there are no latest
# events in the room, because we don't know enough about the graph
# fragment we received to treat it like a graph, so the above returned
# no relevant events. It may have returned some events (if we have
# joined and left the room), but not useful ones, like the invite.
if (
not self.is_host_in_room(context.current_state) and
builder.type == EventTypes.Member
):
prev_member_event = yield self.store.get_room_member(
builder.sender, builder.room_id
)
# The prev_member_event may already be in context.current_state,
# despite us not being present in the room; in particular, if
# inviting user, and all other local users, have already left.
#
# In that case, we have all the information we need, and we don't
# want to drop "context" - not least because we may need to handle
# the invite locally, which will require us to have the whole
# context (not just prev_member_event) to auth it.
#
context_event_ids = (
e.event_id for e in context.current_state.values()
)
if (
prev_member_event and
prev_member_event.event_id not in context_event_ids
):
# The prev_member_event is missing from context, so it must
# have arrived over federation and is an outlier. We forcibly
# set our context to the invite we received over federation
builder.prev_events = (
prev_member_event.event_id,
prev_member_event.prev_events
)
context = yield state_handler.compute_event_context(
builder,
old_state=(prev_member_event,),
outlier=True
)
if builder.is_state():
builder.prev_state = yield self.store.add_event_hashes(
context.prev_state_events
@@ -321,7 +316,11 @@ class BaseHandler(object):
if ratelimit:
self.ratelimit(requester)
self.auth.check(event, auth_events=context.current_state)
try:
self.auth.check(event, auth_events=context.current_state)
except AuthError as err:
logger.warn("Denying new event %r because %s", event, err)
raise err
yield self.maybe_kick_guest_users(event, context.current_state.values())
@@ -411,6 +410,12 @@ class BaseHandler(object):
event, context=context
)
# this intentionally does not yield: we don't care about the result
# and don't need to wait for it.
preserve_fn(self.hs.get_pusherpool().on_new_notifications)(
event_stream_id, max_stream_id
)
destinations = set()
for k, s in context.current_state.items():
try:

View File

@@ -49,6 +49,21 @@ class AuthHandler(BaseHandler):
self.sessions = {}
self.INVALID_TOKEN_HTTP_STATUS = 401
self.ldap_enabled = hs.config.ldap_enabled
self.ldap_server = hs.config.ldap_server
self.ldap_port = hs.config.ldap_port
self.ldap_tls = hs.config.ldap_tls
self.ldap_search_base = hs.config.ldap_search_base
self.ldap_search_property = hs.config.ldap_search_property
self.ldap_email_property = hs.config.ldap_email_property
self.ldap_full_name_property = hs.config.ldap_full_name_property
if self.ldap_enabled is True:
import ldap
logger.info("Import ldap version: %s", ldap.__version__)
self.hs = hs # FIXME better possibility to access registrationHandler later?
@defer.inlineCallbacks
def check_auth(self, flows, clientdict, clientip):
"""
@@ -163,9 +178,13 @@ class AuthHandler(BaseHandler):
def get_session_id(self, clientdict):
"""
Gets the session ID for a client given the client dictionary
:param clientdict: The dictionary sent by the client in the request
:return: The string session ID the client sent. If the client did not
send a session ID, returns None.
Args:
clientdict: The dictionary sent by the client in the request
Returns:
str|None: The string session ID the client sent. If the client did
not send a session ID, returns None.
"""
sid = None
if clientdict and 'auth' in clientdict:
@@ -179,9 +198,11 @@ class AuthHandler(BaseHandler):
Store a key-value pair into the sessions data associated with this
request. This data is stored server-side and cannot be modified by
the client.
:param session_id: (string) The ID of this session as returned from check_auth
:param key: (string) The key to store the data under
:param value: (any) The data to store
Args:
session_id (string): The ID of this session as returned from check_auth
key (string): The key to store the data under
value (any): The data to store
"""
sess = self._get_session_info(session_id)
sess.setdefault('serverdict', {})[key] = value
@@ -190,9 +211,11 @@ class AuthHandler(BaseHandler):
def get_session_data(self, session_id, key, default=None):
"""
Retrieve data stored with set_session_data
:param session_id: (string) The ID of this session as returned from check_auth
:param key: (string) The key to store the data under
:param default: (any) Value to return if the key has not been set
Args:
session_id (string): The ID of this session as returned from check_auth
key (string): The key to store the data under
default (any): Value to return if the key has not been set
"""
sess = self._get_session_info(session_id)
return sess.setdefault('serverdict', {}).get(key, default)
@@ -207,8 +230,10 @@ class AuthHandler(BaseHandler):
if not user_id.startswith('@'):
user_id = UserID.create(user_id, self.hs.hostname).to_string()
user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id)
self._check_password(user_id, password, password_hash)
if not (yield self._check_password(user_id, password)):
logger.warn("Failed password login for user %s", user_id)
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
defer.returnValue(user_id)
@defer.inlineCallbacks
@@ -332,8 +357,10 @@ class AuthHandler(BaseHandler):
StoreError if there was a problem storing the token.
LoginError if there was an authentication problem.
"""
user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id)
self._check_password(user_id, password, password_hash)
if not (yield self._check_password(user_id, password)):
logger.warn("Failed password login for user %s", user_id)
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
logger.info("Logging in user %s", user_id)
access_token = yield self.issue_access_token(user_id)
@@ -399,11 +426,67 @@ class AuthHandler(BaseHandler):
else:
defer.returnValue(user_infos.popitem())
def _check_password(self, user_id, password, stored_hash):
"""Checks that user_id has passed password, raises LoginError if not."""
if not self.validate_hash(password, stored_hash):
logger.warn("Failed password login for user %s", user_id)
raise LoginError(403, "", errcode=Codes.FORBIDDEN)
@defer.inlineCallbacks
def _check_password(self, user_id, password):
"""
Returns:
True if the user_id successfully authenticated
"""
valid_ldap = yield self._check_ldap_password(user_id, password)
if valid_ldap:
defer.returnValue(True)
valid_local_password = yield self._check_local_password(user_id, password)
if valid_local_password:
defer.returnValue(True)
defer.returnValue(False)
@defer.inlineCallbacks
def _check_local_password(self, user_id, password):
try:
user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id)
defer.returnValue(self.validate_hash(password, password_hash))
except LoginError:
defer.returnValue(False)
@defer.inlineCallbacks
def _check_ldap_password(self, user_id, password):
if not self.ldap_enabled:
logger.debug("LDAP not configured")
defer.returnValue(False)
import ldap
logger.info("Authenticating %s with LDAP" % user_id)
try:
ldap_url = "%s:%s" % (self.ldap_server, self.ldap_port)
logger.debug("Connecting LDAP server at %s" % ldap_url)
l = ldap.initialize(ldap_url)
if self.ldap_tls:
logger.debug("Initiating TLS")
self._connection.start_tls_s()
local_name = UserID.from_string(user_id).localpart
dn = "%s=%s, %s" % (
self.ldap_search_property,
local_name,
self.ldap_search_base)
logger.debug("DN for LDAP authentication: %s" % dn)
l.simple_bind_s(dn.encode('utf-8'), password.encode('utf-8'))
if not (yield self.does_user_exist(user_id)):
handler = self.hs.get_handlers().registration_handler
user_id, access_token = (
yield handler.register(localpart=local_name)
)
defer.returnValue(True)
except ldap.LDAPError, e:
logger.warn("LDAP error: %s", e)
defer.returnValue(False)
@defer.inlineCallbacks
def issue_access_token(self, user_id):

View File

@@ -26,7 +26,7 @@ from synapse.api.errors import (
from synapse.api.constants import EventTypes, Membership, RejectedReason
from synapse.events.validator import EventValidator
from synapse.util import unwrapFirstError
from synapse.util.logcontext import PreserveLoggingContext
from synapse.util.logcontext import PreserveLoggingContext, preserve_fn
from synapse.util.logutils import log_function
from synapse.util.async import run_on_reactor
from synapse.util.frozenutils import unfreeze
@@ -40,6 +40,7 @@ from synapse.events.utils import prune_event
from synapse.util.retryutils import NotRetryingDestination
from synapse.push.action_generator import ActionGenerator
from synapse.util.distributor import user_joined_room
from twisted.internet import defer
@@ -49,10 +50,6 @@ import logging
logger = logging.getLogger(__name__)
def user_joined_room(distributor, user, room_id):
return distributor.fire("user_joined_room", user, room_id)
class FederationHandler(BaseHandler):
"""Handles events that originated from federation.
Responsible for:
@@ -102,8 +99,7 @@ class FederationHandler(BaseHandler):
@log_function
@defer.inlineCallbacks
def on_receive_pdu(self, origin, pdu, state=None,
auth_chain=None):
def on_receive_pdu(self, origin, pdu, state=None, auth_chain=None):
""" Called by the ReplicationLayer when we have a new pdu. We need to
do auth checks and put it through the StateHandler.
"""
@@ -174,11 +170,7 @@ class FederationHandler(BaseHandler):
})
seen_ids.add(e.event_id)
yield self._handle_new_events(
origin,
event_infos,
outliers=True
)
yield self._handle_new_events(origin, event_infos)
try:
context, event_stream_id, max_stream_id = yield self._handle_new_event(
@@ -288,7 +280,14 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks
def backfill(self, dest, room_id, limit, extremities=[]):
""" Trigger a backfill request to `dest` for the given `room_id`
This will attempt to get more events from the remote. This may return
be successfull and still return no events if the other side has no new
events to offer.
"""
if dest == self.server_name:
raise SynapseError(400, "Can't backfill from self.")
if not extremities:
extremities = yield self.store.get_oldest_events_in_room(room_id)
@@ -299,6 +298,16 @@ class FederationHandler(BaseHandler):
extremities=extremities,
)
# Don't bother processing events we already have.
seen_events = yield self.store.have_events_in_timeline(
set(e.event_id for e in events)
)
events = [e for e in events if e.event_id not in seen_events]
if not events:
defer.returnValue([])
event_map = {e.event_id: e for e in events}
event_ids = set(e.event_id for e in events)
@@ -358,6 +367,7 @@ class FederationHandler(BaseHandler):
for a in auth_events.values():
if a.event_id in seen_events:
continue
a.internal_metadata.outlier = True
ev_infos.append({
"event": a,
"auth_events": {
@@ -378,20 +388,23 @@ class FederationHandler(BaseHandler):
}
})
yield self._handle_new_events(
dest, ev_infos,
backfilled=True,
)
events.sort(key=lambda e: e.depth)
for event in events:
if event in events_to_state:
continue
ev_infos.append({
"event": event,
})
yield self._handle_new_events(
dest, ev_infos,
backfilled=True,
)
# We store these one at a time since each event depends on the
# previous to work out the state.
# TODO: We can probably do something more clever here.
yield self._handle_new_event(
dest, event
)
defer.returnValue(events)
@@ -455,7 +468,7 @@ class FederationHandler(BaseHandler):
likely_domains = [
domain for domain, depth in curr_domains
if domain is not self.server_name
if domain != self.server_name
]
@defer.inlineCallbacks
@@ -463,11 +476,15 @@ class FederationHandler(BaseHandler):
# TODO: Should we try multiple of these at a time?
for dom in domains:
try:
events = yield self.backfill(
yield self.backfill(
dom, room_id,
limit=100,
extremities=[e for e in extremities.keys()]
)
# If this succeeded then we probably already have the
# appropriate stuff.
# TODO: We can probably do something more intelligent here.
defer.returnValue(True)
except SynapseError as e:
logger.info(
"Failed to backfill from %s because %s",
@@ -493,8 +510,6 @@ class FederationHandler(BaseHandler):
)
continue
if events:
defer.returnValue(True)
defer.returnValue(False)
success = yield try_backfill(likely_domains)
@@ -666,9 +681,13 @@ class FederationHandler(BaseHandler):
"state_key": user_id,
})
event, context = yield self._create_new_client_event(
builder=builder,
)
try:
event, context = yield self._create_new_client_event(
builder=builder,
)
except AuthError as e:
logger.warn("Failed to create join %r because %s", event, e)
raise e
self.auth.check(event, auth_events=context.current_state)
@@ -761,6 +780,7 @@ class FederationHandler(BaseHandler):
event = pdu
event.internal_metadata.outlier = True
event.internal_metadata.invite_from_remote = True
event.signatures.update(
compute_event_signature(
@@ -788,13 +808,19 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks
def do_remotely_reject_invite(self, target_hosts, room_id, user_id):
origin, event = yield self._make_and_verify_event(
target_hosts,
room_id,
user_id,
"leave"
)
signed_event = self._sign_event(event)
try:
origin, event = yield self._make_and_verify_event(
target_hosts,
room_id,
user_id,
"leave"
)
signed_event = self._sign_event(event)
except SynapseError:
raise
except CodeMessageException as e:
logger.warn("Failed to reject invite: %s", e)
raise SynapseError(500, "Failed to reject invite")
# Try the host we successfully got a response to /make_join/
# request first.
@@ -804,10 +830,16 @@ class FederationHandler(BaseHandler):
except ValueError:
pass
yield self.replication_layer.send_leave(
target_hosts,
signed_event
)
try:
yield self.replication_layer.send_leave(
target_hosts,
signed_event
)
except SynapseError:
raise
except CodeMessageException as e:
logger.warn("Failed to reject invite: %s", e)
raise SynapseError(500, "Failed to reject invite")
context = yield self.state_handler.compute_event_context(event)
@@ -887,7 +919,11 @@ class FederationHandler(BaseHandler):
builder=builder,
)
self.auth.check(event, auth_events=context.current_state)
try:
self.auth.check(event, auth_events=context.current_state)
except AuthError as e:
logger.warn("Failed to create new leave %r because %s", event, e)
raise e
defer.returnValue(event)
@@ -1068,10 +1104,8 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks
@log_function
def _handle_new_event(self, origin, event, state=None, auth_events=None):
outlier = event.internal_metadata.is_outlier()
def _handle_new_event(self, origin, event, state=None, auth_events=None,
backfilled=False):
context = yield self._prep_event(
origin, event,
state=state,
@@ -1087,14 +1121,24 @@ class FederationHandler(BaseHandler):
event_stream_id, max_stream_id = yield self.store.persist_event(
event,
context=context,
is_new_state=not outlier,
backfilled=backfilled,
)
# this intentionally does not yield: we don't care about the result
# and don't need to wait for it.
preserve_fn(self.hs.get_pusherpool().on_new_notifications)(
event_stream_id, max_stream_id
)
defer.returnValue((context, event_stream_id, max_stream_id))
@defer.inlineCallbacks
def _handle_new_events(self, origin, event_infos, backfilled=False,
outliers=False):
def _handle_new_events(self, origin, event_infos, backfilled=False):
"""Creates the appropriate contexts and persists events. The events
should not depend on one another, e.g. this should be used to persist
a bunch of outliers, but not a chunk of individual events that depend
on each other for state calculations.
"""
contexts = yield defer.gatherResults(
[
self._prep_event(
@@ -1113,7 +1157,6 @@ class FederationHandler(BaseHandler):
for ev_info, context in itertools.izip(event_infos, contexts)
],
backfilled=backfilled,
is_new_state=(not outliers and not backfilled),
)
@defer.inlineCallbacks
@@ -1128,11 +1171,9 @@ class FederationHandler(BaseHandler):
"""
events_to_context = {}
for e in itertools.chain(auth_events, state):
ctx = yield self.state_handler.compute_event_context(
e, outlier=True,
)
events_to_context[e.event_id] = ctx
e.internal_metadata.outlier = True
ctx = yield self.state_handler.compute_event_context(e)
events_to_context[e.event_id] = ctx
event_map = {
e.event_id: e
@@ -1176,16 +1217,14 @@ class FederationHandler(BaseHandler):
(e, events_to_context[e.event_id])
for e in itertools.chain(auth_events, state)
],
is_new_state=False,
)
new_event_context = yield self.state_handler.compute_event_context(
event, old_state=state, outlier=False,
event, old_state=state
)
event_stream_id, max_stream_id = yield self.store.persist_event(
event, new_event_context,
is_new_state=True,
current_state=state,
)
@@ -1193,10 +1232,9 @@ class FederationHandler(BaseHandler):
@defer.inlineCallbacks
def _prep_event(self, origin, event, state=None, auth_events=None):
outlier = event.internal_metadata.is_outlier()
context = yield self.state_handler.compute_event_context(
event, old_state=state, outlier=outlier,
event, old_state=state,
)
if not auth_events:
@@ -1482,8 +1520,9 @@ class FederationHandler(BaseHandler):
try:
self.auth.check(event, auth_events=auth_events)
except AuthError:
raise
except AuthError as e:
logger.warn("Failed auth resolution for %r because %s", event, e)
raise e
@defer.inlineCallbacks
def construct_auth_difference(self, local_auth, remote_auth):
@@ -1659,7 +1698,12 @@ class FederationHandler(BaseHandler):
event_dict, event, context
)
self.auth.check(event, context.current_state)
try:
self.auth.check(event, context.current_state)
except AuthError as e:
logger.warn("Denying new third party invite %r because %s", event, e)
raise e
yield self._check_signature(event, auth_events=context.current_state)
member_handler = self.hs.get_handlers().room_member_handler
yield member_handler.send_membership_event(None, event, context)
@@ -1684,7 +1728,11 @@ class FederationHandler(BaseHandler):
event_dict, event, context
)
self.auth.check(event, auth_events=context.current_state)
try:
self.auth.check(event, auth_events=context.current_state)
except AuthError as e:
logger.warn("Denying third party invite %r because %s", event, e)
raise e
yield self._check_signature(event, auth_events=context.current_state)
returned_invite = yield self.send_invite(origin, event)
@@ -1718,13 +1766,15 @@ class FederationHandler(BaseHandler):
def _check_signature(self, event, auth_events):
"""
Checks that the signature in the event is consistent with its invite.
:param event (Event): The m.room.member event to check
:param auth_events (dict<(event type, state_key), event>)
:raises
AuthError if signature didn't match any keys, or key has been
Args:
event (Event): The m.room.member event to check
auth_events (dict<(event type, state_key), event>):
Raises:
AuthError: if signature didn't match any keys, or key has been
revoked,
SynapseError if a transient error meant a key couldn't be checked
SynapseError: if a transient error meant a key couldn't be checked
for revocation.
"""
signed = event.content["third_party_invite"]["signed"]
@@ -1766,12 +1816,13 @@ class FederationHandler(BaseHandler):
"""
Checks whether public_key has been revoked.
:param public_key (str): base-64 encoded public key.
:param url (str): Key revocation URL.
Args:
public_key (str): base-64 encoded public key.
url (str): Key revocation URL.
:raises
AuthError if they key has been revoked.
SynapseError if a transient error meant a key couldn't be checked
Raises:
AuthError: if they key has been revoked.
SynapseError: if a transient error meant a key couldn't be checked
for revocation.
"""
try:

View File

@@ -21,6 +21,7 @@ from synapse.streams.config import PaginationConfig
from synapse.events.utils import serialize_event
from synapse.events.validator import EventValidator
from synapse.util import unwrapFirstError
from synapse.util.async import concurrently_execute
from synapse.util.caches.snapshot_cache import SnapshotCache
from synapse.types import UserID, RoomStreamToken, StreamToken
@@ -33,10 +34,6 @@ import logging
logger = logging.getLogger(__name__)
def collect_presencelike_data(distributor, user, content):
return distributor.fire("collect_presencelike_data", user, content)
class MessageHandler(BaseHandler):
def __init__(self, hs):
@@ -47,35 +44,6 @@ class MessageHandler(BaseHandler):
self.validator = EventValidator()
self.snapshot_cache = SnapshotCache()
@defer.inlineCallbacks
def get_message(self, msg_id=None, room_id=None, sender_id=None,
user_id=None):
""" Retrieve a message.
Args:
msg_id (str): The message ID to obtain.
room_id (str): The room where the message resides.
sender_id (str): The user ID of the user who sent the message.
user_id (str): The user ID of the user making this request.
Returns:
The message, or None if no message exists.
Raises:
SynapseError if something went wrong.
"""
yield self.auth.check_joined_room(room_id, user_id)
# Pull out the message from the db
# msg = yield self.store.get_message(
# room_id=room_id,
# msg_id=msg_id,
# user_id=sender_id
# )
# TODO (erikj): Once we work out the correct c-s api we need to think
# on how to do this.
defer.returnValue(None)
@defer.inlineCallbacks
def get_messages(self, requester, room_id=None, pagin_config=None,
as_client_event=True):
@@ -175,7 +143,7 @@ class MessageHandler(BaseHandler):
defer.returnValue(chunk)
@defer.inlineCallbacks
def create_event(self, event_dict, token_id=None, txn_id=None):
def create_event(self, event_dict, token_id=None, txn_id=None, prev_event_ids=None):
"""
Given a dict from a client, create a new event.
@@ -186,6 +154,9 @@ class MessageHandler(BaseHandler):
Args:
event_dict (dict): An entire event
token_id (str)
txn_id (str)
prev_event_ids (list): The prev event ids to use when creating the event
Returns:
Tuple of created event (FrozenEvent), Context
@@ -198,12 +169,8 @@ class MessageHandler(BaseHandler):
membership = builder.content.get("membership", None)
target = UserID.from_string(builder.state_key)
if membership == Membership.JOIN:
if membership in {Membership.JOIN, Membership.INVITE}:
# If event doesn't include a display name, add one.
yield collect_presencelike_data(
self.distributor, target, builder.content
)
elif membership == Membership.INVITE:
profile = self.hs.get_handlers().profile_handler
content = builder.content
@@ -224,6 +191,7 @@ class MessageHandler(BaseHandler):
event, context = yield self._create_new_client_event(
builder=builder,
prev_event_ids=prev_event_ids,
)
defer.returnValue((event, context))
@@ -556,14 +524,7 @@ class MessageHandler(BaseHandler):
except:
logger.exception("Failed to get snapshot")
# Only do N rooms at once
n = 5
d_list = [handle_room(e) for e in room_list]
for i in range(0, len(d_list), n):
yield defer.gatherResults(
d_list[i:i + n],
consumeErrors=True
).addErrback(unwrapFirstError)
yield concurrently_execute(handle_room, room_list, 10)
account_data_events = []
for account_data_type, content in account_data.items():

View File

@@ -17,7 +17,6 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError, AuthError, CodeMessageException
from synapse.types import UserID, Requester
from synapse.util import unwrapFirstError
from ._base import BaseHandler
@@ -27,14 +26,6 @@ import logging
logger = logging.getLogger(__name__)
def changed_presencelike_data(distributor, user, state):
return distributor.fire("changed_presencelike_data", user, state)
def collect_presencelike_data(distributor, user, content):
return distributor.fire("collect_presencelike_data", user, content)
class ProfileHandler(BaseHandler):
def __init__(self, hs):
@@ -46,17 +37,9 @@ class ProfileHandler(BaseHandler):
)
distributor = hs.get_distributor()
self.distributor = distributor
distributor.declare("collect_presencelike_data")
distributor.declare("changed_presencelike_data")
distributor.observe("registered_user", self.registered_user)
distributor.observe(
"collect_presencelike_data", self.collect_presencelike_data
)
def registered_user(self, user):
return self.store.create_profile(user.localpart)
@@ -105,10 +88,6 @@ class ProfileHandler(BaseHandler):
target_user.localpart, new_displayname
)
yield changed_presencelike_data(self.distributor, target_user, {
"displayname": new_displayname,
})
yield self._update_join_states(requester)
@defer.inlineCallbacks
@@ -152,30 +131,8 @@ class ProfileHandler(BaseHandler):
target_user.localpart, new_avatar_url
)
yield changed_presencelike_data(self.distributor, target_user, {
"avatar_url": new_avatar_url,
})
yield self._update_join_states(requester)
@defer.inlineCallbacks
def collect_presencelike_data(self, user, state):
if not self.hs.is_mine(user):
defer.returnValue(None)
(displayname, avatar_url) = yield defer.gatherResults(
[
self.store.get_profile_displayname(user.localpart),
self.store.get_profile_avatar_url(user.localpart),
],
consumeErrors=True
).addErrback(unwrapFirstError)
state["displayname"] = displayname
state["avatar_url"] = avatar_url
defer.returnValue(None)
@defer.inlineCallbacks
def on_profile_query(self, args):
user = UserID.from_string(args["user_id"])

View File

@@ -80,6 +80,9 @@ class ReceiptsHandler(BaseHandler):
def _handle_new_receipts(self, receipts):
"""Takes a list of receipts, stores them and informs the notifier.
"""
min_batch_id = None
max_batch_id = None
for receipt in receipts:
room_id = receipt["room_id"]
receipt_type = receipt["receipt_type"]
@@ -97,10 +100,21 @@ class ReceiptsHandler(BaseHandler):
stream_id, max_persisted_id = res
with PreserveLoggingContext():
self.notifier.on_new_event(
"receipt_key", max_persisted_id, rooms=[room_id]
)
if min_batch_id is None or stream_id < min_batch_id:
min_batch_id = stream_id
if max_batch_id is None or max_persisted_id > max_batch_id:
max_batch_id = max_persisted_id
affected_room_ids = list(set([r["room_id"] for r in receipts]))
with PreserveLoggingContext():
self.notifier.on_new_event(
"receipt_key", max_batch_id, rooms=affected_room_ids
)
# Note that the min here shouldn't be relied upon to be accurate.
self.hs.get_pusherpool().on_new_receipts(
min_batch_id, max_batch_id, affected_room_ids
)
defer.returnValue(True)

View File

@@ -23,6 +23,7 @@ from synapse.api.errors import (
from ._base import BaseHandler
from synapse.util.async import run_on_reactor
from synapse.http.client import CaptchaServerHttpClient
from synapse.util.distributor import registered_user
import logging
import urllib
@@ -30,10 +31,6 @@ import urllib
logger = logging.getLogger(__name__)
def registered_user(distributor, user):
return distributor.fire("registered_user", user)
class RegistrationHandler(BaseHandler):
def __init__(self, hs):

View File

@@ -18,19 +18,16 @@ from twisted.internet import defer
from ._base import BaseHandler
from synapse.types import UserID, RoomAlias, RoomID, RoomStreamToken, Requester
from synapse.types import UserID, RoomAlias, RoomID, RoomStreamToken
from synapse.api.constants import (
EventTypes, Membership, JoinRules, RoomCreationPreset,
EventTypes, JoinRules, RoomCreationPreset,
)
from synapse.api.errors import AuthError, StoreError, SynapseError, Codes
from synapse.util import stringutils, unwrapFirstError
from synapse.util.logcontext import preserve_context_over_fn
from signedjson.sign import verify_signed_json
from signedjson.key import decode_verify_key_bytes
from synapse.api.errors import AuthError, StoreError, SynapseError
from synapse.util import stringutils
from synapse.util.async import concurrently_execute
from synapse.util.caches.response_cache import ResponseCache
from collections import OrderedDict
from unpaddedbase64 import decode_base64
import logging
import math
@@ -41,20 +38,6 @@ logger = logging.getLogger(__name__)
id_server_scheme = "https://"
def user_left_room(distributor, user, room_id):
return preserve_context_over_fn(
distributor.fire,
"user_left_room", user=user, room_id=room_id
)
def user_joined_room(distributor, user, room_id):
return preserve_context_over_fn(
distributor.fire,
"user_joined_room", user=user, room_id=room_id
)
class RoomCreationHandler(BaseHandler):
PRESETS_DICT = {
@@ -356,598 +339,25 @@ class RoomCreationHandler(BaseHandler):
)
class RoomMemberHandler(BaseHandler):
# TODO(paul): This handler currently contains a messy conflation of
# low-level API that works on UserID objects and so on, and REST-level
# API that takes ID strings and returns pagination chunks. These concerns
# ought to be separated out a lot better.
def __init__(self, hs):
super(RoomMemberHandler, self).__init__(hs)
self.clock = hs.get_clock()
self.distributor = hs.get_distributor()
self.distributor.declare("user_joined_room")
self.distributor.declare("user_left_room")
@defer.inlineCallbacks
def get_room_members(self, room_id):
users = yield self.store.get_users_in_room(room_id)
defer.returnValue([UserID.from_string(u) for u in users])
@defer.inlineCallbacks
def fetch_room_distributions_into(self, room_id, localusers=None,
remotedomains=None, ignore_user=None):
"""Fetch the distribution of a room, adding elements to either
'localusers' or 'remotedomains', which should be a set() if supplied.
If ignore_user is set, ignore that user.
This function returns nothing; its result is performed by the
side-effect on the two passed sets. This allows easy accumulation of
member lists of multiple rooms at once if required.
"""
members = yield self.get_room_members(room_id)
for member in members:
if ignore_user is not None and member == ignore_user:
continue
if self.hs.is_mine(member):
if localusers is not None:
localusers.add(member)
else:
if remotedomains is not None:
remotedomains.add(member.domain)
@defer.inlineCallbacks
def update_membership(
self,
requester,
target,
room_id,
action,
txn_id=None,
remote_room_hosts=None,
third_party_signed=None,
ratelimit=True,
):
effective_membership_state = action
if action in ["kick", "unban"]:
effective_membership_state = "leave"
elif action == "forget":
effective_membership_state = "leave"
if third_party_signed is not None:
replication = self.hs.get_replication_layer()
yield replication.exchange_third_party_invite(
third_party_signed["sender"],
target.to_string(),
room_id,
third_party_signed,
)
msg_handler = self.hs.get_handlers().message_handler
content = {"membership": effective_membership_state}
if requester.is_guest:
content["kind"] = "guest"
event, context = yield msg_handler.create_event(
{
"type": EventTypes.Member,
"content": content,
"room_id": room_id,
"sender": requester.user.to_string(),
"state_key": target.to_string(),
# For backwards compatibility:
"membership": effective_membership_state,
},
token_id=requester.access_token_id,
txn_id=txn_id,
)
old_state = context.current_state.get((EventTypes.Member, event.state_key))
old_membership = old_state.content.get("membership") if old_state else None
if action == "unban" and old_membership != "ban":
raise SynapseError(
403,
"Cannot unban user who was not banned (membership=%s)" % old_membership,
errcode=Codes.BAD_STATE
)
if old_membership == "ban" and action != "unban":
raise SynapseError(
403,
"Cannot %s user who was is banned" % (action,),
errcode=Codes.BAD_STATE
)
member_handler = self.hs.get_handlers().room_member_handler
yield member_handler.send_membership_event(
requester,
event,
context,
ratelimit=ratelimit,
remote_room_hosts=remote_room_hosts,
)
if action == "forget":
yield self.forget(requester.user, room_id)
@defer.inlineCallbacks
def send_membership_event(
self,
requester,
event,
context,
remote_room_hosts=None,
ratelimit=True,
):
"""
Change the membership status of a user in a room.
Args:
requester (Requester): The local user who requested the membership
event. If None, certain checks, like whether this homeserver can
act as the sender, will be skipped.
event (SynapseEvent): The membership event.
context: The context of the event.
is_guest (bool): Whether the sender is a guest.
room_hosts ([str]): Homeservers which are likely to already be in
the room, and could be danced with in order to join this
homeserver for the first time.
ratelimit (bool): Whether to rate limit this request.
Raises:
SynapseError if there was a problem changing the membership.
"""
remote_room_hosts = remote_room_hosts or []
target_user = UserID.from_string(event.state_key)
room_id = event.room_id
if requester is not None:
sender = UserID.from_string(event.sender)
assert sender == requester.user, (
"Sender (%s) must be same as requester (%s)" %
(sender, requester.user)
)
assert self.hs.is_mine(sender), "Sender must be our own: %s" % (sender,)
else:
requester = Requester(target_user, None, False)
message_handler = self.hs.get_handlers().message_handler
prev_event = message_handler.deduplicate_state_event(event, context)
if prev_event is not None:
return
action = "send"
if event.membership == Membership.JOIN:
if requester.is_guest and not self._can_guest_join(context.current_state):
# This should be an auth check, but guests are a local concept,
# so don't really fit into the general auth process.
raise AuthError(403, "Guest access not allowed")
do_remote_join_dance, remote_room_hosts = self._should_do_dance(
context,
(self.get_inviter(event.state_key, context.current_state)),
remote_room_hosts,
)
if do_remote_join_dance:
action = "remote_join"
elif event.membership == Membership.LEAVE:
is_host_in_room = self.is_host_in_room(context.current_state)
if not is_host_in_room:
# perhaps we've been invited
inviter = self.get_inviter(target_user.to_string(), context.current_state)
if not inviter:
raise SynapseError(404, "Not a known room")
if self.hs.is_mine(inviter):
# the inviter was on our server, but has now left. Carry on
# with the normal rejection codepath.
#
# This is a bit of a hack, because the room might still be
# active on other servers.
pass
else:
# send the rejection to the inviter's HS.
remote_room_hosts = remote_room_hosts + [inviter.domain]
action = "remote_reject"
federation_handler = self.hs.get_handlers().federation_handler
if action == "remote_join":
if len(remote_room_hosts) == 0:
raise SynapseError(404, "No known servers")
# We don't do an auth check if we are doing an invite
# join dance for now, since we're kinda implicitly checking
# that we are allowed to join when we decide whether or not we
# need to do the invite/join dance.
yield federation_handler.do_invite_join(
remote_room_hosts,
event.room_id,
event.user_id,
event.content,
)
elif action == "remote_reject":
yield federation_handler.do_remotely_reject_invite(
remote_room_hosts,
room_id,
event.user_id
)
else:
yield self.handle_new_client_event(
requester,
event,
context,
extra_users=[target_user],
ratelimit=ratelimit,
)
prev_member_event = context.current_state.get(
(EventTypes.Member, target_user.to_string()),
None
)
if event.membership == Membership.JOIN:
if not prev_member_event or prev_member_event.membership != Membership.JOIN:
# Only fire user_joined_room if the user has acutally joined the
# room. Don't bother if the user is just changing their profile
# info.
yield user_joined_room(self.distributor, target_user, room_id)
elif event.membership == Membership.LEAVE:
if prev_member_event and prev_member_event.membership == Membership.JOIN:
user_left_room(self.distributor, target_user, room_id)
def _can_guest_join(self, current_state):
"""
Returns whether a guest can join a room based on its current state.
"""
guest_access = current_state.get((EventTypes.GuestAccess, ""), None)
return (
guest_access
and guest_access.content
and "guest_access" in guest_access.content
and guest_access.content["guest_access"] == "can_join"
)
def _should_do_dance(self, context, inviter, room_hosts=None):
# TODO: Shouldn't this be remote_room_host?
room_hosts = room_hosts or []
is_host_in_room = self.is_host_in_room(context.current_state)
if is_host_in_room:
return False, room_hosts
if inviter and not self.hs.is_mine(inviter):
room_hosts.append(inviter.domain)
return True, room_hosts
@defer.inlineCallbacks
def lookup_room_alias(self, room_alias):
"""
Get the room ID associated with a room alias.
Args:
room_alias (RoomAlias): The alias to look up.
Returns:
A tuple of:
The room ID as a RoomID object.
Hosts likely to be participating in the room ([str]).
Raises:
SynapseError if room alias could not be found.
"""
directory_handler = self.hs.get_handlers().directory_handler
mapping = yield directory_handler.get_association(room_alias)
if not mapping:
raise SynapseError(404, "No such room alias")
room_id = mapping["room_id"]
servers = mapping["servers"]
defer.returnValue((RoomID.from_string(room_id), servers))
def get_inviter(self, user_id, current_state):
prev_state = current_state.get((EventTypes.Member, user_id))
if prev_state and prev_state.membership == Membership.INVITE:
return UserID.from_string(prev_state.user_id)
return None
@defer.inlineCallbacks
def get_joined_rooms_for_user(self, user):
"""Returns a list of roomids that the user has any of the given
membership states in."""
rooms = yield self.store.get_rooms_for_user(
user.to_string(),
)
# For some reason the list of events contains duplicates
# TODO(paul): work out why because I really don't think it should
room_ids = set(r.room_id for r in rooms)
defer.returnValue(room_ids)
@defer.inlineCallbacks
def do_3pid_invite(
self,
room_id,
inviter,
medium,
address,
id_server,
requester,
txn_id
):
invitee = yield self._lookup_3pid(
id_server, medium, address
)
if invitee:
handler = self.hs.get_handlers().room_member_handler
yield handler.update_membership(
requester,
UserID.from_string(invitee),
room_id,
"invite",
txn_id=txn_id,
)
else:
yield self._make_and_store_3pid_invite(
requester,
id_server,
medium,
address,
room_id,
inviter,
txn_id=txn_id
)
@defer.inlineCallbacks
def _lookup_3pid(self, id_server, medium, address):
"""Looks up a 3pid in the passed identity server.
Args:
id_server (str): The server name (including port, if required)
of the identity server to use.
medium (str): The type of the third party identifier (e.g. "email").
address (str): The third party identifier (e.g. "foo@example.com").
Returns:
(str) the matrix ID of the 3pid, or None if it is not recognized.
"""
try:
data = yield self.hs.get_simple_http_client().get_json(
"%s%s/_matrix/identity/api/v1/lookup" % (id_server_scheme, id_server,),
{
"medium": medium,
"address": address,
}
)
if "mxid" in data:
if "signatures" not in data:
raise AuthError(401, "No signatures on 3pid binding")
self.verify_any_signature(data, id_server)
defer.returnValue(data["mxid"])
except IOError as e:
logger.warn("Error from identity server lookup: %s" % (e,))
defer.returnValue(None)
@defer.inlineCallbacks
def verify_any_signature(self, data, server_hostname):
if server_hostname not in data["signatures"]:
raise AuthError(401, "No signature from server %s" % (server_hostname,))
for key_name, signature in data["signatures"][server_hostname].items():
key_data = yield self.hs.get_simple_http_client().get_json(
"%s%s/_matrix/identity/api/v1/pubkey/%s" %
(id_server_scheme, server_hostname, key_name,),
)
if "public_key" not in key_data:
raise AuthError(401, "No public key named %s from %s" %
(key_name, server_hostname,))
verify_signed_json(
data,
server_hostname,
decode_verify_key_bytes(key_name, decode_base64(key_data["public_key"]))
)
return
@defer.inlineCallbacks
def _make_and_store_3pid_invite(
self,
requester,
id_server,
medium,
address,
room_id,
user,
txn_id
):
room_state = yield self.hs.get_state_handler().get_current_state(room_id)
inviter_display_name = ""
inviter_avatar_url = ""
member_event = room_state.get((EventTypes.Member, user.to_string()))
if member_event:
inviter_display_name = member_event.content.get("displayname", "")
inviter_avatar_url = member_event.content.get("avatar_url", "")
canonical_room_alias = ""
canonical_alias_event = room_state.get((EventTypes.CanonicalAlias, ""))
if canonical_alias_event:
canonical_room_alias = canonical_alias_event.content.get("alias", "")
room_name = ""
room_name_event = room_state.get((EventTypes.Name, ""))
if room_name_event:
room_name = room_name_event.content.get("name", "")
room_join_rules = ""
join_rules_event = room_state.get((EventTypes.JoinRules, ""))
if join_rules_event:
room_join_rules = join_rules_event.content.get("join_rule", "")
room_avatar_url = ""
room_avatar_event = room_state.get((EventTypes.RoomAvatar, ""))
if room_avatar_event:
room_avatar_url = room_avatar_event.content.get("url", "")
token, public_keys, fallback_public_key, display_name = (
yield self._ask_id_server_for_third_party_invite(
id_server=id_server,
medium=medium,
address=address,
room_id=room_id,
inviter_user_id=user.to_string(),
room_alias=canonical_room_alias,
room_avatar_url=room_avatar_url,
room_join_rules=room_join_rules,
room_name=room_name,
inviter_display_name=inviter_display_name,
inviter_avatar_url=inviter_avatar_url
)
)
msg_handler = self.hs.get_handlers().message_handler
yield msg_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.ThirdPartyInvite,
"content": {
"display_name": display_name,
"public_keys": public_keys,
# For backwards compatibility:
"key_validity_url": fallback_public_key["key_validity_url"],
"public_key": fallback_public_key["public_key"],
},
"room_id": room_id,
"sender": user.to_string(),
"state_key": token,
},
txn_id=txn_id,
)
@defer.inlineCallbacks
def _ask_id_server_for_third_party_invite(
self,
id_server,
medium,
address,
room_id,
inviter_user_id,
room_alias,
room_avatar_url,
room_join_rules,
room_name,
inviter_display_name,
inviter_avatar_url
):
"""
Asks an identity server for a third party invite.
:param id_server (str): hostname + optional port for the identity server.
:param medium (str): The literal string "email".
:param address (str): The third party address being invited.
:param room_id (str): The ID of the room to which the user is invited.
:param inviter_user_id (str): The user ID of the inviter.
:param room_alias (str): An alias for the room, for cosmetic
notifications.
:param room_avatar_url (str): The URL of the room's avatar, for cosmetic
notifications.
:param room_join_rules (str): The join rules of the email
(e.g. "public").
:param room_name (str): The m.room.name of the room.
:param inviter_display_name (str): The current display name of the
inviter.
:param inviter_avatar_url (str): The URL of the inviter's avatar.
:return: A deferred tuple containing:
token (str): The token which must be signed to prove authenticity.
public_keys ([{"public_key": str, "key_validity_url": str}]):
public_key is a base64-encoded ed25519 public key.
fallback_public_key: One element from public_keys.
display_name (str): A user-friendly name to represent the invited
user.
"""
is_url = "%s%s/_matrix/identity/api/v1/store-invite" % (
id_server_scheme, id_server,
)
invite_config = {
"medium": medium,
"address": address,
"room_id": room_id,
"room_alias": room_alias,
"room_avatar_url": room_avatar_url,
"room_join_rules": room_join_rules,
"room_name": room_name,
"sender": inviter_user_id,
"sender_display_name": inviter_display_name,
"sender_avatar_url": inviter_avatar_url,
}
if self.hs.config.invite_3pid_guest:
registration_handler = self.hs.get_handlers().registration_handler
guest_access_token = yield registration_handler.guest_access_token_for(
medium=medium,
address=address,
inviter_user_id=inviter_user_id,
)
guest_user_info = yield self.hs.get_auth().get_user_by_access_token(
guest_access_token
)
invite_config.update({
"guest_access_token": guest_access_token,
"guest_user_id": guest_user_info["user"].to_string(),
})
data = yield self.hs.get_simple_http_client().post_urlencoded_get_json(
is_url,
invite_config
)
# TODO: Check for success
token = data["token"]
public_keys = data.get("public_keys", [])
if "public_key" in data:
fallback_public_key = {
"public_key": data["public_key"],
"key_validity_url": "%s%s/_matrix/identity/api/v1/pubkey/isvalid" % (
id_server_scheme, id_server,
),
}
else:
fallback_public_key = public_keys[0]
if not public_keys:
public_keys.append(fallback_public_key)
display_name = data["display_name"]
defer.returnValue((token, public_keys, fallback_public_key, display_name))
def forget(self, user, room_id):
return self.store.forget(user.to_string(), room_id)
class RoomListHandler(BaseHandler):
def __init__(self, hs):
super(RoomListHandler, self).__init__(hs)
self.response_cache = ResponseCache()
def get_public_room_list(self):
result = self.response_cache.get(())
if not result:
result = self.response_cache.set((), self._get_public_room_list())
return result
@defer.inlineCallbacks
def get_public_room_list(self):
def _get_public_room_list(self):
room_ids = yield self.store.get_public_room_ids()
results = []
@defer.inlineCallbacks
def handle_room(room_id):
aliases = yield self.store.get_aliases_for_room(room_id)
# We pull each bit of state out indvidually to avoid pulling the
# full state into memory. Due to how the caching works this should
# be fairly quick, even if not originally in the cache.
@@ -962,6 +372,14 @@ class RoomListHandler(BaseHandler):
defer.returnValue(None)
result = {"room_id": room_id}
joined_users = yield self.store.get_users_in_room(room_id)
if len(joined_users) == 0:
return
result["num_joined_members"] = len(joined_users)
aliases = yield self.store.get_aliases_for_room(room_id)
if aliases:
result["aliases"] = aliases
@@ -1001,21 +419,12 @@ class RoomListHandler(BaseHandler):
if avatar_url:
result["avatar_url"] = avatar_url
joined_users = yield self.store.get_users_in_room(room_id)
result["num_joined_members"] = len(joined_users)
results.append(result)
defer.returnValue(result)
result = []
for chunk in (room_ids[i:i + 10] for i in xrange(0, len(room_ids), 10)):
chunk_result = yield defer.gatherResults([
handle_room(room_id)
for room_id in chunk
], consumeErrors=True).addErrback(unwrapFirstError)
result.extend(v for v in chunk_result if v)
yield concurrently_execute(handle_room, room_ids, 10)
# FIXME (erikj): START is no longer a valid value
defer.returnValue({"start": "START", "end": "END", "chunk": result})
defer.returnValue({"start": "START", "end": "END", "chunk": results})
class RoomContextHandler(BaseHandler):

View File

@@ -0,0 +1,722 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from ._base import BaseHandler
from synapse.types import UserID, RoomID, Requester
from synapse.api.constants import (
EventTypes, Membership,
)
from synapse.api.errors import AuthError, SynapseError, Codes
from synapse.util.async import Linearizer
from synapse.util.distributor import user_left_room, user_joined_room
from signedjson.sign import verify_signed_json
from signedjson.key import decode_verify_key_bytes
from unpaddedbase64 import decode_base64
import logging
logger = logging.getLogger(__name__)
id_server_scheme = "https://"
class RoomMemberHandler(BaseHandler):
# TODO(paul): This handler currently contains a messy conflation of
# low-level API that works on UserID objects and so on, and REST-level
# API that takes ID strings and returns pagination chunks. These concerns
# ought to be separated out a lot better.
def __init__(self, hs):
super(RoomMemberHandler, self).__init__(hs)
self.member_linearizer = Linearizer()
self.clock = hs.get_clock()
self.distributor = hs.get_distributor()
self.distributor.declare("user_joined_room")
self.distributor.declare("user_left_room")
@defer.inlineCallbacks
def get_room_members(self, room_id):
users = yield self.store.get_users_in_room(room_id)
defer.returnValue([UserID.from_string(u) for u in users])
@defer.inlineCallbacks
def fetch_room_distributions_into(self, room_id, localusers=None,
remotedomains=None, ignore_user=None):
"""Fetch the distribution of a room, adding elements to either
'localusers' or 'remotedomains', which should be a set() if supplied.
If ignore_user is set, ignore that user.
This function returns nothing; its result is performed by the
side-effect on the two passed sets. This allows easy accumulation of
member lists of multiple rooms at once if required.
"""
members = yield self.get_room_members(room_id)
for member in members:
if ignore_user is not None and member == ignore_user:
continue
if self.hs.is_mine(member):
if localusers is not None:
localusers.add(member)
else:
if remotedomains is not None:
remotedomains.add(member.domain)
@defer.inlineCallbacks
def _local_membership_update(
self, requester, target, room_id, membership,
prev_event_ids,
txn_id=None,
ratelimit=True,
):
msg_handler = self.hs.get_handlers().message_handler
content = {"membership": membership}
if requester.is_guest:
content["kind"] = "guest"
event, context = yield msg_handler.create_event(
{
"type": EventTypes.Member,
"content": content,
"room_id": room_id,
"sender": requester.user.to_string(),
"state_key": target.to_string(),
# For backwards compatibility:
"membership": membership,
},
token_id=requester.access_token_id,
txn_id=txn_id,
prev_event_ids=prev_event_ids,
)
yield self.handle_new_client_event(
requester,
event,
context,
extra_users=[target],
ratelimit=ratelimit,
)
prev_member_event = context.current_state.get(
(EventTypes.Member, target.to_string()),
None
)
if event.membership == Membership.JOIN:
if not prev_member_event or prev_member_event.membership != Membership.JOIN:
# Only fire user_joined_room if the user has acutally joined the
# room. Don't bother if the user is just changing their profile
# info.
yield user_joined_room(self.distributor, target, room_id)
elif event.membership == Membership.LEAVE:
if prev_member_event and prev_member_event.membership == Membership.JOIN:
user_left_room(self.distributor, target, room_id)
@defer.inlineCallbacks
def remote_join(self, remote_room_hosts, room_id, user, content):
if len(remote_room_hosts) == 0:
raise SynapseError(404, "No known servers")
# We don't do an auth check if we are doing an invite
# join dance for now, since we're kinda implicitly checking
# that we are allowed to join when we decide whether or not we
# need to do the invite/join dance.
yield self.hs.get_handlers().federation_handler.do_invite_join(
remote_room_hosts,
room_id,
user.to_string(),
content,
)
yield user_joined_room(self.distributor, user, room_id)
def reject_remote_invite(self, user_id, room_id, remote_room_hosts):
return self.hs.get_handlers().federation_handler.do_remotely_reject_invite(
remote_room_hosts,
room_id,
user_id
)
@defer.inlineCallbacks
def update_membership(
self,
requester,
target,
room_id,
action,
txn_id=None,
remote_room_hosts=None,
third_party_signed=None,
ratelimit=True,
):
key = (target, room_id,)
with (yield self.member_linearizer.queue(key)):
result = yield self._update_membership(
requester,
target,
room_id,
action,
txn_id=txn_id,
remote_room_hosts=remote_room_hosts,
third_party_signed=third_party_signed,
ratelimit=ratelimit,
)
defer.returnValue(result)
@defer.inlineCallbacks
def _update_membership(
self,
requester,
target,
room_id,
action,
txn_id=None,
remote_room_hosts=None,
third_party_signed=None,
ratelimit=True,
):
effective_membership_state = action
if action in ["kick", "unban"]:
effective_membership_state = "leave"
if third_party_signed is not None:
replication = self.hs.get_replication_layer()
yield replication.exchange_third_party_invite(
third_party_signed["sender"],
target.to_string(),
room_id,
third_party_signed,
)
if not remote_room_hosts:
remote_room_hosts = []
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
current_state = yield self.state_handler.get_current_state(
room_id, latest_event_ids=latest_event_ids,
)
old_state = current_state.get((EventTypes.Member, target.to_string()))
old_membership = old_state.content.get("membership") if old_state else None
if action == "unban" and old_membership != "ban":
raise SynapseError(
403,
"Cannot unban user who was not banned (membership=%s)" % old_membership,
errcode=Codes.BAD_STATE
)
if old_membership == "ban" and action != "unban":
raise SynapseError(
403,
"Cannot %s user who was is banned" % (action,),
errcode=Codes.BAD_STATE
)
is_host_in_room = self.is_host_in_room(current_state)
if effective_membership_state == Membership.JOIN:
if requester.is_guest and not self._can_guest_join(current_state):
# This should be an auth check, but guests are a local concept,
# so don't really fit into the general auth process.
raise AuthError(403, "Guest access not allowed")
if not is_host_in_room:
inviter = yield self.get_inviter(target.to_string(), room_id)
if inviter and not self.hs.is_mine(inviter):
remote_room_hosts.append(inviter.domain)
content = {"membership": Membership.JOIN}
profile = self.hs.get_handlers().profile_handler
content["displayname"] = yield profile.get_displayname(target)
content["avatar_url"] = yield profile.get_avatar_url(target)
if requester.is_guest:
content["kind"] = "guest"
ret = yield self.remote_join(
remote_room_hosts, room_id, target, content
)
defer.returnValue(ret)
elif effective_membership_state == Membership.LEAVE:
if not is_host_in_room:
# perhaps we've been invited
inviter = yield self.get_inviter(target.to_string(), room_id)
if not inviter:
raise SynapseError(404, "Not a known room")
if self.hs.is_mine(inviter):
# the inviter was on our server, but has now left. Carry on
# with the normal rejection codepath.
#
# This is a bit of a hack, because the room might still be
# active on other servers.
pass
else:
# send the rejection to the inviter's HS.
remote_room_hosts = remote_room_hosts + [inviter.domain]
try:
ret = yield self.reject_remote_invite(
target.to_string(), room_id, remote_room_hosts
)
defer.returnValue(ret)
except SynapseError as e:
logger.warn("Failed to reject invite: %s", e)
yield self.store.locally_reject_invite(
target.to_string(), room_id
)
defer.returnValue({})
yield self._local_membership_update(
requester=requester,
target=target,
room_id=room_id,
membership=effective_membership_state,
txn_id=txn_id,
ratelimit=ratelimit,
prev_event_ids=latest_event_ids,
)
@defer.inlineCallbacks
def send_membership_event(
self,
requester,
event,
context,
remote_room_hosts=None,
ratelimit=True,
):
"""
Change the membership status of a user in a room.
Args:
requester (Requester): The local user who requested the membership
event. If None, certain checks, like whether this homeserver can
act as the sender, will be skipped.
event (SynapseEvent): The membership event.
context: The context of the event.
is_guest (bool): Whether the sender is a guest.
room_hosts ([str]): Homeservers which are likely to already be in
the room, and could be danced with in order to join this
homeserver for the first time.
ratelimit (bool): Whether to rate limit this request.
Raises:
SynapseError if there was a problem changing the membership.
"""
remote_room_hosts = remote_room_hosts or []
target_user = UserID.from_string(event.state_key)
room_id = event.room_id
if requester is not None:
sender = UserID.from_string(event.sender)
assert sender == requester.user, (
"Sender (%s) must be same as requester (%s)" %
(sender, requester.user)
)
assert self.hs.is_mine(sender), "Sender must be our own: %s" % (sender,)
else:
requester = Requester(target_user, None, False)
message_handler = self.hs.get_handlers().message_handler
prev_event = message_handler.deduplicate_state_event(event, context)
if prev_event is not None:
return
if event.membership == Membership.JOIN:
if requester.is_guest and not self._can_guest_join(context.current_state):
# This should be an auth check, but guests are a local concept,
# so don't really fit into the general auth process.
raise AuthError(403, "Guest access not allowed")
yield self.handle_new_client_event(
requester,
event,
context,
extra_users=[target_user],
ratelimit=ratelimit,
)
prev_member_event = context.current_state.get(
(EventTypes.Member, target_user.to_string()),
None
)
if event.membership == Membership.JOIN:
if not prev_member_event or prev_member_event.membership != Membership.JOIN:
# Only fire user_joined_room if the user has acutally joined the
# room. Don't bother if the user is just changing their profile
# info.
yield user_joined_room(self.distributor, target_user, room_id)
elif event.membership == Membership.LEAVE:
if prev_member_event and prev_member_event.membership == Membership.JOIN:
user_left_room(self.distributor, target_user, room_id)
def _can_guest_join(self, current_state):
"""
Returns whether a guest can join a room based on its current state.
"""
guest_access = current_state.get((EventTypes.GuestAccess, ""), None)
return (
guest_access
and guest_access.content
and "guest_access" in guest_access.content
and guest_access.content["guest_access"] == "can_join"
)
@defer.inlineCallbacks
def lookup_room_alias(self, room_alias):
"""
Get the room ID associated with a room alias.
Args:
room_alias (RoomAlias): The alias to look up.
Returns:
A tuple of:
The room ID as a RoomID object.
Hosts likely to be participating in the room ([str]).
Raises:
SynapseError if room alias could not be found.
"""
directory_handler = self.hs.get_handlers().directory_handler
mapping = yield directory_handler.get_association(room_alias)
if not mapping:
raise SynapseError(404, "No such room alias")
room_id = mapping["room_id"]
servers = mapping["servers"]
defer.returnValue((RoomID.from_string(room_id), servers))
@defer.inlineCallbacks
def get_inviter(self, user_id, room_id):
invite = yield self.store.get_invite_for_user_in_room(
user_id=user_id,
room_id=room_id,
)
if invite:
defer.returnValue(UserID.from_string(invite.sender))
@defer.inlineCallbacks
def get_joined_rooms_for_user(self, user):
"""Returns a list of roomids that the user has any of the given
membership states in."""
rooms = yield self.store.get_rooms_for_user(
user.to_string(),
)
# For some reason the list of events contains duplicates
# TODO(paul): work out why because I really don't think it should
room_ids = set(r.room_id for r in rooms)
defer.returnValue(room_ids)
@defer.inlineCallbacks
def do_3pid_invite(
self,
room_id,
inviter,
medium,
address,
id_server,
requester,
txn_id
):
invitee = yield self._lookup_3pid(
id_server, medium, address
)
if invitee:
handler = self.hs.get_handlers().room_member_handler
yield handler.update_membership(
requester,
UserID.from_string(invitee),
room_id,
"invite",
txn_id=txn_id,
)
else:
yield self._make_and_store_3pid_invite(
requester,
id_server,
medium,
address,
room_id,
inviter,
txn_id=txn_id
)
@defer.inlineCallbacks
def _lookup_3pid(self, id_server, medium, address):
"""Looks up a 3pid in the passed identity server.
Args:
id_server (str): The server name (including port, if required)
of the identity server to use.
medium (str): The type of the third party identifier (e.g. "email").
address (str): The third party identifier (e.g. "foo@example.com").
Returns:
str: the matrix ID of the 3pid, or None if it is not recognized.
"""
try:
data = yield self.hs.get_simple_http_client().get_json(
"%s%s/_matrix/identity/api/v1/lookup" % (id_server_scheme, id_server,),
{
"medium": medium,
"address": address,
}
)
if "mxid" in data:
if "signatures" not in data:
raise AuthError(401, "No signatures on 3pid binding")
self.verify_any_signature(data, id_server)
defer.returnValue(data["mxid"])
except IOError as e:
logger.warn("Error from identity server lookup: %s" % (e,))
defer.returnValue(None)
@defer.inlineCallbacks
def verify_any_signature(self, data, server_hostname):
if server_hostname not in data["signatures"]:
raise AuthError(401, "No signature from server %s" % (server_hostname,))
for key_name, signature in data["signatures"][server_hostname].items():
key_data = yield self.hs.get_simple_http_client().get_json(
"%s%s/_matrix/identity/api/v1/pubkey/%s" %
(id_server_scheme, server_hostname, key_name,),
)
if "public_key" not in key_data:
raise AuthError(401, "No public key named %s from %s" %
(key_name, server_hostname,))
verify_signed_json(
data,
server_hostname,
decode_verify_key_bytes(key_name, decode_base64(key_data["public_key"]))
)
return
@defer.inlineCallbacks
def _make_and_store_3pid_invite(
self,
requester,
id_server,
medium,
address,
room_id,
user,
txn_id
):
room_state = yield self.hs.get_state_handler().get_current_state(room_id)
inviter_display_name = ""
inviter_avatar_url = ""
member_event = room_state.get((EventTypes.Member, user.to_string()))
if member_event:
inviter_display_name = member_event.content.get("displayname", "")
inviter_avatar_url = member_event.content.get("avatar_url", "")
canonical_room_alias = ""
canonical_alias_event = room_state.get((EventTypes.CanonicalAlias, ""))
if canonical_alias_event:
canonical_room_alias = canonical_alias_event.content.get("alias", "")
room_name = ""
room_name_event = room_state.get((EventTypes.Name, ""))
if room_name_event:
room_name = room_name_event.content.get("name", "")
room_join_rules = ""
join_rules_event = room_state.get((EventTypes.JoinRules, ""))
if join_rules_event:
room_join_rules = join_rules_event.content.get("join_rule", "")
room_avatar_url = ""
room_avatar_event = room_state.get((EventTypes.RoomAvatar, ""))
if room_avatar_event:
room_avatar_url = room_avatar_event.content.get("url", "")
token, public_keys, fallback_public_key, display_name = (
yield self._ask_id_server_for_third_party_invite(
id_server=id_server,
medium=medium,
address=address,
room_id=room_id,
inviter_user_id=user.to_string(),
room_alias=canonical_room_alias,
room_avatar_url=room_avatar_url,
room_join_rules=room_join_rules,
room_name=room_name,
inviter_display_name=inviter_display_name,
inviter_avatar_url=inviter_avatar_url
)
)
msg_handler = self.hs.get_handlers().message_handler
yield msg_handler.create_and_send_nonmember_event(
requester,
{
"type": EventTypes.ThirdPartyInvite,
"content": {
"display_name": display_name,
"public_keys": public_keys,
# For backwards compatibility:
"key_validity_url": fallback_public_key["key_validity_url"],
"public_key": fallback_public_key["public_key"],
},
"room_id": room_id,
"sender": user.to_string(),
"state_key": token,
},
txn_id=txn_id,
)
@defer.inlineCallbacks
def _ask_id_server_for_third_party_invite(
self,
id_server,
medium,
address,
room_id,
inviter_user_id,
room_alias,
room_avatar_url,
room_join_rules,
room_name,
inviter_display_name,
inviter_avatar_url
):
"""
Asks an identity server for a third party invite.
Args:
id_server (str): hostname + optional port for the identity server.
medium (str): The literal string "email".
address (str): The third party address being invited.
room_id (str): The ID of the room to which the user is invited.
inviter_user_id (str): The user ID of the inviter.
room_alias (str): An alias for the room, for cosmetic notifications.
room_avatar_url (str): The URL of the room's avatar, for cosmetic
notifications.
room_join_rules (str): The join rules of the email (e.g. "public").
room_name (str): The m.room.name of the room.
inviter_display_name (str): The current display name of the
inviter.
inviter_avatar_url (str): The URL of the inviter's avatar.
Returns:
A deferred tuple containing:
token (str): The token which must be signed to prove authenticity.
public_keys ([{"public_key": str, "key_validity_url": str}]):
public_key is a base64-encoded ed25519 public key.
fallback_public_key: One element from public_keys.
display_name (str): A user-friendly name to represent the invited
user.
"""
is_url = "%s%s/_matrix/identity/api/v1/store-invite" % (
id_server_scheme, id_server,
)
invite_config = {
"medium": medium,
"address": address,
"room_id": room_id,
"room_alias": room_alias,
"room_avatar_url": room_avatar_url,
"room_join_rules": room_join_rules,
"room_name": room_name,
"sender": inviter_user_id,
"sender_display_name": inviter_display_name,
"sender_avatar_url": inviter_avatar_url,
}
if self.hs.config.invite_3pid_guest:
registration_handler = self.hs.get_handlers().registration_handler
guest_access_token = yield registration_handler.guest_access_token_for(
medium=medium,
address=address,
inviter_user_id=inviter_user_id,
)
guest_user_info = yield self.hs.get_auth().get_user_by_access_token(
guest_access_token
)
invite_config.update({
"guest_access_token": guest_access_token,
"guest_user_id": guest_user_info["user"].to_string(),
})
data = yield self.hs.get_simple_http_client().post_urlencoded_get_json(
is_url,
invite_config
)
# TODO: Check for success
token = data["token"]
public_keys = data.get("public_keys", [])
if "public_key" in data:
fallback_public_key = {
"public_key": data["public_key"],
"key_validity_url": "%s%s/_matrix/identity/api/v1/pubkey/isvalid" % (
id_server_scheme, id_server,
),
}
else:
fallback_public_key = public_keys[0]
if not public_keys:
public_keys.append(fallback_public_key)
display_name = data["display_name"]
defer.returnValue((token, public_keys, fallback_public_key, display_name))
@defer.inlineCallbacks
def forget(self, user, room_id):
user_id = user.to_string()
member = yield self.state_handler.get_current_state(
room_id=room_id,
event_type=EventTypes.Member,
state_key=user_id
)
membership = member.membership if member else None
if membership is not None and membership != Membership.LEAVE:
raise SynapseError(400, "User %s in room %s" % (
user_id, room_id
))
if membership:
yield self.store.forget(user_id, room_id)

View File

@@ -17,9 +17,10 @@ from ._base import BaseHandler
from synapse.streams.config import PaginationConfig
from synapse.api.constants import Membership, EventTypes
from synapse.util import unwrapFirstError
from synapse.util.logcontext import LoggingContext, preserve_fn
from synapse.util.async import concurrently_execute
from synapse.util.logcontext import LoggingContext
from synapse.util.metrics import Measure
from synapse.util.caches.response_cache import ResponseCache
from synapse.push.clientformat import format_push_rules_for_user
from twisted.internet import defer
@@ -35,6 +36,7 @@ SyncConfig = collections.namedtuple("SyncConfig", [
"user",
"filter_collection",
"is_guest",
"request_key",
])
@@ -136,8 +138,8 @@ class SyncHandler(BaseHandler):
super(SyncHandler, self).__init__(hs)
self.event_sources = hs.get_event_sources()
self.clock = hs.get_clock()
self.response_cache = ResponseCache()
@defer.inlineCallbacks
def wait_for_sync_for_user(self, sync_config, since_token=None, timeout=0,
full_state=False):
"""Get the sync for a client if we have new data for it now. Otherwise
@@ -146,7 +148,19 @@ class SyncHandler(BaseHandler):
Returns:
A Deferred SyncResult.
"""
result = self.response_cache.get(sync_config.request_key)
if not result:
result = self.response_cache.set(
sync_config.request_key,
self._wait_for_sync_for_user(
sync_config, since_token, timeout, full_state
)
)
return result
@defer.inlineCallbacks
def _wait_for_sync_for_user(self, sync_config, since_token, timeout,
full_state):
context = LoggingContext.current_context()
if context:
if since_token is None:
@@ -236,58 +250,50 @@ class SyncHandler(BaseHandler):
joined = []
invited = []
archived = []
deferreds = []
room_list_chunks = [room_list[i:i + 10] for i in xrange(0, len(room_list), 10)]
for room_list_chunk in room_list_chunks:
for event in room_list_chunk:
if event.membership == Membership.JOIN:
room_sync_deferred = preserve_fn(
self.full_state_sync_for_joined_room
)(
room_id=event.room_id,
sync_config=sync_config,
now_token=now_token,
timeline_since_token=timeline_since_token,
ephemeral_by_room=ephemeral_by_room,
tags_by_room=tags_by_room,
account_data_by_room=account_data_by_room,
)
room_sync_deferred.addCallback(joined.append)
deferreds.append(room_sync_deferred)
elif event.membership == Membership.INVITE:
invite = yield self.store.get_event(event.event_id)
invited.append(InvitedSyncResult(
room_id=event.room_id,
invite=invite,
))
elif event.membership in (Membership.LEAVE, Membership.BAN):
# Always send down rooms we were banned or kicked from.
if not sync_config.filter_collection.include_leave:
if event.membership == Membership.LEAVE:
if sync_config.user.to_string() == event.sender:
continue
user_id = sync_config.user.to_string()
leave_token = now_token.copy_and_replace(
"room_key", "s%d" % (event.stream_ordering,)
)
room_sync_deferred = preserve_fn(
self.full_state_sync_for_archived_room
)(
sync_config=sync_config,
room_id=event.room_id,
leave_event_id=event.event_id,
leave_token=leave_token,
timeline_since_token=timeline_since_token,
tags_by_room=tags_by_room,
account_data_by_room=account_data_by_room,
)
room_sync_deferred.addCallback(archived.append)
deferreds.append(room_sync_deferred)
@defer.inlineCallbacks
def _generate_room_entry(event):
if event.membership == Membership.JOIN:
room_result = yield self.full_state_sync_for_joined_room(
room_id=event.room_id,
sync_config=sync_config,
now_token=now_token,
timeline_since_token=timeline_since_token,
ephemeral_by_room=ephemeral_by_room,
tags_by_room=tags_by_room,
account_data_by_room=account_data_by_room,
)
joined.append(room_result)
elif event.membership == Membership.INVITE:
invite = yield self.store.get_event(event.event_id)
invited.append(InvitedSyncResult(
room_id=event.room_id,
invite=invite,
))
elif event.membership in (Membership.LEAVE, Membership.BAN):
# Always send down rooms we were banned or kicked from.
if not sync_config.filter_collection.include_leave:
if event.membership == Membership.LEAVE:
if user_id == event.sender:
return
yield defer.gatherResults(
deferreds, consumeErrors=True
).addErrback(unwrapFirstError)
leave_token = now_token.copy_and_replace(
"room_key", "s%d" % (event.stream_ordering,)
)
room_result = yield self.full_state_sync_for_archived_room(
sync_config=sync_config,
room_id=event.room_id,
leave_event_id=event.event_id,
leave_token=leave_token,
timeline_since_token=timeline_since_token,
tags_by_room=tags_by_room,
account_data_by_room=account_data_by_room,
)
archived.append(room_result)
yield concurrently_execute(_generate_room_entry, room_list, 10)
account_data_for_user = sync_config.filter_collection.filter_account_data(
self.account_data_for_user(account_data)
@@ -657,7 +663,8 @@ class SyncHandler(BaseHandler):
def load_filtered_recents(self, room_id, sync_config, now_token,
since_token=None, recents=None, newly_joined_room=False):
"""
:returns a Deferred TimelineBatch
Returns:
a Deferred TimelineBatch
"""
with Measure(self.clock, "load_filtered_recents"):
filtering_factor = 2
@@ -824,8 +831,11 @@ class SyncHandler(BaseHandler):
"""
Get the room state after the given event
:param synapse.events.EventBase event: event of interest
:return: A Deferred map from ((type, state_key)->Event)
Args:
event(synapse.events.EventBase): event of interest
Returns:
A Deferred map from ((type, state_key)->Event)
"""
state = yield self.store.get_state_for_event(event.event_id)
if event.is_state():
@@ -836,9 +846,13 @@ class SyncHandler(BaseHandler):
@defer.inlineCallbacks
def get_state_at(self, room_id, stream_position):
""" Get the room state at a particular stream position
:param str room_id: room for which to get state
:param StreamToken stream_position: point at which to get state
:returns: A Deferred map from ((type, state_key)->Event)
Args:
room_id(str): room for which to get state
stream_position(StreamToken): point at which to get state
Returns:
A Deferred map from ((type, state_key)->Event)
"""
last_events, token = yield self.store.get_recent_events_for_room(
room_id, end_token=stream_position.room_key, limit=1,
@@ -859,15 +873,18 @@ class SyncHandler(BaseHandler):
""" Works out the differnce in state between the start of the timeline
and the previous sync.
:param str room_id
:param TimelineBatch batch: The timeline batch for the room that will
be sent to the user.
:param sync_config
:param str since_token: Token of the end of the previous batch. May be None.
:param str now_token: Token of the end of the current batch.
:param bool full_state: Whether to force returning the full state.
Args:
room_id(str):
batch(synapse.handlers.sync.TimelineBatch): The timeline batch for
the room that will be sent to the user.
sync_config(synapse.handlers.sync.SyncConfig):
since_token(str|None): Token of the end of the previous batch. May
be None.
now_token(str): Token of the end of the current batch.
full_state(bool): Whether to force returning the full state.
:returns A new event dictionary
Returns:
A deferred new event dictionary
"""
# TODO(mjark) Check if the state events were received by the server
# after the previous sync, since we need to include those state
@@ -939,11 +956,13 @@ class SyncHandler(BaseHandler):
Check if the user has just joined the given room (so should
be given the full state)
:param sync_config:
:param dict[(str,str), synapse.events.FrozenEvent] state_delta: the
difference in state since the last sync
Args:
sync_config(synapse.handlers.sync.SyncConfig):
state_delta(dict[(str,str), synapse.events.FrozenEvent]): the
difference in state since the last sync
:returns A deferred Tuple (state_delta, limited)
Returns:
A deferred Tuple (state_delta, limited)
"""
join_event = state_delta.get((
EventTypes.Member, sync_config.user.to_string()), None)

View File

@@ -15,17 +15,24 @@
from OpenSSL import SSL
from OpenSSL.SSL import VERIFY_NONE
from synapse.api.errors import CodeMessageException
from synapse.api.errors import (
CodeMessageException, SynapseError, Codes,
)
from synapse.util.logcontext import preserve_context_over_fn
import synapse.metrics
from synapse.http.endpoint import SpiderEndpoint
from canonicaljson import encode_canonical_json
from twisted.internet import defer, reactor, ssl
from twisted.internet import defer, reactor, ssl, protocol
from twisted.internet.endpoints import SSL4ClientEndpoint, TCP4ClientEndpoint
from twisted.web.client import (
Agent, readBody, FileBodyProducer, PartialDownloadError,
BrowserLikeRedirectAgent, ContentDecoderAgent, GzipDecoder, Agent,
readBody, FileBodyProducer, PartialDownloadError,
)
from twisted.web.http import PotentialDataLoss
from twisted.web.http_headers import Headers
from twisted.web._newclient import ResponseDone
from StringIO import StringIO
@@ -238,6 +245,107 @@ class SimpleHttpClient(object):
else:
raise CodeMessageException(response.code, body)
# XXX: FIXME: This is horribly copy-pasted from matrixfederationclient.
# The two should be factored out.
@defer.inlineCallbacks
def get_file(self, url, output_stream, max_size=None):
"""GETs a file from a given URL
Args:
url (str): The URL to GET
output_stream (file): File to write the response body to.
Returns:
A (int,dict,string,int) tuple of the file length, dict of the response
headers, absolute URI of the response and HTTP response code.
"""
response = yield self.request(
"GET",
url.encode("ascii"),
headers=Headers({
b"User-Agent": [self.user_agent],
})
)
headers = dict(response.headers.getAllRawHeaders())
if 'Content-Length' in headers and headers['Content-Length'] > max_size:
logger.warn("Requested URL is too large > %r bytes" % (self.max_size,))
raise SynapseError(
502,
"Requested file is too large > %r bytes" % (self.max_size,),
Codes.TOO_LARGE,
)
if response.code > 299:
logger.warn("Got %d when downloading %s" % (response.code, url))
raise SynapseError(
502,
"Got error %d" % (response.code,),
Codes.UNKNOWN,
)
# TODO: if our Content-Type is HTML or something, just read the first
# N bytes into RAM rather than saving it all to disk only to read it
# straight back in again
try:
length = yield preserve_context_over_fn(
_readBodyToFile,
response, output_stream, max_size
)
except Exception as e:
logger.exception("Failed to download body")
raise SynapseError(
502,
("Failed to download remote body: %s" % e),
Codes.UNKNOWN,
)
defer.returnValue((length, headers, response.request.absoluteURI, response.code))
# XXX: FIXME: This is horribly copy-pasted from matrixfederationclient.
# The two should be factored out.
class _ReadBodyToFileProtocol(protocol.Protocol):
def __init__(self, stream, deferred, max_size):
self.stream = stream
self.deferred = deferred
self.length = 0
self.max_size = max_size
def dataReceived(self, data):
self.stream.write(data)
self.length += len(data)
if self.max_size is not None and self.length >= self.max_size:
self.deferred.errback(SynapseError(
502,
"Requested file is too large > %r bytes" % (self.max_size,),
Codes.TOO_LARGE,
))
self.deferred = defer.Deferred()
self.transport.loseConnection()
def connectionLost(self, reason):
if reason.check(ResponseDone):
self.deferred.callback(self.length)
elif reason.check(PotentialDataLoss):
# stolen from https://github.com/twisted/treq/pull/49/files
# http://twistedmatrix.com/trac/ticket/4840
self.deferred.callback(self.length)
else:
self.deferred.errback(reason)
# XXX: FIXME: This is horribly copy-pasted from matrixfederationclient.
# The two should be factored out.
def _readBodyToFile(response, stream, max_size):
d = defer.Deferred()
response.deliverBody(_ReadBodyToFileProtocol(stream, d, max_size))
return d
class CaptchaServerHttpClient(SimpleHttpClient):
"""
@@ -269,6 +377,59 @@ class CaptchaServerHttpClient(SimpleHttpClient):
defer.returnValue(e.response)
class SpiderEndpointFactory(object):
def __init__(self, hs):
self.blacklist = hs.config.url_preview_ip_range_blacklist
self.policyForHTTPS = hs.get_http_client_context_factory()
def endpointForURI(self, uri):
logger.info("Getting endpoint for %s", uri.toBytes())
if uri.scheme == "http":
return SpiderEndpoint(
reactor, uri.host, uri.port, self.blacklist,
endpoint=TCP4ClientEndpoint,
endpoint_kw_args={
'timeout': 15
},
)
elif uri.scheme == "https":
tlsPolicy = self.policyForHTTPS.creatorForNetloc(uri.host, uri.port)
return SpiderEndpoint(
reactor, uri.host, uri.port, self.blacklist,
endpoint=SSL4ClientEndpoint,
endpoint_kw_args={
'sslContextFactory': tlsPolicy,
'timeout': 15
},
)
else:
logger.warn("Can't get endpoint for unrecognised scheme %s", uri.scheme)
class SpiderHttpClient(SimpleHttpClient):
"""
Separate HTTP client for spidering arbitrary URLs.
Special in that it follows retries and has a UA that looks
like a browser.
used by the preview_url endpoint in the content repo.
"""
def __init__(self, hs):
SimpleHttpClient.__init__(self, hs)
# clobber the base class's agent and UA:
self.agent = ContentDecoderAgent(
BrowserLikeRedirectAgent(
Agent.usingEndpointFactory(
reactor,
SpiderEndpointFactory(hs)
)
), [('gzip', GzipDecoder)]
)
# We could look like Chrome:
# self.user_agent = ("Mozilla/5.0 (%s) (KHTML, like Gecko)
# Chrome Safari" % hs.version_string)
def encode_urlencode_args(args):
return {k: encode_urlencode_arg(v) for k, v in args.items()}
@@ -301,5 +462,8 @@ class InsecureInterceptableContextFactory(ssl.ContextFactory):
self._context = SSL.Context(SSL.SSLv23_METHOD)
self._context.set_verify(VERIFY_NONE, lambda *_: None)
def getContext(self, hostname, port):
def getContext(self, hostname=None, port=None):
return self._context
def creatorForNetloc(self, hostname, port):
return self

View File

@@ -22,6 +22,7 @@ from twisted.names.error import DNSNameError, DomainError
import collections
import logging
import random
import time
logger = logging.getLogger(__name__)
@@ -31,7 +32,7 @@ SERVER_CACHE = {}
_Server = collections.namedtuple(
"_Server", "priority weight host port"
"_Server", "priority weight host port expires"
)
@@ -74,6 +75,37 @@ def matrix_federation_endpoint(reactor, destination, ssl_context_factory=None,
return transport_endpoint(reactor, domain, port, **endpoint_kw_args)
class SpiderEndpoint(object):
"""An endpoint which refuses to connect to blacklisted IP addresses
Implements twisted.internet.interfaces.IStreamClientEndpoint.
"""
def __init__(self, reactor, host, port, blacklist,
endpoint=TCP4ClientEndpoint, endpoint_kw_args={}):
self.reactor = reactor
self.host = host
self.port = port
self.blacklist = blacklist
self.endpoint = endpoint
self.endpoint_kw_args = endpoint_kw_args
@defer.inlineCallbacks
def connect(self, protocolFactory):
address = yield self.reactor.resolve(self.host)
from netaddr import IPAddress
if IPAddress(address) in self.blacklist:
raise ConnectError(
"Refusing to spider blacklisted IP address %s" % address
)
logger.info("Connecting to %s:%s", address, self.port)
endpoint = self.endpoint(
self.reactor, address, self.port, **self.endpoint_kw_args
)
connection = yield endpoint.connect(protocolFactory)
defer.returnValue(connection)
class SRVClientEndpoint(object):
"""An endpoint which looks up SRV records for a service.
Cycles through the list of servers starting with each call to connect
@@ -92,7 +124,8 @@ class SRVClientEndpoint(object):
host=domain,
port=default_port,
priority=0,
weight=0
weight=0,
expires=0,
)
else:
self.default_server = None
@@ -118,7 +151,7 @@ class SRVClientEndpoint(object):
return self.default_server
else:
raise ConnectError(
"Not server available for %s", self.service_name
"Not server available for %s" % self.service_name
)
min_priority = self.servers[0].priority
@@ -153,7 +186,13 @@ class SRVClientEndpoint(object):
@defer.inlineCallbacks
def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE):
def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE, clock=time):
cache_entry = cache.get(service_name, None)
if cache_entry:
if all(s.expires > int(clock.time()) for s in cache_entry):
servers = list(cache_entry)
defer.returnValue(servers)
servers = []
try:
@@ -166,34 +205,33 @@ def resolve_service(service_name, dns_client=client, cache=SERVER_CACHE):
and answers[0].type == dns.SRV
and answers[0].payload
and answers[0].payload.target == dns.Name('.')):
raise ConnectError("Service %s unavailable", service_name)
raise ConnectError("Service %s unavailable" % service_name)
for answer in answers:
if answer.type != dns.SRV or not answer.payload:
continue
payload = answer.payload
host = str(payload.target)
srv_ttl = answer.ttl
try:
answers, _, _ = yield dns_client.lookupAddress(host)
except DNSNameError:
continue
ips = [
answer.payload.dottedQuad()
for answer in answers
if answer.type == dns.A and answer.payload
]
for answer in answers:
if answer.type == dns.A and answer.payload:
ip = answer.payload.dottedQuad()
host_ttl = min(srv_ttl, answer.ttl)
for ip in ips:
servers.append(_Server(
host=ip,
port=int(payload.port),
priority=int(payload.priority),
weight=int(payload.weight)
))
servers.append(_Server(
host=ip,
port=int(payload.port),
priority=int(payload.priority),
weight=int(payload.weight),
expires=int(clock.time()) + host_ttl,
))
servers.sort()
cache[service_name] = list(servers)

View File

@@ -26,14 +26,19 @@ logger = logging.getLogger(__name__)
def parse_integer(request, name, default=None, required=False):
"""Parse an integer parameter from the request string
:param request: the twisted HTTP request.
:param name (str): the name of the query parameter.
:param default: value to use if the parameter is absent, defaults to None.
:param required (bool): whether to raise a 400 SynapseError if the
parameter is absent, defaults to False.
:return: An int value or the default.
:raises
SynapseError if the parameter is absent and required, or if the
Args:
request: the twisted HTTP request.
name (str): the name of the query parameter.
default (int|None): value to use if the parameter is absent, defaults
to None.
required (bool): whether to raise a 400 SynapseError if the
parameter is absent, defaults to False.
Returns:
int|None: An int value or the default.
Raises:
SynapseError: if the parameter is absent and required, or if the
parameter is present and not an integer.
"""
if name in request.args:
@@ -53,14 +58,19 @@ def parse_integer(request, name, default=None, required=False):
def parse_boolean(request, name, default=None, required=False):
"""Parse a boolean parameter from the request query string
:param request: the twisted HTTP request.
:param name (str): the name of the query parameter.
:param default: value to use if the parameter is absent, defaults to None.
:param required (bool): whether to raise a 400 SynapseError if the
parameter is absent, defaults to False.
:return: A bool value or the default.
:raises
SynapseError if the parameter is absent and required, or if the
Args:
request: the twisted HTTP request.
name (str): the name of the query parameter.
default (bool|None): value to use if the parameter is absent, defaults
to None.
required (bool): whether to raise a 400 SynapseError if the
parameter is absent, defaults to False.
Returns:
bool|None: A bool value or the default.
Raises:
SynapseError: if the parameter is absent and required, or if the
parameter is present and not one of "true" or "false".
"""
@@ -88,15 +98,20 @@ def parse_string(request, name, default=None, required=False,
allowed_values=None, param_type="string"):
"""Parse a string parameter from the request query string.
:param request: the twisted HTTP request.
:param name (str): the name of the query parameter.
:param default: value to use if the parameter is absent, defaults to None.
:param required (bool): whether to raise a 400 SynapseError if the
parameter is absent, defaults to False.
:param allowed_values (list): List of allowed values for the string,
or None if any value is allowed, defaults to None
:return: A string value or the default.
:raises
Args:
request: the twisted HTTP request.
name (str): the name of the query parameter.
default (str|None): value to use if the parameter is absent, defaults
to None.
required (bool): whether to raise a 400 SynapseError if the
parameter is absent, defaults to False.
allowed_values (list[str]): List of allowed values for the string,
or None if any value is allowed, defaults to None
Returns:
str|None: A string value or the default.
Raises:
SynapseError if the parameter is absent and required, or if the
parameter is present, must be one of a list of allowed values and
is not one of those allowed values.
@@ -122,9 +137,13 @@ def parse_string(request, name, default=None, required=False,
def parse_json_value_from_request(request):
"""Parse a JSON value from the body of a twisted HTTP request.
:param request: the twisted HTTP request.
:returns: The JSON value.
:raises
Args:
request: the twisted HTTP request.
Returns:
The JSON value.
Raises:
SynapseError if the request body couldn't be decoded as JSON.
"""
try:
@@ -143,8 +162,10 @@ def parse_json_value_from_request(request):
def parse_json_object_from_request(request):
"""Parse a JSON object from the body of a twisted HTTP request.
:param request: the twisted HTTP request.
:raises
Args:
request: the twisted HTTP request.
Raises:
SynapseError if the request body couldn't be decoded as JSON or
if it wasn't a JSON object.
"""

View File

@@ -503,13 +503,14 @@ class Notifier(object):
def wait_for_replication(self, callback, timeout):
"""Wait for an event to happen.
:param callback:
Gets called whenever an event happens. If this returns a truthy
value then ``wait_for_replication`` returns, otherwise it waits
for another event.
:param int timeout:
How many milliseconds to wait for callback return a truthy value.
:returns:
Args:
callback: Gets called whenever an event happens. If this returns a
truthy value then ``wait_for_replication`` returns, otherwise
it waits for another event.
timeout: How many milliseconds to wait for callback return a truthy
value.
Returns:
A deferred that resolves with the value returned by the callback.
"""
listener = _NotificationListener(None)

View File

@@ -13,333 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from synapse.streams.config import PaginationConfig
from synapse.types import StreamToken
from synapse.util.logcontext import LoggingContext
from synapse.util.metrics import Measure
import synapse.util.async
from .push_rule_evaluator import evaluator_for_user_id
import logging
import random
logger = logging.getLogger(__name__)
_NEXT_ID = 1
def _get_next_id():
global _NEXT_ID
_id = _NEXT_ID
_NEXT_ID += 1
return _id
# Pushers could now be moved to pull out of the event_push_actions table instead
# of listening on the event stream: this would avoid them having to run the
# rules again.
class Pusher(object):
INITIAL_BACKOFF = 1000
MAX_BACKOFF = 60 * 60 * 1000
GIVE_UP_AFTER = 24 * 60 * 60 * 1000
def __init__(self, _hs, user_id, app_id,
app_display_name, device_display_name, pushkey, pushkey_ts,
data, last_token, last_success, failing_since):
self.hs = _hs
self.evStreamHandler = self.hs.get_handlers().event_stream_handler
self.store = self.hs.get_datastore()
self.clock = self.hs.get_clock()
self.user_id = user_id
self.app_id = app_id
self.app_display_name = app_display_name
self.device_display_name = device_display_name
self.pushkey = pushkey
self.pushkey_ts = pushkey_ts
self.data = data
self.last_token = last_token
self.last_success = last_success # not actually used
self.backoff_delay = Pusher.INITIAL_BACKOFF
self.failing_since = failing_since
self.alive = True
self.badge = None
self.name = "Pusher-%d" % (_get_next_id(),)
# The last value of last_active_time that we saw
self.last_last_active_time = 0
self.has_unread = True
@defer.inlineCallbacks
def get_context_for_event(self, ev):
name_aliases = yield self.store.get_room_name_and_aliases(
ev['room_id']
)
ctx = {'aliases': name_aliases[1]}
if name_aliases[0] is not None:
ctx['name'] = name_aliases[0]
their_member_events_for_room = yield self.store.get_current_state(
room_id=ev['room_id'],
event_type='m.room.member',
state_key=ev['user_id']
)
for mev in their_member_events_for_room:
if mev.content['membership'] == 'join' and 'displayname' in mev.content:
dn = mev.content['displayname']
if dn is not None:
ctx['sender_display_name'] = dn
defer.returnValue(ctx)
@defer.inlineCallbacks
def start(self):
with LoggingContext(self.name):
if not self.last_token:
# First-time setup: get a token to start from (we can't
# just start from no token, ie. 'now'
# because we need the result to be reproduceable in case
# we fail to dispatch the push)
config = PaginationConfig(from_token=None, limit='1')
chunk = yield self.evStreamHandler.get_stream(
self.user_id, config, timeout=0, affect_presence=False
)
self.last_token = chunk['end']
yield self.store.update_pusher_last_token(
self.app_id, self.pushkey, self.user_id, self.last_token
)
logger.info("New pusher %s for user %s starting from token %s",
self.pushkey, self.user_id, self.last_token)
else:
logger.info(
"Old pusher %s for user %s starting",
self.pushkey, self.user_id,
)
wait = 0
while self.alive:
try:
if wait > 0:
yield synapse.util.async.sleep(wait)
with Measure(self.clock, "push"):
yield self.get_and_dispatch()
wait = 0
except:
if wait == 0:
wait = 1
else:
wait = min(wait * 2, 1800)
logger.exception(
"Exception in pusher loop for pushkey %s. Pausing for %ds",
self.pushkey, wait
)
@defer.inlineCallbacks
def get_and_dispatch(self):
from_tok = StreamToken.from_string(self.last_token)
config = PaginationConfig(from_token=from_tok, limit='1')
timeout = (300 + random.randint(-60, 60)) * 1000
chunk = yield self.evStreamHandler.get_stream(
self.user_id, config, timeout=timeout, affect_presence=False,
only_keys=("room", "receipt",),
)
# limiting to 1 may get 1 event plus 1 presence event, so
# pick out the actual event
single_event = None
read_receipt = None
for c in chunk['chunk']:
if 'event_id' in c: # Hmmm...
single_event = c
elif c['type'] == 'm.receipt':
read_receipt = c
have_updated_badge = False
if read_receipt:
for receipt_part in read_receipt['content'].values():
if 'm.read' in receipt_part:
if self.user_id in receipt_part['m.read'].keys():
have_updated_badge = True
if not single_event:
if have_updated_badge:
yield self.update_badge()
self.last_token = chunk['end']
yield self.store.update_pusher_last_token(
self.app_id,
self.pushkey,
self.user_id,
self.last_token
)
return
if not self.alive:
return
processed = False
rule_evaluator = yield \
evaluator_for_user_id(
self.user_id, single_event['room_id'], self.store
)
actions = yield rule_evaluator.actions_for_event(single_event)
tweaks = rule_evaluator.tweaks_for_actions(actions)
if 'notify' in actions:
self.badge = yield self._get_badge_count()
rejected = yield self.dispatch_push(single_event, tweaks, self.badge)
self.has_unread = True
if isinstance(rejected, list) or isinstance(rejected, tuple):
processed = True
for pk in rejected:
if pk != self.pushkey:
# for sanity, we only remove the pushkey if it
# was the one we actually sent...
logger.warn(
("Ignoring rejected pushkey %s because we"
" didn't send it"), pk
)
else:
logger.info(
"Pushkey %s was rejected: removing",
pk
)
yield self.hs.get_pusherpool().remove_pusher(
self.app_id, pk, self.user_id
)
else:
if have_updated_badge:
yield self.update_badge()
processed = True
if not self.alive:
return
if processed:
self.backoff_delay = Pusher.INITIAL_BACKOFF
self.last_token = chunk['end']
yield self.store.update_pusher_last_token_and_success(
self.app_id,
self.pushkey,
self.user_id,
self.last_token,
self.clock.time_msec()
)
if self.failing_since:
self.failing_since = None
yield self.store.update_pusher_failing_since(
self.app_id,
self.pushkey,
self.user_id,
self.failing_since)
else:
if not self.failing_since:
self.failing_since = self.clock.time_msec()
yield self.store.update_pusher_failing_since(
self.app_id,
self.pushkey,
self.user_id,
self.failing_since
)
if (self.failing_since and
self.failing_since <
self.clock.time_msec() - Pusher.GIVE_UP_AFTER):
# we really only give up so that if the URL gets
# fixed, we don't suddenly deliver a load
# of old notifications.
logger.warn("Giving up on a notification to user %s, "
"pushkey %s",
self.user_id, self.pushkey)
self.backoff_delay = Pusher.INITIAL_BACKOFF
self.last_token = chunk['end']
yield self.store.update_pusher_last_token(
self.app_id,
self.pushkey,
self.user_id,
self.last_token
)
self.failing_since = None
yield self.store.update_pusher_failing_since(
self.app_id,
self.pushkey,
self.user_id,
self.failing_since
)
else:
logger.warn("Failed to dispatch push for user %s "
"(failing for %dms)."
"Trying again in %dms",
self.user_id,
self.clock.time_msec() - self.failing_since,
self.backoff_delay)
yield synapse.util.async.sleep(self.backoff_delay / 1000.0)
self.backoff_delay *= 2
if self.backoff_delay > Pusher.MAX_BACKOFF:
self.backoff_delay = Pusher.MAX_BACKOFF
def stop(self):
self.alive = False
def dispatch_push(self, p, tweaks, badge):
"""
Overridden by implementing classes to actually deliver the notification
Args:
p: The event to notify for as a single event from the event stream
Returns: If the notification was delivered, an array containing any
pushkeys that were rejected by the push gateway.
False if the notification could not be delivered (ie.
should be retried).
"""
pass
@defer.inlineCallbacks
def update_badge(self):
new_badge = yield self._get_badge_count()
if self.badge != new_badge:
self.badge = new_badge
yield self.send_badge(self.badge)
def send_badge(self, badge):
"""
Overridden by implementing classes to send an updated badge count
"""
pass
@defer.inlineCallbacks
def _get_badge_count(self):
invites, joins = yield defer.gatherResults([
self.store.get_invited_rooms_for_user(self.user_id),
self.store.get_rooms_for_user(self.user_id),
], consumeErrors=True)
my_receipts_by_room = yield self.store.get_receipts_for_user(
self.user_id,
"m.read",
)
badge = len(invites)
for r in joins:
if r.room_id in my_receipts_by_room:
last_unread_event_id = my_receipts_by_room[r.room_id]
notifs = yield (
self.store.get_unread_event_push_actions_by_room_for_user(
r.room_id, self.user_id, last_unread_event_id
)
)
badge += notifs["notify_count"]
defer.returnValue(badge)
class PusherConfigException(Exception):
def __init__(self, msg):

View File

@@ -15,7 +15,9 @@
from twisted.internet import defer
from .bulk_push_rule_evaluator import evaluator_for_room_id
from .bulk_push_rule_evaluator import evaluator_for_event
from synapse.util.metrics import Measure
import logging
@@ -25,6 +27,7 @@ logger = logging.getLogger(__name__)
class ActionGenerator:
def __init__(self, hs):
self.hs = hs
self.clock = hs.get_clock()
self.store = hs.get_datastore()
# really we want to get all user ids and all profile tags too,
# since we want the actions for each profile tag for every user and
@@ -35,14 +38,15 @@ class ActionGenerator:
@defer.inlineCallbacks
def handle_push_actions_for_event(self, event, context, handler):
bulk_evaluator = yield evaluator_for_room_id(
event.room_id, self.hs, self.store
)
with Measure(self.clock, "handle_push_actions_for_event"):
bulk_evaluator = yield evaluator_for_event(
event, self.hs, self.store
)
actions_by_user = yield bulk_evaluator.action_for_event_by_user(
event, handler, context.current_state
)
actions_by_user = yield bulk_evaluator.action_for_event_by_user(
event, handler, context.current_state
)
context.push_actions = [
(uid, actions) for uid, actions in actions_by_user.items()
]
context.push_actions = [
(uid, actions) for uid, actions in actions_by_user.items()
]

View File

@@ -19,9 +19,11 @@ import copy
def list_with_base_rules(rawrules):
"""Combine the list of rules set by the user with the default push rules
:param list rawrules: The rules the user has modified or set.
:returns: A new list with the rules set by the user combined with the
defaults.
Args:
rawrules(list): The rules the user has modified or set.
Returns:
A new list with the rules set by the user combined with the defaults.
"""
ruleslist = []
@@ -77,7 +79,7 @@ def make_base_append_rules(kind, modified_base_rules):
rules = []
if kind == 'override':
rules = BASE_APPEND_OVRRIDE_RULES
rules = BASE_APPEND_OVERRIDE_RULES
elif kind == 'underride':
rules = BASE_APPEND_UNDERRIDE_RULES
elif kind == 'content':
@@ -146,7 +148,7 @@ BASE_PREPEND_OVERRIDE_RULES = [
]
BASE_APPEND_OVRRIDE_RULES = [
BASE_APPEND_OVERRIDE_RULES = [
{
'rule_id': 'global/override/.m.rule.suppress_notices',
'conditions': [
@@ -160,7 +162,61 @@ BASE_APPEND_OVRRIDE_RULES = [
'actions': [
'dont_notify',
]
}
},
# NB. .m.rule.invite_for_me must be higher prio than .m.rule.member_event
# otherwise invites will be matched by .m.rule.member_event
{
'rule_id': 'global/override/.m.rule.invite_for_me',
'conditions': [
{
'kind': 'event_match',
'key': 'type',
'pattern': 'm.room.member',
'_id': '_member',
},
{
'kind': 'event_match',
'key': 'content.membership',
'pattern': 'invite',
'_id': '_invite_member',
},
{
'kind': 'event_match',
'key': 'state_key',
'pattern_type': 'user_id'
},
],
'actions': [
'notify',
{
'set_tweak': 'sound',
'value': 'default'
}, {
'set_tweak': 'highlight',
'value': False
}
]
},
# Will we sometimes want to know about people joining and leaving?
# Perhaps: if so, this could be expanded upon. Seems the most usual case
# is that we don't though. We add this override rule so that even if
# the room rule is set to notify, we don't get notifications about
# join/leave/avatar/displayname events.
# See also: https://matrix.org/jira/browse/SYN-607
{
'rule_id': 'global/override/.m.rule.member_event',
'conditions': [
{
'kind': 'event_match',
'key': 'type',
'pattern': 'm.room.member',
'_id': '_member',
}
],
'actions': [
'dont_notify'
]
},
]
@@ -229,57 +285,6 @@ BASE_APPEND_UNDERRIDE_RULES = [
}
]
},
{
'rule_id': 'global/underride/.m.rule.invite_for_me',
'conditions': [
{
'kind': 'event_match',
'key': 'type',
'pattern': 'm.room.member',
'_id': '_member',
},
{
'kind': 'event_match',
'key': 'content.membership',
'pattern': 'invite',
'_id': '_invite_member',
},
{
'kind': 'event_match',
'key': 'state_key',
'pattern_type': 'user_id'
},
],
'actions': [
'notify',
{
'set_tweak': 'sound',
'value': 'default'
}, {
'set_tweak': 'highlight',
'value': False
}
]
},
# This is too simple: https://matrix.org/jira/browse/SYN-607
# Removing for now
# {
# 'rule_id': 'global/underride/.m.rule.member_event',
# 'conditions': [
# {
# 'kind': 'event_match',
# 'key': 'type',
# 'pattern': 'm.room.member',
# '_id': '_member',
# }
# ],
# 'actions': [
# 'notify', {
# 'set_tweak': 'highlight',
# 'value': False
# }
# ]
# },
{
'rule_id': 'global/underride/.m.rule.message',
'conditions': [
@@ -312,7 +317,7 @@ for r in BASE_PREPEND_OVERRIDE_RULES:
r['default'] = True
BASE_RULE_IDS.add(r['rule_id'])
for r in BASE_APPEND_OVRRIDE_RULES:
for r in BASE_APPEND_OVERRIDE_RULES:
r['priority_class'] = PRIORITY_CLASS_MAP['override']
r['default'] = True
BASE_RULE_IDS.add(r['rule_id'])

View File

@@ -69,12 +69,40 @@ def _get_rules(room_id, user_ids, store):
@defer.inlineCallbacks
def evaluator_for_room_id(room_id, hs, store):
results = yield store.get_receipts_for_room(room_id, "m.read")
user_ids = [
row["user_id"] for row in results
if hs.is_mine_id(row["user_id"])
]
def evaluator_for_event(event, hs, store):
room_id = event.room_id
# users in the room who have pushers need to get push rules run because
# that's how their pushers work
users_with_pushers = yield store.get_users_with_pushers_in_room(room_id)
# We also will want to generate notifs for other people in the room so
# their unread countss are correct in the event stream, but to avoid
# generating them for bot / AS users etc, we only do so for people who've
# sent a read receipt into the room.
all_in_room = yield store.get_users_in_room(room_id)
all_in_room = set(all_in_room)
receipts = yield store.get_receipts_for_room(room_id, "m.read")
# any users with pushers must be ours: they have pushers
user_ids = set(users_with_pushers)
for r in receipts:
if hs.is_mine_id(r['user_id']) and r['user_id'] in all_in_room:
user_ids.add(r['user_id'])
# if this event is an invite event, we may need to run rules for the user
# who's been invited, otherwise they won't get told they've been invited
if event.type == 'm.room.member' and event.content['membership'] == 'invite':
invited_user = event.state_key
if invited_user and hs.is_mine_id(invited_user):
has_pusher = yield store.user_has_pusher(invited_user)
if has_pusher:
user_ids.add(invited_user)
user_ids = list(user_ids)
rules_by_user = yield _get_rules(room_id, user_ids, store)
defer.returnValue(BulkPushRuleEvaluator(
@@ -101,10 +129,15 @@ class BulkPushRuleEvaluator:
def action_for_event_by_user(self, event, handler, current_state):
actions_by_user = {}
users_dict = yield self.store.are_guests(self.rules_by_user.keys())
# None of these users can be peeking since this list of users comes
# from the set of users in the room, so we know for sure they're all
# actually in the room.
user_tuples = [
(u, False) for u in self.rules_by_user.keys()
]
filtered_by_user = yield handler.filter_events_for_clients(
users_dict.items(), [event], {event.event_id: current_state}
user_tuples, [event], {event.event_id: current_state}
)
room_members = yield self.store.get_users_in_room(self.room_id)

View File

@@ -13,60 +13,239 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.push import Pusher, PusherConfigException
from synapse.push import PusherConfigException
from twisted.internet import defer
from twisted.internet import defer, reactor
import logging
import push_rule_evaluator
import push_tools
from synapse.util.logcontext import LoggingContext
from synapse.util.metrics import Measure
logger = logging.getLogger(__name__)
class HttpPusher(Pusher):
def __init__(self, _hs, user_id, app_id,
app_display_name, device_display_name, pushkey, pushkey_ts,
data, last_token, last_success, failing_since):
super(HttpPusher, self).__init__(
_hs,
user_id,
app_id,
app_display_name,
device_display_name,
pushkey,
pushkey_ts,
data,
last_token,
last_success,
failing_since
class HttpPusher(object):
INITIAL_BACKOFF_SEC = 1 # in seconds because that's what Twisted takes
MAX_BACKOFF_SEC = 60 * 60
# This one's in ms because we compare it against the clock
GIVE_UP_AFTER_MS = 24 * 60 * 60 * 1000
def __init__(self, hs, pusherdict):
self.hs = hs
self.store = self.hs.get_datastore()
self.clock = self.hs.get_clock()
self.user_id = pusherdict['user_name']
self.app_id = pusherdict['app_id']
self.app_display_name = pusherdict['app_display_name']
self.device_display_name = pusherdict['device_display_name']
self.pushkey = pusherdict['pushkey']
self.pushkey_ts = pusherdict['ts']
self.data = pusherdict['data']
self.last_stream_ordering = pusherdict['last_stream_ordering']
self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC
self.failing_since = pusherdict['failing_since']
self.timed_call = None
self.processing = False
# This is the highest stream ordering we know it's safe to process.
# When new events arrive, we'll be given a window of new events: we
# should honour this rather than just looking for anything higher
# because of potential out-of-order event serialisation. This starts
# off as None though as we don't know any better.
self.max_stream_ordering = None
if 'data' not in pusherdict:
raise PusherConfigException(
"No 'data' key for HTTP pusher"
)
self.data = pusherdict['data']
self.name = "%s/%s/%s" % (
pusherdict['user_name'],
pusherdict['app_id'],
pusherdict['pushkey'],
)
if 'url' not in data:
if 'url' not in self.data:
raise PusherConfigException(
"'url' required in data for HTTP pusher"
)
self.url = data['url']
self.http_client = _hs.get_simple_http_client()
self.url = self.data['url']
self.http_client = hs.get_simple_http_client()
self.data_minus_url = {}
self.data_minus_url.update(self.data)
del self.data_minus_url['url']
@defer.inlineCallbacks
def _build_notification_dict(self, event, tweaks, badge):
# we probably do not want to push for every presence update
# (we may want to be able to set up notifications when specific
# people sign in, but we'd want to only deliver the pertinent ones)
# Actually, presence events will not get this far now because we
# need to filter them out in the main Pusher code.
if 'event_id' not in event:
defer.returnValue(None)
def on_started(self):
yield self._process()
ctx = yield self.get_context_for_event(event)
@defer.inlineCallbacks
def on_new_notifications(self, min_stream_ordering, max_stream_ordering):
self.max_stream_ordering = max(max_stream_ordering, self.max_stream_ordering)
yield self._process()
@defer.inlineCallbacks
def on_new_receipts(self, min_stream_id, max_stream_id):
# Note that the min here shouldn't be relied upon to be accurate.
# We could check the receipts are actually m.read receipts here,
# but currently that's the only type of receipt anyway...
with LoggingContext("push.on_new_receipts"):
with Measure(self.clock, "push.on_new_receipts"):
badge = yield push_tools.get_badge_count(
self.hs.get_datastore(), self.user_id
)
yield self._send_badge(badge)
@defer.inlineCallbacks
def on_timer(self):
yield self._process()
def on_stop(self):
if self.timed_call:
self.timed_call.cancel()
@defer.inlineCallbacks
def _process(self):
if self.processing:
return
with LoggingContext("push._process"):
with Measure(self.clock, "push._process"):
try:
self.processing = True
# if the max ordering changes while we're running _unsafe_process,
# call it again, and so on until we've caught up.
while True:
starting_max_ordering = self.max_stream_ordering
try:
yield self._unsafe_process()
except:
logger.exception("Exception processing notifs")
if self.max_stream_ordering == starting_max_ordering:
break
finally:
self.processing = False
@defer.inlineCallbacks
def _unsafe_process(self):
"""
Looks for unset notifications and dispatch them, in order
Never call this directly: use _process which will only allow this to
run once per pusher.
"""
unprocessed = yield self.store.get_unread_push_actions_for_user_in_range(
self.user_id, self.last_stream_ordering, self.max_stream_ordering
)
for push_action in unprocessed:
processed = yield self._process_one(push_action)
if processed:
self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC
self.last_stream_ordering = push_action['stream_ordering']
yield self.store.update_pusher_last_stream_ordering_and_success(
self.app_id, self.pushkey, self.user_id,
self.last_stream_ordering,
self.clock.time_msec()
)
if self.failing_since:
self.failing_since = None
yield self.store.update_pusher_failing_since(
self.app_id, self.pushkey, self.user_id,
self.failing_since
)
else:
if not self.failing_since:
self.failing_since = self.clock.time_msec()
yield self.store.update_pusher_failing_since(
self.app_id, self.pushkey, self.user_id,
self.failing_since
)
if (
self.failing_since and
self.failing_since <
self.clock.time_msec() - HttpPusher.GIVE_UP_AFTER_MS
):
# we really only give up so that if the URL gets
# fixed, we don't suddenly deliver a load
# of old notifications.
logger.warn("Giving up on a notification to user %s, "
"pushkey %s",
self.user_id, self.pushkey)
self.backoff_delay = HttpPusher.INITIAL_BACKOFF_SEC
self.last_stream_ordering = push_action['stream_ordering']
yield self.store.update_pusher_last_stream_ordering(
self.app_id,
self.pushkey,
self.user_id,
self.last_stream_ordering
)
self.failing_since = None
yield self.store.update_pusher_failing_since(
self.app_id,
self.pushkey,
self.user_id,
self.failing_since
)
else:
logger.info("Push failed: delaying for %ds", self.backoff_delay)
self.timed_call = reactor.callLater(self.backoff_delay, self.on_timer)
self.backoff_delay = min(self.backoff_delay * 2, self.MAX_BACKOFF_SEC)
break
@defer.inlineCallbacks
def _process_one(self, push_action):
if 'notify' not in push_action['actions']:
defer.returnValue(True)
tweaks = push_rule_evaluator.tweaks_for_actions(push_action['actions'])
badge = yield push_tools.get_badge_count(self.hs.get_datastore(), self.user_id)
event = yield self.store.get_event(push_action['event_id'], allow_none=True)
if event is None:
defer.returnValue(True) # It's been redacted
rejected = yield self.dispatch_push(event, tweaks, badge)
if rejected is False:
defer.returnValue(False)
if isinstance(rejected, list) or isinstance(rejected, tuple):
for pk in rejected:
if pk != self.pushkey:
# for sanity, we only remove the pushkey if it
# was the one we actually sent...
logger.warn(
("Ignoring rejected pushkey %s because we"
" didn't send it"), pk
)
else:
logger.info(
"Pushkey %s was rejected: removing",
pk
)
yield self.hs.remove_pusher(
self.app_id, pk, self.user_id
)
defer.returnValue(True)
@defer.inlineCallbacks
def _build_notification_dict(self, event, tweaks, badge):
ctx = yield push_tools.get_context_for_event(self.hs.get_datastore(), event)
d = {
'notification': {
'id': event['event_id'],
'room_id': event['room_id'],
'type': event['type'],
'sender': event['user_id'],
'id': event.event_id, # deprecated: remove soon
'event_id': event.event_id,
'room_id': event.room_id,
'type': event.type,
'sender': event.user_id,
'counts': { # -- we don't mark messages as read yet so
# we have no way of knowing
# Just set the badge to 1 until we have read receipts
@@ -84,11 +263,11 @@ class HttpPusher(Pusher):
]
}
}
if event['type'] == 'm.room.member':
d['notification']['membership'] = event['content']['membership']
d['notification']['user_is_target'] = event['state_key'] == self.user_id
if event.type == 'm.room.member':
d['notification']['membership'] = event.content['membership']
d['notification']['user_is_target'] = event.state_key == self.user_id
if 'content' in event:
d['notification']['content'] = event['content']
d['notification']['content'] = event.content
if len(ctx['aliases']):
d['notification']['room_alias'] = ctx['aliases'][0]
@@ -115,7 +294,7 @@ class HttpPusher(Pusher):
defer.returnValue(rejected)
@defer.inlineCallbacks
def send_badge(self, badge):
def _send_badge(self, badge):
logger.info("Sending updated badge count %d to %r", badge, self.user_id)
d = {
'notification': {

View File

@@ -13,12 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
from .baserules import list_with_base_rules
import logging
import simplejson as json
import re
from synapse.types import UserID
@@ -32,22 +27,6 @@ IS_GLOB = re.compile(r'[\?\*\[\]]')
INEQUALITY_EXPR = re.compile("^([=<>]*)([0-9]*)$")
@defer.inlineCallbacks
def evaluator_for_user_id(user_id, room_id, store):
rawrules = yield store.get_push_rules_for_user(user_id)
enabled_map = yield store.get_push_rules_enabled_for_user(user_id)
our_member_event = yield store.get_current_state(
room_id=room_id,
event_type='m.room.member',
state_key=user_id,
)
defer.returnValue(PushRuleEvaluator(
user_id, rawrules, enabled_map,
room_id, our_member_event, store
))
def _room_member_count(ev, condition, room_member_count):
if 'is' not in condition:
return False
@@ -74,110 +53,14 @@ def _room_member_count(ev, condition, room_member_count):
return False
class PushRuleEvaluator:
DEFAULT_ACTIONS = []
def __init__(self, user_id, raw_rules, enabled_map, room_id,
our_member_event, store):
self.user_id = user_id
self.room_id = room_id
self.our_member_event = our_member_event
self.store = store
rules = []
for raw_rule in raw_rules:
rule = dict(raw_rule)
rule['conditions'] = json.loads(raw_rule['conditions'])
rule['actions'] = json.loads(raw_rule['actions'])
rules.append(rule)
self.rules = list_with_base_rules(rules)
self.enabled_map = enabled_map
@staticmethod
def tweaks_for_actions(actions):
tweaks = {}
for a in actions:
if not isinstance(a, dict):
continue
if 'set_tweak' in a and 'value' in a:
tweaks[a['set_tweak']] = a['value']
return tweaks
@defer.inlineCallbacks
def actions_for_event(self, ev):
"""
This should take into account notification settings that the user
has configured both globally and per-room when we have the ability
to do such things.
"""
if ev['user_id'] == self.user_id:
# let's assume you probably know about messages you sent yourself
defer.returnValue([])
room_id = ev['room_id']
# get *our* member event for display name matching
my_display_name = None
if self.our_member_event:
my_display_name = self.our_member_event[0].content.get("displayname")
room_members = yield self.store.get_users_in_room(room_id)
room_member_count = len(room_members)
evaluator = PushRuleEvaluatorForEvent(ev, room_member_count)
for r in self.rules:
enabled = self.enabled_map.get(r['rule_id'], None)
if enabled is not None and not enabled:
continue
if not r.get("enabled", True):
continue
conditions = r['conditions']
actions = r['actions']
# ignore rules with no actions (we have an explict 'dont_notify')
if len(actions) == 0:
logger.warn(
"Ignoring rule id %s with no actions for user %s",
r['rule_id'], self.user_id
)
continue
matches = True
for c in conditions:
matches = evaluator.matches(
c, self.user_id, my_display_name
)
if not matches:
break
logger.debug(
"Rule %s %s",
r['rule_id'], "matches" if matches else "doesn't match"
)
if matches:
logger.debug(
"%s matches for user %s, event %s",
r['rule_id'], self.user_id, ev['event_id']
)
# filter out dont_notify as we treat an empty actions list
# as dont_notify, and this doesn't take up a row in our database
actions = [x for x in actions if x != 'dont_notify']
defer.returnValue(actions)
logger.debug(
"No rules match for user %s, event %s",
self.user_id, ev['event_id']
)
defer.returnValue(PushRuleEvaluator.DEFAULT_ACTIONS)
def tweaks_for_actions(actions):
tweaks = {}
for a in actions:
if not isinstance(a, dict):
continue
if 'set_tweak' in a and 'value' in a:
tweaks[a['set_tweak']] = a['value']
return tweaks
class PushRuleEvaluatorForEvent(object):

View File

@@ -0,0 +1,66 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer
@defer.inlineCallbacks
def get_badge_count(store, user_id):
invites, joins = yield defer.gatherResults([
store.get_invited_rooms_for_user(user_id),
store.get_rooms_for_user(user_id),
], consumeErrors=True)
my_receipts_by_room = yield store.get_receipts_for_user(
user_id, "m.read",
)
badge = len(invites)
for r in joins:
if r.room_id in my_receipts_by_room:
last_unread_event_id = my_receipts_by_room[r.room_id]
notifs = yield (
store.get_unread_event_push_actions_by_room_for_user(
r.room_id, user_id, last_unread_event_id
)
)
badge += notifs["notify_count"]
defer.returnValue(badge)
@defer.inlineCallbacks
def get_context_for_event(store, ev):
name_aliases = yield store.get_room_name_and_aliases(
ev.room_id
)
ctx = {'aliases': name_aliases[1]}
if name_aliases[0] is not None:
ctx['name'] = name_aliases[0]
their_member_events_for_room = yield store.get_current_state(
room_id=ev.room_id,
event_type='m.room.member',
state_key=ev.user_id
)
for mev in their_member_events_for_room:
if mev.content['membership'] == 'join' and 'displayname' in mev.content:
dn = mev.content['displayname']
if dn is not None:
ctx['sender_display_name'] = dn
defer.returnValue(ctx)

10
synapse/push/pusher.py Normal file
View File

@@ -0,0 +1,10 @@
from httppusher import HttpPusher
PUSHER_TYPES = {
'http': HttpPusher
}
def create_pusher(hs, pusherdict):
if pusherdict['kind'] in PUSHER_TYPES:
return PUSHER_TYPES[pusherdict['kind']](hs, pusherdict)

View File

@@ -16,9 +16,10 @@
from twisted.internet import defer
from .httppusher import HttpPusher
import pusher
from synapse.push import PusherConfigException
from synapse.util.logcontext import preserve_fn
from synapse.util.async import run_on_reactor
import logging
@@ -28,10 +29,10 @@ logger = logging.getLogger(__name__)
class PusherPool:
def __init__(self, _hs):
self.hs = _hs
self.start_pushers = _hs.config.start_pushers
self.store = self.hs.get_datastore()
self.clock = self.hs.get_clock()
self.pushers = {}
self.last_pusher_started = -1
@defer.inlineCallbacks
def start(self):
@@ -48,7 +49,7 @@ class PusherPool:
# will then get pulled out of the database,
# recreated, added and started: this means we have only one
# code path adding pushers.
self._create_pusher({
pusher.create_pusher(self.hs, {
"user_name": user_id,
"kind": kind,
"app_id": app_id,
@@ -58,10 +59,18 @@ class PusherPool:
"ts": time_now_msec,
"lang": lang,
"data": data,
"last_token": None,
"last_stream_ordering": None,
"last_success": None,
"failing_since": None
})
# create the pusher setting last_stream_ordering to the current maximum
# stream ordering in event_push_actions, so it will process
# pushes from this point onwards.
last_stream_ordering = (
yield self.store.get_latest_push_action_stream_ordering()
)
yield self.store.add_pusher(
user_id=user_id,
access_token=access_token,
@@ -73,6 +82,7 @@ class PusherPool:
pushkey_ts=time_now_msec,
lang=lang,
data=data,
last_stream_ordering=last_stream_ordering,
profile_tag=profile_tag,
)
yield self._refresh_pusher(app_id, pushkey, user_id)
@@ -106,26 +116,51 @@ class PusherPool:
)
yield self.remove_pusher(p['app_id'], p['pushkey'], p['user_name'])
def _create_pusher(self, pusherdict):
if pusherdict['kind'] == 'http':
return HttpPusher(
self.hs,
user_id=pusherdict['user_name'],
app_id=pusherdict['app_id'],
app_display_name=pusherdict['app_display_name'],
device_display_name=pusherdict['device_display_name'],
pushkey=pusherdict['pushkey'],
pushkey_ts=pusherdict['ts'],
data=pusherdict['data'],
last_token=pusherdict['last_token'],
last_success=pusherdict['last_success'],
failing_since=pusherdict['failing_since']
@defer.inlineCallbacks
def on_new_notifications(self, min_stream_id, max_stream_id):
yield run_on_reactor()
try:
users_affected = yield self.store.get_push_action_users_in_range(
min_stream_id, max_stream_id
)
else:
raise PusherConfigException(
"Unknown pusher type '%s' for user %s" %
(pusherdict['kind'], pusherdict['user_name'])
deferreds = []
for u in users_affected:
if u in self.pushers:
for p in self.pushers[u].values():
deferreds.append(
p.on_new_notifications(min_stream_id, max_stream_id)
)
yield defer.gatherResults(deferreds)
except:
logger.exception("Exception in pusher on_new_notifications")
@defer.inlineCallbacks
def on_new_receipts(self, min_stream_id, max_stream_id, affected_room_ids):
yield run_on_reactor()
try:
# Need to subtract 1 from the minimum because the lower bound here
# is not inclusive
updated_receipts = yield self.store.get_all_updated_receipts(
min_stream_id - 1, max_stream_id
)
# This returns a tuple, user_id is at index 3
users_affected = set([r[3] for r in updated_receipts])
deferreds = []
for u in users_affected:
if u in self.pushers:
for p in self.pushers[u].values():
deferreds.append(
p.on_new_receipts(min_stream_id, max_stream_id)
)
yield defer.gatherResults(deferreds)
except:
logger.exception("Exception in pusher on_new_receipts")
@defer.inlineCallbacks
def _refresh_pusher(self, app_id, pushkey, user_id):
@@ -143,33 +178,40 @@ class PusherPool:
self._start_pushers([p])
def _start_pushers(self, pushers):
if not self.start_pushers:
logger.info("Not starting pushers because they are disabled in the config")
return
logger.info("Starting %d pushers", len(pushers))
for pusherdict in pushers:
try:
p = self._create_pusher(pusherdict)
p = pusher.create_pusher(self.hs, pusherdict)
except PusherConfigException:
logger.exception("Couldn't start a pusher: caught PusherConfigException")
continue
if p:
fullid = "%s:%s:%s" % (
appid_pushkey = "%s:%s" % (
pusherdict['app_id'],
pusherdict['pushkey'],
pusherdict['user_name']
)
if fullid in self.pushers:
self.pushers[fullid].stop()
self.pushers[fullid] = p
preserve_fn(p.start)()
byuser = self.pushers.setdefault(pusherdict['user_name'], {})
if appid_pushkey in byuser:
byuser[appid_pushkey].on_stop()
byuser[appid_pushkey] = p
preserve_fn(p.on_started)()
logger.info("Started pushers")
@defer.inlineCallbacks
def remove_pusher(self, app_id, pushkey, user_id):
fullid = "%s:%s:%s" % (app_id, pushkey, user_id)
if fullid in self.pushers:
logger.info("Stopping pusher %s", fullid)
self.pushers[fullid].stop()
del self.pushers[fullid]
appid_pushkey = "%s:%s" % (app_id, pushkey)
byuser = self.pushers.get(user_id, {})
if appid_pushkey in byuser:
logger.info("Stopping pusher %s / %s", user_id, appid_pushkey)
byuser[appid_pushkey].on_stop()
del byuser[appid_pushkey]
yield self.store.delete_pusher_by_app_id_pushkey_user_id(
app_id, pushkey, user_id
)

View File

@@ -36,11 +36,15 @@ REQUIREMENTS = {
"blist": ["blist"],
"pysaml2>=3.0.0,<4.0.0": ["saml2>=3.0.0,<4.0.0"],
"pymacaroons-pynacl": ["pymacaroons"],
"pyjwt": ["jwt"],
}
CONDITIONAL_REQUIREMENTS = {
"web_client": {
"matrix_angular_sdk>=0.6.8": ["syweb>=0.6.8"],
}
},
"preview_url": {
"netaddr>=0.7.18": ["netaddr"],
},
}

View File

@@ -0,0 +1,53 @@
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.http.server import respond_with_json_bytes, request_handler
from synapse.http.servlet import parse_json_object_from_request
from twisted.web.resource import Resource
from twisted.web.server import NOT_DONE_YET
from twisted.internet import defer
class PusherResource(Resource):
"""
HTTP endpoint for deleting rejected pushers
"""
def __init__(self, hs):
Resource.__init__(self) # Resource is old-style, so no super()
self.version_string = hs.version_string
self.store = hs.get_datastore()
self.notifier = hs.get_notifier()
def render_POST(self, request):
self._async_render_POST(request)
return NOT_DONE_YET
@request_handler
@defer.inlineCallbacks
def _async_render_POST(self, request):
content = parse_json_object_from_request(request)
for remove in content["remove"]:
yield self.store.delete_pusher_by_app_id_pushkey_user_id(
remove["app_id"],
remove["push_key"],
remove["user_id"],
)
self.notifier.on_new_replication_data()
respond_with_json_bytes(request, 200, "{}")

View File

@@ -15,6 +15,7 @@
from synapse.http.servlet import parse_integer, parse_string
from synapse.http.server import request_handler, finish_request
from synapse.replication.pusher_resource import PusherResource
from twisted.web.resource import Resource
from twisted.web.server import NOT_DONE_YET
@@ -38,6 +39,7 @@ STREAM_NAMES = (
("backfill",),
("push_rules",),
("pushers",),
("state",),
)
@@ -76,7 +78,7 @@ class ReplicationResource(Resource):
The response is a JSON object with keys for each stream with updates. Under
each key is a JSON object with:
* "postion": The current position of the stream.
* "position": The current position of the stream.
* "field_names": The names of the fields in each row.
* "rows": The updates as an array of arrays.
@@ -101,8 +103,6 @@ class ReplicationResource(Resource):
long-polling this replication API for new data on those streams.
"""
isLeaf = True
def __init__(self, hs):
Resource.__init__(self) # Resource is old-style, so no super()
@@ -113,6 +113,8 @@ class ReplicationResource(Resource):
self.typing_handler = hs.get_handlers().typing_notification_handler
self.notifier = hs.notifier
self.putChild("remove_pushers", PusherResource(hs))
def render_GET(self, request):
self._async_render_GET(request)
return NOT_DONE_YET
@@ -123,6 +125,7 @@ class ReplicationResource(Resource):
backfill_token = yield self.store.get_current_backfill_token()
push_rules_token, room_stream_token = self.store.get_push_rules_stream_token()
pushers_token = self.store.get_pushers_stream_token()
state_token = self.store.get_state_stream_token()
defer.returnValue(_ReplicationToken(
room_stream_token,
@@ -133,6 +136,7 @@ class ReplicationResource(Resource):
backfill_token,
push_rules_token,
pushers_token,
state_token,
))
@request_handler
@@ -142,31 +146,43 @@ class ReplicationResource(Resource):
timeout = parse_integer(request, "timeout", 10 * 1000)
request.setHeader(b"Content-Type", b"application/json")
writer = _Writer(request)
@defer.inlineCallbacks
request_streams = {
name: parse_integer(request, name)
for names in STREAM_NAMES for name in names
}
request_streams["streams"] = parse_string(request, "streams")
def replicate():
current_token = yield self.current_replication_token()
logger.info("Replicating up to %r", current_token)
return self.replicate(request_streams, limit)
yield self.account_data(writer, current_token, limit)
yield self.events(writer, current_token, limit)
yield self.presence(writer, current_token) # TODO: implement limit
yield self.typing(writer, current_token) # TODO: implement limit
yield self.receipts(writer, current_token, limit)
yield self.push_rules(writer, current_token, limit)
yield self.pushers(writer, current_token, limit)
self.streams(writer, current_token)
result = yield self.notifier.wait_for_replication(replicate, timeout)
logger.info("Replicated %d rows", writer.total)
defer.returnValue(writer.total)
request.write(json.dumps(result, ensure_ascii=False))
finish_request(request)
yield self.notifier.wait_for_replication(replicate, timeout)
@defer.inlineCallbacks
def replicate(self, request_streams, limit):
writer = _Writer()
current_token = yield self.current_replication_token()
logger.info("Replicating up to %r", current_token)
writer.finish()
yield self.account_data(writer, current_token, limit, request_streams)
yield self.events(writer, current_token, limit, request_streams)
# TODO: implement limit
yield self.presence(writer, current_token, request_streams)
yield self.typing(writer, current_token, request_streams)
yield self.receipts(writer, current_token, limit, request_streams)
yield self.push_rules(writer, current_token, limit, request_streams)
yield self.pushers(writer, current_token, limit, request_streams)
yield self.state(writer, current_token, limit, request_streams)
self.streams(writer, current_token, request_streams)
def streams(self, writer, current_token):
request_token = parse_string(writer.request, "streams")
logger.info("Replicated %d rows", writer.total)
defer.returnValue(writer.finish())
def streams(self, writer, current_token, request_streams):
request_token = request_streams.get("streams")
streams = []
@@ -191,32 +207,43 @@ class ReplicationResource(Resource):
)
@defer.inlineCallbacks
def events(self, writer, current_token, limit):
request_events = parse_integer(writer.request, "events")
request_backfill = parse_integer(writer.request, "backfill")
def events(self, writer, current_token, limit, request_streams):
request_events = request_streams.get("events")
request_backfill = request_streams.get("backfill")
if request_events is not None or request_backfill is not None:
if request_events is None:
request_events = current_token.events
if request_backfill is None:
request_backfill = current_token.backfill
events_rows, backfill_rows = yield self.store.get_all_new_events(
res = yield self.store.get_all_new_events(
request_backfill, request_events,
current_token.backfill, current_token.events,
limit
)
writer.write_header_and_rows("events", res.new_forward_events, (
"position", "internal", "json", "state_group"
))
writer.write_header_and_rows("backfill", res.new_backfill_events, (
"position", "internal", "json", "state_group"
))
writer.write_header_and_rows(
"events", events_rows, ("position", "internal", "json")
"forward_ex_outliers", res.forward_ex_outliers,
("position", "event_id", "state_group")
)
writer.write_header_and_rows(
"backfill", backfill_rows, ("position", "internal", "json")
"backward_ex_outliers", res.backward_ex_outliers,
("position", "event_id", "state_group")
)
writer.write_header_and_rows(
"state_resets", res.state_resets, ("position",)
)
@defer.inlineCallbacks
def presence(self, writer, current_token):
def presence(self, writer, current_token, request_streams):
current_position = current_token.presence
request_presence = parse_integer(writer.request, "presence")
request_presence = request_streams.get("presence")
if request_presence is not None:
presence_rows = yield self.presence_handler.get_all_presence_updates(
@@ -229,10 +256,10 @@ class ReplicationResource(Resource):
))
@defer.inlineCallbacks
def typing(self, writer, current_token):
def typing(self, writer, current_token, request_streams):
current_position = current_token.presence
request_typing = parse_integer(writer.request, "typing")
request_typing = request_streams.get("typing")
if request_typing is not None:
typing_rows = yield self.typing_handler.get_all_typing_updates(
@@ -243,10 +270,10 @@ class ReplicationResource(Resource):
))
@defer.inlineCallbacks
def receipts(self, writer, current_token, limit):
def receipts(self, writer, current_token, limit, request_streams):
current_position = current_token.receipts
request_receipts = parse_integer(writer.request, "receipts")
request_receipts = request_streams.get("receipts")
if request_receipts is not None:
receipts_rows = yield self.store.get_all_updated_receipts(
@@ -257,12 +284,12 @@ class ReplicationResource(Resource):
))
@defer.inlineCallbacks
def account_data(self, writer, current_token, limit):
def account_data(self, writer, current_token, limit, request_streams):
current_position = current_token.account_data
user_account_data = parse_integer(writer.request, "user_account_data")
room_account_data = parse_integer(writer.request, "room_account_data")
tag_account_data = parse_integer(writer.request, "tag_account_data")
user_account_data = request_streams.get("user_account_data")
room_account_data = request_streams.get("room_account_data")
tag_account_data = request_streams.get("tag_account_data")
if user_account_data is not None or room_account_data is not None:
if user_account_data is None:
@@ -288,10 +315,10 @@ class ReplicationResource(Resource):
))
@defer.inlineCallbacks
def push_rules(self, writer, current_token, limit):
def push_rules(self, writer, current_token, limit, request_streams):
current_position = current_token.push_rules
push_rules = parse_integer(writer.request, "push_rules")
push_rules = request_streams.get("push_rules")
if push_rules is not None:
rows = yield self.store.get_all_push_rule_updates(
@@ -303,10 +330,11 @@ class ReplicationResource(Resource):
))
@defer.inlineCallbacks
def pushers(self, writer, current_token, limit):
def pushers(self, writer, current_token, limit, request_streams):
current_position = current_token.pushers
pushers = parse_integer(writer.request, "pushers")
pushers = request_streams.get("pushers")
if pushers is not None:
updated, deleted = yield self.store.get_all_updated_pushers(
pushers, current_position, limit
@@ -316,16 +344,34 @@ class ReplicationResource(Resource):
"app_id", "app_display_name", "device_display_name", "pushkey",
"ts", "lang", "data"
))
writer.write_header_and_rows("deleted", deleted, (
writer.write_header_and_rows("deleted_pushers", deleted, (
"position", "user_id", "app_id", "pushkey"
))
@defer.inlineCallbacks
def state(self, writer, current_token, limit, request_streams):
current_position = current_token.state
state = request_streams.get("state")
if state is not None:
state_groups, state_group_state = (
yield self.store.get_all_new_state_groups(
state, current_position, limit
)
)
writer.write_header_and_rows("state_groups", state_groups, (
"position", "room_id", "event_id"
))
writer.write_header_and_rows("state_group_state", state_group_state, (
"position", "type", "state_key", "event_id"
))
class _Writer(object):
"""Writes the streams as a JSON object as the response to the request"""
def __init__(self, request):
def __init__(self):
self.streams = {}
self.request = request
self.total = 0
def write_header_and_rows(self, name, rows, fields, position=None):
@@ -344,13 +390,12 @@ class _Writer(object):
self.total += len(rows)
def finish(self):
self.request.write(json.dumps(self.streams, ensure_ascii=False))
finish_request(self.request)
return self.streams
class _ReplicationToken(collections.namedtuple("_ReplicationToken", (
"events", "presence", "typing", "receipts", "account_data", "backfill",
"push_rules", "pushers"
"push_rules", "pushers", "state"
))):
__slots__ = []

View File

@@ -0,0 +1,14 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@@ -0,0 +1,14 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@@ -0,0 +1,28 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.storage._base import SQLBaseStore
from twisted.internet import defer
class BaseSlavedStore(SQLBaseStore):
def __init__(self, db_conn, hs):
super(BaseSlavedStore, self).__init__(hs)
def stream_positions(self):
return {}
def process_replication(self, result):
return defer.succeed(None)

View File

@@ -0,0 +1,30 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.storage.util.id_generators import _load_current_id
class SlavedIdTracker(object):
def __init__(self, db_conn, table, column, extra_tables=[], step=1):
self.step = step
self._current = _load_current_id(db_conn, table, column, step)
for table, column in extra_tables:
self.advance(_load_current_id(db_conn, table, column))
def advance(self, new_id):
self._current = (max if self.step > 0 else min)(self._current, new_id)
def get_current_token(self):
return self._current

View File

@@ -0,0 +1,223 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker
from synapse.api.constants import EventTypes
from synapse.events import FrozenEvent
from synapse.storage import DataStore
from synapse.storage.room import RoomStore
from synapse.storage.roommember import RoomMemberStore
from synapse.storage.event_federation import EventFederationStore
from synapse.storage.event_push_actions import EventPushActionsStore
from synapse.storage.state import StateStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
import ujson as json
# So, um, we want to borrow a load of functions intended for reading from
# a DataStore, but we don't want to take functions that either write to the
# DataStore or are cached and don't have cache invalidation logic.
#
# Rather than write duplicate versions of those functions, or lift them to
# a common base class, we going to grab the underlying __func__ object from
# the method descriptor on the DataStore and chuck them into our class.
class SlavedEventStore(BaseSlavedStore):
def __init__(self, db_conn, hs):
super(SlavedEventStore, self).__init__(db_conn, hs)
self._stream_id_gen = SlavedIdTracker(
db_conn, "events", "stream_ordering",
)
self._backfill_id_gen = SlavedIdTracker(
db_conn, "events", "stream_ordering", step=-1
)
events_max = self._stream_id_gen.get_current_token()
event_cache_prefill, min_event_val = self._get_cache_dict(
db_conn, "events",
entity_column="room_id",
stream_column="stream_ordering",
max_value=events_max,
)
self._events_stream_cache = StreamChangeCache(
"EventsRoomStreamChangeCache", min_event_val,
prefilled_cache=event_cache_prefill,
)
# Cached functions can't be accessed through a class instance so we need
# to reach inside the __dict__ to extract them.
get_room_name_and_aliases = RoomStore.__dict__["get_room_name_and_aliases"]
get_rooms_for_user = RoomMemberStore.__dict__["get_rooms_for_user"]
get_users_in_room = RoomMemberStore.__dict__["get_users_in_room"]
get_latest_event_ids_in_room = EventFederationStore.__dict__[
"get_latest_event_ids_in_room"
]
_get_current_state_for_key = StateStore.__dict__[
"_get_current_state_for_key"
]
get_invited_rooms_for_user = RoomMemberStore.__dict__[
"get_invited_rooms_for_user"
]
get_unread_event_push_actions_by_room_for_user = (
EventPushActionsStore.__dict__["get_unread_event_push_actions_by_room_for_user"]
)
get_unread_push_actions_for_user_in_range = (
DataStore.get_unread_push_actions_for_user_in_range.__func__
)
get_push_action_users_in_range = (
DataStore.get_push_action_users_in_range.__func__
)
get_event = DataStore.get_event.__func__
get_current_state = DataStore.get_current_state.__func__
get_current_state_for_key = DataStore.get_current_state_for_key.__func__
get_rooms_for_user_where_membership_is = (
DataStore.get_rooms_for_user_where_membership_is.__func__
)
get_membership_changes_for_user = (
DataStore.get_membership_changes_for_user.__func__
)
get_room_events_max_id = DataStore.get_room_events_max_id.__func__
get_room_events_stream_for_room = (
DataStore.get_room_events_stream_for_room.__func__
)
_set_before_and_after = DataStore._set_before_and_after
_get_events = DataStore._get_events.__func__
_get_events_from_cache = DataStore._get_events_from_cache.__func__
_invalidate_get_event_cache = DataStore._invalidate_get_event_cache.__func__
_parse_events_txn = DataStore._parse_events_txn.__func__
_get_events_txn = DataStore._get_events_txn.__func__
_enqueue_events = DataStore._enqueue_events.__func__
_do_fetch = DataStore._do_fetch.__func__
_fetch_events_txn = DataStore._fetch_events_txn.__func__
_fetch_event_rows = DataStore._fetch_event_rows.__func__
_get_event_from_row = DataStore._get_event_from_row.__func__
_get_event_from_row_txn = DataStore._get_event_from_row_txn.__func__
_get_rooms_for_user_where_membership_is_txn = (
DataStore._get_rooms_for_user_where_membership_is_txn.__func__
)
_get_members_rows_txn = DataStore._get_members_rows_txn.__func__
def stream_positions(self):
result = super(SlavedEventStore, self).stream_positions()
result["events"] = self._stream_id_gen.get_current_token()
result["backfill"] = self._backfill_id_gen.get_current_token()
return result
def process_replication(self, result):
state_resets = set(
r[0] for r in result.get("state_resets", {"rows": []})["rows"]
)
stream = result.get("events")
if stream:
self._stream_id_gen.advance(stream["position"])
for row in stream["rows"]:
self._process_replication_row(
row, backfilled=False, state_resets=state_resets
)
stream = result.get("backfill")
if stream:
self._backfill_id_gen.advance(stream["position"])
for row in stream["rows"]:
self._process_replication_row(
row, backfilled=True, state_resets=state_resets
)
stream = result.get("forward_ex_outliers")
if stream:
for row in stream["rows"]:
event_id = row[1]
self._invalidate_get_event_cache(event_id)
stream = result.get("backward_ex_outliers")
if stream:
for row in stream["rows"]:
event_id = row[1]
self._invalidate_get_event_cache(event_id)
return super(SlavedEventStore, self).process_replication(result)
def _process_replication_row(self, row, backfilled, state_resets):
position = row[0]
internal = json.loads(row[1])
event_json = json.loads(row[2])
event = FrozenEvent(event_json, internal_metadata_dict=internal)
self.invalidate_caches_for_event(
event, backfilled, reset_state=position in state_resets
)
def invalidate_caches_for_event(self, event, backfilled, reset_state):
if reset_state:
self._get_current_state_for_key.invalidate_all()
self.get_rooms_for_user.invalidate_all()
self.get_users_in_room.invalidate((event.room_id,))
# self.get_joined_hosts_for_room.invalidate((event.room_id,))
self.get_room_name_and_aliases.invalidate((event.room_id,))
self._invalidate_get_event_cache(event.event_id)
self.get_latest_event_ids_in_room.invalidate((event.room_id,))
self.get_unread_event_push_actions_by_room_for_user.invalidate_many(
(event.room_id,)
)
if not backfilled:
self._events_stream_cache.entity_has_changed(
event.room_id, event.internal_metadata.stream_ordering
)
# self.get_unread_event_push_actions_by_room_for_user.invalidate_many(
# (event.room_id,)
# )
if event.type == EventTypes.Redaction:
self._invalidate_get_event_cache(event.redacts)
if event.type == EventTypes.Member:
self.get_rooms_for_user.invalidate((event.state_key,))
# self.get_joined_hosts_for_room.invalidate((event.room_id,))
self.get_users_in_room.invalidate((event.room_id,))
# self._membership_stream_cache.entity_has_changed(
# event.state_key, event.internal_metadata.stream_ordering
# )
self.get_invited_rooms_for_user.invalidate((event.state_key,))
if not event.is_state():
return
if backfilled:
return
if (not event.internal_metadata.is_invite_from_remote()
and event.internal_metadata.is_outlier()):
return
self._get_current_state_for_key.invalidate((
event.room_id, event.type, event.state_key
))
if event.type in [EventTypes.Name, EventTypes.Aliases]:
self.get_room_name_and_aliases.invalidate(
(event.room_id,)
)
pass

View File

@@ -0,0 +1,52 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker
from synapse.storage import DataStore
class SlavedPusherStore(BaseSlavedStore):
def __init__(self, db_conn, hs):
super(SlavedPusherStore, self).__init__(db_conn, hs)
self._pushers_id_gen = SlavedIdTracker(
db_conn, "pushers", "id",
extra_tables=[("deleted_pushers", "stream_id")],
)
get_all_pushers = DataStore.get_all_pushers.__func__
get_pushers_by = DataStore.get_pushers_by.__func__
get_pushers_by_app_id_and_pushkey = (
DataStore.get_pushers_by_app_id_and_pushkey.__func__
)
_decode_pushers_rows = DataStore._decode_pushers_rows.__func__
def stream_positions(self):
result = super(SlavedPusherStore, self).stream_positions()
result["pushers"] = self._pushers_id_gen.get_current_token()
return result
def process_replication(self, result):
stream = result.get("pushers")
if stream:
self._pushers_id_gen.advance(stream["position"])
stream = result.get("deleted_pushers")
if stream:
self._pushers_id_gen.advance(stream["position"])
return super(SlavedPusherStore, self).process_replication(result)

View File

@@ -0,0 +1,61 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker
from synapse.storage import DataStore
from synapse.storage.receipts import ReceiptsStore
# So, um, we want to borrow a load of functions intended for reading from
# a DataStore, but we don't want to take functions that either write to the
# DataStore or are cached and don't have cache invalidation logic.
#
# Rather than write duplicate versions of those functions, or lift them to
# a common base class, we going to grab the underlying __func__ object from
# the method descriptor on the DataStore and chuck them into our class.
class SlavedReceiptsStore(BaseSlavedStore):
def __init__(self, db_conn, hs):
super(SlavedReceiptsStore, self).__init__(db_conn, hs)
self._receipts_id_gen = SlavedIdTracker(
db_conn, "receipts_linearized", "stream_id"
)
get_receipts_for_user = ReceiptsStore.__dict__["get_receipts_for_user"]
get_max_receipt_stream_id = DataStore.get_max_receipt_stream_id.__func__
get_all_updated_receipts = DataStore.get_all_updated_receipts.__func__
def stream_positions(self):
result = super(SlavedReceiptsStore, self).stream_positions()
result["receipts"] = self._receipts_id_gen.get_current_token()
return result
def process_replication(self, result):
stream = result.get("receipts")
if stream:
self._receipts_id_gen.advance(stream["position"])
for row in stream["rows"]:
room_id, receipt_type, user_id = row[1:4]
self.invalidate_caches_for_receipt(room_id, receipt_type, user_id)
return super(SlavedReceiptsStore, self).process_replication(result)
def invalidate_caches_for_receipt(self, room_id, receipt_type, user_id):
self.get_receipts_for_user.invalidate((user_id, receipt_type))

View File

@@ -33,6 +33,9 @@ from saml2.client import Saml2Client
import xml.etree.ElementTree as ET
import jwt
from jwt.exceptions import InvalidTokenError
logger = logging.getLogger(__name__)
@@ -43,12 +46,16 @@ class LoginRestServlet(ClientV1RestServlet):
SAML2_TYPE = "m.login.saml2"
CAS_TYPE = "m.login.cas"
TOKEN_TYPE = "m.login.token"
JWT_TYPE = "m.login.jwt"
def __init__(self, hs):
super(LoginRestServlet, self).__init__(hs)
self.idp_redirect_url = hs.config.saml2_idp_redirect_url
self.password_enabled = hs.config.password_enabled
self.saml2_enabled = hs.config.saml2_enabled
self.jwt_enabled = hs.config.jwt_enabled
self.jwt_secret = hs.config.jwt_secret
self.jwt_algorithm = hs.config.jwt_algorithm
self.cas_enabled = hs.config.cas_enabled
self.cas_server_url = hs.config.cas_server_url
self.cas_required_attributes = hs.config.cas_required_attributes
@@ -57,6 +64,8 @@ class LoginRestServlet(ClientV1RestServlet):
def on_GET(self, request):
flows = []
if self.jwt_enabled:
flows.append({"type": LoginRestServlet.JWT_TYPE})
if self.saml2_enabled:
flows.append({"type": LoginRestServlet.SAML2_TYPE})
if self.cas_enabled:
@@ -98,6 +107,10 @@ class LoginRestServlet(ClientV1RestServlet):
"uri": "%s%s" % (self.idp_redirect_url, relay_state)
}
defer.returnValue((200, result))
elif self.jwt_enabled and (login_submission["type"] ==
LoginRestServlet.JWT_TYPE):
result = yield self.do_jwt_login(login_submission)
defer.returnValue(result)
# TODO Delete this after all CAS clients switch to token login instead
elif self.cas_enabled and (login_submission["type"] ==
LoginRestServlet.CAS_TYPE):
@@ -209,6 +222,46 @@ class LoginRestServlet(ClientV1RestServlet):
defer.returnValue((200, result))
@defer.inlineCallbacks
def do_jwt_login(self, login_submission):
token = login_submission['token']
if token is None:
raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
try:
payload = jwt.decode(token, self.jwt_secret, algorithms=[self.jwt_algorithm])
except InvalidTokenError:
raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED)
user = payload['user']
if user is None:
raise LoginError(401, "Invalid JWT", errcode=Codes.UNAUTHORIZED)
user_id = UserID.create(user, self.hs.hostname).to_string()
auth_handler = self.handlers.auth_handler
user_exists = yield auth_handler.does_user_exist(user_id)
if user_exists:
user_id, access_token, refresh_token = (
yield auth_handler.get_login_tuple_for_user_id(user_id)
)
result = {
"user_id": user_id, # may have changed
"access_token": access_token,
"refresh_token": refresh_token,
"home_server": self.hs.hostname,
}
else:
user_id, access_token = (
yield self.handlers.registration_handler.register(localpart=user)
)
result = {
"user_id": user_id, # may have changed
"access_token": access_token,
"home_server": self.hs.hostname,
}
defer.returnValue((200, result))
# TODO Delete this after all CAS clients switch to token login instead
def parse_cas_response(self, cas_response_body):
root = ET.fromstring(cas_response_body)

View File

@@ -26,11 +26,48 @@ import logging
logger = logging.getLogger(__name__)
class PusherRestServlet(ClientV1RestServlet):
class PushersRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/pushers$")
def __init__(self, hs):
super(PushersRestServlet, self).__init__(hs)
@defer.inlineCallbacks
def on_GET(self, request):
requester = yield self.auth.get_user_by_req(request)
user = requester.user
pushers = yield self.hs.get_datastore().get_pushers_by_user_id(
user.to_string()
)
allowed_keys = [
"app_display_name",
"app_id",
"data",
"device_display_name",
"kind",
"lang",
"profile_tag",
"pushkey",
]
for p in pushers:
for k, v in p.items():
if k not in allowed_keys:
del p[k]
defer.returnValue((200, {"pushers": pushers}))
def on_OPTIONS(self, _):
return 200, {}
class PushersSetRestServlet(ClientV1RestServlet):
PATTERNS = client_path_patterns("/pushers/set$")
def __init__(self, hs):
super(PusherRestServlet, self).__init__(hs)
super(PushersSetRestServlet, self).__init__(hs)
self.notifier = hs.get_notifier()
@defer.inlineCallbacks
@@ -100,4 +137,5 @@ class PusherRestServlet(ClientV1RestServlet):
def register_servlets(hs, http_server):
PusherRestServlet(hs).register(http_server)
PushersRestServlet(hs).register(http_server)
PushersSetRestServlet(hs).register(http_server)

View File

@@ -405,6 +405,42 @@ class RoomEventContext(ClientV1RestServlet):
defer.returnValue((200, results))
class RoomForgetRestServlet(ClientV1RestServlet):
def register(self, http_server):
PATTERNS = ("/rooms/(?P<room_id>[^/]*)/forget")
register_txn_path(self, PATTERNS, http_server)
@defer.inlineCallbacks
def on_POST(self, request, room_id, txn_id=None):
requester = yield self.auth.get_user_by_req(
request,
allow_guest=False,
)
yield self.handlers.room_member_handler.forget(
user=requester.user,
room_id=room_id,
)
defer.returnValue((200, {}))
@defer.inlineCallbacks
def on_PUT(self, request, room_id, txn_id):
try:
defer.returnValue(
self.txns.get_client_transaction(request, txn_id)
)
except KeyError:
pass
response = yield self.on_POST(
request, room_id, txn_id
)
self.txns.store_client_transaction(request, txn_id, response)
defer.returnValue(response)
# TODO: Needs unit testing
class RoomMembershipRestServlet(ClientV1RestServlet):
@@ -624,6 +660,7 @@ def register_servlets(hs, http_server):
RoomMemberListRestServlet(hs).register(http_server)
RoomMessageListRestServlet(hs).register(http_server)
JoinRoomAliasServlet(hs).register(http_server)
RoomForgetRestServlet(hs).register(http_server)
RoomMembershipRestServlet(hs).register(http_server)
RoomSendEventRestServlet(hs).register(http_server)
PublicRoomListRestServlet(hs).register(http_server)

View File

@@ -100,6 +100,11 @@ class RegisterRestServlet(RestServlet):
# == Application Service Registration ==
if appservice:
# Set the desired user according to the AS API (which uses the
# 'user' key not 'username'). Since this is a new addition, we'll
# fallback to 'username' if they gave one.
if isinstance(body.get("user"), basestring):
desired_username = body["user"]
result = yield self._do_appservice_registration(
desired_username, request.args["access_token"][0]
)

View File

@@ -115,6 +115,8 @@ class SyncRestServlet(RestServlet):
)
)
request_key = (user, timeout, since, filter_id, full_state)
if filter_id:
if filter_id.startswith('{'):
try:
@@ -134,6 +136,7 @@ class SyncRestServlet(RestServlet):
user=user,
filter_collection=filter,
is_guest=requester.is_guest,
request_key=request_key,
)
if since is not None:
@@ -196,15 +199,17 @@ class SyncRestServlet(RestServlet):
"""
Encode the joined rooms in a sync result
:param list[synapse.handlers.sync.JoinedSyncResult] rooms: list of sync
results for rooms this user is joined to
:param int time_now: current time - used as a baseline for age
calculations
:param int token_id: ID of the user's auth token - used for namespacing
of transaction IDs
Args:
rooms(list[synapse.handlers.sync.JoinedSyncResult]): list of sync
results for rooms this user is joined to
time_now(int): current time - used as a baseline for age
calculations
token_id(int): ID of the user's auth token - used for namespacing
of transaction IDs
:return: the joined rooms list, in our response format
:rtype: dict[str, dict[str, object]]
Returns:
dict[str, dict[str, object]]: the joined rooms list, in our
response format
"""
joined = {}
for room in rooms:
@@ -218,15 +223,17 @@ class SyncRestServlet(RestServlet):
"""
Encode the invited rooms in a sync result
:param list[synapse.handlers.sync.InvitedSyncResult] rooms: list of
sync results for rooms this user is joined to
:param int time_now: current time - used as a baseline for age
calculations
:param int token_id: ID of the user's auth token - used for namespacing
Args:
rooms(list[synapse.handlers.sync.InvitedSyncResult]): list of
sync results for rooms this user is joined to
time_now(int): current time - used as a baseline for age
calculations
token_id(int): ID of the user's auth token - used for namespacing
of transaction IDs
:return: the invited rooms list, in our response format
:rtype: dict[str, dict[str, object]]
Returns:
dict[str, dict[str, object]]: the invited rooms list, in our
response format
"""
invited = {}
for room in rooms:
@@ -248,15 +255,17 @@ class SyncRestServlet(RestServlet):
"""
Encode the archived rooms in a sync result
:param list[synapse.handlers.sync.ArchivedSyncResult] rooms: list of
sync results for rooms this user is joined to
:param int time_now: current time - used as a baseline for age
calculations
:param int token_id: ID of the user's auth token - used for namespacing
of transaction IDs
Args:
rooms (list[synapse.handlers.sync.ArchivedSyncResult]): list of
sync results for rooms this user is joined to
time_now(int): current time - used as a baseline for age
calculations
token_id(int): ID of the user's auth token - used for namespacing
of transaction IDs
:return: the invited rooms list, in our response format
:rtype: dict[str, dict[str, object]]
Returns:
dict[str, dict[str, object]]: The invited rooms list, in our
response format
"""
joined = {}
for room in rooms:
@@ -269,17 +278,18 @@ class SyncRestServlet(RestServlet):
@staticmethod
def encode_room(room, time_now, token_id, joined=True):
"""
:param JoinedSyncResult|ArchivedSyncResult room: sync result for a
single room
:param int time_now: current time - used as a baseline for age
calculations
:param int token_id: ID of the user's auth token - used for namespacing
of transaction IDs
:param joined: True if the user is joined to this room - will mean
we handle ephemeral events
Args:
room (JoinedSyncResult|ArchivedSyncResult): sync result for a
single room
time_now (int): current time - used as a baseline for age
calculations
token_id (int): ID of the user's auth token - used for namespacing
of transaction IDs
joined (bool): True if the user is joined to this room - will mean
we handle ephemeral events
:return: the room, encoded in our response format
:rtype: dict[str, object]
Returns:
dict[str, object]: the room, encoded in our response format
"""
def serialize(event):
# TODO(mjark): Respect formatting requirements in the filter.

View File

@@ -0,0 +1,110 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.http.server import respond_with_json, finish_request
from synapse.api.errors import (
cs_error, Codes, SynapseError
)
from twisted.internet import defer
from twisted.protocols.basic import FileSender
from synapse.util.stringutils import is_ascii
import os
import logging
import urllib
import urlparse
logger = logging.getLogger(__name__)
def parse_media_id(request):
try:
# This allows users to append e.g. /test.png to the URL. Useful for
# clients that parse the URL to see content type.
server_name, media_id = request.postpath[:2]
file_name = None
if len(request.postpath) > 2:
try:
file_name = urlparse.unquote(request.postpath[-1]).decode("utf-8")
except UnicodeDecodeError:
pass
return server_name, media_id, file_name
except:
raise SynapseError(
404,
"Invalid media id token %r" % (request.postpath,),
Codes.UNKNOWN,
)
def respond_404(request):
respond_with_json(
request, 404,
cs_error(
"Not found %r" % (request.postpath,),
code=Codes.NOT_FOUND,
),
send_cors=True
)
@defer.inlineCallbacks
def respond_with_file(request, media_type, file_path,
file_size=None, upload_name=None):
logger.debug("Responding with %r", file_path)
if os.path.isfile(file_path):
request.setHeader(b"Content-Type", media_type.encode("UTF-8"))
if upload_name:
if is_ascii(upload_name):
request.setHeader(
b"Content-Disposition",
b"inline; filename=%s" % (
urllib.quote(upload_name.encode("utf-8")),
),
)
else:
request.setHeader(
b"Content-Disposition",
b"inline; filename*=utf-8''%s" % (
urllib.quote(upload_name.encode("utf-8")),
),
)
# cache for at least a day.
# XXX: we might want to turn this off for data we don't want to
# recommend caching as it's sensitive or private - or at least
# select private. don't bother setting Expires as all our
# clients are smart enough to be happy with Cache-Control
request.setHeader(
b"Cache-Control", b"public,max-age=86400,s-maxage=86400"
)
if file_size is None:
stat = os.stat(file_path)
file_size = stat.st_size
request.setHeader(
b"Content-Length", b"%d" % (file_size,)
)
with open(file_path, "rb") as f:
yield FileSender().beginFileTransfer(f, request)
finish_request(request)
else:
respond_404(request)

View File

@@ -1,459 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .thumbnailer import Thumbnailer
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
from synapse.http.server import respond_with_json, finish_request
from synapse.util.stringutils import random_string
from synapse.api.errors import (
cs_error, Codes, SynapseError
)
from twisted.internet import defer, threads
from twisted.web.resource import Resource
from twisted.protocols.basic import FileSender
from synapse.util.async import ObservableDeferred
from synapse.util.stringutils import is_ascii
from synapse.util.logcontext import preserve_context_over_fn
import os
import cgi
import logging
import urllib
import urlparse
logger = logging.getLogger(__name__)
def parse_media_id(request):
try:
# This allows users to append e.g. /test.png to the URL. Useful for
# clients that parse the URL to see content type.
server_name, media_id = request.postpath[:2]
file_name = None
if len(request.postpath) > 2:
try:
file_name = urlparse.unquote(request.postpath[-1]).decode("utf-8")
except UnicodeDecodeError:
pass
return server_name, media_id, file_name
except:
raise SynapseError(
404,
"Invalid media id token %r" % (request.postpath,),
Codes.UNKNOWN,
)
class BaseMediaResource(Resource):
isLeaf = True
def __init__(self, hs, filepaths):
Resource.__init__(self)
self.auth = hs.get_auth()
self.client = MatrixFederationHttpClient(hs)
self.clock = hs.get_clock()
self.server_name = hs.hostname
self.store = hs.get_datastore()
self.max_upload_size = hs.config.max_upload_size
self.max_image_pixels = hs.config.max_image_pixels
self.filepaths = filepaths
self.version_string = hs.version_string
self.downloads = {}
self.dynamic_thumbnails = hs.config.dynamic_thumbnails
self.thumbnail_requirements = hs.config.thumbnail_requirements
def _respond_404(self, request):
respond_with_json(
request, 404,
cs_error(
"Not found %r" % (request.postpath,),
code=Codes.NOT_FOUND,
),
send_cors=True
)
@staticmethod
def _makedirs(filepath):
dirname = os.path.dirname(filepath)
if not os.path.exists(dirname):
os.makedirs(dirname)
def _get_remote_media(self, server_name, media_id):
key = (server_name, media_id)
download = self.downloads.get(key)
if download is None:
download = self._get_remote_media_impl(server_name, media_id)
download = ObservableDeferred(
download,
consumeErrors=True
)
self.downloads[key] = download
@download.addBoth
def callback(media_info):
del self.downloads[key]
return media_info
return download.observe()
@defer.inlineCallbacks
def _get_remote_media_impl(self, server_name, media_id):
media_info = yield self.store.get_cached_remote_media(
server_name, media_id
)
if not media_info:
media_info = yield self._download_remote_file(
server_name, media_id
)
defer.returnValue(media_info)
@defer.inlineCallbacks
def _download_remote_file(self, server_name, media_id):
file_id = random_string(24)
fname = self.filepaths.remote_media_filepath(
server_name, file_id
)
self._makedirs(fname)
try:
with open(fname, "wb") as f:
request_path = "/".join((
"/_matrix/media/v1/download", server_name, media_id,
))
length, headers = yield self.client.get_file(
server_name, request_path, output_stream=f,
max_size=self.max_upload_size,
)
media_type = headers["Content-Type"][0]
time_now_ms = self.clock.time_msec()
content_disposition = headers.get("Content-Disposition", None)
if content_disposition:
_, params = cgi.parse_header(content_disposition[0],)
upload_name = None
# First check if there is a valid UTF-8 filename
upload_name_utf8 = params.get("filename*", None)
if upload_name_utf8:
if upload_name_utf8.lower().startswith("utf-8''"):
upload_name = upload_name_utf8[7:]
# If there isn't check for an ascii name.
if not upload_name:
upload_name_ascii = params.get("filename", None)
if upload_name_ascii and is_ascii(upload_name_ascii):
upload_name = upload_name_ascii
if upload_name:
upload_name = urlparse.unquote(upload_name)
try:
upload_name = upload_name.decode("utf-8")
except UnicodeDecodeError:
upload_name = None
else:
upload_name = None
yield self.store.store_cached_remote_media(
origin=server_name,
media_id=media_id,
media_type=media_type,
time_now_ms=self.clock.time_msec(),
upload_name=upload_name,
media_length=length,
filesystem_id=file_id,
)
except:
os.remove(fname)
raise
media_info = {
"media_type": media_type,
"media_length": length,
"upload_name": upload_name,
"created_ts": time_now_ms,
"filesystem_id": file_id,
}
yield self._generate_remote_thumbnails(
server_name, media_id, media_info
)
defer.returnValue(media_info)
@defer.inlineCallbacks
def _respond_with_file(self, request, media_type, file_path,
file_size=None, upload_name=None):
logger.debug("Responding with %r", file_path)
if os.path.isfile(file_path):
request.setHeader(b"Content-Type", media_type.encode("UTF-8"))
if upload_name:
if is_ascii(upload_name):
request.setHeader(
b"Content-Disposition",
b"inline; filename=%s" % (
urllib.quote(upload_name.encode("utf-8")),
),
)
else:
request.setHeader(
b"Content-Disposition",
b"inline; filename*=utf-8''%s" % (
urllib.quote(upload_name.encode("utf-8")),
),
)
# cache for at least a day.
# XXX: we might want to turn this off for data we don't want to
# recommend caching as it's sensitive or private - or at least
# select private. don't bother setting Expires as all our
# clients are smart enough to be happy with Cache-Control
request.setHeader(
b"Cache-Control", b"public,max-age=86400,s-maxage=86400"
)
if file_size is None:
stat = os.stat(file_path)
file_size = stat.st_size
request.setHeader(
b"Content-Length", b"%d" % (file_size,)
)
with open(file_path, "rb") as f:
yield FileSender().beginFileTransfer(f, request)
finish_request(request)
else:
self._respond_404(request)
def _get_thumbnail_requirements(self, media_type):
return self.thumbnail_requirements.get(media_type, ())
def _generate_thumbnail(self, input_path, t_path, t_width, t_height,
t_method, t_type):
thumbnailer = Thumbnailer(input_path)
m_width = thumbnailer.width
m_height = thumbnailer.height
if m_width * m_height >= self.max_image_pixels:
logger.info(
"Image too large to thumbnail %r x %r > %r",
m_width, m_height, self.max_image_pixels
)
return
if t_method == "crop":
t_len = thumbnailer.crop(t_path, t_width, t_height, t_type)
elif t_method == "scale":
t_len = thumbnailer.scale(t_path, t_width, t_height, t_type)
else:
t_len = None
return t_len
@defer.inlineCallbacks
def _generate_local_exact_thumbnail(self, media_id, t_width, t_height,
t_method, t_type):
input_path = self.filepaths.local_media_filepath(media_id)
t_path = self.filepaths.local_media_thumbnail(
media_id, t_width, t_height, t_type, t_method
)
self._makedirs(t_path)
t_len = yield preserve_context_over_fn(
threads.deferToThread,
self._generate_thumbnail,
input_path, t_path, t_width, t_height, t_method, t_type
)
if t_len:
yield self.store.store_local_thumbnail(
media_id, t_width, t_height, t_type, t_method, t_len
)
defer.returnValue(t_path)
@defer.inlineCallbacks
def _generate_remote_exact_thumbnail(self, server_name, file_id, media_id,
t_width, t_height, t_method, t_type):
input_path = self.filepaths.remote_media_filepath(server_name, file_id)
t_path = self.filepaths.remote_media_thumbnail(
server_name, file_id, t_width, t_height, t_type, t_method
)
self._makedirs(t_path)
t_len = yield preserve_context_over_fn(
threads.deferToThread,
self._generate_thumbnail,
input_path, t_path, t_width, t_height, t_method, t_type
)
if t_len:
yield self.store.store_remote_media_thumbnail(
server_name, media_id, file_id,
t_width, t_height, t_type, t_method, t_len
)
defer.returnValue(t_path)
@defer.inlineCallbacks
def _generate_local_thumbnails(self, media_id, media_info):
media_type = media_info["media_type"]
requirements = self._get_thumbnail_requirements(media_type)
if not requirements:
return
input_path = self.filepaths.local_media_filepath(media_id)
thumbnailer = Thumbnailer(input_path)
m_width = thumbnailer.width
m_height = thumbnailer.height
if m_width * m_height >= self.max_image_pixels:
logger.info(
"Image too large to thumbnail %r x %r > %r",
m_width, m_height, self.max_image_pixels
)
return
local_thumbnails = []
def generate_thumbnails():
scales = set()
crops = set()
for r_width, r_height, r_method, r_type in requirements:
if r_method == "scale":
t_width, t_height = thumbnailer.aspect(r_width, r_height)
scales.add((
min(m_width, t_width), min(m_height, t_height), r_type,
))
elif r_method == "crop":
crops.add((r_width, r_height, r_type))
for t_width, t_height, t_type in scales:
t_method = "scale"
t_path = self.filepaths.local_media_thumbnail(
media_id, t_width, t_height, t_type, t_method
)
self._makedirs(t_path)
t_len = thumbnailer.scale(t_path, t_width, t_height, t_type)
local_thumbnails.append((
media_id, t_width, t_height, t_type, t_method, t_len
))
for t_width, t_height, t_type in crops:
if (t_width, t_height, t_type) in scales:
# If the aspect ratio of the cropped thumbnail matches a purely
# scaled one then there is no point in calculating a separate
# thumbnail.
continue
t_method = "crop"
t_path = self.filepaths.local_media_thumbnail(
media_id, t_width, t_height, t_type, t_method
)
self._makedirs(t_path)
t_len = thumbnailer.crop(t_path, t_width, t_height, t_type)
local_thumbnails.append((
media_id, t_width, t_height, t_type, t_method, t_len
))
yield preserve_context_over_fn(threads.deferToThread, generate_thumbnails)
for l in local_thumbnails:
yield self.store.store_local_thumbnail(*l)
defer.returnValue({
"width": m_width,
"height": m_height,
})
@defer.inlineCallbacks
def _generate_remote_thumbnails(self, server_name, media_id, media_info):
media_type = media_info["media_type"]
file_id = media_info["filesystem_id"]
requirements = self._get_thumbnail_requirements(media_type)
if not requirements:
return
remote_thumbnails = []
input_path = self.filepaths.remote_media_filepath(server_name, file_id)
thumbnailer = Thumbnailer(input_path)
m_width = thumbnailer.width
m_height = thumbnailer.height
def generate_thumbnails():
if m_width * m_height >= self.max_image_pixels:
logger.info(
"Image too large to thumbnail %r x %r > %r",
m_width, m_height, self.max_image_pixels
)
return
scales = set()
crops = set()
for r_width, r_height, r_method, r_type in requirements:
if r_method == "scale":
t_width, t_height = thumbnailer.aspect(r_width, r_height)
scales.add((
min(m_width, t_width), min(m_height, t_height), r_type,
))
elif r_method == "crop":
crops.add((r_width, r_height, r_type))
for t_width, t_height, t_type in scales:
t_method = "scale"
t_path = self.filepaths.remote_media_thumbnail(
server_name, file_id, t_width, t_height, t_type, t_method
)
self._makedirs(t_path)
t_len = thumbnailer.scale(t_path, t_width, t_height, t_type)
remote_thumbnails.append([
server_name, media_id, file_id,
t_width, t_height, t_type, t_method, t_len
])
for t_width, t_height, t_type in crops:
if (t_width, t_height, t_type) in scales:
# If the aspect ratio of the cropped thumbnail matches a purely
# scaled one then there is no point in calculating a separate
# thumbnail.
continue
t_method = "crop"
t_path = self.filepaths.remote_media_thumbnail(
server_name, file_id, t_width, t_height, t_type, t_method
)
self._makedirs(t_path)
t_len = thumbnailer.crop(t_path, t_width, t_height, t_type)
remote_thumbnails.append([
server_name, media_id, file_id,
t_width, t_height, t_type, t_method, t_len
])
yield preserve_context_over_fn(threads.deferToThread, generate_thumbnails)
for r in remote_thumbnails:
yield self.store.store_remote_media_thumbnail(*r)
defer.returnValue({
"width": m_width,
"height": m_height,
})

View File

@@ -13,7 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .base_resource import BaseMediaResource, parse_media_id
from ._base import parse_media_id, respond_with_file, respond_404
from twisted.web.resource import Resource
from synapse.http.server import request_handler
from twisted.web.server import NOT_DONE_YET
@@ -24,7 +25,18 @@ import logging
logger = logging.getLogger(__name__)
class DownloadResource(BaseMediaResource):
class DownloadResource(Resource):
isLeaf = True
def __init__(self, hs, media_repo):
Resource.__init__(self)
self.filepaths = media_repo.filepaths
self.media_repo = media_repo
self.server_name = hs.hostname
self.store = hs.get_datastore()
self.version_string = hs.version_string
def render_GET(self, request):
self._async_render_GET(request)
return NOT_DONE_YET
@@ -44,7 +56,7 @@ class DownloadResource(BaseMediaResource):
def _respond_local_file(self, request, media_id, name):
media_info = yield self.store.get_local_media(media_id)
if not media_info:
self._respond_404(request)
respond_404(request)
return
media_type = media_info["media_type"]
@@ -52,14 +64,14 @@ class DownloadResource(BaseMediaResource):
upload_name = name if name else media_info["upload_name"]
file_path = self.filepaths.local_media_filepath(media_id)
yield self._respond_with_file(
yield respond_with_file(
request, media_type, file_path, media_length,
upload_name=upload_name,
)
@defer.inlineCallbacks
def _respond_remote_file(self, request, server_name, media_id, name):
media_info = yield self._get_remote_media(server_name, media_id)
media_info = yield self.media_repo.get_remote_media(server_name, media_id)
media_type = media_info["media_type"]
media_length = media_info["media_length"]
@@ -70,7 +82,7 @@ class DownloadResource(BaseMediaResource):
server_name, filesystem_id
)
yield self._respond_with_file(
yield respond_with_file(
request, media_type, file_path, media_length,
upload_name=upload_name,
)

View File

@@ -17,15 +17,400 @@ from .upload_resource import UploadResource
from .download_resource import DownloadResource
from .thumbnail_resource import ThumbnailResource
from .identicon_resource import IdenticonResource
from .preview_url_resource import PreviewUrlResource
from .filepath import MediaFilePaths
from twisted.web.resource import Resource
from .thumbnailer import Thumbnailer
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
from synapse.util.stringutils import random_string
from twisted.internet import defer, threads
from synapse.util.async import ObservableDeferred
from synapse.util.stringutils import is_ascii
from synapse.util.logcontext import preserve_context_over_fn
import os
import cgi
import logging
import urlparse
logger = logging.getLogger(__name__)
class MediaRepository(object):
def __init__(self, hs, filepaths):
self.auth = hs.get_auth()
self.client = MatrixFederationHttpClient(hs)
self.clock = hs.get_clock()
self.server_name = hs.hostname
self.store = hs.get_datastore()
self.max_upload_size = hs.config.max_upload_size
self.max_image_pixels = hs.config.max_image_pixels
self.filepaths = filepaths
self.downloads = {}
self.dynamic_thumbnails = hs.config.dynamic_thumbnails
self.thumbnail_requirements = hs.config.thumbnail_requirements
@staticmethod
def _makedirs(filepath):
dirname = os.path.dirname(filepath)
if not os.path.exists(dirname):
os.makedirs(dirname)
@defer.inlineCallbacks
def create_content(self, media_type, upload_name, content, content_length,
auth_user):
media_id = random_string(24)
fname = self.filepaths.local_media_filepath(media_id)
self._makedirs(fname)
# This shouldn't block for very long because the content will have
# already been uploaded at this point.
with open(fname, "wb") as f:
f.write(content)
yield self.store.store_local_media(
media_id=media_id,
media_type=media_type,
time_now_ms=self.clock.time_msec(),
upload_name=upload_name,
media_length=content_length,
user_id=auth_user,
)
media_info = {
"media_type": media_type,
"media_length": content_length,
}
yield self._generate_local_thumbnails(media_id, media_info)
defer.returnValue("mxc://%s/%s" % (self.server_name, media_id))
def get_remote_media(self, server_name, media_id):
key = (server_name, media_id)
download = self.downloads.get(key)
if download is None:
download = self._get_remote_media_impl(server_name, media_id)
download = ObservableDeferred(
download,
consumeErrors=True
)
self.downloads[key] = download
@download.addBoth
def callback(media_info):
del self.downloads[key]
return media_info
return download.observe()
@defer.inlineCallbacks
def _get_remote_media_impl(self, server_name, media_id):
media_info = yield self.store.get_cached_remote_media(
server_name, media_id
)
if not media_info:
media_info = yield self._download_remote_file(
server_name, media_id
)
defer.returnValue(media_info)
@defer.inlineCallbacks
def _download_remote_file(self, server_name, media_id):
file_id = random_string(24)
fname = self.filepaths.remote_media_filepath(
server_name, file_id
)
self._makedirs(fname)
try:
with open(fname, "wb") as f:
request_path = "/".join((
"/_matrix/media/v1/download", server_name, media_id,
))
length, headers = yield self.client.get_file(
server_name, request_path, output_stream=f,
max_size=self.max_upload_size,
)
media_type = headers["Content-Type"][0]
time_now_ms = self.clock.time_msec()
content_disposition = headers.get("Content-Disposition", None)
if content_disposition:
_, params = cgi.parse_header(content_disposition[0],)
upload_name = None
# First check if there is a valid UTF-8 filename
upload_name_utf8 = params.get("filename*", None)
if upload_name_utf8:
if upload_name_utf8.lower().startswith("utf-8''"):
upload_name = upload_name_utf8[7:]
# If there isn't check for an ascii name.
if not upload_name:
upload_name_ascii = params.get("filename", None)
if upload_name_ascii and is_ascii(upload_name_ascii):
upload_name = upload_name_ascii
if upload_name:
upload_name = urlparse.unquote(upload_name)
try:
upload_name = upload_name.decode("utf-8")
except UnicodeDecodeError:
upload_name = None
else:
upload_name = None
yield self.store.store_cached_remote_media(
origin=server_name,
media_id=media_id,
media_type=media_type,
time_now_ms=self.clock.time_msec(),
upload_name=upload_name,
media_length=length,
filesystem_id=file_id,
)
except:
os.remove(fname)
raise
media_info = {
"media_type": media_type,
"media_length": length,
"upload_name": upload_name,
"created_ts": time_now_ms,
"filesystem_id": file_id,
}
yield self._generate_remote_thumbnails(
server_name, media_id, media_info
)
defer.returnValue(media_info)
def _get_thumbnail_requirements(self, media_type):
return self.thumbnail_requirements.get(media_type, ())
def _generate_thumbnail(self, input_path, t_path, t_width, t_height,
t_method, t_type):
thumbnailer = Thumbnailer(input_path)
m_width = thumbnailer.width
m_height = thumbnailer.height
if m_width * m_height >= self.max_image_pixels:
logger.info(
"Image too large to thumbnail %r x %r > %r",
m_width, m_height, self.max_image_pixels
)
return
if t_method == "crop":
t_len = thumbnailer.crop(t_path, t_width, t_height, t_type)
elif t_method == "scale":
t_len = thumbnailer.scale(t_path, t_width, t_height, t_type)
else:
t_len = None
return t_len
@defer.inlineCallbacks
def generate_local_exact_thumbnail(self, media_id, t_width, t_height,
t_method, t_type):
input_path = self.filepaths.local_media_filepath(media_id)
t_path = self.filepaths.local_media_thumbnail(
media_id, t_width, t_height, t_type, t_method
)
self._makedirs(t_path)
t_len = yield preserve_context_over_fn(
threads.deferToThread,
self._generate_thumbnail,
input_path, t_path, t_width, t_height, t_method, t_type
)
if t_len:
yield self.store.store_local_thumbnail(
media_id, t_width, t_height, t_type, t_method, t_len
)
defer.returnValue(t_path)
@defer.inlineCallbacks
def generate_remote_exact_thumbnail(self, server_name, file_id, media_id,
t_width, t_height, t_method, t_type):
input_path = self.filepaths.remote_media_filepath(server_name, file_id)
t_path = self.filepaths.remote_media_thumbnail(
server_name, file_id, t_width, t_height, t_type, t_method
)
self._makedirs(t_path)
t_len = yield preserve_context_over_fn(
threads.deferToThread,
self._generate_thumbnail,
input_path, t_path, t_width, t_height, t_method, t_type
)
if t_len:
yield self.store.store_remote_media_thumbnail(
server_name, media_id, file_id,
t_width, t_height, t_type, t_method, t_len
)
defer.returnValue(t_path)
@defer.inlineCallbacks
def _generate_local_thumbnails(self, media_id, media_info):
media_type = media_info["media_type"]
requirements = self._get_thumbnail_requirements(media_type)
if not requirements:
return
input_path = self.filepaths.local_media_filepath(media_id)
thumbnailer = Thumbnailer(input_path)
m_width = thumbnailer.width
m_height = thumbnailer.height
if m_width * m_height >= self.max_image_pixels:
logger.info(
"Image too large to thumbnail %r x %r > %r",
m_width, m_height, self.max_image_pixels
)
return
local_thumbnails = []
def generate_thumbnails():
scales = set()
crops = set()
for r_width, r_height, r_method, r_type in requirements:
if r_method == "scale":
t_width, t_height = thumbnailer.aspect(r_width, r_height)
scales.add((
min(m_width, t_width), min(m_height, t_height), r_type,
))
elif r_method == "crop":
crops.add((r_width, r_height, r_type))
for t_width, t_height, t_type in scales:
t_method = "scale"
t_path = self.filepaths.local_media_thumbnail(
media_id, t_width, t_height, t_type, t_method
)
self._makedirs(t_path)
t_len = thumbnailer.scale(t_path, t_width, t_height, t_type)
local_thumbnails.append((
media_id, t_width, t_height, t_type, t_method, t_len
))
for t_width, t_height, t_type in crops:
if (t_width, t_height, t_type) in scales:
# If the aspect ratio of the cropped thumbnail matches a purely
# scaled one then there is no point in calculating a separate
# thumbnail.
continue
t_method = "crop"
t_path = self.filepaths.local_media_thumbnail(
media_id, t_width, t_height, t_type, t_method
)
self._makedirs(t_path)
t_len = thumbnailer.crop(t_path, t_width, t_height, t_type)
local_thumbnails.append((
media_id, t_width, t_height, t_type, t_method, t_len
))
yield preserve_context_over_fn(threads.deferToThread, generate_thumbnails)
for l in local_thumbnails:
yield self.store.store_local_thumbnail(*l)
defer.returnValue({
"width": m_width,
"height": m_height,
})
@defer.inlineCallbacks
def _generate_remote_thumbnails(self, server_name, media_id, media_info):
media_type = media_info["media_type"]
file_id = media_info["filesystem_id"]
requirements = self._get_thumbnail_requirements(media_type)
if not requirements:
return
remote_thumbnails = []
input_path = self.filepaths.remote_media_filepath(server_name, file_id)
thumbnailer = Thumbnailer(input_path)
m_width = thumbnailer.width
m_height = thumbnailer.height
def generate_thumbnails():
if m_width * m_height >= self.max_image_pixels:
logger.info(
"Image too large to thumbnail %r x %r > %r",
m_width, m_height, self.max_image_pixels
)
return
scales = set()
crops = set()
for r_width, r_height, r_method, r_type in requirements:
if r_method == "scale":
t_width, t_height = thumbnailer.aspect(r_width, r_height)
scales.add((
min(m_width, t_width), min(m_height, t_height), r_type,
))
elif r_method == "crop":
crops.add((r_width, r_height, r_type))
for t_width, t_height, t_type in scales:
t_method = "scale"
t_path = self.filepaths.remote_media_thumbnail(
server_name, file_id, t_width, t_height, t_type, t_method
)
self._makedirs(t_path)
t_len = thumbnailer.scale(t_path, t_width, t_height, t_type)
remote_thumbnails.append([
server_name, media_id, file_id,
t_width, t_height, t_type, t_method, t_len
])
for t_width, t_height, t_type in crops:
if (t_width, t_height, t_type) in scales:
# If the aspect ratio of the cropped thumbnail matches a purely
# scaled one then there is no point in calculating a separate
# thumbnail.
continue
t_method = "crop"
t_path = self.filepaths.remote_media_thumbnail(
server_name, file_id, t_width, t_height, t_type, t_method
)
self._makedirs(t_path)
t_len = thumbnailer.crop(t_path, t_width, t_height, t_type)
remote_thumbnails.append([
server_name, media_id, file_id,
t_width, t_height, t_type, t_method, t_len
])
yield preserve_context_over_fn(threads.deferToThread, generate_thumbnails)
for r in remote_thumbnails:
yield self.store.store_remote_media_thumbnail(*r)
defer.returnValue({
"width": m_width,
"height": m_height,
})
class MediaRepositoryResource(Resource):
"""File uploading and downloading.
@@ -74,7 +459,12 @@ class MediaRepositoryResource(Resource):
def __init__(self, hs):
Resource.__init__(self)
filepaths = MediaFilePaths(hs.config.media_store_path)
self.putChild("upload", UploadResource(hs, filepaths))
self.putChild("download", DownloadResource(hs, filepaths))
self.putChild("thumbnail", ThumbnailResource(hs, filepaths))
media_repo = MediaRepository(hs, filepaths)
self.putChild("upload", UploadResource(hs, media_repo))
self.putChild("download", DownloadResource(hs, media_repo))
self.putChild("thumbnail", ThumbnailResource(hs, media_repo))
self.putChild("identicon", IdenticonResource())
if hs.config.url_preview_enabled:
self.putChild("preview_url", PreviewUrlResource(hs, media_repo))

View File

@@ -0,0 +1,454 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.web.server import NOT_DONE_YET
from twisted.internet import defer
from twisted.web.resource import Resource
from synapse.api.errors import (
SynapseError, Codes,
)
from synapse.util.stringutils import random_string
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.http.client import SpiderHttpClient
from synapse.http.server import (
request_handler, respond_with_json_bytes
)
from synapse.util.async import ObservableDeferred
from synapse.util.stringutils import is_ascii
import os
import re
import fnmatch
import cgi
import ujson as json
import urlparse
import logging
logger = logging.getLogger(__name__)
class PreviewUrlResource(Resource):
isLeaf = True
def __init__(self, hs, media_repo):
Resource.__init__(self)
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.version_string = hs.version_string
self.filepaths = media_repo.filepaths
self.max_spider_size = hs.config.max_spider_size
self.server_name = hs.hostname
self.store = hs.get_datastore()
self.client = SpiderHttpClient(hs)
self.media_repo = media_repo
if hasattr(hs.config, "url_preview_url_blacklist"):
self.url_preview_url_blacklist = hs.config.url_preview_url_blacklist
# simple memory cache mapping urls to OG metadata
self.cache = ExpiringCache(
cache_name="url_previews",
clock=self.clock,
# don't spider URLs more often than once an hour
expiry_ms=60 * 60 * 1000,
)
self.cache.start()
self.downloads = {}
def render_GET(self, request):
self._async_render_GET(request)
return NOT_DONE_YET
@request_handler
@defer.inlineCallbacks
def _async_render_GET(self, request):
# XXX: if get_user_by_req fails, what should we do in an async render?
requester = yield self.auth.get_user_by_req(request)
url = request.args.get("url")[0]
if "ts" in request.args:
ts = int(request.args.get("ts")[0])
else:
ts = self.clock.time_msec()
# impose the URL pattern blacklist
if hasattr(self, "url_preview_url_blacklist"):
url_tuple = urlparse.urlsplit(url)
for entry in self.url_preview_url_blacklist:
match = True
for attrib in entry:
pattern = entry[attrib]
value = getattr(url_tuple, attrib)
logger.debug((
"Matching attrib '%s' with value '%s' against"
" pattern '%s'"
) % (attrib, value, pattern))
if value is None:
match = False
continue
if pattern.startswith('^'):
if not re.match(pattern, getattr(url_tuple, attrib)):
match = False
continue
else:
if not fnmatch.fnmatch(getattr(url_tuple, attrib), pattern):
match = False
continue
if match:
logger.warn(
"URL %s blocked by url_blacklist entry %s", url, entry
)
raise SynapseError(
403, "URL blocked by url pattern blacklist entry",
Codes.UNKNOWN
)
# first check the memory cache - good to handle all the clients on this
# HS thundering away to preview the same URL at the same time.
og = self.cache.get(url)
if og:
respond_with_json_bytes(request, 200, json.dumps(og), send_cors=True)
return
# then check the URL cache in the DB (which will also provide us with
# historical previews, if we have any)
cache_result = yield self.store.get_url_cache(url, ts)
if (
cache_result and
cache_result["download_ts"] + cache_result["expires"] > ts and
cache_result["response_code"] / 100 == 2
):
respond_with_json_bytes(
request, 200, cache_result["og"].encode('utf-8'),
send_cors=True
)
return
# Ensure only one download for a given URL is active at a time
download = self.downloads.get(url)
if download is None:
download = self._download_url(url, requester.user)
download = ObservableDeferred(
download,
consumeErrors=True
)
self.downloads[url] = download
@download.addBoth
def callback(media_info):
del self.downloads[url]
return media_info
media_info = yield download.observe()
# FIXME: we should probably update our cache now anyway, so that
# even if the OG calculation raises, we don't keep hammering on the
# remote server. For now, leave it uncached to aid debugging OG
# calculation problems
logger.debug("got media_info of '%s'" % media_info)
if self._is_media(media_info['media_type']):
dims = yield self.media_repo._generate_local_thumbnails(
media_info['filesystem_id'], media_info
)
og = {
"og:description": media_info['download_name'],
"og:image": "mxc://%s/%s" % (
self.server_name, media_info['filesystem_id']
),
"og:image:type": media_info['media_type'],
"matrix:image:size": media_info['media_length'],
}
if dims:
og["og:image:width"] = dims['width']
og["og:image:height"] = dims['height']
else:
logger.warn("Couldn't get dims for %s" % url)
# define our OG response for this media
elif self._is_html(media_info['media_type']):
# TODO: somehow stop a big HTML tree from exploding synapse's RAM
from lxml import etree
file = open(media_info['filename'])
body = file.read()
file.close()
# clobber the encoding from the content-type, or default to utf-8
# XXX: this overrides any <meta/> or XML charset headers in the body
# which may pose problems, but so far seems to work okay.
match = re.match(r'.*; *charset=(.*?)(;|$)', media_info['media_type'], re.I)
encoding = match.group(1) if match else "utf-8"
try:
parser = etree.HTMLParser(recover=True, encoding=encoding)
tree = etree.fromstring(body, parser)
og = yield self._calc_og(tree, media_info, requester)
except UnicodeDecodeError:
# blindly try decoding the body as utf-8, which seems to fix
# the charset mismatches on https://google.com
parser = etree.HTMLParser(recover=True, encoding=encoding)
tree = etree.fromstring(body.decode('utf-8', 'ignore'), parser)
og = yield self._calc_og(tree, media_info, requester)
else:
logger.warn("Failed to find any OG data in %s", url)
og = {}
logger.debug("Calculated OG for %s as %s" % (url, og))
# store OG in ephemeral in-memory cache
self.cache[url] = og
# store OG in history-aware DB cache
yield self.store.store_url_cache(
url,
media_info["response_code"],
media_info["etag"],
media_info["expires"],
json.dumps(og),
media_info["filesystem_id"],
media_info["created_ts"],
)
respond_with_json_bytes(request, 200, json.dumps(og), send_cors=True)
@defer.inlineCallbacks
def _calc_og(self, tree, media_info, requester):
# suck our tree into lxml and define our OG response.
# if we see any image URLs in the OG response, then spider them
# (although the client could choose to do this by asking for previews of those
# URLs to avoid DoSing the server)
# "og:type" : "video",
# "og:url" : "https://www.youtube.com/watch?v=LXDBoHyjmtw",
# "og:site_name" : "YouTube",
# "og:video:type" : "application/x-shockwave-flash",
# "og:description" : "Fun stuff happening here",
# "og:title" : "RemoteJam - Matrix team hack for Disrupt Europe Hackathon",
# "og:image" : "https://i.ytimg.com/vi/LXDBoHyjmtw/maxresdefault.jpg",
# "og:video:url" : "http://www.youtube.com/v/LXDBoHyjmtw?version=3&autohide=1",
# "og:video:width" : "1280"
# "og:video:height" : "720",
# "og:video:secure_url": "https://www.youtube.com/v/LXDBoHyjmtw?version=3",
og = {}
for tag in tree.xpath("//*/meta[starts-with(@property, 'og:')]"):
og[tag.attrib['property']] = tag.attrib['content']
# TODO: grab article: meta tags too, e.g.:
# "article:publisher" : "https://www.facebook.com/thethudonline" />
# "article:author" content="https://www.facebook.com/thethudonline" />
# "article:tag" content="baby" />
# "article:section" content="Breaking News" />
# "article:published_time" content="2016-03-31T19:58:24+00:00" />
# "article:modified_time" content="2016-04-01T18:31:53+00:00" />
if 'og:title' not in og:
# do some basic spidering of the HTML
title = tree.xpath("(//title)[1] | (//h1)[1] | (//h2)[1] | (//h3)[1]")
og['og:title'] = title[0].text.strip() if title else None
if 'og:image' not in og:
# TODO: extract a favicon failing all else
meta_image = tree.xpath(
"//*/meta[translate(@itemprop, 'IMAGE', 'image')='image']/@content"
)
if meta_image:
og['og:image'] = self._rebase_url(meta_image[0], media_info['uri'])
else:
# TODO: consider inlined CSS styles as well as width & height attribs
images = tree.xpath("//img[@src][number(@width)>10][number(@height)>10]")
images = sorted(images, key=lambda i: (
-1 * int(i.attrib['width']) * int(i.attrib['height'])
))
if not images:
images = tree.xpath("//img[@src]")
if images:
og['og:image'] = images[0].attrib['src']
# pre-cache the image for posterity
# FIXME: it might be cleaner to use the same flow as the main /preview_url request
# itself and benefit from the same caching etc. But for now we just rely on the
# caching on the master request to speed things up.
if 'og:image' in og and og['og:image']:
image_info = yield self._download_url(
self._rebase_url(og['og:image'], media_info['uri']), requester.user
)
if self._is_media(image_info['media_type']):
# TODO: make sure we don't choke on white-on-transparent images
dims = yield self.media_repo._generate_local_thumbnails(
image_info['filesystem_id'], image_info
)
if dims:
og["og:image:width"] = dims['width']
og["og:image:height"] = dims['height']
else:
logger.warn("Couldn't get dims for %s" % og["og:image"])
og["og:image"] = "mxc://%s/%s" % (
self.server_name, image_info['filesystem_id']
)
og["og:image:type"] = image_info['media_type']
og["matrix:image:size"] = image_info['media_length']
else:
del og["og:image"]
if 'og:description' not in og:
meta_description = tree.xpath(
"//*/meta"
"[translate(@name, 'DESCRIPTION', 'description')='description']"
"/@content")
if meta_description:
og['og:description'] = meta_description[0]
else:
# grab any text nodes which are inside the <body/> tag...
# unless they are within an HTML5 semantic markup tag...
# <header/>, <nav/>, <aside/>, <footer/>
# ...or if they are within a <script/> or <style/> tag.
# This is a very very very coarse approximation to a plain text
# render of the page.
text_nodes = tree.xpath("//text()[not(ancestor::header | ancestor::nav | "
"ancestor::aside | ancestor::footer | "
"ancestor::script | ancestor::style)]" +
"[ancestor::body]")
text = ''
for text_node in text_nodes:
if len(text) < 500:
text += text_node + ' '
else:
break
text = re.sub(r'[\t ]+', ' ', text)
text = re.sub(r'[\t \r\n]*[\r\n]+', '\n', text)
text = text.strip()[:500]
og['og:description'] = text if text else None
# TODO: delete the url downloads to stop diskfilling,
# as we only ever cared about its OG
defer.returnValue(og)
def _rebase_url(self, url, base):
base = list(urlparse.urlparse(base))
url = list(urlparse.urlparse(url))
if not url[0]: # fix up schema
url[0] = base[0] or "http"
if not url[1]: # fix up hostname
url[1] = base[1]
if not url[2].startswith('/'):
url[2] = re.sub(r'/[^/]+$', '/', base[2]) + url[2]
return urlparse.urlunparse(url)
@defer.inlineCallbacks
def _download_url(self, url, user):
# TODO: we should probably honour robots.txt... except in practice
# we're most likely being explicitly triggered by a human rather than a
# bot, so are we really a robot?
# XXX: horrible duplication with base_resource's _download_remote_file()
file_id = random_string(24)
fname = self.filepaths.local_media_filepath(file_id)
self.media_repo._makedirs(fname)
try:
with open(fname, "wb") as f:
logger.debug("Trying to get url '%s'" % url)
length, headers, uri, code = yield self.client.get_file(
url, output_stream=f, max_size=self.max_spider_size,
)
# FIXME: pass through 404s and other error messages nicely
media_type = headers["Content-Type"][0]
time_now_ms = self.clock.time_msec()
content_disposition = headers.get("Content-Disposition", None)
if content_disposition:
_, params = cgi.parse_header(content_disposition[0],)
download_name = None
# First check if there is a valid UTF-8 filename
download_name_utf8 = params.get("filename*", None)
if download_name_utf8:
if download_name_utf8.lower().startswith("utf-8''"):
download_name = download_name_utf8[7:]
# If there isn't check for an ascii name.
if not download_name:
download_name_ascii = params.get("filename", None)
if download_name_ascii and is_ascii(download_name_ascii):
download_name = download_name_ascii
if download_name:
download_name = urlparse.unquote(download_name)
try:
download_name = download_name.decode("utf-8")
except UnicodeDecodeError:
download_name = None
else:
download_name = None
yield self.store.store_local_media(
media_id=file_id,
media_type=media_type,
time_now_ms=self.clock.time_msec(),
upload_name=download_name,
media_length=length,
user_id=user,
)
except Exception as e:
os.remove(fname)
raise SynapseError(
500, ("Failed to download content: %s" % e),
Codes.UNKNOWN
)
defer.returnValue({
"media_type": media_type,
"media_length": length,
"download_name": download_name,
"created_ts": time_now_ms,
"filesystem_id": file_id,
"filename": fname,
"uri": uri,
"response_code": code,
# FIXME: we should calculate a proper expiration based on the
# Cache-Control and Expire headers. But for now, assume 1 hour.
"expires": 60 * 60 * 1000,
"etag": headers["ETag"][0] if "ETag" in headers else None,
})
def _is_media(self, content_type):
if content_type.lower().startswith("image/"):
return True
def _is_html(self, content_type):
content_type = content_type.lower()
if (
content_type.startswith("text/html") or
content_type.startswith("application/xhtml")
):
return True

View File

@@ -14,7 +14,8 @@
# limitations under the License.
from .base_resource import BaseMediaResource, parse_media_id
from ._base import parse_media_id, respond_404, respond_with_file
from twisted.web.resource import Resource
from synapse.http.servlet import parse_string, parse_integer
from synapse.http.server import request_handler
@@ -26,9 +27,19 @@ import logging
logger = logging.getLogger(__name__)
class ThumbnailResource(BaseMediaResource):
class ThumbnailResource(Resource):
isLeaf = True
def __init__(self, hs, media_repo):
Resource.__init__(self)
self.store = hs.get_datastore()
self.filepaths = media_repo.filepaths
self.media_repo = media_repo
self.dynamic_thumbnails = hs.config.dynamic_thumbnails
self.server_name = hs.hostname
self.version_string = hs.version_string
def render_GET(self, request):
self._async_render_GET(request)
return NOT_DONE_YET
@@ -69,9 +80,14 @@ class ThumbnailResource(BaseMediaResource):
media_info = yield self.store.get_local_media(media_id)
if not media_info:
self._respond_404(request)
respond_404(request)
return
# if media_info["media_type"] == "image/svg+xml":
# file_path = self.filepaths.local_media_filepath(media_id)
# yield respond_with_file(request, media_info["media_type"], file_path)
# return
thumbnail_infos = yield self.store.get_local_media_thumbnails(media_id)
if thumbnail_infos:
@@ -86,7 +102,7 @@ class ThumbnailResource(BaseMediaResource):
file_path = self.filepaths.local_media_thumbnail(
media_id, t_width, t_height, t_type, t_method,
)
yield self._respond_with_file(request, t_type, file_path)
yield respond_with_file(request, t_type, file_path)
else:
yield self._respond_default_thumbnail(
@@ -100,9 +116,14 @@ class ThumbnailResource(BaseMediaResource):
media_info = yield self.store.get_local_media(media_id)
if not media_info:
self._respond_404(request)
respond_404(request)
return
# if media_info["media_type"] == "image/svg+xml":
# file_path = self.filepaths.local_media_filepath(media_id)
# yield respond_with_file(request, media_info["media_type"], file_path)
# return
thumbnail_infos = yield self.store.get_local_media_thumbnails(media_id)
for info in thumbnail_infos:
t_w = info["thumbnail_width"] == desired_width
@@ -114,18 +135,18 @@ class ThumbnailResource(BaseMediaResource):
file_path = self.filepaths.local_media_thumbnail(
media_id, desired_width, desired_height, desired_type, desired_method,
)
yield self._respond_with_file(request, desired_type, file_path)
yield respond_with_file(request, desired_type, file_path)
return
logger.debug("We don't have a local thumbnail of that size. Generating")
# Okay, so we generate one.
file_path = yield self._generate_local_exact_thumbnail(
file_path = yield self.media_repo.generate_local_exact_thumbnail(
media_id, desired_width, desired_height, desired_method, desired_type
)
if file_path:
yield self._respond_with_file(request, desired_type, file_path)
yield respond_with_file(request, desired_type, file_path)
else:
yield self._respond_default_thumbnail(
request, media_info, desired_width, desired_height,
@@ -136,7 +157,12 @@ class ThumbnailResource(BaseMediaResource):
def _select_or_generate_remote_thumbnail(self, request, server_name, media_id,
desired_width, desired_height,
desired_method, desired_type):
media_info = yield self._get_remote_media(server_name, media_id)
media_info = yield self.media_repo.get_remote_media(server_name, media_id)
# if media_info["media_type"] == "image/svg+xml":
# file_path = self.filepaths.remote_media_filepath(server_name, media_id)
# yield respond_with_file(request, media_info["media_type"], file_path)
# return
thumbnail_infos = yield self.store.get_remote_media_thumbnails(
server_name, media_id,
@@ -155,19 +181,19 @@ class ThumbnailResource(BaseMediaResource):
server_name, file_id, desired_width, desired_height,
desired_type, desired_method,
)
yield self._respond_with_file(request, desired_type, file_path)
yield respond_with_file(request, desired_type, file_path)
return
logger.debug("We don't have a local thumbnail of that size. Generating")
# Okay, so we generate one.
file_path = yield self._generate_remote_exact_thumbnail(
file_path = yield self.media_repo.generate_remote_exact_thumbnail(
server_name, file_id, media_id, desired_width,
desired_height, desired_method, desired_type
)
if file_path:
yield self._respond_with_file(request, desired_type, file_path)
yield respond_with_file(request, desired_type, file_path)
else:
yield self._respond_default_thumbnail(
request, media_info, desired_width, desired_height,
@@ -179,7 +205,12 @@ class ThumbnailResource(BaseMediaResource):
height, method, m_type):
# TODO: Don't download the whole remote file
# We should proxy the thumbnail from the remote server instead.
media_info = yield self._get_remote_media(server_name, media_id)
media_info = yield self.media_repo.get_remote_media(server_name, media_id)
# if media_info["media_type"] == "image/svg+xml":
# file_path = self.filepaths.remote_media_filepath(server_name, media_id)
# yield respond_with_file(request, media_info["media_type"], file_path)
# return
thumbnail_infos = yield self.store.get_remote_media_thumbnails(
server_name, media_id,
@@ -199,7 +230,7 @@ class ThumbnailResource(BaseMediaResource):
file_path = self.filepaths.remote_media_thumbnail(
server_name, file_id, t_width, t_height, t_type, t_method,
)
yield self._respond_with_file(request, t_type, file_path, t_length)
yield respond_with_file(request, t_type, file_path, t_length)
else:
yield self._respond_default_thumbnail(
request, media_info, width, height, method, m_type,
@@ -208,6 +239,8 @@ class ThumbnailResource(BaseMediaResource):
@defer.inlineCallbacks
def _respond_default_thumbnail(self, request, media_info, width, height,
method, m_type):
# XXX: how is this meant to work? store.get_default_thumbnails
# appears to always return [] so won't this always 404?
media_type = media_info["media_type"]
top_level_type = media_type.split("/")[0]
sub_type = media_type.split("/")[-1].split(";")[0]
@@ -223,7 +256,7 @@ class ThumbnailResource(BaseMediaResource):
"_default", "_default",
)
if not thumbnail_infos:
self._respond_404(request)
respond_404(request)
return
thumbnail_info = self._select_thumbnail(
@@ -239,7 +272,7 @@ class ThumbnailResource(BaseMediaResource):
file_path = self.filepaths.default_thumbnail(
top_level_type, sub_type, t_width, t_height, t_type, t_method,
)
yield self.respond_with_file(request, t_type, file_path, t_length)
yield respond_with_file(request, t_type, file_path, t_length)
def _select_thumbnail(self, desired_width, desired_height, desired_method,
desired_type, thumbnail_infos):

View File

@@ -15,20 +15,33 @@
from synapse.http.server import respond_with_json, request_handler
from synapse.util.stringutils import random_string
from synapse.api.errors import SynapseError
from twisted.web.server import NOT_DONE_YET
from twisted.internet import defer
from .base_resource import BaseMediaResource
from twisted.web.resource import Resource
import logging
logger = logging.getLogger(__name__)
class UploadResource(BaseMediaResource):
class UploadResource(Resource):
isLeaf = True
def __init__(self, hs, media_repo):
Resource.__init__(self)
self.media_repo = media_repo
self.filepaths = media_repo.filepaths
self.store = hs.get_datastore()
self.clock = hs.get_clock()
self.server_name = hs.hostname
self.auth = hs.get_auth()
self.max_upload_size = hs.config.max_upload_size
self.version_string = hs.version_string
def render_POST(self, request):
self._async_render_POST(request)
return NOT_DONE_YET
@@ -37,36 +50,6 @@ class UploadResource(BaseMediaResource):
respond_with_json(request, 200, {}, send_cors=True)
return NOT_DONE_YET
@defer.inlineCallbacks
def create_content(self, media_type, upload_name, content, content_length,
auth_user):
media_id = random_string(24)
fname = self.filepaths.local_media_filepath(media_id)
self._makedirs(fname)
# This shouldn't block for very long because the content will have
# already been uploaded at this point.
with open(fname, "wb") as f:
f.write(content)
yield self.store.store_local_media(
media_id=media_id,
media_type=media_type,
time_now_ms=self.clock.time_msec(),
upload_name=upload_name,
media_length=content_length,
user_id=auth_user,
)
media_info = {
"media_type": media_type,
"media_length": content_length,
}
yield self._generate_local_thumbnails(media_id, media_info)
defer.returnValue("mxc://%s/%s" % (self.server_name, media_id))
@request_handler
@defer.inlineCallbacks
def _async_render_POST(self, request):
@@ -108,7 +91,7 @@ class UploadResource(BaseMediaResource):
# disposition = headers.getRawHeaders("Content-Disposition")[0]
# TODO(markjh): parse content-dispostion
content_uri = yield self.create_content(
content_uri = yield self.media_repo.create_content(
media_type, upload_name, request.content.read(),
content_length, requester.user
)

View File

@@ -193,6 +193,9 @@ class HomeServer(object):
**self.db_config.get("args", {})
)
def remove_pusher(self, app_id, push_key, user_id):
return self.get_pusherpool().remove_pusher(app_id, push_key, user_id)
def _make_dependency_method(depname):
def _get(hs):

View File

@@ -75,7 +75,8 @@ class StateHandler(object):
self._state_cache.start()
@defer.inlineCallbacks
def get_current_state(self, room_id, event_type=None, state_key=""):
def get_current_state(self, room_id, event_type=None, state_key="",
latest_event_ids=None):
""" Retrieves the current state for the room. This is done by
calling `get_latest_events_in_room` to get the leading edges of the
event graph and then resolving any of the state conflicts.
@@ -86,11 +87,13 @@ class StateHandler(object):
If `event_type` is specified, then the method returns only the one
event (or None) with that `event_type` and `state_key`.
:returns map from (type, state_key) to event
Returns:
map from (type, state_key) to event
"""
event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
if not latest_event_ids:
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
res = yield self.resolve_state_groups(room_id, event_ids)
res = yield self.resolve_state_groups(room_id, latest_event_ids)
state = res[1]
if event_type:
@@ -100,7 +103,7 @@ class StateHandler(object):
defer.returnValue(state)
@defer.inlineCallbacks
def compute_event_context(self, event, old_state=None, outlier=False):
def compute_event_context(self, event, old_state=None):
""" Fills out the context with the `current state` of the graph. The
`current state` here is defined to be the state of the event graph
just before the event - i.e. it never includes `event`
@@ -115,7 +118,7 @@ class StateHandler(object):
"""
context = EventContext()
if outlier:
if event.internal_metadata.is_outlier():
# If this is an outlier, then we know it shouldn't have any current
# state. Certainly store.get_current_state won't return any, and
# persisting the event won't store the state group.
@@ -176,10 +179,11 @@ class StateHandler(object):
""" Given a list of event_ids this method fetches the state at each
event, resolves conflicts between them and returns them.
:returns a Deferred tuple of (`state_group`, `state`, `prev_state`).
`state_group` is the name of a state group if one and only one is
involved. `state` is a map from (type, state_key) to event, and
`prev_state` is a list of event ids.
Returns:
a Deferred tuple of (`state_group`, `state`, `prev_state`).
`state_group` is the name of a state group if one and only one is
involved. `state` is a map from (type, state_key) to event, and
`prev_state` is a list of event ids.
"""
logger.debug("resolve_state_groups event_ids %s", event_ids)
@@ -210,7 +214,7 @@ class StateHandler(object):
if self._state_cache is not None:
cache = self._state_cache.get(group_names, None)
if cache and cache.state_group:
if cache:
cache.ts = self.clock.time_msec()
event_dict = yield self.store.get_events(cache.state.values())
@@ -226,22 +230,34 @@ class StateHandler(object):
(cache.state_group, state, prev_states)
)
logger.info("Resolving state for %s with %d groups", room_id, len(state_groups))
new_state, prev_states = self._resolve_events(
state_groups.values(), event_type, state_key
)
state_group = None
new_state_event_ids = frozenset(e.event_id for e in new_state.values())
for sg, events in state_groups.items():
if new_state_event_ids == frozenset(e.event_id for e in events):
state_group = sg
break
if self._state_cache is not None:
cache = _StateCacheEntry(
state={key: event.event_id for key, event in new_state.items()},
state_group=None,
state_group=state_group,
ts=self.clock.time_msec()
)
self._state_cache[group_names] = cache
defer.returnValue((None, new_state, prev_states))
defer.returnValue((state_group, new_state, prev_states))
def resolve_events(self, state_sets, event):
logger.info(
"Resolving state for %s with %d groups", event.room_id, len(state_sets)
)
if event.is_state():
return self._resolve_events(
state_sets, event.type, event.state_key
@@ -251,9 +267,10 @@ class StateHandler(object):
def _resolve_events(self, state_sets, event_type=None, state_key=""):
"""
:returns a tuple (new_state, prev_states). new_state is a map
from (type, state_key) to event. prev_states is a list of event_ids.
:rtype: (dict[(str, str), synapse.events.FrozenEvent], list[str])
Returns
(dict[(str, str), synapse.events.FrozenEvent], list[str]): a tuple
(new_state, prev_states). new_state is a map from (type, state_key)
to event. prev_states is a list of event_ids.
"""
with Measure(self.clock, "state._resolve_events"):
state = {}

View File

@@ -88,22 +88,17 @@ class DataStore(RoomMemberStore, RoomStore,
self.hs = hs
self.database_engine = hs.database_engine
cur = db_conn.cursor()
try:
cur.execute("SELECT MIN(stream_ordering) FROM events",)
rows = cur.fetchall()
self.min_stream_token = rows[0][0] if rows and rows[0] and rows[0][0] else -1
self.min_stream_token = min(self.min_stream_token, -1)
finally:
cur.close()
self.client_ip_last_seen = Cache(
name="client_ip_last_seen",
keylen=4,
)
self._stream_id_gen = StreamIdGenerator(
db_conn, "events", "stream_ordering"
db_conn, "events", "stream_ordering",
extra_tables=[("local_invites", "stream_id")]
)
self._backfill_id_gen = StreamIdGenerator(
db_conn, "events", "stream_ordering", step=-1
)
self._receipts_id_gen = StreamIdGenerator(
db_conn, "receipts_linearized", "stream_id"
@@ -116,7 +111,7 @@ class DataStore(RoomMemberStore, RoomStore,
)
self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id")
self._state_groups_id_gen = IdGenerator(db_conn, "state_groups", "id")
self._state_groups_id_gen = StreamIdGenerator(db_conn, "state_groups", "id")
self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
self._refresh_tokens_id_gen = IdGenerator(db_conn, "refresh_tokens", "id")
self._push_rule_id_gen = IdGenerator(db_conn, "push_rules", "id")
@@ -129,7 +124,7 @@ class DataStore(RoomMemberStore, RoomStore,
extra_tables=[("deleted_pushers", "stream_id")],
)
events_max = self._stream_id_gen.get_max_token()
events_max = self._stream_id_gen.get_current_token()
event_cache_prefill, min_event_val = self._get_cache_dict(
db_conn, "events",
entity_column="room_id",
@@ -145,7 +140,7 @@ class DataStore(RoomMemberStore, RoomStore,
"MembershipStreamChangeCache", events_max,
)
account_max = self._account_data_id_gen.get_max_token()
account_max = self._account_data_id_gen.get_current_token()
self._account_data_stream_cache = StreamChangeCache(
"AccountDataAndTagsChangeCache", account_max,
)
@@ -156,7 +151,7 @@ class DataStore(RoomMemberStore, RoomStore,
db_conn, "presence_stream",
entity_column="user_id",
stream_column="stream_id",
max_value=self._presence_id_gen.get_max_token(),
max_value=self._presence_id_gen.get_current_token(),
)
self.presence_stream_cache = StreamChangeCache(
"PresenceStreamChangeCache", min_presence_val,
@@ -167,7 +162,7 @@ class DataStore(RoomMemberStore, RoomStore,
db_conn, "push_rules_stream",
entity_column="user_id",
stream_column="stream_id",
max_value=self._push_rules_stream_id_gen.get_max_token()[0],
max_value=self._push_rules_stream_id_gen.get_current_token()[0],
)
self.push_rules_stream_cache = StreamChangeCache(
@@ -182,39 +177,6 @@ class DataStore(RoomMemberStore, RoomStore,
self.__presence_on_startup = None
return active_on_startup
def _get_cache_dict(self, db_conn, table, entity_column, stream_column, max_value):
# Fetch a mapping of room_id -> max stream position for "recent" rooms.
# It doesn't really matter how many we get, the StreamChangeCache will
# do the right thing to ensure it respects the max size of cache.
sql = (
"SELECT %(entity)s, MAX(%(stream)s) FROM %(table)s"
" WHERE %(stream)s > ? - 100000"
" GROUP BY %(entity)s"
) % {
"table": table,
"entity": entity_column,
"stream": stream_column,
}
sql = self.database_engine.convert_param_style(sql)
txn = db_conn.cursor()
txn.execute(sql, (int(max_value),))
rows = txn.fetchall()
txn.close()
cache = {
row[0]: int(row[1])
for row in rows
}
if cache:
min_val = min(cache.values())
else:
min_val = max_value
return cache, min_val
def _get_active_presence(self, db_conn):
"""Fetch non-offline presence from the database so that we can register
the appropriate time outs.

View File

@@ -810,11 +810,39 @@ class SQLBaseStore(object):
return txn.execute(sql, keyvalues.values())
def get_next_stream_id(self):
with self._next_stream_id_lock:
i = self._next_stream_id
self._next_stream_id += 1
return i
def _get_cache_dict(self, db_conn, table, entity_column, stream_column,
max_value):
# Fetch a mapping of room_id -> max stream position for "recent" rooms.
# It doesn't really matter how many we get, the StreamChangeCache will
# do the right thing to ensure it respects the max size of cache.
sql = (
"SELECT %(entity)s, MAX(%(stream)s) FROM %(table)s"
" WHERE %(stream)s > ? - 100000"
" GROUP BY %(entity)s"
) % {
"table": table,
"entity": entity_column,
"stream": stream_column,
}
sql = self.database_engine.convert_param_style(sql)
txn = db_conn.cursor()
txn.execute(sql, (int(max_value),))
rows = txn.fetchall()
txn.close()
cache = {
row[0]: int(row[1])
for row in rows
}
if cache:
min_val = min(cache.values())
else:
min_val = max_value
return cache, min_val
class _RollbackButIsFineException(Exception):

View File

@@ -200,7 +200,7 @@ class AccountDataStore(SQLBaseStore):
"add_room_account_data", add_account_data_txn, next_id
)
result = self._account_data_id_gen.get_max_token()
result = self._account_data_id_gen.get_current_token()
defer.returnValue(result)
@defer.inlineCallbacks
@@ -239,7 +239,7 @@ class AccountDataStore(SQLBaseStore):
"add_user_account_data", add_account_data_txn, next_id
)
result = self._account_data_id_gen.get_max_token()
result = self._account_data_id_gen.get_current_token()
defer.returnValue(result)
def _update_max_stream_id(self, txn, next_id):

View File

@@ -26,13 +26,13 @@ SUPPORTED_MODULE = {
}
def create_engine(config):
name = config.database_config["name"]
def create_engine(database_config):
name = database_config["name"]
engine_class = SUPPORTED_MODULE.get(name, None)
if engine_class:
module = importlib.import_module(name)
return engine_class(module, config=config)
return engine_class(module)
raise RuntimeError(
"Unsupported database engine '%s'" % (name,)

View File

@@ -13,18 +13,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.storage.prepare_database import prepare_database
from ._base import IncorrectDatabaseSetup
class PostgresEngine(object):
single_threaded = False
def __init__(self, database_module, config):
def __init__(self, database_module):
self.module = database_module
self.module.extensions.register_type(self.module.extensions.UNICODE)
self.config = config
def check_database(self, txn):
txn.execute("SHOW SERVER_ENCODING")
@@ -44,9 +41,6 @@ class PostgresEngine(object):
self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ
)
def prepare_database(self, db_conn):
prepare_database(db_conn, self, config=self.config)
def is_deadlock(self, error):
if isinstance(error, self.module.DatabaseError):
return error.pgcode in ["40001", "40P01"]

View File

@@ -13,9 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from synapse.storage.prepare_database import (
prepare_database, prepare_sqlite3_database
)
from synapse.storage.prepare_database import prepare_database
import struct
@@ -23,9 +21,8 @@ import struct
class Sqlite3Engine(object):
single_threaded = True
def __init__(self, database_module, config):
def __init__(self, database_module):
self.module = database_module
self.config = config
def check_database(self, txn):
pass
@@ -34,13 +31,9 @@ class Sqlite3Engine(object):
return sql
def on_new_connection(self, db_conn):
self.prepare_database(db_conn)
prepare_database(db_conn, self, config=None)
db_conn.create_function("rank", 1, _rank)
def prepare_database(self, db_conn):
prepare_sqlite3_database(db_conn)
prepare_database(db_conn, self, config=self.config)
def is_deadlock(self, error):
return False

View File

@@ -163,6 +163,22 @@ class EventFederationStore(SQLBaseStore):
room_id,
)
@defer.inlineCallbacks
def get_max_depth_of_events(self, event_ids):
sql = (
"SELECT MAX(depth) FROM events WHERE event_id IN (%s)"
) % (",".join(["?"] * len(event_ids)),)
rows = yield self._execute(
"get_max_depth_of_events", None,
sql, *event_ids
)
if rows:
defer.returnValue(rows[0][0])
else:
defer.returnValue(1)
def _get_min_depth_interaction(self, txn, room_id):
min_depth = self._simple_select_one_onecol_txn(
txn,

View File

@@ -26,8 +26,9 @@ logger = logging.getLogger(__name__)
class EventPushActionsStore(SQLBaseStore):
def _set_push_actions_for_event_and_users_txn(self, txn, event, tuples):
"""
:param event: the event set actions for
:param tuples: list of tuples of (user_id, actions)
Args:
event: the event set actions for
tuples: list of tuples of (user_id, actions)
"""
values = []
for uid, actions in tuples:
@@ -99,6 +100,96 @@ class EventPushActionsStore(SQLBaseStore):
)
defer.returnValue(ret)
@defer.inlineCallbacks
def get_push_action_users_in_range(self, min_stream_ordering, max_stream_ordering):
def f(txn):
sql = (
"SELECT DISTINCT(user_id) FROM event_push_actions WHERE"
" stream_ordering >= ? AND stream_ordering <= ?"
)
txn.execute(sql, (min_stream_ordering, max_stream_ordering))
return [r[0] for r in txn.fetchall()]
ret = yield self.runInteraction("get_push_action_users_in_range", f)
defer.returnValue(ret)
@defer.inlineCallbacks
def get_unread_push_actions_for_user_in_range(self, user_id,
min_stream_ordering,
max_stream_ordering=None):
def get_after_receipt(txn):
sql = (
"SELECT ep.event_id, ep.stream_ordering, ep.actions "
"FROM event_push_actions AS ep, ("
" SELECT room_id, user_id,"
" max(topological_ordering) as topological_ordering,"
" max(stream_ordering) as stream_ordering"
" FROM events"
" NATURAL JOIN receipts_linearized WHERE receipt_type = 'm.read'"
" GROUP BY room_id, user_id"
") AS rl "
"WHERE"
" ep.room_id = rl.room_id"
" AND ("
" ep.topological_ordering > rl.topological_ordering"
" OR ("
" ep.topological_ordering = rl.topological_ordering"
" AND ep.stream_ordering > rl.stream_ordering"
" )"
" )"
" AND ep.stream_ordering > ?"
" AND ep.user_id = ?"
" AND ep.user_id = rl.user_id"
)
args = [min_stream_ordering, user_id]
if max_stream_ordering is not None:
sql += " AND ep.stream_ordering <= ?"
args.append(max_stream_ordering)
sql += " ORDER BY ep.stream_ordering ASC"
txn.execute(sql, args)
return txn.fetchall()
after_read_receipt = yield self.runInteraction(
"get_unread_push_actions_for_user_in_range", get_after_receipt
)
def get_no_receipt(txn):
sql = (
"SELECT ep.event_id, ep.stream_ordering, ep.actions "
"FROM event_push_actions AS ep "
"WHERE ep.room_id not in ("
" SELECT room_id FROM events NATURAL JOIN receipts_linearized"
" WHERE receipt_type = 'm.read' AND user_id = ? "
" GROUP BY room_id"
") AND ep.user_id = ? AND ep.stream_ordering > ?"
)
args = [user_id, user_id, min_stream_ordering]
if max_stream_ordering is not None:
sql += " AND ep.stream_ordering <= ?"
args.append(max_stream_ordering)
sql += " ORDER BY ep.stream_ordering ASC"
txn.execute(sql, args)
return txn.fetchall()
no_read_receipt = yield self.runInteraction(
"get_unread_push_actions_for_user_in_range", get_no_receipt
)
defer.returnValue([
{
"event_id": row[0],
"stream_ordering": row[1],
"actions": json.loads(row[2]),
} for row in after_read_receipt + no_read_receipt
])
@defer.inlineCallbacks
def get_latest_push_action_stream_ordering(self):
def f(txn):
txn.execute("SELECT MAX(stream_ordering) FROM event_push_actions")
return txn.fetchone()
result = yield self.runInteraction(
"get_latest_push_action_stream_ordering", f
)
defer.returnValue(result[0] or 0)
def _remove_push_actions_for_event_id_txn(self, txn, room_id, event_id):
# Sad that we have to blow away the cache for the whole room here
txn.call_after(

View File

@@ -24,7 +24,7 @@ from synapse.util.logutils import log_function
from synapse.api.constants import EventTypes
from canonicaljson import encode_canonical_json
from contextlib import contextmanager
from collections import namedtuple
import logging
import math
@@ -60,64 +60,83 @@ class EventsStore(SQLBaseStore):
)
@defer.inlineCallbacks
def persist_events(self, events_and_contexts, backfilled=False,
is_new_state=True):
def persist_events(self, events_and_contexts, backfilled=False):
"""
Write events to the database
Args:
events_and_contexts: list of tuples of (event, context)
backfilled: ?
Returns: Tuple of stream_orderings where the first is the minimum and
last is the maximum stream ordering assigned to the events when
persisting.
"""
if not events_and_contexts:
return
if backfilled:
start = self.min_stream_token - 1
self.min_stream_token -= len(events_and_contexts) + 1
stream_orderings = range(start, self.min_stream_token, -1)
@contextmanager
def stream_ordering_manager():
yield stream_orderings
stream_ordering_manager = stream_ordering_manager()
stream_ordering_manager = self._backfill_id_gen.get_next_mult(
len(events_and_contexts)
)
else:
stream_ordering_manager = self._stream_id_gen.get_next_mult(
len(events_and_contexts)
)
state_group_id_manager = self._state_groups_id_gen.get_next_mult(
len(events_and_contexts)
)
with stream_ordering_manager as stream_orderings:
for (event, _), stream in zip(events_and_contexts, stream_orderings):
event.internal_metadata.stream_ordering = stream
with state_group_id_manager as state_group_ids:
for (event, context), stream, state_group_id in zip(
events_and_contexts, stream_orderings, state_group_ids
):
event.internal_metadata.stream_ordering = stream
# Assign a state group_id in case a new id is needed for
# this context. In theory we only need to assign this
# for contexts that have current_state and aren't outliers
# but that make the code more complicated. Assigning an ID
# per event only causes the state_group_ids to grow as fast
# as the stream_ordering so in practise shouldn't be a problem.
context.new_state_group_id = state_group_id
chunks = [
events_and_contexts[x:x + 100]
for x in xrange(0, len(events_and_contexts), 100)
]
chunks = [
events_and_contexts[x:x + 100]
for x in xrange(0, len(events_and_contexts), 100)
]
for chunk in chunks:
# We can't easily parallelize these since different chunks
# might contain the same event. :(
yield self.runInteraction(
"persist_events",
self._persist_events_txn,
events_and_contexts=chunk,
backfilled=backfilled,
is_new_state=is_new_state,
)
for chunk in chunks:
# We can't easily parallelize these since different chunks
# might contain the same event. :(
yield self.runInteraction(
"persist_events",
self._persist_events_txn,
events_and_contexts=chunk,
backfilled=backfilled,
)
@defer.inlineCallbacks
@log_function
def persist_event(self, event, context,
is_new_state=True, current_state=None):
def persist_event(self, event, context, current_state=None, backfilled=False):
try:
with self._stream_id_gen.get_next() as stream_ordering:
event.internal_metadata.stream_ordering = stream_ordering
yield self.runInteraction(
"persist_event",
self._persist_event_txn,
event=event,
context=context,
is_new_state=is_new_state,
current_state=current_state,
)
with self._state_groups_id_gen.get_next() as state_group_id:
event.internal_metadata.stream_ordering = stream_ordering
context.new_state_group_id = state_group_id
yield self.runInteraction(
"persist_event",
self._persist_event_txn,
event=event,
context=context,
current_state=current_state,
backfilled=backfilled,
)
except _RollbackButIsFineException:
pass
max_persisted_id = yield self._stream_id_gen.get_max_token()
max_persisted_id = yield self._stream_id_gen.get_current_token()
defer.returnValue((stream_ordering, max_persisted_id))
@defer.inlineCallbacks
@@ -177,16 +196,27 @@ class EventsStore(SQLBaseStore):
defer.returnValue({e.event_id: e for e in events})
@log_function
def _persist_event_txn(self, txn, event, context,
is_new_state=True, current_state=None):
def _persist_event_txn(self, txn, event, context, current_state, backfilled=False):
# We purposefully do this first since if we include a `current_state`
# key, we *want* to update the `current_state_events` table
if current_state:
txn.call_after(self._get_current_state_for_key.invalidate_all)
txn.call_after(self.get_rooms_for_user.invalidate_all)
txn.call_after(self.get_users_in_room.invalidate, (event.room_id,))
txn.call_after(
self.get_users_with_pushers_in_room.invalidate, (event.room_id,)
)
txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,))
txn.call_after(self.get_room_name_and_aliases, event.room_id)
txn.call_after(self.get_room_name_and_aliases.invalidate, (event.room_id,))
# Add an entry to the current_state_resets table to record the point
# where we clobbered the current state
stream_order = event.internal_metadata.stream_ordering
self._simple_insert_txn(
txn,
table="current_state_resets",
values={"event_stream_ordering": stream_order}
)
self._simple_delete_txn(
txn,
@@ -209,13 +239,11 @@ class EventsStore(SQLBaseStore):
return self._persist_events_txn(
txn,
[(event, context)],
backfilled=False,
is_new_state=is_new_state,
backfilled=backfilled,
)
@log_function
def _persist_events_txn(self, txn, events_and_contexts, backfilled,
is_new_state=True):
def _persist_events_txn(self, txn, events_and_contexts, backfilled):
depth_updates = {}
for event, context in events_and_contexts:
# Remove the any existing cache entries for the event_ids
@@ -282,9 +310,7 @@ class EventsStore(SQLBaseStore):
outlier_persisted = have_persisted[event.event_id]
if not event.internal_metadata.is_outlier() and outlier_persisted:
self._store_state_groups_txn(
txn, event, context,
)
self._store_mult_state_groups_txn(txn, ((event, context),))
metadata_json = encode_json(
event.internal_metadata.get_dict()
@@ -299,6 +325,18 @@ class EventsStore(SQLBaseStore):
(metadata_json, event.event_id,)
)
stream_order = event.internal_metadata.stream_ordering
state_group_id = context.state_group or context.new_state_group_id
self._simple_insert_txn(
txn,
table="ex_outlier_stream",
values={
"event_stream_ordering": stream_order,
"event_id": event.event_id,
"state_group": state_group_id,
}
)
sql = (
"UPDATE events SET outlier = ?"
" WHERE event_id = ?"
@@ -310,19 +348,14 @@ class EventsStore(SQLBaseStore):
self._update_extremeties(txn, [event])
events_and_contexts = filter(
lambda ec: ec[0] not in to_remove,
events_and_contexts
)
events_and_contexts = [
ec for ec in events_and_contexts if ec[0] not in to_remove
]
if not events_and_contexts:
return
self._store_mult_state_groups_txn(txn, [
(event, context)
for event, context in events_and_contexts
if not event.internal_metadata.is_outlier()
])
self._store_mult_state_groups_txn(txn, events_and_contexts)
self._handle_mult_prev_events(
txn,
@@ -349,7 +382,8 @@ class EventsStore(SQLBaseStore):
event
for event, _ in events_and_contexts
if event.type == EventTypes.Member
]
],
backfilled=backfilled,
)
def event_dict(event):
@@ -421,10 +455,9 @@ class EventsStore(SQLBaseStore):
txn, [event for event, _ in events_and_contexts]
)
state_events_and_contexts = filter(
lambda i: i[0].is_state(),
events_and_contexts,
)
state_events_and_contexts = [
ec for ec in events_and_contexts if ec[0].is_state()
]
state_values = []
for event, context in state_events_and_contexts:
@@ -462,32 +495,44 @@ class EventsStore(SQLBaseStore):
],
)
if is_new_state:
for event, _ in state_events_and_contexts:
if not context.rejected:
txn.call_after(
self._get_current_state_for_key.invalidate,
(event.room_id, event.type, event.state_key,)
)
if backfilled:
# Backfilled events come before the current state so we don't need
# to update the current state table
return
if event.type in [EventTypes.Name, EventTypes.Aliases]:
txn.call_after(
self.get_room_name_and_aliases.invalidate,
(event.room_id,)
)
for event, _ in state_events_and_contexts:
if event.internal_metadata.is_outlier():
# Outlier events shouldn't clobber the current state.
continue
self._simple_upsert_txn(
txn,
"current_state_events",
keyvalues={
"room_id": event.room_id,
"type": event.type,
"state_key": event.state_key,
},
values={
"event_id": event.event_id,
}
)
if context.rejected:
# If the event failed it's auth checks then it shouldn't
# clobbler the current state.
continue
txn.call_after(
self._get_current_state_for_key.invalidate,
(event.room_id, event.type, event.state_key,)
)
if event.type in [EventTypes.Name, EventTypes.Aliases]:
txn.call_after(
self.get_room_name_and_aliases.invalidate,
(event.room_id,)
)
self._simple_upsert_txn(
txn,
"current_state_events",
keyvalues={
"room_id": event.room_id,
"type": event.type,
"state_key": event.state_key,
},
values={
"event_id": event.event_id,
}
)
return
@@ -499,6 +544,22 @@ class EventsStore(SQLBaseStore):
(event.event_id, event.redacts)
)
@defer.inlineCallbacks
def have_events_in_timeline(self, event_ids):
"""Given a list of event ids, check if we have already processed and
stored them as non outliers.
"""
rows = yield self._simple_select_many_batch(
table="events",
retcols=("event_id",),
column="event_id",
iterable=list(event_ids),
keyvalues={"outlier": False},
desc="have_events_in_timeline",
)
defer.returnValue(set(r["event_id"] for r in rows))
def have_events(self, event_ids):
"""Given a list of event ids, check if we have already processed them.
@@ -1076,10 +1137,7 @@ class EventsStore(SQLBaseStore):
def get_current_backfill_token(self):
"""The current minimum token that backfilled events have reached"""
# TODO: Fix race with the persit_event txn by using one of the
# stream id managers
return -self.min_stream_token
return -self._backfill_id_gen.get_current_token()
def get_all_new_events(self, last_backfill_id, last_forward_id,
current_backfill_id, current_forward_id, limit):
@@ -1087,10 +1145,12 @@ class EventsStore(SQLBaseStore):
new events or as backfilled events"""
def get_all_new_events_txn(txn):
sql = (
"SELECT e.stream_ordering, ej.internal_metadata, ej.json"
"SELECT e.stream_ordering, ej.internal_metadata, ej.json, eg.state_group"
" FROM events as e"
" JOIN event_json as ej"
" ON e.event_id = ej.event_id AND e.room_id = ej.room_id"
" LEFT JOIN event_to_state_groups as eg"
" ON e.event_id = eg.event_id"
" WHERE ? < e.stream_ordering AND e.stream_ordering <= ?"
" ORDER BY e.stream_ordering ASC"
" LIMIT ?"
@@ -1098,14 +1158,43 @@ class EventsStore(SQLBaseStore):
if last_forward_id != current_forward_id:
txn.execute(sql, (last_forward_id, current_forward_id, limit))
new_forward_events = txn.fetchall()
if len(new_forward_events) == limit:
upper_bound = new_forward_events[-1][0]
else:
upper_bound = current_forward_id
sql = (
"SELECT event_stream_ordering FROM current_state_resets"
" WHERE ? < event_stream_ordering"
" AND event_stream_ordering <= ?"
" ORDER BY event_stream_ordering ASC"
)
txn.execute(sql, (last_forward_id, upper_bound))
state_resets = txn.fetchall()
sql = (
"SELECT event_stream_ordering, event_id, state_group"
" FROM ex_outlier_stream"
" WHERE ? > event_stream_ordering"
" AND event_stream_ordering >= ?"
" ORDER BY event_stream_ordering DESC"
)
txn.execute(sql, (last_forward_id, upper_bound))
forward_ex_outliers = txn.fetchall()
else:
new_forward_events = []
state_resets = []
forward_ex_outliers = []
sql = (
"SELECT -e.stream_ordering, ej.internal_metadata, ej.json"
"SELECT -e.stream_ordering, ej.internal_metadata, ej.json,"
" eg.state_group"
" FROM events as e"
" JOIN event_json as ej"
" ON e.event_id = ej.event_id AND e.room_id = ej.room_id"
" LEFT JOIN event_to_state_groups as eg"
" ON e.event_id = eg.event_id"
" WHERE ? > e.stream_ordering AND e.stream_ordering >= ?"
" ORDER BY e.stream_ordering DESC"
" LIMIT ?"
@@ -1113,8 +1202,35 @@ class EventsStore(SQLBaseStore):
if last_backfill_id != current_backfill_id:
txn.execute(sql, (-last_backfill_id, -current_backfill_id, limit))
new_backfill_events = txn.fetchall()
if len(new_backfill_events) == limit:
upper_bound = new_backfill_events[-1][0]
else:
upper_bound = current_backfill_id
sql = (
"SELECT -event_stream_ordering, event_id, state_group"
" FROM ex_outlier_stream"
" WHERE ? > event_stream_ordering"
" AND event_stream_ordering >= ?"
" ORDER BY event_stream_ordering DESC"
)
txn.execute(sql, (-last_backfill_id, -upper_bound))
backward_ex_outliers = txn.fetchall()
else:
new_backfill_events = []
backward_ex_outliers = []
return (new_forward_events, new_backfill_events)
return AllNewEventsResult(
new_forward_events, new_backfill_events,
forward_ex_outliers, backward_ex_outliers,
state_resets,
)
return self.runInteraction("get_all_new_events", get_all_new_events_txn)
AllNewEventsResult = namedtuple("AllNewEventsResult", [
"new_forward_events", "new_backfill_events",
"forward_ex_outliers", "backward_ex_outliers",
"state_resets"
])

View File

@@ -25,7 +25,7 @@ class MediaRepositoryStore(SQLBaseStore):
def get_local_media(self, media_id):
"""Get the metadata for a local piece of media
Returns:
None if the meia_id doesn't exist.
None if the media_id doesn't exist.
"""
return self._simple_select_one(
"local_media_repository",
@@ -50,6 +50,61 @@ class MediaRepositoryStore(SQLBaseStore):
desc="store_local_media",
)
def get_url_cache(self, url, ts):
"""Get the media_id and ts for a cached URL as of the given timestamp
Returns:
None if the URL isn't cached.
"""
def get_url_cache_txn(txn):
# get the most recently cached result (relative to the given ts)
sql = (
"SELECT response_code, etag, expires, og, media_id, download_ts"
" FROM local_media_repository_url_cache"
" WHERE url = ? AND download_ts <= ?"
" ORDER BY download_ts DESC LIMIT 1"
)
txn.execute(sql, (url, ts))
row = txn.fetchone()
if not row:
# ...or if we've requested a timestamp older than the oldest
# copy in the cache, return the oldest copy (if any)
sql = (
"SELECT response_code, etag, expires, og, media_id, download_ts"
" FROM local_media_repository_url_cache"
" WHERE url = ? AND download_ts > ?"
" ORDER BY download_ts ASC LIMIT 1"
)
txn.execute(sql, (url, ts))
row = txn.fetchone()
if not row:
return None
return dict(zip((
'response_code', 'etag', 'expires', 'og', 'media_id', 'download_ts'
), row))
return self.runInteraction(
"get_url_cache", get_url_cache_txn
)
def store_url_cache(self, url, response_code, etag, expires, og, media_id,
download_ts):
return self._simple_insert(
"local_media_repository_url_cache",
{
"url": url,
"response_code": response_code,
"etag": etag,
"expires": expires,
"og": og,
"media_id": media_id,
"download_ts": download_ts,
},
desc="store_url_cache",
)
def get_local_media_thumbnails(self, media_id):
return self._simple_select_list(
"local_media_repository_thumbnails",

View File

@@ -25,23 +25,11 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts.
SCHEMA_VERSION = 30
SCHEMA_VERSION = 31
dir_path = os.path.abspath(os.path.dirname(__file__))
def read_schema(path):
""" Read the named database schema.
Args:
path: Path of the database schema.
Returns:
A string containing the database schema.
"""
with open(path) as schema_file:
return schema_file.read()
class PrepareDatabaseException(Exception):
pass
@@ -53,6 +41,9 @@ class UpgradeDatabaseException(PrepareDatabaseException):
def prepare_database(db_conn, database_engine, config):
"""Prepares a database for usage. Will either create all necessary tables
or upgrade from an older schema version.
If `config` is None then prepare_database will assert that no upgrade is
necessary, *or* will create a fresh database if the database is empty.
"""
try:
cur = db_conn.cursor()
@@ -60,13 +51,18 @@ def prepare_database(db_conn, database_engine, config):
if version_info:
user_version, delta_files, upgraded = version_info
_upgrade_existing_database(
cur, user_version, delta_files, upgraded, database_engine, config
)
else:
_setup_new_database(cur, database_engine, config)
# cur.execute("PRAGMA user_version = %d" % (SCHEMA_VERSION,))
if config is None:
if user_version != SCHEMA_VERSION:
# If we don't pass in a config file then we are expecting to
# have already upgraded the DB.
raise UpgradeDatabaseException("Database needs to be upgraded")
else:
_upgrade_existing_database(
cur, user_version, delta_files, upgraded, database_engine, config
)
else:
_setup_new_database(cur, database_engine)
cur.close()
db_conn.commit()
@@ -75,7 +71,7 @@ def prepare_database(db_conn, database_engine, config):
raise
def _setup_new_database(cur, database_engine, config):
def _setup_new_database(cur, database_engine):
"""Sets up the database by finding a base set of "full schemas" and then
applying any necessary deltas.
@@ -148,12 +144,13 @@ def _setup_new_database(cur, database_engine, config):
applied_delta_files=[],
upgraded=False,
database_engine=database_engine,
config=config,
config=None,
is_empty=True,
)
def _upgrade_existing_database(cur, current_version, applied_delta_files,
upgraded, database_engine, config):
upgraded, database_engine, config, is_empty=False):
"""Upgrades an existing database.
Delta files can either be SQL stored in *.sql files, or python modules
@@ -246,7 +243,9 @@ def _upgrade_existing_database(cur, current_version, applied_delta_files,
module_name, absolute_path, python_file
)
logger.debug("Running script %s", relative_path)
module.run_upgrade(cur, database_engine, config=config)
module.run_create(cur, database_engine)
if not is_empty:
module.run_upgrade(cur, database_engine, config=config)
elif ext == ".pyc":
# Sometimes .pyc files turn up anyway even though we've
# disabled their generation; e.g. from distribution package
@@ -361,36 +360,3 @@ def _get_or_create_schema_state(txn, database_engine):
return current_version, applied_deltas, upgraded
return None
def prepare_sqlite3_database(db_conn):
"""This function should be called before `prepare_database` on sqlite3
databases.
Since we changed the way we store the current schema version and handle
updates to schemas, we need a way to upgrade from the old method to the
new. This only affects sqlite databases since they were the only ones
supported at the time.
"""
with db_conn:
schema_path = os.path.join(
dir_path, "schema", "schema_version.sql",
)
create_schema = read_schema(schema_path)
db_conn.executescript(create_schema)
c = db_conn.execute("SELECT * FROM schema_version")
rows = c.fetchall()
c.close()
if not rows:
c = db_conn.execute("PRAGMA user_version")
row = c.fetchone()
c.close()
if row and row[0]:
db_conn.execute(
"REPLACE INTO schema_version (version, upgraded)"
" VALUES (?,?)",
(row[0], False)
)

View File

@@ -68,7 +68,9 @@ class PresenceStore(SQLBaseStore):
self._update_presence_txn, stream_orderings, presence_states,
)
defer.returnValue((stream_orderings[-1], self._presence_id_gen.get_max_token()))
defer.returnValue((
stream_orderings[-1], self._presence_id_gen.get_current_token()
))
def _update_presence_txn(self, txn, stream_orderings, presence_states):
for stream_id, state in zip(stream_orderings, presence_states):
@@ -155,7 +157,7 @@ class PresenceStore(SQLBaseStore):
defer.returnValue([UserPresenceState(**row) for row in rows])
def get_current_presence_token(self):
return self._presence_id_gen.get_max_token()
return self._presence_id_gen.get_current_token()
def allow_presence_visible(self, observed_localpart, observer_userid):
return self._simple_insert(
@@ -174,16 +176,6 @@ class PresenceStore(SQLBaseStore):
desc="disallow_presence_visible",
)
def is_presence_visible(self, observed_localpart, observer_userid):
return self._simple_select_one(
table="presence_allow_inbound",
keyvalues={"observed_user_id": observed_localpart,
"observer_user_id": observer_userid},
retcols=["observed_user_id"],
allow_none=True,
desc="is_presence_visible",
)
def add_presence_list_pending(self, observer_localpart, observed_userid):
return self._simple_insert(
table="presence_list",

View File

@@ -392,7 +392,7 @@ class PushRuleStore(SQLBaseStore):
"""Get the position of the push rules stream.
Returns a pair of a stream id for the push_rules stream and the
room stream ordering it corresponds to."""
return self._push_rules_stream_id_gen.get_max_token()
return self._push_rules_stream_id_gen.get_current_token()
def have_push_rules_changed_for_user(self, user_id, last_id):
if not self.push_rules_stream_cache.has_entity_changed(user_id, last_id):

View File

@@ -18,6 +18,8 @@ from twisted.internet import defer
from canonicaljson import encode_canonical_json
from synapse.util.caches.descriptors import cachedInlineCallbacks
import logging
import simplejson as json
import types
@@ -48,23 +50,46 @@ class PusherStore(SQLBaseStore):
return rows
@defer.inlineCallbacks
def get_pushers_by_app_id_and_pushkey(self, app_id, pushkey):
def r(txn):
sql = (
"SELECT * FROM pushers"
" WHERE app_id = ? AND pushkey = ?"
)
txn.execute(sql, (app_id, pushkey,))
rows = self.cursor_to_dict(txn)
return self._decode_pushers_rows(rows)
rows = yield self.runInteraction(
"get_pushers_by_app_id_and_pushkey", r
def user_has_pusher(self, user_id):
ret = yield self._simple_select_one_onecol(
"pushers", {"user_name": user_id}, "id", allow_none=True
)
defer.returnValue(ret is not None)
defer.returnValue(rows)
def get_pushers_by_app_id_and_pushkey(self, app_id, pushkey):
return self.get_pushers_by({
"app_id": app_id,
"pushkey": pushkey,
})
def get_pushers_by_user_id(self, user_id):
return self.get_pushers_by({
"user_name": user_id,
})
@defer.inlineCallbacks
def get_pushers_by(self, keyvalues):
ret = yield self._simple_select_list(
"pushers", keyvalues,
[
"id",
"user_name",
"access_token",
"profile_tag",
"kind",
"app_id",
"app_display_name",
"device_display_name",
"pushkey",
"ts",
"lang",
"data",
"last_stream_ordering",
"last_success",
"failing_since",
], desc="get_pushers_by"
)
defer.returnValue(self._decode_pushers_rows(ret))
@defer.inlineCallbacks
def get_all_pushers(self):
@@ -78,7 +103,7 @@ class PusherStore(SQLBaseStore):
defer.returnValue(rows)
def get_pushers_stream_token(self):
return self._pushers_id_gen.get_max_token()
return self._pushers_id_gen.get_current_token()
def get_all_updated_pushers(self, last_id, current_id, limit):
def get_all_updated_pushers_txn(txn):
@@ -107,31 +132,50 @@ class PusherStore(SQLBaseStore):
"get_all_updated_pushers", get_all_updated_pushers_txn
)
@cachedInlineCallbacks(num_args=1)
def get_users_with_pushers_in_room(self, room_id):
users = yield self.get_users_in_room(room_id)
result = yield self._simple_select_many_batch(
table='pushers',
column='user_name',
iterable=users,
retcols=['user_name'],
desc='get_users_with_pushers_in_room'
)
defer.returnValue([r['user_name'] for r in result])
@defer.inlineCallbacks
def add_pusher(self, user_id, access_token, kind, app_id,
app_display_name, device_display_name,
pushkey, pushkey_ts, lang, data, profile_tag=""):
pushkey, pushkey_ts, lang, data, last_stream_ordering,
profile_tag=""):
with self._pushers_id_gen.get_next() as stream_id:
yield self._simple_upsert(
"pushers",
dict(
app_id=app_id,
pushkey=pushkey,
user_name=user_id,
),
dict(
access_token=access_token,
kind=kind,
app_display_name=app_display_name,
device_display_name=device_display_name,
ts=pushkey_ts,
lang=lang,
data=encode_canonical_json(data),
profile_tag=profile_tag,
id=stream_id,
),
desc="add_pusher",
)
def f(txn):
txn.call_after(self.get_users_with_pushers_in_room.invalidate_all)
return self._simple_upsert_txn(
txn,
"pushers",
{
"app_id": app_id,
"pushkey": pushkey,
"user_name": user_id,
},
{
"access_token": access_token,
"kind": kind,
"app_display_name": app_display_name,
"device_display_name": device_display_name,
"ts": pushkey_ts,
"lang": lang,
"data": encode_canonical_json(data),
"last_stream_ordering": last_stream_ordering,
"profile_tag": profile_tag,
"id": stream_id,
},
)
defer.returnValue((yield self.runInteraction("add_pusher", f)))
@defer.inlineCallbacks
def delete_pusher_by_app_id_pushkey_user_id(self, app_id, pushkey, user_id):
@@ -153,22 +197,28 @@ class PusherStore(SQLBaseStore):
)
@defer.inlineCallbacks
def update_pusher_last_token(self, app_id, pushkey, user_id, last_token):
def update_pusher_last_stream_ordering(self, app_id, pushkey, user_id,
last_stream_ordering):
yield self._simple_update_one(
"pushers",
{'app_id': app_id, 'pushkey': pushkey, 'user_name': user_id},
{'last_token': last_token},
desc="update_pusher_last_token",
{'last_stream_ordering': last_stream_ordering},
desc="update_pusher_last_stream_ordering",
)
@defer.inlineCallbacks
def update_pusher_last_token_and_success(self, app_id, pushkey, user_id,
last_token, last_success):
def update_pusher_last_stream_ordering_and_success(self, app_id, pushkey,
user_id,
last_stream_ordering,
last_success):
yield self._simple_update_one(
"pushers",
{'app_id': app_id, 'pushkey': pushkey, 'user_name': user_id},
{'last_token': last_token, 'last_success': last_success},
desc="update_pusher_last_token_and_success",
{
'last_stream_ordering': last_stream_ordering,
'last_success': last_success
},
desc="update_pusher_last_stream_ordering_and_success",
)
@defer.inlineCallbacks

View File

@@ -31,7 +31,7 @@ class ReceiptsStore(SQLBaseStore):
super(ReceiptsStore, self).__init__(hs)
self._receipts_stream_cache = StreamChangeCache(
"ReceiptsRoomChangeCache", self._receipts_id_gen.get_max_token()
"ReceiptsRoomChangeCache", self._receipts_id_gen.get_current_token()
)
@cached(num_args=2)
@@ -160,8 +160,8 @@ class ReceiptsStore(SQLBaseStore):
"content": content,
}])
@cachedList(cache=get_linearized_receipts_for_room.cache, list_name="room_ids",
num_args=3, inlineCallbacks=True)
@cachedList(cached_method_name="get_linearized_receipts_for_room",
list_name="room_ids", num_args=3, inlineCallbacks=True)
def _get_linearized_receipts_for_rooms(self, room_ids, to_key, from_key=None):
if not room_ids:
defer.returnValue({})
@@ -221,7 +221,7 @@ class ReceiptsStore(SQLBaseStore):
defer.returnValue(results)
def get_max_receipt_stream_id(self):
return self._receipts_id_gen.get_max_token()
return self._receipts_id_gen.get_current_token()
def insert_linearized_receipt_txn(self, txn, room_id, receipt_type,
user_id, event_id, data, stream_id):
@@ -346,7 +346,7 @@ class ReceiptsStore(SQLBaseStore):
room_id, receipt_type, user_id, event_ids, data
)
max_persisted_id = self._stream_id_gen.get_max_token()
max_persisted_id = self._stream_id_gen.get_current_token()
defer.returnValue((stream_id, max_persisted_id))
@@ -390,16 +390,19 @@ class ReceiptsStore(SQLBaseStore):
}
)
def get_all_updated_receipts(self, last_id, current_id, limit):
def get_all_updated_receipts(self, last_id, current_id, limit=None):
def get_all_updated_receipts_txn(txn):
sql = (
"SELECT stream_id, room_id, receipt_type, user_id, event_id, data"
" FROM receipts_linearized"
" WHERE ? < stream_id AND stream_id <= ?"
" ORDER BY stream_id ASC"
" LIMIT ?"
)
txn.execute(sql, (last_id, current_id, limit))
args = [last_id, current_id]
if limit is not None:
sql += " LIMIT ?"
args.append(limit)
txn.execute(sql, args)
return txn.fetchall()
return self.runInteraction(

View File

@@ -20,7 +20,7 @@ from twisted.internet import defer
from synapse.api.errors import StoreError, Codes
from ._base import SQLBaseStore
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
class RegistrationStore(SQLBaseStore):
@@ -319,26 +319,6 @@ class RegistrationStore(SQLBaseStore):
defer.returnValue(res if res else False)
@cachedList(cache=is_guest.cache, list_name="user_ids", num_args=1,
inlineCallbacks=True)
def are_guests(self, user_ids):
sql = "SELECT name, is_guest FROM users WHERE name IN (%s)" % (
",".join("?" for _ in user_ids),
)
rows = yield self._execute(
"are_guests", self.cursor_to_dict, sql, *user_ids
)
result = {user_id: False for user_id in user_ids}
result.update({
row["name"]: bool(row["is_guest"])
for row in rows
})
defer.returnValue(result)
def _query_for_auth(self, txn, token):
sql = (
"SELECT users.name, users.is_guest, access_tokens.id as token_id"
@@ -458,12 +438,15 @@ class RegistrationStore(SQLBaseStore):
"""
Gets the 3pid's guest access token if exists, else saves access_token.
:param medium (str): Medium of the 3pid. Must be "email".
:param address (str): 3pid address.
:param access_token (str): The access token to persist if none is
already persisted.
:param inviter_user_id (str): User ID of the inviter.
:return (deferred str): Whichever access token is persisted at the end
Args:
medium (str): Medium of the 3pid. Must be "email".
address (str): 3pid address.
access_token (str): The access token to persist if none is
already persisted.
inviter_user_id (str): User ID of the inviter.
Returns:
deferred str: Whichever access token is persisted at the end
of this function call.
"""
def insert(txn):

View File

@@ -36,7 +36,7 @@ RoomsForUser = namedtuple(
class RoomMemberStore(SQLBaseStore):
def _store_room_members_txn(self, txn, events):
def _store_room_members_txn(self, txn, events, backfilled):
"""Store a room member in the database.
"""
self._simple_insert_many_txn(
@@ -58,31 +58,72 @@ class RoomMemberStore(SQLBaseStore):
txn.call_after(self.get_rooms_for_user.invalidate, (event.state_key,))
txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,))
txn.call_after(self.get_users_in_room.invalidate, (event.room_id,))
txn.call_after(
self.get_users_with_pushers_in_room.invalidate, (event.room_id,)
)
txn.call_after(
self._membership_stream_cache.entity_has_changed,
event.state_key, event.internal_metadata.stream_ordering
)
txn.call_after(
self.get_invited_rooms_for_user.invalidate, (event.state_key,)
)
def get_room_member(self, user_id, room_id):
"""Retrieve the current state of a room member.
# We update the local_invites table only if the event is "current",
# i.e., its something that has just happened.
# The only current event that can also be an outlier is if its an
# invite that has come in across federation.
is_new_state = not backfilled and (
not event.internal_metadata.is_outlier()
or event.internal_metadata.is_invite_from_remote()
)
is_mine = self.hs.is_mine_id(event.state_key)
if is_new_state and is_mine:
if event.membership == Membership.INVITE:
self._simple_insert_txn(
txn,
table="local_invites",
values={
"event_id": event.event_id,
"invitee": event.state_key,
"inviter": event.sender,
"room_id": event.room_id,
"stream_id": event.internal_metadata.stream_ordering,
}
)
else:
sql = (
"UPDATE local_invites SET stream_id = ?, replaced_by = ? WHERE"
" room_id = ? AND invitee = ? AND locally_rejected is NULL"
" AND replaced_by is NULL"
)
Args:
user_id (str): The member's user ID.
room_id (str): The room the member is in.
Returns:
Deferred: Results in a MembershipEvent or None.
"""
return self.runInteraction(
"get_room_member",
self._get_members_events_txn,
room_id,
user_id=user_id,
).addCallback(
self._get_events
).addCallback(
lambda events: events[0] if events else None
txn.execute(sql, (
event.internal_metadata.stream_ordering,
event.event_id,
event.room_id,
event.state_key,
))
@defer.inlineCallbacks
def locally_reject_invite(self, user_id, room_id):
sql = (
"UPDATE local_invites SET stream_id = ?, locally_rejected = ? WHERE"
" room_id = ? AND invitee = ? AND locally_rejected is NULL"
" AND replaced_by is NULL"
)
def f(txn, stream_ordering):
txn.execute(sql, (
stream_ordering,
True,
room_id,
user_id,
))
with self._stream_id_gen.get_next() as stream_ordering:
yield self.runInteraction("locally_reject_invite", f, stream_ordering)
@cached(max_entries=5000)
def get_users_in_room(self, room_id):
def f(txn):
@@ -127,18 +168,23 @@ class RoomMemberStore(SQLBaseStore):
user_id, [Membership.INVITE]
)
def get_leave_and_ban_events_for_user(self, user_id):
""" Get all the leave events for a user
@defer.inlineCallbacks
def get_invite_for_user_in_room(self, user_id, room_id):
"""Gets the invite for the given user and room
Args:
user_id (str): The user ID.
user_id (str)
room_id (str)
Returns:
A deferred list of event objects.
Deferred: Resolves to either a RoomsForUser or None if no invite was
found.
"""
return self.get_rooms_for_user_where_membership_is(
user_id, (Membership.LEAVE, Membership.BAN)
).addCallback(lambda leaves: self._get_events([
leave.event_id for leave in leaves
]))
invites = yield self.get_invited_rooms_for_user(user_id)
for invite in invites:
if invite.room_id == room_id:
defer.returnValue(invite)
defer.returnValue(None)
def get_rooms_for_user_where_membership_is(self, user_id, membership_list):
""" Get all the rooms for this user where the membership for this user
@@ -163,29 +209,55 @@ class RoomMemberStore(SQLBaseStore):
def _get_rooms_for_user_where_membership_is_txn(self, txn, user_id,
membership_list):
where_clause = "user_id = ? AND (%s) AND forgotten = 0" % (
" OR ".join(["membership = ?" for _ in membership_list]),
)
args = [user_id]
args.extend(membership_list)
do_invite = Membership.INVITE in membership_list
membership_list = [m for m in membership_list if m != Membership.INVITE]
sql = (
"SELECT m.room_id, m.sender, m.membership, m.event_id, e.stream_ordering"
" FROM current_state_events as c"
" INNER JOIN room_memberships as m"
" ON m.event_id = c.event_id"
" INNER JOIN events as e"
" ON e.event_id = c.event_id"
" AND m.room_id = c.room_id"
" AND m.user_id = c.state_key"
" WHERE %s"
) % (where_clause,)
results = []
if membership_list:
where_clause = "user_id = ? AND (%s) AND forgotten = 0" % (
" OR ".join(["membership = ?" for _ in membership_list]),
)
txn.execute(sql, args)
return [
RoomsForUser(**r) for r in self.cursor_to_dict(txn)
]
args = [user_id]
args.extend(membership_list)
sql = (
"SELECT m.room_id, m.sender, m.membership, m.event_id, e.stream_ordering"
" FROM current_state_events as c"
" INNER JOIN room_memberships as m"
" ON m.event_id = c.event_id"
" INNER JOIN events as e"
" ON e.event_id = c.event_id"
" AND m.room_id = c.room_id"
" AND m.user_id = c.state_key"
" WHERE %s"
) % (where_clause,)
txn.execute(sql, args)
results = [
RoomsForUser(**r) for r in self.cursor_to_dict(txn)
]
if do_invite:
sql = (
"SELECT i.room_id, inviter, i.event_id, e.stream_ordering"
" FROM local_invites as i"
" INNER JOIN events as e USING (event_id)"
" WHERE invitee = ? AND locally_rejected is NULL"
" AND replaced_by is NULL"
)
txn.execute(sql, (user_id,))
results.extend(RoomsForUser(
room_id=r["room_id"],
sender=r["inviter"],
event_id=r["event_id"],
stream_ordering=r["stream_ordering"],
membership=Membership.INVITE,
) for r in self.cursor_to_dict(txn))
return results
@cached(max_entries=5000)
def get_joined_hosts_for_room(self, room_id):

View File

@@ -18,7 +18,7 @@ import logging
logger = logging.getLogger(__name__)
def run_upgrade(cur, *args, **kwargs):
def run_create(cur, *args, **kwargs):
cur.execute("SELECT id, regex FROM application_services_regex")
for row in cur.fetchall():
try:
@@ -35,3 +35,7 @@ def run_upgrade(cur, *args, **kwargs):
"UPDATE application_services_regex SET regex=? WHERE id=?",
(new_regex, row[0])
)
def run_upgrade(*args, **kwargs):
pass

View File

@@ -27,7 +27,7 @@ import logging
logger = logging.getLogger(__name__)
def run_upgrade(cur, database_engine, *args, **kwargs):
def run_create(cur, database_engine, *args, **kwargs):
logger.info("Porting pushers table...")
cur.execute("""
CREATE TABLE IF NOT EXISTS pushers2 (
@@ -74,3 +74,7 @@ def run_upgrade(cur, database_engine, *args, **kwargs):
cur.execute("DROP TABLE pushers")
cur.execute("ALTER TABLE pushers2 RENAME TO pushers")
logger.info("Moved %d pushers to new table", count)
def run_upgrade(*args, **kwargs):
pass

View File

@@ -43,7 +43,7 @@ SQLITE_TABLE = (
)
def run_upgrade(cur, database_engine, *args, **kwargs):
def run_create(cur, database_engine, *args, **kwargs):
if isinstance(database_engine, PostgresEngine):
for statement in get_statements(POSTGRES_TABLE.splitlines()):
cur.execute(statement)
@@ -76,3 +76,7 @@ def run_upgrade(cur, database_engine, *args, **kwargs):
sql = database_engine.convert_param_style(sql)
cur.execute(sql, ("event_search", progress_json))
def run_upgrade(*args, **kwargs):
pass

View File

@@ -27,7 +27,7 @@ ALTER_TABLE = (
)
def run_upgrade(cur, database_engine, *args, **kwargs):
def run_create(cur, database_engine, *args, **kwargs):
for statement in get_statements(ALTER_TABLE.splitlines()):
cur.execute(statement)
@@ -55,3 +55,7 @@ def run_upgrade(cur, database_engine, *args, **kwargs):
sql = database_engine.convert_param_style(sql)
cur.execute(sql, ("event_origin_server_ts", progress_json))
def run_upgrade(*args, **kwargs):
pass

View File

@@ -18,7 +18,7 @@ from synapse.storage.appservice import ApplicationServiceStore
logger = logging.getLogger(__name__)
def run_upgrade(cur, database_engine, config, *args, **kwargs):
def run_create(cur, database_engine, *args, **kwargs):
# NULL indicates user was not registered by an appservice.
try:
cur.execute("ALTER TABLE users ADD COLUMN appservice_id TEXT")
@@ -26,6 +26,8 @@ def run_upgrade(cur, database_engine, config, *args, **kwargs):
# Maybe we already added the column? Hope so...
pass
def run_upgrade(cur, database_engine, config, *args, **kwargs):
cur.execute("SELECT name FROM users")
rows = cur.fetchall()

View File

@@ -0,0 +1,38 @@
/* Copyright 2016 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* The positions in the event stream_ordering when the current_state was
* replaced by the state at the event.
*/
CREATE TABLE IF NOT EXISTS current_state_resets(
event_stream_ordering BIGINT PRIMARY KEY NOT NULL
);
/* The outlier events that have aquired a state group typically through
* backfill. This is tracked separately to the events table, as assigning a
* state group change the position of the existing event in the stream
* ordering.
* However since a stream_ordering is assigned in persist_event for the
* (event, state) pair, we can use that stream_ordering to identify when
* the new state was assigned for the event.
*/
CREATE TABLE IF NOT EXISTS ex_outlier_stream(
event_stream_ordering BIGINT PRIMARY KEY NOT NULL,
event_id TEXT NOT NULL,
state_group BIGINT NOT NULL
);

View File

@@ -0,0 +1,42 @@
/* Copyright 2016 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE local_invites(
stream_id BIGINT NOT NULL,
inviter TEXT NOT NULL,
invitee TEXT NOT NULL,
event_id TEXT NOT NULL,
room_id TEXT NOT NULL,
locally_rejected TEXT,
replaced_by TEXT
);
-- Insert all invites for local users into new `invites` table
INSERT INTO local_invites SELECT
stream_ordering as stream_id,
sender as inviter,
state_key as invitee,
event_id,
room_id,
NULL as locally_rejected,
NULL as replaced_by
FROM events
NATURAL JOIN current_state_events
NATURAL JOIN room_memberships
WHERE membership = 'invite' AND state_key IN (SELECT name FROM users);
CREATE INDEX local_invites_id ON local_invites(stream_id);
CREATE INDEX local_invites_for_user_idx ON local_invites(invitee, locally_rejected, replaced_by, room_id);

View File

@@ -0,0 +1,27 @@
/* Copyright 2016 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE TABLE local_media_repository_url_cache(
url TEXT, -- the URL being cached
response_code INTEGER, -- the HTTP response code of this download attempt
etag TEXT, -- the etag header of this response
expires INTEGER, -- the number of ms this response was valid for
og TEXT, -- cache of the OG metadata of this URL as JSON
media_id TEXT, -- the media_id, if any, of the URL's content in the repo
download_ts BIGINT -- the timestamp of this download attempt
);
CREATE INDEX local_media_repository_url_cache_by_url_download_ts
ON local_media_repository_url_cache(url, download_ts);

View File

@@ -0,0 +1,79 @@
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Change the last_token to last_stream_ordering now that pushers no longer
# listen on an event stream but instead select out of the event_push_actions
# table.
import logging
logger = logging.getLogger(__name__)
def token_to_stream_ordering(token):
return int(token[1:].split('_')[0])
def run_create(cur, database_engine, *args, **kwargs):
logger.info("Porting pushers table, delta 31...")
cur.execute("""
CREATE TABLE IF NOT EXISTS pushers2 (
id BIGINT PRIMARY KEY,
user_name TEXT NOT NULL,
access_token BIGINT DEFAULT NULL,
profile_tag VARCHAR(32) NOT NULL,
kind VARCHAR(8) NOT NULL,
app_id VARCHAR(64) NOT NULL,
app_display_name VARCHAR(64) NOT NULL,
device_display_name VARCHAR(128) NOT NULL,
pushkey TEXT NOT NULL,
ts BIGINT NOT NULL,
lang VARCHAR(8),
data TEXT,
last_stream_ordering INTEGER,
last_success BIGINT,
failing_since BIGINT,
UNIQUE (app_id, pushkey, user_name)
)
""")
cur.execute("""SELECT
id, user_name, access_token, profile_tag, kind,
app_id, app_display_name, device_display_name,
pushkey, ts, lang, data, last_token, last_success,
failing_since
FROM pushers
""")
count = 0
for row in cur.fetchall():
row = list(row)
row[12] = token_to_stream_ordering(row[12])
cur.execute(database_engine.convert_param_style("""
INSERT into pushers2 (
id, user_name, access_token, profile_tag, kind,
app_id, app_display_name, device_display_name,
pushkey, ts, lang, data, last_stream_ordering, last_success,
failing_since
) values (%s)""" % (','.join(['?' for _ in range(len(row))]))),
row
)
count += 1
cur.execute("DROP TABLE pushers")
cur.execute("ALTER TABLE pushers2 RENAME TO pushers")
logger.info("Moved %d pushers to new table", count)
def run_upgrade(cur, database_engine, *args, **kwargs):
pass

View File

@@ -0,0 +1,18 @@
/* Copyright 2016 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
CREATE INDEX event_push_actions_stream_ordering on event_push_actions(
stream_ordering, user_id
);

View File

@@ -64,12 +64,12 @@ class StateStore(SQLBaseStore):
for group, state_map in group_to_state.items()
})
def _store_state_groups_txn(self, txn, event, context):
return self._store_mult_state_groups_txn(txn, [(event, context)])
def _store_mult_state_groups_txn(self, txn, events_and_contexts):
state_groups = {}
for event, context in events_and_contexts:
if event.internal_metadata.is_outlier():
continue
if context.current_state is None:
continue
@@ -82,7 +82,8 @@ class StateStore(SQLBaseStore):
if event.is_state():
state_events[(event.type, event.state_key)] = event
state_group = self._state_groups_id_gen.get_next()
state_group = context.new_state_group_id
self._simple_insert_txn(
txn,
table="state_groups",
@@ -114,11 +115,10 @@ class StateStore(SQLBaseStore):
table="event_to_state_groups",
values=[
{
"state_group": state_groups[event.event_id],
"event_id": event.event_id,
"state_group": state_group_id,
"event_id": event_id,
}
for event, context in events_and_contexts
if context.current_state is not None
for event_id, state_group_id in state_groups.items()
],
)
@@ -174,6 +174,12 @@ class StateStore(SQLBaseStore):
return [r[0] for r in results]
return self.runInteraction("get_current_state_for_key", f)
@cached(num_args=2, lru=True, max_entries=1000)
def _get_state_group_from_group(self, group, types):
raise NotImplementedError()
@cachedList(cached_method_name="_get_state_group_from_group",
list_name="groups", num_args=2, inlineCallbacks=True)
def _get_state_groups_from_groups(self, groups, types):
"""Returns dictionary state_group -> (dict of (type, state_key) -> event id)
"""
@@ -201,18 +207,23 @@ class StateStore(SQLBaseStore):
txn.execute(sql, args)
rows = self.cursor_to_dict(txn)
results = {}
results = {group: {} for group in groups}
for row in rows:
key = (row["type"], row["state_key"])
results.setdefault(row["state_group"], {})[key] = row["event_id"]
results[row["state_group"]][key] = row["event_id"]
return results
results = {}
chunks = [groups[i:i + 100] for i in xrange(0, len(groups), 100)]
for chunk in chunks:
return self.runInteraction(
res = yield self.runInteraction(
"_get_state_groups_from_groups",
f, chunk
)
results.update(res)
defer.returnValue(results)
@defer.inlineCallbacks
def get_state_for_events(self, event_ids, types):
@@ -249,11 +260,14 @@ class StateStore(SQLBaseStore):
"""
Get the state dict corresponding to a particular event
:param str event_id: event whose state should be returned
:param list[(str, str)]|None types: List of (type, state_key) tuples
which are used to filter the state fetched. May be None, which
matches any key
:return: a deferred dict from (type, state_key) -> state_event
Args:
event_id(str): event whose state should be returned
types(list[(str, str)]|None): List of (type, state_key) tuples
which are used to filter the state fetched. May be None, which
matches any key
Returns:
A deferred dict from (type, state_key) -> state_event
"""
state_map = yield self.get_state_for_events([event_id], types)
defer.returnValue(state_map[event_id])
@@ -270,8 +284,8 @@ class StateStore(SQLBaseStore):
desc="_get_state_group_for_event",
)
@cachedList(cache=_get_state_group_for_event.cache, list_name="event_ids",
num_args=1, inlineCallbacks=True)
@cachedList(cached_method_name="_get_state_group_for_event",
list_name="event_ids", num_args=1, inlineCallbacks=True)
def _get_state_group_for_events(self, event_ids):
"""Returns mapping event_id -> state_group
"""
@@ -356,6 +370,8 @@ class StateStore(SQLBaseStore):
a `state_key` of None matches all state_keys. If `types` is None then
all events are returned.
"""
if types:
types = frozenset(types)
results = {}
missing_groups = []
if types is not None:
@@ -429,3 +445,33 @@ class StateStore(SQLBaseStore):
}
defer.returnValue(results)
def get_all_new_state_groups(self, last_id, current_id, limit):
def get_all_new_state_groups_txn(txn):
sql = (
"SELECT id, room_id, event_id FROM state_groups"
" WHERE ? < id AND id <= ? ORDER BY id LIMIT ?"
)
txn.execute(sql, (last_id, current_id, limit))
groups = txn.fetchall()
if not groups:
return ([], [])
lower_bound = groups[0][0]
upper_bound = groups[-1][0]
sql = (
"SELECT state_group, type, state_key, event_id"
" FROM state_groups_state"
" WHERE ? <= state_group AND state_group <= ?"
)
txn.execute(sql, (lower_bound, upper_bound))
state_group_state = txn.fetchall()
return (groups, state_group_state)
return self.runInteraction(
"get_all_new_state_groups", get_all_new_state_groups_txn
)
def get_state_stream_token(self):
return self._state_groups_id_gen.get_current_token()

View File

@@ -303,96 +303,6 @@ class StreamStore(SQLBaseStore):
defer.returnValue(ret)
def get_room_events_stream(
self,
user_id,
from_key,
to_key,
limit=0,
is_guest=False,
room_ids=None
):
room_ids = room_ids or []
room_ids = [r for r in room_ids]
if is_guest:
current_room_membership_sql = (
"SELECT c.room_id FROM history_visibility AS h"
" INNER JOIN current_state_events AS c"
" ON h.event_id = c.event_id"
" WHERE c.room_id IN (%s)"
" AND h.history_visibility = 'world_readable'" % (
",".join(map(lambda _: "?", room_ids))
)
)
current_room_membership_args = room_ids
else:
current_room_membership_sql = (
"SELECT m.room_id FROM room_memberships as m "
" INNER JOIN current_state_events as c"
" ON m.event_id = c.event_id AND c.state_key = m.user_id"
" WHERE m.user_id = ? AND m.membership = 'join'"
)
current_room_membership_args = [user_id]
# We also want to get any membership events about that user, e.g.
# invites or leave notifications.
membership_sql = (
"SELECT m.event_id FROM room_memberships as m "
"INNER JOIN current_state_events as c ON m.event_id = c.event_id "
"WHERE m.user_id = ? "
)
membership_args = [user_id]
if limit:
limit = max(limit, MAX_STREAM_SIZE)
else:
limit = MAX_STREAM_SIZE
# From and to keys should be integers from ordering.
from_id = RoomStreamToken.parse_stream_token(from_key)
to_id = RoomStreamToken.parse_stream_token(to_key)
if from_key == to_key:
return defer.succeed(([], to_key))
sql = (
"SELECT e.event_id, e.stream_ordering FROM events AS e WHERE "
"(e.outlier = ? AND (room_id IN (%(current)s)) OR "
"(event_id IN (%(invites)s))) "
"AND e.stream_ordering > ? AND e.stream_ordering <= ? "
"ORDER BY stream_ordering ASC LIMIT %(limit)d "
) % {
"current": current_room_membership_sql,
"invites": membership_sql,
"limit": limit
}
def f(txn):
args = ([False] + current_room_membership_args + membership_args +
[from_id.stream, to_id.stream])
txn.execute(sql, args)
rows = self.cursor_to_dict(txn)
ret = self._get_events_txn(
txn,
[r["event_id"] for r in rows],
get_prev_content=True
)
self._set_before_and_after(ret, rows)
if rows:
key = "s%d" % max(r["stream_ordering"] for r in rows)
else:
# Assume we didn't get anything because there was nothing to
# get.
key = to_key
return ret, key
return self.runInteraction("get_room_events_stream", f)
@defer.inlineCallbacks
def paginate_room_events(self, room_id, from_key, to_key=None,
direction='b', limit=-1):
@@ -539,7 +449,7 @@ class StreamStore(SQLBaseStore):
@defer.inlineCallbacks
def get_room_events_max_id(self, direction='f'):
token = yield self._stream_id_gen.get_max_token()
token = yield self._stream_id_gen.get_current_token()
if direction != 'b':
defer.returnValue("s%d" % (token,))
else:

View File

@@ -30,7 +30,7 @@ class TagsStore(SQLBaseStore):
Returns:
A deferred int.
"""
return self._account_data_id_gen.get_max_token()
return self._account_data_id_gen.get_current_token()
@cached()
def get_tags_for_user(self, user_id):
@@ -200,7 +200,7 @@ class TagsStore(SQLBaseStore):
self.get_tags_for_user.invalidate((user_id,))
result = self._account_data_id_gen.get_max_token()
result = self._account_data_id_gen.get_current_token()
defer.returnValue(result)
@defer.inlineCallbacks
@@ -222,7 +222,7 @@ class TagsStore(SQLBaseStore):
self.get_tags_for_user.invalidate((user_id,))
result = self._account_data_id_gen.get_max_token()
result = self._account_data_id_gen.get_current_token()
defer.returnValue(result)
def _update_revision_txn(self, txn, user_id, room_id, next_id):

View File

@@ -21,7 +21,7 @@ import threading
class IdGenerator(object):
def __init__(self, db_conn, table, column):
self._lock = threading.Lock()
self._next_id = _load_max_id(db_conn, table, column)
self._next_id = _load_current_id(db_conn, table, column)
def get_next(self):
with self._lock:
@@ -29,12 +29,16 @@ class IdGenerator(object):
return self._next_id
def _load_max_id(db_conn, table, column):
def _load_current_id(db_conn, table, column, step=1):
cur = db_conn.cursor()
cur.execute("SELECT MAX(%s) FROM %s" % (column, table,))
if step == 1:
cur.execute("SELECT MAX(%s) FROM %s" % (column, table,))
else:
cur.execute("SELECT MIN(%s) FROM %s" % (column, table,))
val, = cur.fetchone()
cur.close()
return int(val) if val else 1
current_id = int(val) if val else step
return (max if step > 0 else min)(current_id, step)
class StreamIdGenerator(object):
@@ -45,17 +49,32 @@ class StreamIdGenerator(object):
all ids less than or equal to it have completed. This handles the fact that
persistence of events can complete out of order.
Args:
db_conn(connection): A database connection to use to fetch the
initial value of the generator from.
table(str): A database table to read the initial value of the id
generator from.
column(str): The column of the database table to read the initial
value from the id generator from.
extra_tables(list): List of pairs of database tables and columns to
use to source the initial value of the generator from. The value
with the largest magnitude is used.
step(int): which direction the stream ids grow in. +1 to grow
upwards, -1 to grow downwards.
Usage:
with stream_id_gen.get_next() as stream_id:
# ... persist event ...
"""
def __init__(self, db_conn, table, column, extra_tables=[]):
def __init__(self, db_conn, table, column, extra_tables=[], step=1):
assert step != 0
self._lock = threading.Lock()
self._current_max = _load_max_id(db_conn, table, column)
self._step = step
self._current = _load_current_id(db_conn, table, column, step)
for table, column in extra_tables:
self._current_max = max(
self._current_max,
_load_max_id(db_conn, table, column)
self._current = (max if step > 0 else min)(
self._current,
_load_current_id(db_conn, table, column, step)
)
self._unfinished_ids = deque()
@@ -66,8 +85,8 @@ class StreamIdGenerator(object):
# ... persist event ...
"""
with self._lock:
self._current_max += 1
next_id = self._current_max
self._current += self._step
next_id = self._current
self._unfinished_ids.append(next_id)
@@ -88,8 +107,12 @@ class StreamIdGenerator(object):
# ... persist events ...
"""
with self._lock:
next_ids = range(self._current_max + 1, self._current_max + n + 1)
self._current_max += n
next_ids = range(
self._current + self._step,
self._current + self._step * (n + 1),
self._step
)
self._current += n * self._step
for next_id in next_ids:
self._unfinished_ids.append(next_id)
@@ -105,15 +128,15 @@ class StreamIdGenerator(object):
return manager()
def get_max_token(self):
def get_current_token(self):
"""Returns the maximum stream id such that all stream ids less than or
equal to it have been successfully persisted.
"""
with self._lock:
if self._unfinished_ids:
return self._unfinished_ids[0] - 1
return self._unfinished_ids[0] - self._step
return self._current_max
return self._current
class ChainedIdGenerator(object):
@@ -125,7 +148,7 @@ class ChainedIdGenerator(object):
def __init__(self, chained_generator, db_conn, table, column):
self.chained_generator = chained_generator
self._lock = threading.Lock()
self._current_max = _load_max_id(db_conn, table, column)
self._current_max = _load_current_id(db_conn, table, column)
self._unfinished_ids = deque()
def get_next(self):
@@ -137,7 +160,7 @@ class ChainedIdGenerator(object):
with self._lock:
self._current_max += 1
next_id = self._current_max
chained_id = self.chained_generator.get_max_token()
chained_id = self.chained_generator.get_current_token()
self._unfinished_ids.append((next_id, chained_id))
@@ -151,7 +174,7 @@ class ChainedIdGenerator(object):
return manager()
def get_max_token(self):
def get_current_token(self):
"""Returns the maximum stream id such that all stream ids less than or
equal to it have been successfully persisted.
"""
@@ -160,4 +183,4 @@ class ChainedIdGenerator(object):
stream_id, chained_id = self._unfinished_ids[0]
return (stream_id - 1, chained_id)
return (self._current_max, self.chained_generator.get_max_token())
return (self._current_max, self.chained_generator.get_current_token())

Some files were not shown because too many files have changed in this diff Show More