Skip to content

Commit 9a92899

Browse files
authored
Make import cycles more predictable by prioritizing different import forms (#1736)
This partially addresses #1530 (but does not fully fix it). The new algorithm always processes x before y, while the old one gives an error when y is processed before x: ``` x.py:4: note: In module imported here: y.py: note: In class "Sub": y.py:3: error: Cannot determine type of 'attr' ``` I am working on additional heuristics in https://github.com/python/mypy/tree/class_attrs.
1 parent 53d6e24 commit 9a92899

File tree

5 files changed

+179
-25
lines changed

5 files changed

+179
-25
lines changed

mypy/build.py

Lines changed: 116 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -329,12 +329,22 @@ def default_lib_path(data_dir: str, pyversion: Tuple[int, int],
329329
('data_json', str), # path of <id>.data.json
330330
('suppressed', List[str]), # dependencies that weren't imported
331331
('flags', Optional[List[str]]), # build flags
332+
('dep_prios', List[int]),
332333
])
333-
# NOTE: dependencies + suppressed == all unreachable imports;
334+
# NOTE: dependencies + suppressed == all reachable imports;
334335
# suppressed contains those reachable imports that were prevented by
335336
# --silent-imports or simply not found.
336337

337338

339+
# Priorities used for imports. (Here, top-level includes inside a class.)
340+
# These are used to determine a more predictable order in which the
341+
# nodes in an import cycle are processed.
342+
PRI_HIGH = 5 # top-level "from X import blah"
343+
PRI_MED = 10 # top-level "import X"
344+
PRI_LOW = 20 # either form inside a function
345+
PRI_ALL = 99 # include all priorities
346+
347+
338348
class BuildManager:
339349
"""This class holds shared state for building a mypy program.
340350
@@ -395,12 +405,13 @@ def __init__(self, data_dir: str,
395405
self.missing_modules = set() # type: Set[str]
396406

397407
def all_imported_modules_in_file(self,
398-
file: MypyFile) -> List[Tuple[str, int]]:
408+
file: MypyFile) -> List[Tuple[int, str, int]]:
399409
"""Find all reachable import statements in a file.
400410
401-
Return list of tuples (module id, import line number) for all modules
402-
imported in file.
411+
Return list of tuples (priority, module id, import line number)
412+
for all modules imported in file; lower numbers == higher priority.
403413
"""
414+
404415
def correct_rel_imp(imp: Union[ImportFrom, ImportAll]) -> str:
405416
"""Function to correct for relative imports."""
406417
file_id = file.fullname()
@@ -415,21 +426,23 @@ def correct_rel_imp(imp: Union[ImportFrom, ImportAll]) -> str:
415426

416427
return new_id
417428

418-
res = [] # type: List[Tuple[str, int]]
429+
res = [] # type: List[Tuple[int, str, int]]
419430
for imp in file.imports:
420431
if not imp.is_unreachable:
421432
if isinstance(imp, Import):
433+
pri = PRI_MED if imp.is_top_level else PRI_LOW
422434
for id, _ in imp.ids:
423-
res.append((id, imp.line))
435+
res.append((pri, id, imp.line))
424436
elif isinstance(imp, ImportFrom):
425437
cur_id = correct_rel_imp(imp)
426438
pos = len(res)
427439
all_are_submodules = True
428440
# Also add any imported names that are submodules.
441+
pri = PRI_MED if imp.is_top_level else PRI_LOW
429442
for name, __ in imp.names:
430443
sub_id = cur_id + '.' + name
431444
if self.is_module(sub_id):
432-
res.append((sub_id, imp.line))
445+
res.append((pri, sub_id, imp.line))
433446
else:
434447
all_are_submodules = False
435448
# If all imported names are submodules, don't add
@@ -438,9 +451,12 @@ def correct_rel_imp(imp: Union[ImportFrom, ImportAll]) -> str:
438451
# cur_id is also a dependency, and we should
439452
# insert it *before* any submodules.
440453
if not all_are_submodules:
441-
res.insert(pos, ((cur_id, imp.line)))
454+
pri = PRI_HIGH if imp.is_top_level else PRI_LOW
455+
res.insert(pos, ((pri, cur_id, imp.line)))
442456
elif isinstance(imp, ImportAll):
443-
res.append((correct_rel_imp(imp), imp.line))
457+
pri = PRI_HIGH if imp.is_top_level else PRI_LOW
458+
res.append((pri, correct_rel_imp(imp), imp.line))
459+
444460
return res
445461

446462
def is_module(self, id: str) -> bool:
@@ -773,16 +789,18 @@ def find_cache_meta(id: str, path: str, manager: BuildManager) -> Optional[Cache
773789
data_json,
774790
meta.get('suppressed', []),
775791
meta.get('flags'),
792+
meta.get('dep_prios', []),
776793
)
777794
if (m.id != id or m.path != path or
778795
m.mtime is None or m.size is None or
779796
m.dependencies is None or m.data_mtime is None):
780797
return None
781798

782-
# Metadata generated by older mypy version and no flags were saved
783-
if m.flags is None:
799+
# Ignore cache if generated by an older mypy version.
800+
if m.flags is None or len(m.dependencies) != len(m.dep_prios):
784801
return None
785802

803+
# Ignore cache if (relevant) flags aren't the same.
786804
cached_flags = select_flags_affecting_cache(m.flags)
787805
current_flags = select_flags_affecting_cache(manager.flags)
788806
if cached_flags != current_flags:
@@ -821,6 +839,7 @@ def random_string():
821839

822840
def write_cache(id: str, path: str, tree: MypyFile,
823841
dependencies: List[str], suppressed: List[str],
842+
dep_prios: List[int],
824843
manager: BuildManager) -> None:
825844
"""Write cache files for a module.
826845
@@ -830,6 +849,7 @@ def write_cache(id: str, path: str, tree: MypyFile,
830849
tree: the fully checked module data
831850
dependencies: module IDs on which this module depends
832851
suppressed: module IDs which were suppressed as dependencies
852+
dep_prios: priorities (parallel array to dependencies)
833853
manager: the build manager (for pyversion, log/trace)
834854
"""
835855
path = os.path.abspath(path)
@@ -859,6 +879,7 @@ def write_cache(id: str, path: str, tree: MypyFile,
859879
'dependencies': dependencies,
860880
'suppressed': suppressed,
861881
'flags': manager.flags,
882+
'dep_prios': dep_prios,
862883
}
863884
with open(meta_json_tmp, 'w') as f:
864885
json.dump(meta, f, sort_keys=True)
@@ -1031,6 +1052,7 @@ class State:
10311052
tree = None # type: Optional[MypyFile]
10321053
dependencies = None # type: List[str]
10331054
suppressed = None # type: List[str] # Suppressed/missing dependencies
1055+
priorities = None # type: Dict[str, int]
10341056

10351057
# Map each dependency to the line number where it is first imported
10361058
dep_line_map = None # type: Dict[str, int]
@@ -1132,6 +1154,9 @@ def __init__(self,
11321154
# compare them to the originals later.
11331155
self.dependencies = list(self.meta.dependencies)
11341156
self.suppressed = list(self.meta.suppressed)
1157+
assert len(self.meta.dependencies) == len(self.meta.dep_prios)
1158+
self.priorities = {id: pri
1159+
for id, pri in zip(self.meta.dependencies, self.meta.dep_prios)}
11351160
self.dep_line_map = {}
11361161
else:
11371162
# Parse the file (and then some) to get the dependencies.
@@ -1267,8 +1292,10 @@ def parse_file(self) -> None:
12671292
# Also keep track of each dependency's source line.
12681293
dependencies = []
12691294
suppressed = []
1295+
priorities = {} # type: Dict[str, int] # id -> priority
12701296
dep_line_map = {} # type: Dict[str, int] # id -> line
1271-
for id, line in manager.all_imported_modules_in_file(self.tree):
1297+
for pri, id, line in manager.all_imported_modules_in_file(self.tree):
1298+
priorities[id] = min(pri, priorities.get(id, PRI_ALL))
12721299
if id == self.id:
12731300
continue
12741301
# Omit missing modules, as otherwise we could not type-check
@@ -1299,6 +1326,7 @@ def parse_file(self) -> None:
12991326
# for differences (e.g. --silent-imports).
13001327
self.dependencies = dependencies
13011328
self.suppressed = suppressed
1329+
self.priorities = priorities
13021330
self.dep_line_map = dep_line_map
13031331
self.check_blockers()
13041332

@@ -1338,8 +1366,10 @@ def type_check(self) -> None:
13381366

13391367
def write_cache(self) -> None:
13401368
if self.path and INCREMENTAL in self.manager.flags and not self.manager.errors.is_errors():
1369+
dep_prios = [self.priorities.get(dep, PRI_HIGH) for dep in self.dependencies]
13411370
write_cache(self.id, self.path, self.tree,
13421371
list(self.dependencies), list(self.suppressed),
1372+
dep_prios,
13431373
self.manager)
13441374

13451375

@@ -1408,10 +1438,9 @@ def process_graph(graph: Graph, manager: BuildManager) -> None:
14081438
# dependencies) to roots (those from which everything else can be
14091439
# reached).
14101440
for ascc in sccs:
1411-
# Sort the SCC's nodes in *reverse* order or encounter.
1412-
# This is a heuristic for handling import cycles.
1441+
# Order the SCC's nodes using a heuristic.
14131442
# Note that ascc is a set, and scc is a list.
1414-
scc = sorted(ascc, key=lambda id: -graph[id].order)
1443+
scc = order_ascc(graph, ascc)
14151444
# If builtins is in the list, move it last. (This is a bit of
14161445
# a hack, but it's necessary because the builtins module is
14171446
# part of a small cycle involving at least {builtins, abc,
@@ -1420,6 +1449,12 @@ def process_graph(graph: Graph, manager: BuildManager) -> None:
14201449
if 'builtins' in ascc:
14211450
scc.remove('builtins')
14221451
scc.append('builtins')
1452+
if manager.flags.count(VERBOSE) >= 2:
1453+
for id in scc:
1454+
manager.trace("Priorities for %s:" % id,
1455+
" ".join("%s:%d" % (x, graph[id].priorities[x])
1456+
for x in graph[id].dependencies
1457+
if x in ascc and x in graph[id].priorities))
14231458
# Because the SCCs are presented in topological sort order, we
14241459
# don't need to look at dependencies recursively for staleness
14251460
# -- the immediate dependencies are sufficient.
@@ -1446,7 +1481,7 @@ def process_graph(graph: Graph, manager: BuildManager) -> None:
14461481
# cache file is newer than any scc node's cache file.
14471482
oldest_in_scc = min(graph[id].meta.data_mtime for id in scc)
14481483
newest_in_deps = 0 if not deps else max(graph[dep].meta.data_mtime for dep in deps)
1449-
if manager.flags.count(VERBOSE) >= 2: # Dump all mtimes for extreme debugging.
1484+
if manager.flags.count(VERBOSE) >= 3: # Dump all mtimes for extreme debugging.
14501485
all_ids = sorted(ascc | deps, key=lambda id: graph[id].meta.data_mtime)
14511486
for id in all_ids:
14521487
if id in scc:
@@ -1486,6 +1521,53 @@ def process_graph(graph: Graph, manager: BuildManager) -> None:
14861521
process_stale_scc(graph, scc)
14871522

14881523

1524+
def order_ascc(graph: Graph, ascc: AbstractSet[str], pri_max: int = PRI_ALL) -> List[str]:
1525+
"""Come up with the ideal processing order within an SCC.
1526+
1527+
Using the priorities assigned by all_imported_modules_in_file(),
1528+
try to reduce the cycle to a DAG, by omitting arcs representing
1529+
dependencies of lower priority.
1530+
1531+
In the simplest case, if we have A <--> B where A has a top-level
1532+
"import B" (medium priority) but B only has the reverse "import A"
1533+
inside a function (low priority), we turn the cycle into a DAG by
1534+
dropping the B --> A arc, which leaves only A --> B.
1535+
1536+
If all arcs have the same priority, we fall back to sorting by
1537+
reverse global order (the order in which modules were first
1538+
encountered).
1539+
1540+
The algorithm is recursive, as follows: when as arcs of different
1541+
priorities are present, drop all arcs of the lowest priority,
1542+
identify SCCs in the resulting graph, and apply the algorithm to
1543+
each SCC thus found. The recursion is bounded because at each
1544+
recursion the spread in priorities is (at least) one less.
1545+
1546+
In practice there are only a few priority levels (currently
1547+
N=3) and in the worst case we just carry out the same algorithm
1548+
for finding SCCs N times. Thus the complexity is no worse than
1549+
the complexity of the original SCC-finding algorithm -- see
1550+
strongly_connected_components() below for a reference.
1551+
"""
1552+
if len(ascc) == 1:
1553+
return [s for s in ascc]
1554+
pri_spread = set()
1555+
for id in ascc:
1556+
state = graph[id]
1557+
for dep in state.dependencies:
1558+
if dep in ascc:
1559+
pri = state.priorities.get(dep, PRI_HIGH)
1560+
if pri < pri_max:
1561+
pri_spread.add(pri)
1562+
if len(pri_spread) == 1:
1563+
# Filtered dependencies are uniform -- order by global order.
1564+
return sorted(ascc, key=lambda id: -graph[id].order)
1565+
pri_max = max(pri_spread)
1566+
sccs = sorted_components(graph, ascc, pri_max)
1567+
# The recursion is bounded by the len(pri_spread) check above.
1568+
return [s for ss in sccs for s in order_ascc(graph, ss, pri_max)]
1569+
1570+
14891571
def process_fresh_scc(graph: Graph, scc: List[str]) -> None:
14901572
"""Process the modules in one SCC from their cached data."""
14911573
for id in scc:
@@ -1517,7 +1599,9 @@ def process_stale_scc(graph: Graph, scc: List[str]) -> None:
15171599
graph[id].write_cache()
15181600

15191601

1520-
def sorted_components(graph: Graph) -> List[AbstractSet[str]]:
1602+
def sorted_components(graph: Graph,
1603+
vertices: Optional[AbstractSet[str]] = None,
1604+
pri_max: int = PRI_ALL) -> List[AbstractSet[str]]:
15211605
"""Return the graph's SCCs, topologically sorted by dependencies.
15221606
15231607
The sort order is from leaves (nodes without dependencies) to
@@ -1527,17 +1611,17 @@ def sorted_components(graph: Graph) -> List[AbstractSet[str]]:
15271611
dependencies that aren't present in graph.keys() are ignored.
15281612
"""
15291613
# Compute SCCs.
1530-
vertices = set(graph)
1531-
edges = {id: [dep for dep in st.dependencies if dep in graph]
1532-
for id, st in graph.items()}
1614+
if vertices is None:
1615+
vertices = set(graph)
1616+
edges = {id: deps_filtered(graph, vertices, id, pri_max) for id in vertices}
15331617
sccs = list(strongly_connected_components(vertices, edges))
15341618
# Topsort.
15351619
sccsmap = {id: frozenset(scc) for scc in sccs for id in scc}
15361620
data = {} # type: Dict[AbstractSet[str], Set[AbstractSet[str]]]
15371621
for scc in sccs:
15381622
deps = set() # type: Set[AbstractSet[str]]
15391623
for id in scc:
1540-
deps.update(sccsmap[x] for x in graph[id].dependencies if x in graph)
1624+
deps.update(sccsmap[x] for x in deps_filtered(graph, vertices, id, pri_max))
15411625
data[frozenset(scc)] = deps
15421626
res = []
15431627
for ready in topsort(data):
@@ -1554,7 +1638,17 @@ def sorted_components(graph: Graph) -> List[AbstractSet[str]]:
15541638
return res
15551639

15561640

1557-
def strongly_connected_components(vertices: Set[str],
1641+
def deps_filtered(graph: Graph, vertices: AbstractSet[str], id: str, pri_max: int) -> List[str]:
1642+
"""Filter dependencies for id with pri < pri_max."""
1643+
if id not in vertices:
1644+
return []
1645+
state = graph[id]
1646+
return [dep
1647+
for dep in state.dependencies
1648+
if dep in vertices and state.priorities.get(dep, PRI_HIGH) < pri_max]
1649+
1650+
1651+
def strongly_connected_components(vertices: AbstractSet[str],
15581652
edges: Dict[str, List[str]]) -> Iterator[Set[str]]:
15591653
"""Compute Strongly Connected Components of a directed graph.
15601654

mypy/nodes.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,7 @@ def deserialize(cls, data: JsonDict) -> 'MypyFile':
224224
class ImportBase(Node):
225225
"""Base class for all import statements."""
226226
is_unreachable = False
227+
is_top_level = False # Set by semanal.FirstPass
227228
# If an import replaces existing definitions, we construct dummy assignment
228229
# statements that assign the imported names to the names in the current scope,
229230
# for type checking purposes. Example:

mypy/semanal.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2411,12 +2411,14 @@ def visit_import_from(self, node: ImportFrom) -> None:
24112411
# We can't bind module names during the first pass, as the target module might be
24122412
# unprocessed. However, we add dummy unbound imported names to the symbol table so
24132413
# that we at least know that the name refers to a module.
2414+
node.is_top_level = True
24142415
for name, as_name in node.names:
24152416
imported_name = as_name or name
24162417
if imported_name not in self.sem.globals:
24172418
self.sem.add_symbol(imported_name, SymbolTableNode(UNBOUND_IMPORTED, None), node)
24182419

24192420
def visit_import(self, node: Import) -> None:
2421+
node.is_top_level = True
24202422
# This is similar to visit_import_from -- see the comment there.
24212423
for id, as_id in node.ids:
24222424
imported_id = as_id or id
@@ -2426,6 +2428,9 @@ def visit_import(self, node: Import) -> None:
24262428
# If the previous symbol is a variable, this should take precedence.
24272429
self.sem.globals[imported_id] = SymbolTableNode(UNBOUND_IMPORTED, None)
24282430

2431+
def visit_import_all(self, node: ImportAll) -> None:
2432+
node.is_top_level = True
2433+
24292434
def visit_while_stmt(self, s: WhileStmt) -> None:
24302435
s.body.accept(self)
24312436
if s.else_body:

mypy/test/testgraph.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from mypy.myunit import Suite, assert_equal
66
from mypy.build import BuildManager, State, TYPE_CHECK
7-
from mypy.build import topsort, strongly_connected_components, sorted_components
7+
from mypy.build import topsort, strongly_connected_components, sorted_components, order_ascc
88

99

1010
class GraphSuite(Suite):
@@ -30,7 +30,7 @@ def test_scc(self) -> None:
3030
frozenset({'B', 'C'}),
3131
frozenset({'D'})})
3232

33-
def test_sorted_components(self) -> None:
33+
def _make_manager(self):
3434
manager = BuildManager(
3535
data_dir='',
3636
lib_path=[],
@@ -41,9 +41,27 @@ def test_sorted_components(self) -> None:
4141
custom_typing_module='',
4242
source_set=None,
4343
reports=None)
44+
return manager
45+
46+
def test_sorted_components(self) -> None:
47+
manager = self._make_manager()
4448
graph = {'a': State('a', None, 'import b, c', manager),
49+
'd': State('d', None, 'pass', manager),
4550
'b': State('b', None, 'import c', manager),
4651
'c': State('c', None, 'import b, d', manager),
47-
'd': State('d', None, 'pass', manager)}
52+
}
4853
res = sorted_components(graph)
4954
assert_equal(res, [frozenset({'d'}), frozenset({'c', 'b'}), frozenset({'a'})])
55+
56+
def test_order_ascc(self) -> None:
57+
manager = self._make_manager()
58+
graph = {'a': State('a', None, 'import b, c', manager),
59+
'd': State('d', None, 'def f(): import a', manager),
60+
'b': State('b', None, 'import c', manager),
61+
'c': State('c', None, 'import b, d', manager),
62+
}
63+
res = sorted_components(graph)
64+
assert_equal(res, [frozenset({'a', 'd', 'c', 'b'})])
65+
ascc = res[0]
66+
scc = order_ascc(graph, ascc)
67+
assert_equal(scc, ['d', 'c', 'b', 'a'])

0 commit comments

Comments
 (0)