Skip to content

Commit 4badb33

Browse files
committed
Implement support for pool connection rotation
The new `Pool.cycle()` method expires all currently open connections, so they would be replaced with fresh ones on the next `acquire()` attempt. The new `Pool.set_connect_args()` allows changing the connection arguments for an existing pool instance. Coupled with `cycle()`, it allows adapting the pool to the new environment conditions without having to replace the pool instance. Fixes: #291
1 parent 65c0caa commit 4badb33

File tree

3 files changed

+169
-51
lines changed

3 files changed

+169
-51
lines changed

asyncpg/_testbase/__init__.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -328,17 +328,29 @@ def get_connection_spec(cls, kwargs={}):
328328
conn_spec['user'] = 'postgres'
329329
return conn_spec
330330

331-
def create_pool(self, pool_class=pg_pool.Pool,
332-
connection_class=pg_connection.Connection, **kwargs):
333-
conn_spec = self.get_connection_spec(kwargs)
334-
return create_pool(loop=self.loop, pool_class=pool_class,
335-
connection_class=connection_class, **conn_spec)
336-
337331
@classmethod
338332
def connect(cls, **kwargs):
339333
conn_spec = cls.get_connection_spec(kwargs)
340334
return pg_connection.connect(**conn_spec, loop=cls.loop)
341335

336+
def setUp(self):
337+
super().setUp()
338+
self._pools = []
339+
340+
def tearDown(self):
341+
super().tearDown()
342+
for pool in self._pools:
343+
pool.terminate()
344+
self._pools = []
345+
346+
def create_pool(self, pool_class=pg_pool.Pool,
347+
connection_class=pg_connection.Connection, **kwargs):
348+
conn_spec = self.get_connection_spec(kwargs)
349+
pool = create_pool(loop=self.loop, pool_class=pool_class,
350+
connection_class=connection_class, **conn_spec)
351+
self._pools.append(pool)
352+
return pool
353+
342354

343355
class ProxiedClusterTestCase(ClusterTestCase):
344356
@classmethod

asyncpg/pool.py

Lines changed: 114 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -92,67 +92,46 @@ def __repr__(self):
9292
class PoolConnectionHolder:
9393

9494
__slots__ = ('_con', '_pool', '_loop', '_proxy',
95-
'_connect_args', '_connect_kwargs',
96-
'_max_queries', '_setup', '_init',
95+
'_max_queries', '_setup',
9796
'_max_inactive_time', '_in_use',
98-
'_inactive_callback', '_timeout')
97+
'_inactive_callback', '_timeout',
98+
'_generation')
9999

100-
def __init__(self, pool, *, connect_args, connect_kwargs,
101-
max_queries, setup, init, max_inactive_time):
100+
def __init__(self, pool, *, max_queries, setup, max_inactive_time):
102101

103102
self._pool = pool
104103
self._con = None
105104
self._proxy = None
106105

107-
self._connect_args = connect_args
108-
self._connect_kwargs = connect_kwargs
109106
self._max_queries = max_queries
110107
self._max_inactive_time = max_inactive_time
111108
self._setup = setup
112-
self._init = init
113109
self._inactive_callback = None
114110
self._in_use = None # type: asyncio.Future
115111
self._timeout = None
112+
self._generation = None
116113

117114
async def connect(self):
118115
if self._con is not None:
119116
raise exceptions.InternalClientError(
120117
'PoolConnectionHolder.connect() called while another '
121118
'connection already exists')
122119

123-
if self._pool._working_addr is None:
124-
# First connection attempt on this pool.
125-
con = await connection.connect(
126-
*self._connect_args,
127-
loop=self._pool._loop,
128-
connection_class=self._pool._connection_class,
129-
**self._connect_kwargs)
130-
131-
self._pool._working_addr = con._addr
132-
self._pool._working_config = con._config
133-
self._pool._working_params = con._params
134-
135-
else:
136-
# We've connected before and have a resolved address,
137-
# and parsed options and config.
138-
con = await connect_utils._connect_addr(
139-
loop=self._pool._loop,
140-
addr=self._pool._working_addr,
141-
timeout=self._pool._working_params.connect_timeout,
142-
config=self._pool._working_config,
143-
params=self._pool._working_params,
144-
connection_class=self._pool._connection_class)
145-
146-
if self._init is not None:
147-
await self._init(con)
148-
149-
self._con = con
120+
self._con = await self._pool._get_new_connection()
121+
self._generation = self._pool._generation
150122

151123
async def acquire(self) -> PoolConnectionProxy:
152124
if self._con is None or self._con.is_closed():
153125
self._con = None
154126
await self.connect()
155127

128+
elif self._generation != self._pool._generation:
129+
# Connections have been expired, re-connect the holder.
130+
self._pool._loop.create_task(
131+
self._con.close(timeout=self._timeout))
132+
self._con = None
133+
await self.connect()
134+
156135
self._maybe_cancel_inactive_callback()
157136

158137
self._proxy = proxy = PoolConnectionProxy(self, self._con)
@@ -197,6 +176,13 @@ async def release(self, timeout):
197176
await self._con.close(timeout=timeout)
198177
return
199178

179+
if self._generation != self._pool._generation:
180+
# The connection has expired because it belongs to
181+
# an older generation (Pool.expire_connections() has
182+
# been called.)
183+
await self._con.close(timeout=timeout)
184+
return
185+
200186
try:
201187
budget = timeout
202188

@@ -312,9 +298,10 @@ class Pool:
312298
"""
313299

314300
__slots__ = ('_queue', '_loop', '_minsize', '_maxsize',
301+
'_init', '_connect_args', '_connect_kwargs',
315302
'_working_addr', '_working_config', '_working_params',
316303
'_holders', '_initialized', '_closing', '_closed',
317-
'_connection_class')
304+
'_connection_class', '_generation')
318305

319306
def __init__(self, *connect_args,
320307
min_size,
@@ -349,6 +336,11 @@ def __init__(self, *connect_args,
349336
'max_inactive_connection_lifetime is expected to be greater '
350337
'or equal to zero')
351338

339+
if not issubclass(connection_class, connection.Connection):
340+
raise TypeError(
341+
'connection_class is expected to be a subclass of '
342+
'asyncpg.Connection, got {!r}'.format(connection_class))
343+
352344
self._minsize = min_size
353345
self._maxsize = max_size
354346

@@ -364,16 +356,17 @@ def __init__(self, *connect_args,
364356

365357
self._closing = False
366358
self._closed = False
359+
self._generation = 0
360+
self._init = init
361+
self._connect_args = connect_args
362+
self._connect_kwargs = connect_kwargs
367363

368364
for _ in range(max_size):
369365
ch = PoolConnectionHolder(
370366
self,
371-
connect_args=connect_args,
372-
connect_kwargs=connect_kwargs,
373367
max_queries=max_queries,
374368
max_inactive_time=max_inactive_connection_lifetime,
375-
setup=setup,
376-
init=init)
369+
setup=setup)
377370

378371
self._holders.append(ch)
379372
self._queue.put_nowait(ch)
@@ -409,6 +402,73 @@ async def _async__init__(self):
409402
self._initialized = True
410403
return self
411404

405+
def set_connect_args(self, dsn=None, *,
406+
connection_class=connection.Connection, **kwargs):
407+
"""Set the new connection arguments for this pool.
408+
409+
The new connection arguments will be used for all subsequent
410+
new connection attempts. Existing connections will remain until
411+
they expire. Use :meth:`Pool.expire_connections()
412+
<asyncpg.pool.Pool.expire_connections>` to expedite the connection
413+
expiry.
414+
415+
:param str dsn:
416+
Connection arguments specified using as a single string in
417+
the following format:
418+
``postgres://user:pass@host:port/database?option=value``.
419+
420+
:param \*\*connect_kwargs:
421+
Keyword arguments for the :func:`~asyncpg.connection.connect`
422+
function.
423+
424+
:param Connection connection_class:
425+
The class to use for connections. Must be a subclass of
426+
:class:`~asyncpg.connection.Connection`.
427+
428+
.. versionadded:: 0.16.0
429+
"""
430+
431+
if not issubclass(connection_class, connection.Connection):
432+
raise TypeError(
433+
'connection_class is expected to be a subclass of '
434+
'asyncpg.Connection, got {!r}'.format(connection_class))
435+
436+
self._connect_args = [dsn]
437+
self._connect_kwargs = kwargs
438+
self._connection_class = connection_class
439+
self._working_addr = None
440+
self._working_config = None
441+
self._working_params = None
442+
443+
async def _get_new_connection(self):
444+
if self._working_addr is None:
445+
# First connection attempt on this pool.
446+
con = await connection.connect(
447+
*self._connect_args,
448+
loop=self._loop,
449+
connection_class=self._connection_class,
450+
**self._connect_kwargs)
451+
452+
self._working_addr = con._addr
453+
self._working_config = con._config
454+
self._working_params = con._params
455+
456+
else:
457+
# We've connected before and have a resolved address,
458+
# and parsed options and config.
459+
con = await connect_utils._connect_addr(
460+
loop=self._loop,
461+
addr=self._working_addr,
462+
timeout=self._working_params.connect_timeout,
463+
config=self._working_config,
464+
params=self._working_params,
465+
connection_class=self._connection_class)
466+
467+
if self._init is not None:
468+
await self._init(con)
469+
470+
return con
471+
412472
async def execute(self, query: str, *args, timeout: float=None) -> str:
413473
"""Execute an SQL command (or commands).
414474
@@ -602,6 +662,16 @@ def terminate(self):
602662
ch.terminate()
603663
self._closed = True
604664

665+
async def expire_connections(self):
666+
"""Expire all currently open connections.
667+
668+
Cause all currently open connections to get replaced on the
669+
next :meth:`~asyncpg.pool.Pool.acquire()` call.
670+
671+
.. versionadded:: 0.16.0
672+
"""
673+
self._generation += 1
674+
605675
def _check_init(self):
606676
if not self._initialized:
607677
raise exceptions.InterfaceError('pool is not initialized')
@@ -708,6 +778,10 @@ def create_pool(dsn=None, *,
708778
Keyword arguments for the :func:`~asyncpg.connection.connect`
709779
function.
710780
781+
:param Connection connection_class:
782+
The class to use for connections. Must be a subclass of
783+
:class:`~asyncpg.connection.Connection`.
784+
711785
:param int min_size:
712786
Number of connection the pool will be initialized with.
713787
@@ -759,11 +833,6 @@ def create_pool(dsn=None, *,
759833
<connection.Connection.add_log_listener>`) present on the connection
760834
at the moment of its release to the pool.
761835
"""
762-
if not issubclass(connection_class, connection.Connection):
763-
raise TypeError(
764-
'connection_class is expected to be a subclass of '
765-
'asyncpg.Connection, got {!r}'.format(connection_class))
766-
767836
return Pool(
768837
dsn,
769838
connection_class=connection_class,

tests/test_pool.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -794,6 +794,43 @@ async def worker():
794794

795795
await task
796796

797+
async def test_pool_expire_connections(self):
798+
pool = await self.create_pool(database='postgres',
799+
min_size=1, max_size=1)
800+
801+
con = await pool.acquire()
802+
try:
803+
await pool.expire_connections()
804+
finally:
805+
await pool.release(con)
806+
807+
self.assertIsNone(pool._holders[0]._con)
808+
809+
async def test_pool_set_connection_args(self):
810+
pool = await self.create_pool(database='postgres',
811+
min_size=1, max_size=1)
812+
813+
# Test that connection is expired on release.
814+
con = await pool.acquire()
815+
pool.set_connect_args(
816+
server_settings={'application_name': 'set_conn_args_test'})
817+
await pool.expire_connections()
818+
await pool.release(con)
819+
820+
con = await pool.acquire()
821+
self.assertEqual(con.get_settings().application_name,
822+
'set_conn_args_test')
823+
await pool.release(con)
824+
825+
# Test that connection is expired before acquire.
826+
pool.set_connect_args(
827+
server_settings={'application_name': 'set_conn_args_test_2'})
828+
await pool.expire_connections()
829+
830+
con = await pool.acquire()
831+
self.assertEqual(con.get_settings().application_name,
832+
'set_conn_args_test_2')
833+
797834

798835
@unittest.skipIf(os.environ.get('PGHOST'), 'using remote cluster for testing')
799836
class TestHotStandby(tb.ClusterTestCase):

0 commit comments

Comments
 (0)