update knowledge base kbservice and api:
1. make huggingfaceembeddings hashable 2. unify embeddings' loading method for all kbservie 3. make ApiRequest skip empty content when streaming json to avoid dict KeyError
This commit is contained in:
parent
ec76adc81d
commit
b98f5fd0b9
|
|
@ -11,8 +11,8 @@ from fastapi.middleware.cors import CORSMiddleware
|
||||||
from starlette.responses import RedirectResponse
|
from starlette.responses import RedirectResponse
|
||||||
from server.chat import (chat, knowledge_base_chat, openai_chat,
|
from server.chat import (chat, knowledge_base_chat, openai_chat,
|
||||||
search_engine_chat)
|
search_engine_chat)
|
||||||
from server.knowledge_base import (list_kbs, create_kb, delete_kb,
|
from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb
|
||||||
list_docs, upload_doc, delete_doc,
|
from server.knowledge_base.kb_doc_api import (list_docs, upload_doc, delete_doc,
|
||||||
update_doc, recreate_vector_store)
|
update_doc, recreate_vector_store)
|
||||||
from server.utils import BaseResponse, ListResponse
|
from server.utils import BaseResponse, ListResponse
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -91,7 +91,7 @@ async def download_doc():
|
||||||
async def recreate_vector_store(
|
async def recreate_vector_store(
|
||||||
knowledge_base_name: str,
|
knowledge_base_name: str,
|
||||||
allow_empty_kb: bool = True,
|
allow_empty_kb: bool = True,
|
||||||
vs_type: Union[str, SupportedVSType] = "faiss",
|
vs_type: str = "faiss",
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
recreate vector store from the content.
|
recreate vector store from the content.
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,6 @@ from abc import ABC, abstractmethod
|
||||||
import os
|
import os
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
|
|
||||||
from langchain.embeddings import HuggingFaceEmbeddings
|
|
||||||
from langchain.embeddings.base import Embeddings
|
from langchain.embeddings.base import Embeddings
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
|
|
||||||
|
|
@ -34,6 +33,9 @@ class KBService(ABC):
|
||||||
self.doc_path = get_doc_path(self.kb_name)
|
self.doc_path = get_doc_path(self.kb_name)
|
||||||
self.do_init()
|
self.do_init()
|
||||||
|
|
||||||
|
def _load_embeddings(self, embed_device: str = EMBEDDING_DEVICE) -> Embeddings:
|
||||||
|
return load_embeddings(self.embed_model, embed_device)
|
||||||
|
|
||||||
def create_kb(self):
|
def create_kb(self):
|
||||||
"""
|
"""
|
||||||
创建知识库
|
创建知识库
|
||||||
|
|
@ -63,7 +65,7 @@ class KBService(ABC):
|
||||||
向知识库添加文件
|
向知识库添加文件
|
||||||
"""
|
"""
|
||||||
docs = kb_file.file2text()
|
docs = kb_file.file2text()
|
||||||
embeddings = load_embeddings(self.embed_model, EMBEDDING_DEVICE)
|
embeddings = self._load_embeddings()
|
||||||
self.do_add_doc(docs, embeddings)
|
self.do_add_doc(docs, embeddings)
|
||||||
status = add_doc_to_db(kb_file)
|
status = add_doc_to_db(kb_file)
|
||||||
return status
|
return status
|
||||||
|
|
@ -88,8 +90,8 @@ class KBService(ABC):
|
||||||
def search_docs(self,
|
def search_docs(self,
|
||||||
query: str,
|
query: str,
|
||||||
top_k: int = VECTOR_SEARCH_TOP_K,
|
top_k: int = VECTOR_SEARCH_TOP_K,
|
||||||
embedding_device: str = EMBEDDING_DEVICE, ):
|
):
|
||||||
embeddings = load_embeddings(self.embed_model, embedding_device)
|
embeddings = self._load_embeddings()
|
||||||
docs = self.do_search(query, top_k, embeddings)
|
docs = self.do_search(query, top_k, embeddings)
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
|
|
@ -142,7 +144,8 @@ class KBService(ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def do_add_doc(self,
|
def do_add_doc(self,
|
||||||
docs: List[Document],
|
docs: List[Document],
|
||||||
embeddings: Embeddings):
|
embeddings: Embeddings,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
向知识库添加文档子类实自己逻辑
|
向知识库添加文档子类实自己逻辑
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -1,28 +1,40 @@
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
from configs.model_config import KB_ROOT_PATH, CACHED_VS_NUM, EMBEDDING_DEVICE
|
from configs.model_config import KB_ROOT_PATH, CACHED_VS_NUM, EMBEDDING_MODEL, EMBEDDING_DEVICE
|
||||||
from server.knowledge_base.kb_service.base import KBService, SupportedVSType, load_embeddings
|
from server.knowledge_base.kb_service.base import KBService, SupportedVSType
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from server.knowledge_base.utils import get_vs_path, KnowledgeFile
|
from server.knowledge_base.utils import get_vs_path, load_embeddings, KnowledgeFile
|
||||||
from langchain.vectorstores import FAISS
|
from langchain.vectorstores import FAISS
|
||||||
from langchain.embeddings.base import Embeddings
|
from langchain.embeddings.base import Embeddings
|
||||||
|
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
||||||
from typing import List
|
from typing import List
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
from server.utils import torch_gc
|
from server.utils import torch_gc
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
|
# make HuggingFaceEmbeddings hashable
|
||||||
|
def _embeddings_hash(self):
|
||||||
|
return hash(self.model_name)
|
||||||
|
HuggingFaceEmbeddings.__hash__ = _embeddings_hash
|
||||||
|
|
||||||
|
|
||||||
_VECTOR_STORE_TICKS = {}
|
_VECTOR_STORE_TICKS = {}
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(CACHED_VS_NUM)
|
@lru_cache(CACHED_VS_NUM)
|
||||||
def load_vector_store(
|
def load_vector_store(
|
||||||
knowledge_base_name: str,
|
knowledge_base_name: str,
|
||||||
embeddings: Embeddings,
|
embed_model: str = EMBEDDING_MODEL,
|
||||||
tick: int, # tick will be changed by upload_doc etc. and make cache refreshed.
|
embed_device: str = EMBEDDING_DEVICE,
|
||||||
|
embeddings: Embeddings = None,
|
||||||
|
tick: int = 0, # tick will be changed by upload_doc etc. and make cache refreshed.
|
||||||
):
|
):
|
||||||
print(f"loading vector store in '{knowledge_base_name}'.")
|
print(f"loading vector store in '{knowledge_base_name}'.")
|
||||||
vs_path = get_vs_path(knowledge_base_name)
|
vs_path = get_vs_path(knowledge_base_name)
|
||||||
|
if embeddings is None:
|
||||||
|
embeddings = load_embeddings(embed_model, embed_device)
|
||||||
search_index = FAISS.load_local(vs_path, embeddings)
|
search_index = FAISS.load_local(vs_path, embeddings)
|
||||||
return search_index
|
return search_index
|
||||||
|
|
||||||
|
|
@ -84,14 +96,15 @@ class FaissKBService(KBService):
|
||||||
embeddings: Embeddings,
|
embeddings: Embeddings,
|
||||||
) -> List[Document]:
|
) -> List[Document]:
|
||||||
search_index = load_vector_store(self.kb_name,
|
search_index = load_vector_store(self.kb_name,
|
||||||
embeddings,
|
embeddings=embeddings,
|
||||||
_VECTOR_STORE_TICKS.get(self.kb_name))
|
tick=_VECTOR_STORE_TICKS.get(self.kb_name))
|
||||||
docs = search_index.similarity_search(query, k=top_k)
|
docs = search_index.similarity_search(query, k=top_k)
|
||||||
return docs
|
return docs
|
||||||
|
|
||||||
def do_add_doc(self,
|
def do_add_doc(self,
|
||||||
docs: List[Document],
|
docs: List[Document],
|
||||||
embeddings: Embeddings):
|
embeddings: Embeddings,
|
||||||
|
):
|
||||||
if os.path.exists(self.vs_path) and "index.faiss" in os.listdir(self.vs_path):
|
if os.path.exists(self.vs_path) and "index.faiss" in os.listdir(self.vs_path):
|
||||||
vector_store = FAISS.load_local(self.vs_path, embeddings)
|
vector_store = FAISS.load_local(self.vs_path, embeddings)
|
||||||
vector_store.add_documents(docs)
|
vector_store.add_documents(docs)
|
||||||
|
|
@ -99,14 +112,15 @@ class FaissKBService(KBService):
|
||||||
else:
|
else:
|
||||||
if not os.path.exists(self.vs_path):
|
if not os.path.exists(self.vs_path):
|
||||||
os.makedirs(self.vs_path)
|
os.makedirs(self.vs_path)
|
||||||
vector_store = FAISS.from_documents(docs, embeddings) # docs 为Document列表
|
vector_store = FAISS.from_documents(
|
||||||
|
docs, embeddings) # docs 为Document列表
|
||||||
torch_gc()
|
torch_gc()
|
||||||
vector_store.save_local(self.vs_path)
|
vector_store.save_local(self.vs_path)
|
||||||
refresh_vs_cache(self.kb_name)
|
refresh_vs_cache(self.kb_name)
|
||||||
|
|
||||||
def do_delete_doc(self,
|
def do_delete_doc(self,
|
||||||
kb_file: KnowledgeFile):
|
kb_file: KnowledgeFile):
|
||||||
embeddings = load_embeddings(self.embed_model, EMBEDDING_DEVICE)
|
embeddings = self._load_embeddings()
|
||||||
if os.path.exists(self.vs_path) and "index.faiss" in os.listdir(self.vs_path):
|
if os.path.exists(self.vs_path) and "index.faiss" in os.listdir(self.vs_path):
|
||||||
vector_store = FAISS.load_local(self.vs_path, embeddings)
|
vector_store = FAISS.load_local(self.vs_path, embeddings)
|
||||||
ids = [k for k, v in vector_store.docstore._dict.items() if v.metadata["source"] == kb_file.filepath]
|
ids = [k for k, v in vector_store.docstore._dict.items() if v.metadata["source"] == kb_file.filepath]
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ from langchain.vectorstores import Milvus
|
||||||
|
|
||||||
from configs.model_config import EMBEDDING_DEVICE, kbs_config
|
from configs.model_config import EMBEDDING_DEVICE, kbs_config
|
||||||
|
|
||||||
from server.knowledge_base.kb_service.base import KBService, SupportedVSType, load_embeddings
|
from server.knowledge_base.kb_service.base import KBService, SupportedVSType
|
||||||
from server.knowledge_base.utils import KnowledgeFile
|
from server.knowledge_base.utils import KnowledgeFile
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -33,11 +33,10 @@ class MilvusKBService(KBService):
|
||||||
def vs_type(self) -> str:
|
def vs_type(self) -> str:
|
||||||
return SupportedVSType.MILVUS
|
return SupportedVSType.MILVUS
|
||||||
|
|
||||||
def _load_milvus(self, embedding_device: str = EMBEDDING_DEVICE, embeddings: Embeddings = None):
|
def _load_milvus(self, embeddings: Embeddings = None):
|
||||||
_embeddings = embeddings
|
if embeddings is None:
|
||||||
if _embeddings is None:
|
embeddings = self._load_embeddings()
|
||||||
_embeddings = load_embeddings(self.embed_model, embedding_device)
|
self.milvus = Milvus(embedding_function=embeddings,
|
||||||
self.milvus = Milvus(embedding_function=_embeddings,
|
|
||||||
collection_name=self.kb_name, connection_args=kbs_config.get("milvus"))
|
collection_name=self.kb_name, connection_args=kbs_config.get("milvus"))
|
||||||
|
|
||||||
def do_init(self):
|
def do_init(self):
|
||||||
|
|
|
||||||
|
|
@ -169,7 +169,7 @@ class ApiRequest:
|
||||||
for chunk in iter_over_async(response.body_iterator, loop):
|
for chunk in iter_over_async(response.body_iterator, loop):
|
||||||
if as_json and chunk:
|
if as_json and chunk:
|
||||||
yield json.loads(chunk)
|
yield json.loads(chunk)
|
||||||
else:
|
elif chunk.strip():
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
def _httpx_stream2generator(
|
def _httpx_stream2generator(
|
||||||
|
|
@ -184,7 +184,7 @@ class ApiRequest:
|
||||||
for chunk in r.iter_text(None):
|
for chunk in r.iter_text(None):
|
||||||
if as_json and chunk:
|
if as_json and chunk:
|
||||||
yield json.loads(chunk)
|
yield json.loads(chunk)
|
||||||
else:
|
elif chunk.strip():
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
# 对话相关操作
|
# 对话相关操作
|
||||||
|
|
@ -250,6 +250,7 @@ class ApiRequest:
|
||||||
query: str,
|
query: str,
|
||||||
knowledge_base_name: str,
|
knowledge_base_name: str,
|
||||||
top_k: int = VECTOR_SEARCH_TOP_K,
|
top_k: int = VECTOR_SEARCH_TOP_K,
|
||||||
|
history: List[Dict] = [],
|
||||||
no_remote_api: bool = None,
|
no_remote_api: bool = None,
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
|
|
@ -260,12 +261,12 @@ class ApiRequest:
|
||||||
|
|
||||||
if no_remote_api:
|
if no_remote_api:
|
||||||
from server.chat.knowledge_base_chat import knowledge_base_chat
|
from server.chat.knowledge_base_chat import knowledge_base_chat
|
||||||
response = knowledge_base_chat(query, knowledge_base_name, top_k)
|
response = knowledge_base_chat(query, knowledge_base_name, top_k, history)
|
||||||
return self._fastapi_stream2generator(response, as_json=True)
|
return self._fastapi_stream2generator(response, as_json=True)
|
||||||
else:
|
else:
|
||||||
response = self.post(
|
response = self.post(
|
||||||
"/chat/knowledge_base_chat",
|
"/chat/knowledge_base_chat",
|
||||||
json={"query": query, "knowledge_base_name": knowledge_base_name, "top_k": top_k},
|
json={"query": query, "knowledge_base_name": knowledge_base_name, "top_k": top_k, "history": history},
|
||||||
stream=True,
|
stream=True,
|
||||||
)
|
)
|
||||||
return self._httpx_stream2generator(response, as_json=True)
|
return self._httpx_stream2generator(response, as_json=True)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue