diff --git a/docs/source/api.rst b/docs/source/api.rst index 3a18ba95a..c2a119daa 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -160,6 +160,8 @@ Driver Configuration Additional configuration can be provided via the :class:`neo4j.Driver` constructor. ++ :ref:`session-connection-timeout-ref` ++ :ref:`update-routing-table-timeout-ref` + :ref:`connection-acquisition-timeout-ref` + :ref:`connection-timeout-ref` + :ref:`encrypted-ref` @@ -172,12 +174,59 @@ Additional configuration can be provided via the :class:`neo4j.Driver` construct + :ref:`user-agent-ref` +.. _session-connection-timeout-ref: + +``session_connection_timeout`` +------------------------------ +The maximum amount of time in seconds the session will wait when trying to +establish a usable read/write connection to the remote host. +This encompasses *everything* that needs to happen for this, including, +if necessary, updating the routing table, fetching a connection from the pool, +and, if necessary fully establishing a new connection with the reader/writer. + +Since this process may involve updating the routing table, acquiring a +connection from the pool, or establishing a new connection, it should be chosen +larger than :ref:`update-routing-table-timeout-ref`, +:ref:`connection-acquisition-timeout-ref`, and :ref:`connection-timeout-ref`. + +:Type: ``float`` +:Default: ``float("inf")`` + +.. versionadded:: 4.4.5 + + +.. _update-routing-table-timeout-ref: + +``update_routing_table_timeout`` +-------------------------------- +The maximum amount of time in seconds the driver will attempt to fetch a new +routing table. This encompasses *everything* that needs to happen for this, +including fetching connections from the pool, performing handshakes, and +requesting and receiving a fresh routing table. + +Since this process may involve acquiring a connection from the pool, or +establishing a new connection, it should be chosen larger than +:ref:`connection-acquisition-timeout-ref` and :ref:`connection-timeout-ref`. + +This setting only has an effect for :ref:`neo4j-driver-ref`, but not for +:ref:`bolt-driver-ref` as it does no routing at all. + +:Type: ``float`` +:Default: ``90.0`` + +.. versionadded:: 4.4.5 + + .. _connection-acquisition-timeout-ref: ``connection_acquisition_timeout`` ---------------------------------- -The maximum amount of time in seconds a session will wait when requesting a connection from the connection pool. -Since the process of acquiring a connection may involve creating a new connection, ensure that the value of this configuration is higher than the configured :ref:`connection-timeout-ref`. +The maximum amount of time in seconds the driver will wait to either acquire an +idle connection from the pool (including potential liveness checks) or create a +new connection when the pool is not full and all existing connection are in use. + +Since this process may involve opening a new connection including handshakes, +it should be chosen larger than :ref:`connection-timeout-ref`. :Type: ``float`` :Default: ``60.0`` @@ -187,7 +236,11 @@ Since the process of acquiring a connection may involve creating a new connectio ``connection_timeout`` ---------------------- -The maximum amount of time in seconds to wait for a TCP connection to be established. +The maximum amount of time in seconds to wait for a TCP connection to be +established. + +This *does not* include any handshake(s), or authentication required before the +connection can be used to perform database related work. :Type: ``float`` :Default: ``30.0`` diff --git a/neo4j/__init__.py b/neo4j/__init__.py index 30c001c99..6266a4462 100644 --- a/neo4j/__init__.py +++ b/neo4j/__init__.py @@ -60,6 +60,7 @@ from logging import getLogger +from neo4j._deadline import Deadline from neo4j.addressing import ( Address, IPv4Address, @@ -451,6 +452,7 @@ def _verify_routing_connectivity(self): ) table = self._pool.get_routing_table_for_default_database() + timeout = self._default_workspace_config.connection_acquisition_timeout routing_info = {} for ix in list(table.routers): try: @@ -459,8 +461,7 @@ def _verify_routing_connectivity(self): database=self._default_workspace_config.database, imp_user=self._default_workspace_config.impersonated_user, bookmarks=None, - timeout=self._default_workspace_config - .connection_acquisition_timeout + deadline=Deadline(timeout) ) except (ServiceUnavailable, SessionExpired, Neo4jError): routing_info[ix] = None diff --git a/neo4j/_deadline.py b/neo4j/_deadline.py new file mode 100644 index 000000000..cfcc9035d --- /dev/null +++ b/neo4j/_deadline.py @@ -0,0 +1,98 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from contextlib import contextmanager +from time import perf_counter + + +class Deadline: + def __init__(self, timeout): + if timeout is None or timeout == float("inf"): + self._deadline = float("inf") + else: + self._deadline = perf_counter() + timeout + self._original_timeout = timeout + + @property + def original_timeout(self): + return self._original_timeout + + def expired(self): + return self.to_timeout() == 0 + + def to_timeout(self): + if self._deadline == float("inf"): + return None + timeout = self._deadline - perf_counter() + return timeout if timeout > 0 else 0 + + def __eq__(self, other): + if isinstance(other, Deadline): + return self._deadline == other._deadline + return NotImplemented + + def __gt__(self, other): + if isinstance(other, Deadline): + return self._deadline > other._deadline + return NotImplemented + + def __ge__(self, other): + if isinstance(other, Deadline): + return self._deadline >= other._deadline + return NotImplemented + + def __lt__(self, other): + if isinstance(other, Deadline): + return self._deadline < other._deadline + return NotImplemented + + def __le__(self, other): + if isinstance(other, Deadline): + return self._deadline <= other._deadline + return NotImplemented + + @classmethod + def from_timeout_or_deadline(cls, timeout): + if isinstance(timeout, cls): + return timeout + return cls(timeout) + + +merge_deadlines = min + + +def merge_deadlines_and_timeouts(*deadline): + deadlines = map(Deadline.from_timeout_or_deadline, deadline) + return merge_deadlines(deadlines) + + +@contextmanager +def connection_deadline(connection, deadline): + original_deadline = connection.socket.get_deadline() + if deadline is None and original_deadline is not None: + # nothing to do here + yield + return + deadline = merge_deadlines( + (d for d in (deadline, original_deadline) if d is not None) + ) + connection.socket.set_deadline(deadline) + try: + yield + finally: + connection.socket.set_deadline(original_deadline) diff --git a/neo4j/_exceptions.py b/neo4j/_exceptions.py index 67db7f6cf..c2fdce540 100644 --- a/neo4j/_exceptions.py +++ b/neo4j/_exceptions.py @@ -172,3 +172,7 @@ def transaction(self): class BoltProtocolError(BoltError): """ Raised when an unexpected or unsupported protocol event occurs. """ + + +class SocketDeadlineExceeded(RuntimeError): + """Raised from sockets with deadlines when a timeout occurs.""" diff --git a/neo4j/conf.py b/neo4j/conf.py index f74dd2e51..e3fd7f968 100644 --- a/neo4j/conf.py +++ b/neo4j/conf.py @@ -185,6 +185,13 @@ class PoolConfig(Config): connection_timeout = 30.0 # seconds # The maximum amount of time to wait for a TCP connection to be established. + #: Update Routing Table Timout + update_routing_table_timeout = 90.0 # seconds + # The maximum amount of time to wait for updating the routing table. + # This includes everything necessary for this to happen. + # Including opening sockets, requesting and receiving the routing table, + # etc. + #: Trust trust = TRUST_SYSTEM_CA_SIGNED_CERTIFICATES # Specify how to determine the authenticity of encryption certificates provided by the Neo4j instance on connection. @@ -256,6 +263,12 @@ class WorkspaceConfig(Config): """ WorkSpace configuration. """ + #: Session Connection Timeout + session_connection_timeout = float("inf") # seconds + # The maximum amount of time to wait for a session to obtain a usable + # read/write connection. This includes everything necessary for this to + # happen. Including fetching routing tables, opening sockets, etc. + #: Connection Acquisition Timeout connection_acquisition_timeout = 60.0 # seconds # The maximum amount of time a session will wait when requesting a connection from the connection pool. diff --git a/neo4j/exceptions.py b/neo4j/exceptions.py index c69c75c73..aba256539 100644 --- a/neo4j/exceptions.py +++ b/neo4j/exceptions.py @@ -152,6 +152,8 @@ def __str__(self): class ClientError(Neo4jError): """ The Client sent a bad request - changing the request might yield a successful outcome. """ + def __str__(self): + return super(Neo4jError, self).__str__() class DatabaseError(Neo4jError): diff --git a/neo4j/io/__init__.py b/neo4j/io/__init__.py index 9aae2f998..40d12b575 100644 --- a/neo4j/io/__init__.py +++ b/neo4j/io/__init__.py @@ -39,24 +39,10 @@ defaultdict, deque, ) +from contextlib import contextmanager import logging from logging import getLogger from random import choice -import selectors -from socket import ( - AF_INET, - AF_INET6, - SHUT_RDWR, - SO_KEEPALIVE, - socket, - SOL_SOCKET, - timeout as SocketTimeout, -) -from ssl import ( - CertificateError, - HAS_SNI, - SSLError, -) from threading import ( Condition, RLock, @@ -66,8 +52,12 @@ from neo4j._exceptions import ( BoltError, BoltHandshakeError, - BoltProtocolError, - BoltSecurityError, +) +from neo4j._deadline import ( + connection_deadline, + Deadline, + merge_deadlines, + merge_deadlines_and_timeouts, ) from neo4j.addressing import Address from neo4j.api import ( @@ -101,6 +91,7 @@ Outbox, Response, ) +from neo4j.io._socket import BoltSocket from neo4j.meta import get_user_agent from neo4j.packstream import ( Packer, @@ -133,6 +124,7 @@ class Bolt(abc.ABC): in_use = False # The socket + _closing = False _closed = False # The socket @@ -289,7 +281,7 @@ def ping(cls, address, *, timeout=None, **config): """ config = PoolConfig.consume(config) try: - s, protocol_version, handshake, data = connect( + s, protocol_version, handshake, data = BoltSocket.connect( address, timeout=timeout, custom_resolver=config.resolver, @@ -299,7 +291,7 @@ def ping(cls, address, *, timeout=None, **config): except (ServiceUnavailable, SessionExpired, BoltHandshakeError): return None else: - _close_socket(s) + BoltSocket.close_socket(s) return protocol_version @classmethod @@ -315,10 +307,24 @@ def open(cls, address, *, auth=None, timeout=None, routing_context=None, **pool_ :raise BoltHandshakeError: raised if the Bolt Protocol can not negotiate a protocol version. :raise ServiceUnavailable: raised if there was a connection issue. """ + def time_remaining(): + if timeout is None: + return None + t = timeout - (perf_counter() - t0) + return t if t > 0 else 0 + + t0 = perf_counter() pool_config = PoolConfig.consume(pool_config) - s, pool_config.protocol_version, handshake, data = connect( + + socket_connection_timeout = pool_config.connection_timeout + if socket_connection_timeout is None: + socket_connection_timeout = time_remaining() + elif timeout is not None: + socket_connection_timeout = min(pool_config.connection_timeout, + time_remaining()) + s, pool_config.protocol_version, handshake, data = BoltSocket.connect( address, - timeout=timeout, + timeout=socket_connection_timeout, custom_resolver=pool_config.resolver, ssl_context=pool_config.get_ssl_context(), keep_alive=pool_config.keep_alive, @@ -346,7 +352,7 @@ def open(cls, address, *, auth=None, timeout=None, routing_context=None, **pool_ bolt_cls = Bolt4x4 else: log.debug("[#%04X] S: ", s.getsockname()[1]) - _close_socket(s) + BoltSocket.close_socket(s) supported_versions = Bolt.protocol_handlers().keys() raise BoltHandshakeError("The Neo4J server does not support communication with this driver. This driver have support for Bolt Protocols {}".format(supported_versions), address=address, request_data=handshake, response_data=data) @@ -357,9 +363,13 @@ def open(cls, address, *, auth=None, timeout=None, routing_context=None, **pool_ ) try: - connection.hello() + connection.socket.set_deadline(time_remaining()) + try: + connection.hello() + finally: + connection.socket.set_deadline(None) except Exception: - connection.close() + connection.close_non_blocking() raise return connection @@ -476,6 +486,11 @@ def reset(self): """ pass + @abc.abstractmethod + def goodbye(self): + """Append a GOODBYE message to the outgoing queue.""" + pass + def _append(self, signature, fields=(), response=None): """ Appends a message to the outgoing queue. @@ -550,16 +565,20 @@ def _set_defunct(self, message, error=None, silent=False): direct_driver = isinstance(self.pool, BoltPool) if error: - log.debug("[#%04X] %s", self.socket.getsockname()[1], error) + log.debug("[#%04X] %r", self.socket.getsockname()[1], error) log.error(message) # We were attempting to receive data but the connection # has unexpectedly terminated. So, we need to close the # connection from the client side, and remove the address # from the connection pool. self._defunct = True - self.close() - if self.pool: - self.pool.deactivate(address=self.unresolved_address) + if not self._closing: + # If we fail while closing the connection, there is no need to + # remove the connection from the pool, nor to try to close the + # connection again. + self.close() + if self.pool: + self.pool.deactivate(address=self.unresolved_address) # Iterate through the outstanding responses, and if any correspond # to COMMIT requests then raise an error to signal that we are # unable to confirm that the COMMIT completed successfully. @@ -593,11 +612,36 @@ def stale(self): def set_stale(self): self._stale = True - @abc.abstractmethod def close(self): - """ Close the connection. + """Close the connection.""" + if self._closed or self._closing: + return + self._closing = True + if not self._defunct: + self.goodbye() + try: + self._send_all() + except (OSError, BoltError, DriverError): + pass + log.debug("[#%04X] C: ", self.local_port) + try: + self.socket.close() + except OSError: + pass + finally: + self._closed = True + + async def close_non_blocking(self): + """Set the socket to non-blocking and close it. + This will try to send the `GOODBYE` message (given the socket is not + marked as defunct). However, should the write operation require + blocking (e.g., a full network buffer), then the socket will be closed + immediately (without `GOODBYE` message). """ - pass + if self._closed or self._closing: + return + self.socket.settimeout(0) + self.close() @abc.abstractmethod def closed(self): @@ -621,6 +665,7 @@ def __init__(self, opener, pool_config, workspace_config): self.pool_config = pool_config self.workspace_config = workspace_config self.connections = defaultdict(deque) + self.connections_reservations = defaultdict(lambda: 0) self.lock = RLock() self.cond = Condition(self.lock) @@ -630,96 +675,136 @@ def __enter__(self): def __exit__(self, exc_type, exc_value, traceback): self.close() - def _acquire(self, address, timeout): + def _acquire_from_pool(self, address): + with self.lock: + for connection in list(self.connections.get(address, [])): + if connection.in_use: + continue + connection.pool = self + connection.in_use = True + return connection + return None # no free connection available + + def _acquire_from_pool_checked( + self, address, health_check, deadline + ): + while not deadline.expired(): + connection = self._acquire_from_pool(address) + if not connection: + return None # no free connection available + if not health_check(connection, deadline): + # `close` is a noop on already closed connections. + # This is to make sure that the connection is + # gracefully closed, e.g. if it's just marked as + # `stale` but still alive. + if log.isEnabledFor(logging.DEBUG): + log.debug( + "[#%04X] C: removing old connection " + "(closed=%s, defunct=%s, stale=%s, in_use=%s)", + connection.local_port, + connection.closed(), connection.defunct(), + connection.stale(), connection.in_use + ) + connection.close() + with self.lock: + try: + self.connections.get(address, []).remove(connection) + except ValueError: + # If closure fails (e.g. because the server went + # down), all connections to the same address will + # be removed. Therefore, we silently ignore if the + # connection isn't in the pool anymore. + pass + continue # try again with a new connection + else: + return connection + + def _acquire_new_later(self, address, deadline): + def connection_creator(): + released_reservation = False + try: + try: + connection = self.opener( + address, deadline.to_timeout() + ) + except ServiceUnavailable: + self.deactivate(address) + raise + connection.pool = self + connection.in_use = True + with self.lock: + self.connections_reservations[address] -= 1 + released_reservation = True + self.connections[address].append(connection) + return connection + finally: + if not released_reservation: + with self.lock: + self.connections_reservations[address] -= 1 + + max_pool_size = self.pool_config.max_connection_pool_size + infinite_pool_size = (max_pool_size < 0 + or max_pool_size == float("inf")) + with self.lock: + connections = self.connections[address] + pool_size = (len(connections) + + self.connections_reservations[address]) + can_create_new_connection = (infinite_pool_size + or pool_size < max_pool_size) + self.connections_reservations[address] += 1 + if can_create_new_connection: + return connection_creator + + def _acquire(self, address, deadline): + """ Acquire a connection to a given address from the pool. The address supplied should always be an IP address, not a host name. This method is thread safe. """ - t0 = perf_counter() - if timeout is None: - timeout = self.workspace_config.connection_acquisition_timeout + def health_check(connection_, _deadline): + if (connection_.closed() + or connection_.defunct() + or connection_.stale()): + return False + return True - with self.lock: - def time_remaining(): - t = timeout - (perf_counter() - t0) - return t if t > 0 else 0 - - while True: - # try to find a free connection in pool - for connection in list(self.connections.get(address, [])): - if (connection.closed() or connection.defunct() - or (connection.stale() and not connection.in_use)): - # `close` is a noop on already closed connections. - # This is to make sure that the connection is gracefully - # closed, e.g. if it's just marked as `stale` but still - # alive. - if log.isEnabledFor(logging.DEBUG): - log.debug( - "[#%04X] C: removing old connection " - "(closed=%s, defunct=%s, stale=%s, in_use=%s)", - connection.local_port, - connection.closed(), connection.defunct(), - connection.stale(), connection.in_use - ) - connection.close() - try: - self.connections.get(address, []).remove(connection) - except ValueError: - # If closure fails (e.g. because the server went - # down), all connections to the same address will - # be removed. Therefore, we silently ignore if the - # connection isn't in the pool anymore. - pass - continue - if not connection.in_use: - connection.in_use = True - return connection - # all connections in pool are in-use - connections = self.connections[address] - max_pool_size = self.pool_config.max_connection_pool_size - infinite_pool_size = (max_pool_size < 0 - or max_pool_size == float("inf")) - can_create_new_connection = ( - infinite_pool_size - or len(connections) < max_pool_size + while True: + # try to find a free connection in the pool + connection = self._acquire_from_pool_checked( + address, health_check, deadline + ) + if connection: + return connection + # all connections in pool are in-use + with self.lock: + connection_creator = self._acquire_new_later( + address, deadline ) - if can_create_new_connection: - timeout = min(self.pool_config.connection_timeout, - time_remaining()) - try: - connection = self.opener(address, timeout) - except ServiceUnavailable: - self.deactivate(address) - raise - else: - connection.pool = self - connection.in_use = True - connections.append(connection) - return connection + if connection_creator: + break # failed to obtain a connection from pool because the # pool is full and no free connection in the pool - if time_remaining(): - self.cond.wait(time_remaining()) - # if timed out, then we throw error. This time - # computation is needed, as with python 2.7, we - # cannot tell if the condition is notified or - # timed out when we come to this line - if not time_remaining(): - raise ClientError("Failed to obtain a connection from pool " - "within {!r}s".format(timeout)) - else: - raise ClientError("Failed to obtain a connection from pool " - "within {!r}s".format(timeout)) + timeout = deadline.to_timeout() + if ( + timeout == 0 # deadline expired + or not self.cond.wait(timeout) + ): + raise ClientError( + "Failed to obtain a connection from pool within {!r}s" + .format(deadline.original_timeout) + ) + return connection_creator() - def acquire(self, access_mode=None, timeout=None, database=None, - bookmarks=None): + def acquire(self, access_mode, timeout, acquisition_timeout, database, + bookmarks): """ Acquire a connection to a server that can satisfy a set of parameters. :param access_mode: - :param timeout: + :param timeout: total timeout (including potential preparation) + :param acquisition_timeout: timeout for actually acquiring a connection :param database: :param bookmarks: """ @@ -800,6 +885,9 @@ def close(self): pass +BoltSocket.Bolt = Bolt + + class BoltPool(IOPool): @classmethod @@ -829,9 +917,11 @@ def __init__(self, opener, pool_config, workspace_config, address): def __repr__(self): return "<{} address={!r}>".format(self.__class__.__name__, self.address) - def acquire(self, access_mode=None, timeout=None, database=None, bookmarks=None): + def acquire(self, access_mode, timeout, acquisition_timeout, database, + bookmarks): # The access_mode and database is not needed for a direct connection, its just there for consistency. - return self._acquire(self.address, timeout) + deadline = merge_deadlines_and_timeouts(timeout, acquisition_timeout) + return self._acquire(self.address, deadline) class Neo4jPool(IOPool): @@ -887,6 +977,22 @@ def __repr__(self): """ return "<{} addresses={!r}>".format(self.__class__.__name__, self.get_default_database_initial_router_addresses()) + @contextmanager + def _refresh_lock_deadline(self, deadline): + timeout = deadline.to_timeout() + if timeout is None: + timeout = -1 + if not self.refresh_lock.acquire(timeout=timeout): + raise ClientError( + "pool failed to update routing table within {!r}s (timeout)" + .format(deadline.original_timeout) + ) + + try: + yield + finally: + self.refresh_lock.release() + @property def first_initial_routing_address(self): return self.get_default_database_initial_router_addresses()[0] @@ -920,7 +1026,7 @@ def get_or_create_routing_table(self, database): return self.routing_tables[database] def fetch_routing_info(self, address, database, imp_user, bookmarks, - timeout): + deadline): """ Fetch raw routing info from a given router address. :param address: router address @@ -930,30 +1036,31 @@ def fetch_routing_info(self, address, database, imp_user, bookmarks, :type imp_user: str or None :param bookmarks: iterable of bookmark values after which the routing info should be fetched - :param timeout: connection acquisition timeout in seconds + :param deadline: connection acquisition deadline :return: list of routing records, or None if no connection could be established or if no readers or writers are present :raise ServiceUnavailable: if the server does not support routing, or if routing support is broken or outdated """ - cx = self._acquire(address, timeout) + cx = self._acquire(address, deadline) try: - routing_table = cx.route( - database or self.workspace_config.database, - imp_user or self.workspace_config.impersonated_user, - bookmarks - ) + with connection_deadline(cx, deadline): + routing_table = cx.route( + database or self.workspace_config.database, + imp_user or self.workspace_config.impersonated_user, + bookmarks + ) finally: self.release(cx) return routing_table - def fetch_routing_table(self, *, address, timeout, database, imp_user, + def fetch_routing_table(self, *, address, deadline, database, imp_user, bookmarks): """ Fetch a routing table from a given router address. :param address: router address - :param timeout: seconds + :param deadline: deadline :param database: the database name :type: str :param imp_user: the user to impersonate while fetching the routing @@ -967,7 +1074,7 @@ def fetch_routing_table(self, *, address, timeout, database, imp_user, new_routing_info = None try: new_routing_info = self.fetch_routing_info( - address, database, imp_user, bookmarks, timeout + address, database, imp_user, bookmarks, deadline ) except Neo4jError as e: # checks if the code is an error that is caused by the client. In @@ -1009,8 +1116,10 @@ def fetch_routing_table(self, *, address, timeout, database, imp_user, # At least one of each is fine, so return this table return new_routing_table - def _update_routing_table_from(self, *routers, database=None, imp_user=None, - bookmarks=None, database_callback=None): + def _update_routing_table_from( + self, *routers, database, imp_user, bookmarks, deadline, + database_callback + ): """ Try to update routing tables with the given routers. :return: True if the routing table is successfully updated, @@ -1019,27 +1128,30 @@ def _update_routing_table_from(self, *routers, database=None, imp_user=None, log.debug("Attempting to update routing table from {}".format(", ".join(map(repr, routers)))) for router in routers: for address in router.resolve(resolver=self.pool_config.resolver): + if deadline.expired(): + return False new_routing_table = self.fetch_routing_table( - address=address, - timeout=self.pool_config.connection_timeout, - database=database, imp_user=imp_user, bookmarks=bookmarks + address=address, deadline=deadline, database=database, + imp_user=imp_user, bookmarks=bookmarks ) if new_routing_table is not None: - new_databse = new_routing_table.database - self.get_or_create_routing_table(new_databse)\ - .update(new_routing_table) + new_database = new_routing_table.database + old_routing_table = self.get_or_create_routing_table( + new_database + ) + old_routing_table.update(new_routing_table) log.debug( "[#0000] C: address=%r (%r)", - address, self.routing_tables[new_databse] + address, self.routing_tables[new_database] ) if callable(database_callback): - database_callback(new_databse) + database_callback(new_database) return True self.deactivate(router) return False def update_routing_table(self, *, database, imp_user, bookmarks, - database_callback=None): + timeout=None, database_callback=None): """ Update the routing table from the first router able to provide valid routing information. @@ -1048,6 +1160,7 @@ def update_routing_table(self, *, database, imp_user, bookmarks, table :type imp_user: str or None :param bookmarks: bookmarks used when fetching routing table + :param timeout: timeout in seconds for how long to try updating :param database_callback: A callback function that will be called with the database name as only argument when a new routing table has been acquired. This database name might different from `database` if that @@ -1056,7 +1169,10 @@ def update_routing_table(self, *, database, imp_user, bookmarks, :raise neo4j.exceptions.ServiceUnavailable: """ - with self.refresh_lock: + deadline = merge_deadlines_and_timeouts( + timeout, self.pool_config.update_routing_table_timeout + ) + with self._refresh_lock_deadline(deadline): # copied because it can be modified existing_routers = set( self.get_or_create_routing_table(database).routers @@ -1070,14 +1186,14 @@ def update_routing_table(self, *, database, imp_user, bookmarks, if self._update_routing_table_from( self.first_initial_routing_address, database=database, imp_user=imp_user, bookmarks=bookmarks, - database_callback=database_callback + deadline=deadline, database_callback=database_callback ): # Why is only the first initial routing address used? return if self._update_routing_table_from( *(existing_routers - {self.first_initial_routing_address}), database=database, imp_user=imp_user, bookmarks=bookmarks, - database_callback=database_callback + deadline=deadline, database_callback=database_callback ): return @@ -1085,7 +1201,7 @@ def update_routing_table(self, *, database, imp_user, bookmarks, if self._update_routing_table_from( self.first_initial_routing_address, database=database, imp_user=imp_user, bookmarks=bookmarks, - database_callback=database_callback + deadline=deadline, database_callback=database_callback ): # Why is only the first initial routing address used? return @@ -1100,8 +1216,10 @@ def update_connection_pool(self, *, database): if address.unresolved not in servers: super(Neo4jPool, self).deactivate(address) - def ensure_routing_table_is_fresh(self, *, access_mode, database, imp_user, - bookmarks, database_callback=None): + def ensure_routing_table_is_fresh( + self, *, access_mode, database, imp_user, bookmarks, deadline=None, + database_callback=None + ): """ Update the routing table if stale. This method performs two freshness checks, before and after acquiring @@ -1115,7 +1233,7 @@ def ensure_routing_table_is_fresh(self, *, access_mode, database, imp_user, :return: `True` if an update was required, `False` otherwise. """ from neo4j.api import READ_ACCESS - with self.refresh_lock: + with self._refresh_lock_deadline(deadline): if self.get_or_create_routing_table(database)\ .is_fresh(readonly=(access_mode == READ_ACCESS)): # Readers are fresh. @@ -1123,7 +1241,7 @@ def ensure_routing_table_is_fresh(self, *, access_mode, database, imp_user, self.update_routing_table( database=database, imp_user=imp_user, bookmarks=bookmarks, - database_callback=database_callback + timeout=deadline, database_callback=database_callback ) self.update_connection_pool(database=database) @@ -1162,24 +1280,33 @@ def _select_address(self, *, access_mode, database): ) return choice(addresses_by_usage[min(addresses_by_usage)]) - def acquire(self, access_mode=None, timeout=None, database=None, - bookmarks=None): + def acquire(self, access_mode, timeout, acquisition_timeout, database, + bookmarks): if access_mode not in (WRITE_ACCESS, READ_ACCESS): raise ClientError("Non valid 'access_mode'; {}".format(access_mode)) if not timeout: raise ClientError("'timeout' must be a float larger than 0; {}" .format(timeout)) + if not acquisition_timeout: + raise ClientError("'acquisition_timeout' must be a float larger " + "than 0; {}".format(acquisition_timeout)) + deadline = Deadline.from_timeout_or_deadline(timeout) from neo4j.api import check_access_mode access_mode = check_access_mode(access_mode) - with self.refresh_lock: + with self._refresh_lock_deadline(deadline): log.debug("[#0000] C: %r", self.routing_tables) self.ensure_routing_table_is_fresh( access_mode=access_mode, database=database, imp_user=None, - bookmarks=bookmarks + bookmarks=bookmarks, deadline=deadline ) + # Making sure the routing table is fresh is not considered part of the + # connection acquisition. Hence, the acquisition_timeout starts now! + deadline = merge_deadlines( + deadline, Deadline.from_timeout_or_deadline(acquisition_timeout) + ) while True: try: # Get an address for a connection that have the fewest in-use @@ -1190,7 +1317,8 @@ def acquire(self, access_mode=None, timeout=None, database=None, raise SessionExpired("Failed to obtain connection towards '%s' server." % access_mode) from err try: log.debug("[#0000] C: database=%r address=%r", database, address) - connection = self._acquire(address, timeout=timeout) # should always be a resolved address + # should always be a resolved address + connection = self._acquire(address, deadline) except ServiceUnavailable: self.deactivate(address=address) else: @@ -1220,182 +1348,6 @@ def on_write_failure(self, address): log.debug("[#0000] C: table=%r", self.routing_tables) -def _connect(resolved_address, timeout, keep_alive): - """ - - :param resolved_address: - :param timeout: seconds - :param keep_alive: True or False - :return: socket object - """ - - s = None # The socket - - try: - if len(resolved_address) == 2: - s = socket(AF_INET) - elif len(resolved_address) == 4: - s = socket(AF_INET6) - else: - raise ValueError("Unsupported address {!r}".format(resolved_address)) - t = s.gettimeout() - if timeout: - s.settimeout(timeout) - log.debug("[#0000] C: %s", resolved_address) - s.connect(resolved_address) - s.settimeout(t) - keep_alive = 1 if keep_alive else 0 - s.setsockopt(SOL_SOCKET, SO_KEEPALIVE, keep_alive) - except SocketTimeout: - log.debug("[#0000] C: %s", resolved_address) - log.debug("[#0000] C: %s", resolved_address) - _close_socket(s) - raise ServiceUnavailable("Timed out trying to establish connection to {!r}".format(resolved_address)) - except OSError as error: - log.debug("[#0000] C: %s %s", type(error).__name__, - " ".join(map(repr, error.args))) - log.debug("[#0000] C: %s", resolved_address) - s.close() - raise ServiceUnavailable("Failed to establish connection to {!r} (reason {})".format(resolved_address, error)) - else: - return s - - -def _secure(s, host, ssl_context): - local_port = s.getsockname()[1] - # Secure the connection if an SSL context has been provided - if ssl_context: - last_error = None - log.debug("[#%04X] C: %s", local_port, host) - try: - sni_host = host if HAS_SNI and host else None - s = ssl_context.wrap_socket(s, server_hostname=sni_host) - except (OSError, SSLError, CertificateError) as cause: - raise BoltSecurityError( - message="Failed to establish encrypted connection.", - address=(host, local_port) - ) from cause - # Check that the server provides a certificate - der_encoded_server_certificate = s.getpeercert(binary_form=True) - if der_encoded_server_certificate is None: - raise BoltProtocolError( - "When using an encrypted socket, the server should always " - "provide a certificate", address=(host, local_port) - ) - return s - return s - - -def _handshake(s, resolved_address): - """ - - :param s: Socket - :param resolved_address: - - :return: (socket, version, client_handshake, server_response_data) - """ - local_port = s.getsockname()[1] - - # TODO: Optimize logging code - handshake = Bolt.get_handshake() - import struct - handshake = struct.unpack(">16B", handshake) - handshake = [handshake[i:i + 4] for i in range(0, len(handshake), 4)] - - supported_versions = [("0x%02X%02X%02X%02X" % (vx[0], vx[1], vx[2], vx[3])) for vx in handshake] - - log.debug("[#%04X] C: 0x%08X", local_port, int.from_bytes(Bolt.MAGIC_PREAMBLE, byteorder="big")) - log.debug("[#%04X] C: %s %s %s %s", local_port, *supported_versions) - - data = Bolt.MAGIC_PREAMBLE + Bolt.get_handshake() - s.sendall(data) - - # Handle the handshake response - ready_to_read = False - with selectors.DefaultSelector() as selector: - selector.register(s, selectors.EVENT_READ) - selector.select(1) - try: - data = s.recv(4) - except OSError: - raise ServiceUnavailable("Failed to read any data from server {!r} " - "after connected".format(resolved_address)) - data_size = len(data) - if data_size == 0: - # If no data is returned after a successful select - # response, the server has closed the connection - log.debug("[#%04X] S: ", local_port) - _close_socket(s) - raise ServiceUnavailable("Connection to {address} closed without handshake response".format(address=resolved_address)) - if data_size != 4: - # Some garbled data has been received - log.debug("[#%04X] S: @*#!", local_port) - s.close() - raise BoltProtocolError("Expected four byte Bolt handshake response from %r, received %r instead; check for incorrect port number" % (resolved_address, data), address=resolved_address) - elif data == b"HTTP": - log.debug("[#%04X] S: ", local_port) - _close_socket(s) - raise ServiceUnavailable("Cannot to connect to Bolt service on {!r} " - "(looks like HTTP)".format(resolved_address)) - agreed_version = data[-1], data[-2] - log.debug("[#%04X] S: 0x%06X%02X", local_port, agreed_version[1], agreed_version[0]) - return s, agreed_version, handshake, data - - -def _close_socket(socket_): - try: - socket_.shutdown(SHUT_RDWR) - socket_.close() - except OSError: - pass - - -def connect(address, *, timeout, custom_resolver, ssl_context, keep_alive): - """ Connect and perform a handshake and return a valid Connection object, - assuming a protocol version can be agreed. - """ - errors = [] - # Establish a connection to the host and port specified - # Catches refused connections see: - # https://docs.python.org/2/library/errno.html - - resolved_addresses = Address(address).resolve(resolver=custom_resolver) - for resolved_address in resolved_addresses: - s = None - try: - s = _connect(resolved_address, timeout, keep_alive) - s = _secure(s, resolved_address.host_name, ssl_context) - return _handshake(s, resolved_address) - except (BoltError, DriverError, OSError) as error: - try: - local_port = s.getsockname()[1] - except (OSError, AttributeError): - local_port = 0 - err_str = error.__class__.__name__ - if str(error): - err_str += ": " + str(error) - log.debug("[#%04X] C: %s", local_port, err_str) - if s: - _close_socket(s) - errors.append(error) - except Exception: - if s: - _close_socket(s) - raise - if not errors: - raise ServiceUnavailable( - "Couldn't connect to %s (resolved to %s)" % ( - str(address), tuple(map(str, resolved_addresses))) - ) - else: - raise ServiceUnavailable( - "Couldn't connect to %s (resolved to %s):\n%s" % ( - str(address), tuple(map(str, resolved_addresses)), - "\n".join(map(str, errors)) - ) - ) from errors[0] - - def check_supported_server_product(agent): """ Checks that a server product is supported by the driver by looking at the server agent string. diff --git a/neo4j/io/_bolt3.py b/neo4j/io/_bolt3.py index 2fe88363b..21f71b180 100644 --- a/neo4j/io/_bolt3.py +++ b/neo4j/io/_bolt3.py @@ -306,6 +306,10 @@ def fail(metadata): self.send_all() self.fetch_all() + def goodbye(self): + log.debug("[#%04X] C: GOODBYE", self.local_port) + self._append(b"\x02", ()) + def fetch_message(self): """ Receive at most one message from the server, if available. diff --git a/neo4j/io/_bolt4.py b/neo4j/io/_bolt4.py index 1acc8ab92..c29c6aebb 100644 --- a/neo4j/io/_bolt4.py +++ b/neo4j/io/_bolt4.py @@ -258,6 +258,10 @@ def fail(metadata): self.send_all() self.fetch_all() + def goodbye(self): + log.debug("[#%04X] C: GOODBYE", self.local_port) + self._append(b"\x02", ()) + def fetch_message(self): """ Receive at most one message from the server, if available. @@ -318,25 +322,6 @@ def fetch_message(self): return len(details), 1 - def close(self): - """ Close the connection. - """ - if not self._closed: - if not self._defunct: - log.debug("[#%04X] C: GOODBYE", self.local_port) - self._append(b"\x02", ()) - try: - self._send_all() - except (OSError, BoltError, DriverError): - pass - log.debug("[#%04X] C: ", self.local_port) - try: - self.socket.close() - except OSError: - pass - finally: - self._closed = True - def closed(self): return self._closed diff --git a/neo4j/io/_common.py b/neo4j/io/_common.py index c974462f5..a3f079108 100644 --- a/neo4j/io/_common.py +++ b/neo4j/io/_common.py @@ -23,6 +23,7 @@ import socket from struct import pack as struct_pack +from neo4j._exceptions import SocketDeadlineExceeded from neo4j.exceptions import ( Neo4jError, ServiceUnavailable, @@ -69,7 +70,7 @@ def _yield_messages(self, sock): # Reset for new message unpacker.reset() - except (OSError, socket.timeout) as error: + except (OSError, socket.timeout, SocketDeadlineExceeded) as error: self.on_error(error) def pop(self): diff --git a/neo4j/io/_socket.py b/neo4j/io/_socket.py new file mode 100644 index 000000000..f3cc3f313 --- /dev/null +++ b/neo4j/io/_socket.py @@ -0,0 +1,316 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# This file is part of Neo4j. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import logging +import selectors +import socket +from socket import ( + AF_INET, + AF_INET6, + SHUT_RDWR, + SO_KEEPALIVE, + socket, + SOL_SOCKET, + timeout as SocketTimeout, +) +from ssl import ( + CertificateError, + HAS_SNI, + SSLError, +) +import struct + +from .._exceptions import ( + BoltError, + BoltProtocolError, + BoltSecurityError, + SocketDeadlineExceeded, +) +from .._deadline import Deadline +from ..addressing import Address +from ..exceptions import ( + DriverError, + ServiceUnavailable, +) + + +log = logging.getLogger("neo4j") + + +def _sanitize_deadline(deadline): + if deadline is None: + return None + deadline = Deadline.from_timeout_or_deadline(deadline) + if deadline.to_timeout() is None: + return None + return deadline + + +class BoltSocket: + Bolt = None + + def __init__(self, socket_: socket): + self._socket = socket_ + self._deadline = None + + @property + def _socket(self): + return self.__socket + + @_socket.setter + def _socket(self, socket_: socket): + self.__socket = socket_ + self.getsockname = socket_.getsockname + self.getpeername = socket_.getpeername + if hasattr(socket, "getpeercert"): + self.getpeercert = socket_.getpeercert + elif hasattr(self, "getpeercert"): + del self.getpeercert + self.gettimeout = socket_.gettimeout + self.settimeout = socket_.settimeout + + def _wait_for_io(self, func, *args, **kwargs): + if self._deadline is None: + return func(*args, **kwargs) + timeout = self._socket.gettimeout() + deadline_timeout = self._deadline.to_timeout() + if deadline_timeout <= 0: + raise SocketDeadlineExceeded("timed out") + if timeout is None or deadline_timeout <= timeout: + self._socket.settimeout(deadline_timeout) + try: + return func(*args, **kwargs) + except SocketTimeout as e: + raise SocketDeadlineExceeded("timed out") from e + finally: + self._socket.settimeout(timeout) + return func(*args, **kwargs) + + def get_deadline(self): + return self._deadline + + def set_deadline(self, deadline): + self._deadline = _sanitize_deadline(deadline) + + def recv(self, n): + return self._wait_for_io(self._socket.recv, n) + + def recv_into(self, buffer, nbytes): + return self._wait_for_io(self._socket.recv_into, buffer, nbytes) + + def sendall(self, data): + return self._wait_for_io(self._socket.sendall, data) + + def close(self): + self._socket.shutdown(SHUT_RDWR) + self._socket.close() + + @classmethod + def _connect(cls, resolved_address, timeout, keep_alive): + """ + + :param resolved_address: + :param timeout: seconds + :param keep_alive: True or False + :return: socket object + """ + + s = None # The socket + + try: + if len(resolved_address) == 2: + s = socket(AF_INET) + elif len(resolved_address) == 4: + s = socket(AF_INET6) + else: + raise ValueError( + "Unsupported address {!r}".format(resolved_address)) + t = s.gettimeout() + if timeout: + s.settimeout(timeout) + log.debug("[#0000] C: %s", resolved_address) + s.connect(resolved_address) + s.settimeout(t) + keep_alive = 1 if keep_alive else 0 + s.setsockopt(SOL_SOCKET, SO_KEEPALIVE, keep_alive) + return s + except SocketTimeout: + log.debug("[#0000] C: %s", resolved_address) + log.debug("[#0000] C: %s", resolved_address) + cls.close_socket(s) + raise ServiceUnavailable( + "Timed out trying to establish connection to {!r}".format( + resolved_address)) + except OSError as error: + log.debug("[#0000] C: %s %s", type(error).__name__, + " ".join(map(repr, error.args))) + log.debug("[#0000] C: %s", resolved_address) + s.close() + raise ServiceUnavailable( + "Failed to establish connection to {!r} (reason {})".format( + resolved_address, error)) + + @classmethod + def _secure(cls, s, host, ssl_context): + local_port = s.getsockname()[1] + # Secure the connection if an SSL context has been provided + if ssl_context: + log.debug("[#%04X] C: %s", local_port, host) + try: + sni_host = host if HAS_SNI and host else None + s = ssl_context.wrap_socket(s, server_hostname=sni_host) + except (OSError, SSLError, CertificateError) as cause: + raise BoltSecurityError( + message="Failed to establish encrypted connection.", + address=(host, local_port) + ) from cause + # Check that the server provides a certificate + der_encoded_server_certificate = s.getpeercert(binary_form=True) + if der_encoded_server_certificate is None: + raise BoltProtocolError( + "When using an encrypted socket, the server should always " + "provide a certificate", address=(host, local_port) + ) + return s + return s + + @classmethod + def _handshake(cls, s, resolved_address): + """ + + :param s: Socket + :param resolved_address: + + :return: (socket, version, client_handshake, server_response_data) + """ + local_port = s.getsockname()[1] + + # TODO: Optimize logging code + handshake = cls.Bolt.get_handshake() + handshake = struct.unpack(">16B", handshake) + handshake = [handshake[i:i + 4] for i in range(0, len(handshake), 4)] + + supported_versions = [ + ("0x%02X%02X%02X%02X" % (vx[0], vx[1], vx[2], vx[3])) for vx in + handshake] + + log.debug("[#%04X] C: 0x%08X", local_port, + int.from_bytes(cls.Bolt.MAGIC_PREAMBLE, byteorder="big")) + log.debug("[#%04X] C: %s %s %s %s", local_port, + *supported_versions) + + data = cls.Bolt.MAGIC_PREAMBLE + cls.Bolt.get_handshake() + s.sendall(data) + + # Handle the handshake response + ready_to_read = False + with selectors.DefaultSelector() as selector: + selector.register(s, selectors.EVENT_READ) + selector.select(1) + try: + data = s.recv(4) + except OSError: + raise ServiceUnavailable( + "Failed to read any data from server {!r} " + "after connected".format(resolved_address)) + data_size = len(data) + if data_size == 0: + # If no data is returned after a successful select + # response, the server has closed the connection + log.debug("[#%04X] S: ", local_port) + BoltSocket.close_socket(s) + raise ServiceUnavailable( + "Connection to {address} closed without handshake response".format( + address=resolved_address)) + if data_size != 4: + # Some garbled data has been received + log.debug("[#%04X] S: @*#!", local_port) + s.close() + raise BoltProtocolError( + "Expected four byte Bolt handshake response from %r, received %r instead; check for incorrect port number" % ( + resolved_address, data), address=resolved_address) + elif data == b"HTTP": + log.debug("[#%04X] S: ", local_port) + BoltSocket.close_socket(s) + raise ServiceUnavailable( + "Cannot to connect to Bolt service on {!r} " + "(looks like HTTP)".format(resolved_address)) + agreed_version = data[-1], data[-2] + log.debug("[#%04X] S: 0x%06X%02X", local_port, + agreed_version[1], agreed_version[0]) + return cls(s), agreed_version, handshake, data + + @classmethod + def close_socket(cls, socket_): + try: + if isinstance(socket_, BoltSocket): + socket.close() + else: + socket_.shutdown(SHUT_RDWR) + socket_.close() + except OSError: + pass + + @classmethod + def connect(cls, address, *, timeout, custom_resolver, ssl_context, + keep_alive): + """ Connect and perform a handshake and return a valid Connection object, + assuming a protocol version can be agreed. + """ + errors = [] + # Establish a connection to the host and port specified + # Catches refused connections see: + # https://docs.python.org/2/library/errno.html + + resolved_addresses = Address(address).resolve(resolver=custom_resolver) + for resolved_address in resolved_addresses: + s = None + try: + s = BoltSocket._connect(resolved_address, timeout, keep_alive) + s = BoltSocket._secure(s, resolved_address.host_name, + ssl_context) + return BoltSocket._handshake(s, resolved_address) + except (BoltError, DriverError, OSError) as error: + try: + local_port = s.getsockname()[1] + except (OSError, AttributeError): + local_port = 0 + err_str = error.__class__.__name__ + if str(error): + err_str += ": " + str(error) + log.debug("[#%04X] C: %s", local_port, + err_str) + if s: + BoltSocket.close_socket(s) + errors.append(error) + except Exception: + if s: + BoltSocket.close_socket(s) + raise + if not errors: + raise ServiceUnavailable( + "Couldn't connect to %s (resolved to %s)" % ( + str(address), tuple(map(str, resolved_addresses))) + ) + else: + raise ServiceUnavailable( + "Couldn't connect to %s (resolved to %s):\n%s" % ( + str(address), tuple(map(str, resolved_addresses)), + "\n".join(map(str, errors)) + ) + ) from errors[0] diff --git a/neo4j/work/__init__.py b/neo4j/work/__init__.py index 154f73819..288b7a130 100644 --- a/neo4j/work/__init__.py +++ b/neo4j/work/__init__.py @@ -19,6 +19,7 @@ # limitations under the License. +from neo4j._deadline import Deadline from neo4j.conf import WorkspaceConfig from neo4j.exceptions import ServiceUnavailable from neo4j.io import Neo4jPool @@ -53,6 +54,7 @@ def _set_cached_database(self, database): self._config.database = database def _connect(self, access_mode): + timeout = Deadline(self._config.session_connection_timeout) if self._connection: # TODO: Investigate this # log.warning("FIXME: should always disconnect before connect") @@ -74,11 +76,13 @@ def _connect(self, access_mode): database=self._config.database, imp_user=self._config.impersonated_user, bookmarks=self._bookmarks, + timeout=timeout, database_callback=self._set_cached_database ) self._connection = self._pool.acquire( access_mode=access_mode, - timeout=self._config.connection_acquisition_timeout, + timeout=timeout, + acquisition_timeout=self._config.connection_acquisition_timeout, database=self._config.database, bookmarks=self._bookmarks ) diff --git a/testkitbackend/requests.py b/testkitbackend/requests.py index 8b7b31e4f..a6a5d9e33 100644 --- a/testkitbackend/requests.py +++ b/testkitbackend/requests.py @@ -81,13 +81,15 @@ def NewDriver(backend, data): backend, data["resolverRegistered"], data["domainNameResolverRegistered"] ) - if data.get("connectionTimeoutMs"): - kwargs["connection_timeout"] = data["connectionTimeoutMs"] / 1000 - if data.get("maxTxRetryTimeMs"): - kwargs["max_transaction_retry_time"] = data["maxTxRetryTimeMs"] / 1000 - if data.get("connectionAcquisitionTimeoutMs"): - kwargs["connection_acquisition_timeout"] = \ - data["connectionAcquisitionTimeoutMs"] / 1000 + for timeout_testkit, timeout_driver in ( + ("connectionTimeoutMs", "connection_timeout"), + ("maxTxRetryTimeMs", "max_transaction_retry_time"), + ("connectionAcquisitionTimeoutMs", "connection_acquisition_timeout"), + ("sessionConnectionTimeoutMs", "session_connection_timeout"), + ("updateRoutingTableTimeoutMs", "update_routing_table_timeout"), + ): + if data.get(timeout_testkit) is not None: + kwargs[timeout_driver] = data[timeout_testkit] / 1000 if data.get("maxConnectionPoolSize"): kwargs["max_connection_pool_size"] = data["maxConnectionPoolSize"] if data.get("fetchSize"): diff --git a/testkitbackend/test_config.json b/testkitbackend/test_config.json index 1c7b1445a..075349cc7 100644 --- a/testkitbackend/test_config.json +++ b/testkitbackend/test_config.json @@ -26,10 +26,13 @@ "Driver rejects empty queries before sending it to the server" }, "features": { + "Feature:API:ConnectionAcquisitionTimeout": true, "Feature:API:Liveness.Check": false, "Feature:API:Result.List": true, "Feature:API:Result.Peek": true, "Feature:API:Result.Single": "Does not raise error when not exactly one record is available. To be fixed in 5.0.", + "Feature:API:SessionConnectionTimeout": true, + "Feature:API:UpdateRoutingTableTimeout": true, "Feature:Auth:Bearer": true, "Feature:Auth:Custom": true, "Feature:Auth:Kerberos": true, diff --git a/tests/unit/io/test_direct.py b/tests/unit/io/test_direct.py index 9e572fb26..aab2dc948 100644 --- a/tests/unit/io/test_direct.py +++ b/tests/unit/io/test_direct.py @@ -37,9 +37,9 @@ PoolConfig, WorkspaceConfig, ) +from neo4j._deadline import Deadline from neo4j.io import ( Bolt, - BoltPool, IOPool ) from neo4j.exceptions import ( @@ -104,8 +104,8 @@ def opener(addr, timeout): super().__init__(opener, self.pool_config, self.workspace_config) self.address = address - def acquire(self, access_mode=None, timeout=None, database=None, - bookmarks=None): + def acquire(self, access_mode, timeout, acquisition_timeout, database, + bookmarks): return self._acquire(self.address, timeout) @@ -196,14 +196,14 @@ def assert_pool_size(self, address, expected_active, expected_inactive, pool=Non def test_can_acquire(self): address = ("127.0.0.1", 7687) - connection = self.pool._acquire(address, timeout=3) + connection = self.pool._acquire(address, Deadline(3)) assert connection.address == address self.assert_pool_size(address, 1, 0) def test_can_acquire_twice(self): address = ("127.0.0.1", 7687) - connection_1 = self.pool._acquire(address, timeout=3) - connection_2 = self.pool._acquire(address, timeout=3) + connection_1 = self.pool._acquire(address, Deadline(3)) + connection_2 = self.pool._acquire(address, Deadline(3)) assert connection_1.address == address assert connection_2.address == address assert connection_1 is not connection_2 @@ -212,8 +212,8 @@ def test_can_acquire_twice(self): def test_can_acquire_two_addresses(self): address_1 = ("127.0.0.1", 7687) address_2 = ("127.0.0.1", 7474) - connection_1 = self.pool._acquire(address_1, timeout=3) - connection_2 = self.pool._acquire(address_2, timeout=3) + connection_1 = self.pool._acquire(address_1, Deadline(3)) + connection_2 = self.pool._acquire(address_2, Deadline(3)) assert connection_1.address == address_1 assert connection_2.address == address_2 self.assert_pool_size(address_1, 1, 0) @@ -221,14 +221,14 @@ def test_can_acquire_two_addresses(self): def test_can_acquire_and_release(self): address = ("127.0.0.1", 7687) - connection = self.pool._acquire(address, timeout=3) + connection = self.pool._acquire(address, Deadline(3)) self.assert_pool_size(address, 1, 0) self.pool.release(connection) self.assert_pool_size(address, 0, 1) def test_releasing_twice(self): address = ("127.0.0.1", 7687) - connection = self.pool._acquire(address, timeout=3) + connection = self.pool._acquire(address, Deadline(3)) self.pool.release(connection) self.assert_pool_size(address, 0, 1) self.pool.release(connection) @@ -237,7 +237,7 @@ def test_releasing_twice(self): def test_in_use_count(self): address = ("127.0.0.1", 7687) self.assertEqual(self.pool.in_use_connection_count(address), 0) - connection = self.pool._acquire(address, timeout=3) + connection = self.pool._acquire(address, Deadline(3)) self.assertEqual(self.pool.in_use_connection_count(address), 1) self.pool.release(connection) self.assertEqual(self.pool.in_use_connection_count(address), 0) @@ -245,10 +245,10 @@ def test_in_use_count(self): def test_max_conn_pool_size(self): with FakeBoltPool((), max_connection_pool_size=1) as pool: address = ("127.0.0.1", 7687) - pool._acquire(address, timeout=0) + pool._acquire(address, Deadline(0)) self.assertEqual(pool.in_use_connection_count(address), 1) with self.assertRaises(ClientError): - pool._acquire(address, timeout=0) + pool._acquire(address, Deadline(0)) self.assertEqual(pool.in_use_connection_count(address), 1) def test_multithread(self): @@ -268,7 +268,7 @@ def test_multithread(self): t.start() threads.append(t) - if not acquired_counter.wait(5, timeout=1): + if not acquired_counter.wait(5, 1): raise RuntimeError("Acquire threads not fast enough") # The pool size should be 5, all are in-use self.assert_pool_size(address, 5, 0, pool) @@ -277,7 +277,7 @@ def test_multithread(self): # wait for all threads to release connections back to pool for t in threads: - t.join(timeout=1) + t.join(1) # The pool size is still 5, but all are free self.assert_pool_size(address, 0, 5, pool) @@ -288,7 +288,9 @@ def test(is_reset): with mock.patch(__name__ + ".QuickConnection.reset", new_callable=mock.MagicMock) as reset_mock: is_reset_mock.return_value = is_reset - connection = self.pool._acquire(address, timeout=3) + connection = self.pool._acquire( + address, Deadline(3) + ) self.assertIsInstance(connection, QuickConnection) self.assertEqual(is_reset_mock.call_count, 0) self.assertEqual(reset_mock.call_count, 0) @@ -303,7 +305,7 @@ def test(is_reset): def acquire_release_conn(pool, address, acquired_counter, release_event): - conn = pool._acquire(address, timeout=3) + conn = pool._acquire(address, Deadline(3)) acquired_counter.increment() release_event.wait() pool.release(conn) diff --git a/tests/unit/io/test_neo4j_pool.py b/tests/unit/io/test_neo4j_pool.py index 5aba030a9..5853ceac6 100644 --- a/tests/unit/io/test_neo4j_pool.py +++ b/tests/unit/io/test_neo4j_pool.py @@ -74,37 +74,37 @@ def open_(addr, timeout): def test_acquires_new_routing_table_if_deleted(opener): pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS) - cx = pool.acquire(READ_ACCESS, 30, "test_db", None) + cx = pool.acquire(READ_ACCESS, 30, 60, "test_db", None) pool.release(cx) assert pool.routing_tables.get("test_db") del pool.routing_tables["test_db"] - cx = pool.acquire(READ_ACCESS, 30, "test_db", None) + cx = pool.acquire(READ_ACCESS, 30, 60, "test_db", None) pool.release(cx) assert pool.routing_tables.get("test_db") def test_acquires_new_routing_table_if_stale(opener): pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS) - cx = pool.acquire(READ_ACCESS, 30, "test_db", None) + cx = pool.acquire(READ_ACCESS, 30, 60, "test_db", None) pool.release(cx) assert pool.routing_tables.get("test_db") old_value = pool.routing_tables["test_db"].last_updated_time pool.routing_tables["test_db"].ttl = 0 - cx = pool.acquire(READ_ACCESS, 30, "test_db", None) + cx = pool.acquire(READ_ACCESS, 30, 60, "test_db", None) pool.release(cx) assert pool.routing_tables["test_db"].last_updated_time > old_value def test_removes_old_routing_table(opener): pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS) - cx = pool.acquire(READ_ACCESS, 30, "test_db1", None) + cx = pool.acquire(READ_ACCESS, 30, 60, "test_db1", None) pool.release(cx) assert pool.routing_tables.get("test_db1") - cx = pool.acquire(READ_ACCESS, 30, "test_db2", None) + cx = pool.acquire(READ_ACCESS, 30, 60, "test_db2", None) pool.release(cx) assert pool.routing_tables.get("test_db2") @@ -113,7 +113,7 @@ def test_removes_old_routing_table(opener): pool.routing_tables["test_db2"].ttl = \ -RoutingConfig.routing_table_purge_delay - cx = pool.acquire(READ_ACCESS, 30, "test_db1", None) + cx = pool.acquire(READ_ACCESS, 30, 60, "test_db1", None) pool.release(cx) assert pool.routing_tables["test_db1"].last_updated_time > old_value assert "test_db2" not in pool.routing_tables @@ -123,7 +123,7 @@ def test_removes_old_routing_table(opener): def test_chooses_right_connection_type(opener, type_): pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS) cx1 = pool.acquire(READ_ACCESS if type_ == "r" else WRITE_ACCESS, - 30, "test_db", None) + 30, 60, "test_db", None) pool.release(cx1) if type_ == "r": assert cx1.addr == READER_ADDRESS @@ -133,9 +133,9 @@ def test_chooses_right_connection_type(opener, type_): def test_reuses_connection(opener): pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS) - cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None) + cx1 = pool.acquire(READ_ACCESS, 30, 60, "test_db", None) pool.release(cx1) - cx2 = pool.acquire(READ_ACCESS, 30, "test_db", None) + cx2 = pool.acquire(READ_ACCESS, 30, 60, "test_db", None) assert cx1 is cx2 @@ -148,7 +148,7 @@ def break_connection(): cx_close_mock_side_effect() pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS) - cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None) + cx1 = pool.acquire(READ_ACCESS, 30, 60, "test_db", None) pool.release(cx1) assert cx1 in pool.connections[cx1.addr] # simulate connection going stale (e.g. exceeding) and then breaking when @@ -158,7 +158,7 @@ def break_connection(): if break_on_close: cx_close_mock_side_effect = cx_close_mock.side_effect cx_close_mock.side_effect = break_connection - cx2 = pool.acquire(READ_ACCESS, 30, "test_db", None) + cx2 = pool.acquire(READ_ACCESS, 30, 60, "test_db", None) pool.release(cx2) if break_on_close: cx1.close.assert_called() @@ -172,11 +172,11 @@ def break_connection(): def test_does_not_close_stale_connections_in_use(opener): pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS) - cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None) + cx1 = pool.acquire(READ_ACCESS, 30, 60, "test_db", None) assert cx1 in pool.connections[cx1.addr] # simulate connection going stale (e.g. exceeding) while being in use cx1.stale.return_value = True - cx2 = pool.acquire(READ_ACCESS, 30, "test_db", None) + cx2 = pool.acquire(READ_ACCESS, 30, 60, "test_db", None) pool.release(cx2) cx1.close.assert_not_called() assert cx2 is not cx1 @@ -189,7 +189,7 @@ def test_does_not_close_stale_connections_in_use(opener): # it should be closed when trying to acquire the next connection cx1.close.assert_not_called() - cx3 = pool.acquire(READ_ACCESS, 30, "test_db", None) + cx3 = pool.acquire(READ_ACCESS, 30, 60, "test_db", None) pool.release(cx3) cx1.close.assert_called_once() assert cx2 is cx3 @@ -200,7 +200,7 @@ def test_does_not_close_stale_connections_in_use(opener): def test_release_resets_connections(opener): pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS) - cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None) + cx1 = pool.acquire(READ_ACCESS, 30, 60, "test_db", None) cx1.is_reset_mock.return_value = False cx1.is_reset_mock.reset_mock() pool.release(cx1) @@ -210,7 +210,7 @@ def test_release_resets_connections(opener): def test_release_does_not_resets_closed_connections(opener): pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS) - cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None) + cx1 = pool.acquire(READ_ACCESS, 30, 60, "test_db", None) cx1.closed.return_value = True cx1.closed.reset_mock() cx1.is_reset_mock.reset_mock() @@ -222,7 +222,7 @@ def test_release_does_not_resets_closed_connections(opener): def test_release_does_not_resets_defunct_connections(opener): pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS) - cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None) + cx1 = pool.acquire(READ_ACCESS, 30, 60, "test_db", None) cx1.defunct.return_value = True cx1.defunct.reset_mock() cx1.is_reset_mock.reset_mock() @@ -243,8 +243,8 @@ def close_side_effect(): # create pool with 2 idle connections pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS) - cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None) - cx2 = pool.acquire(READ_ACCESS, 30, "test_db", None) + cx1 = pool.acquire(READ_ACCESS, 30, 60, "test_db", None) + cx2 = pool.acquire(READ_ACCESS, 30, 60, "test_db", None) pool.release(cx1) pool.release(cx2) @@ -256,7 +256,7 @@ def close_side_effect(): # unreachable cx1.stale.return_value = True - cx3 = pool.acquire(READ_ACCESS, 30, "test_db", None) + cx3 = pool.acquire(READ_ACCESS, 30, 60, "test_db", None) assert cx3 is not cx1 assert cx3 is not cx2 @@ -264,10 +264,10 @@ def close_side_effect(): def test_failing_opener_leaves_connections_in_use_alone(opener): pool = Neo4jPool(opener, PoolConfig(), WorkspaceConfig(), ROUTER_ADDRESS) - cx1 = pool.acquire(READ_ACCESS, 30, "test_db", None) + cx1 = pool.acquire(READ_ACCESS, 30, 60, "test_db", None) opener.side_effect = ServiceUnavailable("Server overloaded") with pytest.raises((ServiceUnavailable, SessionExpired)): - pool.acquire(READ_ACCESS, 30, "test_db", None) + pool.acquire(READ_ACCESS, 30, 60, "test_db", None) assert not cx1.closed() diff --git a/tests/unit/test_conf.py b/tests/unit/test_conf.py index 6e685edae..8ecf5e8b1 100644 --- a/tests/unit/test_conf.py +++ b/tests/unit/test_conf.py @@ -43,6 +43,7 @@ test_pool_config = { "connection_timeout": 30.0, + "update_routing_table_timeout": 90.0, "init_size": 1, "keep_alive": True, "max_connection_lifetime": 3600, @@ -55,6 +56,7 @@ } test_session_config = { + "session_connection_timeout": 180.0, "connection_acquisition_timeout": 60.0, "max_transaction_retry_time": 30.0, "initial_retry_delay": 1.0, diff --git a/tests/unit/test_driver.py b/tests/unit/test_driver.py index 0c1192e47..beb1b9c78 100644 --- a/tests/unit/test_driver.py +++ b/tests/unit/test_driver.py @@ -136,6 +136,7 @@ def test_driver_opens_write_session_by_default(uri, mocker): acquire_mock.assert_called_once_with( access_mode=WRITE_ACCESS, timeout=mocker.ANY, + acquisition_timeout=mocker.ANY, database=mocker.ANY, bookmarks=mocker.ANY ) diff --git a/tests/unit/work/_fake_connection.py b/tests/unit/work/_fake_connection.py index c1f78a192..c94026ea2 100644 --- a/tests/unit/work/_fake_connection.py +++ b/tests/unit/work/_fake_connection.py @@ -25,6 +25,7 @@ import pytest from neo4j import ServerInfo +from neo4j._deadline import Deadline class FakeConnection(mock.NonCallableMagicMock): @@ -38,6 +39,18 @@ def __init__(self, *args, **kwargs): self.attach_mock(mock.Mock(return_value=False), "defunct") self.attach_mock(mock.Mock(return_value=False), "stale") self.attach_mock(mock.Mock(return_value=False), "closed") + self.attach_mock(mock.Mock(return_value=False), "socket") + self.socket.attach_mock( + mock.Mock(return_value=None), "get_deadline" + ) + + def set_deadline_side_effect(deadline): + deadline = Deadline.from_timeout_or_deadline(deadline) + self.socket.get_deadline.return_value = deadline + + self.socket.attach_mock( + mock.Mock(side_effect=set_deadline_side_effect), "set_deadline" + ) def close_side_effect(): self.closed.return_value = True