1
0

Merge branch 'develop' of https://github.com/matrix-org/synapse into sonar

 Conflicts:
	poetry.lock
This commit is contained in:
Michael Telatynski
2022-06-07 10:14:42 +01:00
107 changed files with 1195 additions and 853 deletions
+1 -1
View File
@@ -380,7 +380,7 @@ jobs:
# Attempt to check out the same branch of Complement as the PR. If it
# doesn't exist, fallback to HEAD.
- name: Checkout complement
run: .ci/scripts/checkout_complement.sh
run: synapse/.ci/scripts/checkout_complement.sh
- run: |
set -o pipefail
+1
View File
@@ -0,0 +1 @@
Implement [MSC3816](https://github.com/matrix-org/matrix-spec-proposals/pull/3816): sending the root event in a thread should count as "participated" in it.
+1
View File
@@ -0,0 +1 @@
Reduce the amount of state we pull from the DB.
+1
View File
@@ -0,0 +1 @@
Faster room joins: when querying the current state of the room, wait for state to be populated.
+1
View File
@@ -0,0 +1 @@
Remove support for the non-standard groups/communities feature from Synapse.
+1
View File
@@ -0,0 +1 @@
Remove PyNaCl occurrences directly used in Synapse code.
+1
View File
@@ -0,0 +1 @@
Fix a bug introduced in Synapse 1.58.0 where `/sync` would fail if the most recent event in a room was a redaction of an event that has since been purged.
+1
View File
@@ -0,0 +1 @@
Fix potential memory leak when generating thumbnails.
+1
View File
@@ -0,0 +1 @@
Test Synapse against Complement with workers.
+1
View File
@@ -0,0 +1 @@
Remove support for the non-standard groups/communities feature from Synapse.
+1
View File
@@ -0,0 +1 @@
Fix a long-standing bug where a URL preview would break if the image failed to download.
+1
View File
@@ -0,0 +1 @@
Improve URL previews for pages with empty elements.
+1
View File
@@ -0,0 +1 @@
Allow updating a user's password using the admin API without logging out their devices. Contributed by @jcgruenhage.
+1
View File
@@ -0,0 +1 @@
Reduce the amount of state we pull from the DB.
+1
View File
@@ -0,0 +1 @@
Remove support for the non-standard groups/communities feature from Synapse.
+4 -5
View File
@@ -16,6 +16,7 @@
""" Starts a synapse client console. """
import argparse
import binascii
import cmd
import getpass
import json
@@ -26,9 +27,8 @@ import urllib
from http import TwistedHttpClient
from typing import Optional
import nacl.encoding
import nacl.signing
import urlparse
from signedjson.key import NACL_ED25519, decode_verify_key_bytes
from signedjson.sign import SignatureVerifyException, verify_signed_json
from twisted.internet import defer, reactor, threads
@@ -41,7 +41,6 @@ TRUSTED_ID_SERVERS = ["localhost:8001"]
class SynapseCmd(cmd.Cmd):
"""Basic synapse command-line processor.
This processes commands from the user and calls the relevant HTTP methods.
@@ -420,8 +419,8 @@ class SynapseCmd(cmd.Cmd):
pubKey = None
pubKeyObj = yield self.http_client.do_request("GET", url)
if "public_key" in pubKeyObj:
pubKey = nacl.signing.VerifyKey(
pubKeyObj["public_key"], encoder=nacl.encoding.HexEncoder
pubKey = decode_verify_key_bytes(
NACL_ED25519, binascii.unhexlify(pubKeyObj["public_key"])
)
else:
print("No public key found in pubkey response!")
+3 -1
View File
@@ -115,7 +115,9 @@ URL parameters:
Body parameters:
- `password` - string, optional. If provided, the user's password is updated and all
devices are logged out.
devices are logged out, unless `logout_devices` is set to `false`.
- `logout_devices` - bool, optional, defaults to `true`. If set to false, devices aren't
logged out even when `password` is provided.
- `displayname` - string, optional, defaults to the value of `user_id`.
- `threepids` - array, optional, allows setting the third-party IDs (email, msisdn)
- `medium` - string. Kind of third-party ID, either `email` or `msisdn`.
-6
View File
@@ -191,7 +191,6 @@ information.
^/_matrix/federation/v1/event_auth/
^/_matrix/federation/v1/exchange_third_party_invite/
^/_matrix/federation/v1/user/devices/
^/_matrix/federation/v1/get_groups_publicised$
^/_matrix/key/v2/query
^/_matrix/federation/v1/hierarchy/
@@ -213,9 +212,6 @@ information.
^/_matrix/client/(r0|v3|unstable)/devices$
^/_matrix/client/versions$
^/_matrix/client/(api/v1|r0|v3|unstable)/voip/turnServer$
^/_matrix/client/(r0|v3|unstable)/joined_groups$
^/_matrix/client/(r0|v3|unstable)/publicised_groups$
^/_matrix/client/(r0|v3|unstable)/publicised_groups/
^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/event/
^/_matrix/client/(api/v1|r0|v3|unstable)/joined_rooms$
^/_matrix/client/(api/v1|r0|v3|unstable)/search$
@@ -255,9 +251,7 @@ information.
Additionally, the following REST endpoints can be handled for GET requests:
^/_matrix/federation/v1/groups/
^/_matrix/client/(api/v1|r0|v3|unstable)/pushrules/
^/_matrix/client/(r0|v3|unstable)/groups/
Pagination requests can also be handled, but all requests for a given
room must be routed to the same instance. Additionally, care must be taken to
Generated
+1 -55
View File
@@ -187,17 +187,6 @@ category = "main"
optional = false
python-versions = "*"
[[package]]
name = "coverage"
version = "6.4"
description = "Code coverage measurement for Python"
category = "dev"
optional = false
python-versions = ">=3.7"
[package.extras]
toml = ["tomli"]
[[package]]
name = "cryptography"
version = "36.0.1"
@@ -1574,7 +1563,7 @@ url_preview = ["lxml"]
[metadata]
lock-version = "1.1"
python-versions = "^3.7.1"
content-hash = "4ca5c8a4f817f99704a5a1ba466a50655073d2c899cf5034c426f84dc4afe0c7"
content-hash = "539e5326f401472d1ffc8325d53d72e544cd70156b3f43f32f1285c4c131f831"
[metadata.files]
attrs = [
@@ -1713,49 +1702,6 @@ constantly = [
{file = "constantly-15.1.0-py2.py3-none-any.whl", hash = "sha256:dd2fa9d6b1a51a83f0d7dd76293d734046aa176e384bf6e33b7e44880eb37c5d"},
{file = "constantly-15.1.0.tar.gz", hash = "sha256:586372eb92059873e29eba4f9dec8381541b4d3834660707faf8ba59146dfc35"},
]
coverage = [
{file = "coverage-6.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:50ed480b798febce113709846b11f5d5ed1e529c88d8ae92f707806c50297abf"},
{file = "coverage-6.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:26f8f92699756cb7af2b30720de0c5bb8d028e923a95b6d0c891088025a1ac8f"},
{file = "coverage-6.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:60c2147921da7f4d2d04f570e1838db32b95c5509d248f3fe6417e91437eaf41"},
{file = "coverage-6.4-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:750e13834b597eeb8ae6e72aa58d1d831b96beec5ad1d04479ae3772373a8088"},
{file = "coverage-6.4-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:af5b9ee0fc146e907aa0f5fb858c3b3da9199d78b7bb2c9973d95550bd40f701"},
{file = "coverage-6.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:a022394996419142b33a0cf7274cb444c01d2bb123727c4bb0b9acabcb515dea"},
{file = "coverage-6.4-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:5a78cf2c43b13aa6b56003707c5203f28585944c277c1f3f109c7b041b16bd39"},
{file = "coverage-6.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:9229d074e097f21dfe0643d9d0140ee7433814b3f0fc3706b4abffd1e3038632"},
{file = "coverage-6.4-cp310-cp310-win32.whl", hash = "sha256:fb45fe08e1abc64eb836d187b20a59172053999823f7f6ef4f18a819c44ba16f"},
{file = "coverage-6.4-cp310-cp310-win_amd64.whl", hash = "sha256:3cfd07c5889ddb96a401449109a8b97a165be9d67077df6802f59708bfb07720"},
{file = "coverage-6.4-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:03014a74023abaf5a591eeeaf1ac66a73d54eba178ff4cb1fa0c0a44aae70383"},
{file = "coverage-6.4-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9c82f2cd69c71698152e943f4a5a6b83a3ab1db73b88f6e769fabc86074c3b08"},
{file = "coverage-6.4-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7b546cf2b1974ddc2cb222a109b37c6ed1778b9be7e6b0c0bc0cf0438d9e45a6"},
{file = "coverage-6.4-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cc173f1ce9ffb16b299f51c9ce53f66a62f4d975abe5640e976904066f3c835d"},
{file = "coverage-6.4-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:c53ad261dfc8695062fc8811ac7c162bd6096a05a19f26097f411bdf5747aee7"},
{file = "coverage-6.4-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:eef5292b60b6de753d6e7f2d128d5841c7915fb1e3321c3a1fe6acfe76c38052"},
{file = "coverage-6.4-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:543e172ce4c0de533fa892034cce260467b213c0ea8e39da2f65f9a477425211"},
{file = "coverage-6.4-cp37-cp37m-win32.whl", hash = "sha256:00c8544510f3c98476bbd58201ac2b150ffbcce46a8c3e4fb89ebf01998f806a"},
{file = "coverage-6.4-cp37-cp37m-win_amd64.whl", hash = "sha256:b84ab65444dcc68d761e95d4d70f3cfd347ceca5a029f2ffec37d4f124f61311"},
{file = "coverage-6.4-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:d548edacbf16a8276af13063a2b0669d58bbcfca7c55a255f84aac2870786a61"},
{file = "coverage-6.4-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:033ebec282793bd9eb988d0271c211e58442c31077976c19c442e24d827d356f"},
{file = "coverage-6.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:742fb8b43835078dd7496c3c25a1ec8d15351df49fb0037bffb4754291ef30ce"},
{file = "coverage-6.4-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d55fae115ef9f67934e9f1103c9ba826b4c690e4c5bcf94482b8b2398311bf9c"},
{file = "coverage-6.4-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5cd698341626f3c77784858427bad0cdd54a713115b423d22ac83a28303d1d95"},
{file = "coverage-6.4-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:62d382f7d77eeeaff14b30516b17bcbe80f645f5cf02bb755baac376591c653c"},
{file = "coverage-6.4-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:016d7f5cf1c8c84f533a3c1f8f36126fbe00b2ec0ccca47cc5731c3723d327c6"},
{file = "coverage-6.4-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:69432946f154c6add0e9ede03cc43b96e2ef2733110a77444823c053b1ff5166"},
{file = "coverage-6.4-cp38-cp38-win32.whl", hash = "sha256:83bd142cdec5e4a5c4ca1d4ff6fa807d28460f9db919f9f6a31babaaa8b88426"},
{file = "coverage-6.4-cp38-cp38-win_amd64.whl", hash = "sha256:4002f9e8c1f286e986fe96ec58742b93484195defc01d5cc7809b8f7acb5ece3"},
{file = "coverage-6.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:e4f52c272fdc82e7c65ff3f17a7179bc5f710ebc8ce8a5cadac81215e8326740"},
{file = "coverage-6.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:b5578efe4038be02d76c344007b13119b2b20acd009a88dde8adec2de4f630b5"},
{file = "coverage-6.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d8099ea680201c2221f8468c372198ceba9338a5fec0e940111962b03b3f716a"},
{file = "coverage-6.4-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a00441f5ea4504f5abbc047589d09e0dc33eb447dc45a1a527c8b74bfdd32c65"},
{file = "coverage-6.4-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2e76bd16f0e31bc2b07e0fb1379551fcd40daf8cdf7e24f31a29e442878a827c"},
{file = "coverage-6.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:8d2e80dd3438e93b19e1223a9850fa65425e77f2607a364b6fd134fcd52dc9df"},
{file = "coverage-6.4-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:341e9c2008c481c5c72d0e0dbf64980a4b2238631a7f9780b0fe2e95755fb018"},
{file = "coverage-6.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:21e6686a95025927775ac501e74f5940cdf6fe052292f3a3f7349b0abae6d00f"},
{file = "coverage-6.4-cp39-cp39-win32.whl", hash = "sha256:968ed5407f9460bd5a591cefd1388cc00a8f5099de9e76234655ae48cfdbe2c3"},
{file = "coverage-6.4-cp39-cp39-win_amd64.whl", hash = "sha256:e35217031e4b534b09f9b9a5841b9344a30a6357627761d4218818b865d45055"},
{file = "coverage-6.4-pp36.pp37.pp38-none-any.whl", hash = "sha256:e637ae0b7b481905358624ef2e81d7fb0b1af55f5ff99f9ba05442a444b11e45"},
{file = "coverage-6.4.tar.gz", hash = "sha256:727dafd7f67a6e1cad808dc884bd9c5a2f6ef1f8f6d2f22b37b96cb0080d4f49"},
]
cryptography = [
{file = "cryptography-36.0.1-cp36-abi3-macosx_10_10_universal2.whl", hash = "sha256:73bc2d3f2444bcfeac67dd130ff2ea598ea5f20b40e36d19821b4df8c9c5037b"},
{file = "cryptography-36.0.1-cp36-abi3-macosx_10_10_x86_64.whl", hash = "sha256:2d87cdcb378d3cfed944dac30596da1968f88fb96d7fc34fdae30a99054b2e31"},
-1
View File
@@ -113,7 +113,6 @@ unpaddedbase64 = ">=2.1.0"
canonicaljson = ">=1.4.0"
# we use the type definitions added in signedjson 1.1.
signedjson = ">=1.1.0"
PyNaCl = ">=1.2.1"
# validating SSL certs for IP addresses requires service_identity 18.1.
service-identity = ">=18.1.0"
# Twisted 18.9 introduces some logger improvements that the structured
+16 -8
View File
@@ -102,14 +102,6 @@ BOOLEAN_COLUMNS = {
"devices": ["hidden"],
"device_lists_outbound_pokes": ["sent"],
"users_who_share_rooms": ["share_private"],
"groups": ["is_public"],
"group_rooms": ["is_public"],
"group_users": ["is_public", "is_admin"],
"group_summary_rooms": ["is_public"],
"group_room_categories": ["is_public"],
"group_summary_users": ["is_public"],
"group_roles": ["is_public"],
"local_group_membership": ["is_publicised", "is_admin"],
"e2e_room_keys": ["is_verified"],
"account_validity": ["email_sent"],
"redactions": ["have_censored"],
@@ -175,6 +167,22 @@ IGNORED_TABLES = {
"ui_auth_sessions",
"ui_auth_sessions_credentials",
"ui_auth_sessions_ips",
# Groups/communities is no longer supported.
"group_attestations_remote",
"group_attestations_renewals",
"group_invites",
"group_roles",
"group_room_categories",
"group_rooms",
"group_summary_roles",
"group_summary_room_categories",
"group_summary_rooms",
"group_summary_users",
"group_users",
"groups",
"local_group_membership",
"local_group_updates",
"remote_profile_cache",
}
+22 -23
View File
@@ -29,12 +29,11 @@ from synapse.api.errors import (
MissingClientTokenError,
)
from synapse.appservice import ApplicationService
from synapse.events import EventBase
from synapse.http import get_request_user_agent
from synapse.http.site import SynapseRequest
from synapse.logging.opentracing import active_span, force_tracing, start_active_span
from synapse.storage.databases.main.registration import TokenLookupResult
from synapse.types import Requester, StateMap, UserID, create_requester
from synapse.types import Requester, UserID, create_requester
from synapse.util.caches.lrucache import LruCache
from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry
@@ -61,8 +60,8 @@ class Auth:
self.hs = hs
self.clock = hs.get_clock()
self.store = hs.get_datastores().main
self.state = hs.get_state_handler()
self._account_validity_handler = hs.get_account_validity_handler()
self._storage_controllers = hs.get_storage_controllers()
self.token_cache: LruCache[str, Tuple[str, bool]] = LruCache(
10000, "token_cache"
@@ -79,9 +78,8 @@ class Auth:
self,
room_id: str,
user_id: str,
current_state: Optional[StateMap[EventBase]] = None,
allow_departed_users: bool = False,
) -> EventBase:
) -> Tuple[str, Optional[str]]:
"""Check if the user is in the room, or was at some point.
Args:
room_id: The room to check.
@@ -99,29 +97,28 @@ class Auth:
Raises:
AuthError if the user is/was not in the room.
Returns:
Membership event for the user if the user was in the
room. This will be the join event if they are currently joined to
the room. This will be the leave event if they have left the room.
The current membership of the user in the room and the
membership event ID of the user.
"""
if current_state:
member = current_state.get((EventTypes.Member, user_id), None)
else:
member = await self.state.get_current_state(
room_id=room_id, event_type=EventTypes.Member, state_key=user_id
)
if member:
membership = member.membership
(
membership,
member_event_id,
) = await self.store.get_local_current_membership_for_user_in_room(
user_id=user_id,
room_id=room_id,
)
if membership:
if membership == Membership.JOIN:
return member
return membership, member_event_id
# XXX this looks totally bogus. Why do we not allow users who have been banned,
# or those who were members previously and have been re-invited?
if allow_departed_users and membership == Membership.LEAVE:
forgot = await self.store.did_forget(user_id, room_id)
if not forgot:
return member
return membership, member_event_id
raise AuthError(403, "User %s not in room %s" % (user_id, room_id))
@@ -602,8 +599,11 @@ class Auth:
# We currently require the user is a "moderator" in the room. We do this
# by checking if they would (theoretically) be able to change the
# m.room.canonical_alias events
power_level_event = await self.state.get_current_state(
room_id, EventTypes.PowerLevels, ""
power_level_event = (
await self._storage_controllers.state.get_current_state_event(
room_id, EventTypes.PowerLevels, ""
)
)
auth_events = {}
@@ -693,12 +693,11 @@ class Auth:
# * The user is a non-guest user, and was ever in the room
# * The user is a guest user, and has joined the room
# else it will throw.
member_event = await self.check_user_in_room(
return await self.check_user_in_room(
room_id, user_id, allow_departed_users=allow_departed_users
)
return member_event.membership, member_event.event_id
except AuthError:
visibility = await self.state.get_current_state(
visibility = await self._storage_controllers.state.get_current_state_event(
room_id, EventTypes.RoomHistoryVisibility, ""
)
if (
-1
View File
@@ -95,7 +95,6 @@ class EventTypes:
Aliases: Final = "m.room.aliases"
Redaction: Final = "m.room.redaction"
ThirdPartyInvite: Final = "m.room.third_party_invite"
RelatedGroups: Final = "m.room.related_groups"
RoomHistoryVisibility: Final = "m.room.history_visibility"
CanonicalAlias: Final = "m.room.canonical_alias"
-2
View File
@@ -70,7 +70,6 @@ class ApplicationService:
def __init__(
self,
token: str,
hostname: str,
id: str,
sender: str,
url: Optional[str] = None,
@@ -88,7 +87,6 @@ class ApplicationService:
) # url must not end with a slash
self.hs_token = hs_token
self.sender = sender
self.server_name = hostname
self.namespaces = self._check_namespaces(namespaces)
self.id = id
self.ip_range_whitelist = ip_range_whitelist
-1
View File
@@ -179,7 +179,6 @@ def _load_appservice(
return ApplicationService(
token=as_info["as_token"],
hostname=hostname,
url=as_info["url"],
namespaces=as_info["namespaces"],
hs_token=as_info["hs_token"],
+2 -1
View File
@@ -152,6 +152,7 @@ class ThirdPartyEventRules:
self.third_party_rules = None
self.store = hs.get_datastores().main
self._storage_controllers = hs.get_storage_controllers()
self._check_event_allowed_callbacks: List[CHECK_EVENT_ALLOWED_CALLBACK] = []
self._on_create_room_callbacks: List[ON_CREATE_ROOM_CALLBACK] = []
@@ -463,7 +464,7 @@ class ThirdPartyEventRules:
Returns:
A dict mapping (event type, state key) to state event.
"""
state_ids = await self.store.get_filtered_current_state_ids(room_id)
state_ids = await self._storage_controllers.state.get_current_state_ids(room_id)
room_state_events = await self.store.get_events(state_ids.values())
state_events = {}
+1
View File
@@ -53,6 +53,7 @@ class FederationBase:
self.spam_checker = hs.get_spam_checker()
self.store = hs.get_datastores().main
self._clock = hs.get_clock()
self._storage_controllers = hs.get_storage_controllers()
async def _check_sigs_and_hash(
self, room_version: RoomVersion, pdu: EventBase
+6 -8
View File
@@ -118,6 +118,8 @@ class FederationServer(FederationBase):
self.state = hs.get_state_handler()
self._event_auth_handler = hs.get_event_auth_handler()
self._state_storage_controller = hs.get_storage_controllers().state
self.device_handler = hs.get_device_handler()
# Ensure the following handlers are loaded since they register callbacks
@@ -1221,14 +1223,10 @@ class FederationServer(FederationBase):
Raises:
AuthError if the server does not match the ACL
"""
state_ids = await self.store.get_current_state_ids(room_id)
acl_event_id = state_ids.get((EventTypes.ServerACL, ""))
if not acl_event_id:
return
acl_event = await self.store.get_event(acl_event_id)
if server_matches_acl_event(server_name, acl_event):
acl_event = await self._storage_controllers.state.get_current_state_event(
room_id, EventTypes.ServerACL, ""
)
if not acl_event or server_matches_acl_event(server_name, acl_event):
return
raise AuthError(code=403, msg="Server is banned from room")
+5 -1
View File
@@ -245,6 +245,8 @@ class FederationSender(AbstractFederationSender):
self.store = hs.get_datastores().main
self.state = hs.get_state_handler()
self._storage_controllers = hs.get_storage_controllers()
self.clock = hs.get_clock()
self.is_mine_id = hs.is_mine_id
@@ -602,7 +604,9 @@ class FederationSender(AbstractFederationSender):
room_id = receipt.room_id
# Work out which remote servers should be poked and poke them.
domains_set = await self.state.get_current_hosts_in_room(room_id)
domains_set = await self._storage_controllers.state.get_current_hosts_in_room(
room_id
)
domains = [
d
for d in domains_set
+1 -1
View File
@@ -166,7 +166,7 @@ class DeviceWorkerHandler:
possibly_changed = set(changed)
possibly_left = set()
for room_id in rooms_changed:
current_state_ids = await self.store.get_current_state_ids(room_id)
current_state_ids = await self._state_storage.get_current_state_ids(room_id)
# The user may have left the room
# TODO: Check if they actually did or if we were just invited.
+7 -2
View File
@@ -45,6 +45,7 @@ class DirectoryHandler:
self.appservice_handler = hs.get_application_service_handler()
self.event_creation_handler = hs.get_event_creation_handler()
self.store = hs.get_datastores().main
self._storage_controllers = hs.get_storage_controllers()
self.config = hs.config
self.enable_room_list_search = hs.config.roomdirectory.enable_room_list_search
self.require_membership = hs.config.server.require_membership_for_aliases
@@ -319,7 +320,7 @@ class DirectoryHandler:
Raises:
ShadowBanError if the requester has been shadow-banned.
"""
alias_event = await self.state.get_current_state(
alias_event = await self._storage_controllers.state.get_current_state_event(
room_id, EventTypes.CanonicalAlias, ""
)
@@ -463,7 +464,11 @@ class DirectoryHandler:
making_public = visibility == "public"
if making_public:
room_aliases = await self.store.get_aliases_for_room(room_id)
canonical_alias = await self.store.get_canonical_alias_for_room(room_id)
canonical_alias = (
await self._storage_controllers.state.get_canonical_alias_for_room(
room_id
)
)
if canonical_alias:
room_aliases.append(canonical_alias)
+7 -2
View File
@@ -371,7 +371,7 @@ class FederationHandler:
# First we try hosts that are already in the room
# TODO: HEURISTIC ALERT.
curr_state = await self.state_handler.get_current_state(room_id)
curr_state = await self._storage_controllers.state.get_current_state(room_id)
curr_domains = get_domains_from_state(curr_state)
@@ -750,7 +750,9 @@ class FederationHandler:
# Note that this requires the /send_join request to come back to the
# same server.
if room_version.msc3083_join_rules:
state_ids = await self.store.get_current_state_ids(room_id)
state_ids = await self._state_storage_controller.get_current_state_ids(
room_id
)
if await self._event_auth_handler.has_restricted_join_rules(
state_ids, room_version
):
@@ -1552,6 +1554,9 @@ class FederationHandler:
success = await self.store.clear_partial_state_room(room_id)
if success:
logger.info("State resync complete for %s", room_id)
self._storage_controllers.state.notify_room_un_partial_stated(
room_id
)
# TODO(faster_joins) update room stats and user directory?
return
+12 -6
View File
@@ -1584,9 +1584,11 @@ class FederationEventHandler:
if guest_access == GuestAccess.CAN_JOIN:
return
current_state_map = await self._state_handler.get_current_state(event.room_id)
current_state = list(current_state_map.values())
await self._get_room_member_handler().kick_guest_users(current_state)
current_state = await self._storage_controllers.state.get_current_state(
event.room_id
)
current_state_list = list(current_state.values())
await self._get_room_member_handler().kick_guest_users(current_state_list)
async def _check_for_soft_fail(
self,
@@ -1614,6 +1616,9 @@ class FederationEventHandler:
room_version = await self._store.get_room_version_id(event.room_id)
room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
# The event types we want to pull from the "current" state.
auth_types = auth_types_for_event(room_version_obj, event)
# Calculate the "current state".
if state_ids is not None:
# If we're explicitly given the state then we won't have all the
@@ -1643,8 +1648,10 @@ class FederationEventHandler:
)
)
else:
current_state_ids = await self._state_handler.get_current_state_ids(
event.room_id, latest_event_ids=extrem_ids
current_state_ids = (
await self._state_storage_controller.get_current_state_ids(
event.room_id, StateFilter.from_types(auth_types)
)
)
logger.debug(
@@ -1654,7 +1661,6 @@ class FederationEventHandler:
)
# Now check if event pass auth against said current state
auth_types = auth_types_for_event(room_version_obj, event)
current_state_ids_list = [
e for k, e in current_state_ids.items() if k in auth_types
]
+4 -2
View File
@@ -190,7 +190,7 @@ class InitialSyncHandler:
if event.membership == Membership.JOIN:
room_end_token = now_token.room_key
deferred_room_state = run_in_background(
self.state_handler.get_current_state, event.room_id
self._state_storage_controller.get_current_state, event.room_id
)
elif event.membership == Membership.LEAVE:
room_end_token = RoomStreamToken(
@@ -407,7 +407,9 @@ class InitialSyncHandler:
membership: str,
is_peeking: bool,
) -> JsonDict:
current_state = await self.state.get_current_state(room_id=room_id)
current_state = await self._storage_controllers.state.get_current_state(
room_id=room_id
)
# TODO: These concurrently
time_now = self.clock.time_msec()
+81 -39
View File
@@ -28,6 +28,7 @@ from synapse.api.constants import (
EventContentFields,
EventTypes,
GuestAccess,
HistoryVisibility,
Membership,
RelationTypes,
UserTypes,
@@ -66,7 +67,7 @@ from synapse.util import json_decoder, json_encoder, log_failure, unwrapFirstErr
from synapse.util.async_helpers import Linearizer, gather_results
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.metrics import measure_func
from synapse.visibility import filter_events_for_client
from synapse.visibility import get_effective_room_visibility_from_state
if TYPE_CHECKING:
from synapse.events.third_party_rules import ThirdPartyEventRules
@@ -124,7 +125,9 @@ class MessageHandler:
)
if membership == Membership.JOIN:
data = await self.state.get_current_state(room_id, event_type, state_key)
data = await self._storage_controllers.state.get_current_state_event(
room_id, event_type, state_key
)
elif membership == Membership.LEAVE:
key = (event_type, state_key)
# If the membership is not JOIN, then the event ID should exist.
@@ -182,51 +185,31 @@ class MessageHandler:
state_filter = state_filter or StateFilter.all()
if at_token:
last_event = await self.store.get_last_event_in_room_before_stream_ordering(
room_id,
end_token=at_token.room_key,
last_event_id = (
await self.store.get_last_event_in_room_before_stream_ordering(
room_id,
end_token=at_token.room_key,
)
)
if not last_event:
if not last_event_id:
raise NotFoundError("Can't find event for token %s" % (at_token,))
# check whether the user is in the room at that time to determine
# whether they should be treated as peeking.
state_map = await self._state_storage_controller.get_state_for_event(
last_event.event_id,
StateFilter.from_types([(EventTypes.Member, user_id)]),
)
joined = False
membership_event = state_map.get((EventTypes.Member, user_id))
if membership_event:
joined = membership_event.membership == Membership.JOIN
is_peeking = not joined
visible_events = await filter_events_for_client(
self._storage_controllers,
user_id,
[last_event],
filter_send_to_client=False,
is_peeking=is_peeking,
)
if visible_events:
room_state_events = (
await self._state_storage_controller.get_state_for_events(
[last_event.event_id], state_filter=state_filter
)
)
room_state: Mapping[Any, EventBase] = room_state_events[
last_event.event_id
]
else:
if not await self._user_can_see_state_at_event(
user_id, room_id, last_event_id
):
raise AuthError(
403,
"User %s not allowed to view events in room %s at token %s"
% (user_id, room_id, at_token),
)
room_state_events = (
await self._state_storage_controller.get_state_for_events(
[last_event_id], state_filter=state_filter
)
)
room_state: Mapping[Any, EventBase] = room_state_events[last_event_id]
else:
(
membership,
@@ -236,7 +219,7 @@ class MessageHandler:
)
if membership == Membership.JOIN:
state_ids = await self.store.get_filtered_current_state_ids(
state_ids = await self._state_storage_controller.get_current_state_ids(
room_id, state_filter=state_filter
)
room_state = await self.store.get_events(state_ids.values())
@@ -256,6 +239,65 @@ class MessageHandler:
events = self._event_serializer.serialize_events(room_state.values(), now)
return events
async def _user_can_see_state_at_event(
self, user_id: str, room_id: str, event_id: str
) -> bool:
# check whether the user was in the room, and the history visibility,
# at that time.
state_map = await self._state_storage_controller.get_state_for_event(
event_id,
StateFilter.from_types(
[
(EventTypes.Member, user_id),
(EventTypes.RoomHistoryVisibility, ""),
]
),
)
membership = None
membership_event = state_map.get((EventTypes.Member, user_id))
if membership_event:
membership = membership_event.membership
# if the user was a member of the room at the time of the event,
# they can see it.
if membership == Membership.JOIN:
return True
# otherwise, it depends on the history visibility.
visibility = get_effective_room_visibility_from_state(state_map)
if visibility == HistoryVisibility.JOINED:
# we weren't a member at the time of the event, so we can't see this event.
return False
# otherwise *invited* is good enough
if membership == Membership.INVITE:
return True
if visibility == HistoryVisibility.INVITED:
# we weren't invited, so we can't see this event.
return False
if visibility == HistoryVisibility.WORLD_READABLE:
return True
# So it's SHARED, and the user was not a member at the time. The user cannot
# see history, unless they have *subsequently* joined the room.
#
# XXX: if the user has subsequently joined and then left again,
# ideally we would share history up to the point they left. But
# we don't know when they left. We just treat it as though they
# never joined, and restrict access.
(
current_membership,
_,
) = await self.store.get_local_current_membership_for_user_in_room(
user_id, event_id
)
return current_membership == Membership.JOIN
async def get_joined_members(self, requester: Requester, room_id: str) -> dict:
"""Get all the joined members in the room and their profile information.
+5 -1
View File
@@ -134,6 +134,7 @@ class BasePresenceHandler(abc.ABC):
def __init__(self, hs: "HomeServer"):
self.clock = hs.get_clock()
self.store = hs.get_datastores().main
self._storage_controllers = hs.get_storage_controllers()
self.presence_router = hs.get_presence_router()
self.state = hs.get_state_handler()
self.is_mine_id = hs.is_mine_id
@@ -1348,7 +1349,10 @@ class PresenceHandler(BasePresenceHandler):
self._event_pos,
room_max_stream_ordering,
)
max_pos, deltas = await self.store.get_current_state_deltas(
(
max_pos,
deltas,
) = await self._storage_controllers.state.get_current_state_deltas(
self._event_pos, room_max_stream_ordering
)
+1 -82
View File
@@ -23,14 +23,7 @@ from synapse.api.errors import (
StoreError,
SynapseError,
)
from synapse.metrics.background_process_metrics import wrap_as_background_process
from synapse.types import (
JsonDict,
Requester,
UserID,
create_requester,
get_domain_from_id,
)
from synapse.types import JsonDict, Requester, UserID, create_requester
from synapse.util.caches.descriptors import cached
from synapse.util.stringutils import parse_and_validate_mxc_uri
@@ -50,9 +43,6 @@ class ProfileHandler:
delegate to master when necessary.
"""
PROFILE_UPDATE_MS = 60 * 1000
PROFILE_UPDATE_EVERY_MS = 24 * 60 * 60 * 1000
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastores().main
self.clock = hs.get_clock()
@@ -73,11 +63,6 @@ class ProfileHandler:
self._third_party_rules = hs.get_third_party_event_rules()
if hs.config.worker.run_background_tasks:
self.clock.looping_call(
self._update_remote_profile_cache, self.PROFILE_UPDATE_MS
)
async def get_profile(self, user_id: str) -> JsonDict:
target_user = UserID.from_string(user_id)
@@ -116,30 +101,6 @@ class ProfileHandler:
raise SynapseError(502, "Failed to fetch profile")
raise e.to_synapse_error()
async def get_profile_from_cache(self, user_id: str) -> JsonDict:
"""Get the profile information from our local cache. If the user is
ours then the profile information will always be correct. Otherwise,
it may be out of date/missing.
"""
target_user = UserID.from_string(user_id)
if self.hs.is_mine(target_user):
try:
displayname = await self.store.get_profile_displayname(
target_user.localpart
)
avatar_url = await self.store.get_profile_avatar_url(
target_user.localpart
)
except StoreError as e:
if e.code == 404:
raise SynapseError(404, "Profile was not found", Codes.NOT_FOUND)
raise
return {"displayname": displayname, "avatar_url": avatar_url}
else:
profile = await self.store.get_from_remote_profile_cache(user_id)
return profile or {}
async def get_displayname(self, target_user: UserID) -> Optional[str]:
if self.hs.is_mine(target_user):
try:
@@ -509,45 +470,3 @@ class ProfileHandler:
# so we act as if we couldn't find the profile.
raise SynapseError(403, "Profile isn't available", Codes.FORBIDDEN)
raise
@wrap_as_background_process("Update remote profile")
async def _update_remote_profile_cache(self) -> None:
"""Called periodically to check profiles of remote users we haven't
checked in a while.
"""
entries = await self.store.get_remote_profile_cache_entries_that_expire(
last_checked=self.clock.time_msec() - self.PROFILE_UPDATE_EVERY_MS
)
for user_id, displayname, avatar_url in entries:
is_subscribed = await self.store.is_subscribed_remote_profile_for_user(
user_id
)
if not is_subscribed:
await self.store.maybe_delete_remote_profile_cache(user_id)
continue
try:
profile = await self.federation.make_query(
destination=get_domain_from_id(user_id),
query_type="profile",
args={"user_id": user_id},
ignore_backoff=True,
)
except Exception:
logger.exception("Failed to get avatar_url")
await self.store.update_remote_profile_cache(
user_id, displayname, avatar_url
)
continue
new_name = profile.get("displayname")
if not isinstance(new_name, str):
new_name = None
new_avatar = profile.get("avatar_url")
if not isinstance(new_avatar, str):
new_avatar = None
# We always hit update to update the last_check timestamp
await self.store.update_remote_profile_cache(user_id, new_name, new_avatar)
+2 -1
View File
@@ -87,6 +87,7 @@ class LoginDict(TypedDict):
class RegistrationHandler:
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastores().main
self._storage_controllers = hs.get_storage_controllers()
self.clock = hs.get_clock()
self.hs = hs
self.auth = hs.get_auth()
@@ -528,7 +529,7 @@ class RegistrationHandler:
if requires_invite:
# If the server is in the room, check if the room is public.
state = await self.store.get_filtered_current_state_ids(
state = await self._storage_controllers.state.get_current_state_ids(
room_id, StateFilter.from_types([(EventTypes.JoinRules, "")])
)
+37 -21
View File
@@ -12,16 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from typing import (
TYPE_CHECKING,
Collection,
Dict,
FrozenSet,
Iterable,
List,
Optional,
Tuple,
)
from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Tuple
import attr
@@ -256,13 +247,19 @@ class RelationsHandler:
return filtered_results
async def get_threads_for_events(
self, event_ids: Collection[str], user_id: str, ignored_users: FrozenSet[str]
async def _get_threads_for_events(
self,
events_by_id: Dict[str, EventBase],
relations_by_id: Dict[str, str],
user_id: str,
ignored_users: FrozenSet[str],
) -> Dict[str, _ThreadAggregation]:
"""Get the bundled aggregations for threads for the requested events.
Args:
event_ids: Events to get aggregations for threads.
events_by_id: A map of event_id to events to get aggregations for threads.
relations_by_id: A map of event_id to the relation type, if one exists
for that event.
user_id: The user requesting the bundled aggregations.
ignored_users: The users ignored by the requesting user.
@@ -273,16 +270,34 @@ class RelationsHandler:
"""
user = UserID.from_string(user_id)
# It is not valid to start a thread on an event which itself relates to another event.
event_ids = [eid for eid in events_by_id.keys() if eid not in relations_by_id]
# Fetch thread summaries.
summaries = await self._main_store.get_thread_summaries(event_ids)
# Only fetch participated for a limited selection based on what had
# summaries.
# Limit fetching whether the requester has participated in a thread to
# events which are thread roots.
thread_event_ids = [
event_id for event_id, summary in summaries.items() if summary
]
participated = await self._main_store.get_threads_participated(
thread_event_ids, user_id
# Pre-seed thread participation with whether the requester sent the event.
participated = {
event_id: events_by_id[event_id].sender == user_id
for event_id in thread_event_ids
}
# For events the requester did not send, check the database for whether
# the requester sent a threaded reply.
participated.update(
await self._main_store.get_threads_participated(
[
event_id
for event_id in thread_event_ids
if not participated[event_id]
],
user_id,
)
)
# Then subtract off the results for any ignored users.
@@ -343,7 +358,8 @@ class RelationsHandler:
count=thread_count,
# If there's a thread summary it must also exist in the
# participated dictionary.
current_user_participated=participated[event_id],
current_user_participated=events_by_id[event_id].sender == user_id
or participated[event_id],
)
return results
@@ -401,9 +417,9 @@ class RelationsHandler:
# events to be fetched. Thus, we check those first!
# Fetch thread summaries (but only for the directly requested events).
threads = await self.get_threads_for_events(
# It is not valid to start a thread on an event which itself relates to another event.
[eid for eid in events_by_id.keys() if eid not in relations_by_id],
threads = await self._get_threads_for_events(
events_by_id,
relations_by_id,
user_id,
ignored_users,
)
+13 -6
View File
@@ -107,6 +107,7 @@ class EventContext:
class RoomCreationHandler:
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastores().main
self._storage_controllers = hs.get_storage_controllers()
self.auth = hs.get_auth()
self.clock = hs.get_clock()
self.hs = hs
@@ -468,7 +469,6 @@ class RoomCreationHandler:
(EventTypes.RoomAvatar, ""),
(EventTypes.RoomEncryption, ""),
(EventTypes.ServerACL, ""),
(EventTypes.RelatedGroups, ""),
(EventTypes.PowerLevels, ""),
]
@@ -481,8 +481,10 @@ class RoomCreationHandler:
if room_type == RoomTypes.SPACE:
types_to_copy.append((EventTypes.SpaceChild, None))
old_room_state_ids = await self.store.get_filtered_current_state_ids(
old_room_id, StateFilter.from_types(types_to_copy)
old_room_state_ids = (
await self._storage_controllers.state.get_current_state_ids(
old_room_id, StateFilter.from_types(types_to_copy)
)
)
# map from event_id to BaseEvent
old_room_state_events = await self.store.get_events(old_room_state_ids.values())
@@ -559,8 +561,10 @@ class RoomCreationHandler:
)
# Transfer membership events
old_room_member_state_ids = await self.store.get_filtered_current_state_ids(
old_room_id, StateFilter.from_types([(EventTypes.Member, None)])
old_room_member_state_ids = (
await self._storage_controllers.state.get_current_state_ids(
old_room_id, StateFilter.from_types([(EventTypes.Member, None)])
)
)
# map from event_id to BaseEvent
@@ -1329,6 +1333,7 @@ class TimestampLookupHandler:
self.store = hs.get_datastores().main
self.state_handler = hs.get_state_handler()
self.federation_client = hs.get_federation_client()
self._storage_controllers = hs.get_storage_controllers()
async def get_event_for_timestamp(
self,
@@ -1402,7 +1407,9 @@ class TimestampLookupHandler:
)
# Find other homeservers from the given state in the room
curr_state = await self.state_handler.get_current_state(room_id)
curr_state = await self._storage_controllers.state.get_current_state(
room_id
)
curr_domains = get_domains_from_state(curr_state)
likely_domains = [
domain for domain, depth in curr_domains if domain != self.server_name
+2 -1
View File
@@ -50,6 +50,7 @@ EMPTY_THIRD_PARTY_ID = ThirdPartyInstanceID(None, None)
class RoomListHandler:
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastores().main
self._storage_controllers = hs.get_storage_controllers()
self.hs = hs
self.enable_room_list_search = hs.config.roomdirectory.enable_room_list_search
self.response_cache: ResponseCache[
@@ -274,7 +275,7 @@ class RoomListHandler:
if aliases:
result["aliases"] = aliases
current_state_ids = await self.store.get_current_state_ids(
current_state_ids = await self._storage_controllers.state.get_current_state_ids(
room_id, on_invalidate=cache_context.invalidate
)
+18 -3
View File
@@ -68,6 +68,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.store = hs.get_datastores().main
self._storage_controllers = hs.get_storage_controllers()
self.auth = hs.get_auth()
self.state_handler = hs.get_state_handler()
self.config = hs.config
@@ -994,7 +995,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
# If the host is in the room, but not one of the authorised hosts
# for restricted join rules, a remote join must be used.
room_version = await self.store.get_room_version(room_id)
current_state_ids = await self.store.get_current_state_ids(room_id)
current_state_ids = await self._storage_controllers.state.get_current_state_ids(
room_id
)
# If restricted join rules are not being used, a local join can always
# be used.
@@ -1398,7 +1401,19 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
txn_id: Optional[str],
id_access_token: Optional[str] = None,
) -> int:
room_state = await self.state_handler.get_current_state(room_id)
room_state = await self._storage_controllers.state.get_current_state(
room_id,
StateFilter.from_types(
[
(EventTypes.Member, user.to_string()),
(EventTypes.CanonicalAlias, ""),
(EventTypes.Name, ""),
(EventTypes.Create, ""),
(EventTypes.JoinRules, ""),
(EventTypes.RoomAvatar, ""),
]
),
)
inviter_display_name = ""
inviter_avatar_url = ""
@@ -1794,7 +1809,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
async def forget(self, user: UserID, room_id: str) -> None:
user_id = user.to_string()
member = await self.state_handler.get_current_state(
member = await self._storage_controllers.state.get_current_state_event(
room_id=room_id, event_type=EventTypes.Member, state_key=user_id
)
membership = member.membership if member else None
+8 -3
View File
@@ -90,6 +90,7 @@ class RoomSummaryHandler:
def __init__(self, hs: "HomeServer"):
self._event_auth_handler = hs.get_event_auth_handler()
self._store = hs.get_datastores().main
self._storage_controllers = hs.get_storage_controllers()
self._event_serializer = hs.get_event_client_serializer()
self._server_name = hs.hostname
self._federation_client = hs.get_federation_client()
@@ -537,7 +538,7 @@ class RoomSummaryHandler:
Returns:
True if the room is accessible to the requesting user or server.
"""
state_ids = await self._store.get_current_state_ids(room_id)
state_ids = await self._storage_controllers.state.get_current_state_ids(room_id)
# If there's no state for the room, it isn't known.
if not state_ids:
@@ -702,7 +703,9 @@ class RoomSummaryHandler:
# there should always be an entry
assert stats is not None, "unable to retrieve stats for %s" % (room_id,)
current_state_ids = await self._store.get_current_state_ids(room_id)
current_state_ids = await self._storage_controllers.state.get_current_state_ids(
room_id
)
create_event = await self._store.get_event(
current_state_ids[(EventTypes.Create, "")]
)
@@ -760,7 +763,9 @@ class RoomSummaryHandler:
"""
# look for child rooms/spaces.
current_state_ids = await self._store.get_current_state_ids(room_id)
current_state_ids = await self._storage_controllers.state.get_current_state_ids(
room_id
)
events = await self._store.get_events_as_list(
[
+1 -1
View File
@@ -348,7 +348,7 @@ class SearchHandler:
state_results = {}
if include_state:
for room_id in {e.room_id for e in search_result.allowed_events}:
state = await self.state_handler.get_current_state(room_id)
state = await self._storage_controllers.state.get_current_state(room_id)
state_results[room_id] = list(state.values())
aggregations = await self._relations_handler.get_bundled_aggregations(
+5 -1
View File
@@ -40,6 +40,7 @@ class StatsHandler:
def __init__(self, hs: "HomeServer"):
self.hs = hs
self.store = hs.get_datastores().main
self._storage_controllers = hs.get_storage_controllers()
self.state = hs.get_state_handler()
self.server_name = hs.hostname
self.clock = hs.get_clock()
@@ -105,7 +106,10 @@ class StatsHandler:
logger.debug(
"Processing room stats %s->%s", self.pos, room_max_stream_ordering
)
max_pos, deltas = await self.store.get_current_state_deltas(
(
max_pos,
deltas,
) = await self._storage_controllers.state.get_current_state_deltas(
self.pos, room_max_stream_ordering
)
+28 -12
View File
@@ -506,8 +506,10 @@ class SyncHandler:
# ensure that we always include current state in the timeline
current_state_ids: FrozenSet[str] = frozenset()
if any(e.is_state() for e in recents):
current_state_ids_map = await self.store.get_current_state_ids(
room_id
current_state_ids_map = (
await self._state_storage_controller.get_current_state_ids(
room_id
)
)
current_state_ids = frozenset(current_state_ids_map.values())
@@ -574,8 +576,11 @@ class SyncHandler:
# ensure that we always include current state in the timeline
current_state_ids = frozenset()
if any(e.is_state() for e in loaded_recents):
current_state_ids_map = await self.store.get_current_state_ids(
room_id
# FIXME(faster_joins): We use the partial state here as
# we don't want to block `/sync` on finishing a lazy join.
# Is this the correct way of doing it?
current_state_ids_map = (
await self.store.get_partial_current_state_ids(room_id)
)
current_state_ids = frozenset(current_state_ids_map.values())
@@ -621,21 +626,32 @@ class SyncHandler:
)
async def get_state_after_event(
self, event: EventBase, state_filter: Optional[StateFilter] = None
self, event_id: str, state_filter: Optional[StateFilter] = None
) -> StateMap[str]:
"""
Get the room state after the given event
Args:
event: event of interest
event_id: event of interest
state_filter: The state filter used to fetch state from the database.
"""
state_ids = await self._state_storage_controller.get_state_ids_for_event(
event.event_id, state_filter=state_filter or StateFilter.all()
event_id, state_filter=state_filter or StateFilter.all()
)
if event.is_state():
# using get_metadata_for_events here (instead of get_event) sidesteps an issue
# with redactions: if `event_id` is a redaction event, and we don't have the
# original (possibly because it got purged), get_event will refuse to return
# the redaction event, which isn't terribly helpful here.
#
# (To be fair, in that case we could assume it's *not* a state event, and
# therefore we don't need to worry about it. But still, it seems cleaner just
# to pull the metadata.)
m = (await self.store.get_metadata_for_events([event_id]))[event_id]
if m.state_key is not None and m.rejection_reason is None:
state_ids = dict(state_ids)
state_ids[(event.type, event.state_key)] = event.event_id
state_ids[(m.event_type, m.state_key)] = event_id
return state_ids
async def get_state_at(
@@ -654,14 +670,14 @@ class SyncHandler:
# FIXME: This gets the state at the latest event before the stream ordering,
# which might not be the same as the "current state" of the room at the time
# of the stream token if there were multiple forward extremities at the time.
last_event = await self.store.get_last_event_in_room_before_stream_ordering(
last_event_id = await self.store.get_last_event_in_room_before_stream_ordering(
room_id,
end_token=stream_position.room_key,
)
if last_event:
if last_event_id:
state = await self.get_state_after_event(
last_event, state_filter=state_filter or StateFilter.all()
last_event_id, state_filter=state_filter or StateFilter.all()
)
else:
+5 -2
View File
@@ -59,6 +59,7 @@ class FollowerTypingHandler:
def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastores().main
self._storage_controllers = hs.get_storage_controllers()
self.server_name = hs.config.server.server_name
self.clock = hs.get_clock()
self.is_mine_id = hs.is_mine_id
@@ -131,7 +132,6 @@ class FollowerTypingHandler:
return
try:
users = await self.store.get_users_in_room(member.room_id)
self._member_last_federation_poke[member] = self.clock.time_msec()
now = self.clock.time_msec()
@@ -139,7 +139,10 @@ class FollowerTypingHandler:
now=now, obj=member, then=now + FEDERATION_PING_INTERVAL
)
for domain in {get_domain_from_id(u) for u in users}:
hosts = await self._storage_controllers.state.get_current_hosts_in_room(
member.room_id
)
for domain in hosts:
if domain != self.server_name:
logger.debug("sending typing update to %s", domain)
self.federation.build_and_send_edu(
+5 -1
View File
@@ -56,6 +56,7 @@ class UserDirectoryHandler(StateDeltasHandler):
super().__init__(hs)
self.store = hs.get_datastores().main
self._storage_controllers = hs.get_storage_controllers()
self.server_name = hs.hostname
self.clock = hs.get_clock()
self.notifier = hs.get_notifier()
@@ -174,7 +175,10 @@ class UserDirectoryHandler(StateDeltasHandler):
logger.debug(
"Processing user stats %s->%s", self.pos, room_max_stream_ordering
)
max_pos, deltas = await self.store.get_current_state_deltas(
(
max_pos,
deltas,
) = await self._storage_controllers.state.get_current_state_deltas(
self.pos, room_max_stream_ordering
)
+8 -11
View File
@@ -194,6 +194,7 @@ class ModuleApi:
self._store: Union[
DataStore, "GenericWorkerSlavedStore"
] = hs.get_datastores().main
self._storage_controllers = hs.get_storage_controllers()
self._auth = hs.get_auth()
self._auth_handler = auth_handler
self._server_name = hs.hostname
@@ -911,7 +912,7 @@ class ModuleApi:
The filtered state events in the room.
"""
state_ids = yield defer.ensureDeferred(
self._store.get_filtered_current_state_ids(
self._storage_controllers.state.get_current_state_ids(
room_id=room_id, state_filter=StateFilter.from_types(types)
)
)
@@ -1289,20 +1290,16 @@ class ModuleApi:
# regardless of their state key
]
"""
state_filter = None
if event_filter:
# If a filter was provided, turn it into a StateFilter and retrieve a filtered
# view of the state.
state_filter = StateFilter.from_types(event_filter)
state_ids = await self._store.get_filtered_current_state_ids(
room_id,
state_filter,
)
else:
# If no filter was provided, get the whole state. We could also reuse the call
# to get_filtered_current_state_ids above, with `state_filter = StateFilter.all()`,
# but get_filtered_current_state_ids isn't cached and `get_current_state_ids`
# is, so using the latter when we can is better for perf.
state_ids = await self._store.get_current_state_ids(room_id)
state_ids = await self._storage_controllers.state.get_current_state_ids(
room_id,
state_filter,
)
state_events = await self._store.get_events(state_ids.values())
+1 -1
View File
@@ -681,7 +681,7 @@ class Notifier:
return joined_room_ids, True
async def _is_world_readable(self, room_id: str) -> bool:
state = await self.state_handler.get_current_state(
state = await self._storage_controllers.state.get_current_state_event(
room_id, EventTypes.RoomHistoryVisibility, ""
)
if state and "history_visibility" in state.content:
+3 -1
View File
@@ -255,7 +255,9 @@ class Mailer:
user_display_name = user_id
async def _fetch_room_state(room_id: str) -> None:
room_state = await self.store.get_current_state_ids(room_id)
room_state = await self._state_storage_controller.get_current_state_ids(
room_id
)
state_by_room[room_id] = room_state
# Run at most 3 of these at once: sync does 10 at a time but email
+27 -10
View File
@@ -34,6 +34,7 @@ from synapse.rest.admin._base import (
assert_user_is_admin,
)
from synapse.storage.databases.main.room import RoomSortOrder
from synapse.storage.state import StateFilter
from synapse.types import JsonDict, RoomID, UserID, create_requester
from synapse.util import json_decoder
@@ -418,6 +419,7 @@ class RoomStateRestServlet(RestServlet):
def __init__(self, hs: "HomeServer"):
self.auth = hs.get_auth()
self.store = hs.get_datastores().main
self._storage_controllers = hs.get_storage_controllers()
self.clock = hs.get_clock()
self._event_serializer = hs.get_event_client_serializer()
@@ -430,7 +432,7 @@ class RoomStateRestServlet(RestServlet):
if not ret:
raise NotFoundError("Room not found")
event_ids = await self.store.get_current_state_ids(room_id)
event_ids = await self._storage_controllers.state.get_current_state_ids(room_id)
events = await self.store.get_events(event_ids.values())
now = self.clock.time_msec()
room_state = self._event_serializer.serialize_events(events.values(), now)
@@ -447,7 +449,8 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet):
super().__init__(hs)
self.auth = hs.get_auth()
self.admin_handler = hs.get_admin_handler()
self.state_handler = hs.get_state_handler()
self.store = hs.get_datastores().main
self._storage_controllers = hs.get_storage_controllers()
self.is_mine = hs.is_mine
async def on_POST(
@@ -489,8 +492,11 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet):
)
# send invite if room has "JoinRules.INVITE"
room_state = await self.state_handler.get_current_state(room_id)
join_rules_event = room_state.get((EventTypes.JoinRules, ""))
join_rules_event = (
await self._storage_controllers.state.get_current_state_event(
room_id, EventTypes.JoinRules, ""
)
)
if join_rules_event:
if not (join_rules_event.content.get("join_rule") == JoinRules.PUBLIC):
# update_membership with an action of "invite" can raise a
@@ -535,6 +541,7 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet):
super().__init__(hs)
self.auth = hs.get_auth()
self.store = hs.get_datastores().main
self._state_storage_controller = hs.get_storage_controllers().state
self.event_creation_handler = hs.get_event_creation_handler()
self.state_handler = hs.get_state_handler()
self.is_mine_id = hs.is_mine_id
@@ -552,12 +559,22 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet):
user_to_add = content.get("user_id", requester.user.to_string())
# Figure out which local users currently have power in the room, if any.
room_state = await self.state_handler.get_current_state(room_id)
if not room_state:
filtered_room_state = await self._state_storage_controller.get_current_state(
room_id,
StateFilter.from_types(
[
(EventTypes.Create, ""),
(EventTypes.PowerLevels, ""),
(EventTypes.JoinRules, ""),
(EventTypes.Member, user_to_add),
]
),
)
if not filtered_room_state:
raise SynapseError(HTTPStatus.BAD_REQUEST, "Server not in room")
create_event = room_state[(EventTypes.Create, "")]
power_levels = room_state.get((EventTypes.PowerLevels, ""))
create_event = filtered_room_state[(EventTypes.Create, "")]
power_levels = filtered_room_state.get((EventTypes.PowerLevels, ""))
if power_levels is not None:
# We pick the local user with the highest power.
@@ -633,7 +650,7 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet):
# Now we check if the user we're granting admin rights to is already in
# the room. If not and it's not a public room we invite them.
member_event = room_state.get((EventTypes.Member, user_to_add))
member_event = filtered_room_state.get((EventTypes.Member, user_to_add))
is_joined = False
if member_event:
is_joined = member_event.content["membership"] in (
@@ -644,7 +661,7 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet):
if is_joined:
return HTTPStatus.OK, {}
join_rules = room_state.get((EventTypes.JoinRules, ""))
join_rules = filtered_room_state.get((EventTypes.JoinRules, ""))
is_public = False
if join_rules:
is_public = join_rules.content.get("join_rule") == JoinRules.PUBLIC
+7 -1
View File
@@ -226,6 +226,13 @@ class UserRestServletV2(RestServlet):
if not isinstance(password, str) or len(password) > 512:
raise SynapseError(HTTPStatus.BAD_REQUEST, "Invalid password")
logout_devices = body.get("logout_devices", True)
if not isinstance(logout_devices, bool):
raise SynapseError(
HTTPStatus.BAD_REQUEST,
"'logout_devices' parameter is not of type boolean",
)
deactivate = body.get("deactivated", False)
if not isinstance(deactivate, bool):
raise SynapseError(
@@ -305,7 +312,6 @@ class UserRestServletV2(RestServlet):
await self.store.set_server_admin(target_user, set_admin_to)
if password is not None:
logout_devices = True
new_password_hash = await self.auth_handler.hash(password)
await self.set_password_handler.set_password(
+5 -2
View File
@@ -650,6 +650,7 @@ class RoomEventServlet(RestServlet):
self.clock = hs.get_clock()
self._store = hs.get_datastores().main
self._state = hs.get_state_handler()
self._storage_controllers = hs.get_storage_controllers()
self.event_handler = hs.get_event_handler()
self._event_serializer = hs.get_event_client_serializer()
self._relations_handler = hs.get_relations_handler()
@@ -673,8 +674,10 @@ class RoomEventServlet(RestServlet):
if include_unredacted_content and not await self.auth.is_server_admin(
requester.user
):
power_level_event = await self._state.get_current_state(
room_id, EventTypes.PowerLevels, ""
power_level_event = (
await self._storage_controllers.state.get_current_state_event(
room_id, EventTypes.PowerLevels, ""
)
)
auth_events = {}
+143 -126
View File
@@ -587,15 +587,16 @@ class MediaRepository:
)
return None
t_byte_source = await defer_to_thread(
self.hs.get_reactor(),
self._generate_thumbnail,
thumbnailer,
t_width,
t_height,
t_method,
t_type,
)
with thumbnailer:
t_byte_source = await defer_to_thread(
self.hs.get_reactor(),
self._generate_thumbnail,
thumbnailer,
t_width,
t_height,
t_method,
t_type,
)
if t_byte_source:
try:
@@ -657,15 +658,16 @@ class MediaRepository:
)
return None
t_byte_source = await defer_to_thread(
self.hs.get_reactor(),
self._generate_thumbnail,
thumbnailer,
t_width,
t_height,
t_method,
t_type,
)
with thumbnailer:
t_byte_source = await defer_to_thread(
self.hs.get_reactor(),
self._generate_thumbnail,
thumbnailer,
t_width,
t_height,
t_method,
t_type,
)
if t_byte_source:
try:
@@ -749,119 +751,134 @@ class MediaRepository:
)
return None
m_width = thumbnailer.width
m_height = thumbnailer.height
with thumbnailer:
m_width = thumbnailer.width
m_height = thumbnailer.height
if m_width * m_height >= self.max_image_pixels:
logger.info(
"Image too large to thumbnail %r x %r > %r",
m_width,
m_height,
self.max_image_pixels,
)
return None
if thumbnailer.transpose_method is not None:
m_width, m_height = await defer_to_thread(
self.hs.get_reactor(), thumbnailer.transpose
)
# We deduplicate the thumbnail sizes by ignoring the cropped versions if
# they have the same dimensions of a scaled one.
thumbnails: Dict[Tuple[int, int, str], str] = {}
for requirement in requirements:
if requirement.method == "crop":
thumbnails.setdefault(
(requirement.width, requirement.height, requirement.media_type),
requirement.method,
if m_width * m_height >= self.max_image_pixels:
logger.info(
"Image too large to thumbnail %r x %r > %r",
m_width,
m_height,
self.max_image_pixels,
)
elif requirement.method == "scale":
t_width, t_height = thumbnailer.aspect(
requirement.width, requirement.height
return None
if thumbnailer.transpose_method is not None:
m_width, m_height = await defer_to_thread(
self.hs.get_reactor(), thumbnailer.transpose
)
t_width = min(m_width, t_width)
t_height = min(m_height, t_height)
thumbnails[
(t_width, t_height, requirement.media_type)
] = requirement.method
# Now we generate the thumbnails for each dimension, store it
for (t_width, t_height, t_type), t_method in thumbnails.items():
# Generate the thumbnail
if t_method == "crop":
t_byte_source = await defer_to_thread(
self.hs.get_reactor(), thumbnailer.crop, t_width, t_height, t_type
)
elif t_method == "scale":
t_byte_source = await defer_to_thread(
self.hs.get_reactor(), thumbnailer.scale, t_width, t_height, t_type
)
else:
logger.error("Unrecognized method: %r", t_method)
continue
if not t_byte_source:
continue
file_info = FileInfo(
server_name=server_name,
file_id=file_id,
url_cache=url_cache,
thumbnail=ThumbnailInfo(
width=t_width,
height=t_height,
method=t_method,
type=t_type,
),
)
with self.media_storage.store_into_file(file_info) as (f, fname, finish):
try:
await self.media_storage.write_to_file(t_byte_source, f)
await finish()
finally:
t_byte_source.close()
t_len = os.path.getsize(fname)
# Write to database
if server_name:
# Multiple remote media download requests can race (when
# using multiple media repos), so this may throw a violation
# constraint exception. If it does we'll delete the newly
# generated thumbnail from disk (as we're in the ctx
# manager).
#
# However: we've already called `finish()` so we may have
# also written to the storage providers. This is preferable
# to the alternative where we call `finish()` *after* this,
# where we could end up having an entry in the DB but fail
# to write the files to the storage providers.
try:
await self.store.store_remote_media_thumbnail(
server_name,
media_id,
file_id,
t_width,
t_height,
t_type,
t_method,
t_len,
)
except Exception as e:
thumbnail_exists = await self.store.get_remote_media_thumbnail(
server_name,
media_id,
t_width,
t_height,
t_type,
)
if not thumbnail_exists:
raise e
else:
await self.store.store_local_thumbnail(
media_id, t_width, t_height, t_type, t_method, t_len
# We deduplicate the thumbnail sizes by ignoring the cropped versions if
# they have the same dimensions of a scaled one.
thumbnails: Dict[Tuple[int, int, str], str] = {}
for requirement in requirements:
if requirement.method == "crop":
thumbnails.setdefault(
(requirement.width, requirement.height, requirement.media_type),
requirement.method,
)
elif requirement.method == "scale":
t_width, t_height = thumbnailer.aspect(
requirement.width, requirement.height
)
t_width = min(m_width, t_width)
t_height = min(m_height, t_height)
thumbnails[
(t_width, t_height, requirement.media_type)
] = requirement.method
# Now we generate the thumbnails for each dimension, store it
for (t_width, t_height, t_type), t_method in thumbnails.items():
# Generate the thumbnail
if t_method == "crop":
t_byte_source = await defer_to_thread(
self.hs.get_reactor(),
thumbnailer.crop,
t_width,
t_height,
t_type,
)
elif t_method == "scale":
t_byte_source = await defer_to_thread(
self.hs.get_reactor(),
thumbnailer.scale,
t_width,
t_height,
t_type,
)
else:
logger.error("Unrecognized method: %r", t_method)
continue
if not t_byte_source:
continue
file_info = FileInfo(
server_name=server_name,
file_id=file_id,
url_cache=url_cache,
thumbnail=ThumbnailInfo(
width=t_width,
height=t_height,
method=t_method,
type=t_type,
),
)
with self.media_storage.store_into_file(file_info) as (
f,
fname,
finish,
):
try:
await self.media_storage.write_to_file(t_byte_source, f)
await finish()
finally:
t_byte_source.close()
t_len = os.path.getsize(fname)
# Write to database
if server_name:
# Multiple remote media download requests can race (when
# using multiple media repos), so this may throw a violation
# constraint exception. If it does we'll delete the newly
# generated thumbnail from disk (as we're in the ctx
# manager).
#
# However: we've already called `finish()` so we may have
# also written to the storage providers. This is preferable
# to the alternative where we call `finish()` *after* this,
# where we could end up having an entry in the DB but fail
# to write the files to the storage providers.
try:
await self.store.store_remote_media_thumbnail(
server_name,
media_id,
file_id,
t_width,
t_height,
t_type,
t_method,
t_len,
)
except Exception as e:
thumbnail_exists = (
await self.store.get_remote_media_thumbnail(
server_name,
media_id,
t_width,
t_height,
t_type,
)
)
if not thumbnail_exists:
raise e
else:
await self.store.store_local_thumbnail(
media_id, t_width, t_height, t_type, t_method, t_len
)
return {"width": m_width, "height": m_height}
+35 -17
View File
@@ -30,6 +30,9 @@ _xml_encoding_match = re.compile(
)
_content_type_match = re.compile(r'.*; *charset="?(.*?)"?(;|$)', flags=re.I)
# Certain elements aren't meant for display.
ARIA_ROLES_TO_IGNORE = {"directory", "menu", "menubar", "toolbar"}
def _normalise_encoding(encoding: str) -> Optional[str]:
"""Use the Python codec's name as the normalised entry."""
@@ -174,13 +177,15 @@ def parse_html_to_open_graph(tree: "etree.Element") -> Dict[str, Optional[str]]:
# "og:video:secure_url": "https://www.youtube.com/v/LXDBoHyjmtw?version=3",
og: Dict[str, Optional[str]] = {}
for tag in tree.xpath("//*/meta[starts-with(@property, 'og:')]"):
if "content" in tag.attrib:
# if we've got more than 50 tags, someone is taking the piss
if len(og) >= 50:
logger.warning("Skipping OG for page with too many 'og:' tags")
return {}
og[tag.attrib["property"]] = tag.attrib["content"]
for tag in tree.xpath(
"//*/meta[starts-with(@property, 'og:')][@content][not(@content='')]"
):
# if we've got more than 50 tags, someone is taking the piss
if len(og) >= 50:
logger.warning("Skipping OG for page with too many 'og:' tags")
return {}
og[tag.attrib["property"]] = tag.attrib["content"]
# TODO: grab article: meta tags too, e.g.:
@@ -192,21 +197,23 @@ def parse_html_to_open_graph(tree: "etree.Element") -> Dict[str, Optional[str]]:
# "article:modified_time" content="2016-04-01T18:31:53+00:00" />
if "og:title" not in og:
# do some basic spidering of the HTML
title = tree.xpath("(//title)[1] | (//h1)[1] | (//h2)[1] | (//h3)[1]")
if title and title[0].text is not None:
og["og:title"] = title[0].text.strip()
# Attempt to find a title from the title tag, or the biggest header on the page.
title = tree.xpath("((//title)[1] | (//h1)[1] | (//h2)[1] | (//h3)[1])/text()")
if title:
og["og:title"] = title[0].strip()
else:
og["og:title"] = None
if "og:image" not in og:
# TODO: extract a favicon failing all else
meta_image = tree.xpath(
"//*/meta[translate(@itemprop, 'IMAGE', 'image')='image']/@content"
"//*/meta[translate(@itemprop, 'IMAGE', 'image')='image'][not(@content='')]/@content[1]"
)
# If a meta image is found, use it.
if meta_image:
og["og:image"] = meta_image[0]
else:
# Try to find images which are larger than 10px by 10px.
#
# TODO: consider inlined CSS styles as well as width & height attribs
images = tree.xpath("//img[@src][number(@width)>10][number(@height)>10]")
images = sorted(
@@ -215,17 +222,24 @@ def parse_html_to_open_graph(tree: "etree.Element") -> Dict[str, Optional[str]]:
-1 * float(i.attrib["width"]) * float(i.attrib["height"])
),
)
# If no images were found, try to find *any* images.
if not images:
images = tree.xpath("//img[@src]")
images = tree.xpath("//img[@src][1]")
if images:
og["og:image"] = images[0].attrib["src"]
# Finally, fallback to the favicon if nothing else.
else:
favicons = tree.xpath("//link[@href][contains(@rel, 'icon')]/@href[1]")
if favicons:
og["og:image"] = favicons[0]
if "og:description" not in og:
# Check the first meta description tag for content.
meta_description = tree.xpath(
"//*/meta"
"[translate(@name, 'DESCRIPTION', 'description')='description']"
"/@content"
"//*/meta[translate(@name, 'DESCRIPTION', 'description')='description'][not(@content='')]/@content[1]"
)
# If a meta description is found with content, use it.
if meta_description:
og["og:description"] = meta_description[0]
else:
@@ -306,6 +320,10 @@ def _iterate_over_text(
if isinstance(el, str):
yield el
elif el.tag not in tags_to_ignore:
# If the element isn't meant for display, ignore it.
if el.get("role") in ARIA_ROLES_TO_IGNORE:
continue
# el.text is the text before the first child, so we can immediately
# return it if the text exists.
if el.text:
+17 -6
View File
@@ -586,12 +586,16 @@ class PreviewUrlResource(DirectServeJsonResource):
og: The Open Graph dictionary. This is modified with image information.
"""
# If there's no image or it is blank, there's nothing to do.
if "og:image" not in og or not og["og:image"]:
if "og:image" not in og:
return
# Remove the raw image URL, this will be replaced with an MXC URL, if successful.
image_url = og.pop("og:image")
if not image_url:
return
# The image URL from the HTML might be relative to the previewed page,
# convert it to an URL which can be requested directly.
image_url = og["og:image"]
url_parts = urlparse(image_url)
if url_parts.scheme != "data":
image_url = urljoin(media_info.uri, image_url)
@@ -599,7 +603,16 @@ class PreviewUrlResource(DirectServeJsonResource):
# FIXME: it might be cleaner to use the same flow as the main /preview_url
# request itself and benefit from the same caching etc. But for now we
# just rely on the caching on the master request to speed things up.
image_info = await self._handle_url(image_url, user, allow_data_urls=True)
try:
image_info = await self._handle_url(image_url, user, allow_data_urls=True)
except Exception as e:
# Pre-caching the image failed, don't block the entire URL preview.
logger.warning(
"Pre-caching image failed during URL preview: %s errored with %s",
image_url,
e,
)
return
if _is_media(image_info.media_type):
# TODO: make sure we don't choke on white-on-transparent images
@@ -611,13 +624,11 @@ class PreviewUrlResource(DirectServeJsonResource):
og["og:image:width"] = dims["width"]
og["og:image:height"] = dims["height"]
else:
logger.warning("Couldn't get dims for %s", og["og:image"])
logger.warning("Couldn't get dims for %s", image_url)
og["og:image"] = f"mxc://{self.server_name}/{image_info.filesystem_id}"
og["og:image:type"] = image_info.media_type
og["matrix:image:size"] = image_info.media_length
else:
del og["og:image"]
async def _handle_oembed_response(
self, url: str, media_info: MediaInfo, expiration_ms: int
+60 -11
View File
@@ -14,7 +14,8 @@
# limitations under the License.
import logging
from io import BytesIO
from typing import Tuple
from types import TracebackType
from typing import Optional, Tuple, Type
from PIL import Image
@@ -45,6 +46,9 @@ class Thumbnailer:
Image.MAX_IMAGE_PIXELS = max_image_pixels
def __init__(self, input_path: str):
# Have we closed the image?
self._closed = False
try:
self.image = Image.open(input_path)
except OSError as e:
@@ -89,7 +93,8 @@ class Thumbnailer:
# Safety: `transpose` takes an int rather than e.g. an IntEnum.
# self.transpose_method is set above to be a value in
# EXIF_TRANSPOSE_MAPPINGS, and that only contains correct values.
self.image = self.image.transpose(self.transpose_method) # type: ignore[arg-type]
with self.image:
self.image = self.image.transpose(self.transpose_method) # type: ignore[arg-type]
self.width, self.height = self.image.size
self.transpose_method = None
# We don't need EXIF any more
@@ -122,9 +127,11 @@ class Thumbnailer:
# If the image has transparency, use RGBA instead.
if self.image.mode in ["1", "L", "P"]:
if self.image.info.get("transparency", None) is not None:
self.image = self.image.convert("RGBA")
with self.image:
self.image = self.image.convert("RGBA")
else:
self.image = self.image.convert("RGB")
with self.image:
self.image = self.image.convert("RGB")
return self.image.resize((width, height), Image.ANTIALIAS)
def scale(self, width: int, height: int, output_type: str) -> BytesIO:
@@ -133,8 +140,8 @@ class Thumbnailer:
Returns:
BytesIO: the bytes of the encoded image ready to be written to disk
"""
scaled = self._resize(width, height)
return self._encode_image(scaled, output_type)
with self._resize(width, height) as scaled:
return self._encode_image(scaled, output_type)
def crop(self, width: int, height: int, output_type: str) -> BytesIO:
"""Rescales and crops the image to the given dimensions preserving
@@ -151,18 +158,21 @@ class Thumbnailer:
BytesIO: the bytes of the encoded image ready to be written to disk
"""
if width * self.height > height * self.width:
scaled_width = width
scaled_height = (width * self.height) // self.width
scaled_image = self._resize(width, scaled_height)
crop_top = (scaled_height - height) // 2
crop_bottom = height + crop_top
cropped = scaled_image.crop((0, crop_top, width, crop_bottom))
crop = (0, crop_top, width, crop_bottom)
else:
scaled_width = (height * self.width) // self.height
scaled_image = self._resize(scaled_width, height)
scaled_height = height
crop_left = (scaled_width - width) // 2
crop_right = width + crop_left
cropped = scaled_image.crop((crop_left, 0, crop_right, height))
return self._encode_image(cropped, output_type)
crop = (crop_left, 0, crop_right, height)
with self._resize(scaled_width, scaled_height) as scaled_image:
with scaled_image.crop(crop) as cropped:
return self._encode_image(cropped, output_type)
def _encode_image(self, output_image: Image.Image, output_type: str) -> BytesIO:
output_bytes_io = BytesIO()
@@ -171,3 +181,42 @@ class Thumbnailer:
output_image = output_image.convert("RGB")
output_image.save(output_bytes_io, fmt, quality=80)
return output_bytes_io
def close(self) -> None:
"""Closes the underlying image file.
Once closed no other functions can be called.
Can be called multiple times.
"""
if self._closed:
return
self._closed = True
# Since we run this on the finalizer then we need to handle `__init__`
# raising an exception before it can define `self.image`.
image = getattr(self, "image", None)
if image is None:
return
image.close()
def __enter__(self) -> "Thumbnailer":
"""Make `Thumbnailer` a context manager that calls `close` on
`__exit__`.
"""
return self
def __exit__(
self,
type: Optional[Type[BaseException]],
value: Optional[BaseException],
traceback: Optional[TracebackType],
) -> None:
self.close()
def __del__(self) -> None:
# Make sure we actually do close the image, rather than leak data.
self.close()
@@ -36,6 +36,7 @@ class ResourceLimitsServerNotices:
def __init__(self, hs: "HomeServer"):
self._server_notices_manager = hs.get_server_notices_manager()
self._store = hs.get_datastores().main
self._storage_controllers = hs.get_storage_controllers()
self._auth = hs.get_auth()
self._config = hs.config
self._resouce_limited = False
@@ -178,8 +179,10 @@ class ResourceLimitsServerNotices:
currently_blocked = False
pinned_state_event = None
try:
pinned_state_event = await self._state.get_current_state(
room_id, event_type=EventTypes.Pinned
pinned_state_event = (
await self._storage_controllers.state.get_current_state_event(
room_id, event_type=EventTypes.Pinned, state_key=""
)
)
except AuthError:
# The user has yet to join the server notices room
+4 -75
View File
@@ -32,13 +32,11 @@ from typing import (
Set,
Tuple,
Union,
overload,
)
import attr
from frozendict import frozendict
from prometheus_client import Counter, Histogram
from typing_extensions import Literal
from synapse.api.constants import EventTypes
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, StateResolutionVersions
@@ -132,85 +130,20 @@ class StateHandler:
self._state_resolution_handler = hs.get_state_resolution_handler()
self._storage_controllers = hs.get_storage_controllers()
@overload
async def get_current_state(
self,
room_id: str,
event_type: Literal[None] = None,
state_key: str = "",
latest_event_ids: Optional[List[str]] = None,
) -> StateMap[EventBase]:
...
@overload
async def get_current_state(
self,
room_id: str,
event_type: str,
state_key: str = "",
latest_event_ids: Optional[List[str]] = None,
) -> Optional[EventBase]:
...
async def get_current_state(
self,
room_id: str,
event_type: Optional[str] = None,
state_key: str = "",
latest_event_ids: Optional[List[str]] = None,
) -> Union[Optional[EventBase], StateMap[EventBase]]:
"""Retrieves the current state for the room. This is done by
calling `get_latest_events_in_room` to get the leading edges of the
event graph and then resolving any of the state conflicts.
This is equivalent to getting the state of an event that were to send
next before receiving any new events.
Returns:
If `event_type` is specified, then the method returns only the one
event (or None) with that `event_type` and `state_key`.
Otherwise, a map from (type, state_key) to event.
"""
if not latest_event_ids:
latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id)
assert latest_event_ids is not None
logger.debug("calling resolve_state_groups from get_current_state")
ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
state = ret.state
if event_type:
event_id = state.get((event_type, state_key))
event = None
if event_id:
event = await self.store.get_event(event_id, allow_none=True)
return event
state_map = await self.store.get_events(
list(state.values()), get_prev_content=False
)
return {
key: state_map[e_id] for key, e_id in state.items() if e_id in state_map
}
async def get_current_state_ids(
self, room_id: str, latest_event_ids: Optional[Collection[str]] = None
self,
room_id: str,
latest_event_ids: Collection[str],
) -> StateMap[str]:
"""Get the current state, or the state at a set of events, for a room
Args:
room_id:
latest_event_ids: if given, the forward extremities to resolve. If
None, we look them up from the database (via a cache).
latest_event_ids: The forward extremities to resolve.
Returns:
the state dict, mapping from (event_type, state_key) -> event_id
"""
if not latest_event_ids:
latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id)
assert latest_event_ids is not None
logger.debug("calling resolve_state_groups from get_current_state_ids")
ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
return ret.state
@@ -239,10 +172,6 @@ class StateHandler:
entry = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
return await self.store.get_joined_users_from_state(room_id, entry)
async def get_current_hosts_in_room(self, room_id: str) -> FrozenSet[str]:
event_ids = await self.store.get_latest_event_ids_in_room(room_id)
return await self.get_hosts_in_room_at_events(room_id, event_ids)
async def get_hosts_in_room_at_events(
self, room_id: str, event_ids: Collection[str]
) -> FrozenSet[str]:
+2 -1
View File
@@ -71,13 +71,14 @@ class SQLBaseStore(metaclass=ABCMeta):
self._attempt_to_invalidate_cache("is_host_joined", (room_id, host))
if members_changed:
self._attempt_to_invalidate_cache("get_users_in_room", (room_id,))
self._attempt_to_invalidate_cache("get_current_hosts_in_room", (room_id,))
self._attempt_to_invalidate_cache(
"get_users_in_room_with_profiles", (room_id,)
)
# Purge other caches based on room state.
self._attempt_to_invalidate_cache("get_room_summary", (room_id,))
self._attempt_to_invalidate_cache("get_current_state_ids", (room_id,))
self._attempt_to_invalidate_cache("get_partial_current_state_ids", (room_id,))
def _attempt_to_invalidate_cache(
self, cache_name: str, key: Optional[Collection[Any]]
+2 -2
View File
@@ -18,7 +18,7 @@ from synapse.storage.controllers.persist_events import (
EventsPersistenceStorageController,
)
from synapse.storage.controllers.purge_events import PurgeEventsStorageController
from synapse.storage.controllers.state import StateGroupStorageController
from synapse.storage.controllers.state import StateStorageController
from synapse.storage.databases import Databases
from synapse.storage.databases.main import DataStore
@@ -39,7 +39,7 @@ class StorageControllers:
self.main = stores.main
self.purge_events = PurgeEventsStorageController(hs, stores)
self.state = StateGroupStorageController(hs, stores)
self.state = StateStorageController(hs, stores)
self.persistence = None
if stores.persist_events:
@@ -994,7 +994,7 @@ class EventsPersistenceStorageController:
Assumes that we are only persisting events for one room at a time.
"""
existing_state = await self.main_store.get_current_state_ids(room_id)
existing_state = await self.main_store.get_partial_current_state_ids(room_id)
to_delete = [key for key in existing_state if key not in current_state]
@@ -1083,7 +1083,7 @@ class EventsPersistenceStorageController:
# The server will leave the room, so we go and find out which remote
# users will still be joined when we leave.
if current_state is None:
current_state = await self.main_store.get_current_state_ids(room_id)
current_state = await self.main_store.get_partial_current_state_ids(room_id)
current_state = dict(current_state)
for key in delta.to_delete:
current_state.pop(key, None)
+144 -3
View File
@@ -14,19 +14,26 @@
import logging
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Collection,
Dict,
Iterable,
List,
Mapping,
Optional,
Set,
Tuple,
)
from synapse.api.constants import EventTypes
from synapse.events import EventBase
from synapse.storage.state import StateFilter
from synapse.storage.util.partial_state_events_tracker import PartialStateEventsTracker
from synapse.storage.util.partial_state_events_tracker import (
PartialCurrentStateTracker,
PartialStateEventsTracker,
)
from synapse.types import MutableStateMap, StateMap
if TYPE_CHECKING:
@@ -36,17 +43,27 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
class StateGroupStorageController:
"""High level interface to fetching state for event."""
class StateStorageController:
"""High level interface to fetching state for an event, or the current state
in a room.
"""
def __init__(self, hs: "HomeServer", stores: "Databases"):
self._is_mine_id = hs.is_mine_id
self.stores = stores
self._partial_state_events_tracker = PartialStateEventsTracker(stores.main)
self._partial_state_room_tracker = PartialCurrentStateTracker(stores.main)
def notify_event_un_partial_stated(self, event_id: str) -> None:
self._partial_state_events_tracker.notify_un_partial_stated(event_id)
def notify_room_un_partial_stated(self, room_id: str) -> None:
"""Notify that the room no longer has any partial state.
Must be called after `DataStore.clear_partial_state_room`
"""
self._partial_state_room_tracker.notify_un_partial_stated(room_id)
async def get_state_group_delta(
self, state_group: int
) -> Tuple[Optional[int], Optional[StateMap[str]]]:
@@ -349,3 +366,127 @@ class StateGroupStorageController:
return await self.stores.state.store_state_group(
event_id, room_id, prev_group, delta_ids, current_state_ids
)
async def get_current_state_ids(
self,
room_id: str,
state_filter: Optional[StateFilter] = None,
on_invalidate: Optional[Callable[[], None]] = None,
) -> StateMap[str]:
"""Get the current state event ids for a room based on the
current_state_events table.
If a state filter is given (that is not `StateFilter.all()`) the query
result is *not* cached.
Args:
room_id: The room to get the state IDs of. state_filter: The state
filter used to fetch state from the
database.
on_invalidate: Callback for when the `get_current_state_ids` cache
for the room gets invalidated.
Returns:
The current state of the room.
"""
if not state_filter or state_filter.must_await_full_state(self._is_mine_id):
await self._partial_state_room_tracker.await_full_state(room_id)
if state_filter and not state_filter.is_full():
return await self.stores.main.get_partial_filtered_current_state_ids(
room_id, state_filter
)
else:
return await self.stores.main.get_partial_current_state_ids(
room_id, on_invalidate=on_invalidate
)
async def get_canonical_alias_for_room(self, room_id: str) -> Optional[str]:
"""Get canonical alias for room, if any
Args:
room_id: The room ID
Returns:
The canonical alias, if any
"""
state = await self.get_current_state_ids(
room_id, StateFilter.from_types([(EventTypes.CanonicalAlias, "")])
)
event_id = state.get((EventTypes.CanonicalAlias, ""))
if not event_id:
return None
event = await self.stores.main.get_event(event_id, allow_none=True)
if not event:
return None
return event.content.get("canonical_alias")
async def get_current_state_deltas(
self, prev_stream_id: int, max_stream_id: int
) -> Tuple[int, List[Dict[str, Any]]]:
"""Fetch a list of room state changes since the given stream id
Each entry in the result contains the following fields:
- stream_id (int)
- room_id (str)
- type (str): event type
- state_key (str):
- event_id (str|None): new event_id for this state key. None if the
state has been deleted.
- prev_event_id (str|None): previous event_id for this state key. None
if it's new state.
Args:
prev_stream_id: point to get changes since (exclusive)
max_stream_id: the point that we know has been correctly persisted
- ie, an upper limit to return changes from.
Returns:
A tuple consisting of:
- the stream id which these results go up to
- list of current_state_delta_stream rows. If it is empty, we are
up to date.
"""
# FIXME(faster_joins): what do we do here?
return await self.stores.main.get_partial_current_state_deltas(
prev_stream_id, max_stream_id
)
async def get_current_state(
self, room_id: str, state_filter: Optional[StateFilter] = None
) -> StateMap[EventBase]:
"""Same as `get_current_state_ids` but also fetches the events"""
state_map_ids = await self.get_current_state_ids(room_id, state_filter)
event_map = await self.stores.main.get_events(list(state_map_ids.values()))
state_map = {}
for key, event_id in state_map_ids.items():
event = event_map.get(event_id)
if event:
state_map[key] = event
return state_map
async def get_current_state_event(
self, room_id: str, event_type: str, state_key: str
) -> Optional[EventBase]:
"""Get the current state event for the given type/state_key."""
key = (event_type, state_key)
state_map = await self.get_current_state(
room_id, StateFilter.from_types((key,))
)
return state_map.get(key)
async def get_current_hosts_in_room(self, room_id: str) -> Set[str]:
"""Get current hosts in room based on current state."""
await self._partial_state_room_tracker.await_full_state(room_id)
return await self.stores.main.get_current_hosts_in_room(room_id)
@@ -151,10 +151,6 @@ class DataStore(
],
)
self._group_updates_id_gen = StreamIdGenerator(
db_conn, "local_group_updates", "stream_id"
)
self._cache_id_gen: Optional[MultiWriterIdGenerator]
if isinstance(self.database_engine, PostgresEngine):
# We set the `writers` to an empty list here as we don't care about
@@ -197,20 +193,6 @@ class DataStore(
prefilled_cache=curr_state_delta_prefill,
)
_group_updates_prefill, min_group_updates_id = self.db_pool.get_cache_dict(
db_conn,
"local_group_updates",
entity_column="user_id",
stream_column="stream_id",
max_value=self._group_updates_id_gen.get_current_token(),
limit=1000,
)
self._group_updates_stream_cache = StreamChangeCache(
"_group_updates_stream_cache",
min_group_updates_id,
prefilled_cache=_group_updates_prefill,
)
self._stream_order_on_start = self.get_room_max_stream_ordering()
self._min_stream_order_on_start = self.get_room_min_stream_ordering()
@@ -29,11 +29,6 @@ class GroupServerStore(SQLBaseStore):
db_conn: LoggingDatabaseConnection,
hs: "HomeServer",
):
database.updates.register_background_index_update(
update_name="local_group_updates_index",
index_name="local_group_updates_stream_id_index",
table="local_group_updates",
columns=("stream_id",),
unique=True,
)
# Register a legacy groups background update as a no-op.
database.updates.register_noop_background_update("local_group_updates_index")
super().__init__(database, db_conn, hs)
@@ -276,10 +276,6 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
(SELECT 1
FROM profiles
WHERE profiles.avatar_url = '{media_prefix}' || lmr.media_id)
AND NOT EXISTS
(SELECT 1
FROM groups
WHERE groups.avatar_url = '{media_prefix}' || lmr.media_id)
AND NOT EXISTS
(SELECT 1
FROM room_memberships
+2 -105
View File
@@ -11,11 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List, Optional
from typing import Optional
from synapse.api.errors import StoreError
from synapse.storage._base import SQLBaseStore
from synapse.storage.database import LoggingTransaction
from synapse.storage.databases.main.roommember import ProfileInfo
@@ -55,17 +54,6 @@ class ProfileWorkerStore(SQLBaseStore):
desc="get_profile_avatar_url",
)
async def get_from_remote_profile_cache(
self, user_id: str
) -> Optional[Dict[str, Any]]:
return await self.db_pool.simple_select_one(
table="remote_profile_cache",
keyvalues={"user_id": user_id},
retcols=("displayname", "avatar_url"),
allow_none=True,
desc="get_from_remote_profile_cache",
)
async def create_profile(self, user_localpart: str) -> None:
await self.db_pool.simple_insert(
table="profiles", values={"user_id": user_localpart}, desc="create_profile"
@@ -91,97 +79,6 @@ class ProfileWorkerStore(SQLBaseStore):
desc="set_profile_avatar_url",
)
async def update_remote_profile_cache(
self, user_id: str, displayname: Optional[str], avatar_url: Optional[str]
) -> int:
return await self.db_pool.simple_update(
table="remote_profile_cache",
keyvalues={"user_id": user_id},
updatevalues={
"displayname": displayname,
"avatar_url": avatar_url,
"last_check": self._clock.time_msec(),
},
desc="update_remote_profile_cache",
)
async def maybe_delete_remote_profile_cache(self, user_id: str) -> None:
"""Check if we still care about the remote user's profile, and if we
don't then remove their profile from the cache
"""
subscribed = await self.is_subscribed_remote_profile_for_user(user_id)
if not subscribed:
await self.db_pool.simple_delete(
table="remote_profile_cache",
keyvalues={"user_id": user_id},
desc="delete_remote_profile_cache",
)
async def is_subscribed_remote_profile_for_user(self, user_id: str) -> bool:
"""Check whether we are interested in a remote user's profile."""
res: Optional[str] = await self.db_pool.simple_select_one_onecol(
table="group_users",
keyvalues={"user_id": user_id},
retcol="user_id",
allow_none=True,
desc="should_update_remote_profile_cache_for_user",
)
if res:
return True
res = await self.db_pool.simple_select_one_onecol(
table="group_invites",
keyvalues={"user_id": user_id},
retcol="user_id",
allow_none=True,
desc="should_update_remote_profile_cache_for_user",
)
if res:
return True
return False
async def get_remote_profile_cache_entries_that_expire(
self, last_checked: int
) -> List[Dict[str, str]]:
"""Get all users who haven't been checked since `last_checked`"""
def _get_remote_profile_cache_entries_that_expire_txn(
txn: LoggingTransaction,
) -> List[Dict[str, str]]:
sql = """
SELECT user_id, displayname, avatar_url
FROM remote_profile_cache
WHERE last_check < ?
"""
txn.execute(sql, (last_checked,))
return self.db_pool.cursor_to_dict(txn)
return await self.db_pool.runInteraction(
"get_remote_profile_cache_entries_that_expire",
_get_remote_profile_cache_entries_that_expire_txn,
)
class ProfileStore(ProfileWorkerStore):
async def add_remote_profile_cache(
self, user_id: str, displayname: str, avatar_url: str
) -> None:
"""Ensure we are caching the remote user's profiles.
This should only be called when `is_subscribed_remote_profile_for_user`
would return true for the user.
"""
await self.db_pool.simple_upsert(
table="remote_profile_cache",
keyvalues={"user_id": user_id},
values={
"displayname": displayname,
"avatar_url": avatar_url,
"last_check": self._clock.time_msec(),
},
desc="add_remote_profile_cache",
)
pass
@@ -393,7 +393,6 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
"partial_state_events",
"events",
"federation_inbound_events_staging",
"group_rooms",
"local_current_membership",
"partial_state_rooms_servers",
"partial_state_rooms",
@@ -413,7 +412,6 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
"e2e_room_keys",
"event_push_summary",
"pusher_throttle",
"group_summary_rooms",
"room_account_data",
"room_tags",
# "rooms" happens last, to keep the foreign keys in the other tables
+18
View File
@@ -1139,6 +1139,24 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
keyvalues={"room_id": room_id},
)
async def is_partial_state_room(self, room_id: str) -> bool:
"""Checks if this room has partial state.
Returns true if this is a "partial-state" room, which means that the state
at events in the room, and `current_state_events`, may not yet be
complete.
"""
entry = await self.db_pool.simple_select_one_onecol(
table="partial_state_rooms",
keyvalues={"room_id": room_id},
retcol="room_id",
allow_none=True,
desc="is_partial_state_room",
)
return entry is not None
class _BackgroundUpdates:
REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory"
@@ -893,6 +893,43 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return True
@cached(iterable=True, max_entries=10000)
async def get_current_hosts_in_room(self, room_id: str) -> Set[str]:
"""Get current hosts in room based on current state."""
# First we check if we already have `get_users_in_room` in the cache, as
# we can just calculate result from that
users = self.get_users_in_room.cache.get_immediate(
(room_id,), None, update_metrics=False
)
if users is not None:
return {get_domain_from_id(u) for u in users}
if isinstance(self.database_engine, Sqlite3Engine):
# If we're using SQLite then let's just always use
# `get_users_in_room` rather than funky SQL.
users = await self.get_users_in_room(room_id)
return {get_domain_from_id(u) for u in users}
# For PostgreSQL we can use a regex to pull out the domains from the
# joined users in `current_state_events` via regex.
def get_current_hosts_in_room_txn(txn: LoggingTransaction) -> Set[str]:
sql = """
SELECT DISTINCT substring(state_key FROM '@[^:]*:(.*)$')
FROM current_state_events
WHERE
type = 'm.room.member'
AND membership = 'join'
AND room_id = ?
"""
txn.execute(sql, (room_id,))
return {d for d, in txn}
return await self.db_pool.runInteraction(
"get_current_hosts_in_room", get_current_hosts_in_room_txn
)
async def get_joined_hosts(
self, room_id: str, state_entry: "_StateCacheEntry"
) -> FrozenSet[str]:
+18 -32
View File
@@ -54,6 +54,7 @@ class EventMetadata:
room_id: str
event_type: str
state_key: Optional[str]
rejection_reason: Optional[str]
def _retrieve_and_check_room_version(room_id: str, room_version_id: str) -> RoomVersion:
@@ -167,17 +168,22 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
)
sql = f"""
SELECT e.event_id, e.room_id, e.type, se.state_key FROM events AS e
SELECT e.event_id, e.room_id, e.type, se.state_key, r.reason
FROM events AS e
LEFT JOIN state_events se USING (event_id)
LEFT JOIN rejections r USING (event_id)
WHERE {clause}
"""
txn.execute(sql, args)
return {
event_id: EventMetadata(
room_id=room_id, event_type=event_type, state_key=state_key
room_id=room_id,
event_type=event_type,
state_key=state_key,
rejection_reason=rejection_reason,
)
for event_id, room_id, event_type, state_key in txn
for event_id, room_id, event_type, state_key, rejection_reason in txn
}
result_map: Dict[str, EventMetadata] = {}
@@ -236,7 +242,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
Raises:
NotFoundError if the room is unknown
"""
state_ids = await self.get_current_state_ids(room_id)
state_ids = await self.get_partial_current_state_ids(room_id)
if not state_ids:
raise NotFoundError(f"Current state for room {room_id} is empty")
@@ -252,10 +258,12 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
return create_event
@cached(max_entries=100000, iterable=True)
async def get_current_state_ids(self, room_id: str) -> StateMap[str]:
async def get_partial_current_state_ids(self, room_id: str) -> StateMap[str]:
"""Get the current state event ids for a room based on the
current_state_events table.
This may be the partial state if we're lazy joining the room.
Args:
room_id: The room to get the state IDs of.
@@ -274,17 +282,19 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
return {(intern_string(r[0]), intern_string(r[1])): r[2] for r in txn}
return await self.db_pool.runInteraction(
"get_current_state_ids", _get_current_state_ids_txn
"get_partial_current_state_ids", _get_current_state_ids_txn
)
# FIXME: how should this be cached?
async def get_filtered_current_state_ids(
async def get_partial_filtered_current_state_ids(
self, room_id: str, state_filter: Optional[StateFilter] = None
) -> StateMap[str]:
"""Get the current state event of a given type for a room based on the
current_state_events table. This may not be as up-to-date as the result
of doing a fresh state resolution as per state_handler.get_current_state
This may be the partial state if we're lazy joining the room.
Args:
room_id
state_filter: The state filter used to fetch state
@@ -300,7 +310,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
if not where_clause:
# We delegate to the cached version
return await self.get_current_state_ids(room_id)
return await self.get_partial_current_state_ids(room_id)
def _get_filtered_current_state_ids_txn(
txn: LoggingTransaction,
@@ -328,30 +338,6 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
"get_filtered_current_state_ids", _get_filtered_current_state_ids_txn
)
async def get_canonical_alias_for_room(self, room_id: str) -> Optional[str]:
"""Get canonical alias for room, if any
Args:
room_id: The room ID
Returns:
The canonical alias, if any
"""
state = await self.get_filtered_current_state_ids(
room_id, StateFilter.from_types([(EventTypes.CanonicalAlias, "")])
)
event_id = state.get((EventTypes.CanonicalAlias, ""))
if not event_id:
return None
event = await self.get_event(event_id, allow_none=True)
if not event:
return None
return event.content.get("canonical_alias")
@cached(max_entries=50000)
async def _get_state_group_for_event(self, event_id: str) -> Optional[int]:
return await self.db_pool.simple_select_one_onecol(
@@ -27,7 +27,7 @@ class StateDeltasStore(SQLBaseStore):
# attribute. TODO: can we get static analysis to enforce this?
_curr_state_delta_stream_cache: StreamChangeCache
async def get_current_state_deltas(
async def get_partial_current_state_deltas(
self, prev_stream_id: int, max_stream_id: int
) -> Tuple[int, List[Dict[str, Any]]]:
"""Fetch a list of room state changes since the given stream id
@@ -42,6 +42,8 @@ class StateDeltasStore(SQLBaseStore):
- prev_event_id (str|None): previous event_id for this state key. None
if it's new state.
This may be the partial state if we're lazy joining the room.
Args:
prev_stream_id: point to get changes since (exclusive)
max_stream_id: the point that we know has been correctly persisted
+5 -7
View File
@@ -765,15 +765,16 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
self,
room_id: str,
end_token: RoomStreamToken,
) -> Optional[EventBase]:
"""Returns the last event in a room at or before a stream ordering
) -> Optional[str]:
"""Returns the ID of the last event in a room at or before a stream ordering
Args:
room_id
end_token: The token used to stream from
Returns:
The most recent event.
The ID of the most recent event, or None if there are no events in the room
before this stream ordering.
"""
last_row = await self.get_room_event_before_stream_ordering(
@@ -781,10 +782,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
stream_ordering=end_token.stream,
)
if last_row:
_, _, event_id = last_row
event = await self.get_event(event_id, get_prev_content=True)
return event
return last_row[2]
return None
async def get_current_room_stream_token_for_room_id(
@@ -441,7 +441,9 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
(EventTypes.RoomHistoryVisibility, ""),
)
current_state_ids = await self.get_filtered_current_state_ids( # type: ignore[attr-defined]
# Getting the partial state is fine, as we're not looking at membership
# events.
current_state_ids = await self.get_partial_filtered_current_state_ids( # type: ignore[attr-defined]
room_id, StateFilter.from_types(types_to_filter)
)
+1
View File
@@ -70,6 +70,7 @@ Changes in SCHEMA_VERSION = 70:
Changes in SCHEMA_VERSION = 71:
- event_edges.room_id is no longer read from.
- Tables related to groups are no longer accessed.
"""
@@ -21,6 +21,7 @@ from twisted.internet.defer import Deferred
from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.databases.main.room import RoomWorkerStore
from synapse.util import unwrapFirstError
logger = logging.getLogger(__name__)
@@ -118,3 +119,62 @@ class PartialStateEventsTracker:
observer_set.discard(observer)
if not observer_set:
del self._observers[event_id]
class PartialCurrentStateTracker:
"""Keeps track of which rooms have partial state, after partial-state joins"""
def __init__(self, store: RoomWorkerStore):
self._store = store
# a map from room id to a set of Deferreds which are waiting for that room to be
# un-partial-stated.
self._observers: Dict[str, Set[Deferred[None]]] = defaultdict(set)
def notify_un_partial_stated(self, room_id: str) -> None:
"""Notify that we now have full current state for a given room
Unblocks any callers to await_full_state() for that room.
Args:
room_id: the room that now has full current state.
"""
observers = self._observers.pop(room_id, None)
if not observers:
return
logger.info(
"Notifying %i things waiting for un-partial-stating of room %s",
len(observers),
room_id,
)
with PreserveLoggingContext():
for o in observers:
o.callback(None)
async def await_full_state(self, room_id: str) -> None:
# We add the deferred immediately so that the DB call to check for
# partial state doesn't race when we unpartial the room.
d: Deferred[None] = Deferred()
self._observers.setdefault(room_id, set()).add(d)
try:
# Check if the room has partial current state or not.
has_partial_state = await self._store.is_partial_state_room(room_id)
if not has_partial_state:
return
logger.info(
"Awaiting un-partial-stating of room %s",
room_id,
)
await make_deferred_yieldable(d)
logger.info("Room has un-partial-stated")
finally:
# Remove the added observer, and remove the room entry if its empty.
ds = self._observers.get(room_id)
if ds is not None:
ds.discard(d)
if not ds:
self._observers.pop(room_id, None)
+18 -10
View File
@@ -162,16 +162,7 @@ async def filter_events_for_client(
state = event_id_to_state[event.event_id]
# get the room_visibility at the time of the event.
visibility_event = state.get(_HISTORY_VIS_KEY, None)
if visibility_event:
visibility = visibility_event.content.get(
"history_visibility", HistoryVisibility.SHARED
)
else:
visibility = HistoryVisibility.SHARED
if visibility not in VISIBILITY_PRIORITY:
visibility = HistoryVisibility.SHARED
visibility = get_effective_room_visibility_from_state(state)
# Always allow history visibility events on boundaries. This is done
# by setting the effective visibility to the least restrictive
@@ -267,6 +258,23 @@ async def filter_events_for_client(
return [ev for ev in filtered_events if ev]
def get_effective_room_visibility_from_state(state: StateMap[EventBase]) -> str:
"""Get the actual history vis, from a state map including the history_visibility event
Handles missing and invalid history visibility events.
"""
visibility_event = state.get(_HISTORY_VIS_KEY, None)
if not visibility_event:
return HistoryVisibility.SHARED
visibility = visibility_event.content.get(
"history_visibility", HistoryVisibility.SHARED
)
if visibility not in VISIBILITY_PRIORITY:
visibility = HistoryVisibility.SHARED
return visibility
async def filter_events_for_server(
storage: StorageControllers,
server_name: str,
-2
View File
@@ -404,7 +404,6 @@ class AuthTestCase(unittest.HomeserverTestCase):
appservice = ApplicationService(
"abcd",
self.hs.config.server.server_name,
id="1234",
namespaces={
"users": [{"regex": "@_appservice.*:sender", "exclusive": True}]
@@ -433,7 +432,6 @@ class AuthTestCase(unittest.HomeserverTestCase):
appservice = ApplicationService(
"abcd",
self.hs.config.server.server_name,
id="1234",
namespaces={
"users": [{"regex": "@_appservice.*:sender", "exclusive": True}]
-2
View File
@@ -31,7 +31,6 @@ class TestRatelimiter(unittest.HomeserverTestCase):
def test_allowed_appservice_ratelimited_via_can_requester_do_action(self):
appservice = ApplicationService(
None,
"example.com",
id="foo",
rate_limited=True,
sender="@as:example.com",
@@ -62,7 +61,6 @@ class TestRatelimiter(unittest.HomeserverTestCase):
def test_allowed_appservice_via_can_requester_do_action(self):
appservice = ApplicationService(
None,
"example.com",
id="foo",
rate_limited=False,
sender="@as:example.com",
-1
View File
@@ -37,7 +37,6 @@ class ApplicationServiceApiTestCase(unittest.HomeserverTestCase):
url=URL,
token="unused",
hs_token=TOKEN,
hostname="myserver",
)
def test_query_3pe_authenticates_token(self):
-1
View File
@@ -33,7 +33,6 @@ class ApplicationServiceTestCase(unittest.TestCase):
sender="@as:test",
url="some_url",
token="some_token",
hostname="matrix.org", # only used by get_groups_for_user
)
self.event = Mock(
event_id="$abc:xyz",
+5 -12
View File
@@ -12,10 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import nacl.signing
import signedjson.types
from unpaddedbase64 import decode_base64
from signedjson.key import decode_signing_key_base64
from signedjson.types import SigningKey
from synapse.api.room_versions import RoomVersions
from synapse.crypto.event_signing import add_hashes_and_signatures
@@ -25,7 +23,7 @@ from tests import unittest
# Perform these tests using given secret key so we get entirely deterministic
# signatures output that we can test against.
SIGNING_KEY_SEED = decode_base64("YJDBA9Xnr2sVqXD9Vj7XVUnmFZcZrlw8Md7kMW+3XA1")
SIGNING_KEY_SEED = "YJDBA9Xnr2sVqXD9Vj7XVUnmFZcZrlw8Md7kMW+3XA1"
KEY_ALG = "ed25519"
KEY_VER = "1"
@@ -36,14 +34,9 @@ HOSTNAME = "domain"
class EventSigningTestCase(unittest.TestCase):
def setUp(self):
# NB: `signedjson` expects `nacl.signing.SigningKey` instances which have been
# monkeypatched to include new `alg` and `version` attributes. This is captured
# by the `signedjson.types.SigningKey` protocol.
self.signing_key: signedjson.types.SigningKey = nacl.signing.SigningKey( # type: ignore[assignment]
SIGNING_KEY_SEED
self.signing_key: SigningKey = decode_signing_key_base64(
KEY_ALG, KEY_VER, SIGNING_KEY_SEED
)
self.signing_key.alg = KEY_ALG
self.signing_key.version = KEY_VER
def test_sign_minimal(self):
event_dict = {
+1 -1
View File
@@ -19,8 +19,8 @@ import attr
import canonicaljson
import signedjson.key
import signedjson.sign
from nacl.signing import SigningKey
from signedjson.key import encode_verify_key_base64, get_verify_key
from signedjson.types import SigningKey
from twisted.internet import defer
from twisted.internet.defer import Deferred, ensureDeferred
+7 -7
View File
@@ -30,16 +30,16 @@ from tests.unittest import HomeserverTestCase, override_config
class FederationSenderReceiptsTestCases(HomeserverTestCase):
def make_homeserver(self, reactor, clock):
mock_state_handler = Mock(spec=["get_current_hosts_in_room"])
# Ensure a new Awaitable is created for each call.
mock_state_handler.get_current_hosts_in_room.return_value = make_awaitable(
["test", "host2"]
)
return self.setup_test_homeserver(
state_handler=mock_state_handler,
hs = self.setup_test_homeserver(
federation_transport_client=Mock(spec=["send_transaction"]),
)
hs.get_storage_controllers().state.get_current_hosts_in_room = Mock(
return_value=make_awaitable({"test", "host2"})
)
return hs
@override_config({"send_federation": True})
def test_send_receipts(self):
mock_send_transaction = (
+4 -2
View File
@@ -134,6 +134,8 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer):
super().prepare(reactor, clock, hs)
self._storage_controllers = hs.get_storage_controllers()
# create the room
creator_user_id = self.register_user("kermit", "test")
tok = self.login("kermit", "test")
@@ -207,7 +209,7 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase):
# the room should show that the new user is a member
r = self.get_success(
self.hs.get_state_handler().get_current_state(self._room_id)
self._storage_controllers.state.get_current_state(self._room_id)
)
self.assertEqual(r[("m.room.member", joining_user)].membership, "join")
@@ -258,7 +260,7 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase):
# the room should show that the new user is a member
r = self.get_success(
self.hs.get_state_handler().get_current_state(self._room_id)
self._storage_controllers.state.get_current_state(self._room_id)
)
self.assertEqual(r[("m.room.member", joining_user)].membership, "join")
-3
View File
@@ -697,7 +697,6 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
# Create an application service
appservice = ApplicationService(
token=random_string(10),
hostname="example.com",
id=random_string(10),
sender="@as:example.com",
rate_limited=False,
@@ -776,7 +775,6 @@ class ApplicationServicesHandlerDeviceListsTestCase(unittest.HomeserverTestCase)
# Create an appservice that is interested in "local_user"
appservice = ApplicationService(
token=random_string(10),
hostname="example.com",
id=random_string(10),
sender="@as:example.com",
rate_limited=False,
@@ -843,7 +841,6 @@ class ApplicationServicesHandlerOtkCountsTestCase(unittest.HomeserverTestCase):
self._service_token = "VERYSECRET"
self._service = ApplicationService(
self._service_token,
"as1.invalid",
"as1",
"@as.sender:test",
namespaces={
+2 -1
View File
@@ -298,6 +298,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
self.store = hs.get_datastores().main
self.handler = hs.get_directory_handler()
self.state_handler = hs.get_state_handler()
self._storage_controllers = hs.get_storage_controllers()
# Create user
self.admin_user = self.register_user("admin", "pass", admin=True)
@@ -335,7 +336,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase):
def _get_canonical_alias(self):
"""Get the canonical alias state of the room."""
return self.get_success(
self.state_handler.get_current_state(
self._storage_controllers.state.get_current_state_event(
self.room_id, EventTypes.CanonicalAlias, ""
)
)
+4 -2
View File
@@ -237,7 +237,9 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
)
current_state = self.get_success(
self.store.get_events_as_list(
(self.get_success(self.store.get_current_state_ids(room_id))).values()
(
self.get_success(self.store.get_partial_current_state_ids(room_id))
).values()
)
)
@@ -512,7 +514,7 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
self.get_success(d)
# sanity-check: the room should show that the new user is a member
r = self.get_success(self.store.get_current_state_ids(room_id))
r = self.get_success(self.store.get_partial_current_state_ids(room_id))
self.assertEqual(r[(EventTypes.Member, other_user)], join_event.event_id)
return join_event
+3 -1
View File
@@ -91,7 +91,9 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase):
event_injection.inject_member_event(self.hs, room_id, OTHER_USER, "join")
)
initial_state_map = self.get_success(main_store.get_current_state_ids(room_id))
initial_state_map = self.get_success(
main_store.get_partial_current_state_ids(room_id)
)
auth_event_ids = [
initial_state_map[("m.room.create", "")],
+5 -3
View File
@@ -129,10 +129,12 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
hs.get_event_auth_handler().check_host_in_room = check_host_in_room
def get_joined_hosts_for_room(room_id: str):
async def get_current_hosts_in_room(room_id: str):
return {member.domain for member in self.room_members}
self.datastore.get_joined_hosts_for_room = get_joined_hosts_for_room
hs.get_storage_controllers().state.get_current_hosts_in_room = (
get_current_hosts_in_room
)
async def get_users_in_room(room_id: str):
return {str(u) for u in self.room_members}
@@ -146,7 +148,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
)
)
self.datastore.get_current_state_deltas = Mock(return_value=(0, None))
self.datastore.get_partial_current_state_deltas = Mock(return_value=(0, None))
self.datastore.get_to_device_stream_token = lambda: 0
self.datastore.get_new_device_msgs_for_remote = (
-1
View File
@@ -60,7 +60,6 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
self.appservice = ApplicationService(
token="i_am_an_app_service",
hostname="test",
id="1234",
namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]},
# Note: this user does not match the regex above, so that tests
-2
View File
@@ -2467,7 +2467,6 @@ PURGE_TABLES = [
"event_push_actions",
"event_search",
"events",
"group_rooms",
"receipts_graph",
"receipts_linearized",
"room_aliases",
@@ -2484,7 +2483,6 @@ PURGE_TABLES = [
"e2e_room_keys",
"event_push_summary",
"pusher_throttle",
"group_summary_rooms",
"room_account_data",
"room_tags",
# "state_groups", # Current impl leaves orphaned state groups around.
-1
View File
@@ -548,7 +548,6 @@ class WhoamiTestCase(unittest.HomeserverTestCase):
appservice = ApplicationService(
as_token,
self.hs.config.server.server_name,
id="1234",
namespaces={"users": [{"regex": user_id, "exclusive": True}]},
sender=user_id,
-2
View File
@@ -1112,7 +1112,6 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
self.service = ApplicationService(
id="unique_identifier",
token="some_token",
hostname="example.com",
sender="@asbot:example.com",
namespaces={
ApplicationService.NS_USERS: [
@@ -1125,7 +1124,6 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
self.another_service = ApplicationService(
id="another__identifier",
token="another_token",
hostname="example.com",
sender="@as2bot:example.com",
namespaces={
ApplicationService.NS_USERS: [
-2
View File
@@ -56,7 +56,6 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
appservice = ApplicationService(
as_token,
self.hs.config.server.server_name,
id="1234",
namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]},
sender="@as:test",
@@ -80,7 +79,6 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
appservice = ApplicationService(
as_token,
self.hs.config.server.server_name,
id="1234",
namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]},
sender="@as:test",
+59 -26
View File
@@ -896,6 +896,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
relation_type: str,
assertion_callable: Callable[[JsonDict], None],
expected_db_txn_for_event: int,
access_token: Optional[str] = None,
) -> None:
"""
Makes requests to various endpoints which should include bundled aggregations
@@ -907,7 +908,9 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
for relation-specific assertions.
expected_db_txn_for_event: The number of database transactions which
are expected for a call to /event/.
access_token: The access token to user, defaults to self.user_token.
"""
access_token = access_token or self.user_token
def assert_bundle(event_json: JsonDict) -> None:
"""Assert the expected values of the bundled aggregations."""
@@ -921,7 +924,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
channel = self.make_request(
"GET",
f"/rooms/{self.room}/event/{self.parent_id}",
access_token=self.user_token,
access_token=access_token,
)
self.assertEqual(200, channel.code, channel.json_body)
assert_bundle(channel.json_body)
@@ -932,7 +935,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
channel = self.make_request(
"GET",
f"/rooms/{self.room}/messages?dir=b",
access_token=self.user_token,
access_token=access_token,
)
self.assertEqual(200, channel.code, channel.json_body)
assert_bundle(self._find_event_in_chunk(channel.json_body["chunk"]))
@@ -941,7 +944,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
channel = self.make_request(
"GET",
f"/rooms/{self.room}/context/{self.parent_id}",
access_token=self.user_token,
access_token=access_token,
)
self.assertEqual(200, channel.code, channel.json_body)
assert_bundle(channel.json_body["event"])
@@ -949,7 +952,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
# Request sync.
filter = urllib.parse.quote_plus(b'{"room": {"timeline": {"limit": 4}}}')
channel = self.make_request(
"GET", f"/sync?filter={filter}", access_token=self.user_token
"GET", f"/sync?filter={filter}", access_token=access_token
)
self.assertEqual(200, channel.code, channel.json_body)
room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"]
@@ -962,7 +965,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
"/search",
# Search term matches the parent message.
content={"search_categories": {"room_events": {"search_term": "Hi"}}},
access_token=self.user_token,
access_token=access_token,
)
self.assertEqual(200, channel.code, channel.json_body)
chunk = [
@@ -1037,30 +1040,60 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
"""
Test that threads get correctly bundled.
"""
self._send_relation(RelationTypes.THREAD, "m.room.test")
channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
# The root message is from "user", send replies as "user2".
self._send_relation(
RelationTypes.THREAD, "m.room.test", access_token=self.user2_token
)
channel = self._send_relation(
RelationTypes.THREAD, "m.room.test", access_token=self.user2_token
)
thread_2 = channel.json_body["event_id"]
def assert_thread(bundled_aggregations: JsonDict) -> None:
self.assertEqual(2, bundled_aggregations.get("count"))
self.assertTrue(bundled_aggregations.get("current_user_participated"))
# The latest thread event has some fields that don't matter.
self.assert_dict(
{
"content": {
"m.relates_to": {
"event_id": self.parent_id,
"rel_type": RelationTypes.THREAD,
}
# This needs two assertion functions which are identical except for whether
# the current_user_participated flag is True, create a factory for the
# two versions.
def _gen_assert(participated: bool) -> Callable[[JsonDict], None]:
def assert_thread(bundled_aggregations: JsonDict) -> None:
self.assertEqual(2, bundled_aggregations.get("count"))
self.assertEqual(
participated, bundled_aggregations.get("current_user_participated")
)
# The latest thread event has some fields that don't matter.
self.assert_dict(
{
"content": {
"m.relates_to": {
"event_id": self.parent_id,
"rel_type": RelationTypes.THREAD,
}
},
"event_id": thread_2,
"sender": self.user2_id,
"type": "m.room.test",
},
"event_id": thread_2,
"sender": self.user_id,
"type": "m.room.test",
},
bundled_aggregations.get("latest_event"),
)
bundled_aggregations.get("latest_event"),
)
self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 9)
return assert_thread
# The "user" sent the root event and is making queries for the bundled
# aggregations: they have participated.
self._test_bundled_aggregations(RelationTypes.THREAD, _gen_assert(True), 8)
# The "user2" sent replies in the thread and is making queries for the
# bundled aggregations: they have participated.
#
# Note that this re-uses some cached values, so the total number of
# queries is much smaller.
self._test_bundled_aggregations(
RelationTypes.THREAD, _gen_assert(True), 2, access_token=self.user2_token
)
# A user with no interactions with the thread: they have not participated.
user3_id, user3_token = self._create_user("charlie")
self.helper.join(self.room, user=user3_id, tok=user3_token)
self._test_bundled_aggregations(
RelationTypes.THREAD, _gen_assert(False), 2, access_token=user3_token
)
def test_thread_with_bundled_aggregations_for_latest(self) -> None:
"""
@@ -1106,7 +1139,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
bundled_aggregations["latest_event"].get("unsigned"),
)
self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 9)
self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 8)
def test_nested_thread(self) -> None:
"""
-1
View File
@@ -71,7 +71,6 @@ class RoomBatchTestCase(unittest.HomeserverTestCase):
self.appservice = ApplicationService(
token="i_am_an_app_service",
hostname="test",
id="1234",
namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]},
# Note: this user does not have to match the regex above
+6 -2
View File
@@ -249,7 +249,9 @@ class UpgradeRoomTest(unittest.HomeserverTestCase):
new_space_id = channel.json_body["replacement_room"]
state_ids = self.get_success(self.store.get_current_state_ids(new_space_id))
state_ids = self.get_success(
self.store.get_partial_current_state_ids(new_space_id)
)
# Ensure the new room is still a space.
create_event = self.get_success(
@@ -284,7 +286,9 @@ class UpgradeRoomTest(unittest.HomeserverTestCase):
new_room_id = channel.json_body["replacement_room"]
state_ids = self.get_success(self.store.get_current_state_ids(new_room_id))
state_ids = self.get_success(
self.store.get_partial_current_state_ids(new_room_id)
)
# Ensure the new room is the same type as the old room.
create_event = self.get_success(
+36 -1
View File
@@ -145,7 +145,7 @@ class SummarizeTestCase(unittest.TestCase):
)
class CalcOgTestCase(unittest.TestCase):
class OpenGraphFromHtmlTestCase(unittest.TestCase):
if not lxml:
skip = "url preview feature requires lxml"
@@ -235,6 +235,21 @@ class CalcOgTestCase(unittest.TestCase):
self.assertEqual(og, {"og:title": None, "og:description": "Some text."})
# Another variant is a title with no content.
html = b"""
<html>
<head><title></title></head>
<body>
<h1>Title</h1>
</body>
</html>
"""
tree = decode_body(html, "http://example.com/test.html")
og = parse_html_to_open_graph(tree)
self.assertEqual(og, {"og:title": "Title", "og:description": "Title"})
def test_h1_as_title(self) -> None:
html = b"""
<html>
@@ -250,6 +265,26 @@ class CalcOgTestCase(unittest.TestCase):
self.assertEqual(og, {"og:title": "Title", "og:description": "Some text."})
def test_empty_description(self) -> None:
"""Description tags with empty content should be ignored."""
html = b"""
<html>
<meta property="og:description" content=""/>
<meta property="og:description"/>
<meta name="description" content=""/>
<meta name="description"/>
<meta name="description" content="Finally!"/>
<body>
<h1>Title</h1>
</body>
</html>
"""
tree = decode_body(html, "http://example.com/test.html")
og = parse_html_to_open_graph(tree)
self.assertEqual(og, {"og:title": "Title", "og:description": "Finally!"})
def test_missing_title_and_broken_h1(self) -> None:
html = b"""
<html>

Some files were not shown because too many files have changed in this diff Show More