diff --git a/tests/api/test_kb_api.py b/tests/api/test_kb_api.py index 3e371cf..ed4e8b2 100644 --- a/tests/api/test_kb_api.py +++ b/tests/api/test_kb_api.py @@ -8,13 +8,11 @@ sys.path.append(str(root_path)) from server.utils import api_address from configs.model_config import VECTOR_SEARCH_TOP_K from server.knowledge_base.utils import get_kb_path, get_file_path -from webui_pages.utils import ApiRequest from pprint import pprint api_base_url = api_address() -api = ApiRequest(api_base_url) kb = "kb_for_api_test" @@ -24,6 +22,8 @@ test_files = { "test.txt": get_file_path("samples", "test.txt"), } +print("\n\n直接url访问\n") + def test_delete_kb_before(api="/knowledge_base/delete_knowledge_base"): if not Path(get_kb_path(kb)).exists(): diff --git a/tests/api/test_kb_api_request.py b/tests/api/test_kb_api_request.py new file mode 100644 index 0000000..a0d15ce --- /dev/null +++ b/tests/api/test_kb_api_request.py @@ -0,0 +1,161 @@ +import requests +import json +import sys +from pathlib import Path + +root_path = Path(__file__).parent.parent.parent +sys.path.append(str(root_path)) +from server.utils import api_address +from configs.model_config import VECTOR_SEARCH_TOP_K +from server.knowledge_base.utils import get_kb_path, get_file_path +from webui_pages.utils import ApiRequest + +from pprint import pprint + + +api_base_url = api_address() +api: ApiRequest = ApiRequest(api_base_url) + + +kb = "kb_for_api_test" +test_files = { + "FAQ.MD": str(root_path / "docs" / "FAQ.MD"), + "README.MD": str(root_path / "README.MD"), + "test.txt": get_file_path("samples", "test.txt"), +} + +print("\n\nApiRquest调用\n") + + +def test_delete_kb_before(): + if not Path(get_kb_path(kb)).exists(): + return + + data = api.delete_knowledge_base(kb) + pprint(data) + assert data["code"] == 200 + assert isinstance(data["data"], list) and len(data["data"]) > 0 + assert kb not in data["data"] + + +def test_create_kb(): + print(f"\n尝试用空名称创建知识库:") + data = api.create_knowledge_base(" ") + pprint(data) + assert data["code"] == 404 + assert data["msg"] == "知识库名称不能为空,请重新填写知识库名称" + + print(f"\n创建新知识库: {kb}") + data = api.create_knowledge_base(kb) + pprint(data) + assert data["code"] == 200 + assert data["msg"] == f"已新增知识库 {kb}" + + print(f"\n尝试创建同名知识库: {kb}") + data = api.create_knowledge_base(kb) + pprint(data) + assert data["code"] == 404 + assert data["msg"] == f"已存在同名知识库 {kb}" + + +def test_list_kbs(): + data = api.list_knowledge_bases() + pprint(data) + assert isinstance(data, list) and len(data) > 0 + assert kb in data + + +def test_upload_docs(): + files = list(test_files.values()) + + print(f"\n上传知识文件") + data = {"knowledge_base_name": kb, "override": True} + data = api.upload_kb_docs(files, **data) + pprint(data) + assert data["code"] == 200 + assert len(data["data"]["failed_files"]) == 0 + + print(f"\n尝试重新上传知识文件, 不覆盖") + data = {"knowledge_base_name": kb, "override": False} + data = api.upload_kb_docs(files, **data) + pprint(data) + assert data["code"] == 200 + assert len(data["data"]["failed_files"]) == len(test_files) + + print(f"\n尝试重新上传知识文件, 覆盖,自定义docs") + docs = {"FAQ.MD": [{"page_content": "custom docs", "metadata": {}}]} + data = {"knowledge_base_name": kb, "override": True, "docs": json.dumps(docs)} + data = api.upload_kb_docs(files, **data) + pprint(data) + assert data["code"] == 200 + assert len(data["data"]["failed_files"]) == 0 + + +def test_list_files(): + print("\n获取知识库中文件列表:") + data = api.list_kb_docs(knowledge_base_name=kb) + pprint(data) + assert isinstance(data, list) + for name in test_files: + assert name in data + + +def test_search_docs(): + query = "介绍一下langchain-chatchat项目" + print("\n检索知识库:") + print(query) + data = api.search_kb_docs(query, kb) + pprint(data) + assert isinstance(data, list) and len(data) == VECTOR_SEARCH_TOP_K + + +def test_update_docs(): + print(f"\n更新知识文件") + data = api.update_kb_docs(knowledge_base_name=kb, file_names=list(test_files)) + pprint(data) + assert data["code"] == 200 + assert len(data["data"]["failed_files"]) == 0 + + +def test_delete_docs(): + print(f"\n删除知识文件") + data = api.delete_kb_docs(knowledge_base_name=kb, file_names=list(test_files)) + pprint(data) + assert data["code"] == 200 + assert len(data["data"]["failed_files"]) == 0 + + query = "介绍一下langchain-chatchat项目" + print("\n尝试检索删除后的检索知识库:") + print(query) + data = api.search_kb_docs(query, kb) + pprint(data) + assert isinstance(data, list) and len(data) == 0 + + +def test_recreate_vs(): + print("\n重建知识库:") + r = api.recreate_vector_store(kb) + for data in r: + assert isinstance(data, dict) + assert data["code"] == 200 + print(data["msg"]) + + query = "本项目支持哪些文件格式?" + print("\n尝试检索重建后的检索知识库:") + print(query) + data = api.search_kb_docs(query, kb) + pprint(data) + assert isinstance(data, list) and len(data) == VECTOR_SEARCH_TOP_K + + +def test_delete_kb_after(): + print("\n删除知识库") + data = api.delete_knowledge_base(kb) + pprint(data) + + # check kb not exists anymore + print("\n获取知识库列表:") + data = api.list_knowledge_bases() + pprint(data) + assert isinstance(data, list) and len(data) > 0 + assert kb not in data diff --git a/webui_pages/knowledge_base/knowledge_base.py b/webui_pages/knowledge_base/knowledge_base.py index 0889ca5..17b35d4 100644 --- a/webui_pages/knowledge_base/knowledge_base.py +++ b/webui_pages/knowledge_base/knowledge_base.py @@ -127,7 +127,7 @@ def knowledge_base_page(api: ApiRequest): # 上传文件 # sentence_size = st.slider("文本入库分句长度限制", 1, 1000, SENTENCE_SIZE, disabled=True) - files = st.file_uploader("上传知识文件(暂不支持扫描PDF)", + files = st.file_uploader("上传知识文件", [i for ls in LOADER_DICT.values() for i in ls], accept_multiple_files=True, ) @@ -138,14 +138,11 @@ def knowledge_base_page(api: ApiRequest): # use_container_width=True, disabled=len(files) == 0, ): - data = [{"file": f, "knowledge_base_name": kb, "not_refresh_vs_cache": True} for f in files] - data[-1]["not_refresh_vs_cache"]=False - for k in data: - ret = api.upload_kb_doc(**k) - if msg := check_success_msg(ret): - st.toast(msg, icon="✔") - elif msg := check_error_msg(ret): - st.toast(msg, icon="✖") + ret = api.upload_kb_docs(files, knowledge_base_name=kb, override=True) + if msg := check_success_msg(ret): + st.toast(msg, icon="✔") + elif msg := check_error_msg(ret): + st.toast(msg, icon="✖") st.session_state.files = [] st.divider() @@ -217,8 +214,8 @@ def knowledge_base_page(api: ApiRequest): disabled=not file_exists(kb, selected_rows)[0], use_container_width=True, ): - for row in selected_rows: - api.update_kb_doc(kb, row["file_name"]) + file_names = [row["file_name"] for row in selected_rows] + api.update_kb_docs(kb, file_names=file_names) st.experimental_rerun() # 将文件从向量库中删除,但不删除文件本身。 @@ -227,8 +224,8 @@ def knowledge_base_page(api: ApiRequest): disabled=not (selected_rows and selected_rows[0]["in_db"]), use_container_width=True, ): - for row in selected_rows: - api.delete_kb_doc(kb, row["file_name"]) + file_names = [row["file_name"] for row in selected_rows] + api.delete_kb_docs(kb, file_names=file_names) st.experimental_rerun() if cols[3].button( @@ -236,9 +233,8 @@ def knowledge_base_page(api: ApiRequest): type="primary", use_container_width=True, ): - for row in selected_rows: - ret = api.delete_kb_doc(kb, row["file_name"], True) - st.toast(ret.get("msg", " ")) + file_names = [row["file_name"] for row in selected_rows] + api.delete_kb_docs(kb, file_names=file_names, delete_content=True) st.experimental_rerun() st.divider() diff --git a/webui_pages/utils.py b/webui_pages/utils.py index e14df66..27843f8 100644 --- a/webui_pages/utils.py +++ b/webui_pages/utils.py @@ -509,13 +509,45 @@ class ApiRequest: data = self._check_httpx_json_response(response) return data.get("data", []) + def search_kb_docs( + self, + query: str, + knowledge_base_name: str, + top_k: int = VECTOR_SEARCH_TOP_K, + score_threshold: int = SCORE_THRESHOLD, + no_remote_api: bool = None, + ) -> List: + ''' + 对应api.py/knowledge_base/search_docs接口 + ''' + if no_remote_api is None: + no_remote_api = self.no_remote_api + + data = { + "query": query, + "knowledge_base_name": knowledge_base_name, + "top_k": top_k, + "score_threshold": score_threshold, + } + + if no_remote_api: + from server.knowledge_base.kb_doc_api import search_docs + return search_docs(**data) + else: + response = self.post( + "/knowledge_base/search_docs", + json=data, + ) + data = self._check_httpx_json_response(response) + return data + def upload_kb_docs( self, files: List[Union[str, Path, bytes]], knowledge_base_name: str, override: bool = False, to_vector_store: bool = True, - docs: List[Dict] = [], + docs: Dict = {}, not_refresh_vs_cache: bool = False, no_remote_api: bool = None, ): @@ -525,97 +557,113 @@ class ApiRequest: if no_remote_api is None: no_remote_api = self.no_remote_api - if isinstance(file, bytes): # raw bytes - file = BytesIO(file) - elif hasattr(file, "read"): # a file io like object - filename = filename or file.name - else: # a local path - file = Path(file).absolute().open("rb") - filename = filename or file.name + def convert_file(file, filename=None): + if isinstance(file, bytes): # raw bytes + file = BytesIO(file) + elif hasattr(file, "read"): # a file io like object + filename = filename or file.name + else: # a local path + file = Path(file).absolute().open("rb") + filename = filename or file.name + return filename, file + + files = [convert_file(file) for file in files] + data={ + "knowledge_base_name": knowledge_base_name, + "override": override, + "to_vector_store": to_vector_store, + "docs": docs, + "not_refresh_vs_cache": not_refresh_vs_cache, + } if no_remote_api: - from server.knowledge_base.kb_doc_api import upload_doc + from server.knowledge_base.kb_doc_api import upload_docs from fastapi import UploadFile from tempfile import SpooledTemporaryFile - temp_file = SpooledTemporaryFile(max_size=10 * 1024 * 1024) - temp_file.write(file.read()) - temp_file.seek(0) - response = run_async(upload_doc( - UploadFile(file=temp_file, filename=filename), - knowledge_base_name, - override, - )) + upload_files = [] + for file, filename in files: + temp_file = SpooledTemporaryFile(max_size=10 * 1024 * 1024) + temp_file.write(file.read()) + temp_file.seek(0) + upload_files.append(UploadFile(file=temp_file, filename=filename)) + + response = run_async(upload_docs(upload_files, **data)) return response.dict() else: + if isinstance(data["docs"], dict): + data["docs"] = json.dumps(data["docs"], ensure_ascii=False) response = self.post( - "/knowledge_base/upload_doc", - data={ - "knowledge_base_name": knowledge_base_name, - "override": override, - "not_refresh_vs_cache": not_refresh_vs_cache, - }, - files={"file": (filename, file)}, + "/knowledge_base/upload_docs", + data=data, + files=[("files", (filename, file)) for filename, file in files], ) return self._check_httpx_json_response(response) - def delete_kb_doc( + def delete_kb_docs( self, knowledge_base_name: str, - doc_name: str, + file_names: List[str], delete_content: bool = False, not_refresh_vs_cache: bool = False, no_remote_api: bool = None, ): ''' - 对应api.py/knowledge_base/delete_doc接口 + 对应api.py/knowledge_base/delete_docs接口 ''' if no_remote_api is None: no_remote_api = self.no_remote_api data = { "knowledge_base_name": knowledge_base_name, - "doc_name": doc_name, + "file_names": file_names, "delete_content": delete_content, "not_refresh_vs_cache": not_refresh_vs_cache, } if no_remote_api: - from server.knowledge_base.kb_doc_api import delete_doc - response = run_async(delete_doc(**data)) + from server.knowledge_base.kb_doc_api import delete_docs + response = run_async(delete_docs(**data)) return response.dict() else: response = self.post( - "/knowledge_base/delete_doc", + "/knowledge_base/delete_docs", json=data, ) return self._check_httpx_json_response(response) - def update_kb_doc( + def update_kb_docs( self, knowledge_base_name: str, - file_name: str, + file_names: List[str], + override_custom_docs: bool = False, + docs: Dict = {}, not_refresh_vs_cache: bool = False, no_remote_api: bool = None, ): ''' - 对应api.py/knowledge_base/update_doc接口 + 对应api.py/knowledge_base/update_docs接口 ''' if no_remote_api is None: no_remote_api = self.no_remote_api + data = { + "knowledge_base_name": knowledge_base_name, + "file_names": file_names, + "override_custom_docs": override_custom_docs, + "docs": docs, + "not_refresh_vs_cache": not_refresh_vs_cache, + } if no_remote_api: - from server.knowledge_base.kb_doc_api import update_doc - response = run_async(update_doc(knowledge_base_name, file_name)) + from server.knowledge_base.kb_doc_api import update_docs + response = run_async(update_docs(**data)) return response.dict() else: + if isinstance(data["docs"], dict): + data["docs"] = json.dumps(data["docs"], ensure_ascii=False) response = self.post( - "/knowledge_base/update_doc", - json={ - "knowledge_base_name": knowledge_base_name, - "file_name": file_name, - "not_refresh_vs_cache": not_refresh_vs_cache, - }, + "/knowledge_base/update_docs", + json=data, ) return self._check_httpx_json_response(response)