Skip to content

Commit b38a8d7

Browse files
committed
fix-excep-process-fail
1 parent f1e2bfd commit b38a8d7

File tree

4 files changed

+23
-16
lines changed

4 files changed

+23
-16
lines changed

graphrag_sdk/models/litellm.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def __init__(
2525
model_name: str,
2626
generation_config: Optional[GenerativeModelConfig] = None,
2727
system_instruction: Optional[str] = None,
28+
host: Optional[str] = None,
2829
):
2930
"""
3031
Initialize the LiteModel with the required parameters.
@@ -33,9 +34,13 @@ def __init__(
3334
model_name (str): The name and the provider for the LiteLLM client.
3435
generation_config (Optional[GenerativeModelConfig]): Configuration settings for generation.
3536
system_instruction (Optional[str]): Instruction to guide the model.
37+
host (Optional[str]): Host for connecting to the Ollama API.
3638
"""
3739

40+
self.host = host
3841
self.model = model_name
42+
43+
# LiteLLM model name format: <provider>/<model_name> - Example: openai/gpt-4o
3944
if "/" in model_name:
4045
self.provider = model_name.split("/")[0]
4146
self.model_name = model_name.split("/")[1]
@@ -67,7 +72,7 @@ def credentials_validation(self) -> None:
6772
raise ValueError("Missing Gemini API key in the environment.")
6873

6974
elif self.provider == "ollama":
70-
self.ollama_client = Client()
75+
self.ollama_client = Client(self.host)
7176
self.check_and_pull_model()
7277

7378
def check_and_pull_model(self) -> None:
@@ -224,11 +229,11 @@ def send_message(self, message: str, output_method: OutputMethod = OutputMethod.
224229
response = completion(
225230
model=self._model.model,
226231
messages=self._chat_history,
232+
api_base=self._model.host,
227233
**generation_config
228234
)
229235
except Exception as e:
230-
# Handle exception (e.g., log error, retry, or raise a custom exception)
231-
raise RuntimeError("Error during completion request") from e
236+
raise ValueError(f"Error during completion request, please check the credentials - {e}")
232237
content = self._model.parse_generate_content_response(response)
233238
self._chat_history.append({"role": "assistant", "content": content.text})
234239
return content

graphrag_sdk/steps/create_ontology_step.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from graphrag_sdk.models import (
1818
GenerativeModel,
1919
GenerativeModelChatSession,
20-
GenerativeModelConfig,
2120
GenerationResponse,
2221
FinishReason,
2322
)
@@ -71,7 +70,12 @@ def run(self, boundaries: Optional[str] = None):
7170
tasks.append(task)
7271

7372
# Wait for all tasks to complete
74-
wait(tasks)
73+
done, _ = wait(tasks) # Get completed tasks
74+
for task in done:
75+
try:
76+
task.result() # Re-raise any exceptions from _process_source
77+
except Exception as e:
78+
raise ValueError(f"Error in task: {e}")
7579

7680
for task in tasks:
7781
self.ontology = self.ontology.merge_with(task.result())

graphrag_sdk/steps/extract_data_step.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,12 @@ def run(self, instructions: str = None):
9090
tasks.append(task)
9191

9292
# Wait for all tasks to complete
93-
wait(tasks)
93+
done, _ = wait(tasks) # Get completed tasks
94+
for task in done:
95+
try:
96+
task.result() # Re-raise any exceptions from _process_source
97+
except Exception as e:
98+
raise ValueError(f"Error in task: {e}")
9499

95100
def _process_source(
96101
self,
@@ -195,7 +200,7 @@ def _process_source(
195200

196201
except Exception as e:
197202
logger.exception(e)
198-
raise e
203+
raise ValueError(f"Error processing source: {e}")
199204

200205
def _create_entity(self, graph: Graph, args: dict, ontology: Ontology):
201206
# Get unique attributes from entity

tests/test_kg_litellm_openai.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -84,19 +84,12 @@ def test_kg_creation(self):
8484
self.kg.process_sources(sources)
8585

8686
chat = self.kg.chat_session()
87-
answer = chat.send_message("How many actors acted in a movie?")
87+
answer = chat.send_message("How many different actors there are?")
8888
answer = answer['response']
8989

9090
logger.info(f"Answer: {answer}")
9191

9292
actors_count = re.findall(r'\d+', answer)
9393
num_actors = 0 if len(actors_count) == 0 else int(actors_count[0])
9494

95-
assert num_actors > 10, "The number of actors found should be greater than 10"
96-
97-
def test_kg_delete(self):
98-
self.kg.delete()
99-
100-
db = FalkorDB()
101-
graphs = db.list_graphs()
102-
self.assertNotIn(self.graph_name, graphs)
95+
assert num_actors > 10, "The number of actors found should be greater than 10"

0 commit comments

Comments
 (0)