Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,15 @@ jobs:
with:
python-version: 3.x # Update with desired Python version

- uses: datadog/[email protected]
with:
check_mode: "true"

- uses: astral-sh/ruff-action@v3
with:
src: >-
"./graphrag_sdk"

- name: Cache Poetry virtualenv
id: cache
uses: actions/cache@v4
Expand Down
2 changes: 1 addition & 1 deletion graphrag_sdk/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,4 @@
"Relation",
"Attribute",
"AttributeType",
]
]
2 changes: 1 addition & 1 deletion graphrag_sdk/agents/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .agent import Agent

__all__ = ['Agent']
__all__ = ["Agent"]
1 change: 0 additions & 1 deletion graphrag_sdk/agents/agent.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from abc import ABC, abstractmethod
from graphrag_sdk.models.model import GenerativeModelChatSession


class AgentResponseCode:
Expand Down
3 changes: 1 addition & 2 deletions graphrag_sdk/agents/kg_agent.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from graphrag_sdk.kg import KnowledgeGraph
from .agent import Agent
from graphrag_sdk.models import GenerativeModelChatSession


class KGAgent(Agent):
Expand Down Expand Up @@ -136,7 +135,7 @@ def run(self, params: dict) -> str:

"""
output = self.chat_session.send_message(params["prompt"])
return output['response']
return output["response"]

def __repr__(self):
"""
Expand Down
25 changes: 12 additions & 13 deletions graphrag_sdk/attribute.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import json
from graphrag_sdk.fixtures.regex import *
import logging
import re

Expand Down Expand Up @@ -42,18 +41,18 @@ def from_string(txt: str):


class Attribute:
""" Represents an attribute of an entity or relation in the ontology.

Args:
name (str): The name of the attribute.
attr_type (AttributeType): The type of the attribute.
unique (bool): Whether the attribute is unique.
required (bool): Whether the attribute is required.

Examples:
>>> attr = Attribute("name", AttributeType.STRING, True, True)
>>> print(attr)
name: "string!*"
"""Represents an attribute of an entity or relation in the ontology.

Args:
name (str): The name of the attribute.
attr_type (AttributeType): The type of the attribute.
unique (bool): Whether the attribute is unique.
required (bool): Whether the attribute is required.

Examples:
>>> attr = Attribute("name", AttributeType.STRING, True, True)
>>> print(attr)
name: "string!*"
"""

def __init__(
Expand Down
49 changes: 29 additions & 20 deletions graphrag_sdk/chat_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,17 @@ class ChatSession:
>>> chat_session.send_message("What is the capital of France?")
"""

def __init__(self, model_config: KnowledgeGraphModelConfig, ontology: Ontology, graph: Graph,
cypher_system_instruction: str, qa_system_instruction: str,
cypher_gen_prompt: str, qa_prompt: str, cypher_gen_prompt_history: str):
def __init__(
self,
model_config: KnowledgeGraphModelConfig,
ontology: Ontology,
graph: Graph,
cypher_system_instruction: str,
qa_system_instruction: str,
cypher_gen_prompt: str,
qa_prompt: str,
cypher_gen_prompt_history: str,
):
"""
Initializes a new ChatSession object.

Expand All @@ -45,21 +53,22 @@ def __init__(self, model_config: KnowledgeGraphModelConfig, ontology: Ontology,
self.model_config = model_config
self.graph = graph
self.ontology = ontology
cypher_system_instruction = cypher_system_instruction.format(ontology=str(ontology.to_json()))
cypher_system_instruction = cypher_system_instruction.format(
ontology=str(ontology.to_json())
)


self.cypher_prompt = cypher_gen_prompt
self.qa_prompt = qa_prompt
self.cypher_prompt_with_history = cypher_gen_prompt_history

self.cypher_chat_session = (
model_config.cypher_generation.with_system_instruction(
cypher_system_instruction
).start_chat()
)
self.qa_chat_session = model_config.qa.with_system_instruction(
qa_system_instruction
).start_chat()
qa_system_instruction
).start_chat()
self.last_answer = None

def send_message(self, message: str):
Expand All @@ -71,9 +80,9 @@ def send_message(self, message: str):

Returns:
dict: The response to the message in the following format:
{"question": message,
"response": answer,
"context": context,
{"question": message,
"response": answer,
"context": context,
"cypher": cypher}
"""
cypher_step = GraphQueryGenerationStep(
Expand All @@ -82,7 +91,7 @@ def send_message(self, message: str):
ontology=self.ontology,
last_answer=self.last_answer,
cypher_prompt=self.cypher_prompt,
cypher_prompt_with_history=self.cypher_prompt_with_history
cypher_prompt_with_history=self.cypher_prompt_with_history,
)

(context, cypher) = cypher_step.run(message)
Expand All @@ -92,8 +101,8 @@ def send_message(self, message: str):
"question": message,
"response": "I am sorry, I could not find the answer to your question",
"context": None,
"cypher": None
}
"cypher": None,
}

qa_step = QAStep(
chat_session=self.qa_chat_session,
Expand All @@ -102,10 +111,10 @@ def send_message(self, message: str):

answer = qa_step.run(message, cypher, context)
self.last_answer = answer

return {
"question": message,
"response": answer,
"context": context,
"cypher": cypher
}
"question": message,
"response": answer,
"context": context,
"cypher": cypher,
}
2 changes: 1 addition & 1 deletion graphrag_sdk/document.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
class Document():
class Document:
"""
Common class containing text extracted from a source
"""
Expand Down
6 changes: 1 addition & 5 deletions graphrag_sdk/document_loaders/jsonl.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,6 @@ def load(self) -> Iterator[Document]:
num_documents = num_rows // self.rows_per_document
for i in range(num_documents):
content = "\n".join(
rows[
i
* self.rows_per_document : (i + 1)
* self.rows_per_document
]
rows[i * self.rows_per_document : (i + 1) * self.rows_per_document]
)
yield Document(content)
20 changes: 9 additions & 11 deletions graphrag_sdk/document_loaders/pdf.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from typing import Iterator
from graphrag_sdk.document import Document

class PDFLoader():

class PDFLoader:
"""
Load PDF
"""
Expand All @@ -15,11 +16,11 @@ def __init__(self, path: str) -> None:
"""

try:
import pypdf
except ImportError:
raise ImportError(
__import__("pypdf")
except ModuleNotFoundError:
raise ModuleNotFoundError(
"pypdf package not found, please install it with " "`pip install pypdf`"
)
)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Improve exception chaining in ModuleNotFoundError handling

Use exception chaining to preserve the original error context. This helps with debugging by maintaining the full error traceback.

         try:
             __import__("pypdf")
-        except ModuleNotFoundError:
+        except ModuleNotFoundError as err:
             raise ModuleNotFoundError(
                 "pypdf package not found, please install it with " "`pip install pypdf`"
-            )
+            ) from err
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
__import__("pypdf")
except ModuleNotFoundError:
raise ModuleNotFoundError(
"pypdf package not found, please install it with " "`pip install pypdf`"
)
)
__import__("pypdf")
except ModuleNotFoundError as err:
raise ModuleNotFoundError(
"pypdf package not found, please install it with " "`pip install pypdf`"
) from err
🧰 Tools
🪛 Ruff (0.8.2)

21-23: Within an except clause, raise exceptions with raise ... from err or raise ... from None to distinguish them from errors in exception handling

(B904)


self.path = path

Expand All @@ -30,11 +31,8 @@ def load(self) -> Iterator[Document]:
Returns:
Iterator[Document]: document iterator
"""
from pypdf import PdfReader # pylint: disable=import-outside-toplevel

from pypdf import PdfReader # pylint: disable=import-outside-toplevel

reader = PdfReader(self.path)
yield from [
Document(page.extract_text())
for page in reader.pages
]
yield from [Document(page.extract_text()) for page in reader.pages]
9 changes: 4 additions & 5 deletions graphrag_sdk/document_loaders/text.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from typing import Iterator
from graphrag_sdk.document import Document

class TextLoader():

class TextLoader:
"""
Load Text
"""
Expand All @@ -24,7 +25,5 @@ def load(self) -> Iterator[Document]:
Iterator[Document]: document iterator
"""

with open(self.path, 'r') as f:
yield Document(
f.read()
)
with open(self.path, "r") as f:
yield Document(f.read())
11 changes: 6 additions & 5 deletions graphrag_sdk/document_loaders/url.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from bs4 import BeautifulSoup
from graphrag_sdk.document import Document

class URLLoader():

class URLLoader:
"""
Load URL
"""
Expand All @@ -21,7 +22,7 @@ def __init__(self, url: str) -> None:

def _download(self) -> str:
try:
response = requests.get(self.url, headers={'User-Agent': 'Mozilla/5.0'})
response = requests.get(self.url, headers={"User-Agent": "Mozilla/5.0"})
response.raise_for_status() # Raise an HTTPError for bad responses (4xx and 5xx)
return response.text
except requests.exceptions.RequestException as e:
Expand All @@ -39,13 +40,13 @@ def load(self) -> Iterator[Document]:
content = self._download()

# extract text from HTML, populate content
soup = BeautifulSoup(content, 'html.parser')
soup = BeautifulSoup(content, "html.parser")

# Extract text from the HTML
content = soup.get_text()

# Remove extra newlines
content = re.sub(r'\n{2,}', '\n', content)
content = re.sub(r"\n{2,}", "\n", content)

yield Document(content)
#return f"{self.source}\n{self.content}"
# return f"{self.source}\n{self.content}"
2 changes: 1 addition & 1 deletion graphrag_sdk/fixtures/regex.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@

NODE_LABEL_REGEX = r"\(.+:(.*?)\)"

NODE_REGEX = r"\(.*?\)"
NODE_REGEX = r"\(.*?\)"
28 changes: 12 additions & 16 deletions graphrag_sdk/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,12 @@ def stringify_falkordb_response(response):
elif not isinstance(response[0], list):
data = str(response).strip()
else:
for l, _ in enumerate(response):
if not isinstance(response[l], list):
response[l] = str(response[l])
for line, _ in enumerate(response):
if not isinstance(response[line], list):
response[line] = str(response[line])
else:
for i, __ in enumerate(response[l]):
response[l][i] = str(response[l][i])
for i, __ in enumerate(response[line]):
response[line][i] = str(response[line][i])
data = str(response).strip()

return data
Expand All @@ -77,9 +77,7 @@ def extract_cypher(text: str):
return "".join(matches)


def validate_cypher(
cypher: str, ontology: graphrag_sdk.Ontology
) -> list[str] | None:
def validate_cypher(cypher: str, ontology: graphrag_sdk.Ontology) -> list[str] | None:
try:
if not cypher or len(cypher) == 0:
return ["Cypher statement is empty"]
Expand Down Expand Up @@ -126,10 +124,10 @@ def validate_cypher_relations_exist(cypher: str, ontology: graphrag_sdk.Ontology
for relation in relation_labels:
for label in relation.split("|"):
max_idx = min(
label.index("*") if "*" in label else len(label),
label.index("{") if "{" in label else len(label),
label.index("]") if "]" in label else len(label),
)
label.index("*") if "*" in label else len(label),
label.index("{") if "{" in label else len(label),
label.index("]") if "]" in label else len(label),
)
label = label[:max_idx]
if label not in [relation.label for relation in ontology.relations]:
not_found_relation_labels.append(label)
Expand All @@ -139,9 +137,7 @@ def validate_cypher_relations_exist(cypher: str, ontology: graphrag_sdk.Ontology
]


def validate_cypher_relation_directions(
cypher: str, ontology: graphrag_sdk.Ontology
):
def validate_cypher_relation_directions(cypher: str, ontology: graphrag_sdk.Ontology):

errors = []
relations = list(re.finditer(r"\[.*?\]", cypher))
Expand Down Expand Up @@ -215,4 +211,4 @@ def validate_cypher_relation_directions(
except Exception:
continue

return errors
return errors
Loading
Loading