Skip to content

Commit 5f76257

Browse files
committed
Code clean-up
1 parent 315077f commit 5f76257

File tree

2 files changed

+63
-61
lines changed

2 files changed

+63
-61
lines changed

neo4j/io/_bolt3.py

Lines changed: 49 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -61,28 +61,43 @@ class ServerStates(Enum):
6161
FAILED = "FAILED"
6262

6363

64-
STATE_TRANSITIONS = {
65-
ServerStates.CONNECTED: {
66-
"hello": ServerStates.READY,
67-
},
68-
ServerStates.READY: {
69-
"run": ServerStates.STREAMING,
70-
"begin": ServerStates.TX_READY_OR_TX_STREAMING,
71-
},
72-
ServerStates.STREAMING: {
73-
"pull": ServerStates.READY,
74-
"discard": ServerStates.READY,
75-
"reset": ServerStates.READY,
76-
},
77-
ServerStates.TX_READY_OR_TX_STREAMING: {
78-
"commit": ServerStates.READY,
79-
"rollback": ServerStates.READY,
80-
"reset": ServerStates.READY,
81-
},
82-
ServerStates.FAILED: {
83-
"reset": ServerStates.READY,
64+
class ServerStateManager:
65+
_STATE_TRANSITIONS = {
66+
ServerStates.CONNECTED: {
67+
"hello": ServerStates.READY,
68+
},
69+
ServerStates.READY: {
70+
"run": ServerStates.STREAMING,
71+
"begin": ServerStates.TX_READY_OR_TX_STREAMING,
72+
},
73+
ServerStates.STREAMING: {
74+
"pull": ServerStates.READY,
75+
"discard": ServerStates.READY,
76+
"reset": ServerStates.READY,
77+
},
78+
ServerStates.TX_READY_OR_TX_STREAMING: {
79+
"commit": ServerStates.READY,
80+
"rollback": ServerStates.READY,
81+
"reset": ServerStates.READY,
82+
},
83+
ServerStates.FAILED: {
84+
"reset": ServerStates.READY,
85+
}
8486
}
85-
}
87+
88+
def __init__(self, init_state, on_change=None):
89+
self.state = init_state
90+
self._on_change = on_change
91+
92+
def transition(self, metadata, message):
93+
if metadata.get("has_more"):
94+
return
95+
state_before = self.state
96+
self.state = self._STATE_TRANSITIONS\
97+
.get(self.state, {})\
98+
.get(message, self.state)
99+
if state_before != self.state and callable(self._on_change):
100+
self._on_change(state_before, self.state)
86101

87102

88103
class Bolt3(Bolt):
@@ -97,15 +112,23 @@ class Bolt3(Bolt):
97112

98113
supports_multiple_databases = False
99114

100-
_server_state = ServerStates.CONNECTED
115+
def __init__(self, *args, **kwargs):
116+
super().__init__(*args, **kwargs)
117+
self._server_state_manager = ServerStateManager(
118+
ServerStates.CONNECTED, on_change=self._on_server_state_change
119+
)
120+
121+
def _on_server_state_change(self, old_state, new_state):
122+
log.debug("[#%04X] State: %s > %s", self.local_port,
123+
old_state.name, new_state.name)
101124

102125
@property
103126
def is_reset(self):
104127
if self.responses:
105128
# we can't be sure of the server's state as there are still pending
106129
# responses.
107130
return False
108-
return self._server_state == ServerStates.READY
131+
return self._server_state_manager.state == ServerStates.READY
109132

110133
@property
111134
def encrypted(self):
@@ -213,7 +236,6 @@ def pull(self, n=-1, qid=-1, **handlers):
213236
# Just ignore n and qid, it is not supported in the Bolt 3 Protocol.
214237
log.debug("[#%04X] C: PULL_ALL", self.local_port)
215238
self._append(b"\x3F", (), Response(self, "pull", **handlers))
216-
self._is_reset = False
217239

218240
def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, db=None, **handlers):
219241
if db is not None:
@@ -238,7 +260,6 @@ def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, db=None,
238260
raise TypeError("Timeout must be specified as a number of seconds")
239261
log.debug("[#%04X] C: BEGIN %r", self.local_port, extra)
240262
self._append(b"\x11", (extra,), Response(self, "begin", **handlers))
241-
self._is_reset = False
242263

243264
def commit(self, **handlers):
244265
log.debug("[#%04X] C: COMMIT", self.local_port)
@@ -260,18 +281,6 @@ def fail(metadata):
260281
self._append(b"\x0F", response=Response(self, "reset", on_failure=fail))
261282
self.send_all()
262283
self.fetch_all()
263-
self._is_reset = True
264-
265-
def _update_server_state_on_success(self, metadata, message):
266-
if metadata.get("has_more"):
267-
return
268-
state_before = self._server_state
269-
self._server_state = STATE_TRANSITIONS\
270-
.get(self._server_state, {})\
271-
.get(message, self._server_state)
272-
if state_before != self._server_state:
273-
log.debug("[#%04X] State: %s", self.local_port,
274-
self._server_state.name)
275284

