Skip to content

Commit 40096c7

Browse files
manmax31hwchase17
andauthored
Add BGE embeddings support (#8848)
- Description: [BGE-large](https://huggingface.co/BAAI/bge-large-en) embeddings from BAAI are at the top of [MTEB leaderboard](https://huggingface.co/spaces/mteb/leaderboard). Hence adding support for it. - Tag maintainer: @baskaryan - Twitter handle: @ManabChetia3 --------- Co-authored-by: Harrison Chase <[email protected]>
1 parent fbc83df commit 40096c7

File tree

3 files changed

+176
-0
lines changed

3 files changed

+176
-0
lines changed
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"id": "719619d3",
6+
"metadata": {},
7+
"source": [
8+
"# BGE Hugging Face Embeddings\n",
9+
"\n",
10+
"This notebook shows how to use BGE Embeddings through Hugging Face"
11+
]
12+
},
13+
{
14+
"cell_type": "code",
15+
"execution_count": 8,
16+
"id": "f7a54279",
17+
"metadata": {
18+
"scrolled": true
19+
},
20+
"outputs": [],
21+
"source": [
22+
"# !pip install sentence_transformers"
23+
]
24+
},
25+
{
26+
"cell_type": "code",
27+
"execution_count": 5,
28+
"id": "9e1d5b6b",
29+
"metadata": {},
30+
"outputs": [],
31+
"source": [
32+
"from langchain.embeddings import HuggingFaceBgeEmbeddings\n",
33+
"\n",
34+
"model_name = \"BAAI/bge-small-en\"\n",
35+
"model_kwargs = {'device': 'cpu'}\n",
36+
"encode_kwargs = {'normalize_embeddings': False}\n",
37+
"hf = HuggingFaceBgeEmbeddings(\n",
38+
" model_name=model_name,\n",
39+
" model_kwargs=model_kwargs,\n",
40+
" encode_kwargs=encode_kwargs\n",
41+
")"
42+
]
43+
},
44+
{
45+
"cell_type": "code",
46+
"execution_count": 7,
47+
"id": "e59d1a89",
48+
"metadata": {},
49+
"outputs": [],
50+
"source": [
51+
"embedding = hf.embed_query(\"hi this is harrison\")"
52+
]
53+
},
54+
{
55+
"cell_type": "code",
56+
"execution_count": null,
57+
"id": "e596315f",
58+
"metadata": {},
59+
"outputs": [],
60+
"source": []
61+
}
62+
],
63+
"metadata": {
64+
"kernelspec": {
65+
"display_name": "Python 3 (ipykernel)",
66+
"language": "python",
67+
"name": "python3"
68+
},
69+
"language_info": {
70+
"codemirror_mode": {
71+
"name": "ipython",
72+
"version": 3
73+
},
74+
"file_extension": ".py",
75+
"mimetype": "text/x-python",
76+
"name": "python",
77+
"nbconvert_exporter": "python",
78+
"pygments_lexer": "ipython3",
79+
"version": "3.10.1"
80+
}
81+
},
82+
"nbformat": 4,
83+
"nbformat_minor": 5
84+
}

libs/langchain/langchain/embeddings/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from langchain.embeddings.google_palm import GooglePalmEmbeddings
3232
from langchain.embeddings.gpt4all import GPT4AllEmbeddings
3333
from langchain.embeddings.huggingface import (
34+
HuggingFaceBgeEmbeddings,
3435
HuggingFaceEmbeddings,
3536
HuggingFaceInstructEmbeddings,
3637
)
@@ -97,6 +98,7 @@
9798
"XinferenceEmbeddings",
9899
"LocalAIEmbeddings",
99100
"AwaEmbeddings",
101+
"HuggingFaceBgeEmbeddings",
100102
]
101103

102104

libs/langchain/langchain/embeddings/huggingface.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,17 @@
66

77
DEFAULT_MODEL_NAME = "sentence-transformers/all-mpnet-base-v2"
88
DEFAULT_INSTRUCT_MODEL = "hkunlp/instructor-large"
9+
DEFAULT_BGE_MODEL = "BAAI/bge-large-en"
910
DEFAULT_EMBED_INSTRUCTION = "Represent the document for retrieval: "
1011
DEFAULT_QUERY_INSTRUCTION = (
1112
"Represent the question for retrieving supporting documents: "
1213
)
14+
DEFAULT_EMBED_BGE_INSTRUCTION = (
15+
"Represent this sentence for searching relevant passages: "
16+
)
17+
DEFAULT_QUERY_BGE_INSTRUCTION = (
18+
"Represent this question for searching relevant passages: "
19+
)
1320

1421

1522
class HuggingFaceEmbeddings(BaseModel, Embeddings):
@@ -169,3 +176,86 @@ def embed_query(self, text: str) -> List[float]:
169176
instruction_pair = [self.query_instruction, text]
170177
embedding = self.client.encode([instruction_pair], **self.encode_kwargs)[0]
171178
return embedding.tolist()
179+
180+
181+
class HuggingFaceBgeEmbeddings(BaseModel, Embeddings):
182+
"""HuggingFace BGE sentence_transformers embedding models.
183+
184+
To use, you should have the ``sentence_transformers`` python package installed.
185+
186+
Example:
187+
.. code-block:: python
188+
189+
from langchain.embeddings import HuggingFaceBgeEmbeddings
190+
191+
model_name = "BAAI/bge-large-en"
192+
model_kwargs = {'device': 'cpu'}
193+
encode_kwargs = {'normalize_embeddings': False}
194+
hf = HuggingFaceBgeEmbeddings(
195+
model_name=model_name,
196+
model_kwargs=model_kwargs,
197+
encode_kwargs=encode_kwargs
198+
)
199+
"""
200+
201+
client: Any #: :meta private:
202+
model_name: str = DEFAULT_BGE_MODEL
203+
"""Model name to use."""
204+
cache_folder: Optional[str] = None
205+
"""Path to store models.
206+
Can be also set by SENTENCE_TRANSFORMERS_HOME environment variable."""
207+
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
208+
"""Key word arguments to pass to the model."""
209+
encode_kwargs: Dict[str, Any] = Field(default_factory=dict)
210+
"""Key word arguments to pass when calling the `encode` method of the model."""
211+
embed_instruction: str = DEFAULT_EMBED_BGE_INSTRUCTION
212+
"""Instruction to use for embedding documents."""
213+
query_instruction: str = DEFAULT_QUERY_BGE_INSTRUCTION
214+
"""Instruction to use for embedding query."""
215+
216+
def __init__(self, **kwargs: Any):
217+
"""Initialize the sentence_transformer."""
218+
super().__init__(**kwargs)
219+
try:
220+
import sentence_transformers
221+
222+
except ImportError as exc:
223+
raise ImportError(
224+
"Could not import sentence_transformers python package. "
225+
"Please install it with `pip install sentence_transformers`."
226+
) from exc
227+
228+
self.client = sentence_transformers.SentenceTransformer(
229+
self.model_name, cache_folder=self.cache_folder, **self.model_kwargs
230+
)
231+
232+
class Config:
233+
"""Configuration for this pydantic object."""
234+
235+
extra = Extra.forbid
236+
237+
def embed_documents(self, texts: List[str]) -> List[List[float]]:
238+
"""Compute doc embeddings using a HuggingFace transformer model.
239+
240+
Args:
241+
texts: The list of texts to embed.
242+
243+
Returns:
244+
List of embeddings, one for each text.
245+
"""
246+
instruction_pairs = [[self.embed_instruction, text] for text in texts]
247+
embeddings = self.client.encode(instruction_pairs, **self.encode_kwargs)
248+
return embeddings.tolist()
249+
250+
def embed_query(self, text: str) -> List[float]:
251+
"""Compute query embeddings using a HuggingFace transformer model.
252+
253+
Args:
254+
text: The text to embed.
255+
256+
Returns:
257+
Embeddings for the text.
258+
"""
259+
instruction_pair = [self.query_instruction, text]
260+
embedding = self.client.encode([instruction_pair], **self.encode_kwargs)[0]
261+
return embedding.tolist()

0 commit comments

Comments
 (0)