文件对话和知识库对话 API 接口实现全异步操作,防止阻塞 (#2256)

* EmbeddingFunAdapter 支持异步操作;文件对话和知识库对话 API 接口实现全异步操作,防止阻塞

* 修复: 使 list_files_from_folder 返回相对路径
This commit is contained in:
liunux4odoo 2023-12-02 19:22:44 +08:00 committed by GitHub
parent dcb76984bc
commit 7d2de47bcf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 51 additions and 19 deletions

View File

@ -40,9 +40,8 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
callbacks = [callback]
memory = None
if conversation_id:
message_id = add_message_to_db(chat_type="llm_chat", query=query, conversation_id=conversation_id)
# 负责保存llm response到message db
message_id = add_message_to_db(chat_type="llm_chat", query=query, conversation_id=conversation_id)
conversation_callback = ConversationCallbackHandler(conversation_id=conversation_id, message_id=message_id,
chat_type="llm_chat",
query=query)

View File

@ -125,7 +125,7 @@ async def file_chat(query: str = Body(..., description="用户输入", examples=
callbacks=[callback],
)
embed_func = EmbeddingsFunAdapter()
embeddings = embed_func.embed_query(query)
embeddings = await embed_func.aembed_query(query)
with memo_faiss_pool.acquire(knowledge_id) as vs:
docs = vs.similarity_search_with_score_by_vector(embeddings, k=top_k, score_threshold=score_threshold)
docs = [x[0] for x in docs]

View File

@ -1,5 +1,6 @@
from fastapi import Body, Request
from fastapi.responses import StreamingResponse
from fastapi.concurrency import run_in_threadpool
from configs import (LLM_MODELS, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, TEMPERATURE)
from server.utils import wrap_done, get_ChatOpenAI
from server.utils import BaseResponse, get_prompt_template
@ -10,9 +11,7 @@ import asyncio
from langchain.prompts.chat import ChatPromptTemplate
from server.chat.utils import History
from server.knowledge_base.kb_service.base import KBServiceFactory
from server.knowledge_base.utils import get_doc_path
import json
from pathlib import Path
from urllib.parse import urlencode
from server.knowledge_base.kb_doc_api import search_docs
@ -72,7 +71,11 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
max_tokens=max_tokens,
callbacks=[callback],
)
docs = search_docs(query, knowledge_base_name, top_k, score_threshold)
docs = await run_in_threadpool(search_docs,
query=query,
knowledge_base_name=knowledge_base_name,
top_k=top_k,
score_threshold=score_threshold)
context = "\n".join([doc.page_content for doc in docs])
if len(docs) == 0: # 如果没有找到相关文档使用empty模板
prompt_template = get_prompt_template("knowledge_base_chat", "empty")

View File

@ -3,6 +3,7 @@ from configs import EMBEDDING_MODEL, logger
from server.model_workers.base import ApiEmbeddingsParams
from server.utils import BaseResponse, get_model_worker_config, list_embed_models, list_online_embed_models
from fastapi import Body
from fastapi.concurrency import run_in_threadpool
from typing import Dict, List
@ -39,6 +40,32 @@ def embed_texts(
logger.error(e)
return BaseResponse(code=500, msg=f"文本向量化过程中出现错误:{e}")
async def aembed_texts(
texts: List[str],
embed_model: str = EMBEDDING_MODEL,
to_query: bool = False,
) -> BaseResponse:
'''
对文本进行向量化返回数据格式BaseResponse(data=List[List[float]])
'''
try:
if embed_model in list_embed_models(): # 使用本地Embeddings模型
from server.utils import load_local_embeddings
embeddings = load_local_embeddings(model=embed_model)
return BaseResponse(data=await embeddings.aembed_documents(texts))
if embed_model in list_online_embed_models(): # 使用在线API
return await run_in_threadpool(embed_texts,
texts=texts,
embed_model=embed_model,
to_query=to_query)
except Exception as e:
logger.error(e)
return BaseResponse(code=500, msg=f"文本向量化过程中出现错误:{e}")
def embed_texts_endpoint(
texts: List[str] = Body(..., description="要嵌入的文本列表", examples=[["hello", "world"]]),
embed_model: str = Body(EMBEDDING_MODEL, description=f"使用的嵌入模型除了本地部署的Embedding模型也支持在线API({online_embed_models})提供的嵌入服务。"),

View File

@ -26,8 +26,7 @@ from server.knowledge_base.utils import (
from typing import List, Union, Dict, Optional
from server.embeddings_api import embed_texts
from server.embeddings_api import embed_documents
from server.embeddings_api import embed_texts, aembed_texts, embed_documents
from server.knowledge_base.model.kb_document_model import DocumentWithVSId
@ -405,12 +404,16 @@ class EmbeddingsFunAdapter(Embeddings):
normalized_query_embed = normalize(query_embed_2d)
return normalized_query_embed[0].tolist() # 将结果转换为一维数组并返回
# TODO: 暂不支持异步
# async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
# return normalize(await self.embeddings.aembed_documents(texts))
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
embeddings = (await aembed_texts(texts=texts, embed_model=self.embed_model, to_query=False)).data
return normalize(embeddings).tolist()
# async def aembed_query(self, text: str) -> List[float]:
# return normalize(await self.embeddings.aembed_query(text))
async def aembed_query(self, text: str) -> List[float]:
embeddings = (await aembed_texts(texts=[text], embed_model=self.embed_model, to_query=True)).data
query_embed = embeddings[0]
query_embed_2d = np.reshape(query_embed, (1, -1)) # 将一维数组转换为二维数组
normalized_query_embed = normalize(query_embed_2d)
return normalized_query_embed[0].tolist() # 将结果转换为一维数组并返回
def score_threshold_process(score_threshold, k, docs):

View File

@ -71,7 +71,7 @@ def list_files_from_folder(kb_name: str):
for target_entry in target_it:
process_entry(target_entry)
elif entry.is_file():
result.append(entry.path)
result.append(os.path.relpath(entry.path, doc_path))
elif entry.is_dir():
with os.scandir(entry.path) as it:
for sub_entry in it: