Langchain-Chatchat/libs/chatchat-server/chatchat/webui_pages/dialogue/dialogue.py

594 lines
24 KiB
Python
Raw Normal View History

2024-12-20 16:04:03 +08:00
import base64
import hashlib
import io
import os
import uuid
from datetime import datetime
from PIL import Image as PILImage
from typing import Dict, List
from urllib.parse import urlencode
# from audio_recorder_streamlit import audio_recorder
import openai
import streamlit as st
import streamlit_antd_components as sac
from streamlit_chatbox import *
from streamlit_extras.bottom_container import bottom
from streamlit_paste_button import paste_image_button
from chatchat.settings import Settings
from chatchat.server.callback_handler.agent_callback_handler import AgentStatus
from chatchat.server.knowledge_base.model.kb_document_model import DocumentWithVSId
from chatchat.server.knowledge_base.utils import format_reference
from chatchat.server.utils import MsgType, get_config_models, get_config_platforms, get_default_llm
from chatchat.webui_pages.utils import *
chat_box = ChatBox(assistant_avatar=get_img_base64("chatchat_icon_blue_square_v2.png"))
def save_session(conv_name: str = None):
"""save session state to chat context"""
chat_box.context_from_session(
conv_name, exclude=["selected_page", "prompt", "cur_conv_name", "upload_image"]
)
def restore_session(conv_name: str = None):
"""restore sesstion state from chat context"""
chat_box.context_to_session(
conv_name, exclude=["selected_page", "prompt", "cur_conv_name", "upload_image"]
)
def rerun():
"""
save chat context before rerun
"""
save_session()
st.rerun()
def get_messages_history(
history_len: int, content_in_expander: bool = False
) -> List[Dict]:
"""
返回消息历史
content_in_expander控制是否返回expander元素中的内容一般导出的时候可以选上传入LLM的history不需要
"""
def filter(msg):
content = [
x for x in msg["elements"] if x._output_method in ["markdown", "text"]
]
if not content_in_expander:
content = [x for x in content if not x._in_expander]
content = [x.content for x in content]
return {
"role": msg["role"],
"content": "\n\n".join(content),
}
messages = chat_box.filter_history(history_len=history_len, filter=filter)
if sys_msg := chat_box.context.get("system_message"):
messages = [{"role": "system", "content": sys_msg}] + messages
return messages
@st.cache_data
def upload_temp_docs(files, _api: ApiRequest) -> str:
"""
将文件上传到临时目录用于文件对话
返回临时向量库ID
"""
return _api.upload_temp_docs(files).get("data", {}).get("id")
@st.cache_data
def upload_image_file(file_name: str, content: bytes) -> dict:
'''upload image for vision model using openai sdk'''
client = openai.Client(base_url=f"{api_address()}/v1", api_key="NONE")
return client.files.create(file=(file_name, content), purpose="assistants").to_dict()
def get_image_file_url(upload_file: dict) -> str:
file_id = upload_file.get("id")
return f"{api_address(True)}/v1/files/{file_id}/content"
def add_conv(name: str = ""):
conv_names = chat_box.get_chat_names()
if not name:
i = len(conv_names) + 1
while True:
name = f"会话{i}"
if name not in conv_names:
break
i += 1
if name in conv_names:
sac.alert(
"创建新会话出错",
f"该会话名称 “{name}” 已存在",
color="error",
closable=True,
)
else:
chat_box.use_chat_name(name)
st.session_state["cur_conv_name"] = name
def del_conv(name: str = None):
conv_names = chat_box.get_chat_names()
name = name or chat_box.cur_chat_name
if len(conv_names) == 1:
sac.alert(
"删除会话出错", f"这是最后一个会话,无法删除", color="error", closable=True
)
elif not name or name not in conv_names:
sac.alert(
"删除会话出错", f"无效的会话名称:“{name}", color="error", closable=True
)
else:
chat_box.del_chat_name(name)
# restore_session()
st.session_state["cur_conv_name"] = chat_box.cur_chat_name
def clear_conv(name: str = None):
chat_box.reset_history(name=name or None)
# @st.cache_data
def list_tools(_api: ApiRequest):
return _api.list_tools() or {}
def dialogue_page(
api: ApiRequest,
is_lite: bool = False,
):
ctx = chat_box.context
ctx.setdefault("uid", uuid.uuid4().hex)
ctx.setdefault("file_chat_id", None)
ctx.setdefault("llm_model", get_default_llm())
ctx.setdefault("temperature", Settings.model_settings.TEMPERATURE)
st.session_state.setdefault("cur_conv_name", chat_box.cur_chat_name)
st.session_state.setdefault("last_conv_name", chat_box.cur_chat_name)
# sac on_change callbacks not working since st>=1.34
if st.session_state.cur_conv_name != st.session_state.last_conv_name:
save_session(st.session_state.last_conv_name)
restore_session(st.session_state.cur_conv_name)
st.session_state.last_conv_name = st.session_state.cur_conv_name
# st.write(chat_box.cur_chat_name)
# st.write(st.session_state)
# st.write(chat_box.context)
@st.experimental_dialog("模型配置", width="large")
def llm_model_setting():
# 模型
cols = st.columns(3)
platforms = ["所有"] + list(get_config_platforms())
platform = cols[0].selectbox("选择模型平台", platforms, key="platform")
llm_models = list(
get_config_models(
model_type="llm", platform_name=None if platform == "所有" else platform
)
)
llm_models += list(
get_config_models(
model_type="image2text", platform_name=None if platform == "所有" else platform
)
)
llm_model = cols[1].selectbox("选择LLM模型", llm_models, key="llm_model")
temperature = cols[2].slider("Temperature", 0.0, 1.0, key="temperature")
system_message = st.text_area("System Message:", key="system_message")
if st.button("OK"):
rerun()
@st.experimental_dialog("重命名会话")
def rename_conversation():
name = st.text_input("会话名称")
if st.button("OK"):
chat_box.change_chat_name(name)
restore_session()
st.session_state["cur_conv_name"] = name
rerun()
with st.sidebar:
tab1, tab2 = st.tabs(["工具设置", "会话设置"])
with tab1:
use_agent = st.checkbox(
"启用Agent", help="请确保选择的模型具备Agent能力", key="use_agent"
)
output_agent = st.checkbox("显示 Agent 过程", key="output_agent")
# 选择工具
tools = list_tools(api)
tool_names = ["None"] + list(tools)
if use_agent:
# selected_tools = sac.checkbox(list(tools), format_func=lambda x: tools[x]["title"], label="选择工具",
# check_all=True, key="selected_tools")
selected_tools = st.multiselect(
"选择工具",
list(tools),
format_func=lambda x: tools[x]["title"],
key="selected_tools",
)
else:
# selected_tool = sac.buttons(list(tools), format_func=lambda x: tools[x]["title"], label="选择工具",
# key="selected_tool")
selected_tool = st.selectbox(
"选择工具",
tool_names,
format_func=lambda x: tools.get(x, {"title": "None"})["title"],
key="selected_tool",
)
selected_tools = [selected_tool]
selected_tool_configs = {
name: tool["config"]
for name, tool in tools.items()
if name in selected_tools
}
if "None" in selected_tools:
selected_tools.remove("None")
# 当不启用Agent时手动生成工具参数
# TODO: 需要更精细的控制控件
tool_input = {}
if not use_agent and len(selected_tools) == 1:
with st.expander("工具参数", True):
for k, v in tools[selected_tools[0]]["args"].items():
if choices := v.get("choices", v.get("enum")):
tool_input[k] = st.selectbox(v["title"], choices)
else:
if v["type"] == "integer":
tool_input[k] = st.slider(
v["title"], value=v.get("default")
)
elif v["type"] == "number":
tool_input[k] = st.slider(
v["title"], value=v.get("default"), step=0.1
)
else:
tool_input[k] = st.text_input(
v["title"], v.get("default")
)
# uploaded_file = st.file_uploader("上传附件", accept_multiple_files=False)
# files_upload = process_files(files=[uploaded_file]) if uploaded_file else None
files_upload = None
# 用于图片对话、文生图的图片
upload_image = None
def on_upload_file_change():
if f := st.session_state.get("upload_image"):
name = ".".join(f.name.split(".")[:-1]) + ".png"
st.session_state["cur_image"] = (name, PILImage.open(f))
else:
st.session_state["cur_image"] = (None, None)
st.session_state.pop("paste_image", None)
st.file_uploader("上传图片", ["bmp", "jpg", "jpeg", "png"],
accept_multiple_files=False,
key="upload_image",
on_change=on_upload_file_change)
paste_image = paste_image_button("黏贴图像", key="paste_image")
cur_image = st.session_state.get("cur_image", (None, None))
if cur_image[1] is None and paste_image.image_data is not None:
name = hashlib.md5(paste_image.image_data.tobytes()).hexdigest()+".png"
cur_image = (name, paste_image.image_data)
if cur_image[1] is not None:
st.image(cur_image[1])
buffer = io.BytesIO()
cur_image[1].save(buffer, format="png")
upload_image = upload_image_file(cur_image[0], buffer.getvalue())
with tab2:
# 会话
cols = st.columns(3)
conv_names = chat_box.get_chat_names()
def on_conv_change():
print(conversation_name, st.session_state.cur_conv_name)
save_session(conversation_name)
restore_session(st.session_state.cur_conv_name)
conversation_name = sac.buttons(
conv_names,
label="当前会话:",
key="cur_conv_name",
# on_change=on_conv_change, # not work
)
chat_box.use_chat_name(conversation_name)
conversation_id = chat_box.context["uid"]
if cols[0].button("新建", on_click=add_conv):
...
if cols[1].button("重命名"):
rename_conversation()
if cols[2].button("删除", on_click=del_conv):
...
# Display chat messages from history on app rerun
chat_box.output_messages()
chat_input_placeholder = "请输入对话内容换行请使用Shift+Enter。"
# def on_feedback(
# feedback,
# message_id: str = "",
# history_index: int = -1,
# ):
# reason = feedback["text"]
# score_int = chat_box.set_feedback(feedback=feedback, history_index=history_index)
# api.chat_feedback(message_id=message_id,
# score=score_int,
# reason=reason)
# st.session_state["need_rerun"] = True
# feedback_kwargs = {
# "feedback_type": "thumbs",
# "optional_text_label": "欢迎反馈您打分的理由",
# }
# TODO: 这里的内容有点奇怪从后端导入Settings.model_settings.LLM_MODEL_CONFIG然后又从前端传到后端。需要优化
# 传入后端的内容
llm_model_config = Settings.model_settings.LLM_MODEL_CONFIG
chat_model_config = {key: {} for key in llm_model_config.keys()}
for key in llm_model_config:
if c := llm_model_config[key]:
model = c.get("model", "").strip() or get_default_llm()
chat_model_config[key][model] = llm_model_config[key]
llm_model = ctx.get("llm_model")
if llm_model is not None:
chat_model_config["llm_model"][llm_model] = llm_model_config["llm_model"].get(
llm_model, {}
)
# chat input
with bottom():
cols = st.columns([1, 0.2, 15, 1])
if cols[0].button(":gear:", help="模型配置"):
widget_keys = ["platform", "llm_model", "temperature", "system_message"]
chat_box.context_to_session(include=widget_keys)
llm_model_setting()
if cols[-1].button(":wastebasket:", help="清空对话"):
chat_box.reset_history()
rerun()
# with cols[1]:
# mic_audio = audio_recorder("", icon_size="2x", key="mic_audio")
prompt = cols[2].chat_input(chat_input_placeholder, key="prompt")
if prompt:
history = get_messages_history(
chat_model_config["llm_model"]
.get(next(iter(chat_model_config["llm_model"])), {})
.get("history_len", 1)
)
is_vision_chat = upload_image and not selected_tools
if is_vision_chat: # multimodal chat
chat_box.user_say([Image(get_image_file_url(upload_image), width=100), Markdown(prompt)])
else:
chat_box.user_say(prompt)
if files_upload:
if files_upload["images"]:
st.markdown(
f'<img src="data:image/jpeg;base64,{files_upload["images"][0]}" width="300">',
unsafe_allow_html=True,
)
elif files_upload["videos"]:
st.markdown(
f'<video width="400" height="300" controls><source src="data:video/mp4;base64,{files_upload["videos"][0]}" type="video/mp4"></video>',
unsafe_allow_html=True,
)
elif files_upload["audios"]:
st.markdown(
f'<audio controls><source src="data:audio/wav;base64,{files_upload["audios"][0]}" type="audio/wav"></audio>',
unsafe_allow_html=True,
)
chat_box.ai_say("正在思考...")
text = ""
started = False
client = openai.Client(base_url=f"{api_address()}/chat", api_key="NONE")
if is_vision_chat: # multimodal chat
content = [
{"type": "text", "text": prompt},
{"type": "image_url", "image_url": {"url": get_image_file_url(upload_image)}}
]
messages = [{"role": "user", "content": content}]
else:
messages = history + [{"role": "user", "content": prompt}]
tools = list(selected_tool_configs)
if len(selected_tools) == 1:
tool_choice = selected_tools[0]
else:
tool_choice = None
# 如果 tool_input 中有空的字段,设为用户输入
for k in tool_input:
if tool_input[k] in [None, ""]:
tool_input[k] = prompt
extra_body = dict(
metadata=files_upload,
chat_model_config=chat_model_config,
conversation_id=conversation_id,
tool_input=tool_input,
upload_image=upload_image,
)
stream = not is_vision_chat
params = dict(
messages=messages,
model=llm_model,
stream=stream, # TODOxinference qwen-vl-chat 流式输出会出错,后续看更新
extra_body=extra_body,
)
if tools:
params["tools"] = tools
if tool_choice:
params["tool_choice"] = tool_choice
if Settings.model_settings.MAX_TOKENS:
params["max_tokens"] = Settings.model_settings.MAX_TOKENS
if stream:
try:
for d in client.chat.completions.create(**params):
# import rich
# rich.print(d)
message_id = d.message_id
metadata = {
"message_id": message_id,
}
# clear initial message
if not started:
chat_box.update_msg("", streaming=False)
started = True
if d.status == AgentStatus.error:
st.error(d.choices[0].delta.content)
elif d.status == AgentStatus.llm_start:
if not output_agent:
continue
chat_box.insert_msg("正在解读工具输出结果...")
text = d.choices[0].delta.content or ""
elif d.status == AgentStatus.llm_new_token:
if not output_agent:
continue
text += d.choices[0].delta.content or ""
chat_box.update_msg(
text.replace("\n", "\n\n"), streaming=True, metadata=metadata
)
elif d.status == AgentStatus.llm_end:
if not output_agent:
continue
text += d.choices[0].delta.content or ""
chat_box.update_msg(
text.replace("\n", "\n\n"), streaming=False, metadata=metadata
)
# tool 的输出与 llm 输出重复了
# elif d.status == AgentStatus.tool_start:
# formatted_data = {
# "Function": d.choices[0].delta.tool_calls[0].function.name,
# "function_input": d.choices[0].delta.tool_calls[0].function.arguments,
# }
# formatted_json = json.dumps(formatted_data, indent=2, ensure_ascii=False)
# text = """\n```{}\n```\n""".format(formatted_json)
# chat_box.insert_msg( # TODO: insert text directly not shown
# Markdown(text, title="Function call", in_expander=True, expanded=True, state="running"))
# elif d.status == AgentStatus.tool_end:
# tool_output = d.choices[0].delta.tool_calls[0].tool_output
# if d.message_type == MsgType.IMAGE:
# for url in json.loads(tool_output).get("images", []):
# url = f"{api.base_url}/media/{url}"
# chat_box.insert_msg(Image(url))
# chat_box.update_msg(expanded=False, state="complete")
# else:
# text += """\n```\nObservation:\n{}\n```\n""".format(tool_output)
# chat_box.update_msg(text, streaming=False, expanded=False, state="complete")
elif d.status == AgentStatus.agent_finish:
text = d.choices[0].delta.content or ""
chat_box.update_msg(text.replace("\n", "\n\n"))
elif d.status is None: # not agent chat
if getattr(d, "is_ref", False):
context = str(d.tool_output)
if isinstance(d.tool_output, dict):
docs = d.tool_output.get("docs", [])
source_documents = format_reference(kb_name=d.tool_output.get("knowledge_base"),
docs=docs,
api_base_url=api_address(is_public=True))
context = "\n".join(source_documents)
chat_box.insert_msg(
Markdown(
context,
in_expander=True,
state="complete",
title="参考资料",
)
)
chat_box.insert_msg("")
elif getattr(d, "tool_call", None) == "text2images": # TODO特定工具特别处理需要更通用的处理方式
for img in d.tool_output.get("images", []):
chat_box.insert_msg(Image(f"{api.base_url}/media/{img}"), pos=-2)
else:
text += d.choices[0].delta.content or ""
chat_box.update_msg(
text.replace("\n", "\n\n"), streaming=True, metadata=metadata
)
chat_box.update_msg(text, streaming=False, metadata=metadata)
except Exception as e:
st.error(e.body)
else:
try:
d =client.chat.completions.create(**params)
chat_box.update_msg(d.choices[0].message.content or "", streaming=False)
except Exception as e:
st.error(e.body)
# if os.path.exists("tmp/image.jpg"):
# with open("tmp/image.jpg", "rb") as image_file:
# encoded_string = base64.b64encode(image_file.read()).decode()
# img_tag = (
# f'<img src="data:image/jpeg;base64,{encoded_string}" width="300">'
# )
# st.markdown(img_tag, unsafe_allow_html=True)
# os.remove("tmp/image.jpg")
# chat_box.show_feedback(**feedback_kwargs,
# key=message_id,
# on_submit=on_feedback,
# kwargs={"message_id": message_id, "history_index": len(chat_box.history) - 1})
# elif dialogue_mode == "文件对话":
# if st.session_state["file_chat_id"] is None:
# st.error("请先上传文件再进行对话")
# st.stop()
# chat_box.ai_say([
# f"正在查询文件 `{st.session_state['file_chat_id']}` ...",
# Markdown("...", in_expander=True, title="文件匹配结果", state="complete"),
# ])
# text = ""
# for d in api.file_chat(prompt,
# knowledge_id=st.session_state["file_chat_id"],
# top_k=kb_top_k,
# score_threshold=score_threshold,
# history=history,
# model=llm_model,
# prompt_name=prompt_template_name,
# temperature=temperature):
# if error_msg := check_error_msg(d):
# st.error(error_msg)
# elif chunk := d.get("answer"):
# text += chunk
# chat_box.update_msg(text, element_index=0)
# chat_box.update_msg(text, element_index=0, streaming=False)
# chat_box.update_msg("\n\n".join(d.get("docs", [])), element_index=1, streaming=False)
now = datetime.now()
with tab2:
cols = st.columns(2)
export_btn = cols[0]
if cols[1].button(
"清空对话",
use_container_width=True,
):
chat_box.reset_history()
rerun()
export_btn.download_button(
"导出记录",
"".join(chat_box.export2md()),
file_name=f"{now:%Y-%m-%d %H.%M}_对话记录.md",
mime="text/markdown",
use_container_width=True,
)
# st.write(chat_box.history)