改变api视图函数的sync/async,提高api并发能力: (#1414)
1. 4个chat类接口改为async 2. 知识库操作,涉及向量库修改的使用async,避免FAISS写入错误;涉及向量库读取的改为sync,提高并发
This commit is contained in:
parent
1195eb75eb
commit
775870a516
|
|
@ -12,8 +12,8 @@ from typing import List
|
|||
from server.chat.utils import History
|
||||
|
||||
|
||||
def chat(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]),
|
||||
history: List[History] = Body([],
|
||||
async def chat(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]),
|
||||
history: List[History] = Body([],
|
||||
description="历史对话",
|
||||
examples=[[
|
||||
{"role": "user", "content": "我们来玩成语接龙,我先来,生龙活虎"},
|
||||
|
|
|
|||
|
|
@ -18,11 +18,11 @@ from urllib.parse import urlencode
|
|||
from server.knowledge_base.kb_doc_api import search_docs
|
||||
|
||||
|
||||
def knowledge_base_chat(query: str = Body(..., description="用户输入", examples=["你好"]),
|
||||
knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
|
||||
top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
|
||||
score_threshold: float = Body(SCORE_THRESHOLD, description="知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右", ge=0, le=1),
|
||||
history: List[History] = Body([],
|
||||
async def knowledge_base_chat(query: str = Body(..., description="用户输入", examples=["你好"]),
|
||||
knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]),
|
||||
top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
|
||||
score_threshold: float = Body(SCORE_THRESHOLD, description="知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右", ge=0, le=1),
|
||||
history: List[History] = Body([],
|
||||
description="历史对话",
|
||||
examples=[[
|
||||
{"role": "user",
|
||||
|
|
@ -30,10 +30,10 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
|
|||
{"role": "assistant",
|
||||
"content": "虎头虎脑"}]]
|
||||
),
|
||||
stream: bool = Body(False, description="流式输出"),
|
||||
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
|
||||
local_doc_url: bool = Body(False, description="知识文件返回本地路径(true)或URL(false)"),
|
||||
request: Request = None,
|
||||
stream: bool = Body(False, description="流式输出"),
|
||||
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
|
||||
local_doc_url: bool = Body(False, description="知识文件返回本地路径(true)或URL(false)"),
|
||||
request: Request = None,
|
||||
):
|
||||
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
|
||||
if kb is None:
|
||||
|
|
|
|||
|
|
@ -29,13 +29,13 @@ async def openai_chat(msg: OpenAiChatMsgIn):
|
|||
print(f"{openai.api_base=}")
|
||||
print(msg)
|
||||
|
||||
def get_response(msg):
|
||||
async def get_response(msg):
|
||||
data = msg.dict()
|
||||
|
||||
try:
|
||||
response = openai.ChatCompletion.create(**data)
|
||||
response = await openai.ChatCompletion.acreate(**data)
|
||||
if msg.stream:
|
||||
for data in response:
|
||||
async for data in response:
|
||||
if choices := data.choices:
|
||||
if chunk := choices[0].get("delta", {}).get("content"):
|
||||
print(chunk, end="", flush=True)
|
||||
|
|
@ -46,8 +46,7 @@ async def openai_chat(msg: OpenAiChatMsgIn):
|
|||
print(answer)
|
||||
yield(answer)
|
||||
except Exception as e:
|
||||
print(type(e))
|
||||
logger.error(e)
|
||||
logger.error(f"获取ChatCompletion时出错:{e}")
|
||||
|
||||
return StreamingResponse(
|
||||
get_response(msg),
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ from langchain.utilities import BingSearchAPIWrapper, DuckDuckGoSearchAPIWrapper
|
|||
from configs.model_config import BING_SEARCH_URL, BING_SUBSCRIPTION_KEY
|
||||
from fastapi import Body
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
from configs.model_config import (llm_model_dict, LLM_MODEL, SEARCH_ENGINE_TOP_K, PROMPT_TEMPLATE)
|
||||
from server.chat.utils import wrap_done
|
||||
from server.utils import BaseResponse
|
||||
|
|
@ -47,29 +48,30 @@ def search_result2docs(search_results):
|
|||
return docs
|
||||
|
||||
|
||||
def lookup_search_engine(
|
||||
async def lookup_search_engine(
|
||||
query: str,
|
||||
search_engine_name: str,
|
||||
top_k: int = SEARCH_ENGINE_TOP_K,
|
||||
):
|
||||
results = SEARCH_ENGINES[search_engine_name](query, result_len=top_k)
|
||||
search_engine = SEARCH_ENGINES[search_engine_name]
|
||||
results = await run_in_threadpool(search_engine, query, result_len=top_k)
|
||||
docs = search_result2docs(results)
|
||||
return docs
|
||||
|
||||
|
||||
def search_engine_chat(query: str = Body(..., description="用户输入", examples=["你好"]),
|
||||
search_engine_name: str = Body(..., description="搜索引擎名称", examples=["duckduckgo"]),
|
||||
top_k: int = Body(SEARCH_ENGINE_TOP_K, description="检索结果数量"),
|
||||
history: List[History] = Body([],
|
||||
description="历史对话",
|
||||
examples=[[
|
||||
{"role": "user",
|
||||
"content": "我们来玩成语接龙,我先来,生龙活虎"},
|
||||
{"role": "assistant",
|
||||
"content": "虎头虎脑"}]]
|
||||
),
|
||||
stream: bool = Body(False, description="流式输出"),
|
||||
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
|
||||
async def search_engine_chat(query: str = Body(..., description="用户输入", examples=["你好"]),
|
||||
search_engine_name: str = Body(..., description="搜索引擎名称", examples=["duckduckgo"]),
|
||||
top_k: int = Body(SEARCH_ENGINE_TOP_K, description="检索结果数量"),
|
||||
history: List[History] = Body([],
|
||||
description="历史对话",
|
||||
examples=[[
|
||||
{"role": "user",
|
||||
"content": "我们来玩成语接龙,我先来,生龙活虎"},
|
||||
{"role": "assistant",
|
||||
"content": "虎头虎脑"}]]
|
||||
),
|
||||
stream: bool = Body(False, description="流式输出"),
|
||||
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
|
||||
):
|
||||
if search_engine_name not in SEARCH_ENGINES.keys():
|
||||
return BaseResponse(code=404, msg=f"未支持搜索引擎 {search_engine_name}")
|
||||
|
|
@ -96,7 +98,7 @@ def search_engine_chat(query: str = Body(..., description="用户输入", exampl
|
|||
openai_proxy=llm_model_dict[model_name].get("openai_proxy")
|
||||
)
|
||||
|
||||
docs = lookup_search_engine(query, search_engine_name, top_k)
|
||||
docs = await lookup_search_engine(query, search_engine_name, top_k)
|
||||
context = "\n".join([doc.page_content for doc in docs])
|
||||
|
||||
input_msg = History(role="user", content=PROMPT_TEMPLATE).to_msg_template(False)
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ from configs.model_config import EMBEDDING_MODEL, logger
|
|||
from fastapi import Body
|
||||
|
||||
|
||||
async def list_kbs():
|
||||
def list_kbs():
|
||||
# Get List of Knowledge Base
|
||||
return ListResponse(data=list_kbs_from_db())
|
||||
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ def search_docs(query: str = Body(..., description="用户输入", examples=["
|
|||
return data
|
||||
|
||||
|
||||
async def list_files(
|
||||
def list_files(
|
||||
knowledge_base_name: str
|
||||
) -> ListResponse:
|
||||
if not validate_kb_name(knowledge_base_name):
|
||||
|
|
@ -258,10 +258,10 @@ async def update_docs(
|
|||
return BaseResponse(code=200, msg=f"更新文档完成", data={"failed_files": failed_files})
|
||||
|
||||
|
||||
async def download_doc(
|
||||
knowledge_base_name: str = Query(...,description="知识库名称", examples=["samples"]),
|
||||
file_name: str = Query(...,description="文件名称", examples=["test.txt"]),
|
||||
preview: bool = Query(False, description="是:浏览器内预览;否:下载"),
|
||||
def download_doc(
|
||||
knowledge_base_name: str = Query(...,description="知识库名称", examples=["samples"]),
|
||||
file_name: str = Query(...,description="文件名称", examples=["test.txt"]),
|
||||
preview: bool = Query(False, description="是:浏览器内预览;否:下载"),
|
||||
):
|
||||
'''
|
||||
下载知识库文档
|
||||
|
|
|
|||
|
|
@ -48,6 +48,8 @@ class ApiRequest:
|
|||
self.base_url = base_url
|
||||
self.timeout = timeout
|
||||
self.no_remote_api = no_remote_api
|
||||
if no_remote_api:
|
||||
logger.warn("将来可能取消对no_remote_api的支持,更新版本时请注意。")
|
||||
|
||||
def _parse_url(self, url: str) -> str:
|
||||
if (not url.startswith("http")
|
||||
|
|
@ -270,7 +272,7 @@ class ApiRequest:
|
|||
|
||||
if no_remote_api:
|
||||
from server.chat.openai_chat import openai_chat
|
||||
response = openai_chat(msg)
|
||||
response = run_async(openai_chat(msg))
|
||||
return self._fastapi_stream2generator(response)
|
||||
else:
|
||||
data = msg.dict(exclude_unset=True, exclude_none=True)
|
||||
|
|
@ -280,7 +282,7 @@ class ApiRequest:
|
|||
response = self.post(
|
||||
"/chat/fastchat",
|
||||
json=data,
|
||||
stream=stream,
|
||||
stream=True,
|
||||
)
|
||||
return self._httpx_stream2generator(response)
|
||||
|
||||
|
|
@ -310,7 +312,7 @@ class ApiRequest:
|
|||
|
||||
if no_remote_api:
|
||||
from server.chat.chat import chat
|
||||
response = chat(**data)
|
||||
response = run_async(chat(**data))
|
||||
return self._fastapi_stream2generator(response)
|
||||
else:
|
||||
response = self.post("/chat/chat", json=data, stream=True)
|
||||
|
|
@ -349,7 +351,7 @@ class ApiRequest:
|
|||
|
||||
if no_remote_api:
|
||||
from server.chat.knowledge_base_chat import knowledge_base_chat
|
||||
response = knowledge_base_chat(**data)
|
||||
response = run_async(knowledge_base_chat(**data))
|
||||
return self._fastapi_stream2generator(response, as_json=True)
|
||||
else:
|
||||
response = self.post(
|
||||
|
|
@ -387,7 +389,7 @@ class ApiRequest:
|
|||
|
||||
if no_remote_api:
|
||||
from server.chat.search_engine_chat import search_engine_chat
|
||||
response = search_engine_chat(**data)
|
||||
response = run_async(search_engine_chat(**data))
|
||||
return self._fastapi_stream2generator(response, as_json=True)
|
||||
else:
|
||||
response = self.post(
|
||||
|
|
@ -427,7 +429,7 @@ class ApiRequest:
|
|||
|
||||
if no_remote_api:
|
||||
from server.knowledge_base.kb_api import list_kbs
|
||||
response = run_async(list_kbs())
|
||||
response = list_kbs()
|
||||
return response.data
|
||||
else:
|
||||
response = self.get("/knowledge_base/list_knowledge_bases")
|
||||
|
|
@ -499,7 +501,7 @@ class ApiRequest:
|
|||
|
||||
if no_remote_api:
|
||||
from server.knowledge_base.kb_doc_api import list_files
|
||||
response = run_async(list_files(knowledge_base_name))
|
||||
response = list_files(knowledge_base_name)
|
||||
return response.data
|
||||
else:
|
||||
response = self.get(
|
||||
|
|
|
|||
Loading…
Reference in New Issue