diff --git a/asyncpg/_testbase.py b/asyncpg/_testbase.py index ee4a1259..b26d766c 100644 --- a/asyncpg/_testbase.py +++ b/asyncpg/_testbase.py @@ -190,6 +190,17 @@ def start_cluster(cls, ClusterCls, *, return _start_cluster(ClusterCls, cluster_kwargs, server_settings) +def with_connection_options(**options): + if not options: + raise ValueError('no connection options were specified') + + def wrap(func): + func.__connect_options__ = options + return func + + return wrap + + class ConnectedTestCase(ClusterTestCase): def getExtraConnectOptions(self): @@ -197,9 +208,14 @@ def getExtraConnectOptions(self): def setUp(self): super().setUp() - opts = self.getExtraConnectOptions() + + # Extract options set up with `with_connection_options`. + test_func = getattr(self, self._testMethodName).__func__ + opts = getattr(test_func, '__connect_options__', {}) + self.con = self.loop.run_until_complete( self.cluster.connect(database='postgres', loop=self.loop, **opts)) + self.server_version = self.con.get_server_version() def tearDown(self): diff --git a/asyncpg/connection.py b/asyncpg/connection.py index b34bcf00..9c16f5e5 100644 --- a/asyncpg/connection.py +++ b/asyncpg/connection.py @@ -39,13 +39,14 @@ class Connection(metaclass=ConnectionMeta): __slots__ = ('_protocol', '_transport', '_loop', '_types_stmt', '_type_by_name_stmt', '_top_xact', '_uid', '_aborted', - '_stmt_cache_max_size', '_stmt_cache', '_stmts_to_close', + '_stmt_cache', '_stmts_to_close', '_addr', '_opts', '_command_timeout', '_listeners', '_server_version', '_server_caps', '_intro_query', '_reset_query', '_proxy', '_stmt_exclusive_section') def __init__(self, protocol, transport, loop, addr, opts, *, - statement_cache_size, command_timeout): + statement_cache_size, command_timeout, + max_cached_statement_lifetime): self._protocol = protocol self._transport = transport self._loop = loop @@ -58,8 +59,12 @@ def __init__(self, protocol, transport, loop, addr, opts, *, self._addr = addr self._opts = opts - self._stmt_cache_max_size = statement_cache_size - self._stmt_cache = collections.OrderedDict() + self._stmt_cache = _StatementCache( + loop=loop, + max_size=statement_cache_size, + on_remove=self._maybe_gc_stmt, + max_lifetime=max_cached_statement_lifetime) + self._stmts_to_close = set() if command_timeout is not None: @@ -126,6 +131,8 @@ async def add_listener(self, channel, callback): async def remove_listener(self, channel, callback): """Remove a listening callback on the specified channel.""" + if self.is_closed(): + return if channel not in self._listeners: return if callback not in self._listeners[channel]: @@ -266,46 +273,33 @@ async def executemany(self, command: str, args, return await self._executemany(command, args, timeout) async def _get_statement(self, query, timeout, *, named: bool=False): - use_cache = self._stmt_cache_max_size > 0 - if use_cache: - try: - state = self._stmt_cache[query] - except KeyError: - pass - else: - self._stmt_cache.move_to_end(query, last=True) - if not state.closed: - return state - - protocol = self._protocol + statement = self._stmt_cache.get(query) + if statement is not None: + return statement - if use_cache or named: + if self._stmt_cache.get_max_size() or named: stmt_name = self._get_unique_id('stmt') else: stmt_name = '' - state = await protocol.prepare(stmt_name, query, timeout) + statement = await self._protocol.prepare(stmt_name, query, timeout) - ready = state._init_types() + ready = statement._init_types() if ready is not True: if self._types_stmt is None: self._types_stmt = await self.prepare(self._intro_query) types = await self._types_stmt.fetch(list(ready)) - protocol.get_settings().register_data_types(types) + self._protocol.get_settings().register_data_types(types) - if use_cache: - if len(self._stmt_cache) > self._stmt_cache_max_size - 1: - old_query, old_state = self._stmt_cache.popitem(last=False) - self._maybe_gc_stmt(old_state) - self._stmt_cache[query] = state + self._stmt_cache.put(query, statement) # If we've just created a new statement object, check if there # are any statements for GC. if self._stmts_to_close: await self._cleanup_stmts() - return state + return statement def cursor(self, query, *args, prefetch=None, timeout=None): """Return a *cursor factory* for the specified query. @@ -457,14 +451,14 @@ async def close(self): """Close the connection gracefully.""" if self.is_closed(): return - self._close_stmts() + self._mark_stmts_as_closed() self._listeners = {} self._aborted = True await self._protocol.close() def terminate(self): """Terminate the connection without waiting for pending data.""" - self._close_stmts() + self._mark_stmts_as_closed() self._listeners = {} self._aborted = True self._protocol.abort() @@ -484,8 +478,8 @@ def _get_unique_id(self, prefix): self._uid += 1 return '__asyncpg_{}_{}__'.format(prefix, self._uid) - def _close_stmts(self): - for stmt in self._stmt_cache.values(): + def _mark_stmts_as_closed(self): + for stmt in self._stmt_cache.iter_statements(): stmt.mark_closed() for stmt in self._stmts_to_close: @@ -495,11 +489,22 @@ def _close_stmts(self): self._stmts_to_close.clear() def _maybe_gc_stmt(self, stmt): - if stmt.refs == 0 and stmt.query not in self._stmt_cache: + if stmt.refs == 0 and not self._stmt_cache.has(stmt.query): + # If low-level `stmt` isn't referenced from any high-level + # `PreparedStatament` object and is not in the `_stmt_cache`: + # + # * mark it as closed, which will make it non-usable + # for any `PreparedStatament` or for methods like + # `Connection.fetch()`. + # + # * schedule it to be formally closed on the server. stmt.mark_closed() self._stmts_to_close.add(stmt) async def _cleanup_stmts(self): + # Called whenever we create a new prepared statement in + # `Connection._get_statement()` and `_stmts_to_close` is + # not empty. to_close = self._stmts_to_close self._stmts_to_close = set() for stmt in to_close: @@ -700,6 +705,7 @@ async def connect(dsn=None, *, loop=None, timeout=60, statement_cache_size=100, + max_cached_statement_lifetime=300, command_timeout=None, __connection_class__=Connection, **opts): @@ -735,6 +741,12 @@ async def connect(dsn=None, *, :param float timeout: connection timeout in seconds. :param int statement_cache_size: the size of prepared statement LRU cache. + Pass ``0`` to disable the cache. + + :param int max_cached_statement_lifetime: + the maximum time in seconds a prepared statement will stay + in the cache. Pass ``0`` to allow statements be cached + indefinitely. :param float command_timeout: the default timeout for operations on this connection (the default is no timeout). @@ -753,6 +765,9 @@ async def connect(dsn=None, *, ... print(types) >>> asyncio.get_event_loop().run_until_complete(run()) [= 0 + self._max_size = new_size + self._maybe_cleanup() + + def get_max_lifetime(self): + return self._max_lifetime + + def set_max_lifetime(self, new_lifetime): + assert new_lifetime >= 0 + self._max_lifetime = new_lifetime + for entry in self._entries.values(): + # For every entry cancel the existing callback + # and setup a new one if necessary. + self._set_entry_timeout(entry) + + def get(self, query, *, promote=True): + if not self._max_size: + # The cache is disabled. + return + + entry = self._entries.get(query) # type: _StatementCacheEntry + if entry is None: + return + + if entry._statement.closed: + # Happens in unittests when we call `stmt._state.mark_closed()` + # manually. + self._entries.pop(query) + self._clear_entry_callback(entry) + return + + if promote: + # `promote` is `False` when `get()` is called by `has()`. + self._entries.move_to_end(query, last=True) + + return entry._statement + + def has(self, query): + return self.get(query, promote=False) is not None + + def put(self, query, statement): + if not self._max_size: + # The cache is disabled. + return + + self._entries[query] = self._new_entry(query, statement) + + # Check if the cache is bigger than max_size and trim it + # if necessary. + self._maybe_cleanup() + + def iter_statements(self): + return (e._statement for e in self._entries.values()) + + def clear(self): + # First, make sure that we cancel all scheduled callbacks. + for entry in self._entries.values(): + self._clear_entry_callback(entry) + + # Clear the entries dict. + self._entries.clear() + + def _set_entry_timeout(self, entry): + # Clear the existing timeout. + self._clear_entry_callback(entry) + + # Set the new timeout if it's not 0. + if self._max_lifetime: + entry._cleanup_cb = self._loop.call_later( + self._max_lifetime, self._on_entry_expired, entry) + + def _new_entry(self, query, statement): + entry = _StatementCacheEntry(self, query, statement) + self._set_entry_timeout(entry) + return entry + + def _on_entry_expired(self, entry): + # `call_later` callback, called when an entry stayed longer + # than `self._max_lifetime`. + if self._entries.get(entry._query) is entry: + self._entries.pop(entry._query) + self._on_remove(entry._statement) + + def _clear_entry_callback(self, entry): + if entry._cleanup_cb is not None: + entry._cleanup_cb.cancel() + + def _maybe_cleanup(self): + # Delete cache entries until the size of the cache is `max_size`. + while len(self._entries) > self._max_size: + old_query, old_entry = self._entries.popitem(last=False) + self._clear_entry_callback(old_entry) + + # Let the connection know that the statement was removed + # from the cache. + self._on_remove(old_entry._statement) + + class _Atomic: __slots__ = ('_acquired',) diff --git a/tests/test_prepare.py b/tests/test_prepare.py index d602efa0..98e728b8 100644 --- a/tests/test_prepare.py +++ b/tests/test_prepare.py @@ -142,12 +142,14 @@ async def test_prepare_09_raise_error(self): await stmt.fetchval() async def test_prepare_10_stmt_lru(self): + cache = self.con._stmt_cache + query = 'select {}' - cache_max = self.con._stmt_cache_max_size + cache_max = cache.get_max_size() iter_max = cache_max * 2 + 11 # First, we have no cached statements. - self.assertEqual(len(self.con._stmt_cache), 0) + self.assertEqual(len(cache), 0) stmts = [] for i in range(iter_max): @@ -156,9 +158,8 @@ async def test_prepare_10_stmt_lru(self): stmts.append(s) # At this point our cache should be full. - self.assertEqual(len(self.con._stmt_cache), cache_max) - self.assertTrue( - all(not s.closed for s in self.con._stmt_cache.values())) + self.assertEqual(len(cache), cache_max) + self.assertTrue(all(not s.closed for s in cache.iter_statements())) # Since there are references to the statements (`stmts` list), # no statements are scheduled to be closed. @@ -173,22 +174,20 @@ async def test_prepare_10_stmt_lru(self): # scheduled to be closed. self.assertEqual(len(self.con._stmts_to_close), iter_max - cache_max) self.assertTrue(all(s.closed for s in self.con._stmts_to_close)) - self.assertTrue( - all(not s.closed for s in self.con._stmt_cache.values())) + self.assertTrue(all(not s.closed for s in cache.iter_statements())) zero = await self.con.prepare(query.format(0)) # Hence, all stale statements should be closed now. self.assertEqual(len(self.con._stmts_to_close), 0) # The number of cached statements will stay the same though. - self.assertEqual(len(self.con._stmt_cache), cache_max) - self.assertTrue( - all(not s.closed for s in self.con._stmt_cache.values())) + self.assertEqual(len(cache), cache_max) + self.assertTrue(all(not s.closed for s in cache.iter_statements())) # After closing all statements will be closed. await self.con.close() self.assertEqual(len(self.con._stmts_to_close), 0) - self.assertEqual(len(self.con._stmt_cache), 0) + self.assertEqual(len(cache), 0) # An attempt to perform an operation on a closed statement # will trigger an error. @@ -199,8 +198,10 @@ async def test_prepare_11_stmt_gc(self): # Test that prepared statements should stay in the cache after # they are GCed. + cache = self.con._stmt_cache + # First, we have no cached statements. - self.assertEqual(len(self.con._stmt_cache), 0) + self.assertEqual(len(cache), 0) self.assertEqual(len(self.con._stmts_to_close), 0) # The prepared statement that we'll create will be GCed @@ -209,33 +210,34 @@ async def test_prepare_11_stmt_gc(self): await self.con.prepare('select 1') gc.collect() - self.assertEqual(len(self.con._stmt_cache), 1) + self.assertEqual(len(cache), 1) self.assertEqual(len(self.con._stmts_to_close), 0) async def test_prepare_12_stmt_gc(self): # Test that prepared statements are closed when there is no space # for them in the LRU cache and there are no references to them. + cache = self.con._stmt_cache + cache_max = cache.get_max_size() + # First, we have no cached statements. - self.assertEqual(len(self.con._stmt_cache), 0) + self.assertEqual(len(cache), 0) self.assertEqual(len(self.con._stmts_to_close), 0) - cache_max = self.con._stmt_cache_max_size - stmt = await self.con.prepare('select 100000000') - self.assertEqual(len(self.con._stmt_cache), 1) + self.assertEqual(len(cache), 1) self.assertEqual(len(self.con._stmts_to_close), 0) for i in range(cache_max): await self.con.prepare('select {}'.format(i)) - self.assertEqual(len(self.con._stmt_cache), cache_max) + self.assertEqual(len(cache), cache_max) self.assertEqual(len(self.con._stmts_to_close), 0) del stmt gc.collect() - self.assertEqual(len(self.con._stmt_cache), cache_max) + self.assertEqual(len(cache), cache_max) self.assertEqual(len(self.con._stmts_to_close), 1) async def test_prepare_13_connect(self): @@ -283,25 +285,28 @@ async def test_prepare_15_stmt_gc_cache_disabled(self): # Test that even if the statements cache is off, we're still # cleaning up GCed statements. - self.assertEqual(len(self.con._stmt_cache), 0) + cache = self.con._stmt_cache + + self.assertEqual(len(cache), 0) self.assertEqual(len(self.con._stmts_to_close), 0) + # Disable cache - self.con._stmt_cache_max_size = 0 + cache.set_max_size(0) stmt = await self.con.prepare('select 100000000') - self.assertEqual(len(self.con._stmt_cache), 0) + self.assertEqual(len(cache), 0) self.assertEqual(len(self.con._stmts_to_close), 0) del stmt gc.collect() # After GC, _stmts_to_close should contain stmt's state - self.assertEqual(len(self.con._stmt_cache), 0) + self.assertEqual(len(cache), 0) self.assertEqual(len(self.con._stmts_to_close), 1) # Next "prepare" call will trigger a cleanup stmt = await self.con.prepare('select 1') - self.assertEqual(len(self.con._stmt_cache), 0) + self.assertEqual(len(cache), 0) self.assertEqual(len(self.con._stmts_to_close), 0) del stmt @@ -425,10 +430,9 @@ async def test_prepare_statement_invalid(self): finally: await self.con.execute('DROP TABLE tab1') + @tb.with_connection_options(statement_cache_size=0) async def test_prepare_23_no_stmt_cache_seq(self): - # Disable cache, which will force connections to use - # anonymous prepared statements. - self.con._stmt_cache_max_size = 0 + self.assertEqual(self.con._stmt_cache.get_max_size(), 0) async def check_simple(): # Run a simple query a few times. @@ -456,3 +460,56 @@ async def check_simple(): # Check that we can run queries after a failed cursor # operation. await check_simple() + + @tb.with_connection_options(max_cached_statement_lifetime=142) + async def test_prepare_24_max_lifetime(self): + cache = self.con._stmt_cache + + self.assertEqual(cache.get_max_lifetime(), 142) + cache.set_max_lifetime(1) + + s = await self.con.prepare('SELECT 1') + state = s._state + + s = await self.con.prepare('SELECT 1') + self.assertIs(s._state, state) + + s = await self.con.prepare('SELECT 1') + self.assertIs(s._state, state) + + await asyncio.sleep(1, loop=self.loop) + + s = await self.con.prepare('SELECT 1') + self.assertIsNot(s._state, state) + + @tb.with_connection_options(max_cached_statement_lifetime=0.5) + async def test_prepare_25_max_lifetime_reset(self): + cache = self.con._stmt_cache + + s = await self.con.prepare('SELECT 1') + state = s._state + + # Disable max_lifetime + cache.set_max_lifetime(0) + + await asyncio.sleep(1, loop=self.loop) + + # The statement should still be cached (as we disabled the timeout). + s = await self.con.prepare('SELECT 1') + self.assertIs(s._state, state) + + @tb.with_connection_options(max_cached_statement_lifetime=0.5) + async def test_prepare_26_max_lifetime_max_size(self): + cache = self.con._stmt_cache + + s = await self.con.prepare('SELECT 1') + state = s._state + + # Disable max_lifetime + cache.set_max_size(0) + + s = await self.con.prepare('SELECT 1') + self.assertIsNot(s._state, state) + + # Check that nothing crashes after the initial timeout + await asyncio.sleep(1, loop=self.loop) diff --git a/tests/test_timeout.py b/tests/test_timeout.py index b9bf6a21..6ca5d63f 100644 --- a/tests/test_timeout.py +++ b/tests/test_timeout.py @@ -128,11 +128,7 @@ async def test_invalid_timeout(self): class TestConnectionCommandTimeout(tb.ConnectedTestCase): - def getExtraConnectOptions(self): - return { - 'command_timeout': 0.2 - } - + @tb.with_connection_options(command_timeout=0.2) async def test_command_timeout_01(self): for methname in {'fetch', 'fetchrow', 'fetchval', 'execute'}: with self.assertRaises(asyncio.TimeoutError), \ @@ -151,12 +147,8 @@ async def _get_statement(self, query, timeout): class TestTimeoutCoversPrepare(tb.ConnectedTestCase): - def getExtraConnectOptions(self): - return { - '__connection_class__': SlowPrepareConnection, - 'command_timeout': 0.3 - } - + @tb.with_connection_options(__connection_class__=SlowPrepareConnection, + command_timeout=0.3) async def test_timeout_covers_prepare_01(self): for methname in {'fetch', 'fetchrow', 'fetchval', 'execute'}: with self.assertRaises(asyncio.TimeoutError):