Skip to content

Commit a21d9b2

Browse files
committed
update-tests
1 parent a17975a commit a21d9b2

File tree

1 file changed

+110
-146
lines changed

1 file changed

+110
-146
lines changed

tests/test_ontology_from_kg.py

Lines changed: 110 additions & 146 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
1-
1+
import pytest
22
import logging
3-
import unittest
43
from json import loads
54
from falkordb import FalkorDB
65
from dotenv import load_dotenv
@@ -16,158 +15,123 @@
1615
logging.basicConfig(level=logging.DEBUG)
1716
logger = logging.getLogger(__name__)
1817

19-
20-
class TestOntologyFromKG(unittest.TestCase):
21-
@classmethod
22-
def setUpClass(cls):
23-
24-
cls.ontology = Ontology()
25-
cls.ontology.add_entity(
26-
Entity(
27-
label="City",
28-
attributes=[
29-
Attribute(
30-
name="name",
31-
attr_type=AttributeType.STRING,
32-
required=False,
33-
unique=False,
34-
),
35-
Attribute(
36-
name="population",
37-
attr_type=AttributeType.NUMBER,
38-
required=False,
39-
unique=False,
40-
),
41-
Attribute(
42-
name="weather",
43-
attr_type=AttributeType.STRING,
44-
required=False,
45-
unique=False,
46-
),
47-
],
48-
)
18+
@pytest.fixture
19+
def ontology_kg_setup():
20+
"""
21+
Sets up an ontology, initializes the KnowledgeGraph, and imports data.
22+
"""
23+
# Build up the ontology.
24+
ontology = Ontology()
25+
ontology.add_entity(
26+
Entity(
27+
label="City",
28+
attributes=[
29+
Attribute("name", AttributeType.STRING, required=False, unique=False),
30+
Attribute("population", AttributeType.NUMBER, required=False, unique=False),
31+
Attribute("weather", AttributeType.STRING, required=False, unique=False),
32+
],
4933
)
50-
cls.ontology.add_entity(
51-
Entity(
52-
label="Country",
53-
attributes=[
54-
Attribute(
55-
name="name",
56-
attr_type=AttributeType.STRING,
57-
required=False,
58-
unique=False,
59-
),
60-
],
61-
)
34+
)
35+
ontology.add_entity(
36+
Entity(
37+
label="Country",
38+
attributes=[
39+
Attribute("name", AttributeType.STRING, required=False, unique=False),
40+
],
6241
)
63-
cls.ontology.add_entity(
64-
Entity(
65-
label="Restaurant",
66-
attributes=[
67-
Attribute(
68-
name="description",
69-
attr_type=AttributeType.STRING,
70-
required=False,
71-
unique=False,
72-
),
73-
Attribute(
74-
name="food_type",
75-
attr_type=AttributeType.STRING,
76-
required=False,
77-
unique=False,
78-
),
79-
Attribute(
80-
name="name",
81-
attr_type=AttributeType.STRING,
82-
required=False,
83-
unique=False,
84-
),
85-
Attribute(
86-
name="rating",
87-
attr_type=AttributeType.NUMBER,
88-
required=False,
89-
unique=False,
90-
),
91-
],
92-
)
42+
)
43+
ontology.add_entity(
44+
Entity(
45+
label="Restaurant",
46+
attributes=[
47+
Attribute("description", AttributeType.STRING, required=False, unique=False),
48+
Attribute("food_type", AttributeType.STRING, required=False, unique=False),
49+
Attribute("name", AttributeType.STRING, required=False, unique=False),
50+
Attribute("rating", AttributeType.NUMBER, required=False, unique=False),
51+
],
52+
)
53+
)
54+
ontology.add_relation(Relation(label="IN_COUNTRY", source="City", target="Country"))
55+
ontology.add_relation(Relation(label="IN_CITY", source="Restaurant", target="City"))
56+
57+
# Create a model and a knowledge graph.
58+
model = GeminiGenerativeModel("gemini-1.5-flash-001")
59+
kg = KnowledgeGraph(
60+
name="test_ontology",
61+
ontology=ontology,
62+
model_config=KnowledgeGraphModelConfig.with_model(model),
63+
)
64+
65+
# Import test data.
66+
with open("tests/data/cities.json") as f:
67+
cities = loads(f.read())
68+
with open("tests/data/restaurants.json") as f:
69+
restaurants = loads(f.read())
70+
71+
for city in cities:
72+
kg.add_node(
73+
"City",
74+
{
75+
"name": city["name"],
76+
"weather": city["weather"],
77+
"population": city["population"],
78+
},
9379
)
94-
cls.ontology.add_relation(
95-
Relation(
96-
label="IN_COUNTRY",
97-
source="City",
98-
target="Country",
99-
)
80+
kg.add_node("Country", {"name": city["country"]})
81+
kg.add_edge(
82+
"IN_COUNTRY",
83+
"City",
84+
"Country",
85+
{"name": city["name"]},
86+
{"name": city["country"]},
10087
)
101-
cls.ontology.add_relation(
102-
Relation(
103-
label="IN_CITY",
104-
source="Restaurant",
105-
target="City",
106-
)
88+
89+
for restaurant in restaurants:
90+
kg.add_node(
91+
"Restaurant",
92+
{
93+
"name": restaurant["name"],
94+
"description": restaurant["description"],
95+
"rating": restaurant["rating"],
96+
"food_type": restaurant["food_type"],
97+
},
10798
)
108-
cls.model = GeminiGenerativeModel("gemini-1.5-flash-001")
109-
cls.kg = KnowledgeGraph(
110-
name="test_ontology",
111-
ontology=cls.ontology,
112-
model_config=KnowledgeGraphModelConfig.with_model(cls.model),
99+
kg.add_edge(
100+
"IN_CITY",
101+
"Restaurant",
102+
"City",
103+
{"name": restaurant["name"]},
104+
{"name": restaurant["city"]},
113105
)
114-
cls.import_data(cls.kg)
115-
116-
@classmethod
117-
def import_data(
118-
self,
119-
kg: KnowledgeGraph,
120-
):
121-
with open("tests/data/cities.json") as f:
122-
cities = loads(f.read())
123-
with open("tests/data/restaurants.json") as f:
124-
restaurants = loads(f.read())
125-
126-
for city in cities:
127-
kg.add_node(
128-
"City",
129-
{
130-
"name": city["name"],
131-
"weather": city["weather"],
132-
"population": city["population"],
133-
},
134-
)
135-
kg.add_node("Country", {"name": city["country"]})
136-
kg.add_edge(
137-
"IN_COUNTRY",
138-
"City",
139-
"Country",
140-
{"name": city["name"]},
141-
{"name": city["country"]},
142-
)
143-
144-
for restaurant in restaurants:
145-
kg.add_node(
146-
"Restaurant",
147-
{
148-
"name": restaurant["name"],
149-
"description": restaurant["description"],
150-
"rating": restaurant["rating"],
151-
"food_type": restaurant["food_type"],
152-
},
153-
)
154-
kg.add_edge(
155-
"IN_CITY",
156-
"Restaurant",
157-
"City",
158-
{"name": restaurant["name"]},
159-
{"name": restaurant["city"]},
160-
)
161-
162-
# Delete graph after tests
163-
@classmethod
164-
def tearDownClass(cls):
106+
107+
return ontology, kg
108+
109+
110+
@pytest.fixture
111+
def delete_kg():
112+
"""
113+
Returns a function that deletes a given knowledge graph.
114+
"""
115+
def cleanup(kg):
165116
logger.info("Cleaning up test graph...")
166-
cls.kg.delete()
117+
kg.delete()
118+
119+
return cleanup
120+
167121

168-
def test_ontology_serialization(self):
122+
class TestOntologyFromKG:
123+
def test_ontology_serialization(self, ontology_kg_setup, delete_kg):
124+
"""
125+
Tests serializing the Ontology from the knowledge graph.
126+
"""
127+
ontology, kg = ontology_kg_setup
169128
logger.info("Testing ontology serialization...")
129+
170130
db = FalkorDB()
171131
graph = db.select_graph("test_ontology")
172-
ontology = Ontology.from_kg_graph(graph=graph)
173-
self.assertEqual(ontology.to_json(), self.ontology.to_json())
132+
loaded_ontology = Ontology.from_kg_graph(graph=graph)
133+
134+
assert loaded_ontology.to_json() == ontology.to_json()
135+
136+
# Now clean up the KG by calling the function from delete_kg fixture
137+
delete_kg(kg)

0 commit comments

Comments
 (0)