Skip to content

Commit 0df8cf5

Browse files
Support typing_extensions.overload (#12602)
This always existed in typing_extensions, but was an alias for typing.overload. With python/typing#1140, it will actually make a difference at runtime which one you use. Note that this shouldn't change mypy's behaviour, since we alias typing_extensions.overload to typing.overload in typeshed, but this makes the logic less fragile.
1 parent 10ba5c1 commit 0df8cf5

File tree

8 files changed

+122
-17
lines changed

8 files changed

+122
-17
lines changed

mypy/checker.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@
3737
UnionType, TypeVarId, TypeVarType, PartialType, DeletedType, UninhabitedType,
3838
is_named_instance, union_items, TypeQuery, LiteralType,
3939
is_optional, remove_optional, TypeTranslator, StarType, get_proper_type, ProperType,
40-
get_proper_types, is_literal_type, TypeAliasType, TypeGuardedType, ParamSpecType
40+
get_proper_types, is_literal_type, TypeAliasType, TypeGuardedType, ParamSpecType,
41+
OVERLOAD_NAMES,
4142
)
4243
from mypy.sametypes import is_same_type
4344
from mypy.messages import (
@@ -3981,7 +3982,7 @@ def visit_decorator(self, e: Decorator) -> None:
39813982
# may be different from the declared signature.
39823983
sig: Type = self.function_type(e.func)
39833984
for d in reversed(e.decorators):
3984-
if refers_to_fullname(d, 'typing.overload'):
3985+
if refers_to_fullname(d, OVERLOAD_NAMES):
39853986
self.fail(message_registry.MULTIPLE_OVERLOADS_REQUIRED, e)
39863987
continue
39873988
dec = self.expr_checker.accept(d)

mypy/semanal.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@
9999
TypeTranslator, TypeOfAny, TypeType, NoneType, PlaceholderType, TPDICT_NAMES, ProperType,
100100
get_proper_type, get_proper_types, TypeAliasType, TypeVarLikeType, Parameters, ParamSpecType,
101101
PROTOCOL_NAMES, TYPE_ALIAS_NAMES, FINAL_TYPE_NAMES, FINAL_DECORATOR_NAMES, REVEAL_TYPE_NAMES,
102-
ASSERT_TYPE_NAMES, is_named_instance,
102+
ASSERT_TYPE_NAMES, OVERLOAD_NAMES, is_named_instance,
103103
)
104104
from mypy.typeops import function_type, get_type_vars
105105
from mypy.type_visitor import TypeQuery
@@ -835,7 +835,7 @@ def analyze_overload_sigs_and_impl(
835835
if isinstance(item, Decorator):
836836
callable = function_type(item.func, self.named_type('builtins.function'))
837837
assert isinstance(callable, CallableType)
838-
if not any(refers_to_fullname(dec, 'typing.overload')
838+
if not any(refers_to_fullname(dec, OVERLOAD_NAMES)
839839
for dec in item.decorators):
840840
if i == len(defn.items) - 1 and not self.is_stub_file:
841841
# Last item outside a stub is impl

mypy/stubgen.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@
5454
from collections import defaultdict
5555

5656
from typing import (
57-
List, Dict, Tuple, Iterable, Mapping, Optional, Set, cast,
57+
List, Dict, Tuple, Iterable, Mapping, Optional, Set, Union, cast,
5858
)
5959
from typing_extensions import Final
6060

@@ -84,7 +84,7 @@
8484
from mypy.options import Options as MypyOptions
8585
from mypy.types import (
8686
Type, TypeStrVisitor, CallableType, UnboundType, NoneType, TupleType, TypeList, Instance,
87-
AnyType, get_proper_type
87+
AnyType, get_proper_type, OVERLOAD_NAMES
8888
)
8989
from mypy.visitor import NodeVisitor
9090
from mypy.find_sources import create_source_list, InvalidSourceList
@@ -93,6 +93,10 @@
9393
from mypy.traverser import all_yield_expressions, has_return_statement, has_yield_expression
9494
from mypy.moduleinspect import ModuleInspect
9595

96+
TYPING_MODULE_NAMES: Final = (
97+
'typing',
98+
'typing_extensions',
99+
)
96100

97101
# Common ways of naming package containing vendored modules.
98102
VENDOR_PACKAGES: Final = [
@@ -768,13 +772,15 @@ def process_name_expr_decorator(self, expr: NameExpr, context: Decorator) -> Tup
768772
self.add_decorator('property')
769773
self.add_decorator('abc.abstractmethod')
770774
is_abstract = True
771-
elif self.refers_to_fullname(name, 'typing.overload'):
775+
elif self.refers_to_fullname(name, OVERLOAD_NAMES):
772776
self.add_decorator(name)
773777
self.add_typing_import('overload')
774778
is_overload = True
775779
return is_abstract, is_overload
776780

777-
def refers_to_fullname(self, name: str, fullname: str) -> bool:
781+
def refers_to_fullname(self, name: str, fullname: Union[str, Tuple[str, ...]]) -> bool:
782+
if isinstance(fullname, tuple):
783+
return any(self.refers_to_fullname(name, fname) for fname in fullname)
778784
module, short = fullname.rsplit('.', 1)
779785
return (self.import_tracker.module_for.get(name) == module and
780786
(name == short or
@@ -825,8 +831,8 @@ def process_member_expr_decorator(self, expr: MemberExpr, context: Decorator) ->
825831
expr.expr.name + '.coroutine',
826832
expr.expr.name)
827833
elif (isinstance(expr.expr, NameExpr) and
828-
(expr.expr.name == 'typing' or
829-
self.import_tracker.reverse_alias.get(expr.expr.name) == 'typing') and
834+
(expr.expr.name in TYPING_MODULE_NAMES or
835+
self.import_tracker.reverse_alias.get(expr.expr.name) in TYPING_MODULE_NAMES) and
830836
expr.name == 'overload'):
831837
self.import_tracker.require_name(expr.expr.name)
832838
self.add_decorator('%s.%s' % (expr.expr.name, 'overload'))
@@ -1060,7 +1066,7 @@ def visit_import_from(self, o: ImportFrom) -> None:
10601066
and name not in self.referenced_names
10611067
and (not self._all_ or name in IGNORED_DUNDERS)
10621068
and not is_private
1063-
and module not in ('abc', 'typing', 'asyncio')):
1069+
and module not in ('abc', 'asyncio') + TYPING_MODULE_NAMES):
10641070
# An imported name that is never referenced in the module is assumed to be
10651071
# exported, unless there is an explicit __all__. Note that we need to special
10661072
# case 'abc' since some references are deleted during semantic analysis.
@@ -1118,8 +1124,7 @@ def get_init(self, lvalue: str, rvalue: Expression,
11181124
typename = self.print_annotation(annotation)
11191125
if (isinstance(annotation, UnboundType) and not annotation.args and
11201126
annotation.name == 'Final' and
1121-
self.import_tracker.module_for.get('Final') in ('typing',
1122-
'typing_extensions')):
1127+
self.import_tracker.module_for.get('Final') in TYPING_MODULE_NAMES):
11231128
# Final without type argument is invalid in stubs.
11241129
final_arg = self.get_str_type_of_node(rvalue)
11251130
typename += '[{}]'.format(final_arg)

mypy/stubtest.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -912,9 +912,8 @@ def apply_decorator_to_funcitem(
912912
return None
913913
if decorator.fullname in (
914914
"builtins.staticmethod",
915-
"typing.overload",
916915
"abc.abstractmethod",
917-
):
916+
) or decorator.fullname in mypy.types.OVERLOAD_NAMES:
918917
return func
919918
if decorator.fullname == "builtins.classmethod":
920919
assert func.arguments[0].variable.name in ("cls", "metacls")

mypy/types.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,11 @@
137137
'typing_extensions.assert_type',
138138
)
139139

140+
OVERLOAD_NAMES: Final = (
141+
'typing.overload',
142+
'typing_extensions.overload',
143+
)
144+
140145
# Attributes that can optionally be defined in the body of a subclass of
141146
# enum.Enum but are removed from the class __dict__ by EnumMeta.
142147
ENUM_REMOVED_PROPS: Final = (

test-data/unit/check-overloading.test

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,24 @@ class A: pass
4040
class B: pass
4141
[builtins fixtures/isinstance.pyi]
4242

43+
[case testTypingExtensionsOverload]
44+
from typing import Any
45+
from typing_extensions import overload
46+
@overload
47+
def f(x: 'A') -> 'B': ...
48+
@overload
49+
def f(x: 'B') -> 'A': ...
50+
51+
def f(x: Any) -> Any:
52+
pass
53+
54+
reveal_type(f(A())) # N: Revealed type is "__main__.B"
55+
reveal_type(f(B())) # N: Revealed type is "__main__.A"
56+
57+
class A: pass
58+
class B: pass
59+
[builtins fixtures/isinstance.pyi]
60+
4361
[case testOverloadNeedsImplementation]
4462
from typing import overload, Any
4563
@overload # E: An overloaded function outside a stub file must have an implementation

test-data/unit/lib-stub/typing_extensions.pyi

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from typing import TypeVar, Any, Mapping, Iterator, NoReturn as NoReturn, Dict, Type
22
from typing import TYPE_CHECKING as TYPE_CHECKING
3-
from typing import NewType as NewType
3+
from typing import NewType as NewType, overload as overload
44

55
import sys
66

test-data/unit/stubgen.test

Lines changed: 78 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2461,13 +2461,58 @@ class A:
24612461
def f(self, x: Tuple[int, int]) -> int: ...
24622462

24632463

2464+
@overload
2465+
def f(x: int, y: int) -> int: ...
2466+
@overload
2467+
def f(x: Tuple[int, int]) -> int: ...
2468+
2469+
[case testOverload_fromTypingExtensionsImport]
2470+
from typing import Tuple, Union
2471+
from typing_extensions import overload
2472+
2473+
class A:
2474+
@overload
2475+
def f(self, x: int, y: int) -> int:
2476+
...
2477+
2478+
@overload
2479+
def f(self, x: Tuple[int, int]) -> int:
2480+
...
2481+
2482+
def f(self, *args: Union[int, Tuple[int, int]]) -> int:
2483+
pass
2484+
2485+
@overload
2486+
def f(x: int, y: int) -> int:
2487+
...
2488+
2489+
@overload
2490+
def f(x: Tuple[int, int]) -> int:
2491+
...
2492+
2493+
def f(*args: Union[int, Tuple[int, int]]) -> int:
2494+
pass
2495+
2496+
2497+
[out]
2498+
from typing import Tuple
2499+
from typing_extensions import overload
2500+
2501+
class A:
2502+
@overload
2503+
def f(self, x: int, y: int) -> int: ...
2504+
@overload
2505+
def f(self, x: Tuple[int, int]) -> int: ...
2506+
2507+
24642508
@overload
24652509
def f(x: int, y: int) -> int: ...
24662510
@overload
24672511
def f(x: Tuple[int, int]) -> int: ...
24682512

24692513
[case testOverload_importTyping]
24702514
import typing
2515+
import typing_extensions
24712516

24722517
class A:
24732518
@typing.overload
@@ -2506,9 +2551,21 @@ def f(x: typing.Tuple[int, int]) -> int:
25062551
def f(*args: typing.Union[int, typing.Tuple[int, int]]) -> int:
25072552
pass
25082553

2554+
@typing_extensions.overload
2555+
def g(x: int, y: int) -> int:
2556+
...
2557+
2558+
@typing_extensions.overload
2559+
def g(x: typing.Tuple[int, int]) -> int:
2560+
...
2561+
2562+
def g(*args: typing.Union[int, typing.Tuple[int, int]]) -> int:
2563+
pass
2564+
25092565

25102566
[out]
25112567
import typing
2568+
import typing_extensions
25122569

25132570
class A:
25142571
@typing.overload
@@ -2527,10 +2584,14 @@ class A:
25272584
def f(x: int, y: int) -> int: ...
25282585
@typing.overload
25292586
def f(x: typing.Tuple[int, int]) -> int: ...
2530-
2587+
@typing_extensions.overload
2588+
def g(x: int, y: int) -> int: ...
2589+
@typing_extensions.overload
2590+
def g(x: typing.Tuple[int, int]) -> int: ...
25312591

25322592
[case testOverload_importTypingAs]
25332593
import typing as t
2594+
import typing_extensions as te
25342595

25352596
class A:
25362597
@t.overload
@@ -2570,8 +2631,20 @@ def f(*args: t.Union[int, t.Tuple[int, int]]) -> int:
25702631
pass
25712632

25722633

2634+
@te.overload
2635+
def g(x: int, y: int) -> int:
2636+
...
2637+
2638+
@te.overload
2639+
def g(x: t.Tuple[int, int]) -> int:
2640+
...
2641+
2642+
def g(*args: t.Union[int, t.Tuple[int, int]]) -> int:
2643+
pass
2644+
25732645
[out]
25742646
import typing as t
2647+
import typing_extensions as te
25752648

25762649
class A:
25772650
@t.overload
@@ -2590,6 +2663,10 @@ class A:
25902663
def f(x: int, y: int) -> int: ...
25912664
@t.overload
25922665
def f(x: t.Tuple[int, int]) -> int: ...
2666+
@te.overload
2667+
def g(x: int, y: int) -> int: ...
2668+
@te.overload
2669+
def g(x: t.Tuple[int, int]) -> int: ...
25932670

25942671
[case testProtocol_semanal]
25952672
from typing import Protocol, TypeVar

0 commit comments

Comments
 (0)