update webui layout

This commit is contained in:
imClumsyPanda 2023-08-10 23:51:10 +08:00
parent ef098111dc
commit 3c44cf65cd
3 changed files with 115 additions and 66 deletions

View File

@ -41,5 +41,4 @@ if __name__ == "__main__":
icons=[i["icon"] for i in pages.values()], icons=[i["icon"] for i in pages.values()],
menu_icon="chat-quote", menu_icon="chat-quote",
default_index=0) default_index=0)
print(f"root: {api.no_remote_api=}")
pages[selected_page]["func"](api) pages[selected_page]["func"](api)

View File

@ -6,9 +6,9 @@ import streamlit_antd_components as sac
from server.chat.search_engine_chat import SEARCH_ENGINES from server.chat.search_engine_chat import SEARCH_ENGINES
from typing import List, Dict from typing import List, Dict
chat_box = ChatBox() chat_box = ChatBox()
def get_messages_history(history_len: int) -> List[Dict]: def get_messages_history(history_len: int) -> List[Dict]:
def filter(msg): def filter(msg):
''' '''
@ -19,6 +19,7 @@ def get_messages_history(history_len: int) -> List[Dict]:
"role": msg["role"], "role": msg["role"],
"content": content[0] if content else "", "content": content[0] if content else "",
} }
history = chat_box.filter_history(100000, filter) # workaround before upgrading streamlit-chatbox. history = chat_box.filter_history(100000, filter) # workaround before upgrading streamlit-chatbox.
user_count = 0 user_count = 0
i = 1 i = 1
@ -35,7 +36,7 @@ def dialogue_page(api: ApiRequest):
with st.sidebar: with st.sidebar:
with st.expander("会话管理", True): with st.expander("会话管理", True):
col_input, col_btn = st.columns(2) col_input, col_btn = st.columns([2, 1])
new_chat_name = col_input.text_input( new_chat_name = col_input.text_input(
"新会话名称", "新会话名称",
placeholder="新会话名称", placeholder="新会话名称",
@ -48,6 +49,7 @@ def dialogue_page(api: ApiRequest):
if new_chat_name: if new_chat_name:
chat_box.use_chat_name(new_chat_name) chat_box.use_chat_name(new_chat_name)
st.session_state.new_chat_name = "" st.session_state.new_chat_name = ""
col_btn.button("新建会话", on_click=on_btn_new_chat) col_btn.button("新建会话", on_click=on_btn_new_chat)
chat_list = chat_box.get_chat_names() chat_list = chat_box.get_chat_names()
@ -56,10 +58,17 @@ def dialogue_page(api: ApiRequest):
cols = st.columns(3) cols = st.columns(3)
export_btn = cols[0] export_btn = cols[0]
if cols[1].button("Clear"): if cols[1].button(
"Clear",
use_container_width=True,
):
chat_box.reset_history() chat_box.reset_history()
if cols[2].button("Delete", disabled=len(chat_list) <= 1): if cols[2].button(
"Delete",
disabled=len(chat_list) <= 1,
use_container_width=True,
):
chat_box.del_chat_name(cur_chat_name) chat_box.del_chat_name(cur_chat_name)
st.experimental_rerun() st.experimental_rerun()
@ -82,6 +91,7 @@ def dialogue_page(api: ApiRequest):
key="dialogue_mode", key="dialogue_mode",
) )
history_len = st.slider("历史对话轮数:", 1, 10, 3) history_len = st.slider("历史对话轮数:", 1, 10, 3)
# todo: support history len # todo: support history len
def on_kb_change(): def on_kb_change():
@ -97,9 +107,9 @@ def dialogue_page(api: ApiRequest):
key="selected_kb", key="selected_kb",
) )
kb_top_k = st.slider("匹配知识条数:", 1, 20, 3) kb_top_k = st.slider("匹配知识条数:", 1, 20, 3)
score_threshold = st.slider("知识匹配分数阈值:", 0, 1000, 0, disabled=True) # score_threshold = st.slider("知识匹配分数阈值:", 0, 1, 0, disabled=True)
chunk_content = st.checkbox("关联上下文", False, disabled=True) # chunk_content = st.checkbox("关联上下文", False, disabled=True)
chunk_size = st.slider("关联长度:", 0, 500, 250, disabled=True) # chunk_size = st.slider("关联长度:", 0, 500, 250, disabled=True)
elif dialogue_mode == "搜索引擎问答": elif dialogue_mode == "搜索引擎问答":
search_engine = sac.buttons(SEARCH_ENGINES.keys(), 0) search_engine = sac.buttons(SEARCH_ENGINES.keys(), 0)
se_top_k = st.slider("匹配搜索结果条数:", 1, 20, 3) se_top_k = st.slider("匹配搜索结果条数:", 1, 20, 3)
@ -148,4 +158,5 @@ def dialogue_page(api: ApiRequest):
"".join(chat_box.export2md(cur_chat_name)), "".join(chat_box.export2md(cur_chat_name)),
file_name=f"{now:%Y-%m-%d %H.%M}_{cur_chat_name}.md", file_name=f"{now:%Y-%m-%d %H.%M}_{cur_chat_name}.md",
mime="text/markdown", mime="text/markdown",
use_container_width=True,
) )

