Skip to content

Commit febb7a4

Browse files
committed
ci-test
1 parent 796d68e commit febb7a4

File tree

2 files changed

+179
-5
lines changed

2 files changed

+179
-5
lines changed

graphrag_sdk/ontology.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -131,22 +131,23 @@ def from_kg_graph(graph: Graph, sample_size: int = 100,):
131131
l = lbls[0]
132132
attributes = graph.query(
133133
f"""MATCH (a:{l}) call {{ with a return [k in keys(a) | [k, typeof(a[k])]] as types }}
134-
WITH types limit {sample_size} unwind types as kt RETURN kt, count(1)""").result_set
135-
ontology.add_entity(Entity(l, [Attribute(attr[0][0], attr[0][1]) for attr in attributes]))
134+
WITH types limit {sample_size} unwind types as kt RETURN kt, count(1) ORDER BY kt[0]""").result_set
135+
ontology.add_entity(Entity(l, [Attribute(attr[0][0], 'number' if attr[0][1] == 'Integer' or attr[0][1] == 'Float' else attr[0][1].lower()) for attr in attributes]))
136136

137137
# Extract attributes for each edge type, limited by the specified sample size.
138138
for e_type in e_types:
139139
e_t = e_type[0]
140140
attributes = graph.query(
141141
f"""MATCH ()-[a:{e_t}]->() call {{ with a return [k in keys(a) | [k, typeof(a[k])]] as types }}
142-
WITH types limit {sample_size} unwind types as kt RETURN kt, count(1)""").result_set
143-
attributes = [Attribute(attr[0][0], attr[0][1]) for attr in attributes]
142+
WITH types limit {sample_size} unwind types as kt RETURN kt, count(1) ORDER BY kt[0]""").result_set
143+
attributes = [Attribute(attr[0][0], 'number' if attr[0][1] == 'Integer' or attr[0][1] == 'Float' else attr[0][1].lower()) for attr in attributes]
144144
for s_lbls in n_labels:
145145
for t_lbls in n_labels:
146146
s_l = s_lbls[0]
147147
t_l = t_lbls[0]
148148
# Check if a relationship exists between the source and target entity labels
149-
if graph.query(f"MATCH (s:{s_l})-[a:{e_t}]->(t:{t_l}) return a limit 1").result_set:
149+
result_set = graph.query(f"MATCH (s:{s_l})-[a:{e_t}]->(t:{t_l}) return a limit 1").result_set
150+
if len(result_set) > 0:
150151
ontology.add_relation(Relation(e_t, s_l, t_l, attributes))
151152

152153
return ontology

tests/test_ontology_from_kg.py

Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
2+
import logging
3+
import unittest
4+
from json import loads
5+
from falkordb import FalkorDB
6+
from dotenv import load_dotenv
7+
from graphrag_sdk.entity import Entity
8+
from graphrag_sdk.ontology import Ontology
9+
from graphrag_sdk.relation import Relation
10+
from graphrag_sdk.attribute import Attribute, AttributeType
11+
from graphrag_sdk.models.gemini import GeminiGenerativeModel
12+
from graphrag_sdk import KnowledgeGraph, KnowledgeGraphModelConfig
13+
14+
load_dotenv()
15+
16+
logging.basicConfig(level=logging.DEBUG)
17+
logger = logging.getLogger(__name__)
18+
19+
20+
class TestOntologyFromKG(unittest.TestCase):
21+
@classmethod
22+
def setUpClass(cls):
23+
24+
cls.ontology = Ontology()
25+
cls.ontology.add_entity(
26+
Entity(
27+
label="City",
28+
attributes=[
29+
Attribute(
30+
name="name",
31+
attr_type=AttributeType.STRING,
32+
required=False,
33+
unique=False,
34+
),
35+
Attribute(
36+
name="population",
37+
attr_type=AttributeType.NUMBER,
38+
required=False,
39+
unique=False,
40+
),
41+
Attribute(
42+
name="weather",
43+
attr_type=AttributeType.STRING,
44+
required=False,
45+
unique=False,
46+
),
47+
],
48+
)
49+
)
50+
cls.ontology.add_entity(
51+
Entity(
52+
label="Country",
53+
attributes=[
54+
Attribute(
55+
name="name",
56+
attr_type=AttributeType.STRING,
57+
required=False,
58+
unique=False,
59+
),
60+
],
61+
)
62+
)
63+
cls.ontology.add_entity(
64+
Entity(
65+
label="Restaurant",
66+
attributes=[
67+
Attribute(
68+
name="description",
69+
attr_type=AttributeType.STRING,
70+
required=False,
71+
unique=False,
72+
),
73+
Attribute(
74+
name="food_type",
75+
attr_type=AttributeType.STRING,
76+
required=False,
77+
unique=False,
78+
),
79+
Attribute(
80+
name="name",
81+
attr_type=AttributeType.STRING,
82+
required=False,
83+
unique=False,
84+
),
85+
Attribute(
86+
name="rating",
87+
attr_type=AttributeType.NUMBER,
88+
required=False,
89+
unique=False,
90+
),
91+
],
92+
)
93+
)
94+
cls.ontology.add_relation(
95+
Relation(
96+
label="IN_COUNTRY",
97+
source="City",
98+
target="Country",
99+
)
100+
)
101+
cls.ontology.add_relation(
102+
Relation(
103+
label="IN_CITY",
104+
source="Restaurant",
105+
target="City",
106+
)
107+
)
108+
cls.model = GeminiGenerativeModel("gemini-1.5-flash-001")
109+
cls.kg = KnowledgeGraph(
110+
name="test_ontology",
111+
ontology=cls.ontology,
112+
model_config=KnowledgeGraphModelConfig.with_model(cls.model),
113+
)
114+
cls.import_data(cls.kg)
115+
116+
@classmethod
117+
def import_data(
118+
self,
119+
kg: KnowledgeGraph,
120+
):
121+
with open("tests/data/cities.json") as f:
122+
cities = loads(f.read())
123+
with open("tests/data/restaurants.json") as f:
124+
restaurants = loads(f.read())
125+
126+
for city in cities:
127+
kg.add_node(
128+
"City",
129+
{
130+
"name": city["name"],
131+
"weather": city["weather"],
132+
"population": city["population"],
133+
},
134+
)
135+
kg.add_node("Country", {"name": city["country"]})
136+
kg.add_edge(
137+
"IN_COUNTRY",
138+
"City",
139+
"Country",
140+
{"name": city["name"]},
141+
{"name": city["country"]},
142+
)
143+
144+
for restaurant in restaurants:
145+
kg.add_node(
146+
"Restaurant",
147+
{
148+
"name": restaurant["name"],
149+
"description": restaurant["description"],
150+
"rating": restaurant["rating"],
151+
"food_type": restaurant["food_type"],
152+
},
153+
)
154+
kg.add_edge(
155+
"IN_CITY",
156+
"Restaurant",
157+
"City",
158+
{"name": restaurant["name"]},
159+
{"name": restaurant["city"]},
160+
)
161+
162+
# Delete graph after tests
163+
@classmethod
164+
def tearDownClass(cls):
165+
logger.info("Cleaning up test graph...")
166+
cls.kg.delete()
167+
168+
def test_ontology_serialization(self):
169+
logger.info("Testing ontology serialization...")
170+
db = FalkorDB()
171+
graph = db.select_graph("test_ontology")
172+
ontology = Ontology.from_kg_graph(graph=graph)
173+
self.assertEqual(ontology.to_json(), self.ontology.to_json())

0 commit comments

Comments
 (0)