update knowledge base kbservice and api:

1. make huggingfaceembeddings hashable
2. unify embeddings' loading method for all kbservie
3. make ApiRequest skip empty content when streaming json to avoid dict
   KeyError
This commit is contained in:
liunux4odoo 2023-08-09 10:46:01 +08:00
parent ec76adc81d
commit b98f5fd0b9
6 changed files with 45 additions and 28 deletions

View File

@ -11,8 +11,8 @@ from fastapi.middleware.cors import CORSMiddleware
from starlette.responses import RedirectResponse
from server.chat import (chat, knowledge_base_chat, openai_chat,
search_engine_chat)
from server.knowledge_base import (list_kbs, create_kb, delete_kb,
list_docs, upload_doc, delete_doc,
from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb
from server.knowledge_base.kb_doc_api import (list_docs, upload_doc, delete_doc,
update_doc, recreate_vector_store)
from server.utils import BaseResponse, ListResponse

View File

@ -91,7 +91,7 @@ async def download_doc():
async def recreate_vector_store(
knowledge_base_name: str,
allow_empty_kb: bool = True,
vs_type: Union[str, SupportedVSType] = "faiss",
vs_type: str = "faiss",
):
'''
recreate vector store from the content.

View File

@ -3,7 +3,6 @@ from abc import ABC, abstractmethod
import os
from functools import lru_cache
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.embeddings.base import Embeddings
from langchain.docstore.document import Document
@ -34,6 +33,9 @@ class KBService(ABC):
self.doc_path = get_doc_path(self.kb_name)
self.do_init()
def _load_embeddings(self, embed_device: str = EMBEDDING_DEVICE) -> Embeddings:
return load_embeddings(self.embed_model, embed_device)
def create_kb(self):
"""
创建知识库
@ -63,7 +65,7 @@ class KBService(ABC):
向知识库添加文件
"""
docs = kb_file.file2text()
embeddings = load_embeddings(self.embed_model, EMBEDDING_DEVICE)
embeddings = self._load_embeddings()
self.do_add_doc(docs, embeddings)
status = add_doc_to_db(kb_file)
return status
@ -88,8 +90,8 @@ class KBService(ABC):
def search_docs(self,
query: str,
top_k: int = VECTOR_SEARCH_TOP_K,
embedding_device: str = EMBEDDING_DEVICE, ):
embeddings = load_embeddings(self.embed_model, embedding_device)
):
embeddings = self._load_embeddings()
docs = self.do_search(query, top_k, embeddings)
return docs
@ -142,7 +144,8 @@ class KBService(ABC):
@abstractmethod
def do_add_doc(self,
docs: List[Document],
embeddings: Embeddings):
embeddings: Embeddings,
):
"""
向知识库添加文档子类实自己逻辑
"""

View File

@ -1,28 +1,40 @@
import os
import shutil
from configs.model_config import KB_ROOT_PATH, CACHED_VS_NUM, EMBEDDING_DEVICE
from server.knowledge_base.kb_service.base import KBService, SupportedVSType, load_embeddings
from configs.model_config import KB_ROOT_PATH, CACHED_VS_NUM, EMBEDDING_MODEL, EMBEDDING_DEVICE
from server.knowledge_base.kb_service.base import KBService, SupportedVSType
from functools import lru_cache
from server.knowledge_base.utils import get_vs_path, KnowledgeFile
from server.knowledge_base.utils import get_vs_path, load_embeddings, KnowledgeFile
from langchain.vectorstores import FAISS
from langchain.embeddings.base import Embeddings
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from typing import List
from langchain.docstore.document import Document
from server.utils import torch_gc
import numpy as np
# make HuggingFaceEmbeddings hashable
def _embeddings_hash(self):
return hash(self.model_name)
HuggingFaceEmbeddings.__hash__ = _embeddings_hash
_VECTOR_STORE_TICKS = {}
@lru_cache(CACHED_VS_NUM)
def load_vector_store(
knowledge_base_name: str,
embeddings: Embeddings,
tick: int, # tick will be changed by upload_doc etc. and make cache refreshed.
embed_model: str = EMBEDDING_MODEL,
embed_device: str = EMBEDDING_DEVICE,
embeddings: Embeddings = None,
tick: int = 0, # tick will be changed by upload_doc etc. and make cache refreshed.
):
print(f"loading vector store in '{knowledge_base_name}'.")
vs_path = get_vs_path(knowledge_base_name)
if embeddings is None:
embeddings = load_embeddings(embed_model, embed_device)
search_index = FAISS.load_local(vs_path, embeddings)
return search_index
@ -84,14 +96,15 @@ class FaissKBService(KBService):
embeddings: Embeddings,
) -> List[Document]:
search_index = load_vector_store(self.kb_name,
embeddings,
_VECTOR_STORE_TICKS.get(self.kb_name))
embeddings=embeddings,
tick=_VECTOR_STORE_TICKS.get(self.kb_name))
docs = search_index.similarity_search(query, k=top_k)
return docs
def do_add_doc(self,
docs: List[Document],
embeddings: Embeddings):
embeddings: Embeddings,
):
if os.path.exists(self.vs_path) and "index.faiss" in os.listdir(self.vs_path):
vector_store = FAISS.load_local(self.vs_path, embeddings)
vector_store.add_documents(docs)
@ -99,14 +112,15 @@ class FaissKBService(KBService):
else:
if not os.path.exists(self.vs_path):
os.makedirs(self.vs_path)
vector_store = FAISS.from_documents(docs, embeddings) # docs 为Document列表
vector_store = FAISS.from_documents(
docs, embeddings) # docs 为Document列表
torch_gc()
vector_store.save_local(self.vs_path)
refresh_vs_cache(self.kb_name)
def do_delete_doc(self,
kb_file: KnowledgeFile):
embeddings = load_embeddings(self.embed_model, EMBEDDING_DEVICE)
embeddings = self._load_embeddings()
if os.path.exists(self.vs_path) and "index.faiss" in os.listdir(self.vs_path):
vector_store = FAISS.load_local(self.vs_path, embeddings)
ids = [k for k, v in vector_store.docstore._dict.items() if v.metadata["source"] == kb_file.filepath]

View File

@ -6,7 +6,7 @@ from langchain.vectorstores import Milvus
from configs.model_config import EMBEDDING_DEVICE, kbs_config
from server.knowledge_base.kb_service.base import KBService, SupportedVSType, load_embeddings
from server.knowledge_base.kb_service.base import KBService, SupportedVSType
from server.knowledge_base.utils import KnowledgeFile
@ -33,11 +33,10 @@ class MilvusKBService(KBService):
def vs_type(self) -> str:
return SupportedVSType.MILVUS
def _load_milvus(self, embedding_device: str = EMBEDDING_DEVICE, embeddings: Embeddings = None):
_embeddings = embeddings
if _embeddings is None:
_embeddings = load_embeddings(self.embed_model, embedding_device)
self.milvus = Milvus(embedding_function=_embeddings,
def _load_milvus(self, embeddings: Embeddings = None):
if embeddings is None:
embeddings = self._load_embeddings()
self.milvus = Milvus(embedding_function=embeddings,
collection_name=self.kb_name, connection_args=kbs_config.get("milvus"))
def do_init(self):

View File

@ -169,7 +169,7 @@ class ApiRequest:
for chunk in iter_over_async(response.body_iterator, loop):
if as_json and chunk:
yield json.loads(chunk)
else:
elif chunk.strip():
yield chunk
def _httpx_stream2generator(
@ -184,7 +184,7 @@ class ApiRequest:
for chunk in r.iter_text(None):
if as_json and chunk:
yield json.loads(chunk)
else:
elif chunk.strip():
yield chunk
# 对话相关操作
@ -250,6 +250,7 @@ class ApiRequest:
query: str,
knowledge_base_name: str,
top_k: int = VECTOR_SEARCH_TOP_K,
history: List[Dict] = [],
no_remote_api: bool = None,
):
'''
@ -260,12 +261,12 @@ class ApiRequest:
if no_remote_api:
from server.chat.knowledge_base_chat import knowledge_base_chat
response = knowledge_base_chat(query, knowledge_base_name, top_k)
response = knowledge_base_chat(query, knowledge_base_name, top_k, history)
return self._fastapi_stream2generator(response, as_json=True)
else:
response = self.post(
"/chat/knowledge_base_chat",
json={"query": query, "knowledge_base_name": knowledge_base_name, "top_k": top_k},
json={"query": query, "knowledge_base_name": knowledge_base_name, "top_k": top_k, "history": history},
stream=True,
)
return self._httpx_stream2generator(response, as_json=True)