Skip to content

Commit 0ba8ee2

Browse files
authored
feat: reuse doc/chunk/emb results, augment chunks, web service (#6)
* feat: chunk augmentation Signed-off-by: Keming <[email protected]> * fix linter Signed-off-by: Keming <[email protected]> * align naming Signed-off-by: Keming <[email protected]> * feat: allow reuse the doc/chunk/emb Signed-off-by: Keming <[email protected]> * add web, fix augment Signed-off-by: Keming <[email protected]> * add examples Signed-off-by: Keming <[email protected]> * fix lint Signed-off-by: Keming <[email protected]> --------- Signed-off-by: Keming <[email protected]>
1 parent e024e6e commit 0ba8ee2

19 files changed

+1102
-270
lines changed

Makefile

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,3 +18,6 @@ publish: build
1818

1919
test:
2020
@uv run pytest -v tests
21+
22+
sync:
23+
@uv sync --all-extras

README.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,14 @@ timeline
2727
: Filter
2828
Rerank: ColBERT
2929
```
30+
31+
## Development
32+
33+
```bash
34+
docker run --rm -d -e POSTGRES_PASSWORD=postgres -p 5432:5432 tensorchord/vchord-postgres:pg17-v0.2.0
35+
envd up
36+
# inside the envd env, sync all the dependencies
37+
make sync
38+
# format the code
39+
make format
40+
```

test.py renamed to examples/basic.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
LocalLoader,
55
Pipeline,
66
SimpleExtractor,
7-
SpacyEmbedding,
8-
SpacySegmenter,
7+
SpacyChunker,
8+
SpacyDenseEmbedding,
99
VectorChordClient,
1010
)
1111

@@ -16,8 +16,8 @@
1616
),
1717
loader=LocalLoader("data", include=[".pdf"]),
1818
extractor=SimpleExtractor(),
19-
segmenter=SpacySegmenter(),
20-
emb=SpacyEmbedding(),
19+
chunker=SpacyChunker(),
20+
emb=SpacyDenseEmbedding(),
2121
)
2222
pipe.run()
2323

examples/gemini.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from rich import print
2+
3+
from vechord import (
4+
GeminiAugmenter,
5+
GeminiDenseEmbedding,
6+
GeminiExtractor,
7+
LocalLoader,
8+
Pipeline,
9+
VectorChordClient,
10+
WordLlamaChunker,
11+
)
12+
13+
if __name__ == "__main__":
14+
pipe = Pipeline(
15+
client=VectorChordClient(
16+
"local_pdf", "postgresql://postgres:[email protected]:5432/"
17+
),
18+
loader=LocalLoader("data", include=[".pdf"]),
19+
extractor=GeminiExtractor(),
20+
chunker=WordLlamaChunker(),
21+
emb=GeminiDenseEmbedding(),
22+
augmenter=GeminiAugmenter(),
23+
)
24+
pipe.run()
25+
26+
print(pipe.query("vector search"))

pyproject.toml

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,15 @@ description = "VectorChord Python SDK"
55
readme = "README.md"
66
requires-python = ">=3.9"
77
dependencies = [
8-
"en-core-web-sm",
8+
"falcon>=4.0.2",
99
"httpx>=0.28.1",
1010
"msgspec>=0.19.0",
1111
"numpy>=2.0.2",
12-
"openai>=1.59.7",
1312
"pgvector>=0.3.6",
1413
"pillow>=11.1.0",
1514
"psycopg[binary]>=3.2.3",
1615
"pypdfium2>=4.30.1",
1716
"rich>=13.9.4",
18-
"spacy>=3.8.4",
19-
"trio>=0.28.0",
2017
]
2118

2219
[project.scripts]
@@ -26,6 +23,16 @@ vechord = "vechord.main:main"
2623
gemini = [
2724
"google-generativeai>=0.8.4",
2825
]
26+
openai = [
27+
"openai>=1.60.2",
28+
]
29+
spacy = [
30+
"en-core-web-sm",
31+
"spacy>=3.8.4",
32+
]
33+
wordllama = [
34+
"wordllama>=0.3.8.post20",
35+
]
2936

3037
[build-system]
3138
requires = ["pdm-backend"]
@@ -45,7 +52,7 @@ ignore = ["E501"]
4552
[tool.ruff.lint.isort]
4653
known-first-party = ["vechord"]
4754
[tool.ruff.lint.pylint]
48-
max-args = 7
55+
max-args = 5
4956

5057
[tool.pdm]
5158
distribution = true

uv.lock

Lines changed: 365 additions & 174 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

vechord/__init__.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,29 @@
1+
from vechord.augment import GeminiAugmenter
2+
from vechord.chunk import RegexChunker, SpacyChunker, WordLlamaChunker
13
from vechord.client import VectorChordClient
2-
from vechord.embedding import GeminiEmbedding, OpenAIEmbedding, SpacyEmbedding
4+
from vechord.embedding import (
5+
GeminiDenseEmbedding,
6+
OpenAIDenseEmbedding,
7+
SpacyDenseEmbedding,
8+
)
39
from vechord.extract import GeminiExtractor, SimpleExtractor
410
from vechord.load import LocalLoader
511
from vechord.model import Chunk, Document
612
from vechord.pipeline import Pipeline
7-
from vechord.segment import RegexSegmenter, SpacySegmenter
813

914
__all__ = [
1015
"Chunk",
1116
"Document",
12-
"GeminiEmbedding",
17+
"GeminiAugmenter",
18+
"GeminiDenseEmbedding",
1319
"GeminiExtractor",
1420
"LocalLoader",
15-
"OpenAIEmbedding",
21+
"OpenAIDenseEmbedding",
1622
"Pipeline",
17-
"RegexSegmenter",
23+
"RegexChunker",
1824
"SimpleExtractor",
19-
"SpacyEmbedding",
20-
"SpacySegmenter",
25+
"SpacyChunker",
26+
"SpacyDenseEmbedding",
2127
"VectorChordClient",
28+
"WordLlamaChunker",
2229
]

vechord/augment.py

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
import os
2+
from abc import ABC, abstractmethod
3+
from datetime import timedelta
4+
5+
from vechord.log import logger
6+
7+
8+
class BaseAugmenter(ABC):
9+
@abstractmethod
10+
def reset(self, doc: str):
11+
"""Cache the document for augmentation."""
12+
raise NotImplementedError
13+
14+
@abstractmethod
15+
def name(self) -> str:
16+
raise NotImplementedError
17+
18+
@abstractmethod
19+
def augment_context(self, chunks: list[str]) -> list[str]:
20+
raise NotImplementedError
21+
22+
@abstractmethod
23+
def augment_query(self, chunks: list[str]) -> list[str]:
24+
raise NotImplementedError
25+
26+
@abstractmethod
27+
def summarize_doc(self) -> str:
28+
raise NotImplementedError
29+
30+
31+
class GeminiAugmenter(BaseAugmenter):
32+
def __init__(self, model: str = "models/gemini-1.5-flash-001", ttl_sec: int = 600):
33+
"""Gemini Augmenter with cache.
34+
35+
Minimal cache token is 32768.
36+
"""
37+
key = os.environ.get("GEMINI_API_KEY")
38+
if not key:
39+
raise ValueError("env GEMINI_API_KEY not set")
40+
41+
self.model_name = model
42+
self.ttl_sec = ttl_sec
43+
self.min_token = 32768
44+
45+
def name(self) -> str:
46+
return f"gemini_augment_{self.model_name}"
47+
48+
def reset(self, doc: str):
49+
import google.generativeai as genai
50+
51+
self.client = genai.GenerativeModel(model_name=self.model_name)
52+
tokens = self.client.count_tokens(doc).total_tokens
53+
self.doc = "" # empty means doc is in the cache
54+
if tokens <= self.min_token:
55+
# cannot use cache due to the Gemini token limit
56+
self.doc = doc
57+
else:
58+
logger.debug("use cache since the doc has %d tokens", tokens)
59+
cache = genai.caching.CachedContent.create(
60+
model=self.model_name,
61+
system_instruction=(
62+
"You are an expert on the natural language understanding. "
63+
"Answer the questions based on the whole document you have access to."
64+
),
65+
contents=doc,
66+
ttl=timedelta(seconds=self.ttl_sec),
67+
)
68+
self.client = genai.GenerativeModel.from_cached_content(
69+
cached_content=cache
70+
)
71+
72+
def augment(self, chunks: list[str], prompt: str) -> list[str]:
73+
res = []
74+
try:
75+
for chunk in chunks:
76+
context = prompt.format(chunk=chunk)
77+
if self.doc:
78+
context = f"<document>{self.doc}</document>\n" + context
79+
response = self.client.generate_content([context])
80+
res.append(response.text)
81+
except Exception as e:
82+
logger.error("GeminiAugmenter error: %s", e)
83+
breakpoint()
84+
return res
85+
86+
def augment_context(self, chunks: list[str]) -> list[str]:
87+
prompt = (
88+
"Here is the chunk we want to situate within the whole document "
89+
"<chunk>{chunk}</chunk>"
90+
"Please give a short succinct context to situate this chunk within "
91+
"the overall document for the purposes of improving search retrieval "
92+
"of the chunk. Answer only with the succinct context and nothing else."
93+
)
94+
return self.augment(chunks, prompt)
95+
96+
def augment_query(self, chunks: list[str]) -> list[str]:
97+
prompt = (
98+
"Here is the chunk we want to ask questions about "
99+
"<chunk>{chunk}</chunk>"
100+
"Please ask questions about this chunk based on the overall document "
101+
"for the purposes of improving search retrieval of the chunk. "
102+
"Answer only with the question and nothing else."
103+
)
104+
return self.augment(chunks, prompt)
105+
106+
def summarize_doc(self) -> str:
107+
prompt = (
108+
"Summarize the provided document concisely while preserving its key "
109+
"ideas, main arguments, and essential details. Ensure clarity and "
110+
"coherence, avoiding unnecessary repetition."
111+
)
112+
if self.doc:
113+
prompt = f"<document>{self.doc}</document>\n" + prompt
114+
response = self.client.generate_content([prompt])
115+
return response.text

vechord/segment.py renamed to vechord/chunk.py

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,20 @@
22
from abc import ABC, abstractmethod
33

44

5-
class BaseSegmenter(ABC):
5+
class BaseChunker(ABC):
66
@abstractmethod
77
def segment(self, text: str) -> list[str]:
88
raise NotImplementedError
99

10+
@abstractmethod
11+
def name(self) -> str:
12+
raise NotImplementedError
13+
1014

11-
class RegexSegmenter(BaseSegmenter):
15+
class RegexChunker(BaseChunker):
1216
def __init__(
1317
self,
14-
size: int = 1000,
18+
size: int = 1536,
1519
overlap: int = 200,
1620
separator: str = r"\s{2,}",
1721
concat: str = ". ",
@@ -21,6 +25,9 @@ def __init__(
2125
self.separator = re.compile(separator)
2226
self.concatenator = concat
2327

28+
def name(self) -> str:
29+
return f"regex_chunk_{self.size}_{self.overlap}"
30+
2431
def keep_overlap(self, pieces: list[str]) -> list[str]:
2532
length = 0
2633
i = len(pieces) - 1
@@ -69,11 +76,31 @@ def segment(self, text: str) -> list[str]:
6976
return [*chunks, remaining] if remaining else chunks
7077

7178

72-
class SpacySegmenter(BaseSegmenter):
73-
def __init__(self):
79+
class SpacyChunker(BaseChunker):
80+
def __init__(self, model: str = "en_core_web_sm"):
81+
"""A semantic sentence Chunker based on SpaCy."""
7482
import spacy
7583

76-
self.nlp = spacy.load("en_core_web_sm", enable=["parser", "tok2vec"])
84+
self.model = model
85+
self.nlp = spacy.load(model, enable=["parser", "tok2vec"])
86+
87+
def name(self) -> str:
88+
return f"spacy_chunk_{self.model}"
7789

7890
def segment(self, text: str) -> list[str]:
7991
return [sent.text for sent in self.nlp(text).sents]
92+
93+
94+
class WordLlamaChunker(BaseChunker):
95+
def __init__(self, size: int = 1536):
96+
"""A semantic chunker based on WordLlama."""
97+
from wordllama import WordLlama
98+
99+
self.model = WordLlama.load()
100+
self.size = size
101+
102+
def name(self) -> str:
103+
return f"wordllama_chunk_{self.size}"
104+
105+
def segment(self, text: str) -> list[str]:
106+
return self.model.split(text, target_size=self.size)

0 commit comments

Comments
 (0)