Skip to content

Commit 3efbc5c

Browse files
authored
Allow using modules as subtypes of protocols (#13513)
Fixes #5018 Fixes #5439 Fixes #10850 The implementation is simple but not the most beautiful one. I simply add a new slot to the `Instance` class that represents content of the module. This new attribute is short lived (it is not serialized, and not even stored on variables etc., because we erase it in `copy_modified()`). We don't need to store it, because all the information we need is already available in `MypyFile` node. We just need the new attribute to communicate between the checker and `subtypes.py`. Other possible alternatives like introducing new dedicated `ModuleType`, or passing the symbol tables to `subtypes.py` both look way to complicated. Another argument in favor of this new slot is it could be useful for other things, like `hasattr()` support and ad hoc callable attributes (btw I am already working on the former). Note there is one important limitation: since we don't store the module information, we can't support module objects stored in nested positions, like `self.mods = (foo, bar)` and then `accepts_protocol(self.mods[0])`. We only support variables (name expressions) and direct instance, class, or module attributes (see tests). I think this will cover 99% of possible use-cases.
1 parent be495c7 commit 3efbc5c

13 files changed

+385
-44
lines changed

misc/proper_plugin.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,8 @@ def isinstance_proper_hook(ctx: FunctionContext) -> Type:
5050
right = get_proper_type(ctx.arg_types[1][0])
5151
for arg in ctx.arg_types[0]:
5252
if (
53-
is_improper_type(arg)
54-
or isinstance(get_proper_type(arg), AnyType)
55-
and is_dangerous_target(right)
56-
):
53+
is_improper_type(arg) or isinstance(get_proper_type(arg), AnyType)
54+
) and is_dangerous_target(right):
5755
if is_special_target(right):
5856
return ctx.default_return_type
5957
ctx.api.fail(

mypy/checker.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2285,18 +2285,29 @@ def check_multiple_inheritance(self, typ: TypeInfo) -> None:
22852285
if name in base2.names and base2 not in base.mro:
22862286
self.check_compatibility(name, base, base2, typ)
22872287

2288-
def determine_type_of_class_member(self, sym: SymbolTableNode) -> Type | None:
2288+
def determine_type_of_member(self, sym: SymbolTableNode) -> Type | None:
22892289
if sym.type is not None:
22902290
return sym.type
22912291
if isinstance(sym.node, FuncBase):
22922292
return self.function_type(sym.node)
22932293
if isinstance(sym.node, TypeInfo):
2294-
# nested class
2295-
return type_object_type(sym.node, self.named_type)
2294+
if sym.node.typeddict_type:
2295+
# We special-case TypedDict, because they don't define any constructor.
2296+
return self.expr_checker.typeddict_callable(sym.node)
2297+
else:
2298+
return type_object_type(sym.node, self.named_type)
22962299
if isinstance(sym.node, TypeVarExpr):
22972300
# Use of TypeVars is rejected in an expression/runtime context, so
22982301
# we don't need to check supertype compatibility for them.
22992302
return AnyType(TypeOfAny.special_form)
2303+
if isinstance(sym.node, TypeAlias):
2304+
with self.msg.filter_errors():
2305+
# Suppress any errors, they will be given when analyzing the corresponding node.
2306+
# Here we may have incorrect options and location context.
2307+
return self.expr_checker.alias_type_in_runtime_context(
2308+
sym.node, sym.node.no_args, sym.node
2309+
)
2310+
# TODO: handle more node kinds here.
23002311
return None
23012312

23022313
def check_compatibility(
@@ -2327,8 +2338,8 @@ class C(B, A[int]): ... # this is unsafe because...
23272338
return
23282339
first = base1.names[name]
23292340
second = base2.names[name]
2330-
first_type = get_proper_type(self.determine_type_of_class_member(first))
2331-
second_type = get_proper_type(self.determine_type_of_class_member(second))
2341+
first_type = get_proper_type(self.determine_type_of_member(first))
2342+
second_type = get_proper_type(self.determine_type_of_member(second))
23322343

23332344
if isinstance(first_type, FunctionLike) and isinstance(second_type, FunctionLike):
23342345
if first_type.is_type_obj() and second_type.is_type_obj():

mypy/checkexpr.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@
127127
CallableType,
128128
DeletedType,
129129
ErasedType,
130+
ExtraAttrs,
130131
FunctionLike,
131132
Instance,
132133
LiteralType,
@@ -332,13 +333,7 @@ def analyze_ref_expr(self, e: RefExpr, lvalue: bool = False) -> Type:
332333
result = erasetype.erase_typevars(result)
333334
elif isinstance(node, MypyFile):
334335
# Reference to a module object.
335-
try:
336-
result = self.named_type("types.ModuleType")
337-
except KeyError:
338-
# In test cases might 'types' may not be available.
339-
# Fall back to a dummy 'object' type instead to
340-
# avoid a crash.
341-
result = self.named_type("builtins.object")
336+
result = self.module_type(node)
342337
elif isinstance(node, Decorator):
343338
result = self.analyze_var_ref(node.var, e)
344339
elif isinstance(node, TypeAlias):
@@ -374,6 +369,29 @@ def analyze_var_ref(self, var: Var, context: Context) -> Type:
374369
# Implicit 'Any' type.
375370
return AnyType(TypeOfAny.special_form)
376371

372+
def module_type(self, node: MypyFile) -> Instance:
373+
try:
374+
result = self.named_type("types.ModuleType")
375+
except KeyError:
376+
# In test cases might 'types' may not be available.
377+
# Fall back to a dummy 'object' type instead to
378+
# avoid a crash.
379+
result = self.named_type("builtins.object")
380+
module_attrs = {}
381+
immutable = set()
382+
for name, n in node.names.items():
383+
if isinstance(n.node, Var) and n.node.is_final:
384+
immutable.add(name)
385+
typ = self.chk.determine_type_of_member(n)
386+
if typ:
387+
module_attrs[name] = typ
388+
else:
389+
# TODO: what to do about nested module references?
390+
# They are non-trivial because there may be import cycles.
391+
module_attrs[name] = AnyType(TypeOfAny.special_form)
392+
result.extra_attrs = ExtraAttrs(module_attrs, immutable, node.fullname)
393+
return result
394+
377395
def visit_call_expr(self, e: CallExpr, allow_none_return: bool = False) -> Type:
378396
"""Type check a call expression."""
379397
if e.analyzed:

mypy/checkmember.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -475,6 +475,9 @@ def analyze_member_var_access(
475475
return analyze_var(name, v, itype, info, mx, implicit=implicit)
476476
elif isinstance(v, FuncDef):
477477
assert False, "Did not expect a function"
478+
elif isinstance(v, MypyFile):
479+
mx.chk.module_refs.add(v.fullname)
480+
return mx.chk.expr_checker.module_type(v)
478481
elif (
479482
not v
480483
and name not in ["__getattr__", "__setattr__", "__getattribute__"]

mypy/constraints.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -791,7 +791,7 @@ def infer_constraints_from_protocol_members(
791791
# The above is safe since at this point we know that 'instance' is a subtype
792792
# of (erased) 'template', therefore it defines all protocol members
793793
res.extend(infer_constraints(temp, inst, self.direction))
794-
if mypy.subtypes.IS_SETTABLE in mypy.subtypes.get_member_flags(member, protocol.type):
794+
if mypy.subtypes.IS_SETTABLE in mypy.subtypes.get_member_flags(member, protocol):
795795
# Settable members are invariant, add opposite constraints
796796
res.extend(infer_constraints(temp, inst, neg_op(self.direction)))
797797
return res

mypy/messages.py

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1822,6 +1822,7 @@ def report_protocol_problems(
18221822
return
18231823

18241824
class_obj = False
1825+
is_module = False
18251826
if isinstance(subtype, TupleType):
18261827
if not isinstance(subtype.partial_fallback, Instance):
18271828
return
@@ -1845,6 +1846,8 @@ def report_protocol_problems(
18451846
return
18461847
class_obj = True
18471848
subtype = ret_type
1849+
if subtype.extra_attrs and subtype.extra_attrs.mod_name:
1850+
is_module = True
18481851

18491852
# Report missing members
18501853
missing = get_missing_protocol_members(subtype, supertype)
@@ -1881,11 +1884,8 @@ def report_protocol_problems(
18811884
or not subtype.type.defn.type_vars
18821885
or not supertype.type.defn.type_vars
18831886
):
1884-
self.note(
1885-
f"Following member(s) of {format_type(subtype)} have conflicts:",
1886-
context,
1887-
code=code,
1888-
)
1887+
type_name = format_type(subtype, module_names=True)
1888+
self.note(f"Following member(s) of {type_name} have conflicts:", context, code=code)
18891889
for name, got, exp in conflict_types[:MAX_ITEMS]:
18901890
exp = get_proper_type(exp)
18911891
got = get_proper_type(got)
@@ -1902,28 +1902,28 @@ def report_protocol_problems(
19021902
self.note("Expected:", context, offset=OFFSET, code=code)
19031903
if isinstance(exp, CallableType):
19041904
self.note(
1905-
pretty_callable(exp, skip_self=class_obj),
1905+
pretty_callable(exp, skip_self=class_obj or is_module),
19061906
context,
19071907
offset=2 * OFFSET,
19081908
code=code,
19091909
)
19101910
else:
19111911
assert isinstance(exp, Overloaded)
19121912
self.pretty_overload(
1913-
exp, context, 2 * OFFSET, code=code, skip_self=class_obj
1913+
exp, context, 2 * OFFSET, code=code, skip_self=class_obj or is_module
19141914
)
19151915
self.note("Got:", context, offset=OFFSET, code=code)
19161916
if isinstance(got, CallableType):
19171917
self.note(
1918-
pretty_callable(got, skip_self=class_obj),
1918+
pretty_callable(got, skip_self=class_obj or is_module),
19191919
context,
19201920
offset=2 * OFFSET,
19211921
code=code,
19221922
)
19231923
else:
19241924
assert isinstance(got, Overloaded)
19251925
self.pretty_overload(
1926-
got, context, 2 * OFFSET, code=code, skip_self=class_obj
1926+
got, context, 2 * OFFSET, code=code, skip_self=class_obj or is_module
19271927
)
19281928
self.print_more(conflict_types, context, OFFSET, MAX_ITEMS, code=code)
19291929

@@ -2147,7 +2147,9 @@ def format_callable_args(
21472147
return ", ".join(arg_strings)
21482148

21492149

2150-
def format_type_inner(typ: Type, verbosity: int, fullnames: set[str] | None) -> str:
2150+
def format_type_inner(
2151+
typ: Type, verbosity: int, fullnames: set[str] | None, module_names: bool = False
2152+
) -> str:
21512153
"""
21522154
Convert a type to a relatively short string suitable for error messages.
21532155
@@ -2187,7 +2189,10 @@ def format_literal_value(typ: LiteralType) -> str:
21872189
# Get the short name of the type.
21882190
if itype.type.fullname in ("types.ModuleType", "_importlib_modulespec.ModuleType"):
21892191
# Make some common error messages simpler and tidier.
2190-
return "Module"
2192+
base_str = "Module"
2193+
if itype.extra_attrs and itype.extra_attrs.mod_name and module_names:
2194+
return f"{base_str} {itype.extra_attrs.mod_name}"
2195+
return base_str
21912196
if verbosity >= 2 or (fullnames and itype.type.fullname in fullnames):
21922197
base_str = itype.type.fullname
21932198
else:
@@ -2361,7 +2366,7 @@ def find_type_overlaps(*types: Type) -> set[str]:
23612366
return overlaps
23622367

23632368

2364-
def format_type(typ: Type, verbosity: int = 0) -> str:
2369+
def format_type(typ: Type, verbosity: int = 0, module_names: bool = False) -> str:
23652370
"""
23662371
Convert a type to a relatively short string suitable for error messages.
23672372
@@ -2372,10 +2377,10 @@ def format_type(typ: Type, verbosity: int = 0) -> str:
23722377
modification of the formatted string is required, callers should use
23732378
format_type_bare.
23742379
"""
2375-
return quote_type_string(format_type_bare(typ, verbosity))
2380+
return quote_type_string(format_type_bare(typ, verbosity, module_names))
23762381

23772382

2378-
def format_type_bare(typ: Type, verbosity: int = 0) -> str:
2383+
def format_type_bare(typ: Type, verbosity: int = 0, module_names: bool = False) -> str:
23792384
"""
23802385
Convert a type to a relatively short string suitable for error messages.
23812386
@@ -2387,7 +2392,7 @@ def format_type_bare(typ: Type, verbosity: int = 0) -> str:
23872392
instead. (The caller may want to use quote_type_string after
23882393
processing has happened, to maintain consistent quoting in messages.)
23892394
"""
2390-
return format_type_inner(typ, verbosity, find_type_overlaps(typ))
2395+
return format_type_inner(typ, verbosity, find_type_overlaps(typ), module_names)
23912396

23922397

23932398
def format_type_distinctly(*types: Type, bare: bool = False) -> tuple[str, ...]:
@@ -2564,7 +2569,7 @@ def get_conflict_protocol_types(
25642569
if not subtype:
25652570
continue
25662571
is_compat = is_subtype(subtype, supertype, ignore_pos_arg_names=True)
2567-
if IS_SETTABLE in get_member_flags(member, right.type):
2572+
if IS_SETTABLE in get_member_flags(member, right):
25682573
is_compat = is_compat and is_subtype(supertype, subtype)
25692574
if not is_compat:
25702575
conflicts.append((member, subtype, supertype))
@@ -2581,11 +2586,7 @@ def get_bad_protocol_flags(
25812586
all_flags: list[tuple[str, set[int], set[int]]] = []
25822587
for member in right.type.protocol_members:
25832588
if find_member(member, left, left):
2584-
item = (
2585-
member,
2586-
get_member_flags(member, left.type),
2587-
get_member_flags(member, right.type),
2588-
)
2589+
item = (member, get_member_flags(member, left), get_member_flags(member, right))
25892590
all_flags.append(item)
25902591
bad_flags = []
25912592
for name, subflags, superflags in all_flags:

mypy/server/deps.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -969,6 +969,9 @@ def visit_instance(self, typ: Instance) -> list[str]:
969969
triggers.extend(self.get_type_triggers(arg))
970970
if typ.last_known_value:
971971
triggers.extend(self.get_type_triggers(typ.last_known_value))
972+
if typ.extra_attrs and typ.extra_attrs.mod_name:
973+
# Module as type effectively depends on all module attributes, use wildcard.
974+
triggers.append(make_wildcard_trigger(typ.extra_attrs.mod_name))
972975
return triggers
973976

974977
def visit_type_alias_type(self, typ: TypeAliasType) -> list[str]:

mypy/subtypes.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1010,8 +1010,8 @@ def named_type(fullname: str) -> Instance:
10101010
if isinstance(subtype, NoneType) and isinstance(supertype, CallableType):
10111011
# We want __hash__ = None idiom to work even without --strict-optional
10121012
return False
1013-
subflags = get_member_flags(member, left.type, class_obj=class_obj)
1014-
superflags = get_member_flags(member, right.type)
1013+
subflags = get_member_flags(member, left, class_obj=class_obj)
1014+
superflags = get_member_flags(member, right)
10151015
if IS_SETTABLE in superflags:
10161016
# Check opposite direction for settable attributes.
10171017
if not is_subtype(supertype, subtype):
@@ -1095,10 +1095,12 @@ def find_member(
10951095
# PEP 544 doesn't specify anything about such use cases. So we just try
10961096
# to do something meaningful (at least we should not crash).
10971097
return TypeType(fill_typevars_with_any(v))
1098+
if itype.extra_attrs and name in itype.extra_attrs.attrs:
1099+
return itype.extra_attrs.attrs[name]
10981100
return None
10991101

11001102

1101-
def get_member_flags(name: str, info: TypeInfo, class_obj: bool = False) -> set[int]:
1103+
def get_member_flags(name: str, itype: Instance, class_obj: bool = False) -> set[int]:
11021104
"""Detect whether a member 'name' is settable, whether it is an
11031105
instance or class variable, and whether it is class or static method.
11041106
@@ -1109,6 +1111,7 @@ def get_member_flags(name: str, info: TypeInfo, class_obj: bool = False) -> set[
11091111
* IS_CLASS_OR_STATIC: set for methods decorated with @classmethod or
11101112
with @staticmethod.
11111113
"""
1114+
info = itype.type
11121115
method = info.get_method(name)
11131116
setattr_meth = info.get_method("__setattr__")
11141117
if method:
@@ -1126,11 +1129,18 @@ def get_member_flags(name: str, info: TypeInfo, class_obj: bool = False) -> set[
11261129
if not node:
11271130
if setattr_meth:
11281131
return {IS_SETTABLE}
1132+
if itype.extra_attrs and name in itype.extra_attrs.attrs:
1133+
flags = set()
1134+
if name not in itype.extra_attrs.immutable:
1135+
flags.add(IS_SETTABLE)
1136+
return flags
11291137
return set()
11301138
v = node.node
11311139
# just a variable
11321140
if isinstance(v, Var) and not v.is_property:
1133-
flags = {IS_SETTABLE}
1141+
flags = set()
1142+
if not v.is_final:
1143+
flags.add(IS_SETTABLE)
11341144
if v.is_classvar:
11351145
flags.add(IS_CLASSVAR)
11361146
if class_obj and v.is_inferred:

0 commit comments

Comments
 (0)