Skip to content

Commit 25af77d

Browse files
author
Guido van Rossum
committed
Ensure that Node(), Node[T]() and Node[int]() all have the same __class__. Fixes #79.
1 parent 90de1f6 commit 25af77d

File tree

2 files changed

+38
-2
lines changed

2 files changed

+38
-2
lines changed

prototyping/test_typing.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -784,6 +784,25 @@ def append(self, x: int):
784784
a.append(42)
785785
assert a.get() == [1, 42]
786786

787+
def test_type_erasure(self):
788+
T = TypeVar('T')
789+
790+
class Node(Generic[T]):
791+
def __init__(self, label: T, left: 'Node[T]' = None, right: 'Node[T]' = None):
792+
self.label = label # type: T
793+
self.left = left # type: Optional[Node[T]]
794+
self.right = right # type: Optional[Node[T]]
795+
796+
def foo(x: T):
797+
a = Node(x)
798+
b = Node[T](x)
799+
c = Node[Any](x)
800+
assert type(a) is Node
801+
assert type(b) is Node
802+
assert type(c) is Node
803+
804+
foo(42)
805+
787806

788807
class VarianceTests(TestCase):
789808

prototyping/typing.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -882,6 +882,13 @@ class Callable(Final, metaclass=CallableMeta, _root=True):
882882
"""
883883

884884

885+
def _gorg(a):
886+
"""Return the farthest origin of a generic class."""
887+
while a.__origin__ is not None:
888+
a = a.__origin__
889+
return a
890+
891+
885892
def _geqv(a, b):
886893
"""Return whether two generic classes are equivalent.
887894
@@ -892,9 +899,9 @@ def _geqv(a, b):
892899
893900
The relation is reflexive, symmetric and transitive.
894901
"""
895-
# TODO: Don't depend on name/module being equal.
896902
assert isinstance(a, GenericMeta) and isinstance(b, GenericMeta)
897-
return a.__name__ == b.__name__ and a.__module__ == b.__module__
903+
# Reduce each to its origin.
904+
return _gorg(a) is _gorg(b)
898905

899906

900907
class GenericMeta(TypingMeta, abc.ABCMeta):
@@ -1082,6 +1089,16 @@ def lookup_name(mapping: Mapping[X, Y], key: X, default: Y) -> Y:
10821089
# Same body as above.
10831090
"""
10841091

1092+
def __new__(cls, *args, **kwds):
1093+
next_in_mro = None
1094+
for i, c in enumerate(cls.__mro__):
1095+
if _gorg(c) is Generic:
1096+
next_in_mro = cls.__mro__[i+1]
1097+
break
1098+
else:
1099+
next_in_mro = object
1100+
return next_in_mro.__new__(_gorg(cls))
1101+
10851102

10861103
def cast(typ, val):
10871104
"""Cast a value to a type.

0 commit comments

Comments
 (0)