Skip to content

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ async def download_ranges(
attempt_count = 0

def send_ranges_and_get_bytes(
requests: List[_storage_v2.ReadRange],
requests_generator,
state: Dict[str, Any],
metadata: Optional[List[Tuple[str, str]]] = None,
):
Expand All @@ -387,7 +387,7 @@ async def generator():

if attempt_count > 1:
logger.info(
f"Resuming download (attempt {attempt_count - 1}) for {len(requests)} ranges."
f"Resuming download (attempt {attempt_count - 1})."
)

async with lock:
Expand Down Expand Up @@ -436,17 +436,28 @@ async def generator():
)
self._is_stream_open = True

pending_read_ids = {r.read_id for r in requests}
# Stream requests directly without materializing
pending_read_ids = set()
current_batch = []

for read_range in requests_generator:
pending_read_ids.add(read_range.read_id)
current_batch.append(read_range)

# Send batch when it reaches max size
if len(current_batch) >= _MAX_READ_RANGES_PER_BIDI_READ_REQUEST:
await self.read_obj_str.send(
_storage_v2.BidiReadObjectRequest(read_ranges=current_batch)
)
current_batch = []

# Send Requests
for i in range(
0, len(requests), _MAX_READ_RANGES_PER_BIDI_READ_REQUEST
):
batch = requests[i : i + _MAX_READ_RANGES_PER_BIDI_READ_REQUEST]
# Send remaining partial batch
if current_batch:
await self.read_obj_str.send(
_storage_v2.BidiReadObjectRequest(read_ranges=batch)
_storage_v2.BidiReadObjectRequest(read_ranges=current_batch)
)

# Receive responses
while pending_read_ids:
response = await self.read_obj_str.recv()
if response is None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
if you want to use these Rapid Storage APIs.

"""
from typing import Optional
from typing import List, Optional, Tuple
from google.cloud import _storage_v2
from google.cloud.storage._experimental.asyncio.async_grpc_client import AsyncGrpcClient
from google.cloud.storage._experimental.asyncio.async_abstract_object_stream import (
Expand Down Expand Up @@ -62,6 +62,7 @@ def __init__(
object_name: str,
generation_number: Optional[int] = None, # None means new object
write_handle: Optional[bytes] = None,
routing_token: Optional[str] = None,
) -> None:
if client is None:
raise ValueError("client must be provided")
Expand All @@ -77,6 +78,7 @@ def __init__(
)
self.client: AsyncGrpcClient.grpc_client = client
self.write_handle: Optional[bytes] = write_handle
self.routing_token: Optional[str] = routing_token

self._full_bucket_name = f"projects/_/buckets/{self.bucket_name}"

Expand All @@ -91,13 +93,15 @@ def __init__(
self.persisted_size = 0
self.object_resource: Optional[_storage_v2.Object] = None

async def open(self) -> None:
async def open(self, metadata: Optional[List[Tuple[str, str]]] = None) -> None:
"""Opening an object for write , should do it's state lookup
to know what's the persisted size is.
"""
if self._is_stream_open:
raise ValueError("Stream is already open")

write_handle = self.write_handle if self.write_handle else None

# Create a new object or overwrite existing one if generation_number
# is None. This makes it consistent with GCS JSON API behavior.
# Created object type would be Appendable Object.
Expand All @@ -116,37 +120,47 @@ async def open(self) -> None:
bucket=self._full_bucket_name,
object=self.object_name,
generation=self.generation_number,
write_handle=write_handle,
routing_token=self.routing_token if self.routing_token else None,
),
)

request_params = [f"bucket={self._full_bucket_name}"]
other_metadata = []
if metadata:
for key, value in metadata:
if key == "x-goog-request-params":
request_params.append(value)
else:
other_metadata.append((key, value))

current_metadata = other_metadata
current_metadata.append(("x-goog-request-params", ",".join(request_params)))

self.socket_like_rpc = AsyncBidiRpc(
self.rpc, initial_request=self.first_bidi_write_req, metadata=self.metadata
self.rpc,
initial_request=self.first_bidi_write_req,
metadata=current_metadata,
)

await self.socket_like_rpc.open() # this is actually 1 send
response = await self.socket_like_rpc.recv()
self._is_stream_open = True

if not response.resource:
raise ValueError(
"Failed to obtain object resource after opening the stream"
)
if not response.resource.generation:
raise ValueError(
"Failed to obtain object generation after opening the stream"
)
if response.persisted_size:
self.persisted_size = response.persisted_size

if not response.write_handle:
raise ValueError("Failed to obtain write_handle after opening the stream")
if response.resource:
if not response.resource.size:
# Appending to a 0 byte appendable object.
self.persisted_size = 0
else:
self.persisted_size = response.resource.size

if not response.resource.size:
# Appending to a 0 byte appendable object.
self.persisted_size = 0
else:
self.persisted_size = response.resource.size
self.generation_number = response.resource.generation

self.generation_number = response.resource.generation
self.write_handle = response.write_handle
if response.write_handle:
self.write_handle = response.write_handle

async def close(self) -> None:
"""Closes the bidi-gRPC connection."""
Expand Down Expand Up @@ -181,7 +195,16 @@ async def recv(self) -> _storage_v2.BidiWriteObjectResponse:
"""
if not self._is_stream_open:
raise ValueError("Stream is not open")
return await self.socket_like_rpc.recv()
response = await self.socket_like_rpc.recv()
# Update write_handle if present in response
if response:
if response.write_handle:
self.write_handle = response.write_handle
if response.persisted_size is not None:
self.persisted_size = response.persisted_size
if response.resource and response.resource.size:
self.persisted_size = response.resource.size
return response

