diff --git a/asyncpg/compat.py b/asyncpg/compat.py index 99a561d0..6dbce3c9 100644 --- a/asyncpg/compat.py +++ b/asyncpg/compat.py @@ -90,3 +90,19 @@ async def wait_closed(stream): # On Windows wait_closed() sometimes propagates # ConnectionResetError which is totally unnecessary. pass + + +# Workaround for https://bugs.python.org/issue37658 +async def wait_for(fut, timeout): + if timeout is None: + return await fut + + fut = asyncio.ensure_future(fut) + + try: + return await asyncio.wait_for(fut, timeout) + except asyncio.CancelledError: + if fut.done(): + return fut.result() + else: + raise diff --git a/asyncpg/connect_utils.py b/asyncpg/connect_utils.py index e5feebc2..65261664 100644 --- a/asyncpg/connect_utils.py +++ b/asyncpg/connect_utils.py @@ -636,18 +636,13 @@ async def _connect_addr( connector = asyncio.ensure_future(connector) before = time.monotonic() - try: - tr, pr = await asyncio.wait_for( - connector, timeout=timeout) - except asyncio.CancelledError: - connector.add_done_callback(_close_leaked_connection) - raise + tr, pr = await compat.wait_for(connector, timeout=timeout) timeout -= time.monotonic() - before try: if timeout <= 0: raise asyncio.TimeoutError - await asyncio.wait_for(connected, timeout=timeout) + await compat.wait_for(connected, timeout=timeout) except (Exception, asyncio.CancelledError): tr.close() raise @@ -745,12 +740,3 @@ def _create_future(loop): return asyncio.Future(loop=loop) else: return create_future() - - -def _close_leaked_connection(fut): - try: - tr, pr = fut.result() - if tr: - tr.close() - except asyncio.CancelledError: - pass # hide the exception diff --git a/asyncpg/pool.py b/asyncpg/pool.py index b3947451..c4321a2f 100644 --- a/asyncpg/pool.py +++ b/asyncpg/pool.py @@ -12,6 +12,7 @@ import time import warnings +from . import compat from . import connection from . import connect_utils from . import exceptions @@ -198,7 +199,7 @@ async def release(self, timeout): # If the connection is in cancellation state, # wait for the cancellation started = time.monotonic() - await asyncio.wait_for( + await compat.wait_for( self._con._protocol._wait_for_cancellation(), budget) if budget is not None: @@ -623,7 +624,7 @@ async def _acquire_impl(): if timeout is None: return await _acquire_impl() else: - return await asyncio.wait_for( + return await compat.wait_for( _acquire_impl(), timeout=timeout) async def release(self, connection, *, timeout=None): diff --git a/tests/test_pool.py b/tests/test_pool.py index e51923e4..9857dceb 100644 --- a/tests/test_pool.py +++ b/tests/test_pool.py @@ -379,6 +379,26 @@ async def worker(): self.cluster.trust_local_connections() self.cluster.reload() + async def test_pool_handles_task_cancel_in_acquire_with_timeout(self): + # See https://github.com/MagicStack/asyncpg/issues/547 + pool = await self.create_pool(database='postgres', + min_size=1, max_size=1) + + async def worker(): + async with pool.acquire(timeout=100): + pass + + # Schedule task + task = self.loop.create_task(worker()) + # Yield to task, but cancel almost immediately + await asyncio.sleep(0.00000000001) + # Cancel the worker. + task.cancel() + # Wait to make sure the cleanup has completed. + await asyncio.sleep(0.4) + # Check that the connection has been returned to the pool. + self.assertEqual(pool._queue.qsize(), 1) + async def test_pool_handles_task_cancel_in_release(self): # Use SlowResetConnectionPool to simulate # the Task.cancel() and __aexit__ race.