Skip to content

Commit a272164

Browse files
bpo-44490: Improve typing module compatibility with types.Union (GH-27048) (#27222)
(cherry picked from commit bf89ff9) Co-authored-by: Yurii Karabas <[email protected]>
1 parent 37bdd22 commit a272164

File tree

5 files changed

+40
-7
lines changed

5 files changed

+40
-7
lines changed

Lib/test/ann_module.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,3 +58,5 @@ def dec(func):
5858
def wrapper(*args, **kwargs):
5959
return func(*args, **kwargs)
6060
return wrapper
61+
62+
u: int | float

Lib/test/test_grammar.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -473,7 +473,7 @@ class CC(metaclass=CMeta):
473473
def test_var_annot_module_semantics(self):
474474
self.assertEqual(test.__annotations__, {})
475475
self.assertEqual(ann_module.__annotations__,
476-
{1: 2, 'x': int, 'y': str, 'f': typing.Tuple[int, int]})
476+
{1: 2, 'x': int, 'y': str, 'f': typing.Tuple[int, int], 'u': int | float})
477477
self.assertEqual(ann_module.M.__annotations__,
478478
{'123': 123, 'o': type})
479479
self.assertEqual(ann_module2.__annotations__, {})

Lib/test/test_typing.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -315,6 +315,8 @@ def test_repr(self):
315315
self.assertEqual(repr(u), 'typing.Union[typing.List[int], int]')
316316
u = Union[list[int], dict[str, float]]
317317
self.assertEqual(repr(u), 'typing.Union[list[int], dict[str, float]]')
318+
u = Union[int | float]
319+
self.assertEqual(repr(u), 'typing.Union[int, float]')
318320

319321
def test_cannot_subclass(self):
320322
with self.assertRaises(TypeError):
@@ -1449,6 +1451,8 @@ def test_basics(self):
14491451
with self.assertRaises(TypeError):
14501452
issubclass(SM1, SimpleMapping)
14511453
self.assertIsInstance(SM1(), SimpleMapping)
1454+
T = TypeVar("T")
1455+
self.assertEqual(List[list[T] | float].__parameters__, (T,))
14521456

14531457
def test_generic_errors(self):
14541458
T = TypeVar('T')
@@ -1785,6 +1789,7 @@ def test_extended_generic_rules_repr(self):
17851789
def test_generic_forward_ref(self):
17861790
def foobar(x: List[List['CC']]): ...
17871791
def foobar2(x: list[list[ForwardRef('CC')]]): ...
1792+
def foobar3(x: list[ForwardRef('CC | int')] | int): ...
17881793
class CC: ...
17891794
self.assertEqual(
17901795
get_type_hints(foobar, globals(), locals()),
@@ -1794,6 +1799,10 @@ class CC: ...
17941799
get_type_hints(foobar2, globals(), locals()),
17951800
{'x': list[list[CC]]}
17961801
)
1802+
self.assertEqual(
1803+
get_type_hints(foobar3, globals(), locals()),
1804+
{'x': list[CC | int] | int}
1805+
)
17971806

17981807
T = TypeVar('T')
17991808
AT = Tuple[T, ...]
@@ -2467,6 +2476,12 @@ def foo(a: Union['T']):
24672476
self.assertEqual(get_type_hints(foo, globals(), locals()),
24682477
{'a': Union[T]})
24692478

2479+
def foo(a: tuple[ForwardRef('T')] | int):
2480+
pass
2481+
2482+
self.assertEqual(get_type_hints(foo, globals(), locals()),
2483+
{'a': tuple[T] | int})
2484+
24702485
def test_tuple_forward(self):
24712486

24722487
def foo(a: Tuple['T']):
@@ -2851,7 +2866,7 @@ def test_get_type_hints_from_various_objects(self):
28512866
gth(None)
28522867

28532868
def test_get_type_hints_modules(self):
2854-
ann_module_type_hints = {1: 2, 'f': Tuple[int, int], 'x': int, 'y': str}
2869+
ann_module_type_hints = {1: 2, 'f': Tuple[int, int], 'x': int, 'y': str, 'u': int | float}
28552870
self.assertEqual(gth(ann_module), ann_module_type_hints)
28562871
self.assertEqual(gth(ann_module2), {})
28572872
self.assertEqual(gth(ann_module3), {})
@@ -4393,6 +4408,9 @@ def test_no_paramspec_in__parameters__(self):
43934408
self.assertNotIn(P, list[P].__parameters__)
43944409
self.assertIn(T, tuple[T, P].__parameters__)
43954410

