update knowledge base kbservice and api:

1. make huggingfaceembeddings hashable
2. unify embeddings' loading method for all kbservie
3. make ApiRequest skip empty content when streaming json to avoid dict
   KeyError
This commit is contained in:
liunux4odoo 2023-08-09 10:46:01 +08:00
parent ec76adc81d
commit b98f5fd0b9
6 changed files with 45 additions and 28 deletions

View File

@ -11,8 +11,8 @@ from fastapi.middleware.cors import CORSMiddleware
from starlette.responses import RedirectResponse from starlette.responses import RedirectResponse
from server.chat import (chat, knowledge_base_chat, openai_chat, from server.chat import (chat, knowledge_base_chat, openai_chat,
search_engine_chat) search_engine_chat)
from server.knowledge_base import (list_kbs, create_kb, delete_kb, from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb
list_docs, upload_doc, delete_doc, from server.knowledge_base.kb_doc_api import (list_docs, upload_doc, delete_doc,
update_doc, recreate_vector_store) update_doc, recreate_vector_store)
from server.utils import BaseResponse, ListResponse from server.utils import BaseResponse, ListResponse

View File

@ -91,7 +91,7 @@ async def download_doc():
async def recreate_vector_store( async def recreate_vector_store(
knowledge_base_name: str, knowledge_base_name: str,
allow_empty_kb: bool = True, allow_empty_kb: bool = True,
vs_type: Union[str, SupportedVSType] = "faiss", vs_type: str = "faiss",
): ):
''' '''
recreate vector store from the content. recreate vector store from the content.

View File

@ -3,7 +3,6 @@ from abc import ABC, abstractmethod
import os import os
from functools import lru_cache from functools import lru_cache
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.embeddings.base import Embeddings from langchain.embeddings.base import Embeddings
from langchain.docstore.document import Document from langchain.docstore.document import Document
@ -34,6 +33,9 @@ class KBService(ABC):
self.doc_path = get_doc_path(self.kb_name) self.doc_path = get_doc_path(self.kb_name)
self.do_init() self.do_init()
def _load_embeddings(self, embed_device: str = EMBEDDING_DEVICE) -> Embeddings:
return load_embeddings(self.embed_model, embed_device)
def create_kb(self): def create_kb(self):
""" """
创建知识库 创建知识库
@ -63,7 +65,7 @@ class KBService(ABC):
向知识库添加文件 向知识库添加文件
""" """
docs = kb_file.file2text() docs = kb_file.file2text()
embeddings = load_embeddings(self.embed_model, EMBEDDING_DEVICE) embeddings = self._load_embeddings()
self.do_add_doc(docs, embeddings) self.do_add_doc(docs, embeddings)
status = add_doc_to_db(kb_file) status = add_doc_to_db(kb_file)
return status return status
@ -88,8 +90,8 @@ class KBService(ABC):
def search_docs(self, def search_docs(self,
query: str, query: str,
top_k: int = VECTOR_SEARCH_TOP_K, top_k: int = VECTOR_SEARCH_TOP_K,
embedding_device: str = EMBEDDING_DEVICE, ): ):
embeddings = load_embeddings(self.embed_model, embedding_device) embeddings = self._load_embeddings()
docs = self.do_search(query, top_k, embeddings) docs = self.do_search(query, top_k, embeddings)
return docs return docs
@ -142,7 +144,8 @@ class KBService(ABC):
@abstractmethod @abstractmethod
def do_add_doc(self, def do_add_doc(self,
docs: List[Document], docs: List[Document],
embeddings: Embeddings): embeddings: Embeddings,
):
""" """
向知识库添加文档子类实自己逻辑 向知识库添加文档子类实自己逻辑
""" """

View File

@ -1,28 +1,40 @@
import os import os
import shutil import shutil
from configs.model_config import KB_ROOT_PATH, CACHED_VS_NUM, EMBEDDING_DEVICE from configs.model_config import KB_ROOT_PATH, CACHED_VS_NUM, EMBEDDING_MODEL, EMBEDDING_DEVICE
from server.knowledge_base.kb_service.base import KBService, SupportedVSType, load_embeddings from server.knowledge_base.kb_service.base import KBService, SupportedVSType
from functools import lru_cache from functools import lru_cache
from server.knowledge_base.utils import get_vs_path, KnowledgeFile from server.knowledge_base.utils import get_vs_path, load_embeddings, KnowledgeFile
from langchain.vectorstores import FAISS from langchain.vectorstores import FAISS
from langchain.embeddings.base import Embeddings from langchain.embeddings.base import Embeddings
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from typing import List from typing import List
from langchain.docstore.document import Document from langchain.docstore.document import Document
from server.utils import torch_gc from server.utils import torch_gc
import numpy as np import numpy as np
# make HuggingFaceEmbeddings hashable
def _embeddings_hash(self):
return hash(self.model_name)
HuggingFaceEmbeddings.__hash__ = _embeddings_hash
_VECTOR_STORE_TICKS = {} _VECTOR_STORE_TICKS = {}
@lru_cache(CACHED_VS_NUM) @lru_cache(CACHED_VS_NUM)
def load_vector_store( def load_vector_store(
knowledge_base_name: str, knowledge_base_name: str,
embeddings: Embeddings, embed_model: str = EMBEDDING_MODEL,
tick: int, # tick will be changed by upload_doc etc. and make cache refreshed. embed_device: str = EMBEDDING_DEVICE,
embeddings: Embeddings = None,
tick: int = 0, # tick will be changed by upload_doc etc. and make cache refreshed.
): ):
print(f"loading vector store in '{knowledge_base_name}'.") print(f"loading vector store in '{knowledge_base_name}'.")
vs_path = get_vs_path(knowledge_base_name) vs_path = get_vs_path(knowledge_base_name)
if embeddings is None:
embeddings = load_embeddings(embed_model, embed_device)
search_index = FAISS.load_local(vs_path, embeddings) search_index = FAISS.load_local(vs_path, embeddings)
return search_index return search_index
@ -84,14 +96,15 @@ class FaissKBService(KBService):
embeddings: Embeddings, embeddings: Embeddings,
) -> List[Document]: ) -> List[Document]:
search_index = load_vector_store(self.kb_name, search_index = load_vector_store(self.kb_name,
embeddings, embeddings=embeddings,
_VECTOR_STORE_TICKS.get(self.kb_name)) tick=_VECTOR_STORE_TICKS.get(self.kb_name))
docs = search_index.similarity_search(query, k=top_k) docs = search_index.similarity_search(query, k=top_k)
return docs return docs
def do_add_doc(self, def do_add_doc(self,
docs: List[Document], docs: List[Document],
embeddings: Embeddings): embeddings: Embeddings,
):
if os.path.exists(self.vs_path) and "index.faiss" in os.listdir(self.vs_path): if os.path.exists(self.vs_path) and "index.faiss" in os.listdir(self.vs_path):
vector_store = FAISS.load_local(self.vs_path, embeddings) vector_store = FAISS.load_local(self.vs_path, embeddings)
vector_store.add_documents(docs) vector_store.add_documents(docs)
@ -99,14 +112,15 @@ class FaissKBService(KBService):
else: else:
if not os.path.exists(self.vs_path): if not os.path.exists(self.vs_path):
os.makedirs(self.vs_path) os.makedirs(self.vs_path)
vector_store = FAISS.from_documents(docs, embeddings) # docs 为Document列表 vector_store = FAISS.from_documents(
docs, embeddings) # docs 为Document列表
torch_gc() torch_gc()
vector_store.save_local(self.vs_path) vector_store.save_local(self.vs_path)
refresh_vs_cache(self.kb_name) refresh_vs_cache(self.kb_name)
def do_delete_doc(self, def do_delete_doc(self,
kb_file: KnowledgeFile): kb_file: KnowledgeFile):
embeddings = load_embeddings(self.embed_model, EMBEDDING_DEVICE) embeddings = self._load_embeddings()
if os.path.exists(self.vs_path) and "index.faiss" in os.listdir(self.vs_path): if os.path.exists(self.vs_path) and "index.faiss" in os.listdir(self.vs_path):
vector_store = FAISS.load_local(self.vs_path, embeddings) vector_store = FAISS.load_local(self.vs_path, embeddings)
ids = [k for k, v in vector_store.docstore._dict.items() if v.metadata["source"] == kb_file.filepath] ids = [k for k, v in vector_store.docstore._dict.items() if v.metadata["source"] == kb_file.filepath]

