1+
2+ import logging
3+ import unittest
4+ from json import loads
5+ from falkordb import FalkorDB
6+ from dotenv import load_dotenv
7+ from graphrag_sdk .entity import Entity
8+ from graphrag_sdk .ontology import Ontology
9+ from graphrag_sdk .relation import Relation
10+ from graphrag_sdk .attribute import Attribute , AttributeType
11+ from graphrag_sdk .models .gemini import GeminiGenerativeModel
12+ from graphrag_sdk import KnowledgeGraph , KnowledgeGraphModelConfig
13+
14+ load_dotenv ()
15+
16+ logging .basicConfig (level = logging .DEBUG )
17+ logger = logging .getLogger (__name__ )
18+
19+
20+ class TestOntologyFromKG (unittest .TestCase ):
21+ @classmethod
22+ def setUpClass (cls ):
23+
24+ cls .ontology = Ontology ()
25+ cls .ontology .add_entity (
26+ Entity (
27+ label = "City" ,
28+ attributes = [
29+ Attribute (
30+ name = "name" ,
31+ attr_type = AttributeType .STRING ,
32+ required = False ,
33+ unique = False ,
34+ ),
35+ Attribute (
36+ name = "population" ,
37+ attr_type = AttributeType .NUMBER ,
38+ required = False ,
39+ unique = False ,
40+ ),
41+ Attribute (
42+ name = "weather" ,
43+ attr_type = AttributeType .STRING ,
44+ required = False ,
45+ unique = False ,
46+ ),
47+ ],
48+ )
49+ )
50+ cls .ontology .add_entity (
51+ Entity (
52+ label = "Country" ,
53+ attributes = [
54+ Attribute (
55+ name = "name" ,
56+ attr_type = AttributeType .STRING ,
57+ required = False ,
58+ unique = False ,
59+ ),
60+ ],
61+ )
62+ )
63+ cls .ontology .add_entity (
64+ Entity (
65+ label = "Restaurant" ,
66+ attributes = [
67+ Attribute (
68+ name = "description" ,
69+ attr_type = AttributeType .STRING ,
70+ required = False ,
71+ unique = False ,
72+ ),
73+ Attribute (
74+ name = "food_type" ,
75+ attr_type = AttributeType .STRING ,
76+ required = False ,
77+ unique = False ,
78+ ),
79+ Attribute (
80+ name = "name" ,
81+ attr_type = AttributeType .STRING ,
82+ required = False ,
83+ unique = False ,
84+ ),
85+ Attribute (
86+ name = "rating" ,
87+ attr_type = AttributeType .NUMBER ,
88+ required = False ,
89+ unique = False ,
90+ ),
91+ ],
92+ )
93+ )
94+ cls .ontology .add_relation (
95+ Relation (
96+ label = "IN_COUNTRY" ,
97+ source = "City" ,
98+ target = "Country" ,
99+ )
100+ )
101+ cls .ontology .add_relation (
102+ Relation (
103+ label = "IN_CITY" ,
104+ source = "Restaurant" ,
105+ target = "City" ,
106+ )
107+ )
108+ cls .model = GeminiGenerativeModel ("gemini-1.5-flash-001" )
109+ cls .kg = KnowledgeGraph (
110+ name = "test_ontology" ,
111+ ontology = cls .ontology ,
112+ model_config = KnowledgeGraphModelConfig .with_model (cls .model ),
113+ )
114+ cls .import_data (cls .kg )
115+
116+ @classmethod
117+ def import_data (
118+ self ,
119+ kg : KnowledgeGraph ,
120+ ):
121+ with open ("tests/data/cities.json" ) as f :
122+ cities = loads (f .read ())
123+ with open ("tests/data/restaurants.json" ) as f :
124+ restaurants = loads (f .read ())
125+
126+ for city in cities :
127+ kg .add_node (
128+ "City" ,
129+ {
130+ "name" : city ["name" ],
131+ "weather" : city ["weather" ],
132+ "population" : city ["population" ],
133+ },
134+ )
135+ kg .add_node ("Country" , {"name" : city ["country" ]})
136+ kg .add_edge (
137+ "IN_COUNTRY" ,
138+ "City" ,
139+ "Country" ,
140+ {"name" : city ["name" ]},
141+ {"name" : city ["country" ]},
142+ )
143+
144+ for restaurant in restaurants :
145+ kg .add_node (
146+ "Restaurant" ,
147+ {
148+ "name" : restaurant ["name" ],
149+ "description" : restaurant ["description" ],
150+ "rating" : restaurant ["rating" ],
151+ "food_type" : restaurant ["food_type" ],
152+ },
153+ )
154+ kg .add_edge (
155+ "IN_CITY" ,
156+ "Restaurant" ,
157+ "City" ,
158+ {"name" : restaurant ["name" ]},
159+ {"name" : restaurant ["city" ]},
160+ )
161+
162+ # Delete graph after tests
163+ @classmethod
164+ def tearDownClass (cls ):
165+ logger .info ("Cleaning up test graph..." )
166+ cls .kg .delete ()
167+
168+ def test_ontology_serialization (self ):
169+ logger .info ("Testing ontology serialization..." )
170+ db = FalkorDB ()
171+ graph = db .select_graph ("test_ontology" )
172+ ontology = Ontology .from_kg_graph (graph = graph )
173+ self .assertEqual (ontology .to_json (), self .ontology .to_json ())
0 commit comments