Skip to content

Commit 38d246d

Browse files
author
Guido van Rossum
committed
Add SupportsInt etc. and Reversible.
1 parent e20621b commit 38d246d

File tree

2 files changed

+119
-0
lines changed

2 files changed

+119
-0
lines changed

prototyping/test_typing.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -597,6 +597,31 @@ def get(self, key: str, default=None):
597597
return default
598598

599599

600+
class ProtocolTests(TestCase):
601+
602+
def test_supports_int(self):
603+
assert issubclass(int, typing.SupportsInt)
604+
assert not issubclass(str, typing.SupportsInt)
605+
606+
def test_supports_float(self):
607+
assert issubclass(float, typing.SupportsFloat)
608+
assert not issubclass(str, typing.SupportsFloat)
609+
610+
def test_supports_abs(self):
611+
assert issubclass(float, typing.SupportsAbs)
612+
assert issubclass(int, typing.SupportsAbs)
613+
assert not issubclass(str, typing.SupportsAbs)
614+
615+
def test_supports_round(self):
616+
assert issubclass(float, typing.SupportsRound)
617+
assert issubclass(int, typing.SupportsRound)
618+
assert not issubclass(str, typing.SupportsRound)
619+
620+
def test_reversible(self):
621+
assert issubclass(list, typing.Reversible)
622+
assert not issubclass(int, typing.Reversible)
623+
624+
600625
class GenericTests(TestCase):
601626

602627
def test_basics(self):

prototyping/typing.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ class TypingMeta(type):
3838
__new__) and a nicer repr().
3939
"""
4040

41+
_is_protocol = False
42+
4143
def __new__(cls, name, bases, namespace, *, _root=False):
4244
if not _root:
4345
raise TypeError("Cannot subclass %s" %
@@ -1029,6 +1031,63 @@ def overload(func):
10291031
raise RuntimeError("Overloading is only supported in library stubs")
10301032

10311033

1034+
class _Protocol(Generic):
1035+
"""Internal base class for protocol classes.
1036+
1037+
This implements a simple-minded structural isinstance check
1038+
(similar but more general than the one-offs in collections.abc
1039+
such as Hashable).
1040+
"""
1041+
1042+
_is_protocol = True
1043+
1044+
@classmethod
1045+
def __subclasshook__(self, cls):
1046+
if not self._is_protocol:
1047+
# No structural checks since this isn't a protocol.
1048+
return NotImplemented
1049+
1050+
if self is _Protocol:
1051+
# Every class is a subclass of the empty protocol.
1052+
return True
1053+
1054+
# Find all attributes defined in the protocol.
1055+
attrs = self._get_protocol_attrs()
1056+
1057+
for attr in attrs:
1058+
if not any(attr in d.__dict__ for d in cls.__mro__):
1059+
return NotImplemented
1060+
return True
1061+
1062+
@classmethod
1063+
def _get_protocol_attrs(self):
1064+
# Get all Protocol base classes.
1065+
protocol_bases = []
1066+
for c in self.__mro__:
1067+
if getattr(c, '_is_protocol', False) and c.__name__ != '_Protocol':
1068+
protocol_bases.append(c)
1069+
1070+
# Get attributes included in protocol.
1071+
attrs = set()
1072+
for base in protocol_bases:
1073+
for attr in base.__dict__.keys():
1074+
# Include attributes not defined in any non-protocol bases.
1075+
for c in self.__mro__:
1076+
if (c is not base and attr in c.__dict__ and
1077+
not getattr(c, '_is_protocol', False)):
1078+
break
1079+
else:
1080+
if (not attr.startswith('_abc_') and
1081+
attr != '__abstractmethods__' and
1082+
attr != '_is_protocol' and
1083+
attr != '__dict__' and
1084+
attr != '_get_protocol_attrs' and
1085+
attr != '__module__'):
1086+
attrs.add(attr)
1087+
1088+
return attrs
1089+
1090+
10321091
# Various ABCs mimicking those in collections.abc.
10331092
# A few are simply re-exported for completeness.
10341093

@@ -1043,6 +1102,41 @@ class Iterator(Iterable, extra=collections.abc.Iterator):
10431102
pass
10441103

10451104

1105+
class SupportsInt(_Protocol):
1106+
1107+
@abstractmethod
1108+
def __int__(self) -> int:
1109+
pass
1110+
1111+
1112+
class SupportsFloat(_Protocol):
1113+
1114+
@abstractmethod
1115+
def __float__(self) -> float:
1116+
pass
1117+
1118+
1119+
class SupportsAbs(_Protocol[T]):
1120+
1121+
@abstractmethod
1122+
def __abs__(self) -> T:
1123+
pass
1124+
1125+
1126+
class SupportsRound(_Protocol[T]):
1127+
1128+
@abstractmethod
1129+
def __round__(self, ndigits: int = 0) -> T:
1130+
pass
1131+
1132+
1133+
class Reversible(_Protocol[T]):
1134+
1135+
@abstractmethod
1136+
def __reversed__(self) -> 'Iterator[T]':
1137+
pass
1138+
1139+
10461140
Sized = collections.abc.Sized # Not generic.
10471141

10481142

0 commit comments

Comments
 (0)