Skip to content

Add sslmode=allow support and fix =prefer retry #720

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 24, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 110 additions & 38 deletions asyncpg/connect_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import asyncio
import collections
import enum
import functools
import getpass
import os
Expand All @@ -28,14 +29,29 @@
from . import protocol


class SSLMode(enum.IntEnum):
disable = 0
allow = 1
prefer = 2
require = 3
verify_ca = 4
verify_full = 5

@classmethod
def parse(cls, sslmode):
if isinstance(sslmode, cls):
return sslmode
return getattr(cls, sslmode.replace('-', '_'))


_ConnectionParameters = collections.namedtuple(
'ConnectionParameters',
[
'user',
'password',
'database',
'ssl',
'ssl_is_advisory',
'sslmode',
'connect_timeout',
'server_settings',
])
Expand Down Expand Up @@ -402,46 +418,29 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
if ssl is None and have_tcp_addrs:
ssl = 'prefer'

# ssl_is_advisory is only allowed to come from the sslmode parameter.
ssl_is_advisory = None
if isinstance(ssl, str):
SSLMODES = {
'disable': 0,
'allow': 1,
'prefer': 2,
'require': 3,
'verify-ca': 4,
'verify-full': 5,
}
if isinstance(ssl, (str, SSLMode)):
try:
sslmode = SSLMODES[ssl]
except KeyError:
modes = ', '.join(SSLMODES.keys())
sslmode = SSLMode.parse(ssl)
except AttributeError:
modes = ', '.join(m.name.replace('_', '-') for m in SSLMode)
raise exceptions.InterfaceError(
'`sslmode` parameter must be one of: {}'.format(modes))

# sslmode 'allow' is currently handled as 'prefer' because we're
# missing the "retry with SSL" behavior for 'allow', but do have the
# "retry without SSL" behavior for 'prefer'.
# Not changing 'allow' to 'prefer' here would be effectively the same
# as changing 'allow' to 'disable'.
if sslmode == SSLMODES['allow']:
sslmode = SSLMODES['prefer']

# docs at https://www.postgresql.org/docs/10/static/libpq-connect.html
# Not implemented: sslcert & sslkey & sslrootcert & sslcrl params.
if sslmode <= SSLMODES['allow']:
if sslmode < SSLMode.allow:
ssl = False
ssl_is_advisory = sslmode >= SSLMODES['allow']
else:
ssl = ssl_module.create_default_context()
ssl.check_hostname = sslmode >= SSLMODES['verify-full']
ssl.check_hostname = sslmode >= SSLMode.verify_full
ssl.verify_mode = ssl_module.CERT_REQUIRED
if sslmode <= SSLMODES['require']:
if sslmode <= SSLMode.require:
ssl.verify_mode = ssl_module.CERT_NONE
ssl_is_advisory = sslmode <= SSLMODES['prefer']
elif ssl is True:
ssl = ssl_module.create_default_context()
sslmode = SSLMode.verify_full
else:
sslmode = SSLMode.disable

if server_settings is not None and (
not isinstance(server_settings, dict) or
Expand All @@ -453,7 +452,7 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,

params = _ConnectionParameters(
user=user, password=password, database=database, ssl=ssl,
ssl_is_advisory=ssl_is_advisory, connect_timeout=connect_timeout,
sslmode=sslmode, connect_timeout=connect_timeout,
server_settings=server_settings)

return addrs, params
Expand Down Expand Up @@ -520,9 +519,8 @@ def data_received(self, data):
data == b'N'):
# ssl_is_advisory will imply that ssl.verify_mode == CERT_NONE,
# since the only way to get ssl_is_advisory is from
# sslmode=prefer (or sslmode=allow). But be extra sure to
# disallow insecure connections when the ssl context asks for
# real security.
# sslmode=prefer. But be extra sure to disallow insecure
# connections when the ssl context asks for real security.
self.on_data.set_result(False)
else:
self.on_data.set_exception(
Expand Down Expand Up @@ -566,6 +564,7 @@ async def _create_ssl_connection(protocol_factory, host, port, *,
new_tr = tr

pg_proto = protocol_factory()
pg_proto.is_ssl = do_ssl_upgrade
pg_proto.connection_made(new_tr)
new_tr.set_protocol(pg_proto)

Expand All @@ -584,7 +583,9 @@ async def _create_ssl_connection(protocol_factory, host, port, *,
tr.close()

try:
return await conn_factory(sock=sock)
new_tr, pg_proto = await conn_factory(sock=sock)
pg_proto.is_ssl = do_ssl_upgrade
return new_tr, pg_proto
except (Exception, asyncio.CancelledError):
sock.close()
raise
Expand All @@ -605,8 +606,6 @@ async def _connect_addr(
if timeout <= 0:
raise asyncio.TimeoutError

connected = _create_future(loop)

params_input = params
if callable(params.password):
if inspect.iscoroutinefunction(params.password):
Expand All @@ -615,6 +614,49 @@ async def _connect_addr(
password = params.password()

params = params._replace(password=password)
args = (addr, loop, config, connection_class, record_class, params_input)

# prepare the params (which attempt has ssl) for the 2 attempts
if params.sslmode == SSLMode.allow:
params_retry = params
params = params._replace(ssl=None)
elif params.sslmode == SSLMode.prefer:
params_retry = params._replace(ssl=None)
else:
# skip retry if we don't have to
return await __connect_addr(params, timeout, False, *args)

# first attempt
before = time.monotonic()
try:
return await __connect_addr(params, timeout, True, *args)
except _Retry:
pass

# second attempt
timeout -= time.monotonic() - before
if timeout <= 0:
raise asyncio.TimeoutError
else:
return await __connect_addr(params_retry, timeout, False, *args)


class _Retry(Exception):
pass


async def __connect_addr(
params,
timeout,
retry,
addr,
loop,
config,
connection_class,
record_class,
params_input,
):
connected = _create_future(loop)

proto_factory = lambda: protocol.Protocol(
addr, connected, params, record_class, loop)
Expand All @@ -625,7 +667,7 @@ async def _connect_addr(
elif params.ssl:
connector = _create_ssl_connection(
proto_factory, *addr, loop=loop, ssl_context=params.ssl,
ssl_is_advisory=params.ssl_is_advisory)
ssl_is_advisory=params.sslmode == SSLMode.prefer)
else:
connector = loop.create_connection(proto_factory, *addr)

Expand All @@ -638,6 +680,35 @@ async def _connect_addr(
if timeout <= 0:
raise asyncio.TimeoutError
await compat.wait_for(connected, timeout=timeout)
except (
exceptions.InvalidAuthorizationSpecificationError,
exceptions.ConnectionDoesNotExistError, # seen on Windows
):
tr.close()

# retry=True here is a redundant check because we don't want to
# accidentally raise the internal _Retry to the outer world
if retry and (
params.sslmode == SSLMode.allow and not pr.is_ssl or
params.sslmode == SSLMode.prefer and pr.is_ssl
):
# Trigger retry when:
# 1. First attempt with sslmode=allow, ssl=None failed
# 2. First attempt with sslmode=prefer, ssl=ctx failed while the
# server claimed to support SSL (returning "S" for SSLRequest)
# (likely because pg_hba.conf rejected the connection)
raise _Retry()

else:
# but will NOT retry if:
# 1. First attempt with sslmode=prefer failed but the server
# doesn't support SSL (returning 'N' for SSLRequest), because
# we already tried to connect without SSL thru ssl_is_advisory
# 2. Second attempt with sslmode=prefer, ssl=None failed
# 3. Second attempt with sslmode=allow, ssl=ctx failed
# 4. Any other sslmode
raise

except (Exception, asyncio.CancelledError):
tr.close()
raise
Expand Down Expand Up @@ -684,6 +755,7 @@ class CancelProto(asyncio.Protocol):

def __init__(self):
self.on_disconnect = _create_future(loop)
self.is_ssl = False

def connection_lost(self, exc):
if not self.on_disconnect.done():
Expand All @@ -692,13 +764,13 @@ def connection_lost(self, exc):
if isinstance(addr, str):
tr, pr = await loop.create_unix_connection(CancelProto, addr)
else:
if params.ssl:
if params.ssl and params.sslmode != SSLMode.allow:
tr, pr = await _create_ssl_connection(
CancelProto,
*addr,
loop=loop,
ssl_context=params.ssl,
ssl_is_advisory=params.ssl_is_advisory)
ssl_is_advisory=params.sslmode == SSLMode.prefer)
else:
tr, pr = await loop.create_connection(
CancelProto, *addr)
Expand Down
3 changes: 2 additions & 1 deletion asyncpg/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -1879,7 +1879,8 @@ async def connect(dsn=None, *,
- ``'disable'`` - SSL is disabled (equivalent to ``False``)
- ``'prefer'`` - try SSL first, fallback to non-SSL connection
if SSL connection fails
- ``'allow'`` - currently equivalent to ``'prefer'``
- ``'allow'`` - try without SSL first, then retry with SSL if the first
attempt fails.
- ``'require'`` - only try an SSL connection. Certificate
verification errors are ignored
- ``'verify-ca'`` - only try an SSL connection, and verify
Expand Down
2 changes: 2 additions & 0 deletions asyncpg/protocol/protocol.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ cdef class BaseProtocol(CoreProtocol):

readonly uint64_t queries_count

bint _is_ssl

PreparedStatementState statement

cdef get_connection(self)
Expand Down
10 changes: 10 additions & 0 deletions asyncpg/protocol/protocol.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ cdef class BaseProtocol(CoreProtocol):

self.queries_count = 0

self._is_ssl = False

try:
self.create_future = loop.create_future
except AttributeError:
Expand Down Expand Up @@ -943,6 +945,14 @@ cdef class BaseProtocol(CoreProtocol):
def resume_writing(self):
self.writing_allowed.set()

@property
def is_ssl(self):
return self._is_ssl

@is_ssl.setter
def is_ssl(self, value):
self._is_ssl = value


class Timer:
def __init__(self, budget):
Expand Down
Loading