Skip to content

Commit 23146c4

Browse files
authored
[mypyc] Implement async for as a statement and in comprehensions (#13444)
Progress on mypyc/mypyc#868.
1 parent 23ee1e7 commit 23146c4

File tree

11 files changed

+391
-50
lines changed

11 files changed

+391
-50
lines changed

mypyc/irbuild/expression.py

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -873,31 +873,24 @@ def _visit_display(
873873

874874

875875
def transform_list_comprehension(builder: IRBuilder, o: ListComprehension) -> Value:
876-
if any(o.generator.is_async):
877-
builder.error("async comprehensions are unimplemented", o.line)
878876
return translate_list_comprehension(builder, o.generator)
879877

880878

881879
def transform_set_comprehension(builder: IRBuilder, o: SetComprehension) -> Value:
882-
if any(o.generator.is_async):
883-
builder.error("async comprehensions are unimplemented", o.line)
884880
return translate_set_comprehension(builder, o.generator)
885881

886882

887883
def transform_dictionary_comprehension(builder: IRBuilder, o: DictionaryComprehension) -> Value:
888-
if any(o.is_async):
889-
builder.error("async comprehensions are unimplemented", o.line)
890-
891-
d = builder.call_c(dict_new_op, [], o.line)
892-
loop_params = list(zip(o.indices, o.sequences, o.condlists))
884+
d = builder.maybe_spill(builder.call_c(dict_new_op, [], o.line))
885+
loop_params = list(zip(o.indices, o.sequences, o.condlists, o.is_async))
893886

894887
def gen_inner_stmts() -> None:
895888
k = builder.accept(o.key)
896889
v = builder.accept(o.value)
897-
builder.call_c(dict_set_item_op, [d, k, v], o.line)
890+
builder.call_c(dict_set_item_op, [builder.read(d), k, v], o.line)
898891

899892
comprehension_helper(builder, loop_params, gen_inner_stmts, o.line)
900-
return d
893+
return builder.read(d)
901894

902895

903896
# Misc
@@ -915,9 +908,6 @@ def get_arg(arg: Expression | None) -> Value:
915908

916909

917910
def transform_generator_expr(builder: IRBuilder, o: GeneratorExpr) -> Value:
918-
if any(o.is_async):
919-
builder.error("async comprehensions are unimplemented", o.line)
920-
921911
builder.warning("Treating generator comprehension as list", o.line)
922912
return builder.call_c(iter_op, [translate_list_comprehension(builder, o)], o.line)
923913

mypyc/irbuild/for_helpers.py

Lines changed: 115 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -20,17 +20,30 @@
2020
TupleExpr,
2121
TypeAlias,
2222
)
23-
from mypyc.ir.ops import BasicBlock, Branch, Integer, IntOp, Register, TupleGet, TupleSet, Value
23+
from mypyc.ir.ops import (
24+
BasicBlock,
25+
Branch,
26+
Integer,
27+
IntOp,
28+
LoadAddress,
29+
LoadMem,
30+
Register,
31+
TupleGet,
32+
TupleSet,
33+
Value,
34+
)
2435
from mypyc.ir.rtypes import (
2536
RTuple,
2637
RType,
38+
bool_rprimitive,
2739
int_rprimitive,
2840
is_dict_rprimitive,
2941
is_list_rprimitive,
3042
is_sequence_rprimitive,
3143
is_short_int_rprimitive,
3244
is_str_rprimitive,
3345
is_tuple_rprimitive,
46+
pointer_rprimitive,
3447
short_int_rprimitive,
3548
)
3649
from mypyc.irbuild.builder import IRBuilder
@@ -45,8 +58,9 @@
4558
dict_value_iter_op,
4659
)
4760
from mypyc.primitives.exc_ops import no_err_occurred_op
48-
from mypyc.primitives.generic_ops import iter_op, next_op
61+
from mypyc.primitives.generic_ops import aiter_op, anext_op, iter_op, next_op
4962
from mypyc.primitives.list_ops import list_append_op, list_get_item_unsafe_op, new_list_set_item_op
63+
from mypyc.primitives.misc_ops import stop_async_iteration_op
5064
from mypyc.primitives.registry import CFunctionDescription
5165
from mypyc.primitives.set_ops import set_add_op
5266

@@ -59,6 +73,7 @@ def for_loop_helper(
5973
expr: Expression,
6074
body_insts: GenFunc,
6175
else_insts: GenFunc | None,
76+
is_async: bool,
6277
line: int,
6378
) -> None:
6479
"""Generate IR for a loop.
@@ -81,7 +96,9 @@ def for_loop_helper(
8196
# Determine where we want to exit, if our condition check fails.
8297
normal_loop_exit = else_block if else_insts is not None else exit_block
8398

84-
for_gen = make_for_loop_generator(builder, index, expr, body_block, normal_loop_exit, line)
99+
for_gen = make_for_loop_generator(
100+
builder, index, expr, body_block, normal_loop_exit, line, is_async=is_async
101+
)
85102

86103
builder.push_loop_stack(step_block, exit_block)
87104
condition_block = BasicBlock()
@@ -220,32 +237,33 @@ def translate_list_comprehension(builder: IRBuilder, gen: GeneratorExpr) -> Valu
220237
if val is not None:
221238
return val
222239

223-
list_ops = builder.new_list_op([], gen.line)
224-
loop_params = list(zip(gen.indices, gen.sequences, gen.condlists))
240+
list_ops = builder.maybe_spill(builder.new_list_op([], gen.line))
241+
242+
loop_params = list(zip(gen.indices, gen.sequences, gen.condlists, gen.is_async))
225243

226244
def gen_inner_stmts() -> None:
227245
e = builder.accept(gen.left_expr)
228-
builder.call_c(list_append_op, [list_ops, e], gen.line)
246+
builder.call_c(list_append_op, [builder.read(list_ops), e], gen.line)
229247

230248
comprehension_helper(builder, loop_params, gen_inner_stmts, gen.line)
231-
return list_ops
249+
return builder.read(list_ops)
232250

233251

234252
def translate_set_comprehension(builder: IRBuilder, gen: GeneratorExpr) -> Value:
235-
set_ops = builder.new_set_op([], gen.line)
236-
loop_params = list(zip(gen.indices, gen.sequences, gen.condlists))
253+
set_ops = builder.maybe_spill(builder.new_set_op([], gen.line))
254+
loop_params = list(zip(gen.indices, gen.sequences, gen.condlists, gen.is_async))
237255

238256
def gen_inner_stmts() -> None:
239257
e = builder.accept(gen.left_expr)
240-
builder.call_c(set_add_op, [set_ops, e], gen.line)
258+
builder.call_c(set_add_op, [builder.read(set_ops), e], gen.line)
241259

242260
comprehension_helper(builder, loop_params, gen_inner_stmts, gen.line)
243-
return set_ops
261+
return builder.read(set_ops)
244262

245263

246264
def comprehension_helper(
247265
builder: IRBuilder,
248-
loop_params: list[tuple[Lvalue, Expression, list[Expression]]],
266+
loop_params: list[tuple[Lvalue, Expression, list[Expression], bool]],
249267
gen_inner_stmts: Callable[[], None],
250268
line: int,
251269
) -> None:
@@ -260,20 +278,26 @@ def comprehension_helper(
260278
gen_inner_stmts: function to generate the IR for the body of the innermost loop
261279
"""
262280

263-
def handle_loop(loop_params: list[tuple[Lvalue, Expression, list[Expression]]]) -> None:
281+
def handle_loop(loop_params: list[tuple[Lvalue, Expression, list[Expression], bool]]) -> None:
264282
"""Generate IR for a loop.
265283
266284
Given a list of (index, expression, [conditions]) tuples, generate IR
267285
for the nested loops the list defines.
268286
"""
269-
index, expr, conds = loop_params[0]
287+
index, expr, conds, is_async = loop_params[0]
270288
for_loop_helper(
271-
builder, index, expr, lambda: loop_contents(conds, loop_params[1:]), None, line
289+
builder,
290+
index,
291+
expr,
292+
lambda: loop_contents(conds, loop_params[1:]),
293+
None,
294+
is_async=is_async,
295+
line=line,
272296
)
273297

