Skip to content

Commit 4e34fec

Browse files
authored
Now TypeInfo.get_method also returns Decorator nodes (#11150)
Support decorators properly in additional contexts. Closes #10409
1 parent 872bc86 commit 4e34fec

10 files changed

+505
-66
lines changed

mypy/checker.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@
4646
)
4747
import mypy.checkexpr
4848
from mypy.checkmember import (
49-
analyze_member_access, analyze_descriptor_access, type_object_type,
49+
MemberContext, analyze_member_access, analyze_descriptor_access, analyze_var,
50+
type_object_type,
5051
)
5152
from mypy.typeops import (
5253
map_type_from_supertype, bind_self, erase_to_bound, make_simplified_union,
@@ -3205,9 +3206,12 @@ def check_member_assignment(self, instance_type: Type, attribute_type: Type,
32053206
code=codes.ASSIGNMENT)
32063207
return rvalue_type, attribute_type, True
32073208

3208-
get_type = analyze_descriptor_access(
3209-
instance_type, attribute_type, self.named_type,
3210-
self.msg, context, chk=self)
3209+
mx = MemberContext(
3210+
is_lvalue=False, is_super=False, is_operator=False,
3211+
original_type=instance_type, context=context, self_type=None,
3212+
msg=self.msg, chk=self,
3213+
)
3214+
get_type = analyze_descriptor_access(attribute_type, mx)
32113215
if not attribute_type.type.has_readable_member('__set__'):
32123216
# If there is no __set__, we type-check that the assigned value matches
32133217
# the return type of __get__. This doesn't match the python semantics,
@@ -3221,9 +3225,15 @@ def check_member_assignment(self, instance_type: Type, attribute_type: Type,
32213225
if dunder_set is None:
32223226
self.fail(message_registry.DESCRIPTOR_SET_NOT_CALLABLE.format(attribute_type), context)
32233227
return AnyType(TypeOfAny.from_error), get_type, False
3224-
3225-
function = function_type(dunder_set, self.named_type('builtins.function'))
3226-
bound_method = bind_self(function, attribute_type)
3228+
if isinstance(dunder_set, Decorator):
3229+
bound_method = analyze_var(
3230+
'__set__', dunder_set.var, attribute_type, attribute_type.type, mx,
3231+
)
3232+
else:
3233+
bound_method = bind_self(
3234+
function_type(dunder_set, self.named_type('builtins.function')),
3235+
attribute_type,
3236+
)
32273237
typ = map_instance_to_supertype(attribute_type, dunder_set.info)
32283238
dunder_set_type = expand_type_by_instance(bound_method, typ)
32293239

@@ -6214,6 +6224,12 @@ def is_untyped_decorator(typ: Optional[Type]) -> bool:
62146224
elif isinstance(typ, Instance):
62156225
method = typ.type.get_method('__call__')
62166226
if method:
6227+
if isinstance(method, Decorator):
6228+
return (
6229+
is_untyped_decorator(method.func.type)
6230+
or is_untyped_decorator(method.var.type)
6231+
)
6232+
62176233
if isinstance(method.type, Overloaded):
62186234
return any(is_untyped_decorator(item) for item in method.type.items)
62196235
else:

mypy/checkmember.py

Lines changed: 64 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -67,14 +67,17 @@ def not_ready_callback(self, name: str, context: Context) -> None:
6767
self.chk.handle_cannot_determine_type(name, context)
6868

6969
def copy_modified(self, *, messages: Optional[MessageBuilder] = None,
70-
self_type: Optional[Type] = None) -> 'MemberContext':
70+
self_type: Optional[Type] = None,
71+
is_lvalue: Optional[bool] = None) -> 'MemberContext':
7172
mx = MemberContext(self.is_lvalue, self.is_super, self.is_operator,
7273
self.original_type, self.context, self.msg, self.chk,
7374
self.self_type, self.module_symbol_table)
7475
if messages is not None:
7576
mx.msg = messages
7677
if self_type is not None:
7778
mx.self_type = self_type
79+
if is_lvalue is not None:
80+
mx.is_lvalue = is_lvalue
7881
return mx
7982

8083

@@ -197,7 +200,7 @@ def analyze_instance_member_access(name: str,
197200

198201
# Look up the member. First look up the method dictionary.
199202
method = info.get_method(name)
200-
if method:
203+
if method and not isinstance(method, Decorator):
201204
if method.is_property:
202205
assert isinstance(method, OverloadedFuncDef)
203206
first_item = cast(Decorator, method.items[0])
@@ -390,29 +393,46 @@ def analyze_member_var_access(name: str,
390393
if not mx.is_lvalue:
391394
for method_name in ('__getattribute__', '__getattr__'):
392395
method = info.get_method(method_name)
396+
393397
# __getattribute__ is defined on builtins.object and returns Any, so without
394398
# the guard this search will always find object.__getattribute__ and conclude
395399
# that the attribute exists
396400
if method and method.info.fullname != 'builtins.object':
397-
function = function_type(method, mx.named_type('builtins.function'))
398-
bound_method = bind_self(function, mx.self_type)
401+
if isinstance(method, Decorator):
402+
# https://github.com/python/mypy/issues/10409
403+
bound_method = analyze_var(method_name, method.var, itype, info, mx)
404+
else:
405+
bound_method = bind_self(
406+
function_type(method, mx.named_type('builtins.function')),
407+
mx.self_type,
408+
)
399409
typ = map_instance_to_supertype(itype, method.info)
400410
getattr_type = get_proper_type(expand_type_by_instance(bound_method, typ))
401411
if isinstance(getattr_type, CallableType):
402412
result = getattr_type.ret_type
403-
404-
# Call the attribute hook before returning.
405-
fullname = '{}.{}'.format(method.info.fullname, name)
406-
hook = mx.chk.plugin.get_attribute_hook(fullname)
407-
if hook:
408-
result = hook(AttributeContext(get_proper_type(mx.original_type),
409-
result, mx.context, mx.chk))
410-
return result
413+
else:
414+
result = getattr_type
415+
416+
# Call the attribute hook before returning.
417+
fullname = '{}.{}'.format(method.info.fullname, name)
418+
hook = mx.chk.plugin.get_attribute_hook(fullname)
419+
if hook:
420+
result = hook(AttributeContext(get_proper_type(mx.original_type),
421+
result, mx.context, mx.chk))
422+
return result
411423
else:
412424
setattr_meth = info.get_method('__setattr__')
413425
if setattr_meth and setattr_meth.info.fullname != 'builtins.object':
414-
setattr_func = function_type(setattr_meth, mx.named_type('builtins.function'))
415-
bound_type = bind_self(setattr_func, mx.self_type)
426+
if isinstance(setattr_meth, Decorator):
427+
bound_type = analyze_var(
428+
name, setattr_meth.var, itype, info,
429+
mx.copy_modified(is_lvalue=False),
430+
)
431+
else:
432+
bound_type = bind_self(
433+
function_type(setattr_meth, mx.named_type('builtins.function')),
434+
mx.self_type,
435+
)
416436
typ = map_instance_to_supertype(itype, setattr_meth.info)
417437
setattr_type = get_proper_type(expand_type_by_instance(bound_type, typ))
418438
if isinstance(setattr_type, CallableType) and len(setattr_type.arg_types) > 0:
@@ -441,32 +461,24 @@ def check_final_member(name: str, info: TypeInfo, msg: MessageBuilder, ctx: Cont
441461
msg.cant_assign_to_final(name, attr_assign=True, ctx=ctx)
442462

443463

444-
def analyze_descriptor_access(instance_type: Type,
445-
descriptor_type: Type,
446-
named_type: Callable[[str], Instance],
447-
msg: MessageBuilder,
448-
context: Context, *,
449-
chk: 'mypy.checker.TypeChecker') -> Type:
464+
def analyze_descriptor_access(descriptor_type: Type,
465+
mx: MemberContext) -> Type:
450466
"""Type check descriptor access.
451467
452468
Arguments:
453-
instance_type: The type of the instance on which the descriptor
454-
attribute is being accessed (the type of ``a`` in ``a.f`` when
455-
``f`` is a descriptor).
456469
descriptor_type: The type of the descriptor attribute being accessed
457470
(the type of ``f`` in ``a.f`` when ``f`` is a descriptor).
458-
context: The node defining the context of this inference.
471+
mx: The current member access context.
459472
Return:
460473
The return type of the appropriate ``__get__`` overload for the descriptor.
461474
"""
462-
instance_type = get_proper_type(instance_type)
475+
instance_type = get_proper_type(mx.original_type)
463476
descriptor_type = get_proper_type(descriptor_type)
464477

465478
if isinstance(descriptor_type, UnionType):
466479
# Map the access over union types
467480
return make_simplified_union([
468-
analyze_descriptor_access(instance_type, typ, named_type,
469-
msg, context, chk=chk)
481+
analyze_descriptor_access(typ, mx)
470482
for typ in descriptor_type.items
471483
])
472484
elif not isinstance(descriptor_type, Instance):
@@ -476,13 +488,21 @@ def analyze_descriptor_access(instance_type: Type,
476488
return descriptor_type
477489

478490
dunder_get = descriptor_type.type.get_method('__get__')
479-
480491
if dunder_get is None:
481-
msg.fail(message_registry.DESCRIPTOR_GET_NOT_CALLABLE.format(descriptor_type), context)
492+
mx.msg.fail(message_registry.DESCRIPTOR_GET_NOT_CALLABLE.format(descriptor_type),
493+
mx.context)
482494
return AnyType(TypeOfAny.from_error)
483495

484-
function = function_type(dunder_get, named_type('builtins.function'))
485-
bound_method = bind_self(function, descriptor_type)
496+
if isinstance(dunder_get, Decorator):
497+
bound_method = analyze_var(
498+
'__set__', dunder_get.var, descriptor_type, descriptor_type.type, mx,
499+
)
500+
else:
501+
bound_method = bind_self(
502+
function_type(dunder_get, mx.named_type('builtins.function')),
503+
descriptor_type,
504+
)
505+
486506
typ = map_instance_to_supertype(descriptor_type, dunder_get.info)
487507
dunder_get_type = expand_type_by_instance(bound_method, typ)
488508

@@ -495,19 +515,19 @@ def analyze_descriptor_access(instance_type: Type,
495515
else:
496516
owner_type = instance_type
497517

498-
callable_name = chk.expr_checker.method_fullname(descriptor_type, "__get__")
499-
dunder_get_type = chk.expr_checker.transform_callee_type(
518+
callable_name = mx.chk.expr_checker.method_fullname(descriptor_type, "__get__")
519+
dunder_get_type = mx.chk.expr_checker.transform_callee_type(
500520
callable_name, dunder_get_type,
501-
[TempNode(instance_type, context=context),
502-
TempNode(TypeType.make_normalized(owner_type), context=context)],
503-
[ARG_POS, ARG_POS], context, object_type=descriptor_type,
521+
[TempNode(instance_type, context=mx.context),
522+
TempNode(TypeType.make_normalized(owner_type), context=mx.context)],
523+
[ARG_POS, ARG_POS], mx.context, object_type=descriptor_type,
504524
)
505525

506-
_, inferred_dunder_get_type = chk.expr_checker.check_call(
526+
_, inferred_dunder_get_type = mx.chk.expr_checker.check_call(
507527
dunder_get_type,
508-
[TempNode(instance_type, context=context),
509-
TempNode(TypeType.make_normalized(owner_type), context=context)],
510-
[ARG_POS, ARG_POS], context, object_type=descriptor_type,
528+
[TempNode(instance_type, context=mx.context),
529+
TempNode(TypeType.make_normalized(owner_type), context=mx.context)],
530+
[ARG_POS, ARG_POS], mx.context, object_type=descriptor_type,
511531
callable_name=callable_name)
512532

513533
inferred_dunder_get_type = get_proper_type(inferred_dunder_get_type)
@@ -516,7 +536,8 @@ def analyze_descriptor_access(instance_type: Type,
516536
return inferred_dunder_get_type
517537

518538
if not isinstance(inferred_dunder_get_type, CallableType):
519-
msg.fail(message_registry.DESCRIPTOR_GET_NOT_CALLABLE.format(descriptor_type), context)
539+
mx.msg.fail(message_registry.DESCRIPTOR_GET_NOT_CALLABLE.format(descriptor_type),
540+
mx.context)
520541
return AnyType(TypeOfAny.from_error)
521542

522543
return inferred_dunder_get_type.ret_type
@@ -605,8 +626,7 @@ def analyze_var(name: str,
605626
fullname = '{}.{}'.format(var.info.fullname, name)
606627
hook = mx.chk.plugin.get_attribute_hook(fullname)
607628
if result and not mx.is_lvalue and not implicit:
608-
result = analyze_descriptor_access(mx.original_type, result, mx.named_type,
609-
mx.msg, mx.context, chk=mx.chk)
629+
result = analyze_descriptor_access(result, mx)
610630
if hook:
611631
result = hook(AttributeContext(get_proper_type(mx.original_type),
612632
result, mx.context, mx.chk))
@@ -785,8 +805,7 @@ def analyze_class_attribute_access(itype: Instance,
785805
result = add_class_tvars(t, isuper, is_classmethod,
786806
mx.self_type, original_vars=original_vars)
787807
if not mx.is_lvalue:
788-
result = analyze_descriptor_access(mx.original_type, result, mx.named_type,
789-
mx.msg, mx.context, chk=mx.chk)
808+
result = analyze_descriptor_access(result, mx)
790809
return result
791810
elif isinstance(node.node, Var):
792811
mx.not_ready_callback(name, mx.context)

mypy/nodes.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2687,12 +2687,14 @@ def __bool__(self) -> bool:
26872687
def has_readable_member(self, name: str) -> bool:
26882688
return self.get(name) is not None
26892689

2690-
def get_method(self, name: str) -> Optional[FuncBase]:
2690+
def get_method(self, name: str) -> Union[FuncBase, Decorator, None]:
26912691
for cls in self.mro:
26922692
if name in cls.names:
26932693
node = cls.names[name].node
26942694
if isinstance(node, FuncBase):
26952695
return node
2696+
elif isinstance(node, Decorator): # Two `if`s make `mypyc` happy
2697+
return node
26962698
else:
26972699
return None
26982700
return None

mypy/subtypes.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -650,6 +650,8 @@ def find_member(name: str,
650650
info = itype.type
651651
method = info.get_method(name)
652652
if method:
653+
if isinstance(method, Decorator):
654+
return find_node_type(method.var, itype, subtype)
653655
if method.is_property:
654656
assert isinstance(method, OverloadedFuncDef)
655657
dec = method.items[0]
@@ -659,12 +661,7 @@ def find_member(name: str,
659661
else:
660662
# don't have such method, maybe variable or decorator?
661663
node = info.get(name)
662-
if not node:
663-
v = None
664-
else:
665-
v = node.node
666-
if isinstance(v, Decorator):
667-
v = v.var
664+
v = node.node if node else None
668665
if isinstance(v, Var):
669666
return find_node_type(v, itype, subtype)
670667
if (not v and name not in ['__getattr__', '__setattr__', '__getattribute__'] and
@@ -676,9 +673,13 @@ def find_member(name: str,
676673
# structural subtyping.
677674
method = info.get_method(method_name)
678675
if method and method.info.fullname != 'builtins.object':
679-
getattr_type = get_proper_type(find_node_type(method, itype, subtype))
676+
if isinstance(method, Decorator):
677+
getattr_type = get_proper_type(find_node_type(method.var, itype, subtype))
678+
else:
679+
getattr_type = get_proper_type(find_node_type(method, itype, subtype))
680680
if isinstance(getattr_type, CallableType):
681681
return getattr_type.ret_type
682+
return getattr_type
682683
if itype.type.fallback_to_any:
683684
return AnyType(TypeOfAny.special_form)
684685
return None
@@ -698,8 +699,10 @@ def get_member_flags(name: str, info: TypeInfo) -> Set[int]:
698699
method = info.get_method(name)
699700
setattr_meth = info.get_method('__setattr__')
700701
if method:
701-
# this could be settable property
702-
if method.is_property:
702+
if isinstance(method, Decorator):
703+
if method.var.is_staticmethod or method.var.is_classmethod:
704+
return {IS_CLASS_OR_STATIC}
705+
elif method.is_property: # this could be settable property
703706
assert isinstance(method, OverloadedFuncDef)
704707
dec = method.items[0]
705708
assert isinstance(dec, Decorator)
@@ -712,9 +715,6 @@ def get_member_flags(name: str, info: TypeInfo) -> Set[int]:
712715
return {IS_SETTABLE}
713716
return set()
714717
v = node.node
715-
if isinstance(v, Decorator):
716-
if v.var.is_staticmethod or v.var.is_classmethod:
717-
return {IS_CLASS_OR_STATIC}
718718
# just a variable
719719
if isinstance(v, Var) and not v.is_property:
720720
flags = {IS_SETTABLE}

0 commit comments

Comments
 (0)