@@ -386,6 +386,22 @@ def _getTargetClass(self):
386386 from gcloud .bigtable .row_data import PartialRowsData
387387 return PartialRowsData
388388
389+ def _getDoNothingClass (self ):
390+ klass = self ._getTargetClass ()
391+
392+ class FakePartialRowsData (klass ):
393+
394+ def __init__ (self , * args , ** kwargs ):
395+ super (FakePartialRowsData , self ).__init__ (* args , ** kwargs )
396+ self ._consumed = []
397+
398+ def consume_next (self ):
399+ value = self ._response_iterator .next ()
400+ self ._consumed .append (value )
401+ return value
402+
403+ return FakePartialRowsData
404+
389405 def _makeOne (self , * args , ** kwargs ):
390406 return self ._getTargetClass ()(* args , ** kwargs )
391407
@@ -425,3 +441,84 @@ def test_rows_getter(self):
425441 partial_rows_data = self ._makeOne (None )
426442 partial_rows_data ._rows = value = object ()
427443 self .assertTrue (partial_rows_data .rows is value )
444+
445+ def test_cancel (self ):
446+ response_iterator = _MockCancellableIterator ()
447+ partial_rows_data = self ._makeOne (response_iterator )
448+ self .assertEqual (response_iterator .cancel_calls , 0 )
449+ partial_rows_data .cancel ()
450+ self .assertEqual (response_iterator .cancel_calls , 1 )
451+
452+ def test_consume_next (self ):
453+ from gcloud .bigtable ._generated import (
454+ bigtable_service_messages_pb2 as messages_pb2 )
455+ from gcloud .bigtable .row_data import PartialRowData
456+
457+ row_key = b'row-key'
458+ value_pb = messages_pb2 .ReadRowsResponse (row_key = row_key )
459+ response_iterator = _MockCancellableIterator (value_pb )
460+ partial_rows_data = self ._makeOne (response_iterator )
461+ self .assertEqual (partial_rows_data .rows , {})
462+ partial_rows_data .consume_next ()
463+ expected_rows = {row_key : PartialRowData (row_key )}
464+ self .assertEqual (partial_rows_data .rows , expected_rows )
465+
466+ def test_consume_next_row_exists (self ):
467+ from gcloud .bigtable ._generated import (
468+ bigtable_service_messages_pb2 as messages_pb2 )
469+ from gcloud .bigtable .row_data import PartialRowData
470+
471+ row_key = b'row-key'
472+ chunk = messages_pb2 .ReadRowsResponse .Chunk (commit_row = True )
473+ value_pb = messages_pb2 .ReadRowsResponse (row_key = row_key ,
474+ chunks = [chunk ])
475+ response_iterator = _MockCancellableIterator (value_pb )
476+ partial_rows_data = self ._makeOne (response_iterator )
477+ existing_values = PartialRowData (row_key )
478+ partial_rows_data ._rows [row_key ] = existing_values
479+ self .assertFalse (existing_values .committed )
480+ partial_rows_data .consume_next ()
481+ self .assertTrue (existing_values .committed )
482+ self .assertEqual (existing_values .cells , {})
483+
484+ def test_consume_next_empty_iter (self ):
485+ response_iterator = _MockCancellableIterator ()
486+ partial_rows_data = self ._makeOne (response_iterator )
487+ with self .assertRaises (StopIteration ):
488+ partial_rows_data .consume_next ()
489+
490+ def test_consume_all (self ):
491+ klass = self ._getDoNothingClass ()
492+
493+ value1 , value2 , value3 = object (), object (), object ()
494+ response_iterator = _MockCancellableIterator (value1 , value2 , value3 )
495+ partial_rows_data = klass (response_iterator )
496+ self .assertEqual (partial_rows_data ._consumed , [])
497+ partial_rows_data .consume_all ()
498+ self .assertEqual (partial_rows_data ._consumed , [value1 , value2 , value3 ])
499+
500+ def test_consume_all_with_max_loops (self ):
501+ klass = self ._getDoNothingClass ()
502+
503+ value1 , value2 , value3 = object (), object (), object ()
504+ response_iterator = _MockCancellableIterator (value1 , value2 , value3 )
505+ partial_rows_data = klass (response_iterator )
506+ self .assertEqual (partial_rows_data ._consumed , [])
507+ partial_rows_data .consume_all (max_loops = 1 )
508+ self .assertEqual (partial_rows_data ._consumed , [value1 ])
509+ # Make sure the iterator still has the remaining values.
510+ self .assertEqual (list (response_iterator .iter_values ), [value2 , value3 ])
511+
512+
513+ class _MockCancellableIterator (object ):
514+
515+ cancel_calls = 0
516+
517+ def __init__ (self , * values ):
518+ self .iter_values = iter (values )
519+
520+ def cancel (self ):
521+ self .cancel_calls += 1
522+
523+ def next (self ):
524+ return next (self .iter_values )
0 commit comments