Compare commits
68 Commits
v0.99.3.1
...
neilj/cont
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
46a886194f | ||
|
|
e98aabf2eb | ||
|
|
8e85493b0c | ||
|
|
a33a5abc4c | ||
|
|
616e6a10bd | ||
|
|
db265f0642 | ||
|
|
9f5d206c4a | ||
|
|
43c707a010 | ||
|
|
40810b81d2 | ||
|
|
2a59e8e429 | ||
|
|
bd3435e982 | ||
|
|
c6a233a936 | ||
|
|
c192bf8970 | ||
|
|
4a2e13631d | ||
|
|
4a4d5c4fd6 | ||
|
|
e4d473d855 | ||
|
|
e8419554ff | ||
|
|
8f549c1177 | ||
|
|
7efd1d87c2 | ||
|
|
3039d61baf | ||
|
|
6f226eed42 | ||
|
|
66e78700a2 | ||
|
|
ac45b0df0b | ||
|
|
8530090b16 | ||
|
|
5bec8d660d | ||
|
|
297bf2547e | ||
|
|
4ef5d17b96 | ||
|
|
24232514bf | ||
|
|
c75e2017f1 | ||
|
|
4c552ed78a | ||
|
|
39fb971e85 | ||
|
|
862d6e5ba5 | ||
|
|
3715c124b3 | ||
|
|
057715aaa2 | ||
|
|
9fbbc3d9e5 | ||
|
|
1666c0696a | ||
|
|
d461c65465 | ||
|
|
62988f73fd | ||
|
|
bb925b1bd7 | ||
|
|
54a87a7b08 | ||
|
|
215c15d049 | ||
|
|
50b5f08740 | ||
|
|
e0f219789d | ||
|
|
aee4ea8ba8 | ||
|
|
902cdc63b6 | ||
|
|
d688a51736 | ||
|
|
c7296bcb98 | ||
|
|
7a91b9d81c | ||
|
|
248014379e | ||
|
|
4e5f0f7ca0 | ||
|
|
40e56997bc | ||
|
|
d035d62f6b | ||
|
|
4eeb2c2f07 | ||
|
|
2e060774ad | ||
|
|
17d7bacbcf | ||
|
|
4b91c313a9 | ||
|
|
1f6d6f918a | ||
|
|
a65763a5d6 | ||
|
|
015b3622eb | ||
|
|
f570916a3e | ||
|
|
91c3513668 | ||
|
|
71dcb275f1 | ||
|
|
aa1e017864 | ||
|
|
a5798de067 | ||
|
|
acaa18f7dd | ||
|
|
d5a5d1c632 | ||
|
|
b7fa834c40 | ||
|
|
197fae1639 |
11
CHANGES.md
11
CHANGES.md
@@ -1,14 +1,3 @@
|
||||
Synapse 0.99.3.1 (2019-05-03)
|
||||
=============================
|
||||
|
||||
Security update
|
||||
---------------
|
||||
|
||||
This release includes two security fixes:
|
||||
|
||||
- Switch to using a cryptographically-secure random number generator for token strings, ensuring they cannot be predicted by an attacker. Thanks to @opnsec for identifying and responsibly disclosing this issue! ([\#5133](https://github.com/matrix-org/synapse/issues/5133))
|
||||
- Blacklist 0.0.0.0 and :: by default for URL previews. Thanks to @opnsec for identifying and responsibly disclosing this issue too! ([\#5134](https://github.com/matrix-org/synapse/issues/5134))
|
||||
|
||||
Synapse 0.99.3 (2019-04-01)
|
||||
===========================
|
||||
|
||||
|
||||
1
changelog.d/4474.misc
Normal file
1
changelog.d/4474.misc
Normal file
@@ -0,0 +1 @@
|
||||
Add test to verify threepid auth check added in #4435.
|
||||
1
changelog.d/4555.bugfix
Normal file
1
changelog.d/4555.bugfix
Normal file
@@ -0,0 +1 @@
|
||||
Avoid redundant URL encoding of redirect URL for SSO login in the fallback login page. Fixes a regression introduced in [#4220](https://github.com/matrix-org/synapse/pull/4220). Contributed by Marcel Fabian Krüger ("[zaugin](https://github.com/zauguin)").
|
||||
1
changelog.d/4942.bugfix
Normal file
1
changelog.d/4942.bugfix
Normal file
@@ -0,0 +1 @@
|
||||
Fix bug where presence updates were sent to all servers in a room when a new server joined, rather than to just the new server.
|
||||
1
changelog.d/4947.feature
Normal file
1
changelog.d/4947.feature
Normal file
@@ -0,0 +1 @@
|
||||
Add ability for password provider modules to bind email addresses to users upon registration.
|
||||
1
changelog.d/4949.misc
Normal file
1
changelog.d/4949.misc
Normal file
@@ -0,0 +1 @@
|
||||
Fix/improve some docstrings in the replication code.
|
||||
2
changelog.d/4953.misc
Normal file
2
changelog.d/4953.misc
Normal file
@@ -0,0 +1,2 @@
|
||||
Split synapse.replication.tcp.streams into smaller files.
|
||||
|
||||
1
changelog.d/4954.misc
Normal file
1
changelog.d/4954.misc
Normal file
@@ -0,0 +1 @@
|
||||
Refactor replication row generation/parsing.
|
||||
1
changelog.d/4955.bugfix
Normal file
1
changelog.d/4955.bugfix
Normal file
@@ -0,0 +1 @@
|
||||
Fix sync bug which made accepting invites unreliable in worker-mode synapses.
|
||||
1
changelog.d/4956.bugfix
Normal file
1
changelog.d/4956.bugfix
Normal file
@@ -0,0 +1 @@
|
||||
Fix sync bug which made accepting invites unreliable in worker-mode synapses.
|
||||
1
changelog.d/4959.misc
Normal file
1
changelog.d/4959.misc
Normal file
@@ -0,0 +1 @@
|
||||
Run `black` to clean up formatting on `synapse/storage/roommember.py` and `synapse/storage/events.py`.
|
||||
1
changelog.d/4965.misc
Normal file
1
changelog.d/4965.misc
Normal file
@@ -0,0 +1 @@
|
||||
Remove log line for password via the admin API.
|
||||
1
changelog.d/4968.misc
Normal file
1
changelog.d/4968.misc
Normal file
@@ -0,0 +1 @@
|
||||
Fix typo in TLS filenames in docker/README.md. Also add the '-p' commandline option to the 'docker run' example. Contributed by Jurrie Overgoor.
|
||||
2
changelog.d/4969.misc
Normal file
2
changelog.d/4969.misc
Normal file
@@ -0,0 +1,2 @@
|
||||
Refactor room version definitions.
|
||||
|
||||
1
changelog.d/4974.misc
Normal file
1
changelog.d/4974.misc
Normal file
@@ -0,0 +1 @@
|
||||
Add `config.signing_key_path` that can be read by `synapse.config` utility.
|
||||
1
changelog.d/4981.bugfix
Normal file
1
changelog.d/4981.bugfix
Normal file
@@ -0,0 +1 @@
|
||||
start.sh: Fix the --no-rate-limit option for messages and make it bypass rate limit on registration and login too.
|
||||
1
changelog.d/4982.misc
Normal file
1
changelog.d/4982.misc
Normal file
@@ -0,0 +1 @@
|
||||
Track which identity server is used when binding a threepid and use that for unbinding, as per MSC1915.
|
||||
1
changelog.d/4985.misc
Normal file
1
changelog.d/4985.misc
Normal file
@@ -0,0 +1 @@
|
||||
Rewrite KeyringTestCase as a HomeserverTestCase.
|
||||
1
changelog.d/4987.misc
Normal file
1
changelog.d/4987.misc
Normal file
@@ -0,0 +1 @@
|
||||
README updates: Corrected the default POSTGRES_USER. Added port forwarding hint in TLS section.
|
||||
1
changelog.d/4989.feature
Normal file
1
changelog.d/4989.feature
Normal file
@@ -0,0 +1 @@
|
||||
Remove presence list support as per MSC 1819.
|
||||
1
changelog.d/4990.bugfix
Normal file
1
changelog.d/4990.bugfix
Normal file
@@ -0,0 +1 @@
|
||||
Transfer related groups on room upgrade.
|
||||
1
changelog.d/4991.feature
Normal file
1
changelog.d/4991.feature
Normal file
@@ -0,0 +1 @@
|
||||
Reduce CPU usage starting pushers during start up.
|
||||
1
changelog.d/4996.misc
Normal file
1
changelog.d/4996.misc
Normal file
@@ -0,0 +1 @@
|
||||
Run `black` on the remainder of `synapse/storage/`.
|
||||
1
changelog.d/4998.misc
Normal file
1
changelog.d/4998.misc
Normal file
@@ -0,0 +1 @@
|
||||
Fix grammar in get_current_users_in_room and give it a docstring.
|
||||
1
changelog.d/4999.bugfix
Normal file
1
changelog.d/4999.bugfix
Normal file
@@ -0,0 +1 @@
|
||||
Prevent the ability to kick users from a room they aren't in.
|
||||
1
changelog.d/5002.feature
Normal file
1
changelog.d/5002.feature
Normal file
@@ -0,0 +1 @@
|
||||
Add a delete group admin API.
|
||||
1
changelog.d/5003.bugfix
Normal file
1
changelog.d/5003.bugfix
Normal file
@@ -0,0 +1 @@
|
||||
Fix issue #4596 so synapse_port_db script works with --curses option on Python 3. Contributed by Anders Jensen-Waud <anders@jensenwaud.com>.
|
||||
1
changelog.d/5007.misc
Normal file
1
changelog.d/5007.misc
Normal file
@@ -0,0 +1 @@
|
||||
Refactor synapse.storage._base._simple_select_list_paginate.
|
||||
1
changelog.d/5010.feature
Normal file
1
changelog.d/5010.feature
Normal file
@@ -0,0 +1 @@
|
||||
Add config option to block users from looking up 3PIDs.
|
||||
1
changelog.d/5020.feature
Normal file
1
changelog.d/5020.feature
Normal file
@@ -0,0 +1 @@
|
||||
Add context to phonehome stats.
|
||||
6
debian/changelog
vendored
6
debian/changelog
vendored
@@ -1,9 +1,3 @@
|
||||
matrix-synapse-py3 (0.99.3.1) stable; urgency=medium
|
||||
|
||||
* New synapse release 0.99.3.1.
|
||||
|
||||
-- Synapse Packaging team <packages@matrix.org> Fri, 03 May 2019 16:02:43 +0100
|
||||
|
||||
matrix-synapse-py3 (0.99.3) stable; urgency=medium
|
||||
|
||||
[ Richard van der Hoff ]
|
||||
|
||||
@@ -27,17 +27,27 @@ for port in 8080 8081 8082; do
|
||||
--config-path "$DIR/etc/$port.config" \
|
||||
--report-stats no
|
||||
|
||||
printf '\n\n# Customisation made by demo/start.sh\n' >> $DIR/etc/$port.config
|
||||
echo 'enable_registration: true' >> $DIR/etc/$port.config
|
||||
|
||||
# Check script parameters
|
||||
if [ $# -eq 1 ]; then
|
||||
if [ $1 = "--no-rate-limit" ]; then
|
||||
# Set high limits in config file to disable rate limiting
|
||||
perl -p -i -e 's/rc_messages_per_second.*/rc_messages_per_second: 1000/g' $DIR/etc/$port.config
|
||||
perl -p -i -e 's/rc_message_burst_count.*/rc_message_burst_count: 1000/g' $DIR/etc/$port.config
|
||||
# messages rate limit
|
||||
echo 'rc_messages_per_second: 1000' >> $DIR/etc/$port.config
|
||||
echo 'rc_message_burst_count: 1000' >> $DIR/etc/$port.config
|
||||
|
||||
# registration rate limit
|
||||
printf 'rc_registration:\n per_second: 1000\n burst_count: 1000\n' >> $DIR/etc/$port.config
|
||||
|
||||
# login rate limit
|
||||
echo 'rc_login:' >> $DIR/etc/$port.config
|
||||
printf ' address:\n per_second: 1000\n burst_count: 1000\n' >> $DIR/etc/$port.config
|
||||
printf ' account:\n per_second: 1000\n burst_count: 1000\n' >> $DIR/etc/$port.config
|
||||
printf ' failed_attempts:\n per_second: 1000\n burst_count: 1000\n' >> $DIR/etc/$port.config
|
||||
fi
|
||||
fi
|
||||
|
||||
perl -p -i -e 's/^enable_registration:.*/enable_registration: true/g' $DIR/etc/$port.config
|
||||
|
||||
if ! grep -F "full_twisted_stacktraces" -q $DIR/etc/$port.config; then
|
||||
echo "full_twisted_stacktraces: true" >> $DIR/etc/$port.config
|
||||
fi
|
||||
|
||||
@@ -31,6 +31,7 @@ docker run \
|
||||
--mount type=volume,src=synapse-data,dst=/data \
|
||||
-e SYNAPSE_SERVER_NAME=my.matrix.host \
|
||||
-e SYNAPSE_REPORT_STATS=yes \
|
||||
-p 8448:8448 \
|
||||
matrixdotorg/synapse:latest
|
||||
```
|
||||
|
||||
@@ -57,9 +58,10 @@ configuration file there. Multiple application services are supported.
|
||||
Synapse requires a valid TLS certificate. You can do one of the following:
|
||||
|
||||
* Provide your own certificate and key (as
|
||||
`${DATA_PATH}/${SYNAPSE_SERVER_NAME}.crt` and
|
||||
`${DATA_PATH}/${SYNAPSE_SERVER_NAME}.key`, or elsewhere by providing an
|
||||
entire config as `${SYNAPSE_CONFIG_PATH}`).
|
||||
`${DATA_PATH}/${SYNAPSE_SERVER_NAME}.tls.crt` and
|
||||
`${DATA_PATH}/${SYNAPSE_SERVER_NAME}.tls.key`, or elsewhere by providing an
|
||||
entire config as `${SYNAPSE_CONFIG_PATH}`). In this case, you should forward
|
||||
traffic to port 8448 in the container, for example with `-p 443:8448`.
|
||||
|
||||
* Use a reverse proxy to terminate incoming TLS, and forward the plain http
|
||||
traffic to port 8008 in the container. In this case you should set `-e
|
||||
@@ -137,7 +139,7 @@ Database specific values (will use SQLite if not set):
|
||||
**NOTE**: You are highly encouraged to use postgresql! Please use the compose
|
||||
file to make it easier to deploy.
|
||||
* `POSTGRES_USER` - The user for the synapse postgres database. [default:
|
||||
`matrix`]
|
||||
`synapse`]
|
||||
|
||||
Mail server specific values (will not send emails if not set):
|
||||
|
||||
|
||||
14
docs/admin_api/delete_group.md
Normal file
14
docs/admin_api/delete_group.md
Normal file
@@ -0,0 +1,14 @@
|
||||
# Delete a local group
|
||||
|
||||
This API lets a server admin delete a local group. Doing so will kick all
|
||||
users out of the group so that their clients will correctly handle the group
|
||||
being deleted.
|
||||
|
||||
|
||||
The API is:
|
||||
|
||||
```
|
||||
POST /_matrix/client/r0/admin/delete_group/<group_id>
|
||||
```
|
||||
|
||||
including an `access_token` of a server admin.
|
||||
@@ -236,6 +236,9 @@ listeners:
|
||||
# - medium: 'email'
|
||||
# address: 'reserved_user@example.com'
|
||||
|
||||
# Used by phonehome stats to group together related servers.
|
||||
#server_context: context
|
||||
|
||||
|
||||
## TLS ##
|
||||
|
||||
@@ -506,12 +509,11 @@ uploads_path: "DATADIR/uploads"
|
||||
# height: 600
|
||||
# method: scale
|
||||
|
||||
# Is the preview URL API enabled?
|
||||
# Is the preview URL API enabled? If enabled, you *must* specify
|
||||
# an explicit url_preview_ip_range_blacklist of IPs that the spider is
|
||||
# denied from accessing.
|
||||
#
|
||||
# 'false' by default: uncomment the following to enable it (and specify a
|
||||
# url_preview_ip_range_blacklist blacklist).
|
||||
#
|
||||
#url_preview_enabled: true
|
||||
#url_preview_enabled: false
|
||||
|
||||
# List of IP address CIDR ranges that the URL preview spider is denied
|
||||
# from accessing. There are no defaults: you must explicitly
|
||||
@@ -521,12 +523,6 @@ uploads_path: "DATADIR/uploads"
|
||||
# synapse to issue arbitrary GET requests to your internal services,
|
||||
# causing serious security issues.
|
||||
#
|
||||
# (0.0.0.0 and :: are always blacklisted, whether or not they are explicitly
|
||||
# listed here, since they correspond to unroutable addresses.)
|
||||
#
|
||||
# This must be specified if url_preview_enabled is set. It is recommended that
|
||||
# you uncomment the following list as a starting point.
|
||||
#
|
||||
#url_preview_ip_range_blacklist:
|
||||
# - '127.0.0.0/8'
|
||||
# - '10.0.0.0/8'
|
||||
@@ -537,7 +533,7 @@ uploads_path: "DATADIR/uploads"
|
||||
# - '::1/128'
|
||||
# - 'fe80::/64'
|
||||
# - 'fc00::/7'
|
||||
|
||||
#
|
||||
# List of IP address CIDR ranges that the URL preview spider is allowed
|
||||
# to access even if they are specified in url_preview_ip_range_blacklist.
|
||||
# This is useful for specifying exceptions to wide-ranging blacklisted
|
||||
@@ -672,6 +668,10 @@ uploads_path: "DATADIR/uploads"
|
||||
# - medium: msisdn
|
||||
# pattern: '\+44'
|
||||
|
||||
# Enable 3PIDs lookup requests to identity servers from this server.
|
||||
#
|
||||
#enable_3pid_lookup: true
|
||||
|
||||
# If set, allows registration of standard or admin accounts by anyone who
|
||||
# has the shared secret, even if registration is otherwise disabled.
|
||||
#
|
||||
|
||||
@@ -811,7 +811,7 @@ class CursesProgress(Progress):
|
||||
middle_space = 1
|
||||
|
||||
items = self.tables.items()
|
||||
items.sort(key=lambda i: (i[1]["perc"], i[0]))
|
||||
items = sorted(items, key=lambda i: (i[1]["perc"], i[0]))
|
||||
|
||||
for i, (table, data) in enumerate(items):
|
||||
if i + 2 >= rows:
|
||||
|
||||
@@ -27,4 +27,4 @@ try:
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
__version__ = "0.99.3.1"
|
||||
__version__ = "0.99.3"
|
||||
|
||||
@@ -69,6 +69,7 @@ class EventTypes(object):
|
||||
Redaction = "m.room.redaction"
|
||||
ThirdPartyInvite = "m.room.third_party_invite"
|
||||
Encryption = "m.room.encryption"
|
||||
RelatedGroups = "m.room.related_groups"
|
||||
|
||||
RoomHistoryVisibility = "m.room.history_visibility"
|
||||
CanonicalAlias = "m.room.canonical_alias"
|
||||
@@ -102,46 +103,6 @@ class ThirdPartyEntityKind(object):
|
||||
LOCATION = "location"
|
||||
|
||||
|
||||
class RoomVersions(object):
|
||||
V1 = "1"
|
||||
V2 = "2"
|
||||
V3 = "3"
|
||||
STATE_V2_TEST = "state-v2-test"
|
||||
|
||||
|
||||
class RoomDisposition(object):
|
||||
STABLE = "stable"
|
||||
UNSTABLE = "unstable"
|
||||
|
||||
|
||||
# the version we will give rooms which are created on this server
|
||||
DEFAULT_ROOM_VERSION = RoomVersions.V1
|
||||
|
||||
# vdh-test-version is a placeholder to get room versioning support working and tested
|
||||
# until we have a working v2.
|
||||
KNOWN_ROOM_VERSIONS = {
|
||||
RoomVersions.V1,
|
||||
RoomVersions.V2,
|
||||
RoomVersions.V3,
|
||||
RoomVersions.STATE_V2_TEST,
|
||||
RoomVersions.V3,
|
||||
}
|
||||
|
||||
|
||||
class EventFormatVersions(object):
|
||||
"""This is an internal enum for tracking the version of the event format,
|
||||
independently from the room version.
|
||||
"""
|
||||
V1 = 1
|
||||
V2 = 2
|
||||
|
||||
|
||||
KNOWN_EVENT_FORMAT_VERSIONS = {
|
||||
EventFormatVersions.V1,
|
||||
EventFormatVersions.V2,
|
||||
}
|
||||
|
||||
|
||||
ServerNoticeMsgType = "m.server_notice"
|
||||
ServerNoticeLimitReached = "m.server_notice.usage_limit_reached"
|
||||
|
||||
|
||||
91
synapse/api/room_versions.py
Normal file
91
synapse/api/room_versions.py
Normal file
@@ -0,0 +1,91 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2019 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import attr
|
||||
|
||||
|
||||
class EventFormatVersions(object):
|
||||
"""This is an internal enum for tracking the version of the event format,
|
||||
independently from the room version.
|
||||
"""
|
||||
V1 = 1 # $id:server format
|
||||
V2 = 2 # MSC1659-style $hash format: introduced for room v3
|
||||
|
||||
|
||||
KNOWN_EVENT_FORMAT_VERSIONS = {
|
||||
EventFormatVersions.V1,
|
||||
EventFormatVersions.V2,
|
||||
}
|
||||
|
||||
|
||||
class StateResolutionVersions(object):
|
||||
"""Enum to identify the state resolution algorithms"""
|
||||
V1 = 1 # room v1 state res
|
||||
V2 = 2 # MSC1442 state res: room v2 and later
|
||||
|
||||
|
||||
class RoomDisposition(object):
|
||||
STABLE = "stable"
|
||||
UNSTABLE = "unstable"
|
||||
|
||||
|
||||
@attr.s(slots=True, frozen=True)
|
||||
class RoomVersion(object):
|
||||
"""An object which describes the unique attributes of a room version."""
|
||||
|
||||
identifier = attr.ib() # str; the identifier for this version
|
||||
disposition = attr.ib() # str; one of the RoomDispositions
|
||||
event_format = attr.ib() # int; one of the EventFormatVersions
|
||||
state_res = attr.ib() # int; one of the StateResolutionVersions
|
||||
|
||||
|
||||
class RoomVersions(object):
|
||||
V1 = RoomVersion(
|
||||
"1",
|
||||
RoomDisposition.STABLE,
|
||||
EventFormatVersions.V1,
|
||||
StateResolutionVersions.V1,
|
||||
)
|
||||
STATE_V2_TEST = RoomVersion(
|
||||
"state-v2-test",
|
||||
RoomDisposition.UNSTABLE,
|
||||
EventFormatVersions.V1,
|
||||
StateResolutionVersions.V2,
|
||||
)
|
||||
V2 = RoomVersion(
|
||||
"2",
|
||||
RoomDisposition.STABLE,
|
||||
EventFormatVersions.V1,
|
||||
StateResolutionVersions.V2,
|
||||
)
|
||||
V3 = RoomVersion(
|
||||
"3",
|
||||
RoomDisposition.STABLE,
|
||||
EventFormatVersions.V2,
|
||||
StateResolutionVersions.V2,
|
||||
)
|
||||
|
||||
|
||||
# the version we will give rooms which are created on this server
|
||||
DEFAULT_ROOM_VERSION = RoomVersions.V1
|
||||
|
||||
|
||||
KNOWN_ROOM_VERSIONS = {
|
||||
v.identifier: v for v in (
|
||||
RoomVersions.V1,
|
||||
RoomVersions.V2,
|
||||
RoomVersions.V3,
|
||||
RoomVersions.STATE_V2_TEST,
|
||||
)
|
||||
} # type: dict[str, RoomVersion]
|
||||
@@ -38,7 +38,7 @@ from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
|
||||
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
|
||||
from synapse.replication.slave.storage.transactions import SlavedTransactionStore
|
||||
from synapse.replication.tcp.client import ReplicationClientHandler
|
||||
from synapse.replication.tcp.streams import ReceiptsStream
|
||||
from synapse.replication.tcp.streams._base import ReceiptsStream
|
||||
from synapse.server import HomeServer
|
||||
from synapse.storage.engines import create_engine
|
||||
from synapse.types import ReadReceipt
|
||||
|
||||
@@ -518,6 +518,7 @@ def run(hs):
|
||||
uptime = 0
|
||||
|
||||
stats["homeserver"] = hs.config.server_name
|
||||
stats["server_context"] = hs.config.server_context
|
||||
stats["timestamp"] = now
|
||||
stats["uptime_seconds"] = uptime
|
||||
version = sys.version_info
|
||||
@@ -558,7 +559,6 @@ def run(hs):
|
||||
|
||||
stats["database_engine"] = hs.get_datastore().database_engine_name
|
||||
stats["database_server_version"] = hs.get_datastore().get_server_version()
|
||||
|
||||
logger.info("Reporting stats to matrix.org: %s" % (stats,))
|
||||
try:
|
||||
yield hs.get_simple_http_client().put_json(
|
||||
|
||||
@@ -48,6 +48,7 @@ from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
|
||||
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
|
||||
from synapse.replication.slave.storage.room import RoomStore
|
||||
from synapse.replication.tcp.client import ReplicationClientHandler
|
||||
from synapse.replication.tcp.streams.events import EventsStreamEventRow
|
||||
from synapse.rest.client.v1 import events
|
||||
from synapse.rest.client.v1.initial_sync import InitialSyncRestServlet
|
||||
from synapse.rest.client.v1.room import RoomInitialSyncRestServlet
|
||||
@@ -369,7 +370,9 @@ class SyncReplicationHandler(ReplicationClientHandler):
|
||||
# We shouldn't get multiple rows per token for events stream, so
|
||||
# we don't need to optimise this for multiple rows.
|
||||
for row in rows:
|
||||
event = yield self.store.get_event(row.event_id)
|
||||
if row.type != EventsStreamEventRow.TypeId:
|
||||
continue
|
||||
event = yield self.store.get_event(row.data.event_id)
|
||||
extra_users = ()
|
||||
if event.type == EventTypes.Member:
|
||||
extra_users = (event.state_key,)
|
||||
|
||||
@@ -36,6 +36,10 @@ from synapse.replication.slave.storage.client_ips import SlavedClientIpStore
|
||||
from synapse.replication.slave.storage.events import SlavedEventStore
|
||||
from synapse.replication.slave.storage.registration import SlavedRegistrationStore
|
||||
from synapse.replication.tcp.client import ReplicationClientHandler
|
||||
from synapse.replication.tcp.streams.events import (
|
||||
EventsStream,
|
||||
EventsStreamCurrentStateRow,
|
||||
)
|
||||
from synapse.rest.client.v2_alpha import user_directory
|
||||
from synapse.server import HomeServer
|
||||
from synapse.storage.engines import create_engine
|
||||
@@ -73,19 +77,18 @@ class UserDirectorySlaveStore(
|
||||
prefilled_cache=curr_state_delta_prefill,
|
||||
)
|
||||
|
||||
self._current_state_delta_pos = events_max
|
||||
|
||||
def stream_positions(self):
|
||||
result = super(UserDirectorySlaveStore, self).stream_positions()
|
||||
result["current_state_deltas"] = self._current_state_delta_pos
|
||||
return result
|
||||
|
||||
def process_replication_rows(self, stream_name, token, rows):
|
||||
if stream_name == "current_state_deltas":
|
||||
self._current_state_delta_pos = token
|
||||
if stream_name == EventsStream.NAME:
|
||||
self._stream_id_gen.advance(token)
|
||||
for row in rows:
|
||||
if row.type != EventsStreamCurrentStateRow.TypeId:
|
||||
continue
|
||||
self._curr_state_delta_stream_cache.entity_has_changed(
|
||||
row.room_id, token
|
||||
row.data.room_id, token
|
||||
)
|
||||
return super(UserDirectorySlaveStore, self).process_replication_rows(
|
||||
stream_name, token, rows
|
||||
@@ -170,7 +173,7 @@ class UserDirectoryReplicationHandler(ReplicationClientHandler):
|
||||
yield super(UserDirectoryReplicationHandler, self).on_rdata(
|
||||
stream_name, token, rows
|
||||
)
|
||||
if stream_name == "current_state_deltas":
|
||||
if stream_name == EventsStream.NAME:
|
||||
run_in_background(self._notify_directory)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
||||
@@ -42,7 +42,8 @@ class KeyConfig(Config):
|
||||
if "signing_key" in config:
|
||||
self.signing_key = read_signing_keys([config["signing_key"]])
|
||||
else:
|
||||
self.signing_key = self.read_signing_key(config["signing_key_path"])
|
||||
self.signing_key_path = config["signing_key_path"]
|
||||
self.signing_key = self.read_signing_key(self.signing_key_path)
|
||||
|
||||
self.old_signing_keys = self.read_old_signing_keys(
|
||||
config.get("old_signing_keys", {})
|
||||
|
||||
@@ -33,6 +33,7 @@ class RegistrationConfig(Config):
|
||||
|
||||
self.registrations_require_3pid = config.get("registrations_require_3pid", [])
|
||||
self.allowed_local_3pids = config.get("allowed_local_3pids", [])
|
||||
self.enable_3pid_lookup = config.get("enable_3pid_lookup", True)
|
||||
self.registration_shared_secret = config.get("registration_shared_secret")
|
||||
|
||||
self.bcrypt_rounds = config.get("bcrypt_rounds", 12)
|
||||
@@ -97,6 +98,10 @@ class RegistrationConfig(Config):
|
||||
# - medium: msisdn
|
||||
# pattern: '\\+44'
|
||||
|
||||
# Enable 3PIDs lookup requests to identity servers from this server.
|
||||
#
|
||||
#enable_3pid_lookup: true
|
||||
|
||||
# If set, allows registration of standard or admin accounts by anyone who
|
||||
# has the shared secret, even if registration is otherwise disabled.
|
||||
#
|
||||
|
||||
@@ -186,21 +186,17 @@ class ContentRepositoryConfig(Config):
|
||||
except ImportError:
|
||||
raise ConfigError(MISSING_NETADDR)
|
||||
|
||||
if "url_preview_ip_range_blacklist" not in config:
|
||||
if "url_preview_ip_range_blacklist" in config:
|
||||
self.url_preview_ip_range_blacklist = IPSet(
|
||||
config["url_preview_ip_range_blacklist"]
|
||||
)
|
||||
else:
|
||||
raise ConfigError(
|
||||
"For security, you must specify an explicit target IP address "
|
||||
"blacklist in url_preview_ip_range_blacklist for url previewing "
|
||||
"to work"
|
||||
)
|
||||
|
||||
self.url_preview_ip_range_blacklist = IPSet(
|
||||
config["url_preview_ip_range_blacklist"]
|
||||
)
|
||||
|
||||
# we always blacklist '0.0.0.0' and '::', which are supposed to be
|
||||
# unroutable addresses.
|
||||
self.url_preview_ip_range_blacklist.update(['0.0.0.0', '::'])
|
||||
|
||||
self.url_preview_ip_range_whitelist = IPSet(
|
||||
config.get("url_preview_ip_range_whitelist", ())
|
||||
)
|
||||
@@ -264,12 +260,11 @@ class ContentRepositoryConfig(Config):
|
||||
#thumbnail_sizes:
|
||||
%(formatted_thumbnail_sizes)s
|
||||
|
||||
# Is the preview URL API enabled?
|
||||
# Is the preview URL API enabled? If enabled, you *must* specify
|
||||
# an explicit url_preview_ip_range_blacklist of IPs that the spider is
|
||||
# denied from accessing.
|
||||
#
|
||||
# 'false' by default: uncomment the following to enable it (and specify a
|
||||
# url_preview_ip_range_blacklist blacklist).
|
||||
#
|
||||
#url_preview_enabled: true
|
||||
#url_preview_enabled: false
|
||||
|
||||
# List of IP address CIDR ranges that the URL preview spider is denied
|
||||
# from accessing. There are no defaults: you must explicitly
|
||||
@@ -279,12 +274,6 @@ class ContentRepositoryConfig(Config):
|
||||
# synapse to issue arbitrary GET requests to your internal services,
|
||||
# causing serious security issues.
|
||||
#
|
||||
# (0.0.0.0 and :: are always blacklisted, whether or not they are explicitly
|
||||
# listed here, since they correspond to unroutable addresses.)
|
||||
#
|
||||
# This must be specified if url_preview_enabled is set. It is recommended that
|
||||
# you uncomment the following list as a starting point.
|
||||
#
|
||||
#url_preview_ip_range_blacklist:
|
||||
# - '127.0.0.0/8'
|
||||
# - '10.0.0.0/8'
|
||||
@@ -295,7 +284,7 @@ class ContentRepositoryConfig(Config):
|
||||
# - '::1/128'
|
||||
# - 'fe80::/64'
|
||||
# - 'fc00::/7'
|
||||
|
||||
#
|
||||
# List of IP address CIDR ranges that the URL preview spider is allowed
|
||||
# to access even if they are specified in url_preview_ip_range_blacklist.
|
||||
# This is useful for specifying exceptions to wide-ranging blacklisted
|
||||
|
||||
@@ -37,6 +37,7 @@ class ServerConfig(Config):
|
||||
|
||||
def read_config(self, config):
|
||||
self.server_name = config["server_name"]
|
||||
self.server_context = config.get("server_context", None)
|
||||
|
||||
try:
|
||||
parse_and_validate_server_name(self.server_name)
|
||||
@@ -484,6 +485,9 @@ class ServerConfig(Config):
|
||||
#mau_limit_reserved_threepids:
|
||||
# - medium: 'email'
|
||||
# address: 'reserved_user@example.com'
|
||||
|
||||
# Used by phonehome stats to group together related servers.
|
||||
#server_context: context
|
||||
""" % locals()
|
||||
|
||||
def read_arguments(self, args):
|
||||
|
||||
@@ -20,15 +20,9 @@ from signedjson.key import decode_verify_key_bytes
|
||||
from signedjson.sign import SignatureVerifyException, verify_signed_json
|
||||
from unpaddedbase64 import decode_base64
|
||||
|
||||
from synapse.api.constants import (
|
||||
KNOWN_ROOM_VERSIONS,
|
||||
EventFormatVersions,
|
||||
EventTypes,
|
||||
JoinRules,
|
||||
Membership,
|
||||
RoomVersions,
|
||||
)
|
||||
from synapse.api.constants import EventTypes, JoinRules, Membership
|
||||
from synapse.api.errors import AuthError, EventSizeError, SynapseError
|
||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, EventFormatVersions
|
||||
from synapse.types import UserID, get_domain_from_id
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -452,16 +446,18 @@ def check_redaction(room_version, event, auth_events):
|
||||
if user_level >= redact_level:
|
||||
return False
|
||||
|
||||
if room_version in (RoomVersions.V1, RoomVersions.V2,):
|
||||
v = KNOWN_ROOM_VERSIONS.get(room_version)
|
||||
if not v:
|
||||
raise RuntimeError("Unrecognized room version %r" % (room_version,))
|
||||
|
||||
if v.event_format == EventFormatVersions.V1:
|
||||
redacter_domain = get_domain_from_id(event.event_id)
|
||||
redactee_domain = get_domain_from_id(event.redacts)
|
||||
if redacter_domain == redactee_domain:
|
||||
return True
|
||||
elif room_version == RoomVersions.V3:
|
||||
else:
|
||||
event.internal_metadata.recheck_redaction = True
|
||||
return True
|
||||
else:
|
||||
raise RuntimeError("Unrecognized room version %r" % (room_version,))
|
||||
|
||||
raise AuthError(
|
||||
403,
|
||||
|
||||
@@ -21,7 +21,7 @@ import six
|
||||
|
||||
from unpaddedbase64 import encode_base64
|
||||
|
||||
from synapse.api.constants import KNOWN_ROOM_VERSIONS, EventFormatVersions, RoomVersions
|
||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, EventFormatVersions
|
||||
from synapse.util.caches import intern_dict
|
||||
from synapse.util.frozenutils import freeze
|
||||
|
||||
@@ -351,18 +351,13 @@ def room_version_to_event_format(room_version):
|
||||
Returns:
|
||||
int
|
||||
"""
|
||||
if room_version not in KNOWN_ROOM_VERSIONS:
|
||||
v = KNOWN_ROOM_VERSIONS.get(room_version)
|
||||
|
||||
if not v:
|
||||
# We should have already checked version, so this should not happen
|
||||
raise RuntimeError("Unrecognized room version %s" % (room_version,))
|
||||
|
||||
if room_version in (
|
||||
RoomVersions.V1, RoomVersions.V2, RoomVersions.STATE_V2_TEST,
|
||||
):
|
||||
return EventFormatVersions.V1
|
||||
elif room_version in (RoomVersions.V3,):
|
||||
return EventFormatVersions.V2
|
||||
else:
|
||||
raise RuntimeError("Unrecognized room version %s" % (room_version,))
|
||||
return v.event_format
|
||||
|
||||
|
||||
def event_type_from_format_version(format_version):
|
||||
|
||||
@@ -17,21 +17,17 @@ import attr
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.constants import (
|
||||
from synapse.api.constants import MAX_DEPTH
|
||||
from synapse.api.room_versions import (
|
||||
KNOWN_EVENT_FORMAT_VERSIONS,
|
||||
KNOWN_ROOM_VERSIONS,
|
||||
MAX_DEPTH,
|
||||
EventFormatVersions,
|
||||
)
|
||||
from synapse.crypto.event_signing import add_hashes_and_signatures
|
||||
from synapse.types import EventID
|
||||
from synapse.util.stringutils import random_string
|
||||
|
||||
from . import (
|
||||
_EventInternalMetadata,
|
||||
event_type_from_format_version,
|
||||
room_version_to_event_format,
|
||||
)
|
||||
from . import _EventInternalMetadata, event_type_from_format_version
|
||||
|
||||
|
||||
@attr.s(slots=True, cmp=False, frozen=True)
|
||||
@@ -170,21 +166,34 @@ class EventBuilderFactory(object):
|
||||
def new(self, room_version, key_values):
|
||||
"""Generate an event builder appropriate for the given room version
|
||||
|
||||
Deprecated: use for_room_version with a RoomVersion object instead
|
||||
|
||||
Args:
|
||||
room_version (str): Version of the room that we're creating an
|
||||
event builder for
|
||||
room_version (str): Version of the room that we're creating an event builder
|
||||
for
|
||||
key_values (dict): Fields used as the basis of the new event
|
||||
|
||||
Returns:
|
||||
EventBuilder
|
||||
"""
|
||||
|
||||
# There's currently only the one event version defined
|
||||
if room_version not in KNOWN_ROOM_VERSIONS:
|
||||
v = KNOWN_ROOM_VERSIONS.get(room_version)
|
||||
if not v:
|
||||
raise Exception(
|
||||
"No event format defined for version %r" % (room_version,)
|
||||
)
|
||||
return self.for_room_version(v, key_values)
|
||||
|
||||
def for_room_version(self, room_version, key_values):
|
||||
"""Generate an event builder appropriate for the given room version
|
||||
|
||||
Args:
|
||||
room_version (synapse.api.room_versions.RoomVersion):
|
||||
Version of the room that we're creating an event builder for
|
||||
key_values (dict): Fields used as the basis of the new event
|
||||
|
||||
Returns:
|
||||
EventBuilder
|
||||
"""
|
||||
return EventBuilder(
|
||||
store=self.store,
|
||||
state=self.state,
|
||||
@@ -192,7 +201,7 @@ class EventBuilderFactory(object):
|
||||
clock=self.clock,
|
||||
hostname=self.hostname,
|
||||
signing_key=self.signing_key,
|
||||
format_version=room_version_to_event_format(room_version),
|
||||
format_version=room_version.event_format,
|
||||
type=key_values["type"],
|
||||
state_key=key_values.get("state_key"),
|
||||
room_id=key_values["room_id"],
|
||||
@@ -222,7 +231,6 @@ def create_local_event_from_event_dict(clock, hostname, signing_key,
|
||||
FrozenEvent
|
||||
"""
|
||||
|
||||
# There's currently only the one event version defined
|
||||
if format_version not in KNOWN_EVENT_FORMAT_VERSIONS:
|
||||
raise Exception(
|
||||
"No event format defined for version %r" % (format_version,)
|
||||
|
||||
@@ -15,8 +15,9 @@
|
||||
|
||||
from six import string_types
|
||||
|
||||
from synapse.api.constants import EventFormatVersions, EventTypes, Membership
|
||||
from synapse.api.constants import EventTypes, Membership
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.api.room_versions import EventFormatVersions
|
||||
from synapse.types import EventID, RoomID, UserID
|
||||
|
||||
|
||||
|
||||
@@ -20,8 +20,9 @@ import six
|
||||
from twisted.internet import defer
|
||||
from twisted.internet.defer import DeferredList
|
||||
|
||||
from synapse.api.constants import MAX_DEPTH, EventTypes, Membership, RoomVersions
|
||||
from synapse.api.constants import MAX_DEPTH, EventTypes, Membership
|
||||
from synapse.api.errors import Codes, SynapseError
|
||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, EventFormatVersions
|
||||
from synapse.crypto.event_signing import check_event_content_hash
|
||||
from synapse.events import event_type_from_format_version
|
||||
from synapse.events.utils import prune_event
|
||||
@@ -274,9 +275,12 @@ def _check_sigs_on_pdus(keyring, room_version, pdus):
|
||||
# now let's look for events where the sender's domain is different to the
|
||||
# event id's domain (normally only the case for joins/leaves), and add additional
|
||||
# checks. Only do this if the room version has a concept of event ID domain
|
||||
if room_version in (
|
||||
RoomVersions.V1, RoomVersions.V2, RoomVersions.STATE_V2_TEST,
|
||||
):
|
||||
# (ie, the room version uses old-style non-hash event IDs).
|
||||
v = KNOWN_ROOM_VERSIONS.get(room_version)
|
||||
if not v:
|
||||
raise RuntimeError("Unrecognized room version %s" % (room_version,))
|
||||
|
||||
if v.event_format == EventFormatVersions.V1:
|
||||
pdus_to_check_event_id = [
|
||||
p for p in pdus_to_check
|
||||
if p.sender_domain != get_domain_from_id(p.pdu.event_id)
|
||||
@@ -289,10 +293,6 @@ def _check_sigs_on_pdus(keyring, room_version, pdus):
|
||||
|
||||
for p, d in zip(pdus_to_check_event_id, more_deferreds):
|
||||
p.deferreds.append(d)
|
||||
elif room_version in (RoomVersions.V3,):
|
||||
pass # No further checks needed, as event IDs are hashes here
|
||||
else:
|
||||
raise RuntimeError("Unrecognized room version %s" % (room_version,))
|
||||
|
||||
# replace lists of deferreds with single Deferreds
|
||||
return [_flatten_deferred_list(p.deferreds) for p in pdus_to_check]
|
||||
|
||||
@@ -25,12 +25,7 @@ from prometheus_client import Counter
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.constants import (
|
||||
KNOWN_ROOM_VERSIONS,
|
||||
EventTypes,
|
||||
Membership,
|
||||
RoomVersions,
|
||||
)
|
||||
from synapse.api.constants import EventTypes, Membership
|
||||
from synapse.api.errors import (
|
||||
CodeMessageException,
|
||||
Codes,
|
||||
@@ -38,6 +33,11 @@ from synapse.api.errors import (
|
||||
HttpResponseException,
|
||||
SynapseError,
|
||||
)
|
||||
from synapse.api.room_versions import (
|
||||
KNOWN_ROOM_VERSIONS,
|
||||
EventFormatVersions,
|
||||
RoomVersions,
|
||||
)
|
||||
from synapse.events import builder, room_version_to_event_format
|
||||
from synapse.federation.federation_base import FederationBase, event_from_pdu_json
|
||||
from synapse.util import logcontext, unwrapFirstError
|
||||
@@ -570,7 +570,7 @@ class FederationClient(FederationBase):
|
||||
Deferred[tuple[str, FrozenEvent, int]]: resolves to a tuple of
|
||||
`(origin, event, event_format)` where origin is the remote
|
||||
homeserver which generated the event, and event_format is one of
|
||||
`synapse.api.constants.EventFormatVersions`.
|
||||
`synapse.api.room_versions.EventFormatVersions`.
|
||||
|
||||
Fails with a ``SynapseError`` if the chosen remote server
|
||||
returns a 300/400 code.
|
||||
@@ -592,7 +592,7 @@ class FederationClient(FederationBase):
|
||||
|
||||
# Note: If not supplied, the room version may be either v1 or v2,
|
||||
# however either way the event format version will be v1.
|
||||
room_version = ret.get("room_version", RoomVersions.V1)
|
||||
room_version = ret.get("room_version", RoomVersions.V1.identifier)
|
||||
event_format = room_version_to_event_format(room_version)
|
||||
|
||||
pdu_dict = ret.get("event", None)
|
||||
@@ -695,7 +695,9 @@ class FederationClient(FederationBase):
|
||||
room_version = None
|
||||
for e in state:
|
||||
if (e.type, e.state_key) == (EventTypes.Create, ""):
|
||||
room_version = e.content.get("room_version", RoomVersions.V1)
|
||||
room_version = e.content.get(
|
||||
"room_version", RoomVersions.V1.identifier
|
||||
)
|
||||
break
|
||||
|
||||
if room_version is None:
|
||||
@@ -802,11 +804,10 @@ class FederationClient(FederationBase):
|
||||
raise err
|
||||
|
||||
# Otherwise, we assume that the remote server doesn't understand
|
||||
# the v2 invite API.
|
||||
|
||||
if room_version in (RoomVersions.V1, RoomVersions.V2):
|
||||
pass # We'll fall through
|
||||
else:
|
||||
# the v2 invite API. That's ok provided the room uses old-style event
|
||||
# IDs.
|
||||
v = KNOWN_ROOM_VERSIONS.get(room_version)
|
||||
if v.event_format != EventFormatVersions.V1:
|
||||
raise SynapseError(
|
||||
400,
|
||||
"User's homeserver does not support this room version",
|
||||
|
||||
@@ -25,7 +25,7 @@ from twisted.internet import defer
|
||||
from twisted.internet.abstract import isIPAddress
|
||||
from twisted.python import failure
|
||||
|
||||
from synapse.api.constants import KNOWN_ROOM_VERSIONS, EventTypes, Membership
|
||||
from synapse.api.constants import EventTypes, Membership
|
||||
from synapse.api.errors import (
|
||||
AuthError,
|
||||
Codes,
|
||||
@@ -34,6 +34,7 @@ from synapse.api.errors import (
|
||||
NotFoundError,
|
||||
SynapseError,
|
||||
)
|
||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
|
||||
from synapse.crypto.event_signing import compute_event_signature
|
||||
from synapse.events import room_version_to_event_format
|
||||
from synapse.federation.federation_base import FederationBase, event_from_pdu_json
|
||||
|
||||
@@ -55,7 +55,12 @@ class FederationRemoteSendQueue(object):
|
||||
self.is_mine_id = hs.is_mine_id
|
||||
|
||||
self.presence_map = {} # Pending presence map user_id -> UserPresenceState
|
||||
self.presence_changed = SortedDict() # Stream position -> user_id
|
||||
self.presence_changed = SortedDict() # Stream position -> list[user_id]
|
||||
|
||||
# Stores the destinations we need to explicitly send presence to about a
|
||||
# given user.
|
||||
# Stream position -> (user_id, destinations)
|
||||
self.presence_destinations = SortedDict()
|
||||
|
||||
self.keyed_edu = {} # (destination, key) -> EDU
|
||||
self.keyed_edu_changed = SortedDict() # stream position -> (destination, key)
|
||||
@@ -77,7 +82,7 @@ class FederationRemoteSendQueue(object):
|
||||
|
||||
for queue_name in [
|
||||
"presence_map", "presence_changed", "keyed_edu", "keyed_edu_changed",
|
||||
"edus", "device_messages", "pos_time",
|
||||
"edus", "device_messages", "pos_time", "presence_destinations",
|
||||
]:
|
||||
register(queue_name, getattr(self, queue_name))
|
||||
|
||||
@@ -121,6 +126,15 @@ class FederationRemoteSendQueue(object):
|
||||
for user_id in uids
|
||||
)
|
||||
|
||||
keys = self.presence_destinations.keys()
|
||||
i = self.presence_destinations.bisect_left(position_to_delete)
|
||||
for key in keys[:i]:
|
||||
del self.presence_destinations[key]
|
||||
|
||||
user_ids.update(
|
||||
user_id for user_id, _ in self.presence_destinations.values()
|
||||
)
|
||||
|
||||
to_del = [
|
||||
user_id for user_id in self.presence_map if user_id not in user_ids
|
||||
]
|
||||
@@ -209,6 +223,20 @@ class FederationRemoteSendQueue(object):
|
||||
|
||||
self.notifier.on_new_replication_data()
|
||||
|
||||
def send_presence_to_destinations(self, states, destinations):
|
||||
"""As per FederationSender
|
||||
|
||||
Args:
|
||||
states (list[UserPresenceState])
|
||||
destinations (list[str])
|
||||
"""
|
||||
for state in states:
|
||||
pos = self._next_pos()
|
||||
self.presence_map.update({state.user_id: state for state in states})
|
||||
self.presence_destinations[pos] = (state.user_id, destinations)
|
||||
|
||||
self.notifier.on_new_replication_data()
|
||||
|
||||
def send_device_messages(self, destination):
|
||||
"""As per FederationSender"""
|
||||
pos = self._next_pos()
|
||||
@@ -261,6 +289,16 @@ class FederationRemoteSendQueue(object):
|
||||
state=self.presence_map[user_id],
|
||||
)))
|
||||
|
||||
# Fetch presence to send to destinations
|
||||
i = self.presence_destinations.bisect_right(from_token)
|
||||
j = self.presence_destinations.bisect_right(to_token) + 1
|
||||
|
||||
for pos, (user_id, dests) in self.presence_destinations.items()[i:j]:
|
||||
rows.append((pos, PresenceDestinationsRow(
|
||||
state=self.presence_map[user_id],
|
||||
destinations=list(dests),
|
||||
)))
|
||||
|
||||
# Fetch changes keyed edus
|
||||
i = self.keyed_edu_changed.bisect_right(from_token)
|
||||
j = self.keyed_edu_changed.bisect_right(to_token) + 1
|
||||
@@ -357,6 +395,29 @@ class PresenceRow(BaseFederationRow, namedtuple("PresenceRow", (
|
||||
buff.presence.append(self.state)
|
||||
|
||||
|
||||
class PresenceDestinationsRow(BaseFederationRow, namedtuple("PresenceDestinationsRow", (
|
||||
"state", # UserPresenceState
|
||||
"destinations", # list[str]
|
||||
))):
|
||||
TypeId = "pd"
|
||||
|
||||
@staticmethod
|
||||
def from_data(data):
|
||||
return PresenceDestinationsRow(
|
||||
state=UserPresenceState.from_dict(data["state"]),
|
||||
destinations=data["dests"],
|
||||
)
|
||||
|
||||
def to_data(self):
|
||||
return {
|
||||
"state": self.state.as_dict(),
|
||||
"dests": self.destinations,
|
||||
}
|
||||
|
||||
def add_to_buffer(self, buff):
|
||||
buff.presence_destinations.append((self.state, self.destinations))
|
||||
|
||||
|
||||
class KeyedEduRow(BaseFederationRow, namedtuple("KeyedEduRow", (
|
||||
"key", # tuple(str) - the edu key passed to send_edu
|
||||
"edu", # Edu
|
||||
@@ -428,6 +489,7 @@ TypeToRow = {
|
||||
Row.TypeId: Row
|
||||
for Row in (
|
||||
PresenceRow,
|
||||
PresenceDestinationsRow,
|
||||
KeyedEduRow,
|
||||
EduRow,
|
||||
DeviceRow,
|
||||
@@ -437,6 +499,7 @@ TypeToRow = {
|
||||
|
||||
ParsedFederationStreamData = namedtuple("ParsedFederationStreamData", (
|
||||
"presence", # list(UserPresenceState)
|
||||
"presence_destinations", # list of tuples of UserPresenceState and destinations
|
||||
"keyed_edus", # dict of destination -> { key -> Edu }
|
||||
"edus", # dict of destination -> [Edu]
|
||||
"device_destinations", # set of destinations
|
||||
@@ -458,6 +521,7 @@ def process_rows_for_federation(transaction_queue, rows):
|
||||
|
||||
buff = ParsedFederationStreamData(
|
||||
presence=[],
|
||||
presence_destinations=[],
|
||||
keyed_edus={},
|
||||
edus={},
|
||||
device_destinations=set(),
|
||||
@@ -476,6 +540,11 @@ def process_rows_for_federation(transaction_queue, rows):
|
||||
if buff.presence:
|
||||
transaction_queue.send_presence(buff.presence)
|
||||
|
||||
for state, destinations in buff.presence_destinations:
|
||||
transaction_queue.send_presence_to_destinations(
|
||||
states=[state], destinations=destinations,
|
||||
)
|
||||
|
||||
for destination, edu_map in iteritems(buff.keyed_edus):
|
||||
for key, edu in edu_map.items():
|
||||
transaction_queue.send_edu(edu, key)
|
||||
|
||||
@@ -371,7 +371,7 @@ class FederationSender(object):
|
||||
return
|
||||
|
||||
# First we queue up the new presence by user ID, so multiple presence
|
||||
# updates in quick successtion are correctly handled
|
||||
# updates in quick succession are correctly handled.
|
||||
# We only want to send presence for our own users, so lets always just
|
||||
# filter here just in case.
|
||||
self.pending_presence.update({
|
||||
@@ -402,6 +402,23 @@ class FederationSender(object):
|
||||
finally:
|
||||
self._processing_pending_presence = False
|
||||
|
||||
def send_presence_to_destinations(self, states, destinations):
|
||||
"""Send the given presence states to the given destinations.
|
||||
|
||||
Args:
|
||||
states (list[UserPresenceState])
|
||||
destinations (list[str])
|
||||
"""
|
||||
|
||||
if not states or not self.hs.config.use_presence:
|
||||
# No-op if presence is disabled.
|
||||
return
|
||||
|
||||
for destination in destinations:
|
||||
if destination == self.server_name:
|
||||
continue
|
||||
self._get_per_destination_queue(destination).send_presence(states)
|
||||
|
||||
@measure_func("txnqueue._process_presence")
|
||||
@defer.inlineCallbacks
|
||||
def _process_presence_inner(self, states):
|
||||
|
||||
@@ -21,8 +21,8 @@ import re
|
||||
from twisted.internet import defer
|
||||
|
||||
import synapse
|
||||
from synapse.api.constants import RoomVersions
|
||||
from synapse.api.errors import Codes, FederationDeniedError, SynapseError
|
||||
from synapse.api.room_versions import RoomVersions
|
||||
from synapse.api.urls import FEDERATION_V1_PREFIX, FEDERATION_V2_PREFIX
|
||||
from synapse.http.endpoint import parse_and_validate_server_name
|
||||
from synapse.http.server import JsonResource
|
||||
@@ -513,7 +513,7 @@ class FederationV1InviteServlet(BaseFederationServlet):
|
||||
# state resolution algorithm, and we don't use that for processing
|
||||
# invites
|
||||
content = yield self.handler.on_invite_request(
|
||||
origin, content, room_version=RoomVersions.V1,
|
||||
origin, content, room_version=RoomVersions.V1.identifier,
|
||||
)
|
||||
|
||||
# V1 federation API is defined to return a content of `[200, {...}]`
|
||||
|
||||
@@ -22,6 +22,7 @@ from twisted.internet import defer
|
||||
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.types import GroupID, RoomID, UserID, get_domain_from_id
|
||||
from synapse.util.async_helpers import concurrently_execute
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -896,6 +897,78 @@ class GroupsServerHandler(object):
|
||||
"group_id": group_id,
|
||||
})
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def delete_group(self, group_id, requester_user_id):
|
||||
"""Deletes a group, kicking out all current members.
|
||||
|
||||
Only group admins or server admins can call this request
|
||||
|
||||
Args:
|
||||
group_id (str)
|
||||
request_user_id (str)
|
||||
|
||||
Returns:
|
||||
Deferred
|
||||
"""
|
||||
|
||||
yield self.check_group_is_ours(
|
||||
group_id, requester_user_id,
|
||||
and_exists=True,
|
||||
)
|
||||
|
||||
# Only server admins or group admins can delete groups.
|
||||
|
||||
is_admin = yield self.store.is_user_admin_in_group(
|
||||
group_id, requester_user_id
|
||||
)
|
||||
|
||||
if not is_admin:
|
||||
is_admin = yield self.auth.is_server_admin(
|
||||
UserID.from_string(requester_user_id),
|
||||
)
|
||||
|
||||
if not is_admin:
|
||||
raise SynapseError(403, "User is not an admin")
|
||||
|
||||
# Before deleting the group lets kick everyone out of it
|
||||
users = yield self.store.get_users_in_group(
|
||||
group_id, include_private=True,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _kick_user_from_group(user_id):
|
||||
if self.hs.is_mine_id(user_id):
|
||||
groups_local = self.hs.get_groups_local_handler()
|
||||
yield groups_local.user_removed_from_group(group_id, user_id, {})
|
||||
else:
|
||||
yield self.transport_client.remove_user_from_group_notification(
|
||||
get_domain_from_id(user_id), group_id, user_id, {}
|
||||
)
|
||||
yield self.store.maybe_delete_remote_profile_cache(user_id)
|
||||
|
||||
# We kick users out in the order of:
|
||||
# 1. Non-admins
|
||||
# 2. Other admins
|
||||
# 3. The requester
|
||||
#
|
||||
# This is so that if the deletion fails for some reason other admins or
|
||||
# the requester still has auth to retry.
|
||||
non_admins = []
|
||||
admins = []
|
||||
for u in users:
|
||||
if u["user_id"] == requester_user_id:
|
||||
continue
|
||||
if u["is_admin"]:
|
||||
admins.append(u["user_id"])
|
||||
else:
|
||||
non_admins.append(u["user_id"])
|
||||
|
||||
yield concurrently_execute(_kick_user_from_group, non_admins, 10)
|
||||
yield concurrently_execute(_kick_user_from_group, admins, 10)
|
||||
yield _kick_user_from_group(requester_user_id)
|
||||
|
||||
yield self.store.delete_group(group_id)
|
||||
|
||||
|
||||
def _parse_join_policy_from_contents(content):
|
||||
"""Given a content for a request, return the specified join policy or None
|
||||
|
||||
@@ -912,7 +912,7 @@ class AuthHandler(BaseHandler):
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def delete_threepid(self, user_id, medium, address):
|
||||
def delete_threepid(self, user_id, medium, address, id_server=None):
|
||||
"""Attempts to unbind the 3pid on the identity servers and deletes it
|
||||
from the local database.
|
||||
|
||||
@@ -920,6 +920,10 @@ class AuthHandler(BaseHandler):
|
||||
user_id (str)
|
||||
medium (str)
|
||||
address (str)
|
||||
id_server (str|None): Use the given identity server when unbinding
|
||||
any threepids. If None then will attempt to unbind using the
|
||||
identity server specified when binding (if known).
|
||||
|
||||
|
||||
Returns:
|
||||
Deferred[bool]: Returns True if successfully unbound the 3pid on
|
||||
@@ -937,6 +941,7 @@ class AuthHandler(BaseHandler):
|
||||
{
|
||||
'medium': medium,
|
||||
'address': address,
|
||||
'id_server': id_server,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@@ -43,12 +43,15 @@ class DeactivateAccountHandler(BaseHandler):
|
||||
hs.get_reactor().callWhenRunning(self._start_user_parting)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def deactivate_account(self, user_id, erase_data):
|
||||
def deactivate_account(self, user_id, erase_data, id_server=None):
|
||||
"""Deactivate a user's account
|
||||
|
||||
Args:
|
||||
user_id (str): ID of user to be deactivated
|
||||
erase_data (bool): whether to GDPR-erase the user's data
|
||||
id_server (str|None): Use the given identity server when unbinding
|
||||
any threepids. If None then will attempt to unbind using the
|
||||
identity server specified when binding (if known).
|
||||
|
||||
Returns:
|
||||
Deferred[bool]: True if identity server supports removing
|
||||
@@ -74,6 +77,7 @@ class DeactivateAccountHandler(BaseHandler):
|
||||
{
|
||||
'medium': threepid['medium'],
|
||||
'address': threepid['address'],
|
||||
'id_server': id_server,
|
||||
},
|
||||
)
|
||||
identity_server_supports_unbinding &= result
|
||||
|
||||
@@ -68,7 +68,7 @@ class DirectoryHandler(BaseHandler):
|
||||
# TODO(erikj): Add transactions.
|
||||
# TODO(erikj): Check if there is a current association.
|
||||
if not servers:
|
||||
users = yield self.state.get_current_user_in_room(room_id)
|
||||
users = yield self.state.get_current_users_in_room(room_id)
|
||||
servers = set(get_domain_from_id(u) for u in users)
|
||||
|
||||
if not servers:
|
||||
@@ -268,7 +268,7 @@ class DirectoryHandler(BaseHandler):
|
||||
Codes.NOT_FOUND
|
||||
)
|
||||
|
||||
users = yield self.state.get_current_user_in_room(room_id)
|
||||
users = yield self.state.get_current_users_in_room(room_id)
|
||||
extra_servers = set(get_domain_from_id(u) for u in users)
|
||||
servers = set(extra_servers) | set(servers)
|
||||
|
||||
|
||||
@@ -102,7 +102,7 @@ class EventStreamHandler(BaseHandler):
|
||||
# Send down presence.
|
||||
if event.state_key == auth_user_id:
|
||||
# Send down presence for everyone in the room.
|
||||
users = yield self.state.get_current_user_in_room(event.room_id)
|
||||
users = yield self.state.get_current_users_in_room(event.room_id)
|
||||
states = yield presence_handler.get_states(
|
||||
users,
|
||||
as_event=True,
|
||||
|
||||
@@ -29,13 +29,7 @@ from unpaddedbase64 import decode_base64
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.constants import (
|
||||
KNOWN_ROOM_VERSIONS,
|
||||
EventTypes,
|
||||
Membership,
|
||||
RejectedReason,
|
||||
RoomVersions,
|
||||
)
|
||||
from synapse.api.constants import EventTypes, Membership, RejectedReason
|
||||
from synapse.api.errors import (
|
||||
AuthError,
|
||||
CodeMessageException,
|
||||
@@ -44,6 +38,7 @@ from synapse.api.errors import (
|
||||
StoreError,
|
||||
SynapseError,
|
||||
)
|
||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersions
|
||||
from synapse.crypto.event_signing import compute_event_signature
|
||||
from synapse.event_auth import auth_types_for_event
|
||||
from synapse.events.validator import EventValidator
|
||||
@@ -1733,7 +1728,9 @@ class FederationHandler(BaseHandler):
|
||||
# invalid, and it would fail auth checks anyway.
|
||||
raise SynapseError(400, "No create event in state")
|
||||
|
||||
room_version = create_event.content.get("room_version", RoomVersions.V1)
|
||||
room_version = create_event.content.get(
|
||||
"room_version", RoomVersions.V1.identifier,
|
||||
)
|
||||
|
||||
missing_auth_events = set()
|
||||
for e in itertools.chain(auth_events, state, [event]):
|
||||
|
||||
@@ -132,6 +132,14 @@ class IdentityHandler(BaseHandler):
|
||||
}
|
||||
)
|
||||
logger.debug("bound threepid %r to %s", creds, mxid)
|
||||
|
||||
# Remember where we bound the threepid
|
||||
yield self.store.add_user_bound_threepid(
|
||||
user_id=mxid,
|
||||
medium=data["medium"],
|
||||
address=data["address"],
|
||||
id_server=id_server,
|
||||
)
|
||||
except CodeMessageException as e:
|
||||
data = json.loads(e.msg) # XXX WAT?
|
||||
defer.returnValue(data)
|
||||
@@ -140,9 +148,48 @@ class IdentityHandler(BaseHandler):
|
||||
def try_unbind_threepid(self, mxid, threepid):
|
||||
"""Removes a binding from an identity server
|
||||
|
||||
Args:
|
||||
mxid (str): Matrix user ID of binding to be removed
|
||||
threepid (dict): Dict with medium & address of binding to be
|
||||
removed, and an optional id_server.
|
||||
|
||||
Raises:
|
||||
SynapseError: If we failed to contact the identity server
|
||||
|
||||
Returns:
|
||||
Deferred[bool]: True on success, otherwise False if the identity
|
||||
server doesn't support unbinding (or no identity server found to
|
||||
contact).
|
||||
"""
|
||||
if threepid.get("id_server"):
|
||||
id_servers = [threepid["id_server"]]
|
||||
else:
|
||||
id_servers = yield self.store.get_id_servers_user_bound(
|
||||
user_id=mxid,
|
||||
medium=threepid["medium"],
|
||||
address=threepid["address"],
|
||||
)
|
||||
|
||||
# We don't know where to unbind, so we don't have a choice but to return
|
||||
if not id_servers:
|
||||
defer.returnValue(False)
|
||||
|
||||
changed = True
|
||||
for id_server in id_servers:
|
||||
changed &= yield self.try_unbind_threepid_with_id_server(
|
||||
mxid, threepid, id_server,
|
||||
)
|
||||
|
||||
defer.returnValue(changed)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def try_unbind_threepid_with_id_server(self, mxid, threepid, id_server):
|
||||
"""Removes a binding from an identity server
|
||||
|
||||
Args:
|
||||
mxid (str): Matrix user ID of binding to be removed
|
||||
threepid (dict): Dict with medium & address of binding to be removed
|
||||
id_server (str): Identity server to unbind from
|
||||
|
||||
Raises:
|
||||
SynapseError: If we failed to contact the identity server
|
||||
@@ -151,21 +198,13 @@ class IdentityHandler(BaseHandler):
|
||||
Deferred[bool]: True on success, otherwise False if the identity
|
||||
server doesn't support unbinding
|
||||
"""
|
||||
logger.debug("unbinding threepid %r from %s", threepid, mxid)
|
||||
if not self.trusted_id_servers:
|
||||
logger.warn("Can't unbind threepid: no trusted ID servers set in config")
|
||||
defer.returnValue(False)
|
||||
|
||||
# We don't track what ID server we added 3pids on (perhaps we ought to)
|
||||
# but we assume that any of the servers in the trusted list are in the
|
||||
# same ID server federation, so we can pick any one of them to send the
|
||||
# deletion request to.
|
||||
id_server = next(iter(self.trusted_id_servers))
|
||||
|
||||
url = "https://%s/_matrix/identity/api/v1/3pid/unbind" % (id_server,)
|
||||
content = {
|
||||
"mxid": mxid,
|
||||
"threepid": threepid,
|
||||
"threepid": {
|
||||
"medium": threepid["medium"],
|
||||
"address": threepid["address"],
|
||||
},
|
||||
}
|
||||
|
||||
# we abuse the federation http client to sign the request, but we have to send it
|
||||
@@ -188,16 +227,24 @@ class IdentityHandler(BaseHandler):
|
||||
content,
|
||||
headers,
|
||||
)
|
||||
changed = True
|
||||
except HttpResponseException as e:
|
||||
changed = False
|
||||
if e.code in (400, 404, 501,):
|
||||
# The remote server probably doesn't support unbinding (yet)
|
||||
logger.warn("Received %d response while unbinding threepid", e.code)
|
||||
defer.returnValue(False)
|
||||
else:
|
||||
logger.error("Failed to unbind threepid on identity server: %s", e)
|
||||
raise SynapseError(502, "Failed to contact identity server")
|
||||
|
||||
defer.returnValue(True)
|
||||
yield self.store.remove_user_bound_threepid(
|
||||
user_id=mxid,
|
||||
medium=threepid["medium"],
|
||||
address=threepid["address"],
|
||||
id_server=id_server,
|
||||
)
|
||||
|
||||
defer.returnValue(changed)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def requestEmailToken(self, id_server, email, client_secret, send_attempt, **kwargs):
|
||||
|
||||
@@ -22,7 +22,7 @@ from canonicaljson import encode_canonical_json, json
|
||||
from twisted.internet import defer
|
||||
from twisted.internet.defer import succeed
|
||||
|
||||
from synapse.api.constants import EventTypes, Membership, RoomVersions
|
||||
from synapse.api.constants import EventTypes, Membership
|
||||
from synapse.api.errors import (
|
||||
AuthError,
|
||||
Codes,
|
||||
@@ -30,6 +30,7 @@ from synapse.api.errors import (
|
||||
NotFoundError,
|
||||
SynapseError,
|
||||
)
|
||||
from synapse.api.room_versions import RoomVersions
|
||||
from synapse.api.urls import ConsentURIBuilder
|
||||
from synapse.events.utils import serialize_event
|
||||
from synapse.events.validator import EventValidator
|
||||
@@ -191,7 +192,7 @@ class MessageHandler(object):
|
||||
"Getting joined members after leaving is not implemented"
|
||||
)
|
||||
|
||||
users_with_profile = yield self.state.get_current_user_in_room(room_id)
|
||||
users_with_profile = yield self.state.get_current_users_in_room(room_id)
|
||||
|
||||
# If this is an AS, double check that they are allowed to see the members.
|
||||
# This can either be because the AS user is in the room or because there
|
||||
@@ -603,7 +604,9 @@ class EventCreationHandler(object):
|
||||
"""
|
||||
|
||||
if event.is_state() and (event.type, event.state_key) == (EventTypes.Create, ""):
|
||||
room_version = event.content.get("room_version", RoomVersions.V1)
|
||||
room_version = event.content.get(
|
||||
"room_version", RoomVersions.V1.identifier
|
||||
)
|
||||
else:
|
||||
room_version = yield self.store.get_room_version(event.room_id)
|
||||
|
||||
|
||||
@@ -31,9 +31,11 @@ from prometheus_client import Counter
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.constants import PresenceState
|
||||
import synapse.metrics
|
||||
from synapse.api.constants import EventTypes, Membership, PresenceState
|
||||
from synapse.api.errors import SynapseError
|
||||
from synapse.metrics import LaterGauge
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.storage.presence import UserPresenceState
|
||||
from synapse.types import UserID, get_domain_from_id
|
||||
from synapse.util.async_helpers import Linearizer
|
||||
@@ -98,6 +100,7 @@ class PresenceHandler(object):
|
||||
self.hs = hs
|
||||
self.is_mine = hs.is_mine
|
||||
self.is_mine_id = hs.is_mine_id
|
||||
self.server_name = hs.hostname
|
||||
self.clock = hs.get_clock()
|
||||
self.store = hs.get_datastore()
|
||||
self.wheel_timer = WheelTimer()
|
||||
@@ -110,30 +113,6 @@ class PresenceHandler(object):
|
||||
federation_registry.register_edu_handler(
|
||||
"m.presence", self.incoming_presence
|
||||
)
|
||||
federation_registry.register_edu_handler(
|
||||
"m.presence_invite",
|
||||
lambda origin, content: self.invite_presence(
|
||||
observed_user=UserID.from_string(content["observed_user"]),
|
||||
observer_user=UserID.from_string(content["observer_user"]),
|
||||
)
|
||||
)
|
||||
federation_registry.register_edu_handler(
|
||||
"m.presence_accept",
|
||||
lambda origin, content: self.accept_presence(
|
||||
observed_user=UserID.from_string(content["observed_user"]),
|
||||
observer_user=UserID.from_string(content["observer_user"]),
|
||||
)
|
||||
)
|
||||
federation_registry.register_edu_handler(
|
||||
"m.presence_deny",
|
||||
lambda origin, content: self.deny_presence(
|
||||
observed_user=UserID.from_string(content["observed_user"]),
|
||||
observer_user=UserID.from_string(content["observer_user"]),
|
||||
)
|
||||
)
|
||||
|
||||
distributor = hs.get_distributor()
|
||||
distributor.observe("user_joined_room", self.user_joined_room)
|
||||
|
||||
active_presence = self.store.take_presence_startup_info()
|
||||
|
||||
@@ -220,6 +199,15 @@ class PresenceHandler(object):
|
||||
LaterGauge("synapse_handlers_presence_wheel_timer_size", "", [],
|
||||
lambda: len(self.wheel_timer))
|
||||
|
||||
# Used to handle sending of presence to newly joined users/servers
|
||||
if hs.config.use_presence:
|
||||
self.notifier.add_replication_callback(self.notify_new_event)
|
||||
|
||||
# Presence is best effort and quickly heals itself, so lets just always
|
||||
# stream from the current state when we restart.
|
||||
self._event_pos = self.store.get_current_events_token()
|
||||
self._event_processing = False
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _on_shutdown(self):
|
||||
"""Gets called when shutting down. This lets us persist any updates that
|
||||
@@ -750,162 +738,6 @@ class PresenceHandler(object):
|
||||
|
||||
yield self._update_states([prev_state.copy_and_replace(**new_fields)])
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def user_joined_room(self, user, room_id):
|
||||
"""Called (via the distributor) when a user joins a room. This funciton
|
||||
sends presence updates to servers, either:
|
||||
1. the joining user is a local user and we send their presence to
|
||||
all servers in the room.
|
||||
2. the joining user is a remote user and so we send presence for all
|
||||
local users in the room.
|
||||
"""
|
||||
# We only need to send presence to servers that don't have it yet. We
|
||||
# don't need to send to local clients here, as that is done as part
|
||||
# of the event stream/sync.
|
||||
# TODO: Only send to servers not already in the room.
|
||||
if self.is_mine(user):
|
||||
state = yield self.current_state_for_user(user.to_string())
|
||||
|
||||
self._push_to_remotes([state])
|
||||
else:
|
||||
user_ids = yield self.store.get_users_in_room(room_id)
|
||||
user_ids = list(filter(self.is_mine_id, user_ids))
|
||||
|
||||
states = yield self.current_state_for_users(user_ids)
|
||||
|
||||
self._push_to_remotes(list(states.values()))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_presence_list(self, observer_user, accepted=None):
|
||||
"""Returns the presence for all users in their presence list.
|
||||
"""
|
||||
if not self.is_mine(observer_user):
|
||||
raise SynapseError(400, "User is not hosted on this Home Server")
|
||||
|
||||
presence_list = yield self.store.get_presence_list(
|
||||
observer_user.localpart, accepted=accepted
|
||||
)
|
||||
|
||||
results = yield self.get_states(
|
||||
target_user_ids=[row["observed_user_id"] for row in presence_list],
|
||||
as_event=False,
|
||||
)
|
||||
|
||||
now = self.clock.time_msec()
|
||||
results[:] = [format_user_presence_state(r, now) for r in results]
|
||||
|
||||
is_accepted = {
|
||||
row["observed_user_id"]: row["accepted"] for row in presence_list
|
||||
}
|
||||
|
||||
for result in results:
|
||||
result.update({
|
||||
"accepted": is_accepted,
|
||||
})
|
||||
|
||||
defer.returnValue(results)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def send_presence_invite(self, observer_user, observed_user):
|
||||
"""Sends a presence invite.
|
||||
"""
|
||||
yield self.store.add_presence_list_pending(
|
||||
observer_user.localpart, observed_user.to_string()
|
||||
)
|
||||
|
||||
if self.is_mine(observed_user):
|
||||
yield self.invite_presence(observed_user, observer_user)
|
||||
else:
|
||||
yield self.federation.build_and_send_edu(
|
||||
destination=observed_user.domain,
|
||||
edu_type="m.presence_invite",
|
||||
content={
|
||||
"observed_user": observed_user.to_string(),
|
||||
"observer_user": observer_user.to_string(),
|
||||
}
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def invite_presence(self, observed_user, observer_user):
|
||||
"""Handles new presence invites.
|
||||
"""
|
||||
if not self.is_mine(observed_user):
|
||||
raise SynapseError(400, "User is not hosted on this Home Server")
|
||||
|
||||
# TODO: Don't auto accept
|
||||
if self.is_mine(observer_user):
|
||||
yield self.accept_presence(observed_user, observer_user)
|
||||
else:
|
||||
self.federation.build_and_send_edu(
|
||||
destination=observer_user.domain,
|
||||
edu_type="m.presence_accept",
|
||||
content={
|
||||
"observed_user": observed_user.to_string(),
|
||||
"observer_user": observer_user.to_string(),
|
||||
}
|
||||
)
|
||||
|
||||
state_dict = yield self.get_state(observed_user, as_event=False)
|
||||
state_dict = format_user_presence_state(state_dict, self.clock.time_msec())
|
||||
|
||||
self.federation.build_and_send_edu(
|
||||
destination=observer_user.domain,
|
||||
edu_type="m.presence",
|
||||
content={
|
||||
"push": [state_dict]
|
||||
}
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def accept_presence(self, observed_user, observer_user):
|
||||
"""Handles a m.presence_accept EDU. Mark a presence invite from a
|
||||
local or remote user as accepted in a local user's presence list.
|
||||
Starts polling for presence updates from the local or remote user.
|
||||
Args:
|
||||
observed_user(UserID): The user to update in the presence list.
|
||||
observer_user(UserID): The owner of the presence list to update.
|
||||
"""
|
||||
yield self.store.set_presence_list_accepted(
|
||||
observer_user.localpart, observed_user.to_string()
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def deny_presence(self, observed_user, observer_user):
|
||||
"""Handle a m.presence_deny EDU. Removes a local or remote user from a
|
||||
local user's presence list.
|
||||
Args:
|
||||
observed_user(UserID): The local or remote user to remove from the
|
||||
list.
|
||||
observer_user(UserID): The local owner of the presence list.
|
||||
Returns:
|
||||
A Deferred.
|
||||
"""
|
||||
yield self.store.del_presence_list(
|
||||
observer_user.localpart, observed_user.to_string()
|
||||
)
|
||||
|
||||
# TODO(paul): Inform the user somehow?
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def drop(self, observed_user, observer_user):
|
||||
"""Remove a local or remote user from a local user's presence list and
|
||||
unsubscribe the local user from updates that user.
|
||||
Args:
|
||||
observed_user(UserId): The local or remote user to remove from the
|
||||
list.
|
||||
observer_user(UserId): The local owner of the presence list.
|
||||
Returns:
|
||||
A Deferred.
|
||||
"""
|
||||
if not self.is_mine(observer_user):
|
||||
raise SynapseError(400, "User is not hosted on this Home Server")
|
||||
|
||||
yield self.store.del_presence_list(
|
||||
observer_user.localpart, observed_user.to_string()
|
||||
)
|
||||
|
||||
# TODO: Inform the remote that we've dropped the presence list.
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def is_visible(self, observed_user, observer_user):
|
||||
"""Returns whether a user can see another user's presence.
|
||||
@@ -920,11 +752,7 @@ class PresenceHandler(object):
|
||||
if observer_room_ids & observed_room_ids:
|
||||
defer.returnValue(True)
|
||||
|
||||
accepted_observers = yield self.store.get_presence_list_observers_accepted(
|
||||
observed_user.to_string()
|
||||
)
|
||||
|
||||
defer.returnValue(observer_user.to_string() in accepted_observers)
|
||||
defer.returnValue(False)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_all_presence_updates(self, last_id, current_id):
|
||||
@@ -945,6 +773,140 @@ class PresenceHandler(object):
|
||||
rows = yield self.store.get_all_presence_updates(last_id, current_id)
|
||||
defer.returnValue(rows)
|
||||
|
||||
def notify_new_event(self):
|
||||
"""Called when new events have happened. Handles users and servers
|
||||
joining rooms and require being sent presence.
|
||||
"""
|
||||
|
||||
if self._event_processing:
|
||||
return
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _process_presence():
|
||||
assert not self._event_processing
|
||||
|
||||
self._event_processing = True
|
||||
try:
|
||||
yield self._unsafe_process()
|
||||
finally:
|
||||
self._event_processing = False
|
||||
|
||||
run_as_background_process("presence.notify_new_event", _process_presence)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _unsafe_process(self):
|
||||
# Loop round handling deltas until we're up to date
|
||||
while True:
|
||||
with Measure(self.clock, "presence_delta"):
|
||||
deltas = yield self.store.get_current_state_deltas(self._event_pos)
|
||||
if not deltas:
|
||||
return
|
||||
|
||||
yield self._handle_state_delta(deltas)
|
||||
|
||||
self._event_pos = deltas[-1]["stream_id"]
|
||||
|
||||
# Expose current event processing position to prometheus
|
||||
synapse.metrics.event_processing_positions.labels("presence").set(
|
||||
self._event_pos
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _handle_state_delta(self, deltas):
|
||||
"""Process current state deltas to find new joins that need to be
|
||||
handled.
|
||||
"""
|
||||
for delta in deltas:
|
||||
typ = delta["type"]
|
||||
state_key = delta["state_key"]
|
||||
room_id = delta["room_id"]
|
||||
event_id = delta["event_id"]
|
||||
prev_event_id = delta["prev_event_id"]
|
||||
|
||||
logger.debug("Handling: %r %r, %s", typ, state_key, event_id)
|
||||
|
||||
if typ != EventTypes.Member:
|
||||
continue
|
||||
|
||||
event = yield self.store.get_event(event_id)
|
||||
if event.content.get("membership") != Membership.JOIN:
|
||||
# We only care about joins
|
||||
continue
|
||||
|
||||
if prev_event_id:
|
||||
prev_event = yield self.store.get_event(prev_event_id)
|
||||
if prev_event.content.get("membership") == Membership.JOIN:
|
||||
# Ignore changes to join events.
|
||||
continue
|
||||
|
||||
yield self._on_user_joined_room(room_id, state_key)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _on_user_joined_room(self, room_id, user_id):
|
||||
"""Called when we detect a user joining the room via the current state
|
||||
delta stream.
|
||||
|
||||
Args:
|
||||
room_id (str)
|
||||
user_id (str)
|
||||
|
||||
Returns:
|
||||
Deferred
|
||||
"""
|
||||
|
||||
if self.is_mine_id(user_id):
|
||||
# If this is a local user then we need to send their presence
|
||||
# out to hosts in the room (who don't already have it)
|
||||
|
||||
# TODO: We should be able to filter the hosts down to those that
|
||||
# haven't previously seen the user
|
||||
|
||||
state = yield self.current_state_for_user(user_id)
|
||||
hosts = yield self.state.get_current_hosts_in_room(room_id)
|
||||
|
||||
# Filter out ourselves.
|
||||
hosts = set(host for host in hosts if host != self.server_name)
|
||||
|
||||
self.federation.send_presence_to_destinations(
|
||||
states=[state],
|
||||
destinations=hosts,
|
||||
)
|
||||
else:
|
||||
# A remote user has joined the room, so we need to:
|
||||
# 1. Check if this is a new server in the room
|
||||
# 2. If so send any presence they don't already have for
|
||||
# local users in the room.
|
||||
|
||||
# TODO: We should be able to filter the users down to those that
|
||||
# the server hasn't previously seen
|
||||
|
||||
# TODO: Check that this is actually a new server joining the
|
||||
# room.
|
||||
|
||||
user_ids = yield self.state.get_current_users_in_room(room_id)
|
||||
user_ids = list(filter(self.is_mine_id, user_ids))
|
||||
|
||||
states = yield self.current_state_for_users(user_ids)
|
||||
|
||||
# Filter out old presence, i.e. offline presence states where
|
||||
# the user hasn't been active for a week. We can change this
|
||||
# depending on what we want the UX to be, but at the least we
|
||||
# should filter out offline presence where the state is just the
|
||||
# default state.
|
||||
now = self.clock.time_msec()
|
||||
states = [
|
||||
state for state in states.values()
|
||||
if state.state != PresenceState.OFFLINE
|
||||
or now - state.last_active_ts < 7 * 24 * 60 * 60 * 1000
|
||||
or state.status_msg is not None
|
||||
]
|
||||
|
||||
if states:
|
||||
self.federation.send_presence_to_destinations(
|
||||
states=states,
|
||||
destinations=[get_domain_from_id(user_id)],
|
||||
)
|
||||
|
||||
|
||||
def should_notify(old_state, new_state):
|
||||
"""Decides if a presence state change should be sent to interested parties.
|
||||
@@ -1086,10 +1048,7 @@ class PresenceEventSource(object):
|
||||
updates for
|
||||
"""
|
||||
user_id = user.to_string()
|
||||
plist = yield self.store.get_presence_list_accepted(
|
||||
user.localpart, on_invalidate=cache_context.invalidate,
|
||||
)
|
||||
users_interested_in = set(row["observed_user_id"] for row in plist)
|
||||
users_interested_in = set()
|
||||
users_interested_in.add(user_id) # So that we receive our own presence
|
||||
|
||||
users_who_share_room = yield self.store.get_users_who_share_room_with_user(
|
||||
@@ -1294,10 +1253,6 @@ def get_interested_parties(store, states):
|
||||
for room_id in room_ids:
|
||||
room_ids_to_states.setdefault(room_id, []).append(state)
|
||||
|
||||
plist = yield store.get_presence_list_observers_accepted(state.user_id)
|
||||
for u in plist:
|
||||
users_to_states.setdefault(u, []).append(state)
|
||||
|
||||
# Always notify self
|
||||
users_to_states.setdefault(state.user_id, []).append(state)
|
||||
|
||||
|
||||
@@ -153,6 +153,7 @@ class RegistrationHandler(BaseHandler):
|
||||
user_type=None,
|
||||
default_display_name=None,
|
||||
address=None,
|
||||
bind_emails=[],
|
||||
):
|
||||
"""Registers a new client on the server.
|
||||
|
||||
@@ -172,6 +173,7 @@ class RegistrationHandler(BaseHandler):
|
||||
default_display_name (unicode|None): if set, the new user's displayname
|
||||
will be set to this. Defaults to 'localpart'.
|
||||
address (str|None): the IP address used to perform the registration.
|
||||
bind_emails (List[str]): list of emails to bind to this account.
|
||||
Returns:
|
||||
A tuple of (user_id, access_token).
|
||||
Raises:
|
||||
@@ -261,6 +263,21 @@ class RegistrationHandler(BaseHandler):
|
||||
if not self.hs.config.user_consent_at_registration:
|
||||
yield self._auto_join_rooms(user_id)
|
||||
|
||||
# Bind any specified emails to this account
|
||||
current_time = self.hs.get_clock().time_msec()
|
||||
for email in bind_emails:
|
||||
# generate threepid dict
|
||||
threepid_dict = {
|
||||
"medium": "email",
|
||||
"address": email,
|
||||
"validated_at": current_time,
|
||||
}
|
||||
|
||||
# Bind email to new account
|
||||
yield self._register_email_threepid(
|
||||
user_id, threepid_dict, None, False,
|
||||
)
|
||||
|
||||
defer.returnValue((user_id, token))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
|
||||
@@ -25,14 +25,9 @@ from six import iteritems, string_types
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.constants import (
|
||||
DEFAULT_ROOM_VERSION,
|
||||
KNOWN_ROOM_VERSIONS,
|
||||
EventTypes,
|
||||
JoinRules,
|
||||
RoomCreationPreset,
|
||||
)
|
||||
from synapse.api.constants import EventTypes, JoinRules, RoomCreationPreset
|
||||
from synapse.api.errors import AuthError, Codes, NotFoundError, StoreError, SynapseError
|
||||
from synapse.api.room_versions import DEFAULT_ROOM_VERSION, KNOWN_ROOM_VERSIONS
|
||||
from synapse.storage.state import StateFilter
|
||||
from synapse.types import RoomAlias, RoomID, RoomStreamToken, StreamToken, UserID
|
||||
from synapse.util import stringutils
|
||||
@@ -285,6 +280,7 @@ class RoomCreationHandler(BaseHandler):
|
||||
(EventTypes.RoomAvatar, ""),
|
||||
(EventTypes.Encryption, ""),
|
||||
(EventTypes.ServerACL, ""),
|
||||
(EventTypes.RelatedGroups, ""),
|
||||
)
|
||||
|
||||
old_room_state_ids = yield self.store.get_filtered_current_state_ids(
|
||||
@@ -479,7 +475,7 @@ class RoomCreationHandler(BaseHandler):
|
||||
if ratelimit:
|
||||
yield self.ratelimit(requester)
|
||||
|
||||
room_version = config.get("room_version", DEFAULT_ROOM_VERSION)
|
||||
room_version = config.get("room_version", DEFAULT_ROOM_VERSION.identifier)
|
||||
if not isinstance(room_version, string_types):
|
||||
raise SynapseError(
|
||||
400,
|
||||
|
||||
@@ -167,7 +167,7 @@ class RoomListHandler(BaseHandler):
|
||||
if not latest_event_ids:
|
||||
return
|
||||
|
||||
joined_users = yield self.state_handler.get_current_user_in_room(
|
||||
joined_users = yield self.state_handler.get_current_users_in_room(
|
||||
room_id, latest_event_ids,
|
||||
)
|
||||
|
||||
|
||||
@@ -70,6 +70,7 @@ class RoomMemberHandler(object):
|
||||
self.clock = hs.get_clock()
|
||||
self.spam_checker = hs.get_spam_checker()
|
||||
self._server_notices_mxid = self.config.server_notices_mxid
|
||||
self._enable_lookup = hs.config.enable_3pid_lookup
|
||||
|
||||
@abc.abstractmethod
|
||||
def _remote_join(self, requester, remote_room_hosts, room_id, user, content):
|
||||
@@ -421,6 +422,9 @@ class RoomMemberHandler(object):
|
||||
room_id, latest_event_ids=latest_event_ids,
|
||||
)
|
||||
|
||||
# TODO: Refactor into dictionary of explicitly allowed transitions
|
||||
# between old and new state, with specific error messages for some
|
||||
# transitions and generic otherwise
|
||||
old_state_id = current_state_ids.get((EventTypes.Member, target.to_string()))
|
||||
if old_state_id:
|
||||
old_state = yield self.store.get_event(old_state_id, allow_none=True)
|
||||
@@ -446,6 +450,9 @@ class RoomMemberHandler(object):
|
||||
if same_sender and same_membership and same_content:
|
||||
defer.returnValue(old_state)
|
||||
|
||||
if old_membership in ["ban", "leave"] and action == "kick":
|
||||
raise AuthError(403, "The target user is not in the room")
|
||||
|
||||
# we don't allow people to reject invites to the server notice
|
||||
# room, but they can leave it once they are joined.
|
||||
if (
|
||||
@@ -459,6 +466,9 @@ class RoomMemberHandler(object):
|
||||
"You cannot reject this invite",
|
||||
errcode=Codes.CANNOT_LEAVE_SERVER_NOTICE_ROOM,
|
||||
)
|
||||
else:
|
||||
if action == "kick":
|
||||
raise AuthError(403, "The target user is not in the room")
|
||||
|
||||
is_host_in_room = yield self._is_host_in_room(current_state_ids)
|
||||
|
||||
@@ -729,6 +739,10 @@ class RoomMemberHandler(object):
|
||||
Returns:
|
||||
str: the matrix ID of the 3pid, or None if it is not recognized.
|
||||
"""
|
||||
if not self._enable_lookup:
|
||||
raise SynapseError(
|
||||
403, "Looking up third-party identifiers is denied from this server",
|
||||
)
|
||||
try:
|
||||
data = yield self.simple_http_client.get_json(
|
||||
"%s%s/_matrix/identity/api/v1/lookup" % (id_server_scheme, id_server,),
|
||||
|
||||
@@ -1049,11 +1049,11 @@ class SyncHandler(object):
|
||||
# TODO: Be more clever than this, i.e. remove users who we already
|
||||
# share a room with?
|
||||
for room_id in newly_joined_rooms:
|
||||
joined_users = yield self.state.get_current_user_in_room(room_id)
|
||||
joined_users = yield self.state.get_current_users_in_room(room_id)
|
||||
newly_joined_users.update(joined_users)
|
||||
|
||||
for room_id in newly_left_rooms:
|
||||
left_users = yield self.state.get_current_user_in_room(room_id)
|
||||
left_users = yield self.state.get_current_users_in_room(room_id)
|
||||
newly_left_users.update(left_users)
|
||||
|
||||
# TODO: Check that these users are actually new, i.e. either they
|
||||
@@ -1213,7 +1213,7 @@ class SyncHandler(object):
|
||||
|
||||
extra_users_ids = set(newly_joined_users)
|
||||
for room_id in newly_joined_rooms:
|
||||
users = yield self.state.get_current_user_in_room(room_id)
|
||||
users = yield self.state.get_current_users_in_room(room_id)
|
||||
extra_users_ids.update(users)
|
||||
extra_users_ids.discard(user.to_string())
|
||||
|
||||
@@ -1855,7 +1855,7 @@ class SyncHandler(object):
|
||||
extrems = yield self.store.get_forward_extremeties_for_room(
|
||||
room_id, stream_ordering,
|
||||
)
|
||||
users_in_room = yield self.state.get_current_user_in_room(
|
||||
users_in_room = yield self.state.get_current_users_in_room(
|
||||
room_id, extrems,
|
||||
)
|
||||
if user_id in users_in_room:
|
||||
|
||||
@@ -218,7 +218,7 @@ class TypingHandler(object):
|
||||
@defer.inlineCallbacks
|
||||
def _push_remote(self, member, typing):
|
||||
try:
|
||||
users = yield self.state.get_current_user_in_room(member.room_id)
|
||||
users = yield self.state.get_current_users_in_room(member.room_id)
|
||||
self._member_last_federation_poke[member] = self.clock.time_msec()
|
||||
|
||||
now = self.clock.time_msec()
|
||||
@@ -261,7 +261,7 @@ class TypingHandler(object):
|
||||
)
|
||||
return
|
||||
|
||||
users = yield self.state.get_current_user_in_room(room_id)
|
||||
users = yield self.state.get_current_users_in_room(room_id)
|
||||
domains = set(get_domain_from_id(u) for u in users)
|
||||
|
||||
if self.server_name in domains:
|
||||
|
||||
@@ -276,7 +276,7 @@ class UserDirectoryHandler(StateDeltasHandler):
|
||||
# ignore the change
|
||||
return
|
||||
|
||||
users_with_profile = yield self.state.get_current_user_in_room(room_id)
|
||||
users_with_profile = yield self.state.get_current_users_in_room(room_id)
|
||||
|
||||
# Remove every user from the sharing tables for that room.
|
||||
for user_id in iterkeys(users_with_profile):
|
||||
@@ -325,7 +325,7 @@ class UserDirectoryHandler(StateDeltasHandler):
|
||||
room_id
|
||||
)
|
||||
# Now we update users who share rooms with users.
|
||||
users_with_profile = yield self.state.get_current_user_in_room(room_id)
|
||||
users_with_profile = yield self.state.get_current_users_in_room(room_id)
|
||||
|
||||
if is_public:
|
||||
yield self.store.add_users_in_public_rooms(room_id, (user_id,))
|
||||
|
||||
@@ -74,14 +74,14 @@ class ModuleApi(object):
|
||||
return self._auth_handler.check_user_exists(user_id)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def register(self, localpart, displayname=None):
|
||||
def register(self, localpart, displayname=None, emails=[]):
|
||||
"""Registers a new user with given localpart and optional
|
||||
displayname.
|
||||
displayname, emails.
|
||||
|
||||
Args:
|
||||
localpart (str): The localpart of the new user.
|
||||
displayname (str|None): The displayname of the new user. If None,
|
||||
the user's displayname will default to `localpart`.
|
||||
displayname (str|None): The displayname of the new user.
|
||||
emails (List[str]): Emails to bind to the new user.
|
||||
|
||||
Returns:
|
||||
Deferred: a 2-tuple of (user_id, access_token)
|
||||
@@ -90,6 +90,7 @@ class ModuleApi(object):
|
||||
reg = self.hs.get_registration_handler()
|
||||
user_id, access_token = yield reg.register(
|
||||
localpart=localpart, default_display_name=displayname,
|
||||
bind_emails=emails,
|
||||
)
|
||||
|
||||
defer.returnValue((user_id, access_token))
|
||||
|
||||
@@ -72,8 +72,15 @@ class EmailPusher(object):
|
||||
|
||||
self._is_processing = False
|
||||
|
||||
def on_started(self):
|
||||
if self.mailer is not None:
|
||||
def on_started(self, should_check_for_notifs):
|
||||
"""Called when this pusher has been started.
|
||||
|
||||
Args:
|
||||
should_check_for_notifs (bool): Whether we should immediately
|
||||
check for push to send. Set to False only if it's known there
|
||||
is nothing to send
|
||||
"""
|
||||
if should_check_for_notifs and self.mailer is not None:
|
||||
self._start_processing()
|
||||
|
||||
def on_stop(self):
|
||||
|
||||
@@ -112,8 +112,16 @@ class HttpPusher(object):
|
||||
self.data_minus_url.update(self.data)
|
||||
del self.data_minus_url['url']
|
||||
|
||||
def on_started(self):
|
||||
self._start_processing()
|
||||
def on_started(self, should_check_for_notifs):
|
||||
"""Called when this pusher has been started.
|
||||
|
||||
Args:
|
||||
should_check_for_notifs (bool): Whether we should immediately
|
||||
check for push to send. Set to False only if it's known there
|
||||
is nothing to send
|
||||
"""
|
||||
if should_check_for_notifs:
|
||||
self._start_processing()
|
||||
|
||||
def on_new_notifications(self, min_stream_ordering, max_stream_ordering):
|
||||
self.max_stream_ordering = max(max_stream_ordering, self.max_stream_ordering or 0)
|
||||
|
||||
@@ -21,6 +21,7 @@ from twisted.internet import defer
|
||||
from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.push import PusherConfigException
|
||||
from synapse.push.pusher import PusherFactory
|
||||
from synapse.util.async_helpers import concurrently_execute
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -197,7 +198,7 @@ class PusherPool:
|
||||
p = r
|
||||
|
||||
if p:
|
||||
self._start_pusher(p)
|
||||
yield self._start_pusher(p)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _start_pushers(self):
|
||||
@@ -208,10 +209,14 @@ class PusherPool:
|
||||
"""
|
||||
pushers = yield self.store.get_all_pushers()
|
||||
logger.info("Starting %d pushers", len(pushers))
|
||||
for pusherdict in pushers:
|
||||
self._start_pusher(pusherdict)
|
||||
|
||||
# Stagger starting up the pushers so we don't completely drown the
|
||||
# process on start up.
|
||||
yield concurrently_execute(self._start_pusher, pushers, 10)
|
||||
|
||||
logger.info("Started pushers")
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _start_pusher(self, pusherdict):
|
||||
"""Start the given pusher
|
||||
|
||||
@@ -248,7 +253,22 @@ class PusherPool:
|
||||
if appid_pushkey in byuser:
|
||||
byuser[appid_pushkey].on_stop()
|
||||
byuser[appid_pushkey] = p
|
||||
p.on_started()
|
||||
|
||||
# Check if there *may* be push to process. We do this as this check is a
|
||||
# lot cheaper to do than actually fetching the exact rows we need to
|
||||
# push.
|
||||
user_id = pusherdict["user_name"]
|
||||
last_stream_ordering = pusherdict["last_stream_ordering"]
|
||||
if last_stream_ordering:
|
||||
have_notifs = yield self.store.get_if_maybe_push_in_range_for_user(
|
||||
user_id, last_stream_ordering,
|
||||
)
|
||||
else:
|
||||
# We always want to default to starting up the pusher rather than
|
||||
# risk missing push.
|
||||
have_notifs = True
|
||||
|
||||
p.on_started(have_notifs)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def remove_pusher(self, app_id, pushkey, user_id):
|
||||
|
||||
@@ -16,6 +16,10 @@
|
||||
import logging
|
||||
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.replication.tcp.streams.events import (
|
||||
EventsStreamCurrentStateRow,
|
||||
EventsStreamEventRow,
|
||||
)
|
||||
from synapse.storage.event_federation import EventFederationWorkerStore
|
||||
from synapse.storage.event_push_actions import EventPushActionsWorkerStore
|
||||
from synapse.storage.events_worker import EventsWorkerStore
|
||||
@@ -79,11 +83,7 @@ class SlavedEventStore(EventFederationWorkerStore,
|
||||
if stream_name == "events":
|
||||
self._stream_id_gen.advance(token)
|
||||
for row in rows:
|
||||
self.invalidate_caches_for_event(
|
||||
token, row.event_id, row.room_id, row.type, row.state_key,
|
||||
row.redacts,
|
||||
backfilled=False,
|
||||
)
|
||||
self._process_event_stream_row(token, row)
|
||||
elif stream_name == "backfill":
|
||||
self._backfill_id_gen.advance(-token)
|
||||
for row in rows:
|
||||
@@ -96,6 +96,23 @@ class SlavedEventStore(EventFederationWorkerStore,
|
||||
stream_name, token, rows
|
||||
)
|
||||
|
||||
def _process_event_stream_row(self, token, row):
|
||||
data = row.data
|
||||
|
||||
if row.type == EventsStreamEventRow.TypeId:
|
||||
self.invalidate_caches_for_event(
|
||||
token, data.event_id, data.room_id, data.type, data.state_key,
|
||||
data.redacts,
|
||||
backfilled=False,
|
||||
)
|
||||
elif row.type == EventsStreamCurrentStateRow.TypeId:
|
||||
if data.type == EventTypes.Member:
|
||||
self.get_rooms_for_user_with_stream_ordering.invalidate(
|
||||
(data.state_key, ),
|
||||
)
|
||||
else:
|
||||
raise Exception("Unknown events stream row type %s" % (row.type, ))
|
||||
|
||||
def invalidate_caches_for_event(self, stream_ordering, event_id, room_id,
|
||||
etype, state_key, redacts, backfilled):
|
||||
self._invalidate_get_event_cache(event_id)
|
||||
|
||||
@@ -39,16 +39,6 @@ class SlavedPresenceStore(BaseSlavedStore):
|
||||
_get_presence_for_user = PresenceStore.__dict__["_get_presence_for_user"]
|
||||
get_presence_for_users = PresenceStore.__dict__["get_presence_for_users"]
|
||||
|
||||
# XXX: This is a bit broken because we don't persist the accepted list in a
|
||||
# way that can be replicated. This means that we don't have a way to
|
||||
# invalidate the cache correctly.
|
||||
get_presence_list_accepted = PresenceStore.__dict__[
|
||||
"get_presence_list_accepted"
|
||||
]
|
||||
get_presence_list_observers_accepted = PresenceStore.__dict__[
|
||||
"get_presence_list_observers_accepted"
|
||||
]
|
||||
|
||||
def get_current_presence_token(self):
|
||||
return self._presence_id_gen.get_current_token()
|
||||
|
||||
|
||||
@@ -103,10 +103,19 @@ class ReplicationClientHandler(object):
|
||||
hs.get_reactor().connectTCP(host, port, self.factory)
|
||||
|
||||
def on_rdata(self, stream_name, token, rows):
|
||||
"""Called when we get new replication data. By default this just pokes
|
||||
the slave store.
|
||||
"""Called to handle a batch of replication data with a given stream token.
|
||||
|
||||
Can be overriden in subclasses to handle more.
|
||||
By default this just pokes the slave store. Can be overridden in subclasses to
|
||||
handle more.
|
||||
|
||||
Args:
|
||||
stream_name (str): name of the replication stream for this batch of rows
|
||||
token (int): stream token for this batch of rows
|
||||
rows (list): a list of Stream.ROW_TYPE objects as returned by
|
||||
Stream.parse_row.
|
||||
|
||||
Returns:
|
||||
Deferred|None
|
||||
"""
|
||||
logger.debug("Received rdata %s -> %s", stream_name, token)
|
||||
return self.store.process_replication_rows(stream_name, token, rows)
|
||||
|
||||
@@ -42,8 +42,8 @@ indicate which side is sending, these are *not* included on the wire::
|
||||
> POSITION backfill 1
|
||||
> POSITION caches 1
|
||||
> RDATA caches 2 ["get_user_by_id",["@01register-user:localhost:8823"],1490197670513]
|
||||
> RDATA events 14 ["$149019767112vOHxz:localhost:8823",
|
||||
"!AFDCvgApUmpdfVjIXm:localhost:8823","m.room.guest_access","",null]
|
||||
> RDATA events 14 ["ev", ["$149019767112vOHxz:localhost:8823",
|
||||
"!AFDCvgApUmpdfVjIXm:localhost:8823","m.room.guest_access","",null]]
|
||||
< PING 1490197675618
|
||||
> ERROR server stopping
|
||||
* connection closed by server *
|
||||
@@ -605,7 +605,7 @@ class ClientReplicationStreamProtocol(BaseReplicationStreamProtocol):
|
||||
inbound_rdata_count.labels(stream_name).inc()
|
||||
|
||||
try:
|
||||
row = STREAMS_MAP[stream_name].ROW_TYPE(*cmd.row)
|
||||
row = STREAMS_MAP[stream_name].parse_row(cmd.row)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"[%s] Failed to parse RDATA: %r %r",
|
||||
|
||||
@@ -30,7 +30,8 @@ from synapse.metrics.background_process_metrics import run_as_background_process
|
||||
from synapse.util.metrics import Measure, measure_func
|
||||
|
||||
from .protocol import ServerReplicationStreamProtocol
|
||||
from .streams import STREAMS_MAP, FederationStream
|
||||
from .streams import STREAMS_MAP
|
||||
from .streams.federation import FederationStream
|
||||
|
||||
stream_updates_counter = Counter("synapse_replication_tcp_resource_stream_updates",
|
||||
"", ["stream_name"])
|
||||
|
||||
49
synapse/replication/tcp/streams/__init__.py
Normal file
49
synapse/replication/tcp/streams/__init__.py
Normal file
@@ -0,0 +1,49 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2017 Vector Creations Ltd
|
||||
# Copyright 2019 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Defines all the valid streams that clients can subscribe to, and the format
|
||||
of the rows returned by each stream.
|
||||
|
||||
Each stream is defined by the following information:
|
||||
|
||||
stream name: The name of the stream
|
||||
row type: The type that is used to serialise/deserialse the row
|
||||
current_token: The function that returns the current token for the stream
|
||||
update_function: The function that returns a list of updates between two tokens
|
||||
"""
|
||||
|
||||
from . import _base, events, federation
|
||||
|
||||
STREAMS_MAP = {
|
||||
stream.NAME: stream
|
||||
for stream in (
|
||||
events.EventsStream,
|
||||
_base.BackfillStream,
|
||||
_base.PresenceStream,
|
||||
_base.TypingStream,
|
||||
_base.ReceiptsStream,
|
||||
_base.PushRulesStream,
|
||||
_base.PushersStream,
|
||||
_base.CachesStream,
|
||||
_base.PublicRoomsStream,
|
||||
_base.DeviceListsStream,
|
||||
_base.ToDeviceStream,
|
||||
federation.FederationStream,
|
||||
_base.TagAccountDataStream,
|
||||
_base.AccountDataStream,
|
||||
_base.GroupServerStream,
|
||||
)
|
||||
}
|
||||
@@ -1,5 +1,6 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2017 Vector Creations Ltd
|
||||
# Copyright 2019 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -13,16 +14,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Defines all the valid streams that clients can subscribe to, and the format
|
||||
of the rows returned by each stream.
|
||||
|
||||
Each stream is defined by the following information:
|
||||
|
||||
stream name: The name of the stream
|
||||
row type: The type that is used to serialise/deserialse the row
|
||||
current_token: The function that returns the current token for the stream
|
||||
update_function: The function that returns a list of updates between two tokens
|
||||
"""
|
||||
import itertools
|
||||
import logging
|
||||
from collections import namedtuple
|
||||
@@ -34,14 +26,6 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
MAX_EVENTS_BEHIND = 10000
|
||||
|
||||
|
||||
EventStreamRow = namedtuple("EventStreamRow", (
|
||||
"event_id", # str
|
||||
"room_id", # str
|
||||
"type", # str
|
||||
"state_key", # str, optional
|
||||
"redacts", # str, optional
|
||||
))
|
||||
BackfillStreamRow = namedtuple("BackfillStreamRow", (
|
||||
"event_id", # str
|
||||
"room_id", # str
|
||||
@@ -96,10 +80,6 @@ DeviceListsStreamRow = namedtuple("DeviceListsStreamRow", (
|
||||
ToDeviceStreamRow = namedtuple("ToDeviceStreamRow", (
|
||||
"entity", # str
|
||||
))
|
||||
FederationStreamRow = namedtuple("FederationStreamRow", (
|
||||
"type", # str, the type of data as defined in the BaseFederationRows
|
||||
"data", # dict, serialization of a federation.send_queue.BaseFederationRow
|
||||
))
|
||||
TagAccountDataStreamRow = namedtuple("TagAccountDataStreamRow", (
|
||||
"user_id", # str
|
||||
"room_id", # str
|
||||
@@ -111,12 +91,6 @@ AccountDataStreamRow = namedtuple("AccountDataStream", (
|
||||
"data_type", # str
|
||||
"data", # dict
|
||||
))
|
||||
CurrentStateDeltaStreamRow = namedtuple("CurrentStateDeltaStream", (
|
||||
"room_id", # str
|
||||
"type", # str
|
||||
"state_key", # str
|
||||
"event_id", # str, optional
|
||||
))
|
||||
GroupsStreamRow = namedtuple("GroupsStreamRow", (
|
||||
"group_id", # str
|
||||
"user_id", # str
|
||||
@@ -132,9 +106,24 @@ class Stream(object):
|
||||
time it was called up until the point `advance_current_token` was called.
|
||||
"""
|
||||
NAME = None # The name of the stream
|
||||
ROW_TYPE = None # The type of the row
|
||||
ROW_TYPE = None # The type of the row. Used by the default impl of parse_row.
|
||||
_LIMITED = True # Whether the update function takes a limit
|
||||
|
||||
@classmethod
|
||||
def parse_row(cls, row):
|
||||
"""Parse a row received over replication
|
||||
|
||||
By default, assumes that the row data is an array object and passes its contents
|
||||
to the constructor of the ROW_TYPE for this stream.
|
||||
|
||||
Args:
|
||||
row: row data from the incoming RDATA command, after json decoding
|
||||
|
||||
Returns:
|
||||
ROW_TYPE object for this stream
|
||||
"""
|
||||
return cls.ROW_TYPE(*row)
|
||||
|
||||
def __init__(self, hs):
|
||||
# The token from which we last asked for updates
|
||||
self.last_token = self.current_token()
|
||||
@@ -162,8 +151,10 @@ class Stream(object):
|
||||
until the `upto_token`
|
||||
|
||||
Returns:
|
||||
(list(ROW_TYPE), int): list of updates plus the token used as an
|
||||
upper bound of the updates (i.e. the "current token")
|
||||
Deferred[Tuple[List[Tuple[int, Any]], int]:
|
||||
Resolves to a pair ``(updates, current_token)``, where ``updates`` is a
|
||||
list of ``(token, row)`` entries. ``row`` will be json-serialised and
|
||||
sent over the replication steam.
|
||||
"""
|
||||
updates, current_token = yield self.get_updates_since(self.last_token)
|
||||
self.last_token = current_token
|
||||
@@ -176,8 +167,10 @@ class Stream(object):
|
||||
stream updates
|
||||
|
||||
Returns:
|
||||
(list(ROW_TYPE), int): list of updates plus the token used as an
|
||||
upper bound of the updates (i.e. the "current token")
|
||||
Deferred[Tuple[List[Tuple[int, Any]], int]:
|
||||
Resolves to a pair ``(updates, current_token)``, where ``updates`` is a
|
||||
list of ``(token, row)`` entries. ``row`` will be json-serialised and
|
||||
sent over the replication steam.
|
||||
"""
|
||||
if from_token in ("NOW", "now"):
|
||||
defer.returnValue(([], self.upto_token))
|
||||
@@ -202,7 +195,7 @@ class Stream(object):
|
||||
from_token, current_token,
|
||||
)
|
||||
|
||||
updates = [(row[0], self.ROW_TYPE(*row[1:])) for row in rows]
|
||||
updates = [(row[0], row[1:]) for row in rows]
|
||||
|
||||
# check we didn't get more rows than the limit.
|
||||
# doing it like this allows the update_function to be a generator.
|
||||
@@ -232,20 +225,6 @@ class Stream(object):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class EventsStream(Stream):
|
||||
"""We received a new event, or an event went from being an outlier to not
|
||||
"""
|
||||
NAME = "events"
|
||||
ROW_TYPE = EventStreamRow
|
||||
|
||||
def __init__(self, hs):
|
||||
store = hs.get_datastore()
|
||||
self.current_token = store.get_current_events_token
|
||||
self.update_function = store.get_all_new_forward_event_rows
|
||||
|
||||
super(EventsStream, self).__init__(hs)
|
||||
|
||||
|
||||
class BackfillStream(Stream):
|
||||
"""We fetched some old events and either we had never seen that event before
|
||||
or it went from being an outlier to not.
|
||||
@@ -400,22 +379,6 @@ class ToDeviceStream(Stream):
|
||||
super(ToDeviceStream, self).__init__(hs)
|
||||
|
||||
|
||||
class FederationStream(Stream):
|
||||
"""Data to be sent over federation. Only available when master has federation
|
||||
sending disabled.
|
||||
"""
|
||||
NAME = "federation"
|
||||
ROW_TYPE = FederationStreamRow
|
||||
|
||||
def __init__(self, hs):
|
||||
federation_sender = hs.get_federation_sender()
|
||||
|
||||
self.current_token = federation_sender.get_current_token
|
||||
self.update_function = federation_sender.get_replication_rows
|
||||
|
||||
super(FederationStream, self).__init__(hs)
|
||||
|
||||
|
||||
class TagAccountDataStream(Stream):
|
||||
"""Someone added/removed a tag for a room
|
||||
"""
|
||||
@@ -459,21 +422,6 @@ class AccountDataStream(Stream):
|
||||
defer.returnValue(results)
|
||||
|
||||
|
||||
class CurrentStateDeltaStream(Stream):
|
||||
"""Current state for a room was changed
|
||||
"""
|
||||
NAME = "current_state_deltas"
|
||||
ROW_TYPE = CurrentStateDeltaStreamRow
|
||||
|
||||
def __init__(self, hs):
|
||||
store = hs.get_datastore()
|
||||
|
||||
self.current_token = store.get_max_current_state_delta_stream_id
|
||||
self.update_function = store.get_all_updated_current_state_deltas
|
||||
|
||||
super(CurrentStateDeltaStream, self).__init__(hs)
|
||||
|
||||
|
||||
class GroupServerStream(Stream):
|
||||
NAME = "groups"
|
||||
ROW_TYPE = GroupsStreamRow
|
||||
@@ -485,26 +433,3 @@ class GroupServerStream(Stream):
|
||||
self.update_function = store.get_all_groups_changes
|
||||
|
||||
super(GroupServerStream, self).__init__(hs)
|
||||
|
||||
|
||||
STREAMS_MAP = {
|
||||
stream.NAME: stream
|
||||
for stream in (
|
||||
EventsStream,
|
||||
BackfillStream,
|
||||
PresenceStream,
|
||||
TypingStream,
|
||||
ReceiptsStream,
|
||||
PushRulesStream,
|
||||
PushersStream,
|
||||
CachesStream,
|
||||
PublicRoomsStream,
|
||||
DeviceListsStream,
|
||||
ToDeviceStream,
|
||||
FederationStream,
|
||||
TagAccountDataStream,
|
||||
AccountDataStream,
|
||||
CurrentStateDeltaStream,
|
||||
GroupServerStream,
|
||||
)
|
||||
}
|
||||
146
synapse/replication/tcp/streams/events.py
Normal file
146
synapse/replication/tcp/streams/events.py
Normal file
@@ -0,0 +1,146 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2017 Vector Creations Ltd
|
||||
# Copyright 2019 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import heapq
|
||||
|
||||
import attr
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from ._base import Stream
|
||||
|
||||
|
||||
"""Handling of the 'events' replication stream
|
||||
|
||||
This stream contains rows of various types. Each row therefore contains a 'type'
|
||||
identifier before the real data. For example::
|
||||
|
||||
RDATA events batch ["state", ["!room:id", "m.type", "", "$event:id"]]
|
||||
RDATA events 12345 ["ev", ["$event:id", "!room:id", "m.type", null, null]]
|
||||
|
||||
An "ev" row is sent for each new event. The fields in the data part are:
|
||||
|
||||
* The new event id
|
||||
* The room id for the event
|
||||
* The type of the new event
|
||||
* The state key of the event, for state events
|
||||
* The event id of an event which is redacted by this event.
|
||||
|
||||
A "state" row is sent whenever the "current state" in a room changes. The fields in the
|
||||
data part are:
|
||||
|
||||
* The room id for the state change
|
||||
* The event type of the state which has changed
|
||||
* The state_key of the state which has changed
|
||||
* The event id of the new state
|
||||
|
||||
"""
|
||||
|
||||
|
||||
@attr.s(slots=True, frozen=True)
|
||||
class EventsStreamRow(object):
|
||||
"""A parsed row from the events replication stream"""
|
||||
type = attr.ib() # str: the TypeId of one of the *EventsStreamRows
|
||||
data = attr.ib() # BaseEventsStreamRow
|
||||
|
||||
|
||||
class BaseEventsStreamRow(object):
|
||||
"""Base class for rows to be sent in the events stream.
|
||||
|
||||
Specifies how to identify, serialize and deserialize the different types.
|
||||
"""
|
||||
|
||||
TypeId = None # Unique string that ids the type. Must be overriden in sub classes.
|
||||
|
||||
@classmethod
|
||||
def from_data(cls, data):
|
||||
"""Parse the data from the replication stream into a row.
|
||||
|
||||
By default we just call the constructor with the data list as arguments
|
||||
|
||||
Args:
|
||||
data: The value of the data object from the replication stream
|
||||
"""
|
||||
return cls(*data)
|
||||
|
||||
|
||||
@attr.s(slots=True, frozen=True)
|
||||
class EventsStreamEventRow(BaseEventsStreamRow):
|
||||
TypeId = "ev"
|
||||
|
||||
event_id = attr.ib() # str
|
||||
room_id = attr.ib() # str
|
||||
type = attr.ib() # str
|
||||
state_key = attr.ib() # str, optional
|
||||
redacts = attr.ib() # str, optional
|
||||
|
||||
|
||||
@attr.s(slots=True, frozen=True)
|
||||
class EventsStreamCurrentStateRow(BaseEventsStreamRow):
|
||||
TypeId = "state"
|
||||
|
||||
room_id = attr.ib() # str
|
||||
type = attr.ib() # str
|
||||
state_key = attr.ib() # str
|
||||
event_id = attr.ib() # str, optional
|
||||
|
||||
|
||||
TypeToRow = {
|
||||
Row.TypeId: Row
|
||||
for Row in (
|
||||
EventsStreamEventRow,
|
||||
EventsStreamCurrentStateRow,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
class EventsStream(Stream):
|
||||
"""We received a new event, or an event went from being an outlier to not
|
||||
"""
|
||||
NAME = "events"
|
||||
|
||||
def __init__(self, hs):
|
||||
self._store = hs.get_datastore()
|
||||
self.current_token = self._store.get_current_events_token
|
||||
|
||||
super(EventsStream, self).__init__(hs)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def update_function(self, from_token, current_token, limit=None):
|
||||
event_rows = yield self._store.get_all_new_forward_event_rows(
|
||||
from_token, current_token, limit,
|
||||
)
|
||||
event_updates = (
|
||||
(row[0], EventsStreamEventRow.TypeId, row[1:])
|
||||
for row in event_rows
|
||||
)
|
||||
|
||||
state_rows = yield self._store.get_all_updated_current_state_deltas(
|
||||
from_token, current_token, limit
|
||||
)
|
||||
state_updates = (
|
||||
(row[0], EventsStreamCurrentStateRow.TypeId, row[1:])
|
||||
for row in state_rows
|
||||
)
|
||||
|
||||
all_updates = heapq.merge(event_updates, state_updates)
|
||||
|
||||
defer.returnValue(all_updates)
|
||||
|
||||
@classmethod
|
||||
def parse_row(cls, row):
|
||||
(typ, data) = row
|
||||
data = TypeToRow[typ].from_data(data)
|
||||
return EventsStreamRow(typ, data)
|
||||
39
synapse/replication/tcp/streams/federation.py
Normal file
39
synapse/replication/tcp/streams/federation.py
Normal file
@@ -0,0 +1,39 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# Copyright 2017 Vector Creations Ltd
|
||||
# Copyright 2019 New Vector Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from collections import namedtuple
|
||||
|
||||
from ._base import Stream
|
||||
|
||||
FederationStreamRow = namedtuple("FederationStreamRow", (
|
||||
"type", # str, the type of data as defined in the BaseFederationRows
|
||||
"data", # dict, serialization of a federation.send_queue.BaseFederationRow
|
||||
))
|
||||
|
||||
|
||||
class FederationStream(Stream):
|
||||
"""Data to be sent over federation. Only available when master has federation
|
||||
sending disabled.
|
||||
"""
|
||||
NAME = "federation"
|
||||
ROW_TYPE = FederationStreamRow
|
||||
|
||||
def __init__(self, hs):
|
||||
federation_sender = hs.get_federation_sender()
|
||||
|
||||
self.current_token = federation_sender.get_current_token
|
||||
self.update_function = federation_sender.get_replication_rows
|
||||
|
||||
super(FederationStream, self).__init__(hs)
|
||||
@@ -499,7 +499,7 @@ class ShutdownRoomRestServlet(ClientV1RestServlet):
|
||||
# desirable in case the first attempt at blocking the room failed below.
|
||||
yield self.store.block_room(room_id, requester_user_id)
|
||||
|
||||
users = yield self.state.get_current_user_in_room(room_id)
|
||||
users = yield self.state.get_current_users_in_room(room_id)
|
||||
kicked_users = []
|
||||
failed_to_kick_users = []
|
||||
for user_id in users:
|
||||
@@ -647,8 +647,6 @@ class ResetPasswordRestServlet(ClientV1RestServlet):
|
||||
assert_params_in_dict(params, ["new_password"])
|
||||
new_password = params['new_password']
|
||||
|
||||
logger.info("new_password: %r", new_password)
|
||||
|
||||
yield self._set_password_handler.set_password(
|
||||
target_user_id, new_password, requester
|
||||
)
|
||||
@@ -786,6 +784,31 @@ class SearchUsersRestServlet(ClientV1RestServlet):
|
||||
defer.returnValue((200, ret))
|
||||
|
||||
|
||||
class DeleteGroupAdminRestServlet(ClientV1RestServlet):
|
||||
"""Allows deleting of local groups
|
||||
"""
|
||||
PATTERNS = client_path_patterns("/admin/delete_group/(?P<group_id>[^/]*)")
|
||||
|
||||
def __init__(self, hs):
|
||||
super(DeleteGroupAdminRestServlet, self).__init__(hs)
|
||||
self.group_server = hs.get_groups_server_handler()
|
||||
self.is_mine_id = hs.is_mine_id
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request, group_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
is_admin = yield self.auth.is_server_admin(requester.user)
|
||||
|
||||
if not is_admin:
|
||||
raise AuthError(403, "You are not a server admin")
|
||||
|
||||
if not self.is_mine_id(group_id):
|
||||
raise SynapseError(400, "Can only delete local groups")
|
||||
|
||||
yield self.group_server.delete_group(group_id, requester.user.to_string())
|
||||
defer.returnValue((200, {}))
|
||||
|
||||
|
||||
def register_servlets(hs, http_server):
|
||||
WhoisRestServlet(hs).register(http_server)
|
||||
PurgeMediaCacheRestServlet(hs).register(http_server)
|
||||
@@ -801,3 +824,4 @@ def register_servlets(hs, http_server):
|
||||
ListMediaInRoom(hs).register(http_server)
|
||||
UserRegisterServlet(hs).register(http_server)
|
||||
VersionServlet(hs).register(http_server)
|
||||
DeleteGroupAdminRestServlet(hs).register(http_server)
|
||||
|
||||
@@ -93,72 +93,5 @@ class PresenceStatusRestServlet(ClientV1RestServlet):
|
||||
return (200, {})
|
||||
|
||||
|
||||
class PresenceListRestServlet(ClientV1RestServlet):
|
||||
PATTERNS = client_path_patterns("/presence/list/(?P<user_id>[^/]*)")
|
||||
|
||||
def __init__(self, hs):
|
||||
super(PresenceListRestServlet, self).__init__(hs)
|
||||
self.presence_handler = hs.get_presence_handler()
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_GET(self, request, user_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
user = UserID.from_string(user_id)
|
||||
|
||||
if not self.hs.is_mine(user):
|
||||
raise SynapseError(400, "User not hosted on this Home Server")
|
||||
|
||||
if requester.user != user:
|
||||
raise SynapseError(400, "Cannot get another user's presence list")
|
||||
|
||||
presence = yield self.presence_handler.get_presence_list(
|
||||
observer_user=user, accepted=True
|
||||
)
|
||||
|
||||
defer.returnValue((200, presence))
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def on_POST(self, request, user_id):
|
||||
requester = yield self.auth.get_user_by_req(request)
|
||||
user = UserID.from_string(user_id)
|
||||
|
||||
if not self.hs.is_mine(user):
|
||||
raise SynapseError(400, "User not hosted on this Home Server")
|
||||
|
||||
if requester.user != user:
|
||||
raise SynapseError(
|
||||
400, "Cannot modify another user's presence list")
|
||||
|
||||
content = parse_json_object_from_request(request)
|
||||
|
||||
if "invite" in content:
|
||||
for u in content["invite"]:
|
||||
if not isinstance(u, string_types):
|
||||
raise SynapseError(400, "Bad invite value.")
|
||||
if len(u) == 0:
|
||||
continue
|
||||
invited_user = UserID.from_string(u)
|
||||
yield self.presence_handler.send_presence_invite(
|
||||
observer_user=user, observed_user=invited_user
|
||||
)
|
||||
|
||||
if "drop" in content:
|
||||
for u in content["drop"]:
|
||||
if not isinstance(u, string_types):
|
||||
raise SynapseError(400, "Bad drop value.")
|
||||
if len(u) == 0:
|
||||
continue
|
||||
dropped_user = UserID.from_string(u)
|
||||
yield self.presence_handler.drop(
|
||||
observer_user=user, observed_user=dropped_user
|
||||
)
|
||||
|
||||
defer.returnValue((200, {}))
|
||||
|
||||
def on_OPTIONS(self, request):
|
||||
return (200, {})
|
||||
|
||||
|
||||
def register_servlets(hs, http_server):
|
||||
PresenceStatusRestServlet(hs).register(http_server)
|
||||
PresenceListRestServlet(hs).register(http_server)
|
||||
|
||||
@@ -215,6 +215,7 @@ class DeactivateAccountRestServlet(RestServlet):
|
||||
)
|
||||
result = yield self._deactivate_account_handler.deactivate_account(
|
||||
requester.user.to_string(), erase,
|
||||
id_server=body.get("id_server"),
|
||||
)
|
||||
if result:
|
||||
id_server_unbind_result = "success"
|
||||
@@ -363,7 +364,7 @@ class ThreepidRestServlet(RestServlet):
|
||||
|
||||
|
||||
class ThreepidDeleteRestServlet(RestServlet):
|
||||
PATTERNS = client_v2_patterns("/account/3pid/delete$", releases=())
|
||||
PATTERNS = client_v2_patterns("/account/3pid/delete$")
|
||||
|
||||
def __init__(self, hs):
|
||||
super(ThreepidDeleteRestServlet, self).__init__()
|
||||
@@ -380,7 +381,7 @@ class ThreepidDeleteRestServlet(RestServlet):
|
||||
|
||||
try:
|
||||
ret = yield self.auth_handler.delete_threepid(
|
||||
user_id, body['medium'], body['address']
|
||||
user_id, body['medium'], body['address'], body.get("id_server"),
|
||||
)
|
||||
except Exception:
|
||||
# NB. This endpoint should succeed if there is nothing to
|
||||
|
||||
@@ -16,7 +16,7 @@ import logging
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.constants import DEFAULT_ROOM_VERSION, RoomDisposition, RoomVersions
|
||||
from synapse.api.room_versions import DEFAULT_ROOM_VERSION, KNOWN_ROOM_VERSIONS
|
||||
from synapse.http.servlet import RestServlet
|
||||
|
||||
from ._base import client_v2_patterns
|
||||
@@ -48,12 +48,10 @@ class CapabilitiesRestServlet(RestServlet):
|
||||
response = {
|
||||
"capabilities": {
|
||||
"m.room_versions": {
|
||||
"default": DEFAULT_ROOM_VERSION,
|
||||
"default": DEFAULT_ROOM_VERSION.identifier,
|
||||
"available": {
|
||||
RoomVersions.V1: RoomDisposition.STABLE,
|
||||
RoomVersions.V2: RoomDisposition.STABLE,
|
||||
RoomVersions.STATE_V2_TEST: RoomDisposition.UNSTABLE,
|
||||
RoomVersions.V3: RoomDisposition.STABLE,
|
||||
v.identifier: v.disposition
|
||||
for v in KNOWN_ROOM_VERSIONS.values()
|
||||
},
|
||||
},
|
||||
"m.change_password": {"enabled": change_password},
|
||||
|
||||
@@ -17,8 +17,8 @@ import logging
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.constants import KNOWN_ROOM_VERSIONS
|
||||
from synapse.api.errors import Codes, SynapseError
|
||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
|
||||
from synapse.http.servlet import (
|
||||
RestServlet,
|
||||
assert_params_in_dict,
|
||||
|
||||
@@ -24,7 +24,8 @@ from frozendict import frozendict
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.constants import EventTypes, RoomVersions
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, StateResolutionVersions
|
||||
from synapse.events.snapshot import EventContext
|
||||
from synapse.state import v1, v2
|
||||
from synapse.util.async_helpers import Linearizer
|
||||
@@ -160,10 +161,21 @@ class StateHandler(object):
|
||||
defer.returnValue(state)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_current_user_in_room(self, room_id, latest_event_ids=None):
|
||||
def get_current_users_in_room(self, room_id, latest_event_ids=None):
|
||||
"""
|
||||
Get the users who are currently in a room.
|
||||
|
||||
Args:
|
||||
room_id (str): The ID of the room.
|
||||
latest_event_ids (List[str]|None): Precomputed list of latest
|
||||
event IDs. Will be computed if None.
|
||||
Returns:
|
||||
Deferred[Dict[str,ProfileInfo]]: Dictionary of user IDs to their
|
||||
profileinfo.
|
||||
"""
|
||||
if not latest_event_ids:
|
||||
latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
|
||||
logger.debug("calling resolve_state_groups from get_current_user_in_room")
|
||||
logger.debug("calling resolve_state_groups from get_current_users_in_room")
|
||||
entry = yield self.resolve_state_groups_for_events(room_id, latest_event_ids)
|
||||
joined_users = yield self.store.get_joined_users_from_state(room_id, entry)
|
||||
defer.returnValue(joined_users)
|
||||
@@ -603,22 +615,15 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto
|
||||
Deferred[dict[(str, str), str]]:
|
||||
a map from (type, state_key) to event_id.
|
||||
"""
|
||||
if room_version == RoomVersions.V1:
|
||||
v = KNOWN_ROOM_VERSIONS[room_version]
|
||||
if v.state_res == StateResolutionVersions.V1:
|
||||
return v1.resolve_events_with_store(
|
||||
state_sets, event_map, state_res_store.get_events,
|
||||
)
|
||||
elif room_version in (
|
||||
RoomVersions.STATE_V2_TEST, RoomVersions.V2, RoomVersions.V3,
|
||||
):
|
||||
else:
|
||||
return v2.resolve_events_with_store(
|
||||
room_version, state_sets, event_map, state_res_store,
|
||||
)
|
||||
else:
|
||||
# This should only happen if we added a version but forgot to add it to
|
||||
# the list above.
|
||||
raise Exception(
|
||||
"No state resolution algorithm defined for version %r" % (room_version,)
|
||||
)
|
||||
|
||||
|
||||
@attr.s
|
||||
|
||||
@@ -21,8 +21,9 @@ from six import iteritems, iterkeys, itervalues
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse import event_auth
|
||||
from synapse.api.constants import EventTypes, RoomVersions
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.api.errors import AuthError
|
||||
from synapse.api.room_versions import RoomVersions
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -275,7 +276,9 @@ def _resolve_auth_events(events, auth_events):
|
||||
try:
|
||||
# The signatures have already been checked at this point
|
||||
event_auth.check(
|
||||
RoomVersions.V1, event, auth_events,
|
||||
RoomVersions.V1.identifier,
|
||||
event,
|
||||
auth_events,
|
||||
do_sig_check=False,
|
||||
do_size_check=False,
|
||||
)
|
||||
@@ -291,7 +294,9 @@ def _resolve_normal_events(events, auth_events):
|
||||
try:
|
||||
# The signatures have already been checked at this point
|
||||
event_auth.check(
|
||||
RoomVersions.V1, event, auth_events,
|
||||
RoomVersions.V1.identifier,
|
||||
event,
|
||||
auth_events,
|
||||
do_sig_check=False,
|
||||
do_size_check=False,
|
||||
)
|
||||
|
||||
@@ -49,7 +49,7 @@ var show_login = function() {
|
||||
$("#loading").hide();
|
||||
|
||||
var this_page = window.location.origin + window.location.pathname;
|
||||
$("#sso_redirect_url").val(encodeURIComponent(this_page));
|
||||
$("#sso_redirect_url").val(this_page);
|
||||
|
||||
if (matrixLogin.serverAcceptsPassword) {
|
||||
$("#password_flow").show();
|
||||
|
||||
@@ -18,6 +18,8 @@ import calendar
|
||||
import logging
|
||||
import time
|
||||
|
||||
from twisted.internet import defer
|
||||
|
||||
from synapse.api.constants import PresenceState
|
||||
from synapse.storage.devices import DeviceStore
|
||||
from synapse.storage.user_erasure_store import UserErasureStore
|
||||
@@ -61,48 +63,60 @@ from .util.id_generators import ChainedIdGenerator, IdGenerator, StreamIdGenerat
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DataStore(RoomMemberStore, RoomStore,
|
||||
RegistrationStore, StreamStore, ProfileStore,
|
||||
PresenceStore, TransactionStore,
|
||||
DirectoryStore, KeyStore, StateStore, SignatureStore,
|
||||
ApplicationServiceStore,
|
||||
EventsStore,
|
||||
EventFederationStore,
|
||||
MediaRepositoryStore,
|
||||
RejectionsStore,
|
||||
FilteringStore,
|
||||
PusherStore,
|
||||
PushRuleStore,
|
||||
ApplicationServiceTransactionStore,
|
||||
ReceiptsStore,
|
||||
EndToEndKeyStore,
|
||||
EndToEndRoomKeyStore,
|
||||
SearchStore,
|
||||
TagsStore,
|
||||
AccountDataStore,
|
||||
EventPushActionsStore,
|
||||
OpenIdStore,
|
||||
ClientIpStore,
|
||||
DeviceStore,
|
||||
DeviceInboxStore,
|
||||
UserDirectoryStore,
|
||||
GroupServerStore,
|
||||
UserErasureStore,
|
||||
MonthlyActiveUsersStore,
|
||||
):
|
||||
|
||||
class DataStore(
|
||||
RoomMemberStore,
|
||||
RoomStore,
|
||||
RegistrationStore,
|
||||
StreamStore,
|
||||
ProfileStore,
|
||||
PresenceStore,
|
||||
TransactionStore,
|
||||
DirectoryStore,
|
||||
KeyStore,
|
||||
StateStore,
|
||||
SignatureStore,
|
||||
ApplicationServiceStore,
|
||||
EventsStore,
|
||||
EventFederationStore,
|
||||
MediaRepositoryStore,
|
||||
RejectionsStore,
|
||||
FilteringStore,
|
||||
PusherStore,
|
||||
PushRuleStore,
|
||||
ApplicationServiceTransactionStore,
|
||||
ReceiptsStore,
|
||||
EndToEndKeyStore,
|
||||
EndToEndRoomKeyStore,
|
||||
SearchStore,
|
||||
TagsStore,
|
||||
AccountDataStore,
|
||||
EventPushActionsStore,
|
||||
OpenIdStore,
|
||||
ClientIpStore,
|
||||
DeviceStore,
|
||||
DeviceInboxStore,
|
||||
UserDirectoryStore,
|
||||
GroupServerStore,
|
||||
UserErasureStore,
|
||||
MonthlyActiveUsersStore,
|
||||
):
|
||||
def __init__(self, db_conn, hs):
|
||||
self.hs = hs
|
||||
self._clock = hs.get_clock()
|
||||
self.database_engine = hs.database_engine
|
||||
|
||||
self._stream_id_gen = StreamIdGenerator(
|
||||
db_conn, "events", "stream_ordering",
|
||||
extra_tables=[("local_invites", "stream_id")]
|
||||
db_conn,
|
||||
"events",
|
||||
"stream_ordering",
|
||||
extra_tables=[("local_invites", "stream_id")],
|
||||
)
|
||||
self._backfill_id_gen = StreamIdGenerator(
|
||||
db_conn, "events", "stream_ordering", step=-1,
|
||||
extra_tables=[("ex_outlier_stream", "event_stream_ordering")]
|
||||
db_conn,
|
||||
"events",
|
||||
"stream_ordering",
|
||||
step=-1,
|
||||
extra_tables=[("ex_outlier_stream", "event_stream_ordering")],
|
||||
)
|
||||
self._presence_id_gen = StreamIdGenerator(
|
||||
db_conn, "presence_stream", "stream_id"
|
||||
@@ -114,7 +128,7 @@ class DataStore(RoomMemberStore, RoomStore,
|
||||
db_conn, "public_room_list_stream", "stream_id"
|
||||
)
|
||||
self._device_list_id_gen = StreamIdGenerator(
|
||||
db_conn, "device_lists_stream", "stream_id",
|
||||
db_conn, "device_lists_stream", "stream_id"
|
||||
)
|
||||
|
||||
self._access_tokens_id_gen = IdGenerator(db_conn, "access_tokens", "id")
|
||||
@@ -125,16 +139,15 @@ class DataStore(RoomMemberStore, RoomStore,
|
||||
self._stream_id_gen, db_conn, "push_rules_stream", "stream_id"
|
||||
)
|
||||
self._pushers_id_gen = StreamIdGenerator(
|
||||
db_conn, "pushers", "id",
|
||||
extra_tables=[("deleted_pushers", "stream_id")],
|
||||
db_conn, "pushers", "id", extra_tables=[("deleted_pushers", "stream_id")]
|
||||
)
|
||||
self._group_updates_id_gen = StreamIdGenerator(
|
||||
db_conn, "local_group_updates", "stream_id",
|
||||
db_conn, "local_group_updates", "stream_id"
|
||||
)
|
||||
|
||||
if isinstance(self.database_engine, PostgresEngine):
|
||||
self._cache_id_gen = StreamIdGenerator(
|
||||
db_conn, "cache_invalidation_stream", "stream_id",
|
||||
db_conn, "cache_invalidation_stream", "stream_id"
|
||||
)
|
||||
else:
|
||||
self._cache_id_gen = None
|
||||
@@ -142,72 +155,82 @@ class DataStore(RoomMemberStore, RoomStore,
|
||||
self._presence_on_startup = self._get_active_presence(db_conn)
|
||||
|
||||
presence_cache_prefill, min_presence_val = self._get_cache_dict(
|
||||
db_conn, "presence_stream",
|
||||
db_conn,
|
||||
"presence_stream",
|
||||
entity_column="user_id",
|
||||
stream_column="stream_id",
|
||||
max_value=self._presence_id_gen.get_current_token(),
|
||||
)
|
||||
self.presence_stream_cache = StreamChangeCache(
|
||||
"PresenceStreamChangeCache", min_presence_val,
|
||||
prefilled_cache=presence_cache_prefill
|
||||
"PresenceStreamChangeCache",
|
||||
min_presence_val,
|
||||
prefilled_cache=presence_cache_prefill,
|
||||
)
|
||||
|
||||
max_device_inbox_id = self._device_inbox_id_gen.get_current_token()
|
||||
device_inbox_prefill, min_device_inbox_id = self._get_cache_dict(
|
||||
db_conn, "device_inbox",
|
||||
db_conn,
|
||||
"device_inbox",
|
||||
entity_column="user_id",
|
||||
stream_column="stream_id",
|
||||
max_value=max_device_inbox_id,
|
||||
limit=1000,
|
||||
)
|
||||
self._device_inbox_stream_cache = StreamChangeCache(
|
||||
"DeviceInboxStreamChangeCache", min_device_inbox_id,
|
||||
"DeviceInboxStreamChangeCache",
|
||||
min_device_inbox_id,
|
||||
prefilled_cache=device_inbox_prefill,
|
||||
)
|
||||
# The federation outbox and the local device inbox uses the same
|
||||
# stream_id generator.
|
||||
device_outbox_prefill, min_device_outbox_id = self._get_cache_dict(
|
||||
db_conn, "device_federation_outbox",
|
||||
db_conn,
|
||||
"device_federation_outbox",
|
||||
entity_column="destination",
|
||||
stream_column="stream_id",
|
||||
max_value=max_device_inbox_id,
|
||||
limit=1000,
|
||||
)
|
||||
self._device_federation_outbox_stream_cache = StreamChangeCache(
|
||||
"DeviceFederationOutboxStreamChangeCache", min_device_outbox_id,
|
||||
"DeviceFederationOutboxStreamChangeCache",
|
||||
min_device_outbox_id,
|
||||
prefilled_cache=device_outbox_prefill,
|
||||
)
|
||||
|
||||
device_list_max = self._device_list_id_gen.get_current_token()
|
||||
self._device_list_stream_cache = StreamChangeCache(
|
||||
"DeviceListStreamChangeCache", device_list_max,
|
||||
"DeviceListStreamChangeCache", device_list_max
|
||||
)
|
||||
self._device_list_federation_stream_cache = StreamChangeCache(
|
||||
"DeviceListFederationStreamChangeCache", device_list_max,
|
||||
"DeviceListFederationStreamChangeCache", device_list_max
|
||||
)
|
||||
|
||||
events_max = self._stream_id_gen.get_current_token()
|
||||
curr_state_delta_prefill, min_curr_state_delta_id = self._get_cache_dict(
|
||||
db_conn, "current_state_delta_stream",
|
||||
db_conn,
|
||||
"current_state_delta_stream",
|
||||
entity_column="room_id",
|
||||
stream_column="stream_id",
|
||||
max_value=events_max, # As we share the stream id with events token
|
||||
limit=1000,
|
||||
)
|
||||
self._curr_state_delta_stream_cache = StreamChangeCache(
|
||||
"_curr_state_delta_stream_cache", min_curr_state_delta_id,
|
||||
"_curr_state_delta_stream_cache",
|
||||
min_curr_state_delta_id,
|
||||
prefilled_cache=curr_state_delta_prefill,
|
||||
)
|
||||
|
||||
_group_updates_prefill, min_group_updates_id = self._get_cache_dict(
|
||||
db_conn, "local_group_updates",
|
||||
db_conn,
|
||||
"local_group_updates",
|
||||
entity_column="user_id",
|
||||
stream_column="stream_id",
|
||||
max_value=self._group_updates_id_gen.get_current_token(),
|
||||
limit=1000,
|
||||
)
|
||||
self._group_updates_stream_cache = StreamChangeCache(
|
||||
"_group_updates_stream_cache", min_group_updates_id,
|
||||
"_group_updates_stream_cache",
|
||||
min_group_updates_id,
|
||||
prefilled_cache=_group_updates_prefill,
|
||||
)
|
||||
|
||||
@@ -250,6 +273,7 @@ class DataStore(RoomMemberStore, RoomStore,
|
||||
"""
|
||||
Counts the number of users who used this homeserver in the last 24 hours.
|
||||
"""
|
||||
|
||||
def _count_users(txn):
|
||||
yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24)
|
||||
|
||||
@@ -277,6 +301,7 @@ class DataStore(RoomMemberStore, RoomStore,
|
||||
Returns counts globaly for a given user as well as breaking
|
||||
by platform
|
||||
"""
|
||||
|
||||
def _count_r30_users(txn):
|
||||
thirty_days_in_secs = 86400 * 30
|
||||
now = int(self._clock.time())
|
||||
@@ -313,8 +338,7 @@ class DataStore(RoomMemberStore, RoomStore,
|
||||
"""
|
||||
|
||||
results = {}
|
||||
txn.execute(sql, (thirty_days_ago_in_secs,
|
||||
thirty_days_ago_in_secs))
|
||||
txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs))
|
||||
|
||||
for row in txn:
|
||||
if row[0] == 'unknown':
|
||||
@@ -341,8 +365,7 @@ class DataStore(RoomMemberStore, RoomStore,
|
||||
) u
|
||||
"""
|
||||
|
||||
txn.execute(sql, (thirty_days_ago_in_secs,
|
||||
thirty_days_ago_in_secs))
|
||||
txn.execute(sql, (thirty_days_ago_in_secs, thirty_days_ago_in_secs))
|
||||
|
||||
count, = txn.fetchone()
|
||||
results['all'] = count
|
||||
@@ -356,15 +379,14 @@ class DataStore(RoomMemberStore, RoomStore,
|
||||
Returns millisecond unixtime for start of UTC day.
|
||||
"""
|
||||
now = time.gmtime()
|
||||
today_start = calendar.timegm((
|
||||
now.tm_year, now.tm_mon, now.tm_mday, 0, 0, 0,
|
||||
))
|
||||
today_start = calendar.timegm((now.tm_year, now.tm_mon, now.tm_mday, 0, 0, 0))
|
||||
return today_start * 1000
|
||||
|
||||
def generate_user_daily_visits(self):
|
||||
"""
|
||||
Generates daily visit data for use in cohort/ retention analysis
|
||||
"""
|
||||
|
||||
def _generate_user_daily_visits(txn):
|
||||
logger.info("Calling _generate_user_daily_visits")
|
||||
today_start = self._get_start_of_day()
|
||||
@@ -395,25 +417,29 @@ class DataStore(RoomMemberStore, RoomStore,
|
||||
# often to minimise this case.
|
||||
if today_start > self._last_user_visit_update:
|
||||
yesterday_start = today_start - a_day_in_milliseconds
|
||||
txn.execute(sql, (
|
||||
yesterday_start, yesterday_start,
|
||||
self._last_user_visit_update, today_start
|
||||
))
|
||||
txn.execute(
|
||||
sql,
|
||||
(
|
||||
yesterday_start,
|
||||
yesterday_start,
|
||||
self._last_user_visit_update,
|
||||
today_start,
|
||||
),
|
||||
)
|
||||
self._last_user_visit_update = today_start
|
||||
|
||||
txn.execute(sql, (
|
||||
today_start, today_start,
|
||||
self._last_user_visit_update,
|
||||
now
|
||||
))
|
||||
txn.execute(
|
||||
sql, (today_start, today_start, self._last_user_visit_update, now)
|
||||
)
|
||||
# Update _last_user_visit_update to now. The reason to do this
|
||||
# rather just clamping to the beginning of the day is to limit
|
||||
# the size of the join - meaning that the query can be run more
|
||||
# frequently
|
||||
self._last_user_visit_update = now
|
||||
|
||||
return self.runInteraction("generate_user_daily_visits",
|
||||
_generate_user_daily_visits)
|
||||
return self.runInteraction(
|
||||
"generate_user_daily_visits", _generate_user_daily_visits
|
||||
)
|
||||
|
||||
def get_users(self):
|
||||
"""Function to reterive a list of users in users table.
|
||||
@@ -425,15 +451,11 @@ class DataStore(RoomMemberStore, RoomStore,
|
||||
return self._simple_select_list(
|
||||
table="users",
|
||||
keyvalues={},
|
||||
retcols=[
|
||||
"name",
|
||||
"password_hash",
|
||||
"is_guest",
|
||||
"admin"
|
||||
],
|
||||
retcols=["name", "password_hash", "is_guest", "admin"],
|
||||
desc="get_users",
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_users_paginate(self, order, start, limit):
|
||||
"""Function to reterive a paginated list of users from
|
||||
users list. This will return a json object, which contains
|
||||
@@ -446,27 +468,19 @@ class DataStore(RoomMemberStore, RoomStore,
|
||||
Returns:
|
||||
defer.Deferred: resolves to json object {list[dict[str, Any]], count}
|
||||
"""
|
||||
is_guest = 0
|
||||
i_start = (int)(start)
|
||||
i_limit = (int)(limit)
|
||||
return self.get_user_list_paginate(
|
||||
users = yield self.runInteraction(
|
||||
"get_users_paginate",
|
||||
self._simple_select_list_paginate_txn,
|
||||
table="users",
|
||||
keyvalues={
|
||||
"is_guest": is_guest
|
||||
},
|
||||
pagevalues=[
|
||||
order,
|
||||
i_limit,
|
||||
i_start
|
||||
],
|
||||
retcols=[
|
||||
"name",
|
||||
"password_hash",
|
||||
"is_guest",
|
||||
"admin"
|
||||
],
|
||||
desc="get_users_paginate",
|
||||
keyvalues={"is_guest": False},
|
||||
orderby=order,
|
||||
start=start,
|
||||
limit=limit,
|
||||
retcols=["name", "password_hash", "is_guest", "admin"],
|
||||
)
|
||||
count = yield self.runInteraction("get_users_paginate", self.get_user_count_txn)
|
||||
retval = {"users": users, "total": count}
|
||||
defer.returnValue(retval)
|
||||
|
||||
def search_users(self, term):
|
||||
"""Function to search users list for one or more users with
|
||||
@@ -482,12 +496,7 @@ class DataStore(RoomMemberStore, RoomStore,
|
||||
table="users",
|
||||
term=term,
|
||||
col="name",
|
||||
retcols=[
|
||||
"name",
|
||||
"password_hash",
|
||||
"is_guest",
|
||||
"admin"
|
||||
],
|
||||
retcols=["name", "password_hash", "is_guest", "admin"],
|
||||
desc="search_users",
|
||||
)
|
||||
|
||||
|
||||
@@ -41,7 +41,7 @@ try:
|
||||
MAX_TXN_ID = sys.maxint - 1
|
||||
except AttributeError:
|
||||
# python 3 does not have a maximum int value
|
||||
MAX_TXN_ID = 2**63 - 1
|
||||
MAX_TXN_ID = 2 ** 63 - 1
|
||||
|
||||
sql_logger = logging.getLogger("synapse.storage.SQL")
|
||||
transaction_logger = logging.getLogger("synapse.storage.txn")
|
||||
@@ -76,12 +76,18 @@ class LoggingTransaction(object):
|
||||
"""An object that almost-transparently proxies for the 'txn' object
|
||||
passed to the constructor. Adds logging and metrics to the .execute()
|
||||
method."""
|
||||
|
||||
__slots__ = [
|
||||
"txn", "name", "database_engine", "after_callbacks", "exception_callbacks",
|
||||
"txn",
|
||||
"name",
|
||||
"database_engine",
|
||||
"after_callbacks",
|
||||
"exception_callbacks",
|
||||
]
|
||||
|
||||
def __init__(self, txn, name, database_engine, after_callbacks,
|
||||
exception_callbacks):
|
||||
def __init__(
|
||||
self, txn, name, database_engine, after_callbacks, exception_callbacks
|
||||
):
|
||||
object.__setattr__(self, "txn", txn)
|
||||
object.__setattr__(self, "name", name)
|
||||
object.__setattr__(self, "database_engine", database_engine)
|
||||
@@ -110,6 +116,7 @@ class LoggingTransaction(object):
|
||||
def execute_batch(self, sql, args):
|
||||
if isinstance(self.database_engine, PostgresEngine):
|
||||
from psycopg2.extras import execute_batch
|
||||
|
||||
self._do_execute(lambda *x: execute_batch(self.txn, *x), sql, args)
|
||||
else:
|
||||
for val in args:
|
||||
@@ -134,10 +141,7 @@ class LoggingTransaction(object):
|
||||
sql = self.database_engine.convert_param_style(sql)
|
||||
if args:
|
||||
try:
|
||||
sql_logger.debug(
|
||||
"[SQL values] {%s} %r",
|
||||
self.name, args[0]
|
||||
)
|
||||
sql_logger.debug("[SQL values] {%s} %r", self.name, args[0])
|
||||
except Exception:
|
||||
# Don't let logging failures stop SQL from working
|
||||
pass
|
||||
@@ -145,9 +149,7 @@ class LoggingTransaction(object):
|
||||
start = time.time()
|
||||
|
||||
try:
|
||||
return func(
|
||||
sql, *args
|
||||
)
|
||||
return func(sql, *args)
|
||||
except Exception as e:
|
||||
logger.debug("[SQL FAIL] {%s} %s", self.name, e)
|
||||
raise
|
||||
@@ -176,11 +178,9 @@ class PerformanceCounters(object):
|
||||
counters = []
|
||||
for name, (count, cum_time) in iteritems(self.current_counters):
|
||||
prev_count, prev_time = self.previous_counters.get(name, (0, 0))
|
||||
counters.append((
|
||||
(cum_time - prev_time) / interval_duration,
|
||||
count - prev_count,
|
||||
name
|
||||
))
|
||||
counters.append(
|
||||
((cum_time - prev_time) / interval_duration, count - prev_count, name)
|
||||
)
|
||||
|
||||
self.previous_counters = dict(self.current_counters)
|
||||
|
||||
@@ -212,8 +212,9 @@ class SQLBaseStore(object):
|
||||
self._txn_perf_counters = PerformanceCounters()
|
||||
self._get_event_counters = PerformanceCounters()
|
||||
|
||||
self._get_event_cache = Cache("*getEvent*", keylen=3,
|
||||
max_entries=hs.config.event_cache_size)
|
||||
self._get_event_cache = Cache(
|
||||
"*getEvent*", keylen=3, max_entries=hs.config.event_cache_size
|
||||
)
|
||||
|
||||
self._event_fetch_lock = threading.Condition()
|
||||
self._event_fetch_list = []
|
||||
@@ -239,7 +240,7 @@ class SQLBaseStore(object):
|
||||
0.0,
|
||||
run_as_background_process,
|
||||
"upsert_safety_check",
|
||||
self._check_safe_to_upsert
|
||||
self._check_safe_to_upsert,
|
||||
)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
@@ -271,7 +272,7 @@ class SQLBaseStore(object):
|
||||
15.0,
|
||||
run_as_background_process,
|
||||
"upsert_safety_check",
|
||||
self._check_safe_to_upsert
|
||||
self._check_safe_to_upsert,
|
||||
)
|
||||
|
||||
def start_profiling(self):
|
||||
@@ -298,13 +299,16 @@ class SQLBaseStore(object):
|
||||
|
||||
perf_logger.info(
|
||||
"Total database time: %.3f%% {%s} {%s}",
|
||||
ratio * 100, top_three_counters, top_3_event_counters
|
||||
ratio * 100,
|
||||
top_three_counters,
|
||||
top_3_event_counters,
|
||||
)
|
||||
|
||||
self._clock.looping_call(loop, 10000)
|
||||
|
||||
def _new_transaction(self, conn, desc, after_callbacks, exception_callbacks,
|
||||
func, *args, **kwargs):
|
||||
def _new_transaction(
|
||||
self, conn, desc, after_callbacks, exception_callbacks, func, *args, **kwargs
|
||||
):
|
||||
start = time.time()
|
||||
txn_id = self._TXN_ID
|
||||
|
||||
@@ -312,7 +316,7 @@ class SQLBaseStore(object):
|
||||
# growing really large.
|
||||
self._TXN_ID = (self._TXN_ID + 1) % (MAX_TXN_ID)
|
||||
|
||||
name = "%s-%x" % (desc, txn_id, )
|
||||
name = "%s-%x" % (desc, txn_id)
|
||||
|
||||
transaction_logger.debug("[TXN START] {%s}", name)
|
||||
|
||||
@@ -323,7 +327,10 @@ class SQLBaseStore(object):
|
||||
try:
|
||||
txn = conn.cursor()
|
||||
txn = LoggingTransaction(
|
||||
txn, name, self.database_engine, after_callbacks,
|
||||
txn,
|
||||
name,
|
||||
self.database_engine,
|
||||
after_callbacks,
|
||||
exception_callbacks,
|
||||
)
|
||||
r = func(txn, *args, **kwargs)
|
||||
@@ -334,7 +341,10 @@ class SQLBaseStore(object):
|
||||
# transaction.
|
||||
logger.warning(
|
||||
"[TXN OPERROR] {%s} %s %d/%d",
|
||||
name, exception_to_unicode(e), i, N
|
||||
name,
|
||||
exception_to_unicode(e),
|
||||
i,
|
||||
N,
|
||||
)
|
||||
if i < N:
|
||||
i += 1
|
||||
@@ -342,8 +352,7 @@ class SQLBaseStore(object):
|
||||
conn.rollback()
|
||||
except self.database_engine.module.Error as e1:
|
||||
logger.warning(
|
||||
"[TXN EROLL] {%s} %s",
|
||||
name, exception_to_unicode(e1),
|
||||
"[TXN EROLL] {%s} %s", name, exception_to_unicode(e1)
|
||||
)
|
||||
continue
|
||||
raise
|
||||
@@ -357,7 +366,8 @@ class SQLBaseStore(object):
|
||||
except self.database_engine.module.Error as e1:
|
||||
logger.warning(
|
||||
"[TXN EROLL] {%s} %s",
|
||||
name, exception_to_unicode(e1),
|
||||
name,
|
||||
exception_to_unicode(e1),
|
||||
)
|
||||
continue
|
||||
raise
|
||||
@@ -396,16 +406,17 @@ class SQLBaseStore(object):
|
||||
exception_callbacks = []
|
||||
|
||||
if LoggingContext.current_context() == LoggingContext.sentinel:
|
||||
logger.warn(
|
||||
"Starting db txn '%s' from sentinel context",
|
||||
desc,
|
||||
)
|
||||
logger.warn("Starting db txn '%s' from sentinel context", desc)
|
||||
|
||||
try:
|
||||
result = yield self.runWithConnection(
|
||||
self._new_transaction,
|
||||
desc, after_callbacks, exception_callbacks, func,
|
||||
*args, **kwargs
|
||||
desc,
|
||||
after_callbacks,
|
||||
exception_callbacks,
|
||||
func,
|
||||
*args,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
for after_callback, after_args, after_kwargs in after_callbacks:
|
||||
@@ -434,7 +445,7 @@ class SQLBaseStore(object):
|
||||
parent_context = LoggingContext.current_context()
|
||||
if parent_context == LoggingContext.sentinel:
|
||||
logger.warn(
|
||||
"Starting db connection from sentinel context: metrics will be lost",
|
||||
"Starting db connection from sentinel context: metrics will be lost"
|
||||
)
|
||||
parent_context = None
|
||||
|
||||
@@ -453,9 +464,7 @@ class SQLBaseStore(object):
|
||||
return func(conn, *args, **kwargs)
|
||||
|
||||
with PreserveLoggingContext():
|
||||
result = yield self._db_pool.runWithConnection(
|
||||
inner_func, *args, **kwargs
|
||||
)
|
||||
result = yield self._db_pool.runWithConnection(inner_func, *args, **kwargs)
|
||||
|
||||
defer.returnValue(result)
|
||||
|
||||
@@ -469,9 +478,7 @@ class SQLBaseStore(object):
|
||||
A list of dicts where the key is the column header.
|
||||
"""
|
||||
col_headers = list(intern(str(column[0])) for column in cursor.description)
|
||||
results = list(
|
||||
dict(zip(col_headers, row)) for row in cursor
|
||||
)
|
||||
results = list(dict(zip(col_headers, row)) for row in cursor)
|
||||
return results
|
||||
|
||||
def _execute(self, desc, decoder, query, *args):
|
||||
@@ -485,6 +492,7 @@ class SQLBaseStore(object):
|
||||
Returns:
|
||||
The result of decoder(results)
|
||||
"""
|
||||
|
||||
def interaction(txn):
|
||||
txn.execute(query, args)
|
||||
if decoder:
|
||||
@@ -498,8 +506,7 @@ class SQLBaseStore(object):
|
||||
# no complex WHERE clauses, just a dict of values for columns.
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _simple_insert(self, table, values, or_ignore=False,
|
||||
desc="_simple_insert"):
|
||||
def _simple_insert(self, table, values, or_ignore=False, desc="_simple_insert"):
|
||||
"""Executes an INSERT query on the named table.
|
||||
|
||||
Args:
|
||||
@@ -511,10 +518,7 @@ class SQLBaseStore(object):
|
||||
`or_ignore` is True
|
||||
"""
|
||||
try:
|
||||
yield self.runInteraction(
|
||||
desc,
|
||||
self._simple_insert_txn, table, values,
|
||||
)
|
||||
yield self.runInteraction(desc, self._simple_insert_txn, table, values)
|
||||
except self.database_engine.module.IntegrityError:
|
||||
# We have to do or_ignore flag at this layer, since we can't reuse
|
||||
# a cursor after we receive an error from the db.
|
||||
@@ -530,15 +534,13 @@ class SQLBaseStore(object):
|
||||
sql = "INSERT INTO %s (%s) VALUES(%s)" % (
|
||||
table,
|
||||
", ".join(k for k in keys),
|
||||
", ".join("?" for _ in keys)
|
||||
", ".join("?" for _ in keys),
|
||||
)
|
||||
|
||||
txn.execute(sql, vals)
|
||||
|
||||
def _simple_insert_many(self, table, values, desc):
|
||||
return self.runInteraction(
|
||||
desc, self._simple_insert_many_txn, table, values
|
||||
)
|
||||
return self.runInteraction(desc, self._simple_insert_many_txn, table, values)
|
||||
|
||||
@staticmethod
|
||||
def _simple_insert_many_txn(txn, table, values):
|
||||
@@ -553,24 +555,18 @@ class SQLBaseStore(object):
|
||||
#
|
||||
# The sort is to ensure that we don't rely on dictionary iteration
|
||||
# order.
|
||||
keys, vals = zip(*[
|
||||
zip(
|
||||
*(sorted(i.items(), key=lambda kv: kv[0]))
|
||||
)
|
||||
for i in values
|
||||
if i
|
||||
])
|
||||
keys, vals = zip(
|
||||
*[zip(*(sorted(i.items(), key=lambda kv: kv[0]))) for i in values if i]
|
||||
)
|
||||
|
||||
for k in keys:
|
||||
if k != keys[0]:
|
||||
raise RuntimeError(
|
||||
"All items must have the same keys"
|
||||
)
|
||||
raise RuntimeError("All items must have the same keys")
|
||||
|
||||
sql = "INSERT INTO %s (%s) VALUES(%s)" % (
|
||||
table,
|
||||
", ".join(k for k in keys[0]),
|
||||
", ".join("?" for _ in keys[0])
|
||||
", ".join("?" for _ in keys[0]),
|
||||
)
|
||||
|
||||
txn.executemany(sql, vals)
|
||||
@@ -583,7 +579,7 @@ class SQLBaseStore(object):
|
||||
values,
|
||||
insertion_values={},
|
||||
desc="_simple_upsert",
|
||||
lock=True
|
||||
lock=True,
|
||||
):
|
||||
"""
|
||||
|
||||
@@ -599,7 +595,7 @@ class SQLBaseStore(object):
|
||||
|
||||
Args:
|
||||
table (str): The table to upsert into
|
||||
keyvalues (dict): The unique key tables and their new values
|
||||
keyvalues (dict): The unique key columns and their new values
|
||||
values (dict): The nonunique columns and their new values
|
||||
insertion_values (dict): additional key/values to use only when
|
||||
inserting
|
||||
@@ -631,17 +627,11 @@ class SQLBaseStore(object):
|
||||
|
||||
# presumably we raced with another transaction: let's retry.
|
||||
logger.warn(
|
||||
"%s when upserting into %s; retrying: %s", e.__name__, table, e
|
||||
"IntegrityError when upserting into %s; retrying: %s", table, e
|
||||
)
|
||||
|
||||
def _simple_upsert_txn(
|
||||
self,
|
||||
txn,
|
||||
table,
|
||||
keyvalues,
|
||||
values,
|
||||
insertion_values={},
|
||||
lock=True,
|
||||
self, txn, table, keyvalues, values, insertion_values={}, lock=True
|
||||
):
|
||||
"""
|
||||
Pick the UPSERT method which works best on the platform. Either the
|
||||
@@ -665,11 +655,7 @@ class SQLBaseStore(object):
|
||||
and table not in self._unsafe_to_upsert_tables
|
||||
):
|
||||
return self._simple_upsert_txn_native_upsert(
|
||||
txn,
|
||||
table,
|
||||
keyvalues,
|
||||
values,
|
||||
insertion_values=insertion_values,
|
||||
txn, table, keyvalues, values, insertion_values=insertion_values
|
||||
)
|
||||
else:
|
||||
return self._simple_upsert_txn_emulated(
|
||||
@@ -714,7 +700,7 @@ class SQLBaseStore(object):
|
||||
# SELECT instead to see if it exists.
|
||||
sql = "SELECT 1 FROM %s WHERE %s" % (
|
||||
table,
|
||||
" AND ".join(_getwhere(k) for k in keyvalues)
|
||||
" AND ".join(_getwhere(k) for k in keyvalues),
|
||||
)
|
||||
sqlargs = list(keyvalues.values())
|
||||
txn.execute(sql, sqlargs)
|
||||
@@ -726,7 +712,7 @@ class SQLBaseStore(object):
|
||||
sql = "UPDATE %s SET %s WHERE %s" % (
|
||||
table,
|
||||
", ".join("%s = ?" % (k,) for k in values),
|
||||
" AND ".join(_getwhere(k) for k in keyvalues)
|
||||
" AND ".join(_getwhere(k) for k in keyvalues),
|
||||
)
|
||||
sqlargs = list(values.values()) + list(keyvalues.values())
|
||||
|
||||
@@ -773,19 +759,14 @@ class SQLBaseStore(object):
|
||||
latter = "NOTHING"
|
||||
else:
|
||||
allvalues.update(values)
|
||||
latter = (
|
||||
"UPDATE SET " + ", ".join(k + "=EXCLUDED." + k for k in values)
|
||||
)
|
||||
latter = "UPDATE SET " + ", ".join(k + "=EXCLUDED." + k for k in values)
|
||||
|
||||
sql = (
|
||||
"INSERT INTO %s (%s) VALUES (%s) "
|
||||
"ON CONFLICT (%s) DO %s"
|
||||
) % (
|
||||
sql = ("INSERT INTO %s (%s) VALUES (%s) " "ON CONFLICT (%s) DO %s") % (
|
||||
table,
|
||||
", ".join(k for k in allvalues),
|
||||
", ".join("?" for _ in allvalues),
|
||||
", ".join(k for k in keyvalues),
|
||||
latter
|
||||
latter,
|
||||
)
|
||||
txn.execute(sql, list(allvalues.values()))
|
||||
|
||||
@@ -870,8 +851,8 @@ class SQLBaseStore(object):
|
||||
latter = "NOTHING"
|
||||
value_values = [() for x in range(len(key_values))]
|
||||
else:
|
||||
latter = (
|
||||
"UPDATE SET " + ", ".join(k + "=EXCLUDED." + k for k in value_names)
|
||||
latter = "UPDATE SET " + ", ".join(
|
||||
k + "=EXCLUDED." + k for k in value_names
|
||||
)
|
||||
|
||||
sql = "INSERT INTO %s (%s) VALUES (%s) ON CONFLICT (%s) DO %s" % (
|
||||
@@ -889,8 +870,9 @@ class SQLBaseStore(object):
|
||||
|
||||
return txn.execute_batch(sql, args)
|
||||
|
||||
def _simple_select_one(self, table, keyvalues, retcols,
|
||||
allow_none=False, desc="_simple_select_one"):
|
||||
def _simple_select_one(
|
||||
self, table, keyvalues, retcols, allow_none=False, desc="_simple_select_one"
|
||||
):
|
||||
"""Executes a SELECT query on the named table, which is expected to
|
||||
return a single row, returning multiple columns from it.
|
||||
|
||||
@@ -903,14 +885,17 @@ class SQLBaseStore(object):
|
||||
statement returns no rows
|
||||
"""
|
||||
return self.runInteraction(
|
||||
desc,
|
||||
self._simple_select_one_txn,
|
||||
table, keyvalues, retcols, allow_none,
|
||||
desc, self._simple_select_one_txn, table, keyvalues, retcols, allow_none
|
||||
)
|
||||
|
||||
def _simple_select_one_onecol(self, table, keyvalues, retcol,
|
||||
allow_none=False,
|
||||
desc="_simple_select_one_onecol"):
|
||||
def _simple_select_one_onecol(
|
||||
self,
|
||||
table,
|
||||
keyvalues,
|
||||
retcol,
|
||||
allow_none=False,
|
||||
desc="_simple_select_one_onecol",
|
||||
):
|
||||
"""Executes a SELECT query on the named table, which is expected to
|
||||
return a single row, returning a single column from it.
|
||||
|
||||
@@ -922,17 +907,18 @@ class SQLBaseStore(object):
|
||||
return self.runInteraction(
|
||||
desc,
|
||||
self._simple_select_one_onecol_txn,
|
||||
table, keyvalues, retcol, allow_none=allow_none,
|
||||
table,
|
||||
keyvalues,
|
||||
retcol,
|
||||
allow_none=allow_none,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _simple_select_one_onecol_txn(cls, txn, table, keyvalues, retcol,
|
||||
allow_none=False):
|
||||
def _simple_select_one_onecol_txn(
|
||||
cls, txn, table, keyvalues, retcol, allow_none=False
|
||||
):
|
||||
ret = cls._simple_select_onecol_txn(
|
||||
txn,
|
||||
table=table,
|
||||
keyvalues=keyvalues,
|
||||
retcol=retcol,
|
||||
txn, table=table, keyvalues=keyvalues, retcol=retcol
|
||||
)
|
||||
|
||||
if ret:
|
||||
@@ -945,12 +931,7 @@ class SQLBaseStore(object):
|
||||
|
||||
@staticmethod
|
||||
def _simple_select_onecol_txn(txn, table, keyvalues, retcol):
|
||||
sql = (
|
||||
"SELECT %(retcol)s FROM %(table)s"
|
||||
) % {
|
||||
"retcol": retcol,
|
||||
"table": table,
|
||||
}
|
||||
sql = ("SELECT %(retcol)s FROM %(table)s") % {"retcol": retcol, "table": table}
|
||||
|
||||
if keyvalues:
|
||||
sql += " WHERE %s" % " AND ".join("%s = ?" % k for k in iterkeys(keyvalues))
|
||||
@@ -960,8 +941,9 @@ class SQLBaseStore(object):
|
||||
|
||||
return [r[0] for r in txn]
|
||||
|
||||
def _simple_select_onecol(self, table, keyvalues, retcol,
|
||||
desc="_simple_select_onecol"):
|
||||
def _simple_select_onecol(
|
||||
self, table, keyvalues, retcol, desc="_simple_select_onecol"
|
||||
):
|
||||
"""Executes a SELECT query on the named table, which returns a list
|
||||
comprising of the values of the named column from the selected rows.
|
||||
|
||||
@@ -974,13 +956,12 @@ class SQLBaseStore(object):
|
||||
Deferred: Results in a list
|
||||
"""
|
||||
return self.runInteraction(
|
||||
desc,
|
||||
self._simple_select_onecol_txn,
|
||||
table, keyvalues, retcol
|
||||
desc, self._simple_select_onecol_txn, table, keyvalues, retcol
|
||||
)
|
||||
|
||||
def _simple_select_list(self, table, keyvalues, retcols,
|
||||
desc="_simple_select_list"):
|
||||
def _simple_select_list(
|
||||
self, table, keyvalues, retcols, desc="_simple_select_list"
|
||||
):
|
||||
"""Executes a SELECT query on the named table, which may return zero or
|
||||
more rows, returning the result as a list of dicts.
|
||||
|
||||
@@ -994,9 +975,7 @@ class SQLBaseStore(object):
|
||||
defer.Deferred: resolves to list[dict[str, Any]]
|
||||
"""
|
||||
return self.runInteraction(
|
||||
desc,
|
||||
self._simple_select_list_txn,
|
||||
table, keyvalues, retcols
|
||||
desc, self._simple_select_list_txn, table, keyvalues, retcols
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -1016,22 +995,26 @@ class SQLBaseStore(object):
|
||||
sql = "SELECT %s FROM %s WHERE %s" % (
|
||||
", ".join(retcols),
|
||||
table,
|
||||
" AND ".join("%s = ?" % (k, ) for k in keyvalues)
|
||||
" AND ".join("%s = ?" % (k,) for k in keyvalues),
|
||||
)
|
||||
txn.execute(sql, list(keyvalues.values()))
|
||||
else:
|
||||
sql = "SELECT %s FROM %s" % (
|
||||
", ".join(retcols),
|
||||
table
|
||||
)
|
||||
sql = "SELECT %s FROM %s" % (", ".join(retcols), table)
|
||||
txn.execute(sql)
|
||||
|
||||
return cls.cursor_to_dict(txn)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _simple_select_many_batch(self, table, column, iterable, retcols,
|
||||
keyvalues={}, desc="_simple_select_many_batch",
|
||||
batch_size=100):
|
||||
def _simple_select_many_batch(
|
||||
self,
|
||||
table,
|
||||
column,
|
||||
iterable,
|
||||
retcols,
|
||||
keyvalues={},
|
||||
desc="_simple_select_many_batch",
|
||||
batch_size=100,
|
||||
):
|
||||
"""Executes a SELECT query on the named table, which may return zero or
|
||||
more rows, returning the result as a list of dicts.
|
||||
|
||||
@@ -1053,14 +1036,17 @@ class SQLBaseStore(object):
|
||||
it_list = list(iterable)
|
||||
|
||||
chunks = [
|
||||
it_list[i:i + batch_size]
|
||||
for i in range(0, len(it_list), batch_size)
|
||||
it_list[i : i + batch_size] for i in range(0, len(it_list), batch_size)
|
||||
]
|
||||
for chunk in chunks:
|
||||
rows = yield self.runInteraction(
|
||||
desc,
|
||||
self._simple_select_many_txn,
|
||||
table, column, chunk, keyvalues, retcols
|
||||
table,
|
||||
column,
|
||||
chunk,
|
||||
keyvalues,
|
||||
retcols,
|
||||
)
|
||||
|
||||
results.extend(rows)
|
||||
@@ -1089,9 +1075,7 @@ class SQLBaseStore(object):
|
||||
|
||||
clauses = []
|
||||
values = []
|
||||
clauses.append(
|
||||
"%s IN (%s)" % (column, ",".join("?" for _ in iterable))
|
||||
)
|
||||
clauses.append("%s IN (%s)" % (column, ",".join("?" for _ in iterable)))
|
||||
values.extend(iterable)
|
||||
|
||||
for key, value in iteritems(keyvalues):
|
||||
@@ -1099,19 +1083,14 @@ class SQLBaseStore(object):
|
||||
values.append(value)
|
||||
|
||||
if clauses:
|
||||
sql = "%s WHERE %s" % (
|
||||
sql,
|
||||
" AND ".join(clauses),
|
||||
)
|
||||
sql = "%s WHERE %s" % (sql, " AND ".join(clauses))
|
||||
|
||||
txn.execute(sql, values)
|
||||
return cls.cursor_to_dict(txn)
|
||||
|
||||
def _simple_update(self, table, keyvalues, updatevalues, desc):
|
||||
return self.runInteraction(
|
||||
desc,
|
||||
self._simple_update_txn,
|
||||
table, keyvalues, updatevalues,
|
||||
desc, self._simple_update_txn, table, keyvalues, updatevalues
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -1127,15 +1106,13 @@ class SQLBaseStore(object):
|
||||
where,
|
||||
)
|
||||
|
||||
txn.execute(
|
||||
update_sql,
|
||||
list(updatevalues.values()) + list(keyvalues.values())
|
||||
)
|
||||
txn.execute(update_sql, list(updatevalues.values()) + list(keyvalues.values()))
|
||||
|
||||
return txn.rowcount
|
||||
|
||||
def _simple_update_one(self, table, keyvalues, updatevalues,
|
||||
desc="_simple_update_one"):
|
||||
def _simple_update_one(
|
||||
self, table, keyvalues, updatevalues, desc="_simple_update_one"
|
||||
):
|
||||
"""Executes an UPDATE query on the named table, setting new values for
|
||||
columns in a row matching the key values.
|
||||
|
||||
@@ -1154,9 +1131,7 @@ class SQLBaseStore(object):
|
||||
the update column in the 'keyvalues' dict as well.
|
||||
"""
|
||||
return self.runInteraction(
|
||||
desc,
|
||||
self._simple_update_one_txn,
|
||||
table, keyvalues, updatevalues,
|
||||
desc, self._simple_update_one_txn, table, keyvalues, updatevalues
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -1169,12 +1144,11 @@ class SQLBaseStore(object):
|
||||
raise StoreError(500, "More than one row matched (%s)" % (table,))
|
||||
|
||||
@staticmethod
|
||||
def _simple_select_one_txn(txn, table, keyvalues, retcols,
|
||||
allow_none=False):
|
||||
def _simple_select_one_txn(txn, table, keyvalues, retcols, allow_none=False):
|
||||
select_sql = "SELECT %s FROM %s WHERE %s" % (
|
||||
", ".join(retcols),
|
||||
table,
|
||||
" AND ".join("%s = ?" % (k,) for k in keyvalues)
|
||||
" AND ".join("%s = ?" % (k,) for k in keyvalues),
|
||||
)
|
||||
|
||||
txn.execute(select_sql, list(keyvalues.values()))
|
||||
@@ -1197,9 +1171,7 @@ class SQLBaseStore(object):
|
||||
table : string giving the table name
|
||||
keyvalues : dict of column names and values to select the row with
|
||||
"""
|
||||
return self.runInteraction(
|
||||
desc, self._simple_delete_one_txn, table, keyvalues
|
||||
)
|
||||
return self.runInteraction(desc, self._simple_delete_one_txn, table, keyvalues)
|
||||
|
||||
@staticmethod
|
||||
def _simple_delete_one_txn(txn, table, keyvalues):
|
||||
@@ -1212,7 +1184,7 @@ class SQLBaseStore(object):
|
||||
"""
|
||||
sql = "DELETE FROM %s WHERE %s" % (
|
||||
table,
|
||||
" AND ".join("%s = ?" % (k, ) for k in keyvalues)
|
||||
" AND ".join("%s = ?" % (k,) for k in keyvalues),
|
||||
)
|
||||
|
||||
txn.execute(sql, list(keyvalues.values()))
|
||||
@@ -1222,15 +1194,13 @@ class SQLBaseStore(object):
|
||||
raise StoreError(500, "More than one row matched (%s)" % (table,))
|
||||
|
||||
def _simple_delete(self, table, keyvalues, desc):
|
||||
return self.runInteraction(
|
||||
desc, self._simple_delete_txn, table, keyvalues
|
||||
)
|
||||
return self.runInteraction(desc, self._simple_delete_txn, table, keyvalues)
|
||||
|
||||
@staticmethod
|
||||
def _simple_delete_txn(txn, table, keyvalues):
|
||||
sql = "DELETE FROM %s WHERE %s" % (
|
||||
table,
|
||||
" AND ".join("%s = ?" % (k, ) for k in keyvalues)
|
||||
" AND ".join("%s = ?" % (k,) for k in keyvalues),
|
||||
)
|
||||
|
||||
return txn.execute(sql, list(keyvalues.values()))
|
||||
@@ -1260,9 +1230,7 @@ class SQLBaseStore(object):
|
||||
|
||||
clauses = []
|
||||
values = []
|
||||
clauses.append(
|
||||
"%s IN (%s)" % (column, ",".join("?" for _ in iterable))
|
||||
)
|
||||
clauses.append("%s IN (%s)" % (column, ",".join("?" for _ in iterable)))
|
||||
values.extend(iterable)
|
||||
|
||||
for key, value in iteritems(keyvalues):
|
||||
@@ -1270,14 +1238,12 @@ class SQLBaseStore(object):
|
||||
values.append(value)
|
||||
|
||||
if clauses:
|
||||
sql = "%s WHERE %s" % (
|
||||
sql,
|
||||
" AND ".join(clauses),
|
||||
)
|
||||
sql = "%s WHERE %s" % (sql, " AND ".join(clauses))
|
||||
return txn.execute(sql, values)
|
||||
|
||||
def _get_cache_dict(self, db_conn, table, entity_column, stream_column,
|
||||
max_value, limit=100000):
|
||||
def _get_cache_dict(
|
||||
self, db_conn, table, entity_column, stream_column, max_value, limit=100000
|
||||
):
|
||||
# Fetch a mapping of room_id -> max stream position for "recent" rooms.
|
||||
# It doesn't really matter how many we get, the StreamChangeCache will
|
||||
# do the right thing to ensure it respects the max size of cache.
|
||||
@@ -1297,10 +1263,7 @@ class SQLBaseStore(object):
|
||||
txn = db_conn.cursor()
|
||||
txn.execute(sql, (int(max_value),))
|
||||
|
||||
cache = {
|
||||
row[0]: int(row[1])
|
||||
for row in txn
|
||||
}
|
||||
cache = {row[0]: int(row[1]) for row in txn}
|
||||
|
||||
txn.close()
|
||||
|
||||
@@ -1342,9 +1305,7 @@ class SQLBaseStore(object):
|
||||
# be safe.
|
||||
for chunk in batch_iter(members_changed, 50):
|
||||
keys = itertools.chain([room_id], chunk)
|
||||
self._send_invalidation_to_replication(
|
||||
txn, _CURRENT_STATE_CACHE_NAME, keys,
|
||||
)
|
||||
self._send_invalidation_to_replication(txn, _CURRENT_STATE_CACHE_NAME, keys)
|
||||
|
||||
def _invalidate_state_caches(self, room_id, members_changed):
|
||||
"""Invalidates caches that are based on the current state, but does
|
||||
@@ -1355,28 +1316,13 @@ class SQLBaseStore(object):
|
||||
members_changed (iterable[str]): The user_ids of members that have
|
||||
changed
|
||||
"""
|
||||
for member in members_changed:
|
||||
self._attempt_to_invalidate_cache(
|
||||
"get_rooms_for_user_with_stream_ordering", (member,),
|
||||
)
|
||||
|
||||
for host in set(get_domain_from_id(u) for u in members_changed):
|
||||
self._attempt_to_invalidate_cache(
|
||||
"is_host_joined", (room_id, host,),
|
||||
)
|
||||
self._attempt_to_invalidate_cache(
|
||||
"was_host_joined", (room_id, host,),
|
||||
)
|
||||
self._attempt_to_invalidate_cache("is_host_joined", (room_id, host))
|
||||
self._attempt_to_invalidate_cache("was_host_joined", (room_id, host))
|
||||
|
||||
self._attempt_to_invalidate_cache(
|
||||
"get_users_in_room", (room_id,),
|
||||
)
|
||||
self._attempt_to_invalidate_cache(
|
||||
"get_room_summary", (room_id,),
|
||||
)
|
||||
self._attempt_to_invalidate_cache(
|
||||
"get_current_state_ids", (room_id,),
|
||||
)
|
||||
self._attempt_to_invalidate_cache("get_users_in_room", (room_id,))
|
||||
self._attempt_to_invalidate_cache("get_room_summary", (room_id,))
|
||||
self._attempt_to_invalidate_cache("get_current_state_ids", (room_id,))
|
||||
|
||||
def _attempt_to_invalidate_cache(self, cache_name, key):
|
||||
"""Attempts to invalidate the cache of the given name, ignoring if the
|
||||
@@ -1424,7 +1370,7 @@ class SQLBaseStore(object):
|
||||
"cache_func": cache_name,
|
||||
"keys": list(keys),
|
||||
"invalidation_ts": self.clock.time_msec(),
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
def get_all_updated_caches(self, last_id, current_id, limit):
|
||||
@@ -1440,11 +1386,10 @@ class SQLBaseStore(object):
|
||||
" FROM cache_invalidation_stream"
|
||||
" WHERE stream_id > ? ORDER BY stream_id ASC LIMIT ?"
|
||||
)
|
||||
txn.execute(sql, (last_id, limit,))
|
||||
txn.execute(sql, (last_id, limit))
|
||||
return txn.fetchall()
|
||||
return self.runInteraction(
|
||||
"get_all_updated_caches", get_all_updated_caches_txn
|
||||
)
|
||||
|
||||
return self.runInteraction("get_all_updated_caches", get_all_updated_caches_txn)
|
||||
|
||||
def get_cache_stream_token(self):
|
||||
if self._cache_id_gen:
|
||||
@@ -1452,33 +1397,61 @@ class SQLBaseStore(object):
|
||||
else:
|
||||
return 0
|
||||
|
||||
def _simple_select_list_paginate(self, table, keyvalues, pagevalues, retcols,
|
||||
desc="_simple_select_list_paginate"):
|
||||
"""Executes a SELECT query on the named table with start and limit,
|
||||
def _simple_select_list_paginate(
|
||||
self,
|
||||
table,
|
||||
keyvalues,
|
||||
orderby,
|
||||
start,
|
||||
limit,
|
||||
retcols,
|
||||
order_direction="ASC",
|
||||
desc="_simple_select_list_paginate",
|
||||
):
|
||||
"""
|
||||
Executes a SELECT query on the named table with start and limit,
|
||||
of row numbers, which may return zero or number of rows from start to limit,
|
||||
returning the result as a list of dicts.
|
||||
|
||||
Args:
|
||||
table (str): the table name
|
||||
keyvalues (dict[str, Any] | None):
|
||||
keyvalues (dict[str, T] | None):
|
||||
column names and values to select the rows with, or None to not
|
||||
apply a WHERE clause.
|
||||
orderby (str): Column to order the results by.
|
||||
start (int): Index to begin the query at.
|
||||
limit (int): Number of results to return.
|
||||
retcols (iterable[str]): the names of the columns to return
|
||||
order (str): order the select by this column
|
||||
start (int): start number to begin the query from
|
||||
limit (int): number of rows to reterive
|
||||
order_direction (str): Whether the results should be ordered "ASC" or "DESC".
|
||||
Returns:
|
||||
defer.Deferred: resolves to list[dict[str, Any]]
|
||||
"""
|
||||
return self.runInteraction(
|
||||
desc,
|
||||
self._simple_select_list_paginate_txn,
|
||||
table, keyvalues, pagevalues, retcols
|
||||
table,
|
||||
keyvalues,
|
||||
orderby,
|
||||
start,
|
||||
limit,
|
||||
retcols,
|
||||
order_direction=order_direction,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _simple_select_list_paginate_txn(cls, txn, table, keyvalues, pagevalues, retcols):
|
||||
"""Executes a SELECT query on the named table with start and limit,
|
||||
def _simple_select_list_paginate_txn(
|
||||
cls,
|
||||
txn,
|
||||
table,
|
||||
keyvalues,
|
||||
orderby,
|
||||
start,
|
||||
limit,
|
||||
retcols,
|
||||
order_direction="ASC",
|
||||
):
|
||||
"""
|
||||
Executes a SELECT query on the named table with start and limit,
|
||||
of row numbers, which may return zero or number of rows from start to limit,
|
||||
returning the result as a list of dicts.
|
||||
|
||||
@@ -1488,67 +1461,33 @@ class SQLBaseStore(object):
|
||||
keyvalues (dict[str, T] | None):
|
||||
column names and values to select the rows with, or None to not
|
||||
apply a WHERE clause.
|
||||
pagevalues ([]):
|
||||
order (str): order the select by this column
|
||||
start (int): start number to begin the query from
|
||||
limit (int): number of rows to reterive
|
||||
orderby (str): Column to order the results by.
|
||||
start (int): Index to begin the query at.
|
||||
limit (int): Number of results to return.
|
||||
retcols (iterable[str]): the names of the columns to return
|
||||
order_direction (str): Whether the results should be ordered "ASC" or "DESC".
|
||||
Returns:
|
||||
defer.Deferred: resolves to list[dict[str, Any]]
|
||||
|
||||
"""
|
||||
if order_direction not in ["ASC", "DESC"]:
|
||||
raise ValueError("order_direction must be one of 'ASC' or 'DESC'.")
|
||||
|
||||
if keyvalues:
|
||||
sql = "SELECT %s FROM %s WHERE %s ORDER BY %s" % (
|
||||
", ".join(retcols),
|
||||
table,
|
||||
" AND ".join("%s = ?" % (k,) for k in keyvalues),
|
||||
" ? ASC LIMIT ? OFFSET ?"
|
||||
)
|
||||
txn.execute(sql, list(keyvalues.values()) + list(pagevalues))
|
||||
where_clause = "WHERE " + " AND ".join("%s = ?" % (k,) for k in keyvalues)
|
||||
else:
|
||||
sql = "SELECT %s FROM %s ORDER BY %s" % (
|
||||
", ".join(retcols),
|
||||
table,
|
||||
" ? ASC LIMIT ? OFFSET ?"
|
||||
)
|
||||
txn.execute(sql, pagevalues)
|
||||
where_clause = ""
|
||||
|
||||
sql = "SELECT %s FROM %s %s ORDER BY %s %s LIMIT ? OFFSET ?" % (
|
||||
", ".join(retcols),
|
||||
table,
|
||||
where_clause,
|
||||
orderby,
|
||||
order_direction,
|
||||
)
|
||||
txn.execute(sql, list(keyvalues.values()) + [limit, start])
|
||||
|
||||
return cls.cursor_to_dict(txn)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_user_list_paginate(self, table, keyvalues, pagevalues, retcols,
|
||||
desc="get_user_list_paginate"):
|
||||
"""Get a list of users from start row to a limit number of rows. This will
|
||||
return a json object with users and total number of users in users list.
|
||||
|
||||
Args:
|
||||
table (str): the table name
|
||||
keyvalues (dict[str, Any] | None):
|
||||
column names and values to select the rows with, or None to not
|
||||
apply a WHERE clause.
|
||||
pagevalues ([]):
|
||||
order (str): order the select by this column
|
||||
start (int): start number to begin the query from
|
||||
limit (int): number of rows to reterive
|
||||
retcols (iterable[str]): the names of the columns to return
|
||||
Returns:
|
||||
defer.Deferred: resolves to json object {list[dict[str, Any]], count}
|
||||
"""
|
||||
users = yield self.runInteraction(
|
||||
desc,
|
||||
self._simple_select_list_paginate_txn,
|
||||
table, keyvalues, pagevalues, retcols
|
||||
)
|
||||
count = yield self.runInteraction(
|
||||
desc,
|
||||
self.get_user_count_txn
|
||||
)
|
||||
retval = {
|
||||
"users": users,
|
||||
"total": count
|
||||
}
|
||||
defer.returnValue(retval)
|
||||
|
||||
def get_user_count_txn(self, txn):
|
||||
"""Get a total number of registered users in the users list.
|
||||
|
||||
@@ -1561,8 +1500,9 @@ class SQLBaseStore(object):
|
||||
txn.execute(sql_count)
|
||||
return txn.fetchone()[0]
|
||||
|
||||
def _simple_search_list(self, table, term, col, retcols,
|
||||
desc="_simple_search_list"):
|
||||
def _simple_search_list(
|
||||
self, table, term, col, retcols, desc="_simple_search_list"
|
||||
):
|
||||
"""Executes a SELECT query on the named table, which may return zero or
|
||||
more rows, returning the result as a list of dicts.
|
||||
|
||||
@@ -1577,9 +1517,7 @@ class SQLBaseStore(object):
|
||||
"""
|
||||
|
||||
return self.runInteraction(
|
||||
desc,
|
||||
self._simple_search_list_txn,
|
||||
table, term, col, retcols
|
||||
desc, self._simple_search_list_txn, table, term, col, retcols
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -1598,11 +1536,7 @@ class SQLBaseStore(object):
|
||||
defer.Deferred: resolves to list[dict[str, Any]] or None
|
||||
"""
|
||||
if term:
|
||||
sql = "SELECT %s FROM %s WHERE %s LIKE ?" % (
|
||||
", ".join(retcols),
|
||||
table,
|
||||
col
|
||||
)
|
||||
sql = "SELECT %s FROM %s WHERE %s LIKE ?" % (", ".join(retcols), table, col)
|
||||
termvalues = ["%%" + term + "%%"]
|
||||
txn.execute(sql, termvalues)
|
||||
else:
|
||||
@@ -1623,6 +1557,7 @@ class _RollbackButIsFineException(Exception):
|
||||
""" This exception is used to rollback a transaction without implying
|
||||
something went wrong.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
|
||||
@@ -41,7 +41,7 @@ class AccountDataWorkerStore(SQLBaseStore):
|
||||
def __init__(self, db_conn, hs):
|
||||
account_max = self.get_max_account_data_stream_id()
|
||||
self._account_data_stream_cache = StreamChangeCache(
|
||||
"AccountDataAndTagsChangeCache", account_max,
|
||||
"AccountDataAndTagsChangeCache", account_max
|
||||
)
|
||||
|
||||
super(AccountDataWorkerStore, self).__init__(db_conn, hs)
|
||||
@@ -68,8 +68,10 @@ class AccountDataWorkerStore(SQLBaseStore):
|
||||
|
||||
def get_account_data_for_user_txn(txn):
|
||||
rows = self._simple_select_list_txn(
|
||||
txn, "account_data", {"user_id": user_id},
|
||||
["account_data_type", "content"]
|
||||
txn,
|
||||
"account_data",
|
||||
{"user_id": user_id},
|
||||
["account_data_type", "content"],
|
||||
)
|
||||
|
||||
global_account_data = {
|
||||
@@ -77,8 +79,10 @@ class AccountDataWorkerStore(SQLBaseStore):
|
||||
}
|
||||
|
||||
rows = self._simple_select_list_txn(
|
||||
txn, "room_account_data", {"user_id": user_id},
|
||||
["room_id", "account_data_type", "content"]
|
||||
txn,
|
||||
"room_account_data",
|
||||
{"user_id": user_id},
|
||||
["room_id", "account_data_type", "content"],
|
||||
)
|
||||
|
||||
by_room = {}
|
||||
@@ -100,10 +104,7 @@ class AccountDataWorkerStore(SQLBaseStore):
|
||||
"""
|
||||
result = yield self._simple_select_one_onecol(
|
||||
table="account_data",
|
||||
keyvalues={
|
||||
"user_id": user_id,
|
||||
"account_data_type": data_type,
|
||||
},
|
||||
keyvalues={"user_id": user_id, "account_data_type": data_type},
|
||||
retcol="content",
|
||||
desc="get_global_account_data_by_type_for_user",
|
||||
allow_none=True,
|
||||
@@ -124,10 +125,13 @@ class AccountDataWorkerStore(SQLBaseStore):
|
||||
Returns:
|
||||
A deferred dict of the room account_data
|
||||
"""
|
||||
|
||||
def get_account_data_for_room_txn(txn):
|
||||
rows = self._simple_select_list_txn(
|
||||
txn, "room_account_data", {"user_id": user_id, "room_id": room_id},
|
||||
["account_data_type", "content"]
|
||||
txn,
|
||||
"room_account_data",
|
||||
{"user_id": user_id, "room_id": room_id},
|
||||
["account_data_type", "content"],
|
||||
)
|
||||
|
||||
return {
|
||||
@@ -150,6 +154,7 @@ class AccountDataWorkerStore(SQLBaseStore):
|
||||
A deferred of the room account_data for that type, or None if
|
||||
there isn't any set.
|
||||
"""
|
||||
|
||||
def get_account_data_for_room_and_type_txn(txn):
|
||||
content_json = self._simple_select_one_onecol_txn(
|
||||
txn,
|
||||
@@ -160,18 +165,18 @@ class AccountDataWorkerStore(SQLBaseStore):
|
||||
"account_data_type": account_data_type,
|
||||
},
|
||||
retcol="content",
|
||||
allow_none=True
|
||||
allow_none=True,
|
||||
)
|
||||
|
||||
return json.loads(content_json) if content_json else None
|
||||
|
||||
return self.runInteraction(
|
||||
"get_account_data_for_room_and_type",
|
||||
get_account_data_for_room_and_type_txn,
|
||||
"get_account_data_for_room_and_type", get_account_data_for_room_and_type_txn
|
||||
)
|
||||
|
||||
def get_all_updated_account_data(self, last_global_id, last_room_id,
|
||||
current_id, limit):
|
||||
def get_all_updated_account_data(
|
||||
self, last_global_id, last_room_id, current_id, limit
|
||||
):
|
||||
"""Get all the client account_data that has changed on the server
|
||||
Args:
|
||||
last_global_id(int): The position to fetch from for top level data
|
||||
@@ -201,6 +206,7 @@ class AccountDataWorkerStore(SQLBaseStore):
|
||||
txn.execute(sql, (last_room_id, current_id, limit))
|
||||
room_results = txn.fetchall()
|
||||
return (global_results, room_results)
|
||||
|
||||
return self.runInteraction(
|
||||
"get_all_updated_account_data_txn", get_updated_account_data_txn
|
||||
)
|
||||
@@ -224,9 +230,7 @@ class AccountDataWorkerStore(SQLBaseStore):
|
||||
|
||||
txn.execute(sql, (user_id, stream_id))
|
||||
|
||||
global_account_data = {
|
||||
row[0]: json.loads(row[1]) for row in txn
|
||||
}
|
||||
global_account_data = {row[0]: json.loads(row[1]) for row in txn}
|
||||
|
||||
sql = (
|
||||
"SELECT room_id, account_data_type, content FROM room_account_data"
|
||||
@@ -255,7 +259,8 @@ class AccountDataWorkerStore(SQLBaseStore):
|
||||
@cachedInlineCallbacks(num_args=2, cache_context=True, max_entries=5000)
|
||||
def is_ignored_by(self, ignored_user_id, ignorer_user_id, cache_context):
|
||||
ignored_account_data = yield self.get_global_account_data_by_type_for_user(
|
||||
"m.ignored_user_list", ignorer_user_id,
|
||||
"m.ignored_user_list",
|
||||
ignorer_user_id,
|
||||
on_invalidate=cache_context.invalidate,
|
||||
)
|
||||
if not ignored_account_data:
|
||||
@@ -307,10 +312,7 @@ class AccountDataStore(AccountDataWorkerStore):
|
||||
"room_id": room_id,
|
||||
"account_data_type": account_data_type,
|
||||
},
|
||||
values={
|
||||
"stream_id": next_id,
|
||||
"content": content_json,
|
||||
},
|
||||
values={"stream_id": next_id, "content": content_json},
|
||||
lock=False,
|
||||
)
|
||||
|
||||
@@ -324,9 +326,9 @@ class AccountDataStore(AccountDataWorkerStore):
|
||||
|
||||
self._account_data_stream_cache.entity_has_changed(user_id, next_id)
|
||||
self.get_account_data_for_user.invalidate((user_id,))
|
||||
self.get_account_data_for_room.invalidate((user_id, room_id,))
|
||||
self.get_account_data_for_room.invalidate((user_id, room_id))
|
||||
self.get_account_data_for_room_and_type.prefill(
|
||||
(user_id, room_id, account_data_type,), content,
|
||||
(user_id, room_id, account_data_type), content
|
||||
)
|
||||
|
||||
result = self._account_data_id_gen.get_current_token()
|
||||
@@ -351,14 +353,8 @@ class AccountDataStore(AccountDataWorkerStore):
|
||||
yield self._simple_upsert(
|
||||
desc="add_user_account_data",
|
||||
table="account_data",
|
||||
keyvalues={
|
||||
"user_id": user_id,
|
||||
"account_data_type": account_data_type,
|
||||
},
|
||||
values={
|
||||
"stream_id": next_id,
|
||||
"content": content_json,
|
||||
},
|
||||
keyvalues={"user_id": user_id, "account_data_type": account_data_type},
|
||||
values={"stream_id": next_id, "content": content_json},
|
||||
lock=False,
|
||||
)
|
||||
|
||||
@@ -370,12 +366,10 @@ class AccountDataStore(AccountDataWorkerStore):
|
||||
# transaction.
|
||||
yield self._update_max_stream_id(next_id)
|
||||
|
||||
self._account_data_stream_cache.entity_has_changed(
|
||||
user_id, next_id,
|
||||
)
|
||||
self._account_data_stream_cache.entity_has_changed(user_id, next_id)
|
||||
self.get_account_data_for_user.invalidate((user_id,))
|
||||
self.get_global_account_data_by_type_for_user.invalidate(
|
||||
(account_data_type, user_id,)
|
||||
(account_data_type, user_id)
|
||||
)
|
||||
|
||||
result = self._account_data_id_gen.get_current_token()
|
||||
@@ -387,6 +381,7 @@ class AccountDataStore(AccountDataWorkerStore):
|
||||
Args:
|
||||
next_id(int): The the revision to advance to.
|
||||
"""
|
||||
|
||||
def _update(txn):
|
||||
update_max_id_sql = (
|
||||
"UPDATE account_data_max_stream_id"
|
||||
@@ -394,7 +389,5 @@ class AccountDataStore(AccountDataWorkerStore):
|
||||
" WHERE stream_id < ?"
|
||||
)
|
||||
txn.execute(update_max_id_sql, (next_id, next_id))
|
||||
return self.runInteraction(
|
||||
"update_account_data_max_stream_id",
|
||||
_update,
|
||||
)
|
||||
|
||||
return self.runInteraction("update_account_data_max_stream_id", _update)
|
||||
|
||||
@@ -51,8 +51,7 @@ def _make_exclusive_regex(services_cache):
|
||||
class ApplicationServiceWorkerStore(SQLBaseStore):
|
||||
def __init__(self, db_conn, hs):
|
||||
self.services_cache = load_appservices(
|
||||
hs.hostname,
|
||||
hs.config.app_service_config_files
|
||||
hs.hostname, hs.config.app_service_config_files
|
||||
)
|
||||
self.exclusive_user_regex = _make_exclusive_regex(self.services_cache)
|
||||
|
||||
@@ -122,8 +121,9 @@ class ApplicationServiceStore(ApplicationServiceWorkerStore):
|
||||
pass
|
||||
|
||||
|
||||
class ApplicationServiceTransactionWorkerStore(ApplicationServiceWorkerStore,
|
||||
EventsWorkerStore):
|
||||
class ApplicationServiceTransactionWorkerStore(
|
||||
ApplicationServiceWorkerStore, EventsWorkerStore
|
||||
):
|
||||
@defer.inlineCallbacks
|
||||
def get_appservices_by_state(self, state):
|
||||
"""Get a list of application services based on their state.
|
||||
@@ -135,9 +135,7 @@ class ApplicationServiceTransactionWorkerStore(ApplicationServiceWorkerStore,
|
||||
may be empty.
|
||||
"""
|
||||
results = yield self._simple_select_list(
|
||||
"application_services_state",
|
||||
dict(state=state),
|
||||
["as_id"]
|
||||
"application_services_state", dict(state=state), ["as_id"]
|
||||
)
|
||||
# NB: This assumes this class is linked with ApplicationServiceStore
|
||||
as_list = self.get_app_services()
|
||||
@@ -180,9 +178,7 @@ class ApplicationServiceTransactionWorkerStore(ApplicationServiceWorkerStore,
|
||||
A Deferred which resolves when the state was set successfully.
|
||||
"""
|
||||
return self._simple_upsert(
|
||||
"application_services_state",
|
||||
dict(as_id=service.id),
|
||||
dict(state=state)
|
||||
"application_services_state", dict(as_id=service.id), dict(state=state)
|
||||
)
|
||||
|
||||
def create_appservice_txn(self, service, events):
|
||||
@@ -195,6 +191,7 @@ class ApplicationServiceTransactionWorkerStore(ApplicationServiceWorkerStore,
|
||||
Returns:
|
||||
AppServiceTransaction: A new transaction.
|
||||
"""
|
||||
|
||||
def _create_appservice_txn(txn):
|
||||
# work out new txn id (highest txn id for this service += 1)
|
||||
# The highest id may be the last one sent (in which case it is last_txn)
|
||||
@@ -204,7 +201,7 @@ class ApplicationServiceTransactionWorkerStore(ApplicationServiceWorkerStore,
|
||||
|
||||
txn.execute(
|
||||
"SELECT MAX(txn_id) FROM application_services_txns WHERE as_id=?",
|
||||
(service.id,)
|
||||
(service.id,),
|
||||
)
|
||||
highest_txn_id = txn.fetchone()[0]
|
||||
if highest_txn_id is None:
|
||||
@@ -217,16 +214,11 @@ class ApplicationServiceTransactionWorkerStore(ApplicationServiceWorkerStore,
|
||||
txn.execute(
|
||||
"INSERT INTO application_services_txns(as_id, txn_id, event_ids) "
|
||||
"VALUES(?,?,?)",
|
||||
(service.id, new_txn_id, event_ids)
|
||||
)
|
||||
return AppServiceTransaction(
|
||||
service=service, id=new_txn_id, events=events
|
||||
(service.id, new_txn_id, event_ids),
|
||||
)
|
||||
return AppServiceTransaction(service=service, id=new_txn_id, events=events)
|
||||
|
||||
return self.runInteraction(
|
||||
"create_appservice_txn",
|
||||
_create_appservice_txn,
|
||||
)
|
||||
return self.runInteraction("create_appservice_txn", _create_appservice_txn)
|
||||
|
||||
def complete_appservice_txn(self, txn_id, service):
|
||||
"""Completes an application service transaction.
|
||||
@@ -252,26 +244,26 @@ class ApplicationServiceTransactionWorkerStore(ApplicationServiceWorkerStore,
|
||||
"appservice: Completing a transaction which has an ID > 1 from "
|
||||
"the last ID sent to this AS. We've either dropped events or "
|
||||
"sent it to the AS out of order. FIX ME. last_txn=%s "
|
||||
"completing_txn=%s service_id=%s", last_txn_id, txn_id,
|
||||
service.id
|
||||
"completing_txn=%s service_id=%s",
|
||||
last_txn_id,
|
||||
txn_id,
|
||||
service.id,
|
||||
)
|
||||
|
||||
# Set current txn_id for AS to 'txn_id'
|
||||
self._simple_upsert_txn(
|
||||
txn, "application_services_state", dict(as_id=service.id),
|
||||
dict(last_txn=txn_id)
|
||||
txn,
|
||||
"application_services_state",
|
||||
dict(as_id=service.id),
|
||||
dict(last_txn=txn_id),
|
||||
)
|
||||
|
||||
# Delete txn
|
||||
self._simple_delete_txn(
|
||||
txn, "application_services_txns",
|
||||
dict(txn_id=txn_id, as_id=service.id)
|
||||
txn, "application_services_txns", dict(txn_id=txn_id, as_id=service.id)
|
||||
)
|
||||
|
||||
return self.runInteraction(
|
||||
"complete_appservice_txn",
|
||||
_complete_appservice_txn,
|
||||
)
|
||||
return self.runInteraction("complete_appservice_txn", _complete_appservice_txn)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def get_oldest_unsent_txn(self, service):
|
||||
@@ -284,13 +276,14 @@ class ApplicationServiceTransactionWorkerStore(ApplicationServiceWorkerStore,
|
||||
A Deferred which resolves to an AppServiceTransaction or
|
||||
None.
|
||||
"""
|
||||
|
||||
def _get_oldest_unsent_txn(txn):
|
||||
# Monotonically increasing txn ids, so just select the smallest
|
||||
# one in the txns table (we delete them when they are sent)
|
||||
txn.execute(
|
||||
"SELECT * FROM application_services_txns WHERE as_id=?"
|
||||
" ORDER BY txn_id ASC LIMIT 1",
|
||||
(service.id,)
|
||||
(service.id,),
|
||||
)
|
||||
rows = self.cursor_to_dict(txn)
|
||||
if not rows:
|
||||
@@ -301,8 +294,7 @@ class ApplicationServiceTransactionWorkerStore(ApplicationServiceWorkerStore,
|
||||
return entry
|
||||
|
||||
entry = yield self.runInteraction(
|
||||
"get_oldest_unsent_appservice_txn",
|
||||
_get_oldest_unsent_txn,
|
||||
"get_oldest_unsent_appservice_txn", _get_oldest_unsent_txn
|
||||
)
|
||||
|
||||
if not entry:
|
||||
@@ -312,14 +304,14 @@ class ApplicationServiceTransactionWorkerStore(ApplicationServiceWorkerStore,
|
||||
|
||||
events = yield self._get_events(event_ids)
|
||||
|
||||
defer.returnValue(AppServiceTransaction(
|
||||
service=service, id=entry["txn_id"], events=events
|
||||
))
|
||||
defer.returnValue(
|
||||
AppServiceTransaction(service=service, id=entry["txn_id"], events=events)
|
||||
)
|
||||
|
||||
def _get_last_txn(self, txn, service_id):
|
||||
txn.execute(
|
||||
"SELECT last_txn FROM application_services_state WHERE as_id=?",
|
||||
(service_id,)
|
||||
(service_id,),
|
||||
)
|
||||
last_txn_id = txn.fetchone()
|
||||
if last_txn_id is None or last_txn_id[0] is None: # no row exists
|
||||
@@ -332,6 +324,7 @@ class ApplicationServiceTransactionWorkerStore(ApplicationServiceWorkerStore,
|
||||
txn.execute(
|
||||
"UPDATE appservice_stream_position SET stream_ordering = ?", (pos,)
|
||||
)
|
||||
|
||||
return self.runInteraction(
|
||||
"set_appservice_last_pos", set_appservice_last_pos_txn
|
||||
)
|
||||
@@ -362,7 +355,7 @@ class ApplicationServiceTransactionWorkerStore(ApplicationServiceWorkerStore,
|
||||
return upper_bound, [row[1] for row in rows]
|
||||
|
||||
upper_bound, event_ids = yield self.runInteraction(
|
||||
"get_new_events_for_appservice", get_new_events_for_appservice_txn,
|
||||
"get_new_events_for_appservice", get_new_events_for_appservice_txn
|
||||
)
|
||||
|
||||
events = yield self._get_events(event_ids)
|
||||
|
||||
@@ -94,16 +94,13 @@ class BackgroundUpdateStore(SQLBaseStore):
|
||||
self._all_done = False
|
||||
|
||||
def start_doing_background_updates(self):
|
||||
run_as_background_process(
|
||||
"background_updates", self._run_background_updates,
|
||||
)
|
||||
run_as_background_process("background_updates", self._run_background_updates)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _run_background_updates(self):
|
||||
logger.info("Starting background schema updates")
|
||||
while True:
|
||||
yield self.hs.get_clock().sleep(
|
||||
self.BACKGROUND_UPDATE_INTERVAL_MS / 1000.)
|
||||
yield self.hs.get_clock().sleep(self.BACKGROUND_UPDATE_INTERVAL_MS / 1000.0)
|
||||
|
||||
try:
|
||||
result = yield self.do_next_background_update(
|
||||
@@ -187,8 +184,7 @@ class BackgroundUpdateStore(SQLBaseStore):
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def _do_background_update(self, update_name, desired_duration_ms):
|
||||
logger.info("Starting update batch on background update '%s'",
|
||||
update_name)
|
||||
logger.info("Starting update batch on background update '%s'", update_name)
|
||||
|
||||
update_handler = self._background_update_handlers[update_name]
|
||||
|
||||
@@ -210,7 +206,7 @@ class BackgroundUpdateStore(SQLBaseStore):
|
||||
progress_json = yield self._simple_select_one_onecol(
|
||||
"background_updates",
|
||||
keyvalues={"update_name": update_name},
|
||||
retcol="progress_json"
|
||||
retcol="progress_json",
|
||||
)
|
||||
|
||||
progress = json.loads(progress_json)
|
||||
@@ -224,7 +220,9 @@ class BackgroundUpdateStore(SQLBaseStore):
|
||||
logger.info(
|
||||
"Updating %r. Updated %r items in %rms."
|
||||
" (total_rate=%r/ms, current_rate=%r/ms, total_updated=%r, batch_size=%r)",
|
||||
update_name, items_updated, duration_ms,
|
||||
update_name,
|
||||
items_updated,
|
||||
duration_ms,
|
||||
performance.total_items_per_ms(),
|
||||
performance.average_items_per_ms(),
|
||||
performance.total_item_count,
|
||||
@@ -264,6 +262,7 @@ class BackgroundUpdateStore(SQLBaseStore):
|
||||
Args:
|
||||
update_name (str): Name of update
|
||||
"""
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def noop_update(progress, batch_size):
|
||||
yield self._end_background_update(update_name)
|
||||
@@ -271,10 +270,16 @@ class BackgroundUpdateStore(SQLBaseStore):
|
||||
|
||||
self.register_background_update_handler(update_name, noop_update)
|
||||
|
||||
def register_background_index_update(self, update_name, index_name,
|
||||
table, columns, where_clause=None,
|
||||
unique=False,
|
||||
psql_only=False):
|
||||
def register_background_index_update(
|
||||
self,
|
||||
update_name,
|
||||
index_name,
|
||||
table,
|
||||
columns,
|
||||
where_clause=None,
|
||||
unique=False,
|
||||
psql_only=False,
|
||||
):
|
||||
"""Helper for store classes to do a background index addition
|
||||
|
||||
To use:
|
||||
@@ -320,7 +325,7 @@ class BackgroundUpdateStore(SQLBaseStore):
|
||||
"name": index_name,
|
||||
"table": table,
|
||||
"columns": ", ".join(columns),
|
||||
"where_clause": "WHERE " + where_clause if where_clause else ""
|
||||
"where_clause": "WHERE " + where_clause if where_clause else "",
|
||||
}
|
||||
logger.debug("[SQL] %s", sql)
|
||||
c.execute(sql)
|
||||
@@ -387,7 +392,7 @@ class BackgroundUpdateStore(SQLBaseStore):
|
||||
|
||||
return self._simple_insert(
|
||||
"background_updates",
|
||||
{"update_name": update_name, "progress_json": progress_json}
|
||||
{"update_name": update_name, "progress_json": progress_json},
|
||||
)
|
||||
|
||||
def _end_background_update(self, update_name):
|
||||
|
||||
@@ -37,9 +37,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
|
||||
def __init__(self, db_conn, hs):
|
||||
|
||||
self.client_ip_last_seen = Cache(
|
||||
name="client_ip_last_seen",
|
||||
keylen=4,
|
||||
max_entries=50000 * CACHE_SIZE_FACTOR,
|
||||
name="client_ip_last_seen", keylen=4, max_entries=50000 * CACHE_SIZE_FACTOR
|
||||
)
|
||||
|
||||
super(ClientIpStore, self).__init__(db_conn, hs)
|
||||
@@ -66,13 +64,11 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
|
||||
)
|
||||
|
||||
self.register_background_update_handler(
|
||||
"user_ips_analyze",
|
||||
self._analyze_user_ip,
|
||||
"user_ips_analyze", self._analyze_user_ip
|
||||
)
|
||||
|
||||
self.register_background_update_handler(
|
||||
"user_ips_remove_dupes",
|
||||
self._remove_user_ip_dupes,
|
||||
"user_ips_remove_dupes", self._remove_user_ip_dupes
|
||||
)
|
||||
|
||||
# Register a unique index
|
||||
@@ -86,8 +82,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
|
||||
|
||||
# Drop the old non-unique index
|
||||
self.register_background_update_handler(
|
||||
"user_ips_drop_nonunique_index",
|
||||
self._remove_user_ip_nonunique,
|
||||
"user_ips_drop_nonunique_index", self._remove_user_ip_nonunique
|
||||
)
|
||||
|
||||
# (user_id, access_token, ip,) -> (user_agent, device_id, last_seen)
|
||||
@@ -104,9 +99,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
|
||||
def _remove_user_ip_nonunique(self, progress, batch_size):
|
||||
def f(conn):
|
||||
txn = conn.cursor()
|
||||
txn.execute(
|
||||
"DROP INDEX IF EXISTS user_ips_user_ip"
|
||||
)
|
||||
txn.execute("DROP INDEX IF EXISTS user_ips_user_ip")
|
||||
txn.close()
|
||||
|
||||
yield self.runWithConnection(f)
|
||||
@@ -124,9 +117,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
|
||||
def user_ips_analyze(txn):
|
||||
txn.execute("ANALYZE user_ips")
|
||||
|
||||
yield self.runInteraction(
|
||||
"user_ips_analyze", user_ips_analyze
|
||||
)
|
||||
yield self.runInteraction("user_ips_analyze", user_ips_analyze)
|
||||
|
||||
yield self._end_background_update("user_ips_analyze")
|
||||
|
||||
@@ -151,7 +142,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
|
||||
LIMIT 1
|
||||
OFFSET ?
|
||||
""",
|
||||
(begin_last_seen, batch_size)
|
||||
(begin_last_seen, batch_size),
|
||||
)
|
||||
row = txn.fetchone()
|
||||
if row:
|
||||
@@ -169,7 +160,8 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
|
||||
|
||||
logger.info(
|
||||
"Scanning for duplicate 'user_ips' rows in range: %s <= last_seen < %s",
|
||||
begin_last_seen, end_last_seen,
|
||||
begin_last_seen,
|
||||
end_last_seen,
|
||||
)
|
||||
|
||||
def remove(txn):
|
||||
@@ -207,8 +199,10 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
|
||||
INNER JOIN user_ips USING (user_id, access_token, ip)
|
||||
GROUP BY user_id, access_token, ip
|
||||
HAVING count(*) > 1
|
||||
""".format(clause),
|
||||
args
|
||||
""".format(
|
||||
clause
|
||||
),
|
||||
args,
|
||||
)
|
||||
res = txn.fetchall()
|
||||
|
||||
@@ -254,7 +248,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
|
||||
DELETE FROM user_ips
|
||||
WHERE user_id = ? AND access_token = ? AND ip = ? AND last_seen < ?
|
||||
""",
|
||||
(user_id, access_token, ip, last_seen)
|
||||
(user_id, access_token, ip, last_seen),
|
||||
)
|
||||
if txn.rowcount == count - 1:
|
||||
# We deleted all but one of the duplicate rows, i.e. there
|
||||
@@ -263,7 +257,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
|
||||
continue
|
||||
elif txn.rowcount >= count:
|
||||
raise Exception(
|
||||
"We deleted more duplicate rows from 'user_ips' than expected",
|
||||
"We deleted more duplicate rows from 'user_ips' than expected"
|
||||
)
|
||||
|
||||
# The previous step didn't delete enough rows, so we fallback to
|
||||
@@ -275,7 +269,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
|
||||
DELETE FROM user_ips
|
||||
WHERE user_id = ? AND access_token = ? AND ip = ?
|
||||
""",
|
||||
(user_id, access_token, ip)
|
||||
(user_id, access_token, ip),
|
||||
)
|
||||
|
||||
# Add in one to be the last_seen
|
||||
@@ -285,7 +279,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
|
||||
(user_id, access_token, ip, device_id, user_agent, last_seen)
|
||||
VALUES (?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(user_id, access_token, ip, device_id, user_agent, last_seen)
|
||||
(user_id, access_token, ip, device_id, user_agent, last_seen),
|
||||
)
|
||||
|
||||
self._background_update_progress_txn(
|
||||
@@ -300,8 +294,9 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
|
||||
defer.returnValue(batch_size)
|
||||
|
||||
@defer.inlineCallbacks
|
||||
def insert_client_ip(self, user_id, access_token, ip, user_agent, device_id,
|
||||
now=None):
|
||||
def insert_client_ip(
|
||||
self, user_id, access_token, ip, user_agent, device_id, now=None
|
||||
):
|
||||
if not now:
|
||||
now = int(self._clock.time_msec())
|
||||
key = (user_id, access_token, ip)
|
||||
@@ -329,13 +324,10 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
|
||||
to_update = self._batch_row_update
|
||||
self._batch_row_update = {}
|
||||
return self.runInteraction(
|
||||
"_update_client_ips_batch", self._update_client_ips_batch_txn,
|
||||
to_update,
|
||||
"_update_client_ips_batch", self._update_client_ips_batch_txn, to_update
|
||||
)
|
||||
|
||||
return run_as_background_process(
|
||||
"update_client_ips", update,
|
||||
)
|
||||
return run_as_background_process("update_client_ips", update)
|
||||
|
||||
def _update_client_ips_batch_txn(self, txn, to_update):
|
||||
if "user_ips" in self._unsafe_to_upsert_tables or (
|
||||
@@ -383,7 +375,8 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
|
||||
res = yield self.runInteraction(
|
||||
"get_last_client_ip_by_device",
|
||||
self._get_last_client_ip_by_device_txn,
|
||||
user_id, device_id,
|
||||
user_id,
|
||||
device_id,
|
||||
retcols=(
|
||||
"user_id",
|
||||
"access_token",
|
||||
@@ -416,7 +409,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
|
||||
bindings = []
|
||||
if device_id is None:
|
||||
where_clauses.append("user_id = ?")
|
||||
bindings.extend((user_id, ))
|
||||
bindings.extend((user_id,))
|
||||
else:
|
||||
where_clauses.append("(user_id = ? AND device_id = ?)")
|
||||
bindings.extend((user_id, device_id))
|
||||
@@ -428,9 +421,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
|
||||
"SELECT MAX(last_seen) mls, user_id, device_id FROM user_ips "
|
||||
"WHERE %(where)s "
|
||||
"GROUP BY user_id, device_id"
|
||||
) % {
|
||||
"where": " OR ".join(where_clauses),
|
||||
}
|
||||
) % {"where": " OR ".join(where_clauses)}
|
||||
|
||||
sql = (
|
||||
"SELECT %(retcols)s FROM user_ips "
|
||||
@@ -462,9 +453,7 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
|
||||
rows = yield self._simple_select_list(
|
||||
table="user_ips",
|
||||
keyvalues={"user_id": user_id},
|
||||
retcols=[
|
||||
"access_token", "ip", "user_agent", "last_seen"
|
||||
],
|
||||
retcols=["access_token", "ip", "user_agent", "last_seen"],
|
||||
desc="get_user_ip_and_agents",
|
||||
)
|
||||
|
||||
@@ -472,12 +461,14 @@ class ClientIpStore(background_updates.BackgroundUpdateStore):
|
||||
((row["access_token"], row["ip"]), (row["user_agent"], row["last_seen"]))
|
||||
for row in rows
|
||||
)
|
||||
defer.returnValue(list(
|
||||
{
|
||||
"access_token": access_token,
|
||||
"ip": ip,
|
||||
"user_agent": user_agent,
|
||||
"last_seen": last_seen,
|
||||
}
|
||||
for (access_token, ip), (user_agent, last_seen) in iteritems(results)
|
||||
))
|
||||
defer.returnValue(
|
||||
list(
|
||||
{
|
||||
"access_token": access_token,
|
||||
"ip": ip,
|
||||
"user_agent": user_agent,
|
||||
"last_seen": last_seen,
|
||||
}
|
||||
for (access_token, ip), (user_agent, last_seen) in iteritems(results)
|
||||
)
|
||||
)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user