From 7d580d9a47c9455f420e87e9fbdd3188a712582e Mon Sep 17 00:00:00 2001 From: liunux4odoo Date: Tue, 28 Nov 2023 21:00:00 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=EF=BC=9A*=5Fchat=20=E6=8E=A5?= =?UTF-8?q?=E5=8F=A3=E4=B8=AD=E9=87=8D=E8=AE=BE=20max=5Ftokens=20=E6=97=B6?= =?UTF-8?q?=E5=AF=BC=E8=87=B4=20`local=20variable=20'max=5Ftokens'=20refer?= =?UTF-8?q?enced=20before=20assignment`=20=E9=94=99=E8=AF=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- server/chat/agent_chat.py | 1 + server/chat/chat.py | 2 +- server/chat/completion.py | 1 + server/chat/file_chat.py | 1 + server/chat/knowledge_base_chat.py | 1 + server/chat/search_engine_chat.py | 1 + 6 files changed, 6 insertions(+), 1 deletion(-) diff --git a/server/chat/agent_chat.py b/server/chat/agent_chat.py index 536c5cd..51d4d1f 100644 --- a/server/chat/agent_chat.py +++ b/server/chat/agent_chat.py @@ -43,6 +43,7 @@ async def agent_chat(query: str = Body(..., description="用户输入", examples model_name: str = LLM_MODELS[0], prompt_name: str = prompt_name, ) -> AsyncIterable[str]: + nonlocal max_tokens callback = CustomAsyncIteratorCallbackHandler() if isinstance(max_tokens, int) and max_tokens <= 0: max_tokens = None diff --git a/server/chat/chat.py b/server/chat/chat.py index 47ec871..0885283 100644 --- a/server/chat/chat.py +++ b/server/chat/chat.py @@ -34,7 +34,7 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼 prompt_name: str = Body("default", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"), ): async def chat_iterator() -> AsyncIterable[str]: - nonlocal history + nonlocal history, max_tokens callback = AsyncIteratorCallbackHandler() callbacks = [callback] memory = None diff --git a/server/chat/completion.py b/server/chat/completion.py index beda026..31ae3ff 100644 --- a/server/chat/completion.py +++ b/server/chat/completion.py @@ -27,6 +27,7 @@ async def completion(query: str = Body(..., description="用户输入", examples prompt_name: str = prompt_name, echo: bool = echo, ) -> AsyncIterable[str]: + nonlocal max_tokens callback = AsyncIteratorCallbackHandler() if isinstance(max_tokens, int) and max_tokens <= 0: max_tokens = None diff --git a/server/chat/file_chat.py b/server/chat/file_chat.py index a4db917..ef4a522 100644 --- a/server/chat/file_chat.py +++ b/server/chat/file_chat.py @@ -113,6 +113,7 @@ async def file_chat(query: str = Body(..., description="用户输入", examples= history = [History.from_data(h) for h in history] async def knowledge_base_chat_iterator() -> AsyncIterable[str]: + nonlocal max_tokens callback = AsyncIteratorCallbackHandler() if isinstance(max_tokens, int) and max_tokens <= 0: max_tokens = None diff --git a/server/chat/knowledge_base_chat.py b/server/chat/knowledge_base_chat.py index a99a045..b82e1c0 100644 --- a/server/chat/knowledge_base_chat.py +++ b/server/chat/knowledge_base_chat.py @@ -48,6 +48,7 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入", model_name: str = LLM_MODELS[0], prompt_name: str = prompt_name, ) -> AsyncIterable[str]: + nonlocal max_tokens callback = AsyncIteratorCallbackHandler() if isinstance(max_tokens, int) and max_tokens <= 0: max_tokens = None diff --git a/server/chat/search_engine_chat.py b/server/chat/search_engine_chat.py index 98b26c6..5b77e99 100644 --- a/server/chat/search_engine_chat.py +++ b/server/chat/search_engine_chat.py @@ -147,6 +147,7 @@ async def search_engine_chat(query: str = Body(..., description="用户输入", model_name: str = LLM_MODELS[0], prompt_name: str = prompt_name, ) -> AsyncIterable[str]: + nonlocal max_tokens callback = AsyncIteratorCallbackHandler() if isinstance(max_tokens, int) and max_tokens <= 0: max_tokens = None