diff --git a/asyncpg/connection.py b/asyncpg/connection.py index 96a8f776..4b32f5ba 100644 --- a/asyncpg/connection.py +++ b/asyncpg/connection.py @@ -296,17 +296,34 @@ async def _get_statement(self, query, timeout, *, named: bool=False, stmt_name = '' statement = await self._protocol.prepare(stmt_name, query, timeout) - ready = statement._init_types() - if ready is not True: - types, intro_stmt = await self.__execute( - self._intro_query, (list(ready),), 0, timeout) - self._protocol.get_settings().register_data_types(types) + need_reprepare = False + types_with_missing_codecs = statement._init_types() + tries = 0 + while types_with_missing_codecs: + settings = self._protocol.get_settings() + + # Introspect newly seen types and populate the + # codec cache. + types, intro_stmt = await self._introspect_types( + types_with_missing_codecs, timeout) + + settings.register_data_types(types) + # The introspection query has used an anonymous statement, # which has blown away the anonymous statement we've prepared # for the query, so we need to re-prepare it. need_reprepare = not intro_stmt.name and not statement.name - else: - need_reprepare = False + types_with_missing_codecs = statement._init_types() + tries += 1 + if tries > 5: + # In the vast majority of cases there will be only + # one iteration. In rare cases, there might be a race + # with reload_schema_state(), which would cause a + # second try. More than five is clearly a bug. + raise exceptions.InternalClientError( + 'could not resolve query result and/or argument types ' + 'in {} attempts'.format(tries) + ) # Now that types have been resolved, populate the codec pipeline # for the statement. @@ -326,6 +343,10 @@ async def _get_statement(self, query, timeout, *, named: bool=False, return statement + async def _introspect_types(self, typeoids, timeout): + return await self.__execute( + self._intro_query, (list(typeoids),), 0, timeout) + def cursor(self, query, *args, prefetch=None, timeout=None): """Return a *cursor factory* for the specified query. @@ -1271,6 +1292,18 @@ def _drop_global_statement_cache(self): else: self._drop_local_statement_cache() + def _drop_local_type_cache(self): + self._protocol.get_settings().clear_type_cache() + + def _drop_global_type_cache(self): + if self._proxy is not None: + # This connection is a member of a pool, so we delegate + # the cache drop to the pool. + pool = self._proxy._holder._pool + pool._drop_type_cache() + else: + self._drop_local_type_cache() + async def reload_schema_state(self): """Indicate that the database schema information must be reloaded. @@ -1313,9 +1346,7 @@ async def reload_schema_state(self): .. versionadded:: 0.14.0 """ - # It is enough to clear the type cache only once, not in each - # connection in the pool. - self._protocol.get_settings().clear_type_cache() + self._drop_global_type_cache() self._drop_global_statement_cache() async def _execute(self, query, args, limit, timeout, return_status=False): diff --git a/asyncpg/pool.py b/asyncpg/pool.py index f60de047..b3f6d181 100644 --- a/asyncpg/pool.py +++ b/asyncpg/pool.py @@ -614,6 +614,12 @@ def _drop_statement_cache(self): if ch._con is not None: ch._con._drop_local_statement_cache() + def _drop_type_cache(self): + # Drop type codec cache for all connections in the pool. + for ch in self._holders: + if ch._con is not None: + ch._con._drop_local_type_cache() + def __await__(self): return self._async__init__().__await__() diff --git a/asyncpg/protocol/codecs/base.pxd b/asyncpg/protocol/codecs/base.pxd index 59e75238..5bfb4f32 100644 --- a/asyncpg/protocol/codecs/base.pxd +++ b/asyncpg/protocol/codecs/base.pxd @@ -167,8 +167,8 @@ cdef class Codec: cdef class DataCodecConfig: cdef: - dict _type_codecs_cache - dict _local_type_codecs + dict _derived_type_codecs + dict _custom_type_codecs cdef inline Codec get_codec(self, uint32_t oid, ServerDataFormat format) cdef inline Codec get_local_codec(self, uint32_t oid) diff --git a/asyncpg/protocol/codecs/base.pyx b/asyncpg/protocol/codecs/base.pyx index c1348781..663941bb 100644 --- a/asyncpg/protocol/codecs/base.pyx +++ b/asyncpg/protocol/codecs/base.pyx @@ -10,7 +10,6 @@ from asyncpg.exceptions import OutdatedSchemaCacheError cdef void* binary_codec_map[(MAXSUPPORTEDOID + 1) * 2] cdef void* text_codec_map[(MAXSUPPORTEDOID + 1) * 2] -cdef dict TYPE_CODECS_CACHE = {} cdef dict EXTRA_CODECS = {} @@ -391,12 +390,11 @@ cdef uint32_t pylong_as_oid(val) except? 0xFFFFFFFFl: cdef class DataCodecConfig: def __init__(self, cache_key): - try: - self._type_codecs_cache = TYPE_CODECS_CACHE[cache_key] - except KeyError: - self._type_codecs_cache = TYPE_CODECS_CACHE[cache_key] = {} - - self._local_type_codecs = {} + # Codec instance cache for derived types: + # composites, arrays, ranges, domains and their combinations. + self._derived_type_codecs = {} + # Codec instances set up by the user for the connection. + self._custom_type_codecs = {} def add_types(self, types): cdef: @@ -451,7 +449,7 @@ cdef class DataCodecConfig: elem_delim = ti['elemdelim'][0] - self._type_codecs_cache[oid, elem_format] = \ + self._derived_type_codecs[oid, elem_format] = \ Codec.new_array_codec( oid, name, schema, elem_codec, elem_delim) @@ -483,7 +481,7 @@ cdef class DataCodecConfig: if has_text_elements: format = PG_FORMAT_TEXT - self._type_codecs_cache[oid, format] = \ + self._derived_type_codecs[oid, format] = \ Codec.new_composite_codec( oid, name, schema, format, comp_elem_codecs, comp_type_attrs, element_names) @@ -502,7 +500,7 @@ cdef class DataCodecConfig: elem_codec = self.declare_fallback_codec( base_type, name, schema) - self._type_codecs_cache[oid, format] = elem_codec + self._derived_type_codecs[oid, format] = elem_codec elif ti['kind'] == b'r': # Range type @@ -523,7 +521,7 @@ cdef class DataCodecConfig: elem_codec = self.declare_fallback_codec( range_subtype_oid, name, schema) - self._type_codecs_cache[oid, elem_format] = \ + self._derived_type_codecs[oid, elem_format] = \ Codec.new_range_codec(oid, name, schema, elem_codec) elif ti['kind'] == b'e': @@ -554,13 +552,13 @@ cdef class DataCodecConfig: # Clear all previous overrides (this also clears type cache). self.remove_python_codec(typeoid, typename, typeschema) - self._local_type_codecs[typeoid] = \ + self._custom_type_codecs[typeoid] = \ Codec.new_python_codec(oid, typename, typeschema, typekind, encoder, decoder, c_encoder, c_decoder, format, xformat) def remove_python_codec(self, typeoid, typename, typeschema): - self._local_type_codecs.pop(typeoid, None) + self._custom_type_codecs.pop(typeoid, None) self.clear_type_cache() def _set_builtin_type_codec(self, typeoid, typename, typeschema, typekind, @@ -592,7 +590,7 @@ cdef class DataCodecConfig: codec.schema = typeschema codec.kind = typekind - self._local_type_codecs[typeoid] = codec + self._custom_type_codecs[typeoid] = codec break else: raise ValueError('unknown alias target: {}'.format(alias_to)) @@ -604,7 +602,7 @@ cdef class DataCodecConfig: self.clear_type_cache() def clear_type_cache(self): - self._type_codecs_cache.clear() + self._derived_type_codecs.clear() def declare_fallback_codec(self, uint32_t oid, str name, str schema): cdef Codec codec @@ -654,12 +652,12 @@ cdef class DataCodecConfig: return codec else: try: - return self._type_codecs_cache[oid, format] + return self._derived_type_codecs[oid, format] except KeyError: return None cdef inline Codec get_local_codec(self, uint32_t oid): - return self._local_type_codecs.get(oid) + return self._custom_type_codecs.get(oid) cdef inline Codec get_core_codec( diff --git a/asyncpg/protocol/prepared_stmt.pyx b/asyncpg/protocol/prepared_stmt.pyx index e8ea038c..7e0d6e31 100644 --- a/asyncpg/protocol/prepared_stmt.pyx +++ b/asyncpg/protocol/prepared_stmt.pyx @@ -63,24 +63,21 @@ cdef class PreparedStatementState: def _init_types(self): cdef: Codec codec - set result = set() + set missing = set() if self.parameters_desc: for p_oid in self.parameters_desc: codec = self.settings.get_data_codec(p_oid) if codec is None or not codec.has_encoder(): - result.add(p_oid) + missing.add(p_oid) if self.row_desc: for rdesc in self.row_desc: codec = self.settings.get_data_codec((rdesc[3])) if codec is None or not codec.has_decoder(): - result.add(rdesc[3]) + missing.add(rdesc[3]) - if len(result): - return result - else: - return True + return missing cpdef _init_codecs(self): self._ensure_args_encoder() diff --git a/asyncpg/protocol/settings.pxd b/asyncpg/protocol/settings.pxd index fe4ef507..6e4adf21 100644 --- a/asyncpg/protocol/settings.pxd +++ b/asyncpg/protocol/settings.pxd @@ -25,4 +25,5 @@ cdef class ConnectionSettings: cpdef inline clear_type_cache(self) cpdef inline set_builtin_type_codec( self, typeoid, typename, typeschema, typekind, alias_to) - cpdef inline Codec get_data_codec(self, uint32_t oid, ServerDataFormat format=*) + cpdef inline Codec get_data_codec( + self, uint32_t oid, ServerDataFormat format=*) diff --git a/tests/test_introspection.py b/tests/test_introspection.py index fcf5885d..92ebc7a4 100644 --- a/tests/test_introspection.py +++ b/tests/test_introspection.py @@ -5,6 +5,7 @@ # the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 +import asyncio import json from asyncpg import _testbase as tb @@ -14,6 +15,16 @@ MAX_RUNTIME = 0.1 +class SlowIntrospectionConnection(apg_con.Connection): + """Connection class to test introspection races.""" + introspect_count = 0 + + async def _introspect_types(self, *args, **kwargs): + self.introspect_count += 1 + await asyncio.sleep(0.4, loop=self._loop) + return await super()._introspect_types(*args, **kwargs) + + class TestIntrospection(tb.ConnectedTestCase): @classmethod def setUpClass(cls): @@ -125,3 +136,42 @@ async def test_introspection_sticks_for_ps(self): finally: await self.con.reset_type_codec( 'json', schema='pg_catalog') + + async def test_introspection_retries_after_cache_bust(self): + # Test that codec cache bust racing with the introspection + # query would cause introspection to retry. + slow_intro_conn = await self.connect( + connection_class=SlowIntrospectionConnection) + try: + await self.con.execute(''' + CREATE DOMAIN intro_1_t AS int; + CREATE DOMAIN intro_2_t AS int; + ''') + + await slow_intro_conn.fetchval(''' + SELECT $1::intro_1_t + ''', 10) + # slow_intro_conn cache is now populated with intro_1_t + + async def wait_and_drop(): + await asyncio.sleep(0.1, loop=self.loop) + await slow_intro_conn.reload_schema_state() + + # Now, in parallel, run another query that + # references both intro_1_t and intro_2_t. + await asyncio.gather( + slow_intro_conn.fetchval(''' + SELECT $1::intro_1_t, $2::intro_2_t + ''', 10, 20), + wait_and_drop() + ) + + # Initial query + two tries for the second query. + self.assertEqual(slow_intro_conn.introspect_count, 3) + + finally: + await self.con.execute(''' + DROP DOMAIN intro_1_t; + DROP DOMAIN intro_2_t; + ''') + await slow_intro_conn.close()