1
0

Merge branch 'develop' of github.com:matrix-org/synapse into neilj/disable-mau-alerting-for-small-instances

This commit is contained in:
Neil Johnson
2019-10-11 09:41:10 +01:00
113 changed files with 2008 additions and 1120 deletions

1
.gitignore vendored
View File

@@ -7,6 +7,7 @@
*.egg-info
*.lock
*.pyc
*.snap
*.tac
_trial_temp/
_trial_temp*/

View File

@@ -47,7 +47,5 @@ prune debian
prune demo/etc
prune docker
prune mypy.ini
prune snap
prune stubs
exclude jenkins*
recursive-exclude jenkins *.sh

1
changelog.d/2380.bugfix Normal file
View File

@@ -0,0 +1 @@
Return an HTTP 404 instead of 400 when requesting a filter by ID that is unknown to the server. Thanks to @krombel for contributing this!

1
changelog.d/3436.bugfix Normal file
View File

@@ -0,0 +1 @@
Fix a problem where users could be invited twice to the same group.

1
changelog.d/4088.bugfix Normal file
View File

@@ -0,0 +1 @@
Added domain validation when including a list of invitees upon room creation.

1
changelog.d/6084.misc Normal file
View File

@@ -0,0 +1 @@
Add snapcraft packaging information. Contributed by @devec0.

1
changelog.d/6109.bugfix Normal file
View File

@@ -0,0 +1 @@
Fix bug when uploading a large file: Synapse responds with `M_UNKNOWN` while it should be `M_TOO_LARGE` according to spec. Contributed by Anshul Angaria.

1
changelog.d/6127.misc Normal file
View File

@@ -0,0 +1 @@
Add env var to turn on tracking of log context changes.

1
changelog.d/6137.misc Normal file
View File

@@ -0,0 +1 @@
Refactor configuration loading to allow better typechecking.

1
changelog.d/6139.misc Normal file
View File

@@ -0,0 +1 @@
Log responder when responding to media request.

1
changelog.d/6155.bugfix Normal file
View File

@@ -0,0 +1 @@
Fix transferring notifications and tags when joining an upgraded room that is new to your server.

1
changelog.d/6156.misc Normal file
View File

@@ -0,0 +1 @@
Use Postgres ANY for selecting many values.

1
changelog.d/6159.misc Normal file
View File

@@ -0,0 +1 @@
Add more caching to `_get_joined_users_from_context` DB query.

1
changelog.d/6161.bugfix Normal file
View File

@@ -0,0 +1 @@
Fix bug where guest account registration can wedge after restart.

1
changelog.d/6167.misc Normal file
View File

@@ -0,0 +1 @@
Add some logging to the rooms stats updates, to try to track down a flaky test.

1
changelog.d/6168.bugfix Normal file
View File

@@ -0,0 +1 @@
Fix monthly active user reaping where reserved users are specified.

1
changelog.d/6170.bugfix Normal file
View File

@@ -0,0 +1 @@
Fix /federation/v1/state endpoint for recent room versions.

1
changelog.d/6175.misc Normal file
View File

@@ -0,0 +1 @@
Update `user_filters` table to have a unique index, and non-null columns. Thanks to @pik for contributing this.

1
changelog.d/6178.bugfix Normal file
View File

@@ -0,0 +1 @@
Make the `synapse_port_db` script create the right indexes on a new PostgreSQL database.

1
changelog.d/6179.misc Normal file
View File

@@ -0,0 +1 @@
Remove unused `timeout` parameter from `_get_public_room_list`.

1
changelog.d/6184.misc Normal file
View File

@@ -0,0 +1 @@
Update `user_filters` table to have a unique index, and non-null columns. Thanks to @pik for contributing this.

1
changelog.d/6185.bugfix Normal file
View File

@@ -0,0 +1 @@
Fix bug where redacted events were sometimes incorrectly censored in the database, breaking APIs that attempted to fetch such events.

1
changelog.d/6186.bugfix Normal file
View File

@@ -0,0 +1 @@
Fix bug where we were updating censored events as bytes rather than text, occaisonally causing invalid JSON being inserted breaking APIs that attempted to fetch such events.

1
changelog.d/6187.bugfix Normal file
View File

@@ -0,0 +1 @@
Fix occasional missed updates in the room and user directories.

1
changelog.d/6191.misc Normal file
View File

@@ -0,0 +1 @@
Add snapcraft packaging information. Contributed by @devec0.

View File

@@ -4,10 +4,6 @@ plugins=mypy_zope:plugin
follow_imports=skip
mypy_path=stubs
[mypy-synapse.config.homeserver]
# this is a mess because of the metaclass shenanigans
ignore_errors = True
[mypy-zope]
ignore_missing_imports = True
@@ -52,3 +48,15 @@ ignore_missing_imports = True
[mypy-signedjson.*]
ignore_missing_imports = True
[mypy-prometheus_client.*]
ignore_missing_imports = True
[mypy-service_identity.*]
ignore_missing_imports = True
[mypy-daemonize]
ignore_missing_imports = True
[mypy-sentry_sdk]
ignore_missing_imports = True

22
snap/snapcraft.yaml Normal file
View File

@@ -0,0 +1,22 @@
name: matrix-synapse
base: core18
version: git
summary: Reference Matrix homeserver
description: |
Synapse is the reference Matrix homeserver.
Matrix is a federated and decentralised instant messaging and VoIP system.
grade: stable
confinement: strict
apps:
matrix-synapse:
command: synctl --no-daemonize start $SNAP_COMMON/homeserver.yaml
stop-command: synctl -c $SNAP_COMMON stop
plugs: [network-bind, network]
daemon: simple
parts:
matrix-synapse:
source: .
plugin: python
python-version: python3

View File

@@ -17,6 +17,7 @@
""" This is a reference implementation of a Matrix home server.
"""
import os
import sys
# Check that we're not running on an unsupported Python version.
@@ -36,3 +37,10 @@ except ImportError:
pass
__version__ = "1.4.0"
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
# We import here so that we don't have to install a bunch of deps when
# running the packaging tox test.
from synapse.util.patch_inline_callbacks import do_patch
do_patch()

View File

@@ -605,13 +605,13 @@ def run(hs):
@defer.inlineCallbacks
def generate_monthly_active_users():
current_mau_count = 0
reserved_count = 0
reserved_users = ()
store = hs.get_datastore()
if hs.config.limit_usage_by_mau or hs.config.mau_stats_only:
current_mau_count = yield store.get_monthly_active_count()
reserved_count = yield store.get_registered_reserved_users_count()
reserved_users = yield store.get_registered_reserved_users()
current_mau_gauge.set(float(current_mau_count))
registered_reserved_users_mau_gauge.set(float(reserved_count))
registered_reserved_users_mau_gauge.set(float(len(reserved_users)))
max_mau_gauge.set(float(hs.config.max_mau_value))
def start_generate_monthly_active_users():

View File

@@ -18,7 +18,9 @@
import argparse
import errno
import os
from collections import OrderedDict
from textwrap import dedent
from typing import Any, MutableMapping, Optional
from six import integer_types
@@ -51,7 +53,56 @@ Missing mandatory `server_name` config option.
"""
def path_exists(file_path):
"""Check if a file exists
Unlike os.path.exists, this throws an exception if there is an error
checking if the file exists (for example, if there is a perms error on
the parent dir).
Returns:
bool: True if the file exists; False if not.
"""
try:
os.stat(file_path)
return True
except OSError as e:
if e.errno != errno.ENOENT:
raise e
return False
class Config(object):
"""
A configuration section, containing configuration keys and values.
Attributes:
section (str): The section title of this config object, such as
"tls" or "logger". This is used to refer to it on the root
logger (for example, `config.tls.some_option`). Must be
defined in subclasses.
"""
section = None
def __init__(self, root_config=None):
self.root = root_config
def __getattr__(self, item: str) -> Any:
"""
Try and fetch a configuration option that does not exist on this class.
This is so that existing configs that rely on `self.value`, where value
is actually from a different config section, continue to work.
"""
if item in ["generate_config_section", "read_config"]:
raise AttributeError(item)
if self.root is None:
raise AttributeError(item)
else:
return self.root._get_unclassed_config(self.section, item)
@staticmethod
def parse_size(value):
if isinstance(value, integer_types):
@@ -88,22 +139,7 @@ class Config(object):
@classmethod
def path_exists(cls, file_path):
"""Check if a file exists
Unlike os.path.exists, this throws an exception if there is an error
checking if the file exists (for example, if there is a perms error on
the parent dir).
Returns:
bool: True if the file exists; False if not.
"""
try:
os.stat(file_path)
return True
except OSError as e:
if e.errno != errno.ENOENT:
raise e
return False
return path_exists(file_path)
@classmethod
def check_file(cls, file_path, config_name):
@@ -136,42 +172,106 @@ class Config(object):
with open(file_path) as file_stream:
return file_stream.read()
def invoke_all(self, name, *args, **kargs):
"""Invoke all instance methods with the given name and arguments in the
class's MRO.
class RootConfig(object):
"""
Holder of an application's configuration.
What configuration this object holds is defined by `config_classes`, a list
of Config classes that will be instantiated and given the contents of a
configuration file to read. They can then be accessed on this class by their
section name, defined in the Config or dynamically set to be the name of the
class, lower-cased and with "Config" removed.
"""
config_classes = []
def __init__(self):
self._configs = OrderedDict()
for config_class in self.config_classes:
if config_class.section is None:
raise ValueError("%r requires a section name" % (config_class,))
try:
conf = config_class(self)
except Exception as e:
raise Exception("Failed making %s: %r" % (config_class.section, e))
self._configs[config_class.section] = conf
def __getattr__(self, item: str) -> Any:
"""
Redirect lookups on this object either to config objects, or values on
config objects, so that `config.tls.blah` works, as well as legacy uses
of things like `config.server_name`. It will first look up the config
section name, and then values on those config classes.
"""
if item in self._configs.keys():
return self._configs[item]
return self._get_unclassed_config(None, item)
def _get_unclassed_config(self, asking_section: Optional[str], item: str):
"""
Fetch a config value from one of the instantiated config classes that
has not been fetched directly.
Args:
name (str): Name of function to invoke
asking_section: If this check is coming from a Config child, which
one? This section will not be asked if it has the value.
item: The configuration value key.
Raises:
AttributeError if no config classes have the config key. The body
will contain what sections were checked.
"""
for key, val in self._configs.items():
if key == asking_section:
continue
if item in dir(val):
return getattr(val, item)
raise AttributeError(item, "not found in %s" % (list(self._configs.keys()),))
def invoke_all(self, func_name: str, *args, **kwargs) -> MutableMapping[str, Any]:
"""
Invoke a function on all instantiated config objects this RootConfig is
configured to use.
Args:
func_name: Name of function to invoke
*args
**kwargs
Returns:
list: The list of the return values from each method called
ordered dictionary of config section name and the result of the
function from it.
"""
results = []
for cls in type(self).mro():
if name in cls.__dict__:
results.append(getattr(cls, name)(self, *args, **kargs))
return results
res = OrderedDict()
for name, config in self._configs.items():
if hasattr(config, func_name):
res[name] = getattr(config, func_name)(*args, **kwargs)
return res
@classmethod
def invoke_all_static(cls, name, *args, **kargs):
"""Invoke all static methods with the given name and arguments in the
class's MRO.
def invoke_all_static(cls, func_name: str, *args, **kwargs):
"""
Invoke a static function on config objects this RootConfig is
configured to use.
Args:
name (str): Name of function to invoke
func_name: Name of function to invoke
*args
**kwargs
Returns:
list: The list of the return values from each method called
ordered dictionary of config section name and the result of the
function from it.
"""
results = []
for c in cls.mro():
if name in c.__dict__:
results.append(getattr(c, name)(*args, **kargs))
return results
for config in cls.config_classes:
if hasattr(config, func_name):
getattr(config, func_name)(*args, **kwargs)
def generate_config(
self,
@@ -187,7 +287,8 @@ class Config(object):
tls_private_key_path=None,
acme_domain=None,
):
"""Build a default configuration file
"""
Build a default configuration file
This is used when the user explicitly asks us to generate a config file
(eg with --generate_config).
@@ -242,6 +343,7 @@ class Config(object):
Returns:
str: the yaml config file
"""
return "\n\n".join(
dedent(conf)
for conf in self.invoke_all(
@@ -257,7 +359,7 @@ class Config(object):
tls_certificate_path=tls_certificate_path,
tls_private_key_path=tls_private_key_path,
acme_domain=acme_domain,
)
).values()
)
@classmethod
@@ -444,7 +546,7 @@ class Config(object):
)
(config_path,) = config_files
if not cls.path_exists(config_path):
if not path_exists(config_path):
print("Generating config file %s" % (config_path,))
if config_args.data_directory:
@@ -469,7 +571,7 @@ class Config(object):
open_private_ports=config_args.open_private_ports,
)
if not cls.path_exists(config_dir_path):
if not path_exists(config_dir_path):
os.makedirs(config_dir_path)
with open(config_path, "w") as config_file:
config_file.write("# vim:ft=yaml\n\n")
@@ -518,7 +620,7 @@ class Config(object):
return obj
def parse_config_dict(self, config_dict, config_dir_path, data_dir_path):
def parse_config_dict(self, config_dict, config_dir_path=None, data_dir_path=None):
"""Read the information from the config dict into this Config object.
Args:
@@ -607,3 +709,6 @@ def find_config_files(search_paths):
else:
config_files.append(config_path)
return config_files
__all__ = ["Config", "RootConfig"]

135
synapse/config/_base.pyi Normal file
View File

@@ -0,0 +1,135 @@
from typing import Any, List, Optional
from synapse.config import (
api,
appservice,
captcha,
cas,
consent_config,
database,
emailconfig,
groups,
jwt_config,
key,
logger,
metrics,
password,
password_auth_providers,
push,
ratelimiting,
registration,
repository,
room_directory,
saml2_config,
server,
server_notices_config,
spam_checker,
stats,
third_party_event_rules,
tls,
tracer,
user_directory,
voip,
workers,
)
class ConfigError(Exception): ...
MISSING_REPORT_STATS_CONFIG_INSTRUCTIONS: str
MISSING_REPORT_STATS_SPIEL: str
MISSING_SERVER_NAME: str
def path_exists(file_path: str): ...
class RootConfig:
server: server.ServerConfig
tls: tls.TlsConfig
database: database.DatabaseConfig
logging: logger.LoggingConfig
ratelimit: ratelimiting.RatelimitConfig
media: repository.ContentRepositoryConfig
captcha: captcha.CaptchaConfig
voip: voip.VoipConfig
registration: registration.RegistrationConfig
metrics: metrics.MetricsConfig
api: api.ApiConfig
appservice: appservice.AppServiceConfig
key: key.KeyConfig
saml2: saml2_config.SAML2Config
cas: cas.CasConfig
jwt: jwt_config.JWTConfig
password: password.PasswordConfig
email: emailconfig.EmailConfig
worker: workers.WorkerConfig
authproviders: password_auth_providers.PasswordAuthProviderConfig
push: push.PushConfig
spamchecker: spam_checker.SpamCheckerConfig
groups: groups.GroupsConfig
userdirectory: user_directory.UserDirectoryConfig
consent: consent_config.ConsentConfig
stats: stats.StatsConfig
servernotices: server_notices_config.ServerNoticesConfig
roomdirectory: room_directory.RoomDirectoryConfig
thirdpartyrules: third_party_event_rules.ThirdPartyRulesConfig
tracer: tracer.TracerConfig
config_classes: List = ...
def __init__(self) -> None: ...
def invoke_all(self, func_name: str, *args: Any, **kwargs: Any): ...
@classmethod
def invoke_all_static(cls, func_name: str, *args: Any, **kwargs: Any) -> None: ...
def __getattr__(self, item: str): ...
def parse_config_dict(
self,
config_dict: Any,
config_dir_path: Optional[Any] = ...,
data_dir_path: Optional[Any] = ...,
) -> None: ...
read_config: Any = ...
def generate_config(
self,
config_dir_path: str,
data_dir_path: str,
server_name: str,
generate_secrets: bool = ...,
report_stats: Optional[str] = ...,
open_private_ports: bool = ...,
listeners: Optional[Any] = ...,
database_conf: Optional[Any] = ...,
tls_certificate_path: Optional[str] = ...,
tls_private_key_path: Optional[str] = ...,
acme_domain: Optional[str] = ...,
): ...
@classmethod
def load_or_generate_config(cls, description: Any, argv: Any): ...
@classmethod
def load_config(cls, description: Any, argv: Any): ...
@classmethod
def add_arguments_to_parser(cls, config_parser: Any) -> None: ...
@classmethod
def load_config_with_parser(cls, parser: Any, argv: Any): ...
def generate_missing_files(
self, config_dict: dict, config_dir_path: str
) -> None: ...
class Config:
root: RootConfig
def __init__(self, root_config: Optional[RootConfig] = ...) -> None: ...
def __getattr__(self, item: str, from_root: bool = ...): ...
@staticmethod
def parse_size(value: Any): ...
@staticmethod
def parse_duration(value: Any): ...
@staticmethod
def abspath(file_path: Optional[str]): ...
@classmethod
def path_exists(cls, file_path: str): ...
@classmethod
def check_file(cls, file_path: str, config_name: str): ...
@classmethod
def ensure_directory(cls, dir_path: str): ...
@classmethod
def read_file(cls, file_path: str, config_name: str): ...
def read_config_files(config_files: List[str]): ...
def find_config_files(search_paths: List[str]): ...

View File

@@ -18,6 +18,8 @@ from ._base import Config
class ApiConfig(Config):
section = "api"
def read_config(self, config, **kwargs):
self.room_invite_state_types = config.get(
"room_invite_state_types",

View File

@@ -30,6 +30,8 @@ logger = logging.getLogger(__name__)
class AppServiceConfig(Config):
section = "appservice"
def read_config(self, config, **kwargs):
self.app_service_config_files = config.get("app_service_config_files", [])
self.notify_appservices = config.get("notify_appservices", True)

View File

@@ -16,6 +16,8 @@ from ._base import Config
class CaptchaConfig(Config):
section = "captcha"
def read_config(self, config, **kwargs):
self.recaptcha_private_key = config.get("recaptcha_private_key")
self.recaptcha_public_key = config.get("recaptcha_public_key")

View File

@@ -22,6 +22,8 @@ class CasConfig(Config):
cas_server_url: URL of CAS server
"""
section = "cas"
def read_config(self, config, **kwargs):
cas_config = config.get("cas_config", None)
if cas_config:

