所有 chat 接口都改为 EventSourceResponse;ApiRequest 作对应修改

This commit is contained in:
liunux4odoo 2023-12-13 16:08:58 +08:00
parent c8fef3380c
commit 472a97a616
13 changed files with 36 additions and 43 deletions

View File

@ -25,6 +25,8 @@ if not os.path.exists(LOG_PATH):
# 临时文件目录,主要用于文件对话
BASE_TEMP_DIR = os.path.join(tempfile.gettempdir(), "chatchat")
if os.path.isdir(BASE_TEMP_DIR):
try:
shutil.rmtree(BASE_TEMP_DIR)
except Exception:
pass
os.makedirs(BASE_TEMP_DIR, exist_ok=True)

View File

@ -6,7 +6,7 @@ from server.agent.callbacks import CustomAsyncIteratorCallbackHandler, Status
from langchain.agents import LLMSingleActionAgent, AgentExecutor
from server.agent.custom_template import CustomOutputParser, CustomPromptTemplate
from fastapi import Body
from fastapi.responses import StreamingResponse
from sse_starlette.sse import EventSourceResponse
from configs import LLM_MODELS, TEMPERATURE, HISTORY_LEN, Agent_MODEL
from server.utils import wrap_done, get_ChatOpenAI, get_prompt_template
from langchain.chains import LLMChain
@ -180,8 +180,8 @@ async def agent_chat(query: str = Body(..., description="用户输入", examples
yield json.dumps({"answer": answer, "final_answer": final_answer}, ensure_ascii=False)
await task
return StreamingResponse(agent_chat_iterator(query=query,
return EventSourceResponse(agent_chat_iterator(query=query,
history=history,
model_name=model_name,
prompt_name=prompt_name),
media_type="text/event-stream")
)

View File

@ -1,5 +1,5 @@
from fastapi import Body
from fastapi.responses import StreamingResponse
from sse_starlette.sse import EventSourceResponse
from configs import LLM_MODELS, TEMPERATURE
from server.utils import wrap_done, get_ChatOpenAI
from langchain.chains import LLMChain
@ -100,4 +100,4 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
await task
return StreamingResponse(chat_iterator(), media_type="text/event-stream")
return EventSourceResponse(chat_iterator())

View File

@ -1,5 +1,5 @@
from fastapi import Body
from fastapi.responses import StreamingResponse
from sse_starlette.sse import EventSourceResponse
from configs import LLM_MODELS, TEMPERATURE
from server.utils import wrap_done, get_OpenAI
from langchain.chains import LLMChain
@ -63,7 +63,7 @@ async def completion(query: str = Body(..., description="用户输入", examples
await task
return StreamingResponse(completion_iterator(query=query,
return EventSourceResponse(completion_iterator(query=query,
model_name=model_name,
prompt_name=prompt_name),
media_type="text/event-stream")
)

View File

@ -1,5 +1,5 @@
from fastapi import Body, File, Form, UploadFile
from fastapi.responses import StreamingResponse
from sse_starlette.sse import EventSourceResponse
from configs import (LLM_MODELS, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, TEMPERATURE,
CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE)
from server.utils import (wrap_done, get_ChatOpenAI,
@ -170,4 +170,4 @@ async def file_chat(query: str = Body(..., description="用户输入", examples=
ensure_ascii=False)
await task
return StreamingResponse(knowledge_base_chat_iterator(), media_type="text/event-stream")
return EventSourceResponse(knowledge_base_chat_iterator())

View File

@ -4,7 +4,7 @@ from configs import (BING_SEARCH_URL, BING_SUBSCRIPTION_KEY, METAPHOR_API_KEY,
LLM_MODELS, SEARCH_ENGINE_TOP_K, TEMPERATURE,
TEXT_SPLITTER_NAME, OVERLAP_SIZE)
from fastapi import Body
from fastapi.responses import StreamingResponse
from sse_starlette import EventSourceResponse
from fastapi.concurrency import run_in_threadpool
from server.utils import wrap_done, get_ChatOpenAI
from server.utils import BaseResponse, get_prompt_template
@ -197,10 +197,10 @@ async def search_engine_chat(query: str = Body(..., description="用户输入",
ensure_ascii=False)
await task
return StreamingResponse(search_engine_chat_iterator(query=query,
return EventSourceResponse(search_engine_chat_iterator(query=query,
search_engine_name=search_engine_name,
top_k=top_k,
history=history,
model_name=model_name,
prompt_name=prompt_name),
media_type="text/event-stream")
)

View File

@ -8,7 +8,8 @@ from configs import (DEFAULT_VS_TYPE, EMBEDDING_MODEL,
from server.utils import BaseResponse, ListResponse, run_in_thread_pool
from server.knowledge_base.utils import (validate_kb_name, list_files_from_folder, get_file_path,
files2docs_in_thread, KnowledgeFile)
from fastapi.responses import StreamingResponse, FileResponse
from fastapi.responses import FileResponse
from sse_starlette import EventSourceResponse
from pydantic import Json
import json
from server.knowledge_base.kb_service.base import KBServiceFactory
@ -97,20 +98,6 @@ def _save_files_in_thread(files: List[UploadFile],
yield result
# 似乎没有单独增加一个文件上传API接口的必要
# def upload_files(files: List[UploadFile] = File(..., description="上传文件,支持多文件"),
# knowledge_base_name: str = Form(..., description="知识库名称", examples=["samples"]),
# override: bool = Form(False, description="覆盖已有文件")):
# '''
# API接口上传文件。流式返回保存结果{"code":200, "msg": "xxx", "data": {"knowledge_base_name":"xxx", "file_name": "xxx"}}
# '''
# def generate(files, knowledge_base_name, override):
# for result in _save_files_in_thread(files, knowledge_base_name=knowledge_base_name, override=override):
# yield json.dumps(result, ensure_ascii=False)
# return StreamingResponse(generate(files, knowledge_base_name=knowledge_base_name, override=override), media_type="text/event-stream")
# TODO: 等langchain.document_loaders支持内存文件的时候再开通
# def files2docs(files: List[UploadFile] = File(..., description="上传文件,支持多文件"),
# knowledge_base_name: str = Form(..., description="知识库名称", examples=["samples"]),
@ -397,4 +384,4 @@ def recreate_vector_store(
if not not_refresh_vs_cache:
kb.save_vector_store()
return StreamingResponse(output(), media_type="text/event-stream")
return EventSourceResponse(output())

View File

@ -3,7 +3,7 @@ from configs import (DEFAULT_VS_TYPE, EMBEDDING_MODEL,
OVERLAP_SIZE,
logger, log_verbose, )
from server.knowledge_base.utils import (list_files_from_folder)
from fastapi.responses import StreamingResponse
from sse_starlette import EventSourceResponse
import json
from server.knowledge_base.kb_service.base import KBServiceFactory
from typing import List, Optional
@ -90,7 +90,7 @@ def recreate_summary_vector_store(
})
i += 1
return StreamingResponse(output(), media_type="text/event-stream")
return EventSourceResponse(output())
def summary_file_to_vector_store(
@ -163,7 +163,7 @@ def summary_file_to_vector_store(
"msg": msg,
})
return StreamingResponse(output(), media_type="text/event-stream")
return EventSourceResponse(output())
def summary_doc_ids_to_vector_store(

View File

@ -181,7 +181,7 @@ def test_recreate_vs(api="/knowledge_base/recreate_vector_store"):
print("\n重建知识库:")
r = requests.post(url, json={"knowledge_base_name": kb}, stream=True)
for chunk in r.iter_content(None):
data = json.loads(chunk)
data = json.loads(chunk[6:])
assert isinstance(data, dict)
assert data["code"] == 200
print(data["msg"])

View File

@ -25,7 +25,7 @@ def test_summary_file_to_vector_store(api="/knowledge_base/kb_summary_api/summar
"file_name": file_name
}, stream=True)
for chunk in r.iter_content(None):
data = json.loads(chunk)
data = json.loads(chunk[6:])
assert isinstance(data, dict)
assert data["code"] == 200
print(data["msg"])
@ -38,7 +38,7 @@ def test_summary_doc_ids_to_vector_store(api="/knowledge_base/kb_summary_api/sum
"doc_ids": doc_ids
}, stream=True)
for chunk in r.iter_content(None):
data = json.loads(chunk)
data = json.loads(chunk[6:])
assert isinstance(data, dict)
assert data["code"] == 200
print(data)

View File

@ -90,7 +90,7 @@ def test_knowledge_chat(api="/chat/knowledge_base_chat"):
print("\n")
print("=" * 30 + api + " output" + "="*30)
for line in response.iter_content(None, decode_unicode=True):
data = json.loads(line)
data = json.loads(line[6:])
if "answer" in data:
print(data["answer"], end="", flush=True)
pprint(data)
@ -116,7 +116,7 @@ def test_search_engine_chat(api="/chat/search_engine_chat"):
print("\n")
print("=" * 30 + api + f" by {se} output" + "="*30)
for line in response.iter_content(None, decode_unicode=True):
data = json.loads(line)
data = json.loads(line[6:])
if "answer" in data:
print(data["answer"], end="", flush=True)
assert "docs" in data and len(data["docs"]) > 0

View File

@ -55,7 +55,7 @@ def knowledge_chat(api="/chat/knowledge_base_chat"):
response = requests.post(url, headers=headers, json=data, stream=True)
for line in response.iter_content(None, decode_unicode=True):
data = json.loads(line)
data = json.loads(line[6:])
result.append(data)
return result

View File

@ -133,8 +133,10 @@ class ApiRequest:
continue
if as_json:
try:
data = json.loads(chunk)
# pprint(data, depth=1)
if chunk.startswith("data: "):
data = json.loads(chunk[6:-2])
else:
data = json.loads(chunk)
yield data
except Exception as e:
msg = f"接口返回json错误 {chunk}’。错误信息是:{e}"
@ -165,8 +167,10 @@ class ApiRequest:
continue
if as_json:
try:
data = json.loads(chunk)
# pprint(data, depth=1)
if chunk.startswith("data: "):
data = json.loads(chunk[6:-2])
else:
data = json.loads(chunk)
yield data
except Exception as e:
msg = f"接口返回json错误 {chunk}’。错误信息是:{e}"