Skip to content

Use the general statement cache for type introspection #199

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 2, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 28 additions & 32 deletions asyncpg/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ class Connection(metaclass=ConnectionMeta):
Connections are created by calling :func:`~asyncpg.connection.connect`.
"""

__slots__ = ('_protocol', '_transport', '_loop', '_types_stmt',
'_type_by_name_stmt', '_top_xact', '_uid', '_aborted',
__slots__ = ('_protocol', '_transport', '_loop',
'_top_xact', '_uid', '_aborted',
'_pool_release_ctr', '_stmt_cache', '_stmts_to_close',
'_listeners', '_server_version', '_server_caps',
'_intro_query', '_reset_query', '_proxy',
Expand All @@ -53,8 +53,6 @@ def __init__(self, protocol, transport, loop,
self._protocol = protocol
self._transport = transport
self._loop = loop
self._types_stmt = None
self._type_by_name_stmt = None
self._top_xact = None
self._uid = 0
self._aborted = False
Expand Down Expand Up @@ -286,14 +284,17 @@ async def _get_statement(self, query, timeout, *, named: bool=False):
stmt_name = ''

statement = await self._protocol.prepare(stmt_name, query, timeout)

ready = statement._init_types()
if ready is not True:
if self._types_stmt is None:
self._types_stmt = await self.prepare(self._intro_query)

types = await self._types_stmt.fetch(list(ready))
types, intro_stmt = await self.__execute(
self._intro_query, (list(ready),), 0, timeout)
self._protocol.get_settings().register_data_types(types)
if not intro_stmt.name and not statement.name:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just check for statement.name? The current check won’t work correctly if statement is anonymous and intro_stmt is not. Also, tests?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think that's true. From the docs:

An unnamed prepared statement lasts only until the next Parse statement specifying the unnamed statement as destination is issued.

I added another test to show that.

# The introspection query has used an anonymous statement,
# which has blown away the anonymous statement we've prepared
# for the query, so we need to re-prepare it.
statement = await self._protocol.prepare(
stmt_name, query, timeout)

if use_cache:
self._stmt_cache.put(query, statement)
Expand Down Expand Up @@ -886,12 +887,8 @@ async def set_type_codec(self, typename, *,
"asyncpg 0.13.0. Use the `format` keyword argument instead.",
DeprecationWarning, stacklevel=2)

if self._type_by_name_stmt is None:
self._type_by_name_stmt = await self.prepare(
introspection.TYPE_BY_NAME)

typeinfo = await self._type_by_name_stmt.fetchrow(
typename, schema)
typeinfo = await self.fetchrow(
introspection.TYPE_BY_NAME, typename, schema)
if not typeinfo:
raise ValueError('unknown type: {}.{}'.format(schema, typename))

Expand Down Expand Up @@ -921,12 +918,8 @@ async def reset_type_codec(self, typename, *, schema='public'):
.. versionadded:: 0.12.0
"""

if self._type_by_name_stmt is None:
self._type_by_name_stmt = await self.prepare(
introspection.TYPE_BY_NAME)

typeinfo = await self._type_by_name_stmt.fetchrow(
typename, schema)
typeinfo = await self.fetchrow(
introspection.TYPE_BY_NAME, typename, schema)
if not typeinfo:
raise ValueError('unknown type: {}.{}'.format(schema, typename))

Expand All @@ -949,12 +942,8 @@ async def set_builtin_type_codec(self, typename, *,
"""
self._check_open()

if self._type_by_name_stmt is None:
self._type_by_name_stmt = await self.prepare(
introspection.TYPE_BY_NAME)

typeinfo = await self._type_by_name_stmt.fetchrow(
typename, schema)
typeinfo = await self.fetchrow(
introspection.TYPE_BY_NAME, typename, schema)
if not typeinfo:
raise ValueError('unknown type: {}.{}'.format(schema, typename))

Expand Down Expand Up @@ -1209,18 +1198,25 @@ def _drop_global_statement_cache(self):
self._drop_local_statement_cache()

async def _execute(self, query, args, limit, timeout, return_status=False):
with self._stmt_exclusive_section:
result, _ = await self.__execute(
query, args, limit, timeout, return_status=return_status)
return result

async def __execute(self, query, args, limit, timeout,
return_status=False):
executor = lambda stmt, timeout: self._protocol.bind_execute(
stmt, args, '', limit, return_status, timeout)
timeout = self._protocol._get_timeout(timeout)
with self._stmt_exclusive_section:
return await self._do_execute(query, executor, timeout)
return await self._do_execute(query, executor, timeout)

async def _executemany(self, query, args, timeout):
executor = lambda stmt, timeout: self._protocol.bind_execute_many(
stmt, args, '', timeout)
timeout = self._protocol._get_timeout(timeout)
with self._stmt_exclusive_section:
return await self._do_execute(query, executor, timeout)
result, _ = await self._do_execute(query, executor, timeout)
return result

async def _do_execute(self, query, executor, timeout, retry=True):
if timeout is None:
Expand Down Expand Up @@ -1269,10 +1265,10 @@ async def _do_execute(self, query, executor, timeout, retry=True):
if self._protocol.is_in_transaction() or not retry:
raise
else:
result = await self._do_execute(
return await self._do_execute(
query, executor, timeout, retry=False)

return result
return result, stmt


async def connect(dsn=None, *,
Expand Down
50 changes: 49 additions & 1 deletion tests/test_introspection.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
MAX_RUNTIME = 0.1


class TestTimeout(tb.ConnectedTestCase):
class TestIntrospection(tb.ConnectedTestCase):
@classmethod
def setUpClass(cls):
super().setUpClass()
Expand Down Expand Up @@ -44,3 +44,51 @@ async def test_introspection_on_large_db(self):

with self.assertRunUnder(MAX_RUNTIME):
await self.con.fetchval('SELECT $1::int[]', [1, 2])

@tb.with_connection_options(statement_cache_size=0)
async def test_introspection_no_stmt_cache_01(self):
self.assertEqual(self.con._stmt_cache.get_max_size(), 0)
await self.con.fetchval('SELECT $1::int[]', [1, 2])

await self.con.execute('''
CREATE EXTENSION IF NOT EXISTS hstore
''')

try:
await self.con.set_builtin_type_codec(
'hstore', codec_name='pg_contrib.hstore')
finally:
await self.con.execute('''
DROP EXTENSION hstore
''')

self.assertEqual(self.con._uid, 0)

@tb.with_connection_options(max_cacheable_statement_size=1)
async def test_introspection_no_stmt_cache_02(self):
# max_cacheable_statement_size will disable caching both for
# the user query and for the introspection query.
await self.con.fetchval('SELECT $1::int[]', [1, 2])

await self.con.execute('''
CREATE EXTENSION IF NOT EXISTS hstore
''')

try:
await self.con.set_builtin_type_codec(
'hstore', codec_name='pg_contrib.hstore')
finally:
await self.con.execute('''
DROP EXTENSION hstore
''')

self.assertEqual(self.con._uid, 0)

@tb.with_connection_options(max_cacheable_statement_size=10000)
async def test_introspection_no_stmt_cache_03(self):
# max_cacheable_statement_size will disable caching for
# the user query but not for the introspection query.
await self.con.fetchval(
"SELECT $1::int[], '{foo}'".format(foo='a' * 10000), [1, 2])

self.assertEqual(self.con._uid, 1)