diff --git a/asyncpg/connect_utils.py b/asyncpg/connect_utils.py index 0631f976..c3607569 100644 --- a/asyncpg/connect_utils.py +++ b/asyncpg/connect_utils.py @@ -68,6 +68,7 @@ def parse(cls, sslmode): 'statement_cache_size', 'max_cached_statement_lifetime', 'max_cacheable_statement_size', + 'get_reset_query', ]) @@ -690,7 +691,8 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile, max_cached_statement_lifetime, max_cacheable_statement_size, ssl, direct_tls, server_settings, - target_session_attrs, krbsrvname, gsslib): + target_session_attrs, krbsrvname, gsslib, + get_reset_query): local_vars = locals() for var_name in {'max_cacheable_statement_size', 'max_cached_statement_lifetime', @@ -726,7 +728,9 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile, command_timeout=command_timeout, statement_cache_size=statement_cache_size, max_cached_statement_lifetime=max_cached_statement_lifetime, - max_cacheable_statement_size=max_cacheable_statement_size,) + max_cacheable_statement_size=max_cacheable_statement_size, + get_reset_query=get_reset_query, + ) return addrs, params, config diff --git a/asyncpg/connection.py b/asyncpg/connection.py index e54d6df8..a0ea10a0 100644 --- a/asyncpg/connection.py +++ b/asyncpg/connection.py @@ -1659,25 +1659,25 @@ def _unwrap(self): return con_ref def _get_reset_query(self): - if self._reset_query is not None: - return self._reset_query + if self._reset_query is None: + get_reset_query = self._config.get_reset_query or self._get_default_reset_query + self._reset_query = get_reset_query(self._server_caps) - caps = self._server_caps + return self._reset_query + def _get_default_reset_query(self, server_caps): _reset_query = [] - if caps.advisory_locks: + + if server_caps.advisory_locks: _reset_query.append('SELECT pg_advisory_unlock_all();') - if caps.sql_close_all: + if server_caps.sql_close_all: _reset_query.append('CLOSE ALL;') - if caps.notifications and caps.plpgsql: + if server_caps.notifications and server_caps.plpgsql: _reset_query.append('UNLISTEN *;') - if caps.sql_reset: + if server_caps.sql_reset: _reset_query.append('RESET ALL;') - _reset_query = '\n'.join(_reset_query) - self._reset_query = _reset_query - - return _reset_query + return '\n'.join(_reset_query) def _set_proxy(self, proxy): if self._proxy is not None and proxy is not None: @@ -2009,7 +2009,8 @@ async def connect(dsn=None, *, server_settings=None, target_session_attrs=None, krbsrvname=None, - gsslib=None): + gsslib=None, + get_reset_query=None): r"""A coroutine to establish a connection to a PostgreSQL server. The connection parameters may be specified either as a connection @@ -2245,6 +2246,12 @@ async def connect(dsn=None, *, GSS library to use for GSSAPI/SSPI authentication. Can be 'gssapi' or 'sspi'. Defaults to 'sspi' on Windows and 'gssapi' otherwise. + :param get_reset_query: + Function to build a query that should be executed when resetting + the connection. Takes a single argument of type `~.asyncpg.connection.ServerCapabilities` + that communicates auto-detected server capabilities. + Defaults to `None` which means use the default reset query builder + :return: A :class:`~asyncpg.connection.Connection` instance. Example: @@ -2360,6 +2367,7 @@ async def connect(dsn=None, *, target_session_attrs=target_session_attrs, krbsrvname=krbsrvname, gsslib=gsslib, + get_reset_query=get_reset_query, )