9
9
import time
10
10
from collections .abc import AsyncIterable
11
11
from dataclasses import asdict
12
+ from typing import Any
12
13
13
14
from graphrag .callbacks .workflow_callbacks import WorkflowCallbacks
14
15
from graphrag .config .models .graph_rag_config import GraphRagConfig
@@ -28,7 +29,7 @@ async def run_pipeline(
28
29
config : GraphRagConfig ,
29
30
callbacks : WorkflowCallbacks ,
30
31
is_update_run : bool = False ,
31
- additional_context : dict | None = None ,
32
+ additional_context : dict [ str , Any ] | None = None ,
32
33
) -> AsyncIterable [PipelineRunResult ]:
33
34
"""Run all workflows using a simplified pipeline."""
34
35
root_dir = config .root_dir
@@ -41,8 +42,13 @@ async def run_pipeline(
41
42
state_json = await output_storage .get ("context.json" )
42
43
state = json .loads (state_json ) if state_json else {}
43
44
44
- for key , value in (additional_context or {}).items ():
45
- state ["additional_context" ][key ] = value
45
+ if additional_context is not None :
46
+ if "additional_context" not in state :
47
+ state ["additional_context" ] = {}
48
+
49
+ # add additional context to the state
50
+ for key , value in (additional_context or {}).items ():
51
+ state ["additional_context" ][key ] = value
46
52
47
53
if is_update_run :
48
54
logger .info ("Running incremental indexing." )
0 commit comments