|
| 1 | +import os |
| 2 | +from types import SimpleNamespace |
| 3 | + |
| 4 | +import pytest |
| 5 | + |
| 6 | +from llama_cloud.types import File as CloudFile |
| 7 | +from llama_cloud_services.extract import LlamaExtract, ExtractionAgent |
| 8 | + |
| 9 | + |
| 10 | +@pytest.fixture(autouse=True) |
| 11 | +def _set_dummy_env(monkeypatch): |
| 12 | + monkeypatch.setenv("LLAMA_CLOUD_API_KEY", "test-api-key") |
| 13 | + monkeypatch.setenv("LLAMA_CLOUD_BASE_URL", "https://example.test") |
| 14 | + |
| 15 | + |
| 16 | +@pytest.fixture |
| 17 | +def llama_file() -> CloudFile: |
| 18 | + return CloudFile( |
| 19 | + id="file_123", |
| 20 | + name="sample.pdf", |
| 21 | + external_file_id="ext_123", |
| 22 | + project_id="proj_123", |
| 23 | + ) |
| 24 | + |
| 25 | + |
| 26 | +@pytest.fixture |
| 27 | +def extractor() -> LlamaExtract: |
| 28 | + return LlamaExtract( |
| 29 | + api_key=os.environ["LLAMA_CLOUD_API_KEY"], |
| 30 | + base_url=os.environ["LLAMA_CLOUD_BASE_URL"], |
| 31 | + verify=False, |
| 32 | + ) |
| 33 | + |
| 34 | + |
| 35 | +@pytest.fixture |
| 36 | +def no_external_validation(monkeypatch): |
| 37 | + import llama_cloud_services.extract.extract as extract_mod |
| 38 | + |
| 39 | + async def _noop_validate_schema(client, data_schema): |
| 40 | + return data_schema |
| 41 | + |
| 42 | + # Disable config warnings and external schema validation |
| 43 | + monkeypatch.setattr( |
| 44 | + extract_mod, "_extraction_config_warning", lambda *_args, **_kwargs: None |
| 45 | + ) |
| 46 | + monkeypatch.setattr(extract_mod, "_validate_schema", _noop_validate_schema) |
| 47 | + |
| 48 | + |
| 49 | +def test_convert_fileinput_accepts_llama_file_directly( |
| 50 | + extractor: LlamaExtract, llama_file: CloudFile |
| 51 | +): |
| 52 | + result = extractor._convert_file_to_file_data(llama_file) |
| 53 | + assert result is llama_file |
| 54 | + |
| 55 | + |
| 56 | +@pytest.mark.asyncio |
| 57 | +async def test_queue_extraction_with_llama_file_uses_file_id( |
| 58 | + extractor: LlamaExtract, llama_file: CloudFile, no_external_validation, monkeypatch |
| 59 | +): |
| 60 | + calls = [] |
| 61 | + |
| 62 | + async def fake_extract_stateless(**kwargs): |
| 63 | + calls.append(kwargs) |
| 64 | + return SimpleNamespace(id="job_1") |
| 65 | + |
| 66 | + # Patch the client's method that would normally hit the network |
| 67 | + monkeypatch.setattr( |
| 68 | + extractor._async_client.llama_extract, |
| 69 | + "extract_stateless", |
| 70 | + fake_extract_stateless, |
| 71 | + ) |
| 72 | + |
| 73 | + # Minimal schema and dummy config (warnings disabled by fixture) |
| 74 | + schema = {"type": "object", "properties": {}} |
| 75 | + dummy_config = SimpleNamespace() |
| 76 | + |
| 77 | + job = await extractor.queue_extraction(schema, dummy_config, llama_file) |
| 78 | + |
| 79 | + assert getattr(job, "id") == "job_1" |
| 80 | + assert len(calls) == 1 |
| 81 | + kwargs = calls[0] |
| 82 | + assert "file_id" in kwargs and kwargs["file_id"] == llama_file.id |
| 83 | + assert "file" not in kwargs |
| 84 | + assert "text" not in kwargs |
| 85 | + |
| 86 | + |
| 87 | +@pytest.mark.asyncio |
| 88 | +async def test_extraction_agent_upload_file_accepts_llama_file_directly( |
| 89 | + llama_file: CloudFile, |
| 90 | +): |
| 91 | + # Build a minimal agent without hitting external services |
| 92 | + dummy_async_client = SimpleNamespace() |
| 93 | + dummy_agent = SimpleNamespace(id="agent_1", name="dummy", data_schema={}, config={}) |
| 94 | + |
| 95 | + agent = ExtractionAgent( |
| 96 | + client=dummy_async_client, |
| 97 | + agent=dummy_agent, |
| 98 | + project_id=None, |
| 99 | + organization_id=None, |
| 100 | + check_interval=0, |
| 101 | + max_timeout=0, |
| 102 | + num_workers=1, |
| 103 | + show_progress=False, |
| 104 | + verbose=False, |
| 105 | + verify=False, |
| 106 | + httpx_timeout=1, |
| 107 | + ) |
| 108 | + |
| 109 | + result = await agent._upload_file(llama_file) |
| 110 | + assert result is llama_file |
| 111 | + |
| 112 | + |
| 113 | +@pytest.mark.asyncio |
| 114 | +async def test_extraction_agent_aextract_accepts_llama_file( |
| 115 | + monkeypatch, llama_file: CloudFile |
| 116 | +): |
| 117 | + # Build a minimal agent without network |
| 118 | + dummy_llama_extract_iface = SimpleNamespace() |
| 119 | + |
| 120 | + async def fake_run_job(**kwargs): |
| 121 | + # Ensure we are receiving a request with the right file_id |
| 122 | + request = kwargs.get("request") |
| 123 | + assert hasattr(request, "file_id") |
| 124 | + assert request.file_id == llama_file.id |
| 125 | + return SimpleNamespace(id="job_42") |
| 126 | + |
| 127 | + dummy_llama_extract_iface.run_job = fake_run_job |
| 128 | + dummy_async_client = SimpleNamespace(llama_extract=dummy_llama_extract_iface) |
| 129 | + dummy_agent = SimpleNamespace(id="agent_1", name="dummy", data_schema={}, config={}) |
| 130 | + |
| 131 | + agent = ExtractionAgent( |
| 132 | + client=dummy_async_client, |
| 133 | + agent=dummy_agent, |
| 134 | + project_id=None, |
| 135 | + organization_id=None, |
| 136 | + check_interval=0, |
| 137 | + max_timeout=0, |
| 138 | + num_workers=1, |
| 139 | + show_progress=False, |
| 140 | + verbose=False, |
| 141 | + verify=False, |
| 142 | + httpx_timeout=1, |
| 143 | + ) |
| 144 | + |
| 145 | + # Ensure _upload_file returns the File directly and is called with our File |
| 146 | + calls = {} |
| 147 | + |
| 148 | + async def fake_upload_file(file_input): |
| 149 | + calls["upload_called_with"] = file_input |
| 150 | + assert file_input is llama_file |
| 151 | + return file_input |
| 152 | + |
| 153 | + monkeypatch.setattr(agent, "_upload_file", fake_upload_file) |
| 154 | + |
| 155 | + # Avoid polling logic by short-circuiting result wait |
| 156 | + async def fake_wait(job_id: str): |
| 157 | + assert job_id == "job_42" |
| 158 | + return SimpleNamespace(id="run_42", status="SUCCESS", data={}) |
| 159 | + |
| 160 | + monkeypatch.setattr(agent, "_wait_for_job_result", fake_wait) |
| 161 | + |
| 162 | + result = await agent.aextract(llama_file) |
| 163 | + |
| 164 | + assert calls.get("upload_called_with") is llama_file |
| 165 | + assert getattr(result, "status") == "SUCCESS" |
| 166 | + assert getattr(result, "id") == "run_42" |
0 commit comments