文件对话和知识库对话 API 接口实现全异步操作,防止阻塞 (#2256)

* EmbeddingFunAdapter 支持异步操作;文件对话和知识库对话 API 接口实现全异步操作,防止阻塞

* 修复: 使 list_files_from_folder 返回相对路径
This commit is contained in:
liunux4odoo 2023-12-02 19:22:44 +08:00 committed by GitHub
parent dcb76984bc
commit 7d2de47bcf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 51 additions and 19 deletions

View File

@ -40,13 +40,12 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
callbacks = [callback] callbacks = [callback]
memory = None memory = None
if conversation_id: # 负责保存llm response到message db
message_id = add_message_to_db(chat_type="llm_chat", query=query, conversation_id=conversation_id) message_id = add_message_to_db(chat_type="llm_chat", query=query, conversation_id=conversation_id)
# 负责保存llm response到message db conversation_callback = ConversationCallbackHandler(conversation_id=conversation_id, message_id=message_id,
conversation_callback = ConversationCallbackHandler(conversation_id=conversation_id, message_id=message_id, chat_type="llm_chat",
chat_type="llm_chat", query=query)
query=query) callbacks.append(conversation_callback)
callbacks.append(conversation_callback)
if isinstance(max_tokens, int) and max_tokens <= 0: if isinstance(max_tokens, int) and max_tokens <= 0:
max_tokens = None max_tokens = None

View File

@ -125,7 +125,7 @@ async def file_chat(query: str = Body(..., description="用户输入", examples=
callbacks=[callback], callbacks=[callback],
) )
embed_func = EmbeddingsFunAdapter() embed_func = EmbeddingsFunAdapter()
embeddings = embed_func.embed_query(query) embeddings = await embed_func.aembed_query(query)
with memo_faiss_pool.acquire(knowledge_id) as vs: with memo_faiss_pool.acquire(knowledge_id) as vs:
docs = vs.similarity_search_with_score_by_vector(embeddings, k=top_k, score_threshold=score_threshold) docs = vs.similarity_search_with_score_by_vector(embeddings, k=top_k, score_threshold=score_threshold)
docs = [x[0] for x in docs] docs = [x[0] for x in docs]

View File

@ -1,5 +1,6 @@
from fastapi import Body, Request from fastapi import Body, Request
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from fastapi.concurrency import run_in_threadpool
from configs import (LLM_MODELS, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, TEMPERATURE) from configs import (LLM_MODELS, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, TEMPERATURE)
from server.utils import wrap_done, get_ChatOpenAI from server.utils import wrap_done, get_ChatOpenAI
from server.utils import BaseResponse, get_prompt_template from server.utils import BaseResponse, get_prompt_template
@ -10,9 +11,7 @@ import asyncio
from langchain.prompts.chat import ChatPromptTemplate from langchain.prompts.chat import ChatPromptTemplate
from server.chat.utils import History from server.chat.utils import History
from server.knowledge_base.kb_service.base import KBServiceFactory from server.knowledge_base.kb_service.base import KBServiceFactory
from server.knowledge_base.utils import get_doc_path
import json import json
from pathlib import Path
from urllib.parse import urlencode from urllib.parse import urlencode
from server.knowledge_base.kb_doc_api import search_docs from server.knowledge_base.kb_doc_api import search_docs
@ -72,7 +71,11 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
max_tokens=max_tokens, max_tokens=max_tokens,
callbacks=[callback], callbacks=[callback],
) )
docs = search_docs(query, knowledge_base_name, top_k, score_threshold) docs = await run_in_threadpool(search_docs,
query=query,
knowledge_base_name=knowledge_base_name,
top_k=top_k,
score_threshold=score_threshold)
context = "\n".join([doc.page_content for doc in docs]) context = "\n".join([doc.page_content for doc in docs])
if len(docs) == 0: # 如果没有找到相关文档使用empty模板 if len(docs) == 0: # 如果没有找到相关文档使用empty模板
prompt_template = get_prompt_template("knowledge_base_chat", "empty") prompt_template = get_prompt_template("knowledge_base_chat", "empty")

View File

@ -3,6 +3,7 @@ from configs import EMBEDDING_MODEL, logger
from server.model_workers.base import ApiEmbeddingsParams from server.model_workers.base import ApiEmbeddingsParams
from server.utils import BaseResponse, get_model_worker_config, list_embed_models, list_online_embed_models from server.utils import BaseResponse, get_model_worker_config, list_embed_models, list_online_embed_models
from fastapi import Body from fastapi import Body
from fastapi.concurrency import run_in_threadpool
from typing import Dict, List from typing import Dict, List
@ -39,6 +40,32 @@ def embed_texts(
logger.error(e) logger.error(e)
return BaseResponse(code=500, msg=f"文本向量化过程中出现错误:{e}") return BaseResponse(code=500, msg=f"文本向量化过程中出现错误:{e}")
async def aembed_texts(
texts: List[str],
embed_model: str = EMBEDDING_MODEL,
to_query: bool = False,
) -> BaseResponse:
'''
对文本进行向量化返回数据格式BaseResponse(data=List[List[float]])
'''
try:
if embed_model in list_embed_models(): # 使用本地Embeddings模型
from server.utils import load_local_embeddings
embeddings = load_local_embeddings(model=embed_model)
return BaseResponse(data=await embeddings.aembed_documents(texts))
if embed_model in list_online_embed_models(): # 使用在线API
return await run_in_threadpool(embed_texts,
texts=texts,
embed_model=embed_model,
to_query=to_query)
except Exception as e:
logger.error(e)
return BaseResponse(code=500, msg=f"文本向量化过程中出现错误:{e}")
def embed_texts_endpoint( def embed_texts_endpoint(
texts: List[str] = Body(..., description="要嵌入的文本列表", examples=[["hello", "world"]]), texts: List[str] = Body(..., description="要嵌入的文本列表", examples=[["hello", "world"]]),
embed_model: str = Body(EMBEDDING_MODEL, description=f"使用的嵌入模型除了本地部署的Embedding模型也支持在线API({online_embed_models})提供的嵌入服务。"), embed_model: str = Body(EMBEDDING_MODEL, description=f"使用的嵌入模型除了本地部署的Embedding模型也支持在线API({online_embed_models})提供的嵌入服务。"),

View File

@ -26,8 +26,7 @@ from server.knowledge_base.utils import (
from typing import List, Union, Dict, Optional from typing import List, Union, Dict, Optional
from server.embeddings_api import embed_texts from server.embeddings_api import embed_texts, aembed_texts, embed_documents
from server.embeddings_api import embed_documents
from server.knowledge_base.model.kb_document_model import DocumentWithVSId from server.knowledge_base.model.kb_document_model import DocumentWithVSId
@ -405,12 +404,16 @@ class EmbeddingsFunAdapter(Embeddings):
normalized_query_embed = normalize(query_embed_2d) normalized_query_embed = normalize(query_embed_2d)
return normalized_query_embed[0].tolist() # 将结果转换为一维数组并返回 return normalized_query_embed[0].tolist() # 将结果转换为一维数组并返回
# TODO: 暂不支持异步 async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
# async def aembed_documents(self, texts: List[str]) -> List[List[float]]: embeddings = (await aembed_texts(texts=texts, embed_model=self.embed_model, to_query=False)).data
# return normalize(await self.embeddings.aembed_documents(texts)) return normalize(embeddings).tolist()
# async def aembed_query(self, text: str) -> List[float]: async def aembed_query(self, text: str) -> List[float]:
# return normalize(await self.embeddings.aembed_query(text)) embeddings = (await aembed_texts(texts=[text], embed_model=self.embed_model, to_query=True)).data
query_embed = embeddings[0]
query_embed_2d = np.reshape(query_embed, (1, -1)) # 将一维数组转换为二维数组
normalized_query_embed = normalize(query_embed_2d)
return normalized_query_embed[0].tolist() # 将结果转换为一维数组并返回
def score_threshold_process(score_threshold, k, docs): def score_threshold_process(score_threshold, k, docs):

View File

@ -71,7 +71,7 @@ def list_files_from_folder(kb_name: str):
for target_entry in target_it: for target_entry in target_it:
process_entry(target_entry) process_entry(target_entry)
elif entry.is_file(): elif entry.is_file():
result.append(entry.path) result.append(os.path.relpath(entry.path, doc_path))
elif entry.is_dir(): elif entry.is_dir():
with os.scandir(entry.path) as it: with os.scandir(entry.path) as it:
for sub_entry in it: for sub_entry in it: