Skip to content

Fix pool connection ownership #732

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 14 additions & 20 deletions neo4j/_async/io/_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,31 +241,24 @@ async def deactivate(self, address):
connections = self.connections[address]
except KeyError: # already removed from the connection pool
return
for conn in list(connections):
if not conn.in_use:
connections.remove(conn)
try:
await conn.close()
except OSError:
pass
if not connections:
await self.remove(address)
closable_connections = [
conn for conn in connections if not conn.in_use
]
# First remove all connections in question, then try to close them.
# If closing of a connection fails, we will end up in this method
# again.
for conn in closable_connections:
connections.remove(conn)
for conn in closable_connections:
await conn.close()
if not self.connections[address]:
del self.connections[address]

def on_write_failure(self, address):
raise WriteServiceUnavailable(
"No write service available for pool {}".format(self)
)

async def remove(self, address):
""" Remove an address from the connection pool, if present, closing
all connections to that address.
"""
async with self.lock:
for connection in self.connections.pop(address, ()):
try:
await connection.close()
except OSError:
pass

async def close(self):
""" Close all connections and empty the pool.
Expand All @@ -274,7 +267,8 @@ async def close(self):
try:
async with self.lock:
for address in list(self.connections):
await self.remove(address)
for connection in self.connections.pop(address, ()):
await connection.close()
except TypeError:
pass

Expand Down
34 changes: 14 additions & 20 deletions neo4j/_sync/io/_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,31 +241,24 @@ def deactivate(self, address):
connections = self.connections[address]
except KeyError: # already removed from the connection pool
return
for conn in list(connections):
if not conn.in_use:
connections.remove(conn)
try:
conn.close()
except OSError:
pass
if not connections:
self.remove(address)
closable_connections = [
conn for conn in connections if not conn.in_use
]
# First remove all connections in question, then try to close them.
# If closing of a connection fails, we will end up in this method
# again.
for conn in closable_connections:
connections.remove(conn)
for conn in closable_connections:
conn.close()
if not self.connections[address]:
del self.connections[address]

def on_write_failure(self, address):
raise WriteServiceUnavailable(
"No write service available for pool {}".format(self)
)

def remove(self, address):
""" Remove an address from the connection pool, if present, closing
all connections to that address.
"""
with self.lock:
for connection in self.connections.pop(address, ()):
try:
connection.close()
except OSError:
pass

def close(self):
""" Close all connections and empty the pool.
Expand All @@ -274,7 +267,8 @@ def close(self):
try:
with self.lock:
for address in list(self.connections):
self.remove(address)
for connection in self.connections.pop(address, ()):
connection.close()
except TypeError:
pass

Expand Down
47 changes: 47 additions & 0 deletions tests/unit/async_/io/test_neo4j_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,3 +384,50 @@ def liveness_side_effect(*args, **kwargs):
cx3.reset.assert_awaited_once()
assert cx1 not in pool.connections[cx1.addr]
assert cx3 in pool.connections[cx1.addr]


@mark_async_test
async def test_multiple_broken_connections_on_close(opener, mocker):
def mock_connection_breaks_on_close(cx):
async def close_side_effect():
cx.closed.return_value = True
cx.defunct.return_value = True
await pool.deactivate(READER_ADDRESS)

cx.attach_mock(mocker.AsyncMock(side_effect=close_side_effect),
"close")

# create pool with 2 idle connections
pool = AsyncNeo4jPool(
opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS
)
cx1 = await pool.acquire(READ_ACCESS, 30, "test_db", None)
cx2 = await pool.acquire(READ_ACCESS, 30, "test_db", None)
await pool.release(cx1)
await pool.release(cx2)

# both will loose connection
mock_connection_breaks_on_close(cx1)
mock_connection_breaks_on_close(cx2)

# force pool to close cx1, which will make it realize that the server is
# unreachable
cx1.stale.return_value = True

cx3 = await pool.acquire(READ_ACCESS, 30, "test_db", None)

assert cx3 is not cx1
assert cx3 is not cx2


@mark_async_test
async def test_failing_opener_leaves_connections_in_use_alone(opener):
pool = AsyncNeo4jPool(
opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS
)
cx1 = await pool.acquire(READ_ACCESS, 30, "test_db", None)

opener.side_effect = ServiceUnavailable("Server overloaded")
with pytest.raises((ServiceUnavailable, SessionExpired)):
await pool.acquire(READ_ACCESS, 30, "test_db", None)
assert not cx1.closed()
47 changes: 47 additions & 0 deletions tests/unit/sync/io/test_neo4j_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,3 +384,50 @@ def liveness_side_effect(*args, **kwargs):
cx3.reset.assert_called_once()
assert cx1 not in pool.connections[cx1.addr]
assert cx3 in pool.connections[cx1.addr]


@mark_sync_test
def test_multiple_broken_connections_on_close(opener, mocker):
def mock_connection_breaks_on_close(cx):
def close_side_effect():
cx.closed.return_value = True
cx.defunct.return_value = True
pool.deactivate(READER_ADDRESS)

cx.attach_mock(mocker.Mock(side_effect=close_side_effect),
"close")

# create pool with 2 idle connections
pool = Neo4jPool(
opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS
)
cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None)
cx2 = pool.acquire(READ_ACCESS, 30, "test_db", None)
pool.release(cx1)
pool.release(cx2)

# both will loose connection
mock_connection_breaks_on_close(cx1)
mock_connection_breaks_on_close(cx2)

# force pool to close cx1, which will make it realize that the server is
# unreachable
cx1.stale.return_value = True

cx3 = pool.acquire(READ_ACCESS, 30, "test_db", None)

assert cx3 is not cx1
assert cx3 is not cx2


@mark_sync_test
def test_failing_opener_leaves_connections_in_use_alone(opener):
pool = Neo4jPool(
opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS
)
cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None)

opener.side_effect = ServiceUnavailable("Server overloaded")
with pytest.raises((ServiceUnavailable, SessionExpired)):
pool.acquire(READ_ACCESS, 30, "test_db", None)
assert not cx1.closed()