2023-09-17 11:19:16 +08:00
|
|
|
|
from langchain.memory import ConversationBufferWindowMemory
|
2023-11-12 16:45:50 +08:00
|
|
|
|
|
|
|
|
|
|
from server.agent.custom_agent.ChatGLM3Agent import initialize_glm3_agent
|
2023-10-18 15:19:02 +08:00
|
|
|
|
from server.agent.tools_select import tools, tool_names
|
|
|
|
|
|
from server.agent.callbacks import CustomAsyncIteratorCallbackHandler, Status
|
2023-11-12 16:45:50 +08:00
|
|
|
|
from langchain.agents import LLMSingleActionAgent, AgentExecutor
|
2023-09-28 20:19:26 +08:00
|
|
|
|
from server.agent.custom_template import CustomOutputParser, CustomPromptTemplate
|
2023-09-17 11:19:16 +08:00
|
|
|
|
from fastapi import Body
|
|
|
|
|
|
from fastapi.responses import StreamingResponse
|
2023-11-09 22:15:52 +08:00
|
|
|
|
from configs import LLM_MODELS, TEMPERATURE, HISTORY_LEN, Agent_MODEL
|
2023-09-28 20:19:26 +08:00
|
|
|
|
from server.utils import wrap_done, get_ChatOpenAI, get_prompt_template
|
2023-09-27 21:53:47 +08:00
|
|
|
|
from langchain.chains import LLMChain
|
2023-11-12 16:45:50 +08:00
|
|
|
|
from typing import AsyncIterable, Optional
|
2023-09-17 11:19:16 +08:00
|
|
|
|
import asyncio
|
|
|
|
|
|
from typing import List
|
|
|
|
|
|
from server.chat.utils import History
|
2023-09-27 19:19:25 +08:00
|
|
|
|
import json
|
2023-10-18 15:19:02 +08:00
|
|
|
|
from server.agent import model_container
|
|
|
|
|
|
from server.knowledge_base.kb_service.base import get_kb_details
|
2023-09-28 20:19:26 +08:00
|
|
|
|
|
2023-10-27 17:56:27 +08:00
|
|
|
|
|
2023-09-17 11:19:16 +08:00
|
|
|
|
async def agent_chat(query: str = Body(..., description="用户输入", examples=["恼羞成怒"]),
|
2023-09-27 19:19:25 +08:00
|
|
|
|
history: List[History] = Body([],
|
|
|
|
|
|
description="历史对话",
|
|
|
|
|
|
examples=[[
|
2023-10-18 15:19:02 +08:00
|
|
|
|
{"role": "user", "content": "请使用知识库工具查询今天北京天气"},
|
2023-10-27 17:56:27 +08:00
|
|
|
|
{"role": "assistant",
|
|
|
|
|
|
"content": "使用天气查询工具查询到今天北京多云,10-14摄氏度,东北风2级,易感冒"}]]
|
2023-09-27 19:19:25 +08:00
|
|
|
|
),
|
|
|
|
|
|
stream: bool = Body(False, description="流式输出"),
|
2023-11-09 22:15:52 +08:00
|
|
|
|
model_name: str = Body(LLM_MODELS[0], description="LLM 模型名称。"),
|
2023-09-27 21:17:50 +08:00
|
|
|
|
temperature: float = Body(TEMPERATURE, description="LLM 采样温度", ge=0.0, le=1.0),
|
2023-10-26 22:44:48 +08:00
|
|
|
|
max_tokens: Optional[int] = Body(None, description="限制LLM生成Token数量,默认None代表模型最大值"),
|
2023-10-27 17:56:27 +08:00
|
|
|
|
prompt_name: str = Body("default",
|
|
|
|
|
|
description="使用的prompt模板名称(在configs/prompt_config.py中配置)"),
|
2023-09-27 19:19:25 +08:00
|
|
|
|
# top_p: float = Body(TOP_P, description="LLM 核采样。勿与temperature同时设置", gt=0.0, lt=1.0),
|
|
|
|
|
|
):
|
2023-09-17 11:19:16 +08:00
|
|
|
|
history = [History.from_data(h) for h in history]
|
|
|
|
|
|
|
2023-09-28 20:19:26 +08:00
|
|
|
|
async def agent_chat_iterator(
|
|
|
|
|
|
query: str,
|
|
|
|
|
|
history: Optional[List[History]],
|
2023-11-09 22:15:52 +08:00
|
|
|
|
model_name: str = LLM_MODELS[0],
|
2023-09-28 20:19:26 +08:00
|
|
|
|
prompt_name: str = prompt_name,
|
|
|
|
|
|
) -> AsyncIterable[str]:
|
2023-09-27 19:19:25 +08:00
|
|
|
|
callback = CustomAsyncIteratorCallbackHandler()
|
2023-11-26 16:47:58 +08:00
|
|
|
|
if isinstance(max_tokens, int) and max_tokens <= 0:
|
|
|
|
|
|
max_tokens = None
|
|
|
|
|
|
|
2023-09-17 11:19:16 +08:00
|
|
|
|
model = get_ChatOpenAI(
|
|
|
|
|
|
model_name=model_name,
|
|
|
|
|
|
temperature=temperature,
|
2023-10-12 16:18:56 +08:00
|
|
|
|
max_tokens=max_tokens,
|
2023-10-18 15:19:02 +08:00
|
|
|
|
callbacks=[callback],
|
2023-09-17 11:19:16 +08:00
|
|
|
|
)
|
2023-09-28 20:19:26 +08:00
|
|
|
|
|
2023-10-18 15:19:02 +08:00
|
|
|
|
## 传入全局变量来实现agent调用
|
|
|
|
|
|
kb_list = {x["kb_name"]: x for x in get_kb_details()}
|
|
|
|
|
|
model_container.DATABASE = {name: details['kb_info'] for name, details in kb_list.items()}
|
2023-10-22 00:00:15 +08:00
|
|
|
|
|
|
|
|
|
|
if Agent_MODEL:
|
|
|
|
|
|
## 如果有指定使用Agent模型来完成任务
|
|
|
|
|
|
model_agent = get_ChatOpenAI(
|
|
|
|
|
|
model_name=Agent_MODEL,
|
|
|
|
|
|
temperature=temperature,
|
|
|
|
|
|
max_tokens=max_tokens,
|
|
|
|
|
|
callbacks=[callback],
|
|
|
|
|
|
)
|
|
|
|
|
|
model_container.MODEL = model_agent
|
|
|
|
|
|
else:
|
|
|
|
|
|
model_container.MODEL = model
|
2023-10-18 15:19:02 +08:00
|
|
|
|
|
|
|
|
|
|
prompt_template = get_prompt_template("agent_chat", prompt_name)
|
|
|
|
|
|
prompt_template_agent = CustomPromptTemplate(
|
|
|
|
|
|
template=prompt_template,
|
2023-09-28 20:19:26 +08:00
|
|
|
|
tools=tools,
|
|
|
|
|
|
input_variables=["input", "intermediate_steps", "history"]
|
|
|
|
|
|
)
|
2023-09-17 11:19:16 +08:00
|
|
|
|
output_parser = CustomOutputParser()
|
2023-10-18 15:19:02 +08:00
|
|
|
|
llm_chain = LLMChain(llm=model, prompt=prompt_template_agent)
|
2023-09-27 19:19:25 +08:00
|
|
|
|
# 把history转成agent的memory
|
2023-09-28 20:19:26 +08:00
|
|
|
|
memory = ConversationBufferWindowMemory(k=HISTORY_LEN * 2)
|
2023-09-27 19:19:25 +08:00
|
|
|
|
for message in history:
|
|
|
|
|
|
# 检查消息的角色
|
|
|
|
|
|
if message.role == 'user':
|
|
|
|
|
|
# 添加用户消息
|
|
|
|
|
|
memory.chat_memory.add_user_message(message.content)
|
|
|
|
|
|
else:
|
|
|
|
|
|
# 添加AI消息
|
|
|
|
|
|
memory.chat_memory.add_ai_message(message.content)
|
2023-11-12 16:45:50 +08:00
|
|
|
|
|
|
|
|
|
|
if "chatglm3" in model_container.MODEL.model_name:
|
|
|
|
|
|
agent_executor = initialize_glm3_agent(
|
|
|
|
|
|
llm=model,
|
|
|
|
|
|
tools=tools,
|
|
|
|
|
|
callback_manager=None,
|
2023-11-14 17:45:22 +08:00
|
|
|
|
# Langchain Prompt is not constructed directly here, it is constructed inside the GLM3 agent.
|
|
|
|
|
|
prompt=prompt_template,
|
|
|
|
|
|
input_variables=["input", "intermediate_steps", "history"],
|
2023-11-12 16:45:50 +08:00
|
|
|
|
memory=memory,
|
2023-11-14 17:45:22 +08:00
|
|
|
|
verbose=True,
|
2023-11-12 16:45:50 +08:00
|
|
|
|
)
|
|
|
|
|
|
else:
|
|
|
|
|
|
agent = LLMSingleActionAgent(
|
|
|
|
|
|
llm_chain=llm_chain,
|
|
|
|
|
|
output_parser=output_parser,
|
|
|
|
|
|
stop=["\nObservation:", "Observation"],
|
|
|
|
|
|
allowed_tools=tool_names,
|
|
|
|
|
|
)
|
|
|
|
|
|
agent_executor = AgentExecutor.from_agent_and_tools(agent=agent,
|
|
|
|
|
|
tools=tools,
|
|
|
|
|
|
verbose=True,
|
|
|
|
|
|
memory=memory,
|
|
|
|
|
|
)
|
2023-10-07 11:26:11 +08:00
|
|
|
|
while True:
|
|
|
|
|
|
try:
|
|
|
|
|
|
task = asyncio.create_task(wrap_done(
|
2023-10-18 15:19:02 +08:00
|
|
|
|
agent_executor.acall(query, callbacks=[callback], include_run_info=True),
|
|
|
|
|
|
callback.done))
|
2023-10-07 11:26:11 +08:00
|
|
|
|
break
|
|
|
|
|
|
except:
|
|
|
|
|
|
pass
|
2023-10-18 15:19:02 +08:00
|
|
|
|
|
2023-09-17 11:19:16 +08:00
|
|
|
|
if stream:
|
2023-09-27 19:19:25 +08:00
|
|
|
|
async for chunk in callback.aiter():
|
|
|
|
|
|
tools_use = []
|
2023-09-17 11:19:16 +08:00
|
|
|
|
# Use server-sent-events to stream the response
|
2023-09-27 19:19:25 +08:00
|
|
|
|
data = json.loads(chunk)
|
|
|
|
|
|
if data["status"] == Status.start or data["status"] == Status.complete:
|
|
|
|
|
|
continue
|
2023-10-18 15:19:02 +08:00
|
|
|
|
elif data["status"] == Status.error:
|
|
|
|
|
|
tools_use.append("\n```\n")
|
2023-10-07 11:26:11 +08:00
|
|
|
|
tools_use.append("工具名称: " + data["tool_name"])
|
|
|
|
|
|
tools_use.append("工具状态: " + "调用失败")
|
|
|
|
|
|
tools_use.append("错误信息: " + data["error"])
|
|
|
|
|
|
tools_use.append("重新开始尝试")
|
|
|
|
|
|
tools_use.append("\n```\n")
|
|
|
|
|
|
yield json.dumps({"tools": tools_use}, ensure_ascii=False)
|
2023-10-18 15:19:02 +08:00
|
|
|
|
elif data["status"] == Status.tool_finish:
|
|
|
|
|
|
tools_use.append("\n```\n")
|
2023-09-27 19:19:25 +08:00
|
|
|
|
tools_use.append("工具名称: " + data["tool_name"])
|
2023-10-07 11:26:11 +08:00
|
|
|
|
tools_use.append("工具状态: " + "调用成功")
|
2023-09-27 19:19:25 +08:00
|
|
|
|
tools_use.append("工具输入: " + data["input_str"])
|
|
|
|
|
|
tools_use.append("工具输出: " + data["output_str"])
|
2023-10-07 11:26:11 +08:00
|
|
|
|
tools_use.append("\n```\n")
|
2023-09-27 19:19:25 +08:00
|
|
|
|
yield json.dumps({"tools": tools_use}, ensure_ascii=False)
|
2023-10-18 15:19:02 +08:00
|
|
|
|
elif data["status"] == Status.agent_finish:
|
2023-10-07 11:26:11 +08:00
|
|
|
|
yield json.dumps({"final_answer": data["final_answer"]}, ensure_ascii=False)
|
|
|
|
|
|
else:
|
|
|
|
|
|
yield json.dumps({"answer": data["llm_token"]}, ensure_ascii=False)
|
2023-09-27 19:19:25 +08:00
|
|
|
|
|
2023-10-18 15:19:02 +08:00
|
|
|
|
|
2023-09-17 11:19:16 +08:00
|
|
|
|
else:
|
2023-10-18 15:19:02 +08:00
|
|
|
|
answer = ""
|
|
|
|
|
|
final_answer = ""
|
|
|
|
|
|
async for chunk in callback.aiter():
|
|
|
|
|
|
# Use server-sent-events to stream the response
|
|
|
|
|
|
data = json.loads(chunk)
|
|
|
|
|
|
if data["status"] == Status.start or data["status"] == Status.complete:
|
|
|
|
|
|
continue
|
|
|
|
|
|
if data["status"] == Status.error:
|
|
|
|
|
|
answer += "\n```\n"
|
|
|
|
|
|
answer += "工具名称: " + data["tool_name"] + "\n"
|
|
|
|
|
|
answer += "工具状态: " + "调用失败" + "\n"
|
|
|
|
|
|
answer += "错误信息: " + data["error"] + "\n"
|
|
|
|
|
|
answer += "\n```\n"
|
|
|
|
|
|
if data["status"] == Status.tool_finish:
|
|
|
|
|
|
answer += "\n```\n"
|
|
|
|
|
|
answer += "工具名称: " + data["tool_name"] + "\n"
|
|
|
|
|
|
answer += "工具状态: " + "调用成功" + "\n"
|
|
|
|
|
|
answer += "工具输入: " + data["input_str"] + "\n"
|
|
|
|
|
|
answer += "工具输出: " + data["output_str"] + "\n"
|
|
|
|
|
|
answer += "\n```\n"
|
|
|
|
|
|
if data["status"] == Status.agent_finish:
|
|
|
|
|
|
final_answer = data["final_answer"]
|
|
|
|
|
|
else:
|
|
|
|
|
|
answer += data["llm_token"]
|
2023-09-17 11:19:16 +08:00
|
|
|
|
|
2023-10-18 15:19:02 +08:00
|
|
|
|
yield json.dumps({"answer": answer, "final_answer": final_answer}, ensure_ascii=False)
|
2023-09-17 11:19:16 +08:00
|
|
|
|
await task
|
|
|
|
|
|
|
2023-09-28 20:19:26 +08:00
|
|
|
|
return StreamingResponse(agent_chat_iterator(query=query,
|
|
|
|
|
|
history=history,
|
|
|
|
|
|
model_name=model_name,
|
|
|
|
|
|
prompt_name=prompt_name),
|
2023-09-17 11:19:16 +08:00
|
|
|
|
media_type="text/event-stream")
|