Compare commits
291 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ffba978077 | ||
|
|
13e16cf302 | ||
|
|
6070647774 | ||
|
|
d6237859f6 | ||
|
|
0ef0aeceac | ||
|
|
b4a6b7f720 | ||
|
|
c7d46510d7 | ||
|
|
ffd3f1a783 | ||
|
|
29bafe2f7e | ||
|
|
287dd1ee2c | ||
|
|
513c23bfd9 | ||
|
|
011d03a0f6 | ||
|
|
9ab859f27b | ||
|
|
f4f65ef93e | ||
|
|
bd5718d0ad | ||
|
|
161a862ffb | ||
|
|
69994c385a | ||
|
|
b5dbbac308 | ||
|
|
582bd19ee9 | ||
|
|
74f99f227c | ||
|
|
c2bd177ea0 | ||
|
|
fe6e9f580b | ||
|
|
7216c76654 | ||
|
|
dbdfd8967d | ||
|
|
b8e40d146f | ||
|
|
4cc8bb0767 | ||
|
|
4e242b3e20 | ||
|
|
a6245478c8 | ||
|
|
2e9f5ea31a | ||
|
|
a6ad8148b9 | ||
|
|
5b5f35ccc0 | ||
|
|
9b714abf35 | ||
|
|
33122c5a1b | ||
|
|
a9c2e930ac | ||
|
|
c05e6015cc | ||
|
|
e0a75e0c25 | ||
|
|
85f5674e44 | ||
|
|
c43e8a9736 | ||
|
|
a3ac4f6b0a | ||
|
|
5dfd0350c7 | ||
|
|
ca96d609e4 | ||
|
|
2c5972f87f | ||
|
|
6079d0027a | ||
|
|
99a6c9dbf2 | ||
|
|
9342bcfce0 | ||
|
|
e504816977 | ||
|
|
b2e02084b8 | ||
|
|
db3d84f46c | ||
|
|
1b6b0b1e66 | ||
|
|
6b725cf56a | ||
|
|
64665b57d0 | ||
|
|
2b24416e90 | ||
|
|
b92a8e6e4a | ||
|
|
931fc43cc8 | ||
|
|
31aa7bd8d1 | ||
|
|
ad1911bbf4 | ||
|
|
c021c39cbd | ||
|
|
1f43d22397 | ||
|
|
a675bd08bd | ||
|
|
4d7e1dde70 | ||
|
|
ae5d18617a | ||
|
|
9732ec6797 | ||
|
|
0e28281a02 | ||
|
|
505371414f | ||
|
|
e3428d26ca | ||
|
|
35332298ef | ||
|
|
64db043a71 | ||
|
|
b60859d6cc | ||
|
|
d76621a47b | ||
|
|
4ae85ae121 | ||
|
|
cc505b4b5e | ||
|
|
1259a76047 | ||
|
|
802ca12d05 | ||
|
|
e283b555b1 | ||
|
|
b77a13812c | ||
|
|
6dfde6d485 | ||
|
|
c8eeef6947 | ||
|
|
67cb89fbdf | ||
|
|
bf4fb1fb40 | ||
|
|
f807f7f804 | ||
|
|
b8d8ed1ba9 | ||
|
|
cc794d60e7 | ||
|
|
8dd0c85ac5 | ||
|
|
76fa695241 | ||
|
|
f30c4ed2bc | ||
|
|
b752507b48 | ||
|
|
af94ba9d02 | ||
|
|
818b08d0e4 | ||
|
|
ea18996f54 | ||
|
|
68fd82e840 | ||
|
|
4fad8efbfb | ||
|
|
b78bae2d51 | ||
|
|
271f5601f3 | ||
|
|
c3b7a45e84 | ||
|
|
c3e190ce67 | ||
|
|
b75d443caf | ||
|
|
27e727a146 | ||
|
|
4ce4379235 | ||
|
|
c2c47550f9 | ||
|
|
535cc49f27 | ||
|
|
dfbf73408c | ||
|
|
bc7f3eb32f | ||
|
|
ec954f47fb | ||
|
|
81a5e0073c | ||
|
|
ab1bc9bf5f | ||
|
|
0f1eb3e914 | ||
|
|
84e27a592d | ||
|
|
c9f034b4ac | ||
|
|
a9f9d68631 | ||
|
|
707374d5dc | ||
|
|
89fa00ddff | ||
|
|
79bea15830 | ||
|
|
426f8b0f66 | ||
|
|
6a6cc27aee | ||
|
|
4c7c4d4061 | ||
|
|
4d24becf7f | ||
|
|
ba5b9b80a5 | ||
|
|
c7b0678356 | ||
|
|
a6e3222fe5 | ||
|
|
3cc852d339 | ||
|
|
0eeaa25694 | ||
|
|
aa3fac8057 | ||
|
|
c1c81ee2a4 | ||
|
|
e8496efe84 | ||
|
|
01bbacf3c4 | ||
|
|
148428ce76 | ||
|
|
c8f568ddf9 | ||
|
|
3ddda939d3 | ||
|
|
5de926d66f | ||
|
|
f878e6f8af | ||
|
|
269af961e9 | ||
|
|
ed80c6b6cc | ||
|
|
e433393c4f | ||
|
|
985ce80375 | ||
|
|
b9b9714fd5 | ||
|
|
fa969cfdde | ||
|
|
44f8e383f3 | ||
|
|
0c8da8b519 | ||
|
|
eaaa837e00 | ||
|
|
cbe3c3fdd4 | ||
|
|
6748f0a579 | ||
|
|
93b0cf7a99 | ||
|
|
d8ce68b09b | ||
|
|
78d4ced829 | ||
|
|
197c14dbcf | ||
|
|
5f20a91fa1 | ||
|
|
1e2ac54351 | ||
|
|
1e375468de | ||
|
|
c2c188b699 | ||
|
|
c46a0d7eb4 | ||
|
|
bd769a81e1 | ||
|
|
537088e7dc | ||
|
|
41fd9989a2 | ||
|
|
11d62f43c9 | ||
|
|
e4ab96021e | ||
|
|
2a7ed700d5 | ||
|
|
84716d267c | ||
|
|
e4779be97a | ||
|
|
f2da6df568 | ||
|
|
30848c0fcd | ||
|
|
e585c83209 | ||
|
|
6c1bb1601e | ||
|
|
ea87cb1ba5 | ||
|
|
3fed5bb25f | ||
|
|
27955056e0 | ||
|
|
90d70af269 | ||
|
|
7fc1aad195 | ||
|
|
cafb8de132 | ||
|
|
d5325d7ef1 | ||
|
|
d5694ac5fa | ||
|
|
e43de3ae4b | ||
|
|
75e67b9ee4 | ||
|
|
768f00dedb | ||
|
|
4dc07e93a8 | ||
|
|
7cc483aa0e | ||
|
|
e1e7d76cf1 | ||
|
|
93247a424a | ||
|
|
5f501ec7e2 | ||
|
|
761d255fdf | ||
|
|
ace8079086 | ||
|
|
7a44c01d89 | ||
|
|
c9bc4b7031 | ||
|
|
ae79764fe5 | ||
|
|
77f1d24de3 | ||
|
|
9ccb4226ba | ||
|
|
bf86a41ef1 | ||
|
|
8090fd4664 | ||
|
|
3a743f649c | ||
|
|
adec03395d | ||
|
|
74e494b010 | ||
|
|
ef3a5ae787 | ||
|
|
8c06dd6071 | ||
|
|
60c78666ab | ||
|
|
1786b0e768 | ||
|
|
8ad5f34908 | ||
|
|
6cd5fcd536 | ||
|
|
ccc67d445b | ||
|
|
9fd086e506 | ||
|
|
0b03a97708 | ||
|
|
4824a33c31 | ||
|
|
1e5fcfd14a | ||
|
|
17b8e2bd02 | ||
|
|
a8e2a3df32 | ||
|
|
0d7c7fd907 | ||
|
|
95298783bb | ||
|
|
e1dec2f1a7 | ||
|
|
bb746a9de1 | ||
|
|
ae8d4bb0f0 | ||
|
|
197d82dc07 | ||
|
|
069ae2df12 | ||
|
|
47d9848dc4 | ||
|
|
93e504d04e | ||
|
|
b5feaa5a49 | ||
|
|
7f0d0ba3bc | ||
|
|
4a9b1cf253 | ||
|
|
6d8799af1a | ||
|
|
258409ef61 | ||
|
|
bf81f3cf2c | ||
|
|
27ebc5c8f2 | ||
|
|
97c544f91f | ||
|
|
2800983f3e | ||
|
|
8b50fe5330 | ||
|
|
73b4e18c62 | ||
|
|
175a01f56c | ||
|
|
ba3ff7918b | ||
|
|
ef8e578677 | ||
|
|
b880ff190a | ||
|
|
05e21285aa | ||
|
|
a1e67bcb97 | ||
|
|
ebbaae5526 | ||
|
|
966a70f1fa | ||
|
|
629cdfb124 | ||
|
|
ed666d3969 | ||
|
|
b76ef6ccb8 | ||
|
|
851aeae7c7 | ||
|
|
d5e32c843f | ||
|
|
96917d5552 | ||
|
|
0401604222 | ||
|
|
b238cf7f6b | ||
|
|
960dae3340 | ||
|
|
2cc998fed8 | ||
|
|
139fe30f47 | ||
|
|
4d793626ff | ||
|
|
c544188ee3 | ||
|
|
0ab153d201 | ||
|
|
8209b5f033 | ||
|
|
b3bf6a1218 | ||
|
|
57826d645b | ||
|
|
6f443a74cf | ||
|
|
14a34f12d7 | ||
|
|
3431ec55dc | ||
|
|
6027b1992f | ||
|
|
e884ff31d8 | ||
|
|
05c13f6c22 | ||
|
|
94ecd871a0 | ||
|
|
12ed4ee48e | ||
|
|
332839f6ea | ||
|
|
e5ea6dd021 | ||
|
|
cccfcfa7b9 | ||
|
|
68f34e85ce | ||
|
|
3e703eb04e | ||
|
|
508460f240 | ||
|
|
6e9f147faa | ||
|
|
4540730111 | ||
|
|
e96ee95a7e | ||
|
|
2f9eafdd36 | ||
|
|
b3de67234e | ||
|
|
cb3aee8219 | ||
|
|
85fda57208 | ||
|
|
4b203bdba5 | ||
|
|
3b0470dba5 | ||
|
|
8575e3160f | ||
|
|
a78cda4baf | ||
|
|
7a39da8cc6 | ||
|
|
5bbb53580a | ||
|
|
26451a09eb | ||
|
|
8d55877c9e | ||
|
|
a62406aaa5 | ||
|
|
28e8c46f29 | ||
|
|
6d586dc05c | ||
|
|
410b4e14a1 | ||
|
|
fe4e885f54 | ||
|
|
bbb739d24a | ||
|
|
26752df503 | ||
|
|
e52c391cd4 | ||
|
|
0aac30d53b | ||
|
|
6322fbbd41 | ||
|
|
8ba89f1050 | ||
|
|
429925a5e9 | ||
|
|
83936293eb | ||
|
|
b8ca494ee9 |
50
CHANGES.rst
50
CHANGES.rst
@@ -1,3 +1,53 @@
|
||||
Changes in synapse v0.24.0 (2017-10-23)
|
||||
=======================================
|
||||
|
||||
No changes since v0.24.0-rc1
|
||||
|
||||
|
||||
Changes in synapse v0.24.0-rc1 (2017-10-19)
|
||||
===========================================
|
||||
|
||||
Features:
|
||||
|
||||
* Add Group Server (PR #2352, #2363, #2374, #2377, #2378, #2382, #2410, #2426,
|
||||
#2430, #2454, #2471, #2472, #2544)
|
||||
* Add support for channel notifications (PR #2501)
|
||||
* Add basic implementation of backup media store (PR #2538)
|
||||
* Add config option to auto-join new users to rooms (PR #2545)
|
||||
|
||||
|
||||
Changes:
|
||||
|
||||
* Make the spam checker a module (PR #2474)
|
||||
* Delete expired url cache data (PR #2478)
|
||||
* Ignore incoming events for rooms that we have left (PR #2490)
|
||||
* Allow spam checker to reject invites too (PR #2492)
|
||||
* Add room creation checks to spam checker (PR #2495)
|
||||
* Spam checking: add the invitee to user_may_invite (PR #2502)
|
||||
* Process events from federation for different rooms in parallel (PR #2520)
|
||||
* Allow error strings from spam checker (PR #2531)
|
||||
* Improve error handling for missing files in config (PR #2551)
|
||||
|
||||
|
||||
Bug fixes:
|
||||
|
||||
* Fix handling SERVFAILs when doing AAAA lookups for federation (PR #2477)
|
||||
* Fix incompatibility with newer versions of ujson (PR #2483) Thanks to
|
||||
@jeremycline!
|
||||
* Fix notification keywords that start/end with non-word chars (PR #2500)
|
||||
* Fix stack overflow and logcontexts from linearizer (PR #2532)
|
||||
* Fix 500 error when fields missing from power_levels event (PR #2552)
|
||||
* Fix 500 error when we get an error handling a PDU (PR #2553)
|
||||
|
||||
|
||||
Changes in synapse v0.23.1 (2017-10-02)
|
||||
=======================================
|
||||
|
||||
Changes:
|
||||
|
||||
* Make 'affinity' package optional, as it is not supported on some platforms
|
||||
|
||||
|
||||
Changes in synapse v0.23.0 (2017-10-02)
|
||||
=======================================
|
||||
|
||||
|
||||
@@ -4,6 +4,8 @@ Purge History API
|
||||
The purge history API allows server admins to purge historic events from their
|
||||
database, reclaiming disk space.
|
||||
|
||||
**NB!** This will not delete local events (locally sent messages content etc) from the database, but will remove lots of the metadata about them and does dramatically reduce the on disk space usage
|
||||
|
||||
Depending on the amount of history being purged a call to the API may take
|
||||
several minutes or longer. During this period users will not be able to
|
||||
paginate further back in the room from the point being purged from.
|
||||
|
||||
@@ -50,7 +50,7 @@ master_doc = 'index'
|
||||
|
||||
# General information about the project.
|
||||
project = u'Synapse'
|
||||
copyright = u'2014, TNG'
|
||||
copyright = u'Copyright 2014-2017 OpenMarket Ltd, 2017 Vector Creations Ltd, 2017 New Vector Ltd'
|
||||
|
||||
# The version info for the project you're documenting, acts as replacement for
|
||||
# |version| and |release|, also used in various other places throughout the
|
||||
|
||||
@@ -376,10 +376,13 @@ class Porter(object):
|
||||
" VALUES (?,?,?,?,to_tsvector('english', ?),?,?)"
|
||||
)
|
||||
|
||||
rows_dict = [
|
||||
dict(zip(headers, row))
|
||||
for row in rows
|
||||
]
|
||||
rows_dict = []
|
||||
for row in rows:
|
||||
d = dict(zip(headers, row))
|
||||
if "\0" in d['value']:
|
||||
logger.warn('dropping search row %s', d)
|
||||
else:
|
||||
rows_dict.append(d)
|
||||
|
||||
txn.executemany(sql, [
|
||||
(
|
||||
|
||||
@@ -16,4 +16,4 @@
|
||||
""" This is a reference implementation of a Matrix home server.
|
||||
"""
|
||||
|
||||
__version__ = "0.23.0"
|
||||
__version__ = "0.24.0"
|
||||
|
||||
@@ -12,10 +12,16 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import logging
|
||||
import sys
|
||||
|
||||
try:
|
||||
import affinity
|
||||
except:
|
||||
affinity = None
|
||||
|
||||
import affinity
|
||||
from daemonize import Daemonize
|
||||
from synapse.util import PreserveLoggingContext
|
||||
from synapse.util.rlimit import change_resource_limit
|
||||
@@ -78,6 +84,13 @@ def start_reactor(
|
||||
with PreserveLoggingContext():
|
||||
logger.info("Running")
|
||||
if cpu_affinity is not None:
|
||||
if not affinity:
|
||||
quit_with_error(
|
||||
"Missing package 'affinity' required for cpu_affinity\n"
|
||||
"option\n\n"
|
||||
"Install by running:\n\n"
|
||||
" pip install affinity\n\n"
|
||||
)
|
||||
logger.info("Setting CPU affinity to %s" % cpu_affinity)
|
||||
affinity.set_process_affinity_mask(0, cpu_affinity)
|
||||
change_resource_limit(soft_file_limit)
|
||||
@@ -97,3 +110,13 @@ def start_reactor(
|
||||
daemon.start()
|
||||
else:
|
||||
run()
|
||||
|
||||
|
||||
def quit_with_error(error_string):
|
||||
message_lines = error_string.split("\n")
|
||||
line_length = max([len(l) for l in message_lines if len(l) < 80]) + 2
|
||||
sys.stderr.write("*" * line_length + '\n')
|
||||
for line in message_lines:
|
||||
sys.stderr.write(" %s\n" % (line.rstrip(),))
|
||||
sys.stderr.write("*" * line_length + '\n')
|
||||
sys.exit(1)
|
||||
|
||||
@@ -25,6 +25,7 @@ from synapse.api.urls import CONTENT_REPO_PREFIX, FEDERATION_PREFIX, \
|
||||
LEGACY_MEDIA_PREFIX, MEDIA_PREFIX, SERVER_KEY_PREFIX, SERVER_KEY_V2_PREFIX, \
|
||||
STATIC_PREFIX, WEB_CLIENT_PREFIX
|
||||
from synapse.app import _base
|
||||
from synapse.app._base import quit_with_error
|
||||
from synapse.config._base import ConfigError
|
||||
from synapse.config.homeserver import HomeServerConfig
|
||||
from synapse.crypto import context_factory
|
||||
@@ -249,16 +250,6 @@ class SynapseHomeServer(HomeServer):
|
||||
return db_conn
|
||||
|
||||
|
||||
def quit_with_error(error_string):
|
||||
message_lines = error_string.split("\n")
|
||||
line_length = max([len(l) for l in message_lines if len(l) < 80]) + 2
|
||||
sys.stderr.write("*" * line_length + '\n')
|
||||
for line in message_lines:
|
||||
sys.stderr.write(" %s\n" % (line.rstrip(),))
|
||||
sys.stderr.write("*" * line_length + '\n')
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def setup(config_options):
|
||||
"""
|
||||
Args:
|
||||
|
||||
@@ -40,6 +40,7 @@ from synapse.replication.slave.storage.push_rule import SlavedPushRuleStore
|
||||
from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
|
||||
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
|
||||
from synapse.replication.slave.storage.room import RoomStore
|
||||
from synapse.replication.slave.storage.groups import SlavedGroupServerStore
|
||||
from synapse.replication.tcp.client import ReplicationClientHandler
|
||||
from synapse.rest.client.v1 import events
|
||||
from synapse.rest.client.v1.initial_sync import InitialSyncRestServlet
|
||||
@@ -69,6 +70,7 @@ class SynchrotronSlavedStore(
|
||||
SlavedRegistrationStore,
|
||||
SlavedFilteringStore,
|
||||
SlavedPresenceStore,
|
||||
SlavedGroupServerStore,
|
||||
SlavedDeviceInboxStore,
|
||||
SlavedDeviceStore,
|
||||
SlavedClientIpStore,
|
||||
@@ -403,6 +405,10 @@ class SyncReplicationHandler(ReplicationClientHandler):
|
||||
)
|
||||
elif stream_name == "presence":
|
||||
yield self.presence_handler.process_replication_rows(token, rows)
|
||||
elif stream_name == "receipts":
|
||||
self.notifier.on_new_event(
|
||||
"groups_key", token, users=[row.user_id for row in rows],
|
||||
)
|
||||
|
||||
|
||||
def start(config_options):
|
||||
|
||||
@@ -81,22 +81,38 @@ class Config(object):
|
||||
def abspath(file_path):
|
||||
return os.path.abspath(file_path) if file_path else file_path
|
||||
|
||||
@classmethod
|
||||
def path_exists(cls, file_path):
|
||||
"""Check if a file exists
|
||||
|
||||
Unlike os.path.exists, this throws an exception if there is an error
|
||||
checking if the file exists (for example, if there is a perms error on
|
||||
the parent dir).
|
||||
|
||||
Returns:
|
||||
bool: True if the file exists; False if not.
|
||||
"""
|
||||
try:
|
||||
os.stat(file_path)
|
||||
return True
|
||||
except OSError as e:
|
||||
if e.errno != errno.ENOENT:
|
||||
raise e
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def check_file(cls, file_path, config_name):
|
||||
if file_path is None:
|
||||
raise ConfigError(
|
||||
"Missing config for %s."
|
||||
" You must specify a path for the config file. You can "
|
||||
"do this with the -c or --config-path option. "
|
||||
"Adding --generate-config along with --server-name "
|
||||
"<server name> will generate a config file at the given path."
|
||||
% (config_name,)
|
||||
)
|
||||
if not os.path.exists(file_path):
|
||||
try:
|
||||
os.stat(file_path)
|
||||
except OSError as e:
|
||||
raise ConfigError(
|
||||
"File %s config for %s doesn't exist."
|
||||
" Try running again with --generate-config"
|
||||
% (file_path, config_name,)
|
||||
"Error accessing file '%s' (config for %s): %s"
|
||||
% (file_path, config_name, e.strerror)
|
||||
)
|
||||
return cls.abspath(file_path)
|
||||
|
||||
@@ -248,7 +264,7 @@ class Config(object):
|
||||
" -c CONFIG-FILE\""
|
||||
)
|
||||
(config_path,) = config_files
|
||||
if not os.path.exists(config_path):
|
||||
if not cls.path_exists(config_path):
|
||||
if config_args.keys_directory:
|
||||
config_dir_path = config_args.keys_directory
|
||||
else:
|
||||
@@ -261,7 +277,7 @@ class Config(object):
|
||||
"Must specify a server_name to a generate config for."
|
||||
" Pass -H server.name."
|
||||
)
|
||||
if not os.path.exists(config_dir_path):
|
||||
if not cls.path_exists(config_dir_path):
|
||||
os.makedirs(config_dir_path)
|
||||
with open(config_path, "wb") as config_file:
|
||||
config_bytes, config = obj.generate_config(
|
||||
|
||||
32
synapse/config/groups.py
Normal file
32
synapse/config/groups.py
Normal file
@@ -0,0 +1,32 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2017 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ._base import Config
|
||||
|
||||
|
||||
class GroupsConfig(Config):
|
||||
def read_config(self, config):
|
||||
self.enable_group_creation = config.get("enable_group_creation", False)
|
||||
self.group_creation_prefix = config.get("group_creation_prefix", "")
|
||||
|
||||
def default_config(self, **kwargs):
|
||||
return """\
|
||||
# Whether to allow non server admins to create groups on this server
|
||||
enable_group_creation: false
|
||||
|
||||
# If enabled, non server admins can only create groups with local parts
|
||||
# starting with this prefix
|
||||
# group_creation_prefix: "unofficial/"
|
||||
"""
|
||||
@@ -34,6 +34,8 @@ from .password_auth_providers import PasswordAuthProviderConfig
|
||||
from .emailconfig import EmailConfig
|
||||
from .workers import WorkerConfig
|
||||
from .push import PushConfig
|
||||
from .spam_checker import SpamCheckerConfig
|
||||
from .groups import GroupsConfig
|
||||
|
||||
|
||||
class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
|
||||
@@ -41,7 +43,8 @@ class HomeServerConfig(TlsConfig, ServerConfig, DatabaseConfig, LoggingConfig,
|
||||
VoipConfig, RegistrationConfig, MetricsConfig, ApiConfig,
|
||||
AppServiceConfig, KeyConfig, SAML2Config, CasConfig,
|
||||
JWTConfig, PasswordConfig, EmailConfig,
|
||||
WorkerConfig, PasswordAuthProviderConfig, PushConfig,):
|
||||
WorkerConfig, PasswordAuthProviderConfig, PushConfig,
|
||||
SpamCheckerConfig, GroupsConfig,):
|
||||
pass
|
||||
|
||||
|
||||
|
||||
@@ -118,10 +118,9 @@ class KeyConfig(Config):
|
||||
signing_keys = self.read_file(signing_key_path, "signing_key")
|
||||
try:
|
||||
return read_signing_keys(signing_keys.splitlines(True))
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
raise ConfigError(
|
||||
"Error reading signing_key."
|
||||
" Try running again with --generate-config"
|
||||
"Error reading signing_key: %s" % (str(e))
|
||||
)
|
||||
|
||||
def read_old_signing_keys(self, old_signing_keys):
|
||||
@@ -141,7 +140,8 @@ class KeyConfig(Config):
|
||||
|
||||
def generate_files(self, config):
|
||||
signing_key_path = config["signing_key_path"]
|
||||
if not os.path.exists(signing_key_path):
|
||||
|
||||
if not self.path_exists(signing_key_path):
|
||||
with open(signing_key_path, "w") as signing_key_file:
|
||||
key_id = "a_" + random_string(4)
|
||||
write_signing_keys(
|
||||
|
||||
@@ -15,13 +15,15 @@
|
||||
|
||||
from ._base import Config, ConfigError
|
||||
|
||||
import importlib
|
||||
from synapse.util.module_loader import load_module
|
||||
|
||||
|
||||
class PasswordAuthProviderConfig(Config):
|
||||
def read_config(self, config):
|
||||
self.password_providers = []
|
||||
|
||||
provider_config = None
|
||||
|
||||
# We want to be backwards compatible with the old `ldap_config`
|
||||
# param.
|
||||
ldap_config = config.get("ldap_config", {})
|
||||
@@ -38,19 +40,15 @@ class PasswordAuthProviderConfig(Config):
|
||||
if provider['module'] == "synapse.util.ldap_auth_provider.LdapAuthProvider":
|
||||
from ldap_auth_provider import LdapAuthProvider
|
||||
provider_class = LdapAuthProvider
|
||||
try:
|
||||
provider_config = provider_class.parse_config(provider["config"])
|
||||
except Exception as e:
|
||||
raise ConfigError(
|
||||
"Failed to parse config for %r: %r" % (provider['module'], e)
|
||||
)
|
||||
else:
|
||||
# We need to import the module, and then pick the class out of
|
||||
# that, so we split based on the last dot.
|
||||
module, clz = provider['module'].rsplit(".", 1)
|
||||
module = importlib.import_module(module)
|
||||
provider_class = getattr(module, clz)
|
||||
(provider_class, provider_config) = load_module(provider)
|
||||
|
||||
try:
|
||||
provider_config = provider_class.parse_config(provider["config"])
|
||||
except Exception as e:
|
||||
raise ConfigError(
|
||||
"Failed to parse config for %r: %r" % (provider['module'], e)
|
||||
)
|
||||
self.password_providers.append((provider_class, provider_config))
|
||||
|
||||
def default_config(self, **kwargs):
|
||||
|
||||
@@ -41,6 +41,8 @@ class RegistrationConfig(Config):
|
||||
self.allow_guest_access and config.get("invite_3pid_guest", False)
|
||||
)
|
||||
|
||||
self.auto_join_rooms = config.get("auto_join_rooms", [])
|
||||
|
||||
def default_config(self, **kwargs):
|
||||
registration_shared_secret = random_string_with_symbols(50)
|
||||
|
||||
@@ -70,6 +72,11 @@ class RegistrationConfig(Config):
|
||||
- matrix.org
|
||||
- vector.im
|
||||
- riot.im
|
||||
|
||||
# Users who register on this homeserver will automatically be joined
|
||||
# to these rooms
|
||||
#auto_join_rooms:
|
||||
# - "#example:example.com"
|
||||
""" % locals()
|
||||
|
||||
def add_arguments(self, parser):
|
||||
|
||||
@@ -70,7 +70,19 @@ class ContentRepositoryConfig(Config):
|
||||
self.max_upload_size = self.parse_size(config["max_upload_size"])
|
||||
self.max_image_pixels = self.parse_size(config["max_image_pixels"])
|
||||
self.max_spider_size = self.parse_size(config["max_spider_size"])
|
||||
|
||||
self.media_store_path = self.ensure_directory(config["media_store_path"])
|
||||
|
||||
self.backup_media_store_path = config.get("backup_media_store_path")
|
||||
if self.backup_media_store_path:
|
||||
self.backup_media_store_path = self.ensure_directory(
|
||||
self.backup_media_store_path
|
||||
)
|
||||
|
||||
self.synchronous_backup_media_store = config.get(
|
||||
"synchronous_backup_media_store", False
|
||||
)
|
||||
|
||||
self.uploads_path = self.ensure_directory(config["uploads_path"])
|
||||
self.dynamic_thumbnails = config["dynamic_thumbnails"]
|
||||
self.thumbnail_requirements = parse_thumbnail_requirements(
|
||||
@@ -115,6 +127,14 @@ class ContentRepositoryConfig(Config):
|
||||
# Directory where uploaded images and attachments are stored.
|
||||
media_store_path: "%(media_store)s"
|
||||
|
||||
# A secondary directory where uploaded images and attachments are
|
||||
# stored as a backup.
|
||||
# backup_media_store_path: "%(media_store)s"
|
||||
|
||||
# Whether to wait for successful write to backup media store before
|
||||
# returning successfully.
|
||||
# synchronous_backup_media_store: false
|
||||
|
||||
# Directory where in-progress uploads are stored.
|
||||
uploads_path: "%(uploads_path)s"
|
||||
|
||||
|
||||
35
synapse/config/spam_checker.py
Normal file
35
synapse/config/spam_checker.py
Normal file
@@ -0,0 +1,35 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2017 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from synapse.util.module_loader import load_module
|
||||
|
||||
from ._base import Config
|
||||
|
||||
|
||||
class SpamCheckerConfig(Config):
|
||||
def read_config(self, config):
|
||||
self.spam_checker = None
|
||||
|
||||
provider = config.get("spam_checker", None)
|
||||
if provider is not None:
|
||||
self.spam_checker = load_module(provider)
|
||||
|
||||
def default_config(self, **kwargs):
|
||||
return """\
|
||||
# spam_checker:
|
||||
# module: "my_custom_project.SuperSpamChecker"
|
||||
# config:
|
||||
# example_option: 'things'
|
||||
"""
|
||||
@@ -126,7 +126,7 @@ class TlsConfig(Config):
|
||||
tls_private_key_path = config["tls_private_key_path"]
|
||||
tls_dh_params_path = config["tls_dh_params_path"]
|
||||
|
||||
if not os.path.exists(tls_private_key_path):
|
||||
if not self.path_exists(tls_private_key_path):
|
||||
with open(tls_private_key_path, "w") as private_key_file:
|
||||
tls_private_key = crypto.PKey()
|
||||
tls_private_key.generate_key(crypto.TYPE_RSA, 2048)
|
||||
@@ -141,7 +141,7 @@ class TlsConfig(Config):
|
||||
crypto.FILETYPE_PEM, private_key_pem
|
||||
)
|
||||
|
||||
if not os.path.exists(tls_certificate_path):
|
||||
if not self.path_exists(tls_certificate_path):
|
||||
with open(tls_certificate_path, "w") as certificate_file:
|
||||
cert = crypto.X509()
|
||||
subject = cert.get_subject()
|
||||
@@ -159,7 +159,7 @@ class TlsConfig(Config):
|
||||
|
||||
certificate_file.write(cert_pem)
|
||||
|
||||
if not os.path.exists(tls_dh_params_path):
|
||||
if not self.path_exists(tls_dh_params_path):
|
||||
if GENERATE_DH_PARAMS:
|
||||
subprocess.check_call([
|
||||
"openssl", "dhparam",
|
||||
|
||||
@@ -470,14 +470,14 @@ def _check_power_levels(event, auth_events):
|
||||
("invite", None),
|
||||
]
|
||||
|
||||
old_list = current_state.content.get("users")
|
||||
old_list = current_state.content.get("users", {})
|
||||
for user in set(old_list.keys() + user_list.keys()):
|
||||
levels_to_check.append(
|
||||
(user, "users")
|
||||
)
|
||||
|
||||
old_list = current_state.content.get("events")
|
||||
new_list = event.content.get("events")
|
||||
old_list = current_state.content.get("events", {})
|
||||
new_list = event.content.get("events", {})
|
||||
for ev_id in set(old_list.keys() + new_list.keys()):
|
||||
levels_to_check.append(
|
||||
(ev_id, "events")
|
||||
|
||||
@@ -14,25 +14,100 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
def check_event_for_spam(event):
|
||||
"""Checks if a given event is considered "spammy" by this server.
|
||||
class SpamChecker(object):
|
||||
def __init__(self, hs):
|
||||
self.spam_checker = None
|
||||
|
||||
If the server considers an event spammy, then it will be rejected if
|
||||
sent by a local user. If it is sent by a user on another server, then
|
||||
users receive a blank event.
|
||||
module = None
|
||||
config = None
|
||||
try:
|
||||
module, config = hs.config.spam_checker
|
||||
except:
|
||||
pass
|
||||
|
||||
Args:
|
||||
event (synapse.events.EventBase): the event to be checked
|
||||
if module is not None:
|
||||
self.spam_checker = module(config=config)
|
||||
|
||||
Returns:
|
||||
bool: True if the event is spammy.
|
||||
"""
|
||||
if not hasattr(event, "content") or "body" not in event.content:
|
||||
return False
|
||||
def check_event_for_spam(self, event):
|
||||
"""Checks if a given event is considered "spammy" by this server.
|
||||
|
||||
# for example:
|
||||
#
|
||||
# if "the third flower is green" in event.content["body"]:
|
||||
# return True
|
||||
If the server considers an event spammy, then it will be rejected if
|
||||
sent by a local user. If it is sent by a user on another server, then
|
||||
users receive a blank event.
|
||||
|
||||
return False
|
||||
Args:
|
||||
event (synapse.events.EventBase): the event to be checked
|
||||
|
||||
Returns:
|
||||
bool: True if the event is spammy.
|
||||
"""
|
||||
if self.spam_checker is None:
|
||||
return False
|
||||
|
||||
return self.spam_checker.check_event_for_spam(event)
|
||||
|
||||
def user_may_invite(self, inviter_userid, invitee_userid, room_id):
|
||||
"""Checks if a given user may send an invite
|
||||
|
||||
If this method returns false, the invite will be rejected.
|
||||
|
||||
Args:
|
||||
userid (string): The sender's user ID
|
||||
|
||||
Returns:
|
||||
bool: True if the user may send an invite, otherwise False
|
||||
"""
|
||||
if self.spam_checker is None:
|
||||
return True
|
||||
|
||||
return self.spam_checker.user_may_invite(inviter_userid, invitee_userid, room_id)
|
||||
|
||||
def user_may_create_room(self, userid):
|
||||
"""Checks if a given user may create a room
|
||||
|
||||
If this method returns false, the creation request will be rejected.
|
||||
|
||||
Args:
|
||||
userid (string): The sender's user ID
|
||||
|
||||
Returns:
|
||||
bool: True if the user may create a room, otherwise False
|
||||
"""
|
||||
if self.spam_checker is None:
|
||||
return True
|
||||
|
||||
return self.spam_checker.user_may_create_room(userid)
|
||||
|
||||
def user_may_create_room_alias(self, userid, room_alias):
|
||||
"""Checks if a given user may create a room alias
|
||||
|
||||
If this method returns false, the association request will be rejected.
|
||||
|
||||
Args:
|
||||
userid (string): The sender's user ID
|
||||
room_alias (string): The alias to be created
|
||||
|
||||
Returns:
|
||||
bool: True if the user may create a room alias, otherwise False
|
||||
"""
|
||||
if self.spam_checker is None:
|
||||
return True
|
||||
|
||||
return self.spam_checker.user_may_create_room_alias(userid, room_alias)
|
||||
|
||||
def user_may_publish_room(self, userid, room_id):
|
||||
"""Checks if a given user may publish a room to the directory
|
||||
|
||||
If this method returns false, the publish request will be rejected.
|
||||
|
||||
Args:
|
||||
userid (string): The sender's user ID
|
||||
room_id (string): The ID of the room that would be published
|
||||
|
||||
Returns:
|
||||
bool: True if the user may publish the room, otherwise False
|
||||
"""
|
||||
if self.spam_checker is None:
|
||||
return True
|
||||
|
||||
return self.spam_checker.user_may_publish_room(userid, room_id)
|
||||
|
||||
@@ -16,7 +16,6 @@ import logging
|
||||
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.crypto.event_signing import check_event_content_hash
|
||||
from synapse.events import spamcheck
|
||||
from synapse.events.utils import prune_event
|
||||
from synapse.util import unwrapFirstError, logcontext
|
||||
from twisted.internet import defer
|
||||
@@ -26,7 +25,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
class FederationBase(object):
|
||||
def __init__(self, hs):
|
||||
pass
|
||||
self.spam_checker = hs.get_spam_checker()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _check_sigs_and_hash_and_fetch(self, origin, pdus, outlier=False,
|
||||
@@ -144,7 +143,7 @@ class FederationBase(object):
|
||||
)
|
||||
return redacted
|
||||
|
||||
if spamcheck.check_event_for_spam(pdu):
|
||||
if self.spam_checker.check_event_for_spam(pdu):
|
||||
logger.warn(
|
||||
"Event contains spam, redacting %s: %s",
|
||||
pdu.event_id, pdu.get_pdu_json()
|
||||
|
||||
@@ -12,14 +12,12 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from .federation_base import FederationBase
|
||||
from .units import Transaction, Edu
|
||||
|
||||
from synapse.util.async import Linearizer
|
||||
from synapse.util import async
|
||||
from synapse.util.logutils import log_function
|
||||
from synapse.util.caches.response_cache import ResponseCache
|
||||
from synapse.events import FrozenEvent
|
||||
@@ -33,6 +31,9 @@ from synapse.crypto.event_signing import compute_event_signature
|
||||
import simplejson as json
|
||||
import logging
|
||||
|
||||
# when processing incoming transactions, we try to handle multiple rooms in
|
||||
# parallel, up to this limit.
|
||||
TRANSACTION_CONCURRENCY_LIMIT = 10
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -52,7 +53,8 @@ class FederationServer(FederationBase):
|
||||
|
||||
self.auth = hs.get_auth()
|
||||
|
||||
self._server_linearizer = Linearizer("fed_server")
|
||||
self._server_linearizer = async.Linearizer("fed_server")
|
||||
self._transaction_linearizer = async.Linearizer("fed_txn_handler")
|
||||
|
||||
# We cache responses to state queries, as they take a while and often
|
||||
# come in waves.
|
||||
@@ -109,25 +111,41 @@ class FederationServer(FederationBase):
|
||||
@defer.inlineCallbacks
|
||||
@log_function
|
||||
def on_incoming_transaction(self, transaction_data):
|
||||
# keep this as early as possible to make the calculated origin ts as
|
||||
# accurate as possible.
|
||||
request_time = self._clock.time_msec()
|
||||
|
||||
transaction = Transaction(**transaction_data)
|
||||
|
||||
received_pdus_counter.inc_by(len(transaction.pdus))
|
||||
|
||||
for p in transaction.pdus:
|
||||
if "unsigned" in p:
|
||||
unsigned = p["unsigned"]
|
||||
if "age" in unsigned:
|
||||
p["age"] = unsigned["age"]
|
||||
if "age" in p:
|
||||
p["age_ts"] = int(self._clock.time_msec()) - int(p["age"])
|
||||
del p["age"]
|
||||
|
||||
pdu_list = [
|
||||
self.event_from_pdu_json(p) for p in transaction.pdus
|
||||
]
|
||||
if not transaction.transaction_id:
|
||||
raise Exception("Transaction missing transaction_id")
|
||||
if not transaction.origin:
|
||||
raise Exception("Transaction missing origin")
|
||||
|
||||
logger.debug("[%s] Got transaction", transaction.transaction_id)
|
||||
|
||||
# use a linearizer to ensure that we don't process the same transaction
|
||||
# multiple times in parallel.
|
||||
with (yield self._transaction_linearizer.queue(
|
||||
(transaction.origin, transaction.transaction_id),
|
||||
)):
|
||||
result = yield self._handle_incoming_transaction(
|
||||
transaction, request_time,
|
||||
)
|
||||
|
||||
defer.returnValue(result)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _handle_incoming_transaction(self, transaction, request_time):
|
||||
""" Process an incoming transaction and return the HTTP response
|
||||
|
||||
Args:
|
||||
transaction (Transaction): incoming transaction
|
||||
request_time (int): timestamp that the HTTP request arrived at
|
||||
|
||||
Returns:
|
||||
Deferred[(int, object)]: http response code and body
|
||||
"""
|
||||
response = yield self.transaction_actions.have_responded(transaction)
|
||||
|
||||
if response:
|
||||
@@ -140,42 +158,49 @@ class FederationServer(FederationBase):
|
||||
|
||||
logger.debug("[%s] Transaction is new", transaction.transaction_id)
|
||||
|
||||
results = []
|
||||
received_pdus_counter.inc_by(len(transaction.pdus))
|
||||
|
||||
for pdu in pdu_list:
|
||||
# check that it's actually being sent from a valid destination to
|
||||
# workaround bug #1753 in 0.18.5 and 0.18.6
|
||||
if transaction.origin != get_domain_from_id(pdu.event_id):
|
||||
# We continue to accept join events from any server; this is
|
||||
# necessary for the federation join dance to work correctly.
|
||||
# (When we join over federation, the "helper" server is
|
||||
# responsible for sending out the join event, rather than the
|
||||
# origin. See bug #1893).
|
||||
if not (
|
||||
pdu.type == 'm.room.member' and
|
||||
pdu.content and
|
||||
pdu.content.get("membership", None) == 'join'
|
||||
):
|
||||
logger.info(
|
||||
"Discarding PDU %s from invalid origin %s",
|
||||
pdu.event_id, transaction.origin
|
||||
)
|
||||
continue
|
||||
else:
|
||||
logger.info(
|
||||
"Accepting join PDU %s from %s",
|
||||
pdu.event_id, transaction.origin
|
||||
)
|
||||
pdus_by_room = {}
|
||||
|
||||
try:
|
||||
yield self._handle_received_pdu(transaction.origin, pdu)
|
||||
results.append({})
|
||||
except FederationError as e:
|
||||
self.send_failure(e, transaction.origin)
|
||||
results.append({"error": str(e)})
|
||||
except Exception as e:
|
||||
results.append({"error": str(e)})
|
||||
logger.exception("Failed to handle PDU")
|
||||
for p in transaction.pdus:
|
||||
if "unsigned" in p:
|
||||
unsigned = p["unsigned"]
|
||||
if "age" in unsigned:
|
||||
p["age"] = unsigned["age"]
|
||||
if "age" in p:
|
||||
p["age_ts"] = request_time - int(p["age"])
|
||||
del p["age"]
|
||||
|
||||
event = self.event_from_pdu_json(p)
|
||||
room_id = event.room_id
|
||||
pdus_by_room.setdefault(room_id, []).append(event)
|
||||
|
||||
pdu_results = {}
|
||||
|
||||
# we can process different rooms in parallel (which is useful if they
|
||||
# require callouts to other servers to fetch missing events), but
|
||||
# impose a limit to avoid going too crazy with ram/cpu.
|
||||
@defer.inlineCallbacks
|
||||
def process_pdus_for_room(room_id):
|
||||
logger.debug("Processing PDUs for %s", room_id)
|
||||
for pdu in pdus_by_room[room_id]:
|
||||
event_id = pdu.event_id
|
||||
try:
|
||||
yield self._handle_received_pdu(
|
||||
transaction.origin, pdu
|
||||
)
|
||||
pdu_results[event_id] = {}
|
||||
except FederationError as e:
|
||||
logger.warn("Error handling PDU %s: %s", event_id, e)
|
||||
pdu_results[event_id] = {"error": str(e)}
|
||||
except Exception as e:
|
||||
pdu_results[event_id] = {"error": str(e)}
|
||||
logger.exception("Failed to handle PDU %s", event_id)
|
||||
|
||||
yield async.concurrently_execute(
|
||||
process_pdus_for_room, pdus_by_room.keys(),
|
||||
TRANSACTION_CONCURRENCY_LIMIT,
|
||||
)
|
||||
|
||||
if hasattr(transaction, "edus"):
|
||||
for edu in (Edu(**x) for x in transaction.edus):
|
||||
@@ -185,17 +210,16 @@ class FederationServer(FederationBase):
|
||||
edu.content
|
||||
)
|
||||
|
||||
for failure in getattr(transaction, "pdu_failures", []):
|
||||
logger.info("Got failure %r", failure)
|
||||
|
||||
logger.debug("Returning: %s", str(results))
|
||||
pdu_failures = getattr(transaction, "pdu_failures", [])
|
||||
for failure in pdu_failures:
|
||||
logger.info("Got failure %r", failure)
|
||||
|
||||
response = {
|
||||
"pdus": dict(zip(
|
||||
(p.event_id for p in pdu_list), results
|
||||
)),
|
||||
"pdus": pdu_results,
|
||||
}
|
||||
|
||||
logger.debug("Returning: %s", str(response))
|
||||
|
||||
yield self.transaction_actions.set_response(
|
||||
transaction,
|
||||
200, response
|
||||
@@ -520,6 +544,30 @@ class FederationServer(FederationBase):
|
||||
Returns (Deferred): completes with None
|
||||
Raises: FederationError if the signatures / hash do not match
|
||||
"""
|
||||
# check that it's actually being sent from a valid destination to
|
||||
# workaround bug #1753 in 0.18.5 and 0.18.6
|
||||
if origin != get_domain_from_id(pdu.event_id):
|
||||
# We continue to accept join events from any server; this is
|
||||
# necessary for the federation join dance to work correctly.
|
||||
# (When we join over federation, the "helper" server is
|
||||
# responsible for sending out the join event, rather than the
|
||||
# origin. See bug #1893).
|
||||
if not (
|
||||
pdu.type == 'm.room.member' and
|
||||
pdu.content and
|
||||
pdu.content.get("membership", None) == 'join'
|
||||
):
|
||||
logger.info(
|
||||
"Discarding PDU %s from invalid origin %s",
|
||||
pdu.event_id, origin
|
||||
)
|
||||
return
|
||||
else:
|
||||
logger.info(
|
||||
"Accepting join PDU %s from %s",
|
||||
pdu.event_id, origin
|
||||
)
|
||||
|
||||
# Check signature.
|
||||
try:
|
||||
pdu = yield self._check_sigs_and_hash(pdu)
|
||||
|
||||
@@ -20,8 +20,8 @@ from .persistence import TransactionActions
|
||||
from .units import Transaction, Edu
|
||||
|
||||
from synapse.api.errors import HttpResponseException
|
||||
from synapse.util import logcontext
|
||||
from synapse.util.async import run_on_reactor
|
||||
from synapse.util.logcontext import preserve_context_over_fn, preserve_fn
|
||||
from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter
|
||||
from synapse.util.metrics import measure_func
|
||||
from synapse.handlers.presence import format_user_presence_state, get_interested_remotes
|
||||
@@ -231,11 +231,9 @@ class TransactionQueue(object):
|
||||
(pdu, order)
|
||||
)
|
||||
|
||||
preserve_context_over_fn(
|
||||
self._attempt_new_transaction, destination
|
||||
)
|
||||
self._attempt_new_transaction(destination)
|
||||
|
||||
@preserve_fn # the caller should not yield on this
|
||||
@logcontext.preserve_fn # the caller should not yield on this
|
||||
@defer.inlineCallbacks
|
||||
def send_presence(self, states):
|
||||
"""Send the new presence states to the appropriate destinations.
|
||||
@@ -299,7 +297,7 @@ class TransactionQueue(object):
|
||||
state.user_id: state for state in states
|
||||
})
|
||||
|
||||
preserve_fn(self._attempt_new_transaction)(destination)
|
||||
self._attempt_new_transaction(destination)
|
||||
|
||||
def send_edu(self, destination, edu_type, content, key=None):
|
||||
edu = Edu(
|
||||
@@ -321,9 +319,7 @@ class TransactionQueue(object):
|
||||
else:
|
||||
self.pending_edus_by_dest.setdefault(destination, []).append(edu)
|
||||
|
||||
preserve_context_over_fn(
|
||||
self._attempt_new_transaction, destination
|
||||
)
|
||||
self._attempt_new_transaction(destination)
|
||||
|
||||
def send_failure(self, failure, destination):
|
||||
if destination == self.server_name or destination == "localhost":
|
||||
@@ -336,9 +332,7 @@ class TransactionQueue(object):
|
||||
destination, []
|
||||
).append(failure)
|
||||
|
||||
preserve_context_over_fn(
|
||||
self._attempt_new_transaction, destination
|
||||
)
|
||||
self._attempt_new_transaction(destination)
|
||||
|
||||
def send_device_messages(self, destination):
|
||||
if destination == self.server_name or destination == "localhost":
|
||||
@@ -347,15 +341,24 @@ class TransactionQueue(object):
|
||||
if not self.can_send_to(destination):
|
||||
return
|
||||
|
||||
preserve_context_over_fn(
|
||||
self._attempt_new_transaction, destination
|
||||
)
|
||||
self._attempt_new_transaction(destination)
|
||||
|
||||
def get_current_token(self):
|
||||
return 0
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _attempt_new_transaction(self, destination):
|
||||
"""Try to start a new transaction to this destination
|
||||
|
||||
If there is already a transaction in progress to this destination,
|
||||
returns immediately. Otherwise kicks off the process of sending a
|
||||
transaction in the background.
|
||||
|
||||
Args:
|
||||
destination (str):
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
# list of (pending_pdu, deferred, order)
|
||||
if destination in self.pending_transactions:
|
||||
# XXX: pending_transactions can get stuck on by a never-ending
|
||||
@@ -368,6 +371,19 @@ class TransactionQueue(object):
|
||||
)
|
||||
return
|
||||
|
||||
logger.debug("TX [%s] Starting transaction loop", destination)
|
||||
|
||||
# Drop the logcontext before starting the transaction. It doesn't
|
||||
# really make sense to log all the outbound transactions against
|
||||
# whatever path led us to this point: that's pretty arbitrary really.
|
||||
#
|
||||
# (this also means we can fire off _perform_transaction without
|
||||
# yielding)
|
||||
with logcontext.PreserveLoggingContext():
|
||||
self._transaction_transmission_loop(destination)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _transaction_transmission_loop(self, destination):
|
||||
pending_pdus = []
|
||||
try:
|
||||
self.pending_transactions[destination] = 1
|
||||
|
||||
@@ -471,3 +471,384 @@ class TransportLayerClient(object):
|
||||
)
|
||||
|
||||
defer.returnValue(content)
|
||||
|
||||
@log_function
|
||||
def get_group_profile(self, destination, group_id, requester_user_id):
|
||||
"""Get a group profile
|
||||
"""
|
||||
path = PREFIX + "/groups/%s/profile" % (group_id,)
|
||||
|
||||
return self.client.get_json(
|
||||
destination=destination,
|
||||
path=path,
|
||||
args={"requester_user_id": requester_user_id},
|
||||
ignore_backoff=True,
|
||||
)
|
||||
|
||||
@log_function
|
||||
def get_group_summary(self, destination, group_id, requester_user_id):
|
||||
"""Get a group summary
|
||||
"""
|
||||
path = PREFIX + "/groups/%s/summary" % (group_id,)
|
||||
|
||||
return self.client.get_json(
|
||||
destination=destination,
|
||||
path=path,
|
||||
args={"requester_user_id": requester_user_id},
|
||||
ignore_backoff=True,
|
||||
)
|
||||
|
||||
@log_function
|
||||
def get_rooms_in_group(self, destination, group_id, requester_user_id):
|
||||
"""Get all rooms in a group
|
||||
"""
|
||||
path = PREFIX + "/groups/%s/rooms" % (group_id,)
|
||||
|
||||
return self.client.get_json(
|
||||
destination=destination,
|
||||
path=path,
|
||||
args={"requester_user_id": requester_user_id},
|
||||
ignore_backoff=True,
|
||||
)
|
||||
|
||||
def add_room_to_group(self, destination, group_id, requester_user_id, room_id,
|
||||
content):
|
||||
"""Add a room to a group
|
||||
"""
|
||||
path = PREFIX + "/groups/%s/room/%s" % (group_id, room_id,)
|
||||
|
||||
return self.client.post_json(
|
||||
destination=destination,
|
||||
path=path,
|
||||
args={"requester_user_id": requester_user_id},
|
||||
data=content,
|
||||
ignore_backoff=True,
|
||||
)
|
||||
|
||||
def remove_room_from_group(self, destination, group_id, requester_user_id, room_id):
|
||||
"""Remove a room from a group
|
||||
"""
|
||||
path = PREFIX + "/groups/%s/room/%s" % (group_id, room_id,)
|
||||
|
||||
return self.client.delete_json(
|
||||
destination=destination,
|
||||
path=path,
|
||||
args={"requester_user_id": requester_user_id},
|
||||
ignore_backoff=True,
|
||||
)
|
||||
|
||||
@log_function
|
||||
def get_users_in_group(self, destination, group_id, requester_user_id):
|
||||
"""Get users in a group
|
||||
"""
|
||||
path = PREFIX + "/groups/%s/users" % (group_id,)
|
||||
|
||||
return self.client.get_json(
|
||||
destination=destination,
|
||||
path=path,
|
||||
args={"requester_user_id": requester_user_id},
|
||||
ignore_backoff=True,
|
||||
)
|
||||
|
||||
@log_function
|
||||
def get_invited_users_in_group(self, destination, group_id, requester_user_id):
|
||||
"""Get users that have been invited to a group
|
||||
"""
|
||||
path = PREFIX + "/groups/%s/invited_users" % (group_id,)
|
||||
|
||||
return self.client.get_json(
|
||||
destination=destination,
|
||||
path=path,
|
||||
args={"requester_user_id": requester_user_id},
|
||||
ignore_backoff=True,
|
||||
)
|
||||
|
||||
@log_function
|
||||
def accept_group_invite(self, destination, group_id, user_id, content):
|
||||
"""Accept a group invite
|
||||
"""
|
||||
path = PREFIX + "/groups/%s/users/%s/accept_invite" % (group_id, user_id)
|
||||
|
||||
return self.client.post_json(
|
||||
destination=destination,
|
||||
path=path,
|
||||
data=content,
|
||||
ignore_backoff=True,
|
||||
)
|
||||
|
||||
@log_function
|
||||
def invite_to_group(self, destination, group_id, user_id, requester_user_id, content):
|
||||
"""Invite a user to a group
|
||||
"""
|
||||
path = PREFIX + "/groups/%s/users/%s/invite" % (group_id, user_id)
|
||||
|
||||
return self.client.post_json(
|
||||
destination=destination,
|
||||
path=path,
|
||||
args={"requester_user_id": requester_user_id},
|
||||
data=content,
|
||||
ignore_backoff=True,
|
||||
)
|
||||
|
||||
@log_function
|
||||
def invite_to_group_notification(self, destination, group_id, user_id, content):
|
||||
"""Sent by group server to inform a user's server that they have been
|
||||
invited.
|
||||
"""
|
||||
|
||||
path = PREFIX + "/groups/local/%s/users/%s/invite" % (group_id, user_id)
|
||||
|
||||
return self.client.post_json(
|
||||
destination=destination,
|
||||
path=path,
|
||||
data=content,
|
||||
ignore_backoff=True,
|
||||
)
|
||||
|
||||
@log_function
|
||||
def remove_user_from_group(self, destination, group_id, requester_user_id,
|
||||
user_id, content):
|
||||
"""Remove a user fron a group
|
||||
"""
|
||||
path = PREFIX + "/groups/%s/users/%s/remove" % (group_id, user_id)
|
||||
|
||||
return self.client.post_json(
|
||||
destination=destination,
|
||||
path=path,
|
||||
args={"requester_user_id": requester_user_id},
|
||||
data=content,
|
||||
ignore_backoff=True,
|
||||
)
|
||||
|
||||
@log_function
|
||||
def remove_user_from_group_notification(self, destination, group_id, user_id,
|
||||
content):
|
||||
"""Sent by group server to inform a user's server that they have been
|
||||
kicked from the group.
|
||||
"""
|
||||
|
||||
path = PREFIX + "/groups/local/%s/users/%s/remove" % (group_id, user_id)
|
||||
|
||||
return self.client.post_json(
|
||||
destination=destination,
|
||||
path=path,
|
||||
data=content,
|
||||
ignore_backoff=True,
|
||||
)
|
||||
|
||||
@log_function
|
||||
def renew_group_attestation(self, destination, group_id, user_id, content):
|
||||
"""Sent by either a group server or a user's server to periodically update
|
||||
the attestations
|
||||
"""
|
||||
|
||||
path = PREFIX + "/groups/%s/renew_attestation/%s" % (group_id, user_id)
|
||||
|
||||
return self.client.post_json(
|
||||
destination=destination,
|
||||
path=path,
|
||||
data=content,
|
||||
ignore_backoff=True,
|
||||
)
|
||||
|
||||
@log_function
|
||||
def update_group_summary_room(self, destination, group_id, user_id, room_id,
|
||||
category_id, content):
|
||||
"""Update a room entry in a group summary
|
||||
"""
|
||||
if category_id:
|
||||
path = PREFIX + "/groups/%s/summary/categories/%s/rooms/%s" % (
|
||||
group_id, category_id, room_id,
|
||||
)
|
||||
else:
|
||||
path = PREFIX + "/groups/%s/summary/rooms/%s" % (group_id, room_id,)
|
||||
|
||||
return self.client.post_json(
|
||||
destination=destination,
|
||||
path=path,
|
||||
args={"requester_user_id": user_id},
|
||||
data=content,
|
||||
ignore_backoff=True,
|
||||
)
|
||||
|
||||
@log_function
|
||||
def delete_group_summary_room(self, destination, group_id, user_id, room_id,
|
||||
category_id):
|
||||
"""Delete a room entry in a group summary
|
||||
"""
|
||||
if category_id:
|
||||
path = PREFIX + "/groups/%s/summary/categories/%s/rooms/%s" % (
|
||||
group_id, category_id, room_id,
|
||||
)
|
||||
else:
|
||||
path = PREFIX + "/groups/%s/summary/rooms/%s" % (group_id, room_id,)
|
||||
|
||||
return self.client.delete_json(
|
||||
destination=destination,
|
||||
path=path,
|
||||
args={"requester_user_id": user_id},
|
||||
ignore_backoff=True,
|
||||
)
|
||||
|
||||
@log_function
|
||||
def get_group_categories(self, destination, group_id, requester_user_id):
|
||||
"""Get all categories in a group
|
||||
"""
|
||||
path = PREFIX + "/groups/%s/categories" % (group_id,)
|
||||
|
||||
return self.client.get_json(
|
||||
destination=destination,
|
||||
path=path,
|
||||
args={"requester_user_id": requester_user_id},
|
||||
ignore_backoff=True,
|
||||
)
|
||||
|
||||
@log_function
|
||||
def get_group_category(self, destination, group_id, requester_user_id, category_id):
|
||||
"""Get category info in a group
|
||||
"""
|
||||
path = PREFIX + "/groups/%s/categories/%s" % (group_id, category_id,)
|
||||
|
||||
return self.client.get_json(
|
||||
destination=destination,
|
||||
path=path,
|
||||
args={"requester_user_id": requester_user_id},
|
||||
ignore_backoff=True,
|
||||
)
|
||||
|
||||
@log_function
|
||||
def update_group_category(self, destination, group_id, requester_user_id, category_id,
|
||||
content):
|
||||
"""Update a category in a group
|
||||
"""
|
||||
path = PREFIX + "/groups/%s/categories/%s" % (group_id, category_id,)
|
||||
|
||||
return self.client.post_json(
|
||||
destination=destination,
|
||||
path=path,
|
||||
args={"requester_user_id": requester_user_id},
|
||||
data=content,
|
||||
ignore_backoff=True,
|
||||
)
|
||||
|
||||
@log_function
|
||||
def delete_group_category(self, destination, group_id, requester_user_id,
|
||||
category_id):
|
||||
"""Delete a category in a group
|
||||
"""
|
||||
path = PREFIX + "/groups/%s/categories/%s" % (group_id, category_id,)
|
||||
|
||||
return self.client.delete_json(
|
||||
destination=destination,
|
||||
path=path,
|
||||
args={"requester_user_id": requester_user_id},
|
||||
ignore_backoff=True,
|
||||
)
|
||||
|
||||
@log_function
|
||||
def get_group_roles(self, destination, group_id, requester_user_id):
|
||||
"""Get all roles in a group
|
||||
"""
|
||||
path = PREFIX + "/groups/%s/roles" % (group_id,)
|
||||
|
||||
return self.client.get_json(
|
||||
destination=destination,
|
||||
path=path,
|
||||
args={"requester_user_id": requester_user_id},
|
||||
ignore_backoff=True,
|
||||
)
|
||||
|
||||
@log_function
|
||||
def get_group_role(self, destination, group_id, requester_user_id, role_id):
|
||||
"""Get a roles info
|
||||
"""
|
||||
path = PREFIX + "/groups/%s/roles/%s" % (group_id, role_id,)
|
||||
|
||||
return self.client.get_json(
|
||||
destination=destination,
|
||||
path=path,
|
||||
args={"requester_user_id": requester_user_id},
|
||||
ignore_backoff=True,
|
||||
)
|
||||
|
||||
@log_function
|
||||
def update_group_role(self, destination, group_id, requester_user_id, role_id,
|
||||
content):
|
||||
"""Update a role in a group
|
||||
"""
|
||||
path = PREFIX + "/groups/%s/roles/%s" % (group_id, role_id,)
|
||||
|
||||
return self.client.post_json(
|
||||
destination=destination,
|
||||
path=path,
|
||||
args={"requester_user_id": requester_user_id},
|
||||
data=content,
|
||||
ignore_backoff=True,
|
||||
)
|
||||
|
||||
@log_function
|
||||
def delete_group_role(self, destination, group_id, requester_user_id, role_id):
|
||||
"""Delete a role in a group
|
||||
"""
|
||||
path = PREFIX + "/groups/%s/roles/%s" % (group_id, role_id,)
|
||||
|
||||
return self.client.delete_json(
|
||||
destination=destination,
|
||||
path=path,
|
||||
args={"requester_user_id": requester_user_id},
|
||||
ignore_backoff=True,
|
||||
)
|
||||
|
||||
@log_function
|
||||
def update_group_summary_user(self, destination, group_id, requester_user_id,
|
||||
user_id, role_id, content):
|
||||
"""Update a users entry in a group
|
||||
"""
|
||||
if role_id:
|
||||
path = PREFIX + "/groups/%s/summary/roles/%s/users/%s" % (
|
||||
group_id, role_id, user_id,
|
||||
)
|
||||
else:
|
||||
path = PREFIX + "/groups/%s/summary/users/%s" % (group_id, user_id,)
|
||||
|
||||
return self.client.post_json(
|
||||
destination=destination,
|
||||
path=path,
|
||||
args={"requester_user_id": requester_user_id},
|
||||
data=content,
|
||||
ignore_backoff=True,
|
||||
)
|
||||
|
||||
@log_function
|
||||
def delete_group_summary_user(self, destination, group_id, requester_user_id,
|
||||
user_id, role_id):
|
||||
"""Delete a users entry in a group
|
||||
"""
|
||||
if role_id:
|
||||
path = PREFIX + "/groups/%s/summary/roles/%s/users/%s" % (
|
||||
group_id, role_id, user_id,
|
||||
)
|
||||
else:
|
||||
path = PREFIX + "/groups/%s/summary/users/%s" % (group_id, user_id,)
|
||||
|
||||
return self.client.delete_json(
|
||||
destination=destination,
|
||||
path=path,
|
||||
args={"requester_user_id": requester_user_id},
|
||||
ignore_backoff=True,
|
||||
)
|
||||
|
||||
def bulk_get_publicised_groups(self, destination, user_ids):
|
||||
"""Get the groups a list of users are publicising
|
||||
"""
|
||||
|
||||
path = PREFIX + "/get_groups_publicised"
|
||||
|
||||
content = {"user_ids": user_ids}
|
||||
|
||||
return self.client.post_json(
|
||||
destination=destination,
|
||||
path=path,
|
||||
data=content,
|
||||
ignore_backoff=True,
|
||||
)
|
||||
|
||||
@@ -25,7 +25,7 @@ from synapse.http.servlet import (
|
||||
from synapse.util.ratelimitutils import FederationRateLimiter
|
||||
from synapse.util.versionstring import get_version_string
|
||||
from synapse.util.logcontext import preserve_fn
|
||||
from synapse.types import ThirdPartyInstanceID
|
||||
from synapse.types import ThirdPartyInstanceID, get_domain_from_id
|
||||
|
||||
import functools
|
||||
import logging
|
||||
@@ -609,6 +609,493 @@ class FederationVersionServlet(BaseFederationServlet):
|
||||
}))
|
||||
|
||||
|
||||
class FederationGroupsProfileServlet(BaseFederationServlet):
|
||||
"""Get the basic profile of a group on behalf of a user
|
||||
"""
|
||||
PATH = "/groups/(?P<group_id>[^/]*)/profile$"
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, origin, content, query, group_id):
|
||||
requester_user_id = parse_string_from_args(query, "requester_user_id")
|
||||
if get_domain_from_id(requester_user_id) != origin:
|
||||
raise SynapseError(403, "requester_user_id doesn't match origin")
|
||||
|
||||
new_content = yield self.handler.get_group_profile(
|
||||
group_id, requester_user_id
|
||||
)
|
||||
|
||||
defer.returnValue((200, new_content))
|
||||
|
||||
|
||||
class FederationGroupsSummaryServlet(BaseFederationServlet):
|
||||
PATH = "/groups/(?P<group_id>[^/]*)/summary$"
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, origin, content, query, group_id):
|
||||
requester_user_id = parse_string_from_args(query, "requester_user_id")
|
||||
if get_domain_from_id(requester_user_id) != origin:
|
||||
raise SynapseError(403, "requester_user_id doesn't match origin")
|
||||
|
||||
new_content = yield self.handler.get_group_summary(
|
||||
group_id, requester_user_id
|
||||
)
|
||||
|
||||
defer.returnValue((200, new_content))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, origin, content, query, group_id):
|
||||
requester_user_id = parse_string_from_args(query, "requester_user_id")
|
||||
if get_domain_from_id(requester_user_id) != origin:
|
||||
raise SynapseError(403, "requester_user_id doesn't match origin")
|
||||
|
||||
new_content = yield self.handler.update_group_profile(
|
||||
group_id, requester_user_id, content
|
||||
)
|
||||
|
||||
defer.returnValue((200, new_content))
|
||||
|
||||
|
||||
class FederationGroupsRoomsServlet(BaseFederationServlet):
|
||||
"""Get the rooms in a group on behalf of a user
|
||||
"""
|
||||
PATH = "/groups/(?P<group_id>[^/]*)/rooms$"
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, origin, content, query, group_id):
|
||||
requester_user_id = parse_string_from_args(query, "requester_user_id")
|
||||
if get_domain_from_id(requester_user_id) != origin:
|
||||
raise SynapseError(403, "requester_user_id doesn't match origin")
|
||||
|
||||
new_content = yield self.handler.get_rooms_in_group(
|
||||
group_id, requester_user_id
|
||||
)
|
||||
|
||||
defer.returnValue((200, new_content))
|
||||
|
||||
|
||||
class FederationGroupsAddRoomsServlet(BaseFederationServlet):
|
||||
"""Add/remove room from group
|
||||
"""
|
||||
PATH = "/groups/(?P<group_id>[^/]*)/room/(?<room_id>)$"
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, origin, content, query, group_id, room_id):
|
||||
requester_user_id = parse_string_from_args(query, "requester_user_id")
|
||||
if get_domain_from_id(requester_user_id) != origin:
|
||||
raise SynapseError(403, "requester_user_id doesn't match origin")
|
||||
|
||||
new_content = yield self.handler.add_room_to_group(
|
||||
group_id, requester_user_id, room_id, content
|
||||
)
|
||||
|
||||
defer.returnValue((200, new_content))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_DELETE(self, origin, content, query, group_id, room_id):
|
||||
requester_user_id = parse_string_from_args(query, "requester_user_id")
|
||||
if get_domain_from_id(requester_user_id) != origin:
|
||||
raise SynapseError(403, "requester_user_id doesn't match origin")
|
||||
|
||||
new_content = yield self.handler.remove_room_from_group(
|
||||
group_id, requester_user_id, room_id,
|
||||
)
|
||||
|
||||
defer.returnValue((200, new_content))
|
||||
|
||||
|
||||
class FederationGroupsUsersServlet(BaseFederationServlet):
|
||||
"""Get the users in a group on behalf of a user
|
||||
"""
|
||||
PATH = "/groups/(?P<group_id>[^/]*)/users$"
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, origin, content, query, group_id):
|
||||
requester_user_id = parse_string_from_args(query, "requester_user_id")
|
||||
if get_domain_from_id(requester_user_id) != origin:
|
||||
raise SynapseError(403, "requester_user_id doesn't match origin")
|
||||
|
||||
new_content = yield self.handler.get_users_in_group(
|
||||
group_id, requester_user_id
|
||||
)
|
||||
|
||||
defer.returnValue((200, new_content))
|
||||
|
||||
|
||||
class FederationGroupsInvitedUsersServlet(BaseFederationServlet):
|
||||
"""Get the users that have been invited to a group
|
||||
"""
|
||||
PATH = "/groups/(?P<group_id>[^/]*)/invited_users$"
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, origin, content, query, group_id):
|
||||
requester_user_id = parse_string_from_args(query, "requester_user_id")
|
||||
if get_domain_from_id(requester_user_id) != origin:
|
||||
raise SynapseError(403, "requester_user_id doesn't match origin")
|
||||
|
||||
new_content = yield self.handler.get_invited_users_in_group(
|
||||
group_id, requester_user_id
|
||||
)
|
||||
|
||||
defer.returnValue((200, new_content))
|
||||
|
||||
|
||||
class FederationGroupsInviteServlet(BaseFederationServlet):
|
||||
"""Ask a group server to invite someone to the group
|
||||
"""
|
||||
PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/invite$"
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, origin, content, query, group_id, user_id):
|
||||
requester_user_id = parse_string_from_args(query, "requester_user_id")
|
||||
if get_domain_from_id(requester_user_id) != origin:
|
||||
raise SynapseError(403, "requester_user_id doesn't match origin")
|
||||
|
||||
new_content = yield self.handler.invite_to_group(
|
||||
group_id, user_id, requester_user_id, content,
|
||||
)
|
||||
|
||||
defer.returnValue((200, new_content))
|
||||
|
||||
|
||||
class FederationGroupsAcceptInviteServlet(BaseFederationServlet):
|
||||
"""Accept an invitation from the group server
|
||||
"""
|
||||
PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/accept_invite$"
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, origin, content, query, group_id, user_id):
|
||||
if get_domain_from_id(user_id) != origin:
|
||||
raise SynapseError(403, "user_id doesn't match origin")
|
||||
|
||||
new_content = yield self.handler.accept_invite(
|
||||
group_id, user_id, content,
|
||||
)
|
||||
|
||||
defer.returnValue((200, new_content))
|
||||
|
||||
|
||||
class FederationGroupsRemoveUserServlet(BaseFederationServlet):
|
||||
"""Leave or kick a user from the group
|
||||
"""
|
||||
PATH = "/groups/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/remove$"
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, origin, content, query, group_id, user_id):
|
||||
requester_user_id = parse_string_from_args(query, "requester_user_id")
|
||||
if get_domain_from_id(requester_user_id) != origin:
|
||||
raise SynapseError(403, "requester_user_id doesn't match origin")
|
||||
|
||||
new_content = yield self.handler.remove_user_from_group(
|
||||
group_id, user_id, requester_user_id, content,
|
||||
)
|
||||
|
||||
defer.returnValue((200, new_content))
|
||||
|
||||
|
||||
class FederationGroupsLocalInviteServlet(BaseFederationServlet):
|
||||
"""A group server has invited a local user
|
||||
"""
|
||||
PATH = "/groups/local/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/invite$"
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, origin, content, query, group_id, user_id):
|
||||
if get_domain_from_id(group_id) != origin:
|
||||
raise SynapseError(403, "group_id doesn't match origin")
|
||||
|
||||
new_content = yield self.handler.on_invite(
|
||||
group_id, user_id, content,
|
||||
)
|
||||
|
||||
defer.returnValue((200, new_content))
|
||||
|
||||
|
||||
class FederationGroupsRemoveLocalUserServlet(BaseFederationServlet):
|
||||
"""A group server has removed a local user
|
||||
"""
|
||||
PATH = "/groups/local/(?P<group_id>[^/]*)/users/(?P<user_id>[^/]*)/remove$"
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, origin, content, query, group_id, user_id):
|
||||
if get_domain_from_id(group_id) != origin:
|
||||
raise SynapseError(403, "user_id doesn't match origin")
|
||||
|
||||
new_content = yield self.handler.user_removed_from_group(
|
||||
group_id, user_id, content,
|
||||
)
|
||||
|
||||
defer.returnValue((200, new_content))
|
||||
|
||||
|
||||
class FederationGroupsRenewAttestaionServlet(BaseFederationServlet):
|
||||
"""A group or user's server renews their attestation
|
||||
"""
|
||||
PATH = "/groups/(?P<group_id>[^/]*)/renew_attestation/(?P<user_id>[^/]*)$"
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, origin, content, query, group_id, user_id):
|
||||
# We don't need to check auth here as we check the attestation signatures
|
||||
|
||||
new_content = yield self.handler.on_renew_attestation(
|
||||
group_id, user_id, content
|
||||
)
|
||||
|
||||
defer.returnValue((200, new_content))
|
||||
|
||||
|
||||
class FederationGroupsSummaryRoomsServlet(BaseFederationServlet):
|
||||
"""Add/remove a room from the group summary, with optional category.
|
||||
|
||||
Matches both:
|
||||
- /groups/:group/summary/rooms/:room_id
|
||||
- /groups/:group/summary/categories/:category/rooms/:room_id
|
||||
"""
|
||||
PATH = (
|
||||
"/groups/(?P<group_id>[^/]*)/summary"
|
||||
"(/categories/(?P<category_id>[^/]+))?"
|
||||
"/rooms/(?P<room_id>[^/]*)$"
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, origin, content, query, group_id, category_id, room_id):
|
||||
requester_user_id = parse_string_from_args(query, "requester_user_id")
|
||||
if get_domain_from_id(requester_user_id) != origin:
|
||||
raise SynapseError(403, "requester_user_id doesn't match origin")
|
||||
|
||||
if category_id == "":
|
||||
raise SynapseError(400, "category_id cannot be empty string")
|
||||
|
||||
resp = yield self.handler.update_group_summary_room(
|
||||
group_id, requester_user_id,
|
||||
room_id=room_id,
|
||||
category_id=category_id,
|
||||
content=content,
|
||||
)
|
||||
|
||||
defer.returnValue((200, resp))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_DELETE(self, origin, content, query, group_id, category_id, room_id):
|
||||
requester_user_id = parse_string_from_args(query, "requester_user_id")
|
||||
if get_domain_from_id(requester_user_id) != origin:
|
||||
raise SynapseError(403, "requester_user_id doesn't match origin")
|
||||
|
||||
if category_id == "":
|
||||
raise SynapseError(400, "category_id cannot be empty string")
|
||||
|
||||
resp = yield self.handler.delete_group_summary_room(
|
||||
group_id, requester_user_id,
|
||||
room_id=room_id,
|
||||
category_id=category_id,
|
||||
)
|
||||
|
||||
defer.returnValue((200, resp))
|
||||
|
||||
|
||||
class FederationGroupsCategoriesServlet(BaseFederationServlet):
|
||||
"""Get all categories for a group
|
||||
"""
|
||||
PATH = (
|
||||
"/groups/(?P<group_id>[^/]*)/categories/$"
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, origin, content, query, group_id):
|
||||
requester_user_id = parse_string_from_args(query, "requester_user_id")
|
||||
if get_domain_from_id(requester_user_id) != origin:
|
||||
raise SynapseError(403, "requester_user_id doesn't match origin")
|
||||
|
||||
resp = yield self.handler.get_group_categories(
|
||||
group_id, requester_user_id,
|
||||
)
|
||||
|
||||
defer.returnValue((200, resp))
|
||||
|
||||
|
||||
class FederationGroupsCategoryServlet(BaseFederationServlet):
|
||||
"""Add/remove/get a category in a group
|
||||
"""
|
||||
PATH = (
|
||||
"/groups/(?P<group_id>[^/]*)/categories/(?P<category_id>[^/]+)$"
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, origin, content, query, group_id, category_id):
|
||||
requester_user_id = parse_string_from_args(query, "requester_user_id")
|
||||
if get_domain_from_id(requester_user_id) != origin:
|
||||
raise SynapseError(403, "requester_user_id doesn't match origin")
|
||||
|
||||
resp = yield self.handler.get_group_category(
|
||||
group_id, requester_user_id, category_id
|
||||
)
|
||||
|
||||
defer.returnValue((200, resp))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, origin, content, query, group_id, category_id):
|
||||
requester_user_id = parse_string_from_args(query, "requester_user_id")
|
||||
if get_domain_from_id(requester_user_id) != origin:
|
||||
raise SynapseError(403, "requester_user_id doesn't match origin")
|
||||
|
||||
if category_id == "":
|
||||
raise SynapseError(400, "category_id cannot be empty string")
|
||||
|
||||
resp = yield self.handler.upsert_group_category(
|
||||
group_id, requester_user_id, category_id, content,
|
||||
)
|
||||
|
||||
defer.returnValue((200, resp))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_DELETE(self, origin, content, query, group_id, category_id):
|
||||
requester_user_id = parse_string_from_args(query, "requester_user_id")
|
||||
if get_domain_from_id(requester_user_id) != origin:
|
||||
raise SynapseError(403, "requester_user_id doesn't match origin")
|
||||
|
||||
if category_id == "":
|
||||
raise SynapseError(400, "category_id cannot be empty string")
|
||||
|
||||
resp = yield self.handler.delete_group_category(
|
||||
group_id, requester_user_id, category_id,
|
||||
)
|
||||
|
||||
defer.returnValue((200, resp))
|
||||
|
||||
|
||||
class FederationGroupsRolesServlet(BaseFederationServlet):
|
||||
"""Get roles in a group
|
||||
"""
|
||||
PATH = (
|
||||
"/groups/(?P<group_id>[^/]*)/roles/$"
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, origin, content, query, group_id):
|
||||
requester_user_id = parse_string_from_args(query, "requester_user_id")
|
||||
if get_domain_from_id(requester_user_id) != origin:
|
||||
raise SynapseError(403, "requester_user_id doesn't match origin")
|
||||
|
||||
resp = yield self.handler.get_group_roles(
|
||||
group_id, requester_user_id,
|
||||
)
|
||||
|
||||
defer.returnValue((200, resp))
|
||||
|
||||
|
||||
class FederationGroupsRoleServlet(BaseFederationServlet):
|
||||
"""Add/remove/get a role in a group
|
||||
"""
|
||||
PATH = (
|
||||
"/groups/(?P<group_id>[^/]*)/roles/(?P<role_id>[^/]+)$"
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, origin, content, query, group_id, role_id):
|
||||
requester_user_id = parse_string_from_args(query, "requester_user_id")
|
||||
if get_domain_from_id(requester_user_id) != origin:
|
||||
raise SynapseError(403, "requester_user_id doesn't match origin")
|
||||
|
||||
resp = yield self.handler.get_group_role(
|
||||
group_id, requester_user_id, role_id
|
||||
)
|
||||
|
||||
defer.returnValue((200, resp))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, origin, content, query, group_id, role_id):
|
||||
requester_user_id = parse_string_from_args(query, "requester_user_id")
|
||||
if get_domain_from_id(requester_user_id) != origin:
|
||||
raise SynapseError(403, "requester_user_id doesn't match origin")
|
||||
|
||||
if role_id == "":
|
||||
raise SynapseError(400, "role_id cannot be empty string")
|
||||
|
||||
resp = yield self.handler.update_group_role(
|
||||
group_id, requester_user_id, role_id, content,
|
||||
)
|
||||
|
||||
defer.returnValue((200, resp))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_DELETE(self, origin, content, query, group_id, role_id):
|
||||
requester_user_id = parse_string_from_args(query, "requester_user_id")
|
||||
if get_domain_from_id(requester_user_id) != origin:
|
||||
raise SynapseError(403, "requester_user_id doesn't match origin")
|
||||
|
||||
if role_id == "":
|
||||
raise SynapseError(400, "role_id cannot be empty string")
|
||||
|
||||
resp = yield self.handler.delete_group_role(
|
||||
group_id, requester_user_id, role_id,
|
||||
)
|
||||
|
||||
defer.returnValue((200, resp))
|
||||
|
||||
|
||||
class FederationGroupsSummaryUsersServlet(BaseFederationServlet):
|
||||
"""Add/remove a user from the group summary, with optional role.
|
||||
|
||||
Matches both:
|
||||
- /groups/:group/summary/users/:user_id
|
||||
- /groups/:group/summary/roles/:role/users/:user_id
|
||||
"""
|
||||
PATH = (
|
||||
"/groups/(?P<group_id>[^/]*)/summary"
|
||||
"(/roles/(?P<role_id>[^/]+))?"
|
||||
"/users/(?P<user_id>[^/]*)$"
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, origin, content, query, group_id, role_id, user_id):
|
||||
requester_user_id = parse_string_from_args(query, "requester_user_id")
|
||||
if get_domain_from_id(requester_user_id) != origin:
|
||||
raise SynapseError(403, "requester_user_id doesn't match origin")
|
||||
|
||||
if role_id == "":
|
||||
raise SynapseError(400, "role_id cannot be empty string")
|
||||
|
||||
resp = yield self.handler.update_group_summary_user(
|
||||
group_id, requester_user_id,
|
||||
user_id=user_id,
|
||||
role_id=role_id,
|
||||
content=content,
|
||||
)
|
||||
|
||||
defer.returnValue((200, resp))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_DELETE(self, origin, content, query, group_id, role_id, user_id):
|
||||
requester_user_id = parse_string_from_args(query, "requester_user_id")
|
||||
if get_domain_from_id(requester_user_id) != origin:
|
||||
raise SynapseError(403, "requester_user_id doesn't match origin")
|
||||
|
||||
if role_id == "":
|
||||
raise SynapseError(400, "role_id cannot be empty string")
|
||||
|
||||
resp = yield self.handler.delete_group_summary_user(
|
||||
group_id, requester_user_id,
|
||||
user_id=user_id,
|
||||
role_id=role_id,
|
||||
)
|
||||
|
||||
defer.returnValue((200, resp))
|
||||
|
||||
|
||||
class FederationGroupsBulkPublicisedServlet(BaseFederationServlet):
|
||||
"""Get roles in a group
|
||||
"""
|
||||
PATH = (
|
||||
"/get_groups_publicised$"
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, origin, content, query):
|
||||
resp = yield self.handler.bulk_get_publicised_groups(
|
||||
content["user_ids"], proxy=False,
|
||||
)
|
||||
|
||||
defer.returnValue((200, resp))
|
||||
|
||||
|
||||
FEDERATION_SERVLET_CLASSES = (
|
||||
FederationSendServlet,
|
||||
FederationPullServlet,
|
||||
@@ -635,10 +1122,40 @@ FEDERATION_SERVLET_CLASSES = (
|
||||
FederationVersionServlet,
|
||||
)
|
||||
|
||||
|
||||
ROOM_LIST_CLASSES = (
|
||||
PublicRoomList,
|
||||
)
|
||||
|
||||
GROUP_SERVER_SERVLET_CLASSES = (
|
||||
FederationGroupsProfileServlet,
|
||||
FederationGroupsSummaryServlet,
|
||||
FederationGroupsRoomsServlet,
|
||||
FederationGroupsUsersServlet,
|
||||
FederationGroupsInvitedUsersServlet,
|
||||
FederationGroupsInviteServlet,
|
||||
FederationGroupsAcceptInviteServlet,
|
||||
FederationGroupsRemoveUserServlet,
|
||||
FederationGroupsSummaryRoomsServlet,
|
||||
FederationGroupsCategoriesServlet,
|
||||
FederationGroupsCategoryServlet,
|
||||
FederationGroupsRolesServlet,
|
||||
FederationGroupsRoleServlet,
|
||||
FederationGroupsSummaryUsersServlet,
|
||||
)
|
||||
|
||||
|
||||
GROUP_LOCAL_SERVLET_CLASSES = (
|
||||
FederationGroupsLocalInviteServlet,
|
||||
FederationGroupsRemoveLocalUserServlet,
|
||||
FederationGroupsBulkPublicisedServlet,
|
||||
)
|
||||
|
||||
|
||||
GROUP_ATTESTATION_SERVLET_CLASSES = (
|
||||
FederationGroupsRenewAttestaionServlet,
|
||||
)
|
||||
|
||||
|
||||
def register_servlets(hs, resource, authenticator, ratelimiter):
|
||||
for servletclass in FEDERATION_SERVLET_CLASSES:
|
||||
@@ -656,3 +1173,27 @@ def register_servlets(hs, resource, authenticator, ratelimiter):
|
||||
ratelimiter=ratelimiter,
|
||||
server_name=hs.hostname,
|
||||
).register(resource)
|
||||
|
||||
for servletclass in GROUP_SERVER_SERVLET_CLASSES:
|
||||
servletclass(
|
||||
handler=hs.get_groups_server_handler(),
|
||||
authenticator=authenticator,
|
||||
ratelimiter=ratelimiter,
|
||||
server_name=hs.hostname,
|
||||
).register(resource)
|
||||
|
||||
for servletclass in GROUP_LOCAL_SERVLET_CLASSES:
|
||||
servletclass(
|
||||
handler=hs.get_groups_local_handler(),
|
||||
authenticator=authenticator,
|
||||
ratelimiter=ratelimiter,
|
||||
server_name=hs.hostname,
|
||||
).register(resource)
|
||||
|
||||
for servletclass in GROUP_ATTESTATION_SERVLET_CLASSES:
|
||||
servletclass(
|
||||
handler=hs.get_groups_attestation_renewer(),
|
||||
authenticator=authenticator,
|
||||
ratelimiter=ratelimiter,
|
||||
server_name=hs.hostname,
|
||||
).register(resource)
|
||||
|
||||
0
synapse/groups/__init__.py
Normal file
0
synapse/groups/__init__.py
Normal file
151
synapse/groups/attestations.py
Normal file
151
synapse/groups/attestations.py
Normal file
@@ -0,0 +1,151 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2017 Vector Creations Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.types import get_domain_from_id
|
||||
from synapse.util.logcontext import preserve_fn
|
||||
|
||||
from signedjson.sign import sign_json
|
||||
|
||||
|
||||
# Default validity duration for new attestations we create
|
||||
DEFAULT_ATTESTATION_LENGTH_MS = 3 * 24 * 60 * 60 * 1000
|
||||
|
||||
# Start trying to update our attestations when they come this close to expiring
|
||||
UPDATE_ATTESTATION_TIME_MS = 1 * 24 * 60 * 60 * 1000
|
||||
|
||||
|
||||
class GroupAttestationSigning(object):
|
||||
"""Creates and verifies group attestations.
|
||||
"""
|
||||
def __init__(self, hs):
|
||||
self.keyring = hs.get_keyring()
|
||||
self.clock = hs.get_clock()
|
||||
self.server_name = hs.hostname
|
||||
self.signing_key = hs.config.signing_key[0]
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def verify_attestation(self, attestation, group_id, user_id, server_name=None):
|
||||
"""Verifies that the given attestation matches the given parameters.
|
||||
|
||||
An optional server_name can be supplied to explicitly set which server's
|
||||
signature is expected. Otherwise assumes that either the group_id or user_id
|
||||
is local and uses the other's server as the one to check.
|
||||
"""
|
||||
|
||||
if not server_name:
|
||||
if get_domain_from_id(group_id) == self.server_name:
|
||||
server_name = get_domain_from_id(user_id)
|
||||
elif get_domain_from_id(user_id) == self.server_name:
|
||||
server_name = get_domain_from_id(group_id)
|
||||
else:
|
||||
raise Exception("Expected either group_id or user_id to be local")
|
||||
|
||||
if user_id != attestation["user_id"]:
|
||||
raise SynapseError(400, "Attestation has incorrect user_id")
|
||||
|
||||
if group_id != attestation["group_id"]:
|
||||
raise SynapseError(400, "Attestation has incorrect group_id")
|
||||
valid_until_ms = attestation["valid_until_ms"]
|
||||
|
||||
# TODO: We also want to check that *new* attestations that people give
|
||||
# us to store are valid for at least a little while.
|
||||
if valid_until_ms < self.clock.time_msec():
|
||||
raise SynapseError(400, "Attestation expired")
|
||||
|
||||
yield self.keyring.verify_json_for_server(server_name, attestation)
|
||||
|
||||
def create_attestation(self, group_id, user_id):
|
||||
"""Create an attestation for the group_id and user_id with default
|
||||
validity length.
|
||||
"""
|
||||
return sign_json({
|
||||
"group_id": group_id,
|
||||
"user_id": user_id,
|
||||
"valid_until_ms": self.clock.time_msec() + DEFAULT_ATTESTATION_LENGTH_MS,
|
||||
}, self.server_name, self.signing_key)
|
||||
|
||||
|
||||
class GroupAttestionRenewer(object):
|
||||
"""Responsible for sending and receiving attestation updates.
|
||||
"""
|
||||
|
||||
def __init__(self, hs):
|
||||
self.clock = hs.get_clock()
|
||||
self.store = hs.get_datastore()
|
||||
self.assestations = hs.get_groups_attestation_signing()
|
||||
self.transport_client = hs.get_federation_transport_client()
|
||||
self.is_mine_id = hs.is_mine_id
|
||||
self.attestations = hs.get_groups_attestation_signing()
|
||||
|
||||
self._renew_attestations_loop = self.clock.looping_call(
|
||||
self._renew_attestations, 30 * 60 * 1000,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_renew_attestation(self, group_id, user_id, content):
|
||||
"""When a remote updates an attestation
|
||||
"""
|
||||
attestation = content["attestation"]
|
||||
|
||||
if not self.is_mine_id(group_id) and not self.is_mine_id(user_id):
|
||||
raise SynapseError(400, "Neither user not group are on this server")
|
||||
|
||||
yield self.attestations.verify_attestation(
|
||||
attestation,
|
||||
user_id=user_id,
|
||||
group_id=group_id,
|
||||
)
|
||||
|
||||
yield self.store.update_remote_attestion(group_id, user_id, attestation)
|
||||
|
||||
defer.returnValue({})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _renew_attestations(self):
|
||||
"""Called periodically to check if we need to update any of our attestations
|
||||
"""
|
||||
|
||||
now = self.clock.time_msec()
|
||||
|
||||
rows = yield self.store.get_attestations_need_renewals(
|
||||
now + UPDATE_ATTESTATION_TIME_MS
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _renew_attestation(group_id, user_id):
|
||||
attestation = self.attestations.create_attestation(group_id, user_id)
|
||||
|
||||
if self.is_mine_id(group_id):
|
||||
destination = get_domain_from_id(user_id)
|
||||
else:
|
||||
destination = get_domain_from_id(group_id)
|
||||
|
||||
yield self.transport_client.renew_group_attestation(
|
||||
destination, group_id, user_id,
|
||||
content={"attestation": attestation},
|
||||
)
|
||||
|
||||
yield self.store.update_attestation_renewal(
|
||||
group_id, user_id, attestation
|
||||
)
|
||||
|
||||
for row in rows:
|
||||
group_id = row["group_id"]
|
||||
user_id = row["user_id"]
|
||||
|
||||
preserve_fn(_renew_attestation)(group_id, user_id)
|
||||
803
synapse/groups/groups_server.py
Normal file
803
synapse/groups/groups_server.py
Normal file
@@ -0,0 +1,803 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2017 Vector Creations Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.types import UserID, get_domain_from_id, RoomID, GroupID
|
||||
|
||||
|
||||
import logging
|
||||
import urllib
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# TODO: Allow users to "knock" or simpkly join depending on rules
|
||||
# TODO: Federation admin APIs
|
||||
# TODO: is_priveged flag to users and is_public to users and rooms
|
||||
# TODO: Audit log for admins (profile updates, membership changes, users who tried
|
||||
# to join but were rejected, etc)
|
||||
# TODO: Flairs
|
||||
|
||||
|
||||
class GroupsServerHandler(object):
|
||||
def __init__(self, hs):
|
||||
self.hs = hs
|
||||
self.store = hs.get_datastore()
|
||||
self.room_list_handler = hs.get_room_list_handler()
|
||||
self.auth = hs.get_auth()
|
||||
self.clock = hs.get_clock()
|
||||
self.keyring = hs.get_keyring()
|
||||
self.is_mine_id = hs.is_mine_id
|
||||
self.signing_key = hs.config.signing_key[0]
|
||||
self.server_name = hs.hostname
|
||||
self.attestations = hs.get_groups_attestation_signing()
|
||||
self.transport_client = hs.get_federation_transport_client()
|
||||
self.profile_handler = hs.get_profile_handler()
|
||||
|
||||
# Ensure attestations get renewed
|
||||
hs.get_groups_attestation_renewer()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def check_group_is_ours(self, group_id, and_exists=False, and_is_admin=None):
|
||||
"""Check that the group is ours, and optionally if it exists.
|
||||
|
||||
If group does exist then return group.
|
||||
|
||||
Args:
|
||||
group_id (str)
|
||||
and_exists (bool): whether to also check if group exists
|
||||
and_is_admin (str): whether to also check if given str is a user_id
|
||||
that is an admin
|
||||
"""
|
||||
if not self.is_mine_id(group_id):
|
||||
raise SynapseError(400, "Group not on this server")
|
||||
|
||||
group = yield self.store.get_group(group_id)
|
||||
if and_exists and not group:
|
||||
raise SynapseError(404, "Unknown group")
|
||||
|
||||
if and_is_admin:
|
||||
is_admin = yield self.store.is_user_admin_in_group(group_id, and_is_admin)
|
||||
if not is_admin:
|
||||
raise SynapseError(403, "User is not admin in group")
|
||||
|
||||
defer.returnValue(group)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_group_summary(self, group_id, requester_user_id):
|
||||
"""Get the summary for a group as seen by requester_user_id.
|
||||
|
||||
The group summary consists of the profile of the room, and a curated
|
||||
list of users and rooms. These list *may* be organised by role/category.
|
||||
The roles/categories are ordered, and so are the users/rooms within them.
|
||||
|
||||
A user/room may appear in multiple roles/categories.
|
||||
"""
|
||||
yield self.check_group_is_ours(group_id, and_exists=True)
|
||||
|
||||
is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id)
|
||||
|
||||
profile = yield self.get_group_profile(group_id, requester_user_id)
|
||||
|
||||
users, roles = yield self.store.get_users_for_summary_by_role(
|
||||
group_id, include_private=is_user_in_group,
|
||||
)
|
||||
|
||||
# TODO: Add profiles to users
|
||||
|
||||
rooms, categories = yield self.store.get_rooms_for_summary_by_category(
|
||||
group_id, include_private=is_user_in_group,
|
||||
)
|
||||
|
||||
for room_entry in rooms:
|
||||
room_id = room_entry["room_id"]
|
||||
joined_users = yield self.store.get_users_in_room(room_id)
|
||||
entry = yield self.room_list_handler.generate_room_entry(
|
||||
room_id, len(joined_users),
|
||||
with_alias=False, allow_private=True,
|
||||
)
|
||||
entry = dict(entry) # so we don't change whats cached
|
||||
entry.pop("room_id", None)
|
||||
|
||||
room_entry["profile"] = entry
|
||||
|
||||
rooms.sort(key=lambda e: e.get("order", 0))
|
||||
|
||||
for entry in users:
|
||||
user_id = entry["user_id"]
|
||||
|
||||
if not self.is_mine_id(requester_user_id):
|
||||
attestation = yield self.store.get_remote_attestation(group_id, user_id)
|
||||
if not attestation:
|
||||
continue
|
||||
|
||||
entry["attestation"] = attestation
|
||||
else:
|
||||
entry["attestation"] = self.attestations.create_attestation(
|
||||
group_id, user_id,
|
||||
)
|
||||
|
||||
user_profile = yield self.profile_handler.get_profile_from_cache(user_id)
|
||||
entry.update(user_profile)
|
||||
|
||||
users.sort(key=lambda e: e.get("order", 0))
|
||||
|
||||
membership_info = yield self.store.get_users_membership_info_in_group(
|
||||
group_id, requester_user_id,
|
||||
)
|
||||
|
||||
defer.returnValue({
|
||||
"profile": profile,
|
||||
"users_section": {
|
||||
"users": users,
|
||||
"roles": roles,
|
||||
"total_user_count_estimate": 0, # TODO
|
||||
},
|
||||
"rooms_section": {
|
||||
"rooms": rooms,
|
||||
"categories": categories,
|
||||
"total_room_count_estimate": 0, # TODO
|
||||
},
|
||||
"user": membership_info,
|
||||
})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def update_group_summary_room(self, group_id, user_id, room_id, category_id, content):
|
||||
"""Add/update a room to the group summary
|
||||
"""
|
||||
yield self.check_group_is_ours(group_id, and_exists=True, and_is_admin=user_id)
|
||||
|
||||
RoomID.from_string(room_id) # Ensure valid room id
|
||||
|
||||
order = content.get("order", None)
|
||||
|
||||
is_public = _parse_visibility_from_contents(content)
|
||||
|
||||
yield self.store.add_room_to_summary(
|
||||
group_id=group_id,
|
||||
room_id=room_id,
|
||||
category_id=category_id,
|
||||
order=order,
|
||||
is_public=is_public,
|
||||
)
|
||||
|
||||
defer.returnValue({})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def delete_group_summary_room(self, group_id, user_id, room_id, category_id):
|
||||
"""Remove a room from the summary
|
||||
"""
|
||||
yield self.check_group_is_ours(group_id, and_exists=True, and_is_admin=user_id)
|
||||
|
||||
yield self.store.remove_room_from_summary(
|
||||
group_id=group_id,
|
||||
room_id=room_id,
|
||||
category_id=category_id,
|
||||
)
|
||||
|
||||
defer.returnValue({})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_group_categories(self, group_id, user_id):
|
||||
"""Get all categories in a group (as seen by user)
|
||||
"""
|
||||
yield self.check_group_is_ours(group_id, and_exists=True)
|
||||
|
||||
categories = yield self.store.get_group_categories(
|
||||
group_id=group_id,
|
||||
)
|
||||
defer.returnValue({"categories": categories})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_group_category(self, group_id, user_id, category_id):
|
||||
"""Get a specific category in a group (as seen by user)
|
||||
"""
|
||||
yield self.check_group_is_ours(group_id, and_exists=True)
|
||||
|
||||
res = yield self.store.get_group_category(
|
||||
group_id=group_id,
|
||||
category_id=category_id,
|
||||
)
|
||||
|
||||
defer.returnValue(res)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def update_group_category(self, group_id, user_id, category_id, content):
|
||||
"""Add/Update a group category
|
||||
"""
|
||||
yield self.check_group_is_ours(group_id, and_exists=True, and_is_admin=user_id)
|
||||
|
||||
is_public = _parse_visibility_from_contents(content)
|
||||
profile = content.get("profile")
|
||||
|
||||
yield self.store.upsert_group_category(
|
||||
group_id=group_id,
|
||||
category_id=category_id,
|
||||
is_public=is_public,
|
||||
profile=profile,
|
||||
)
|
||||
|
||||
defer.returnValue({})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def delete_group_category(self, group_id, user_id, category_id):
|
||||
"""Delete a group category
|
||||
"""
|
||||
yield self.check_group_is_ours(group_id, and_exists=True, and_is_admin=user_id)
|
||||
|
||||
yield self.store.remove_group_category(
|
||||
group_id=group_id,
|
||||
category_id=category_id,
|
||||
)
|
||||
|
||||
defer.returnValue({})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_group_roles(self, group_id, user_id):
|
||||
"""Get all roles in a group (as seen by user)
|
||||
"""
|
||||
yield self.check_group_is_ours(group_id, and_exists=True)
|
||||
|
||||
roles = yield self.store.get_group_roles(
|
||||
group_id=group_id,
|
||||
)
|
||||
defer.returnValue({"roles": roles})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_group_role(self, group_id, user_id, role_id):
|
||||
"""Get a specific role in a group (as seen by user)
|
||||
"""
|
||||
yield self.check_group_is_ours(group_id, and_exists=True)
|
||||
|
||||
res = yield self.store.get_group_role(
|
||||
group_id=group_id,
|
||||
role_id=role_id,
|
||||
)
|
||||
defer.returnValue(res)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def update_group_role(self, group_id, user_id, role_id, content):
|
||||
"""Add/update a role in a group
|
||||
"""
|
||||
yield self.check_group_is_ours(group_id, and_exists=True, and_is_admin=user_id)
|
||||
|
||||
is_public = _parse_visibility_from_contents(content)
|
||||
|
||||
profile = content.get("profile")
|
||||
|
||||
yield self.store.upsert_group_role(
|
||||
group_id=group_id,
|
||||
role_id=role_id,
|
||||
is_public=is_public,
|
||||
profile=profile,
|
||||
)
|
||||
|
||||
defer.returnValue({})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def delete_group_role(self, group_id, user_id, role_id):
|
||||
"""Remove role from group
|
||||
"""
|
||||
yield self.check_group_is_ours(group_id, and_exists=True, and_is_admin=user_id)
|
||||
|
||||
yield self.store.remove_group_role(
|
||||
group_id=group_id,
|
||||
role_id=role_id,
|
||||
)
|
||||
|
||||
defer.returnValue({})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def update_group_summary_user(self, group_id, requester_user_id, user_id, role_id,
|
||||
content):
|
||||
"""Add/update a users entry in the group summary
|
||||
"""
|
||||
yield self.check_group_is_ours(
|
||||
group_id, and_exists=True, and_is_admin=requester_user_id,
|
||||
)
|
||||
|
||||
order = content.get("order", None)
|
||||
|
||||
is_public = _parse_visibility_from_contents(content)
|
||||
|
||||
yield self.store.add_user_to_summary(
|
||||
group_id=group_id,
|
||||
user_id=user_id,
|
||||
role_id=role_id,
|
||||
order=order,
|
||||
is_public=is_public,
|
||||
)
|
||||
|
||||
defer.returnValue({})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def delete_group_summary_user(self, group_id, requester_user_id, user_id, role_id):
|
||||
"""Remove a user from the group summary
|
||||
"""
|
||||
yield self.check_group_is_ours(
|
||||
group_id, and_exists=True, and_is_admin=requester_user_id,
|
||||
)
|
||||
|
||||
yield self.store.remove_user_from_summary(
|
||||
group_id=group_id,
|
||||
user_id=user_id,
|
||||
role_id=role_id,
|
||||
)
|
||||
|
||||
defer.returnValue({})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_group_profile(self, group_id, requester_user_id):
|
||||
"""Get the group profile as seen by requester_user_id
|
||||
"""
|
||||
|
||||
yield self.check_group_is_ours(group_id)
|
||||
|
||||
group_description = yield self.store.get_group(group_id)
|
||||
|
||||
if group_description:
|
||||
defer.returnValue(group_description)
|
||||
else:
|
||||
raise SynapseError(404, "Unknown group")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def update_group_profile(self, group_id, requester_user_id, content):
|
||||
"""Update the group profile
|
||||
"""
|
||||
yield self.check_group_is_ours(
|
||||
group_id, and_exists=True, and_is_admin=requester_user_id,
|
||||
)
|
||||
|
||||
profile = {}
|
||||
for keyname in ("name", "avatar_url", "short_description",
|
||||
"long_description"):
|
||||
if keyname in content:
|
||||
value = content[keyname]
|
||||
if not isinstance(value, basestring):
|
||||
raise SynapseError(400, "%r value is not a string" % (keyname,))
|
||||
profile[keyname] = value
|
||||
|
||||
yield self.store.update_group_profile(group_id, profile)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_users_in_group(self, group_id, requester_user_id):
|
||||
"""Get the users in group as seen by requester_user_id.
|
||||
|
||||
The ordering is arbitrary at the moment
|
||||
"""
|
||||
|
||||
yield self.check_group_is_ours(group_id, and_exists=True)
|
||||
|
||||
is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id)
|
||||
|
||||
user_results = yield self.store.get_users_in_group(
|
||||
group_id, include_private=is_user_in_group,
|
||||
)
|
||||
|
||||
chunk = []
|
||||
for user_result in user_results:
|
||||
g_user_id = user_result["user_id"]
|
||||
is_public = user_result["is_public"]
|
||||
|
||||
entry = {"user_id": g_user_id}
|
||||
|
||||
profile = yield self.profile_handler.get_profile_from_cache(g_user_id)
|
||||
entry.update(profile)
|
||||
|
||||
if not is_public:
|
||||
entry["is_public"] = False
|
||||
|
||||
if not self.is_mine_id(g_user_id):
|
||||
attestation = yield self.store.get_remote_attestation(group_id, g_user_id)
|
||||
if not attestation:
|
||||
continue
|
||||
|
||||
entry["attestation"] = attestation
|
||||
else:
|
||||
entry["attestation"] = self.attestations.create_attestation(
|
||||
group_id, g_user_id,
|
||||
)
|
||||
|
||||
chunk.append(entry)
|
||||
|
||||
# TODO: If admin add lists of users whose attestations have timed out
|
||||
|
||||
defer.returnValue({
|
||||
"chunk": chunk,
|
||||
"total_user_count_estimate": len(user_results),
|
||||
})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_invited_users_in_group(self, group_id, requester_user_id):
|
||||
"""Get the users that have been invited to a group as seen by requester_user_id.
|
||||
|
||||
The ordering is arbitrary at the moment
|
||||
"""
|
||||
|
||||
yield self.check_group_is_ours(group_id, and_exists=True)
|
||||
|
||||
is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id)
|
||||
|
||||
if not is_user_in_group:
|
||||
raise SynapseError(403, "User not in group")
|
||||
|
||||
invited_users = yield self.store.get_invited_users_in_group(group_id)
|
||||
|
||||
user_profiles = []
|
||||
|
||||
for user_id in invited_users:
|
||||
user_profile = {
|
||||
"user_id": user_id
|
||||
}
|
||||
try:
|
||||
profile = yield self.profile_handler.get_profile_from_cache(user_id)
|
||||
user_profile.update(profile)
|
||||
except Exception as e:
|
||||
logger.warn("Error getting profile for %s: %s", user_id, e)
|
||||
user_profiles.append(user_profile)
|
||||
|
||||
defer.returnValue({
|
||||
"chunk": user_profiles,
|
||||
"total_user_count_estimate": len(invited_users),
|
||||
})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_rooms_in_group(self, group_id, requester_user_id):
|
||||
"""Get the rooms in group as seen by requester_user_id
|
||||
|
||||
This returns rooms in order of decreasing number of joined users
|
||||
"""
|
||||
|
||||
yield self.check_group_is_ours(group_id, and_exists=True)
|
||||
|
||||
is_user_in_group = yield self.store.is_user_in_group(requester_user_id, group_id)
|
||||
|
||||
room_results = yield self.store.get_rooms_in_group(
|
||||
group_id, include_private=is_user_in_group,
|
||||
)
|
||||
|
||||
chunk = []
|
||||
for room_result in room_results:
|
||||
room_id = room_result["room_id"]
|
||||
is_public = room_result["is_public"]
|
||||
|
||||
joined_users = yield self.store.get_users_in_room(room_id)
|
||||
entry = yield self.room_list_handler.generate_room_entry(
|
||||
room_id, len(joined_users),
|
||||
with_alias=False, allow_private=True,
|
||||
)
|
||||
|
||||
if not entry:
|
||||
continue
|
||||
|
||||
if not is_public:
|
||||
entry["is_public"] = False
|
||||
|
||||
chunk.append(entry)
|
||||
|
||||
chunk.sort(key=lambda e: -e["num_joined_members"])
|
||||
|
||||
defer.returnValue({
|
||||
"chunk": chunk,
|
||||
"total_room_count_estimate": len(room_results),
|
||||
})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def add_room_to_group(self, group_id, requester_user_id, room_id, content):
|
||||
"""Add room to group
|
||||
"""
|
||||
RoomID.from_string(room_id) # Ensure valid room id
|
||||
|
||||
yield self.check_group_is_ours(
|
||||
group_id, and_exists=True, and_is_admin=requester_user_id
|
||||
)
|
||||
|
||||
is_public = _parse_visibility_from_contents(content)
|
||||
|
||||
yield self.store.add_room_to_group(group_id, room_id, is_public=is_public)
|
||||
|
||||
defer.returnValue({})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def remove_room_from_group(self, group_id, requester_user_id, room_id):
|
||||
"""Remove room from group
|
||||
"""
|
||||
yield self.check_group_is_ours(
|
||||
group_id, and_exists=True, and_is_admin=requester_user_id
|
||||
)
|
||||
|
||||
yield self.store.remove_room_from_group(group_id, room_id)
|
||||
|
||||
defer.returnValue({})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def invite_to_group(self, group_id, user_id, requester_user_id, content):
|
||||
"""Invite user to group
|
||||
"""
|
||||
|
||||
group = yield self.check_group_is_ours(
|
||||
group_id, and_exists=True, and_is_admin=requester_user_id
|
||||
)
|
||||
|
||||
# TODO: Check if user knocked
|
||||
# TODO: Check if user is already invited
|
||||
|
||||
content = {
|
||||
"profile": {
|
||||
"name": group["name"],
|
||||
"avatar_url": group["avatar_url"],
|
||||
},
|
||||
"inviter": requester_user_id,
|
||||
}
|
||||
|
||||
if self.hs.is_mine_id(user_id):
|
||||
groups_local = self.hs.get_groups_local_handler()
|
||||
res = yield groups_local.on_invite(group_id, user_id, content)
|
||||
local_attestation = None
|
||||
else:
|
||||
local_attestation = self.attestations.create_attestation(group_id, user_id)
|
||||
content.update({
|
||||
"attestation": local_attestation,
|
||||
})
|
||||
|
||||
res = yield self.transport_client.invite_to_group_notification(
|
||||
get_domain_from_id(user_id), group_id, user_id, content
|
||||
)
|
||||
|
||||
user_profile = res.get("user_profile", {})
|
||||
yield self.store.add_remote_profile_cache(
|
||||
user_id,
|
||||
displayname=user_profile.get("displayname"),
|
||||
avatar_url=user_profile.get("avatar_url"),
|
||||
)
|
||||
|
||||
if res["state"] == "join":
|
||||
if not self.hs.is_mine_id(user_id):
|
||||
remote_attestation = res["attestation"]
|
||||
|
||||
yield self.attestations.verify_attestation(
|
||||
remote_attestation,
|
||||
user_id=user_id,
|
||||
group_id=group_id,
|
||||
)
|
||||
else:
|
||||
remote_attestation = None
|
||||
|
||||
yield self.store.add_user_to_group(
|
||||
group_id, user_id,
|
||||
is_admin=False,
|
||||
is_public=False, # TODO
|
||||
local_attestation=local_attestation,
|
||||
remote_attestation=remote_attestation,
|
||||
)
|
||||
elif res["state"] == "invite":
|
||||
yield self.store.add_group_invite(
|
||||
group_id, user_id,
|
||||
)
|
||||
defer.returnValue({
|
||||
"state": "invite"
|
||||
})
|
||||
elif res["state"] == "reject":
|
||||
defer.returnValue({
|
||||
"state": "reject"
|
||||
})
|
||||
else:
|
||||
raise SynapseError(502, "Unknown state returned by HS")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def accept_invite(self, group_id, user_id, content):
|
||||
"""User tries to accept an invite to the group.
|
||||
|
||||
This is different from them asking to join, and so should error if no
|
||||
invite exists (and they're not a member of the group)
|
||||
"""
|
||||
|
||||
yield self.check_group_is_ours(group_id, and_exists=True)
|
||||
|
||||
if not self.store.is_user_invited_to_local_group(group_id, user_id):
|
||||
raise SynapseError(403, "User not invited to group")
|
||||
|
||||
if not self.hs.is_mine_id(user_id):
|
||||
remote_attestation = content["attestation"]
|
||||
|
||||
yield self.attestations.verify_attestation(
|
||||
remote_attestation,
|
||||
user_id=user_id,
|
||||
group_id=group_id,
|
||||
)
|
||||
else:
|
||||
remote_attestation = None
|
||||
|
||||
local_attestation = self.attestations.create_attestation(group_id, user_id)
|
||||
|
||||
is_public = _parse_visibility_from_contents(content)
|
||||
|
||||
yield self.store.add_user_to_group(
|
||||
group_id, user_id,
|
||||
is_admin=False,
|
||||
is_public=is_public,
|
||||
local_attestation=local_attestation,
|
||||
remote_attestation=remote_attestation,
|
||||
)
|
||||
|
||||
defer.returnValue({
|
||||
"state": "join",
|
||||
"attestation": local_attestation,
|
||||
})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def knock(self, group_id, user_id, content):
|
||||
"""A user requests becoming a member of the group
|
||||
"""
|
||||
yield self.check_group_is_ours(group_id, and_exists=True)
|
||||
|
||||
raise NotImplementedError()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def accept_knock(self, group_id, user_id, content):
|
||||
"""Accept a users knock to the room.
|
||||
|
||||
Errors if the user hasn't knocked, rather than inviting them.
|
||||
"""
|
||||
|
||||
yield self.check_group_is_ours(group_id, and_exists=True)
|
||||
|
||||
raise NotImplementedError()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def remove_user_from_group(self, group_id, user_id, requester_user_id, content):
|
||||
"""Remove a user from the group; either a user is leaving or and admin
|
||||
kicked htem.
|
||||
"""
|
||||
|
||||
yield self.check_group_is_ours(group_id, and_exists=True)
|
||||
|
||||
is_kick = False
|
||||
if requester_user_id != user_id:
|
||||
is_admin = yield self.store.is_user_admin_in_group(
|
||||
group_id, requester_user_id
|
||||
)
|
||||
if not is_admin:
|
||||
raise SynapseError(403, "User is not admin in group")
|
||||
|
||||
is_kick = True
|
||||
|
||||
yield self.store.remove_user_from_group(
|
||||
group_id, user_id,
|
||||
)
|
||||
|
||||
if is_kick:
|
||||
if self.hs.is_mine_id(user_id):
|
||||
groups_local = self.hs.get_groups_local_handler()
|
||||
yield groups_local.user_removed_from_group(group_id, user_id, {})
|
||||
else:
|
||||
yield self.transport_client.remove_user_from_group_notification(
|
||||
get_domain_from_id(user_id), group_id, user_id, {}
|
||||
)
|
||||
|
||||
if not self.hs.is_mine_id(user_id):
|
||||
yield self.store.maybe_delete_remote_profile_cache(user_id)
|
||||
|
||||
defer.returnValue({})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def create_group(self, group_id, user_id, content):
|
||||
group = yield self.check_group_is_ours(group_id)
|
||||
|
||||
_validate_group_id(group_id)
|
||||
|
||||
logger.info("Attempting to create group with ID: %r", group_id)
|
||||
if group:
|
||||
raise SynapseError(400, "Group already exists")
|
||||
|
||||
is_admin = yield self.auth.is_server_admin(UserID.from_string(user_id))
|
||||
if not is_admin:
|
||||
if not self.hs.config.enable_group_creation:
|
||||
raise SynapseError(
|
||||
403, "Only server admin can create group on this server",
|
||||
)
|
||||
localpart = GroupID.from_string(group_id).localpart
|
||||
if not localpart.startswith(self.hs.config.group_creation_prefix):
|
||||
raise SynapseError(
|
||||
400,
|
||||
"Can only create groups with prefix %r on this server" % (
|
||||
self.hs.config.group_creation_prefix,
|
||||
),
|
||||
)
|
||||
|
||||
profile = content.get("profile", {})
|
||||
name = profile.get("name")
|
||||
avatar_url = profile.get("avatar_url")
|
||||
short_description = profile.get("short_description")
|
||||
long_description = profile.get("long_description")
|
||||
user_profile = content.get("user_profile", {})
|
||||
|
||||
yield self.store.create_group(
|
||||
group_id,
|
||||
user_id,
|
||||
name=name,
|
||||
avatar_url=avatar_url,
|
||||
short_description=short_description,
|
||||
long_description=long_description,
|
||||
)
|
||||
|
||||
if not self.hs.is_mine_id(user_id):
|
||||
remote_attestation = content["attestation"]
|
||||
|
||||
yield self.attestations.verify_attestation(
|
||||
remote_attestation,
|
||||
user_id=user_id,
|
||||
group_id=group_id,
|
||||
)
|
||||
|
||||
local_attestation = self.attestations.create_attestation(group_id, user_id)
|
||||
else:
|
||||
local_attestation = None
|
||||
remote_attestation = None
|
||||
|
||||
yield self.store.add_user_to_group(
|
||||
group_id, user_id,
|
||||
is_admin=True,
|
||||
is_public=True, # TODO
|
||||
local_attestation=local_attestation,
|
||||
remote_attestation=remote_attestation,
|
||||
)
|
||||
|
||||
if not self.hs.is_mine_id(user_id):
|
||||
yield self.store.add_remote_profile_cache(
|
||||
user_id,
|
||||
displayname=user_profile.get("displayname"),
|
||||
avatar_url=user_profile.get("avatar_url"),
|
||||
)
|
||||
|
||||
defer.returnValue({
|
||||
"group_id": group_id,
|
||||
})
|
||||
|
||||
|
||||
def _parse_visibility_from_contents(content):
|
||||
"""Given a content for a request parse out whether the entity should be
|
||||
public or not
|
||||
"""
|
||||
|
||||
visibility = content.get("visibility")
|
||||
if visibility:
|
||||
vis_type = visibility["type"]
|
||||
if vis_type not in ("public", "private"):
|
||||
raise SynapseError(
|
||||
400, "Synapse only supports 'public'/'private' visibility"
|
||||
)
|
||||
is_public = vis_type == "public"
|
||||
else:
|
||||
is_public = True
|
||||
|
||||
return is_public
|
||||
|
||||
|
||||
def _validate_group_id(group_id):
|
||||
"""Validates the group ID is valid for creation on this home server
|
||||
"""
|
||||
localpart = GroupID.from_string(group_id).localpart
|
||||
|
||||
if localpart.lower() != localpart:
|
||||
raise SynapseError(400, "Group ID must be lower case")
|
||||
|
||||
if urllib.quote(localpart.encode('utf-8')) != localpart:
|
||||
raise SynapseError(
|
||||
400,
|
||||
"Group ID can only contain characters a-z, 0-9, or '_-./'",
|
||||
)
|
||||
@@ -20,7 +20,6 @@ from .room import (
|
||||
from .room_member import RoomMemberHandler
|
||||
from .message import MessageHandler
|
||||
from .federation import FederationHandler
|
||||
from .profile import ProfileHandler
|
||||
from .directory import DirectoryHandler
|
||||
from .admin import AdminHandler
|
||||
from .identity import IdentityHandler
|
||||
@@ -52,7 +51,6 @@ class Handlers(object):
|
||||
self.room_creation_handler = RoomCreationHandler(hs)
|
||||
self.room_member_handler = RoomMemberHandler(hs)
|
||||
self.federation_handler = FederationHandler(hs)
|
||||
self.profile_handler = ProfileHandler(hs)
|
||||
self.directory_handler = DirectoryHandler(hs)
|
||||
self.admin_handler = AdminHandler(hs)
|
||||
self.identity_handler = IdentityHandler(hs)
|
||||
|
||||
@@ -40,6 +40,8 @@ class DirectoryHandler(BaseHandler):
|
||||
"directory", self.on_directory_query
|
||||
)
|
||||
|
||||
self.spam_checker = hs.get_spam_checker()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _create_association(self, room_alias, room_id, servers=None, creator=None):
|
||||
# general association creation for both human users and app services
|
||||
@@ -73,6 +75,11 @@ class DirectoryHandler(BaseHandler):
|
||||
# association creation for human users
|
||||
# TODO(erikj): Do user auth.
|
||||
|
||||
if not self.spam_checker.user_may_create_room_alias(user_id, room_alias):
|
||||
raise SynapseError(
|
||||
403, "This user is not permitted to create this alias",
|
||||
)
|
||||
|
||||
can_create = yield self.can_modify_alias(
|
||||
room_alias,
|
||||
user_id=user_id
|
||||
@@ -327,6 +334,14 @@ class DirectoryHandler(BaseHandler):
|
||||
room_id (str)
|
||||
visibility (str): "public" or "private"
|
||||
"""
|
||||
if not self.spam_checker.user_may_publish_room(
|
||||
requester.user.to_string(), room_id
|
||||
):
|
||||
raise AuthError(
|
||||
403,
|
||||
"This user is not permitted to publish rooms to the room list"
|
||||
)
|
||||
|
||||
if requester.is_guest:
|
||||
raise AuthError(403, "Guests cannot edit the published room list")
|
||||
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
"""Contains handlers for federation events."""
|
||||
import synapse.util.logcontext
|
||||
from signedjson.key import decode_verify_key_bytes
|
||||
from signedjson.sign import verify_signed_json
|
||||
from unpaddedbase64 import decode_base64
|
||||
@@ -26,10 +25,7 @@ from synapse.api.errors import (
|
||||
)
|
||||
from synapse.api.constants import EventTypes, Membership, RejectedReason
|
||||
from synapse.events.validator import EventValidator
|
||||
from synapse.util import unwrapFirstError
|
||||
from synapse.util.logcontext import (
|
||||
preserve_fn, preserve_context_over_deferred
|
||||
)
|
||||
from synapse.util import unwrapFirstError, logcontext
|
||||
from synapse.util.metrics import measure_func
|
||||
from synapse.util.logutils import log_function
|
||||
from synapse.util.async import run_on_reactor, Linearizer
|
||||
@@ -77,6 +73,7 @@ class FederationHandler(BaseHandler):
|
||||
self.action_generator = hs.get_action_generator()
|
||||
self.is_mine_id = hs.is_mine_id
|
||||
self.pusher_pool = hs.get_pusherpool()
|
||||
self.spam_checker = hs.get_spam_checker()
|
||||
|
||||
self.replication_layer.set_handler(self)
|
||||
|
||||
@@ -125,6 +122,28 @@ class FederationHandler(BaseHandler):
|
||||
self.room_queues[pdu.room_id].append((pdu, origin))
|
||||
return
|
||||
|
||||
# If we're no longer in the room just ditch the event entirely. This
|
||||
# is probably an old server that has come back and thinks we're still
|
||||
# in the room (or we've been rejoined to the room by a state reset).
|
||||
#
|
||||
# If we were never in the room then maybe our database got vaped and
|
||||
# we should check if we *are* in fact in the room. If we are then we
|
||||
# can magically rejoin the room.
|
||||
is_in_room = yield self.auth.check_host_in_room(
|
||||
pdu.room_id,
|
||||
self.server_name
|
||||
)
|
||||
if not is_in_room:
|
||||
was_in_room = yield self.store.was_host_joined(
|
||||
pdu.room_id, self.server_name,
|
||||
)
|
||||
if was_in_room:
|
||||
logger.info(
|
||||
"Ignoring PDU %s for room %s from %s as we've left the room!",
|
||||
pdu.event_id, pdu.room_id, origin,
|
||||
)
|
||||
return
|
||||
|
||||
state = None
|
||||
|
||||
auth_chain = []
|
||||
@@ -591,9 +610,9 @@ class FederationHandler(BaseHandler):
|
||||
missing_auth - failed_to_fetch
|
||||
)
|
||||
|
||||
results = yield preserve_context_over_deferred(defer.gatherResults(
|
||||
results = yield logcontext.make_deferred_yieldable(defer.gatherResults(
|
||||
[
|
||||
preserve_fn(self.replication_layer.get_pdu)(
|
||||
logcontext.preserve_fn(self.replication_layer.get_pdu)(
|
||||
[dest],
|
||||
event_id,
|
||||
outlier=True,
|
||||
@@ -785,10 +804,14 @@ class FederationHandler(BaseHandler):
|
||||
event_ids = list(extremities.keys())
|
||||
|
||||
logger.debug("calling resolve_state_groups in _maybe_backfill")
|
||||
states = yield preserve_context_over_deferred(defer.gatherResults([
|
||||
preserve_fn(self.state_handler.resolve_state_groups)(room_id, [e])
|
||||
for e in event_ids
|
||||
]))
|
||||
states = yield logcontext.make_deferred_yieldable(defer.gatherResults(
|
||||
[
|
||||
logcontext.preserve_fn(self.state_handler.resolve_state_groups)(
|
||||
room_id, [e]
|
||||
)
|
||||
for e in event_ids
|
||||
], consumeErrors=True,
|
||||
))
|
||||
states = dict(zip(event_ids, [s.state for s in states]))
|
||||
|
||||
state_map = yield self.store.get_events(
|
||||
@@ -941,9 +964,7 @@ class FederationHandler(BaseHandler):
|
||||
# lots of requests for missing prev_events which we do actually
|
||||
# have. Hence we fire off the deferred, but don't wait for it.
|
||||
|
||||
synapse.util.logcontext.preserve_fn(self._handle_queued_pdus)(
|
||||
room_queue
|
||||
)
|
||||
logcontext.preserve_fn(self._handle_queued_pdus)(room_queue)
|
||||
|
||||
defer.returnValue(True)
|
||||
|
||||
@@ -1070,6 +1091,9 @@ class FederationHandler(BaseHandler):
|
||||
"""
|
||||
event = pdu
|
||||
|
||||
if event.state_key is None:
|
||||
raise SynapseError(400, "The invite event did not have a state key")
|
||||
|
||||
is_blocked = yield self.store.is_room_blocked(event.room_id)
|
||||
if is_blocked:
|
||||
raise SynapseError(403, "This room has been blocked on this server")
|
||||
@@ -1077,6 +1101,13 @@ class FederationHandler(BaseHandler):
|
||||
if self.hs.config.block_non_admin_invites:
|
||||
raise SynapseError(403, "This server does not accept room invites")
|
||||
|
||||
if not self.spam_checker.user_may_invite(
|
||||
event.sender, event.state_key, event.room_id,
|
||||
):
|
||||
raise SynapseError(
|
||||
403, "This user is not permitted to send invites to this server/user"
|
||||
)
|
||||
|
||||
membership = event.content.get("membership")
|
||||
if event.type != EventTypes.Member or membership != Membership.INVITE:
|
||||
raise SynapseError(400, "The event was not an m.room.member invite event")
|
||||
@@ -1085,9 +1116,6 @@ class FederationHandler(BaseHandler):
|
||||
if sender_domain != origin:
|
||||
raise SynapseError(400, "The invite event was not from the server sending it")
|
||||
|
||||
if event.state_key is None:
|
||||
raise SynapseError(400, "The invite event did not have a state key")
|
||||
|
||||
if not self.is_mine_id(event.state_key):
|
||||
raise SynapseError(400, "The invite event must be for this server")
|
||||
|
||||
@@ -1430,7 +1458,7 @@ class FederationHandler(BaseHandler):
|
||||
if not backfilled:
|
||||
# this intentionally does not yield: we don't care about the result
|
||||
# and don't need to wait for it.
|
||||
preserve_fn(self.pusher_pool.on_new_notifications)(
|
||||
logcontext.preserve_fn(self.pusher_pool.on_new_notifications)(
|
||||
event_stream_id, max_stream_id
|
||||
)
|
||||
|
||||
@@ -1443,16 +1471,16 @@ class FederationHandler(BaseHandler):
|
||||
a bunch of outliers, but not a chunk of individual events that depend
|
||||
on each other for state calculations.
|
||||
"""
|
||||
contexts = yield preserve_context_over_deferred(defer.gatherResults(
|
||||
contexts = yield logcontext.make_deferred_yieldable(defer.gatherResults(
|
||||
[
|
||||
preserve_fn(self._prep_event)(
|
||||
logcontext.preserve_fn(self._prep_event)(
|
||||
origin,
|
||||
ev_info["event"],
|
||||
state=ev_info.get("state"),
|
||||
auth_events=ev_info.get("auth_events"),
|
||||
)
|
||||
for ev_info in event_infos
|
||||
]
|
||||
], consumeErrors=True,
|
||||
))
|
||||
|
||||
yield self.store.persist_events(
|
||||
@@ -1760,18 +1788,17 @@ class FederationHandler(BaseHandler):
|
||||
# Do auth conflict res.
|
||||
logger.info("Different auth: %s", different_auth)
|
||||
|
||||
different_events = yield preserve_context_over_deferred(defer.gatherResults(
|
||||
[
|
||||
preserve_fn(self.store.get_event)(
|
||||
different_events = yield logcontext.make_deferred_yieldable(
|
||||
defer.gatherResults([
|
||||
logcontext.preserve_fn(self.store.get_event)(
|
||||
d,
|
||||
allow_none=True,
|
||||
allow_rejected=False,
|
||||
)
|
||||
for d in different_auth
|
||||
if d in have_events and not have_events[d]
|
||||
],
|
||||
consumeErrors=True
|
||||
)).addErrback(unwrapFirstError)
|
||||
], consumeErrors=True)
|
||||
).addErrback(unwrapFirstError)
|
||||
|
||||
if different_events:
|
||||
local_view = dict(auth_events)
|
||||
|
||||
417
synapse/handlers/groups_local.py
Normal file
417
synapse/handlers/groups_local.py
Normal file
@@ -0,0 +1,417 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2017 Vector Creations Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.types import get_domain_from_id
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _create_rerouter(func_name):
|
||||
"""Returns a function that looks at the group id and calls the function
|
||||
on federation or the local group server if the group is local
|
||||
"""
|
||||
def f(self, group_id, *args, **kwargs):
|
||||
if self.is_mine_id(group_id):
|
||||
return getattr(self.groups_server_handler, func_name)(
|
||||
group_id, *args, **kwargs
|
||||
)
|
||||
else:
|
||||
destination = get_domain_from_id(group_id)
|
||||
return getattr(self.transport_client, func_name)(
|
||||
destination, group_id, *args, **kwargs
|
||||
)
|
||||
return f
|
||||
|
||||
|
||||
class GroupsLocalHandler(object):
|
||||
def __init__(self, hs):
|
||||
self.hs = hs
|
||||
self.store = hs.get_datastore()
|
||||
self.room_list_handler = hs.get_room_list_handler()
|
||||
self.groups_server_handler = hs.get_groups_server_handler()
|
||||
self.transport_client = hs.get_federation_transport_client()
|
||||
self.auth = hs.get_auth()
|
||||
self.clock = hs.get_clock()
|
||||
self.keyring = hs.get_keyring()
|
||||
self.is_mine_id = hs.is_mine_id
|
||||
self.signing_key = hs.config.signing_key[0]
|
||||
self.server_name = hs.hostname
|
||||
self.notifier = hs.get_notifier()
|
||||
self.attestations = hs.get_groups_attestation_signing()
|
||||
|
||||
self.profile_handler = hs.get_profile_handler()
|
||||
|
||||
# Ensure attestations get renewed
|
||||
hs.get_groups_attestation_renewer()
|
||||
|
||||
# The following functions merely route the query to the local groups server
|
||||
# or federation depending on if the group is local or remote
|
||||
|
||||
get_group_profile = _create_rerouter("get_group_profile")
|
||||
update_group_profile = _create_rerouter("update_group_profile")
|
||||
get_rooms_in_group = _create_rerouter("get_rooms_in_group")
|
||||
|
||||
get_invited_users_in_group = _create_rerouter("get_invited_users_in_group")
|
||||
|
||||
add_room_to_group = _create_rerouter("add_room_to_group")
|
||||
remove_room_from_group = _create_rerouter("remove_room_from_group")
|
||||
|
||||
update_group_summary_room = _create_rerouter("update_group_summary_room")
|
||||
delete_group_summary_room = _create_rerouter("delete_group_summary_room")
|
||||
|
||||
update_group_category = _create_rerouter("update_group_category")
|
||||
delete_group_category = _create_rerouter("delete_group_category")
|
||||
get_group_category = _create_rerouter("get_group_category")
|
||||
get_group_categories = _create_rerouter("get_group_categories")
|
||||
|
||||
update_group_summary_user = _create_rerouter("update_group_summary_user")
|
||||
delete_group_summary_user = _create_rerouter("delete_group_summary_user")
|
||||
|
||||
update_group_role = _create_rerouter("update_group_role")
|
||||
delete_group_role = _create_rerouter("delete_group_role")
|
||||
get_group_role = _create_rerouter("get_group_role")
|
||||
get_group_roles = _create_rerouter("get_group_roles")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_group_summary(self, group_id, requester_user_id):
|
||||
"""Get the group summary for a group.
|
||||
|
||||
If the group is remote we check that the users have valid attestations.
|
||||
"""
|
||||
if self.is_mine_id(group_id):
|
||||
res = yield self.groups_server_handler.get_group_summary(
|
||||
group_id, requester_user_id
|
||||
)
|
||||
else:
|
||||
res = yield self.transport_client.get_group_summary(
|
||||
get_domain_from_id(group_id), group_id, requester_user_id,
|
||||
)
|
||||
|
||||
group_server_name = get_domain_from_id(group_id)
|
||||
|
||||
# Loop through the users and validate the attestations.
|
||||
chunk = res["users_section"]["users"]
|
||||
valid_users = []
|
||||
for entry in chunk:
|
||||
g_user_id = entry["user_id"]
|
||||
attestation = entry.pop("attestation", {})
|
||||
try:
|
||||
if get_domain_from_id(g_user_id) != group_server_name:
|
||||
yield self.attestations.verify_attestation(
|
||||
attestation,
|
||||
group_id=group_id,
|
||||
user_id=g_user_id,
|
||||
server_name=get_domain_from_id(g_user_id),
|
||||
)
|
||||
valid_users.append(entry)
|
||||
except Exception as e:
|
||||
logger.info("Failed to verify user is in group: %s", e)
|
||||
|
||||
res["users_section"]["users"] = valid_users
|
||||
|
||||
res["users_section"]["users"].sort(key=lambda e: e.get("order", 0))
|
||||
res["rooms_section"]["rooms"].sort(key=lambda e: e.get("order", 0))
|
||||
|
||||
# Add `is_publicised` flag to indicate whether the user has publicised their
|
||||
# membership of the group on their profile
|
||||
result = yield self.store.get_publicised_groups_for_user(requester_user_id)
|
||||
is_publicised = group_id in result
|
||||
|
||||
res.setdefault("user", {})["is_publicised"] = is_publicised
|
||||
|
||||
defer.returnValue(res)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def create_group(self, group_id, user_id, content):
|
||||
"""Create a group
|
||||
"""
|
||||
|
||||
logger.info("Asking to create group with ID: %r", group_id)
|
||||
|
||||
if self.is_mine_id(group_id):
|
||||
res = yield self.groups_server_handler.create_group(
|
||||
group_id, user_id, content
|
||||
)
|
||||
local_attestation = None
|
||||
remote_attestation = None
|
||||
else:
|
||||
local_attestation = self.attestations.create_attestation(group_id, user_id)
|
||||
content["attestation"] = local_attestation
|
||||
|
||||
content["user_profile"] = yield self.profile_handler.get_profile(user_id)
|
||||
|
||||
res = yield self.transport_client.create_group(
|
||||
get_domain_from_id(group_id), group_id, user_id, content,
|
||||
)
|
||||
|
||||
remote_attestation = res["attestation"]
|
||||
yield self.attestations.verify_attestation(
|
||||
remote_attestation,
|
||||
group_id=group_id,
|
||||
user_id=user_id,
|
||||
server_name=get_domain_from_id(group_id),
|
||||
)
|
||||
|
||||
is_publicised = content.get("publicise", False)
|
||||
token = yield self.store.register_user_group_membership(
|
||||
group_id, user_id,
|
||||
membership="join",
|
||||
is_admin=True,
|
||||
local_attestation=local_attestation,
|
||||
remote_attestation=remote_attestation,
|
||||
is_publicised=is_publicised,
|
||||
)
|
||||
self.notifier.on_new_event(
|
||||
"groups_key", token, users=[user_id],
|
||||
)
|
||||
|
||||
defer.returnValue(res)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_users_in_group(self, group_id, requester_user_id):
|
||||
"""Get users in a group
|
||||
"""
|
||||
if self.is_mine_id(group_id):
|
||||
res = yield self.groups_server_handler.get_users_in_group(
|
||||
group_id, requester_user_id
|
||||
)
|
||||
defer.returnValue(res)
|
||||
|
||||
group_server_name = get_domain_from_id(group_id)
|
||||
|
||||
res = yield self.transport_client.get_users_in_group(
|
||||
get_domain_from_id(group_id), group_id, requester_user_id,
|
||||
)
|
||||
|
||||
chunk = res["chunk"]
|
||||
valid_entries = []
|
||||
for entry in chunk:
|
||||
g_user_id = entry["user_id"]
|
||||
attestation = entry.pop("attestation", {})
|
||||
try:
|
||||
if get_domain_from_id(g_user_id) != group_server_name:
|
||||
yield self.attestations.verify_attestation(
|
||||
attestation,
|
||||
group_id=group_id,
|
||||
user_id=g_user_id,
|
||||
server_name=get_domain_from_id(g_user_id),
|
||||
)
|
||||
valid_entries.append(entry)
|
||||
except Exception as e:
|
||||
logger.info("Failed to verify user is in group: %s", e)
|
||||
|
||||
res["chunk"] = valid_entries
|
||||
|
||||
defer.returnValue(res)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def join_group(self, group_id, user_id, content):
|
||||
"""Request to join a group
|
||||
"""
|
||||
raise NotImplementedError() # TODO
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def accept_invite(self, group_id, user_id, content):
|
||||
"""Accept an invite to a group
|
||||
"""
|
||||
if self.is_mine_id(group_id):
|
||||
yield self.groups_server_handler.accept_invite(
|
||||
group_id, user_id, content
|
||||
)
|
||||
local_attestation = None
|
||||
remote_attestation = None
|
||||
else:
|
||||
local_attestation = self.attestations.create_attestation(group_id, user_id)
|
||||
content["attestation"] = local_attestation
|
||||
|
||||
res = yield self.transport_client.accept_group_invite(
|
||||
get_domain_from_id(group_id), group_id, user_id, content,
|
||||
)
|
||||
|
||||
remote_attestation = res["attestation"]
|
||||
|
||||
yield self.attestations.verify_attestation(
|
||||
remote_attestation,
|
||||
group_id=group_id,
|
||||
user_id=user_id,
|
||||
server_name=get_domain_from_id(group_id),
|
||||
)
|
||||
|
||||
# TODO: Check that the group is public and we're being added publically
|
||||
is_publicised = content.get("publicise", False)
|
||||
|
||||
token = yield self.store.register_user_group_membership(
|
||||
group_id, user_id,
|
||||
membership="join",
|
||||
is_admin=False,
|
||||
local_attestation=local_attestation,
|
||||
remote_attestation=remote_attestation,
|
||||
is_publicised=is_publicised,
|
||||
)
|
||||
self.notifier.on_new_event(
|
||||
"groups_key", token, users=[user_id],
|
||||
)
|
||||
|
||||
defer.returnValue({})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def invite(self, group_id, user_id, requester_user_id, config):
|
||||
"""Invite a user to a group
|
||||
"""
|
||||
content = {
|
||||
"requester_user_id": requester_user_id,
|
||||
"config": config,
|
||||
}
|
||||
if self.is_mine_id(group_id):
|
||||
res = yield self.groups_server_handler.invite_to_group(
|
||||
group_id, user_id, requester_user_id, content,
|
||||
)
|
||||
else:
|
||||
res = yield self.transport_client.invite_to_group(
|
||||
get_domain_from_id(group_id), group_id, user_id, requester_user_id,
|
||||
content,
|
||||
)
|
||||
|
||||
defer.returnValue(res)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_invite(self, group_id, user_id, content):
|
||||
"""One of our users were invited to a group
|
||||
"""
|
||||
# TODO: Support auto join and rejection
|
||||
|
||||
if not self.is_mine_id(user_id):
|
||||
raise SynapseError(400, "User not on this server")
|
||||
|
||||
local_profile = {}
|
||||
if "profile" in content:
|
||||
if "name" in content["profile"]:
|
||||
local_profile["name"] = content["profile"]["name"]
|
||||
if "avatar_url" in content["profile"]:
|
||||
local_profile["avatar_url"] = content["profile"]["avatar_url"]
|
||||
|
||||
token = yield self.store.register_user_group_membership(
|
||||
group_id, user_id,
|
||||
membership="invite",
|
||||
content={"profile": local_profile, "inviter": content["inviter"]},
|
||||
)
|
||||
self.notifier.on_new_event(
|
||||
"groups_key", token, users=[user_id],
|
||||
)
|
||||
try:
|
||||
user_profile = yield self.profile_handler.get_profile(user_id)
|
||||
except Exception as e:
|
||||
logger.warn("No profile for user %s: %s", user_id, e)
|
||||
user_profile = {}
|
||||
|
||||
defer.returnValue({"state": "invite", "user_profile": user_profile})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def remove_user_from_group(self, group_id, user_id, requester_user_id, content):
|
||||
"""Remove a user from a group
|
||||
"""
|
||||
if user_id == requester_user_id:
|
||||
token = yield self.store.register_user_group_membership(
|
||||
group_id, user_id,
|
||||
membership="leave",
|
||||
)
|
||||
self.notifier.on_new_event(
|
||||
"groups_key", token, users=[user_id],
|
||||
)
|
||||
|
||||
# TODO: Should probably remember that we tried to leave so that we can
|
||||
# retry if the group server is currently down.
|
||||
|
||||
if self.is_mine_id(group_id):
|
||||
res = yield self.groups_server_handler.remove_user_from_group(
|
||||
group_id, user_id, requester_user_id, content,
|
||||
)
|
||||
else:
|
||||
content["requester_user_id"] = requester_user_id
|
||||
res = yield self.transport_client.remove_user_from_group(
|
||||
get_domain_from_id(group_id), group_id, requester_user_id,
|
||||
user_id, content,
|
||||
)
|
||||
|
||||
defer.returnValue(res)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def user_removed_from_group(self, group_id, user_id, content):
|
||||
"""One of our users was removed/kicked from a group
|
||||
"""
|
||||
# TODO: Check if user in group
|
||||
token = yield self.store.register_user_group_membership(
|
||||
group_id, user_id,
|
||||
membership="leave",
|
||||
)
|
||||
self.notifier.on_new_event(
|
||||
"groups_key", token, users=[user_id],
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_joined_groups(self, user_id):
|
||||
group_ids = yield self.store.get_joined_groups(user_id)
|
||||
defer.returnValue({"groups": group_ids})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_publicised_groups_for_user(self, user_id):
|
||||
if self.hs.is_mine_id(user_id):
|
||||
result = yield self.store.get_publicised_groups_for_user(user_id)
|
||||
defer.returnValue({"groups": result})
|
||||
else:
|
||||
result = yield self.transport_client.get_publicised_groups_for_user(
|
||||
get_domain_from_id(user_id), user_id
|
||||
)
|
||||
# TODO: Verify attestations
|
||||
defer.returnValue(result)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def bulk_get_publicised_groups(self, user_ids, proxy=True):
|
||||
destinations = {}
|
||||
local_users = set()
|
||||
|
||||
for user_id in user_ids:
|
||||
if self.hs.is_mine_id(user_id):
|
||||
local_users.add(user_id)
|
||||
else:
|
||||
destinations.setdefault(
|
||||
get_domain_from_id(user_id), set()
|
||||
).add(user_id)
|
||||
|
||||
if not proxy and destinations:
|
||||
raise SynapseError(400, "Some user_ids are not local")
|
||||
|
||||
results = {}
|
||||
failed_results = []
|
||||
for destination, dest_user_ids in destinations.iteritems():
|
||||
try:
|
||||
r = yield self.transport_client.bulk_get_publicised_groups(
|
||||
destination, list(dest_user_ids),
|
||||
)
|
||||
results.update(r["users"])
|
||||
except Exception:
|
||||
failed_results.extend(dest_user_ids)
|
||||
|
||||
for uid in local_users:
|
||||
results[uid] = yield self.store.get_publicised_groups_for_user(
|
||||
uid
|
||||
)
|
||||
|
||||
defer.returnValue({"users": results})
|
||||
@@ -1,5 +1,6 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2014 - 2016 OpenMarket Ltd
|
||||
# Copyright 2017 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -12,7 +13,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from synapse.events import spamcheck
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.constants import EventTypes, Membership
|
||||
@@ -26,6 +26,7 @@ from synapse.types import (
|
||||
from synapse.util.async import run_on_reactor, ReadWriteLock, Limiter
|
||||
from synapse.util.logcontext import preserve_fn
|
||||
from synapse.util.metrics import measure_func
|
||||
from synapse.util.frozenutils import unfreeze
|
||||
from synapse.visibility import filter_events_for_client
|
||||
|
||||
from ._base import BaseHandler
|
||||
@@ -47,6 +48,7 @@ class MessageHandler(BaseHandler):
|
||||
self.state = hs.get_state_handler()
|
||||
self.clock = hs.get_clock()
|
||||
self.validator = EventValidator()
|
||||
self.profile_handler = hs.get_profile_handler()
|
||||
|
||||
self.pagination_lock = ReadWriteLock()
|
||||
|
||||
@@ -58,6 +60,8 @@ class MessageHandler(BaseHandler):
|
||||
|
||||
self.action_generator = hs.get_action_generator()
|
||||
|
||||
self.spam_checker = hs.get_spam_checker()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def purge_history(self, room_id, event_id):
|
||||
event = yield self.store.get_event(event_id)
|
||||
@@ -210,7 +214,7 @@ class MessageHandler(BaseHandler):
|
||||
|
||||
if membership in {Membership.JOIN, Membership.INVITE}:
|
||||
# If event doesn't include a display name, add one.
|
||||
profile = self.hs.get_handlers().profile_handler
|
||||
profile = self.profile_handler
|
||||
content = builder.content
|
||||
|
||||
try:
|
||||
@@ -322,9 +326,12 @@ class MessageHandler(BaseHandler):
|
||||
txn_id=txn_id
|
||||
)
|
||||
|
||||
if spamcheck.check_event_for_spam(event):
|
||||
spam_error = self.spam_checker.check_event_for_spam(event)
|
||||
if spam_error:
|
||||
if not isinstance(spam_error, basestring):
|
||||
spam_error = "Spam is not permitted here"
|
||||
raise SynapseError(
|
||||
403, "Spam is not permitted here", Codes.FORBIDDEN
|
||||
403, spam_error, Codes.FORBIDDEN
|
||||
)
|
||||
|
||||
yield self.send_nonmember_event(
|
||||
@@ -418,6 +425,51 @@ class MessageHandler(BaseHandler):
|
||||
[serialize_event(c, now) for c in room_state.values()]
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_joined_members(self, requester, room_id):
|
||||
"""Get all the joined members in the room and their profile information.
|
||||
|
||||
If the user has left the room return the state events from when they left.
|
||||
|
||||
Args:
|
||||
requester(Requester): The user requesting state events.
|
||||
room_id(str): The room ID to get all state events from.
|
||||
Returns:
|
||||
A dict of user_id to profile info
|
||||
"""
|
||||
user_id = requester.user.to_string()
|
||||
if not requester.app_service:
|
||||
# We check AS auth after fetching the room membership, as it
|
||||
# requires us to pull out all joined members anyway.
|
||||
membership, _ = yield self._check_in_room_or_world_readable(
|
||||
room_id, user_id
|
||||
)
|
||||
if membership != Membership.JOIN:
|
||||
raise NotImplementedError(
|
||||
"Getting joined members after leaving is not implemented"
|
||||
)
|
||||
|
||||
users_with_profile = yield self.state.get_current_user_in_room(room_id)
|
||||
|
||||
# If this is an AS, double check that they are allowed to see the members.
|
||||
# This can either be because the AS user is in the room or becuase there
|
||||
# is a user in the room that the AS is "interested in"
|
||||
if requester.app_service and user_id not in users_with_profile:
|
||||
for uid in users_with_profile:
|
||||
if requester.app_service.is_interested_in_user(uid):
|
||||
break
|
||||
else:
|
||||
# Loop fell through, AS has no interested users in room
|
||||
raise AuthError(403, "Appservice not in room")
|
||||
|
||||
defer.returnValue({
|
||||
user_id: {
|
||||
"avatar_url": profile.avatar_url,
|
||||
"display_name": profile.display_name,
|
||||
}
|
||||
for user_id, profile in users_with_profile.iteritems()
|
||||
})
|
||||
|
||||
@measure_func("_create_new_client_event")
|
||||
@defer.inlineCallbacks
|
||||
def _create_new_client_event(self, builder, requester=None, prev_event_ids=None):
|
||||
@@ -509,7 +561,7 @@ class MessageHandler(BaseHandler):
|
||||
|
||||
# Ensure that we can round trip before trying to persist in db
|
||||
try:
|
||||
dump = ujson.dumps(event.content)
|
||||
dump = ujson.dumps(unfreeze(event.content))
|
||||
ujson.loads(dump)
|
||||
except:
|
||||
logger.exception("Failed to encode content: %r", event.content)
|
||||
|
||||
@@ -19,14 +19,15 @@ from twisted.internet import defer
|
||||
|
||||
import synapse.types
|
||||
from synapse.api.errors import SynapseError, AuthError, CodeMessageException
|
||||
from synapse.types import UserID
|
||||
from synapse.types import UserID, get_domain_from_id
|
||||
from ._base import BaseHandler
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ProfileHandler(BaseHandler):
|
||||
PROFILE_UPDATE_MS = 60 * 1000
|
||||
PROFILE_UPDATE_EVERY_MS = 24 * 60 * 60 * 1000
|
||||
|
||||
def __init__(self, hs):
|
||||
super(ProfileHandler, self).__init__(hs)
|
||||
@@ -36,6 +37,63 @@ class ProfileHandler(BaseHandler):
|
||||
"profile", self.on_profile_query
|
||||
)
|
||||
|
||||
self.clock.looping_call(self._update_remote_profile_cache, self.PROFILE_UPDATE_MS)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_profile(self, user_id):
|
||||
target_user = UserID.from_string(user_id)
|
||||
if self.hs.is_mine(target_user):
|
||||
displayname = yield self.store.get_profile_displayname(
|
||||
target_user.localpart
|
||||
)
|
||||
avatar_url = yield self.store.get_profile_avatar_url(
|
||||
target_user.localpart
|
||||
)
|
||||
|
||||
defer.returnValue({
|
||||
"displayname": displayname,
|
||||
"avatar_url": avatar_url,
|
||||
})
|
||||
else:
|
||||
try:
|
||||
result = yield self.federation.make_query(
|
||||
destination=target_user.domain,
|
||||
query_type="profile",
|
||||
args={
|
||||
"user_id": user_id,
|
||||
},
|
||||
ignore_backoff=True,
|
||||
)
|
||||
defer.returnValue(result)
|
||||
except CodeMessageException as e:
|
||||
if e.code != 404:
|
||||
logger.exception("Failed to get displayname")
|
||||
|
||||
raise
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_profile_from_cache(self, user_id):
|
||||
"""Get the profile information from our local cache. If the user is
|
||||
ours then the profile information will always be corect. Otherwise,
|
||||
it may be out of date/missing.
|
||||
"""
|
||||
target_user = UserID.from_string(user_id)
|
||||
if self.hs.is_mine(target_user):
|
||||
displayname = yield self.store.get_profile_displayname(
|
||||
target_user.localpart
|
||||
)
|
||||
avatar_url = yield self.store.get_profile_avatar_url(
|
||||
target_user.localpart
|
||||
)
|
||||
|
||||
defer.returnValue({
|
||||
"displayname": displayname,
|
||||
"avatar_url": avatar_url,
|
||||
})
|
||||
else:
|
||||
profile = yield self.store.get_from_remote_profile_cache(user_id)
|
||||
defer.returnValue(profile or {})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_displayname(self, target_user):
|
||||
if self.hs.is_mine(target_user):
|
||||
@@ -182,3 +240,44 @@ class ProfileHandler(BaseHandler):
|
||||
"Failed to update join event for room %s - %s",
|
||||
room_id, str(e.message)
|
||||
)
|
||||
|
||||
def _update_remote_profile_cache(self):
|
||||
"""Called periodically to check profiles of remote users we haven't
|
||||
checked in a while.
|
||||
"""
|
||||
entries = yield self.store.get_remote_profile_cache_entries_that_expire(
|
||||
last_checked=self.clock.time_msec() - self.PROFILE_UPDATE_EVERY_MS
|
||||
)
|
||||
|
||||
for user_id, displayname, avatar_url in entries:
|
||||
is_subscribed = yield self.store.is_subscribed_remote_profile_for_user(
|
||||
user_id,
|
||||
)
|
||||
if not is_subscribed:
|
||||
yield self.store.maybe_delete_remote_profile_cache(user_id)
|
||||
continue
|
||||
|
||||
try:
|
||||
profile = yield self.federation.make_query(
|
||||
destination=get_domain_from_id(user_id),
|
||||
query_type="profile",
|
||||
args={
|
||||
"user_id": user_id,
|
||||
},
|
||||
ignore_backoff=True,
|
||||
)
|
||||
except:
|
||||
logger.exception("Failed to get avatar_url")
|
||||
|
||||
yield self.store.update_remote_profile_cache(
|
||||
user_id, displayname, avatar_url
|
||||
)
|
||||
continue
|
||||
|
||||
new_name = profile.get("displayname")
|
||||
new_avatar = profile.get("avatar_url")
|
||||
|
||||
# We always hit update to update the last_check timestamp
|
||||
yield self.store.update_remote_profile_cache(
|
||||
user_id, new_name, new_avatar
|
||||
)
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from synapse.util import logcontext
|
||||
|
||||
from ._base import BaseHandler
|
||||
|
||||
@@ -59,6 +60,8 @@ class ReceiptsHandler(BaseHandler):
|
||||
is_new = yield self._handle_new_receipts([receipt])
|
||||
|
||||
if is_new:
|
||||
# fire off a process in the background to send the receipt to
|
||||
# remote servers
|
||||
self._push_remotes([receipt])
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@@ -126,6 +129,7 @@ class ReceiptsHandler(BaseHandler):
|
||||
|
||||
defer.returnValue(True)
|
||||
|
||||
@logcontext.preserve_fn # caller should not yield on this
|
||||
@defer.inlineCallbacks
|
||||
def _push_remotes(self, receipts):
|
||||
"""Given a list of receipts, works out which remote servers should be
|
||||
|
||||
@@ -36,6 +36,7 @@ class RegistrationHandler(BaseHandler):
|
||||
super(RegistrationHandler, self).__init__(hs)
|
||||
|
||||
self.auth = hs.get_auth()
|
||||
self.profile_handler = hs.get_profile_handler()
|
||||
self.captcha_client = CaptchaServerHttpClient(hs)
|
||||
|
||||
self._next_generated_user_id = None
|
||||
@@ -423,8 +424,7 @@ class RegistrationHandler(BaseHandler):
|
||||
|
||||
if displayname is not None:
|
||||
logger.info("setting user display name: %s -> %s", user_id, displayname)
|
||||
profile_handler = self.hs.get_handlers().profile_handler
|
||||
yield profile_handler.set_displayname(
|
||||
yield self.profile_handler.set_displayname(
|
||||
user, requester, displayname, by_admin=True,
|
||||
)
|
||||
|
||||
|
||||
@@ -60,6 +60,11 @@ class RoomCreationHandler(BaseHandler):
|
||||
},
|
||||
}
|
||||
|
||||
def __init__(self, hs):
|
||||
super(RoomCreationHandler, self).__init__(hs)
|
||||
|
||||
self.spam_checker = hs.get_spam_checker()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def create_room(self, requester, config, ratelimit=True):
|
||||
""" Creates a new room.
|
||||
@@ -75,6 +80,9 @@ class RoomCreationHandler(BaseHandler):
|
||||
"""
|
||||
user_id = requester.user.to_string()
|
||||
|
||||
if not self.spam_checker.user_may_create_room(user_id):
|
||||
raise SynapseError(403, "You are not permitted to create rooms")
|
||||
|
||||
if ratelimit:
|
||||
yield self.ratelimit(requester)
|
||||
|
||||
|
||||
@@ -276,13 +276,14 @@ class RoomListHandler(BaseHandler):
|
||||
# We've already got enough, so lets just drop it.
|
||||
return
|
||||
|
||||
result = yield self._generate_room_entry(room_id, num_joined_users)
|
||||
result = yield self.generate_room_entry(room_id, num_joined_users)
|
||||
|
||||
if result and _matches_room_entry(result, search_filter):
|
||||
chunk.append(result)
|
||||
|
||||
@cachedInlineCallbacks(num_args=1, cache_context=True)
|
||||
def _generate_room_entry(self, room_id, num_joined_users, cache_context):
|
||||
def generate_room_entry(self, room_id, num_joined_users, cache_context,
|
||||
with_alias=True, allow_private=False):
|
||||
"""Returns the entry for a room
|
||||
"""
|
||||
result = {
|
||||
@@ -316,14 +317,15 @@ class RoomListHandler(BaseHandler):
|
||||
join_rules_event = current_state.get((EventTypes.JoinRules, ""))
|
||||
if join_rules_event:
|
||||
join_rule = join_rules_event.content.get("join_rule", None)
|
||||
if join_rule and join_rule != JoinRules.PUBLIC:
|
||||
if not allow_private and join_rule and join_rule != JoinRules.PUBLIC:
|
||||
defer.returnValue(None)
|
||||
|
||||
aliases = yield self.store.get_aliases_for_room(
|
||||
room_id, on_invalidate=cache_context.invalidate
|
||||
)
|
||||
if aliases:
|
||||
result["aliases"] = aliases
|
||||
if with_alias:
|
||||
aliases = yield self.store.get_aliases_for_room(
|
||||
room_id, on_invalidate=cache_context.invalidate
|
||||
)
|
||||
if aliases:
|
||||
result["aliases"] = aliases
|
||||
|
||||
name_event = yield current_state.get((EventTypes.Name, ""))
|
||||
if name_event:
|
||||
|
||||
@@ -45,9 +45,12 @@ class RoomMemberHandler(BaseHandler):
|
||||
def __init__(self, hs):
|
||||
super(RoomMemberHandler, self).__init__(hs)
|
||||
|
||||
self.profile_handler = hs.get_profile_handler()
|
||||
|
||||
self.member_linearizer = Linearizer(name="member")
|
||||
|
||||
self.clock = hs.get_clock()
|
||||
self.spam_checker = hs.get_spam_checker()
|
||||
|
||||
self.distributor = hs.get_distributor()
|
||||
self.distributor.declare("user_joined_room")
|
||||
@@ -210,12 +213,26 @@ class RoomMemberHandler(BaseHandler):
|
||||
if is_blocked:
|
||||
raise SynapseError(403, "This room has been blocked on this server")
|
||||
|
||||
if (effective_membership_state == "invite" and
|
||||
self.hs.config.block_non_admin_invites):
|
||||
if effective_membership_state == "invite":
|
||||
block_invite = False
|
||||
is_requester_admin = yield self.auth.is_server_admin(
|
||||
requester.user,
|
||||
)
|
||||
if not is_requester_admin:
|
||||
if self.hs.config.block_non_admin_invites:
|
||||
logger.info(
|
||||
"Blocking invite: user is not admin and non-admin "
|
||||
"invites disabled"
|
||||
)
|
||||
block_invite = True
|
||||
|
||||
if not self.spam_checker.user_may_invite(
|
||||
requester.user.to_string(), target.to_string(), room_id,
|
||||
):
|
||||
logger.info("Blocking invite due to spam checker")
|
||||
block_invite = True
|
||||
|
||||
if block_invite:
|
||||
raise SynapseError(
|
||||
403, "Invites have been disabled on this server",
|
||||
)
|
||||
@@ -267,7 +284,7 @@ class RoomMemberHandler(BaseHandler):
|
||||
|
||||
content["membership"] = Membership.JOIN
|
||||
|
||||
profile = self.hs.get_handlers().profile_handler
|
||||
profile = self.profile_handler
|
||||
if not content_specified:
|
||||
content["displayname"] = yield profile.get_displayname(target)
|
||||
content["avatar_url"] = yield profile.get_avatar_url(target)
|
||||
|
||||
@@ -108,6 +108,17 @@ class InvitedSyncResult(collections.namedtuple("InvitedSyncResult", [
|
||||
return True
|
||||
|
||||
|
||||
class GroupsSyncResult(collections.namedtuple("GroupsSyncResult", [
|
||||
"join",
|
||||
"invite",
|
||||
"leave",
|
||||
])):
|
||||
__slots__ = []
|
||||
|
||||
def __nonzero__(self):
|
||||
return bool(self.join or self.invite or self.leave)
|
||||
|
||||
|
||||
class DeviceLists(collections.namedtuple("DeviceLists", [
|
||||
"changed", # list of user_ids whose devices may have changed
|
||||
"left", # list of user_ids whose devices we no longer track
|
||||
@@ -129,6 +140,7 @@ class SyncResult(collections.namedtuple("SyncResult", [
|
||||
"device_lists", # List of user_ids whose devices have chanegd
|
||||
"device_one_time_keys_count", # Dict of algorithm to count for one time keys
|
||||
# for this device
|
||||
"groups",
|
||||
])):
|
||||
__slots__ = []
|
||||
|
||||
@@ -144,7 +156,8 @@ class SyncResult(collections.namedtuple("SyncResult", [
|
||||
self.archived or
|
||||
self.account_data or
|
||||
self.to_device or
|
||||
self.device_lists
|
||||
self.device_lists or
|
||||
self.groups
|
||||
)
|
||||
|
||||
|
||||
@@ -595,6 +608,8 @@ class SyncHandler(object):
|
||||
user_id, device_id
|
||||
)
|
||||
|
||||
yield self._generate_sync_entry_for_groups(sync_result_builder)
|
||||
|
||||
defer.returnValue(SyncResult(
|
||||
presence=sync_result_builder.presence,
|
||||
account_data=sync_result_builder.account_data,
|
||||
@@ -603,10 +618,57 @@ class SyncHandler(object):
|
||||
archived=sync_result_builder.archived,
|
||||
to_device=sync_result_builder.to_device,
|
||||
device_lists=device_lists,
|
||||
groups=sync_result_builder.groups,
|
||||
device_one_time_keys_count=one_time_key_counts,
|
||||
next_batch=sync_result_builder.now_token,
|
||||
))
|
||||
|
||||
@measure_func("_generate_sync_entry_for_groups")
|
||||
@defer.inlineCallbacks
|
||||
def _generate_sync_entry_for_groups(self, sync_result_builder):
|
||||
user_id = sync_result_builder.sync_config.user.to_string()
|
||||
since_token = sync_result_builder.since_token
|
||||
now_token = sync_result_builder.now_token
|
||||
|
||||
if since_token and since_token.groups_key:
|
||||
results = yield self.store.get_groups_changes_for_user(
|
||||
user_id, since_token.groups_key, now_token.groups_key,
|
||||
)
|
||||
else:
|
||||
results = yield self.store.get_all_groups_for_user(
|
||||
user_id, now_token.groups_key,
|
||||
)
|
||||
|
||||
invited = {}
|
||||
joined = {}
|
||||
left = {}
|
||||
for result in results:
|
||||
membership = result["membership"]
|
||||
group_id = result["group_id"]
|
||||
gtype = result["type"]
|
||||
content = result["content"]
|
||||
|
||||
if membership == "join":
|
||||
if gtype == "membership":
|
||||
# TODO: Add profile
|
||||
content.pop("membership", None)
|
||||
joined[group_id] = content["content"]
|
||||
else:
|
||||
joined.setdefault(group_id, {})[gtype] = content
|
||||
elif membership == "invite":
|
||||
if gtype == "membership":
|
||||
content.pop("membership", None)
|
||||
invited[group_id] = content["content"]
|
||||
else:
|
||||
if gtype == "membership":
|
||||
left[group_id] = content["content"]
|
||||
|
||||
sync_result_builder.groups = GroupsSyncResult(
|
||||
join=joined,
|
||||
invite=invited,
|
||||
leave=left,
|
||||
)
|
||||
|
||||
@measure_func("_generate_sync_entry_for_device_list")
|
||||
@defer.inlineCallbacks
|
||||
def _generate_sync_entry_for_device_list(self, sync_result_builder,
|
||||
@@ -1368,6 +1430,7 @@ class SyncResultBuilder(object):
|
||||
self.invited = []
|
||||
self.archived = []
|
||||
self.device = []
|
||||
self.groups = None
|
||||
self.to_device = []
|
||||
|
||||
|
||||
|
||||
@@ -354,16 +354,28 @@ def _get_hosts_for_srv_record(dns_client, host):
|
||||
|
||||
return res[0]
|
||||
|
||||
def eb(res):
|
||||
res.trap(DNSNameError)
|
||||
return []
|
||||
def eb(res, record_type):
|
||||
if res.check(DNSNameError):
|
||||
return []
|
||||
logger.warn("Error looking up %s for %s: %s",
|
||||
record_type, host, res, res.value)
|
||||
return res
|
||||
|
||||
# no logcontexts here, so we can safely fire these off and gatherResults
|
||||
d1 = dns_client.lookupAddress(host).addCallbacks(cb, eb)
|
||||
d2 = dns_client.lookupIPV6Address(host).addCallbacks(cb, eb)
|
||||
results = yield defer.gatherResults([d1, d2], consumeErrors=True)
|
||||
results = yield defer.DeferredList(
|
||||
[d1, d2], consumeErrors=True)
|
||||
|
||||
# if all of the lookups failed, raise an exception rather than blowing out
|
||||
# the cache with an empty result.
|
||||
if results and all(s == defer.FAILURE for (s, _) in results):
|
||||
defer.returnValue(results[0][1])
|
||||
|
||||
for (success, result) in results:
|
||||
if success == defer.FAILURE:
|
||||
continue
|
||||
|
||||
for result in results:
|
||||
for answer in result:
|
||||
if not answer.payload:
|
||||
continue
|
||||
|
||||
@@ -204,18 +204,15 @@ class MatrixFederationHttpClient(object):
|
||||
raise
|
||||
|
||||
logger.warn(
|
||||
"{%s} Sending request failed to %s: %s %s: %s - %s",
|
||||
"{%s} Sending request failed to %s: %s %s: %s",
|
||||
txn_id,
|
||||
destination,
|
||||
method,
|
||||
url_bytes,
|
||||
type(e).__name__,
|
||||
_flatten_response_never_received(e),
|
||||
)
|
||||
|
||||
log_result = "%s - %s" % (
|
||||
type(e).__name__, _flatten_response_never_received(e),
|
||||
)
|
||||
log_result = _flatten_response_never_received(e)
|
||||
|
||||
if retries_left and not timeout:
|
||||
if long_retries:
|
||||
@@ -347,7 +344,7 @@ class MatrixFederationHttpClient(object):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def post_json(self, destination, path, data={}, long_retries=False,
|
||||
timeout=None, ignore_backoff=False):
|
||||
timeout=None, ignore_backoff=False, args={}):
|
||||
""" Sends the specifed json data using POST
|
||||
|
||||
Args:
|
||||
@@ -383,6 +380,7 @@ class MatrixFederationHttpClient(object):
|
||||
destination,
|
||||
"POST",
|
||||
path,
|
||||
query_bytes=encode_query_args(args),
|
||||
body_callback=body_callback,
|
||||
headers_dict={"Content-Type": ["application/json"]},
|
||||
long_retries=long_retries,
|
||||
@@ -427,13 +425,6 @@ class MatrixFederationHttpClient(object):
|
||||
"""
|
||||
logger.debug("get_json args: %s", args)
|
||||
|
||||
encoded_args = {}
|
||||
for k, vs in args.items():
|
||||
if isinstance(vs, basestring):
|
||||
vs = [vs]
|
||||
encoded_args[k] = [v.encode("UTF-8") for v in vs]
|
||||
|
||||
query_bytes = urllib.urlencode(encoded_args, True)
|
||||
logger.debug("Query bytes: %s Retry DNS: %s", args, retry_on_dns_fail)
|
||||
|
||||
def body_callback(method, url_bytes, headers_dict):
|
||||
@@ -444,7 +435,7 @@ class MatrixFederationHttpClient(object):
|
||||
destination,
|
||||
"GET",
|
||||
path,
|
||||
query_bytes=query_bytes,
|
||||
query_bytes=encode_query_args(args),
|
||||
body_callback=body_callback,
|
||||
retry_on_dns_fail=retry_on_dns_fail,
|
||||
timeout=timeout,
|
||||
@@ -460,6 +451,52 @@ class MatrixFederationHttpClient(object):
|
||||
|
||||
defer.returnValue(json.loads(body))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def delete_json(self, destination, path, long_retries=False,
|
||||
timeout=None, ignore_backoff=False, args={}):
|
||||
"""Send a DELETE request to the remote expecting some json response
|
||||
|
||||
Args:
|
||||
destination (str): The remote server to send the HTTP request
|
||||
to.
|
||||
path (str): The HTTP path.
|
||||
long_retries (bool): A boolean that indicates whether we should
|
||||
retry for a short or long time.
|
||||
timeout(int): How long to try (in ms) the destination for before
|
||||
giving up. None indicates no timeout.
|
||||
ignore_backoff (bool): true to ignore the historical backoff data and
|
||||
try the request anyway.
|
||||
Returns:
|
||||
Deferred: Succeeds when we get a 2xx HTTP response. The result
|
||||
will be the decoded JSON body.
|
||||
|
||||
Fails with ``HTTPRequestException`` if we get an HTTP response
|
||||
code >= 300.
|
||||
|
||||
Fails with ``NotRetryingDestination`` if we are not yet ready
|
||||
to retry this server.
|
||||
"""
|
||||
|
||||
response = yield self._request(
|
||||
destination,
|
||||
"DELETE",
|
||||
path,
|
||||
query_bytes=encode_query_args(args),
|
||||
headers_dict={"Content-Type": ["application/json"]},
|
||||
long_retries=long_retries,
|
||||
timeout=timeout,
|
||||
ignore_backoff=ignore_backoff,
|
||||
)
|
||||
|
||||
if 200 <= response.code < 300:
|
||||
# We need to update the transactions table to say it was sent?
|
||||
check_content_type_is_json(response.headers)
|
||||
|
||||
with logcontext.PreserveLoggingContext():
|
||||
body = yield readBody(response)
|
||||
|
||||
defer.returnValue(json.loads(body))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_file(self, destination, path, output_stream, args={},
|
||||
retry_on_dns_fail=True, max_size=None,
|
||||
@@ -578,12 +615,14 @@ class _JsonProducer(object):
|
||||
|
||||
def _flatten_response_never_received(e):
|
||||
if hasattr(e, "reasons"):
|
||||
return ", ".join(
|
||||
reasons = ", ".join(
|
||||
_flatten_response_never_received(f.value)
|
||||
for f in e.reasons
|
||||
)
|
||||
|
||||
return "%s:[%s]" % (type(e).__name__, reasons)
|
||||
else:
|
||||
return "%s: %s" % (type(e).__name__, e.message,)
|
||||
return repr(e)
|
||||
|
||||
|
||||
def check_content_type_is_json(headers):
|
||||
@@ -610,3 +649,15 @@ def check_content_type_is_json(headers):
|
||||
raise RuntimeError(
|
||||
"Content-Type not application/json: was '%s'" % c_type
|
||||
)
|
||||
|
||||
|
||||
def encode_query_args(args):
|
||||
encoded_args = {}
|
||||
for k, vs in args.items():
|
||||
if isinstance(vs, basestring):
|
||||
vs = [vs]
|
||||
encoded_args[k] = [v.encode("UTF-8") for v in vs]
|
||||
|
||||
query_bytes = urllib.urlencode(encoded_args, True)
|
||||
|
||||
return query_bytes
|
||||
|
||||
@@ -145,7 +145,9 @@ def wrap_request_handler(request_handler, include_metrics=False):
|
||||
"error": "Internal server error",
|
||||
"errcode": Codes.UNKNOWN,
|
||||
},
|
||||
send_cors=True
|
||||
send_cors=True,
|
||||
pretty_print=_request_user_agent_is_curl(request),
|
||||
version_string=self.version_string,
|
||||
)
|
||||
finally:
|
||||
try:
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
# Copyright 2015, 2016 OpenMarket Ltd
|
||||
# Copyright 2017 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -238,6 +239,28 @@ BASE_APPEND_OVERRIDE_RULES = [
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
'rule_id': 'global/override/.m.rule.roomnotif',
|
||||
'conditions': [
|
||||
{
|
||||
'kind': 'event_match',
|
||||
'key': 'content.body',
|
||||
'pattern': '@room',
|
||||
'_id': '_roomnotif_content',
|
||||
},
|
||||
{
|
||||
'kind': 'sender_notification_permission',
|
||||
'key': 'room',
|
||||
'_id': '_roomnotif_pl',
|
||||
},
|
||||
],
|
||||
'actions': [
|
||||
'notify', {
|
||||
'set_tweak': 'highlight',
|
||||
'value': True,
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2015 OpenMarket Ltd
|
||||
# Copyright 2017 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -19,11 +20,13 @@ from twisted.internet import defer
|
||||
|
||||
from .push_rule_evaluator import PushRuleEvaluatorForEvent
|
||||
|
||||
from synapse.event_auth import get_user_power_level
|
||||
from synapse.api.constants import EventTypes, Membership
|
||||
from synapse.metrics import get_metrics_for
|
||||
from synapse.util.caches import metrics as cache_metrics
|
||||
from synapse.util.caches.descriptors import cached
|
||||
from synapse.util.async import Linearizer
|
||||
from synapse.state import POWER_KEY
|
||||
|
||||
from collections import namedtuple
|
||||
|
||||
@@ -59,6 +62,7 @@ class BulkPushRuleEvaluator(object):
|
||||
def __init__(self, hs):
|
||||
self.hs = hs
|
||||
self.store = hs.get_datastore()
|
||||
self.auth = hs.get_auth()
|
||||
|
||||
self.room_push_rule_cache_metrics = cache_metrics.register_cache(
|
||||
"cache",
|
||||
@@ -108,6 +112,29 @@ class BulkPushRuleEvaluator(object):
|
||||
self.room_push_rule_cache_metrics,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _get_power_levels_and_sender_level(self, event, context):
|
||||
pl_event_id = context.prev_state_ids.get(POWER_KEY)
|
||||
if pl_event_id:
|
||||
# fastpath: if there's a power level event, that's all we need, and
|
||||
# not having a power level event is an extreme edge case
|
||||
pl_event = yield self.store.get_event(pl_event_id)
|
||||
auth_events = {POWER_KEY: pl_event}
|
||||
else:
|
||||
auth_events_ids = yield self.auth.compute_auth_events(
|
||||
event, context.prev_state_ids, for_verification=False,
|
||||
)
|
||||
auth_events = yield self.store.get_events(auth_events_ids)
|
||||
auth_events = {
|
||||
(e.type, e.state_key): e for e in auth_events.itervalues()
|
||||
}
|
||||
|
||||
sender_level = get_user_power_level(event.sender, auth_events)
|
||||
|
||||
pl_event = auth_events.get(POWER_KEY)
|
||||
|
||||
defer.returnValue((pl_event.content if pl_event else {}, sender_level))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def action_for_event_by_user(self, event, context):
|
||||
"""Given an event and context, evaluate the push rules and return
|
||||
@@ -123,7 +150,13 @@ class BulkPushRuleEvaluator(object):
|
||||
event, context
|
||||
)
|
||||
|
||||
evaluator = PushRuleEvaluatorForEvent(event, len(room_members))
|
||||
(power_levels, sender_power_level) = (
|
||||
yield self._get_power_levels_and_sender_level(event, context)
|
||||
)
|
||||
|
||||
evaluator = PushRuleEvaluatorForEvent(
|
||||
event, len(room_members), sender_power_level, power_levels,
|
||||
)
|
||||
|
||||
condition_cache = {}
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2015, 2016 OpenMarket Ltd
|
||||
# Copyright 2017 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -29,6 +30,21 @@ INEQUALITY_EXPR = re.compile("^([=<>]*)([0-9]*)$")
|
||||
|
||||
|
||||
def _room_member_count(ev, condition, room_member_count):
|
||||
return _test_ineq_condition(condition, room_member_count)
|
||||
|
||||
|
||||
def _sender_notification_permission(ev, condition, sender_power_level, power_levels):
|
||||
notif_level_key = condition.get('key')
|
||||
if notif_level_key is None:
|
||||
return False
|
||||
|
||||
notif_levels = power_levels.get('notifications', {})
|
||||
room_notif_level = notif_levels.get(notif_level_key, 50)
|
||||
|
||||
return sender_power_level >= room_notif_level
|
||||
|
||||
|
||||
def _test_ineq_condition(condition, number):
|
||||
if 'is' not in condition:
|
||||
return False
|
||||
m = INEQUALITY_EXPR.match(condition['is'])
|
||||
@@ -41,15 +57,15 @@ def _room_member_count(ev, condition, room_member_count):
|
||||
rhs = int(rhs)
|
||||
|
||||
if ineq == '' or ineq == '==':
|
||||
return room_member_count == rhs
|
||||
return number == rhs
|
||||
elif ineq == '<':
|
||||
return room_member_count < rhs
|
||||
return number < rhs
|
||||
elif ineq == '>':
|
||||
return room_member_count > rhs
|
||||
return number > rhs
|
||||
elif ineq == '>=':
|
||||
return room_member_count >= rhs
|
||||
return number >= rhs
|
||||
elif ineq == '<=':
|
||||
return room_member_count <= rhs
|
||||
return number <= rhs
|
||||
else:
|
||||
return False
|
||||
|
||||
@@ -65,9 +81,11 @@ def tweaks_for_actions(actions):
|
||||
|
||||
|
||||
class PushRuleEvaluatorForEvent(object):
|
||||
def __init__(self, event, room_member_count):
|
||||
def __init__(self, event, room_member_count, sender_power_level, power_levels):
|
||||
self._event = event
|
||||
self._room_member_count = room_member_count
|
||||
self._sender_power_level = sender_power_level
|
||||
self._power_levels = power_levels
|
||||
|
||||
# Maps strings of e.g. 'content.body' -> event["content"]["body"]
|
||||
self._value_cache = _flatten_dict(event)
|
||||
@@ -81,6 +99,10 @@ class PushRuleEvaluatorForEvent(object):
|
||||
return _room_member_count(
|
||||
self._event, condition, self._room_member_count
|
||||
)
|
||||
elif condition['kind'] == 'sender_notification_permission':
|
||||
return _sender_notification_permission(
|
||||
self._event, condition, self._sender_power_level, self._power_levels,
|
||||
)
|
||||
else:
|
||||
return True
|
||||
|
||||
@@ -183,7 +205,7 @@ def _glob_to_re(glob, word_boundary):
|
||||
r,
|
||||
)
|
||||
if word_boundary:
|
||||
r = r"\b%s\b" % (r,)
|
||||
r = _re_word_boundary(r)
|
||||
|
||||
return re.compile(r, flags=re.IGNORECASE)
|
||||
else:
|
||||
@@ -192,7 +214,7 @@ def _glob_to_re(glob, word_boundary):
|
||||
return re.compile(r, flags=re.IGNORECASE)
|
||||
elif word_boundary:
|
||||
r = re.escape(glob)
|
||||
r = r"\b%s\b" % (r,)
|
||||
r = _re_word_boundary(r)
|
||||
|
||||
return re.compile(r, flags=re.IGNORECASE)
|
||||
else:
|
||||
@@ -200,6 +222,18 @@ def _glob_to_re(glob, word_boundary):
|
||||
return re.compile(r, flags=re.IGNORECASE)
|
||||
|
||||
|
||||
def _re_word_boundary(r):
|
||||
"""
|
||||
Adds word boundary characters to the start and end of an
|
||||
expression to require that the match occur as a whole word,
|
||||
but do so respecting the fact that strings starting or ending
|
||||
with non-word characters will change word boundaries.
|
||||
"""
|
||||
# we can't use \b as it chokes on unicode. however \W seems to be okay
|
||||
# as shorthand for [^0-9A-Za-z_].
|
||||
return r"(^|\W)%s(\W|$)" % (r,)
|
||||
|
||||
|
||||
def _flatten_dict(d, prefix=[], result=None):
|
||||
if result is None:
|
||||
result = {}
|
||||
|
||||
@@ -40,7 +40,6 @@ REQUIREMENTS = {
|
||||
"pymacaroons-pynacl": ["pymacaroons"],
|
||||
"msgpack-python>=0.3.0": ["msgpack"],
|
||||
"phonenumbers>=8.2.0": ["phonenumbers"],
|
||||
"affinity": ["affinity"],
|
||||
}
|
||||
CONDITIONAL_REQUIREMENTS = {
|
||||
"web_client": {
|
||||
@@ -59,6 +58,9 @@ CONDITIONAL_REQUIREMENTS = {
|
||||
"psutil": {
|
||||
"psutil>=2.0.0": ["psutil>=2.0.0"],
|
||||
},
|
||||
"affinity": {
|
||||
"affinity": ["affinity"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
|
||||
54
synapse/replication/slave/storage/groups.py
Normal file
54
synapse/replication/slave/storage/groups.py
Normal file
@@ -0,0 +1,54 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2016 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from ._base import BaseSlavedStore
|
||||
from ._slaved_id_tracker import SlavedIdTracker
|
||||
from synapse.storage import DataStore
|
||||
from synapse.util.caches.stream_change_cache import StreamChangeCache
|
||||
|
||||
|
||||
class SlavedGroupServerStore(BaseSlavedStore):
|
||||
def __init__(self, db_conn, hs):
|
||||
super(SlavedGroupServerStore, self).__init__(db_conn, hs)
|
||||
|
||||
self.hs = hs
|
||||
|
||||
self._group_updates_id_gen = SlavedIdTracker(
|
||||
db_conn, "local_group_updates", "stream_id",
|
||||
)
|
||||
self._group_updates_stream_cache = StreamChangeCache(
|
||||
"_group_updates_stream_cache", self._group_updates_id_gen.get_current_token(),
|
||||
)
|
||||
|
||||
get_groups_changes_for_user = DataStore.get_groups_changes_for_user.__func__
|
||||
get_group_stream_token = DataStore.get_group_stream_token.__func__
|
||||
get_all_groups_for_user = DataStore.get_all_groups_for_user.__func__
|
||||
|
||||
def stream_positions(self):
|
||||
result = super(SlavedGroupServerStore, self).stream_positions()
|
||||
result["groups"] = self._group_updates_id_gen.get_current_token()
|
||||
return result
|
||||
|
||||
def process_replication_rows(self, stream_name, token, rows):
|
||||
if stream_name == "groups":
|
||||
self._group_updates_id_gen.advance(token)
|
||||
for row in rows:
|
||||
self._group_updates_stream_cache.entity_has_changed(
|
||||
row.user_id, token
|
||||
)
|
||||
|
||||
return super(SlavedGroupServerStore, self).process_replication_rows(
|
||||
stream_name, token, rows
|
||||
)
|
||||
@@ -160,7 +160,11 @@ class ReplicationStreamer(object):
|
||||
"Getting stream: %s: %s -> %s",
|
||||
stream.NAME, stream.last_token, stream.upto_token
|
||||
)
|
||||
updates, current_token = yield stream.get_updates()
|
||||
try:
|
||||
updates, current_token = yield stream.get_updates()
|
||||
except:
|
||||
logger.info("Failed to handle stream %s", stream.NAME)
|
||||
raise
|
||||
|
||||
logger.debug(
|
||||
"Sending %d updates to %d connections",
|
||||
|
||||
@@ -118,6 +118,12 @@ CurrentStateDeltaStreamRow = namedtuple("CurrentStateDeltaStream", (
|
||||
"state_key", # str
|
||||
"event_id", # str, optional
|
||||
))
|
||||
GroupsStreamRow = namedtuple("GroupsStreamRow", (
|
||||
"group_id", # str
|
||||
"user_id", # str
|
||||
"type", # str
|
||||
"content", # dict
|
||||
))
|
||||
|
||||
|
||||
class Stream(object):
|
||||
@@ -464,6 +470,19 @@ class CurrentStateDeltaStream(Stream):
|
||||
super(CurrentStateDeltaStream, self).__init__(hs)
|
||||
|
||||
|
||||
class GroupServerStream(Stream):
|
||||
NAME = "groups"
|
||||
ROW_TYPE = GroupsStreamRow
|
||||
|
||||
def __init__(self, hs):
|
||||
store = hs.get_datastore()
|
||||
|
||||
self.current_token = store.get_group_stream_token
|
||||
self.update_function = store.get_all_groups_changes
|
||||
|
||||
super(GroupServerStream, self).__init__(hs)
|
||||
|
||||
|
||||
STREAMS_MAP = {
|
||||
stream.NAME: stream
|
||||
for stream in (
|
||||
@@ -482,5 +501,6 @@ STREAMS_MAP = {
|
||||
TagAccountDataStream,
|
||||
AccountDataStream,
|
||||
CurrentStateDeltaStream,
|
||||
GroupServerStream,
|
||||
)
|
||||
}
|
||||
|
||||
@@ -52,6 +52,7 @@ from synapse.rest.client.v2_alpha import (
|
||||
thirdparty,
|
||||
sendtodevice,
|
||||
user_directory,
|
||||
groups,
|
||||
)
|
||||
|
||||
from synapse.http.server import JsonResource
|
||||
@@ -102,3 +103,4 @@ class ClientRestResource(JsonResource):
|
||||
thirdparty.register_servlets(hs, client_resource)
|
||||
sendtodevice.register_servlets(hs, client_resource)
|
||||
user_directory.register_servlets(hs, client_resource)
|
||||
groups.register_servlets(hs, client_resource)
|
||||
|
||||
@@ -26,13 +26,13 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet):
|
||||
|
||||
def __init__(self, hs):
|
||||
super(ProfileDisplaynameRestServlet, self).__init__(hs)
|
||||
self.handlers = hs.get_handlers()
|
||||
self.profile_handler = hs.get_profile_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, user_id):
|
||||
user = UserID.from_string(user_id)
|
||||
|
||||
displayname = yield self.handlers.profile_handler.get_displayname(
|
||||
displayname = yield self.profile_handler.get_displayname(
|
||||
user,
|
||||
)
|
||||
|
||||
@@ -55,7 +55,7 @@ class ProfileDisplaynameRestServlet(ClientV1RestServlet):
|
||||
except:
|
||||
defer.returnValue((400, "Unable to parse name"))
|
||||
|
||||
yield self.handlers.profile_handler.set_displayname(
|
||||
yield self.profile_handler.set_displayname(
|
||||
user, requester, new_name, is_admin)
|
||||
|
||||
defer.returnValue((200, {}))
|
||||
@@ -69,13 +69,13 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet):
|
||||
|
||||
def __init__(self, hs):
|
||||
super(ProfileAvatarURLRestServlet, self).__init__(hs)
|
||||
self.handlers = hs.get_handlers()
|
||||
self.profile_handler = hs.get_profile_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, user_id):
|
||||
user = UserID.from_string(user_id)
|
||||
|
||||
avatar_url = yield self.handlers.profile_handler.get_avatar_url(
|
||||
avatar_url = yield self.profile_handler.get_avatar_url(
|
||||
user,
|
||||
)
|
||||
|
||||
@@ -97,7 +97,7 @@ class ProfileAvatarURLRestServlet(ClientV1RestServlet):
|
||||
except:
|
||||
defer.returnValue((400, "Unable to parse name"))
|
||||
|
||||
yield self.handlers.profile_handler.set_avatar_url(
|
||||
yield self.profile_handler.set_avatar_url(
|
||||
user, requester, new_name, is_admin)
|
||||
|
||||
defer.returnValue((200, {}))
|
||||
@@ -111,16 +111,16 @@ class ProfileRestServlet(ClientV1RestServlet):
|
||||
|
||||
def __init__(self, hs):
|
||||
super(ProfileRestServlet, self).__init__(hs)
|
||||
self.handlers = hs.get_handlers()
|
||||
self.profile_handler = hs.get_profile_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, user_id):
|
||||
user = UserID.from_string(user_id)
|
||||
|
||||
displayname = yield self.handlers.profile_handler.get_displayname(
|
||||
displayname = yield self.profile_handler.get_displayname(
|
||||
user,
|
||||
)
|
||||
avatar_url = yield self.handlers.profile_handler.get_avatar_url(
|
||||
avatar_url = yield self.profile_handler.get_avatar_url(
|
||||
user,
|
||||
)
|
||||
|
||||
|
||||
@@ -398,22 +398,18 @@ class JoinedRoomMemberListRestServlet(ClientV1RestServlet):
|
||||
|
||||
def __init__(self, hs):
|
||||
super(JoinedRoomMemberListRestServlet, self).__init__(hs)
|
||||
self.state = hs.get_state_handler()
|
||||
self.message_handler = hs.get_handlers().message_handler
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, room_id):
|
||||
yield self.auth.get_user_by_req(request)
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
|
||||
users_with_profile = yield self.state.get_current_user_in_room(room_id)
|
||||
users_with_profile = yield self.message_handler.get_joined_members(
|
||||
requester, room_id,
|
||||
)
|
||||
|
||||
defer.returnValue((200, {
|
||||
"joined": {
|
||||
user_id: {
|
||||
"avatar_url": profile.avatar_url,
|
||||
"display_name": profile.display_name,
|
||||
}
|
||||
for user_id, profile in users_with_profile.iteritems()
|
||||
}
|
||||
"joined": users_with_profile,
|
||||
}))
|
||||
|
||||
|
||||
|
||||
717
synapse/rest/client/v2_alpha/groups.py
Normal file
717
synapse/rest/client/v2_alpha/groups.py
Normal file
@@ -0,0 +1,717 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2017 Vector Creations Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.http.servlet import RestServlet, parse_json_object_from_request
|
||||
from synapse.types import GroupID
|
||||
|
||||
from ._base import client_v2_patterns
|
||||
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GroupServlet(RestServlet):
|
||||
"""Get the group profile
|
||||
"""
|
||||
PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/profile$")
|
||||
|
||||
def __init__(self, hs):
|
||||
super(GroupServlet, self).__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self.clock = hs.get_clock()
|
||||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, group_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
user_id = requester.user.to_string()
|
||||
|
||||
group_description = yield self.groups_handler.get_group_profile(group_id, user_id)
|
||||
|
||||
defer.returnValue((200, group_description))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request, group_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
user_id = requester.user.to_string()
|
||||
|
||||
content = parse_json_object_from_request(request)
|
||||
yield self.groups_handler.update_group_profile(
|
||||
group_id, user_id, content,
|
||||
)
|
||||
|
||||
defer.returnValue((200, {}))
|
||||
|
||||
|
||||
class GroupSummaryServlet(RestServlet):
|
||||
"""Get the full group summary
|
||||
"""
|
||||
PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/summary$")
|
||||
|
||||
def __init__(self, hs):
|
||||
super(GroupSummaryServlet, self).__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self.clock = hs.get_clock()
|
||||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, group_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
user_id = requester.user.to_string()
|
||||
|
||||
get_group_summary = yield self.groups_handler.get_group_summary(group_id, user_id)
|
||||
|
||||
defer.returnValue((200, get_group_summary))
|
||||
|
||||
|
||||
class GroupSummaryRoomsCatServlet(RestServlet):
|
||||
"""Update/delete a rooms entry in the summary.
|
||||
|
||||
Matches both:
|
||||
- /groups/:group/summary/rooms/:room_id
|
||||
- /groups/:group/summary/categories/:category/rooms/:room_id
|
||||
"""
|
||||
PATTERNS = client_v2_patterns(
|
||||
"/groups/(?P<group_id>[^/]*)/summary"
|
||||
"(/categories/(?P<category_id>[^/]+))?"
|
||||
"/rooms/(?P<room_id>[^/]*)$"
|
||||
)
|
||||
|
||||
def __init__(self, hs):
|
||||
super(GroupSummaryRoomsCatServlet, self).__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self.clock = hs.get_clock()
|
||||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_PUT(self, request, group_id, category_id, room_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
user_id = requester.user.to_string()
|
||||
|
||||
content = parse_json_object_from_request(request)
|
||||
resp = yield self.groups_handler.update_group_summary_room(
|
||||
group_id, user_id,
|
||||
room_id=room_id,
|
||||
category_id=category_id,
|
||||
content=content,
|
||||
)
|
||||
|
||||
defer.returnValue((200, resp))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_DELETE(self, request, group_id, category_id, room_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
user_id = requester.user.to_string()
|
||||
|
||||
resp = yield self.groups_handler.delete_group_summary_room(
|
||||
group_id, user_id,
|
||||
room_id=room_id,
|
||||
category_id=category_id,
|
||||
)
|
||||
|
||||
defer.returnValue((200, resp))
|
||||
|
||||
|
||||
class GroupCategoryServlet(RestServlet):
|
||||
"""Get/add/update/delete a group category
|
||||
"""
|
||||
PATTERNS = client_v2_patterns(
|
||||
"/groups/(?P<group_id>[^/]*)/categories/(?P<category_id>[^/]+)$"
|
||||
)
|
||||
|
||||
def __init__(self, hs):
|
||||
super(GroupCategoryServlet, self).__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self.clock = hs.get_clock()
|
||||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, group_id, category_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
user_id = requester.user.to_string()
|
||||
|
||||
category = yield self.groups_handler.get_group_category(
|
||||
group_id, user_id,
|
||||
category_id=category_id,
|
||||
)
|
||||
|
||||
defer.returnValue((200, category))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_PUT(self, request, group_id, category_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
user_id = requester.user.to_string()
|
||||
|
||||
content = parse_json_object_from_request(request)
|
||||
resp = yield self.groups_handler.update_group_category(
|
||||
group_id, user_id,
|
||||
category_id=category_id,
|
||||
content=content,
|
||||
)
|
||||
|
||||
defer.returnValue((200, resp))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_DELETE(self, request, group_id, category_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
user_id = requester.user.to_string()
|
||||
|
||||
resp = yield self.groups_handler.delete_group_category(
|
||||
group_id, user_id,
|
||||
category_id=category_id,
|
||||
)
|
||||
|
||||
defer.returnValue((200, resp))
|
||||
|
||||
|
||||
class GroupCategoriesServlet(RestServlet):
|
||||
"""Get all group categories
|
||||
"""
|
||||
PATTERNS = client_v2_patterns(
|
||||
"/groups/(?P<group_id>[^/]*)/categories/$"
|
||||
)
|
||||
|
||||
def __init__(self, hs):
|
||||
super(GroupCategoriesServlet, self).__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self.clock = hs.get_clock()
|
||||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, group_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
user_id = requester.user.to_string()
|
||||
|
||||
category = yield self.groups_handler.get_group_categories(
|
||||
group_id, user_id,
|
||||
)
|
||||
|
||||
defer.returnValue((200, category))
|
||||
|
||||
|
||||
class GroupRoleServlet(RestServlet):
|
||||
"""Get/add/update/delete a group role
|
||||
"""
|
||||
PATTERNS = client_v2_patterns(
|
||||
"/groups/(?P<group_id>[^/]*)/roles/(?P<role_id>[^/]+)$"
|
||||
)
|
||||
|
||||
def __init__(self, hs):
|
||||
super(GroupRoleServlet, self).__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self.clock = hs.get_clock()
|
||||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, group_id, role_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
user_id = requester.user.to_string()
|
||||
|
||||
category = yield self.groups_handler.get_group_role(
|
||||
group_id, user_id,
|
||||
role_id=role_id,
|
||||
)
|
||||
|
||||
defer.returnValue((200, category))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_PUT(self, request, group_id, role_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
user_id = requester.user.to_string()
|
||||
|
||||
content = parse_json_object_from_request(request)
|
||||
resp = yield self.groups_handler.update_group_role(
|
||||
group_id, user_id,
|
||||
role_id=role_id,
|
||||
content=content,
|
||||
)
|
||||
|
||||
defer.returnValue((200, resp))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_DELETE(self, request, group_id, role_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
user_id = requester.user.to_string()
|
||||
|
||||
resp = yield self.groups_handler.delete_group_role(
|
||||
group_id, user_id,
|
||||
role_id=role_id,
|
||||
)
|
||||
|
||||
defer.returnValue((200, resp))
|
||||
|
||||
|
||||
class GroupRolesServlet(RestServlet):
|
||||
"""Get all group roles
|
||||
"""
|
||||
PATTERNS = client_v2_patterns(
|
||||
"/groups/(?P<group_id>[^/]*)/roles/$"
|
||||
)
|
||||
|
||||
def __init__(self, hs):
|
||||
super(GroupRolesServlet, self).__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self.clock = hs.get_clock()
|
||||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, group_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
user_id = requester.user.to_string()
|
||||
|
||||
category = yield self.groups_handler.get_group_roles(
|
||||
group_id, user_id,
|
||||
)
|
||||
|
||||
defer.returnValue((200, category))
|
||||
|
||||
|
||||
class GroupSummaryUsersRoleServlet(RestServlet):
|
||||
"""Update/delete a user's entry in the summary.
|
||||
|
||||
Matches both:
|
||||
- /groups/:group/summary/users/:room_id
|
||||
- /groups/:group/summary/roles/:role/users/:user_id
|
||||
"""
|
||||
PATTERNS = client_v2_patterns(
|
||||
"/groups/(?P<group_id>[^/]*)/summary"
|
||||
"(/roles/(?P<role_id>[^/]+))?"
|
||||
"/users/(?P<user_id>[^/]*)$"
|
||||
)
|
||||
|
||||
def __init__(self, hs):
|
||||
super(GroupSummaryUsersRoleServlet, self).__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self.clock = hs.get_clock()
|
||||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_PUT(self, request, group_id, role_id, user_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
content = parse_json_object_from_request(request)
|
||||
resp = yield self.groups_handler.update_group_summary_user(
|
||||
group_id, requester_user_id,
|
||||
user_id=user_id,
|
||||
role_id=role_id,
|
||||
content=content,
|
||||
)
|
||||
|
||||
defer.returnValue((200, resp))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_DELETE(self, request, group_id, role_id, user_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
resp = yield self.groups_handler.delete_group_summary_user(
|
||||
group_id, requester_user_id,
|
||||
user_id=user_id,
|
||||
role_id=role_id,
|
||||
)
|
||||
|
||||
defer.returnValue((200, resp))
|
||||
|
||||
|
||||
class GroupRoomServlet(RestServlet):
|
||||
"""Get all rooms in a group
|
||||
"""
|
||||
PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/rooms$")
|
||||
|
||||
def __init__(self, hs):
|
||||
super(GroupRoomServlet, self).__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self.clock = hs.get_clock()
|
||||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, group_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
user_id = requester.user.to_string()
|
||||
|
||||
result = yield self.groups_handler.get_rooms_in_group(group_id, user_id)
|
||||
|
||||
defer.returnValue((200, result))
|
||||
|
||||
|
||||
class GroupUsersServlet(RestServlet):
|
||||
"""Get all users in a group
|
||||
"""
|
||||
PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/users$")
|
||||
|
||||
def __init__(self, hs):
|
||||
super(GroupUsersServlet, self).__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self.clock = hs.get_clock()
|
||||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, group_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
user_id = requester.user.to_string()
|
||||
|
||||
result = yield self.groups_handler.get_users_in_group(group_id, user_id)
|
||||
|
||||
defer.returnValue((200, result))
|
||||
|
||||
|
||||
class GroupInvitedUsersServlet(RestServlet):
|
||||
"""Get users invited to a group
|
||||
"""
|
||||
PATTERNS = client_v2_patterns("/groups/(?P<group_id>[^/]*)/invited_users$")
|
||||
|
||||
def __init__(self, hs):
|
||||
super(GroupInvitedUsersServlet, self).__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self.clock = hs.get_clock()
|
||||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, group_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
user_id = requester.user.to_string()
|
||||
|
||||
result = yield self.groups_handler.get_invited_users_in_group(group_id, user_id)
|
||||
|
||||
defer.returnValue((200, result))
|
||||
|
||||
|
||||
class GroupCreateServlet(RestServlet):
|
||||
"""Create a group
|
||||
"""
|
||||
PATTERNS = client_v2_patterns("/create_group$")
|
||||
|
||||
def __init__(self, hs):
|
||||
super(GroupCreateServlet, self).__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self.clock = hs.get_clock()
|
||||
self.groups_handler = hs.get_groups_local_handler()
|
||||
self.server_name = hs.hostname
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
user_id = requester.user.to_string()
|
||||
|
||||
# TODO: Create group on remote server
|
||||
content = parse_json_object_from_request(request)
|
||||
localpart = content.pop("localpart")
|
||||
group_id = GroupID.create(localpart, self.server_name).to_string()
|
||||
|
||||
result = yield self.groups_handler.create_group(group_id, user_id, content)
|
||||
|
||||
defer.returnValue((200, result))
|
||||
|
||||
|
||||
class GroupAdminRoomsServlet(RestServlet):
|
||||
"""Add a room to the group
|
||||
"""
|
||||
PATTERNS = client_v2_patterns(
|
||||
"/groups/(?P<group_id>[^/]*)/admin/rooms/(?P<room_id>[^/]*)$"
|
||||
)
|
||||
|
||||
def __init__(self, hs):
|
||||
super(GroupAdminRoomsServlet, self).__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self.clock = hs.get_clock()
|
||||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_PUT(self, request, group_id, room_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
user_id = requester.user.to_string()
|
||||
|
||||
content = parse_json_object_from_request(request)
|
||||
result = yield self.groups_handler.add_room_to_group(
|
||||
group_id, user_id, room_id, content,
|
||||
)
|
||||
|
||||
defer.returnValue((200, result))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_DELETE(self, request, group_id, room_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
user_id = requester.user.to_string()
|
||||
|
||||
result = yield self.groups_handler.remove_room_from_group(
|
||||
group_id, user_id, room_id,
|
||||
)
|
||||
|
||||
defer.returnValue((200, result))
|
||||
|
||||
|
||||
class GroupAdminUsersInviteServlet(RestServlet):
|
||||
"""Invite a user to the group
|
||||
"""
|
||||
PATTERNS = client_v2_patterns(
|
||||
"/groups/(?P<group_id>[^/]*)/admin/users/invite/(?P<user_id>[^/]*)$"
|
||||
)
|
||||
|
||||
def __init__(self, hs):
|
||||
super(GroupAdminUsersInviteServlet, self).__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self.clock = hs.get_clock()
|
||||
self.groups_handler = hs.get_groups_local_handler()
|
||||
self.store = hs.get_datastore()
|
||||
self.is_mine_id = hs.is_mine_id
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_PUT(self, request, group_id, user_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
content = parse_json_object_from_request(request)
|
||||
config = content.get("config", {})
|
||||
result = yield self.groups_handler.invite(
|
||||
group_id, user_id, requester_user_id, config,
|
||||
)
|
||||
|
||||
defer.returnValue((200, result))
|
||||
|
||||
|
||||
class GroupAdminUsersKickServlet(RestServlet):
|
||||
"""Kick a user from the group
|
||||
"""
|
||||
PATTERNS = client_v2_patterns(
|
||||
"/groups/(?P<group_id>[^/]*)/admin/users/remove/(?P<user_id>[^/]*)$"
|
||||
)
|
||||
|
||||
def __init__(self, hs):
|
||||
super(GroupAdminUsersKickServlet, self).__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self.clock = hs.get_clock()
|
||||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_PUT(self, request, group_id, user_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
content = parse_json_object_from_request(request)
|
||||
result = yield self.groups_handler.remove_user_from_group(
|
||||
group_id, user_id, requester_user_id, content,
|
||||
)
|
||||
|
||||
defer.returnValue((200, result))
|
||||
|
||||
|
||||
class GroupSelfLeaveServlet(RestServlet):
|
||||
"""Leave a joined group
|
||||
"""
|
||||
PATTERNS = client_v2_patterns(
|
||||
"/groups/(?P<group_id>[^/]*)/self/leave$"
|
||||
)
|
||||
|
||||
def __init__(self, hs):
|
||||
super(GroupSelfLeaveServlet, self).__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self.clock = hs.get_clock()
|
||||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_PUT(self, request, group_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
content = parse_json_object_from_request(request)
|
||||
result = yield self.groups_handler.remove_user_from_group(
|
||||
group_id, requester_user_id, requester_user_id, content,
|
||||
)
|
||||
|
||||
defer.returnValue((200, result))
|
||||
|
||||
|
||||
class GroupSelfJoinServlet(RestServlet):
|
||||
"""Attempt to join a group, or knock
|
||||
"""
|
||||
PATTERNS = client_v2_patterns(
|
||||
"/groups/(?P<group_id>[^/]*)/self/join$"
|
||||
)
|
||||
|
||||
def __init__(self, hs):
|
||||
super(GroupSelfJoinServlet, self).__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self.clock = hs.get_clock()
|
||||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_PUT(self, request, group_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
content = parse_json_object_from_request(request)
|
||||
result = yield self.groups_handler.join_group(
|
||||
group_id, requester_user_id, content,
|
||||
)
|
||||
|
||||
defer.returnValue((200, result))
|
||||
|
||||
|
||||
class GroupSelfAcceptInviteServlet(RestServlet):
|
||||
"""Accept a group invite
|
||||
"""
|
||||
PATTERNS = client_v2_patterns(
|
||||
"/groups/(?P<group_id>[^/]*)/self/accept_invite$"
|
||||
)
|
||||
|
||||
def __init__(self, hs):
|
||||
super(GroupSelfAcceptInviteServlet, self).__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self.clock = hs.get_clock()
|
||||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_PUT(self, request, group_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
content = parse_json_object_from_request(request)
|
||||
result = yield self.groups_handler.accept_invite(
|
||||
group_id, requester_user_id, content,
|
||||
)
|
||||
|
||||
defer.returnValue((200, result))
|
||||
|
||||
|
||||
class GroupSelfUpdatePublicityServlet(RestServlet):
|
||||
"""Update whether we publicise a users membership of a group
|
||||
"""
|
||||
PATTERNS = client_v2_patterns(
|
||||
"/groups/(?P<group_id>[^/]*)/self/update_publicity$"
|
||||
)
|
||||
|
||||
def __init__(self, hs):
|
||||
super(GroupSelfUpdatePublicityServlet, self).__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self.clock = hs.get_clock()
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_PUT(self, request, group_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
requester_user_id = requester.user.to_string()
|
||||
|
||||
content = parse_json_object_from_request(request)
|
||||
publicise = content["publicise"]
|
||||
yield self.store.update_group_publicity(
|
||||
group_id, requester_user_id, publicise,
|
||||
)
|
||||
|
||||
defer.returnValue((200, {}))
|
||||
|
||||
|
||||
class PublicisedGroupsForUserServlet(RestServlet):
|
||||
"""Get the list of groups a user is advertising
|
||||
"""
|
||||
PATTERNS = client_v2_patterns(
|
||||
"/publicised_groups/(?P<user_id>[^/]*)$"
|
||||
)
|
||||
|
||||
def __init__(self, hs):
|
||||
super(PublicisedGroupsForUserServlet, self).__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self.clock = hs.get_clock()
|
||||
self.store = hs.get_datastore()
|
||||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, user_id):
|
||||
yield self.auth.get_user_by_req(request)
|
||||
|
||||
result = yield self.groups_handler.get_publicised_groups_for_user(
|
||||
user_id
|
||||
)
|
||||
|
||||
defer.returnValue((200, result))
|
||||
|
||||
|
||||
class PublicisedGroupsForUsersServlet(RestServlet):
|
||||
"""Get the list of groups a user is advertising
|
||||
"""
|
||||
PATTERNS = client_v2_patterns(
|
||||
"/publicised_groups$"
|
||||
)
|
||||
|
||||
def __init__(self, hs):
|
||||
super(PublicisedGroupsForUsersServlet, self).__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self.clock = hs.get_clock()
|
||||
self.store = hs.get_datastore()
|
||||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request):
|
||||
yield self.auth.get_user_by_req(request)
|
||||
|
||||
content = parse_json_object_from_request(request)
|
||||
user_ids = content["user_ids"]
|
||||
|
||||
result = yield self.groups_handler.bulk_get_publicised_groups(
|
||||
user_ids
|
||||
)
|
||||
|
||||
defer.returnValue((200, result))
|
||||
|
||||
|
||||
class GroupsForUserServlet(RestServlet):
|
||||
"""Get all groups the logged in user is joined to
|
||||
"""
|
||||
PATTERNS = client_v2_patterns(
|
||||
"/joined_groups$"
|
||||
)
|
||||
|
||||
def __init__(self, hs):
|
||||
super(GroupsForUserServlet, self).__init__()
|
||||
self.auth = hs.get_auth()
|
||||
self.clock = hs.get_clock()
|
||||
self.groups_handler = hs.get_groups_local_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
user_id = requester.user.to_string()
|
||||
|
||||
result = yield self.groups_handler.get_joined_groups(user_id)
|
||||
|
||||
defer.returnValue((200, result))
|
||||
|
||||
|
||||
def register_servlets(hs, http_server):
|
||||
GroupServlet(hs).register(http_server)
|
||||
GroupSummaryServlet(hs).register(http_server)
|
||||
GroupInvitedUsersServlet(hs).register(http_server)
|
||||
GroupUsersServlet(hs).register(http_server)
|
||||
GroupRoomServlet(hs).register(http_server)
|
||||
GroupCreateServlet(hs).register(http_server)
|
||||
GroupAdminRoomsServlet(hs).register(http_server)
|
||||
GroupAdminUsersInviteServlet(hs).register(http_server)
|
||||
GroupAdminUsersKickServlet(hs).register(http_server)
|
||||
GroupSelfLeaveServlet(hs).register(http_server)
|
||||
GroupSelfJoinServlet(hs).register(http_server)
|
||||
GroupSelfAcceptInviteServlet(hs).register(http_server)
|
||||
GroupsForUserServlet(hs).register(http_server)
|
||||
GroupCategoryServlet(hs).register(http_server)
|
||||
GroupCategoriesServlet(hs).register(http_server)
|
||||
GroupSummaryRoomsCatServlet(hs).register(http_server)
|
||||
GroupRoleServlet(hs).register(http_server)
|
||||
GroupRolesServlet(hs).register(http_server)
|
||||
GroupSelfUpdatePublicityServlet(hs).register(http_server)
|
||||
GroupSummaryUsersRoleServlet(hs).register(http_server)
|
||||
PublicisedGroupsForUserServlet(hs).register(http_server)
|
||||
PublicisedGroupsForUsersServlet(hs).register(http_server)
|
||||
@@ -17,8 +17,10 @@
|
||||
from twisted.internet import defer
|
||||
|
||||
import synapse
|
||||
import synapse.types
|
||||
from synapse.api.auth import get_access_token_from_request, has_access_token
|
||||
from synapse.api.constants import LoginType
|
||||
from synapse.types import RoomID, RoomAlias
|
||||
from synapse.api.errors import SynapseError, Codes, UnrecognizedRequestError
|
||||
from synapse.http.servlet import (
|
||||
RestServlet, parse_json_object_from_request, assert_params_in_request, parse_string
|
||||
@@ -170,6 +172,7 @@ class RegisterRestServlet(RestServlet):
|
||||
self.auth_handler = hs.get_auth_handler()
|
||||
self.registration_handler = hs.get_handlers().registration_handler
|
||||
self.identity_handler = hs.get_handlers().identity_handler
|
||||
self.room_member_handler = hs.get_handlers().room_member_handler
|
||||
self.device_handler = hs.get_device_handler()
|
||||
self.macaroon_gen = hs.get_macaroon_generator()
|
||||
|
||||
@@ -340,6 +343,14 @@ class RegisterRestServlet(RestServlet):
|
||||
generate_token=False,
|
||||
)
|
||||
|
||||
# auto-join the user to any rooms we're supposed to dump them into
|
||||
fake_requester = synapse.types.create_requester(registered_user_id)
|
||||
for r in self.hs.config.auto_join_rooms:
|
||||
try:
|
||||
yield self._join_user_to_room(fake_requester, r)
|
||||
except Exception as e:
|
||||
logger.error("Failed to join new user to %r: %r", r, e)
|
||||
|
||||
# remember that we've now registered that user account, and with
|
||||
# what user ID (since the user may not have specified)
|
||||
self.auth_handler.set_session_data(
|
||||
@@ -372,6 +383,29 @@ class RegisterRestServlet(RestServlet):
|
||||
def on_OPTIONS(self, _):
|
||||
return 200, {}
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _join_user_to_room(self, requester, room_identifier):
|
||||
room_id = None
|
||||
if RoomID.is_valid(room_identifier):
|
||||
room_id = room_identifier
|
||||
elif RoomAlias.is_valid(room_identifier):
|
||||
room_alias = RoomAlias.from_string(room_identifier)
|
||||
room_id, remote_room_hosts = (
|
||||
yield self.room_member_handler.lookup_room_alias(room_alias)
|
||||
)
|
||||
room_id = room_id.to_string()
|
||||
else:
|
||||
raise SynapseError(400, "%s was not legal room ID or room alias" % (
|
||||
room_identifier,
|
||||
))
|
||||
|
||||
yield self.room_member_handler.update_membership(
|
||||
requester=requester,
|
||||
target=requester.user,
|
||||
room_id=room_id,
|
||||
action="join",
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _do_appservice_registration(self, username, as_token, body):
|
||||
user_id = yield self.registration_handler.appservice_register(
|
||||
|
||||
@@ -200,6 +200,11 @@ class SyncRestServlet(RestServlet):
|
||||
"invite": invited,
|
||||
"leave": archived,
|
||||
},
|
||||
"groups": {
|
||||
"join": sync_result.groups.join,
|
||||
"invite": sync_result.groups.invite,
|
||||
"leave": sync_result.groups.leave,
|
||||
},
|
||||
"device_one_time_keys_count": sync_result.device_one_time_keys_count,
|
||||
"next_batch": sync_result.next_batch.to_string(),
|
||||
}
|
||||
|
||||
@@ -14,78 +14,200 @@
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import re
|
||||
import functools
|
||||
|
||||
NEW_FORMAT_ID_RE = re.compile(r"^\d\d\d\d-\d\d-\d\d")
|
||||
|
||||
|
||||
def _wrap_in_base_path(func):
|
||||
"""Takes a function that returns a relative path and turns it into an
|
||||
absolute path based on the location of the primary media store
|
||||
"""
|
||||
@functools.wraps(func)
|
||||
def _wrapped(self, *args, **kwargs):
|
||||
path = func(self, *args, **kwargs)
|
||||
return os.path.join(self.base_path, path)
|
||||
|
||||
return _wrapped
|
||||
|
||||
|
||||
class MediaFilePaths(object):
|
||||
"""Describes where files are stored on disk.
|
||||
|
||||
def __init__(self, base_path):
|
||||
self.base_path = base_path
|
||||
Most of the functions have a `*_rel` variant which returns a file path that
|
||||
is relative to the base media store path. This is mainly used when we want
|
||||
to write to the backup media store (when one is configured)
|
||||
"""
|
||||
|
||||
def default_thumbnail(self, default_top_level, default_sub_type, width,
|
||||
height, content_type, method):
|
||||
def __init__(self, primary_base_path):
|
||||
self.base_path = primary_base_path
|
||||
|
||||
def default_thumbnail_rel(self, default_top_level, default_sub_type, width,
|
||||
height, content_type, method):
|
||||
top_level_type, sub_type = content_type.split("/")
|
||||
file_name = "%i-%i-%s-%s-%s" % (
|
||||
width, height, top_level_type, sub_type, method
|
||||
)
|
||||
return os.path.join(
|
||||
self.base_path, "default_thumbnails", default_top_level,
|
||||
"default_thumbnails", default_top_level,
|
||||
default_sub_type, file_name
|
||||
)
|
||||
|
||||
def local_media_filepath(self, media_id):
|
||||
default_thumbnail = _wrap_in_base_path(default_thumbnail_rel)
|
||||
|
||||
def local_media_filepath_rel(self, media_id):
|
||||
return os.path.join(
|
||||
self.base_path, "local_content",
|
||||
"local_content",
|
||||
media_id[0:2], media_id[2:4], media_id[4:]
|
||||
)
|
||||
|
||||
def local_media_thumbnail(self, media_id, width, height, content_type,
|
||||
method):
|
||||
local_media_filepath = _wrap_in_base_path(local_media_filepath_rel)
|
||||
|
||||
def local_media_thumbnail_rel(self, media_id, width, height, content_type,
|
||||
method):
|
||||
top_level_type, sub_type = content_type.split("/")
|
||||
file_name = "%i-%i-%s-%s-%s" % (
|
||||
width, height, top_level_type, sub_type, method
|
||||
)
|
||||
return os.path.join(
|
||||
self.base_path, "local_thumbnails",
|
||||
"local_thumbnails",
|
||||
media_id[0:2], media_id[2:4], media_id[4:],
|
||||
file_name
|
||||
)
|
||||
|
||||
def remote_media_filepath(self, server_name, file_id):
|
||||
local_media_thumbnail = _wrap_in_base_path(local_media_thumbnail_rel)
|
||||
|
||||
def remote_media_filepath_rel(self, server_name, file_id):
|
||||
return os.path.join(
|
||||
self.base_path, "remote_content", server_name,
|
||||
"remote_content", server_name,
|
||||
file_id[0:2], file_id[2:4], file_id[4:]
|
||||
)
|
||||
|
||||
def remote_media_thumbnail(self, server_name, file_id, width, height,
|
||||
content_type, method):
|
||||
remote_media_filepath = _wrap_in_base_path(remote_media_filepath_rel)
|
||||
|
||||
def remote_media_thumbnail_rel(self, server_name, file_id, width, height,
|
||||
content_type, method):
|
||||
top_level_type, sub_type = content_type.split("/")
|
||||
file_name = "%i-%i-%s-%s" % (width, height, top_level_type, sub_type)
|
||||
return os.path.join(
|
||||
self.base_path, "remote_thumbnail", server_name,
|
||||
"remote_thumbnail", server_name,
|
||||
file_id[0:2], file_id[2:4], file_id[4:],
|
||||
file_name
|
||||
)
|
||||
|
||||
remote_media_thumbnail = _wrap_in_base_path(remote_media_thumbnail_rel)
|
||||
|
||||
def remote_media_thumbnail_dir(self, server_name, file_id):
|
||||
return os.path.join(
|
||||
self.base_path, "remote_thumbnail", server_name,
|
||||
file_id[0:2], file_id[2:4], file_id[4:],
|
||||
)
|
||||
|
||||
def url_cache_filepath(self, media_id):
|
||||
return os.path.join(
|
||||
self.base_path, "url_cache",
|
||||
media_id[0:2], media_id[2:4], media_id[4:]
|
||||
)
|
||||
def url_cache_filepath_rel(self, media_id):
|
||||
if NEW_FORMAT_ID_RE.match(media_id):
|
||||
# Media id is of the form <DATE><RANDOM_STRING>
|
||||
# E.g.: 2017-09-28-fsdRDt24DS234dsf
|
||||
return os.path.join(
|
||||
"url_cache",
|
||||
media_id[:10], media_id[11:]
|
||||
)
|
||||
else:
|
||||
return os.path.join(
|
||||
"url_cache",
|
||||
media_id[0:2], media_id[2:4], media_id[4:],
|
||||
)
|
||||
|
||||
url_cache_filepath = _wrap_in_base_path(url_cache_filepath_rel)
|
||||
|
||||
def url_cache_filepath_dirs_to_delete(self, media_id):
|
||||
"The dirs to try and remove if we delete the media_id file"
|
||||
if NEW_FORMAT_ID_RE.match(media_id):
|
||||
return [
|
||||
os.path.join(
|
||||
self.base_path, "url_cache",
|
||||
media_id[:10],
|
||||
),
|
||||
]
|
||||
else:
|
||||
return [
|
||||
os.path.join(
|
||||
self.base_path, "url_cache",
|
||||
media_id[0:2], media_id[2:4],
|
||||
),
|
||||
os.path.join(
|
||||
self.base_path, "url_cache",
|
||||
media_id[0:2],
|
||||
),
|
||||
]
|
||||
|
||||
def url_cache_thumbnail_rel(self, media_id, width, height, content_type,
|
||||
method):
|
||||
# Media id is of the form <DATE><RANDOM_STRING>
|
||||
# E.g.: 2017-09-28-fsdRDt24DS234dsf
|
||||
|
||||
def url_cache_thumbnail(self, media_id, width, height, content_type,
|
||||
method):
|
||||
top_level_type, sub_type = content_type.split("/")
|
||||
file_name = "%i-%i-%s-%s-%s" % (
|
||||
width, height, top_level_type, sub_type, method
|
||||
)
|
||||
return os.path.join(
|
||||
self.base_path, "url_cache_thumbnails",
|
||||
media_id[0:2], media_id[2:4], media_id[4:],
|
||||
file_name
|
||||
)
|
||||
|
||||
if NEW_FORMAT_ID_RE.match(media_id):
|
||||
return os.path.join(
|
||||
"url_cache_thumbnails",
|
||||
media_id[:10], media_id[11:],
|
||||
file_name
|
||||
)
|
||||
else:
|
||||
return os.path.join(
|
||||
"url_cache_thumbnails",
|
||||
media_id[0:2], media_id[2:4], media_id[4:],
|
||||
file_name
|
||||
)
|
||||
|
||||
url_cache_thumbnail = _wrap_in_base_path(url_cache_thumbnail_rel)
|
||||
|
||||
def url_cache_thumbnail_directory(self, media_id):
|
||||
# Media id is of the form <DATE><RANDOM_STRING>
|
||||
# E.g.: 2017-09-28-fsdRDt24DS234dsf
|
||||
|
||||
if NEW_FORMAT_ID_RE.match(media_id):
|
||||
return os.path.join(
|
||||
self.base_path, "url_cache_thumbnails",
|
||||
media_id[:10], media_id[11:],
|
||||
)
|
||||
else:
|
||||
return os.path.join(
|
||||
self.base_path, "url_cache_thumbnails",
|
||||
media_id[0:2], media_id[2:4], media_id[4:],
|
||||
)
|
||||
|
||||
def url_cache_thumbnail_dirs_to_delete(self, media_id):
|
||||
"The dirs to try and remove if we delete the media_id thumbnails"
|
||||
# Media id is of the form <DATE><RANDOM_STRING>
|
||||
# E.g.: 2017-09-28-fsdRDt24DS234dsf
|
||||
if NEW_FORMAT_ID_RE.match(media_id):
|
||||
return [
|
||||
os.path.join(
|
||||
self.base_path, "url_cache_thumbnails",
|
||||
media_id[:10], media_id[11:],
|
||||
),
|
||||
os.path.join(
|
||||
self.base_path, "url_cache_thumbnails",
|
||||
media_id[:10],
|
||||
),
|
||||
]
|
||||
else:
|
||||
return [
|
||||
os.path.join(
|
||||
self.base_path, "url_cache_thumbnails",
|
||||
media_id[0:2], media_id[2:4], media_id[4:],
|
||||
),
|
||||
os.path.join(
|
||||
self.base_path, "url_cache_thumbnails",
|
||||
media_id[0:2], media_id[2:4],
|
||||
),
|
||||
os.path.join(
|
||||
self.base_path, "url_cache_thumbnails",
|
||||
media_id[0:2],
|
||||
),
|
||||
]
|
||||
|
||||
@@ -33,7 +33,7 @@ from synapse.api.errors import SynapseError, HttpResponseException, \
|
||||
|
||||
from synapse.util.async import Linearizer
|
||||
from synapse.util.stringutils import is_ascii
|
||||
from synapse.util.logcontext import preserve_context_over_fn
|
||||
from synapse.util.logcontext import make_deferred_yieldable, preserve_fn
|
||||
from synapse.util.retryutils import NotRetryingDestination
|
||||
|
||||
import os
|
||||
@@ -59,7 +59,14 @@ class MediaRepository(object):
|
||||
self.store = hs.get_datastore()
|
||||
self.max_upload_size = hs.config.max_upload_size
|
||||
self.max_image_pixels = hs.config.max_image_pixels
|
||||
self.filepaths = MediaFilePaths(hs.config.media_store_path)
|
||||
|
||||
self.primary_base_path = hs.config.media_store_path
|
||||
self.filepaths = MediaFilePaths(self.primary_base_path)
|
||||
|
||||
self.backup_base_path = hs.config.backup_media_store_path
|
||||
|
||||
self.synchronous_backup_media_store = hs.config.synchronous_backup_media_store
|
||||
|
||||
self.dynamic_thumbnails = hs.config.dynamic_thumbnails
|
||||
self.thumbnail_requirements = hs.config.thumbnail_requirements
|
||||
|
||||
@@ -87,18 +94,86 @@ class MediaRepository(object):
|
||||
if not os.path.exists(dirname):
|
||||
os.makedirs(dirname)
|
||||
|
||||
@staticmethod
|
||||
def _write_file_synchronously(source, fname):
|
||||
"""Write `source` to the path `fname` synchronously. Should be called
|
||||
from a thread.
|
||||
|
||||
Args:
|
||||
source: A file like object to be written
|
||||
fname (str): Path to write to
|
||||
"""
|
||||
MediaRepository._makedirs(fname)
|
||||
source.seek(0) # Ensure we read from the start of the file
|
||||
with open(fname, "wb") as f:
|
||||
shutil.copyfileobj(source, f)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def write_to_file_and_backup(self, source, path):
|
||||
"""Write `source` to the on disk media store, and also the backup store
|
||||
if configured.
|
||||
|
||||
Args:
|
||||
source: A file like object that should be written
|
||||
path (str): Relative path to write file to
|
||||
|
||||
Returns:
|
||||
Deferred[str]: the file path written to in the primary media store
|
||||
"""
|
||||
fname = os.path.join(self.primary_base_path, path)
|
||||
|
||||
# Write to the main repository
|
||||
yield make_deferred_yieldable(threads.deferToThread(
|
||||
self._write_file_synchronously, source, fname,
|
||||
))
|
||||
|
||||
# Write to backup repository
|
||||
yield self.copy_to_backup(path)
|
||||
|
||||
defer.returnValue(fname)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def copy_to_backup(self, path):
|
||||
"""Copy a file from the primary to backup media store, if configured.
|
||||
|
||||
Args:
|
||||
path(str): Relative path to write file to
|
||||
"""
|
||||
if self.backup_base_path:
|
||||
primary_fname = os.path.join(self.primary_base_path, path)
|
||||
backup_fname = os.path.join(self.backup_base_path, path)
|
||||
|
||||
# We can either wait for successful writing to the backup repository
|
||||
# or write in the background and immediately return
|
||||
if self.synchronous_backup_media_store:
|
||||
yield make_deferred_yieldable(threads.deferToThread(
|
||||
shutil.copyfile, primary_fname, backup_fname,
|
||||
))
|
||||
else:
|
||||
preserve_fn(threads.deferToThread)(
|
||||
shutil.copyfile, primary_fname, backup_fname,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def create_content(self, media_type, upload_name, content, content_length,
|
||||
auth_user):
|
||||
"""Store uploaded content for a local user and return the mxc URL
|
||||
|
||||
Args:
|
||||
media_type(str): The content type of the file
|
||||
upload_name(str): The name of the file
|
||||
content: A file like object that is the content to store
|
||||
content_length(int): The length of the content
|
||||
auth_user(str): The user_id of the uploader
|
||||
|
||||
Returns:
|
||||
Deferred[str]: The mxc url of the stored content
|
||||
"""
|
||||
media_id = random_string(24)
|
||||
|
||||
fname = self.filepaths.local_media_filepath(media_id)
|
||||
self._makedirs(fname)
|
||||
|
||||
# This shouldn't block for very long because the content will have
|
||||
# already been uploaded at this point.
|
||||
with open(fname, "wb") as f:
|
||||
f.write(content)
|
||||
fname = yield self.write_to_file_and_backup(
|
||||
content, self.filepaths.local_media_filepath_rel(media_id)
|
||||
)
|
||||
|
||||
logger.info("Stored local media in file %r", fname)
|
||||
|
||||
@@ -115,7 +190,7 @@ class MediaRepository(object):
|
||||
"media_length": content_length,
|
||||
}
|
||||
|
||||
yield self._generate_local_thumbnails(media_id, media_info)
|
||||
yield self._generate_thumbnails(None, media_id, media_info)
|
||||
|
||||
defer.returnValue("mxc://%s/%s" % (self.server_name, media_id))
|
||||
|
||||
@@ -148,9 +223,10 @@ class MediaRepository(object):
|
||||
def _download_remote_file(self, server_name, media_id):
|
||||
file_id = random_string(24)
|
||||
|
||||
fname = self.filepaths.remote_media_filepath(
|
||||
fpath = self.filepaths.remote_media_filepath_rel(
|
||||
server_name, file_id
|
||||
)
|
||||
fname = os.path.join(self.primary_base_path, fpath)
|
||||
self._makedirs(fname)
|
||||
|
||||
try:
|
||||
@@ -192,6 +268,8 @@ class MediaRepository(object):
|
||||
server_name, media_id)
|
||||
raise SynapseError(502, "Failed to fetch remote media")
|
||||
|
||||
yield self.copy_to_backup(fpath)
|
||||
|
||||
media_type = headers["Content-Type"][0]
|
||||
time_now_ms = self.clock.time_msec()
|
||||
|
||||
@@ -244,7 +322,7 @@ class MediaRepository(object):
|
||||
"filesystem_id": file_id,
|
||||
}
|
||||
|
||||
yield self._generate_remote_thumbnails(
|
||||
yield self._generate_thumbnails(
|
||||
server_name, media_id, media_info
|
||||
)
|
||||
|
||||
@@ -253,9 +331,8 @@ class MediaRepository(object):
|
||||
def _get_thumbnail_requirements(self, media_type):
|
||||
return self.thumbnail_requirements.get(media_type, ())
|
||||
|
||||
def _generate_thumbnail(self, input_path, t_path, t_width, t_height,
|
||||
def _generate_thumbnail(self, thumbnailer, t_width, t_height,
|
||||
t_method, t_type):
|
||||
thumbnailer = Thumbnailer(input_path)
|
||||
m_width = thumbnailer.width
|
||||
m_height = thumbnailer.height
|
||||
|
||||
@@ -267,72 +344,105 @@ class MediaRepository(object):
|
||||
return
|
||||
|
||||
if t_method == "crop":
|
||||
t_len = thumbnailer.crop(t_path, t_width, t_height, t_type)
|
||||
t_byte_source = thumbnailer.crop(t_width, t_height, t_type)
|
||||
elif t_method == "scale":
|
||||
t_width, t_height = thumbnailer.aspect(t_width, t_height)
|
||||
t_width = min(m_width, t_width)
|
||||
t_height = min(m_height, t_height)
|
||||
t_len = thumbnailer.scale(t_path, t_width, t_height, t_type)
|
||||
t_byte_source = thumbnailer.scale(t_width, t_height, t_type)
|
||||
else:
|
||||
t_len = None
|
||||
t_byte_source = None
|
||||
|
||||
return t_len
|
||||
return t_byte_source
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def generate_local_exact_thumbnail(self, media_id, t_width, t_height,
|
||||
t_method, t_type):
|
||||
input_path = self.filepaths.local_media_filepath(media_id)
|
||||
|
||||
t_path = self.filepaths.local_media_thumbnail(
|
||||
media_id, t_width, t_height, t_type, t_method
|
||||
)
|
||||
self._makedirs(t_path)
|
||||
|
||||
t_len = yield preserve_context_over_fn(
|
||||
threads.deferToThread,
|
||||
thumbnailer = Thumbnailer(input_path)
|
||||
t_byte_source = yield make_deferred_yieldable(threads.deferToThread(
|
||||
self._generate_thumbnail,
|
||||
input_path, t_path, t_width, t_height, t_method, t_type
|
||||
)
|
||||
thumbnailer, t_width, t_height, t_method, t_type
|
||||
))
|
||||
|
||||
if t_byte_source:
|
||||
try:
|
||||
output_path = yield self.write_to_file_and_backup(
|
||||
t_byte_source,
|
||||
self.filepaths.local_media_thumbnail_rel(
|
||||
media_id, t_width, t_height, t_type, t_method
|
||||
)
|
||||
)
|
||||
finally:
|
||||
t_byte_source.close()
|
||||
|
||||
logger.info("Stored thumbnail in file %r", output_path)
|
||||
|
||||
t_len = os.path.getsize(output_path)
|
||||
|
||||
if t_len:
|
||||
yield self.store.store_local_thumbnail(
|
||||
media_id, t_width, t_height, t_type, t_method, t_len
|
||||
)
|
||||
|
||||
defer.returnValue(t_path)
|
||||
defer.returnValue(output_path)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def generate_remote_exact_thumbnail(self, server_name, file_id, media_id,
|
||||
t_width, t_height, t_method, t_type):
|
||||
input_path = self.filepaths.remote_media_filepath(server_name, file_id)
|
||||
|
||||
t_path = self.filepaths.remote_media_thumbnail(
|
||||
server_name, file_id, t_width, t_height, t_type, t_method
|
||||
)
|
||||
self._makedirs(t_path)
|
||||
|
||||
t_len = yield preserve_context_over_fn(
|
||||
threads.deferToThread,
|
||||
thumbnailer = Thumbnailer(input_path)
|
||||
t_byte_source = yield make_deferred_yieldable(threads.deferToThread(
|
||||
self._generate_thumbnail,
|
||||
input_path, t_path, t_width, t_height, t_method, t_type
|
||||
)
|
||||
thumbnailer, t_width, t_height, t_method, t_type
|
||||
))
|
||||
|
||||
if t_byte_source:
|
||||
try:
|
||||
output_path = yield self.write_to_file_and_backup(
|
||||
t_byte_source,
|
||||
self.filepaths.remote_media_thumbnail_rel(
|
||||
server_name, file_id, t_width, t_height, t_type, t_method
|
||||
)
|
||||
)
|
||||
finally:
|
||||
t_byte_source.close()
|
||||
|
||||
logger.info("Stored thumbnail in file %r", output_path)
|
||||
|
||||
t_len = os.path.getsize(output_path)
|
||||
|
||||
if t_len:
|
||||
yield self.store.store_remote_media_thumbnail(
|
||||
server_name, media_id, file_id,
|
||||
t_width, t_height, t_type, t_method, t_len
|
||||
)
|
||||
|
||||
defer.returnValue(t_path)
|
||||
defer.returnValue(output_path)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _generate_local_thumbnails(self, media_id, media_info, url_cache=False):
|
||||
def _generate_thumbnails(self, server_name, media_id, media_info, url_cache=False):
|
||||
"""Generate and store thumbnails for an image.
|
||||
|
||||
Args:
|
||||
server_name(str|None): The server name if remote media, else None if local
|
||||
media_id(str)
|
||||
media_info(dict)
|
||||
url_cache(bool): If we are thumbnailing images downloaded for the URL cache,
|
||||
used exclusively by the url previewer
|
||||
|
||||
Returns:
|
||||
Deferred[dict]: Dict with "width" and "height" keys of original image
|
||||
"""
|
||||
media_type = media_info["media_type"]
|
||||
file_id = media_info.get("filesystem_id")
|
||||
requirements = self._get_thumbnail_requirements(media_type)
|
||||
if not requirements:
|
||||
return
|
||||
|
||||
if url_cache:
|
||||
if server_name:
|
||||
input_path = self.filepaths.remote_media_filepath(server_name, file_id)
|
||||
elif url_cache:
|
||||
input_path = self.filepaths.url_cache_filepath(media_id)
|
||||
else:
|
||||
input_path = self.filepaths.local_media_filepath(media_id)
|
||||
@@ -348,135 +458,72 @@ class MediaRepository(object):
|
||||
)
|
||||
return
|
||||
|
||||
local_thumbnails = []
|
||||
# We deduplicate the thumbnail sizes by ignoring the cropped versions if
|
||||
# they have the same dimensions of a scaled one.
|
||||
thumbnails = {}
|
||||
for r_width, r_height, r_method, r_type in requirements:
|
||||
if r_method == "crop":
|
||||
thumbnails.setdefault((r_width, r_height, r_type), r_method)
|
||||
elif r_method == "scale":
|
||||
t_width, t_height = thumbnailer.aspect(r_width, r_height)
|
||||
t_width = min(m_width, t_width)
|
||||
t_height = min(m_height, t_height)
|
||||
thumbnails[(t_width, t_height, r_type)] = r_method
|
||||
|
||||
def generate_thumbnails():
|
||||
scales = set()
|
||||
crops = set()
|
||||
for r_width, r_height, r_method, r_type in requirements:
|
||||
if r_method == "scale":
|
||||
t_width, t_height = thumbnailer.aspect(r_width, r_height)
|
||||
scales.add((
|
||||
min(m_width, t_width), min(m_height, t_height), r_type,
|
||||
))
|
||||
elif r_method == "crop":
|
||||
crops.add((r_width, r_height, r_type))
|
||||
|
||||
for t_width, t_height, t_type in scales:
|
||||
t_method = "scale"
|
||||
if url_cache:
|
||||
t_path = self.filepaths.url_cache_thumbnail(
|
||||
media_id, t_width, t_height, t_type, t_method
|
||||
)
|
||||
else:
|
||||
t_path = self.filepaths.local_media_thumbnail(
|
||||
media_id, t_width, t_height, t_type, t_method
|
||||
)
|
||||
self._makedirs(t_path)
|
||||
t_len = thumbnailer.scale(t_path, t_width, t_height, t_type)
|
||||
|
||||
local_thumbnails.append((
|
||||
media_id, t_width, t_height, t_type, t_method, t_len
|
||||
))
|
||||
|
||||
for t_width, t_height, t_type in crops:
|
||||
if (t_width, t_height, t_type) in scales:
|
||||
# If the aspect ratio of the cropped thumbnail matches a purely
|
||||
# scaled one then there is no point in calculating a separate
|
||||
# thumbnail.
|
||||
continue
|
||||
t_method = "crop"
|
||||
if url_cache:
|
||||
t_path = self.filepaths.url_cache_thumbnail(
|
||||
media_id, t_width, t_height, t_type, t_method
|
||||
)
|
||||
else:
|
||||
t_path = self.filepaths.local_media_thumbnail(
|
||||
media_id, t_width, t_height, t_type, t_method
|
||||
)
|
||||
self._makedirs(t_path)
|
||||
t_len = thumbnailer.crop(t_path, t_width, t_height, t_type)
|
||||
local_thumbnails.append((
|
||||
media_id, t_width, t_height, t_type, t_method, t_len
|
||||
))
|
||||
|
||||
yield preserve_context_over_fn(threads.deferToThread, generate_thumbnails)
|
||||
|
||||
for l in local_thumbnails:
|
||||
yield self.store.store_local_thumbnail(*l)
|
||||
|
||||
defer.returnValue({
|
||||
"width": m_width,
|
||||
"height": m_height,
|
||||
})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _generate_remote_thumbnails(self, server_name, media_id, media_info):
|
||||
media_type = media_info["media_type"]
|
||||
file_id = media_info["filesystem_id"]
|
||||
requirements = self._get_thumbnail_requirements(media_type)
|
||||
if not requirements:
|
||||
return
|
||||
|
||||
remote_thumbnails = []
|
||||
|
||||
input_path = self.filepaths.remote_media_filepath(server_name, file_id)
|
||||
thumbnailer = Thumbnailer(input_path)
|
||||
m_width = thumbnailer.width
|
||||
m_height = thumbnailer.height
|
||||
|
||||
def generate_thumbnails():
|
||||
if m_width * m_height >= self.max_image_pixels:
|
||||
logger.info(
|
||||
"Image too large to thumbnail %r x %r > %r",
|
||||
m_width, m_height, self.max_image_pixels
|
||||
)
|
||||
return
|
||||
|
||||
scales = set()
|
||||
crops = set()
|
||||
for r_width, r_height, r_method, r_type in requirements:
|
||||
if r_method == "scale":
|
||||
t_width, t_height = thumbnailer.aspect(r_width, r_height)
|
||||
scales.add((
|
||||
min(m_width, t_width), min(m_height, t_height), r_type,
|
||||
))
|
||||
elif r_method == "crop":
|
||||
crops.add((r_width, r_height, r_type))
|
||||
|
||||
for t_width, t_height, t_type in scales:
|
||||
t_method = "scale"
|
||||
t_path = self.filepaths.remote_media_thumbnail(
|
||||
# Now we generate the thumbnails for each dimension, store it
|
||||
for (t_width, t_height, t_type), t_method in thumbnails.iteritems():
|
||||
# Work out the correct file name for thumbnail
|
||||
if server_name:
|
||||
file_path = self.filepaths.remote_media_thumbnail_rel(
|
||||
server_name, file_id, t_width, t_height, t_type, t_method
|
||||
)
|
||||
self._makedirs(t_path)
|
||||
t_len = thumbnailer.scale(t_path, t_width, t_height, t_type)
|
||||
remote_thumbnails.append([
|
||||
server_name, media_id, file_id,
|
||||
t_width, t_height, t_type, t_method, t_len
|
||||
])
|
||||
|
||||
for t_width, t_height, t_type in crops:
|
||||
if (t_width, t_height, t_type) in scales:
|
||||
# If the aspect ratio of the cropped thumbnail matches a purely
|
||||
# scaled one then there is no point in calculating a separate
|
||||
# thumbnail.
|
||||
continue
|
||||
t_method = "crop"
|
||||
t_path = self.filepaths.remote_media_thumbnail(
|
||||
server_name, file_id, t_width, t_height, t_type, t_method
|
||||
elif url_cache:
|
||||
file_path = self.filepaths.url_cache_thumbnail_rel(
|
||||
media_id, t_width, t_height, t_type, t_method
|
||||
)
|
||||
self._makedirs(t_path)
|
||||
t_len = thumbnailer.crop(t_path, t_width, t_height, t_type)
|
||||
remote_thumbnails.append([
|
||||
else:
|
||||
file_path = self.filepaths.local_media_thumbnail_rel(
|
||||
media_id, t_width, t_height, t_type, t_method
|
||||
)
|
||||
|
||||
# Generate the thumbnail
|
||||
if t_method == "crop":
|
||||
t_byte_source = yield make_deferred_yieldable(threads.deferToThread(
|
||||
thumbnailer.crop,
|
||||
t_width, t_height, t_type,
|
||||
))
|
||||
elif t_method == "scale":
|
||||
t_byte_source = yield make_deferred_yieldable(threads.deferToThread(
|
||||
thumbnailer.scale,
|
||||
t_width, t_height, t_type,
|
||||
))
|
||||
else:
|
||||
logger.error("Unrecognized method: %r", t_method)
|
||||
continue
|
||||
|
||||
if not t_byte_source:
|
||||
continue
|
||||
|
||||
try:
|
||||
# Write to disk
|
||||
output_path = yield self.write_to_file_and_backup(
|
||||
t_byte_source, file_path,
|
||||
)
|
||||
finally:
|
||||
t_byte_source.close()
|
||||
|
||||
t_len = os.path.getsize(output_path)
|
||||
|
||||
# Write to database
|
||||
if server_name:
|
||||
yield self.store.store_remote_media_thumbnail(
|
||||
server_name, media_id, file_id,
|
||||
t_width, t_height, t_type, t_method, t_len
|
||||
])
|
||||
|
||||
yield preserve_context_over_fn(threads.deferToThread, generate_thumbnails)
|
||||
|
||||
for r in remote_thumbnails:
|
||||
yield self.store.store_remote_media_thumbnail(*r)
|
||||
)
|
||||
else:
|
||||
yield self.store.store_local_thumbnail(
|
||||
media_id, t_width, t_height, t_type, t_method, t_len
|
||||
)
|
||||
|
||||
defer.returnValue({
|
||||
"width": m_width,
|
||||
@@ -497,6 +544,8 @@ class MediaRepository(object):
|
||||
|
||||
logger.info("Deleting: %r", key)
|
||||
|
||||
# TODO: Should we delete from the backup store
|
||||
|
||||
with (yield self.remote_media_linearizer.queue(key)):
|
||||
full_path = self.filepaths.remote_media_filepath(origin, file_id)
|
||||
try:
|
||||
|
||||
@@ -36,6 +36,9 @@ import cgi
|
||||
import ujson as json
|
||||
import urlparse
|
||||
import itertools
|
||||
import datetime
|
||||
import errno
|
||||
import shutil
|
||||
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -56,6 +59,7 @@ class PreviewUrlResource(Resource):
|
||||
self.store = hs.get_datastore()
|
||||
self.client = SpiderHttpClient(hs)
|
||||
self.media_repo = media_repo
|
||||
self.primary_base_path = media_repo.primary_base_path
|
||||
|
||||
self.url_preview_url_blacklist = hs.config.url_preview_url_blacklist
|
||||
|
||||
@@ -70,6 +74,10 @@ class PreviewUrlResource(Resource):
|
||||
|
||||
self.downloads = {}
|
||||
|
||||
self._cleaner_loop = self.clock.looping_call(
|
||||
self._expire_url_cache_data, 10 * 1000
|
||||
)
|
||||
|
||||
def render_GET(self, request):
|
||||
self._async_render_GET(request)
|
||||
return NOT_DONE_YET
|
||||
@@ -130,7 +138,7 @@ class PreviewUrlResource(Resource):
|
||||
cache_result = yield self.store.get_url_cache(url, ts)
|
||||
if (
|
||||
cache_result and
|
||||
cache_result["download_ts"] + cache_result["expires"] > ts and
|
||||
cache_result["expires_ts"] > ts and
|
||||
cache_result["response_code"] / 100 == 2
|
||||
):
|
||||
respond_with_json_bytes(
|
||||
@@ -163,8 +171,8 @@ class PreviewUrlResource(Resource):
|
||||
logger.debug("got media_info of '%s'" % media_info)
|
||||
|
||||
if _is_media(media_info['media_type']):
|
||||
dims = yield self.media_repo._generate_local_thumbnails(
|
||||
media_info['filesystem_id'], media_info, url_cache=True,
|
||||
dims = yield self.media_repo._generate_thumbnails(
|
||||
None, media_info['filesystem_id'], media_info, url_cache=True,
|
||||
)
|
||||
|
||||
og = {
|
||||
@@ -209,8 +217,8 @@ class PreviewUrlResource(Resource):
|
||||
|
||||
if _is_media(image_info['media_type']):
|
||||
# TODO: make sure we don't choke on white-on-transparent images
|
||||
dims = yield self.media_repo._generate_local_thumbnails(
|
||||
image_info['filesystem_id'], image_info, url_cache=True,
|
||||
dims = yield self.media_repo._generate_thumbnails(
|
||||
None, image_info['filesystem_id'], image_info, url_cache=True,
|
||||
)
|
||||
if dims:
|
||||
og["og:image:width"] = dims['width']
|
||||
@@ -239,7 +247,7 @@ class PreviewUrlResource(Resource):
|
||||
url,
|
||||
media_info["response_code"],
|
||||
media_info["etag"],
|
||||
media_info["expires"],
|
||||
media_info["expires"] + media_info["created_ts"],
|
||||
json.dumps(og),
|
||||
media_info["filesystem_id"],
|
||||
media_info["created_ts"],
|
||||
@@ -253,10 +261,10 @@ class PreviewUrlResource(Resource):
|
||||
# we're most likely being explicitly triggered by a human rather than a
|
||||
# bot, so are we really a robot?
|
||||
|
||||
# XXX: horrible duplication with base_resource's _download_remote_file()
|
||||
file_id = random_string(24)
|
||||
file_id = datetime.date.today().isoformat() + '_' + random_string(16)
|
||||
|
||||
fname = self.filepaths.url_cache_filepath(file_id)
|
||||
fpath = self.filepaths.url_cache_filepath_rel(file_id)
|
||||
fname = os.path.join(self.primary_base_path, fpath)
|
||||
self.media_repo._makedirs(fname)
|
||||
|
||||
try:
|
||||
@@ -267,6 +275,8 @@ class PreviewUrlResource(Resource):
|
||||
)
|
||||
# FIXME: pass through 404s and other error messages nicely
|
||||
|
||||
yield self.media_repo.copy_to_backup(fpath)
|
||||
|
||||
media_type = headers["Content-Type"][0]
|
||||
time_now_ms = self.clock.time_msec()
|
||||
|
||||
@@ -328,6 +338,91 @@ class PreviewUrlResource(Resource):
|
||||
"etag": headers["ETag"][0] if "ETag" in headers else None,
|
||||
})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _expire_url_cache_data(self):
|
||||
"""Clean up expired url cache content, media and thumbnails.
|
||||
"""
|
||||
|
||||
# TODO: Delete from backup media store
|
||||
|
||||
now = self.clock.time_msec()
|
||||
|
||||
# First we delete expired url cache entries
|
||||
media_ids = yield self.store.get_expired_url_cache(now)
|
||||
|
||||
removed_media = []
|
||||
for media_id in media_ids:
|
||||
fname = self.filepaths.url_cache_filepath(media_id)
|
||||
try:
|
||||
os.remove(fname)
|
||||
except OSError as e:
|
||||
# If the path doesn't exist, meh
|
||||
if e.errno != errno.ENOENT:
|
||||
logger.warn("Failed to remove media: %r: %s", media_id, e)
|
||||
continue
|
||||
|
||||
removed_media.append(media_id)
|
||||
|
||||
try:
|
||||
dirs = self.filepaths.url_cache_filepath_dirs_to_delete(media_id)
|
||||
for dir in dirs:
|
||||
os.rmdir(dir)
|
||||
except:
|
||||
pass
|
||||
|
||||
yield self.store.delete_url_cache(removed_media)
|
||||
|
||||
if removed_media:
|
||||
logger.info("Deleted %d entries from url cache", len(removed_media))
|
||||
|
||||
# Now we delete old images associated with the url cache.
|
||||
# These may be cached for a bit on the client (i.e., they
|
||||
# may have a room open with a preview url thing open).
|
||||
# So we wait a couple of days before deleting, just in case.
|
||||
expire_before = now - 2 * 24 * 60 * 60 * 1000
|
||||
media_ids = yield self.store.get_url_cache_media_before(expire_before)
|
||||
|
||||
removed_media = []
|
||||
for media_id in media_ids:
|
||||
fname = self.filepaths.url_cache_filepath(media_id)
|
||||
try:
|
||||
os.remove(fname)
|
||||
except OSError as e:
|
||||
# If the path doesn't exist, meh
|
||||
if e.errno != errno.ENOENT:
|
||||
logger.warn("Failed to remove media: %r: %s", media_id, e)
|
||||
continue
|
||||
|
||||
try:
|
||||
dirs = self.filepaths.url_cache_filepath_dirs_to_delete(media_id)
|
||||
for dir in dirs:
|
||||
os.rmdir(dir)
|
||||
except:
|
||||
pass
|
||||
|
||||
thumbnail_dir = self.filepaths.url_cache_thumbnail_directory(media_id)
|
||||
try:
|
||||
shutil.rmtree(thumbnail_dir)
|
||||
except OSError as e:
|
||||
# If the path doesn't exist, meh
|
||||
if e.errno != errno.ENOENT:
|
||||
logger.warn("Failed to remove media: %r: %s", media_id, e)
|
||||
continue
|
||||
|
||||
removed_media.append(media_id)
|
||||
|
||||
try:
|
||||
dirs = self.filepaths.url_cache_thumbnail_dirs_to_delete(media_id)
|
||||
for dir in dirs:
|
||||
os.rmdir(dir)
|
||||
except:
|
||||
pass
|
||||
|
||||
yield self.store.delete_url_cache_media(removed_media)
|
||||
|
||||
if removed_media:
|
||||
logger.info("Deleted %d media from url cache", len(removed_media))
|
||||
|
||||
|
||||
def decode_and_calc_og(body, media_uri, request_encoding=None):
|
||||
from lxml import etree
|
||||
|
||||
@@ -50,12 +50,16 @@ class Thumbnailer(object):
|
||||
else:
|
||||
return ((max_height * self.width) // self.height, max_height)
|
||||
|
||||
def scale(self, output_path, width, height, output_type):
|
||||
"""Rescales the image to the given dimensions"""
|
||||
scaled = self.image.resize((width, height), Image.ANTIALIAS)
|
||||
return self.save_image(scaled, output_type, output_path)
|
||||
def scale(self, width, height, output_type):
|
||||
"""Rescales the image to the given dimensions.
|
||||
|
||||
def crop(self, output_path, width, height, output_type):
|
||||
Returns:
|
||||
BytesIO: the bytes of the encoded image ready to be written to disk
|
||||
"""
|
||||
scaled = self.image.resize((width, height), Image.ANTIALIAS)
|
||||
return self._encode_image(scaled, output_type)
|
||||
|
||||
def crop(self, width, height, output_type):
|
||||
"""Rescales and crops the image to the given dimensions preserving
|
||||
aspect::
|
||||
(w_in / h_in) = (w_scaled / h_scaled)
|
||||
@@ -65,6 +69,9 @@ class Thumbnailer(object):
|
||||
Args:
|
||||
max_width: The largest possible width.
|
||||
max_height: The larget possible height.
|
||||
|
||||
Returns:
|
||||
BytesIO: the bytes of the encoded image ready to be written to disk
|
||||
"""
|
||||
if width * self.height > height * self.width:
|
||||
scaled_height = (width * self.height) // self.width
|
||||
@@ -82,13 +89,9 @@ class Thumbnailer(object):
|
||||
crop_left = (scaled_width - width) // 2
|
||||
crop_right = width + crop_left
|
||||
cropped = scaled_image.crop((crop_left, 0, crop_right, height))
|
||||
return self.save_image(cropped, output_type, output_path)
|
||||
return self._encode_image(cropped, output_type)
|
||||
|
||||
def save_image(self, output_image, output_type, output_path):
|
||||
def _encode_image(self, output_image, output_type):
|
||||
output_bytes_io = BytesIO()
|
||||
output_image.save(output_bytes_io, self.FORMATS[output_type], quality=80)
|
||||
output_bytes = output_bytes_io.getvalue()
|
||||
with open(output_path, "wb") as output_file:
|
||||
output_file.write(output_bytes)
|
||||
logger.info("Stored thumbnail in file %r", output_path)
|
||||
return len(output_bytes)
|
||||
return output_bytes_io
|
||||
|
||||
@@ -93,7 +93,7 @@ class UploadResource(Resource):
|
||||
# TODO(markjh): parse content-dispostion
|
||||
|
||||
content_uri = yield self.media_repo.create_content(
|
||||
media_type, upload_name, request.content.read(),
|
||||
media_type, upload_name, request.content,
|
||||
content_length, requester.user
|
||||
)
|
||||
|
||||
|
||||
@@ -31,6 +31,7 @@ from synapse.appservice.api import ApplicationServiceApi
|
||||
from synapse.appservice.scheduler import ApplicationServiceScheduler
|
||||
from synapse.crypto.keyring import Keyring
|
||||
from synapse.events.builder import EventBuilderFactory
|
||||
from synapse.events.spamcheck import SpamChecker
|
||||
from synapse.federation import initialize_http_replication
|
||||
from synapse.federation.send_queue import FederationRemoteSendQueue
|
||||
from synapse.federation.transport.client import TransportLayerClient
|
||||
@@ -50,6 +51,10 @@ from synapse.handlers.initial_sync import InitialSyncHandler
|
||||
from synapse.handlers.receipts import ReceiptsHandler
|
||||
from synapse.handlers.read_marker import ReadMarkerHandler
|
||||
from synapse.handlers.user_directory import UserDirectoyHandler
|
||||
from synapse.handlers.groups_local import GroupsLocalHandler
|
||||
from synapse.handlers.profile import ProfileHandler
|
||||
from synapse.groups.groups_server import GroupsServerHandler
|
||||
from synapse.groups.attestations import GroupAttestionRenewer, GroupAttestationSigning
|
||||
from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory
|
||||
from synapse.http.matrixfederationclient import MatrixFederationHttpClient
|
||||
from synapse.notifier import Notifier
|
||||
@@ -111,6 +116,7 @@ class HomeServer(object):
|
||||
'application_service_scheduler',
|
||||
'application_service_handler',
|
||||
'device_message_handler',
|
||||
'profile_handler',
|
||||
'notifier',
|
||||
'distributor',
|
||||
'client_resource',
|
||||
@@ -139,6 +145,11 @@ class HomeServer(object):
|
||||
'read_marker_handler',
|
||||
'action_generator',
|
||||
'user_directory_handler',
|
||||
'groups_local_handler',
|
||||
'groups_server_handler',
|
||||
'groups_attestation_signing',
|
||||
'groups_attestation_renewer',
|
||||
'spam_checker',
|
||||
]
|
||||
|
||||
def __init__(self, hostname, **kwargs):
|
||||
@@ -251,6 +262,9 @@ class HomeServer(object):
|
||||
def build_initial_sync_handler(self):
|
||||
return InitialSyncHandler(self)
|
||||
|
||||
def build_profile_handler(self):
|
||||
return ProfileHandler(self)
|
||||
|
||||
def build_event_sources(self):
|
||||
return EventSources(self)
|
||||
|
||||
@@ -309,6 +323,21 @@ class HomeServer(object):
|
||||
def build_user_directory_handler(self):
|
||||
return UserDirectoyHandler(self)
|
||||
|
||||
def build_groups_local_handler(self):
|
||||
return GroupsLocalHandler(self)
|
||||
|
||||
def build_groups_server_handler(self):
|
||||
return GroupsServerHandler(self)
|
||||
|
||||
def build_groups_attestation_signing(self):
|
||||
return GroupAttestationSigning(self)
|
||||
|
||||
def build_groups_attestation_renewer(self):
|
||||
return GroupAttestionRenewer(self)
|
||||
|
||||
def build_spam_checker(self):
|
||||
return SpamChecker(self)
|
||||
|
||||
def remove_pusher(self, app_id, push_key, user_id):
|
||||
return self.get_pusherpool().remove_pusher(app_id, push_key, user_id)
|
||||
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
import synapse.api.auth
|
||||
import synapse.federation.transaction_queue
|
||||
import synapse.federation.transport.client
|
||||
import synapse.handlers
|
||||
import synapse.handlers.auth
|
||||
import synapse.handlers.device
|
||||
@@ -27,3 +29,9 @@ class HomeServer(object):
|
||||
|
||||
def get_state_handler(self) -> synapse.state.StateHandler:
|
||||
pass
|
||||
|
||||
def get_federation_sender(self) -> synapse.federation.transaction_queue.TransactionQueue:
|
||||
pass
|
||||
|
||||
def get_federation_transport_client(self) -> synapse.federation.transport.client.TransportLayerClient:
|
||||
pass
|
||||
|
||||
@@ -288,6 +288,9 @@ class StateHandler(object):
|
||||
"""
|
||||
logger.debug("resolve_state_groups event_ids %s", event_ids)
|
||||
|
||||
# map from state group id to the state in that state group (where
|
||||
# 'state' is a map from state key to event id)
|
||||
# dict[int, dict[(str, str), str]]
|
||||
state_groups_ids = yield self.store.get_state_groups_ids(
|
||||
room_id, event_ids
|
||||
)
|
||||
@@ -320,11 +323,15 @@ class StateHandler(object):
|
||||
"Resolving state for %s with %d groups", room_id, len(state_groups_ids)
|
||||
)
|
||||
|
||||
# build a map from state key to the event_ids which set that state.
|
||||
# dict[(str, str), set[str])
|
||||
state = {}
|
||||
for st in state_groups_ids.values():
|
||||
for key, e_id in st.items():
|
||||
state.setdefault(key, set()).add(e_id)
|
||||
|
||||
# build a map from state key to the event_ids which set that state,
|
||||
# including only those where there are state keys in conflict.
|
||||
conflicted_state = {
|
||||
k: list(v)
|
||||
for k, v in state.items()
|
||||
@@ -494,8 +501,14 @@ def _resolve_with_state_fac(unconflicted_state, conflicted_state,
|
||||
|
||||
logger.info("Asking for %d conflicted events", len(needed_events))
|
||||
|
||||
# dict[str, FrozenEvent]: a map from state event id to event. Only includes
|
||||
# the state events which are in conflict.
|
||||
state_map = yield state_map_factory(needed_events)
|
||||
|
||||
# get the ids of the auth events which allow us to authenticate the
|
||||
# conflicted state, picking only from the unconflicting state.
|
||||
#
|
||||
# dict[(str, str), str]: a map from state key to event id
|
||||
auth_events = _create_auth_events_from_maps(
|
||||
unconflicted_state, conflicted_state, state_map
|
||||
)
|
||||
|
||||
@@ -37,7 +37,7 @@ from .media_repository import MediaRepositoryStore
|
||||
from .rejections import RejectionsStore
|
||||
from .event_push_actions import EventPushActionsStore
|
||||
from .deviceinbox import DeviceInboxStore
|
||||
|
||||
from .group_server import GroupServerStore
|
||||
from .state import StateStore
|
||||
from .signatures import SignatureStore
|
||||
from .filtering import FilteringStore
|
||||
@@ -88,6 +88,7 @@ class DataStore(RoomMemberStore, RoomStore,
|
||||
DeviceStore,
|
||||
DeviceInboxStore,
|
||||
UserDirectoryStore,
|
||||
GroupServerStore,
|
||||
):
|
||||
|
||||
def __init__(self, db_conn, hs):
|
||||
@@ -135,6 +136,9 @@ class DataStore(RoomMemberStore, RoomStore,
|
||||
db_conn, "pushers", "id",
|
||||
extra_tables=[("deleted_pushers", "stream_id")],
|
||||
)
|
||||
self._group_updates_id_gen = StreamIdGenerator(
|
||||
db_conn, "local_group_updates", "stream_id",
|
||||
)
|
||||
|
||||
if isinstance(self.database_engine, PostgresEngine):
|
||||
self._cache_id_gen = StreamIdGenerator(
|
||||
@@ -235,6 +239,18 @@ class DataStore(RoomMemberStore, RoomStore,
|
||||
prefilled_cache=curr_state_delta_prefill,
|
||||
)
|
||||
|
||||
_group_updates_prefill, min_group_updates_id = self._get_cache_dict(
|
||||
db_conn, "local_group_updates",
|
||||
entity_column="user_id",
|
||||
stream_column="stream_id",
|
||||
max_value=self._group_updates_id_gen.get_current_token(),
|
||||
limit=1000,
|
||||
)
|
||||
self._group_updates_stream_cache = StreamChangeCache(
|
||||
"_group_updates_stream_cache", min_group_updates_id,
|
||||
prefilled_cache=_group_updates_prefill,
|
||||
)
|
||||
|
||||
cur = LoggingTransaction(
|
||||
db_conn.cursor(),
|
||||
name="_find_stream_orderings_for_times_txn",
|
||||
|
||||
@@ -743,6 +743,33 @@ class SQLBaseStore(object):
|
||||
txn.execute(sql, values)
|
||||
return cls.cursor_to_dict(txn)
|
||||
|
||||
def _simple_update(self, table, keyvalues, updatevalues, desc):
|
||||
return self.runInteraction(
|
||||
desc,
|
||||
self._simple_update_txn,
|
||||
table, keyvalues, updatevalues,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _simple_update_txn(txn, table, keyvalues, updatevalues):
|
||||
if keyvalues:
|
||||
where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.iterkeys())
|
||||
else:
|
||||
where = ""
|
||||
|
||||
update_sql = "UPDATE %s SET %s %s" % (
|
||||
table,
|
||||
", ".join("%s = ?" % (k,) for k in updatevalues),
|
||||
where,
|
||||
)
|
||||
|
||||
txn.execute(
|
||||
update_sql,
|
||||
updatevalues.values() + keyvalues.values()
|
||||
)
|
||||
|
||||
return txn.rowcount
|
||||
|
||||
def _simple_update_one(self, table, keyvalues, updatevalues,
|
||||
desc="_simple_update_one"):
|
||||
"""Executes an UPDATE query on the named table, setting new values for
|
||||
@@ -768,27 +795,13 @@ class SQLBaseStore(object):
|
||||
table, keyvalues, updatevalues,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _simple_update_one_txn(txn, table, keyvalues, updatevalues):
|
||||
if keyvalues:
|
||||
where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.iterkeys())
|
||||
else:
|
||||
where = ""
|
||||
@classmethod
|
||||
def _simple_update_one_txn(cls, txn, table, keyvalues, updatevalues):
|
||||
rowcount = cls._simple_update_txn(txn, table, keyvalues, updatevalues)
|
||||
|
||||
update_sql = "UPDATE %s SET %s %s" % (
|
||||
table,
|
||||
", ".join("%s = ?" % (k,) for k in updatevalues),
|
||||
where,
|
||||
)
|
||||
|
||||
txn.execute(
|
||||
update_sql,
|
||||
updatevalues.values() + keyvalues.values()
|
||||
)
|
||||
|
||||
if txn.rowcount == 0:
|
||||
if rowcount == 0:
|
||||
raise StoreError(404, "No row found")
|
||||
if txn.rowcount > 1:
|
||||
if rowcount > 1:
|
||||
raise StoreError(500, "More than one row matched")
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -21,7 +21,7 @@ from synapse.events.utils import prune_event
|
||||
|
||||
from synapse.util.async import ObservableDeferred
|
||||
from synapse.util.logcontext import (
|
||||
preserve_fn, PreserveLoggingContext, preserve_context_over_deferred
|
||||
preserve_fn, PreserveLoggingContext, make_deferred_yieldable
|
||||
)
|
||||
from synapse.util.logutils import log_function
|
||||
from synapse.util.metrics import Measure
|
||||
@@ -88,13 +88,23 @@ class _EventPeristenceQueue(object):
|
||||
def add_to_queue(self, room_id, events_and_contexts, backfilled):
|
||||
"""Add events to the queue, with the given persist_event options.
|
||||
|
||||
NB: due to the normal usage pattern of this method, it does *not*
|
||||
follow the synapse logcontext rules, and leaves the logcontext in
|
||||
place whether or not the returned deferred is ready.
|
||||
|
||||
Args:
|
||||
room_id (str):
|
||||
events_and_contexts (list[(EventBase, EventContext)]):
|
||||
backfilled (bool):
|
||||
|
||||
Returns:
|
||||
defer.Deferred: a deferred which will resolve once the events are
|
||||
persisted. Runs its callbacks *without* a logcontext.
|
||||
"""
|
||||
queue = self._event_persist_queues.setdefault(room_id, deque())
|
||||
if queue:
|
||||
# if the last item in the queue has the same `backfilled` setting,
|
||||
# we can just add these new events to that item.
|
||||
end_item = queue[-1]
|
||||
if end_item.backfilled == backfilled:
|
||||
end_item.events_and_contexts.extend(events_and_contexts)
|
||||
@@ -113,11 +123,11 @@ class _EventPeristenceQueue(object):
|
||||
def handle_queue(self, room_id, per_item_callback):
|
||||
"""Attempts to handle the queue for a room if not already being handled.
|
||||
|
||||
The given callback will be invoked with for each item in the queue,1
|
||||
The given callback will be invoked with for each item in the queue,
|
||||
of type _EventPersistQueueItem. The per_item_callback will continuously
|
||||
be called with new items, unless the queue becomnes empty. The return
|
||||
value of the function will be given to the deferreds waiting on the item,
|
||||
exceptions will be passed to the deferres as well.
|
||||
exceptions will be passed to the deferreds as well.
|
||||
|
||||
This function should therefore be called whenever anything is added
|
||||
to the queue.
|
||||
@@ -233,7 +243,7 @@ class EventsStore(SQLBaseStore):
|
||||
|
||||
deferreds = []
|
||||
for room_id, evs_ctxs in partitioned.iteritems():
|
||||
d = preserve_fn(self._event_persist_queue.add_to_queue)(
|
||||
d = self._event_persist_queue.add_to_queue(
|
||||
room_id, evs_ctxs,
|
||||
backfilled=backfilled,
|
||||
)
|
||||
@@ -242,7 +252,7 @@ class EventsStore(SQLBaseStore):
|
||||
for room_id in partitioned:
|
||||
self._maybe_start_persisting(room_id)
|
||||
|
||||
return preserve_context_over_deferred(
|
||||
return make_deferred_yieldable(
|
||||
defer.gatherResults(deferreds, consumeErrors=True)
|
||||
)
|
||||
|
||||
@@ -267,7 +277,7 @@ class EventsStore(SQLBaseStore):
|
||||
|
||||
self._maybe_start_persisting(event.room_id)
|
||||
|
||||
yield preserve_context_over_deferred(deferred)
|
||||
yield make_deferred_yieldable(deferred)
|
||||
|
||||
max_persisted_id = yield self._stream_id_gen.get_current_token()
|
||||
defer.returnValue((event.internal_metadata.stream_ordering, max_persisted_id))
|
||||
@@ -784,6 +794,9 @@ class EventsStore(SQLBaseStore):
|
||||
self._invalidate_cache_and_stream(
|
||||
txn, self.is_host_joined, (room_id, host)
|
||||
)
|
||||
self._invalidate_cache_and_stream(
|
||||
txn, self.was_host_joined, (room_id, host)
|
||||
)
|
||||
|
||||
self._invalidate_cache_and_stream(
|
||||
txn, self.get_users_in_room, (room_id,)
|
||||
@@ -1523,7 +1536,7 @@ class EventsStore(SQLBaseStore):
|
||||
if not allow_rejected:
|
||||
rows[:] = [r for r in rows if not r["rejects"]]
|
||||
|
||||
res = yield preserve_context_over_deferred(defer.gatherResults(
|
||||
res = yield make_deferred_yieldable(defer.gatherResults(
|
||||
[
|
||||
preserve_fn(self._get_event_from_row)(
|
||||
row["internal_metadata"], row["json"], row["redacts"],
|
||||
|
||||
1199
synapse/storage/group_server.py
Normal file
1199
synapse/storage/group_server.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -62,7 +62,7 @@ class MediaRepositoryStore(SQLBaseStore):
|
||||
def get_url_cache_txn(txn):
|
||||
# get the most recently cached result (relative to the given ts)
|
||||
sql = (
|
||||
"SELECT response_code, etag, expires, og, media_id, download_ts"
|
||||
"SELECT response_code, etag, expires_ts, og, media_id, download_ts"
|
||||
" FROM local_media_repository_url_cache"
|
||||
" WHERE url = ? AND download_ts <= ?"
|
||||
" ORDER BY download_ts DESC LIMIT 1"
|
||||
@@ -74,7 +74,7 @@ class MediaRepositoryStore(SQLBaseStore):
|
||||
# ...or if we've requested a timestamp older than the oldest
|
||||
# copy in the cache, return the oldest copy (if any)
|
||||
sql = (
|
||||
"SELECT response_code, etag, expires, og, media_id, download_ts"
|
||||
"SELECT response_code, etag, expires_ts, og, media_id, download_ts"
|
||||
" FROM local_media_repository_url_cache"
|
||||
" WHERE url = ? AND download_ts > ?"
|
||||
" ORDER BY download_ts ASC LIMIT 1"
|
||||
@@ -86,14 +86,14 @@ class MediaRepositoryStore(SQLBaseStore):
|
||||
return None
|
||||
|
||||
return dict(zip((
|
||||
'response_code', 'etag', 'expires', 'og', 'media_id', 'download_ts'
|
||||
'response_code', 'etag', 'expires_ts', 'og', 'media_id', 'download_ts'
|
||||
), row))
|
||||
|
||||
return self.runInteraction(
|
||||
"get_url_cache", get_url_cache_txn
|
||||
)
|
||||
|
||||
def store_url_cache(self, url, response_code, etag, expires, og, media_id,
|
||||
def store_url_cache(self, url, response_code, etag, expires_ts, og, media_id,
|
||||
download_ts):
|
||||
return self._simple_insert(
|
||||
"local_media_repository_url_cache",
|
||||
@@ -101,7 +101,7 @@ class MediaRepositoryStore(SQLBaseStore):
|
||||
"url": url,
|
||||
"response_code": response_code,
|
||||
"etag": etag,
|
||||
"expires": expires,
|
||||
"expires_ts": expires_ts,
|
||||
"og": og,
|
||||
"media_id": media_id,
|
||||
"download_ts": download_ts,
|
||||
@@ -238,3 +238,64 @@ class MediaRepositoryStore(SQLBaseStore):
|
||||
},
|
||||
)
|
||||
return self.runInteraction("delete_remote_media", delete_remote_media_txn)
|
||||
|
||||
def get_expired_url_cache(self, now_ts):
|
||||
sql = (
|
||||
"SELECT media_id FROM local_media_repository_url_cache"
|
||||
" WHERE expires_ts < ?"
|
||||
" ORDER BY expires_ts ASC"
|
||||
" LIMIT 500"
|
||||
)
|
||||
|
||||
def _get_expired_url_cache_txn(txn):
|
||||
txn.execute(sql, (now_ts,))
|
||||
return [row[0] for row in txn]
|
||||
|
||||
return self.runInteraction("get_expired_url_cache", _get_expired_url_cache_txn)
|
||||
|
||||
def delete_url_cache(self, media_ids):
|
||||
sql = (
|
||||
"DELETE FROM local_media_repository_url_cache"
|
||||
" WHERE media_id = ?"
|
||||
)
|
||||
|
||||
def _delete_url_cache_txn(txn):
|
||||
txn.executemany(sql, [(media_id,) for media_id in media_ids])
|
||||
|
||||
return self.runInteraction("delete_url_cache", _delete_url_cache_txn)
|
||||
|
||||
def get_url_cache_media_before(self, before_ts):
|
||||
sql = (
|
||||
"SELECT media_id FROM local_media_repository"
|
||||
" WHERE created_ts < ? AND url_cache IS NOT NULL"
|
||||
" ORDER BY created_ts ASC"
|
||||
" LIMIT 500"
|
||||
)
|
||||
|
||||
def _get_url_cache_media_before_txn(txn):
|
||||
txn.execute(sql, (before_ts,))
|
||||
return [row[0] for row in txn]
|
||||
|
||||
return self.runInteraction(
|
||||
"get_url_cache_media_before", _get_url_cache_media_before_txn,
|
||||
)
|
||||
|
||||
def delete_url_cache_media(self, media_ids):
|
||||
def _delete_url_cache_media_txn(txn):
|
||||
sql = (
|
||||
"DELETE FROM local_media_repository"
|
||||
" WHERE media_id = ?"
|
||||
)
|
||||
|
||||
txn.executemany(sql, [(media_id,) for media_id in media_ids])
|
||||
|
||||
sql = (
|
||||
"DELETE FROM local_media_repository_thumbnails"
|
||||
" WHERE media_id = ?"
|
||||
)
|
||||
|
||||
txn.executemany(sql, [(media_id,) for media_id in media_ids])
|
||||
|
||||
return self.runInteraction(
|
||||
"delete_url_cache_media", _delete_url_cache_media_txn,
|
||||
)
|
||||
|
||||
@@ -25,7 +25,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
# Remember to update this number every time a change is made to database
|
||||
# schema files, so the users will be informed on server restarts.
|
||||
SCHEMA_VERSION = 43
|
||||
SCHEMA_VERSION = 45
|
||||
|
||||
dir_path = os.path.abspath(os.path.dirname(__file__))
|
||||
|
||||
|
||||
@@ -13,6 +13,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from ._base import SQLBaseStore
|
||||
|
||||
|
||||
@@ -55,3 +57,99 @@ class ProfileStore(SQLBaseStore):
|
||||
updatevalues={"avatar_url": new_avatar_url},
|
||||
desc="set_profile_avatar_url",
|
||||
)
|
||||
|
||||
def get_from_remote_profile_cache(self, user_id):
|
||||
return self._simple_select_one(
|
||||
table="remote_profile_cache",
|
||||
keyvalues={"user_id": user_id},
|
||||
retcols=("displayname", "avatar_url",),
|
||||
allow_none=True,
|
||||
desc="get_from_remote_profile_cache",
|
||||
)
|
||||
|
||||
def add_remote_profile_cache(self, user_id, displayname, avatar_url):
|
||||
"""Ensure we are caching the remote user's profiles.
|
||||
|
||||
This should only be called when `is_subscribed_remote_profile_for_user`
|
||||
would return true for the user.
|
||||
"""
|
||||
return self._simple_upsert(
|
||||
table="remote_profile_cache",
|
||||
keyvalues={"user_id": user_id},
|
||||
values={
|
||||
"displayname": displayname,
|
||||
"avatar_url": avatar_url,
|
||||
"last_check": self._clock.time_msec(),
|
||||
},
|
||||
desc="add_remote_profile_cache",
|
||||
)
|
||||
|
||||
def update_remote_profile_cache(self, user_id, displayname, avatar_url):
|
||||
return self._simple_update(
|
||||
table="remote_profile_cache",
|
||||
keyvalues={"user_id": user_id},
|
||||
values={
|
||||
"displayname": displayname,
|
||||
"avatar_url": avatar_url,
|
||||
"last_check": self._clock.time_msec(),
|
||||
},
|
||||
desc="update_remote_profile_cache",
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def maybe_delete_remote_profile_cache(self, user_id):
|
||||
"""Check if we still care about the remote user's profile, and if we
|
||||
don't then remove their profile from the cache
|
||||
"""
|
||||
subscribed = yield self.is_subscribed_remote_profile_for_user(user_id)
|
||||
if not subscribed:
|
||||
yield self._simple_delete(
|
||||
table="remote_profile_cache",
|
||||
keyvalues={"user_id": user_id},
|
||||
desc="delete_remote_profile_cache",
|
||||
)
|
||||
|
||||
def get_remote_profile_cache_entries_that_expire(self, last_checked):
|
||||
"""Get all users who haven't been checked since `last_checked`
|
||||
"""
|
||||
def _get_remote_profile_cache_entries_that_expire_txn(txn):
|
||||
sql = """
|
||||
SELECT user_id, displayname, avatar_url
|
||||
FROM remote_profile_cache
|
||||
WHERE last_check < ?
|
||||
"""
|
||||
|
||||
txn.execute(sql, (last_checked,))
|
||||
|
||||
return self.cursor_to_dict(txn)
|
||||
|
||||
return self.runInteraction(
|
||||
"get_remote_profile_cache_entries_that_expire",
|
||||
_get_remote_profile_cache_entries_that_expire_txn,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def is_subscribed_remote_profile_for_user(self, user_id):
|
||||
"""Check whether we are interested in a remote user's profile.
|
||||
"""
|
||||
res = yield self._simple_select_one_onecol(
|
||||
table="group_users",
|
||||
keyvalues={"user_id": user_id},
|
||||
retcol="user_id",
|
||||
allow_none=True,
|
||||
desc="should_update_remote_profile_cache_for_user",
|
||||
)
|
||||
|
||||
if res:
|
||||
defer.returnValue(True)
|
||||
|
||||
res = yield self._simple_select_one_onecol(
|
||||
table="group_invites",
|
||||
keyvalues={"user_id": user_id},
|
||||
retcol="user_id",
|
||||
allow_none=True,
|
||||
desc="should_update_remote_profile_cache_for_user",
|
||||
)
|
||||
|
||||
if res:
|
||||
defer.returnValue(True)
|
||||
|
||||
@@ -533,6 +533,46 @@ class RoomMemberStore(SQLBaseStore):
|
||||
|
||||
defer.returnValue(True)
|
||||
|
||||
@cachedInlineCallbacks()
|
||||
def was_host_joined(self, room_id, host):
|
||||
"""Check whether the server is or ever was in the room.
|
||||
|
||||
Args:
|
||||
room_id (str)
|
||||
host (str)
|
||||
|
||||
Returns:
|
||||
Deferred: Resolves to True if the host is/was in the room, otherwise
|
||||
False.
|
||||
"""
|
||||
if '%' in host or '_' in host:
|
||||
raise Exception("Invalid host name")
|
||||
|
||||
sql = """
|
||||
SELECT user_id FROM room_memberships
|
||||
WHERE room_id = ?
|
||||
AND user_id LIKE ?
|
||||
AND membership = 'join'
|
||||
LIMIT 1
|
||||
"""
|
||||
|
||||
# We do need to be careful to ensure that host doesn't have any wild cards
|
||||
# in it, but we checked above for known ones and we'll check below that
|
||||
# the returned user actually has the correct domain.
|
||||
like_clause = "%:" + host
|
||||
|
||||
rows = yield self._execute("was_host_joined", None, sql, room_id, like_clause)
|
||||
|
||||
if not rows:
|
||||
defer.returnValue(False)
|
||||
|
||||
user_id = rows[0][0]
|
||||
if get_domain_from_id(user_id) != host:
|
||||
# This can only happen if the host name has something funky in it
|
||||
raise Exception("Invalid host name")
|
||||
|
||||
defer.returnValue(True)
|
||||
|
||||
def get_joined_hosts(self, room_id, state_entry):
|
||||
state_group = state_entry.state_group
|
||||
if not state_group:
|
||||
|
||||
38
synapse/storage/schema/delta/44/expire_url_cache.sql
Normal file
38
synapse/storage/schema/delta/44/expire_url_cache.sql
Normal file
@@ -0,0 +1,38 @@
|
||||
/* Copyright 2017 New Vector Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
CREATE INDEX local_media_repository_url_idx ON local_media_repository(created_ts) WHERE url_cache IS NOT NULL;
|
||||
|
||||
-- we need to change `expires` to `expires_ts` so that we can index on it. SQLite doesn't support
|
||||
-- indices on expressions until 3.9.
|
||||
CREATE TABLE local_media_repository_url_cache_new(
|
||||
url TEXT,
|
||||
response_code INTEGER,
|
||||
etag TEXT,
|
||||
expires_ts BIGINT,
|
||||
og TEXT,
|
||||
media_id TEXT,
|
||||
download_ts BIGINT
|
||||
);
|
||||
|
||||
INSERT INTO local_media_repository_url_cache_new
|
||||
SELECT url, response_code, etag, expires + download_ts, og, media_id, download_ts FROM local_media_repository_url_cache;
|
||||
|
||||
DROP TABLE local_media_repository_url_cache;
|
||||
ALTER TABLE local_media_repository_url_cache_new RENAME TO local_media_repository_url_cache;
|
||||
|
||||
CREATE INDEX local_media_repository_url_cache_expires_idx ON local_media_repository_url_cache(expires_ts);
|
||||
CREATE INDEX local_media_repository_url_cache_by_url_download_ts ON local_media_repository_url_cache(url, download_ts);
|
||||
CREATE INDEX local_media_repository_url_cache_media_idx ON local_media_repository_url_cache(media_id);
|
||||
167
synapse/storage/schema/delta/45/group_server.sql
Normal file
167
synapse/storage/schema/delta/45/group_server.sql
Normal file
@@ -0,0 +1,167 @@
|
||||
/* Copyright 2017 Vector Creations Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
CREATE TABLE groups (
|
||||
group_id TEXT NOT NULL,
|
||||
name TEXT, -- the display name of the room
|
||||
avatar_url TEXT,
|
||||
short_description TEXT,
|
||||
long_description TEXT
|
||||
);
|
||||
|
||||
CREATE UNIQUE INDEX groups_idx ON groups(group_id);
|
||||
|
||||
|
||||
-- list of users the group server thinks are joined
|
||||
CREATE TABLE group_users (
|
||||
group_id TEXT NOT NULL,
|
||||
user_id TEXT NOT NULL,
|
||||
is_admin BOOLEAN NOT NULL,
|
||||
is_public BOOLEAN NOT NULL -- whether the users membership can be seen by everyone
|
||||
);
|
||||
|
||||
|
||||
CREATE INDEX groups_users_g_idx ON group_users(group_id, user_id);
|
||||
CREATE INDEX groups_users_u_idx ON group_users(user_id);
|
||||
|
||||
-- list of users the group server thinks are invited
|
||||
CREATE TABLE group_invites (
|
||||
group_id TEXT NOT NULL,
|
||||
user_id TEXT NOT NULL
|
||||
);
|
||||
|
||||
CREATE INDEX groups_invites_g_idx ON group_invites(group_id, user_id);
|
||||
CREATE INDEX groups_invites_u_idx ON group_invites(user_id);
|
||||
|
||||
|
||||
CREATE TABLE group_rooms (
|
||||
group_id TEXT NOT NULL,
|
||||
room_id TEXT NOT NULL,
|
||||
is_public BOOLEAN NOT NULL -- whether the room can be seen by everyone
|
||||
);
|
||||
|
||||
CREATE UNIQUE INDEX groups_rooms_g_idx ON group_rooms(group_id, room_id);
|
||||
CREATE INDEX groups_rooms_r_idx ON group_rooms(room_id);
|
||||
|
||||
|
||||
-- Rooms to include in the summary
|
||||
CREATE TABLE group_summary_rooms (
|
||||
group_id TEXT NOT NULL,
|
||||
room_id TEXT NOT NULL,
|
||||
category_id TEXT NOT NULL,
|
||||
room_order BIGINT NOT NULL,
|
||||
is_public BOOLEAN NOT NULL, -- whether the room should be show to everyone
|
||||
UNIQUE (group_id, category_id, room_id, room_order),
|
||||
CHECK (room_order > 0)
|
||||
);
|
||||
|
||||
CREATE UNIQUE INDEX group_summary_rooms_g_idx ON group_summary_rooms(group_id, room_id, category_id);
|
||||
|
||||
|
||||
-- Categories to include in the summary
|
||||
CREATE TABLE group_summary_room_categories (
|
||||
group_id TEXT NOT NULL,
|
||||
category_id TEXT NOT NULL,
|
||||
cat_order BIGINT NOT NULL,
|
||||
UNIQUE (group_id, category_id, cat_order),
|
||||
CHECK (cat_order > 0)
|
||||
);
|
||||
|
||||
-- The categories in the group
|
||||
CREATE TABLE group_room_categories (
|
||||
group_id TEXT NOT NULL,
|
||||
category_id TEXT NOT NULL,
|
||||
profile TEXT NOT NULL,
|
||||
is_public BOOLEAN NOT NULL, -- whether the category should be show to everyone
|
||||
UNIQUE (group_id, category_id)
|
||||
);
|
||||
|
||||
-- The users to include in the group summary
|
||||
CREATE TABLE group_summary_users (
|
||||
group_id TEXT NOT NULL,
|
||||
user_id TEXT NOT NULL,
|
||||
role_id TEXT NOT NULL,
|
||||
user_order BIGINT NOT NULL,
|
||||
is_public BOOLEAN NOT NULL -- whether the user should be show to everyone
|
||||
);
|
||||
|
||||
CREATE INDEX group_summary_users_g_idx ON group_summary_users(group_id);
|
||||
|
||||
-- The roles to include in the group summary
|
||||
CREATE TABLE group_summary_roles (
|
||||
group_id TEXT NOT NULL,
|
||||
role_id TEXT NOT NULL,
|
||||
role_order BIGINT NOT NULL,
|
||||
UNIQUE (group_id, role_id, role_order),
|
||||
CHECK (role_order > 0)
|
||||
);
|
||||
|
||||
|
||||
-- The roles in a groups
|
||||
CREATE TABLE group_roles (
|
||||
group_id TEXT NOT NULL,
|
||||
role_id TEXT NOT NULL,
|
||||
profile TEXT NOT NULL,
|
||||
is_public BOOLEAN NOT NULL, -- whether the role should be show to everyone
|
||||
UNIQUE (group_id, role_id)
|
||||
);
|
||||
|
||||
|
||||
-- List of attestations we've given out and need to renew
|
||||
CREATE TABLE group_attestations_renewals (
|
||||
group_id TEXT NOT NULL,
|
||||
user_id TEXT NOT NULL,
|
||||
valid_until_ms BIGINT NOT NULL
|
||||
);
|
||||
|
||||
CREATE INDEX group_attestations_renewals_g_idx ON group_attestations_renewals(group_id, user_id);
|
||||
CREATE INDEX group_attestations_renewals_u_idx ON group_attestations_renewals(user_id);
|
||||
CREATE INDEX group_attestations_renewals_v_idx ON group_attestations_renewals(valid_until_ms);
|
||||
|
||||
|
||||
-- List of attestations we've received from remotes and are interested in.
|
||||
CREATE TABLE group_attestations_remote (
|
||||
group_id TEXT NOT NULL,
|
||||
user_id TEXT NOT NULL,
|
||||
valid_until_ms BIGINT NOT NULL,
|
||||
attestation_json TEXT NOT NULL
|
||||
);
|
||||
|
||||
CREATE INDEX group_attestations_remote_g_idx ON group_attestations_remote(group_id, user_id);
|
||||
CREATE INDEX group_attestations_remote_u_idx ON group_attestations_remote(user_id);
|
||||
CREATE INDEX group_attestations_remote_v_idx ON group_attestations_remote(valid_until_ms);
|
||||
|
||||
|
||||
-- The group membership for the HS's users
|
||||
CREATE TABLE local_group_membership (
|
||||
group_id TEXT NOT NULL,
|
||||
user_id TEXT NOT NULL,
|
||||
is_admin BOOLEAN NOT NULL,
|
||||
membership TEXT NOT NULL,
|
||||
is_publicised BOOLEAN NOT NULL, -- if the user is publicising their membership
|
||||
content TEXT NOT NULL
|
||||
);
|
||||
|
||||
CREATE INDEX local_group_membership_u_idx ON local_group_membership(user_id, group_id);
|
||||
CREATE INDEX local_group_membership_g_idx ON local_group_membership(group_id);
|
||||
|
||||
|
||||
CREATE TABLE local_group_updates (
|
||||
stream_id BIGINT NOT NULL,
|
||||
group_id TEXT NOT NULL,
|
||||
user_id TEXT NOT NULL,
|
||||
type TEXT NOT NULL,
|
||||
content TEXT NOT NULL
|
||||
);
|
||||
28
synapse/storage/schema/delta/45/profile_cache.sql
Normal file
28
synapse/storage/schema/delta/45/profile_cache.sql
Normal file
@@ -0,0 +1,28 @@
|
||||
/* Copyright 2017 New Vector Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
|
||||
-- A subset of remote users whose profiles we have cached.
|
||||
-- Whether a user is in this table or not is defined by the storage function
|
||||
-- `is_subscribed_remote_profile_for_user`
|
||||
CREATE TABLE remote_profile_cache (
|
||||
user_id TEXT NOT NULL,
|
||||
displayname TEXT,
|
||||
avatar_url TEXT,
|
||||
last_check BIGINT NOT NULL
|
||||
);
|
||||
|
||||
CREATE UNIQUE INDEX remote_profile_cache_user_id ON remote_profile_cache(user_id);
|
||||
CREATE INDEX remote_profile_cache_time ON remote_profile_cache(last_check);
|
||||
@@ -45,6 +45,7 @@ class EventSources(object):
|
||||
push_rules_key, _ = self.store.get_push_rules_stream_token()
|
||||
to_device_key = self.store.get_to_device_stream_token()
|
||||
device_list_key = self.store.get_device_stream_token()
|
||||
groups_key = self.store.get_group_stream_token()
|
||||
|
||||
token = StreamToken(
|
||||
room_key=(
|
||||
@@ -65,6 +66,7 @@ class EventSources(object):
|
||||
push_rules_key=push_rules_key,
|
||||
to_device_key=to_device_key,
|
||||
device_list_key=device_list_key,
|
||||
groups_key=groups_key,
|
||||
)
|
||||
defer.returnValue(token)
|
||||
|
||||
@@ -73,6 +75,7 @@ class EventSources(object):
|
||||
push_rules_key, _ = self.store.get_push_rules_stream_token()
|
||||
to_device_key = self.store.get_to_device_stream_token()
|
||||
device_list_key = self.store.get_device_stream_token()
|
||||
groups_key = self.store.get_group_stream_token()
|
||||
|
||||
token = StreamToken(
|
||||
room_key=(
|
||||
@@ -93,5 +96,6 @@ class EventSources(object):
|
||||
push_rules_key=push_rules_key,
|
||||
to_device_key=to_device_key,
|
||||
device_list_key=device_list_key,
|
||||
groups_key=groups_key,
|
||||
)
|
||||
defer.returnValue(token)
|
||||
|
||||
@@ -156,6 +156,11 @@ class EventID(DomainSpecificString):
|
||||
SIGIL = "$"
|
||||
|
||||
|
||||
class GroupID(DomainSpecificString):
|
||||
"""Structure representing a group ID."""
|
||||
SIGIL = "+"
|
||||
|
||||
|
||||
class StreamToken(
|
||||
namedtuple("Token", (
|
||||
"room_key",
|
||||
@@ -166,6 +171,7 @@ class StreamToken(
|
||||
"push_rules_key",
|
||||
"to_device_key",
|
||||
"device_list_key",
|
||||
"groups_key",
|
||||
))
|
||||
):
|
||||
_SEPARATOR = "_"
|
||||
@@ -204,6 +210,7 @@ class StreamToken(
|
||||
or (int(other.push_rules_key) < int(self.push_rules_key))
|
||||
or (int(other.to_device_key) < int(self.to_device_key))
|
||||
or (int(other.device_list_key) < int(self.device_list_key))
|
||||
or (int(other.groups_key) < int(self.groups_key))
|
||||
)
|
||||
|
||||
def copy_and_advance(self, key, new_value):
|
||||
|
||||
@@ -19,7 +19,7 @@ from twisted.internet import defer, reactor
|
||||
from .logcontext import (
|
||||
PreserveLoggingContext, preserve_fn, preserve_context_over_deferred,
|
||||
)
|
||||
from synapse.util import unwrapFirstError
|
||||
from synapse.util import logcontext, unwrapFirstError
|
||||
|
||||
from contextlib import contextmanager
|
||||
|
||||
@@ -53,6 +53,11 @@ class ObservableDeferred(object):
|
||||
|
||||
Cancelling or otherwise resolving an observer will not affect the original
|
||||
ObservableDeferred.
|
||||
|
||||
NB that it does not attempt to do anything with logcontexts; in general
|
||||
you should probably make_deferred_yieldable the deferreds
|
||||
returned by `observe`, and ensure that the original deferred runs its
|
||||
callbacks in the sentinel logcontext.
|
||||
"""
|
||||
|
||||
__slots__ = ["_deferred", "_observers", "_result"]
|
||||
@@ -155,7 +160,7 @@ def concurrently_execute(func, args, limit):
|
||||
except StopIteration:
|
||||
pass
|
||||
|
||||
return preserve_context_over_deferred(defer.gatherResults([
|
||||
return logcontext.make_deferred_yieldable(defer.gatherResults([
|
||||
preserve_fn(_concurrently_execute_inner)()
|
||||
for _ in xrange(limit)
|
||||
], consumeErrors=True)).addErrback(unwrapFirstError)
|
||||
@@ -203,7 +208,26 @@ class Linearizer(object):
|
||||
except:
|
||||
logger.exception("Unexpected exception in Linearizer")
|
||||
|
||||
logger.info("Acquired linearizer lock %r for key %r", self.name, key)
|
||||
logger.info("Acquired linearizer lock %r for key %r", self.name,
|
||||
key)
|
||||
|
||||
# if the code holding the lock completes synchronously, then it
|
||||
# will recursively run the next claimant on the list. That can
|
||||
# relatively rapidly lead to stack exhaustion. This is essentially
|
||||
# the same problem as http://twistedmatrix.com/trac/ticket/9304.
|
||||
#
|
||||
# In order to break the cycle, we add a cheeky sleep(0) here to
|
||||
# ensure that we fall back to the reactor between each iteration.
|
||||
#
|
||||
# (There's no particular need for it to happen before we return
|
||||
# the context manager, but it needs to happen while we hold the
|
||||
# lock, and the context manager's exit code must be synchronous,
|
||||
# so actually this is the only sensible place.
|
||||
yield run_on_reactor()
|
||||
|
||||
else:
|
||||
logger.info("Acquired uncontended linearizer lock %r for key %r",
|
||||
self.name, key)
|
||||
|
||||
@contextmanager
|
||||
def _ctx_manager():
|
||||
@@ -211,7 +235,8 @@ class Linearizer(object):
|
||||
yield
|
||||
finally:
|
||||
logger.info("Releasing linearizer lock %r for key %r", self.name, key)
|
||||
new_defer.callback(None)
|
||||
with PreserveLoggingContext():
|
||||
new_defer.callback(None)
|
||||
current_d = self.key_to_defer.get(key)
|
||||
if current_d is new_defer:
|
||||
self.key_to_defer.pop(key, None)
|
||||
|
||||
51
synapse/util/logformatter.py
Normal file
51
synapse/util/logformatter.py
Normal file
@@ -0,0 +1,51 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2017 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import StringIO
|
||||
import logging
|
||||
import traceback
|
||||
|
||||
|
||||
class LogFormatter(logging.Formatter):
|
||||
"""Log formatter which gives more detail for exceptions
|
||||
|
||||
This is the same as the standard log formatter, except that when logging
|
||||
exceptions [typically via log.foo("msg", exc_info=1)], it prints the
|
||||
sequence that led up to the point at which the exception was caught.
|
||||
(Normally only stack frames between the point the exception was raised and
|
||||
where it was caught are logged).
|
||||
"""
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(LogFormatter, self).__init__(*args, **kwargs)
|
||||
|
||||
def formatException(self, ei):
|
||||
sio = StringIO.StringIO()
|
||||
(typ, val, tb) = ei
|
||||
|
||||
# log the stack above the exception capture point if possible, but
|
||||
# check that we actually have an f_back attribute to work around
|
||||
# https://twistedmatrix.com/trac/ticket/9305
|
||||
|
||||
if tb and hasattr(tb.tb_frame, 'f_back'):
|
||||
sio.write("Capture point (most recent call last):\n")
|
||||
traceback.print_stack(tb.tb_frame.f_back, None, sio)
|
||||
|
||||
traceback.print_exception(typ, val, tb, None, sio)
|
||||
s = sio.getvalue()
|
||||
sio.close()
|
||||
if s[-1:] == "\n":
|
||||
s = s[:-1]
|
||||
return s
|
||||
42
synapse/util/module_loader.py
Normal file
42
synapse/util/module_loader.py
Normal file
@@ -0,0 +1,42 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2017 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import importlib
|
||||
|
||||
from synapse.config._base import ConfigError
|
||||
|
||||
|
||||
def load_module(provider):
|
||||
""" Loads a module with its config
|
||||
Take a dict with keys 'module' (the module name) and 'config'
|
||||
(the config dict).
|
||||
|
||||
Returns
|
||||
Tuple of (provider class, parsed config object)
|
||||
"""
|
||||
# We need to import the module, and then pick the class out of
|
||||
# that, so we split based on the last dot.
|
||||
module, clz = provider['module'].rsplit(".", 1)
|
||||
module = importlib.import_module(module)
|
||||
provider_class = getattr(module, clz)
|
||||
|
||||
try:
|
||||
provider_config = provider_class.parse_config(provider["config"])
|
||||
except Exception as e:
|
||||
raise ConfigError(
|
||||
"Failed to parse config for %r: %r" % (provider['module'], e)
|
||||
)
|
||||
|
||||
return provider_class, provider_config
|
||||
@@ -62,8 +62,6 @@ class ProfileTestCase(unittest.TestCase):
|
||||
self.ratelimiter = hs.get_ratelimiter()
|
||||
self.ratelimiter.send_message.return_value = (True, 0)
|
||||
|
||||
hs.handlers = ProfileHandlers(hs)
|
||||
|
||||
self.store = hs.get_datastore()
|
||||
|
||||
self.frank = UserID.from_string("@1234ABCD:test")
|
||||
@@ -72,7 +70,7 @@ class ProfileTestCase(unittest.TestCase):
|
||||
|
||||
yield self.store.create_profile(self.frank.localpart)
|
||||
|
||||
self.handler = hs.get_handlers().profile_handler
|
||||
self.handler = hs.get_profile_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_get_my_name(self):
|
||||
|
||||
@@ -40,13 +40,14 @@ class RegistrationTestCase(unittest.TestCase):
|
||||
self.hs = yield setup_test_homeserver(
|
||||
handlers=None,
|
||||
http_client=None,
|
||||
expire_access_token=True)
|
||||
expire_access_token=True,
|
||||
profile_handler=Mock(),
|
||||
)
|
||||
self.macaroon_generator = Mock(
|
||||
generate_access_token=Mock(return_value='secret'))
|
||||
self.hs.get_macaroon_generator = Mock(return_value=self.macaroon_generator)
|
||||
self.hs.handlers = RegistrationHandlers(self.hs)
|
||||
self.handler = self.hs.get_handlers().registration_handler
|
||||
self.hs.get_handlers().profile_handler = Mock()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_user_is_created_and_logged_in_if_doesnt_exist(self):
|
||||
|
||||
@@ -46,6 +46,7 @@ class ProfileTestCase(unittest.TestCase):
|
||||
resource_for_client=self.mock_resource,
|
||||
federation=Mock(),
|
||||
replication_layer=Mock(),
|
||||
profile_handler=self.mock_handler
|
||||
)
|
||||
|
||||
def _get_user_by_req(request=None, allow_guest=False):
|
||||
@@ -53,8 +54,6 @@ class ProfileTestCase(unittest.TestCase):
|
||||
|
||||
hs.get_v1auth().get_user_by_req = _get_user_by_req
|
||||
|
||||
hs.get_handlers().profile_handler = self.mock_handler
|
||||
|
||||
profile.register_servlets(hs, self.mock_resource)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
||||
@@ -1032,7 +1032,7 @@ class RoomMessageListTestCase(RestTestCase):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_topo_token_is_accepted(self):
|
||||
token = "t1-0_0_0_0_0_0_0_0"
|
||||
token = "t1-0_0_0_0_0_0_0_0_0"
|
||||
(code, response) = yield self.mock_resource.trigger_get(
|
||||
"/rooms/%s/messages?access_token=x&from=%s" %
|
||||
(self.room_id, token))
|
||||
@@ -1044,7 +1044,7 @@ class RoomMessageListTestCase(RestTestCase):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_stream_token_is_accepted_for_fwd_pagianation(self):
|
||||
token = "s0_0_0_0_0_0_0_0"
|
||||
token = "s0_0_0_0_0_0_0_0_0"
|
||||
(code, response) = yield self.mock_resource.trigger_get(
|
||||
"/rooms/%s/messages?access_token=x&from=%s" %
|
||||
(self.room_id, token))
|
||||
|
||||
@@ -47,6 +47,7 @@ class RegisterRestServletTestCase(unittest.TestCase):
|
||||
self.hs.get_auth_handler = Mock(return_value=self.auth_handler)
|
||||
self.hs.get_device_handler = Mock(return_value=self.device_handler)
|
||||
self.hs.config.enable_registration = True
|
||||
self.hs.config.auto_join_rooms = []
|
||||
|
||||
# init the thing we're testing
|
||||
self.servlet = RegisterRestServlet(self.hs)
|
||||
|
||||
@@ -1,76 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2015, 2016 OpenMarket Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.constants import EventTypes
|
||||
|
||||
|
||||
class EventInjector:
|
||||
def __init__(self, hs):
|
||||
self.hs = hs
|
||||
self.store = hs.get_datastore()
|
||||
self.message_handler = hs.get_handlers().message_handler
|
||||
self.event_builder_factory = hs.get_event_builder_factory()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def create_room(self, room, user):
|
||||
builder = self.event_builder_factory.new({
|
||||
"type": EventTypes.Create,
|
||||
"sender": user.to_string(),
|
||||
"room_id": room.to_string(),
|
||||
"content": {},
|
||||
})
|
||||
|
||||
event, context = yield self.message_handler._create_new_client_event(
|
||||
builder
|
||||
)
|
||||
|
||||
yield self.store.persist_event(event, context)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def inject_room_member(self, room, user, membership):
|
||||
builder = self.event_builder_factory.new({
|
||||
"type": EventTypes.Member,
|
||||
"sender": user.to_string(),
|
||||
"state_key": user.to_string(),
|
||||
"room_id": room.to_string(),
|
||||
"content": {"membership": membership},
|
||||
})
|
||||
|
||||
event, context = yield self.message_handler._create_new_client_event(
|
||||
builder
|
||||
)
|
||||
|
||||
yield self.store.persist_event(event, context)
|
||||
|
||||
defer.returnValue(event)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def inject_message(self, room, user, body):
|
||||
builder = self.event_builder_factory.new({
|
||||
"type": EventTypes.Message,
|
||||
"sender": user.to_string(),
|
||||
"state_key": user.to_string(),
|
||||
"room_id": room.to_string(),
|
||||
"content": {"body": body, "msgtype": u"message"},
|
||||
})
|
||||
|
||||
event, context = yield self.message_handler._create_new_client_event(
|
||||
builder
|
||||
)
|
||||
|
||||
yield self.store.persist_event(event, context)
|
||||
@@ -12,8 +12,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from synapse.util import async, logcontext
|
||||
from tests import unittest
|
||||
|
||||
from twisted.internet import defer
|
||||
@@ -38,7 +37,28 @@ class LinearizerTestCase(unittest.TestCase):
|
||||
with cm1:
|
||||
self.assertFalse(d2.called)
|
||||
|
||||
self.assertTrue(d2.called)
|
||||
|
||||
with (yield d2):
|
||||
pass
|
||||
|
||||
def test_lots_of_queued_things(self):
|
||||
# we have one slow thing, and lots of fast things queued up behind it.
|
||||
# it should *not* explode the stack.
|
||||
linearizer = Linearizer()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def func(i, sleep=False):
|
||||
with logcontext.LoggingContext("func(%s)" % i) as lc:
|
||||
with (yield linearizer.queue("")):
|
||||
self.assertEqual(
|
||||
logcontext.LoggingContext.current_context(), lc)
|
||||
if sleep:
|
||||
yield async.sleep(0)
|
||||
|
||||
self.assertEqual(
|
||||
logcontext.LoggingContext.current_context(), lc)
|
||||
|
||||
func(0, sleep=True)
|
||||
for i in xrange(1, 100):
|
||||
func(i)
|
||||
|
||||
return func(1000)
|
||||
|
||||
@@ -94,3 +94,41 @@ class LoggingContextTestCase(unittest.TestCase):
|
||||
yield defer.succeed(None)
|
||||
|
||||
return self._test_preserve_fn(nonblocking_function)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_make_deferred_yieldable(self):
|
||||
# a function which retuns an incomplete deferred, but doesn't follow
|
||||
# the synapse rules.
|
||||
def blocking_function():
|
||||
d = defer.Deferred()
|
||||
reactor.callLater(0, d.callback, None)
|
||||
return d
|
||||
|
||||
sentinel_context = LoggingContext.current_context()
|
||||
|
||||
with LoggingContext() as context_one:
|
||||
context_one.test_key = "one"
|
||||
|
||||
d1 = logcontext.make_deferred_yieldable(blocking_function())
|
||||
# make sure that the context was reset by make_deferred_yieldable
|
||||
self.assertIs(LoggingContext.current_context(), sentinel_context)
|
||||
|
||||
yield d1
|
||||
|
||||
# now it should be restored
|
||||
self._check_test_key("one")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def test_make_deferred_yieldable_on_non_deferred(self):
|
||||
"""Check that make_deferred_yieldable does the right thing when its
|
||||
argument isn't actually a deferred"""
|
||||
|
||||
with LoggingContext() as context_one:
|
||||
context_one.test_key = "one"
|
||||
|
||||
d1 = logcontext.make_deferred_yieldable("bum")
|
||||
self._check_test_key("one")
|
||||
|
||||
r = yield d1
|
||||
self.assertEqual(r, "bum")
|
||||
self._check_test_key("one")
|
||||
Reference in New Issue
Block a user