优化server下知识库相关模块的结构:将知识库相关操作从knowledge_base_chat移动到knowledge_base.utils;优化kb_doc_api中embeddings加载。
This commit is contained in:
parent
9a18218293
commit
88682c87ff
|
|
@ -1,8 +1,7 @@
|
|||
from fastapi import Body
|
||||
from fastapi.responses import StreamingResponse
|
||||
from configs.model_config import (llm_model_dict, LLM_MODEL, PROMPT_TEMPLATE,
|
||||
CACHED_VS_NUM, VECTOR_SEARCH_TOP_K,
|
||||
embedding_model_dict, EMBEDDING_MODEL, EMBEDDING_DEVICE)
|
||||
VECTOR_SEARCH_TOP_K)
|
||||
from server.chat.utils import wrap_done
|
||||
from server.utils import BaseResponse
|
||||
import os
|
||||
|
|
@ -13,41 +12,7 @@ from langchain.callbacks import AsyncIteratorCallbackHandler
|
|||
from typing import AsyncIterable
|
||||
import asyncio
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.vectorstores import FAISS
|
||||
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
||||
from server.knowledge_base.utils import get_vs_path
|
||||
from functools import lru_cache
|
||||
|
||||
|
||||
@lru_cache(1)
|
||||
def load_embeddings(model: str, device: str):
|
||||
embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[model],
|
||||
model_kwargs={'device': device})
|
||||
return embeddings
|
||||
|
||||
|
||||
@lru_cache(CACHED_VS_NUM)
|
||||
def load_vector_store(
|
||||
knowledge_base_name: str,
|
||||
embedding_model: str,
|
||||
embedding_device: str,
|
||||
):
|
||||
embeddings = load_embeddings(embedding_model, embedding_device)
|
||||
vs_path = get_vs_path(knowledge_base_name)
|
||||
search_index = FAISS.load_local(vs_path, embeddings)
|
||||
return search_index
|
||||
|
||||
|
||||
def lookup_vs(
|
||||
query: str,
|
||||
knowledge_base_name: str,
|
||||
top_k: int = VECTOR_SEARCH_TOP_K,
|
||||
embedding_model: str = EMBEDDING_MODEL,
|
||||
embedding_device: str = EMBEDDING_DEVICE,
|
||||
):
|
||||
search_index = load_vector_store(knowledge_base_name, embedding_model, embedding_device)
|
||||
docs = search_index.similarity_search(query, k=top_k)
|
||||
return docs
|
||||
from server.knowledge_base.utils import lookup_vs
|
||||
|
||||
|
||||
def knowledge_base_chat(query: str = Body(..., description="用户输入", example="你好"),
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from server.utils import BaseResponse, ListResponse, torch_gc
|
|||
from server.knowledge_base.utils import (validate_kb_name, get_kb_path, get_doc_path,
|
||||
get_vs_path, get_file_path, file2text)
|
||||
from configs.model_config import embedding_model_dict, EMBEDDING_MODEL, EMBEDDING_DEVICE
|
||||
from server.knowledge_base.utils import load_embeddings, refresh_vs_cache
|
||||
|
||||
|
||||
async def list_docs(knowledge_base_name: str):
|
||||
|
|
@ -55,8 +56,7 @@ async def upload_doc(file: UploadFile = File(description="上传文件"),
|
|||
filepath = get_file_path(knowledge_base_name, file.filename)
|
||||
docs = file2text(filepath)
|
||||
loaded_files = [file]
|
||||
embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[EMBEDDING_MODEL],
|
||||
model_kwargs={'device': EMBEDDING_DEVICE})
|
||||
embeddings = load_embeddings(embedding_model_dict[EMBEDDING_MODEL], EMBEDDING_DEVICE)
|
||||
if os.path.exists(vs_path) and "index.faiss" in os.listdir(vs_path):
|
||||
vector_store = FAISS.load_local(vs_path, embeddings)
|
||||
vector_store.add_documents(docs)
|
||||
|
|
@ -69,6 +69,7 @@ async def upload_doc(file: UploadFile = File(description="上传文件"),
|
|||
vector_store.save_local(vs_path)
|
||||
if len(loaded_files) > 0:
|
||||
file_status = f"成功上传文件 {file.filename}"
|
||||
refresh_vs_cache(knowledge_base_name)
|
||||
return BaseResponse(code=200, msg=file_status)
|
||||
else:
|
||||
file_status = f"上传文件 {file.filename} 失败"
|
||||
|
|
@ -95,6 +96,7 @@ async def delete_doc(knowledge_base_name: str,
|
|||
# TODO: 重写从向量库中删除文件
|
||||
status = "" # local_doc_qa.delete_file_from_vector_store(doc_path, get_vs_path(knowledge_base_name))
|
||||
if "success" in status:
|
||||
refresh_vs_cache(knowledge_base_name)
|
||||
return BaseResponse(code=200, msg=f"document {doc_name} delete success")
|
||||
else:
|
||||
return BaseResponse(code=500, msg=f"document {doc_name} delete fail")
|
||||
|
|
@ -104,9 +106,9 @@ async def delete_doc(knowledge_base_name: str,
|
|||
|
||||
async def update_doc():
|
||||
# TODO: 替换文件
|
||||
# refresh_vs_cache(knowledge_base_name)
|
||||
pass
|
||||
|
||||
|
||||
async def download_doc():
|
||||
# TODO: 下载文件
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -1,5 +1,13 @@
|
|||
import os
|
||||
from configs.model_config import KB_ROOT_PATH
|
||||
from langchain.vectorstores import FAISS
|
||||
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
||||
from configs.model_config import (CACHED_VS_NUM, VECTOR_SEARCH_TOP_K,
|
||||
embedding_model_dict, EMBEDDING_MODEL, EMBEDDING_DEVICE)
|
||||
from functools import lru_cache
|
||||
|
||||
|
||||
_VECTOR_STORE_TICKS = {}
|
||||
|
||||
|
||||
def get_kb_path(knowledge_base_name: str):
|
||||
|
|
@ -36,6 +44,46 @@ def file2text(filepath):
|
|||
return docs
|
||||
|
||||
|
||||
@lru_cache(1)
|
||||
def load_embeddings(model: str, device: str):
|
||||
embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[model],
|
||||
model_kwargs={'device': device})
|
||||
return embeddings
|
||||
|
||||
|
||||
@lru_cache(CACHED_VS_NUM)
|
||||
def load_vector_store(
|
||||
knowledge_base_name: str,
|
||||
embedding_model: str,
|
||||
embedding_device: str,
|
||||
tick: int, # tick will be changed by upload_doc etc. and make cache refreshed.
|
||||
):
|
||||
print(f"loading vector store in '{knowledge_base_name}' with '{embedding_model}' embeddings.")
|
||||
embeddings = load_embeddings(embedding_model, embedding_device)
|
||||
vs_path = get_vs_path(knowledge_base_name)
|
||||
search_index = FAISS.load_local(vs_path, embeddings)
|
||||
return search_index
|
||||
|
||||
|
||||
def lookup_vs(
|
||||
query: str,
|
||||
knowledge_base_name: str,
|
||||
top_k: int = VECTOR_SEARCH_TOP_K,
|
||||
embedding_model: str = EMBEDDING_MODEL,
|
||||
embedding_device: str = EMBEDDING_DEVICE,
|
||||
):
|
||||
search_index = load_vector_store(knowledge_base_name, embedding_model, embedding_device)
|
||||
docs = search_index.similarity_search(query, k=top_k)
|
||||
return docs
|
||||
|
||||
|
||||
def refresh_vs_cache(kb_name: str):
|
||||
'''
|
||||
make vector store cache refreshed when next loading
|
||||
'''
|
||||
_VECTOR_STORE_TICKS[kb_name] = _VECTOR_STORE_TICKS.get(kb_name, 0) + 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
filepath = "/Users/liuqian/PycharmProjects/chatchat/knowledge_base/123/content/test.txt"
|
||||
docs = file2text(filepath)
|
||||
|
|
|
|||
Loading…
Reference in New Issue