diff --git a/api_one.py b/api_one.py deleted file mode 100644 index 542a8f0..0000000 --- a/api_one.py +++ /dev/null @@ -1,157 +0,0 @@ -import sys -import os -from configs.model_config import ( - llm_model_dict, - LLM_MODEL, - EMBEDDING_DEVICE, - LLM_DEVICE, - LOG_PATH, - logger, -) -from fastchat.serve.model_worker import heart_beat_worker -from fastapi import FastAPI -import threading -import asyncio - -# 把fastchat的3个服务端整合在一起,分别是: -# - http://{host_ip}:{port}/controller 对应 python -m fastchat.serve.controller -# - http://{host_ip}:{port}/model_worker 对应 python -m fastchat.serve.model_worker ... -# - http://{host_ip}:{port}/openai 对应 python -m fastchat.serve.openai_api_server ... - - -host_ip = "0.0.0.0" -port = 8888 -base_url = f"http://127.0.0.1:{port}" - - -def create_controller_app( - dispatch_method="shortest_queue", -): - from fastchat.serve.controller import app, Controller - controller = Controller(dispatch_method) - sys.modules["fastchat.serve.controller"].controller = controller - logger.info(f"controller dispatch method: {dispatch_method}") - return app, controller - - -def create_model_worker_app( - model_path, - model_names=[LLM_MODEL], - device=LLM_DEVICE, - load_8bit=False, - gptq_ckpt=None, - gptq_wbits=16, - gpus='', - num_gpus=-1, - max_gpu_memory=-1, - cpu_offloading=None, - worker_address=f"{base_url}/model_worker", - controller_address=f"{base_url}/controller", - limit_model_concurrency=5, - stream_interval=2, - no_register=True, # mannually register -): - from fastchat.serve.model_worker import app, GptqConfig, ModelWorker, worker_id - from fastchat.serve import model_worker - if gpus and num_gpus is None: - num_gpus = len(gpus.split(',')) - gptq_config = GptqConfig( - ckpt=gptq_ckpt or model_path, - wbits=gptq_wbits, - groupsize=-1, - act_order=None, - ) - worker = ModelWorker( - controller_address, - worker_address, - worker_id, - no_register, - model_path, - model_names, - device, - num_gpus, - max_gpu_memory, - load_8bit, - cpu_offloading, - gptq_config, - ) - sys.modules["fastchat.serve.model_worker"].worker = worker - return app, worker - - -def create_openai_api_app( - host=host_ip, - port=port, - controller_address=f"{base_url}/controller", - allow_credentials=None, - allowed_origins=["*"], - allowed_methods=["*"], - allowed_headers=["*"], - api_keys=[], -): - from fastchat.serve.openai_api_server import app, CORSMiddleware, app_settings - app.add_middleware( - CORSMiddleware, - allow_origins=allowed_origins, - allow_credentials=allow_credentials, - allow_methods=allowed_methods, - allow_headers=allowed_headers, - ) - app_settings.controller_address = controller_address - app_settings.api_keys = api_keys - sys.modules["fastchat.serve.openai_api_server.app_settings"] = app_settings - - return app - - -LLM_MODEL = 'chatglm-6b' -model_path = llm_model_dict[LLM_MODEL]["local_model_path"] -global controller - - -if not model_path: - logger.error("local_model_path 不能为空") -else: - logger.info(f"using local model: {model_path}") - app = FastAPI() - - controller_app, controller = create_controller_app() - app.mount("/controller", controller_app) - - model_woker_app, worker = create_model_worker_app(model_path) - app.mount("/model_worker", model_woker_app) - - openai_api_app = create_openai_api_app() - app.mount("/openai", openai_api_app) - - - @app.on_event("startup") - async def on_startup(): - logger.info("Register to controller") - controller.register_worker( - worker.worker_addr, - True, - worker.get_status(), - ) - worker.heart_beat_thread = threading.Thread( - target=heart_beat_worker, args=(worker,) - ) - worker.heart_beat_thread.start() - - # 通过网络请求注册会卡死 - # async def register(): - # while True: - # try: - # worker.register_to_controller() - # worker.heart_beat_thread = threading.Thread( - # target=heart_beat_worker, args=(worker,) - # ) - # worker.heart_beat_thread.start() - # except: - # await asyncio.sleep(1) - # asyncio.get_event_loop().create_task(register()) - - -if __name__ == "__main__": - import uvicorn - uvicorn.run(app, host=host_ip, port=port) diff --git a/requirements.txt b/requirements.txt index 1fb3519..a46259d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,4 @@ -langchain==0.0.237 +langchain==0.0.257 openai sentence_transformers chromadb diff --git a/server/api.py b/server/api.py index ee3f152..520cbb6 100644 --- a/server/api.py +++ b/server/api.py @@ -11,9 +11,9 @@ from fastapi.middleware.cors import CORSMiddleware from starlette.responses import RedirectResponse from server.chat import (chat, knowledge_base_chat, openai_chat, search_engine_chat) -from server.knowledge_base import (list_kbs, create_kb, delete_kb, - list_docs, upload_doc, delete_doc, - update_doc, recreate_vector_store) +# from server.knowledge_base import (list_kbs, create_kb, delete_kb, +# list_docs, upload_doc, delete_doc, +# update_doc, recreate_vector_store) from server.utils import BaseResponse, ListResponse nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path @@ -58,52 +58,52 @@ def create_app(): tags=["Chat"], summary="与搜索引擎对话")(search_engine_chat) - # Tag: Knowledge Base Management - app.get("/knowledge_base/list_knowledge_bases", - tags=["Knowledge Base Management"], - response_model=ListResponse, - summary="获取知识库列表")(list_kbs) - - app.post("/knowledge_base/create_knowledge_base", - tags=["Knowledge Base Management"], - response_model=BaseResponse, - summary="创建知识库" - )(create_kb) - - app.delete("/knowledge_base/delete_knowledge_base", - tags=["Knowledge Base Management"], - response_model=BaseResponse, - summary="删除知识库" - )(delete_kb) - - app.get("/knowledge_base/list_docs", - tags=["Knowledge Base Management"], - response_model=ListResponse, - summary="获取知识库内的文件列表" - )(list_docs) - - app.post("/knowledge_base/upload_doc", - tags=["Knowledge Base Management"], - response_model=BaseResponse, - summary="上传文件到知识库" - )(upload_doc) - - app.delete("/knowledge_base/delete_doc", - tags=["Knowledge Base Management"], - response_model=BaseResponse, - summary="删除知识库内的文件" - )(delete_doc) - - # app.post("/knowledge_base/update_doc", + # # Tag: Knowledge Base Management + # app.get("/knowledge_base/list_knowledge_bases", + # tags=["Knowledge Base Management"], + # response_model=ListResponse, + # summary="获取知识库列表")(list_kbs) + # + # app.post("/knowledge_base/create_knowledge_base", # tags=["Knowledge Base Management"], # response_model=BaseResponse, - # summary="上传文件到知识库,并删除另一个文件" - # )(update_doc) - - app.post("/knowledge_base/recreate_vector_store", - tags=["Knowledge Base Management"], - summary="根据content中文档重建向量库,流式输出处理进度。" - )(recreate_vector_store) + # summary="创建知识库" + # )(create_kb) + # + # app.delete("/knowledge_base/delete_knowledge_base", + # tags=["Knowledge Base Management"], + # response_model=BaseResponse, + # summary="删除知识库" + # )(delete_kb) + # + # app.get("/knowledge_base/list_docs", + # tags=["Knowledge Base Management"], + # response_model=ListResponse, + # summary="获取知识库内的文件列表" + # )(list_docs) + # + # app.post("/knowledge_base/upload_doc", + # tags=["Knowledge Base Management"], + # response_model=BaseResponse, + # summary="上传文件到知识库" + # )(upload_doc) + # + # app.delete("/knowledge_base/delete_doc", + # tags=["Knowledge Base Management"], + # response_model=BaseResponse, + # summary="删除知识库内的文件" + # )(delete_doc) + # + # # app.post("/knowledge_base/update_doc", + # # tags=["Knowledge Base Management"], + # # response_model=BaseResponse, + # # summary="上传文件到知识库,并删除另一个文件" + # # )(update_doc) + # + # app.post("/knowledge_base/recreate_vector_store", + # tags=["Knowledge Base Management"], + # summary="根据content中文档重建向量库,流式输出处理进度。" + # )(recreate_vector_store) return app diff --git a/server/chat/chat.py b/server/chat/chat.py index 3239cda..27d880e 100644 --- a/server/chat/chat.py +++ b/server/chat/chat.py @@ -7,11 +7,22 @@ from langchain import LLMChain from langchain.callbacks import AsyncIteratorCallbackHandler from typing import AsyncIterable import asyncio -from langchain.prompts import PromptTemplate +from langchain.prompts.chat import ChatPromptTemplate +from typing import List, Optional +from server.chat.utils import History -def chat(query: str = Body(..., description="用户输入", example="你好")): - async def chat_iterator(message: str) -> AsyncIterable[str]: +def chat(query: str = Body(..., description="用户输入", example="恼羞成怒"), + history: Optional[List[History]] = Body(..., + description="历史对话", + example=[ + {"role": "user", "content": "我们来玩成语接龙,我先来,生龙活虎"}, + {"role": "assistant", "content": "虎头虎脑"}] + ), + ): + async def chat_iterator(query: str, + history: Optional[List[History]] + ) -> AsyncIterable[str]: callback = AsyncIteratorCallbackHandler() model = ChatOpenAI( @@ -23,17 +34,13 @@ def chat(query: str = Body(..., description="用户输入", example="你好")): model_name=LLM_MODEL ) - # llm = OpenAI(model_name=LLM_MODEL, - # openai_api_key=llm_model_dict[LLM_MODEL]["api_key"], - # openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"], - # streaming=True) - - prompt = PromptTemplate(input_variables=["input"], template="{input}") - chain = LLMChain(prompt=prompt, llm=model) + chat_prompt = ChatPromptTemplate.from_messages( + [i.to_msg_tuple() for i in history] + [("human", "{input}")]) + chain = LLMChain(prompt=chat_prompt, llm=model) # Begin a task that runs in the background. task = asyncio.create_task(wrap_done( - chain.acall(message), + chain.acall({"input": query}), callback.done), ) @@ -41,4 +48,6 @@ def chat(query: str = Body(..., description="用户输入", example="你好")): # Use server-sent-events to stream the response yield token await task - return StreamingResponse(chat_iterator(query), media_type="text/event-stream") \ No newline at end of file + + return StreamingResponse(chat_iterator(query, history), + media_type="text/event-stream") diff --git a/server/chat/knowledge_base_chat.py b/server/chat/knowledge_base_chat.py index 5723d08..1a00abf 100644 --- a/server/chat/knowledge_base_chat.py +++ b/server/chat/knowledge_base_chat.py @@ -9,7 +9,9 @@ from langchain import LLMChain from langchain.callbacks import AsyncIteratorCallbackHandler from typing import AsyncIterable import asyncio -from langchain.prompts import PromptTemplate +from langchain.prompts.chat import ChatPromptTemplate +from typing import List, Optional +from server.chat.utils import History from server.knowledge_base.kb_service.base import KBService, KBServiceFactory import json @@ -17,6 +19,14 @@ import json def knowledge_base_chat(query: str = Body(..., description="用户输入", example="你好"), knowledge_base_name: str = Body(..., description="知识库名称", example="samples"), top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"), + history: Optional[List[History]] = Body(..., + description="历史对话", + example=[ + {"role": "user", + "content": "我们来玩成语接龙,我先来,生龙活虎"}, + {"role": "assistant", + "content": "虎头虎脑"}] + ), ): kb = KBServiceFactory.get_service_by_name(knowledge_base_name) if kb is None: @@ -25,6 +35,7 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp async def knowledge_base_chat_iterator(query: str, kb: KBService, top_k: int, + history: Optional[List[History]], ) -> AsyncIterable[str]: callback = AsyncIteratorCallbackHandler() model = ChatOpenAI( @@ -37,9 +48,11 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp ) docs = kb.search_docs(query, top_k) context = "\n".join([doc.page_content for doc in docs]) - prompt = PromptTemplate(template=PROMPT_TEMPLATE, input_variables=["context", "question"]) - chain = LLMChain(prompt=prompt, llm=model) + chat_prompt = ChatPromptTemplate.from_messages( + [i.to_msg_tuple() for i in history] + [("human", PROMPT_TEMPLATE)]) + + chain = LLMChain(prompt=chat_prompt, llm=model) # Begin a task that runs in the background. task = asyncio.create_task(wrap_done( @@ -58,5 +71,5 @@ def knowledge_base_chat(query: str = Body(..., description="用户输入", examp "docs": source_documents}) await task - return StreamingResponse(knowledge_base_chat_iterator(query, kb, top_k), + return StreamingResponse(knowledge_base_chat_iterator(query, kb, top_k, history), media_type="text/event-stream") diff --git a/server/chat/search_engine_chat.py b/server/chat/search_engine_chat.py index 4689d58..70e7a3a 100644 --- a/server/chat/search_engine_chat.py +++ b/server/chat/search_engine_chat.py @@ -10,7 +10,9 @@ from langchain import LLMChain from langchain.callbacks import AsyncIteratorCallbackHandler from typing import AsyncIterable import asyncio -from langchain.prompts import PromptTemplate +from langchain.prompts.chat import ChatPromptTemplate +from typing import List, Optional +from server.chat.utils import History from langchain.docstore.document import Document import json @@ -58,6 +60,14 @@ def lookup_search_engine( def search_engine_chat(query: str = Body(..., description="用户输入", example="你好"), search_engine_name: str = Body(..., description="搜索引擎名称", example="duckduckgo"), top_k: int = Body(SEARCH_ENGINE_TOP_K, description="检索结果数量"), + history: Optional[List[History]] = Body(..., + description="历史对话", + example=[ + {"role": "user", + "content": "我们来玩成语接龙,我先来,生龙活虎"}, + {"role": "assistant", + "content": "虎头虎脑"}] + ), ): if search_engine_name not in SEARCH_ENGINES.keys(): return BaseResponse(code=404, msg=f"未支持搜索引擎 {search_engine_name}") @@ -65,6 +75,7 @@ def search_engine_chat(query: str = Body(..., description="用户输入", exampl async def search_engine_chat_iterator(query: str, search_engine_name: str, top_k: int, + history: Optional[List[History]], ) -> AsyncIterable[str]: callback = AsyncIteratorCallbackHandler() model = ChatOpenAI( @@ -78,9 +89,11 @@ def search_engine_chat(query: str = Body(..., description="用户输入", exampl docs = lookup_search_engine(query, search_engine_name, top_k) context = "\n".join([doc.page_content for doc in docs]) - prompt = PromptTemplate(template=PROMPT_TEMPLATE, input_variables=["context", "question"]) - chain = LLMChain(prompt=prompt, llm=model) + chat_prompt = ChatPromptTemplate.from_messages( + [i.to_msg_tuple() for i in history] + [("human", PROMPT_TEMPLATE)]) + + chain = LLMChain(prompt=chat_prompt, llm=model) # Begin a task that runs in the background. task = asyncio.create_task(wrap_done( @@ -99,5 +112,5 @@ def search_engine_chat(query: str = Body(..., description="用户输入", exampl "docs": source_documents}) await task - return StreamingResponse(search_engine_chat_iterator(query, search_engine_name, top_k), + return StreamingResponse(search_engine_chat_iterator(query, search_engine_name, top_k, history), media_type="text/event-stream") diff --git a/server/chat/utils.py b/server/chat/utils.py index 32d8152..f8afb10 100644 --- a/server/chat/utils.py +++ b/server/chat/utils.py @@ -1,5 +1,6 @@ import asyncio from typing import Awaitable +from pydantic import BaseModel, Field async def wrap_done(fn: Awaitable, event: asyncio.Event): @@ -11,4 +12,19 @@ async def wrap_done(fn: Awaitable, event: asyncio.Event): print(f"Caught exception: {e}") finally: # Signal the aiter to stop. - event.set() \ No newline at end of file + event.set() + + +class History(BaseModel): + """ + 对话历史 + 可从dict生成,如 + h = History(**{"role":"user","content":"你好"}) + 也可转换为tuple,如 + h.to_msy_tuple = ("human", "你好") + """ + role: str = Field(...) + content: str = Field(...) + + def to_msg_tuple(self): + return "ai" if self.role=="assistant" else "human", self.content