Merge branch 'develop' of https://github.com/matrix-org/synapse into sonar
Conflicts: poetry.lock
This commit is contained in:
@@ -380,7 +380,7 @@ jobs:
|
||||
# Attempt to check out the same branch of Complement as the PR. If it
|
||||
# doesn't exist, fallback to HEAD.
|
||||
- name: Checkout complement
|
||||
run: .ci/scripts/checkout_complement.sh
|
||||
run: synapse/.ci/scripts/checkout_complement.sh
|
||||
|
||||
- run: |
|
||||
set -o pipefail
|
||||
|
||||
@@ -0,0 +1 @@
|
||||
Implement [MSC3816](https://github.com/matrix-org/matrix-spec-proposals/pull/3816): sending the root event in a thread should count as "participated" in it.
|
||||
@@ -0,0 +1 @@
|
||||
Reduce the amount of state we pull from the DB.
|
||||
@@ -0,0 +1 @@
|
||||
Faster room joins: when querying the current state of the room, wait for state to be populated.
|
||||
@@ -0,0 +1 @@
|
||||
Remove support for the non-standard groups/communities feature from Synapse.
|
||||
@@ -0,0 +1 @@
|
||||
Remove PyNaCl occurrences directly used in Synapse code.
|
||||
@@ -0,0 +1 @@
|
||||
Fix a bug introduced in Synapse 1.58.0 where `/sync` would fail if the most recent event in a room was a redaction of an event that has since been purged.
|
||||
@@ -0,0 +1 @@
|
||||
Fix potential memory leak when generating thumbnails.
|
||||
@@ -0,0 +1 @@
|
||||
Test Synapse against Complement with workers.
|
||||
@@ -0,0 +1 @@
|
||||
Remove support for the non-standard groups/communities feature from Synapse.
|
||||
@@ -0,0 +1 @@
|
||||
Fix a long-standing bug where a URL preview would break if the image failed to download.
|
||||
@@ -0,0 +1 @@
|
||||
Improve URL previews for pages with empty elements.
|
||||
@@ -0,0 +1 @@
|
||||
Allow updating a user's password using the admin API without logging out their devices. Contributed by @jcgruenhage.
|
||||
@@ -0,0 +1 @@
|
||||
Reduce the amount of state we pull from the DB.
|
||||
@@ -0,0 +1 @@
|
||||
Remove support for the non-standard groups/communities feature from Synapse.
|
||||
@@ -16,6 +16,7 @@
|
||||
|
||||
""" Starts a synapse client console. """
|
||||
import argparse
|
||||
import binascii
|
||||
import cmd
|
||||
import getpass
|
||||
import json
|
||||
@@ -26,9 +27,8 @@ import urllib
|
||||
from http import TwistedHttpClient
|
||||
from typing import Optional
|
||||
|
||||
import nacl.encoding
|
||||
import nacl.signing
|
||||
import urlparse
|
||||
from signedjson.key import NACL_ED25519, decode_verify_key_bytes
|
||||
from signedjson.sign import SignatureVerifyException, verify_signed_json
|
||||
|
||||
from twisted.internet import defer, reactor, threads
|
||||
@@ -41,7 +41,6 @@ TRUSTED_ID_SERVERS = ["localhost:8001"]
|
||||
|
||||
|
||||
class SynapseCmd(cmd.Cmd):
|
||||
|
||||
"""Basic synapse command-line processor.
|
||||
|
||||
This processes commands from the user and calls the relevant HTTP methods.
|
||||
@@ -420,8 +419,8 @@ class SynapseCmd(cmd.Cmd):
|
||||
pubKey = None
|
||||
pubKeyObj = yield self.http_client.do_request("GET", url)
|
||||
if "public_key" in pubKeyObj:
|
||||
pubKey = nacl.signing.VerifyKey(
|
||||
pubKeyObj["public_key"], encoder=nacl.encoding.HexEncoder
|
||||
pubKey = decode_verify_key_bytes(
|
||||
NACL_ED25519, binascii.unhexlify(pubKeyObj["public_key"])
|
||||
)
|
||||
else:
|
||||
print("No public key found in pubkey response!")
|
||||
|
||||
@@ -115,7 +115,9 @@ URL parameters:
|
||||
Body parameters:
|
||||
|
||||
- `password` - string, optional. If provided, the user's password is updated and all
|
||||
devices are logged out.
|
||||
devices are logged out, unless `logout_devices` is set to `false`.
|
||||
- `logout_devices` - bool, optional, defaults to `true`. If set to false, devices aren't
|
||||
logged out even when `password` is provided.
|
||||
- `displayname` - string, optional, defaults to the value of `user_id`.
|
||||
- `threepids` - array, optional, allows setting the third-party IDs (email, msisdn)
|
||||
- `medium` - string. Kind of third-party ID, either `email` or `msisdn`.
|
||||
|
||||
@@ -191,7 +191,6 @@ information.
|
||||
^/_matrix/federation/v1/event_auth/
|
||||
^/_matrix/federation/v1/exchange_third_party_invite/
|
||||
^/_matrix/federation/v1/user/devices/
|
||||
^/_matrix/federation/v1/get_groups_publicised$
|
||||
^/_matrix/key/v2/query
|
||||
^/_matrix/federation/v1/hierarchy/
|
||||
|
||||
@@ -213,9 +212,6 @@ information.
|
||||
^/_matrix/client/(r0|v3|unstable)/devices$
|
||||
^/_matrix/client/versions$
|
||||
^/_matrix/client/(api/v1|r0|v3|unstable)/voip/turnServer$
|
||||
^/_matrix/client/(r0|v3|unstable)/joined_groups$
|
||||
^/_matrix/client/(r0|v3|unstable)/publicised_groups$
|
||||
^/_matrix/client/(r0|v3|unstable)/publicised_groups/
|
||||
^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/event/
|
||||
^/_matrix/client/(api/v1|r0|v3|unstable)/joined_rooms$
|
||||
^/_matrix/client/(api/v1|r0|v3|unstable)/search$
|
||||
@@ -255,9 +251,7 @@ information.
|
||||
|
||||
Additionally, the following REST endpoints can be handled for GET requests:
|
||||
|
||||
^/_matrix/federation/v1/groups/
|
||||
^/_matrix/client/(api/v1|r0|v3|unstable)/pushrules/
|
||||
^/_matrix/client/(r0|v3|unstable)/groups/
|
||||
|
||||
Pagination requests can also be handled, but all requests for a given
|
||||
room must be routed to the same instance. Additionally, care must be taken to
|
||||
|
||||
Generated
+1
-55
@@ -187,17 +187,6 @@ category = "main"
|
||||
optional = false
|
||||
python-versions = "*"
|
||||
|
||||
[[package]]
|
||||
name = "coverage"
|
||||
version = "6.4"
|
||||
description = "Code coverage measurement for Python"
|
||||
category = "dev"
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
|
||||
[package.extras]
|
||||
toml = ["tomli"]
|
||||
|
||||
[[package]]
|
||||
name = "cryptography"
|
||||
version = "36.0.1"
|
||||
@@ -1574,7 +1563,7 @@ url_preview = ["lxml"]
|
||||
[metadata]
|
||||
lock-version = "1.1"
|
||||
python-versions = "^3.7.1"
|
||||
content-hash = "4ca5c8a4f817f99704a5a1ba466a50655073d2c899cf5034c426f84dc4afe0c7"
|
||||
content-hash = "539e5326f401472d1ffc8325d53d72e544cd70156b3f43f32f1285c4c131f831"
|
||||
|
||||
[metadata.files]
|
||||
attrs = [
|
||||
@@ -1713,49 +1702,6 @@ constantly = [
|
||||
{file = "constantly-15.1.0-py2.py3-none-any.whl", hash = "sha256:dd2fa9d6b1a51a83f0d7dd76293d734046aa176e384bf6e33b7e44880eb37c5d"},
|
||||
{file = "constantly-15.1.0.tar.gz", hash = "sha256:586372eb92059873e29eba4f9dec8381541b4d3834660707faf8ba59146dfc35"},
|
||||
]
|
||||
coverage = [
|
||||
{file = "coverage-6.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:50ed480b798febce113709846b11f5d5ed1e529c88d8ae92f707806c50297abf"},
|
||||
{file = "coverage-6.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:26f8f92699756cb7af2b30720de0c5bb8d028e923a95b6d0c891088025a1ac8f"},
|
||||
{file = "coverage-6.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:60c2147921da7f4d2d04f570e1838db32b95c5509d248f3fe6417e91437eaf41"},
|
||||
{file = "coverage-6.4-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:750e13834b597eeb8ae6e72aa58d1d831b96beec5ad1d04479ae3772373a8088"},
|
||||
{file = "coverage-6.4-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:af5b9ee0fc146e907aa0f5fb858c3b3da9199d78b7bb2c9973d95550bd40f701"},
|
||||
{file = "coverage-6.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:a022394996419142b33a0cf7274cb444c01d2bb123727c4bb0b9acabcb515dea"},
|
||||
{file = "coverage-6.4-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:5a78cf2c43b13aa6b56003707c5203f28585944c277c1f3f109c7b041b16bd39"},
|
||||
{file = "coverage-6.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:9229d074e097f21dfe0643d9d0140ee7433814b3f0fc3706b4abffd1e3038632"},
|
||||
{file = "coverage-6.4-cp310-cp310-win32.whl", hash = "sha256:fb45fe08e1abc64eb836d187b20a59172053999823f7f6ef4f18a819c44ba16f"},
|
||||
{file = "coverage-6.4-cp310-cp310-win_amd64.whl", hash = "sha256:3cfd07c5889ddb96a401449109a8b97a165be9d67077df6802f59708bfb07720"},
|
||||
{file = "coverage-6.4-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:03014a74023abaf5a591eeeaf1ac66a73d54eba178ff4cb1fa0c0a44aae70383"},
|
||||
{file = "coverage-6.4-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9c82f2cd69c71698152e943f4a5a6b83a3ab1db73b88f6e769fabc86074c3b08"},
|
||||
{file = "coverage-6.4-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7b546cf2b1974ddc2cb222a109b37c6ed1778b9be7e6b0c0bc0cf0438d9e45a6"},
|
||||
{file = "coverage-6.4-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cc173f1ce9ffb16b299f51c9ce53f66a62f4d975abe5640e976904066f3c835d"},
|
||||
{file = "coverage-6.4-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:c53ad261dfc8695062fc8811ac7c162bd6096a05a19f26097f411bdf5747aee7"},
|
||||
{file = "coverage-6.4-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:eef5292b60b6de753d6e7f2d128d5841c7915fb1e3321c3a1fe6acfe76c38052"},
|
||||
{file = "coverage-6.4-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:543e172ce4c0de533fa892034cce260467b213c0ea8e39da2f65f9a477425211"},
|
||||
{file = "coverage-6.4-cp37-cp37m-win32.whl", hash = "sha256:00c8544510f3c98476bbd58201ac2b150ffbcce46a8c3e4fb89ebf01998f806a"},
|
||||
{file = "coverage-6.4-cp37-cp37m-win_amd64.whl", hash = "sha256:b84ab65444dcc68d761e95d4d70f3cfd347ceca5a029f2ffec37d4f124f61311"},
|
||||
{file = "coverage-6.4-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:d548edacbf16a8276af13063a2b0669d58bbcfca7c55a255f84aac2870786a61"},
|
||||
{file = "coverage-6.4-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:033ebec282793bd9eb988d0271c211e58442c31077976c19c442e24d827d356f"},
|
||||
{file = "coverage-6.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:742fb8b43835078dd7496c3c25a1ec8d15351df49fb0037bffb4754291ef30ce"},
|
||||
{file = "coverage-6.4-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d55fae115ef9f67934e9f1103c9ba826b4c690e4c5bcf94482b8b2398311bf9c"},
|
||||
{file = "coverage-6.4-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5cd698341626f3c77784858427bad0cdd54a713115b423d22ac83a28303d1d95"},
|
||||
{file = "coverage-6.4-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:62d382f7d77eeeaff14b30516b17bcbe80f645f5cf02bb755baac376591c653c"},
|
||||
{file = "coverage-6.4-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:016d7f5cf1c8c84f533a3c1f8f36126fbe00b2ec0ccca47cc5731c3723d327c6"},
|
||||
{file = "coverage-6.4-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:69432946f154c6add0e9ede03cc43b96e2ef2733110a77444823c053b1ff5166"},
|
||||
{file = "coverage-6.4-cp38-cp38-win32.whl", hash = "sha256:83bd142cdec5e4a5c4ca1d4ff6fa807d28460f9db919f9f6a31babaaa8b88426"},
|
||||
{file = "coverage-6.4-cp38-cp38-win_amd64.whl", hash = "sha256:4002f9e8c1f286e986fe96ec58742b93484195defc01d5cc7809b8f7acb5ece3"},
|
||||
{file = "coverage-6.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:e4f52c272fdc82e7c65ff3f17a7179bc5f710ebc8ce8a5cadac81215e8326740"},
|
||||
{file = "coverage-6.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:b5578efe4038be02d76c344007b13119b2b20acd009a88dde8adec2de4f630b5"},
|
||||
{file = "coverage-6.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d8099ea680201c2221f8468c372198ceba9338a5fec0e940111962b03b3f716a"},
|
||||
{file = "coverage-6.4-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a00441f5ea4504f5abbc047589d09e0dc33eb447dc45a1a527c8b74bfdd32c65"},
|
||||
{file = "coverage-6.4-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2e76bd16f0e31bc2b07e0fb1379551fcd40daf8cdf7e24f31a29e442878a827c"},
|
||||
{file = "coverage-6.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:8d2e80dd3438e93b19e1223a9850fa65425e77f2607a364b6fd134fcd52dc9df"},
|
||||
{file = "coverage-6.4-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:341e9c2008c481c5c72d0e0dbf64980a4b2238631a7f9780b0fe2e95755fb018"},
|
||||
{file = "coverage-6.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:21e6686a95025927775ac501e74f5940cdf6fe052292f3a3f7349b0abae6d00f"},
|
||||
{file = "coverage-6.4-cp39-cp39-win32.whl", hash = "sha256:968ed5407f9460bd5a591cefd1388cc00a8f5099de9e76234655ae48cfdbe2c3"},
|
||||
{file = "coverage-6.4-cp39-cp39-win_amd64.whl", hash = "sha256:e35217031e4b534b09f9b9a5841b9344a30a6357627761d4218818b865d45055"},
|
||||
{file = "coverage-6.4-pp36.pp37.pp38-none-any.whl", hash = "sha256:e637ae0b7b481905358624ef2e81d7fb0b1af55f5ff99f9ba05442a444b11e45"},
|
||||
{file = "coverage-6.4.tar.gz", hash = "sha256:727dafd7f67a6e1cad808dc884bd9c5a2f6ef1f8f6d2f22b37b96cb0080d4f49"},
|
||||
]
|
||||
cryptography = [
|
||||
{file = "cryptography-36.0.1-cp36-abi3-macosx_10_10_universal2.whl", hash = "sha256:73bc2d3f2444bcfeac67dd130ff2ea598ea5f20b40e36d19821b4df8c9c5037b"},
|
||||
{file = "cryptography-36.0.1-cp36-abi3-macosx_10_10_x86_64.whl", hash = "sha256:2d87cdcb378d3cfed944dac30596da1968f88fb96d7fc34fdae30a99054b2e31"},
|
||||
|
||||
@@ -113,7 +113,6 @@ unpaddedbase64 = ">=2.1.0"
|
||||
canonicaljson = ">=1.4.0"
|
||||
# we use the type definitions added in signedjson 1.1.
|
||||
signedjson = ">=1.1.0"
|
||||
PyNaCl = ">=1.2.1"
|
||||
# validating SSL certs for IP addresses requires service_identity 18.1.
|
||||
service-identity = ">=18.1.0"
|
||||
# Twisted 18.9 introduces some logger improvements that the structured
|
||||
|
||||
@@ -102,14 +102,6 @@ BOOLEAN_COLUMNS = {
|
||||
"devices": ["hidden"],
|
||||
"device_lists_outbound_pokes": ["sent"],
|
||||
"users_who_share_rooms": ["share_private"],
|
||||
"groups": ["is_public"],
|
||||
"group_rooms": ["is_public"],
|
||||
"group_users": ["is_public", "is_admin"],
|
||||
"group_summary_rooms": ["is_public"],
|
||||
"group_room_categories": ["is_public"],
|
||||
"group_summary_users": ["is_public"],
|
||||
"group_roles": ["is_public"],
|
||||
"local_group_membership": ["is_publicised", "is_admin"],
|
||||
"e2e_room_keys": ["is_verified"],
|
||||
"account_validity": ["email_sent"],
|
||||
"redactions": ["have_censored"],
|
||||
@@ -175,6 +167,22 @@ IGNORED_TABLES = {
|
||||
"ui_auth_sessions",
|
||||
"ui_auth_sessions_credentials",
|
||||
"ui_auth_sessions_ips",
|
||||
# Groups/communities is no longer supported.
|
||||
"group_attestations_remote",
|
||||
"group_attestations_renewals",
|
||||
"group_invites",
|
||||
"group_roles",
|
||||
"group_room_categories",
|
||||
"group_rooms",
|
||||
"group_summary_roles",
|
||||
"group_summary_room_categories",
|
||||
"group_summary_rooms",
|
||||
"group_summary_users",
|
||||
"group_users",
|
||||
"groups",
|
||||
"local_group_membership",
|
||||
"local_group_updates",
|
||||
"remote_profile_cache",
|
||||
}
|
||||
|
||||
|
||||
|
||||
+22
-23
@@ -29,12 +29,11 @@ from synapse.api.errors import (
|
||||
MissingClientTokenError,
|
||||
)
|
||||
from synapse.appservice import ApplicationService
|
||||
from synapse.events import EventBase
|
||||
from synapse.http import get_request_user_agent
|
||||
from synapse.http.site import SynapseRequest
|
||||
from synapse.logging.opentracing import active_span, force_tracing, start_active_span
|
||||
from synapse.storage.databases.main.registration import TokenLookupResult
|
||||
from synapse.types import Requester, StateMap, UserID, create_requester
|
||||
from synapse.types import Requester, UserID, create_requester
|
||||
from synapse.util.caches.lrucache import LruCache
|
||||
from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry
|
||||
|
||||
@@ -61,8 +60,8 @@ class Auth:
|
||||
self.hs = hs
|
||||
self.clock = hs.get_clock()
|
||||
self.store = hs.get_datastores().main
|
||||
self.state = hs.get_state_handler()
|
||||
self._account_validity_handler = hs.get_account_validity_handler()
|
||||
self._storage_controllers = hs.get_storage_controllers()
|
||||
|
||||
self.token_cache: LruCache[str, Tuple[str, bool]] = LruCache(
|
||||
10000, "token_cache"
|
||||
@@ -79,9 +78,8 @@ class Auth:
|
||||
self,
|
||||
room_id: str,
|
||||
user_id: str,
|
||||
current_state: Optional[StateMap[EventBase]] = None,
|
||||
allow_departed_users: bool = False,
|
||||
) -> EventBase:
|
||||
) -> Tuple[str, Optional[str]]:
|
||||
"""Check if the user is in the room, or was at some point.
|
||||
Args:
|
||||
room_id: The room to check.
|
||||
@@ -99,29 +97,28 @@ class Auth:
|
||||
Raises:
|
||||
AuthError if the user is/was not in the room.
|
||||
Returns:
|
||||
Membership event for the user if the user was in the
|
||||
room. This will be the join event if they are currently joined to
|
||||
the room. This will be the leave event if they have left the room.
|
||||
The current membership of the user in the room and the
|
||||
membership event ID of the user.
|
||||
"""
|
||||
if current_state:
|
||||
member = current_state.get((EventTypes.Member, user_id), None)
|
||||
else:
|
||||
member = await self.state.get_current_state(
|
||||
room_id=room_id, event_type=EventTypes.Member, state_key=user_id
|
||||
)
|
||||
|
||||
if member:
|
||||
membership = member.membership
|
||||
(
|
||||
membership,
|
||||
member_event_id,
|
||||
) = await self.store.get_local_current_membership_for_user_in_room(
|
||||
user_id=user_id,
|
||||
room_id=room_id,
|
||||
)
|
||||
|
||||
if membership:
|
||||
if membership == Membership.JOIN:
|
||||
return member
|
||||
return membership, member_event_id
|
||||
|
||||
# XXX this looks totally bogus. Why do we not allow users who have been banned,
|
||||
# or those who were members previously and have been re-invited?
|
||||
if allow_departed_users and membership == Membership.LEAVE:
|
||||
forgot = await self.store.did_forget(user_id, room_id)
|
||||
if not forgot:
|
||||
return member
|
||||
return membership, member_event_id
|
||||
|
||||
raise AuthError(403, "User %s not in room %s" % (user_id, room_id))
|
||||
|
||||
@@ -602,8 +599,11 @@ class Auth:
|
||||
# We currently require the user is a "moderator" in the room. We do this
|
||||
# by checking if they would (theoretically) be able to change the
|
||||
# m.room.canonical_alias events
|
||||
power_level_event = await self.state.get_current_state(
|
||||
room_id, EventTypes.PowerLevels, ""
|
||||
|
||||
power_level_event = (
|
||||
await self._storage_controllers.state.get_current_state_event(
|
||||
room_id, EventTypes.PowerLevels, ""
|
||||
)
|
||||
)
|
||||
|
||||
auth_events = {}
|
||||
@@ -693,12 +693,11 @@ class Auth:
|
||||
# * The user is a non-guest user, and was ever in the room
|
||||
# * The user is a guest user, and has joined the room
|
||||
# else it will throw.
|
||||
member_event = await self.check_user_in_room(
|
||||
return await self.check_user_in_room(
|
||||
room_id, user_id, allow_departed_users=allow_departed_users
|
||||
)
|
||||
return member_event.membership, member_event.event_id
|
||||
except AuthError:
|
||||
visibility = await self.state.get_current_state(
|
||||
visibility = await self._storage_controllers.state.get_current_state_event(
|
||||
room_id, EventTypes.RoomHistoryVisibility, ""
|
||||
)
|
||||
if (
|
||||
|
||||
@@ -95,7 +95,6 @@ class EventTypes:
|
||||
Aliases: Final = "m.room.aliases"
|
||||
Redaction: Final = "m.room.redaction"
|
||||
ThirdPartyInvite: Final = "m.room.third_party_invite"
|
||||
RelatedGroups: Final = "m.room.related_groups"
|
||||
|
||||
RoomHistoryVisibility: Final = "m.room.history_visibility"
|
||||
CanonicalAlias: Final = "m.room.canonical_alias"
|
||||
|
||||
@@ -70,7 +70,6 @@ class ApplicationService:
|
||||
def __init__(
|
||||
self,
|
||||
token: str,
|
||||
hostname: str,
|
||||
id: str,
|
||||
sender: str,
|
||||
url: Optional[str] = None,
|
||||
@@ -88,7 +87,6 @@ class ApplicationService:
|
||||
) # url must not end with a slash
|
||||
self.hs_token = hs_token
|
||||
self.sender = sender
|
||||
self.server_name = hostname
|
||||
self.namespaces = self._check_namespaces(namespaces)
|
||||
self.id = id
|
||||
self.ip_range_whitelist = ip_range_whitelist
|
||||
|
||||
@@ -179,7 +179,6 @@ def _load_appservice(
|
||||
|
||||
return ApplicationService(
|
||||
token=as_info["as_token"],
|
||||
hostname=hostname,
|
||||
url=as_info["url"],
|
||||
namespaces=as_info["namespaces"],
|
||||
hs_token=as_info["hs_token"],
|
||||
|
||||
@@ -152,6 +152,7 @@ class ThirdPartyEventRules:
|
||||
self.third_party_rules = None
|
||||
|
||||
self.store = hs.get_datastores().main
|
||||
self._storage_controllers = hs.get_storage_controllers()
|
||||
|
||||
self._check_event_allowed_callbacks: List[CHECK_EVENT_ALLOWED_CALLBACK] = []
|
||||
self._on_create_room_callbacks: List[ON_CREATE_ROOM_CALLBACK] = []
|
||||
@@ -463,7 +464,7 @@ class ThirdPartyEventRules:
|
||||
Returns:
|
||||
A dict mapping (event type, state key) to state event.
|
||||
"""
|
||||
state_ids = await self.store.get_filtered_current_state_ids(room_id)
|
||||
state_ids = await self._storage_controllers.state.get_current_state_ids(room_id)
|
||||
room_state_events = await self.store.get_events(state_ids.values())
|
||||
|
||||
state_events = {}
|
||||
|
||||
@@ -53,6 +53,7 @@ class FederationBase:
|
||||
self.spam_checker = hs.get_spam_checker()
|
||||
self.store = hs.get_datastores().main
|
||||
self._clock = hs.get_clock()
|
||||
self._storage_controllers = hs.get_storage_controllers()
|
||||
|
||||
async def _check_sigs_and_hash(
|
||||
self, room_version: RoomVersion, pdu: EventBase
|
||||
|
||||
@@ -118,6 +118,8 @@ class FederationServer(FederationBase):
|
||||
self.state = hs.get_state_handler()
|
||||
self._event_auth_handler = hs.get_event_auth_handler()
|
||||
|
||||
self._state_storage_controller = hs.get_storage_controllers().state
|
||||
|
||||
self.device_handler = hs.get_device_handler()
|
||||
|
||||
# Ensure the following handlers are loaded since they register callbacks
|
||||
@@ -1221,14 +1223,10 @@ class FederationServer(FederationBase):
|
||||
Raises:
|
||||
AuthError if the server does not match the ACL
|
||||
"""
|
||||
state_ids = await self.store.get_current_state_ids(room_id)
|
||||
acl_event_id = state_ids.get((EventTypes.ServerACL, ""))
|
||||
|
||||
if not acl_event_id:
|
||||
return
|
||||
|
||||
acl_event = await self.store.get_event(acl_event_id)
|
||||
if server_matches_acl_event(server_name, acl_event):
|
||||
acl_event = await self._storage_controllers.state.get_current_state_event(
|
||||
room_id, EventTypes.ServerACL, ""
|
||||
)
|
||||
if not acl_event or server_matches_acl_event(server_name, acl_event):
|
||||
return
|
||||
|
||||
raise AuthError(code=403, msg="Server is banned from room")
|
||||
|
||||
@@ -245,6 +245,8 @@ class FederationSender(AbstractFederationSender):
|
||||
self.store = hs.get_datastores().main
|
||||
self.state = hs.get_state_handler()
|
||||
|
||||
self._storage_controllers = hs.get_storage_controllers()
|
||||
|
||||
self.clock = hs.get_clock()
|
||||
self.is_mine_id = hs.is_mine_id
|
||||
|
||||
@@ -602,7 +604,9 @@ class FederationSender(AbstractFederationSender):
|
||||
room_id = receipt.room_id
|
||||
|
||||
# Work out which remote servers should be poked and poke them.
|
||||
domains_set = await self.state.get_current_hosts_in_room(room_id)
|
||||
domains_set = await self._storage_controllers.state.get_current_hosts_in_room(
|
||||
room_id
|
||||
)
|
||||
domains = [
|
||||
d
|
||||
for d in domains_set
|
||||
|
||||
@@ -166,7 +166,7 @@ class DeviceWorkerHandler:
|
||||
possibly_changed = set(changed)
|
||||
possibly_left = set()
|
||||
for room_id in rooms_changed:
|
||||
current_state_ids = await self.store.get_current_state_ids(room_id)
|
||||
current_state_ids = await self._state_storage.get_current_state_ids(room_id)
|
||||
|
||||
# The user may have left the room
|
||||
# TODO: Check if they actually did or if we were just invited.
|
||||
|
||||
@@ -45,6 +45,7 @@ class DirectoryHandler:
|
||||
self.appservice_handler = hs.get_application_service_handler()
|
||||
self.event_creation_handler = hs.get_event_creation_handler()
|
||||
self.store = hs.get_datastores().main
|
||||
self._storage_controllers = hs.get_storage_controllers()
|
||||
self.config = hs.config
|
||||
self.enable_room_list_search = hs.config.roomdirectory.enable_room_list_search
|
||||
self.require_membership = hs.config.server.require_membership_for_aliases
|
||||
@@ -319,7 +320,7 @@ class DirectoryHandler:
|
||||
Raises:
|
||||
ShadowBanError if the requester has been shadow-banned.
|
||||
"""
|
||||
alias_event = await self.state.get_current_state(
|
||||
alias_event = await self._storage_controllers.state.get_current_state_event(
|
||||
room_id, EventTypes.CanonicalAlias, ""
|
||||
)
|
||||
|
||||
@@ -463,7 +464,11 @@ class DirectoryHandler:
|
||||
making_public = visibility == "public"
|
||||
if making_public:
|
||||
room_aliases = await self.store.get_aliases_for_room(room_id)
|
||||
canonical_alias = await self.store.get_canonical_alias_for_room(room_id)
|
||||
canonical_alias = (
|
||||
await self._storage_controllers.state.get_canonical_alias_for_room(
|
||||
room_id
|
||||
)
|
||||
)
|
||||
if canonical_alias:
|
||||
room_aliases.append(canonical_alias)
|
||||
|
||||
|
||||
@@ -371,7 +371,7 @@ class FederationHandler:
|
||||
# First we try hosts that are already in the room
|
||||
# TODO: HEURISTIC ALERT.
|
||||
|
||||
curr_state = await self.state_handler.get_current_state(room_id)
|
||||
curr_state = await self._storage_controllers.state.get_current_state(room_id)
|
||||
|
||||
curr_domains = get_domains_from_state(curr_state)
|
||||
|
||||
@@ -750,7 +750,9 @@ class FederationHandler:
|
||||
# Note that this requires the /send_join request to come back to the
|
||||
# same server.
|
||||
if room_version.msc3083_join_rules:
|
||||
state_ids = await self.store.get_current_state_ids(room_id)
|
||||
state_ids = await self._state_storage_controller.get_current_state_ids(
|
||||
room_id
|
||||
)
|
||||
if await self._event_auth_handler.has_restricted_join_rules(
|
||||
state_ids, room_version
|
||||
):
|
||||
@@ -1552,6 +1554,9 @@ class FederationHandler:
|
||||
success = await self.store.clear_partial_state_room(room_id)
|
||||
if success:
|
||||
logger.info("State resync complete for %s", room_id)
|
||||
self._storage_controllers.state.notify_room_un_partial_stated(
|
||||
room_id
|
||||
)
|
||||
|
||||
# TODO(faster_joins) update room stats and user directory?
|
||||
return
|
||||
|
||||
@@ -1584,9 +1584,11 @@ class FederationEventHandler:
|
||||
if guest_access == GuestAccess.CAN_JOIN:
|
||||
return
|
||||
|
||||
current_state_map = await self._state_handler.get_current_state(event.room_id)
|
||||
current_state = list(current_state_map.values())
|
||||
await self._get_room_member_handler().kick_guest_users(current_state)
|
||||
current_state = await self._storage_controllers.state.get_current_state(
|
||||
event.room_id
|
||||
)
|
||||
current_state_list = list(current_state.values())
|
||||
await self._get_room_member_handler().kick_guest_users(current_state_list)
|
||||
|
||||
async def _check_for_soft_fail(
|
||||
self,
|
||||
@@ -1614,6 +1616,9 @@ class FederationEventHandler:
|
||||
room_version = await self._store.get_room_version_id(event.room_id)
|
||||
room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
|
||||
|
||||
# The event types we want to pull from the "current" state.
|
||||
auth_types = auth_types_for_event(room_version_obj, event)
|
||||
|
||||
# Calculate the "current state".
|
||||
if state_ids is not None:
|
||||
# If we're explicitly given the state then we won't have all the
|
||||
@@ -1643,8 +1648,10 @@ class FederationEventHandler:
|
||||
)
|
||||
)
|
||||
else:
|
||||
current_state_ids = await self._state_handler.get_current_state_ids(
|
||||
event.room_id, latest_event_ids=extrem_ids
|
||||
current_state_ids = (
|
||||
await self._state_storage_controller.get_current_state_ids(
|
||||
event.room_id, StateFilter.from_types(auth_types)
|
||||
)
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
@@ -1654,7 +1661,6 @@ class FederationEventHandler:
|
||||
)
|
||||
|
||||
# Now check if event pass auth against said current state
|
||||
auth_types = auth_types_for_event(room_version_obj, event)
|
||||
current_state_ids_list = [
|
||||
e for k, e in current_state_ids.items() if k in auth_types
|
||||
]
|
||||
|
||||
@@ -190,7 +190,7 @@ class InitialSyncHandler:
|
||||
if event.membership == Membership.JOIN:
|
||||
room_end_token = now_token.room_key
|
||||
deferred_room_state = run_in_background(
|
||||
self.state_handler.get_current_state, event.room_id
|
||||
self._state_storage_controller.get_current_state, event.room_id
|
||||
)
|
||||
elif event.membership == Membership.LEAVE:
|
||||
room_end_token = RoomStreamToken(
|
||||
@@ -407,7 +407,9 @@ class InitialSyncHandler:
|
||||
membership: str,
|
||||
is_peeking: bool,
|
||||
) -> JsonDict:
|
||||
current_state = await self.state.get_current_state(room_id=room_id)
|
||||
current_state = await self._storage_controllers.state.get_current_state(
|
||||
room_id=room_id
|
||||
)
|
||||
|
||||
# TODO: These concurrently
|
||||
time_now = self.clock.time_msec()
|
||||
|
||||
+81
-39
@@ -28,6 +28,7 @@ from synapse.api.constants import (
|
||||
EventContentFields,
|
||||
EventTypes,
|
||||
GuestAccess,
|
||||
HistoryVisibility,
|
||||
Membership,
|
||||
RelationTypes,
|
||||
UserTypes,
|
||||
@@ -66,7 +67,7 @@ from synapse.util import json_decoder, json_encoder, log_failure, unwrapFirstErr
|
||||
from synapse.util.async_helpers import Linearizer, gather_results
|
||||
from synapse.util.caches.expiringcache import ExpiringCache
|
||||
from synapse.util.metrics import measure_func
|
||||
from synapse.visibility import filter_events_for_client
|
||||
from synapse.visibility import get_effective_room_visibility_from_state
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from synapse.events.third_party_rules import ThirdPartyEventRules
|
||||
@@ -124,7 +125,9 @@ class MessageHandler:
|
||||
)
|
||||
|
||||
if membership == Membership.JOIN:
|
||||
data = await self.state.get_current_state(room_id, event_type, state_key)
|
||||
data = await self._storage_controllers.state.get_current_state_event(
|
||||
room_id, event_type, state_key
|
||||
)
|
||||
elif membership == Membership.LEAVE:
|
||||
key = (event_type, state_key)
|
||||
# If the membership is not JOIN, then the event ID should exist.
|
||||
@@ -182,51 +185,31 @@ class MessageHandler:
|
||||
state_filter = state_filter or StateFilter.all()
|
||||
|
||||
if at_token:
|
||||
last_event = await self.store.get_last_event_in_room_before_stream_ordering(
|
||||
room_id,
|
||||
end_token=at_token.room_key,
|
||||
last_event_id = (
|
||||
await self.store.get_last_event_in_room_before_stream_ordering(
|
||||
room_id,
|
||||
end_token=at_token.room_key,
|
||||
)
|
||||
)
|
||||
|
||||
if not last_event:
|
||||
if not last_event_id:
|
||||
raise NotFoundError("Can't find event for token %s" % (at_token,))
|
||||
|
||||
# check whether the user is in the room at that time to determine
|
||||
# whether they should be treated as peeking.
|
||||
state_map = await self._state_storage_controller.get_state_for_event(
|
||||
last_event.event_id,
|
||||
StateFilter.from_types([(EventTypes.Member, user_id)]),
|
||||
)
|
||||
|
||||
joined = False
|
||||
membership_event = state_map.get((EventTypes.Member, user_id))
|
||||
if membership_event:
|
||||
joined = membership_event.membership == Membership.JOIN
|
||||
|
||||
is_peeking = not joined
|
||||
|
||||
visible_events = await filter_events_for_client(
|
||||
self._storage_controllers,
|
||||
user_id,
|
||||
[last_event],
|
||||
filter_send_to_client=False,
|
||||
is_peeking=is_peeking,
|
||||
)
|
||||
|
||||
if visible_events:
|
||||
room_state_events = (
|
||||
await self._state_storage_controller.get_state_for_events(
|
||||
[last_event.event_id], state_filter=state_filter
|
||||
)
|
||||
)
|
||||
room_state: Mapping[Any, EventBase] = room_state_events[
|
||||
last_event.event_id
|
||||
]
|
||||
else:
|
||||
if not await self._user_can_see_state_at_event(
|
||||
user_id, room_id, last_event_id
|
||||
):
|
||||
raise AuthError(
|
||||
403,
|
||||
"User %s not allowed to view events in room %s at token %s"
|
||||
% (user_id, room_id, at_token),
|
||||
)
|
||||
|
||||
room_state_events = (
|
||||
await self._state_storage_controller.get_state_for_events(
|
||||
[last_event_id], state_filter=state_filter
|
||||
)
|
||||
)
|
||||
room_state: Mapping[Any, EventBase] = room_state_events[last_event_id]
|
||||
else:
|
||||
(
|
||||
membership,
|
||||
@@ -236,7 +219,7 @@ class MessageHandler:
|
||||
)
|
||||
|
||||
if membership == Membership.JOIN:
|
||||
state_ids = await self.store.get_filtered_current_state_ids(
|
||||
state_ids = await self._state_storage_controller.get_current_state_ids(
|
||||
room_id, state_filter=state_filter
|
||||
)
|
||||
room_state = await self.store.get_events(state_ids.values())
|
||||
@@ -256,6 +239,65 @@ class MessageHandler:
|
||||
events = self._event_serializer.serialize_events(room_state.values(), now)
|
||||
return events
|
||||
|
||||
async def _user_can_see_state_at_event(
|
||||
self, user_id: str, room_id: str, event_id: str
|
||||
) -> bool:
|
||||
# check whether the user was in the room, and the history visibility,
|
||||
# at that time.
|
||||
state_map = await self._state_storage_controller.get_state_for_event(
|
||||
event_id,
|
||||
StateFilter.from_types(
|
||||
[
|
||||
(EventTypes.Member, user_id),
|
||||
(EventTypes.RoomHistoryVisibility, ""),
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
membership = None
|
||||
membership_event = state_map.get((EventTypes.Member, user_id))
|
||||
if membership_event:
|
||||
membership = membership_event.membership
|
||||
|
||||
# if the user was a member of the room at the time of the event,
|
||||
# they can see it.
|
||||
if membership == Membership.JOIN:
|
||||
return True
|
||||
|
||||
# otherwise, it depends on the history visibility.
|
||||
visibility = get_effective_room_visibility_from_state(state_map)
|
||||
|
||||
if visibility == HistoryVisibility.JOINED:
|
||||
# we weren't a member at the time of the event, so we can't see this event.
|
||||
return False
|
||||
|
||||
# otherwise *invited* is good enough
|
||||
if membership == Membership.INVITE:
|
||||
return True
|
||||
|
||||
if visibility == HistoryVisibility.INVITED:
|
||||
# we weren't invited, so we can't see this event.
|
||||
return False
|
||||
|
||||
if visibility == HistoryVisibility.WORLD_READABLE:
|
||||
return True
|
||||
|
||||
# So it's SHARED, and the user was not a member at the time. The user cannot
|
||||
# see history, unless they have *subsequently* joined the room.
|
||||
#
|
||||
# XXX: if the user has subsequently joined and then left again,
|
||||
# ideally we would share history up to the point they left. But
|
||||
# we don't know when they left. We just treat it as though they
|
||||
# never joined, and restrict access.
|
||||
|
||||
(
|
||||
current_membership,
|
||||
_,
|
||||
) = await self.store.get_local_current_membership_for_user_in_room(
|
||||
user_id, event_id
|
||||
)
|
||||
return current_membership == Membership.JOIN
|
||||
|
||||
async def get_joined_members(self, requester: Requester, room_id: str) -> dict:
|
||||
"""Get all the joined members in the room and their profile information.
|
||||
|
||||
|
||||
@@ -134,6 +134,7 @@ class BasePresenceHandler(abc.ABC):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.clock = hs.get_clock()
|
||||
self.store = hs.get_datastores().main
|
||||
self._storage_controllers = hs.get_storage_controllers()
|
||||
self.presence_router = hs.get_presence_router()
|
||||
self.state = hs.get_state_handler()
|
||||
self.is_mine_id = hs.is_mine_id
|
||||
@@ -1348,7 +1349,10 @@ class PresenceHandler(BasePresenceHandler):
|
||||
self._event_pos,
|
||||
room_max_stream_ordering,
|
||||
)
|
||||
max_pos, deltas = await self.store.get_current_state_deltas(
|
||||
(
|
||||
max_pos,
|
||||
deltas,
|
||||
) = await self._storage_controllers.state.get_current_state_deltas(
|
||||
self._event_pos, room_max_stream_ordering
|
||||
)
|
||||
|
||||
|
||||
@@ -23,14 +23,7 @@ from synapse.api.errors import (
|
||||
StoreError,
|
||||
SynapseError,
|
||||
)
|
||||
from synapse.metrics.background_process_metrics import wrap_as_background_process
|
||||
from synapse.types import (
|
||||
JsonDict,
|
||||
Requester,
|
||||
UserID,
|
||||
create_requester,
|
||||
get_domain_from_id,
|
||||
)
|
||||
from synapse.types import JsonDict, Requester, UserID, create_requester
|
||||
from synapse.util.caches.descriptors import cached
|
||||
from synapse.util.stringutils import parse_and_validate_mxc_uri
|
||||
|
||||
@@ -50,9 +43,6 @@ class ProfileHandler:
|
||||
delegate to master when necessary.
|
||||
"""
|
||||
|
||||
PROFILE_UPDATE_MS = 60 * 1000
|
||||
PROFILE_UPDATE_EVERY_MS = 24 * 60 * 60 * 1000
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.store = hs.get_datastores().main
|
||||
self.clock = hs.get_clock()
|
||||
@@ -73,11 +63,6 @@ class ProfileHandler:
|
||||
|
||||
self._third_party_rules = hs.get_third_party_event_rules()
|
||||
|
||||
if hs.config.worker.run_background_tasks:
|
||||
self.clock.looping_call(
|
||||
self._update_remote_profile_cache, self.PROFILE_UPDATE_MS
|
||||
)
|
||||
|
||||
async def get_profile(self, user_id: str) -> JsonDict:
|
||||
target_user = UserID.from_string(user_id)
|
||||
|
||||
@@ -116,30 +101,6 @@ class ProfileHandler:
|
||||
raise SynapseError(502, "Failed to fetch profile")
|
||||
raise e.to_synapse_error()
|
||||
|
||||
async def get_profile_from_cache(self, user_id: str) -> JsonDict:
|
||||
"""Get the profile information from our local cache. If the user is
|
||||
ours then the profile information will always be correct. Otherwise,
|
||||
it may be out of date/missing.
|
||||
"""
|
||||
target_user = UserID.from_string(user_id)
|
||||
if self.hs.is_mine(target_user):
|
||||
try:
|
||||
displayname = await self.store.get_profile_displayname(
|
||||
target_user.localpart
|
||||
)
|
||||
avatar_url = await self.store.get_profile_avatar_url(
|
||||
target_user.localpart
|
||||
)
|
||||
except StoreError as e:
|
||||
if e.code == 404:
|
||||
raise SynapseError(404, "Profile was not found", Codes.NOT_FOUND)
|
||||
raise
|
||||
|
||||
return {"displayname": displayname, "avatar_url": avatar_url}
|
||||
else:
|
||||
profile = await self.store.get_from_remote_profile_cache(user_id)
|
||||
return profile or {}
|
||||
|
||||
async def get_displayname(self, target_user: UserID) -> Optional[str]:
|
||||
if self.hs.is_mine(target_user):
|
||||
try:
|
||||
@@ -509,45 +470,3 @@ class ProfileHandler:
|
||||
# so we act as if we couldn't find the profile.
|
||||
raise SynapseError(403, "Profile isn't available", Codes.FORBIDDEN)
|
||||
raise
|
||||
|
||||
@wrap_as_background_process("Update remote profile")
|
||||
async def _update_remote_profile_cache(self) -> None:
|
||||
"""Called periodically to check profiles of remote users we haven't
|
||||
checked in a while.
|
||||
"""
|
||||
entries = await self.store.get_remote_profile_cache_entries_that_expire(
|
||||
last_checked=self.clock.time_msec() - self.PROFILE_UPDATE_EVERY_MS
|
||||
)
|
||||
|
||||
for user_id, displayname, avatar_url in entries:
|
||||
is_subscribed = await self.store.is_subscribed_remote_profile_for_user(
|
||||
user_id
|
||||
)
|
||||
if not is_subscribed:
|
||||
await self.store.maybe_delete_remote_profile_cache(user_id)
|
||||
continue
|
||||
|
||||
try:
|
||||
profile = await self.federation.make_query(
|
||||
destination=get_domain_from_id(user_id),
|
||||
query_type="profile",
|
||||
args={"user_id": user_id},
|
||||
ignore_backoff=True,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Failed to get avatar_url")
|
||||
|
||||
await self.store.update_remote_profile_cache(
|
||||
user_id, displayname, avatar_url
|
||||
)
|
||||
continue
|
||||
|
||||
new_name = profile.get("displayname")
|
||||
if not isinstance(new_name, str):
|
||||
new_name = None
|
||||
new_avatar = profile.get("avatar_url")
|
||||
if not isinstance(new_avatar, str):
|
||||
new_avatar = None
|
||||
|
||||
# We always hit update to update the last_check timestamp
|
||||
await self.store.update_remote_profile_cache(user_id, new_name, new_avatar)
|
||||
|
||||
@@ -87,6 +87,7 @@ class LoginDict(TypedDict):
|
||||
class RegistrationHandler:
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.store = hs.get_datastores().main
|
||||
self._storage_controllers = hs.get_storage_controllers()
|
||||
self.clock = hs.get_clock()
|
||||
self.hs = hs
|
||||
self.auth = hs.get_auth()
|
||||
@@ -528,7 +529,7 @@ class RegistrationHandler:
|
||||
|
||||
if requires_invite:
|
||||
# If the server is in the room, check if the room is public.
|
||||
state = await self.store.get_filtered_current_state_ids(
|
||||
state = await self._storage_controllers.state.get_current_state_ids(
|
||||
room_id, StateFilter.from_types([(EventTypes.JoinRules, "")])
|
||||
)
|
||||
|
||||
|
||||
@@ -12,16 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Collection,
|
||||
Dict,
|
||||
FrozenSet,
|
||||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
)
|
||||
from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Tuple
|
||||
|
||||
import attr
|
||||
|
||||
@@ -256,13 +247,19 @@ class RelationsHandler:
|
||||
|
||||
return filtered_results
|
||||
|
||||
async def get_threads_for_events(
|
||||
self, event_ids: Collection[str], user_id: str, ignored_users: FrozenSet[str]
|
||||
async def _get_threads_for_events(
|
||||
self,
|
||||
events_by_id: Dict[str, EventBase],
|
||||
relations_by_id: Dict[str, str],
|
||||
user_id: str,
|
||||
ignored_users: FrozenSet[str],
|
||||
) -> Dict[str, _ThreadAggregation]:
|
||||
"""Get the bundled aggregations for threads for the requested events.
|
||||
|
||||
Args:
|
||||
event_ids: Events to get aggregations for threads.
|
||||
events_by_id: A map of event_id to events to get aggregations for threads.
|
||||
relations_by_id: A map of event_id to the relation type, if one exists
|
||||
for that event.
|
||||
user_id: The user requesting the bundled aggregations.
|
||||
ignored_users: The users ignored by the requesting user.
|
||||
|
||||
@@ -273,16 +270,34 @@ class RelationsHandler:
|
||||
"""
|
||||
user = UserID.from_string(user_id)
|
||||
|
||||
# It is not valid to start a thread on an event which itself relates to another event.
|
||||
event_ids = [eid for eid in events_by_id.keys() if eid not in relations_by_id]
|
||||
|
||||
# Fetch thread summaries.
|
||||
summaries = await self._main_store.get_thread_summaries(event_ids)
|
||||
|
||||
# Only fetch participated for a limited selection based on what had
|
||||
# summaries.
|
||||
# Limit fetching whether the requester has participated in a thread to
|
||||
# events which are thread roots.
|
||||
thread_event_ids = [
|
||||
event_id for event_id, summary in summaries.items() if summary
|
||||
]
|
||||
participated = await self._main_store.get_threads_participated(
|
||||
thread_event_ids, user_id
|
||||
|
||||
# Pre-seed thread participation with whether the requester sent the event.
|
||||
participated = {
|
||||
event_id: events_by_id[event_id].sender == user_id
|
||||
for event_id in thread_event_ids
|
||||
}
|
||||
# For events the requester did not send, check the database for whether
|
||||
# the requester sent a threaded reply.
|
||||
participated.update(
|
||||
await self._main_store.get_threads_participated(
|
||||
[
|
||||
event_id
|
||||
for event_id in thread_event_ids
|
||||
if not participated[event_id]
|
||||
],
|
||||
user_id,
|
||||
)
|
||||
)
|
||||
|
||||
# Then subtract off the results for any ignored users.
|
||||
@@ -343,7 +358,8 @@ class RelationsHandler:
|
||||
count=thread_count,
|
||||
# If there's a thread summary it must also exist in the
|
||||
# participated dictionary.
|
||||
current_user_participated=participated[event_id],
|
||||
current_user_participated=events_by_id[event_id].sender == user_id
|
||||
or participated[event_id],
|
||||
)
|
||||
|
||||
return results
|
||||
@@ -401,9 +417,9 @@ class RelationsHandler:
|
||||
# events to be fetched. Thus, we check those first!
|
||||
|
||||
# Fetch thread summaries (but only for the directly requested events).
|
||||
threads = await self.get_threads_for_events(
|
||||
# It is not valid to start a thread on an event which itself relates to another event.
|
||||
[eid for eid in events_by_id.keys() if eid not in relations_by_id],
|
||||
threads = await self._get_threads_for_events(
|
||||
events_by_id,
|
||||
relations_by_id,
|
||||
user_id,
|
||||
ignored_users,
|
||||
)
|
||||
|
||||
@@ -107,6 +107,7 @@ class EventContext:
|
||||
class RoomCreationHandler:
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.store = hs.get_datastores().main
|
||||
self._storage_controllers = hs.get_storage_controllers()
|
||||
self.auth = hs.get_auth()
|
||||
self.clock = hs.get_clock()
|
||||
self.hs = hs
|
||||
@@ -468,7 +469,6 @@ class RoomCreationHandler:
|
||||
(EventTypes.RoomAvatar, ""),
|
||||
(EventTypes.RoomEncryption, ""),
|
||||
(EventTypes.ServerACL, ""),
|
||||
(EventTypes.RelatedGroups, ""),
|
||||
(EventTypes.PowerLevels, ""),
|
||||
]
|
||||
|
||||
@@ -481,8 +481,10 @@ class RoomCreationHandler:
|
||||
if room_type == RoomTypes.SPACE:
|
||||
types_to_copy.append((EventTypes.SpaceChild, None))
|
||||
|
||||
old_room_state_ids = await self.store.get_filtered_current_state_ids(
|
||||
old_room_id, StateFilter.from_types(types_to_copy)
|
||||
old_room_state_ids = (
|
||||
await self._storage_controllers.state.get_current_state_ids(
|
||||
old_room_id, StateFilter.from_types(types_to_copy)
|
||||
)
|
||||
)
|
||||
# map from event_id to BaseEvent
|
||||
old_room_state_events = await self.store.get_events(old_room_state_ids.values())
|
||||
@@ -559,8 +561,10 @@ class RoomCreationHandler:
|
||||
)
|
||||
|
||||
# Transfer membership events
|
||||
old_room_member_state_ids = await self.store.get_filtered_current_state_ids(
|
||||
old_room_id, StateFilter.from_types([(EventTypes.Member, None)])
|
||||
old_room_member_state_ids = (
|
||||
await self._storage_controllers.state.get_current_state_ids(
|
||||
old_room_id, StateFilter.from_types([(EventTypes.Member, None)])
|
||||
)
|
||||
)
|
||||
|
||||
# map from event_id to BaseEvent
|
||||
@@ -1329,6 +1333,7 @@ class TimestampLookupHandler:
|
||||
self.store = hs.get_datastores().main
|
||||
self.state_handler = hs.get_state_handler()
|
||||
self.federation_client = hs.get_federation_client()
|
||||
self._storage_controllers = hs.get_storage_controllers()
|
||||
|
||||
async def get_event_for_timestamp(
|
||||
self,
|
||||
@@ -1402,7 +1407,9 @@ class TimestampLookupHandler:
|
||||
)
|
||||
|
||||
# Find other homeservers from the given state in the room
|
||||
curr_state = await self.state_handler.get_current_state(room_id)
|
||||
curr_state = await self._storage_controllers.state.get_current_state(
|
||||
room_id
|
||||
)
|
||||
curr_domains = get_domains_from_state(curr_state)
|
||||
likely_domains = [
|
||||
domain for domain, depth in curr_domains if domain != self.server_name
|
||||
|
||||
@@ -50,6 +50,7 @@ EMPTY_THIRD_PARTY_ID = ThirdPartyInstanceID(None, None)
|
||||
class RoomListHandler:
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.store = hs.get_datastores().main
|
||||
self._storage_controllers = hs.get_storage_controllers()
|
||||
self.hs = hs
|
||||
self.enable_room_list_search = hs.config.roomdirectory.enable_room_list_search
|
||||
self.response_cache: ResponseCache[
|
||||
@@ -274,7 +275,7 @@ class RoomListHandler:
|
||||
if aliases:
|
||||
result["aliases"] = aliases
|
||||
|
||||
current_state_ids = await self.store.get_current_state_ids(
|
||||
current_state_ids = await self._storage_controllers.state.get_current_state_ids(
|
||||
room_id, on_invalidate=cache_context.invalidate
|
||||
)
|
||||
|
||||
|
||||
@@ -68,6 +68,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.hs = hs
|
||||
self.store = hs.get_datastores().main
|
||||
self._storage_controllers = hs.get_storage_controllers()
|
||||
self.auth = hs.get_auth()
|
||||
self.state_handler = hs.get_state_handler()
|
||||
self.config = hs.config
|
||||
@@ -994,7 +995,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
||||
# If the host is in the room, but not one of the authorised hosts
|
||||
# for restricted join rules, a remote join must be used.
|
||||
room_version = await self.store.get_room_version(room_id)
|
||||
current_state_ids = await self.store.get_current_state_ids(room_id)
|
||||
current_state_ids = await self._storage_controllers.state.get_current_state_ids(
|
||||
room_id
|
||||
)
|
||||
|
||||
# If restricted join rules are not being used, a local join can always
|
||||
# be used.
|
||||
@@ -1398,7 +1401,19 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
|
||||
txn_id: Optional[str],
|
||||
id_access_token: Optional[str] = None,
|
||||
) -> int:
|
||||
room_state = await self.state_handler.get_current_state(room_id)
|
||||
room_state = await self._storage_controllers.state.get_current_state(
|
||||
room_id,
|
||||
StateFilter.from_types(
|
||||
[
|
||||
(EventTypes.Member, user.to_string()),
|
||||
(EventTypes.CanonicalAlias, ""),
|
||||
(EventTypes.Name, ""),
|
||||
(EventTypes.Create, ""),
|
||||
(EventTypes.JoinRules, ""),
|
||||
(EventTypes.RoomAvatar, ""),
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
inviter_display_name = ""
|
||||
inviter_avatar_url = ""
|
||||
@@ -1794,7 +1809,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
|
||||
async def forget(self, user: UserID, room_id: str) -> None:
|
||||
user_id = user.to_string()
|
||||
|
||||
member = await self.state_handler.get_current_state(
|
||||
member = await self._storage_controllers.state.get_current_state_event(
|
||||
room_id=room_id, event_type=EventTypes.Member, state_key=user_id
|
||||
)
|
||||
membership = member.membership if member else None
|
||||
|
||||
@@ -90,6 +90,7 @@ class RoomSummaryHandler:
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self._event_auth_handler = hs.get_event_auth_handler()
|
||||
self._store = hs.get_datastores().main
|
||||
self._storage_controllers = hs.get_storage_controllers()
|
||||
self._event_serializer = hs.get_event_client_serializer()
|
||||
self._server_name = hs.hostname
|
||||
self._federation_client = hs.get_federation_client()
|
||||
@@ -537,7 +538,7 @@ class RoomSummaryHandler:
|
||||
Returns:
|
||||
True if the room is accessible to the requesting user or server.
|
||||
"""
|
||||
state_ids = await self._store.get_current_state_ids(room_id)
|
||||
state_ids = await self._storage_controllers.state.get_current_state_ids(room_id)
|
||||
|
||||
# If there's no state for the room, it isn't known.
|
||||
if not state_ids:
|
||||
@@ -702,7 +703,9 @@ class RoomSummaryHandler:
|
||||
# there should always be an entry
|
||||
assert stats is not None, "unable to retrieve stats for %s" % (room_id,)
|
||||
|
||||
current_state_ids = await self._store.get_current_state_ids(room_id)
|
||||
current_state_ids = await self._storage_controllers.state.get_current_state_ids(
|
||||
room_id
|
||||
)
|
||||
create_event = await self._store.get_event(
|
||||
current_state_ids[(EventTypes.Create, "")]
|
||||
)
|
||||
@@ -760,7 +763,9 @@ class RoomSummaryHandler:
|
||||
"""
|
||||
|
||||
# look for child rooms/spaces.
|
||||
current_state_ids = await self._store.get_current_state_ids(room_id)
|
||||
current_state_ids = await self._storage_controllers.state.get_current_state_ids(
|
||||
room_id
|
||||
)
|
||||
|
||||
events = await self._store.get_events_as_list(
|
||||
[
|
||||
|
||||
@@ -348,7 +348,7 @@ class SearchHandler:
|
||||
state_results = {}
|
||||
if include_state:
|
||||
for room_id in {e.room_id for e in search_result.allowed_events}:
|
||||
state = await self.state_handler.get_current_state(room_id)
|
||||
state = await self._storage_controllers.state.get_current_state(room_id)
|
||||
state_results[room_id] = list(state.values())
|
||||
|
||||
aggregations = await self._relations_handler.get_bundled_aggregations(
|
||||
|
||||
@@ -40,6 +40,7 @@ class StatsHandler:
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.hs = hs
|
||||
self.store = hs.get_datastores().main
|
||||
self._storage_controllers = hs.get_storage_controllers()
|
||||
self.state = hs.get_state_handler()
|
||||
self.server_name = hs.hostname
|
||||
self.clock = hs.get_clock()
|
||||
@@ -105,7 +106,10 @@ class StatsHandler:
|
||||
logger.debug(
|
||||
"Processing room stats %s->%s", self.pos, room_max_stream_ordering
|
||||
)
|
||||
max_pos, deltas = await self.store.get_current_state_deltas(
|
||||
(
|
||||
max_pos,
|
||||
deltas,
|
||||
) = await self._storage_controllers.state.get_current_state_deltas(
|
||||
self.pos, room_max_stream_ordering
|
||||
)
|
||||
|
||||
|
||||
+28
-12
@@ -506,8 +506,10 @@ class SyncHandler:
|
||||
# ensure that we always include current state in the timeline
|
||||
current_state_ids: FrozenSet[str] = frozenset()
|
||||
if any(e.is_state() for e in recents):
|
||||
current_state_ids_map = await self.store.get_current_state_ids(
|
||||
room_id
|
||||
current_state_ids_map = (
|
||||
await self._state_storage_controller.get_current_state_ids(
|
||||
room_id
|
||||
)
|
||||
)
|
||||
current_state_ids = frozenset(current_state_ids_map.values())
|
||||
|
||||
@@ -574,8 +576,11 @@ class SyncHandler:
|
||||
# ensure that we always include current state in the timeline
|
||||
current_state_ids = frozenset()
|
||||
if any(e.is_state() for e in loaded_recents):
|
||||
current_state_ids_map = await self.store.get_current_state_ids(
|
||||
room_id
|
||||
# FIXME(faster_joins): We use the partial state here as
|
||||
# we don't want to block `/sync` on finishing a lazy join.
|
||||
# Is this the correct way of doing it?
|
||||
current_state_ids_map = (
|
||||
await self.store.get_partial_current_state_ids(room_id)
|
||||
)
|
||||
current_state_ids = frozenset(current_state_ids_map.values())
|
||||
|
||||
@@ -621,21 +626,32 @@ class SyncHandler:
|
||||
)
|
||||
|
||||
async def get_state_after_event(
|
||||
self, event: EventBase, state_filter: Optional[StateFilter] = None
|
||||
self, event_id: str, state_filter: Optional[StateFilter] = None
|
||||
) -> StateMap[str]:
|
||||
"""
|
||||
Get the room state after the given event
|
||||
|
||||
Args:
|
||||
event: event of interest
|
||||
event_id: event of interest
|
||||
state_filter: The state filter used to fetch state from the database.
|
||||
"""
|
||||
state_ids = await self._state_storage_controller.get_state_ids_for_event(
|
||||
event.event_id, state_filter=state_filter or StateFilter.all()
|
||||
event_id, state_filter=state_filter or StateFilter.all()
|
||||
)
|
||||
if event.is_state():
|
||||
|
||||
# using get_metadata_for_events here (instead of get_event) sidesteps an issue
|
||||
# with redactions: if `event_id` is a redaction event, and we don't have the
|
||||
# original (possibly because it got purged), get_event will refuse to return
|
||||
# the redaction event, which isn't terribly helpful here.
|
||||
#
|
||||
# (To be fair, in that case we could assume it's *not* a state event, and
|
||||
# therefore we don't need to worry about it. But still, it seems cleaner just
|
||||
# to pull the metadata.)
|
||||
m = (await self.store.get_metadata_for_events([event_id]))[event_id]
|
||||
if m.state_key is not None and m.rejection_reason is None:
|
||||
state_ids = dict(state_ids)
|
||||
state_ids[(event.type, event.state_key)] = event.event_id
|
||||
state_ids[(m.event_type, m.state_key)] = event_id
|
||||
|
||||
return state_ids
|
||||
|
||||
async def get_state_at(
|
||||
@@ -654,14 +670,14 @@ class SyncHandler:
|
||||
# FIXME: This gets the state at the latest event before the stream ordering,
|
||||
# which might not be the same as the "current state" of the room at the time
|
||||
# of the stream token if there were multiple forward extremities at the time.
|
||||
last_event = await self.store.get_last_event_in_room_before_stream_ordering(
|
||||
last_event_id = await self.store.get_last_event_in_room_before_stream_ordering(
|
||||
room_id,
|
||||
end_token=stream_position.room_key,
|
||||
)
|
||||
|
||||
if last_event:
|
||||
if last_event_id:
|
||||
state = await self.get_state_after_event(
|
||||
last_event, state_filter=state_filter or StateFilter.all()
|
||||
last_event_id, state_filter=state_filter or StateFilter.all()
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
@@ -59,6 +59,7 @@ class FollowerTypingHandler:
|
||||
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.store = hs.get_datastores().main
|
||||
self._storage_controllers = hs.get_storage_controllers()
|
||||
self.server_name = hs.config.server.server_name
|
||||
self.clock = hs.get_clock()
|
||||
self.is_mine_id = hs.is_mine_id
|
||||
@@ -131,7 +132,6 @@ class FollowerTypingHandler:
|
||||
return
|
||||
|
||||
try:
|
||||
users = await self.store.get_users_in_room(member.room_id)
|
||||
self._member_last_federation_poke[member] = self.clock.time_msec()
|
||||
|
||||
now = self.clock.time_msec()
|
||||
@@ -139,7 +139,10 @@ class FollowerTypingHandler:
|
||||
now=now, obj=member, then=now + FEDERATION_PING_INTERVAL
|
||||
)
|
||||
|
||||
for domain in {get_domain_from_id(u) for u in users}:
|
||||
hosts = await self._storage_controllers.state.get_current_hosts_in_room(
|
||||
member.room_id
|
||||
)
|
||||
for domain in hosts:
|
||||
if domain != self.server_name:
|
||||
logger.debug("sending typing update to %s", domain)
|
||||
self.federation.build_and_send_edu(
|
||||
|
||||
@@ -56,6 +56,7 @@ class UserDirectoryHandler(StateDeltasHandler):
|
||||
super().__init__(hs)
|
||||
|
||||
self.store = hs.get_datastores().main
|
||||
self._storage_controllers = hs.get_storage_controllers()
|
||||
self.server_name = hs.hostname
|
||||
self.clock = hs.get_clock()
|
||||
self.notifier = hs.get_notifier()
|
||||
@@ -174,7 +175,10 @@ class UserDirectoryHandler(StateDeltasHandler):
|
||||
logger.debug(
|
||||
"Processing user stats %s->%s", self.pos, room_max_stream_ordering
|
||||
)
|
||||
max_pos, deltas = await self.store.get_current_state_deltas(
|
||||
(
|
||||
max_pos,
|
||||
deltas,
|
||||
) = await self._storage_controllers.state.get_current_state_deltas(
|
||||
self.pos, room_max_stream_ordering
|
||||
)
|
||||
|
||||
|
||||
@@ -194,6 +194,7 @@ class ModuleApi:
|
||||
self._store: Union[
|
||||
DataStore, "GenericWorkerSlavedStore"
|
||||
] = hs.get_datastores().main
|
||||
self._storage_controllers = hs.get_storage_controllers()
|
||||
self._auth = hs.get_auth()
|
||||
self._auth_handler = auth_handler
|
||||
self._server_name = hs.hostname
|
||||
@@ -911,7 +912,7 @@ class ModuleApi:
|
||||
The filtered state events in the room.
|
||||
"""
|
||||
state_ids = yield defer.ensureDeferred(
|
||||
self._store.get_filtered_current_state_ids(
|
||||
self._storage_controllers.state.get_current_state_ids(
|
||||
room_id=room_id, state_filter=StateFilter.from_types(types)
|
||||
)
|
||||
)
|
||||
@@ -1289,20 +1290,16 @@ class ModuleApi:
|
||||
# regardless of their state key
|
||||
]
|
||||
"""
|
||||
state_filter = None
|
||||
if event_filter:
|
||||
# If a filter was provided, turn it into a StateFilter and retrieve a filtered
|
||||
# view of the state.
|
||||
state_filter = StateFilter.from_types(event_filter)
|
||||
state_ids = await self._store.get_filtered_current_state_ids(
|
||||
room_id,
|
||||
state_filter,
|
||||
)
|
||||
else:
|
||||
# If no filter was provided, get the whole state. We could also reuse the call
|
||||
# to get_filtered_current_state_ids above, with `state_filter = StateFilter.all()`,
|
||||
# but get_filtered_current_state_ids isn't cached and `get_current_state_ids`
|
||||
# is, so using the latter when we can is better for perf.
|
||||
state_ids = await self._store.get_current_state_ids(room_id)
|
||||
|
||||
state_ids = await self._storage_controllers.state.get_current_state_ids(
|
||||
room_id,
|
||||
state_filter,
|
||||
)
|
||||
|
||||
state_events = await self._store.get_events(state_ids.values())
|
||||
|
||||
|
||||
+1
-1
@@ -681,7 +681,7 @@ class Notifier:
|
||||
return joined_room_ids, True
|
||||
|
||||
async def _is_world_readable(self, room_id: str) -> bool:
|
||||
state = await self.state_handler.get_current_state(
|
||||
state = await self._storage_controllers.state.get_current_state_event(
|
||||
room_id, EventTypes.RoomHistoryVisibility, ""
|
||||
)
|
||||
if state and "history_visibility" in state.content:
|
||||
|
||||
@@ -255,7 +255,9 @@ class Mailer:
|
||||
user_display_name = user_id
|
||||
|
||||
async def _fetch_room_state(room_id: str) -> None:
|
||||
room_state = await self.store.get_current_state_ids(room_id)
|
||||
room_state = await self._state_storage_controller.get_current_state_ids(
|
||||
room_id
|
||||
)
|
||||
state_by_room[room_id] = room_state
|
||||
|
||||
# Run at most 3 of these at once: sync does 10 at a time but email
|
||||
|
||||
+27
-10
@@ -34,6 +34,7 @@ from synapse.rest.admin._base import (
|
||||
assert_user_is_admin,
|
||||
)
|
||||
from synapse.storage.databases.main.room import RoomSortOrder
|
||||
from synapse.storage.state import StateFilter
|
||||
from synapse.types import JsonDict, RoomID, UserID, create_requester
|
||||
from synapse.util import json_decoder
|
||||
|
||||
@@ -418,6 +419,7 @@ class RoomStateRestServlet(RestServlet):
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self.auth = hs.get_auth()
|
||||
self.store = hs.get_datastores().main
|
||||
self._storage_controllers = hs.get_storage_controllers()
|
||||
self.clock = hs.get_clock()
|
||||
self._event_serializer = hs.get_event_client_serializer()
|
||||
|
||||
@@ -430,7 +432,7 @@ class RoomStateRestServlet(RestServlet):
|
||||
if not ret:
|
||||
raise NotFoundError("Room not found")
|
||||
|
||||
event_ids = await self.store.get_current_state_ids(room_id)
|
||||
event_ids = await self._storage_controllers.state.get_current_state_ids(room_id)
|
||||
events = await self.store.get_events(event_ids.values())
|
||||
now = self.clock.time_msec()
|
||||
room_state = self._event_serializer.serialize_events(events.values(), now)
|
||||
@@ -447,7 +449,8 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet):
|
||||
super().__init__(hs)
|
||||
self.auth = hs.get_auth()
|
||||
self.admin_handler = hs.get_admin_handler()
|
||||
self.state_handler = hs.get_state_handler()
|
||||
self.store = hs.get_datastores().main
|
||||
self._storage_controllers = hs.get_storage_controllers()
|
||||
self.is_mine = hs.is_mine
|
||||
|
||||
async def on_POST(
|
||||
@@ -489,8 +492,11 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet):
|
||||
)
|
||||
|
||||
# send invite if room has "JoinRules.INVITE"
|
||||
room_state = await self.state_handler.get_current_state(room_id)
|
||||
join_rules_event = room_state.get((EventTypes.JoinRules, ""))
|
||||
join_rules_event = (
|
||||
await self._storage_controllers.state.get_current_state_event(
|
||||
room_id, EventTypes.JoinRules, ""
|
||||
)
|
||||
)
|
||||
if join_rules_event:
|
||||
if not (join_rules_event.content.get("join_rule") == JoinRules.PUBLIC):
|
||||
# update_membership with an action of "invite" can raise a
|
||||
@@ -535,6 +541,7 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet):
|
||||
super().__init__(hs)
|
||||
self.auth = hs.get_auth()
|
||||
self.store = hs.get_datastores().main
|
||||
self._state_storage_controller = hs.get_storage_controllers().state
|
||||
self.event_creation_handler = hs.get_event_creation_handler()
|
||||
self.state_handler = hs.get_state_handler()
|
||||
self.is_mine_id = hs.is_mine_id
|
||||
@@ -552,12 +559,22 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet):
|
||||
user_to_add = content.get("user_id", requester.user.to_string())
|
||||
|
||||
# Figure out which local users currently have power in the room, if any.
|
||||
room_state = await self.state_handler.get_current_state(room_id)
|
||||
if not room_state:
|
||||
filtered_room_state = await self._state_storage_controller.get_current_state(
|
||||
room_id,
|
||||
StateFilter.from_types(
|
||||
[
|
||||
(EventTypes.Create, ""),
|
||||
(EventTypes.PowerLevels, ""),
|
||||
(EventTypes.JoinRules, ""),
|
||||
(EventTypes.Member, user_to_add),
|
||||
]
|
||||
),
|
||||
)
|
||||
if not filtered_room_state:
|
||||
raise SynapseError(HTTPStatus.BAD_REQUEST, "Server not in room")
|
||||
|
||||
create_event = room_state[(EventTypes.Create, "")]
|
||||
power_levels = room_state.get((EventTypes.PowerLevels, ""))
|
||||
create_event = filtered_room_state[(EventTypes.Create, "")]
|
||||
power_levels = filtered_room_state.get((EventTypes.PowerLevels, ""))
|
||||
|
||||
if power_levels is not None:
|
||||
# We pick the local user with the highest power.
|
||||
@@ -633,7 +650,7 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet):
|
||||
|
||||
# Now we check if the user we're granting admin rights to is already in
|
||||
# the room. If not and it's not a public room we invite them.
|
||||
member_event = room_state.get((EventTypes.Member, user_to_add))
|
||||
member_event = filtered_room_state.get((EventTypes.Member, user_to_add))
|
||||
is_joined = False
|
||||
if member_event:
|
||||
is_joined = member_event.content["membership"] in (
|
||||
@@ -644,7 +661,7 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet):
|
||||
if is_joined:
|
||||
return HTTPStatus.OK, {}
|
||||
|
||||
join_rules = room_state.get((EventTypes.JoinRules, ""))
|
||||
join_rules = filtered_room_state.get((EventTypes.JoinRules, ""))
|
||||
is_public = False
|
||||
if join_rules:
|
||||
is_public = join_rules.content.get("join_rule") == JoinRules.PUBLIC
|
||||
|
||||
@@ -226,6 +226,13 @@ class UserRestServletV2(RestServlet):
|
||||
if not isinstance(password, str) or len(password) > 512:
|
||||
raise SynapseError(HTTPStatus.BAD_REQUEST, "Invalid password")
|
||||
|
||||
logout_devices = body.get("logout_devices", True)
|
||||
if not isinstance(logout_devices, bool):
|
||||
raise SynapseError(
|
||||
HTTPStatus.BAD_REQUEST,
|
||||
"'logout_devices' parameter is not of type boolean",
|
||||
)
|
||||
|
||||
deactivate = body.get("deactivated", False)
|
||||
if not isinstance(deactivate, bool):
|
||||
raise SynapseError(
|
||||
@@ -305,7 +312,6 @@ class UserRestServletV2(RestServlet):
|
||||
await self.store.set_server_admin(target_user, set_admin_to)
|
||||
|
||||
if password is not None:
|
||||
logout_devices = True
|
||||
new_password_hash = await self.auth_handler.hash(password)
|
||||
|
||||
await self.set_password_handler.set_password(
|
||||
|
||||
@@ -650,6 +650,7 @@ class RoomEventServlet(RestServlet):
|
||||
self.clock = hs.get_clock()
|
||||
self._store = hs.get_datastores().main
|
||||
self._state = hs.get_state_handler()
|
||||
self._storage_controllers = hs.get_storage_controllers()
|
||||
self.event_handler = hs.get_event_handler()
|
||||
self._event_serializer = hs.get_event_client_serializer()
|
||||
self._relations_handler = hs.get_relations_handler()
|
||||
@@ -673,8 +674,10 @@ class RoomEventServlet(RestServlet):
|
||||
if include_unredacted_content and not await self.auth.is_server_admin(
|
||||
requester.user
|
||||
):
|
||||
power_level_event = await self._state.get_current_state(
|
||||
room_id, EventTypes.PowerLevels, ""
|
||||
power_level_event = (
|
||||
await self._storage_controllers.state.get_current_state_event(
|
||||
room_id, EventTypes.PowerLevels, ""
|
||||
)
|
||||
)
|
||||
|
||||
auth_events = {}
|
||||
|
||||
@@ -587,15 +587,16 @@ class MediaRepository:
|
||||
)
|
||||
return None
|
||||
|
||||
t_byte_source = await defer_to_thread(
|
||||
self.hs.get_reactor(),
|
||||
self._generate_thumbnail,
|
||||
thumbnailer,
|
||||
t_width,
|
||||
t_height,
|
||||
t_method,
|
||||
t_type,
|
||||
)
|
||||
with thumbnailer:
|
||||
t_byte_source = await defer_to_thread(
|
||||
self.hs.get_reactor(),
|
||||
self._generate_thumbnail,
|
||||
thumbnailer,
|
||||
t_width,
|
||||
t_height,
|
||||
t_method,
|
||||
t_type,
|
||||
)
|
||||
|
||||
if t_byte_source:
|
||||
try:
|
||||
@@ -657,15 +658,16 @@ class MediaRepository:
|
||||
)
|
||||
return None
|
||||
|
||||
t_byte_source = await defer_to_thread(
|
||||
self.hs.get_reactor(),
|
||||
self._generate_thumbnail,
|
||||
thumbnailer,
|
||||
t_width,
|
||||
t_height,
|
||||
t_method,
|
||||
t_type,
|
||||
)
|
||||
with thumbnailer:
|
||||
t_byte_source = await defer_to_thread(
|
||||
self.hs.get_reactor(),
|
||||
self._generate_thumbnail,
|
||||
thumbnailer,
|
||||
t_width,
|
||||
t_height,
|
||||
t_method,
|
||||
t_type,
|
||||
)
|
||||
|
||||
if t_byte_source:
|
||||
try:
|
||||
@@ -749,119 +751,134 @@ class MediaRepository:
|
||||
)
|
||||
return None
|
||||
|
||||
m_width = thumbnailer.width
|
||||
m_height = thumbnailer.height
|
||||
with thumbnailer:
|
||||
m_width = thumbnailer.width
|
||||
m_height = thumbnailer.height
|
||||
|
||||
if m_width * m_height >= self.max_image_pixels:
|
||||
logger.info(
|
||||
"Image too large to thumbnail %r x %r > %r",
|
||||
m_width,
|
||||
m_height,
|
||||
self.max_image_pixels,
|
||||
)
|
||||
return None
|
||||
|
||||
if thumbnailer.transpose_method is not None:
|
||||
m_width, m_height = await defer_to_thread(
|
||||
self.hs.get_reactor(), thumbnailer.transpose
|
||||
)
|
||||
|
||||
# We deduplicate the thumbnail sizes by ignoring the cropped versions if
|
||||
# they have the same dimensions of a scaled one.
|
||||
thumbnails: Dict[Tuple[int, int, str], str] = {}
|
||||
for requirement in requirements:
|
||||
if requirement.method == "crop":
|
||||
thumbnails.setdefault(
|
||||
(requirement.width, requirement.height, requirement.media_type),
|
||||
requirement.method,
|
||||
if m_width * m_height >= self.max_image_pixels:
|
||||
logger.info(
|
||||
"Image too large to thumbnail %r x %r > %r",
|
||||
m_width,
|
||||
m_height,
|
||||
self.max_image_pixels,
|
||||
)
|
||||
elif requirement.method == "scale":
|
||||
t_width, t_height = thumbnailer.aspect(
|
||||
requirement.width, requirement.height
|
||||
return None
|
||||
|
||||
if thumbnailer.transpose_method is not None:
|
||||
m_width, m_height = await defer_to_thread(
|
||||
self.hs.get_reactor(), thumbnailer.transpose
|
||||
)
|
||||
t_width = min(m_width, t_width)
|
||||
t_height = min(m_height, t_height)
|
||||
thumbnails[
|
||||
(t_width, t_height, requirement.media_type)
|
||||
] = requirement.method
|
||||
|
||||
# Now we generate the thumbnails for each dimension, store it
|
||||
for (t_width, t_height, t_type), t_method in thumbnails.items():
|
||||
# Generate the thumbnail
|
||||
if t_method == "crop":
|
||||
t_byte_source = await defer_to_thread(
|
||||
self.hs.get_reactor(), thumbnailer.crop, t_width, t_height, t_type
|
||||
)
|
||||
elif t_method == "scale":
|
||||
t_byte_source = await defer_to_thread(
|
||||
self.hs.get_reactor(), thumbnailer.scale, t_width, t_height, t_type
|
||||
)
|
||||
else:
|
||||
logger.error("Unrecognized method: %r", t_method)
|
||||
continue
|
||||
|
||||
if not t_byte_source:
|
||||
continue
|
||||
|
||||
file_info = FileInfo(
|
||||
server_name=server_name,
|
||||
file_id=file_id,
|
||||
url_cache=url_cache,
|
||||
thumbnail=ThumbnailInfo(
|
||||
width=t_width,
|
||||
height=t_height,
|
||||
method=t_method,
|
||||
type=t_type,
|
||||
),
|
||||
)
|
||||
|
||||
with self.media_storage.store_into_file(file_info) as (f, fname, finish):
|
||||
try:
|
||||
await self.media_storage.write_to_file(t_byte_source, f)
|
||||
await finish()
|
||||
finally:
|
||||
t_byte_source.close()
|
||||
|
||||
t_len = os.path.getsize(fname)
|
||||
|
||||
# Write to database
|
||||
if server_name:
|
||||
# Multiple remote media download requests can race (when
|
||||
# using multiple media repos), so this may throw a violation
|
||||
# constraint exception. If it does we'll delete the newly
|
||||
# generated thumbnail from disk (as we're in the ctx
|
||||
# manager).
|
||||
#
|
||||
# However: we've already called `finish()` so we may have
|
||||
# also written to the storage providers. This is preferable
|
||||
# to the alternative where we call `finish()` *after* this,
|
||||
# where we could end up having an entry in the DB but fail
|
||||
# to write the files to the storage providers.
|
||||
try:
|
||||
await self.store.store_remote_media_thumbnail(
|
||||
server_name,
|
||||
media_id,
|
||||
file_id,
|
||||
t_width,
|
||||
t_height,
|
||||
t_type,
|
||||
t_method,
|
||||
t_len,
|
||||
)
|
||||
except Exception as e:
|
||||
thumbnail_exists = await self.store.get_remote_media_thumbnail(
|
||||
server_name,
|
||||
media_id,
|
||||
t_width,
|
||||
t_height,
|
||||
t_type,
|
||||
)
|
||||
if not thumbnail_exists:
|
||||
raise e
|
||||
else:
|
||||
await self.store.store_local_thumbnail(
|
||||
media_id, t_width, t_height, t_type, t_method, t_len
|
||||
# We deduplicate the thumbnail sizes by ignoring the cropped versions if
|
||||
# they have the same dimensions of a scaled one.
|
||||
thumbnails: Dict[Tuple[int, int, str], str] = {}
|
||||
for requirement in requirements:
|
||||
if requirement.method == "crop":
|
||||
thumbnails.setdefault(
|
||||
(requirement.width, requirement.height, requirement.media_type),
|
||||
requirement.method,
|
||||
)
|
||||
elif requirement.method == "scale":
|
||||
t_width, t_height = thumbnailer.aspect(
|
||||
requirement.width, requirement.height
|
||||
)
|
||||
t_width = min(m_width, t_width)
|
||||
t_height = min(m_height, t_height)
|
||||
thumbnails[
|
||||
(t_width, t_height, requirement.media_type)
|
||||
] = requirement.method
|
||||
|
||||
# Now we generate the thumbnails for each dimension, store it
|
||||
for (t_width, t_height, t_type), t_method in thumbnails.items():
|
||||
# Generate the thumbnail
|
||||
if t_method == "crop":
|
||||
t_byte_source = await defer_to_thread(
|
||||
self.hs.get_reactor(),
|
||||
thumbnailer.crop,
|
||||
t_width,
|
||||
t_height,
|
||||
t_type,
|
||||
)
|
||||
elif t_method == "scale":
|
||||
t_byte_source = await defer_to_thread(
|
||||
self.hs.get_reactor(),
|
||||
thumbnailer.scale,
|
||||
t_width,
|
||||
t_height,
|
||||
t_type,
|
||||
)
|
||||
else:
|
||||
logger.error("Unrecognized method: %r", t_method)
|
||||
continue
|
||||
|
||||
if not t_byte_source:
|
||||
continue
|
||||
|
||||
file_info = FileInfo(
|
||||
server_name=server_name,
|
||||
file_id=file_id,
|
||||
url_cache=url_cache,
|
||||
thumbnail=ThumbnailInfo(
|
||||
width=t_width,
|
||||
height=t_height,
|
||||
method=t_method,
|
||||
type=t_type,
|
||||
),
|
||||
)
|
||||
|
||||
with self.media_storage.store_into_file(file_info) as (
|
||||
f,
|
||||
fname,
|
||||
finish,
|
||||
):
|
||||
try:
|
||||
await self.media_storage.write_to_file(t_byte_source, f)
|
||||
await finish()
|
||||
finally:
|
||||
t_byte_source.close()
|
||||
|
||||
t_len = os.path.getsize(fname)
|
||||
|
||||
# Write to database
|
||||
if server_name:
|
||||
# Multiple remote media download requests can race (when
|
||||
# using multiple media repos), so this may throw a violation
|
||||
# constraint exception. If it does we'll delete the newly
|
||||
# generated thumbnail from disk (as we're in the ctx
|
||||
# manager).
|
||||
#
|
||||
# However: we've already called `finish()` so we may have
|
||||
# also written to the storage providers. This is preferable
|
||||
# to the alternative where we call `finish()` *after* this,
|
||||
# where we could end up having an entry in the DB but fail
|
||||
# to write the files to the storage providers.
|
||||
try:
|
||||
await self.store.store_remote_media_thumbnail(
|
||||
server_name,
|
||||
media_id,
|
||||
file_id,
|
||||
t_width,
|
||||
t_height,
|
||||
t_type,
|
||||
t_method,
|
||||
t_len,
|
||||
)
|
||||
except Exception as e:
|
||||
thumbnail_exists = (
|
||||
await self.store.get_remote_media_thumbnail(
|
||||
server_name,
|
||||
media_id,
|
||||
t_width,
|
||||
t_height,
|
||||
t_type,
|
||||
)
|
||||
)
|
||||
if not thumbnail_exists:
|
||||
raise e
|
||||
else:
|
||||
await self.store.store_local_thumbnail(
|
||||
media_id, t_width, t_height, t_type, t_method, t_len
|
||||
)
|
||||
|
||||
return {"width": m_width, "height": m_height}
|
||||
|
||||
|
||||
@@ -30,6 +30,9 @@ _xml_encoding_match = re.compile(
|
||||
)
|
||||
_content_type_match = re.compile(r'.*; *charset="?(.*?)"?(;|$)', flags=re.I)
|
||||
|
||||
# Certain elements aren't meant for display.
|
||||
ARIA_ROLES_TO_IGNORE = {"directory", "menu", "menubar", "toolbar"}
|
||||
|
||||
|
||||
def _normalise_encoding(encoding: str) -> Optional[str]:
|
||||
"""Use the Python codec's name as the normalised entry."""
|
||||
@@ -174,13 +177,15 @@ def parse_html_to_open_graph(tree: "etree.Element") -> Dict[str, Optional[str]]:
|
||||
# "og:video:secure_url": "https://www.youtube.com/v/LXDBoHyjmtw?version=3",
|
||||
|
||||
og: Dict[str, Optional[str]] = {}
|
||||
for tag in tree.xpath("//*/meta[starts-with(@property, 'og:')]"):
|
||||
if "content" in tag.attrib:
|
||||
# if we've got more than 50 tags, someone is taking the piss
|
||||
if len(og) >= 50:
|
||||
logger.warning("Skipping OG for page with too many 'og:' tags")
|
||||
return {}
|
||||
og[tag.attrib["property"]] = tag.attrib["content"]
|
||||
for tag in tree.xpath(
|
||||
"//*/meta[starts-with(@property, 'og:')][@content][not(@content='')]"
|
||||
):
|
||||
# if we've got more than 50 tags, someone is taking the piss
|
||||
if len(og) >= 50:
|
||||
logger.warning("Skipping OG for page with too many 'og:' tags")
|
||||
return {}
|
||||
|
||||
og[tag.attrib["property"]] = tag.attrib["content"]
|
||||
|
||||
# TODO: grab article: meta tags too, e.g.:
|
||||
|
||||
@@ -192,21 +197,23 @@ def parse_html_to_open_graph(tree: "etree.Element") -> Dict[str, Optional[str]]:
|
||||
# "article:modified_time" content="2016-04-01T18:31:53+00:00" />
|
||||
|
||||
if "og:title" not in og:
|
||||
# do some basic spidering of the HTML
|
||||
title = tree.xpath("(//title)[1] | (//h1)[1] | (//h2)[1] | (//h3)[1]")
|
||||
if title and title[0].text is not None:
|
||||
og["og:title"] = title[0].text.strip()
|
||||
# Attempt to find a title from the title tag, or the biggest header on the page.
|
||||
title = tree.xpath("((//title)[1] | (//h1)[1] | (//h2)[1] | (//h3)[1])/text()")
|
||||
if title:
|
||||
og["og:title"] = title[0].strip()
|
||||
else:
|
||||
og["og:title"] = None
|
||||
|
||||
if "og:image" not in og:
|
||||
# TODO: extract a favicon failing all else
|
||||
meta_image = tree.xpath(
|
||||
"//*/meta[translate(@itemprop, 'IMAGE', 'image')='image']/@content"
|
||||
"//*/meta[translate(@itemprop, 'IMAGE', 'image')='image'][not(@content='')]/@content[1]"
|
||||
)
|
||||
# If a meta image is found, use it.
|
||||
if meta_image:
|
||||
og["og:image"] = meta_image[0]
|
||||
else:
|
||||
# Try to find images which are larger than 10px by 10px.
|
||||
#
|
||||
# TODO: consider inlined CSS styles as well as width & height attribs
|
||||
images = tree.xpath("//img[@src][number(@width)>10][number(@height)>10]")
|
||||
images = sorted(
|
||||
@@ -215,17 +222,24 @@ def parse_html_to_open_graph(tree: "etree.Element") -> Dict[str, Optional[str]]:
|
||||
-1 * float(i.attrib["width"]) * float(i.attrib["height"])
|
||||
),
|
||||
)
|
||||
# If no images were found, try to find *any* images.
|
||||
if not images:
|
||||
images = tree.xpath("//img[@src]")
|
||||
images = tree.xpath("//img[@src][1]")
|
||||
if images:
|
||||
og["og:image"] = images[0].attrib["src"]
|
||||
|
||||
# Finally, fallback to the favicon if nothing else.
|
||||
else:
|
||||
favicons = tree.xpath("//link[@href][contains(@rel, 'icon')]/@href[1]")
|
||||
if favicons:
|
||||
og["og:image"] = favicons[0]
|
||||
|
||||
if "og:description" not in og:
|
||||
# Check the first meta description tag for content.
|
||||
meta_description = tree.xpath(
|
||||
"//*/meta"
|
||||
"[translate(@name, 'DESCRIPTION', 'description')='description']"
|
||||
"/@content"
|
||||
"//*/meta[translate(@name, 'DESCRIPTION', 'description')='description'][not(@content='')]/@content[1]"
|
||||
)
|
||||
# If a meta description is found with content, use it.
|
||||
if meta_description:
|
||||
og["og:description"] = meta_description[0]
|
||||
else:
|
||||
@@ -306,6 +320,10 @@ def _iterate_over_text(
|
||||
if isinstance(el, str):
|
||||
yield el
|
||||
elif el.tag not in tags_to_ignore:
|
||||
# If the element isn't meant for display, ignore it.
|
||||
if el.get("role") in ARIA_ROLES_TO_IGNORE:
|
||||
continue
|
||||
|
||||
# el.text is the text before the first child, so we can immediately
|
||||
# return it if the text exists.
|
||||
if el.text:
|
||||
|
||||
@@ -586,12 +586,16 @@ class PreviewUrlResource(DirectServeJsonResource):
|
||||
og: The Open Graph dictionary. This is modified with image information.
|
||||
"""
|
||||
# If there's no image or it is blank, there's nothing to do.
|
||||
if "og:image" not in og or not og["og:image"]:
|
||||
if "og:image" not in og:
|
||||
return
|
||||
|
||||
# Remove the raw image URL, this will be replaced with an MXC URL, if successful.
|
||||
image_url = og.pop("og:image")
|
||||
if not image_url:
|
||||
return
|
||||
|
||||
# The image URL from the HTML might be relative to the previewed page,
|
||||
# convert it to an URL which can be requested directly.
|
||||
image_url = og["og:image"]
|
||||
url_parts = urlparse(image_url)
|
||||
if url_parts.scheme != "data":
|
||||
image_url = urljoin(media_info.uri, image_url)
|
||||
@@ -599,7 +603,16 @@ class PreviewUrlResource(DirectServeJsonResource):
|
||||
# FIXME: it might be cleaner to use the same flow as the main /preview_url
|
||||
# request itself and benefit from the same caching etc. But for now we
|
||||
# just rely on the caching on the master request to speed things up.
|
||||
image_info = await self._handle_url(image_url, user, allow_data_urls=True)
|
||||
try:
|
||||
image_info = await self._handle_url(image_url, user, allow_data_urls=True)
|
||||
except Exception as e:
|
||||
# Pre-caching the image failed, don't block the entire URL preview.
|
||||
logger.warning(
|
||||
"Pre-caching image failed during URL preview: %s errored with %s",
|
||||
image_url,
|
||||
e,
|
||||
)
|
||||
return
|
||||
|
||||
if _is_media(image_info.media_type):
|
||||
# TODO: make sure we don't choke on white-on-transparent images
|
||||
@@ -611,13 +624,11 @@ class PreviewUrlResource(DirectServeJsonResource):
|
||||
og["og:image:width"] = dims["width"]
|
||||
og["og:image:height"] = dims["height"]
|
||||
else:
|
||||
logger.warning("Couldn't get dims for %s", og["og:image"])
|
||||
logger.warning("Couldn't get dims for %s", image_url)
|
||||
|
||||
og["og:image"] = f"mxc://{self.server_name}/{image_info.filesystem_id}"
|
||||
og["og:image:type"] = image_info.media_type
|
||||
og["matrix:image:size"] = image_info.media_length
|
||||
else:
|
||||
del og["og:image"]
|
||||
|
||||
async def _handle_oembed_response(
|
||||
self, url: str, media_info: MediaInfo, expiration_ms: int
|
||||
|
||||
@@ -14,7 +14,8 @@
|
||||
# limitations under the License.
|
||||
import logging
|
||||
from io import BytesIO
|
||||
from typing import Tuple
|
||||
from types import TracebackType
|
||||
from typing import Optional, Tuple, Type
|
||||
|
||||
from PIL import Image
|
||||
|
||||
@@ -45,6 +46,9 @@ class Thumbnailer:
|
||||
Image.MAX_IMAGE_PIXELS = max_image_pixels
|
||||
|
||||
def __init__(self, input_path: str):
|
||||
# Have we closed the image?
|
||||
self._closed = False
|
||||
|
||||
try:
|
||||
self.image = Image.open(input_path)
|
||||
except OSError as e:
|
||||
@@ -89,7 +93,8 @@ class Thumbnailer:
|
||||
# Safety: `transpose` takes an int rather than e.g. an IntEnum.
|
||||
# self.transpose_method is set above to be a value in
|
||||
# EXIF_TRANSPOSE_MAPPINGS, and that only contains correct values.
|
||||
self.image = self.image.transpose(self.transpose_method) # type: ignore[arg-type]
|
||||
with self.image:
|
||||
self.image = self.image.transpose(self.transpose_method) # type: ignore[arg-type]
|
||||
self.width, self.height = self.image.size
|
||||
self.transpose_method = None
|
||||
# We don't need EXIF any more
|
||||
@@ -122,9 +127,11 @@ class Thumbnailer:
|
||||
# If the image has transparency, use RGBA instead.
|
||||
if self.image.mode in ["1", "L", "P"]:
|
||||
if self.image.info.get("transparency", None) is not None:
|
||||
self.image = self.image.convert("RGBA")
|
||||
with self.image:
|
||||
self.image = self.image.convert("RGBA")
|
||||
else:
|
||||
self.image = self.image.convert("RGB")
|
||||
with self.image:
|
||||
self.image = self.image.convert("RGB")
|
||||
return self.image.resize((width, height), Image.ANTIALIAS)
|
||||
|
||||
def scale(self, width: int, height: int, output_type: str) -> BytesIO:
|
||||
@@ -133,8 +140,8 @@ class Thumbnailer:
|
||||
Returns:
|
||||
BytesIO: the bytes of the encoded image ready to be written to disk
|
||||
"""
|
||||
scaled = self._resize(width, height)
|
||||
return self._encode_image(scaled, output_type)
|
||||
with self._resize(width, height) as scaled:
|
||||
return self._encode_image(scaled, output_type)
|
||||
|
||||
def crop(self, width: int, height: int, output_type: str) -> BytesIO:
|
||||
"""Rescales and crops the image to the given dimensions preserving
|
||||
@@ -151,18 +158,21 @@ class Thumbnailer:
|
||||
BytesIO: the bytes of the encoded image ready to be written to disk
|
||||
"""
|
||||
if width * self.height > height * self.width:
|
||||
scaled_width = width
|
||||
scaled_height = (width * self.height) // self.width
|
||||
scaled_image = self._resize(width, scaled_height)
|
||||
crop_top = (scaled_height - height) // 2
|
||||
crop_bottom = height + crop_top
|
||||
cropped = scaled_image.crop((0, crop_top, width, crop_bottom))
|
||||
crop = (0, crop_top, width, crop_bottom)
|
||||
else:
|
||||
scaled_width = (height * self.width) // self.height
|
||||
scaled_image = self._resize(scaled_width, height)
|
||||
scaled_height = height
|
||||
crop_left = (scaled_width - width) // 2
|
||||
crop_right = width + crop_left
|
||||
cropped = scaled_image.crop((crop_left, 0, crop_right, height))
|
||||
return self._encode_image(cropped, output_type)
|
||||
crop = (crop_left, 0, crop_right, height)
|
||||
|
||||
with self._resize(scaled_width, scaled_height) as scaled_image:
|
||||
with scaled_image.crop(crop) as cropped:
|
||||
return self._encode_image(cropped, output_type)
|
||||
|
||||
def _encode_image(self, output_image: Image.Image, output_type: str) -> BytesIO:
|
||||
output_bytes_io = BytesIO()
|
||||
@@ -171,3 +181,42 @@ class Thumbnailer:
|
||||
output_image = output_image.convert("RGB")
|
||||
output_image.save(output_bytes_io, fmt, quality=80)
|
||||
return output_bytes_io
|
||||
|
||||
def close(self) -> None:
|
||||
"""Closes the underlying image file.
|
||||
|
||||
Once closed no other functions can be called.
|
||||
|
||||
Can be called multiple times.
|
||||
"""
|
||||
|
||||
if self._closed:
|
||||
return
|
||||
|
||||
self._closed = True
|
||||
|
||||
# Since we run this on the finalizer then we need to handle `__init__`
|
||||
# raising an exception before it can define `self.image`.
|
||||
image = getattr(self, "image", None)
|
||||
if image is None:
|
||||
return
|
||||
|
||||
image.close()
|
||||
|
||||
def __enter__(self) -> "Thumbnailer":
|
||||
"""Make `Thumbnailer` a context manager that calls `close` on
|
||||
`__exit__`.
|
||||
"""
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
type: Optional[Type[BaseException]],
|
||||
value: Optional[BaseException],
|
||||
traceback: Optional[TracebackType],
|
||||
) -> None:
|
||||
self.close()
|
||||
|
||||
def __del__(self) -> None:
|
||||
# Make sure we actually do close the image, rather than leak data.
|
||||
self.close()
|
||||
|
||||
@@ -36,6 +36,7 @@ class ResourceLimitsServerNotices:
|
||||
def __init__(self, hs: "HomeServer"):
|
||||
self._server_notices_manager = hs.get_server_notices_manager()
|
||||
self._store = hs.get_datastores().main
|
||||
self._storage_controllers = hs.get_storage_controllers()
|
||||
self._auth = hs.get_auth()
|
||||
self._config = hs.config
|
||||
self._resouce_limited = False
|
||||
@@ -178,8 +179,10 @@ class ResourceLimitsServerNotices:
|
||||
currently_blocked = False
|
||||
pinned_state_event = None
|
||||
try:
|
||||
pinned_state_event = await self._state.get_current_state(
|
||||
room_id, event_type=EventTypes.Pinned
|
||||
pinned_state_event = (
|
||||
await self._storage_controllers.state.get_current_state_event(
|
||||
room_id, event_type=EventTypes.Pinned, state_key=""
|
||||
)
|
||||
)
|
||||
except AuthError:
|
||||
# The user has yet to join the server notices room
|
||||
|
||||
@@ -32,13 +32,11 @@ from typing import (
|
||||
Set,
|
||||
Tuple,
|
||||
Union,
|
||||
overload,
|
||||
)
|
||||
|
||||
import attr
|
||||
from frozendict import frozendict
|
||||
from prometheus_client import Counter, Histogram
|
||||
from typing_extensions import Literal
|
||||
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, StateResolutionVersions
|
||||
@@ -132,85 +130,20 @@ class StateHandler:
|
||||
self._state_resolution_handler = hs.get_state_resolution_handler()
|
||||
self._storage_controllers = hs.get_storage_controllers()
|
||||
|
||||
@overload
|
||||
async def get_current_state(
|
||||
self,
|
||||
room_id: str,
|
||||
event_type: Literal[None] = None,
|
||||
state_key: str = "",
|
||||
latest_event_ids: Optional[List[str]] = None,
|
||||
) -> StateMap[EventBase]:
|
||||
...
|
||||
|
||||
@overload
|
||||
async def get_current_state(
|
||||
self,
|
||||
room_id: str,
|
||||
event_type: str,
|
||||
state_key: str = "",
|
||||
latest_event_ids: Optional[List[str]] = None,
|
||||
) -> Optional[EventBase]:
|
||||
...
|
||||
|
||||
async def get_current_state(
|
||||
self,
|
||||
room_id: str,
|
||||
event_type: Optional[str] = None,
|
||||
state_key: str = "",
|
||||
latest_event_ids: Optional[List[str]] = None,
|
||||
) -> Union[Optional[EventBase], StateMap[EventBase]]:
|
||||
"""Retrieves the current state for the room. This is done by
|
||||
calling `get_latest_events_in_room` to get the leading edges of the
|
||||
event graph and then resolving any of the state conflicts.
|
||||
|
||||
This is equivalent to getting the state of an event that were to send
|
||||
next before receiving any new events.
|
||||
|
||||
Returns:
|
||||
If `event_type` is specified, then the method returns only the one
|
||||
event (or None) with that `event_type` and `state_key`.
|
||||
|
||||
Otherwise, a map from (type, state_key) to event.
|
||||
"""
|
||||
if not latest_event_ids:
|
||||
latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id)
|
||||
assert latest_event_ids is not None
|
||||
|
||||
logger.debug("calling resolve_state_groups from get_current_state")
|
||||
ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
|
||||
state = ret.state
|
||||
|
||||
if event_type:
|
||||
event_id = state.get((event_type, state_key))
|
||||
event = None
|
||||
if event_id:
|
||||
event = await self.store.get_event(event_id, allow_none=True)
|
||||
return event
|
||||
|
||||
state_map = await self.store.get_events(
|
||||
list(state.values()), get_prev_content=False
|
||||
)
|
||||
return {
|
||||
key: state_map[e_id] for key, e_id in state.items() if e_id in state_map
|
||||
}
|
||||
|
||||
async def get_current_state_ids(
|
||||
self, room_id: str, latest_event_ids: Optional[Collection[str]] = None
|
||||
self,
|
||||
room_id: str,
|
||||
latest_event_ids: Collection[str],
|
||||
) -> StateMap[str]:
|
||||
"""Get the current state, or the state at a set of events, for a room
|
||||
|
||||
Args:
|
||||
room_id:
|
||||
latest_event_ids: if given, the forward extremities to resolve. If
|
||||
None, we look them up from the database (via a cache).
|
||||
latest_event_ids: The forward extremities to resolve.
|
||||
|
||||
Returns:
|
||||
the state dict, mapping from (event_type, state_key) -> event_id
|
||||
"""
|
||||
if not latest_event_ids:
|
||||
latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id)
|
||||
assert latest_event_ids is not None
|
||||
|
||||
logger.debug("calling resolve_state_groups from get_current_state_ids")
|
||||
ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
|
||||
return ret.state
|
||||
@@ -239,10 +172,6 @@ class StateHandler:
|
||||
entry = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
|
||||
return await self.store.get_joined_users_from_state(room_id, entry)
|
||||
|
||||
async def get_current_hosts_in_room(self, room_id: str) -> FrozenSet[str]:
|
||||
event_ids = await self.store.get_latest_event_ids_in_room(room_id)
|
||||
return await self.get_hosts_in_room_at_events(room_id, event_ids)
|
||||
|
||||
async def get_hosts_in_room_at_events(
|
||||
self, room_id: str, event_ids: Collection[str]
|
||||
) -> FrozenSet[str]:
|
||||
|
||||
@@ -71,13 +71,14 @@ class SQLBaseStore(metaclass=ABCMeta):
|
||||
self._attempt_to_invalidate_cache("is_host_joined", (room_id, host))
|
||||
if members_changed:
|
||||
self._attempt_to_invalidate_cache("get_users_in_room", (room_id,))
|
||||
self._attempt_to_invalidate_cache("get_current_hosts_in_room", (room_id,))
|
||||
self._attempt_to_invalidate_cache(
|
||||
"get_users_in_room_with_profiles", (room_id,)
|
||||
)
|
||||
|
||||
# Purge other caches based on room state.
|
||||
self._attempt_to_invalidate_cache("get_room_summary", (room_id,))
|
||||
self._attempt_to_invalidate_cache("get_current_state_ids", (room_id,))
|
||||
self._attempt_to_invalidate_cache("get_partial_current_state_ids", (room_id,))
|
||||
|
||||
def _attempt_to_invalidate_cache(
|
||||
self, cache_name: str, key: Optional[Collection[Any]]
|
||||
|
||||
@@ -18,7 +18,7 @@ from synapse.storage.controllers.persist_events import (
|
||||
EventsPersistenceStorageController,
|
||||
)
|
||||
from synapse.storage.controllers.purge_events import PurgeEventsStorageController
|
||||
from synapse.storage.controllers.state import StateGroupStorageController
|
||||
from synapse.storage.controllers.state import StateStorageController
|
||||
from synapse.storage.databases import Databases
|
||||
from synapse.storage.databases.main import DataStore
|
||||
|
||||
@@ -39,7 +39,7 @@ class StorageControllers:
|
||||
self.main = stores.main
|
||||
|
||||
self.purge_events = PurgeEventsStorageController(hs, stores)
|
||||
self.state = StateGroupStorageController(hs, stores)
|
||||
self.state = StateStorageController(hs, stores)
|
||||
|
||||
self.persistence = None
|
||||
if stores.persist_events:
|
||||
|
||||
@@ -994,7 +994,7 @@ class EventsPersistenceStorageController:
|
||||
|
||||
Assumes that we are only persisting events for one room at a time.
|
||||
"""
|
||||
existing_state = await self.main_store.get_current_state_ids(room_id)
|
||||
existing_state = await self.main_store.get_partial_current_state_ids(room_id)
|
||||
|
||||
to_delete = [key for key in existing_state if key not in current_state]
|
||||
|
||||
@@ -1083,7 +1083,7 @@ class EventsPersistenceStorageController:
|
||||
# The server will leave the room, so we go and find out which remote
|
||||
# users will still be joined when we leave.
|
||||
if current_state is None:
|
||||
current_state = await self.main_store.get_current_state_ids(room_id)
|
||||
current_state = await self.main_store.get_partial_current_state_ids(room_id)
|
||||
current_state = dict(current_state)
|
||||
for key in delta.to_delete:
|
||||
current_state.pop(key, None)
|
||||
|
||||
@@ -14,19 +14,26 @@
|
||||
import logging
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Collection,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
)
|
||||
|
||||
from synapse.api.constants import EventTypes
|
||||
from synapse.events import EventBase
|
||||
from synapse.storage.state import StateFilter
|
||||
from synapse.storage.util.partial_state_events_tracker import PartialStateEventsTracker
|
||||
from synapse.storage.util.partial_state_events_tracker import (
|
||||
PartialCurrentStateTracker,
|
||||
PartialStateEventsTracker,
|
||||
)
|
||||
from synapse.types import MutableStateMap, StateMap
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -36,17 +43,27 @@ if TYPE_CHECKING:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class StateGroupStorageController:
|
||||
"""High level interface to fetching state for event."""
|
||||
class StateStorageController:
|
||||
"""High level interface to fetching state for an event, or the current state
|
||||
in a room.
|
||||
"""
|
||||
|
||||
def __init__(self, hs: "HomeServer", stores: "Databases"):
|
||||
self._is_mine_id = hs.is_mine_id
|
||||
self.stores = stores
|
||||
self._partial_state_events_tracker = PartialStateEventsTracker(stores.main)
|
||||
self._partial_state_room_tracker = PartialCurrentStateTracker(stores.main)
|
||||
|
||||
def notify_event_un_partial_stated(self, event_id: str) -> None:
|
||||
self._partial_state_events_tracker.notify_un_partial_stated(event_id)
|
||||
|
||||
def notify_room_un_partial_stated(self, room_id: str) -> None:
|
||||
"""Notify that the room no longer has any partial state.
|
||||
|
||||
Must be called after `DataStore.clear_partial_state_room`
|
||||
"""
|
||||
self._partial_state_room_tracker.notify_un_partial_stated(room_id)
|
||||
|
||||
async def get_state_group_delta(
|
||||
self, state_group: int
|
||||
) -> Tuple[Optional[int], Optional[StateMap[str]]]:
|
||||
@@ -349,3 +366,127 @@ class StateGroupStorageController:
|
||||
return await self.stores.state.store_state_group(
|
||||
event_id, room_id, prev_group, delta_ids, current_state_ids
|
||||
)
|
||||
|
||||
async def get_current_state_ids(
|
||||
self,
|
||||
room_id: str,
|
||||
state_filter: Optional[StateFilter] = None,
|
||||
on_invalidate: Optional[Callable[[], None]] = None,
|
||||
) -> StateMap[str]:
|
||||
"""Get the current state event ids for a room based on the
|
||||
current_state_events table.
|
||||
|
||||
If a state filter is given (that is not `StateFilter.all()`) the query
|
||||
result is *not* cached.
|
||||
|
||||
Args:
|
||||
room_id: The room to get the state IDs of. state_filter: The state
|
||||
filter used to fetch state from the
|
||||
database.
|
||||
on_invalidate: Callback for when the `get_current_state_ids` cache
|
||||
for the room gets invalidated.
|
||||
|
||||
Returns:
|
||||
The current state of the room.
|
||||
"""
|
||||
if not state_filter or state_filter.must_await_full_state(self._is_mine_id):
|
||||
await self._partial_state_room_tracker.await_full_state(room_id)
|
||||
|
||||
if state_filter and not state_filter.is_full():
|
||||
return await self.stores.main.get_partial_filtered_current_state_ids(
|
||||
room_id, state_filter
|
||||
)
|
||||
else:
|
||||
return await self.stores.main.get_partial_current_state_ids(
|
||||
room_id, on_invalidate=on_invalidate
|
||||
)
|
||||
|
||||
async def get_canonical_alias_for_room(self, room_id: str) -> Optional[str]:
|
||||
"""Get canonical alias for room, if any
|
||||
|
||||
Args:
|
||||
room_id: The room ID
|
||||
|
||||
Returns:
|
||||
The canonical alias, if any
|
||||
"""
|
||||
|
||||
state = await self.get_current_state_ids(
|
||||
room_id, StateFilter.from_types([(EventTypes.CanonicalAlias, "")])
|
||||
)
|
||||
|
||||
event_id = state.get((EventTypes.CanonicalAlias, ""))
|
||||
if not event_id:
|
||||
return None
|
||||
|
||||
event = await self.stores.main.get_event(event_id, allow_none=True)
|
||||
if not event:
|
||||
return None
|
||||
|
||||
return event.content.get("canonical_alias")
|
||||
|
||||
async def get_current_state_deltas(
|
||||
self, prev_stream_id: int, max_stream_id: int
|
||||
) -> Tuple[int, List[Dict[str, Any]]]:
|
||||
"""Fetch a list of room state changes since the given stream id
|
||||
|
||||
Each entry in the result contains the following fields:
|
||||
- stream_id (int)
|
||||
- room_id (str)
|
||||
- type (str): event type
|
||||
- state_key (str):
|
||||
- event_id (str|None): new event_id for this state key. None if the
|
||||
state has been deleted.
|
||||
- prev_event_id (str|None): previous event_id for this state key. None
|
||||
if it's new state.
|
||||
|
||||
Args:
|
||||
prev_stream_id: point to get changes since (exclusive)
|
||||
max_stream_id: the point that we know has been correctly persisted
|
||||
- ie, an upper limit to return changes from.
|
||||
|
||||
Returns:
|
||||
A tuple consisting of:
|
||||
- the stream id which these results go up to
|
||||
- list of current_state_delta_stream rows. If it is empty, we are
|
||||
up to date.
|
||||
"""
|
||||
# FIXME(faster_joins): what do we do here?
|
||||
|
||||
return await self.stores.main.get_partial_current_state_deltas(
|
||||
prev_stream_id, max_stream_id
|
||||
)
|
||||
|
||||
async def get_current_state(
|
||||
self, room_id: str, state_filter: Optional[StateFilter] = None
|
||||
) -> StateMap[EventBase]:
|
||||
"""Same as `get_current_state_ids` but also fetches the events"""
|
||||
state_map_ids = await self.get_current_state_ids(room_id, state_filter)
|
||||
|
||||
event_map = await self.stores.main.get_events(list(state_map_ids.values()))
|
||||
|
||||
state_map = {}
|
||||
for key, event_id in state_map_ids.items():
|
||||
event = event_map.get(event_id)
|
||||
if event:
|
||||
state_map[key] = event
|
||||
|
||||
return state_map
|
||||
|
||||
async def get_current_state_event(
|
||||
self, room_id: str, event_type: str, state_key: str
|
||||
) -> Optional[EventBase]:
|
||||
"""Get the current state event for the given type/state_key."""
|
||||
|
||||
key = (event_type, state_key)
|
||||
state_map = await self.get_current_state(
|
||||
room_id, StateFilter.from_types((key,))
|
||||
)
|
||||
return state_map.get(key)
|
||||
|
||||
async def get_current_hosts_in_room(self, room_id: str) -> Set[str]:
|
||||
"""Get current hosts in room based on current state."""
|
||||
|
||||
await self._partial_state_room_tracker.await_full_state(room_id)
|
||||
|
||||
return await self.stores.main.get_current_hosts_in_room(room_id)
|
||||
|
||||
@@ -151,10 +151,6 @@ class DataStore(
|
||||
],
|
||||
)
|
||||
|
||||
self._group_updates_id_gen = StreamIdGenerator(
|
||||
db_conn, "local_group_updates", "stream_id"
|
||||
)
|
||||
|
||||
self._cache_id_gen: Optional[MultiWriterIdGenerator]
|
||||
if isinstance(self.database_engine, PostgresEngine):
|
||||
# We set the `writers` to an empty list here as we don't care about
|
||||
@@ -197,20 +193,6 @@ class DataStore(
|
||||
prefilled_cache=curr_state_delta_prefill,
|
||||
)
|
||||
|
||||
_group_updates_prefill, min_group_updates_id = self.db_pool.get_cache_dict(
|
||||
db_conn,
|
||||
"local_group_updates",
|
||||
entity_column="user_id",
|
||||
stream_column="stream_id",
|
||||
max_value=self._group_updates_id_gen.get_current_token(),
|
||||
limit=1000,
|
||||
)
|
||||
self._group_updates_stream_cache = StreamChangeCache(
|
||||
"_group_updates_stream_cache",
|
||||
min_group_updates_id,
|
||||
prefilled_cache=_group_updates_prefill,
|
||||
)
|
||||
|
||||
self._stream_order_on_start = self.get_room_max_stream_ordering()
|
||||
self._min_stream_order_on_start = self.get_room_min_stream_ordering()
|
||||
|
||||
|
||||
@@ -29,11 +29,6 @@ class GroupServerStore(SQLBaseStore):
|
||||
db_conn: LoggingDatabaseConnection,
|
||||
hs: "HomeServer",
|
||||
):
|
||||
database.updates.register_background_index_update(
|
||||
update_name="local_group_updates_index",
|
||||
index_name="local_group_updates_stream_id_index",
|
||||
table="local_group_updates",
|
||||
columns=("stream_id",),
|
||||
unique=True,
|
||||
)
|
||||
# Register a legacy groups background update as a no-op.
|
||||
database.updates.register_noop_background_update("local_group_updates_index")
|
||||
super().__init__(database, db_conn, hs)
|
||||
|
||||
@@ -276,10 +276,6 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
|
||||
(SELECT 1
|
||||
FROM profiles
|
||||
WHERE profiles.avatar_url = '{media_prefix}' || lmr.media_id)
|
||||
AND NOT EXISTS
|
||||
(SELECT 1
|
||||
FROM groups
|
||||
WHERE groups.avatar_url = '{media_prefix}' || lmr.media_id)
|
||||
AND NOT EXISTS
|
||||
(SELECT 1
|
||||
FROM room_memberships
|
||||
|
||||
@@ -11,11 +11,10 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Optional
|
||||
|
||||
from synapse.api.errors import StoreError
|
||||
from synapse.storage._base import SQLBaseStore
|
||||
from synapse.storage.database import LoggingTransaction
|
||||
from synapse.storage.databases.main.roommember import ProfileInfo
|
||||
|
||||
|
||||
@@ -55,17 +54,6 @@ class ProfileWorkerStore(SQLBaseStore):
|
||||
desc="get_profile_avatar_url",
|
||||
)
|
||||
|
||||
async def get_from_remote_profile_cache(
|
||||
self, user_id: str
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
return await self.db_pool.simple_select_one(
|
||||
table="remote_profile_cache",
|
||||
keyvalues={"user_id": user_id},
|
||||
retcols=("displayname", "avatar_url"),
|
||||
allow_none=True,
|
||||
desc="get_from_remote_profile_cache",
|
||||
)
|
||||
|
||||
async def create_profile(self, user_localpart: str) -> None:
|
||||
await self.db_pool.simple_insert(
|
||||
table="profiles", values={"user_id": user_localpart}, desc="create_profile"
|
||||
@@ -91,97 +79,6 @@ class ProfileWorkerStore(SQLBaseStore):
|
||||
desc="set_profile_avatar_url",
|
||||
)
|
||||
|
||||
async def update_remote_profile_cache(
|
||||
self, user_id: str, displayname: Optional[str], avatar_url: Optional[str]
|
||||
) -> int:
|
||||
return await self.db_pool.simple_update(
|
||||
table="remote_profile_cache",
|
||||
keyvalues={"user_id": user_id},
|
||||
updatevalues={
|
||||
"displayname": displayname,
|
||||
"avatar_url": avatar_url,
|
||||
"last_check": self._clock.time_msec(),
|
||||
},
|
||||
desc="update_remote_profile_cache",
|
||||
)
|
||||
|
||||
async def maybe_delete_remote_profile_cache(self, user_id: str) -> None:
|
||||
"""Check if we still care about the remote user's profile, and if we
|
||||
don't then remove their profile from the cache
|
||||
"""
|
||||
subscribed = await self.is_subscribed_remote_profile_for_user(user_id)
|
||||
if not subscribed:
|
||||
await self.db_pool.simple_delete(
|
||||
table="remote_profile_cache",
|
||||
keyvalues={"user_id": user_id},
|
||||
desc="delete_remote_profile_cache",
|
||||
)
|
||||
|
||||
async def is_subscribed_remote_profile_for_user(self, user_id: str) -> bool:
|
||||
"""Check whether we are interested in a remote user's profile."""
|
||||
res: Optional[str] = await self.db_pool.simple_select_one_onecol(
|
||||
table="group_users",
|
||||
keyvalues={"user_id": user_id},
|
||||
retcol="user_id",
|
||||
allow_none=True,
|
||||
desc="should_update_remote_profile_cache_for_user",
|
||||
)
|
||||
|
||||
if res:
|
||||
return True
|
||||
|
||||
res = await self.db_pool.simple_select_one_onecol(
|
||||
table="group_invites",
|
||||
keyvalues={"user_id": user_id},
|
||||
retcol="user_id",
|
||||
allow_none=True,
|
||||
desc="should_update_remote_profile_cache_for_user",
|
||||
)
|
||||
|
||||
if res:
|
||||
return True
|
||||
return False
|
||||
|
||||
async def get_remote_profile_cache_entries_that_expire(
|
||||
self, last_checked: int
|
||||
) -> List[Dict[str, str]]:
|
||||
"""Get all users who haven't been checked since `last_checked`"""
|
||||
|
||||
def _get_remote_profile_cache_entries_that_expire_txn(
|
||||
txn: LoggingTransaction,
|
||||
) -> List[Dict[str, str]]:
|
||||
sql = """
|
||||
SELECT user_id, displayname, avatar_url
|
||||
FROM remote_profile_cache
|
||||
WHERE last_check < ?
|
||||
"""
|
||||
|
||||
txn.execute(sql, (last_checked,))
|
||||
|
||||
return self.db_pool.cursor_to_dict(txn)
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_remote_profile_cache_entries_that_expire",
|
||||
_get_remote_profile_cache_entries_that_expire_txn,
|
||||
)
|
||||
|
||||
|
||||
class ProfileStore(ProfileWorkerStore):
|
||||
async def add_remote_profile_cache(
|
||||
self, user_id: str, displayname: str, avatar_url: str
|
||||
) -> None:
|
||||
"""Ensure we are caching the remote user's profiles.
|
||||
|
||||
This should only be called when `is_subscribed_remote_profile_for_user`
|
||||
would return true for the user.
|
||||
"""
|
||||
await self.db_pool.simple_upsert(
|
||||
table="remote_profile_cache",
|
||||
keyvalues={"user_id": user_id},
|
||||
values={
|
||||
"displayname": displayname,
|
||||
"avatar_url": avatar_url,
|
||||
"last_check": self._clock.time_msec(),
|
||||
},
|
||||
desc="add_remote_profile_cache",
|
||||
)
|
||||
pass
|
||||
|
||||
@@ -393,7 +393,6 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
|
||||
"partial_state_events",
|
||||
"events",
|
||||
"federation_inbound_events_staging",
|
||||
"group_rooms",
|
||||
"local_current_membership",
|
||||
"partial_state_rooms_servers",
|
||||
"partial_state_rooms",
|
||||
@@ -413,7 +412,6 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
|
||||
"e2e_room_keys",
|
||||
"event_push_summary",
|
||||
"pusher_throttle",
|
||||
"group_summary_rooms",
|
||||
"room_account_data",
|
||||
"room_tags",
|
||||
# "rooms" happens last, to keep the foreign keys in the other tables
|
||||
|
||||
@@ -1139,6 +1139,24 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
|
||||
keyvalues={"room_id": room_id},
|
||||
)
|
||||
|
||||
async def is_partial_state_room(self, room_id: str) -> bool:
|
||||
"""Checks if this room has partial state.
|
||||
|
||||
Returns true if this is a "partial-state" room, which means that the state
|
||||
at events in the room, and `current_state_events`, may not yet be
|
||||
complete.
|
||||
"""
|
||||
|
||||
entry = await self.db_pool.simple_select_one_onecol(
|
||||
table="partial_state_rooms",
|
||||
keyvalues={"room_id": room_id},
|
||||
retcol="room_id",
|
||||
allow_none=True,
|
||||
desc="is_partial_state_room",
|
||||
)
|
||||
|
||||
return entry is not None
|
||||
|
||||
|
||||
class _BackgroundUpdates:
|
||||
REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory"
|
||||
|
||||
@@ -893,6 +893,43 @@ class RoomMemberWorkerStore(EventsWorkerStore):
|
||||
|
||||
return True
|
||||
|
||||
@cached(iterable=True, max_entries=10000)
|
||||
async def get_current_hosts_in_room(self, room_id: str) -> Set[str]:
|
||||
"""Get current hosts in room based on current state."""
|
||||
|
||||
# First we check if we already have `get_users_in_room` in the cache, as
|
||||
# we can just calculate result from that
|
||||
users = self.get_users_in_room.cache.get_immediate(
|
||||
(room_id,), None, update_metrics=False
|
||||
)
|
||||
if users is not None:
|
||||
return {get_domain_from_id(u) for u in users}
|
||||
|
||||
if isinstance(self.database_engine, Sqlite3Engine):
|
||||
# If we're using SQLite then let's just always use
|
||||
# `get_users_in_room` rather than funky SQL.
|
||||
users = await self.get_users_in_room(room_id)
|
||||
return {get_domain_from_id(u) for u in users}
|
||||
|
||||
# For PostgreSQL we can use a regex to pull out the domains from the
|
||||
# joined users in `current_state_events` via regex.
|
||||
|
||||
def get_current_hosts_in_room_txn(txn: LoggingTransaction) -> Set[str]:
|
||||
sql = """
|
||||
SELECT DISTINCT substring(state_key FROM '@[^:]*:(.*)$')
|
||||
FROM current_state_events
|
||||
WHERE
|
||||
type = 'm.room.member'
|
||||
AND membership = 'join'
|
||||
AND room_id = ?
|
||||
"""
|
||||
txn.execute(sql, (room_id,))
|
||||
return {d for d, in txn}
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_current_hosts_in_room", get_current_hosts_in_room_txn
|
||||
)
|
||||
|
||||
async def get_joined_hosts(
|
||||
self, room_id: str, state_entry: "_StateCacheEntry"
|
||||
) -> FrozenSet[str]:
|
||||
|
||||
@@ -54,6 +54,7 @@ class EventMetadata:
|
||||
room_id: str
|
||||
event_type: str
|
||||
state_key: Optional[str]
|
||||
rejection_reason: Optional[str]
|
||||
|
||||
|
||||
def _retrieve_and_check_room_version(room_id: str, room_version_id: str) -> RoomVersion:
|
||||
@@ -167,17 +168,22 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
)
|
||||
|
||||
sql = f"""
|
||||
SELECT e.event_id, e.room_id, e.type, se.state_key FROM events AS e
|
||||
SELECT e.event_id, e.room_id, e.type, se.state_key, r.reason
|
||||
FROM events AS e
|
||||
LEFT JOIN state_events se USING (event_id)
|
||||
LEFT JOIN rejections r USING (event_id)
|
||||
WHERE {clause}
|
||||
"""
|
||||
|
||||
txn.execute(sql, args)
|
||||
return {
|
||||
event_id: EventMetadata(
|
||||
room_id=room_id, event_type=event_type, state_key=state_key
|
||||
room_id=room_id,
|
||||
event_type=event_type,
|
||||
state_key=state_key,
|
||||
rejection_reason=rejection_reason,
|
||||
)
|
||||
for event_id, room_id, event_type, state_key in txn
|
||||
for event_id, room_id, event_type, state_key, rejection_reason in txn
|
||||
}
|
||||
|
||||
result_map: Dict[str, EventMetadata] = {}
|
||||
@@ -236,7 +242,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
Raises:
|
||||
NotFoundError if the room is unknown
|
||||
"""
|
||||
state_ids = await self.get_current_state_ids(room_id)
|
||||
state_ids = await self.get_partial_current_state_ids(room_id)
|
||||
|
||||
if not state_ids:
|
||||
raise NotFoundError(f"Current state for room {room_id} is empty")
|
||||
@@ -252,10 +258,12 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
return create_event
|
||||
|
||||
@cached(max_entries=100000, iterable=True)
|
||||
async def get_current_state_ids(self, room_id: str) -> StateMap[str]:
|
||||
async def get_partial_current_state_ids(self, room_id: str) -> StateMap[str]:
|
||||
"""Get the current state event ids for a room based on the
|
||||
current_state_events table.
|
||||
|
||||
This may be the partial state if we're lazy joining the room.
|
||||
|
||||
Args:
|
||||
room_id: The room to get the state IDs of.
|
||||
|
||||
@@ -274,17 +282,19 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
return {(intern_string(r[0]), intern_string(r[1])): r[2] for r in txn}
|
||||
|
||||
return await self.db_pool.runInteraction(
|
||||
"get_current_state_ids", _get_current_state_ids_txn
|
||||
"get_partial_current_state_ids", _get_current_state_ids_txn
|
||||
)
|
||||
|
||||
# FIXME: how should this be cached?
|
||||
async def get_filtered_current_state_ids(
|
||||
async def get_partial_filtered_current_state_ids(
|
||||
self, room_id: str, state_filter: Optional[StateFilter] = None
|
||||
) -> StateMap[str]:
|
||||
"""Get the current state event of a given type for a room based on the
|
||||
current_state_events table. This may not be as up-to-date as the result
|
||||
of doing a fresh state resolution as per state_handler.get_current_state
|
||||
|
||||
This may be the partial state if we're lazy joining the room.
|
||||
|
||||
Args:
|
||||
room_id
|
||||
state_filter: The state filter used to fetch state
|
||||
@@ -300,7 +310,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
|
||||
if not where_clause:
|
||||
# We delegate to the cached version
|
||||
return await self.get_current_state_ids(room_id)
|
||||
return await self.get_partial_current_state_ids(room_id)
|
||||
|
||||
def _get_filtered_current_state_ids_txn(
|
||||
txn: LoggingTransaction,
|
||||
@@ -328,30 +338,6 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
"get_filtered_current_state_ids", _get_filtered_current_state_ids_txn
|
||||
)
|
||||
|
||||
async def get_canonical_alias_for_room(self, room_id: str) -> Optional[str]:
|
||||
"""Get canonical alias for room, if any
|
||||
|
||||
Args:
|
||||
room_id: The room ID
|
||||
|
||||
Returns:
|
||||
The canonical alias, if any
|
||||
"""
|
||||
|
||||
state = await self.get_filtered_current_state_ids(
|
||||
room_id, StateFilter.from_types([(EventTypes.CanonicalAlias, "")])
|
||||
)
|
||||
|
||||
event_id = state.get((EventTypes.CanonicalAlias, ""))
|
||||
if not event_id:
|
||||
return None
|
||||
|
||||
event = await self.get_event(event_id, allow_none=True)
|
||||
if not event:
|
||||
return None
|
||||
|
||||
return event.content.get("canonical_alias")
|
||||
|
||||
@cached(max_entries=50000)
|
||||
async def _get_state_group_for_event(self, event_id: str) -> Optional[int]:
|
||||
return await self.db_pool.simple_select_one_onecol(
|
||||
|
||||
@@ -27,7 +27,7 @@ class StateDeltasStore(SQLBaseStore):
|
||||
# attribute. TODO: can we get static analysis to enforce this?
|
||||
_curr_state_delta_stream_cache: StreamChangeCache
|
||||
|
||||
async def get_current_state_deltas(
|
||||
async def get_partial_current_state_deltas(
|
||||
self, prev_stream_id: int, max_stream_id: int
|
||||
) -> Tuple[int, List[Dict[str, Any]]]:
|
||||
"""Fetch a list of room state changes since the given stream id
|
||||
@@ -42,6 +42,8 @@ class StateDeltasStore(SQLBaseStore):
|
||||
- prev_event_id (str|None): previous event_id for this state key. None
|
||||
if it's new state.
|
||||
|
||||
This may be the partial state if we're lazy joining the room.
|
||||
|
||||
Args:
|
||||
prev_stream_id: point to get changes since (exclusive)
|
||||
max_stream_id: the point that we know has been correctly persisted
|
||||
|
||||
@@ -765,15 +765,16 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
self,
|
||||
room_id: str,
|
||||
end_token: RoomStreamToken,
|
||||
) -> Optional[EventBase]:
|
||||
"""Returns the last event in a room at or before a stream ordering
|
||||
) -> Optional[str]:
|
||||
"""Returns the ID of the last event in a room at or before a stream ordering
|
||||
|
||||
Args:
|
||||
room_id
|
||||
end_token: The token used to stream from
|
||||
|
||||
Returns:
|
||||
The most recent event.
|
||||
The ID of the most recent event, or None if there are no events in the room
|
||||
before this stream ordering.
|
||||
"""
|
||||
|
||||
last_row = await self.get_room_event_before_stream_ordering(
|
||||
@@ -781,10 +782,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
|
||||
stream_ordering=end_token.stream,
|
||||
)
|
||||
if last_row:
|
||||
_, _, event_id = last_row
|
||||
event = await self.get_event(event_id, get_prev_content=True)
|
||||
return event
|
||||
|
||||
return last_row[2]
|
||||
return None
|
||||
|
||||
async def get_current_room_stream_token_for_room_id(
|
||||
|
||||
@@ -441,7 +441,9 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
|
||||
(EventTypes.RoomHistoryVisibility, ""),
|
||||
)
|
||||
|
||||
current_state_ids = await self.get_filtered_current_state_ids( # type: ignore[attr-defined]
|
||||
# Getting the partial state is fine, as we're not looking at membership
|
||||
# events.
|
||||
current_state_ids = await self.get_partial_filtered_current_state_ids( # type: ignore[attr-defined]
|
||||
room_id, StateFilter.from_types(types_to_filter)
|
||||
)
|
||||
|
||||
|
||||
@@ -70,6 +70,7 @@ Changes in SCHEMA_VERSION = 70:
|
||||
|
||||
Changes in SCHEMA_VERSION = 71:
|
||||
- event_edges.room_id is no longer read from.
|
||||
- Tables related to groups are no longer accessed.
|
||||
"""
|
||||
|
||||
|
||||
|
||||
@@ -21,6 +21,7 @@ from twisted.internet.defer import Deferred
|
||||
|
||||
from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
|
||||
from synapse.storage.databases.main.events_worker import EventsWorkerStore
|
||||
from synapse.storage.databases.main.room import RoomWorkerStore
|
||||
from synapse.util import unwrapFirstError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -118,3 +119,62 @@ class PartialStateEventsTracker:
|
||||
observer_set.discard(observer)
|
||||
if not observer_set:
|
||||
del self._observers[event_id]
|
||||
|
||||
|
||||
class PartialCurrentStateTracker:
|
||||
"""Keeps track of which rooms have partial state, after partial-state joins"""
|
||||
|
||||
def __init__(self, store: RoomWorkerStore):
|
||||
self._store = store
|
||||
|
||||
# a map from room id to a set of Deferreds which are waiting for that room to be
|
||||
# un-partial-stated.
|
||||
self._observers: Dict[str, Set[Deferred[None]]] = defaultdict(set)
|
||||
|
||||
def notify_un_partial_stated(self, room_id: str) -> None:
|
||||
"""Notify that we now have full current state for a given room
|
||||
|
||||
Unblocks any callers to await_full_state() for that room.
|
||||
|
||||
Args:
|
||||
room_id: the room that now has full current state.
|
||||
"""
|
||||
observers = self._observers.pop(room_id, None)
|
||||
if not observers:
|
||||
return
|
||||
logger.info(
|
||||
"Notifying %i things waiting for un-partial-stating of room %s",
|
||||
len(observers),
|
||||
room_id,
|
||||
)
|
||||
with PreserveLoggingContext():
|
||||
for o in observers:
|
||||
o.callback(None)
|
||||
|
||||
async def await_full_state(self, room_id: str) -> None:
|
||||
# We add the deferred immediately so that the DB call to check for
|
||||
# partial state doesn't race when we unpartial the room.
|
||||
d: Deferred[None] = Deferred()
|
||||
self._observers.setdefault(room_id, set()).add(d)
|
||||
|
||||
try:
|
||||
# Check if the room has partial current state or not.
|
||||
has_partial_state = await self._store.is_partial_state_room(room_id)
|
||||
if not has_partial_state:
|
||||
return
|
||||
|
||||
logger.info(
|
||||
"Awaiting un-partial-stating of room %s",
|
||||
room_id,
|
||||
)
|
||||
|
||||
await make_deferred_yieldable(d)
|
||||
|
||||
logger.info("Room has un-partial-stated")
|
||||
finally:
|
||||
# Remove the added observer, and remove the room entry if its empty.
|
||||
ds = self._observers.get(room_id)
|
||||
if ds is not None:
|
||||
ds.discard(d)
|
||||
if not ds:
|
||||
self._observers.pop(room_id, None)
|
||||
|
||||
+18
-10
@@ -162,16 +162,7 @@ async def filter_events_for_client(
|
||||
state = event_id_to_state[event.event_id]
|
||||
|
||||
# get the room_visibility at the time of the event.
|
||||
visibility_event = state.get(_HISTORY_VIS_KEY, None)
|
||||
if visibility_event:
|
||||
visibility = visibility_event.content.get(
|
||||
"history_visibility", HistoryVisibility.SHARED
|
||||
)
|
||||
else:
|
||||
visibility = HistoryVisibility.SHARED
|
||||
|
||||
if visibility not in VISIBILITY_PRIORITY:
|
||||
visibility = HistoryVisibility.SHARED
|
||||
visibility = get_effective_room_visibility_from_state(state)
|
||||
|
||||
# Always allow history visibility events on boundaries. This is done
|
||||
# by setting the effective visibility to the least restrictive
|
||||
@@ -267,6 +258,23 @@ async def filter_events_for_client(
|
||||
return [ev for ev in filtered_events if ev]
|
||||
|
||||
|
||||
def get_effective_room_visibility_from_state(state: StateMap[EventBase]) -> str:
|
||||
"""Get the actual history vis, from a state map including the history_visibility event
|
||||
|
||||
Handles missing and invalid history visibility events.
|
||||
"""
|
||||
visibility_event = state.get(_HISTORY_VIS_KEY, None)
|
||||
if not visibility_event:
|
||||
return HistoryVisibility.SHARED
|
||||
|
||||
visibility = visibility_event.content.get(
|
||||
"history_visibility", HistoryVisibility.SHARED
|
||||
)
|
||||
if visibility not in VISIBILITY_PRIORITY:
|
||||
visibility = HistoryVisibility.SHARED
|
||||
return visibility
|
||||
|
||||
|
||||
async def filter_events_for_server(
|
||||
storage: StorageControllers,
|
||||
server_name: str,
|
||||
|
||||
@@ -404,7 +404,6 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
appservice = ApplicationService(
|
||||
"abcd",
|
||||
self.hs.config.server.server_name,
|
||||
id="1234",
|
||||
namespaces={
|
||||
"users": [{"regex": "@_appservice.*:sender", "exclusive": True}]
|
||||
@@ -433,7 +432,6 @@ class AuthTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
appservice = ApplicationService(
|
||||
"abcd",
|
||||
self.hs.config.server.server_name,
|
||||
id="1234",
|
||||
namespaces={
|
||||
"users": [{"regex": "@_appservice.*:sender", "exclusive": True}]
|
||||
|
||||
@@ -31,7 +31,6 @@ class TestRatelimiter(unittest.HomeserverTestCase):
|
||||
def test_allowed_appservice_ratelimited_via_can_requester_do_action(self):
|
||||
appservice = ApplicationService(
|
||||
None,
|
||||
"example.com",
|
||||
id="foo",
|
||||
rate_limited=True,
|
||||
sender="@as:example.com",
|
||||
@@ -62,7 +61,6 @@ class TestRatelimiter(unittest.HomeserverTestCase):
|
||||
def test_allowed_appservice_via_can_requester_do_action(self):
|
||||
appservice = ApplicationService(
|
||||
None,
|
||||
"example.com",
|
||||
id="foo",
|
||||
rate_limited=False,
|
||||
sender="@as:example.com",
|
||||
|
||||
@@ -37,7 +37,6 @@ class ApplicationServiceApiTestCase(unittest.HomeserverTestCase):
|
||||
url=URL,
|
||||
token="unused",
|
||||
hs_token=TOKEN,
|
||||
hostname="myserver",
|
||||
)
|
||||
|
||||
def test_query_3pe_authenticates_token(self):
|
||||
|
||||
@@ -33,7 +33,6 @@ class ApplicationServiceTestCase(unittest.TestCase):
|
||||
sender="@as:test",
|
||||
url="some_url",
|
||||
token="some_token",
|
||||
hostname="matrix.org", # only used by get_groups_for_user
|
||||
)
|
||||
self.event = Mock(
|
||||
event_id="$abc:xyz",
|
||||
|
||||
@@ -12,10 +12,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import nacl.signing
|
||||
import signedjson.types
|
||||
from unpaddedbase64 import decode_base64
|
||||
from signedjson.key import decode_signing_key_base64
|
||||
from signedjson.types import SigningKey
|
||||
|
||||
from synapse.api.room_versions import RoomVersions
|
||||
from synapse.crypto.event_signing import add_hashes_and_signatures
|
||||
@@ -25,7 +23,7 @@ from tests import unittest
|
||||
|
||||
# Perform these tests using given secret key so we get entirely deterministic
|
||||
# signatures output that we can test against.
|
||||
SIGNING_KEY_SEED = decode_base64("YJDBA9Xnr2sVqXD9Vj7XVUnmFZcZrlw8Md7kMW+3XA1")
|
||||
SIGNING_KEY_SEED = "YJDBA9Xnr2sVqXD9Vj7XVUnmFZcZrlw8Md7kMW+3XA1"
|
||||
|
||||
KEY_ALG = "ed25519"
|
||||
KEY_VER = "1"
|
||||
@@ -36,14 +34,9 @@ HOSTNAME = "domain"
|
||||
|
||||
class EventSigningTestCase(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# NB: `signedjson` expects `nacl.signing.SigningKey` instances which have been
|
||||
# monkeypatched to include new `alg` and `version` attributes. This is captured
|
||||
# by the `signedjson.types.SigningKey` protocol.
|
||||
self.signing_key: signedjson.types.SigningKey = nacl.signing.SigningKey( # type: ignore[assignment]
|
||||
SIGNING_KEY_SEED
|
||||
self.signing_key: SigningKey = decode_signing_key_base64(
|
||||
KEY_ALG, KEY_VER, SIGNING_KEY_SEED
|
||||
)
|
||||
self.signing_key.alg = KEY_ALG
|
||||
self.signing_key.version = KEY_VER
|
||||
|
||||
def test_sign_minimal(self):
|
||||
event_dict = {
|
||||
|
||||
@@ -19,8 +19,8 @@ import attr
|
||||
import canonicaljson
|
||||
import signedjson.key
|
||||
import signedjson.sign
|
||||
from nacl.signing import SigningKey
|
||||
from signedjson.key import encode_verify_key_base64, get_verify_key
|
||||
from signedjson.types import SigningKey
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.internet.defer import Deferred, ensureDeferred
|
||||
|
||||
@@ -30,16 +30,16 @@ from tests.unittest import HomeserverTestCase, override_config
|
||||
|
||||
class FederationSenderReceiptsTestCases(HomeserverTestCase):
|
||||
def make_homeserver(self, reactor, clock):
|
||||
mock_state_handler = Mock(spec=["get_current_hosts_in_room"])
|
||||
# Ensure a new Awaitable is created for each call.
|
||||
mock_state_handler.get_current_hosts_in_room.return_value = make_awaitable(
|
||||
["test", "host2"]
|
||||
)
|
||||
return self.setup_test_homeserver(
|
||||
state_handler=mock_state_handler,
|
||||
hs = self.setup_test_homeserver(
|
||||
federation_transport_client=Mock(spec=["send_transaction"]),
|
||||
)
|
||||
|
||||
hs.get_storage_controllers().state.get_current_hosts_in_room = Mock(
|
||||
return_value=make_awaitable({"test", "host2"})
|
||||
)
|
||||
|
||||
return hs
|
||||
|
||||
@override_config({"send_federation": True})
|
||||
def test_send_receipts(self):
|
||||
mock_send_transaction = (
|
||||
|
||||
@@ -134,6 +134,8 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase):
|
||||
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer):
|
||||
super().prepare(reactor, clock, hs)
|
||||
|
||||
self._storage_controllers = hs.get_storage_controllers()
|
||||
|
||||
# create the room
|
||||
creator_user_id = self.register_user("kermit", "test")
|
||||
tok = self.login("kermit", "test")
|
||||
@@ -207,7 +209,7 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase):
|
||||
|
||||
# the room should show that the new user is a member
|
||||
r = self.get_success(
|
||||
self.hs.get_state_handler().get_current_state(self._room_id)
|
||||
self._storage_controllers.state.get_current_state(self._room_id)
|
||||
)
|
||||
self.assertEqual(r[("m.room.member", joining_user)].membership, "join")
|
||||
|
||||
@@ -258,7 +260,7 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase):
|
||||
|
||||
# the room should show that the new user is a member
|
||||
r = self.get_success(
|
||||
self.hs.get_state_handler().get_current_state(self._room_id)
|
||||
self._storage_controllers.state.get_current_state(self._room_id)
|
||||
)
|
||||
self.assertEqual(r[("m.room.member", joining_user)].membership, "join")
|
||||
|
||||
|
||||
@@ -697,7 +697,6 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
|
||||
# Create an application service
|
||||
appservice = ApplicationService(
|
||||
token=random_string(10),
|
||||
hostname="example.com",
|
||||
id=random_string(10),
|
||||
sender="@as:example.com",
|
||||
rate_limited=False,
|
||||
@@ -776,7 +775,6 @@ class ApplicationServicesHandlerDeviceListsTestCase(unittest.HomeserverTestCase)
|
||||
# Create an appservice that is interested in "local_user"
|
||||
appservice = ApplicationService(
|
||||
token=random_string(10),
|
||||
hostname="example.com",
|
||||
id=random_string(10),
|
||||
sender="@as:example.com",
|
||||
rate_limited=False,
|
||||
@@ -843,7 +841,6 @@ class ApplicationServicesHandlerOtkCountsTestCase(unittest.HomeserverTestCase):
|
||||
self._service_token = "VERYSECRET"
|
||||
self._service = ApplicationService(
|
||||
self._service_token,
|
||||
"as1.invalid",
|
||||
"as1",
|
||||
"@as.sender:test",
|
||||
namespaces={
|
||||
|
||||
@@ -298,6 +298,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
|
||||
self.store = hs.get_datastores().main
|
||||
self.handler = hs.get_directory_handler()
|
||||
self.state_handler = hs.get_state_handler()
|
||||
self._storage_controllers = hs.get_storage_controllers()
|
||||
|
||||
# Create user
|
||||
self.admin_user = self.register_user("admin", "pass", admin=True)
|
||||
@@ -335,7 +336,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
|
||||
def _get_canonical_alias(self):
|
||||
"""Get the canonical alias state of the room."""
|
||||
return self.get_success(
|
||||
self.state_handler.get_current_state(
|
||||
self._storage_controllers.state.get_current_state_event(
|
||||
self.room_id, EventTypes.CanonicalAlias, ""
|
||||
)
|
||||
)
|
||||
|
||||
@@ -237,7 +237,9 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
|
||||
)
|
||||
current_state = self.get_success(
|
||||
self.store.get_events_as_list(
|
||||
(self.get_success(self.store.get_current_state_ids(room_id))).values()
|
||||
(
|
||||
self.get_success(self.store.get_partial_current_state_ids(room_id))
|
||||
).values()
|
||||
)
|
||||
)
|
||||
|
||||
@@ -512,7 +514,7 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
|
||||
self.get_success(d)
|
||||
|
||||
# sanity-check: the room should show that the new user is a member
|
||||
r = self.get_success(self.store.get_current_state_ids(room_id))
|
||||
r = self.get_success(self.store.get_partial_current_state_ids(room_id))
|
||||
self.assertEqual(r[(EventTypes.Member, other_user)], join_event.event_id)
|
||||
|
||||
return join_event
|
||||
|
||||
@@ -91,7 +91,9 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
|
||||
event_injection.inject_member_event(self.hs, room_id, OTHER_USER, "join")
|
||||
)
|
||||
|
||||
initial_state_map = self.get_success(main_store.get_current_state_ids(room_id))
|
||||
initial_state_map = self.get_success(
|
||||
main_store.get_partial_current_state_ids(room_id)
|
||||
)
|
||||
|
||||
auth_event_ids = [
|
||||
initial_state_map[("m.room.create", "")],
|
||||
|
||||
@@ -129,10 +129,12 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
hs.get_event_auth_handler().check_host_in_room = check_host_in_room
|
||||
|
||||
def get_joined_hosts_for_room(room_id: str):
|
||||
async def get_current_hosts_in_room(room_id: str):
|
||||
return {member.domain for member in self.room_members}
|
||||
|
||||
self.datastore.get_joined_hosts_for_room = get_joined_hosts_for_room
|
||||
hs.get_storage_controllers().state.get_current_hosts_in_room = (
|
||||
get_current_hosts_in_room
|
||||
)
|
||||
|
||||
async def get_users_in_room(room_id: str):
|
||||
return {str(u) for u in self.room_members}
|
||||
@@ -146,7 +148,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
|
||||
)
|
||||
)
|
||||
|
||||
self.datastore.get_current_state_deltas = Mock(return_value=(0, None))
|
||||
self.datastore.get_partial_current_state_deltas = Mock(return_value=(0, None))
|
||||
|
||||
self.datastore.get_to_device_stream_token = lambda: 0
|
||||
self.datastore.get_new_device_msgs_for_remote = (
|
||||
|
||||
@@ -60,7 +60,6 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
self.appservice = ApplicationService(
|
||||
token="i_am_an_app_service",
|
||||
hostname="test",
|
||||
id="1234",
|
||||
namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]},
|
||||
# Note: this user does not match the regex above, so that tests
|
||||
|
||||
@@ -2467,7 +2467,6 @@ PURGE_TABLES = [
|
||||
"event_push_actions",
|
||||
"event_search",
|
||||
"events",
|
||||
"group_rooms",
|
||||
"receipts_graph",
|
||||
"receipts_linearized",
|
||||
"room_aliases",
|
||||
@@ -2484,7 +2483,6 @@ PURGE_TABLES = [
|
||||
"e2e_room_keys",
|
||||
"event_push_summary",
|
||||
"pusher_throttle",
|
||||
"group_summary_rooms",
|
||||
"room_account_data",
|
||||
"room_tags",
|
||||
# "state_groups", # Current impl leaves orphaned state groups around.
|
||||
|
||||
@@ -548,7 +548,6 @@ class WhoamiTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
appservice = ApplicationService(
|
||||
as_token,
|
||||
self.hs.config.server.server_name,
|
||||
id="1234",
|
||||
namespaces={"users": [{"regex": user_id, "exclusive": True}]},
|
||||
sender=user_id,
|
||||
|
||||
@@ -1112,7 +1112,6 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
|
||||
self.service = ApplicationService(
|
||||
id="unique_identifier",
|
||||
token="some_token",
|
||||
hostname="example.com",
|
||||
sender="@asbot:example.com",
|
||||
namespaces={
|
||||
ApplicationService.NS_USERS: [
|
||||
@@ -1125,7 +1124,6 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
|
||||
self.another_service = ApplicationService(
|
||||
id="another__identifier",
|
||||
token="another_token",
|
||||
hostname="example.com",
|
||||
sender="@as2bot:example.com",
|
||||
namespaces={
|
||||
ApplicationService.NS_USERS: [
|
||||
|
||||
@@ -56,7 +56,6 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
appservice = ApplicationService(
|
||||
as_token,
|
||||
self.hs.config.server.server_name,
|
||||
id="1234",
|
||||
namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]},
|
||||
sender="@as:test",
|
||||
@@ -80,7 +79,6 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
appservice = ApplicationService(
|
||||
as_token,
|
||||
self.hs.config.server.server_name,
|
||||
id="1234",
|
||||
namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]},
|
||||
sender="@as:test",
|
||||
|
||||
@@ -896,6 +896,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
|
||||
relation_type: str,
|
||||
assertion_callable: Callable[[JsonDict], None],
|
||||
expected_db_txn_for_event: int,
|
||||
access_token: Optional[str] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Makes requests to various endpoints which should include bundled aggregations
|
||||
@@ -907,7 +908,9 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
|
||||
for relation-specific assertions.
|
||||
expected_db_txn_for_event: The number of database transactions which
|
||||
are expected for a call to /event/.
|
||||
access_token: The access token to user, defaults to self.user_token.
|
||||
"""
|
||||
access_token = access_token or self.user_token
|
||||
|
||||
def assert_bundle(event_json: JsonDict) -> None:
|
||||
"""Assert the expected values of the bundled aggregations."""
|
||||
@@ -921,7 +924,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
f"/rooms/{self.room}/event/{self.parent_id}",
|
||||
access_token=self.user_token,
|
||||
access_token=access_token,
|
||||
)
|
||||
self.assertEqual(200, channel.code, channel.json_body)
|
||||
assert_bundle(channel.json_body)
|
||||
@@ -932,7 +935,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
f"/rooms/{self.room}/messages?dir=b",
|
||||
access_token=self.user_token,
|
||||
access_token=access_token,
|
||||
)
|
||||
self.assertEqual(200, channel.code, channel.json_body)
|
||||
assert_bundle(self._find_event_in_chunk(channel.json_body["chunk"]))
|
||||
@@ -941,7 +944,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
|
||||
channel = self.make_request(
|
||||
"GET",
|
||||
f"/rooms/{self.room}/context/{self.parent_id}",
|
||||
access_token=self.user_token,
|
||||
access_token=access_token,
|
||||
)
|
||||
self.assertEqual(200, channel.code, channel.json_body)
|
||||
assert_bundle(channel.json_body["event"])
|
||||
@@ -949,7 +952,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
|
||||
# Request sync.
|
||||
filter = urllib.parse.quote_plus(b'{"room": {"timeline": {"limit": 4}}}')
|
||||
channel = self.make_request(
|
||||
"GET", f"/sync?filter={filter}", access_token=self.user_token
|
||||
"GET", f"/sync?filter={filter}", access_token=access_token
|
||||
)
|
||||
self.assertEqual(200, channel.code, channel.json_body)
|
||||
room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"]
|
||||
@@ -962,7 +965,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
|
||||
"/search",
|
||||
# Search term matches the parent message.
|
||||
content={"search_categories": {"room_events": {"search_term": "Hi"}}},
|
||||
access_token=self.user_token,
|
||||
access_token=access_token,
|
||||
)
|
||||
self.assertEqual(200, channel.code, channel.json_body)
|
||||
chunk = [
|
||||
@@ -1037,30 +1040,60 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
|
||||
"""
|
||||
Test that threads get correctly bundled.
|
||||
"""
|
||||
self._send_relation(RelationTypes.THREAD, "m.room.test")
|
||||
channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
|
||||
# The root message is from "user", send replies as "user2".
|
||||
self._send_relation(
|
||||
RelationTypes.THREAD, "m.room.test", access_token=self.user2_token
|
||||
)
|
||||
channel = self._send_relation(
|
||||
RelationTypes.THREAD, "m.room.test", access_token=self.user2_token
|
||||
)
|
||||
thread_2 = channel.json_body["event_id"]
|
||||
|
||||
def assert_thread(bundled_aggregations: JsonDict) -> None:
|
||||
self.assertEqual(2, bundled_aggregations.get("count"))
|
||||
self.assertTrue(bundled_aggregations.get("current_user_participated"))
|
||||
# The latest thread event has some fields that don't matter.
|
||||
self.assert_dict(
|
||||
{
|
||||
"content": {
|
||||
"m.relates_to": {
|
||||
"event_id": self.parent_id,
|
||||
"rel_type": RelationTypes.THREAD,
|
||||
}
|
||||
# This needs two assertion functions which are identical except for whether
|
||||
# the current_user_participated flag is True, create a factory for the
|
||||
# two versions.
|
||||
def _gen_assert(participated: bool) -> Callable[[JsonDict], None]:
|
||||
def assert_thread(bundled_aggregations: JsonDict) -> None:
|
||||
self.assertEqual(2, bundled_aggregations.get("count"))
|
||||
self.assertEqual(
|
||||
participated, bundled_aggregations.get("current_user_participated")
|
||||
)
|
||||
# The latest thread event has some fields that don't matter.
|
||||
self.assert_dict(
|
||||
{
|
||||
"content": {
|
||||
"m.relates_to": {
|
||||
"event_id": self.parent_id,
|
||||
"rel_type": RelationTypes.THREAD,
|
||||
}
|
||||
},
|
||||
"event_id": thread_2,
|
||||
"sender": self.user2_id,
|
||||
"type": "m.room.test",
|
||||
},
|
||||
"event_id": thread_2,
|
||||
"sender": self.user_id,
|
||||
"type": "m.room.test",
|
||||
},
|
||||
bundled_aggregations.get("latest_event"),
|
||||
)
|
||||
bundled_aggregations.get("latest_event"),
|
||||
)
|
||||
|
||||
self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 9)
|
||||
return assert_thread
|
||||
|
||||
# The "user" sent the root event and is making queries for the bundled
|
||||
# aggregations: they have participated.
|
||||
self._test_bundled_aggregations(RelationTypes.THREAD, _gen_assert(True), 8)
|
||||
# The "user2" sent replies in the thread and is making queries for the
|
||||
# bundled aggregations: they have participated.
|
||||
#
|
||||
# Note that this re-uses some cached values, so the total number of
|
||||
# queries is much smaller.
|
||||
self._test_bundled_aggregations(
|
||||
RelationTypes.THREAD, _gen_assert(True), 2, access_token=self.user2_token
|
||||
)
|
||||
|
||||
# A user with no interactions with the thread: they have not participated.
|
||||
user3_id, user3_token = self._create_user("charlie")
|
||||
self.helper.join(self.room, user=user3_id, tok=user3_token)
|
||||
self._test_bundled_aggregations(
|
||||
RelationTypes.THREAD, _gen_assert(False), 2, access_token=user3_token
|
||||
)
|
||||
|
||||
def test_thread_with_bundled_aggregations_for_latest(self) -> None:
|
||||
"""
|
||||
@@ -1106,7 +1139,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
|
||||
bundled_aggregations["latest_event"].get("unsigned"),
|
||||
)
|
||||
|
||||
self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 9)
|
||||
self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 8)
|
||||
|
||||
def test_nested_thread(self) -> None:
|
||||
"""
|
||||
|
||||
@@ -71,7 +71,6 @@ class RoomBatchTestCase(unittest.HomeserverTestCase):
|
||||
|
||||
self.appservice = ApplicationService(
|
||||
token="i_am_an_app_service",
|
||||
hostname="test",
|
||||
id="1234",
|
||||
namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]},
|
||||
# Note: this user does not have to match the regex above
|
||||
|
||||
@@ -249,7 +249,9 @@ class UpgradeRoomTest(unittest.HomeserverTestCase):
|
||||
|
||||
new_space_id = channel.json_body["replacement_room"]
|
||||
|
||||
state_ids = self.get_success(self.store.get_current_state_ids(new_space_id))
|
||||
state_ids = self.get_success(
|
||||
self.store.get_partial_current_state_ids(new_space_id)
|
||||
)
|
||||
|
||||
# Ensure the new room is still a space.
|
||||
create_event = self.get_success(
|
||||
@@ -284,7 +286,9 @@ class UpgradeRoomTest(unittest.HomeserverTestCase):
|
||||
|
||||
new_room_id = channel.json_body["replacement_room"]
|
||||
|
||||
state_ids = self.get_success(self.store.get_current_state_ids(new_room_id))
|
||||
state_ids = self.get_success(
|
||||
self.store.get_partial_current_state_ids(new_room_id)
|
||||
)
|
||||
|
||||
# Ensure the new room is the same type as the old room.
|
||||
create_event = self.get_success(
|
||||
|
||||
@@ -145,7 +145,7 @@ class SummarizeTestCase(unittest.TestCase):
|
||||
)
|
||||
|
||||
|
||||
class CalcOgTestCase(unittest.TestCase):
|
||||
class OpenGraphFromHtmlTestCase(unittest.TestCase):
|
||||
if not lxml:
|
||||
skip = "url preview feature requires lxml"
|
||||
|
||||
@@ -235,6 +235,21 @@ class CalcOgTestCase(unittest.TestCase):
|
||||
|
||||
self.assertEqual(og, {"og:title": None, "og:description": "Some text."})
|
||||
|
||||
# Another variant is a title with no content.
|
||||
html = b"""
|
||||
<html>
|
||||
<head><title></title></head>
|
||||
<body>
|
||||
<h1>Title</h1>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
tree = decode_body(html, "http://example.com/test.html")
|
||||
og = parse_html_to_open_graph(tree)
|
||||
|
||||
self.assertEqual(og, {"og:title": "Title", "og:description": "Title"})
|
||||
|
||||
def test_h1_as_title(self) -> None:
|
||||
html = b"""
|
||||
<html>
|
||||
@@ -250,6 +265,26 @@ class CalcOgTestCase(unittest.TestCase):
|
||||
|
||||
self.assertEqual(og, {"og:title": "Title", "og:description": "Some text."})
|
||||
|
||||
def test_empty_description(self) -> None:
|
||||
"""Description tags with empty content should be ignored."""
|
||||
html = b"""
|
||||
<html>
|
||||
<meta property="og:description" content=""/>
|
||||
<meta property="og:description"/>
|
||||
<meta name="description" content=""/>
|
||||
<meta name="description"/>
|
||||
<meta name="description" content="Finally!"/>
|
||||
<body>
|
||||
<h1>Title</h1>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
tree = decode_body(html, "http://example.com/test.html")
|
||||
og = parse_html_to_open_graph(tree)
|
||||
|
||||
self.assertEqual(og, {"og:title": "Title", "og:description": "Finally!"})
|
||||
|
||||
def test_missing_title_and_broken_h1(self) -> None:
|
||||
html = b"""
|
||||
<html>
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user