@property
def is_stream_open(self) -> bool:
Expand Down
46 changes: 44 additions & 2 deletions google/cloud/storage/_experimental/asyncio/retry/_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,19 @@
from typing import Tuple, Optional

from google.api_core import exceptions
from google.cloud._storage_v2.types import BidiReadObjectRedirectedError
from google.cloud._storage_v2.types import (
BidiReadObjectRedirectedError,
BidiWriteObjectRedirectedError,
)
from google.rpc import status_pb2

_BIDI_READ_REDIRECTED_TYPE_URL = (
"type.googleapis.com/google.storage.v2.BidiReadObjectRedirectedError"
)
_BIDI_WRITE_REDIRECTED_TYPE_URL = (
"type.googleapis.com/google.storage.v2.BidiWriteObjectRedirectedError"
)
logger = logging.getLogger(__name__)


def _handle_redirect(
Expand Down Expand Up @@ -78,6 +85,41 @@ def _handle_redirect(
read_handle = redirect_proto.read_handle
break
except Exception as e:
logging.ERROR(f"Error unpacking redirect: {e}")
logger.error(f"Error unpacking redirect: {e}")

return routing_token, read_handle


def _extract_bidi_writes_redirect_proto(exc: Exception):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this method is almost same as _handle_redirect method. Please merge both of them. Also please refer to this comment

grpc_error = None
if isinstance(exc, exceptions.Aborted) and exc.errors:
grpc_error = exc.errors[0]

if grpc_error:
if isinstance(grpc_error, BidiWriteObjectRedirectedError):
return grpc_error

if hasattr(grpc_error, "trailing_metadata"):
trailers = grpc_error.trailing_metadata()
if not trailers:
return

status_details_bin = None
for key, value in trailers:
if key == "grpc-status-details-bin":
status_details_bin = value
break

if status_details_bin:
status_proto = status_pb2.Status()
try:
status_proto.ParseFromString(status_details_bin)
for detail in status_proto.details:
if detail.type_url == _BIDI_WRITE_REDIRECTED_TYPE_URL:
redirect_proto = BidiWriteObjectRedirectedError.deserialize(
detail.value
)
return redirect_proto
except Exception:
logger.error("Error unpacking redirect details from gRPC error.")
pass
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,20 @@ class _BaseResumptionStrategy(abc.ABC):
"""

@abc.abstractmethod
def generate_requests(self, state: Any) -> Iterable[Any]:
"""Generates the next batch of requests based on the current state.
def generate_requests(self, state: Any):
"""Generates requests based on the current state as a generator.

This method is called at the beginning of each retry attempt. It should
inspect the provided state object and generate the appropriate list of
request protos to send to the server. For example, a read strategy
would use this to implement "Smarter Resumption" by creating smaller
`ReadRange` requests for partially downloaded ranges. For bidi-writes,
it will set the `write_offset` field to the persisted size received
from the server in the next request.
inspect the provided state object and yield request protos to send to
the server. For example, a read strategy would use this to implement
"Smarter Resumption" by creating smaller `ReadRange` requests for
partially downloaded ranges. For bidi-writes, it will set the
`write_offset` field to the persisted size received from the server
in the next request.

This is a generator that yields requests incrementally rather than
returning them all at once, allowing for better memory efficiency
and on-demand generation.

:type state: Any
:param state: An object containing all the state needed for the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ async def execute(self, initial_state: Any, retry_policy):
state = initial_state

async def attempt():
requests = self._strategy.generate_requests(state)
stream = self._send_and_recv(requests, state)
requests_generator = self._strategy.generate_requests(state)
stream = self._send_and_recv(requests_generator, state)
try:
async for response in stream:
self._strategy.update_state_from_response(response, state)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,16 @@ def __init__(
class _ReadResumptionStrategy(_BaseResumptionStrategy):
"""The concrete resumption strategy for bidi reads."""

def generate_requests(self, state: Dict[str, Any]) -> List[storage_v2.ReadRange]:
def generate_requests(self, state: Dict[str, Any]):
"""Generates new ReadRange requests for all incomplete downloads.

This is a generator that yields requests one at a time for incomplete
downloads, allowing for better memory efficiency and incremental processing.

:type state: dict
:param state: A dictionary mapping a read_id to its corresponding
_DownloadState object.
"""
pending_requests = []
download_states: Dict[int, _DownloadState] = state["download_states"]

for read_id, read_state in download_states.items():
Expand All @@ -74,8 +76,7 @@ def generate_requests(self, state: Dict[str, Any]) -> List[storage_v2.ReadRange]
read_length=new_length,
read_id=read_id,
)
pending_requests.append(new_request)
return pending_requests
yield new_request

def update_state_from_response(
self, response: storage_v2.BidiReadObjectResponse, state: Dict[str, Any]
Expand Down
Loading