Skip to content

Commit 101311b

Browse files
committed
Add new max_cached_statement_use_count parameter to Connection.
The parameter allows asyncpg to refresh cached prepared statements periodically. See also issue #76.
1 parent 12cce92 commit 101311b

File tree

2 files changed

+64
-20
lines changed

2 files changed

+64
-20
lines changed

asyncpg/connection.py

Lines changed: 46 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,12 @@ class Connection(metaclass=ConnectionMeta):
4242
'_stmt_cache_max_size', '_stmt_cache', '_stmts_to_close',
4343
'_addr', '_opts', '_command_timeout', '_listeners',
4444
'_server_version', '_server_caps', '_intro_query',
45-
'_reset_query', '_proxy', '_stmt_exclusive_section')
45+
'_reset_query', '_proxy', '_stmt_exclusive_section',
46+
'_max_cached_statement_use_count')
4647

4748
def __init__(self, protocol, transport, loop, addr, opts, *,
48-
statement_cache_size, command_timeout):
49+
statement_cache_size, command_timeout,
50+
max_cached_statement_use_count):
4951
self._protocol = protocol
5052
self._transport = transport
5153
self._loop = loop
@@ -61,6 +63,7 @@ def __init__(self, protocol, transport, loop, addr, opts, *,
6163
self._stmt_cache_max_size = statement_cache_size
6264
self._stmt_cache = collections.OrderedDict()
6365
self._stmts_to_close = set()
66+
self._max_cached_statement_use_count = max_cached_statement_use_count
6467

6568
if command_timeout is not None:
6669
try:
@@ -263,13 +266,20 @@ async def _get_statement(self, query, timeout, *, named: bool=False):
263266
use_cache = self._stmt_cache_max_size > 0
264267
if use_cache:
265268
try:
266-
state = self._stmt_cache[query]
269+
holder = self._stmt_cache[query]
267270
except KeyError:
268271
pass
269272
else:
270-
self._stmt_cache.move_to_end(query, last=True)
271-
if not state.closed:
272-
return state
273+
if holder.use_count < self._max_cached_statement_use_count:
274+
holder.use_count += 1
275+
276+
if holder.statement.closed:
277+
self._stmt_cache.pop(query)
278+
else:
279+
self._stmt_cache.move_to_end(query, last=True)
280+
return holder.statement
281+
else:
282+
self._stmt_cache.pop(query)
273283

274284
protocol = self._protocol
275285

@@ -278,9 +288,9 @@ async def _get_statement(self, query, timeout, *, named: bool=False):
278288
else:
279289
stmt_name = ''
280290

281-
state = await protocol.prepare(stmt_name, query, timeout)
291+
statement = await protocol.prepare(stmt_name, query, timeout)
282292

283-
ready = state._init_types()
293+
ready = statement._init_types()
284294
if ready is not True:
285295
if self._types_stmt is None:
286296
self._types_stmt = await self.prepare(self._intro_query)
@@ -290,16 +300,16 @@ async def _get_statement(self, query, timeout, *, named: bool=False):
290300

291301
if use_cache:
292302
if len(self._stmt_cache) > self._stmt_cache_max_size - 1:
293-
old_query, old_state = self._stmt_cache.popitem(last=False)
294-
self._maybe_gc_stmt(old_state)
295-
self._stmt_cache[query] = state
303+
old_query, old_holder = self._stmt_cache.popitem(last=False)
304+
self._maybe_gc_stmt(old_holder.statement)
305+
self._stmt_cache[query] = _StatementCacheHolder(statement)
296306

297307
# If we've just created a new statement object, check if there
298308
# are any statements for GC.
299309
if self._stmts_to_close:
300310
await self._cleanup_stmts()
301311

302-
return state
312+
return statement
303313

304314
def cursor(self, query, *args, prefetch=None, timeout=None):
305315
"""Return a *cursor factory* for the specified query.
@@ -465,8 +475,8 @@ def _get_unique_id(self, prefix):
465475
return '__asyncpg_{}_{}__'.format(prefix, self._uid)
466476

467477
def _close_stmts(self):
468-
for stmt in self._stmt_cache.values():
469-
stmt.mark_closed()
478+
for holder in self._stmt_cache.values():
479+
holder.statement.mark_closed()
470480

471481
for stmt in self._stmts_to_close:
472482
stmt.mark_closed()
@@ -680,6 +690,7 @@ async def connect(dsn=None, *,
680690
loop=None,
681691
timeout=60,
682692
statement_cache_size=100,
693+
max_cached_statement_use_count=100,
683694
command_timeout=None,
684695
__connection_class__=Connection,
685696
**opts):
@@ -715,6 +726,10 @@ async def connect(dsn=None, *,
715726
:param float timeout: connection timeout in seconds.
716727
717728
:param int statement_cache_size: the size of prepared statement LRU cache.
729+
Pass ``0`` to disable the cache.
730+
731+
:param int max_cached_statement_use_count: max number of uses for a cached
732+
prepared statement.
718733
719734
:param float command_timeout: the default timeout for operations on
720735
this connection (the default is no timeout).
@@ -733,6 +748,9 @@ async def connect(dsn=None, *,
733748
... print(types)
734749
>>> asyncio.get_event_loop().run_until_complete(run())
735750
[<Record typname='bool' typnamespace=11 ...
751+
752+
.. versionchanged:: 0.10.0
753+
Added ``max_cached_statement_use_count`` parameter.
736754
"""
737755
if loop is None:
738756
loop = asyncio.get_event_loop()
@@ -776,13 +794,24 @@ async def connect(dsn=None, *,
776794
tr.close()
777795
raise
778796

779-
con = __connection_class__(pr, tr, loop, addr, opts,
780-
statement_cache_size=statement_cache_size,
781-
command_timeout=command_timeout)
797+
con = __connection_class__(
798+
pr, tr, loop, addr, opts,
799+
statement_cache_size=statement_cache_size,
800+
max_cached_statement_use_count=max_cached_statement_use_count,
801+
command_timeout=command_timeout)
802+
782803
pr.set_connection(con)
783804
return con
784805

785806

807+
class _StatementCacheHolder:
808+
__slots__ = ('statement', 'use_count')
809+
810+
def __init__(self, statement):
811+
self.use_count = 1
812+
self.statement = statement
813+
814+
786815
class _Atomic:
787816
__slots__ = ('_acquired',)
788817

tests/test_prepare.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ async def test_prepare_10_stmt_lru(self):
158158
# At this point our cache should be full.
159159
self.assertEqual(len(self.con._stmt_cache), cache_max)
160160
self.assertTrue(
161-
all(not s.closed for s in self.con._stmt_cache.values()))
161+
all(not s.statement.closed for s in self.con._stmt_cache.values()))
162162

163163
# Since there are references to the statements (`stmts` list),
164164
# no statements are scheduled to be closed.
@@ -174,7 +174,7 @@ async def test_prepare_10_stmt_lru(self):
174174
self.assertEqual(len(self.con._stmts_to_close), iter_max - cache_max)
175175
self.assertTrue(all(s.closed for s in self.con._stmts_to_close))
176176
self.assertTrue(
177-
all(not s.closed for s in self.con._stmt_cache.values()))
177+
all(not s.statement.closed for s in self.con._stmt_cache.values()))
178178

179179
zero = await self.con.prepare(query.format(0))
180180
# Hence, all stale statements should be closed now.
@@ -183,7 +183,7 @@ async def test_prepare_10_stmt_lru(self):
183183
# The number of cached statements will stay the same though.
184184
self.assertEqual(len(self.con._stmt_cache), cache_max)
185185
self.assertTrue(
186-
all(not s.closed for s in self.con._stmt_cache.values()))
186+
all(not s.statement.closed for s in self.con._stmt_cache.values()))
187187

188188
# After closing all statements will be closed.
189189
await self.con.close()
@@ -456,3 +456,18 @@ async def check_simple():
456456
# Check that we can run queries after a failed cursor
457457
# operation.
458458
await check_simple()
459+
460+
async def test_prepare_24_max_use_count(self):
461+
self.con._max_cached_statement_use_count = 3
462+
463+
s = await self.con.prepare('SELECT 1')
464+
state = s._state
465+
466+
s = await self.con.prepare('SELECT 1')
467+
self.assertIs(s._state, state)
468+
469+
s = await self.con.prepare('SELECT 1')
470+
self.assertIs(s._state, state)
471+
472+
s = await self.con.prepare('SELECT 1')
473+
self.assertIsNot(s._state, state)

0 commit comments

Comments
 (0)