add history to chat apis
This commit is contained in:
parent
c8b6d53ede
commit
2d49746a8d
157
api_one.py
157
api_one.py
|
|
@ -1,157 +0,0 @@
|
||||||
import sys
|
|
||||||
import os
|
|
||||||
from configs.model_config import (
|
|
||||||
llm_model_dict,
|
|
||||||
LLM_MODEL,
|
|
||||||
EMBEDDING_DEVICE,
|
|
||||||
LLM_DEVICE,
|
|
||||||
LOG_PATH,
|
|
||||||
logger,
|
|
||||||
)
|
|
||||||
from fastchat.serve.model_worker import heart_beat_worker
|
|
||||||
from fastapi import FastAPI
|
|
||||||
import threading
|
|
||||||
import asyncio
|
|
||||||
|
|
||||||
# 把fastchat的3个服务端整合在一起,分别是:
|
|
||||||
# - http://{host_ip}:{port}/controller 对应 python -m fastchat.serve.controller
|
|
||||||
# - http://{host_ip}:{port}/model_worker 对应 python -m fastchat.serve.model_worker ...
|
|
||||||
# - http://{host_ip}:{port}/openai 对应 python -m fastchat.serve.openai_api_server ...
|
|
||||||
|
|
||||||
|
|
||||||
host_ip = "0.0.0.0"
|
|
||||||
port = 8888
|
|
||||||
base_url = f"http://127.0.0.1:{port}"
|
|
||||||
|
|
||||||
|
|
||||||
def create_controller_app(
|
|
||||||
dispatch_method="shortest_queue",
|
|
||||||
):
|
|
||||||
from fastchat.serve.controller import app, Controller
|
|
||||||
controller = Controller(dispatch_method)
|
|
||||||
sys.modules["fastchat.serve.controller"].controller = controller
|
|
||||||
logger.info(f"controller dispatch method: {dispatch_method}")
|
|
||||||
return app, controller
|
|
||||||
|
|
||||||
|
|
||||||
def create_model_worker_app(
|
|
||||||
model_path,
|
|
||||||
model_names=[LLM_MODEL],
|
|
||||||
device=LLM_DEVICE,
|
|
||||||
load_8bit=False,
|
|
||||||
gptq_ckpt=None,
|
|
||||||
gptq_wbits=16,
|
|
||||||
gpus='',
|
|
||||||
num_gpus=-1,
|
|
||||||
max_gpu_memory=-1,
|
|
||||||
cpu_offloading=None,
|
|
||||||
worker_address=f"{base_url}/model_worker",
|
|
||||||
controller_address=f"{base_url}/controller",
|
|
||||||
limit_model_concurrency=5,
|
|
||||||
stream_interval=2,
|
|
||||||
no_register=True, # mannually register
|
|
||||||
):
|
|
||||||
from fastchat.serve.model_worker import app, GptqConfig, ModelWorker, worker_id
|
|
||||||
from fastchat.serve import model_worker
|
|
||||||
if gpus and num_gpus is None:
|
|
||||||
num_gpus = len(gpus.split(','))
|
|
||||||
gptq_config = GptqConfig(
|
|
||||||
ckpt=gptq_ckpt or model_path,
|
|
||||||
wbits=gptq_wbits,
|
|
||||||
groupsize=-1,
|
|
||||||
act_order=None,
|
|
||||||
)
|
|
||||||
worker = ModelWorker(
|
|
||||||
controller_address,
|
|
||||||
worker_address,
|
|
||||||
worker_id,
|
|
||||||
no_register,
|
|
||||||
model_path,
|
|
||||||
model_names,
|
|
||||||
device,
|
|
||||||
num_gpus,
|
|
||||||
max_gpu_memory,
|
|
||||||
load_8bit,
|
|
||||||
cpu_offloading,
|
|
||||||
gptq_config,
|
|
||||||
)
|
|
||||||
sys.modules["fastchat.serve.model_worker"].worker = worker
|
|
||||||
return app, worker
|
|
||||||
|
|
||||||
|
|
||||||
def create_openai_api_app(
|
|
||||||
host=host_ip,
|
|
||||||
port=port,
|
|
||||||
controller_address=f"{base_url}/controller",
|
|
||||||
allow_credentials=None,
|
|
||||||
allowed_origins=["*"],
|
|
||||||
allowed_methods=["*"],
|
|
||||||
allowed_headers=["*"],
|
|
||||||
api_keys=[],
|
|
||||||
):
|
|
||||||
from fastchat.serve.openai_api_server import app, CORSMiddleware, app_settings
|
|
||||||
app.add_middleware(
|
|
||||||
CORSMiddleware,
|
|
||||||
allow_origins=allowed_origins,
|
|
||||||
allow_credentials=allow_credentials,
|
|
||||||
allow_methods=allowed_methods,
|
|
||||||
allow_headers=allowed_headers,
|
|
||||||
)
|
|
||||||
app_settings.controller_address = controller_address
|
|
||||||
app_settings.api_keys = api_keys
|
|
||||||
sys.modules["fastchat.serve.openai_api_server.app_settings"] = app_settings
|
|
||||||
|
|
||||||
return app
|
|
||||||
|
|
||||||
|
|
||||||
LLM_MODEL = 'chatglm-6b'
|
|
||||||
model_path = llm_model_dict[LLM_MODEL]["local_model_path"]
|
|
||||||
global controller
|
|
||||||
|
|
||||||
|
|
||||||
if not model_path:
|
|
||||||
logger.error("local_model_path 不能为空")
|
|
||||||
else:
|
|
||||||
logger.info(f"using local model: {model_path}")
|
|
||||||
app = FastAPI()
|
|
||||||
|
|
||||||
controller_app, controller = create_controller_app()
|
|
||||||
app.mount("/controller", controller_app)
|
|
||||||
|
|
||||||
model_woker_app, worker = create_model_worker_app(model_path)
|
|
||||||
app.mount("/model_worker", model_woker_app)
|
|
||||||
|
|
||||||
openai_api_app = create_openai_api_app()
|
|
||||||
app.mount("/openai", openai_api_app)
|
|
||||||
|
|
||||||
|
|
||||||
@app.on_event("startup")
|
|
||||||
async def on_startup():
|
|
||||||
logger.info("Register to controller")
|
|
||||||
controller.register_worker(
|
|
||||||
worker.worker_addr,
|
|
||||||
True,
|
|
||||||
worker.get_status(),
|
|
||||||
)
|
|
||||||
worker.heart_beat_thread = threading.Thread(
|
|
||||||
target=heart_beat_worker, args=(worker,)
|
|
||||||
)
|
|
||||||
worker.heart_beat_thread.start()
|
|
||||||
|
|
||||||
# 通过网络请求注册会卡死
|
|
||||||
# async def register():
|
|
||||||
# while True:
|
|
||||||
# try:
|
|
||||||
# worker.register_to_controller()
|
|
||||||
# worker.heart_beat_thread = threading.Thread(
|
|
||||||
# target=heart_beat_worker, args=(worker,)
|
|
||||||
# )
|
|
||||||
# worker.heart_beat_thread.start()
|
|
||||||
# except:
|
|
||||||
# await asyncio.sleep(1)
|
|
||||||
# asyncio.get_event_loop().create_task(register())
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
import uvicorn
|
|
||||||
uvicorn.run(app, host=host_ip, port=port)
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
langchain==0.0.237
|
langchain==0.0.257
|
||||||
openai
|
openai
|
||||||
sentence_transformers
|
sentence_transformers
|
||||||
chromadb
|
chromadb
|
||||||
|
|
|
||||||
|
|
@ -11,9 +11,9 @@ from fastapi.middleware.cors import CORSMiddleware
|
||||||
from starlette.responses import RedirectResponse
|
from starlette.responses import RedirectResponse
|
||||||
from server.chat import (chat, knowledge_base_chat, openai_chat,
|
from server.chat import (chat, knowledge_base_chat, openai_chat,
|
||||||
search_engine_chat)
|
search_engine_chat)
|
||||||
from server.knowledge_base import (list_kbs, create_kb, delete_kb,
|
# from server.knowledge_base import (list_kbs, create_kb, delete_kb,
|
||||||
list_docs, upload_doc, delete_doc,
|
# list_docs, upload_doc, delete_doc,
|
||||||
update_doc, recreate_vector_store)
|
# update_doc, recreate_vector_store)
|
||||||
from server.utils import BaseResponse, ListResponse
|
from server.utils import BaseResponse, ListResponse
|
||||||
|
|
||||||
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
|
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
|
||||||
|
|
@ -58,52 +58,52 @@ def create_app():
|
||||||
tags=["Chat"],
|
tags=["Chat"],
|
||||||
summary="与搜索引擎对话")(search_engine_chat)
|
summary="与搜索引擎对话")(search_engine_chat)
|
||||||
|
|
||||||
# Tag: Knowledge Base Management
|
# # Tag: Knowledge Base Management
|
||||||
app.get("/knowledge_base/list_knowledge_bases",
|
# app.get("/knowledge_base/list_knowledge_bases",
|
||||||
tags=["Knowledge Base Management"],
|
# tags=["Knowledge Base Management"],
|
||||||
response_model=ListResponse,
|
# response_model=ListResponse,
|
||||||
summary="获取知识库列表")(list_kbs)
|
# summary="获取知识库列表")(list_kbs)
|
||||||
|
#
|
||||||
app.post("/knowledge_base/create_knowledge_base",
|
# app.post("/knowledge_base/create_knowledge_base",
|
||||||
tags=["Knowledge Base Management"],
|
|
||||||
response_model=BaseResponse,
|
|
||||||
summary="创建知识库"
|
|
||||||
)(create_kb)
|
|
||||||
|
|
||||||
app.delete("/knowledge_base/delete_knowledge_base",
|
|
||||||
tags=["Knowledge Base Management"],
|
|
||||||
response_model=BaseResponse,
|
|
||||||
summary="删除知识库"
|
|
||||||
)(delete_kb)
|
|
||||||
|
|
||||||
app.get("/knowledge_base/list_docs",
|
|
||||||
tags=["Knowledge Base Management"],
|
|
||||||
response_model=ListResponse,
|
|
||||||
summary="获取知识库内的文件列表"
|
|
||||||
)(list_docs)
|
|
||||||
|
|
||||||
app.post("/knowledge_base/upload_doc",
|
|
||||||
tags=["Knowledge Base Management"],
|
|
||||||
response_model=BaseResponse,
|
|
||||||
summary="上传文件到知识库"
|
|
||||||
)(upload_doc)
|
|
||||||
|
|
||||||
app.delete("/knowledge_base/delete_doc",
|
|
||||||
tags=["Knowledge Base Management"],
|
|
||||||
response_model=BaseResponse,
|
|
||||||
summary="删除知识库内的文件"
|
|
||||||
)(delete_doc)
|
|
||||||
|
|
||||||
# app.post("/knowledge_base/update_doc",
|
|
||||||
# tags=["Knowledge Base Management"],
|
# tags=["Knowledge Base Management"],
|
||||||
# response_model=BaseResponse,
|
# response_model=BaseResponse,
|
||||||
# summary="上传文件到知识库,并删除另一个文件"
|
# summary="创建知识库"
|
||||||
# )(update_doc)
|
# )(create_kb)
|
||||||
|
#
|
||||||
app.post("/knowledge_base/recreate_vector_store",
|
# app.delete("/knowledge_base/delete_knowledge_base",
|
||||||
tags=["Knowledge Base Management"],
|
# tags=["Knowledge Base Management"],
|
||||||
summary="根据content中文档重建向量库,流式输出处理进度。"
|
# response_model=BaseResponse,
|
||||||
)(recreate_vector_store)
|
# summary="删除知识库"
|
||||||
|
# )(delete_kb)
|
||||||
|
#
|
||||||
|
# app.get("/knowledge_base/list_docs",
|
||||||
|
# tags=["Knowledge Base Management"],
|
||||||
|
# response_model=ListResponse,
|
||||||
|
# summary="获取知识库内的文件列表"
|
||||||
|
# )(list_docs)
|
||||||
|
#
|
||||||
|
# app.post("/knowledge_base/upload_doc",
|
||||||
|
# tags=["Knowledge Base Management"],
|
||||||
|
# response_model=BaseResponse,
|
||||||
|
# summary="上传文件到知识库"
|
||||||
|
# )(upload_doc)
|
||||||
|
#
|
||||||
|
# app.delete("/knowledge_base/delete_doc",
|
||||||
|
# tags=["Knowledge Base Management"],
|
||||||
|
# response_model=BaseResponse,
|
||||||
|
# summary="删除知识库内的文件"
|
||||||
|
# )(delete_doc)
|
||||||
|
#
|
||||||
|
# # app.post("/knowledge_base/update_doc",
|
||||||
|
# # tags=["Knowledge Base Management"],
|
||||||
|
# # response_model=BaseResponse,
|
||||||
|
# # summary="上传文件到知识库,并删除另一个文件"
|
||||||
|
# # )(update_doc)
|
||||||
|
#
|
||||||
|
# app.post("/knowledge_base/recreate_vector_store",
|
||||||
|
# tags=["Knowledge Base Management"],
|
||||||
|
# summary="根据content中文档重建向量库,流式输出处理进度。"
|
||||||
|
# )(recreate_vector_store)
|
||||||
return app
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -7,11 +7,22 @@ from langchain import LLMChain
|
||||||
from langchain.callbacks import AsyncIteratorCallbackHandler
|
from langchain.callbacks import AsyncIteratorCallbackHandler
|
||||||
from typing import AsyncIterable
|
from typing import AsyncIterable
|
||||||
import asyncio
|
import asyncio
|
||||||
from langchain.prompts import PromptTemplate
|
from langchain.prompts.chat import ChatPromptTemplate
|
||||||
|
from typing import List, Optional
|
||||||
|
from server.chat.utils import History
|
||||||
|
|
||||||
|
|
||||||
def chat(query: str = Body(..., description="用户输入", example="你好")):
|
def chat(query: str = Body(..., description="用户输入", example="恼羞成怒"),
|
||||||
async def chat_iterator(message: str) -> AsyncIterable[str]:
|
history: Optional[List[History]] = Body(...,
|
||||||
|
description="历史对话",
|
||||||
|
example=[
|
||||||
|
{"role": "user", "content": "我们来玩成语接龙,我先来,生龙活虎"},
|
||||||
|
{"role": "assistant", "content": "虎头虎脑"}]
|
||||||
|
),
|
||||||
|
):
|
||||||
|
async def chat_iterator(query: str,
|
||||||
|
history: Optional[List[History]]
|
||||||
|
) -> AsyncIterable[str]:
|
||||||
callback = AsyncIteratorCallbackHandler()
|
callback = AsyncIteratorCallbackHandler()
|
||||||
|
|
||||||
model = ChatOpenAI(
|
model = ChatOpenAI(
|
||||||
|
|
@ -23,17 +34,13 @@ def chat(query: str = Body(..., description="用户输入", example="你好")):
|
||||||
model_name=LLM_MODEL
|
model_name=LLM_MODEL
|
||||||
)
|
)
|
||||||
|
|
||||||
# llm = OpenAI(model_name=LLM_MODEL,
|
chat_prompt = ChatPromptTemplate.from_messages(
|
||||||
# openai_api_key=llm_model_dict[LLM_MODEL]["api_key"],
|
[i.to_msg_tuple() for i in history] + [("human", "{input}")])
|
||||||
# openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"],
|
chain = LLMChain(prompt=chat_prompt, llm=model)
|
||||||
# streaming=True)
|
|
||||||
|
|
||||||
prompt = PromptTemplate(input_variables=["input"], template="{input}")
|
|
||||||
chain = LLMChain(prompt=prompt, llm=model)
|
|
||||||
|
|
||||||
# Begin a task that runs in the background.
|
# Begin a task that runs in the background.
|
||||||
task = asyncio.create_task(wrap_done(
|
task = asyncio.create_task(wrap_done(
|
||||||
chain.acall(message),
|
chain.acall({"input": query}),
|
||||||
callback.done),
|
callback.done),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -41,4 +48,6 @@ def chat(query: str = Body(..., description="用户输入", example="你好")):
|
||||||
# Use server-sent-events to stream the response
|
# Use server-sent-events to stream the response
|
||||||
yield token
|
yield token
|
||||||
await task
|
await task
|
||||||
return StreamingResponse(chat_iterator(query), media_type="text/event-stream")
|
|
||||||
|
return StreamingResponse(chat_iterator(query, history),
|
||||||
|
media_type="text/event-stream")
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,9 @@ from langchain import LLMChain
|
||||||
from langchain.callbacks import AsyncIteratorCallbackHandler
|
from langchain.callbacks import AsyncIteratorCallbackHandler
|
||||||
from typing import AsyncIterable
|
from typing import AsyncIterable
|
||||||
import asyncio
|
import asyncio
|
||||||
from langchain.prompts import PromptTemplate
|
from langchain.prompts.chat import ChatPromptTemplate
|
||||||
|
from typing import List, Optional
|
||||||
|
from server.chat.utils import History
|
||||||
from server.knowledge_base.kb_service.base import KBService, KBServiceFactory
|
from server.knowledge_base.kb_service.base import KBService, KBServiceFactory
|
||||||
import json
|
import json
|
||||||
|
|
||||||
|
|
@ -17,6 +19,14 @@ import json
|
||||||
def knowledge_base_chat(query: str = Body(..., description="用户输入", example="你好"),
|
def knowledge_base_chat(query: str = Body(..., description="用户输入", example="你好"),
|
||||||
knowledge_base_name: str = Body(..., description="知识库名称", example="samples"),
|
knowledge_base_name: str = Body(..., description="知识库名称", example="samples"),
|
||||||
top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
|
top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
|
||||||
|
history: Optional[List[History]] = Body(...,
|
||||||
|
description="历史对话",
|
||||||
|
example=[
|
||||||
|
{"role": "user",
|
||||||
|
"content": "我们来玩成语接龙,我先来,生龙活虎"},
|
||||||
|
{"role": "assistant",
|
||||||
|
"content": "虎头虎脑"}]
|
||||||
|
),
|
||||||
):
|
):
|
||||||
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
|
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
|
||||||
if kb is None:
|
if kb is None:
|
||||||
|
|
@ -25,6 +35,7 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
|
||||||
async def knowledge_base_chat_iterator(query: str,
|
async def knowledge_base_chat_iterator(query: str,
|
||||||
kb: KBService,
|
kb: KBService,
|
||||||
top_k: int,
|
top_k: int,
|
||||||
|
history: Optional[List[History]],
|
||||||
) -> AsyncIterable[str]:
|
) -> AsyncIterable[str]:
|
||||||
callback = AsyncIteratorCallbackHandler()
|
callback = AsyncIteratorCallbackHandler()
|
||||||
model = ChatOpenAI(
|
model = ChatOpenAI(
|
||||||
|
|
@ -37,9 +48,11 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
|
||||||
)
|
)
|
||||||
docs = kb.search_docs(query, top_k)
|
docs = kb.search_docs(query, top_k)
|
||||||
context = "\n".join([doc.page_content for doc in docs])
|
context = "\n".join([doc.page_content for doc in docs])
|
||||||
prompt = PromptTemplate(template=PROMPT_TEMPLATE, input_variables=["context", "question"])
|
|
||||||
|
|
||||||
chain = LLMChain(prompt=prompt, llm=model)
|
chat_prompt = ChatPromptTemplate.from_messages(
|
||||||
|
[i.to_msg_tuple() for i in history] + [("human", PROMPT_TEMPLATE)])
|
||||||
|
|
||||||
|
chain = LLMChain(prompt=chat_prompt, llm=model)
|
||||||
|
|
||||||
# Begin a task that runs in the background.
|
# Begin a task that runs in the background.
|
||||||
task = asyncio.create_task(wrap_done(
|
task = asyncio.create_task(wrap_done(
|
||||||
|
|
@ -58,5 +71,5 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
|
||||||
"docs": source_documents})
|
"docs": source_documents})
|
||||||
await task
|
await task
|
||||||
|
|
||||||
return StreamingResponse(knowledge_base_chat_iterator(query, kb, top_k),
|
return StreamingResponse(knowledge_base_chat_iterator(query, kb, top_k, history),
|
||||||
media_type="text/event-stream")
|
media_type="text/event-stream")
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,9 @@ from langchain import LLMChain
|
||||||
from langchain.callbacks import AsyncIteratorCallbackHandler
|
from langchain.callbacks import AsyncIteratorCallbackHandler
|
||||||
from typing import AsyncIterable
|
from typing import AsyncIterable
|
||||||
import asyncio
|
import asyncio
|
||||||
from langchain.prompts import PromptTemplate
|
from langchain.prompts.chat import ChatPromptTemplate
|
||||||
|
from typing import List, Optional
|
||||||
|
from server.chat.utils import History
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
import json
|
import json
|
||||||
|
|
||||||
|
|
@ -58,6 +60,14 @@ def lookup_search_engine(
|
||||||
def search_engine_chat(query: str = Body(..., description="用户输入", example="你好"),
|
def search_engine_chat(query: str = Body(..., description="用户输入", example="你好"),
|
||||||
search_engine_name: str = Body(..., description="搜索引擎名称", example="duckduckgo"),
|
search_engine_name: str = Body(..., description="搜索引擎名称", example="duckduckgo"),
|
||||||
top_k: int = Body(SEARCH_ENGINE_TOP_K, description="检索结果数量"),
|
top_k: int = Body(SEARCH_ENGINE_TOP_K, description="检索结果数量"),
|
||||||
|
history: Optional[List[History]] = Body(...,
|
||||||
|
description="历史对话",
|
||||||
|
example=[
|
||||||
|
{"role": "user",
|
||||||
|
"content": "我们来玩成语接龙,我先来,生龙活虎"},
|
||||||
|
{"role": "assistant",
|
||||||
|
"content": "虎头虎脑"}]
|
||||||
|
),
|
||||||
):
|
):
|
||||||
if search_engine_name not in SEARCH_ENGINES.keys():
|
if search_engine_name not in SEARCH_ENGINES.keys():
|
||||||
return BaseResponse(code=404, msg=f"未支持搜索引擎 {search_engine_name}")
|
return BaseResponse(code=404, msg=f"未支持搜索引擎 {search_engine_name}")
|
||||||
|
|
@ -65,6 +75,7 @@ def search_engine_chat(query: str = Body(..., description="用户输入", exampl
|
||||||
async def search_engine_chat_iterator(query: str,
|
async def search_engine_chat_iterator(query: str,
|
||||||
search_engine_name: str,
|
search_engine_name: str,
|
||||||
top_k: int,
|
top_k: int,
|
||||||
|
history: Optional[List[History]],
|
||||||
) -> AsyncIterable[str]:
|
) -> AsyncIterable[str]:
|
||||||
callback = AsyncIteratorCallbackHandler()
|
callback = AsyncIteratorCallbackHandler()
|
||||||
model = ChatOpenAI(
|
model = ChatOpenAI(
|
||||||
|
|
@ -78,9 +89,11 @@ def search_engine_chat(query: str = Body(..., description="用户输入", exampl
|
||||||
|
|
||||||
docs = lookup_search_engine(query, search_engine_name, top_k)
|
docs = lookup_search_engine(query, search_engine_name, top_k)
|
||||||
context = "\n".join([doc.page_content for doc in docs])
|
context = "\n".join([doc.page_content for doc in docs])
|
||||||
prompt = PromptTemplate(template=PROMPT_TEMPLATE, input_variables=["context", "question"])
|
|
||||||
|
|
||||||
chain = LLMChain(prompt=prompt, llm=model)
|
chat_prompt = ChatPromptTemplate.from_messages(
|
||||||
|
[i.to_msg_tuple() for i in history] + [("human", PROMPT_TEMPLATE)])
|
||||||
|
|
||||||
|
chain = LLMChain(prompt=chat_prompt, llm=model)
|
||||||
|
|
||||||
# Begin a task that runs in the background.
|
# Begin a task that runs in the background.
|
||||||
task = asyncio.create_task(wrap_done(
|
task = asyncio.create_task(wrap_done(
|
||||||
|
|
@ -99,5 +112,5 @@ def search_engine_chat(query: str = Body(..., description="用户输入", exampl
|
||||||
"docs": source_documents})
|
"docs": source_documents})
|
||||||
await task
|
await task
|
||||||
|
|
||||||
return StreamingResponse(search_engine_chat_iterator(query, search_engine_name, top_k),
|
return StreamingResponse(search_engine_chat_iterator(query, search_engine_name, top_k, history),
|
||||||
media_type="text/event-stream")
|
media_type="text/event-stream")
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
import asyncio
|
import asyncio
|
||||||
from typing import Awaitable
|
from typing import Awaitable
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
async def wrap_done(fn: Awaitable, event: asyncio.Event):
|
async def wrap_done(fn: Awaitable, event: asyncio.Event):
|
||||||
|
|
@ -12,3 +13,18 @@ async def wrap_done(fn: Awaitable, event: asyncio.Event):
|
||||||
finally:
|
finally:
|
||||||
# Signal the aiter to stop.
|
# Signal the aiter to stop.
|
||||||
event.set()
|
event.set()
|
||||||
|
|
||||||
|
|
||||||
|
class History(BaseModel):
|
||||||
|
"""
|
||||||
|
对话历史
|
||||||
|
可从dict生成,如
|
||||||
|
h = History(**{"role":"user","content":"你好"})
|
||||||
|
也可转换为tuple,如
|
||||||
|
h.to_msy_tuple = ("human", "你好")
|
||||||
|
"""
|
||||||
|
role: str = Field(...)
|
||||||
|
content: str = Field(...)
|
||||||
|
|
||||||
|
def to_msg_tuple(self):
|
||||||
|
return "ai" if self.role=="assistant" else "human", self.content
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue