Skip to content

Commit c4c88ee

Browse files
committed
add-litellm
1 parent 7c671ec commit c4c88ee

File tree

2 files changed

+236
-0
lines changed

2 files changed

+236
-0
lines changed

graphrag_sdk/models/litellm.py

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
from typing import Optional
2+
from litellm import completion
3+
from .model import (
4+
OutputMethod,
5+
GenerativeModel,
6+
GenerativeModelConfig,
7+
GenerationResponse,
8+
FinishReason,
9+
GenerativeModelChatSession,
10+
)
11+
12+
13+
class LiteLLMGenerativeModel(GenerativeModel):
14+
"""
15+
A generative model that interfaces with the LiteLLM for chat completions.
16+
"""
17+
18+
def __init__(
19+
self,
20+
model_name: str,
21+
generation_config: Optional[GenerativeModelConfig] = None,
22+
system_instruction: Optional[str] = None,
23+
):
24+
self.model_name = model_name
25+
self.generation_config = generation_config or GenerativeModelConfig()
26+
self.system_instruction = system_instruction
27+
28+
29+
def with_system_instruction(self, system_instruction: str) -> "GenerativeModel":
30+
self.system_instruction = system_instruction
31+
return self
32+
33+
def start_chat(self, args: Optional[dict] = None) -> GenerativeModelChatSession:
34+
return LiteLLMChatSession(self, args)
35+
36+
def parse_generate_content_response(self, response: any) -> GenerationResponse:
37+
return GenerationResponse(
38+
text=response.choices[0].message.content,
39+
finish_reason=(
40+
FinishReason.STOP
41+
if response.choices[0].finish_reason == "stop"
42+
else (
43+
FinishReason.MAX_TOKENS
44+
if response.choices[0].finish_reason == "length"
45+
else FinishReason.OTHER
46+
)
47+
),
48+
)
49+
50+
def to_json(self) -> dict:
51+
return {
52+
"model_name": self.model_name,
53+
"generation_config": self.generation_config.to_json(),
54+
"system_instruction": self.system_instruction,
55+
}
56+
57+
@staticmethod
58+
def from_json(json: dict) -> "GenerativeModel":
59+
return LiteLLMGenerativeModel(
60+
json["model_name"],
61+
generation_config=GenerativeModelConfig.from_json(
62+
json["generation_config"]
63+
),
64+
system_instruction=json["system_instruction"],
65+
)
66+
67+
68+
class LiteLLMChatSession(GenerativeModelChatSession):
69+
70+
_history = []
71+
72+
def __init__(self, model: LiteLLMGenerativeModel, args: Optional[dict] = None):
73+
self._model = model
74+
self._args = args
75+
self._history = (
76+
[{"role": "system", "content": self._model.system_instruction}]
77+
if self._model.system_instruction is not None
78+
else []
79+
)
80+
81+
def send_message(self, message: str, output_method: OutputMethod = OutputMethod.DEFAULT) -> GenerationResponse:
82+
generation_config = self._get_generation_config(output_method)
83+
prompt = []
84+
prompt.extend(self._history)
85+
prompt.append({"role": "user", "content": message[:14385]})
86+
response = completion(
87+
model=self._model.model_name,
88+
messages=prompt,
89+
**generation_config
90+
)
91+
content = self._model.parse_generate_content_response(response)
92+
self._history.append({"role": "user", "content": message})
93+
self._history.append({"role": "assistant", "content": content.text})
94+
return content
95+
96+
def _get_generation_config(self, output_method: OutputMethod):
97+
config = self._model.generation_config.to_json()
98+
if output_method == OutputMethod.JSON:
99+
config['temperature'] = 0
100+
config['response_format'] = { "type": "json_object" }
101+
102+
return config
103+
104+
def delete_last_message(self):
105+
"""
106+
Deletes the last message exchange (user message and assistant response) from the chat history.
107+
Preserves the system message if present.
108+
109+
Example:
110+
Before:
111+
[
112+
{"role": "system", "content": "System message"},
113+
{"role": "user", "content": "User message"},
114+
{"role": "assistant", "content": "Assistant response"},
115+
]
116+
After:
117+
[
118+
{"role": "system", "content": "System message"},
119+
]
120+
121+
Note: Does nothing if the chat history is empty or contains only a system message.
122+
"""
123+
# Keep at least the system message if present
124+
min_length = 1 if self._model.system_instruction else 0
125+
if len(self._history) - 2 >= min_length:
126+
self._history.pop()
127+
self._history.pop()
128+
else:
129+
# Reset to initial state with just system message if present
130+
self._history = (
131+
[{"role": "system", "content": self._model.system_instruction}]
132+
if self._model.system_instruction is not None
133+
else []
134+
)

tests/test_kg_litellm_openai.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
import re
2+
import logging
3+
import unittest
4+
from falkordb import FalkorDB
5+
from dotenv import load_dotenv
6+
from graphrag_sdk.entity import Entity
7+
from graphrag_sdk.source import Source
8+
from graphrag_sdk.relation import Relation
9+
from graphrag_sdk.ontology import Ontology
10+
from graphrag_sdk.attribute import Attribute, AttributeType
11+
from graphrag_sdk.models.litellm import LiteLLMGenerativeModel
12+
from graphrag_sdk import KnowledgeGraph, KnowledgeGraphModelConfig
13+
14+
load_dotenv()
15+
16+
logging.basicConfig(level=logging.DEBUG)
17+
logger = logging.getLogger(__name__)
18+
19+
class TestKGLiteLLM(unittest.TestCase):
20+
"""
21+
Test Knowledge Graph
22+
"""
23+
24+
@classmethod
25+
def setUpClass(cls):
26+
27+
cls.ontology = Ontology([], [])
28+
29+
cls.ontology.add_entity(
30+
Entity(
31+
label="Actor",
32+
attributes=[
33+
Attribute(
34+
name="name",
35+
attr_type=AttributeType.STRING,
36+
unique=True,
37+
required=True,
38+
),
39+
],
40+
)
41+
)
42+
cls.ontology.add_entity(
43+
Entity(
44+
label="Movie",
45+
attributes=[
46+
Attribute(
47+
name="title",
48+
attr_type=AttributeType.STRING,
49+
unique=True,
50+
required=True,
51+
),
52+
],
53+
)
54+
)
55+
cls.ontology.add_relation(
56+
Relation(
57+
label="ACTED_IN",
58+
source="Actor",
59+
target="Movie",
60+
attributes=[
61+
Attribute(
62+
name="role",
63+
attr_type=AttributeType.STRING,
64+
unique=False,
65+
required=False,
66+
),
67+
],
68+
)
69+
)
70+
cls.graph_name = "IMDB_openai"
71+
model = LiteLLMGenerativeModel(model_name="gpt-4o")
72+
cls.kg = KnowledgeGraph(
73+
name=cls.graph_name,
74+
ontology=cls.ontology,
75+
model_config=KnowledgeGraphModelConfig.with_model(model),
76+
)
77+
78+
def test_kg_creation(self):
79+
80+
file_path = "tests/data/madoff.txt"
81+
82+
sources = [Source(file_path)]
83+
84+
self.kg.process_sources(sources)
85+
86+
chat = self.kg.chat_session()
87+
answer = chat.send_message("How many actors acted in a movie?")
88+
answer = answer['response']
89+
90+
logger.info(f"Answer: {answer}")
91+
92+
actors_count = re.findall(r'\d+', answer)
93+
num_actors = 0 if len(actors_count) == 0 else int(actors_count[0])
94+
95+
assert num_actors > 10, "The number of actors found should be greater than 10"
96+
97+
def test_kg_delete(self):
98+
self.kg.delete()
99+
100+
db = FalkorDB()
101+
graphs = db.list_graphs()
102+
self.assertNotIn(self.graph_name, graphs)

0 commit comments

Comments
 (0)