add paramter `max_tokens` to 4 chat api with default value 1024 (#1744)

This commit is contained in:
liunux4odoo 2023-10-12 16:18:56 +08:00 committed by GitHub
parent 1ac173958d
commit cd748128c3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 18 additions and 2 deletions

View File

@ -25,6 +25,7 @@ async def agent_chat(query: str = Body(..., description="用户输入", examples
stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: int = Body(1024, description="限制LLM生成Token数量当前默认为1024"), # TODO: fastchat更新后默认值设为None自动使用LLM支持的最大值。
prompt_name: str = Body("agent_chat",
description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
# top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0),
@ -41,6 +42,7 @@ async def agent_chat(query: str = Body(..., description="用户输入", examples
model = get_ChatOpenAI(
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
)
prompt_template = CustomPromptTemplate(

View File

@ -22,6 +22,7 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: int = Body(1024, description="限制LLM生成Token数量当前默认为1024"), # TODO: fastchat更新后默认值设为None自动使用LLM支持的最大值。
# top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0),
prompt_name: str = Body("llm_chat", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
):
@ -36,6 +37,7 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
model = get_ChatOpenAI(
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
callbacks=[callback],
)

View File

@ -31,6 +31,7 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: int = Body(1024, description="限制LLM生成Token数量当前默认为1024"), # TODO: fastchat更新后默认值设为None自动使用LLM支持的最大值。
prompt_name: str = Body("knowledge_base_chat", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
local_doc_url: bool = Body(False, description="知识文件返回本地路径(true)或URL(false)"),
request: Request = None,
@ -51,6 +52,7 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
model = get_ChatOpenAI(
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
callbacks=[callback],
)
docs = search_docs(query, knowledge_base_name, top_k, score_threshold)

View File

@ -72,6 +72,7 @@ async def search_engine_chat(query: str = Body(..., description="用户输入",
stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: int = Body(1024, description="限制LLM生成Token数量当前默认为1024"), # TODO: fastchat更新后默认值设为None自动使用LLM支持的最大值。
prompt_name: str = Body("knowledge_base_chat", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
):
if search_engine_name not in SEARCH_ENGINES.keys():
@ -93,6 +94,7 @@ async def search_engine_chat(query: str = Body(..., description="用户输入",
model = get_ChatOpenAI(
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
callbacks=[callback],
)

View File

@ -34,6 +34,7 @@ async def wrap_done(fn: Awaitable, event: asyncio.Event):
def get_ChatOpenAI(
model_name: str,
temperature: float,
max_tokens: int = None,
streaming: bool = True,
callbacks: List[Callable] = [],
verbose: bool = True,
@ -48,6 +49,7 @@ def get_ChatOpenAI(
openai_api_base=config.get("api_base_url", fschat_openai_api_address()),
model_name=model_name,
temperature=temperature,
max_tokens=max_tokens,
openai_proxy=config.get("openai_proxy"),
**kwargs
)
@ -144,7 +146,7 @@ def run_async(cor):
return loop.run_until_complete(cor)
def iter_over_async(ait, loop):
def iter_over_async(ait, loop=None):
'''
将异步生成器封装成同步生成器.
'''
@ -157,6 +159,12 @@ def iter_over_async(ait, loop):
except StopAsyncIteration:
return True, None
if loop is None:
try:
loop = asyncio.get_event_loop()
except:
loop = asyncio.new_event_loop()
while True:
done, obj = loop.run_until_complete(get_next())
if done:

View File

@ -166,7 +166,7 @@ def dialogue_page(api: ApiRequest):
ans = ""
support_agent = ["gpt", "Qwen", "qwen-api", "baichuan-api"] # 目前支持agent的模型
if not any(agent in llm_model for agent in support_agent):
ans += "正在思考... \n\n <span style='color:red'>模型并没有进行Agent对齐无法正常使用Agent功能</span>\n\n\n<span style='color:red'>请更换 GPT4或Qwen-14B等支持Agent的模型获得更好的体验 </span> \n\n\n"
ans += "正在思考... \n\n <span style='color:red'>模型并没有进行Agent对齐无法正常使用Agent功能</span>\n\n\n<span style='color:red'>请更换 GPT4或Qwen-14B等支持Agent的模型获得更好的体验 </span> \n\n\n"
chat_box.update_msg(ans, element_index=0, streaming=False)