add paramter `max_tokens` to 4 chat api with default value 1024 (#1744)
This commit is contained in:
parent
1ac173958d
commit
cd748128c3
|
|
@ -25,6 +25,7 @@ async def agent_chat(query: str = Body(..., description="用户输入", examples
|
|||
stream: bool = Body(False, description="流式输出"),
|
||||
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
|
||||
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
|
||||
max_tokens: int = Body(1024, description="限制LLM生成Token数量,当前默认为1024"), # TODO: fastchat更新后默认值设为None,自动使用LLM支持的最大值。
|
||||
prompt_name: str = Body("agent_chat",
|
||||
description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
|
||||
# top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0),
|
||||
|
|
@ -41,6 +42,7 @@ async def agent_chat(query: str = Body(..., description="用户输入", examples
|
|||
model = get_ChatOpenAI(
|
||||
model_name=model_name,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
|
||||
prompt_template = CustomPromptTemplate(
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
|
|||
stream: bool = Body(False, description="流式输出"),
|
||||
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
|
||||
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
|
||||
max_tokens: int = Body(1024, description="限制LLM生成Token数量,当前默认为1024"), # TODO: fastchat更新后默认值设为None,自动使用LLM支持的最大值。
|
||||
# top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0),
|
||||
prompt_name: str = Body("llm_chat", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
|
||||
):
|
||||
|
|
@ -36,6 +37,7 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
|
|||
model = get_ChatOpenAI(
|
||||
model_name=model_name,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
callbacks=[callback],
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -31,6 +31,7 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
|
|||
stream: bool = Body(False, description="流式输出"),
|
||||
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
|
||||
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
|
||||
max_tokens: int = Body(1024, description="限制LLM生成Token数量,当前默认为1024"), # TODO: fastchat更新后默认值设为None,自动使用LLM支持的最大值。
|
||||
prompt_name: str = Body("knowledge_base_chat", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
|
||||
local_doc_url: bool = Body(False, description="知识文件返回本地路径(true)或URL(false)"),
|
||||
request: Request = None,
|
||||
|
|
@ -51,6 +52,7 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
|
|||
model = get_ChatOpenAI(
|
||||
model_name=model_name,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
callbacks=[callback],
|
||||
)
|
||||
docs = search_docs(query, knowledge_base_name, top_k, score_threshold)
|
||||
|
|
|
|||
|
|
@ -72,6 +72,7 @@ async def search_engine_chat(query: str = Body(..., description="用户输入",
|
|||
stream: bool = Body(False, description="流式输出"),
|
||||
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
|
||||
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
|
||||
max_tokens: int = Body(1024, description="限制LLM生成Token数量,当前默认为1024"), # TODO: fastchat更新后默认值设为None,自动使用LLM支持的最大值。
|
||||
prompt_name: str = Body("knowledge_base_chat", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
|
||||
):
|
||||
if search_engine_name not in SEARCH_ENGINES.keys():
|
||||
|
|
@ -93,6 +94,7 @@ async def search_engine_chat(query: str = Body(..., description="用户输入",
|
|||
model = get_ChatOpenAI(
|
||||
model_name=model_name,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
callbacks=[callback],
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -34,6 +34,7 @@ async def wrap_done(fn: Awaitable, event: asyncio.Event):
|
|||
def get_ChatOpenAI(
|
||||
model_name: str,
|
||||
temperature: float,
|
||||
max_tokens: int = None,
|
||||
streaming: bool = True,
|
||||
callbacks: List[Callable] = [],
|
||||
verbose: bool = True,
|
||||
|
|
@ -48,6 +49,7 @@ def get_ChatOpenAI(
|
|||
openai_api_base=config.get("api_base_url", fschat_openai_api_address()),
|
||||
model_name=model_name,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
openai_proxy=config.get("openai_proxy"),
|
||||
**kwargs
|
||||
)
|
||||
|
|
@ -144,7 +146,7 @@ def run_async(cor):
|
|||
return loop.run_until_complete(cor)
|
||||
|
||||
|
||||
def iter_over_async(ait, loop):
|
||||
def iter_over_async(ait, loop=None):
|
||||
'''
|
||||
将异步生成器封装成同步生成器.
|
||||
'''
|
||||
|
|
@ -157,6 +159,12 @@ def iter_over_async(ait, loop):
|
|||
except StopAsyncIteration:
|
||||
return True, None
|
||||
|
||||
if loop is None:
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
except:
|
||||
loop = asyncio.new_event_loop()
|
||||
|
||||
while True:
|
||||
done, obj = loop.run_until_complete(get_next())
|
||||
if done:
|
||||
|
|
|
|||
|
|
@ -166,7 +166,7 @@ def dialogue_page(api: ApiRequest):
|
|||
ans = ""
|
||||
support_agent = ["gpt", "Qwen", "qwen-api", "baichuan-api"] # 目前支持agent的模型
|
||||
if not any(agent in llm_model for agent in support_agent):
|
||||
ans += "正在思考... \n\n <span style='color:red'>改模型并没有进行Agent对齐,无法正常使用Agent功能!</span>\n\n\n<span style='color:red'>请更换 GPT4或Qwen-14B等支持Agent的模型获得更好的体验! </span> \n\n\n"
|
||||
ans += "正在思考... \n\n <span style='color:red'>该模型并没有进行Agent对齐,无法正常使用Agent功能!</span>\n\n\n<span style='color:red'>请更换 GPT4或Qwen-14B等支持Agent的模型获得更好的体验! </span> \n\n\n"
|
||||
chat_box.update_msg(ans, element_index=0, streaming=False)
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue