1
0

Compare commits

...

226 Commits

Author SHA1 Message Date
Erik Johnston
85f28744e8 Fix /keys/changes TypeError
We also use the new cache of users who share rooms with cache.
2017-02-02 13:07:52 +00:00
Erik Johnston
54a79c1d37 Make presence.get_new_events a bit faster
We do this by caching the set of users a user shares rooms with.
2017-02-02 13:07:18 +00:00
Erik Johnston
14d5e22700 Merge pull request #1872 from matrix-org/erikj/key_changes
Include newly joined users in /keys/changes API
2017-02-01 18:10:33 +00:00
Erik Johnston
fbfe44bb4d Doc args 2017-02-01 17:52:57 +00:00
Erik Johnston
d61a04583e Comment 2017-02-01 17:35:23 +00:00
Erik Johnston
7e919bdbd0 Include newly joined users in /keys/changes API 2017-02-01 17:33:16 +00:00
Erik Johnston
96355d2f2f Merge pull request #1871 from matrix-org/erikj/ratelimit_401
Correctly raise exceptions for ratelimitng. Ratelimit on 401
2017-02-01 15:56:16 +00:00
Erik Johnston
df4ecff5a9 Correctly raise exceptions for ratelimitng. Ratelimit on 401 2017-02-01 15:42:19 +00:00
Erik Johnston
6d6591880e Wake sync up for device changes 2017-02-01 15:15:16 +00:00
Erik Johnston
bd84387ac6 Merge pull request #1869 from matrix-org/erikj/device_list_stream
Implement /keys/changes
2017-02-01 13:25:26 +00:00
Erik Johnston
ebfaff84c9 Merge pull request #1870 from matrix-org/erikj/cache_get_all_new_events
Add a small cache get_all_new_events
2017-02-01 13:22:02 +00:00
Erik Johnston
73d676dc8b Comment 2017-02-01 13:17:17 +00:00
Erik Johnston
62f6b86ba7 Merge pull request #1868 from matrix-org/erikj/replication_cache
Only invalidate membership caches based on the cache stream
2017-02-01 13:12:30 +00:00
Erik Johnston
f6124311fd Add m.room.member type to query 2017-02-01 11:59:17 +00:00
Erik Johnston
88a4d54883 Merge pull request #1867 from matrix-org/erikj/member_index
Add an index to make membership queries faster
2017-02-01 11:44:27 +00:00
Erik Johnston
368c88c487 Add a small cache get_all_new_events 2017-02-01 10:50:44 +00:00
Erik Johnston
5deaf9e30b Up get_latest_event_ids_in_room cache 2017-02-01 10:39:41 +00:00
Erik Johnston
acb501c46d Comment 2017-02-01 10:32:49 +00:00
Erik Johnston
97479d0c54 Implement /keys/changes 2017-02-01 10:30:03 +00:00
Erik Johnston
06567ec513 Merge pull request #1866 from matrix-org/erikj/device_list_fixes
Better handle 404 response for federation /send/
2017-02-01 09:44:14 +00:00
Erik Johnston
692daf6f54 Remote membership tests for replication
This is because it now relies of the caches stream, which only works on
postgres. We are trying to test with sqlite.
2017-01-31 16:10:16 +00:00
Erik Johnston
458b6f4733 Only invalidate membership caches based on the cache stream
Before we completely invalidated get_users_in_room whenever we updated
any current_state_events table. This was way too aggressive.
2017-01-31 16:09:03 +00:00
Erik Johnston
fe08db2713 Remove explicit < 400 check as apparently this is confusing 2017-01-31 15:21:32 +00:00
Erik Johnston
21b7375778 Add an index to make membership queries faster 2017-01-31 15:15:57 +00:00
Erik Johnston
4c0ec15bdc Comment 2017-01-31 13:53:46 +00:00
Erik Johnston
85c590105f Comment 2017-01-31 13:46:38 +00:00
Erik Johnston
ae7a132f38 Better handle 404 response for federation /send/ 2017-01-31 13:40:09 +00:00
Erik Johnston
ac001dabdc Merge pull request #1864 from matrix-org/erikj/device_list_fixes
Fix clearing out old device list outbound pokes
2017-01-31 13:35:35 +00:00
Erik Johnston
bfb3d255b1 Merge pull request #1862 from matrix-org/erikj/presence_update
Use DB cache of joined users for presence
2017-01-31 13:23:24 +00:00
Erik Johnston
ab55794b6f Fix deletion of old sent devices correctly 2017-01-31 13:22:41 +00:00
Erik Johnston
d3169e8d28 Only fetch with row ts and count > 1 2017-01-31 11:20:03 +00:00
Erik Johnston
05b9f48ee5 Fix clearing out old device list outbound pokes 2017-01-31 10:08:55 +00:00
Erik Johnston
4c9812f5da Merge pull request #1861 from matrix-org/erikj/device_list_fixes
Device List fixes
2017-01-30 17:56:19 +00:00
Erik Johnston
4b3403ca9b Stream cache invalidations for room membership storage functions 2017-01-30 17:28:22 +00:00
Erik Johnston
1c13c9f6b6 Don't have such a large cache 2017-01-30 17:12:14 +00:00
Erik Johnston
c7a26b7c32 Fix unit tests 2017-01-30 17:11:24 +00:00
Erik Johnston
fd1c18c088 Use DB cache of joined users for presence 2017-01-30 17:00:24 +00:00
Erik Johnston
c2c9a78db9 Noop device key changes if they're the same 2017-01-30 16:55:04 +00:00
Erik Johnston
e75a779d9e Fix query 2017-01-30 16:38:20 +00:00
Erik Johnston
828db669ec Use get_users_in_room and declare it iterable 2017-01-30 16:37:22 +00:00
Erik Johnston
9636b2407d Merge pull request #1857 from matrix-org/erikj/device_list_stream
Implement device lists updates over federation
2017-01-30 14:35:21 +00:00
Erik Johnston
3670025e64 Rename func 2017-01-30 14:11:31 +00:00
Erik Johnston
4ac363a168 Remove debug logging 2017-01-30 14:10:12 +00:00
Erik Johnston
d360c97ae1 Clear out old destination pokes. 2017-01-30 10:14:37 +00:00
Erik Johnston
76100203ab Always use the latest stream_id, sent or unsent 2017-01-30 10:14:25 +00:00
Erik Johnston
d1e1fd6210 Add ts column to device_lists_outbound_pokes 2017-01-27 15:23:48 +00:00
Erik Johnston
252b503fc8 Hook device list updates to replication 2017-01-27 14:31:35 +00:00
Erik Johnston
84a35f32c7 Comment 2017-01-27 10:35:12 +00:00
Erik Johnston
c517a19c2d Comment 2017-01-27 10:33:26 +00:00
Erik Johnston
738a2867c8 SQL param ordering 2017-01-27 10:31:29 +00:00
Erik Johnston
755adff0e4 User if rather than for 2017-01-27 10:31:06 +00:00
Erik Johnston
888c59c955 Better name 2017-01-27 10:29:47 +00:00
Erik Johnston
f25a4a4692 Remove unused param 2017-01-27 10:27:39 +00:00
Erik Johnston
b3e1f2aa7a Fix unit tests 2017-01-26 17:16:24 +00:00
Erik Johnston
31aca5589c Fix on sqlite: use left rather than outer join 2017-01-26 16:55:50 +00:00
Erik Johnston
76d40f4904 Handle users leaving rooms 2017-01-26 16:39:33 +00:00
Erik Johnston
fbfad76c03 Add comments 2017-01-26 16:33:21 +00:00
Erik Johnston
c974116f19 Implement device key caching over federation 2017-01-26 16:07:24 +00:00
Paul Evans
e978247fe5 Merge pull request #1852 from matrix-org/paul/issue-1382
Don't clobber a displayname or avatar_url if provided by an m.room.member event
2017-01-25 18:15:19 +00:00
Erik Johnston
51e9fe36e4 Fix up sending of m.device_list_update edus 2017-01-25 16:55:21 +00:00
Erik Johnston
2367c5568c Add basic implementation of local device list changes 2017-01-25 14:27:27 +00:00
Paul "LeoNerd" Evans
10e48d8310 Don't clobber a displayname or avatar_url if provided by an m.room.member event 2017-01-24 18:06:07 +00:00
Erik Johnston
ba8e144554 Merge branch 'erikj/current_state_fix' into develop 2017-01-23 16:15:10 +00:00
Erik Johnston
f5b46482f4 Merge pull request #1840 from matrix-org/erikj/current_state_fix
Insert delta of current_state_events to be more efficient
2017-01-23 16:14:34 +00:00
Erik Johnston
fdf2a31a51 Typo 2017-01-23 16:14:14 +00:00
Erik Johnston
c77b24c092 Refactor to calculate state delta outside transaction 2017-01-23 14:51:33 +00:00
Erik Johnston
5d2134d485 Comments 2017-01-20 17:13:24 +00:00
Erik Johnston
a55fa2047f Insert delta of current_state_events to be more efficient 2017-01-20 17:10:18 +00:00
Erik Johnston
3d9d48fffb Merge pull request #1836 from matrix-org/erikj/current_state_fix
Derive current_state_events from state groups
2017-01-20 15:14:05 +00:00
Richard van der Hoff
a0d03f2e15 Merge pull request #1837 from matrix-org/rav/fix_purge_media_doc
fix doc for purge_media_cache
2017-01-20 15:12:34 +00:00
Erik Johnston
d0897dead5 Spelling 2017-01-20 15:05:11 +00:00
Erik Johnston
567aa35b67 Update all call sites after rename 2017-01-20 14:40:31 +00:00
Erik Johnston
f2f40e64a9 Comments 2017-01-20 14:38:13 +00:00
Erik Johnston
4c6a31cd6e Calculate the forward extremeties once 2017-01-20 14:28:53 +00:00
Richard van der Hoff
83333498a5 fix doc for purge_media_cache
purge_media_cache takes its arg from a query-param, not the POST body, for some
reason.
2017-01-20 12:15:50 +00:00
Erik Johnston
86063d4321 Merge pull request #1835 from matrix-org/erikj/fix_workers
Make worker listener config backwards compat
2017-01-20 11:55:56 +00:00
Erik Johnston
09eb08f910 Derive current_state_events from state groups 2017-01-20 11:52:51 +00:00
Erik Johnston
97efe99ae9 Make worker listener config backwards compat 2017-01-20 11:45:29 +00:00
David Baker
691c8198b7 Merge pull request #1832 from xsteadfastx/xsteadfastx/turn-username-password
Added username and password for turn server
2017-01-19 14:28:31 +00:00
Marvin Steadfast
86e6165687 Added default config for turn username and password 2017-01-19 14:35:55 +01:00
Marvin Steadfast
1e38be3a7a Added username and password for turn server
It makes it possible to use a turn server that needs a username and
password instead of a token.
2017-01-19 14:08:20 +01:00
Erik Johnston
841c228533 Merge pull request #1828 from matrix-org/erikj/iterable_cache_size
Update LruCache size estimate on clear
2017-01-18 14:57:54 +00:00
Erik Johnston
c430111d0e Update LruCache size estimate on clear 2017-01-18 14:55:23 +00:00
David Baker
97d3918377 Merge pull request #1811 from aperezdc/unhardcode-riot-urls
Allow configuring the Riot URL used in notification emails
2017-01-18 14:38:49 +00:00
David Baker
6f6bf2a1eb Merge pull request #1827 from matrix-org/dbkr/email_case_insensitive
Lowercase all email addresses before querying db
2017-01-18 13:39:47 +00:00
David Baker
8c5009b628 Lowercase all email addresses before querying db
Since we store all emails in the DB in lowercase
(https://github.com/matrix-org/synapse/pull/1170)
2017-01-18 13:25:56 +00:00
Erik Johnston
ae7b4da4cc Merge pull request #1823 from matrix-org/erikj/load_events_logs
Remove loading events logs
2017-01-18 11:07:58 +00:00
Erik Johnston
fc7cae8aa3 Merge pull request #1824 from matrix-org/erikj/retry_host_log
Lower the not retrying host log line to debug
2017-01-18 11:07:51 +00:00
Erik Johnston
f9058ca785 Merge pull request #1822 from matrix-org/erikj/statE_logging
Change resolve_state_groups call site logging to DEBUG
2017-01-18 11:02:03 +00:00
Erik Johnston
f648313f98 Merge pull request #1821 from matrix-org/erikj/cache_metrics_string_intern
Measure metrics of string_cache
2017-01-18 10:57:39 +00:00
Erik Johnston
15f012032c Merge pull request #1818 from matrix-org/erikj/state_auth_splitout_split
Optimise state resolution
2017-01-18 10:53:00 +00:00
Erik Johnston
4ec1cf49e2 Lower loading events log to DEBUG 2017-01-17 17:28:32 +00:00
Erik Johnston
f878f64f43 Lower the not retrying host log line to debug 2017-01-17 17:20:39 +00:00
Erik Johnston
5f027d1fc5 Change resolve_state_groups call site logging to DEBUG 2017-01-17 17:07:15 +00:00
Erik Johnston
380dba1020 Measure metrics of string_cache 2017-01-17 17:04:46 +00:00
Erik Johnston
ed4d176152 PEP8 2017-01-17 15:27:28 +00:00
Mark Haines
c6064a7ba6 Only construct sets when necessary 2017-01-17 15:23:07 +00:00
Erik Johnston
a8594fd19f Use better names 2017-01-17 14:59:03 +00:00
Erik Johnston
7fae460402 Merge pull request #1820 from matrix-org/erikj/push_tools
Get state at event rather than for room in push
2017-01-17 14:58:26 +00:00
Erik Johnston
37b4c7d8a9 Fix typo in return type 2017-01-17 14:43:32 +00:00
Erik Johnston
e5d2df9c34 Use better variable name 2017-01-17 14:32:53 +00:00
Erik Johnston
04006bb7f0 Get state at event rather than for room in push 2017-01-17 14:31:21 +00:00
Erik Johnston
ce59a2faad Correctly handle case of rejected events in state res 2017-01-17 14:18:53 +00:00
Erik Johnston
633f97151c Check event is in state_map 2017-01-17 13:33:54 +00:00
Erik Johnston
e6153e1bd1 Fix couple of federation state bugs 2017-01-17 13:22:34 +00:00
Erik Johnston
5d6bad1b3c Optimise state resolution 2017-01-17 13:22:19 +00:00
Erik Johnston
e8ecbb6f20 Merge pull request #1812 from matrix-org/erikj/state_auth_splitout_split
Split out static state methods from StateHandler
2017-01-17 11:55:18 +00:00
Erik Johnston
d11d7cdf87 Merge pull request #1815 from matrix-org/erikj/iter_cache_size
Optionally measure size of cache by sum of length of values
2017-01-17 11:51:09 +00:00
Erik Johnston
9e8e236d98 Tidy up test 2017-01-17 11:50:18 +00:00
Erik Johnston
d6c75cb7c2 Rename and comment tree_to_leaves_iterator 2017-01-17 11:47:03 +00:00
Erik Johnston
1ccd5676e3 Remove needless call to evict() 2017-01-17 11:42:26 +00:00
Erik Johnston
d906206049 Increase state_group_cache_size 2017-01-17 11:31:08 +00:00
Erik Johnston
f85b6ca494 Speed up cache size calculation
Instead of calculating the size of the cache repeatedly, which can take
a long time now that it can use a callback, instead cache the size and
update that on insertion and deletion.

This requires changing the cache descriptors to have two caches, one for
pending deferreds and the other for the actual values. There's no reason
to evict from the pending deferreds as they won't take up any more
memory.
2017-01-17 11:18:13 +00:00
Erik Johnston
f2f179dce2 Add ExpiringCache tests 2017-01-16 15:33:34 +00:00
Erik Johnston
6d00213e80 Use OrderedDict in ExpiringCache 2017-01-16 15:33:22 +00:00
Erik Johnston
897f8752da Up cache max entries for state 2017-01-16 15:08:17 +00:00
Erik Johnston
beda469bc6 Put staticmethods at module level 2017-01-16 15:05:24 +00:00
Erik Johnston
46aebbbcbf Add support for 'iterable' to ExpiringCache 2017-01-16 14:57:23 +00:00
Erik Johnston
01521299c7 Increase cache size limit 2017-01-16 11:56:51 +00:00
Erik Johnston
2fae34bd2c Optionally measure size of cache by sum of length of values 2017-01-13 17:46:17 +00:00
Erik Johnston
95a22ae194 Merge pull request #1810 from matrix-org/erikj/state_auth_splitout_split
Split out static auth methods from Auth object
2017-01-13 16:32:27 +00:00
Erik Johnston
ec0a523ac3 Split out static state methods from StateHandler 2017-01-13 15:25:06 +00:00
Erik Johnston
e178feca3f Remove unused function 2017-01-13 15:16:45 +00:00
Erik Johnston
f0325a9ccc Merge pull request #1793 from matrix-org/erikj/change_device_inbox_index
Change device_inbox stream index to include user
2017-01-13 15:14:51 +00:00
Erik Johnston
c050f493dd Add comment 2017-01-13 15:14:41 +00:00
Adrian Perez de Castro
a3e4a198e3 Allow configuring the Riot URL used in notification emails
The URLs used for notification emails were hardcoded to use either matrix.to
or vector.im; but for self-hosted setups where Riot is also self-hosted it
may be desirable to allow configuring an alternative Riot URL.

Fixes #1809.

Signed-off-by: Adrian Perez de Castro <aperez@igalia.com>
2017-01-13 17:12:04 +02:00
Erik Johnston
8b2fa38256 Split event auth code into seperate module 2017-01-13 15:07:32 +00:00
Erik Johnston
641ccdbb14 Merge pull request #1795 from matrix-org/erikj/port_defaults
Restore default bind address
2017-01-13 13:02:59 +00:00
Richard van der Hoff
6f5e41e420 README.rst: fix formatting
Fix formatting blooper introduced in https://github.com/matrix-org/synapse/pull/1672 :/
2017-01-13 12:52:11 +00:00
Erik Johnston
0d37a7bf83 Merge pull request #1803 from matrix-org/erikj/swallow_errors
Fix spurious Unhandled Error log lines
2017-01-13 10:52:41 +00:00
Erik Johnston
ebf94aff8d Fix spurious Unhandled Error log lines 2017-01-12 17:19:47 +00:00
Erik Johnston
7a13fe16f7 Merge pull request #1802 from matrix-org/erikj/remove_debug_deferreds
Remove full_twisted_stacktraces option
2017-01-12 14:25:51 +00:00
Erik Johnston
bf5c9706d9 Remove full_twisted_stacktraces option
The debug 'full_twisted_stacktraces' flag caused synapse to rewrite
twisted deferreds to always fire the callback on the next reactor tick.
This was to force the deferred to always store the stacktraces on
exceptions, and thus be more likely to have a full stacktrace when it
reaches the final error handlers and gets printed to the logs.

Dynamically rewriting things is generally bad, and in particular this
change violates assumptions of various bits of Twisted. This wouldn't
necessarily be so bad, but it turns out this option has been turned on
on some production servers.

Turning the option can cause e.g. #1778.

For now, lets just entirely nuke this option.
2017-01-12 10:32:52 +00:00
Erik Johnston
7b62d0bc70 Add missing None check 2017-01-11 10:57:03 +00:00
Erik Johnston
7e6c2937c3 Split out static auth methods from Auth object 2017-01-10 18:16:54 +00:00
Erik Johnston
b1dfd20292 Pop bind_address 2017-01-10 17:23:18 +00:00
Erik Johnston
edd6cdfc9a Restore default bind address 2017-01-10 17:21:41 +00:00
Matthew Hodgson
3cb1799347 credit patrik properly 2017-01-10 16:50:35 +00:00
Erik Johnston
8a0fddfd73 Remove spurious for..else.. 2017-01-10 16:30:53 +00:00
Erik Johnston
d524bc9110 Merge pull request #1792 from matrix-org/erikj/limit_cache_prefill_device
Limit number of entries to prefill from cache
2017-01-10 15:42:00 +00:00
Erik Johnston
d2b00d0866 Merge pull request #1790 from matrix-org/erikj/linearizer
Add paranoia exception catch in Linearizer
2017-01-10 15:38:30 +00:00
Erik Johnston
ab655dca33 Explicitly close the cursor 2017-01-10 15:15:25 +00:00
Erik Johnston
5a32e9273e Don't disable autocommit 2017-01-10 15:11:27 +00:00
Erik Johnston
caddadfc5a Change device_inbox stream index to include user
This makes fetching the nost recently changed users much tricker, and
brings it in line with e.g. presence_stream indices.
2017-01-10 15:04:57 +00:00
Erik Johnston
dd52d4de4c Limit number of entries to prefill from cache
Some tables, like device_inbox, take a long time to query at startup for
the stream change cache prefills. This is likely because they are slower
growing streams and so are more fragmented on disk. For now, lets pull
fewer entries out to make startup quicker.

In future, we should add a better index to make it even faster.
2017-01-10 14:34:50 +00:00
Mark Haines
024eb98524 Merge pull request #1791 from matrix-org/markjh/file_logging
Log which files we saved attachments to in the media_repository
2017-01-10 14:27:55 +00:00
Mark Haines
32019c9897 Log which files we saved attachments to in the media_repository 2017-01-10 14:19:50 +00:00
Erik Johnston
657488113e Merge pull request #1789 from matrix-org/erikj/decouple_presence
Don't block messages sending on bumping presence
2017-01-10 14:06:05 +00:00
Erik Johnston
3b4de17d2b Comment 2017-01-10 14:05:53 +00:00
Erik Johnston
7d0981b312 Merge pull request #1787 from matrix-org/erikj/linearize_member
Linearize updates to membership via PUT /state/
2017-01-10 14:04:54 +00:00
Erik Johnston
07c3c08fad Merge pull request #1786 from matrix-org/erikj/linearizer_name
Name linearizer's for better logs
2017-01-10 14:04:45 +00:00
Erik Johnston
f477370c0c Add paranoia exception catch in Linearizer 2017-01-10 14:04:13 +00:00
Erik Johnston
586f474a44 Don't block messages sending on bumping presence 2017-01-10 12:46:00 +00:00
Erik Johnston
6823fe5241 Linearize updates to membership via PUT /state/ 2017-01-09 18:25:13 +00:00
Erik Johnston
f7085ac84f Name linearizer's for better logs 2017-01-09 17:17:10 +00:00
Erik Johnston
9898bbd9dc Merge branch 'master' of github.com:matrix-org/synapse into develop 2017-01-09 14:51:17 +00:00
Erik Johnston
9a8ae6f1bf Bump version and changelog 2017-01-09 14:47:56 +00:00
Matthew Hodgson
2f4b2f4783 gah, fix mangled merge of 0.18.7 into develop 2017-01-07 04:00:42 +00:00
Matthew
6d363cea9d Merge branch 'release-v0.18.7' into develop 2017-01-07 03:46:16 +00:00
Matthew
f0e4bac64e bump changelog & version 2017-01-07 03:45:38 +00:00
Matthew
4304e7e593 do the discard check in the right place to avoid grabbing dependent events 2017-01-07 03:44:18 +00:00
Matthew Hodgson
6515b9c0d4 changelog 2017-01-07 02:52:37 +00:00
Matthew Hodgson
8c48971b51 Merge branch 'release-v0.18.7' into develop 2017-01-07 02:23:37 +00:00
Matthew
e10c527930 Discard PDUs from invalid origins due to #1753 in 0.18.[56] 2017-01-07 02:13:14 +00:00
Matthew Hodgson
2f5be2d8dc oops, this should have been rc1 2017-01-07 01:11:56 +00:00
Matthew Hodgson
4086026524 move logging to right place 2017-01-07 00:41:46 +00:00
Matthew Hodgson
9d914454c8 Merge branch 'release-v0.18.6' into develop 2017-01-07 00:40:30 +00:00
Matthew
19e2fb4386 bump version 2017-01-06 23:38:22 +00:00
Matthew
189fd15564 update changelog 2017-01-06 23:33:28 +00:00
Matthew
8404f132c3 Revert "fix typo breaking the fix to #1753"
This reverts commit b2850e62db.
2017-01-06 23:28:46 +00:00
Matthew
b2850e62db fix typo breaking the fix to #1753 2017-01-06 23:23:37 +00:00
Mark Haines
06c00bd19b Merge branch 'release-v0.18.6' into develop 2017-01-06 14:46:27 +00:00
Mark Haines
b42a972b71 Bump version and changelog 2017-01-06 14:44:28 +00:00
Mark Haines
2c8ac84a26 Merge pull request #1772 from matrix-org/markjh/fix_guest_access_check
handlers/room_member: fix guest access check when joining rooms
2017-01-06 14:41:52 +00:00
Patrik Oldsberg
1ef6084b75 handlers/room_member: fix guest access check when joining rooms
Signed-off-by: Patrik Oldsberg <patrik.oldsberg@ericsson.com>
2017-01-06 14:36:56 +00:00
Matthew Hodgson
bd85434cb3 Merge branch 'release-v0.18.6' into develop 2017-01-05 13:58:19 +00:00
Mark Haines
c18f7fc410 Fix flake8 and update changelog 2017-01-05 13:50:22 +00:00
Matthew Hodgson
dafd50d178 Merge pull request #1767 from matrix-org/matthew/resolve_state_group_logging
log call paths for resolve_state_group
2017-01-05 13:47:42 +00:00
Matthew Hodgson
883ff92a7f Fix case 2017-01-05 13:45:02 +00:00
Matthew Hodgson
d79d165761 add logging for all the places we call resolve_state_groups. my kingdom for a backtrace that actually works. 2017-01-05 13:40:39 +00:00
Matthew Hodgson
8cfc0165e9 fix annoying typos 2017-01-05 13:39:43 +00:00
Mark Haines
62451800e7 Bump version and changelog to v0.18.6-rc3 2017-01-05 13:36:10 +00:00
Matthew Hodgson
b31ed22738 Merge branch 'release-v0.18.6' into develop 2017-01-05 13:03:02 +00:00
Matthew Hodgson
7738329672 Merge pull request #1766 from matrix-org/markjh/linear_logging
More logging for the linearizer and for get_events
2017-01-05 13:01:31 +00:00
Mark Haines
dd3df11c55 More logging for the linearizer and for get_events 2017-01-05 12:32:47 +00:00
Mark Haines
e1c5463efc Merge pull request #1765 from matrix-org/markjh/timeout_get_missing_events
cherrypick #1744: limit total timeout for get_missing_events to 10s
2017-01-05 12:02:23 +00:00
Matthew Hodgson
468749c9fc fix comment 2017-01-05 12:00:11 +00:00
Matthew Hodgson
eedf400d05 limit total timeout for get_missing_events to 10s 2017-01-05 11:58:15 +00:00
Mark Haines
5175094707 Merge pull request #1744 from matrix-org/matthew/timeout_get_missing_events
limit total timeout for get_missing_events to 10s
2017-01-05 11:53:15 +00:00
Matthew Hodgson
8e82611f37 fix comment 2017-01-05 11:44:44 +00:00
Mark Haines
6028718b1a Merge pull request #1764 from matrix-org/markjh/fix_send_pdu
Only send events that originate on this server.
2017-01-05 11:41:22 +00:00
Mark Haines
f784980d2b Only send events that originate on this server.
Or events that are sent via the federation "send_join" API.

This should match the behaviour from before v0.18.5 and #1635 landed.
2017-01-05 11:26:30 +00:00
Mark Haines
0d766c8ccf Merge pull request #1758 from matrix-org/markjh/fix_ban_propagation
Fix propagation of bans to remote servers.
2017-01-04 15:39:31 +00:00
Mark Haines
e02bdaf08b Get the destinations from the state from before the event
Rather than the state after then event.
2017-01-04 15:17:15 +00:00
Mark Haines
b6b67715ed Send ALL membership events to the server that was affected.
Send all membership changes to the server that was affected.
This ensures that if the last member of a room on a server
was kicked or banned they get told about it.
2017-01-04 13:56:20 +00:00
Matthew Hodgson
555d702e34 limit total timeout for get_missing_events to 10s 2016-12-31 15:21:37 +00:00
Matthew Hodgson
899a3a1268 Merge branch 'release-v0.18.6' into develop 2016-12-31 02:38:26 +00:00
Mark Haines
f3de4f8cb7 Bump version and changelog 2016-12-30 20:21:04 +00:00
Mark Haines
321d5b73d8 Merge pull request #1736 from matrix-org/markjh/linearizer_logging
Add more useful logging when we block fetching events
2016-12-30 20:05:12 +00:00
Mark Haines
62ce3034f3 s/aquire/acquire/g 2016-12-30 20:04:44 +00:00
Mark Haines
0aff09f6c9 Add more useful logging when we block fetching events 2016-12-30 20:00:44 +00:00
Mark Haines
48c3b7dc19 Merge pull request #1734 from matrix-org/markjh/fix_get_missing
Remove fallback from get_missing_events.
2016-12-30 19:42:11 +00:00
Mark Haines
cc50b1ae53 Remove fallback from get_missing_events.
get_missing_events used to fallback to fetching the missing events
individually requesting from every server in the room, one by one.e

This could be unacceptably slow, possibly causing #1732
2016-12-30 18:13:15 +00:00
Mark Haines
f576c34594 Merge remote-tracking branch 'origin/release-v0.18.6' into develop 2016-12-30 15:13:49 +00:00
Mark Haines
0eac4fa525 Merge pull request #1731 from matrix-org/markjh/logging-memleak
Use the new twisted logging framework.
2016-12-30 12:52:50 +00:00
Mark Haines
822cb39dfa Use the new twisted logging framework.
Hopefully adding an observer to the new framework will avoid a memory
leak https://twistedmatrix.com/trac/ticket/8164
2016-12-30 11:09:24 +00:00
Mark Haines
342fb8dae9 Merge branch 'release-v0.18.6' into develop 2016-12-29 17:33:46 +00:00
David Baker
84cf00c645 Fix another comment typo 2016-12-21 09:51:43 +00:00
David Baker
bea15fb599 Merge pull request #1714 from matrix-org/dbkr/delete_threepid
Add /account/3pid/delete endpoint
2016-12-21 09:51:04 +00:00
David Baker
0c88ab1844 Add /account/3pid/delete endpoint
Also fix a typo in a comment
2016-12-20 18:27:30 +00:00
Matthew Hodgson
b7f4f902fa Merge pull request #1712 from kyrias/fix-bind-address-none
Fix check for bind_address
2016-12-20 00:41:42 +00:00
Johannes Löthberg
702c020e58 Fix check for bind_address
The empty string is a valid setting for the bind_address option, so
explicitly check for None here instead.

Signed-off-by: Johannes Löthberg <johannes@kyriasis.com>
2016-12-20 01:37:50 +01:00
Matthew Hodgson
09f15918be Merge pull request #1711 from matrix-org/matthew/utf8-password-change
fix ability to change password to a non-ascii one
2016-12-20 00:02:13 +00:00
Matthew Hodgson
da2c8f3c94 Merge pull request #1709 from kyrias/bind_addresses
Add support for specifying multiple bind addresses
2016-12-19 23:49:34 +00:00
Matthew Hodgson
a58e4e0d48 Merge pull request #1696 from kyrias/ipv6
IPv6 support
2016-12-19 23:49:07 +00:00
Matthew Hodgson
f2a5aebf98 fix ability to change password to a non-ascii one
https://github.com/vector-im/riot-web/issues/2658
2016-12-18 22:25:21 +00:00
Johannes Löthberg
a9c1b419a9 Bump twisted dependency
At least 16.0.0 is needed for wrapClientTLS support.

Signed-off-by: Johannes Löthberg <johannes@kyriasis.com>
2016-12-18 23:16:43 +01:00
Johannes Löthberg
f5cd5ebd7b Add IPv6 comment to default config
Signed-off-by: Johannes Löthberg <johannes@kyriasis.com>
2016-12-18 23:14:32 +01:00
Johannes Löthberg
1859af9b2a Update README to use bind_addresses
Signed-off-by: Johannes Löthberg <johannes@kyriasis.com>
2016-12-18 22:01:34 +01:00
Johannes Löthberg
c95e9fff99 Make default homeserver config use bind_addresses
Signed-off-by: Johannes Löthberg <johannes@kyriasis.com>
2016-12-18 21:51:56 +01:00
Johannes Löthberg
7dfd70fc83 Add support for specifying multiple bind addresses
Signed-off-by: Johannes Löthberg <johannes@kyriasis.com>
2016-12-18 21:51:56 +01:00
Erik Johnston
b2f8642d3d Cache network room list queries. 2016-12-16 16:11:43 +00:00
Johannes Löthberg
0648e76979 Remove spurious newline
Apparently I just removed the spaces instead...

Signed-off-by: Johannes Löthberg <johannes@kyriasis.com>
2016-12-12 18:41:30 +01:00
Johannes Löthberg
d3bd94805f Fixup for #1689 and #1690
Signed-off-by: Johannes Löthberg <johannes@kyriasis.com>
2016-12-12 16:32:47 +01:00
Glyph
9f07f4c559 IPv6 support for endpoint.py
Similar to https://github.com/matrix-org/synapse/pull/1689, but for endpoint.py
2016-12-11 11:10:32 +01:00
Glyph
6e18805ac2 IPv6 support for client.py
This is an (untested) general sketch of how to use wrapClientTLS to implement TLS over IPv6, as well as faster connections over IPv4.
2016-12-11 11:10:32 +01:00
97 changed files with 3643 additions and 1806 deletions

View File

@@ -1,3 +1,64 @@
Changes in synapse v0.18.7 (2017-01-09)
=======================================
No changes from v0.18.7-rc2
Changes in synapse v0.18.7-rc2 (2017-01-07)
===========================================
Bug fixes:
* Fix error in rc1's discarding invalid inbound traffic logic that was
incorrectly discarding missing events
Changes in synapse v0.18.7-rc1 (2017-01-06)
===========================================
Bug fixes:
* Fix error in #PR 1764 to actually fix the nightmare #1753 bug.
* Improve deadlock logging further
* Discard inbound federation traffic from invalid domains, to immunise
against #1753
Changes in synapse v0.18.6 (2017-01-06)
=======================================
Bug fixes:
* Fix bug when checking if a guest user is allowed to join a room (PR #1772)
Thanks to Patrik Oldsberg for diagnosing and the fix!
Changes in synapse v0.18.6-rc3 (2017-01-05)
===========================================
Bug fixes:
* Fix bug where we failed to send ban events to the banned server (PR #1758)
* Fix bug where we sent event that didn't originate on this server to
other servers (PR #1764)
* Fix bug where processing an event from a remote server took a long time
because we were making long HTTP requests (PR #1765, PR #1744)
Changes:
* Improve logging for debugging deadlocks (PR #1766, PR #1767)
Changes in synapse v0.18.6-rc2 (2016-12-30)
===========================================
Bug fixes:
* Fix memory leak in twisted by initialising logging correctly (PR #1731)
* Fix bug where fetching missing events took an unacceptable amount of time in
large rooms (PR #1734)
Changes in synapse v0.18.6-rc1 (2016-12-29)
===========================================
@@ -5,6 +66,7 @@ Bug fixes:
* Make sure that outbound connections are closed (PR #1725)
Changes in synapse v0.18.5 (2016-12-16)
=======================================

View File

@@ -138,6 +138,7 @@ Installing prerequisites on openSUSE::
python-devel libffi-devel libopenssl-devel libjpeg62-devel
Installing prerequisites on OpenBSD::
doas pkg_add python libffi py-pip py-setuptools sqlite3 py-virtualenv \
libxslt
@@ -658,7 +659,7 @@ configuration might look like::
}
}
You will also want to set ``bind_address: 127.0.0.1`` and ``x_forwarded: true``
You will also want to set ``bind_addresses: ['127.0.0.1']`` and ``x_forwarded: true``
for port 8008 in ``homeserver.yaml`` to ensure that client IP addresses are
recorded correctly.

View File

@@ -2,15 +2,13 @@ Purge Remote Media API
======================
The purge remote media API allows server admins to purge old cached remote
media.
media.
The API is::
POST /_matrix/client/r0/admin/purge_media_cache
POST /_matrix/client/r0/admin/purge_media_cache?before_ts=<unix_timestamp_in_ms>&access_token=<access_token>
{
"before_ts": <unix_timestamp_in_ms>
}
{}
Which will remove all cached media that was last accessed before
``<unix_timestamp_in_ms>``.

View File

@@ -16,4 +16,4 @@
""" This is a reference implementation of a Matrix home server.
"""
__version__ = "0.18.6-rc1"
__version__ = "0.18.7"

View File

@@ -16,18 +16,14 @@
import logging
import pymacaroons
from canonicaljson import encode_canonical_json
from signedjson.key import decode_verify_key_bytes
from signedjson.sign import verify_signed_json, SignatureVerifyException
from twisted.internet import defer
from unpaddedbase64 import decode_base64
import synapse.types
from synapse import event_auth
from synapse.api.constants import EventTypes, Membership, JoinRules
from synapse.api.errors import AuthError, Codes, SynapseError, EventSizeError
from synapse.types import UserID, get_domain_from_id
from synapse.api.errors import AuthError, Codes
from synapse.types import UserID
from synapse.util.logcontext import preserve_context_over_fn
from synapse.util.logutils import log_function
from synapse.util.metrics import Measure
logger = logging.getLogger(__name__)
@@ -78,147 +74,7 @@ class Auth(object):
True if the auth checks pass.
"""
with Measure(self.clock, "auth.check"):
self.check_size_limits(event)
if not hasattr(event, "room_id"):
raise AuthError(500, "Event has no room_id: %s" % event)
if do_sig_check:
sender_domain = get_domain_from_id(event.sender)
event_id_domain = get_domain_from_id(event.event_id)
is_invite_via_3pid = (
event.type == EventTypes.Member
and event.membership == Membership.INVITE
and "third_party_invite" in event.content
)
# Check the sender's domain has signed the event
if not event.signatures.get(sender_domain):
# We allow invites via 3pid to have a sender from a different
# HS, as the sender must match the sender of the original
# 3pid invite. This is checked further down with the
# other dedicated membership checks.
if not is_invite_via_3pid:
raise AuthError(403, "Event not signed by sender's server")
# Check the event_id's domain has signed the event
if not event.signatures.get(event_id_domain):
raise AuthError(403, "Event not signed by sending server")
if auth_events is None:
# Oh, we don't know what the state of the room was, so we
# are trusting that this is allowed (at least for now)
logger.warn("Trusting event: %s", event.event_id)
return True
if event.type == EventTypes.Create:
room_id_domain = get_domain_from_id(event.room_id)
if room_id_domain != sender_domain:
raise AuthError(
403,
"Creation event's room_id domain does not match sender's"
)
# FIXME
return True
creation_event = auth_events.get((EventTypes.Create, ""), None)
if not creation_event:
raise SynapseError(
403,
"Room %r does not exist" % (event.room_id,)
)
creating_domain = get_domain_from_id(event.room_id)
originating_domain = get_domain_from_id(event.sender)
if creating_domain != originating_domain:
if not self.can_federate(event, auth_events):
raise AuthError(
403,
"This room has been marked as unfederatable."
)
# FIXME: Temp hack
if event.type == EventTypes.Aliases:
if not event.is_state():
raise AuthError(
403,
"Alias event must be a state event",
)
if not event.state_key:
raise AuthError(
403,
"Alias event must have non-empty state_key"
)
sender_domain = get_domain_from_id(event.sender)
if event.state_key != sender_domain:
raise AuthError(
403,
"Alias event's state_key does not match sender's domain"
)
return True
logger.debug(
"Auth events: %s",
[a.event_id for a in auth_events.values()]
)
if event.type == EventTypes.Member:
allowed = self.is_membership_change_allowed(
event, auth_events
)
if allowed:
logger.debug("Allowing! %s", event)
else:
logger.debug("Denying! %s", event)
return allowed
self.check_event_sender_in_room(event, auth_events)
# Special case to allow m.room.third_party_invite events wherever
# a user is allowed to issue invites. Fixes
# https://github.com/vector-im/vector-web/issues/1208 hopefully
if event.type == EventTypes.ThirdPartyInvite:
user_level = self._get_user_power_level(event.user_id, auth_events)
invite_level = self._get_named_level(auth_events, "invite", 0)
if user_level < invite_level:
raise AuthError(
403, (
"You cannot issue a third party invite for %s." %
(event.content.display_name,)
)
)
else:
return True
self._can_send_event(event, auth_events)
if event.type == EventTypes.PowerLevels:
self._check_power_levels(event, auth_events)
if event.type == EventTypes.Redaction:
self.check_redaction(event, auth_events)
logger.debug("Allowing! %s", event)
def check_size_limits(self, event):
def too_big(field):
raise EventSizeError("%s too large" % (field,))
if len(event.user_id) > 255:
too_big("user_id")
if len(event.room_id) > 255:
too_big("room_id")
if event.is_state() and len(event.state_key) > 255:
too_big("state_key")
if len(event.type) > 255:
too_big("type")
if len(event.event_id) > 255:
too_big("event_id")
if len(encode_canonical_json(event.get_pdu_json())) > 65536:
too_big("event")
event_auth.check(event, auth_events, do_sig_check=do_sig_check)
@defer.inlineCallbacks
def check_joined_room(self, room_id, user_id, current_state=None):
@@ -290,6 +146,7 @@ class Auth(object):
with Measure(self.clock, "check_host_in_room"):
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
logger.debug("calling resolve_state_groups from check_host_in_room")
entry = yield self.state.resolve_state_groups(
room_id, latest_event_ids
)
@@ -299,16 +156,6 @@ class Auth(object):
)
defer.returnValue(ret)
def check_event_sender_in_room(self, event, auth_events):
key = (EventTypes.Member, event.user_id, )
member_event = auth_events.get(key)
return self._check_joined_room(
member_event,
event.user_id,
event.room_id
)
def _check_joined_room(self, member, user_id, room_id):
if not member or member.membership != Membership.JOIN:
raise AuthError(403, "User %s not in room %s (%s)" % (
@@ -320,267 +167,8 @@ class Auth(object):
return creation_event.content.get("m.federate", True) is True
@log_function
def is_membership_change_allowed(self, event, auth_events):
membership = event.content["membership"]
# Check if this is the room creator joining:
if len(event.prev_events) == 1 and Membership.JOIN == membership:
# Get room creation event:
key = (EventTypes.Create, "", )
create = auth_events.get(key)
if create and event.prev_events[0][0] == create.event_id:
if create.content["creator"] == event.state_key:
return True
target_user_id = event.state_key
creating_domain = get_domain_from_id(event.room_id)
target_domain = get_domain_from_id(target_user_id)
if creating_domain != target_domain:
if not self.can_federate(event, auth_events):
raise AuthError(
403,
"This room has been marked as unfederatable."
)
# get info about the caller
key = (EventTypes.Member, event.user_id, )
caller = auth_events.get(key)
caller_in_room = caller and caller.membership == Membership.JOIN
caller_invited = caller and caller.membership == Membership.INVITE
# get info about the target
key = (EventTypes.Member, target_user_id, )
target = auth_events.get(key)
target_in_room = target and target.membership == Membership.JOIN
target_banned = target and target.membership == Membership.BAN
key = (EventTypes.JoinRules, "", )
join_rule_event = auth_events.get(key)
if join_rule_event:
join_rule = join_rule_event.content.get(
"join_rule", JoinRules.INVITE
)
else:
join_rule = JoinRules.INVITE
user_level = self._get_user_power_level(event.user_id, auth_events)
target_level = self._get_user_power_level(
target_user_id, auth_events
)
# FIXME (erikj): What should we do here as the default?
ban_level = self._get_named_level(auth_events, "ban", 50)
logger.debug(
"is_membership_change_allowed: %s",
{
"caller_in_room": caller_in_room,
"caller_invited": caller_invited,
"target_banned": target_banned,
"target_in_room": target_in_room,
"membership": membership,
"join_rule": join_rule,
"target_user_id": target_user_id,
"event.user_id": event.user_id,
}
)
if Membership.INVITE == membership and "third_party_invite" in event.content:
if not self._verify_third_party_invite(event, auth_events):
raise AuthError(403, "You are not invited to this room.")
if target_banned:
raise AuthError(
403, "%s is banned from the room" % (target_user_id,)
)
return True
if Membership.JOIN != membership:
if (caller_invited
and Membership.LEAVE == membership
and target_user_id == event.user_id):
return True
if not caller_in_room: # caller isn't joined
raise AuthError(
403,
"%s not in room %s." % (event.user_id, event.room_id,)
)
if Membership.INVITE == membership:
# TODO (erikj): We should probably handle this more intelligently
# PRIVATE join rules.
# Invites are valid iff caller is in the room and target isn't.
if target_banned:
raise AuthError(
403, "%s is banned from the room" % (target_user_id,)
)
elif target_in_room: # the target is already in the room.
raise AuthError(403, "%s is already in the room." %
target_user_id)
else:
invite_level = self._get_named_level(auth_events, "invite", 0)
if user_level < invite_level:
raise AuthError(
403, "You cannot invite user %s." % target_user_id
)
elif Membership.JOIN == membership:
# Joins are valid iff caller == target and they were:
# invited: They are accepting the invitation
# joined: It's a NOOP
if event.user_id != target_user_id:
raise AuthError(403, "Cannot force another user to join.")
elif target_banned:
raise AuthError(403, "You are banned from this room")
elif join_rule == JoinRules.PUBLIC:
pass
elif join_rule == JoinRules.INVITE:
if not caller_in_room and not caller_invited:
raise AuthError(403, "You are not invited to this room.")
else:
# TODO (erikj): may_join list
# TODO (erikj): private rooms
raise AuthError(403, "You are not allowed to join this room")
elif Membership.LEAVE == membership:
# TODO (erikj): Implement kicks.
if target_banned and user_level < ban_level:
raise AuthError(
403, "You cannot unban user &s." % (target_user_id,)
)
elif target_user_id != event.user_id:
kick_level = self._get_named_level(auth_events, "kick", 50)
if user_level < kick_level or user_level <= target_level:
raise AuthError(
403, "You cannot kick user %s." % target_user_id
)
elif Membership.BAN == membership:
if user_level < ban_level or user_level <= target_level:
raise AuthError(403, "You don't have permission to ban")
else:
raise AuthError(500, "Unknown membership %s" % membership)
return True
def _verify_third_party_invite(self, event, auth_events):
"""
Validates that the invite event is authorized by a previous third-party invite.
Checks that the public key, and keyserver, match those in the third party invite,
and that the invite event has a signature issued using that public key.
Args:
event: The m.room.member join event being validated.
auth_events: All relevant previous context events which may be used
for authorization decisions.
Return:
True if the event fulfills the expectations of a previous third party
invite event.
"""
if "third_party_invite" not in event.content:
return False
if "signed" not in event.content["third_party_invite"]:
return False
signed = event.content["third_party_invite"]["signed"]
for key in {"mxid", "token"}:
if key not in signed:
return False
token = signed["token"]
invite_event = auth_events.get(
(EventTypes.ThirdPartyInvite, token,)
)
if not invite_event:
return False
if invite_event.sender != event.sender:
return False
if event.user_id != invite_event.user_id:
return False
if signed["mxid"] != event.state_key:
return False
if signed["token"] != token:
return False
for public_key_object in self.get_public_keys(invite_event):
public_key = public_key_object["public_key"]
try:
for server, signature_block in signed["signatures"].items():
for key_name, encoded_signature in signature_block.items():
if not key_name.startswith("ed25519:"):
continue
verify_key = decode_verify_key_bytes(
key_name,
decode_base64(public_key)
)
verify_signed_json(signed, server, verify_key)
# We got the public key from the invite, so we know that the
# correct server signed the signed bundle.
# The caller is responsible for checking that the signing
# server has not revoked that public key.
return True
except (KeyError, SignatureVerifyException,):
continue
return False
def get_public_keys(self, invite_event):
public_keys = []
if "public_key" in invite_event.content:
o = {
"public_key": invite_event.content["public_key"],
}
if "key_validity_url" in invite_event.content:
o["key_validity_url"] = invite_event.content["key_validity_url"]
public_keys.append(o)
public_keys.extend(invite_event.content.get("public_keys", []))
return public_keys
def _get_power_level_event(self, auth_events):
key = (EventTypes.PowerLevels, "", )
return auth_events.get(key)
def _get_user_power_level(self, user_id, auth_events):
power_level_event = self._get_power_level_event(auth_events)
if power_level_event:
level = power_level_event.content.get("users", {}).get(user_id)
if not level:
level = power_level_event.content.get("users_default", 0)
if level is None:
return 0
else:
return int(level)
else:
key = (EventTypes.Create, "", )
create_event = auth_events.get(key)
if (create_event is not None and
create_event.content["creator"] == user_id):
return 100
else:
return 0
def _get_named_level(self, auth_events, name, default):
power_level_event = self._get_power_level_event(auth_events)
if not power_level_event:
return default
level = power_level_event.content.get(name, None)
if level is not None:
return int(level)
else:
return default
return event_auth.get_public_keys(invite_event)
@defer.inlineCallbacks
def get_user_by_req(self, request, allow_guest=False, rights="access"):
@@ -973,56 +561,6 @@ class Auth(object):
defer.returnValue(auth_ids)
def _get_send_level(self, etype, state_key, auth_events):
key = (EventTypes.PowerLevels, "", )
send_level_event = auth_events.get(key)
send_level = None
if send_level_event:
send_level = send_level_event.content.get("events", {}).get(
etype
)
if send_level is None:
if state_key is not None:
send_level = send_level_event.content.get(
"state_default", 50
)
else:
send_level = send_level_event.content.get(
"events_default", 0
)
if send_level:
send_level = int(send_level)
else:
send_level = 0
return send_level
@log_function
def _can_send_event(self, event, auth_events):
send_level = self._get_send_level(
event.type, event.get("state_key", None), auth_events
)
user_level = self._get_user_power_level(event.user_id, auth_events)
if user_level < send_level:
raise AuthError(
403,
"You don't have permission to post that to the room. " +
"user_level (%d) < send_level (%d)" % (user_level, send_level)
)
# Check state_key
if hasattr(event, "state_key"):
if event.state_key.startswith("@"):
if event.state_key != event.user_id:
raise AuthError(
403,
"You are not allowed to set others state"
)
return True
def check_redaction(self, event, auth_events):
"""Check whether the event sender is allowed to redact the target event.
@@ -1036,107 +574,7 @@ class Auth(object):
AuthError if the event sender is definitely not allowed to redact
the target event.
"""
user_level = self._get_user_power_level(event.user_id, auth_events)
redact_level = self._get_named_level(auth_events, "redact", 50)
if user_level >= redact_level:
return False
redacter_domain = get_domain_from_id(event.event_id)
redactee_domain = get_domain_from_id(event.redacts)
if redacter_domain == redactee_domain:
return True
raise AuthError(
403,
"You don't have permission to redact events"
)
def _check_power_levels(self, event, auth_events):
user_list = event.content.get("users", {})
# Validate users
for k, v in user_list.items():
try:
UserID.from_string(k)
except:
raise SynapseError(400, "Not a valid user_id: %s" % (k,))
try:
int(v)
except:
raise SynapseError(400, "Not a valid power level: %s" % (v,))
key = (event.type, event.state_key, )
current_state = auth_events.get(key)
if not current_state:
return
user_level = self._get_user_power_level(event.user_id, auth_events)
# Check other levels:
levels_to_check = [
("users_default", None),
("events_default", None),
("state_default", None),
("ban", None),
("redact", None),
("kick", None),
("invite", None),
]
old_list = current_state.content.get("users")
for user in set(old_list.keys() + user_list.keys()):
levels_to_check.append(
(user, "users")
)
old_list = current_state.content.get("events")
new_list = event.content.get("events")
for ev_id in set(old_list.keys() + new_list.keys()):
levels_to_check.append(
(ev_id, "events")
)
old_state = current_state.content
new_state = event.content
for level_to_check, dir in levels_to_check:
old_loc = old_state
new_loc = new_state
if dir:
old_loc = old_loc.get(dir, {})
new_loc = new_loc.get(dir, {})
if level_to_check in old_loc:
old_level = int(old_loc[level_to_check])
else:
old_level = None
if level_to_check in new_loc:
new_level = int(new_loc[level_to_check])
else:
new_level = None
if new_level is not None and old_level is not None:
if new_level == old_level:
continue
if dir == "users" and level_to_check != event.user_id:
if old_level == user_level:
raise AuthError(
403,
"You don't have permission to remove ops level equal "
"to your own"
)
if old_level > user_level or new_level > user_level:
raise AuthError(
403,
"You don't have permission to add ops level greater "
"than your own"
)
return event_auth.check_redaction(event, auth_events)
@defer.inlineCallbacks
def check_can_change_room_list(self, room_id, user):
@@ -1166,10 +604,10 @@ class Auth(object):
if power_level_event:
auth_events[(EventTypes.PowerLevels, "")] = power_level_event
send_level = self._get_send_level(
send_level = event_auth.get_send_level(
EventTypes.Aliases, "", auth_events
)
user_level = self._get_user_power_level(user_id, auth_events)
user_level = event_auth.get_user_power_level(user_id, auth_events)
if user_level < send_level:
raise AuthError(

View File

@@ -76,7 +76,7 @@ class AppserviceServer(HomeServer):
def _listen_http(self, listener_config):
port = listener_config["port"]
bind_address = listener_config.get("bind_address", "")
bind_addresses = listener_config["bind_addresses"]
site_tag = listener_config.get("tag", port)
resources = {}
for res in listener_config["resources"]:
@@ -85,16 +85,19 @@ class AppserviceServer(HomeServer):
resources[METRICS_PREFIX] = MetricsResource(self)
root_resource = create_resource_tree(resources, Resource())
reactor.listenTCP(
port,
SynapseSite(
"synapse.access.http.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
),
interface=bind_address
)
for address in bind_addresses:
reactor.listenTCP(
port,
SynapseSite(
"synapse.access.http.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
),
interface=address
)
logger.info("Synapse appservice now listening on port %d", port)
def start_listening(self, listeners):
@@ -102,15 +105,18 @@ class AppserviceServer(HomeServer):
if listener["type"] == "http":
self._listen_http(listener)
elif listener["type"] == "manhole":
reactor.listenTCP(
listener["port"],
manhole(
username="matrix",
password="rabbithole",
globals={"hs": self},
),
interface=listener.get("bind_address", '127.0.0.1')
)
bind_addresses = listener["bind_addresses"]
for address in bind_addresses:
reactor.listenTCP(
listener["port"],
manhole(
username="matrix",
password="rabbithole",
globals={"hs": self},
),
interface=address
)
else:
logger.warn("Unrecognized listener type: %s", listener["type"])

View File

@@ -90,7 +90,7 @@ class ClientReaderServer(HomeServer):
def _listen_http(self, listener_config):
port = listener_config["port"]
bind_address = listener_config.get("bind_address", "")
bind_addresses = listener_config["bind_addresses"]
site_tag = listener_config.get("tag", port)
resources = {}
for res in listener_config["resources"]:
@@ -108,16 +108,19 @@ class ClientReaderServer(HomeServer):
})
root_resource = create_resource_tree(resources, Resource())
reactor.listenTCP(
port,
SynapseSite(
"synapse.access.http.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
),
interface=bind_address
)
for address in bind_addresses:
reactor.listenTCP(
port,
SynapseSite(
"synapse.access.http.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
),
interface=address
)
logger.info("Synapse client reader now listening on port %d", port)
def start_listening(self, listeners):
@@ -125,15 +128,18 @@ class ClientReaderServer(HomeServer):
if listener["type"] == "http":
self._listen_http(listener)
elif listener["type"] == "manhole":
reactor.listenTCP(
listener["port"],
manhole(
username="matrix",
password="rabbithole",
globals={"hs": self},
),
interface=listener.get("bind_address", '127.0.0.1')
)
bind_addresses = listener["bind_addresses"]
for address in bind_addresses:
reactor.listenTCP(
listener["port"],
manhole(
username="matrix",
password="rabbithole",
globals={"hs": self},
),
interface=address
)
else:
logger.warn("Unrecognized listener type: %s", listener["type"])

View File

@@ -86,7 +86,7 @@ class FederationReaderServer(HomeServer):
def _listen_http(self, listener_config):
port = listener_config["port"]
bind_address = listener_config.get("bind_address", "")
bind_addresses = listener_config["bind_addresses"]
site_tag = listener_config.get("tag", port)
resources = {}
for res in listener_config["resources"]:
@@ -99,16 +99,19 @@ class FederationReaderServer(HomeServer):
})
root_resource = create_resource_tree(resources, Resource())
reactor.listenTCP(
port,
SynapseSite(
"synapse.access.http.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
),
interface=bind_address
)
for address in bind_addresses:
reactor.listenTCP(
port,
SynapseSite(
"synapse.access.http.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
),
interface=address
)
logger.info("Synapse federation reader now listening on port %d", port)
def start_listening(self, listeners):
@@ -116,15 +119,18 @@ class FederationReaderServer(HomeServer):
if listener["type"] == "http":
self._listen_http(listener)
elif listener["type"] == "manhole":
reactor.listenTCP(
listener["port"],
manhole(
username="matrix",
password="rabbithole",
globals={"hs": self},
),
interface=listener.get("bind_address", '127.0.0.1')
)
bind_addresses = listener["bind_addresses"]
for address in bind_addresses:
reactor.listenTCP(
listener["port"],
manhole(
username="matrix",
password="rabbithole",
globals={"hs": self},
),
interface=address
)
else:
logger.warn("Unrecognized listener type: %s", listener["type"])

View File

@@ -30,6 +30,7 @@ from synapse.replication.slave.storage.events import SlavedEventStore
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
from synapse.replication.slave.storage.transactions import TransactionStore
from synapse.replication.slave.storage.devices import SlavedDeviceStore
from synapse.storage.engines import create_engine
from synapse.storage.presence import UserPresenceState
from synapse.util.async import sleep
@@ -56,7 +57,7 @@ logger = logging.getLogger("synapse.app.appservice")
class FederationSenderSlaveStore(
SlavedDeviceInboxStore, TransactionStore, SlavedReceiptsStore, SlavedEventStore,
SlavedRegistrationStore,
SlavedRegistrationStore, SlavedDeviceStore,
):
pass
@@ -82,7 +83,7 @@ class FederationSenderServer(HomeServer):
def _listen_http(self, listener_config):
port = listener_config["port"]
bind_address = listener_config.get("bind_address", "")
bind_addresses = listener_config["bind_addresses"]
site_tag = listener_config.get("tag", port)
resources = {}
for res in listener_config["resources"]:
@@ -91,16 +92,19 @@ class FederationSenderServer(HomeServer):
resources[METRICS_PREFIX] = MetricsResource(self)
root_resource = create_resource_tree(resources, Resource())
reactor.listenTCP(
port,
SynapseSite(
"synapse.access.http.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
),
interface=bind_address
)
for address in bind_addresses:
reactor.listenTCP(
port,
SynapseSite(
"synapse.access.http.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
),
interface=address
)
logger.info("Synapse federation_sender now listening on port %d", port)
def start_listening(self, listeners):
@@ -108,15 +112,18 @@ class FederationSenderServer(HomeServer):
if listener["type"] == "http":
self._listen_http(listener)
elif listener["type"] == "manhole":
reactor.listenTCP(
listener["port"],
manhole(
username="matrix",
password="rabbithole",
globals={"hs": self},
),
interface=listener.get("bind_address", '127.0.0.1')
)
bind_addresses = listener["bind_addresses"]
for address in bind_addresses:
reactor.listenTCP(
listener["port"],
manhole(
username="matrix",
password="rabbithole",
globals={"hs": self},
),
interface=address
)
else:
logger.warn("Unrecognized listener type: %s", listener["type"])

View File

@@ -107,7 +107,7 @@ def build_resource_for_web_client(hs):
class SynapseHomeServer(HomeServer):
def _listener_http(self, config, listener_config):
port = listener_config["port"]
bind_address = listener_config.get("bind_address", "")
bind_addresses = listener_config["bind_addresses"]
tls = listener_config.get("tls", False)
site_tag = listener_config.get("tag", port)
@@ -173,29 +173,32 @@ class SynapseHomeServer(HomeServer):
root_resource = Resource()
root_resource = create_resource_tree(resources, root_resource)
if tls:
reactor.listenSSL(
port,
SynapseSite(
"synapse.access.https.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
),
self.tls_server_context_factory,
interface=bind_address
)
for address in bind_addresses:
reactor.listenSSL(
port,
SynapseSite(
"synapse.access.https.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
),
self.tls_server_context_factory,
interface=address
)
else:
reactor.listenTCP(
port,
SynapseSite(
"synapse.access.http.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
),
interface=bind_address
)
for address in bind_addresses:
reactor.listenTCP(
port,
SynapseSite(
"synapse.access.http.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
),
interface=address
)
logger.info("Synapse now listening on port %d", port)
def start_listening(self):
@@ -205,15 +208,18 @@ class SynapseHomeServer(HomeServer):
if listener["type"] == "http":
self._listener_http(config, listener)
elif listener["type"] == "manhole":
reactor.listenTCP(
listener["port"],
manhole(
username="matrix",
password="rabbithole",
globals={"hs": self},
),
interface=listener.get("bind_address", '127.0.0.1')
)
bind_addresses = listener["bind_addresses"]
for address in bind_addresses:
reactor.listenTCP(
listener["port"],
manhole(
username="matrix",
password="rabbithole",
globals={"hs": self},
),
interface=address
)
else:
logger.warn("Unrecognized listener type: %s", listener["type"])

View File

@@ -87,7 +87,7 @@ class MediaRepositoryServer(HomeServer):
def _listen_http(self, listener_config):
port = listener_config["port"]
bind_address = listener_config.get("bind_address", "")
bind_addresses = listener_config["bind_addresses"]
site_tag = listener_config.get("tag", port)
resources = {}
for res in listener_config["resources"]:
@@ -105,16 +105,19 @@ class MediaRepositoryServer(HomeServer):
})
root_resource = create_resource_tree(resources, Resource())
reactor.listenTCP(
port,
SynapseSite(
"synapse.access.http.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
),
interface=bind_address
)
for address in bind_addresses:
reactor.listenTCP(
port,
SynapseSite(
"synapse.access.http.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
),
interface=address
)
logger.info("Synapse media repository now listening on port %d", port)
def start_listening(self, listeners):
@@ -122,15 +125,18 @@ class MediaRepositoryServer(HomeServer):
if listener["type"] == "http":
self._listen_http(listener)
elif listener["type"] == "manhole":
reactor.listenTCP(
listener["port"],
manhole(
username="matrix",
password="rabbithole",
globals={"hs": self},
),
interface=listener.get("bind_address", '127.0.0.1')
)
bind_addresses = listener["bind_addresses"]
for address in bind_addresses:
reactor.listenTCP(
listener["port"],
manhole(
username="matrix",
password="rabbithole",
globals={"hs": self},
),
interface=address
)
else:
logger.warn("Unrecognized listener type: %s", listener["type"])

View File

@@ -121,7 +121,7 @@ class PusherServer(HomeServer):
def _listen_http(self, listener_config):
port = listener_config["port"]
bind_address = listener_config.get("bind_address", "")
bind_addresses = listener_config["bind_addresses"]
site_tag = listener_config.get("tag", port)
resources = {}
for res in listener_config["resources"]:
@@ -130,16 +130,19 @@ class PusherServer(HomeServer):
resources[METRICS_PREFIX] = MetricsResource(self)
root_resource = create_resource_tree(resources, Resource())
reactor.listenTCP(
port,
SynapseSite(
"synapse.access.http.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
),
interface=bind_address
)
for address in bind_addresses:
reactor.listenTCP(
port,
SynapseSite(
"synapse.access.http.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
),
interface=address
)
logger.info("Synapse pusher now listening on port %d", port)
def start_listening(self, listeners):
@@ -147,15 +150,18 @@ class PusherServer(HomeServer):
if listener["type"] == "http":
self._listen_http(listener)
elif listener["type"] == "manhole":
reactor.listenTCP(
listener["port"],
manhole(
username="matrix",
password="rabbithole",
globals={"hs": self},
),
interface=listener.get("bind_address", '127.0.0.1')
)
bind_addresses = listener["bind_addresses"]
for address in bind_addresses:
reactor.listenTCP(
listener["port"],
manhole(
username="matrix",
password="rabbithole",
globals={"hs": self},
),
interface=address
)
else:
logger.warn("Unrecognized listener type: %s", listener["type"])

View File

@@ -39,6 +39,7 @@ from synapse.replication.slave.storage.filtering import SlavedFilteringStore
from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore
from synapse.replication.slave.storage.presence import SlavedPresenceStore
from synapse.replication.slave.storage.deviceinbox import SlavedDeviceInboxStore
from synapse.replication.slave.storage.devices import SlavedDeviceStore
from synapse.replication.slave.storage.room import RoomStore
from synapse.server import HomeServer
from synapse.storage.client_ips import ClientIpStore
@@ -77,6 +78,7 @@ class SynchrotronSlavedStore(
SlavedFilteringStore,
SlavedPresenceStore,
SlavedDeviceInboxStore,
SlavedDeviceStore,
RoomStore,
BaseSlavedStore,
ClientIpStore, # After BaseSlavedStore because the constructor is different
@@ -289,7 +291,7 @@ class SynchrotronServer(HomeServer):
def _listen_http(self, listener_config):
port = listener_config["port"]
bind_address = listener_config.get("bind_address", "")
bind_addresses = listener_config["bind_addresses"]
site_tag = listener_config.get("tag", port)
resources = {}
for res in listener_config["resources"]:
@@ -310,16 +312,19 @@ class SynchrotronServer(HomeServer):
})
root_resource = create_resource_tree(resources, Resource())
reactor.listenTCP(
port,
SynapseSite(
"synapse.access.http.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
),
interface=bind_address
)
for address in bind_addresses:
reactor.listenTCP(
port,
SynapseSite(
"synapse.access.http.%s" % (site_tag,),
site_tag,
listener_config,
root_resource,
),
interface=address
)
logger.info("Synapse synchrotron now listening on port %d", port)
def start_listening(self, listeners):
@@ -327,15 +332,18 @@ class SynchrotronServer(HomeServer):
if listener["type"] == "http":
self._listen_http(listener)
elif listener["type"] == "manhole":
reactor.listenTCP(
listener["port"],
manhole(
username="matrix",
password="rabbithole",
globals={"hs": self},
),
interface=listener.get("bind_address", '127.0.0.1')
)
bind_addresses = listener["bind_addresses"]
for address in bind_addresses:
reactor.listenTCP(
listener["port"],
manhole(
username="matrix",
password="rabbithole",
globals={"hs": self},
),
interface=address
)
else:
logger.warn("Unrecognized listener type: %s", listener["type"])
@@ -374,6 +382,27 @@ class SynchrotronServer(HomeServer):
stream_key, position, users=users, rooms=rooms
)
@defer.inlineCallbacks
def notify_device_list_update(result):
stream = result.get("device_lists")
if not stream:
return
position_index = stream["field_names"].index("position")
user_index = stream["field_names"].index("user_id")
for row in stream["rows"]:
position = row[position_index]
user_id = row[user_index]
rooms = yield store.get_rooms_for_user(user_id)
room_ids = [r.room_id for r in rooms]
notifier.on_new_event(
"device_list_key", position, rooms=room_ids,
)
@defer.inlineCallbacks
def notify(result):
stream = result.get("events")
if stream:
@@ -411,6 +440,7 @@ class SynchrotronServer(HomeServer):
notify_from_stream(
result, "to_device", "to_device_key", user="user_id"
)
yield notify_device_list_update(result)
while True:
try:
@@ -421,7 +451,7 @@ class SynchrotronServer(HomeServer):
yield store.process_replication(result)
typing_handler.process_replication(result)
yield presence_handler.process_replication(result)
notify(result)
yield notify(result)
except:
logger.exception("Error replicating from %r", replication_url)
yield sleep(5)

View File

@@ -68,6 +68,9 @@ class EmailConfig(Config):
self.email_notif_for_new_users = email_config.get(
"notif_for_new_users", True
)
self.email_riot_base_url = email_config.get(
"riot_base_url", None
)
if "app_name" in email_config:
self.email_app_name = email_config["app_name"]
else:
@@ -85,6 +88,9 @@ class EmailConfig(Config):
def default_config(self, config_dir_path, server_name, **kwargs):
return """
# Enable sending emails for notification events
# Defining a custom URL for Riot is only needed if email notifications
# should contain links to a self-hosted installation of Riot; when set
# the "app_name" setting is ignored.
#email:
# enable_notifs: false
# smtp_host: "localhost"
@@ -95,4 +101,5 @@ class EmailConfig(Config):
# notif_template_html: notif_mail.html
# notif_template_text: notif_mail.txt
# notif_for_new_users: True
# riot_base_url: "http://localhost/riot"
"""

View File

@@ -15,14 +15,13 @@
from ._base import Config
from synapse.util.logcontext import LoggingContextFilter
from twisted.python.log import PythonLoggingObserver
from twisted.logger import globalLogBeginner, STDLibLogObserver
import logging
import logging.config
import yaml
from string import Template
import os
import signal
from synapse.util.debug import debug_deferreds
DEFAULT_LOG_CONFIG = Template("""
@@ -71,8 +70,6 @@ class LoggingConfig(Config):
self.verbosity = config.get("verbose", 0)
self.log_config = self.abspath(config.get("log_config"))
self.log_file = self.abspath(config.get("log_file"))
if config.get("full_twisted_stacktraces"):
debug_deferreds()
def default_config(self, config_dir_path, server_name, **kwargs):
log_file = self.abspath("homeserver.log")
@@ -88,11 +85,6 @@ class LoggingConfig(Config):
# A yaml python logging config file
log_config: "%(log_config)s"
# Stop twisted from discarding the stack traces of exceptions in
# deferreds by waiting a reactor tick before running a deferred's
# callbacks.
# full_twisted_stacktraces: true
""" % locals()
def read_arguments(self, args):
@@ -180,5 +172,15 @@ def setup_logging(log_config=None, log_file=None, verbosity=None):
with open(log_config, 'r') as f:
logging.config.dictConfig(yaml.load(f))
observer = PythonLoggingObserver()
observer.start()
# It's critical to point twisted's internal logging somewhere, otherwise it
# stacks up and leaks kup to 64K object;
# see: https://twistedmatrix.com/trac/ticket/8164
#
# Routing to the python logging framework could be a performance problem if
# the handlers blocked for a long time as python.logging is a blocking API
# see https://twistedmatrix.com/documents/current/core/howto/logger.html
# filed as https://github.com/matrix-org/synapse/issues/1727
#
# However this may not be too much of a problem if we are just writing to a file.
observer = STDLibLogObserver()
globalLogBeginner.beginLoggingTo([observer])

View File

@@ -42,6 +42,15 @@ class ServerConfig(Config):
self.listeners = config.get("listeners", [])
for listener in self.listeners:
bind_address = listener.pop("bind_address", None)
bind_addresses = listener.setdefault("bind_addresses", [])
if bind_address:
bind_addresses.append(bind_address)
elif not bind_addresses:
bind_addresses.append('')
self.gc_thresholds = read_gc_thresholds(config.get("gc_thresholds", None))
bind_port = config.get("bind_port")
@@ -54,7 +63,7 @@ class ServerConfig(Config):
self.listeners.append({
"port": bind_port,
"bind_address": bind_host,
"bind_addresses": [bind_host],
"tls": True,
"type": "http",
"resources": [
@@ -73,7 +82,7 @@ class ServerConfig(Config):
if unsecure_port:
self.listeners.append({
"port": unsecure_port,
"bind_address": bind_host,
"bind_addresses": [bind_host],
"tls": False,
"type": "http",
"resources": [
@@ -92,7 +101,7 @@ class ServerConfig(Config):
if manhole:
self.listeners.append({
"port": manhole,
"bind_address": "127.0.0.1",
"bind_addresses": ["127.0.0.1"],
"type": "manhole",
})
@@ -100,7 +109,7 @@ class ServerConfig(Config):
if metrics_port:
self.listeners.append({
"port": metrics_port,
"bind_address": config.get("metrics_bind_host", "127.0.0.1"),
"bind_addresses": [config.get("metrics_bind_host", "127.0.0.1")],
"tls": False,
"type": "http",
"resources": [
@@ -155,9 +164,14 @@ class ServerConfig(Config):
# The port to listen for HTTPS requests on.
port: %(bind_port)s
# Local interface to listen on.
# The empty string will cause synapse to listen on all interfaces.
bind_address: ''
# Local addresses to listen on.
# This will listen on all IPv4 addresses by default.
bind_addresses:
- '0.0.0.0'
# Uncomment to listen on all IPv6 interfaces
# N.B: On at least Linux this will also listen on all IPv4
# addresses, so you will need to comment out the line above.
# - '::'
# This is a 'http' listener, allows us to specify 'resources'.
type: http
@@ -188,7 +202,7 @@ class ServerConfig(Config):
# For when matrix traffic passes through loadbalancer that unwraps TLS.
- port: %(unsecure_port)s
tls: false
bind_address: ''
bind_addresses: ['0.0.0.0']
type: http
x_forwarded: false

View File

@@ -19,7 +19,9 @@ class VoipConfig(Config):
def read_config(self, config):
self.turn_uris = config.get("turn_uris", [])
self.turn_shared_secret = config["turn_shared_secret"]
self.turn_shared_secret = config.get("turn_shared_secret")
self.turn_username = config.get("turn_username")
self.turn_password = config.get("turn_password")
self.turn_user_lifetime = self.parse_duration(config["turn_user_lifetime"])
def default_config(self, **kwargs):
@@ -32,6 +34,11 @@ class VoipConfig(Config):
# The shared secret used to compute passwords for the TURN server
turn_shared_secret: "YOUR_SHARED_SECRET"
# The Username and password if the TURN server needs them and
# does not use a token
#turn_username: "TURNSERVER_USERNAME"
#turn_password: "TURNSERVER_PASSWORD"
# How long generated TURN credentials last
turn_user_lifetime: "1h"
"""

View File

@@ -29,3 +29,13 @@ class WorkerConfig(Config):
self.worker_log_file = config.get("worker_log_file")
self.worker_log_config = config.get("worker_log_config")
self.worker_replication_url = config.get("worker_replication_url")
if self.worker_listeners:
for listener in self.worker_listeners:
bind_address = listener.pop("bind_address", None)
bind_addresses = listener.setdefault("bind_addresses", [])
if bind_address:
bind_addresses.append(bind_address)
elif not bind_addresses:
bind_addresses.append('')

678
synapse/event_auth.py Normal file
View File

@@ -0,0 +1,678 @@
# -*- coding: utf-8 -*-
# Copyright 2014 - 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from canonicaljson import encode_canonical_json
from signedjson.key import decode_verify_key_bytes
from signedjson.sign import verify_signed_json, SignatureVerifyException
from unpaddedbase64 import decode_base64
from synapse.api.constants import EventTypes, Membership, JoinRules
from synapse.api.errors import AuthError, SynapseError, EventSizeError
from synapse.types import UserID, get_domain_from_id
logger = logging.getLogger(__name__)
def check(event, auth_events, do_sig_check=True, do_size_check=True):
""" Checks if this event is correctly authed.
Args:
event: the event being checked.
auth_events (dict: event-key -> event): the existing room state.
Returns:
True if the auth checks pass.
"""
if do_size_check:
_check_size_limits(event)
if not hasattr(event, "room_id"):
raise AuthError(500, "Event has no room_id: %s" % event)
if do_sig_check:
sender_domain = get_domain_from_id(event.sender)
event_id_domain = get_domain_from_id(event.event_id)
is_invite_via_3pid = (
event.type == EventTypes.Member
and event.membership == Membership.INVITE
and "third_party_invite" in event.content
)
# Check the sender's domain has signed the event
if not event.signatures.get(sender_domain):
# We allow invites via 3pid to have a sender from a different
# HS, as the sender must match the sender of the original
# 3pid invite. This is checked further down with the
# other dedicated membership checks.
if not is_invite_via_3pid:
raise AuthError(403, "Event not signed by sender's server")
# Check the event_id's domain has signed the event
if not event.signatures.get(event_id_domain):
raise AuthError(403, "Event not signed by sending server")
if auth_events is None:
# Oh, we don't know what the state of the room was, so we
# are trusting that this is allowed (at least for now)
logger.warn("Trusting event: %s", event.event_id)
return True
if event.type == EventTypes.Create:
room_id_domain = get_domain_from_id(event.room_id)
if room_id_domain != sender_domain:
raise AuthError(
403,
"Creation event's room_id domain does not match sender's"
)
# FIXME
return True
creation_event = auth_events.get((EventTypes.Create, ""), None)
if not creation_event:
raise SynapseError(
403,
"Room %r does not exist" % (event.room_id,)
)
creating_domain = get_domain_from_id(event.room_id)
originating_domain = get_domain_from_id(event.sender)
if creating_domain != originating_domain:
if not _can_federate(event, auth_events):
raise AuthError(
403,
"This room has been marked as unfederatable."
)
# FIXME: Temp hack
if event.type == EventTypes.Aliases:
if not event.is_state():
raise AuthError(
403,
"Alias event must be a state event",
)
if not event.state_key:
raise AuthError(
403,
"Alias event must have non-empty state_key"
)
sender_domain = get_domain_from_id(event.sender)
if event.state_key != sender_domain:
raise AuthError(
403,
"Alias event's state_key does not match sender's domain"
)
return True
if logger.isEnabledFor(logging.DEBUG):
logger.debug(
"Auth events: %s",
[a.event_id for a in auth_events.values()]
)
if event.type == EventTypes.Member:
allowed = _is_membership_change_allowed(
event, auth_events
)
if allowed:
logger.debug("Allowing! %s", event)
else:
logger.debug("Denying! %s", event)
return allowed
_check_event_sender_in_room(event, auth_events)
# Special case to allow m.room.third_party_invite events wherever
# a user is allowed to issue invites. Fixes
# https://github.com/vector-im/vector-web/issues/1208 hopefully
if event.type == EventTypes.ThirdPartyInvite:
user_level = get_user_power_level(event.user_id, auth_events)
invite_level = _get_named_level(auth_events, "invite", 0)
if user_level < invite_level:
raise AuthError(
403, (
"You cannot issue a third party invite for %s." %
(event.content.display_name,)
)
)
else:
return True
_can_send_event(event, auth_events)
if event.type == EventTypes.PowerLevels:
_check_power_levels(event, auth_events)
if event.type == EventTypes.Redaction:
check_redaction(event, auth_events)
logger.debug("Allowing! %s", event)
def _check_size_limits(event):
def too_big(field):
raise EventSizeError("%s too large" % (field,))
if len(event.user_id) > 255:
too_big("user_id")
if len(event.room_id) > 255:
too_big("room_id")
if event.is_state() and len(event.state_key) > 255:
too_big("state_key")
if len(event.type) > 255:
too_big("type")
if len(event.event_id) > 255:
too_big("event_id")
if len(encode_canonical_json(event.get_pdu_json())) > 65536:
too_big("event")
def _can_federate(event, auth_events):
creation_event = auth_events.get((EventTypes.Create, ""))
return creation_event.content.get("m.federate", True) is True
def _is_membership_change_allowed(event, auth_events):
membership = event.content["membership"]
# Check if this is the room creator joining:
if len(event.prev_events) == 1 and Membership.JOIN == membership:
# Get room creation event:
key = (EventTypes.Create, "", )
create = auth_events.get(key)
if create and event.prev_events[0][0] == create.event_id:
if create.content["creator"] == event.state_key:
return True
target_user_id = event.state_key
creating_domain = get_domain_from_id(event.room_id)
target_domain = get_domain_from_id(target_user_id)
if creating_domain != target_domain:
if not _can_federate(event, auth_events):
raise AuthError(
403,
"This room has been marked as unfederatable."
)
# get info about the caller
key = (EventTypes.Member, event.user_id, )
caller = auth_events.get(key)
caller_in_room = caller and caller.membership == Membership.JOIN
caller_invited = caller and caller.membership == Membership.INVITE
# get info about the target
key = (EventTypes.Member, target_user_id, )
target = auth_events.get(key)
target_in_room = target and target.membership == Membership.JOIN
target_banned = target and target.membership == Membership.BAN
key = (EventTypes.JoinRules, "", )
join_rule_event = auth_events.get(key)
if join_rule_event:
join_rule = join_rule_event.content.get(
"join_rule", JoinRules.INVITE
)
else:
join_rule = JoinRules.INVITE
user_level = get_user_power_level(event.user_id, auth_events)
target_level = get_user_power_level(
target_user_id, auth_events
)
# FIXME (erikj): What should we do here as the default?
ban_level = _get_named_level(auth_events, "ban", 50)
logger.debug(
"_is_membership_change_allowed: %s",
{
"caller_in_room": caller_in_room,
"caller_invited": caller_invited,
"target_banned": target_banned,
"target_in_room": target_in_room,
"membership": membership,
"join_rule": join_rule,
"target_user_id": target_user_id,
"event.user_id": event.user_id,
}
)
if Membership.INVITE == membership and "third_party_invite" in event.content:
if not _verify_third_party_invite(event, auth_events):
raise AuthError(403, "You are not invited to this room.")
if target_banned:
raise AuthError(
403, "%s is banned from the room" % (target_user_id,)
)
return True
if Membership.JOIN != membership:
if (caller_invited
and Membership.LEAVE == membership
and target_user_id == event.user_id):
return True
if not caller_in_room: # caller isn't joined
raise AuthError(
403,
"%s not in room %s." % (event.user_id, event.room_id,)
)
if Membership.INVITE == membership:
# TODO (erikj): We should probably handle this more intelligently
# PRIVATE join rules.
# Invites are valid iff caller is in the room and target isn't.
if target_banned:
raise AuthError(
403, "%s is banned from the room" % (target_user_id,)
)
elif target_in_room: # the target is already in the room.
raise AuthError(403, "%s is already in the room." %
target_user_id)
else:
invite_level = _get_named_level(auth_events, "invite", 0)
if user_level < invite_level:
raise AuthError(
403, "You cannot invite user %s." % target_user_id
)
elif Membership.JOIN == membership:
# Joins are valid iff caller == target and they were:
# invited: They are accepting the invitation
# joined: It's a NOOP
if event.user_id != target_user_id:
raise AuthError(403, "Cannot force another user to join.")
elif target_banned:
raise AuthError(403, "You are banned from this room")
elif join_rule == JoinRules.PUBLIC:
pass
elif join_rule == JoinRules.INVITE:
if not caller_in_room and not caller_invited:
raise AuthError(403, "You are not invited to this room.")
else:
# TODO (erikj): may_join list
# TODO (erikj): private rooms
raise AuthError(403, "You are not allowed to join this room")
elif Membership.LEAVE == membership:
# TODO (erikj): Implement kicks.
if target_banned and user_level < ban_level:
raise AuthError(
403, "You cannot unban user &s." % (target_user_id,)
)
elif target_user_id != event.user_id:
kick_level = _get_named_level(auth_events, "kick", 50)
if user_level < kick_level or user_level <= target_level:
raise AuthError(
403, "You cannot kick user %s." % target_user_id
)
elif Membership.BAN == membership:
if user_level < ban_level or user_level <= target_level:
raise AuthError(403, "You don't have permission to ban")
else:
raise AuthError(500, "Unknown membership %s" % membership)
return True
def _check_event_sender_in_room(event, auth_events):
key = (EventTypes.Member, event.user_id, )
member_event = auth_events.get(key)
return _check_joined_room(
member_event,
event.user_id,
event.room_id
)
def _check_joined_room(member, user_id, room_id):
if not member or member.membership != Membership.JOIN:
raise AuthError(403, "User %s not in room %s (%s)" % (
user_id, room_id, repr(member)
))
def get_send_level(etype, state_key, auth_events):
key = (EventTypes.PowerLevels, "", )
send_level_event = auth_events.get(key)
send_level = None
if send_level_event:
send_level = send_level_event.content.get("events", {}).get(
etype
)
if send_level is None:
if state_key is not None:
send_level = send_level_event.content.get(
"state_default", 50
)
else:
send_level = send_level_event.content.get(
"events_default", 0
)
if send_level:
send_level = int(send_level)
else:
send_level = 0
return send_level
def _can_send_event(event, auth_events):
send_level = get_send_level(
event.type, event.get("state_key", None), auth_events
)
user_level = get_user_power_level(event.user_id, auth_events)
if user_level < send_level:
raise AuthError(
403,
"You don't have permission to post that to the room. " +
"user_level (%d) < send_level (%d)" % (user_level, send_level)
)
# Check state_key
if hasattr(event, "state_key"):
if event.state_key.startswith("@"):
if event.state_key != event.user_id:
raise AuthError(
403,
"You are not allowed to set others state"
)
return True
def check_redaction(event, auth_events):
"""Check whether the event sender is allowed to redact the target event.
Returns:
True if the the sender is allowed to redact the target event if the
target event was created by them.
False if the sender is allowed to redact the target event with no
further checks.
Raises:
AuthError if the event sender is definitely not allowed to redact
the target event.
"""
user_level = get_user_power_level(event.user_id, auth_events)
redact_level = _get_named_level(auth_events, "redact", 50)
if user_level >= redact_level:
return False
redacter_domain = get_domain_from_id(event.event_id)
redactee_domain = get_domain_from_id(event.redacts)
if redacter_domain == redactee_domain:
return True
raise AuthError(
403,
"You don't have permission to redact events"
)
def _check_power_levels(event, auth_events):
user_list = event.content.get("users", {})
# Validate users
for k, v in user_list.items():
try:
UserID.from_string(k)
except:
raise SynapseError(400, "Not a valid user_id: %s" % (k,))
try:
int(v)
except:
raise SynapseError(400, "Not a valid power level: %s" % (v,))
key = (event.type, event.state_key, )
current_state = auth_events.get(key)
if not current_state:
return
user_level = get_user_power_level(event.user_id, auth_events)
# Check other levels:
levels_to_check = [
("users_default", None),
("events_default", None),
("state_default", None),
("ban", None),
("redact", None),
("kick", None),
("invite", None),
]
old_list = current_state.content.get("users")
for user in set(old_list.keys() + user_list.keys()):
levels_to_check.append(
(user, "users")
)
old_list = current_state.content.get("events")
new_list = event.content.get("events")
for ev_id in set(old_list.keys() + new_list.keys()):
levels_to_check.append(
(ev_id, "events")
)
old_state = current_state.content
new_state = event.content
for level_to_check, dir in levels_to_check:
old_loc = old_state
new_loc = new_state
if dir:
old_loc = old_loc.get(dir, {})
new_loc = new_loc.get(dir, {})
if level_to_check in old_loc:
old_level = int(old_loc[level_to_check])
else:
old_level = None
if level_to_check in new_loc:
new_level = int(new_loc[level_to_check])
else:
new_level = None
if new_level is not None and old_level is not None:
if new_level == old_level:
continue
if dir == "users" and level_to_check != event.user_id:
if old_level == user_level:
raise AuthError(
403,
"You don't have permission to remove ops level equal "
"to your own"
)
if old_level > user_level or new_level > user_level:
raise AuthError(
403,
"You don't have permission to add ops level greater "
"than your own"
)
def _get_power_level_event(auth_events):
key = (EventTypes.PowerLevels, "", )
return auth_events.get(key)
def get_user_power_level(user_id, auth_events):
power_level_event = _get_power_level_event(auth_events)
if power_level_event:
level = power_level_event.content.get("users", {}).get(user_id)
if not level:
level = power_level_event.content.get("users_default", 0)
if level is None:
return 0
else:
return int(level)
else:
key = (EventTypes.Create, "", )
create_event = auth_events.get(key)
if (create_event is not None and
create_event.content["creator"] == user_id):
return 100
else:
return 0
def _get_named_level(auth_events, name, default):
power_level_event = _get_power_level_event(auth_events)
if not power_level_event:
return default
level = power_level_event.content.get(name, None)
if level is not None:
return int(level)
else:
return default
def _verify_third_party_invite(event, auth_events):
"""
Validates that the invite event is authorized by a previous third-party invite.
Checks that the public key, and keyserver, match those in the third party invite,
and that the invite event has a signature issued using that public key.
Args:
event: The m.room.member join event being validated.
auth_events: All relevant previous context events which may be used
for authorization decisions.
Return:
True if the event fulfills the expectations of a previous third party
invite event.
"""
if "third_party_invite" not in event.content:
return False
if "signed" not in event.content["third_party_invite"]:
return False
signed = event.content["third_party_invite"]["signed"]
for key in {"mxid", "token"}:
if key not in signed:
return False
token = signed["token"]
invite_event = auth_events.get(
(EventTypes.ThirdPartyInvite, token,)
)
if not invite_event:
return False
if invite_event.sender != event.sender:
return False
if event.user_id != invite_event.user_id:
return False
if signed["mxid"] != event.state_key:
return False
if signed["token"] != token:
return False
for public_key_object in get_public_keys(invite_event):
public_key = public_key_object["public_key"]
try:
for server, signature_block in signed["signatures"].items():
for key_name, encoded_signature in signature_block.items():
if not key_name.startswith("ed25519:"):
continue
verify_key = decode_verify_key_bytes(
key_name,
decode_base64(public_key)
)
verify_signed_json(signed, server, verify_key)
# We got the public key from the invite, so we know that the
# correct server signed the signed bundle.
# The caller is responsible for checking that the signing
# server has not revoked that public key.
return True
except (KeyError, SignatureVerifyException,):
continue
return False
def get_public_keys(invite_event):
public_keys = []
if "public_key" in invite_event.content:
o = {
"public_key": invite_event.content["public_key"],
}
if "key_validity_url" in invite_event.content:
o["key_validity_url"] = invite_event.content["key_validity_url"]
public_keys.append(o)
public_keys.extend(invite_event.content.get("public_keys", []))
return public_keys
def auth_types_for_event(event):
"""Given an event, return a list of (EventType, StateKey) that may be
needed to auth the event. The returned list may be a superset of what
would actually be required depending on the full state of the room.
Used to limit the number of events to fetch from the database to
actually auth the event.
"""
if event.type == EventTypes.Create:
return []
auth_types = []
auth_types.append((EventTypes.PowerLevels, "", ))
auth_types.append((EventTypes.Member, event.user_id, ))
auth_types.append((EventTypes.Create, "", ))
if event.type == EventTypes.Member:
membership = event.content["membership"]
if membership in [Membership.JOIN, Membership.INVITE]:
auth_types.append((EventTypes.JoinRules, "", ))
auth_types.append((EventTypes.Member, event.state_key, ))
if membership == Membership.INVITE:
if "third_party_invite" in event.content:
key = (
EventTypes.ThirdPartyInvite,
event.content["third_party_invite"]["signed"]["token"]
)
auth_types.append(key)
return auth_types

View File

@@ -36,6 +36,15 @@ class _EventInternalMetadata(object):
def is_invite_from_remote(self):
return getattr(self, "invite_from_remote", False)
def get_send_on_behalf_of(self):
"""Whether this server should send the event on behalf of another server.
This is used by the federation "send_join" API to forward the initial join
event for a server in the room.
returns a str with the name of the server this event is sent on behalf of.
"""
return getattr(self, "send_on_behalf_of", None)
def _event_dict_property(key):
def getter(self):
@@ -70,7 +79,6 @@ class EventBase(object):
auth_events = _event_dict_property("auth_events")
depth = _event_dict_property("depth")
content = _event_dict_property("content")
event_id = _event_dict_property("event_id")
hashes = _event_dict_property("hashes")
origin = _event_dict_property("origin")
origin_server_ts = _event_dict_property("origin_server_ts")
@@ -79,8 +87,6 @@ class EventBase(object):
redacts = _event_dict_property("redacts")
room_id = _event_dict_property("room_id")
sender = _event_dict_property("sender")
state_key = _event_dict_property("state_key")
type = _event_dict_property("type")
user_id = _event_dict_property("sender")
@property
@@ -153,6 +159,11 @@ class FrozenEvent(EventBase):
else:
frozen_dict = event_dict
self.event_id = event_dict["event_id"]
self.type = event_dict["type"]
if "state_key" in event_dict:
self.state_key = event_dict["state_key"]
super(FrozenEvent, self).__init__(
frozen_dict,
signatures=signatures,

View File

@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from . import EventBase, FrozenEvent
from . import EventBase, FrozenEvent, _event_dict_property
from synapse.types import EventID
@@ -34,6 +34,10 @@ class EventBuilder(EventBase):
internal_metadata_dict=internal_metadata_dict,
)
event_id = _event_dict_property("event_id")
state_key = _event_dict_property("state_key")
type = _event_dict_property("type")
def build(self):
return FrozenEvent.from_event(self)

View File

@@ -26,8 +26,7 @@ from synapse.util import unwrapFirstError
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.logutils import log_function
from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
from synapse.events import FrozenEvent
from synapse.types import get_domain_from_id
from synapse.events import FrozenEvent, builder
import synapse.metrics
from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination
@@ -127,6 +126,16 @@ class FederationClient(FederationBase):
destination, content, timeout
)
@log_function
def query_user_devices(self, destination, user_id, timeout=30000):
"""Query the device keys for a list of user ids hosted on a remote
server.
"""
sent_queries_counter.inc("user_devices")
return self.transport_layer.query_user_devices(
destination, user_id, timeout
)
@log_function
def claim_client_keys(self, destination, content, timeout):
"""Claims one-time keys for a device hosted on a remote server.
@@ -500,8 +509,10 @@ class FederationClient(FederationBase):
if "prev_state" not in pdu_dict:
pdu_dict["prev_state"] = []
ev = builder.EventBuilder(pdu_dict)
defer.returnValue(
(destination, self.event_from_pdu_json(pdu_dict))
(destination, ev)
)
break
except CodeMessageException as e:
@@ -708,7 +719,7 @@ class FederationClient(FederationBase):
@defer.inlineCallbacks
def get_missing_events(self, destination, room_id, earliest_events_ids,
latest_events, limit, min_depth):
latest_events, limit, min_depth, timeout):
"""Tries to fetch events we are missing. This is called when we receive
an event without having received all of its ancestors.
@@ -722,6 +733,7 @@ class FederationClient(FederationBase):
have all previous events for.
limit (int): Maximum number of events to return.
min_depth (int): Minimum depth of events tor return.
timeout (int): Max time to wait in ms
"""
try:
content = yield self.transport_layer.get_missing_events(
@@ -731,6 +743,7 @@ class FederationClient(FederationBase):
latest_events=[e.event_id for e in latest_events],
limit=limit,
min_depth=min_depth,
timeout=timeout,
)
events = [
@@ -741,8 +754,6 @@ class FederationClient(FederationBase):
signed_events = yield self._check_sigs_and_hash_and_fetch(
destination, events, outlier=False
)
have_gotten_all_from_destination = True
except HttpResponseException as e:
if not e.code == 400:
raise
@@ -750,72 +761,6 @@ class FederationClient(FederationBase):
# We are probably hitting an old server that doesn't support
# get_missing_events
signed_events = []
have_gotten_all_from_destination = False
if len(signed_events) >= limit:
defer.returnValue(signed_events)
users = yield self.state.get_current_user_in_room(room_id)
servers = set(get_domain_from_id(u) for u in users)
servers = set(servers)
servers.discard(self.server_name)
failed_to_fetch = set()
while len(signed_events) < limit:
# Are we missing any?
seen_events = set(earliest_events_ids)
seen_events.update(e.event_id for e in signed_events if e)
missing_events = {}
for e in itertools.chain(latest_events, signed_events):
if e.depth > min_depth:
missing_events.update({
e_id: e.depth for e_id, _ in e.prev_events
if e_id not in seen_events
and e_id not in failed_to_fetch
})
if not missing_events:
break
have_seen = yield self.store.have_events(missing_events)
for k in have_seen:
missing_events.pop(k, None)
if not missing_events:
break
# Okay, we haven't gotten everything yet. Lets get them.
ordered_missing = sorted(missing_events.items(), key=lambda x: x[0])
if have_gotten_all_from_destination:
servers.discard(destination)
def random_server_list():
srvs = list(servers)
random.shuffle(srvs)
return srvs
deferreds = [
preserve_fn(self.get_pdu)(
destinations=random_server_list(),
event_id=e_id,
)
for e_id, depth in ordered_missing[:limit - len(signed_events)]
]
res = yield preserve_context_over_deferred(
defer.DeferredList(deferreds, consumeErrors=True)
)
for (result, val), (e_id, _) in zip(res, ordered_missing):
if result and val:
signed_events.append(val)
else:
failed_to_fetch.add(e_id)
defer.returnValue(signed_events)

View File

@@ -23,6 +23,7 @@ from synapse.util.async import Linearizer
from synapse.util.logutils import log_function
from synapse.util.caches.response_cache import ResponseCache
from synapse.events import FrozenEvent
from synapse.types import get_domain_from_id
import synapse.metrics
from synapse.api.errors import AuthError, FederationError, SynapseError
@@ -51,8 +52,8 @@ class FederationServer(FederationBase):
self.auth = hs.get_auth()
self._room_pdu_linearizer = Linearizer()
self._server_linearizer = Linearizer()
self._room_pdu_linearizer = Linearizer("fed_room_pdu")
self._server_linearizer = Linearizer("fed_server")
# We cache responses to state queries, as they take a while and often
# come in waves.
@@ -132,7 +133,7 @@ class FederationServer(FederationBase):
if response:
logger.debug(
"[%s] We've already responed to this request",
"[%s] We've already responded to this request",
transaction.transaction_id
)
defer.returnValue(response)
@@ -143,6 +144,26 @@ class FederationServer(FederationBase):
results = []
for pdu in pdu_list:
# check that it's actually being sent from a valid destination to
# workaround bug #1753 in 0.18.5 and 0.18.6
if transaction.origin != get_domain_from_id(pdu.event_id):
if not (
pdu.type == 'm.room.member' and
pdu.content and
pdu.content.get("membership", None) == 'join' and
self.hs.is_mine_id(pdu.state_key)
):
logger.info(
"Discarding PDU %s from invalid origin %s",
pdu.event_id, transaction.origin
)
continue
else:
logger.info(
"Accepting join PDU %s from %s",
pdu.event_id, transaction.origin
)
try:
yield self._handle_new_pdu(transaction.origin, pdu)
results.append({})
@@ -395,6 +416,9 @@ class FederationServer(FederationBase):
def on_query_client_keys(self, origin, content):
return self.on_query_request("client_keys", content)
def on_query_user_devices(self, origin, user_id):
return self.on_query_request("user_devices", user_id)
@defer.inlineCallbacks
@log_function
def on_claim_client_keys(self, origin, content):
@@ -425,6 +449,7 @@ class FederationServer(FederationBase):
" limit: %d, min_depth: %d",
earliest_events, latest_events, limit, min_depth
)
missing_events = yield self.handler.on_get_missing_events(
origin, room_id, earliest_events, latest_events, limit, min_depth
)
@@ -474,6 +499,7 @@ class FederationServer(FederationBase):
@defer.inlineCallbacks
@log_function
def _handle_new_pdu(self, origin, pdu, get_missing=True):
# We reprocess pdus when we have seen them only as outliers
existing = yield self._get_persisted_pdu(
origin, pdu.event_id, do_auth=False
@@ -538,7 +564,16 @@ class FederationServer(FederationBase):
if get_missing and prevs - seen:
# If we're missing stuff, ensure we only fetch stuff one
# at a time.
logger.info(
"Acquiring lock for room %r to fetch %d missing events: %r...",
pdu.room_id, len(prevs - seen), list(prevs - seen)[:5],
)
with (yield self._room_pdu_linearizer.queue(pdu.room_id)):
logger.info(
"Acquired lock for room %r to fetch %d missing events",
pdu.room_id, len(prevs - seen),
)
# We recalculate seen, since it may have changed.
have_seen = yield self.store.have_events(prevs)
seen = set(have_seen.keys())
@@ -558,6 +593,25 @@ class FederationServer(FederationBase):
len(prevs - seen), pdu.room_id, list(prevs - seen)[:5]
)
# XXX: we set timeout to 10s to help workaround
# https://github.com/matrix-org/synapse/issues/1733.
# The reason is to avoid holding the linearizer lock
# whilst processing inbound /send transactions, causing
# FDs to stack up and block other inbound transactions
# which empirically can currently take up to 30 minutes.
#
# N.B. this explicitly disables retry attempts.
#
# N.B. this also increases our chances of falling back to
# fetching fresh state for the room if the missing event
# can't be found, which slightly reduces our security.
# it may also increase our DAG extremity count for the room,
# causing additional state resolution? See #1760.
# However, fetching state doesn't hold the linearizer lock
# apparently.
#
# see https://github.com/matrix-org/synapse/pull/1744
missing_events = yield self.get_missing_events(
origin,
pdu.room_id,
@@ -565,6 +619,7 @@ class FederationServer(FederationBase):
latest_events=[pdu],
limit=10,
min_depth=min_depth,
timeout=10000,
)
# We want to sort these by depth so we process them and

View File

@@ -19,7 +19,6 @@ from twisted.internet import defer
from .persistence import TransactionActions
from .units import Transaction, Edu
from synapse.api.constants import EventTypes, Membership
from synapse.api.errors import HttpResponseException
from synapse.util.async import run_on_reactor
from synapse.util.logcontext import preserve_context_over_fn
@@ -62,6 +61,7 @@ class TransactionQueue(object):
self.transport_layer = hs.get_federation_transport_client()
self.clock = hs.get_clock()
self.is_mine_id = hs.is_mine_id
# Is a mapping from destinations -> deferreds. Used to keep track
# of which destinations have transactions in flight and when they are
@@ -100,6 +100,7 @@ class TransactionQueue(object):
self.pending_failures_by_dest = {}
self.last_device_stream_id_by_dest = {}
self.last_device_list_stream_id_by_dest = {}
# HACK to get unique tx id
self._next_txn_id = int(self.clock.time_msec())
@@ -153,17 +154,32 @@ class TransactionQueue(object):
break
for event in events:
# Only send events for this server.
send_on_behalf_of = event.internal_metadata.get_send_on_behalf_of()
is_mine = self.is_mine_id(event.event_id)
if not is_mine and send_on_behalf_of is None:
continue
# Get the state from before the event.
# We need to make sure that this is the state from before
# the event and not from after it.
# Otherwise if the last member on a server in a room is
# banned then it won't receive the event because it won't
# be in the room after the ban.
users_in_room = yield self.state.get_current_user_in_room(
event.room_id, latest_event_ids=[event.event_id],
event.room_id, latest_event_ids=[
prev_id for prev_id, _ in event.prev_events
],
)
destinations = set(
get_domain_from_id(user_id) for user_id in users_in_room
)
if event.type == EventTypes.Member:
if event.content["membership"] == Membership.JOIN:
destinations.add(get_domain_from_id(event.state_key))
if send_on_behalf_of is not None:
# If we are sending the event on behalf of another server
# then it already has the event and there is no reason to
# send the event to it.
destinations.discard(send_on_behalf_of)
logger.debug("Sending %s to %r", event, destinations)
@@ -290,64 +306,77 @@ class TransactionQueue(object):
yield run_on_reactor()
while True:
pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
pending_edus = self.pending_edus_by_dest.pop(destination, [])
pending_presence = self.pending_presence_by_dest.pop(destination, {})
pending_failures = self.pending_failures_by_dest.pop(destination, [])
pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
pending_edus = self.pending_edus_by_dest.pop(destination, [])
pending_presence = self.pending_presence_by_dest.pop(destination, {})
pending_failures = self.pending_failures_by_dest.pop(destination, [])
pending_edus.extend(
self.pending_edus_keyed_by_dest.pop(destination, {}).values()
pending_edus.extend(
self.pending_edus_keyed_by_dest.pop(destination, {}).values()
)
limiter = yield get_retry_limiter(
destination,
self.clock,
self.store,
backoff_on_404=True, # If we get a 404 the other side has gone
)
device_message_edus, device_stream_id, dev_list_id = (
yield self._get_new_device_messages(destination)
)
pending_edus.extend(device_message_edus)
if pending_presence:
pending_edus.append(
Edu(
origin=self.server_name,
destination=destination,
edu_type="m.presence",
content={
"push": [
format_user_presence_state(
presence, self.clock.time_msec()
)
for presence in pending_presence.values()
]
},
)
)
limiter = yield get_retry_limiter(
destination,
self.clock,
self.store,
)
if pending_pdus:
logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d",
destination, len(pending_pdus))
device_message_edus, device_stream_id = (
yield self._get_new_device_messages(destination)
if not pending_pdus and not pending_edus and not pending_failures:
logger.debug("TX [%s] Nothing to send", destination)
self.last_device_stream_id_by_dest[destination] = (
device_stream_id
)
return
pending_edus.extend(device_message_edus)
if pending_presence:
pending_edus.append(
Edu(
origin=self.server_name,
destination=destination,
edu_type="m.presence",
content={
"push": [
format_user_presence_state(
presence, self.clock.time_msec()
)
for presence in pending_presence.values()
]
},
)
success = yield self._send_new_transaction(
destination, pending_pdus, pending_edus, pending_failures,
limiter=limiter,
)
if success:
# Remove the acknowledged device messages from the database
# Only bother if we actually sent some device messages
if device_message_edus:
yield self.store.delete_device_msgs_for_remote(
destination, device_stream_id
)
logger.info("Marking as sent %r %r", destination, dev_list_id)
yield self.store.mark_as_sent_devices_by_remote(
destination, dev_list_id
)
if pending_pdus:
logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d",
destination, len(pending_pdus))
if not pending_pdus and not pending_edus and not pending_failures:
logger.debug("TX [%s] Nothing to send", destination)
self.last_device_stream_id_by_dest[destination] = (
device_stream_id
)
return
success = yield self._send_new_transaction(
destination, pending_pdus, pending_edus, pending_failures,
device_stream_id,
should_delete_from_device_stream=bool(device_message_edus),
limiter=limiter,
)
if not success:
break
self.last_device_stream_id_by_dest[destination] = device_stream_id
self.last_device_list_stream_id_by_dest[destination] = dev_list_id
else:
break
except NotRetryingDestination:
logger.info(
logger.debug(
"TX [%s] not ready for retry yet - "
"dropping transaction for now",
destination,
@@ -372,13 +401,26 @@ class TransactionQueue(object):
)
for content in contents
]
defer.returnValue((edus, stream_id))
last_device_list = self.last_device_list_stream_id_by_dest.get(destination, 0)
now_stream_id, results = yield self.store.get_devices_by_remote(
destination, last_device_list
)
edus.extend(
Edu(
origin=self.server_name,
destination=destination,
edu_type="m.device_list_update",
content=content,
)
for content in results
)
defer.returnValue((edus, stream_id, now_stream_id))
@measure_func("_send_new_transaction")
@defer.inlineCallbacks
def _send_new_transaction(self, destination, pending_pdus, pending_edus,
pending_failures, device_stream_id,
should_delete_from_device_stream, limiter):
pending_failures, limiter):
# Sort based on the order field
pending_pdus.sort(key=lambda t: t[1])
@@ -462,7 +504,7 @@ class TransactionQueue(object):
code = e.code
response = e.response
if e.code == 429 or 500 <= e.code:
if e.code in (401, 404, 429) or 500 <= e.code:
logger.info(
"TX [%s] {%s} got %d response",
destination, txn_id, code
@@ -489,13 +531,6 @@ class TransactionQueue(object):
"Failed to send event %s to %s", p.event_id, destination
)
success = False
else:
# Remove the acknowledged device messages from the database
if should_delete_from_device_stream:
yield self.store.delete_device_msgs_for_remote(
destination, device_stream_id
)
self.last_device_stream_id_by_dest[destination] = device_stream_id
except RuntimeError as e:
# We capture this here as there as nothing actually listens
# for this finishing functions deferred.

View File

@@ -346,6 +346,32 @@ class TransportLayerClient(object):
)
defer.returnValue(content)
@defer.inlineCallbacks
@log_function
def query_user_devices(self, destination, user_id, timeout):
"""Query the devices for a user id hosted on a remote server.
Response:
{
"stream_id": "...",
"devices": [ { ... } ]
}
Args:
destination(str): The server to query.
query_content(dict): The user ids to query.
Returns:
A dict containg the device keys.
"""
path = PREFIX + "/user/devices/" + user_id
content = yield self.client.get_json(
destination=destination,
path=path,
timeout=timeout,
)
defer.returnValue(content)
@defer.inlineCallbacks
@log_function
def claim_client_keys(self, destination, query_content, timeout):
@@ -386,7 +412,7 @@ class TransportLayerClient(object):
@defer.inlineCallbacks
@log_function
def get_missing_events(self, destination, room_id, earliest_events,
latest_events, limit, min_depth):
latest_events, limit, min_depth, timeout):
path = PREFIX + "/get_missing_events/%s" % (room_id,)
content = yield self.client.post_json(
@@ -397,7 +423,8 @@ class TransportLayerClient(object):
"min_depth": int(min_depth),
"earliest_events": earliest_events,
"latest_events": latest_events,
}
},
timeout=timeout,
)
defer.returnValue(content)

View File

@@ -409,6 +409,13 @@ class FederationClientKeysQueryServlet(BaseFederationServlet):
return self.handler.on_query_client_keys(origin, content)
class FederationUserDevicesQueryServlet(BaseFederationServlet):
PATH = "/user/devices/(?P<user_id>[^/]*)"
def on_GET(self, origin, content, query, user_id):
return self.handler.on_query_user_devices(origin, user_id)
class FederationClientKeysClaimServlet(BaseFederationServlet):
PATH = "/user/keys/claim"
@@ -613,6 +620,7 @@ SERVLET_CLASSES = (
FederationGetMissingEventsServlet,
FederationEventAuthServlet,
FederationClientKeysQueryServlet,
FederationUserDevicesQueryServlet,
FederationClientKeysClaimServlet,
FederationThirdPartyInviteExchangeServlet,
On3pidBindServlet,

View File

@@ -88,9 +88,13 @@ class BaseHandler(object):
current_state = yield self.store.get_events(
context.current_state_ids.values()
)
current_state = current_state.values()
else:
current_state = yield self.store.get_current_state(event.room_id)
current_state = yield self.state_handler.get_current_state(
event.room_id
)
current_state = current_state.values()
logger.info("maybe_kick_guest_users %r", current_state)
yield self.kick_guest_users(current_state)

View File

@@ -607,7 +607,7 @@ class AuthHandler(BaseHandler):
# types (mediums) of threepid. For now, we still use the existing
# infrastructure, but this is the start of synapse gaining knowledge
# of specific types of threepid (and fixes the fact that checking
# for the presenc eof an email address during password reset was
# for the presence of an email address during password reset was
# case sensitive).
if medium == 'email':
address = address.lower()
@@ -617,6 +617,17 @@ class AuthHandler(BaseHandler):
self.hs.get_clock().time_msec()
)
@defer.inlineCallbacks
def delete_threepid(self, user_id, medium, address):
# 'Canonicalise' email addresses as per above
if medium == 'email':
address = address.lower()
ret = yield self.store.user_delete_threepid(
user_id, medium, address,
)
defer.returnValue(ret)
def _save_session(self, session):
# TODO: Persistent storage
logger.debug("Saving session %s", session)
@@ -656,8 +667,8 @@ class AuthHandler(BaseHandler):
Whether self.hash(password) == stored_hash (bool).
"""
if stored_hash:
return bcrypt.hashpw(password + self.hs.config.password_pepper,
stored_hash.encode('utf-8')) == stored_hash
return bcrypt.hashpw(password.encode('utf8') + self.hs.config.password_pepper,
stored_hash.encode('utf8')) == stored_hash
else:
return False

View File

@@ -14,7 +14,10 @@
# limitations under the License.
from synapse.api import errors
from synapse.api.constants import EventTypes
from synapse.util import stringutils
from synapse.util.async import Linearizer
from synapse.types import get_domain_from_id
from twisted.internet import defer
from ._base import BaseHandler
@@ -27,6 +30,21 @@ class DeviceHandler(BaseHandler):
def __init__(self, hs):
super(DeviceHandler, self).__init__(hs)
self.hs = hs
self.state = hs.get_state_handler()
self.federation_sender = hs.get_federation_sender()
self.federation = hs.get_replication_layer()
self._remote_edue_linearizer = Linearizer(name="remote_device_list")
self.federation.register_edu_handler(
"m.device_list_update", self._incoming_device_list_update,
)
self.federation.register_query_handler(
"user_devices", self.on_federation_query_user_devices,
)
hs.get_distributor().observe("user_left_room", self.user_left_room)
@defer.inlineCallbacks
def check_device_registered(self, user_id, device_id,
initial_device_display_name=None):
@@ -45,29 +63,29 @@ class DeviceHandler(BaseHandler):
str: device id (generated if none was supplied)
"""
if device_id is not None:
yield self.store.store_device(
new_device = yield self.store.store_device(
user_id=user_id,
device_id=device_id,
initial_device_display_name=initial_device_display_name,
ignore_if_known=True,
)
if new_device:
yield self.notify_device_update(user_id, [device_id])
defer.returnValue(device_id)
# if the device id is not specified, we'll autogen one, but loop a few
# times in case of a clash.
attempts = 0
while attempts < 5:
try:
device_id = stringutils.random_string(10).upper()
yield self.store.store_device(
user_id=user_id,
device_id=device_id,
initial_device_display_name=initial_device_display_name,
ignore_if_known=False,
)
device_id = stringutils.random_string(10).upper()
new_device = yield self.store.store_device(
user_id=user_id,
device_id=device_id,
initial_device_display_name=initial_device_display_name,
)
if new_device:
yield self.notify_device_update(user_id, [device_id])
defer.returnValue(device_id)
except errors.StoreError:
attempts += 1
attempts += 1
raise errors.StoreError(500, "Couldn't generate a device ID.")
@@ -147,6 +165,8 @@ class DeviceHandler(BaseHandler):
user_id=user_id, device_id=device_id
)
yield self.notify_device_update(user_id, [device_id])
@defer.inlineCallbacks
def update_device(self, user_id, device_id, content):
""" Update the given device
@@ -166,12 +186,164 @@ class DeviceHandler(BaseHandler):
device_id,
new_display_name=content.get("display_name")
)
yield self.notify_device_update(user_id, [device_id])
except errors.StoreError, e:
if e.code == 404:
raise errors.NotFoundError()
else:
raise
@defer.inlineCallbacks
def notify_device_update(self, user_id, device_ids):
"""Notify that a user's device(s) has changed. Pokes the notifier, and
remote servers if the user is local.
"""
rooms = yield self.store.get_rooms_for_user(user_id)
room_ids = [r.room_id for r in rooms]
hosts = set()
if self.hs.is_mine_id(user_id):
for room_id in room_ids:
users = yield self.store.get_users_in_room(room_id)
hosts.update(get_domain_from_id(u) for u in users)
hosts.discard(self.server_name)
position = yield self.store.add_device_change_to_streams(
user_id, device_ids, list(hosts)
)
yield self.notifier.on_new_event(
"device_list_key", position, rooms=room_ids,
)
if hosts:
logger.info("Sending device list update notif to: %r", hosts)
for host in hosts:
self.federation_sender.send_device_messages(host)
@defer.inlineCallbacks
def get_user_ids_changed(self, user_id, from_token):
"""Get list of users that have had the devices updated, or have newly
joined a room, that `user_id` may be interested in.
Args:
user_id (str)
from_token (StreamToken)
"""
rooms = yield self.store.get_rooms_for_user(user_id)
room_ids = set(r.room_id for r in rooms)
# First we check if any devices have changed
changed = yield self.store.get_user_whose_devices_changed(
from_token.device_list_key
)
# Then work out if any users have since joined
rooms_changed = self.store.get_rooms_that_changed(room_ids, from_token.room_key)
possibly_changed = set(changed)
for room_id in rooms_changed:
# Fetch (an approximation) of the current state at the time.
event_rows, token = yield self.store.get_recent_event_ids_for_room(
room_id, end_token=from_token.room_key, limit=1,
)
if event_rows:
last_event_id = event_rows[-1]["event_id"]
prev_state_ids = yield self.store.get_state_ids_for_event(last_event_id)
else:
prev_state_ids = {}
current_state_ids = yield self.state.get_current_state_ids(room_id)
# If there has been any change in membership, include them in the
# possibly changed list. We'll check if they are joined below,
# and we're not toooo worried about spuriously adding users.
for key, event_id in current_state_ids.iteritems():
etype, state_key = key
if etype == EventTypes.Member:
prev_event_id = prev_state_ids.get(key, None)
if not prev_event_id or prev_event_id != event_id:
possibly_changed.add(state_key)
users_who_share_room = yield self.store.get_users_who_share_room_with_user(
user_id
)
# We return the intersection of users whose devices have changed (or
# membership has changeD) and the users who share a room with the
# requester
defer.returnValue(users_who_share_room & users_who_share_room)
@defer.inlineCallbacks
def _incoming_device_list_update(self, origin, edu_content):
user_id = edu_content["user_id"]
device_id = edu_content["device_id"]
stream_id = edu_content["stream_id"]
prev_ids = edu_content.get("prev_id", [])
if get_domain_from_id(user_id) != origin:
# TODO: Raise?
logger.warning("Got device list update edu for %r from %r", user_id, origin)
return
rooms = yield self.store.get_rooms_for_user(user_id)
if not rooms:
# We don't share any rooms with this user. Ignore update, as we
# probably won't get any further updates.
return
with (yield self._remote_edue_linearizer.queue(user_id)):
# If the prev id matches whats in our cache table, then we don't need
# to resync the users device list, otherwise we do.
resync = True
if len(prev_ids) == 1:
extremity = yield self.store.get_device_list_last_stream_id_for_remote(
user_id
)
logger.info("Extrem: %r, prev_ids: %r", extremity, prev_ids)
if str(extremity) == str(prev_ids[0]):
resync = False
if resync:
# Fetch all devices for the user.
result = yield self.federation.query_user_devices(origin, user_id)
stream_id = result["stream_id"]
devices = result["devices"]
yield self.store.update_remote_device_list_cache(
user_id, devices, stream_id,
)
device_ids = [device["device_id"] for device in devices]
yield self.notify_device_update(user_id, device_ids)
else:
# Simply update the single device, since we know that is the only
# change (becuase of the single prev_id matching the current cache)
content = dict(edu_content)
for key in ("user_id", "device_id", "stream_id", "prev_ids"):
content.pop(key, None)
yield self.store.update_remote_device_list_cache_entry(
user_id, device_id, content, stream_id,
)
yield self.notify_device_update(user_id, [device_id])
@defer.inlineCallbacks
def on_federation_query_user_devices(self, user_id):
stream_id, devices = yield self.store.get_devices_with_keys_by_user(user_id)
defer.returnValue({
"user_id": user_id,
"stream_id": stream_id,
"devices": devices,
})
@defer.inlineCallbacks
def user_left_room(self, user, room_id):
user_id = user.to_string()
rooms = yield self.store.get_rooms_for_user(user_id)
if not rooms:
# We no longer share rooms with this user, so we'll no longer
# receive device updates. Mark this in DB.
yield self.store.mark_remote_user_device_list_as_unsubscribed(user_id)
def _update_device_from_client_ips(device, client_ips):
ip = client_ips.get((device["user_id"], device["device_id"]), {})

View File

@@ -73,10 +73,9 @@ class E2eKeysHandler(object):
if self.is_mine_id(user_id):
local_query[user_id] = device_ids
else:
domain = get_domain_from_id(user_id)
remote_queries.setdefault(domain, {})[user_id] = device_ids
remote_queries[user_id] = device_ids
# do the queries
# Firt get local devices.
failures = {}
results = {}
if local_query:
@@ -85,9 +84,42 @@ class E2eKeysHandler(object):
if user_id in local_query:
results[user_id] = keys
# Now attempt to get any remote devices from our local cache.
remote_queries_not_in_cache = {}
if remote_queries:
query_list = []
for user_id, device_ids in remote_queries.iteritems():
if device_ids:
query_list.extend((user_id, device_id) for device_id in device_ids)
else:
query_list.append((user_id, None))
user_ids_not_in_cache, remote_results = (
yield self.store.get_user_devices_from_cache(
query_list
)
)
for user_id, devices in remote_results.iteritems():
user_devices = results.setdefault(user_id, {})
for device_id, device in devices.iteritems():
keys = device.get("keys", None)
device_display_name = device.get("device_display_name", None)
if keys:
result = dict(keys)
unsigned = result.setdefault("unsigned", {})
if device_display_name:
unsigned["device_display_name"] = device_display_name
user_devices[device_id] = result
for user_id in user_ids_not_in_cache:
domain = get_domain_from_id(user_id)
r = remote_queries_not_in_cache.setdefault(domain, {})
r[user_id] = remote_queries[user_id]
# Now fetch any devices that we don't have in our cache
@defer.inlineCallbacks
def do_remote_query(destination):
destination_query = remote_queries[destination]
destination_query = remote_queries_not_in_cache[destination]
try:
limiter = yield get_retry_limiter(
destination, self.clock, self.store
@@ -119,7 +151,7 @@ class E2eKeysHandler(object):
yield preserve_context_over_deferred(defer.gatherResults([
preserve_fn(do_remote_query)(destination)
for destination in remote_queries
for destination in remote_queries_not_in_cache
]))
defer.returnValue({
@@ -162,7 +194,7 @@ class E2eKeysHandler(object):
# "unsigned" section
for user_id, device_keys in results.items():
for device_id, device_info in device_keys.items():
r = json.loads(device_info["key_json"])
r = dict(device_info["keys"])
r["unsigned"] = {}
display_name = device_info["device_display_name"]
if display_name is not None:
@@ -255,10 +287,12 @@ class E2eKeysHandler(object):
device_id, user_id, time_now
)
# TODO: Sign the JSON with the server key
yield self.store.set_e2e_device_keys(
user_id, device_id, time_now,
encode_canonical_json(device_keys)
changed = yield self.store.set_e2e_device_keys(
user_id, device_id, time_now, device_keys,
)
if changed:
# Only notify about device updates *if* the keys actually changed
yield self.device_handler.notify_device_update(user_id, [device_id])
one_time_keys = keys.get("one_time_keys", None)
if one_time_keys:

View File

@@ -591,11 +591,12 @@ class FederationHandler(BaseHandler):
event_ids = list(extremities.keys())
logger.debug("calling resolve_state_groups in _maybe_backfill")
states = yield preserve_context_over_deferred(defer.gatherResults([
preserve_fn(self.state_handler.resolve_state_groups)(room_id, [e])
for e in event_ids
]))
states = dict(zip(event_ids, [s[1] for s in states]))
states = dict(zip(event_ids, [s.state for s in states]))
state_map = yield self.store.get_events(
[e_id for ids in states.values() for e_id in ids],
@@ -790,6 +791,10 @@ class FederationHandler(BaseHandler):
)
event.internal_metadata.outlier = False
# Send this event on behalf of the origin server since they may not
# have an up to data view of the state of the room at this event so
# will not know which servers to send the event to.
event.internal_metadata.send_on_behalf_of = origin
context, event_stream_id, max_stream_id = yield self._handle_new_event(
origin, event
@@ -1314,7 +1319,6 @@ class FederationHandler(BaseHandler):
event_stream_id, max_stream_id = yield self.store.persist_event(
event, new_event_context,
current_state=state,
)
defer.returnValue((event_stream_id, max_stream_id))
@@ -1525,7 +1529,7 @@ class FederationHandler(BaseHandler):
(d.type, d.state_key): d for d in different_events if d
})
new_state, prev_state = self.state_handler.resolve_events(
new_state = self.state_handler.resolve_events(
[local_view.values(), remote_view.values()],
event
)

View File

@@ -208,8 +208,10 @@ class MessageHandler(BaseHandler):
content = builder.content
try:
content["displayname"] = yield profile.get_displayname(target)
content["avatar_url"] = yield profile.get_avatar_url(target)
if "displayname" not in content:
content["displayname"] = yield profile.get_displayname(target)
if "avatar_url" not in content:
content["avatar_url"] = yield profile.get_avatar_url(target)
except Exception as e:
logger.info(
"Failed to get profile information for %r: %s",
@@ -279,7 +281,9 @@ class MessageHandler(BaseHandler):
if event.type == EventTypes.Message:
presence = self.hs.get_presence_handler()
yield presence.bump_presence_active_time(user)
# We don't want to block sending messages on any presence code. This
# matters as sometimes presence code can take a while.
preserve_fn(presence.bump_presence_active_time)(user)
@defer.inlineCallbacks
def deduplicate_state_event(self, event, context):

View File

@@ -574,7 +574,7 @@ class PresenceHandler(object):
if not local_states:
continue
users = yield self.state.get_current_user_in_room(room_id)
users = yield self.store.get_users_in_room(room_id)
hosts = set(get_domain_from_id(u) for u in users)
for host in hosts:
@@ -766,7 +766,7 @@ class PresenceHandler(object):
# don't need to send to local clients here, as that is done as part
# of the event stream/sync.
# TODO: Only send to servers not already in the room.
user_ids = yield self.state.get_current_user_in_room(room_id)
user_ids = yield self.store.get_users_in_room(room_id)
if self.is_mine(user):
state = yield self.current_state_for_user(user.to_string())
@@ -1011,7 +1011,7 @@ class PresenceEventSource(object):
@defer.inlineCallbacks
@log_function
def get_new_events(self, user, from_key, room_ids=None, include_offline=True,
**kwargs):
explicit_room_id=None, **kwargs):
# The process for getting presence events are:
# 1. Get the rooms the user is in.
# 2. Get the list of user in the rooms.
@@ -1028,22 +1028,24 @@ class PresenceEventSource(object):
user_id = user.to_string()
if from_key is not None:
from_key = int(from_key)
room_ids = room_ids or []
presence = self.get_presence_handler()
stream_change_cache = self.store.presence_stream_cache
if not room_ids:
rooms = yield self.store.get_rooms_for_user(user_id)
room_ids = set(e.room_id for e in rooms)
else:
room_ids = set(room_ids)
max_token = self.store.get_current_presence_token()
plist = yield self.store.get_presence_list_accepted(user.localpart)
friends = set(row["observed_user_id"] for row in plist)
friends.add(user_id) # So that we receive our own presence
users_interested_in = set(row["observed_user_id"] for row in plist)
users_interested_in.add(user_id) # So that we receive our own presence
users_who_share_room = yield self.store.get_users_who_share_room_with_user(
user_id
)
users_interested_in.update(users_who_share_room)
if explicit_room_id:
user_ids = yield self.store.get_users_in_room(explicit_room_id)
users_interested_in.update(user_ids)
user_ids_changed = set()
changed = None
@@ -1055,35 +1057,19 @@ class PresenceEventSource(object):
# work out if we share a room or they're in our presence list
get_updates_counter.inc("stream")
for other_user_id in changed:
if other_user_id in friends:
if other_user_id in users_interested_in:
user_ids_changed.add(other_user_id)
continue
other_rooms = yield self.store.get_rooms_for_user(other_user_id)
if room_ids.intersection(e.room_id for e in other_rooms):
user_ids_changed.add(other_user_id)
continue
else:
# Too many possible updates. Find all users we can see and check
# if any of them have changed.
get_updates_counter.inc("full")
user_ids_to_check = set()
for room_id in room_ids:
users = yield self.state.get_current_user_in_room(room_id)
user_ids_to_check.update(users)
user_ids_to_check.update(friends)
# Always include yourself. Only really matters for when the user is
# not in any rooms, but still.
user_ids_to_check.add(user_id)
if from_key:
user_ids_changed = stream_change_cache.get_entities_changed(
user_ids_to_check, from_key,
users_interested_in, from_key,
)
else:
user_ids_changed = user_ids_to_check
user_ids_changed = users_interested_in
updates = yield presence.current_state_for_users(user_ids_changed)

View File

@@ -437,6 +437,7 @@ class RoomEventSource(object):
limit,
room_ids,
is_guest,
explicit_room_id=None,
):
# We just ignore the key for now.

View File

@@ -62,17 +62,18 @@ class RoomListHandler(BaseHandler):
appservice and network id to use an appservice specific one.
Setting to None returns all public rooms across all lists.
"""
if search_filter or (network_tuple and network_tuple.appservice_id is not None):
if search_filter:
# We explicitly don't bother caching searches or requests for
# appservice specific lists.
return self._get_public_room_list(
limit, since_token, search_filter, network_tuple=network_tuple,
)
result = self.response_cache.get((limit, since_token))
key = (limit, since_token, network_tuple)
result = self.response_cache.get(key)
if not result:
result = self.response_cache.set(
(limit, since_token),
key,
self._get_public_room_list(
limit, since_token, network_tuple=network_tuple
)

View File

@@ -45,7 +45,7 @@ class RoomMemberHandler(BaseHandler):
def __init__(self, hs):
super(RoomMemberHandler, self).__init__(hs)
self.member_linearizer = Linearizer()
self.member_linearizer = Linearizer(name="member")
self.clock = hs.get_clock()
@@ -89,7 +89,7 @@ class RoomMemberHandler(BaseHandler):
duplicate = yield msg_handler.deduplicate_state_event(event, context)
if duplicate is not None:
# Discard the new event since this membership change is a no-op.
return
defer.returnValue(duplicate)
yield msg_handler.handle_new_client_event(
requester,
@@ -120,6 +120,8 @@ class RoomMemberHandler(BaseHandler):
if prev_member_event.membership == Membership.JOIN:
user_left_room(self.distributor, target, room_id)
defer.returnValue(event)
@defer.inlineCallbacks
def remote_join(self, remote_room_hosts, room_id, user, content):
if len(remote_room_hosts) == 0:
@@ -187,6 +189,7 @@ class RoomMemberHandler(BaseHandler):
ratelimit=True,
content=None,
):
content_specified = bool(content)
if content is None:
content = {}
@@ -229,13 +232,22 @@ class RoomMemberHandler(BaseHandler):
errcode=Codes.BAD_STATE
)
if old_state:
same_content = content == old_state.content
same_membership = old_membership == effective_membership_state
same_sender = requester.user.to_string() == old_state.sender
if same_sender and same_membership and same_content:
defer.returnValue(old_state)
is_host_in_room = yield self._is_host_in_room(current_state_ids)
if effective_membership_state == Membership.JOIN:
if requester.is_guest and not self._can_guest_join(current_state_ids):
# This should be an auth check, but guests are a local concept,
# so don't really fit into the general auth process.
raise AuthError(403, "Guest access not allowed")
if requester.is_guest:
guest_can_join = yield self._can_guest_join(current_state_ids)
if not guest_can_join:
# This should be an auth check, but guests are a local concept,
# so don't really fit into the general auth process.
raise AuthError(403, "Guest access not allowed")
if not is_host_in_room:
inviter = yield self.get_inviter(target.to_string(), room_id)
@@ -245,8 +257,9 @@ class RoomMemberHandler(BaseHandler):
content["membership"] = Membership.JOIN
profile = self.hs.get_handlers().profile_handler
content["displayname"] = yield profile.get_displayname(target)
content["avatar_url"] = yield profile.get_avatar_url(target)
if not content_specified:
content["displayname"] = yield profile.get_displayname(target)
content["avatar_url"] = yield profile.get_avatar_url(target)
if requester.is_guest:
content["kind"] = "guest"
@@ -288,7 +301,7 @@ class RoomMemberHandler(BaseHandler):
defer.returnValue({})
yield self._local_membership_update(
res = yield self._local_membership_update(
requester=requester,
target=target,
room_id=room_id,
@@ -298,6 +311,7 @@ class RoomMemberHandler(BaseHandler):
prev_event_ids=latest_event_ids,
content=content,
)
defer.returnValue(res)
@defer.inlineCallbacks
def send_membership_event(

View File

@@ -115,6 +115,7 @@ class SyncResult(collections.namedtuple("SyncResult", [
"invited", # InvitedSyncResult for each invited room.
"archived", # ArchivedSyncResult for each archived room.
"to_device", # List of direct messages for the device.
"device_lists", # List of user_ids whose devices have chanegd
])):
__slots__ = []
@@ -129,7 +130,8 @@ class SyncResult(collections.namedtuple("SyncResult", [
self.invited or
self.archived or
self.account_data or
self.to_device
self.to_device or
self.device_lists
)
@@ -544,6 +546,10 @@ class SyncHandler(object):
yield self._generate_sync_entry_for_to_device(sync_result_builder)
device_lists = yield self._generate_sync_entry_for_device_list(
sync_result_builder
)
defer.returnValue(SyncResult(
presence=sync_result_builder.presence,
account_data=sync_result_builder.account_data,
@@ -551,9 +557,32 @@ class SyncHandler(object):
invited=sync_result_builder.invited,
archived=sync_result_builder.archived,
to_device=sync_result_builder.to_device,
device_lists=device_lists,
next_batch=sync_result_builder.now_token,
))
@defer.inlineCallbacks
def _generate_sync_entry_for_device_list(self, sync_result_builder):
user_id = sync_result_builder.sync_config.user.to_string()
since_token = sync_result_builder.since_token
if since_token and since_token.device_list_key:
rooms = yield self.store.get_rooms_for_user(user_id)
room_ids = set(r.room_id for r in rooms)
user_ids_changed = set()
changed = yield self.store.get_user_whose_devices_changed(
since_token.device_list_key
)
for other_user_id in changed:
other_rooms = yield self.store.get_rooms_for_user(other_user_id)
if room_ids.intersection(e.room_id for e in other_rooms):
user_ids_changed.add(other_user_id)
defer.returnValue(user_ids_changed)
else:
defer.returnValue([])
@defer.inlineCallbacks
def _generate_sync_entry_for_to_device(self, sync_result_builder):
"""Generates the portion of the sync response. Populates

View File

@@ -25,7 +25,7 @@ from synapse.http.endpoint import SpiderEndpoint
from canonicaljson import encode_canonical_json
from twisted.internet import defer, reactor, ssl, protocol, task
from twisted.internet.endpoints import SSL4ClientEndpoint, TCP4ClientEndpoint
from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
from twisted.web.client import (
BrowserLikeRedirectAgent, ContentDecoderAgent, GzipDecoder, Agent,
readBody, PartialDownloadError,
@@ -386,26 +386,23 @@ class SpiderEndpointFactory(object):
def endpointForURI(self, uri):
logger.info("Getting endpoint for %s", uri.toBytes())
if uri.scheme == "http":
return SpiderEndpoint(
reactor, uri.host, uri.port, self.blacklist, self.whitelist,
endpoint=TCP4ClientEndpoint,
endpoint_kw_args={
'timeout': 15
},
)
endpoint_factory = HostnameEndpoint
elif uri.scheme == "https":
tlsPolicy = self.policyForHTTPS.creatorForNetloc(uri.host, uri.port)
return SpiderEndpoint(
reactor, uri.host, uri.port, self.blacklist, self.whitelist,
endpoint=SSL4ClientEndpoint,
endpoint_kw_args={
'sslContextFactory': tlsPolicy,
'timeout': 15
},
)
tlsCreator = self.policyForHTTPS.creatorForNetloc(uri.host, uri.port)
def endpoint_factory(reactor, host, port, **kw):
return wrapClientTLS(
tlsCreator,
HostnameEndpoint(reactor, host, port, **kw))
else:
logger.warn("Can't get endpoint for unrecognised scheme %s", uri.scheme)
return None
return SpiderEndpoint(
reactor, uri.host, uri.port, self.blacklist, self.whitelist,
endpoint=endpoint_factory, endpoint_kw_args=dict(timeout=15),
)
class SpiderHttpClient(SimpleHttpClient):

View File

@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet.endpoints import SSL4ClientEndpoint, TCP4ClientEndpoint
from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS
from twisted.internet import defer, reactor
from twisted.internet.error import ConnectError
from twisted.names import client, dns
@@ -58,11 +58,13 @@ def matrix_federation_endpoint(reactor, destination, ssl_context_factory=None,
endpoint_kw_args.update(timeout=timeout)
if ssl_context_factory is None:
transport_endpoint = TCP4ClientEndpoint
transport_endpoint = HostnameEndpoint
default_port = 8008
else:
transport_endpoint = SSL4ClientEndpoint
endpoint_kw_args.update(sslContextFactory=ssl_context_factory)
def transport_endpoint(reactor, host, port, timeout):
return wrapClientTLS(
ssl_context_factory,
HostnameEndpoint(reactor, host, port, timeout=timeout))
default_port = 8448
if port is None:
@@ -142,7 +144,7 @@ class SpiderEndpoint(object):
Implements twisted.internet.interfaces.IStreamClientEndpoint.
"""
def __init__(self, reactor, host, port, blacklist, whitelist,
endpoint=TCP4ClientEndpoint, endpoint_kw_args={}):
endpoint=HostnameEndpoint, endpoint_kw_args={}):
self.reactor = reactor
self.host = host
self.port = port
@@ -180,7 +182,7 @@ class SRVClientEndpoint(object):
"""
def __init__(self, reactor, service, domain, protocol="tcp",
default_port=None, endpoint=TCP4ClientEndpoint,
default_port=None, endpoint=HostnameEndpoint,
endpoint_kw_args={}):
self.reactor = reactor
self.service_name = "_%s._%s.%s" % (service, protocol, domain)

View File

@@ -378,6 +378,7 @@ class Notifier(object):
limit=limit,
is_guest=is_peeking,
room_ids=room_ids,
explicit_room_id=explicit_room_id,
)
if name == "room":

View File

@@ -439,15 +439,23 @@ class Mailer(object):
})
def make_room_link(self, room_id):
# need /beta for Universal Links to work on iOS
if self.app_name == "Vector":
return "https://vector.im/beta/#/room/%s" % (room_id,)
if self.hs.config.email_riot_base_url:
base_url = self.hs.config.email_riot_base_url
elif self.app_name == "Vector":
# need /beta for Universal Links to work on iOS
base_url = "https://vector.im/beta/#/room"
else:
return "https://matrix.to/#/%s" % (room_id,)
base_url = "https://matrix.to/#"
return "%s/%s" % (base_url, room_id)
def make_notif_link(self, notif):
# need /beta for Universal Links to work on iOS
if self.app_name == "Vector":
if self.hs.config.email_riot_base_url:
return "%s/#/room/%s/%s" % (
self.hs.config.email_riot_base_url,
notif['room_id'], notif['event_id']
)
elif self.app_name == "Vector":
# need /beta for Universal Links to work on iOS
return "https://vector.im/beta/#/room/%s/%s" % (
notif['room_id'], notif['event_id']
)

View File

@@ -52,7 +52,7 @@ def get_badge_count(store, user_id):
def get_context_for_event(store, state_handler, ev, user_id):
ctx = {}
room_state_ids = yield state_handler.get_current_state_ids(ev.room_id)
room_state_ids = yield store.get_state_ids_for_event(ev.event_id)
# we no longer bother setting room_alias, and make room_name the
# human-readable name instead, be that m.room.name, an alias or

View File

@@ -24,7 +24,7 @@ REQUIREMENTS = {
"signedjson>=1.0.0": ["signedjson>=1.0.0"],
"pynacl==0.3.0": ["nacl==0.3.0", "nacl.bindings"],
"service_identity>=1.0.0": ["service_identity>=1.0.0"],
"Twisted>=15.1.0": ["twisted>=15.1.0"],
"Twisted>=16.0.0": ["twisted>=16.0.0"],
"pyopenssl>=0.14": ["OpenSSL>=0.14"],
"pyyaml": ["yaml"],
"pyasn1": ["pyasn1"],

View File

@@ -46,6 +46,7 @@ STREAM_NAMES = (
("to_device",),
("public_rooms",),
("federation",),
("device_lists",),
)
@@ -140,6 +141,7 @@ class ReplicationResource(Resource):
caches_token = self.store.get_cache_stream_token()
public_rooms_token = self.store.get_current_public_room_stream_id()
federation_token = self.federation_sender.get_current_token()
device_list_token = self.store.get_device_stream_token()
defer.returnValue(_ReplicationToken(
room_stream_token,
@@ -155,6 +157,7 @@ class ReplicationResource(Resource):
int(stream_token.to_device_key),
int(public_rooms_token),
int(federation_token),
int(device_list_token),
))
@request_handler()
@@ -214,6 +217,7 @@ class ReplicationResource(Resource):
yield self.caches(writer, current_token, limit, request_streams)
yield self.to_device(writer, current_token, limit, request_streams)
yield self.public_rooms(writer, current_token, limit, request_streams)
yield self.device_lists(writer, current_token, limit, request_streams)
self.federation(writer, current_token, limit, request_streams, federation_ack)
self.streams(writer, current_token, request_streams)
@@ -295,9 +299,6 @@ class ReplicationResource(Resource):
"backward_ex_outliers", res.backward_ex_outliers,
("position", "event_id", "state_group"),
)
writer.write_header_and_rows(
"state_resets", res.state_resets, ("position",),
)
@defer.inlineCallbacks
def presence(self, writer, current_token, request_streams):
@@ -495,6 +496,20 @@ class ReplicationResource(Resource):
"position", "type", "content",
), position=upto_token)
@defer.inlineCallbacks
def device_lists(self, writer, current_token, limit, request_streams):
current_position = current_token.device_lists
device_lists = request_streams.get("device_lists")
if device_lists is not None and device_lists != current_position:
changes = yield self.store.get_all_device_list_changes_for_remotes(
device_lists,
)
writer.write_header_and_rows("device_lists", changes, (
"position", "user_id", "destination",
), position=current_position)
class _Writer(object):
"""Writes the streams as a JSON object as the response to the request"""
@@ -527,7 +542,7 @@ class _Writer(object):
class _ReplicationToken(collections.namedtuple("_ReplicationToken", (
"events", "presence", "typing", "receipts", "account_data", "backfill",
"push_rules", "pushers", "state", "caches", "to_device", "public_rooms",
"federation",
"federation", "device_lists",
))):
__slots__ = []

View File

@@ -0,0 +1,72 @@
# -*- coding: utf-8 -*-
# Copyright 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import BaseSlavedStore
from ._slaved_id_tracker import SlavedIdTracker
from synapse.storage import DataStore
from synapse.util.caches.stream_change_cache import StreamChangeCache
class SlavedDeviceStore(BaseSlavedStore):
def __init__(self, db_conn, hs):
super(SlavedDeviceStore, self).__init__(db_conn, hs)
self.hs = hs
self._device_list_id_gen = SlavedIdTracker(
db_conn, "device_lists_stream", "stream_id",
)
device_list_max = self._device_list_id_gen.get_current_token()
self._device_list_stream_cache = StreamChangeCache(
"DeviceListStreamChangeCache", device_list_max,
)
self._device_list_federation_stream_cache = StreamChangeCache(
"DeviceListFederationStreamChangeCache", device_list_max,
)
get_device_stream_token = DataStore.get_device_stream_token.__func__
get_user_whose_devices_changed = DataStore.get_user_whose_devices_changed.__func__
get_devices_by_remote = DataStore.get_devices_by_remote.__func__
_get_devices_by_remote_txn = DataStore._get_devices_by_remote_txn.__func__
_get_e2e_device_keys_txn = DataStore._get_e2e_device_keys_txn.__func__
mark_as_sent_devices_by_remote = DataStore.mark_as_sent_devices_by_remote.__func__
_mark_as_sent_devices_by_remote_txn = (
DataStore._mark_as_sent_devices_by_remote_txn.__func__
)
def stream_positions(self):
result = super(SlavedDeviceStore, self).stream_positions()
result["device_lists"] = self._device_list_id_gen.get_current_token()
return result
def process_replication(self, result):
stream = result.get("device_lists")
if stream:
self._device_list_id_gen.advance(int(stream["position"]))
for row in stream["rows"]:
stream_id = row[0]
user_id = row[1]
destination = row[2]
self._device_list_stream_cache.entity_has_changed(
user_id, stream_id
)
if destination:
self._device_list_federation_stream_cache.entity_has_changed(
destination, stream_id
)
return super(SlavedDeviceStore, self).process_replication(result)

View File

@@ -76,9 +76,6 @@ class SlavedEventStore(BaseSlavedStore):
get_latest_event_ids_in_room = EventFederationStore.__dict__[
"get_latest_event_ids_in_room"
]
_get_current_state_for_key = StateStore.__dict__[
"_get_current_state_for_key"
]
get_invited_rooms_for_user = RoomMemberStore.__dict__[
"get_invited_rooms_for_user"
]
@@ -115,8 +112,6 @@ class SlavedEventStore(BaseSlavedStore):
)
get_event = DataStore.get_event.__func__
get_events = DataStore.get_events.__func__
get_current_state = DataStore.get_current_state.__func__
get_current_state_for_key = DataStore.get_current_state_for_key.__func__
get_rooms_for_user_where_membership_is = (
DataStore.get_rooms_for_user_where_membership_is.__func__
)
@@ -197,10 +192,6 @@ class SlavedEventStore(BaseSlavedStore):
return result
def process_replication(self, result):
state_resets = set(
r[0] for r in result.get("state_resets", {"rows": []})["rows"]
)
stream = result.get("events")
if stream:
self._stream_id_gen.advance(int(stream["position"]))
@@ -210,7 +201,7 @@ class SlavedEventStore(BaseSlavedStore):
for row in stream["rows"]:
self._process_replication_row(
row, backfilled=False, state_resets=state_resets
row, backfilled=False,
)
stream = result.get("backfill")
@@ -218,7 +209,7 @@ class SlavedEventStore(BaseSlavedStore):
self._backfill_id_gen.advance(-int(stream["position"]))
for row in stream["rows"]:
self._process_replication_row(
row, backfilled=True, state_resets=state_resets
row, backfilled=True,
)
stream = result.get("forward_ex_outliers")
@@ -237,21 +228,15 @@ class SlavedEventStore(BaseSlavedStore):
return super(SlavedEventStore, self).process_replication(result)
def _process_replication_row(self, row, backfilled, state_resets):
position = row[0]
def _process_replication_row(self, row, backfilled):
internal = json.loads(row[1])
event_json = json.loads(row[2])
event = FrozenEvent(event_json, internal_metadata_dict=internal)
self.invalidate_caches_for_event(
event, backfilled, reset_state=position in state_resets
event, backfilled,
)
def invalidate_caches_for_event(self, event, backfilled, reset_state):
if reset_state:
self._get_current_state_for_key.invalidate_all()
self.get_rooms_for_user.invalidate_all()
self.get_users_in_room.invalidate((event.room_id,))
def invalidate_caches_for_event(self, event, backfilled):
self._invalidate_get_event_cache(event.event_id)
self.get_latest_event_ids_in_room.invalidate((event.room_id,))
@@ -273,8 +258,6 @@ class SlavedEventStore(BaseSlavedStore):
self._invalidate_get_event_cache(event.redacts)
if event.type == EventTypes.Member:
self.get_rooms_for_user.invalidate((event.state_key,))
self.get_users_in_room.invalidate((event.room_id,))
self._membership_stream_cache.entity_has_changed(
event.state_key, event.internal_metadata.stream_ordering
)
@@ -289,7 +272,3 @@ class SlavedEventStore(BaseSlavedStore):
if (not event.internal_metadata.is_invite_from_remote()
and event.internal_metadata.is_outlier()):
return
self._get_current_state_for_key.invalidate((
event.room_id, event.type, event.state_key
))

View File

@@ -86,7 +86,11 @@ class HttpTransactionCache(object):
pass # execute the function instead.
deferred = fn(*args, **kwargs)
observable = ObservableDeferred(deferred)
# We don't add an errback to the raw deferred, so we ask ObservableDeferred
# to swallow the error. This is fine as the error will still be reported
# to the observers.
observable = ObservableDeferred(deferred, consumeErrors=True)
self.transactions[txn_key] = (observable, self.clock.time_msec())
return observable.observe()

View File

@@ -118,8 +118,14 @@ class LoginRestServlet(ClientV1RestServlet):
@defer.inlineCallbacks
def do_password_login(self, login_submission):
if 'medium' in login_submission and 'address' in login_submission:
address = login_submission['address']
if login_submission['medium'] == 'email':
# For emails, transform the address to lowercase.
# We store all email addreses as lowercase in the DB.
# (See add_threepid in synapse/handlers/auth.py)
address = address.lower()
user_id = yield self.hs.get_datastore().get_user_id_by_threepid(
login_submission['medium'], login_submission['address']
login_submission['medium'], address
)
if not user_id:
raise LoginError(403, "", errcode=Codes.FORBIDDEN)

View File

@@ -152,23 +152,29 @@ class RoomStateEventRestServlet(ClientV1RestServlet):
if state_key is not None:
event_dict["state_key"] = state_key
msg_handler = self.handlers.message_handler
event, context = yield msg_handler.create_event(
event_dict,
token_id=requester.access_token_id,
txn_id=txn_id,
)
if event_type == EventTypes.Member:
yield self.handlers.room_member_handler.send_membership_event(
membership = content.get("membership", None)
event = yield self.handlers.room_member_handler.update_membership(
requester,
event,
context,
target=UserID.from_string(state_key),
room_id=room_id,
action=membership,
content=content,
)
else:
msg_handler = self.handlers.message_handler
event, context = yield msg_handler.create_event(
event_dict,
token_id=requester.access_token_id,
txn_id=txn_id,
)
yield msg_handler.send_nonmember_event(requester, event, context)
defer.returnValue((200, {"event_id": event.event_id}))
ret = {}
if event:
ret = {"event_id": event.event_id}
defer.returnValue((200, ret))
# TODO: Needs unit testing for generic events + feedback

View File

@@ -32,19 +32,27 @@ class VoipRestServlet(ClientV1RestServlet):
turnUris = self.hs.config.turn_uris
turnSecret = self.hs.config.turn_shared_secret
turnUsername = self.hs.config.turn_username
turnPassword = self.hs.config.turn_password
userLifetime = self.hs.config.turn_user_lifetime
if not turnUris or not turnSecret or not userLifetime:
if turnUris and turnSecret and userLifetime:
expiry = (self.hs.get_clock().time_msec() + userLifetime) / 1000
username = "%d:%s" % (expiry, requester.user.to_string())
mac = hmac.new(turnSecret, msg=username, digestmod=hashlib.sha1)
# We need to use standard padded base64 encoding here
# encode_base64 because we need to add the standard padding to get the
# same result as the TURN server.
password = base64.b64encode(mac.digest())
elif turnUris and turnUsername and turnPassword and userLifetime:
username = turnUsername
password = turnPassword
else:
defer.returnValue((200, {}))
expiry = (self.hs.get_clock().time_msec() + userLifetime) / 1000
username = "%d:%s" % (expiry, requester.user.to_string())
mac = hmac.new(turnSecret, msg=username, digestmod=hashlib.sha1)
# We need to use standard padded base64 encoding here
# encode_base64 because we need to add the standard padding to get the
# same result as the TURN server.
password = base64.b64encode(mac.digest())
defer.returnValue((200, {
'username': username,
'password': password,

View File

@@ -96,6 +96,11 @@ class PasswordRestServlet(RestServlet):
threepid = result[LoginType.EMAIL_IDENTITY]
if 'medium' not in threepid or 'address' not in threepid:
raise SynapseError(500, "Malformed threepid")
if threepid['medium'] == 'email':
# For emails, transform the address to lowercase.
# We store all email addreses as lowercase in the DB.
# (See add_threepid in synapse/handlers/auth.py)
threepid['address'] = threepid['address'].lower()
# if using email, we must know about the email they're authing with!
threepid_user_id = yield self.hs.get_datastore().get_user_id_by_threepid(
threepid['medium'], threepid['address']
@@ -241,7 +246,7 @@ class ThreepidRestServlet(RestServlet):
for reqd in ['medium', 'address', 'validated_at']:
if reqd not in threepid:
logger.warn("Couldn't add 3pid: invalid response from ID sevrer")
logger.warn("Couldn't add 3pid: invalid response from ID server")
raise SynapseError(500, "Invalid response from ID Server")
yield self.auth_handler.add_threepid(
@@ -263,9 +268,43 @@ class ThreepidRestServlet(RestServlet):
defer.returnValue((200, {}))
class ThreepidDeleteRestServlet(RestServlet):
PATTERNS = client_v2_patterns("/account/3pid/delete$", releases=())
def __init__(self, hs):
super(ThreepidDeleteRestServlet, self).__init__()
self.auth = hs.get_auth()
self.auth_handler = hs.get_auth_handler()
@defer.inlineCallbacks
def on_POST(self, request):
yield run_on_reactor()
body = parse_json_object_from_request(request)
required = ['medium', 'address']
absent = []
for k in required:
if k not in body:
absent.append(k)
if absent:
raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
requester = yield self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
yield self.auth_handler.delete_threepid(
user_id, body['medium'], body['address']
)
defer.returnValue((200, {}))
def register_servlets(hs, http_server):
PasswordRequestTokenRestServlet(hs).register(http_server)
PasswordRestServlet(hs).register(http_server)
DeactivateAccountRestServlet(hs).register(http_server)
ThreepidRequestTokenRestServlet(hs).register(http_server)
ThreepidRestServlet(hs).register(http_server)
ThreepidDeleteRestServlet(hs).register(http_server)

View File

@@ -21,6 +21,8 @@ from synapse.api.errors import SynapseError
from synapse.http.servlet import (
RestServlet, parse_json_object_from_request, parse_integer
)
from synapse.http.servlet import parse_string
from synapse.types import StreamToken
from ._base import client_v2_patterns
logger = logging.getLogger(__name__)
@@ -149,6 +151,52 @@ class KeyQueryServlet(RestServlet):
defer.returnValue((200, result))
class KeyChangesServlet(RestServlet):
"""Returns the list of changes of keys between two stream tokens (may return
spurious extra results, since we currently ignore the `to` param).
GET /keys/changes?from=...&to=...
200 OK
{ "changed": ["@foo:example.com"] }
"""
PATTERNS = client_v2_patterns(
"/keys/changes$",
releases=()
)
def __init__(self, hs):
"""
Args:
hs (synapse.server.HomeServer):
"""
super(KeyChangesServlet, self).__init__()
self.auth = hs.get_auth()
self.device_handler = hs.get_device_handler()
@defer.inlineCallbacks
def on_GET(self, request):
requester = yield self.auth.get_user_by_req(request, allow_guest=True)
from_token_string = parse_string(request, "from")
# We want to enforce they do pass us one, but we ignore it and return
# changes after the "to" as well as before.
parse_string(request, "to")
from_token = StreamToken.from_string(from_token_string)
user_id = requester.user.to_string()
changed = yield self.device_handler.get_user_ids_changed(
user_id, from_token,
)
defer.returnValue((200, {
"changed": list(changed)
}))
class OneTimeKeyServlet(RestServlet):
"""
POST /keys/claim HTTP/1.1
@@ -192,4 +240,5 @@ class OneTimeKeyServlet(RestServlet):
def register_servlets(hs, http_server):
KeyUploadServlet(hs).register(http_server)
KeyQueryServlet(hs).register(http_server)
KeyChangesServlet(hs).register(http_server)
OneTimeKeyServlet(hs).register(http_server)

View File

@@ -170,12 +170,16 @@ class SyncRestServlet(RestServlet):
)
archived = self.encode_archived(
sync_result.archived, time_now, requester.access_token_id, filter.event_fields
sync_result.archived, time_now, requester.access_token_id,
filter.event_fields,
)
response_content = {
"account_data": {"events": sync_result.account_data},
"to_device": {"events": sync_result.to_device},
"device_lists": {
"changed": list(sync_result.device_lists),
},
"presence": self.encode_presence(
sync_result.presence, time_now
),

View File

@@ -61,7 +61,7 @@ class MediaRepository(object):
self.dynamic_thumbnails = hs.config.dynamic_thumbnails
self.thumbnail_requirements = hs.config.thumbnail_requirements
self.remote_media_linearizer = Linearizer()
self.remote_media_linearizer = Linearizer(name="media_remote")
self.recently_accessed_remotes = set()
@@ -98,6 +98,8 @@ class MediaRepository(object):
with open(fname, "wb") as f:
f.write(content)
logger.info("Stored local media in file %r", fname)
yield self.store.store_local_media(
media_id=media_id,
media_type=media_type,
@@ -190,6 +192,8 @@ class MediaRepository(object):
else:
upload_name = None
logger.info("Stored remote media in file %r", fname)
yield self.store.store_cached_remote_media(
origin=server_name,
media_id=media_id,

View File

@@ -16,6 +16,10 @@
import PIL.Image as Image
from io import BytesIO
import logging
logger = logging.getLogger(__name__)
class Thumbnailer(object):
@@ -86,4 +90,5 @@ class Thumbnailer(object):
output_bytes = output_bytes_io.getvalue()
with open(output_path, "wb") as output_file:
output_file.write(output_bytes)
logger.info("Stored thumbnail in file %r", output_path)
return len(output_bytes)

View File

@@ -97,6 +97,8 @@ class UploadResource(Resource):
content_length, requester.user
)
logger.info("Uploaded content with URI %r", content_uri)
respond_with_json(
request, 200, {"content_uri": content_uri}, send_cors=True
)

View File

@@ -16,12 +16,12 @@
from twisted.internet import defer
from synapse import event_auth
from synapse.util.logutils import log_function
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.metrics import Measure
from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError
from synapse.api.auth import AuthEventTypes
from synapse.events.snapshot import EventContext
from synapse.util.async import Linearizer
@@ -41,12 +41,14 @@ KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key"))
CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1))
SIZE_OF_CACHE = int(1000 * CACHE_SIZE_FACTOR)
SIZE_OF_CACHE = int(100000 * CACHE_SIZE_FACTOR)
EVICTION_TIMEOUT_SECONDS = 60 * 60
_NEXT_STATE_ID = 1
POWER_KEY = (EventTypes.PowerLevels, "")
def _gen_state_id():
global _NEXT_STATE_ID
@@ -77,6 +79,9 @@ class _StateCacheEntry(object):
else:
self.state_id = _gen_state_id()
def __len__(self):
return len(self.state)
class StateHandler(object):
""" Responsible for doing state conflict resolution.
@@ -89,7 +94,7 @@ class StateHandler(object):
# dict of set of event_ids -> _StateCacheEntry.
self._state_cache = None
self.resolve_linearizer = Linearizer()
self.resolve_linearizer = Linearizer(name="state_resolve_lock")
def start_caching(self):
logger.debug("start_caching")
@@ -99,6 +104,7 @@ class StateHandler(object):
clock=self.clock,
max_len=SIZE_OF_CACHE,
expiry_ms=EVICTION_TIMEOUT_SECONDS * 1000,
iterable=True,
reset_expiry_on_get=True,
)
@@ -123,6 +129,7 @@ class StateHandler(object):
if not latest_event_ids:
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
logger.debug("calling resolve_state_groups from get_current_state")
ret = yield self.resolve_state_groups(room_id, latest_event_ids)
state = ret.state
@@ -147,6 +154,7 @@ class StateHandler(object):
if not latest_event_ids:
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
logger.debug("calling resolve_state_groups from get_current_state_ids")
ret = yield self.resolve_state_groups(room_id, latest_event_ids)
state = ret.state
@@ -160,6 +168,7 @@ class StateHandler(object):
def get_current_user_in_room(self, room_id, latest_event_ids=None):
if not latest_event_ids:
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
logger.debug("calling resolve_state_groups from get_current_user_in_room")
entry = yield self.resolve_state_groups(room_id, latest_event_ids)
joined_users = yield self.store.get_joined_users_from_state(
room_id, entry.state_id, entry.state
@@ -223,6 +232,7 @@ class StateHandler(object):
context.prev_state_events = []
defer.returnValue(context)
logger.debug("calling resolve_state_groups from compute_event_context")
if event.is_state():
entry = yield self.resolve_state_groups(
event.room_id, [e for e, _ in event.prev_events],
@@ -323,20 +333,13 @@ class StateHandler(object):
if conflicted_state:
logger.info("Resolving conflicted state for %r", room_id)
state_map = yield self.store.get_events(
[e_id for st in state_groups_ids.values() for e_id in st.values()],
get_prev_content=False
)
state_sets = [
[state_map[e_id] for key, e_id in st.items() if e_id in state_map]
for st in state_groups_ids.values()
]
new_state, _ = self._resolve_events(
state_sets, event_type, state_key
)
new_state = {
key: e.event_id for key, e in new_state.items()
}
with Measure(self.clock, "state._resolve_events"):
new_state = yield resolve_events(
state_groups_ids.values(),
state_map_factory=lambda ev_ids: self.store.get_events(
ev_ids, get_prev_content=False, check_redacted=False,
),
)
else:
new_state = {
key: e_ids.pop() for key, e_ids in state.items()
@@ -384,152 +387,267 @@ class StateHandler(object):
logger.info(
"Resolving state for %s with %d groups", event.room_id, len(state_sets)
)
if event.is_state():
return self._resolve_events(
state_sets, event.type, event.state_key
)
else:
return self._resolve_events(state_sets)
state_set_ids = [{
(ev.type, ev.state_key): ev.event_id
for ev in st
} for st in state_sets]
state_map = {
ev.event_id: ev
for st in state_sets
for ev in st
}
def _resolve_events(self, state_sets, event_type=None, state_key=""):
"""
Returns
(dict[(str, str), synapse.events.FrozenEvent], list[str]): a tuple
(new_state, prev_states). new_state is a map from (type, state_key)
to event. prev_states is a list of event_ids.
"""
with Measure(self.clock, "state._resolve_events"):
state = {}
for st in state_sets:
for e in st:
state.setdefault(
(e.type, e.state_key),
{}
)[e.event_id] = e
new_state = resolve_events(state_set_ids, state_map)
unconflicted_state = {
k: v.values()[0] for k, v in state.items()
if len(v.values()) == 1
}
new_state = {
key: state_map[ev_id] for key, ev_id in new_state.items()
}
conflicted_state = {
k: v.values()
for k, v in state.items()
if len(v.values()) > 1
}
return new_state
if event_type:
prev_states_events = conflicted_state.get(
(event_type, state_key), []
)
prev_states = [s.event_id for s in prev_states_events]
else:
prev_states = []
auth_events = {
k: e for k, e in unconflicted_state.items()
if k[0] in AuthEventTypes
}
def _ordered_events(events):
def key_func(e):
return -int(e.depth), hashlib.sha1(e.event_id).hexdigest()
try:
resolved_state = self._resolve_state_events(
conflicted_state, auth_events
)
except:
logger.exception("Failed to resolve state")
raise
return sorted(events, key=key_func)
new_state = unconflicted_state
new_state.update(resolved_state)
return new_state, prev_states
def resolve_events(state_sets, state_map_factory):
"""
Args:
state_sets(list): List of dicts of (type, state_key) -> event_id,
which are the different state groups to resolve.
state_map_factory(dict|callable): If callable, then will be called
with a list of event_ids that are needed, and should return with
a Deferred of dict of event_id to event. Otherwise, should be
a dict from event_id to event of all events in state_sets.
@log_function
def _resolve_state_events(self, conflicted_state, auth_events):
""" This is where we actually decide which of the conflicted state to
use.
Returns
dict[(str, str), synapse.events.FrozenEvent] is a map from
(type, state_key) to event.
"""
if len(state_sets) == 1:
return state_sets[0]
We resolve conflicts in the following order:
1. power levels
2. join rules
3. memberships
4. other events.
"""
resolved_state = {}
power_key = (EventTypes.PowerLevels, "")
if power_key in conflicted_state:
events = conflicted_state[power_key]
logger.debug("Resolving conflicted power levels %r", events)
resolved_state[power_key] = self._resolve_auth_events(
events, auth_events)
unconflicted_state, conflicted_state = _seperate(
state_sets,
)
auth_events.update(resolved_state)
if callable(state_map_factory):
return _resolve_with_state_fac(
unconflicted_state, conflicted_state, state_map_factory
)
for key, events in conflicted_state.items():
if key[0] == EventTypes.JoinRules:
logger.debug("Resolving conflicted join rules %r", events)
resolved_state[key] = self._resolve_auth_events(
events,
auth_events
)
state_map = state_map_factory
auth_events.update(resolved_state)
auth_events = _create_auth_events_from_maps(
unconflicted_state, conflicted_state, state_map
)
for key, events in conflicted_state.items():
if key[0] == EventTypes.Member:
logger.debug("Resolving conflicted member lists %r", events)
resolved_state[key] = self._resolve_auth_events(
events,
auth_events
)
return _resolve_with_state(
unconflicted_state, conflicted_state, auth_events, state_map
)
auth_events.update(resolved_state)
for key, events in conflicted_state.items():
if key not in resolved_state:
logger.debug("Resolving conflicted state %r:%r", key, events)
resolved_state[key] = self._resolve_normal_events(
events, auth_events
)
def _seperate(state_sets):
"""Takes the state_sets and figures out which keys are conflicted and
which aren't. i.e., which have multiple different event_ids associated
with them in different state sets.
"""
unconflicted_state = dict(state_sets[0])
conflicted_state = {}
return resolved_state
for state_set in state_sets[1:]:
for key, value in state_set.iteritems():
# Check if there is an unconflicted entry for the state key.
unconflicted_value = unconflicted_state.get(key)
if unconflicted_value is None:
# There isn't an unconflicted entry so check if there is a
# conflicted entry.
ls = conflicted_state.get(key)
if ls is None:
# There wasn't a conflicted entry so haven't seen this key before.
# Therefore it isn't conflicted yet.
unconflicted_state[key] = value
else:
# This key is already conflicted, add our value to the conflict set.
ls.add(value)
elif unconflicted_value != value:
# If the unconflicted value is not the same as our value then we
# have a new conflict. So move the key from the unconflicted_state
# to the conflicted state.
conflicted_state[key] = {value, unconflicted_value}
unconflicted_state.pop(key, None)
def _resolve_auth_events(self, events, auth_events):
reverse = [i for i in reversed(self._ordered_events(events))]
return unconflicted_state, conflicted_state
auth_events = dict(auth_events)
prev_event = reverse[0]
for event in reverse[1:]:
auth_events[(prev_event.type, prev_event.state_key)] = prev_event
try:
# FIXME: hs.get_auth() is bad style, but we need to do it to
# get around circular deps.
# The signatures have already been checked at this point
self.hs.get_auth().check(event, auth_events, do_sig_check=False)
prev_event = event
except AuthError:
return prev_event
@defer.inlineCallbacks
def _resolve_with_state_fac(unconflicted_state, conflicted_state,
state_map_factory):
needed_events = set(
event_id
for event_ids in conflicted_state.itervalues()
for event_id in event_ids
)
return event
logger.info("Asking for %d conflicted events", len(needed_events))
def _resolve_normal_events(self, events, auth_events):
for event in self._ordered_events(events):
try:
# FIXME: hs.get_auth() is bad style, but we need to do it to
# get around circular deps.
# The signatures have already been checked at this point
self.hs.get_auth().check(event, auth_events, do_sig_check=False)
return event
except AuthError:
pass
state_map = yield state_map_factory(needed_events)
# Use the last event (the one with the least depth) if they all fail
# the auth check.
return event
auth_events = _create_auth_events_from_maps(
unconflicted_state, conflicted_state, state_map
)
def _ordered_events(self, events):
def key_func(e):
return -int(e.depth), hashlib.sha1(e.event_id).hexdigest()
new_needed_events = set(auth_events.itervalues())
new_needed_events -= needed_events
return sorted(events, key=key_func)
logger.info("Asking for %d auth events", len(new_needed_events))
state_map_new = yield state_map_factory(new_needed_events)
state_map.update(state_map_new)
defer.returnValue(_resolve_with_state(
unconflicted_state, conflicted_state, auth_events, state_map
))
def _create_auth_events_from_maps(unconflicted_state, conflicted_state, state_map):
auth_events = {}
for event_ids in conflicted_state.itervalues():
for event_id in event_ids:
if event_id in state_map:
keys = event_auth.auth_types_for_event(state_map[event_id])
for key in keys:
if key not in auth_events:
event_id = unconflicted_state.get(key, None)
if event_id:
auth_events[key] = event_id
return auth_events
def _resolve_with_state(unconflicted_state_ids, conflicted_state_ds, auth_event_ids,
state_map):
conflicted_state = {}
for key, event_ids in conflicted_state_ds.iteritems():
events = [state_map[ev_id] for ev_id in event_ids if ev_id in state_map]
if len(events) > 1:
conflicted_state[key] = events
elif len(events) == 1:
unconflicted_state_ids[key] = events[0].event_id
auth_events = {
key: state_map[ev_id]
for key, ev_id in auth_event_ids.items()
if ev_id in state_map
}
try:
resolved_state = _resolve_state_events(
conflicted_state, auth_events
)
except:
logger.exception("Failed to resolve state")
raise
new_state = unconflicted_state_ids
for key, event in resolved_state.iteritems():
new_state[key] = event.event_id
return new_state
def _resolve_state_events(conflicted_state, auth_events):
""" This is where we actually decide which of the conflicted state to
use.
We resolve conflicts in the following order:
1. power levels
2. join rules
3. memberships
4. other events.
"""
resolved_state = {}
if POWER_KEY in conflicted_state:
events = conflicted_state[POWER_KEY]
logger.debug("Resolving conflicted power levels %r", events)
resolved_state[POWER_KEY] = _resolve_auth_events(
events, auth_events)
auth_events.update(resolved_state)
for key, events in conflicted_state.items():
if key[0] == EventTypes.JoinRules:
logger.debug("Resolving conflicted join rules %r", events)
resolved_state[key] = _resolve_auth_events(
events,
auth_events
)
auth_events.update(resolved_state)
for key, events in conflicted_state.items():
if key[0] == EventTypes.Member:
logger.debug("Resolving conflicted member lists %r", events)
resolved_state[key] = _resolve_auth_events(
events,
auth_events
)
auth_events.update(resolved_state)
for key, events in conflicted_state.items():
if key not in resolved_state:
logger.debug("Resolving conflicted state %r:%r", key, events)
resolved_state[key] = _resolve_normal_events(
events, auth_events
)
return resolved_state
def _resolve_auth_events(events, auth_events):
reverse = [i for i in reversed(_ordered_events(events))]
auth_keys = set(
key
for event in events
for key in event_auth.auth_types_for_event(event)
)
new_auth_events = {}
for key in auth_keys:
auth_event = auth_events.get(key, None)
if auth_event:
new_auth_events[key] = auth_event
auth_events = new_auth_events
prev_event = reverse[0]
for event in reverse[1:]:
auth_events[(prev_event.type, prev_event.state_key)] = prev_event
try:
# The signatures have already been checked at this point
event_auth.check(event, auth_events, do_sig_check=False, do_size_check=False)
prev_event = event
except AuthError:
return prev_event
return event
def _resolve_normal_events(events, auth_events):
for event in _ordered_events(events):
try:
# The signatures have already been checked at this point
event_auth.check(event, auth_events, do_sig_check=False, do_size_check=False)
return event
except AuthError:
pass
# Use the last event (the one with the least depth) if they all fail
# the auth check.
return event

View File

@@ -116,6 +116,9 @@ class DataStore(RoomMemberStore, RoomStore,
self._public_room_id_gen = StreamIdGenerator(
db_conn, "public_room_list_stream", "stream_id"
)
self._device_list_id_gen = StreamIdGenerator(
db_conn, "device_lists_stream", "stream_id",
)
self._transaction_id_gen = IdGenerator(db_conn, "sent_transactions", "id")
self._state_groups_id_gen = IdGenerator(db_conn, "state_groups", "id")
@@ -189,7 +192,8 @@ class DataStore(RoomMemberStore, RoomStore,
db_conn, "device_inbox",
entity_column="user_id",
stream_column="stream_id",
max_value=max_device_inbox_id
max_value=max_device_inbox_id,
limit=1000,
)
self._device_inbox_stream_cache = StreamChangeCache(
"DeviceInboxStreamChangeCache", min_device_inbox_id,
@@ -202,12 +206,21 @@ class DataStore(RoomMemberStore, RoomStore,
entity_column="destination",
stream_column="stream_id",
max_value=max_device_inbox_id,
limit=1000,
)
self._device_federation_outbox_stream_cache = StreamChangeCache(
"DeviceFederationOutboxStreamChangeCache", min_device_outbox_id,
prefilled_cache=device_outbox_prefill,
)
device_list_max = self._device_list_id_gen.get_current_token()
self._device_list_stream_cache = StreamChangeCache(
"DeviceListStreamChangeCache", device_list_max,
)
self._device_list_federation_stream_cache = StreamChangeCache(
"DeviceListFederationStreamChangeCache", device_list_max,
)
cur = LoggingTransaction(
db_conn.cursor(),
name="_find_stream_orderings_for_times_txn",

View File

@@ -169,7 +169,7 @@ class SQLBaseStore(object):
max_entries=hs.config.event_cache_size)
self._state_group_cache = DictionaryCache(
"*stateGroupCache*", 2000 * CACHE_SIZE_FACTOR
"*stateGroupCache*", 100000 * CACHE_SIZE_FACTOR
)
self._event_fetch_lock = threading.Condition()
@@ -387,6 +387,10 @@ class SQLBaseStore(object):
Args:
table : string giving the table name
values : dict of new column names and values for them
Returns:
bool: Whether the row was inserted or not. Only useful when
`or_ignore` is True
"""
try:
yield self.runInteraction(
@@ -398,6 +402,8 @@ class SQLBaseStore(object):
# a cursor after we receive an error from the db.
if not or_ignore:
raise
defer.returnValue(False)
defer.returnValue(True)
@staticmethod
def _simple_insert_txn(txn, table, values):
@@ -838,18 +844,19 @@ class SQLBaseStore(object):
return txn.execute(sql, keyvalues.values())
def _get_cache_dict(self, db_conn, table, entity_column, stream_column,
max_value):
max_value, limit=100000):
# Fetch a mapping of room_id -> max stream position for "recent" rooms.
# It doesn't really matter how many we get, the StreamChangeCache will
# do the right thing to ensure it respects the max size of cache.
sql = (
"SELECT %(entity)s, MAX(%(stream)s) FROM %(table)s"
" WHERE %(stream)s > ? - 100000"
" WHERE %(stream)s > ? - %(limit)s"
" GROUP BY %(entity)s"
) % {
"table": table,
"entity": entity_column,
"stream": stream_column,
"limit": limit,
}
sql = self.database_engine.convert_param_style(sql)

View File

@@ -18,13 +18,29 @@ import ujson
from twisted.internet import defer
from ._base import SQLBaseStore
from .background_updates import BackgroundUpdateStore
logger = logging.getLogger(__name__)
class DeviceInboxStore(SQLBaseStore):
class DeviceInboxStore(BackgroundUpdateStore):
DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
def __init__(self, hs):
super(DeviceInboxStore, self).__init__(hs)
self.register_background_index_update(
"device_inbox_stream_index",
index_name="device_inbox_stream_id_user_id",
table="device_inbox",
columns=["stream_id", "user_id"],
)
self.register_background_update_handler(
self.DEVICE_INBOX_STREAM_ID,
self._background_drop_index_device_inbox,
)
@defer.inlineCallbacks
def add_messages_to_device_inbox(self, local_messages_by_user_then_device,
@@ -368,3 +384,18 @@ class DeviceInboxStore(SQLBaseStore):
"delete_device_msgs_for_remote",
delete_messages_for_remote_destination_txn
)
@defer.inlineCallbacks
def _background_drop_index_device_inbox(self, progress, batch_size):
def reindex_txn(conn):
txn = conn.cursor()
txn.execute(
"DROP INDEX IF EXISTS device_inbox_stream_id"
)
txn.close()
yield self.runWithConnection(reindex_txn)
yield self._end_background_update(self.DEVICE_INBOX_STREAM_ID)
defer.returnValue(1)

View File

@@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import ujson as json
from twisted.internet import defer
@@ -23,27 +24,29 @@ logger = logging.getLogger(__name__)
class DeviceStore(SQLBaseStore):
def __init__(self, hs):
super(DeviceStore, self).__init__(hs)
self._clock.looping_call(
self._prune_old_outbound_device_pokes, 60 * 60 * 1000
)
@defer.inlineCallbacks
def store_device(self, user_id, device_id,
initial_device_display_name,
ignore_if_known=True):
initial_device_display_name):
"""Ensure the given device is known; add it to the store if not
Args:
user_id (str): id of user associated with the device
device_id (str): id of device
initial_device_display_name (str): initial displayname of the
device
ignore_if_known (bool): ignore integrity errors which mean the
device is already known
device. Ignored if device exists.
Returns:
defer.Deferred
Raises:
StoreError: if ignore_if_known is False and the device was already
known
defer.Deferred: boolean whether the device was inserted or an
existing device existed with that ID.
"""
try:
yield self._simple_insert(
inserted = yield self._simple_insert(
"devices",
values={
"user_id": user_id,
@@ -51,8 +54,9 @@ class DeviceStore(SQLBaseStore):
"display_name": initial_device_display_name
},
desc="store_device",
or_ignore=ignore_if_known,
or_ignore=True,
)
defer.returnValue(inserted)
except Exception as e:
logger.error("store_device with device_id=%s(%r) user_id=%s(%r)"
" display_name=%s(%r) failed: %s",
@@ -139,3 +143,451 @@ class DeviceStore(SQLBaseStore):
)
defer.returnValue({d["device_id"]: d for d in devices})
def get_device_list_last_stream_id_for_remote(self, user_id):
"""Get the last stream_id we got for a user. May be None if we haven't
got any information for them.
"""
return self._simple_select_one_onecol(
table="device_lists_remote_extremeties",
keyvalues={"user_id": user_id},
retcol="stream_id",
desc="get_device_list_remote_extremity",
allow_none=True,
)
def mark_remote_user_device_list_as_unsubscribed(self, user_id):
"""Mark that we no longer track device lists for remote user.
"""
return self._simple_delete(
table="device_lists_remote_extremeties",
keyvalues={
"user_id": user_id,
},
desc="mark_remote_user_device_list_as_unsubscribed",
)
def update_remote_device_list_cache_entry(self, user_id, device_id, content,
stream_id):
"""Updates a single user's device in the cache.
"""
return self.runInteraction(
"update_remote_device_list_cache_entry",
self._update_remote_device_list_cache_entry_txn,
user_id, device_id, content, stream_id,
)
def _update_remote_device_list_cache_entry_txn(self, txn, user_id, device_id,
content, stream_id):
self._simple_upsert_txn(
txn,
table="device_lists_remote_cache",
keyvalues={
"user_id": user_id,
"device_id": device_id,
},
values={
"content": json.dumps(content),
}
)
self._simple_upsert_txn(
txn,
table="device_lists_remote_extremeties",
keyvalues={
"user_id": user_id,
},
values={
"stream_id": stream_id,
}
)
def update_remote_device_list_cache(self, user_id, devices, stream_id):
"""Replace the cache of the remote user's devices.
"""
return self.runInteraction(
"update_remote_device_list_cache",
self._update_remote_device_list_cache_txn,
user_id, devices, stream_id,
)
def _update_remote_device_list_cache_txn(self, txn, user_id, devices,
stream_id):
self._simple_delete_txn(
txn,
table="device_lists_remote_cache",
keyvalues={
"user_id": user_id,
},
)
self._simple_insert_many_txn(
txn,
table="device_lists_remote_cache",
values=[
{
"user_id": user_id,
"device_id": content["device_id"],
"content": json.dumps(content),
}
for content in devices
]
)
self._simple_upsert_txn(
txn,
table="device_lists_remote_extremeties",
keyvalues={
"user_id": user_id,
},
values={
"stream_id": stream_id,
}
)
def get_devices_by_remote(self, destination, from_stream_id):
"""Get stream of updates to send to remote servers
Returns:
(now_stream_id, [ { updates }, .. ])
"""
now_stream_id = self._device_list_id_gen.get_current_token()
has_changed = self._device_list_federation_stream_cache.has_entity_changed(
destination, int(from_stream_id)
)
if not has_changed:
return (now_stream_id, [])
return self.runInteraction(
"get_devices_by_remote", self._get_devices_by_remote_txn,
destination, from_stream_id, now_stream_id,
)
def _get_devices_by_remote_txn(self, txn, destination, from_stream_id,
now_stream_id):
sql = """
SELECT user_id, device_id, max(stream_id) FROM device_lists_outbound_pokes
WHERE destination = ? AND ? < stream_id AND stream_id <= ? AND sent = ?
GROUP BY user_id, device_id
"""
txn.execute(
sql, (destination, from_stream_id, now_stream_id, False)
)
rows = txn.fetchall()
if not rows:
return (now_stream_id, [])
# maps (user_id, device_id) -> stream_id
query_map = {(r[0], r[1]): r[2] for r in rows}
devices = self._get_e2e_device_keys_txn(
txn, query_map.keys(), include_all_devices=True
)
prev_sent_id_sql = """
SELECT coalesce(max(stream_id), 0) as stream_id
FROM device_lists_outbound_pokes
WHERE destination = ? AND user_id = ? AND stream_id <= ?
"""
results = []
for user_id, user_devices in devices.iteritems():
# The prev_id for the first row is always the last row before
# `from_stream_id`
txn.execute(prev_sent_id_sql, (destination, user_id, from_stream_id))
rows = txn.fetchall()
prev_id = rows[0][0]
for device_id, device in user_devices.iteritems():
stream_id = query_map[(user_id, device_id)]
result = {
"user_id": user_id,
"device_id": device_id,
"prev_id": [prev_id] if prev_id else [],
"stream_id": stream_id,
}
prev_id = stream_id
key_json = device.get("key_json", None)
if key_json:
result["keys"] = json.loads(key_json)
device_display_name = device.get("device_display_name", None)
if device_display_name:
result["device_display_name"] = device_display_name
results.append(result)
return (now_stream_id, results)
def get_user_devices_from_cache(self, query_list):
"""Get the devices (and keys if any) for remote users from the cache.
Args:
query_list(list): List of (user_id, device_ids), if device_ids is
falsey then return all device ids for that user.
Returns:
(user_ids_not_in_cache, results_map), where user_ids_not_in_cache is
a set of user_ids and results_map is a mapping of
user_id -> device_id -> device_info
"""
return self.runInteraction(
"get_user_devices_from_cache", self._get_user_devices_from_cache_txn,
query_list,
)
def _get_user_devices_from_cache_txn(self, txn, query_list):
user_ids = {user_id for user_id, _ in query_list}
user_ids_in_cache = set()
for user_id in user_ids:
stream_ids = self._simple_select_onecol_txn(
txn,
table="device_lists_remote_extremeties",
keyvalues={
"user_id": user_id,
},
retcol="stream_id",
)
if stream_ids:
user_ids_in_cache.add(user_id)
user_ids_not_in_cache = user_ids - user_ids_in_cache
results = {}
for user_id, device_id in query_list:
if user_id not in user_ids_in_cache:
continue
if device_id:
content = self._simple_select_one_onecol_txn(
txn,
table="device_lists_remote_cache",
keyvalues={
"user_id": user_id,
"device_id": device_id,
},
retcol="content",
)
results.setdefault(user_id, {})[device_id] = json.loads(content)
else:
devices = self._simple_select_list_txn(
txn,
table="device_lists_remote_cache",
keyvalues={
"user_id": user_id,
},
retcols=("device_id", "content"),
)
results[user_id] = {
device["device_id"]: json.loads(device["content"])
for device in devices
}
user_ids_in_cache.discard(user_id)
return user_ids_not_in_cache, results
def get_devices_with_keys_by_user(self, user_id):
"""Get all devices (with any device keys) for a user
Returns:
(stream_id, devices)
"""
return self.runInteraction(
"get_devices_with_keys_by_user",
self._get_devices_with_keys_by_user_txn, user_id,
)
def _get_devices_with_keys_by_user_txn(self, txn, user_id):
now_stream_id = self._device_list_id_gen.get_current_token()
devices = self._get_e2e_device_keys_txn(
txn, [(user_id, None)], include_all_devices=True
)
if devices:
user_devices = devices[user_id]
results = []
for device_id, device in user_devices.iteritems():
result = {
"device_id": device_id,
}
key_json = device.get("key_json", None)
if key_json:
result["keys"] = json.loads(key_json)
device_display_name = device.get("device_display_name", None)
if device_display_name:
result["device_display_name"] = device_display_name
results.append(result)
return now_stream_id, results
return now_stream_id, []
def mark_as_sent_devices_by_remote(self, destination, stream_id):
"""Mark that updates have successfully been sent to the destination.
"""
return self.runInteraction(
"mark_as_sent_devices_by_remote", self._mark_as_sent_devices_by_remote_txn,
destination, stream_id,
)
def _mark_as_sent_devices_by_remote_txn(self, txn, destination, stream_id):
# First we DELETE all rows such that only the latest row for each
# (destination, user_id is left. We do this by selecting first and
# deleting.
sql = """
SELECT user_id, coalesce(max(stream_id), 0) FROM device_lists_outbound_pokes
WHERE destination = ? AND stream_id <= ?
GROUP BY user_id
HAVING count(*) > 1
"""
txn.execute(sql, (destination, stream_id,))
rows = txn.fetchall()
sql = """
DELETE FROM device_lists_outbound_pokes
WHERE destination = ? AND user_id = ? AND stream_id < ?
"""
txn.executemany(
sql, ((destination, row[0], row[1],) for row in rows)
)
# Mark everything that is left as sent
sql = """
UPDATE device_lists_outbound_pokes SET sent = ?
WHERE destination = ? AND stream_id <= ?
"""
txn.execute(sql, (True, destination, stream_id,))
@defer.inlineCallbacks
def get_user_whose_devices_changed(self, from_key):
"""Get set of users whose devices have changed since `from_key`.
"""
from_key = int(from_key)
changed = self._device_list_stream_cache.get_all_entities_changed(from_key)
if changed is not None:
defer.returnValue(set(changed))
sql = """
SELECT user_id FROM device_lists_stream WHERE stream_id > ?
"""
rows = yield self._execute("get_user_whose_devices_changed", None, sql, from_key)
defer.returnValue(set(row[0] for row in rows))
def get_all_device_list_changes_for_remotes(self, from_key):
"""Return a list of `(stream_id, user_id, destination)` which is the
combined list of changes to devices, and which destinations need to be
poked. `destination` may be None if no destinations need to be poked.
"""
sql = """
SELECT stream_id, user_id, destination FROM device_lists_stream
LEFT JOIN device_lists_outbound_pokes USING (stream_id, user_id, device_id)
WHERE stream_id > ?
"""
return self._execute(
"get_users_and_hosts_device_list", None,
sql, from_key,
)
@defer.inlineCallbacks
def add_device_change_to_streams(self, user_id, device_ids, hosts):
"""Persist that a user's devices have been updated, and which hosts
(if any) should be poked.
"""
with self._device_list_id_gen.get_next() as stream_id:
yield self.runInteraction(
"add_device_change_to_streams", self._add_device_change_txn,
user_id, device_ids, hosts, stream_id,
)
defer.returnValue(stream_id)
def _add_device_change_txn(self, txn, user_id, device_ids, hosts, stream_id):
now = self._clock.time_msec()
txn.call_after(
self._device_list_stream_cache.entity_has_changed,
user_id, stream_id,
)
for host in hosts:
txn.call_after(
self._device_list_federation_stream_cache.entity_has_changed,
host, stream_id,
)
self._simple_insert_many_txn(
txn,
table="device_lists_stream",
values=[
{
"stream_id": stream_id,
"user_id": user_id,
"device_id": device_id,
}
for device_id in device_ids
]
)
self._simple_insert_many_txn(
txn,
table="device_lists_outbound_pokes",
values=[
{
"destination": destination,
"stream_id": stream_id,
"user_id": user_id,
"device_id": device_id,
"sent": False,
"ts": now,
}
for destination in hosts
for device_id in device_ids
]
)
def get_device_stream_token(self):
return self._device_list_id_gen.get_current_token()
def _prune_old_outbound_device_pokes(self):
"""Delete old entries out of the device_lists_outbound_pokes to ensure
that we don't fill up due to dead servers. We keep one entry per
(destination, user_id) tuple to ensure that the prev_ids remain correct
if the server does come back.
"""
yesterday = self._clock.time_msec() - 24 * 60 * 60 * 1000
def _prune_txn(txn):
select_sql = """
SELECT destination, user_id, max(stream_id) as stream_id
FROM device_lists_outbound_pokes
GROUP BY destination, user_id
HAVING min(ts) < ? AND count(*) > 1
"""
txn.execute(select_sql, (yesterday,))
rows = txn.fetchall()
if not rows:
return
delete_sql = """
DELETE FROM device_lists_outbound_pokes
WHERE ts < ? AND destination = ? AND user_id = ? AND stream_id < ?
"""
txn.executemany(
delete_sql,
(
(yesterday, row[0], row[1], row[2])
for row in rows
)
)
logger.info("Pruned %d device list outbound pokes", txn.rowcount)
return self.runInteraction(
"_prune_old_outbound_device_pokes", _prune_txn
)

View File

@@ -12,74 +12,111 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
from twisted.internet import defer
import twisted.internet.defer
from canonicaljson import encode_canonical_json
import ujson as json
from ._base import SQLBaseStore
class EndToEndKeyStore(SQLBaseStore):
def set_e2e_device_keys(self, user_id, device_id, time_now, json_bytes):
return self._simple_upsert(
table="e2e_device_keys_json",
keyvalues={
"user_id": user_id,
"device_id": device_id,
},
values={
"ts_added_ms": time_now,
"key_json": json_bytes,
}
def set_e2e_device_keys(self, user_id, device_id, time_now, device_keys):
"""Stores device keys for a device. Returns whether there was a change
or the keys were already in the database.
"""
def _set_e2e_device_keys_txn(txn):
old_key_json = self._simple_select_one_onecol_txn(
txn,
table="e2e_device_keys_json",
keyvalues={
"user_id": user_id,
"device_id": device_id,
},
retcol="key_json",
allow_none=True,
)
new_key_json = encode_canonical_json(device_keys)
if old_key_json == new_key_json:
return False
self._simple_upsert_txn(
txn,
table="e2e_device_keys_json",
keyvalues={
"user_id": user_id,
"device_id": device_id,
},
values={
"ts_added_ms": time_now,
"key_json": new_key_json,
}
)
return True
return self.runInteraction(
"set_e2e_device_keys", _set_e2e_device_keys_txn
)
def get_e2e_device_keys(self, query_list):
@defer.inlineCallbacks
def get_e2e_device_keys(self, query_list, include_all_devices=False):
"""Fetch a list of device keys.
Args:
query_list(list): List of pairs of user_ids and device_ids.
include_all_devices (bool): whether to include entries for devices
that don't have device keys
Returns:
Dict mapping from user-id to dict mapping from device_id to
dict containing "key_json", "device_display_name".
"""
if not query_list:
return {}
defer.returnValue({})
return self.runInteraction(
"get_e2e_device_keys", self._get_e2e_device_keys_txn, query_list
results = yield self.runInteraction(
"get_e2e_device_keys", self._get_e2e_device_keys_txn,
query_list, include_all_devices,
)
def _get_e2e_device_keys_txn(self, txn, query_list):
for user_id, device_keys in results.iteritems():
for device_id, device_info in device_keys.iteritems():
device_info["keys"] = json.loads(device_info.pop("key_json"))
defer.returnValue(results)
def _get_e2e_device_keys_txn(self, txn, query_list, include_all_devices):
query_clauses = []
query_params = []
for (user_id, device_id) in query_list:
query_clause = "k.user_id = ?"
query_clause = "user_id = ?"
query_params.append(user_id)
if device_id:
query_clause += " AND k.device_id = ?"
query_clause += " AND device_id = ?"
query_params.append(device_id)
query_clauses.append(query_clause)
sql = (
"SELECT k.user_id, k.device_id, "
"SELECT user_id, device_id, "
" d.display_name AS device_display_name, "
" k.key_json"
" FROM e2e_device_keys_json k"
" LEFT JOIN devices d ON d.user_id = k.user_id"
" AND d.device_id = k.device_id"
" FROM devices d"
" %s JOIN e2e_device_keys_json k USING (user_id, device_id)"
" WHERE %s"
) % (
"LEFT" if include_all_devices else "INNER",
" OR ".join("(" + q + ")" for q in query_clauses)
)
txn.execute(sql, query_params)
rows = self.cursor_to_dict(txn)
result = collections.defaultdict(dict)
result = {}
for row in rows:
result[row["user_id"]][row["device_id"]] = row
result.setdefault(row["user_id"], {})[row["device_id"]] = row
return result
@@ -152,7 +189,7 @@ class EndToEndKeyStore(SQLBaseStore):
"claim_e2e_one_time_keys", _claim_e2e_one_time_keys
)
@twisted.internet.defer.inlineCallbacks
@defer.inlineCallbacks
def delete_e2e_keys_by_device(self, user_id, device_id):
yield self._simple_delete(
table="e2e_device_keys_json",

View File

@@ -129,7 +129,7 @@ class EventFederationStore(SQLBaseStore):
room_id,
)
@cached()
@cached(max_entries=5000, iterable=True)
def get_latest_event_ids_in_room(self, room_id):
return self._simple_select_onecol(
table="event_forward_extremities",
@@ -235,80 +235,21 @@ class EventFederationStore(SQLBaseStore):
],
)
self._update_extremeties(txn, events)
self._update_backward_extremeties(txn, events)
def _update_extremeties(self, txn, events):
"""Updates the event_*_extremities tables based on the new/updated
def _update_backward_extremeties(self, txn, events):
"""Updates the event_backward_extremities tables based on the new/updated
events being persisted.
This is called for new events *and* for events that were outliers, but
are are now being persisted as non-outliers.
are now being persisted as non-outliers.
Forward extremities are handled when we first start persisting the events.
"""
events_by_room = {}
for ev in events:
events_by_room.setdefault(ev.room_id, []).append(ev)
for room_id, room_events in events_by_room.items():
prevs = [
e_id for ev in room_events for e_id, _ in ev.prev_events
if not ev.internal_metadata.is_outlier()
]
if prevs:
txn.execute(
"DELETE FROM event_forward_extremities"
" WHERE room_id = ?"
" AND event_id in (%s)" % (
",".join(["?"] * len(prevs)),
),
[room_id] + prevs,
)
query = (
"INSERT INTO event_forward_extremities (event_id, room_id)"
" SELECT ?, ? WHERE NOT EXISTS ("
" SELECT 1 FROM event_edges WHERE prev_event_id = ?"
" )"
)
txn.executemany(
query,
[
(ev.event_id, ev.room_id, ev.event_id) for ev in events
if not ev.internal_metadata.is_outlier()
]
)
# We now insert into stream_ordering_to_exterm a mapping from room_id,
# new stream_ordering to new forward extremeties in the room.
# This allows us to later efficiently look up the forward extremeties
# for a room before a given stream_ordering
max_stream_ord = max(
ev.internal_metadata.stream_ordering for ev in events
)
new_extrem = {}
for room_id in events_by_room:
event_ids = self._simple_select_onecol_txn(
txn,
table="event_forward_extremities",
keyvalues={"room_id": room_id},
retcol="event_id",
)
new_extrem[room_id] = event_ids
self._simple_insert_many_txn(
txn,
table="stream_ordering_to_exterm",
values=[
{
"room_id": room_id,
"event_id": event_id,
"stream_ordering": max_stream_ord,
}
for room_id, extrem_evs in new_extrem.items()
for event_id in extrem_evs
]
)
query = (
"INSERT INTO event_backward_extremities (event_id, room_id)"
" SELECT ?, ? WHERE NOT EXISTS ("
@@ -339,11 +280,6 @@ class EventFederationStore(SQLBaseStore):
]
)
for room_id in events_by_room:
txn.call_after(
self.get_latest_event_ids_in_room.invalidate, (room_id,)
)
def get_forward_extremeties_for_room(self, room_id, stream_ordering):
# We want to make the cache more effective, so we clamp to the last
# change before the given ordering.

View File

@@ -12,7 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import SQLBaseStore, _RollbackButIsFineException
from ._base import SQLBaseStore
from twisted.internet import defer, reactor
@@ -27,6 +27,8 @@ from synapse.util.logutils import log_function
from synapse.util.metrics import Measure
from synapse.api.constants import EventTypes
from synapse.api.errors import SynapseError
from synapse.state import resolve_events
from synapse.util.caches.descriptors import cached
from canonicaljson import encode_canonical_json
from collections import deque, namedtuple, OrderedDict
@@ -71,22 +73,19 @@ class _EventPeristenceQueue(object):
"""
_EventPersistQueueItem = namedtuple("_EventPersistQueueItem", (
"events_and_contexts", "current_state", "backfilled", "deferred",
"events_and_contexts", "backfilled", "deferred",
))
def __init__(self):
self._event_persist_queues = {}
self._currently_persisting_rooms = set()
def add_to_queue(self, room_id, events_and_contexts, backfilled, current_state):
def add_to_queue(self, room_id, events_and_contexts, backfilled):
"""Add events to the queue, with the given persist_event options.
"""
queue = self._event_persist_queues.setdefault(room_id, deque())
if queue:
end_item = queue[-1]
if end_item.current_state or current_state:
# We perist events with current_state set to True one at a time
pass
if end_item.backfilled == backfilled:
end_item.events_and_contexts.extend(events_and_contexts)
return end_item.deferred.observe()
@@ -96,7 +95,6 @@ class _EventPeristenceQueue(object):
queue.append(self._EventPersistQueueItem(
events_and_contexts=events_and_contexts,
backfilled=backfilled,
current_state=current_state,
deferred=deferred,
))
@@ -216,7 +214,6 @@ class EventsStore(SQLBaseStore):
d = preserve_fn(self._event_persist_queue.add_to_queue)(
room_id, evs_ctxs,
backfilled=backfilled,
current_state=None,
)
deferreds.append(d)
@@ -229,11 +226,10 @@ class EventsStore(SQLBaseStore):
@defer.inlineCallbacks
@log_function
def persist_event(self, event, context, current_state=None, backfilled=False):
def persist_event(self, event, context, backfilled=False):
deferred = self._event_persist_queue.add_to_queue(
event.room_id, [(event, context)],
backfilled=backfilled,
current_state=current_state,
)
self._maybe_start_persisting(event.room_id)
@@ -246,21 +242,10 @@ class EventsStore(SQLBaseStore):
def _maybe_start_persisting(self, room_id):
@defer.inlineCallbacks
def persisting_queue(item):
if item.current_state:
for event, context in item.events_and_contexts:
# There should only ever be one item in
# events_and_contexts when current_state is
# not None
yield self._persist_event(
event, context,
current_state=item.current_state,
backfilled=item.backfilled,
)
else:
yield self._persist_events(
item.events_and_contexts,
backfilled=item.backfilled,
)
yield self._persist_events(
item.events_and_contexts,
backfilled=item.backfilled,
)
self._event_persist_queue.handle_queue(room_id, persisting_queue)
@@ -294,35 +279,183 @@ class EventsStore(SQLBaseStore):
for chunk in chunks:
# We can't easily parallelize these since different chunks
# might contain the same event. :(
# NB: Assumes that we are only persisting events for one room
# at a time.
new_forward_extremeties = {}
current_state_for_room = {}
if not backfilled:
with Measure(self._clock, "_calculate_state_and_extrem"):
# Work out the new "current state" for each room.
# We do this by working out what the new extremities are and then
# calculating the state from that.
events_by_room = {}
for event, context in chunk:
events_by_room.setdefault(event.room_id, []).append(
(event, context)
)
for room_id, ev_ctx_rm in events_by_room.items():
# Work out new extremities by recursively adding and removing
# the new events.
latest_event_ids = yield self.get_latest_event_ids_in_room(
room_id
)
new_latest_event_ids = yield self._calculate_new_extremeties(
room_id, [ev for ev, _ in ev_ctx_rm]
)
if new_latest_event_ids == set(latest_event_ids):
# No change in extremities, so no change in state
continue
new_forward_extremeties[room_id] = new_latest_event_ids
state = yield self._calculate_state_delta(
room_id, ev_ctx_rm, new_latest_event_ids
)
if state:
current_state_for_room[room_id] = state
yield self.runInteraction(
"persist_events",
self._persist_events_txn,
events_and_contexts=chunk,
backfilled=backfilled,
delete_existing=delete_existing,
current_state_for_room=current_state_for_room,
new_forward_extremeties=new_forward_extremeties,
)
persist_event_counter.inc_by(len(chunk))
@_retry_on_integrity_error
@defer.inlineCallbacks
@log_function
def _persist_event(self, event, context, current_state=None, backfilled=False,
delete_existing=False):
try:
with self._stream_id_gen.get_next() as stream_ordering:
event.internal_metadata.stream_ordering = stream_ordering
yield self.runInteraction(
"persist_event",
self._persist_event_txn,
event=event,
context=context,
current_state=current_state,
backfilled=backfilled,
delete_existing=delete_existing,
)
persist_event_counter.inc()
except _RollbackButIsFineException:
pass
def _calculate_new_extremeties(self, room_id, events):
"""Calculates the new forward extremeties for a room given events to
persist.
Assumes that we are only persisting events for one room at a time.
"""
latest_event_ids = yield self.get_latest_event_ids_in_room(
room_id
)
new_latest_event_ids = set(latest_event_ids)
# First, add all the new events to the list
new_latest_event_ids.update(
event.event_id for event in events
if not event.internal_metadata.is_outlier()
)
# Now remove all events that are referenced by the to-be-added events
new_latest_event_ids.difference_update(
e_id
for event in events
for e_id, _ in event.prev_events
if not event.internal_metadata.is_outlier()
)
# And finally remove any events that are referenced by previously added
# events.
rows = yield self._simple_select_many_batch(
table="event_edges",
column="prev_event_id",
iterable=list(new_latest_event_ids),
retcols=["prev_event_id"],
keyvalues={
"room_id": room_id,
"is_state": False,
},
desc="_calculate_new_extremeties",
)
new_latest_event_ids.difference_update(
row["prev_event_id"] for row in rows
)
defer.returnValue(new_latest_event_ids)
@defer.inlineCallbacks
def _calculate_state_delta(self, room_id, events_context, new_latest_event_ids):
"""Calculate the new state deltas for a room.
Assumes that we are only persisting events for one room at a time.
Returns:
2-tuple (to_delete, to_insert) where both are state dicts, i.e.
(type, state_key) -> event_id. `to_delete` are the entries to
first be deleted from current_state_events, `to_insert` are entries
to insert.
May return None if there are no changes to be applied.
"""
# Now we need to work out the different state sets for
# each state extremities
state_sets = []
missing_event_ids = []
was_updated = False
for event_id in new_latest_event_ids:
# First search in the list of new events we're adding,
# and then use the current state from that
for ev, ctx in events_context:
if event_id == ev.event_id:
if ctx.current_state_ids is None:
raise Exception("Unknown current state")
state_sets.append(ctx.current_state_ids)
if ctx.delta_ids or hasattr(ev, "state_key"):
was_updated = True
break
else:
# If we couldn't find it, then we'll need to pull
# the state from the database
was_updated = True
missing_event_ids.append(event_id)
if missing_event_ids:
# Now pull out the state for any missing events from DB
event_to_groups = yield self._get_state_group_for_events(
missing_event_ids,
)
groups = set(event_to_groups.values())
group_to_state = yield self._get_state_for_groups(groups)
state_sets.extend(group_to_state.values())
if not new_latest_event_ids:
current_state = {}
elif was_updated:
current_state = yield resolve_events(
state_sets,
state_map_factory=lambda ev_ids: self.get_events(
ev_ids, get_prev_content=False, check_redacted=False,
),
)
else:
return
existing_state_rows = yield self._simple_select_list(
table="current_state_events",
keyvalues={"room_id": room_id},
retcols=["event_id", "type", "state_key"],
desc="_calculate_state_delta",
)
existing_events = set(row["event_id"] for row in existing_state_rows)
new_events = set(ev_id for ev_id in current_state.itervalues())
changed_events = existing_events ^ new_events
if not changed_events:
return
to_delete = {
(row["type"], row["state_key"]): row["event_id"]
for row in existing_state_rows
if row["event_id"] in changed_events
}
events_to_insert = (new_events - existing_events)
to_insert = {
key: ev_id for key, ev_id in current_state.iteritems()
if ev_id in events_to_insert
}
defer.returnValue((to_delete, to_insert))
@defer.inlineCallbacks
def get_event(self, event_id, check_redacted=True,
@@ -380,53 +513,10 @@ class EventsStore(SQLBaseStore):
defer.returnValue({e.event_id: e for e in events})
@log_function
def _persist_event_txn(self, txn, event, context, current_state, backfilled=False,
delete_existing=False):
# We purposefully do this first since if we include a `current_state`
# key, we *want* to update the `current_state_events` table
if current_state:
txn.call_after(self._get_current_state_for_key.invalidate_all)
txn.call_after(self.get_rooms_for_user.invalidate_all)
txn.call_after(self.get_users_in_room.invalidate, (event.room_id,))
# Add an entry to the current_state_resets table to record the point
# where we clobbered the current state
stream_order = event.internal_metadata.stream_ordering
self._simple_insert_txn(
txn,
table="current_state_resets",
values={"event_stream_ordering": stream_order}
)
self._simple_delete_txn(
txn,
table="current_state_events",
keyvalues={"room_id": event.room_id},
)
for s in current_state:
self._simple_insert_txn(
txn,
"current_state_events",
{
"event_id": s.event_id,
"room_id": s.room_id,
"type": s.type,
"state_key": s.state_key,
}
)
return self._persist_events_txn(
txn,
[(event, context)],
backfilled=backfilled,
delete_existing=delete_existing,
)
@log_function
def _persist_events_txn(self, txn, events_and_contexts, backfilled,
delete_existing=False):
delete_existing=False, current_state_for_room={},
new_forward_extremeties={}):
"""Insert some number of room events into the necessary database tables.
Rejected events are only inserted into the events table, the events_json table,
@@ -436,6 +526,93 @@ class EventsStore(SQLBaseStore):
If delete_existing is True then existing events will be purged from the
database before insertion. This is useful when retrying due to IntegrityError.
"""
max_stream_order = events_and_contexts[-1][0].internal_metadata.stream_ordering
for room_id, current_state_tuple in current_state_for_room.iteritems():
to_delete, to_insert = current_state_tuple
txn.executemany(
"DELETE FROM current_state_events WHERE event_id = ?",
[(ev_id,) for ev_id in to_delete.itervalues()],
)
self._simple_insert_many_txn(
txn,
table="current_state_events",
values=[
{
"event_id": ev_id,
"room_id": room_id,
"type": key[0],
"state_key": key[1],
}
for key, ev_id in to_insert.iteritems()
],
)
# Invalidate the various caches
# Figure out the changes of membership to invalidate the
# `get_rooms_for_user` cache.
# We find out which membership events we may have deleted
# and which we have added, then we invlidate the caches for all
# those users.
members_changed = set(
state_key for ev_type, state_key in to_delete.iterkeys()
if ev_type == EventTypes.Member
)
members_changed.update(
state_key for ev_type, state_key in to_insert.iterkeys()
if ev_type == EventTypes.Member
)
for member in members_changed:
self._invalidate_cache_and_stream(
txn, self.get_rooms_for_user, (member,)
)
self._invalidate_cache_and_stream(
txn, self.get_users_in_room, (room_id,)
)
for room_id, new_extrem in new_forward_extremeties.items():
self._simple_delete_txn(
txn,
table="event_forward_extremities",
keyvalues={"room_id": room_id},
)
txn.call_after(
self.get_latest_event_ids_in_room.invalidate, (room_id,)
)
self._simple_insert_many_txn(
txn,
table="event_forward_extremities",
values=[
{
"event_id": ev_id,
"room_id": room_id,
}
for room_id, new_extrem in new_forward_extremeties.items()
for ev_id in new_extrem
],
)
# We now insert into stream_ordering_to_exterm a mapping from room_id,
# new stream_ordering to new forward extremeties in the room.
# This allows us to later efficiently look up the forward extremeties
# for a room before a given stream_ordering
self._simple_insert_many_txn(
txn,
table="stream_ordering_to_exterm",
values=[
{
"room_id": room_id,
"event_id": event_id,
"stream_ordering": max_stream_order,
}
for room_id, new_extrem in new_forward_extremeties.items()
for event_id in new_extrem
]
)
# Ensure that we don't have the same event twice.
# Pick the earliest non-outlier if there is one, else the earliest one.
new_events_and_contexts = OrderedDict()
@@ -550,7 +727,7 @@ class EventsStore(SQLBaseStore):
# Update the event_backward_extremities table now that this
# event isn't an outlier any more.
self._update_extremeties(txn, [event])
self._update_backward_extremeties(txn, [event])
events_and_contexts = [
ec for ec in events_and_contexts if ec[0] not in to_remove
@@ -798,29 +975,6 @@ class EventsStore(SQLBaseStore):
# to update the current state table
return
for event, _ in state_events_and_contexts:
if event.internal_metadata.is_outlier():
# Outlier events shouldn't clobber the current state.
continue
txn.call_after(
self._get_current_state_for_key.invalidate,
(event.room_id, event.type, event.state_key,)
)
self._simple_upsert_txn(
txn,
"current_state_events",
keyvalues={
"room_id": event.room_id,
"type": event.type,
"state_key": event.state_key,
},
values={
"event_id": event.event_id,
}
)
return
def _add_to_cache(self, txn, events_and_contexts):
@@ -1084,8 +1238,10 @@ class EventsStore(SQLBaseStore):
self._do_fetch
)
logger.debug("Loading %d events", len(events))
with PreserveLoggingContext():
rows = yield events_d
logger.debug("Loaded %d events (%d rows)", len(events), len(rows))
if not allow_rejected:
rows[:] = [r for r in rows if not r["rejects"]]
@@ -1416,6 +1572,7 @@ class EventsStore(SQLBaseStore):
"""The current minimum token that backfilled events have reached"""
return -self._backfill_id_gen.get_current_token()
@cached(num_args=5, max_entries=10)
def get_all_new_events(self, last_backfill_id, last_forward_id,
current_backfill_id, current_forward_id, limit):
"""Get all the new events that have arrived at the server either as
@@ -1447,15 +1604,6 @@ class EventsStore(SQLBaseStore):
else:
upper_bound = current_forward_id
sql = (
"SELECT event_stream_ordering FROM current_state_resets"
" WHERE ? < event_stream_ordering"
" AND event_stream_ordering <= ?"
" ORDER BY event_stream_ordering ASC"
)
txn.execute(sql, (last_forward_id, upper_bound))
state_resets = txn.fetchall()
sql = (
"SELECT event_stream_ordering, event_id, state_group"
" FROM ex_outlier_stream"
@@ -1467,7 +1615,6 @@ class EventsStore(SQLBaseStore):
forward_ex_outliers = txn.fetchall()
else:
new_forward_events = []
state_resets = []
forward_ex_outliers = []
sql = (
@@ -1507,7 +1654,6 @@ class EventsStore(SQLBaseStore):
return AllNewEventsResult(
new_forward_events, new_backfill_events,
forward_ex_outliers, backward_ex_outliers,
state_resets,
)
return self.runInteraction("get_all_new_events", get_all_new_events_txn)
@@ -1733,5 +1879,4 @@ class EventsStore(SQLBaseStore):
AllNewEventsResult = namedtuple("AllNewEventsResult", [
"new_forward_events", "new_backfill_events",
"forward_ex_outliers", "backward_ex_outliers",
"state_resets"
])

View File

@@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
# Remember to update this number every time a change is made to database
# schema files, so the users will be informed on server restarts.
SCHEMA_VERSION = 39
SCHEMA_VERSION = 40
dir_path = os.path.abspath(os.path.dirname(__file__))

View File

@@ -413,6 +413,17 @@ class RegistrationStore(background_updates.BackgroundUpdateStore):
desc="user_delete_threepids",
)
def user_delete_threepid(self, user_id, medium, address):
return self._simple_delete(
"user_threepids",
keyvalues={
"user_id": user_id,
"medium": medium,
"address": address,
},
desc="user_delete_threepids",
)
@defer.inlineCallbacks
def count_all_users(self):
"""Counts all users registered on the homeserver."""

View File

@@ -66,8 +66,6 @@ class RoomMemberStore(SQLBaseStore):
)
for event in events:
txn.call_after(self.get_rooms_for_user.invalidate, (event.state_key,))
txn.call_after(self.get_users_in_room.invalidate, (event.room_id,))
txn.call_after(
self._membership_stream_cache.entity_has_changed,
event.state_key, event.internal_metadata.stream_ordering
@@ -131,7 +129,7 @@ class RoomMemberStore(SQLBaseStore):
with self._stream_id_gen.get_next() as stream_ordering:
yield self.runInteraction("locally_reject_invite", f, stream_ordering)
@cached(max_entries=5000)
@cached(max_entries=100000, iterable=True)
def get_users_in_room(self, room_id):
def f(txn):
@@ -220,7 +218,7 @@ class RoomMemberStore(SQLBaseStore):
" ON e.event_id = c.event_id"
" AND m.room_id = c.room_id"
" AND m.user_id = c.state_key"
" WHERE %s"
" WHERE c.type = 'm.room.member' AND %s"
) % (where_clause,)
txn.execute(sql, args)
@@ -266,7 +264,7 @@ class RoomMemberStore(SQLBaseStore):
" ON m.event_id = c.event_id "
" AND m.room_id = c.room_id "
" AND m.user_id = c.state_key"
" WHERE %(where)s"
" WHERE c.type = 'm.room.member' AND %(where)s"
) % {
"where": where_clause,
}
@@ -282,6 +280,22 @@ class RoomMemberStore(SQLBaseStore):
user_id, membership_list=[Membership.JOIN],
)
@cachedInlineCallbacks(max_entries=50000, cache_context=True, iterable=True)
def get_users_who_share_room_with_user(self, user_id, cache_context):
rooms = yield self.get_rooms_for_user(
user_id, on_invalidate=cache_context.invalidate,
)
user_who_share_room = set()
for room in rooms:
user_ids = yield self.get_users_in_room(
room.room_id, on_invalidate=cache_context.invalidate,
)
logger.info("Users in room: %r %r", room.room_id, user_ids)
user_who_share_room.update(user_ids)
defer.returnValue(user_who_share_room)
def forget(self, user_id, room_id):
"""Indicate that user_id wishes to discard history for room_id."""
def f(txn):
@@ -390,11 +404,12 @@ class RoomMemberStore(SQLBaseStore):
room_id, state_group, state_ids,
)
@cachedInlineCallbacks(num_args=2, cache_context=True)
@cachedInlineCallbacks(num_args=2, cache_context=True, iterable=True,
max_entries=100000)
def _get_joined_users_from_context(self, room_id, state_group, current_state_ids,
cache_context, event=None):
# We don't use `state_group`, its there so that we can cache based
# on it. However, its important that its never None, since two current_state's
# We don't use `state_group`, it's there so that we can cache based
# on it. However, it's important that it's never None, since two current_states
# with a state_group of None are likely to be different.
# See bulk_get_push_rules_for_room for how we work around this.
assert state_group is not None

View File

@@ -0,0 +1,17 @@
/* Copyright 2017 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
INSERT INTO background_updates (update_name, progress_json) VALUES
('current_state_members_idx', '{}');

View File

@@ -0,0 +1,21 @@
/* Copyright 2016 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- turn the pre-fill startup query into a index-only scan on postgresql.
INSERT into background_updates (update_name, progress_json)
VALUES ('device_inbox_stream_index', '{}');
INSERT into background_updates (update_name, progress_json, depends_on)
VALUES ('device_inbox_stream_drop', '{}', 'device_inbox_stream_index');

View File

@@ -0,0 +1,59 @@
/* Copyright 2017 OpenMarket Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- Cache of remote devices.
CREATE TABLE device_lists_remote_cache (
user_id TEXT NOT NULL,
device_id TEXT NOT NULL,
content TEXT NOT NULL
);
CREATE INDEX device_lists_remote_cache_id ON device_lists_remote_cache(user_id, device_id);
-- The last update we got for a user. Empty if we're not receiving updates for
-- that user.
CREATE TABLE device_lists_remote_extremeties (
user_id TEXT NOT NULL,
stream_id TEXT NOT NULL
);
CREATE INDEX device_lists_remote_extremeties_id ON device_lists_remote_extremeties(user_id, stream_id);
-- Stream of device lists updates. Includes both local and remotes
CREATE TABLE device_lists_stream (
stream_id BIGINT NOT NULL,
user_id TEXT NOT NULL,
device_id TEXT NOT NULL
);
CREATE INDEX device_lists_stream_id ON device_lists_stream(stream_id, user_id);
-- The stream of updates to send to other servers. We keep at least one row
-- per user that was sent so that the prev_id for any new updates can be
-- calculated
CREATE TABLE device_lists_outbound_pokes (
destination TEXT NOT NULL,
stream_id BIGINT NOT NULL,
user_id TEXT NOT NULL,
device_id TEXT NOT NULL,
sent BOOLEAN NOT NULL,
ts BIGINT NOT NULL -- So that in future we can clear out pokes to dead servers
);
CREATE INDEX device_lists_outbound_pokes_id ON device_lists_outbound_pokes(destination, stream_id);
CREATE INDEX device_lists_outbound_pokes_user ON device_lists_outbound_pokes(destination, user_id);

View File

@@ -49,6 +49,7 @@ class StateStore(SQLBaseStore):
STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication"
STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
CURRENT_STATE_INDEX_UPDATE_NAME = "current_state_members_idx"
def __init__(self, hs):
super(StateStore, self).__init__(hs)
@@ -60,6 +61,13 @@ class StateStore(SQLBaseStore):
self.STATE_GROUP_INDEX_UPDATE_NAME,
self._background_index_state,
)
self.register_background_index_update(
self.CURRENT_STATE_INDEX_UPDATE_NAME,
index_name="current_state_events_member_index",
table="current_state_events",
columns=["state_key"],
where_clause="type='m.room.member'",
)
@defer.inlineCallbacks
def get_state_groups_ids(self, room_id, event_ids):
@@ -232,59 +240,7 @@ class StateStore(SQLBaseStore):
return count
@defer.inlineCallbacks
def get_current_state(self, room_id, event_type=None, state_key=""):
if event_type and state_key is not None:
result = yield self.get_current_state_for_key(
room_id, event_type, state_key
)
defer.returnValue(result)
def f(txn):
sql = (
"SELECT event_id FROM current_state_events"
" WHERE room_id = ? "
)
if event_type and state_key is not None:
sql += " AND type = ? AND state_key = ? "
args = (room_id, event_type, state_key)
elif event_type:
sql += " AND type = ?"
args = (room_id, event_type)
else:
args = (room_id, )
txn.execute(sql, args)
results = txn.fetchall()
return [r[0] for r in results]
event_ids = yield self.runInteraction("get_current_state", f)
events = yield self._get_events(event_ids, get_prev_content=False)
defer.returnValue(events)
@defer.inlineCallbacks
def get_current_state_for_key(self, room_id, event_type, state_key):
event_ids = yield self._get_current_state_for_key(room_id, event_type, state_key)
events = yield self._get_events(event_ids, get_prev_content=False)
defer.returnValue(events)
@cached(num_args=3)
def _get_current_state_for_key(self, room_id, event_type, state_key):
def f(txn):
sql = (
"SELECT event_id FROM current_state_events"
" WHERE room_id = ? AND type = ? AND state_key = ?"
)
args = (room_id, event_type, state_key)
txn.execute(sql, args)
results = txn.fetchall()
return [r[0] for r in results]
return self.runInteraction("get_current_state_for_key", f)
@cached(num_args=2, max_entries=1000)
@cached(num_args=2, max_entries=100000, iterable=True)
def _get_state_group_from_group(self, group, types):
raise NotImplementedError()
@@ -384,7 +340,7 @@ class StateStore(SQLBaseStore):
# We did this before by getting the list of group ids, and
# then passing that list to sqlite to get latest event for
# each (type, state_key). However, that was terribly slow
# without the right indicies (which we can't add until
# without the right indices (which we can't add until
# after we finish deduping state, which requires this func)
args = [next_group]
if types:

View File

@@ -244,6 +244,20 @@ class StreamStore(SQLBaseStore):
defer.returnValue(results)
def get_rooms_that_changed(self, room_ids, from_key):
"""Given a list of rooms and a token, return rooms where there may have
been changes.
Args:
room_ids (list)
from_key (str): The room_key portion of a StreamToken
"""
from_key = RoomStreamToken.parse_stream_token(from_key).stream
return set(
room_id for room_id in room_ids
if self._events_stream_cache.has_entity_changed(room_id, from_key)
)
@defer.inlineCallbacks
def get_room_events_stream_for_room(self, room_id, from_key, to_key, limit=0,
order='DESC'):

View File

@@ -44,6 +44,7 @@ class EventSources(object):
def get_current_token(self):
push_rules_key, _ = self.store.get_push_rules_stream_token()
to_device_key = self.store.get_to_device_stream_token()
device_list_key = self.store.get_device_stream_token()
token = StreamToken(
room_key=(
@@ -63,6 +64,7 @@ class EventSources(object):
),
push_rules_key=push_rules_key,
to_device_key=to_device_key,
device_list_key=device_list_key,
)
defer.returnValue(token)
@@ -70,6 +72,7 @@ class EventSources(object):
def get_current_token_for_room(self, room_id):
push_rules_key, _ = self.store.get_push_rules_stream_token()
to_device_key = self.store.get_to_device_stream_token()
device_list_key = self.store.get_device_stream_token()
token = StreamToken(
room_key=(
@@ -89,5 +92,6 @@ class EventSources(object):
),
push_rules_key=push_rules_key,
to_device_key=to_device_key,
device_list_key=device_list_key,
)
defer.returnValue(token)

View File

@@ -158,6 +158,7 @@ class StreamToken(
"account_data_key",
"push_rules_key",
"to_device_key",
"device_list_key",
))
):
_SEPARATOR = "_"
@@ -195,6 +196,7 @@ class StreamToken(
or (int(other.account_data_key) < int(self.account_data_key))
or (int(other.push_rules_key) < int(self.push_rules_key))
or (int(other.to_device_key) < int(self.to_device_key))
or (int(other.device_list_key) < int(self.device_list_key))
)
def copy_and_advance(self, key, new_value):

View File

@@ -23,6 +23,10 @@ from synapse.util import unwrapFirstError
from contextlib import contextmanager
import logging
logger = logging.getLogger(__name__)
@defer.inlineCallbacks
def sleep(seconds):
@@ -162,7 +166,11 @@ class Linearizer(object):
# do some work.
"""
def __init__(self):
def __init__(self, name=None):
if name is None:
self.name = id(self)
else:
self.name = name
self.key_to_defer = {}
@defer.inlineCallbacks
@@ -181,14 +189,23 @@ class Linearizer(object):
self.key_to_defer[key] = new_defer
if current_defer:
with PreserveLoggingContext():
yield current_defer
logger.info(
"Waiting to acquire linearizer lock %r for key %r", self.name, key
)
try:
with PreserveLoggingContext():
yield current_defer
except:
logger.exception("Unexpected exception in Linearizer")
logger.info("Acquired linearizer lock %r for key %r", self.name, key)
@contextmanager
def _ctx_manager():
try:
yield
finally:
logger.info("Releasing linearizer lock %r for key %r", self.name, key)
new_defer.callback(None)
current_d = self.key_to_defer.get(key)
if current_d is new_defer:

View File

@@ -40,8 +40,8 @@ def register_cache(name, cache):
)
_string_cache = LruCache(int(5000 * CACHE_SIZE_FACTOR))
caches_by_name["string_cache"] = _string_cache
_string_cache = LruCache(int(100000 * CACHE_SIZE_FACTOR))
_stirng_cache_metrics = register_cache("string_cache", _string_cache)
KNOWN_KEYS = {
@@ -69,7 +69,12 @@ KNOWN_KEYS = {
def intern_string(string):
"""Takes a (potentially) unicode string and interns using custom cache
"""
return _string_cache.setdefault(string, string)
new_str = _string_cache.setdefault(string, string)
if new_str is string:
_stirng_cache_metrics.inc_hits()
else:
_stirng_cache_metrics.inc_misses()
return new_str
def intern_dict(dictionary):

View File

@@ -17,7 +17,7 @@ import logging
from synapse.util.async import ObservableDeferred
from synapse.util import unwrapFirstError
from synapse.util.caches.lrucache import LruCache
from synapse.util.caches.treecache import TreeCache
from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry
from synapse.util.logcontext import (
PreserveLoggingContext, preserve_context_over_deferred, preserve_context_over_fn
)
@@ -42,6 +42,25 @@ _CacheSentinel = object()
CACHE_SIZE_FACTOR = float(os.environ.get("SYNAPSE_CACHE_FACTOR", 0.1))
class CacheEntry(object):
__slots__ = [
"deferred", "sequence", "callbacks", "invalidated"
]
def __init__(self, deferred, sequence, callbacks):
self.deferred = deferred
self.sequence = sequence
self.callbacks = set(callbacks)
self.invalidated = False
def invalidate(self):
if not self.invalidated:
self.invalidated = True
for callback in self.callbacks:
callback()
self.callbacks.clear()
class Cache(object):
__slots__ = (
"cache",
@@ -51,12 +70,16 @@ class Cache(object):
"sequence",
"thread",
"metrics",
"_pending_deferred_cache",
)
def __init__(self, name, max_entries=1000, keylen=1, tree=False):
def __init__(self, name, max_entries=1000, keylen=1, tree=False, iterable=False):
cache_type = TreeCache if tree else dict
self._pending_deferred_cache = cache_type()
self.cache = LruCache(
max_size=max_entries, keylen=keylen, cache_type=cache_type
max_size=max_entries, keylen=keylen, cache_type=cache_type,
size_callback=(lambda d: len(d.result)) if iterable else None,
)
self.name = name
@@ -76,7 +99,15 @@ class Cache(object):
)
def get(self, key, default=_CacheSentinel, callback=None):
val = self.cache.get(key, _CacheSentinel, callback=callback)
callbacks = [callback] if callback else []
val = self._pending_deferred_cache.get(key, _CacheSentinel)
if val is not _CacheSentinel:
if val.sequence == self.sequence:
val.callbacks.update(callbacks)
self.metrics.inc_hits()
return val.deferred
val = self.cache.get(key, _CacheSentinel, callbacks=callbacks)
if val is not _CacheSentinel:
self.metrics.inc_hits()
return val
@@ -88,15 +119,39 @@ class Cache(object):
else:
return default
def update(self, sequence, key, value, callback=None):
def set(self, key, value, callback=None):
callbacks = [callback] if callback else []
self.check_thread()
if self.sequence == sequence:
# Only update the cache if the caches sequence number matches the
# number that the cache had before the SELECT was started (SYN-369)
self.prefill(key, value, callback=callback)
entry = CacheEntry(
deferred=value,
sequence=self.sequence,
callbacks=callbacks,
)
entry.callbacks.update(callbacks)
existing_entry = self._pending_deferred_cache.pop(key, None)
if existing_entry:
existing_entry.invalidate()
self._pending_deferred_cache[key] = entry
def shuffle(result):
if self.sequence == entry.sequence:
existing_entry = self._pending_deferred_cache.pop(key, None)
if existing_entry is entry:
self.cache.set(key, entry.deferred, entry.callbacks)
else:
entry.invalidate()
else:
entry.invalidate()
return result
entry.deferred.addCallback(shuffle)
def prefill(self, key, value, callback=None):
self.cache.set(key, value, callback=callback)
callbacks = [callback] if callback else []
self.cache.set(key, value, callbacks=callbacks)
def invalidate(self, key):
self.check_thread()
@@ -108,6 +163,10 @@ class Cache(object):
# Increment the sequence number so that any SELECT statements that
# raced with the INSERT don't update the cache (SYN-369)
self.sequence += 1
entry = self._pending_deferred_cache.pop(key, None)
if entry:
entry.invalidate()
self.cache.pop(key, None)
def invalidate_many(self, key):
@@ -119,6 +178,11 @@ class Cache(object):
self.sequence += 1
self.cache.del_multi(key)
entry_dict = self._pending_deferred_cache.pop(key, None)
if entry_dict is not None:
for entry in iterate_tree_cache_entry(entry_dict):
entry.invalidate()
def invalidate_all(self):
self.check_thread()
self.sequence += 1
@@ -155,7 +219,7 @@ class CacheDescriptor(object):
"""
def __init__(self, orig, max_entries=1000, num_args=1, tree=False,
inlineCallbacks=False, cache_context=False):
inlineCallbacks=False, cache_context=False, iterable=False):
max_entries = int(max_entries * CACHE_SIZE_FACTOR)
self.orig = orig
@@ -169,6 +233,8 @@ class CacheDescriptor(object):
self.num_args = num_args
self.tree = tree
self.iterable = iterable
all_args = inspect.getargspec(orig)
self.arg_names = all_args.args[1:num_args + 1]
@@ -203,6 +269,7 @@ class CacheDescriptor(object):
max_entries=self.max_entries,
keylen=self.num_args,
tree=self.tree,
iterable=self.iterable,
)
@functools.wraps(self.orig)
@@ -243,11 +310,6 @@ class CacheDescriptor(object):
return preserve_context_over_deferred(observer)
except KeyError:
# Get the sequence number of the cache before reading from the
# database so that we can tell if the cache is invalidated
# while the SELECT is executing (SYN-369)
sequence = cache.sequence
ret = defer.maybeDeferred(
preserve_context_over_fn,
self.function_to_call,
@@ -261,7 +323,7 @@ class CacheDescriptor(object):
ret.addErrback(onErr)
ret = ObservableDeferred(ret, consumeErrors=True)
cache.update(sequence, cache_key, ret, callback=invalidate_callback)
cache.set(cache_key, ret, callback=invalidate_callback)
return preserve_context_over_deferred(ret.observe())
@@ -359,7 +421,6 @@ class CacheListDescriptor(object):
missing.append(arg)
if missing:
sequence = cache.sequence
args_to_call = dict(arg_dict)
args_to_call[self.list_name] = missing
@@ -382,8 +443,8 @@ class CacheListDescriptor(object):
key = list(keyargs)
key[self.list_pos] = arg
cache.update(
sequence, tuple(key), observer,
cache.set(
tuple(key), observer,
callback=invalidate_callback
)
@@ -421,17 +482,20 @@ class _CacheContext(namedtuple("_CacheContext", ("cache", "key"))):
self.cache.invalidate(self.key)
def cached(max_entries=1000, num_args=1, tree=False, cache_context=False):
def cached(max_entries=1000, num_args=1, tree=False, cache_context=False,
iterable=False):
return lambda orig: CacheDescriptor(
orig,
max_entries=max_entries,
num_args=num_args,
tree=tree,
cache_context=cache_context,
iterable=iterable,
)
def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_context=False):
def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_context=False,
iterable=False):
return lambda orig: CacheDescriptor(
orig,
max_entries=max_entries,
@@ -439,6 +503,7 @@ def cachedInlineCallbacks(max_entries=1000, num_args=1, tree=False, cache_contex
tree=tree,
inlineCallbacks=True,
cache_context=cache_context,
iterable=iterable,
)

View File

@@ -23,7 +23,9 @@ import logging
logger = logging.getLogger(__name__)
DictionaryEntry = namedtuple("DictionaryEntry", ("full", "value"))
class DictionaryEntry(namedtuple("DictionaryEntry", ("full", "value"))):
def __len__(self):
return len(self.value)
class DictionaryCache(object):
@@ -32,7 +34,7 @@ class DictionaryCache(object):
"""
def __init__(self, name, max_entries=1000):
self.cache = LruCache(max_size=max_entries)
self.cache = LruCache(max_size=max_entries, size_callback=len)
self.name = name
self.sequence = 0

View File

@@ -15,6 +15,7 @@
from synapse.util.caches import register_cache
from collections import OrderedDict
import logging
@@ -23,7 +24,7 @@ logger = logging.getLogger(__name__)
class ExpiringCache(object):
def __init__(self, cache_name, clock, max_len=0, expiry_ms=0,
reset_expiry_on_get=False):
reset_expiry_on_get=False, iterable=False):
"""
Args:
cache_name (str): Name of this cache, used for logging.
@@ -36,6 +37,8 @@ class ExpiringCache(object):
evicted based on time.
reset_expiry_on_get (bool): If true, will reset the expiry time for
an item on access. Defaults to False.
iterable (bool): If true, the size is calculated by summing the
sizes of all entries, rather than the number of entries.
"""
self._cache_name = cache_name
@@ -47,9 +50,13 @@ class ExpiringCache(object):
self._reset_expiry_on_get = reset_expiry_on_get
self._cache = {}
self._cache = OrderedDict()
self.metrics = register_cache(cache_name, self._cache)
self.metrics = register_cache(cache_name, self)
self.iterable = iterable
self._size_estimate = 0
def start(self):
if not self._expiry_ms:
@@ -65,15 +72,14 @@ class ExpiringCache(object):
now = self._clock.time_msec()
self._cache[key] = _CacheEntry(now, value)
# Evict if there are now too many items
if self._max_len and len(self._cache.keys()) > self._max_len:
sorted_entries = sorted(
self._cache.items(),
key=lambda item: item[1].time,
)
if self.iterable:
self._size_estimate += len(value)
for k, _ in sorted_entries[self._max_len:]:
self._cache.pop(k)
# Evict if there are now too many items
while self._max_len and len(self) > self._max_len:
_key, value = self._cache.popitem(last=False)
if self.iterable:
self._size_estimate -= len(value.value)
def __getitem__(self, key):
try:
@@ -99,7 +105,7 @@ class ExpiringCache(object):
# zero expiry time means don't expire. This should never get called
# since we have this check in start too.
return
begin_length = len(self._cache)
begin_length = len(self)
now = self._clock.time_msec()
@@ -110,15 +116,20 @@ class ExpiringCache(object):
keys_to_delete.add(key)
for k in keys_to_delete:
self._cache.pop(k)
value = self._cache.pop(k)
if self.iterable:
self._size_estimate -= len(value.value)
logger.debug(
"[%s] _prune_cache before: %d, after len: %d",
self._cache_name, begin_length, len(self._cache)
self._cache_name, begin_length, len(self)
)
def __len__(self):
return len(self._cache)
if self.iterable:
return self._size_estimate
else:
return len(self._cache)
class _CacheEntry(object):

View File

@@ -49,7 +49,7 @@ class LruCache(object):
Can also set callbacks on objects when getting/setting which are fired
when that key gets invalidated/evicted.
"""
def __init__(self, max_size, keylen=1, cache_type=dict):
def __init__(self, max_size, keylen=1, cache_type=dict, size_callback=None):
cache = cache_type()
self.cache = cache # Used for introspection.
list_root = _Node(None, None, None, None)
@@ -58,6 +58,12 @@ class LruCache(object):
lock = threading.Lock()
def evict():
while cache_len() > max_size:
todelete = list_root.prev_node
delete_node(todelete)
cache.pop(todelete.key, None)
def synchronized(f):
@wraps(f)
def inner(*args, **kwargs):
@@ -66,6 +72,16 @@ class LruCache(object):
return inner
cached_cache_len = [0]
if size_callback is not None:
def cache_len():
return cached_cache_len[0]
else:
def cache_len():
return len(cache)
self.len = synchronized(cache_len)
def add_node(key, value, callbacks=set()):
prev_node = list_root
next_node = prev_node.next_node
@@ -74,6 +90,9 @@ class LruCache(object):
next_node.prev_node = node
cache[key] = node
if size_callback:
cached_cache_len[0] += size_callback(node.value)
def move_node_to_front(node):
prev_node = node.prev_node
next_node = node.next_node
@@ -92,23 +111,25 @@ class LruCache(object):
prev_node.next_node = next_node
next_node.prev_node = prev_node
if size_callback:
cached_cache_len[0] -= size_callback(node.value)
for cb in node.callbacks:
cb()
node.callbacks.clear()
@synchronized
def cache_get(key, default=None, callback=None):
def cache_get(key, default=None, callbacks=[]):
node = cache.get(key, None)
if node is not None:
move_node_to_front(node)
if callback:
node.callbacks.add(callback)
node.callbacks.update(callbacks)
return node.value
else:
return default
@synchronized
def cache_set(key, value, callback=None):
def cache_set(key, value, callbacks=[]):
node = cache.get(key, None)
if node is not None:
if value != node.value:
@@ -116,21 +137,18 @@ class LruCache(object):
cb()
node.callbacks.clear()
if callback:
node.callbacks.add(callback)
if size_callback:
cached_cache_len[0] -= size_callback(node.value)
cached_cache_len[0] += size_callback(value)
node.callbacks.update(callbacks)
move_node_to_front(node)
node.value = value
else:
if callback:
callbacks = set([callback])
else:
callbacks = set()
add_node(key, value, callbacks)
if len(cache) > max_size:
todelete = list_root.prev_node
delete_node(todelete)
cache.pop(todelete.key, None)
add_node(key, value, set(callbacks))
evict()
@synchronized
def cache_set_default(key, value):
@@ -139,10 +157,7 @@ class LruCache(object):
return node.value
else:
add_node(key, value)
if len(cache) > max_size:
todelete = list_root.prev_node
delete_node(todelete)
cache.pop(todelete.key, None)
evict()
return value
@synchronized
@@ -174,10 +189,8 @@ class LruCache(object):
for cb in node.callbacks:
cb()
cache.clear()
@synchronized
def cache_len():
return len(cache)
if size_callback:
cached_cache_len[0] = 0
@synchronized
def cache_contains(key):
@@ -190,7 +203,7 @@ class LruCache(object):
self.pop = cache_pop
if cache_type is TreeCache:
self.del_multi = cache_del_multi
self.len = cache_len
self.len = synchronized(cache_len)
self.contains = cache_contains
self.clear = cache_clear

View File

@@ -65,12 +65,27 @@ class TreeCache(object):
return popped
def values(self):
return [e.value for e in self.root.values()]
return list(iterate_tree_cache_entry(self.root))
def __len__(self):
return self.size
def iterate_tree_cache_entry(d):
"""Helper function to iterate over the leaves of a tree, i.e. a dict of that
can contain dicts.
"""
if isinstance(d, dict):
for value_d in d.itervalues():
for value in iterate_tree_cache_entry(value_d):
yield value
else:
if isinstance(d, _Entry):
yield d.value
else:
yield d
class _Entry(object):
__slots__ = ["value"]

View File

@@ -1,71 +0,0 @@
# -*- coding: utf-8 -*-
# Copyright 2015, 2016 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from twisted.internet import defer, reactor
from functools import wraps
from synapse.util.logcontext import LoggingContext, PreserveLoggingContext
def debug_deferreds():
"""Cause all deferreds to wait for a reactor tick before running their
callbacks. This increases the chance of getting a stack trace out of
a defer.inlineCallback since the code waiting on the deferred will get
a chance to add an errback before the deferred runs."""
# Helper method for retrieving and restoring the current logging context
# around a callback.
def with_logging_context(fn):
context = LoggingContext.current_context()
def restore_context_callback(x):
with PreserveLoggingContext(context):
return fn(x)
return restore_context_callback
# We are going to modify the __init__ method of defer.Deferred so we
# need to get a copy of the old method so we can still call it.
old__init__ = defer.Deferred.__init__
# We need to create a deferred to bounce the callbacks through the reactor
# but we don't want to add a callback when we create that deferred so we
# we create a new type of deferred that uses the old __init__ method.
# This is safe as long as the old __init__ method doesn't invoke an
# __init__ using super.
class Bouncer(defer.Deferred):
__init__ = old__init__
# We'll add this as a callback to all Deferreds. Twisted will wait until
# the bouncer deferred resolves before calling the callbacks of the
# original deferred.
def bounce_callback(x):
bouncer = Bouncer()
reactor.callLater(0, with_logging_context(bouncer.callback), x)
return bouncer
# We'll add this as an errback to all Deferreds. Twisted will wait until
# the bouncer deferred resolves before calling the errbacks of the
# original deferred.
def bounce_errback(x):
bouncer = Bouncer()
reactor.callLater(0, with_logging_context(bouncer.errback), x)
return bouncer
@wraps(old__init__)
def new__init__(self, *args, **kargs):
old__init__(self, *args, **kargs)
self.addCallbacks(bounce_callback, bounce_errback)
defer.Deferred.__init__ = new__init__

View File

@@ -88,7 +88,7 @@ class RetryDestinationLimiter(object):
def __init__(self, destination, clock, store, retry_interval,
min_retry_interval=10 * 60 * 1000,
max_retry_interval=24 * 60 * 60 * 1000,
multiplier_retry_interval=5,):
multiplier_retry_interval=5, backoff_on_404=False):
"""Marks the destination as "down" if an exception is thrown in the
context, except for CodeMessageException with code < 500.
@@ -107,6 +107,7 @@ class RetryDestinationLimiter(object):
a failed request, in milliseconds.
multiplier_retry_interval (int): The multiplier to use to increase
the retry interval after a failed request.
backoff_on_404 (bool): Back off if we get a 404
"""
self.clock = clock
self.store = store
@@ -116,6 +117,7 @@ class RetryDestinationLimiter(object):
self.min_retry_interval = min_retry_interval
self.max_retry_interval = max_retry_interval
self.multiplier_retry_interval = multiplier_retry_interval
self.backoff_on_404 = backoff_on_404
def __enter__(self):
pass
@@ -123,7 +125,22 @@ class RetryDestinationLimiter(object):
def __exit__(self, exc_type, exc_val, exc_tb):
valid_err_code = False
if exc_type is not None and issubclass(exc_type, CodeMessageException):
valid_err_code = exc_val.code != 429 and 0 <= exc_val.code < 500
# Some error codes are perfectly fine for some APIs, whereas other
# APIs may expect to never received e.g. a 404. It's important to
# handle 404 as some remote servers will return a 404 when the HS
# has been decommissioned.
# If we get a 401, then we should probably back off since they
# won't accept our requests for at least a while.
# 429 is us being aggresively rate limited, so lets rate limit
# ourselves.
if exc_val.code == 404 and self.backoff_on_404:
valid_err_code = False
elif exc_val.code in (401, 429):
valid_err_code = False
elif exc_val.code < 500:
valid_err_code = True
else:
valid_err_code = False
if exc_type is None or valid_err_code:
# We connected successfully.

View File

@@ -25,10 +25,13 @@ from synapse.api.filtering import Filter
from synapse.events import FrozenEvent
user_localpart = "test_user"
# MockEvent = namedtuple("MockEvent", "sender type room_id")
def MockEvent(**kwargs):
if "event_id" not in kwargs:
kwargs["event_id"] = "fake_event_id"
if "type" not in kwargs:
kwargs["type"] = "fake_type"
return FrozenEvent(kwargs)

View File

@@ -21,6 +21,10 @@ from synapse.events.utils import prune_event, serialize_event
def MockEvent(**kwargs):
if "event_id" not in kwargs:
kwargs["event_id"] = "fake_event_id"
if "type" not in kwargs:
kwargs["type"] = "fake_type"
return FrozenEvent(kwargs)
@@ -35,9 +39,13 @@ class PruneEventTestCase(unittest.TestCase):
def test_minimal(self):
self.run_test(
{'type': 'A'},
{
'type': 'A',
'event_id': '$test:domain',
},
{
'type': 'A',
'event_id': '$test:domain',
'content': {},
'signatures': {},
'unsigned': {},
@@ -69,10 +77,12 @@ class PruneEventTestCase(unittest.TestCase):
self.run_test(
{
'type': 'B',
'event_id': '$test:domain',
'unsigned': {'age_ts': 20},
},
{
'type': 'B',
'event_id': '$test:domain',
'content': {},
'signatures': {},
'unsigned': {'age_ts': 20},
@@ -82,10 +92,12 @@ class PruneEventTestCase(unittest.TestCase):
self.run_test(
{
'type': 'B',
'event_id': '$test:domain',
'unsigned': {'other_key': 'here'},
},
{
'type': 'B',
'event_id': '$test:domain',
'content': {},
'signatures': {},
'unsigned': {},
@@ -96,10 +108,12 @@ class PruneEventTestCase(unittest.TestCase):
self.run_test(
{
'type': 'C',
'event_id': '$test:domain',
'content': {'things': 'here'},
},
{
'type': 'C',
'event_id': '$test:domain',
'content': {},
'signatures': {},
'unsigned': {},
@@ -109,10 +123,12 @@ class PruneEventTestCase(unittest.TestCase):
self.run_test(
{
'type': 'm.room.create',
'event_id': '$test:domain',
'content': {'creator': '@2:domain', 'other_field': 'here'},
},
{
'type': 'm.room.create',
'event_id': '$test:domain',
'content': {'creator': '@2:domain'},
'signatures': {},
'unsigned': {},
@@ -255,6 +271,8 @@ class SerializeEventTestCase(unittest.TestCase):
self.assertEquals(
self.serialize(
MockEvent(
type="foo",
event_id="test",
room_id="!foo:bar",
content={
"foo": "bar",
@@ -263,6 +281,8 @@ class SerializeEventTestCase(unittest.TestCase):
[]
),
{
"type": "foo",
"event_id": "test",
"room_id": "!foo:bar",
"content": {
"foo": "bar",

View File

@@ -35,51 +35,51 @@ class DeviceTestCase(unittest.TestCase):
@defer.inlineCallbacks
def setUp(self):
hs = yield utils.setup_test_homeserver(handlers=None)
self.handler = synapse.handlers.device.DeviceHandler(hs)
hs = yield utils.setup_test_homeserver()
self.handler = hs.get_device_handler()
self.store = hs.get_datastore()
self.clock = hs.get_clock()
@defer.inlineCallbacks
def test_device_is_created_if_doesnt_exist(self):
res = yield self.handler.check_device_registered(
user_id="boris",
user_id="@boris:foo",
device_id="fco",
initial_device_display_name="display name"
)
self.assertEqual(res, "fco")
dev = yield self.handler.store.get_device("boris", "fco")
dev = yield self.handler.store.get_device("@boris:foo", "fco")
self.assertEqual(dev["display_name"], "display name")
@defer.inlineCallbacks
def test_device_is_preserved_if_exists(self):
res1 = yield self.handler.check_device_registered(
user_id="boris",
user_id="@boris:foo",
device_id="fco",
initial_device_display_name="display name"
)
self.assertEqual(res1, "fco")
res2 = yield self.handler.check_device_registered(
user_id="boris",
user_id="@boris:foo",
device_id="fco",
initial_device_display_name="new display name"
)
self.assertEqual(res2, "fco")
dev = yield self.handler.store.get_device("boris", "fco")
dev = yield self.handler.store.get_device("@boris:foo", "fco")
self.assertEqual(dev["display_name"], "display name")
@defer.inlineCallbacks
def test_device_id_is_made_up_if_unspecified(self):
device_id = yield self.handler.check_device_registered(
user_id="theresa",
user_id="@theresa:foo",
device_id=None,
initial_device_display_name="display"
)
dev = yield self.handler.store.get_device("theresa", device_id)
dev = yield self.handler.store.get_device("@theresa:foo", device_id)
self.assertEqual(dev["display_name"], "display")
@defer.inlineCallbacks

View File

@@ -37,6 +37,7 @@ class DirectoryTestCase(unittest.TestCase):
def setUp(self):
self.mock_federation = Mock(spec=[
"make_query",
"register_edu_handler",
])
self.query_handlers = {}

View File

@@ -39,6 +39,7 @@ class ProfileTestCase(unittest.TestCase):
def setUp(self):
self.mock_federation = Mock(spec=[
"make_query",
"register_edu_handler",
])
self.query_handlers = {}

View File

@@ -75,6 +75,7 @@ class TypingNotificationsTestCase(unittest.TestCase):
"get_received_txn_response",
"set_received_txn_response",
"get_destination_retry_timings",
"get_devices_by_remote",
]),
state_handler=self.state_handler,
handlers=None,
@@ -99,6 +100,8 @@ class TypingNotificationsTestCase(unittest.TestCase):
defer.succeed(retry_timings_res)
)
self.datastore.get_devices_by_remote.return_value = (0, [])
def get_received_txn_response(*args):
return defer.succeed(None)
self.datastore.get_received_txn_response = get_received_txn_response

View File

@@ -58,53 +58,6 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
def tearDown(self):
[unpatch() for unpatch in self.unpatches]
@defer.inlineCallbacks
def test_room_members(self):
create = yield self.persist(type="m.room.create", key="", creator=USER_ID)
yield self.replicate()
yield self.check("get_rooms_for_user", (USER_ID,), [])
yield self.check("get_users_in_room", (ROOM_ID,), [])
# Join the room.
join = yield self.persist(type="m.room.member", key=USER_ID, membership="join")
yield self.replicate()
yield self.check("get_rooms_for_user", (USER_ID,), [RoomsForUser(
room_id=ROOM_ID,
sender=USER_ID,
membership="join",
event_id=join.event_id,
stream_ordering=join.internal_metadata.stream_ordering,
)])
yield self.check("get_users_in_room", (ROOM_ID,), [USER_ID])
# Leave the room.
yield self.persist(type="m.room.member", key=USER_ID, membership="leave")
yield self.replicate()
yield self.check("get_rooms_for_user", (USER_ID,), [])
yield self.check("get_users_in_room", (ROOM_ID,), [])
# Add some other user to the room.
join = yield self.persist(type="m.room.member", key=USER_ID_2, membership="join")
yield self.replicate()
yield self.check("get_rooms_for_user", (USER_ID_2,), [RoomsForUser(
room_id=ROOM_ID,
sender=USER_ID,
membership="join",
event_id=join.event_id,
stream_ordering=join.internal_metadata.stream_ordering,
)])
yield self.check("get_users_in_room", (ROOM_ID,), [USER_ID_2])
# Join the room clobbering the state.
# This should remove any evidence of the other user being in the room.
yield self.persist(
type="m.room.member", key=USER_ID, membership="join",
reset_state=[create]
)
yield self.replicate()
yield self.check("get_users_in_room", (ROOM_ID,), [USER_ID])
yield self.check("get_rooms_for_user", (USER_ID_2,), [])
@defer.inlineCallbacks
def test_get_latest_event_ids_in_room(self):
create = yield self.persist(type="m.room.create", key="", creator=USER_ID)
@@ -122,51 +75,6 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
"get_latest_event_ids_in_room", (ROOM_ID,), [join.event_id]
)
@defer.inlineCallbacks
def test_get_current_state(self):
# Create the room.
create = yield self.persist(type="m.room.create", key="", creator=USER_ID)
yield self.replicate()
yield self.check(
"get_current_state_for_key", (ROOM_ID, "m.room.member", USER_ID), []
)
# Join the room.
join1 = yield self.persist(
type="m.room.member", key=USER_ID, membership="join",
)
yield self.replicate()
yield self.check(
"get_current_state_for_key", (ROOM_ID, "m.room.member", USER_ID),
[join1]
)
# Add some other user to the room.
join2 = yield self.persist(
type="m.room.member", key=USER_ID_2, membership="join",
)
yield self.replicate()
yield self.check(
"get_current_state_for_key", (ROOM_ID, "m.room.member", USER_ID_2),
[join2]
)
# Leave the room, then rejoin the room clobbering state.
yield self.persist(type="m.room.member", key=USER_ID, membership="leave")
join3 = yield self.persist(
type="m.room.member", key=USER_ID, membership="join",
reset_state=[create]
)
yield self.replicate()
yield self.check(
"get_current_state_for_key", (ROOM_ID, "m.room.member", USER_ID_2),
[]
)
yield self.check(
"get_current_state_for_key", (ROOM_ID, "m.room.member", USER_ID),
[join3]
)
@defer.inlineCallbacks
def test_redactions(self):
yield self.persist(type="m.room.create", key="", creator=USER_ID)
@@ -283,6 +191,12 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
if depth is None:
depth = self.event_id
if not prev_events:
latest_event_ids = yield self.master_store.get_latest_event_ids_in_room(
room_id
)
prev_events = [(ev_id, {}) for ev_id in latest_event_ids]
event_dict = {
"sender": sender,
"type": type,
@@ -309,12 +223,15 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
state_ids = {
key: e.event_id for key, e in state.items()
}
context = EventContext()
context.current_state_ids = state_ids
context.prev_state_ids = state_ids
elif not backfill:
state_handler = self.hs.get_state_handler()
context = yield state_handler.compute_event_context(event)
else:
state_ids = None
context = EventContext()
context = EventContext()
context.current_state_ids = state_ids
context.prev_state_ids = state_ids
context.push_actions = push_actions
ordering = None
@@ -324,7 +241,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
)
else:
ordering, _ = yield self.master_store.persist_event(
event, context, current_state=reset_state
event, context,
)
if ordering:

View File

@@ -259,8 +259,8 @@ class RoomPermissionsTestCase(RestTestCase):
# set [invite/join/left] of self, set [invite/join/left] of other,
# expect all 404s because room doesn't exist on any server
for usr in [self.user_id, self.rmcreator_id]:
yield self.join(room=room, user=usr, expect_code=403)
yield self.leave(room=room, user=usr, expect_code=403)
yield self.join(room=room, user=usr, expect_code=404)
yield self.leave(room=room, user=usr, expect_code=404)
@defer.inlineCallbacks
def test_membership_private_room_perms(self):
@@ -1032,7 +1032,7 @@ class RoomMessageListTestCase(RestTestCase):
@defer.inlineCallbacks
def test_topo_token_is_accepted(self):
token = "t1-0_0_0_0_0_0_0"
token = "t1-0_0_0_0_0_0_0_0"
(code, response) = yield self.mock_resource.trigger_get(
"/rooms/%s/messages?access_token=x&from=%s" %
(self.room_id, token))
@@ -1044,7 +1044,7 @@ class RoomMessageListTestCase(RestTestCase):
@defer.inlineCallbacks
def test_stream_token_is_accepted_for_fwd_pagianation(self):
token = "s0_0_0_0_0_0_0"
token = "s0_0_0_0_0_0_0_0"
(code, response) = yield self.mock_resource.trigger_get(
"/rooms/%s/messages?access_token=x&from=%s" %
(self.room_id, token))

View File

@@ -87,7 +87,10 @@ class RestTestCase(unittest.TestCase):
(code, response) = yield self.mock_resource.trigger(
"PUT", path, json.dumps(data)
)
self.assertEquals(expect_code, code, msg=str(response))
self.assertEquals(
expect_code, code,
msg="Expected: %d, got: %d, resp: %r" % (expect_code, code, response)
)
self.auth_user_id = temp_id

View File

@@ -241,7 +241,7 @@ class CacheDecoratorTestCase(unittest.TestCase):
callcount2 = [0]
class A(object):
@cached(max_entries=2)
@cached(max_entries=20) # HACK: This makes it 2 due to cache factor
def func(self, key):
callcount[0] += 1
return key
@@ -258,6 +258,10 @@ class CacheDecoratorTestCase(unittest.TestCase):
self.assertEquals(callcount[0], 2)
self.assertEquals(callcount2[0], 2)
yield a.func2("foo")
self.assertEquals(callcount[0], 2)
self.assertEquals(callcount2[0], 2)
yield a.func("foo3")
self.assertEquals(callcount[0], 3)

View File

@@ -39,7 +39,11 @@ class ApplicationServiceStoreTestCase(unittest.TestCase):
event_cache_size=1,
password_providers=[],
)
hs = yield setup_test_homeserver(config=config, federation_sender=Mock())
hs = yield setup_test_homeserver(
config=config,
federation_sender=Mock(),
replication_layer=Mock(),
)
self.as_token = "token1"
self.as_url = "some_url"
@@ -112,7 +116,11 @@ class ApplicationServiceTransactionStoreTestCase(unittest.TestCase):
event_cache_size=1,
password_providers=[],
)
hs = yield setup_test_homeserver(config=config, federation_sender=Mock())
hs = yield setup_test_homeserver(
config=config,
federation_sender=Mock(),
replication_layer=Mock(),
)
self.db_pool = hs.get_db_pool()
self.as_list = [
@@ -446,7 +454,8 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
hs = yield setup_test_homeserver(
config=config,
datastore=Mock(),
federation_sender=Mock()
federation_sender=Mock(),
replication_layer=Mock(),
)
ApplicationServiceStore(hs)
@@ -463,7 +472,8 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
hs = yield setup_test_homeserver(
config=config,
datastore=Mock(),
federation_sender=Mock()
federation_sender=Mock(),
replication_layer=Mock(),
)
with self.assertRaises(ConfigError) as cm:
@@ -486,7 +496,8 @@ class ApplicationServiceStoreConfigTestCase(unittest.TestCase):
hs = yield setup_test_homeserver(
config=config,
datastore=Mock(),
federation_sender=Mock()
federation_sender=Mock(),
replication_layer=Mock(),
)
with self.assertRaises(ConfigError) as cm:

View File

@@ -33,7 +33,11 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
@defer.inlineCallbacks
def test_key_without_device_name(self):
now = 1470174257070
json = '{ "key": "value" }'
json = {"key": "value"}
yield self.store.store_device(
"user", "device", None
)
yield self.store.set_e2e_device_keys(
"user", "device", now, json)
@@ -43,14 +47,14 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
self.assertIn("device", res["user"])
dev = res["user"]["device"]
self.assertDictContainsSubset({
"key_json": json,
"keys": json,
"device_display_name": None,
}, dev)
@defer.inlineCallbacks
def test_get_key_with_device_name(self):
now = 1470174257070
json = '{ "key": "value" }'
json = {"key": "value"}
yield self.store.set_e2e_device_keys(
"user", "device", now, json)
@@ -63,7 +67,7 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
self.assertIn("device", res["user"])
dev = res["user"]["device"]
self.assertDictContainsSubset({
"key_json": json,
"keys": json,
"device_display_name": "display_name",
}, dev)
@@ -71,6 +75,19 @@ class EndToEndKeyStoreTestCase(tests.unittest.TestCase):
def test_multiple_devices(self):
now = 1470174257070
yield self.store.store_device(
"user1", "device1", None
)
yield self.store.store_device(
"user1", "device2", None
)
yield self.store.store_device(
"user2", "device1", None
)
yield self.store.store_device(
"user2", "device2", None
)
yield self.store.set_e2e_device_keys(
"user1", "device1", now, 'json11')
yield self.store.set_e2e_device_keys(

View File

@@ -0,0 +1,84 @@
# -*- coding: utf-8 -*-
# Copyright 2017 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .. import unittest
from synapse.util.caches.expiringcache import ExpiringCache
from tests.utils import MockClock
class ExpiringCacheTestCase(unittest.TestCase):
def test_get_set(self):
clock = MockClock()
cache = ExpiringCache("test", clock, max_len=1)
cache["key"] = "value"
self.assertEquals(cache.get("key"), "value")
self.assertEquals(cache["key"], "value")
def test_eviction(self):
clock = MockClock()
cache = ExpiringCache("test", clock, max_len=2)
cache["key"] = "value"
cache["key2"] = "value2"
self.assertEquals(cache.get("key"), "value")
self.assertEquals(cache.get("key2"), "value2")
cache["key3"] = "value3"
self.assertEquals(cache.get("key"), None)
self.assertEquals(cache.get("key2"), "value2")
self.assertEquals(cache.get("key3"), "value3")
def test_iterable_eviction(self):
clock = MockClock()
cache = ExpiringCache("test", clock, max_len=5, iterable=True)
cache["key"] = [1]
cache["key2"] = [2, 3]
cache["key3"] = [4, 5]
self.assertEquals(cache.get("key"), [1])
self.assertEquals(cache.get("key2"), [2, 3])
self.assertEquals(cache.get("key3"), [4, 5])
cache["key4"] = [6, 7]
self.assertEquals(cache.get("key"), None)
self.assertEquals(cache.get("key2"), None)
self.assertEquals(cache.get("key3"), [4, 5])
self.assertEquals(cache.get("key4"), [6, 7])
def test_time_eviction(self):
clock = MockClock()
cache = ExpiringCache("test", clock, expiry_ms=1000)
cache.start()
cache["key"] = 1
clock.advance_time(0.5)
cache["key2"] = 2
self.assertEquals(cache.get("key"), 1)
self.assertEquals(cache.get("key2"), 2)
clock.advance_time(0.9)
self.assertEquals(cache.get("key"), None)
self.assertEquals(cache.get("key2"), 2)
clock.advance_time(1)
self.assertEquals(cache.get("key"), None)
self.assertEquals(cache.get("key2"), None)

View File

@@ -93,7 +93,7 @@ class LruCacheCallbacksTestCase(unittest.TestCase):
cache.set("key", "value")
self.assertFalse(m.called)
cache.get("key", callback=m)
cache.get("key", callbacks=[m])
self.assertFalse(m.called)
cache.get("key", "value")
@@ -112,10 +112,10 @@ class LruCacheCallbacksTestCase(unittest.TestCase):
cache.set("key", "value")
self.assertFalse(m.called)
cache.get("key", callback=m)
cache.get("key", callbacks=[m])
self.assertFalse(m.called)
cache.get("key", callback=m)
cache.get("key", callbacks=[m])
self.assertFalse(m.called)
cache.set("key", "value2")
@@ -128,7 +128,7 @@ class LruCacheCallbacksTestCase(unittest.TestCase):
m = Mock()
cache = LruCache(1)
cache.set("key", "value", m)
cache.set("key", "value", callbacks=[m])
self.assertFalse(m.called)
cache.set("key", "value")
@@ -144,7 +144,7 @@ class LruCacheCallbacksTestCase(unittest.TestCase):
m = Mock()
cache = LruCache(1)
cache.set("key", "value", m)
cache.set("key", "value", callbacks=[m])
self.assertFalse(m.called)
cache.pop("key")
@@ -163,10 +163,10 @@ class LruCacheCallbacksTestCase(unittest.TestCase):
m4 = Mock()
cache = LruCache(4, 2, cache_type=TreeCache)
cache.set(("a", "1"), "value", m1)
cache.set(("a", "2"), "value", m2)
cache.set(("b", "1"), "value", m3)
cache.set(("b", "2"), "value", m4)
cache.set(("a", "1"), "value", callbacks=[m1])
cache.set(("a", "2"), "value", callbacks=[m2])
cache.set(("b", "1"), "value", callbacks=[m3])
cache.set(("b", "2"), "value", callbacks=[m4])
self.assertEquals(m1.call_count, 0)
self.assertEquals(m2.call_count, 0)
@@ -185,8 +185,8 @@ class LruCacheCallbacksTestCase(unittest.TestCase):
m2 = Mock()
cache = LruCache(5)
cache.set("key1", "value", m1)
cache.set("key2", "value", m2)
cache.set("key1", "value", callbacks=[m1])
cache.set("key2", "value", callbacks=[m2])
self.assertEquals(m1.call_count, 0)
self.assertEquals(m2.call_count, 0)
@@ -202,14 +202,14 @@ class LruCacheCallbacksTestCase(unittest.TestCase):
m3 = Mock(name="m3")
cache = LruCache(2)
cache.set("key1", "value", m1)
cache.set("key2", "value", m2)
cache.set("key1", "value", callbacks=[m1])
cache.set("key2", "value", callbacks=[m2])
self.assertEquals(m1.call_count, 0)
self.assertEquals(m2.call_count, 0)
self.assertEquals(m3.call_count, 0)
cache.set("key3", "value", m3)
cache.set("key3", "value", callbacks=[m3])
self.assertEquals(m1.call_count, 1)
self.assertEquals(m2.call_count, 0)
@@ -227,8 +227,33 @@ class LruCacheCallbacksTestCase(unittest.TestCase):
self.assertEquals(m2.call_count, 0)
self.assertEquals(m3.call_count, 0)
cache.set("key1", "value", m1)
cache.set("key1", "value", callbacks=[m1])
self.assertEquals(m1.call_count, 1)
self.assertEquals(m2.call_count, 0)
self.assertEquals(m3.call_count, 1)
class LruCacheSizedTestCase(unittest.TestCase):
def test_evict(self):
cache = LruCache(5, size_callback=len)
cache["key1"] = [0]
cache["key2"] = [1, 2]
cache["key3"] = [3]
cache["key4"] = [4]
self.assertEquals(cache["key1"], [0])
self.assertEquals(cache["key2"], [1, 2])
self.assertEquals(cache["key3"], [3])
self.assertEquals(cache["key4"], [4])
self.assertEquals(len(cache), 5)
cache["key5"] = [5, 6]
self.assertEquals(len(cache), 4)
self.assertEquals(cache.get("key1"), None)
self.assertEquals(cache.get("key2"), None)
self.assertEquals(cache["key3"], [3])
self.assertEquals(cache["key4"], [4])
self.assertEquals(cache["key5"], [5, 6])