Skip to content

Fix decorators on class __init__ methods #2

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 19 additions & 8 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
Context, Decorator, PrintStmt, BreakStmt, PassStmt, ContinueStmt,
ComparisonExpr, StarExpr, EllipsisExpr, RefExpr, PromoteExpr,
Import, ImportFrom, ImportAll, ImportBase, TypeAlias,
ARG_POS, ARG_STAR, LITERAL_TYPE, MDEF, GDEF,
CONTRAVARIANT, COVARIANT, INVARIANT,
ARG_POS, ARG_STAR, LITERAL_TYPE, MDEF, GDEF, CallableDecorator,
CONTRAVARIANT, COVARIANT, INVARIANT, get_callable
)
from mypy import nodes
from mypy.literals import literal, literal_hash
Expand Down Expand Up @@ -1619,11 +1619,15 @@ def check_compatibility(self, name: str, base1: TypeInfo,
first = base1[name]
second = base2[name]
first_type = first.type
if first_type is None and isinstance(first.node, FuncBase):
first_type = self.function_type(first.node)
if first_type is None:
method = get_callable(first.node)
if method:
first_type = self.function_type(method)
second_type = second.type
if second_type is None and isinstance(second.node, FuncBase):
second_type = self.function_type(second.node)
if second_type is None:
method = get_callable(second.node)
if method:
second_type = self.function_type(method)
# TODO: What if some classes are generic?
if (isinstance(first_type, FunctionLike) and
isinstance(second_type, FunctionLike)):
Expand Down Expand Up @@ -3019,10 +3023,17 @@ def visit_decorator(self, e: Decorator) -> None:
callable_name=fullname)
self.check_untyped_after_decorator(sig, e.func)
sig = set_callable_name(sig, e.func)
e.var.type = sig
e.var.is_ready = True
if e.func.is_property:
self.check_incompatible_property_override(e)
e.var.type = sig
e.var.is_ready = True
if isinstance(sig, CallableType):
if e.func.is_property:
assert isinstance(sig, CallableType)
if isinstance(sig.ret_type, CallableType):
e.callable_decorator = CallableDecorator(e)
else:
e.callable_decorator = CallableDecorator(e)
if e.func.info and not e.func.is_dynamic():
self.check_method_override(e)

