根据新的接口修改ApiRequest和webui,以及测试用例。修改后预期webui中批量知识文件相关操作减少时间
This commit is contained in:
parent
661a0e9d72
commit
4cfee9c17c
|
|
@ -8,13 +8,11 @@ sys.path.append(str(root_path))
|
|||
from server.utils import api_address
|
||||
from configs.model_config import VECTOR_SEARCH_TOP_K
|
||||
from server.knowledge_base.utils import get_kb_path, get_file_path
|
||||
from webui_pages.utils import ApiRequest
|
||||
|
||||
from pprint import pprint
|
||||
|
||||
|
||||
api_base_url = api_address()
|
||||
api = ApiRequest(api_base_url)
|
||||
|
||||
|
||||
kb = "kb_for_api_test"
|
||||
|
|
@ -24,6 +22,8 @@ test_files = {
|
|||
"test.txt": get_file_path("samples", "test.txt"),
|
||||
}
|
||||
|
||||
print("\n\n直接url访问\n")
|
||||
|
||||
|
||||
def test_delete_kb_before(api="/knowledge_base/delete_knowledge_base"):
|
||||
if not Path(get_kb_path(kb)).exists():
|
||||
|
|
|
|||
|
|
@ -0,0 +1,161 @@
|
|||
import requests
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
root_path = Path(__file__).parent.parent.parent
|
||||
sys.path.append(str(root_path))
|
||||
from server.utils import api_address
|
||||
from configs.model_config import VECTOR_SEARCH_TOP_K
|
||||
from server.knowledge_base.utils import get_kb_path, get_file_path
|
||||
from webui_pages.utils import ApiRequest
|
||||
|
||||
from pprint import pprint
|
||||
|
||||
|
||||
api_base_url = api_address()
|
||||
api: ApiRequest = ApiRequest(api_base_url)
|
||||
|
||||
|
||||
kb = "kb_for_api_test"
|
||||
test_files = {
|
||||
"FAQ.MD": str(root_path / "docs" / "FAQ.MD"),
|
||||
"README.MD": str(root_path / "README.MD"),
|
||||
"test.txt": get_file_path("samples", "test.txt"),
|
||||
}
|
||||
|
||||
print("\n\nApiRquest调用\n")
|
||||
|
||||
|
||||
def test_delete_kb_before():
|
||||
if not Path(get_kb_path(kb)).exists():
|
||||
return
|
||||
|
||||
data = api.delete_knowledge_base(kb)
|
||||
pprint(data)
|
||||
assert data["code"] == 200
|
||||
assert isinstance(data["data"], list) and len(data["data"]) > 0
|
||||
assert kb not in data["data"]
|
||||
|
||||
|
||||
def test_create_kb():
|
||||
print(f"\n尝试用空名称创建知识库:")
|
||||
data = api.create_knowledge_base(" ")
|
||||
pprint(data)
|
||||
assert data["code"] == 404
|
||||
assert data["msg"] == "知识库名称不能为空,请重新填写知识库名称"
|
||||
|
||||
print(f"\n创建新知识库: {kb}")
|
||||
data = api.create_knowledge_base(kb)
|
||||
pprint(data)
|
||||
assert data["code"] == 200
|
||||
assert data["msg"] == f"已新增知识库 {kb}"
|
||||
|
||||
print(f"\n尝试创建同名知识库: {kb}")
|
||||
data = api.create_knowledge_base(kb)
|
||||
pprint(data)
|
||||
assert data["code"] == 404
|
||||
assert data["msg"] == f"已存在同名知识库 {kb}"
|
||||
|
||||
|
||||
def test_list_kbs():
|
||||
data = api.list_knowledge_bases()
|
||||
pprint(data)
|
||||
assert isinstance(data, list) and len(data) > 0
|
||||
assert kb in data
|
||||
|
||||
|
||||
def test_upload_docs():
|
||||
files = list(test_files.values())
|
||||
|
||||
print(f"\n上传知识文件")
|
||||
data = {"knowledge_base_name": kb, "override": True}
|
||||
data = api.upload_kb_docs(files, **data)
|
||||
pprint(data)
|
||||
assert data["code"] == 200
|
||||
assert len(data["data"]["failed_files"]) == 0
|
||||
|
||||
print(f"\n尝试重新上传知识文件, 不覆盖")
|
||||
data = {"knowledge_base_name": kb, "override": False}
|
||||
data = api.upload_kb_docs(files, **data)
|
||||
pprint(data)
|
||||
assert data["code"] == 200
|
||||
assert len(data["data"]["failed_files"]) == len(test_files)
|
||||
|
||||
print(f"\n尝试重新上传知识文件, 覆盖,自定义docs")
|
||||
docs = {"FAQ.MD": [{"page_content": "custom docs", "metadata": {}}]}
|
||||
data = {"knowledge_base_name": kb, "override": True, "docs": json.dumps(docs)}
|
||||
data = api.upload_kb_docs(files, **data)
|
||||
pprint(data)
|
||||
assert data["code"] == 200
|
||||
assert len(data["data"]["failed_files"]) == 0
|
||||
|
||||
|
||||
def test_list_files():
|
||||
print("\n获取知识库中文件列表:")
|
||||
data = api.list_kb_docs(knowledge_base_name=kb)
|
||||
pprint(data)
|
||||
assert isinstance(data, list)
|
||||
for name in test_files:
|
||||
assert name in data
|
||||
|
||||
|
||||
def test_search_docs():
|
||||
query = "介绍一下langchain-chatchat项目"
|
||||
print("\n检索知识库:")
|
||||
print(query)
|
||||
data = api.search_kb_docs(query, kb)
|
||||
pprint(data)
|
||||
assert isinstance(data, list) and len(data) == VECTOR_SEARCH_TOP_K
|
||||
|
||||
|
||||
def test_update_docs():
|
||||
print(f"\n更新知识文件")
|
||||
data = api.update_kb_docs(knowledge_base_name=kb, file_names=list(test_files))
|
||||
pprint(data)
|
||||
assert data["code"] == 200
|
||||
assert len(data["data"]["failed_files"]) == 0
|
||||
|
||||
|
||||
def test_delete_docs():
|
||||
print(f"\n删除知识文件")
|
||||
data = api.delete_kb_docs(knowledge_base_name=kb, file_names=list(test_files))
|
||||
pprint(data)
|
||||
assert data["code"] == 200
|
||||
assert len(data["data"]["failed_files"]) == 0
|
||||
|
||||
query = "介绍一下langchain-chatchat项目"
|
||||
print("\n尝试检索删除后的检索知识库:")
|
||||
print(query)
|
||||
data = api.search_kb_docs(query, kb)
|
||||
pprint(data)
|
||||
assert isinstance(data, list) and len(data) == 0
|
||||
|
||||
|
||||
def test_recreate_vs():
|
||||
print("\n重建知识库:")
|
||||
r = api.recreate_vector_store(kb)
|
||||
for data in r:
|
||||
assert isinstance(data, dict)
|
||||
assert data["code"] == 200
|
||||
print(data["msg"])
|
||||
|
||||
query = "本项目支持哪些文件格式?"
|
||||
print("\n尝试检索重建后的检索知识库:")
|
||||
print(query)
|
||||
data = api.search_kb_docs(query, kb)
|
||||
pprint(data)
|
||||
assert isinstance(data, list) and len(data) == VECTOR_SEARCH_TOP_K
|
||||
|
||||
|
||||
def test_delete_kb_after():
|
||||
print("\n删除知识库")
|
||||
data = api.delete_knowledge_base(kb)
|
||||
pprint(data)
|
||||
|
||||
# check kb not exists anymore
|
||||
print("\n获取知识库列表:")
|
||||
data = api.list_knowledge_bases()
|
||||
pprint(data)
|
||||
assert isinstance(data, list) and len(data) > 0
|
||||
assert kb not in data
|
||||
|
|
@ -127,7 +127,7 @@ def knowledge_base_page(api: ApiRequest):
|
|||
|
||||
# 上传文件
|
||||
# sentence_size = st.slider("文本入库分句长度限制", 1, 1000, SENTENCE_SIZE, disabled=True)
|
||||
files = st.file_uploader("上传知识文件(暂不支持扫描PDF)",
|
||||
files = st.file_uploader("上传知识文件",
|
||||
[i for ls in LOADER_DICT.values() for i in ls],
|
||||
accept_multiple_files=True,
|
||||
)
|
||||
|
|
@ -138,10 +138,7 @@ def knowledge_base_page(api: ApiRequest):
|
|||
# use_container_width=True,
|
||||
disabled=len(files) == 0,
|
||||
):
|
||||
data = [{"file": f, "knowledge_base_name": kb, "not_refresh_vs_cache": True} for f in files]
|
||||
data[-1]["not_refresh_vs_cache"]=False
|
||||
for k in data:
|
||||
ret = api.upload_kb_doc(**k)
|
||||
ret = api.upload_kb_docs(files, knowledge_base_name=kb, override=True)
|
||||
if msg := check_success_msg(ret):
|
||||
st.toast(msg, icon="✔")
|
||||
elif msg := check_error_msg(ret):
|
||||
|
|
@ -217,8 +214,8 @@ def knowledge_base_page(api: ApiRequest):
|
|||
disabled=not file_exists(kb, selected_rows)[0],
|
||||
use_container_width=True,
|
||||
):
|
||||
for row in selected_rows:
|
||||
api.update_kb_doc(kb, row["file_name"])
|
||||
file_names = [row["file_name"] for row in selected_rows]
|
||||
api.update_kb_docs(kb, file_names=file_names)
|
||||
st.experimental_rerun()
|
||||
|
||||
# 将文件从向量库中删除,但不删除文件本身。
|
||||
|
|
@ -227,8 +224,8 @@ def knowledge_base_page(api: ApiRequest):
|
|||
disabled=not (selected_rows and selected_rows[0]["in_db"]),
|
||||
use_container_width=True,
|
||||
):
|
||||
for row in selected_rows:
|
||||
api.delete_kb_doc(kb, row["file_name"])
|
||||
file_names = [row["file_name"] for row in selected_rows]
|
||||
api.delete_kb_docs(kb, file_names=file_names)
|
||||
st.experimental_rerun()
|
||||
|
||||
if cols[3].button(
|
||||
|
|
@ -236,9 +233,8 @@ def knowledge_base_page(api: ApiRequest):
|
|||
type="primary",
|
||||
use_container_width=True,
|
||||
):
|
||||
for row in selected_rows:
|
||||
ret = api.delete_kb_doc(kb, row["file_name"], True)
|
||||
st.toast(ret.get("msg", " "))
|
||||
file_names = [row["file_name"] for row in selected_rows]
|
||||
api.delete_kb_docs(kb, file_names=file_names, delete_content=True)
|
||||
st.experimental_rerun()
|
||||
|
||||
st.divider()
|
||||
|
|
|
|||
|
|
@ -509,13 +509,45 @@ class ApiRequest:
|
|||
data = self._check_httpx_json_response(response)
|
||||
return data.get("data", [])
|
||||
|
||||
def search_kb_docs(
|
||||
self,
|
||||
query: str,
|
||||
knowledge_base_name: str,
|
||||
top_k: int = VECTOR_SEARCH_TOP_K,
|
||||
score_threshold: int = SCORE_THRESHOLD,
|
||||
no_remote_api: bool = None,
|
||||
) -> List:
|
||||
'''
|
||||
对应api.py/knowledge_base/search_docs接口
|
||||
'''
|
||||
if no_remote_api is None:
|
||||
no_remote_api = self.no_remote_api
|
||||
|
||||
data = {
|
||||
"query": query,
|
||||
"knowledge_base_name": knowledge_base_name,
|
||||
"top_k": top_k,
|
||||
"score_threshold": score_threshold,
|
||||
}
|
||||
|
||||
if no_remote_api:
|
||||
from server.knowledge_base.kb_doc_api import search_docs
|
||||
return search_docs(**data)
|
||||
else:
|
||||
response = self.post(
|
||||
"/knowledge_base/search_docs",
|
||||
json=data,
|
||||
)
|
||||
data = self._check_httpx_json_response(response)
|
||||
return data
|
||||
|
||||
def upload_kb_docs(
|
||||
self,
|
||||
files: List[Union[str, Path, bytes]],
|
||||
knowledge_base_name: str,
|
||||
override: bool = False,
|
||||
to_vector_store: bool = True,
|
||||
docs: List[Dict] = [],
|
||||
docs: Dict = {},
|
||||
not_refresh_vs_cache: bool = False,
|
||||
no_remote_api: bool = None,
|
||||
):
|
||||
|
|
@ -525,6 +557,7 @@ class ApiRequest:
|
|||
if no_remote_api is None:
|
||||
no_remote_api = self.no_remote_api
|
||||
|
||||
def convert_file(file, filename=None):
|
||||
if isinstance(file, bytes): # raw bytes
|
||||
file = BytesIO(file)
|
||||
elif hasattr(file, "read"): # a file io like object
|
||||
|
|
@ -532,90 +565,105 @@ class ApiRequest:
|
|||
else: # a local path
|
||||
file = Path(file).absolute().open("rb")
|
||||
filename = filename or file.name
|
||||
return filename, file
|
||||
|
||||
if no_remote_api:
|
||||
from server.knowledge_base.kb_doc_api import upload_doc
|
||||
from fastapi import UploadFile
|
||||
from tempfile import SpooledTemporaryFile
|
||||
|
||||
temp_file = SpooledTemporaryFile(max_size=10 * 1024 * 1024)
|
||||
temp_file.write(file.read())
|
||||
temp_file.seek(0)
|
||||
response = run_async(upload_doc(
|
||||
UploadFile(file=temp_file, filename=filename),
|
||||
knowledge_base_name,
|
||||
override,
|
||||
))
|
||||
return response.dict()
|
||||
else:
|
||||
response = self.post(
|
||||
"/knowledge_base/upload_doc",
|
||||
files = [convert_file(file) for file in files]
|
||||
data={
|
||||
"knowledge_base_name": knowledge_base_name,
|
||||
"override": override,
|
||||
"to_vector_store": to_vector_store,
|
||||
"docs": docs,
|
||||
"not_refresh_vs_cache": not_refresh_vs_cache,
|
||||
},
|
||||
files={"file": (filename, file)},
|
||||
}
|
||||
|
||||
if no_remote_api:
|
||||
from server.knowledge_base.kb_doc_api import upload_docs
|
||||
from fastapi import UploadFile
|
||||
from tempfile import SpooledTemporaryFile
|
||||
|
||||
upload_files = []
|
||||
for file, filename in files:
|
||||
temp_file = SpooledTemporaryFile(max_size=10 * 1024 * 1024)
|
||||
temp_file.write(file.read())
|
||||
temp_file.seek(0)
|
||||
upload_files.append(UploadFile(file=temp_file, filename=filename))
|
||||
|
||||
response = run_async(upload_docs(upload_files, **data))
|
||||
return response.dict()
|
||||
else:
|
||||
if isinstance(data["docs"], dict):
|
||||
data["docs"] = json.dumps(data["docs"], ensure_ascii=False)
|
||||
response = self.post(
|
||||
"/knowledge_base/upload_docs",
|
||||
data=data,
|
||||
files=[("files", (filename, file)) for filename, file in files],
|
||||
)
|
||||
return self._check_httpx_json_response(response)
|
||||
|
||||
def delete_kb_doc(
|
||||
def delete_kb_docs(
|
||||
self,
|
||||
knowledge_base_name: str,
|
||||
doc_name: str,
|
||||
file_names: List[str],
|
||||
delete_content: bool = False,
|
||||
not_refresh_vs_cache: bool = False,
|
||||
no_remote_api: bool = None,
|
||||
):
|
||||
'''
|
||||
对应api.py/knowledge_base/delete_doc接口
|
||||
对应api.py/knowledge_base/delete_docs接口
|
||||
'''
|
||||
if no_remote_api is None:
|
||||
no_remote_api = self.no_remote_api
|
||||
|
||||
data = {
|
||||
"knowledge_base_name": knowledge_base_name,
|
||||
"doc_name": doc_name,
|
||||
"file_names": file_names,
|
||||
"delete_content": delete_content,
|
||||
"not_refresh_vs_cache": not_refresh_vs_cache,
|
||||
}
|
||||
|
||||
if no_remote_api:
|
||||
from server.knowledge_base.kb_doc_api import delete_doc
|
||||
response = run_async(delete_doc(**data))
|
||||
from server.knowledge_base.kb_doc_api import delete_docs
|
||||
response = run_async(delete_docs(**data))
|
||||
return response.dict()
|
||||
else:
|
||||
response = self.post(
|
||||
"/knowledge_base/delete_doc",
|
||||
"/knowledge_base/delete_docs",
|
||||
json=data,
|
||||
)
|
||||
return self._check_httpx_json_response(response)
|
||||
|
||||
def update_kb_doc(
|
||||
def update_kb_docs(
|
||||
self,
|
||||
knowledge_base_name: str,
|
||||
file_name: str,
|
||||
file_names: List[str],
|
||||
override_custom_docs: bool = False,
|
||||
docs: Dict = {},
|
||||
not_refresh_vs_cache: bool = False,
|
||||
no_remote_api: bool = None,
|
||||
):
|
||||
'''
|
||||
对应api.py/knowledge_base/update_doc接口
|
||||
对应api.py/knowledge_base/update_docs接口
|
||||
'''
|
||||
if no_remote_api is None:
|
||||
no_remote_api = self.no_remote_api
|
||||
|
||||
data = {
|
||||
"knowledge_base_name": knowledge_base_name,
|
||||
"file_names": file_names,
|
||||
"override_custom_docs": override_custom_docs,
|
||||
"docs": docs,
|
||||
"not_refresh_vs_cache": not_refresh_vs_cache,
|
||||
}
|
||||
if no_remote_api:
|
||||
from server.knowledge_base.kb_doc_api import update_doc
|
||||
response = run_async(update_doc(knowledge_base_name, file_name))
|
||||
from server.knowledge_base.kb_doc_api import update_docs
|
||||
response = run_async(update_docs(**data))
|
||||
return response.dict()
|
||||
else:
|
||||
if isinstance(data["docs"], dict):
|
||||
data["docs"] = json.dumps(data["docs"], ensure_ascii=False)
|
||||
response = self.post(
|
||||
"/knowledge_base/update_doc",
|
||||
json={
|
||||
"knowledge_base_name": knowledge_base_name,
|
||||
"file_name": file_name,
|
||||
"not_refresh_vs_cache": not_refresh_vs_cache,
|
||||
},
|
||||
"/knowledge_base/update_docs",
|
||||
json=data,
|
||||
)
|
||||
return self._check_httpx_json_response(response)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue