增加了仅限GPT4的agent功能,陆续补充,中文版readme已写 (#1611)
This commit is contained in:
parent
c546b4271e
commit
5702554171
27
README.md
27
README.md
|
|
@ -189,6 +189,18 @@ docker run -d --gpus all -p 80:8501 registry.cn-beijing.aliyuncs.com/chatchat/ch
|
||||||
|
|
||||||
关于如何使用自定义分词器和贡献自己的分词器,可以参考[Text Splitter 贡献说明](docs/splitter.md)。
|
关于如何使用自定义分词器和贡献自己的分词器,可以参考[Text Splitter 贡献说明](docs/splitter.md)。
|
||||||
|
|
||||||
|
## Agent生态
|
||||||
|
### 基础的Agent
|
||||||
|
在本版本中,我们实现了一个简单的基于OpenAI的React的Agent模型,目前,经过我们测试,仅有以下两个模型支持:
|
||||||
|
+ OpenAI GPT4
|
||||||
|
+ ChatGLM2-130B
|
||||||
|
|
||||||
|
目前版本的Agent仍然需要对提示词进行大量调试,调试位置
|
||||||
|
|
||||||
|
### 构建自己的Agent工具
|
||||||
|
|
||||||
|
详见 (docs/自定义Agent.md)
|
||||||
|
|
||||||
## Docker 部署
|
## Docker 部署
|
||||||
|
|
||||||
🐳 Docker 镜像地址: `registry.cn-beijing.aliyuncs.com/chatchat/chatchat:0.2.3)`
|
🐳 Docker 镜像地址: `registry.cn-beijing.aliyuncs.com/chatchat/chatchat:0.2.3)`
|
||||||
|
|
@ -392,23 +404,24 @@ CUDA_VISIBLE_DEVICES=0,1 python startup.py -a
|
||||||
- [ ] 结构化数据接入
|
- [ ] 结构化数据接入
|
||||||
- [X] .csv
|
- [X] .csv
|
||||||
- [ ] .xlsx
|
- [ ] .xlsx
|
||||||
- [ ] 分词及召回
|
- [X] 分词及召回
|
||||||
- [ ] 接入不同类型 TextSplitter
|
- [X] 接入不同类型 TextSplitter
|
||||||
- [ ] 优化依据中文标点符号设计的 ChineseTextSplitter
|
- [X] 优化依据中文标点符号设计的 ChineseTextSplitter
|
||||||
- [ ] 重新实现上下文拼接召回
|
- [X] 重新实现上下文拼接召回
|
||||||
- [ ] 本地网页接入
|
- [ ] 本地网页接入
|
||||||
- [ ] SQL 接入
|
- [ ] SQL 接入
|
||||||
- [ ] 知识图谱/图数据库接入
|
- [ ] 知识图谱/图数据库接入
|
||||||
- [X] 搜索引擎接入
|
- [X] 搜索引擎接入
|
||||||
- [X] Bing 搜索
|
- [X] Bing 搜索
|
||||||
- [X] DuckDuckGo 搜索
|
- [X] DuckDuckGo 搜索
|
||||||
- [ ] Agent 实现
|
- [X] Agent 实现
|
||||||
|
- [X]基础React形式的Agent实现,包括调用计算器等
|
||||||
- [X] LLM 模型接入
|
- [X] LLM 模型接入
|
||||||
- [X] 支持通过调用 [FastChat](https://github.com/lm-sys/fastchat) api 调用 llm
|
- [X] 支持通过调用 [FastChat](https://github.com/lm-sys/fastchat) api 调用 llm
|
||||||
- [ ] 支持 ChatGLM API 等 LLM API 的接入
|
- [X] 支持 ChatGLM API 等 LLM API 的接入
|
||||||
- [X] Embedding 模型接入
|
- [X] Embedding 模型接入
|
||||||
- [X] 支持调用 HuggingFace 中各开源 Emebdding 模型
|
- [X] 支持调用 HuggingFace 中各开源 Emebdding 模型
|
||||||
- [ ] 支持 OpenAI Embedding API 等 Embedding API 的接入
|
- [X] 支持 OpenAI Embedding API 等 Embedding API 的接入
|
||||||
- [X] 基于 FastAPI 的 API 方式调用
|
- [X] 基于 FastAPI 的 API 方式调用
|
||||||
- [X] Web UI
|
- [X] Web UI
|
||||||
- [X] 基于 Streamlit 的 Web UI
|
- [X] 基于 Streamlit 的 Web UI
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,12 @@
|
||||||
langchain==0.0.287
|
langchain>=0.0.302
|
||||||
fschat[model_worker]==0.2.29
|
fschat[model_worker]==0.2.29
|
||||||
openai
|
openai
|
||||||
sentence_transformers
|
sentence_transformers
|
||||||
transformers>=4.31.0
|
transformers>=4.33.0
|
||||||
torch~=2.0.0
|
torch>=2.0.1
|
||||||
fastapi~=0.99.1
|
torchvision
|
||||||
|
torchaudio
|
||||||
|
fastapi>=0.103.1
|
||||||
nltk~=3.8.1
|
nltk~=3.8.1
|
||||||
uvicorn~=0.23.1
|
uvicorn~=0.23.1
|
||||||
starlette~=0.27.0
|
starlette~=0.27.0
|
||||||
|
|
@ -40,9 +42,13 @@ pandas~=2.0.3
|
||||||
streamlit>=1.26.0
|
streamlit>=1.26.0
|
||||||
streamlit-option-menu>=0.3.6
|
streamlit-option-menu>=0.3.6
|
||||||
streamlit-antd-components>=0.1.11
|
streamlit-antd-components>=0.1.11
|
||||||
streamlit-chatbox >=1.1.6, <=1.1.7
|
streamlit-chatbox>=1.1.9
|
||||||
streamlit-aggrid>=0.3.4.post3
|
streamlit-aggrid>=0.3.4.post3
|
||||||
httpx~=0.24.1
|
httpx~=0.24.1
|
||||||
watchdog
|
watchdog
|
||||||
tqdm
|
tqdm
|
||||||
websockets
|
websockets
|
||||||
|
tiktoken
|
||||||
|
einops
|
||||||
|
scipy
|
||||||
|
transformers_stream_generator==0.0.4
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,12 @@
|
||||||
langchain==0.0.287
|
langchain>=0.0.302
|
||||||
fschat[model_worker]==0.2.29
|
fschat[model_worker]==0.2.29
|
||||||
openai
|
openai
|
||||||
sentence_transformers
|
sentence_transformers
|
||||||
transformers>=4.31.0
|
transformers>=4.33.0
|
||||||
torch~=2.0.0
|
torch >=2.0.1
|
||||||
fastapi~=0.99.1
|
torchvision
|
||||||
|
torchaudio
|
||||||
|
fastapi>=0.103.1
|
||||||
nltk~=3.8.1
|
nltk~=3.8.1
|
||||||
uvicorn~=0.23.1
|
uvicorn~=0.23.1
|
||||||
starlette~=0.27.0
|
starlette~=0.27.0
|
||||||
|
|
@ -17,13 +19,14 @@ accelerate
|
||||||
spacy
|
spacy
|
||||||
PyMuPDF==1.22.5
|
PyMuPDF==1.22.5
|
||||||
rapidocr_onnxruntime>=1.3.2
|
rapidocr_onnxruntime>=1.3.2
|
||||||
|
|
||||||
requests
|
requests
|
||||||
pathlib
|
pathlib
|
||||||
pytest
|
pytest
|
||||||
scikit-learn
|
scikit-learn
|
||||||
numexpr
|
numexpr
|
||||||
vllm==0.1.7; sys_platform == "linux"
|
vllm==0.1.7; sys_platform == "linux"
|
||||||
|
|
||||||
|
|
||||||
# online api libs
|
# online api libs
|
||||||
# zhipuai
|
# zhipuai
|
||||||
# dashscope>=1.10.0 # qwen
|
# dashscope>=1.10.0 # qwen
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ pandas~=2.0.3
|
||||||
streamlit>=1.26.0
|
streamlit>=1.26.0
|
||||||
streamlit-option-menu>=0.3.6
|
streamlit-option-menu>=0.3.6
|
||||||
streamlit-antd-components>=0.1.11
|
streamlit-antd-components>=0.1.11
|
||||||
streamlit-chatbox >=1.1.6, <=1.1.7
|
streamlit-chatbox>=1.1.9
|
||||||
streamlit-aggrid>=0.3.4.post3
|
streamlit-aggrid>=0.3.4.post3
|
||||||
httpx~=0.24.1
|
httpx~=0.24.1
|
||||||
nltk
|
nltk
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,109 @@
|
||||||
|
from uuid import UUID
|
||||||
|
from langchain.callbacks import AsyncIteratorCallbackHandler
|
||||||
|
import json
|
||||||
|
import asyncio
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from langchain.schema import AgentFinish, AgentAction
|
||||||
|
from langchain.schema.output import LLMResult
|
||||||
|
|
||||||
|
|
||||||
|
def dumps(obj: Dict) -> str:
|
||||||
|
return json.dumps(obj, ensure_ascii=False)
|
||||||
|
|
||||||
|
|
||||||
|
class Status:
|
||||||
|
start: int = 1
|
||||||
|
running: int = 2
|
||||||
|
complete: int = 3
|
||||||
|
agent_action: int = 4
|
||||||
|
agent_finish: int = 5
|
||||||
|
error: int = 6
|
||||||
|
make_tool: int = 7
|
||||||
|
|
||||||
|
|
||||||
|
class CustomAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.queue = asyncio.Queue()
|
||||||
|
self.done = asyncio.Event()
|
||||||
|
self.cur_tool = {}
|
||||||
|
self.out = True
|
||||||
|
|
||||||
|
async def on_tool_start(self, serialized: Dict[str, Any], input_str: str, *, run_id: UUID,
|
||||||
|
parent_run_id: UUID | None = None, tags: List[str] | None = None,
|
||||||
|
metadata: Dict[str, Any] | None = None, **kwargs: Any) -> None:
|
||||||
|
self.cur_tool = {
|
||||||
|
"tool_name": serialized["name"],
|
||||||
|
"input_str": input_str,
|
||||||
|
"output_str": "",
|
||||||
|
"status": Status.agent_action,
|
||||||
|
"run_id": run_id.hex,
|
||||||
|
"llm_token": "",
|
||||||
|
"final_answer": "",
|
||||||
|
"error": "",
|
||||||
|
}
|
||||||
|
self.queue.put_nowait(dumps(self.cur_tool))
|
||||||
|
|
||||||
|
async def on_tool_end(self, output: str, *, run_id: UUID, parent_run_id: UUID | None = None,
|
||||||
|
tags: List[str] | None = None, **kwargs: Any) -> None:
|
||||||
|
self.out = True
|
||||||
|
self.cur_tool.update(
|
||||||
|
status=Status.agent_finish,
|
||||||
|
output_str=output.replace("Answer:", ""),
|
||||||
|
)
|
||||||
|
self.queue.put_nowait(dumps(self.cur_tool))
|
||||||
|
|
||||||
|
async def on_tool_error(self, error: Exception | KeyboardInterrupt, *, run_id: UUID,
|
||||||
|
parent_run_id: UUID | None = None, tags: List[str] | None = None, **kwargs: Any) -> None:
|
||||||
|
self.out = True
|
||||||
|
self.cur_tool.update(
|
||||||
|
status=Status.error,
|
||||||
|
error=str(error),
|
||||||
|
)
|
||||||
|
self.queue.put_nowait(dumps(self.cur_tool))
|
||||||
|
|
||||||
|
async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||||
|
if token:
|
||||||
|
if token == "Action":
|
||||||
|
self.out = False
|
||||||
|
self.cur_tool.update(
|
||||||
|
status=Status.running,
|
||||||
|
llm_token="\n\n",
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.out:
|
||||||
|
self.cur_tool.update(
|
||||||
|
status=Status.running,
|
||||||
|
llm_token=token,
|
||||||
|
)
|
||||||
|
self.queue.put_nowait(dumps(self.cur_tool))
|
||||||
|
|
||||||
|
async def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any) -> None:
|
||||||
|
self.cur_tool.update(
|
||||||
|
status=Status.start,
|
||||||
|
llm_token="",
|
||||||
|
)
|
||||||
|
self.queue.put_nowait(dumps(self.cur_tool))
|
||||||
|
|
||||||
|
async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||||||
|
self.out = True
|
||||||
|
self.cur_tool.update(
|
||||||
|
status=Status.complete,
|
||||||
|
llm_token="",
|
||||||
|
)
|
||||||
|
self.queue.put_nowait(dumps(self.cur_tool))
|
||||||
|
|
||||||
|
async def on_llm_error(self, error: Exception | KeyboardInterrupt, **kwargs: Any) -> None:
|
||||||
|
self.cur_tool.update(
|
||||||
|
status=Status.error,
|
||||||
|
error=str(error),
|
||||||
|
)
|
||||||
|
self.queue.put_nowait(dumps(self.cur_tool))
|
||||||
|
|
||||||
|
async def on_agent_finish(
|
||||||
|
self, finish: AgentFinish, *, run_id: UUID, parent_run_id: Optional[UUID] = None,
|
||||||
|
tags: Optional[List[str]] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
self.cur_tool = {}
|
||||||
|
|
@ -1,22 +1,23 @@
|
||||||
template = """
|
template = """
|
||||||
尽可能地回答以下问题。你可以使用以下工具:{tools}
|
尽可能地回答以下问题。你可以使用以下工具:{tools}
|
||||||
请按照以下格式进行:
|
请按照以下格式进行:
|
||||||
Question: 需要你回答的输入问题
|
|
||||||
Thought: 你应该总是思考该做什么
|
Question: 需要你回答的输入问题。
|
||||||
|
Thought: 你应该总是思考该做什么,并告诉我你要用什么工具。
|
||||||
Action: 需要使用的工具,应该是[{tool_names}]中的一个
|
Action: 需要使用的工具,应该是[{tool_names}]中的一个
|
||||||
Action Input: 传入工具的内容
|
Action Input: 传入工具的内容
|
||||||
Observation: 行动的结果
|
Observation: 行动的结果
|
||||||
... (这个Thought/Action/Action Input/Observation可以重复N次)
|
... (这个Thought/Action/Action Input/Observation可以重复N次)
|
||||||
Thought: 我现在知道最后的答案
|
Thought: 通过使用工具,我是否知道了答案,如果知道,就自然的回答问题,如果不知道,继续使用工具或者自己的知识 \n
|
||||||
Final Answer: 对原始输入问题的最终答案
|
Final Answer: 这个问题的答案是,输出完整的句子。
|
||||||
|
|
||||||
现在开始!
|
现在开始!
|
||||||
|
|
||||||
之前的对话:
|
之前的对话:
|
||||||
{history}
|
{history}
|
||||||
|
New question:
|
||||||
New question: {input}
|
{input}
|
||||||
Thought: {agent_scratchpad}"""
|
Thought:
|
||||||
|
{agent_scratchpad}"""
|
||||||
|
|
||||||
|
|
||||||
# ChatGPT 提示词模板
|
# ChatGPT 提示词模板
|
||||||
|
|
@ -84,7 +85,7 @@ class CustomOutputParser(AgentOutputParser):
|
||||||
return AgentFinish(
|
return AgentFinish(
|
||||||
# Return values is generally always a dictionary with a single `output` key
|
# Return values is generally always a dictionary with a single `output` key
|
||||||
# It is not recommended to try anything else at the moment :)
|
# It is not recommended to try anything else at the moment :)
|
||||||
return_values={"output": llm_output.split("Final Answer:")[-1].strip()},
|
return_values={"output": llm_output.replace("Final Answer:", "").strip()},
|
||||||
log=llm_output,
|
log=llm_output,
|
||||||
)
|
)
|
||||||
# Parse out the action and action input
|
# Parse out the action and action input
|
||||||
|
|
@ -95,10 +96,13 @@ class CustomOutputParser(AgentOutputParser):
|
||||||
return_values={"output": f"调用agent失败: `{llm_output}`"},
|
return_values={"output": f"调用agent失败: `{llm_output}`"},
|
||||||
log=llm_output,
|
log=llm_output,
|
||||||
)
|
)
|
||||||
raise OutputParserException(f"调用agent失败: `{llm_output}`")
|
|
||||||
action = match.group(1).strip()
|
action = match.group(1).strip()
|
||||||
action_input = match.group(2)
|
action_input = match.group(2)
|
||||||
# Return the action and action input
|
# Return the action and action input
|
||||||
return AgentAction(tool=action, tool_input=action_input.strip(" ").strip('"'), log=llm_output)
|
return AgentAction(
|
||||||
|
tool=action,
|
||||||
|
tool_input=action_input.strip(" ").strip('"'),
|
||||||
|
log=llm_output
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ from configs.model_config import LLM_MODEL,TEMPERATURE
|
||||||
|
|
||||||
_PROMPT_TEMPLATE = '''
|
_PROMPT_TEMPLATE = '''
|
||||||
# 指令
|
# 指令
|
||||||
接下来,作为一个专业的翻译专家,当我给出英文句子或段落时,你将提供通顺且具有可读性的对应语言的翻译。注意:
|
接下来,作为一个专业的翻译专家,当我给出句子或段落时,你将提供通顺且具有可读性的对应语言的翻译。注意:
|
||||||
1. 确保翻译结果流畅且易于理解
|
1. 确保翻译结果流畅且易于理解
|
||||||
2. 无论提供的是陈述句或疑问句,只进行翻译
|
2. 无论提供的是陈述句或疑问句,只进行翻译
|
||||||
3. 不添加与原文无关的内容
|
3. 不添加与原文无关的内容
|
||||||
|
|
@ -21,6 +21,14 @@ _PROMPT_TEMPLATE = '''
|
||||||
${{翻译结果}}
|
${{翻译结果}}
|
||||||
```
|
```
|
||||||
答案: ${{答案}}
|
答案: ${{答案}}
|
||||||
|
|
||||||
|
以下是一个例子
|
||||||
|
问题: 翻译13成英语
|
||||||
|
```text
|
||||||
|
13 English
|
||||||
|
```output
|
||||||
|
thirteen
|
||||||
|
答案: thirteen
|
||||||
'''
|
'''
|
||||||
|
|
||||||
PROMPT = PromptTemplate(
|
PROMPT = PromptTemplate(
|
||||||
|
|
|
||||||
|
|
@ -113,7 +113,7 @@ def weather(query):
|
||||||
return "只能查看24小时内的天气,无法回答"
|
return "只能查看24小时内的天气,无法回答"
|
||||||
if time == "None":
|
if time == "None":
|
||||||
time = "24" # 免费的版本只能24小时内的天气
|
time = "24" # 免费的版本只能24小时内的天气
|
||||||
key = "" # 和风天气API Key
|
key = "315625cdca234137944d7f8956106a3e" # 和风天气API Key
|
||||||
if key == "":
|
if key == "":
|
||||||
return "请先在代码中填入和风天气API Key"
|
return "请先在代码中填入和风天气API Key"
|
||||||
city_info = get_city_info(location=location, adm=adm, key=key)
|
city_info = get_city_info(location=location, adm=adm, key=key)
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
from langchain.memory import ConversationBufferWindowMemory
|
from langchain.memory import ConversationBufferWindowMemory
|
||||||
from server.agent.tools import tools, tool_names
|
from server.agent.tools import tools, tool_names
|
||||||
|
from server.agent.callbacks import CustomAsyncIteratorCallbackHandler, Status, dumps
|
||||||
from langchain.agents import AgentExecutor, LLMSingleActionAgent
|
from langchain.agents import AgentExecutor, LLMSingleActionAgent
|
||||||
from server.agent.custom_template import CustomOutputParser, prompt
|
from server.agent.custom_template import CustomOutputParser, prompt
|
||||||
from fastapi import Body
|
from fastapi import Body
|
||||||
|
|
@ -7,15 +8,12 @@ from fastapi.responses import StreamingResponse
|
||||||
from configs.model_config import LLM_MODEL, TEMPERATURE, HISTORY_LEN
|
from configs.model_config import LLM_MODEL, TEMPERATURE, HISTORY_LEN
|
||||||
from server.utils import wrap_done, get_ChatOpenAI
|
from server.utils import wrap_done, get_ChatOpenAI
|
||||||
from langchain import LLMChain
|
from langchain import LLMChain
|
||||||
from langchain.callbacks import AsyncIteratorCallbackHandler
|
|
||||||
from langchain.callbacks.streaming_aiter_final_only import AsyncFinalIteratorCallbackHandler
|
|
||||||
from typing import AsyncIterable
|
from typing import AsyncIterable
|
||||||
import asyncio
|
import asyncio
|
||||||
from langchain.prompts.chat import ChatPromptTemplate
|
from langchain.prompts.chat import ChatPromptTemplate
|
||||||
from typing import List
|
from typing import List
|
||||||
from server.chat.utils import History
|
from server.chat.utils import History
|
||||||
|
import json
|
||||||
memory = ConversationBufferWindowMemory(k=HISTORY_LEN)
|
|
||||||
async def agent_chat(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]),
|
async def agent_chat(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]),
|
||||||
history: List[History] = Body([],
|
history: List[History] = Body([],
|
||||||
description="历史对话",
|
description="历史对话",
|
||||||
|
|
@ -30,15 +28,11 @@ async def agent_chat(query: str = Body(..., description="用户输入", examples
|
||||||
):
|
):
|
||||||
history = [History.from_data(h) for h in history]
|
history = [History.from_data(h) for h in history]
|
||||||
|
|
||||||
async def chat_iterator(query: str,
|
async def chat_iterator() -> AsyncIterable[str]:
|
||||||
history: List[History] = [],
|
callback = CustomAsyncIteratorCallbackHandler()
|
||||||
model_name: str = LLM_MODEL,
|
|
||||||
) -> AsyncIterable[str]:
|
|
||||||
callback = AsyncFinalIteratorCallbackHandler()
|
|
||||||
model = get_ChatOpenAI(
|
model = get_ChatOpenAI(
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
callbacks=[callback],
|
|
||||||
)
|
)
|
||||||
output_parser = CustomOutputParser()
|
output_parser = CustomOutputParser()
|
||||||
llm_chain = LLMChain(llm=model, prompt=prompt)
|
llm_chain = LLMChain(llm=model, prompt=prompt)
|
||||||
|
|
@ -46,28 +40,69 @@ async def agent_chat(query: str = Body(..., description="用户输入", examples
|
||||||
llm_chain=llm_chain,
|
llm_chain=llm_chain,
|
||||||
output_parser=output_parser,
|
output_parser=output_parser,
|
||||||
stop=["\nObservation:"],
|
stop=["\nObservation:"],
|
||||||
allowed_tools=tool_names
|
allowed_tools=tool_names,
|
||||||
)
|
)
|
||||||
agent_executor = AgentExecutor.from_agent_and_tools(agent=agent, tools=tools, verbose=True, memory=memory,
|
# 把history转成agent的memory
|
||||||
callbacks=[callback])
|
memory = ConversationBufferWindowMemory(k=100)
|
||||||
|
|
||||||
|
for message in history:
|
||||||
|
# 检查消息的角色
|
||||||
|
if message.role == 'user':
|
||||||
|
# 添加用户消息
|
||||||
|
memory.chat_memory.add_user_message(message.content)
|
||||||
|
else:
|
||||||
|
# 添加AI消息
|
||||||
|
memory.chat_memory.add_ai_message(message.content)
|
||||||
|
|
||||||
|
agent_executor = AgentExecutor.from_agent_and_tools(agent=agent,
|
||||||
|
tools=tools,
|
||||||
|
verbose=True,
|
||||||
|
memory=memory,
|
||||||
|
)
|
||||||
|
# TODO: history is not used
|
||||||
input_msg = History(role="user", content="{{ input }}").to_msg_template(False)
|
input_msg = History(role="user", content="{{ input }}").to_msg_template(False)
|
||||||
chat_prompt = ChatPromptTemplate.from_messages(
|
chat_prompt = ChatPromptTemplate.from_messages(
|
||||||
[i.to_msg_template() for i in history] + [input_msg])
|
[i.to_msg_template() for i in history] + [input_msg])
|
||||||
task = asyncio.create_task(wrap_done(
|
task = asyncio.create_task(wrap_done(
|
||||||
agent_executor.acall(query),
|
agent_executor.acall(query, callbacks=[callback], include_run_info=True),
|
||||||
callback.done),
|
callback.done),
|
||||||
)
|
)
|
||||||
if stream:
|
if stream:
|
||||||
async for token in callback.aiter():
|
async for chunk in callback.aiter():
|
||||||
|
tools_use = []
|
||||||
# Use server-sent-events to stream the response
|
# Use server-sent-events to stream the response
|
||||||
yield token
|
data = json.loads(chunk)
|
||||||
|
if data["status"] == Status.start or data["status"] == Status.complete:
|
||||||
|
continue
|
||||||
|
if data["status"] == Status.agent_action:
|
||||||
|
yield json.dumps({"answer": "(正在使用工具,请注意工具栏变化) \n\n"}, ensure_ascii=False)
|
||||||
|
if data["status"] == Status.agent_finish:
|
||||||
|
tools_use.append("工具名称: " + data["tool_name"])
|
||||||
|
tools_use.append("工具输入: " + data["input_str"])
|
||||||
|
tools_use.append("工具输出: " + data["output_str"])
|
||||||
|
yield json.dumps({"tools": tools_use}, ensure_ascii=False)
|
||||||
|
yield json.dumps({"answer": data["llm_token"]}, ensure_ascii=False)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
answer = ""
|
pass
|
||||||
async for token in callback.aiter():
|
# agent必须要steram=True
|
||||||
answer += token
|
# result = []
|
||||||
yield answer
|
# async for chunk in callback.aiter():
|
||||||
|
# data = json.loads(chunk)
|
||||||
|
# status = data["status"]
|
||||||
|
# if status == Status.start:
|
||||||
|
# result.append(chunk)
|
||||||
|
# elif status == Status.running:
|
||||||
|
# result[-1]["llm_token"] += chunk["llm_token"]
|
||||||
|
# elif status == Status.complete:
|
||||||
|
# result[-1]["status"] = Status.complete
|
||||||
|
# elif status == Status.agent_finish:
|
||||||
|
# result.append(chunk)
|
||||||
|
# elif status == Status.agent_finish:
|
||||||
|
# pass
|
||||||
|
# yield dumps(result)
|
||||||
|
|
||||||
await task
|
await task
|
||||||
|
|
||||||
return StreamingResponse(chat_iterator(query, history, model_name),
|
return StreamingResponse(chat_iterator(),
|
||||||
media_type="text/event-stream")
|
media_type="text/event-stream")
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,6 @@ from configs import LLM_MODEL, TEMPERATURE
|
||||||
from server.utils import get_model_worker_config
|
from server.utils import get_model_worker_config
|
||||||
from typing import List, Dict
|
from typing import List, Dict
|
||||||
|
|
||||||
|
|
||||||
chat_box = ChatBox(
|
chat_box = ChatBox(
|
||||||
assistant_avatar=os.path.join(
|
assistant_avatar=os.path.join(
|
||||||
"img",
|
"img",
|
||||||
|
|
@ -158,18 +157,31 @@ def dialogue_page(api: ApiRequest):
|
||||||
chat_box.update_msg(text)
|
chat_box.update_msg(text)
|
||||||
chat_box.update_msg(text, streaming=False) # 更新最终的字符串,去除光标
|
chat_box.update_msg(text, streaming=False) # 更新最终的字符串,去除光标
|
||||||
|
|
||||||
elif dialogue_mode == "自定义Agent问答":
|
|
||||||
chat_box.ai_say("正在调用工具回答...")
|
|
||||||
text = ""
|
|
||||||
r = api.agent_chat(prompt, history=history, model=llm_model, temperature=temperature)
|
|
||||||
for t in r:
|
|
||||||
if error_msg := check_error_msg(t): # check whether error occured
|
|
||||||
st.error(error_msg)
|
|
||||||
break
|
|
||||||
text += t
|
|
||||||
chat_box.update_msg(text)
|
|
||||||
chat_box.update_msg(text, streaming=False) # 更新最终的字符串,去除光标
|
|
||||||
|
|
||||||
|
elif dialogue_mode == "自定义Agent问答":
|
||||||
|
chat_box.ai_say([
|
||||||
|
f"正在思考和寻找工具 ...",])
|
||||||
|
text = ""
|
||||||
|
element_index = 0
|
||||||
|
for d in api.agent_chat(prompt,
|
||||||
|
history=history,
|
||||||
|
model=llm_model,
|
||||||
|
temperature=temperature):
|
||||||
|
try:
|
||||||
|
d = json.loads(d)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
if error_msg := check_error_msg(d): # check whether error occured
|
||||||
|
st.error(error_msg)
|
||||||
|
|
||||||
|
elif chunk := d.get("answer"):
|
||||||
|
text += chunk
|
||||||
|
chat_box.update_msg(text, element_index=0)
|
||||||
|
elif chunk := d.get("tools"):
|
||||||
|
element_index += 1
|
||||||
|
chat_box.insert_msg(Markdown("...", in_expander=True, title="使用工具...", state="complete"))
|
||||||
|
chat_box.update_msg("\n\n".join(d.get("tools", [])), element_index=element_index, streaming=False)
|
||||||
|
chat_box.update_msg(text, element_index=0, streaming=False)
|
||||||
elif dialogue_mode == "知识库问答":
|
elif dialogue_mode == "知识库问答":
|
||||||
chat_box.ai_say([
|
chat_box.ai_say([
|
||||||
f"正在查询知识库 `{selected_kb}` ...",
|
f"正在查询知识库 `{selected_kb}` ...",
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue