update knowledge base kbservice and api:
1. make huggingfaceembeddings hashable 2. unify embeddings' loading method for all kbservie 3. make ApiRequest skip empty content when streaming json to avoid dict KeyError
This commit is contained in:
parent
ec76adc81d
commit
b98f5fd0b9
|
|
@ -11,8 +11,8 @@ from fastapi.middleware.cors import CORSMiddleware
|
|||
from starlette.responses import RedirectResponse
|
||||
from server.chat import (chat, knowledge_base_chat, openai_chat,
|
||||
search_engine_chat)
|
||||
from server.knowledge_base import (list_kbs, create_kb, delete_kb,
|
||||
list_docs, upload_doc, delete_doc,
|
||||
from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb
|
||||
from server.knowledge_base.kb_doc_api import (list_docs, upload_doc, delete_doc,
|
||||
update_doc, recreate_vector_store)
|
||||
from server.utils import BaseResponse, ListResponse
|
||||
|
||||
|
|
|
|||
|
|
@ -91,7 +91,7 @@ async def download_doc():
|
|||
async def recreate_vector_store(
|
||||
knowledge_base_name: str,
|
||||
allow_empty_kb: bool = True,
|
||||
vs_type: Union[str, SupportedVSType] = "faiss",
|
||||
vs_type: str = "faiss",
|
||||
):
|
||||
'''
|
||||
recreate vector store from the content.
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@ from abc import ABC, abstractmethod
|
|||
import os
|
||||
from functools import lru_cache
|
||||
|
||||
from langchain.embeddings import HuggingFaceEmbeddings
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.docstore.document import Document
|
||||
|
||||
|
|
@ -34,6 +33,9 @@ class KBService(ABC):
|
|||
self.doc_path = get_doc_path(self.kb_name)
|
||||
self.do_init()
|
||||
|
||||
def _load_embeddings(self, embed_device: str = EMBEDDING_DEVICE) -> Embeddings:
|
||||
return load_embeddings(self.embed_model, embed_device)
|
||||
|
||||
def create_kb(self):
|
||||
"""
|
||||
创建知识库
|
||||
|
|
@ -63,7 +65,7 @@ class KBService(ABC):
|
|||
向知识库添加文件
|
||||
"""
|
||||
docs = kb_file.file2text()
|
||||
embeddings = load_embeddings(self.embed_model, EMBEDDING_DEVICE)
|
||||
embeddings = self._load_embeddings()
|
||||
self.do_add_doc(docs, embeddings)
|
||||
status = add_doc_to_db(kb_file)
|
||||
return status
|
||||
|
|
@ -88,8 +90,8 @@ class KBService(ABC):
|
|||
def search_docs(self,
|
||||
query: str,
|
||||
top_k: int = VECTOR_SEARCH_TOP_K,
|
||||
embedding_device: str = EMBEDDING_DEVICE, ):
|
||||
embeddings = load_embeddings(self.embed_model, embedding_device)
|
||||
):
|
||||
embeddings = self._load_embeddings()
|
||||
docs = self.do_search(query, top_k, embeddings)
|
||||
return docs
|
||||
|
||||
|
|
@ -142,7 +144,8 @@ class KBService(ABC):
|
|||
@abstractmethod
|
||||
def do_add_doc(self,
|
||||
docs: List[Document],
|
||||
embeddings: Embeddings):
|
||||
embeddings: Embeddings,
|
||||
):
|
||||
"""
|
||||
向知识库添加文档子类实自己逻辑
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,28 +1,40 @@
|
|||
import os
|
||||
import shutil
|
||||
|
||||
from configs.model_config import KB_ROOT_PATH, CACHED_VS_NUM, EMBEDDING_DEVICE
|
||||
from server.knowledge_base.kb_service.base import KBService, SupportedVSType, load_embeddings
|
||||
from configs.model_config import KB_ROOT_PATH, CACHED_VS_NUM, EMBEDDING_MODEL, EMBEDDING_DEVICE
|
||||
from server.knowledge_base.kb_service.base import KBService, SupportedVSType
|
||||
from functools import lru_cache
|
||||
from server.knowledge_base.utils import get_vs_path, KnowledgeFile
|
||||
from server.knowledge_base.utils import get_vs_path, load_embeddings, KnowledgeFile
|
||||
from langchain.vectorstores import FAISS
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
||||
from typing import List
|
||||
from langchain.docstore.document import Document
|
||||
from server.utils import torch_gc
|
||||
import numpy as np
|
||||
|
||||
|
||||
# make HuggingFaceEmbeddings hashable
|
||||
def _embeddings_hash(self):
|
||||
return hash(self.model_name)
|
||||
HuggingFaceEmbeddings.__hash__ = _embeddings_hash
|
||||
|
||||
|
||||
_VECTOR_STORE_TICKS = {}
|
||||
|
||||
|
||||
@lru_cache(CACHED_VS_NUM)
|
||||
def load_vector_store(
|
||||
knowledge_base_name: str,
|
||||
embeddings: Embeddings,
|
||||
tick: int, # tick will be changed by upload_doc etc. and make cache refreshed.
|
||||
embed_model: str = EMBEDDING_MODEL,
|
||||
embed_device: str = EMBEDDING_DEVICE,
|
||||
embeddings: Embeddings = None,
|
||||
tick: int = 0, # tick will be changed by upload_doc etc. and make cache refreshed.
|
||||
):
|
||||
print(f"loading vector store in '{knowledge_base_name}'.")
|
||||
vs_path = get_vs_path(knowledge_base_name)
|
||||
if embeddings is None:
|
||||
embeddings = load_embeddings(embed_model, embed_device)
|
||||
search_index = FAISS.load_local(vs_path, embeddings)
|
||||
return search_index
|
||||
|
||||
|
|
@ -84,14 +96,15 @@ class FaissKBService(KBService):
|
|||
embeddings: Embeddings,
|
||||
) -> List[Document]:
|
||||
search_index = load_vector_store(self.kb_name,
|
||||
embeddings,
|
||||
_VECTOR_STORE_TICKS.get(self.kb_name))
|
||||
embeddings=embeddings,
|
||||
tick=_VECTOR_STORE_TICKS.get(self.kb_name))
|
||||
docs = search_index.similarity_search(query, k=top_k)
|
||||
return docs
|
||||
|
||||
def do_add_doc(self,
|
||||
docs: List[Document],
|
||||
embeddings: Embeddings):
|
||||
embeddings: Embeddings,
|
||||
):
|
||||
if os.path.exists(self.vs_path) and "index.faiss" in os.listdir(self.vs_path):
|
||||
vector_store = FAISS.load_local(self.vs_path, embeddings)
|
||||
vector_store.add_documents(docs)
|
||||
|
|
@ -99,14 +112,15 @@ class FaissKBService(KBService):
|
|||
else:
|
||||
if not os.path.exists(self.vs_path):
|
||||
os.makedirs(self.vs_path)
|
||||
vector_store = FAISS.from_documents(docs, embeddings) # docs 为Document列表
|
||||
vector_store = FAISS.from_documents(
|
||||
docs, embeddings) # docs 为Document列表
|
||||
torch_gc()
|
||||
vector_store.save_local(self.vs_path)
|
||||
refresh_vs_cache(self.kb_name)
|
||||
|
||||
def do_delete_doc(self,
|
||||
kb_file: KnowledgeFile):
|
||||
embeddings = load_embeddings(self.embed_model, EMBEDDING_DEVICE)
|
||||
embeddings = self._load_embeddings()
|
||||
if os.path.exists(self.vs_path) and "index.faiss" in os.listdir(self.vs_path):
|
||||
vector_store = FAISS.load_local(self.vs_path, embeddings)
|
||||
ids = [k for k, v in vector_store.docstore._dict.items() if v.metadata["source"] == kb_file.filepath]
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from langchain.vectorstores import Milvus
|
|||
|
||||
from configs.model_config import EMBEDDING_DEVICE, kbs_config
|
||||
|
||||
from server.knowledge_base.kb_service.base import KBService, SupportedVSType, load_embeddings
|
||||
from server.knowledge_base.kb_service.base import KBService, SupportedVSType
|
||||
from server.knowledge_base.utils import KnowledgeFile
|
||||
|
||||
|
||||
|
|
@ -33,11 +33,10 @@ class MilvusKBService(KBService):
|
|||
def vs_type(self) -> str:
|
||||
return SupportedVSType.MILVUS
|
||||
|
||||
def _load_milvus(self, embedding_device: str = EMBEDDING_DEVICE, embeddings: Embeddings = None):
|
||||
_embeddings = embeddings
|
||||
if _embeddings is None:
|
||||
_embeddings = load_embeddings(self.embed_model, embedding_device)
|
||||
self.milvus = Milvus(embedding_function=_embeddings,
|
||||
def _load_milvus(self, embeddings: Embeddings = None):
|
||||
if embeddings is None:
|
||||
embeddings = self._load_embeddings()
|
||||
self.milvus = Milvus(embedding_function=embeddings,
|
||||
collection_name=self.kb_name, connection_args=kbs_config.get("milvus"))
|
||||
|
||||
def do_init(self):
|
||||
|
|
|
|||
|
|
@ -169,7 +169,7 @@ class ApiRequest:
|
|||
for chunk in iter_over_async(response.body_iterator, loop):
|
||||
if as_json and chunk:
|
||||
yield json.loads(chunk)
|
||||
else:
|
||||
elif chunk.strip():
|
||||
yield chunk
|
||||
|
||||
def _httpx_stream2generator(
|
||||
|
|
@ -184,7 +184,7 @@ class ApiRequest:
|
|||
for chunk in r.iter_text(None):
|
||||
if as_json and chunk:
|
||||
yield json.loads(chunk)
|
||||
else:
|
||||
elif chunk.strip():
|
||||
yield chunk
|
||||
|
||||
# 对话相关操作
|
||||
|
|
@ -250,6 +250,7 @@ class ApiRequest:
|
|||
query: str,
|
||||
knowledge_base_name: str,
|
||||
top_k: int = VECTOR_SEARCH_TOP_K,
|
||||
history: List[Dict] = [],
|
||||
no_remote_api: bool = None,
|
||||
):
|
||||
'''
|
||||
|
|
@ -260,12 +261,12 @@ class ApiRequest:
|
|||
|
||||
if no_remote_api:
|
||||
from server.chat.knowledge_base_chat import knowledge_base_chat
|
||||
response = knowledge_base_chat(query, knowledge_base_name, top_k)
|
||||
response = knowledge_base_chat(query, knowledge_base_name, top_k, history)
|
||||
return self._fastapi_stream2generator(response, as_json=True)
|
||||
else:
|
||||
response = self.post(
|
||||
"/chat/knowledge_base_chat",
|
||||
json={"query": query, "knowledge_base_name": knowledge_base_name, "top_k": top_k},
|
||||
json={"query": query, "knowledge_base_name": knowledge_base_name, "top_k": top_k, "history": history},
|
||||
stream=True,
|
||||
)
|
||||
return self._httpx_stream2generator(response, as_json=True)
|
||||
|
|
|
|||
Loading…
Reference in New Issue