From f1f742ce442799ed868d52983b6378b9f11a9e69 Mon Sep 17 00:00:00 2001 From: imClumsyPanda Date: Wed, 7 Jun 2023 23:18:47 +0800 Subject: [PATCH] add self-defined class MyFAISS --- chains/local_doc_qa.py | 91 ++++--------------------------- vectorstores/MyFAISS.py | 114 +++++++++++++++++++++++++++++++++++++++ vectorstores/__init__.py | 1 + 3 files changed, 124 insertions(+), 82 deletions(-) create mode 100644 vectorstores/MyFAISS.py create mode 100644 vectorstores/__init__.py diff --git a/chains/local_doc_qa.py b/chains/local_doc_qa.py index c96ede5..96a4ccd 100644 --- a/chains/local_doc_qa.py +++ b/chains/local_doc_qa.py @@ -1,5 +1,5 @@ from langchain.embeddings.huggingface import HuggingFaceEmbeddings -from langchain.vectorstores import FAISS +from vectorstores import MyFAISS from langchain.document_loaders import UnstructuredFileLoader, TextLoader, CSVLoader from configs.model_config import * import datetime @@ -32,7 +32,7 @@ HuggingFaceEmbeddings.__hash__ = _embeddings_hash # will keep CACHED_VS_NUM of vector store caches @lru_cache(CACHED_VS_NUM) def load_vector_store(vs_path, embeddings): - return FAISS.load_local(vs_path, embeddings) + return MyFAISS.load_local(vs_path, embeddings) def tree(filepath, ignore_dir_names=None, ignore_file_names=None): @@ -107,78 +107,6 @@ def generate_prompt(related_docs: List[str], return prompt -def seperate_list(ls: List[int]) -> List[List[int]]: - lists = [] - ls1 = [ls[0]] - for i in range(1, len(ls)): - if ls[i - 1] + 1 == ls[i]: - ls1.append(ls[i]) - else: - lists.append(ls1) - ls1 = [ls[i]] - lists.append(ls1) - return lists - - -def similarity_search_with_score_by_vector( - self, embedding: List[float], k: int = 4 -) -> List[Tuple[Document, float]]: - scores, indices = self.index.search(np.array([embedding], dtype=np.float32), k) - docs = [] - id_set = set() - store_len = len(self.index_to_docstore_id) - for j, i in enumerate(indices[0]): - if i == -1 or 0 < self.score_threshold < scores[0][j]: - # This happens when not enough docs are returned. - continue - _id = self.index_to_docstore_id[i] - doc = self.docstore.search(_id) - if not self.chunk_conent: - if not isinstance(doc, Document): - raise ValueError(f"Could not find document for id {_id}, got {doc}") - doc.metadata["score"] = int(scores[0][j]) - docs.append(doc) - continue - id_set.add(i) - docs_len = len(doc.page_content) - for k in range(1, max(i, store_len - i)): - break_flag = False - for l in [i + k, i - k]: - if 0 <= l < len(self.index_to_docstore_id): - _id0 = self.index_to_docstore_id[l] - doc0 = self.docstore.search(_id0) - if docs_len + len(doc0.page_content) > self.chunk_size: - break_flag = True - break - elif doc0.metadata["source"] == doc.metadata["source"]: - docs_len += len(doc0.page_content) - id_set.add(l) - if break_flag: - break - if not self.chunk_conent: - return docs - if len(id_set) == 0 and self.score_threshold > 0: - return [] - id_list = sorted(list(id_set)) - id_lists = seperate_list(id_list) - for id_seq in id_lists: - for id in id_seq: - if id == id_seq[0]: - _id = self.index_to_docstore_id[id] - doc = self.docstore.search(_id) - else: - _id0 = self.index_to_docstore_id[id] - doc0 = self.docstore.search(_id0) - doc.page_content += " " + doc0.page_content - if not isinstance(doc, Document): - raise ValueError(f"Could not find document for id {_id}, got {doc}") - doc_score = min([scores[0][id] for id in [indices[0].tolist().index(i) for i in id_seq if i in indices[0]]]) - doc.metadata["score"] = int(doc_score) - docs.append(doc) - torch_gc() - return docs - - def search_result2docs(search_results): docs = [] for result in search_results: @@ -263,7 +191,7 @@ class LocalDocQA: if not vs_path: vs_path = os.path.join(VS_ROOT_PATH, f"""{"".join(lazy_pinyin(os.path.splitext(file)[0]))}_FAISS_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}""") - vector_store = FAISS.from_documents(docs, self.embeddings) # docs 为Document列表 + vector_store = MyFAISS.from_documents(docs, self.embeddings) # docs 为Document列表 torch_gc() vector_store.save_local(vs_path) @@ -281,11 +209,11 @@ class LocalDocQA: if not one_content_segmentation: text_splitter = ChineseTextSplitter(pdf=False, sentence_size=sentence_size) docs = text_splitter.split_documents(docs) - if os.path.isdir(vs_path) and os.path.isfile(vs_path+"/index.faiss"): + if os.path.isdir(vs_path) and os.path.isfile(vs_path + "/index.faiss"): vector_store = load_vector_store(vs_path, self.embeddings) vector_store.add_documents(docs) else: - vector_store = FAISS.from_documents(docs, self.embeddings) ##docs 为Document列表 + vector_store = MyFAISS.from_documents(docs, self.embeddings) ##docs 为Document列表 torch_gc() vector_store.save_local(vs_path) return vs_path, [one_title] @@ -295,13 +223,12 @@ class LocalDocQA: def get_knowledge_based_answer(self, query, vs_path, chat_history=[], streaming: bool = STREAMING): vector_store = load_vector_store(vs_path, self.embeddings) - FAISS.similarity_search_with_score_by_vector = similarity_search_with_score_by_vector vector_store.chunk_size = self.chunk_size vector_store.chunk_conent = self.chunk_conent vector_store.score_threshold = self.score_threshold related_docs_with_score = vector_store.similarity_search_with_score(query, k=self.top_k) torch_gc() - if len(related_docs_with_score)>0: + if len(related_docs_with_score) > 0: prompt = generate_prompt(related_docs_with_score, query) else: prompt = query @@ -326,7 +253,7 @@ class LocalDocQA: score_threshold=VECTOR_SEARCH_SCORE_THRESHOLD, vector_search_top_k=VECTOR_SEARCH_TOP_K, chunk_size=CHUNK_SIZE): vector_store = load_vector_store(vs_path, self.embeddings) - FAISS.similarity_search_with_score_by_vector = similarity_search_with_score_by_vector + # FAISS.similarity_search_with_score_by_vector = similarity_search_with_score_by_vector vector_store.chunk_conent = chunk_conent vector_store.score_threshold = score_threshold vector_store.chunk_size = chunk_size @@ -381,8 +308,8 @@ if __name__ == "__main__": streaming=True): print(resp["result"][last_print_len:], end="", flush=True) last_print_len = len(resp["result"]) - source_text = [f"""出处 [{inum + 1}] {doc.metadata['source'] if doc.metadata['source'].startswith("http") - else os.path.split(doc.metadata['source'])[-1]}:\n\n{doc.page_content}\n\n""" + source_text = [f"""出处 [{inum + 1}] {doc.metadata['source'] if doc.metadata['source'].startswith("http") + else os.path.split(doc.metadata['source'])[-1]}:\n\n{doc.page_content}\n\n""" # f"""相关度:{doc.metadata['score']}\n\n""" for inum, doc in enumerate(resp["source_documents"])] diff --git a/vectorstores/MyFAISS.py b/vectorstores/MyFAISS.py new file mode 100644 index 0000000..fedbb34 --- /dev/null +++ b/vectorstores/MyFAISS.py @@ -0,0 +1,114 @@ +from langchain.vectorstores import FAISS +from langchain.vectorstores.base import VectorStore +from langchain.vectorstores.faiss import dependable_faiss_import +from typing import Any, Callable, List, Tuple, Dict +from langchain.docstore.base import Docstore +from langchain.docstore.document import Document +import numpy as np + + +class MyFAISS(FAISS, VectorStore): + def __init__( + self, + embedding_function: Callable, + index: Any, + docstore: Docstore, + index_to_docstore_id: Dict[int, str], + normalize_L2: bool = False, + ): + super().__init__(embedding_function=embedding_function, + index=index, + docstore=docstore, + index_to_docstore_id=index_to_docstore_id, + normalize_L2=normalize_L2) + + # def similarity_search_with_score_by_vector( + # self, embedding: List[float], k: int = 4 + # ) -> List[Tuple[Document, float]]: + # faiss = dependable_faiss_import() + # vector = np.array([embedding], dtype=np.float32) + # if self._normalize_L2: + # faiss.normalize_L2(vector) + # scores, indices = self.index.search(vector, k) + # docs = [] + # for j, i in enumerate(indices[0]): + # if i == -1: + # # This happens when not enough docs are returned. + # continue + # _id = self.index_to_docstore_id[i] + # doc = self.docstore.search(_id) + # if not isinstance(doc, Document): + # raise ValueError(f"Could not find document for id {_id}, got {doc}") + # + # docs.append((doc, scores[0][j])) + # return docs + + def seperate_list(self, ls: List[int]) -> List[List[int]]: + # TODO: 增加是否属于同一文档的判断 + lists = [] + ls1 = [ls[0]] + for i in range(1, len(ls)): + if ls[i - 1] + 1 == ls[i]: + ls1.append(ls[i]) + else: + lists.append(ls1) + ls1 = [ls[i]] + lists.append(ls1) + return lists + + def similarity_search_with_score_by_vector( + self, embedding: List[float], k: int = 4 + ) -> List[Document]: + scores, indices = self.index.search(np.array([embedding], dtype=np.float32), k) + docs = [] + id_set = set() + store_len = len(self.index_to_docstore_id) + for j, i in enumerate(indices[0]): + if i == -1 or 0 < self.score_threshold < scores[0][j]: + # This happens when not enough docs are returned. + continue + _id = self.index_to_docstore_id[i] + doc = self.docstore.search(_id) + if (not self.chunk_conent) or ("add_context" in doc.metadata and not doc.metadata["add_context"]): + if not isinstance(doc, Document): + raise ValueError(f"Could not find document for id {_id}, got {doc}") + doc.metadata["score"] = int(scores[0][j]) + docs.append(doc) + continue + id_set.add(i) + docs_len = len(doc.page_content) + for k in range(1, max(i, store_len - i)): + break_flag = False + for l in [i + k, i - k]: + if 0 <= l < len(self.index_to_docstore_id): + _id0 = self.index_to_docstore_id[l] + doc0 = self.docstore.search(_id0) + if docs_len + len(doc0.page_content) > self.chunk_size: + break_flag = True + break + elif doc0.metadata["source"] == doc.metadata["source"]: + docs_len += len(doc0.page_content) + id_set.add(l) + if break_flag: + break + if (not self.chunk_conent) or ("add_context" in doc.metadata and doc.metadata["add_context"] == False): + return docs + if len(id_set) == 0 and self.score_threshold > 0: + return [] + id_list = sorted(list(id_set)) + id_lists = self.seperate_list(id_list) + for id_seq in id_lists: + for id in id_seq: + if id == id_seq[0]: + _id = self.index_to_docstore_id[id] + doc = self.docstore.search(_id) + else: + _id0 = self.index_to_docstore_id[id] + doc0 = self.docstore.search(_id0) + doc.page_content += " " + doc0.page_content + if not isinstance(doc, Document): + raise ValueError(f"Could not find document for id {_id}, got {doc}") + doc_score = min([scores[0][id] for id in [indices[0].tolist().index(i) for i in id_seq if i in indices[0]]]) + doc.metadata["score"] = int(doc_score) + docs.append(doc) + return docs diff --git a/vectorstores/__init__.py b/vectorstores/__init__.py new file mode 100644 index 0000000..d08d3e3 --- /dev/null +++ b/vectorstores/__init__.py @@ -0,0 +1 @@ +from .MyFAISS import MyFAISS \ No newline at end of file