Expand Down
4 changes: 2 additions & 2 deletions mypy/checkmember.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def analyze_member_access(name: str,

# Look up the member. First look up the method dictionary.
method = info.get_method(name)
if method:
if method and not method.is_class:
if method.is_property:
assert isinstance(method, OverloadedFuncDef)
first_item = cast(Decorator, method.items[0])
Expand All @@ -87,7 +87,7 @@ def analyze_member_access(name: str,
msg.cant_assign_to_method(node)
signature = function_type(method, builtin_type('builtins.function'))
signature = freshen_function_type_vars(signature)
if name == '__new__':
if name == '__new__' or method.is_static:
# __new__ is special and behaves like a static method -- don't strip
# the first argument.
pass
Expand Down
39 changes: 34 additions & 5 deletions mypy/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,6 +657,7 @@ class Decorator(SymbolNode, Statement):
# TODO: This is mostly used for the type; consider replacing with a 'type' attribute
var = None # type: Var # Represents the decorated function obj
is_overload = False
callable_decorator = None # type: Optional[CallableDecorator]

def __init__(self, func: FuncDef, decorators: List[Expression],
var: 'Var') -> None:
Expand Down Expand Up @@ -704,6 +705,28 @@ def deserialize(cls, data: JsonDict) -> 'Decorator':
return dec


class CallableDecorator(FuncItem):
"""A wrapper around a Decorator that allows it to be treated as a callable function"""
def __init__(self, decorator: Decorator) -> None:
super().__init__(decorator.func.arguments, decorator.func.body,
cast('mypy.types.CallableType', decorator.type))
self.is_final = decorator.is_final
self.is_class = decorator.func.is_class
self.is_property = decorator.func.is_property
self.is_static = decorator.func.is_static
self.is_overload = decorator.func.is_overload
self.is_generator = decorator.func.is_generator
self.is_async_generator = decorator.func.is_async_generator
self.is_awaitable_coroutine = decorator.func.is_awaitable_coroutine
self.expanded = decorator.func.expanded
self.info = decorator.info
self._name = decorator.func.name()
self._fullname = decorator.func._fullname

def name(self) -> str:
return self._name


VAR_FLAGS = [
'is_self', 'is_initialized_in_class', 'is_staticmethod',
'is_classmethod', 'is_property', 'is_settable_property', 'is_suppressed_import',
Expand Down Expand Up @@ -2308,11 +2331,7 @@ def has_readable_member(self, name: str) -> bool:
def get_method(self, name: str) -> Optional[FuncBase]:
for cls in self.mro:
if name in cls.names:
node = cls.names[name].node
if isinstance(node, FuncBase):
return node
else:
return None
return get_callable(cls.names[name].node)
return None

def calculate_metaclass_type(self) -> 'Optional[mypy.types.Instance]':
Expand Down Expand Up @@ -2935,3 +2954,13 @@ def is_class_var(expr: NameExpr) -> bool:
if isinstance(expr.node, Var):
return expr.node.is_classvar
return False


def get_callable(node: Optional[Node]) -> Optional[FuncBase]:
"""Check if the passed node represents a callable function or funcion-like object"""
if isinstance(node, FuncBase):
return node
elif isinstance(node, Decorator):
return node.callable_decorator
else:
return None
11 changes: 5 additions & 6 deletions mypy/plugins/attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
from mypy.fixup import lookup_qualified_stnode
from mypy.nodes import (
Context, Argument, Var, ARG_OPT, ARG_POS, TypeInfo, AssignmentStmt,
TupleExpr, ListExpr, NameExpr, CallExpr, RefExpr, FuncBase,
is_class_var, TempNode, Decorator, MemberExpr, Expression, FuncDef, Block,
PassStmt, SymbolTableNode, MDEF, JsonDict, OverloadedFuncDef
TupleExpr, ListExpr, NameExpr, CallExpr, RefExpr, is_class_var,
TempNode, Decorator, MemberExpr, Expression, FuncDef, Block,
PassStmt, SymbolTableNode, MDEF, JsonDict, OverloadedFuncDef, get_callable
)
from mypy.plugins.common import (
_get_argument, _get_bool_argument, _get_decorator_bool_argument
Expand Down Expand Up @@ -405,9 +405,8 @@ def _parse_converter(ctx: 'mypy.plugin.ClassDefContext',
# TODO: Support complex converters, e.g. lambdas, calls, etc.
if converter:
if isinstance(converter, RefExpr) and converter.node:
if (isinstance(converter.node, FuncBase)
and converter.node.type
and isinstance(converter.node.type, FunctionLike)):
method = get_callable(converter.node)
if method and method.type and isinstance(method.type, FunctionLike):
return Converter(converter.node.fullname())
elif isinstance(converter.node, TypeInfo):
return Converter(converter.node.fullname())
Expand Down
9 changes: 5 additions & 4 deletions mypy/plugins/common.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from typing import List, Optional, Any

from mypy.nodes import (
ARG_POS, MDEF, Argument, Block, CallExpr, Expression, FuncBase,
FuncDef, PassStmt, RefExpr, SymbolTableNode, Var
ARG_POS, MDEF, Argument, Block, CallExpr, Expression, FuncDef,
PassStmt, RefExpr, SymbolTableNode, Var, get_callable
)
from mypy.plugin import ClassDefContext
from mypy.semanal import set_callable_name
Expand Down Expand Up @@ -53,8 +53,9 @@ def _get_argument(call: CallExpr, name: str) -> Optional[Expression]:
callee_type = None
# mypyc hack to workaround mypy misunderstanding multiple inheritance (#3603)
callee_node = call.callee.node # type: Any
if (isinstance(callee_node, (Var, FuncBase))
and callee_node.type):
if not isinstance(callee_node, Var):
callee_node = get_callable(callee_node)
if callee_node and callee_node.type:
callee_node_type = callee_node.type
if isinstance(callee_node_type, Overloaded):
# We take the last overload.
Expand Down
15 changes: 9 additions & 6 deletions mypy/server/deps.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,9 @@ class 'mod.Cls'. This can also refer to an attribute inherited from a
ImportFrom, CallExpr, CastExpr, TypeVarExpr, TypeApplication, IndexExpr, UnaryExpr, OpExpr,
ComparisonExpr, GeneratorExpr, DictionaryComprehension, StarExpr, PrintStmt, ForStmt, WithStmt,
TupleExpr, OperatorAssignmentStmt, DelStmt, YieldFromExpr, Decorator, Block,
TypeInfo, FuncBase, OverloadedFuncDef, RefExpr, SuperExpr, Var, NamedTupleExpr, TypedDictExpr,
TypeInfo, OverloadedFuncDef, RefExpr, SuperExpr, Var, NamedTupleExpr, TypedDictExpr,
LDEF, MDEF, GDEF, TypeAliasExpr, NewTypeExpr, ImportAll, EnumCallExpr, AwaitExpr,
op_methods, reverse_op_methods, ops_with_inplace_method, unary_op_methods
op_methods, reverse_op_methods, ops_with_inplace_method, unary_op_methods, get_callable
)
from mypy.traverser import TraverserVisitor
from mypy.types import (
Expand Down Expand Up @@ -127,17 +127,18 @@ def get_dependencies_of_target(module_id: str,
# TODO: Add tests for this function.
visitor = DependencyVisitor(type_map, python_version, module_tree.alias_deps)
visitor.scope.enter_file(module_id)
method = get_callable(target)
if isinstance(target, MypyFile):
# Only get dependencies of the top-level of the module. Don't recurse into
# functions.
for defn in target.defs:
# TODO: Recurse into top-level statements and class bodies but skip functions.
if not isinstance(defn, (ClassDef, Decorator, FuncDef, OverloadedFuncDef)):
defn.accept(visitor)
elif isinstance(target, FuncBase) and target.info:
elif method and method.info:
# It's a method.
# TODO: Methods in nested classes.
visitor.scope.enter_class(target.info)
visitor.scope.enter_class(method.info)
target.accept(visitor)
visitor.scope.leave()
else:
Expand Down Expand Up @@ -425,8 +426,10 @@ def visit_assignment_stmt(self, o: AssignmentStmt) -> None:
if isinstance(rvalue.callee.node, TypeInfo):
# use actual __init__ as a dependency source
init = rvalue.callee.node.get('__init__')
if init and isinstance(init.node, FuncBase):
fname = init.node.fullname()
if init:
method = get_callable(init.node)
if method:
fname = method.fullname()
else:
fname = rvalue.callee.fullname
if fname is None:
Expand Down
12 changes: 10 additions & 2 deletions mypy/subtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,7 +572,10 @@ def get_member_flags(name: str, info: TypeInfo) -> Set[int]:
assert isinstance(dec, Decorator)
if dec.var.is_settable_property or setattr_meth:
return {IS_SETTABLE}
return set()
if method.is_static or method.is_class:
return {IS_CLASS_OR_STATIC}
else:
return set()
node = info.get(name)
if not node:
if setattr_meth:
Expand Down Expand Up @@ -604,7 +607,12 @@ def find_node_type(node: Union[Var, FuncBase], itype: Instance, subtype: Type) -
if typ is None:
return AnyType(TypeOfAny.from_error)
# We don't need to bind 'self' for static methods, since there is no 'self'.
if isinstance(node, FuncBase) or isinstance(typ, FunctionLike) and not node.is_staticmethod:
need_bind = False
if isinstance(node, FuncBase):
need_bind = not node.is_static
elif isinstance(typ, FunctionLike):
need_bind = not node.is_staticmethod
if need_bind:
assert isinstance(typ, FunctionLike)
signature = bind_self(typ, subtype)
if node.is_property:
Expand Down
26 changes: 26 additions & 0 deletions test-data/unit/check-classes.test
Original file line number Diff line number Diff line change
Expand Up @@ -5610,3 +5610,29 @@ from typing import TypeVar, Tuple, Callable
T = TypeVar('T')
def deco(f: Callable[..., T]) -> Callable[..., Tuple[T, int]]: ...
[out]

[case testDecoratedInit]
from typing import Callable, Any
def dec(func: Callable[[Any], None]) -> Callable[[Any], None]:
return func

class A:
@dec
def __init__(self):
pass

reveal_type(A()) # E: Revealed type is '__main__.A'
[out]

[case testAbstractInit]
from abc import abstractmethod
class A:
@abstractmethod
def __init__(self): ...

class B(A):
def __init__(self):
pass

reveal_type(B()) # E: Revealed type is '__main__.B'
[out]
13 changes: 13 additions & 0 deletions test-data/unit/check-functions.test
Original file line number Diff line number Diff line change
Expand Up @@ -2383,3 +2383,16 @@ def foo() -> None:

def lol():
x = foo()


[case testNonCallableDecorator]
def dec(func) -> int:
return 1

@dec
def f():
pass

reveal_type(f) # E: Revealed type is 'builtins.int'
f() # E: "int" not callable
[out]