Skip to content

Commit d63d610

Browse files
committed
Solve this a better way
1 parent 9ba6afb commit d63d610

File tree

3 files changed

+2313
-2283
lines changed

3 files changed

+2313
-2283
lines changed

py/llama_cloud_services/beta/agent_data/schema.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ class Person(BaseModel):
4141
from llama_cloud import ExtractRun
4242
from llama_cloud.types.agent_data import AgentData
4343
from llama_cloud.types.aggregate_group import AggregateGroup
44-
from pydantic import BaseModel, Field, ValidationError
44+
from pydantic import BaseModel, Field, ValidationError, model_validator, ConfigDict
4545
from typing import (
4646
Generic,
4747
List,
@@ -201,9 +201,12 @@ class ExtractedFieldMetadata(BaseModel):
201201
description="The original text this field's value was derived from",
202202
)
203203

204+
# Forbid unknown keys to avoid swallowing nested dicts
205+
model_config = ConfigDict(extra="forbid")
206+
204207

205208
ExtractedFieldMetaDataDict = Dict[
206-
str, Union[Dict[str, Any], ExtractedFieldMetadata, list[Any]]
209+
str, Union[ExtractedFieldMetadata, Dict[str, Any], list[Any]]
207210
]
208211

209212

@@ -223,7 +226,7 @@ def parse_extracted_field_metadata(
223226
def _parse_extracted_field_metadata_recursive(
224227
field_value: Any,
225228
additional_fields: dict[str, Any] = {},
226-
) -> Union[Dict[str, Any], ExtractedFieldMetadata, list[Any]]:
229+
) -> Union[ExtractedFieldMetadata, Dict[str, Any], list[Any]]:
227230
"""
228231
Parse the extracted field metadata into a dictionary of field names to field metadata.
229232
"""
@@ -238,6 +241,8 @@ def _parse_extracted_field_metadata_recursive(
238241
if len(indicator_fields.intersection(field_value.keys())) > 0:
239242
try:
240243
merged = {**field_value, **additional_fields}
244+
allowed_fields = ExtractedFieldMetadata.model_fields.keys()
245+
merged = {k: v for k, v in merged.items() if k in allowed_fields}
241246
validated = ExtractedFieldMetadata.model_validate(merged)
242247

243248
# grab the citation from the array. This is just an array for backwards compatibility.
@@ -340,6 +345,28 @@ class ExtractedData(BaseModel, Generic[ExtractedT]):
340345
description="Additional metadata about the extracted data, such as errors, tokens, etc.",
341346
)
342347

348+
@model_validator(mode="before")
349+
@classmethod
350+
def _normalize_field_metadata_on_input(cls, value: Any) -> Any:
351+
# Ensure any inbound representation (including JSON round-trips)
352+
# gets normalized so nested dicts become ExtractedFieldMetadata where appropriate.
353+
if (
354+
isinstance(value, dict)
355+
and "field_metadata" in value
356+
and isinstance(value["field_metadata"], dict)
357+
):
358+
try:
359+
value = {
360+
**value,
361+
"field_metadata": parse_extracted_field_metadata(
362+
value["field_metadata"]
363+
),
364+
}
365+
except Exception:
366+
# Let pydantic surface detailed errors later rather than swallowing completely
367+
pass
368+
return value
369+
343370
@classmethod
344371
def create(
345372
cls,

py/unit_tests/beta/agent/test_agent_data_schema.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -554,7 +554,7 @@ def test_full_parse_nested_dimensions():
554554
with open(Path(__file__).parent.parent.parent / "data" / "capacitor.json") as f:
555555
data = json.load(f)
556556
result = ExtractedData.from_extraction_result(ExtractRun.parse_obj(data), Capacitor)
557-
assert result.field_metadata == {
557+
expected = {
558558
"dimensions": {
559559
"diameter": ExtractedFieldMetadata(
560560
reasoning="VERBATIM EXTRACTION",
@@ -577,3 +577,6 @@ def test_full_parse_nested_dimensions():
577577
),
578578
}
579579
}
580+
assert result.field_metadata == expected
581+
parsed = ExtractedData.model_validate_json(result.model_dump_json())
582+
assert parsed.field_metadata == expected

0 commit comments

Comments
 (0)