diff --git a/asyncpg/_testbase/__init__.py b/asyncpg/_testbase/__init__.py index 9944b20f..3dd8a314 100644 --- a/asyncpg/_testbase/__init__.py +++ b/asyncpg/_testbase/__init__.py @@ -435,3 +435,92 @@ def tearDown(self): self.con = None finally: super().tearDown() + + +class HotStandbyTestCase(ClusterTestCase): + @classmethod + def setup_cluster(cls): + cls.master_cluster = cls.new_cluster(pg_cluster.TempCluster) + cls.start_cluster( + cls.master_cluster, + server_settings={ + 'max_wal_senders': 10, + 'wal_level': 'hot_standby' + } + ) + + con = None + + try: + con = cls.loop.run_until_complete( + cls.master_cluster.connect( + database='postgres', user='postgres', loop=cls.loop)) + + cls.loop.run_until_complete( + con.execute(''' + CREATE ROLE replication WITH LOGIN REPLICATION + ''')) + + cls.master_cluster.trust_local_replication_by('replication') + + conn_spec = cls.master_cluster.get_connection_spec() + + cls.standby_cluster = cls.new_cluster( + pg_cluster.HotStandbyCluster, + cluster_kwargs={ + 'master': conn_spec, + 'replication_user': 'replication' + } + ) + cls.start_cluster( + cls.standby_cluster, + server_settings={ + 'hot_standby': True + } + ) + + finally: + if con is not None: + cls.loop.run_until_complete(con.close()) + + @classmethod + def get_cluster_connection_spec(cls, cluster, kwargs={}): + conn_spec = cluster.get_connection_spec() + if kwargs.get('dsn'): + conn_spec.pop('host') + conn_spec.update(kwargs) + if not os.environ.get('PGHOST') and not kwargs.get('dsn'): + if 'database' not in conn_spec: + conn_spec['database'] = 'postgres' + if 'user' not in conn_spec: + conn_spec['user'] = 'postgres' + return conn_spec + + @classmethod + def get_connection_spec(cls, kwargs={}): + primary_spec = cls.get_cluster_connection_spec( + cls.master_cluster, kwargs + ) + standby_spec = cls.get_cluster_connection_spec( + cls.standby_cluster, kwargs + ) + return { + 'host': [primary_spec['host'], standby_spec['host']], + 'port': [primary_spec['port'], standby_spec['port']], + 'database': primary_spec['database'], + 'user': primary_spec['user'], + **kwargs + } + + @classmethod + def connect_primary(cls, **kwargs): + conn_spec = cls.get_cluster_connection_spec(cls.master_cluster, kwargs) + return pg_connection.connect(**conn_spec, loop=cls.loop) + + @classmethod + def connect_standby(cls, **kwargs): + conn_spec = cls.get_cluster_connection_spec( + cls.standby_cluster, + kwargs + ) + return pg_connection.connect(**conn_spec, loop=cls.loop) diff --git a/asyncpg/connect_utils.py b/asyncpg/connect_utils.py index 90a61503..a51eb789 100644 --- a/asyncpg/connect_utils.py +++ b/asyncpg/connect_utils.py @@ -13,6 +13,7 @@ import os import pathlib import platform +import random import re import socket import ssl as ssl_module @@ -56,6 +57,7 @@ def parse(cls, sslmode): 'direct_tls', 'connect_timeout', 'server_settings', + 'target_session_attribute', ]) @@ -259,7 +261,8 @@ def _dot_postgresql_path(filename) -> pathlib.Path: def _parse_connect_dsn_and_args(*, dsn, host, port, user, password, passfile, database, ssl, - direct_tls, connect_timeout, server_settings): + direct_tls, connect_timeout, server_settings, + target_session_attribute): # `auth_hosts` is the version of host information for the purposes # of reading the pgpass file. auth_hosts = None @@ -603,7 +606,8 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user, params = _ConnectionParameters( user=user, password=password, database=database, ssl=ssl, sslmode=sslmode, direct_tls=direct_tls, - connect_timeout=connect_timeout, server_settings=server_settings) + connect_timeout=connect_timeout, server_settings=server_settings, + target_session_attribute=target_session_attribute) return addrs, params @@ -613,8 +617,8 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile, statement_cache_size, max_cached_statement_lifetime, max_cacheable_statement_size, - ssl, direct_tls, server_settings): - + ssl, direct_tls, server_settings, + target_session_attribute): local_vars = locals() for var_name in {'max_cacheable_statement_size', 'max_cached_statement_lifetime', @@ -642,7 +646,8 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile, dsn=dsn, host=host, port=port, user=user, password=password, passfile=passfile, ssl=ssl, direct_tls=direct_tls, database=database, - connect_timeout=timeout, server_settings=server_settings) + connect_timeout=timeout, server_settings=server_settings, + target_session_attribute=target_session_attribute) config = _ClientConfiguration( command_timeout=command_timeout, @@ -875,18 +880,64 @@ async def __connect_addr( return con +class SessionAttribute(str, enum.Enum): + any = 'any' + primary = 'primary' + standby = 'standby' + prefer_standby = 'prefer-standby' + + +def _accept_in_hot_standby(should_be_in_hot_standby: bool): + """ + If the server didn't report "in_hot_standby" at startup, we must determine + the state by checking "SELECT pg_catalog.pg_is_in_recovery()". + """ + async def can_be_used(connection): + settings = connection.get_settings() + hot_standby_status = getattr(settings, 'in_hot_standby', None) + if hot_standby_status is not None: + is_in_hot_standby = hot_standby_status == 'on' + else: + is_in_hot_standby = await connection.fetchval( + "SELECT pg_catalog.pg_is_in_recovery()" + ) + + return is_in_hot_standby == should_be_in_hot_standby + + return can_be_used + + +async def _accept_any(_): + return True + + +target_attrs_check = { + SessionAttribute.any: _accept_any, + SessionAttribute.primary: _accept_in_hot_standby(False), + SessionAttribute.standby: _accept_in_hot_standby(True), + SessionAttribute.prefer_standby: _accept_in_hot_standby(True), +} + + +async def _can_use_connection(connection, attr: SessionAttribute): + can_use = target_attrs_check[attr] + return await can_use(connection) + + async def _connect(*, loop, timeout, connection_class, record_class, **kwargs): if loop is None: loop = asyncio.get_event_loop() addrs, params, config = _parse_connect_arguments(timeout=timeout, **kwargs) + target_attr = params.target_session_attribute + candidates = [] + chosen_connection = None last_error = None - addr = None for addr in addrs: before = time.monotonic() try: - return await _connect_addr( + conn = await _connect_addr( addr=addr, loop=loop, timeout=timeout, @@ -895,12 +946,30 @@ async def _connect(*, loop, timeout, connection_class, record_class, **kwargs): connection_class=connection_class, record_class=record_class, ) + candidates.append(conn) + if await _can_use_connection(conn, target_attr): + chosen_connection = conn + break except (OSError, asyncio.TimeoutError, ConnectionError) as ex: last_error = ex finally: timeout -= time.monotonic() - before + else: + if target_attr == SessionAttribute.prefer_standby and candidates: + chosen_connection = random.choice(candidates) + + await asyncio.gather( + (c.close() for c in candidates if c is not chosen_connection), + return_exceptions=True + ) + + if chosen_connection: + return chosen_connection - raise last_error + raise last_error or exceptions.TargetServerAttributeNotMatched( + 'None of the hosts match the target attribute requirement ' + '{!r}'.format(target_attr) + ) async def _cancel(*, loop, addr, params: _ConnectionParameters, diff --git a/asyncpg/connection.py b/asyncpg/connection.py index ea128aab..6797c54e 100644 --- a/asyncpg/connection.py +++ b/asyncpg/connection.py @@ -30,6 +30,7 @@ from . import serverversion from . import transaction from . import utils +from .connect_utils import SessionAttribute class ConnectionMeta(type): @@ -1792,7 +1793,8 @@ async def connect(dsn=None, *, direct_tls=False, connection_class=Connection, record_class=protocol.Record, - server_settings=None): + server_settings=None, + target_session_attribute=SessionAttribute.any): r"""A coroutine to establish a connection to a PostgreSQL server. The connection parameters may be specified either as a connection @@ -2003,6 +2005,16 @@ async def connect(dsn=None, *, this connection object. Must be a subclass of :class:`~asyncpg.Record`. + :param SessionAttribute target_session_attribute: + If specified, check that the host has the correct attribute. + Can be one of: + "any": the first successfully connected host + "primary": the host must NOT be in hot standby mode + "standby": the host must be in hot standby mode + "prefer-standby": first try to find a standby host, but if + none of the listed hosts is a standby server, + return any of them. + :return: A :class:`~asyncpg.connection.Connection` instance. Example: @@ -2087,6 +2099,15 @@ async def connect(dsn=None, *, if record_class is not protocol.Record: _check_record_class(record_class) + try: + target_session_attribute = SessionAttribute(target_session_attribute) + except ValueError as exc: + raise exceptions.InterfaceError( + "target_session_attribute is expected to be one of " + "'any', 'primary', 'standby' or 'prefer-standby'" + ", got {!r}".format(target_session_attribute) + ) from exc + if loop is None: loop = asyncio.get_event_loop() @@ -2109,6 +2130,7 @@ async def connect(dsn=None, *, statement_cache_size=statement_cache_size, max_cached_statement_lifetime=max_cached_statement_lifetime, max_cacheable_statement_size=max_cacheable_statement_size, + target_session_attribute=target_session_attribute ) diff --git a/asyncpg/exceptions/_base.py b/asyncpg/exceptions/_base.py index 783b5eb5..de981d25 100644 --- a/asyncpg/exceptions/_base.py +++ b/asyncpg/exceptions/_base.py @@ -13,7 +13,7 @@ __all__ = ('PostgresError', 'FatalPostgresError', 'UnknownPostgresError', 'InterfaceError', 'InterfaceWarning', 'PostgresLogMessage', 'InternalClientError', 'OutdatedSchemaCacheError', 'ProtocolError', - 'UnsupportedClientFeatureError') + 'UnsupportedClientFeatureError', 'TargetServerAttributeNotMatched') def _is_asyncpg_class(cls): @@ -244,6 +244,10 @@ class ProtocolError(InternalClientError): """Unexpected condition in the handling of PostgreSQL protocol input.""" +class TargetServerAttributeNotMatched(InternalClientError): + """Could not find a host that satisfies the target attribute requirement""" + + class OutdatedSchemaCacheError(InternalClientError): """A value decoding error caused by a schema change before row fetching.""" diff --git a/tests/test_connect.py b/tests/test_connect.py index db7817f6..f905e3cd 100644 --- a/tests/test_connect.py +++ b/tests/test_connect.py @@ -789,6 +789,7 @@ def run_testcase(self, testcase): database = testcase.get('database') sslmode = testcase.get('ssl') server_settings = testcase.get('server_settings') + target_session_attribute = testcase.get('target_session_attribute') expected = testcase.get('result') expected_error = testcase.get('error') @@ -812,7 +813,8 @@ def run_testcase(self, testcase): dsn=dsn, host=host, port=port, user=user, password=password, passfile=passfile, database=database, ssl=sslmode, direct_tls=False, connect_timeout=None, - server_settings=server_settings) + server_settings=server_settings, + target_session_attribute=target_session_attribute) params = { k: v for k, v in params._asdict().items() @@ -1743,3 +1745,66 @@ async def test_no_explicit_close_with_debug(self): self.assertIn('in test_no_explicit_close_with_debug', msg) finally: self.loop.set_debug(olddebug) + + +class TestConnectionAttributes(tb.HotStandbyTestCase): + + async def _run_connection_test( + self, connect, target_attribute, expected_host + ): + conn = await connect(target_session_attribute=target_attribute) + self.assertTrue(_get_connected_host(conn).startswith(expected_host)) + await conn.close() + + async def test_target_server_attribute_host(self): + master_host = self.master_cluster.get_connection_spec()['host'] + standby_host = self.standby_cluster.get_connection_spec()['host'] + tests = [ + (self.connect_primary, 'primary', master_host), + (self.connect_standby, 'standby', standby_host), + ] + + for connect, target_attr, expected_host in tests: + await self._run_connection_test( + connect, target_attr, expected_host + ) + + async def test_target_attribute_not_matched(self): + tests = [ + (self.connect_standby, 'primary'), + (self.connect_primary, 'standby'), + ] + + for connect, target_attr in tests: + with self.assertRaises(exceptions.TargetServerAttributeNotMatched): + await connect(target_session_attribute=target_attr) + + async def test_prefer_standby_when_standby_is_up(self): + con = await self.connect(target_session_attribute='prefer-standby') + standby_host = self.standby_cluster.get_connection_spec()['host'] + connected_host = _get_connected_host(con) + self.assertTrue(connected_host.startswith(standby_host)) + await con.close() + + async def test_prefer_standby_picks_master_when_standby_is_down(self): + primary_spec = self.get_cluster_connection_spec(self.master_cluster) + connection_spec = { + 'host': [ + primary_spec['host'], + '/var/test/a/cluster/that/does/not/exist', + ], + 'port': [primary_spec['port'], 12345], + 'database': primary_spec['database'], + 'user': primary_spec['user'], + 'target_session_attribute': 'prefer-standby' + } + + con = await connection.connect(**connection_spec, loop=self.loop) + master_host = self.master_cluster.get_connection_spec()['host'] + connected_host = _get_connected_host(con) + self.assertTrue(connected_host.startswith(master_host)) + await con.close() + + +def _get_connected_host(con): + return con._transport.get_extra_info('peername') diff --git a/tests/test_pool.py b/tests/test_pool.py index e2c99efc..f96cd2a6 100644 --- a/tests/test_pool.py +++ b/tests/test_pool.py @@ -18,7 +18,6 @@ import asyncpg from asyncpg import _testbase as tb from asyncpg import connection as pg_connection -from asyncpg import cluster as pg_cluster from asyncpg import pool as pg_pool _system = platform.uname().system @@ -964,52 +963,7 @@ async def worker(): @unittest.skipIf(os.environ.get('PGHOST'), 'using remote cluster for testing') -class TestHotStandby(tb.ClusterTestCase): - @classmethod - def setup_cluster(cls): - cls.master_cluster = cls.new_cluster(pg_cluster.TempCluster) - cls.start_cluster( - cls.master_cluster, - server_settings={ - 'max_wal_senders': 10, - 'wal_level': 'hot_standby' - } - ) - - con = None - - try: - con = cls.loop.run_until_complete( - cls.master_cluster.connect( - database='postgres', user='postgres', loop=cls.loop)) - - cls.loop.run_until_complete( - con.execute(''' - CREATE ROLE replication WITH LOGIN REPLICATION - ''')) - - cls.master_cluster.trust_local_replication_by('replication') - - conn_spec = cls.master_cluster.get_connection_spec() - - cls.standby_cluster = cls.new_cluster( - pg_cluster.HotStandbyCluster, - cluster_kwargs={ - 'master': conn_spec, - 'replication_user': 'replication' - } - ) - cls.start_cluster( - cls.standby_cluster, - server_settings={ - 'hot_standby': True - } - ) - - finally: - if con is not None: - cls.loop.run_until_complete(con.close()) - +class TestHotStandby(tb.HotStandbyTestCase): def create_pool(self, **kwargs): conn_spec = self.standby_cluster.get_connection_spec() conn_spec.update(kwargs)