1
0

Merge remote-tracking branch 'element/develop' into kegan/sticky-events

This commit is contained in:
Will Hunt
2025-11-11 11:09:11 +00:00
739 changed files with 12240 additions and 11243 deletions

View File

@@ -18,16 +18,15 @@ import sys
import threading
from concurrent.futures import ThreadPoolExecutor
from types import FrameType
from typing import Collection, Optional, Sequence, Set
from typing import Collection, Sequence
# These are expanded inside the dockerfile to be a fully qualified image name.
# e.g. docker.io/library/debian:bullseye
# e.g. docker.io/library/debian:bookworm
#
# If an EOL is forced by a Python version and we're dropping support for it, make sure
# to remove references to the distibution across Synapse (search for "bullseye" for
# to remove references to the distibution across Synapse (search for "bookworm" for
# example)
DISTS = (
"debian:bullseye", # (EOL ~2024-07) (our EOL forced by Python 3.9 is 2025-10-05)
"debian:bookworm", # (EOL 2026-06) (our EOL forced by Python 3.11 is 2027-10-24)
"debian:sid", # (rolling distro, no EOL)
"ubuntu:jammy", # 22.04 LTS (EOL 2027-04) (our EOL forced by Python 3.10 is 2026-10-04)
@@ -50,11 +49,11 @@ class Builder:
def __init__(
self,
redirect_stdout: bool = False,
docker_build_args: Optional[Sequence[str]] = None,
docker_build_args: Sequence[str] | None = None,
):
self.redirect_stdout = redirect_stdout
self._docker_build_args = tuple(docker_build_args or ())
self.active_containers: Set[str] = set()
self.active_containers: set[str] = set()
self._lock = threading.Lock()
self._failed = False
@@ -168,7 +167,7 @@ class Builder:
def run_builds(
builder: Builder, dists: Collection[str], jobs: int = 1, skip_tests: bool = False
) -> None:
def sig(signum: int, _frame: Optional[FrameType]) -> None:
def sig(signum: int, _frame: FrameType | None) -> None:
print("Caught SIGINT")
builder.kill_containers()

View File

@@ -21,7 +21,6 @@
#
import sys
from pathlib import Path
from typing import Dict, List
import tomli
@@ -33,7 +32,7 @@ def main() -> None:
# Poetry 1.3+ lockfile format:
# There's a `files` inline table in each [[package]]
packages_to_assets: Dict[str, List[Dict[str, str]]] = {
packages_to_assets: dict[str, list[dict[str, str]]] = {
package["name"]: package["files"] for package in lockfile_content["package"]
}

View File

@@ -1,478 +0,0 @@
#! /usr/bin/env python
#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright 2022 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
# Originally licensed under the Apache License, Version 2.0:
# <http://www.apache.org/licenses/LICENSE-2.0>.
#
# [This file includes modifications made by New Vector Limited]
#
#
"""
A script which enforces that Synapse always uses strict types when defining a Pydantic
model.
Pydantic does not yet offer a strict mode, but it is planned for pydantic v2. See
https://github.com/pydantic/pydantic/issues/1098
https://pydantic-docs.helpmanual.io/blog/pydantic-v2/#strict-mode
until then, this script is a best effort to stop us from introducing type coersion bugs
(like the infamous stringy power levels fixed in room version 10).
"""
import argparse
import contextlib
import functools
import importlib
import logging
import os
import pkgutil
import sys
import textwrap
import traceback
import unittest.mock
from contextlib import contextmanager
from typing import (
Any,
Callable,
Dict,
Generator,
List,
Set,
Type,
TypeVar,
)
from parameterized import parameterized
from typing_extensions import ParamSpec
from synapse._pydantic_compat import (
BaseModel as PydanticBaseModel,
conbytes,
confloat,
conint,
constr,
get_args,
)
logger = logging.getLogger(__name__)
CONSTRAINED_TYPE_FACTORIES_WITH_STRICT_FLAG: List[Callable] = [
constr,
conbytes,
conint,
confloat,
]
TYPES_THAT_PYDANTIC_WILL_COERCE_TO = [
str,
bytes,
int,
float,
bool,
]
P = ParamSpec("P")
R = TypeVar("R")
class ModelCheckerException(Exception):
"""Dummy exception. Allows us to detect unwanted types during a module import."""
class MissingStrictInConstrainedTypeException(ModelCheckerException):
factory_name: str
def __init__(self, factory_name: str):
self.factory_name = factory_name
class FieldHasUnwantedTypeException(ModelCheckerException):
message: str
def __init__(self, message: str):
self.message = message
def make_wrapper(factory: Callable[P, R]) -> Callable[P, R]:
"""We patch `constr` and friends with wrappers that enforce strict=True."""
@functools.wraps(factory)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
if "strict" not in kwargs:
raise MissingStrictInConstrainedTypeException(factory.__name__)
if not kwargs["strict"]:
raise MissingStrictInConstrainedTypeException(factory.__name__)
return factory(*args, **kwargs)
return wrapper
def field_type_unwanted(type_: Any) -> bool:
"""Very rough attempt to detect if a type is unwanted as a Pydantic annotation.
At present, we exclude types which will coerce, or any generic type involving types
which will coerce."""
logger.debug("Is %s unwanted?")
if type_ in TYPES_THAT_PYDANTIC_WILL_COERCE_TO:
logger.debug("yes")
return True
logger.debug("Maybe. Subargs are %s", get_args(type_))
rv = any(field_type_unwanted(t) for t in get_args(type_))
logger.debug("Conclusion: %s %s unwanted", type_, "is" if rv else "is not")
return rv
class PatchedBaseModel(PydanticBaseModel):
"""A patched version of BaseModel that inspects fields after models are defined.
We complain loudly if we see an unwanted type.
Beware: ModelField.type_ is presumably private; this is likely to be very brittle.
"""
@classmethod
def __init_subclass__(cls: Type[PydanticBaseModel], **kwargs: object):
for field in cls.__fields__.values():
# Note that field.type_ and field.outer_type are computed based on the
# annotation type, see pydantic.fields.ModelField._type_analysis
if field_type_unwanted(field.outer_type_):
# TODO: this only reports the first bad field. Can we find all bad ones
# and report them all?
raise FieldHasUnwantedTypeException(
f"{cls.__module__}.{cls.__qualname__} has field '{field.name}' "
f"with unwanted type `{field.outer_type_}`"
)
@contextmanager
def monkeypatch_pydantic() -> Generator[None, None, None]:
"""Patch pydantic with our snooping versions of BaseModel and the con* functions.
If the snooping functions see something they don't like, they'll raise a
ModelCheckingException instance.
"""
with contextlib.ExitStack() as patches:
# Most Synapse code ought to import the patched objects directly from
# `pydantic`. But we also patch their containing modules `pydantic.main` and
# `pydantic.types` for completeness.
patch_basemodel = unittest.mock.patch(
"synapse._pydantic_compat.BaseModel", new=PatchedBaseModel
)
patches.enter_context(patch_basemodel)
for factory in CONSTRAINED_TYPE_FACTORIES_WITH_STRICT_FLAG:
wrapper: Callable = make_wrapper(factory)
patch = unittest.mock.patch(
f"synapse._pydantic_compat.{factory.__name__}", new=wrapper
)
patches.enter_context(patch)
yield
def format_model_checker_exception(e: ModelCheckerException) -> str:
"""Work out which line of code caused e. Format the line in a human-friendly way."""
# TODO. FieldHasUnwantedTypeException gives better error messages. Can we ditch the
# patches of constr() etc, and instead inspect fields to look for ConstrainedStr
# with strict=False? There is some difficulty with the inheritance hierarchy
# because StrictStr < ConstrainedStr < str.
if isinstance(e, FieldHasUnwantedTypeException):
return e.message
elif isinstance(e, MissingStrictInConstrainedTypeException):
frame_summary = traceback.extract_tb(e.__traceback__)[-2]
return (
f"Missing `strict=True` from {e.factory_name}() call \n"
+ traceback.format_list([frame_summary])[0].lstrip()
)
else:
raise ValueError(f"Unknown exception {e}") from e
def lint() -> int:
"""Try to import all of Synapse and see if we spot any Pydantic type coercions.
Print any problems, then return a status code suitable for sys.exit."""
failures = do_lint()
if failures:
print(f"Found {len(failures)} problem(s)")
for failure in sorted(failures):
print(failure)
return os.EX_DATAERR if failures else os.EX_OK
def do_lint() -> Set[str]:
"""Try to import all of Synapse and see if we spot any Pydantic type coercions."""
failures = set()
with monkeypatch_pydantic():
logger.debug("Importing synapse")
try:
# TODO: make "synapse" an argument so we can target this script at
# a subpackage
module = importlib.import_module("synapse")
except ModelCheckerException as e:
logger.warning("Bad annotation found when importing synapse")
failures.add(format_model_checker_exception(e))
return failures
try:
logger.debug("Fetching subpackages")
module_infos = list(
pkgutil.walk_packages(module.__path__, f"{module.__name__}.")
)
except ModelCheckerException as e:
logger.warning("Bad annotation found when looking for modules to import")
failures.add(format_model_checker_exception(e))
return failures
for module_info in module_infos:
logger.debug("Importing %s", module_info.name)
try:
importlib.import_module(module_info.name)
except ModelCheckerException as e:
logger.warning(
"Bad annotation found when importing %s", module_info.name
)
failures.add(format_model_checker_exception(e))
return failures
def run_test_snippet(source: str) -> None:
"""Exec a snippet of source code in an isolated environment."""
# To emulate `source` being called at the top level of the module,
# the globals and locals we provide apparently have to be the same mapping.
#
# > Remember that at the module level, globals and locals are the same dictionary.
# > If exec gets two separate objects as globals and locals, the code will be
# > executed as if it were embedded in a class definition.
globals_: Dict[str, object]
locals_: Dict[str, object]
globals_ = locals_ = {}
exec(textwrap.dedent(source), globals_, locals_)
class TestConstrainedTypesPatch(unittest.TestCase):
def test_expression_without_strict_raises(self) -> None:
with monkeypatch_pydantic(), self.assertRaises(ModelCheckerException):
run_test_snippet(
"""
try:
from pydantic.v1 import constr
except ImportError:
from pydantic import constr
constr()
"""
)
def test_called_as_module_attribute_raises(self) -> None:
with monkeypatch_pydantic(), self.assertRaises(ModelCheckerException):
run_test_snippet(
"""
import pydantic
pydantic.constr()
"""
)
def test_wildcard_import_raises(self) -> None:
with monkeypatch_pydantic(), self.assertRaises(ModelCheckerException):
run_test_snippet(
"""
try:
from pydantic.v1 import *
except ImportError:
from pydantic import *
constr()
"""
)
def test_alternative_import_raises(self) -> None:
with monkeypatch_pydantic(), self.assertRaises(ModelCheckerException):
run_test_snippet(
"""
try:
from pydantic.v1.types import constr
except ImportError:
from pydantic.types import constr
constr()
"""
)
def test_alternative_import_attribute_raises(self) -> None:
with monkeypatch_pydantic(), self.assertRaises(ModelCheckerException):
run_test_snippet(
"""
try:
from pydantic.v1 import types as pydantic_types
except ImportError:
from pydantic import types as pydantic_types
pydantic_types.constr()
"""
)
def test_kwarg_but_no_strict_raises(self) -> None:
with monkeypatch_pydantic(), self.assertRaises(ModelCheckerException):
run_test_snippet(
"""
try:
from pydantic.v1 import constr
except ImportError:
from pydantic import constr
constr(min_length=10)
"""
)
def test_kwarg_strict_False_raises(self) -> None:
with monkeypatch_pydantic(), self.assertRaises(ModelCheckerException):
run_test_snippet(
"""
try:
from pydantic.v1 import constr
except ImportError:
from pydantic import constr
constr(strict=False)
"""
)
def test_kwarg_strict_True_doesnt_raise(self) -> None:
with monkeypatch_pydantic():
run_test_snippet(
"""
try:
from pydantic.v1 import constr
except ImportError:
from pydantic import constr
constr(strict=True)
"""
)
def test_annotation_without_strict_raises(self) -> None:
with monkeypatch_pydantic(), self.assertRaises(ModelCheckerException):
run_test_snippet(
"""
try:
from pydantic.v1 import constr
except ImportError:
from pydantic import constr
x: constr()
"""
)
def test_field_annotation_without_strict_raises(self) -> None:
with monkeypatch_pydantic(), self.assertRaises(ModelCheckerException):
run_test_snippet(
"""
try:
from pydantic.v1 import BaseModel, conint
except ImportError:
from pydantic import BaseModel, conint
class C:
x: conint()
"""
)
class TestFieldTypeInspection(unittest.TestCase):
@parameterized.expand(
[
("str",),
("bytes"),
("int",),
("float",),
("bool"),
("Optional[str]",),
("Union[None, str]",),
("List[str]",),
("List[List[str]]",),
("Dict[StrictStr, str]",),
("Dict[str, StrictStr]",),
("TypedDict('D', x=int)",),
]
)
def test_field_holding_unwanted_type_raises(self, annotation: str) -> None:
with monkeypatch_pydantic(), self.assertRaises(ModelCheckerException):
run_test_snippet(
f"""
from typing import *
try:
from pydantic.v1 import *
except ImportError:
from pydantic import *
class C(BaseModel):
f: {annotation}
"""
)
@parameterized.expand(
[
("StrictStr",),
("StrictBytes"),
("StrictInt",),
("StrictFloat",),
("StrictBool"),
("constr(strict=True, min_length=10)",),
("Optional[StrictStr]",),
("Union[None, StrictStr]",),
("List[StrictStr]",),
("List[List[StrictStr]]",),
("Dict[StrictStr, StrictStr]",),
("TypedDict('D', x=StrictInt)",),
]
)
def test_field_holding_accepted_type_doesnt_raise(self, annotation: str) -> None:
with monkeypatch_pydantic():
run_test_snippet(
f"""
from typing import *
try:
from pydantic.v1 import *
except ImportError:
from pydantic import *
class C(BaseModel):
f: {annotation}
"""
)
def test_field_holding_str_raises_with_alternative_import(self) -> None:
with monkeypatch_pydantic(), self.assertRaises(ModelCheckerException):
run_test_snippet(
"""
try:
from pydantic.v1.main import BaseModel
except ImportError:
from pydantic.main import BaseModel
class C(BaseModel):
f: str
"""
)
parser = argparse.ArgumentParser()
parser.add_argument("mode", choices=["lint", "test"], default="lint", nargs="?")
parser.add_argument("-v", "--verbose", action="store_true")
if __name__ == "__main__":
args = parser.parse_args(sys.argv[1:])
logging.basicConfig(
format="%(asctime)s %(name)s:%(lineno)d %(levelname)s %(message)s",
level=logging.DEBUG if args.verbose else logging.INFO,
)
# suppress logs we don't care about
logging.getLogger("xmlschema").setLevel(logging.WARNING)
if args.mode == "lint":
sys.exit(lint())
elif args.mode == "test":
unittest.main(argv=sys.argv[:1])

View File

@@ -5,15 +5,19 @@
# Also checks that schema deltas do not try and create or drop indices.
import re
from typing import Any, Dict, List
from typing import Any
import click
import git
SCHEMA_FILE_REGEX = re.compile(r"^synapse/storage/schema/(.*)/delta/(.*)/(.*)$")
INDEX_CREATION_REGEX = re.compile(r"CREATE .*INDEX .*ON ([a-z_]+)", flags=re.IGNORECASE)
INDEX_DELETION_REGEX = re.compile(r"DROP .*INDEX ([a-z_]+)", flags=re.IGNORECASE)
TABLE_CREATION_REGEX = re.compile(r"CREATE .*TABLE ([a-z_]+)", flags=re.IGNORECASE)
INDEX_CREATION_REGEX = re.compile(
r"CREATE .*INDEX .*ON ([a-z_0-9]+)", flags=re.IGNORECASE
)
INDEX_DELETION_REGEX = re.compile(r"DROP .*INDEX ([a-z_0-9]+)", flags=re.IGNORECASE)
TABLE_CREATION_REGEX = re.compile(
r"CREATE .*TABLE.* ([a-z_0-9]+)\s*\(", flags=re.IGNORECASE
)
# The base branch we want to check against. We use the main development branch
# on the assumption that is what we are developing against.
@@ -48,16 +52,16 @@ def main(force_colors: bool) -> None:
r = repo.git.show(f"origin/{DEVELOP_BRANCH}:synapse/storage/schema/__init__.py")
locals: Dict[str, Any] = {}
locals: dict[str, Any] = {}
exec(r, locals)
current_schema_version = locals["SCHEMA_VERSION"]
diffs: List[git.Diff] = repo.remote().refs[DEVELOP_BRANCH].commit.diff(None)
diffs: list[git.Diff] = repo.remote().refs[DEVELOP_BRANCH].commit.diff(None)
# Get the schema version of the local file to check against current schema on develop
with open("synapse/storage/schema/__init__.py") as file:
local_schema = file.read()
new_locals: Dict[str, Any] = {}
new_locals: dict[str, Any] = {}
exec(local_schema, new_locals)
local_schema_version = new_locals["SCHEMA_VERSION"]
@@ -173,11 +177,14 @@ def main(force_colors: bool) -> None:
clause = match.group()
click.secho(
f"Found delta with index deletion: '{clause}' in {delta_file}\nThese should be in background updates.",
f"Found delta with index deletion: '{clause}' in {delta_file}",
fg="red",
bold=True,
color=force_colors,
)
click.secho(
" ↪ These should be in background updates.",
)
return_code = 1
# Check for index creation, which is only allowed for tables we've
@@ -188,11 +195,14 @@ def main(force_colors: bool) -> None:
table_name = match.group(1)
if table_name not in created_tables:
click.secho(
f"Found delta with index creation: '{clause}' in {delta_file}\nThese should be in background updates.",
f"Found delta with index creation for existing table: '{clause}' in {delta_file}",
fg="red",
bold=True,
color=force_colors,
)
click.secho(
" ↪ These should be in background updates (or the table should be created in the same delta).",
)
return_code = 1
click.get_current_context().exit(return_code)

View File

@@ -43,7 +43,7 @@ import argparse
import base64
import json
import sys
from typing import Any, Dict, Mapping, Optional, Tuple, Union
from typing import Any, Mapping
from urllib import parse as urlparse
import requests
@@ -103,12 +103,12 @@ def sign_json(
def request(
method: Optional[str],
method: str | None,
origin_name: str,
origin_key: signedjson.types.SigningKey,
destination: str,
path: str,
content: Optional[str],
content: str | None,
verify_tls: bool,
) -> requests.Response:
if method is None:
@@ -147,7 +147,7 @@ def request(
s = requests.Session()
s.mount("matrix-federation://", MatrixConnectionAdapter())
headers: Dict[str, str] = {
headers: dict[str, str] = {
"Authorization": authorization_headers[0],
}
@@ -301,9 +301,9 @@ class MatrixConnectionAdapter(HTTPAdapter):
def get_connection_with_tls_context(
self,
request: PreparedRequest,
verify: Optional[Union[bool, str]],
proxies: Optional[Mapping[str, str]] = None,
cert: Optional[Union[Tuple[str, str], str]] = None,
verify: bool | str | None,
proxies: Mapping[str, str] | None = None,
cert: tuple[str, str] | str | None = None,
) -> HTTPConnectionPool:
# overrides the get_connection_with_tls_context() method in the base class
parsed = urlparse.urlsplit(request.url)
@@ -326,7 +326,7 @@ class MatrixConnectionAdapter(HTTPAdapter):
)
@staticmethod
def _lookup(server_name: str) -> Tuple[str, int, str]:
def _lookup(server_name: str) -> tuple[str, int, str]:
"""
Do an SRV lookup on a server name and return the host:port to connect to
Given the server_name (after any .well-known lookup), return the host, port and
@@ -368,7 +368,7 @@ class MatrixConnectionAdapter(HTTPAdapter):
return server_name, 8448, server_name
@staticmethod
def _get_well_known(server_name: str) -> Optional[str]:
def _get_well_known(server_name: str) -> str | None:
if ":" in server_name:
# explicit port, or ipv6 literal. Either way, no .well-known
return None

View File

@@ -4,7 +4,7 @@
import json
import re
import sys
from typing import Any, Optional
from typing import Any
import yaml
@@ -259,17 +259,17 @@ def indent(text: str, first_line: bool = True) -> str:
return text
def em(s: Optional[str]) -> str:
def em(s: str | None) -> str:
"""Add emphasis to text."""
return f"*{s}*" if s else ""
def a(s: Optional[str], suffix: str = " ") -> str:
def a(s: str | None, suffix: str = " ") -> str:
"""Appends a space if the given string is not empty."""
return s + suffix if s else ""
def p(s: Optional[str], prefix: str = " ") -> str:
def p(s: str | None, prefix: str = " ") -> str:
"""Prepend a space if the given string is not empty."""
return prefix + s if s else ""

View File

@@ -134,9 +134,6 @@ fi
# Ensure the formatting of Rust code.
cargo-fmt
# Ensure all Pydantic models use strict types.
./scripts-dev/check_pydantic_models.py lint
# Ensure type hints are correct.
mypy

View File

@@ -24,7 +24,7 @@ can crop up, e.g the cache descriptors.
"""
import enum
from typing import Callable, Mapping, Optional, Tuple, Type, Union
from typing import Callable, Mapping
import attr
import mypy.types
@@ -123,7 +123,7 @@ class ArgLocation:
"""
prometheus_metric_fullname_to_label_arg_map: Mapping[str, Optional[ArgLocation]] = {
prometheus_metric_fullname_to_label_arg_map: Mapping[str, ArgLocation | None] = {
# `Collector` subclasses:
"prometheus_client.metrics.MetricWrapperBase": ArgLocation("labelnames", 2),
"prometheus_client.metrics.Counter": ArgLocation("labelnames", 2),
@@ -184,8 +184,8 @@ should be in the source code.
# Unbound at this point because we don't know the mypy version yet.
# This is set in the `plugin(...)` function below.
MypyPydanticPluginClass: Type[Plugin]
MypyZopePluginClass: Type[Plugin]
MypyPydanticPluginClass: type[Plugin]
MypyZopePluginClass: type[Plugin]
class SynapsePlugin(Plugin):
@@ -211,7 +211,7 @@ class SynapsePlugin(Plugin):
def get_base_class_hook(
self, fullname: str
) -> Optional[Callable[[ClassDefContext], None]]:
) -> Callable[[ClassDefContext], None] | None:
def _get_base_class_hook(ctx: ClassDefContext) -> None:
# Run any `get_base_class_hook` checks from other plugins first.
#
@@ -232,7 +232,7 @@ class SynapsePlugin(Plugin):
def get_function_signature_hook(
self, fullname: str
) -> Optional[Callable[[FunctionSigContext], FunctionLike]]:
) -> Callable[[FunctionSigContext], FunctionLike] | None:
# Strip off the unique identifier for classes that are dynamically created inside
# functions. ex. `synapse.metrics.jemalloc.JemallocCollector@185` (this is the line
# number)
@@ -262,7 +262,7 @@ class SynapsePlugin(Plugin):
def get_method_signature_hook(
self, fullname: str
) -> Optional[Callable[[MethodSigContext], CallableType]]:
) -> Callable[[MethodSigContext], CallableType] | None:
if fullname.startswith(
(
"synapse.util.caches.descriptors.CachedFunction.__call__",
@@ -721,7 +721,7 @@ def check_is_cacheable_wrapper(ctx: MethodSigContext) -> CallableType:
def check_is_cacheable(
signature: CallableType,
ctx: Union[MethodSigContext, FunctionSigContext],
ctx: MethodSigContext | FunctionSigContext,
) -> None:
"""
Check if a callable returns a type which can be cached.
@@ -795,7 +795,7 @@ AT_CACHED_MUTABLE_RETURN = ErrorCode(
def is_cacheable(
rt: mypy.types.Type, signature: CallableType, verbose: bool
) -> Tuple[bool, Optional[str]]:
) -> tuple[bool, str | None]:
"""
Check if a particular type is cachable.
@@ -905,7 +905,7 @@ def is_cacheable(
return False, f"Don't know how to handle {type(rt).__qualname__} return type"
def plugin(version: str) -> Type[SynapsePlugin]:
def plugin(version: str) -> type[SynapsePlugin]:
global MypyPydanticPluginClass, MypyZopePluginClass
# This is the entry point of the plugin, and lets us deal with the fact
# that the mypy plugin interface is *not* stable by looking at the version

View File

@@ -32,11 +32,13 @@ import time
import urllib.request
from os import path
from tempfile import TemporaryDirectory
from typing import Any, List, Match, Optional, Union
from typing import Any, Match
import attr
import click
import git
import github
import github.Auth
from click.exceptions import ClickException
from git import GitCommandError, Repo
from github import BadCredentialsException, Github
@@ -314,7 +316,10 @@ def _prepare() -> None:
)
print("Opening the changelog in your browser...")
print("Please ask #synapse-dev to give it a check.")
print(
"Please review it using the release notes review checklist: https://element-hq.github.io/synapse/develop/development/internal_documentation/release_notes_review_checklist.html"
)
print("And post it in #synapse-dev for cursory review from the team.")
click.launch(
f"https://github.com/element-hq/synapse/blob/{synapse_repo.active_branch.name}/CHANGES.md"
)
@@ -322,11 +327,11 @@ def _prepare() -> None:
@cli.command()
@click.option("--gh-token", envvar=["GH_TOKEN", "GITHUB_TOKEN"])
def tag(gh_token: Optional[str]) -> None:
def tag(gh_token: str | None) -> None:
_tag(gh_token)
def _tag(gh_token: Optional[str]) -> None:
def _tag(gh_token: str | None) -> None:
"""Tags the release and generates a draft GitHub release"""
# Test that the GH Token is valid before continuing.
@@ -397,7 +402,7 @@ def _tag(gh_token: Optional[str]) -> None:
return
# Create a new draft release
gh = Github(gh_token)
gh = Github(auth=github.Auth.Token(token=gh_token))
gh_repo = gh.get_repo("element-hq/synapse")
release = gh_repo.create_git_release(
tag=tag_name,
@@ -428,7 +433,7 @@ def _publish(gh_token: str) -> None:
if gh_token:
# Test that the GH Token is valid before continuing.
gh = Github(gh_token)
gh = Github(auth=github.Auth.Token(token=gh_token))
gh.get_user()
# Make sure we're in a git repo.
@@ -441,7 +446,7 @@ def _publish(gh_token: str) -> None:
return
# Publish the draft release
gh = Github(gh_token)
gh = Github(auth=github.Auth.Token(token=gh_token))
gh_repo = gh.get_repo("element-hq/synapse")
for release in gh_repo.get_releases():
if release.title == tag_name:
@@ -466,11 +471,11 @@ def _publish(gh_token: str) -> None:
@cli.command()
@click.option("--gh-token", envvar=["GH_TOKEN", "GITHUB_TOKEN"], required=False)
def upload(gh_token: Optional[str]) -> None:
def upload(gh_token: str | None) -> None:
_upload(gh_token)
def _upload(gh_token: Optional[str]) -> None:
def _upload(gh_token: str | None) -> None:
"""Upload release to pypi."""
# Test that the GH Token is valid before continuing.
@@ -486,8 +491,13 @@ def _upload(gh_token: Optional[str]) -> None:
click.echo(f"Tag {tag_name} ({tag.commit}) is not currently checked out!")
click.get_current_context().abort()
if gh_token:
gh = Github(auth=github.Auth.Token(token=gh_token))
else:
# Use github anonymously.
gh = Github()
# Query all the assets corresponding to this release.
gh = Github(gh_token)
gh_repo = gh.get_repo("element-hq/synapse")
gh_release = gh_repo.get_release(tag_name)
@@ -566,11 +576,11 @@ def _merge_into(repo: Repo, source: str, target: str) -> None:
@cli.command()
@click.option("--gh-token", envvar=["GH_TOKEN", "GITHUB_TOKEN"], required=False)
def wait_for_actions(gh_token: Optional[str]) -> None:
def wait_for_actions(gh_token: str | None) -> None:
_wait_for_actions(gh_token)
def _wait_for_actions(gh_token: Optional[str]) -> None:
def _wait_for_actions(gh_token: str | None) -> None:
# Test that the GH Token is valid before continuing.
check_valid_gh_token(gh_token)
@@ -639,7 +649,16 @@ def _notify(message: str) -> None:
@cli.command()
def merge_back() -> None:
# Although this option is not used, allow it anyways. Otherwise the user will
# receive an error when providing it, which is annoying as other commands accept
# it.
@click.option(
"--gh-token",
"_gh_token",
envvar=["GH_TOKEN", "GITHUB_TOKEN"],
required=False,
)
def merge_back(_gh_token: str | None) -> None:
_merge_back()
@@ -687,7 +706,16 @@ def _merge_back() -> None:
@cli.command()
def announce() -> None:
# Although this option is not used, allow it anyways. Otherwise the user will
# receive an error when providing it, which is annoying as other commands accept
# it.
@click.option(
"--gh-token",
"_gh_token",
envvar=["GH_TOKEN", "GITHUB_TOKEN"],
required=False,
)
def announce(_gh_token: str | None) -> None:
_announce()
@@ -696,18 +724,31 @@ def _announce() -> None:
current_version = get_package_version()
tag_name = f"v{current_version}"
is_rc = "rc" in tag_name
release_text = f"""
### Synapse {current_version} {"🧪" if is_rc else "🚀"}
click.echo(
f"""
Hi everyone. Synapse {current_version} has just been released.
"""
if "rc" in tag_name:
release_text += (
"\nThis is a release candidate. Please help us test it out "
"before the final release by deploying it to non-production environments, "
"and reporting any issues you find to "
"[the issue tracker](https://github.com/element-hq/synapse/issues). Thanks!\n"
)
release_text += f"""
[notes](https://github.com/element-hq/synapse/releases/tag/{tag_name}) | \
[docker](https://hub.docker.com/r/matrixdotorg/synapse/tags?name={tag_name}) | \
[debs](https://packages.matrix.org/debian/) | \
[pypi](https://pypi.org/project/matrix-synapse/{current_version}/)"""
)
if "rc" in tag_name:
click.echo(release_text)
if is_rc:
click.echo(
"""
Announce the RC in
@@ -732,7 +773,7 @@ Ask the designated people to do the blog and tweets."""
def full(gh_token: str) -> None:
if gh_token:
# Test that the GH Token is valid before continuing.
gh = Github(gh_token)
gh = Github(auth=github.Auth.Token(token=gh_token))
gh.get_user()
click.echo("1. If this is a security release, read the security wiki page.")
@@ -801,12 +842,16 @@ def get_repo_and_check_clean_checkout(
raise click.ClickException(
f"{path} is not a git repository (expecting a {name} repository)."
)
if repo.is_dirty():
raise click.ClickException(f"Uncommitted changes exist in {path}.")
while repo.is_dirty():
if not click.confirm(
f"Uncommitted changes exist in {path}. Commit or stash them. Ready to continue?"
):
raise click.ClickException("Aborted.")
return repo
def check_valid_gh_token(gh_token: Optional[str]) -> None:
def check_valid_gh_token(gh_token: str | None) -> None:
"""Check that a github token is valid, if supplied"""
if not gh_token:
@@ -814,7 +859,7 @@ def check_valid_gh_token(gh_token: Optional[str]) -> None:
return
try:
gh = Github(gh_token)
gh = Github(auth=github.Auth.Token(token=gh_token))
# We need to lookup name to trigger a request.
_name = gh.get_user().name
@@ -822,7 +867,7 @@ def check_valid_gh_token(gh_token: Optional[str]) -> None:
raise click.ClickException(f"Github credentials are bad: {e}")
def find_ref(repo: git.Repo, ref_name: str) -> Optional[git.HEAD]:
def find_ref(repo: git.Repo, ref_name: str) -> git.HEAD | None:
"""Find the branch/ref, looking first locally then in the remote."""
if ref_name in repo.references:
return repo.references[ref_name]
@@ -859,9 +904,9 @@ def get_changes_for_version(wanted_version: version.Version) -> str:
# These are 0-based.
start_line: int
end_line: Optional[int] = None # Is none if its the last entry
end_line: int | None = None # Is none if its the last entry
headings: List[VersionSection] = []
headings: list[VersionSection] = []
for i, token in enumerate(tokens):
# We look for level 1 headings (h1 tags).
if token.type != "heading_open" or token.tag != "h1":
@@ -946,7 +991,7 @@ def build_dependabot_changelog(repo: Repo, current_version: version.Version) ->
messages = []
for commit in reversed(commits):
if commit.author.name == "dependabot[bot]":
message: Union[str, bytes] = commit.message
message: str | bytes = commit.message
if isinstance(message, bytes):
message = message.decode("utf-8")
messages.append(message.split("\n", maxsplit=1)[0])

View File

@@ -38,7 +38,7 @@ import io
import json
import sys
from collections import defaultdict
from typing import Any, Dict, Iterator, Optional, Tuple
from typing import Any, Iterator
import git
from packaging import version
@@ -57,7 +57,7 @@ SCHEMA_VERSION_FILES = (
OLDEST_SHOWN_VERSION = version.parse("v1.0")
def get_schema_versions(tag: git.Tag) -> Tuple[Optional[int], Optional[int]]:
def get_schema_versions(tag: git.Tag) -> tuple[int | None, int | None]:
"""Get the schema and schema compat versions for a tag."""
schema_version = None
schema_compat_version = None
@@ -81,7 +81,7 @@ def get_schema_versions(tag: git.Tag) -> Tuple[Optional[int], Optional[int]]:
# SCHEMA_COMPAT_VERSION is sometimes across multiple lines, the easist
# thing to do is exec the code. Luckily it has only ever existed in
# a file which imports nothing else from Synapse.
locals: Dict[str, Any] = {}
locals: dict[str, Any] = {}
exec(schema_file.data_stream.read().decode("utf-8"), {}, locals)
schema_version = locals["SCHEMA_VERSION"]
schema_compat_version = locals.get("SCHEMA_COMPAT_VERSION")