diff --git a/webui.py b/webui.py index 2406c19..6077b81 100644 --- a/webui.py +++ b/webui.py @@ -20,43 +20,22 @@ if __name__ == "__main__": f"当前使用模型`{LLM_MODEL}`, 您可以开始提问了." ) - if "chat_list" not in st.session_state: - st.session_state["chat_list"] = {"对话1": {"need_rename": True, "chat_no": 1}} - if "cur_chat_name" not in st.session_state: - st.session_state["cur_chat_name"] = list(st.session_state["chat_list"].keys())[0] - if "need_chat_name" not in st.session_state: - st.session_state["need_chat_name"] = True - if "chat_count" not in st.session_state: - st.session_state["chat_count"] = 1 - - chat_list = [{"name": k, "chat_no": v.get("chat_no", 0)} for k, v in st.session_state.chat_list.items()] - chat_list = [x["name"] for x in sorted(chat_list, key=lambda x: x["chat_no"])] - pages1 = {i: { - "icon": "chat", - "func": dialogue_page, - } for i in chat_list} - - pages2 = { - "新建对话": { - "icon": "plus-circle", + pages = { + "对话": { + "icon": "chat", "func": dialogue_page, }, - "---": { - "icon": None, - "func": None - }, "知识库管理": { "icon": "hdd-stack", "func": knowledge_base_page, }, } - pages = {**pages1, **pages2} with st.sidebar: - options = chat_list + list(pages2) - icons = ["chat"] * len(chat_list) + [x["icon"] for x in pages2.values()] + options = list(pages) + icons = [x["icon"] for x in pages.values()] - default_index = list(pages).index(st.session_state["cur_chat_name"]) + default_index = 0 selected_page = option_menu( "langchain-chatglm", options=options, @@ -65,28 +44,5 @@ if __name__ == "__main__": default_index=default_index, ) - if selected_page == "新建对话": - cur_chat_name = st.session_state.get("cur_chat_name") - if (not st.session_state.get("create_chat") - and not st.session_state.get("renamde_chat") - and not st.session_state.get("delete_chat")): - st.session_state.chat_count += 1 - chat_no = st.session_state.chat_count - new_chat_name = f"对话{chat_no}" - st.session_state.chat_list[new_chat_name] = {"need_rename": True, "chat_no": chat_no} - st.session_state["cur_chat_name"] = new_chat_name - st.experimental_rerun() - else: - if st.session_state.get("create_chat"): - st.session_state.create_chat = False - if st.session_state.get("renamde_chat"): - st.session_state.renamde_chat = False - st.experimental_rerun() - if st.session_state.get("delete_chat"): - st.session_state.delete_chat = False - st.experimental_rerun() - elif selected_page in st.session_state.chat_list: - st.session_state["selected_page"] = selected_page - if selected_page in pages: pages[selected_page]["func"](api) diff --git a/webui_pages/dialogue/dialogue.py b/webui_pages/dialogue/dialogue.py index e2f0fbd..43334c3 100644 --- a/webui_pages/dialogue/dialogue.py +++ b/webui_pages/dialogue/dialogue.py @@ -32,7 +32,6 @@ def get_messages_history(history_len: int) -> List[Dict]: def dialogue_page(api: ApiRequest): chat_box.init_session() - chat_box.use_chat_name(st.session_state.selected_page) with st.sidebar: # TODO: 对话模型与会话绑定 @@ -82,67 +81,47 @@ def dialogue_page(api: ApiRequest): chat_box.output_messages() - if st.session_state.chat_list.get(st.session_state.selected_page, {}).get("need_rename"): - chat_input_placeholder = "请输入对话名称" - else: - chat_input_placeholder = "请输入对话内容,换行请使用Ctrl+Enter " - - def on_prompt(): - st.session_state.rename_chat = True + chat_input_placeholder = "请输入对话内容,换行请使用Ctrl+Enter " - if prompt := st.chat_input(chat_input_placeholder, key="prompt", on_submit=on_prompt): - if st.session_state.chat_list.get(st.session_state.selected_page, {}).get("need_rename"): - if prompt in st.session_state.chat_list.keys(): - st.toast("已有同名对话,请重新命名") - else: - cur_chat_name = st.session_state.get("selected_page") - st.session_state.chat_list[prompt] = { - "need_rename": False, - "chat_no": st.session_state.chat_list[cur_chat_name]["chat_no"]} - st.session_state.chat_list.pop(cur_chat_name) - chat_box.del_chat_name(cur_chat_name) - st.session_state.cur_chat_name = prompt - chat_box.use_chat_name(prompt) - st.experimental_rerun() - else: + if prompt := st.chat_input(chat_input_placeholder, key="prompt"): + history = get_messages_history(history_len) + chat_box.user_say(prompt) + if dialogue_mode == "LLM 对话": + chat_box.ai_say("正在思考...") + text = "" + r = api.chat_chat(prompt, history) + for t in r: + text += t + chat_box.update_msg(text) + chat_box.update_msg(text, streaming=False) # 更新最终的字符串,去除光标 + elif dialogue_mode == "知识库问答": history = get_messages_history(history_len) - chat_box.user_say(prompt) - if dialogue_mode == "LLM 对话": - chat_box.ai_say("正在思考...") - text = "" - r = api.chat_chat(prompt, history) - for t in r: - text += t - chat_box.update_msg(text) - chat_box.update_msg(text, streaming=False) # 更新最终的字符串,去除光标 - elif dialogue_mode == "知识库问答": - history = get_messages_history(history_len) - chat_box.ai_say([ - f"正在查询知识库 `{selected_kb}` ...", - Markdown("...", in_expander=True, title="知识库匹配结果"), - ]) - text = "" - for d in api.knowledge_base_chat(prompt, selected_kb, kb_top_k, history): - text += d["answer"] - chat_box.update_msg(text, 0) - chat_box.update_msg("\n\n".join(d["docs"]), 1, streaming=False) - chat_box.update_msg(text, 0, streaming=False) - elif dialogue_mode == "搜索引擎问答": - chat_box.ai_say([ - f"正在执行 `{search_engine}` 搜索...", - Markdown("...", in_expander=True, title="网络搜索结果"), - ]) - text = "" - for d in api.search_engine_chat(prompt, search_engine, se_top_k): - text += d["answer"] - chat_box.update_msg(text, 0) - chat_box.update_msg("\n\n".join(d["docs"]), 1, streaming=False) - chat_box.update_msg(text, 0, streaming=False) + chat_box.ai_say([ + f"正在查询知识库 `{selected_kb}` ...", + Markdown("...", in_expander=True, title="知识库匹配结果"), + ]) + text = "" + for d in api.knowledge_base_chat(prompt, selected_kb, kb_top_k, history): + text += d["answer"] + chat_box.update_msg(text, 0) + chat_box.update_msg("\n\n".join(d["docs"]), 1, streaming=False) + chat_box.update_msg(text, 0, streaming=False) + elif dialogue_mode == "搜索引擎问答": + chat_box.ai_say([ + f"正在执行 `{search_engine}` 搜索...", + Markdown("...", in_expander=True, title="网络搜索结果"), + ]) + text = "" + for d in api.search_engine_chat(prompt, search_engine, se_top_k): + text += d["answer"] + chat_box.update_msg(text, 0) + chat_box.update_msg("\n\n".join(d["docs"]), 1, streaming=False) + chat_box.update_msg(text, 0, streaming=False) now = datetime.now() with st.sidebar: - cols = st.columns(3) + cols = st.columns(2) export_btn = cols[0] if cols[1].button( "清空对话", @@ -151,26 +130,10 @@ def dialogue_page(api: ApiRequest): chat_box.reset_history() st.experimental_rerun() - def on_delete_chat(): - st.session_state.delete_chat = True - - if cols[2].button( - "删除对话", - disabled=len(st.session_state.chat_list) <= 1, - use_container_width=True, - on_click=on_delete_chat - ): - cur_chat_name = st.session_state.get("selected_page") - chat_box.del_chat_name(cur_chat_name) - st.session_state.chat_list.pop(cur_chat_name) - st.session_state.cur_chat_name = list(st.session_state.chat_list.keys())[0] - chat_box.use_chat_name(st.session_state.cur_chat_name) - st.experimental_rerun() - export_btn.download_button( "导出记录", - "".join(chat_box.export2md(st.session_state.selected_page)), - file_name=f"{now:%Y-%m-%d %H.%M}_{st.session_state.selected_page}.md", + "".join(chat_box.export2md()), + file_name=f"{now:%Y-%m-%d %H.%M}_对话记录.md", mime="text/markdown", use_container_width=True, )