diff --git a/chains/modules/embeddings.py b/chains/modules/embeddings.py new file mode 100644 index 0000000..3abeddf --- /dev/null +++ b/chains/modules/embeddings.py @@ -0,0 +1,34 @@ +from langchain.embeddings.huggingface import HuggingFaceEmbeddings + +from typing import Any, List + + +class MyEmbeddings(HuggingFaceEmbeddings): + def __init__(self, **kwargs: Any): + super().__init__(**kwargs) + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + """Compute doc embeddings using a HuggingFace transformer model. + + Args: + texts: The list of texts to embed. + + Returns: + List of embeddings, one for each text. + """ + texts = list(map(lambda x: x.replace("\n", " "), texts)) + embeddings = self.client.encode(texts, normalize_embeddings=True) + return embeddings.tolist() + + def embed_query(self, text: str) -> List[float]: + """Compute query embeddings using a HuggingFace transformer model. + + Args: + text: The text to embed. + + Returns: + Embeddings for the text. + """ + text = text.replace("\n", " ") + embedding = self.client.encode(text, normalize_embeddings=True) + return embedding.tolist() diff --git a/chains/modules/vectorstores.py b/chains/modules/vectorstores.py new file mode 100644 index 0000000..da89775 --- /dev/null +++ b/chains/modules/vectorstores.py @@ -0,0 +1,121 @@ +from langchain.vectorstores import FAISS +from typing import Any, Callable, List, Optional, Tuple, Dict +from langchain.docstore.document import Document +from langchain.docstore.base import Docstore + +from langchain.vectorstores.utils import maximal_marginal_relevance +from langchain.embeddings.base import Embeddings +import uuid +from langchain.docstore.in_memory import InMemoryDocstore + +import numpy as np + +def dependable_faiss_import() -> Any: + """Import faiss if available, otherwise raise error.""" + try: + import faiss + except ImportError: + raise ValueError( + "Could not import faiss python package. " + "Please install it with `pip install faiss` " + "or `pip install faiss-cpu` (depending on Python version)." + ) + return faiss + +class FAISSVS(FAISS): + def __init__(self, + embedding_function: Callable[..., Any], + index: Any, + docstore: Docstore, + index_to_docstore_id: Dict[int, str]): + super().__init__(embedding_function, index, docstore, index_to_docstore_id) + + def max_marginal_relevance_search_by_vector( + self, embedding: List[float], k: int = 4, fetch_k: int = 20, **kwargs: Any + ) -> List[Tuple[Document, float]]: + """Return docs selected using the maximal marginal relevance. + + Maximal marginal relevance optimizes for similarity to query AND diversity + among selected documents. + + Args: + embedding: Embedding to look up documents similar to. + k: Number of Documents to return. Defaults to 4. + fetch_k: Number of Documents to fetch to pass to MMR algorithm. + + Returns: + List of Documents with scores selected by maximal marginal relevance. + """ + scores, indices = self.index.search(np.array([embedding], dtype=np.float32), fetch_k) + # -1 happens when not enough docs are returned. + embeddings = [self.index.reconstruct(int(i)) for i in indices[0] if i != -1] + mmr_selected = maximal_marginal_relevance( + np.array([embedding], dtype=np.float32), embeddings, k=k + ) + selected_indices = [indices[0][i] for i in mmr_selected] + selected_scores = [scores[0][i] for i in mmr_selected] + docs = [] + for i, score in zip(selected_indices, selected_scores): + if i == -1: + # This happens when not enough docs are returned. + continue + _id = self.index_to_docstore_id[i] + doc = self.docstore.search(_id) + if not isinstance(doc, Document): + raise ValueError(f"Could not find document for id {_id}, got {doc}") + docs.append((doc, score)) + return docs + + def max_marginal_relevance_search( + self, + query: str, + k: int = 4, + fetch_k: int = 20, + **kwargs: Any, + ) -> List[Tuple[Document, float]]: + """Return docs selected using the maximal marginal relevance. + + Maximal marginal relevance optimizes for similarity to query AND diversity + among selected documents. + + Args: + query: Text to look up documents similar to. + k: Number of Documents to return. Defaults to 4. + fetch_k: Number of Documents to fetch to pass to MMR algorithm. + + Returns: + List of Documents with scores selected by maximal marginal relevance. + """ + embedding = self.embedding_function(query) + docs = self.max_marginal_relevance_search_by_vector(embedding, k, fetch_k) + return docs + + @classmethod + def __from( + cls, + texts: List[str], + embeddings: List[List[float]], + embedding: Embeddings, + metadatas: Optional[List[dict]] = None, + **kwargs: Any, + ) -> FAISS: + faiss = dependable_faiss_import() + index = faiss.IndexFlatIP(len(embeddings[0])) + index.add(np.array(embeddings, dtype=np.float32)) + + # # my code, for speeding up search + # quantizer = faiss.IndexFlatL2(len(embeddings[0])) + # index = faiss.IndexIVFFlat(quantizer, len(embeddings[0]), 100) + # index.train(np.array(embeddings, dtype=np.float32)) + # index.add(np.array(embeddings, dtype=np.float32)) + + documents = [] + for i, text in enumerate(texts): + metadata = metadatas[i] if metadatas else {} + documents.append(Document(page_content=text, metadata=metadata)) + index_to_id = {i: str(uuid.uuid4()) for i in range(len(documents))} + docstore = InMemoryDocstore( + {index_to_id[i]: doc for i, doc in enumerate(documents)} + ) + return cls(embedding.embed_query, index, docstore, index_to_id) +