diff --git a/server/chat/knowledge_base_chat.py b/server/chat/knowledge_base_chat.py index 23a380a..58ee7e0 100644 --- a/server/chat/knowledge_base_chat.py +++ b/server/chat/knowledge_base_chat.py @@ -1,8 +1,7 @@ from fastapi import Body from fastapi.responses import StreamingResponse from configs.model_config import (llm_model_dict, LLM_MODEL, PROMPT_TEMPLATE, - CACHED_VS_NUM, VECTOR_SEARCH_TOP_K, - embedding_model_dict, EMBEDDING_MODEL, EMBEDDING_DEVICE) + VECTOR_SEARCH_TOP_K) from server.chat.utils import wrap_done from server.utils import BaseResponse import os @@ -13,41 +12,7 @@ from langchain.callbacks import AsyncIteratorCallbackHandler from typing import AsyncIterable import asyncio from langchain.prompts import PromptTemplate -from langchain.vectorstores import FAISS -from langchain.embeddings.huggingface import HuggingFaceEmbeddings -from server.knowledge_base.utils import get_vs_path -from functools import lru_cache - - -@lru_cache(1) -def load_embeddings(model: str, device: str): - embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[model], - model_kwargs={'device': device}) - return embeddings - - -@lru_cache(CACHED_VS_NUM) -def load_vector_store( - knowledge_base_name: str, - embedding_model: str, - embedding_device: str, -): - embeddings = load_embeddings(embedding_model, embedding_device) - vs_path = get_vs_path(knowledge_base_name) - search_index = FAISS.load_local(vs_path, embeddings) - return search_index - - -def lookup_vs( - query: str, - knowledge_base_name: str, - top_k: int = VECTOR_SEARCH_TOP_K, - embedding_model: str = EMBEDDING_MODEL, - embedding_device: str = EMBEDDING_DEVICE, -): - search_index = load_vector_store(knowledge_base_name, embedding_model, embedding_device) - docs = search_index.similarity_search(query, k=top_k) - return docs +from server.knowledge_base.utils import lookup_vs def knowledge_base_chat(query: str = Body(..., description="用户输入", example="你好"), diff --git a/server/knowledge_base/kb_doc_api.py b/server/knowledge_base/kb_doc_api.py index e0543ad..bc0ee64 100644 --- a/server/knowledge_base/kb_doc_api.py +++ b/server/knowledge_base/kb_doc_api.py @@ -8,6 +8,7 @@ from server.utils import BaseResponse, ListResponse, torch_gc from server.knowledge_base.utils import (validate_kb_name, get_kb_path, get_doc_path, get_vs_path, get_file_path, file2text) from configs.model_config import embedding_model_dict, EMBEDDING_MODEL, EMBEDDING_DEVICE +from server.knowledge_base.utils import load_embeddings, refresh_vs_cache async def list_docs(knowledge_base_name: str): @@ -55,8 +56,7 @@ async def upload_doc(file: UploadFile = File(description="上传文件"), filepath = get_file_path(knowledge_base_name, file.filename) docs = file2text(filepath) loaded_files = [file] - embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[EMBEDDING_MODEL], - model_kwargs={'device': EMBEDDING_DEVICE}) + embeddings = load_embeddings(embedding_model_dict[EMBEDDING_MODEL], EMBEDDING_DEVICE) if os.path.exists(vs_path) and "index.faiss" in os.listdir(vs_path): vector_store = FAISS.load_local(vs_path, embeddings) vector_store.add_documents(docs) @@ -69,6 +69,7 @@ async def upload_doc(file: UploadFile = File(description="上传文件"), vector_store.save_local(vs_path) if len(loaded_files) > 0: file_status = f"成功上传文件 {file.filename}" + refresh_vs_cache(knowledge_base_name) return BaseResponse(code=200, msg=file_status) else: file_status = f"上传文件 {file.filename} 失败" @@ -95,6 +96,7 @@ async def delete_doc(knowledge_base_name: str, # TODO: 重写从向量库中删除文件 status = "" # local_doc_qa.delete_file_from_vector_store(doc_path, get_vs_path(knowledge_base_name)) if "success" in status: + refresh_vs_cache(knowledge_base_name) return BaseResponse(code=200, msg=f"document {doc_name} delete success") else: return BaseResponse(code=500, msg=f"document {doc_name} delete fail") @@ -104,9 +106,9 @@ async def delete_doc(knowledge_base_name: str, async def update_doc(): # TODO: 替换文件 + # refresh_vs_cache(knowledge_base_name) pass - async def download_doc(): # TODO: 下载文件 pass diff --git a/server/knowledge_base/utils.py b/server/knowledge_base/utils.py index 39fa015..383c89d 100644 --- a/server/knowledge_base/utils.py +++ b/server/knowledge_base/utils.py @@ -1,5 +1,13 @@ import os from configs.model_config import KB_ROOT_PATH +from langchain.vectorstores import FAISS +from langchain.embeddings.huggingface import HuggingFaceEmbeddings +from configs.model_config import (CACHED_VS_NUM, VECTOR_SEARCH_TOP_K, + embedding_model_dict, EMBEDDING_MODEL, EMBEDDING_DEVICE) +from functools import lru_cache + + +_VECTOR_STORE_TICKS = {} def get_kb_path(knowledge_base_name: str): @@ -36,6 +44,46 @@ def file2text(filepath): return docs +@lru_cache(1) +def load_embeddings(model: str, device: str): + embeddings = HuggingFaceEmbeddings(model_name=embedding_model_dict[model], + model_kwargs={'device': device}) + return embeddings + + +@lru_cache(CACHED_VS_NUM) +def load_vector_store( + knowledge_base_name: str, + embedding_model: str, + embedding_device: str, + tick: int, # tick will be changed by upload_doc etc. and make cache refreshed. +): + print(f"loading vector store in '{knowledge_base_name}' with '{embedding_model}' embeddings.") + embeddings = load_embeddings(embedding_model, embedding_device) + vs_path = get_vs_path(knowledge_base_name) + search_index = FAISS.load_local(vs_path, embeddings) + return search_index + + +def lookup_vs( + query: str, + knowledge_base_name: str, + top_k: int = VECTOR_SEARCH_TOP_K, + embedding_model: str = EMBEDDING_MODEL, + embedding_device: str = EMBEDDING_DEVICE, +): + search_index = load_vector_store(knowledge_base_name, embedding_model, embedding_device) + docs = search_index.similarity_search(query, k=top_k) + return docs + + +def refresh_vs_cache(kb_name: str): + ''' + make vector store cache refreshed when next loading + ''' + _VECTOR_STORE_TICKS[kb_name] = _VECTOR_STORE_TICKS.get(kb_name, 0) + 1 + + if __name__ == "__main__": filepath = "/Users/liuqian/PycharmProjects/chatchat/knowledge_base/123/content/test.txt" docs = file2text(filepath)