View File

@@ -73,6 +73,9 @@ DEFAULT_CONFIG = """\
class ConsentConfig(Config):
section = "consent"
def __init__(self, *args):
super(ConsentConfig, self).__init__(*args)

View File

@@ -21,6 +21,8 @@ from ._base import Config
class DatabaseConfig(Config):
section = "database"
def read_config(self, config, **kwargs):
self.event_cache_size = self.parse_size(config.get("event_cache_size", "10K"))

View File

@@ -28,6 +28,8 @@ from ._base import Config, ConfigError
class EmailConfig(Config):
section = "email"
def read_config(self, config, **kwargs):
# TODO: We should separate better the email configuration from the notification
# and account validity config.

View File

@@ -17,6 +17,8 @@ from ._base import Config
class GroupsConfig(Config):
section = "groups"
def read_config(self, config, **kwargs):
self.enable_group_creation = config.get("enable_group_creation", False)
self.group_creation_prefix = config.get("group_creation_prefix", "")

View File

@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from ._base import RootConfig
from .api import ApiConfig
from .appservice import AppServiceConfig
from .captcha import CaptchaConfig
@@ -46,36 +47,37 @@ from .voip import VoipConfig
from .workers import WorkerConfig
class HomeServerConfig(
ServerConfig,
TlsConfig,
DatabaseConfig,
LoggingConfig,
RatelimitConfig,
ContentRepositoryConfig,
CaptchaConfig,
VoipConfig,
RegistrationConfig,
MetricsConfig,
ApiConfig,
AppServiceConfig,
KeyConfig,
SAML2Config,
CasConfig,
JWTConfig,
PasswordConfig,
EmailConfig,
WorkerConfig,
PasswordAuthProviderConfig,
PushConfig,
SpamCheckerConfig,
GroupsConfig,
UserDirectoryConfig,
ConsentConfig,
StatsConfig,
ServerNoticesConfig,
RoomDirectoryConfig,
ThirdPartyRulesConfig,
TracerConfig,
):
pass
class HomeServerConfig(RootConfig):
config_classes = [
ServerConfig,
TlsConfig,
DatabaseConfig,
LoggingConfig,
RatelimitConfig,
ContentRepositoryConfig,
CaptchaConfig,
VoipConfig,
RegistrationConfig,
MetricsConfig,
ApiConfig,
AppServiceConfig,
KeyConfig,
SAML2Config,
CasConfig,
JWTConfig,
PasswordConfig,
EmailConfig,
WorkerConfig,
PasswordAuthProviderConfig,
PushConfig,
SpamCheckerConfig,
GroupsConfig,
UserDirectoryConfig,
ConsentConfig,
StatsConfig,
ServerNoticesConfig,
RoomDirectoryConfig,
ThirdPartyRulesConfig,
TracerConfig,
]

View File

@@ -23,6 +23,8 @@ MISSING_JWT = """Missing jwt library. This is required for jwt login.
class JWTConfig(Config):
section = "jwt"
def read_config(self, config, **kwargs):
jwt_config = config.get("jwt_config", None)
if jwt_config:

View File

@@ -92,6 +92,8 @@ class TrustedKeyServer(object):
class KeyConfig(Config):
section = "key"
def read_config(self, config, config_dir_path, **kwargs):
# the signing key can be specified inline or in a separate file
if "signing_key" in config:

View File

@@ -84,6 +84,8 @@ root:
class LoggingConfig(Config):
section = "logging"
def read_config(self, config, **kwargs):
self.log_config = self.abspath(config.get("log_config"))
self.no_redirect_stdio = config.get("no_redirect_stdio", False)

View File

@@ -34,6 +34,8 @@ class MetricsFlags(object):
class MetricsConfig(Config):
section = "metrics"
def read_config(self, config, **kwargs):
self.enable_metrics = config.get("enable_metrics", False)
self.report_stats = config.get("report_stats", None)

View File

@@ -20,6 +20,8 @@ class PasswordConfig(Config):
"""Password login configuration
"""
section = "password"
def read_config(self, config, **kwargs):
password_config = config.get("password_config", {})
if password_config is None:

View File

@@ -23,6 +23,8 @@ LDAP_PROVIDER = "ldap_auth_provider.LdapAuthProvider"
class PasswordAuthProviderConfig(Config):
section = "authproviders"
def read_config(self, config, **kwargs):
self.password_providers = [] # type: List[Any]
providers = []

View File

@@ -18,6 +18,8 @@ from ._base import Config
class PushConfig(Config):
section = "push"
def read_config(self, config, **kwargs):
push_config = config.get("push", {})
self.push_include_content = push_config.get("include_content", True)

View File

@@ -36,6 +36,8 @@ class FederationRateLimitConfig(object):
class RatelimitConfig(Config):
section = "ratelimiting"
def read_config(self, config, **kwargs):
# Load the new-style messages config if it exists. Otherwise fall back

View File

@@ -24,6 +24,8 @@ from synapse.util.stringutils import random_string_with_symbols
class AccountValidityConfig(Config):
section = "accountvalidity"
def __init__(self, config, synapse_config):
self.enabled = config.get("enabled", False)
self.renew_by_email_enabled = "renew_at" in config
@@ -77,6 +79,8 @@ class AccountValidityConfig(Config):
class RegistrationConfig(Config):
section = "registration"
def read_config(self, config, **kwargs):
self.enable_registration = bool(
strtobool(str(config.get("enable_registration", False)))

View File

@@ -78,6 +78,8 @@ def parse_thumbnail_requirements(thumbnail_sizes):
class ContentRepositoryConfig(Config):
section = "media"
def read_config(self, config, **kwargs):
# Only enable the media repo if either the media repo is enabled or the

View File

@@ -19,6 +19,8 @@ from ._base import Config, ConfigError
class RoomDirectoryConfig(Config):
section = "roomdirectory"
def read_config(self, config, **kwargs):
self.enable_room_list_search = config.get("enable_room_list_search", True)

View File

@@ -55,6 +55,8 @@ def _dict_merge(merge_dict, into_dict):
class SAML2Config(Config):
section = "saml2"
def read_config(self, config, **kwargs):
self.saml2_enabled = False

View File

@@ -58,6 +58,8 @@ on how to configure the new listener.
class ServerConfig(Config):
section = "server"
def read_config(self, config, **kwargs):
self.server_name = config["server_name"]
self.server_context = config.get("server_context", None)

View File

@@ -59,6 +59,8 @@ class ServerNoticesConfig(Config):
None if server notices are not enabled.
"""
section = "servernotices"
def __init__(self, *args):
super(ServerNoticesConfig, self).__init__(*args)
self.server_notices_mxid = None

View File

@@ -19,6 +19,8 @@ from ._base import Config
class SpamCheckerConfig(Config):
section = "spamchecker"
def read_config(self, config, **kwargs):
self.spam_checker = None

View File

@@ -25,6 +25,8 @@ class StatsConfig(Config):
Configuration for the behaviour of synapse's stats engine
"""
section = "stats"
def read_config(self, config, **kwargs):
self.stats_enabled = True
self.stats_bucket_size = 86400 * 1000

View File

@@ -19,6 +19,8 @@ from ._base import Config
class ThirdPartyRulesConfig(Config):
section = "thirdpartyrules"
def read_config(self, config, **kwargs):
self.third_party_event_rules = None

View File

@@ -18,6 +18,7 @@ import os
import warnings
from datetime import datetime
from hashlib import sha256
from typing import List
import six
@@ -33,7 +34,9 @@ logger = logging.getLogger(__name__)
class TlsConfig(Config):
def read_config(self, config, config_dir_path, **kwargs):
section = "tls"
def read_config(self, config: dict, config_dir_path: str, **kwargs):
acme_config = config.get("acme", None)
if acme_config is None:
@@ -57,7 +60,7 @@ class TlsConfig(Config):
self.tls_certificate_file = self.abspath(config.get("tls_certificate_path"))
self.tls_private_key_file = self.abspath(config.get("tls_private_key_path"))
if self.has_tls_listener():
if self.root.server.has_tls_listener():
if not self.tls_certificate_file:
raise ConfigError(
"tls_certificate_path must be specified if TLS-enabled listeners are "
@@ -108,7 +111,7 @@ class TlsConfig(Config):
)
# Support globs (*) in whitelist values
self.federation_certificate_verification_whitelist = []
self.federation_certificate_verification_whitelist = [] # type: List[str]
for entry in fed_whitelist_entries:
try:
entry_regex = glob_to_regex(entry.encode("ascii").decode("ascii"))

View File

@@ -19,6 +19,8 @@ from ._base import Config, ConfigError
class TracerConfig(Config):
section = "tracing"
def read_config(self, config, **kwargs):
opentracing_config = config.get("opentracing")
if opentracing_config is None:

View File

@@ -21,6 +21,8 @@ class UserDirectoryConfig(Config):
Configuration for the behaviour of the /user_directory API
"""
section = "userdirectory"
def read_config(self, config, **kwargs):
self.user_directory_search_enabled = True
self.user_directory_search_all_users = False

View File

@@ -16,6 +16,8 @@ from ._base import Config
class VoipConfig(Config):
section = "voip"
def read_config(self, config, **kwargs):
self.turn_uris = config.get("turn_uris", [])
self.turn_shared_secret = config.get("turn_shared_secret")

View File

@@ -21,6 +21,8 @@ class WorkerConfig(Config):
They have their own pid_file and listener configuration. They use the
replication_url to talk to the main synapse process."""
section = "worker"
def read_config(self, config, **kwargs):
self.worker_app = config.get("worker_app")

View File

@@ -36,7 +36,6 @@ from synapse.api.errors import (
UnsupportedRoomVersionError,
)
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.crypto.event_signing import compute_event_signature
from synapse.events import room_version_to_event_format
from synapse.federation.federation_base import FederationBase, event_from_pdu_json
from synapse.federation.persistence import TransactionActions
@@ -322,18 +321,6 @@ class FederationServer(FederationBase):
pdus = yield self.handler.get_state_for_pdu(room_id, event_id)
auth_chain = yield self.store.get_auth_chain([pdu.event_id for pdu in pdus])
for event in auth_chain:
# We sign these again because there was a bug where we
# incorrectly signed things the first time round
if self.hs.is_mine_id(event.event_id):
event.signatures.update(
compute_event_signature(
event.get_pdu_json(),
self.hs.hostname,
self.hs.config.signing_key[0],
)
)
return {
"pdus": [pdu.get_pdu_json() for pdu in pdus],
"auth_chain": [pdu.get_pdu_json() for pdu in auth_chain],

View File

@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2017 Vector Creations Ltd
# Copyright 2018 New Vector Ltd
# Copyright 2019 Michael Telatynski <7t3chguy@gmail.com>
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -20,16 +21,16 @@ from six import string_types
from twisted.internet import defer
from synapse.api.errors import SynapseError
from synapse.api.errors import Codes, SynapseError
from synapse.types import GroupID, RoomID, UserID, get_domain_from_id
from synapse.util.async_helpers import concurrently_execute
logger = logging.getLogger(__name__)
# TODO: Allow users to "knock" or simpkly join depending on rules
# TODO: Allow users to "knock" or simply join depending on rules
# TODO: Federation admin APIs
# TODO: is_priveged flag to users and is_public to users and rooms
# TODO: is_privileged flag to users and is_public to users and rooms
# TODO: Audit log for admins (profile updates, membership changes, users who tried
# to join but were rejected, etc)
# TODO: Flairs
@@ -590,7 +591,18 @@ class GroupsServerHandler(object):
)
# TODO: Check if user knocked
# TODO: Check if user is already invited
invited_users = yield self.store.get_invited_users_in_group(group_id)
if user_id in invited_users:
raise SynapseError(
400, "User already invited to group", errcode=Codes.BAD_STATE
)
user_results = yield self.store.get_users_in_group(
group_id, include_private=True
)
if user_id in [user_result["user_id"] for user_result in user_results]:
raise SynapseError(400, "User already in group")
content = {
"profile": {"name": group["name"], "avatar_url": group["avatar_url"]},

View File

@@ -803,17 +803,25 @@ class PresenceHandler(object):
# Loop round handling deltas until we're up to date
while True:
with Measure(self.clock, "presence_delta"):
deltas = yield self.store.get_current_state_deltas(self._event_pos)
if not deltas:
room_max_stream_ordering = self.store.get_room_max_stream_ordering()
if self._event_pos == room_max_stream_ordering:
return
logger.debug(
"Processing presence stats %s->%s",
self._event_pos,
room_max_stream_ordering,
)
max_pos, deltas = yield self.store.get_current_state_deltas(
self._event_pos, room_max_stream_ordering
)
yield self._handle_state_delta(deltas)
self._event_pos = deltas[-1]["stream_id"]
self._event_pos = max_pos
# Expose current event processing position to prometheus
synapse.metrics.event_processing_positions.labels("presence").set(
self._event_pos
max_pos
)
@defer.inlineCallbacks

View File

@@ -217,10 +217,9 @@ class RegistrationHandler(BaseHandler):
else:
# autogen a sequential user ID
attempts = 0
user = None
while not user:
localpart = yield self._generate_user_id(attempts > 0)
localpart = yield self._generate_user_id()
user = UserID(localpart, self.hs.hostname)
user_id = user.to_string()
yield self.check_user_id_not_appservice_exclusive(user_id)
@@ -238,7 +237,6 @@ class RegistrationHandler(BaseHandler):
# if user id is taken, just generate another
user = None
user_id = None
attempts += 1
if not self.hs.config.user_consent_at_registration:
yield self._auto_join_rooms(user_id)
@@ -379,10 +377,10 @@ class RegistrationHandler(BaseHandler):
)
@defer.inlineCallbacks
def _generate_user_id(self, reseed=False):
if reseed or self._next_generated_user_id is None:
def _generate_user_id(self):
if self._next_generated_user_id is None:
with (yield self._generate_user_id_linearizer.queue(())):
if reseed or self._next_generated_user_id is None:
if self._next_generated_user_id is None:
self._next_generated_user_id = (
yield self.store.find_next_generated_user_id_localpart()
)

View File

@@ -28,6 +28,7 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, JoinRules, RoomCreationPreset
from synapse.api.errors import AuthError, Codes, NotFoundError, StoreError, SynapseError
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.http.endpoint import parse_and_validate_server_name
from synapse.storage.state import StateFilter
from synapse.types import RoomAlias, RoomID, RoomStreamToken, StreamToken, UserID
from synapse.util import stringutils
@@ -554,7 +555,8 @@ class RoomCreationHandler(BaseHandler):
invite_list = config.get("invite", [])
for i in invite_list:
try:
UserID.from_string(i)
uid = UserID.from_string(i)
parse_and_validate_server_name(uid.domain)
except Exception:
raise SynapseError(400, "Invalid user_id: %s" % (i,))

View File

@@ -88,16 +88,8 @@ class RoomListHandler(BaseHandler):
# appservice specific lists.
logger.info("Bypassing cache as search request.")
# XXX: Quick hack to stop room directory queries taking too long.
# Timeout request after 60s. Probably want a more fundamental
# solution at some point
timeout = self.clock.time() + 60
return self._get_public_room_list(
limit,
since_token,
search_filter,
network_tuple=network_tuple,
timeout=timeout,
limit, since_token, search_filter, network_tuple=network_tuple
)
key = (limit, since_token, network_tuple)
@@ -118,7 +110,6 @@ class RoomListHandler(BaseHandler):
search_filter=None,
network_tuple=EMPTY_THIRD_PARTY_ID,
from_federation=False,
timeout=None,
):
"""Generate a public room list.
Args:
@@ -131,8 +122,6 @@ class RoomListHandler(BaseHandler):
Setting to None returns all public rooms across all lists.
from_federation (bool): Whether this request originated from a
federating server or a client. Used for room filtering.
timeout (int|None): Amount of seconds to wait for a response before
timing out. TODO
"""
# Pagination tokens work by storing the room ID sent in the last batch,

View File

@@ -203,23 +203,11 @@ class RoomMemberHandler(object):
prev_member_event = yield self.store.get_event(prev_member_event_id)
newly_joined = prev_member_event.membership != Membership.JOIN
if newly_joined:
# Copy over user state if we're joining an upgraded room
yield self.copy_user_state_if_room_upgrade(
room_id, requester.user.to_string()
)
yield self._user_joined_room(target, room_id)
# Copy over direct message status and room tags if this is a join
# on an upgraded room
# Check if this is an upgraded room
predecessor = yield self.store.get_room_predecessor(room_id)
if predecessor:
# It is an upgraded room. Copy over old tags
self.copy_room_tags_and_direct_to_room(
predecessor["room_id"], room_id, user_id
)
# Copy over push rules
yield self.store.copy_push_rules_from_room_to_room_for_user(
predecessor["room_id"], room_id, user_id
)
elif event.membership == Membership.LEAVE:
if prev_member_event_id:
prev_member_event = yield self.store.get_event(prev_member_event_id)
@@ -463,10 +451,16 @@ class RoomMemberHandler(object):
if requester.is_guest:
content["kind"] = "guest"
ret = yield self._remote_join(
remote_join_response = yield self._remote_join(
requester, remote_room_hosts, room_id, target, content
)
return ret
# Copy over user state if this is a join on an remote upgraded room
yield self.copy_user_state_if_room_upgrade(
room_id, requester.user.to_string()
)
return remote_join_response
elif effective_membership_state == Membership.LEAVE:
if not is_host_in_room:
@@ -503,6 +497,38 @@ class RoomMemberHandler(object):
)
return res
@defer.inlineCallbacks
def copy_user_state_if_room_upgrade(self, new_room_id, user_id):
"""Copy user-specific information when they join a new room if that new room is the
result of a room upgrade
Args:
new_room_id (str): The ID of the room the user is joining
user_id (str): The ID of the user
Returns:
Deferred
"""
# Check if the new room is an upgraded room
predecessor = yield self.store.get_room_predecessor(new_room_id)
if not predecessor:
return
logger.debug(
"Found predecessor for %s: %s. Copying over room tags and push " "rules",
new_room_id,
predecessor,
)
# It is an upgraded room. Copy over old tags
yield self.copy_room_tags_and_direct_to_room(
predecessor["room_id"], new_room_id, user_id
)
# Copy over push rules
yield self.store.copy_push_rules_from_room_to_room_for_user(
predecessor["room_id"], new_room_id, user_id
)
@defer.inlineCallbacks
def send_membership_event(self, requester, event, context, ratelimit=True):
"""

View File

@@ -87,21 +87,23 @@ class StatsHandler(StateDeltasHandler):
# Be sure to read the max stream_ordering *before* checking if there are any outstanding
# deltas, since there is otherwise a chance that we could miss updates which arrive
# after we check the deltas.
room_max_stream_ordering = yield self.store.get_room_max_stream_ordering()
room_max_stream_ordering = self.store.get_room_max_stream_ordering()
if self.pos == room_max_stream_ordering:
break
deltas = yield self.store.get_current_state_deltas(self.pos)
logger.debug(
"Processing room stats %s->%s", self.pos, room_max_stream_ordering
)
max_pos, deltas = yield self.store.get_current_state_deltas(
self.pos, room_max_stream_ordering
)
if deltas:
logger.debug("Handling %d state deltas", len(deltas))
room_deltas, user_deltas = yield self._handle_deltas(deltas)
max_pos = deltas[-1]["stream_id"]
else:
room_deltas = {}
user_deltas = {}
max_pos = room_max_stream_ordering
# Then count deltas for total_events and total_event_bytes.
room_count, user_count = yield self.store.get_changes_room_total_events_and_bytes(
@@ -293,6 +295,7 @@ class StatsHandler(StateDeltasHandler):
room_state["guest_access"] = event_content.get("guest_access")
for room_id, state in room_to_state_updates.items():
logger.info("Updating room_stats_state for %s: %s", room_id, state)
yield self.store.update_room_state(room_id, state)
return room_to_stats_deltas, user_to_stats_deltas

View File

@@ -138,21 +138,28 @@ class UserDirectoryHandler(StateDeltasHandler):
# Loop round handling deltas until we're up to date
while True:
with Measure(self.clock, "user_dir_delta"):
deltas = yield self.store.get_current_state_deltas(self.pos)
if not deltas:
room_max_stream_ordering = self.store.get_room_max_stream_ordering()
if self.pos == room_max_stream_ordering:
return
logger.debug(
"Processing user stats %s->%s", self.pos, room_max_stream_ordering
)
max_pos, deltas = yield self.store.get_current_state_deltas(
self.pos, room_max_stream_ordering
)
logger.info("Handling %d state deltas", len(deltas))
yield self._handle_deltas(deltas)
self.pos = deltas[-1]["stream_id"]
self.pos = max_pos
# Expose current event processing position to prometheus
synapse.metrics.event_processing_positions.labels("user_dir").set(
self.pos
max_pos
)
yield self.store.update_user_directory_stream_pos(self.pos)
yield self.store.update_user_directory_stream_pos(max_pos)
@defer.inlineCallbacks
def _handle_deltas(self, deltas):

View File

@@ -17,7 +17,7 @@ import logging
from twisted.internet import defer
from synapse.api.errors import AuthError, Codes, StoreError, SynapseError
from synapse.api.errors import AuthError, NotFoundError, StoreError, SynapseError
from synapse.http.servlet import RestServlet, parse_json_object_from_request
from synapse.types import UserID
@@ -52,13 +52,15 @@ class GetFilterRestServlet(RestServlet):
raise SynapseError(400, "Invalid filter_id")
try:
filter = yield self.filtering.get_user_filter(
filter_collection = yield self.filtering.get_user_filter(
user_localpart=target_user.localpart, filter_id=filter_id
)
except StoreError as e:
if e.code != 404:
raise
raise NotFoundError("No such filter")
return 200, filter.get_filter_json()
except (KeyError, StoreError):
raise SynapseError(400, "No such filter", errcode=Codes.NOT_FOUND)
return 200, filter_collection.get_filter_json()
class CreateFilterRestServlet(RestServlet):

View File

@@ -21,7 +21,7 @@ from canonicaljson import json
from twisted.internet import defer
from synapse.api.constants import PresenceState
from synapse.api.errors import SynapseError
from synapse.api.errors import Codes, StoreError, SynapseError
from synapse.api.filtering import DEFAULT_FILTER_COLLECTION, FilterCollection
from synapse.events.utils import (
format_event_for_client_v2_without_room_id,
@@ -119,25 +119,32 @@ class SyncRestServlet(RestServlet):
request_key = (user, timeout, since, filter_id, full_state, device_id)
if filter_id:
if filter_id.startswith("{"):
try:
filter_object = json.loads(filter_id)
set_timeline_upper_limit(
filter_object, self.hs.config.filter_timeline_limit
)
except Exception:
raise SynapseError(400, "Invalid filter JSON")
self.filtering.check_valid_filter(filter_object)
filter = FilterCollection(filter_object)
else:
filter = yield self.filtering.get_user_filter(user.localpart, filter_id)
if filter_id is None:
filter_collection = DEFAULT_FILTER_COLLECTION
elif filter_id.startswith("{"):
try:
filter_object = json.loads(filter_id)
set_timeline_upper_limit(
filter_object, self.hs.config.filter_timeline_limit
)
except Exception:
raise SynapseError(400, "Invalid filter JSON")
self.filtering.check_valid_filter(filter_object)
filter_collection = FilterCollection(filter_object)
else:
filter = DEFAULT_FILTER_COLLECTION
try:
filter_collection = yield self.filtering.get_user_filter(
user.localpart, filter_id
)
except StoreError as err:
if err.code != 404:
raise
# fix up the description and errcode to be more useful
raise SynapseError(400, "No such filter", errcode=Codes.INVALID_PARAM)
sync_config = SyncConfig(
user=user,
filter_collection=filter,
filter_collection=filter_collection,
is_guest=requester.is_guest,
request_key=request_key,
device_id=device_id,
@@ -171,7 +178,7 @@ class SyncRestServlet(RestServlet):
time_now = self.clock.time_msec()
response_content = yield self.encode_response(
time_now, sync_result, requester.access_token_id, filter
time_now, sync_result, requester.access_token_id, filter_collection
)
return 200, response_content

View File

@@ -195,7 +195,7 @@ def respond_with_responder(request, responder, media_type, file_size, upload_nam
respond_404(request)
return
logger.debug("Responding to media request with responder %s")
logger.debug("Responding to media request with responder %s", responder)
add_file_headers(request, media_type, file_size, upload_name)
try:
with responder:

View File

@@ -270,7 +270,7 @@ class PreviewUrlResource(DirectServeResource):
logger.debug("Calculated OG for %s as %s" % (url, og))
jsonog = json.dumps(og).encode("utf8")
jsonog = json.dumps(og)
# store OG in history-aware DB cache
yield self.store.store_url_cache(
@@ -283,7 +283,7 @@ class PreviewUrlResource(DirectServeResource):
media_info["created_ts"],
)
return jsonog
return jsonog.encode("utf8")
@defer.inlineCallbacks
def _download_url(self, url, user):

View File

@@ -17,7 +17,7 @@ import logging
from twisted.web.server import NOT_DONE_YET
from synapse.api.errors import SynapseError
from synapse.api.errors import Codes, SynapseError
from synapse.http.server import (
DirectServeResource,
respond_with_json,
@@ -56,7 +56,11 @@ class UploadResource(DirectServeResource):
if content_length is None:
raise SynapseError(msg="Request must specify a Content-Length", code=400)
if int(content_length) > self.max_upload_size:
raise SynapseError(msg="Upload request body is too large", code=413)
raise SynapseError(
msg="Upload request body is too large",
code=413,
errcode=Codes.TOO_LARGE,
)
upload_name = parse_string(request, b"filename", encoding=None)
if upload_name:

View File

@@ -20,6 +20,7 @@ import random
import sys
import threading
import time
from typing import Iterable, Tuple
from six import PY2, iteritems, iterkeys, itervalues
from six.moves import builtins, intern, range
@@ -30,7 +31,7 @@ from prometheus_client import Histogram
from twisted.internet import defer
from synapse.api.errors import StoreError
from synapse.logging.context import LoggingContext, PreserveLoggingContext
from synapse.logging.context import LoggingContext, make_deferred_yieldable
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
from synapse.types import get_domain_from_id
@@ -550,8 +551,9 @@ class SQLBaseStore(object):
return func(conn, *args, **kwargs)
with PreserveLoggingContext():
result = yield self._db_pool.runWithConnection(inner_func, *args, **kwargs)
result = yield make_deferred_yieldable(
self._db_pool.runWithConnection(inner_func, *args, **kwargs)
)
return result
@@ -1162,19 +1164,18 @@ class SQLBaseStore(object):
if not iterable:
return []
sql = "SELECT %s FROM %s" % (", ".join(retcols), table)
clauses = []
values = []
clauses.append("%s IN (%s)" % (column, ",".join("?" for _ in iterable)))
values.extend(iterable)
clause, values = make_in_list_sql_clause(txn.database_engine, column, iterable)
clauses = [clause]
for key, value in iteritems(keyvalues):
clauses.append("%s = ?" % (key,))
values.append(value)
if clauses:
sql = "%s WHERE %s" % (sql, " AND ".join(clauses))
sql = "SELECT %s FROM %s WHERE %s" % (
", ".join(retcols),
table,
" AND ".join(clauses),
)
txn.execute(sql, values)
return cls.cursor_to_dict(txn)
@@ -1323,10 +1324,8 @@ class SQLBaseStore(object):
sql = "DELETE FROM %s" % table
clauses = []
values = []
clauses.append("%s IN (%s)" % (column, ",".join("?" for _ in iterable)))
values.extend(iterable)
clause, values = make_in_list_sql_clause(txn.database_engine, column, iterable)
clauses = [clause]
for key, value in iteritems(keyvalues):
clauses.append("%s = ?" % (key,))
@@ -1693,3 +1692,30 @@ def db_to_json(db_content):
except Exception:
logging.warning("Tried to decode '%r' as JSON and failed", db_content)
raise
def make_in_list_sql_clause(
database_engine, column: str, iterable: Iterable
) -> Tuple[str, Iterable]:
"""Returns an SQL clause that checks the given column is in the iterable.
On SQLite this expands to `column IN (?, ?, ...)`, whereas on Postgres
it expands to `column = ANY(?)`. While both DBs support the `IN` form,
using the `ANY` form on postgres means that it views queries with
different length iterables as the same, helping the query stats.
Args:
database_engine
column: Name of the column
iterable: The values to check the column against.
Returns:
A tuple of SQL query and the args
"""
if database_engine.supports_using_any_list:
# This should hopefully be faster, but also makes postgres query
# stats easier to understand.
return "%s = ANY(?)" % (column,), [list(iterable)]
else:
return "%s IN (%s)" % (column, ",".join("?" for _ in iterable)), list(iterable)

View File

@@ -33,16 +33,9 @@ logger = logging.getLogger(__name__)
LAST_SEEN_GRANULARITY = 120 * 1000
class ClientIpStore(background_updates.BackgroundUpdateStore):
class ClientIpBackgroundUpdateStore(background_updates.BackgroundUpdateStore):
def __init__(self, db_conn, hs):
self.client_ip_last_seen = Cache(
name="client_ip_last_seen", keylen=4, max_entries=50000 * CACHE_SIZE_FACTOR
)
super(ClientIpStore, self).__init__(db_conn, hs)
self.user_ips_max_age = hs.config.user_ips_max_age
super(ClientIpBackgroundUpdateStore, self).__init__(db_conn, hs)
self.register_background_index_update(
"user_ips_device_index",
@@ -92,19 +85,6 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
"devices_last_seen", self._devices_last_seen_update
)
# (user_id, access_token, ip,) -> (user_agent, device_id, last_seen)
self._batch_row_update = {}
self._client_ip_looper = self._clock.looping_call(
self._update_client_ips_batch, 5 * 1000
)
self.hs.get_reactor().addSystemEventTrigger(
"before", "shutdown", self._update_client_ips_batch
)
if self.user_ips_max_age:
self._clock.looping_call(self._prune_old_user_ips, 5 * 1000)
@defer.inlineCallbacks
def _remove_user_ip_nonunique(self, progress, batch_size):
def f(conn):
@@ -303,6 +283,110 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
return batch_size
@defer.inlineCallbacks
def _devices_last_seen_update(self, progress, batch_size):
"""Background update to insert last seen info into devices table
"""
last_user_id = progress.get("last_user_id", "")
last_device_id = progress.get("last_device_id", "")
def _devices_last_seen_update_txn(txn):
# This consists of two queries:
#
# 1. The sub-query searches for the next N devices and joins
# against user_ips to find the max last_seen associated with
# that device.
# 2. The outer query then joins again against user_ips on
# user/device/last_seen. This *should* hopefully only
# return one row, but if it does return more than one then
# we'll just end up updating the same device row multiple
# times, which is fine.
if self.database_engine.supports_tuple_comparison:
where_clause = "(user_id, device_id) > (?, ?)"
where_args = [last_user_id, last_device_id]
else:
# We explicitly do a `user_id >= ? AND (...)` here to ensure
# that an index is used, as doing `user_id > ? OR (user_id = ? AND ...)`
# makes it hard for query optimiser to tell that it can use the
# index on user_id
where_clause = "user_id >= ? AND (user_id > ? OR device_id > ?)"
where_args = [last_user_id, last_user_id, last_device_id]
sql = """
SELECT
last_seen, ip, user_agent, user_id, device_id
FROM (
SELECT
user_id, device_id, MAX(u.last_seen) AS last_seen
FROM devices
INNER JOIN user_ips AS u USING (user_id, device_id)
WHERE %(where_clause)s
GROUP BY user_id, device_id
ORDER BY user_id ASC, device_id ASC
LIMIT ?
) c
INNER JOIN user_ips AS u USING (user_id, device_id, last_seen)
""" % {
"where_clause": where_clause
}
txn.execute(sql, where_args + [batch_size])
rows = txn.fetchall()
if not rows:
return 0
sql = """
UPDATE devices
SET last_seen = ?, ip = ?, user_agent = ?
WHERE user_id = ? AND device_id = ?
"""
txn.execute_batch(sql, rows)
_, _, _, user_id, device_id = rows[-1]
self._background_update_progress_txn(
txn,
"devices_last_seen",
{"last_user_id": user_id, "last_device_id": device_id},
)
return len(rows)
updated = yield self.runInteraction(
"_devices_last_seen_update", _devices_last_seen_update_txn
)
if not updated:
yield self._end_background_update("devices_last_seen")
return updated
class ClientIpStore(ClientIpBackgroundUpdateStore):
def __init__(self, db_conn, hs):
self.client_ip_last_seen = Cache(
name="client_ip_last_seen", keylen=4, max_entries=50000 * CACHE_SIZE_FACTOR
)
super(ClientIpStore, self).__init__(db_conn, hs)
self.user_ips_max_age = hs.config.user_ips_max_age
# (user_id, access_token, ip,) -> (user_agent, device_id, last_seen)
self._batch_row_update = {}
self._client_ip_looper = self._clock.looping_call(
self._update_client_ips_batch, 5 * 1000
)
self.hs.get_reactor().addSystemEventTrigger(
"before", "shutdown", self._update_client_ips_batch
)
if self.user_ips_max_age:
self._clock.looping_call(self._prune_old_user_ips, 5 * 1000)
@defer.inlineCallbacks
def insert_client_ip(
self, user_id, access_token, ip, user_agent, device_id, now=None
@@ -454,85 +538,6 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
for (access_token, ip), (user_agent, last_seen) in iteritems(results)
)
@defer.inlineCallbacks
def _devices_last_seen_update(self, progress, batch_size):
"""Background update to insert last seen info into devices table
"""
last_user_id = progress.get("last_user_id", "")
last_device_id = progress.get("last_device_id", "")
def _devices_last_seen_update_txn(txn):
# This consists of two queries:
#
# 1. The sub-query searches for the next N devices and joins
# against user_ips to find the max last_seen associated with
# that device.
# 2. The outer query then joins again against user_ips on
# user/device/last_seen. This *should* hopefully only
# return one row, but if it does return more than one then
# we'll just end up updating the same device row multiple
# times, which is fine.
if self.database_engine.supports_tuple_comparison:
where_clause = "(user_id, device_id) > (?, ?)"
where_args = [last_user_id, last_device_id]
else:
# We explicitly do a `user_id >= ? AND (...)` here to ensure
# that an index is used, as doing `user_id > ? OR (user_id = ? AND ...)`
# makes it hard for query optimiser to tell that it can use the
# index on user_id
where_clause = "user_id >= ? AND (user_id > ? OR device_id > ?)"
where_args = [last_user_id, last_user_id, last_device_id]
sql = """
SELECT
last_seen, ip, user_agent, user_id, device_id
FROM (
SELECT
user_id, device_id, MAX(u.last_seen) AS last_seen
FROM devices
INNER JOIN user_ips AS u USING (user_id, device_id)
WHERE %(where_clause)s
GROUP BY user_id, device_id
ORDER BY user_id ASC, device_id ASC
LIMIT ?
) c
INNER JOIN user_ips AS u USING (user_id, device_id, last_seen)
""" % {
"where_clause": where_clause
}
txn.execute(sql, where_args + [batch_size])
rows = txn.fetchall()
if not rows:
return 0
sql = """
UPDATE devices
SET last_seen = ?, ip = ?, user_agent = ?
WHERE user_id = ? AND device_id = ?
"""
txn.execute_batch(sql, rows)
_, _, _, user_id, device_id = rows[-1]
self._background_update_progress_txn(
txn,
"devices_last_seen",
{"last_user_id": user_id, "last_device_id": device_id},
)
return len(rows)
updated = yield self.runInteraction(
"_devices_last_seen_update", _devices_last_seen_update_txn
)
if not updated:
yield self._end_background_update("devices_last_seen")
return updated
@wrap_as_background_process("prune_old_user_ips")
async def _prune_old_user_ips(self):
"""Removes entries in user IPs older than the configured period.

View File

@@ -20,7 +20,7 @@ from canonicaljson import json
from twisted.internet import defer
from synapse.logging.opentracing import log_kv, set_tag, trace
from synapse.storage._base import SQLBaseStore
from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
from synapse.storage.background_updates import BackgroundUpdateStore
from synapse.util.caches.expiringcache import ExpiringCache
@@ -208,11 +208,11 @@ class DeviceInboxWorkerStore(SQLBaseStore):
)
class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore):
class DeviceInboxBackgroundUpdateStore(BackgroundUpdateStore):
DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
def __init__(self, db_conn, hs):
super(DeviceInboxStore, self).__init__(db_conn, hs)
super(DeviceInboxBackgroundUpdateStore, self).__init__(db_conn, hs)
self.register_background_index_update(
"device_inbox_stream_index",
@@ -225,6 +225,26 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore):
self.DEVICE_INBOX_STREAM_ID, self._background_drop_index_device_inbox
)
@defer.inlineCallbacks
def _background_drop_index_device_inbox(self, progress, batch_size):
def reindex_txn(conn):
txn = conn.cursor()
txn.execute("DROP INDEX IF EXISTS device_inbox_stream_id")
txn.close()
yield self.runWithConnection(reindex_txn)
yield self._end_background_update(self.DEVICE_INBOX_STREAM_ID)
return 1
class DeviceInboxStore(DeviceInboxWorkerStore, DeviceInboxBackgroundUpdateStore):
DEVICE_INBOX_STREAM_ID = "device_inbox_stream_drop"
def __init__(self, db_conn, hs):
super(DeviceInboxStore, self).__init__(db_conn, hs)
# Map of (user_id, device_id) to the last stream_id that has been
# deleted up to. This is so that we can no op deletions.
self._last_device_delete_cache = ExpiringCache(
@@ -358,15 +378,15 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore):
else:
if not devices:
continue
sql = (
"SELECT device_id FROM devices"
" WHERE user_id = ? AND device_id IN ("
+ ",".join("?" * len(devices))
+ ")"
clause, args = make_in_list_sql_clause(
txn.database_engine, "device_id", devices
)
sql = "SELECT device_id FROM devices WHERE user_id = ? AND " + clause
# TODO: Maybe this needs to be done in batches if there are
# too many local devices for a given user.
txn.execute(sql, [user_id] + devices)
txn.execute(sql, [user_id] + list(args))
for row in txn:
# Only insert into the local inbox if the device exists on
# this server
@@ -435,16 +455,3 @@ class DeviceInboxStore(DeviceInboxWorkerStore, BackgroundUpdateStore):
return self.runInteraction(
"get_all_new_device_messages", get_all_new_device_messages_txn
)
@defer.inlineCallbacks
def _background_drop_index_device_inbox(self, progress, batch_size):
def reindex_txn(conn):
txn = conn.cursor()
txn.execute("DROP INDEX IF EXISTS device_inbox_stream_id")
txn.close()
yield self.runWithConnection(reindex_txn)
yield self._end_background_update(self.DEVICE_INBOX_STREAM_ID)
return 1

View File

@@ -28,7 +28,12 @@ from synapse.logging.opentracing import (
whitelisted_homeserver,
)
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import Cache, SQLBaseStore, db_to_json
from synapse.storage._base import (
Cache,
SQLBaseStore,
db_to_json,
make_in_list_sql_clause,
)
from synapse.storage.background_updates import BackgroundUpdateStore
from synapse.util import batch_iter
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
@@ -448,11 +453,14 @@ class DeviceWorkerStore(SQLBaseStore):
sql = """
SELECT DISTINCT user_id FROM device_lists_stream
WHERE stream_id > ?
AND user_id IN (%s)
AND
"""
for chunk in batch_iter(to_check, 100):
txn.execute(sql % (",".join("?" for _ in chunk),), (from_key,) + chunk)
clause, args = make_in_list_sql_clause(
txn.database_engine, "user_id", chunk
)
txn.execute(sql + clause, (from_key,) + tuple(args))
changes.update(user_id for user_id, in txn)
return changes
@@ -512,17 +520,9 @@ class DeviceWorkerStore(SQLBaseStore):
return results
class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
class DeviceBackgroundUpdateStore(BackgroundUpdateStore):
def __init__(self, db_conn, hs):
super(DeviceStore, self).__init__(db_conn, hs)
# Map of (user_id, device_id) -> bool. If there is an entry that implies
# the device exists.
self.device_id_exists_cache = Cache(
name="device_id_exists", keylen=2, max_entries=10000
)
self._clock.looping_call(self._prune_old_outbound_device_pokes, 60 * 60 * 1000)
super(DeviceBackgroundUpdateStore, self).__init__(db_conn, hs)
self.register_background_index_update(
"device_lists_stream_idx",
@@ -555,6 +555,31 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
self._drop_device_list_streams_non_unique_indexes,
)
@defer.inlineCallbacks
def _drop_device_list_streams_non_unique_indexes(self, progress, batch_size):
def f(conn):
txn = conn.cursor()
txn.execute("DROP INDEX IF EXISTS device_lists_remote_cache_id")
txn.execute("DROP INDEX IF EXISTS device_lists_remote_extremeties_id")
txn.close()
yield self.runWithConnection(f)
yield self._end_background_update(DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES)
return 1
class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
def __init__(self, db_conn, hs):
super(DeviceStore, self).__init__(db_conn, hs)
# Map of (user_id, device_id) -> bool. If there is an entry that implies
# the device exists.
self.device_id_exists_cache = Cache(
name="device_id_exists", keylen=2, max_entries=10000
)
self._clock.looping_call(self._prune_old_outbound_device_pokes, 60 * 60 * 1000)
@defer.inlineCallbacks
def store_device(self, user_id, device_id, initial_device_display_name):
"""Ensure the given device is known; add it to the store if not
@@ -910,15 +935,3 @@ class DeviceStore(DeviceWorkerStore, BackgroundUpdateStore):
"_prune_old_outbound_device_pokes",
_prune_txn,
)
@defer.inlineCallbacks
def _drop_device_list_streams_non_unique_indexes(self, progress, batch_size):
def f(conn):
txn = conn.cursor()
txn.execute("DROP INDEX IF EXISTS device_lists_remote_cache_id")
txn.execute("DROP INDEX IF EXISTS device_lists_remote_extremeties_id")
txn.close()
yield self.runWithConnection(f)
yield self._end_background_update(DROP_DEVICE_LIST_STREAMS_NON_UNIQUE_INDEXES)
return 1

View File

@@ -22,6 +22,13 @@ class PostgresEngine(object):
def __init__(self, database_module, database_config):
self.module = database_module
self.module.extensions.register_type(self.module.extensions.UNICODE)
# Disables passing `bytes` to txn.execute, c.f. #6186. If you do
# actually want to use bytes than wrap it in `bytearray`.
def _disable_bytes_adapter(_):
raise Exception("Passing bytes to DB is disabled.")
self.module.extensions.register_adapter(bytes, _disable_bytes_adapter)
self.synchronous_commit = database_config.get("synchronous_commit", True)
self._version = None # unknown as yet
@@ -79,6 +86,12 @@ class PostgresEngine(object):
"""
return True
@property
def supports_using_any_list(self):
"""Do we support using `a = ANY(?)` and passing a list
"""
return True
def is_deadlock(self, error):
if isinstance(error, self.module.DatabaseError):
# https://www.postgresql.org/docs/current/static/errcodes-appendix.html

View File

@@ -46,6 +46,12 @@ class Sqlite3Engine(object):
"""
return self.module.sqlite_version_info >= (3, 15, 0)
@property
def supports_using_any_list(self):
"""Do we support using `a = ANY(?)` and passing a list
"""
return False
def check_database(self, txn):
pass

View File

@@ -25,7 +25,7 @@ from twisted.internet import defer
from synapse.api.errors import StoreError
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import SQLBaseStore
from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
from synapse.storage.events_worker import EventsWorkerStore
from synapse.storage.signatures import SignatureWorkerStore
from synapse.util.caches.descriptors import cached
@@ -68,7 +68,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
else:
results = set()
base_sql = "SELECT auth_id FROM event_auth WHERE event_id IN (%s)"
base_sql = "SELECT auth_id FROM event_auth WHERE "
front = set(event_ids)
while front:
@@ -76,7 +76,10 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
front_list = list(front)
chunks = [front_list[x : x + 100] for x in range(0, len(front), 100)]
for chunk in chunks:
txn.execute(base_sql % (",".join(["?"] * len(chunk)),), chunk)
clause, args = make_in_list_sql_clause(
txn.database_engine, "event_id", chunk
)
txn.execute(base_sql + clause, list(args))
new_front.update([r[0] for r in txn])
new_front -= results

View File

@@ -23,7 +23,7 @@ from functools import wraps
from six import iteritems, text_type
from six.moves import range
from canonicaljson import encode_canonical_json, json
from canonicaljson import json
from prometheus_client import Counter, Histogram
from twisted.internet import defer
@@ -39,6 +39,7 @@ from synapse.logging.utils import log_function
from synapse.metrics import BucketCollector
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.state import StateResolutionStore
from synapse.storage._base import make_in_list_sql_clause
from synapse.storage.background_updates import BackgroundUpdateStore
from synapse.storage.event_federation import EventFederationStore
from synapse.storage.events_worker import EventsWorkerStore
@@ -641,14 +642,16 @@ class EventsStore(
LEFT JOIN rejections USING (event_id)
LEFT JOIN event_json USING (event_id)
WHERE
prev_event_id IN (%s)
AND NOT events.outlier
NOT events.outlier
AND rejections.event_id IS NULL
""" % (
",".join("?" for _ in batch),
AND
"""
clause, args = make_in_list_sql_clause(
self.database_engine, "prev_event_id", batch
)
txn.execute(sql, batch)
txn.execute(sql + clause, args)
results.extend(r[0] for r in txn if not json.loads(r[1]).get("soft_failed"))
for chunk in batch_iter(event_ids, 100):
@@ -695,13 +698,15 @@ class EventsStore(
LEFT JOIN rejections USING (event_id)
LEFT JOIN event_json USING (event_id)
WHERE
event_id IN (%s)
AND NOT events.outlier
""" % (
",".join("?" for _ in to_recursively_check),
NOT events.outlier
AND
"""
clause, args = make_in_list_sql_clause(
self.database_engine, "event_id", to_recursively_check
)
txn.execute(sql, to_recursively_check)
txn.execute(sql + clause, args)
to_recursively_check = []
for event_id, prev_event_id, metadata, rejected in txn:
@@ -1543,10 +1548,14 @@ class EventsStore(
" FROM events as e"
" LEFT JOIN rejections as rej USING (event_id)"
" LEFT JOIN redactions as r ON e.event_id = r.redacts"
" WHERE e.event_id IN (%s)"
) % (",".join(["?"] * len(ev_map)),)
" WHERE "
)
txn.execute(sql, list(ev_map))
clause, args = make_in_list_sql_clause(
self.database_engine, "e.event_id", list(ev_map)
)
txn.execute(sql + clause, args)
rows = self.cursor_to_dict(txn)
for row in rows:
event = ev_map[row["event_id"]]
@@ -1632,9 +1641,7 @@ class EventsStore(
and original_event.internal_metadata.is_redacted()
):
# Redaction was allowed
pruned_json = encode_canonical_json(
prune_event_dict(original_event.get_dict())
)
pruned_json = encode_json(prune_event_dict(original_event.get_dict()))
else:
# Redaction wasn't allowed
pruned_json = None
@@ -2251,11 +2258,12 @@ class EventsStore(
sql = """
SELECT DISTINCT state_group FROM event_to_state_groups
LEFT JOIN events_to_purge AS ep USING (event_id)
WHERE state_group IN (%s) AND ep.event_id IS NULL
""" % (
",".join("?" for _ in current_search),
WHERE ep.event_id IS NULL AND
"""
clause, args = make_in_list_sql_clause(
txn.database_engine, "state_group", current_search
)
txn.execute(sql, list(current_search))
txn.execute(sql + clause, list(args))
referenced = set(sg for sg, in txn)
referenced_groups |= referenced

View File

@@ -21,6 +21,7 @@ from canonicaljson import json
from twisted.internet import defer
from synapse.storage._base import make_in_list_sql_clause
from synapse.storage.background_updates import BackgroundUpdateStore
logger = logging.getLogger(__name__)
@@ -71,6 +72,19 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
"redactions_received_ts", self._redactions_received_ts
)
# This index gets deleted in `event_fix_redactions_bytes` update
self.register_background_index_update(
"event_fix_redactions_bytes_create_index",
index_name="redactions_censored_redacts",
table="redactions",
columns=["redacts"],
where_clause="have_censored",
)
self.register_background_update_handler(
"event_fix_redactions_bytes", self._event_fix_redactions_bytes
)
@defer.inlineCallbacks
def _background_reindex_fields_sender(self, progress, batch_size):
target_min_stream_id = progress["target_min_stream_id_inclusive"]
@@ -312,12 +326,13 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
INNER JOIN event_json USING (event_id)
LEFT JOIN rejections USING (event_id)
WHERE
prev_event_id IN (%s)
AND NOT events.outlier
""" % (
",".join("?" for _ in to_check),
NOT events.outlier
AND
"""
clause, args = make_in_list_sql_clause(
self.database_engine, "prev_event_id", to_check
)
txn.execute(sql, to_check)
txn.execute(sql + clause, list(args))
for prev_event_id, event_id, metadata, rejected in txn:
if event_id in graph:
@@ -458,3 +473,33 @@ class EventsBackgroundUpdatesStore(BackgroundUpdateStore):
yield self._end_background_update("redactions_received_ts")
return count
@defer.inlineCallbacks
def _event_fix_redactions_bytes(self, progress, batch_size):
"""Undoes hex encoded censored redacted event JSON.
"""
def _event_fix_redactions_bytes_txn(txn):
# This update is quite fast due to new index.
txn.execute(
"""
UPDATE event_json
SET
json = convert_from(json::bytea, 'utf8')
FROM redactions
WHERE
redactions.have_censored
AND event_json.event_id = redactions.redacts
AND json NOT LIKE '{%';
"""
)
txn.execute("DROP INDEX redactions_censored_redacts")
yield self.runInteraction(
"_event_fix_redactions_bytes", _event_fix_redactions_bytes_txn
)
yield self._end_background_update("event_fix_redactions_bytes")
return 1

View File

@@ -31,12 +31,11 @@ from synapse.events.snapshot import EventContext # noqa: F401
from synapse.events.utils import prune_event
from synapse.logging.context import LoggingContext, PreserveLoggingContext
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
from synapse.types import get_domain_from_id
from synapse.util import batch_iter
from synapse.util.metrics import Measure
from ._base import SQLBaseStore
logger = logging.getLogger(__name__)
@@ -623,10 +622,14 @@ class EventsWorkerStore(SQLBaseStore):
" rej.reason "
" FROM event_json as e"
" LEFT JOIN rejections as rej USING (event_id)"
" WHERE e.event_id IN (%s)"
) % (",".join(["?"] * len(evs)),)
" WHERE "
)
txn.execute(sql, evs)
clause, args = make_in_list_sql_clause(
txn.database_engine, "e.event_id", evs
)
txn.execute(sql + clause, args)
for row in txn:
event_id = row[0]
@@ -640,11 +643,11 @@ class EventsWorkerStore(SQLBaseStore):
}
# check for redactions
redactions_sql = (
"SELECT event_id, redacts FROM redactions WHERE redacts IN (%s)"
) % (",".join(["?"] * len(evs)),)
redactions_sql = "SELECT event_id, redacts FROM redactions WHERE "
txn.execute(redactions_sql, evs)
clause, args = make_in_list_sql_clause(txn.database_engine, "redacts", evs)
txn.execute(redactions_sql + clause, args)
for (redacter, redacted) in txn:
d = event_dict.get(redacted)
@@ -753,10 +756,11 @@ class EventsWorkerStore(SQLBaseStore):
results = set()
def have_seen_events_txn(txn, chunk):
sql = "SELECT event_id FROM events as e WHERE e.event_id IN (%s)" % (
",".join("?" * len(chunk)),
sql = "SELECT event_id FROM events as e WHERE "
clause, args = make_in_list_sql_clause(
txn.database_engine, "e.event_id", chunk
)
txn.execute(sql, chunk)
txn.execute(sql + clause, args)
for (event_id,) in txn:
results.add(event_id)

View File

@@ -51,7 +51,7 @@ class FilteringStore(SQLBaseStore):
"SELECT filter_id FROM user_filters "
"WHERE user_id = ? AND filter_json = ?"
)
txn.execute(sql, (user_localpart, def_json))
txn.execute(sql, (user_localpart, bytearray(def_json)))
filter_id_response = txn.fetchone()
if filter_id_response is not None:
return filter_id_response[0]
@@ -68,7 +68,7 @@ class FilteringStore(SQLBaseStore):
"INSERT INTO user_filters (user_id, filter_id, filter_json)"
"VALUES(?, ?, ?)"
)
txn.execute(sql, (user_localpart, filter_id, def_json))
txn.execute(sql, (user_localpart, filter_id, bytearray(def_json)))
return filter_id

