增加了仅限GPT4的agent功能,陆续补充,中文版readme已写 (#1611)
This commit is contained in:
parent
c546b4271e
commit
5702554171
27
README.md
27
README.md
|
|
@ -189,6 +189,18 @@ docker run -d --gpus all -p 80:8501 registry.cn-beijing.aliyuncs.com/chatchat/ch
|
|||
|
||||
关于如何使用自定义分词器和贡献自己的分词器,可以参考[Text Splitter 贡献说明](docs/splitter.md)。
|
||||
|
||||
## Agent生态
|
||||
### 基础的Agent
|
||||
在本版本中,我们实现了一个简单的基于OpenAI的React的Agent模型,目前,经过我们测试,仅有以下两个模型支持:
|
||||
+ OpenAI GPT4
|
||||
+ ChatGLM2-130B
|
||||
|
||||
目前版本的Agent仍然需要对提示词进行大量调试,调试位置
|
||||
|
||||
### 构建自己的Agent工具
|
||||
|
||||
详见 (docs/自定义Agent.md)
|
||||
|
||||
## Docker 部署
|
||||
|
||||
🐳 Docker 镜像地址: `registry.cn-beijing.aliyuncs.com/chatchat/chatchat:0.2.3)`
|
||||
|
|
@ -392,23 +404,24 @@ CUDA_VISIBLE_DEVICES=0,1 python startup.py -a
|
|||
- [ ] 结构化数据接入
|
||||
- [X] .csv
|
||||
- [ ] .xlsx
|
||||
- [ ] 分词及召回
|
||||
- [ ] 接入不同类型 TextSplitter
|
||||
- [ ] 优化依据中文标点符号设计的 ChineseTextSplitter
|
||||
- [ ] 重新实现上下文拼接召回
|
||||
- [X] 分词及召回
|
||||
- [X] 接入不同类型 TextSplitter
|
||||
- [X] 优化依据中文标点符号设计的 ChineseTextSplitter
|
||||
- [X] 重新实现上下文拼接召回
|
||||
- [ ] 本地网页接入
|
||||
- [ ] SQL 接入
|
||||
- [ ] 知识图谱/图数据库接入
|
||||
- [X] 搜索引擎接入
|
||||
- [X] Bing 搜索
|
||||
- [X] DuckDuckGo 搜索
|
||||
- [ ] Agent 实现
|
||||
- [X] Agent 实现
|
||||
- [X]基础React形式的Agent实现,包括调用计算器等
|
||||
- [X] LLM 模型接入
|
||||
- [X] 支持通过调用 [FastChat](https://github.com/lm-sys/fastchat) api 调用 llm
|
||||
- [ ] 支持 ChatGLM API 等 LLM API 的接入
|
||||
- [X] 支持 ChatGLM API 等 LLM API 的接入
|
||||
- [X] Embedding 模型接入
|
||||
- [X] 支持调用 HuggingFace 中各开源 Emebdding 模型
|
||||
- [ ] 支持 OpenAI Embedding API 等 Embedding API 的接入
|
||||
- [X] 支持 OpenAI Embedding API 等 Embedding API 的接入
|
||||
- [X] 基于 FastAPI 的 API 方式调用
|
||||
- [X] Web UI
|
||||
- [X] 基于 Streamlit 的 Web UI
|
||||
|
|
|
|||
|
|
@ -1,10 +1,12 @@
|
|||
langchain==0.0.287
|
||||
langchain>=0.0.302
|
||||
fschat[model_worker]==0.2.29
|
||||
openai
|
||||
sentence_transformers
|
||||
transformers>=4.31.0
|
||||
torch~=2.0.0
|
||||
fastapi~=0.99.1
|
||||
transformers>=4.33.0
|
||||
torch>=2.0.1
|
||||
torchvision
|
||||
torchaudio
|
||||
fastapi>=0.103.1
|
||||
nltk~=3.8.1
|
||||
uvicorn~=0.23.1
|
||||
starlette~=0.27.0
|
||||
|
|
@ -40,9 +42,13 @@ pandas~=2.0.3
|
|||
streamlit>=1.26.0
|
||||
streamlit-option-menu>=0.3.6
|
||||
streamlit-antd-components>=0.1.11
|
||||
streamlit-chatbox >=1.1.6, <=1.1.7
|
||||
streamlit-chatbox>=1.1.9
|
||||
streamlit-aggrid>=0.3.4.post3
|
||||
httpx~=0.24.1
|
||||
watchdog
|
||||
tqdm
|
||||
websockets
|
||||
tiktoken
|
||||
einops
|
||||
scipy
|
||||
transformers_stream_generator==0.0.4
|
||||
|
|
|
|||
|
|
@ -1,10 +1,12 @@
|
|||
langchain==0.0.287
|
||||
langchain>=0.0.302
|
||||
fschat[model_worker]==0.2.29
|
||||
openai
|
||||
sentence_transformers
|
||||
transformers>=4.31.0
|
||||
torch~=2.0.0
|
||||
fastapi~=0.99.1
|
||||
transformers>=4.33.0
|
||||
torch >=2.0.1
|
||||
torchvision
|
||||
torchaudio
|
||||
fastapi>=0.103.1
|
||||
nltk~=3.8.1
|
||||
uvicorn~=0.23.1
|
||||
starlette~=0.27.0
|
||||
|
|
@ -17,13 +19,14 @@ accelerate
|
|||
spacy
|
||||
PyMuPDF==1.22.5
|
||||
rapidocr_onnxruntime>=1.3.2
|
||||
|
||||
requests
|
||||
pathlib
|
||||
pytest
|
||||
scikit-learn
|
||||
numexpr
|
||||
vllm==0.1.7; sys_platform == "linux"
|
||||
|
||||
|
||||
# online api libs
|
||||
# zhipuai
|
||||
# dashscope>=1.10.0 # qwen
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ pandas~=2.0.3
|
|||
streamlit>=1.26.0
|
||||
streamlit-option-menu>=0.3.6
|
||||
streamlit-antd-components>=0.1.11
|
||||
streamlit-chatbox >=1.1.6, <=1.1.7
|
||||
streamlit-chatbox>=1.1.9
|
||||
streamlit-aggrid>=0.3.4.post3
|
||||
httpx~=0.24.1
|
||||
nltk
|
||||
|
|
|
|||
|
|
@ -0,0 +1,109 @@
|
|||
from uuid import UUID
|
||||
from langchain.callbacks import AsyncIteratorCallbackHandler
|
||||
import json
|
||||
import asyncio
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain.schema import AgentFinish, AgentAction
|
||||
from langchain.schema.output import LLMResult
|
||||
|
||||
|
||||
def dumps(obj: Dict) -> str:
|
||||
return json.dumps(obj, ensure_ascii=False)
|
||||
|
||||
|
||||
class Status:
|
||||
start: int = 1
|
||||
running: int = 2
|
||||
complete: int = 3
|
||||
agent_action: int = 4
|
||||
agent_finish: int = 5
|
||||
error: int = 6
|
||||
make_tool: int = 7
|
||||
|
||||
|
||||
class CustomAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.queue = asyncio.Queue()
|
||||
self.done = asyncio.Event()
|
||||
self.cur_tool = {}
|
||||
self.out = True
|
||||
|
||||
async def on_tool_start(self, serialized: Dict[str, Any], input_str: str, *, run_id: UUID,
|
||||
parent_run_id: UUID | None = None, tags: List[str] | None = None,
|
||||
metadata: Dict[str, Any] | None = None, **kwargs: Any) -> None:
|
||||
self.cur_tool = {
|
||||
"tool_name": serialized["name"],
|
||||
"input_str": input_str,
|
||||
"output_str": "",
|
||||
"status": Status.agent_action,
|
||||
"run_id": run_id.hex,
|
||||
"llm_token": "",
|
||||
"final_answer": "",
|
||||
"error": "",
|
||||
}
|
||||
self.queue.put_nowait(dumps(self.cur_tool))
|
||||
|
||||
async def on_tool_end(self, output: str, *, run_id: UUID, parent_run_id: UUID | None = None,
|
||||
tags: List[str] | None = None, **kwargs: Any) -> None:
|
||||
self.out = True
|
||||
self.cur_tool.update(
|
||||
status=Status.agent_finish,
|
||||
output_str=output.replace("Answer:", ""),
|
||||
)
|
||||
self.queue.put_nowait(dumps(self.cur_tool))
|
||||
|
||||
async def on_tool_error(self, error: Exception | KeyboardInterrupt, *, run_id: UUID,
|
||||
parent_run_id: UUID | None = None, tags: List[str] | None = None, **kwargs: Any) -> None:
|
||||
self.out = True
|
||||
self.cur_tool.update(
|
||||
status=Status.error,
|
||||
error=str(error),
|
||||
)
|
||||
self.queue.put_nowait(dumps(self.cur_tool))
|
||||
|
||||
async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
if token:
|
||||
if token == "Action":
|
||||
self.out = False
|
||||
self.cur_tool.update(
|
||||
status=Status.running,
|
||||
llm_token="\n\n",
|
||||
)
|
||||
|
||||
if self.out:
|
||||
self.cur_tool.update(
|
||||
status=Status.running,
|
||||
llm_token=token,
|
||||
)
|
||||
self.queue.put_nowait(dumps(self.cur_tool))
|
||||
|
||||
async def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any) -> None:
|
||||
self.cur_tool.update(
|
||||
status=Status.start,
|
||||
llm_token="",
|
||||
)
|
||||
self.queue.put_nowait(dumps(self.cur_tool))
|
||||
|
||||
async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||
self.out = True
|
||||
self.cur_tool.update(
|
||||
status=Status.complete,
|
||||
llm_token="",
|
||||
)
|
||||
self.queue.put_nowait(dumps(self.cur_tool))
|
||||
|
||||
async def on_llm_error(self, error: Exception | KeyboardInterrupt, **kwargs: Any) -> None:
|
||||
self.cur_tool.update(
|
||||
status=Status.error,
|
||||
error=str(error),
|
||||
)
|
||||
self.queue.put_nowait(dumps(self.cur_tool))
|
||||
|
||||
async def on_agent_finish(
|
||||
self, finish: AgentFinish, *, run_id: UUID, parent_run_id: Optional[UUID] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self.cur_tool = {}
|
||||
|
|
@ -1,22 +1,23 @@
|
|||
template = """
|
||||
尽可能地回答以下问题。你可以使用以下工具:{tools}
|
||||
请按照以下格式进行:
|
||||
Question: 需要你回答的输入问题
|
||||
Thought: 你应该总是思考该做什么
|
||||
|
||||
Question: 需要你回答的输入问题。
|
||||
Thought: 你应该总是思考该做什么,并告诉我你要用什么工具。
|
||||
Action: 需要使用的工具,应该是[{tool_names}]中的一个
|
||||
Action Input: 传入工具的内容
|
||||
Observation: 行动的结果
|
||||
... (这个Thought/Action/Action Input/Observation可以重复N次)
|
||||
Thought: 我现在知道最后的答案
|
||||
Final Answer: 对原始输入问题的最终答案
|
||||
|
||||
Thought: 通过使用工具,我是否知道了答案,如果知道,就自然的回答问题,如果不知道,继续使用工具或者自己的知识 \n
|
||||
Final Answer: 这个问题的答案是,输出完整的句子。
|
||||
现在开始!
|
||||
|
||||
之前的对话:
|
||||
{history}
|
||||
|
||||
New question: {input}
|
||||
Thought: {agent_scratchpad}"""
|
||||
New question:
|
||||
{input}
|
||||
Thought:
|
||||
{agent_scratchpad}"""
|
||||
|
||||
|
||||
# ChatGPT 提示词模板
|
||||
|
|
@ -84,7 +85,7 @@ class CustomOutputParser(AgentOutputParser):
|
|||
return AgentFinish(
|
||||
# Return values is generally always a dictionary with a single `output` key
|
||||
# It is not recommended to try anything else at the moment :)
|
||||
return_values={"output": llm_output.split("Final Answer:")[-1].strip()},
|
||||
return_values={"output": llm_output.replace("Final Answer:", "").strip()},
|
||||
log=llm_output,
|
||||
)
|
||||
# Parse out the action and action input
|
||||
|
|
@ -95,10 +96,13 @@ class CustomOutputParser(AgentOutputParser):
|
|||
return_values={"output": f"调用agent失败: `{llm_output}`"},
|
||||
log=llm_output,
|
||||
)
|
||||
raise OutputParserException(f"调用agent失败: `{llm_output}`")
|
||||
action = match.group(1).strip()
|
||||
action_input = match.group(2)
|
||||
# Return the action and action input
|
||||
return AgentAction(tool=action, tool_input=action_input.strip(" ").strip('"'), log=llm_output)
|
||||
return AgentAction(
|
||||
tool=action,
|
||||
tool_input=action_input.strip(" ").strip('"'),
|
||||
log=llm_output
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from configs.model_config import LLM_MODEL,TEMPERATURE
|
|||
|
||||
_PROMPT_TEMPLATE = '''
|
||||
# 指令
|
||||
接下来,作为一个专业的翻译专家,当我给出英文句子或段落时,你将提供通顺且具有可读性的对应语言的翻译。注意:
|
||||
接下来,作为一个专业的翻译专家,当我给出句子或段落时,你将提供通顺且具有可读性的对应语言的翻译。注意:
|
||||
1. 确保翻译结果流畅且易于理解
|
||||
2. 无论提供的是陈述句或疑问句,只进行翻译
|
||||
3. 不添加与原文无关的内容
|
||||
|
|
@ -21,6 +21,14 @@ _PROMPT_TEMPLATE = '''
|
|||
${{翻译结果}}
|
||||
```
|
||||
答案: ${{答案}}
|
||||
|
||||
以下是一个例子
|
||||
问题: 翻译13成英语
|
||||
```text
|
||||
13 English
|
||||
```output
|
||||
thirteen
|
||||
答案: thirteen
|
||||
'''
|
||||
|
||||
PROMPT = PromptTemplate(
|
||||
|
|
|
|||
|
|
@ -113,7 +113,7 @@ def weather(query):
|
|||
return "只能查看24小时内的天气,无法回答"
|
||||
if time == "None":
|
||||
time = "24" # 免费的版本只能24小时内的天气
|
||||
key = "" # 和风天气API Key
|
||||
key = "315625cdca234137944d7f8956106a3e" # 和风天气API Key
|
||||
if key == "":
|
||||
return "请先在代码中填入和风天气API Key"
|
||||
city_info = get_city_info(location=location, adm=adm, key=key)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from langchain.memory import ConversationBufferWindowMemory
|
||||
from server.agent.tools import tools, tool_names
|
||||
from server.agent.callbacks import CustomAsyncIteratorCallbackHandler, Status, dumps
|
||||
from langchain.agents import AgentExecutor, LLMSingleActionAgent
|
||||
from server.agent.custom_template import CustomOutputParser, prompt
|
||||
from fastapi import Body
|
||||
|
|
@ -7,38 +8,31 @@ from fastapi.responses import StreamingResponse
|
|||
from configs.model_config import LLM_MODEL, TEMPERATURE, HISTORY_LEN
|
||||
from server.utils import wrap_done, get_ChatOpenAI
|
||||
from langchain import LLMChain
|
||||
from langchain.callbacks import AsyncIteratorCallbackHandler
|
||||
from langchain.callbacks.streaming_aiter_final_only import AsyncFinalIteratorCallbackHandler
|
||||
from typing import AsyncIterable
|
||||
import asyncio
|
||||
from langchain.prompts.chat import ChatPromptTemplate
|
||||
from typing import List
|
||||
from server.chat.utils import History
|
||||
|
||||
memory = ConversationBufferWindowMemory(k=HISTORY_LEN)
|
||||
import json
|
||||
async def agent_chat(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]),
|
||||
history: List[History] = Body([],
|
||||
description="历史对话",
|
||||
examples=[[
|
||||
{"role": "user", "content": "我们来玩成语接龙,我先来,生龙活虎"},
|
||||
{"role": "assistant", "content": "虎头虎脑"}]]
|
||||
),
|
||||
stream: bool = Body(False, description="流式输出"),
|
||||
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
|
||||
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", gt=0.0, le=1.0),
|
||||
# top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0),
|
||||
):
|
||||
history: List[History] = Body([],
|
||||
description="历史对话",
|
||||
examples=[[
|
||||
{"role": "user", "content": "我们来玩成语接龙,我先来,生龙活虎"},
|
||||
{"role": "assistant", "content": "虎头虎脑"}]]
|
||||
),
|
||||
stream: bool = Body(False, description="流式输出"),
|
||||
model_name: str = Body(LLM_MODEL, description="LLM 模型名称。"),
|
||||
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", gt=0.0, le=1.0),
|
||||
# top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0),
|
||||
):
|
||||
history = [History.from_data(h) for h in history]
|
||||
|
||||
async def chat_iterator(query: str,
|
||||
history: List[History] = [],
|
||||
model_name: str = LLM_MODEL,
|
||||
) -> AsyncIterable[str]:
|
||||
callback = AsyncFinalIteratorCallbackHandler()
|
||||
async def chat_iterator() -> AsyncIterable[str]:
|
||||
callback = CustomAsyncIteratorCallbackHandler()
|
||||
model = get_ChatOpenAI(
|
||||
model_name=model_name,
|
||||
temperature=temperature,
|
||||
callbacks=[callback],
|
||||
)
|
||||
output_parser = CustomOutputParser()
|
||||
llm_chain = LLMChain(llm=model, prompt=prompt)
|
||||
|
|
@ -46,28 +40,69 @@ async def agent_chat(query: str = Body(..., description="用户输入", examples
|
|||
llm_chain=llm_chain,
|
||||
output_parser=output_parser,
|
||||
stop=["\nObservation:"],
|
||||
allowed_tools=tool_names
|
||||
allowed_tools=tool_names,
|
||||
)
|
||||
agent_executor = AgentExecutor.from_agent_and_tools(agent=agent, tools=tools, verbose=True, memory=memory,
|
||||
callbacks=[callback])
|
||||
# 把history转成agent的memory
|
||||
memory = ConversationBufferWindowMemory(k=100)
|
||||
|
||||
for message in history:
|
||||
# 检查消息的角色
|
||||
if message.role == 'user':
|
||||
# 添加用户消息
|
||||
memory.chat_memory.add_user_message(message.content)
|
||||
else:
|
||||
# 添加AI消息
|
||||
memory.chat_memory.add_ai_message(message.content)
|
||||
|
||||
agent_executor = AgentExecutor.from_agent_and_tools(agent=agent,
|
||||
tools=tools,
|
||||
verbose=True,
|
||||
memory=memory,
|
||||
)
|
||||
# TODO: history is not used
|
||||
input_msg = History(role="user", content="{{ input }}").to_msg_template(False)
|
||||
chat_prompt = ChatPromptTemplate.from_messages(
|
||||
[i.to_msg_template() for i in history] + [input_msg])
|
||||
task = asyncio.create_task(wrap_done(
|
||||
agent_executor.acall(query),
|
||||
agent_executor.acall(query, callbacks=[callback], include_run_info=True),
|
||||
callback.done),
|
||||
)
|
||||
if stream:
|
||||
async for token in callback.aiter():
|
||||
async for chunk in callback.aiter():
|
||||
tools_use = []
|
||||
# Use server-sent-events to stream the response
|
||||
yield token
|
||||
data = json.loads(chunk)
|
||||
if data["status"] == Status.start or data["status"] == Status.complete:
|
||||
continue
|
||||
if data["status"] == Status.agent_action:
|
||||
yield json.dumps({"answer": "(正在使用工具,请注意工具栏变化) \n\n"}, ensure_ascii=False)
|
||||
if data["status"] == Status.agent_finish:
|
||||
tools_use.append("工具名称: " + data["tool_name"])
|
||||
tools_use.append("工具输入: " + data["input_str"])
|
||||
tools_use.append("工具输出: " + data["output_str"])
|
||||
yield json.dumps({"tools": tools_use}, ensure_ascii=False)
|
||||
yield json.dumps({"answer": data["llm_token"]}, ensure_ascii=False)
|
||||
|
||||
else:
|
||||
answer = ""
|
||||
async for token in callback.aiter():
|
||||
answer += token
|
||||
yield answer
|
||||
pass
|
||||
# agent必须要steram=True
|
||||
# result = []
|
||||
# async for chunk in callback.aiter():
|
||||
# data = json.loads(chunk)
|
||||
# status = data["status"]
|
||||
# if status == Status.start:
|
||||
# result.append(chunk)
|
||||
# elif status == Status.running:
|
||||
# result[-1]["llm_token"] += chunk["llm_token"]
|
||||
# elif status == Status.complete:
|
||||
# result[-1]["status"] = Status.complete
|
||||
# elif status == Status.agent_finish:
|
||||
# result.append(chunk)
|
||||
# elif status == Status.agent_finish:
|
||||
# pass
|
||||
# yield dumps(result)
|
||||
|
||||
await task
|
||||
|
||||
return StreamingResponse(chat_iterator(query, history, model_name),
|
||||
return StreamingResponse(chat_iterator(),
|
||||
media_type="text/event-stream")
|
||||
|
|
|
|||
|
|
@ -8,7 +8,6 @@ from configs import LLM_MODEL, TEMPERATURE
|
|||
from server.utils import get_model_worker_config
|
||||
from typing import List, Dict
|
||||
|
||||
|
||||
chat_box = ChatBox(
|
||||
assistant_avatar=os.path.join(
|
||||
"img",
|
||||
|
|
@ -71,7 +70,7 @@ def dialogue_page(api: ApiRequest):
|
|||
|
||||
def on_llm_change():
|
||||
config = get_model_worker_config(llm_model)
|
||||
if not config.get("online_api"): # 只有本地model_worker可以切换模型
|
||||
if not config.get("online_api"): # 只有本地model_worker可以切换模型
|
||||
st.session_state["prev_llm_model"] = llm_model
|
||||
st.session_state["cur_llm_model"] = st.session_state.llm_model
|
||||
|
||||
|
|
@ -90,15 +89,15 @@ def dialogue_page(api: ApiRequest):
|
|||
llm_models = running_models + available_models
|
||||
index = llm_models.index(st.session_state.get("cur_llm_model", LLM_MODEL))
|
||||
llm_model = st.selectbox("选择LLM模型:",
|
||||
llm_models,
|
||||
index,
|
||||
format_func=llm_model_format_func,
|
||||
on_change=on_llm_change,
|
||||
key="llm_model",
|
||||
)
|
||||
llm_models,
|
||||
index,
|
||||
format_func=llm_model_format_func,
|
||||
on_change=on_llm_change,
|
||||
key="llm_model",
|
||||
)
|
||||
if (st.session_state.get("prev_llm_model") != llm_model
|
||||
and not get_model_worker_config(llm_model).get("online_api")
|
||||
and llm_model not in running_models):
|
||||
and not get_model_worker_config(llm_model).get("online_api")
|
||||
and llm_model not in running_models):
|
||||
with st.spinner(f"正在加载模型: {llm_model},请勿进行操作或刷新页面"):
|
||||
prev_model = st.session_state.get("prev_llm_model")
|
||||
r = api.change_llm_model(prev_model, llm_model)
|
||||
|
|
@ -150,18 +149,6 @@ def dialogue_page(api: ApiRequest):
|
|||
chat_box.ai_say("正在思考...")
|
||||
text = ""
|
||||
r = api.chat_chat(prompt, history=history, model=llm_model, temperature=temperature)
|
||||
for t in r:
|
||||
if error_msg := check_error_msg(t): # check whether error occured
|
||||
st.error(error_msg)
|
||||
break
|
||||
text += t
|
||||
chat_box.update_msg(text)
|
||||
chat_box.update_msg(text, streaming=False) # 更新最终的字符串,去除光标
|
||||
|
||||
elif dialogue_mode == "自定义Agent问答":
|
||||
chat_box.ai_say("正在调用工具回答...")
|
||||
text = ""
|
||||
r = api.agent_chat(prompt, history=history, model=llm_model, temperature=temperature)
|
||||
for t in r:
|
||||
if error_msg := check_error_msg(t): # check whether error occured
|
||||
st.error(error_msg)
|
||||
|
|
@ -170,6 +157,31 @@ def dialogue_page(api: ApiRequest):
|
|||
chat_box.update_msg(text)
|
||||
chat_box.update_msg(text, streaming=False) # 更新最终的字符串,去除光标
|
||||
|
||||
|
||||
elif dialogue_mode == "自定义Agent问答":
|
||||
chat_box.ai_say([
|
||||
f"正在思考和寻找工具 ...",])
|
||||
text = ""
|
||||
element_index = 0
|
||||
for d in api.agent_chat(prompt,
|
||||
history=history,
|
||||
model=llm_model,
|
||||
temperature=temperature):
|
||||
try:
|
||||
d = json.loads(d)
|
||||
except:
|
||||
pass
|
||||
if error_msg := check_error_msg(d): # check whether error occured
|
||||
st.error(error_msg)
|
||||
|
||||
elif chunk := d.get("answer"):
|
||||
text += chunk
|
||||
chat_box.update_msg(text, element_index=0)
|
||||
elif chunk := d.get("tools"):
|
||||
element_index += 1
|
||||
chat_box.insert_msg(Markdown("...", in_expander=True, title="使用工具...", state="complete"))
|
||||
chat_box.update_msg("\n\n".join(d.get("tools", [])), element_index=element_index, streaming=False)
|
||||
chat_box.update_msg(text, element_index=0, streaming=False)
|
||||
elif dialogue_mode == "知识库问答":
|
||||
chat_box.ai_say([
|
||||
f"正在查询知识库 `{selected_kb}` ...",
|
||||
|
|
|
|||
Loading…
Reference in New Issue