修复:*_chat 接口中重设 max_tokens 时导致 `local variable 'max_tokens' referenced before assignment` 错误

This commit is contained in:
liunux4odoo 2023-11-28 21:00:00 +08:00
parent 8b695dba03
commit 7d580d9a47
6 changed files with 6 additions and 1 deletions

View File

@ -43,6 +43,7 @@ async def agent_chat(query: str = Body(..., description="用户输入", examples
model_name: str = LLM_MODELS[0], model_name: str = LLM_MODELS[0],
prompt_name: str = prompt_name, prompt_name: str = prompt_name,
) -> AsyncIterable[str]: ) -> AsyncIterable[str]:
nonlocal max_tokens
callback = CustomAsyncIteratorCallbackHandler() callback = CustomAsyncIteratorCallbackHandler()
if isinstance(max_tokens, int) and max_tokens <= 0: if isinstance(max_tokens, int) and max_tokens <= 0:
max_tokens = None max_tokens = None

View File

@ -34,7 +34,7 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
prompt_name: str = Body("default", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"), prompt_name: str = Body("default", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
): ):
async def chat_iterator() -> AsyncIterable[str]: async def chat_iterator() -> AsyncIterable[str]:
nonlocal history nonlocal history, max_tokens
callback = AsyncIteratorCallbackHandler() callback = AsyncIteratorCallbackHandler()
callbacks = [callback] callbacks = [callback]
memory = None memory = None

View File

@ -27,6 +27,7 @@ async def completion(query: str = Body(..., description="用户输入", examples
prompt_name: str = prompt_name, prompt_name: str = prompt_name,
echo: bool = echo, echo: bool = echo,
) -> AsyncIterable[str]: ) -> AsyncIterable[str]:
nonlocal max_tokens
callback = AsyncIteratorCallbackHandler() callback = AsyncIteratorCallbackHandler()
if isinstance(max_tokens, int) and max_tokens <= 0: if isinstance(max_tokens, int) and max_tokens <= 0:
max_tokens = None max_tokens = None

View File

@ -113,6 +113,7 @@ async def file_chat(query: str = Body(..., description="用户输入", examples=
history = [History.from_data(h) for h in history] history = [History.from_data(h) for h in history]
async def knowledge_base_chat_iterator() -> AsyncIterable[str]: async def knowledge_base_chat_iterator() -> AsyncIterable[str]:
nonlocal max_tokens
callback = AsyncIteratorCallbackHandler() callback = AsyncIteratorCallbackHandler()
if isinstance(max_tokens, int) and max_tokens <= 0: if isinstance(max_tokens, int) and max_tokens <= 0:
max_tokens = None max_tokens = None

View File

@ -48,6 +48,7 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
model_name: str = LLM_MODELS[0], model_name: str = LLM_MODELS[0],
prompt_name: str = prompt_name, prompt_name: str = prompt_name,
) -> AsyncIterable[str]: ) -> AsyncIterable[str]:
nonlocal max_tokens
callback = AsyncIteratorCallbackHandler() callback = AsyncIteratorCallbackHandler()
if isinstance(max_tokens, int) and max_tokens <= 0: if isinstance(max_tokens, int) and max_tokens <= 0:
max_tokens = None max_tokens = None

View File

@ -147,6 +147,7 @@ async def search_engine_chat(query: str = Body(..., description="用户输入",
model_name: str = LLM_MODELS[0], model_name: str = LLM_MODELS[0],
prompt_name: str = prompt_name, prompt_name: str = prompt_name,
) -> AsyncIterable[str]: ) -> AsyncIterable[str]:
nonlocal max_tokens
callback = AsyncIteratorCallbackHandler() callback = AsyncIteratorCallbackHandler()
if isinstance(max_tokens, int) and max_tokens <= 0: if isinstance(max_tokens, int) and max_tokens <= 0:
max_tokens = None max_tokens = None