View File

@ -6,7 +6,7 @@ from langchain.vectorstores import Milvus
from configs.model_config import EMBEDDING_DEVICE, kbs_config from configs.model_config import EMBEDDING_DEVICE, kbs_config
from server.knowledge_base.kb_service.base import KBService, SupportedVSType, load_embeddings from server.knowledge_base.kb_service.base import KBService, SupportedVSType
from server.knowledge_base.utils import KnowledgeFile from server.knowledge_base.utils import KnowledgeFile
@ -33,11 +33,10 @@ class MilvusKBService(KBService):
def vs_type(self) -> str: def vs_type(self) -> str:
return SupportedVSType.MILVUS return SupportedVSType.MILVUS
def _load_milvus(self, embedding_device: str = EMBEDDING_DEVICE, embeddings: Embeddings = None): def _load_milvus(self, embeddings: Embeddings = None):
_embeddings = embeddings if embeddings is None:
if _embeddings is None: embeddings = self._load_embeddings()
_embeddings = load_embeddings(self.embed_model, embedding_device) self.milvus = Milvus(embedding_function=embeddings,
self.milvus = Milvus(embedding_function=_embeddings,
collection_name=self.kb_name, connection_args=kbs_config.get("milvus")) collection_name=self.kb_name, connection_args=kbs_config.get("milvus"))
def do_init(self): def do_init(self):

View File

@ -169,7 +169,7 @@ class ApiRequest:
for chunk in iter_over_async(response.body_iterator, loop): for chunk in iter_over_async(response.body_iterator, loop):
if as_json and chunk: if as_json and chunk:
yield json.loads(chunk) yield json.loads(chunk)
else: elif chunk.strip():
yield chunk yield chunk
def _httpx_stream2generator( def _httpx_stream2generator(
@ -184,7 +184,7 @@ class ApiRequest:
for chunk in r.iter_text(None): for chunk in r.iter_text(None):
if as_json and chunk: if as_json and chunk:
yield json.loads(chunk) yield json.loads(chunk)
else: elif chunk.strip():
yield chunk yield chunk
# 对话相关操作 # 对话相关操作
@ -250,6 +250,7 @@ class ApiRequest:
query: str, query: str,
knowledge_base_name: str, knowledge_base_name: str,
top_k: int = VECTOR_SEARCH_TOP_K, top_k: int = VECTOR_SEARCH_TOP_K,
history: List[Dict] = [],
no_remote_api: bool = None, no_remote_api: bool = None,
): ):
''' '''
@ -260,12 +261,12 @@ class ApiRequest:
if no_remote_api: if no_remote_api:
from server.chat.knowledge_base_chat import knowledge_base_chat from server.chat.knowledge_base_chat import knowledge_base_chat
response = knowledge_base_chat(query, knowledge_base_name, top_k) response = knowledge_base_chat(query, knowledge_base_name, top_k, history)
return self._fastapi_stream2generator(response, as_json=True) return self._fastapi_stream2generator(response, as_json=True)
else: else:
response = self.post( response = self.post(
"/chat/knowledge_base_chat", "/chat/knowledge_base_chat",
json={"query": query, "knowledge_base_name": knowledge_base_name, "top_k": top_k}, json={"query": query, "knowledge_base_name": knowledge_base_name, "top_k": top_k, "history": history},
stream=True, stream=True,
) )
return self._httpx_stream2generator(response, as_json=True) return self._httpx_stream2generator(response, as_json=True)