@@ -42,10 +42,12 @@ class Connection(metaclass=ConnectionMeta):
42
42
'_stmt_cache_max_size' , '_stmt_cache' , '_stmts_to_close' ,
43
43
'_addr' , '_opts' , '_command_timeout' , '_listeners' ,
44
44
'_server_version' , '_server_caps' , '_intro_query' ,
45
- '_reset_query' , '_proxy' , '_stmt_exclusive_section' )
45
+ '_reset_query' , '_proxy' , '_stmt_exclusive_section' ,
46
+ '_max_cached_statement_use_count' )
46
47
47
48
def __init__ (self , protocol , transport , loop , addr , opts , * ,
48
- statement_cache_size , command_timeout ):
49
+ statement_cache_size , command_timeout ,
50
+ max_cached_statement_use_count ):
49
51
self ._protocol = protocol
50
52
self ._transport = transport
51
53
self ._loop = loop
@@ -61,6 +63,7 @@ def __init__(self, protocol, transport, loop, addr, opts, *,
61
63
self ._stmt_cache_max_size = statement_cache_size
62
64
self ._stmt_cache = collections .OrderedDict ()
63
65
self ._stmts_to_close = set ()
66
+ self ._max_cached_statement_use_count = max_cached_statement_use_count
64
67
65
68
if command_timeout is not None :
66
69
try :
@@ -263,13 +266,20 @@ async def _get_statement(self, query, timeout, *, named: bool=False):
263
266
use_cache = self ._stmt_cache_max_size > 0
264
267
if use_cache :
265
268
try :
266
- state = self ._stmt_cache [query ]
269
+ holder = self ._stmt_cache [query ]
267
270
except KeyError :
268
271
pass
269
272
else :
270
- self ._stmt_cache .move_to_end (query , last = True )
271
- if not state .closed :
272
- return state
273
+ if holder .use_count < self ._max_cached_statement_use_count :
274
+ holder .use_count += 1
275
+
276
+ if holder .statement .closed :
277
+ self ._stmt_cache .pop (query )
278
+ else :
279
+ self ._stmt_cache .move_to_end (query , last = True )
280
+ return holder .statement
281
+ else :
282
+ self ._stmt_cache .pop (query )
273
283
274
284
protocol = self ._protocol
275
285
@@ -278,9 +288,9 @@ async def _get_statement(self, query, timeout, *, named: bool=False):
278
288
else :
279
289
stmt_name = ''
280
290
281
- state = await protocol .prepare (stmt_name , query , timeout )
291
+ statement = await protocol .prepare (stmt_name , query , timeout )
282
292
283
- ready = state ._init_types ()
293
+ ready = statement ._init_types ()
284
294
if ready is not True :
285
295
if self ._types_stmt is None :
286
296
self ._types_stmt = await self .prepare (self ._intro_query )
@@ -290,16 +300,16 @@ async def _get_statement(self, query, timeout, *, named: bool=False):
290
300
291
301
if use_cache :
292
302
if len (self ._stmt_cache ) > self ._stmt_cache_max_size - 1 :
293
- old_query , old_state = self ._stmt_cache .popitem (last = False )
294
- self ._maybe_gc_stmt (old_state )
295
- self ._stmt_cache [query ] = state
303
+ old_query , old_holder = self ._stmt_cache .popitem (last = False )
304
+ self ._maybe_gc_stmt (old_holder . statement )
305
+ self ._stmt_cache [query ] = _StatementCacheHolder ( statement )
296
306
297
307
# If we've just created a new statement object, check if there
298
308
# are any statements for GC.
299
309
if self ._stmts_to_close :
300
310
await self ._cleanup_stmts ()
301
311
302
- return state
312
+ return statement
303
313
304
314
def cursor (self , query , * args , prefetch = None , timeout = None ):
305
315
"""Return a *cursor factory* for the specified query.
@@ -465,8 +475,8 @@ def _get_unique_id(self, prefix):
465
475
return '__asyncpg_{}_{}__' .format (prefix , self ._uid )
466
476
467
477
def _close_stmts (self ):
468
- for stmt in self ._stmt_cache .values ():
469
- stmt .mark_closed ()
478
+ for holder in self ._stmt_cache .values ():
479
+ holder . statement .mark_closed ()
470
480
471
481
for stmt in self ._stmts_to_close :
472
482
stmt .mark_closed ()
@@ -680,6 +690,7 @@ async def connect(dsn=None, *,
680
690
loop = None ,
681
691
timeout = 60 ,
682
692
statement_cache_size = 100 ,
693
+ max_cached_statement_use_count = 100 ,
683
694
command_timeout = None ,
684
695
__connection_class__ = Connection ,
685
696
** opts ):
@@ -715,6 +726,10 @@ async def connect(dsn=None, *,
715
726
:param float timeout: connection timeout in seconds.
716
727
717
728
:param int statement_cache_size: the size of prepared statement LRU cache.
729
+ Pass ``0`` to disable the cache.
730
+
731
+ :param int max_cached_statement_use_count: max number of uses for a cached
732
+ prepared statement.
718
733
719
734
:param float command_timeout: the default timeout for operations on
720
735
this connection (the default is no timeout).
@@ -733,6 +748,9 @@ async def connect(dsn=None, *,
733
748
... print(types)
734
749
>>> asyncio.get_event_loop().run_until_complete(run())
735
750
[<Record typname='bool' typnamespace=11 ...
751
+
752
+ .. versionchanged:: 0.10.0
753
+ Added ``max_cached_statement_use_count`` parameter.
736
754
"""
737
755
if loop is None :
738
756
loop = asyncio .get_event_loop ()
@@ -776,13 +794,24 @@ async def connect(dsn=None, *,
776
794
tr .close ()
777
795
raise
778
796
779
- con = __connection_class__ (pr , tr , loop , addr , opts ,
780
- statement_cache_size = statement_cache_size ,
781
- command_timeout = command_timeout )
797
+ con = __connection_class__ (
798
+ pr , tr , loop , addr , opts ,
799
+ statement_cache_size = statement_cache_size ,
800
+ max_cached_statement_use_count = max_cached_statement_use_count ,
801
+ command_timeout = command_timeout )
802
+
782
803
pr .set_connection (con )
783
804
return con
784
805
785
806
807
+ class _StatementCacheHolder :
808
+ __slots__ = ('statement' , 'use_count' )
809
+
810
+ def __init__ (self , statement ):
811
+ self .use_count = 1
812
+ self .statement = statement
813
+
814
+
786
815
class _Atomic :
787
816
__slots__ = ('_acquired' ,)
788
817
0 commit comments