diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 92977ea5a0..3ae4441bdf 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -380,7 +380,7 @@ jobs: # Attempt to check out the same branch of Complement as the PR. If it # doesn't exist, fallback to HEAD. - name: Checkout complement - run: .ci/scripts/checkout_complement.sh + run: synapse/.ci/scripts/checkout_complement.sh - run: | set -o pipefail diff --git a/changelog.d/12766.bugfix b/changelog.d/12766.bugfix new file mode 100644 index 0000000000..912c3deb70 --- /dev/null +++ b/changelog.d/12766.bugfix @@ -0,0 +1 @@ +Implement [MSC3816](https://github.com/matrix-org/matrix-spec-proposals/pull/3816): sending the root event in a thread should count as "participated" in it. diff --git a/changelog.d/12811.misc b/changelog.d/12811.misc new file mode 100644 index 0000000000..d57e1aca6b --- /dev/null +++ b/changelog.d/12811.misc @@ -0,0 +1 @@ +Reduce the amount of state we pull from the DB. diff --git a/changelog.d/12872.misc b/changelog.d/12872.misc new file mode 100644 index 0000000000..f60a756f21 --- /dev/null +++ b/changelog.d/12872.misc @@ -0,0 +1 @@ +Faster room joins: when querying the current state of the room, wait for state to be populated. diff --git a/changelog.d/12899.removal b/changelog.d/12899.removal new file mode 100644 index 0000000000..41f6fae5da --- /dev/null +++ b/changelog.d/12899.removal @@ -0,0 +1 @@ +Remove support for the non-standard groups/communities feature from Synapse. diff --git a/changelog.d/12902.misc b/changelog.d/12902.misc new file mode 100644 index 0000000000..3ee8f92552 --- /dev/null +++ b/changelog.d/12902.misc @@ -0,0 +1 @@ +Remove PyNaCl occurrences directly used in Synapse code. \ No newline at end of file diff --git a/changelog.d/12905.bugfix b/changelog.d/12905.bugfix new file mode 100644 index 0000000000..67e95d0398 --- /dev/null +++ b/changelog.d/12905.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in Synapse 1.58.0 where `/sync` would fail if the most recent event in a room was a redaction of an event that has since been purged. diff --git a/changelog.d/12932.bugfix b/changelog.d/12932.bugfix new file mode 100644 index 0000000000..506f92b427 --- /dev/null +++ b/changelog.d/12932.bugfix @@ -0,0 +1 @@ +Fix potential memory leak when generating thumbnails. diff --git a/changelog.d/12933.misc b/changelog.d/12933.misc new file mode 100644 index 0000000000..e29bf02407 --- /dev/null +++ b/changelog.d/12933.misc @@ -0,0 +1 @@ +Test Synapse against Complement with workers. diff --git a/changelog.d/12936.removal b/changelog.d/12936.removal new file mode 100644 index 0000000000..41f6fae5da --- /dev/null +++ b/changelog.d/12936.removal @@ -0,0 +1 @@ +Remove support for the non-standard groups/communities feature from Synapse. diff --git a/changelog.d/12950.bugfix b/changelog.d/12950.bugfix new file mode 100644 index 0000000000..e835d9aa72 --- /dev/null +++ b/changelog.d/12950.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug where a URL preview would break if the image failed to download. diff --git a/changelog.d/12951.feature b/changelog.d/12951.feature new file mode 100644 index 0000000000..f885be9fe4 --- /dev/null +++ b/changelog.d/12951.feature @@ -0,0 +1 @@ +Improve URL previews for pages with empty elements. diff --git a/changelog.d/12952.feature b/changelog.d/12952.feature new file mode 100644 index 0000000000..7329bcc3d4 --- /dev/null +++ b/changelog.d/12952.feature @@ -0,0 +1 @@ +Allow updating a user's password using the admin API without logging out their devices. Contributed by @jcgruenhage. diff --git a/changelog.d/12964.misc b/changelog.d/12964.misc new file mode 100644 index 0000000000..d57e1aca6b --- /dev/null +++ b/changelog.d/12964.misc @@ -0,0 +1 @@ +Reduce the amount of state we pull from the DB. diff --git a/changelog.d/12966.removal b/changelog.d/12966.removal new file mode 100644 index 0000000000..41f6fae5da --- /dev/null +++ b/changelog.d/12966.removal @@ -0,0 +1 @@ +Remove support for the non-standard groups/communities feature from Synapse. diff --git a/contrib/cmdclient/console.py b/contrib/cmdclient/console.py index 856dd437db..895b2a7af1 100755 --- a/contrib/cmdclient/console.py +++ b/contrib/cmdclient/console.py @@ -16,6 +16,7 @@ """ Starts a synapse client console. """ import argparse +import binascii import cmd import getpass import json @@ -26,9 +27,8 @@ import urllib from http import TwistedHttpClient from typing import Optional -import nacl.encoding -import nacl.signing import urlparse +from signedjson.key import NACL_ED25519, decode_verify_key_bytes from signedjson.sign import SignatureVerifyException, verify_signed_json from twisted.internet import defer, reactor, threads @@ -41,7 +41,6 @@ TRUSTED_ID_SERVERS = ["localhost:8001"] class SynapseCmd(cmd.Cmd): - """Basic synapse command-line processor. This processes commands from the user and calls the relevant HTTP methods. @@ -420,8 +419,8 @@ class SynapseCmd(cmd.Cmd): pubKey = None pubKeyObj = yield self.http_client.do_request("GET", url) if "public_key" in pubKeyObj: - pubKey = nacl.signing.VerifyKey( - pubKeyObj["public_key"], encoder=nacl.encoding.HexEncoder + pubKey = decode_verify_key_bytes( + NACL_ED25519, binascii.unhexlify(pubKeyObj["public_key"]) ) else: print("No public key found in pubkey response!") diff --git a/docs/admin_api/user_admin_api.md b/docs/admin_api/user_admin_api.md index c8794299e7..62f89e8cba 100644 --- a/docs/admin_api/user_admin_api.md +++ b/docs/admin_api/user_admin_api.md @@ -115,7 +115,9 @@ URL parameters: Body parameters: - `password` - string, optional. If provided, the user's password is updated and all - devices are logged out. + devices are logged out, unless `logout_devices` is set to `false`. +- `logout_devices` - bool, optional, defaults to `true`. If set to false, devices aren't + logged out even when `password` is provided. - `displayname` - string, optional, defaults to the value of `user_id`. - `threepids` - array, optional, allows setting the third-party IDs (email, msisdn) - `medium` - string. Kind of third-party ID, either `email` or `msisdn`. diff --git a/docs/workers.md b/docs/workers.md index 78973a498c..6969c424d8 100644 --- a/docs/workers.md +++ b/docs/workers.md @@ -191,7 +191,6 @@ information. ^/_matrix/federation/v1/event_auth/ ^/_matrix/federation/v1/exchange_third_party_invite/ ^/_matrix/federation/v1/user/devices/ - ^/_matrix/federation/v1/get_groups_publicised$ ^/_matrix/key/v2/query ^/_matrix/federation/v1/hierarchy/ @@ -213,9 +212,6 @@ information. ^/_matrix/client/(r0|v3|unstable)/devices$ ^/_matrix/client/versions$ ^/_matrix/client/(api/v1|r0|v3|unstable)/voip/turnServer$ - ^/_matrix/client/(r0|v3|unstable)/joined_groups$ - ^/_matrix/client/(r0|v3|unstable)/publicised_groups$ - ^/_matrix/client/(r0|v3|unstable)/publicised_groups/ ^/_matrix/client/(api/v1|r0|v3|unstable)/rooms/.*/event/ ^/_matrix/client/(api/v1|r0|v3|unstable)/joined_rooms$ ^/_matrix/client/(api/v1|r0|v3|unstable)/search$ @@ -255,9 +251,7 @@ information. Additionally, the following REST endpoints can be handled for GET requests: - ^/_matrix/federation/v1/groups/ ^/_matrix/client/(api/v1|r0|v3|unstable)/pushrules/ - ^/_matrix/client/(r0|v3|unstable)/groups/ Pagination requests can also be handled, but all requests for a given room must be routed to the same instance. Additionally, care must be taken to diff --git a/poetry.lock b/poetry.lock index efbdc7d2f9..7c561e3182 100644 --- a/poetry.lock +++ b/poetry.lock @@ -187,17 +187,6 @@ category = "main" optional = false python-versions = "*" -[[package]] -name = "coverage" -version = "6.4" -description = "Code coverage measurement for Python" -category = "dev" -optional = false -python-versions = ">=3.7" - -[package.extras] -toml = ["tomli"] - [[package]] name = "cryptography" version = "36.0.1" @@ -1574,7 +1563,7 @@ url_preview = ["lxml"] [metadata] lock-version = "1.1" python-versions = "^3.7.1" -content-hash = "4ca5c8a4f817f99704a5a1ba466a50655073d2c899cf5034c426f84dc4afe0c7" +content-hash = "539e5326f401472d1ffc8325d53d72e544cd70156b3f43f32f1285c4c131f831" [metadata.files] attrs = [ @@ -1713,49 +1702,6 @@ constantly = [ {file = "constantly-15.1.0-py2.py3-none-any.whl", hash = "sha256:dd2fa9d6b1a51a83f0d7dd76293d734046aa176e384bf6e33b7e44880eb37c5d"}, {file = "constantly-15.1.0.tar.gz", hash = "sha256:586372eb92059873e29eba4f9dec8381541b4d3834660707faf8ba59146dfc35"}, ] -coverage = [ - {file = "coverage-6.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:50ed480b798febce113709846b11f5d5ed1e529c88d8ae92f707806c50297abf"}, - {file = "coverage-6.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:26f8f92699756cb7af2b30720de0c5bb8d028e923a95b6d0c891088025a1ac8f"}, - {file = "coverage-6.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:60c2147921da7f4d2d04f570e1838db32b95c5509d248f3fe6417e91437eaf41"}, - {file = "coverage-6.4-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:750e13834b597eeb8ae6e72aa58d1d831b96beec5ad1d04479ae3772373a8088"}, - {file = "coverage-6.4-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:af5b9ee0fc146e907aa0f5fb858c3b3da9199d78b7bb2c9973d95550bd40f701"}, - {file = "coverage-6.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:a022394996419142b33a0cf7274cb444c01d2bb123727c4bb0b9acabcb515dea"}, - {file = "coverage-6.4-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:5a78cf2c43b13aa6b56003707c5203f28585944c277c1f3f109c7b041b16bd39"}, - {file = "coverage-6.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:9229d074e097f21dfe0643d9d0140ee7433814b3f0fc3706b4abffd1e3038632"}, - {file = "coverage-6.4-cp310-cp310-win32.whl", hash = "sha256:fb45fe08e1abc64eb836d187b20a59172053999823f7f6ef4f18a819c44ba16f"}, - {file = "coverage-6.4-cp310-cp310-win_amd64.whl", hash = "sha256:3cfd07c5889ddb96a401449109a8b97a165be9d67077df6802f59708bfb07720"}, - {file = "coverage-6.4-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:03014a74023abaf5a591eeeaf1ac66a73d54eba178ff4cb1fa0c0a44aae70383"}, - {file = "coverage-6.4-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9c82f2cd69c71698152e943f4a5a6b83a3ab1db73b88f6e769fabc86074c3b08"}, - {file = "coverage-6.4-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7b546cf2b1974ddc2cb222a109b37c6ed1778b9be7e6b0c0bc0cf0438d9e45a6"}, - {file = "coverage-6.4-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cc173f1ce9ffb16b299f51c9ce53f66a62f4d975abe5640e976904066f3c835d"}, - {file = "coverage-6.4-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:c53ad261dfc8695062fc8811ac7c162bd6096a05a19f26097f411bdf5747aee7"}, - {file = "coverage-6.4-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:eef5292b60b6de753d6e7f2d128d5841c7915fb1e3321c3a1fe6acfe76c38052"}, - {file = "coverage-6.4-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:543e172ce4c0de533fa892034cce260467b213c0ea8e39da2f65f9a477425211"}, - {file = "coverage-6.4-cp37-cp37m-win32.whl", hash = "sha256:00c8544510f3c98476bbd58201ac2b150ffbcce46a8c3e4fb89ebf01998f806a"}, - {file = "coverage-6.4-cp37-cp37m-win_amd64.whl", hash = "sha256:b84ab65444dcc68d761e95d4d70f3cfd347ceca5a029f2ffec37d4f124f61311"}, - {file = "coverage-6.4-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:d548edacbf16a8276af13063a2b0669d58bbcfca7c55a255f84aac2870786a61"}, - {file = "coverage-6.4-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:033ebec282793bd9eb988d0271c211e58442c31077976c19c442e24d827d356f"}, - {file = "coverage-6.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:742fb8b43835078dd7496c3c25a1ec8d15351df49fb0037bffb4754291ef30ce"}, - {file = "coverage-6.4-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d55fae115ef9f67934e9f1103c9ba826b4c690e4c5bcf94482b8b2398311bf9c"}, - {file = "coverage-6.4-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5cd698341626f3c77784858427bad0cdd54a713115b423d22ac83a28303d1d95"}, - {file = "coverage-6.4-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:62d382f7d77eeeaff14b30516b17bcbe80f645f5cf02bb755baac376591c653c"}, - {file = "coverage-6.4-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:016d7f5cf1c8c84f533a3c1f8f36126fbe00b2ec0ccca47cc5731c3723d327c6"}, - {file = "coverage-6.4-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:69432946f154c6add0e9ede03cc43b96e2ef2733110a77444823c053b1ff5166"}, - {file = "coverage-6.4-cp38-cp38-win32.whl", hash = "sha256:83bd142cdec5e4a5c4ca1d4ff6fa807d28460f9db919f9f6a31babaaa8b88426"}, - {file = "coverage-6.4-cp38-cp38-win_amd64.whl", hash = "sha256:4002f9e8c1f286e986fe96ec58742b93484195defc01d5cc7809b8f7acb5ece3"}, - {file = "coverage-6.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:e4f52c272fdc82e7c65ff3f17a7179bc5f710ebc8ce8a5cadac81215e8326740"}, - {file = "coverage-6.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:b5578efe4038be02d76c344007b13119b2b20acd009a88dde8adec2de4f630b5"}, - {file = "coverage-6.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d8099ea680201c2221f8468c372198ceba9338a5fec0e940111962b03b3f716a"}, - {file = "coverage-6.4-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a00441f5ea4504f5abbc047589d09e0dc33eb447dc45a1a527c8b74bfdd32c65"}, - {file = "coverage-6.4-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2e76bd16f0e31bc2b07e0fb1379551fcd40daf8cdf7e24f31a29e442878a827c"}, - {file = "coverage-6.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:8d2e80dd3438e93b19e1223a9850fa65425e77f2607a364b6fd134fcd52dc9df"}, - {file = "coverage-6.4-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:341e9c2008c481c5c72d0e0dbf64980a4b2238631a7f9780b0fe2e95755fb018"}, - {file = "coverage-6.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:21e6686a95025927775ac501e74f5940cdf6fe052292f3a3f7349b0abae6d00f"}, - {file = "coverage-6.4-cp39-cp39-win32.whl", hash = "sha256:968ed5407f9460bd5a591cefd1388cc00a8f5099de9e76234655ae48cfdbe2c3"}, - {file = "coverage-6.4-cp39-cp39-win_amd64.whl", hash = "sha256:e35217031e4b534b09f9b9a5841b9344a30a6357627761d4218818b865d45055"}, - {file = "coverage-6.4-pp36.pp37.pp38-none-any.whl", hash = "sha256:e637ae0b7b481905358624ef2e81d7fb0b1af55f5ff99f9ba05442a444b11e45"}, - {file = "coverage-6.4.tar.gz", hash = "sha256:727dafd7f67a6e1cad808dc884bd9c5a2f6ef1f8f6d2f22b37b96cb0080d4f49"}, -] cryptography = [ {file = "cryptography-36.0.1-cp36-abi3-macosx_10_10_universal2.whl", hash = "sha256:73bc2d3f2444bcfeac67dd130ff2ea598ea5f20b40e36d19821b4df8c9c5037b"}, {file = "cryptography-36.0.1-cp36-abi3-macosx_10_10_x86_64.whl", hash = "sha256:2d87cdcb378d3cfed944dac30596da1968f88fb96d7fc34fdae30a99054b2e31"}, diff --git a/pyproject.toml b/pyproject.toml index f483db0fce..9b3064e945 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -113,7 +113,6 @@ unpaddedbase64 = ">=2.1.0" canonicaljson = ">=1.4.0" # we use the type definitions added in signedjson 1.1. signedjson = ">=1.1.0" -PyNaCl = ">=1.2.1" # validating SSL certs for IP addresses requires service_identity 18.1. service-identity = ">=18.1.0" # Twisted 18.9 introduces some logger improvements that the structured diff --git a/synapse/_scripts/synapse_port_db.py b/synapse/_scripts/synapse_port_db.py index d7dfa92bd1..361b51d2fa 100755 --- a/synapse/_scripts/synapse_port_db.py +++ b/synapse/_scripts/synapse_port_db.py @@ -102,14 +102,6 @@ BOOLEAN_COLUMNS = { "devices": ["hidden"], "device_lists_outbound_pokes": ["sent"], "users_who_share_rooms": ["share_private"], - "groups": ["is_public"], - "group_rooms": ["is_public"], - "group_users": ["is_public", "is_admin"], - "group_summary_rooms": ["is_public"], - "group_room_categories": ["is_public"], - "group_summary_users": ["is_public"], - "group_roles": ["is_public"], - "local_group_membership": ["is_publicised", "is_admin"], "e2e_room_keys": ["is_verified"], "account_validity": ["email_sent"], "redactions": ["have_censored"], @@ -175,6 +167,22 @@ IGNORED_TABLES = { "ui_auth_sessions", "ui_auth_sessions_credentials", "ui_auth_sessions_ips", + # Groups/communities is no longer supported. + "group_attestations_remote", + "group_attestations_renewals", + "group_invites", + "group_roles", + "group_room_categories", + "group_rooms", + "group_summary_roles", + "group_summary_room_categories", + "group_summary_rooms", + "group_summary_users", + "group_users", + "groups", + "local_group_membership", + "local_group_updates", + "remote_profile_cache", } diff --git a/synapse/api/auth.py b/synapse/api/auth.py index 931750668e..5a410f805a 100644 --- a/synapse/api/auth.py +++ b/synapse/api/auth.py @@ -29,12 +29,11 @@ from synapse.api.errors import ( MissingClientTokenError, ) from synapse.appservice import ApplicationService -from synapse.events import EventBase from synapse.http import get_request_user_agent from synapse.http.site import SynapseRequest from synapse.logging.opentracing import active_span, force_tracing, start_active_span from synapse.storage.databases.main.registration import TokenLookupResult -from synapse.types import Requester, StateMap, UserID, create_requester +from synapse.types import Requester, UserID, create_requester from synapse.util.caches.lrucache import LruCache from synapse.util.macaroons import get_value_from_macaroon, satisfy_expiry @@ -61,8 +60,8 @@ class Auth: self.hs = hs self.clock = hs.get_clock() self.store = hs.get_datastores().main - self.state = hs.get_state_handler() self._account_validity_handler = hs.get_account_validity_handler() + self._storage_controllers = hs.get_storage_controllers() self.token_cache: LruCache[str, Tuple[str, bool]] = LruCache( 10000, "token_cache" @@ -79,9 +78,8 @@ class Auth: self, room_id: str, user_id: str, - current_state: Optional[StateMap[EventBase]] = None, allow_departed_users: bool = False, - ) -> EventBase: + ) -> Tuple[str, Optional[str]]: """Check if the user is in the room, or was at some point. Args: room_id: The room to check. @@ -99,29 +97,28 @@ class Auth: Raises: AuthError if the user is/was not in the room. Returns: - Membership event for the user if the user was in the - room. This will be the join event if they are currently joined to - the room. This will be the leave event if they have left the room. + The current membership of the user in the room and the + membership event ID of the user. """ - if current_state: - member = current_state.get((EventTypes.Member, user_id), None) - else: - member = await self.state.get_current_state( - room_id=room_id, event_type=EventTypes.Member, state_key=user_id - ) - if member: - membership = member.membership + ( + membership, + member_event_id, + ) = await self.store.get_local_current_membership_for_user_in_room( + user_id=user_id, + room_id=room_id, + ) + if membership: if membership == Membership.JOIN: - return member + return membership, member_event_id # XXX this looks totally bogus. Why do we not allow users who have been banned, # or those who were members previously and have been re-invited? if allow_departed_users and membership == Membership.LEAVE: forgot = await self.store.did_forget(user_id, room_id) if not forgot: - return member + return membership, member_event_id raise AuthError(403, "User %s not in room %s" % (user_id, room_id)) @@ -602,8 +599,11 @@ class Auth: # We currently require the user is a "moderator" in the room. We do this # by checking if they would (theoretically) be able to change the # m.room.canonical_alias events - power_level_event = await self.state.get_current_state( - room_id, EventTypes.PowerLevels, "" + + power_level_event = ( + await self._storage_controllers.state.get_current_state_event( + room_id, EventTypes.PowerLevels, "" + ) ) auth_events = {} @@ -693,12 +693,11 @@ class Auth: # * The user is a non-guest user, and was ever in the room # * The user is a guest user, and has joined the room # else it will throw. - member_event = await self.check_user_in_room( + return await self.check_user_in_room( room_id, user_id, allow_departed_users=allow_departed_users ) - return member_event.membership, member_event.event_id except AuthError: - visibility = await self.state.get_current_state( + visibility = await self._storage_controllers.state.get_current_state_event( room_id, EventTypes.RoomHistoryVisibility, "" ) if ( diff --git a/synapse/api/constants.py b/synapse/api/constants.py index f03fdd6dae..e1d31cabed 100644 --- a/synapse/api/constants.py +++ b/synapse/api/constants.py @@ -95,7 +95,6 @@ class EventTypes: Aliases: Final = "m.room.aliases" Redaction: Final = "m.room.redaction" ThirdPartyInvite: Final = "m.room.third_party_invite" - RelatedGroups: Final = "m.room.related_groups" RoomHistoryVisibility: Final = "m.room.history_visibility" CanonicalAlias: Final = "m.room.canonical_alias" diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py index ed92c2e910..0dfa00df44 100644 --- a/synapse/appservice/__init__.py +++ b/synapse/appservice/__init__.py @@ -70,7 +70,6 @@ class ApplicationService: def __init__( self, token: str, - hostname: str, id: str, sender: str, url: Optional[str] = None, @@ -88,7 +87,6 @@ class ApplicationService: ) # url must not end with a slash self.hs_token = hs_token self.sender = sender - self.server_name = hostname self.namespaces = self._check_namespaces(namespaces) self.id = id self.ip_range_whitelist = ip_range_whitelist diff --git a/synapse/config/appservice.py b/synapse/config/appservice.py index 24498e7944..16f93273b3 100644 --- a/synapse/config/appservice.py +++ b/synapse/config/appservice.py @@ -179,7 +179,6 @@ def _load_appservice( return ApplicationService( token=as_info["as_token"], - hostname=hostname, url=as_info["url"], namespaces=as_info["namespaces"], hs_token=as_info["hs_token"], diff --git a/synapse/events/third_party_rules.py b/synapse/events/third_party_rules.py index 9f4ff9799c..35f3f3690f 100644 --- a/synapse/events/third_party_rules.py +++ b/synapse/events/third_party_rules.py @@ -152,6 +152,7 @@ class ThirdPartyEventRules: self.third_party_rules = None self.store = hs.get_datastores().main + self._storage_controllers = hs.get_storage_controllers() self._check_event_allowed_callbacks: List[CHECK_EVENT_ALLOWED_CALLBACK] = [] self._on_create_room_callbacks: List[ON_CREATE_ROOM_CALLBACK] = [] @@ -463,7 +464,7 @@ class ThirdPartyEventRules: Returns: A dict mapping (event type, state key) to state event. """ - state_ids = await self.store.get_filtered_current_state_ids(room_id) + state_ids = await self._storage_controllers.state.get_current_state_ids(room_id) room_state_events = await self.store.get_events(state_ids.values()) state_events = {} diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py index a6232e048b..2522bf78fc 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py @@ -53,6 +53,7 @@ class FederationBase: self.spam_checker = hs.get_spam_checker() self.store = hs.get_datastores().main self._clock = hs.get_clock() + self._storage_controllers = hs.get_storage_controllers() async def _check_sigs_and_hash( self, room_version: RoomVersion, pdu: EventBase diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 12591dc8db..3e1518f1f6 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -118,6 +118,8 @@ class FederationServer(FederationBase): self.state = hs.get_state_handler() self._event_auth_handler = hs.get_event_auth_handler() + self._state_storage_controller = hs.get_storage_controllers().state + self.device_handler = hs.get_device_handler() # Ensure the following handlers are loaded since they register callbacks @@ -1221,14 +1223,10 @@ class FederationServer(FederationBase): Raises: AuthError if the server does not match the ACL """ - state_ids = await self.store.get_current_state_ids(room_id) - acl_event_id = state_ids.get((EventTypes.ServerACL, "")) - - if not acl_event_id: - return - - acl_event = await self.store.get_event(acl_event_id) - if server_matches_acl_event(server_name, acl_event): + acl_event = await self._storage_controllers.state.get_current_state_event( + room_id, EventTypes.ServerACL, "" + ) + if not acl_event or server_matches_acl_event(server_name, acl_event): return raise AuthError(code=403, msg="Server is banned from room") diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py index dbe303ed9b..99a794c042 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py @@ -245,6 +245,8 @@ class FederationSender(AbstractFederationSender): self.store = hs.get_datastores().main self.state = hs.get_state_handler() + self._storage_controllers = hs.get_storage_controllers() + self.clock = hs.get_clock() self.is_mine_id = hs.is_mine_id @@ -602,7 +604,9 @@ class FederationSender(AbstractFederationSender): room_id = receipt.room_id # Work out which remote servers should be poked and poke them. - domains_set = await self.state.get_current_hosts_in_room(room_id) + domains_set = await self._storage_controllers.state.get_current_hosts_in_room( + room_id + ) domains = [ d for d in domains_set diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 72faf2ee38..a0cbeedc30 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -166,7 +166,7 @@ class DeviceWorkerHandler: possibly_changed = set(changed) possibly_left = set() for room_id in rooms_changed: - current_state_ids = await self.store.get_current_state_ids(room_id) + current_state_ids = await self._state_storage.get_current_state_ids(room_id) # The user may have left the room # TODO: Check if they actually did or if we were just invited. diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py index 4aa33df884..1459a046de 100644 --- a/synapse/handlers/directory.py +++ b/synapse/handlers/directory.py @@ -45,6 +45,7 @@ class DirectoryHandler: self.appservice_handler = hs.get_application_service_handler() self.event_creation_handler = hs.get_event_creation_handler() self.store = hs.get_datastores().main + self._storage_controllers = hs.get_storage_controllers() self.config = hs.config self.enable_room_list_search = hs.config.roomdirectory.enable_room_list_search self.require_membership = hs.config.server.require_membership_for_aliases @@ -319,7 +320,7 @@ class DirectoryHandler: Raises: ShadowBanError if the requester has been shadow-banned. """ - alias_event = await self.state.get_current_state( + alias_event = await self._storage_controllers.state.get_current_state_event( room_id, EventTypes.CanonicalAlias, "" ) @@ -463,7 +464,11 @@ class DirectoryHandler: making_public = visibility == "public" if making_public: room_aliases = await self.store.get_aliases_for_room(room_id) - canonical_alias = await self.store.get_canonical_alias_for_room(room_id) + canonical_alias = ( + await self._storage_controllers.state.get_canonical_alias_for_room( + room_id + ) + ) if canonical_alias: room_aliases.append(canonical_alias) diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 659f279441..6a143440d3 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -371,7 +371,7 @@ class FederationHandler: # First we try hosts that are already in the room # TODO: HEURISTIC ALERT. - curr_state = await self.state_handler.get_current_state(room_id) + curr_state = await self._storage_controllers.state.get_current_state(room_id) curr_domains = get_domains_from_state(curr_state) @@ -750,7 +750,9 @@ class FederationHandler: # Note that this requires the /send_join request to come back to the # same server. if room_version.msc3083_join_rules: - state_ids = await self.store.get_current_state_ids(room_id) + state_ids = await self._state_storage_controller.get_current_state_ids( + room_id + ) if await self._event_auth_handler.has_restricted_join_rules( state_ids, room_version ): @@ -1552,6 +1554,9 @@ class FederationHandler: success = await self.store.clear_partial_state_room(room_id) if success: logger.info("State resync complete for %s", room_id) + self._storage_controllers.state.notify_room_un_partial_stated( + room_id + ) # TODO(faster_joins) update room stats and user directory? return diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 549b066dd9..87a0608359 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -1584,9 +1584,11 @@ class FederationEventHandler: if guest_access == GuestAccess.CAN_JOIN: return - current_state_map = await self._state_handler.get_current_state(event.room_id) - current_state = list(current_state_map.values()) - await self._get_room_member_handler().kick_guest_users(current_state) + current_state = await self._storage_controllers.state.get_current_state( + event.room_id + ) + current_state_list = list(current_state.values()) + await self._get_room_member_handler().kick_guest_users(current_state_list) async def _check_for_soft_fail( self, @@ -1614,6 +1616,9 @@ class FederationEventHandler: room_version = await self._store.get_room_version_id(event.room_id) room_version_obj = KNOWN_ROOM_VERSIONS[room_version] + # The event types we want to pull from the "current" state. + auth_types = auth_types_for_event(room_version_obj, event) + # Calculate the "current state". if state_ids is not None: # If we're explicitly given the state then we won't have all the @@ -1643,8 +1648,10 @@ class FederationEventHandler: ) ) else: - current_state_ids = await self._state_handler.get_current_state_ids( - event.room_id, latest_event_ids=extrem_ids + current_state_ids = ( + await self._state_storage_controller.get_current_state_ids( + event.room_id, StateFilter.from_types(auth_types) + ) ) logger.debug( @@ -1654,7 +1661,6 @@ class FederationEventHandler: ) # Now check if event pass auth against said current state - auth_types = auth_types_for_event(room_version_obj, event) current_state_ids_list = [ e for k, e in current_state_ids.items() if k in auth_types ] diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py index d2b489e816..85b472f250 100644 --- a/synapse/handlers/initial_sync.py +++ b/synapse/handlers/initial_sync.py @@ -190,7 +190,7 @@ class InitialSyncHandler: if event.membership == Membership.JOIN: room_end_token = now_token.room_key deferred_room_state = run_in_background( - self.state_handler.get_current_state, event.room_id + self._state_storage_controller.get_current_state, event.room_id ) elif event.membership == Membership.LEAVE: room_end_token = RoomStreamToken( @@ -407,7 +407,9 @@ class InitialSyncHandler: membership: str, is_peeking: bool, ) -> JsonDict: - current_state = await self.state.get_current_state(room_id=room_id) + current_state = await self._storage_controllers.state.get_current_state( + room_id=room_id + ) # TODO: These concurrently time_now = self.clock.time_msec() diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index cf7c2d1979..f455158a2c 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -28,6 +28,7 @@ from synapse.api.constants import ( EventContentFields, EventTypes, GuestAccess, + HistoryVisibility, Membership, RelationTypes, UserTypes, @@ -66,7 +67,7 @@ from synapse.util import json_decoder, json_encoder, log_failure, unwrapFirstErr from synapse.util.async_helpers import Linearizer, gather_results from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.metrics import measure_func -from synapse.visibility import filter_events_for_client +from synapse.visibility import get_effective_room_visibility_from_state if TYPE_CHECKING: from synapse.events.third_party_rules import ThirdPartyEventRules @@ -124,7 +125,9 @@ class MessageHandler: ) if membership == Membership.JOIN: - data = await self.state.get_current_state(room_id, event_type, state_key) + data = await self._storage_controllers.state.get_current_state_event( + room_id, event_type, state_key + ) elif membership == Membership.LEAVE: key = (event_type, state_key) # If the membership is not JOIN, then the event ID should exist. @@ -182,51 +185,31 @@ class MessageHandler: state_filter = state_filter or StateFilter.all() if at_token: - last_event = await self.store.get_last_event_in_room_before_stream_ordering( - room_id, - end_token=at_token.room_key, + last_event_id = ( + await self.store.get_last_event_in_room_before_stream_ordering( + room_id, + end_token=at_token.room_key, + ) ) - if not last_event: + if not last_event_id: raise NotFoundError("Can't find event for token %s" % (at_token,)) - # check whether the user is in the room at that time to determine - # whether they should be treated as peeking. - state_map = await self._state_storage_controller.get_state_for_event( - last_event.event_id, - StateFilter.from_types([(EventTypes.Member, user_id)]), - ) - - joined = False - membership_event = state_map.get((EventTypes.Member, user_id)) - if membership_event: - joined = membership_event.membership == Membership.JOIN - - is_peeking = not joined - - visible_events = await filter_events_for_client( - self._storage_controllers, - user_id, - [last_event], - filter_send_to_client=False, - is_peeking=is_peeking, - ) - - if visible_events: - room_state_events = ( - await self._state_storage_controller.get_state_for_events( - [last_event.event_id], state_filter=state_filter - ) - ) - room_state: Mapping[Any, EventBase] = room_state_events[ - last_event.event_id - ] - else: + if not await self._user_can_see_state_at_event( + user_id, room_id, last_event_id + ): raise AuthError( 403, "User %s not allowed to view events in room %s at token %s" % (user_id, room_id, at_token), ) + + room_state_events = ( + await self._state_storage_controller.get_state_for_events( + [last_event_id], state_filter=state_filter + ) + ) + room_state: Mapping[Any, EventBase] = room_state_events[last_event_id] else: ( membership, @@ -236,7 +219,7 @@ class MessageHandler: ) if membership == Membership.JOIN: - state_ids = await self.store.get_filtered_current_state_ids( + state_ids = await self._state_storage_controller.get_current_state_ids( room_id, state_filter=state_filter ) room_state = await self.store.get_events(state_ids.values()) @@ -256,6 +239,65 @@ class MessageHandler: events = self._event_serializer.serialize_events(room_state.values(), now) return events + async def _user_can_see_state_at_event( + self, user_id: str, room_id: str, event_id: str + ) -> bool: + # check whether the user was in the room, and the history visibility, + # at that time. + state_map = await self._state_storage_controller.get_state_for_event( + event_id, + StateFilter.from_types( + [ + (EventTypes.Member, user_id), + (EventTypes.RoomHistoryVisibility, ""), + ] + ), + ) + + membership = None + membership_event = state_map.get((EventTypes.Member, user_id)) + if membership_event: + membership = membership_event.membership + + # if the user was a member of the room at the time of the event, + # they can see it. + if membership == Membership.JOIN: + return True + + # otherwise, it depends on the history visibility. + visibility = get_effective_room_visibility_from_state(state_map) + + if visibility == HistoryVisibility.JOINED: + # we weren't a member at the time of the event, so we can't see this event. + return False + + # otherwise *invited* is good enough + if membership == Membership.INVITE: + return True + + if visibility == HistoryVisibility.INVITED: + # we weren't invited, so we can't see this event. + return False + + if visibility == HistoryVisibility.WORLD_READABLE: + return True + + # So it's SHARED, and the user was not a member at the time. The user cannot + # see history, unless they have *subsequently* joined the room. + # + # XXX: if the user has subsequently joined and then left again, + # ideally we would share history up to the point they left. But + # we don't know when they left. We just treat it as though they + # never joined, and restrict access. + + ( + current_membership, + _, + ) = await self.store.get_local_current_membership_for_user_in_room( + user_id, event_id + ) + return current_membership == Membership.JOIN + async def get_joined_members(self, requester: Requester, room_id: str) -> dict: """Get all the joined members in the room and their profile information. diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index bf112b9e1e..895ea63ed3 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -134,6 +134,7 @@ class BasePresenceHandler(abc.ABC): def __init__(self, hs: "HomeServer"): self.clock = hs.get_clock() self.store = hs.get_datastores().main + self._storage_controllers = hs.get_storage_controllers() self.presence_router = hs.get_presence_router() self.state = hs.get_state_handler() self.is_mine_id = hs.is_mine_id @@ -1348,7 +1349,10 @@ class PresenceHandler(BasePresenceHandler): self._event_pos, room_max_stream_ordering, ) - max_pos, deltas = await self.store.get_current_state_deltas( + ( + max_pos, + deltas, + ) = await self._storage_controllers.state.get_current_state_deltas( self._event_pos, room_max_stream_ordering ) diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index 239b0aa744..6eed3826a7 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -23,14 +23,7 @@ from synapse.api.errors import ( StoreError, SynapseError, ) -from synapse.metrics.background_process_metrics import wrap_as_background_process -from synapse.types import ( - JsonDict, - Requester, - UserID, - create_requester, - get_domain_from_id, -) +from synapse.types import JsonDict, Requester, UserID, create_requester from synapse.util.caches.descriptors import cached from synapse.util.stringutils import parse_and_validate_mxc_uri @@ -50,9 +43,6 @@ class ProfileHandler: delegate to master when necessary. """ - PROFILE_UPDATE_MS = 60 * 1000 - PROFILE_UPDATE_EVERY_MS = 24 * 60 * 60 * 1000 - def __init__(self, hs: "HomeServer"): self.store = hs.get_datastores().main self.clock = hs.get_clock() @@ -73,11 +63,6 @@ class ProfileHandler: self._third_party_rules = hs.get_third_party_event_rules() - if hs.config.worker.run_background_tasks: - self.clock.looping_call( - self._update_remote_profile_cache, self.PROFILE_UPDATE_MS - ) - async def get_profile(self, user_id: str) -> JsonDict: target_user = UserID.from_string(user_id) @@ -116,30 +101,6 @@ class ProfileHandler: raise SynapseError(502, "Failed to fetch profile") raise e.to_synapse_error() - async def get_profile_from_cache(self, user_id: str) -> JsonDict: - """Get the profile information from our local cache. If the user is - ours then the profile information will always be correct. Otherwise, - it may be out of date/missing. - """ - target_user = UserID.from_string(user_id) - if self.hs.is_mine(target_user): - try: - displayname = await self.store.get_profile_displayname( - target_user.localpart - ) - avatar_url = await self.store.get_profile_avatar_url( - target_user.localpart - ) - except StoreError as e: - if e.code == 404: - raise SynapseError(404, "Profile was not found", Codes.NOT_FOUND) - raise - - return {"displayname": displayname, "avatar_url": avatar_url} - else: - profile = await self.store.get_from_remote_profile_cache(user_id) - return profile or {} - async def get_displayname(self, target_user: UserID) -> Optional[str]: if self.hs.is_mine(target_user): try: @@ -509,45 +470,3 @@ class ProfileHandler: # so we act as if we couldn't find the profile. raise SynapseError(403, "Profile isn't available", Codes.FORBIDDEN) raise - - @wrap_as_background_process("Update remote profile") - async def _update_remote_profile_cache(self) -> None: - """Called periodically to check profiles of remote users we haven't - checked in a while. - """ - entries = await self.store.get_remote_profile_cache_entries_that_expire( - last_checked=self.clock.time_msec() - self.PROFILE_UPDATE_EVERY_MS - ) - - for user_id, displayname, avatar_url in entries: - is_subscribed = await self.store.is_subscribed_remote_profile_for_user( - user_id - ) - if not is_subscribed: - await self.store.maybe_delete_remote_profile_cache(user_id) - continue - - try: - profile = await self.federation.make_query( - destination=get_domain_from_id(user_id), - query_type="profile", - args={"user_id": user_id}, - ignore_backoff=True, - ) - except Exception: - logger.exception("Failed to get avatar_url") - - await self.store.update_remote_profile_cache( - user_id, displayname, avatar_url - ) - continue - - new_name = profile.get("displayname") - if not isinstance(new_name, str): - new_name = None - new_avatar = profile.get("avatar_url") - if not isinstance(new_avatar, str): - new_avatar = None - - # We always hit update to update the last_check timestamp - await self.store.update_remote_profile_cache(user_id, new_name, new_avatar) diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 05bb1e0225..338204287f 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -87,6 +87,7 @@ class LoginDict(TypedDict): class RegistrationHandler: def __init__(self, hs: "HomeServer"): self.store = hs.get_datastores().main + self._storage_controllers = hs.get_storage_controllers() self.clock = hs.get_clock() self.hs = hs self.auth = hs.get_auth() @@ -528,7 +529,7 @@ class RegistrationHandler: if requires_invite: # If the server is in the room, check if the room is public. - state = await self.store.get_filtered_current_state_ids( + state = await self._storage_controllers.state.get_current_state_ids( room_id, StateFilter.from_types([(EventTypes.JoinRules, "")]) ) diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py index 9a1cc11bb3..0b63cd2186 100644 --- a/synapse/handlers/relations.py +++ b/synapse/handlers/relations.py @@ -12,16 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import ( - TYPE_CHECKING, - Collection, - Dict, - FrozenSet, - Iterable, - List, - Optional, - Tuple, -) +from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Tuple import attr @@ -256,13 +247,19 @@ class RelationsHandler: return filtered_results - async def get_threads_for_events( - self, event_ids: Collection[str], user_id: str, ignored_users: FrozenSet[str] + async def _get_threads_for_events( + self, + events_by_id: Dict[str, EventBase], + relations_by_id: Dict[str, str], + user_id: str, + ignored_users: FrozenSet[str], ) -> Dict[str, _ThreadAggregation]: """Get the bundled aggregations for threads for the requested events. Args: - event_ids: Events to get aggregations for threads. + events_by_id: A map of event_id to events to get aggregations for threads. + relations_by_id: A map of event_id to the relation type, if one exists + for that event. user_id: The user requesting the bundled aggregations. ignored_users: The users ignored by the requesting user. @@ -273,16 +270,34 @@ class RelationsHandler: """ user = UserID.from_string(user_id) + # It is not valid to start a thread on an event which itself relates to another event. + event_ids = [eid for eid in events_by_id.keys() if eid not in relations_by_id] + # Fetch thread summaries. summaries = await self._main_store.get_thread_summaries(event_ids) - # Only fetch participated for a limited selection based on what had - # summaries. + # Limit fetching whether the requester has participated in a thread to + # events which are thread roots. thread_event_ids = [ event_id for event_id, summary in summaries.items() if summary ] - participated = await self._main_store.get_threads_participated( - thread_event_ids, user_id + + # Pre-seed thread participation with whether the requester sent the event. + participated = { + event_id: events_by_id[event_id].sender == user_id + for event_id in thread_event_ids + } + # For events the requester did not send, check the database for whether + # the requester sent a threaded reply. + participated.update( + await self._main_store.get_threads_participated( + [ + event_id + for event_id in thread_event_ids + if not participated[event_id] + ], + user_id, + ) ) # Then subtract off the results for any ignored users. @@ -343,7 +358,8 @@ class RelationsHandler: count=thread_count, # If there's a thread summary it must also exist in the # participated dictionary. - current_user_participated=participated[event_id], + current_user_participated=events_by_id[event_id].sender == user_id + or participated[event_id], ) return results @@ -401,9 +417,9 @@ class RelationsHandler: # events to be fetched. Thus, we check those first! # Fetch thread summaries (but only for the directly requested events). - threads = await self.get_threads_for_events( - # It is not valid to start a thread on an event which itself relates to another event. - [eid for eid in events_by_id.keys() if eid not in relations_by_id], + threads = await self._get_threads_for_events( + events_by_id, + relations_by_id, user_id, ignored_users, ) diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 5c91d33f58..520663f172 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -107,6 +107,7 @@ class EventContext: class RoomCreationHandler: def __init__(self, hs: "HomeServer"): self.store = hs.get_datastores().main + self._storage_controllers = hs.get_storage_controllers() self.auth = hs.get_auth() self.clock = hs.get_clock() self.hs = hs @@ -468,7 +469,6 @@ class RoomCreationHandler: (EventTypes.RoomAvatar, ""), (EventTypes.RoomEncryption, ""), (EventTypes.ServerACL, ""), - (EventTypes.RelatedGroups, ""), (EventTypes.PowerLevels, ""), ] @@ -481,8 +481,10 @@ class RoomCreationHandler: if room_type == RoomTypes.SPACE: types_to_copy.append((EventTypes.SpaceChild, None)) - old_room_state_ids = await self.store.get_filtered_current_state_ids( - old_room_id, StateFilter.from_types(types_to_copy) + old_room_state_ids = ( + await self._storage_controllers.state.get_current_state_ids( + old_room_id, StateFilter.from_types(types_to_copy) + ) ) # map from event_id to BaseEvent old_room_state_events = await self.store.get_events(old_room_state_ids.values()) @@ -559,8 +561,10 @@ class RoomCreationHandler: ) # Transfer membership events - old_room_member_state_ids = await self.store.get_filtered_current_state_ids( - old_room_id, StateFilter.from_types([(EventTypes.Member, None)]) + old_room_member_state_ids = ( + await self._storage_controllers.state.get_current_state_ids( + old_room_id, StateFilter.from_types([(EventTypes.Member, None)]) + ) ) # map from event_id to BaseEvent @@ -1329,6 +1333,7 @@ class TimestampLookupHandler: self.store = hs.get_datastores().main self.state_handler = hs.get_state_handler() self.federation_client = hs.get_federation_client() + self._storage_controllers = hs.get_storage_controllers() async def get_event_for_timestamp( self, @@ -1402,7 +1407,9 @@ class TimestampLookupHandler: ) # Find other homeservers from the given state in the room - curr_state = await self.state_handler.get_current_state(room_id) + curr_state = await self._storage_controllers.state.get_current_state( + room_id + ) curr_domains = get_domains_from_state(curr_state) likely_domains = [ domain for domain, depth in curr_domains if domain != self.server_name diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py index f3577b5d5a..183d4ae3c4 100644 --- a/synapse/handlers/room_list.py +++ b/synapse/handlers/room_list.py @@ -50,6 +50,7 @@ EMPTY_THIRD_PARTY_ID = ThirdPartyInstanceID(None, None) class RoomListHandler: def __init__(self, hs: "HomeServer"): self.store = hs.get_datastores().main + self._storage_controllers = hs.get_storage_controllers() self.hs = hs self.enable_room_list_search = hs.config.roomdirectory.enable_room_list_search self.response_cache: ResponseCache[ @@ -274,7 +275,7 @@ class RoomListHandler: if aliases: result["aliases"] = aliases - current_state_ids = await self.store.get_current_state_ids( + current_state_ids = await self._storage_controllers.state.get_current_state_ids( room_id, on_invalidate=cache_context.invalidate ) diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 00662dc961..d1199a0644 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -68,6 +68,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): def __init__(self, hs: "HomeServer"): self.hs = hs self.store = hs.get_datastores().main + self._storage_controllers = hs.get_storage_controllers() self.auth = hs.get_auth() self.state_handler = hs.get_state_handler() self.config = hs.config @@ -994,7 +995,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): # If the host is in the room, but not one of the authorised hosts # for restricted join rules, a remote join must be used. room_version = await self.store.get_room_version(room_id) - current_state_ids = await self.store.get_current_state_ids(room_id) + current_state_ids = await self._storage_controllers.state.get_current_state_ids( + room_id + ) # If restricted join rules are not being used, a local join can always # be used. @@ -1398,7 +1401,19 @@ class RoomMemberHandler(metaclass=abc.ABCMeta): txn_id: Optional[str], id_access_token: Optional[str] = None, ) -> int: - room_state = await self.state_handler.get_current_state(room_id) + room_state = await self._storage_controllers.state.get_current_state( + room_id, + StateFilter.from_types( + [ + (EventTypes.Member, user.to_string()), + (EventTypes.CanonicalAlias, ""), + (EventTypes.Name, ""), + (EventTypes.Create, ""), + (EventTypes.JoinRules, ""), + (EventTypes.RoomAvatar, ""), + ] + ), + ) inviter_display_name = "" inviter_avatar_url = "" @@ -1794,7 +1809,7 @@ class RoomMemberMasterHandler(RoomMemberHandler): async def forget(self, user: UserID, room_id: str) -> None: user_id = user.to_string() - member = await self.state_handler.get_current_state( + member = await self._storage_controllers.state.get_current_state_event( room_id=room_id, event_type=EventTypes.Member, state_key=user_id ) membership = member.membership if member else None diff --git a/synapse/handlers/room_summary.py b/synapse/handlers/room_summary.py index 75aee6a111..13098f56ed 100644 --- a/synapse/handlers/room_summary.py +++ b/synapse/handlers/room_summary.py @@ -90,6 +90,7 @@ class RoomSummaryHandler: def __init__(self, hs: "HomeServer"): self._event_auth_handler = hs.get_event_auth_handler() self._store = hs.get_datastores().main + self._storage_controllers = hs.get_storage_controllers() self._event_serializer = hs.get_event_client_serializer() self._server_name = hs.hostname self._federation_client = hs.get_federation_client() @@ -537,7 +538,7 @@ class RoomSummaryHandler: Returns: True if the room is accessible to the requesting user or server. """ - state_ids = await self._store.get_current_state_ids(room_id) + state_ids = await self._storage_controllers.state.get_current_state_ids(room_id) # If there's no state for the room, it isn't known. if not state_ids: @@ -702,7 +703,9 @@ class RoomSummaryHandler: # there should always be an entry assert stats is not None, "unable to retrieve stats for %s" % (room_id,) - current_state_ids = await self._store.get_current_state_ids(room_id) + current_state_ids = await self._storage_controllers.state.get_current_state_ids( + room_id + ) create_event = await self._store.get_event( current_state_ids[(EventTypes.Create, "")] ) @@ -760,7 +763,9 @@ class RoomSummaryHandler: """ # look for child rooms/spaces. - current_state_ids = await self._store.get_current_state_ids(room_id) + current_state_ids = await self._storage_controllers.state.get_current_state_ids( + room_id + ) events = await self._store.get_events_as_list( [ diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py index 659f99f7e2..bcab98c6d5 100644 --- a/synapse/handlers/search.py +++ b/synapse/handlers/search.py @@ -348,7 +348,7 @@ class SearchHandler: state_results = {} if include_state: for room_id in {e.room_id for e in search_result.allowed_events}: - state = await self.state_handler.get_current_state(room_id) + state = await self._storage_controllers.state.get_current_state(room_id) state_results[room_id] = list(state.values()) aggregations = await self._relations_handler.get_bundled_aggregations( diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py index 436cd971ce..f45e06eb0e 100644 --- a/synapse/handlers/stats.py +++ b/synapse/handlers/stats.py @@ -40,6 +40,7 @@ class StatsHandler: def __init__(self, hs: "HomeServer"): self.hs = hs self.store = hs.get_datastores().main + self._storage_controllers = hs.get_storage_controllers() self.state = hs.get_state_handler() self.server_name = hs.hostname self.clock = hs.get_clock() @@ -105,7 +106,10 @@ class StatsHandler: logger.debug( "Processing room stats %s->%s", self.pos, room_max_stream_ordering ) - max_pos, deltas = await self.store.get_current_state_deltas( + ( + max_pos, + deltas, + ) = await self._storage_controllers.state.get_current_state_deltas( self.pos, room_max_stream_ordering ) diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index b5859dcb28..b4ead79f97 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -506,8 +506,10 @@ class SyncHandler: # ensure that we always include current state in the timeline current_state_ids: FrozenSet[str] = frozenset() if any(e.is_state() for e in recents): - current_state_ids_map = await self.store.get_current_state_ids( - room_id + current_state_ids_map = ( + await self._state_storage_controller.get_current_state_ids( + room_id + ) ) current_state_ids = frozenset(current_state_ids_map.values()) @@ -574,8 +576,11 @@ class SyncHandler: # ensure that we always include current state in the timeline current_state_ids = frozenset() if any(e.is_state() for e in loaded_recents): - current_state_ids_map = await self.store.get_current_state_ids( - room_id + # FIXME(faster_joins): We use the partial state here as + # we don't want to block `/sync` on finishing a lazy join. + # Is this the correct way of doing it? + current_state_ids_map = ( + await self.store.get_partial_current_state_ids(room_id) ) current_state_ids = frozenset(current_state_ids_map.values()) @@ -621,21 +626,32 @@ class SyncHandler: ) async def get_state_after_event( - self, event: EventBase, state_filter: Optional[StateFilter] = None + self, event_id: str, state_filter: Optional[StateFilter] = None ) -> StateMap[str]: """ Get the room state after the given event Args: - event: event of interest + event_id: event of interest state_filter: The state filter used to fetch state from the database. """ state_ids = await self._state_storage_controller.get_state_ids_for_event( - event.event_id, state_filter=state_filter or StateFilter.all() + event_id, state_filter=state_filter or StateFilter.all() ) - if event.is_state(): + + # using get_metadata_for_events here (instead of get_event) sidesteps an issue + # with redactions: if `event_id` is a redaction event, and we don't have the + # original (possibly because it got purged), get_event will refuse to return + # the redaction event, which isn't terribly helpful here. + # + # (To be fair, in that case we could assume it's *not* a state event, and + # therefore we don't need to worry about it. But still, it seems cleaner just + # to pull the metadata.) + m = (await self.store.get_metadata_for_events([event_id]))[event_id] + if m.state_key is not None and m.rejection_reason is None: state_ids = dict(state_ids) - state_ids[(event.type, event.state_key)] = event.event_id + state_ids[(m.event_type, m.state_key)] = event_id + return state_ids async def get_state_at( @@ -654,14 +670,14 @@ class SyncHandler: # FIXME: This gets the state at the latest event before the stream ordering, # which might not be the same as the "current state" of the room at the time # of the stream token if there were multiple forward extremities at the time. - last_event = await self.store.get_last_event_in_room_before_stream_ordering( + last_event_id = await self.store.get_last_event_in_room_before_stream_ordering( room_id, end_token=stream_position.room_key, ) - if last_event: + if last_event_id: state = await self.get_state_after_event( - last_event, state_filter=state_filter or StateFilter.all() + last_event_id, state_filter=state_filter or StateFilter.all() ) else: diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index 0aeab86bbb..d104ea07fe 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -59,6 +59,7 @@ class FollowerTypingHandler: def __init__(self, hs: "HomeServer"): self.store = hs.get_datastores().main + self._storage_controllers = hs.get_storage_controllers() self.server_name = hs.config.server.server_name self.clock = hs.get_clock() self.is_mine_id = hs.is_mine_id @@ -131,7 +132,6 @@ class FollowerTypingHandler: return try: - users = await self.store.get_users_in_room(member.room_id) self._member_last_federation_poke[member] = self.clock.time_msec() now = self.clock.time_msec() @@ -139,7 +139,10 @@ class FollowerTypingHandler: now=now, obj=member, then=now + FEDERATION_PING_INTERVAL ) - for domain in {get_domain_from_id(u) for u in users}: + hosts = await self._storage_controllers.state.get_current_hosts_in_room( + member.room_id + ) + for domain in hosts: if domain != self.server_name: logger.debug("sending typing update to %s", domain) self.federation.build_and_send_edu( diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py index 74f7fdfe6c..8c3c52e1ca 100644 --- a/synapse/handlers/user_directory.py +++ b/synapse/handlers/user_directory.py @@ -56,6 +56,7 @@ class UserDirectoryHandler(StateDeltasHandler): super().__init__(hs) self.store = hs.get_datastores().main + self._storage_controllers = hs.get_storage_controllers() self.server_name = hs.hostname self.clock = hs.get_clock() self.notifier = hs.get_notifier() @@ -174,7 +175,10 @@ class UserDirectoryHandler(StateDeltasHandler): logger.debug( "Processing user stats %s->%s", self.pos, room_max_stream_ordering ) - max_pos, deltas = await self.store.get_current_state_deltas( + ( + max_pos, + deltas, + ) = await self._storage_controllers.state.get_current_state_deltas( self.pos, room_max_stream_ordering ) diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index b7451fc870..a8ad575fcd 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -194,6 +194,7 @@ class ModuleApi: self._store: Union[ DataStore, "GenericWorkerSlavedStore" ] = hs.get_datastores().main + self._storage_controllers = hs.get_storage_controllers() self._auth = hs.get_auth() self._auth_handler = auth_handler self._server_name = hs.hostname @@ -911,7 +912,7 @@ class ModuleApi: The filtered state events in the room. """ state_ids = yield defer.ensureDeferred( - self._store.get_filtered_current_state_ids( + self._storage_controllers.state.get_current_state_ids( room_id=room_id, state_filter=StateFilter.from_types(types) ) ) @@ -1289,20 +1290,16 @@ class ModuleApi: # regardless of their state key ] """ + state_filter = None if event_filter: # If a filter was provided, turn it into a StateFilter and retrieve a filtered # view of the state. state_filter = StateFilter.from_types(event_filter) - state_ids = await self._store.get_filtered_current_state_ids( - room_id, - state_filter, - ) - else: - # If no filter was provided, get the whole state. We could also reuse the call - # to get_filtered_current_state_ids above, with `state_filter = StateFilter.all()`, - # but get_filtered_current_state_ids isn't cached and `get_current_state_ids` - # is, so using the latter when we can is better for perf. - state_ids = await self._store.get_current_state_ids(room_id) + + state_ids = await self._storage_controllers.state.get_current_state_ids( + room_id, + state_filter, + ) state_events = await self._store.get_events(state_ids.values()) diff --git a/synapse/notifier.py b/synapse/notifier.py index 1100434b3f..54b0ec4b97 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -681,7 +681,7 @@ class Notifier: return joined_room_ids, True async def _is_world_readable(self, room_id: str) -> bool: - state = await self.state_handler.get_current_state( + state = await self._storage_controllers.state.get_current_state_event( room_id, EventTypes.RoomHistoryVisibility, "" ) if state and "history_visibility" in state.content: diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py index 63aefd07f5..015c19b2d9 100644 --- a/synapse/push/mailer.py +++ b/synapse/push/mailer.py @@ -255,7 +255,9 @@ class Mailer: user_display_name = user_id async def _fetch_room_state(room_id: str) -> None: - room_state = await self.store.get_current_state_ids(room_id) + room_state = await self._state_storage_controller.get_current_state_ids( + room_id + ) state_by_room[room_id] = room_state # Run at most 3 of these at once: sync does 10 at a time but email diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py index 356d6f74d7..9d953d58de 100644 --- a/synapse/rest/admin/rooms.py +++ b/synapse/rest/admin/rooms.py @@ -34,6 +34,7 @@ from synapse.rest.admin._base import ( assert_user_is_admin, ) from synapse.storage.databases.main.room import RoomSortOrder +from synapse.storage.state import StateFilter from synapse.types import JsonDict, RoomID, UserID, create_requester from synapse.util import json_decoder @@ -418,6 +419,7 @@ class RoomStateRestServlet(RestServlet): def __init__(self, hs: "HomeServer"): self.auth = hs.get_auth() self.store = hs.get_datastores().main + self._storage_controllers = hs.get_storage_controllers() self.clock = hs.get_clock() self._event_serializer = hs.get_event_client_serializer() @@ -430,7 +432,7 @@ class RoomStateRestServlet(RestServlet): if not ret: raise NotFoundError("Room not found") - event_ids = await self.store.get_current_state_ids(room_id) + event_ids = await self._storage_controllers.state.get_current_state_ids(room_id) events = await self.store.get_events(event_ids.values()) now = self.clock.time_msec() room_state = self._event_serializer.serialize_events(events.values(), now) @@ -447,7 +449,8 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet): super().__init__(hs) self.auth = hs.get_auth() self.admin_handler = hs.get_admin_handler() - self.state_handler = hs.get_state_handler() + self.store = hs.get_datastores().main + self._storage_controllers = hs.get_storage_controllers() self.is_mine = hs.is_mine async def on_POST( @@ -489,8 +492,11 @@ class JoinRoomAliasServlet(ResolveRoomIdMixin, RestServlet): ) # send invite if room has "JoinRules.INVITE" - room_state = await self.state_handler.get_current_state(room_id) - join_rules_event = room_state.get((EventTypes.JoinRules, "")) + join_rules_event = ( + await self._storage_controllers.state.get_current_state_event( + room_id, EventTypes.JoinRules, "" + ) + ) if join_rules_event: if not (join_rules_event.content.get("join_rule") == JoinRules.PUBLIC): # update_membership with an action of "invite" can raise a @@ -535,6 +541,7 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet): super().__init__(hs) self.auth = hs.get_auth() self.store = hs.get_datastores().main + self._state_storage_controller = hs.get_storage_controllers().state self.event_creation_handler = hs.get_event_creation_handler() self.state_handler = hs.get_state_handler() self.is_mine_id = hs.is_mine_id @@ -552,12 +559,22 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet): user_to_add = content.get("user_id", requester.user.to_string()) # Figure out which local users currently have power in the room, if any. - room_state = await self.state_handler.get_current_state(room_id) - if not room_state: + filtered_room_state = await self._state_storage_controller.get_current_state( + room_id, + StateFilter.from_types( + [ + (EventTypes.Create, ""), + (EventTypes.PowerLevels, ""), + (EventTypes.JoinRules, ""), + (EventTypes.Member, user_to_add), + ] + ), + ) + if not filtered_room_state: raise SynapseError(HTTPStatus.BAD_REQUEST, "Server not in room") - create_event = room_state[(EventTypes.Create, "")] - power_levels = room_state.get((EventTypes.PowerLevels, "")) + create_event = filtered_room_state[(EventTypes.Create, "")] + power_levels = filtered_room_state.get((EventTypes.PowerLevels, "")) if power_levels is not None: # We pick the local user with the highest power. @@ -633,7 +650,7 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet): # Now we check if the user we're granting admin rights to is already in # the room. If not and it's not a public room we invite them. - member_event = room_state.get((EventTypes.Member, user_to_add)) + member_event = filtered_room_state.get((EventTypes.Member, user_to_add)) is_joined = False if member_event: is_joined = member_event.content["membership"] in ( @@ -644,7 +661,7 @@ class MakeRoomAdminRestServlet(ResolveRoomIdMixin, RestServlet): if is_joined: return HTTPStatus.OK, {} - join_rules = room_state.get((EventTypes.JoinRules, "")) + join_rules = filtered_room_state.get((EventTypes.JoinRules, "")) is_public = False if join_rules: is_public = join_rules.content.get("join_rule") == JoinRules.PUBLIC diff --git a/synapse/rest/admin/users.py b/synapse/rest/admin/users.py index 8e29ada8a0..f0614a2897 100644 --- a/synapse/rest/admin/users.py +++ b/synapse/rest/admin/users.py @@ -226,6 +226,13 @@ class UserRestServletV2(RestServlet): if not isinstance(password, str) or len(password) > 512: raise SynapseError(HTTPStatus.BAD_REQUEST, "Invalid password") + logout_devices = body.get("logout_devices", True) + if not isinstance(logout_devices, bool): + raise SynapseError( + HTTPStatus.BAD_REQUEST, + "'logout_devices' parameter is not of type boolean", + ) + deactivate = body.get("deactivated", False) if not isinstance(deactivate, bool): raise SynapseError( @@ -305,7 +312,6 @@ class UserRestServletV2(RestServlet): await self.store.set_server_admin(target_user, set_admin_to) if password is not None: - logout_devices = True new_password_hash = await self.auth_handler.hash(password) await self.set_password_handler.set_password( diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index 7a5ce8ad0e..a26e976492 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -650,6 +650,7 @@ class RoomEventServlet(RestServlet): self.clock = hs.get_clock() self._store = hs.get_datastores().main self._state = hs.get_state_handler() + self._storage_controllers = hs.get_storage_controllers() self.event_handler = hs.get_event_handler() self._event_serializer = hs.get_event_client_serializer() self._relations_handler = hs.get_relations_handler() @@ -673,8 +674,10 @@ class RoomEventServlet(RestServlet): if include_unredacted_content and not await self.auth.is_server_admin( requester.user ): - power_level_event = await self._state.get_current_state( - room_id, EventTypes.PowerLevels, "" + power_level_event = ( + await self._storage_controllers.state.get_current_state_event( + room_id, EventTypes.PowerLevels, "" + ) ) auth_events = {} diff --git a/synapse/rest/media/v1/media_repository.py b/synapse/rest/media/v1/media_repository.py index 20af366538..a551458a9f 100644 --- a/synapse/rest/media/v1/media_repository.py +++ b/synapse/rest/media/v1/media_repository.py @@ -587,15 +587,16 @@ class MediaRepository: ) return None - t_byte_source = await defer_to_thread( - self.hs.get_reactor(), - self._generate_thumbnail, - thumbnailer, - t_width, - t_height, - t_method, - t_type, - ) + with thumbnailer: + t_byte_source = await defer_to_thread( + self.hs.get_reactor(), + self._generate_thumbnail, + thumbnailer, + t_width, + t_height, + t_method, + t_type, + ) if t_byte_source: try: @@ -657,15 +658,16 @@ class MediaRepository: ) return None - t_byte_source = await defer_to_thread( - self.hs.get_reactor(), - self._generate_thumbnail, - thumbnailer, - t_width, - t_height, - t_method, - t_type, - ) + with thumbnailer: + t_byte_source = await defer_to_thread( + self.hs.get_reactor(), + self._generate_thumbnail, + thumbnailer, + t_width, + t_height, + t_method, + t_type, + ) if t_byte_source: try: @@ -749,119 +751,134 @@ class MediaRepository: ) return None - m_width = thumbnailer.width - m_height = thumbnailer.height + with thumbnailer: + m_width = thumbnailer.width + m_height = thumbnailer.height - if m_width * m_height >= self.max_image_pixels: - logger.info( - "Image too large to thumbnail %r x %r > %r", - m_width, - m_height, - self.max_image_pixels, - ) - return None - - if thumbnailer.transpose_method is not None: - m_width, m_height = await defer_to_thread( - self.hs.get_reactor(), thumbnailer.transpose - ) - - # We deduplicate the thumbnail sizes by ignoring the cropped versions if - # they have the same dimensions of a scaled one. - thumbnails: Dict[Tuple[int, int, str], str] = {} - for requirement in requirements: - if requirement.method == "crop": - thumbnails.setdefault( - (requirement.width, requirement.height, requirement.media_type), - requirement.method, + if m_width * m_height >= self.max_image_pixels: + logger.info( + "Image too large to thumbnail %r x %r > %r", + m_width, + m_height, + self.max_image_pixels, ) - elif requirement.method == "scale": - t_width, t_height = thumbnailer.aspect( - requirement.width, requirement.height + return None + + if thumbnailer.transpose_method is not None: + m_width, m_height = await defer_to_thread( + self.hs.get_reactor(), thumbnailer.transpose ) - t_width = min(m_width, t_width) - t_height = min(m_height, t_height) - thumbnails[ - (t_width, t_height, requirement.media_type) - ] = requirement.method - # Now we generate the thumbnails for each dimension, store it - for (t_width, t_height, t_type), t_method in thumbnails.items(): - # Generate the thumbnail - if t_method == "crop": - t_byte_source = await defer_to_thread( - self.hs.get_reactor(), thumbnailer.crop, t_width, t_height, t_type - ) - elif t_method == "scale": - t_byte_source = await defer_to_thread( - self.hs.get_reactor(), thumbnailer.scale, t_width, t_height, t_type - ) - else: - logger.error("Unrecognized method: %r", t_method) - continue - - if not t_byte_source: - continue - - file_info = FileInfo( - server_name=server_name, - file_id=file_id, - url_cache=url_cache, - thumbnail=ThumbnailInfo( - width=t_width, - height=t_height, - method=t_method, - type=t_type, - ), - ) - - with self.media_storage.store_into_file(file_info) as (f, fname, finish): - try: - await self.media_storage.write_to_file(t_byte_source, f) - await finish() - finally: - t_byte_source.close() - - t_len = os.path.getsize(fname) - - # Write to database - if server_name: - # Multiple remote media download requests can race (when - # using multiple media repos), so this may throw a violation - # constraint exception. If it does we'll delete the newly - # generated thumbnail from disk (as we're in the ctx - # manager). - # - # However: we've already called `finish()` so we may have - # also written to the storage providers. This is preferable - # to the alternative where we call `finish()` *after* this, - # where we could end up having an entry in the DB but fail - # to write the files to the storage providers. - try: - await self.store.store_remote_media_thumbnail( - server_name, - media_id, - file_id, - t_width, - t_height, - t_type, - t_method, - t_len, - ) - except Exception as e: - thumbnail_exists = await self.store.get_remote_media_thumbnail( - server_name, - media_id, - t_width, - t_height, - t_type, - ) - if not thumbnail_exists: - raise e - else: - await self.store.store_local_thumbnail( - media_id, t_width, t_height, t_type, t_method, t_len + # We deduplicate the thumbnail sizes by ignoring the cropped versions if + # they have the same dimensions of a scaled one. + thumbnails: Dict[Tuple[int, int, str], str] = {} + for requirement in requirements: + if requirement.method == "crop": + thumbnails.setdefault( + (requirement.width, requirement.height, requirement.media_type), + requirement.method, ) + elif requirement.method == "scale": + t_width, t_height = thumbnailer.aspect( + requirement.width, requirement.height + ) + t_width = min(m_width, t_width) + t_height = min(m_height, t_height) + thumbnails[ + (t_width, t_height, requirement.media_type) + ] = requirement.method + + # Now we generate the thumbnails for each dimension, store it + for (t_width, t_height, t_type), t_method in thumbnails.items(): + # Generate the thumbnail + if t_method == "crop": + t_byte_source = await defer_to_thread( + self.hs.get_reactor(), + thumbnailer.crop, + t_width, + t_height, + t_type, + ) + elif t_method == "scale": + t_byte_source = await defer_to_thread( + self.hs.get_reactor(), + thumbnailer.scale, + t_width, + t_height, + t_type, + ) + else: + logger.error("Unrecognized method: %r", t_method) + continue + + if not t_byte_source: + continue + + file_info = FileInfo( + server_name=server_name, + file_id=file_id, + url_cache=url_cache, + thumbnail=ThumbnailInfo( + width=t_width, + height=t_height, + method=t_method, + type=t_type, + ), + ) + + with self.media_storage.store_into_file(file_info) as ( + f, + fname, + finish, + ): + try: + await self.media_storage.write_to_file(t_byte_source, f) + await finish() + finally: + t_byte_source.close() + + t_len = os.path.getsize(fname) + + # Write to database + if server_name: + # Multiple remote media download requests can race (when + # using multiple media repos), so this may throw a violation + # constraint exception. If it does we'll delete the newly + # generated thumbnail from disk (as we're in the ctx + # manager). + # + # However: we've already called `finish()` so we may have + # also written to the storage providers. This is preferable + # to the alternative where we call `finish()` *after* this, + # where we could end up having an entry in the DB but fail + # to write the files to the storage providers. + try: + await self.store.store_remote_media_thumbnail( + server_name, + media_id, + file_id, + t_width, + t_height, + t_type, + t_method, + t_len, + ) + except Exception as e: + thumbnail_exists = ( + await self.store.get_remote_media_thumbnail( + server_name, + media_id, + t_width, + t_height, + t_type, + ) + ) + if not thumbnail_exists: + raise e + else: + await self.store.store_local_thumbnail( + media_id, t_width, t_height, t_type, t_method, t_len + ) return {"width": m_width, "height": m_height} diff --git a/synapse/rest/media/v1/preview_html.py b/synapse/rest/media/v1/preview_html.py index 13ec7ab533..ed8f21a483 100644 --- a/synapse/rest/media/v1/preview_html.py +++ b/synapse/rest/media/v1/preview_html.py @@ -30,6 +30,9 @@ _xml_encoding_match = re.compile( ) _content_type_match = re.compile(r'.*; *charset="?(.*?)"?(;|$)', flags=re.I) +# Certain elements aren't meant for display. +ARIA_ROLES_TO_IGNORE = {"directory", "menu", "menubar", "toolbar"} + def _normalise_encoding(encoding: str) -> Optional[str]: """Use the Python codec's name as the normalised entry.""" @@ -174,13 +177,15 @@ def parse_html_to_open_graph(tree: "etree.Element") -> Dict[str, Optional[str]]: # "og:video:secure_url": "https://www.youtube.com/v/LXDBoHyjmtw?version=3", og: Dict[str, Optional[str]] = {} - for tag in tree.xpath("//*/meta[starts-with(@property, 'og:')]"): - if "content" in tag.attrib: - # if we've got more than 50 tags, someone is taking the piss - if len(og) >= 50: - logger.warning("Skipping OG for page with too many 'og:' tags") - return {} - og[tag.attrib["property"]] = tag.attrib["content"] + for tag in tree.xpath( + "//*/meta[starts-with(@property, 'og:')][@content][not(@content='')]" + ): + # if we've got more than 50 tags, someone is taking the piss + if len(og) >= 50: + logger.warning("Skipping OG for page with too many 'og:' tags") + return {} + + og[tag.attrib["property"]] = tag.attrib["content"] # TODO: grab article: meta tags too, e.g.: @@ -192,21 +197,23 @@ def parse_html_to_open_graph(tree: "etree.Element") -> Dict[str, Optional[str]]: # "article:modified_time" content="2016-04-01T18:31:53+00:00" /> if "og:title" not in og: - # do some basic spidering of the HTML - title = tree.xpath("(//title)[1] | (//h1)[1] | (//h2)[1] | (//h3)[1]") - if title and title[0].text is not None: - og["og:title"] = title[0].text.strip() + # Attempt to find a title from the title tag, or the biggest header on the page. + title = tree.xpath("((//title)[1] | (//h1)[1] | (//h2)[1] | (//h3)[1])/text()") + if title: + og["og:title"] = title[0].strip() else: og["og:title"] = None if "og:image" not in og: - # TODO: extract a favicon failing all else meta_image = tree.xpath( - "//*/meta[translate(@itemprop, 'IMAGE', 'image')='image']/@content" + "//*/meta[translate(@itemprop, 'IMAGE', 'image')='image'][not(@content='')]/@content[1]" ) + # If a meta image is found, use it. if meta_image: og["og:image"] = meta_image[0] else: + # Try to find images which are larger than 10px by 10px. + # # TODO: consider inlined CSS styles as well as width & height attribs images = tree.xpath("//img[@src][number(@width)>10][number(@height)>10]") images = sorted( @@ -215,17 +222,24 @@ def parse_html_to_open_graph(tree: "etree.Element") -> Dict[str, Optional[str]]: -1 * float(i.attrib["width"]) * float(i.attrib["height"]) ), ) + # If no images were found, try to find *any* images. if not images: - images = tree.xpath("//img[@src]") + images = tree.xpath("//img[@src][1]") if images: og["og:image"] = images[0].attrib["src"] + # Finally, fallback to the favicon if nothing else. + else: + favicons = tree.xpath("//link[@href][contains(@rel, 'icon')]/@href[1]") + if favicons: + og["og:image"] = favicons[0] + if "og:description" not in og: + # Check the first meta description tag for content. meta_description = tree.xpath( - "//*/meta" - "[translate(@name, 'DESCRIPTION', 'description')='description']" - "/@content" + "//*/meta[translate(@name, 'DESCRIPTION', 'description')='description'][not(@content='')]/@content[1]" ) + # If a meta description is found with content, use it. if meta_description: og["og:description"] = meta_description[0] else: @@ -306,6 +320,10 @@ def _iterate_over_text( if isinstance(el, str): yield el elif el.tag not in tags_to_ignore: + # If the element isn't meant for display, ignore it. + if el.get("role") in ARIA_ROLES_TO_IGNORE: + continue + # el.text is the text before the first child, so we can immediately # return it if the text exists. if el.text: diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py index 2b2db63bf7..54a849eac9 100644 --- a/synapse/rest/media/v1/preview_url_resource.py +++ b/synapse/rest/media/v1/preview_url_resource.py @@ -586,12 +586,16 @@ class PreviewUrlResource(DirectServeJsonResource): og: The Open Graph dictionary. This is modified with image information. """ # If there's no image or it is blank, there's nothing to do. - if "og:image" not in og or not og["og:image"]: + if "og:image" not in og: + return + + # Remove the raw image URL, this will be replaced with an MXC URL, if successful. + image_url = og.pop("og:image") + if not image_url: return # The image URL from the HTML might be relative to the previewed page, # convert it to an URL which can be requested directly. - image_url = og["og:image"] url_parts = urlparse(image_url) if url_parts.scheme != "data": image_url = urljoin(media_info.uri, image_url) @@ -599,7 +603,16 @@ class PreviewUrlResource(DirectServeJsonResource): # FIXME: it might be cleaner to use the same flow as the main /preview_url # request itself and benefit from the same caching etc. But for now we # just rely on the caching on the master request to speed things up. - image_info = await self._handle_url(image_url, user, allow_data_urls=True) + try: + image_info = await self._handle_url(image_url, user, allow_data_urls=True) + except Exception as e: + # Pre-caching the image failed, don't block the entire URL preview. + logger.warning( + "Pre-caching image failed during URL preview: %s errored with %s", + image_url, + e, + ) + return if _is_media(image_info.media_type): # TODO: make sure we don't choke on white-on-transparent images @@ -611,13 +624,11 @@ class PreviewUrlResource(DirectServeJsonResource): og["og:image:width"] = dims["width"] og["og:image:height"] = dims["height"] else: - logger.warning("Couldn't get dims for %s", og["og:image"]) + logger.warning("Couldn't get dims for %s", image_url) og["og:image"] = f"mxc://{self.server_name}/{image_info.filesystem_id}" og["og:image:type"] = image_info.media_type og["matrix:image:size"] = image_info.media_length - else: - del og["og:image"] async def _handle_oembed_response( self, url: str, media_info: MediaInfo, expiration_ms: int diff --git a/synapse/rest/media/v1/thumbnailer.py b/synapse/rest/media/v1/thumbnailer.py index 390491eb83..9b93b9b4f6 100644 --- a/synapse/rest/media/v1/thumbnailer.py +++ b/synapse/rest/media/v1/thumbnailer.py @@ -14,7 +14,8 @@ # limitations under the License. import logging from io import BytesIO -from typing import Tuple +from types import TracebackType +from typing import Optional, Tuple, Type from PIL import Image @@ -45,6 +46,9 @@ class Thumbnailer: Image.MAX_IMAGE_PIXELS = max_image_pixels def __init__(self, input_path: str): + # Have we closed the image? + self._closed = False + try: self.image = Image.open(input_path) except OSError as e: @@ -89,7 +93,8 @@ class Thumbnailer: # Safety: `transpose` takes an int rather than e.g. an IntEnum. # self.transpose_method is set above to be a value in # EXIF_TRANSPOSE_MAPPINGS, and that only contains correct values. - self.image = self.image.transpose(self.transpose_method) # type: ignore[arg-type] + with self.image: + self.image = self.image.transpose(self.transpose_method) # type: ignore[arg-type] self.width, self.height = self.image.size self.transpose_method = None # We don't need EXIF any more @@ -122,9 +127,11 @@ class Thumbnailer: # If the image has transparency, use RGBA instead. if self.image.mode in ["1", "L", "P"]: if self.image.info.get("transparency", None) is not None: - self.image = self.image.convert("RGBA") + with self.image: + self.image = self.image.convert("RGBA") else: - self.image = self.image.convert("RGB") + with self.image: + self.image = self.image.convert("RGB") return self.image.resize((width, height), Image.ANTIALIAS) def scale(self, width: int, height: int, output_type: str) -> BytesIO: @@ -133,8 +140,8 @@ class Thumbnailer: Returns: BytesIO: the bytes of the encoded image ready to be written to disk """ - scaled = self._resize(width, height) - return self._encode_image(scaled, output_type) + with self._resize(width, height) as scaled: + return self._encode_image(scaled, output_type) def crop(self, width: int, height: int, output_type: str) -> BytesIO: """Rescales and crops the image to the given dimensions preserving @@ -151,18 +158,21 @@ class Thumbnailer: BytesIO: the bytes of the encoded image ready to be written to disk """ if width * self.height > height * self.width: + scaled_width = width scaled_height = (width * self.height) // self.width - scaled_image = self._resize(width, scaled_height) crop_top = (scaled_height - height) // 2 crop_bottom = height + crop_top - cropped = scaled_image.crop((0, crop_top, width, crop_bottom)) + crop = (0, crop_top, width, crop_bottom) else: scaled_width = (height * self.width) // self.height - scaled_image = self._resize(scaled_width, height) + scaled_height = height crop_left = (scaled_width - width) // 2 crop_right = width + crop_left - cropped = scaled_image.crop((crop_left, 0, crop_right, height)) - return self._encode_image(cropped, output_type) + crop = (crop_left, 0, crop_right, height) + + with self._resize(scaled_width, scaled_height) as scaled_image: + with scaled_image.crop(crop) as cropped: + return self._encode_image(cropped, output_type) def _encode_image(self, output_image: Image.Image, output_type: str) -> BytesIO: output_bytes_io = BytesIO() @@ -171,3 +181,42 @@ class Thumbnailer: output_image = output_image.convert("RGB") output_image.save(output_bytes_io, fmt, quality=80) return output_bytes_io + + def close(self) -> None: + """Closes the underlying image file. + + Once closed no other functions can be called. + + Can be called multiple times. + """ + + if self._closed: + return + + self._closed = True + + # Since we run this on the finalizer then we need to handle `__init__` + # raising an exception before it can define `self.image`. + image = getattr(self, "image", None) + if image is None: + return + + image.close() + + def __enter__(self) -> "Thumbnailer": + """Make `Thumbnailer` a context manager that calls `close` on + `__exit__`. + """ + return self + + def __exit__( + self, + type: Optional[Type[BaseException]], + value: Optional[BaseException], + traceback: Optional[TracebackType], + ) -> None: + self.close() + + def __del__(self) -> None: + # Make sure we actually do close the image, rather than leak data. + self.close() diff --git a/synapse/server_notices/resource_limits_server_notices.py b/synapse/server_notices/resource_limits_server_notices.py index b5f3a0c74e..6863020778 100644 --- a/synapse/server_notices/resource_limits_server_notices.py +++ b/synapse/server_notices/resource_limits_server_notices.py @@ -36,6 +36,7 @@ class ResourceLimitsServerNotices: def __init__(self, hs: "HomeServer"): self._server_notices_manager = hs.get_server_notices_manager() self._store = hs.get_datastores().main + self._storage_controllers = hs.get_storage_controllers() self._auth = hs.get_auth() self._config = hs.config self._resouce_limited = False @@ -178,8 +179,10 @@ class ResourceLimitsServerNotices: currently_blocked = False pinned_state_event = None try: - pinned_state_event = await self._state.get_current_state( - room_id, event_type=EventTypes.Pinned + pinned_state_event = ( + await self._storage_controllers.state.get_current_state_event( + room_id, event_type=EventTypes.Pinned, state_key="" + ) ) except AuthError: # The user has yet to join the server notices room diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index bf09f5128a..da25f20ae5 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -32,13 +32,11 @@ from typing import ( Set, Tuple, Union, - overload, ) import attr from frozendict import frozendict from prometheus_client import Counter, Histogram -from typing_extensions import Literal from synapse.api.constants import EventTypes from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, StateResolutionVersions @@ -132,85 +130,20 @@ class StateHandler: self._state_resolution_handler = hs.get_state_resolution_handler() self._storage_controllers = hs.get_storage_controllers() - @overload - async def get_current_state( - self, - room_id: str, - event_type: Literal[None] = None, - state_key: str = "", - latest_event_ids: Optional[List[str]] = None, - ) -> StateMap[EventBase]: - ... - - @overload - async def get_current_state( - self, - room_id: str, - event_type: str, - state_key: str = "", - latest_event_ids: Optional[List[str]] = None, - ) -> Optional[EventBase]: - ... - - async def get_current_state( - self, - room_id: str, - event_type: Optional[str] = None, - state_key: str = "", - latest_event_ids: Optional[List[str]] = None, - ) -> Union[Optional[EventBase], StateMap[EventBase]]: - """Retrieves the current state for the room. This is done by - calling `get_latest_events_in_room` to get the leading edges of the - event graph and then resolving any of the state conflicts. - - This is equivalent to getting the state of an event that were to send - next before receiving any new events. - - Returns: - If `event_type` is specified, then the method returns only the one - event (or None) with that `event_type` and `state_key`. - - Otherwise, a map from (type, state_key) to event. - """ - if not latest_event_ids: - latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id) - assert latest_event_ids is not None - - logger.debug("calling resolve_state_groups from get_current_state") - ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids) - state = ret.state - - if event_type: - event_id = state.get((event_type, state_key)) - event = None - if event_id: - event = await self.store.get_event(event_id, allow_none=True) - return event - - state_map = await self.store.get_events( - list(state.values()), get_prev_content=False - ) - return { - key: state_map[e_id] for key, e_id in state.items() if e_id in state_map - } - async def get_current_state_ids( - self, room_id: str, latest_event_ids: Optional[Collection[str]] = None + self, + room_id: str, + latest_event_ids: Collection[str], ) -> StateMap[str]: """Get the current state, or the state at a set of events, for a room Args: room_id: - latest_event_ids: if given, the forward extremities to resolve. If - None, we look them up from the database (via a cache). + latest_event_ids: The forward extremities to resolve. Returns: the state dict, mapping from (event_type, state_key) -> event_id """ - if not latest_event_ids: - latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id) - assert latest_event_ids is not None - logger.debug("calling resolve_state_groups from get_current_state_ids") ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids) return ret.state @@ -239,10 +172,6 @@ class StateHandler: entry = await self.resolve_state_groups_for_events(room_id, latest_event_ids) return await self.store.get_joined_users_from_state(room_id, entry) - async def get_current_hosts_in_room(self, room_id: str) -> FrozenSet[str]: - event_ids = await self.store.get_latest_event_ids_in_room(room_id) - return await self.get_hosts_in_room_at_events(room_id, event_ids) - async def get_hosts_in_room_at_events( self, room_id: str, event_ids: Collection[str] ) -> FrozenSet[str]: diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 8df80664a2..abfc56b061 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -71,13 +71,14 @@ class SQLBaseStore(metaclass=ABCMeta): self._attempt_to_invalidate_cache("is_host_joined", (room_id, host)) if members_changed: self._attempt_to_invalidate_cache("get_users_in_room", (room_id,)) + self._attempt_to_invalidate_cache("get_current_hosts_in_room", (room_id,)) self._attempt_to_invalidate_cache( "get_users_in_room_with_profiles", (room_id,) ) # Purge other caches based on room state. self._attempt_to_invalidate_cache("get_room_summary", (room_id,)) - self._attempt_to_invalidate_cache("get_current_state_ids", (room_id,)) + self._attempt_to_invalidate_cache("get_partial_current_state_ids", (room_id,)) def _attempt_to_invalidate_cache( self, cache_name: str, key: Optional[Collection[Any]] diff --git a/synapse/storage/controllers/__init__.py b/synapse/storage/controllers/__init__.py index 992261d07b..55649719f6 100644 --- a/synapse/storage/controllers/__init__.py +++ b/synapse/storage/controllers/__init__.py @@ -18,7 +18,7 @@ from synapse.storage.controllers.persist_events import ( EventsPersistenceStorageController, ) from synapse.storage.controllers.purge_events import PurgeEventsStorageController -from synapse.storage.controllers.state import StateGroupStorageController +from synapse.storage.controllers.state import StateStorageController from synapse.storage.databases import Databases from synapse.storage.databases.main import DataStore @@ -39,7 +39,7 @@ class StorageControllers: self.main = stores.main self.purge_events = PurgeEventsStorageController(hs, stores) - self.state = StateGroupStorageController(hs, stores) + self.state = StateStorageController(hs, stores) self.persistence = None if stores.persist_events: diff --git a/synapse/storage/controllers/persist_events.py b/synapse/storage/controllers/persist_events.py index ef8c135b12..4caaa81808 100644 --- a/synapse/storage/controllers/persist_events.py +++ b/synapse/storage/controllers/persist_events.py @@ -994,7 +994,7 @@ class EventsPersistenceStorageController: Assumes that we are only persisting events for one room at a time. """ - existing_state = await self.main_store.get_current_state_ids(room_id) + existing_state = await self.main_store.get_partial_current_state_ids(room_id) to_delete = [key for key in existing_state if key not in current_state] @@ -1083,7 +1083,7 @@ class EventsPersistenceStorageController: # The server will leave the room, so we go and find out which remote # users will still be joined when we leave. if current_state is None: - current_state = await self.main_store.get_current_state_ids(room_id) + current_state = await self.main_store.get_partial_current_state_ids(room_id) current_state = dict(current_state) for key in delta.to_delete: current_state.pop(key, None) diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py index 0f09953086..3b4cdb67eb 100644 --- a/synapse/storage/controllers/state.py +++ b/synapse/storage/controllers/state.py @@ -14,19 +14,26 @@ import logging from typing import ( TYPE_CHECKING, + Any, Awaitable, + Callable, Collection, Dict, Iterable, List, Mapping, Optional, + Set, Tuple, ) +from synapse.api.constants import EventTypes from synapse.events import EventBase from synapse.storage.state import StateFilter -from synapse.storage.util.partial_state_events_tracker import PartialStateEventsTracker +from synapse.storage.util.partial_state_events_tracker import ( + PartialCurrentStateTracker, + PartialStateEventsTracker, +) from synapse.types import MutableStateMap, StateMap if TYPE_CHECKING: @@ -36,17 +43,27 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) -class StateGroupStorageController: - """High level interface to fetching state for event.""" +class StateStorageController: + """High level interface to fetching state for an event, or the current state + in a room. + """ def __init__(self, hs: "HomeServer", stores: "Databases"): self._is_mine_id = hs.is_mine_id self.stores = stores self._partial_state_events_tracker = PartialStateEventsTracker(stores.main) + self._partial_state_room_tracker = PartialCurrentStateTracker(stores.main) def notify_event_un_partial_stated(self, event_id: str) -> None: self._partial_state_events_tracker.notify_un_partial_stated(event_id) + def notify_room_un_partial_stated(self, room_id: str) -> None: + """Notify that the room no longer has any partial state. + + Must be called after `DataStore.clear_partial_state_room` + """ + self._partial_state_room_tracker.notify_un_partial_stated(room_id) + async def get_state_group_delta( self, state_group: int ) -> Tuple[Optional[int], Optional[StateMap[str]]]: @@ -349,3 +366,127 @@ class StateGroupStorageController: return await self.stores.state.store_state_group( event_id, room_id, prev_group, delta_ids, current_state_ids ) + + async def get_current_state_ids( + self, + room_id: str, + state_filter: Optional[StateFilter] = None, + on_invalidate: Optional[Callable[[], None]] = None, + ) -> StateMap[str]: + """Get the current state event ids for a room based on the + current_state_events table. + + If a state filter is given (that is not `StateFilter.all()`) the query + result is *not* cached. + + Args: + room_id: The room to get the state IDs of. state_filter: The state + filter used to fetch state from the + database. + on_invalidate: Callback for when the `get_current_state_ids` cache + for the room gets invalidated. + + Returns: + The current state of the room. + """ + if not state_filter or state_filter.must_await_full_state(self._is_mine_id): + await self._partial_state_room_tracker.await_full_state(room_id) + + if state_filter and not state_filter.is_full(): + return await self.stores.main.get_partial_filtered_current_state_ids( + room_id, state_filter + ) + else: + return await self.stores.main.get_partial_current_state_ids( + room_id, on_invalidate=on_invalidate + ) + + async def get_canonical_alias_for_room(self, room_id: str) -> Optional[str]: + """Get canonical alias for room, if any + + Args: + room_id: The room ID + + Returns: + The canonical alias, if any + """ + + state = await self.get_current_state_ids( + room_id, StateFilter.from_types([(EventTypes.CanonicalAlias, "")]) + ) + + event_id = state.get((EventTypes.CanonicalAlias, "")) + if not event_id: + return None + + event = await self.stores.main.get_event(event_id, allow_none=True) + if not event: + return None + + return event.content.get("canonical_alias") + + async def get_current_state_deltas( + self, prev_stream_id: int, max_stream_id: int + ) -> Tuple[int, List[Dict[str, Any]]]: + """Fetch a list of room state changes since the given stream id + + Each entry in the result contains the following fields: + - stream_id (int) + - room_id (str) + - type (str): event type + - state_key (str): + - event_id (str|None): new event_id for this state key. None if the + state has been deleted. + - prev_event_id (str|None): previous event_id for this state key. None + if it's new state. + + Args: + prev_stream_id: point to get changes since (exclusive) + max_stream_id: the point that we know has been correctly persisted + - ie, an upper limit to return changes from. + + Returns: + A tuple consisting of: + - the stream id which these results go up to + - list of current_state_delta_stream rows. If it is empty, we are + up to date. + """ + # FIXME(faster_joins): what do we do here? + + return await self.stores.main.get_partial_current_state_deltas( + prev_stream_id, max_stream_id + ) + + async def get_current_state( + self, room_id: str, state_filter: Optional[StateFilter] = None + ) -> StateMap[EventBase]: + """Same as `get_current_state_ids` but also fetches the events""" + state_map_ids = await self.get_current_state_ids(room_id, state_filter) + + event_map = await self.stores.main.get_events(list(state_map_ids.values())) + + state_map = {} + for key, event_id in state_map_ids.items(): + event = event_map.get(event_id) + if event: + state_map[key] = event + + return state_map + + async def get_current_state_event( + self, room_id: str, event_type: str, state_key: str + ) -> Optional[EventBase]: + """Get the current state event for the given type/state_key.""" + + key = (event_type, state_key) + state_map = await self.get_current_state( + room_id, StateFilter.from_types((key,)) + ) + return state_map.get(key) + + async def get_current_hosts_in_room(self, room_id: str) -> Set[str]: + """Get current hosts in room based on current state.""" + + await self._partial_state_room_tracker.await_full_state(room_id) + + return await self.stores.main.get_current_hosts_in_room(room_id) diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py index d545a1c002..11d9d16c19 100644 --- a/synapse/storage/databases/main/__init__.py +++ b/synapse/storage/databases/main/__init__.py @@ -151,10 +151,6 @@ class DataStore( ], ) - self._group_updates_id_gen = StreamIdGenerator( - db_conn, "local_group_updates", "stream_id" - ) - self._cache_id_gen: Optional[MultiWriterIdGenerator] if isinstance(self.database_engine, PostgresEngine): # We set the `writers` to an empty list here as we don't care about @@ -197,20 +193,6 @@ class DataStore( prefilled_cache=curr_state_delta_prefill, ) - _group_updates_prefill, min_group_updates_id = self.db_pool.get_cache_dict( - db_conn, - "local_group_updates", - entity_column="user_id", - stream_column="stream_id", - max_value=self._group_updates_id_gen.get_current_token(), - limit=1000, - ) - self._group_updates_stream_cache = StreamChangeCache( - "_group_updates_stream_cache", - min_group_updates_id, - prefilled_cache=_group_updates_prefill, - ) - self._stream_order_on_start = self.get_room_max_stream_ordering() self._min_stream_order_on_start = self.get_room_min_stream_ordering() diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py index da21a50144..c15a7136b6 100644 --- a/synapse/storage/databases/main/group_server.py +++ b/synapse/storage/databases/main/group_server.py @@ -29,11 +29,6 @@ class GroupServerStore(SQLBaseStore): db_conn: LoggingDatabaseConnection, hs: "HomeServer", ): - database.updates.register_background_index_update( - update_name="local_group_updates_index", - index_name="local_group_updates_stream_id_index", - table="local_group_updates", - columns=("stream_id",), - unique=True, - ) + # Register a legacy groups background update as a no-op. + database.updates.register_noop_background_update("local_group_updates_index") super().__init__(database, db_conn, hs) diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py index 40ac377ca9..deffdc19ce 100644 --- a/synapse/storage/databases/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py @@ -276,10 +276,6 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore): (SELECT 1 FROM profiles WHERE profiles.avatar_url = '{media_prefix}' || lmr.media_id) - AND NOT EXISTS - (SELECT 1 - FROM groups - WHERE groups.avatar_url = '{media_prefix}' || lmr.media_id) AND NOT EXISTS (SELECT 1 FROM room_memberships diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py index e197b7203e..a1747f04ce 100644 --- a/synapse/storage/databases/main/profile.py +++ b/synapse/storage/databases/main/profile.py @@ -11,11 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, List, Optional +from typing import Optional from synapse.api.errors import StoreError from synapse.storage._base import SQLBaseStore -from synapse.storage.database import LoggingTransaction from synapse.storage.databases.main.roommember import ProfileInfo @@ -55,17 +54,6 @@ class ProfileWorkerStore(SQLBaseStore): desc="get_profile_avatar_url", ) - async def get_from_remote_profile_cache( - self, user_id: str - ) -> Optional[Dict[str, Any]]: - return await self.db_pool.simple_select_one( - table="remote_profile_cache", - keyvalues={"user_id": user_id}, - retcols=("displayname", "avatar_url"), - allow_none=True, - desc="get_from_remote_profile_cache", - ) - async def create_profile(self, user_localpart: str) -> None: await self.db_pool.simple_insert( table="profiles", values={"user_id": user_localpart}, desc="create_profile" @@ -91,97 +79,6 @@ class ProfileWorkerStore(SQLBaseStore): desc="set_profile_avatar_url", ) - async def update_remote_profile_cache( - self, user_id: str, displayname: Optional[str], avatar_url: Optional[str] - ) -> int: - return await self.db_pool.simple_update( - table="remote_profile_cache", - keyvalues={"user_id": user_id}, - updatevalues={ - "displayname": displayname, - "avatar_url": avatar_url, - "last_check": self._clock.time_msec(), - }, - desc="update_remote_profile_cache", - ) - - async def maybe_delete_remote_profile_cache(self, user_id: str) -> None: - """Check if we still care about the remote user's profile, and if we - don't then remove their profile from the cache - """ - subscribed = await self.is_subscribed_remote_profile_for_user(user_id) - if not subscribed: - await self.db_pool.simple_delete( - table="remote_profile_cache", - keyvalues={"user_id": user_id}, - desc="delete_remote_profile_cache", - ) - - async def is_subscribed_remote_profile_for_user(self, user_id: str) -> bool: - """Check whether we are interested in a remote user's profile.""" - res: Optional[str] = await self.db_pool.simple_select_one_onecol( - table="group_users", - keyvalues={"user_id": user_id}, - retcol="user_id", - allow_none=True, - desc="should_update_remote_profile_cache_for_user", - ) - - if res: - return True - - res = await self.db_pool.simple_select_one_onecol( - table="group_invites", - keyvalues={"user_id": user_id}, - retcol="user_id", - allow_none=True, - desc="should_update_remote_profile_cache_for_user", - ) - - if res: - return True - return False - - async def get_remote_profile_cache_entries_that_expire( - self, last_checked: int - ) -> List[Dict[str, str]]: - """Get all users who haven't been checked since `last_checked`""" - - def _get_remote_profile_cache_entries_that_expire_txn( - txn: LoggingTransaction, - ) -> List[Dict[str, str]]: - sql = """ - SELECT user_id, displayname, avatar_url - FROM remote_profile_cache - WHERE last_check < ? - """ - - txn.execute(sql, (last_checked,)) - - return self.db_pool.cursor_to_dict(txn) - - return await self.db_pool.runInteraction( - "get_remote_profile_cache_entries_that_expire", - _get_remote_profile_cache_entries_that_expire_txn, - ) - class ProfileStore(ProfileWorkerStore): - async def add_remote_profile_cache( - self, user_id: str, displayname: str, avatar_url: str - ) -> None: - """Ensure we are caching the remote user's profiles. - - This should only be called when `is_subscribed_remote_profile_for_user` - would return true for the user. - """ - await self.db_pool.simple_upsert( - table="remote_profile_cache", - keyvalues={"user_id": user_id}, - values={ - "displayname": displayname, - "avatar_url": avatar_url, - "last_check": self._clock.time_msec(), - }, - desc="add_remote_profile_cache", - ) + pass diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py index 2353c120e9..ba385f9fc4 100644 --- a/synapse/storage/databases/main/purge_events.py +++ b/synapse/storage/databases/main/purge_events.py @@ -393,7 +393,6 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore): "partial_state_events", "events", "federation_inbound_events_staging", - "group_rooms", "local_current_membership", "partial_state_rooms_servers", "partial_state_rooms", @@ -413,7 +412,6 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore): "e2e_room_keys", "event_push_summary", "pusher_throttle", - "group_summary_rooms", "room_account_data", "room_tags", # "rooms" happens last, to keep the foreign keys in the other tables diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index cfd8ce1624..68d4fc2e64 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -1139,6 +1139,24 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): keyvalues={"room_id": room_id}, ) + async def is_partial_state_room(self, room_id: str) -> bool: + """Checks if this room has partial state. + + Returns true if this is a "partial-state" room, which means that the state + at events in the room, and `current_state_events`, may not yet be + complete. + """ + + entry = await self.db_pool.simple_select_one_onecol( + table="partial_state_rooms", + keyvalues={"room_id": room_id}, + retcol="room_id", + allow_none=True, + desc="is_partial_state_room", + ) + + return entry is not None + class _BackgroundUpdates: REMOVE_TOMESTONED_ROOMS_BG_UPDATE = "remove_tombstoned_rooms_from_directory" diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index e222b7bd1f..31bc8c5601 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -893,6 +893,43 @@ class RoomMemberWorkerStore(EventsWorkerStore): return True + @cached(iterable=True, max_entries=10000) + async def get_current_hosts_in_room(self, room_id: str) -> Set[str]: + """Get current hosts in room based on current state.""" + + # First we check if we already have `get_users_in_room` in the cache, as + # we can just calculate result from that + users = self.get_users_in_room.cache.get_immediate( + (room_id,), None, update_metrics=False + ) + if users is not None: + return {get_domain_from_id(u) for u in users} + + if isinstance(self.database_engine, Sqlite3Engine): + # If we're using SQLite then let's just always use + # `get_users_in_room` rather than funky SQL. + users = await self.get_users_in_room(room_id) + return {get_domain_from_id(u) for u in users} + + # For PostgreSQL we can use a regex to pull out the domains from the + # joined users in `current_state_events` via regex. + + def get_current_hosts_in_room_txn(txn: LoggingTransaction) -> Set[str]: + sql = """ + SELECT DISTINCT substring(state_key FROM '@[^:]*:(.*)$') + FROM current_state_events + WHERE + type = 'm.room.member' + AND membership = 'join' + AND room_id = ? + """ + txn.execute(sql, (room_id,)) + return {d for d, in txn} + + return await self.db_pool.runInteraction( + "get_current_hosts_in_room", get_current_hosts_in_room_txn + ) + async def get_joined_hosts( self, room_id: str, state_entry: "_StateCacheEntry" ) -> FrozenSet[str]: diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py index a07ad85582..bdd00273cd 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py @@ -54,6 +54,7 @@ class EventMetadata: room_id: str event_type: str state_key: Optional[str] + rejection_reason: Optional[str] def _retrieve_and_check_room_version(room_id: str, room_version_id: str) -> RoomVersion: @@ -167,17 +168,22 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): ) sql = f""" - SELECT e.event_id, e.room_id, e.type, se.state_key FROM events AS e + SELECT e.event_id, e.room_id, e.type, se.state_key, r.reason + FROM events AS e LEFT JOIN state_events se USING (event_id) + LEFT JOIN rejections r USING (event_id) WHERE {clause} """ txn.execute(sql, args) return { event_id: EventMetadata( - room_id=room_id, event_type=event_type, state_key=state_key + room_id=room_id, + event_type=event_type, + state_key=state_key, + rejection_reason=rejection_reason, ) - for event_id, room_id, event_type, state_key in txn + for event_id, room_id, event_type, state_key, rejection_reason in txn } result_map: Dict[str, EventMetadata] = {} @@ -236,7 +242,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): Raises: NotFoundError if the room is unknown """ - state_ids = await self.get_current_state_ids(room_id) + state_ids = await self.get_partial_current_state_ids(room_id) if not state_ids: raise NotFoundError(f"Current state for room {room_id} is empty") @@ -252,10 +258,12 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): return create_event @cached(max_entries=100000, iterable=True) - async def get_current_state_ids(self, room_id: str) -> StateMap[str]: + async def get_partial_current_state_ids(self, room_id: str) -> StateMap[str]: """Get the current state event ids for a room based on the current_state_events table. + This may be the partial state if we're lazy joining the room. + Args: room_id: The room to get the state IDs of. @@ -274,17 +282,19 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): return {(intern_string(r[0]), intern_string(r[1])): r[2] for r in txn} return await self.db_pool.runInteraction( - "get_current_state_ids", _get_current_state_ids_txn + "get_partial_current_state_ids", _get_current_state_ids_txn ) # FIXME: how should this be cached? - async def get_filtered_current_state_ids( + async def get_partial_filtered_current_state_ids( self, room_id: str, state_filter: Optional[StateFilter] = None ) -> StateMap[str]: """Get the current state event of a given type for a room based on the current_state_events table. This may not be as up-to-date as the result of doing a fresh state resolution as per state_handler.get_current_state + This may be the partial state if we're lazy joining the room. + Args: room_id state_filter: The state filter used to fetch state @@ -300,7 +310,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): if not where_clause: # We delegate to the cached version - return await self.get_current_state_ids(room_id) + return await self.get_partial_current_state_ids(room_id) def _get_filtered_current_state_ids_txn( txn: LoggingTransaction, @@ -328,30 +338,6 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): "get_filtered_current_state_ids", _get_filtered_current_state_ids_txn ) - async def get_canonical_alias_for_room(self, room_id: str) -> Optional[str]: - """Get canonical alias for room, if any - - Args: - room_id: The room ID - - Returns: - The canonical alias, if any - """ - - state = await self.get_filtered_current_state_ids( - room_id, StateFilter.from_types([(EventTypes.CanonicalAlias, "")]) - ) - - event_id = state.get((EventTypes.CanonicalAlias, "")) - if not event_id: - return None - - event = await self.get_event(event_id, allow_none=True) - if not event: - return None - - return event.content.get("canonical_alias") - @cached(max_entries=50000) async def _get_state_group_for_event(self, event_id: str) -> Optional[int]: return await self.db_pool.simple_select_one_onecol( diff --git a/synapse/storage/databases/main/state_deltas.py b/synapse/storage/databases/main/state_deltas.py index 188afec332..445213e12a 100644 --- a/synapse/storage/databases/main/state_deltas.py +++ b/synapse/storage/databases/main/state_deltas.py @@ -27,7 +27,7 @@ class StateDeltasStore(SQLBaseStore): # attribute. TODO: can we get static analysis to enforce this? _curr_state_delta_stream_cache: StreamChangeCache - async def get_current_state_deltas( + async def get_partial_current_state_deltas( self, prev_stream_id: int, max_stream_id: int ) -> Tuple[int, List[Dict[str, Any]]]: """Fetch a list of room state changes since the given stream id @@ -42,6 +42,8 @@ class StateDeltasStore(SQLBaseStore): - prev_event_id (str|None): previous event_id for this state key. None if it's new state. + This may be the partial state if we're lazy joining the room. + Args: prev_stream_id: point to get changes since (exclusive) max_stream_id: the point that we know has been correctly persisted diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index 0e3a23a140..8e88784d3c 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -765,15 +765,16 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): self, room_id: str, end_token: RoomStreamToken, - ) -> Optional[EventBase]: - """Returns the last event in a room at or before a stream ordering + ) -> Optional[str]: + """Returns the ID of the last event in a room at or before a stream ordering Args: room_id end_token: The token used to stream from Returns: - The most recent event. + The ID of the most recent event, or None if there are no events in the room + before this stream ordering. """ last_row = await self.get_room_event_before_stream_ordering( @@ -781,10 +782,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): stream_ordering=end_token.stream, ) if last_row: - _, _, event_id = last_row - event = await self.get_event(event_id, get_prev_content=True) - return event - + return last_row[2] return None async def get_current_room_stream_token_for_room_id( diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index 2282242e9d..ddb25b5cea 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -441,7 +441,9 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): (EventTypes.RoomHistoryVisibility, ""), ) - current_state_ids = await self.get_filtered_current_state_ids( # type: ignore[attr-defined] + # Getting the partial state is fine, as we're not looking at membership + # events. + current_state_ids = await self.get_partial_filtered_current_state_ids( # type: ignore[attr-defined] room_id, StateFilter.from_types(types_to_filter) ) diff --git a/synapse/storage/schema/__init__.py b/synapse/storage/schema/__init__.py index 19466150d4..5843fae605 100644 --- a/synapse/storage/schema/__init__.py +++ b/synapse/storage/schema/__init__.py @@ -70,6 +70,7 @@ Changes in SCHEMA_VERSION = 70: Changes in SCHEMA_VERSION = 71: - event_edges.room_id is no longer read from. + - Tables related to groups are no longer accessed. """ diff --git a/synapse/storage/util/partial_state_events_tracker.py b/synapse/storage/util/partial_state_events_tracker.py index a61a951ef0..211437cfaa 100644 --- a/synapse/storage/util/partial_state_events_tracker.py +++ b/synapse/storage/util/partial_state_events_tracker.py @@ -21,6 +21,7 @@ from twisted.internet.defer import Deferred from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable from synapse.storage.databases.main.events_worker import EventsWorkerStore +from synapse.storage.databases.main.room import RoomWorkerStore from synapse.util import unwrapFirstError logger = logging.getLogger(__name__) @@ -118,3 +119,62 @@ class PartialStateEventsTracker: observer_set.discard(observer) if not observer_set: del self._observers[event_id] + + +class PartialCurrentStateTracker: + """Keeps track of which rooms have partial state, after partial-state joins""" + + def __init__(self, store: RoomWorkerStore): + self._store = store + + # a map from room id to a set of Deferreds which are waiting for that room to be + # un-partial-stated. + self._observers: Dict[str, Set[Deferred[None]]] = defaultdict(set) + + def notify_un_partial_stated(self, room_id: str) -> None: + """Notify that we now have full current state for a given room + + Unblocks any callers to await_full_state() for that room. + + Args: + room_id: the room that now has full current state. + """ + observers = self._observers.pop(room_id, None) + if not observers: + return + logger.info( + "Notifying %i things waiting for un-partial-stating of room %s", + len(observers), + room_id, + ) + with PreserveLoggingContext(): + for o in observers: + o.callback(None) + + async def await_full_state(self, room_id: str) -> None: + # We add the deferred immediately so that the DB call to check for + # partial state doesn't race when we unpartial the room. + d: Deferred[None] = Deferred() + self._observers.setdefault(room_id, set()).add(d) + + try: + # Check if the room has partial current state or not. + has_partial_state = await self._store.is_partial_state_room(room_id) + if not has_partial_state: + return + + logger.info( + "Awaiting un-partial-stating of room %s", + room_id, + ) + + await make_deferred_yieldable(d) + + logger.info("Room has un-partial-stated") + finally: + # Remove the added observer, and remove the room entry if its empty. + ds = self._observers.get(room_id) + if ds is not None: + ds.discard(d) + if not ds: + self._observers.pop(room_id, None) diff --git a/synapse/visibility.py b/synapse/visibility.py index 97548c14e3..8aaa8c709f 100644 --- a/synapse/visibility.py +++ b/synapse/visibility.py @@ -162,16 +162,7 @@ async def filter_events_for_client( state = event_id_to_state[event.event_id] # get the room_visibility at the time of the event. - visibility_event = state.get(_HISTORY_VIS_KEY, None) - if visibility_event: - visibility = visibility_event.content.get( - "history_visibility", HistoryVisibility.SHARED - ) - else: - visibility = HistoryVisibility.SHARED - - if visibility not in VISIBILITY_PRIORITY: - visibility = HistoryVisibility.SHARED + visibility = get_effective_room_visibility_from_state(state) # Always allow history visibility events on boundaries. This is done # by setting the effective visibility to the least restrictive @@ -267,6 +258,23 @@ async def filter_events_for_client( return [ev for ev in filtered_events if ev] +def get_effective_room_visibility_from_state(state: StateMap[EventBase]) -> str: + """Get the actual history vis, from a state map including the history_visibility event + + Handles missing and invalid history visibility events. + """ + visibility_event = state.get(_HISTORY_VIS_KEY, None) + if not visibility_event: + return HistoryVisibility.SHARED + + visibility = visibility_event.content.get( + "history_visibility", HistoryVisibility.SHARED + ) + if visibility not in VISIBILITY_PRIORITY: + visibility = HistoryVisibility.SHARED + return visibility + + async def filter_events_for_server( storage: StorageControllers, server_name: str, diff --git a/tests/api/test_auth.py b/tests/api/test_auth.py index d547df8a64..bc75ddd3e9 100644 --- a/tests/api/test_auth.py +++ b/tests/api/test_auth.py @@ -404,7 +404,6 @@ class AuthTestCase(unittest.HomeserverTestCase): appservice = ApplicationService( "abcd", - self.hs.config.server.server_name, id="1234", namespaces={ "users": [{"regex": "@_appservice.*:sender", "exclusive": True}] @@ -433,7 +432,6 @@ class AuthTestCase(unittest.HomeserverTestCase): appservice = ApplicationService( "abcd", - self.hs.config.server.server_name, id="1234", namespaces={ "users": [{"regex": "@_appservice.*:sender", "exclusive": True}] diff --git a/tests/api/test_ratelimiting.py b/tests/api/test_ratelimiting.py index 483d5463ad..f661a9ff8e 100644 --- a/tests/api/test_ratelimiting.py +++ b/tests/api/test_ratelimiting.py @@ -31,7 +31,6 @@ class TestRatelimiter(unittest.HomeserverTestCase): def test_allowed_appservice_ratelimited_via_can_requester_do_action(self): appservice = ApplicationService( None, - "example.com", id="foo", rate_limited=True, sender="@as:example.com", @@ -62,7 +61,6 @@ class TestRatelimiter(unittest.HomeserverTestCase): def test_allowed_appservice_via_can_requester_do_action(self): appservice = ApplicationService( None, - "example.com", id="foo", rate_limited=False, sender="@as:example.com", diff --git a/tests/appservice/test_api.py b/tests/appservice/test_api.py index 3e0db4dd98..532b676365 100644 --- a/tests/appservice/test_api.py +++ b/tests/appservice/test_api.py @@ -37,7 +37,6 @@ class ApplicationServiceApiTestCase(unittest.HomeserverTestCase): url=URL, token="unused", hs_token=TOKEN, - hostname="myserver", ) def test_query_3pe_authenticates_token(self): diff --git a/tests/appservice/test_appservice.py b/tests/appservice/test_appservice.py index 7135362f76..3018d3fc6f 100644 --- a/tests/appservice/test_appservice.py +++ b/tests/appservice/test_appservice.py @@ -33,7 +33,6 @@ class ApplicationServiceTestCase(unittest.TestCase): sender="@as:test", url="some_url", token="some_token", - hostname="matrix.org", # only used by get_groups_for_user ) self.event = Mock( event_id="$abc:xyz", diff --git a/tests/crypto/test_event_signing.py b/tests/crypto/test_event_signing.py index 06e0545a4f..8fa710c9dc 100644 --- a/tests/crypto/test_event_signing.py +++ b/tests/crypto/test_event_signing.py @@ -12,10 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. - -import nacl.signing -import signedjson.types -from unpaddedbase64 import decode_base64 +from signedjson.key import decode_signing_key_base64 +from signedjson.types import SigningKey from synapse.api.room_versions import RoomVersions from synapse.crypto.event_signing import add_hashes_and_signatures @@ -25,7 +23,7 @@ from tests import unittest # Perform these tests using given secret key so we get entirely deterministic # signatures output that we can test against. -SIGNING_KEY_SEED = decode_base64("YJDBA9Xnr2sVqXD9Vj7XVUnmFZcZrlw8Md7kMW+3XA1") +SIGNING_KEY_SEED = "YJDBA9Xnr2sVqXD9Vj7XVUnmFZcZrlw8Md7kMW+3XA1" KEY_ALG = "ed25519" KEY_VER = "1" @@ -36,14 +34,9 @@ HOSTNAME = "domain" class EventSigningTestCase(unittest.TestCase): def setUp(self): - # NB: `signedjson` expects `nacl.signing.SigningKey` instances which have been - # monkeypatched to include new `alg` and `version` attributes. This is captured - # by the `signedjson.types.SigningKey` protocol. - self.signing_key: signedjson.types.SigningKey = nacl.signing.SigningKey( # type: ignore[assignment] - SIGNING_KEY_SEED + self.signing_key: SigningKey = decode_signing_key_base64( + KEY_ALG, KEY_VER, SIGNING_KEY_SEED ) - self.signing_key.alg = KEY_ALG - self.signing_key.version = KEY_VER def test_sign_minimal(self): event_dict = { diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index d00ef24ca8..820a1a54e2 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -19,8 +19,8 @@ import attr import canonicaljson import signedjson.key import signedjson.sign -from nacl.signing import SigningKey from signedjson.key import encode_verify_key_base64, get_verify_key +from signedjson.types import SigningKey from twisted.internet import defer from twisted.internet.defer import Deferred, ensureDeferred diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py index b5be727fe4..01a1db6115 100644 --- a/tests/federation/test_federation_sender.py +++ b/tests/federation/test_federation_sender.py @@ -30,16 +30,16 @@ from tests.unittest import HomeserverTestCase, override_config class FederationSenderReceiptsTestCases(HomeserverTestCase): def make_homeserver(self, reactor, clock): - mock_state_handler = Mock(spec=["get_current_hosts_in_room"]) - # Ensure a new Awaitable is created for each call. - mock_state_handler.get_current_hosts_in_room.return_value = make_awaitable( - ["test", "host2"] - ) - return self.setup_test_homeserver( - state_handler=mock_state_handler, + hs = self.setup_test_homeserver( federation_transport_client=Mock(spec=["send_transaction"]), ) + hs.get_storage_controllers().state.get_current_hosts_in_room = Mock( + return_value=make_awaitable({"test", "host2"}) + ) + + return hs + @override_config({"send_federation": True}) def test_send_receipts(self): mock_send_transaction = ( diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py index b19365b81a..413b3c9426 100644 --- a/tests/federation/test_federation_server.py +++ b/tests/federation/test_federation_server.py @@ -134,6 +134,8 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer): super().prepare(reactor, clock, hs) + self._storage_controllers = hs.get_storage_controllers() + # create the room creator_user_id = self.register_user("kermit", "test") tok = self.login("kermit", "test") @@ -207,7 +209,7 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase): # the room should show that the new user is a member r = self.get_success( - self.hs.get_state_handler().get_current_state(self._room_id) + self._storage_controllers.state.get_current_state(self._room_id) ) self.assertEqual(r[("m.room.member", joining_user)].membership, "join") @@ -258,7 +260,7 @@ class SendJoinFederationTests(unittest.FederatingHomeserverTestCase): # the room should show that the new user is a member r = self.get_success( - self.hs.get_state_handler().get_current_state(self._room_id) + self._storage_controllers.state.get_current_state(self._room_id) ) self.assertEqual(r[("m.room.member", joining_user)].membership, "join") diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py index 0e100c404d..d96d5aa138 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py @@ -697,7 +697,6 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase): # Create an application service appservice = ApplicationService( token=random_string(10), - hostname="example.com", id=random_string(10), sender="@as:example.com", rate_limited=False, @@ -776,7 +775,6 @@ class ApplicationServicesHandlerDeviceListsTestCase(unittest.HomeserverTestCase) # Create an appservice that is interested in "local_user" appservice = ApplicationService( token=random_string(10), - hostname="example.com", id=random_string(10), sender="@as:example.com", rate_limited=False, @@ -843,7 +841,6 @@ class ApplicationServicesHandlerOtkCountsTestCase(unittest.HomeserverTestCase): self._service_token = "VERYSECRET" self._service = ApplicationService( self._service_token, - "as1.invalid", "as1", "@as.sender:test", namespaces={ diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py index 11ad44223d..53d49ca896 100644 --- a/tests/handlers/test_directory.py +++ b/tests/handlers/test_directory.py @@ -298,6 +298,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase): self.store = hs.get_datastores().main self.handler = hs.get_directory_handler() self.state_handler = hs.get_state_handler() + self._storage_controllers = hs.get_storage_controllers() # Create user self.admin_user = self.register_user("admin", "pass", admin=True) @@ -335,7 +336,7 @@ class CanonicalAliasTestCase(unittest.HomeserverTestCase): def _get_canonical_alias(self): """Get the canonical alias state of the room.""" return self.get_success( - self.state_handler.get_current_state( + self._storage_controllers.state.get_current_state_event( self.room_id, EventTypes.CanonicalAlias, "" ) ) diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py index 500c9ccfbc..e0eda545b9 100644 --- a/tests/handlers/test_federation.py +++ b/tests/handlers/test_federation.py @@ -237,7 +237,9 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase): ) current_state = self.get_success( self.store.get_events_as_list( - (self.get_success(self.store.get_current_state_ids(room_id))).values() + ( + self.get_success(self.store.get_partial_current_state_ids(room_id)) + ).values() ) ) @@ -512,7 +514,7 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase): self.get_success(d) # sanity-check: the room should show that the new user is a member - r = self.get_success(self.store.get_current_state_ids(room_id)) + r = self.get_success(self.store.get_partial_current_state_ids(room_id)) self.assertEqual(r[(EventTypes.Member, other_user)], join_event.event_id) return join_event diff --git a/tests/handlers/test_federation_event.py b/tests/handlers/test_federation_event.py index 1d5b2492c0..1a36c25c41 100644 --- a/tests/handlers/test_federation_event.py +++ b/tests/handlers/test_federation_event.py @@ -91,7 +91,9 @@ class FederationEventHandlerTests(unittest.FederatingHomeserverTestCase): event_injection.inject_member_event(self.hs, room_id, OTHER_USER, "join") ) - initial_state_map = self.get_success(main_store.get_current_state_ids(room_id)) + initial_state_map = self.get_success( + main_store.get_partial_current_state_ids(room_id) + ) auth_event_ids = [ initial_state_map[("m.room.create", "")], diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index 057256cecd..7af1333126 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -129,10 +129,12 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): hs.get_event_auth_handler().check_host_in_room = check_host_in_room - def get_joined_hosts_for_room(room_id: str): + async def get_current_hosts_in_room(room_id: str): return {member.domain for member in self.room_members} - self.datastore.get_joined_hosts_for_room = get_joined_hosts_for_room + hs.get_storage_controllers().state.get_current_hosts_in_room = ( + get_current_hosts_in_room + ) async def get_users_in_room(room_id: str): return {str(u) for u in self.room_members} @@ -146,7 +148,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): ) ) - self.datastore.get_current_state_deltas = Mock(return_value=(0, None)) + self.datastore.get_partial_current_state_deltas = Mock(return_value=(0, None)) self.datastore.get_to_device_stream_token = lambda: 0 self.datastore.get_new_device_msgs_for_remote = ( diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py index a68c2ffd45..9e39cd97e5 100644 --- a/tests/handlers/test_user_directory.py +++ b/tests/handlers/test_user_directory.py @@ -60,7 +60,6 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase): self.appservice = ApplicationService( token="i_am_an_app_service", - hostname="test", id="1234", namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]}, # Note: this user does not match the regex above, so that tests diff --git a/tests/rest/admin/test_room.py b/tests/rest/admin/test_room.py index 608d3f2dc3..ca6af9417b 100644 --- a/tests/rest/admin/test_room.py +++ b/tests/rest/admin/test_room.py @@ -2467,7 +2467,6 @@ PURGE_TABLES = [ "event_push_actions", "event_search", "events", - "group_rooms", "receipts_graph", "receipts_linearized", "room_aliases", @@ -2484,7 +2483,6 @@ PURGE_TABLES = [ "e2e_room_keys", "event_push_summary", "pusher_throttle", - "group_summary_rooms", "room_account_data", "room_tags", # "state_groups", # Current impl leaves orphaned state groups around. diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py index e0a11da97b..a43a137273 100644 --- a/tests/rest/client/test_account.py +++ b/tests/rest/client/test_account.py @@ -548,7 +548,6 @@ class WhoamiTestCase(unittest.HomeserverTestCase): appservice = ApplicationService( as_token, - self.hs.config.server.server_name, id="1234", namespaces={"users": [{"regex": user_id, "exclusive": True}]}, sender=user_id, diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py index 4920468f7a..f4ea1209d9 100644 --- a/tests/rest/client/test_login.py +++ b/tests/rest/client/test_login.py @@ -1112,7 +1112,6 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): self.service = ApplicationService( id="unique_identifier", token="some_token", - hostname="example.com", sender="@asbot:example.com", namespaces={ ApplicationService.NS_USERS: [ @@ -1125,7 +1124,6 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase): self.another_service = ApplicationService( id="another__identifier", token="another_token", - hostname="example.com", sender="@as2bot:example.com", namespaces={ ApplicationService.NS_USERS: [ diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py index 9aebf1735a..afb08b2736 100644 --- a/tests/rest/client/test_register.py +++ b/tests/rest/client/test_register.py @@ -56,7 +56,6 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): appservice = ApplicationService( as_token, - self.hs.config.server.server_name, id="1234", namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]}, sender="@as:test", @@ -80,7 +79,6 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase): appservice = ApplicationService( as_token, - self.hs.config.server.server_name, id="1234", namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]}, sender="@as:test", diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py index bc9cc51b92..62e4db23ef 100644 --- a/tests/rest/client/test_relations.py +++ b/tests/rest/client/test_relations.py @@ -896,6 +896,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): relation_type: str, assertion_callable: Callable[[JsonDict], None], expected_db_txn_for_event: int, + access_token: Optional[str] = None, ) -> None: """ Makes requests to various endpoints which should include bundled aggregations @@ -907,7 +908,9 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): for relation-specific assertions. expected_db_txn_for_event: The number of database transactions which are expected for a call to /event/. + access_token: The access token to user, defaults to self.user_token. """ + access_token = access_token or self.user_token def assert_bundle(event_json: JsonDict) -> None: """Assert the expected values of the bundled aggregations.""" @@ -921,7 +924,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): channel = self.make_request( "GET", f"/rooms/{self.room}/event/{self.parent_id}", - access_token=self.user_token, + access_token=access_token, ) self.assertEqual(200, channel.code, channel.json_body) assert_bundle(channel.json_body) @@ -932,7 +935,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): channel = self.make_request( "GET", f"/rooms/{self.room}/messages?dir=b", - access_token=self.user_token, + access_token=access_token, ) self.assertEqual(200, channel.code, channel.json_body) assert_bundle(self._find_event_in_chunk(channel.json_body["chunk"])) @@ -941,7 +944,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): channel = self.make_request( "GET", f"/rooms/{self.room}/context/{self.parent_id}", - access_token=self.user_token, + access_token=access_token, ) self.assertEqual(200, channel.code, channel.json_body) assert_bundle(channel.json_body["event"]) @@ -949,7 +952,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): # Request sync. filter = urllib.parse.quote_plus(b'{"room": {"timeline": {"limit": 4}}}') channel = self.make_request( - "GET", f"/sync?filter={filter}", access_token=self.user_token + "GET", f"/sync?filter={filter}", access_token=access_token ) self.assertEqual(200, channel.code, channel.json_body) room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"] @@ -962,7 +965,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): "/search", # Search term matches the parent message. content={"search_categories": {"room_events": {"search_term": "Hi"}}}, - access_token=self.user_token, + access_token=access_token, ) self.assertEqual(200, channel.code, channel.json_body) chunk = [ @@ -1037,30 +1040,60 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): """ Test that threads get correctly bundled. """ - self._send_relation(RelationTypes.THREAD, "m.room.test") - channel = self._send_relation(RelationTypes.THREAD, "m.room.test") + # The root message is from "user", send replies as "user2". + self._send_relation( + RelationTypes.THREAD, "m.room.test", access_token=self.user2_token + ) + channel = self._send_relation( + RelationTypes.THREAD, "m.room.test", access_token=self.user2_token + ) thread_2 = channel.json_body["event_id"] - def assert_thread(bundled_aggregations: JsonDict) -> None: - self.assertEqual(2, bundled_aggregations.get("count")) - self.assertTrue(bundled_aggregations.get("current_user_participated")) - # The latest thread event has some fields that don't matter. - self.assert_dict( - { - "content": { - "m.relates_to": { - "event_id": self.parent_id, - "rel_type": RelationTypes.THREAD, - } + # This needs two assertion functions which are identical except for whether + # the current_user_participated flag is True, create a factory for the + # two versions. + def _gen_assert(participated: bool) -> Callable[[JsonDict], None]: + def assert_thread(bundled_aggregations: JsonDict) -> None: + self.assertEqual(2, bundled_aggregations.get("count")) + self.assertEqual( + participated, bundled_aggregations.get("current_user_participated") + ) + # The latest thread event has some fields that don't matter. + self.assert_dict( + { + "content": { + "m.relates_to": { + "event_id": self.parent_id, + "rel_type": RelationTypes.THREAD, + } + }, + "event_id": thread_2, + "sender": self.user2_id, + "type": "m.room.test", }, - "event_id": thread_2, - "sender": self.user_id, - "type": "m.room.test", - }, - bundled_aggregations.get("latest_event"), - ) + bundled_aggregations.get("latest_event"), + ) - self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 9) + return assert_thread + + # The "user" sent the root event and is making queries for the bundled + # aggregations: they have participated. + self._test_bundled_aggregations(RelationTypes.THREAD, _gen_assert(True), 8) + # The "user2" sent replies in the thread and is making queries for the + # bundled aggregations: they have participated. + # + # Note that this re-uses some cached values, so the total number of + # queries is much smaller. + self._test_bundled_aggregations( + RelationTypes.THREAD, _gen_assert(True), 2, access_token=self.user2_token + ) + + # A user with no interactions with the thread: they have not participated. + user3_id, user3_token = self._create_user("charlie") + self.helper.join(self.room, user=user3_id, tok=user3_token) + self._test_bundled_aggregations( + RelationTypes.THREAD, _gen_assert(False), 2, access_token=user3_token + ) def test_thread_with_bundled_aggregations_for_latest(self) -> None: """ @@ -1106,7 +1139,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase): bundled_aggregations["latest_event"].get("unsigned"), ) - self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 9) + self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 8) def test_nested_thread(self) -> None: """ diff --git a/tests/rest/client/test_room_batch.py b/tests/rest/client/test_room_batch.py index 1b7ee08ab2..9d5cb60d16 100644 --- a/tests/rest/client/test_room_batch.py +++ b/tests/rest/client/test_room_batch.py @@ -71,7 +71,6 @@ class RoomBatchTestCase(unittest.HomeserverTestCase): self.appservice = ApplicationService( token="i_am_an_app_service", - hostname="test", id="1234", namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]}, # Note: this user does not have to match the regex above diff --git a/tests/rest/client/test_upgrade_room.py b/tests/rest/client/test_upgrade_room.py index a21cbe9fa8..98c1039d33 100644 --- a/tests/rest/client/test_upgrade_room.py +++ b/tests/rest/client/test_upgrade_room.py @@ -249,7 +249,9 @@ class UpgradeRoomTest(unittest.HomeserverTestCase): new_space_id = channel.json_body["replacement_room"] - state_ids = self.get_success(self.store.get_current_state_ids(new_space_id)) + state_ids = self.get_success( + self.store.get_partial_current_state_ids(new_space_id) + ) # Ensure the new room is still a space. create_event = self.get_success( @@ -284,7 +286,9 @@ class UpgradeRoomTest(unittest.HomeserverTestCase): new_room_id = channel.json_body["replacement_room"] - state_ids = self.get_success(self.store.get_current_state_ids(new_room_id)) + state_ids = self.get_success( + self.store.get_partial_current_state_ids(new_room_id) + ) # Ensure the new room is the same type as the old room. create_event = self.get_success( diff --git a/tests/rest/media/v1/test_html_preview.py b/tests/rest/media/v1/test_html_preview.py index 62e308814d..ea9e5889bf 100644 --- a/tests/rest/media/v1/test_html_preview.py +++ b/tests/rest/media/v1/test_html_preview.py @@ -145,7 +145,7 @@ class SummarizeTestCase(unittest.TestCase): ) -class CalcOgTestCase(unittest.TestCase): +class OpenGraphFromHtmlTestCase(unittest.TestCase): if not lxml: skip = "url preview feature requires lxml" @@ -235,6 +235,21 @@ class CalcOgTestCase(unittest.TestCase): self.assertEqual(og, {"og:title": None, "og:description": "Some text."}) + # Another variant is a title with no content. + html = b""" + + + +

Title

+ + + """ + + tree = decode_body(html, "http://example.com/test.html") + og = parse_html_to_open_graph(tree) + + self.assertEqual(og, {"og:title": "Title", "og:description": "Title"}) + def test_h1_as_title(self) -> None: html = b""" @@ -250,6 +265,26 @@ class CalcOgTestCase(unittest.TestCase): self.assertEqual(og, {"og:title": "Title", "og:description": "Some text."}) + def test_empty_description(self) -> None: + """Description tags with empty content should be ignored.""" + html = b""" + + + + + + + +

Title

+ + + """ + + tree = decode_body(html, "http://example.com/test.html") + og = parse_html_to_open_graph(tree) + + self.assertEqual(og, {"og:title": "Title", "og:description": "Finally!"}) + def test_missing_title_and_broken_h1(self) -> None: html = b""" diff --git a/tests/rest/media/v1/test_url_preview.py b/tests/rest/media/v1/test_url_preview.py index 3b24d0ace6..2c321f8d04 100644 --- a/tests/rest/media/v1/test_url_preview.py +++ b/tests/rest/media/v1/test_url_preview.py @@ -656,6 +656,41 @@ class URLPreviewTests(unittest.HomeserverTestCase): server.data, ) + def test_nonexistent_image(self) -> None: + """If the preview image doesn't exist, ensure some data is returned.""" + self.lookups["matrix.org"] = [(IPv4Address, "10.1.2.3")] + + end_content = ( + b"""""" + ) + + channel = self.make_request( + "GET", + "preview_url?url=http://matrix.org", + shorthand=False, + await_result=False, + ) + self.pump() + + client = self.reactor.tcpClients[0][2].buildProtocol(None) + server = AccumulatingProtocol() + server.makeConnection(FakeTransport(client, self.reactor)) + client.makeConnection(FakeTransport(server, self.reactor)) + client.dataReceived( + ( + b"HTTP/1.0 200 OK\r\nContent-Length: %d\r\n" + b'Content-Type: text/html; charset="utf8"\r\n\r\n' + ) + % (len(end_content),) + + end_content + ) + + self.pump() + self.assertEqual(channel.code, 200) + + # The image should not be in the result. + self.assertNotIn("og:image", channel.json_body) + def test_data_url(self) -> None: """ Requesting to preview a data URL is not supported. diff --git a/tests/storage/test_events.py b/tests/storage/test_events.py index a76718e8f9..2ff88e64a5 100644 --- a/tests/storage/test_events.py +++ b/tests/storage/test_events.py @@ -32,6 +32,7 @@ class ExtremPruneTestCase(HomeserverTestCase): def prepare(self, reactor, clock, homeserver): self.state = self.hs.get_state_handler() self._persistence = self.hs.get_storage_controllers().persistence + self._state_storage_controller = self.hs.get_storage_controllers().state self.store = self.hs.get_datastores().main self.register_user("user", "pass") @@ -104,7 +105,7 @@ class ExtremPruneTestCase(HomeserverTestCase): ) state_before_gap = self.get_success( - self.state.get_current_state_ids(self.room_id) + self._state_storage_controller.get_current_state_ids(self.room_id) ) self.persist_event(remote_event_2, state=state_before_gap) @@ -137,7 +138,9 @@ class ExtremPruneTestCase(HomeserverTestCase): # setting. The state resolution across the old and new event will then # include it, and so the resolved state won't match the new state. state_before_gap = dict( - self.get_success(self.state.get_current_state_ids(self.room_id)) + self.get_success( + self._state_storage_controller.get_current_state_ids(self.room_id) + ) ) state_before_gap.pop(("m.room.history_visibility", "")) @@ -181,7 +184,7 @@ class ExtremPruneTestCase(HomeserverTestCase): ) state_before_gap = self.get_success( - self.state.get_current_state_ids(self.room_id) + self._state_storage_controller.get_current_state_ids(self.room_id) ) self.persist_event(remote_event_2, state=state_before_gap) @@ -213,7 +216,7 @@ class ExtremPruneTestCase(HomeserverTestCase): ) state_before_gap = self.get_success( - self.state.get_current_state_ids(self.room_id) + self._state_storage_controller.get_current_state_ids(self.room_id) ) self.persist_event(remote_event_2, state=state_before_gap) @@ -255,7 +258,7 @@ class ExtremPruneTestCase(HomeserverTestCase): ) state_before_gap = self.get_success( - self.state.get_current_state_ids(self.room_id) + self._state_storage_controller.get_current_state_ids(self.room_id) ) self.persist_event(remote_event_2, state=state_before_gap) @@ -299,7 +302,7 @@ class ExtremPruneTestCase(HomeserverTestCase): ) state_before_gap = self.get_success( - self.state.get_current_state_ids(self.room_id) + self._state_storage_controller.get_current_state_ids(self.room_id) ) self.persist_event(remote_event_2, state=state_before_gap) @@ -335,7 +338,7 @@ class ExtremPruneTestCase(HomeserverTestCase): ) state_before_gap = self.get_success( - self.state.get_current_state_ids(self.room_id) + self._state_storage_controller.get_current_state_ids(self.room_id) ) self.persist_event(remote_event_2, state=state_before_gap) diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py index 92cd0dfc05..8dfaa0559b 100644 --- a/tests/storage/test_purge.py +++ b/tests/storage/test_purge.py @@ -102,9 +102,10 @@ class PurgeTests(HomeserverTestCase): first = self.helper.send(self.room_id, body="test1") # Get the current room state. - state_handler = self.hs.get_state_handler() create_event = self.get_success( - state_handler.get_current_state(self.room_id, "m.room.create", "") + self._storage_controllers.state.get_current_state_event( + self.room_id, "m.room.create", "" + ) ) self.assertIsNotNone(create_event) diff --git a/tests/storage/test_room.py b/tests/storage/test_room.py index d497a19f63..3c79dabc9f 100644 --- a/tests/storage/test_room.py +++ b/tests/storage/test_room.py @@ -72,7 +72,7 @@ class RoomEventsStoreTestCase(HomeserverTestCase): # Room events need the full datastore, for persist_event() and # get_room_state() self.store = hs.get_datastores().main - self._storage = hs.get_storage_controllers() + self._storage_controllers = hs.get_storage_controllers() self.event_factory = hs.get_event_factory() self.room = RoomID.from_string("!abcde:test") @@ -88,7 +88,7 @@ class RoomEventsStoreTestCase(HomeserverTestCase): def inject_room_event(self, **kwargs): self.get_success( - self._storage.persistence.persist_event( + self._storage_controllers.persistence.persist_event( self.event_factory.create_event(room_id=self.room.to_string(), **kwargs) ) ) @@ -101,7 +101,9 @@ class RoomEventsStoreTestCase(HomeserverTestCase): ) state = self.get_success( - self.store.get_current_state(room_id=self.room.to_string()) + self._storage_controllers.state.get_current_state( + room_id=self.room.to_string() + ) ) self.assertEqual(1, len(state)) @@ -118,7 +120,9 @@ class RoomEventsStoreTestCase(HomeserverTestCase): ) state = self.get_success( - self.store.get_current_state(room_id=self.room.to_string()) + self._storage_controllers.state.get_current_state( + room_id=self.room.to_string() + ) ) self.assertEqual(1, len(state)) diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py index 7f1964eb6a..5b60cf5285 100644 --- a/tests/storage/test_user_directory.py +++ b/tests/storage/test_user_directory.py @@ -134,7 +134,6 @@ class UserDirectoryInitialPopulationTestcase(HomeserverTestCase): def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: self.appservice = ApplicationService( token="i_am_an_app_service", - hostname="test", id="1234", namespaces={"users": [{"regex": r"@as_user.*", "exclusive": True}]}, sender="@as:test", diff --git a/tests/storage/util/test_partial_state_events_tracker.py b/tests/storage/util/test_partial_state_events_tracker.py index 303e190b6c..cae14151c0 100644 --- a/tests/storage/util/test_partial_state_events_tracker.py +++ b/tests/storage/util/test_partial_state_events_tracker.py @@ -17,8 +17,12 @@ from unittest import mock from twisted.internet.defer import CancelledError, ensureDeferred -from synapse.storage.util.partial_state_events_tracker import PartialStateEventsTracker +from synapse.storage.util.partial_state_events_tracker import ( + PartialCurrentStateTracker, + PartialStateEventsTracker, +) +from tests.test_utils import make_awaitable from tests.unittest import TestCase @@ -115,3 +119,56 @@ class PartialStateEventsTrackerTestCase(TestCase): self.tracker.notify_un_partial_stated("event1") self.successResultOf(d2) + + +class PartialCurrentStateTrackerTestCase(TestCase): + def setUp(self) -> None: + self.mock_store = mock.Mock(spec_set=["is_partial_state_room"]) + + self.tracker = PartialCurrentStateTracker(self.mock_store) + + def test_does_not_block_for_full_state_rooms(self): + self.mock_store.is_partial_state_room.return_value = make_awaitable(False) + + self.successResultOf(ensureDeferred(self.tracker.await_full_state("room_id"))) + + def test_blocks_for_partial_room_state(self): + self.mock_store.is_partial_state_room.return_value = make_awaitable(True) + + d = ensureDeferred(self.tracker.await_full_state("room_id")) + + # there should be no result yet + self.assertNoResult(d) + + # notifying that the room has been de-partial-stated should unblock + self.tracker.notify_un_partial_stated("room_id") + self.successResultOf(d) + + def test_un_partial_state_race(self): + # We should correctly handle race between awaiting the state and us + # un-partialling the state + async def is_partial_state_room(events): + self.tracker.notify_un_partial_stated("room_id") + return True + + self.mock_store.is_partial_state_room.side_effect = is_partial_state_room + + self.successResultOf(ensureDeferred(self.tracker.await_full_state("room_id"))) + + def test_cancellation(self): + self.mock_store.is_partial_state_room.return_value = make_awaitable(True) + + d1 = ensureDeferred(self.tracker.await_full_state("room_id")) + self.assertNoResult(d1) + + d2 = ensureDeferred(self.tracker.await_full_state("room_id")) + self.assertNoResult(d2) + + d1.cancel() + self.assertFailure(d1, CancelledError) + + # d2 should still be waiting! + self.assertNoResult(d2) + + self.tracker.notify_un_partial_stated("room_id") + self.successResultOf(d2) diff --git a/tests/test_mau.py b/tests/test_mau.py index 5bbc361aa2..f14fcb7db9 100644 --- a/tests/test_mau.py +++ b/tests/test_mau.py @@ -105,7 +105,6 @@ class TestMauLimit(unittest.HomeserverTestCase): self.store.services_cache.append( ApplicationService( token=as_token, - hostname=self.hs.hostname, id="SomeASID", sender="@as_sender:test", namespaces={"users": [{"regex": "@as_*", "exclusive": True}]}, @@ -251,7 +250,6 @@ class TestMauLimit(unittest.HomeserverTestCase): self.store.services_cache.append( ApplicationService( token=as_token_1, - hostname=self.hs.hostname, id="SomeASID", sender="@as_sender_1:test", namespaces={"users": [{"regex": "@as_1.*", "exclusive": True}]}, @@ -262,7 +260,6 @@ class TestMauLimit(unittest.HomeserverTestCase): self.store.services_cache.append( ApplicationService( token=as_token_2, - hostname=self.hs.hostname, id="AnotherASID", sender="@as_sender_2:test", namespaces={"users": [{"regex": "@as_2.*", "exclusive": True}]},