276285
def fetch_message(self):
277286
""" Receive at most one message from the server, if available.
@@ -304,15 +313,15 @@ def fetch_message(self):
304313
response.complete = True
305314
if summary_signature == b"\x70":
306315
log.debug("[#%04X] S: SUCCESS %r", self.local_port, summary_metadata)
307-
self._update_server_state_on_success(summary_metadata,
308-
response.message)
316+
self._server_state_manager.transition(response.message,
317+
summary_metadata)
309318
response.on_success(summary_metadata or {})
310319
elif summary_signature == b"\x7E":
311320
log.debug("[#%04X] S: IGNORED", self.local_port)
312321
response.on_ignored(summary_metadata or {})
313322
elif summary_signature == b"\x7F":
314323
log.debug("[#%04X] S: FAILURE %r", self.local_port, summary_metadata)
315-
self._server_state = ServerStates.FAILED
324+
self._server_state_manager.state = ServerStates.FAILED
316325
try:
317326
response.on_failure(summary_metadata or {})
318327
except (ServiceUnavailable, DatabaseUnavailable):

neo4j/io/_bolt4.py

Lines changed: 14 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,8 @@
4949
Response,
5050
)
5151
from neo4j.io._bolt3 import (
52+
ServerStateManager,
5253
ServerStates,
53-
STATE_TRANSITIONS,
5454
)
5555

5656

@@ -69,15 +69,23 @@ class Bolt4x0(Bolt):
6969

7070
supports_multiple_databases = True
7171

72-
_server_state = ServerStates.CONNECTED
72+
def __init__(self, *args, **kwargs):
73+
super().__init__(*args, **kwargs)
74+
self._server_state_manager = ServerStateManager(
75+
ServerStates.CONNECTED, on_change=self._on_server_state_change
76+
)
77+
78+
def _on_server_state_change(self, old_state, new_state):
79+
log.debug("[#%04X] State: %s > %s", self.local_port,
80+
old_state.name, new_state.name)
7381

7482
@property
7583
def is_reset(self):
7684
if self.responses:
7785
# we can't be sure of the server's state as there are still pending
7886
# responses.
7987
return False
80-
return self._server_state == ServerStates.READY
88+
return self._server_state_manager.state == ServerStates.READY
8189

8290
@property
8391
def encrypted(self):
@@ -181,7 +189,6 @@ def run(self, query, parameters=None, mode=None, bookmarks=None, metadata=None,
181189
**handlers))
182190
else:
183191
self._append(b"\x10", fields, Response(self, "run", **handlers))
184-
self._is_reset = False
185192

186193
def discard(self, n=-1, qid=-1, **handlers):
187194
extra = {"n": n}
@@ -196,7 +203,6 @@ def pull(self, n=-1, qid=-1, **handlers):
196203
extra["qid"] = qid
197204
log.debug("[#%04X] C: PULL %r", self.local_port, extra)
198205
self._append(b"\x3F", (extra,), Response(self, "pull", **handlers))
199-
self._is_reset = False
200206

201207
def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None,
202208
db=None, **handlers):
@@ -222,7 +228,6 @@ def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None,
222228
raise TypeError("Timeout must be specified as a number of seconds")
223229
log.debug("[#%04X] C: BEGIN %r", self.local_port, extra)
224230
self._append(b"\x11", (extra,), Response(self, "begin", **handlers))
225-
self._is_reset = False
226231

227232
def commit(self, **handlers):
228233
log.debug("[#%04X] C: COMMIT", self.local_port)
@@ -244,18 +249,6 @@ def fail(metadata):
244249
self._append(b"\x0F", response=Response(self, "reset", on_failure=fail))
245250
self.send_all()
246251
self.fetch_all()
247-
self._is_reset = True
248-
249-
def _update_server_state_on_success(self, metadata, message):
250-
if metadata.get("has_more"):
251-
return
252-
state_before = self._server_state
253-
self._server_state = STATE_TRANSITIONS\
254-
.get(self._server_state, {})\
255-
.get(message, self._server_state)
256-
if state_before != self._server_state:
257-
log.debug("[#%04X] [%s]", self.local_port,
258-
self._server_state.name)
259252

260253
def fetch_message(self):
261254
""" Receive at most one message from the server, if available.
@@ -288,15 +281,15 @@ def fetch_message(self):
288281
response.complete = True
289282
if summary_signature == b"\x70":
290283
log.debug("[#%04X] S: SUCCESS %r", self.local_port, summary_metadata)
291-
self._update_server_state_on_success(summary_metadata,
292-
response.message)
284+
self._server_state_manager.transition(response.message,
285+
summary_metadata)
293286
response.on_success(summary_metadata or {})
294287
elif summary_signature == b"\x7E":
295288
log.debug("[#%04X] S: IGNORED", self.local_port)
296289
response.on_ignored(summary_metadata or {})
297290
elif summary_signature == b"\x7F":
298291
log.debug("[#%04X] S: FAILURE %r", self.local_port, summary_metadata)
299-
self._server_state = ServerStates.FAILED
292+
self._server_state_manager.state = ServerStates.FAILED
300293
try:
301294
response.on_failure(summary_metadata or {})
302295
except (ServiceUnavailable, DatabaseUnavailable):

0 commit comments

Comments
 (0)