Skip to content

Commit 29d619b

Browse files
committed
Add new max_cached_statement_use_count parameter to Connection.
The parameter allows asyncpg to refresh cached prepared statements periodically. See also issue #76.
1 parent f95cd86 commit 29d619b

File tree

2 files changed

+64
-20
lines changed

2 files changed

+64
-20
lines changed

asyncpg/connection.py

+46-17
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,12 @@ class Connection(metaclass=ConnectionMeta):
4141
'_stmt_cache_max_size', '_stmt_cache', '_stmts_to_close',
4242
'_addr', '_opts', '_command_timeout', '_listeners',
4343
'_server_version', '_server_caps', '_intro_query',
44-
'_reset_query', '_proxy', '_stmt_exclusive_section')
44+
'_reset_query', '_proxy', '_stmt_exclusive_section',
45+
'_max_cached_statement_use_count')
4546

4647
def __init__(self, protocol, transport, loop, addr, opts, *,
47-
statement_cache_size, command_timeout):
48+
statement_cache_size, command_timeout,
49+
max_cached_statement_use_count):
4850
self._protocol = protocol
4951
self._transport = transport
5052
self._loop = loop
@@ -60,6 +62,7 @@ def __init__(self, protocol, transport, loop, addr, opts, *,
6062
self._stmt_cache_max_size = statement_cache_size
6163
self._stmt_cache = collections.OrderedDict()
6264
self._stmts_to_close = set()
65+
self._max_cached_statement_use_count = max_cached_statement_use_count
6366

6467
if command_timeout is not None:
6568
try:
@@ -240,13 +243,20 @@ async def _get_statement(self, query, timeout, *, named: bool=False):
240243
use_cache = self._stmt_cache_max_size > 0
241244
if use_cache:
242245
try:
243-
state = self._stmt_cache[query]
246+
holder = self._stmt_cache[query]
244247
except KeyError:
245248
pass
246249
else:
247-
self._stmt_cache.move_to_end(query, last=True)
248-
if not state.closed:
249-
return state
250+
if holder.use_count < self._max_cached_statement_use_count:
251+
holder.use_count += 1
252+
253+
if holder.statement.closed:
254+
self._stmt_cache.pop(query)
255+
else:
256+
self._stmt_cache.move_to_end(query, last=True)
257+
return holder.statement
258+
else:
259+
self._stmt_cache.pop(query)
250260

251261
protocol = self._protocol
252262

@@ -255,9 +265,9 @@ async def _get_statement(self, query, timeout, *, named: bool=False):
255265
else:
256266
stmt_name = ''
257267

258-
state = await protocol.prepare(stmt_name, query, timeout)
268+
statement = await protocol.prepare(stmt_name, query, timeout)
259269

260-
ready = state._init_types()
270+
ready = statement._init_types()
261271
if ready is not True:
262272
if self._types_stmt is None:
263273
self._types_stmt = await self.prepare(self._intro_query)
@@ -267,16 +277,16 @@ async def _get_statement(self, query, timeout, *, named: bool=False):
267277

268278
if use_cache:
269279
if len(self._stmt_cache) > self._stmt_cache_max_size - 1:
270-
old_query, old_state = self._stmt_cache.popitem(last=False)
271-
self._maybe_gc_stmt(old_state)
272-
self._stmt_cache[query] = state
280+
old_query, old_holder = self._stmt_cache.popitem(last=False)
281+
self._maybe_gc_stmt(old_holder.statement)
282+
self._stmt_cache[query] = _StatementCacheHolder(statement)
273283

274284
# If we've just created a new statement object, check if there
275285
# are any statements for GC.
276286
if self._stmts_to_close:
277287
await self._cleanup_stmts()
278288

279-
return state
289+
return statement
280290

281291
def cursor(self, query, *args, prefetch=None, timeout=None):
282292
"""Return a *cursor factory* for the specified query.
@@ -442,8 +452,8 @@ def _get_unique_id(self, prefix):
442452
return '__asyncpg_{}_{}__'.format(prefix, self._uid)
443453

444454
def _close_stmts(self):
445-
for stmt in self._stmt_cache.values():
446-
stmt.mark_closed()
455+
for holder in self._stmt_cache.values():
456+
holder.statement.mark_closed()
447457

448458
for stmt in self._stmts_to_close:
449459
stmt.mark_closed()
@@ -657,6 +667,7 @@ async def connect(dsn=None, *,
657667
loop=None,
658668
timeout=60,
659669
statement_cache_size=100,
670+
max_cached_statement_use_count=100,
660671
command_timeout=None,
661672
__connection_class__=Connection,
662673
**opts):
@@ -692,6 +703,10 @@ async def connect(dsn=None, *,
692703
:param float timeout: connection timeout in seconds.
693704
694705
:param int statement_cache_size: the size of prepared statement LRU cache.
706+
Pass ``0`` to disable the cache.
707+
708+
:param int max_cached_statement_use_count: max number of uses for a cached
709+
prepared statement.
695710
696711
:param float command_timeout: the default timeout for operations on
697712
this connection (the default is no timeout).
@@ -710,6 +725,9 @@ async def connect(dsn=None, *,
710725
... print(types)
711726
>>> asyncio.get_event_loop().run_until_complete(run())
712727
[<Record typname='bool' typnamespace=11 ...
728+
729+
.. versionchanged:: 0.10.0
730+
Added ``max_cached_statement_use_count`` parameter.
713731
"""
714732
if loop is None:
715733
loop = asyncio.get_event_loop()
@@ -753,13 +771,24 @@ async def connect(dsn=None, *,
753771
tr.close()
754772
raise
755773

756-
con = __connection_class__(pr, tr, loop, addr, opts,
757-
statement_cache_size=statement_cache_size,
758-
command_timeout=command_timeout)
774+
con = __connection_class__(
775+
pr, tr, loop, addr, opts,
776+
statement_cache_size=statement_cache_size,
777+
max_cached_statement_use_count=max_cached_statement_use_count,
778+
command_timeout=command_timeout)
779+
759780
pr.set_connection(con)
760781
return con
761782

762783

784+
class _StatementCacheHolder:
785+
__slots__ = ('statement', 'use_count')
786+
787+
def __init__(self, statement):
788+
self.use_count = 1
789+
self.statement = statement
790+
791+
763792
class _Atomic:
764793
__slots__ = ('_acquired',)
765794

tests/test_prepare.py

+18-3
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,7 @@ async def test_prepare_10_stmt_lru(self):
158158
# At this point our cache should be full.
159159
self.assertEqual(len(self.con._stmt_cache), cache_max)
160160
self.assertTrue(
161-
all(not s.closed for s in self.con._stmt_cache.values()))
161+
all(not s.statement.closed for s in self.con._stmt_cache.values()))
162162

163163
# Since there are references to the statements (`stmts` list),
164164
# no statements are scheduled to be closed.
@@ -174,7 +174,7 @@ async def test_prepare_10_stmt_lru(self):
174174
self.assertEqual(len(self.con._stmts_to_close), iter_max - cache_max)
175175
self.assertTrue(all(s.closed for s in self.con._stmts_to_close))
176176
self.assertTrue(
177-
all(not s.closed for s in self.con._stmt_cache.values()))
177+
all(not s.statement.closed for s in self.con._stmt_cache.values()))
178178

179179
zero = await self.con.prepare(query.format(0))
180180
# Hence, all stale statements should be closed now.
@@ -183,7 +183,7 @@ async def test_prepare_10_stmt_lru(self):
183183
# The number of cached statements will stay the same though.
184184
self.assertEqual(len(self.con._stmt_cache), cache_max)
185185
self.assertTrue(
186-
all(not s.closed for s in self.con._stmt_cache.values()))
186+
all(not s.statement.closed for s in self.con._stmt_cache.values()))
187187

188188
# After closing all statements will be closed.
189189
await self.con.close()
@@ -456,3 +456,18 @@ async def check_simple():
456456
# Check that we can run queries after a failed cursor
457457
# operation.
458458
await check_simple()
459+
460+
async def test_prepare_24_max_use_count(self):
461+
self.con._max_cached_statement_use_count = 3
462+
463+
s = await self.con.prepare('SELECT 1')
464+
state = s._state
465+
466+
s = await self.con.prepare('SELECT 1')
467+
self.assertIs(s._state, state)
468+
469+
s = await self.con.prepare('SELECT 1')
470+
self.assertIs(s._state, state)
471+
472+
s = await self.con.prepare('SELECT 1')
473+
self.assertIsNot(s._state, state)

0 commit comments

Comments
 (0)