Skip to content

Add new max_cached_statement_lifetime parameter to Connection. #109

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 3, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion asyncpg/_testbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,16 +190,32 @@ def start_cluster(cls, ClusterCls, *,
return _start_cluster(ClusterCls, cluster_kwargs, server_settings)


def with_connection_options(**options):
if not options:
raise ValueError('no connection options were specified')

def wrap(func):
func.__connect_options__ = options
return func

return wrap


class ConnectedTestCase(ClusterTestCase):

def getExtraConnectOptions(self):
return {}

def setUp(self):
super().setUp()
opts = self.getExtraConnectOptions()

# Extract options set up with `with_connection_options`.
test_func = getattr(self, self._testMethodName).__func__
opts = getattr(test_func, '__connect_options__', {})

self.con = self.loop.run_until_complete(
self.cluster.connect(database='postgres', loop=self.loop, **opts))

self.server_version = self.con.get_server_version()

def tearDown(self):
Expand Down
232 changes: 198 additions & 34 deletions asyncpg/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,14 @@ class Connection(metaclass=ConnectionMeta):

__slots__ = ('_protocol', '_transport', '_loop', '_types_stmt',
'_type_by_name_stmt', '_top_xact', '_uid', '_aborted',
'_stmt_cache_max_size', '_stmt_cache', '_stmts_to_close',
'_stmt_cache', '_stmts_to_close',
'_addr', '_opts', '_command_timeout', '_listeners',
'_server_version', '_server_caps', '_intro_query',
'_reset_query', '_proxy', '_stmt_exclusive_section')

def __init__(self, protocol, transport, loop, addr, opts, *,
statement_cache_size, command_timeout):
statement_cache_size, command_timeout,
max_cached_statement_lifetime):
self._protocol = protocol
self._transport = transport
self._loop = loop
Expand All @@ -58,8 +59,12 @@ def __init__(self, protocol, transport, loop, addr, opts, *,
self._addr = addr
self._opts = opts

self._stmt_cache_max_size = statement_cache_size
self._stmt_cache = collections.OrderedDict()
self._stmt_cache = _StatementCache(
loop=loop,
max_size=statement_cache_size,
on_remove=self._maybe_gc_stmt,
max_lifetime=max_cached_statement_lifetime)

self._stmts_to_close = set()

if command_timeout is not None:
Expand Down Expand Up @@ -126,6 +131,8 @@ async def add_listener(self, channel, callback):

async def remove_listener(self, channel, callback):
"""Remove a listening callback on the specified channel."""
if self.is_closed():
return
if channel not in self._listeners:
return
if callback not in self._listeners[channel]:
Expand Down Expand Up @@ -266,46 +273,33 @@ async def executemany(self, command: str, args,
return await self._executemany(command, args, timeout)

async def _get_statement(self, query, timeout, *, named: bool=False):
use_cache = self._stmt_cache_max_size > 0
if use_cache:
try:
state = self._stmt_cache[query]
except KeyError:
pass
else:
self._stmt_cache.move_to_end(query, last=True)
if not state.closed:
return state

protocol = self._protocol
statement = self._stmt_cache.get(query)
if statement is not None:
return statement

if use_cache or named:
if self._stmt_cache.get_max_size() or named:
stmt_name = self._get_unique_id('stmt')
else:
stmt_name = ''

state = await protocol.prepare(stmt_name, query, timeout)
statement = await self._protocol.prepare(stmt_name, query, timeout)

ready = state._init_types()
ready = statement._init_types()
if ready is not True:
if self._types_stmt is None:
self._types_stmt = await self.prepare(self._intro_query)

types = await self._types_stmt.fetch(list(ready))
protocol.get_settings().register_data_types(types)
self._protocol.get_settings().register_data_types(types)

if use_cache:
if len(self._stmt_cache) > self._stmt_cache_max_size - 1:
old_query, old_state = self._stmt_cache.popitem(last=False)
self._maybe_gc_stmt(old_state)
self._stmt_cache[query] = state
self._stmt_cache.put(query, statement)

# If we've just created a new statement object, check if there
# are any statements for GC.
if self._stmts_to_close:
await self._cleanup_stmts()

return state
return statement

def cursor(self, query, *args, prefetch=None, timeout=None):
"""Return a *cursor factory* for the specified query.
Expand Down Expand Up @@ -457,14 +451,14 @@ async def close(self):
"""Close the connection gracefully."""
if self.is_closed():
return
self._close_stmts()
self._mark_stmts_as_closed()
self._listeners = {}
self._aborted = True
await self._protocol.close()

def terminate(self):
"""Terminate the connection without waiting for pending data."""
self._close_stmts()
self._mark_stmts_as_closed()
self._listeners = {}
self._aborted = True
self._protocol.abort()
Expand All @@ -484,8 +478,8 @@ def _get_unique_id(self, prefix):
self._uid += 1
return '__asyncpg_{}_{}__'.format(prefix, self._uid)

def _close_stmts(self):
for stmt in self._stmt_cache.values():
def _mark_stmts_as_closed(self):
for stmt in self._stmt_cache.iter_statements():
stmt.mark_closed()

for stmt in self._stmts_to_close:
Expand All @@ -495,11 +489,22 @@ def _close_stmts(self):
self._stmts_to_close.clear()

def _maybe_gc_stmt(self, stmt):
if stmt.refs == 0 and stmt.query not in self._stmt_cache:
if stmt.refs == 0 and not self._stmt_cache.has(stmt.query):
# If low-level `stmt` isn't referenced from any high-level
# `PreparedStatament` object and is not in the `_stmt_cache`:
#
# * mark it as closed, which will make it non-usable
# for any `PreparedStatament` or for methods like
# `Connection.fetch()`.
#
# * schedule it to be formally closed on the server.
stmt.mark_closed()
self._stmts_to_close.add(stmt)

async def _cleanup_stmts(self):
# Called whenever we create a new prepared statement in
# `Connection._get_statement()` and `_stmts_to_close` is
# not empty.
to_close = self._stmts_to_close
self._stmts_to_close = set()
for stmt in to_close:
Expand Down Expand Up @@ -700,6 +705,7 @@ async def connect(dsn=None, *,
loop=None,
timeout=60,
statement_cache_size=100,
max_cached_statement_lifetime=300,
command_timeout=None,
__connection_class__=Connection,
**opts):
Expand Down Expand Up @@ -735,6 +741,12 @@ async def connect(dsn=None, *,
:param float timeout: connection timeout in seconds.

:param int statement_cache_size: the size of prepared statement LRU cache.
Pass ``0`` to disable the cache.

:param int max_cached_statement_lifetime:
the maximum time in seconds a prepared statement will stay
in the cache. Pass ``0`` to allow statements be cached
indefinitely.

:param float command_timeout: the default timeout for operations on
this connection (the default is no timeout).
Expand All @@ -753,6 +765,9 @@ async def connect(dsn=None, *,
... print(types)
>>> asyncio.get_event_loop().run_until_complete(run())
[<Record typname='bool' typnamespace=11 ...

.. versionchanged:: 0.10.0
Added ``max_cached_statement_use_count`` parameter.
"""
if loop is None:
loop = asyncio.get_event_loop()
Expand Down Expand Up @@ -796,13 +811,162 @@ async def connect(dsn=None, *,
tr.close()
raise

con = __connection_class__(pr, tr, loop, addr, opts,
statement_cache_size=statement_cache_size,
command_timeout=command_timeout)
con = __connection_class__(
pr, tr, loop, addr, opts,
statement_cache_size=statement_cache_size,
max_cached_statement_lifetime=max_cached_statement_lifetime,
command_timeout=command_timeout)

pr.set_connection(con)
return con


class _StatementCacheEntry:

__slots__ = ('_query', '_statement', '_cache', '_cleanup_cb')

def __init__(self, cache, query, statement):
self._cache = cache
self._query = query
self._statement = statement
self._cleanup_cb = None


class _StatementCache:

__slots__ = ('_loop', '_entries', '_max_size', '_on_remove',
'_max_lifetime')

def __init__(self, *, loop, max_size, on_remove, max_lifetime):
self._loop = loop
self._max_size = max_size
self._on_remove = on_remove
self._max_lifetime = max_lifetime

# We use an OrderedDict for LRU implementation. Operations:
#
# * We use a simple `__setitem__` to push a new entry:
# `entries[key] = new_entry`
# That will push `new_entry` to the *end* of the entries dict.
#
# * When we have a cache hit, we call
# `entries.move_to_end(key, last=True)`
# to move the entry to the *end* of the entries dict.
#
# * When we need to remove entries to maintain `max_size`, we call
# `entries.popitem(last=False)`
# to remove an entry from the *beginning* of the entries dict.
#
# So new entries and hits are always promoted to the end of the
# entries dict, whereas the unused one will group in the
# beginning of it.
self._entries = collections.OrderedDict()

def __len__(self):
return len(self._entries)

def get_max_size(self):
return self._max_size

def set_max_size(self, new_size):
assert new_size >= 0
self._max_size = new_size
self._maybe_cleanup()

def get_max_lifetime(self):
return self._max_lifetime

def set_max_lifetime(self, new_lifetime):
assert new_lifetime >= 0
self._max_lifetime = new_lifetime
for entry in self._entries.values():
# For every entry cancel the existing callback
# and setup a new one if necessary.
self._set_entry_timeout(entry)

def get(self, query, *, promote=True):
if not self._max_size:
# The cache is disabled.
return

entry = self._entries.get(query) # type: _StatementCacheEntry
if entry is None:
return

if entry._statement.closed:
# Happens in unittests when we call `stmt._state.mark_closed()`
# manually.
self._entries.pop(query)
self._clear_entry_callback(entry)
return

if promote:
# `promote` is `False` when `get()` is called by `has()`.
self._entries.move_to_end(query, last=True)

return entry._statement

def has(self, query):
return self.get(query, promote=False) is not None

def put(self, query, statement):
if not self._max_size:
# The cache is disabled.
return

self._entries[query] = self._new_entry(query, statement)

# Check if the cache is bigger than max_size and trim it
# if necessary.
self._maybe_cleanup()

def iter_statements(self):
return (e._statement for e in self._entries.values())

def clear(self):
# First, make sure that we cancel all scheduled callbacks.
for entry in self._entries.values():
self._clear_entry_callback(entry)

# Clear the entries dict.
self._entries.clear()

def _set_entry_timeout(self, entry):
# Clear the existing timeout.
self._clear_entry_callback(entry)

# Set the new timeout if it's not 0.
if self._max_lifetime:
entry._cleanup_cb = self._loop.call_later(
self._max_lifetime, self._on_entry_expired, entry)

def _new_entry(self, query, statement):
entry = _StatementCacheEntry(self, query, statement)
self._set_entry_timeout(entry)
return entry

def _on_entry_expired(self, entry):
# `call_later` callback, called when an entry stayed longer
# than `self._max_lifetime`.
if self._entries.get(entry._query) is entry:
self._entries.pop(entry._query)
self._on_remove(entry._statement)

def _clear_entry_callback(self, entry):
if entry._cleanup_cb is not None:
entry._cleanup_cb.cancel()

def _maybe_cleanup(self):
# Delete cache entries until the size of the cache is `max_size`.
while len(self._entries) > self._max_size:
old_query, old_entry = self._entries.popitem(last=False)
self._clear_entry_callback(old_entry)

# Let the connection know that the statement was removed
# from the cache.
self._on_remove(old_entry._statement)


class _Atomic:
__slots__ = ('_acquired',)

Expand Down
Loading