diff --git a/server/chat/chat.py b/server/chat/chat.py index a0c60b5..2b163d4 100644 --- a/server/chat/chat.py +++ b/server/chat/chat.py @@ -20,6 +20,7 @@ def chat(query: str = Body(..., description="用户输入", example="恼羞成 {"role": "assistant", "content": "虎头虎脑"}] ), ): + history = [History(**h) if isinstance(h, dict) else h for h in history] async def chat_iterator(query: str, history: List[History] = [], ) -> AsyncIterable[str]: diff --git a/server/chat/knowledge_base_chat.py b/server/chat/knowledge_base_chat.py index 5c49175..17c045d 100644 --- a/server/chat/knowledge_base_chat.py +++ b/server/chat/knowledge_base_chat.py @@ -32,6 +32,7 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp if kb is None: return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}") + history = [History(**h) if isinstance(h, dict) else h for h in history] async def knowledge_base_chat_iterator(query: str, kb: KBService, top_k: int, diff --git a/webui_pages/dialogue/dialogue.py b/webui_pages/dialogue/dialogue.py index e03a539..87ac9ee 100644 --- a/webui_pages/dialogue/dialogue.py +++ b/webui_pages/dialogue/dialogue.py @@ -4,10 +4,25 @@ from streamlit_chatbox import * from datetime import datetime import streamlit_antd_components as sac from server.chat.search_engine_chat import SEARCH_ENGINES +from typing import List, Dict chat_box = ChatBox() +def get_messages_history(history_len: int) -> List[Dict]: + def filter(msg): + ''' + 针对当前简单文本对话,只返回每条消息的第一个element的内容 + ''' + content = [x._content for x in msg["elements"] if x._output_method in ["markdown", "text"]] + return { + "role": msg["role"], + "content": content[0] if content else "", + } + history = chat_box.filter_history(history_len, filter) + return history + + def dialogue_page(api: ApiRequest): chat_box.init_session() @@ -58,7 +73,7 @@ def dialogue_page(api: ApiRequest): on_change=on_mode_change, key="dialogue_mode", ) - history_len = st.slider("历史对话轮数:", 1, 10, 1, disabled=True) + history_len = st.slider("历史对话轮数:", 1, 10, 3) # todo: support history len def on_kb_change(): @@ -85,22 +100,24 @@ def dialogue_page(api: ApiRequest): chat_box.output_messages() if prompt := st.chat_input("请输入对话内容,换行请使用Ctrl+Enter"): + history = get_messages_history(history_len) chat_box.user_say(prompt) if dialogue_mode == "LLM 对话": chat_box.ai_say("正在思考...") text = "" - r = api.chat_chat(prompt, no_remote_api=True) + r = api.chat_chat(prompt, history) for t in r: text += t chat_box.update_msg(text) chat_box.update_msg(text, streaming=False) # 更新最终的字符串,去除光标 elif dialogue_mode == "知识库问答": + history = get_messages_history(history_len) chat_box.ai_say([ f"正在查询知识库: `{selected_kb}` ...", Markdown("...", in_expander=True, title="知识库匹配结果"), ]) text = "" - for d in api.knowledge_base_chat(prompt, selected_kb, kb_top_k): + for d in api.knowledge_base_chat(prompt, selected_kb, kb_top_k, history): text += d["answer"] chat_box.update_msg(text, 0) chat_box.update_msg("\n\n".join(d["docs"]), 1, streaming=False)