diff --git a/webui.py b/webui.py index ad86687..d15dc59 100644 --- a/webui.py +++ b/webui.py @@ -12,7 +12,7 @@ from webui_pages import * api = ApiRequest(base_url="http://127.0.0.1:7861", no_remote_api=False) if __name__ == "__main__": - st.set_page_config("langchain-chatglm WebUI") + st.set_page_config("langchain-chatglm WebUI", initial_sidebar_state="expanded") if not chat_box.chat_inited: st.toast( @@ -63,8 +63,11 @@ if __name__ == "__main__": st.session_state["cur_chat_name"] = new_chat_name st.session_state[key] = new_chat_name elif st.session_state[key] not in ["新建对话", "知识库管理"]: - st.session_state["cur_chat_name"] = st.session_state[key] - + if st.session_state.get("prompt"): + st.session_state["cur_chat_name"] = st.session_state.get("prompt") + else: + st.session_state["cur_chat_name"] = st.session_state[key] + with st.sidebar: selected_page = option_menu( "langchain-chatglm", @@ -75,4 +78,5 @@ if __name__ == "__main__": on_change=on_page_change, ) - pages[selected_page]["func"](api) + if selected_page == "知识库管理" or selected_page in pages: + pages[selected_page]["func"](api) diff --git a/webui_pages/dialogue/dialogue.py b/webui_pages/dialogue/dialogue.py index fe3e9da..8b4e7c4 100644 --- a/webui_pages/dialogue/dialogue.py +++ b/webui_pages/dialogue/dialogue.py @@ -82,15 +82,16 @@ def dialogue_page(api: ApiRequest): chat_box.output_messages() - if (st.session_state.cur_chat_name == "新建对话" - or st.session_state.chat_list.get(st.session_state.cur_chat_name, {}).get("need_rename")): + if st.session_state.chat_list.get(st.session_state.cur_chat_name, {}).get("need_rename"): chat_input_placeholder = "请输入对话名称" else: chat_input_placeholder = "请输入对话内容,换行请使用Ctrl+Enter " + + def on_prompt(): + st.session_state["selected_page"] = prompt - if prompt := st.chat_input(chat_input_placeholder, key="prompt"): - if (st.session_state.cur_chat_name == "新建对话" - or st.session_state.chat_list.get(st.session_state.cur_chat_name, {}).get("need_rename")): + if prompt := st.chat_input(chat_input_placeholder, key="prompt", on_submit=on_prompt): + if st.session_state.chat_list.get(st.session_state.cur_chat_name, {}).get("need_rename"): if prompt in st.session_state.chat_list.keys(): st.toast("已有同名对话,请重新命名") else: