Skip to content

Commit 64661ab

Browse files
author
rony batista
committed
Add target session attribute connection param
1 parent bd19262 commit 64661ab

File tree

6 files changed

+261
-58
lines changed

6 files changed

+261
-58
lines changed

asyncpg/_testbase/__init__.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -435,3 +435,92 @@ def tearDown(self):
435435
self.con = None
436436
finally:
437437
super().tearDown()
438+
439+
440+
class HotStandbyTestCase(ClusterTestCase):
441+
@classmethod
442+
def setup_cluster(cls):
443+
cls.master_cluster = cls.new_cluster(pg_cluster.TempCluster)
444+
cls.start_cluster(
445+
cls.master_cluster,
446+
server_settings={
447+
'max_wal_senders': 10,
448+
'wal_level': 'hot_standby'
449+
}
450+
)
451+
452+
con = None
453+
454+
try:
455+
con = cls.loop.run_until_complete(
456+
cls.master_cluster.connect(
457+
database='postgres', user='postgres', loop=cls.loop))
458+
459+
cls.loop.run_until_complete(
460+
con.execute('''
461+
CREATE ROLE replication WITH LOGIN REPLICATION
462+
'''))
463+
464+
cls.master_cluster.trust_local_replication_by('replication')
465+
466+
conn_spec = cls.master_cluster.get_connection_spec()
467+
468+
cls.standby_cluster = cls.new_cluster(
469+
pg_cluster.HotStandbyCluster,
470+
cluster_kwargs={
471+
'master': conn_spec,
472+
'replication_user': 'replication'
473+
}
474+
)
475+
cls.start_cluster(
476+
cls.standby_cluster,
477+
server_settings={
478+
'hot_standby': True
479+
}
480+
)
481+
482+
finally:
483+
if con is not None:
484+
cls.loop.run_until_complete(con.close())
485+
486+
@classmethod
487+
def get_cluster_connection_spec(cls, cluster, kwargs={}):
488+
conn_spec = cluster.get_connection_spec()
489+
if kwargs.get('dsn'):
490+
conn_spec.pop('host')
491+
conn_spec.update(kwargs)
492+
if not os.environ.get('PGHOST') and not kwargs.get('dsn'):
493+
if 'database' not in conn_spec:
494+
conn_spec['database'] = 'postgres'
495+
if 'user' not in conn_spec:
496+
conn_spec['user'] = 'postgres'
497+
return conn_spec
498+
499+
@classmethod
500+
def get_connection_spec(cls, kwargs={}):
501+
primary_spec = cls.get_cluster_connection_spec(
502+
cls.master_cluster, kwargs
503+
)
504+
standby_spec = cls.get_cluster_connection_spec(
505+
cls.standby_cluster, kwargs
506+
)
507+
return {
508+
'host': [primary_spec['host'], standby_spec['host']],
509+
'port': [primary_spec['port'], standby_spec['port']],
510+
'database': primary_spec['database'],
511+
'user': primary_spec['user'],
512+
**kwargs
513+
}
514+
515+
@classmethod
516+
def connect_primary(cls, **kwargs):
517+
conn_spec = cls.get_cluster_connection_spec(cls.master_cluster, kwargs)
518+
return pg_connection.connect(**conn_spec, loop=cls.loop)
519+
520+
@classmethod
521+
def connect_standby(cls, **kwargs):
522+
conn_spec = cls.get_cluster_connection_spec(
523+
cls.standby_cluster,
524+
kwargs
525+
)
526+
return pg_connection.connect(**conn_spec, loop=cls.loop)

asyncpg/connect_utils.py

Lines changed: 77 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import os
1414
import pathlib
1515
import platform
16+
import random
1617
import re
1718
import socket
1819
import ssl as ssl_module
@@ -55,6 +56,7 @@ def parse(cls, sslmode):
5556
'sslmode',
5657
'connect_timeout',
5758
'server_settings',
59+
'target_session_attribute',
5860
])
5961

6062

@@ -258,7 +260,8 @@ def _dot_postgresql_path(filename) -> pathlib.Path:
258260

259261
def _parse_connect_dsn_and_args(*, dsn, host, port, user,
260262
password, passfile, database, ssl,
261-
connect_timeout, server_settings):
263+
connect_timeout, server_settings,
264+
target_session_attribute):
262265
# `auth_hosts` is the version of host information for the purposes
263266
# of reading the pgpass file.
264267
auth_hosts = None
@@ -602,7 +605,8 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
602605
params = _ConnectionParameters(
603606
user=user, password=password, database=database, ssl=ssl,
604607
sslmode=sslmode, connect_timeout=connect_timeout,
605-
server_settings=server_settings)
608+
server_settings=server_settings,
609+
target_session_attribute=target_session_attribute)
606610

607611
return addrs, params
608612

@@ -612,8 +616,8 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile,
612616
statement_cache_size,
613617
max_cached_statement_lifetime,
614618
max_cacheable_statement_size,
615-
ssl, server_settings):
616-
619+
ssl, server_settings,
620+
target_session_attribute):
617621
local_vars = locals()
618622
for var_name in {'max_cacheable_statement_size',
619623
'max_cached_statement_lifetime',
@@ -641,7 +645,8 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile,
641645
dsn=dsn, host=host, port=port, user=user,
642646
password=password, passfile=passfile, ssl=ssl,
643647
database=database, connect_timeout=timeout,
644-
server_settings=server_settings)
648+
server_settings=server_settings,
649+
target_session_attribute=target_session_attribute)
645650

646651
config = _ClientConfiguration(
647652
command_timeout=command_timeout,
@@ -866,18 +871,64 @@ async def __connect_addr(
866871
return con
867872

868873

874+
class SessionAttribute(str, enum.Enum):
875+
any = 'any'
876+
primary = 'primary'
877+
standby = 'standby'
878+
prefer_standby = 'prefer-standby'
879+
880+
881+
def _accept_in_hot_standby(should_be_in_hot_standby: bool):
882+
"""
883+
If the server didn't report "in_hot_standby" at startup, we must determine
884+
the state by checking "SELECT pg_catalog.pg_is_in_recovery()".
885+
"""
886+
async def can_be_used(connection):
887+
settings = connection.get_settings()
888+
hot_standby_status = getattr(settings, 'in_hot_standby', None)
889+
if hot_standby_status is not None:
890+
is_in_hot_standby = hot_standby_status == 'on'
891+
else:
892+
is_in_hot_standby = await connection.fetchval(
893+
"SELECT pg_catalog.pg_is_in_recovery()"
894+
)
895+
896+
return is_in_hot_standby == should_be_in_hot_standby
897+
898+
return can_be_used
899+
900+
901+
async def _accept_any(_):
902+
return True
903+
904+
905+
target_attrs_check = {
906+
SessionAttribute.any: _accept_any,
907+
SessionAttribute.primary: _accept_in_hot_standby(False),
908+
SessionAttribute.standby: _accept_in_hot_standby(True),
909+
SessionAttribute.prefer_standby: _accept_in_hot_standby(True),
910+
}
911+
912+
913+
async def _can_use_connection(connection, attr: SessionAttribute):
914+
can_use = target_attrs_check[attr]
915+
return await can_use(connection)
916+
917+
869918
async def _connect(*, loop, timeout, connection_class, record_class, **kwargs):
870919
if loop is None:
871920
loop = asyncio.get_event_loop()
872921

873922
addrs, params, config = _parse_connect_arguments(timeout=timeout, **kwargs)
923+
target_attr = params.target_session_attribute
874924

925+
candidates = []
926+
chosen_connection = None
875927
last_error = None
876-
addr = None
877928
for addr in addrs:
878929
before = time.monotonic()
879930
try:
880-
return await _connect_addr(
931+
conn = await _connect_addr(
881932
addr=addr,
882933
loop=loop,
883934
timeout=timeout,
@@ -886,12 +937,30 @@ async def _connect(*, loop, timeout, connection_class, record_class, **kwargs):
886937
connection_class=connection_class,
887938
record_class=record_class,
888939
)
940+
candidates.append(conn)
941+
if await _can_use_connection(conn, target_attr):
942+
chosen_connection = conn
943+
break
889944
except (OSError, asyncio.TimeoutError, ConnectionError) as ex:
890945
last_error = ex
891946
finally:
892947
timeout -= time.monotonic() - before
948+
else:
949+
if target_attr == SessionAttribute.prefer_standby and candidates:
950+
chosen_connection = random.choice(candidates)
951+
952+
await asyncio.gather(
953+
(c.close() for c in candidates if c is not chosen_connection),
954+
return_exceptions=True
955+
)
956+
957+
if chosen_connection:
958+
return chosen_connection
893959

894-
raise last_error
960+
raise last_error or exceptions.TargetServerAttributeNotMatched(
961+
'None of the hosts match the target attribute requirement '
962+
'{!r}'.format(target_attr)
963+
)
895964

896965

897966
async def _cancel(*, loop, addr, params: _ConnectionParameters,

asyncpg/connection.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from . import serverversion
3131
from . import transaction
3232
from . import utils
33+
from .connect_utils import SessionAttribute
3334

3435

3536
class ConnectionMeta(type):
@@ -1791,7 +1792,8 @@ async def connect(dsn=None, *,
17911792
ssl=None,
17921793
connection_class=Connection,
17931794
record_class=protocol.Record,
1794-
server_settings=None):
1795+
server_settings=None,
1796+
target_session_attribute=SessionAttribute.any):
17951797
r"""A coroutine to establish a connection to a PostgreSQL server.
17961798
17971799
The connection parameters may be specified either as a connection
@@ -1998,6 +2000,16 @@ async def connect(dsn=None, *,
19982000
this connection object. Must be a subclass of
19992001
:class:`~asyncpg.Record`.
20002002
2003+
:param SessionAttribute target_session_attribute:
2004+
If specified, check that the host has the correct attribute.
2005+
Can be one of:
2006+
"any": the first successfully connected host
2007+
"primary": the host must NOT be in hot standby mode
2008+
"standby": the host must be in hot standby mode
2009+
"prefer-standby": first try to find a standby host, but if
2010+
none of the listed hosts is a standby server,
2011+
return any of them.
2012+
20012013
:return: A :class:`~asyncpg.connection.Connection` instance.
20022014
20032015
Example:
@@ -2079,6 +2091,15 @@ async def connect(dsn=None, *,
20792091
if record_class is not protocol.Record:
20802092
_check_record_class(record_class)
20812093

2094+
try:
2095+
target_session_attribute = SessionAttribute(target_session_attribute)
2096+
except ValueError as exc:
2097+
raise exceptions.InterfaceError(
2098+
"target_session_attribute is expected to be one of "
2099+
"'any', 'primary', 'standby' or 'prefer-standby'"
2100+
", got {!r}".format(target_session_attribute)
2101+
) from exc
2102+
20822103
if loop is None:
20832104
loop = asyncio.get_event_loop()
20842105

@@ -2100,6 +2121,7 @@ async def connect(dsn=None, *,
21002121
statement_cache_size=statement_cache_size,
21012122
max_cached_statement_lifetime=max_cached_statement_lifetime,
21022123
max_cacheable_statement_size=max_cacheable_statement_size,
2124+
target_session_attribute=target_session_attribute
21032125
)
21042126

21052127

asyncpg/exceptions/_base.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
__all__ = ('PostgresError', 'FatalPostgresError', 'UnknownPostgresError',
1414
'InterfaceError', 'InterfaceWarning', 'PostgresLogMessage',
1515
'InternalClientError', 'OutdatedSchemaCacheError', 'ProtocolError',
16-
'UnsupportedClientFeatureError')
16+
'UnsupportedClientFeatureError', 'TargetServerAttributeNotMatched')
1717

1818

1919
def _is_asyncpg_class(cls):
@@ -244,6 +244,10 @@ class ProtocolError(InternalClientError):
244244
"""Unexpected condition in the handling of PostgreSQL protocol input."""
245245

246246

247+
class TargetServerAttributeNotMatched(InternalClientError):
248+
"""Could not find a host that satisfies the target attribute requirement"""
249+
250+
247251
class OutdatedSchemaCacheError(InternalClientError):
248252
"""A value decoding error caused by a schema change before row fetching."""
249253

0 commit comments

Comments
 (0)