diff --git a/server/chat/agent_chat.py b/server/chat/agent_chat.py
index df0ce8b..1f51802 100644
--- a/server/chat/agent_chat.py
+++ b/server/chat/agent_chat.py
@@ -25,6 +25,7 @@ async def agent_chat(query: str = Body(..., description="用户输入", examples
stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
+ max_tokens: int = Body(1024, description="限制LLM生成Token数量,当前默认为1024"), # TODO: fastchat更新后默认值设为None,自动使用LLM支持的最大值。
prompt_name: str = Body("agent_chat",
description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
# top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0),
@@ -41,6 +42,7 @@ async def agent_chat(query: str = Body(..., description="用户输入", examples
model = get_ChatOpenAI(
model_name=model_name,
temperature=temperature,
+ max_tokens=max_tokens,
)
prompt_template = CustomPromptTemplate(
diff --git a/server/chat/chat.py b/server/chat/chat.py
index 6d3c9ce..2f37f99 100644
--- a/server/chat/chat.py
+++ b/server/chat/chat.py
@@ -22,6 +22,7 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
+ max_tokens: int = Body(1024, description="限制LLM生成Token数量,当前默认为1024"), # TODO: fastchat更新后默认值设为None,自动使用LLM支持的最大值。
# top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0),
prompt_name: str = Body("llm_chat", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
):
@@ -36,6 +37,7 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
model = get_ChatOpenAI(
model_name=model_name,
temperature=temperature,
+ max_tokens=max_tokens,
callbacks=[callback],
)
diff --git a/server/chat/knowledge_base_chat.py b/server/chat/knowledge_base_chat.py
index 9c70ee5..87e7098 100644
--- a/server/chat/knowledge_base_chat.py
+++ b/server/chat/knowledge_base_chat.py
@@ -31,6 +31,7 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
+ max_tokens: int = Body(1024, description="限制LLM生成Token数量,当前默认为1024"), # TODO: fastchat更新后默认值设为None,自动使用LLM支持的最大值。
prompt_name: str = Body("knowledge_base_chat", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
local_doc_url: bool = Body(False, description="知识文件返回本地路径(true)或URL(false)"),
request: Request = None,
@@ -51,6 +52,7 @@ async def knowledge_base_chat(query: str = Body(..., description="用户输入",
model = get_ChatOpenAI(
model_name=model_name,
temperature=temperature,
+ max_tokens=max_tokens,
callbacks=[callback],
)
docs = search_docs(query, knowledge_base_name, top_k, score_threshold)
diff --git a/server/chat/search_engine_chat.py b/server/chat/search_engine_chat.py
index 00708b7..930c9bd 100644
--- a/server/chat/search_engine_chat.py
+++ b/server/chat/search_engine_chat.py
@@ -72,6 +72,7 @@ async def search_engine_chat(query: str = Body(..., description="用户输入",
stream: bool = Body(False, description="流式输出"),
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
+ max_tokens: int = Body(1024, description="限制LLM生成Token数量,当前默认为1024"), # TODO: fastchat更新后默认值设为None,自动使用LLM支持的最大值。
prompt_name: str = Body("knowledge_base_chat", description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
):
if search_engine_name not in SEARCH_ENGINES.keys():
@@ -93,6 +94,7 @@ async def search_engine_chat(query: str = Body(..., description="用户输入",
model = get_ChatOpenAI(
model_name=model_name,
temperature=temperature,
+ max_tokens=max_tokens,
callbacks=[callback],
)
diff --git a/server/utils.py b/server/utils.py
index 6030748..bae9fe4 100644
--- a/server/utils.py
+++ b/server/utils.py
@@ -34,6 +34,7 @@ async def wrap_done(fn: Awaitable, event: asyncio.Event):
def get_ChatOpenAI(
model_name: str,
temperature: float,
+ max_tokens: int = None,
streaming: bool = True,
callbacks: List[Callable] = [],
verbose: bool = True,
@@ -48,6 +49,7 @@ def get_ChatOpenAI(
openai_api_base=config.get("api_base_url", fschat_openai_api_address()),
model_name=model_name,
temperature=temperature,
+ max_tokens=max_tokens,
openai_proxy=config.get("openai_proxy"),
**kwargs
)
@@ -144,7 +146,7 @@ def run_async(cor):
return loop.run_until_complete(cor)
-def iter_over_async(ait, loop):
+def iter_over_async(ait, loop=None):
'''
将异步生成器封装成同步生成器.
'''
@@ -157,6 +159,12 @@ def iter_over_async(ait, loop):
except StopAsyncIteration:
return True, None
+ if loop is None:
+ try:
+ loop = asyncio.get_event_loop()
+ except:
+ loop = asyncio.new_event_loop()
+
while True:
done, obj = loop.run_until_complete(get_next())
if done:
diff --git a/webui_pages/dialogue/dialogue.py b/webui_pages/dialogue/dialogue.py
index 7af2dc5..2664b30 100644
--- a/webui_pages/dialogue/dialogue.py
+++ b/webui_pages/dialogue/dialogue.py
@@ -166,7 +166,7 @@ def dialogue_page(api: ApiRequest):
ans = ""
support_agent = ["gpt", "Qwen", "qwen-api", "baichuan-api"] # 目前支持agent的模型
if not any(agent in llm_model for agent in support_agent):
- ans += "正在思考... \n\n 改模型并没有进行Agent对齐,无法正常使用Agent功能!\n\n\n请更换 GPT4或Qwen-14B等支持Agent的模型获得更好的体验! \n\n\n"
+ ans += "正在思考... \n\n 该模型并没有进行Agent对齐,无法正常使用Agent功能!\n\n\n请更换 GPT4或Qwen-14B等支持Agent的模型获得更好的体验! \n\n\n"
chat_box.update_msg(ans, element_index=0, streaming=False)