根据新的接口修改ApiRequest和webui,以及测试用例。修改后预期webui中批量知识文件相关操作减少时间
This commit is contained in:
parent
661a0e9d72
commit
4cfee9c17c
|
|
@ -8,13 +8,11 @@ sys.path.append(str(root_path))
|
||||||
from server.utils import api_address
|
from server.utils import api_address
|
||||||
from configs.model_config import VECTOR_SEARCH_TOP_K
|
from configs.model_config import VECTOR_SEARCH_TOP_K
|
||||||
from server.knowledge_base.utils import get_kb_path, get_file_path
|
from server.knowledge_base.utils import get_kb_path, get_file_path
|
||||||
from webui_pages.utils import ApiRequest
|
|
||||||
|
|
||||||
from pprint import pprint
|
from pprint import pprint
|
||||||
|
|
||||||
|
|
||||||
api_base_url = api_address()
|
api_base_url = api_address()
|
||||||
api = ApiRequest(api_base_url)
|
|
||||||
|
|
||||||
|
|
||||||
kb = "kb_for_api_test"
|
kb = "kb_for_api_test"
|
||||||
|
|
@ -24,6 +22,8 @@ test_files = {
|
||||||
"test.txt": get_file_path("samples", "test.txt"),
|
"test.txt": get_file_path("samples", "test.txt"),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
print("\n\n直接url访问\n")
|
||||||
|
|
||||||
|
|
||||||
def test_delete_kb_before(api="/knowledge_base/delete_knowledge_base"):
|
def test_delete_kb_before(api="/knowledge_base/delete_knowledge_base"):
|
||||||
if not Path(get_kb_path(kb)).exists():
|
if not Path(get_kb_path(kb)).exists():
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,161 @@
|
||||||
|
import requests
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
root_path = Path(__file__).parent.parent.parent
|
||||||
|
sys.path.append(str(root_path))
|
||||||
|
from server.utils import api_address
|
||||||
|
from configs.model_config import VECTOR_SEARCH_TOP_K
|
||||||
|
from server.knowledge_base.utils import get_kb_path, get_file_path
|
||||||
|
from webui_pages.utils import ApiRequest
|
||||||
|
|
||||||
|
from pprint import pprint
|
||||||
|
|
||||||
|
|
||||||
|
api_base_url = api_address()
|
||||||
|
api: ApiRequest = ApiRequest(api_base_url)
|
||||||
|
|
||||||
|
|
||||||
|
kb = "kb_for_api_test"
|
||||||
|
test_files = {
|
||||||
|
"FAQ.MD": str(root_path / "docs" / "FAQ.MD"),
|
||||||
|
"README.MD": str(root_path / "README.MD"),
|
||||||
|
"test.txt": get_file_path("samples", "test.txt"),
|
||||||
|
}
|
||||||
|
|
||||||
|
print("\n\nApiRquest调用\n")
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_kb_before():
|
||||||
|
if not Path(get_kb_path(kb)).exists():
|
||||||
|
return
|
||||||
|
|
||||||
|
data = api.delete_knowledge_base(kb)
|
||||||
|
pprint(data)
|
||||||
|
assert data["code"] == 200
|
||||||
|
assert isinstance(data["data"], list) and len(data["data"]) > 0
|
||||||
|
assert kb not in data["data"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_kb():
|
||||||
|
print(f"\n尝试用空名称创建知识库:")
|
||||||
|
data = api.create_knowledge_base(" ")
|
||||||
|
pprint(data)
|
||||||
|
assert data["code"] == 404
|
||||||
|
assert data["msg"] == "知识库名称不能为空,请重新填写知识库名称"
|
||||||
|
|
||||||
|
print(f"\n创建新知识库: {kb}")
|
||||||
|
data = api.create_knowledge_base(kb)
|
||||||
|
pprint(data)
|
||||||
|
assert data["code"] == 200
|
||||||
|
assert data["msg"] == f"已新增知识库 {kb}"
|
||||||
|
|
||||||
|
print(f"\n尝试创建同名知识库: {kb}")
|
||||||
|
data = api.create_knowledge_base(kb)
|
||||||
|
pprint(data)
|
||||||
|
assert data["code"] == 404
|
||||||
|
assert data["msg"] == f"已存在同名知识库 {kb}"
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_kbs():
|
||||||
|
data = api.list_knowledge_bases()
|
||||||
|
pprint(data)
|
||||||
|
assert isinstance(data, list) and len(data) > 0
|
||||||
|
assert kb in data
|
||||||
|
|
||||||
|
|
||||||
|
def test_upload_docs():
|
||||||
|
files = list(test_files.values())
|
||||||
|
|
||||||
|
print(f"\n上传知识文件")
|
||||||
|
data = {"knowledge_base_name": kb, "override": True}
|
||||||
|
data = api.upload_kb_docs(files, **data)
|
||||||
|
pprint(data)
|
||||||
|
assert data["code"] == 200
|
||||||
|
assert len(data["data"]["failed_files"]) == 0
|
||||||
|
|
||||||
|
print(f"\n尝试重新上传知识文件, 不覆盖")
|
||||||
|
data = {"knowledge_base_name": kb, "override": False}
|
||||||
|
data = api.upload_kb_docs(files, **data)
|
||||||
|
pprint(data)
|
||||||
|
assert data["code"] == 200
|
||||||
|
assert len(data["data"]["failed_files"]) == len(test_files)
|
||||||
|
|
||||||
|
print(f"\n尝试重新上传知识文件, 覆盖,自定义docs")
|
||||||
|
docs = {"FAQ.MD": [{"page_content": "custom docs", "metadata": {}}]}
|
||||||
|
data = {"knowledge_base_name": kb, "override": True, "docs": json.dumps(docs)}
|
||||||
|
data = api.upload_kb_docs(files, **data)
|
||||||
|
pprint(data)
|
||||||
|
assert data["code"] == 200
|
||||||
|
assert len(data["data"]["failed_files"]) == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_files():
|
||||||
|
print("\n获取知识库中文件列表:")
|
||||||
|
data = api.list_kb_docs(knowledge_base_name=kb)
|
||||||
|
pprint(data)
|
||||||
|
assert isinstance(data, list)
|
||||||
|
for name in test_files:
|
||||||
|
assert name in data
|
||||||
|
|
||||||
|
|
||||||
|
def test_search_docs():
|
||||||
|
query = "介绍一下langchain-chatchat项目"
|
||||||
|
print("\n检索知识库:")
|
||||||
|
print(query)
|
||||||
|
data = api.search_kb_docs(query, kb)
|
||||||
|
pprint(data)
|
||||||
|
assert isinstance(data, list) and len(data) == VECTOR_SEARCH_TOP_K
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_docs():
|
||||||
|
print(f"\n更新知识文件")
|
||||||
|
data = api.update_kb_docs(knowledge_base_name=kb, file_names=list(test_files))
|
||||||
|
pprint(data)
|
||||||
|
assert data["code"] == 200
|
||||||
|
assert len(data["data"]["failed_files"]) == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_docs():
|
||||||
|
print(f"\n删除知识文件")
|
||||||
|
data = api.delete_kb_docs(knowledge_base_name=kb, file_names=list(test_files))
|
||||||
|
pprint(data)
|
||||||
|
assert data["code"] == 200
|
||||||
|
assert len(data["data"]["failed_files"]) == 0
|
||||||
|
|
||||||
|
query = "介绍一下langchain-chatchat项目"
|
||||||
|
print("\n尝试检索删除后的检索知识库:")
|
||||||
|
print(query)
|
||||||
|
data = api.search_kb_docs(query, kb)
|
||||||
|
pprint(data)
|
||||||
|
assert isinstance(data, list) and len(data) == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_recreate_vs():
|
||||||
|
print("\n重建知识库:")
|
||||||
|
r = api.recreate_vector_store(kb)
|
||||||
|
for data in r:
|
||||||
|
assert isinstance(data, dict)
|
||||||
|
assert data["code"] == 200
|
||||||
|
print(data["msg"])
|
||||||
|
|
||||||
|
query = "本项目支持哪些文件格式?"
|
||||||
|
print("\n尝试检索重建后的检索知识库:")
|
||||||
|
print(query)
|
||||||
|
data = api.search_kb_docs(query, kb)
|
||||||
|
pprint(data)
|
||||||
|
assert isinstance(data, list) and len(data) == VECTOR_SEARCH_TOP_K
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_kb_after():
|
||||||
|
print("\n删除知识库")
|
||||||
|
data = api.delete_knowledge_base(kb)
|
||||||
|
pprint(data)
|
||||||
|
|
||||||
|
# check kb not exists anymore
|
||||||
|
print("\n获取知识库列表:")
|
||||||
|
data = api.list_knowledge_bases()
|
||||||
|
pprint(data)
|
||||||
|
assert isinstance(data, list) and len(data) > 0
|
||||||
|
assert kb not in data
|
||||||
|
|
@ -127,7 +127,7 @@ def knowledge_base_page(api: ApiRequest):
|
||||||
|
|
||||||
# 上传文件
|
# 上传文件
|
||||||
# sentence_size = st.slider("文本入库分句长度限制", 1, 1000, SENTENCE_SIZE, disabled=True)
|
# sentence_size = st.slider("文本入库分句长度限制", 1, 1000, SENTENCE_SIZE, disabled=True)
|
||||||
files = st.file_uploader("上传知识文件(暂不支持扫描PDF)",
|
files = st.file_uploader("上传知识文件",
|
||||||
[i for ls in LOADER_DICT.values() for i in ls],
|
[i for ls in LOADER_DICT.values() for i in ls],
|
||||||
accept_multiple_files=True,
|
accept_multiple_files=True,
|
||||||
)
|
)
|
||||||
|
|
@ -138,14 +138,11 @@ def knowledge_base_page(api: ApiRequest):
|
||||||
# use_container_width=True,
|
# use_container_width=True,
|
||||||
disabled=len(files) == 0,
|
disabled=len(files) == 0,
|
||||||
):
|
):
|
||||||
data = [{"file": f, "knowledge_base_name": kb, "not_refresh_vs_cache": True} for f in files]
|
ret = api.upload_kb_docs(files, knowledge_base_name=kb, override=True)
|
||||||
data[-1]["not_refresh_vs_cache"]=False
|
if msg := check_success_msg(ret):
|
||||||
for k in data:
|
st.toast(msg, icon="✔")
|
||||||
ret = api.upload_kb_doc(**k)
|
elif msg := check_error_msg(ret):
|
||||||
if msg := check_success_msg(ret):
|
st.toast(msg, icon="✖")
|
||||||
st.toast(msg, icon="✔")
|
|
||||||
elif msg := check_error_msg(ret):
|
|
||||||
st.toast(msg, icon="✖")
|
|
||||||
st.session_state.files = []
|
st.session_state.files = []
|
||||||
|
|
||||||
st.divider()
|
st.divider()
|
||||||
|
|
@ -217,8 +214,8 @@ def knowledge_base_page(api: ApiRequest):
|
||||||
disabled=not file_exists(kb, selected_rows)[0],
|
disabled=not file_exists(kb, selected_rows)[0],
|
||||||
use_container_width=True,
|
use_container_width=True,
|
||||||
):
|
):
|
||||||
for row in selected_rows:
|
file_names = [row["file_name"] for row in selected_rows]
|
||||||
api.update_kb_doc(kb, row["file_name"])
|
api.update_kb_docs(kb, file_names=file_names)
|
||||||
st.experimental_rerun()
|
st.experimental_rerun()
|
||||||
|
|
||||||
# 将文件从向量库中删除,但不删除文件本身。
|
# 将文件从向量库中删除,但不删除文件本身。
|
||||||
|
|
@ -227,8 +224,8 @@ def knowledge_base_page(api: ApiRequest):
|
||||||
disabled=not (selected_rows and selected_rows[0]["in_db"]),
|
disabled=not (selected_rows and selected_rows[0]["in_db"]),
|
||||||
use_container_width=True,
|
use_container_width=True,
|
||||||
):
|
):
|
||||||
for row in selected_rows:
|
file_names = [row["file_name"] for row in selected_rows]
|
||||||
api.delete_kb_doc(kb, row["file_name"])
|
api.delete_kb_docs(kb, file_names=file_names)
|
||||||
st.experimental_rerun()
|
st.experimental_rerun()
|
||||||
|
|
||||||
if cols[3].button(
|
if cols[3].button(
|
||||||
|
|
@ -236,9 +233,8 @@ def knowledge_base_page(api: ApiRequest):
|
||||||
type="primary",
|
type="primary",
|
||||||
use_container_width=True,
|
use_container_width=True,
|
||||||
):
|
):
|
||||||
for row in selected_rows:
|
file_names = [row["file_name"] for row in selected_rows]
|
||||||
ret = api.delete_kb_doc(kb, row["file_name"], True)
|
api.delete_kb_docs(kb, file_names=file_names, delete_content=True)
|
||||||
st.toast(ret.get("msg", " "))
|
|
||||||
st.experimental_rerun()
|
st.experimental_rerun()
|
||||||
|
|
||||||
st.divider()
|
st.divider()
|
||||||
|
|
|
||||||
|
|
@ -509,13 +509,45 @@ class ApiRequest:
|
||||||
data = self._check_httpx_json_response(response)
|
data = self._check_httpx_json_response(response)
|
||||||
return data.get("data", [])
|
return data.get("data", [])
|
||||||
|
|
||||||
|
def search_kb_docs(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
knowledge_base_name: str,
|
||||||
|
top_k: int = VECTOR_SEARCH_TOP_K,
|
||||||
|
score_threshold: int = SCORE_THRESHOLD,
|
||||||
|
no_remote_api: bool = None,
|
||||||
|
) -> List:
|
||||||
|
'''
|
||||||
|
对应api.py/knowledge_base/search_docs接口
|
||||||
|
'''
|
||||||
|
if no_remote_api is None:
|
||||||
|
no_remote_api = self.no_remote_api
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"query": query,
|
||||||
|
"knowledge_base_name": knowledge_base_name,
|
||||||
|
"top_k": top_k,
|
||||||
|
"score_threshold": score_threshold,
|
||||||
|
}
|
||||||
|
|
||||||
|
if no_remote_api:
|
||||||
|
from server.knowledge_base.kb_doc_api import search_docs
|
||||||
|
return search_docs(**data)
|
||||||
|
else:
|
||||||
|
response = self.post(
|
||||||
|
"/knowledge_base/search_docs",
|
||||||
|
json=data,
|
||||||
|
)
|
||||||
|
data = self._check_httpx_json_response(response)
|
||||||
|
return data
|
||||||
|
|
||||||
def upload_kb_docs(
|
def upload_kb_docs(
|
||||||
self,
|
self,
|
||||||
files: List[Union[str, Path, bytes]],
|
files: List[Union[str, Path, bytes]],
|
||||||
knowledge_base_name: str,
|
knowledge_base_name: str,
|
||||||
override: bool = False,
|
override: bool = False,
|
||||||
to_vector_store: bool = True,
|
to_vector_store: bool = True,
|
||||||
docs: List[Dict] = [],
|
docs: Dict = {},
|
||||||
not_refresh_vs_cache: bool = False,
|
not_refresh_vs_cache: bool = False,
|
||||||
no_remote_api: bool = None,
|
no_remote_api: bool = None,
|
||||||
):
|
):
|
||||||
|
|
@ -525,97 +557,113 @@ class ApiRequest:
|
||||||
if no_remote_api is None:
|
if no_remote_api is None:
|
||||||
no_remote_api = self.no_remote_api
|
no_remote_api = self.no_remote_api
|
||||||
|
|
||||||
if isinstance(file, bytes): # raw bytes
|
def convert_file(file, filename=None):
|
||||||
file = BytesIO(file)
|
if isinstance(file, bytes): # raw bytes
|
||||||
elif hasattr(file, "read"): # a file io like object
|
file = BytesIO(file)
|
||||||
filename = filename or file.name
|
elif hasattr(file, "read"): # a file io like object
|
||||||
else: # a local path
|
filename = filename or file.name
|
||||||
file = Path(file).absolute().open("rb")
|
else: # a local path
|
||||||
filename = filename or file.name
|
file = Path(file).absolute().open("rb")
|
||||||
|
filename = filename or file.name
|
||||||
|
return filename, file
|
||||||
|
|
||||||
|
files = [convert_file(file) for file in files]
|
||||||
|
data={
|
||||||
|
"knowledge_base_name": knowledge_base_name,
|
||||||
|
"override": override,
|
||||||
|
"to_vector_store": to_vector_store,
|
||||||
|
"docs": docs,
|
||||||
|
"not_refresh_vs_cache": not_refresh_vs_cache,
|
||||||
|
}
|
||||||
|
|
||||||
if no_remote_api:
|
if no_remote_api:
|
||||||
from server.knowledge_base.kb_doc_api import upload_doc
|
from server.knowledge_base.kb_doc_api import upload_docs
|
||||||
from fastapi import UploadFile
|
from fastapi import UploadFile
|
||||||
from tempfile import SpooledTemporaryFile
|
from tempfile import SpooledTemporaryFile
|
||||||
|
|
||||||
temp_file = SpooledTemporaryFile(max_size=10 * 1024 * 1024)
|
upload_files = []
|
||||||
temp_file.write(file.read())
|
for file, filename in files:
|
||||||
temp_file.seek(0)
|
temp_file = SpooledTemporaryFile(max_size=10 * 1024 * 1024)
|
||||||
response = run_async(upload_doc(
|
temp_file.write(file.read())
|
||||||
UploadFile(file=temp_file, filename=filename),
|
temp_file.seek(0)
|
||||||
knowledge_base_name,
|
upload_files.append(UploadFile(file=temp_file, filename=filename))
|
||||||
override,
|
|
||||||
))
|
response = run_async(upload_docs(upload_files, **data))
|
||||||
return response.dict()
|
return response.dict()
|
||||||
else:
|
else:
|
||||||
|
if isinstance(data["docs"], dict):
|
||||||
|
data["docs"] = json.dumps(data["docs"], ensure_ascii=False)
|
||||||
response = self.post(
|
response = self.post(
|
||||||
"/knowledge_base/upload_doc",
|
"/knowledge_base/upload_docs",
|
||||||
data={
|
data=data,
|
||||||
"knowledge_base_name": knowledge_base_name,
|
files=[("files", (filename, file)) for filename, file in files],
|
||||||
"override": override,
|
|
||||||
"not_refresh_vs_cache": not_refresh_vs_cache,
|
|
||||||
},
|
|
||||||
files={"file": (filename, file)},
|
|
||||||
)
|
)
|
||||||
return self._check_httpx_json_response(response)
|
return self._check_httpx_json_response(response)
|
||||||
|
|
||||||
def delete_kb_doc(
|
def delete_kb_docs(
|
||||||
self,
|
self,
|
||||||
knowledge_base_name: str,
|
knowledge_base_name: str,
|
||||||
doc_name: str,
|
file_names: List[str],
|
||||||
delete_content: bool = False,
|
delete_content: bool = False,
|
||||||
not_refresh_vs_cache: bool = False,
|
not_refresh_vs_cache: bool = False,
|
||||||
no_remote_api: bool = None,
|
no_remote_api: bool = None,
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
对应api.py/knowledge_base/delete_doc接口
|
对应api.py/knowledge_base/delete_docs接口
|
||||||
'''
|
'''
|
||||||
if no_remote_api is None:
|
if no_remote_api is None:
|
||||||
no_remote_api = self.no_remote_api
|
no_remote_api = self.no_remote_api
|
||||||
|
|
||||||
data = {
|
data = {
|
||||||
"knowledge_base_name": knowledge_base_name,
|
"knowledge_base_name": knowledge_base_name,
|
||||||
"doc_name": doc_name,
|
"file_names": file_names,
|
||||||
"delete_content": delete_content,
|
"delete_content": delete_content,
|
||||||
"not_refresh_vs_cache": not_refresh_vs_cache,
|
"not_refresh_vs_cache": not_refresh_vs_cache,
|
||||||
}
|
}
|
||||||
|
|
||||||
if no_remote_api:
|
if no_remote_api:
|
||||||
from server.knowledge_base.kb_doc_api import delete_doc
|
from server.knowledge_base.kb_doc_api import delete_docs
|
||||||
response = run_async(delete_doc(**data))
|
response = run_async(delete_docs(**data))
|
||||||
return response.dict()
|
return response.dict()
|
||||||
else:
|
else:
|
||||||
response = self.post(
|
response = self.post(
|
||||||
"/knowledge_base/delete_doc",
|
"/knowledge_base/delete_docs",
|
||||||
json=data,
|
json=data,
|
||||||
)
|
)
|
||||||
return self._check_httpx_json_response(response)
|
return self._check_httpx_json_response(response)
|
||||||
|
|
||||||
def update_kb_doc(
|
def update_kb_docs(
|
||||||
self,
|
self,
|
||||||
knowledge_base_name: str,
|
knowledge_base_name: str,
|
||||||
file_name: str,
|
file_names: List[str],
|
||||||
|
override_custom_docs: bool = False,
|
||||||
|
docs: Dict = {},
|
||||||
not_refresh_vs_cache: bool = False,
|
not_refresh_vs_cache: bool = False,
|
||||||
no_remote_api: bool = None,
|
no_remote_api: bool = None,
|
||||||
):
|
):
|
||||||
'''
|
'''
|
||||||
对应api.py/knowledge_base/update_doc接口
|
对应api.py/knowledge_base/update_docs接口
|
||||||
'''
|
'''
|
||||||
if no_remote_api is None:
|
if no_remote_api is None:
|
||||||
no_remote_api = self.no_remote_api
|
no_remote_api = self.no_remote_api
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"knowledge_base_name": knowledge_base_name,
|
||||||
|
"file_names": file_names,
|
||||||
|
"override_custom_docs": override_custom_docs,
|
||||||
|
"docs": docs,
|
||||||
|
"not_refresh_vs_cache": not_refresh_vs_cache,
|
||||||
|
}
|
||||||
if no_remote_api:
|
if no_remote_api:
|
||||||
from server.knowledge_base.kb_doc_api import update_doc
|
from server.knowledge_base.kb_doc_api import update_docs
|
||||||
response = run_async(update_doc(knowledge_base_name, file_name))
|
response = run_async(update_docs(**data))
|
||||||
return response.dict()
|
return response.dict()
|
||||||
else:
|
else:
|
||||||
|
if isinstance(data["docs"], dict):
|
||||||
|
data["docs"] = json.dumps(data["docs"], ensure_ascii=False)
|
||||||
response = self.post(
|
response = self.post(
|
||||||
"/knowledge_base/update_doc",
|
"/knowledge_base/update_docs",
|
||||||
json={
|
json=data,
|
||||||
"knowledge_base_name": knowledge_base_name,
|
|
||||||
"file_name": file_name,
|
|
||||||
"not_refresh_vs_cache": not_refresh_vs_cache,
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
return self._check_httpx_json_response(response)
|
return self._check_httpx_json_response(response)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue