2023-08-06 23:43:54 +08:00
|
|
|
import os
|
|
|
|
|
import shutil
|
|
|
|
|
|
2023-08-10 00:36:51 +08:00
|
|
|
from configs.model_config import (
|
|
|
|
|
KB_ROOT_PATH,
|
|
|
|
|
CACHED_VS_NUM,
|
|
|
|
|
EMBEDDING_MODEL,
|
2023-09-08 08:55:12 +08:00
|
|
|
SCORE_THRESHOLD,
|
|
|
|
|
logger,
|
2023-08-10 00:36:51 +08:00
|
|
|
)
|
2023-08-09 10:46:01 +08:00
|
|
|
from server.knowledge_base.kb_service.base import KBService, SupportedVSType
|
2023-08-06 23:43:54 +08:00
|
|
|
from functools import lru_cache
|
2023-08-09 10:46:01 +08:00
|
|
|
from server.knowledge_base.utils import get_vs_path, load_embeddings, KnowledgeFile
|
2023-08-06 23:43:54 +08:00
|
|
|
from langchain.vectorstores import FAISS
|
|
|
|
|
from langchain.embeddings.base import Embeddings
|
2023-09-01 22:54:57 +08:00
|
|
|
from typing import List, Dict, Optional
|
2023-08-06 23:43:54 +08:00
|
|
|
from langchain.docstore.document import Document
|
2023-08-31 17:33:43 +08:00
|
|
|
from server.utils import torch_gc, embedding_device
|
2023-08-06 23:43:54 +08:00
|
|
|
|
2023-08-09 10:46:01 +08:00
|
|
|
|
2023-08-06 23:43:54 +08:00
|
|
|
_VECTOR_STORE_TICKS = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@lru_cache(CACHED_VS_NUM)
|
2023-08-28 13:50:35 +08:00
|
|
|
def load_faiss_vector_store(
|
2023-08-06 23:43:54 +08:00
|
|
|
knowledge_base_name: str,
|
2023-08-09 10:46:01 +08:00
|
|
|
embed_model: str = EMBEDDING_MODEL,
|
2023-08-31 17:33:43 +08:00
|
|
|
embed_device: str = embedding_device(),
|
2023-08-09 10:46:01 +08:00
|
|
|
embeddings: Embeddings = None,
|
|
|
|
|
tick: int = 0, # tick will be changed by upload_doc etc. and make cache refreshed.
|
2023-08-31 16:43:42 +08:00
|
|
|
) -> FAISS:
|
2023-09-08 08:55:12 +08:00
|
|
|
logger.info(f"loading vector store in '{knowledge_base_name}'.")
|
2023-08-06 23:43:54 +08:00
|
|
|
vs_path = get_vs_path(knowledge_base_name)
|
2023-08-09 10:46:01 +08:00
|
|
|
if embeddings is None:
|
|
|
|
|
embeddings = load_embeddings(embed_model, embed_device)
|
2023-08-20 16:52:49 +08:00
|
|
|
|
|
|
|
|
if not os.path.exists(vs_path):
|
|
|
|
|
os.makedirs(vs_path)
|
|
|
|
|
|
|
|
|
|
if "index.faiss" in os.listdir(vs_path):
|
|
|
|
|
search_index = FAISS.load_local(vs_path, embeddings, normalize_L2=True)
|
|
|
|
|
else:
|
|
|
|
|
# create an empty vector store
|
|
|
|
|
doc = Document(page_content="init", metadata={})
|
|
|
|
|
search_index = FAISS.from_documents([doc], embeddings, normalize_L2=True)
|
|
|
|
|
ids = [k for k, v in search_index.docstore._dict.items()]
|
|
|
|
|
search_index.delete(ids)
|
|
|
|
|
search_index.save_local(vs_path)
|
|
|
|
|
|
|
|
|
|
if tick == 0: # vector store is loaded first time
|
|
|
|
|
_VECTOR_STORE_TICKS[knowledge_base_name] = 0
|
|
|
|
|
|
2023-08-06 23:43:54 +08:00
|
|
|
return search_index
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def refresh_vs_cache(kb_name: str):
|
|
|
|
|
"""
|
|
|
|
|
make vector store cache refreshed when next loading
|
|
|
|
|
"""
|
|
|
|
|
_VECTOR_STORE_TICKS[kb_name] = _VECTOR_STORE_TICKS.get(kb_name, 0) + 1
|
2023-09-08 08:55:12 +08:00
|
|
|
logger.info(f"知识库 {kb_name} 缓存刷新:{_VECTOR_STORE_TICKS[kb_name]}")
|
2023-08-06 23:43:54 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
class FaissKBService(KBService):
|
|
|
|
|
vs_path: str
|
|
|
|
|
kb_path: str
|
|
|
|
|
|
|
|
|
|
def vs_type(self) -> str:
|
|
|
|
|
return SupportedVSType.FAISS
|
|
|
|
|
|
2023-08-28 13:50:35 +08:00
|
|
|
def get_vs_path(self):
|
|
|
|
|
return os.path.join(self.get_kb_path(), "vector_store")
|
|
|
|
|
|
|
|
|
|
def get_kb_path(self):
|
|
|
|
|
return os.path.join(KB_ROOT_PATH, self.kb_name)
|
2023-08-06 23:43:54 +08:00
|
|
|
|
2023-08-31 16:43:42 +08:00
|
|
|
def load_vector_store(self) -> FAISS:
|
2023-08-28 13:50:35 +08:00
|
|
|
return load_faiss_vector_store(
|
|
|
|
|
knowledge_base_name=self.kb_name,
|
|
|
|
|
embed_model=self.embed_model,
|
|
|
|
|
tick=_VECTOR_STORE_TICKS.get(self.kb_name, 0),
|
|
|
|
|
)
|
|
|
|
|
|
2023-08-31 16:43:42 +08:00
|
|
|
def save_vector_store(self, vector_store: FAISS = None):
|
|
|
|
|
vector_store = vector_store or self.load_vector_store()
|
|
|
|
|
vector_store.save_local(self.vs_path)
|
|
|
|
|
return vector_store
|
|
|
|
|
|
2023-08-28 13:50:35 +08:00
|
|
|
def refresh_vs_cache(self):
|
|
|
|
|
refresh_vs_cache(self.kb_name)
|
2023-08-06 23:43:54 +08:00
|
|
|
|
2023-09-01 22:54:57 +08:00
|
|
|
def get_doc_by_id(self, id: str) -> Optional[Document]:
|
|
|
|
|
vector_store = self.load_vector_store()
|
|
|
|
|
return vector_store.docstore._dict.get(id)
|
|
|
|
|
|
2023-08-06 23:43:54 +08:00
|
|
|
def do_init(self):
|
2023-08-28 13:50:35 +08:00
|
|
|
self.kb_path = self.get_kb_path()
|
|
|
|
|
self.vs_path = self.get_vs_path()
|
2023-08-06 23:43:54 +08:00
|
|
|
|
|
|
|
|
def do_create_kb(self):
|
|
|
|
|
if not os.path.exists(self.vs_path):
|
|
|
|
|
os.makedirs(self.vs_path)
|
2023-08-28 13:50:35 +08:00
|
|
|
self.load_vector_store()
|
2023-08-06 23:43:54 +08:00
|
|
|
|
2023-08-07 16:56:57 +08:00
|
|
|
def do_drop_kb(self):
|
2023-08-20 16:52:49 +08:00
|
|
|
self.clear_vs()
|
2023-08-06 23:43:54 +08:00
|
|
|
shutil.rmtree(self.kb_path)
|
|
|
|
|
|
|
|
|
|
def do_search(self,
|
|
|
|
|
query: str,
|
|
|
|
|
top_k: int,
|
2023-08-16 13:18:58 +08:00
|
|
|
score_threshold: float = SCORE_THRESHOLD,
|
|
|
|
|
embeddings: Embeddings = None,
|
2023-08-06 23:43:54 +08:00
|
|
|
) -> List[Document]:
|
2023-08-28 13:50:35 +08:00
|
|
|
search_index = self.load_vector_store()
|
2023-08-16 13:18:58 +08:00
|
|
|
docs = search_index.similarity_search_with_score(query, k=top_k, score_threshold=score_threshold)
|
2023-08-06 23:43:54 +08:00
|
|
|
return docs
|
|
|
|
|
|
|
|
|
|
def do_add_doc(self,
|
|
|
|
|
docs: List[Document],
|
2023-08-20 19:10:29 +08:00
|
|
|
**kwargs,
|
2023-09-01 22:54:57 +08:00
|
|
|
) -> List[Dict]:
|
2023-08-28 13:50:35 +08:00
|
|
|
vector_store = self.load_vector_store()
|
2023-09-01 22:54:57 +08:00
|
|
|
ids = vector_store.add_documents(docs)
|
|
|
|
|
doc_infos = [{"id": id, "metadata": doc.metadata} for id, doc in zip(ids, docs)]
|
2023-08-20 16:52:49 +08:00
|
|
|
torch_gc()
|
2023-08-20 19:10:29 +08:00
|
|
|
if not kwargs.get("not_refresh_vs_cache"):
|
|
|
|
|
vector_store.save_local(self.vs_path)
|
2023-08-28 13:50:35 +08:00
|
|
|
self.refresh_vs_cache()
|
2023-09-01 22:54:57 +08:00
|
|
|
return doc_infos
|
2023-08-06 23:43:54 +08:00
|
|
|
|
2023-08-07 16:56:57 +08:00
|
|
|
def do_delete_doc(self,
|
2023-08-20 19:10:29 +08:00
|
|
|
kb_file: KnowledgeFile,
|
|
|
|
|
**kwargs):
|
2023-08-28 13:50:35 +08:00
|
|
|
vector_store = self.load_vector_store()
|
2023-08-20 16:52:49 +08:00
|
|
|
|
2023-09-08 08:55:12 +08:00
|
|
|
ids = [k for k, v in vector_store.docstore._dict.items() if v.metadata.get("source") == kb_file.filepath]
|
2023-08-20 16:52:49 +08:00
|
|
|
if len(ids) == 0:
|
2023-08-06 23:43:54 +08:00
|
|
|
return None
|
|
|
|
|
|
2023-08-20 16:52:49 +08:00
|
|
|
vector_store.delete(ids)
|
2023-08-20 19:10:29 +08:00
|
|
|
if not kwargs.get("not_refresh_vs_cache"):
|
|
|
|
|
vector_store.save_local(self.vs_path)
|
2023-08-28 13:50:35 +08:00
|
|
|
self.refresh_vs_cache()
|
2023-08-20 16:52:49 +08:00
|
|
|
|
2023-08-31 16:43:42 +08:00
|
|
|
return vector_store
|
2023-08-20 16:52:49 +08:00
|
|
|
|
2023-08-06 23:43:54 +08:00
|
|
|
def do_clear_vs(self):
|
|
|
|
|
shutil.rmtree(self.vs_path)
|
2023-08-07 20:37:16 +08:00
|
|
|
os.makedirs(self.vs_path)
|
2023-08-28 13:50:35 +08:00
|
|
|
self.refresh_vs_cache()
|
2023-08-09 16:52:04 +08:00
|
|
|
|
|
|
|
|
def exist_doc(self, file_name: str):
|
|
|
|
|
if super().exist_doc(file_name):
|
|
|
|
|
return "in_db"
|
|
|
|
|
|
|
|
|
|
content_path = os.path.join(self.kb_path, "content")
|
|
|
|
|
if os.path.isfile(os.path.join(content_path, file_name)):
|
|
|
|
|
return "in_folder"
|
|
|
|
|
else:
|
|
|
|
|
return False
|
2023-08-27 11:21:10 +08:00
|
|
|
|
2023-08-28 13:50:35 +08:00
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
faissService = FaissKBService("test")
|
|
|
|
|
faissService.add_doc(KnowledgeFile("README.md", "test"))
|
|
|
|
|
faissService.delete_doc(KnowledgeFile("README.md", "test"))
|
|
|
|
|
faissService.do_drop_kb()
|
|
|
|
|
print(faissService.search_docs("如何启动api服务"))
|