add history to chat apis

This commit is contained in:
imClumsyPanda 2023-08-08 23:54:51 +08:00
parent c8b6d53ede
commit 2d49746a8d
7 changed files with 120 additions and 226 deletions

View File

@ -1,157 +0,0 @@
import sys
import os
from configs.model_config import (
llm_model_dict,
LLM_MODEL,
EMBEDDING_DEVICE,
LLM_DEVICE,
LOG_PATH,
logger,
)
from fastchat.serve.model_worker import heart_beat_worker
from fastapi import FastAPI
import threading
import asyncio
# 把fastchat的3个服务端整合在一起分别是
# - http://{host_ip}:{port}/controller 对应 python -m fastchat.serve.controller
# - http://{host_ip}:{port}/model_worker 对应 python -m fastchat.serve.model_worker ...
# - http://{host_ip}:{port}/openai 对应 python -m fastchat.serve.openai_api_server ...
host_ip = "0.0.0.0"
port = 8888
base_url = f"http://127.0.0.1:{port}"
def create_controller_app(
dispatch_method="shortest_queue",
):
from fastchat.serve.controller import app, Controller
controller = Controller(dispatch_method)
sys.modules["fastchat.serve.controller"].controller = controller
logger.info(f"controller dispatch method: {dispatch_method}")
return app, controller
def create_model_worker_app(
model_path,
model_names=[LLM_MODEL],
device=LLM_DEVICE,
load_8bit=False,
gptq_ckpt=None,
gptq_wbits=16,
gpus='',
num_gpus=-1,
max_gpu_memory=-1,
cpu_offloading=None,
worker_address=f"{base_url}/model_worker",
controller_address=f"{base_url}/controller",
limit_model_concurrency=5,
stream_interval=2,
no_register=True, # mannually register
):
from fastchat.serve.model_worker import app, GptqConfig, ModelWorker, worker_id
from fastchat.serve import model_worker
if gpus and num_gpus is None:
num_gpus = len(gpus.split(','))
gptq_config = GptqConfig(
ckpt=gptq_ckpt or model_path,
wbits=gptq_wbits,
groupsize=-1,
act_order=None,
)
worker = ModelWorker(
controller_address,
worker_address,
worker_id,
no_register,
model_path,
model_names,
device,
num_gpus,
max_gpu_memory,
load_8bit,
cpu_offloading,
gptq_config,
)
sys.modules["fastchat.serve.model_worker"].worker = worker
return app, worker
def create_openai_api_app(
host=host_ip,
port=port,
controller_address=f"{base_url}/controller",
allow_credentials=None,
allowed_origins=["*"],
allowed_methods=["*"],
allowed_headers=["*"],
api_keys=[],
):
from fastchat.serve.openai_api_server import app, CORSMiddleware, app_settings
app.add_middleware(
CORSMiddleware,
allow_origins=allowed_origins,
allow_credentials=allow_credentials,
allow_methods=allowed_methods,
allow_headers=allowed_headers,
)
app_settings.controller_address = controller_address
app_settings.api_keys = api_keys
sys.modules["fastchat.serve.openai_api_server.app_settings"] = app_settings
return app
LLM_MODEL = 'chatglm-6b'
model_path = llm_model_dict[LLM_MODEL]["local_model_path"]
global controller
if not model_path:
logger.error("local_model_path 不能为空")
else:
logger.info(f"using local model: {model_path}")
app = FastAPI()
controller_app, controller = create_controller_app()
app.mount("/controller", controller_app)
model_woker_app, worker = create_model_worker_app(model_path)
app.mount("/model_worker", model_woker_app)
openai_api_app = create_openai_api_app()
app.mount("/openai", openai_api_app)
@app.on_event("startup")
async def on_startup():
logger.info("Register to controller")
controller.register_worker(
worker.worker_addr,
True,
worker.get_status(),
)
worker.heart_beat_thread = threading.Thread(
target=heart_beat_worker, args=(worker,)
)
worker.heart_beat_thread.start()
# 通过网络请求注册会卡死
# async def register():
# while True:
# try:
# worker.register_to_controller()
# worker.heart_beat_thread = threading.Thread(
# target=heart_beat_worker, args=(worker,)
# )
# worker.heart_beat_thread.start()
# except:
# await asyncio.sleep(1)
# asyncio.get_event_loop().create_task(register())
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host=host_ip, port=port)

View File

@ -1,4 +1,4 @@
langchain==0.0.237
langchain==0.0.257
openai
sentence_transformers
chromadb

View File

@ -11,9 +11,9 @@ from fastapi.middleware.cors import CORSMiddleware
from starlette.responses import RedirectResponse
from server.chat import (chat, knowledge_base_chat, openai_chat,
search_engine_chat)
from server.knowledge_base import (list_kbs, create_kb, delete_kb,
list_docs, upload_doc, delete_doc,
update_doc, recreate_vector_store)
# from server.knowledge_base import (list_kbs, create_kb, delete_kb,
# list_docs, upload_doc, delete_doc,
# update_doc, recreate_vector_store)
from server.utils import BaseResponse, ListResponse
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
@ -58,52 +58,52 @@ def create_app():
tags=["Chat"],
summary="与搜索引擎对话")(search_engine_chat)
# Tag: Knowledge Base Management
app.get("/knowledge_base/list_knowledge_bases",
tags=["Knowledge Base Management"],
response_model=ListResponse,
summary="获取知识库列表")(list_kbs)
app.post("/knowledge_base/create_knowledge_base",
tags=["Knowledge Base Management"],
response_model=BaseResponse,
summary="创建知识库"
)(create_kb)
app.delete("/knowledge_base/delete_knowledge_base",
tags=["Knowledge Base Management"],
response_model=BaseResponse,
summary="删除知识库"
)(delete_kb)
app.get("/knowledge_base/list_docs",
tags=["Knowledge Base Management"],
response_model=ListResponse,
summary="获取知识库内的文件列表"
)(list_docs)
app.post("/knowledge_base/upload_doc",
tags=["Knowledge Base Management"],
response_model=BaseResponse,
summary="上传文件到知识库"
)(upload_doc)
app.delete("/knowledge_base/delete_doc",
tags=["Knowledge Base Management"],
response_model=BaseResponse,
summary="删除知识库内的文件"
)(delete_doc)
# app.post("/knowledge_base/update_doc",
# # Tag: Knowledge Base Management
# app.get("/knowledge_base/list_knowledge_bases",
# tags=["Knowledge Base Management"],
# response_model=ListResponse,
# summary="获取知识库列表")(list_kbs)
#
# app.post("/knowledge_base/create_knowledge_base",
# tags=["Knowledge Base Management"],
# response_model=BaseResponse,
# summary="上传文件到知识库,并删除另一个文件"
# )(update_doc)
app.post("/knowledge_base/recreate_vector_store",
tags=["Knowledge Base Management"],
summary="根据content中文档重建向量库流式输出处理进度。"
)(recreate_vector_store)
# summary="创建知识库"
# )(create_kb)
#
# app.delete("/knowledge_base/delete_knowledge_base",
# tags=["Knowledge Base Management"],
# response_model=BaseResponse,
# summary="删除知识库"
# )(delete_kb)
#
# app.get("/knowledge_base/list_docs",
# tags=["Knowledge Base Management"],
# response_model=ListResponse,
# summary="获取知识库内的文件列表"
# )(list_docs)
#
# app.post("/knowledge_base/upload_doc",
# tags=["Knowledge Base Management"],
# response_model=BaseResponse,
# summary="上传文件到知识库"
# )(upload_doc)
#
# app.delete("/knowledge_base/delete_doc",
# tags=["Knowledge Base Management"],
# response_model=BaseResponse,
# summary="删除知识库内的文件"
# )(delete_doc)
#
# # app.post("/knowledge_base/update_doc",
# # tags=["Knowledge Base Management"],
# # response_model=BaseResponse,
# # summary="上传文件到知识库,并删除另一个文件"
# # )(update_doc)
#
# app.post("/knowledge_base/recreate_vector_store",
# tags=["Knowledge Base Management"],
# summary="根据content中文档重建向量库流式输出处理进度。"
# )(recreate_vector_store)
return app

View File

@ -7,11 +7,22 @@ from langchain import LLMChain
from langchain.callbacks import AsyncIteratorCallbackHandler
from typing import AsyncIterable
import asyncio
from langchain.prompts import PromptTemplate
from langchain.prompts.chat import ChatPromptTemplate
from typing import List, Optional
from server.chat.utils import History
def chat(query: str = Body(..., description="用户输入", example="你好")):
async def chat_iterator(message: str) -> AsyncIterable[str]:
def chat(query: str = Body(..., description="用户输入", example="恼羞成怒"),
history: Optional[List[History]] = Body(...,
description="历史对话",
example=[
{"role": "user", "content": "我们来玩成语接龙,我先来,生龙活虎"},
{"role": "assistant", "content": "虎头虎脑"}]
),
):
async def chat_iterator(query: str,
history: Optional[List[History]]
) -> AsyncIterable[str]:
callback = AsyncIteratorCallbackHandler()
model = ChatOpenAI(
@ -23,17 +34,13 @@ def chat(query: str = Body(..., description="用户输入", example="你好")):
model_name=LLM_MODEL
)
# llm = OpenAI(model_name=LLM_MODEL,
# openai_api_key=llm_model_dict[LLM_MODEL]["api_key"],
# openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"],
# streaming=True)
prompt = PromptTemplate(input_variables=["input"], template="{input}")
chain = LLMChain(prompt=prompt, llm=model)
chat_prompt = ChatPromptTemplate.from_messages(
[i.to_msg_tuple() for i in history] + [("human", "{input}")])
chain = LLMChain(prompt=chat_prompt, llm=model)
# Begin a task that runs in the background.
task = asyncio.create_task(wrap_done(
chain.acall(message),
chain.acall({"input": query}),
callback.done),
)
@ -41,4 +48,6 @@ def chat(query: str = Body(..., description="用户输入", example="你好")):
# Use server-sent-events to stream the response
yield token
await task
return StreamingResponse(chat_iterator(query), media_type="text/event-stream")
return StreamingResponse(chat_iterator(query, history),
media_type="text/event-stream")

View File

