根据新的接口修改ApiRequest和webui,以及测试用例。修改后预期webui中批量知识文件相关操作减少时间

This commit is contained in:
liunux4odoo 2023-09-08 10:22:04 +08:00
parent 661a0e9d72
commit 4cfee9c17c
4 changed files with 265 additions and 60 deletions

View File

@ -8,13 +8,11 @@ sys.path.append(str(root_path))
from server.utils import api_address
from configs.model_config import VECTOR_SEARCH_TOP_K
from server.knowledge_base.utils import get_kb_path, get_file_path
from webui_pages.utils import ApiRequest
from pprint import pprint
api_base_url = api_address()
api = ApiRequest(api_base_url)
kb = "kb_for_api_test"
@ -24,6 +22,8 @@ test_files = {
"test.txt": get_file_path("samples", "test.txt"),
}
print("\n\n直接url访问\n")
def test_delete_kb_before(api="/knowledge_base/delete_knowledge_base"):
if not Path(get_kb_path(kb)).exists():

View File

@ -0,0 +1,161 @@
import requests
import json
import sys
from pathlib import Path
root_path = Path(__file__).parent.parent.parent
sys.path.append(str(root_path))
from server.utils import api_address
from configs.model_config import VECTOR_SEARCH_TOP_K
from server.knowledge_base.utils import get_kb_path, get_file_path
from webui_pages.utils import ApiRequest
from pprint import pprint
api_base_url = api_address()
api: ApiRequest = ApiRequest(api_base_url)
kb = "kb_for_api_test"
test_files = {
"FAQ.MD": str(root_path / "docs" / "FAQ.MD"),
"README.MD": str(root_path / "README.MD"),
"test.txt": get_file_path("samples", "test.txt"),
}
print("\n\nApiRquest调用\n")
def test_delete_kb_before():
if not Path(get_kb_path(kb)).exists():
return
data = api.delete_knowledge_base(kb)
pprint(data)
assert data["code"] == 200
assert isinstance(data["data"], list) and len(data["data"]) > 0
assert kb not in data["data"]
def test_create_kb():
print(f"\n尝试用空名称创建知识库:")
data = api.create_knowledge_base(" ")
pprint(data)
assert data["code"] == 404
assert data["msg"] == "知识库名称不能为空,请重新填写知识库名称"
print(f"\n创建新知识库: {kb}")
data = api.create_knowledge_base(kb)
pprint(data)
assert data["code"] == 200
assert data["msg"] == f"已新增知识库 {kb}"
print(f"\n尝试创建同名知识库: {kb}")
data = api.create_knowledge_base(kb)
pprint(data)
assert data["code"] == 404
assert data["msg"] == f"已存在同名知识库 {kb}"
def test_list_kbs():
data = api.list_knowledge_bases()
pprint(data)
assert isinstance(data, list) and len(data) > 0
assert kb in data
def test_upload_docs():
files = list(test_files.values())
print(f"\n上传知识文件")
data = {"knowledge_base_name": kb, "override": True}
data = api.upload_kb_docs(files, **data)
pprint(data)
assert data["code"] == 200
assert len(data["data"]["failed_files"]) == 0
print(f"\n尝试重新上传知识文件, 不覆盖")
data = {"knowledge_base_name": kb, "override": False}
data = api.upload_kb_docs(files, **data)
pprint(data)
assert data["code"] == 200
assert len(data["data"]["failed_files"]) == len(test_files)
print(f"\n尝试重新上传知识文件, 覆盖自定义docs")
docs = {"FAQ.MD": [{"page_content": "custom docs", "metadata": {}}]}
data = {"knowledge_base_name": kb, "override": True, "docs": json.dumps(docs)}
data = api.upload_kb_docs(files, **data)
pprint(data)
assert data["code"] == 200
assert len(data["data"]["failed_files"]) == 0
def test_list_files():
print("\n获取知识库中文件列表:")
data = api.list_kb_docs(knowledge_base_name=kb)
pprint(data)
assert isinstance(data, list)
for name in test_files:
assert name in data
def test_search_docs():
query = "介绍一下langchain-chatchat项目"
print("\n检索知识库:")
print(query)
data = api.search_kb_docs(query, kb)
pprint(data)
assert isinstance(data, list) and len(data) == VECTOR_SEARCH_TOP_K
def test_update_docs():
print(f"\n更新知识文件")
data = api.update_kb_docs(knowledge_base_name=kb, file_names=list(test_files))
pprint(data)
assert data["code"] == 200
assert len(data["data"]["failed_files"]) == 0
def test_delete_docs():
print(f"\n删除知识文件")
data = api.delete_kb_docs(knowledge_base_name=kb, file_names=list(test_files))
pprint(data)
assert data["code"] == 200
assert len(data["data"]["failed_files"]) == 0
query = "介绍一下langchain-chatchat项目"
print("\n尝试检索删除后的检索知识库:")
print(query)
data = api.search_kb_docs(query, kb)
pprint(data)
assert isinstance(data, list) and len(data) == 0
def test_recreate_vs():
print("\n重建知识库:")
r = api.recreate_vector_store(kb)
for data in r:
assert isinstance(data, dict)
assert data["code"] == 200
print(data["msg"])
query = "本项目支持哪些文件格式?"
print("\n尝试检索重建后的检索知识库:")
print(query)
data = api.search_kb_docs(query, kb)
pprint(data)
assert isinstance(data, list) and len(data) == VECTOR_SEARCH_TOP_K
def test_delete_kb_after():
print("\n删除知识库")
data = api.delete_knowledge_base(kb)
pprint(data)
# check kb not exists anymore
print("\n获取知识库列表:")
data = api.list_knowledge_bases()
pprint(data)
assert isinstance(data, list) and len(data) > 0
assert kb not in data

View File

@ -127,7 +127,7 @@ def knowledge_base_page(api: ApiRequest):
# 上传文件
# sentence_size = st.slider("文本入库分句长度限制", 1, 1000, SENTENCE_SIZE, disabled=True)
files = st.file_uploader("上传知识文件(暂不支持扫描PDF)",
files = st.file_uploader("上传知识文件",
[i for ls in LOADER_DICT.values() for i in ls],
accept_multiple_files=True,
)
@ -138,14 +138,11 @@ def knowledge_base_page(api: ApiRequest):
# use_container_width=True,
disabled=len(files) == 0,
):
data = [{"file": f, "knowledge_base_name": kb, "not_refresh_vs_cache": True} for f in files]
data[-1]["not_refresh_vs_cache"]=False
for k in data:
ret = api.upload_kb_doc(**k)
if msg := check_success_msg(ret):
st.toast(msg, icon="")
elif msg := check_error_msg(ret):
st.toast(msg, icon="")
ret = api.upload_kb_docs(files, knowledge_base_name=kb, override=True)
if msg := check_success_msg(ret):
st.toast(msg, icon="")
elif msg := check_error_msg(ret):
st.toast(msg, icon="")
st.session_state.files = []
st.divider()
@ -217,8 +214,8 @@ def knowledge_base_page(api: ApiRequest):
disabled=not file_exists(kb, selected_rows)[0],
use_container_width=True,
):
for row in selected_rows:
api.update_kb_doc(kb, row["file_name"])
file_names = [row["file_name"] for row in selected_rows]
api.update_kb_docs(kb, file_names=file_names)
st.experimental_rerun()
# 将文件从向量库中删除,但不删除文件本身。
@ -227,8 +224,8 @@ def knowledge_base_page(api: ApiRequest):
disabled=not (selected_rows and selected_rows[0]["in_db"]),
use_container_width=True,
):
for row in selected_rows:
api.delete_kb_doc(kb, row["file_name"])
file_names = [row["file_name"] for row in selected_rows]
api.delete_kb_docs(kb, file_names=file_names)
st.experimental_rerun()
if cols[3].button(
@ -236,9 +233,8 @@ def knowledge_base_page(api: ApiRequest):
type="primary",
use_container_width=True,
):
for row in selected_rows:
ret = api.delete_kb_doc(kb, row["file_name"], True)
st.toast(ret.get("msg", " "))
file_names = [row["file_name"] for row in selected_rows]
api.delete_kb_docs(kb, file_names=file_names, delete_content=True)
st.experimental_rerun()
st.divider()

View File

@ -509,13 +509,45 @@ class ApiRequest:
data = self._check_httpx_json_response(response)
return data.get("data", [])
def search_kb_docs(
self,
query: str,
knowledge_base_name: str,
top_k: int = VECTOR_SEARCH_TOP_K,
score_threshold: int = SCORE_THRESHOLD,
no_remote_api: bool = None,
) -> List:
'''
对应api.py/knowledge_base/search_docs接口
'''
if no_remote_api is None:
no_remote_api = self.no_remote_api
data = {
"query": query,
"knowledge_base_name": knowledge_base_name,
"top_k": top_k,
"score_threshold": score_threshold,
}
if no_remote_api:
from server.knowledge_base.kb_doc_api import search_docs
return search_docs(**data)
else:
response = self.post(
"/knowledge_base/search_docs",
json=data,
)
data = self._check_httpx_json_response(response)
return data
def upload_kb_docs(
self,
files: List[Union[str, Path, bytes]],
knowledge_base_name: str,
override: bool = False,
to_vector_store: bool = True,
docs: List[Dict] = [],
docs: Dict = {},
not_refresh_vs_cache: bool = False,
no_remote_api: bool = None,
):
@ -525,97 +557,113 @@ class ApiRequest:
if no_remote_api is None:
no_remote_api = self.no_remote_api
if isinstance(file, bytes): # raw bytes
file = BytesIO(file)
elif hasattr(file, "read"): # a file io like object
filename = filename or file.name
else: # a local path
file = Path(file).absolute().open("rb")
filename = filename or file.name
def convert_file(file, filename=None):
if isinstance(file, bytes): # raw bytes
file = BytesIO(file)
elif hasattr(file, "read"): # a file io like object
filename = filename or file.name
else: # a local path
file = Path(file).absolute().open("rb")
filename = filename or file.name
return filename, file
files = [convert_file(file) for file in files]
data={
"knowledge_base_name": knowledge_base_name,
"override": override,
"to_vector_store": to_vector_store,
"docs": docs,
"not_refresh_vs_cache": not_refresh_vs_cache,
}
if no_remote_api:
from server.knowledge_base.kb_doc_api import upload_doc
from server.knowledge_base.kb_doc_api import upload_docs
from fastapi import UploadFile
from tempfile import SpooledTemporaryFile
temp_file = SpooledTemporaryFile(max_size=10 * 1024 * 1024)
temp_file.write(file.read())
temp_file.seek(0)
response = run_async(upload_doc(
UploadFile(file=temp_file, filename=filename),
knowledge_base_name,
override,
))
upload_files = []
for file, filename in files:
temp_file = SpooledTemporaryFile(max_size=10 * 1024 * 1024)
temp_file.write(file.read())
temp_file.seek(0)
upload_files.append(UploadFile(file=temp_file, filename=filename))
response = run_async(upload_docs(upload_files, **data))
return response.dict()
else:
if isinstance(data["docs"], dict):
data["docs"] = json.dumps(data["docs"], ensure_ascii=False)
response = self.post(
"/knowledge_base/upload_doc",
data={
"knowledge_base_name": knowledge_base_name,
"override": override,
"not_refresh_vs_cache": not_refresh_vs_cache,
},
files={"file": (filename, file)},
"/knowledge_base/upload_docs",
data=data,
files=[("files", (filename, file)) for filename, file in files],
)
return self._check_httpx_json_response(response)
def delete_kb_doc(
def delete_kb_docs(
self,
knowledge_base_name: str,
doc_name: str,
file_names: List[str],
delete_content: bool = False,
not_refresh_vs_cache: bool = False,
no_remote_api: bool = None,
):
'''
对应api.py/knowledge_base/delete_doc接口
对应api.py/knowledge_base/delete_docs接口
'''
if no_remote_api is None:
no_remote_api = self.no_remote_api
data = {
"knowledge_base_name": knowledge_base_name,
"doc_name": doc_name,
"file_names": file_names,
"delete_content": delete_content,
"not_refresh_vs_cache": not_refresh_vs_cache,
}
if no_remote_api:
from server.knowledge_base.kb_doc_api import delete_doc
response = run_async(delete_doc(**data))
from server.knowledge_base.kb_doc_api import delete_docs
response = run_async(delete_docs(**data))
return response.dict()
else:
response = self.post(
"/knowledge_base/delete_doc",
"/knowledge_base/delete_docs",
json=data,
)
return self._check_httpx_json_response(response)
def update_kb_doc(
def update_kb_docs(
self,
knowledge_base_name: str,
file_name: str,
file_names: List[str],
override_custom_docs: bool = False,
docs: Dict = {},
not_refresh_vs_cache: bool = False,
no_remote_api: bool = None,
):
'''
对应api.py/knowledge_base/update_doc接口
对应api.py/knowledge_base/update_docs接口
'''
if no_remote_api is None:
no_remote_api = self.no_remote_api
data = {
"knowledge_base_name": knowledge_base_name,
"file_names": file_names,
"override_custom_docs": override_custom_docs,
"docs": docs,
"not_refresh_vs_cache": not_refresh_vs_cache,
}
if no_remote_api:
from server.knowledge_base.kb_doc_api import update_doc
response = run_async(update_doc(knowledge_base_name, file_name))
from server.knowledge_base.kb_doc_api import update_docs
response = run_async(update_docs(**data))
return response.dict()
else:
if isinstance(data["docs"], dict):
data["docs"] = json.dumps(data["docs"], ensure_ascii=False)
response = self.post(
"/knowledge_base/update_doc",
json={
"knowledge_base_name": knowledge_base_name,
"file_name": file_name,
"not_refresh_vs_cache": not_refresh_vs_cache,
},
"/knowledge_base/update_docs",
json=data,
)
return self._check_httpx_json_response(response)