updata MyFAISS

This commit is contained in:
imClumsyPanda 2023-06-12 00:06:06 +08:00
parent 27a9bf2433
commit 7863e0fea8
2 changed files with 17 additions and 30 deletions

View File

@ -4,9 +4,7 @@ from langchain.document_loaders import UnstructuredFileLoader, TextLoader, CSVLo
from configs.model_config import * from configs.model_config import *
import datetime import datetime
from textsplitter import ChineseTextSplitter from textsplitter import ChineseTextSplitter
from typing import List, Tuple, Dict from typing import List
from langchain.docstore.document import Document
import numpy as np
from utils import torch_gc from utils import torch_gc
from tqdm import tqdm from tqdm import tqdm
from pypinyin import lazy_pinyin from pypinyin import lazy_pinyin

View File

@ -22,27 +22,6 @@ class MyFAISS(FAISS, VectorStore):
index_to_docstore_id=index_to_docstore_id, index_to_docstore_id=index_to_docstore_id,
normalize_L2=normalize_L2) normalize_L2=normalize_L2)
# def similarity_search_with_score_by_vector(
# self, embedding: List[float], k: int = 4
# ) -> List[Tuple[Document, float]]:
# faiss = dependable_faiss_import()
# vector = np.array([embedding], dtype=np.float32)
# if self._normalize_L2:
# faiss.normalize_L2(vector)
# scores, indices = self.index.search(vector, k)
# docs = []
# for j, i in enumerate(indices[0]):
# if i == -1:
# # This happens when not enough docs are returned.
# continue
# _id = self.index_to_docstore_id[i]
# doc = self.docstore.search(_id)
# if not isinstance(doc, Document):
# raise ValueError(f"Could not find document for id {_id}, got {doc}")
#
# docs.append((doc, scores[0][j]))
# return docs
def seperate_list(self, ls: List[int]) -> List[List[int]]: def seperate_list(self, ls: List[int]) -> List[List[int]]:
# TODO: 增加是否属于同一文档的判断 # TODO: 增加是否属于同一文档的判断
lists = [] lists = []
@ -59,7 +38,11 @@ class MyFAISS(FAISS, VectorStore):
def similarity_search_with_score_by_vector( def similarity_search_with_score_by_vector(
self, embedding: List[float], k: int = 4 self, embedding: List[float], k: int = 4
) -> List[Document]: ) -> List[Document]:
scores, indices = self.index.search(np.array([embedding], dtype=np.float32), k) faiss = dependable_faiss_import()
vector = np.array([embedding], dtype=np.float32)
if self._normalize_L2:
faiss.normalize_L2(vector)
scores, indices = self.index.search(vector, k)
docs = [] docs = []
id_set = set() id_set = set()
store_len = len(self.index_to_docstore_id) store_len = len(self.index_to_docstore_id)
@ -69,7 +52,7 @@ class MyFAISS(FAISS, VectorStore):
continue continue
_id = self.index_to_docstore_id[i] _id = self.index_to_docstore_id[i]
doc = self.docstore.search(_id) doc = self.docstore.search(_id)
if (not self.chunk_conent) or ("add_context" in doc.metadata and not doc.metadata["add_context"]): if (not self.chunk_conent) or ("context_expand" in doc.metadata and not doc.metadata["context_expand"]):
if not isinstance(doc, Document): if not isinstance(doc, Document):
raise ValueError(f"Could not find document for id {_id}, got {doc}") raise ValueError(f"Could not find document for id {_id}, got {doc}")
doc.metadata["score"] = int(scores[0][j]) doc.metadata["score"] = int(scores[0][j])
@ -79,11 +62,17 @@ class MyFAISS(FAISS, VectorStore):
docs_len = len(doc.page_content) docs_len = len(doc.page_content)
for k in range(1, max(i, store_len - i)): for k in range(1, max(i, store_len - i)):
break_flag = False break_flag = False
for l in [i + k, i - k]: if "context_expand_method" in doc.metadata and doc.metadata["context_expand_method"] == "forward":
if 0 <= l < len(self.index_to_docstore_id): expand_range = [i + k]
elif "context_expand_method" in doc.metadata and doc.metadata["context_expand_method"] == "backward":
expand_range = [i - k]
else:
expand_range = [i + k, i - k]
for l in expand_range:
if l not in id_set and 0 <= l < len(self.index_to_docstore_id):
_id0 = self.index_to_docstore_id[l] _id0 = self.index_to_docstore_id[l]
doc0 = self.docstore.search(_id0) doc0 = self.docstore.search(_id0)
if docs_len + len(doc0.page_content) > self.chunk_size: if docs_len + len(doc0.page_content) > self.chunk_size or doc0.metadata["source"] != doc.metadata["source"]:
break_flag = True break_flag = True
break break
elif doc0.metadata["source"] == doc.metadata["source"]: elif doc0.metadata["source"] == doc.metadata["source"]:
@ -91,7 +80,7 @@ class MyFAISS(FAISS, VectorStore):
id_set.add(l) id_set.add(l)
if break_flag: if break_flag:
break break
if (not self.chunk_conent) or ("add_context" in doc.metadata and doc.metadata["add_context"] == False): if (not self.chunk_conent) or ("add_context" in doc.metadata and not doc.metadata["add_context"]):
return docs return docs
if len(id_set) == 0 and self.score_threshold > 0: if len(id_set) == 0 and self.score_threshold > 0:
return [] return []