@@ -41,10 +41,12 @@ class Connection(metaclass=ConnectionMeta):
41
41
'_stmt_cache_max_size' , '_stmt_cache' , '_stmts_to_close' ,
42
42
'_addr' , '_opts' , '_command_timeout' , '_listeners' ,
43
43
'_server_version' , '_server_caps' , '_intro_query' ,
44
- '_reset_query' , '_proxy' , '_stmt_exclusive_section' )
44
+ '_reset_query' , '_proxy' , '_stmt_exclusive_section' ,
45
+ '_max_cached_statement_use_count' )
45
46
46
47
def __init__ (self , protocol , transport , loop , addr , opts , * ,
47
- statement_cache_size , command_timeout ):
48
+ statement_cache_size , command_timeout ,
49
+ max_cached_statement_use_count ):
48
50
self ._protocol = protocol
49
51
self ._transport = transport
50
52
self ._loop = loop
@@ -60,6 +62,7 @@ def __init__(self, protocol, transport, loop, addr, opts, *,
60
62
self ._stmt_cache_max_size = statement_cache_size
61
63
self ._stmt_cache = collections .OrderedDict ()
62
64
self ._stmts_to_close = set ()
65
+ self ._max_cached_statement_use_count = max_cached_statement_use_count
63
66
64
67
if command_timeout is not None :
65
68
try :
@@ -240,13 +243,20 @@ async def _get_statement(self, query, timeout, *, named: bool=False):
240
243
use_cache = self ._stmt_cache_max_size > 0
241
244
if use_cache :
242
245
try :
243
- state = self ._stmt_cache [query ]
246
+ holder = self ._stmt_cache [query ]
244
247
except KeyError :
245
248
pass
246
249
else :
247
- self ._stmt_cache .move_to_end (query , last = True )
248
- if not state .closed :
249
- return state
250
+ if holder .use_count < self ._max_cached_statement_use_count :
251
+ holder .use_count += 1
252
+
253
+ if holder .statement .closed :
254
+ self ._stmt_cache .pop (query )
255
+ else :
256
+ self ._stmt_cache .move_to_end (query , last = True )
257
+ return holder .statement
258
+ else :
259
+ self ._stmt_cache .pop (query )
250
260
251
261
protocol = self ._protocol
252
262
@@ -255,9 +265,9 @@ async def _get_statement(self, query, timeout, *, named: bool=False):
255
265
else :
256
266
stmt_name = ''
257
267
258
- state = await protocol .prepare (stmt_name , query , timeout )
268
+ statement = await protocol .prepare (stmt_name , query , timeout )
259
269
260
- ready = state ._init_types ()
270
+ ready = statement ._init_types ()
261
271
if ready is not True :
262
272
if self ._types_stmt is None :
263
273
self ._types_stmt = await self .prepare (self ._intro_query )
@@ -267,16 +277,16 @@ async def _get_statement(self, query, timeout, *, named: bool=False):
267
277
268
278
if use_cache :
269
279
if len (self ._stmt_cache ) > self ._stmt_cache_max_size - 1 :
270
- old_query , old_state = self ._stmt_cache .popitem (last = False )
271
- self ._maybe_gc_stmt (old_state )
272
- self ._stmt_cache [query ] = state
280
+ old_query , old_holder = self ._stmt_cache .popitem (last = False )
281
+ self ._maybe_gc_stmt (old_holder . statement )
282
+ self ._stmt_cache [query ] = _StatementCacheHolder ( statement )
273
283
274
284
# If we've just created a new statement object, check if there
275
285
# are any statements for GC.
276
286
if self ._stmts_to_close :
277
287
await self ._cleanup_stmts ()
278
288
279
- return state
289
+ return statement
280
290
281
291
def cursor (self , query , * args , prefetch = None , timeout = None ):
282
292
"""Return a *cursor factory* for the specified query.
@@ -442,8 +452,8 @@ def _get_unique_id(self, prefix):
442
452
return '__asyncpg_{}_{}__' .format (prefix , self ._uid )
443
453
444
454
def _close_stmts (self ):
445
- for stmt in self ._stmt_cache .values ():
446
- stmt .mark_closed ()
455
+ for holder in self ._stmt_cache .values ():
456
+ holder . statement .mark_closed ()
447
457
448
458
for stmt in self ._stmts_to_close :
449
459
stmt .mark_closed ()
@@ -657,6 +667,7 @@ async def connect(dsn=None, *,
657
667
loop = None ,
658
668
timeout = 60 ,
659
669
statement_cache_size = 100 ,
670
+ max_cached_statement_use_count = 100 ,
660
671
command_timeout = None ,
661
672
__connection_class__ = Connection ,
662
673
** opts ):
@@ -692,6 +703,10 @@ async def connect(dsn=None, *,
692
703
:param float timeout: connection timeout in seconds.
693
704
694
705
:param int statement_cache_size: the size of prepared statement LRU cache.
706
+ Pass ``0`` to disable the cache.
707
+
708
+ :param int max_cached_statement_use_count: max number of uses for a cached
709
+ prepared statement.
695
710
696
711
:param float command_timeout: the default timeout for operations on
697
712
this connection (the default is no timeout).
@@ -710,6 +725,9 @@ async def connect(dsn=None, *,
710
725
... print(types)
711
726
>>> asyncio.get_event_loop().run_until_complete(run())
712
727
[<Record typname='bool' typnamespace=11 ...
728
+
729
+ .. versionchanged:: 0.10.0
730
+ Added ``max_cached_statement_use_count`` parameter.
713
731
"""
714
732
if loop is None :
715
733
loop = asyncio .get_event_loop ()
@@ -753,13 +771,24 @@ async def connect(dsn=None, *,
753
771
tr .close ()
754
772
raise
755
773
756
- con = __connection_class__ (pr , tr , loop , addr , opts ,
757
- statement_cache_size = statement_cache_size ,
758
- command_timeout = command_timeout )
774
+ con = __connection_class__ (
775
+ pr , tr , loop , addr , opts ,
776
+ statement_cache_size = statement_cache_size ,
777
+ max_cached_statement_use_count = max_cached_statement_use_count ,
778
+ command_timeout = command_timeout )
779
+
759
780
pr .set_connection (con )
760
781
return con
761
782
762
783
784
+ class _StatementCacheHolder :
785
+ __slots__ = ('statement' , 'use_count' )
786
+
787
+ def __init__ (self , statement ):
788
+ self .use_count = 1
789
+ self .statement = statement
790
+
791
+
763
792
class _Atomic :
764
793
__slots__ = ('_acquired' ,)
765
794
0 commit comments