所有 chat 接口都改为 EventSourceResponse;ApiRequest 作对应修改
This commit is contained in:
parent
c8fef3380c
commit
472a97a616
|
|
@ -25,6 +25,8 @@ if not os.path.exists(LOG_PATH):
|
||||||
|
|
||||||
# 临时文件目录,主要用于文件对话
|
# 临时文件目录,主要用于文件对话
|
||||||
BASE_TEMP_DIR = os.path.join(tempfile.gettempdir(), "chatchat")
|
BASE_TEMP_DIR = os.path.join(tempfile.gettempdir(), "chatchat")
|
||||||
if os.path.isdir(BASE_TEMP_DIR):
|
try:
|
||||||
shutil.rmtree(BASE_TEMP_DIR)
|
shutil.rmtree(BASE_TEMP_DIR)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
os.makedirs(BASE_TEMP_DIR, exist_ok=True)
|
os.makedirs(BASE_TEMP_DIR, exist_ok=True)
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ from server.agent.callbacks import CustomAsyncIteratorCallbackHandler, Status
|
||||||
from langchain.agents import LLMSingleActionAgent, AgentExecutor
|
from langchain.agents import LLMSingleActionAgent, AgentExecutor
|
||||||
from server.agent.custom_template import CustomOutputParser, CustomPromptTemplate
|
from server.agent.custom_template import CustomOutputParser, CustomPromptTemplate
|
||||||
from fastapi import Body
|
from fastapi import Body
|
||||||
from fastapi.responses import StreamingResponse
|
from sse_starlette.sse import EventSourceResponse
|
||||||
from configs import LLM_MODELS, TEMPERATURE, HISTORY_LEN, Agent_MODEL
|
from configs import LLM_MODELS, TEMPERATURE, HISTORY_LEN, Agent_MODEL
|
||||||
from server.utils import wrap_done, get_ChatOpenAI, get_prompt_template
|
from server.utils import wrap_done, get_ChatOpenAI, get_prompt_template
|
||||||
from langchain.chains import LLMChain
|
from langchain.chains import LLMChain
|
||||||
|
|
@ -180,8 +180,8 @@ async def agent_chat(query: str = Body(..., description="用户输入", examples
|
||||||
yield json.dumps({"answer": answer, "final_answer": final_answer}, ensure_ascii=False)
|
yield json.dumps({"answer": answer, "final_answer": final_answer}, ensure_ascii=False)
|
||||||
await task
|
await task
|
||||||
|
|
||||||
return StreamingResponse(agent_chat_iterator(query=query,
|
return EventSourceResponse(agent_chat_iterator(query=query,
|
||||||
history=history,
|
history=history,
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
prompt_name=prompt_name),
|
prompt_name=prompt_name),
|
||||||
media_type="text/event-stream")
|
)
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
from fastapi import Body
|
from fastapi import Body
|
||||||
from fastapi.responses import StreamingResponse
|
from sse_starlette.sse import EventSourceResponse
|
||||||
from configs import LLM_MODELS, TEMPERATURE
|
from configs import LLM_MODELS, TEMPERATURE
|
||||||
from server.utils import wrap_done, get_ChatOpenAI
|
from server.utils import wrap_done, get_ChatOpenAI
|
||||||
from langchain.chains import LLMChain
|
from langchain.chains import LLMChain
|
||||||
|
|
@ -100,4 +100,4 @@ async def chat(query: str = Body(..., description="用户输入", examples=["恼
|
||||||
|
|
||||||
await task
|
await task
|
||||||
|
|
||||||
return StreamingResponse(chat_iterator(), media_type="text/event-stream")
|
return EventSourceResponse(chat_iterator())
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
from fastapi import Body
|
from fastapi import Body
|
||||||
from fastapi.responses import StreamingResponse
|
from sse_starlette.sse import EventSourceResponse
|
||||||
from configs import LLM_MODELS, TEMPERATURE
|
from configs import LLM_MODELS, TEMPERATURE
|
||||||
from server.utils import wrap_done, get_OpenAI
|
from server.utils import wrap_done, get_OpenAI
|
||||||
from langchain.chains import LLMChain
|
from langchain.chains import LLMChain
|
||||||
|
|
@ -63,7 +63,7 @@ async def completion(query: str = Body(..., description="用户输入", examples
|
||||||
|
|
||||||
await task
|
await task
|
||||||
|
|
||||||
return StreamingResponse(completion_iterator(query=query,
|
return EventSourceResponse(completion_iterator(query=query,
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
prompt_name=prompt_name),
|
prompt_name=prompt_name),
|
||||||
media_type="text/event-stream")
|
)
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
from fastapi import Body, File, Form, UploadFile
|
from fastapi import Body, File, Form, UploadFile
|
||||||
from fastapi.responses import StreamingResponse
|
from sse_starlette.sse import EventSourceResponse
|
||||||
from configs import (LLM_MODELS, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, TEMPERATURE,
|
from configs import (LLM_MODELS, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD, TEMPERATURE,
|
||||||
CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE)
|
CHUNK_SIZE, OVERLAP_SIZE, ZH_TITLE_ENHANCE)
|
||||||
from server.utils import (wrap_done, get_ChatOpenAI,
|
from server.utils import (wrap_done, get_ChatOpenAI,
|
||||||
|
|
@ -170,4 +170,4 @@ async def file_chat(query: str = Body(..., description="用户输入", examples=
|
||||||
ensure_ascii=False)
|
ensure_ascii=False)
|
||||||
await task
|
await task
|
||||||
|
|
||||||
return StreamingResponse(knowledge_base_chat_iterator(), media_type="text/event-stream")
|
return EventSourceResponse(knowledge_base_chat_iterator())
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ from configs import (BING_SEARCH_URL, BING_SUBSCRIPTION_KEY, METAPHOR_API_KEY,
|
||||||
LLM_MODELS, SEARCH_ENGINE_TOP_K, TEMPERATURE,
|
LLM_MODELS, SEARCH_ENGINE_TOP_K, TEMPERATURE,
|
||||||
TEXT_SPLITTER_NAME, OVERLAP_SIZE)
|
TEXT_SPLITTER_NAME, OVERLAP_SIZE)
|
||||||
from fastapi import Body
|
from fastapi import Body
|
||||||
from fastapi.responses import StreamingResponse
|
from sse_starlette import EventSourceResponse
|
||||||
from fastapi.concurrency import run_in_threadpool
|
from fastapi.concurrency import run_in_threadpool
|
||||||
from server.utils import wrap_done, get_ChatOpenAI
|
from server.utils import wrap_done, get_ChatOpenAI
|
||||||
from server.utils import BaseResponse, get_prompt_template
|
from server.utils import BaseResponse, get_prompt_template
|
||||||
|
|
@ -197,10 +197,10 @@ async def search_engine_chat(query: str = Body(..., description="用户输入",
|
||||||
ensure_ascii=False)
|
ensure_ascii=False)
|
||||||
await task
|
await task
|
||||||
|
|
||||||
return StreamingResponse(search_engine_chat_iterator(query=query,
|
return EventSourceResponse(search_engine_chat_iterator(query=query,
|
||||||
search_engine_name=search_engine_name,
|
search_engine_name=search_engine_name,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
history=history,
|
history=history,
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
prompt_name=prompt_name),
|
prompt_name=prompt_name),
|
||||||
media_type="text/event-stream")
|
)
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,8 @@ from configs import (DEFAULT_VS_TYPE, EMBEDDING_MODEL,
|
||||||
from server.utils import BaseResponse, ListResponse, run_in_thread_pool
|
from server.utils import BaseResponse, ListResponse, run_in_thread_pool
|
||||||
from server.knowledge_base.utils import (validate_kb_name, list_files_from_folder, get_file_path,
|
from server.knowledge_base.utils import (validate_kb_name, list_files_from_folder, get_file_path,
|
||||||
files2docs_in_thread, KnowledgeFile)
|
files2docs_in_thread, KnowledgeFile)
|
||||||
from fastapi.responses import StreamingResponse, FileResponse
|
from fastapi.responses import FileResponse
|
||||||
|
from sse_starlette import EventSourceResponse
|
||||||
from pydantic import Json
|
from pydantic import Json
|
||||||
import json
|
import json
|
||||||
from server.knowledge_base.kb_service.base import KBServiceFactory
|
from server.knowledge_base.kb_service.base import KBServiceFactory
|
||||||
|
|
@ -97,20 +98,6 @@ def _save_files_in_thread(files: List[UploadFile],
|
||||||
yield result
|
yield result
|
||||||
|
|
||||||
|
|
||||||
# 似乎没有单独增加一个文件上传API接口的必要
|
|
||||||
# def upload_files(files: List[UploadFile] = File(..., description="上传文件,支持多文件"),
|
|
||||||
# knowledge_base_name: str = Form(..., description="知识库名称", examples=["samples"]),
|
|
||||||
# override: bool = Form(False, description="覆盖已有文件")):
|
|
||||||
# '''
|
|
||||||
# API接口:上传文件。流式返回保存结果:{"code":200, "msg": "xxx", "data": {"knowledge_base_name":"xxx", "file_name": "xxx"}}
|
|
||||||
# '''
|
|
||||||
# def generate(files, knowledge_base_name, override):
|
|
||||||
# for result in _save_files_in_thread(files, knowledge_base_name=knowledge_base_name, override=override):
|
|
||||||
# yield json.dumps(result, ensure_ascii=False)
|
|
||||||
|
|
||||||
# return StreamingResponse(generate(files, knowledge_base_name=knowledge_base_name, override=override), media_type="text/event-stream")
|
|
||||||
|
|
||||||
|
|
||||||
# TODO: 等langchain.document_loaders支持内存文件的时候再开通
|
# TODO: 等langchain.document_loaders支持内存文件的时候再开通
|
||||||
# def files2docs(files: List[UploadFile] = File(..., description="上传文件,支持多文件"),
|
# def files2docs(files: List[UploadFile] = File(..., description="上传文件,支持多文件"),
|
||||||
# knowledge_base_name: str = Form(..., description="知识库名称", examples=["samples"]),
|
# knowledge_base_name: str = Form(..., description="知识库名称", examples=["samples"]),
|
||||||
|
|
@ -397,4 +384,4 @@ def recreate_vector_store(
|
||||||
if not not_refresh_vs_cache:
|
if not not_refresh_vs_cache:
|
||||||
kb.save_vector_store()
|
kb.save_vector_store()
|
||||||
|
|
||||||
return StreamingResponse(output(), media_type="text/event-stream")
|
return EventSourceResponse(output())
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ from configs import (DEFAULT_VS_TYPE, EMBEDDING_MODEL,
|
||||||
OVERLAP_SIZE,
|
OVERLAP_SIZE,
|
||||||
logger, log_verbose, )
|
logger, log_verbose, )
|
||||||
from server.knowledge_base.utils import (list_files_from_folder)
|
from server.knowledge_base.utils import (list_files_from_folder)
|
||||||
from fastapi.responses import StreamingResponse
|
from sse_starlette import EventSourceResponse
|
||||||
import json
|
import json
|
||||||
from server.knowledge_base.kb_service.base import KBServiceFactory
|
from server.knowledge_base.kb_service.base import KBServiceFactory
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
@ -90,7 +90,7 @@ def recreate_summary_vector_store(
|
||||||
})
|
})
|
||||||
i += 1
|
i += 1
|
||||||
|
|
||||||
return StreamingResponse(output(), media_type="text/event-stream")
|
return EventSourceResponse(output())
|
||||||
|
|
||||||
|
|
||||||
def summary_file_to_vector_store(
|
def summary_file_to_vector_store(
|
||||||
|
|
@ -163,7 +163,7 @@ def summary_file_to_vector_store(
|
||||||
"msg": msg,
|
"msg": msg,
|
||||||
})
|
})
|
||||||
|
|
||||||
return StreamingResponse(output(), media_type="text/event-stream")
|
return EventSourceResponse(output())
|
||||||
|
|
||||||
|
|
||||||
def summary_doc_ids_to_vector_store(
|
def summary_doc_ids_to_vector_store(
|
||||||
|
|
|
||||||
|
|
@ -181,7 +181,7 @@ def test_recreate_vs(api="/knowledge_base/recreate_vector_store"):
|
||||||
print("\n重建知识库:")
|
print("\n重建知识库:")
|
||||||
r = requests.post(url, json={"knowledge_base_name": kb}, stream=True)
|
r = requests.post(url, json={"knowledge_base_name": kb}, stream=True)
|
||||||
for chunk in r.iter_content(None):
|
for chunk in r.iter_content(None):
|
||||||
data = json.loads(chunk)
|
data = json.loads(chunk[6:])
|
||||||
assert isinstance(data, dict)
|
assert isinstance(data, dict)
|
||||||
assert data["code"] == 200
|
assert data["code"] == 200
|
||||||
print(data["msg"])
|
print(data["msg"])
|
||||||
|
|
|
||||||
|
|
@ -25,7 +25,7 @@ def test_summary_file_to_vector_store(api="/knowledge_base/kb_summary_api/summar
|
||||||
"file_name": file_name
|
"file_name": file_name
|
||||||
}, stream=True)
|
}, stream=True)
|
||||||
for chunk in r.iter_content(None):
|
for chunk in r.iter_content(None):
|
||||||
data = json.loads(chunk)
|
data = json.loads(chunk[6:])
|
||||||
assert isinstance(data, dict)
|
assert isinstance(data, dict)
|
||||||
assert data["code"] == 200
|
assert data["code"] == 200
|
||||||
print(data["msg"])
|
print(data["msg"])
|
||||||
|
|
@ -38,7 +38,7 @@ def test_summary_doc_ids_to_vector_store(api="/knowledge_base/kb_summary_api/sum
|
||||||
"doc_ids": doc_ids
|
"doc_ids": doc_ids
|
||||||
}, stream=True)
|
}, stream=True)
|
||||||
for chunk in r.iter_content(None):
|
for chunk in r.iter_content(None):
|
||||||
data = json.loads(chunk)
|
data = json.loads(chunk[6:])
|
||||||
assert isinstance(data, dict)
|
assert isinstance(data, dict)
|
||||||
assert data["code"] == 200
|
assert data["code"] == 200
|
||||||
print(data)
|
print(data)
|
||||||
|
|
|
||||||
|
|
@ -90,7 +90,7 @@ def test_knowledge_chat(api="/chat/knowledge_base_chat"):
|
||||||
print("\n")
|
print("\n")
|
||||||
print("=" * 30 + api + " output" + "="*30)
|
print("=" * 30 + api + " output" + "="*30)
|
||||||
for line in response.iter_content(None, decode_unicode=True):
|
for line in response.iter_content(None, decode_unicode=True):
|
||||||
data = json.loads(line)
|
data = json.loads(line[6:])
|
||||||
if "answer" in data:
|
if "answer" in data:
|
||||||
print(data["answer"], end="", flush=True)
|
print(data["answer"], end="", flush=True)
|
||||||
pprint(data)
|
pprint(data)
|
||||||
|
|
@ -116,7 +116,7 @@ def test_search_engine_chat(api="/chat/search_engine_chat"):
|
||||||
print("\n")
|
print("\n")
|
||||||
print("=" * 30 + api + f" by {se} output" + "="*30)
|
print("=" * 30 + api + f" by {se} output" + "="*30)
|
||||||
for line in response.iter_content(None, decode_unicode=True):
|
for line in response.iter_content(None, decode_unicode=True):
|
||||||
data = json.loads(line)
|
data = json.loads(line[6:])
|
||||||
if "answer" in data:
|
if "answer" in data:
|
||||||
print(data["answer"], end="", flush=True)
|
print(data["answer"], end="", flush=True)
|
||||||
assert "docs" in data and len(data["docs"]) > 0
|
assert "docs" in data and len(data["docs"]) > 0
|
||||||
|
|
|
||||||
|
|
@ -55,7 +55,7 @@ def knowledge_chat(api="/chat/knowledge_base_chat"):
|
||||||
response = requests.post(url, headers=headers, json=data, stream=True)
|
response = requests.post(url, headers=headers, json=data, stream=True)
|
||||||
|
|
||||||
for line in response.iter_content(None, decode_unicode=True):
|
for line in response.iter_content(None, decode_unicode=True):
|
||||||
data = json.loads(line)
|
data = json.loads(line[6:])
|
||||||
result.append(data)
|
result.append(data)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
|
||||||
|
|
@ -133,8 +133,10 @@ class ApiRequest:
|
||||||
continue
|
continue
|
||||||
if as_json:
|
if as_json:
|
||||||
try:
|
try:
|
||||||
data = json.loads(chunk)
|
if chunk.startswith("data: "):
|
||||||
# pprint(data, depth=1)
|
data = json.loads(chunk[6:-2])
|
||||||
|
else:
|
||||||
|
data = json.loads(chunk)
|
||||||
yield data
|
yield data
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
msg = f"接口返回json错误: ‘{chunk}’。错误信息是:{e}。"
|
msg = f"接口返回json错误: ‘{chunk}’。错误信息是:{e}。"
|
||||||
|
|
@ -165,8 +167,10 @@ class ApiRequest:
|
||||||
continue
|
continue
|
||||||
if as_json:
|
if as_json:
|
||||||
try:
|
try:
|
||||||
data = json.loads(chunk)
|
if chunk.startswith("data: "):
|
||||||
# pprint(data, depth=1)
|
data = json.loads(chunk[6:-2])
|
||||||
|
else:
|
||||||
|
data = json.loads(chunk)
|
||||||
yield data
|
yield data
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
msg = f"接口返回json错误: ‘{chunk}’。错误信息是:{e}。"
|
msg = f"接口返回json错误: ‘{chunk}’。错误信息是:{e}。"
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue