|
6 | 6 |
|
7 | 7 | DEFAULT_MODEL_NAME = "sentence-transformers/all-mpnet-base-v2" |
8 | 8 | DEFAULT_INSTRUCT_MODEL = "hkunlp/instructor-large" |
| 9 | +DEFAULT_BGE_MODEL = "BAAI/bge-large-en" |
9 | 10 | DEFAULT_EMBED_INSTRUCTION = "Represent the document for retrieval: " |
10 | 11 | DEFAULT_QUERY_INSTRUCTION = ( |
11 | 12 | "Represent the question for retrieving supporting documents: " |
12 | 13 | ) |
| 14 | +DEFAULT_EMBED_BGE_INSTRUCTION = ( |
| 15 | + "Represent this sentence for searching relevant passages: " |
| 16 | +) |
| 17 | +DEFAULT_QUERY_BGE_INSTRUCTION = ( |
| 18 | + "Represent this question for searching relevant passages: " |
| 19 | +) |
13 | 20 |
|
14 | 21 |
|
15 | 22 | class HuggingFaceEmbeddings(BaseModel, Embeddings): |
@@ -169,3 +176,86 @@ def embed_query(self, text: str) -> List[float]: |
169 | 176 | instruction_pair = [self.query_instruction, text] |
170 | 177 | embedding = self.client.encode([instruction_pair], **self.encode_kwargs)[0] |
171 | 178 | return embedding.tolist() |
| 179 | + |
| 180 | + |
| 181 | +class HuggingFaceBgeEmbeddings(BaseModel, Embeddings): |
| 182 | + """HuggingFace BGE sentence_transformers embedding models. |
| 183 | +
|
| 184 | + To use, you should have the ``sentence_transformers`` python package installed. |
| 185 | +
|
| 186 | + Example: |
| 187 | + .. code-block:: python |
| 188 | +
|
| 189 | + from langchain.embeddings import HuggingFaceBgeEmbeddings |
| 190 | +
|
| 191 | + model_name = "BAAI/bge-large-en" |
| 192 | + model_kwargs = {'device': 'cpu'} |
| 193 | + encode_kwargs = {'normalize_embeddings': False} |
| 194 | + hf = HuggingFaceBgeEmbeddings( |
| 195 | + model_name=model_name, |
| 196 | + model_kwargs=model_kwargs, |
| 197 | + encode_kwargs=encode_kwargs |
| 198 | + ) |
| 199 | + """ |
| 200 | + |
| 201 | + client: Any #: :meta private: |
| 202 | + model_name: str = DEFAULT_BGE_MODEL |
| 203 | + """Model name to use.""" |
| 204 | + cache_folder: Optional[str] = None |
| 205 | + """Path to store models. |
| 206 | + Can be also set by SENTENCE_TRANSFORMERS_HOME environment variable.""" |
| 207 | + model_kwargs: Dict[str, Any] = Field(default_factory=dict) |
| 208 | + """Key word arguments to pass to the model.""" |
| 209 | + encode_kwargs: Dict[str, Any] = Field(default_factory=dict) |
| 210 | + """Key word arguments to pass when calling the `encode` method of the model.""" |
| 211 | + embed_instruction: str = DEFAULT_EMBED_BGE_INSTRUCTION |
| 212 | + """Instruction to use for embedding documents.""" |
| 213 | + query_instruction: str = DEFAULT_QUERY_BGE_INSTRUCTION |
| 214 | + """Instruction to use for embedding query.""" |
| 215 | + |
| 216 | + def __init__(self, **kwargs: Any): |
| 217 | + """Initialize the sentence_transformer.""" |
| 218 | + super().__init__(**kwargs) |
| 219 | + try: |
| 220 | + import sentence_transformers |
| 221 | + |
| 222 | + except ImportError as exc: |
| 223 | + raise ImportError( |
| 224 | + "Could not import sentence_transformers python package. " |
| 225 | + "Please install it with `pip install sentence_transformers`." |
| 226 | + ) from exc |
| 227 | + |
| 228 | + self.client = sentence_transformers.SentenceTransformer( |
| 229 | + self.model_name, cache_folder=self.cache_folder, **self.model_kwargs |
| 230 | + ) |
| 231 | + |
| 232 | + class Config: |
| 233 | + """Configuration for this pydantic object.""" |
| 234 | + |
| 235 | + extra = Extra.forbid |
| 236 | + |
| 237 | + def embed_documents(self, texts: List[str]) -> List[List[float]]: |
| 238 | + """Compute doc embeddings using a HuggingFace transformer model. |
| 239 | +
|
| 240 | + Args: |
| 241 | + texts: The list of texts to embed. |
| 242 | +
|
| 243 | + Returns: |
| 244 | + List of embeddings, one for each text. |
| 245 | + """ |
| 246 | + instruction_pairs = [[self.embed_instruction, text] for text in texts] |
| 247 | + embeddings = self.client.encode(instruction_pairs, **self.encode_kwargs) |
| 248 | + return embeddings.tolist() |
| 249 | + |
| 250 | + def embed_query(self, text: str) -> List[float]: |
| 251 | + """Compute query embeddings using a HuggingFace transformer model. |
| 252 | +
|
| 253 | + Args: |
| 254 | + text: The text to embed. |
| 255 | +
|
| 256 | + Returns: |
| 257 | + Embeddings for the text. |
| 258 | + """ |
| 259 | + instruction_pair = [self.query_instruction, text] |
| 260 | + embedding = self.client.encode([instruction_pair], **self.encode_kwargs)[0] |
| 261 | + return embedding.tolist() |
0 commit comments