Skip to content

Commit 72734f2

Browse files
authored
Support method plugin hooks on unions (#6560)
Fixes #6117 Fixes #5930 Currently both our plugin method hooks don't work with unions. This PR fixes this with three things: * Moves a bit of logic from `visit_call_expr_inner()` (which is a long method already) to `check_call_expr_with_callee_type()` (which is a short method). * Special-cases unions in `check_call_expr_with_callee_type()` (normal method calls) and `check_method_call_by_name()` (dunder/operator method calls). * Adds some clarifying comments and a docstring. The week point is interaction with binder, but IMO this is the best we can have for now. I left a comment mentioning that check for overlap should be consistent in two functions. In general, I don't like special-casing, but I spent several days thinking of other solutions, and it looks like special-casing unions in couple more places is the only reasonable way to fix unions-vs-plugins interactions. This PR may interfere with #6558 that fixes an "opposite" problem, hopefully they will work together unmodified, so that accessing union of literals on union of typed dicts works. Whatever PR lands second, should add a test for this.
1 parent b724cca commit 72734f2

File tree

5 files changed

+273
-21
lines changed

5 files changed

+273
-21
lines changed

mypy/checkexpr.py

Lines changed: 109 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from typing import (
66
cast, Dict, Set, List, Tuple, Callable, Union, Optional, Sequence, Iterator
77
)
8-
from typing_extensions import ClassVar, Final
8+
from typing_extensions import ClassVar, Final, overload
99

1010
from mypy.errors import report_internal_error
1111
from mypy.typeanal import (
@@ -284,24 +284,27 @@ def visit_call_expr_inner(self, e: CallExpr, allow_none_return: bool = False) ->
284284
return self.msg.untyped_function_call(callee_type, e)
285285
# Figure out the full name of the callee for plugin lookup.
286286
object_type = None
287-
if not isinstance(e.callee, RefExpr):
288-
fullname = None
289-
else:
287+
member = None
288+
fullname = None
289+
if isinstance(e.callee, RefExpr):
290+
# There are two special cases where plugins might act:
291+
# * A "static" reference/alias to a class or function;
292+
# get_function_hook() will be invoked for these.
290293
fullname = e.callee.fullname
291294
if (isinstance(e.callee.node, TypeAlias) and
292295
isinstance(e.callee.node.target, Instance)):
293296
fullname = e.callee.node.target.type.fullname()
297+
# * Call to a method on object that has a full name (see
298+
# method_fullname() for details on supported objects);
299+
# get_method_hook() and get_method_signature_hook() will
300+
# be invoked for these.
294301
if (fullname is None
295302
and isinstance(e.callee, MemberExpr)
296-
and e.callee.expr in self.chk.type_map
297-
and isinstance(callee_type, FunctionLike)):
298-
# For method calls we include the defining class for the method
299-
# in the full name (example: 'typing.Mapping.get').
300-
callee_expr_type = self.chk.type_map[e.callee.expr]
301-
fullname = self.method_fullname(callee_expr_type, e.callee.name)
302-
if fullname is not None:
303-
object_type = callee_expr_type
304-
ret_type = self.check_call_expr_with_callee_type(callee_type, e, fullname, object_type)
303+
and e.callee.expr in self.chk.type_map):
304+
member = e.callee.name
305+
object_type = self.chk.type_map[e.callee.expr]
306+
ret_type = self.check_call_expr_with_callee_type(callee_type, e, fullname,
307+
object_type, member)
305308
if isinstance(e.callee, RefExpr) and len(e.args) == 2:
306309
if e.callee.fullname in ('builtins.isinstance', 'builtins.issubclass'):
307310
self.check_runtime_protocol_test(e)
@@ -632,21 +635,53 @@ def check_call_expr_with_callee_type(self,
632635
callee_type: Type,
633636
e: CallExpr,
634637
callable_name: Optional[str],
635-
object_type: Optional[Type]) -> Type:
638+
object_type: Optional[Type],
639+
member: Optional[str] = None) -> Type:
636640
"""Type check call expression.
637641
638-
The given callee type overrides the type of the callee
639-
expression.
640-
"""
641-
# Try to refine the call signature using plugin hooks before checking the call.
642-
callee_type = self.transform_callee_type(
643-
callable_name, callee_type, e.args, e.arg_kinds, e, e.arg_names, object_type)
642+
The callee_type should be used as the type of callee expression. In particular,
643+
in case of a union type this can be a particular item of the union, so that we can
644+
apply plugin hooks to each item.
644645
646+
The 'member', 'callable_name' and 'object_type' are only used to call plugin hooks.
647+
If 'callable_name' is None but 'member' is not None (member call), try constructing
648+
'callable_name' using 'object_type' (the base type on which the method is called),
649+
for example 'typing.Mapping.get'.
650+
"""
651+
if callable_name is None and member is not None:
652+
assert object_type is not None
653+
callable_name = self.method_fullname(object_type, member)
654+
if callable_name:
655+
# Try to refine the call signature using plugin hooks before checking the call.
656+
callee_type = self.transform_callee_type(
657+
callable_name, callee_type, e.args, e.arg_kinds, e, e.arg_names, object_type)
658+
# Unions are special-cased to allow plugins to act on each item in the union.
659+
elif member is not None and isinstance(object_type, UnionType):
660+
return self.check_union_call_expr(e, object_type, member)
645661
return self.check_call(callee_type, e.args, e.arg_kinds, e,
646662
e.arg_names, callable_node=e.callee,
647663
callable_name=callable_name,
648664
object_type=object_type)[0]
649665

666+
def check_union_call_expr(self, e: CallExpr, object_type: UnionType, member: str) -> Type:
667+
""""Type check calling a member expression where the base type is a union."""
668+
res = [] # type: List[Type]
669+
for typ in object_type.relevant_items():
670+
# Member access errors are already reported when visiting the member expression.
671+
self.msg.disable_errors()
672+
item = analyze_member_access(member, typ, e, False, False, False,
673+
self.msg, original_type=object_type, chk=self.chk,
674+
in_literal_context=self.is_literal_context())
675+
self.msg.enable_errors()
676+
narrowed = self.narrow_type_from_binder(e.callee, item, skip_non_overlapping=True)
677+
if narrowed is None:
678+
continue
679+
callable_name = self.method_fullname(typ, member)
680+
item_object_type = typ if callable_name else None
681+
res.append(self.check_call_expr_with_callee_type(narrowed, e, callable_name,
682+
item_object_type))
683+
return UnionType.make_simplified_union(res)
684+
650685
def check_call(self,
651686
callee: Type,
652687
args: List[Expression],
@@ -2018,13 +2053,48 @@ def check_method_call_by_name(self,
20182053
"""
20192054
local_errors = local_errors or self.msg
20202055
original_type = original_type or base_type
2056+
# Unions are special-cased to allow plugins to act on each element of the union.
2057+
if isinstance(base_type, UnionType):
2058+
return self.check_union_method_call_by_name(method, base_type,
2059+
args, arg_kinds,
2060+
context, local_errors, original_type)
2061+
20212062
method_type = analyze_member_access(method, base_type, context, False, False, True,
20222063
local_errors, original_type=original_type,
20232064
chk=self.chk,
20242065
in_literal_context=self.is_literal_context())
20252066
return self.check_method_call(
20262067
method, base_type, method_type, args, arg_kinds, context, local_errors)
20272068

2069+
def check_union_method_call_by_name(self,
2070+
method: str,
2071+
base_type: UnionType,
2072+
args: List[Expression],
2073+
arg_kinds: List[int],
2074+
context: Context,
2075+
local_errors: MessageBuilder,
2076+
original_type: Optional[Type] = None
2077+
) -> Tuple[Type, Type]:
2078+
"""Type check a call to a named method on an object with union type.
2079+
2080+
This essentially checks the call using check_method_call_by_name() for each
2081+
union item and unions the result. We do this to allow plugins to act on
2082+
individual union items.
2083+
"""
2084+
res = [] # type: List[Type]
2085+
meth_res = [] # type: List[Type]
2086+
for typ in base_type.relevant_items():
2087+
# Format error messages consistently with
2088+
# mypy.checkmember.analyze_union_member_access().
2089+
local_errors.disable_type_names += 1
2090+
item, meth_item = self.check_method_call_by_name(method, typ, args, arg_kinds,
2091+
context, local_errors,
2092+
original_type)
2093+
local_errors.disable_type_names -= 1
2094+
res.append(item)
2095+
meth_res.append(meth_item)
2096+
return UnionType.make_simplified_union(res), UnionType.make_simplified_union(meth_res)
2097+
20282098
def check_method_call(self,
20292099
method_name: str,
20302100
base_type: Type,
@@ -3524,14 +3594,32 @@ def bool_type(self) -> Instance:
35243594
"""Return instance type 'bool'."""
35253595
return self.named_type('builtins.bool')
35263596

3527-
def narrow_type_from_binder(self, expr: Expression, known_type: Type) -> Type:
3597+
@overload
3598+
def narrow_type_from_binder(self, expr: Expression, known_type: Type) -> Type: ...
3599+
3600+
@overload # noqa
3601+
def narrow_type_from_binder(self, expr: Expression, known_type: Type,
3602+
skip_non_overlapping: bool) -> Optional[Type]: ...
3603+
3604+
def narrow_type_from_binder(self, expr: Expression, known_type: Type, # noqa
3605+
skip_non_overlapping: bool = False) -> Optional[Type]:
3606+
"""Narrow down a known type of expression using information in conditional type binder.
3607+
3608+
If 'skip_non_overlapping' is True, return None if the type and restriction are
3609+
non-overlapping.
3610+
"""
35283611
if literal(expr) >= LITERAL_TYPE:
35293612
restriction = self.chk.binder.get(expr)
35303613
# If the current node is deferred, some variables may get Any types that they
35313614
# otherwise wouldn't have. We don't want to narrow down these since it may
35323615
# produce invalid inferred Optional[Any] types, at least.
35333616
if restriction and not (isinstance(known_type, AnyType)
35343617
and self.chk.current_node_deferred):
3618+
# Note: this call should match the one in narrow_declared_type().
3619+
if (skip_non_overlapping and
3620+
not is_overlapping_types(known_type, restriction,
3621+
prohibit_none_typevar_overlap=True)):
3622+
return None
35353623
ans = narrow_declared_type(known_type, restriction)
35363624
return ans
35373625
return known_type

test-data/unit/check-ctypes.test

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ for x in a:
2323

2424
[case testCtypesArrayCustomElementType]
2525
import ctypes
26+
from typing import Union, List
2627

2728
class MyCInt(ctypes.c_int):
2829
pass
@@ -46,6 +47,10 @@ mya[3] = b"bytes" # E: No overload variant of "__setitem__" of "Array" matches
4647
# N: def __setitem__(self, slice, List[Union[MyCInt, int]]) -> None
4748
for myx in mya:
4849
reveal_type(myx) # N: Revealed type is '__main__.MyCInt*'
50+
51+
myu: Union[ctypes.Array[ctypes.c_int], List[str]]
52+
for myi in myu:
53+
reveal_type(myi) # N: Revealed type is 'Union[builtins.int*, builtins.str*]'
4954
[builtins fixtures/floatdict.pyi]
5055

5156
[case testCtypesArrayUnionElementType]

test-data/unit/check-custom-plugin.test

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -585,6 +585,67 @@ reveal_type(instance(2)) # N: Revealed type is 'builtins.float*'
585585
[[mypy]
586586
plugins=<ROOT>/test-data/unit/plugins/callable_instance.py
587587

588+
[case testGetMethodHooksOnUnions]
589+
# flags: --config-file tmp/mypy.ini --no-strict-optional
590+
from typing import Union
591+
592+
class Foo:
593+
def meth(self, x: str) -> str: ...
594+
class Bar:
595+
def meth(self, x: int) -> float: ...
596+
class Other:
597+
meth: int
598+
599+
x: Union[Foo, Bar, Other]
600+
if isinstance(x.meth, int):
601+
reveal_type(x.meth) # N: Revealed type is 'builtins.int'
602+
else:
603+
reveal_type(x.meth(int())) # N: Revealed type is 'builtins.int'
604+
605+
[builtins fixtures/isinstancelist.pyi]
606+
[file mypy.ini]
607+
[[mypy]
608+
plugins=<ROOT>/test-data/unit/plugins/union_method.py
609+
610+
[case testGetMethodHooksOnUnionsStrictOptional]
611+
# flags: --config-file tmp/mypy.ini --strict-optional
612+
from typing import Union
613+
614+
class Foo:
615+
def meth(self, x: str) -> str: ...
616+
class Bar:
617+
def meth(self, x: int) -> float: ...
618+
class Other:
619+
meth: int
620+
621+
x: Union[Foo, Bar, Other]
622+
if isinstance(x.meth, int):
623+
reveal_type(x.meth) # N: Revealed type is 'builtins.int'
624+
else:
625+
reveal_type(x.meth(int())) # N: Revealed type is 'builtins.int'
626+
627+
[builtins fixtures/isinstancelist.pyi]
628+
[file mypy.ini]
629+
[[mypy]
630+
plugins=<ROOT>/test-data/unit/plugins/union_method.py
631+
632+
[case testGetMethodHooksOnUnionsSpecial]
633+
# flags: --config-file tmp/mypy.ini
634+
from typing import Union
635+
636+
class Foo:
637+
def __getitem__(self, x: str) -> str: ...
638+
class Bar:
639+
def __getitem__(self, x: int) -> float: ...
640+
641+
x: Union[Foo, Bar]
642+
reveal_type(x[int()]) # N: Revealed type is 'builtins.int'
643+
644+
[builtins fixtures/isinstancelist.pyi]
645+
[file mypy.ini]
646+
[[mypy]
647+
plugins=<ROOT>/test-data/unit/plugins/union_method.py
648+
588649
[case testPluginDependencies]
589650
# flags: --config-file tmp/mypy.ini
590651

test-data/unit/check-typeddict.test

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1589,6 +1589,55 @@ alias('x') # E: Argument 1 has incompatible type "str"; expected "NoReturn"
15891589
alias(s) # E: Argument 1 has incompatible type "str"; expected "NoReturn"
15901590
[builtins fixtures/dict.pyi]
15911591

1592+
[case testPluginUnionsOfTypedDicts]
1593+
from typing import Union
1594+
from mypy_extensions import TypedDict
1595+
1596+
class TDA(TypedDict):
1597+
a: int
1598+
b: str
1599+
1600+
class TDB(TypedDict):
1601+
a: int
1602+
b: int
1603+
c: int
1604+
1605+
td: Union[TDA, TDB]
1606+
1607+
reveal_type(td.get('a')) # N: Revealed type is 'builtins.int'
1608+
reveal_type(td.get('b')) # N: Revealed type is 'Union[builtins.str, builtins.int]'
1609+
reveal_type(td.get('c')) # N: Revealed type is 'Union[Any, builtins.int]' \
1610+
# E: TypedDict "TDA" has no key 'c'
1611+
1612+
reveal_type(td['a']) # N: Revealed type is 'builtins.int'
1613+
reveal_type(td['b']) # N: Revealed type is 'Union[builtins.str, builtins.int]'
1614+
reveal_type(td['c']) # N: Revealed type is 'Union[Any, builtins.int]' \
1615+
# E: TypedDict "TDA" has no key 'c'
1616+
[builtins fixtures/dict.pyi]
1617+
[typing fixtures/typing-full.pyi]
1618+
1619+
[case testPluginUnionsOfTypedDictsNonTotal]
1620+
from typing import Union
1621+
from mypy_extensions import TypedDict
1622+
1623+
class TDA(TypedDict, total=False):
1624+
a: int
1625+
b: str
1626+
1627+
class TDB(TypedDict, total=False):
1628+
a: int
1629+
b: int
1630+
c: int
1631+
1632+
td: Union[TDA, TDB]
1633+
1634+
reveal_type(td.pop('a')) # N: Revealed type is 'builtins.int'
1635+
reveal_type(td.pop('b')) # N: Revealed type is 'Union[builtins.str, builtins.int]'
1636+
reveal_type(td.pop('c')) # N: Revealed type is 'Union[Any, builtins.int]' \
1637+
# E: TypedDict "TDA" has no key 'c'
1638+
[builtins fixtures/dict.pyi]
1639+
[typing fixtures/typing-full.pyi]
1640+
15921641
[case testCanCreateTypedDictWithTypingExtensions]
15931642
# flags: --python-version 3.6
15941643
from typing_extensions import TypedDict
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
from mypy.plugin import (
2+
CallableType, CheckerPluginInterface, MethodSigContext, MethodContext, Plugin
3+
)
4+
from mypy.types import Instance, Type
5+
6+
7+
class MethodPlugin(Plugin):
8+
def get_method_signature_hook(self, fullname):
9+
if fullname.startswith('__main__.Foo.'):
10+
return my_meth_sig_hook
11+
return None
12+
13+
def get_method_hook(self, fullname):
14+
if fullname.startswith('__main__.Bar.'):
15+
return my_meth_hook
16+
return None
17+
18+
19+
def _str_to_int(api: CheckerPluginInterface, typ: Type) -> Type:
20+
if isinstance(typ, Instance):
21+
if typ.type.fullname() == 'builtins.str':
22+
return api.named_generic_type('builtins.int', [])
23+
elif typ.args:
24+
return typ.copy_modified(args=[_str_to_int(api, t) for t in typ.args])
25+
return typ
26+
27+
28+
def _float_to_int(api: CheckerPluginInterface, typ: Type) -> Type:
29+
if isinstance(typ, Instance):
30+
if typ.type.fullname() == 'builtins.float':
31+
return api.named_generic_type('builtins.int', [])
32+
elif typ.args:
33+
return typ.copy_modified(args=[_float_to_int(api, t) for t in typ.args])
34+
return typ
35+
36+
37+
def my_meth_sig_hook(ctx: MethodSigContext) -> CallableType:
38+
return ctx.default_signature.copy_modified(
39+
arg_types=[_str_to_int(ctx.api, t) for t in ctx.default_signature.arg_types],
40+
ret_type=_str_to_int(ctx.api, ctx.default_signature.ret_type),
41+
)
42+
43+
44+
def my_meth_hook(ctx: MethodContext) -> Type:
45+
return _float_to_int(ctx.api, ctx.default_return_type)
46+
47+
48+
def plugin(version):
49+
return MethodPlugin

0 commit comments

Comments
 (0)