Skip to content

Commit 274af1c

Browse files
authored
Fix inference when class and instance match protocol (#18587)
Fixes #14688 The bug resulted from (accidentally) inferring against `Iterable` for both instance and class object. While working on this I noticed there are also couple flaws in direction handling in constrain inference, namely: * A protocol can never ever be a subtype of class object or a `Type[X]` * When matching against callback protocol, subtype check direction must match inference direction I also (conservatively) fix some unrelated issues uncovered by the fix (to avoid fallout): * Callable subtyping with trivial suffixes was broken for positional-only args * Join of `Parameters` could lead to meaningless results in case of incompatible arg kinds * Protocol inference was inconsistent with protocol subtyping w.r.t. metaclasses.
1 parent c8489a2 commit 274af1c

File tree

7 files changed

+132
-28
lines changed

7 files changed

+132
-28
lines changed

mypy/constraints.py

Lines changed: 31 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -756,40 +756,40 @@ def visit_instance(self, template: Instance) -> list[Constraint]:
756756
"__call__", template, actual, is_operator=True
757757
)
758758
assert call is not None
759-
if mypy.subtypes.is_subtype(actual, erase_typevars(call)):
760-
subres = infer_constraints(call, actual, self.direction)
761-
res.extend(subres)
759+
if (
760+
self.direction == SUPERTYPE_OF
761+
and mypy.subtypes.is_subtype(actual, erase_typevars(call))
762+
or self.direction == SUBTYPE_OF
763+
and mypy.subtypes.is_subtype(erase_typevars(call), actual)
764+
):
765+
res.extend(infer_constraints(call, actual, self.direction))
762766
template.type.inferring.pop()
763767
if isinstance(actual, CallableType) and actual.fallback is not None:
764-
if actual.is_type_obj() and template.type.is_protocol:
768+
if (
769+
actual.is_type_obj()
770+
and template.type.is_protocol
771+
and self.direction == SUPERTYPE_OF
772+
):
765773
ret_type = get_proper_type(actual.ret_type)
766774
if isinstance(ret_type, TupleType):
767775
ret_type = mypy.typeops.tuple_fallback(ret_type)
768776
if isinstance(ret_type, Instance):
769-
if self.direction == SUBTYPE_OF:
770-
subtype = template
771-
else:
772-
subtype = ret_type
773777
res.extend(
774778
self.infer_constraints_from_protocol_members(
775-
ret_type, template, subtype, template, class_obj=True
779+
ret_type, template, ret_type, template, class_obj=True
776780
)
777781
)
778782
actual = actual.fallback
779783
if isinstance(actual, TypeType) and template.type.is_protocol:
780-
if isinstance(actual.item, Instance):
781-
if self.direction == SUBTYPE_OF:
782-
subtype = template
783-
else:
784-
subtype = actual.item
785-
res.extend(
786-
self.infer_constraints_from_protocol_members(
787-
actual.item, template, subtype, template, class_obj=True
788-
)
789-
)
790784
if self.direction == SUPERTYPE_OF:
791-
# Infer constraints for Type[T] via metaclass of T when it makes sense.
792785
a_item = actual.item
786+
if isinstance(a_item, Instance):
787+
res.extend(
788+
self.infer_constraints_from_protocol_members(
789+
a_item, template, a_item, template, class_obj=True
790+
)
791+
)
792+
# Infer constraints for Type[T] via metaclass of T when it makes sense.
793793
if isinstance(a_item, TypeVarType):
794794
a_item = get_proper_type(a_item.upper_bound)
795795
if isinstance(a_item, Instance) and a_item.type.metaclass_type:
@@ -1043,6 +1043,17 @@ def infer_constraints_from_protocol_members(
10431043
return [] # See #11020
10441044
# The above is safe since at this point we know that 'instance' is a subtype
10451045
# of (erased) 'template', therefore it defines all protocol members
1046+
if class_obj:
1047+
# For class objects we must only infer constraints if possible, otherwise it
1048+
# can lead to confusion between class and instance, for example StrEnum is
1049+
# Iterable[str] for an instance, but Iterable[StrEnum] for a class object.
1050+
if not mypy.subtypes.is_subtype(
1051+
inst, erase_typevars(temp), ignore_pos_arg_names=True
1052+
):
1053+
continue
1054+
# This exception matches the one in subtypes.py, see PR #14121 for context.
1055+
if member == "__call__" and instance.type.is_metaclass():
1056+
continue
10461057
res.extend(infer_constraints(temp, inst, self.direction))
10471058
if mypy.subtypes.IS_SETTABLE in mypy.subtypes.get_member_flags(member, protocol):
10481059
# Settable members are invariant, add opposite constraints

mypy/join.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,8 @@ def visit_unpack_type(self, t: UnpackType) -> UnpackType:
355355

356356
def visit_parameters(self, t: Parameters) -> ProperType:
357357
if isinstance(self.s, Parameters):
358-
if len(t.arg_types) != len(self.s.arg_types):
358+
if not is_similar_params(t, self.s):
359+
# TODO: it would be prudent to return [*object, **object] instead of Any.
359360
return self.default(self.s)
360361
from mypy.meet import meet_types
361362

@@ -724,6 +725,15 @@ def is_similar_callables(t: CallableType, s: CallableType) -> bool:
724725
)
725726

