Skip to content

Commit a17975a

Browse files
committed
update-attributes-convention
1 parent 44844e5 commit a17975a

File tree

2 files changed

+61
-27
lines changed

2 files changed

+61
-27
lines changed

graphrag_sdk/attribute.py

Lines changed: 54 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1+
import re
12
import json
2-
from graphrag_sdk.fixtures.regex import *
33
import logging
4-
import re
4+
from graphrag_sdk.fixtures.regex import *
55

66
logger = logging.getLogger(__name__)
77

@@ -15,6 +15,28 @@ class AttributeType:
1515
NUMBER = "number"
1616
BOOLEAN = "boolean"
1717
LIST = "list"
18+
POINT = "point"
19+
MAP = "map"
20+
VECTOR = "vectorf32"
21+
DATE = "date"
22+
DATE_TIME = "datetime"
23+
TIME = "time"
24+
DURATION = "duration"
25+
26+
# Synonyms for attribute types
27+
_SYNONYMS = {
28+
STRING: {"string"},
29+
NUMBER: {"integer", "float"},
30+
BOOLEAN: {"boolean"},
31+
LIST: {"list"},
32+
POINT: {"point"},
33+
MAP: {"map"},
34+
VECTOR: {"vectorf32"},
35+
DATE: {"date"},
36+
DATE_TIME: {"datetime", "local datetime"},
37+
TIME: {"time", "local time"},
38+
DURATION: {"duration"},
39+
}
1840

1941
@staticmethod
2042
def from_string(txt: str):
@@ -28,18 +50,17 @@ def from_string(txt: str):
2850
AttributeType: The corresponding AttributeType value.
2951
3052
Raises:
31-
Exception: If the provided attribute type is invalid.
53+
ValueError: If the provided attribute type is invalid.
3254
"""
33-
if txt.lower() == AttributeType.STRING:
34-
return AttributeType.STRING
35-
if txt.lower() == AttributeType.NUMBER:
36-
return AttributeType.NUMBER
37-
if txt.lower() == AttributeType.BOOLEAN:
38-
return AttributeType.BOOLEAN
39-
if txt.lower() == AttributeType.LIST:
40-
return AttributeType.LIST
41-
raise Exception(f"Invalid attribute type: {txt}")
42-
55+
# Graph representation of the attribute type
56+
normalized_txt = txt.lower()
57+
58+
# Find the matching attribute type
59+
for attr_type, synonyms in AttributeType._SYNONYMS.items():
60+
if normalized_txt in synonyms:
61+
return attr_type
62+
63+
raise ValueError(f"Invalid attribute type: {txt}")
4364

4465
class Attribute:
4566
""" Represents an attribute of an entity or relation in the ontology.
@@ -120,14 +141,6 @@ def from_string(txt: str):
120141
unique = "!" in txt
121142
required = "*" in txt
122143

123-
if attr_type not in [
124-
AttributeType.STRING,
125-
AttributeType.NUMBER,
126-
AttributeType.BOOLEAN,
127-
AttributeType.LIST,
128-
]:
129-
raise Exception(f"Invalid attribute type: {attr_type}")
130-
131144
return Attribute(name, AttributeType.from_string(attr_type), unique, required)
132145

133146
def to_json(self):
@@ -161,3 +174,23 @@ def __str__(self) -> str:
161174
str: A string representation of the Attribute object.
162175
"""
163176
return f"{self.name}: \"{self.type}{'!' if self.unique else ''}{'*' if self.required else ''}\""
177+
178+
def process_attributes_from_graph(attributes: list[list[str]]) -> list[Attribute]:
179+
"""
180+
Processes the attributes extracted from the graph and converts them into the SDK convention.
181+
182+
Args:
183+
attributes (list[list[str]]): The attributes extracted from the graph.
184+
185+
Returns:
186+
processed_attributes (list[Attribute]): The processed attributes.
187+
"""
188+
processed_attributes = []
189+
for attr in attributes:
190+
try:
191+
type = AttributeType.from_string(attr[0][1])
192+
processed_attributes.append(Attribute(attr[0][0],type))
193+
except:
194+
continue
195+
196+
return processed_attributes

graphrag_sdk/ontology.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from falkordb import Graph
66
from typing import Optional
77
from .relation import Relation
8-
from .attribute import Attribute
8+
from .attribute import process_attributes_from_graph
99
from graphrag_sdk.source import AbstractSource
1010
from graphrag_sdk.models import GenerativeModel
1111

@@ -130,17 +130,18 @@ def from_kg_graph(graph: Graph, sample_size: int = 100,):
130130
for lbls in n_labels:
131131
l = lbls[0]
132132
attributes = graph.query(
133-
f"""MATCH (a:{l}) call {{ with a return [k in keys(a) | [k, typeof(a[k])]] as types }}
133+
f"""MATCH (a:{l}) call {{ with a return [k in keys(a) | [k, typeof(a[k])]] as types }}
134134
WITH types limit {sample_size} unwind types as kt RETURN kt, count(1) ORDER BY kt[0]""").result_set
135-
ontology.add_entity(Entity(l, [Attribute(attr[0][0], 'number' if attr[0][1] == 'Integer' or attr[0][1] == 'Float' else attr[0][1].lower()) for attr in attributes]))
135+
attributes = process_attributes_from_graph(attributes)
136+
ontology.add_entity(Entity(l, attributes))
136137

137138
# Extract attributes for each edge type, limited by the specified sample size.
138139
for e_type in e_types:
139140
e_t = e_type[0]
140141
attributes = graph.query(
141-
f"""MATCH ()-[a:{e_t}]->() call {{ with a return [k in keys(a) | [k, typeof(a[k])]] as types }}
142+
f"""MATCH ()-[a:{e_t}]->() call {{ with a return [k in keys(a) | [k, typeof(a[k])]] as types }}
142143
WITH types limit {sample_size} unwind types as kt RETURN kt, count(1) ORDER BY kt[0]""").result_set
143-
attributes = [Attribute(attr[0][0], 'number' if attr[0][1] == 'Integer' or attr[0][1] == 'Float' else attr[0][1].lower()) for attr in attributes]
144+
attributes = process_attributes_from_graph(attributes)
144145
for s_lbls in n_labels:
145146
for t_lbls in n_labels:
146147
s_l = s_lbls[0]
@@ -380,4 +381,4 @@ def save_to_graph(self, graph: Graph):
380381
for relation in self.relations:
381382
query = relation.to_graph_query()
382383
logger.debug(f"Query: {query}")
383-
graph.query(query)
384+
graph.query(query)

0 commit comments

Comments
 (0)