Skip to content

Commit e507bea

Browse files
authored
Merge pull request #40 from FalkorDB/litellm-int
LiteLLM integration
2 parents 0fa57ce + f2e2574 commit e507bea

File tree

8 files changed

+1289
-54
lines changed

8 files changed

+1289
-54
lines changed

.env.template

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,12 @@
55
# Endpoint format: https://{your-resource-name}.openai.azure.com
66
AZURE_ENDPOINT="AZURE_ENDPOINT"
77
# API Version (e.g., 2023-05-15)
8-
AZURE_API_VERSION="AZURE_API_VERSION"
8+
AZURE_API_VERSION="AZURE_API_VERSION"
9+
10+
# For LiteLLM usage
11+
GEMINI_API_KEY = "GEMINI_API_KEY"
12+
OLLAMA_API_BASE = "OLLAMA_API_BASE" # "http://localhost:11434"
13+
14+
AZURE_API_KEY = "AZURE_API_KEY" # "my-azure-api-key"
15+
AZURE_API_BASE = "AZURE_API_BASE" # "https://example-endpoint.openai.azure.com"
16+
AZURE_API_VERSION = "AZURE_API_VERSION" # "2023-05-15"

.wordlist.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,4 +27,5 @@ www
2727
faq
2828
Ollama
2929
ollama
30-
Cypher
30+
Cypher
31+
LiteLLM

README.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ GraphRAG-SDK is a comprehensive solution for building Graph Retrieval-Augmented
1313

1414
* Ontology Management: Manage ontologies either manually or automatically from unstructured data.
1515
* Knowledge Graph (KG): Construct and query knowledge graphs for efficient data retrieval.
16-
* LLMs Integration: Support for OpenAI and Google Gemini models.
16+
* LLMs Integration: Support for OpenAI, Google Gemini, Ollama and LiteLLM models.
1717
* Multi-Agent System: Multi-agent orchestrators using KG-based agents.
1818