View File

@@ -15,11 +15,9 @@
from synapse.storage.background_updates import BackgroundUpdateStore
class MediaRepositoryStore(BackgroundUpdateStore):
"""Persistence for attachments and avatars"""
class MediaRepositoryBackgroundUpdateStore(BackgroundUpdateStore):
def __init__(self, db_conn, hs):
super(MediaRepositoryStore, self).__init__(db_conn, hs)
super(MediaRepositoryBackgroundUpdateStore, self).__init__(db_conn, hs)
self.register_background_index_update(
update_name="local_media_repository_url_idx",
@@ -29,6 +27,13 @@ class MediaRepositoryStore(BackgroundUpdateStore):
where_clause="url_cache IS NOT NULL",
)
class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
"""Persistence for attachments and avatars"""
def __init__(self, db_conn, hs):
super(MediaRepositoryStore, self).__init__(db_conn, hs)
def get_local_media(self, media_id):
"""Get the metadata for a local piece of media
Returns:

View File

@@ -32,7 +32,6 @@ class MonthlyActiveUsersStore(SQLBaseStore):
super(MonthlyActiveUsersStore, self).__init__(None, hs)
self._clock = hs.get_clock()
self.hs = hs
self.reserved_users = ()
# Do not add more reserved users than the total allowable number
self._new_transaction(
dbconn,
@@ -51,7 +50,6 @@ class MonthlyActiveUsersStore(SQLBaseStore):
txn (cursor):
threepids (list[dict]): List of threepid dicts to reserve
"""
reserved_user_list = []
for tp in threepids:
user_id = self.get_user_id_by_threepid_txn(txn, tp["medium"], tp["address"])
@@ -60,10 +58,8 @@ class MonthlyActiveUsersStore(SQLBaseStore):
is_support = self.is_support_user_txn(txn, user_id)
if not is_support:
self.upsert_monthly_active_user_txn(txn, user_id)
reserved_user_list.append(user_id)
else:
logger.warning("mau limit reserved threepid %s not found in db" % tp)
self.reserved_users = tuple(reserved_user_list)
@defer.inlineCallbacks
def reap_monthly_active_users(self):
@@ -74,8 +70,11 @@ class MonthlyActiveUsersStore(SQLBaseStore):
Deferred[]
"""
def _reap_users(txn):
# Purge stale users
def _reap_users(txn, reserved_users):
"""
Args:
reserved_users (tuple): reserved users to preserve
"""
thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30)
query_args = [thirty_days_ago]
@@ -83,20 +82,19 @@ class MonthlyActiveUsersStore(SQLBaseStore):
# Need if/else since 'AND user_id NOT IN ({})' fails on Postgres
# when len(reserved_users) == 0. Works fine on sqlite.
if len(self.reserved_users) > 0:
if len(reserved_users) > 0:
# questionmarks is a hack to overcome sqlite not supporting
# tuples in 'WHERE IN %s'
questionmarks = "?" * len(self.reserved_users)
question_marks = ",".join("?" * len(reserved_users))
query_args.extend(self.reserved_users)
sql = base_sql + """ AND user_id NOT IN ({})""".format(
",".join(questionmarks)
)
query_args.extend(reserved_users)
sql = base_sql + " AND user_id NOT IN ({})".format(question_marks)
else:
sql = base_sql
txn.execute(sql, query_args)
max_mau_value = self.hs.config.max_mau_value
if self.hs.config.limit_usage_by_mau:
# If MAU user count still exceeds the MAU threshold, then delete on
# a least recently active basis.
@@ -106,31 +104,52 @@ class MonthlyActiveUsersStore(SQLBaseStore):
# While Postgres does not require 'LIMIT', but also does not support
# negative LIMIT values. So there is no way to write it that both can
# support
safe_guard = self.hs.config.max_mau_value - len(self.reserved_users)
# Must be greater than zero for postgres
safe_guard = safe_guard if safe_guard > 0 else 0
query_args = [safe_guard]
base_sql = """
DELETE FROM monthly_active_users
WHERE user_id NOT IN (
SELECT user_id FROM monthly_active_users
ORDER BY timestamp DESC
LIMIT ?
if len(reserved_users) == 0:
sql = """
DELETE FROM monthly_active_users
WHERE user_id NOT IN (
SELECT user_id FROM monthly_active_users
ORDER BY timestamp DESC
LIMIT ?
)
"""
"""
txn.execute(sql, (max_mau_value,))
# Need if/else since 'AND user_id NOT IN ({})' fails on Postgres
# when len(reserved_users) == 0. Works fine on sqlite.
if len(self.reserved_users) > 0:
query_args.extend(self.reserved_users)
sql = base_sql + """ AND user_id NOT IN ({})""".format(
",".join(questionmarks)
)
else:
sql = base_sql
txn.execute(sql, query_args)
# Must be >= 0 for postgres
num_of_non_reserved_users_to_remove = max(
max_mau_value - len(reserved_users), 0
)
yield self.runInteraction("reap_monthly_active_users", _reap_users)
# It is important to filter reserved users twice to guard
# against the case where the reserved user is present in the
# SELECT, meaning that a legitmate mau is deleted.
sql = """
DELETE FROM monthly_active_users
WHERE user_id NOT IN (
SELECT user_id FROM monthly_active_users
WHERE user_id NOT IN ({})
ORDER BY timestamp DESC
LIMIT ?
)
AND user_id NOT IN ({})
""".format(
question_marks, question_marks
)
query_args = [
*reserved_users,
num_of_non_reserved_users_to_remove,
*reserved_users,
]
txn.execute(sql, query_args)
reserved_users = yield self.get_registered_reserved_users()
yield self.runInteraction(
"reap_monthly_active_users", _reap_users, reserved_users
)
# It seems poor to invalidate the whole cache, Postgres supports
# 'Returning' which would allow me to invalidate only the
# specific users, but sqlite has no way to do this and instead
@@ -159,21 +178,25 @@ class MonthlyActiveUsersStore(SQLBaseStore):
return self.runInteraction("count_users", _count_users)
@defer.inlineCallbacks
def get_registered_reserved_users_count(self):
"""Of the reserved threepids defined in config, how many are associated
def get_registered_reserved_users(self):
"""Of the reserved threepids defined in config, which are associated
with registered users?
Returns:
Defered[int]: Number of real reserved users
Defered[list]: Real reserved users
"""
count = 0
for tp in self.hs.config.mau_limits_reserved_threepids:
users = []
for tp in self.hs.config.mau_limits_reserved_threepids[
: self.hs.config.max_mau_value
]:
user_id = yield self.hs.get_datastore().get_user_id_by_threepid(
tp["medium"], tp["address"]
)
if user_id:
count = count + 1
return count
users.append(user_id)
return users
@defer.inlineCallbacks
def upsert_monthly_active_user(self, user_id):

View File

@@ -18,11 +18,10 @@ from collections import namedtuple
from twisted.internet import defer
from synapse.api.constants import PresenceState
from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
from synapse.util import batch_iter
from synapse.util.caches.descriptors import cached, cachedList
from ._base import SQLBaseStore
class UserPresenceState(
namedtuple(
@@ -119,14 +118,13 @@ class PresenceStore(SQLBaseStore):
)
# Delete old rows to stop database from getting really big
sql = (
"DELETE FROM presence_stream WHERE" " stream_id < ?" " AND user_id IN (%s)"
)
sql = "DELETE FROM presence_stream WHERE stream_id < ? AND "
for states in batch_iter(presence_states, 50):
args = [stream_id]
args.extend(s.user_id for s in states)
txn.execute(sql % (",".join("?" for _ in states),), args)
clause, args = make_in_list_sql_clause(
self.database_engine, "user_id", [s.user_id for s in states]
)
txn.execute(sql + clause, [stream_id] + list(args))
def get_all_presence_updates(self, last_id, current_id):
if last_id == current_id:

View File

@@ -241,7 +241,7 @@ class PusherStore(PusherWorkerStore):
"device_display_name": device_display_name,
"ts": pushkey_ts,
"lang": lang,
"data": encode_canonical_json(data),
"data": bytearray(encode_canonical_json(data)),
"last_stream_ordering": last_stream_ordering,
"profile_tag": profile_tag,
"id": stream_id,

View File

@@ -21,12 +21,11 @@ from canonicaljson import json
from twisted.internet import defer
from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
from synapse.util.caches.stream_change_cache import StreamChangeCache
from ._base import SQLBaseStore
from .util.id_generators import StreamIdGenerator
logger = logging.getLogger(__name__)
@@ -217,24 +216,26 @@ class ReceiptsWorkerStore(SQLBaseStore):
def f(txn):
if from_key:
sql = (
"SELECT * FROM receipts_linearized WHERE"
" room_id IN (%s) AND stream_id > ? AND stream_id <= ?"
) % (",".join(["?"] * len(room_ids)))
args = list(room_ids)
args.extend([from_key, to_key])
sql = """
SELECT * FROM receipts_linearized WHERE
stream_id > ? AND stream_id <= ? AND
"""
clause, args = make_in_list_sql_clause(
self.database_engine, "room_id", room_ids
)
txn.execute(sql, args)
txn.execute(sql + clause, [from_key, to_key] + list(args))
else:
sql = (
"SELECT * FROM receipts_linearized WHERE"
" room_id IN (%s) AND stream_id <= ?"
) % (",".join(["?"] * len(room_ids)))
sql = """
SELECT * FROM receipts_linearized WHERE
stream_id <= ? AND
"""
args = list(room_ids)
args.append(to_key)
clause, args = make_in_list_sql_clause(
self.database_engine, "room_id", room_ids
)
txn.execute(sql, args)
txn.execute(sql + clause, [to_key] + list(args))
return self.cursor_to_dict(txn)
@@ -433,13 +434,19 @@ class ReceiptsStore(ReceiptsWorkerStore):
# we need to points in graph -> linearized form.
# TODO: Make this better.
def graph_to_linear(txn):
query = (
"SELECT event_id WHERE room_id = ? AND stream_ordering IN ("
" SELECT max(stream_ordering) WHERE event_id IN (%s)"
")"
) % (",".join(["?"] * len(event_ids)))
clause, args = make_in_list_sql_clause(
self.database_engine, "event_id", event_ids
)
txn.execute(query, [room_id] + event_ids)
sql = """
SELECT event_id WHERE room_id = ? AND stream_ordering IN (
SELECT max(stream_ordering) WHERE %s
)
""" % (
clause,
)
txn.execute(sql, [room_id] + list(args))
rows = txn.fetchall()
if rows:
return rows[0][0]

View File

@@ -787,13 +787,14 @@ class RegistrationWorkerStore(SQLBaseStore):
)
class RegistrationStore(
class RegistrationBackgroundUpdateStore(
RegistrationWorkerStore, background_updates.BackgroundUpdateStore
):
def __init__(self, db_conn, hs):
super(RegistrationStore, self).__init__(db_conn, hs)
super(RegistrationBackgroundUpdateStore, self).__init__(db_conn, hs)
self.clock = hs.get_clock()
self.config = hs.config
self.register_background_index_update(
"access_tokens_device_index",
@@ -809,8 +810,6 @@ class RegistrationStore(
columns=["creation_ts"],
)
self._account_validity = hs.config.account_validity
# we no longer use refresh tokens, but it's possible that some people
# might have a background update queued to build this index. Just
# clear the background update.
@@ -824,17 +823,6 @@ class RegistrationStore(
"users_set_deactivated_flag", self._background_update_set_deactivated_flag
)
# Create a background job for culling expired 3PID validity tokens
def start_cull():
# run as a background process to make sure that the database transactions
# have a logcontext to report to
return run_as_background_process(
"cull_expired_threepid_validation_tokens",
self.cull_expired_threepid_validation_tokens,
)
hs.get_clock().looping_call(start_cull, THIRTY_MINUTES_IN_MS)
@defer.inlineCallbacks
def _background_update_set_deactivated_flag(self, progress, batch_size):
"""Retrieves a list of all deactivated users and sets the 'deactivated' flag to 1
@@ -896,6 +884,54 @@ class RegistrationStore(
return nb_processed
@defer.inlineCallbacks
def _bg_user_threepids_grandfather(self, progress, batch_size):
"""We now track which identity servers a user binds their 3PID to, so
we need to handle the case of existing bindings where we didn't track
this.
We do this by grandfathering in existing user threepids assuming that
they used one of the server configured trusted identity servers.
"""
id_servers = set(self.config.trusted_third_party_id_servers)
def _bg_user_threepids_grandfather_txn(txn):
sql = """
INSERT INTO user_threepid_id_server
(user_id, medium, address, id_server)
SELECT user_id, medium, address, ?
FROM user_threepids
"""
txn.executemany(sql, [(id_server,) for id_server in id_servers])
if id_servers:
yield self.runInteraction(
"_bg_user_threepids_grandfather", _bg_user_threepids_grandfather_txn
)
yield self._end_background_update("user_threepids_grandfather")
return 1
class RegistrationStore(RegistrationBackgroundUpdateStore):
def __init__(self, db_conn, hs):
super(RegistrationStore, self).__init__(db_conn, hs)
self._account_validity = hs.config.account_validity
# Create a background job for culling expired 3PID validity tokens
def start_cull():
# run as a background process to make sure that the database transactions
# have a logcontext to report to
return run_as_background_process(
"cull_expired_threepid_validation_tokens",
self.cull_expired_threepid_validation_tokens,
)
hs.get_clock().looping_call(start_cull, THIRTY_MINUTES_IN_MS)
@defer.inlineCallbacks
def add_access_token_to_user(self, user_id, token, device_id, valid_until_ms):
"""Adds an access token for the given user.
@@ -1244,36 +1280,6 @@ class RegistrationStore(
desc="get_users_pending_deactivation",
)
@defer.inlineCallbacks
def _bg_user_threepids_grandfather(self, progress, batch_size):
"""We now track which identity servers a user binds their 3PID to, so
we need to handle the case of existing bindings where we didn't track
this.
We do this by grandfathering in existing user threepids assuming that
they used one of the server configured trusted identity servers.
"""
id_servers = set(self.config.trusted_third_party_id_servers)
def _bg_user_threepids_grandfather_txn(txn):
sql = """
INSERT INTO user_threepid_id_server
(user_id, medium, address, id_server)
SELECT user_id, medium, address, ?
FROM user_threepids
"""
txn.executemany(sql, [(id_server,) for id_server in id_servers])
if id_servers:
yield self.runInteraction(
"_bg_user_threepids_grandfather", _bg_user_threepids_grandfather_txn
)
yield self._end_background_update("user_threepids_grandfather")
return 1
def validate_threepid_session(self, session_id, client_secret, token, current_ts):
"""Attempt to validate a threepid session using a token
@@ -1465,17 +1471,6 @@ class RegistrationStore(
self.clock.time_msec(),
)
def set_user_deactivated_status_txn(self, txn, user_id, deactivated):
self._simple_update_one_txn(
txn=txn,
table="users",
keyvalues={"name": user_id},
updatevalues={"deactivated": 1 if deactivated else 0},
)
self._invalidate_cache_and_stream(
txn, self.get_user_deactivated_status, (user_id,)
)
@defer.inlineCallbacks
def set_user_deactivated_status(self, user_id, deactivated):
"""Set the `deactivated` property for the provided user to the provided value.
@@ -1491,3 +1486,14 @@ class RegistrationStore(
user_id,
deactivated,
)
def set_user_deactivated_status_txn(self, txn, user_id, deactivated):
self._simple_update_one_txn(
txn=txn,
table="users",
keyvalues={"name": user_id},
updatevalues={"deactivated": 1 if deactivated else 0},
)
self._invalidate_cache_and_stream(
txn, self.get_user_deactivated_status, (user_id,)
)

View File

@@ -26,13 +26,14 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, Membership
from synapse.metrics import LaterGauge
from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage._base import LoggingTransaction
from synapse.storage._base import LoggingTransaction, make_in_list_sql_clause
from synapse.storage.background_updates import BackgroundUpdateStore
from synapse.storage.engines import Sqlite3Engine
from synapse.storage.events_worker import EventsWorkerStore
from synapse.types import get_domain_from_id
from synapse.util.async_helpers import Linearizer
from synapse.util.caches import intern_string
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
from synapse.util.metrics import Measure
from synapse.util.stringutils import to_ascii
@@ -371,6 +372,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
results = []
if membership_list:
if self._current_state_events_membership_up_to_date:
clause, args = make_in_list_sql_clause(
self.database_engine, "c.membership", membership_list
)
sql = """
SELECT room_id, e.sender, c.membership, event_id, e.stream_ordering
FROM current_state_events AS c
@@ -378,11 +382,14 @@ class RoomMemberWorkerStore(EventsWorkerStore):
WHERE
c.type = 'm.room.member'
AND state_key = ?
AND c.membership IN (%s)
AND %s
""" % (
",".join("?" * len(membership_list))
clause,
)
else:
clause, args = make_in_list_sql_clause(
self.database_engine, "m.membership", membership_list
)
sql = """
SELECT room_id, e.sender, m.membership, event_id, e.stream_ordering
FROM current_state_events AS c
@@ -391,12 +398,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
WHERE
c.type = 'm.room.member'
AND state_key = ?
AND m.membership IN (%s)
AND %s
""" % (
",".join("?" * len(membership_list))
clause,
)
txn.execute(sql, (user_id, *membership_list))
txn.execute(sql, (user_id, *args))
results = [RoomsForUser(**r) for r in self.cursor_to_dict(txn)]
if do_invite:
@@ -572,25 +579,10 @@ class RoomMemberWorkerStore(EventsWorkerStore):
missing_member_event_ids.append(event_id)
if missing_member_event_ids:
rows = yield self._simple_select_many_batch(
table="room_memberships",
column="event_id",
iterable=missing_member_event_ids,
retcols=("user_id", "display_name", "avatar_url"),
keyvalues={"membership": Membership.JOIN},
batch_size=500,
desc="_get_joined_users_from_context",
)
users_in_room.update(
{
to_ascii(row["user_id"]): ProfileInfo(
avatar_url=to_ascii(row["avatar_url"]),
display_name=to_ascii(row["display_name"]),
)
for row in rows
}
event_to_memberships = yield self._get_joined_profiles_from_event_ids(
missing_member_event_ids
)
users_in_room.update((row for row in event_to_memberships.values() if row))
if event is not None and event.type == EventTypes.Member:
if event.membership == Membership.JOIN:
@@ -602,6 +594,47 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return users_in_room
@cached(max_entries=10000)
def _get_joined_profile_from_event_id(self, event_id):
raise NotImplementedError()
@cachedList(
cached_method_name="_get_joined_profile_from_event_id",
list_name="event_ids",
inlineCallbacks=True,
)
def _get_joined_profiles_from_event_ids(self, event_ids):
"""For given set of member event_ids check if they point to a join
event and if so return the associated user and profile info.
Args:
event_ids (Iterable[str]): The member event IDs to lookup
Returns:
Deferred[dict[str, Tuple[str, ProfileInfo]|None]]: Map from event ID
to `user_id` and ProfileInfo (or None if not join event).
"""
rows = yield self._simple_select_many_batch(
table="room_memberships",
column="event_id",
iterable=event_ids,
retcols=("user_id", "display_name", "avatar_url", "event_id"),
keyvalues={"membership": Membership.JOIN},
batch_size=500,
desc="_get_membership_from_event_ids",
)
return {
row["event_id"]: (
row["user_id"],
ProfileInfo(
avatar_url=row["avatar_url"], display_name=row["display_name"]
),
)
for row in rows
}
@cachedInlineCallbacks(max_entries=10000)
def is_host_joined(self, room_id, host):
if "%" in host or "_" in host:
@@ -794,9 +827,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return set(room_ids)
class RoomMemberStore(RoomMemberWorkerStore):
class RoomMemberBackgroundUpdateStore(BackgroundUpdateStore):
def __init__(self, db_conn, hs):
super(RoomMemberStore, self).__init__(db_conn, hs)
super(RoomMemberBackgroundUpdateStore, self).__init__(db_conn, hs)
self.register_background_update_handler(
_MEMBERSHIP_PROFILE_UPDATE_NAME, self._background_add_membership_profile
)
@@ -812,112 +845,6 @@ class RoomMemberStore(RoomMemberWorkerStore):
where_clause="forgotten = 1",
)
def _store_room_members_txn(self, txn, events, backfilled):
"""Store a room member in the database.
"""
self._simple_insert_many_txn(
txn,
table="room_memberships",
values=[
{
"event_id": event.event_id,
"user_id": event.state_key,
"sender": event.user_id,
"room_id": event.room_id,
"membership": event.membership,
"display_name": event.content.get("displayname", None),
"avatar_url": event.content.get("avatar_url", None),
}
for event in events
],
)
for event in events:
txn.call_after(
self._membership_stream_cache.entity_has_changed,
event.state_key,
event.internal_metadata.stream_ordering,
)
txn.call_after(
self.get_invited_rooms_for_user.invalidate, (event.state_key,)
)
# We update the local_invites table only if the event is "current",
# i.e., its something that has just happened. If the event is an
# outlier it is only current if its an "out of band membership",
# like a remote invite or a rejection of a remote invite.
is_new_state = not backfilled and (
not event.internal_metadata.is_outlier()
or event.internal_metadata.is_out_of_band_membership()
)
is_mine = self.hs.is_mine_id(event.state_key)
if is_new_state and is_mine:
if event.membership == Membership.INVITE:
self._simple_insert_txn(
txn,
table="local_invites",
values={
"event_id": event.event_id,
"invitee": event.state_key,
"inviter": event.sender,
"room_id": event.room_id,
"stream_id": event.internal_metadata.stream_ordering,
},
)
else:
sql = (
"UPDATE local_invites SET stream_id = ?, replaced_by = ? WHERE"
" room_id = ? AND invitee = ? AND locally_rejected is NULL"
" AND replaced_by is NULL"
)
txn.execute(
sql,
(
event.internal_metadata.stream_ordering,
event.event_id,
event.room_id,
event.state_key,
),
)
@defer.inlineCallbacks
def locally_reject_invite(self, user_id, room_id):
sql = (
"UPDATE local_invites SET stream_id = ?, locally_rejected = ? WHERE"
" room_id = ? AND invitee = ? AND locally_rejected is NULL"
" AND replaced_by is NULL"
)
def f(txn, stream_ordering):
txn.execute(sql, (stream_ordering, True, room_id, user_id))
with self._stream_id_gen.get_next() as stream_ordering:
yield self.runInteraction("locally_reject_invite", f, stream_ordering)
def forget(self, user_id, room_id):
"""Indicate that user_id wishes to discard history for room_id."""
def f(txn):
sql = (
"UPDATE"
" room_memberships"
" SET"
" forgotten = 1"
" WHERE"
" user_id = ?"
" AND"
" room_id = ?"
)
txn.execute(sql, (user_id, room_id))
self._invalidate_cache_and_stream(txn, self.did_forget, (user_id, room_id))
self._invalidate_cache_and_stream(
txn, self.get_forgotten_rooms_for_user, (user_id,)
)
return self.runInteraction("forget_membership", f)
@defer.inlineCallbacks
def _background_add_membership_profile(self, progress, batch_size):
target_min_stream_id = progress.get(
@@ -1052,6 +979,117 @@ class RoomMemberStore(RoomMemberWorkerStore):
return row_count
class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
def __init__(self, db_conn, hs):
super(RoomMemberStore, self).__init__(db_conn, hs)
def _store_room_members_txn(self, txn, events, backfilled):
"""Store a room member in the database.
"""
self._simple_insert_many_txn(
txn,
table="room_memberships",
values=[
{
"event_id": event.event_id,
"user_id": event.state_key,
"sender": event.user_id,
"room_id": event.room_id,
"membership": event.membership,
"display_name": event.content.get("displayname", None),
"avatar_url": event.content.get("avatar_url", None),
}
for event in events
],
)
for event in events:
txn.call_after(
self._membership_stream_cache.entity_has_changed,
event.state_key,
event.internal_metadata.stream_ordering,
)
txn.call_after(
self.get_invited_rooms_for_user.invalidate, (event.state_key,)
)
# We update the local_invites table only if the event is "current",
# i.e., its something that has just happened. If the event is an
# outlier it is only current if its an "out of band membership",
# like a remote invite or a rejection of a remote invite.
is_new_state = not backfilled and (
not event.internal_metadata.is_outlier()
or event.internal_metadata.is_out_of_band_membership()
)
is_mine = self.hs.is_mine_id(event.state_key)
if is_new_state and is_mine:
if event.membership == Membership.INVITE:
self._simple_insert_txn(
txn,
table="local_invites",
values={
"event_id": event.event_id,
"invitee": event.state_key,
"inviter": event.sender,
"room_id": event.room_id,
"stream_id": event.internal_metadata.stream_ordering,
},
)
else:
sql = (
"UPDATE local_invites SET stream_id = ?, replaced_by = ? WHERE"
" room_id = ? AND invitee = ? AND locally_rejected is NULL"
" AND replaced_by is NULL"
)
txn.execute(
sql,
(
event.internal_metadata.stream_ordering,
event.event_id,
event.room_id,
event.state_key,
),
)
@defer.inlineCallbacks
def locally_reject_invite(self, user_id, room_id):
sql = (
"UPDATE local_invites SET stream_id = ?, locally_rejected = ? WHERE"
" room_id = ? AND invitee = ? AND locally_rejected is NULL"
" AND replaced_by is NULL"
)
def f(txn, stream_ordering):
txn.execute(sql, (stream_ordering, True, room_id, user_id))
with self._stream_id_gen.get_next() as stream_ordering:
yield self.runInteraction("locally_reject_invite", f, stream_ordering)
def forget(self, user_id, room_id):
"""Indicate that user_id wishes to discard history for room_id."""
def f(txn):
sql = (
"UPDATE"
" room_memberships"
" SET"
" forgotten = 1"
" WHERE"
" user_id = ?"
" AND"
" room_id = ?"
)
txn.execute(sql, (user_id, room_id))
self._invalidate_cache_and_stream(txn, self.did_forget, (user_id, room_id))
self._invalidate_cache_and_stream(
txn, self.get_forgotten_rooms_for_user, (user_id,)
)
return self.runInteraction("forget_membership", f)
class _JoinedHostsCache(object):
"""Cache for joined hosts in a room that is optimised to handle updates
via state deltas.

View File

@@ -0,0 +1,25 @@
/* Copyright 2019 The Matrix.org Foundation C.I.C.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-- There was a bug where we may have updated censored redactions as bytes,
-- which can (somehow) cause json to be inserted hex encoded. These updates go
-- and undoes any such hex encoded JSON.
INSERT into background_updates (update_name, progress_json)
VALUES ('event_fix_redactions_bytes_create_index', '{}');
INSERT into background_updates (update_name, progress_json, depends_on)
VALUES ('event_fix_redactions_bytes', '{}', 'event_fix_redactions_bytes_create_index');

View File

@@ -5,42 +5,48 @@ from synapse.storage.engines import PostgresEngine
logger = logging.getLogger(__name__)
"""
This migration updates the user_filters table as follows:
- drops any (user_id, filter_id) duplicates
- makes the columns NON-NULLable
- turns the index into a UNIQUE index
"""
def run_upgrade(cur, database_engine, *args, **kwargs):
pass
def run_create(cur, database_engine, *args, **kwargs):
if isinstance(database_engine, PostgresEngine):
select_clause = """
CREATE TEMPORARY TABLE user_filters_migration AS
SELECT DISTINCT ON (user_id, filter_id) user_id, filter_id, filter_json
FROM user_filters;
FROM user_filters
"""
else:
select_clause = """
CREATE TEMPORARY TABLE user_filters_migration AS
SELECT * FROM user_filters GROUP BY user_id, filter_id;
SELECT * FROM user_filters GROUP BY user_id, filter_id
"""
sql = (
"""
BEGIN;
%s
DROP INDEX user_filters_by_user_id_filter_id;
DELETE FROM user_filters;
ALTER TABLE user_filters
ALTER COLUMN user_id SET NOT NULL
ALTER COLUMN filter_id SET NOT NULL
ALTER COLUMN filter_json SET NOT NULL;
INSERT INTO user_filters(user_id, filter_id, filter_json)
SELECT * FROM user_filters_migration;
DROP TABLE user_filters_migration;
CREATE UNIQUE INDEX user_filters_by_user_id_filter_id_unique
ON user_filters(user_id, filter_id);
END;
"""
% select_clause
sql = """
DROP TABLE IF EXISTS user_filters_migration;
DROP INDEX IF EXISTS user_filters_unique;
CREATE TABLE user_filters_migration (
user_id TEXT NOT NULL,
filter_id BIGINT NOT NULL,
filter_json BYTEA NOT NULL
);
INSERT INTO user_filters_migration (user_id, filter_id, filter_json)
%s;
CREATE UNIQUE INDEX user_filters_unique ON user_filters_migration
(user_id, filter_id);
DROP TABLE user_filters;
ALTER TABLE user_filters_migration RENAME TO user_filters;
""" % (
select_clause,
)
if isinstance(database_engine, PostgresEngine):
cur.execute(sql)
else:
cur.executescript(sql)
def run_create(cur, database_engine, *args, **kwargs):
pass

View File

@@ -24,6 +24,7 @@ from canonicaljson import json
from twisted.internet import defer
from synapse.api.errors import SynapseError
from synapse.storage._base import make_in_list_sql_clause
from synapse.storage.engines import PostgresEngine, Sqlite3Engine
from .background_updates import BackgroundUpdateStore
@@ -36,7 +37,7 @@ SearchEntry = namedtuple(
)
class SearchStore(BackgroundUpdateStore):
class SearchBackgroundUpdateStore(BackgroundUpdateStore):
EVENT_SEARCH_UPDATE_NAME = "event_search"
EVENT_SEARCH_ORDER_UPDATE_NAME = "event_search_order"
@@ -44,7 +45,7 @@ class SearchStore(BackgroundUpdateStore):
EVENT_SEARCH_USE_GIN_POSTGRES_NAME = "event_search_postgres_gin"
def __init__(self, db_conn, hs):
super(SearchStore, self).__init__(db_conn, hs)
super(SearchBackgroundUpdateStore, self).__init__(db_conn, hs)
if not hs.config.enable_search:
return
@@ -289,29 +290,6 @@ class SearchStore(BackgroundUpdateStore):
return num_rows
def store_event_search_txn(self, txn, event, key, value):
"""Add event to the search table
Args:
txn (cursor):
event (EventBase):
key (str):
value (str):
"""
self.store_search_entries_txn(
txn,
(
SearchEntry(
key=key,
value=value,
event_id=event.event_id,
room_id=event.room_id,
stream_ordering=event.internal_metadata.stream_ordering,
origin_server_ts=event.origin_server_ts,
),
),
)
def store_search_entries_txn(self, txn, entries):
"""Add entries to the search table
@@ -358,6 +336,34 @@ class SearchStore(BackgroundUpdateStore):
# This should be unreachable.
raise Exception("Unrecognized database engine")
class SearchStore(SearchBackgroundUpdateStore):
def __init__(self, db_conn, hs):
super(SearchStore, self).__init__(db_conn, hs)
def store_event_search_txn(self, txn, event, key, value):
"""Add event to the search table
Args:
txn (cursor):
event (EventBase):
key (str):
value (str):
"""
self.store_search_entries_txn(
txn,
(
SearchEntry(
key=key,
value=value,
event_id=event.event_id,
room_id=event.room_id,
stream_ordering=event.internal_metadata.stream_ordering,
origin_server_ts=event.origin_server_ts,
),
),
)
@defer.inlineCallbacks
def search_msgs(self, room_ids, search_term, keys):
"""Performs a full text search over events with given keys.
@@ -380,8 +386,10 @@ class SearchStore(BackgroundUpdateStore):
# Make sure we don't explode because the person is in too many rooms.
# We filter the results below regardless.
if len(room_ids) < 500:
clauses.append("room_id IN (%s)" % (",".join(["?"] * len(room_ids)),))
args.extend(room_ids)
clause, args = make_in_list_sql_clause(
self.database_engine, "room_id", room_ids
)
clauses = [clause]
local_clauses = []
for key in keys:
@@ -487,8 +495,10 @@ class SearchStore(BackgroundUpdateStore):
# Make sure we don't explode because the person is in too many rooms.
# We filter the results below regardless.
if len(room_ids) < 500:
clauses.append("room_id IN (%s)" % (",".join(["?"] * len(room_ids)),))
args.extend(room_ids)
clause, args = make_in_list_sql_clause(
self.database_engine, "room_id", room_ids
)
clauses = [clause]
local_clauses = []
for key in keys:

View File

@@ -353,8 +353,158 @@ class StateFilter(object):
return member_filter, non_member_filter
class StateGroupBackgroundUpdateStore(SQLBaseStore):
"""Defines functions related to state groups needed to run the state backgroud
updates.
"""
def _count_state_group_hops_txn(self, txn, state_group):
"""Given a state group, count how many hops there are in the tree.
This is used to ensure the delta chains don't get too long.
"""
if isinstance(self.database_engine, PostgresEngine):
sql = """
WITH RECURSIVE state(state_group) AS (
VALUES(?::bigint)
UNION ALL
SELECT prev_state_group FROM state_group_edges e, state s
WHERE s.state_group = e.state_group
)
SELECT count(*) FROM state;
"""
txn.execute(sql, (state_group,))
row = txn.fetchone()
if row and row[0]:
return row[0]
else:
return 0
else:
# We don't use WITH RECURSIVE on sqlite3 as there are distributions
# that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
next_group = state_group
count = 0
while next_group:
next_group = self._simple_select_one_onecol_txn(
txn,
table="state_group_edges",
keyvalues={"state_group": next_group},
retcol="prev_state_group",
allow_none=True,
)
if next_group:
count += 1
return count
def _get_state_groups_from_groups_txn(
self, txn, groups, state_filter=StateFilter.all()
):
results = {group: {} for group in groups}
where_clause, where_args = state_filter.make_sql_filter_clause()
# Unless the filter clause is empty, we're going to append it after an
# existing where clause
if where_clause:
where_clause = " AND (%s)" % (where_clause,)
if isinstance(self.database_engine, PostgresEngine):
# Temporarily disable sequential scans in this transaction. This is
# a temporary hack until we can add the right indices in
txn.execute("SET LOCAL enable_seqscan=off")
# The below query walks the state_group tree so that the "state"
# table includes all state_groups in the tree. It then joins
# against `state_groups_state` to fetch the latest state.
# It assumes that previous state groups are always numerically
# lesser.
# The PARTITION is used to get the event_id in the greatest state
# group for the given type, state_key.
# This may return multiple rows per (type, state_key), but last_value
# should be the same.
sql = """
WITH RECURSIVE state(state_group) AS (
VALUES(?::bigint)
UNION ALL
SELECT prev_state_group FROM state_group_edges e, state s
WHERE s.state_group = e.state_group
)
SELECT DISTINCT type, state_key, last_value(event_id) OVER (
PARTITION BY type, state_key ORDER BY state_group ASC
ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
) AS event_id FROM state_groups_state
WHERE state_group IN (
SELECT state_group FROM state
)
"""
for group in groups:
args = [group]
args.extend(where_args)
txn.execute(sql + where_clause, args)
for row in txn:
typ, state_key, event_id = row
key = (typ, state_key)
results[group][key] = event_id
else:
max_entries_returned = state_filter.max_entries_returned()
# We don't use WITH RECURSIVE on sqlite3 as there are distributions
# that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
for group in groups:
next_group = group
while next_group:
# We did this before by getting the list of group ids, and
# then passing that list to sqlite to get latest event for
# each (type, state_key). However, that was terribly slow
# without the right indices (which we can't add until
# after we finish deduping state, which requires this func)
args = [next_group]
args.extend(where_args)
txn.execute(
"SELECT type, state_key, event_id FROM state_groups_state"
" WHERE state_group = ? " + where_clause,
args,
)
results[group].update(
((typ, state_key), event_id)
for typ, state_key, event_id in txn
if (typ, state_key) not in results[group]
)
# If the number of entries in the (type,state_key)->event_id dict
# matches the number of (type,state_keys) types we were searching
# for, then we must have found them all, so no need to go walk
# further down the tree... UNLESS our types filter contained
# wildcards (i.e. Nones) in which case we have to do an exhaustive
# search
if (
max_entries_returned is not None
and len(results[group]) == max_entries_returned
):
break
next_group = self._simple_select_one_onecol_txn(
txn,
table="state_group_edges",
keyvalues={"state_group": next_group},
retcol="prev_state_group",
allow_none=True,
)
return results
# this inherits from EventsWorkerStore because it calls self.get_events
class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
class StateGroupWorkerStore(
EventsWorkerStore, StateGroupBackgroundUpdateStore, SQLBaseStore
):
"""The parts of StateGroupStore that can be called from workers.
"""
@@ -694,107 +844,6 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
return results
def _get_state_groups_from_groups_txn(
self, txn, groups, state_filter=StateFilter.all()
):
results = {group: {} for group in groups}
where_clause, where_args = state_filter.make_sql_filter_clause()
# Unless the filter clause is empty, we're going to append it after an
# existing where clause
if where_clause:
where_clause = " AND (%s)" % (where_clause,)
if isinstance(self.database_engine, PostgresEngine):
# Temporarily disable sequential scans in this transaction. This is
# a temporary hack until we can add the right indices in
txn.execute("SET LOCAL enable_seqscan=off")
# The below query walks the state_group tree so that the "state"
# table includes all state_groups in the tree. It then joins
# against `state_groups_state` to fetch the latest state.
# It assumes that previous state groups are always numerically
# lesser.
# The PARTITION is used to get the event_id in the greatest state
# group for the given type, state_key.
# This may return multiple rows per (type, state_key), but last_value
# should be the same.
sql = """
WITH RECURSIVE state(state_group) AS (
VALUES(?::bigint)
UNION ALL
SELECT prev_state_group FROM state_group_edges e, state s
WHERE s.state_group = e.state_group
)
SELECT DISTINCT type, state_key, last_value(event_id) OVER (
PARTITION BY type, state_key ORDER BY state_group ASC
ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
) AS event_id FROM state_groups_state
WHERE state_group IN (
SELECT state_group FROM state
)
"""
for group in groups:
args = [group]
args.extend(where_args)
txn.execute(sql + where_clause, args)
for row in txn:
typ, state_key, event_id = row
key = (typ, state_key)
results[group][key] = event_id
else:
max_entries_returned = state_filter.max_entries_returned()
# We don't use WITH RECURSIVE on sqlite3 as there are distributions
# that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
for group in groups:
next_group = group
while next_group:
# We did this before by getting the list of group ids, and
# then passing that list to sqlite to get latest event for
# each (type, state_key). However, that was terribly slow
# without the right indices (which we can't add until
# after we finish deduping state, which requires this func)
args = [next_group]
args.extend(where_args)
txn.execute(
"SELECT type, state_key, event_id FROM state_groups_state"
" WHERE state_group = ? " + where_clause,
args,
)
results[group].update(
((typ, state_key), event_id)
for typ, state_key, event_id in txn
if (typ, state_key) not in results[group]
)
# If the number of entries in the (type,state_key)->event_id dict
# matches the number of (type,state_keys) types we were searching
# for, then we must have found them all, so no need to go walk
# further down the tree... UNLESS our types filter contained
# wildcards (i.e. Nones) in which case we have to do an exhaustive
# search
if (
max_entries_returned is not None
and len(results[group]) == max_entries_returned
):
break
next_group = self._simple_select_one_onecol_txn(
txn,
table="state_group_edges",
keyvalues={"state_group": next_group},
retcol="prev_state_group",
allow_none=True,
)
return results
@defer.inlineCallbacks
def get_state_for_events(self, event_ids, state_filter=StateFilter.all()):
"""Given a list of event_ids and type tuples, return a list of state
@@ -1238,66 +1287,10 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
return self.runInteraction("store_state_group", _store_state_group_txn)
def _count_state_group_hops_txn(self, txn, state_group):
"""Given a state group, count how many hops there are in the tree.
This is used to ensure the delta chains don't get too long.
"""
if isinstance(self.database_engine, PostgresEngine):
sql = """
WITH RECURSIVE state(state_group) AS (
VALUES(?::bigint)
UNION ALL
SELECT prev_state_group FROM state_group_edges e, state s
WHERE s.state_group = e.state_group
)
SELECT count(*) FROM state;
"""
txn.execute(sql, (state_group,))
row = txn.fetchone()
if row and row[0]:
return row[0]
else:
return 0
else:
# We don't use WITH RECURSIVE on sqlite3 as there are distributions
# that ship with an sqlite3 version that doesn't support it (e.g. wheezy)
next_group = state_group
count = 0
while next_group:
next_group = self._simple_select_one_onecol_txn(
txn,
table="state_group_edges",
keyvalues={"state_group": next_group},
retcol="prev_state_group",
allow_none=True,
)
if next_group:
count += 1
return count
class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
""" Keeps track of the state at a given event.
This is done by the concept of `state groups`. Every event is a assigned
a state group (identified by an arbitrary string), which references a
collection of state events. The current state of an event is then the
collection of state events referenced by the event's state group.
Hence, every change in the current state causes a new state group to be
generated. However, if no change happens (e.g., if we get a message event
with only one parent it inherits the state group from its parent.)
There are three tables:
* `state_groups`: Stores group name, first event with in the group and
room id.
* `event_to_state_groups`: Maps events to state groups.
* `state_groups_state`: Maps state group to state events.
"""
class StateBackgroundUpdateStore(
StateGroupBackgroundUpdateStore, BackgroundUpdateStore
):
STATE_GROUP_DEDUPLICATION_UPDATE_NAME = "state_group_state_deduplication"
STATE_GROUP_INDEX_UPDATE_NAME = "state_group_state_type_index"
@@ -1305,7 +1298,7 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
EVENT_STATE_GROUP_INDEX_UPDATE_NAME = "event_to_state_groups_sg_index"
def __init__(self, db_conn, hs):
super(StateStore, self).__init__(db_conn, hs)
super(StateBackgroundUpdateStore, self).__init__(db_conn, hs)
self.register_background_update_handler(
self.STATE_GROUP_DEDUPLICATION_UPDATE_NAME,
self._background_deduplicate_state,
@@ -1327,34 +1320,6 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
columns=["state_group"],
)
def _store_event_state_mappings_txn(self, txn, events_and_contexts):
state_groups = {}
for event, context in events_and_contexts:
if event.internal_metadata.is_outlier():
continue
# if the event was rejected, just give it the same state as its
# predecessor.
if context.rejected:
state_groups[event.event_id] = context.prev_group
continue
state_groups[event.event_id] = context.state_group
self._simple_insert_many_txn(
txn,
table="event_to_state_groups",
values=[
{"state_group": state_group_id, "event_id": event_id}
for event_id, state_group_id in iteritems(state_groups)
],
)
for event_id, state_group_id in iteritems(state_groups):
txn.call_after(
self._get_state_group_for_event.prefill, (event_id,), state_group_id
)
@defer.inlineCallbacks
def _background_deduplicate_state(self, progress, batch_size):
"""This background update will slowly deduplicate state by reencoding
@@ -1527,3 +1492,54 @@ class StateStore(StateGroupWorkerStore, BackgroundUpdateStore):
yield self._end_background_update(self.STATE_GROUP_INDEX_UPDATE_NAME)
return 1
class StateStore(StateGroupWorkerStore, StateBackgroundUpdateStore):
""" Keeps track of the state at a given event.
This is done by the concept of `state groups`. Every event is a assigned
a state group (identified by an arbitrary string), which references a
collection of state events. The current state of an event is then the
collection of state events referenced by the event's state group.
Hence, every change in the current state causes a new state group to be
generated. However, if no change happens (e.g., if we get a message event
with only one parent it inherits the state group from its parent.)
There are three tables:
* `state_groups`: Stores group name, first event with in the group and
room id.
* `event_to_state_groups`: Maps events to state groups.
* `state_groups_state`: Maps state group to state events.
"""
def __init__(self, db_conn, hs):
super(StateStore, self).__init__(db_conn, hs)
def _store_event_state_mappings_txn(self, txn, events_and_contexts):
state_groups = {}
for event, context in events_and_contexts:
if event.internal_metadata.is_outlier():
continue
# if the event was rejected, just give it the same state as its
# predecessor.
if context.rejected:
state_groups[event.event_id] = context.prev_group
continue
state_groups[event.event_id] = context.state_group
self._simple_insert_many_txn(
txn,
table="event_to_state_groups",
values=[
{"state_group": state_group_id, "event_id": event_id}
for event_id, state_group_id in iteritems(state_groups)
],
)
for event_id, state_group_id in iteritems(state_groups):
txn.call_after(
self._get_state_group_for_event.prefill, (event_id,), state_group_id
)

View File

@@ -21,7 +21,7 @@ logger = logging.getLogger(__name__)
class StateDeltasStore(SQLBaseStore):
def get_current_state_deltas(self, prev_stream_id):
def get_current_state_deltas(self, prev_stream_id: int, max_stream_id: int):
"""Fetch a list of room state changes since the given stream id
Each entry in the result contains the following fields:
@@ -36,15 +36,27 @@ class StateDeltasStore(SQLBaseStore):
Args:
prev_stream_id (int): point to get changes since (exclusive)
max_stream_id (int): the point that we know has been correctly persisted
- ie, an upper limit to return changes from.
Returns:
Deferred[list[dict]]: results
Deferred[tuple[int, list[dict]]: A tuple consisting of:
- the stream id which these results go up to
- list of current_state_delta_stream rows. If it is empty, we are
up to date.
"""
prev_stream_id = int(prev_stream_id)
# check we're not going backwards
assert prev_stream_id <= max_stream_id
if not self._curr_state_delta_stream_cache.has_any_entity_changed(
prev_stream_id
):
return []
# if the CSDs haven't changed between prev_stream_id and now, we
# know for certain that they haven't changed between prev_stream_id and
# max_stream_id.
return max_stream_id, []
def get_current_state_deltas_txn(txn):
# First we calculate the max stream id that will give us less than
@@ -54,21 +66,29 @@ class StateDeltasStore(SQLBaseStore):
sql = """
SELECT stream_id, count(*)
FROM current_state_delta_stream
WHERE stream_id > ?
WHERE stream_id > ? AND stream_id <= ?
GROUP BY stream_id
ORDER BY stream_id ASC
LIMIT 100
"""
txn.execute(sql, (prev_stream_id,))
txn.execute(sql, (prev_stream_id, max_stream_id))
total = 0
max_stream_id = prev_stream_id
for max_stream_id, count in txn:
for stream_id, count in txn:
total += count
if total > 100:
# We arbitarily limit to 100 entries to ensure we don't
# select toooo many.
logger.debug(
"Clipping current_state_delta_stream rows to stream_id %i",
stream_id,
)
clipped_stream_id = stream_id
break
else:
# if there's no problem, we may as well go right up to the max_stream_id
clipped_stream_id = max_stream_id
# Now actually get the deltas
sql = """
@@ -77,8 +97,8 @@ class StateDeltasStore(SQLBaseStore):
WHERE ? < stream_id AND stream_id <= ?
ORDER BY stream_id ASC
"""
txn.execute(sql, (prev_stream_id, max_stream_id))
return self.cursor_to_dict(txn)
txn.execute(sql, (prev_stream_id, clipped_stream_id))
return clipped_stream_id, self.cursor_to_dict(txn)
return self.runInteraction(
"get_current_state_deltas", get_current_state_deltas_txn

View File

@@ -332,6 +332,9 @@ class StatsStore(StateDeltasStore):
def _bulk_update_stats_delta_txn(txn):
for stats_type, stats_updates in updates.items():
for stats_id, fields in stats_updates.items():
logger.info(
"Updating %s stats for %s: %s", stats_type, stats_id, fields
)
self._update_stats_delta_txn(
txn,
ts=ts,

Some files were not shown because too many files have changed in this diff Show More