在search_engine_chat中检查Bing KEY,并更新tests
This commit is contained in:
parent
29738c071c
commit
447b370416
|
|
@ -73,6 +73,9 @@ def search_engine_chat(query: str = Body(..., description="用户输入", exampl
|
||||||
if search_engine_name not in SEARCH_ENGINES.keys():
|
if search_engine_name not in SEARCH_ENGINES.keys():
|
||||||
return BaseResponse(code=404, msg=f"未支持搜索引擎 {search_engine_name}")
|
return BaseResponse(code=404, msg=f"未支持搜索引擎 {search_engine_name}")
|
||||||
|
|
||||||
|
if search_engine_name == "bing" and not BING_SUBSCRIPTION_KEY:
|
||||||
|
return BaseResponse(code=404, msg=f"要使用Bing搜索引擎,需要设置 `BING_SUBSCRIPTION_KEY`")
|
||||||
|
|
||||||
history = [History.from_data(h) for h in history]
|
history = [History.from_data(h) for h in history]
|
||||||
|
|
||||||
async def search_engine_chat_iterator(query: str,
|
async def search_engine_chat_iterator(query: str,
|
||||||
|
|
|
||||||
|
|
@ -201,7 +201,7 @@ def run_model_worker(
|
||||||
):
|
):
|
||||||
import uvicorn
|
import uvicorn
|
||||||
|
|
||||||
kwargs = FSCHAT_MODEL_WORKERS[LLM_MODEL].copy()
|
kwargs = FSCHAT_MODEL_WORKERS[model_name].copy()
|
||||||
host = kwargs.pop("host")
|
host = kwargs.pop("host")
|
||||||
port = kwargs.pop("port")
|
port = kwargs.pop("port")
|
||||||
model_path = llm_model_dict[model_name].get("local_model_path", "")
|
model_path = llm_model_dict[model_name].get("local_model_path", "")
|
||||||
|
|
|
||||||
|
|
@ -1,41 +0,0 @@
|
||||||
import requests
|
|
||||||
import json
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
url = 'http://localhost:7861/chat/chat'
|
|
||||||
headers = {
|
|
||||||
'accept': 'application/json',
|
|
||||||
'Content-Type': 'application/json',
|
|
||||||
}
|
|
||||||
|
|
||||||
data = {
|
|
||||||
"query": "请用100字左右的文字介绍自己",
|
|
||||||
"history": [
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": "你好"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"role": "assistant",
|
|
||||||
"content": "你好,我是 ChatGLM"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"stream": True
|
|
||||||
}
|
|
||||||
|
|
||||||
response = requests.post(url, headers=headers, data=json.dumps(data), stream=True)
|
|
||||||
if response.status_code == 200:
|
|
||||||
for line in response.iter_content(decode_unicode=True):
|
|
||||||
print(line, flush=True)
|
|
||||||
else:
|
|
||||||
print("Error:", response.status_code)
|
|
||||||
|
|
||||||
|
|
||||||
r = requests.post(
|
|
||||||
openai_url + "/chat/completions",
|
|
||||||
json={"model": LLM_MODEL, "messages": "你好", "max_tokens": 1000})
|
|
||||||
data = r.json()
|
|
||||||
print(f"/chat/completions\n")
|
|
||||||
print(data)
|
|
||||||
assert "choices" in data
|
|
||||||
|
|
||||||
|
|
@ -4,6 +4,7 @@ import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
sys.path.append(str(Path(__file__).parent.parent.parent))
|
sys.path.append(str(Path(__file__).parent.parent.parent))
|
||||||
|
from configs.model_config import BING_SUBSCRIPTION_KEY
|
||||||
from configs.server_config import API_SERVER, api_address
|
from configs.server_config import API_SERVER, api_address
|
||||||
|
|
||||||
from pprint import pprint
|
from pprint import pprint
|
||||||
|
|
@ -39,7 +40,7 @@ data = {
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"role": "assistant",
|
"role": "assistant",
|
||||||
"content": "你好,我是 ChatGLM"
|
"content": "你好,我是人工智能大模型"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"stream": True
|
"stream": True
|
||||||
|
|
@ -100,9 +101,30 @@ def test_knowledge_chat(api="/chat/knowledge_base_chat"):
|
||||||
|
|
||||||
|
|
||||||
def test_search_engine_chat(api="/chat/search_engine_chat"):
|
def test_search_engine_chat(api="/chat/search_engine_chat"):
|
||||||
|
global data
|
||||||
|
|
||||||
|
data["query"] = "室温超导最新进展是什么样?"
|
||||||
|
|
||||||
url = f"{api_base_url}{api}"
|
url = f"{api_base_url}{api}"
|
||||||
for se in ["bing", "duckduckgo"]:
|
for se in ["bing", "duckduckgo"]:
|
||||||
dump_input(data, api)
|
data["search_engine_name"] = se
|
||||||
|
dump_input(data, api + f" by {se}")
|
||||||
response = requests.post(url, json=data, stream=True)
|
response = requests.post(url, json=data, stream=True)
|
||||||
dump_output(response, api)
|
if se == "bing" and not BING_SUBSCRIPTION_KEY:
|
||||||
|
data = response.json()
|
||||||
|
assert data["code"] == 404
|
||||||
|
assert data["msg"] == f"要使用Bing搜索引擎,需要设置 `BING_SUBSCRIPTION_KEY`"
|
||||||
|
|
||||||
|
print("\n")
|
||||||
|
print("=" * 30 + api + " by {se} output" + "="*30)
|
||||||
|
first = True
|
||||||
|
for line in response.iter_content(None, decode_unicode=True):
|
||||||
|
data = json.loads(line)
|
||||||
|
assert "docs" in data and len(data["docs"]) > 0
|
||||||
|
if first:
|
||||||
|
for doc in data.get("docs", []):
|
||||||
|
print(doc)
|
||||||
|
first = False
|
||||||
|
print(data["answer"], end="", flush=True)
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue