@@ -222,3 +222,105 @@ def once(tx):
222
222
for failure in failures :
223
223
with self .subTest (failure = failure ):
224
224
_test ()
225
+
226
+ def test_reset_fails_after_pull (self ):
227
+ def _test (invalid_response_ , api_ ):
228
+ def check_exception (exc ):
229
+ self .assertEqual (
230
+ exc .exception .code ,
231
+ "Neo.TransientError.Statement."
232
+ "RemoteExecutionTransientError"
233
+ )
234
+ if self .driver_supports_features (
235
+ types .Feature .API_RETRYABLE_EXCEPTION
236
+ ):
237
+ self .assertTrue (exc .exception .retryable )
238
+
239
+ def api_call (session_ ):
240
+ if api_ == "session" :
241
+ with self .assertRaises (types .DriverError ) as exc :
242
+ result = session_ .run ("RETURN 1 AS n" )
243
+ list (result )
244
+ check_exception (exc )
245
+ elif api_ == "explicit_tx" :
246
+ tx = session_ .begin_transaction ()
247
+ try :
248
+ with self .assertRaises (types .DriverError ) as exc :
249
+ result = tx .run ("RETURN 1 AS n" )
250
+ list (result )
251
+ check_exception (exc )
252
+ finally :
253
+ tx .close ()
254
+ elif api_ == "managed_tx" :
255
+ run = 0
256
+
257
+ def work (tx ):
258
+ nonlocal run
259
+ run += 1
260
+ if run == 1 :
261
+ with self .assertRaises (types .DriverError ) as exc :
262
+ result = tx .run ("RETURN 1 AS n" )
263
+ list (result )
264
+ check_exception (exc )
265
+ raise exc .exception
266
+ else :
267
+ result = tx .run ("RETURN 1 AS n" )
268
+ return list (result )
269
+
270
+ records = session_ .execute_write (work )
271
+ assert len (records ) == 1
272
+ self .assertEqual (records , [
273
+ types .Record (values = [types .CypherInt (1 )])
274
+ ])
275
+ else :
276
+ raise ValueError (f"Unknown API: { api_ } " )
277
+
278
+ self ._server .start (
279
+ path = self .script_path ("reset_fails_after_pull.script" ),
280
+ vars_ = {
281
+ "#INVALID_RESPONSE#" : invalid_response_ ,
282
+ }
283
+ )
284
+ auth = types .AuthorizationToken ("basic" , principal = "" ,
285
+ credentials = "" )
286
+ driver = Driver (self ._backend ,
287
+ "bolt://%s" % self ._server .address , auth )
288
+ try :
289
+ session = driver .session ("r" )
290
+ try :
291
+ api_call (session )
292
+
293
+ finally :
294
+ session .close ()
295
+ # driver should've killed the misbehaving connection
296
+ try :
297
+ self .assertEqual (
298
+ self ._server .count_responses ("<HANGUP>" ), 1
299
+ )
300
+ finally :
301
+ self ._server ._dump ()
302
+ finally :
303
+ driver .close ()
304
+ self ._server .done ()
305
+
306
+ invalid_responses = (
307
+ (
308
+ 'S: FAILURE {"code": "Neo.ClientError.General.Unknown", '
309
+ '"message": "The driver should ignore this error!"}'
310
+ ),
311
+ "S: IGNORED" ,
312
+ (
313
+ "# MIXED \n "
314
+ "IF: invalid_responses <= 1\n "
315
+ ' S: FAILURE {"code": "Neo.ClientError.General.Unknown", '
316
+ '"message": "The driver should ignore this error!"}\n '
317
+ "ELSE:\n "
318
+ " S: IGNORED\n "
319
+ )
320
+ )
321
+ for invalid_response in invalid_responses :
322
+ for api in ("session" , "explicit_tx" , "managed_tx" ):
323
+ with self .subTest (response = invalid_response [2 :10 ].strip (),
324
+ api = api ):
325
+ _test (invalid_response , api )
326
+ self ._server .reset ()
0 commit comments