diff --git a/asyncpg/connection.py b/asyncpg/connection.py index ea62b7a4..9cfd5f5c 100644 --- a/asyncpg/connection.py +++ b/asyncpg/connection.py @@ -38,8 +38,8 @@ class Connection(metaclass=ConnectionMeta): Connections are created by calling :func:`~asyncpg.connection.connect`. """ - __slots__ = ('_protocol', '_transport', '_loop', '_types_stmt', - '_type_by_name_stmt', '_top_xact', '_uid', '_aborted', + __slots__ = ('_protocol', '_transport', '_loop', + '_top_xact', '_uid', '_aborted', '_pool_release_ctr', '_stmt_cache', '_stmts_to_close', '_listeners', '_server_version', '_server_caps', '_intro_query', '_reset_query', '_proxy', @@ -53,8 +53,6 @@ def __init__(self, protocol, transport, loop, self._protocol = protocol self._transport = transport self._loop = loop - self._types_stmt = None - self._type_by_name_stmt = None self._top_xact = None self._uid = 0 self._aborted = False @@ -286,14 +284,17 @@ async def _get_statement(self, query, timeout, *, named: bool=False): stmt_name = '' statement = await self._protocol.prepare(stmt_name, query, timeout) - ready = statement._init_types() if ready is not True: - if self._types_stmt is None: - self._types_stmt = await self.prepare(self._intro_query) - - types = await self._types_stmt.fetch(list(ready)) + types, intro_stmt = await self.__execute( + self._intro_query, (list(ready),), 0, timeout) self._protocol.get_settings().register_data_types(types) + if not intro_stmt.name and not statement.name: + # The introspection query has used an anonymous statement, + # which has blown away the anonymous statement we've prepared + # for the query, so we need to re-prepare it. + statement = await self._protocol.prepare( + stmt_name, query, timeout) if use_cache: self._stmt_cache.put(query, statement) @@ -886,12 +887,8 @@ async def set_type_codec(self, typename, *, "asyncpg 0.13.0. Use the `format` keyword argument instead.", DeprecationWarning, stacklevel=2) - if self._type_by_name_stmt is None: - self._type_by_name_stmt = await self.prepare( - introspection.TYPE_BY_NAME) - - typeinfo = await self._type_by_name_stmt.fetchrow( - typename, schema) + typeinfo = await self.fetchrow( + introspection.TYPE_BY_NAME, typename, schema) if not typeinfo: raise ValueError('unknown type: {}.{}'.format(schema, typename)) @@ -921,12 +918,8 @@ async def reset_type_codec(self, typename, *, schema='public'): .. versionadded:: 0.12.0 """ - if self._type_by_name_stmt is None: - self._type_by_name_stmt = await self.prepare( - introspection.TYPE_BY_NAME) - - typeinfo = await self._type_by_name_stmt.fetchrow( - typename, schema) + typeinfo = await self.fetchrow( + introspection.TYPE_BY_NAME, typename, schema) if not typeinfo: raise ValueError('unknown type: {}.{}'.format(schema, typename)) @@ -949,12 +942,8 @@ async def set_builtin_type_codec(self, typename, *, """ self._check_open() - if self._type_by_name_stmt is None: - self._type_by_name_stmt = await self.prepare( - introspection.TYPE_BY_NAME) - - typeinfo = await self._type_by_name_stmt.fetchrow( - typename, schema) + typeinfo = await self.fetchrow( + introspection.TYPE_BY_NAME, typename, schema) if not typeinfo: raise ValueError('unknown type: {}.{}'.format(schema, typename)) @@ -1209,18 +1198,25 @@ def _drop_global_statement_cache(self): self._drop_local_statement_cache() async def _execute(self, query, args, limit, timeout, return_status=False): + with self._stmt_exclusive_section: + result, _ = await self.__execute( + query, args, limit, timeout, return_status=return_status) + return result + + async def __execute(self, query, args, limit, timeout, + return_status=False): executor = lambda stmt, timeout: self._protocol.bind_execute( stmt, args, '', limit, return_status, timeout) timeout = self._protocol._get_timeout(timeout) - with self._stmt_exclusive_section: - return await self._do_execute(query, executor, timeout) + return await self._do_execute(query, executor, timeout) async def _executemany(self, query, args, timeout): executor = lambda stmt, timeout: self._protocol.bind_execute_many( stmt, args, '', timeout) timeout = self._protocol._get_timeout(timeout) with self._stmt_exclusive_section: - return await self._do_execute(query, executor, timeout) + result, _ = await self._do_execute(query, executor, timeout) + return result async def _do_execute(self, query, executor, timeout, retry=True): if timeout is None: @@ -1269,10 +1265,10 @@ async def _do_execute(self, query, executor, timeout, retry=True): if self._protocol.is_in_transaction() or not retry: raise else: - result = await self._do_execute( + return await self._do_execute( query, executor, timeout, retry=False) - return result + return result, stmt async def connect(dsn=None, *, diff --git a/tests/test_introspection.py b/tests/test_introspection.py index d96ba690..aa286211 100644 --- a/tests/test_introspection.py +++ b/tests/test_introspection.py @@ -11,7 +11,7 @@ MAX_RUNTIME = 0.1 -class TestTimeout(tb.ConnectedTestCase): +class TestIntrospection(tb.ConnectedTestCase): @classmethod def setUpClass(cls): super().setUpClass() @@ -44,3 +44,51 @@ async def test_introspection_on_large_db(self): with self.assertRunUnder(MAX_RUNTIME): await self.con.fetchval('SELECT $1::int[]', [1, 2]) + + @tb.with_connection_options(statement_cache_size=0) + async def test_introspection_no_stmt_cache_01(self): + self.assertEqual(self.con._stmt_cache.get_max_size(), 0) + await self.con.fetchval('SELECT $1::int[]', [1, 2]) + + await self.con.execute(''' + CREATE EXTENSION IF NOT EXISTS hstore + ''') + + try: + await self.con.set_builtin_type_codec( + 'hstore', codec_name='pg_contrib.hstore') + finally: + await self.con.execute(''' + DROP EXTENSION hstore + ''') + + self.assertEqual(self.con._uid, 0) + + @tb.with_connection_options(max_cacheable_statement_size=1) + async def test_introspection_no_stmt_cache_02(self): + # max_cacheable_statement_size will disable caching both for + # the user query and for the introspection query. + await self.con.fetchval('SELECT $1::int[]', [1, 2]) + + await self.con.execute(''' + CREATE EXTENSION IF NOT EXISTS hstore + ''') + + try: + await self.con.set_builtin_type_codec( + 'hstore', codec_name='pg_contrib.hstore') + finally: + await self.con.execute(''' + DROP EXTENSION hstore + ''') + + self.assertEqual(self.con._uid, 0) + + @tb.with_connection_options(max_cacheable_statement_size=10000) + async def test_introspection_no_stmt_cache_03(self): + # max_cacheable_statement_size will disable caching for + # the user query but not for the introspection query. + await self.con.fetchval( + "SELECT $1::int[], '{foo}'".format(foo='a' * 10000), [1, 2]) + + self.assertEqual(self.con._uid, 1)