726727

728+
def is_similar_params(t: Parameters, s: Parameters) -> bool:
729+
# This matches the logic in is_similar_callables() above.
730+
return (
731+
len(t.arg_types) == len(s.arg_types)
732+
and t.min_args == s.min_args
733+
and (t.var_arg() is not None) == (s.var_arg() is not None)
734+
)
735+
736+
727737
def update_callable_ids(c: CallableType, ids: list[TypeVarId]) -> CallableType:
728738
tv_map = {}
729739
tvs = []

mypy/subtypes.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1719,11 +1719,16 @@ def _incompatible(left_arg: FormalArgument | None, right_arg: FormalArgument | N
17191719
):
17201720
return False
17211721

1722+
if trivial_suffix:
1723+
# For trivial right suffix we *only* check that every non-star right argument
1724+
# has a valid match on the left.
1725+
return True
1726+
17221727
# Phase 1c: Check var args. Right has an infinite series of optional positional
17231728
# arguments. Get all further positional args of left, and make sure
17241729
# they're more general than the corresponding member in right.
17251730
# TODO: are we handling UnpackType correctly here?
1726-
if right_star is not None and not trivial_suffix:
1731+
if right_star is not None:
17271732
# Synthesize an anonymous formal argument for the right
17281733
right_by_position = right.try_synthesizing_arg_from_vararg(None)
17291734
assert right_by_position is not None
@@ -1750,7 +1755,7 @@ def _incompatible(left_arg: FormalArgument | None, right_arg: FormalArgument | N
17501755
# Phase 1d: Check kw args. Right has an infinite series of optional named
17511756
# arguments. Get all further named args of left, and make sure
17521757
# they're more general than the corresponding member in right.
1753-
if right_star2 is not None and not trivial_suffix:
1758+
if right_star2 is not None:
17541759
right_names = {name for name in right.arg_names if name is not None}
17551760
left_only_names = set()
17561761
for name, kind in zip(left.arg_names, left.arg_kinds):

test-data/unit/check-enum.test

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2394,3 +2394,25 @@ def do_check(value: E) -> None:
23942394

23952395
[builtins fixtures/primitives.pyi]
23962396
[typing fixtures/typing-full.pyi]
2397+
2398+
[case testStrEnumClassCorrectIterable]
2399+
from enum import StrEnum
2400+
from typing import Type, TypeVar
2401+
2402+
class Choices(StrEnum):
2403+
LOREM = "lorem"
2404+
IPSUM = "ipsum"
2405+
2406+
var = list(Choices)
2407+
reveal_type(var) # N: Revealed type is "builtins.list[__main__.Choices]"
2408+
2409+
e: type[StrEnum]
2410+
reveal_type(list(e)) # N: Revealed type is "builtins.list[enum.StrEnum]"
2411+
2412+
T = TypeVar("T", bound=StrEnum)
2413+
def list_vals(e: Type[T]) -> list[T]:
2414+
reveal_type(list(e)) # N: Revealed type is "builtins.list[T`-1]"
2415+
return list(e)
2416+
2417+
reveal_type(list_vals(Choices)) # N: Revealed type is "builtins.list[__main__.Choices]"
2418+
[builtins fixtures/enum.pyi]

test-data/unit/check-functions.test

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -105,16 +105,38 @@ if int():
105105
h = h
106106

107107
[case testSubtypingFunctionsDoubleCorrespondence]
108+
def l(x) -> None: ...
109+
def r(__x, *, x) -> None: ...
110+
r = l # E: Incompatible types in assignment (expression has type "Callable[[Any], None]", variable has type "Callable[[Any, NamedArg(Any, 'x')], None]")
108111

112+
[case testSubtypingFunctionsDoubleCorrespondenceNamedOptional]
109113
def l(x) -> None: ...
110-
def r(__, *, x) -> None: ...
111-
r = l # E: Incompatible types in assignment (expression has type "Callable[[Any], None]", variable has type "Callable[[Any, NamedArg(Any, 'x')], None]")
114+
def r(__x, *, x = 1) -> None: ...
115+
r = l # E: Incompatible types in assignment (expression has type "Callable[[Any], None]", variable has type "Callable[[Any, DefaultNamedArg(Any, 'x')], None]")
112116

113-
[case testSubtypingFunctionsRequiredLeftArgNotPresent]
117+
[case testSubtypingFunctionsDoubleCorrespondenceBothNamedOptional]
118+
def l(x = 1) -> None: ...
119+
def r(__x, *, x = 1) -> None: ...
120+
r = l # E: Incompatible types in assignment (expression has type "Callable[[Any], None]", variable has type "Callable[[Any, DefaultNamedArg(Any, 'x')], None]")
121+
122+
[case testSubtypingFunctionsTrivialSuffixRequired]
123+
def l(__x) -> None: ...
124+
def r(x, *args, **kwargs) -> None: ...
125+
126+
r = l # E: Incompatible types in assignment (expression has type "Callable[[Any], None]", variable has type "Callable[[Arg(Any, 'x'), VarArg(Any), KwArg(Any)], None]")
127+
[builtins fixtures/dict.pyi]
114128

129+
[case testSubtypingFunctionsTrivialSuffixOptional]
130+
def l(__x = 1) -> None: ...
131+
def r(x = 1, *args, **kwargs) -> None: ...
132+
133+
r = l # E: Incompatible types in assignment (expression has type "Callable[[DefaultArg(Any)], None]", variable has type "Callable[[DefaultArg(Any, 'x'), VarArg(Any), KwArg(Any)], None]")
134+
[builtins fixtures/dict.pyi]
135+
136+
[case testSubtypingFunctionsRequiredLeftArgNotPresent]
115137
def l(x, y) -> None: ...
116138
def r(x) -> None: ...
117-
r = l # E: Incompatible types in assignment (expression has type "Callable[[Any, Any], None]", variable has type "Callable[[Any], None]")
139+
r = l # E: Incompatible types in assignment (expression has type "Callable[[Any, Any], None]", variable has type "Callable[[Any], None]")
118140

119141
[case testSubtypingFunctionsImplicitNames]
120142
from typing import Any

test-data/unit/check-parameter-specification.test

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2532,3 +2532,30 @@ class GenericWrapper(Generic[P]):
25322532
def contains(c: Callable[P, None], *args: P.args, **kwargs: P.kwargs) -> None: ...
25332533
def inherits(*args: P.args, **kwargs: P.kwargs) -> None: ...
25342534
[builtins fixtures/paramspec.pyi]
2535+
2536+
[case testCallbackProtocolClassObjectParamSpec]
2537+
from typing import Any, Callable, Protocol, Optional, Generic
2538+
from typing_extensions import ParamSpec
2539+
2540+
P = ParamSpec("P")
2541+
2542+
class App: ...
2543+
2544+
class MiddlewareFactory(Protocol[P]):
2545+
def __call__(self, app: App, /, *args: P.args, **kwargs: P.kwargs) -> App:
2546+
...
2547+
2548+
class Capture(Generic[P]): ...
2549+
2550+
class ServerErrorMiddleware(App):
2551+
def __init__(
2552+
self,
2553+
app: App,
2554+
handler: Optional[str] = None,
2555+
debug: bool = False,
2556+
) -> None: ...
2557+
2558+
def fn(f: MiddlewareFactory[P]) -> Capture[P]: ...
2559+
2560+
reveal_type(fn(ServerErrorMiddleware)) # N: Revealed type is "__main__.Capture[[handler: Union[builtins.str, None] =, debug: builtins.bool =]]"
2561+
[builtins fixtures/paramspec.pyi]

test-data/unit/fixtures/enum.pyi

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# Minimal set of builtins required to work with Enums
2-
from typing import TypeVar, Generic
2+
from typing import TypeVar, Generic, Iterator, Sequence, overload, Iterable
33

44
T = TypeVar('T')
55

@@ -13,6 +13,13 @@ class tuple(Generic[T]):
1313
class int: pass
1414
class str:
1515
def __len__(self) -> int: pass
16+
def __iter__(self) -> Iterator[str]: pass
1617

1718
class dict: pass
1819
class ellipsis: pass
20+
21+
class list(Sequence[T]):
22+
@overload
23+
def __init__(self) -> None: pass
24+
@overload
25+
def __init__(self, x: Iterable[T]) -> None: pass

0 commit comments

Comments
 (0)