diff --git a/server/api.py b/server/api.py index ee3f152..eecd478 100644 --- a/server/api.py +++ b/server/api.py @@ -11,8 +11,8 @@ from fastapi.middleware.cors import CORSMiddleware from starlette.responses import RedirectResponse from server.chat import (chat, knowledge_base_chat, openai_chat, search_engine_chat) -from server.knowledge_base import (list_kbs, create_kb, delete_kb, - list_docs, upload_doc, delete_doc, +from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb +from server.knowledge_base.kb_doc_api import (list_docs, upload_doc, delete_doc, update_doc, recreate_vector_store) from server.utils import BaseResponse, ListResponse diff --git a/server/knowledge_base/kb_doc_api.py b/server/knowledge_base/kb_doc_api.py index 90c3deb..636e606 100644 --- a/server/knowledge_base/kb_doc_api.py +++ b/server/knowledge_base/kb_doc_api.py @@ -91,7 +91,7 @@ async def download_doc(): async def recreate_vector_store( knowledge_base_name: str, allow_empty_kb: bool = True, - vs_type: Union[str, SupportedVSType] = "faiss", + vs_type: str = "faiss", ): ''' recreate vector store from the content. diff --git a/server/knowledge_base/kb_service/base.py b/server/knowledge_base/kb_service/base.py index c7cbf46..5af47e7 100644 --- a/server/knowledge_base/kb_service/base.py +++ b/server/knowledge_base/kb_service/base.py @@ -3,7 +3,6 @@ from abc import ABC, abstractmethod import os from functools import lru_cache -from langchain.embeddings import HuggingFaceEmbeddings from langchain.embeddings.base import Embeddings from langchain.docstore.document import Document @@ -34,6 +33,9 @@ class KBService(ABC): self.doc_path = get_doc_path(self.kb_name) self.do_init() + def _load_embeddings(self, embed_device: str = EMBEDDING_DEVICE) -> Embeddings: + return load_embeddings(self.embed_model, embed_device) + def create_kb(self): """ 创建知识库 @@ -63,7 +65,7 @@ class KBService(ABC): 向知识库添加文件 """ docs = kb_file.file2text() - embeddings = load_embeddings(self.embed_model, EMBEDDING_DEVICE) + embeddings = self._load_embeddings() self.do_add_doc(docs, embeddings) status = add_doc_to_db(kb_file) return status @@ -88,8 +90,8 @@ class KBService(ABC): def search_docs(self, query: str, top_k: int = VECTOR_SEARCH_TOP_K, - embedding_device: str = EMBEDDING_DEVICE, ): - embeddings = load_embeddings(self.embed_model, embedding_device) + ): + embeddings = self._load_embeddings() docs = self.do_search(query, top_k, embeddings) return docs @@ -142,7 +144,8 @@ class KBService(ABC): @abstractmethod def do_add_doc(self, docs: List[Document], - embeddings: Embeddings): + embeddings: Embeddings, + ): """ 向知识库添加文档子类实自己逻辑 """ diff --git a/server/knowledge_base/kb_service/faiss_kb_service.py b/server/knowledge_base/kb_service/faiss_kb_service.py index 9798206..064eb43 100644 --- a/server/knowledge_base/kb_service/faiss_kb_service.py +++ b/server/knowledge_base/kb_service/faiss_kb_service.py @@ -1,28 +1,40 @@ import os import shutil -from configs.model_config import KB_ROOT_PATH, CACHED_VS_NUM, EMBEDDING_DEVICE -from server.knowledge_base.kb_service.base import KBService, SupportedVSType, load_embeddings +from configs.model_config import KB_ROOT_PATH, CACHED_VS_NUM, EMBEDDING_MODEL, EMBEDDING_DEVICE +from server.knowledge_base.kb_service.base import KBService, SupportedVSType from functools import lru_cache -from server.knowledge_base.utils import get_vs_path, KnowledgeFile +from server.knowledge_base.utils import get_vs_path, load_embeddings, KnowledgeFile from langchain.vectorstores import FAISS from langchain.embeddings.base import Embeddings +from langchain.embeddings.huggingface import HuggingFaceEmbeddings from typing import List from langchain.docstore.document import Document from server.utils import torch_gc import numpy as np + +# make HuggingFaceEmbeddings hashable +def _embeddings_hash(self): + return hash(self.model_name) +HuggingFaceEmbeddings.__hash__ = _embeddings_hash + + _VECTOR_STORE_TICKS = {} @lru_cache(CACHED_VS_NUM) def load_vector_store( knowledge_base_name: str, - embeddings: Embeddings, - tick: int, # tick will be changed by upload_doc etc. and make cache refreshed. + embed_model: str = EMBEDDING_MODEL, + embed_device: str = EMBEDDING_DEVICE, + embeddings: Embeddings = None, + tick: int = 0, # tick will be changed by upload_doc etc. and make cache refreshed. ): print(f"loading vector store in '{knowledge_base_name}'.") vs_path = get_vs_path(knowledge_base_name) + if embeddings is None: + embeddings = load_embeddings(embed_model, embed_device) search_index = FAISS.load_local(vs_path, embeddings) return search_index @@ -84,14 +96,15 @@ class FaissKBService(KBService): embeddings: Embeddings, ) -> List[Document]: search_index = load_vector_store(self.kb_name, - embeddings, - _VECTOR_STORE_TICKS.get(self.kb_name)) + embeddings=embeddings, + tick=_VECTOR_STORE_TICKS.get(self.kb_name)) docs = search_index.similarity_search(query, k=top_k) return docs def do_add_doc(self, docs: List[Document], - embeddings: Embeddings): + embeddings: Embeddings, + ): if os.path.exists(self.vs_path) and "index.faiss" in os.listdir(self.vs_path): vector_store = FAISS.load_local(self.vs_path, embeddings) vector_store.add_documents(docs) @@ -99,14 +112,15 @@ class FaissKBService(KBService): else: if not os.path.exists(self.vs_path): os.makedirs(self.vs_path) - vector_store = FAISS.from_documents(docs, embeddings) # docs 为Document列表 + vector_store = FAISS.from_documents( + docs, embeddings) # docs 为Document列表 torch_gc() vector_store.save_local(self.vs_path) refresh_vs_cache(self.kb_name) def do_delete_doc(self, kb_file: KnowledgeFile): - embeddings = load_embeddings(self.embed_model, EMBEDDING_DEVICE) + embeddings = self._load_embeddings() if os.path.exists(self.vs_path) and "index.faiss" in os.listdir(self.vs_path): vector_store = FAISS.load_local(self.vs_path, embeddings) ids = [k for k, v in vector_store.docstore._dict.items() if v.metadata["source"] == kb_file.filepath] diff --git a/server/knowledge_base/kb_service/milvus_kb_service.py b/server/knowledge_base/kb_service/milvus_kb_service.py index 641a43b..8898a0e 100644 --- a/server/knowledge_base/kb_service/milvus_kb_service.py +++ b/server/knowledge_base/kb_service/milvus_kb_service.py @@ -6,7 +6,7 @@ from langchain.vectorstores import Milvus from configs.model_config import EMBEDDING_DEVICE, kbs_config -from server.knowledge_base.kb_service.base import KBService, SupportedVSType, load_embeddings +from server.knowledge_base.kb_service.base import KBService, SupportedVSType from server.knowledge_base.utils import KnowledgeFile @@ -33,11 +33,10 @@ class MilvusKBService(KBService): def vs_type(self) -> str: return SupportedVSType.MILVUS - def _load_milvus(self, embedding_device: str = EMBEDDING_DEVICE, embeddings: Embeddings = None): - _embeddings = embeddings - if _embeddings is None: - _embeddings = load_embeddings(self.embed_model, embedding_device) - self.milvus = Milvus(embedding_function=_embeddings, + def _load_milvus(self, embeddings: Embeddings = None): + if embeddings is None: + embeddings = self._load_embeddings() + self.milvus = Milvus(embedding_function=embeddings, collection_name=self.kb_name, connection_args=kbs_config.get("milvus")) def do_init(self): diff --git a/webui_pages/utils.py b/webui_pages/utils.py index 7054115..4f20757 100644 --- a/webui_pages/utils.py +++ b/webui_pages/utils.py @@ -169,7 +169,7 @@ class ApiRequest: for chunk in iter_over_async(response.body_iterator, loop): if as_json and chunk: yield json.loads(chunk) - else: + elif chunk.strip(): yield chunk def _httpx_stream2generator( @@ -184,7 +184,7 @@ class ApiRequest: for chunk in r.iter_text(None): if as_json and chunk: yield json.loads(chunk) - else: + elif chunk.strip(): yield chunk # 对话相关操作 @@ -250,6 +250,7 @@ class ApiRequest: query: str, knowledge_base_name: str, top_k: int = VECTOR_SEARCH_TOP_K, + history: List[Dict] = [], no_remote_api: bool = None, ): ''' @@ -260,12 +261,12 @@ class ApiRequest: if no_remote_api: from server.chat.knowledge_base_chat import knowledge_base_chat - response = knowledge_base_chat(query, knowledge_base_name, top_k) + response = knowledge_base_chat(query, knowledge_base_name, top_k, history) return self._fastapi_stream2generator(response, as_json=True) else: response = self.post( "/chat/knowledge_base_chat", - json={"query": query, "knowledge_base_name": knowledge_base_name, "top_k": top_k}, + json={"query": query, "knowledge_base_name": knowledge_base_name, "top_k": top_k, "history": history}, stream=True, ) return self._httpx_stream2generator(response, as_json=True)