Langchain-Chatchat/libs/chatchat-server/chatchat/webui_pages/kb_chat.py

258 lines
11 KiB
Python
Raw Normal View History

2024-12-20 16:04:03 +08:00
from datetime import datetime
import uuid
from typing import List, Dict
import openai
import streamlit as st
import streamlit_antd_components as sac
from streamlit_chatbox import *
from streamlit_extras.bottom_container import bottom
from chatchat.settings import Settings
from chatchat.server.knowledge_base.utils import LOADER_DICT
from chatchat.server.utils import get_config_models, get_config_platforms, get_default_llm, api_address
from chatchat.webui_pages.dialogue.dialogue import (save_session, restore_session, rerun,
get_messages_history, upload_temp_docs,
add_conv, del_conv, clear_conv)
from chatchat.webui_pages.utils import *
chat_box = ChatBox(assistant_avatar=get_img_base64("chatchat_icon_blue_square_v2.png"))
def init_widgets():
st.session_state.setdefault("history_len", Settings.model_settings.HISTORY_LEN)
st.session_state.setdefault("selected_kb", Settings.kb_settings.DEFAULT_KNOWLEDGE_BASE)
st.session_state.setdefault("kb_top_k", Settings.kb_settings.VECTOR_SEARCH_TOP_K)
st.session_state.setdefault("se_top_k", Settings.kb_settings.SEARCH_ENGINE_TOP_K)
st.session_state.setdefault("score_threshold", Settings.kb_settings.SCORE_THRESHOLD)
st.session_state.setdefault("search_engine", Settings.kb_settings.DEFAULT_SEARCH_ENGINE)
st.session_state.setdefault("return_direct", False)
st.session_state.setdefault("cur_conv_name", chat_box.cur_chat_name)
st.session_state.setdefault("last_conv_name", chat_box.cur_chat_name)
st.session_state.setdefault("file_chat_id", None)
def kb_chat(api: ApiRequest):
ctx = chat_box.context
ctx.setdefault("uid", uuid.uuid4().hex)
ctx.setdefault("file_chat_id", None)
ctx.setdefault("llm_model", get_default_llm())
ctx.setdefault("temperature", Settings.model_settings.TEMPERATURE)
init_widgets()
# sac on_change callbacks not working since st>=1.34
if st.session_state.cur_conv_name != st.session_state.last_conv_name:
save_session(st.session_state.last_conv_name)
restore_session(st.session_state.cur_conv_name)
st.session_state.last_conv_name = st.session_state.cur_conv_name
# st.write(chat_box.cur_chat_name)
# st.write(st.session_state)
@st.experimental_dialog("模型配置", width="large")
def llm_model_setting():
# 模型
cols = st.columns(3)
platforms = ["所有"] + list(get_config_platforms())
platform = cols[0].selectbox("选择模型平台", platforms, key="platform")
llm_models = list(
get_config_models(
model_type="llm", platform_name=None if platform == "所有" else platform
)
)
llm_models += list(
get_config_models(
model_type="image2text", platform_name=None if platform == "所有" else platform
)
)
llm_model = cols[1].selectbox("选择LLM模型", llm_models, key="llm_model")
temperature = cols[2].slider("Temperature", 0.0, 1.0, key="temperature")
system_message = st.text_area("System Message:", key="system_message")
if st.button("OK"):
rerun()
@st.experimental_dialog("重命名会话")
def rename_conversation():
name = st.text_input("会话名称")
if st.button("OK"):
chat_box.change_chat_name(name)
restore_session()
st.session_state["cur_conv_name"] = name
rerun()
# 配置参数
with st.sidebar:
tabs = st.tabs(["RAG 配置", "会话设置"])
with tabs[0]:
dialogue_modes = ["知识库问答",
"文件对话",
"搜索引擎问答",
]
dialogue_mode = st.selectbox("请选择对话模式:",
dialogue_modes,
key="dialogue_mode",
)
placeholder = st.empty()
st.divider()
# prompt_templates_kb_list = list(Settings.prompt_settings.rag)
# prompt_name = st.selectbox(
# "请选择Prompt模板",
# prompt_templates_kb_list,
# key="prompt_name",
# )
prompt_name="default"
history_len = st.number_input("历史对话轮数:", 0, 20, key="history_len")
kb_top_k = st.number_input("匹配知识条数:", 1, 20, key="kb_top_k")
## Bge 模型会超过1
score_threshold = st.slider("知识匹配分数阈值:", 0.0, 2.0, step=0.01, key="score_threshold")
return_direct = st.checkbox("仅返回检索结果", key="return_direct")
def on_kb_change():
st.toast(f"已加载知识库: {st.session_state.selected_kb}")
with placeholder.container():
if dialogue_mode == "知识库问答":
kb_list = [x["kb_name"] for x in api.list_knowledge_bases()]
selected_kb = st.selectbox(
"请选择知识库:",
kb_list,
on_change=on_kb_change,
key="selected_kb",
)
elif dialogue_mode == "文件对话":
files = st.file_uploader("上传知识文件:",
[i for ls in LOADER_DICT.values() for i in ls],
accept_multiple_files=True,
)
if st.button("开始上传", disabled=len(files) == 0):
st.session_state["file_chat_id"] = upload_temp_docs(files, api)
elif dialogue_mode == "搜索引擎问答":
search_engine_list = list(Settings.tool_settings.search_internet["search_engine_config"])
search_engine = st.selectbox(
label="请选择搜索引擎",
options=search_engine_list,
key="search_engine",
)
with tabs[1]:
# 会话
cols = st.columns(3)
conv_names = chat_box.get_chat_names()
def on_conv_change():
print(conversation_name, st.session_state.cur_conv_name)
save_session(conversation_name)
restore_session(st.session_state.cur_conv_name)
conversation_name = sac.buttons(
conv_names,
label="当前会话:",
key="cur_conv_name",
on_change=on_conv_change,
)
chat_box.use_chat_name(conversation_name)
conversation_id = chat_box.context["uid"]
if cols[0].button("新建", on_click=add_conv):
...
if cols[1].button("重命名"):
rename_conversation()
if cols[2].button("删除", on_click=del_conv):
...
# Display chat messages from history on app rerun
chat_box.output_messages()
chat_input_placeholder = "请输入对话内容换行请使用Shift+Enter。"
llm_model = ctx.get("llm_model")
# chat input
with bottom():
cols = st.columns([1, 0.2, 15, 1])
if cols[0].button(":gear:", help="模型配置"):
widget_keys = ["platform", "llm_model", "temperature", "system_message"]
chat_box.context_to_session(include=widget_keys)
llm_model_setting()
if cols[-1].button(":wastebasket:", help="清空对话"):
chat_box.reset_history()
rerun()
# with cols[1]:
# mic_audio = audio_recorder("", icon_size="2x", key="mic_audio")
prompt = cols[2].chat_input(chat_input_placeholder, key="prompt")
if prompt:
history = get_messages_history(ctx.get("history_len", 0))
messages = history + [{"role": "user", "content": prompt}]
chat_box.user_say(prompt)
extra_body = dict(
top_k=kb_top_k,
score_threshold=score_threshold,
temperature=ctx.get("temperature"),
prompt_name=prompt_name,
return_direct=return_direct,
)
api_url = api_address(is_public=True)
if dialogue_mode == "知识库问答":
client = openai.Client(base_url=f"{api_url}/knowledge_base/local_kb/{selected_kb}", api_key="NONE")
chat_box.ai_say([
Markdown("...", in_expander=True, title="知识库匹配结果", state="running", expanded=return_direct),
f"正在查询知识库 `{selected_kb}` ...",
])
elif dialogue_mode == "文件对话":
if st.session_state.get("file_chat_id") is None:
st.error("请先上传文件再进行对话")
st.stop()
knowledge_id=st.session_state.get("file_chat_id")
client = openai.Client(base_url=f"{api_url}/knowledge_base/temp_kb/{knowledge_id}", api_key="NONE")
chat_box.ai_say([
Markdown("...", in_expander=True, title="知识库匹配结果", state="running", expanded=return_direct),
f"正在查询文件 `{st.session_state.get('file_chat_id')}` ...",
])
else:
client = openai.Client(base_url=f"{api_url}/knowledge_base/search_engine/{search_engine}", api_key="NONE")
chat_box.ai_say([
Markdown("...", in_expander=True, title="知识库匹配结果", state="running", expanded=return_direct),
f"正在执行 `{search_engine}` 搜索...",
])
text = ""
first = True
try:
for d in client.chat.completions.create(messages=messages, model=llm_model, stream=True, extra_body=extra_body):
if first:
chat_box.update_msg("\n\n".join(d.docs), element_index=0, streaming=False, state="complete")
chat_box.update_msg("", streaming=False)
first = False
continue
text += d.choices[0].delta.content or ""
chat_box.update_msg(text.replace("\n", "\n\n"), streaming=True)
chat_box.update_msg(text, streaming=False)
# TODO: 搜索未配置API KEY时产生报错
except Exception as e:
st.error(e.body)
now = datetime.now()
with tabs[1]:
cols = st.columns(2)
export_btn = cols[0]
if cols[1].button(
"清空对话",
use_container_width=True,
):
chat_box.reset_history()
rerun()
export_btn.download_button(
"导出记录",
"".join(chat_box.export2md()),
file_name=f"{now:%Y-%m-%d %H.%M}_对话记录.md",
mime="text/markdown",
use_container_width=True,
)
# st.write(chat_box.history)