@ -9,7 +9,9 @@ from langchain import LLMChain
from langchain.callbacks import AsyncIteratorCallbackHandler
from typing import AsyncIterable
import asyncio
from langchain.prompts import PromptTemplate
from langchain.prompts.chat import ChatPromptTemplate
from typing import List, Optional
from server.chat.utils import History
from server.knowledge_base.kb_service.base import KBService, KBServiceFactory
import json
@ -17,6 +19,14 @@ import json
def knowledge_base_chat(query: str = Body(..., description="用户输入", example="你好"),
knowledge_base_name: str = Body(..., description="知识库名称", example="samples"),
top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
history: Optional[List[History]] = Body(...,
description="历史对话",
example=[
{"role": "user",
"content": "我们来玩成语接龙,我先来,生龙活虎"},
{"role": "assistant",
"content": "虎头虎脑"}]
),
):
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
if kb is None:
@ -25,6 +35,7 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
async def knowledge_base_chat_iterator(query: str,
kb: KBService,
top_k: int,
history: Optional[List[History]],
) -> AsyncIterable[str]:
callback = AsyncIteratorCallbackHandler()
model = ChatOpenAI(
@ -37,9 +48,11 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
)
docs = kb.search_docs(query, top_k)
context = "\n".join([doc.page_content for doc in docs])
prompt = PromptTemplate(template=PROMPT_TEMPLATE, input_variables=["context", "question"])
chain = LLMChain(prompt=prompt, llm=model)
chat_prompt = ChatPromptTemplate.from_messages(
[i.to_msg_tuple() for i in history] + [("human", PROMPT_TEMPLATE)])
chain = LLMChain(prompt=chat_prompt, llm=model)
# Begin a task that runs in the background.
task = asyncio.create_task(wrap_done(
@ -58,5 +71,5 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
"docs": source_documents})
await task
return StreamingResponse(knowledge_base_chat_iterator(query, kb, top_k),
return StreamingResponse(knowledge_base_chat_iterator(query, kb, top_k, history),
media_type="text/event-stream")

View File

@ -10,7 +10,9 @@ from langchain import LLMChain
from langchain.callbacks import AsyncIteratorCallbackHandler
from typing import AsyncIterable
import asyncio
from langchain.prompts import PromptTemplate
from langchain.prompts.chat import ChatPromptTemplate
from typing import List, Optional
from server.chat.utils import History
from langchain.docstore.document import Document
import json
@ -58,6 +60,14 @@ def lookup_search_engine(
def search_engine_chat(query: str = Body(..., description="用户输入", example="你好"),
search_engine_name: str = Body(..., description="搜索引擎名称", example="duckduckgo"),
top_k: int = Body(SEARCH_ENGINE_TOP_K, description="检索结果数量"),
history: Optional[List[History]] = Body(...,
description="历史对话",
example=[
{"role": "user",
"content": "我们来玩成语接龙,我先来,生龙活虎"},
{"role": "assistant",
"content": "虎头虎脑"}]
),
):
if search_engine_name not in SEARCH_ENGINES.keys():
return BaseResponse(code=404, msg=f"未支持搜索引擎 {search_engine_name}")
@ -65,6 +75,7 @@ def search_engine_chat(query: str = Body(..., description="用户输入", exampl
async def search_engine_chat_iterator(query: str,
search_engine_name: str,
top_k: int,
history: Optional[List[History]],
) -> AsyncIterable[str]:
callback = AsyncIteratorCallbackHandler()
model = ChatOpenAI(
@ -78,9 +89,11 @@ def search_engine_chat(query: str = Body(..., description="用户输入", exampl
docs = lookup_search_engine(query, search_engine_name, top_k)
context = "\n".join([doc.page_content for doc in docs])
prompt = PromptTemplate(template=PROMPT_TEMPLATE, input_variables=["context", "question"])
chain = LLMChain(prompt=prompt, llm=model)
chat_prompt = ChatPromptTemplate.from_messages(
[i.to_msg_tuple() for i in history] + [("human", PROMPT_TEMPLATE)])
chain = LLMChain(prompt=chat_prompt, llm=model)
# Begin a task that runs in the background.
task = asyncio.create_task(wrap_done(
@ -99,5 +112,5 @@ def search_engine_chat(query: str = Body(..., description="用户输入", exampl
"docs": source_documents})
await task
return StreamingResponse(search_engine_chat_iterator(query, search_engine_name, top_k),
return StreamingResponse(search_engine_chat_iterator(query, search_engine_name, top_k, history),
media_type="text/event-stream")

View File

@ -1,5 +1,6 @@
import asyncio
from typing import Awaitable
from pydantic import BaseModel, Field
async def wrap_done(fn: Awaitable, event: asyncio.Event):
@ -11,4 +12,19 @@ async def wrap_done(fn: Awaitable, event: asyncio.Event):
print(f"Caught exception: {e}")
finally:
# Signal the aiter to stop.
event.set()
event.set()
class History(BaseModel):
"""
对话历史
可从dict生成
h = History(**{"role":"user","content":"你好"})
也可转换为tuple
h.to_msy_tuple = ("human", "你好")
"""
role: str = Field(...)
content: str = Field(...)
def to_msg_tuple(self):
return "ai" if self.role=="assistant" else "human", self.content