add paramter `max_tokens` to 4 chat api with default value 1024 (#1744)

This commit is contained in:
liunux4odoo 2023-10-12 16:18:56 +08:00 committed by GitHub
parent 1ac173958d
commit cd748128c3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 18 additions and 2 deletions

View File

@ -25,6 +25,7 @@ async def agent_chat(query: str = Body(..., description="用户输入", examples
stream: bool = Body(False, description="流式输出"), stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"), model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0), temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: int = Body(1024, description="限制LLM生成Token数量当前默认为1024"), # TODO: fastchat更新后默认值设为None自动使用LLM支持的最大值。
prompt_name: str = Body("agent_chat", prompt_name: str = Body("agent_chat",
description="使用的prompt模板名称(在configs/prompt_config.py中配置)"), description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
# top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0), # top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0),
@ -41,6 +42,7 @@ async def agent_chat(query: str = Body(..., description="用户输入", examples
model = get_ChatOpenAI( model = get_ChatOpenAI(
model_name=model_name, model_name=model_name,
temperature=temperature, temperature=temperature,
max_tokens=max_tokens,
) )
prompt_template = CustomPromptTemplate( prompt_template = CustomPromptTemplate(

View File

@ -22,6 +22,7 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
stream: bool = Body(False, description="流式输出"), stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"), model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0), temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: int = Body(1024, description="限制LLM生成Token数量当前默认为1024"), # TODO: fastchat更新后默认值设为None自动使用LLM支持的最大值。
# top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0), # top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0),
prompt_name: str = Body("llm_chat", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"), prompt_name: str = Body("llm_chat", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
): ):
@ -36,6 +37,7 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
model = get_ChatOpenAI( model = get_ChatOpenAI(
model_name=model_name, model_name=model_name,
temperature=temperature, temperature=temperature,
max_tokens=max_tokens,
callbacks=[callback], callbacks=[callback],
) )

View File

@ -31,6 +31,7 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
stream: bool = Body(False, description="流式输出"), stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"), model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0), temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: int = Body(1024, description="限制LLM生成Token数量当前默认为1024"), # TODO: fastchat更新后默认值设为None自动使用LLM支持的最大值。
prompt_name: str = Body("knowledge_base_chat", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"), prompt_name: str = Body("knowledge_base_chat", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
local_doc_url: bool = Body(False, description="知识文件返回本地路径(true)或URL(false)"), local_doc_url: bool = Body(False, description="知识文件返回本地路径(true)或URL(false)"),
request: Request = None, request: Request = None,
@ -51,6 +52,7 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
model = get_ChatOpenAI( model = get_ChatOpenAI(
model_name=model_name, model_name=model_name,
temperature=temperature, temperature=temperature,
max_tokens=max_tokens,
callbacks=[callback], callbacks=[callback],
) )
docs = search_docs(query, knowledge_base_name, top_k, score_threshold) docs = search_docs(query, knowledge_base_name, top_k, score_threshold)

View File

@ -72,6 +72,7 @@ async def search_engine_chat(query: str = Body(..., description="用户输入",
stream: bool = Body(False, description="流式输出"), stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"), model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0), temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
max_tokens: int = Body(1024, description="限制LLM生成Token数量当前默认为1024"), # TODO: fastchat更新后默认值设为None自动使用LLM支持的最大值。
prompt_name: str = Body("knowledge_base_chat", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"), prompt_name: str = Body("knowledge_base_chat", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
): ):
if search_engine_name not in SEARCH_ENGINES.keys(): if search_engine_name not in SEARCH_ENGINES.keys():
@ -93,6 +94,7 @@ async def search_engine_chat(query: str = Body(..., description="用户输入",
model = get_ChatOpenAI( model = get_ChatOpenAI(
model_name=model_name, model_name=model_name,
temperature=temperature, temperature=temperature,
max_tokens=max_tokens,
callbacks=[callback], callbacks=[callback],
) )

View File

@ -34,6 +34,7 @@ async def wrap_done(fn: Awaitable, event: asyncio.Event):
def get_ChatOpenAI( def get_ChatOpenAI(
model_name: str, model_name: str,
temperature: float, temperature: float,
max_tokens: int = None,
streaming: bool = True, streaming: bool = True,
callbacks: List[Callable] = [], callbacks: List[Callable] = [],
verbose: bool = True, verbose: bool = True,
@ -48,6 +49,7 @@ def get_ChatOpenAI(
openai_api_base=config.get("api_base_url", fschat_openai_api_address()), openai_api_base=config.get("api_base_url", fschat_openai_api_address()),
model_name=model_name, model_name=model_name,
temperature=temperature, temperature=temperature,
max_tokens=max_tokens,
openai_proxy=config.get("openai_proxy"), openai_proxy=config.get("openai_proxy"),
**kwargs **kwargs
) )
@ -144,7 +146,7 @@ def run_async(cor):
return loop.run_until_complete(cor) return loop.run_until_complete(cor)
def iter_over_async(ait, loop): def iter_over_async(ait, loop=None):
''' '''
将异步生成器封装成同步生成器. 将异步生成器封装成同步生成器.
''' '''
@ -157,6 +159,12 @@ def iter_over_async(ait, loop):
except StopAsyncIteration: except StopAsyncIteration:
return True, None return True, None
if loop is None:
try:
loop = asyncio.get_event_loop()
except:
loop = asyncio.new_event_loop()
while True: while True:
done, obj = loop.run_until_complete(get_next()) done, obj = loop.run_until_complete(get_next())
if done: if done:

View File

@ -166,7 +166,7 @@ def dialogue_page(api: ApiRequest):
ans = "" ans = ""
support_agent = ["gpt", "Qwen", "qwen-api", "baichuan-api"] # 目前支持agent的模型 support_agent = ["gpt", "Qwen", "qwen-api", "baichuan-api"] # 目前支持agent的模型
if not any(agent in llm_model for agent in support_agent): if not any(agent in llm_model for agent in support_agent):
ans += "正在思考... \n\n <span style='color:red'>模型并没有进行Agent对齐无法正常使用Agent功能</span>\n\n\n<span style='color:red'>请更换 GPT4或Qwen-14B等支持Agent的模型获得更好的体验 </span> \n\n\n" ans += "正在思考... \n\n <span style='color:red'>模型并没有进行Agent对齐无法正常使用Agent功能</span>\n\n\n<span style='color:red'>请更换 GPT4或Qwen-14B等支持Agent的模型获得更好的体验 </span> \n\n\n"
chat_box.update_msg(ans, element_index=0, streaming=False) chat_box.update_msg(ans, element_index=0, streaming=False)