Skip to content

Commit daa924a

Browse files
authored
feat!: use orchestrators from jobkit (#248)
Signed-off-by: Michele Dolfi <[email protected]>
1 parent e63197e commit daa924a

30 files changed

+816
-2000
lines changed

docling_serve/__main__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
def version_callback(value: bool) -> None:
3131
if value:
3232
docling_serve_version = importlib.metadata.version("docling_serve")
33+
docling_jobkit_version = importlib.metadata.version("docling-jobkit")
3334
docling_version = importlib.metadata.version("docling")
3435
docling_core_version = importlib.metadata.version("docling-core")
3536
docling_ibm_models_version = importlib.metadata.version("docling-ibm-models")
@@ -38,6 +39,7 @@ def version_callback(value: bool) -> None:
3839
py_impl_version = sys.implementation.cache_tag
3940
py_lang_version = platform.python_version()
4041
console.print(f"Docling Serve version: {docling_serve_version}")
42+
console.print(f"Docling Jobkit version: {docling_jobkit_version}")
4143
console.print(f"Docling version: {docling_version}")
4244
console.print(f"Docling Core version: {docling_core_version}")
4345
console.print(f"Docling IBM Models version: {docling_ibm_models_version}")

docling_serve/app.py

Lines changed: 57 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,18 @@
2828
from scalar_fastapi import get_scalar_api_reference
2929

3030
from docling.datamodel.base_models import DocumentStream
31-
32-
from docling_serve.datamodel.callback import (
31+
from docling_jobkit.datamodel.callback import (
3332
ProgressCallbackRequest,
3433
ProgressCallbackResponse,
3534
)
36-
from docling_serve.datamodel.convert import ConvertDocumentsOptions
35+
from docling_jobkit.datamodel.task import Task, TaskSource
36+
from docling_jobkit.orchestrators.base_orchestrator import (
37+
BaseOrchestrator,
38+
ProgressInvalid,
39+
TaskNotFoundError,
40+
)
41+
42+
from docling_serve.datamodel.convert import ConvertDocumentsRequestOptions
3743
from docling_serve.datamodel.requests import (
3844
ConvertDocumentFileSourcesRequest,
3945
ConvertDocumentHttpSourcesRequest,
@@ -47,17 +53,12 @@
4753
TaskStatusResponse,
4854
WebsocketMessage,
4955
)
50-
from docling_serve.datamodel.task import Task, TaskSource
51-
from docling_serve.docling_conversion import _get_converter_from_hash
52-
from docling_serve.engines.async_orchestrator import (
53-
BaseAsyncOrchestrator,
54-
ProgressInvalid,
55-
)
56-
from docling_serve.engines.async_orchestrator_factory import get_async_orchestrator
57-
from docling_serve.engines.base_orchestrator import TaskNotFoundError
5856
from docling_serve.helper_functions import FormDepends
57+
from docling_serve.orchestrator_factory import get_async_orchestrator
58+
from docling_serve.response_preparation import prepare_response
5959
from docling_serve.settings import docling_serve_settings
6060
from docling_serve.storage import get_scratch
61+
from docling_serve.websocker_notifier import WebsocketNotifier
6162

6263

6364
# Set up custom logging as we'll be intermixes with FastAPI/Uvicorn's logging
@@ -95,9 +96,12 @@ def format(self, record):
9596
# Context manager to initialize and clean up the lifespan of the FastAPI app
9697
@asynccontextmanager
9798
async def lifespan(app: FastAPI):
98-
orchestrator = get_async_orchestrator()
9999
scratch_dir = get_scratch()
100100

101+
orchestrator = get_async_orchestrator()
102+
notifier = WebsocketNotifier(orchestrator)
103+
orchestrator.bind_notifier(notifier)
104+
101105
# Warm up processing cache
102106
if docling_serve_settings.load_models_at_boot:
103107
await orchestrator.warm_up_caches()
@@ -230,7 +234,7 @@ async def scalar_html():
230234
########################
231235

232236
async def _enque_source(
233-
orchestrator: BaseAsyncOrchestrator, conversion_request: ConvertDocumentsRequest
237+
orchestrator: BaseOrchestrator, conversion_request: ConvertDocumentsRequest
234238
) -> Task:
235239
sources: list[TaskSource] = []
236240
if isinstance(conversion_request, ConvertDocumentFileSourcesRequest):
@@ -244,9 +248,9 @@ async def _enque_source(
244248
return task
245249

246250
async def _enque_file(
247-
orchestrator: BaseAsyncOrchestrator,
251+
orchestrator: BaseOrchestrator,
248252
files: list[UploadFile],
249-
options: ConvertDocumentsOptions,
253+
options: ConvertDocumentsRequestOptions,
250254
) -> Task:
251255
_log.info(f"Received {len(files)} files for processing.")
252256

@@ -261,9 +265,7 @@ async def _enque_file(
261265
task = await orchestrator.enqueue(sources=file_sources, options=options)
262266
return task
263267

264-
async def _wait_task_complete(
265-
orchestrator: BaseAsyncOrchestrator, task_id: str
266-
) -> bool:
268+
async def _wait_task_complete(orchestrator: BaseOrchestrator, task_id: str) -> bool:
267269
start_time = time.monotonic()
268270
while True:
269271
task = await orchestrator.task_status(task_id=task_id)
@@ -309,32 +311,28 @@ def api_check() -> HealthCheckResponse:
309311
)
310312
async def process_url(
311313
background_tasks: BackgroundTasks,
312-
orchestrator: Annotated[BaseAsyncOrchestrator, Depends(get_async_orchestrator)],
314+
orchestrator: Annotated[BaseOrchestrator, Depends(get_async_orchestrator)],
313315
conversion_request: ConvertDocumentsRequest,
314316
):
315317
task = await _enque_source(
316318
orchestrator=orchestrator, conversion_request=conversion_request
317319
)
318-
success = await _wait_task_complete(
320+
completed = await _wait_task_complete(
319321
orchestrator=orchestrator, task_id=task.task_id
320322
)
321323

322-
if not success:
324+
if not completed:
323325
# TODO: abort task!
324326
return HTTPException(
325327
status_code=504,
326328
detail=f"Conversion is taking too long. The maximum wait time is configure as DOCLING_SERVE_MAX_SYNC_WAIT={docling_serve_settings.max_sync_wait}.",
327329
)
328330

329-
result = await orchestrator.task_result(
330-
task_id=task.task_id, background_tasks=background_tasks
331+
task = await orchestrator.get_raw_task(task_id=task.task_id)
332+
response = await prepare_response(
333+
task=task, orchestrator=orchestrator, background_tasks=background_tasks
331334
)
332-
if result is None:
333-
raise HTTPException(
334-
status_code=404,
335-
detail="Task result not found. Please wait for a completion status.",
336-
)
337-
return result
335+
return response
338336

339337
# Convert a document from file(s)
340338
@app.post(
@@ -348,43 +346,39 @@ async def process_url(
348346
)
349347
async def process_file(
350348
background_tasks: BackgroundTasks,
351-
orchestrator: Annotated[BaseAsyncOrchestrator, Depends(get_async_orchestrator)],
349+
orchestrator: Annotated[BaseOrchestrator, Depends(get_async_orchestrator)],
352350
files: list[UploadFile],
353351
options: Annotated[
354-
ConvertDocumentsOptions, FormDepends(ConvertDocumentsOptions)
352+
ConvertDocumentsRequestOptions, FormDepends(ConvertDocumentsRequestOptions)
355353
],
356354
):
357355
task = await _enque_file(
358356
orchestrator=orchestrator, files=files, options=options
359357
)
360-
success = await _wait_task_complete(
358+
completed = await _wait_task_complete(
361359
orchestrator=orchestrator, task_id=task.task_id
362360
)
363361

364-
if not success:
362+
if not completed:
365363
# TODO: abort task!
366364
return HTTPException(
367365
status_code=504,
368366
detail=f"Conversion is taking too long. The maximum wait time is configure as DOCLING_SERVE_MAX_SYNC_WAIT={docling_serve_settings.max_sync_wait}.",
369367
)
370368

371-
result = await orchestrator.task_result(
372-
task_id=task.task_id, background_tasks=background_tasks
369+
task = await orchestrator.get_raw_task(task_id=task.task_id)
370+
response = await prepare_response(
371+
task=task, orchestrator=orchestrator, background_tasks=background_tasks
373372
)
374-
if result is None:
375-
raise HTTPException(
376-
status_code=404,
377-
detail="Task result not found. Please wait for a completion status.",
378-
)
379-
return result
373+
return response
380374

381375
# Convert a document from URL(s) using the async api
382376
@app.post(
383377
"/v1alpha/convert/source/async",
384378
response_model=TaskStatusResponse,
385379
)
386380
async def process_url_async(
387-
orchestrator: Annotated[BaseAsyncOrchestrator, Depends(get_async_orchestrator)],
381+
orchestrator: Annotated[BaseOrchestrator, Depends(get_async_orchestrator)],
388382
conversion_request: ConvertDocumentsRequest,
389383
):
390384
task = await _enque_source(
@@ -406,11 +400,11 @@ async def process_url_async(
406400
response_model=TaskStatusResponse,
407401
)
408402
async def process_file_async(
409-
orchestrator: Annotated[BaseAsyncOrchestrator, Depends(get_async_orchestrator)],
403+
orchestrator: Annotated[BaseOrchestrator, Depends(get_async_orchestrator)],
410404
background_tasks: BackgroundTasks,
411405
files: list[UploadFile],
412406
options: Annotated[
413-
ConvertDocumentsOptions, FormDepends(ConvertDocumentsOptions)
407+
ConvertDocumentsRequestOptions, FormDepends(ConvertDocumentsRequestOptions)
414408
],
415409
):
416410
task = await _enque_file(
@@ -432,7 +426,7 @@ async def process_file_async(
432426
response_model=TaskStatusResponse,
433427
)
434428
async def task_status_poll(
435-
orchestrator: Annotated[BaseAsyncOrchestrator, Depends(get_async_orchestrator)],
429+
orchestrator: Annotated[BaseOrchestrator, Depends(get_async_orchestrator)],
436430
task_id: str,
437431
wait: Annotated[
438432
float, Query(help="Number of seconds to wait for a completed status.")
@@ -456,9 +450,10 @@ async def task_status_poll(
456450
)
457451
async def task_status_ws(
458452
websocket: WebSocket,
459-
orchestrator: Annotated[BaseAsyncOrchestrator, Depends(get_async_orchestrator)],
453+
orchestrator: Annotated[BaseOrchestrator, Depends(get_async_orchestrator)],
460454
task_id: str,
461455
):
456+
assert isinstance(orchestrator.notifier, WebsocketNotifier)
462457
await websocket.accept()
463458

464459
if task_id not in orchestrator.tasks:
@@ -473,7 +468,7 @@ async def task_status_ws(
473468
task = orchestrator.tasks[task_id]
474469

475470
# Track active WebSocket connections for this job
476-
orchestrator.task_subscribers[task_id].add(websocket)
471+
orchestrator.notifier.task_subscribers[task_id].add(websocket)
477472

478473
try:
479474
task_queue_position = await orchestrator.get_queue_position(task_id=task_id)
@@ -511,7 +506,7 @@ async def task_status_ws(
511506
_log.info(f"WebSocket disconnected for job {task_id}")
512507

513508
finally:
514-
orchestrator.task_subscribers[task_id].remove(websocket)
509+
orchestrator.notifier.task_subscribers[task_id].remove(websocket)
515510

516511
# Task result
517512
@app.get(
@@ -524,27 +519,26 @@ async def task_status_ws(
524519
},
525520
)
526521
async def task_result(
527-
orchestrator: Annotated[BaseAsyncOrchestrator, Depends(get_async_orchestrator)],
522+
orchestrator: Annotated[BaseOrchestrator, Depends(get_async_orchestrator)],
528523
background_tasks: BackgroundTasks,
529524
task_id: str,
530525
):
531-
result = await orchestrator.task_result(
532-
task_id=task_id, background_tasks=background_tasks
533-
)
534-
if result is None:
535-
raise HTTPException(
536-
status_code=404,
537-
detail="Task result not found. Please wait for a completion status.",
526+
try:
527+
task = await orchestrator.get_raw_task(task_id=task_id)
528+
response = await prepare_response(
529+
task=task, orchestrator=orchestrator, background_tasks=background_tasks
538530
)
539-
return result
531+
return response
532+
except TaskNotFoundError:
533+
raise HTTPException(status_code=404, detail="Task not found.")
540534

541535
# Update task progress
542536
@app.post(
543537
"/v1alpha/callback/task/progress",
544538
response_model=ProgressCallbackResponse,
545539
)
546540
async def callback_task_progress(
547-
orchestrator: Annotated[BaseAsyncOrchestrator, Depends(get_async_orchestrator)],
541+
orchestrator: Annotated[BaseOrchestrator, Depends(get_async_orchestrator)],
548542
request: ProgressCallbackRequest,
549543
):
550544
try:
@@ -564,8 +558,10 @@ async def callback_task_progress(
564558
"/v1alpha/clear/converters",
565559
response_model=ClearResponse,
566560
)
567-
async def clear_converters():
568-
_get_converter_from_hash.cache_clear()
561+
async def clear_converters(
562+
orchestrator: Annotated[BaseOrchestrator, Depends(get_async_orchestrator)],
563+
):
564+
await orchestrator.clear_converters()
569565
return ClearResponse()
570566

571567
# Clean results
@@ -574,7 +570,7 @@ async def clear_converters():
574570
response_model=ClearResponse,
575571
)
576572
async def clear_results(
577-
orchestrator: Annotated[BaseAsyncOrchestrator, Depends(get_async_orchestrator)],
573+
orchestrator: Annotated[BaseOrchestrator, Depends(get_async_orchestrator)],
578574
older_then: float = 3600,
579575
):
580576
await orchestrator.clear_results(older_than=older_then)

docling_serve/datamodel/callback.py

Lines changed: 0 additions & 50 deletions
This file was deleted.

0 commit comments

Comments
 (0)