4411+
self.assertNotIn(P, (list[P] | int).__parameters__)
4412+
self.assertIn(T, (tuple[T, P] | int).__parameters__)
4413+
43964414
def test_paramspec_in_nested_generics(self):
43974415
# Although ParamSpec should not be found in __parameters__ of most
43984416
# generics, they probably should be found when nested in
@@ -4402,8 +4420,10 @@ def test_paramspec_in_nested_generics(self):
44024420
C1 = Callable[P, T]
44034421
G1 = List[C1]
44044422
G2 = list[C1]
4423+
G3 = list[C1] | int
44054424
self.assertEqual(G1.__parameters__, (P, T))
44064425
self.assertEqual(G2.__parameters__, (P, T))
4426+
self.assertEqual(G3.__parameters__, (P, T))
44074427

44084428

44094429
class ConcatenateTests(BaseTestCase):

Lib/typing.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ def _type_repr(obj):
195195
return repr(obj)
196196

197197

198-
def _collect_type_vars(types, typevar_types=None):
198+
def _collect_type_vars(types_, typevar_types=None):
199199
"""Collect all type variable contained
200200
in types in order of first appearance (lexicographic order). For example::
201201
@@ -204,10 +204,10 @@ def _collect_type_vars(types, typevar_types=None):
204204
if typevar_types is None:
205205
typevar_types = TypeVar
206206
tvars = []
207-
for t in types:
207+
for t in types_:
208208
if isinstance(t, typevar_types) and t not in tvars:
209209
tvars.append(t)
210-
if isinstance(t, (_GenericAlias, GenericAlias)):
210+
if isinstance(t, (_GenericAlias, GenericAlias, types.Union)):
211211
tvars.extend([t for t in t.__parameters__ if t not in tvars])
212212
return tuple(tvars)
213213

@@ -314,12 +314,14 @@ def _eval_type(t, globalns, localns, recursive_guard=frozenset()):
314314
"""
315315
if isinstance(t, ForwardRef):
316316
return t._evaluate(globalns, localns, recursive_guard)
317-
if isinstance(t, (_GenericAlias, GenericAlias)):
317+
if isinstance(t, (_GenericAlias, GenericAlias, types.Union)):
318318
ev_args = tuple(_eval_type(a, globalns, localns, recursive_guard) for a in t.__args__)
319319
if ev_args == t.__args__:
320320
return t
321321
if isinstance(t, GenericAlias):
322322
return GenericAlias(t.__origin__, ev_args)
323+
if isinstance(t, types.Union):
324+
return functools.reduce(operator.or_, ev_args)
323325
else:
324326
return t.copy_with(ev_args)
325327
return t
@@ -1013,7 +1015,7 @@ def __getitem__(self, params):
10131015
for arg in self.__args__:
10141016
if isinstance(arg, self._typevar_types):
10151017
arg = subst[arg]
1016-
elif isinstance(arg, (_GenericAlias, GenericAlias)):
1018+
elif isinstance(arg, (_GenericAlias, GenericAlias, types.Union)):
10171019
subparams = arg.__parameters__
10181020
if subparams:
10191021
subargs = tuple(subst[x] for x in subparams)
@@ -1779,6 +1781,12 @@ def _strip_annotations(t):
17791781
if stripped_args == t.__args__:
17801782
return t
17811783
return GenericAlias(t.__origin__, stripped_args)
1784+
if isinstance(t, types.Union):
1785+
stripped_args = tuple(_strip_annotations(a) for a in t.__args__)
1786+
if stripped_args == t.__args__:
1787+
return t
1788+
return functools.reduce(operator.or_, stripped_args)
1789+
17821790
return t
17831791

17841792

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
:mod:`typing` now searches for type parameters in ``types.Union`` objects.
2+
``get_type_hints`` will also properly resolve annotations with nested
3+
``types.Union`` objects. Patch provided by Yurii Karabas.

0 commit comments

Comments
 (0)