Skip to content

Fix issues with inet type I/O #203

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 10, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 41 additions & 13 deletions asyncpg/protocol/codecs/network.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,37 @@ _ipaddr = ipaddress.ip_address
_ipnet = ipaddress.ip_network


cdef inline _net_encode(WriteBuffer buf, int32_t version, uint8_t bits,
cdef inline uint8_t _ip_max_prefix_len(int32_t family):
# Maximum number of bits in the network prefix of the specified
# IP protocol version.
if family == PGSQL_AF_INET:
return 32
else:
return 128


cdef inline int32_t _ip_addr_len(int32_t family):
# Length of address in bytes for the specified IP protocol version.
if family == PGSQL_AF_INET:
return 4
else:
return 16


cdef inline int8_t _ver_to_family(int32_t version):
if version == 4:
return PGSQL_AF_INET
else:
return PGSQL_AF_INET6


cdef inline _net_encode(WriteBuffer buf, int8_t family, uint32_t bits,
int8_t is_cidr, bytes addr):

cdef:
char *addrbytes
ssize_t addrlen
int8_t family

family = PGSQL_AF_INET if version == 4 else PGSQL_AF_INET6
cpython.PyBytes_AsStringAndSize(addr, &addrbytes, &addrlen)

buf.write_int32(4 + <int32_t>addrlen)
Expand All @@ -41,28 +63,31 @@ cdef net_decode(ConnectionSettings settings, FastReadBuffer buf):
cdef:
int32_t family = <int32_t>buf.read(1)[0]
uint8_t bits = <uint8_t>buf.read(1)[0]
uint32_t is_cidr = <uint32_t>buf.read(1)[0]
uint32_t addrlen = <uint32_t>buf.read(1)[0]
int32_t is_cidr = <int32_t>buf.read(1)[0]
int32_t addrlen = <int32_t>buf.read(1)[0]
bytes addr
uint8_t max_prefix_len = _ip_max_prefix_len(family)

if family != PGSQL_AF_INET and family != PGSQL_AF_INET6:
raise ValueError('invalid address family in "{}" value'.format(
'cidr' if is_cidr else 'inet'
))

if bits > (32 if family == PGSQL_AF_INET else 128):
raise ValueError('invalid bits in "{}" value'.format(
max_prefix_len = _ip_max_prefix_len(family)

if bits > max_prefix_len:
raise ValueError('invalid network prefix length in "{}" value'.format(
'cidr' if is_cidr else 'inet'
))

if addrlen != (4 if family == PGSQL_AF_INET else 16):
raise ValueError('invalid length in "{}" value'.format(
if addrlen != _ip_addr_len(family):
raise ValueError('invalid address length in "{}" value'.format(
'cidr' if is_cidr else 'inet'
))

addr = cpython.PyBytes_FromStringAndSize(buf.read(addrlen), addrlen)

if is_cidr or bits > 0:
if is_cidr or bits != max_prefix_len:
return _ipnet(addr).supernet(new_prefix=cpython.PyLong_FromLong(bits))
else:
return _ipaddr(addr)
Expand All @@ -71,15 +96,17 @@ cdef net_decode(ConnectionSettings settings, FastReadBuffer buf):
cdef cidr_encode(ConnectionSettings settings, WriteBuffer buf, obj):
cdef:
object ipnet
int8_t family

ipnet = _ipnet(obj)
_net_encode(buf, ipnet.version, ipnet.prefixlen, 1,
ipnet.network_address.packed)
family = _ver_to_family(ipnet.version)
_net_encode(buf, family, ipnet.prefixlen, 1, ipnet.network_address.packed)


cdef inet_encode(ConnectionSettings settings, WriteBuffer buf, obj):
cdef:
object ipaddr
int8_t family

try:
ipaddr = _ipaddr(obj)
Expand All @@ -88,7 +115,8 @@ cdef inet_encode(ConnectionSettings settings, WriteBuffer buf, obj):
# for the host datatype.
cidr_encode(settings, buf, obj)
else:
_net_encode(buf, ipaddr.version, 0, 0, ipaddr.packed)
family = _ver_to_family(ipaddr.version)
_net_encode(buf, family, _ip_max_prefix_len(family), 0, ipaddr.packed)


cdef init_network_codecs():
Expand Down
46 changes: 42 additions & 4 deletions tests/test_codecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,32 @@ def _timezone(offset):
output=ipaddress.IPv4Network('127.0.0.0/8')),
dict(
input='127.0.0.1/32',
output=ipaddress.IPv4Network('127.0.0.1/32')),
output=ipaddress.IPv4Address('127.0.0.1')),
# Postgres appends /32 when casting to text explicitly, but
# *not* in inet_out.
dict(
input='10.11.12.13',
textoutput='10.11.12.13/32'
),
dict(
input=ipaddress.IPv4Address('10.11.12.13'),
textoutput='10.11.12.13/32'
),
dict(
input=ipaddress.IPv4Network('10.11.12.13'),
textoutput='10.11.12.13/32'
),
dict(
textinput='10.11.12.13',
output=ipaddress.IPv4Address('10.11.12.13'),
),
dict(
# Non-zero address bits after the network prefix are permitted
# by postgres, but are invalid in Python
# (and zeroed out by supernet()).
textinput='10.11.12.13/0',
output=ipaddress.IPv4Network('0.0.0.0/0'),
),
]),
('macaddr', 'macaddr', [
'00:00:00:00:00:00',
Expand Down Expand Up @@ -369,20 +394,33 @@ async def test_standard_codecs(self):
"SELECT $1::" + typname
)

textst = await self.con.prepare(
text_in = await self.con.prepare(
"SELECT $1::text::" + typname
)

text_out = await self.con.prepare(
"SELECT $1::" + typname + "::text"
)

for sample in sample_data:
with self.subTest(sample=sample, typname=typname):
stmt = st
if isinstance(sample, dict):
if 'textinput' in sample:
inputval = sample['textinput']
stmt = textst
stmt = text_in
else:
inputval = sample['input']
outputval = sample['output']

if 'textoutput' in sample:
outputval = sample['textoutput']
if stmt is text_in:
raise ValueError(
'cannot test "textin" and'
' "textout" simultaneously')
stmt = text_out
else:
outputval = sample['output']
else:
inputval = outputval = sample

Expand Down