Skip to content

*Experimental* support for kwargs. #196

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions asyncpg/.gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
*.html
/*.c
1 change: 1 addition & 0 deletions asyncpg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .pool import create_pool # NOQA
from .protocol import Record # NOQA
from .types import * # NOQA
from .query_pp import keyword_parameters # NOQA


__all__ = ('connect', 'create_pool', 'Record', 'Connection') + \
Expand Down
6 changes: 4 additions & 2 deletions asyncpg/connect_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
'statement_cache_size',
'max_cached_statement_lifetime',
'max_cacheable_statement_size',
'query_pp'
])


Expand Down Expand Up @@ -210,7 +211,7 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, database,
timeout, command_timeout, statement_cache_size,
max_cached_statement_lifetime,
max_cacheable_statement_size,
ssl, server_settings):
ssl, server_settings, query_pp):

local_vars = locals()
for var_name in {'max_cacheable_statement_size',
Expand Down Expand Up @@ -245,7 +246,8 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, database,
command_timeout=command_timeout,
statement_cache_size=statement_cache_size,
max_cached_statement_lifetime=max_cached_statement_lifetime,
max_cacheable_statement_size=max_cacheable_statement_size,)
max_cacheable_statement_size=max_cacheable_statement_size,
query_pp=query_pp)

return addrs, params, config

Expand Down
46 changes: 27 additions & 19 deletions asyncpg/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from . import cursor
from . import exceptions
from . import introspection
from . import query_pp
from . import prepared_stmt
from . import protocol
from . import serverversion
Expand All @@ -44,7 +45,7 @@ class Connection(metaclass=ConnectionMeta):
'_listeners', '_server_version', '_server_caps',
'_intro_query', '_reset_query', '_proxy',
'_stmt_exclusive_section', '_config', '_params', '_addr',
'_log_listeners')
'_log_listeners', '_query_pp')

def __init__(self, protocol, transport, loop,
addr: (str, int) or str,
Expand All @@ -66,6 +67,7 @@ def __init__(self, protocol, transport, loop,
self._addr = addr
self._config = config
self._params = params
self._query_pp = config.query_pp

self._stmt_cache = _StatementCache(
loop=loop,
Expand Down Expand Up @@ -206,7 +208,8 @@ def transaction(self, *, isolation='read_committed', readonly=False,
self._check_open()
return transaction.Transaction(self, isolation, readonly, deferrable)

async def execute(self, query: str, *args, timeout: float=None) -> str:
async def execute(self, query: str, *args, kwargs=None,
timeout: float=None) -> str:
"""Execute an SQL command (or commands).

This method can execute many SQL commands at once, when no arguments
Expand Down Expand Up @@ -239,7 +242,8 @@ async def execute(self, query: str, *args, timeout: float=None) -> str:
if not args:
return await self._protocol.query(query, timeout)

_, status, _ = await self._execute(query, args, 0, timeout, True)
_, status, _ = await self._execute(
query, args, kwargs, 0, timeout, True)
return status.decode()

async def executemany(self, command: str, args, *, timeout: float=None):
Expand Down Expand Up @@ -305,7 +309,7 @@ async def _get_statement(self, query, timeout, *, named: bool=False):

return statement

def cursor(self, query, *args, prefetch=None, timeout=None):
def cursor(self, query, *args, kwargs=None, prefetch=None, timeout=None):
"""Return a *cursor factory* for the specified query.

:param args: Query arguments.
Expand All @@ -316,7 +320,7 @@ def cursor(self, query, *args, prefetch=None, timeout=None):
:return: A :class:`~cursor.CursorFactory` object.
"""
self._check_open()
return cursor.CursorFactory(self, query, None, args,
return cursor.CursorFactory(self, query, None, args, kwargs,
prefetch, timeout)

async def prepare(self, query, *, timeout=None):
Expand All @@ -329,9 +333,9 @@ async def prepare(self, query, *, timeout=None):
"""
self._check_open()
stmt = await self._get_statement(query, timeout, named=True)
return prepared_stmt.PreparedStatement(self, query, stmt)
return prepared_stmt.PreparedStatement(self, stmt)

async def fetch(self, query, *args, timeout=None) -> list:
async def fetch(self, query, *args, kwargs=None, timeout=None) -> list:
"""Run a query and return the results as a list of :class:`Record`.

:param str query: Query text.
Expand All @@ -341,9 +345,10 @@ async def fetch(self, query, *args, timeout=None) -> list:
:return list: A list of :class:`Record` instances.
"""
self._check_open()
return await self._execute(query, args, 0, timeout)
return await self._execute(query, args, kwargs, 0, timeout)

async def fetchval(self, query, *args, column=0, timeout=None):
async def fetchval(self, query, *args, kwargs=None,
column=0, timeout=None):
"""Run a query and return a value in the first row.

:param str query: Query text.
Expand All @@ -359,12 +364,12 @@ async def fetchval(self, query, *args, column=0, timeout=None):
None if no records were returned by the query.
"""
self._check_open()
data = await self._execute(query, args, 1, timeout)
data = await self._execute(query, args, kwargs, 1, timeout)
if not data:
return None
return data[0][column]

async def fetchrow(self, query, *args, timeout=None):
async def fetchrow(self, query, *args, kwargs=None, timeout=None):
"""Run a query and return the first row.

:param str query: Query text
Expand All @@ -375,7 +380,7 @@ async def fetchrow(self, query, *args, timeout=None):
no records were returned by the query.
"""
self._check_open()
data = await self._execute(query, args, 1, timeout)
data = await self._execute(query, args, kwargs, 1, timeout)
if not data:
return None
return data[0]
Expand Down Expand Up @@ -451,7 +456,7 @@ async def copy_from_table(self, table_name, *, output,

return await self._copy_out(copy_stmt, output, timeout)

async def copy_from_query(self, query, *args, output,
async def copy_from_query(self, query, *args, kwargs=None, output,
timeout=None, format=None, oids=None,
delimiter=None, null=None, header=None,
quote=None, escape=None, force_quote=None,
Expand Down Expand Up @@ -504,8 +509,8 @@ async def copy_from_query(self, query, *args, output,
force_quote=force_quote, encoding=encoding
)

if args:
query = await utils._mogrify(self, query, args)
if args or kwargs:
query = await utils._mogrify(self, query, args, kwargs)

copy_stmt = 'COPY ({query}) TO STDOUT {opts}'.format(
query=query, opts=opts)
Expand Down Expand Up @@ -1208,9 +1213,10 @@ def _drop_global_statement_cache(self):
else:
self._drop_local_statement_cache()

async def _execute(self, query, args, limit, timeout, return_status=False):
async def _execute(self, query, args, kwargs,
limit, timeout, return_status=False):
executor = lambda stmt, timeout: self._protocol.bind_execute(
stmt, args, '', limit, return_status, timeout)
stmt, args, kwargs, '', limit, return_status, timeout)
timeout = self._protocol._get_timeout(timeout)
with self._stmt_exclusive_section:
return await self._do_execute(query, executor, timeout)
Expand Down Expand Up @@ -1287,7 +1293,8 @@ async def connect(dsn=None, *,
command_timeout=None,
ssl=None,
connection_class=Connection,
server_settings=None):
server_settings=None,
query_pp=None):
r"""A coroutine to establish a connection to a PostgreSQL server.

Returns a new :class:`~asyncpg.connection.Connection` object.
Expand Down Expand Up @@ -1403,7 +1410,8 @@ class of the returned connection object. Must be a subclass of
command_timeout=command_timeout,
statement_cache_size=statement_cache_size,
max_cached_statement_lifetime=max_cached_statement_lifetime,
max_cacheable_statement_size=max_cacheable_statement_size)
max_cacheable_statement_size=max_cacheable_statement_size,
query_pp=query_pp)


class _StatementCacheEntry:
Expand Down
27 changes: 17 additions & 10 deletions asyncpg/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,14 @@ class CursorFactory(connresource.ConnectionResource):
results of a large query.
"""

__slots__ = ('_state', '_args', '_prefetch', '_query', '_timeout')
__slots__ = ('_state', '_args', '_kwargs', '_prefetch',
'_query', '_timeout')

def __init__(self, connection, query, state, args, prefetch, timeout):
def __init__(self, connection, query, state, args, kwargs,
prefetch, timeout):
super().__init__(connection)
self._args = args
self._kwargs = kwargs
self._prefetch = prefetch
self._query = query
self._timeout = timeout
Expand All @@ -37,7 +40,7 @@ def __aiter__(self):
prefetch = 50 if self._prefetch is None else self._prefetch
return CursorIterator(self._connection,
self._query, self._state,
self._args, prefetch,
self._args, self._kwargs, prefetch,
self._timeout)

@connresource.guarded
Expand All @@ -46,7 +49,7 @@ def __await__(self):
raise exceptions.InterfaceError(
'prefetch argument can only be specified for iterable cursor')
cursor = Cursor(self._connection, self._query,
self._state, self._args)
self._state, self._args, self._kwargs)
return cursor._init(self._timeout).__await__()

def __del__(self):
Expand All @@ -57,11 +60,13 @@ def __del__(self):

class BaseCursor(connresource.ConnectionResource):

__slots__ = ('_state', '_args', '_portal_name', '_exhausted', '_query')
__slots__ = ('_state', '_args', '_kwargs', '_portal_name',
'_exhausted', '_query')

def __init__(self, connection, query, state, args):
def __init__(self, connection, query, state, args, kwargs):
super().__init__(connection)
self._args = args
self._kwargs = kwargs
self._state = state
if state is not None:
state.attach()
Expand Down Expand Up @@ -94,7 +99,8 @@ async def _bind_exec(self, n, timeout):

self._portal_name = con._get_unique_id('portal')
buffer, _, self._exhausted = await protocol.bind_execute(
self._state, self._args, self._portal_name, n, True, timeout)
self._state, self._args, self._kwargs, self._portal_name,
n, True, timeout)
return buffer

async def _bind(self, timeout):
Expand All @@ -108,7 +114,7 @@ async def _bind(self, timeout):
protocol = con._protocol

self._portal_name = con._get_unique_id('portal')
buffer = await protocol.bind(self._state, self._args,
buffer = await protocol.bind(self._state, self._args, self._kwargs,
self._portal_name,
timeout)
return buffer
Expand Down Expand Up @@ -151,8 +157,9 @@ class CursorIterator(BaseCursor):

__slots__ = ('_buffer', '_prefetch', '_timeout')

def __init__(self, connection, query, state, args, prefetch, timeout):
super().__init__(connection, query, state, args)
def __init__(self, connection, query, state, args, kwargs,
prefetch, timeout):
super().__init__(connection, query, state, args, kwargs)

if prefetch <= 0:
raise exceptions.InterfaceError(
Expand Down
Loading