Skip to content

Commit 57ce73d

Browse files
authored
Support additinal attributes in callback protocols (#14084)
Fixes #10976 Fixes #10403 This is quite straightforward. Note that we will not allow _arbitrary_ attributes on functions, only those that are defined in `types.FunctionType` (or more precisely `builtins.function` that is identical). We have a separate issue for arbitrary attributes #2087
1 parent 47a435f commit 57ce73d

File tree

9 files changed

+100
-29
lines changed

9 files changed

+100
-29
lines changed

mypy/checker.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5883,18 +5883,19 @@ def check_subtype(
58835883
if (
58845884
isinstance(supertype, Instance)
58855885
and supertype.type.is_protocol
5886-
and isinstance(subtype, (Instance, TupleType, TypedDictType))
5886+
and isinstance(subtype, (CallableType, Instance, TupleType, TypedDictType))
58875887
):
58885888
self.msg.report_protocol_problems(subtype, supertype, context, code=msg.code)
58895889
if isinstance(supertype, CallableType) and isinstance(subtype, Instance):
58905890
call = find_member("__call__", subtype, subtype, is_operator=True)
58915891
if call:
58925892
self.msg.note_call(subtype, call, context, code=msg.code)
58935893
if isinstance(subtype, (CallableType, Overloaded)) and isinstance(supertype, Instance):
5894-
if supertype.type.is_protocol and supertype.type.protocol_members == ["__call__"]:
5894+
if supertype.type.is_protocol and "__call__" in supertype.type.protocol_members:
58955895
call = find_member("__call__", supertype, subtype, is_operator=True)
58965896
assert call is not None
5897-
self.msg.note_call(supertype, call, context, code=msg.code)
5897+
if not is_subtype(subtype, call, options=self.options):
5898+
self.msg.note_call(supertype, call, context, code=msg.code)
58985899
self.check_possible_missing_await(subtype, supertype, context)
58995900
return False
59005901

mypy/constraints.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -553,7 +553,7 @@ def visit_instance(self, template: Instance) -> list[Constraint]:
553553
original_actual = actual = self.actual
554554
res: list[Constraint] = []
555555
if isinstance(actual, (CallableType, Overloaded)) and template.type.is_protocol:
556-
if template.type.protocol_members == ["__call__"]:
556+
if "__call__" in template.type.protocol_members:
557557
# Special case: a generic callback protocol
558558
if not any(template == t for t in template.type.inferring):
559559
template.type.inferring.append(template)
@@ -565,7 +565,6 @@ def visit_instance(self, template: Instance) -> list[Constraint]:
565565
subres = infer_constraints(call, actual, self.direction)
566566
res.extend(subres)
567567
template.type.inferring.pop()
568-
return res
569568
if isinstance(actual, CallableType) and actual.fallback is not None:
570569
if actual.is_type_obj() and template.type.is_protocol:
571570
ret_type = get_proper_type(actual.ret_type)
@@ -815,7 +814,7 @@ def visit_instance(self, template: Instance) -> list[Constraint]:
815814
# because some type may be considered a subtype of a protocol
816815
# due to _promote, but still not implement the protocol.
817816
not any(template == t for t in reversed(template.type.inferring))
818-
and mypy.subtypes.is_protocol_implementation(instance, erased)
817+
and mypy.subtypes.is_protocol_implementation(instance, erased, skip=["__call__"])
819818
):
820819
template.type.inferring.append(template)
821820
res.extend(
@@ -831,7 +830,7 @@ def visit_instance(self, template: Instance) -> list[Constraint]:
831830
and
832831
# We avoid infinite recursion for structural subtypes also here.
833832
not any(instance == i for i in reversed(instance.type.inferring))
834-
and mypy.subtypes.is_protocol_implementation(erased, instance)
833+
and mypy.subtypes.is_protocol_implementation(erased, instance, skip=["__call__"])
835834
):
836835
instance.type.inferring.append(instance)
837836
res.extend(
@@ -887,6 +886,8 @@ def infer_constraints_from_protocol_members(
887886
inst = mypy.subtypes.find_member(member, instance, subtype, class_obj=class_obj)
888887
temp = mypy.subtypes.find_member(member, template, subtype)
889888
if inst is None or temp is None:
889+
if member == "__call__":
890+
continue
890891
return [] # See #11020
891892
# The above is safe since at this point we know that 'instance' is a subtype
892893
# of (erased) 'template', therefore it defines all protocol members

mypy/messages.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1866,6 +1866,7 @@ def report_protocol_problems(
18661866

18671867
class_obj = False
18681868
is_module = False
1869+
skip = []
18691870
if isinstance(subtype, TupleType):
18701871
if not isinstance(subtype.partial_fallback, Instance):
18711872
return
@@ -1880,20 +1881,22 @@ def report_protocol_problems(
18801881
class_obj = True
18811882
subtype = subtype.item
18821883
elif isinstance(subtype, CallableType):
1883-
if not subtype.is_type_obj():
1884-
return
1885-
ret_type = get_proper_type(subtype.ret_type)
1886-
if isinstance(ret_type, TupleType):
1887-
ret_type = ret_type.partial_fallback
1888-
if not isinstance(ret_type, Instance):
1889-
return
1890-
class_obj = True
1891-
subtype = ret_type
1884+
if subtype.is_type_obj():
1885+
ret_type = get_proper_type(subtype.ret_type)
1886+
if isinstance(ret_type, TupleType):
1887+
ret_type = ret_type.partial_fallback
1888+
if not isinstance(ret_type, Instance):
1889+
return
1890+
class_obj = True
1891+
subtype = ret_type
1892+
else:
1893+
subtype = subtype.fallback
1894+
skip = ["__call__"]
18921895
if subtype.extra_attrs and subtype.extra_attrs.mod_name:
18931896
is_module = True
18941897

18951898
# Report missing members
1896-
missing = get_missing_protocol_members(subtype, supertype)
1899+
missing = get_missing_protocol_members(subtype, supertype, skip=skip)
18971900
if (
18981901
missing
18991902
and len(missing) < len(supertype.type.protocol_members)
@@ -2605,13 +2608,15 @@ def variance_string(variance: int) -> str:
26052608
return "invariant"
26062609

26072610

2608-
def get_missing_protocol_members(left: Instance, right: Instance) -> list[str]:
2611+
def get_missing_protocol_members(left: Instance, right: Instance, skip: list[str]) -> list[str]:
26092612
"""Find all protocol members of 'right' that are not implemented
26102613
(i.e. completely missing) in 'left'.
26112614
"""
26122615
assert right.type.is_protocol
26132616
missing: list[str] = []
26142617
for member in right.type.protocol_members:
2618+
if member in skip:
2619+
continue
26152620
if not find_member(member, left, left):
26162621
missing.append(member)
26172622
return missing

mypy/subtypes.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -678,13 +678,16 @@ def visit_callable_type(self, left: CallableType) -> bool:
678678
elif isinstance(right, Overloaded):
679679
return all(self._is_subtype(left, item) for item in right.items)
680680
elif isinstance(right, Instance):
681-
if right.type.is_protocol and right.type.protocol_members == ["__call__"]:
682-
# OK, a callable can implement a protocol with a single `__call__` member.
681+
if right.type.is_protocol and "__call__" in right.type.protocol_members:
682+
# OK, a callable can implement a protocol with a `__call__` member.
683683
# TODO: we should probably explicitly exclude self-types in this case.
684684
call = find_member("__call__", right, left, is_operator=True)
685685
assert call is not None
686686
if self._is_subtype(left, call):
687-
return True
687+
if len(right.type.protocol_members) == 1:
688+
return True
689+
if is_protocol_implementation(left.fallback, right, skip=["__call__"]):
690+
return True
688691
if right.type.is_protocol and left.is_type_obj():
689692
ret_type = get_proper_type(left.ret_type)
690693
if isinstance(ret_type, TupleType):
@@ -792,12 +795,15 @@ def visit_literal_type(self, left: LiteralType) -> bool:
792795
def visit_overloaded(self, left: Overloaded) -> bool:
793796
right = self.right
794797
if isinstance(right, Instance):
795-
if right.type.is_protocol and right.type.protocol_members == ["__call__"]:
798+
if right.type.is_protocol and "__call__" in right.type.protocol_members:
796799
# same as for CallableType
797800
call = find_member("__call__", right, left, is_operator=True)
798801
assert call is not None
799802
if self._is_subtype(left, call):
800-
return True
803+
if len(right.type.protocol_members) == 1:
804+
return True
805+
if is_protocol_implementation(left.fallback, right, skip=["__call__"]):
806+
return True
801807
return self._is_subtype(left.fallback, right)
802808
elif isinstance(right, CallableType):
803809
for item in left.items:
@@ -938,7 +944,11 @@ def pop_on_exit(stack: list[tuple[T, T]], left: T, right: T) -> Iterator[None]:
938944

939945

940946
def is_protocol_implementation(
941-
left: Instance, right: Instance, proper_subtype: bool = False, class_obj: bool = False
947+
left: Instance,
948+
right: Instance,
949+
proper_subtype: bool = False,
950+
class_obj: bool = False,
951+
skip: list[str] | None = None,
942952
) -> bool:
943953
"""Check whether 'left' implements the protocol 'right'.
944954
@@ -958,10 +968,13 @@ def f(self) -> A: ...
958968
as well.
959969
"""
960970
assert right.type.is_protocol
971+
if skip is None:
972+
skip = []
961973
# We need to record this check to generate protocol fine-grained dependencies.
962974
TypeState.record_protocol_subtype_check(left.type, right.type)
963975
# nominal subtyping currently ignores '__init__' and '__new__' signatures
964976
members_not_to_check = {"__init__", "__new__"}
977+
members_not_to_check.update(skip)
965978
# Trivial check that circumvents the bug described in issue 9771:
966979
if left.type.is_protocol:
967980
members_right = set(right.type.protocol_members) - members_not_to_check

mypy/test/testtypegen.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from mypy import build
88
from mypy.errors import CompileError
99
from mypy.modulefinder import BuildSource
10-
from mypy.nodes import NameExpr
10+
from mypy.nodes import NameExpr, TempNode
1111
from mypy.options import Options
1212
from mypy.test.config import test_temp_dir
1313
from mypy.test.data import DataDrivenTestCase, DataSuite
@@ -54,6 +54,8 @@ def run_case(self, testcase: DataDrivenTestCase) -> None:
5454
# Filter nodes that should be included in the output.
5555
keys = []
5656
for node in nodes:
57+
if isinstance(node, TempNode):
58+
continue
5759
if node.line != -1 and map[node]:
5860
if ignore_node(node) or node in ignored:
5961
continue

test-data/unit/check-protocols.test

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2642,6 +2642,53 @@ reveal_type([b, a]) # N: Revealed type is "builtins.list[def (x: def (__main__.
26422642
[builtins fixtures/list.pyi]
26432643
[out]
26442644

2645+
[case testCallbackProtocolFunctionAttributesSubtyping]
2646+
from typing import Protocol
2647+
2648+
class A(Protocol):
2649+
__name__: str
2650+
def __call__(self) -> str: ...
2651+
2652+
class B1(Protocol):
2653+
__name__: int
2654+
def __call__(self) -> str: ...
2655+
2656+
class B2(Protocol):
2657+
__name__: str
2658+
def __call__(self) -> int: ...
2659+
2660+
class B3(Protocol):
2661+
__name__: str
2662+
extra_stuff: int
2663+
def __call__(self) -> str: ...
2664+
2665+
def f() -> str: ...
2666+
2667+
reveal_type(f.__name__) # N: Revealed type is "builtins.str"
2668+
a: A = f # OK
2669+
b1: B1 = f # E: Incompatible types in assignment (expression has type "Callable[[], str]", variable has type "B1") \
2670+
# N: Following member(s) of "function" have conflicts: \
2671+
# N: __name__: expected "int", got "str"
2672+
b2: B2 = f # E: Incompatible types in assignment (expression has type "Callable[[], str]", variable has type "B2") \
2673+
# N: "B2.__call__" has type "Callable[[], int]"
2674+
b3: B3 = f # E: Incompatible types in assignment (expression has type "Callable[[], str]", variable has type "B3") \
2675+
# N: "function" is missing following "B3" protocol member: \
2676+
# N: extra_stuff
2677+
2678+
[case testCallbackProtocolFunctionAttributesInference]
2679+
from typing import Protocol, TypeVar, Generic, Tuple
2680+
2681+
T = TypeVar("T")
2682+
S = TypeVar("S", covariant=True)
2683+
class A(Protocol[T, S]):
2684+
__name__: T
2685+
def __call__(self) -> S: ...
2686+
2687+
def f() -> int: ...
2688+
def test(func: A[T, S]) -> Tuple[T, S]: ...
2689+
reveal_type(test(f)) # N: Revealed type is "Tuple[builtins.str, builtins.int]"
2690+
[builtins fixtures/tuple.pyi]
2691+
26452692
[case testProtocolsAlwaysABCs]
26462693
from typing import Protocol
26472694

test-data/unit/fine-grained-inspect.test

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,8 @@ class Meta(type):
5252
==
5353
{"C": ["meth", "x"]}
5454
{"C": ["meth", "x"], "Meta": ["y"], "type": ["__init__"]}
55-
{}
56-
{"object": ["__init__"]}
55+
{"function": ["__name__"]}
56+
{"function": ["__name__"], "object": ["__init__"]}
5757

5858
[case testInspectDefBasic]
5959
# inspect2: --show=definition foo.py:5:5

test-data/unit/fixtures/tuple.pyi

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ class tuple(Sequence[Tco], Generic[Tco]):
2323
def __rmul__(self, n: int) -> Tuple[Tco, ...]: pass
2424
def __add__(self, x: Tuple[Tco, ...]) -> Tuple[Tco, ...]: pass
2525
def count(self, obj: object) -> int: pass
26-
class function: pass
26+
class function:
27+
__name__: str
2728
class ellipsis: pass
2829
class classmethod: pass
2930

test-data/unit/lib-stub/builtins.pyi

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ class float: pass
1717
class str: pass
1818
class bytes: pass
1919

20-
class function: pass
20+
class function:
21+
__name__: str
2122
class ellipsis: pass
2223

2324
from typing import Generic, Sequence, TypeVar

0 commit comments

Comments
 (0)