Skip to content

Commit 4edaacb

Browse files
committed
Correctly handle generics when doing overload union math
1 parent 0da38c4 commit 4edaacb

File tree

2 files changed

+106
-27
lines changed

2 files changed

+106
-27
lines changed

mypy/checkexpr.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -659,7 +659,8 @@ def check_call(self, callee: Type, args: List[Expression],
659659
# and return type.
660660

661661
targets = cast(List[CallableType], targets)
662-
unioned_callable = self.union_overload_matches(targets)
662+
unioned_callable = self.union_overload_matches(
663+
targets, args, arg_kinds, arg_names, context)
663664
if unioned_callable is None:
664665
# If it was not possible to actually combine together the
665666
# callables in a sound way, we give up and return the original
@@ -1204,14 +1205,18 @@ def overload_call_targets(self, arg_types: List[Type], arg_kinds: List[int],
12041205
arg_types, arg_kinds, arg_names, m, context=context)]
12051206
return out if len(out) >= 1 else match
12061207

1207-
def union_overload_matches(self, callables: List[CallableType]) -> Optional[CallableType]:
1208+
def union_overload_matches(self, callables: List[CallableType],
1209+
args: List[Expression],
1210+
arg_kinds: List[int],
1211+
arg_names: Optional[Sequence[Optional[str]]],
1212+
context: Context) -> Optional[CallableType]:
12081213
"""Accepts a list of overload signatures and attempts to combine them together into a
12091214
new CallableType consisting of the union of all of the given arguments and return types.
12101215
12111216
Returns None if it is not possible to combine the different callables together in a
12121217
sound manner."""
1213-
12141218
new_args = [[] for _ in range(len(callables[0].arg_types))] # type: List[List[Type]]
1219+
new_returns = [] # type: List[Type]
12151220

12161221
expected_names = callables[0].arg_names
12171222
expected_kinds = callables[0].arg_kinds
@@ -1222,13 +1227,25 @@ def union_overload_matches(self, callables: List[CallableType]) -> Optional[Call
12221227
# TODO: Enhance the union overload logic to handle a wider variety of signatures.
12231228
return None
12241229

1230+
if target.is_generic():
1231+
formal_to_actual = map_actuals_to_formals(
1232+
arg_kinds, arg_names,
1233+
target.arg_kinds, target.arg_names,
1234+
lambda i: self.accept(args[i]))
1235+
1236+
target = freshen_function_type_vars(target)
1237+
target = self.infer_function_type_arguments_using_context(target, context)
1238+
target = self.infer_function_type_arguments(
1239+
target, args, arg_kinds, formal_to_actual, context)
1240+
12251241
for i, arg in enumerate(target.arg_types):
12261242
new_args[i].append(arg)
1243+
new_returns.append(target.ret_type)
12271244

12281245
union_count = 0
12291246
final_args = []
1230-
for args in new_args:
1231-
new_type = UnionType.make_simplified_union(args)
1247+
for args_list in new_args:
1248+
new_type = UnionType.make_simplified_union(args_list)
12321249
union_count += 1 if isinstance(new_type, UnionType) else 0
12331250
final_args.append(new_type)
12341251

@@ -1256,7 +1273,7 @@ def union_overload_matches(self, callables: List[CallableType]) -> Optional[Call
12561273

12571274
return callables[0].copy_modified(
12581275
arg_types=final_args,
1259-
ret_type=UnionType.make_simplified_union([t.ret_type for t in callables]),
1276+
ret_type=UnionType.make_simplified_union(new_returns),
12601277
implicit=True,
12611278
from_overloads=True)
12621279

test-data/unit/check-overloading.test

Lines changed: 83 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1444,52 +1444,52 @@ class Child4(Parent):
14441444
[case testOverloadWithIncompatibleMethodOverrideAndImplementation]
14451445
from typing import overload, Union, Any
14461446

1447-
class StrSub: pass
1447+
class Sub: pass
1448+
class A: pass
1449+
class B: pass
14481450

14491451
class ParentWithTypedImpl:
14501452
@overload
1451-
def f(self, arg: int) -> int: ...
1453+
def f(self, arg: A) -> A: ...
14521454
@overload
1453-
def f(self, arg: str) -> str: ...
1454-
def f(self, arg: Union[int, str]) -> Union[int, str]: ...
1455+
def f(self, arg: B) -> B: ...
1456+
def f(self, arg: Union[A, B]) -> Union[A, B]: ...
14551457

14561458
class Child1(ParentWithTypedImpl):
14571459
@overload # E: Signature of "f" incompatible with supertype "ParentWithTypedImpl"
1458-
def f(self, arg: int) -> int: ...
1460+
def f(self, arg: A) -> A: ...
14591461
@overload
1460-
def f(self, arg: StrSub) -> str: ...
1461-
def f(self, arg: Union[int, StrSub]) -> Union[int, str]: ...
1462+
def f(self, arg: Sub) -> B: ...
1463+
def f(self, arg: Union[A, Sub]) -> Union[A, B]: ...
14621464

14631465
class Child2(ParentWithTypedImpl):
14641466
@overload # E: Signature of "f" incompatible with supertype "ParentWithTypedImpl"
1465-
def f(self, arg: int) -> int: ...
1467+
def f(self, arg: A) -> A: ...
14661468
@overload
1467-
def f(self, arg: StrSub) -> str: ...
1469+
def f(self, arg: Sub) -> B: ...
14681470
def f(self, arg: Any) -> Any: ...
14691471

14701472
class ParentWithDynamicImpl:
14711473
@overload
1472-
def f(self, arg: int) -> int: ...
1474+
def f(self, arg: A) -> A: ...
14731475
@overload
1474-
def f(self, arg: str) -> str: ...
1476+
def f(self, arg: B) -> B: ...
14751477
def f(self, arg: Any) -> Any: ...
14761478

14771479
class Child3(ParentWithDynamicImpl):
14781480
@overload # E: Signature of "f" incompatible with supertype "ParentWithDynamicImpl"
1479-
def f(self, arg: int) -> int: ...
1481+
def f(self, arg: A) -> A: ...
14801482
@overload
1481-
def f(self, arg: StrSub) -> str: ...
1482-
def f(self, arg: Union[int, StrSub]) -> Union[int, str]: ...
1483+
def f(self, arg: Sub) -> B: ...
1484+
def f(self, arg: Union[A, Sub]) -> Union[A, B]: ...
14831485

14841486
class Child4(ParentWithDynamicImpl):
14851487
@overload # E: Signature of "f" incompatible with supertype "ParentWithDynamicImpl"
1486-
def f(self, arg: int) -> int: ...
1488+
def f(self, arg: A) -> A: ...
14871489
@overload
1488-
def f(self, arg: StrSub) -> str: ...
1490+
def f(self, arg: Sub) -> B: ...
14891491
def f(self, arg: Any) -> Any: ...
14901492

1491-
[builtins fixtures/tuple.pyi]
1492-
14931493
[case testOverloadInferUnionReturnBasic]
14941494
from typing import overload, Union
14951495

@@ -1515,8 +1515,6 @@ def f2(x): ...
15151515

15161516
reveal_type(f2(arg1)) # E: Revealed type is '__main__.B'
15171517

1518-
[builtins fixtures/tuple.pyi]
1519-
15201518
[case testOverloadInferUnionReturnMultipleArguments]
15211519
from typing import overload, Union
15221520

@@ -1544,7 +1542,6 @@ def f2(x, y): ...
15441542
reveal_type(f2(arg1, arg1))
15451543
reveal_type(f2(arg1, C()))
15461544

1547-
[builtins fixtures/tuple.pyi]
15481545
[out]
15491546
main:16: error: Revealed type is '__main__.B'
15501547
main:16: error: Argument 1 to "f1" has incompatible type "Union[A, C]"; expected "A"
@@ -1553,6 +1550,24 @@ main:24: error: Revealed type is 'Union[__main__.B, __main__.D]'
15531550
main:24: error: Argument 2 to "f2" has incompatible type "Union[A, C]"; expected "C"
15541551
main:25: error: Revealed type is 'Union[__main__.B, __main__.D]'
15551552

1553+
[case testOverloadInferUnionSkipIfParameterNamesAreDifferent]
1554+
from typing import overload, Union
1555+
1556+
class A: ...
1557+
class B: ...
1558+
class C: ...
1559+
1560+
@overload
1561+
def f(x: A) -> B: ...
1562+
@overload
1563+
def f(y: B) -> C: ...
1564+
def f(x): ...
1565+
1566+
x: Union[A, B]
1567+
reveal_type(f(A())) # E: Revealed type is '__main__.B'
1568+
reveal_type(f(B())) # E: Revealed type is '__main__.C'
1569+
f(x) # E: Argument 1 to "f" has incompatible type "Union[A, B]"; expected "A"
1570+
15561571
[case testOverloadInferUnionReturnFunctionsWithKwargs]
15571572
from typing import overload, Union, Optional
15581573

@@ -1584,3 +1599,50 @@ main:19: error: Revealed type is '__main__.C'
15841599
main:19: error: Argument 2 to "f" has incompatible type "Union[B, C]"; expected "Optional[B]"
15851600
main:21: error: Revealed type is '__main__.A'
15861601

1602+
[case testOverloadingInferUnionReturnWithTypevarWithValueRestriction]
1603+
from typing import overload, Union, TypeVar, Generic
1604+
1605+
class A: pass
1606+
class B: pass
1607+
class C: pass
1608+
1609+
T = TypeVar('T', B, C)
1610+
1611+
class Wrapper(Generic[T]):
1612+
@overload
1613+
def f(self, x: T) -> B: ...
1614+
1615+
@overload
1616+
def f(self, x: A) -> C: ...
1617+
1618+
def f(self, x): ...
1619+
1620+
obj: Wrapper[B] = Wrapper()
1621+
x: Union[A, B]
1622+
1623+
reveal_type(obj.f(A())) # E: Revealed type is '__main__.C'
1624+
reveal_type(obj.f(B())) # E: Revealed type is '__main__.B'
1625+
reveal_type(obj.f(x)) # E: Revealed type is 'Union[__main__.B, __main__.C]'
1626+
1627+
[case testOverloadingInferUnionReturnWithTypevarReturn]
1628+
from typing import overload, Union, TypeVar, Generic
1629+
1630+
T = TypeVar('T')
1631+
1632+
class Wrapper1(Generic[T]): pass
1633+
class Wrapper2(Generic[T]): pass
1634+
class A: pass
1635+
class B: pass
1636+
1637+
@overload
1638+
def f(x: Wrapper1[T]) -> T: ...
1639+
@overload
1640+
def f(x: Wrapper2[T]) -> T: ...
1641+
def f(x): ...
1642+
1643+
obj1: Union[Wrapper1[A], Wrapper2[A]]
1644+
reveal_type(f(obj1)) # E: Revealed type is '__main__.A'
1645+
1646+
obj2: Union[Wrapper1[A], Wrapper2[B]]
1647+
reveal_type(f(obj2)) # E: Revealed type is 'Union[__main__.A, __main__.B]'
1648+

0 commit comments

Comments
 (0)