diff --git a/extmod/asyncio/stream.py b/extmod/asyncio/stream.py index c47c48cf09..5547bfbd51 100644 --- a/extmod/asyncio/stream.py +++ b/extmod/asyncio/stream.py @@ -127,20 +127,30 @@ class Server: await self.wait_closed() def close(self): + # Note: the _serve task must have already started by now due to the sleep + # in start_server, so `state` won't be clobbered at the start of _serve. + self.state = True self.task.cancel() async def wait_closed(self): await self.task async def _serve(self, s, cb): + self.state = False # Accept incoming connections while True: try: yield core._io_queue.queue_read(s) - except core.CancelledError: - # Shutdown server + except core.CancelledError as er: + # The server task was cancelled, shutdown server and close socket. s.close() - return + if self.state: + # If the server was explicitly closed, ignore the cancellation. + return + else: + # Otherwise e.g. the parent task was cancelled, propagate + # cancellation. + raise er try: s2, addr = s.accept() except: @@ -167,6 +177,16 @@ async def start_server(cb, host, port, backlog=5): # Create and return server object and task. srv = Server() srv.task = core.create_task(srv._serve(s, cb)) + try: + # Ensure that the _serve task has been scheduled so that it gets to + # handle cancellation. + await core.sleep_ms(0) + except core.CancelledError as er: + # If the parent task is cancelled during this first sleep, then + # we will leak the task and it will sit waiting for the socket, so + # cancel it. + srv.task.cancel() + raise er return srv diff --git a/tests/net_hosted/asyncio_start_server.py b/tests/net_hosted/asyncio_start_server.py index 3162218981..e76faf7edb 100644 --- a/tests/net_hosted/asyncio_start_server.py +++ b/tests/net_hosted/asyncio_start_server.py @@ -22,6 +22,44 @@ async def test(): print("sleep") await asyncio.sleep(0) + # Test that cancellation works before the server starts if + # the subsequent code raises. + print("create server3") + server3 = await asyncio.start_server(None, "0.0.0.0", 8000) + try: + async with server3: + raise OSError + except OSError as er: + print("OSError") + + # Test that closing doesn't raise CancelledError. + print("create server4") + server4 = await asyncio.start_server(None, "0.0.0.0", 8000) + server4.close() + await server4.wait_closed() + print("server4 closed") + + # Test that cancelling the task will still raise CancelledError, checking + # edge cases around how many times the tasks have been re-scheduled by + # sleep. + async def task(n): + print("create task server", n) + srv = await asyncio.start_server(None, "0.0.0.0", 8000) + await srv.wait_closed() + # This should be unreachable. + print("task finished") + + for num_sleep in range(0, 5): + print("sleep", num_sleep) + t = asyncio.create_task(task(num_sleep)) + for _ in range(num_sleep): + await asyncio.sleep(0) + t.cancel() + try: + await t + except asyncio.CancelledError: + print("CancelledError") + print("done") diff --git a/tests/net_hosted/asyncio_start_server.py.exp b/tests/net_hosted/asyncio_start_server.py.exp index 0fb8e6a63b..58982a108c 100644 --- a/tests/net_hosted/asyncio_start_server.py.exp +++ b/tests/net_hosted/asyncio_start_server.py.exp @@ -2,4 +2,22 @@ create server1 create server2 OSError sleep +create server3 +OSError +create server4 +server4 closed +sleep 0 +CancelledError +sleep 1 +create task server 1 +CancelledError +sleep 2 +create task server 2 +CancelledError +sleep 3 +create task server 3 +CancelledError +sleep 4 +create task server 4 +CancelledError done