From c43bd6f8f2c3fab95ae1fa8ea5f6688e88aa6775 Mon Sep 17 00:00:00 2001 From: Elvis Pranskevichus Date: Wed, 30 May 2018 17:59:19 -0400 Subject: [PATCH] Make Pool.close() wait until all checked out connections are released Currently, `pool.close()`, despite the "graceful" designation, closes all connections immediately regardless of whether they are acquired. With this change, pool will wait for connections to actually be released before closing. WARNING: This is a potentially incompatible behavior change, as sloppily written code which does not release acquired connections will now cause `pool.close()` to hang forever. Also, when `conn.close()` or `conn.terminate()` are called directly on an acquired connection, the associated pool item is released immediately. Closes: #290 --- asyncpg/connection.py | 45 +++--- asyncpg/pool.py | 247 +++++++++++++++++++------------ tests/test_adversity.py | 7 +- tests/test_cache_invalidation.py | 5 + tests/test_pool.py | 91 +++++++++--- 5 files changed, 266 insertions(+), 129 deletions(-) diff --git a/asyncpg/connection.py b/asyncpg/connection.py index ea1a86fa..5e00a18f 100644 --- a/asyncpg/connection.py +++ b/asyncpg/connection.py @@ -984,7 +984,7 @@ def is_closed(self): :return bool: ``True`` if the connection is closed, ``False`` otherwise. """ - return not self._protocol.is_connected() or self._aborted + return self._aborted or not self._protocol.is_connected() async def close(self, *, timeout=None): """Close the connection gracefully. @@ -995,30 +995,21 @@ async def close(self, *, timeout=None): .. versionchanged:: 0.14.0 Added the *timeout* parameter. """ - if self.is_closed(): - return - self._mark_stmts_as_closed() - self._listeners.clear() - self._log_listeners.clear() - self._aborted = True try: - await self._protocol.close(timeout) + if not self.is_closed(): + await self._protocol.close(timeout) except Exception: # If we fail to close gracefully, abort the connection. - self._aborted = True - self._protocol.abort() + self._abort() raise finally: - self._clean_tasks() + self._cleanup() def terminate(self): """Terminate the connection without waiting for pending data.""" - self._mark_stmts_as_closed() - self._listeners.clear() - self._log_listeners.clear() - self._aborted = True - self._protocol.abort() - self._clean_tasks() + if not self.is_closed(): + self._abort() + self._cleanup() async def reset(self, *, timeout=None): self._check_open() @@ -1041,6 +1032,26 @@ async def reset(self, *, timeout=None): if reset_query: await self.execute(reset_query, timeout=timeout) + def _abort(self): + # Put the connection into the aborted state. + self._aborted = True + self._protocol.abort() + self._protocol = None + + def _cleanup(self): + # Free the resources associated with this connection. + # This must be called when a connection is terminated. + + if self._proxy is not None: + # Connection is a member of a pool, so let the pool + # know that this connection is dead. + self._proxy._holder._release_on_close() + + self._mark_stmts_as_closed() + self._listeners.clear() + self._log_listeners.clear() + self._clean_tasks() + def _clean_tasks(self): # Wrap-up any remaining tasks associated with this connection. if self._cancellations: diff --git a/asyncpg/pool.py b/asyncpg/pool.py index 9c8d581e..f60de047 100644 --- a/asyncpg/pool.py +++ b/asyncpg/pool.py @@ -74,8 +74,7 @@ def __getattr__(self, attr): def _detach(self) -> connection.Connection: if self._con is None: - raise exceptions.InterfaceError( - 'cannot detach PoolConnectionProxy: already detached') + return con, self._con = self._con, None con._set_proxy(None) @@ -92,7 +91,7 @@ def __repr__(self): class PoolConnectionHolder: - __slots__ = ('_con', '_pool', '_loop', + __slots__ = ('_con', '_pool', '_loop', '_proxy', '_connect_args', '_connect_kwargs', '_max_queries', '_setup', '_init', '_max_inactive_time', '_in_use', @@ -103,6 +102,7 @@ def __init__(self, pool, *, connect_args, connect_kwargs, self._pool = pool self._con = None + self._proxy = None self._connect_args = connect_args self._connect_kwargs = connect_kwargs @@ -111,11 +111,14 @@ def __init__(self, pool, *, connect_args, connect_kwargs, self._setup = setup self._init = init self._inactive_callback = None - self._in_use = False + self._in_use = None # type: asyncio.Future self._timeout = None async def connect(self): - assert self._con is None + if self._con is not None: + raise exceptions.InternalClientError( + 'PoolConnectionHolder.connect() called while another ' + 'connection already exists') if self._pool._working_addr is None: # First connection attempt on this pool. @@ -152,7 +155,7 @@ async def acquire(self) -> PoolConnectionProxy: self._maybe_cancel_inactive_callback() - proxy = PoolConnectionProxy(self, self._con) + self._proxy = proxy = PoolConnectionProxy(self, self._con) if self._setup is not None: try: @@ -163,102 +166,139 @@ async def acquire(self) -> PoolConnectionProxy: # we close it. A new connection will be created # when `acquire` is called again. try: - proxy._detach() - # Use `close` to close the connection gracefully. + # Use `close()` to close the connection gracefully. # An exception in `setup` isn't necessarily caused - # by an IO or a protocol error. + # by an IO or a protocol error. close() will + # do the necessary cleanup via _release_on_close(). await self._con.close() finally: - self._con = None raise ex - self._in_use = True + self._in_use = self._pool._loop.create_future() + return proxy async def release(self, timeout): - assert self._in_use - self._in_use = False - self._timeout = None + if self._in_use is None: + raise exceptions.InternalClientError( + 'PoolConnectionHolder.release() called on ' + 'a free connection holder') if self._con.is_closed(): - self._con = None + # When closing, pool connections perform the necessary + # cleanup, so we don't have to do anything else here. + return + + self._timeout = None + + if self._con._protocol.queries_count >= self._max_queries: + # The connection has reached its maximum utilization limit, + # so close it. Connection.close() will call _release(). + await self._con.close(timeout=timeout) + return - elif self._con._protocol.queries_count >= self._max_queries: + try: + budget = timeout + + if self._con._protocol._is_cancelling(): + # If the connection is in cancellation state, + # wait for the cancellation + started = time.monotonic() + await asyncio.wait_for( + self._con._protocol._wait_for_cancellation(), + budget, loop=self._pool._loop) + if budget is not None: + budget -= time.monotonic() - started + + await self._con.reset(timeout=budget) + except Exception as ex: + # If the `reset` call failed, terminate the connection. + # A new one will be created when `acquire` is called + # again. try: - await self._con.close(timeout=timeout) + # An exception in `reset` is most likely caused by + # an IO error, so terminate the connection. + self._con.terminate() finally: - self._con = None + raise ex - else: - try: - budget = timeout - - if self._con._protocol._is_cancelling(): - # If the connection is in cancellation state, - # wait for the cancellation - started = time.monotonic() - await asyncio.wait_for( - self._con._protocol._wait_for_cancellation(), - budget, loop=self._pool._loop) - if budget is not None: - budget -= time.monotonic() - started - - await self._con.reset(timeout=budget) - except Exception as ex: - # If the `reset` call failed, terminate the connection. - # A new one will be created when `acquire` is called - # again. - try: - # An exception in `reset` is most likely caused by - # an IO error, so terminate the connection. - self._con.terminate() - finally: - self._con = None - raise ex + # Free this connection holder and invalidate the + # connection proxy. + self._release() - assert self._inactive_callback is None - if self._max_inactive_time and self._con is not None: - self._inactive_callback = self._pool._loop.call_later( - self._max_inactive_time, self._deactivate_connection) + # Rearm the connection inactivity timer. + self._setup_inactive_callback() - async def close(self): - self._maybe_cancel_inactive_callback() - if self._con is None: - return - if self._con.is_closed(): - self._con = None + async def wait_until_released(self): + if self._in_use is None: return + else: + await self._in_use - try: + async def close(self): + if self._con is not None: + # Connection.close() will call _release_on_close() to + # finish holder cleanup. await self._con.close() - finally: - self._con = None def terminate(self): - self._maybe_cancel_inactive_callback() - if self._con is None: - return - if self._con.is_closed(): - self._con = None - return - - try: + if self._con is not None: + # Connection.terminate() will call _release_on_close() to + # finish holder cleanup. self._con.terminate() - finally: - self._con = None + + def _setup_inactive_callback(self): + if self._inactive_callback is not None: + raise exceptions.InternalClientError( + 'pool connection inactivity timer already exists') + + if self._max_inactive_time: + self._inactive_callback = self._pool._loop.call_later( + self._max_inactive_time, self._deactivate_inactive_connection) def _maybe_cancel_inactive_callback(self): if self._inactive_callback is not None: self._inactive_callback.cancel() self._inactive_callback = None - def _deactivate_connection(self): - assert not self._in_use - if self._con is None or self._con.is_closed(): - return - self._con.terminate() + def _deactivate_inactive_connection(self): + if self._in_use is not None: + raise exceptions.InternalClientError( + 'attempting to deactivate an acquired connection') + + if self._con is not None: + # The connection is idle and not in use, so it's fine to + # use terminate() instead of close(). + self._con.terminate() + # Must call clear_connection, because _deactivate_connection + # is called when the connection is *not* checked out, and + # so terminate() above will not call the below. + self._release_on_close() + + def _release_on_close(self): + self._maybe_cancel_inactive_callback() + self._release() self._con = None + def _release(self): + """Release this connection holder.""" + if self._in_use is None: + # The holder is not checked out. + return + + if not self._in_use.done(): + self._in_use.set_result(None) + self._in_use = None + + # Deinitialize the connection proxy. All subsequent + # operations on it will fail. + if self._proxy is not None: + self._proxy._detach() + self._proxy = None + + # Put ourselves back to the pool queue. + self._pool._queue.put_nowait(self) + class Pool: """A connection pool. @@ -273,7 +313,7 @@ class Pool: __slots__ = ('_queue', '_loop', '_minsize', '_maxsize', '_working_addr', '_working_config', '_working_params', - '_holders', '_initialized', '_closed', + '_holders', '_initialized', '_closing', '_closed', '_connection_class') def __init__(self, *connect_args, @@ -322,6 +362,7 @@ def __init__(self, *connect_args, self._connection_class = connection_class + self._closing = False self._closed = False for _ in range(max_size): @@ -468,7 +509,10 @@ async def _acquire_impl(): ch._timeout = timeout return proxy + if self._closing: + raise exceptions.InterfaceError('pool is closing') self._check_init() + if timeout is None: return await _acquire_impl() else: @@ -488,14 +532,6 @@ async def release(self, connection, *, timeout=None): .. versionchanged:: 0.14.0 Added the *timeout* parameter. """ - async def _release_impl(ch: PoolConnectionHolder, timeout: float): - try: - await ch.release(timeout) - finally: - self._queue.put_nowait(ch) - - self._check_init() - if (type(connection) is not PoolConnectionProxy or connection._holder._pool is not self): raise exceptions.InterfaceError( @@ -507,35 +543,64 @@ async def _release_impl(ch: PoolConnectionHolder, timeout: float): # Already released, do nothing. return - con = connection._detach() - con._on_release() + self._check_init() + + # Let the connection do its internal housekeeping when its released. + connection._con._on_release() + ch = connection._holder if timeout is None: - timeout = connection._holder._timeout + timeout = ch._timeout # Use asyncio.shield() to guarantee that task cancellation # does not prevent the connection from being returned to the # pool properly. - return await asyncio.shield( - _release_impl(connection._holder, timeout), loop=self._loop) + return await asyncio.shield(ch.release(timeout), loop=self._loop) async def close(self): - """Gracefully close all connections in the pool.""" + """Attempt to gracefully close all connections in the pool. + + Wait until all pool connections are released, close them and + shut down the pool. If any error (including cancellation) occurs + in ``close()`` the pool will terminate by calling + :meth:'Pool.terminate() `. + + .. versionchanged:: 0.16.0 + ``close()`` now waits until all pool connections are released + before closing them and the pool. Errors raised in ``close()`` + will cause immediate pool termination. + """ if self._closed: return self._check_init() - self._closed = True - coros = [ch.close() for ch in self._holders] - await asyncio.gather(*coros, loop=self._loop) + + self._closing = True + + try: + release_coros = [ + ch.wait_until_released() for ch in self._holders] + await asyncio.gather(*release_coros, loop=self._loop) + + close_coros = [ + ch.close() for ch in self._holders] + await asyncio.gather(*close_coros, loop=self._loop) + + except Exception: + self.terminate() + raise + + finally: + self._closed = True + self._closing = False def terminate(self): """Terminate all connections in the pool.""" if self._closed: return self._check_init() - self._closed = True for ch in self._holders: ch.terminate() + self._closed = True def _check_init(self): if not self._initialized: diff --git a/tests/test_adversity.py b/tests/test_adversity.py index a9032b54..a0d153eb 100644 --- a/tests/test_adversity.py +++ b/tests/test_adversity.py @@ -32,7 +32,7 @@ async def test_pool_release_timeout(self): self.proxy.trigger_connectivity_loss() finally: self.proxy.restore_connectivity() - await pool.close() + pool.terminate() @tb.with_timeout(30.0) async def test_pool_handles_abrupt_connection_loss(self): @@ -57,8 +57,11 @@ def kill_connectivity(): timeout=cmd_timeout, command_timeout=cmd_timeout) with self.assertRunUnder(worst_runtime): - async with new_pool as pool: + pool = await new_pool + try: workers = [worker(pool) for _ in range(concurrency)] self.loop.call_later(1, kill_connectivity) await asyncio.gather( *workers, loop=self.loop, return_exceptions=True) + finally: + pool.terminate() diff --git a/tests/test_cache_invalidation.py b/tests/test_cache_invalidation.py index 13995f08..96cfd58f 100644 --- a/tests/test_cache_invalidation.py +++ b/tests/test_cache_invalidation.py @@ -85,6 +85,8 @@ async def test_prepare_cache_invalidation_in_pool(self): finally: await self.con.execute('DROP TABLE tab1') + await pool.release(con2) + await pool.release(con1) await pool.close() async def test_type_cache_invalidation_in_transaction(self): @@ -303,6 +305,9 @@ async def test_type_cache_invalidation_in_pool(self): finally: await self.con.execute('DROP TABLE tab1') await self.con.execute('DROP TYPE typ1') + await pool.release(con2) + await pool.release(con1) await pool.close() + await pool_chk.release(con_chk) await pool_chk.close() await self.con.execute('DROP DATABASE testdb') diff --git a/tests/test_pool.py b/tests/test_pool.py index eba49f7d..02570394 100644 --- a/tests/test_pool.py +++ b/tests/test_pool.py @@ -6,7 +6,6 @@ import asyncio -import asyncpg import inspect import os import platform @@ -16,6 +15,7 @@ import time import unittest +import asyncpg from asyncpg import _testbase as tb from asyncpg import connection as pg_connection from asyncpg import cluster as pg_cluster @@ -92,16 +92,23 @@ async def test_pool_04(self): min_size=1, max_size=1) con = await pool.acquire(timeout=POOL_NOMINAL_TIMEOUT) - con.terminate() - await pool.release(con) - async with pool.acquire(timeout=POOL_NOMINAL_TIMEOUT) as con: - con.terminate() + # Manual termination of pool connections releases the + # pool item immediately. + con.terminate() + self.assertIsNone(pool._holders[0]._con) + self.assertIsNone(pool._holders[0]._in_use) con = await pool.acquire(timeout=POOL_NOMINAL_TIMEOUT) self.assertEqual(await con.fetchval('SELECT 1'), 1) - await pool.close() + await con.close() + self.assertIsNone(pool._holders[0]._con) + self.assertIsNone(pool._holders[0]._in_use) + # Calling release should not hurt. + await pool.release(con) + + pool.terminate() async def test_pool_05(self): for n in {1, 3, 5, 10, 20, 100}: @@ -126,7 +133,8 @@ async def setup(con): async with self.create_pool(database='postgres', min_size=5, max_size=5, setup=setup) as pool: - con = await pool.acquire() + async with pool.acquire() as con: + pass self.assertIs(con, await fut) @@ -172,9 +180,13 @@ async def test_pool_09(self): pool2 = await self.create_pool(database='postgres', min_size=1, max_size=1) - con = await pool1.acquire(timeout=POOL_NOMINAL_TIMEOUT) - with self.assertRaisesRegex(asyncpg.InterfaceError, 'is not a member'): - await pool2.release(con) + try: + con = await pool1.acquire(timeout=POOL_NOMINAL_TIMEOUT) + with self.assertRaisesRegex(asyncpg.InterfaceError, + 'is not a member'): + await pool2.release(con) + finally: + await pool1.release(con) await pool1.close() await pool2.close() @@ -297,8 +309,8 @@ async def setup(con): with self.assertRaises(Error): await pool.acquire() - con = await pool.acquire() - self.assertEqual(cons, ['error', con]) + async with pool.acquire() as con: + self.assertEqual(cons, ['error', con]) with self.subTest(method='init'): setup_calls = 0 @@ -309,9 +321,9 @@ async def setup(con): with self.assertRaises(Error): await pool.acquire() - con = await pool.acquire() - self.assertEqual(await con.fetchval('select 1::int'), 1) - self.assertEqual(cons, ['error', con._con]) + async with pool.acquire() as con: + self.assertEqual(await con.fetchval('select 1::int'), 1) + self.assertEqual(cons, ['error', con._con]) async def test_pool_auth(self): if not self.cluster.is_managed(): @@ -665,11 +677,10 @@ async def test_pool_handles_inactive_connection_errors(self): true_con.terminate() # now pool should reopen terminated connection - con = await pool.acquire(timeout=POOL_NOMINAL_TIMEOUT) - - self.assertEqual(await con.fetchval('SELECT 1'), 1) + async with pool.acquire(timeout=POOL_NOMINAL_TIMEOUT) as con: + self.assertEqual(await con.fetchval('SELECT 1'), 1) + await con.close() - await con.close() await pool.close() @unittest.skipIf(sys.version_info[:2] < (3, 6), 'no asyncgen support') @@ -741,6 +752,48 @@ class MyException(Exception): async for _ in iterate(con): # noqa raise MyException() + async def test_pool_close_waits_for_release(self): + pool = await self.create_pool(database='postgres', + min_size=1, max_size=1) + + flag = self.loop.create_future() + conn_released = False + + async def worker(): + nonlocal conn_released + + async with pool.acquire() as connection: + async with connection.transaction(): + flag.set_result(True) + await asyncio.sleep(0.1, loop=self.loop) + + conn_released = True + + self.loop.create_task(worker()) + + await flag + await pool.close() + self.assertTrue(conn_released) + + async def test_pool_close_timeout(self): + pool = await self.create_pool(database='postgres', + min_size=1, max_size=1) + + flag = self.loop.create_future() + + async def worker(): + async with pool.acquire(): + flag.set_result(True) + await asyncio.sleep(0.5, loop=self.loop) + + task = self.loop.create_task(worker()) + + with self.assertRaises(asyncio.TimeoutError): + await flag + await asyncio.wait_for(pool.close(), timeout=0.1) + + await task + @unittest.skipIf(os.environ.get('PGHOST'), 'using remote cluster for testing') class TestHotStandby(tb.ConnectedTestCase):