From c4fe3393b3ca56babd5acdc4a7d25cea58693492 Mon Sep 17 00:00:00 2001
From: liunux4odoo <41217877+liunux4odoo@users.noreply.github.com>
Date: Thu, 30 Nov 2023 11:39:41 +0800
Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E8=87=AA=E5=AE=9A=E4=B9=89?=
=?UTF-8?q?=E5=91=BD=E4=BB=A4=EF=BC=9A=20(#2229)?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
/new [conv_name] 新建会话
/del [conv_name] 删除会话
/clear [conv_name] 清空会话
/help 命令帮助
新增依赖:streamlit-modal
---
requirements.txt | 1 +
requirements_webui.txt | 1 +
webui_pages/dialogue/dialogue.py | 333 +++++++++++++++++--------------
3 files changed, 190 insertions(+), 145 deletions(-)
diff --git a/requirements.txt b/requirements.txt
index 27e68e5..6b46886 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -58,6 +58,7 @@ streamlit~=1.28.2 # # on win, make sure write its path in environment variable
streamlit-option-menu>=0.3.6
streamlit-antd-components>=0.2.3
streamlit-chatbox>=1.1.11
+streamlit-modal==0.1.0
streamlit-aggrid>=0.3.4.post3
httpx[brotli,http2,socks]~=0.24.1
watchdog
\ No newline at end of file
diff --git a/requirements_webui.txt b/requirements_webui.txt
index 0fe3dc8..86ca8af 100644
--- a/requirements_webui.txt
+++ b/requirements_webui.txt
@@ -4,6 +4,7 @@ streamlit~=1.28.2
streamlit-option-menu>=0.3.6
streamlit-antd-components>=0.2.3
streamlit-chatbox>=1.1.11
+streamlit-modal==0.1.0
streamlit-aggrid>=0.3.4.post3
httpx[brotli,http2,socks]~=0.24.1
watchdog
diff --git a/webui_pages/dialogue/dialogue.py b/webui_pages/dialogue/dialogue.py
index 3310fb3..064c8a4 100644
--- a/webui_pages/dialogue/dialogue.py
+++ b/webui_pages/dialogue/dialogue.py
@@ -1,8 +1,11 @@
import streamlit as st
from webui_pages.utils import *
from streamlit_chatbox import *
+from streamlit_modal import Modal
from datetime import datetime
import os
+import re
+import time
from configs import (TEMPERATURE, HISTORY_LEN, PROMPT_TEMPLATES,
DEFAULT_KNOWLEDGE_BASE, DEFAULT_SEARCH_ENGINE, SUPPORT_AGENT_MODEL)
from server.knowledge_base.utils import LOADER_DICT
@@ -47,6 +50,53 @@ def upload_temp_docs(files, _api: ApiRequest) -> str:
return _api.upload_temp_docs(files).get("data", {}).get("id")
+def parse_command(text: str, modal: Modal) -> bool:
+ '''
+ 检查用户是否输入了自定义命令,当前支持:
+ /new {session_name}。如果未提供名称,默认为“会话X”
+ /del {session_name}。如果未提供名称,在会话数量>1的情况下,删除当前会话。
+ /clear {session_name}。如果未提供名称,默认清除当前会话
+ /help。查看命令帮助
+ 返回值:输入的是命令返回True,否则返回False
+ '''
+ if m := re.match(r"/([^\s]+)\s*(.*)", text):
+ cmd, name = m.groups()
+ name = name.strip()
+ conv_names = chat_box.get_chat_names()
+ if cmd == "help":
+ modal.open()
+ elif cmd == "new":
+ if not name:
+ i = 1
+ while True:
+ name = f"会话{i}"
+ if name not in conv_names:
+ break
+ i += 1
+ if name in st.session_state["conversation_ids"]:
+ st.error(f"该会话名称 “{name}” 已存在")
+ time.sleep(1)
+ else:
+ st.session_state["conversation_ids"][name] = uuid.uuid4().hex
+ st.session_state["cur_conv_name"] = name
+ elif cmd == "del":
+ name = name or st.session_state.get("cur_conv_name")
+ if len(conv_names) == 1:
+ st.error("这是最后一个会话,无法删除")
+ time.sleep(1)
+ elif not name or name not in st.session_state["conversation_ids"]:
+ st.error(f"无效的会话名称:“{name}”")
+ time.sleep(1)
+ else:
+ st.session_state["conversation_ids"].pop(name, None)
+ chat_box.del_chat_name(name)
+ st.session_state["cur_conv_name"] = ""
+ elif cmd == "clear":
+ chat_box.reset_history(name=name or None)
+ return True
+ return False
+
+
def dialogue_page(api: ApiRequest, is_lite: bool = False):
st.session_state.setdefault("conversation_ids", {})
st.session_state["conversation_ids"].setdefault(chat_box.cur_chat_name, uuid.uuid4().hex)
@@ -60,25 +110,15 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
)
chat_box.init_session()
+ # 弹出自定义命令帮助信息
+ modal = Modal("自定义命令", key="cmd_help", max_width="500")
+ if modal.is_open():
+ with modal.container():
+ cmds = [x for x in parse_command.__doc__.split("\n") if x.strip().startswith("/")]
+ st.write("\n\n".join(cmds))
+
with st.sidebar:
# 多会话
- cols = st.columns([3, 1])
- conv_name = cols[0].text_input("会话名称")
- with cols[1]:
- if st.button("添加"):
- if not conv_name or conv_name in st.session_state["conversation_ids"]:
- st.error("请指定有效的会话名称")
- else:
- st.session_state["conversation_ids"][conv_name] = uuid.uuid4().hex
- st.session_state["cur_conv_name"] = conv_name
- st.session_state["conv_name"] = ""
- if st.button("删除"):
- if not conv_name or conv_name not in st.session_state["conversation_ids"]:
- st.error("请指定有效的会话名称")
- else:
- st.session_state["conversation_ids"].pop(conv_name, None)
- st.session_state["cur_conv_name"] = ""
-
conv_names = list(st.session_state["conversation_ids"].keys())
index = 0
if st.session_state.get("cur_conv_name") in conv_names:
@@ -236,7 +276,7 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
# Display chat messages from history on app rerun
chat_box.output_messages()
- chat_input_placeholder = "请输入对话内容,换行请使用Shift+Enter "
+ chat_input_placeholder = "请输入对话内容,换行请使用Shift+Enter。输入/help查看自定义命令 "
def on_feedback(
feedback,
@@ -256,139 +296,142 @@ def dialogue_page(api: ApiRequest, is_lite: bool = False):
}
if prompt := st.chat_input(chat_input_placeholder, key="prompt"):
- history = get_messages_history(history_len)
- chat_box.user_say(prompt)
- if dialogue_mode == "LLM 对话":
- chat_box.ai_say("正在思考...")
- text = ""
- message_id = ""
- r = api.chat_chat(prompt,
- history=history,
- conversation_id=conversation_id,
- model=llm_model,
- prompt_name=prompt_template_name,
- temperature=temperature)
- for t in r:
- if error_msg := check_error_msg(t): # check whether error occured
- st.error(error_msg)
- break
- text += t.get("text", "")
- chat_box.update_msg(text)
- message_id = t.get("message_id", "")
+ if parse_command(text=prompt, modal=modal): # 用户输入自定义命令
+ st.rerun()
+ else:
+ history = get_messages_history(history_len)
+ chat_box.user_say(prompt)
+ if dialogue_mode == "LLM 对话":
+ chat_box.ai_say("正在思考...")
+ text = ""
+ message_id = ""
+ r = api.chat_chat(prompt,
+ history=history,
+ conversation_id=conversation_id,
+ model=llm_model,
+ prompt_name=prompt_template_name,
+ temperature=temperature)
+ for t in r:
+ if error_msg := check_error_msg(t): # check whether error occured
+ st.error(error_msg)
+ break
+ text += t.get("text", "")
+ chat_box.update_msg(text)
+ message_id = t.get("message_id", "")
- metadata = {
- "message_id": message_id,
- }
- chat_box.update_msg(text, streaming=False, metadata=metadata) # 更新最终的字符串,去除光标
- chat_box.show_feedback(**feedback_kwargs,
- key=message_id,
- on_submit=on_feedback,
- kwargs={"message_id": message_id, "history_index": len(chat_box.history) - 1})
+ metadata = {
+ "message_id": message_id,
+ }
+ chat_box.update_msg(text, streaming=False, metadata=metadata) # 更新最终的字符串,去除光标
+ chat_box.show_feedback(**feedback_kwargs,
+ key=message_id,
+ on_submit=on_feedback,
+ kwargs={"message_id": message_id, "history_index": len(chat_box.history) - 1})
- elif dialogue_mode == "自定义Agent问答":
- if not any(agent in llm_model for agent in SUPPORT_AGENT_MODEL):
+ elif dialogue_mode == "自定义Agent问答":
+ if not any(agent in llm_model for agent in SUPPORT_AGENT_MODEL):
+ chat_box.ai_say([
+ f"正在思考... \n\n 该模型并没有进行Agent对齐,请更换支持Agent的模型获得更好的体验!\n\n\n",
+ Markdown("...", in_expander=True, title="思考过程", state="complete"),
+
+ ])
+ else:
+ chat_box.ai_say([
+ f"正在思考...",
+ Markdown("...", in_expander=True, title="思考过程", state="complete"),
+
+ ])
+ text = ""
+ ans = ""
+ for d in api.agent_chat(prompt,
+ history=history,
+ model=llm_model,
+ prompt_name=prompt_template_name,
+ temperature=temperature,
+ ):
+ try:
+ d = json.loads(d)
+ except:
+ pass
+ if error_msg := check_error_msg(d): # check whether error occured
+ st.error(error_msg)
+ if chunk := d.get("answer"):
+ text += chunk
+ chat_box.update_msg(text, element_index=1)
+ if chunk := d.get("final_answer"):
+ ans += chunk
+ chat_box.update_msg(ans, element_index=0)
+ if chunk := d.get("tools"):
+ text += "\n\n".join(d.get("tools", []))
+ chat_box.update_msg(text, element_index=1)
+ chat_box.update_msg(ans, element_index=0, streaming=False)
+ chat_box.update_msg(text, element_index=1, streaming=False)
+ elif dialogue_mode == "知识库问答":
chat_box.ai_say([
- f"正在思考... \n\n 该模型并没有进行Agent对齐,请更换支持Agent的模型获得更好的体验!\n\n\n",
- Markdown("...", in_expander=True, title="思考过程", state="complete"),
-
+ f"正在查询知识库 `{selected_kb}` ...",
+ Markdown("...", in_expander=True, title="知识库匹配结果", state="complete"),
])
- else:
+ text = ""
+ for d in api.knowledge_base_chat(prompt,
+ knowledge_base_name=selected_kb,
+ top_k=kb_top_k,
+ score_threshold=score_threshold,
+ history=history,
+ model=llm_model,
+ prompt_name=prompt_template_name,
+ temperature=temperature):
+ if error_msg := check_error_msg(d): # check whether error occured
+ st.error(error_msg)
+ elif chunk := d.get("answer"):
+ text += chunk
+ chat_box.update_msg(text, element_index=0)
+ chat_box.update_msg(text, element_index=0, streaming=False)
+ chat_box.update_msg("\n\n".join(d.get("docs", [])), element_index=1, streaming=False)
+ elif dialogue_mode == "文件对话":
+ if st.session_state["file_chat_id"] is None:
+ st.error("请先上传文件再进行对话")
+ st.stop()
chat_box.ai_say([
- f"正在思考...",
- Markdown("...", in_expander=True, title="思考过程", state="complete"),
-
+ f"正在查询文件 `{st.session_state['file_chat_id']}` ...",
+ Markdown("...", in_expander=True, title="文件匹配结果", state="complete"),
])
- text = ""
- ans = ""
- for d in api.agent_chat(prompt,
- history=history,
- model=llm_model,
- prompt_name=prompt_template_name,
- temperature=temperature,
- ):
- try:
- d = json.loads(d)
- except:
- pass
- if error_msg := check_error_msg(d): # check whether error occured
- st.error(error_msg)
- if chunk := d.get("answer"):
- text += chunk
- chat_box.update_msg(text, element_index=1)
- if chunk := d.get("final_answer"):
- ans += chunk
- chat_box.update_msg(ans, element_index=0)
- if chunk := d.get("tools"):
- text += "\n\n".join(d.get("tools", []))
- chat_box.update_msg(text, element_index=1)
- chat_box.update_msg(ans, element_index=0, streaming=False)
- chat_box.update_msg(text, element_index=1, streaming=False)
- elif dialogue_mode == "知识库问答":
- chat_box.ai_say([
- f"正在查询知识库 `{selected_kb}` ...",
- Markdown("...", in_expander=True, title="知识库匹配结果", state="complete"),
- ])
- text = ""
- for d in api.knowledge_base_chat(prompt,
- knowledge_base_name=selected_kb,
- top_k=kb_top_k,
- score_threshold=score_threshold,
- history=history,
- model=llm_model,
- prompt_name=prompt_template_name,
- temperature=temperature):
- if error_msg := check_error_msg(d): # check whether error occured
- st.error(error_msg)
- elif chunk := d.get("answer"):
- text += chunk
- chat_box.update_msg(text, element_index=0)
- chat_box.update_msg(text, element_index=0, streaming=False)
- chat_box.update_msg("\n\n".join(d.get("docs", [])), element_index=1, streaming=False)
- elif dialogue_mode == "文件对话":
- if st.session_state["file_chat_id"] is None:
- st.error("请先上传文件再进行对话")
- st.stop()
- chat_box.ai_say([
- f"正在查询文件 `{st.session_state['file_chat_id']}` ...",
- Markdown("...", in_expander=True, title="文件匹配结果", state="complete"),
- ])
- text = ""
- for d in api.file_chat(prompt,
- knowledge_id=st.session_state["file_chat_id"],
- top_k=kb_top_k,
- score_threshold=score_threshold,
- history=history,
- model=llm_model,
- prompt_name=prompt_template_name,
- temperature=temperature):
- if error_msg := check_error_msg(d): # check whether error occured
- st.error(error_msg)
- elif chunk := d.get("answer"):
- text += chunk
- chat_box.update_msg(text, element_index=0)
- chat_box.update_msg(text, element_index=0, streaming=False)
- chat_box.update_msg("\n\n".join(d.get("docs", [])), element_index=1, streaming=False)
- elif dialogue_mode == "搜索引擎问答":
- chat_box.ai_say([
- f"正在执行 `{search_engine}` 搜索...",
- Markdown("...", in_expander=True, title="网络搜索结果", state="complete"),
- ])
- text = ""
- for d in api.search_engine_chat(prompt,
- search_engine_name=search_engine,
- top_k=se_top_k,
- history=history,
- model=llm_model,
- prompt_name=prompt_template_name,
- temperature=temperature,
- split_result=se_top_k > 1):
- if error_msg := check_error_msg(d): # check whether error occured
- st.error(error_msg)
- elif chunk := d.get("answer"):
- text += chunk
- chat_box.update_msg(text, element_index=0)
- chat_box.update_msg(text, element_index=0, streaming=False)
- chat_box.update_msg("\n\n".join(d.get("docs", [])), element_index=1, streaming=False)
+ text = ""
+ for d in api.file_chat(prompt,
+ knowledge_id=st.session_state["file_chat_id"],
+ top_k=kb_top_k,
+ score_threshold=score_threshold,
+ history=history,
+ model=llm_model,
+ prompt_name=prompt_template_name,
+ temperature=temperature):
+ if error_msg := check_error_msg(d): # check whether error occured
+ st.error(error_msg)
+ elif chunk := d.get("answer"):
+ text += chunk
+ chat_box.update_msg(text, element_index=0)
+ chat_box.update_msg(text, element_index=0, streaming=False)
+ chat_box.update_msg("\n\n".join(d.get("docs", [])), element_index=1, streaming=False)
+ elif dialogue_mode == "搜索引擎问答":
+ chat_box.ai_say([
+ f"正在执行 `{search_engine}` 搜索...",
+ Markdown("...", in_expander=True, title="网络搜索结果", state="complete"),
+ ])
+ text = ""
+ for d in api.search_engine_chat(prompt,
+ search_engine_name=search_engine,
+ top_k=se_top_k,
+ history=history,
+ model=llm_model,
+ prompt_name=prompt_template_name,
+ temperature=temperature,
+ split_result=se_top_k > 1):
+ if error_msg := check_error_msg(d): # check whether error occured
+ st.error(error_msg)
+ elif chunk := d.get("answer"):
+ text += chunk
+ chat_box.update_msg(text, element_index=0)
+ chat_box.update_msg(text, element_index=0, streaming=False)
+ chat_box.update_msg("\n\n".join(d.get("docs", [])), element_index=1, streaming=False)
if st.session_state.get("need_rerun"):
st.session_state["need_rerun"] = False