From 472a97a6165d4267654efb4ee384814b5716d183 Mon Sep 17 00:00:00 2001 From: liunux4odoo Date: Wed, 13 Dec 2023 16:08:58 +0800 Subject: [PATCH] =?UTF-8?q?=E6=89=80=E6=9C=89=20chat=20=E6=8E=A5=E5=8F=A3?= =?UTF-8?q?=E9=83=BD=E6=94=B9=E4=B8=BA=20EventSourceResponse=EF=BC=9BApiRe?= =?UTF-8?q?quest=20=E4=BD=9C=E5=AF=B9=E5=BA=94=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- configs/basic_config.py.example | 4 +++- server/chat/agent_chat.py | 6 +++--- server/chat/chat.py | 4 ++-- server/chat/completion.py | 6 +++--- server/chat/file_chat.py | 4 ++-- server/chat/search_engine_chat.py | 6 +++--- server/knowledge_base/kb_doc_api.py | 19 +++---------------- server/knowledge_base/kb_summary_api.py | 6 +++--- tests/api/test_kb_api.py | 2 +- tests/api/test_kb_summary_api.py | 4 ++-- tests/api/test_stream_chat_api.py | 4 ++-- tests/api/test_stream_chat_api_thread.py | 2 +- webui_pages/utils.py | 12 ++++++++---- 13 files changed, 36 insertions(+), 43 deletions(-) diff --git a/configs/basic_config.py.example b/configs/basic_config.py.example index a22fb97..7b50365 100644 --- a/configs/basic_config.py.example +++ b/configs/basic_config.py.example @@ -25,6 +25,8 @@ if not os.path.exists(LOG_PATH): # 临时文件目录,主要用于文件对话 BASE_TEMP_DIR = os.path.join(tempfile.gettempdir(), "chatchat") -if os.path.isdir(BASE_TEMP_DIR): +try: shutil.rmtree(BASE_TEMP_DIR) +except Exception: + pass os.makedirs(BASE_TEMP_DIR, exist_ok=True) diff --git a/server/chat/agent_chat.py b/server/chat/agent_chat.py index 51d4d1f..03c0c2b 100644 --- a/server/chat/agent_chat.py +++ b/server/chat/agent_chat.py @@ -6,7 +6,7 @@ from server.agent.callbacks import CustomAsyncIteratorCallbackHandler, Status from langchain.agents import LLMSingleActionAgent, AgentExecutor from server.agent.custom_template import CustomOutputParser, CustomPromptTemplate from fastapi import Body -from fastapi.responses import StreamingResponse +from sse_starlette.sse import EventSourceResponse from configs import LLM_MODELS, TEMPERATURE, HISTORY_LEN, Agent_MODEL from server.utils import wrap_done, get_ChatOpenAI, get_prompt_template from langchain.chains import LLMChain @@ -180,8 +180,8 @@ async def agent_chat(query: str = Body(..., description="用户输入", examples yield json.dumps({"answer": answer, "final_answer": final_answer}, ensure_ascii=False) await task - return StreamingResponse(agent_chat_iterator(query=query, + return EventSourceResponse(agent_chat_iterator(query=query, history=history, model_name=model_name, prompt_name=prompt_name), - media_type="text/event-stream") + ) diff --git a/server/chat/chat.py b/server/chat/chat.py index bac82f8..0c566dd 100644 --- a/server/chat/chat.py +++ b/server/chat/chat.py @@ -1,5 +1,5 @@ from fastapi import Body -from fastapi.responses import StreamingResponse +from sse_starlette.sse import EventSourceResponse from configs import LLM_MODELS, TEMPERATURE from server.utils import wrap_done, get_ChatOpenAI from langchain.chains import LLMChain @@ -100,4 +100,4 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼 await task - return StreamingResponse(chat_iterator(), media_type="text/event-stream") + return EventSourceResponse(chat_iterator()) diff --git a/server/chat/completion.py b/server/chat/completion.py index 6e5827d..acf1736 100644 --- a/server/chat/completion.py +++ b/server/chat/completion.py @@ -1,5 +1,5 @@ from fastapi import Body -from fastapi.responses import StreamingResponse +from sse_starlette.sse import EventSourceResponse from configs import LLM_MODELS, TEMPERATURE from server.utils import wrap_done, get_OpenAI from langchain.chains import LLMChain @@ -63,7 +63,7 @@ async def completion(query: str = Body(..., description="用户输入", examples await task - return StreamingResponse(completion_iterator(query=query, + return EventSourceResponse(completion_iterator(query=query, model_name=model_name, prompt_name=prompt_name), - media_type="text/event-stream") + ) diff --git a/server/chat/file_chat.py b/server/chat/file_chat.py index a58cb29..d9ba5b6 100644 --- a/server/chat/file_chat.py +++ b/server/chat/file_chat.py @@ -1,5 +1,5 @@ from fastapi import Body, File, Form, UploadFile -from fastapi.responses import StreamingResponse +from sse_starlette.sse import EventSourceResponse from configs import (LLM_MODELS, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, TEMPERATURE, CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE) from server.utils import (wrap_done, get_ChatOpenAI, @@ -170,4 +170,4 @@ async def file_chat(query: str = Body(..., description="用户输入", examples= ensure_ascii=False) await task - return StreamingResponse(knowledge_base_chat_iterator(), media_type="text/event-stream") + return EventSourceResponse(knowledge_base_chat_iterator()) diff --git a/server/chat/search_engine_chat.py b/server/chat/search_engine_chat.py index 5b77e99..db00dfa 100644 --- a/server/chat/search_engine_chat.py +++ b/server/chat/search_engine_chat.py @@ -4,7 +4,7 @@ from configs import (BING_SEARCH_URL, BING_SUBSCRIPTION_KEY, METAPHOR_API_KEY, LLM_MODELS, SEARCH_ENGINE_TOP_K, TEMPERATURE, TEXT_SPLITTER_NAME, OVERLAP_SIZE) from fastapi import Body -from fastapi.responses import StreamingResponse +from sse_starlette import EventSourceResponse from fastapi.concurrency import run_in_threadpool from server.utils import wrap_done, get_ChatOpenAI from server.utils import BaseResponse, get_prompt_template @@ -197,10 +197,10 @@ async def search_engine_chat(query: str = Body(..., description="用户输入", ensure_ascii=False) await task - return StreamingResponse(search_engine_chat_iterator(query=query, + return EventSourceResponse(search_engine_chat_iterator(query=query, search_engine_name=search_engine_name, top_k=top_k, history=history, model_name=model_name, prompt_name=prompt_name), - media_type="text/event-stream") + ) diff --git a/server/knowledge_base/kb_doc_api.py b/server/knowledge_base/kb_doc_api.py index df574cd..ff68a6b 100644 --- a/server/knowledge_base/kb_doc_api.py +++ b/server/knowledge_base/kb_doc_api.py @@ -8,7 +8,8 @@ from configs import (DEFAULT_VS_TYPE, EMBEDDING_MODEL, from server.utils import BaseResponse, ListResponse, run_in_thread_pool from server.knowledge_base.utils import (validate_kb_name, list_files_from_folder, get_file_path, files2docs_in_thread, KnowledgeFile) -from fastapi.responses import StreamingResponse, FileResponse +from fastapi.responses import FileResponse +from sse_starlette import EventSourceResponse from pydantic import Json import json from server.knowledge_base.kb_service.base import KBServiceFactory @@ -97,20 +98,6 @@ def _save_files_in_thread(files: List[UploadFile], yield result -# 似乎没有单独增加一个文件上传API接口的必要 -# def upload_files(files: List[UploadFile] = File(..., description="上传文件,支持多文件"), -# knowledge_base_name: str = Form(..., description="知识库名称", examples=["samples"]), -# override: bool = Form(False, description="覆盖已有文件")): -# ''' -# API接口:上传文件。流式返回保存结果:{"code":200, "msg": "xxx", "data": {"knowledge_base_name":"xxx", "file_name": "xxx"}} -# ''' -# def generate(files, knowledge_base_name, override): -# for result in _save_files_in_thread(files, knowledge_base_name=knowledge_base_name, override=override): -# yield json.dumps(result, ensure_ascii=False) - -# return StreamingResponse(generate(files, knowledge_base_name=knowledge_base_name, override=override), media_type="text/event-stream") - - # TODO: 等langchain.document_loaders支持内存文件的时候再开通 # def files2docs(files: List[UploadFile] = File(..., description="上传文件,支持多文件"), # knowledge_base_name: str = Form(..., description="知识库名称", examples=["samples"]), @@ -397,4 +384,4 @@ def recreate_vector_store( if not not_refresh_vs_cache: kb.save_vector_store() - return StreamingResponse(output(), media_type="text/event-stream") + return EventSourceResponse(output()) diff --git a/server/knowledge_base/kb_summary_api.py b/server/knowledge_base/kb_summary_api.py index aac4de7..6558f87 100644 --- a/server/knowledge_base/kb_summary_api.py +++ b/server/knowledge_base/kb_summary_api.py @@ -3,7 +3,7 @@ from configs import (DEFAULT_VS_TYPE, EMBEDDING_MODEL, OVERLAP_SIZE, logger, log_verbose, ) from server.knowledge_base.utils import (list_files_from_folder) -from fastapi.responses import StreamingResponse +from sse_starlette import EventSourceResponse import json from server.knowledge_base.kb_service.base import KBServiceFactory from typing import List, Optional @@ -90,7 +90,7 @@ def recreate_summary_vector_store( }) i += 1 - return StreamingResponse(output(), media_type="text/event-stream") + return EventSourceResponse(output()) def summary_file_to_vector_store( @@ -163,7 +163,7 @@ def summary_file_to_vector_store( "msg": msg, }) - return StreamingResponse(output(), media_type="text/event-stream") + return EventSourceResponse(output()) def summary_doc_ids_to_vector_store( diff --git a/tests/api/test_kb_api.py b/tests/api/test_kb_api.py index c404d22..10a07fd 100644 --- a/tests/api/test_kb_api.py +++ b/tests/api/test_kb_api.py @@ -181,7 +181,7 @@ def test_recreate_vs(api="/knowledge_base/recreate_vector_store"): print("\n重建知识库:") r = requests.post(url, json={"knowledge_base_name": kb}, stream=True) for chunk in r.iter_content(None): - data = json.loads(chunk) + data = json.loads(chunk[6:]) assert isinstance(data, dict) assert data["code"] == 200 print(data["msg"]) diff --git a/tests/api/test_kb_summary_api.py b/tests/api/test_kb_summary_api.py index d59c203..9f85d61 100644 --- a/tests/api/test_kb_summary_api.py +++ b/tests/api/test_kb_summary_api.py @@ -25,7 +25,7 @@ def test_summary_file_to_vector_store(api="/knowledge_base/kb_summary_api/summar "file_name": file_name }, stream=True) for chunk in r.iter_content(None): - data = json.loads(chunk) + data = json.loads(chunk[6:]) assert isinstance(data, dict) assert data["code"] == 200 print(data["msg"]) @@ -38,7 +38,7 @@ def test_summary_doc_ids_to_vector_store(api="/knowledge_base/kb_summary_api/sum "doc_ids": doc_ids }, stream=True) for chunk in r.iter_content(None): - data = json.loads(chunk) + data = json.loads(chunk[6:]) assert isinstance(data, dict) assert data["code"] == 200 print(data) diff --git a/tests/api/test_stream_chat_api.py b/tests/api/test_stream_chat_api.py index 7db8276..606eb8d 100644 --- a/tests/api/test_stream_chat_api.py +++ b/tests/api/test_stream_chat_api.py @@ -90,7 +90,7 @@ def test_knowledge_chat(api="/chat/knowledge_base_chat"): print("\n") print("=" * 30 + api + " output" + "="*30) for line in response.iter_content(None, decode_unicode=True): - data = json.loads(line) + data = json.loads(line[6:]) if "answer" in data: print(data["answer"], end="", flush=True) pprint(data) @@ -116,7 +116,7 @@ def test_search_engine_chat(api="/chat/search_engine_chat"): print("\n") print("=" * 30 + api + f" by {se} output" + "="*30) for line in response.iter_content(None, decode_unicode=True): - data = json.loads(line) + data = json.loads(line[6:]) if "answer" in data: print(data["answer"], end="", flush=True) assert "docs" in data and len(data["docs"]) > 0 diff --git a/tests/api/test_stream_chat_api_thread.py b/tests/api/test_stream_chat_api_thread.py index bdf7436..7c9a1d1 100644 --- a/tests/api/test_stream_chat_api_thread.py +++ b/tests/api/test_stream_chat_api_thread.py @@ -55,7 +55,7 @@ def knowledge_chat(api="/chat/knowledge_base_chat"): response = requests.post(url, headers=headers, json=data, stream=True) for line in response.iter_content(None, decode_unicode=True): - data = json.loads(line) + data = json.loads(line[6:]) result.append(data) return result diff --git a/webui_pages/utils.py b/webui_pages/utils.py index 75e1419..f688949 100644 --- a/webui_pages/utils.py +++ b/webui_pages/utils.py @@ -133,8 +133,10 @@ class ApiRequest: continue if as_json: try: - data = json.loads(chunk) - # pprint(data, depth=1) + if chunk.startswith("data: "): + data = json.loads(chunk[6:-2]) + else: + data = json.loads(chunk) yield data except Exception as e: msg = f"接口返回json错误: ‘{chunk}’。错误信息是:{e}。" @@ -165,8 +167,10 @@ class ApiRequest: continue if as_json: try: - data = json.loads(chunk) - # pprint(data, depth=1) + if chunk.startswith("data: "): + data = json.loads(chunk[6:-2]) + else: + data = json.loads(chunk) yield data except Exception as e: msg = f"接口返回json错误: ‘{chunk}’。错误信息是:{e}。"