1919
## Get Started
@@ -50,6 +50,8 @@ Currently, this SDK supports the following LLMs API:
5050
* [Google](https://makersuite.google.com/app/apikey) Recommended model:`gemini-1.5-flash-001`
5151
* [Azure-OpenAI](https://ai.azure.com) Recommended model:`gpt-4o`
5252
* [Ollama](https://ollama.com/) Available only to the Q&A step. Recommended models: `llama3`. Ollama models are suitable for the Q&A step only (after the knowledge graph (KG) created).
53+
* [LiteLLM](https://docs.litellm.ai): A framework supporting inference of large language models, allowing flexibility in deployment and use cases.
54+
5355

5456
## Basic Usage
5557

graphrag_sdk/models/litellm.py

Lines changed: 291 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,291 @@
1+
import os
2+
import logging
3+
from ollama import Client
4+
from typing import Optional
5+
from litellm import completion, validate_environment, utils as litellm_utils
6+
from .model import (
7+
OutputMethod,
8+
GenerativeModel,
9+
GenerativeModelConfig,
10+
GenerationResponse,
11+
FinishReason,
12+
GenerativeModelChatSession,
13+
)
14+
15+
logger = logging.getLogger(__name__)
16+
logger.setLevel(logging.INFO)
17+
18+
class LiteModel(GenerativeModel):
19+
"""
20+
A generative model that interfaces with the LiteLLM for chat completions.
21+
"""
22+
23+
def __init__(
24+
self,
25+
model_name: str,
26+
generation_config: Optional[GenerativeModelConfig] = None,
27+
system_instruction: Optional[str] = None,
28+
):
29+
"""
30+
Initialize the LiteModel with the required parameters.
31+
32+
LiteLLM model_name format: <provider>/<model_name>
33+
Examples:
34+
- openai/gpt-4o
35+
- azure/gpt-4o
36+
- gemini/gemini-1.5-pro
37+
- ollama/llama3:8b
38+
39+
Args:
40+
model_name (str): The name and the provider for the LiteLLM client.
41+
generation_config (Optional[GenerativeModelConfig]): Configuration settings for generation.
42+
system_instruction (Optional[str]): Instruction to guide the model.
43+
"""
44+
45+
46+
env_val = validate_environment(model_name)
47+
if not env_val['keys_in_environment']:
48+
raise ValueError(f"Missing {env_val['missing_keys']} in the environment.")
49+
self.model_name, provider, _, _ = litellm_utils.get_llm_provider(model_name)
50+
self.model = model_name
51+
52+
if provider == "ollama":
53+
self.ollama_client = Client()
54+
self.check_and_pull_model()
55+
if not self.check_valid_key(model_name):
56+
raise ValueError(f"Invalid keys for model {model_name}.")
57+
58+
59+
self.generation_config = generation_config or GenerativeModelConfig()
60+
self.system_instruction = system_instruction
61+
62+
def check_valid_key(self, model: str):
63+
"""
64+
Checks if the environment key is valid for a specific model by making a litellm.completion call with max_tokens=10
65+
66+
Args:
67+
model (str): The name of the model to check the key against.
68+
69+
Returns:
70+
bool: True if the key is valid for the model, False otherwise.
71+
"""
72+
messages = [{"role": "user", "content": "Hey, how's it going?"}]
73+
try:
74+
completion(
75+
model=model, messages=messages, max_tokens=10
76+
)
77+
return True
78+
except:
79+
return False
80+
81+
def check_and_pull_model(self) -> None:
82+
"""
83+
Checks if the specified model is available locally, and pulls it if not.
84+
85+
Logs:
86+
- Info: If the model is already available or after successfully pulling the model.
87+
- Error: If there is a failure in pulling the model.
88+
89+
Raises:
90+
Exception: If there is an error during the model pull process.
91+
"""
92+
# Get the list of available models
93+
response = self.ollama_client.list() # This returns a dictionary
94+
available_models = [model['name'] for model in response['models']] # Extract model names
95+
96+
# Check if the model is already pulled
97+
if self.model_name in available_models:
98+
logger.info(f"The model '{self.model_name}' is already available.")
99+
else:
100+
logger.info(f"Pulling the model '{self.model_name}'...")
101+
try:
102+
self.ollama_client.pull(self.model_name) # Pull the model
103+
logger.info(f"Model '{self.model_name}' pulled successfully.")
104+
except Exception as e:
105+
logger.error(f"Failed to pull the model '{self.model_name}': {e}")
106+
raise ValueError(f"Failed to pull the model '{self.model_name}': {e}")
107+
108+
def with_system_instruction(self, system_instruction: str) -> "GenerativeModel":
109+
"""
110+
Set or update the system instruction for new model instance.
111+
112+
Args:
113+
system_instruction (str): Instruction for guiding the model's behavior.
114+
115+
Returns:
116+
GenerativeModel: The updated model instance.
117+
"""
118+
self.system_instruction = system_instruction
119+
return self
120+
121+
def start_chat(self, args: Optional[dict] = None) -> GenerativeModelChatSession:
122+
"""
123+
Start a new chat session.
124+
125+
Args:
126+
args (Optional[dict]): Additional arguments for the chat session.
127+
128+
Returns:
129+
GenerativeModelChatSession: A new instance of the chat session.
130+
"""
131+
return LiteModelChatSession(self, args)
132+
133+
def parse_generate_content_response(self, response: any) -> GenerationResponse:
134+
"""
135+
Parse the model's response and extract content for the user.
136+
137+
Args:
138+
response (any): The raw response from the model.
139+
140+
Returns:
141+
GenerationResponse: Parsed response containing the generated text.
142+
"""
143+
return GenerationResponse(
144+
text=response.choices[0].message.content,
145+
finish_reason=(
146+
FinishReason.STOP
147+
if response.choices[0].finish_reason == "stop"
148+
else (
149+
FinishReason.MAX_TOKENS
150+
if response.choices[0].finish_reason == "length"
151+
else FinishReason.OTHER
152+
)
153+
),
154+
)
155+
156+
def to_json(self) -> dict:
157+
"""
158+
Serialize the model's configuration and state to JSON format.
159+
160+
Returns:
161+
dict: The serialized JSON data.
162+
"""
163+
return {
164+
"model_name": self.model_name,
165+
"generation_config": self.generation_config.to_json(),
166+
"system_instruction": self.system_instruction,
167+
}
168+
169+
@staticmethod
170+
def from_json(json: dict) -> "GenerativeModel":
171+
"""
172+
Deserialize a JSON object to create an instance of LiteLLMGenerativeModel.
173+
174+
Args:
175+
json (dict): The serialized JSON data.
176+
177+
Returns:
178+
GenerativeModel: A new instance of the model.
179+
"""
180+
return LiteModel(
181+
json["model_name"],
182+
generation_config=GenerativeModelConfig.from_json(
183+
json["generation_config"]
184+
),
185+
system_instruction=json["system_instruction"],
186+
)
187+
188+
189+
class LiteModelChatSession(GenerativeModelChatSession):
190+
"""
191+
A chat session for interacting with the LiteLLM model, maintaining conversation history.
192+
"""
193+
194+
def __init__(self, model: LiteModel, args: Optional[dict] = None):
195+
"""
196+
Initialize the chat session and set up the conversation history.
197+
198+
Args:
199+
model (LiteLLMGenerativeModel): The model instance for the session.
200+
args (Optional[dict]): Additional arguments for customization.
201+
"""
202+
self._model = model
203+
self._args = args
204+
self._chat_history = (
205+
[{"role": "system", "content": self._model.system_instruction}]
206+
if self._model.system_instruction is not None
207+
else []
208+
)
209+
210+
def get_chat_history(self) -> list[dict]:
211+
"""
212+
Retrieve the conversation history for the current chat session.
213+
214+
Returns:
215+
list[dict]: The chat session's conversation history.
216+
"""
217+
return self._chat_history.copy()
218+
219+
def send_message(self, message: str, output_method: OutputMethod = OutputMethod.DEFAULT) -> GenerationResponse:
220+
"""
221+
Send a message in the chat session and receive the model's response.
222+
223+
Args:
224+
message (str): The message to send.
225+
output_method (OutputMethod): Format for the model's output.
226+
227+
Returns:
228+
GenerationResponse: The generated response.
229+
"""
230+
generation_config = self._adjust_generation_config(output_method)
231+
self._chat_history.append({"role": "user", "content": message})
232+
try:
233+
response = completion(
234+
model=self._model.model,
235+
messages=self._chat_history,
236+
**generation_config
237+
)
238+
except Exception as e:
239+
raise ValueError(f"Error during completion request, please check the credentials - {e}")
240+
content = self._model.parse_generate_content_response(response)
241+
self._chat_history.append({"role": "assistant", "content": content.text})
242+
return content
243+
244+
def _adjust_generation_config(self, output_method: OutputMethod):
245+
"""
246+
Adjust the generation configuration based on the specified output method.
247+
248+
Args:
249+
output_method (OutputMethod): The desired output method (e.g., default or JSON).
250+
251+
Returns:
252+
dict: The adjusted configuration settings for generation.
253+
"""
254+
config = self._model.generation_config.to_json()
255+
if output_method == OutputMethod.JSON:
256+
config['temperature'] = 0
257+
config['response_format'] = { "type": "json_object" }
258+
259+
return config
260+
261+
def delete_last_message(self):
262+
"""
263+
Deletes the last message exchange (user message and assistant response) from the chat history.
264+
Preserves the system message if present.
265+
266+
Example:
267+
Before:
268+
[
269+
{"role": "system", "content": "System message"},
270+
{"role": "user", "content": "User message"},
271+
{"role": "assistant", "content": "Assistant response"},
272+
]
273+
After:
274+
[
275+
{"role": "system", "content": "System message"},
276+
]
277+
278+
Note: Does nothing if the chat history is empty or contains only a system message.
279+
"""
280+
# Keep at least the system message if present
281+
min_length = 1 if self._model.system_instruction else 0
282+
if len(self._chat_history) - 2 >= min_length:
283+
self._chat_history.pop()
284+
self._chat_history.pop()
285+
else:
286+
# Reset to initial state with just system message if present
287+
self._chat_history = (
288+
[{"role": "system", "content": self._model.system_instruction}]
289+
if self._model.system_instruction is not None
290+
else []
291+
)

graphrag_sdk/steps/create_ontology_step.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from graphrag_sdk.models import (
1818
GenerativeModel,
1919
GenerativeModelChatSession,
20-
GenerativeModelConfig,
2120
GenerationResponse,
2221
FinishReason,
2322
)

0 commit comments

Comments
 (0)