Skip to content

Handle environments without home dir #1011

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions asyncpg/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import asyncio
import pathlib
import platform
import typing


SYSTEM = platform.uname().system
Expand All @@ -18,7 +19,7 @@

CSIDL_APPDATA = 0x001a

def get_pg_home_directory() -> pathlib.Path:
def get_pg_home_directory() -> typing.Optional[pathlib.Path]:
# We cannot simply use expanduser() as that returns the user's
# home directory, whereas Postgres stores its config in
# %AppData% on Windows.
Expand All @@ -30,8 +31,11 @@ def get_pg_home_directory() -> pathlib.Path:
return pathlib.Path(buf.value) / 'postgresql'

else:
def get_pg_home_directory() -> pathlib.Path:
return pathlib.Path.home()
def get_pg_home_directory() -> typing.Optional[pathlib.Path]:
try:
return pathlib.Path.home()
except (RuntimeError, KeyError):
return None


async def wait_closed(stream):
Expand Down
49 changes: 32 additions & 17 deletions asyncpg/connect_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,8 +249,13 @@ def _parse_tls_version(tls_version):
)


def _dot_postgresql_path(filename) -> pathlib.Path:
return (pathlib.Path.home() / '.postgresql' / filename).resolve()
def _dot_postgresql_path(filename) -> typing.Optional[pathlib.Path]:
try:
homedir = pathlib.Path.home()
except (RuntimeError, KeyError):
return None

return (homedir / '.postgresql' / filename).resolve()


def _parse_connect_dsn_and_args(*, dsn, host, port, user,
Expand Down Expand Up @@ -501,11 +506,16 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
ssl.load_verify_locations(cafile=sslrootcert)
ssl.verify_mode = ssl_module.CERT_REQUIRED
else:
sslrootcert = _dot_postgresql_path('root.crt')
try:
sslrootcert = _dot_postgresql_path('root.crt')
assert sslrootcert is not None
ssl.load_verify_locations(cafile=sslrootcert)
except FileNotFoundError:
except (AssertionError, FileNotFoundError):
if sslmode > SSLMode.require:
if sslrootcert is None:
raise RuntimeError(
'Cannot determine home directory'
)
raise ValueError(
f'root certificate file "{sslrootcert}" does '
f'not exist\nEither provide the file or '
Expand All @@ -526,18 +536,20 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
ssl.verify_flags |= ssl_module.VERIFY_CRL_CHECK_CHAIN
else:
sslcrl = _dot_postgresql_path('root.crl')
try:
ssl.load_verify_locations(cafile=sslcrl)
except FileNotFoundError:
pass
else:
ssl.verify_flags |= ssl_module.VERIFY_CRL_CHECK_CHAIN
if sslcrl is not None:
try:
ssl.load_verify_locations(cafile=sslcrl)
except FileNotFoundError:
pass
else:
ssl.verify_flags |= \
ssl_module.VERIFY_CRL_CHECK_CHAIN

if sslkey is None:
sslkey = os.getenv('PGSSLKEY')
if not sslkey:
sslkey = _dot_postgresql_path('postgresql.key')
if not sslkey.exists():
if sslkey is not None and not sslkey.exists():
sslkey = None
if not sslpassword:
sslpassword = ''
Expand All @@ -549,12 +561,15 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
)
else:
sslcert = _dot_postgresql_path('postgresql.crt')
try:
ssl.load_cert_chain(
sslcert, keyfile=sslkey, password=lambda: sslpassword
)
except FileNotFoundError:
pass
if sslcert is not None:
try:
ssl.load_cert_chain(
sslcert,
keyfile=sslkey,
password=lambda: sslpassword
)
except FileNotFoundError:
pass

# OpenSSL 1.1.1 keylog file, copied from create_default_context()
if hasattr(ssl, 'keylog_filename'):
Expand Down
29 changes: 29 additions & 0 deletions tests/test_connect.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,14 @@ def mock_dot_postgresql(*, ca=True, crl=False, client=False, protected=False):
yield


@contextlib.contextmanager
def mock_no_home_dir():
with unittest.mock.patch(
'pathlib.Path.home', unittest.mock.Mock(side_effect=RuntimeError)
):
yield


class TestSettings(tb.ConnectedTestCase):

async def test_get_settings_01(self):
Expand Down Expand Up @@ -1257,6 +1265,27 @@ async def test_connection_implicit_host(self):
user=conn_spec.get('user'))
await con.close()

@unittest.skipIf(os.environ.get('PGHOST'), 'unmanaged cluster')
async def test_connection_no_home_dir(self):
with mock_no_home_dir():
con = await self.connect(
dsn='postgresql://foo/',
user='postgres',
database='postgres',
host='localhost')
await con.fetchval('SELECT 42')
await con.close()

with self.assertRaisesRegex(
RuntimeError,
'Cannot determine home directory'
):
with mock_no_home_dir():
await self.connect(
host='localhost',
user='ssl_user',
ssl='verify-full')


class BaseTestSSLConnection(tb.ConnectedTestCase):
@classmethod
Expand Down