From 956237feac6ebd9c98a76531b3281170c843841f Mon Sep 17 00:00:00 2001 From: liunux4odoo Date: Sat, 19 Aug 2023 15:19:01 +0800 Subject: [PATCH] add api tests --- tests/api/test_kb_api.py | 200 ++++++++++++++++++++++++++++++ tests/api/test_stream_chat_api.py | 108 ++++++++++++++++ 2 files changed, 308 insertions(+) create mode 100644 tests/api/test_kb_api.py create mode 100644 tests/api/test_stream_chat_api.py diff --git a/tests/api/test_kb_api.py b/tests/api/test_kb_api.py new file mode 100644 index 0000000..c09d57f --- /dev/null +++ b/tests/api/test_kb_api.py @@ -0,0 +1,200 @@ +from doctest import testfile +import requests +import json +import sys +from pathlib import Path + +root_path = Path(__file__).parent.parent.parent +sys.path.append(str(root_path)) +from configs.server_config import api_address +from configs.model_config import VECTOR_SEARCH_TOP_K + +from pprint import pprint + + +api_base_url = api_address() + +kb = "kb_for_api_test" +test_files = { + "README.MD": str(root_path / "README.MD"), + "FAQ.MD": str(root_path / "docs" / "FAQ.MD") +} + + +def test_delete_kb_before(api="/knowledge_base/delete_knowledge_base"): + url = api_base_url + api + print("\n删除知识库") + r = requests.post(url, json=kb) + data = r.json() + pprint(data) + + # check kb not exists anymore + url = api_base_url + "/knowledge_base/list_knowledge_bases" + print("\n获取知识库列表:") + r = requests.get(url) + data = r.json() + pprint(data) + assert data["code"] == 200 + assert isinstance(data["data"], list) and len(data["data"]) > 0 + assert kb not in data["data"] + + +def test_create_kb(api="/knowledge_base/create_knowledge_base"): + url = api_base_url + api + + print(f"\n尝试用空名称创建知识库:") + r = requests.post(url, json={"knowledge_base_name": " "}) + data = r.json() + pprint(data) + assert data["code"] == 404 + assert data["msg"] == "知识库名称不能为空,请重新填写知识库名称" + + print(f"\n创建新知识库: {kb}") + r = requests.post(url, json={"knowledge_base_name": kb}) + data = r.json() + pprint(data) + assert data["code"] == 200 + assert data["msg"] == f"已新增知识库 {kb}" + + print(f"\n尝试创建同名知识库: {kb}") + r = requests.post(url, json={"knowledge_base_name": kb}) + data = r.json() + pprint(data) + assert data["code"] == 404 + assert data["msg"] == f"已存在同名知识库 {kb}" + + +def test_list_kbs(api="/knowledge_base/list_knowledge_bases"): + url = api_base_url + api + print("\n获取知识库列表:") + r = requests.get(url) + data = r.json() + pprint(data) + assert data["code"] == 200 + assert isinstance(data["data"], list) and len(data["data"]) > 0 + assert kb in data["data"] + + +def test_upload_doc(api="/knowledge_base/upload_doc"): + url = api_base_url + api + for name, path in test_files.items(): + print(f"\n上传知识文件: {name}") + data = {"knowledge_base_name": kb, "override": True} + files = {"file": (name, open(path, "rb"))} + r = requests.post(url, data=data, files=files) + data = r.json() + pprint(data) + assert data["code"] == 200 + assert data["msg"] == f"成功上传文件 {name}" + + for name, path in test_files.items(): + print(f"\n尝试重新上传知识文件: {name}, 不覆盖") + data = {"knowledge_base_name": kb, "override": False} + files = {"file": (name, open(path, "rb"))} + r = requests.post(url, data=data, files=files) + data = r.json() + pprint(data) + assert data["code"] == 404 + assert data["msg"] == f"文件 {name} 已存在。" + + for name, path in test_files.items(): + print(f"\n尝试重新上传知识文件: {name}, 覆盖") + data = {"knowledge_base_name": kb, "override": True} + files = {"file": (name, open(path, "rb"))} + r = requests.post(url, data=data, files=files) + data = r.json() + pprint(data) + assert data["code"] == 200 + assert data["msg"] == f"成功上传文件 {name}" + + +def test_list_docs(api="/knowledge_base/list_docs"): + url = api_base_url + api + print("\n获取知识库中文件列表:") + r = requests.get(url, params={"knowledge_base_name": kb}) + data = r.json() + pprint(data) + assert data["code"] == 200 + assert isinstance(data["data"], list) + for name in test_files: + assert name in data["data"] + + +def test_search_docs(api="/knowledge_base/search_docs"): + url = api_base_url + api + query = "介绍一下langchain-chatchat项目" + print("\n检索知识库:") + print(query) + r = requests.post(url, json={"knowledge_base_name": kb, "query": query}) + data = r.json() + pprint(data) + assert isinstance(data, list) and len(data) == VECTOR_SEARCH_TOP_K + + +def test_update_doc(api="/knowledge_base/update_doc"): + url = api_base_url + api + for name, path in test_files.items(): + print(f"\n更新知识文件: {name}") + r = requests.post(url, json={"knowledge_base_name": kb, "file_name": name}) + data = r.json() + pprint(data) + assert data["code"] == 200 + assert data["msg"] == f"成功更新文件 {name}" + + +def test_delete_doc(api="/knowledge_base/delete_doc"): + url = api_base_url + api + for name, path in test_files.items(): + print(f"\n删除知识文件: {name}") + r = requests.post(url, json={"knowledge_base_name": kb, "doc_name": name}) + data = r.json() + pprint(data) + assert data["code"] == 200 + assert data["msg"] == f"{name} 文件删除成功" + + url = api_base_url + "/knowledge_base/search_docs" + query = "介绍一下langchain-chatchat项目" + print("\n尝试检索删除后的检索知识库:") + print(query) + r = requests.post(url, json={"knowledge_base_name": kb, "query": query}) + data = r.json() + pprint(data) + assert isinstance(data, list) and len(data) == 0 + + +def test_recreate_vs(api="/knowledge_base/recreate_vector_store"): + url = api_base_url + api + print("\n重建知识库:") + r = requests.post(url, json={"knowledge_base_name": kb}, stream=True) + for chunk in r.iter_content(None): + data = json.loads(chunk) + assert isinstance(data, dict) + assert data["code"] == 200 + print(data["msg"]) + + url = api_base_url + "/knowledge_base/search_docs" + query = "本项目支持哪些文件格式?" + print("\n尝试检索重建后的检索知识库:") + print(query) + r = requests.post(url, json={"knowledge_base_name": kb, "query": query}) + data = r.json() + pprint(data) + assert isinstance(data, list) and len(data) == VECTOR_SEARCH_TOP_K + + +def test_delete_kb_after(api="/knowledge_base/delete_knowledge_base"): + url = api_base_url + api + print("\n删除知识库") + r = requests.post(url, json=kb) + data = r.json() + pprint(data) + + # check kb not exists anymore + url = api_base_url + "/knowledge_base/list_knowledge_bases" + print("\n获取知识库列表:") + r = requests.get(url) + data = r.json() + pprint(data) + assert data["code"] == 200 + assert isinstance(data["data"], list) and len(data["data"]) > 0 + assert kb not in data["data"] diff --git a/tests/api/test_stream_chat_api.py b/tests/api/test_stream_chat_api.py new file mode 100644 index 0000000..56d3237 --- /dev/null +++ b/tests/api/test_stream_chat_api.py @@ -0,0 +1,108 @@ +import requests +import json +import sys +from pathlib import Path + +sys.path.append(str(Path(__file__).parent.parent.parent)) +from configs.server_config import API_SERVER, api_address + +from pprint import pprint + + +api_base_url = api_address() + + +def dump_input(d, title): + print("\n") + print("=" * 30 + title + " input " + "="*30) + pprint(d) + + +def dump_output(r, title): + print("\n") + print("=" * 30 + title + " output" + "="*30) + for line in r.iter_content(None, decode_unicode=True): + print(line, end="", flush=True) + + +headers = { + 'accept': 'application/json', + 'Content-Type': 'application/json', +} + +data = { + "query": "请用100字左右的文字介绍自己", + "history": [ + { + "role": "user", + "content": "你好" + }, + { + "role": "assistant", + "content": "你好,我是 ChatGLM" + } + ], + "stream": True +} + + + +def test_chat_fastchat(api="/chat/fastchat"): + url = f"{api_base_url}{api}" + data2 = { + "stream": True, + "messages": data["history"] + [{"role": "user", "content": "推荐一部科幻电影"}] + } + dump_input(data2, api) + response = requests.post(url, headers=headers, json=data2, stream=True) + dump_output(response, api) + assert response.status_code == 200 + + +def test_chat_chat(api="/chat/chat"): + url = f"{api_base_url}{api}" + dump_input(data, api) + response = requests.post(url, headers=headers, json=data, stream=True) + dump_output(response, api) + assert response.status_code == 200 + + +def test_knowledge_chat(api="/chat/knowledge_base_chat"): + url = f"{api_base_url}{api}" + data = { + "query": "如何提问以获得高质量答案", + "knowledge_base_name": "samples", + "history": [ + { + "role": "user", + "content": "你好" + }, + { + "role": "assistant", + "content": "你好,我是 ChatGLM" + } + ], + "stream": True + } + dump_input(data, api) + response = requests.post(url, headers=headers, json=data, stream=True) + print("\n") + print("=" * 30 + api + " output" + "="*30) + first = True + for line in response.iter_content(None, decode_unicode=True): + data = json.loads(line) + if first: + for doc in data["docs"]: + print(doc) + first = False + print(data["answer"], end="", flush=True) + assert response.status_code == 200 + + +def test_search_engine_chat(api="/chat/search_engine_chat"): + url = f"{api_base_url}{api}" + for se in ["bing", "duckduckgo"]: + dump_input(data, api) + response = requests.post(url, json=data, stream=True) + dump_output(response, api) + assert response.status_code == 200