-
Notifications
You must be signed in to change notification settings - Fork 0
[feat] create vector embeddings #22
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 5 commits
6112985
27f5fe0
d5cdedb
42d4529
7512db3
2b98880
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -34,3 +34,6 @@ yarn-error.log* | |
# typescript | ||
*.tsbuildinfo | ||
next-env.d.ts | ||
|
||
# python | ||
/adopt-an-inmate-venv | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
import os | ||
from dotenv import load_dotenv | ||
import supabase | ||
import vecs | ||
from sentence_transformers import SentenceTransformer | ||
from config import MODEL_NAME, MODEL_DIMENSION, VECS_COLLECTION_NAME, SUPABASE_TABLE_NAME | ||
|
||
load_dotenv(os.path.join(os.path.dirname(__file__), "../../.env.local")) | ||
|
||
# Initialize model | ||
model = SentenceTransformer(MODEL_NAME) | ||
|
||
# Initialize Supabase | ||
SUPABASE_URL = os.getenv("NEXT_PUBLIC_SUPABASE_URL") | ||
SUPABASE_ANON_KEY = os.getenv("NEXT_PUBLIC_SUPABASE_ANON_KEY") | ||
supabase_client = supabase.create_client(SUPABASE_URL, SUPABASE_ANON_KEY) | ||
adoptee_table = supabase_client.table(SUPABASE_TABLE_NAME).select("*").execute().data | ||
|
||
# Initialize Vecs | ||
DB_CONNECTION = os.getenv("DATABASE_URL") | ||
vx = vecs.create_client(DB_CONNECTION) | ||
adoptee_vector_collection = vx.get_or_create_collection( | ||
name=VECS_COLLECTION_NAME, | ||
dimension=MODEL_DIMENSION | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
# Model configuration | ||
MODEL_NAME = "paraphrase-MiniLM-L3-v2" | ||
|
||
dimensions = {"paraphrase-MiniLM-L3-v2": 384} | ||
MODEL_DIMENSION = dimensions[MODEL_NAME] | ||
|
||
# Supabase configuration | ||
SUPABASE_TABLE_NAME = "adoptee" | ||
|
||
# Collection configuration | ||
VECS_COLLECTION_NAME = "adoptee_vector" |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
from tqdm import tqdm | ||
from clients import model, vx, adoptee_vector_collection, adoptee_table | ||
|
||
def upsert_data(model, database_table, vector_collection, batch_size=64): | ||
""" | ||
Encodes and upserts data to a vector database in batches. | ||
|
||
Args: | ||
model: The embedding model. | ||
database_table (list): A list of dictionaries containing the data. | ||
vector_collection: The vector collection to which to upsert records. | ||
batch_size (int): The number of records to process per batch. | ||
""" | ||
|
||
for i in tqdm(range(0, len(database_table), batch_size)): | ||
batch = database_table[i:i + batch_size] | ||
|
||
ids = [row['id'] for row in batch] | ||
bios = [row['bio'] for row in batch] | ||
|
||
embeddings = model.encode(bios, show_progress_bar=False).tolist() | ||
|
||
records = [] | ||
|
||
for j, row in enumerate(batch): | ||
metadata = { | ||
"bio": row["bio"], | ||
"gender": row["gender"], | ||
"age": row["age"], | ||
"veteran_status": row["veteran_status"], | ||
"offense": row["offense"], | ||
"state": row["state"] | ||
} | ||
|
||
records.append(((ids[j], embeddings[j], metadata))) | ||
|
||
try: | ||
vector_collection.upsert(records) | ||
print(f"Successfully upserted batch starting at index {i}") | ||
except Exception as e: | ||
print(f"Upsert failed for batch starting at index {i}: {e}") | ||
|
||
if __name__ == "__main__": | ||
upsert_data(model, adoptee_table, adoptee_vector_collection) | ||
vx.disconnect() |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,6 +22,7 @@ const eslintConfig = [ | |
'build', | ||
'.vscode', | ||
'next-env.d.ts', | ||
'adopt-an-inmate-venv', | ||
], | ||
}, | ||
{ | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
sentence-transformers | ||
supabase | ||
vecs | ||
python-dotenv | ||
tqdm |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,196 @@ | ||
# | ||
# This file is autogenerated by pip-compile with Python 3.13 | ||
# by the following command: | ||
# | ||
# pip-compile --output-file=requirements.txt requirements.in | ||
# | ||
annotated-types==0.7.0 | ||
# via pydantic | ||
anyio==4.11.0 | ||
# via httpx | ||
certifi==2025.10.5 | ||
# via | ||
# httpcore | ||
# httpx | ||
# requests | ||
cffi==2.0.0 | ||
# via cryptography | ||
charset-normalizer==3.4.3 | ||
# via requests | ||
cryptography==46.0.2 | ||
# via pyjwt | ||
deprecated==1.2.18 | ||
# via vecs | ||
deprecation==2.1.0 | ||
# via | ||
# postgrest | ||
# storage3 | ||
filelock==3.20.0 | ||
# via | ||
# huggingface-hub | ||
# torch | ||
# transformers | ||
flupy==1.2.3 | ||
# via vecs | ||
fsspec==2025.9.0 | ||
# via | ||
# huggingface-hub | ||
# torch | ||
h11==0.16.0 | ||
# via httpcore | ||
h2==4.3.0 | ||
# via httpx | ||
hf-xet==1.1.10 | ||
# via huggingface-hub | ||
hpack==4.1.0 | ||
# via h2 | ||
httpcore==1.0.9 | ||
# via httpx | ||
httpx[http2]==0.28.1 | ||
# via | ||
# postgrest | ||
# storage3 | ||
# supabase | ||
# supabase-auth | ||
# supabase-functions | ||
huggingface-hub==0.35.3 | ||
# via | ||
# sentence-transformers | ||
# tokenizers | ||
# transformers | ||
hyperframe==6.1.0 | ||
# via h2 | ||
idna==3.10 | ||
# via | ||
# anyio | ||
# httpx | ||
# requests | ||
# yarl | ||
jinja2==3.1.6 | ||
# via torch | ||
joblib==1.5.2 | ||
# via scikit-learn | ||
markupsafe==3.0.3 | ||
# via jinja2 | ||
mpmath==1.3.0 | ||
# via sympy | ||
multidict==6.7.0 | ||
# via yarl | ||
networkx==3.5 | ||
# via torch | ||
numpy==2.3.3 | ||
# via | ||
# pgvector | ||
# scikit-learn | ||
# scipy | ||
# transformers | ||
packaging==25.0 | ||
# via | ||
# deprecation | ||
# huggingface-hub | ||
# transformers | ||
pgvector==0.3.6 | ||
# via vecs | ||
pillow==11.3.0 | ||
# via sentence-transformers | ||
postgrest==2.22.0 | ||
# via supabase | ||
propcache==0.4.1 | ||
# via yarl | ||
psycopg2-binary==2.9.11 | ||
# via vecs | ||
pycparser==2.23 | ||
# via cffi | ||
pydantic==2.12.0 | ||
# via | ||
# postgrest | ||
# realtime | ||
# storage3 | ||
# supabase-auth | ||
pydantic-core==2.41.1 | ||
# via pydantic | ||
pyjwt[crypto]==2.10.1 | ||
# via supabase-auth | ||
python-dotenv==1.1.1 | ||
# via -r requirements.in | ||
pyyaml==6.0.3 | ||
# via | ||
# huggingface-hub | ||
# transformers | ||
realtime==2.22.0 | ||
# via supabase | ||
regex==2025.9.18 | ||
# via transformers | ||
requests==2.32.5 | ||
# via | ||
# huggingface-hub | ||
# transformers | ||
safetensors==0.6.2 | ||
# via transformers | ||
scikit-learn==1.7.2 | ||
# via sentence-transformers | ||
scipy==1.16.2 | ||
# via | ||
# scikit-learn | ||
# sentence-transformers | ||
sentence-transformers==5.1.1 | ||
# via -r requirements.in | ||
sniffio==1.3.1 | ||
# via anyio | ||
sqlalchemy==2.0.44 | ||
# via vecs | ||
storage3==2.22.0 | ||
# via supabase | ||
strenum==0.4.15 | ||
# via supabase-functions | ||
supabase==2.22.0 | ||
# via -r requirements.in | ||
supabase-auth==2.22.0 | ||
# via supabase | ||
supabase-functions==2.22.0 | ||
# via supabase | ||
sympy==1.14.0 | ||
# via torch | ||
threadpoolctl==3.6.0 | ||
# via scikit-learn | ||
tokenizers==0.22.1 | ||
# via transformers | ||
torch==2.8.0 | ||
# via sentence-transformers | ||
tqdm==4.67.1 | ||
# via | ||
# -r requirements.in | ||
# huggingface-hub | ||
# sentence-transformers | ||
# transformers | ||
transformers==4.57.0 | ||
# via sentence-transformers | ||
typing-extensions==4.15.0 | ||
# via | ||
# flupy | ||
# huggingface-hub | ||
# pydantic | ||
# pydantic-core | ||
# realtime | ||
# sentence-transformers | ||
# sqlalchemy | ||
# torch | ||
# typing-inspection | ||
typing-inspection==0.4.2 | ||
# via pydantic | ||
urllib3==2.5.0 | ||
# via requests | ||
vecs==0.4.5 | ||
# via -r requirements.in | ||
websockets==15.0.1 | ||
# via realtime | ||
wrapt==1.17.3 | ||
# via deprecated | ||
yarl==1.22.0 | ||
# via | ||
# postgrest | ||
# storage3 | ||
# supabase-functions | ||
|
||
# The following packages are considered to be unsafe in a requirements file: | ||
# setuptools |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would also recommend adding a line for
__pycache__
.It's generally best to not add binaries to the code repo, especially if it is generated by a package manager like
pip
, since it can become an attraction for merge conflicts.Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@carolynzhuang once you land this change we can merge ur PR!