Langchain-Chatchat/libs/chatchat-server/tests/api/test_stream_chat_api.py

112 lines
3.2 KiB
Python
Raw Normal View History

2024-12-20 16:04:03 +08:00
import json
import sys
from pathlib import Path
import requests
sys.path.append(str(Path(__file__).parent.parent.parent))
from pprint import pprint
from chatchat.server.utils import api_address
api_base_url = api_address()
api="/chat/chat/completions"
url = f"{api_base_url}{api}"
def dump_input(d, title):
print("\n")
print("=" * 30 + title + " input " + "=" * 30)
pprint(d)
def dump_output(r, title):
print("\n")
print("=" * 30 + title + " output" + "=" * 30)
for line in r.iter_content(None, decode_unicode=True):
print(line, end="", flush=True)
headers = {
"accept": "application/json",
"Content-Type": "application/json",
}
def test_llm_chat():
data = {
"model": "qwen1.5-chat",
"messages": [
{"role": "user", "content": "你好"},
{"role": "assistant", "content": "你好,我是人工智能大模型"},
{"role": "user", "content": "请用100字左右的文字介绍自己"},
],
"stream": True,
"temperature": 0.7,
}
dump_input(data, "LLM Chat")
response = requests.post(url, headers=headers, json=data, stream=True)
dump_output(response, "LLM Chat")
assert response.status_code == 200
def test_agent_chat():
tools = list(requests.get(f"{api_base_url}/tools").json()["data"])
data = {
"model": "qwen1.5-chat",
"messages": [
{"role": "user", "content": "37+48="},
],
"stream": True,
"temperature": 0.7,
"tools": tools,
}
dump_input(data, "Agent Chat")
response = requests.post(url, headers=headers, json=data, stream=True)
dump_output(response, "Agent Chat")
assert response.status_code == 200
def test_kb_chat_auto():
data = {
"messages": [
{"role": "user", "content": "如何提问以获得高质量答案"},
],
"model": "qwen1.5-chat",
"tool_choice": "search_local_knowledgebase",
"stream": True,
}
dump_input(data, "KB Chat (auto parameters)")
response = requests.post(url, headers=headers, json=data, stream=True)
print("\n")
print("=" * 30 + "KB Chat (auto parameters)" + " output" + "=" * 30)
for line in response.iter_content(None, decode_unicode=True):
if line.startswith("data: "):
data = json.loads(line[6:])
pprint(data)
assert response.status_code == 200
def test_kb_chat_mannual():
data = {
"messages": [
{"role": "user", "content": "如何提问以获得高质量答案"},
],
"model": "qwen1.5-chat",
"tool_choice": "search_local_knowledgebase",
"extra_body": {"tool_input": {"database": "samples", "query": "如何提问以获得高质量答案"}},
"stream": True,
}
dump_input(data, "KB Chat (auto parameters)")
response = requests.post(url, headers=headers, json=data, stream=True)
print("\n")
print("=" * 30 + "KB Chat (auto parameters)" + " output" + "=" * 30)
for line in response.iter_content(None, decode_unicode=True):
if line.startswith("data: "):
data = json.loads(line[6:])
pprint(data)
assert response.status_code == 200