diff --git a/server/chat/search_engine_chat.py b/server/chat/search_engine_chat.py index 032d06a..8a2633b 100644 --- a/server/chat/search_engine_chat.py +++ b/server/chat/search_engine_chat.py @@ -73,6 +73,9 @@ def search_engine_chat(query: str = Body(..., description="用户输入", exampl if search_engine_name not in SEARCH_ENGINES.keys(): return BaseResponse(code=404, msg=f"未支持搜索引擎 {search_engine_name}") + if search_engine_name == "bing" and not BING_SUBSCRIPTION_KEY: + return BaseResponse(code=404, msg=f"要使用Bing搜索引擎,需要设置 `BING_SUBSCRIPTION_KEY`") + history = [History.from_data(h) for h in history] async def search_engine_chat_iterator(query: str, diff --git a/startup.py b/startup.py index df00851..c8706e2 100644 --- a/startup.py +++ b/startup.py @@ -201,7 +201,7 @@ def run_model_worker( ): import uvicorn - kwargs = FSCHAT_MODEL_WORKERS[LLM_MODEL].copy() + kwargs = FSCHAT_MODEL_WORKERS[model_name].copy() host = kwargs.pop("host") port = kwargs.pop("port") model_path = llm_model_dict[model_name].get("local_model_path", "") diff --git a/tests/api/stream_api_test.py b/tests/api/stream_api_test.py deleted file mode 100644 index 2902c8a..0000000 --- a/tests/api/stream_api_test.py +++ /dev/null @@ -1,41 +0,0 @@ -import requests -import json - -if __name__ == "__main__": - url = 'http://localhost:7861/chat/chat' - headers = { - 'accept': 'application/json', - 'Content-Type': 'application/json', - } - - data = { - "query": "请用100字左右的文字介绍自己", - "history": [ - { - "role": "user", - "content": "你好" - }, - { - "role": "assistant", - "content": "你好,我是 ChatGLM" - } - ], - "stream": True - } - - response = requests.post(url, headers=headers, data=json.dumps(data), stream=True) - if response.status_code == 200: - for line in response.iter_content(decode_unicode=True): - print(line, flush=True) - else: - print("Error:", response.status_code) - - - r = requests.post( - openai_url + "/chat/completions", - json={"model": LLM_MODEL, "messages": "你好", "max_tokens": 1000}) - data = r.json() - print(f"/chat/completions\n") - print(data) - assert "choices" in data - diff --git a/tests/api/test_stream_chat_api.py b/tests/api/test_stream_chat_api.py index 56d3237..ad9d3d8 100644 --- a/tests/api/test_stream_chat_api.py +++ b/tests/api/test_stream_chat_api.py @@ -4,6 +4,7 @@ import sys from pathlib import Path sys.path.append(str(Path(__file__).parent.parent.parent)) +from configs.model_config import BING_SUBSCRIPTION_KEY from configs.server_config import API_SERVER, api_address from pprint import pprint @@ -39,7 +40,7 @@ data = { }, { "role": "assistant", - "content": "你好,我是 ChatGLM" + "content": "你好,我是人工智能大模型" } ], "stream": True @@ -100,9 +101,30 @@ def test_knowledge_chat(api="/chat/knowledge_base_chat"): def test_search_engine_chat(api="/chat/search_engine_chat"): + global data + + data["query"] = "室温超导最新进展是什么样?" + url = f"{api_base_url}{api}" for se in ["bing", "duckduckgo"]: - dump_input(data, api) + data["search_engine_name"] = se + dump_input(data, api + f" by {se}") response = requests.post(url, json=data, stream=True) - dump_output(response, api) + if se == "bing" and not BING_SUBSCRIPTION_KEY: + data = response.json() + assert data["code"] == 404 + assert data["msg"] == f"要使用Bing搜索引擎,需要设置 `BING_SUBSCRIPTION_KEY`" + + print("\n") + print("=" * 30 + api + " by {se} output" + "="*30) + first = True + for line in response.iter_content(None, decode_unicode=True): + data = json.loads(line) + assert "docs" in data and len(data["docs"]) > 0 + if first: + for doc in data.get("docs", []): + print(doc) + first = False + print(data["answer"], end="", flush=True) assert response.status_code == 200 +