Skip to content

Commit 31f54bc

Browse files
authored
feat: support passing a pre-uploaded file directly (#871)
* feat: support passing a pre-uploaded file directly * bump version
1 parent b1ae7bb commit 31f54bc

File tree

4 files changed

+193
-21
lines changed

4 files changed

+193
-21
lines changed

py/llama_cloud_services/extract/extract.py

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ def _validate(self) -> None:
227227
raise ValueError(f"Unsupported file type: {type(self.file)}")
228228

229229

230-
FileInput = Union[str, Path, BufferedIOBase, SourceText]
230+
FileInput = Union[str, Path, BufferedIOBase, SourceText, File]
231231

232232

233233
def run_in_thread(
@@ -406,6 +406,8 @@ async def upload_file(self, file_input: SourceText) -> File:
406406

407407
async def _upload_file(self, file_input: FileInput) -> File:
408408
source_text = None
409+
if isinstance(file_input, File):
410+
return file_input
409411
if isinstance(file_input, SourceText):
410412
source_text = file_input
411413
elif isinstance(file_input, (str, Path)):
@@ -533,7 +535,7 @@ async def queue_extraction(
533535

534536
upload_tasks = [self._upload_file(file) for file in files]
535537
with augment_async_errors():
536-
uploaded_files = await run_jobs(
538+
uploaded_files: List[File] = await run_jobs(
537539
upload_tasks,
538540
workers=self.num_workers,
539541
desc="Uploading files",
@@ -987,8 +989,13 @@ def _get_mime_type(
987989
f"Could not determine file type. Please provide a filename with one of these supported extensions: {supported_list}"
988990
)
989991

990-
def _convert_file_to_file_data(self, file_input: FileInput) -> Union[FileData, str]:
992+
def _convert_file_to_file_data(
993+
self, file_input: FileInput
994+
) -> Union[FileData, str, File]:
991995
"""Convert FileInput to FileData or text string for stateless extraction."""
996+
if isinstance(file_input, File):
997+
return file_input
998+
992999
if isinstance(file_input, SourceText):
9931000
if file_input.text_content is not None:
9941001
return file_input.text_content
@@ -1084,24 +1091,23 @@ async def queue_extraction(
10841091
for file_input in files:
10851092
file_data_or_text = self._convert_file_to_file_data(file_input)
10861093

1087-
if isinstance(file_data_or_text, str):
1094+
if isinstance(file_data_or_text, File):
1095+
file_args = {"file_id": file_data_or_text.id}
1096+
1097+
elif isinstance(file_data_or_text, str):
10881098
# It's text content
1089-
job = await self._async_client.llama_extract.extract_stateless(
1090-
project_id=self._project_id,
1091-
organization_id=self._organization_id,
1092-
data_schema=processed_schema,
1093-
config=config,
1094-
text=file_data_or_text,
1095-
)
1099+
file_args = {"text": file_data_or_text}
10961100
else:
10971101
# It's FileData
1098-
job = await self._async_client.llama_extract.extract_stateless(
1099-
project_id=self._project_id,
1100-
organization_id=self._organization_id,
1101-
data_schema=processed_schema,
1102-
config=config,
1103-
file=file_data_or_text,
1104-
)
1102+
file_args = {"file": file_data_or_text}
1103+
1104+
job = await self._async_client.llama_extract.extract_stateless(
1105+
project_id=self._project_id,
1106+
organization_id=self._organization_id,
1107+
data_schema=processed_schema,
1108+
config=config,
1109+
**file_args,
1110+
)
11051111
jobs.append(job)
11061112

11071113
return jobs[0] if len(jobs) == 1 else jobs

py/llama_parse/pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,13 @@ dev = [
1111

1212
[project]
1313
name = "llama-parse"
14-
version = "0.6.59"
14+
version = "0.6.60"
1515
description = "Parse files into RAG-Optimized formats."
1616
authors = [{name = "Logan Markewich", email = "[email protected]"}]
1717
requires-python = ">=3.9,<4.0"
1818
readme = "README.md"
1919
license = "MIT"
20-
dependencies = ["llama-cloud-services>=0.6.59"]
20+
dependencies = ["llama-cloud-services>=0.6.60"]
2121

2222
[project.scripts]
2323
llama-parse = "llama_parse.cli.main:parse"

py/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ dev = [
1919

2020
[project]
2121
name = "llama-cloud-services"
22-
version = "0.6.59"
22+
version = "0.6.60"
2323
description = "Tailored SDK clients for LlamaCloud services."
2424
authors = [{name = "Logan Markewich", email = "[email protected]"}]
2525
requires-python = ">=3.9,<4.0"

py/unit_tests/extract/test_extract.py

Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
import os
2+
from types import SimpleNamespace
3+
4+
import pytest
5+
6+
from llama_cloud.types import File as CloudFile
7+
from llama_cloud_services.extract import LlamaExtract, ExtractionAgent
8+
9+
10+
@pytest.fixture(autouse=True)
11+
def _set_dummy_env(monkeypatch):
12+
monkeypatch.setenv("LLAMA_CLOUD_API_KEY", "test-api-key")
13+
monkeypatch.setenv("LLAMA_CLOUD_BASE_URL", "https://example.test")
14+
15+
16+
@pytest.fixture
17+
def llama_file() -> CloudFile:
18+
return CloudFile(
19+
id="file_123",
20+
name="sample.pdf",
21+
external_file_id="ext_123",
22+
project_id="proj_123",
23+
)
24+
25+
26+
@pytest.fixture
27+
def extractor() -> LlamaExtract:
28+
return LlamaExtract(
29+
api_key=os.environ["LLAMA_CLOUD_API_KEY"],
30+
base_url=os.environ["LLAMA_CLOUD_BASE_URL"],
31+
verify=False,
32+
)
33+
34+
35+
@pytest.fixture
36+
def no_external_validation(monkeypatch):
37+
import llama_cloud_services.extract.extract as extract_mod
38+
39+
async def _noop_validate_schema(client, data_schema):
40+
return data_schema
41+
42+
# Disable config warnings and external schema validation
43+
monkeypatch.setattr(
44+
extract_mod, "_extraction_config_warning", lambda *_args, **_kwargs: None
45+
)
46+
monkeypatch.setattr(extract_mod, "_validate_schema", _noop_validate_schema)
47+
48+
49+
def test_convert_fileinput_accepts_llama_file_directly(
50+
extractor: LlamaExtract, llama_file: CloudFile
51+
):
52+
result = extractor._convert_file_to_file_data(llama_file)
53+
assert result is llama_file
54+
55+
56+
@pytest.mark.asyncio
57+
async def test_queue_extraction_with_llama_file_uses_file_id(
58+
extractor: LlamaExtract, llama_file: CloudFile, no_external_validation, monkeypatch
59+
):
60+
calls = []
61+
62+
async def fake_extract_stateless(**kwargs):
63+
calls.append(kwargs)
64+
return SimpleNamespace(id="job_1")
65+
66+
# Patch the client's method that would normally hit the network
67+
monkeypatch.setattr(
68+
extractor._async_client.llama_extract,
69+
"extract_stateless",
70+
fake_extract_stateless,
71+
)
72+
73+
# Minimal schema and dummy config (warnings disabled by fixture)
74+
schema = {"type": "object", "properties": {}}
75+
dummy_config = SimpleNamespace()
76+
77+
job = await extractor.queue_extraction(schema, dummy_config, llama_file)
78+
79+
assert getattr(job, "id") == "job_1"
80+
assert len(calls) == 1
81+
kwargs = calls[0]
82+
assert "file_id" in kwargs and kwargs["file_id"] == llama_file.id
83+
assert "file" not in kwargs
84+
assert "text" not in kwargs
85+
86+
87+
@pytest.mark.asyncio
88+
async def test_extraction_agent_upload_file_accepts_llama_file_directly(
89+
llama_file: CloudFile,
90+
):
91+
# Build a minimal agent without hitting external services
92+
dummy_async_client = SimpleNamespace()
93+
dummy_agent = SimpleNamespace(id="agent_1", name="dummy", data_schema={}, config={})
94+
95+
agent = ExtractionAgent(
96+
client=dummy_async_client,
97+
agent=dummy_agent,
98+
project_id=None,
99+
organization_id=None,
100+
check_interval=0,
101+
max_timeout=0,
102+
num_workers=1,
103+
show_progress=False,
104+
verbose=False,
105+
verify=False,
106+
httpx_timeout=1,
107+
)
108+
109+
result = await agent._upload_file(llama_file)
110+
assert result is llama_file
111+
112+
113+
@pytest.mark.asyncio
114+
async def test_extraction_agent_aextract_accepts_llama_file(
115+
monkeypatch, llama_file: CloudFile
116+
):
117+
# Build a minimal agent without network
118+
dummy_llama_extract_iface = SimpleNamespace()
119+
120+
async def fake_run_job(**kwargs):
121+
# Ensure we are receiving a request with the right file_id
122+
request = kwargs.get("request")
123+
assert hasattr(request, "file_id")
124+
assert request.file_id == llama_file.id
125+
return SimpleNamespace(id="job_42")
126+
127+
dummy_llama_extract_iface.run_job = fake_run_job
128+
dummy_async_client = SimpleNamespace(llama_extract=dummy_llama_extract_iface)
129+
dummy_agent = SimpleNamespace(id="agent_1", name="dummy", data_schema={}, config={})
130+
131+
agent = ExtractionAgent(
132+
client=dummy_async_client,
133+
agent=dummy_agent,
134+
project_id=None,
135+
organization_id=None,
136+
check_interval=0,
137+
max_timeout=0,
138+
num_workers=1,
139+
show_progress=False,
140+
verbose=False,
141+
verify=False,
142+
httpx_timeout=1,
143+
)
144+
145+
# Ensure _upload_file returns the File directly and is called with our File
146+
calls = {}
147+
148+
async def fake_upload_file(file_input):
149+
calls["upload_called_with"] = file_input
150+
assert file_input is llama_file
151+
return file_input
152+
153+
monkeypatch.setattr(agent, "_upload_file", fake_upload_file)
154+
155+
# Avoid polling logic by short-circuiting result wait
156+
async def fake_wait(job_id: str):
157+
assert job_id == "job_42"
158+
return SimpleNamespace(id="run_42", status="SUCCESS", data={})
159+
160+
monkeypatch.setattr(agent, "_wait_for_job_result", fake_wait)
161+
162+
result = await agent.aextract(llama_file)
163+
164+
assert calls.get("upload_called_with") is llama_file
165+
assert getattr(result, "status") == "SUCCESS"
166+
assert getattr(result, "id") == "run_42"

0 commit comments

Comments
 (0)