diff --git a/server/chat/agent_chat.py b/server/chat/agent_chat.py index 8ae38db..536c5cd 100644 --- a/server/chat/agent_chat.py +++ b/server/chat/agent_chat.py @@ -44,6 +44,9 @@ async def agent_chat(query: str = Body(..., description="用户输入", examples prompt_name: str = prompt_name, ) -> AsyncIterable[str]: callback = CustomAsyncIteratorCallbackHandler() + if isinstance(max_tokens, int) and max_tokens <= 0: + max_tokens = None + model = get_ChatOpenAI( model_name=model_name, temperature=temperature, diff --git a/server/chat/chat.py b/server/chat/chat.py index acf3ec0..47ec871 100644 --- a/server/chat/chat.py +++ b/server/chat/chat.py @@ -45,7 +45,8 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼 chat_type="llm_chat", query=query) callbacks.append(conversation_callback) - + if isinstance(max_tokens, int) and max_tokens <= 0: + max_tokens = None model = get_ChatOpenAI( model_name=model_name, diff --git a/server/chat/completion.py b/server/chat/completion.py index ee5e2d1..beda026 100644 --- a/server/chat/completion.py +++ b/server/chat/completion.py @@ -28,6 +28,9 @@ async def completion(query: str = Body(..., description="用户输入", examples echo: bool = echo, ) -> AsyncIterable[str]: callback = AsyncIteratorCallbackHandler() + if isinstance(max_tokens, int) and max_tokens <= 0: + max_tokens = None + model = get_OpenAI( model_name=model_name, temperature=temperature, diff --git a/server/chat/file_chat.py b/server/chat/file_chat.py index ea3475a..a4db917 100644 --- a/server/chat/file_chat.py +++ b/server/chat/file_chat.py @@ -114,6 +114,9 @@ async def file_chat(query: str = Body(..., description="用户输入", examples= async def knowledge_base_chat_iterator() -> AsyncIterable[str]: callback = AsyncIteratorCallbackHandler() + if isinstance(max_tokens, int) and max_tokens <= 0: + max_tokens = None + model = get_ChatOpenAI( model_name=model_name, temperature=temperature, diff --git a/server/chat/knowledge_base_chat.py b/server/chat/knowledge_base_chat.py index 0ea99a6..a99a045 100644 --- a/server/chat/knowledge_base_chat.py +++ b/server/chat/knowledge_base_chat.py @@ -49,6 +49,9 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入", prompt_name: str = prompt_name, ) -> AsyncIterable[str]: callback = AsyncIteratorCallbackHandler() + if isinstance(max_tokens, int) and max_tokens <= 0: + max_tokens = None + model = get_ChatOpenAI( model_name=model_name, temperature=temperature, diff --git a/server/chat/search_engine_chat.py b/server/chat/search_engine_chat.py index 8325b4d..98b26c6 100644 --- a/server/chat/search_engine_chat.py +++ b/server/chat/search_engine_chat.py @@ -148,6 +148,9 @@ async def search_engine_chat(query: str = Body(..., description="用户输入", prompt_name: str = prompt_name, ) -> AsyncIterable[str]: callback = AsyncIteratorCallbackHandler() + if isinstance(max_tokens, int) and max_tokens <= 0: + max_tokens = None + model = get_ChatOpenAI( model_name=model_name, temperature=temperature,