add history to chat apis
This commit is contained in:
parent
c8b6d53ede
commit
2d49746a8d
157
api_one.py
157
api_one.py
|
|
@ -1,157 +0,0 @@
|
|||
import sys
|
||||
import os
|
||||
from configs.model_config import (
|
||||
llm_model_dict,
|
||||
LLM_MODEL,
|
||||
EMBEDDING_DEVICE,
|
||||
LLM_DEVICE,
|
||||
LOG_PATH,
|
||||
logger,
|
||||
)
|
||||
from fastchat.serve.model_worker import heart_beat_worker
|
||||
from fastapi import FastAPI
|
||||
import threading
|
||||
import asyncio
|
||||
|
||||
# 把fastchat的3个服务端整合在一起,分别是:
|
||||
# - http://{host_ip}:{port}/controller 对应 python -m fastchat.serve.controller
|
||||
# - http://{host_ip}:{port}/model_worker 对应 python -m fastchat.serve.model_worker ...
|
||||
# - http://{host_ip}:{port}/openai 对应 python -m fastchat.serve.openai_api_server ...
|
||||
|
||||
|
||||
host_ip = "0.0.0.0"
|
||||
port = 8888
|
||||
base_url = f"http://127.0.0.1:{port}"
|
||||
|
||||
|
||||
def create_controller_app(
|
||||
dispatch_method="shortest_queue",
|
||||
):
|
||||
from fastchat.serve.controller import app, Controller
|
||||
controller = Controller(dispatch_method)
|
||||
sys.modules["fastchat.serve.controller"].controller = controller
|
||||
logger.info(f"controller dispatch method: {dispatch_method}")
|
||||
return app, controller
|
||||
|
||||
|
||||
def create_model_worker_app(
|
||||
model_path,
|
||||
model_names=[LLM_MODEL],
|
||||
device=LLM_DEVICE,
|
||||
load_8bit=False,
|
||||
gptq_ckpt=None,
|
||||
gptq_wbits=16,
|
||||
gpus='',
|
||||
num_gpus=-1,
|
||||
max_gpu_memory=-1,
|
||||
cpu_offloading=None,
|
||||
worker_address=f"{base_url}/model_worker",
|
||||
controller_address=f"{base_url}/controller",
|
||||
limit_model_concurrency=5,
|
||||
stream_interval=2,
|
||||
no_register=True, # mannually register
|
||||
):
|
||||
from fastchat.serve.model_worker import app, GptqConfig, ModelWorker, worker_id
|
||||
from fastchat.serve import model_worker
|
||||
if gpus and num_gpus is None:
|
||||
num_gpus = len(gpus.split(','))
|
||||
gptq_config = GptqConfig(
|
||||
ckpt=gptq_ckpt or model_path,
|
||||
wbits=gptq_wbits,
|
||||
groupsize=-1,
|
||||
act_order=None,
|
||||
)
|
||||
worker = ModelWorker(
|
||||
controller_address,
|
||||
worker_address,
|
||||
worker_id,
|
||||
no_register,
|
||||
model_path,
|
||||
model_names,
|
||||
device,
|
||||
num_gpus,
|
||||
max_gpu_memory,
|
||||
load_8bit,
|
||||
cpu_offloading,
|
||||
gptq_config,
|
||||
)
|
||||
sys.modules["fastchat.serve.model_worker"].worker = worker
|
||||
return app, worker
|
||||
|
||||
|
||||
def create_openai_api_app(
|
||||
host=host_ip,
|
||||
port=port,
|
||||
controller_address=f"{base_url}/controller",
|
||||
allow_credentials=None,
|
||||
allowed_origins=["*"],
|
||||
allowed_methods=["*"],
|
||||
allowed_headers=["*"],
|
||||
api_keys=[],
|
||||
):
|
||||
from fastchat.serve.openai_api_server import app, CORSMiddleware, app_settings
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=allowed_origins,
|
||||
allow_credentials=allow_credentials,
|
||||
allow_methods=allowed_methods,
|
||||
allow_headers=allowed_headers,
|
||||
)
|
||||
app_settings.controller_address = controller_address
|
||||
app_settings.api_keys = api_keys
|
||||
sys.modules["fastchat.serve.openai_api_server.app_settings"] = app_settings
|
||||
|
||||
return app
|
||||
|
||||
|
||||
LLM_MODEL = 'chatglm-6b'
|
||||
model_path = llm_model_dict[LLM_MODEL]["local_model_path"]
|
||||
global controller
|
||||
|
||||
|
||||
if not model_path:
|
||||
logger.error("local_model_path 不能为空")
|
||||
else:
|
||||
logger.info(f"using local model: {model_path}")
|
||||
app = FastAPI()
|
||||
|
||||
controller_app, controller = create_controller_app()
|
||||
app.mount("/controller", controller_app)
|
||||
|
||||
model_woker_app, worker = create_model_worker_app(model_path)
|
||||
app.mount("/model_worker", model_woker_app)
|
||||
|
||||
openai_api_app = create_openai_api_app()
|
||||
app.mount("/openai", openai_api_app)
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
async def on_startup():
|
||||
logger.info("Register to controller")
|
||||
controller.register_worker(
|
||||
worker.worker_addr,
|
||||
True,
|
||||
worker.get_status(),
|
||||
)
|
||||
worker.heart_beat_thread = threading.Thread(
|
||||
target=heart_beat_worker, args=(worker,)
|
||||
)
|
||||
worker.heart_beat_thread.start()
|
||||
|
||||
# 通过网络请求注册会卡死
|
||||
# async def register():
|
||||
# while True:
|
||||
# try:
|
||||
# worker.register_to_controller()
|
||||
# worker.heart_beat_thread = threading.Thread(
|
||||
# target=heart_beat_worker, args=(worker,)
|
||||
# )
|
||||
# worker.heart_beat_thread.start()
|
||||
# except:
|
||||
# await asyncio.sleep(1)
|
||||
# asyncio.get_event_loop().create_task(register())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host=host_ip, port=port)
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
langchain==0.0.237
|
||||
langchain==0.0.257
|
||||
openai
|
||||
sentence_transformers
|
||||
chromadb
|
||||
|
|
|
|||
|
|
@ -11,9 +11,9 @@ from fastapi.middleware.cors import CORSMiddleware
|
|||
from starlette.responses import RedirectResponse
|
||||
from server.chat import (chat, knowledge_base_chat, openai_chat,
|
||||
search_engine_chat)
|
||||
from server.knowledge_base import (list_kbs, create_kb, delete_kb,
|
||||
list_docs, upload_doc, delete_doc,
|
||||
update_doc, recreate_vector_store)
|
||||
# from server.knowledge_base import (list_kbs, create_kb, delete_kb,
|
||||
# list_docs, upload_doc, delete_doc,
|
||||
# update_doc, recreate_vector_store)
|
||||
from server.utils import BaseResponse, ListResponse
|
||||
|
||||
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
|
||||
|
|
@ -58,52 +58,52 @@ def create_app():
|
|||
tags=["Chat"],
|
||||
summary="与搜索引擎对话")(search_engine_chat)
|
||||
|
||||
# Tag: Knowledge Base Management
|
||||
app.get("/knowledge_base/list_knowledge_bases",
|
||||
tags=["Knowledge Base Management"],
|
||||
response_model=ListResponse,
|
||||
summary="获取知识库列表")(list_kbs)
|
||||
|
||||
app.post("/knowledge_base/create_knowledge_base",
|
||||
tags=["Knowledge Base Management"],
|
||||
response_model=BaseResponse,
|
||||
summary="创建知识库"
|
||||
)(create_kb)
|
||||
|
||||
app.delete("/knowledge_base/delete_knowledge_base",
|
||||
tags=["Knowledge Base Management"],
|
||||
response_model=BaseResponse,
|
||||
summary="删除知识库"
|
||||
)(delete_kb)
|
||||
|
||||
app.get("/knowledge_base/list_docs",
|
||||
tags=["Knowledge Base Management"],
|
||||
response_model=ListResponse,
|
||||
summary="获取知识库内的文件列表"
|
||||
)(list_docs)
|
||||
|
||||
app.post("/knowledge_base/upload_doc",
|
||||
tags=["Knowledge Base Management"],
|
||||
response_model=BaseResponse,
|
||||
summary="上传文件到知识库"
|
||||
)(upload_doc)
|
||||
|
||||
app.delete("/knowledge_base/delete_doc",
|
||||
tags=["Knowledge Base Management"],
|
||||
response_model=BaseResponse,
|
||||
summary="删除知识库内的文件"
|
||||
)(delete_doc)
|
||||
|
||||
# app.post("/knowledge_base/update_doc",
|
||||
# # Tag: Knowledge Base Management
|
||||
# app.get("/knowledge_base/list_knowledge_bases",
|
||||
# tags=["Knowledge Base Management"],
|
||||
# response_model=ListResponse,
|
||||
# summary="获取知识库列表")(list_kbs)
|
||||
#
|
||||
# app.post("/knowledge_base/create_knowledge_base",
|
||||
# tags=["Knowledge Base Management"],
|
||||
# response_model=BaseResponse,
|
||||
# summary="上传文件到知识库,并删除另一个文件"
|
||||
# )(update_doc)
|
||||
|
||||
app.post("/knowledge_base/recreate_vector_store",
|
||||
tags=["Knowledge Base Management"],
|
||||
summary="根据content中文档重建向量库,流式输出处理进度。"
|
||||
)(recreate_vector_store)
|
||||
# summary="创建知识库"
|
||||
# )(create_kb)
|
||||
#
|
||||
# app.delete("/knowledge_base/delete_knowledge_base",
|
||||
# tags=["Knowledge Base Management"],
|
||||
# response_model=BaseResponse,
|
||||
# summary="删除知识库"
|
||||
# )(delete_kb)
|
||||
#
|
||||
# app.get("/knowledge_base/list_docs",
|
||||
# tags=["Knowledge Base Management"],
|
||||
# response_model=ListResponse,
|
||||
# summary="获取知识库内的文件列表"
|
||||
# )(list_docs)
|
||||
#
|
||||
# app.post("/knowledge_base/upload_doc",
|
||||
# tags=["Knowledge Base Management"],
|
||||
# response_model=BaseResponse,
|
||||
# summary="上传文件到知识库"
|
||||
# )(upload_doc)
|
||||
#
|
||||
# app.delete("/knowledge_base/delete_doc",
|
||||
# tags=["Knowledge Base Management"],
|
||||
# response_model=BaseResponse,
|
||||
# summary="删除知识库内的文件"
|
||||
# )(delete_doc)
|
||||
#
|
||||
# # app.post("/knowledge_base/update_doc",
|
||||
# # tags=["Knowledge Base Management"],
|
||||
# # response_model=BaseResponse,
|
||||
# # summary="上传文件到知识库,并删除另一个文件"
|
||||
# # )(update_doc)
|
||||
#
|
||||
# app.post("/knowledge_base/recreate_vector_store",
|
||||
# tags=["Knowledge Base Management"],
|
||||
# summary="根据content中文档重建向量库,流式输出处理进度。"
|
||||
# )(recreate_vector_store)
|
||||
return app
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -7,11 +7,22 @@ from langchain import LLMChain
|
|||
from langchain.callbacks import AsyncIteratorCallbackHandler
|
||||
from typing import AsyncIterable
|
||||
import asyncio
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.prompts.chat import ChatPromptTemplate
|
||||
from typing import List, Optional
|
||||
from server.chat.utils import History
|
||||
|
||||
|
||||
def chat(query: str = Body(..., description="用户输入", example="你好")):
|
||||
async def chat_iterator(message: str) -> AsyncIterable[str]:
|
||||
def chat(query: str = Body(..., description="用户输入", example="恼羞成怒"),
|
||||
history: Optional[List[History]] = Body(...,
|
||||
description="历史对话",
|
||||
example=[
|
||||
{"role": "user", "content": "我们来玩成语接龙,我先来,生龙活虎"},
|
||||
{"role": "assistant", "content": "虎头虎脑"}]
|
||||
),
|
||||
):
|
||||
async def chat_iterator(query: str,
|
||||
history: Optional[List[History]]
|
||||
) -> AsyncIterable[str]:
|
||||
callback = AsyncIteratorCallbackHandler()
|
||||
|
||||
model = ChatOpenAI(
|
||||
|
|
@ -23,17 +34,13 @@ def chat(query: str = Body(..., description="用户输入", example="你好")):
|
|||
model_name=LLM_MODEL
|
||||
)
|
||||
|
||||
# llm = OpenAI(model_name=LLM_MODEL,
|
||||
# openai_api_key=llm_model_dict[LLM_MODEL]["api_key"],
|
||||
# openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"],
|
||||
# streaming=True)
|
||||
|
||||
prompt = PromptTemplate(input_variables=["input"], template="{input}")
|
||||
chain = LLMChain(prompt=prompt, llm=model)
|
||||
chat_prompt = ChatPromptTemplate.from_messages(
|
||||
[i.to_msg_tuple() for i in history] + [("human", "{input}")])
|
||||
chain = LLMChain(prompt=chat_prompt, llm=model)
|
||||
|
||||
# Begin a task that runs in the background.
|
||||
task = asyncio.create_task(wrap_done(
|
||||
chain.acall(message),
|
||||
chain.acall({"input": query}),
|
||||
callback.done),
|
||||
)
|
||||
|
||||
|
|
@ -41,4 +48,6 @@ def chat(query: str = Body(..., description="用户输入", example="你好")):
|
|||
# Use server-sent-events to stream the response
|
||||
yield token
|
||||
await task
|
||||
return StreamingResponse(chat_iterator(query), media_type="text/event-stream")
|
||||
|
||||
return StreamingResponse(chat_iterator(query, history),
|
||||
media_type="text/event-stream")
|
||||
|
|
|
|||
|
|
@ -9,7 +9,9 @@ from langchain import LLMChain
|
|||
from langchain.callbacks import AsyncIteratorCallbackHandler
|
||||
from typing import AsyncIterable
|
||||
import asyncio
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.prompts.chat import ChatPromptTemplate
|
||||
from typing import List, Optional
|
||||
from server.chat.utils import History
|
||||
from server.knowledge_base.kb_service.base import KBService, KBServiceFactory
|
||||
import json
|
||||
|
||||
|
|
@ -17,6 +19,14 @@ import json
|
|||
def knowledge_base_chat(query: str = Body(..., description="用户输入", example="你好"),
|
||||
knowledge_base_name: str = Body(..., description="知识库名称", example="samples"),
|
||||
top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
|
||||
history: Optional[List[History]] = Body(...,
|
||||
description="历史对话",
|
||||
example=[
|
||||
{"role": "user",
|
||||
"content": "我们来玩成语接龙,我先来,生龙活虎"},
|
||||
{"role": "assistant",
|
||||
"content": "虎头虎脑"}]
|
||||
),
|
||||
):
|
||||
kb = KBServiceFactory.get_service_by_name(knowledge_base_name)
|
||||
if kb is None:
|
||||
|
|
@ -25,6 +35,7 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
|
|||
async def knowledge_base_chat_iterator(query: str,
|
||||
kb: KBService,
|
||||
top_k: int,
|
||||
history: Optional[List[History]],
|
||||
) -> AsyncIterable[str]:
|
||||
callback = AsyncIteratorCallbackHandler()
|
||||
model = ChatOpenAI(
|
||||
|
|
@ -37,9 +48,11 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
|
|||
)
|
||||
docs = kb.search_docs(query, top_k)
|
||||
context = "\n".join([doc.page_content for doc in docs])
|
||||
prompt = PromptTemplate(template=PROMPT_TEMPLATE, input_variables=["context", "question"])
|
||||
|
||||
chain = LLMChain(prompt=prompt, llm=model)
|
||||
chat_prompt = ChatPromptTemplate.from_messages(
|
||||
[i.to_msg_tuple() for i in history] + [("human", PROMPT_TEMPLATE)])
|
||||
|
||||
chain = LLMChain(prompt=chat_prompt, llm=model)
|
||||
|
||||
# Begin a task that runs in the background.
|
||||
task = asyncio.create_task(wrap_done(
|
||||
|
|
@ -58,5 +71,5 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp
|
|||
"docs": source_documents})
|
||||
await task
|
||||
|
||||
return StreamingResponse(knowledge_base_chat_iterator(query, kb, top_k),
|
||||
return StreamingResponse(knowledge_base_chat_iterator(query, kb, top_k, history),
|
||||
media_type="text/event-stream")
|
||||
|
|
|
|||
|
|
@ -10,7 +10,9 @@ from langchain import LLMChain
|
|||
from langchain.callbacks import AsyncIteratorCallbackHandler
|
||||
from typing import AsyncIterable
|
||||
import asyncio
|
||||
from langchain.prompts import PromptTemplate
|
||||
from langchain.prompts.chat import ChatPromptTemplate
|
||||
from typing import List, Optional
|
||||
from server.chat.utils import History
|
||||
from langchain.docstore.document import Document
|
||||
import json
|
||||
|
||||
|
|
@ -58,6 +60,14 @@ def lookup_search_engine(
|
|||
def search_engine_chat(query: str = Body(..., description="用户输入", example="你好"),
|
||||
search_engine_name: str = Body(..., description="搜索引擎名称", example="duckduckgo"),
|
||||
top_k: int = Body(SEARCH_ENGINE_TOP_K, description="检索结果数量"),
|
||||
history: Optional[List[History]] = Body(...,
|
||||
description="历史对话",
|
||||
example=[
|
||||
{"role": "user",
|
||||
"content": "我们来玩成语接龙,我先来,生龙活虎"},
|
||||
{"role": "assistant",
|
||||
"content": "虎头虎脑"}]
|
||||
),
|
||||
):
|
||||
if search_engine_name not in SEARCH_ENGINES.keys():
|
||||
return BaseResponse(code=404, msg=f"未支持搜索引擎 {search_engine_name}")
|
||||
|
|
@ -65,6 +75,7 @@ def search_engine_chat(query: str = Body(..., description="用户输入", exampl
|
|||
async def search_engine_chat_iterator(query: str,
|
||||
search_engine_name: str,
|
||||
top_k: int,
|
||||
history: Optional[List[History]],
|
||||
) -> AsyncIterable[str]:
|
||||
callback = AsyncIteratorCallbackHandler()
|
||||
model = ChatOpenAI(
|
||||
|
|
@ -78,9 +89,11 @@ def search_engine_chat(query: str = Body(..., description="用户输入", exampl
|
|||
|
||||
docs = lookup_search_engine(query, search_engine_name, top_k)
|
||||
context = "\n".join([doc.page_content for doc in docs])
|
||||
prompt = PromptTemplate(template=PROMPT_TEMPLATE, input_variables=["context", "question"])
|
||||
|
||||
chain = LLMChain(prompt=prompt, llm=model)
|
||||
chat_prompt = ChatPromptTemplate.from_messages(
|
||||
[i.to_msg_tuple() for i in history] + [("human", PROMPT_TEMPLATE)])
|
||||
|
||||
chain = LLMChain(prompt=chat_prompt, llm=model)
|
||||
|
||||
# Begin a task that runs in the background.
|
||||
task = asyncio.create_task(wrap_done(
|
||||
|
|
@ -99,5 +112,5 @@ def search_engine_chat(query: str = Body(..., description="用户输入", exampl
|
|||
"docs": source_documents})
|
||||
await task
|
||||
|
||||
return StreamingResponse(search_engine_chat_iterator(query, search_engine_name, top_k),
|
||||
return StreamingResponse(search_engine_chat_iterator(query, search_engine_name, top_k, history),
|
||||
media_type="text/event-stream")
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
import asyncio
|
||||
from typing import Awaitable
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
async def wrap_done(fn: Awaitable, event: asyncio.Event):
|
||||
|
|
@ -11,4 +12,19 @@ async def wrap_done(fn: Awaitable, event: asyncio.Event):
|
|||
print(f"Caught exception: {e}")
|
||||
finally:
|
||||
# Signal the aiter to stop.
|
||||
event.set()
|
||||
event.set()
|
||||
|
||||
|
||||
class History(BaseModel):
|
||||
"""
|
||||
对话历史
|
||||
可从dict生成,如
|
||||
h = History(**{"role":"user","content":"你好"})
|
||||
也可转换为tuple,如
|
||||
h.to_msy_tuple = ("human", "你好")
|
||||
"""
|
||||
role: str = Field(...)
|
||||
content: str = Field(...)
|
||||
|
||||
def to_msg_tuple(self):
|
||||
return "ai" if self.role=="assistant" else "human", self.content
|
||||
|
|
|
|||
Loading…
Reference in New Issue