Langchain-Chatchat/vectorstores/MyFAISS.py

172 lines
7.7 KiB
Python
Raw Normal View History

2023-06-07 23:18:47 +08:00
from langchain.vectorstores import FAISS
from langchain.vectorstores.base import VectorStore
from langchain.vectorstores.faiss import dependable_faiss_import
from typing import Any, Callable, List, Dict
2023-06-07 23:18:47 +08:00
from langchain.docstore.base import Docstore
from langchain.docstore.document import Document
import numpy as np
import copy
2023-06-16 23:51:20 +08:00
import os
from configs.model_config import *
2023-06-07 23:18:47 +08:00
class MyFAISS(FAISS, VectorStore):
def __init__(
self,
embedding_function: Callable,
index: Any,
docstore: Docstore,
index_to_docstore_id: Dict[int, str],
normalize_L2: bool = False,
):
super().__init__(embedding_function=embedding_function,
index=index,
docstore=docstore,
index_to_docstore_id=index_to_docstore_id,
normalize_L2=normalize_L2)
self.score_threshold=VECTOR_SEARCH_SCORE_THRESHOLD
self.chunk_size = CHUNK_SIZE
self.chunk_conent = False
2023-06-07 23:18:47 +08:00
def seperate_list(self, ls: List[int]) -> List[List[int]]:
2023-06-07 23:18:47 +08:00
lists = []
ls1 = [ls[0]]
source1 = self.index_to_docstore_source(ls[0])
2023-06-07 23:18:47 +08:00
for i in range(1, len(ls)):
if ls[i - 1] + 1 == ls[i] and self.index_to_docstore_source(ls[i]) == source1:
2023-06-07 23:18:47 +08:00
ls1.append(ls[i])
else:
lists.append(ls1)
ls1 = [ls[i]]
source1 = self.index_to_docstore_source(ls[i])
2023-06-07 23:18:47 +08:00
lists.append(ls1)
return lists
def similarity_search_with_score_by_vector(
self, embedding: List[float], k: int = 4
) -> List[Document]:
2023-06-12 00:06:06 +08:00
faiss = dependable_faiss_import()
vector = np.array([embedding], dtype=np.float32)
if self._normalize_L2:
faiss.normalize_L2(vector)
scores, indices = self.index.search(vector, k)
2023-06-07 23:18:47 +08:00
docs = []
id_set = set()
store_len = len(self.index_to_docstore_id)
rearrange_id_list = False
2023-06-07 23:18:47 +08:00
for j, i in enumerate(indices[0]):
if i == -1 or 0 < self.score_threshold < scores[0][j]:
# This happens when not enough docs are returned.
continue
if i in self.index_to_docstore_id:
2023-06-16 23:52:28 +08:00
_id = self.index_to_docstore_id[i]
# 执行接下来的操作
else:
2023-06-16 23:52:28 +08:00
continue
2023-06-07 23:18:47 +08:00
doc = self.docstore.search(_id)
2023-06-12 00:06:06 +08:00
if (not self.chunk_conent) or ("context_expand" in doc.metadata and not doc.metadata["context_expand"]):
# 匹配出的文本如果不需要扩展上下文则执行如下代码
2023-06-07 23:18:47 +08:00
if not isinstance(doc, Document):
raise ValueError(f"Could not find document for id {_id}, got {doc}")
doc.metadata["score"] = int(scores[0][j])
docs.append(doc)
continue
2023-06-07 23:18:47 +08:00
id_set.add(i)
docs_len = len(doc.page_content)
for k in range(1, max(i, store_len - i)):
break_flag = False
2023-06-12 00:06:06 +08:00
if "context_expand_method" in doc.metadata and doc.metadata["context_expand_method"] == "forward":
expand_range = [i + k]
elif "context_expand_method" in doc.metadata and doc.metadata["context_expand_method"] == "backward":
expand_range = [i - k]
else:
expand_range = [i + k, i - k]
for l in expand_range:
if l not in id_set and 0 <= l < len(self.index_to_docstore_id):
2023-06-07 23:18:47 +08:00
_id0 = self.index_to_docstore_id[l]
doc0 = self.docstore.search(_id0)
if docs_len + len(doc0.page_content) > self.chunk_size or doc0.metadata["source"] != \
doc.metadata["source"]:
2023-06-07 23:18:47 +08:00
break_flag = True
break
elif doc0.metadata["source"] == doc.metadata["source"]:
docs_len += len(doc0.page_content)
id_set.add(l)
rearrange_id_list = True
2023-06-07 23:18:47 +08:00
if break_flag:
break
if (not self.chunk_conent) or (not rearrange_id_list):
2023-06-07 23:18:47 +08:00
return docs
if len(id_set) == 0 and self.score_threshold > 0:
return []
id_list = sorted(list(id_set))
id_lists = self.seperate_list(id_list)
for id_seq in id_lists:
for id in id_seq:
if id == id_seq[0]:
_id = self.index_to_docstore_id[id]
# doc = self.docstore.search(_id)
doc = copy.deepcopy(self.docstore.search(_id))
2023-06-07 23:18:47 +08:00
else:
_id0 = self.index_to_docstore_id[id]
doc0 = self.docstore.search(_id0)
doc.page_content += " " + doc0.page_content
if not isinstance(doc, Document):
raise ValueError(f"Could not find document for id {_id}, got {doc}")
doc_score = min([scores[0][id] for id in [indices[0].tolist().index(i) for i in id_seq if i in indices[0]]])
doc.metadata["score"] = int(doc_score)
docs.append(doc)
return docs
def delete_doc(self, source: str or List[str]):
2023-06-14 00:35:33 +08:00
try:
if isinstance(source, str):
ids = [k for k, v in self.docstore._dict.items() if v.metadata["source"] == source]
2023-06-16 23:51:20 +08:00
vs_path = os.path.join(os.path.split(os.path.split(source)[0])[0], "vector_store")
else:
ids = [k for k, v in self.docstore._dict.items() if v.metadata["source"] in source]
2023-06-16 23:51:20 +08:00
vs_path = os.path.join(os.path.split(os.path.split(source[0])[0])[0], "vector_store")
if len(ids) == 0:
return f"docs delete fail"
else:
2023-08-06 16:49:52 +08:00
_reversed_index = {v: k for k, v in self.index_to_docstore_id.items()}
index_to_delete = [_reversed_index[i] for i in ids]
2023-08-04 23:26:42 +08:00
# 从 self.index 中删除对应id
# 使用remove_ids从faiss索引中删除向量时剩余的待索引向量idx仍然是连续的 0, 3, 4 - > 0, 1, 2
2023-08-06 16:49:52 +08:00
self.index.remove_ids(np.array(index_to_delete, dtype=np.int64))
for id in ids:
index = list(self.index_to_docstore_id.keys())[list(self.index_to_docstore_id.values()).index(id)]
self.index_to_docstore_id.pop(index)
self.docstore._dict.pop(id)
#为了保证index_to_docstore_id中的idx和faiss索引中的向量idx相一致需要将index_to_docstore_id中的idx重排序
index_to_docstore_id_items = sorted(self.index_to_docstore_id.items())#0, 1, 3 - > 0, 1, 2
for i in range(len(index_to_docstore_id_items)):
index_to_docstore_id_items[i] = (i, index_to_docstore_id_items[i][1])
self.index_to_docstore_id.clear()
self.index_to_docstore_id.update(index_to_docstore_id_items)
2023-06-16 23:51:20 +08:00
self.save_local(vs_path)
return f"docs delete success"
except Exception as e:
print(e)
2023-06-14 00:35:33 +08:00
return f"docs delete fail"
def update_doc(self, source, new_docs):
2023-06-14 00:35:33 +08:00
try:
delete_len = self.delete_doc(source)
ls = self.add_documents(new_docs)
return f"docs update success"
except Exception as e:
print(e)
2023-06-14 00:35:33 +08:00
return f"docs update fail"
def list_docs(self):
return list(set(v.metadata["source"] for v in self.docstore._dict.values()))
def index_to_docstore_source(self,i:int):
_id = self.index_to_docstore_id[i]
doc = self.docstore.search(_id)
return doc.metadata["source"]