diff --git a/asyncpg/.gitignore b/asyncpg/.gitignore index 2d19fc76..544eb359 100644 --- a/asyncpg/.gitignore +++ b/asyncpg/.gitignore @@ -1 +1,2 @@ *.html +/*.c diff --git a/asyncpg/__init__.py b/asyncpg/__init__.py index 380dc26e..21fd5456 100644 --- a/asyncpg/__init__.py +++ b/asyncpg/__init__.py @@ -10,6 +10,7 @@ from .pool import create_pool # NOQA from .protocol import Record # NOQA from .types import * # NOQA +from .query_pp import keyword_parameters # NOQA __all__ = ('connect', 'create_pool', 'Record', 'Connection') + \ diff --git a/asyncpg/connect_utils.py b/asyncpg/connect_utils.py index da616523..f68aa9ff 100644 --- a/asyncpg/connect_utils.py +++ b/asyncpg/connect_utils.py @@ -38,6 +38,7 @@ 'statement_cache_size', 'max_cached_statement_lifetime', 'max_cacheable_statement_size', + 'query_pp' ]) @@ -210,7 +211,7 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, database, timeout, command_timeout, statement_cache_size, max_cached_statement_lifetime, max_cacheable_statement_size, - ssl, server_settings): + ssl, server_settings, query_pp): local_vars = locals() for var_name in {'max_cacheable_statement_size', @@ -245,7 +246,8 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, database, command_timeout=command_timeout, statement_cache_size=statement_cache_size, max_cached_statement_lifetime=max_cached_statement_lifetime, - max_cacheable_statement_size=max_cacheable_statement_size,) + max_cacheable_statement_size=max_cacheable_statement_size, + query_pp=query_pp) return addrs, params, config diff --git a/asyncpg/connection.py b/asyncpg/connection.py index ea62b7a4..8851a9c0 100644 --- a/asyncpg/connection.py +++ b/asyncpg/connection.py @@ -18,6 +18,7 @@ from . import cursor from . import exceptions from . import introspection +from . import query_pp from . import prepared_stmt from . import protocol from . import serverversion @@ -44,7 +45,7 @@ class Connection(metaclass=ConnectionMeta): '_listeners', '_server_version', '_server_caps', '_intro_query', '_reset_query', '_proxy', '_stmt_exclusive_section', '_config', '_params', '_addr', - '_log_listeners') + '_log_listeners', '_query_pp') def __init__(self, protocol, transport, loop, addr: (str, int) or str, @@ -66,6 +67,7 @@ def __init__(self, protocol, transport, loop, self._addr = addr self._config = config self._params = params + self._query_pp = config.query_pp self._stmt_cache = _StatementCache( loop=loop, @@ -206,7 +208,8 @@ def transaction(self, *, isolation='read_committed', readonly=False, self._check_open() return transaction.Transaction(self, isolation, readonly, deferrable) - async def execute(self, query: str, *args, timeout: float=None) -> str: + async def execute(self, query: str, *args, kwargs=None, + timeout: float=None) -> str: """Execute an SQL command (or commands). This method can execute many SQL commands at once, when no arguments @@ -239,7 +242,8 @@ async def execute(self, query: str, *args, timeout: float=None) -> str: if not args: return await self._protocol.query(query, timeout) - _, status, _ = await self._execute(query, args, 0, timeout, True) + _, status, _ = await self._execute( + query, args, kwargs, 0, timeout, True) return status.decode() async def executemany(self, command: str, args, *, timeout: float=None): @@ -305,7 +309,7 @@ async def _get_statement(self, query, timeout, *, named: bool=False): return statement - def cursor(self, query, *args, prefetch=None, timeout=None): + def cursor(self, query, *args, kwargs=None, prefetch=None, timeout=None): """Return a *cursor factory* for the specified query. :param args: Query arguments. @@ -316,7 +320,7 @@ def cursor(self, query, *args, prefetch=None, timeout=None): :return: A :class:`~cursor.CursorFactory` object. """ self._check_open() - return cursor.CursorFactory(self, query, None, args, + return cursor.CursorFactory(self, query, None, args, kwargs, prefetch, timeout) async def prepare(self, query, *, timeout=None): @@ -329,9 +333,9 @@ async def prepare(self, query, *, timeout=None): """ self._check_open() stmt = await self._get_statement(query, timeout, named=True) - return prepared_stmt.PreparedStatement(self, query, stmt) + return prepared_stmt.PreparedStatement(self, stmt) - async def fetch(self, query, *args, timeout=None) -> list: + async def fetch(self, query, *args, kwargs=None, timeout=None) -> list: """Run a query and return the results as a list of :class:`Record`. :param str query: Query text. @@ -341,9 +345,10 @@ async def fetch(self, query, *args, timeout=None) -> list: :return list: A list of :class:`Record` instances. """ self._check_open() - return await self._execute(query, args, 0, timeout) + return await self._execute(query, args, kwargs, 0, timeout) - async def fetchval(self, query, *args, column=0, timeout=None): + async def fetchval(self, query, *args, kwargs=None, + column=0, timeout=None): """Run a query and return a value in the first row. :param str query: Query text. @@ -359,12 +364,12 @@ async def fetchval(self, query, *args, column=0, timeout=None): None if no records were returned by the query. """ self._check_open() - data = await self._execute(query, args, 1, timeout) + data = await self._execute(query, args, kwargs, 1, timeout) if not data: return None return data[0][column] - async def fetchrow(self, query, *args, timeout=None): + async def fetchrow(self, query, *args, kwargs=None, timeout=None): """Run a query and return the first row. :param str query: Query text @@ -375,7 +380,7 @@ async def fetchrow(self, query, *args, timeout=None): no records were returned by the query. """ self._check_open() - data = await self._execute(query, args, 1, timeout) + data = await self._execute(query, args, kwargs, 1, timeout) if not data: return None return data[0] @@ -451,7 +456,7 @@ async def copy_from_table(self, table_name, *, output, return await self._copy_out(copy_stmt, output, timeout) - async def copy_from_query(self, query, *args, output, + async def copy_from_query(self, query, *args, kwargs=None, output, timeout=None, format=None, oids=None, delimiter=None, null=None, header=None, quote=None, escape=None, force_quote=None, @@ -504,8 +509,8 @@ async def copy_from_query(self, query, *args, output, force_quote=force_quote, encoding=encoding ) - if args: - query = await utils._mogrify(self, query, args) + if args or kwargs: + query = await utils._mogrify(self, query, args, kwargs) copy_stmt = 'COPY ({query}) TO STDOUT {opts}'.format( query=query, opts=opts) @@ -1208,9 +1213,10 @@ def _drop_global_statement_cache(self): else: self._drop_local_statement_cache() - async def _execute(self, query, args, limit, timeout, return_status=False): + async def _execute(self, query, args, kwargs, + limit, timeout, return_status=False): executor = lambda stmt, timeout: self._protocol.bind_execute( - stmt, args, '', limit, return_status, timeout) + stmt, args, kwargs, '', limit, return_status, timeout) timeout = self._protocol._get_timeout(timeout) with self._stmt_exclusive_section: return await self._do_execute(query, executor, timeout) @@ -1287,7 +1293,8 @@ async def connect(dsn=None, *, command_timeout=None, ssl=None, connection_class=Connection, - server_settings=None): + server_settings=None, + query_pp=None): r"""A coroutine to establish a connection to a PostgreSQL server. Returns a new :class:`~asyncpg.connection.Connection` object. @@ -1403,7 +1410,8 @@ class of the returned connection object. Must be a subclass of command_timeout=command_timeout, statement_cache_size=statement_cache_size, max_cached_statement_lifetime=max_cached_statement_lifetime, - max_cacheable_statement_size=max_cacheable_statement_size) + max_cacheable_statement_size=max_cacheable_statement_size, + query_pp=query_pp) class _StatementCacheEntry: diff --git a/asyncpg/cursor.py b/asyncpg/cursor.py index 030def0e..b6933b0f 100644 --- a/asyncpg/cursor.py +++ b/asyncpg/cursor.py @@ -19,11 +19,14 @@ class CursorFactory(connresource.ConnectionResource): results of a large query. """ - __slots__ = ('_state', '_args', '_prefetch', '_query', '_timeout') + __slots__ = ('_state', '_args', '_kwargs', '_prefetch', + '_query', '_timeout') - def __init__(self, connection, query, state, args, prefetch, timeout): + def __init__(self, connection, query, state, args, kwargs, + prefetch, timeout): super().__init__(connection) self._args = args + self._kwargs = kwargs self._prefetch = prefetch self._query = query self._timeout = timeout @@ -37,7 +40,7 @@ def __aiter__(self): prefetch = 50 if self._prefetch is None else self._prefetch return CursorIterator(self._connection, self._query, self._state, - self._args, prefetch, + self._args, self._kwargs, prefetch, self._timeout) @connresource.guarded @@ -46,7 +49,7 @@ def __await__(self): raise exceptions.InterfaceError( 'prefetch argument can only be specified for iterable cursor') cursor = Cursor(self._connection, self._query, - self._state, self._args) + self._state, self._args, self._kwargs) return cursor._init(self._timeout).__await__() def __del__(self): @@ -57,11 +60,13 @@ def __del__(self): class BaseCursor(connresource.ConnectionResource): - __slots__ = ('_state', '_args', '_portal_name', '_exhausted', '_query') + __slots__ = ('_state', '_args', '_kwargs', '_portal_name', + '_exhausted', '_query') - def __init__(self, connection, query, state, args): + def __init__(self, connection, query, state, args, kwargs): super().__init__(connection) self._args = args + self._kwargs = kwargs self._state = state if state is not None: state.attach() @@ -94,7 +99,8 @@ async def _bind_exec(self, n, timeout): self._portal_name = con._get_unique_id('portal') buffer, _, self._exhausted = await protocol.bind_execute( - self._state, self._args, self._portal_name, n, True, timeout) + self._state, self._args, self._kwargs, self._portal_name, + n, True, timeout) return buffer async def _bind(self, timeout): @@ -108,7 +114,7 @@ async def _bind(self, timeout): protocol = con._protocol self._portal_name = con._get_unique_id('portal') - buffer = await protocol.bind(self._state, self._args, + buffer = await protocol.bind(self._state, self._args, self._kwargs, self._portal_name, timeout) return buffer @@ -151,8 +157,9 @@ class CursorIterator(BaseCursor): __slots__ = ('_buffer', '_prefetch', '_timeout') - def __init__(self, connection, query, state, args, prefetch, timeout): - super().__init__(connection, query, state, args) + def __init__(self, connection, query, state, args, kwargs, + prefetch, timeout): + super().__init__(connection, query, state, args, kwargs) if prefetch <= 0: raise exceptions.InterfaceError( diff --git a/asyncpg/prepared_stmt.py b/asyncpg/prepared_stmt.py index ce95d6e0..d4ab3629 100644 --- a/asyncpg/prepared_stmt.py +++ b/asyncpg/prepared_stmt.py @@ -15,12 +15,11 @@ class PreparedStatement(connresource.ConnectionResource): """A representation of a prepared statement.""" - __slots__ = ('_state', '_query', '_last_status') + __slots__ = ('_state', '_last_status') - def __init__(self, connection, query, state): + def __init__(self, connection, state): super().__init__(connection) self._state = state - self._query = query state.attach() self._last_status = None @@ -33,7 +32,7 @@ def get_query(self) -> str: stmt = await connection.prepare('SELECT $1::int') assert stmt.get_query() == "SELECT $1::int" """ - return self._query + return self._state.query @connresource.guarded def get_statusmsg(self) -> str: @@ -92,7 +91,7 @@ def get_attributes(self): return self._state._get_attributes() @connresource.guarded - def cursor(self, *args, prefetch=None, + def cursor(self, *args, kwargs=None, prefetch=None, timeout=None) -> cursor.CursorFactory: """Return a *cursor factory* for the prepared statement. @@ -103,12 +102,12 @@ def cursor(self, *args, prefetch=None, :return: A :class:`~cursor.CursorFactory` object. """ - return cursor.CursorFactory(self._connection, self._query, - self._state, args, prefetch, + return cursor.CursorFactory(self._connection, self._state.query, + self._state, args, kwargs, prefetch, timeout) @connresource.guarded - async def explain(self, *args, analyze=False): + async def explain(self, *args, kwargs=None, analyze=False): """Return the execution plan of the statement. :param args: Query arguments. @@ -141,16 +140,18 @@ async def explain(self, *args, analyze=False): tr = self._connection.transaction() await tr.start() try: - data = await self._connection.fetchval(query, *args) + data = await self._connection.fetchval( + query, *args, kwargs=kwargs) finally: await tr.rollback() else: - data = await self._connection.fetchval(query, *args) + data = await self._connection.fetchval( + query, *args, kwargs=kwargs) return json.loads(data) @connresource.guarded - async def fetch(self, *args, timeout=None): + async def fetch(self, *args, kwargs=None, timeout=None): r"""Execute the statement and return a list of :class:`Record` objects. :param str query: Query text @@ -159,11 +160,11 @@ async def fetch(self, *args, timeout=None): :return: A list of :class:`Record` instances. """ - data = await self.__bind_execute(args, 0, timeout) + data = await self.__bind_execute(args, kwargs, 0, timeout) return data @connresource.guarded - async def fetchval(self, *args, column=0, timeout=None): + async def fetchval(self, *args, kwargs=None, column=0, timeout=None): """Execute the statement and return a value in the first row. :param args: Query arguments. @@ -176,13 +177,13 @@ async def fetchval(self, *args, column=0, timeout=None): :return: The value of the specified column of the first record. """ - data = await self.__bind_execute(args, 1, timeout) + data = await self.__bind_execute(args, kwargs, 1, timeout) if not data: return None return data[0][column] @connresource.guarded - async def fetchrow(self, *args, timeout=None): + async def fetchrow(self, *args, kwargs=None, timeout=None): """Execute the statement and return the first row. :param str query: Query text @@ -191,15 +192,15 @@ async def fetchrow(self, *args, timeout=None): :return: The first row as a :class:`Record` instance. """ - data = await self.__bind_execute(args, 1, timeout) + data = await self.__bind_execute(args, kwargs, 1, timeout) if not data: return None return data[0] - async def __bind_execute(self, args, limit, timeout): + async def __bind_execute(self, args, kwargs, limit, timeout): protocol = self._connection._protocol data, status, _ = await protocol.bind_execute( - self._state, args, '', limit, True, timeout) + self._state, args, kwargs, '', limit, True, timeout) self._last_status = status return data diff --git a/asyncpg/protocol/prepared_stmt.pxd b/asyncpg/protocol/prepared_stmt.pxd index 8dab35b1..df5b41e8 100644 --- a/asyncpg/protocol/prepared_stmt.pxd +++ b/asyncpg/protocol/prepared_stmt.pxd @@ -29,9 +29,14 @@ cdef class PreparedStatementState: bint have_text_cols tuple rows_codecs - cdef _encode_bind_msg(self, args) + bint have_query_pp + tuple kwargs_order + + cdef _encode_bind_msg(self, args, kwargs) cdef _ensure_rows_decoder(self) cdef _ensure_args_encoder(self) cdef _set_row_desc(self, object desc) cdef _set_args_desc(self, object desc) cdef _decode_row(self, const char* cbuf, ssize_t buf_len) + + cpdef apply_kwargs(self, args, kwargs) diff --git a/asyncpg/protocol/prepared_stmt.pyx b/asyncpg/protocol/prepared_stmt.pyx index 3edb56f0..109d602b 100644 --- a/asyncpg/protocol/prepared_stmt.pyx +++ b/asyncpg/protocol/prepared_stmt.pyx @@ -11,9 +11,12 @@ from asyncpg import exceptions @cython.final cdef class PreparedStatementState: - def __cinit__(self, str name, str query, BaseProtocol protocol): + def __cinit__(self, str name, str query, BaseProtocol protocol, + bint have_query_pp, tuple kwargs_order): self.name = name self.query = query + self.have_query_pp = have_query_pp + self.kwargs_order = kwargs_order self.protocol = protocol self.settings = protocol.settings self.row_desc = self.parameters_desc = None @@ -91,7 +94,51 @@ cdef class PreparedStatementState: def mark_closed(self): self.closed = True - cdef _encode_bind_msg(self, args): + cpdef apply_kwargs(self, args, kwargs): + cdef: + bint has_args = bool(args) + bint has_kwargs = bool(kwargs) + + if has_args and has_kwargs: + raise exceptions.InterfaceError( + 'got both `kwargs` and positional arguments') + + if not self.have_query_pp: + if has_kwargs: + raise exceptions.InterfaceError( + 'no query preprocessor is defined on the connection ' + 'to enable support for the `kwargs` argument') + else: + if self.kwargs_order: + if not has_kwargs: + raise exceptions.InterfaceError( + 'query has keyword parameters but no `kwargs` ' + 'argument was provided') + + if len(kwargs) != len(self.kwargs_order): + missing_kwargs = set(self.kwargs_order) - set(kwargs) + raise exceptions.InterfaceError( + 'missing values for the following keyword ' + 'arguments: {!r}'.format(missing_kwargs)) + + args = [] + for name in self.kwargs_order: + try: + val = kwargs[name] + except KeyError: + raise exceptions.InterfaceError( + 'missing a value for the {!r} keyword ' + 'argument'.format(name)) from None + args.append(val) + else: + if has_kwargs: + raise exceptions.InterfaceError( + 'cannot use `kwargs`: query preprocessor found no ' + 'keyword parameters in the query') + + return args + + cdef _encode_bind_msg(self, args, kwargs): cdef: int idx WriteBuffer writer @@ -104,6 +151,8 @@ cdef class PreparedStatementState: self._ensure_args_encoder() self._ensure_rows_decoder() + args = self.apply_kwargs(args, kwargs) + writer = WriteBuffer.new() num_args_passed = len(args) diff --git a/asyncpg/protocol/protocol.pxd b/asyncpg/protocol/protocol.pxd index a602854f..c3fb82cc 100644 --- a/asyncpg/protocol/protocol.pxd +++ b/asyncpg/protocol/protocol.pxd @@ -37,6 +37,7 @@ cdef class BaseProtocol(CoreProtocol): object timeout_callback object completed_callback object connection + object query_pp bint is_reading str last_query diff --git a/asyncpg/protocol/protocol.pyx b/asyncpg/protocol/protocol.pyx index ac9e08d3..571acfe8 100644 --- a/asyncpg/protocol/protocol.pyx +++ b/asyncpg/protocol/protocol.pyx @@ -117,6 +117,9 @@ cdef class BaseProtocol(CoreProtocol): self.queries_count = 0 + self.connection = None + self.query_pp = None + try: self.create_future = loop.create_future except AttributeError: @@ -124,6 +127,7 @@ cdef class BaseProtocol(CoreProtocol): def set_connection(self, connection): self.connection = connection + self.query_pp = connection._query_pp def get_server_pid(self): return self.backend_pid @@ -156,18 +160,28 @@ cdef class BaseProtocol(CoreProtocol): self._check_state() timeout = self._get_timeout_impl(timeout) + kwargs_order = None + if self.query_pp is not None: + try: + query, kwargs_order = self.query_pp(query) + except Exception as ex: + raise apg_exc.InternalClientError( + 'exception while calling query preprocessor') from ex + waiter = self._new_waiter(timeout) try: self._prepare(stmt_name, query) # network op self.last_query = query - self.statement = PreparedStatementState(stmt_name, query, self) + self.statement = PreparedStatementState( + stmt_name, query, self, + self.query_pp is not None, kwargs_order) except Exception as ex: waiter.set_exception(ex) self._coreproto_error() finally: return await waiter - async def bind_execute(self, PreparedStatementState state, args, + async def bind_execute(self, PreparedStatementState state, args, kwargs, str portal_name, int limit, return_extra, timeout): @@ -179,7 +193,7 @@ cdef class BaseProtocol(CoreProtocol): self._check_state() timeout = self._get_timeout_impl(timeout) - args_buf = state._encode_bind_msg(args) + args_buf = state._encode_bind_msg(args, kwargs) waiter = self._new_waiter(timeout) try: @@ -214,7 +228,7 @@ cdef class BaseProtocol(CoreProtocol): # Make sure the argument sequence is encoded lazily with # this generator expression to keep the memory pressure under # control. - data_gen = (state._encode_bind_msg(b) for b in args) + data_gen = (state._encode_bind_msg(b, None) for b in args) arg_bufs = iter(data_gen) waiter = self._new_waiter(timeout) @@ -234,7 +248,7 @@ cdef class BaseProtocol(CoreProtocol): finally: return await waiter - async def bind(self, PreparedStatementState state, args, + async def bind(self, PreparedStatementState state, args, kwargs, str portal_name, timeout): if self.cancel_waiter is not None: @@ -245,7 +259,7 @@ cdef class BaseProtocol(CoreProtocol): self._check_state() timeout = self._get_timeout_impl(timeout) - args_buf = state._encode_bind_msg(args) + args_buf = state._encode_bind_msg(args, kwargs) waiter = self._new_waiter(timeout) try: diff --git a/asyncpg/protocol/python.pxd b/asyncpg/protocol/python.pxd index 3ae2d14b..14d39d34 100644 --- a/asyncpg/protocol/python.pxd +++ b/asyncpg/protocol/python.pxd @@ -31,3 +31,8 @@ cdef extern from "Python.h": int kind, const void *buffer, Py_ssize_t size) int PyUnicode_4BYTE_KIND + + int PyUnicode_KIND(object o) + void *PyUnicode_DATA(object o) + Py_UCS4 PyUnicode_READ(int kind, void *data, ssize_t index) + void PyUnicode_WRITE(int kind, void *data, Py_ssize_t index, Py_UCS4 value) diff --git a/asyncpg/query_pp.pyx b/asyncpg/query_pp.pyx new file mode 100644 index 00000000..bdbedc0d --- /dev/null +++ b/asyncpg/query_pp.pyx @@ -0,0 +1,380 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + + +import asyncpg +import functools + +cimport cpython + +from asyncpg.protocol.python cimport ( + PyUnicode_KIND, PyUnicode_READ, PyUnicode_WRITE, PyUnicode_DATA, + PyUnicode_FromKindAndData) + + +cdef enum ParamLexMode: + MODE_NORMAL + MODE_STRING + MODE_ESTRING + MODE_QUOTE + MODE_COMMENT + MODE_BLOCK_COMMENT + MODE_DOLLAR + + +cdef inline Py_ssize_t _buf_append( + str source, Py_ssize_t source_from, Py_ssize_t source_to, + int dest_kind, void *dest, Py_ssize_t to_dest_pos): + + cdef: + void *source_buf = PyUnicode_DATA(source) + int source_kind = PyUnicode_KIND(source) + Py_ssize_t i + Py_UCS4 ch + + for i in range(source_from, source_to): + ch = PyUnicode_READ(source_kind, source_buf, i) + PyUnicode_WRITE(dest_kind, dest, to_dest_pos, ch) + to_dest_pos += 1 + + return to_dest_pos + + +cpdef transform_kwargs(str query): + # This optimized lexer is 10x faster on average + # than an equivalent pure-Python implementation. + + cdef: + ParamLexMode mode = MODE_NORMAL + Py_ssize_t i = 0 + Py_ssize_t query_len = len(query) + bint lexing = True + bint eof + dict named_params + list named_params_list + + str arg_name + str tag + str marker + + int ukind + + Py_UCS4 ch + Py_UCS4 prev_ch + Py_UCS4 prev_prev_ch + + bint has_positional_only = False + + void *query_buf = PyUnicode_DATA(query) + void *new_query_buf = NULL + Py_ssize_t new_query_buf_pos = 0 + + Py_ssize_t dollar_started = 0 + + if not query_len: + return query, None + + ukind = PyUnicode_KIND(query) + + new_query_buf = cpython.PyMem_Malloc( + ((ukind * 1.4 * query_len) * sizeof(void*))); + if new_query_buf == NULL: + raise MemoryError + + try: + named_params_list = [] + named_params = {} + + ch = PyUnicode_READ(ukind, query_buf, 0) + PyUnicode_WRITE(ukind, new_query_buf, new_query_buf_pos, ch) + new_query_buf_pos += 1 + + while lexing: + if mode == MODE_NORMAL: + # Normal lexer mode when we are not parsing a string literal, + # a quoted identifier, a comment, or a `$..` sequence. + + while True: + i += 1 + if i >= query_len: + lexing = False + break + prev_ch = ch + ch = PyUnicode_READ(ukind, query_buf, i) + + PyUnicode_WRITE( + ukind, new_query_buf, new_query_buf_pos, ch) + new_query_buf_pos += 1 + + if ch == u'"': + mode = MODE_QUOTE + break + elif ch == u"'": + prev_prev_ch = u'\0' + if query_len > 2: + prev_prev_ch = PyUnicode_READ( + ukind, query_buf, i - 2) + + if (prev_ch == u'E' or + prev_ch == u'e') and ( + # Check that we don't have + # `CASE'\'` situation. + (prev_prev_ch < u'A' or + prev_prev_ch > u'Z') and + (prev_prev_ch < u'a' or + prev_prev_ch > u'z')): + mode = MODE_ESTRING + else: + mode = MODE_STRING + break + elif prev_ch == u'-' and ch == u'-': + mode = MODE_COMMENT + break + elif prev_ch == u'/' and ch == u'*': + mode = MODE_BLOCK_COMMENT + break + elif ch == u'$': + mode = MODE_DOLLAR + break + + elif mode == MODE_QUOTE: + # Quoted identifier, such as `"name"`. + + while True: + i += 1 + if i >= query_len: + lexing = False + break + prev_ch = ch + ch = PyUnicode_READ(ukind, query_buf, i) + + PyUnicode_WRITE( + ukind, new_query_buf, new_query_buf_pos, ch) + new_query_buf_pos += 1 + + if ch == u'"': + mode = MODE_NORMAL + break + + elif mode == MODE_STRING: + # Regular string literal, such as `'aaaa'`. + + while True: + i += 1 + if i >= query_len: + lexing = False + break + prev_ch = ch + ch = PyUnicode_READ(ukind, query_buf, i) + + PyUnicode_WRITE( + ukind, new_query_buf, new_query_buf_pos, ch) + new_query_buf_pos += 1 + + if ch == u"'": + mode = MODE_NORMAL + break + + elif mode == MODE_ESTRING: + # String literal prefixed with `e` or `E`. + # For example: `E'aaa'` or `e'\''`. + + while True: + i += 1 + if i >= query_len: + lexing = False + break + prev_ch = ch + ch = PyUnicode_READ(ukind, query_buf, i) + + PyUnicode_WRITE( + ukind, new_query_buf, new_query_buf_pos, ch) + new_query_buf_pos += 1 + + + if ch == u"'": + if prev_ch == u'\\': + continue + else: + mode = MODE_NORMAL + break + + elif mode == MODE_COMMENT: + # Single line "--" comment. + + while True: + i += 1 + if i >= query_len: + lexing = False + break + prev_ch = ch + ch = PyUnicode_READ(ukind, query_buf, i) + + PyUnicode_WRITE( + ukind, new_query_buf, new_query_buf_pos, ch) + new_query_buf_pos += 1 + + if ch == u"\n": + mode = MODE_NORMAL + break + + elif mode == MODE_BLOCK_COMMENT: + # Block /* .. */ comment. + + while True: + i += 1 + if i >= query_len: + lexing = False + break + prev_ch = ch + ch = PyUnicode_READ(ukind, query_buf, i) + + PyUnicode_WRITE( + ukind, new_query_buf, new_query_buf_pos, ch) + new_query_buf_pos += 1 + + if prev_ch == u'*' and ch == u'/': + mode = MODE_NORMAL + break + + elif mode == MODE_DOLLAR: + # Whenever we see '$' we switch to this mode. + # The '$' character can be a start of: + # - an argument, such as `$1` or `$foo`; + # - a quoted string, such as `$$ .. $$` or `$foo$ .. $foo$`. + + while True: + i += 1 + eof = False + if i >= query_len: + ch = u'\0' + eof = True + else: + prev_ch = ch + ch = PyUnicode_READ(ukind, query_buf, i) + + if ch == u'$': + # We found a second '$' character, this looks like + # the beginning of a quoted string like + # `$$` or `$foo$`. + + assert ch != u'\0' + + if dollar_started == 0: + tag = '' + else: + tag = query[dollar_started:i] + marker = '$' + tag + '$' + + dollar_started = 0 + + new_query_buf_pos = _buf_append( + tag, 0, len(tag), + ukind, new_query_buf, new_query_buf_pos) + + PyUnicode_WRITE( + ukind, new_query_buf, new_query_buf_pos, ch) + new_query_buf_pos += 1 + + res = query.find(marker, i) + if res == -1: + # Can't find the matching end, as in + # "SELECT $aa$ ... " query. The query is + # likely invalid, but we don't care. + new_query_buf_pos = _buf_append( + query, i + 1, len(query), + ukind, new_query_buf, new_query_buf_pos) + + lexing = False + break + else: + # Found the end marker. + new_query_buf_pos = _buf_append( + query, i + 1, res + len(marker), + ukind, new_query_buf, new_query_buf_pos) + + i = res + len(marker) - 1 + mode = MODE_NORMAL + break + + elif not (ch >= u'A' and ch <= u'Z' or + ch >= u'a' and ch <= u'z' or + ch >= u'0' and ch <= u'9' or + ch == u'_') or eof: + + # This looks like an argument. + + if dollar_started == 0: + if not eof: + PyUnicode_WRITE( + ukind, new_query_buf, + new_query_buf_pos, ch) + new_query_buf_pos += 1 + mode = MODE_NORMAL + break + else: + arg_name = query[dollar_started:i] + + dollar_started = 0 + positional_only = arg_name.isdecimal() + + if positional_only and not named_params: + has_positional_only = True + + new_query_buf_pos = _buf_append( + arg_name, 0, len(arg_name), + ukind, new_query_buf, new_query_buf_pos) + + elif not positional_only and not has_positional_only: + if arg_name[0].isdecimal(): + raise ValueError( + 'invalid argument name {!r}: first ' + 'character is a digit') + + if arg_name not in named_params: + named_params[arg_name] = len(named_params) + 1 + named_params_list.append(arg_name) + + arg_name = str(named_params[arg_name]) + new_query_buf_pos = _buf_append( + arg_name, 0, len(arg_name), + ukind, new_query_buf, new_query_buf_pos) + + else: + raise ValueError( + 'queries with both named and positional-only ' + 'arguments are not supported') + + if not eof: + PyUnicode_WRITE( + ukind, new_query_buf, new_query_buf_pos, ch) + new_query_buf_pos += 1 + mode = MODE_NORMAL + break + + else: + assert ch != u'\0' + if dollar_started == 0: + dollar_started = i + + if i >= query_len: + lexing = False + break + + if named_params_list: + new_query = PyUnicode_FromKindAndData( + ukind, new_query_buf, new_query_buf_pos) + return new_query, tuple(named_params_list) + else: + return query, None + + finally: + cpython.PyMem_Free(new_query_buf) + + +@functools.lru_cache(250) +def keyword_parameters(query): + return transform_kwargs(query) diff --git a/asyncpg/utils.py b/asyncpg/utils.py index 3940e04d..7add1400 100644 --- a/asyncpg/utils.py +++ b/asyncpg/utils.py @@ -16,11 +16,17 @@ def _quote_literal(string): return "'{}'".format(string.replace("'", "''")) -async def _mogrify(conn, query, args): +async def _mogrify(conn, query, args, kwargs): """Safely inline arguments to query text.""" # Introspect the target query for argument types and # build a list of safely-quoted fully-qualified type names. ps = await conn.prepare(query) + + # Get the query from the prepared statement, as it potentially + # could be preprocessed. + query = ps.get_query() + args = ps._state.apply_kwargs(args, kwargs) + paramtypes = [] for t in ps.get_parameters(): if t.name.endswith('[]'): diff --git a/setup.py b/setup.py index fd8b0c26..ca842923 100644 --- a/setup.py +++ b/setup.py @@ -211,6 +211,12 @@ def _patch_cfile(self, cfile): ["asyncpg/protocol/record/recordobj.c", "asyncpg/protocol/protocol.pyx"], extra_compile_args=CFLAGS, + extra_link_args=LDFLAGS), + + setuptools.Extension( + "asyncpg.query_pp", + ["asyncpg/query_pp.pyx"], + extra_compile_args=CFLAGS, extra_link_args=LDFLAGS) ], cmdclass={'build_ext': build_ext}, diff --git a/tests/test_copy.py b/tests/test_copy.py index 6bedb23c..6f12abc9 100644 --- a/tests/test_copy.py +++ b/tests/test_copy.py @@ -165,6 +165,30 @@ async def test_copy_from_query_with_args(self): ] ) + @tb.with_connection_options(query_pp=asyncpg.keyword_parameters) + async def test_copy_from_query_with_kwargs(self): + f = io.BytesIO() + + res = await self.con.copy_from_query(''' + SELECT + i, i * 10 + FROM + generate_series(1, 5) AS i + WHERE + i = $i + ''', kwargs=dict(i=3), output=f) + + self.assertEqual(res, 'COPY 1') + + output = f.getvalue().decode().split('\n') + self.assertEqual( + output, + [ + '3\t30', + '' + ] + ) + async def test_copy_from_query_to_path(self): with tempfile.NamedTemporaryFile() as f: f.close() diff --git a/tests/test_kwargs_transformer.py b/tests/test_kwargs_transformer.py new file mode 100644 index 00000000..a8dd3809 --- /dev/null +++ b/tests/test_kwargs_transformer.py @@ -0,0 +1,357 @@ +# Copyright (C) 2016-present the asyncpg authors and contributors +# +# +# This module is part of asyncpg and is released under +# the Apache 2.0 License: http://www.apache.org/licenses/LICENSE-2.0 + + +import unittest + +from asyncpg.query_pp import transform_kwargs # NOQA + + +class TestKwargsTransformer(unittest.TestCase): + + queries = [ + ('select 1', + 'select 1', + None), + + ('select $a', + 'select $1', + ('a',)), + + ('select $1', + 'select $1', + None), + + ('select $1 + "$1" + \'$1\' + $2', + 'select $1 + "$1" + \'$1\' + $2', + None), + + ('select $abc1', + 'select $1', + ('abc1',)), + + ('select $abc1 ', + 'select $1 ', + ('abc1',)), + + (''' + SELECT """" + $foo + "$foo""" + ''', + ''' + SELECT """" + $1 + "$foo""" + ''', + ('foo',)), + + (''' + SELECT """" + $foo + "$foo""" + -- $bar " + - $baz + ''', + ''' + SELECT """" + $1 + "$foo""" + -- $bar " + - $2 + ''', + ('foo', 'baz')), + + (''' + SELECT """" + $1 + "$foo""" + -- $bar " + - $2 + ''', + ''' + SELECT """" + $1 + "$foo""" + -- $bar " + - $2 + ''', + None), + + (r''' + SELECT E'\'' + $foo + E'$foo\'' + -- $bar " + - $baz + ''', + r''' + SELECT E'\'' + $1 + E'$foo\'' + -- $bar " + - $2 + ''', + ('foo', 'baz')), + + (r''' + SELECT e'\'' + $foo + e'$foo\'' + -- $bar " + - $baz + ''', + r''' + SELECT e'\'' + $1 + e'$foo\'' + -- $bar " + - $2 + ''', + ('foo', 'baz')), + + (r''' + SELECT '\' || $foo + '$foo\' + $fiz -- $bar " + - $baz + ''', + r''' + SELECT '\' || $1 + '$foo\' + $2 -- $bar " + - $3 + ''', + ('foo', 'fiz', 'baz')), + + (r''' + SELECT CASE'\' WHEN $a THEN $b ELSE $a END; + ''', + r''' + SELECT CASE'\' WHEN $1 THEN $2 ELSE $1 END; + ''', + ('a', 'b')), + + (r''' + SELECT case'\' WHEN $a THEN $b ELSE $a END; + ''', + r''' + SELECT case'\' WHEN $1 THEN $2 ELSE $1 END; + ''', + ('a', 'b')), + + (""" + SELECT 'Baz''' + $foo + /* x * + - $y + $z **// $a + """, + """ + SELECT 'Baz''' + $1 + /* x * + - $y + $z **// $2 + """, + ('foo', 'a')), + + (""" + SELECT 'Baz' 'fiz' + $foo + /* x * + - $y + $z **// $a + """, + """ + SELECT 'Baz' 'fiz' + $1 + /* x * + - $y + $z **// $2 + """, + ('foo', 'a')), + + (""" + SELECT $$'Baz''' + $foo + /* x * $$ + - $y + $z **// $a + """, + """ + SELECT $$'Baz''' + $foo + /* x * $$ + - $1 + $2 **// $3 + """, + ('y', 'z', 'a')), + + (""" + SELECT $abc_a$'Baz''' + $foo + /* x * $abc_a$ + - $y + $z **// $a + """, + """ + SELECT $abc_a$'Baz''' + $foo + /* x * $abc_a$ + - $1 + $2 **// $3 + """, + ('y', 'z', 'a')), + ] + + # We should ignore any errors in queries. + invalid_queries = [ + ('', + '', + None), + + ('расколбас $f', + 'расколбас $1', + ('f',)), + + (' ', + ' ', + None), + + (' ', + ' ', + None), + + ('$', + '$', + None), + + ('$$', + '$$', + None), + + ('$$$', + '$$$', + None), + + ('$$$$', + '$$$$', + None), + + ('$$$$$', + '$$$$$', + None), + + ('"', + '"', + None), + + ('""', + '""', + None), + + ('"""', + '"""', + None), + + ('""""', + '""""', + None), + + ('"""""', + '"""""', + None), + + ('e', + 'e', + None), + + ('e"', + 'e"', + None), + + ("e'", + "e'", + None), + + ('select $', + 'select $', + None), + + ('select $ ', + 'select $ ', + None), + + ('select $ as', + 'select $ as', + None), + + ('select $ $as', + 'select $ $1', + ('as',)), + + ('select "as', + 'select "as', + None), + + ("select 'as", + "select 'as", + None), + + ("select 'as\\", + "select 'as\\", + None), + + ("select $foo + 'as\'a ", + "select $1 + 'as\'a ", + ('foo',)), + + ("select $foo + E'as\'a ", + "select $1 + E'as\'a ", + ('foo',)), + + ("select $$as", + "select $$as", + None), + + ("select $foo$as", + "select $foo$as", + None), + ] + + # Can't combine named and positional-only arguments. + invalid_params_combos = [ + 'select $foo + $1', + 'select $1 + $foo' + ] + + # Invalid arguments names + invalid_names = [ + 'select $1foo', + ] + + def test_params_lex_valid_queries(self): + for query, expected_query, expected_params in self.queries: + with self.subTest(query=query): + new_query, new_params = transform_kwargs(query) + self.assertEqual(new_query, expected_query) + self.assertEqual(new_params, expected_params) + + def test_params_lex_invalid_queries(self): + for query, expected_query, expected_params in self.invalid_queries: + with self.subTest(query=query): + new_query, new_params = transform_kwargs(query) + self.assertEqual(new_query, expected_query) + self.assertEqual(new_params, expected_params) + + def test_params_lex_param_combination(self): + for query in self.invalid_params_combos: + with self.subTest(query=query): + with self.assertRaisesRegex(ValueError, 'queries with both'): + transform_kwargs(query) + + def test_params_lex_invalid_names(self): + for query in self.invalid_names: + with self.subTest(query=query): + with self.assertRaisesRegex(ValueError, + 'invalid argument name'): + transform_kwargs(query) + + def test_params_lex_worstcase_args(self): + + def count_base61(): + ar = ('abcdefghijklmnopqrstuvwxyz' + 'ABCDEFGHIJKLMNOPQRSTUVWXYZ' + '_') + arl = len(ar) + + far = ar + '0123456789' + farl = len(far) + + i = 0 + while True: + num = i + buf = '' + + pos = num % arl + num = num // arl + buf += ar[pos] + + while num: + pos = num % farl + num = num // farl + buf += far[pos] + + yield buf + i += 1 + + # Build a query that has 33000 unique and shortest possible + # named arguments. + query = 'SELECT ' + expected_params = [] + expected_query = query + for i, name in enumerate(count_base61()): + if i == 33000: # Max number of arguments Postgres can accept. + break + query += f'${name}+' + expected_query += f'${i + 1}+' + expected_params.append(name) + + query = query[:-1] + expected_query = expected_query[:-1] + + new_query, params = transform_kwargs(query) + self.assertEqual(new_query, expected_query) + self.assertEqual(params, tuple(expected_params)) + + # We use '1.4' as a coefficient when we pre-allocate a buffer + # for the new query in `query_pp.pyx`. + self.assertLess(len(new_query) / len(query), 1.4) diff --git a/tests/test_prepare.py b/tests/test_prepare.py index 4425859f..e1afb1a1 100644 --- a/tests/test_prepare.py +++ b/tests/test_prepare.py @@ -241,7 +241,7 @@ async def test_prepare_12_stmt_gc(self): self.assertEqual(len(cache), cache_max) self.assertEqual(len(self.con._stmts_to_close), 1) - async def test_prepare_13_connect(self): + async def test_prepare_13_fetch_methods(self): v = await self.con.fetchval( 'SELECT $1::smallint AS foo', 10, column='foo') self.assertEqual(v, 10) @@ -569,3 +569,13 @@ async def test_prepare_30_invalid_arg_count(self): exceptions.InterfaceError, 'the server expects 0 arguments for this query, 1 was passed'): await self.con.fetchval('SELECT 1', 1) + + @tb.with_connection_options(query_pp=asyncpg.keyword_parameters) + async def test_prepare_31_fetch_methods_kwargs(self): + r = await self.con.fetchrow( + 'SELECT $r::smallint * 2 AS test', kwargs=dict(r=10)) + self.assertEqual(r['test'], 20) + + rows = await self.con.fetch( + 'SELECT generate_series(0,$r::int)', kwargs=dict(r=3)) + self.assertEqual([r[0] for r in rows], [0, 1, 2, 3]) diff --git a/tests/test_utils.py b/tests/test_utils.py index 30cecc32..5dc75f63 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -26,12 +26,12 @@ async def test_mogrify_simple(self): for typename, data, expected in cases: with self.subTest(value=data, type=typename): mogrified = await utils._mogrify( - self.con, 'SELECT $1::{}'.format(typename), [data]) + self.con, 'SELECT $1::{}'.format(typename), [data], None) self.assertEqual(mogrified, expected) async def test_mogrify_multiple(self): mogrified = await utils._mogrify( self.con, 'SELECT $1::int, $2::int[]', - [1, [2, 3, 4, 5]]) + [1, [2, 3, 4, 5]], None) expected = "SELECT '1'::int, '{2,3,4,5}'::int[]" self.assertEqual(mogrified, expected)