Skip to content

Commit 0933632

Browse files
kiberguscopybara-github
authored andcommitted
fix: Make t_part and t_content conform to their type annotations: they now handle FileDict correctly and t_contents handles PartUnionDict correctly.
PiperOrigin-RevId: 814727063
1 parent 6a6d588 commit 0933632

File tree

4 files changed

+61
-4
lines changed

4 files changed

+61
-4
lines changed

google/genai/_transformers.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -366,7 +366,10 @@ def t_part(part: Optional[types.PartUnionDict]) -> types.Part:
366366
raise ValueError('file uri and mime_type are required.')
367367
return types.Part.from_uri(file_uri=part.uri, mime_type=part.mime_type)
368368
if isinstance(part, dict):
369-
return types.Part.model_validate(part)
369+
try:
370+
return types.Part.model_validate(part)
371+
except pydantic.ValidationError:
372+
return types.Part(file_data=types.FileData.model_validate(part))
370373
if isinstance(part, types.Part):
371374
return part
372375

@@ -420,7 +423,7 @@ def t_image_predictions(
420423

421424

422425
def t_content(
423-
content: Optional[ContentType],
426+
content: Union[ContentType, types.ContentDict, None],
424427
) -> types.Content:
425428
if content is None:
426429
raise ValueError('content is required.')
@@ -430,12 +433,14 @@ def t_content(
430433
try:
431434
return types.Content.model_validate(content)
432435
except pydantic.ValidationError:
433-
possible_part = types.Part.model_validate(content)
436+
possible_part = t_part(content) # type: ignore[arg-type]
434437
return (
435438
types.ModelContent(parts=[possible_part])
436439
if possible_part.function_call
437440
else types.UserContent(parts=[possible_part])
438441
)
442+
if isinstance(content, types.File):
443+
return types.UserContent(parts=[t_part(content)])
439444
if isinstance(content, types.Part):
440445
return (
441446
types.ModelContent(parts=[content])
@@ -495,11 +500,18 @@ def _is_part(
495500
return True
496501

497502
if isinstance(part, dict):
503+
if not part:
504+
# Empty dict should be considered as Content, not Part.
505+
return False
498506
try:
499507
types.Part.model_validate(part)
500508
return True
501509
except pydantic.ValidationError:
502-
return False
510+
try:
511+
types.FileData.model_validate(part)
512+
return True
513+
except pydantic.ValidationError:
514+
return False
503515

504516
if 'image' in part.__class__.__name__.lower():
505517
try:

google/genai/tests/transformers/test_t_content.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,17 @@ def test_file_no_mime_type():
122122
t.t_content(types.File(uri='gs://test'))
123123

124124

125+
def test_file_dict():
126+
assert t.t_content({'file_uri': 'gs://test', 'mime_type': 'image/png'}) == types.UserContent(
127+
parts=[
128+
types.Part(
129+
file_data=types.FileData(
130+
file_uri='gs://test', mime_type='image/png'
131+
)
132+
)
133+
]
134+
)
135+
125136
def test_int():
126137
try:
127138
t.t_content(1)

google/genai/tests/transformers/test_t_contents.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,32 @@ def test_file():
122122
]
123123

124124

125+
def test_file_dict():
126+
assert t.t_contents({'file_uri': 'gs://test', 'mime_type': 'image/png'}) == [
127+
types.UserContent(
128+
parts=[
129+
types.Part(
130+
file_data=types.FileData(
131+
file_uri='gs://test', mime_type='image/png'
132+
)
133+
)
134+
]
135+
)
136+
]
137+
138+
def test_file_dict_list():
139+
assert t.t_contents([{'file_uri': 'gs://test', 'mime_type': 'image/png'}]) == [
140+
types.UserContent(
141+
parts=[
142+
types.Part(
143+
file_data=types.FileData(
144+
file_uri='gs://test', mime_type='image/png'
145+
)
146+
)
147+
]
148+
)
149+
]
150+
125151
def test_file_no_uri():
126152
with pytest.raises(ValueError):
127153
t.t_contents(types.File(mime_type='image/png'))

google/genai/tests/transformers/test_t_part.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,14 @@ def test_file():
4343
)
4444

4545

46+
def test_file_dict():
47+
assert t.t_part(
48+
{'file_uri': 'gs://test', 'mime_type': 'image/png'}
49+
) == types.Part(
50+
file_data=types.FileData(file_uri='gs://test', mime_type='image/png')
51+
)
52+
53+
4654
def test_file_no_uri():
4755
with pytest.raises(ValueError):
4856
t.t_part(types.File(mime_type='image/png'))

0 commit comments

Comments
 (0)