From cd748128c351ff8f6e728f76d0a23e4ce5b13e73 Mon Sep 17 00:00:00 2001 From: liunux4odoo <41217877+liunux4odoo@users.noreply.github.com> Date: Thu, 12 Oct 2023 16:18:56 +0800 Subject: [PATCH] add paramter `max_tokens` to 4 chat api with default value 1024 (#1744) --- server/chat/agent_chat.py | 2 ++ server/chat/chat.py | 2 ++ server/chat/knowledge_base_chat.py | 2 ++ server/chat/search_engine_chat.py | 2 ++ server/utils.py | 10 +++++++++- webui_pages/dialogue/dialogue.py | 2 +- 6 files changed, 18 insertions(+), 2 deletions(-) diff --git a/server/chat/agent_chat.py b/server/chat/agent_chat.py index df0ce8b..1f51802 100644 --- a/server/chat/agent_chat.py +++ b/server/chat/agent_chat.py @@ -25,6 +25,7 @@ async def agent_chat(query: str = Body(..., description="用户输入", examples stream: bool = Body(False, description="流式输出"), model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"), temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0), + max_tokens: int = Body(1024, description="限制LLM生成Token数量,当前默认为1024"), # TODO: fastchat更新后默认值设为None,自动使用LLM支持的最大值。 prompt_name: str = Body("agent_chat", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"), # top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0), @@ -41,6 +42,7 @@ async def agent_chat(query: str = Body(..., description="用户输入", examples model = get_ChatOpenAI( model_name=model_name, temperature=temperature, + max_tokens=max_tokens, ) prompt_template = CustomPromptTemplate( diff --git a/server/chat/chat.py b/server/chat/chat.py index 6d3c9ce..2f37f99 100644 --- a/server/chat/chat.py +++ b/server/chat/chat.py @@ -22,6 +22,7 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼 stream: bool = Body(False, description="流式输出"), model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"), temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0), + max_tokens: int = Body(1024, description="限制LLM生成Token数量,当前默认为1024"), # TODO: fastchat更新后默认值设为None,自动使用LLM支持的最大值。 # top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0), prompt_name: str = Body("llm_chat", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"), ): @@ -36,6 +37,7 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼 model = get_ChatOpenAI( model_name=model_name, temperature=temperature, + max_tokens=max_tokens, callbacks=[callback], ) diff --git a/server/chat/knowledge_base_chat.py b/server/chat/knowledge_base_chat.py index 9c70ee5..87e7098 100644 --- a/server/chat/knowledge_base_chat.py +++ b/server/chat/knowledge_base_chat.py @@ -31,6 +31,7 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入", stream: bool = Body(False, description="流式输出"), model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"), temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0), + max_tokens: int = Body(1024, description="限制LLM生成Token数量,当前默认为1024"), # TODO: fastchat更新后默认值设为None,自动使用LLM支持的最大值。 prompt_name: str = Body("knowledge_base_chat", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"), local_doc_url: bool = Body(False, description="知识文件返回本地路径(true)或URL(false)"), request: Request = None, @@ -51,6 +52,7 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入", model = get_ChatOpenAI( model_name=model_name, temperature=temperature, + max_tokens=max_tokens, callbacks=[callback], ) docs = search_docs(query, knowledge_base_name, top_k, score_threshold) diff --git a/server/chat/search_engine_chat.py b/server/chat/search_engine_chat.py index 00708b7..930c9bd 100644 --- a/server/chat/search_engine_chat.py +++ b/server/chat/search_engine_chat.py @@ -72,6 +72,7 @@ async def search_engine_chat(query: str = Body(..., description="用户输入", stream: bool = Body(False, description="流式输出"), model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"), temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0), + max_tokens: int = Body(1024, description="限制LLM生成Token数量,当前默认为1024"), # TODO: fastchat更新后默认值设为None,自动使用LLM支持的最大值。 prompt_name: str = Body("knowledge_base_chat", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"), ): if search_engine_name not in SEARCH_ENGINES.keys(): @@ -93,6 +94,7 @@ async def search_engine_chat(query: str = Body(..., description="用户输入", model = get_ChatOpenAI( model_name=model_name, temperature=temperature, + max_tokens=max_tokens, callbacks=[callback], ) diff --git a/server/utils.py b/server/utils.py index 6030748..bae9fe4 100644 --- a/server/utils.py +++ b/server/utils.py @@ -34,6 +34,7 @@ async def wrap_done(fn: Awaitable, event: asyncio.Event): def get_ChatOpenAI( model_name: str, temperature: float, + max_tokens: int = None, streaming: bool = True, callbacks: List[Callable] = [], verbose: bool = True, @@ -48,6 +49,7 @@ def get_ChatOpenAI( openai_api_base=config.get("api_base_url", fschat_openai_api_address()), model_name=model_name, temperature=temperature, + max_tokens=max_tokens, openai_proxy=config.get("openai_proxy"), **kwargs ) @@ -144,7 +146,7 @@ def run_async(cor): return loop.run_until_complete(cor) -def iter_over_async(ait, loop): +def iter_over_async(ait, loop=None): ''' 将异步生成器封装成同步生成器. ''' @@ -157,6 +159,12 @@ def iter_over_async(ait, loop): except StopAsyncIteration: return True, None + if loop is None: + try: + loop = asyncio.get_event_loop() + except: + loop = asyncio.new_event_loop() + while True: done, obj = loop.run_until_complete(get_next()) if done: diff --git a/webui_pages/dialogue/dialogue.py b/webui_pages/dialogue/dialogue.py index 7af2dc5..2664b30 100644 --- a/webui_pages/dialogue/dialogue.py +++ b/webui_pages/dialogue/dialogue.py @@ -166,7 +166,7 @@ def dialogue_page(api: ApiRequest): ans = "" support_agent = ["gpt", "Qwen", "qwen-api", "baichuan-api"] # 目前支持agent的模型 if not any(agent in llm_model for agent in support_agent): - ans += "正在思考... \n\n 改模型并没有进行Agent对齐,无法正常使用Agent功能!\n\n\n请更换 GPT4或Qwen-14B等支持Agent的模型获得更好的体验! \n\n\n" + ans += "正在思考... \n\n 该模型并没有进行Agent对齐,无法正常使用Agent功能!\n\n\n请更换 GPT4或Qwen-14B等支持Agent的模型获得更好的体验! \n\n\n" chat_box.update_msg(ans, element_index=0, streaming=False)