From 775870a51696f9550550b192aa1813764238f45f Mon Sep 17 00:00:00 2001 From: liunux4odoo <41217877+liunux4odoo@users.noreply.github.com> Date: Fri, 8 Sep 2023 12:25:02 +0800 Subject: [PATCH] =?UTF-8?q?=E6=94=B9=E5=8F=98api=E8=A7=86=E5=9B=BE?= =?UTF-8?q?=E5=87=BD=E6=95=B0=E7=9A=84sync/async=EF=BC=8C=E6=8F=90?= =?UTF-8?q?=E9=AB=98api=E5=B9=B6=E5=8F=91=E8=83=BD=E5=8A=9B=EF=BC=9A=20(#1?= =?UTF-8?q?414)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1. 4个chat类接口改为async 2. 知识库操作,涉及向量库修改的使用async,避免FAISS写入错误;涉及向量库读取的改为sync,提高并发 --- server/chat/chat.py | 4 ++-- server/chat/knowledge_base_chat.py | 18 +++++++-------- server/chat/openai_chat.py | 9 ++++---- server/chat/search_engine_chat.py | 34 +++++++++++++++-------------- server/knowledge_base/kb_api.py | 2 +- server/knowledge_base/kb_doc_api.py | 10 ++++----- webui_pages/utils.py | 16 ++++++++------ 7 files changed, 48 insertions(+), 45 deletions(-) diff --git a/server/chat/chat.py b/server/chat/chat.py index ba23a5a..b4fd6bb 100644 --- a/server/chat/chat.py +++ b/server/chat/chat.py @@ -12,8 +12,8 @@ from typing import List from server.chat.utils import History -def chat(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]), - history: List[History] = Body([], +async def chat(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]), + history: List[History] = Body([], description="历史对话", examples=[[ {"role": "user", "content": "我们来玩成语接龙,我先来,生龙活虎"}, diff --git a/server/chat/knowledge_base_chat.py b/server/chat/knowledge_base_chat.py index 69ec25d..9aa0d5d 100644 --- a/server/chat/knowledge_base_chat.py +++ b/server/chat/knowledge_base_chat.py @@ -18,11 +18,11 @@ from urllib.parse import urlencode from server.knowledge_base.kb_doc_api import search_docs -def knowledge_base_chat(query: str = Body(..., description="用户输入", examples=["你好"]), - knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]), - top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"), - score_threshold: float = Body(SCORE_THRESHOLD, description="知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右", ge=0, le=1), - history: List[History] = Body([], +async def knowledge_base_chat(query: str = Body(..., description="用户输入", examples=["你好"]), + knowledge_base_name: str = Body(..., description="知识库名称", examples=["samples"]), + top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"), + score_threshold: float = Body(SCORE_THRESHOLD, description="知识库匹配相关度阈值,取值范围在0-1之间,SCORE越小,相关度越高,取到1相当于不筛选,建议设置在0.5左右", ge=0, le=1), + history: List[History] = Body([], description="历史对话", examples=[[ {"role": "user", @@ -30,10 +30,10 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp {"role": "assistant", "content": "虎头虎脑"}]] ), - stream: bool = Body(False, description="流式输出"), - model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"), - local_doc_url: bool = Body(False, description="知识文件返回本地路径(true)或URL(false)"), - request: Request = None, + stream: bool = Body(False, description="流式输出"), + model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"), + local_doc_url: bool = Body(False, description="知识文件返回本地路径(true)或URL(false)"), + request: Request = None, ): kb = KBServiceFactory.get_service_by_name(knowledge_base_name) if kb is None: diff --git a/server/chat/openai_chat.py b/server/chat/openai_chat.py index a799c62..2414090 100644 --- a/server/chat/openai_chat.py +++ b/server/chat/openai_chat.py @@ -29,13 +29,13 @@ async def openai_chat(msg: OpenAiChatMsgIn): print(f"{openai.api_base=}") print(msg) - def get_response(msg): + async def get_response(msg): data = msg.dict() try: - response = openai.ChatCompletion.create(**data) + response = await openai.ChatCompletion.acreate(**data) if msg.stream: - for data in response: + async for data in response: if choices := data.choices: if chunk := choices[0].get("delta", {}).get("content"): print(chunk, end="", flush=True) @@ -46,8 +46,7 @@ async def openai_chat(msg: OpenAiChatMsgIn): print(answer) yield(answer) except Exception as e: - print(type(e)) - logger.error(e) + logger.error(f"获取ChatCompletion时出错:{e}") return StreamingResponse( get_response(msg), diff --git a/server/chat/search_engine_chat.py b/server/chat/search_engine_chat.py index 8fe7dae..a77106f 100644 --- a/server/chat/search_engine_chat.py +++ b/server/chat/search_engine_chat.py @@ -2,6 +2,7 @@ from langchain.utilities import BingSearchAPIWrapper, DuckDuckGoSearchAPIWrapper from configs.model_config import BING_SEARCH_URL, BING_SUBSCRIPTION_KEY from fastapi import Body from fastapi.responses import StreamingResponse +from fastapi.concurrency import run_in_threadpool from configs.model_config import (llm_model_dict, LLM_MODEL, SEARCH_ENGINE_TOP_K, PROMPT_TEMPLATE) from server.chat.utils import wrap_done from server.utils import BaseResponse @@ -47,29 +48,30 @@ def search_result2docs(search_results): return docs -def lookup_search_engine( +async def lookup_search_engine( query: str, search_engine_name: str, top_k: int = SEARCH_ENGINE_TOP_K, ): - results = SEARCH_ENGINES[search_engine_name](query, result_len=top_k) + search_engine = SEARCH_ENGINES[search_engine_name] + results = await run_in_threadpool(search_engine, query, result_len=top_k) docs = search_result2docs(results) return docs -def search_engine_chat(query: str = Body(..., description="用户输入", examples=["你好"]), - search_engine_name: str = Body(..., description="搜索引擎名称", examples=["duckduckgo"]), - top_k: int = Body(SEARCH_ENGINE_TOP_K, description="检索结果数量"), - history: List[History] = Body([], - description="历史对话", - examples=[[ - {"role": "user", - "content": "我们来玩成语接龙,我先来,生龙活虎"}, - {"role": "assistant", - "content": "虎头虎脑"}]] - ), - stream: bool = Body(False, description="流式输出"), - model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"), +async def search_engine_chat(query: str = Body(..., description="用户输入", examples=["你好"]), + search_engine_name: str = Body(..., description="搜索引擎名称", examples=["duckduckgo"]), + top_k: int = Body(SEARCH_ENGINE_TOP_K, description="检索结果数量"), + history: List[History] = Body([], + description="历史对话", + examples=[[ + {"role": "user", + "content": "我们来玩成语接龙,我先来,生龙活虎"}, + {"role": "assistant", + "content": "虎头虎脑"}]] + ), + stream: bool = Body(False, description="流式输出"), + model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"), ): if search_engine_name not in SEARCH_ENGINES.keys(): return BaseResponse(code=404, msg=f"未支持搜索引擎 {search_engine_name}") @@ -96,7 +98,7 @@ def search_engine_chat(query: str = Body(..., description="用户输入", exampl openai_proxy=llm_model_dict[model_name].get("openai_proxy") ) - docs = lookup_search_engine(query, search_engine_name, top_k) + docs = await lookup_search_engine(query, search_engine_name, top_k) context = "\n".join([doc.page_content for doc in docs]) input_msg = History(role="user", content=PROMPT_TEMPLATE).to_msg_template(False) diff --git a/server/knowledge_base/kb_api.py b/server/knowledge_base/kb_api.py index 4b135fe..0648a39 100644 --- a/server/knowledge_base/kb_api.py +++ b/server/knowledge_base/kb_api.py @@ -7,7 +7,7 @@ from configs.model_config import EMBEDDING_MODEL, logger from fastapi import Body -async def list_kbs(): +def list_kbs(): # Get List of Knowledge Base return ListResponse(data=list_kbs_from_db()) diff --git a/server/knowledge_base/kb_doc_api.py b/server/knowledge_base/kb_doc_api.py index 24b1c8e..6539217 100644 --- a/server/knowledge_base/kb_doc_api.py +++ b/server/knowledge_base/kb_doc_api.py @@ -34,7 +34,7 @@ def search_docs(query: str = Body(..., description="用户输入", examples=[" return data -async def list_files( +def list_files( knowledge_base_name: str ) -> ListResponse: if not validate_kb_name(knowledge_base_name): @@ -258,10 +258,10 @@ async def update_docs( return BaseResponse(code=200, msg=f"更新文档完成", data={"failed_files": failed_files}) -async def download_doc( - knowledge_base_name: str = Query(...,description="知识库名称", examples=["samples"]), - file_name: str = Query(...,description="文件名称", examples=["test.txt"]), - preview: bool = Query(False, description="是:浏览器内预览;否:下载"), +def download_doc( + knowledge_base_name: str = Query(...,description="知识库名称", examples=["samples"]), + file_name: str = Query(...,description="文件名称", examples=["test.txt"]), + preview: bool = Query(False, description="是:浏览器内预览;否:下载"), ): ''' 下载知识库文档 diff --git a/webui_pages/utils.py b/webui_pages/utils.py index 27843f8..0292c2a 100644 --- a/webui_pages/utils.py +++ b/webui_pages/utils.py @@ -48,6 +48,8 @@ class ApiRequest: self.base_url = base_url self.timeout = timeout self.no_remote_api = no_remote_api + if no_remote_api: + logger.warn("将来可能取消对no_remote_api的支持,更新版本时请注意。") def _parse_url(self, url: str) -> str: if (not url.startswith("http") @@ -270,7 +272,7 @@ class ApiRequest: if no_remote_api: from server.chat.openai_chat import openai_chat - response = openai_chat(msg) + response = run_async(openai_chat(msg)) return self._fastapi_stream2generator(response) else: data = msg.dict(exclude_unset=True, exclude_none=True) @@ -280,7 +282,7 @@ class ApiRequest: response = self.post( "/chat/fastchat", json=data, - stream=stream, + stream=True, ) return self._httpx_stream2generator(response) @@ -310,7 +312,7 @@ class ApiRequest: if no_remote_api: from server.chat.chat import chat - response = chat(**data) + response = run_async(chat(**data)) return self._fastapi_stream2generator(response) else: response = self.post("/chat/chat", json=data, stream=True) @@ -349,7 +351,7 @@ class ApiRequest: if no_remote_api: from server.chat.knowledge_base_chat import knowledge_base_chat - response = knowledge_base_chat(**data) + response = run_async(knowledge_base_chat(**data)) return self._fastapi_stream2generator(response, as_json=True) else: response = self.post( @@ -387,7 +389,7 @@ class ApiRequest: if no_remote_api: from server.chat.search_engine_chat import search_engine_chat - response = search_engine_chat(**data) + response = run_async(search_engine_chat(**data)) return self._fastapi_stream2generator(response, as_json=True) else: response = self.post( @@ -427,7 +429,7 @@ class ApiRequest: if no_remote_api: from server.knowledge_base.kb_api import list_kbs - response = run_async(list_kbs()) + response = list_kbs() return response.data else: response = self.get("/knowledge_base/list_knowledge_bases") @@ -499,7 +501,7 @@ class ApiRequest: if no_remote_api: from server.knowledge_base.kb_doc_api import list_files - response = run_async(list_files(knowledge_base_name)) + response = list_files(knowledge_base_name) return response.data else: response = self.get(