1
0

Compare commits

..

7 Commits

Author SHA1 Message Date
Patrick Cloke
b685c5e7f1 Move note above changes. 2021-01-27 11:02:04 -05:00
Patrick Cloke
71c46652a2 Copy the upgrade note to 1.26.0. 2021-01-27 10:52:45 -05:00
Patrick Cloke
73ed289bd2 1.26.0 2021-01-27 10:50:37 -05:00
Patrick Cloke
69961c7e9f Tweak changes. 2021-01-25 08:26:42 -05:00
Patrick Cloke
a01605c136 1.26.0rc2 2021-01-25 08:25:40 -05:00
Erik Johnston
056327457f Fix chain cover update to handle events with duplicate auth events (#9210) 2021-01-22 19:44:08 +00:00
Erik Johnston
28f255d5f3 Bump psycopg2 version (#9204)
As we use `execute_values` with the `fetch` parameter.
2021-01-22 11:14:49 +00:00
45 changed files with 326 additions and 721 deletions

View File

@@ -9,3 +9,5 @@ apt-get update
apt-get install -y python3.5 python3.5-dev python3-pip libxml2-dev libxslt-dev xmlsec1 zlib1g-dev tox apt-get install -y python3.5 python3.5-dev python3-pip libxml2-dev libxslt-dev xmlsec1 zlib1g-dev tox
export LANG="C.UTF-8" export LANG="C.UTF-8"
exec tox -e py35-old,combine

View File

@@ -1,9 +1,36 @@
Synapse 1.26.0 (2021-01-27)
===========================
This release brings a new schema version for Synapse and rolling back to a previous
version is not trivial. Please review [UPGRADE.rst](UPGRADE.rst) for more details
on these changes and for general upgrade guidance.
No significant changes since 1.26.0rc2.
Synapse 1.26.0rc2 (2021-01-25)
==============================
Bugfixes
--------
- Fix receipts and account data not being sent down sync. Introduced in v1.26.0rc1. ([\#9193](https://github.com/matrix-org/synapse/issues/9193), [\#9195](https://github.com/matrix-org/synapse/issues/9195))
- Fix chain cover update to handle events with duplicate auth events. Introduced in v1.26.0rc1. ([\#9210](https://github.com/matrix-org/synapse/issues/9210))
Internal Changes
----------------
- Add an `oidc-` prefix to any `idp_id`s which are given in the `oidc_providers` configuration. ([\#9189](https://github.com/matrix-org/synapse/issues/9189))
- Bump minimum `psycopg2` version to v2.8. ([\#9204](https://github.com/matrix-org/synapse/issues/9204))
Synapse 1.26.0rc1 (2021-01-20) Synapse 1.26.0rc1 (2021-01-20)
============================== ==============================
This release brings a new schema version for Synapse and rolling back to a previous This release brings a new schema version for Synapse and rolling back to a previous
version is not trivial. Please review [UPGRADE.rst](UPGRADE.rst) for more details version is not trivial. Please review [UPGRADE.rst](UPGRADE.rst) for more details
on these changes and for general upgrade guidance. on these changes and for general upgrade guidance.
Features Features
-------- --------

View File

@@ -1 +0,0 @@
Add tests to `test_user.UsersListTestCase` for List Users Admin API.

View File

@@ -1 +0,0 @@
Various improvements to the federation client.

View File

@@ -1 +0,0 @@
Add link to Matrix VoIP tester for turn-howto.

View File

@@ -1 +0,0 @@
Fix a long-standing bug where Synapse would return a 500 error when a thumbnail did not exist (and auto-generation of thumbnails was not enabled).

View File

@@ -1 +0,0 @@
Speed up chain cover calculation when persisting a batch of state events at once.

View File

@@ -1 +0,0 @@
Add a `long_description_type` to the package metadata.

View File

@@ -1 +0,0 @@
Speed up batch insertion when using PostgreSQL.

View File

@@ -1 +0,0 @@
Emit an error at startup if different Identity Providers are configured with the same `idp_id`.

View File

@@ -1 +0,0 @@
Speed up batch insertion when using PostgreSQL.

View File

@@ -1 +0,0 @@
Add an `oidc-` prefix to any `idp_id`s which are given in the `oidc_providers` configuration.

View File

@@ -1 +0,0 @@
Improve performance of concurrent use of `StreamIDGenerators`.

View File

@@ -1 +0,0 @@
Add some missing source directories to the automatic linting script.

View File

@@ -1 +0,0 @@
Fix receipts or account data not being sent down sync. Introduced in v1.26.0rc1.

View File

@@ -1 +0,0 @@
Fix receipts or account data not being sent down sync. Introduced in v1.26.0rc1.

View File

@@ -232,12 +232,6 @@ Here are a few things to try:
(Understanding the output is beyond the scope of this document!) (Understanding the output is beyond the scope of this document!)
* You can test your Matrix homeserver TURN setup with https://test.voip.librepush.net/.
Note that this test is not fully reliable yet, so don't be discouraged if
the test fails.
[Here](https://github.com/matrix-org/voip-tester) is the github repo of the
source of the tester, where you can file bug reports.
* There is a WebRTC test tool at * There is a WebRTC test tool at
https://webrtc.github.io/samples/src/content/peerconnection/trickle-ice/. To https://webrtc.github.io/samples/src/content/peerconnection/trickle-ice/. To
use it, you will need a username/password for your TURN server. You can use it, you will need a username/password for your TURN server. You can

View File

@@ -80,8 +80,7 @@ else
# then lint everything! # then lint everything!
if [[ -z ${files+x} ]]; then if [[ -z ${files+x} ]]; then
# Lint all source code files and directories # Lint all source code files and directories
# Note: this list aims the mirror the one in tox.ini files=("synapse" "tests" "scripts-dev" "scripts" "contrib" "synctl" "setup.py" "synmark")
files=("synapse" "docker" "tests" "scripts-dev" "scripts" "contrib" "synctl" "setup.py" "synmark" "stubs" ".buildkite")
fi fi
fi fi

View File

@@ -121,7 +121,6 @@ setup(
include_package_data=True, include_package_data=True,
zip_safe=False, zip_safe=False,
long_description=long_description, long_description=long_description,
long_description_content_type="text/x-rst",
python_requires="~=3.5", python_requires="~=3.5",
classifiers=[ classifiers=[
"Development Status :: 5 - Production/Stable", "Development Status :: 5 - Production/Stable",

View File

@@ -48,7 +48,7 @@ try:
except ImportError: except ImportError:
pass pass
__version__ = "1.26.0rc1" __version__ = "1.26.0"
if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)): if bool(os.environ.get("SYNAPSE_TEST_PATCH_LOG_CONTEXTS", False)):
# We import here so that we don't have to install a bunch of deps when # We import here so that we don't have to install a bunch of deps when

View File

@@ -15,7 +15,6 @@
# limitations under the License. # limitations under the License.
import string import string
from collections import Counter
from typing import Iterable, Optional, Tuple, Type from typing import Iterable, Optional, Tuple, Type
import attr import attr
@@ -44,16 +43,6 @@ class OIDCConfig(Config):
except DependencyException as e: except DependencyException as e:
raise ConfigError(e.message) from e raise ConfigError(e.message) from e
# check we don't have any duplicate idp_ids now. (The SSO handler will also
# check for duplicates when the REST listeners get registered, but that happens
# after synapse has forked so doesn't give nice errors.)
c = Counter([i.idp_id for i in self.oidc_providers])
for idp_id, count in c.items():
if count > 1:
raise ConfigError(
"Multiple OIDC providers have the idp_id %r." % idp_id
)
public_baseurl = self.public_baseurl public_baseurl = self.public_baseurl
self.oidc_callback_url = public_baseurl + "_synapse/oidc/callback" self.oidc_callback_url = public_baseurl + "_synapse/oidc/callback"

View File

@@ -18,7 +18,6 @@ import copy
import itertools import itertools
import logging import logging
from typing import ( from typing import (
TYPE_CHECKING,
Any, Any,
Awaitable, Awaitable,
Callable, Callable,
@@ -27,6 +26,7 @@ from typing import (
List, List,
Mapping, Mapping,
Optional, Optional,
Sequence,
Tuple, Tuple,
TypeVar, TypeVar,
Union, Union,
@@ -61,9 +61,6 @@ from synapse.util import unwrapFirstError
from synapse.util.caches.expiringcache import ExpiringCache from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.retryutils import NotRetryingDestination from synapse.util.retryutils import NotRetryingDestination
if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
sent_queries_counter = Counter("synapse_federation_client_sent_queries", "", ["type"]) sent_queries_counter = Counter("synapse_federation_client_sent_queries", "", ["type"])
@@ -83,10 +80,10 @@ class InvalidResponseError(RuntimeError):
class FederationClient(FederationBase): class FederationClient(FederationBase):
def __init__(self, hs: "HomeServer"): def __init__(self, hs):
super().__init__(hs) super().__init__(hs)
self.pdu_destination_tried = {} # type: Dict[str, Dict[str, int]] self.pdu_destination_tried = {}
self._clock.looping_call(self._clear_tried_cache, 60 * 1000) self._clock.looping_call(self._clear_tried_cache, 60 * 1000)
self.state = hs.get_state_handler() self.state = hs.get_state_handler()
self.transport_layer = hs.get_federation_transport_client() self.transport_layer = hs.get_federation_transport_client()
@@ -119,32 +116,33 @@ class FederationClient(FederationBase):
self.pdu_destination_tried[event_id] = destination_dict self.pdu_destination_tried[event_id] = destination_dict
@log_function @log_function
async def make_query( def make_query(
self, self,
destination: str, destination,
query_type: str, query_type,
args: dict, args,
retry_on_dns_fail: bool = False, retry_on_dns_fail=False,
ignore_backoff: bool = False, ignore_backoff=False,
) -> JsonDict: ):
"""Sends a federation Query to a remote homeserver of the given type """Sends a federation Query to a remote homeserver of the given type
and arguments. and arguments.
Args: Args:
destination: Domain name of the remote homeserver destination (str): Domain name of the remote homeserver
query_type: Category of the query type; should match the query_type (str): Category of the query type; should match the
handler name used in register_query_handler(). handler name used in register_query_handler().
args: Mapping of strings to strings containing the details args (dict): Mapping of strings to strings containing the details
of the query request. of the query request.
ignore_backoff: true to ignore the historical backoff data ignore_backoff (bool): true to ignore the historical backoff data
and try the request anyway. and try the request anyway.
Returns: Returns:
The JSON object from the response a Awaitable which will eventually yield a JSON object from the
response
""" """
sent_queries_counter.labels(query_type).inc() sent_queries_counter.labels(query_type).inc()
return await self.transport_layer.make_query( return self.transport_layer.make_query(
destination, destination,
query_type, query_type,
args, args,
@@ -153,52 +151,42 @@ class FederationClient(FederationBase):
) )
@log_function @log_function
async def query_client_keys( def query_client_keys(self, destination, content, timeout):
self, destination: str, content: JsonDict, timeout: int
) -> JsonDict:
"""Query device keys for a device hosted on a remote server. """Query device keys for a device hosted on a remote server.
Args: Args:
destination: Domain name of the remote homeserver destination (str): Domain name of the remote homeserver
content: The query content. content (dict): The query content.
Returns: Returns:
The JSON object from the response an Awaitable which will eventually yield a JSON object from the
response
""" """
sent_queries_counter.labels("client_device_keys").inc() sent_queries_counter.labels("client_device_keys").inc()
return await self.transport_layer.query_client_keys( return self.transport_layer.query_client_keys(destination, content, timeout)
destination, content, timeout
)
@log_function @log_function
async def query_user_devices( def query_user_devices(self, destination, user_id, timeout=30000):
self, destination: str, user_id: str, timeout: int = 30000
) -> JsonDict:
"""Query the device keys for a list of user ids hosted on a remote """Query the device keys for a list of user ids hosted on a remote
server. server.
""" """
sent_queries_counter.labels("user_devices").inc() sent_queries_counter.labels("user_devices").inc()
return await self.transport_layer.query_user_devices( return self.transport_layer.query_user_devices(destination, user_id, timeout)
destination, user_id, timeout
)
@log_function @log_function
async def claim_client_keys( def claim_client_keys(self, destination, content, timeout):
self, destination: str, content: JsonDict, timeout: int
) -> JsonDict:
"""Claims one-time keys for a device hosted on a remote server. """Claims one-time keys for a device hosted on a remote server.
Args: Args:
destination: Domain name of the remote homeserver destination (str): Domain name of the remote homeserver
content: The query content. content (dict): The query content.
Returns: Returns:
The JSON object from the response an Awaitable which will eventually yield a JSON object from the
response
""" """
sent_queries_counter.labels("client_one_time_keys").inc() sent_queries_counter.labels("client_one_time_keys").inc()
return await self.transport_layer.claim_client_keys( return self.transport_layer.claim_client_keys(destination, content, timeout)
destination, content, timeout
)
async def backfill( async def backfill(
self, dest: str, room_id: str, limit: int, extremities: Iterable[str] self, dest: str, room_id: str, limit: int, extremities: Iterable[str]
@@ -207,10 +195,10 @@ class FederationClient(FederationBase):
given destination server. given destination server.
Args: Args:
dest: The remote homeserver to ask. dest (str): The remote homeserver to ask.
room_id: The room_id to backfill. room_id (str): The room_id to backfill.
limit: The maximum number of events to return. limit (int): The maximum number of events to return.
extremities: our current backwards extremities, to backfill from extremities (list): our current backwards extremities, to backfill from
""" """
logger.debug("backfill extrem=%s", extremities) logger.debug("backfill extrem=%s", extremities)
@@ -382,7 +370,7 @@ class FederationClient(FederationBase):
for events that have failed their checks for events that have failed their checks
Returns: Returns:
A list of PDUs that have valid signatures and hashes. Deferred : A list of PDUs that have valid signatures and hashes.
""" """
deferreds = self._check_sigs_and_hashes(room_version, pdus) deferreds = self._check_sigs_and_hashes(room_version, pdus)
@@ -430,9 +418,7 @@ class FederationClient(FederationBase):
else: else:
return [p for p in valid_pdus if p] return [p for p in valid_pdus if p]
async def get_event_auth( async def get_event_auth(self, destination, room_id, event_id):
self, destination: str, room_id: str, event_id: str
) -> List[EventBase]:
res = await self.transport_layer.get_event_auth(destination, room_id, event_id) res = await self.transport_layer.get_event_auth(destination, room_id, event_id)
room_version = await self.store.get_room_version(room_id) room_version = await self.store.get_room_version(room_id)
@@ -714,16 +700,18 @@ class FederationClient(FederationBase):
return await self._try_destination_list("send_join", destinations, send_request) return await self._try_destination_list("send_join", destinations, send_request)
async def _do_send_join(self, destination: str, pdu: EventBase) -> JsonDict: async def _do_send_join(self, destination: str, pdu: EventBase):
time_now = self._clock.time_msec() time_now = self._clock.time_msec()
try: try:
return await self.transport_layer.send_join_v2( content = await self.transport_layer.send_join_v2(
destination=destination, destination=destination,
room_id=pdu.room_id, room_id=pdu.room_id,
event_id=pdu.event_id, event_id=pdu.event_id,
content=pdu.get_pdu_json(time_now), content=pdu.get_pdu_json(time_now),
) )
return content
except HttpResponseException as e: except HttpResponseException as e:
if e.code in [400, 404]: if e.code in [400, 404]:
err = e.to_synapse_error() err = e.to_synapse_error()
@@ -781,7 +769,7 @@ class FederationClient(FederationBase):
time_now = self._clock.time_msec() time_now = self._clock.time_msec()
try: try:
return await self.transport_layer.send_invite_v2( content = await self.transport_layer.send_invite_v2(
destination=destination, destination=destination,
room_id=pdu.room_id, room_id=pdu.room_id,
event_id=pdu.event_id, event_id=pdu.event_id,
@@ -791,6 +779,7 @@ class FederationClient(FederationBase):
"invite_room_state": pdu.unsigned.get("invite_room_state", []), "invite_room_state": pdu.unsigned.get("invite_room_state", []),
}, },
) )
return content
except HttpResponseException as e: except HttpResponseException as e:
if e.code in [400, 404]: if e.code in [400, 404]:
err = e.to_synapse_error() err = e.to_synapse_error()
@@ -853,16 +842,18 @@ class FederationClient(FederationBase):
"send_leave", destinations, send_request "send_leave", destinations, send_request
) )
async def _do_send_leave(self, destination: str, pdu: EventBase) -> JsonDict: async def _do_send_leave(self, destination, pdu):
time_now = self._clock.time_msec() time_now = self._clock.time_msec()
try: try:
return await self.transport_layer.send_leave_v2( content = await self.transport_layer.send_leave_v2(
destination=destination, destination=destination,
room_id=pdu.room_id, room_id=pdu.room_id,
event_id=pdu.event_id, event_id=pdu.event_id,
content=pdu.get_pdu_json(time_now), content=pdu.get_pdu_json(time_now),
) )
return content
except HttpResponseException as e: except HttpResponseException as e:
if e.code in [400, 404]: if e.code in [400, 404]:
err = e.to_synapse_error() err = e.to_synapse_error()
@@ -888,7 +879,7 @@ class FederationClient(FederationBase):
# content. # content.
return resp[1] return resp[1]
async def get_public_rooms( def get_public_rooms(
self, self,
remote_server: str, remote_server: str,
limit: Optional[int] = None, limit: Optional[int] = None,
@@ -896,7 +887,7 @@ class FederationClient(FederationBase):
search_filter: Optional[Dict] = None, search_filter: Optional[Dict] = None,
include_all_networks: bool = False, include_all_networks: bool = False,
third_party_instance_id: Optional[str] = None, third_party_instance_id: Optional[str] = None,
) -> JsonDict: ):
"""Get the list of public rooms from a remote homeserver """Get the list of public rooms from a remote homeserver
Args: Args:
@@ -910,7 +901,8 @@ class FederationClient(FederationBase):
party instance party instance
Returns: Returns:
The response from the remote server. Awaitable[Dict[str, Any]]: The response from the remote server, or None if
`remote_server` is the same as the local server_name
Raises: Raises:
HttpResponseException: There was an exception returned from the remote server HttpResponseException: There was an exception returned from the remote server
@@ -918,7 +910,7 @@ class FederationClient(FederationBase):
requests over federation requests over federation
""" """
return await self.transport_layer.get_public_rooms( return self.transport_layer.get_public_rooms(
remote_server, remote_server,
limit, limit,
since_token, since_token,
@@ -931,7 +923,7 @@ class FederationClient(FederationBase):
self, self,
destination: str, destination: str,
room_id: str, room_id: str,
earliest_events_ids: Iterable[str], earliest_events_ids: Sequence[str],
latest_events: Iterable[EventBase], latest_events: Iterable[EventBase],
limit: int, limit: int,
min_depth: int, min_depth: int,
@@ -982,9 +974,7 @@ class FederationClient(FederationBase):
return signed_events return signed_events
async def forward_third_party_invite( async def forward_third_party_invite(self, destinations, room_id, event_dict):
self, destinations: Iterable[str], room_id: str, event_dict: JsonDict
) -> None:
for destination in destinations: for destination in destinations:
if destination == self.server_name: if destination == self.server_name:
continue continue
@@ -993,7 +983,7 @@ class FederationClient(FederationBase):
await self.transport_layer.exchange_third_party_invite( await self.transport_layer.exchange_third_party_invite(
destination=destination, room_id=room_id, event_dict=event_dict destination=destination, room_id=room_id, event_dict=event_dict
) )
return return None
except CodeMessageException: except CodeMessageException:
raise raise
except Exception as e: except Exception as e:
@@ -1005,7 +995,7 @@ class FederationClient(FederationBase):
async def get_room_complexity( async def get_room_complexity(
self, destination: str, room_id: str self, destination: str, room_id: str
) -> Optional[JsonDict]: ) -> Optional[dict]:
""" """
Fetch the complexity of a remote room from another server. Fetch the complexity of a remote room from another server.
@@ -1018,9 +1008,10 @@ class FederationClient(FederationBase):
could not fetch the complexity. could not fetch the complexity.
""" """
try: try:
return await self.transport_layer.get_room_complexity( complexity = await self.transport_layer.get_room_complexity(
destination=destination, room_id=room_id destination=destination, room_id=room_id
) )
return complexity
except CodeMessageException as e: except CodeMessageException as e:
# We didn't manage to get it -- probably a 404. We are okay if other # We didn't manage to get it -- probably a 404. We are okay if other
# servers don't give it to us. # servers don't give it to us.

View File

@@ -86,8 +86,8 @@ REQUIREMENTS = [
CONDITIONAL_REQUIREMENTS = { CONDITIONAL_REQUIREMENTS = {
"matrix-synapse-ldap3": ["matrix-synapse-ldap3>=0.1"], "matrix-synapse-ldap3": ["matrix-synapse-ldap3>=0.1"],
# we use execute_batch, which arrived in psycopg 2.7. # we use execute_values with the fetch param, which arrived in psycopg 2.8.
"postgres": ["psycopg2>=2.7"], "postgres": ["psycopg2>=2.8"],
# ACME support is required to provision TLS certificates from authorities # ACME support is required to provision TLS certificates from authorities
# that use the protocol, such as Let's Encrypt. # that use the protocol, such as Let's Encrypt.
"acme": [ "acme": [

View File

@@ -83,32 +83,17 @@ class UsersRestServletV2(RestServlet):
The parameter `deactivated` can be used to include deactivated users. The parameter `deactivated` can be used to include deactivated users.
""" """
def __init__(self, hs: "HomeServer"): def __init__(self, hs):
self.hs = hs self.hs = hs
self.store = hs.get_datastore() self.store = hs.get_datastore()
self.auth = hs.get_auth() self.auth = hs.get_auth()
self.admin_handler = hs.get_admin_handler() self.admin_handler = hs.get_admin_handler()
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: async def on_GET(self, request):
await assert_requester_is_admin(self.auth, request) await assert_requester_is_admin(self.auth, request)
start = parse_integer(request, "from", default=0) start = parse_integer(request, "from", default=0)
limit = parse_integer(request, "limit", default=100) limit = parse_integer(request, "limit", default=100)
if start < 0:
raise SynapseError(
400,
"Query parameter from must be a string representing a positive integer.",
errcode=Codes.INVALID_PARAM,
)
if limit < 0:
raise SynapseError(
400,
"Query parameter limit must be a string representing a positive integer.",
errcode=Codes.INVALID_PARAM,
)
user_id = parse_string(request, "user_id", default=None) user_id = parse_string(request, "user_id", default=None)
name = parse_string(request, "name", default=None) name = parse_string(request, "name", default=None)
guests = parse_boolean(request, "guests", default=True) guests = parse_boolean(request, "guests", default=True)
@@ -118,7 +103,7 @@ class UsersRestServletV2(RestServlet):
start, limit, user_id, name, guests, deactivated start, limit, user_id, name, guests, deactivated
) )
ret = {"users": users, "total": total} ret = {"users": users, "total": total}
if (start + limit) < total: if len(users) >= limit:
ret["next_token"] = str(start + len(users)) ret["next_token"] = str(start + len(users))
return 200, ret return 200, ret

View File

@@ -300,7 +300,6 @@ class FileInfo:
thumbnail_height (int) thumbnail_height (int)
thumbnail_method (str) thumbnail_method (str)
thumbnail_type (str): Content type of thumbnail, e.g. image/png thumbnail_type (str): Content type of thumbnail, e.g. image/png
thumbnail_length (int): The size of the media file, in bytes.
""" """
def __init__( def __init__(
@@ -313,7 +312,6 @@ class FileInfo:
thumbnail_height=None, thumbnail_height=None,
thumbnail_method=None, thumbnail_method=None,
thumbnail_type=None, thumbnail_type=None,
thumbnail_length=None,
): ):
self.server_name = server_name self.server_name = server_name
self.file_id = file_id self.file_id = file_id
@@ -323,7 +321,6 @@ class FileInfo:
self.thumbnail_height = thumbnail_height self.thumbnail_height = thumbnail_height
self.thumbnail_method = thumbnail_method self.thumbnail_method = thumbnail_method
self.thumbnail_type = thumbnail_type self.thumbnail_type = thumbnail_type
self.thumbnail_length = thumbnail_length
def get_filename_from_headers(headers: Dict[bytes, List[bytes]]) -> Optional[str]: def get_filename_from_headers(headers: Dict[bytes, List[bytes]]) -> Optional[str]:

View File

@@ -16,7 +16,7 @@
import logging import logging
from typing import TYPE_CHECKING, Any, Dict, List, Optional from typing import TYPE_CHECKING
from twisted.web.http import Request from twisted.web.http import Request
@@ -106,17 +106,31 @@ class ThumbnailResource(DirectServeJsonResource):
return return
thumbnail_infos = await self.store.get_local_media_thumbnails(media_id) thumbnail_infos = await self.store.get_local_media_thumbnails(media_id)
await self._select_and_respond_with_thumbnail(
request, if thumbnail_infos:
width, thumbnail_info = self._select_thumbnail(
height, width, height, method, m_type, thumbnail_infos
method, )
m_type,
thumbnail_infos, file_info = FileInfo(
media_id, server_name=None,
url_cache=media_info["url_cache"], file_id=media_id,
server_name=None, url_cache=media_info["url_cache"],
) thumbnail=True,
thumbnail_width=thumbnail_info["thumbnail_width"],
thumbnail_height=thumbnail_info["thumbnail_height"],
thumbnail_type=thumbnail_info["thumbnail_type"],
thumbnail_method=thumbnail_info["thumbnail_method"],
)
t_type = file_info.thumbnail_type
t_length = thumbnail_info["thumbnail_length"]
responder = await self.media_storage.fetch_media(file_info)
await respond_with_responder(request, responder, t_type, t_length)
else:
logger.info("Couldn't find any generated thumbnails")
respond_404(request)
async def _select_or_generate_local_thumbnail( async def _select_or_generate_local_thumbnail(
self, self,
@@ -262,64 +276,26 @@ class ThumbnailResource(DirectServeJsonResource):
thumbnail_infos = await self.store.get_remote_media_thumbnails( thumbnail_infos = await self.store.get_remote_media_thumbnails(
server_name, media_id server_name, media_id
) )
await self._select_and_respond_with_thumbnail(
request,
width,
height,
method,
m_type,
thumbnail_infos,
media_info["filesystem_id"],
url_cache=None,
server_name=server_name,
)
async def _select_and_respond_with_thumbnail(
self,
request: Request,
desired_width: int,
desired_height: int,
desired_method: str,
desired_type: str,
thumbnail_infos: List[Dict[str, Any]],
file_id: str,
url_cache: Optional[str] = None,
server_name: Optional[str] = None,
) -> None:
"""
Respond to a request with an appropriate thumbnail from the previously generated thumbnails.
Args:
request: The incoming request.
desired_width: The desired width, the returned thumbnail may be larger than this.
desired_height: The desired height, the returned thumbnail may be larger than this.
desired_method: The desired method used to generate the thumbnail.
desired_type: The desired content-type of the thumbnail.
thumbnail_infos: A list of dictionaries of candidate thumbnails.
file_id: The ID of the media that a thumbnail is being requested for.
url_cache: The URL cache value.
server_name: The server name, if this is a remote thumbnail.
"""
if thumbnail_infos: if thumbnail_infos:
file_info = self._select_thumbnail( thumbnail_info = self._select_thumbnail(
desired_width, width, height, method, m_type, thumbnail_infos
desired_height,
desired_method,
desired_type,
thumbnail_infos,
file_id,
url_cache,
server_name,
) )
if not file_info: file_info = FileInfo(
logger.info("Couldn't find a thumbnail matching the desired inputs") server_name=server_name,
respond_404(request) file_id=media_info["filesystem_id"],
return thumbnail=True,
thumbnail_width=thumbnail_info["thumbnail_width"],
thumbnail_height=thumbnail_info["thumbnail_height"],
thumbnail_type=thumbnail_info["thumbnail_type"],
thumbnail_method=thumbnail_info["thumbnail_method"],
)
t_type = file_info.thumbnail_type
t_length = thumbnail_info["thumbnail_length"]
responder = await self.media_storage.fetch_media(file_info) responder = await self.media_storage.fetch_media(file_info)
await respond_with_responder( await respond_with_responder(request, responder, t_type, t_length)
request, responder, file_info.thumbnail_type, file_info.thumbnail_length
)
else: else:
logger.info("Failed to find any generated thumbnails") logger.info("Failed to find any generated thumbnails")
respond_404(request) respond_404(request)
@@ -330,117 +306,67 @@ class ThumbnailResource(DirectServeJsonResource):
desired_height: int, desired_height: int,
desired_method: str, desired_method: str,
desired_type: str, desired_type: str,
thumbnail_infos: List[Dict[str, Any]], thumbnail_infos,
file_id: str, ) -> dict:
url_cache: Optional[str],
server_name: Optional[str],
) -> Optional[FileInfo]:
"""
Choose an appropriate thumbnail from the previously generated thumbnails.
Args:
desired_width: The desired width, the returned thumbnail may be larger than this.
desired_height: The desired height, the returned thumbnail may be larger than this.
desired_method: The desired method used to generate the thumbnail.
desired_type: The desired content-type of the thumbnail.
thumbnail_infos: A list of dictionaries of candidate thumbnails.
file_id: The ID of the media that a thumbnail is being requested for.
url_cache: The URL cache value.
server_name: The server name, if this is a remote thumbnail.
Returns:
The thumbnail which best matches the desired parameters.
"""
desired_method = desired_method.lower()
# The chosen thumbnail.
thumbnail_info = None
d_w = desired_width d_w = desired_width
d_h = desired_height d_h = desired_height
if desired_method == "crop": if desired_method.lower() == "crop":
# Thumbnails that match equal or larger sizes of desired width/height.
crop_info_list = [] crop_info_list = []
# Other thumbnails.
crop_info_list2 = [] crop_info_list2 = []
for info in thumbnail_infos: for info in thumbnail_infos:
# Skip thumbnails generated with different methods.
if info["thumbnail_method"] != "crop":
continue
t_w = info["thumbnail_width"] t_w = info["thumbnail_width"]
t_h = info["thumbnail_height"] t_h = info["thumbnail_height"]
aspect_quality = abs(d_w * t_h - d_h * t_w) t_method = info["thumbnail_method"]
min_quality = 0 if d_w <= t_w and d_h <= t_h else 1 if t_method == "crop":
size_quality = abs((d_w - t_w) * (d_h - t_h)) aspect_quality = abs(d_w * t_h - d_h * t_w)
type_quality = desired_type != info["thumbnail_type"] min_quality = 0 if d_w <= t_w and d_h <= t_h else 1
length_quality = info["thumbnail_length"] size_quality = abs((d_w - t_w) * (d_h - t_h))
if t_w >= d_w or t_h >= d_h: type_quality = desired_type != info["thumbnail_type"]
crop_info_list.append( length_quality = info["thumbnail_length"]
( if t_w >= d_w or t_h >= d_h:
aspect_quality, crop_info_list.append(
min_quality, (
size_quality, aspect_quality,
type_quality, min_quality,
length_quality, size_quality,
info, type_quality,
length_quality,
info,
)
) )
) else:
else: crop_info_list2.append(
crop_info_list2.append( (
( aspect_quality,
aspect_quality, min_quality,
min_quality, size_quality,
size_quality, type_quality,
type_quality, length_quality,
length_quality, info,
info, )
) )
)
if crop_info_list: if crop_info_list:
thumbnail_info = min(crop_info_list)[-1] return min(crop_info_list)[-1]
elif crop_info_list2: else:
thumbnail_info = min(crop_info_list2)[-1] return min(crop_info_list2)[-1]
elif desired_method == "scale": else:
# Thumbnails that match equal or larger sizes of desired width/height.
info_list = [] info_list = []
# Other thumbnails.
info_list2 = [] info_list2 = []
for info in thumbnail_infos: for info in thumbnail_infos:
# Skip thumbnails generated with different methods.
if info["thumbnail_method"] != "scale":
continue
t_w = info["thumbnail_width"] t_w = info["thumbnail_width"]
t_h = info["thumbnail_height"] t_h = info["thumbnail_height"]
t_method = info["thumbnail_method"]
size_quality = abs((d_w - t_w) * (d_h - t_h)) size_quality = abs((d_w - t_w) * (d_h - t_h))
type_quality = desired_type != info["thumbnail_type"] type_quality = desired_type != info["thumbnail_type"]
length_quality = info["thumbnail_length"] length_quality = info["thumbnail_length"]
if t_w >= d_w or t_h >= d_h: if t_method == "scale" and (t_w >= d_w or t_h >= d_h):
info_list.append((size_quality, type_quality, length_quality, info)) info_list.append((size_quality, type_quality, length_quality, info))
else: elif t_method == "scale":
info_list2.append( info_list2.append(
(size_quality, type_quality, length_quality, info) (size_quality, type_quality, length_quality, info)
) )
if info_list: if info_list:
thumbnail_info = min(info_list)[-1] return min(info_list)[-1]
elif info_list2: else:
thumbnail_info = min(info_list2)[-1] return min(info_list2)[-1]
if thumbnail_info:
return FileInfo(
file_id=file_id,
url_cache=url_cache,
server_name=server_name,
thumbnail=True,
thumbnail_width=thumbnail_info["thumbnail_width"],
thumbnail_height=thumbnail_info["thumbnail_height"],
thumbnail_type=thumbnail_info["thumbnail_type"],
thumbnail_method=thumbnail_info["thumbnail_method"],
thumbnail_length=thumbnail_info["thumbnail_length"],
)
# No matching thumbnail was found.
return None

View File

@@ -262,18 +262,13 @@ class LoggingTransaction:
return self.txn.description return self.txn.description
def execute_batch(self, sql: str, args: Iterable[Iterable[Any]]) -> None: def execute_batch(self, sql: str, args: Iterable[Iterable[Any]]) -> None:
"""Similar to `executemany`, except `txn.rowcount` will not be correct
afterwards.
More efficient than `executemany` on PostgreSQL
"""
if isinstance(self.database_engine, PostgresEngine): if isinstance(self.database_engine, PostgresEngine):
from psycopg2.extras import execute_batch # type: ignore from psycopg2.extras import execute_batch # type: ignore
self._do_execute(lambda *x: execute_batch(self.txn, *x), sql, args) self._do_execute(lambda *x: execute_batch(self.txn, *x), sql, args)
else: else:
self.executemany(sql, args) for val in args:
self.execute(sql, val)
def execute_values(self, sql: str, *args: Any) -> List[Tuple]: def execute_values(self, sql: str, *args: Any) -> List[Tuple]:
"""Corresponds to psycopg2.extras.execute_values. Only available when """Corresponds to psycopg2.extras.execute_values. Only available when
@@ -893,7 +888,7 @@ class DatabasePool:
", ".join("?" for _ in keys[0]), ", ".join("?" for _ in keys[0]),
) )
txn.execute_batch(sql, vals) txn.executemany(sql, vals)
async def simple_upsert( async def simple_upsert(
self, self,

View File

@@ -897,7 +897,7 @@ class DeviceWorkerStore(SQLBaseStore):
DELETE FROM device_lists_outbound_last_success DELETE FROM device_lists_outbound_last_success
WHERE destination = ? AND user_id = ? WHERE destination = ? AND user_id = ?
""" """
txn.execute_batch(sql, ((row[0], row[1]) for row in rows)) txn.executemany(sql, ((row[0], row[1]) for row in rows))
logger.info("Pruned %d device list outbound pokes", count) logger.info("Pruned %d device list outbound pokes", count)
@@ -1343,7 +1343,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
# Delete older entries in the table, as we really only care about # Delete older entries in the table, as we really only care about
# when the latest change happened. # when the latest change happened.
txn.execute_batch( txn.executemany(
""" """
DELETE FROM device_lists_stream DELETE FROM device_lists_stream
WHERE user_id = ? AND device_id = ? AND stream_id < ? WHERE user_id = ? AND device_id = ? AND stream_id < ?

View File

@@ -487,7 +487,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
VALUES (?, ?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?, ?)
""" """
txn.execute_batch( txn.executemany(
sql, sql,
( (
_gen_entry(user_id, actions) _gen_entry(user_id, actions)
@@ -803,7 +803,7 @@ class EventPushActionsWorkerStore(SQLBaseStore):
], ],
) )
txn.execute_batch( txn.executemany(
""" """
UPDATE event_push_summary UPDATE event_push_summary
SET notif_count = ?, unread_count = ?, stream_ordering = ? SET notif_count = ?, unread_count = ?, stream_ordering = ?

View File

@@ -473,9 +473,8 @@ class PersistEventsStore:
txn, self.db_pool, event_to_room_id, event_to_types, event_to_auth_chain, txn, self.db_pool, event_to_room_id, event_to_types, event_to_auth_chain,
) )
@classmethod @staticmethod
def _add_chain_cover_index( def _add_chain_cover_index(
cls,
txn, txn,
db_pool: DatabasePool, db_pool: DatabasePool,
event_to_room_id: Dict[str, str], event_to_room_id: Dict[str, str],
@@ -615,17 +614,60 @@ class PersistEventsStore:
if not events_to_calc_chain_id_for: if not events_to_calc_chain_id_for:
return return
# Allocate chain ID/sequence numbers to each new event. # We now calculate the chain IDs/sequence numbers for the events. We
new_chain_tuples = cls._allocate_chain_ids( # do this by looking at the chain ID and sequence number of any auth
txn, # event with the same type/state_key and incrementing the sequence
db_pool, # number by one. If there was no match or the chain ID/sequence
event_to_room_id, # number is already taken we generate a new chain.
event_to_types, #
event_to_auth_chain, # We need to do this in a topologically sorted order as we want to
events_to_calc_chain_id_for, # generate chain IDs/sequence numbers of an event's auth events
chain_map, # before the event itself.
) chains_tuples_allocated = set() # type: Set[Tuple[int, int]]
chain_map.update(new_chain_tuples) new_chain_tuples = {} # type: Dict[str, Tuple[int, int]]
for event_id in sorted_topologically(
events_to_calc_chain_id_for, event_to_auth_chain
):
existing_chain_id = None
for auth_id in event_to_auth_chain.get(event_id, []):
if event_to_types.get(event_id) == event_to_types.get(auth_id):
existing_chain_id = chain_map[auth_id]
break
new_chain_tuple = None
if existing_chain_id:
# We found a chain ID/sequence number candidate, check its
# not already taken.
proposed_new_id = existing_chain_id[0]
proposed_new_seq = existing_chain_id[1] + 1
if (proposed_new_id, proposed_new_seq) not in chains_tuples_allocated:
already_allocated = db_pool.simple_select_one_onecol_txn(
txn,
table="event_auth_chains",
keyvalues={
"chain_id": proposed_new_id,
"sequence_number": proposed_new_seq,
},
retcol="event_id",
allow_none=True,
)
if already_allocated:
# Mark it as already allocated so we don't need to hit
# the DB again.
chains_tuples_allocated.add((proposed_new_id, proposed_new_seq))
else:
new_chain_tuple = (
proposed_new_id,
proposed_new_seq,
)
if not new_chain_tuple:
new_chain_tuple = (db_pool.event_chain_id_gen.get_next_id_txn(txn), 1)
chains_tuples_allocated.add(new_chain_tuple)
chain_map[event_id] = new_chain_tuple
new_chain_tuples[event_id] = new_chain_tuple
db_pool.simple_insert_many_txn( db_pool.simple_insert_many_txn(
txn, txn,
@@ -752,137 +794,6 @@ class PersistEventsStore:
], ],
) )
@staticmethod
def _allocate_chain_ids(
txn,
db_pool: DatabasePool,
event_to_room_id: Dict[str, str],
event_to_types: Dict[str, Tuple[str, str]],
event_to_auth_chain: Dict[str, List[str]],
events_to_calc_chain_id_for: Set[str],
chain_map: Dict[str, Tuple[int, int]],
) -> Dict[str, Tuple[int, int]]:
"""Allocates, but does not persist, chain ID/sequence numbers for the
events in `events_to_calc_chain_id_for`. (c.f. _add_chain_cover_index
for info on args)
"""
# We now calculate the chain IDs/sequence numbers for the events. We do
# this by looking at the chain ID and sequence number of any auth event
# with the same type/state_key and incrementing the sequence number by
# one. If there was no match or the chain ID/sequence number is already
# taken we generate a new chain.
#
# We try to reduce the number of times that we hit the database by
# batching up calls, to make this more efficient when persisting large
# numbers of state events (e.g. during joins).
#
# We do this by:
# 1. Calculating for each event which auth event will be used to
# inherit the chain ID, i.e. converting the auth chain graph to a
# tree that we can allocate chains on. We also keep track of which
# existing chain IDs have been referenced.
# 2. Fetching the max allocated sequence number for each referenced
# existing chain ID, generating a map from chain ID to the max
# allocated sequence number.
# 3. Iterating over the tree and allocating a chain ID/seq no. to the
# new event, by incrementing the sequence number from the
# referenced event's chain ID/seq no. and checking that the
# incremented sequence number hasn't already been allocated (by
# looking in the map generated in the previous step). We generate a
# new chain if the sequence number has already been allocated.
#
existing_chains = set() # type: Set[int]
tree = [] # type: List[Tuple[str, Optional[str]]]
# We need to do this in a topologically sorted order as we want to
# generate chain IDs/sequence numbers of an event's auth events before
# the event itself.
for event_id in sorted_topologically(
events_to_calc_chain_id_for, event_to_auth_chain
):
for auth_id in event_to_auth_chain.get(event_id, []):
if event_to_types.get(event_id) == event_to_types.get(auth_id):
existing_chain_id = chain_map.get(auth_id)
if existing_chain_id:
existing_chains.add(existing_chain_id[0])
tree.append((event_id, auth_id))
break
else:
tree.append((event_id, None))
# Fetch the current max sequence number for each existing referenced chain.
sql = """
SELECT chain_id, MAX(sequence_number) FROM event_auth_chains
WHERE %s
GROUP BY chain_id
"""
clause, args = make_in_list_sql_clause(
db_pool.engine, "chain_id", existing_chains
)
txn.execute(sql % (clause,), args)
chain_to_max_seq_no = {row[0]: row[1] for row in txn} # type: Dict[Any, int]
# Allocate the new events chain ID/sequence numbers.
#
# To reduce the number of calls to the database we don't allocate a
# chain ID number in the loop, instead we use a temporary `object()` for
# each new chain ID. Once we've done the loop we generate the necessary
# number of new chain IDs in one call, replacing all temporary
# objects with real allocated chain IDs.
unallocated_chain_ids = set() # type: Set[object]
new_chain_tuples = {} # type: Dict[str, Tuple[Any, int]]
for event_id, auth_event_id in tree:
# If we reference an auth_event_id we fetch the allocated chain ID,
# either from the existing `chain_map` or the newly generated
# `new_chain_tuples` map.
existing_chain_id = None
if auth_event_id:
existing_chain_id = new_chain_tuples.get(auth_event_id)
if not existing_chain_id:
existing_chain_id = chain_map[auth_event_id]
new_chain_tuple = None # type: Optional[Tuple[Any, int]]
if existing_chain_id:
# We found a chain ID/sequence number candidate, check its
# not already taken.
proposed_new_id = existing_chain_id[0]
proposed_new_seq = existing_chain_id[1] + 1
if chain_to_max_seq_no[proposed_new_id] < proposed_new_seq:
new_chain_tuple = (
proposed_new_id,
proposed_new_seq,
)
# If we need to start a new chain we allocate a temporary chain ID.
if not new_chain_tuple:
new_chain_tuple = (object(), 1)
unallocated_chain_ids.add(new_chain_tuple[0])
new_chain_tuples[event_id] = new_chain_tuple
chain_to_max_seq_no[new_chain_tuple[0]] = new_chain_tuple[1]
# Generate new chain IDs for all unallocated chain IDs.
newly_allocated_chain_ids = db_pool.event_chain_id_gen.get_next_mult_txn(
txn, len(unallocated_chain_ids)
)
# Map from potentially temporary chain ID to real chain ID
chain_id_to_allocated_map = dict(
zip(unallocated_chain_ids, newly_allocated_chain_ids)
) # type: Dict[Any, int]
chain_id_to_allocated_map.update((c, c) for c in existing_chains)
return {
event_id: (chain_id_to_allocated_map[chain_id], seq)
for event_id, (chain_id, seq) in new_chain_tuples.items()
}
def _persist_transaction_ids_txn( def _persist_transaction_ids_txn(
self, self,
txn: LoggingTransaction, txn: LoggingTransaction,
@@ -965,7 +876,7 @@ class PersistEventsStore:
WHERE room_id = ? AND type = ? AND state_key = ? WHERE room_id = ? AND type = ? AND state_key = ?
) )
""" """
txn.execute_batch( txn.executemany(
sql, sql,
( (
( (
@@ -984,7 +895,7 @@ class PersistEventsStore:
) )
# Now we actually update the current_state_events table # Now we actually update the current_state_events table
txn.execute_batch( txn.executemany(
"DELETE FROM current_state_events" "DELETE FROM current_state_events"
" WHERE room_id = ? AND type = ? AND state_key = ?", " WHERE room_id = ? AND type = ? AND state_key = ?",
( (
@@ -996,7 +907,7 @@ class PersistEventsStore:
# We include the membership in the current state table, hence we do # We include the membership in the current state table, hence we do
# a lookup when we insert. This assumes that all events have already # a lookup when we insert. This assumes that all events have already
# been inserted into room_memberships. # been inserted into room_memberships.
txn.execute_batch( txn.executemany(
"""INSERT INTO current_state_events """INSERT INTO current_state_events
(room_id, type, state_key, event_id, membership) (room_id, type, state_key, event_id, membership)
VALUES (?, ?, ?, ?, (SELECT membership FROM room_memberships WHERE event_id = ?)) VALUES (?, ?, ?, ?, (SELECT membership FROM room_memberships WHERE event_id = ?))
@@ -1016,7 +927,7 @@ class PersistEventsStore:
# we have no record of the fact the user *was* a member of the # we have no record of the fact the user *was* a member of the
# room but got, say, state reset out of it. # room but got, say, state reset out of it.
if to_delete or to_insert: if to_delete or to_insert:
txn.execute_batch( txn.executemany(
"DELETE FROM local_current_membership" "DELETE FROM local_current_membership"
" WHERE room_id = ? AND user_id = ?", " WHERE room_id = ? AND user_id = ?",
( (
@@ -1027,7 +938,7 @@ class PersistEventsStore:
) )
if to_insert: if to_insert:
txn.execute_batch( txn.executemany(
"""INSERT INTO local_current_membership """INSERT INTO local_current_membership
(room_id, user_id, event_id, membership) (room_id, user_id, event_id, membership)
VALUES (?, ?, ?, (SELECT membership FROM room_memberships WHERE event_id = ?)) VALUES (?, ?, ?, (SELECT membership FROM room_memberships WHERE event_id = ?))
@@ -1827,7 +1738,7 @@ class PersistEventsStore:
""" """
if events_and_contexts: if events_and_contexts:
txn.execute_batch( txn.executemany(
sql, sql,
( (
( (
@@ -1856,7 +1767,7 @@ class PersistEventsStore:
# Now we delete the staging area for *all* events that were being # Now we delete the staging area for *all* events that were being
# persisted. # persisted.
txn.execute_batch( txn.executemany(
"DELETE FROM event_push_actions_staging WHERE event_id = ?", "DELETE FROM event_push_actions_staging WHERE event_id = ?",
((event.event_id,) for event, _ in all_events_and_contexts), ((event.event_id,) for event, _ in all_events_and_contexts),
) )
@@ -1975,7 +1886,7 @@ class PersistEventsStore:
" )" " )"
) )
txn.execute_batch( txn.executemany(
query, query,
[ [
(e_id, ev.room_id, e_id, ev.room_id, e_id, ev.room_id, False) (e_id, ev.room_id, e_id, ev.room_id, e_id, ev.room_id, False)
@@ -1989,7 +1900,7 @@ class PersistEventsStore:
"DELETE FROM event_backward_extremities" "DELETE FROM event_backward_extremities"
" WHERE event_id = ? AND room_id = ?" " WHERE event_id = ? AND room_id = ?"
) )
txn.execute_batch( txn.executemany(
query, query,
[ [
(ev.event_id, ev.room_id) (ev.event_id, ev.room_id)

View File

@@ -139,6 +139,8 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
max_stream_id = progress["max_stream_id_exclusive"] max_stream_id = progress["max_stream_id_exclusive"]
rows_inserted = progress.get("rows_inserted", 0) rows_inserted = progress.get("rows_inserted", 0)
INSERT_CLUMP_SIZE = 1000
def reindex_txn(txn): def reindex_txn(txn):
sql = ( sql = (
"SELECT stream_ordering, event_id, json FROM events" "SELECT stream_ordering, event_id, json FROM events"
@@ -176,7 +178,9 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
sql = "UPDATE events SET sender = ?, contains_url = ? WHERE event_id = ?" sql = "UPDATE events SET sender = ?, contains_url = ? WHERE event_id = ?"
txn.execute_batch(sql, update_rows) for index in range(0, len(update_rows), INSERT_CLUMP_SIZE):
clump = update_rows[index : index + INSERT_CLUMP_SIZE]
txn.executemany(sql, clump)
progress = { progress = {
"target_min_stream_id_inclusive": target_min_stream_id, "target_min_stream_id_inclusive": target_min_stream_id,
@@ -206,6 +210,8 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
max_stream_id = progress["max_stream_id_exclusive"] max_stream_id = progress["max_stream_id_exclusive"]
rows_inserted = progress.get("rows_inserted", 0) rows_inserted = progress.get("rows_inserted", 0)
INSERT_CLUMP_SIZE = 1000
def reindex_search_txn(txn): def reindex_search_txn(txn):
sql = ( sql = (
"SELECT stream_ordering, event_id FROM events" "SELECT stream_ordering, event_id FROM events"
@@ -250,7 +256,9 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
sql = "UPDATE events SET origin_server_ts = ? WHERE event_id = ?" sql = "UPDATE events SET origin_server_ts = ? WHERE event_id = ?"
txn.execute_batch(sql, rows_to_update) for index in range(0, len(rows_to_update), INSERT_CLUMP_SIZE):
clump = rows_to_update[index : index + INSERT_CLUMP_SIZE]
txn.executemany(sql, clump)
progress = { progress = {
"target_min_stream_id_inclusive": target_min_stream_id, "target_min_stream_id_inclusive": target_min_stream_id,

View File

@@ -417,7 +417,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
" WHERE media_origin = ? AND media_id = ?" " WHERE media_origin = ? AND media_id = ?"
) )
txn.execute_batch( txn.executemany(
sql, sql,
( (
(time_ms, media_origin, media_id) (time_ms, media_origin, media_id)
@@ -430,7 +430,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
" WHERE media_id = ?" " WHERE media_id = ?"
) )
txn.execute_batch(sql, ((time_ms, media_id) for media_id in local_media)) txn.executemany(sql, ((time_ms, media_id) for media_id in local_media))
return await self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"update_cached_last_access_time", update_cache_txn "update_cached_last_access_time", update_cache_txn
@@ -557,7 +557,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
sql = "DELETE FROM local_media_repository_url_cache WHERE media_id = ?" sql = "DELETE FROM local_media_repository_url_cache WHERE media_id = ?"
def _delete_url_cache_txn(txn): def _delete_url_cache_txn(txn):
txn.execute_batch(sql, [(media_id,) for media_id in media_ids]) txn.executemany(sql, [(media_id,) for media_id in media_ids])
return await self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"delete_url_cache", _delete_url_cache_txn "delete_url_cache", _delete_url_cache_txn
@@ -586,11 +586,11 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
def _delete_url_cache_media_txn(txn): def _delete_url_cache_media_txn(txn):
sql = "DELETE FROM local_media_repository WHERE media_id = ?" sql = "DELETE FROM local_media_repository WHERE media_id = ?"
txn.execute_batch(sql, [(media_id,) for media_id in media_ids]) txn.executemany(sql, [(media_id,) for media_id in media_ids])
sql = "DELETE FROM local_media_repository_thumbnails WHERE media_id = ?" sql = "DELETE FROM local_media_repository_thumbnails WHERE media_id = ?"
txn.execute_batch(sql, [(media_id,) for media_id in media_ids]) txn.executemany(sql, [(media_id,) for media_id in media_ids])
return await self.db_pool.runInteraction( return await self.db_pool.runInteraction(
"delete_url_cache_media", _delete_url_cache_media_txn "delete_url_cache_media", _delete_url_cache_media_txn

View File

@@ -172,7 +172,7 @@ class PurgeEventsStore(StateGroupWorkerStore, SQLBaseStore):
) )
# Update backward extremeties # Update backward extremeties
txn.execute_batch( txn.executemany(
"INSERT INTO event_backward_extremities (room_id, event_id)" "INSERT INTO event_backward_extremities (room_id, event_id)"
" VALUES (?, ?)", " VALUES (?, ?)",
[(room_id, event_id) for event_id, in new_backwards_extrems], [(room_id, event_id) for event_id, in new_backwards_extrems],

View File

@@ -1104,7 +1104,7 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
FROM user_threepids FROM user_threepids
""" """
txn.execute_batch(sql, [(id_server,) for id_server in id_servers]) txn.executemany(sql, [(id_server,) for id_server in id_servers])
if id_servers: if id_servers:
await self.db_pool.runInteraction( await self.db_pool.runInteraction(

View File

@@ -873,6 +873,8 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
"max_stream_id_exclusive", self._stream_order_on_start + 1 "max_stream_id_exclusive", self._stream_order_on_start + 1
) )
INSERT_CLUMP_SIZE = 1000
def add_membership_profile_txn(txn): def add_membership_profile_txn(txn):
sql = """ sql = """
SELECT stream_ordering, event_id, events.room_id, event_json.json SELECT stream_ordering, event_id, events.room_id, event_json.json
@@ -913,7 +915,9 @@ class RoomMemberBackgroundUpdateStore(SQLBaseStore):
UPDATE room_memberships SET display_name = ?, avatar_url = ? UPDATE room_memberships SET display_name = ?, avatar_url = ?
WHERE event_id = ? AND room_id = ? WHERE event_id = ? AND room_id = ?
""" """
txn.execute_batch(to_update_sql, to_update) for index in range(0, len(to_update), INSERT_CLUMP_SIZE):
clump = to_update[index : index + INSERT_CLUMP_SIZE]
txn.executemany(to_update_sql, clump)
progress = { progress = {
"target_min_stream_id_inclusive": target_min_stream_id, "target_min_stream_id_inclusive": target_min_stream_id,

View File

@@ -55,7 +55,7 @@ def run_create(cur: Cursor, database_engine: BaseDatabaseEngine, *args, **kwargs
# { "ignored_users": "@someone:example.org": {} } # { "ignored_users": "@someone:example.org": {} }
ignored_users = content.get("ignored_users", {}) ignored_users = content.get("ignored_users", {})
if isinstance(ignored_users, dict) and ignored_users: if isinstance(ignored_users, dict) and ignored_users:
cur.execute_batch(insert_sql, [(user_id, u) for u in ignored_users]) cur.executemany(insert_sql, [(user_id, u) for u in ignored_users])
# Add indexes after inserting data for efficiency. # Add indexes after inserting data for efficiency.
logger.info("Adding constraints to ignored_users table") logger.info("Adding constraints to ignored_users table")

View File

@@ -63,7 +63,7 @@ class SearchWorkerStore(SQLBaseStore):
for entry in entries for entry in entries
) )
txn.execute_batch(sql, args) txn.executemany(sql, args)
elif isinstance(self.database_engine, Sqlite3Engine): elif isinstance(self.database_engine, Sqlite3Engine):
sql = ( sql = (
@@ -75,7 +75,7 @@ class SearchWorkerStore(SQLBaseStore):
for entry in entries for entry in entries
) )
txn.execute_batch(sql, args) txn.executemany(sql, args)
else: else:
# This should be unreachable. # This should be unreachable.
raise Exception("Unrecognized database engine") raise Exception("Unrecognized database engine")

View File

@@ -565,11 +565,11 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
) )
logger.info("[purge] removing redundant state groups") logger.info("[purge] removing redundant state groups")
txn.execute_batch( txn.executemany(
"DELETE FROM state_groups_state WHERE state_group = ?", "DELETE FROM state_groups_state WHERE state_group = ?",
((sg,) for sg in state_groups_to_delete), ((sg,) for sg in state_groups_to_delete),
) )
txn.execute_batch( txn.executemany(
"DELETE FROM state_groups WHERE id = ?", "DELETE FROM state_groups WHERE id = ?",
((sg,) for sg in state_groups_to_delete), ((sg,) for sg in state_groups_to_delete),
) )

View File

@@ -15,11 +15,12 @@
import heapq import heapq
import logging import logging
import threading import threading
from collections import OrderedDict from collections import deque
from contextlib import contextmanager from contextlib import contextmanager
from typing import Dict, List, Optional, Set, Tuple, Union from typing import Dict, List, Optional, Set, Tuple, Union
import attr import attr
from typing_extensions import Deque
from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.database import DatabasePool, LoggingTransaction
@@ -100,13 +101,7 @@ class StreamIdGenerator:
self._current = (max if step > 0 else min)( self._current = (max if step > 0 else min)(
self._current, _load_current_id(db_conn, table, column, step) self._current, _load_current_id(db_conn, table, column, step)
) )
self._unfinished_ids = deque() # type: Deque[int]
# We use this as an ordered set, as we want to efficiently append items,
# remove items and get the first item. Since we insert IDs in order, the
# insertion ordering will ensure its in the correct ordering.
#
# The key and values are the same, but we never look at the values.
self._unfinished_ids = OrderedDict() # type: OrderedDict[int, int]
def get_next(self): def get_next(self):
""" """
@@ -118,7 +113,7 @@ class StreamIdGenerator:
self._current += self._step self._current += self._step
next_id = self._current next_id = self._current
self._unfinished_ids[next_id] = next_id self._unfinished_ids.append(next_id)
@contextmanager @contextmanager
def manager(): def manager():
@@ -126,7 +121,7 @@ class StreamIdGenerator:
yield next_id yield next_id
finally: finally:
with self._lock: with self._lock:
self._unfinished_ids.pop(next_id) self._unfinished_ids.remove(next_id)
return _AsyncCtxManagerWrapper(manager()) return _AsyncCtxManagerWrapper(manager())
@@ -145,7 +140,7 @@ class StreamIdGenerator:
self._current += n * self._step self._current += n * self._step
for next_id in next_ids: for next_id in next_ids:
self._unfinished_ids[next_id] = next_id self._unfinished_ids.append(next_id)
@contextmanager @contextmanager
def manager(): def manager():
@@ -154,7 +149,7 @@ class StreamIdGenerator:
finally: finally:
with self._lock: with self._lock:
for next_id in next_ids: for next_id in next_ids:
self._unfinished_ids.pop(next_id) self._unfinished_ids.remove(next_id)
return _AsyncCtxManagerWrapper(manager()) return _AsyncCtxManagerWrapper(manager())
@@ -167,7 +162,7 @@ class StreamIdGenerator:
""" """
with self._lock: with self._lock:
if self._unfinished_ids: if self._unfinished_ids:
return next(iter(self._unfinished_ids)) - self._step return self._unfinished_ids[0] - self._step
return self._current return self._current

View File

@@ -69,11 +69,6 @@ class SequenceGenerator(metaclass=abc.ABCMeta):
"""Gets the next ID in the sequence""" """Gets the next ID in the sequence"""
... ...
@abc.abstractmethod
def get_next_mult_txn(self, txn: Cursor, n: int) -> List[int]:
"""Get the next `n` IDs in the sequence"""
...
@abc.abstractmethod @abc.abstractmethod
def check_consistency( def check_consistency(
self, self,
@@ -224,17 +219,6 @@ class LocalSequenceGenerator(SequenceGenerator):
self._current_max_id += 1 self._current_max_id += 1
return self._current_max_id return self._current_max_id
def get_next_mult_txn(self, txn: Cursor, n: int) -> List[int]:
with self._lock:
if self._current_max_id is None:
assert self._callback is not None
self._current_max_id = self._callback(txn)
self._callback = None
first_id = self._current_max_id + 1
self._current_max_id += n
return [first_id + i for i in range(n)]
def check_consistency( def check_consistency(
self, self,
db_conn: Connection, db_conn: Connection,

View File

@@ -78,7 +78,7 @@ def sorted_topologically(
if node not in degree_map: if node not in degree_map:
continue continue
for edge in edges: for edge in set(edges):
if edge in degree_map: if edge in degree_map:
degree_map[node] += 1 degree_map[node] += 1

View File

@@ -28,7 +28,6 @@ from synapse.api.errors import Codes, HttpResponseException, ResourceLimitError
from synapse.api.room_versions import RoomVersions from synapse.api.room_versions import RoomVersions
from synapse.rest.client.v1 import login, logout, profile, room from synapse.rest.client.v1 import login, logout, profile, room
from synapse.rest.client.v2_alpha import devices, sync from synapse.rest.client.v2_alpha import devices, sync
from synapse.types import JsonDict
from tests import unittest from tests import unittest
from tests.test_utils import make_awaitable from tests.test_utils import make_awaitable
@@ -469,6 +468,13 @@ class UsersListTestCase(unittest.HomeserverTestCase):
self.admin_user = self.register_user("admin", "pass", admin=True) self.admin_user = self.register_user("admin", "pass", admin=True)
self.admin_user_tok = self.login("admin", "pass") self.admin_user_tok = self.login("admin", "pass")
self.user1 = self.register_user(
"user1", "pass1", admin=False, displayname="Name 1"
)
self.user2 = self.register_user(
"user2", "pass2", admin=False, displayname="Name 2"
)
def test_no_auth(self): def test_no_auth(self):
""" """
Try to list users without authentication. Try to list users without authentication.
@@ -482,7 +488,6 @@ class UsersListTestCase(unittest.HomeserverTestCase):
""" """
If the user is not a server admin, an error is returned. If the user is not a server admin, an error is returned.
""" """
self._create_users(1)
other_user_token = self.login("user1", "pass1") other_user_token = self.login("user1", "pass1")
channel = self.make_request("GET", self.url, access_token=other_user_token) channel = self.make_request("GET", self.url, access_token=other_user_token)
@@ -494,8 +499,6 @@ class UsersListTestCase(unittest.HomeserverTestCase):
""" """
List all users, including deactivated users. List all users, including deactivated users.
""" """
self._create_users(2)
channel = self.make_request( channel = self.make_request(
"GET", "GET",
self.url + "?deactivated=true", self.url + "?deactivated=true",
@@ -508,7 +511,14 @@ class UsersListTestCase(unittest.HomeserverTestCase):
self.assertEqual(3, channel.json_body["total"]) self.assertEqual(3, channel.json_body["total"])
# Check that all fields are available # Check that all fields are available
self._check_fields(channel.json_body["users"]) for u in channel.json_body["users"]:
self.assertIn("name", u)
self.assertIn("is_guest", u)
self.assertIn("admin", u)
self.assertIn("user_type", u)
self.assertIn("deactivated", u)
self.assertIn("displayname", u)
self.assertIn("avatar_url", u)
def test_search_term(self): def test_search_term(self):
"""Test that searching for a users works correctly""" """Test that searching for a users works correctly"""
@@ -539,7 +549,6 @@ class UsersListTestCase(unittest.HomeserverTestCase):
# Check that users were returned # Check that users were returned
self.assertTrue("users" in channel.json_body) self.assertTrue("users" in channel.json_body)
self._check_fields(channel.json_body["users"])
users = channel.json_body["users"] users = channel.json_body["users"]
# Check that the expected number of users were returned # Check that the expected number of users were returned
@@ -552,30 +561,25 @@ class UsersListTestCase(unittest.HomeserverTestCase):
u = users[0] u = users[0]
self.assertEqual(expected_user_id, u["name"]) self.assertEqual(expected_user_id, u["name"])
self._create_users(2)
user1 = "@user1:test"
user2 = "@user2:test"
# Perform search tests # Perform search tests
_search_test(user1, "er1") _search_test(self.user1, "er1")
_search_test(user1, "me 1") _search_test(self.user1, "me 1")
_search_test(user2, "er2") _search_test(self.user2, "er2")
_search_test(user2, "me 2") _search_test(self.user2, "me 2")
_search_test(user1, "er1", "user_id") _search_test(self.user1, "er1", "user_id")
_search_test(user2, "er2", "user_id") _search_test(self.user2, "er2", "user_id")
# Test case insensitive # Test case insensitive
_search_test(user1, "ER1") _search_test(self.user1, "ER1")
_search_test(user1, "NAME 1") _search_test(self.user1, "NAME 1")
_search_test(user2, "ER2") _search_test(self.user2, "ER2")
_search_test(user2, "NAME 2") _search_test(self.user2, "NAME 2")
_search_test(user1, "ER1", "user_id") _search_test(self.user1, "ER1", "user_id")
_search_test(user2, "ER2", "user_id") _search_test(self.user2, "ER2", "user_id")
_search_test(None, "foo") _search_test(None, "foo")
_search_test(None, "bar") _search_test(None, "bar")
@@ -583,179 +587,6 @@ class UsersListTestCase(unittest.HomeserverTestCase):
_search_test(None, "foo", "user_id") _search_test(None, "foo", "user_id")
_search_test(None, "bar", "user_id") _search_test(None, "bar", "user_id")
def test_invalid_parameter(self):
"""
If parameters are invalid, an error is returned.
"""
# negative limit
channel = self.make_request(
"GET", self.url + "?limit=-5", access_token=self.admin_user_tok,
)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# negative from
channel = self.make_request(
"GET", self.url + "?from=-5", access_token=self.admin_user_tok,
)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.INVALID_PARAM, channel.json_body["errcode"])
# invalid guests
channel = self.make_request(
"GET", self.url + "?guests=not_bool", access_token=self.admin_user_tok,
)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
# invalid deactivated
channel = self.make_request(
"GET", self.url + "?deactivated=not_bool", access_token=self.admin_user_tok,
)
self.assertEqual(400, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(Codes.UNKNOWN, channel.json_body["errcode"])
def test_limit(self):
"""
Testing list of users with limit
"""
number_users = 20
# Create one less user (since there's already an admin user).
self._create_users(number_users - 1)
channel = self.make_request(
"GET", self.url + "?limit=5", access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["total"], number_users)
self.assertEqual(len(channel.json_body["users"]), 5)
self.assertEqual(channel.json_body["next_token"], "5")
self._check_fields(channel.json_body["users"])
def test_from(self):
"""
Testing list of users with a defined starting point (from)
"""
number_users = 20
# Create one less user (since there's already an admin user).
self._create_users(number_users - 1)
channel = self.make_request(
"GET", self.url + "?from=5", access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["total"], number_users)
self.assertEqual(len(channel.json_body["users"]), 15)
self.assertNotIn("next_token", channel.json_body)
self._check_fields(channel.json_body["users"])
def test_limit_and_from(self):
"""
Testing list of users with a defined starting point and limit
"""
number_users = 20
# Create one less user (since there's already an admin user).
self._create_users(number_users - 1)
channel = self.make_request(
"GET", self.url + "?from=5&limit=10", access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["total"], number_users)
self.assertEqual(channel.json_body["next_token"], "15")
self.assertEqual(len(channel.json_body["users"]), 10)
self._check_fields(channel.json_body["users"])
def test_next_token(self):
"""
Testing that `next_token` appears at the right place
"""
number_users = 20
# Create one less user (since there's already an admin user).
self._create_users(number_users - 1)
# `next_token` does not appear
# Number of results is the number of entries
channel = self.make_request(
"GET", self.url + "?limit=20", access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["total"], number_users)
self.assertEqual(len(channel.json_body["users"]), number_users)
self.assertNotIn("next_token", channel.json_body)
# `next_token` does not appear
# Number of max results is larger than the number of entries
channel = self.make_request(
"GET", self.url + "?limit=21", access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["total"], number_users)
self.assertEqual(len(channel.json_body["users"]), number_users)
self.assertNotIn("next_token", channel.json_body)
# `next_token` does appear
# Number of max results is smaller than the number of entries
channel = self.make_request(
"GET", self.url + "?limit=19", access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["total"], number_users)
self.assertEqual(len(channel.json_body["users"]), 19)
self.assertEqual(channel.json_body["next_token"], "19")
# Check
# Set `from` to value of `next_token` for request remaining entries
# `next_token` does not appear
channel = self.make_request(
"GET", self.url + "?from=19", access_token=self.admin_user_tok,
)
self.assertEqual(200, int(channel.result["code"]), msg=channel.result["body"])
self.assertEqual(channel.json_body["total"], number_users)
self.assertEqual(len(channel.json_body["users"]), 1)
self.assertNotIn("next_token", channel.json_body)
def _check_fields(self, content: JsonDict):
"""Checks that the expected user attributes are present in content
Args:
content: List that is checked for content
"""
for u in content:
self.assertIn("name", u)
self.assertIn("is_guest", u)
self.assertIn("admin", u)
self.assertIn("user_type", u)
self.assertIn("deactivated", u)
self.assertIn("displayname", u)
self.assertIn("avatar_url", u)
def _create_users(self, number_users: int):
"""
Create a number of users
Args:
number_users: Number of users to be created
"""
for i in range(1, number_users + 1):
self.register_user(
"user%d" % i, "pass%d" % i, admin=False, displayname="Name %d" % i,
)
class DeactivateAccountTestCase(unittest.HomeserverTestCase): class DeactivateAccountTestCase(unittest.HomeserverTestCase):

View File

@@ -202,6 +202,7 @@ class MediaRepoTests(unittest.HomeserverTestCase):
config = self.default_config() config = self.default_config()
config["media_store_path"] = self.media_store_path config["media_store_path"] = self.media_store_path
config["thumbnail_requirements"] = {}
config["max_image_pixels"] = 2000000 config["max_image_pixels"] = 2000000
provider_config = { provider_config = {
@@ -312,39 +313,15 @@ class MediaRepoTests(unittest.HomeserverTestCase):
self.assertEqual(headers.getRawHeaders(b"Content-Disposition"), None) self.assertEqual(headers.getRawHeaders(b"Content-Disposition"), None)
def test_thumbnail_crop(self): def test_thumbnail_crop(self):
"""Test that a cropped remote thumbnail is available."""
self._test_thumbnail( self._test_thumbnail(
"crop", self.test_image.expected_cropped, self.test_image.expected_found "crop", self.test_image.expected_cropped, self.test_image.expected_found
) )
def test_thumbnail_scale(self): def test_thumbnail_scale(self):
"""Test that a scaled remote thumbnail is available."""
self._test_thumbnail( self._test_thumbnail(
"scale", self.test_image.expected_scaled, self.test_image.expected_found "scale", self.test_image.expected_scaled, self.test_image.expected_found
) )
def test_invalid_type(self):
"""An invalid thumbnail type is never available."""
self._test_thumbnail("invalid", None, False)
@unittest.override_config(
{"thumbnail_sizes": [{"width": 32, "height": 32, "method": "scale"}]}
)
def test_no_thumbnail_crop(self):
"""
Override the config to generate only scaled thumbnails, but request a cropped one.
"""
self._test_thumbnail("crop", None, False)
@unittest.override_config(
{"thumbnail_sizes": [{"width": 32, "height": 32, "method": "crop"}]}
)
def test_no_thumbnail_scale(self):
"""
Override the config to generate only cropped thumbnails, but request a scaled one.
"""
self._test_thumbnail("scale", None, False)
def _test_thumbnail(self, method, expected_body, expected_found): def _test_thumbnail(self, method, expected_body, expected_found):
params = "?width=32&height=32&method=" + method params = "?width=32&height=32&method=" + method
channel = make_request( channel = make_request(

View File

@@ -92,3 +92,15 @@ class SortTopologically(TestCase):
# Valid orderings are `[1, 3, 2, 4]` or `[1, 2, 3, 4]`, but we should # Valid orderings are `[1, 3, 2, 4]` or `[1, 2, 3, 4]`, but we should
# always get the same one. # always get the same one.
self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4]) self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4])
def test_duplicates(self):
"Test that a graph with duplicate edges work"
graph = {1: [], 2: [1, 1], 3: [2, 2], 4: [3]} # type: Dict[int, List[int]]
self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4])
def test_multiple_paths(self):
"Test that a graph with multiple paths between two nodes work"
graph = {1: [], 2: [1], 3: [2], 4: [3, 2, 1]} # type: Dict[int, List[int]]
self.assertEqual(list(sorted_topologically([4, 3, 2, 1], graph)), [1, 2, 3, 4])

11
tox.ini
View File

@@ -24,8 +24,7 @@ deps =
# install the "enum34" dependency of cryptography. # install the "enum34" dependency of cryptography.
pip>=10 pip>=10
# directories/files we run the linters on. # directories/files we run the linters on
# if you update this list, make sure to do the same in scripts-dev/lint.sh
lint_targets = lint_targets =
setup.py setup.py
synapse synapse
@@ -101,7 +100,7 @@ usedevelop=true
# A test suite for the oldest supported versions of Python libraries, to catch # A test suite for the oldest supported versions of Python libraries, to catch
# any uses of APIs not available in them. # any uses of APIs not available in them.
[testenv:py35-{old,old-postgres}] [testenv:py35-old]
skip_install=True skip_install=True
deps = deps =
# Ensure a version of setuptools that supports Python 3.5 is installed. # Ensure a version of setuptools that supports Python 3.5 is installed.
@@ -114,17 +113,11 @@ deps =
coverage coverage
coverage-enable-subprocess==1.0 coverage-enable-subprocess==1.0
setenv =
postgres: SYNAPSE_POSTGRES = 1
commands = commands =
# Make all greater-thans equals so we test the oldest version of our direct # Make all greater-thans equals so we test the oldest version of our direct
# dependencies, but make the pyopenssl 17.0, which can work against an # dependencies, but make the pyopenssl 17.0, which can work against an
# OpenSSL 1.1 compiled cryptography (as older ones don't compile on Travis). # OpenSSL 1.1 compiled cryptography (as older ones don't compile on Travis).
#
# Also strip out psycopg2 unless we need it.
/bin/sh -c 'python -m synapse.python_dependencies | sed -e "s/>=/==/g" -e "/psycopg2/d" -e "s/pyopenssl==16.0.0/pyopenssl==17.0.0/" | xargs -d"\n" pip install' /bin/sh -c 'python -m synapse.python_dependencies | sed -e "s/>=/==/g" -e "/psycopg2/d" -e "s/pyopenssl==16.0.0/pyopenssl==17.0.0/" | xargs -d"\n" pip install'
postgres: /bin/sh -c 'python -m synapse.python_dependencies | sed -e "s/>=/==/g" | grep -F "psycopg2" | xargs -d"\n" pip install'
# Install Synapse itself. This won't update any libraries. # Install Synapse itself. This won't update any libraries.
pip install -e ".[test]" pip install -e ".[test]"