Skip to content

Commit afa1309

Browse files
committed
filter-to-session
1 parent e29eb6c commit afa1309

File tree

5 files changed

+50
-24
lines changed

5 files changed

+50
-24
lines changed

graphrag_sdk/attribute.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -57,15 +57,15 @@ class Attribute:
5757
"""
5858

5959
def __init__(
60-
self, name: str, attr_type: AttributeType, unique: bool, required: bool = False
60+
self, name: str, attr_type: AttributeType, unique: bool = False, required: bool = False
6161
):
6262
"""
6363
Initialize a new Attribute object.
6464
6565
Args:
6666
name (str): The name of the attribute.
6767
attr_type (AttributeType): The type of the attribute.
68-
unique (bool): Indicates whether the attribute should be unique.
68+
unique (bool, optional): Indicates whether the attribute should be unique. Defaults to False.
6969
required (bool, optional): Indicates whether the attribute is required. Defaults to False.
7070
"""
7171
self.name = re.sub(r"([^a-zA-Z0-9_])", "_", name)
@@ -130,29 +130,24 @@ def from_string(txt: str):
130130

131131
return Attribute(name, AttributeType.from_string(attr_type), unique, required)
132132

133-
def to_json(self, include_all: bool = True):
133+
def to_json(self):
134134
"""
135135
Converts the attribute object to a JSON representation.
136136
137-
Args:
138-
include_all (bool): Whether to include both "unique" and "required" fields in the output. Default is True.
139-
140137
Returns:
141138
dict: A dictionary representing the attribute object in JSON format.
142139
The dictionary contains the following keys:
143140
- "name": The name of the attribute.
144141
- "type": The type of the attribute.
145-
Optionally includes:
146-
- "unique": A boolean indicating whether the attribute is unique (if include_all is True).
147-
- "required": A boolean indicating whether the attribute is required (if include_all is True).
142+
- "unique": A boolean indicating whether the attribute is unique.
143+
- "required": A boolean indicating whether the attribute is required.
148144
"""
149145
json_data = {
150146
"name": self.name,
151147
"type": self.type,
148+
"unique": self.unique,
149+
"required": self.required,
152150
}
153-
if include_all:
154-
json_data["unique"] = self.unique
155-
json_data["required"] = self.required
156151

157152
return json_data
158153

graphrag_sdk/chat_session.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import json
12
from falkordb import Graph
23
from graphrag_sdk.ontology import Ontology
34
from graphrag_sdk.steps.qa_step import QAStep
@@ -45,8 +46,11 @@ def __init__(self, model_config: KnowledgeGraphModelConfig, ontology: Ontology,
4546
self.model_config = model_config
4647
self.graph = graph
4748
self.ontology = ontology
48-
cypher_system_instruction = cypher_system_instruction.format(ontology=str(ontology.to_json(include_all=False)))
49-
49+
50+
# Filter the ontology to remove unique and required attributes that are not needed for Q&A.
51+
ontology_prompt = self.clean_ontology_for_prompt(ontology)
52+
53+
cypher_system_instruction = cypher_system_instruction.format(ontology=ontology_prompt)
5054

5155
self.cypher_prompt = cypher_gen_prompt
5256
self.qa_prompt = qa_prompt
@@ -108,4 +112,31 @@ def send_message(self, message: str):
108112
"response": answer,
109113
"context": context,
110114
"cypher": cypher
111-
}
115+
}
116+
117+
def clean_ontology_for_prompt(self, ontology: dict):
118+
"""
119+
Cleans the ontology by removing 'unique' and 'required' keys and prepares it for use in a prompt.
120+
121+
Args:
122+
ontology (dict): The ontology to clean and transform.
123+
124+
Returns:
125+
str: The cleaned ontology as a JSON string.
126+
"""
127+
# Convert the ontology object to a JSON.
128+
ontology = ontology.to_json()
129+
130+
# Remove unique and required attributes from the ontology.
131+
for entity in ontology["entities"]:
132+
for attribute in entity["attributes"]:
133+
del attribute['unique']
134+
del attribute['required']
135+
136+
for relation in ontology["relations"]:
137+
for attribute in relation["attributes"]:
138+
del attribute['unique']
139+
del attribute['required']
140+
141+
# Return the transformed ontology as a JSON string
142+
return json.dumps(ontology)

graphrag_sdk/entity.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def from_json(txt: dict | str):
8282
txt.get("description", ""),
8383
)
8484

85-
def to_json(self, include_all: bool = True) -> dict:
85+
def to_json(self) -> dict:
8686
"""
8787
Convert the entity object to a JSON representation.
8888
@@ -95,7 +95,7 @@ def to_json(self, include_all: bool = True) -> dict:
9595
"""
9696
return {
9797
"label": self.label,
98-
"attributes": [attr.to_json(include_all=include_all) for attr in self.attributes],
98+
"attributes": [attr.to_json() for attr in self.attributes],
9999
"description": self.description,
100100
}
101101

graphrag_sdk/ontology.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def from_kg_graph(graph: Graph, node_limit: int = 100,):
132132
attributes = graph.query(
133133
f"""MATCH (a:{label[0]}) call {{ with a return [k in keys(a) | [k, typeof(a[k])]] as types }}
134134
WITH types limit {node_limit} unwind types as kt RETURN kt, count(1)""").result_set
135-
ontology.add_entity(Entity(label[0], [Attribute(attr[0][0], attr[0][1], False, False) for attr in attributes]))
135+
ontology.add_entity(Entity(label[0], [Attribute(attr[0][0], attr[0][1]) for attr in attributes]))
136136

137137
# Process each relationship type and extract attributes, limited to the specified number of nodes
138138
for label in r_labels:
@@ -143,7 +143,7 @@ def from_kg_graph(graph: Graph, node_limit: int = 100,):
143143
attributes = graph.query(
144144
f"""MATCH ()-[a:{label[0]}]->() call {{ with a return [k in keys(a) | [k, typeof(a[k])]] as types }}
145145
WITH types limit {node_limit} unwind types as kt RETURN kt, count(1)""").result_set
146-
ontology.add_relation(Relation(label[0], label_s[0], label_t[0], [Attribute(attr[0][0], attr[0][1], False, False) for attr in attributes]))
146+
ontology.add_relation(Relation(label[0], label_s[0], label_t[0], [Attribute(attr[0][0], attr[0][1]) for attr in attributes]))
147147

148148
return ontology
149149

@@ -165,16 +165,16 @@ def add_relation(self, relation: Relation):
165165
"""
166166
self.relations.append(relation)
167167

168-
def to_json(self, include_all: bool = True) -> dict:
168+
def to_json(self) -> dict:
169169
"""
170170
Converts the ontology object to a JSON representation.
171171
172172
Returns:
173173
A dictionary representing the ontology object in JSON format.
174174
"""
175175
return {
176-
"entities": [entity.to_json(include_all) for entity in self.entities],
177-
"relations": [relation.to_json(include_all) for relation in self.relations],
176+
"entities": [entity.to_json() for entity in self.entities],
177+
"relations": [relation.to_json() for relation in self.relations],
178178
}
179179

180180
def merge_with(self, o: "Ontology"):

graphrag_sdk/relation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ def from_string(txt: str):
205205
[Attribute.from_string(attr) for attr in attributes],
206206
)
207207

208-
def to_json(self, include_all: bool = True) -> dict:
208+
def to_json(self) -> dict:
209209
"""
210210
Converts the Relation object to a JSON dictionary.
211211
@@ -216,7 +216,7 @@ def to_json(self, include_all: bool = True) -> dict:
216216
"label": self.label,
217217
"source": self.source.to_json(),
218218
"target": self.target.to_json(),
219-
"attributes": [attr.to_json(include_all=include_all) for attr in self.attributes],
219+
"attributes": [attr.to_json() for attr in self.attributes],
220220
}
221221

222222
def combine(self, relation2: "Relation"):

0 commit comments

Comments
 (0)