From 606ec9bfb1cca262cf5b936db78924a609f9e73d Mon Sep 17 00:00:00 2001 From: Damien George Date: Fri, 19 May 2023 17:00:53 +1000 Subject: [PATCH] py/compile: Fix async for's stack handling of iterator expression. Prior to this fix, async for assumed the iterator expression was a simple identifier, and used that identifier as a local to store the intermediate iterator object. This is incorrect behaviour. This commit fixes the issue by keeping the iterator object on the stack as an anonymous local variable. Fixes issue #11511. Signed-off-by: Damien George --- py/compile.c | 25 ++++++++++--- tests/basics/async_for.py | 70 +++++++++++++++++++++++++++++------ tests/basics/async_for.py.exp | 43 ++++++++++++++++++++- 3 files changed, 120 insertions(+), 18 deletions(-) diff --git a/py/compile.c b/py/compile.c index bb7c1117fa..4f91ca49b9 100644 --- a/py/compile.c +++ b/py/compile.c @@ -1768,18 +1768,21 @@ STATIC void compile_await_object_method(compiler_t *comp, qstr method) { } STATIC void compile_async_for_stmt(compiler_t *comp, mp_parse_node_struct_t *pns) { - // comp->break_label |= MP_EMIT_BREAK_FROM_FOR; - - qstr context = MP_PARSE_NODE_LEAF_ARG(pns->nodes[1]); + // Allocate labels. uint while_else_label = comp_next_label(comp); uint try_exception_label = comp_next_label(comp); uint try_else_label = comp_next_label(comp); uint try_finally_label = comp_next_label(comp); + // Stack: (...) + + // Compile the iterator expression and load and call its __aiter__ method. compile_node(comp, pns->nodes[1]); // iterator + // Stack: (..., iterator) EMIT_ARG(load_method, MP_QSTR___aiter__, false); + // Stack: (..., iterator, __aiter__) EMIT_ARG(call_method, 0, 0, 0); - compile_store_id(comp, context); + // Stack: (..., iterable) START_BREAK_CONTINUE_BLOCK @@ -1787,9 +1790,15 @@ STATIC void compile_async_for_stmt(compiler_t *comp, mp_parse_node_struct_t *pns compile_increase_except_level(comp, try_exception_label, MP_EMIT_SETUP_BLOCK_EXCEPT); - compile_load_id(comp, context); + EMIT(dup_top); + // Stack: (..., iterable, iterable) + + // Compile: yield from iterable.__anext__() compile_await_object_method(comp, MP_QSTR___anext__); + // Stack: (..., iterable, yielded_value) + c_assign(comp, pns->nodes[0], ASSIGN_STORE); // variable + // Stack: (..., iterable) EMIT_ARG(pop_except_jump, try_else_label, false); EMIT_ARG(label_assign, try_exception_label); @@ -1806,6 +1815,8 @@ STATIC void compile_async_for_stmt(compiler_t *comp, mp_parse_node_struct_t *pns compile_decrease_except_level(comp); EMIT(end_except_handler); + // Stack: (..., iterable) + EMIT_ARG(label_assign, try_else_label); compile_node(comp, pns->nodes[2]); // body @@ -1817,6 +1828,10 @@ STATIC void compile_async_for_stmt(compiler_t *comp, mp_parse_node_struct_t *pns compile_node(comp, pns->nodes[3]); // else EMIT_ARG(label_assign, break_label); + // Stack: (..., iterable) + + EMIT(pop_top); + // Stack: (...) } STATIC void compile_async_with_stmt_helper(compiler_t *comp, size_t n, mp_parse_node_t *nodes, mp_parse_node_t body) { diff --git a/tests/basics/async_for.py b/tests/basics/async_for.py index 5fd0540828..f54f70238c 100644 --- a/tests/basics/async_for.py +++ b/tests/basics/async_for.py @@ -1,29 +1,75 @@ # test basic async for execution # example taken from PEP0492 + class AsyncIteratorWrapper: def __init__(self, obj): - print('init') - self._it = iter(obj) + print("init") + self._obj = obj + + def __repr__(self): + return "AsyncIteratorWrapper-" + self._obj def __aiter__(self): - print('aiter') - return self + print("aiter") + return AsyncIteratorWrapperIterator(self._obj) + + +class AsyncIteratorWrapperIterator: + def __init__(self, obj): + print("init") + self._it = iter(obj) async def __anext__(self): - print('anext') + print("anext") try: value = next(self._it) except StopIteration: raise StopAsyncIteration return value -async def coro(): - async for letter in AsyncIteratorWrapper('abc'): + +def run_coro(c): + print("== start ==") + try: + c.send(None) + except StopIteration: + print("== finish ==") + + +async def coro0(): + async for letter in AsyncIteratorWrapper("abc"): print(letter) -o = coro() -try: - o.send(None) -except StopIteration: - print('finished') + +run_coro(coro0()) + + +async def coro1(): + a = AsyncIteratorWrapper("def") + async for letter in a: + print(letter) + print(a) + + +run_coro(coro1()) + +a_global = AsyncIteratorWrapper("ghi") + + +async def coro2(): + async for letter in a_global: + print(letter) + print(a_global) + + +run_coro(coro2()) + + +async def coro3(a): + async for letter in a: + print(letter) + print(a) + + +run_coro(coro3(AsyncIteratorWrapper("jkl"))) diff --git a/tests/basics/async_for.py.exp b/tests/basics/async_for.py.exp index 1f728a66c8..6f59979c06 100644 --- a/tests/basics/async_for.py.exp +++ b/tests/basics/async_for.py.exp @@ -1,5 +1,7 @@ +== start == init aiter +init anext a anext @@ -7,4 +9,43 @@ b anext c anext -finished +== finish == +== start == +init +aiter +init +anext +d +anext +e +anext +f +anext +AsyncIteratorWrapper-def +== finish == +init +== start == +aiter +init +anext +g +anext +h +anext +i +anext +AsyncIteratorWrapper-ghi +== finish == +init +== start == +aiter +init +anext +j +anext +k +anext +l +anext +AsyncIteratorWrapper-jkl +== finish ==