withdraw conversation management func
This commit is contained in:
parent
1b07648238
commit
77364c046e
56
webui.py
56
webui.py
|
|
@ -20,43 +20,22 @@ if __name__ == "__main__":
|
||||||
f"当前使用模型`{LLM_MODEL}`, 您可以开始提问了."
|
f"当前使用模型`{LLM_MODEL}`, 您可以开始提问了."
|
||||||
)
|
)
|
||||||
|
|
||||||
if "chat_list" not in st.session_state:
|
pages = {
|
||||||
st.session_state["chat_list"] = {"对话1": {"need_rename": True, "chat_no": 1}}
|
"对话": {
|
||||||
if "cur_chat_name" not in st.session_state:
|
"icon": "chat",
|
||||||
st.session_state["cur_chat_name"] = list(st.session_state["chat_list"].keys())[0]
|
|
||||||
if "need_chat_name" not in st.session_state:
|
|
||||||
st.session_state["need_chat_name"] = True
|
|
||||||
if "chat_count" not in st.session_state:
|
|
||||||
st.session_state["chat_count"] = 1
|
|
||||||
|
|
||||||
chat_list = [{"name": k, "chat_no": v.get("chat_no", 0)} for k, v in st.session_state.chat_list.items()]
|
|
||||||
chat_list = [x["name"] for x in sorted(chat_list, key=lambda x: x["chat_no"])]
|
|
||||||
pages1 = {i: {
|
|
||||||
"icon": "chat",
|
|
||||||
"func": dialogue_page,
|
|
||||||
} for i in chat_list}
|
|
||||||
|
|
||||||
pages2 = {
|
|
||||||
"新建对话": {
|
|
||||||
"icon": "plus-circle",
|
|
||||||
"func": dialogue_page,
|
"func": dialogue_page,
|
||||||
},
|
},
|
||||||
"---": {
|
|
||||||
"icon": None,
|
|
||||||
"func": None
|
|
||||||
},
|
|
||||||
"知识库管理": {
|
"知识库管理": {
|
||||||
"icon": "hdd-stack",
|
"icon": "hdd-stack",
|
||||||
"func": knowledge_base_page,
|
"func": knowledge_base_page,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
pages = {**pages1, **pages2}
|
|
||||||
|
|
||||||
with st.sidebar:
|
with st.sidebar:
|
||||||
options = chat_list + list(pages2)
|
options = list(pages)
|
||||||
icons = ["chat"] * len(chat_list) + [x["icon"] for x in pages2.values()]
|
icons = [x["icon"] for x in pages.values()]
|
||||||
|
|
||||||
default_index = list(pages).index(st.session_state["cur_chat_name"])
|
default_index = 0
|
||||||
selected_page = option_menu(
|
selected_page = option_menu(
|
||||||
"langchain-chatglm",
|
"langchain-chatglm",
|
||||||
options=options,
|
options=options,
|
||||||
|
|
@ -65,28 +44,5 @@ if __name__ == "__main__":
|
||||||
default_index=default_index,
|
default_index=default_index,
|
||||||
)
|
)
|
||||||
|
|
||||||
if selected_page == "新建对话":
|
|
||||||
cur_chat_name = st.session_state.get("cur_chat_name")
|
|
||||||
if (not st.session_state.get("create_chat")
|
|
||||||
and not st.session_state.get("renamde_chat")
|
|
||||||
and not st.session_state.get("delete_chat")):
|
|
||||||
st.session_state.chat_count += 1
|
|
||||||
chat_no = st.session_state.chat_count
|
|
||||||
new_chat_name = f"对话{chat_no}"
|
|
||||||
st.session_state.chat_list[new_chat_name] = {"need_rename": True, "chat_no": chat_no}
|
|
||||||
st.session_state["cur_chat_name"] = new_chat_name
|
|
||||||
st.experimental_rerun()
|
|
||||||
else:
|
|
||||||
if st.session_state.get("create_chat"):
|
|
||||||
st.session_state.create_chat = False
|
|
||||||
if st.session_state.get("renamde_chat"):
|
|
||||||
st.session_state.renamde_chat = False
|
|
||||||
st.experimental_rerun()
|
|
||||||
if st.session_state.get("delete_chat"):
|
|
||||||
st.session_state.delete_chat = False
|
|
||||||
st.experimental_rerun()
|
|
||||||
elif selected_page in st.session_state.chat_list:
|
|
||||||
st.session_state["selected_page"] = selected_page
|
|
||||||
|
|
||||||
if selected_page in pages:
|
if selected_page in pages:
|
||||||
pages[selected_page]["func"](api)
|
pages[selected_page]["func"](api)
|
||||||
|
|
|
||||||
|
|
@ -32,7 +32,6 @@ def get_messages_history(history_len: int) -> List[Dict]:
|
||||||
|
|
||||||
def dialogue_page(api: ApiRequest):
|
def dialogue_page(api: ApiRequest):
|
||||||
chat_box.init_session()
|
chat_box.init_session()
|
||||||
chat_box.use_chat_name(st.session_state.selected_page)
|
|
||||||
|
|
||||||
with st.sidebar:
|
with st.sidebar:
|
||||||
# TODO: 对话模型与会话绑定
|
# TODO: 对话模型与会话绑定
|
||||||
|
|
@ -82,67 +81,47 @@ def dialogue_page(api: ApiRequest):
|
||||||
|
|
||||||
chat_box.output_messages()
|
chat_box.output_messages()
|
||||||
|
|
||||||
if st.session_state.chat_list.get(st.session_state.selected_page, {}).get("need_rename"):
|
chat_input_placeholder = "请输入对话内容,换行请使用Ctrl+Enter "
|
||||||
chat_input_placeholder = "请输入对话名称"
|
|
||||||
else:
|
|
||||||
chat_input_placeholder = "请输入对话内容,换行请使用Ctrl+Enter "
|
|
||||||
|
|
||||||
def on_prompt():
|
|
||||||
st.session_state.rename_chat = True
|
|
||||||
|
|
||||||
if prompt := st.chat_input(chat_input_placeholder, key="prompt", on_submit=on_prompt):
|
if prompt := st.chat_input(chat_input_placeholder, key="prompt"):
|
||||||
if st.session_state.chat_list.get(st.session_state.selected_page, {}).get("need_rename"):
|
history = get_messages_history(history_len)
|
||||||
if prompt in st.session_state.chat_list.keys():
|
chat_box.user_say(prompt)
|
||||||
st.toast("已有同名对话,请重新命名")
|
if dialogue_mode == "LLM 对话":
|
||||||
else:
|
chat_box.ai_say("正在思考...")
|
||||||
cur_chat_name = st.session_state.get("selected_page")
|
text = ""
|
||||||
st.session_state.chat_list[prompt] = {
|
r = api.chat_chat(prompt, history)
|
||||||
"need_rename": False,
|
for t in r:
|
||||||
"chat_no": st.session_state.chat_list[cur_chat_name]["chat_no"]}
|
text += t
|
||||||
st.session_state.chat_list.pop(cur_chat_name)
|
chat_box.update_msg(text)
|
||||||
chat_box.del_chat_name(cur_chat_name)
|
chat_box.update_msg(text, streaming=False) # 更新最终的字符串,去除光标
|
||||||
st.session_state.cur_chat_name = prompt
|
elif dialogue_mode == "知识库问答":
|
||||||
chat_box.use_chat_name(prompt)
|
|
||||||
st.experimental_rerun()
|
|
||||||
else:
|
|
||||||
history = get_messages_history(history_len)
|
history = get_messages_history(history_len)
|
||||||
chat_box.user_say(prompt)
|
chat_box.ai_say([
|
||||||
if dialogue_mode == "LLM 对话":
|
f"正在查询知识库 `{selected_kb}` ...",
|
||||||
chat_box.ai_say("正在思考...")
|
Markdown("...", in_expander=True, title="知识库匹配结果"),
|
||||||
text = ""
|
])
|
||||||
r = api.chat_chat(prompt, history)
|
text = ""
|
||||||
for t in r:
|
for d in api.knowledge_base_chat(prompt, selected_kb, kb_top_k, history):
|
||||||
text += t
|
text += d["answer"]
|
||||||
chat_box.update_msg(text)
|
chat_box.update_msg(text, 0)
|
||||||
chat_box.update_msg(text, streaming=False) # 更新最终的字符串,去除光标
|
chat_box.update_msg("\n\n".join(d["docs"]), 1, streaming=False)
|
||||||
elif dialogue_mode == "知识库问答":
|
chat_box.update_msg(text, 0, streaming=False)
|
||||||
history = get_messages_history(history_len)
|
elif dialogue_mode == "搜索引擎问答":
|
||||||
chat_box.ai_say([
|
chat_box.ai_say([
|
||||||
f"正在查询知识库 `{selected_kb}` ...",
|
f"正在执行 `{search_engine}` 搜索...",
|
||||||
Markdown("...", in_expander=True, title="知识库匹配结果"),
|
Markdown("...", in_expander=True, title="网络搜索结果"),
|
||||||
])
|
])
|
||||||
text = ""
|
text = ""
|
||||||
for d in api.knowledge_base_chat(prompt, selected_kb, kb_top_k, history):
|
for d in api.search_engine_chat(prompt, search_engine, se_top_k):
|
||||||
text += d["answer"]
|
text += d["answer"]
|
||||||
chat_box.update_msg(text, 0)
|
chat_box.update_msg(text, 0)
|
||||||
chat_box.update_msg("\n\n".join(d["docs"]), 1, streaming=False)
|
chat_box.update_msg("\n\n".join(d["docs"]), 1, streaming=False)
|
||||||
chat_box.update_msg(text, 0, streaming=False)
|
chat_box.update_msg(text, 0, streaming=False)
|
||||||
elif dialogue_mode == "搜索引擎问答":
|
|
||||||
chat_box.ai_say([
|
|
||||||
f"正在执行 `{search_engine}` 搜索...",
|
|
||||||
Markdown("...", in_expander=True, title="网络搜索结果"),
|
|
||||||
])
|
|
||||||
text = ""
|
|
||||||
for d in api.search_engine_chat(prompt, search_engine, se_top_k):
|
|
||||||
text += d["answer"]
|
|
||||||
chat_box.update_msg(text, 0)
|
|
||||||
chat_box.update_msg("\n\n".join(d["docs"]), 1, streaming=False)
|
|
||||||
chat_box.update_msg(text, 0, streaming=False)
|
|
||||||
|
|
||||||
now = datetime.now()
|
now = datetime.now()
|
||||||
with st.sidebar:
|
with st.sidebar:
|
||||||
|
|
||||||
cols = st.columns(3)
|
cols = st.columns(2)
|
||||||
export_btn = cols[0]
|
export_btn = cols[0]
|
||||||
if cols[1].button(
|
if cols[1].button(
|
||||||
"清空对话",
|
"清空对话",
|
||||||
|
|
@ -151,26 +130,10 @@ def dialogue_page(api: ApiRequest):
|
||||||
chat_box.reset_history()
|
chat_box.reset_history()
|
||||||
st.experimental_rerun()
|
st.experimental_rerun()
|
||||||
|
|
||||||
def on_delete_chat():
|
|
||||||
st.session_state.delete_chat = True
|
|
||||||
|
|
||||||
if cols[2].button(
|
|
||||||
"删除对话",
|
|
||||||
disabled=len(st.session_state.chat_list) <= 1,
|
|
||||||
use_container_width=True,
|
|
||||||
on_click=on_delete_chat
|
|
||||||
):
|
|
||||||
cur_chat_name = st.session_state.get("selected_page")
|
|
||||||
chat_box.del_chat_name(cur_chat_name)
|
|
||||||
st.session_state.chat_list.pop(cur_chat_name)
|
|
||||||
st.session_state.cur_chat_name = list(st.session_state.chat_list.keys())[0]
|
|
||||||
chat_box.use_chat_name(st.session_state.cur_chat_name)
|
|
||||||
st.experimental_rerun()
|
|
||||||
|
|
||||||
export_btn.download_button(
|
export_btn.download_button(
|
||||||
"导出记录",
|
"导出记录",
|
||||||
"".join(chat_box.export2md(st.session_state.selected_page)),
|
"".join(chat_box.export2md()),
|
||||||
file_name=f"{now:%Y-%m-%d %H.%M}_{st.session_state.selected_page}.md",
|
file_name=f"{now:%Y-%m-%d %H.%M}_对话记录.md",
|
||||||
mime="text/markdown",
|
mime="text/markdown",
|
||||||
use_container_width=True,
|
use_container_width=True,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue