文件对话和知识库对话 API 接口实现全异步操作,防止阻塞 (#2256)
* EmbeddingFunAdapter 支持异步操作;文件对话和知识库对话 API 接口实现全异步操作,防止阻塞 * 修复: 使 list_files_from_folder 返回相对路径
This commit is contained in:
parent
dcb76984bc
commit
7d2de47bcf
|
|
@ -40,9 +40,8 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
|
||||||
callbacks = [callback]
|
callbacks = [callback]
|
||||||
memory = None
|
memory = None
|
||||||
|
|
||||||
if conversation_id:
|
|
||||||
message_id = add_message_to_db(chat_type="llm_chat", query=query, conversation_id=conversation_id)
|
|
||||||
# 负责保存llm response到message db
|
# 负责保存llm response到message db
|
||||||
|
message_id = add_message_to_db(chat_type="llm_chat", query=query, conversation_id=conversation_id)
|
||||||
conversation_callback = ConversationCallbackHandler(conversation_id=conversation_id, message_id=message_id,
|
conversation_callback = ConversationCallbackHandler(conversation_id=conversation_id, message_id=message_id,
|
||||||
chat_type="llm_chat",
|
chat_type="llm_chat",
|
||||||
query=query)
|
query=query)
|
||||||
|
|
|
||||||
|
|
@ -125,7 +125,7 @@ async def file_chat(query: str = Body(..., description="用户输入", examples=
|
||||||
callbacks=[callback],
|
callbacks=[callback],
|
||||||
)
|
)
|
||||||
embed_func = EmbeddingsFunAdapter()
|
embed_func = EmbeddingsFunAdapter()
|
||||||
embeddings = embed_func.embed_query(query)
|
embeddings = await embed_func.aembed_query(query)
|
||||||
with memo_faiss_pool.acquire(knowledge_id) as vs:
|
with memo_faiss_pool.acquire(knowledge_id) as vs:
|
||||||
docs = vs.similarity_search_with_score_by_vector(embeddings, k=top_k, score_threshold=score_threshold)
|
docs = vs.similarity_search_with_score_by_vector(embeddings, k=top_k, score_threshold=score_threshold)
|
||||||
docs = [x[0] for x in docs]
|
docs = [x[0] for x in docs]
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
from fastapi import Body, Request
|
from fastapi import Body, Request
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
|
from fastapi.concurrency import run_in_threadpool
|
||||||
from configs import (LLM_MODELS, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, TEMPERATURE)
|
from configs import (LLM_MODELS, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, TEMPERATURE)
|
||||||
from server.utils import wrap_done, get_ChatOpenAI
|
from server.utils import wrap_done, get_ChatOpenAI
|
||||||
from server.utils import BaseResponse, get_prompt_template
|
from server.utils import BaseResponse, get_prompt_template
|
||||||
|
|
@ -10,9 +11,7 @@ import asyncio
|
||||||
from langchain.prompts.chat import ChatPromptTemplate
|
from langchain.prompts.chat import ChatPromptTemplate
|
||||||
from server.chat.utils import History
|
from server.chat.utils import History
|
||||||
from server.knowledge_base.kb_service.base import KBServiceFactory
|
from server.knowledge_base.kb_service.base import KBServiceFactory
|
||||||
from server.knowledge_base.utils import get_doc_path
|
|
||||||
import json
|
import json
|
||||||
from pathlib import Path
|
|
||||||
from urllib.parse import urlencode
|
from urllib.parse import urlencode
|
||||||
from server.knowledge_base.kb_doc_api import search_docs
|
from server.knowledge_base.kb_doc_api import search_docs
|
||||||
|
|
||||||
|
|
@ -72,7 +71,11 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
callbacks=[callback],
|
callbacks=[callback],
|
||||||
)
|
)
|
||||||
docs = search_docs(query, knowledge_base_name, top_k, score_threshold)
|
docs = await run_in_threadpool(search_docs,
|
||||||
|
query=query,
|
||||||
|
knowledge_base_name=knowledge_base_name,
|
||||||
|
top_k=top_k,
|
||||||
|
score_threshold=score_threshold)
|
||||||
context = "\n".join([doc.page_content for doc in docs])
|
context = "\n".join([doc.page_content for doc in docs])
|
||||||
if len(docs) == 0: # 如果没有找到相关文档,使用empty模板
|
if len(docs) == 0: # 如果没有找到相关文档,使用empty模板
|
||||||
prompt_template = get_prompt_template("knowledge_base_chat", "empty")
|
prompt_template = get_prompt_template("knowledge_base_chat", "empty")
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@ from configs import EMBEDDING_MODEL, logger
|
||||||
from server.model_workers.base import ApiEmbeddingsParams
|
from server.model_workers.base import ApiEmbeddingsParams
|
||||||
from server.utils import BaseResponse, get_model_worker_config, list_embed_models, list_online_embed_models
|
from server.utils import BaseResponse, get_model_worker_config, list_embed_models, list_online_embed_models
|
||||||
from fastapi import Body
|
from fastapi import Body
|
||||||
|
from fastapi.concurrency import run_in_threadpool
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -39,6 +40,32 @@ def embed_texts(
|
||||||
logger.error(e)
|
logger.error(e)
|
||||||
return BaseResponse(code=500, msg=f"文本向量化过程中出现错误:{e}")
|
return BaseResponse(code=500, msg=f"文本向量化过程中出现错误:{e}")
|
||||||
|
|
||||||
|
|
||||||
|
async def aembed_texts(
|
||||||
|
texts: List[str],
|
||||||
|
embed_model: str = EMBEDDING_MODEL,
|
||||||
|
to_query: bool = False,
|
||||||
|
) -> BaseResponse:
|
||||||
|
'''
|
||||||
|
对文本进行向量化。返回数据格式:BaseResponse(data=List[List[float]])
|
||||||
|
'''
|
||||||
|
try:
|
||||||
|
if embed_model in list_embed_models(): # 使用本地Embeddings模型
|
||||||
|
from server.utils import load_local_embeddings
|
||||||
|
|
||||||
|
embeddings = load_local_embeddings(model=embed_model)
|
||||||
|
return BaseResponse(data=await embeddings.aembed_documents(texts))
|
||||||
|
|
||||||
|
if embed_model in list_online_embed_models(): # 使用在线API
|
||||||
|
return await run_in_threadpool(embed_texts,
|
||||||
|
texts=texts,
|
||||||
|
embed_model=embed_model,
|
||||||
|
to_query=to_query)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(e)
|
||||||
|
return BaseResponse(code=500, msg=f"文本向量化过程中出现错误:{e}")
|
||||||
|
|
||||||
|
|
||||||
def embed_texts_endpoint(
|
def embed_texts_endpoint(
|
||||||
texts: List[str] = Body(..., description="要嵌入的文本列表", examples=[["hello", "world"]]),
|
texts: List[str] = Body(..., description="要嵌入的文本列表", examples=[["hello", "world"]]),
|
||||||
embed_model: str = Body(EMBEDDING_MODEL, description=f"使用的嵌入模型,除了本地部署的Embedding模型,也支持在线API({online_embed_models})提供的嵌入服务。"),
|
embed_model: str = Body(EMBEDDING_MODEL, description=f"使用的嵌入模型,除了本地部署的Embedding模型,也支持在线API({online_embed_models})提供的嵌入服务。"),
|
||||||
|
|
|
||||||
|
|
@ -26,8 +26,7 @@ from server.knowledge_base.utils import (
|
||||||
|
|
||||||
from typing import List, Union, Dict, Optional
|
from typing import List, Union, Dict, Optional
|
||||||
|
|
||||||
from server.embeddings_api import embed_texts
|
from server.embeddings_api import embed_texts, aembed_texts, embed_documents
|
||||||
from server.embeddings_api import embed_documents
|
|
||||||
from server.knowledge_base.model.kb_document_model import DocumentWithVSId
|
from server.knowledge_base.model.kb_document_model import DocumentWithVSId
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -405,12 +404,16 @@ class EmbeddingsFunAdapter(Embeddings):
|
||||||
normalized_query_embed = normalize(query_embed_2d)
|
normalized_query_embed = normalize(query_embed_2d)
|
||||||
return normalized_query_embed[0].tolist() # 将结果转换为一维数组并返回
|
return normalized_query_embed[0].tolist() # 将结果转换为一维数组并返回
|
||||||
|
|
||||||
# TODO: 暂不支持异步
|
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||||
# async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
embeddings = (await aembed_texts(texts=texts, embed_model=self.embed_model, to_query=False)).data
|
||||||
# return normalize(await self.embeddings.aembed_documents(texts))
|
return normalize(embeddings).tolist()
|
||||||
|
|
||||||
# async def aembed_query(self, text: str) -> List[float]:
|
async def aembed_query(self, text: str) -> List[float]:
|
||||||
# return normalize(await self.embeddings.aembed_query(text))
|
embeddings = (await aembed_texts(texts=[text], embed_model=self.embed_model, to_query=True)).data
|
||||||
|
query_embed = embeddings[0]
|
||||||
|
query_embed_2d = np.reshape(query_embed, (1, -1)) # 将一维数组转换为二维数组
|
||||||
|
normalized_query_embed = normalize(query_embed_2d)
|
||||||
|
return normalized_query_embed[0].tolist() # 将结果转换为一维数组并返回
|
||||||
|
|
||||||
|
|
||||||
def score_threshold_process(score_threshold, k, docs):
|
def score_threshold_process(score_threshold, k, docs):
|
||||||
|
|
|
||||||
|
|
@ -71,7 +71,7 @@ def list_files_from_folder(kb_name: str):
|
||||||
for target_entry in target_it:
|
for target_entry in target_it:
|
||||||
process_entry(target_entry)
|
process_entry(target_entry)
|
||||||
elif entry.is_file():
|
elif entry.is_file():
|
||||||
result.append(entry.path)
|
result.append(os.path.relpath(entry.path, doc_path))
|
||||||
elif entry.is_dir():
|
elif entry.is_dir():
|
||||||
with os.scandir(entry.path) as it:
|
with os.scandir(entry.path) as it:
|
||||||
for sub_entry in it:
|
for sub_entry in it:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue