根据新的接口修改ApiRequest和webui,以及测试用例。修改后预期webui中批量知识文件相关操作减少时间

This commit is contained in:
liunux4odoo 2023-09-08 10:22:04 +08:00
parent 661a0e9d72
commit 4cfee9c17c
4 changed files with 265 additions and 60 deletions

View File

@ -8,13 +8,11 @@ sys.path.append(str(root_path))
from server.utils import api_address from server.utils import api_address
from configs.model_config import VECTOR_SEARCH_TOP_K from configs.model_config import VECTOR_SEARCH_TOP_K
from server.knowledge_base.utils import get_kb_path, get_file_path from server.knowledge_base.utils import get_kb_path, get_file_path
from webui_pages.utils import ApiRequest
from pprint import pprint from pprint import pprint
api_base_url = api_address() api_base_url = api_address()
api = ApiRequest(api_base_url)
kb = "kb_for_api_test" kb = "kb_for_api_test"
@ -24,6 +22,8 @@ test_files = {
"test.txt": get_file_path("samples", "test.txt"), "test.txt": get_file_path("samples", "test.txt"),
} }
print("\n\n直接url访问\n")
def test_delete_kb_before(api="/knowledge_base/delete_knowledge_base"): def test_delete_kb_before(api="/knowledge_base/delete_knowledge_base"):
if not Path(get_kb_path(kb)).exists(): if not Path(get_kb_path(kb)).exists():

View File

@ -0,0 +1,161 @@
import requests
import json
import sys
from pathlib import Path
root_path = Path(__file__).parent.parent.parent
sys.path.append(str(root_path))
from server.utils import api_address
from configs.model_config import VECTOR_SEARCH_TOP_K
from server.knowledge_base.utils import get_kb_path, get_file_path
from webui_pages.utils import ApiRequest
from pprint import pprint
api_base_url = api_address()
api: ApiRequest = ApiRequest(api_base_url)
kb = "kb_for_api_test"
test_files = {
"FAQ.MD": str(root_path / "docs" / "FAQ.MD"),
"README.MD": str(root_path / "README.MD"),
"test.txt": get_file_path("samples", "test.txt"),
}
print("\n\nApiRquest调用\n")
def test_delete_kb_before():
if not Path(get_kb_path(kb)).exists():
return
data = api.delete_knowledge_base(kb)
pprint(data)
assert data["code"] == 200
assert isinstance(data["data"], list) and len(data["data"]) > 0
assert kb not in data["data"]
def test_create_kb():
print(f"\n尝试用空名称创建知识库:")
data = api.create_knowledge_base(" ")
pprint(data)
assert data["code"] == 404
assert data["msg"] == "知识库名称不能为空,请重新填写知识库名称"
print(f"\n创建新知识库: {kb}")
data = api.create_knowledge_base(kb)
pprint(data)
assert data["code"] == 200
assert data["msg"] == f"已新增知识库 {kb}"
print(f"\n尝试创建同名知识库: {kb}")
data = api.create_knowledge_base(kb)
pprint(data)
assert data["code"] == 404
assert data["msg"] == f"已存在同名知识库 {kb}"
def test_list_kbs():
data = api.list_knowledge_bases()
pprint(data)
assert isinstance(data, list) and len(data) > 0
assert kb in data
def test_upload_docs():
files = list(test_files.values())
print(f"\n上传知识文件")
data = {"knowledge_base_name": kb, "override": True}
data = api.upload_kb_docs(files, **data)
pprint(data)
assert data["code"] == 200
assert len(data["data"]["failed_files"]) == 0
print(f"\n尝试重新上传知识文件, 不覆盖")
data = {"knowledge_base_name": kb, "override": False}
data = api.upload_kb_docs(files, **data)
pprint(data)
assert data["code"] == 200
assert len(data["data"]["failed_files"]) == len(test_files)
print(f"\n尝试重新上传知识文件, 覆盖自定义docs")
docs = {"FAQ.MD": [{"page_content": "custom docs", "metadata": {}}]}
data = {"knowledge_base_name": kb, "override": True, "docs": json.dumps(docs)}
data = api.upload_kb_docs(files, **data)
pprint(data)
assert data["code"] == 200
assert len(data["data"]["failed_files"]) == 0
def test_list_files():
print("\n获取知识库中文件列表:")
data = api.list_kb_docs(knowledge_base_name=kb)
pprint(data)
assert isinstance(data, list)
for name in test_files:
assert name in data
def test_search_docs():
query = "介绍一下langchain-chatchat项目"
print("\n检索知识库:")
print(query)
data = api.search_kb_docs(query, kb)
pprint(data)
assert isinstance(data, list) and len(data) == VECTOR_SEARCH_TOP_K
def test_update_docs():
print(f"\n更新知识文件")
data = api.update_kb_docs(knowledge_base_name=kb, file_names=list(test_files))
pprint(data)
assert data["code"] == 200
assert len(data["data"]["failed_files"]) == 0
def test_delete_docs():
print(f"\n删除知识文件")
data = api.delete_kb_docs(knowledge_base_name=kb, file_names=list(test_files))
pprint(data)
assert data["code"] == 200
assert len(data["data"]["failed_files"]) == 0
query = "介绍一下langchain-chatchat项目"
print("\n尝试检索删除后的检索知识库:")
print(query)
data = api.search_kb_docs(query, kb)
pprint(data)
assert isinstance(data, list) and len(data) == 0
def test_recreate_vs():
print("\n重建知识库:")
r = api.recreate_vector_store(kb)
for data in r:
assert isinstance(data, dict)
assert data["code"] == 200
print(data["msg"])
query = "本项目支持哪些文件格式?"
print("\n尝试检索重建后的检索知识库:")
print(query)
data = api.search_kb_docs(query, kb)
pprint(data)
assert isinstance(data, list) and len(data) == VECTOR_SEARCH_TOP_K
def test_delete_kb_after():
print("\n删除知识库")
data = api.delete_knowledge_base(kb)
pprint(data)
# check kb not exists anymore
print("\n获取知识库列表:")
data = api.list_knowledge_bases()
pprint(data)
assert isinstance(data, list) and len(data) > 0
assert kb not in data

View File

@ -127,7 +127,7 @@ def knowledge_base_page(api: ApiRequest):
# 上传文件 # 上传文件
# sentence_size = st.slider("文本入库分句长度限制", 1, 1000, SENTENCE_SIZE, disabled=True) # sentence_size = st.slider("文本入库分句长度限制", 1, 1000, SENTENCE_SIZE, disabled=True)
files = st.file_uploader("上传知识文件(暂不支持扫描PDF)", files = st.file_uploader("上传知识文件",
[i for ls in LOADER_DICT.values() for i in ls], [i for ls in LOADER_DICT.values() for i in ls],
accept_multiple_files=True, accept_multiple_files=True,
) )
@ -138,14 +138,11 @@ def knowledge_base_page(api: ApiRequest):
# use_container_width=True, # use_container_width=True,
disabled=len(files) == 0, disabled=len(files) == 0,
): ):
data = [{"file": f, "knowledge_base_name": kb, "not_refresh_vs_cache": True} for f in files] ret = api.upload_kb_docs(files, knowledge_base_name=kb, override=True)
data[-1]["not_refresh_vs_cache"]=False if msg := check_success_msg(ret):
for k in data: st.toast(msg, icon="")
ret = api.upload_kb_doc(**k) elif msg := check_error_msg(ret):
if msg := check_success_msg(ret): st.toast(msg, icon="")
st.toast(msg, icon="")
elif msg := check_error_msg(ret):
st.toast(msg, icon="")
st.session_state.files = [] st.session_state.files = []
st.divider() st.divider()
@ -217,8 +214,8 @@ def knowledge_base_page(api: ApiRequest):
disabled=not file_exists(kb, selected_rows)[0], disabled=not file_exists(kb, selected_rows)[0],
use_container_width=True, use_container_width=True,
): ):
for row in selected_rows: file_names = [row["file_name"] for row in selected_rows]
api.update_kb_doc(kb, row["file_name"]) api.update_kb_docs(kb, file_names=file_names)
st.experimental_rerun() st.experimental_rerun()
# 将文件从向量库中删除,但不删除文件本身。 # 将文件从向量库中删除,但不删除文件本身。
@ -227,8 +224,8 @@ def knowledge_base_page(api: ApiRequest):
disabled=not (selected_rows and selected_rows[0]["in_db"]), disabled=not (selected_rows and selected_rows[0]["in_db"]),
use_container_width=True, use_container_width=True,
): ):
for row in selected_rows: file_names = [row["file_name"] for row in selected_rows]
api.delete_kb_doc(kb, row["file_name"]) api.delete_kb_docs(kb, file_names=file_names)
st.experimental_rerun() st.experimental_rerun()
if cols[3].button( if cols[3].button(
@ -236,9 +233,8 @@ def knowledge_base_page(api: ApiRequest):
type="primary", type="primary",
use_container_width=True, use_container_width=True,
): ):
for row in selected_rows: file_names = [row["file_name"] for row in selected_rows]
ret = api.delete_kb_doc(kb, row["file_name"], True) api.delete_kb_docs(kb, file_names=file_names, delete_content=True)
st.toast(ret.get("msg", " "))
st.experimental_rerun() st.experimental_rerun()
st.divider() st.divider()

View File

@ -509,13 +509,45 @@ class ApiRequest:
data = self._check_httpx_json_response(response) data = self._check_httpx_json_response(response)
return data.get("data", []) return data.get("data", [])
def search_kb_docs(
self,
query: str,
knowledge_base_name: str,
top_k: int = VECTOR_SEARCH_TOP_K,
score_threshold: int = SCORE_THRESHOLD,
no_remote_api: bool = None,
) -> List:
'''
对应api.py/knowledge_base/search_docs接口
'''
if no_remote_api is None:
no_remote_api = self.no_remote_api
data = {
"query": query,
"knowledge_base_name": knowledge_base_name,
"top_k": top_k,
"score_threshold": score_threshold,
}
if no_remote_api:
from server.knowledge_base.kb_doc_api import search_docs
return search_docs(**data)
else:
response = self.post(
"/knowledge_base/search_docs",
json=data,
)
data = self._check_httpx_json_response(response)
return data
def upload_kb_docs( def upload_kb_docs(
self, self,
files: List[Union[str, Path, bytes]], files: List[Union[str, Path, bytes]],
knowledge_base_name: str, knowledge_base_name: str,
override: bool = False, override: bool = False,
to_vector_store: bool = True, to_vector_store: bool = True,
docs: List[Dict] = [], docs: Dict = {},
not_refresh_vs_cache: bool = False, not_refresh_vs_cache: bool = False,
no_remote_api: bool = None, no_remote_api: bool = None,
): ):
@ -525,97 +557,113 @@ class ApiRequest:
if no_remote_api is None: if no_remote_api is None:
no_remote_api = self.no_remote_api no_remote_api = self.no_remote_api
if isinstance(file, bytes): # raw bytes def convert_file(file, filename=None):
file = BytesIO(file) if isinstance(file, bytes): # raw bytes
elif hasattr(file, "read"): # a file io like object file = BytesIO(file)
filename = filename or file.name elif hasattr(file, "read"): # a file io like object
else: # a local path filename = filename or file.name
file = Path(file).absolute().open("rb") else: # a local path
filename = filename or file.name file = Path(file).absolute().open("rb")
filename = filename or file.name
return filename, file
files = [convert_file(file) for file in files]
data={
"knowledge_base_name": knowledge_base_name,
"override": override,
"to_vector_store": to_vector_store,
"docs": docs,
"not_refresh_vs_cache": not_refresh_vs_cache,
}
if no_remote_api: if no_remote_api:
from server.knowledge_base.kb_doc_api import upload_doc from server.knowledge_base.kb_doc_api import upload_docs
from fastapi import UploadFile from fastapi import UploadFile
from tempfile import SpooledTemporaryFile from tempfile import SpooledTemporaryFile
temp_file = SpooledTemporaryFile(max_size=10 * 1024 * 1024) upload_files = []
temp_file.write(file.read()) for file, filename in files:
temp_file.seek(0) temp_file = SpooledTemporaryFile(max_size=10 * 1024 * 1024)
response = run_async(upload_doc( temp_file.write(file.read())
UploadFile(file=temp_file, filename=filename), temp_file.seek(0)
knowledge_base_name, upload_files.append(UploadFile(file=temp_file, filename=filename))
override,
)) response = run_async(upload_docs(upload_files, **data))
return response.dict() return response.dict()
else: else:
if isinstance(data["docs"], dict):
data["docs"] = json.dumps(data["docs"], ensure_ascii=False)
response = self.post( response = self.post(
"/knowledge_base/upload_doc", "/knowledge_base/upload_docs",
data={ data=data,
"knowledge_base_name": knowledge_base_name, files=[("files", (filename, file)) for filename, file in files],
"override": override,
"not_refresh_vs_cache": not_refresh_vs_cache,
},
files={"file": (filename, file)},
) )
return self._check_httpx_json_response(response) return self._check_httpx_json_response(response)
def delete_kb_doc( def delete_kb_docs(
self, self,
knowledge_base_name: str, knowledge_base_name: str,
doc_name: str, file_names: List[str],
delete_content: bool = False, delete_content: bool = False,
not_refresh_vs_cache: bool = False, not_refresh_vs_cache: bool = False,
no_remote_api: bool = None, no_remote_api: bool = None,
): ):
''' '''
对应api.py/knowledge_base/delete_doc接口 对应api.py/knowledge_base/delete_docs接口
''' '''
if no_remote_api is None: if no_remote_api is None:
no_remote_api = self.no_remote_api no_remote_api = self.no_remote_api
data = { data = {
"knowledge_base_name": knowledge_base_name, "knowledge_base_name": knowledge_base_name,
"doc_name": doc_name, "file_names": file_names,
"delete_content": delete_content, "delete_content": delete_content,
"not_refresh_vs_cache": not_refresh_vs_cache, "not_refresh_vs_cache": not_refresh_vs_cache,
} }
if no_remote_api: if no_remote_api:
from server.knowledge_base.kb_doc_api import delete_doc from server.knowledge_base.kb_doc_api import delete_docs
response = run_async(delete_doc(**data)) response = run_async(delete_docs(**data))
return response.dict() return response.dict()
else: else:
response = self.post( response = self.post(
"/knowledge_base/delete_doc", "/knowledge_base/delete_docs",
json=data, json=data,
) )
return self._check_httpx_json_response(response) return self._check_httpx_json_response(response)
def update_kb_doc( def update_kb_docs(
self, self,
knowledge_base_name: str, knowledge_base_name: str,
file_name: str, file_names: List[str],
override_custom_docs: bool = False,
docs: Dict = {},
not_refresh_vs_cache: bool = False, not_refresh_vs_cache: bool = False,
no_remote_api: bool = None, no_remote_api: bool = None,
): ):
''' '''
对应api.py/knowledge_base/update_doc接口 对应api.py/knowledge_base/update_docs接口
''' '''
if no_remote_api is None: if no_remote_api is None:
no_remote_api = self.no_remote_api no_remote_api = self.no_remote_api
data = {
"knowledge_base_name": knowledge_base_name,
"file_names": file_names,
"override_custom_docs": override_custom_docs,
"docs": docs,
"not_refresh_vs_cache": not_refresh_vs_cache,
}
if no_remote_api: if no_remote_api:
from server.knowledge_base.kb_doc_api import update_doc from server.knowledge_base.kb_doc_api import update_docs
response = run_async(update_doc(knowledge_base_name, file_name)) response = run_async(update_docs(**data))
return response.dict() return response.dict()
else: else:
if isinstance(data["docs"], dict):
data["docs"] = json.dumps(data["docs"], ensure_ascii=False)
response = self.post( response = self.post(
"/knowledge_base/update_doc", "/knowledge_base/update_docs",
json={ json=data,
"knowledge_base_name": knowledge_base_name,
"file_name": file_name,
"not_refresh_vs_cache": not_refresh_vs_cache,
},
) )
return self._check_httpx_json_response(response) return self._check_httpx_json_response(response)