@@ -61,28 +61,43 @@ class ServerStates(Enum):
61
61
FAILED = "FAILED"
62
62
63
63
64
- STATE_TRANSITIONS = {
65
- ServerStates .CONNECTED : {
66
- "hello" : ServerStates .READY ,
67
- },
68
- ServerStates .READY : {
69
- "run" : ServerStates .STREAMING ,
70
- "begin" : ServerStates .TX_READY_OR_TX_STREAMING ,
71
- },
72
- ServerStates .STREAMING : {
73
- "pull" : ServerStates .READY ,
74
- "discard" : ServerStates .READY ,
75
- "reset" : ServerStates .READY ,
76
- },
77
- ServerStates .TX_READY_OR_TX_STREAMING : {
78
- "commit" : ServerStates .READY ,
79
- "rollback" : ServerStates .READY ,
80
- "reset" : ServerStates .READY ,
81
- },
82
- ServerStates .FAILED : {
83
- "reset" : ServerStates .READY ,
64
+ class ServerStateManager :
65
+ _STATE_TRANSITIONS = {
66
+ ServerStates .CONNECTED : {
67
+ "hello" : ServerStates .READY ,
68
+ },
69
+ ServerStates .READY : {
70
+ "run" : ServerStates .STREAMING ,
71
+ "begin" : ServerStates .TX_READY_OR_TX_STREAMING ,
72
+ },
73
+ ServerStates .STREAMING : {
74
+ "pull" : ServerStates .READY ,
75
+ "discard" : ServerStates .READY ,
76
+ "reset" : ServerStates .READY ,
77
+ },
78
+ ServerStates .TX_READY_OR_TX_STREAMING : {
79
+ "commit" : ServerStates .READY ,
80
+ "rollback" : ServerStates .READY ,
81
+ "reset" : ServerStates .READY ,
82
+ },
83
+ ServerStates .FAILED : {
84
+ "reset" : ServerStates .READY ,
85
+ }
84
86
}
85
- }
87
+
88
+ def __init__ (self , init_state , on_change = None ):
89
+ self .state = init_state
90
+ self ._on_change = on_change
91
+
92
+ def transition (self , metadata , message ):
93
+ if metadata .get ("has_more" ):
94
+ return
95
+ state_before = self .state
96
+ self .state = self ._STATE_TRANSITIONS \
97
+ .get (self .state , {})\
98
+ .get (message , self .state )
99
+ if state_before != self .state and callable (self ._on_change ):
100
+ self ._on_change (state_before , self .state )
86
101
87
102
88
103
class Bolt3 (Bolt ):
@@ -97,15 +112,23 @@ class Bolt3(Bolt):
97
112
98
113
supports_multiple_databases = False
99
114
100
- _server_state = ServerStates .CONNECTED
115
+ def __init__ (self , * args , ** kwargs ):
116
+ super ().__init__ (* args , ** kwargs )
117
+ self ._server_state_manager = ServerStateManager (
118
+ ServerStates .CONNECTED , on_change = self ._on_server_state_change
119
+ )
120
+
121
+ def _on_server_state_change (self , old_state , new_state ):
122
+ log .debug ("[#%04X] State: %s > %s" , self .local_port ,
123
+ old_state .name , new_state .name )
101
124
102
125
@property
103
126
def is_reset (self ):
104
127
if self .responses :
105
128
# we can't be sure of the server's state as there are still pending
106
129
# responses.
107
130
return False
108
- return self ._server_state == ServerStates .READY
131
+ return self ._server_state_manager . state == ServerStates .READY
109
132
110
133
@property
111
134
def encrypted (self ):
@@ -213,7 +236,6 @@ def pull(self, n=-1, qid=-1, **handlers):
213
236
# Just ignore n and qid, it is not supported in the Bolt 3 Protocol.
214
237
log .debug ("[#%04X] C: PULL_ALL" , self .local_port )
215
238
self ._append (b"\x3F " , (), Response (self , "pull" , ** handlers ))
216
- self ._is_reset = False
217
239
218
240
def begin (self , mode = None , bookmarks = None , metadata = None , timeout = None , db = None , ** handlers ):
219
241
if db is not None :
@@ -238,7 +260,6 @@ def begin(self, mode=None, bookmarks=None, metadata=None, timeout=None, db=None,
238
260
raise TypeError ("Timeout must be specified as a number of seconds" )
239
261
log .debug ("[#%04X] C: BEGIN %r" , self .local_port , extra )
240
262
self ._append (b"\x11 " , (extra ,), Response (self , "begin" , ** handlers ))
241
- self ._is_reset = False
242
263
243
264
def commit (self , ** handlers ):
244
265
log .debug ("[#%04X] C: COMMIT" , self .local_port )
@@ -260,18 +281,6 @@ def fail(metadata):
260
281
self ._append (b"\x0F " , response = Response (self , "reset" , on_failure = fail ))
261
282
self .send_all ()
262
283
self .fetch_all ()
263
- self ._is_reset = True
264
-
265
- def _update_server_state_on_success (self , metadata , message ):
266
- if metadata .get ("has_more" ):
267
- return
268
- state_before = self ._server_state
269
- self ._server_state = STATE_TRANSITIONS \
270
- .get (self ._server_state , {})\
271
- .get (message , self ._server_state )
272
- if state_before != self ._server_state :
273
- log .debug ("[#%04X] State: %s" , self .local_port ,
274
- self ._server_state .name )
275
284
276
285
def fetch_message (self ):
277
286
""" Receive at most one message from the server, if available.
@@ -304,15 +313,15 @@ def fetch_message(self):
304
313
response .complete = True
305
314
if summary_signature == b"\x70 " :
306
315
log .debug ("[#%04X] S: SUCCESS %r" , self .local_port , summary_metadata )
307
- self ._update_server_state_on_success ( summary_metadata ,
308
- response . message )
316
+ self ._server_state_manager . transition ( response . message ,
317
+ summary_metadata )
309
318
response .on_success (summary_metadata or {})
310
319
elif summary_signature == b"\x7E " :
311
320
log .debug ("[#%04X] S: IGNORED" , self .local_port )
312
321
response .on_ignored (summary_metadata or {})
313
322
elif summary_signature == b"\x7F " :
314
323
log .debug ("[#%04X] S: FAILURE %r" , self .local_port , summary_metadata )
315
- self ._server_state = ServerStates .FAILED
324
+ self ._server_state_manager . state = ServerStates .FAILED
316
325
try :
317
326
response .on_failure (summary_metadata or {})
318
327
except (ServiceUnavailable , DatabaseUnavailable ):
0 commit comments