diff --git a/webui_pages/knowledge_base/knowledge_base.py b/webui_pages/knowledge_base/knowledge_base.py index d6399a8..e71edc0 100644 --- a/webui_pages/knowledge_base/knowledge_base.py +++ b/webui_pages/knowledge_base/knowledge_base.py @@ -41,7 +41,15 @@ def get_kb_details(api: ApiRequest) -> pd.DataFrame: kb_detail["in_folder"] = False result[kb] = kb_detail - df = pd.DataFrame(result.values()) + df = pd.DataFrame(result.values(), columns=[ + "kb_name", + "vs_type", + "embed_model", + "file_count", + "create_time", + "in_folder", + "in_db", + ]) df.insert(0, "No", range(1, len(df) + 1)) return df @@ -74,7 +82,17 @@ def get_kb_doc_details(api: ApiRequest, kb: str) -> pd.DataFrame: doc_detail["in_folder"] = False result[doc] = doc_detail - df = pd.DataFrame(result.values()) + df = pd.DataFrame(result.values(), columns=[ + "kb_name", + "file_name", + "file_ext", + "file_version", + "document_loader", + "text_splitter", + "create_time", + "in_folder", + "in_db", + ]) df.insert(0, "No", range(1, len(df) + 1)) return df @@ -98,24 +116,8 @@ def config_aggrid( def knowledge_base_page(api: ApiRequest): api = ApiRequest(base_url="http://127.0.0.1:7861", no_remote_api=True) kb_details = get_kb_details(api) - kb_list = list(kb_details.keys()) + kb_list = list(kb_details.kb_name) - def on_new_kb(): - if name := st.session_state.get("new_kb_name"): - if name in kb_list: - st.error(f"名为 {name} 的知识库已经存在!") - else: - ret = api.create_knowledge_base(name) - st.toast(ret["msg"]) - - def on_del_kb(): - if name := st.session_state.get("new_kb_name"): - if name in kb_list: - ret = api.delete_knowledge_base(name) - st.toast(ret["msg"]) - else: - st.error(f"名为 {name} 的知识库不存在!") - cols = st.columns([2, 1, 1]) new_kb_name = cols[0].text_input( "新知识库名称", @@ -123,8 +125,22 @@ def knowledge_base_page(api: ApiRequest): label_visibility="collapsed", key="new_kb_name", ) - cols[1].button("新建", on_click=on_new_kb, disabled=not bool(new_kb_name)) - cols[2].button("删除", on_click=on_del_kb, disabled=not bool(new_kb_name)) + + if cols[1].button("新建", disabled=not bool(new_kb_name)) and new_kb_name: + if new_kb_name in kb_list: + st.error(f"名为 {new_kb_name} 的知识库已经存在!") + else: + ret = api.create_knowledge_base(new_kb_name) + st.toast(ret["msg"]) + st.experimental_rerun() + + if cols[2].button("删除", disabled=not bool(new_kb_name)) and new_kb_name: + if new_kb_name in kb_list: + ret = api.delete_knowledge_base(new_kb_name) + st.toast(ret["msg"]) + st.experimental_rerun() + else: + st.error(f"名为 {new_kb_name} 的知识库不存在!") st.write("知识库:") if kb_list: @@ -177,7 +193,8 @@ def knowledge_base_page(api: ApiRequest): # 知识库详情 st.subheader(f"知识库 {kb} 详情") doc_details = get_kb_doc_details(api, kb) - del doc_details["kb_name"] + doc_details.drop(columns=["kb_name"], inplace=True) + gb = config_aggrid( doc_details, { diff --git a/webui_pages/utils.py b/webui_pages/utils.py index a273d8b..89d1b78 100644 --- a/webui_pages/utils.py +++ b/webui_pages/utils.py @@ -3,6 +3,7 @@ from typing import * from pathlib import Path import os from configs.model_config import ( + EMBEDDING_MODEL, KB_ROOT_PATH, LLM_MODEL, llm_model_dict, @@ -320,6 +321,8 @@ class ApiRequest: def create_knowledge_base( self, knowledge_base_name: str, + vector_store_type: str = "faiss", + embed_model: str = EMBEDDING_MODEL, no_remote_api: bool = None, ): ''' @@ -330,12 +333,12 @@ class ApiRequest: if no_remote_api: from server.knowledge_base.kb_api import create_kb - response = run_async(create_kb(knowledge_base_name)) + response = run_async(create_kb(knowledge_base_name, vector_store_type, embed_model)) return response.dict() else: response = self.post( "/knowledge_base/create_knowledge_base", - json={"knowledge_base_name": knowledge_base_name}, + json={"knowledge_base_name": knowledge_base_name, "vector_store_type": vector_store_type, "embed_model": embed_model}, ) return response.json()