Improve type hints for cached decorator. (#15658)

The cached decorators always return a Deferred, which was not
properly propagated. It was close enough when wrapping coroutines,
but failed if a bare function was wrapped.
This commit is contained in:
Patrick Cloke
2023-05-24 08:59:31 -04:00
committed by GitHub
parent 379eb2d7ab
commit 1f55c04cbc
6 changed files with 73 additions and 63 deletions

View File

@@ -18,10 +18,11 @@ can crop up, e.g the cache descriptors.
from typing import Callable, Optional, Type
from mypy.erasetype import remove_instance_last_known_values
from mypy.nodes import ARG_NAMED_OPT
from mypy.plugin import MethodSigContext, Plugin
from mypy.typeops import bind_self
from mypy.types import CallableType, NoneType, UnionType
from mypy.types import CallableType, Instance, NoneType, UnionType
class SynapsePlugin(Plugin):
@@ -92,10 +93,41 @@ def cached_function_method_signature(ctx: MethodSigContext) -> CallableType:
arg_names.append("on_invalidate")
arg_kinds.append(ARG_NAMED_OPT) # Arg is an optional kwarg.
# Finally we ensure the return type is a Deferred.
if (
isinstance(signature.ret_type, Instance)
and signature.ret_type.type.fullname == "twisted.internet.defer.Deferred"
):
# If it is already a Deferred, nothing to do.
ret_type = signature.ret_type
else:
ret_arg = None
if isinstance(signature.ret_type, Instance):
# If a coroutine, wrap the coroutine's return type in a Deferred.
if signature.ret_type.type.fullname == "typing.Coroutine":
ret_arg = signature.ret_type.args[2]
# If an awaitable, wrap the awaitable's final value in a Deferred.
elif signature.ret_type.type.fullname == "typing.Awaitable":
ret_arg = signature.ret_type.args[0]
# Otherwise, wrap the return value in a Deferred.
if ret_arg is None:
ret_arg = signature.ret_type
# This should be able to use ctx.api.named_generic_type, but that doesn't seem
# to find the correct symbol for anything more than 1 module deep.
#
# modules is not part of CheckerPluginInterface. The following is a combination
# of TypeChecker.named_generic_type and TypeChecker.lookup_typeinfo.
sym = ctx.api.modules["twisted.internet.defer"].names.get("Deferred") # type: ignore[attr-defined]
ret_type = Instance(sym.node, [remove_instance_last_known_values(ret_arg)])
signature = signature.copy_modified(
arg_types=arg_types,
arg_names=arg_names,
arg_kinds=arg_kinds,
ret_type=ret_type,
)
return signature