diff --git a/server/chat/agent_chat.py b/server/chat/agent_chat.py index 536c5cd..51d4d1f 100644 --- a/server/chat/agent_chat.py +++ b/server/chat/agent_chat.py @@ -43,6 +43,7 @@ async def agent_chat(query: str = Body(..., description="用户输入", examples model_name: str = LLM_MODELS[0], prompt_name: str = prompt_name, ) -> AsyncIterable[str]: + nonlocal max_tokens callback = CustomAsyncIteratorCallbackHandler() if isinstance(max_tokens, int) and max_tokens <= 0: max_tokens = None diff --git a/server/chat/chat.py b/server/chat/chat.py index 47ec871..0885283 100644 --- a/server/chat/chat.py +++ b/server/chat/chat.py @@ -34,7 +34,7 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼 prompt_name: str = Body("default", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"), ): async def chat_iterator() -> AsyncIterable[str]: - nonlocal history + nonlocal history, max_tokens callback = AsyncIteratorCallbackHandler() callbacks = [callback] memory = None diff --git a/server/chat/completion.py b/server/chat/completion.py index beda026..31ae3ff 100644 --- a/server/chat/completion.py +++ b/server/chat/completion.py @@ -27,6 +27,7 @@ async def completion(query: str = Body(..., description="用户输入", examples prompt_name: str = prompt_name, echo: bool = echo, ) -> AsyncIterable[str]: + nonlocal max_tokens callback = AsyncIteratorCallbackHandler() if isinstance(max_tokens, int) and max_tokens <= 0: max_tokens = None diff --git a/server/chat/file_chat.py b/server/chat/file_chat.py index a4db917..ef4a522 100644 --- a/server/chat/file_chat.py +++ b/server/chat/file_chat.py @@ -113,6 +113,7 @@ async def file_chat(query: str = Body(..., description="用户输入", examples= history = [History.from_data(h) for h in history] async def knowledge_base_chat_iterator() -> AsyncIterable[str]: + nonlocal max_tokens callback = AsyncIteratorCallbackHandler() if isinstance(max_tokens, int) and max_tokens <= 0: max_tokens = None diff --git a/server/chat/knowledge_base_chat.py b/server/chat/knowledge_base_chat.py index a99a045..b82e1c0 100644 --- a/server/chat/knowledge_base_chat.py +++ b/server/chat/knowledge_base_chat.py @@ -48,6 +48,7 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入", model_name: str = LLM_MODELS[0], prompt_name: str = prompt_name, ) -> AsyncIterable[str]: + nonlocal max_tokens callback = AsyncIteratorCallbackHandler() if isinstance(max_tokens, int) and max_tokens <= 0: max_tokens = None diff --git a/server/chat/search_engine_chat.py b/server/chat/search_engine_chat.py index 98b26c6..5b77e99 100644 --- a/server/chat/search_engine_chat.py +++ b/server/chat/search_engine_chat.py @@ -147,6 +147,7 @@ async def search_engine_chat(query: str = Body(..., description="用户输入", model_name: str = LLM_MODELS[0], prompt_name: str = prompt_name, ) -> AsyncIterable[str]: + nonlocal max_tokens callback = AsyncIteratorCallbackHandler() if isinstance(max_tokens, int) and max_tokens <= 0: max_tokens = None