35 lines
1.1 KiB
Python
35 lines
1.1 KiB
Python
|
|
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
||
|
|
|
||
|
|
from typing import Any, List
|
||
|
|
|
||
|
|
|
||
|
|
class MyEmbeddings(HuggingFaceEmbeddings):
|
||
|
|
def __init__(self, **kwargs: Any):
|
||
|
|
super().__init__(**kwargs)
|
||
|
|
|
||
|
|
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||
|
|
"""Compute doc embeddings using a HuggingFace transformer model.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
texts: The list of texts to embed.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
List of embeddings, one for each text.
|
||
|
|
"""
|
||
|
|
texts = list(map(lambda x: x.replace("\n", " "), texts))
|
||
|
|
embeddings = self.client.encode(texts, normalize_embeddings=True)
|
||
|
|
return embeddings.tolist()
|
||
|
|
|
||
|
|
def embed_query(self, text: str) -> List[float]:
|
||
|
|
"""Compute query embeddings using a HuggingFace transformer model.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
text: The text to embed.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Embeddings for the text.
|
||
|
|
"""
|
||
|
|
text = text.replace("\n", " ")
|
||
|
|
embedding = self.client.encode(text, normalize_embeddings=True)
|
||
|
|
return embedding.tolist()
|