274298
def loop_contents(
275299
conds: list[Expression],
276-
remaining_loop_params: list[tuple[Lvalue, Expression, list[Expression]]],
300+
remaining_loop_params: list[tuple[Lvalue, Expression, list[Expression], bool]],
277301
) -> None:
278302
"""Generate the body of the loop.
279303
@@ -319,13 +343,23 @@ def make_for_loop_generator(
319343
body_block: BasicBlock,
320344
loop_exit: BasicBlock,
321345
line: int,
346+
is_async: bool = False,
322347
nested: bool = False,
323348
) -> ForGenerator:
324349
"""Return helper object for generating a for loop over an iterable.
325350
326351
If "nested" is True, this is a nested iterator such as "e" in "enumerate(e)".
327352
"""
328353

354+
# Do an async loop if needed. async is always generic
355+
if is_async:
356+
expr_reg = builder.accept(expr)
357+
async_obj = ForAsyncIterable(builder, index, body_block, loop_exit, line, nested)
358+
item_type = builder._analyze_iterable_item_type(expr)
359+
item_rtype = builder.type_to_rtype(item_type)
360+
async_obj.init(expr_reg, item_rtype)
361+
return async_obj
362+
329363
rtyp = builder.node_type(expr)
330364
if is_sequence_rprimitive(rtyp):
331365
# Special case "for x in <list>".
@@ -500,7 +534,7 @@ def load_len(self, expr: Value | AssignmentTarget) -> Value:
500534

501535

502536
class ForIterable(ForGenerator):
503-
"""Generate IR for a for loop over an arbitrary iterable (the normal case)."""
537+
"""Generate IR for a for loop over an arbitrary iterable (the general case)."""
504538

505539
def need_cleanup(self) -> bool:
506540
# Create a new cleanup block for when the loop is finished.
@@ -548,6 +582,70 @@ def gen_cleanup(self) -> None:
548582
self.builder.call_c(no_err_occurred_op, [], self.line)
549583

550584

585+
class ForAsyncIterable(ForGenerator):
586+
"""Generate IR for an async for loop."""
587+
588+
def init(self, expr_reg: Value, target_type: RType) -> None:
589+
# Define targets to contain the expression, along with the
590+
# iterator that will be used for the for-loop. We are inside
591+
# of a generator function, so we will spill these into
592+
# environment class.
593+
builder = self.builder
594+
iter_reg = builder.call_c(aiter_op, [expr_reg], self.line)
595+
builder.maybe_spill(expr_reg)
596+
self.iter_target = builder.maybe_spill(iter_reg)
597+
self.target_type = target_type
598+
self.stop_reg = Register(bool_rprimitive)
599+
600+
def gen_condition(self) -> None:
601+
# This does the test and fetches the next value
602+
# try:
603+
# TARGET = await type(iter).__anext__(iter)
604+
# stop = False
605+
# except StopAsyncIteration:
606+
# stop = True
607+
#
608+
# What a pain.
609+
# There are optimizations available here if we punch through some abstractions.
610+
611+
from mypyc.irbuild.statement import emit_await, transform_try_except
612+
613+
builder = self.builder
614+
line = self.line
615+
616+
def except_match() -> Value:
617+
addr = builder.add(LoadAddress(pointer_rprimitive, stop_async_iteration_op.src, line))
618+
return builder.add(LoadMem(stop_async_iteration_op.type, addr))
619+
620+
def try_body() -> None:
621+
awaitable = builder.call_c(anext_op, [builder.read(self.iter_target)], line)
622+
self.next_reg = emit_await(builder, awaitable, line)
623+
builder.assign(self.stop_reg, builder.false(), -1)
624+
625+
def except_body() -> None:
626+
builder.assign(self.stop_reg, builder.true(), line)
627+
628+
transform_try_except(
629+
builder, try_body, [((except_match, line), None, except_body)], None, line
630+
)
631+
632+
builder.add(Branch(self.stop_reg, self.loop_exit, self.body_block, Branch.BOOL))
633+
634+
def begin_body(self) -> None:
635+
# Assign the value obtained from await __anext__ to the
636+
# lvalue so that it can be referenced by code in the body of the loop.
637+
builder = self.builder
638+
line = self.line
639+
# We unbox here so that iterating with tuple unpacking generates a tuple based
640+
# unpack instead of an iterator based one.
641+
next_reg = builder.coerce(self.next_reg, self.target_type, line)
642+
builder.assign(builder.get_assignment_target(self.index), next_reg, line)
643+
644+
def gen_step(self) -> None:
645+
# Nothing to do here, since we get the next item as part of gen_condition().
646+
pass
647+
648+
551649
def unsafe_index(builder: IRBuilder, target: Value, index: Value, line: int) -> Value:
552650
"""Emit a potentially unsafe index into a target."""
553651
# This doesn't really fit nicely into any of our data-driven frameworks

mypyc/irbuild/specialize.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,7 @@ def any_all_helper(
369369
) -> Value:
370370
retval = Register(bool_rprimitive)
371371
builder.assign(retval, initial_value(), -1)
372-
loop_params = list(zip(gen.indices, gen.sequences, gen.condlists))
372+
loop_params = list(zip(gen.indices, gen.sequences, gen.condlists, gen.is_async))
373373
true_block, false_block, exit_block = BasicBlock(), BasicBlock(), BasicBlock()
374374

375375
def gen_inner_stmts() -> None:
@@ -417,7 +417,9 @@ def gen_inner_stmts() -> None:
417417
call_expr = builder.accept(gen_expr.left_expr)
418418
builder.assign(retval, builder.binary_op(retval, call_expr, "+", -1), -1)
419419

420-
loop_params = list(zip(gen_expr.indices, gen_expr.sequences, gen_expr.condlists))
420+
loop_params = list(
421+
zip(gen_expr.indices, gen_expr.sequences, gen_expr.condlists, gen_expr.is_async)
422+
)
421423
comprehension_helper(builder, loop_params, gen_inner_stmts, gen_expr.line)
422424

423425
return retval
@@ -467,7 +469,7 @@ def gen_inner_stmts() -> None:
467469
builder.assign(retval, builder.accept(gen.left_expr), gen.left_expr.line)
468470
builder.goto(exit_block)
469471

470-
loop_params = list(zip(gen.indices, gen.sequences, gen.condlists))
472+
loop_params = list(zip(gen.indices, gen.sequences, gen.condlists, gen.is_async))
471473
comprehension_helper(builder, loop_params, gen_inner_stmts, gen.line)
472474

473475
# Now we need the case for when nothing got hit. If there was

mypyc/irbuild/statement.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@
100100
)
101101

102102
GenFunc = Callable[[], None]
103+
ValueGenFunc = Callable[[], Value]
103104

104105

105106
def transform_block(builder: IRBuilder, block: Block) -> None:
@@ -327,17 +328,16 @@ def transform_while_stmt(builder: IRBuilder, s: WhileStmt) -> None:
327328

328329

329330
def transform_for_stmt(builder: IRBuilder, s: ForStmt) -> None:
330-
if s.is_async:
331-
builder.error("async for is unimplemented", s.line)
332-
333331
def body() -> None:
334332
builder.accept(s.body)
335333

336334
def else_block() -> None:
337335
assert s.else_body is not None
338336
builder.accept(s.else_body)
339337

340-
for_loop_helper(builder, s.index, s.expr, body, else_block if s.else_body else None, s.line)
338+
for_loop_helper(
339+
builder, s.index, s.expr, body, else_block if s.else_body else None, s.is_async, s.line
340+
)
341341

342342

343343
def transform_break_stmt(builder: IRBuilder, node: BreakStmt) -> None:
@@ -362,7 +362,7 @@ def transform_raise_stmt(builder: IRBuilder, s: RaiseStmt) -> None:
362362
def transform_try_except(
363363
builder: IRBuilder,
364364
body: GenFunc,
365-
handlers: Sequence[tuple[Expression | None, Expression | None, GenFunc]],
365+
handlers: Sequence[tuple[tuple[ValueGenFunc, int] | None, Expression | None, GenFunc]],
366366
else_body: GenFunc | None,
367367
line: int,
368368
) -> None:
@@ -399,8 +399,9 @@ def transform_try_except(
399399
for type, var, handler_body in handlers:
400400
next_block = None
401401
if type:
402+
type_f, type_line = type
402403
next_block, body_block = BasicBlock(), BasicBlock()
403-
matches = builder.call_c(exc_matches_op, [builder.accept(type)], type.line)
404+
matches = builder.call_c(exc_matches_op, [type_f()], type_line)
404405
builder.add(Branch(matches, body_block, next_block, Branch.BOOL))
405406
builder.activate_block(body_block)
406407
if var:
@@ -451,8 +452,12 @@ def body() -> None:
451452
def make_handler(body: Block) -> GenFunc:
452453
return lambda: builder.accept(body)
453454

455+
def make_entry(type: Expression) -> tuple[ValueGenFunc, int]:
456+
return (lambda: builder.accept(type), type.line)
457+
454458
handlers = [
455-
(type, var, make_handler(body)) for type, var, body in zip(t.types, t.vars, t.handlers)
459+
(make_entry(type) if type else None, var, make_handler(body))
460+
for type, var, body in zip(t.types, t.vars, t.handlers)
456461
]
457462
else_body = (lambda: builder.accept(t.else_body)) if t.else_body else None
458463
transform_try_except(builder, body, handlers, else_body, t.line)

mypyc/lib-rt/CPy.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -610,6 +610,10 @@ PyObject *CPyImport_ImportFrom(PyObject *module, PyObject *package_name,
610610

611611
PyObject *CPySingledispatch_RegisterFunction(PyObject *singledispatch_func, PyObject *cls,
612612
PyObject *func);
613+
614+
PyObject *CPy_GetAIter(PyObject *obj);
615+
PyObject *CPy_GetANext(PyObject *aiter);
616+
613617
#ifdef __cplusplus
614618
}
615619
#endif

0 commit comments

Comments
 (0)