Skip to content

Quicker unpacking #102

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 2, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 28 additions & 25 deletions neo4j/v1/bolt.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,7 @@ def fill(self):
ready_to_read, _, _ = select((self.socket,), (), (), 0)
received = self.socket.recv(65539)
if received:
if __debug__:
log_debug("S: b%r", received)
log_debug("S: b%r", received)
self.buffer[len(self.buffer):] = received
else:
if ready_to_read is not None:
Expand Down Expand Up @@ -174,8 +173,7 @@ def send(self):
""" Send all queued messages to the server.
"""
data = self.raw.getvalue()
if __debug__:
log_debug("C: b%r", data)
log_debug("C: b%r", data)
self.socket.sendall(data)

self.raw.seek(self.raw.truncate(0))
Expand Down Expand Up @@ -264,8 +262,7 @@ def append(self, signature, fields=(), response=None):
:arg fields: the fields of the message as a tuple
:arg response: a response object to handle callbacks
"""
if __debug__:
log_info("C: %s %s", message_names[signature], " ".join(map(repr, fields)))
log_info("C: %s %r", message_names[signature], fields)

self.packer.pack_struct_header(len(fields), signature)
for field in fields:
Expand Down Expand Up @@ -329,29 +326,36 @@ def fetch(self):
self.defunct = True
self.close()
raise
# Unpack from the raw byte stream and call the relevant message handler(s)
self.unpacker.load(message_data)
size, signature = self.unpacker.unpack_structure_header()
fields = [self.unpacker.unpack() for _ in range(size)]

if __debug__:
log_info("S: %s %r", message_names[signature], fields)
unpacker = self.unpacker
unpacker.load(message_data)
size, signature = unpacker.unpack_structure_header()
if size > 1:
raise ProtocolError("Expected one field")

if signature == SUCCESS:
metadata = unpacker.unpack_map()
log_info("S: SUCCESS (%r)", metadata)
response = self.responses.popleft()
response.complete = True
response.on_success(*fields)
response.on_success(metadata or {})
elif signature == RECORD:
data = unpacker.unpack_list()
log_info("S: RECORD (%r)", data)
response = self.responses[0]
response.on_record(*fields)
response.on_record(data or [])
elif signature == IGNORED:
metadata = unpacker.unpack_map()
log_info("S: IGNORED (%r)", metadata)
response = self.responses.popleft()
response.complete = True
response.on_ignored(*fields)
response.on_ignored(metadata or {})
elif signature == FAILURE:
metadata = unpacker.unpack_map()
log_info("S: FAILURE (%r)", metadata)
response = self.responses.popleft()
response.complete = True
response.on_failure(*fields)
response.on_failure(metadata or {})
else:
raise ProtocolError("Unexpected response message with signature %02X" % signature)

Expand All @@ -365,8 +369,7 @@ def close(self):
""" Close the connection.
"""
if not self.closed:
if __debug__:
log_info("~~ [CLOSE]")
log_info("~~ [CLOSE]")
self.channel.socket.close()
self.closed = True

Expand Down Expand Up @@ -476,7 +479,7 @@ def connect(address, ssl_context=None, **config):
# Establish a connection to the host and port specified
# Catches refused connections see:
# https://docs.python.org/2/library/errno.html
if __debug__: log_info("~~ [CONNECT] %s", address)
log_info("~~ [CONNECT] %s", address)
try:
s = create_connection(address)
except SocketError as error:
Expand All @@ -488,7 +491,7 @@ def connect(address, ssl_context=None, **config):
# Secure the connection if an SSL context has been provided
if ssl_context and SSL_AVAILABLE:
host, port = address
if __debug__: log_info("~~ [SECURE] %s", host)
log_info("~~ [SECURE] %s", host)
try:
s = ssl_context.wrap_socket(s, server_hostname=host if HAS_SNI else None)
except SSLError as cause:
Expand All @@ -514,9 +517,9 @@ def connect(address, ssl_context=None, **config):
# Send details of the protocol versions supported
supported_versions = [1, 0, 0, 0]
handshake = [MAGIC_PREAMBLE] + supported_versions
if __debug__: log_info("C: [HANDSHAKE] 0x%X %r", MAGIC_PREAMBLE, supported_versions)
log_info("C: [HANDSHAKE] 0x%X %r", MAGIC_PREAMBLE, supported_versions)
data = b"".join(struct_pack(">I", num) for num in handshake)
if __debug__: log_debug("C: b%r", data)
log_debug("C: b%r", data)
s.sendall(data)

# Handle the handshake response
Expand All @@ -531,15 +534,15 @@ def connect(address, ssl_context=None, **config):
log_error("S: [CLOSE]")
raise ProtocolError("Connection to %r closed without handshake response" % (address,))
if data_size == 4:
if __debug__: log_debug("S: b%r", data)
log_debug("S: b%r", data)
else:
# Some other garbled data has been received
log_error("S: @*#!")
raise ProtocolError("Expected four byte handshake response, received %r instead" % data)
agreed_version, = struct_unpack(">I", data)
if __debug__: log_info("S: [HANDSHAKE] %d", agreed_version)
log_info("S: [HANDSHAKE] %d", agreed_version)
if agreed_version == 0:
if __debug__: log_info("~~ [CLOSE]")
log_info("~~ [CLOSE]")
s.shutdown(SHUT_RDWR)
s.close()
elif agreed_version == 1:
Expand Down
141 changes: 85 additions & 56 deletions neo4j/v1/packstream.py
Original file line number Diff line number Diff line change
Expand Up @@ -708,64 +708,12 @@ def unpack(self):
return self.read(byte_size).decode(ENCODING)

# List
elif marker_high == 0x90:
size = marker & 0x0F
return [unpack() for _ in range(size)]
elif marker == 0xD4: # LIST_8:
size = UNPACKED_UINT_8[self.read_bytes(1)]
return [unpack() for _ in range(size)]
elif marker == 0xD5: # LIST_16:
size = UNPACKED_UINT_16[self.read_bytes(2)]
return [unpack() for _ in range(size)]
elif marker == 0xD6: # LIST_32:
size = struct_unpack(UINT_32_STRUCT, self.read(4))[0]
return [unpack() for _ in range(size)]
elif marker == 0xD7: # LIST_STREAM:
value = []
item = None
while item is not EndOfStream:
item = unpack()
if item is not EndOfStream:
value.append(item)
return value
elif 0x90 <= marker <= 0x9F or 0xD4 <= marker <= 0xD7:
return self._unpack_list(marker)

# Map
elif marker_high == 0xA0:
size = marker & 0x0F
value = {}
for _ in range(size):
key = unpack()
value[key] = unpack()
return value
elif marker == 0xD8: # MAP_8:
size = UNPACKED_UINT_8[self.read_bytes(1)]
value = {}
for _ in range(size):
key = unpack()
value[key] = unpack()
return value
elif marker == 0xD9: # MAP_16:
size = UNPACKED_UINT_16[self.read_bytes(2)]
value = {}
for _ in range(size):
key = unpack()
value[key] = unpack()
return value
elif marker == 0xDA: # MAP_32:
size = struct_unpack(UINT_32_STRUCT, self.read(4))[0]
value = {}
for _ in range(size):
key = unpack()
value[key] = unpack()
return value
elif marker == 0xDB: # MAP_STREAM:
value = {}
key = None
while key is not EndOfStream:
key = unpack()
if key is not EndOfStream:
value[key] = unpack()
return value
elif 0xA0 <= marker <= 0xAF or 0xD8 <= marker <= 0xDB:
return self._unpack_map(marker)

# Structure
elif marker_high == 0xB0:
Expand Down Expand Up @@ -793,6 +741,87 @@ def unpack(self):
else:
raise RuntimeError("Unknown PackStream marker %02X" % marker)

def unpack_list(self):
marker = self.read_marker()
return self._unpack_list(marker)

def _unpack_list(self, marker):
marker_high = marker & 0xF0
unpack = self.unpack
if marker_high == 0x90:
size = marker & 0x0F
if size == 0:
return []
elif size == 1:
return [unpack()]
else:
return [unpack() for _ in range(size)]
elif marker == 0xD4: # LIST_8:
size = UNPACKED_UINT_8[self.read_bytes(1)]
return [unpack() for _ in range(size)]
elif marker == 0xD5: # LIST_16:
size = UNPACKED_UINT_16[self.read_bytes(2)]
return [unpack() for _ in range(size)]
elif marker == 0xD6: # LIST_32:
size = struct_unpack(UINT_32_STRUCT, self.read_bytes(4))[0]
return [unpack() for _ in range(size)]
elif marker == 0xD7: # LIST_STREAM:
value = []
item = None
while item is not EndOfStream:
item = unpack()
if item is not EndOfStream:
value.append(item)
return value
else:
return None

def unpack_map(self):
marker = self.read_marker()
return self._unpack_map(marker)

def _unpack_map(self, marker):
marker_high = marker & 0xF0
unpack = self.unpack
if marker_high == 0xA0:
size = marker & 0x0F
value = {}
for _ in range(size):
key = unpack()
value[key] = unpack()
return value
elif marker == 0xD8: # MAP_8:
size = UNPACKED_UINT_8[self.read_bytes(1)]
value = {}
for _ in range(size):
key = unpack()
value[key] = unpack()
return value
elif marker == 0xD9: # MAP_16:
size = UNPACKED_UINT_16[self.read_bytes(2)]
value = {}
for _ in range(size):
key = unpack()
value[key] = unpack()
return value
elif marker == 0xDA: # MAP_32:
size = struct_unpack(UINT_32_STRUCT, self.read_bytes(4))[0]
value = {}
for _ in range(size):
key = unpack()
value[key] = unpack()
return value
elif marker == 0xDB: # MAP_STREAM:
value = {}
key = None
while key is not EndOfStream:
key = unpack()
if key is not EndOfStream:
value[key] = unpack()
return value
else:
return None

def unpack_structure_header(self):
marker = self.read_marker()
if marker == -1:
Expand Down