Skip to content

Commit 468d529

Browse files
MarkDaoustcopybara-github
authored andcommitted
feat: [Python] Implement async embedding batches for MLDev.
feat: [Python] Add BatchJob.done property. - Vertex's interface is generic enough that it already works with embedding batches. - MLDev's interface uses a separate method for embedding batches. - The SDK uses needs different argument names to indicate which method to call. - The different argument/result field names are also important to avoid the use of Union Types with inlinedRequests and inlinedResponses. PiperOrigin-RevId: 804658916
1 parent 8429744 commit 468d529

File tree

9 files changed

+1086
-172
lines changed

9 files changed

+1086
-172
lines changed

google/genai/_extra_utils.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -90,14 +90,12 @@ def _get_bigquery_uri(
9090

9191

9292
def format_destination(
93-
src: Union[str, types.BatchJobSourceOrDict],
94-
config: Optional[types.CreateBatchJobConfigOrDict] = None,
93+
src: Union[str, types.BatchJobSource],
94+
config: Optional[types.CreateBatchJobConfig] = None,
9595
) -> types.CreateBatchJobConfig:
9696
"""Formats the destination uri based on the source uri for Vertex AI."""
97-
config = (
98-
types._CreateBatchJobParameters(config=config).config
99-
or types.CreateBatchJobConfig()
100-
)
97+
if config is None:
98+
config = types.CreateBatchJobConfig()
10199

102100
unique_name = None
103101
if not config.display_name:
@@ -113,8 +111,7 @@ def format_destination(
113111
elif bigquery_source_uri:
114112
unique_name = unique_name or _common.timestamped_unique_name()
115113
config.dest = f'{bigquery_source_uri}_dest_{unique_name}'
116-
else:
117-
raise ValueError(f'The source {src} is not supported.')
114+
118115
return config
119116

120117

google/genai/_transformers.py

Lines changed: 55 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -964,30 +964,30 @@ def t_cached_content_name(client: _api_client.BaseApiClient, name: str) -> str:
964964

965965
def t_batch_job_source(
966966
client: _api_client.BaseApiClient,
967-
src: Union[
968-
str, List[types.InlinedRequestOrDict], types.BatchJobSourceOrDict
969-
],
967+
src: types.BatchJobSourceUnionDict,
970968
) -> types.BatchJobSource:
971969
if isinstance(src, dict):
972970
src = types.BatchJobSource(**src)
973971
if isinstance(src, types.BatchJobSource):
972+
vertex_sources = sum(
973+
[src.gcs_uri is not None, src.bigquery_uri is not None]
974+
)
975+
mldev_sources = sum([
976+
src.inlined_requests is not None,
977+
src.file_name is not None,
978+
])
974979
if client.vertexai:
975-
if src.gcs_uri and src.bigquery_uri:
980+
if mldev_sources or vertex_sources != 1:
976981
raise ValueError(
977-
'Only one of `gcs_uri` or `bigquery_uri` can be set.'
978-
)
979-
elif not src.gcs_uri and not src.bigquery_uri:
980-
raise ValueError(
981-
'One of `gcs_uri` or `bigquery_uri` must be set.'
982+
'Exactly one of `gcs_uri` or `bigquery_uri` must be set, other '
983+
'sources are not supported in Vertex AI.'
982984
)
983985
else:
984-
if src.inlined_requests and src.file_name:
986+
if vertex_sources or mldev_sources != 1:
985987
raise ValueError(
986-
'Only one of `inlined_requests` or `file_name` can be set.'
987-
)
988-
elif not src.inlined_requests and not src.file_name:
989-
raise ValueError(
990-
'One of `inlined_requests` or `file_name` must be set.'
988+
'Exactly one of `inlined_requests`, `file_name`, '
989+
'`inlined_embed_content_requests`, or `embed_content_file_name` '
990+
'must be set, other sources are not supported in Gemini API.'
991991
)
992992
return src
993993

@@ -1012,6 +1012,29 @@ def t_batch_job_source(
10121012
raise ValueError(f'Unsupported source: {src}')
10131013

10141014

1015+
def t_embedding_batch_job_source(
1016+
client: _api_client.BaseApiClient,
1017+
src: types.EmbeddingsBatchJobSourceOrDict,
1018+
) -> types.EmbeddingsBatchJobSource:
1019+
if isinstance(src, dict):
1020+
src = types.EmbeddingsBatchJobSource(**src)
1021+
1022+
if isinstance(src, types.EmbeddingsBatchJobSource):
1023+
mldev_sources = sum([
1024+
src.inlined_requests is not None,
1025+
src.file_name is not None,
1026+
])
1027+
if mldev_sources != 1:
1028+
raise ValueError(
1029+
'Exactly one of `inlined_requests`, `file_name`, '
1030+
'`inlined_embed_content_requests`, or `embed_content_file_name` '
1031+
'must be set, other sources are not supported in Gemini API.'
1032+
)
1033+
return src
1034+
else:
1035+
raise ValueError(f'Unsupported source type: {type(src)}')
1036+
1037+
10151038
def t_batch_job_destination(
10161039
dest: Union[str, types.BatchJobDestinationOrDict],
10171040
) -> types.BatchJobDestination:
@@ -1037,6 +1060,23 @@ def t_batch_job_destination(
10371060
raise ValueError(f'Unsupported destination: {dest}')
10381061

10391062

1063+
def t_recv_batch_job_destination(dest: dict[str, Any]) -> dict[str, Any]:
1064+
# Rename inlinedResponses if it looks like an embedding response.
1065+
inline_responses = dest.get('inlinedResponses', {}).get(
1066+
'inlinedResponses', []
1067+
)
1068+
if not inline_responses:
1069+
return dest
1070+
for response in inline_responses:
1071+
inner_response = response.get('response', {})
1072+
if not inner_response:
1073+
continue
1074+
if 'embedding' in inner_response:
1075+
dest['inlinedEmbedContentResponses'] = dest.pop('inlinedResponses')
1076+
break
1077+
return dest
1078+
1079+
10401080
def t_batch_job_name(client: _api_client.BaseApiClient, name: str) -> str:
10411081
if not client.vertexai:
10421082
mldev_pattern = r'batches/[^/]+$'

0 commit comments

Comments
 (0)