diff --git a/knowledge_base/samples/vector_store/index.faiss b/knowledge_base/samples/vector_store/index.faiss index 480036d..a3fcd52 100644 Binary files a/knowledge_base/samples/vector_store/index.faiss and b/knowledge_base/samples/vector_store/index.faiss differ diff --git a/knowledge_base/samples/vector_store/index.pkl b/knowledge_base/samples/vector_store/index.pkl index ef82196..a324da1 100644 Binary files a/knowledge_base/samples/vector_store/index.pkl and b/knowledge_base/samples/vector_store/index.pkl differ diff --git a/webui.py b/webui.py index 7aad225..0a54c14 100644 --- a/webui.py +++ b/webui.py @@ -20,21 +20,45 @@ if __name__ == "__main__": f"当前使用模型`{LLM_MODEL}`, 您可以开始提问了." ) - pages = {"对话": {"icon": "chat", + if "chat_list" not in st.session_state: + st.session_state["chat_list"] = ["对话"] + if "cur_chat_name" not in st.session_state: + st.session_state["cur_chat_name"] = "对话" + if "need_chat_name" not in st.session_state: + st.session_state["need_chat_name"] = True + + pages = {i: {"icon": "chat", + "func": dialogue_page, + } for i in st.session_state.chat_list} + pages2 = { + "新建对话": {"icon": "plus-circle", "func": dialogue_page, - }, - "知识库管理": {"icon": "hdd-stack", - "func": knowledge_base_page, - }, - # "模型配置": {"icon": "gear", - # "func": model_config_page, - # } - } + }, + "---": {"icon": None, + "func": None}, + "知识库管理": {"icon": "hdd-stack", + "func": knowledge_base_page, + }, + # "模型配置": {"icon": "gear", + # "func": model_config_page, + # } + } + pages.update(pages2) with st.sidebar: selected_page = option_menu("langchain-chatglm", options=list(pages.keys()), icons=[i["icon"] for i in pages.values()], menu_icon="chat-quote", - default_index=0) - pages[selected_page]["func"](api) + default_index=list(pages.keys()).index(st.session_state["cur_chat_name"])) + if selected_page == "新建对话": + if len(st.session_state.chat_list) > 1 and st.session_state.chat_list[0] == "对话": + st.session_state.chat_list[0] = "对话1" + st.write(st.session_state.chat_list) + new_chat_name = f"对话{len(st.session_state.chat_list) + 1}" + st.session_state.chat_list += [new_chat_name] + st.session_state["cur_chat_name"] = new_chat_name + st.session_state["need_chat_name"] = True + st.experimental_rerun() + else: + pages[selected_page]["func"](api) diff --git a/webui_pages/dialogue/dialogue.py b/webui_pages/dialogue/dialogue.py index 7a92128..abf7ab5 100644 --- a/webui_pages/dialogue/dialogue.py +++ b/webui_pages/dialogue/dialogue.py @@ -35,46 +35,32 @@ def dialogue_page(api: ApiRequest): chat_box.init_session() with st.sidebar: - with st.expander("会话管理", True): - col_input, col_btn = st.columns([1.5, 1]) - col_input.text_input( - "新会话名称", - placeholder="新会话名称", - label_visibility="collapsed", - key="new_chat_name", - ) + # with st.expander("会话管理", True): + # col_input, col_btn = st.columns([1.5, 1]) + # col_input.text_input( + # "新会话名称", + # placeholder="新会话名称", + # label_visibility="collapsed", + # key="new_chat_name", + # ) + # + # def on_btn_new_chat(): + # new_chat_name = st.session_state.new_chat_name + # if new_chat_name: + # chat_box.use_chat_name(new_chat_name) + # st.session_state.new_chat_name = "" + # + # col_btn.button( + # "新建会话", + # on_click=on_btn_new_chat, + # use_container_width=True, + # ) + # + # chat_list = chat_box.get_chat_names() + # cur_chat_name = sac.buttons(chat_list, 0) + # chat_box.use_chat_name(cur_chat_name) - def on_btn_new_chat(): - new_chat_name = st.session_state.new_chat_name - if new_chat_name: - chat_box.use_chat_name(new_chat_name) - st.session_state.new_chat_name = "" - col_btn.button( - "新建会话", - on_click=on_btn_new_chat, - use_container_width=True, - ) - - chat_list = chat_box.get_chat_names() - cur_chat_name = sac.buttons(chat_list, 0) - chat_box.use_chat_name(cur_chat_name) - - cols = st.columns(3) - export_btn = cols[0] - if cols[1].button( - "Clear", - use_container_width=True, - ): - chat_box.reset_history() - - if cols[2].button( - "Delete", - disabled=len(chat_list) <= 1, - use_container_width=True, - ): - chat_box.del_chat_name(cur_chat_name) - st.experimental_rerun() def on_mode_change(): mode = st.session_state.dialogue_mode @@ -86,7 +72,7 @@ def dialogue_page(api: ApiRequest): st.toast(text) # sac.alert(text, description="descp", type="success", closable=True, banner=True) - dialogue_mode = st.radio("请选择对话模式", + dialogue_mode = st.selectbox("请选择对话模式", ["LLM 对话", "知识库问答", "搜索引擎问答", @@ -94,7 +80,7 @@ def dialogue_page(api: ApiRequest): on_change=on_mode_change, key="dialogue_mode", ) - history_len = st.slider("历史对话轮数:", 0, 10, 3) + history_len = st.number_input("历史对话轮数:", 0, 10, 3) # todo: support history len def on_kb_change(): @@ -109,57 +95,87 @@ def dialogue_page(api: ApiRequest): on_change=on_kb_change, key="selected_kb", ) - kb_top_k = st.slider("匹配知识条数:", 1, 20, 3) + kb_top_k = st.number_input("匹配知识条数:", 1, 20, 3) # score_threshold = st.slider("知识匹配分数阈值:", 0, 1, 0, disabled=True) # chunk_content = st.checkbox("关联上下文", False, disabled=True) # chunk_size = st.slider("关联长度:", 0, 500, 250, disabled=True) elif dialogue_mode == "搜索引擎问答": - search_engine = sac.buttons(SEARCH_ENGINES.keys(), 0) - se_top_k = st.slider("匹配搜索结果条数:", 1, 20, 3) + with st.expander("搜索引擎配置", True): + search_engine = st.selectbox("请选择搜索引擎", SEARCH_ENGINES.keys(), 0) + se_top_k = st.number_input("匹配搜索结果条数:", 1, 20, 3) # Display chat messages from history on app rerun - chat_box.output_messages() - if prompt := st.chat_input("请输入对话内容,换行请使用Ctrl+Enter"): - history = get_messages_history(history_len) - chat_box.user_say(prompt) - if dialogue_mode == "LLM 对话": - chat_box.ai_say("正在思考...") - text = "" - r = api.chat_chat(prompt, history) - for t in r: - text += t - chat_box.update_msg(text) - chat_box.update_msg(text, streaming=False) # 更新最终的字符串,去除光标 - elif dialogue_mode == "知识库问答": + chat_box.use_chat_name(st.session_state.cur_chat_name) + chat_box.output_messages() + chat_input_placeholder = "请输入对话名称" if st.session_state.need_chat_name else "请输入对话内容,换行请使用Ctrl+Enter" + if prompt := st.chat_input(chat_input_placeholder): + if st.session_state.need_chat_name: + if prompt in st.session_state.chat_list: + st.toast("已有同名对话,请重新命名") + else: + st.session_state.chat_list[st.session_state.chat_list.index(st.session_state.cur_chat_name)] = prompt + st.session_state.need_chat_name = False + st.session_state.cur_chat_name = prompt + chat_box.use_chat_name(prompt) + st.experimental_rerun() + else: history = get_messages_history(history_len) - chat_box.ai_say([ - f"正在查询知识库: `{selected_kb}` ...", - Markdown("...", in_expander=True, title="知识库匹配结果"), - ]) - text = "" - for d in api.knowledge_base_chat(prompt, selected_kb, kb_top_k, history): - text += d["answer"] - chat_box.update_msg(text, 0) - chat_box.update_msg("\n\n".join(d["docs"]), 1, streaming=False) - chat_box.update_msg(text, 0, streaming=False) - elif dialogue_mode == "搜索引擎问答": - chat_box.ai_say([ - f"正在执行{search_engine}搜索...", - Markdown("...", in_expander=True, title="网络搜索结果"), - ]) - text = "" - for d in api.bing_search_chat(prompt, search_engine, se_top_k): - text += d["answer"] - chat_box.update_msg(text, 0) - chat_box.update_msg("\n\n".join(d["docs"]), 1, streaming=False) - chat_box.update_msg(text, 0, streaming=False) + chat_box.user_say(prompt) + if dialogue_mode == "LLM 对话": + chat_box.ai_say("正在思考...") + text = "" + r = api.chat_chat(prompt, history) + for t in r: + text += t + chat_box.update_msg(text) + chat_box.update_msg(text, streaming=False) # 更新最终的字符串,去除光标 + elif dialogue_mode == "知识库问答": + history = get_messages_history(history_len) + chat_box.ai_say([ + f"正在查询知识库: `{selected_kb}` ...", + Markdown("...", in_expander=True, title="知识库匹配结果"), + ]) + text = "" + for d in api.knowledge_base_chat(prompt, selected_kb, kb_top_k, history): + text += d["answer"] + chat_box.update_msg(text, 0) + chat_box.update_msg("\n\n".join(d["docs"]), 1, streaming=False) + chat_box.update_msg(text, 0, streaming=False) + elif dialogue_mode == "搜索引擎问答": + chat_box.ai_say([ + f"正在执行{search_engine}搜索...", + Markdown("...", in_expander=True, title="网络搜索结果"), + ]) + text = "" + for d in api.bing_search_chat(prompt, search_engine, se_top_k): + text += d["answer"] + chat_box.update_msg(text, 0) + chat_box.update_msg("\n\n".join(d["docs"]), 1, streaming=False) + chat_box.update_msg(text, 0, streaming=False) now = datetime.now() + with st.sidebar: + + cols = st.columns(3) + export_btn = cols[0] + if cols[1].button( + "Clear", + use_container_width=True, + ): + chat_box.reset_history() + + if cols[2].button( + "Delete", + disabled=len(st.session_state.chat_list) <= 1, + use_container_width=True, + ): + chat_box.del_chat_name(st.session_state.cur_chat_name) + st.experimental_rerun() export_btn.download_button( "Export", - "".join(chat_box.export2md(cur_chat_name)), - file_name=f"{now:%Y-%m-%d %H.%M}_{cur_chat_name}.md", + "".join(chat_box.export2md(st.session_state.cur_chat_name)), + file_name=f"{now:%Y-%m-%d %H.%M}_{st.session_state.cur_chat_name}.md", mime="text/markdown", use_container_width=True, )