@@ -1927,62 +1927,121 @@ async def mock_call(*args, **kwargs):
19271927 assert call_time < 0.2
19281928
19291929 @pytest .mark .asyncio
1930- async def test_read_rows_sharded_batching (self ):
1930+ async def test_read_rows_sharded_concurrency_limit (self ):
19311931 """
1932- Large queries should be processed in batches to limit concurrency
1933- operation timeout should change between batches
1932+ Only 10 queries should be processed concurrently. Others should be queued
1933+
1934+ Should start a new query as soon as previous finishes
19341935 """
1935- from google .cloud .bigtable .data ._async .client import TableAsync
19361936 from google .cloud .bigtable .data ._async .client import _CONCURRENCY_LIMIT
19371937
19381938 assert _CONCURRENCY_LIMIT == 10 # change this test if this changes
1939+ num_queries = 15
19391940
1940- n_queries = 90
1941- expected_num_batches = n_queries // _CONCURRENCY_LIMIT
1942- query_list = [ReadRowsQuery () for _ in range (n_queries )]
1943-
1944- table_mock = AsyncMock ()
1945- start_operation_timeout = 10
1946- start_attempt_timeout = 3
1947- table_mock .default_read_rows_operation_timeout = start_operation_timeout
1948- table_mock .default_read_rows_attempt_timeout = start_attempt_timeout
1949- # clock ticks one second on each check
1950- with mock .patch ("time.monotonic" , side_effect = range (0 , 100000 )):
1951- with mock .patch ("asyncio.gather" , AsyncMock ()) as gather_mock :
1952- await TableAsync .read_rows_sharded (table_mock , query_list )
1953- # should have individual calls for each query
1954- assert table_mock .read_rows .call_count == n_queries
1955- # should have single gather call for each batch
1956- assert gather_mock .call_count == expected_num_batches
1957- # ensure that timeouts decrease over time
1958- kwargs = [
1959- table_mock .read_rows .call_args_list [idx ][1 ]
1960- for idx in range (n_queries )
1961- ]
1962- for batch_idx in range (expected_num_batches ):
1963- batch_kwargs = kwargs [
1964- batch_idx
1965- * _CONCURRENCY_LIMIT : (batch_idx + 1 )
1966- * _CONCURRENCY_LIMIT
1941+ # each of the first 10 queries take longer than the last
1942+ # later rpcs will have to wait on first 10
1943+ increment_time = 0.05
1944+ max_time = increment_time * (_CONCURRENCY_LIMIT - 1 )
1945+ rpc_times = [min (i * increment_time , max_time ) for i in range (num_queries )]
1946+
1947+ async def mock_call (* args , ** kwargs ):
1948+ next_sleep = rpc_times .pop (0 )
1949+ await asyncio .sleep (next_sleep )
1950+ return [mock .Mock ()]
1951+
1952+ starting_timeout = 10
1953+
1954+ async with _make_client () as client :
1955+ async with client .get_table ("instance" , "table" ) as table :
1956+ with mock .patch .object (table , "read_rows" ) as read_rows :
1957+ read_rows .side_effect = mock_call
1958+ queries = [ReadRowsQuery () for _ in range (num_queries )]
1959+ await table .read_rows_sharded (
1960+ queries , operation_timeout = starting_timeout
1961+ )
1962+ assert read_rows .call_count == num_queries
1963+ # check operation timeouts to see how far into the operation each rpc started
1964+ rpc_start_list = [
1965+ starting_timeout - kwargs ["operation_timeout" ]
1966+ for _ , kwargs in read_rows .call_args_list
19671967 ]
1968- for req_kwargs in batch_kwargs :
1969- # each batch should have the same operation_timeout, and it should decrease in each batch
1970- expected_operation_timeout = start_operation_timeout - (
1971- batch_idx + 1
1972- )
1973- assert (
1974- req_kwargs ["operation_timeout" ]
1975- == expected_operation_timeout
1976- )
1977- # each attempt_timeout should start with default value, but decrease when operation_timeout reaches it
1978- expected_attempt_timeout = min (
1979- start_attempt_timeout , expected_operation_timeout
1968+ eps = 0.01
1969+ # first 10 should start immediately
1970+ assert all (
1971+ rpc_start_list [i ] < eps for i in range (_CONCURRENCY_LIMIT )
1972+ )
1973+ # next rpcs should start as first ones finish
1974+ for i in range (num_queries - _CONCURRENCY_LIMIT ):
1975+ idx = i + _CONCURRENCY_LIMIT
1976+ assert rpc_start_list [idx ] - (i * increment_time ) < eps
1977+
1978+ @pytest .mark .asyncio
1979+ async def test_read_rows_sharded_expirary (self ):
1980+ """
1981+ If the operation times out before all shards complete, should raise
1982+ a ShardedReadRowsExceptionGroup
1983+ """
1984+ from google .cloud .bigtable .data ._async .client import _CONCURRENCY_LIMIT
1985+ from google .cloud .bigtable .data .exceptions import ShardedReadRowsExceptionGroup
1986+ from google .api_core .exceptions import DeadlineExceeded
1987+
1988+ operation_timeout = 0.1
1989+
1990+ # let the first batch complete, but the next batch times out
1991+ num_queries = 15
1992+ sleeps = [0 ] * _CONCURRENCY_LIMIT + [DeadlineExceeded ("times up" )] * (
1993+ num_queries - _CONCURRENCY_LIMIT
1994+ )
1995+
1996+ async def mock_call (* args , ** kwargs ):
1997+ next_item = sleeps .pop (0 )
1998+ if isinstance (next_item , Exception ):
1999+ raise next_item
2000+ else :
2001+ await asyncio .sleep (next_item )
2002+ return [mock .Mock ()]
2003+
2004+ async with _make_client () as client :
2005+ async with client .get_table ("instance" , "table" ) as table :
2006+ with mock .patch .object (table , "read_rows" ) as read_rows :
2007+ read_rows .side_effect = mock_call
2008+ queries = [ReadRowsQuery () for _ in range (num_queries )]
2009+ with pytest .raises (ShardedReadRowsExceptionGroup ) as exc :
2010+ await table .read_rows_sharded (
2011+ queries , operation_timeout = operation_timeout
19802012 )
1981- assert req_kwargs ["attempt_timeout" ] == expected_attempt_timeout
1982- # await all created coroutines to avoid warnings
1983- for i in range (len (gather_mock .call_args_list )):
1984- for j in range (len (gather_mock .call_args_list [i ][0 ])):
1985- await gather_mock .call_args_list [i ][0 ][j ]
2013+ assert isinstance (exc .value , ShardedReadRowsExceptionGroup )
2014+ assert len (exc .value .exceptions ) == num_queries - _CONCURRENCY_LIMIT
2015+ # should keep successful queries
2016+ assert len (exc .value .successful_rows ) == _CONCURRENCY_LIMIT
2017+
2018+ @pytest .mark .asyncio
2019+ async def test_read_rows_sharded_negative_batch_timeout (self ):
2020+ """
2021+ try to run with batch that starts after operation timeout
2022+
2023+ They should raise DeadlineExceeded errors
2024+ """
2025+ from google .cloud .bigtable .data .exceptions import ShardedReadRowsExceptionGroup
2026+ from google .api_core .exceptions import DeadlineExceeded
2027+
2028+ async def mock_call (* args , ** kwargs ):
2029+ await asyncio .sleep (0.05 )
2030+ return [mock .Mock ()]
2031+
2032+ async with _make_client () as client :
2033+ async with client .get_table ("instance" , "table" ) as table :
2034+ with mock .patch .object (table , "read_rows" ) as read_rows :
2035+ read_rows .side_effect = mock_call
2036+ queries = [ReadRowsQuery () for _ in range (15 )]
2037+ with pytest .raises (ShardedReadRowsExceptionGroup ) as exc :
2038+ await table .read_rows_sharded (queries , operation_timeout = 0.01 )
2039+ assert isinstance (exc .value , ShardedReadRowsExceptionGroup )
2040+ assert len (exc .value .exceptions ) == 5
2041+ assert all (
2042+ isinstance (e .__cause__ , DeadlineExceeded )
2043+ for e in exc .value .exceptions
2044+ )
19862045
19872046
19882047class TestSampleRowKeys :
0 commit comments