Skip to content

Implement support for pool connection rotation #307

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 5, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 18 additions & 6 deletions asyncpg/_testbase/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,17 +328,29 @@ def get_connection_spec(cls, kwargs={}):
conn_spec['user'] = 'postgres'
return conn_spec

def create_pool(self, pool_class=pg_pool.Pool,
connection_class=pg_connection.Connection, **kwargs):
conn_spec = self.get_connection_spec(kwargs)
return create_pool(loop=self.loop, pool_class=pool_class,
connection_class=connection_class, **conn_spec)

@classmethod
def connect(cls, **kwargs):
conn_spec = cls.get_connection_spec(kwargs)
return pg_connection.connect(**conn_spec, loop=cls.loop)

def setUp(self):
super().setUp()
self._pools = []

def tearDown(self):
super().tearDown()
for pool in self._pools:
pool.terminate()
self._pools = []

def create_pool(self, pool_class=pg_pool.Pool,
connection_class=pg_connection.Connection, **kwargs):
conn_spec = self.get_connection_spec(kwargs)
pool = create_pool(loop=self.loop, pool_class=pool_class,
connection_class=connection_class, **conn_spec)
self._pools.append(pool)
return pool


class ProxiedClusterTestCase(ClusterTestCase):
@classmethod
Expand Down
157 changes: 112 additions & 45 deletions asyncpg/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import functools
import inspect
import time
import warnings

from . import connection
from . import connect_utils
Expand Down Expand Up @@ -92,67 +93,46 @@ def __repr__(self):
class PoolConnectionHolder:

__slots__ = ('_con', '_pool', '_loop', '_proxy',
'_connect_args', '_connect_kwargs',
'_max_queries', '_setup', '_init',
'_max_queries', '_setup',
'_max_inactive_time', '_in_use',
'_inactive_callback', '_timeout')
'_inactive_callback', '_timeout',
'_generation')

def __init__(self, pool, *, connect_args, connect_kwargs,
max_queries, setup, init, max_inactive_time):
def __init__(self, pool, *, max_queries, setup, max_inactive_time):

self._pool = pool
self._con = None
self._proxy = None

self._connect_args = connect_args
self._connect_kwargs = connect_kwargs
self._max_queries = max_queries
self._max_inactive_time = max_inactive_time
self._setup = setup
self._init = init
self._inactive_callback = None
self._in_use = None # type: asyncio.Future
self._timeout = None
self._generation = None

async def connect(self):
if self._con is not None:
raise exceptions.InternalClientError(
'PoolConnectionHolder.connect() called while another '
'connection already exists')

if self._pool._working_addr is None:
# First connection attempt on this pool.
con = await connection.connect(
*self._connect_args,
loop=self._pool._loop,
connection_class=self._pool._connection_class,
**self._connect_kwargs)

self._pool._working_addr = con._addr
self._pool._working_config = con._config
self._pool._working_params = con._params

else:
# We've connected before and have a resolved address,
# and parsed options and config.
con = await connect_utils._connect_addr(
loop=self._pool._loop,
addr=self._pool._working_addr,
timeout=self._pool._working_params.connect_timeout,
config=self._pool._working_config,
params=self._pool._working_params,
connection_class=self._pool._connection_class)

if self._init is not None:
await self._init(con)

self._con = con
self._con = await self._pool._get_new_connection()
self._generation = self._pool._generation

async def acquire(self) -> PoolConnectionProxy:
if self._con is None or self._con.is_closed():
self._con = None
await self.connect()

elif self._generation != self._pool._generation:
# Connections have been expired, re-connect the holder.
self._pool._loop.create_task(
self._con.close(timeout=self._timeout))
self._con = None
await self.connect()

self._maybe_cancel_inactive_callback()

self._proxy = proxy = PoolConnectionProxy(self, self._con)
Expand Down Expand Up @@ -197,6 +177,13 @@ async def release(self, timeout):
await self._con.close(timeout=timeout)
return

if self._generation != self._pool._generation:
# The connection has expired because it belongs to
# an older generation (Pool.expire_connections() has
# been called.)
await self._con.close(timeout=timeout)
return

try:
budget = timeout

Expand Down Expand Up @@ -312,9 +299,10 @@ class Pool:
"""

__slots__ = ('_queue', '_loop', '_minsize', '_maxsize',
'_init', '_connect_args', '_connect_kwargs',
'_working_addr', '_working_config', '_working_params',
'_holders', '_initialized', '_closing', '_closed',
'_connection_class')
'_connection_class', '_generation')

def __init__(self, *connect_args,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please add a DeprecationWarning if len(connect_args) > 1?

min_size,
Expand All @@ -327,6 +315,14 @@ def __init__(self, *connect_args,
connection_class,
**connect_kwargs):

if len(connect_args) > 1:
warnings.warn(
"Passing multiple positional arguments to asyncpg.Pool "
"constructor is deprecated and will be removed in "
"asyncpg 0.17.0. The non-deprecated form is "
"asyncpg.Pool(<dsn>, **kwargs)",
DeprecationWarning, stacklevel=2)

if loop is None:
loop = asyncio.get_event_loop()
self._loop = loop
Expand All @@ -349,6 +345,11 @@ def __init__(self, *connect_args,
'max_inactive_connection_lifetime is expected to be greater '
'or equal to zero')

if not issubclass(connection_class, connection.Connection):
raise TypeError(
'connection_class is expected to be a subclass of '
'asyncpg.Connection, got {!r}'.format(connection_class))

self._minsize = min_size
self._maxsize = max_size

Expand All @@ -364,16 +365,17 @@ def __init__(self, *connect_args,

self._closing = False
self._closed = False
self._generation = 0
self._init = init
self._connect_args = connect_args
self._connect_kwargs = connect_kwargs

for _ in range(max_size):
ch = PoolConnectionHolder(
self,
connect_args=connect_args,
connect_kwargs=connect_kwargs,
max_queries=max_queries,
max_inactive_time=max_inactive_connection_lifetime,
setup=setup,
init=init)
setup=setup)

self._holders.append(ch)
self._queue.put_nowait(ch)
Expand Down Expand Up @@ -409,6 +411,62 @@ async def _async__init__(self):
self._initialized = True
return self

def set_connect_args(self, dsn=None, **connect_kwargs):
r"""Set the new connection arguments for this pool.

The new connection arguments will be used for all subsequent
new connection attempts. Existing connections will remain until
they expire. Use :meth:`Pool.expire_connections()
<asyncpg.pool.Pool.expire_connections>` to expedite the connection
expiry.

:param str dsn:
Connection arguments specified using as a single string in
the following format:
``postgres://user:pass@host:port/database?option=value``.

:param \*\*connect_kwargs:
Keyword arguments for the :func:`~asyncpg.connection.connect`
function.

.. versionadded:: 0.16.0
"""

self._connect_args = [dsn]
self._connect_kwargs = connect_kwargs
self._working_addr = None
self._working_config = None
self._working_params = None

async def _get_new_connection(self):
if self._working_addr is None:
# First connection attempt on this pool.
con = await connection.connect(
*self._connect_args,
loop=self._loop,
connection_class=self._connection_class,
**self._connect_kwargs)

self._working_addr = con._addr
self._working_config = con._config
self._working_params = con._params

else:
# We've connected before and have a resolved address,
# and parsed options and config.
con = await connect_utils._connect_addr(
loop=self._loop,
addr=self._working_addr,
timeout=self._working_params.connect_timeout,
config=self._working_config,
params=self._working_params,
connection_class=self._connection_class)

if self._init is not None:
await self._init(con)

return con

async def execute(self, query: str, *args, timeout: float=None) -> str:
"""Execute an SQL command (or commands).

Expand Down Expand Up @@ -602,6 +660,16 @@ def terminate(self):
ch.terminate()
self._closed = True

async def expire_connections(self):
"""Expire all currently open connections.

Cause all currently open connections to get replaced on the
next :meth:`~asyncpg.pool.Pool.acquire()` call.

.. versionadded:: 0.16.0
"""
self._generation += 1

def _check_init(self):
if not self._initialized:
raise exceptions.InterfaceError('pool is not initialized')
Expand Down Expand Up @@ -708,6 +776,10 @@ def create_pool(dsn=None, *,
Keyword arguments for the :func:`~asyncpg.connection.connect`
function.

:param Connection connection_class:
The class to use for connections. Must be a subclass of
:class:`~asyncpg.connection.Connection`.

:param int min_size:
Number of connection the pool will be initialized with.

Expand Down Expand Up @@ -759,11 +831,6 @@ def create_pool(dsn=None, *,
<connection.Connection.add_log_listener>`) present on the connection
at the moment of its release to the pool.
"""
if not issubclass(connection_class, connection.Connection):
raise TypeError(
'connection_class is expected to be a subclass of '
'asyncpg.Connection, got {!r}'.format(connection_class))

return Pool(
dsn,
connection_class=connection_class,
Expand Down
53 changes: 53 additions & 0 deletions tests/test_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -794,6 +794,59 @@ async def worker():

await task

async def test_pool_expire_connections(self):
pool = await self.create_pool(database='postgres',
min_size=1, max_size=1)

con = await pool.acquire()
try:
await pool.expire_connections()
finally:
await pool.release(con)

self.assertIsNone(pool._holders[0]._con)

async def test_pool_set_connection_args(self):
pool = await self.create_pool(database='postgres',
min_size=1, max_size=1)

# Test that connection is expired on release.
con = await pool.acquire()
connspec = self.get_connection_spec()
try:
connspec['server_settings']['application_name'] = \
'set_conn_args_test'
except KeyError:
connspec['server_settings'] = {
'application_name': 'set_conn_args_test'
}

pool.set_connect_args(**connspec)
await pool.expire_connections()
await pool.release(con)

con = await pool.acquire()
self.assertEqual(con.get_settings().application_name,
'set_conn_args_test')
await pool.release(con)

# Test that connection is expired before acquire.
connspec = self.get_connection_spec()
try:
connspec['server_settings']['application_name'] = \
'set_conn_args_test'
except KeyError:
connspec['server_settings'] = {
'application_name': 'set_conn_args_test_2'
}

pool.set_connect_args(**connspec)
await pool.expire_connections()

con = await pool.acquire()
self.assertEqual(con.get_settings().application_name,
'set_conn_args_test_2')


@unittest.skipIf(os.environ.get('PGHOST'), 'using remote cluster for testing')
class TestHotStandby(tb.ClusterTestCase):
Expand Down