View File

@ -8,7 +8,6 @@ from server.knowledge_base.utils import get_file_path
# from streamlit_chatbox import * # from streamlit_chatbox import *
from typing import Literal, Dict, Tuple from typing import Literal, Dict, Tuple
SENTENCE_SIZE = 100 SENTENCE_SIZE = 100
@ -33,19 +32,23 @@ def config_aggrid(
# kb_box = ChatBox(session_key="kb_messages") # kb_box = ChatBox(session_key="kb_messages")
def knowledge_base_page(api: ApiRequest): def knowledge_base_page(api: ApiRequest):
api = ApiRequest(base_url="http://127.0.0.1:7861", no_remote_api=True) # api = ApiRequest(base_url="http://127.0.0.1:7861", no_remote_api=True)
kb_details = get_kb_details(api) kb_details = get_kb_details(api)
kb_list = list(kb_details.kb_name) kb_list = list(kb_details.kb_name)
cols = st.columns([2, 1, 1]) cols = st.columns([3, 1, 1])
new_kb_name = cols[0].text_input( new_kb_name = cols[0].text_input(
"新知识库名称", "新知识库名称",
placeholder="新知识库名称", placeholder="新知识库名称,暂不支持中文命名",
label_visibility="collapsed", label_visibility="collapsed",
key="new_kb_name", key="new_kb_name",
) )
if cols[1].button("新建", disabled=not bool(new_kb_name)) and new_kb_name: if cols[1].button(
"新建",
disabled=not bool(new_kb_name),
use_container_width=True,
) and new_kb_name:
if new_kb_name in kb_list: if new_kb_name in kb_list:
st.error(f"名为 {new_kb_name} 的知识库已经存在!") st.error(f"名为 {new_kb_name} 的知识库已经存在!")
else: else:
@ -53,7 +56,11 @@ def knowledge_base_page(api: ApiRequest):
st.toast(ret["msg"]) st.toast(ret["msg"])
st.experimental_rerun() st.experimental_rerun()
if cols[2].button("删除", disabled=not bool(new_kb_name)) and new_kb_name: if cols[2].button(
"删除",
disabled=not bool(new_kb_name),
use_container_width=True,
) and new_kb_name:
if new_kb_name in kb_list: if new_kb_name in kb_list:
ret = api.delete_knowledge_base(new_kb_name) ret = api.delete_knowledge_base(new_kb_name)
st.toast(ret["msg"]) st.toast(ret["msg"])
@ -62,27 +69,33 @@ def knowledge_base_page(api: ApiRequest):
st.error(f"名为 {new_kb_name} 的知识库不存在!") st.error(f"名为 {new_kb_name} 的知识库不存在!")
st.write("知识库列表:") st.write("知识库列表:")
st.info("请选择知识库")
if kb_list: if kb_list:
gb = config_aggrid( gb = config_aggrid(
kb_details, kb_details,
{ {
("kb_name", "知识库名称"): {"maxWidth": 150}, ("kb_name", "知识库名称"): {},
("vs_type", "知识库类型"): {"maxWidth": 100}, ("vs_type", "知识库类型"): {},
("embed_model", "嵌入模型"): {"maxWidth": 100}, ("embed_model", "嵌入模型"): {},
("file_count", "文档数量"): {"maxWidth": 60}, ("file_count", "文档数量"): {},
("create_time", "创建时间"): {"maxWidth": 150}, ("create_time", "创建时间"): {},
("in_folder", "文件夹"): {"maxWidth": 50}, ("in_folder", "文件夹"): {},
("in_db", "数据库"): {"maxWidth": 50}, ("in_db", "数据库"): {},
} }
) )
kb_grid = AgGrid(kb_details, gb.build()) kb_grid = AgGrid(
kb_details,
gb.build(),
columns_auto_size_mode="FIT_CONTENTS",
theme="alpine",
)
# st.write(kb_grid) # st.write(kb_grid)
if kb_grid.selected_rows: if kb_grid.selected_rows:
# st.session_state.selected_rows = [x["nIndex"] for x in kb_grid.selected_rows] # st.session_state.selected_rows = [x["nIndex"] for x in kb_grid.selected_rows]
kb = kb_grid.selected_rows[0]["kb_name"] kb = kb_grid.selected_rows[0]["kb_name"]
with st.sidebar: with st.sidebar:
sentence_size = st.slider("文本入库分句长度限制", 1, 1000, SENTENCE_SIZE, disabled=True) # sentence_size = st.slider("文本入库分句长度限制", 1, 1000, SENTENCE_SIZE, disabled=True)
files = st.file_uploader("上传知识文件", files = st.file_uploader("上传知识文件",
["docx", "txt", "md", "csv", "xlsx", "pdf"], ["docx", "txt", "md", "csv", "xlsx", "pdf"],
accept_multiple_files=True, accept_multiple_files=True,
@ -91,7 +104,7 @@ def knowledge_base_page(api: ApiRequest):
"添加文件到知识库", "添加文件到知识库",
help="请先上传文件,再点击添加", help="请先上传文件,再点击添加",
use_container_width=True, use_container_width=True,
disabled=len(files)==0, disabled=len(files) == 0,
): ):
for f in files: for f in files:
ret = api.upload_kb_doc(f, kb) ret = api.upload_kb_doc(f, kb)
@ -101,37 +114,43 @@ def knowledge_base_page(api: ApiRequest):
st.toast(ret["msg"], icon="") st.toast(ret["msg"], icon="")
st.session_state.files = [] st.session_state.files = []
if st.button( # if st.button(
"重建知识库", # "重建知识库",
help="无需上传文件通过其它方式将文档拷贝到对应知识库content目录下点击本按钮即可重建知识库。", # help="无需上传文件通过其它方式将文档拷贝到对应知识库content目录下点击本按钮即可重建知识库。",
use_container_width=True, # use_container_width=True,
disabled=True, # disabled=True,
): # ):
progress = st.progress(0.0, "") # progress = st.progress(0.0, "")
for d in api.recreate_vector_store(kb): # for d in api.recreate_vector_store(kb):
progress.progress(d["finished"] / d["t]otal"], f"正在处理: {d['doc']}") # progress.progress(d["finished"] / d["total"], f"正在处理: {d['doc']}")
# 知识库详情 # 知识库详情
st.write(f"知识库 {kb} 详情:") st.write(f"知识库 `{kb}` 详情:")
st.info("请选择文件")
doc_details = get_kb_doc_details(api, kb) doc_details = get_kb_doc_details(api, kb)
doc_details.drop(columns=["kb_name"], inplace=True) doc_details.drop(columns=["kb_name"], inplace=True)
gb = config_aggrid( gb = config_aggrid(
doc_details, doc_details,
{ {
("file_name", "文档名称"): {"maxWidth": 150}, ("file_name", "文档名称"): {},
("file_ext", "文档类型"): {"maxWidth": 50}, ("file_ext", "文档类型"): {},
("file_version", "文档版本"): {"maxWidth": 50}, ("file_version", "文档版本"): {},
("document_loader", "文档加载器"): {"maxWidth": 150}, ("document_loader", "文档加载器"): {},
("text_splitter", "分词器"): {"maxWidth": 150}, ("text_splitter", "分词器"): {},
("create_time", "创建时间"): {"maxWidth": 150}, ("create_time", "创建时间"): {},
("in_folder", "文件夹"): {"maxWidth": 50}, ("in_folder", "文件夹"): {},
("in_db", "数据库"): {"maxWidth": 50}, ("in_db", "数据库"): {},
}, },
"multiple", "multiple",
) )
doc_grid = AgGrid(doc_details, gb.build()) doc_grid = AgGrid(
doc_details,
gb.build(),
columns_auto_size_mode="FIT_CONTENTS",
theme="alpine",
)
cols = st.columns(3) cols = st.columns(3)
selected_rows = doc_grid.get("selected_rows", []) selected_rows = doc_grid.get("selected_rows", [])
@ -141,21 +160,41 @@ def knowledge_base_page(api: ApiRequest):
file_name = selected_rows[0]["file_name"] file_name = selected_rows[0]["file_name"]
file_path = get_file_path(kb, file_name) file_path = get_file_path(kb, file_name)
with open(file_path, "rb") as fp: with open(file_path, "rb") as fp:
cols[0].download_button("下载选中文档", fp, file_name=file_name) cols[0].download_button(
"下载选中文档",
fp,
file_name=file_name,
use_container_width=True,)
else: else:
cols[0].download_button("下载选中文档", "", disabled=True) cols[0].download_button(
"下载选中文档",
"",
disabled=True,
use_container_width=True,)
if cols[1].button("入库", disabled=len(selected_rows)==0): if cols[1].button(
"入库",
disabled=len(selected_rows) == 0,
use_container_width=True,
):
for row in selected_rows: for row in selected_rows:
api.update_kb_doc(kb, row["file_name"]) api.update_kb_doc(kb, row["file_name"])
st.experimental_rerun() st.experimental_rerun()
if cols[2].button("出库", disabled=len(selected_rows)==0): if cols[2].button(
"出库",
disabled=len(selected_rows) == 0,
use_container_width=True,
):
for row in selected_rows: for row in selected_rows:
api.delete_kb_doc(kb, row["file_name"]) api.delete_kb_doc(kb, row["file_name"])
st.experimental_rerun() st.experimental_rerun()
if cols[3].button("删除选中文档!", type="primary"): if cols[3].button(
"删除选中文档!",
type="primary",
use_container_width=True,
):
for row in selected_rows: for row in selected_rows:
ret = api.delete_kb_doc(kb, row["file_name"], True) ret = api.delete_kb_doc(kb, row["file_name"], True)
st.toast(ret["msg"]) st.toast(ret["msg"])