From a4731c83938c9da95606b05431133fe87e96871f Mon Sep 17 00:00:00 2001 From: Fantix King Date: Sun, 21 Mar 2021 23:25:59 -0400 Subject: [PATCH 1/2] Add sslmode=allow support and fix =prefer retry We didn't really retry the connection without SSL if the first SSL connection fails under sslmode=prefer, that led to an issue when the server has SSL support but explicitly denies SSL connection through pg_hba.conf. This commit adds a retry in a new connection, which makes it easy to implement the sslmode=allow retry. Fixes #716 --- asyncpg/connect_utils.py | 142 ++++++++++++++++------- asyncpg/connection.py | 3 +- asyncpg/protocol/protocol.pxd | 2 + asyncpg/protocol/protocol.pyx | 10 ++ tests/test_connect.py | 208 +++++++++++++++++++++++++++++++--- 5 files changed, 308 insertions(+), 57 deletions(-) diff --git a/asyncpg/connect_utils.py b/asyncpg/connect_utils.py index acfe87e4..e601f1d0 100644 --- a/asyncpg/connect_utils.py +++ b/asyncpg/connect_utils.py @@ -7,6 +7,7 @@ import asyncio import collections +import enum import functools import getpass import os @@ -28,6 +29,21 @@ from . import protocol +class SSLMode(enum.IntEnum): + disable = 0 + allow = 1 + prefer = 2 + require = 3 + verify_ca = 4 + verify_full = 5 + + @classmethod + def parse(cls, sslmode): + if isinstance(sslmode, cls): + return sslmode + return getattr(cls, sslmode.replace('-', '_')) + + _ConnectionParameters = collections.namedtuple( 'ConnectionParameters', [ @@ -35,7 +51,7 @@ 'password', 'database', 'ssl', - 'ssl_is_advisory', + 'sslmode', 'connect_timeout', 'server_settings', ]) @@ -402,46 +418,29 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user, if ssl is None and have_tcp_addrs: ssl = 'prefer' - # ssl_is_advisory is only allowed to come from the sslmode parameter. - ssl_is_advisory = None - if isinstance(ssl, str): - SSLMODES = { - 'disable': 0, - 'allow': 1, - 'prefer': 2, - 'require': 3, - 'verify-ca': 4, - 'verify-full': 5, - } + if isinstance(ssl, (str, SSLMode)): try: - sslmode = SSLMODES[ssl] - except KeyError: - modes = ', '.join(SSLMODES.keys()) + sslmode = SSLMode.parse(ssl) + except AttributeError: + modes = ', '.join(m.name.replace('_', '-') for m in SSLMode) raise exceptions.InterfaceError( '`sslmode` parameter must be one of: {}'.format(modes)) - # sslmode 'allow' is currently handled as 'prefer' because we're - # missing the "retry with SSL" behavior for 'allow', but do have the - # "retry without SSL" behavior for 'prefer'. - # Not changing 'allow' to 'prefer' here would be effectively the same - # as changing 'allow' to 'disable'. - if sslmode == SSLMODES['allow']: - sslmode = SSLMODES['prefer'] - # docs at https://www.postgresql.org/docs/10/static/libpq-connect.html # Not implemented: sslcert & sslkey & sslrootcert & sslcrl params. - if sslmode <= SSLMODES['allow']: + if sslmode < SSLMode.allow: ssl = False - ssl_is_advisory = sslmode >= SSLMODES['allow'] else: ssl = ssl_module.create_default_context() - ssl.check_hostname = sslmode >= SSLMODES['verify-full'] + ssl.check_hostname = sslmode >= SSLMode.verify_full ssl.verify_mode = ssl_module.CERT_REQUIRED - if sslmode <= SSLMODES['require']: + if sslmode <= SSLMode.require: ssl.verify_mode = ssl_module.CERT_NONE - ssl_is_advisory = sslmode <= SSLMODES['prefer'] elif ssl is True: ssl = ssl_module.create_default_context() + sslmode = SSLMode.verify_full + else: + sslmode = SSLMode.disable if server_settings is not None and ( not isinstance(server_settings, dict) or @@ -453,7 +452,7 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user, params = _ConnectionParameters( user=user, password=password, database=database, ssl=ssl, - ssl_is_advisory=ssl_is_advisory, connect_timeout=connect_timeout, + sslmode=sslmode, connect_timeout=connect_timeout, server_settings=server_settings) return addrs, params @@ -520,9 +519,8 @@ def data_received(self, data): data == b'N'): # ssl_is_advisory will imply that ssl.verify_mode == CERT_NONE, # since the only way to get ssl_is_advisory is from - # sslmode=prefer (or sslmode=allow). But be extra sure to - # disallow insecure connections when the ssl context asks for - # real security. + # sslmode=prefer. But be extra sure to disallow insecure + # connections when the ssl context asks for real security. self.on_data.set_result(False) else: self.on_data.set_exception( @@ -566,6 +564,7 @@ async def _create_ssl_connection(protocol_factory, host, port, *, new_tr = tr pg_proto = protocol_factory() + pg_proto.is_ssl = do_ssl_upgrade pg_proto.connection_made(new_tr) new_tr.set_protocol(pg_proto) @@ -584,7 +583,9 @@ async def _create_ssl_connection(protocol_factory, host, port, *, tr.close() try: - return await conn_factory(sock=sock) + new_tr, pg_proto = await conn_factory(sock=sock) + pg_proto.is_ssl = do_ssl_upgrade + return new_tr, pg_proto except (Exception, asyncio.CancelledError): sock.close() raise @@ -605,8 +606,6 @@ async def _connect_addr( if timeout <= 0: raise asyncio.TimeoutError - connected = _create_future(loop) - params_input = params if callable(params.password): if inspect.iscoroutinefunction(params.password): @@ -615,6 +614,44 @@ async def _connect_addr( password = params.password() params = params._replace(password=password) + args = (addr, loop, config, connection_class, record_class, params_input) + + # prepare the params (which attempt has ssl) for the 2 attempts + if params.sslmode == SSLMode.allow: + params_retry = params + params = params._replace(ssl=None) + elif params.sslmode == SSLMode.prefer: + params_retry = params._replace(ssl=None) + else: + # skip retry if we don't have to + return await __connect_addr(params, timeout, *args) + + # first attempt + before = time.monotonic() + try: + return await __connect_addr(params, timeout, *args) + except ConnectionError: + pass + + # second attempt + timeout -= time.monotonic() - before + if timeout <= 0: + raise asyncio.TimeoutError + else: + return await __connect_addr(params_retry, timeout, *args) + + +async def __connect_addr( + params, + timeout, + addr, + loop, + config, + connection_class, + record_class, + params_input, +): + connected = _create_future(loop) proto_factory = lambda: protocol.Protocol( addr, connected, params, record_class, loop) @@ -625,7 +662,7 @@ async def _connect_addr( elif params.ssl: connector = _create_ssl_connection( proto_factory, *addr, loop=loop, ssl_context=params.ssl, - ssl_is_advisory=params.ssl_is_advisory) + ssl_is_advisory=params.sslmode == SSLMode.prefer) else: connector = loop.create_connection(proto_factory, *addr) @@ -638,6 +675,34 @@ async def _connect_addr( if timeout <= 0: raise asyncio.TimeoutError await compat.wait_for(connected, timeout=timeout) + except ( + exceptions.InvalidAuthorizationSpecificationError, + exceptions.ConnectionDoesNotExistError, # seen on Windows + ): + tr.close() + + if ( + params.sslmode == SSLMode.allow and not pr.is_ssl or + params.sslmode == SSLMode.prefer and pr.is_ssl + ): + # Elevate the error to ConnectionError to trigger retry when: + # 1. First attempt with sslmode=allow, ssl=None failed + # 2. First attempt with sslmode=prefer, ssl=ctx failed while the + # server claimed to support SSL (returning "S" for SSLRequest) + # (likely because pg_hba.conf rejected the connection) + raise ConnectionError("Connection rejected trying {} SSL".format( + 'with' if pr.is_ssl else 'without')) + + else: + # but will NOT retry if: + # 1. First attempt with sslmode=prefer failed but the server + # doesn't support SSL (returning 'N' for SSLRequest), because + # we already tried to connect without SSL thru ssl_is_advisory + # 2. Second attempt with sslmode=prefer, ssl=None failed + # 3. Second attempt with sslmode=allow, ssl=ctx failed + # 4. Any other sslmode + raise + except (Exception, asyncio.CancelledError): tr.close() raise @@ -684,6 +749,7 @@ class CancelProto(asyncio.Protocol): def __init__(self): self.on_disconnect = _create_future(loop) + self.is_ssl = False def connection_lost(self, exc): if not self.on_disconnect.done(): @@ -692,13 +758,13 @@ def connection_lost(self, exc): if isinstance(addr, str): tr, pr = await loop.create_unix_connection(CancelProto, addr) else: - if params.ssl: + if params.ssl and params.sslmode != SSLMode.allow: tr, pr = await _create_ssl_connection( CancelProto, *addr, loop=loop, ssl_context=params.ssl, - ssl_is_advisory=params.ssl_is_advisory) + ssl_is_advisory=params.sslmode == SSLMode.prefer) else: tr, pr = await loop.create_connection( CancelProto, *addr) diff --git a/asyncpg/connection.py b/asyncpg/connection.py index 2e86fde0..043c6ddd 100644 --- a/asyncpg/connection.py +++ b/asyncpg/connection.py @@ -1879,7 +1879,8 @@ async def connect(dsn=None, *, - ``'disable'`` - SSL is disabled (equivalent to ``False``) - ``'prefer'`` - try SSL first, fallback to non-SSL connection if SSL connection fails - - ``'allow'`` - currently equivalent to ``'prefer'`` + - ``'allow'`` - try without SSL first, then retry with SSL if the first + attempt fails. - ``'require'`` - only try an SSL connection. Certificate verification errors are ignored - ``'verify-ca'`` - only try an SSL connection, and verify diff --git a/asyncpg/protocol/protocol.pxd b/asyncpg/protocol/protocol.pxd index 772d6432..5f144e55 100644 --- a/asyncpg/protocol/protocol.pxd +++ b/asyncpg/protocol/protocol.pxd @@ -52,6 +52,8 @@ cdef class BaseProtocol(CoreProtocol): readonly uint64_t queries_count + bint _is_ssl + PreparedStatementState statement cdef get_connection(self) diff --git a/asyncpg/protocol/protocol.pyx b/asyncpg/protocol/protocol.pyx index 4df256e6..3a1594a5 100644 --- a/asyncpg/protocol/protocol.pyx +++ b/asyncpg/protocol/protocol.pyx @@ -103,6 +103,8 @@ cdef class BaseProtocol(CoreProtocol): self.queries_count = 0 + self._is_ssl = False + try: self.create_future = loop.create_future except AttributeError: @@ -943,6 +945,14 @@ cdef class BaseProtocol(CoreProtocol): def resume_writing(self): self.writing_allowed.set() + @property + def is_ssl(self): + return self._is_ssl + + @is_ssl.setter + def is_ssl(self, value): + self._is_ssl = value + class Timer: def __init__(self, budget): diff --git a/tests/test_connect.py b/tests/test_connect.py index 5adb977d..7b08f93d 100644 --- a/tests/test_connect.py +++ b/tests/test_connect.py @@ -24,6 +24,7 @@ from asyncpg import connect_utils from asyncpg import cluster as pg_cluster from asyncpg import exceptions +from asyncpg.connect_utils import SSLMode from asyncpg.serverversion import split_server_version_string _system = platform.uname().system @@ -308,6 +309,7 @@ class TestConnectParams(tb.TestCase): TESTS = [ { + 'name': 'all_env_default_ssl', 'env': { 'PGUSER': 'user', 'PGDATABASE': 'testdb', @@ -320,10 +322,11 @@ class TestConnectParams(tb.TestCase): 'password': 'passw', 'database': 'testdb', 'ssl': True, - 'ssl_is_advisory': True}) + 'sslmode': SSLMode.prefer}) }, { + 'name': 'params_override_env', 'env': { 'PGUSER': 'user', 'PGDATABASE': 'testdb', @@ -345,6 +348,56 @@ class TestConnectParams(tb.TestCase): }, { + 'name': 'params_override_env_and_dsn', + 'env': { + 'PGUSER': 'user', + 'PGDATABASE': 'testdb', + 'PGPASSWORD': 'passw', + 'PGHOST': 'host', + 'PGPORT': '123', + 'PGSSLMODE': 'allow' + }, + + 'dsn': 'postgres://user3:123123@localhost/abcdef', + + 'host': 'host2', + 'port': '456', + 'user': 'user2', + 'password': 'passw2', + 'database': 'db2', + 'ssl': False, + + 'result': ([('host2', 456)], { + 'user': 'user2', + 'password': 'passw2', + 'database': 'db2', + 'sslmode': SSLMode.disable, + 'ssl': False}) + }, + + { + 'name': 'dsn_overrides_env_partially', + 'env': { + 'PGUSER': 'user', + 'PGDATABASE': 'testdb', + 'PGPASSWORD': 'passw', + 'PGHOST': 'host', + 'PGPORT': '123', + 'PGSSLMODE': 'allow' + }, + + 'dsn': 'postgres://user3:123123@localhost:5555/abcdef', + + 'result': ([('localhost', 5555)], { + 'user': 'user3', + 'password': '123123', + 'database': 'abcdef', + 'ssl': True, + 'sslmode': SSLMode.allow}) + }, + + { + 'name': 'params_override_env_and_dsn_ssl_prefer', 'env': { 'PGUSER': 'user', 'PGDATABASE': 'testdb', @@ -367,10 +420,12 @@ class TestConnectParams(tb.TestCase): 'user': 'user2', 'password': 'passw2', 'database': 'db2', + 'sslmode': SSLMode.disable, 'ssl': False}) }, { + 'name': 'dsn_overrides_env_partially_ssl_prefer', 'env': { 'PGUSER': 'user', 'PGDATABASE': 'testdb', @@ -387,10 +442,11 @@ class TestConnectParams(tb.TestCase): 'password': '123123', 'database': 'abcdef', 'ssl': True, - 'ssl_is_advisory': True}) + 'sslmode': SSLMode.prefer}) }, { + 'name': 'dsn_only', 'dsn': 'postgres://user3:123123@localhost:5555/abcdef', 'result': ([('localhost', 5555)], { 'user': 'user3', @@ -399,6 +455,7 @@ class TestConnectParams(tb.TestCase): }, { + 'name': 'dsn_only_multi_host', 'dsn': 'postgresql://user@host1,host2/db', 'result': ([('host1', 5432), ('host2', 5432)], { 'database': 'db', @@ -407,6 +464,7 @@ class TestConnectParams(tb.TestCase): }, { + 'name': 'dsn_only_multi_host_and_port', 'dsn': 'postgresql://user@host1:1111,host2:2222/db', 'result': ([('host1', 1111), ('host2', 2222)], { 'database': 'db', @@ -415,6 +473,7 @@ class TestConnectParams(tb.TestCase): }, { + 'name': 'dsn_combines_env_multi_host', 'env': { 'PGHOST': 'host1:1111,host2:2222', 'PGUSER': 'foo', @@ -427,6 +486,7 @@ class TestConnectParams(tb.TestCase): }, { + 'name': 'dsn_multi_host_combines_env', 'env': { 'PGUSER': 'foo', }, @@ -438,6 +498,7 @@ class TestConnectParams(tb.TestCase): }, { + 'name': 'params_multi_host_dsn_env_mix', 'env': { 'PGUSER': 'foo', }, @@ -450,6 +511,7 @@ class TestConnectParams(tb.TestCase): }, { + 'name': 'params_combine_dsn_settings_override_and_ssl', 'dsn': 'postgresql://user3:123123@localhost:5555/' 'abcdef?param=sss¶m=123&host=testhost&user=testuser' '&port=2222&database=testdb&sslmode=require', @@ -464,10 +526,11 @@ class TestConnectParams(tb.TestCase): 'password': 'ask', 'database': 'db', 'ssl': True, - 'ssl_is_advisory': False}) + 'sslmode': SSLMode.require}) }, { + 'name': 'params_settings_and_ssl_override_dsn', 'dsn': 'postgresql://user3:123123@localhost:5555/' 'abcdef?param=sss¶m=123&host=testhost&user=testuser' '&port=2222&database=testdb&sslmode=disable', @@ -483,10 +546,12 @@ class TestConnectParams(tb.TestCase): 'user': 'me', 'password': 'ask', 'database': 'db', + 'sslmode': SSLMode.verify_full, 'ssl': True}) }, { + 'name': 'dsn_only_unix', 'dsn': 'postgresql:///dbname?host=/unix_sock/test&user=spam', 'result': ([os.path.join('/unix_sock/test', '.s.PGSQL.5432')], { 'user': 'spam', @@ -494,6 +559,7 @@ class TestConnectParams(tb.TestCase): }, { + 'name': 'dsn_only_quoted', 'dsn': 'postgresql://us%40r:p%40ss@h%40st1,h%40st2:543%33/d%62', 'result': ( [('h@st1', 5432), ('h@st2', 5433)], @@ -506,6 +572,7 @@ class TestConnectParams(tb.TestCase): }, { + 'name': 'dsn_only_unquoted_host', 'dsn': 'postgresql://user:p@ss@host/db', 'result': ( [('ss@host', 5432)], @@ -518,6 +585,7 @@ class TestConnectParams(tb.TestCase): }, { + 'name': 'dsn_only_quoted_params', 'dsn': 'postgresql:///d%62?user=us%40r&host=h%40st&port=543%33', 'result': ( [('h@st', 5433)], @@ -529,10 +597,12 @@ class TestConnectParams(tb.TestCase): }, { + 'name': 'dsn_only_illegal_protocol', 'dsn': 'pq:///dbname?host=/unix_sock/test&user=spam', 'error': (ValueError, 'invalid DSN') }, { + 'name': 'dsn_params_ports_mismatch_dsn_multi_hosts', 'dsn': 'postgresql://host1,host2,host3/db', 'port': [111, 222], 'error': ( @@ -541,17 +611,20 @@ class TestConnectParams(tb.TestCase): ) }, { + 'name': 'dsn_only_quoted_unix_host_port_in_params', 'dsn': 'postgres://user@?port=56226&host=%2Ftmp', 'result': ( [os.path.join('/tmp', '.s.PGSQL.56226')], { 'user': 'user', 'database': 'user', + 'sslmode': SSLMode.disable, 'ssl': None } ) }, { + 'name': 'dsn_only_cloudsql', 'dsn': 'postgres:///db?host=/cloudsql/' 'project:region:instance-name&user=spam', 'result': ( @@ -565,6 +638,7 @@ class TestConnectParams(tb.TestCase): ) }, { + 'name': 'dsn_only_cloudsql_unix_and_tcp', 'dsn': 'postgres:///db?host=127.0.0.1:5432,/cloudsql/' 'project:region:instance-name,localhost:5433&user=spam', 'result': ( @@ -579,7 +653,7 @@ class TestConnectParams(tb.TestCase): 'user': 'spam', 'database': 'db', 'ssl': True, - 'ssl_is_advisory': True + 'sslmode': SSLMode.prefer, } ) }, @@ -663,7 +737,7 @@ def run_testcase(self, testcase): # Avoid the hassle of specifying the default SSL mode # unless explicitly tested for. params.pop('ssl', None) - params.pop('ssl_is_advisory', None) + params.pop('sslmode', None) self.assertEqual(expected, result, 'Testcase: {}'.format(testcase)) @@ -1050,6 +1124,7 @@ async def verify_works(sslmode): dsn='postgresql://foo/?sslmode=' + sslmode, host='localhost') self.assertEqual(await con.fetchval('SELECT 42'), 42) + self.assertFalse(con._protocol.is_ssl) finally: if con: await con.close() @@ -1058,7 +1133,7 @@ async def verify_fails(sslmode): con = None try: with self.assertRaises(ConnectionError): - await self.connect( + con = await self.connect( dsn='postgresql://foo/?sslmode=' + sslmode, host='localhost') await con.fetchval('SELECT 42') @@ -1082,8 +1157,7 @@ async def test_connection_implicit_host(self): await con.close() -@unittest.skipIf(os.environ.get('PGHOST'), 'unmanaged cluster') -class TestSSLConnection(tb.ConnectedTestCase): +class BaseTestSSLConnection(tb.ConnectedTestCase): @classmethod def get_server_settings(cls): conf = super().get_server_settings() @@ -1109,15 +1183,7 @@ def setUp(self): create_script = [] create_script.append('CREATE ROLE ssl_user WITH LOGIN;') - self.cluster.add_hba_entry( - type='hostssl', address=ipaddress.ip_network('127.0.0.0/24'), - database='postgres', user='ssl_user', - auth_method='trust') - - self.cluster.add_hba_entry( - type='hostssl', address=ipaddress.ip_network('::1/128'), - database='postgres', user='ssl_user', - auth_method='trust') + self._add_hba_entry() # Put hba changes into effect self.cluster.reload() @@ -1136,6 +1202,23 @@ def tearDown(self): super().tearDown() + def _add_hba_entry(self): + raise NotImplementedError() + + +@unittest.skipIf(os.environ.get('PGHOST'), 'unmanaged cluster') +class TestSSLConnection(BaseTestSSLConnection): + def _add_hba_entry(self): + self.cluster.add_hba_entry( + type='hostssl', address=ipaddress.ip_network('127.0.0.0/24'), + database='postgres', user='ssl_user', + auth_method='trust') + + self.cluster.add_hba_entry( + type='hostssl', address=ipaddress.ip_network('::1/128'), + database='postgres', user='ssl_user', + auth_method='trust') + async def test_ssl_connection_custom_context(self): ssl_context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) ssl_context.load_verify_locations(SSL_CA_CERT_FILE) @@ -1164,6 +1247,7 @@ async def verify_works(sslmode, *, host='localhost'): host=host, user='ssl_user') self.assertEqual(await con.fetchval('SELECT 42'), 42) + self.assertTrue(con._protocol.is_ssl) finally: if con: await con.close() @@ -1176,7 +1260,7 @@ async def verify_fails(sslmode, *, host='localhost', try: self.loop.set_exception_handler(lambda *args: None) with self.assertRaises(exn_type): - await self.connect( + con = await self.connect( dsn='postgresql://foo/?sslmode=' + sslmode, host=host, user='ssl_user') @@ -1272,6 +1356,94 @@ async def test_executemany_uvloop_ssl_issue_700(self): await con.close() +@unittest.skipIf(os.environ.get('PGHOST'), 'unmanaged cluster') +class TestNoSSLConnection(BaseTestSSLConnection): + def _add_hba_entry(self): + self.cluster.add_hba_entry( + type='hostnossl', address=ipaddress.ip_network('127.0.0.0/24'), + database='postgres', user='ssl_user', + auth_method='trust') + + self.cluster.add_hba_entry( + type='hostnossl', address=ipaddress.ip_network('::1/128'), + database='postgres', user='ssl_user', + auth_method='trust') + + async def test_nossl_connection_sslmode(self): + async def verify_works(sslmode, *, host='localhost'): + con = None + try: + con = await self.connect( + dsn='postgresql://foo/?sslmode=' + sslmode, + host=host, + user='ssl_user') + self.assertEqual(await con.fetchval('SELECT 42'), 42) + self.assertFalse(con._protocol.is_ssl) + finally: + if con: + await con.close() + + async def verify_fails(sslmode, *, host='localhost', + exn_type=ssl.SSLError): + # XXX: uvloop artifact + old_handler = self.loop.get_exception_handler() + con = None + try: + self.loop.set_exception_handler(lambda *args: None) + with self.assertRaises(exn_type): + con = await self.connect( + dsn='postgresql://foo/?sslmode=' + sslmode, + host=host, + user='ssl_user') + await con.fetchval('SELECT 42') + finally: + if con: + await con.close() + self.loop.set_exception_handler(old_handler) + + invalid_auth_err = asyncpg.InvalidAuthorizationSpecificationError + await verify_works('disable') + await verify_works('allow') + await verify_works('prefer') + await verify_fails('require', exn_type=invalid_auth_err) + await verify_fails('verify-ca') + await verify_fails('verify-full') + + async def test_nossl_connection_prefer_cancel(self): + con = await self.connect( + dsn='postgresql://foo/?sslmode=prefer', + host='localhost', + user='ssl_user') + self.assertFalse(con._protocol.is_ssl) + with self.assertRaises(asyncio.TimeoutError): + await con.execute('SELECT pg_sleep(5)', timeout=0.5) + val = await con.fetchval('SELECT 123') + self.assertEqual(val, 123) + + async def test_nossl_connection_pool(self): + pool = await self.create_pool( + host='localhost', + user='ssl_user', + database='postgres', + min_size=5, + max_size=10, + ssl='prefer') + + async def worker(): + async with pool.acquire() as con: + self.assertFalse(con._protocol.is_ssl) + self.assertEqual(await con.fetchval('SELECT 42'), 42) + + with self.assertRaises(asyncio.TimeoutError): + await con.execute('SELECT pg_sleep(5)', timeout=0.5) + + self.assertEqual(await con.fetchval('SELECT 43'), 43) + + tasks = [worker() for _ in range(100)] + await asyncio.gather(*tasks) + await pool.close() + + class TestConnectionGC(tb.ClusterTestCase): async def _run_no_explicit_close_test(self): From 25323fdf72e7c7cf91ff8143e3c898446d19a4e8 Mon Sep 17 00:00:00 2001 From: Fantix King Date: Tue, 23 Mar 2021 23:18:38 -0400 Subject: [PATCH 2/2] CRF: use internal exception to mark the retry --- asyncpg/connect_utils.py | 22 ++++++++++++++-------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/asyncpg/connect_utils.py b/asyncpg/connect_utils.py index e601f1d0..3fd64252 100644 --- a/asyncpg/connect_utils.py +++ b/asyncpg/connect_utils.py @@ -624,13 +624,13 @@ async def _connect_addr( params_retry = params._replace(ssl=None) else: # skip retry if we don't have to - return await __connect_addr(params, timeout, *args) + return await __connect_addr(params, timeout, False, *args) # first attempt before = time.monotonic() try: - return await __connect_addr(params, timeout, *args) - except ConnectionError: + return await __connect_addr(params, timeout, True, *args) + except _Retry: pass # second attempt @@ -638,12 +638,17 @@ async def _connect_addr( if timeout <= 0: raise asyncio.TimeoutError else: - return await __connect_addr(params_retry, timeout, *args) + return await __connect_addr(params_retry, timeout, False, *args) + + +class _Retry(Exception): + pass async def __connect_addr( params, timeout, + retry, addr, loop, config, @@ -681,17 +686,18 @@ async def __connect_addr( ): tr.close() - if ( + # retry=True here is a redundant check because we don't want to + # accidentally raise the internal _Retry to the outer world + if retry and ( params.sslmode == SSLMode.allow and not pr.is_ssl or params.sslmode == SSLMode.prefer and pr.is_ssl ): - # Elevate the error to ConnectionError to trigger retry when: + # Trigger retry when: # 1. First attempt with sslmode=allow, ssl=None failed # 2. First attempt with sslmode=prefer, ssl=ctx failed while the # server claimed to support SSL (returning "S" for SSLRequest) # (likely because pg_hba.conf rejected the connection) - raise ConnectionError("Connection rejected trying {} SSL".format( - 'with' if pr.is_ssl else 'without')) + raise _Retry() else: # but will NOT retry if: