@@ -41,7 +41,7 @@ class Person(BaseModel):
41
41
from llama_cloud import ExtractRun
42
42
from llama_cloud .types .agent_data import AgentData
43
43
from llama_cloud .types .aggregate_group import AggregateGroup
44
- from pydantic import BaseModel , Field , ValidationError
44
+ from pydantic import BaseModel , Field , ValidationError , model_validator , ConfigDict
45
45
from typing import (
46
46
Generic ,
47
47
List ,
@@ -201,9 +201,12 @@ class ExtractedFieldMetadata(BaseModel):
201
201
description = "The original text this field's value was derived from" ,
202
202
)
203
203
204
+ # Forbid unknown keys to avoid swallowing nested dicts
205
+ model_config = ConfigDict (extra = "forbid" )
206
+
204
207
205
208
ExtractedFieldMetaDataDict = Dict [
206
- str , Union [Dict [str , Any ], ExtractedFieldMetadata , list [Any ]]
209
+ str , Union [ExtractedFieldMetadata , Dict [str , Any ], list [Any ]]
207
210
]
208
211
209
212
@@ -223,7 +226,7 @@ def parse_extracted_field_metadata(
223
226
def _parse_extracted_field_metadata_recursive (
224
227
field_value : Any ,
225
228
additional_fields : dict [str , Any ] = {},
226
- ) -> Union [Dict [str , Any ], ExtractedFieldMetadata , list [Any ]]:
229
+ ) -> Union [ExtractedFieldMetadata , Dict [str , Any ], list [Any ]]:
227
230
"""
228
231
Parse the extracted field metadata into a dictionary of field names to field metadata.
229
232
"""
@@ -238,6 +241,8 @@ def _parse_extracted_field_metadata_recursive(
238
241
if len (indicator_fields .intersection (field_value .keys ())) > 0 :
239
242
try :
240
243
merged = {** field_value , ** additional_fields }
244
+ allowed_fields = ExtractedFieldMetadata .model_fields .keys ()
245
+ merged = {k : v for k , v in merged .items () if k in allowed_fields }
241
246
validated = ExtractedFieldMetadata .model_validate (merged )
242
247
243
248
# grab the citation from the array. This is just an array for backwards compatibility.
@@ -340,6 +345,28 @@ class ExtractedData(BaseModel, Generic[ExtractedT]):
340
345
description = "Additional metadata about the extracted data, such as errors, tokens, etc." ,
341
346
)
342
347
348
+ @model_validator (mode = "before" )
349
+ @classmethod
350
+ def _normalize_field_metadata_on_input (cls , value : Any ) -> Any :
351
+ # Ensure any inbound representation (including JSON round-trips)
352
+ # gets normalized so nested dicts become ExtractedFieldMetadata where appropriate.
353
+ if (
354
+ isinstance (value , dict )
355
+ and "field_metadata" in value
356
+ and isinstance (value ["field_metadata" ], dict )
357
+ ):
358
+ try :
359
+ value = {
360
+ ** value ,
361
+ "field_metadata" : parse_extracted_field_metadata (
362
+ value ["field_metadata" ]
363
+ ),
364
+ }
365
+ except Exception :
366
+ # Let pydantic surface detailed errors later rather than swallowing completely
367
+ pass
368
+ return value
369
+
343
370
@classmethod
344
371
def create (
345
372
cls ,
0 commit comments