Skip to content

Commit 2b104e0

Browse files
elpransElvis Pranskevichus
authored and
Elvis Pranskevichus
committed
Fix Connection.reset() on read-only connections
Fixes: #48
1 parent 330dbca commit 2b104e0

File tree

4 files changed

+171
-16
lines changed

4 files changed

+171
-16
lines changed

asyncpg/_testbase.py

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -70,17 +70,19 @@ def wrapper(self, *args, __meth__=meth, **kwargs):
7070

7171
class TestCase(unittest.TestCase, metaclass=TestCaseMeta):
7272

73-
def setUp(self):
73+
@classmethod
74+
def setUpClass(cls):
7475
if os.environ.get('USE_UVLOOP'):
7576
import uvloop
7677
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
7778

7879
loop = asyncio.new_event_loop()
7980
asyncio.set_event_loop(None)
80-
self.loop = loop
81+
cls.loop = loop
8182

82-
def tearDown(self):
83-
self.loop.close()
83+
@classmethod
84+
def tearDownClass(cls):
85+
cls.loop.close()
8486
asyncio.set_event_loop(None)
8587

8688
@contextlib.contextmanager
@@ -97,7 +99,16 @@ def assertRunUnder(self, delta):
9799
_default_cluster = None
98100

99101

100-
def _start_cluster(server_settings={}):
102+
def _start_cluster(ClusterCls, cluster_kwargs, server_settings):
103+
cluster = ClusterCls(**cluster_kwargs)
104+
cluster.init()
105+
cluster.trust_local_connections()
106+
cluster.start(port='dynamic', server_settings=server_settings)
107+
atexit.register(_shutdown_cluster, cluster)
108+
return cluster
109+
110+
111+
def _start_default_cluster(server_settings={}):
101112
global _default_cluster
102113

103114
if _default_cluster is None:
@@ -106,12 +117,8 @@ def _start_cluster(server_settings={}):
106117
# Using existing cluster, assuming it is initialized and running
107118
_default_cluster = pg_cluster.RunningCluster()
108119
else:
109-
_default_cluster = pg_cluster.TempCluster()
110-
_default_cluster.init()
111-
_default_cluster.trust_local_connections()
112-
_default_cluster.start(port='dynamic',
113-
server_settings=server_settings)
114-
atexit.register(_shutdown_cluster, _default_cluster)
120+
_default_cluster = _start_cluster(
121+
pg_cluster.TempCluster, {}, server_settings)
115122

116123
return _default_cluster
117124

@@ -122,9 +129,10 @@ def _shutdown_cluster(cluster):
122129

123130

124131
class ClusterTestCase(TestCase):
125-
def setUp(self):
126-
super().setUp()
127-
self.cluster = _start_cluster({
132+
@classmethod
133+
def setUpClass(cls):
134+
super().setUpClass()
135+
cls.cluster = _start_default_cluster({
128136
'log_connections': 'on'
129137
})
130138

@@ -133,6 +141,11 @@ def create_pool(self, **kwargs):
133141
conn_spec.update(kwargs)
134142
return pg_pool.create_pool(loop=self.loop, **conn_spec)
135143

144+
@classmethod
145+
def start_cluster(cls, ClusterCls, *,
146+
cluster_kwargs={}, server_settings={}):
147+
return _start_cluster(ClusterCls, cluster_kwargs, server_settings)
148+
136149

137150
class ConnectedTestCase(ClusterTestCase):
138151

asyncpg/cluster.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import socket
1717
import subprocess
1818
import tempfile
19+
import textwrap
1920
import time
2021

2122
import asyncpg
@@ -332,6 +333,20 @@ def trust_local_connections(self):
332333
if status == 'running':
333334
self.reload()
334335

336+
def trust_local_replication_by(self, user):
337+
if _system != 'Windows':
338+
self.add_hba_entry(type='local', database='replication',
339+
user=user, auth_method='trust')
340+
self.add_hba_entry(type='host', address='127.0.0.1/32',
341+
database='replication', user=user,
342+
auth_method='trust')
343+
self.add_hba_entry(type='host', address='::1/128',
344+
database='replication', user=user,
345+
auth_method='trust')
346+
status = self.get_status()
347+
if status == 'running':
348+
self.reload()
349+
335350
def _init_env(self):
336351
self._pg_config = self._find_pg_config(self._pg_config_path)
337352
self._pg_config_data = self._run_pg_config(self._pg_config)
@@ -489,6 +504,55 @@ def __init__(self, *,
489504
super().__init__(self._data_dir, pg_config_path=pg_config_path)
490505

491506

507+
class HotStandbyCluster(TempCluster):
508+
def __init__(self, *,
509+
master, replication_user,
510+
data_dir_suffix=None, data_dir_prefix=None,
511+
data_dir_parent=None, pg_config_path=None):
512+
self._master = master
513+
self._repl_user = replication_user
514+
super().__init__(
515+
data_dir_suffix=data_dir_suffix,
516+
data_dir_prefix=data_dir_prefix,
517+
data_dir_parent=data_dir_parent,
518+
pg_config_path=pg_config_path)
519+
520+
def _init_env(self):
521+
super()._init_env()
522+
self._pg_basebackup = self._find_pg_binary('pg_basebackup')
523+
524+
def init(self, **settings):
525+
"""Initialize cluster."""
526+
if self.get_status() != 'not-initialized':
527+
raise ClusterError(
528+
'cluster in {!r} has already been initialized'.format(
529+
self._data_dir))
530+
531+
process = subprocess.run(
532+
[self._pg_basebackup, '-h', self._master['host'],
533+
'-p', self._master['port'], '-D', self._data_dir,
534+
'-U', self._repl_user],
535+
stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
536+
537+
output = process.stdout
538+
539+
if process.returncode != 0:
540+
raise ClusterError(
541+
'pg_basebackup init exited with status {:d}:\n{}'.format(
542+
process.returncode, output.decode()))
543+
544+
with open(os.path.join(self._data_dir, 'recovery.conf'), 'w') as f:
545+
f.write(textwrap.dedent("""\
546+
standby_mode = 'on'
547+
primary_conninfo = 'host={host} port={port} user={user}'
548+
""".format(
549+
host=self._master['host'],
550+
port=self._master['port'],
551+
user=self._repl_user)))
552+
553+
return output.decode()
554+
555+
492556
class RunningCluster(Cluster):
493557
def __init__(self, **kwargs):
494558
self.conn_spec = kwargs

asyncpg/connection.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -369,11 +369,19 @@ def terminate(self):
369369

370370
async def reset(self):
371371
self._listeners = {}
372+
372373
await self.execute('''
374+
DO $$
375+
BEGIN
376+
PERFORM * FROM pg_listening_channels() LIMIT 1;
377+
IF FOUND THEN
378+
UNLISTEN *;
379+
END IF;
380+
END;
381+
$$;
373382
SET SESSION AUTHORIZATION DEFAULT;
374383
RESET ALL;
375384
CLOSE ALL;
376-
UNLISTEN *;
377385
SELECT pg_advisory_unlock_all();
378386
''')
379387

tests/test_pool.py

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
import platform
1010

1111
from asyncpg import _testbase as tb
12-
12+
from asyncpg import cluster as pg_cluster
13+
from asyncpg import pool as pg_pool
1314

1415
_system = platform.uname().system
1516

@@ -148,3 +149,72 @@ async def worker():
148149
# Reset cluster's pg_hba.conf since we've meddled with it
149150
self.cluster.trust_local_connections()
150151
self.cluster.reload()
152+
153+
154+
class TestHostStandby(tb.ConnectedTestCase):
155+
@classmethod
156+
def setUpClass(cls):
157+
super().setUpClass()
158+
159+
cls.master_cluster = cls.start_cluster(
160+
pg_cluster.TempCluster,
161+
server_settings={
162+
'max_wal_senders': 10,
163+
'wal_level': 'hot_standby'
164+
})
165+
166+
con = None
167+
168+
try:
169+
con = cls.loop.run_until_complete(
170+
cls.master_cluster.connect(database='postgres', loop=cls.loop))
171+
172+
cls.loop.run_until_complete(
173+
con.execute('''
174+
CREATE ROLE replication WITH LOGIN REPLICATION
175+
'''))
176+
177+
cls.master_cluster.trust_local_replication_by('replication')
178+
179+
conn_spec = cls.master_cluster.get_connection_spec()
180+
181+
cls.standby_cluster = cls.start_cluster(
182+
pg_cluster.HotStandbyCluster,
183+
cluster_kwargs={
184+
'master': conn_spec,
185+
'replication_user': 'replication'
186+
},
187+
server_settings={
188+
'hot_standby': True
189+
})
190+
191+
finally:
192+
if con is not None:
193+
cls.loop.run_until_complete(con.close())
194+
195+
@classmethod
196+
def tearDownMethod(cls):
197+
cls.standby_cluster.stop()
198+
cls.standby_cluster.destroy()
199+
cls.master_cluster.stop()
200+
cls.master_cluster.destroy()
201+
202+
def create_pool(self, **kwargs):
203+
conn_spec = self.standby_cluster.get_connection_spec()
204+
conn_spec.update(kwargs)
205+
return pg_pool.create_pool(loop=self.loop, **conn_spec)
206+
207+
async def test_standby_pool_01(self):
208+
for n in {1, 3, 5, 10, 20, 100}:
209+
with self.subTest(tasksnum=n):
210+
pool = await self.create_pool(database='postgres',
211+
min_size=5, max_size=10)
212+
213+
async def worker():
214+
con = await pool.acquire()
215+
self.assertEqual(await con.fetchval('SELECT 1'), 1)
216+
await pool.release(con)
217+
218+
tasks = [worker() for _ in range(n)]
219+
await asyncio.gather(*tasks, loop=self.loop)
220+
await pool.close()

0 commit comments

Comments
 (0)