@@ -38,6 +38,8 @@ class TypingMeta(type):
38
38
__new__) and a nicer repr().
39
39
"""
40
40
41
+ _is_protocol = False
42
+
41
43
def __new__ (cls , name , bases , namespace , * , _root = False ):
42
44
if not _root :
43
45
raise TypeError ("Cannot subclass %s" %
@@ -1029,6 +1031,63 @@ def overload(func):
1029
1031
raise RuntimeError ("Overloading is only supported in library stubs" )
1030
1032
1031
1033
1034
+ class _Protocol (Generic ):
1035
+ """Internal base class for protocol classes.
1036
+
1037
+ This implements a simple-minded structural isinstance check
1038
+ (similar but more general than the one-offs in collections.abc
1039
+ such as Hashable).
1040
+ """
1041
+
1042
+ _is_protocol = True
1043
+
1044
+ @classmethod
1045
+ def __subclasshook__ (self , cls ):
1046
+ if not self ._is_protocol :
1047
+ # No structural checks since this isn't a protocol.
1048
+ return NotImplemented
1049
+
1050
+ if self is _Protocol :
1051
+ # Every class is a subclass of the empty protocol.
1052
+ return True
1053
+
1054
+ # Find all attributes defined in the protocol.
1055
+ attrs = self ._get_protocol_attrs ()
1056
+
1057
+ for attr in attrs :
1058
+ if not any (attr in d .__dict__ for d in cls .__mro__ ):
1059
+ return NotImplemented
1060
+ return True
1061
+
1062
+ @classmethod
1063
+ def _get_protocol_attrs (self ):
1064
+ # Get all Protocol base classes.
1065
+ protocol_bases = []
1066
+ for c in self .__mro__ :
1067
+ if getattr (c , '_is_protocol' , False ) and c .__name__ != '_Protocol' :
1068
+ protocol_bases .append (c )
1069
+
1070
+ # Get attributes included in protocol.
1071
+ attrs = set ()
1072
+ for base in protocol_bases :
1073
+ for attr in base .__dict__ .keys ():
1074
+ # Include attributes not defined in any non-protocol bases.
1075
+ for c in self .__mro__ :
1076
+ if (c is not base and attr in c .__dict__ and
1077
+ not getattr (c , '_is_protocol' , False )):
1078
+ break
1079
+ else :
1080
+ if (not attr .startswith ('_abc_' ) and
1081
+ attr != '__abstractmethods__' and
1082
+ attr != '_is_protocol' and
1083
+ attr != '__dict__' and
1084
+ attr != '_get_protocol_attrs' and
1085
+ attr != '__module__' ):
1086
+ attrs .add (attr )
1087
+
1088
+ return attrs
1089
+
1090
+
1032
1091
# Various ABCs mimicking those in collections.abc.
1033
1092
# A few are simply re-exported for completeness.
1034
1093
@@ -1043,6 +1102,41 @@ class Iterator(Iterable, extra=collections.abc.Iterator):
1043
1102
pass
1044
1103
1045
1104
1105
+ class SupportsInt (_Protocol ):
1106
+
1107
+ @abstractmethod
1108
+ def __int__ (self ) -> int :
1109
+ pass
1110
+
1111
+
1112
+ class SupportsFloat (_Protocol ):
1113
+
1114
+ @abstractmethod
1115
+ def __float__ (self ) -> float :
1116
+ pass
1117
+
1118
+
1119
+ class SupportsAbs (_Protocol [T ]):
1120
+
1121
+ @abstractmethod
1122
+ def __abs__ (self ) -> T :
1123
+ pass
1124
+
1125
+
1126
+ class SupportsRound (_Protocol [T ]):
1127
+
1128
+ @abstractmethod
1129
+ def __round__ (self , ndigits : int = 0 ) -> T :
1130
+ pass
1131
+
1132
+
1133
+ class Reversible (_Protocol [T ]):
1134
+
1135
+ @abstractmethod
1136
+ def __reversed__ (self ) -> 'Iterator[T]' :
1137
+ pass
1138
+
1139
+
1046
1140
Sized = collections .abc .Sized # Not generic.
1047
1141
1048
1142
0 commit comments