diff --git a/requirements.txt b/requirements.txt index a81fa1a..9e60611 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,8 +9,6 @@ fastapi-offline nltk~=3.8.1 uvicorn~=0.23.1 starlette~=0.27.0 -numpy~=1.24.4 -pandas~=2.0.3 pydantic~=1.10.11 unstructured[all-docs] python-magic-bin; sys_platform == 'win32' diff --git a/requirements_webui.txt b/requirements_webui.txt index b4a954e..0f7e253 100644 --- a/requirements_webui.txt +++ b/requirements_webui.txt @@ -1,6 +1,8 @@ +numpy~=1.24.4 +pandas~=2.0.3 streamlit>=1.25.0 -streamlit-option-menu -streamlit-antd-components +streamlit-option-menu>=0.3.6 +streamlit-antd-components>=0.1.11 streamlit-chatbox>=1.1.6 -streamlit-aggrid -httpx +streamlit-aggrid>=0.3.4.post3 +httpx~=0.24.1 diff --git a/server/knowledge_base/kb_doc_api.py b/server/knowledge_base/kb_doc_api.py index 0249691..b467d96 100644 --- a/server/knowledge_base/kb_doc_api.py +++ b/server/knowledge_base/kb_doc_api.py @@ -1,12 +1,13 @@ import os import urllib from fastapi import File, Form, Body, UploadFile +from configs.model_config import DEFAULT_VS_TYPE, EMBEDDING_MODEL from server.utils import BaseResponse, ListResponse -from server.knowledge_base.utils import validate_kb_name +from server.knowledge_base.utils import validate_kb_name, list_docs_from_folder, KnowledgeFile from fastapi.responses import StreamingResponse import json -from server.knowledge_base.utils import KnowledgeFile, list_docs_from_folder from server.knowledge_base.kb_service.base import KBServiceFactory +from typing import List async def list_docs( @@ -100,7 +101,7 @@ async def update_doc( kb.update_doc(kb_file) return BaseResponse(code=200, msg=f"成功更新文件 {kb_file.filename}") else: - return BaseResponse(code=500, msg=f"{kb_file.filename} 文件上传失败,报错信息为: {e}") + return BaseResponse(code=500, msg=f"{kb_file.filename} 文件更新失败") async def download_doc(): @@ -111,7 +112,8 @@ async def download_doc(): async def recreate_vector_store( knowledge_base_name: str = Body(..., examples=["samples"]), allow_empty_kb: bool = Body(True), - vs_type: str = Body("faiss"), + vs_type: str = Body(DEFAULT_VS_TYPE), + embed_model: str = Body(EMBEDDING_MODEL), ): ''' recreate vector store from the content. @@ -119,31 +121,24 @@ async def recreate_vector_store( by default, get_service_by_name only return knowledge base in the info.db and having document files in it. set allow_empty_kb to True make it applied on empty knowledge base which it not in the info.db or having no documents. ''' - kb = KBServiceFactory.get_service_by_name(knowledge_base_name) - if kb is None: - if allow_empty_kb: - kb = KBServiceFactory.get_service(knowledge_base_name, vs_type) - else: - return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}") + kb = KBServiceFactory.get_service(knowledge_base_name, vs_type, embed_model) + if not kb.exists() and not allow_empty_kb: + return BaseResponse(code=404, msg=f"未找到知识库 {knowledge_base_name}") async def output(kb): kb.create_kb() kb.clear_vs() - print(f"start to recreate vector store of {kb.kb_name}") docs = list_docs_from_folder(knowledge_base_name) - print(docs) - for i, filename in enumerate(docs): - yield json.dumps({ - "total": len(docs), - "finished": i, - "doc": filename, - }) + for i, doc in enumerate(docs): try: - kb_file = KnowledgeFile(filename=filename, - knowledge_base_name=kb.kb_name) - print(f"processing {kb_file.filepath} to vector store.") + kb_file = KnowledgeFile(doc, knowledge_base_name) + yield json.dumps({ + "total": len(docs), + "finished": i, + "doc": doc, + }, ensure_ascii=False) kb.add_doc(kb_file) - except ValueError as e: + except Exception as e: print(e) return StreamingResponse(output(kb), media_type="text/event-stream") diff --git a/server/knowledge_base/kb_service/base.py b/server/knowledge_base/kb_service/base.py index c22bac7..295196b 100644 --- a/server/knowledge_base/kb_service/base.py +++ b/server/knowledge_base/kb_service/base.py @@ -1,7 +1,6 @@ from abc import ABC, abstractmethod import os -import pandas as pd from langchain.embeddings.base import Embeddings from langchain.docstore.document import Document @@ -20,7 +19,7 @@ from server.knowledge_base.utils import ( get_kb_path, get_doc_path, load_embeddings, KnowledgeFile, list_kbs_from_folder, list_docs_from_folder, ) -from typing import List, Union +from typing import List, Union, Dict class SupportedVSType: @@ -221,7 +220,7 @@ class KBServiceFactory: return KBServiceFactory.get_service("default", SupportedVSType.DEFAULT) -def get_kb_details() -> pd.DataFrame: +def get_kb_details() -> List[Dict]: kbs_in_folder = list_kbs_from_folder() kbs_in_db = KBService.list_kbs() result = {} @@ -247,20 +246,15 @@ def get_kb_details() -> pd.DataFrame: kb_detail["in_folder"] = False result[kb] = kb_detail - df = pd.DataFrame(result.values(), columns=[ - "kb_name", - "vs_type", - "embed_model", - "file_count", - "create_time", - "in_folder", - "in_db", - ]) - df.insert(0, "No", range(1, len(df) + 1)) - return df + data = [] + for i, v in enumerate(result.values()): + v['No'] = i + 1 + data.append(v) + + return data -def get_kb_doc_details(kb_name: str) -> pd.DataFrame: +def get_kb_doc_details(kb_name: str) -> List[Dict]: kb = KBServiceFactory.get_service_by_name(kb_name) docs_in_folder = list_docs_from_folder(kb_name) docs_in_db = kb.list_docs() @@ -289,17 +283,9 @@ def get_kb_doc_details(kb_name: str) -> pd.DataFrame: doc_detail["in_folder"] = False result[doc] = doc_detail - df = pd.DataFrame(result.values(), columns=[ - "kb_name", - "file_name", - "file_ext", - "file_version", - "document_loader", - "text_splitter", - "create_time", - "in_folder", - "in_db", - ]) - df.insert(0, "No", range(1, len(df) + 1)) - return df - + data = [] + for i, v in enumerate(result.values()): + v['No'] = i + 1 + data.append(v) + + return data diff --git a/server/knowledge_base/migrate.py b/server/knowledge_base/migrate.py index 29f274e..1c023fa 100644 --- a/server/knowledge_base/migrate.py +++ b/server/knowledge_base/migrate.py @@ -4,7 +4,7 @@ from server.knowledge_base.kb_service.base import KBServiceFactory from server.db.repository.knowledge_file_repository import add_doc_to_db from server.db.base import Base, engine import os -from typing import Literal +from typing import Literal, Callable, Any def create_tables(): @@ -21,6 +21,8 @@ def folder2db( mode: Literal["recreate_vs", "fill_info_only", "update_in_db", "increament"], vs_type: Literal["faiss", "milvus", "pg", "chromadb"] = DEFAULT_VS_TYPE, embed_model: str = EMBEDDING_MODEL, + callback_before: Callable = None, + callback_after: Callable = None, ): ''' use existed files in local folder to populate database and/or vector store. @@ -35,34 +37,53 @@ def folder2db( if mode == "recreate_vs": kb.clear_vs() - for doc in list_docs_from_folder(kb_name): + docs = list_docs_from_folder(kb_name) + for i, doc in enumerate(docs): try: kb_file = KnowledgeFile(doc, kb_name) + if callable(callback_before): + callback_before(kb_file, i, docs) kb.add_doc(kb_file) + if callable(callback_after): + callback_after(kb_file, i, docs) except Exception as e: print(e) elif mode == "fill_info_only": - for doc in list_docs_from_folder(kb_name): + docs = list_docs_from_folder(kb_name) + for i, doc in enumerate(docs): try: kb_file = KnowledgeFile(doc, kb_name) + if callable(callback_before): + callback_before(kb_file, i, docs) add_doc_to_db(kb_file) + if callable(callback_after): + callback_after(kb_file, i, docs) except Exception as e: print(e) elif mode == "update_in_db": - for doc in kb.list_docs(): + docs = kb.list_docs() + for i, doc in enumerate(docs): try: kb_file = KnowledgeFile(doc, kb_name) + if callable(callback_before): + callback_before(kb_file, i, docs) kb.update_doc(kb_file) + if callable(callback_after): + callback_after(kb_file, i, docs) except Exception as e: print(e) elif mode == "increament": db_docs = kb.list_docs() folder_docs = list_docs_from_folder(kb_name) docs = list(set(folder_docs) - set(db_docs)) - for doc in docs: + for i, doc in enumerate(docs): try: kb_file = KnowledgeFile(doc, kb_name) + if callable(callback_before): + callback_before(kb_file, i, docs) kb.add_doc(kb_file) + if callable(callback_after): + callback_after(kb_file, i, docs) except Exception as e: print(e) else: @@ -72,12 +93,13 @@ def folder2db( def recreate_all_vs( vs_type: Literal["faiss", "milvus", "pg", "chromadb"] = DEFAULT_VS_TYPE, embed_mode: str = EMBEDDING_MODEL, + **kwargs: Any, ): ''' used to recreate a vector store or change current vector store to another type or embed_model ''' for kb_name in list_kbs_from_folder(): - folder2db(kb_name, "recreate_vs", vs_type, embed_mode) + folder2db(kb_name, "recreate_vs", vs_type, embed_mode, **kwargs) def prune_db_docs(kb_name: str): diff --git a/webui_pages/knowledge_base/knowledge_base.py b/webui_pages/knowledge_base/knowledge_base.py index a790bcf..53a86a3 100644 --- a/webui_pages/knowledge_base/knowledge_base.py +++ b/webui_pages/knowledge_base/knowledge_base.py @@ -1,9 +1,10 @@ +from pydoc import doc import streamlit as st from webui_pages.utils import * from st_aggrid import AgGrid from st_aggrid.grid_options_builder import GridOptionsBuilder import pandas as pd -from server.knowledge_base.utils import get_file_path +from server.knowledge_base.utils import get_file_path, LOADER_DICT from server.knowledge_base.kb_service.base import get_kb_details, get_kb_doc_details from typing import Literal, Dict, Tuple @@ -28,17 +29,14 @@ def config_aggrid( return gb -# kb_box = ChatBox(session_key="kb_messages") - def knowledge_base_page(api: ApiRequest): # api = ApiRequest(base_url="http://127.0.0.1:7861", no_remote_api=True) - kb_details = get_kb_details() - kb_list = list(kb_details.kb_name) + kb_list = get_kb_details() - cols = st.columns([3, 1, 1]) + cols = st.columns([3, 1, 1, 3]) new_kb_name = cols[0].text_input( "新知识库名称", - placeholder="新知识库名称,暂不支持中文命名", + placeholder="新知识库名称,不支持中文命名", label_visibility="collapsed", key="new_kb_name", ) @@ -67,134 +65,128 @@ def knowledge_base_page(api: ApiRequest): else: st.error(f"名为 {new_kb_name} 的知识库不存在!") - st.write("知识库列表:") - st.info("请选择知识库") - if kb_list: + selected_kb = cols[3].selectbox( + "请选择知识库:", + kb_list, + format_func=lambda s: f"{s['kb_name']} ({s['vs_type']} @ {s['embed_model']})", + label_visibility="collapsed" + ) + + if selected_kb: + kb = selected_kb["kb_name"] + + # 知识库详情 + st.write(f"知识库 `{kb}` 详情:") + # st.info("请选择文件,点击按钮进行操作。") + doc_details = pd.DataFrame(get_kb_doc_details(kb)) + doc_details.drop(columns=["kb_name"], inplace=True) + doc_details = doc_details[[ + "No", "file_name", "document_loader", "text_splitter", "in_folder", "in_db", + ]] + gb = config_aggrid( - kb_details, + doc_details, { - ("kb_name", "知识库名称"): {}, - ("vs_type", "知识库类型"): {}, - ("embed_model", "嵌入模型"): {}, - ("file_count", "文档数量"): {}, - ("create_time", "创建时间"): {}, + ("file_name", "文档名称"): {}, + # ("file_ext", "文档类型"): {}, + # ("file_version", "文档版本"): {}, + ("document_loader", "文档加载器"): {}, + ("text_splitter", "分词器"): {}, + # ("create_time", "创建时间"): {}, ("in_folder", "文件夹"): {}, ("in_db", "数据库"): {}, - } + }, + "multiple", ) - kb_grid = AgGrid( - kb_details, + + doc_grid = AgGrid( + doc_details, gb.build(), columns_auto_size_mode="FIT_CONTENTS", theme="alpine", + custom_css={ + "#gridToolBar": {"display": "none"}, + }, ) - # st.write(kb_grid) - if kb_grid.selected_rows: - # st.session_state.selected_rows = [x["nIndex"] for x in kb_grid.selected_rows] - kb = kb_grid.selected_rows[0]["kb_name"] - with st.sidebar: - # sentence_size = st.slider("文本入库分句长度限制", 1, 1000, SENTENCE_SIZE, disabled=True) - files = st.file_uploader("上传知识文件", - ["docx", "txt", "md", "csv", "xlsx", "pdf"], - accept_multiple_files=True, - ) - if st.button( - "添加文件到知识库", - help="请先上传文件,再点击添加", - use_container_width=True, - disabled=len(files) == 0, - ): - for f in files: - ret = api.upload_kb_doc(f, kb) - if ret["code"] == 200: - st.toast(ret["msg"], icon="✔") - else: - st.toast(ret["msg"], icon="❌") - st.session_state.files = [] + cols = st.columns(3) + selected_rows = doc_grid.get("selected_rows", []) - # if st.button( - # "重建知识库", - # help="无需上传文件,通过其它方式将文档拷贝到对应知识库content目录下,点击本按钮即可重建知识库。", - # use_container_width=True, - # disabled=True, - # ): - # progress = st.progress(0.0, "") - # for d in api.recreate_vector_store(kb): - # progress.progress(d["finished"] / d["total"], f"正在处理: {d['doc']}") - - # 知识库详情 - st.write(f"知识库 `{kb}` 详情:") - st.info("请选择文件") - doc_details = get_kb_doc_details(kb) - doc_details.drop(columns=["kb_name"], inplace=True) - - gb = config_aggrid( - doc_details, - { - ("file_name", "文档名称"): {}, - ("file_ext", "文档类型"): {}, - ("file_version", "文档版本"): {}, - ("document_loader", "文档加载器"): {}, - ("text_splitter", "分词器"): {}, - ("create_time", "创建时间"): {}, - ("in_folder", "文件夹"): {}, - ("in_db", "数据库"): {}, - }, - "multiple", - ) - - doc_grid = AgGrid( - doc_details, - gb.build(), - columns_auto_size_mode="FIT_CONTENTS", - theme="alpine", - ) - - cols = st.columns(3) - selected_rows = doc_grid.get("selected_rows", []) - - cols = st.columns(4) - if selected_rows: - file_name = selected_rows[0]["file_name"] - file_path = get_file_path(kb, file_name) - with open(file_path, "rb") as fp: - cols[0].download_button( - "下载选中文档", - fp, - file_name=file_name, - use_container_width=True,) - else: + cols = st.columns(4) + if selected_rows: + file_name = selected_rows[0]["file_name"] + file_path = get_file_path(kb, file_name) + with open(file_path, "rb") as fp: cols[0].download_button( "下载选中文档", - "", - disabled=True, + fp, + file_name=file_name, use_container_width=True,) + else: + cols[0].download_button( + "下载选中文档", + "", + disabled=True, + use_container_width=True,) - if cols[1].button( - "入库", - disabled=len(selected_rows) == 0, - use_container_width=True, - ): - for row in selected_rows: - api.update_kb_doc(kb, row["file_name"]) - st.experimental_rerun() + if cols[1].button( + "入库", + disabled=len(selected_rows) == 0, + use_container_width=True, + help="将文件分词并加载到向量库中", + ): + for row in selected_rows: + api.update_kb_doc(kb, row["file_name"]) + st.experimental_rerun() - if cols[2].button( - "出库", - disabled=len(selected_rows) == 0, - use_container_width=True, - ): - for row in selected_rows: - api.delete_kb_doc(kb, row["file_name"]) - st.experimental_rerun() + if cols[2].button( + "出库", + disabled=len(selected_rows) == 0, + use_container_width=True, + help="将文件从向量库中删除,但不删除文件本身。" + ): + for row in selected_rows: + api.delete_kb_doc(kb, row["file_name"]) + st.experimental_rerun() - if cols[3].button( - "删除选中文档!", - type="primary", - use_container_width=True, - ): - for row in selected_rows: - ret = api.delete_kb_doc(kb, row["file_name"], True) - st.toast(ret["msg"]) - st.experimental_rerun() + if cols[3].button( + "删除选中文档!", + type="primary", + use_container_width=True, + ): + for row in selected_rows: + ret = api.delete_kb_doc(kb, row["file_name"], True) + st.toast(ret["msg"]) + st.experimental_rerun() + + st.divider() + # sentence_size = st.slider("文本入库分句长度限制", 1, 1000, SENTENCE_SIZE, disabled=True) + files = st.file_uploader("上传知识文件", + [i for ls in LOADER_DICT.values() for i in ls], + accept_multiple_files=True, + ) + cols = st.columns([3, 1]) + if cols[0].button( + "添加文件到知识库", + help="请先上传文件,再点击添加", + use_container_width=True, + disabled=len(files) == 0, + ): + for f in files: + ret = api.upload_kb_doc(f, kb) + if ret["code"] == 200: + st.toast(ret["msg"], icon="✔") + else: + st.toast(ret["msg"], icon="❌") + st.session_state.files = [] + + # todo: freezed + # if cols[1].button( + # "重建知识库", + # help="无需上传文件,通过其它方式将文档拷贝到对应知识库content目录下,点击本按钮即可重建知识库。", + # use_container_width=True, + # type="primary", + # ): + # progress = st.progress(0.0, "") + # for d in api.recreate_vector_store(kb): + # progress.progress(d["finished"] / d["total"], f"正在处理: {d['doc']}") diff --git a/webui_pages/utils.py b/webui_pages/utils.py index 1432326..9780dc6 100644 --- a/webui_pages/utils.py +++ b/webui_pages/utils.py @@ -3,6 +3,7 @@ from typing import * from pathlib import Path from configs.model_config import ( EMBEDDING_MODEL, + DEFAULT_VS_TYPE, KB_ROOT_PATH, LLM_MODEL, llm_model_dict, @@ -88,7 +89,7 @@ class ApiRequest: stream: bool = False, **kwargs: Any, ) -> Union[httpx.Response, None]: - rl = self._parse_url(url) + url = self._parse_url(url) kwargs.setdefault("timeout", self.timeout) async with httpx.AsyncClient() as client: while retry > 0: @@ -130,7 +131,7 @@ class ApiRequest: stream: bool = False, **kwargs: Any ) -> Union[httpx.Response, None]: - rl = self._parse_url(url) + url = self._parse_url(url) kwargs.setdefault("timeout", self.timeout) async with httpx.AsyncClient() as client: while retry > 0: @@ -171,7 +172,7 @@ class ApiRequest: stream: bool = False, **kwargs: Any ) -> Union[httpx.Response, None]: - rl = self._parse_url(url) + url = self._parse_url(url) kwargs.setdefault("timeout", self.timeout) async with httpx.AsyncClient() as client: while retry > 0: @@ -534,6 +535,9 @@ class ApiRequest: def recreate_vector_store( self, knowledge_base_name: str, + allow_empty_kb: bool = True, + vs_type: str = DEFAULT_VS_TYPE, + embed_model: str = EMBEDDING_MODEL, no_remote_api: bool = None, ): ''' @@ -542,14 +546,22 @@ class ApiRequest: if no_remote_api is None: no_remote_api = self.no_remote_api + data = { + "knowledge_base_name": knowledge_base_name, + "allow_empty_kb": allow_empty_kb, + "vs_type": vs_type, + "embed_model": embed_model, + } + if no_remote_api: from server.knowledge_base.kb_doc_api import recreate_vector_store - response = run_async(recreate_vector_store(knowledge_base_name)) + response = run_async(recreate_vector_store(**data)) return self._fastapi_stream2generator(response, as_json=True) else: response = self.post( "/knowledge_base/recreate_vector_store", - json={"knowledge_base_name": knowledge_base_name}, + json=data, + stream=True, ) return self._httpx_stream2generator(response, as_json=True)