Skip to content

Fix type codec cache races #301

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 1, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 41 additions & 10 deletions asyncpg/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,17 +296,34 @@ async def _get_statement(self, query, timeout, *, named: bool=False,
stmt_name = ''

statement = await self._protocol.prepare(stmt_name, query, timeout)
ready = statement._init_types()
if ready is not True:
types, intro_stmt = await self.__execute(
self._intro_query, (list(ready),), 0, timeout)
self._protocol.get_settings().register_data_types(types)
need_reprepare = False
types_with_missing_codecs = statement._init_types()
tries = 0
while types_with_missing_codecs:
settings = self._protocol.get_settings()

# Introspect newly seen types and populate the
# codec cache.
types, intro_stmt = await self._introspect_types(
types_with_missing_codecs, timeout)

settings.register_data_types(types)

# The introspection query has used an anonymous statement,
# which has blown away the anonymous statement we've prepared
# for the query, so we need to re-prepare it.
need_reprepare = not intro_stmt.name and not statement.name
else:
need_reprepare = False
types_with_missing_codecs = statement._init_types()
tries += 1
if tries > 5:
# In the vast majority of cases there will be only
# one iteration. In rare cases, there might be a race
# with reload_schema_state(), which would cause a
# second try. More than five is clearly a bug.
raise exceptions.InternalClientError(
'could not resolve query result and/or argument types '
'in {} attempts'.format(tries)
)

# Now that types have been resolved, populate the codec pipeline
# for the statement.
Expand All @@ -326,6 +343,10 @@ async def _get_statement(self, query, timeout, *, named: bool=False,

return statement

async def _introspect_types(self, typeoids, timeout):
return await self.__execute(
self._intro_query, (list(typeoids),), 0, timeout)

def cursor(self, query, *args, prefetch=None, timeout=None):
"""Return a *cursor factory* for the specified query.

Expand Down Expand Up @@ -1271,6 +1292,18 @@ def _drop_global_statement_cache(self):
else:
self._drop_local_statement_cache()

def _drop_local_type_cache(self):
self._protocol.get_settings().clear_type_cache()

def _drop_global_type_cache(self):
if self._proxy is not None:
# This connection is a member of a pool, so we delegate
# the cache drop to the pool.
pool = self._proxy._holder._pool
pool._drop_type_cache()
else:
self._drop_local_type_cache()

async def reload_schema_state(self):
"""Indicate that the database schema information must be reloaded.

Expand Down Expand Up @@ -1313,9 +1346,7 @@ async def reload_schema_state(self):

.. versionadded:: 0.14.0
"""
# It is enough to clear the type cache only once, not in each
# connection in the pool.
self._protocol.get_settings().clear_type_cache()
self._drop_global_type_cache()
self._drop_global_statement_cache()

async def _execute(self, query, args, limit, timeout, return_status=False):
Expand Down
6 changes: 6 additions & 0 deletions asyncpg/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,6 +614,12 @@ def _drop_statement_cache(self):
if ch._con is not None:
ch._con._drop_local_statement_cache()

def _drop_type_cache(self):
# Drop type codec cache for all connections in the pool.
for ch in self._holders:
if ch._con is not None:
ch._con._drop_local_type_cache()

def __await__(self):
return self._async__init__().__await__()

Expand Down
4 changes: 2 additions & 2 deletions asyncpg/protocol/codecs/base.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,8 @@ cdef class Codec:

cdef class DataCodecConfig:
cdef:
dict _type_codecs_cache
dict _local_type_codecs
dict _derived_type_codecs
dict _custom_type_codecs

cdef inline Codec get_codec(self, uint32_t oid, ServerDataFormat format)
cdef inline Codec get_local_codec(self, uint32_t oid)
32 changes: 15 additions & 17 deletions asyncpg/protocol/codecs/base.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ from asyncpg.exceptions import OutdatedSchemaCacheError

cdef void* binary_codec_map[(MAXSUPPORTEDOID + 1) * 2]
cdef void* text_codec_map[(MAXSUPPORTEDOID + 1) * 2]
cdef dict TYPE_CODECS_CACHE = {}
cdef dict EXTRA_CODECS = {}


Expand Down Expand Up @@ -391,12 +390,11 @@ cdef uint32_t pylong_as_oid(val) except? 0xFFFFFFFFl:

cdef class DataCodecConfig:
def __init__(self, cache_key):
try:
self._type_codecs_cache = TYPE_CODECS_CACHE[cache_key]
except KeyError:
self._type_codecs_cache = TYPE_CODECS_CACHE[cache_key] = {}

self._local_type_codecs = {}
# Codec instance cache for derived types:
# composites, arrays, ranges, domains and their combinations.
self._derived_type_codecs = {}
# Codec instances set up by the user for the connection.
self._custom_type_codecs = {}

def add_types(self, types):
cdef:
Expand Down Expand Up @@ -451,7 +449,7 @@ cdef class DataCodecConfig:

elem_delim = <Py_UCS4>ti['elemdelim'][0]

self._type_codecs_cache[oid, elem_format] = \
self._derived_type_codecs[oid, elem_format] = \
Codec.new_array_codec(
oid, name, schema, elem_codec, elem_delim)

Expand Down Expand Up @@ -483,7 +481,7 @@ cdef class DataCodecConfig:
if has_text_elements:
format = PG_FORMAT_TEXT

self._type_codecs_cache[oid, format] = \
self._derived_type_codecs[oid, format] = \
Codec.new_composite_codec(
oid, name, schema, format, comp_elem_codecs,
comp_type_attrs, element_names)
Expand All @@ -502,7 +500,7 @@ cdef class DataCodecConfig:
elem_codec = self.declare_fallback_codec(
base_type, name, schema)

self._type_codecs_cache[oid, format] = elem_codec
self._derived_type_codecs[oid, format] = elem_codec

elif ti['kind'] == b'r':
# Range type
Expand All @@ -523,7 +521,7 @@ cdef class DataCodecConfig:
elem_codec = self.declare_fallback_codec(
range_subtype_oid, name, schema)

self._type_codecs_cache[oid, elem_format] = \
self._derived_type_codecs[oid, elem_format] = \
Codec.new_range_codec(oid, name, schema, elem_codec)

elif ti['kind'] == b'e':
Expand Down Expand Up @@ -554,13 +552,13 @@ cdef class DataCodecConfig:
# Clear all previous overrides (this also clears type cache).
self.remove_python_codec(typeoid, typename, typeschema)

self._local_type_codecs[typeoid] = \
self._custom_type_codecs[typeoid] = \
Codec.new_python_codec(oid, typename, typeschema, typekind,
encoder, decoder, c_encoder, c_decoder,
format, xformat)

def remove_python_codec(self, typeoid, typename, typeschema):
self._local_type_codecs.pop(typeoid, None)
self._custom_type_codecs.pop(typeoid, None)
self.clear_type_cache()

def _set_builtin_type_codec(self, typeoid, typename, typeschema, typekind,
Expand Down Expand Up @@ -592,7 +590,7 @@ cdef class DataCodecConfig:
codec.schema = typeschema
codec.kind = typekind

self._local_type_codecs[typeoid] = codec
self._custom_type_codecs[typeoid] = codec
break
else:
raise ValueError('unknown alias target: {}'.format(alias_to))
Expand All @@ -604,7 +602,7 @@ cdef class DataCodecConfig:
self.clear_type_cache()

def clear_type_cache(self):
self._type_codecs_cache.clear()
self._derived_type_codecs.clear()

def declare_fallback_codec(self, uint32_t oid, str name, str schema):
cdef Codec codec
Expand Down Expand Up @@ -654,12 +652,12 @@ cdef class DataCodecConfig:
return codec
else:
try:
return self._type_codecs_cache[oid, format]
return self._derived_type_codecs[oid, format]
except KeyError:
return None

cdef inline Codec get_local_codec(self, uint32_t oid):
return self._local_type_codecs.get(oid)
return self._custom_type_codecs.get(oid)


cdef inline Codec get_core_codec(
Expand Down
11 changes: 4 additions & 7 deletions asyncpg/protocol/prepared_stmt.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -63,24 +63,21 @@ cdef class PreparedStatementState:
def _init_types(self):
cdef:
Codec codec
set result = set()
set missing = set()

if self.parameters_desc:
for p_oid in self.parameters_desc:
codec = self.settings.get_data_codec(<uint32_t>p_oid)
if codec is None or not codec.has_encoder():
result.add(p_oid)
missing.add(p_oid)

if self.row_desc:
for rdesc in self.row_desc:
codec = self.settings.get_data_codec(<uint32_t>(rdesc[3]))
if codec is None or not codec.has_decoder():
result.add(rdesc[3])
missing.add(rdesc[3])

if len(result):
return result
else:
return True
return missing

cpdef _init_codecs(self):
self._ensure_args_encoder()
Expand Down
3 changes: 2 additions & 1 deletion asyncpg/protocol/settings.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,5 @@ cdef class ConnectionSettings:
cpdef inline clear_type_cache(self)
cpdef inline set_builtin_type_codec(
self, typeoid, typename, typeschema, typekind, alias_to)
cpdef inline Codec get_data_codec(self, uint32_t oid, ServerDataFormat format=*)
cpdef inline Codec get_data_codec(
self, uint32_t oid, ServerDataFormat format=*)
50 changes: 50 additions & 0 deletions tests/test_introspection.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0


import asyncio
import json

from asyncpg import _testbase as tb
Expand All @@ -14,6 +15,16 @@
MAX_RUNTIME = 0.1


class SlowIntrospectionConnection(apg_con.Connection):
"""Connection class to test introspection races."""
introspect_count = 0

async def _introspect_types(self, *args, **kwargs):
self.introspect_count += 1
await asyncio.sleep(0.4, loop=self._loop)
return await super()._introspect_types(*args, **kwargs)


class TestIntrospection(tb.ConnectedTestCase):
@classmethod
def setUpClass(cls):
Expand Down Expand Up @@ -125,3 +136,42 @@ async def test_introspection_sticks_for_ps(self):
finally:
await self.con.reset_type_codec(
'json', schema='pg_catalog')

async def test_introspection_retries_after_cache_bust(self):
# Test that codec cache bust racing with the introspection
# query would cause introspection to retry.
slow_intro_conn = await self.connect(
connection_class=SlowIntrospectionConnection)
try:
await self.con.execute('''
CREATE DOMAIN intro_1_t AS int;
CREATE DOMAIN intro_2_t AS int;
''')

await slow_intro_conn.fetchval('''
SELECT $1::intro_1_t
''', 10)
# slow_intro_conn cache is now populated with intro_1_t

async def wait_and_drop():
await asyncio.sleep(0.1, loop=self.loop)
await slow_intro_conn.reload_schema_state()

# Now, in parallel, run another query that
# references both intro_1_t and intro_2_t.
await asyncio.gather(
slow_intro_conn.fetchval('''
SELECT $1::intro_1_t, $2::intro_2_t
''', 10, 20),
wait_and_drop()
)

# Initial query + two tries for the second query.
self.assertEqual(slow_intro_conn.introspect_count, 3)

finally:
await self.con.execute('''
DROP DOMAIN intro_1_t;
DROP DOMAIN intro_2_t;
''')
await slow_intro_conn.close()