@@ -964,30 +964,30 @@ def t_cached_content_name(client: _api_client.BaseApiClient, name: str) -> str:
964964
965965def t_batch_job_source (
966966 client : _api_client .BaseApiClient ,
967- src : Union [
968- str , List [types .InlinedRequestOrDict ], types .BatchJobSourceOrDict
969- ],
967+ src : types .BatchJobSourceUnionDict ,
970968) -> types .BatchJobSource :
971969 if isinstance (src , dict ):
972970 src = types .BatchJobSource (** src )
973971 if isinstance (src , types .BatchJobSource ):
972+ vertex_sources = sum (
973+ [src .gcs_uri is not None , src .bigquery_uri is not None ]
974+ )
975+ mldev_sources = sum ([
976+ src .inlined_requests is not None ,
977+ src .file_name is not None ,
978+ ])
974979 if client .vertexai :
975- if src . gcs_uri and src . bigquery_uri :
980+ if mldev_sources or vertex_sources != 1 :
976981 raise ValueError (
977- 'Only one of `gcs_uri` or `bigquery_uri` can be set.'
978- )
979- elif not src .gcs_uri and not src .bigquery_uri :
980- raise ValueError (
981- 'One of `gcs_uri` or `bigquery_uri` must be set.'
982+ 'Exactly one of `gcs_uri` or `bigquery_uri` must be set, other '
983+ 'sources are not supported in Vertex AI.'
982984 )
983985 else :
984- if src . inlined_requests and src . file_name :
986+ if vertex_sources or mldev_sources != 1 :
985987 raise ValueError (
986- 'Only one of `inlined_requests` or `file_name` can be set.'
987- )
988- elif not src .inlined_requests and not src .file_name :
989- raise ValueError (
990- 'One of `inlined_requests` or `file_name` must be set.'
988+ 'Exactly one of `inlined_requests`, `file_name`, '
989+ '`inlined_embed_content_requests`, or `embed_content_file_name` '
990+ 'must be set, other sources are not supported in Gemini API.'
991991 )
992992 return src
993993
@@ -1012,6 +1012,29 @@ def t_batch_job_source(
10121012 raise ValueError (f'Unsupported source: { src } ' )
10131013
10141014
1015+ def t_embedding_batch_job_source (
1016+ client : _api_client .BaseApiClient ,
1017+ src : types .EmbeddingsBatchJobSourceOrDict ,
1018+ ) -> types .EmbeddingsBatchJobSource :
1019+ if isinstance (src , dict ):
1020+ src = types .EmbeddingsBatchJobSource (** src )
1021+
1022+ if isinstance (src , types .EmbeddingsBatchJobSource ):
1023+ mldev_sources = sum ([
1024+ src .inlined_requests is not None ,
1025+ src .file_name is not None ,
1026+ ])
1027+ if mldev_sources != 1 :
1028+ raise ValueError (
1029+ 'Exactly one of `inlined_requests`, `file_name`, '
1030+ '`inlined_embed_content_requests`, or `embed_content_file_name` '
1031+ 'must be set, other sources are not supported in Gemini API.'
1032+ )
1033+ return src
1034+ else :
1035+ raise ValueError (f'Unsupported source type: { type (src )} ' )
1036+
1037+
10151038def t_batch_job_destination (
10161039 dest : Union [str , types .BatchJobDestinationOrDict ],
10171040) -> types .BatchJobDestination :
@@ -1037,6 +1060,23 @@ def t_batch_job_destination(
10371060 raise ValueError (f'Unsupported destination: { dest } ' )
10381061
10391062
1063+ def t_recv_batch_job_destination (dest : dict [str , Any ]) -> dict [str , Any ]:
1064+ # Rename inlinedResponses if it looks like an embedding response.
1065+ inline_responses = dest .get ('inlinedResponses' , {}).get (
1066+ 'inlinedResponses' , []
1067+ )
1068+ if not inline_responses :
1069+ return dest
1070+ for response in inline_responses :
1071+ inner_response = response .get ('response' , {})
1072+ if not inner_response :
1073+ continue
1074+ if 'embedding' in inner_response :
1075+ dest ['inlinedEmbedContentResponses' ] = dest .pop ('inlinedResponses' )
1076+ break
1077+ return dest
1078+
1079+
10401080def t_batch_job_name (client : _api_client .BaseApiClient , name : str ) -> str :
10411081 if not client .vertexai :
10421082 mldev_pattern = r'batches/[^/]+$'
0 commit comments