在search_engine_chat中检查Bing KEY,并更新tests

This commit is contained in:
liunux4odoo 2023-08-25 10:58:40 +08:00
parent 29738c071c
commit 447b370416
4 changed files with 29 additions and 45 deletions

View File

@ -73,6 +73,9 @@ def search_engine_chat(query: str = Body(..., description="用户输入", exampl
if search_engine_name not in SEARCH_ENGINES.keys(): if search_engine_name not in SEARCH_ENGINES.keys():
return BaseResponse(code=404, msg=f"未支持搜索引擎 {search_engine_name}") return BaseResponse(code=404, msg=f"未支持搜索引擎 {search_engine_name}")
if search_engine_name == "bing" and not BING_SUBSCRIPTION_KEY:
return BaseResponse(code=404, msg=f"要使用Bing搜索引擎需要设置 `BING_SUBSCRIPTION_KEY`")
history = [History.from_data(h) for h in history] history = [History.from_data(h) for h in history]
async def search_engine_chat_iterator(query: str, async def search_engine_chat_iterator(query: str,

View File

@ -201,7 +201,7 @@ def run_model_worker(
): ):
import uvicorn import uvicorn
kwargs = FSCHAT_MODEL_WORKERS[LLM_MODEL].copy() kwargs = FSCHAT_MODEL_WORKERS[model_name].copy()
host = kwargs.pop("host") host = kwargs.pop("host")
port = kwargs.pop("port") port = kwargs.pop("port")
model_path = llm_model_dict[model_name].get("local_model_path", "") model_path = llm_model_dict[model_name].get("local_model_path", "")

View File

@ -1,41 +0,0 @@
import requests
import json
if __name__ == "__main__":
url = 'http://localhost:7861/chat/chat'
headers = {
'accept': 'application/json',
'Content-Type': 'application/json',
}
data = {
"query": "请用100字左右的文字介绍自己",
"history": [
{
"role": "user",
"content": "你好"
},
{
"role": "assistant",
"content": "你好,我是 ChatGLM"
}
],
"stream": True
}
response = requests.post(url, headers=headers, data=json.dumps(data), stream=True)
if response.status_code == 200:
for line in response.iter_content(decode_unicode=True):
print(line, flush=True)
else:
print("Error:", response.status_code)
r = requests.post(
openai_url + "/chat/completions",
json={"model": LLM_MODEL, "messages": "你好", "max_tokens": 1000})
data = r.json()
print(f"/chat/completions\n")
print(data)
assert "choices" in data

View File

@ -4,6 +4,7 @@ import sys
from pathlib import Path from pathlib import Path
sys.path.append(str(Path(__file__).parent.parent.parent)) sys.path.append(str(Path(__file__).parent.parent.parent))
from configs.model_config import BING_SUBSCRIPTION_KEY
from configs.server_config import API_SERVER, api_address from configs.server_config import API_SERVER, api_address
from pprint import pprint from pprint import pprint
@ -39,7 +40,7 @@ data = {
}, },
{ {
"role": "assistant", "role": "assistant",
"content": "你好,我是 ChatGLM" "content": "你好,我是人工智能大模型"
} }
], ],
"stream": True "stream": True
@ -100,9 +101,30 @@ def test_knowledge_chat(api="/chat/knowledge_base_chat"):
def test_search_engine_chat(api="/chat/search_engine_chat"): def test_search_engine_chat(api="/chat/search_engine_chat"):
global data
data["query"] = "室温超导最新进展是什么样?"
url = f"{api_base_url}{api}" url = f"{api_base_url}{api}"
for se in ["bing", "duckduckgo"]: for se in ["bing", "duckduckgo"]:
dump_input(data, api) data["search_engine_name"] = se
dump_input(data, api + f" by {se}")
response = requests.post(url, json=data, stream=True) response = requests.post(url, json=data, stream=True)
dump_output(response, api) if se == "bing" and not BING_SUBSCRIPTION_KEY:
data = response.json()
assert data["code"] == 404
assert data["msg"] == f"要使用Bing搜索引擎需要设置 `BING_SUBSCRIPTION_KEY`"
print("\n")
print("=" * 30 + api + " by {se} output" + "="*30)
first = True
for line in response.iter_content(None, decode_unicode=True):
data = json.loads(line)
assert "docs" in data and len(data["docs"]) > 0
if first:
for doc in data.get("docs", []):
print(doc)
first = False
print(data["answer"], end="", flush=True)
assert response.status_code == 200 assert response.status_code == 200