update kb_doc_api:make faiss cache working; delete vector store docs before add duplicate docs
This commit is contained in:
parent
d694652b87
commit
150a78bfd9
|
|
@ -77,6 +77,7 @@ class KBService(ABC):
|
|||
"""
|
||||
docs = kb_file.file2text()
|
||||
if docs:
|
||||
self.delete_doc(kb_file)
|
||||
embeddings = self._load_embeddings()
|
||||
self.do_add_doc(docs, embeddings)
|
||||
status = add_doc_to_db(kb_file)
|
||||
|
|
|
|||
|
|
@ -41,7 +41,23 @@ def load_vector_store(
|
|||
vs_path = get_vs_path(knowledge_base_name)
|
||||
if embeddings is None:
|
||||
embeddings = load_embeddings(embed_model, embed_device)
|
||||
search_index = FAISS.load_local(vs_path, embeddings, normalize_L2=True)
|
||||
|
||||
if not os.path.exists(vs_path):
|
||||
os.makedirs(vs_path)
|
||||
|
||||
if "index.faiss" in os.listdir(vs_path):
|
||||
search_index = FAISS.load_local(vs_path, embeddings, normalize_L2=True)
|
||||
else:
|
||||
# create an empty vector store
|
||||
doc = Document(page_content="init", metadata={})
|
||||
search_index = FAISS.from_documents([doc], embeddings, normalize_L2=True)
|
||||
ids = [k for k, v in search_index.docstore._dict.items()]
|
||||
search_index.delete(ids)
|
||||
search_index.save_local(vs_path)
|
||||
|
||||
if tick == 0: # vector store is loaded first time
|
||||
_VECTOR_STORE_TICKS[knowledge_base_name] = 0
|
||||
|
||||
return search_index
|
||||
|
||||
|
||||
|
|
@ -74,8 +90,10 @@ class FaissKBService(KBService):
|
|||
def do_create_kb(self):
|
||||
if not os.path.exists(self.vs_path):
|
||||
os.makedirs(self.vs_path)
|
||||
load_vector_store(self.kb_name)
|
||||
|
||||
def do_drop_kb(self):
|
||||
self.clear_vs()
|
||||
shutil.rmtree(self.kb_path)
|
||||
|
||||
def do_search(self,
|
||||
|
|
@ -94,37 +112,35 @@ class FaissKBService(KBService):
|
|||
docs: List[Document],
|
||||
embeddings: Embeddings,
|
||||
):
|
||||
if os.path.exists(self.vs_path) and "index.faiss" in os.listdir(self.vs_path):
|
||||
vector_store = FAISS.load_local(self.vs_path, embeddings, normalize_L2=True)
|
||||
vector_store.add_documents(docs)
|
||||
torch_gc()
|
||||
else:
|
||||
if not os.path.exists(self.vs_path):
|
||||
os.makedirs(self.vs_path)
|
||||
vector_store = FAISS.from_documents(
|
||||
docs, embeddings, normalize_L2=True) # docs 为Document列表
|
||||
torch_gc()
|
||||
vector_store = load_vector_store(self.kb_name,
|
||||
embeddings=embeddings,
|
||||
tick=_VECTOR_STORE_TICKS.get(self.kb_name, 0))
|
||||
vector_store.add_documents(docs)
|
||||
torch_gc()
|
||||
vector_store.save_local(self.vs_path)
|
||||
refresh_vs_cache(self.kb_name)
|
||||
|
||||
def do_delete_doc(self,
|
||||
kb_file: KnowledgeFile):
|
||||
embeddings = self._load_embeddings()
|
||||
if os.path.exists(self.vs_path) and "index.faiss" in os.listdir(self.vs_path):
|
||||
vector_store = FAISS.load_local(self.vs_path, embeddings, normalize_L2=True)
|
||||
ids = [k for k, v in vector_store.docstore._dict.items() if v.metadata["source"] == kb_file.filepath]
|
||||
if len(ids) == 0:
|
||||
return None
|
||||
vector_store.delete(ids)
|
||||
vector_store.save_local(self.vs_path)
|
||||
refresh_vs_cache(self.kb_name)
|
||||
return True
|
||||
else:
|
||||
vector_store = load_vector_store(self.kb_name,
|
||||
embeddings=embeddings,
|
||||
tick=_VECTOR_STORE_TICKS.get(self.kb_name, 0))
|
||||
|
||||
ids = [k for k, v in vector_store.docstore._dict.items() if v.metadata["source"] == kb_file.filepath]
|
||||
if len(ids) == 0:
|
||||
return None
|
||||
|
||||
vector_store.delete(ids)
|
||||
vector_store.save_local(self.vs_path)
|
||||
refresh_vs_cache(self.kb_name)
|
||||
|
||||
return True
|
||||
|
||||
def do_clear_vs(self):
|
||||
shutil.rmtree(self.vs_path)
|
||||
os.makedirs(self.vs_path)
|
||||
refresh_vs_cache(self.kb_name)
|
||||
|
||||
def exist_doc(self, file_name: str):
|
||||
if super().exist_doc(file_name):
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ root_path = Path(__file__).parent.parent.parent
|
|||
sys.path.append(str(root_path))
|
||||
from configs.server_config import api_address
|
||||
from configs.model_config import VECTOR_SEARCH_TOP_K
|
||||
from server.knowledge_base.utils import get_kb_path
|
||||
|
||||
from pprint import pprint
|
||||
|
||||
|
|
@ -22,8 +23,11 @@ test_files = {
|
|||
|
||||
|
||||
def test_delete_kb_before(api="/knowledge_base/delete_knowledge_base"):
|
||||
if not Path(get_kb_path(kb)).exists():
|
||||
return
|
||||
|
||||
url = api_base_url + api
|
||||
print("\n删除知识库")
|
||||
print("\n测试知识库存在,需要删除")
|
||||
r = requests.post(url, json=kb)
|
||||
data = r.json()
|
||||
pprint(data)
|
||||
|
|
|
|||
Loading…
Reference in New Issue