add api tests
This commit is contained in:
parent
69627a2fa3
commit
956237feac
|
|
@ -0,0 +1,200 @@
|
||||||
|
from doctest import testfile
|
||||||
|
import requests
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
root_path = Path(__file__).parent.parent.parent
|
||||||
|
sys.path.append(str(root_path))
|
||||||
|
from configs.server_config import api_address
|
||||||
|
from configs.model_config import VECTOR_SEARCH_TOP_K
|
||||||
|
|
||||||
|
from pprint import pprint
|
||||||
|
|
||||||
|
|
||||||
|
api_base_url = api_address()
|
||||||
|
|
||||||
|
kb = "kb_for_api_test"
|
||||||
|
test_files = {
|
||||||
|
"README.MD": str(root_path / "README.MD"),
|
||||||
|
"FAQ.MD": str(root_path / "docs" / "FAQ.MD")
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_kb_before(api="/knowledge_base/delete_knowledge_base"):
|
||||||
|
url = api_base_url + api
|
||||||
|
print("\n删除知识库")
|
||||||
|
r = requests.post(url, json=kb)
|
||||||
|
data = r.json()
|
||||||
|
pprint(data)
|
||||||
|
|
||||||
|
# check kb not exists anymore
|
||||||
|
url = api_base_url + "/knowledge_base/list_knowledge_bases"
|
||||||
|
print("\n获取知识库列表:")
|
||||||
|
r = requests.get(url)
|
||||||
|
data = r.json()
|
||||||
|
pprint(data)
|
||||||
|
assert data["code"] == 200
|
||||||
|
assert isinstance(data["data"], list) and len(data["data"]) > 0
|
||||||
|
assert kb not in data["data"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_kb(api="/knowledge_base/create_knowledge_base"):
|
||||||
|
url = api_base_url + api
|
||||||
|
|
||||||
|
print(f"\n尝试用空名称创建知识库:")
|
||||||
|
r = requests.post(url, json={"knowledge_base_name": " "})
|
||||||
|
data = r.json()
|
||||||
|
pprint(data)
|
||||||
|
assert data["code"] == 404
|
||||||
|
assert data["msg"] == "知识库名称不能为空,请重新填写知识库名称"
|
||||||
|
|
||||||
|
print(f"\n创建新知识库: {kb}")
|
||||||
|
r = requests.post(url, json={"knowledge_base_name": kb})
|
||||||
|
data = r.json()
|
||||||
|
pprint(data)
|
||||||
|
assert data["code"] == 200
|
||||||
|
assert data["msg"] == f"已新增知识库 {kb}"
|
||||||
|
|
||||||
|
print(f"\n尝试创建同名知识库: {kb}")
|
||||||
|
r = requests.post(url, json={"knowledge_base_name": kb})
|
||||||
|
data = r.json()
|
||||||
|
pprint(data)
|
||||||
|
assert data["code"] == 404
|
||||||
|
assert data["msg"] == f"已存在同名知识库 {kb}"
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_kbs(api="/knowledge_base/list_knowledge_bases"):
|
||||||
|
url = api_base_url + api
|
||||||
|
print("\n获取知识库列表:")
|
||||||
|
r = requests.get(url)
|
||||||
|
data = r.json()
|
||||||
|
pprint(data)
|
||||||
|
assert data["code"] == 200
|
||||||
|
assert isinstance(data["data"], list) and len(data["data"]) > 0
|
||||||
|
assert kb in data["data"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_upload_doc(api="/knowledge_base/upload_doc"):
|
||||||
|
url = api_base_url + api
|
||||||
|
for name, path in test_files.items():
|
||||||
|
print(f"\n上传知识文件: {name}")
|
||||||
|
data = {"knowledge_base_name": kb, "override": True}
|
||||||
|
files = {"file": (name, open(path, "rb"))}
|
||||||
|
r = requests.post(url, data=data, files=files)
|
||||||
|
data = r.json()
|
||||||
|
pprint(data)
|
||||||
|
assert data["code"] == 200
|
||||||
|
assert data["msg"] == f"成功上传文件 {name}"
|
||||||
|
|
||||||
|
for name, path in test_files.items():
|
||||||
|
print(f"\n尝试重新上传知识文件: {name}, 不覆盖")
|
||||||
|
data = {"knowledge_base_name": kb, "override": False}
|
||||||
|
files = {"file": (name, open(path, "rb"))}
|
||||||
|
r = requests.post(url, data=data, files=files)
|
||||||
|
data = r.json()
|
||||||
|
pprint(data)
|
||||||
|
assert data["code"] == 404
|
||||||
|
assert data["msg"] == f"文件 {name} 已存在。"
|
||||||
|
|
||||||
|
for name, path in test_files.items():
|
||||||
|
print(f"\n尝试重新上传知识文件: {name}, 覆盖")
|
||||||
|
data = {"knowledge_base_name": kb, "override": True}
|
||||||
|
files = {"file": (name, open(path, "rb"))}
|
||||||
|
r = requests.post(url, data=data, files=files)
|
||||||
|
data = r.json()
|
||||||
|
pprint(data)
|
||||||
|
assert data["code"] == 200
|
||||||
|
assert data["msg"] == f"成功上传文件 {name}"
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_docs(api="/knowledge_base/list_docs"):
|
||||||
|
url = api_base_url + api
|
||||||
|
print("\n获取知识库中文件列表:")
|
||||||
|
r = requests.get(url, params={"knowledge_base_name": kb})
|
||||||
|
data = r.json()
|
||||||
|
pprint(data)
|
||||||
|
assert data["code"] == 200
|
||||||
|
assert isinstance(data["data"], list)
|
||||||
|
for name in test_files:
|
||||||
|
assert name in data["data"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_search_docs(api="/knowledge_base/search_docs"):
|
||||||
|
url = api_base_url + api
|
||||||
|
query = "介绍一下langchain-chatchat项目"
|
||||||
|
print("\n检索知识库:")
|
||||||
|
print(query)
|
||||||
|
r = requests.post(url, json={"knowledge_base_name": kb, "query": query})
|
||||||
|
data = r.json()
|
||||||
|
pprint(data)
|
||||||
|
assert isinstance(data, list) and len(data) == VECTOR_SEARCH_TOP_K
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_doc(api="/knowledge_base/update_doc"):
|
||||||
|
url = api_base_url + api
|
||||||
|
for name, path in test_files.items():
|
||||||
|
print(f"\n更新知识文件: {name}")
|
||||||
|
r = requests.post(url, json={"knowledge_base_name": kb, "file_name": name})
|
||||||
|
data = r.json()
|
||||||
|
pprint(data)
|
||||||
|
assert data["code"] == 200
|
||||||
|
assert data["msg"] == f"成功更新文件 {name}"
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_doc(api="/knowledge_base/delete_doc"):
|
||||||
|
url = api_base_url + api
|
||||||
|
for name, path in test_files.items():
|
||||||
|
print(f"\n删除知识文件: {name}")
|
||||||
|
r = requests.post(url, json={"knowledge_base_name": kb, "doc_name": name})
|
||||||
|
data = r.json()
|
||||||
|
pprint(data)
|
||||||
|
assert data["code"] == 200
|
||||||
|
assert data["msg"] == f"{name} 文件删除成功"
|
||||||
|
|
||||||
|
url = api_base_url + "/knowledge_base/search_docs"
|
||||||
|
query = "介绍一下langchain-chatchat项目"
|
||||||
|
print("\n尝试检索删除后的检索知识库:")
|
||||||
|
print(query)
|
||||||
|
r = requests.post(url, json={"knowledge_base_name": kb, "query": query})
|
||||||
|
data = r.json()
|
||||||
|
pprint(data)
|
||||||
|
assert isinstance(data, list) and len(data) == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_recreate_vs(api="/knowledge_base/recreate_vector_store"):
|
||||||
|
url = api_base_url + api
|
||||||
|
print("\n重建知识库:")
|
||||||
|
r = requests.post(url, json={"knowledge_base_name": kb}, stream=True)
|
||||||
|
for chunk in r.iter_content(None):
|
||||||
|
data = json.loads(chunk)
|
||||||
|
assert isinstance(data, dict)
|
||||||
|
assert data["code"] == 200
|
||||||
|
print(data["msg"])
|
||||||
|
|
||||||
|
url = api_base_url + "/knowledge_base/search_docs"
|
||||||
|
query = "本项目支持哪些文件格式?"
|
||||||
|
print("\n尝试检索重建后的检索知识库:")
|
||||||
|
print(query)
|
||||||
|
r = requests.post(url, json={"knowledge_base_name": kb, "query": query})
|
||||||
|
data = r.json()
|
||||||
|
pprint(data)
|
||||||
|
assert isinstance(data, list) and len(data) == VECTOR_SEARCH_TOP_K
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_kb_after(api="/knowledge_base/delete_knowledge_base"):
|
||||||
|
url = api_base_url + api
|
||||||
|
print("\n删除知识库")
|
||||||
|
r = requests.post(url, json=kb)
|
||||||
|
data = r.json()
|
||||||
|
pprint(data)
|
||||||
|
|
||||||
|
# check kb not exists anymore
|
||||||
|
url = api_base_url + "/knowledge_base/list_knowledge_bases"
|
||||||
|
print("\n获取知识库列表:")
|
||||||
|
r = requests.get(url)
|
||||||
|
data = r.json()
|
||||||
|
pprint(data)
|
||||||
|
assert data["code"] == 200
|
||||||
|
assert isinstance(data["data"], list) and len(data["data"]) > 0
|
||||||
|
assert kb not in data["data"]
|
||||||
|
|
@ -0,0 +1,108 @@
|
||||||
|
import requests
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
sys.path.append(str(Path(__file__).parent.parent.parent))
|
||||||
|
from configs.server_config import API_SERVER, api_address
|
||||||
|
|
||||||
|
from pprint import pprint
|
||||||
|
|
||||||
|
|
||||||
|
api_base_url = api_address()
|
||||||
|
|
||||||
|
|
||||||
|
def dump_input(d, title):
|
||||||
|
print("\n")
|
||||||
|
print("=" * 30 + title + " input " + "="*30)
|
||||||
|
pprint(d)
|
||||||
|
|
||||||
|
|
||||||
|
def dump_output(r, title):
|
||||||
|
print("\n")
|
||||||
|
print("=" * 30 + title + " output" + "="*30)
|
||||||
|
for line in r.iter_content(None, decode_unicode=True):
|
||||||
|
print(line, end="", flush=True)
|
||||||
|
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
'accept': 'application/json',
|
||||||
|
'Content-Type': 'application/json',
|
||||||
|
}
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"query": "请用100字左右的文字介绍自己",
|
||||||
|
"history": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "你好"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "你好,我是 ChatGLM"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"stream": True
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat_fastchat(api="/chat/fastchat"):
|
||||||
|
url = f"{api_base_url}{api}"
|
||||||
|
data2 = {
|
||||||
|
"stream": True,
|
||||||
|
"messages": data["history"] + [{"role": "user", "content": "推荐一部科幻电影"}]
|
||||||
|
}
|
||||||
|
dump_input(data2, api)
|
||||||
|
response = requests.post(url, headers=headers, json=data2, stream=True)
|
||||||
|
dump_output(response, api)
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
|
||||||
|
def test_chat_chat(api="/chat/chat"):
|
||||||
|
url = f"{api_base_url}{api}"
|
||||||
|
dump_input(data, api)
|
||||||
|
response = requests.post(url, headers=headers, json=data, stream=True)
|
||||||
|
dump_output(response, api)
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
|
||||||
|
def test_knowledge_chat(api="/chat/knowledge_base_chat"):
|
||||||
|
url = f"{api_base_url}{api}"
|
||||||
|
data = {
|
||||||
|
"query": "如何提问以获得高质量答案",
|
||||||
|
"knowledge_base_name": "samples",
|
||||||
|
"history": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "你好"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "你好,我是 ChatGLM"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"stream": True
|
||||||
|
}
|
||||||
|
dump_input(data, api)
|
||||||
|
response = requests.post(url, headers=headers, json=data, stream=True)
|
||||||
|
print("\n")
|
||||||
|
print("=" * 30 + api + " output" + "="*30)
|
||||||
|
first = True
|
||||||
|
for line in response.iter_content(None, decode_unicode=True):
|
||||||
|
data = json.loads(line)
|
||||||
|
if first:
|
||||||
|
for doc in data["docs"]:
|
||||||
|
print(doc)
|
||||||
|
first = False
|
||||||
|
print(data["answer"], end="", flush=True)
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
|
||||||
|
def test_search_engine_chat(api="/chat/search_engine_chat"):
|
||||||
|
url = f"{api_base_url}{api}"
|
||||||
|
for se in ["bing", "duckduckgo"]:
|
||||||
|
dump_input(data, api)
|
||||||
|
response = requests.post(url, json=data, stream=True)
|
||||||
|
dump_output(response, api)
|
||||||
|
assert response.status_code == 200
|
||||||
Loading…
Reference in New Issue