Skip to content

Commit 575227f

Browse files
committed
Add support for reading passwords from .pgpass
This largely mirrors libpq's behaviour with respect to ~/.pgpass. Fixes: #267.
1 parent 2f558c2 commit 575227f

File tree

5 files changed

+368
-11
lines changed

5 files changed

+368
-11
lines changed

asyncpg/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,4 @@
1515
__all__ = ('connect', 'create_pool', 'Record', 'Connection') + \
1616
exceptions.__all__ # NOQA
1717

18-
__version__ = '0.15.0'
18+
__version__ = '0.16.0.dev0'

asyncpg/compat.py

+21
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,13 @@
77

88
import functools
99
import os
10+
import pathlib
11+
import platform
1012
import sys
1113

1214

1315
PY_36 = sys.version_info >= (3, 6)
16+
SYSTEM = platform.uname().system
1417

1518

1619
if sys.version_info < (3, 5, 2):
@@ -45,3 +48,21 @@ def fspath(path):
4548
type(path).__name__
4649
)
4750
)
51+
52+
53+
if SYSTEM == 'Windows':
54+
import ctypes.wintypes
55+
56+
CSIDL_APPDATA = 0x001a
57+
58+
def get_pg_home_directory() -> pathlib.Path:
59+
buf = ctypes.create_unicode_buffer(ctypes.wintypes.MAX_PATH)
60+
r = ctypes.windll.shell32.SHGetFolderPathW(0, CSIDL_APPDATA, 0, 0, buf)
61+
if not r:
62+
return None
63+
else:
64+
return pathlib.Path(buf.value)
65+
66+
else:
67+
def get_pg_home_directory() -> pathlib.Path:
68+
return pathlib.Path.home()

asyncpg/connect_utils.py

+126-6
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,18 @@
99
import collections
1010
import getpass
1111
import os
12+
import pathlib
1213
import platform
14+
import re
1315
import socket
16+
import stat
1417
import struct
1518
import time
19+
import typing
1620
import urllib.parse
21+
import warnings
1722

23+
from . import compat
1824
from . import exceptions
1925
from . import protocol
2026

@@ -44,9 +50,92 @@
4450
_system = platform.uname().system
4551

4652

53+
if _system == 'Windows':
54+
PGPASSFILE = 'pgpass.conf'
55+
else:
56+
PGPASSFILE = '.pgpass'
57+
58+
59+
def _read_password_file(passfile: pathlib.Path) \
60+
-> typing.List[typing.Tuple[str, ...]]:
61+
62+
if not passfile.is_file():
63+
warnings.warn(
64+
'password file {!r} is not a plain file'.format(passfile))
65+
66+
return None
67+
68+
if _system != 'Windows':
69+
if passfile.stat().st_mode & (stat.S_IRWXG | stat.S_IRWXO):
70+
warnings.warn(
71+
'password file {!r} has group or world access; '
72+
'permissions should be u=rw (0600) or less'.format(passfile))
73+
74+
return None
75+
76+
passtab = []
77+
78+
try:
79+
with passfile.open('rt') as f:
80+
for line in f:
81+
line = line.strip()
82+
if not line or line.startswith('#'):
83+
# Skip empty lines and comments.
84+
continue
85+
# Backslash escapes both itself and the colon,
86+
# which is a record separator.
87+
line = line.replace(R'\\', '\n')
88+
passtab.append(tuple(
89+
p.replace('\n', R'\\')
90+
for p in re.split(r'(?<!\\):', line, maxsplit=4)
91+
))
92+
except IOError:
93+
pass
94+
95+
return passtab
96+
97+
98+
def _read_password_from_pgpass(
99+
*, passfile: typing.Optional[pathlib.Path],
100+
hosts: typing.List[typing.Union[str, typing.Tuple[str, int]]],
101+
port: int, database: str, user: str):
102+
"""Parse the pgpass file and return the matching password.
103+
104+
:return:
105+
Password string, if found, ``None`` otherwise.
106+
"""
107+
108+
if not passfile.exists():
109+
return None
110+
111+
passtab = _read_password_file(passfile)
112+
if not passtab:
113+
return None
114+
115+
for host in hosts:
116+
if host.startswith('/'):
117+
# Unix sockets get normalized into 'localhost'
118+
host = 'localhost'
119+
120+
for phost, pport, pdatabase, puser, ppassword in passtab:
121+
if phost != '*' and phost != host:
122+
continue
123+
if pport != '*' and pport != str(port):
124+
continue
125+
if pdatabase != '*' and pdatabase != database:
126+
continue
127+
if puser != '*' and puser != user:
128+
continue
129+
130+
# Found a match.
131+
return ppassword
132+
133+
return None
134+
135+
47136
def _parse_connect_dsn_and_args(*, dsn, host, port, user,
48-
password, database, ssl, connect_timeout,
49-
server_settings):
137+
password, passfile, database, ssl,
138+
connect_timeout, server_settings):
50139
if host is not None and not isinstance(host, str):
51140
raise TypeError(
52141
'host argument is expected to be str, got {!r}'.format(
@@ -113,6 +202,11 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
113202
if password is None:
114203
password = val
115204

205+
if 'passfile' in query:
206+
val = query.pop('passfile')
207+
if passfile is None:
208+
passfile = val
209+
116210
if query:
117211
if server_settings is None:
118212
server_settings = query
@@ -123,10 +217,14 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
123217
# https://www.postgresql.org/docs/current/static/libpq-envars.html
124218
# Note that env values may be an empty string in cases when
125219
# the variable is "unset" by setting it to an empty value
126-
#
220+
# `auth_hosts` is the version of host information for the purposes
221+
# of reading the pgpass file.
222+
auth_hosts = None
127223
if host is None:
128224
host = os.getenv('PGHOST')
129225
if not host:
226+
auth_hosts = ['localhost']
227+
130228
if _system == 'Windows':
131229
host = ['localhost']
132230
else:
@@ -137,6 +235,9 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
137235
if not isinstance(host, list):
138236
host = [host]
139237

238+
if auth_hosts is None:
239+
auth_hosts = host
240+
140241
if port is None:
141242
port = os.getenv('PGPORT')
142243
if port:
@@ -168,6 +269,24 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
168269
raise exceptions.InterfaceError(
169270
'could not determine database name to connect to')
170271

272+
if password is None:
273+
if passfile is None:
274+
passfile = os.getenv('PGPASSFILE')
275+
276+
if passfile is None:
277+
homedir = compat.get_pg_home_directory()
278+
if homedir:
279+
passfile = homedir / PGPASSFILE
280+
else:
281+
passfile = None
282+
else:
283+
passfile = pathlib.Path(passfile)
284+
285+
if passfile is not None:
286+
password = _read_password_from_pgpass(
287+
hosts=auth_hosts, port=port, database=database, user=user,
288+
passfile=passfile)
289+
171290
addrs = []
172291
for h in host:
173292
if h.startswith('/'):
@@ -206,8 +325,9 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
206325
return addrs, params
207326

208327

209-
def _parse_connect_arguments(*, dsn, host, port, user, password, database,
210-
timeout, command_timeout, statement_cache_size,
328+
def _parse_connect_arguments(*, dsn, host, port, user, password, passfile,
329+
database, timeout, command_timeout,
330+
statement_cache_size,
211331
max_cached_statement_lifetime,
212332
max_cacheable_statement_size,
213333
ssl, server_settings):
@@ -237,7 +357,7 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, database,
237357

238358
addrs, params = _parse_connect_dsn_and_args(
239359
dsn=dsn, host=host, port=port, user=user,
240-
password=password, ssl=ssl,
360+
password=password, passfile=passfile, ssl=ssl,
241361
database=database, connect_timeout=timeout,
242362
server_settings=server_settings)
243363

asyncpg/connection.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -1382,7 +1382,7 @@ async def _do_execute(self, query, executor, timeout, retry=True):
13821382

13831383
async def connect(dsn=None, *,
13841384
host=None, port=None,
1385-
user=None, password=None,
1385+
user=None, password=None, passfile=None,
13861386
database=None,
13871387
loop=None,
13881388
timeout=60,
@@ -1424,6 +1424,11 @@ async def connect(dsn=None, *,
14241424
:param password:
14251425
password used for authentication
14261426
1427+
:param passfile:
1428+
the name of the file used to store passwords
1429+
(defaults to ``~/.pgpass``, or ``%APPDATA%\postgresql\pgpass.conf``
1430+
on Windows)
1431+
14271432
:param loop:
14281433
An asyncio event loop instance. If ``None``, the default
14291434
event loop will be used.
@@ -1489,6 +1494,10 @@ class of the returned connection object. Must be a subclass of
14891494
.. versionadded:: 0.11.0
14901495
Added ``connection_class`` parameter.
14911496
1497+
.. versionadded:: 0.16.0
1498+
Added ``passfile`` parameter
1499+
(and support for password files in general).
1500+
14921501
.. _SSLContext: https://docs.python.org/3/library/ssl.html#ssl.SSLContext
14931502
.. _create_default_context: https://docs.python.org/3/library/ssl.html#\
14941503
ssl.create_default_context
@@ -1503,7 +1512,8 @@ class of the returned connection object. Must be a subclass of
15031512

15041513
return await connect_utils._connect(
15051514
loop=loop, timeout=timeout, connection_class=connection_class,
1506-
dsn=dsn, host=host, port=port, user=user, password=password,
1515+
dsn=dsn, host=host, port=port, user=user,
1516+
password=password, passfile=passfile,
15071517
ssl=ssl, database=database,
15081518
server_settings=server_settings,
15091519
command_timeout=command_timeout,

0 commit comments

Comments
 (0)