2023-07-27 23:22:07 +08:00
|
|
|
|
import nltk
|
2023-07-29 23:00:50 +08:00
|
|
|
|
import sys
|
|
|
|
|
|
import os
|
2023-08-10 21:26:05 +08:00
|
|
|
|
|
2023-07-29 23:00:50 +08:00
|
|
|
|
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
2023-08-05 03:15:41 +08:00
|
|
|
|
|
2023-09-01 23:58:09 +08:00
|
|
|
|
from configs.model_config import LLM_MODEL, NLTK_DATA_PATH
|
|
|
|
|
|
from configs.server_config import OPEN_CROSS_DOMAIN, HTTPX_DEFAULT_TIMEOUT
|
2023-09-08 20:48:31 +08:00
|
|
|
|
from configs import VERSION, logger, log_verbose
|
2023-07-27 23:22:07 +08:00
|
|
|
|
import argparse
|
|
|
|
|
|
import uvicorn
|
2023-09-01 23:58:09 +08:00
|
|
|
|
from fastapi import Body
|
2023-07-27 23:22:07 +08:00
|
|
|
|
from fastapi.middleware.cors import CORSMiddleware
|
2023-08-01 16:39:17 +08:00
|
|
|
|
from starlette.responses import RedirectResponse
|
|
|
|
|
|
from server.chat import (chat, knowledge_base_chat, openai_chat,
|
2023-08-03 18:22:36 +08:00
|
|
|
|
search_engine_chat)
|
2023-08-09 10:46:01 +08:00
|
|
|
|
from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb
|
2023-09-08 08:55:12 +08:00
|
|
|
|
from server.knowledge_base.kb_doc_api import (list_files, upload_docs, delete_docs,
|
|
|
|
|
|
update_docs, download_doc, recreate_vector_store,
|
2023-08-17 22:19:26 +08:00
|
|
|
|
search_docs, DocumentWithScore)
|
2023-09-01 23:58:09 +08:00
|
|
|
|
from server.utils import BaseResponse, ListResponse, FastAPI, MakeFastAPIOffline, fschat_controller_address
|
|
|
|
|
|
import httpx
|
2023-08-16 13:18:58 +08:00
|
|
|
|
from typing import List
|
|
|
|
|
|
|
2023-07-27 23:22:07 +08:00
|
|
|
|
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
async def document():
|
|
|
|
|
|
return RedirectResponse(url="/docs")
|
|
|
|
|
|
|
|
|
|
|
|
|
2023-07-31 10:05:19 +08:00
|
|
|
|
def create_app():
|
2023-08-17 22:19:26 +08:00
|
|
|
|
app = FastAPI(
|
|
|
|
|
|
title="Langchain-Chatchat API Server",
|
|
|
|
|
|
version=VERSION
|
|
|
|
|
|
)
|
2023-08-16 14:20:09 +08:00
|
|
|
|
MakeFastAPIOffline(app)
|
2023-07-27 23:22:07 +08:00
|
|
|
|
# Add CORS middleware to allow all origins
|
|
|
|
|
|
# 在config.py中设置OPEN_DOMAIN=True,允许跨域
|
|
|
|
|
|
# set OPEN_DOMAIN=True in config.py to allow cross-domain
|
|
|
|
|
|
if OPEN_CROSS_DOMAIN:
|
|
|
|
|
|
app.add_middleware(
|
|
|
|
|
|
CORSMiddleware,
|
|
|
|
|
|
allow_origins=["*"],
|
|
|
|
|
|
allow_credentials=True,
|
|
|
|
|
|
allow_methods=["*"],
|
|
|
|
|
|
allow_headers=["*"],
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
app.get("/",
|
|
|
|
|
|
response_model=BaseResponse,
|
|
|
|
|
|
summary="swagger 文档")(document)
|
|
|
|
|
|
|
2023-08-03 18:22:36 +08:00
|
|
|
|
# Tag: Chat
|
2023-07-27 23:22:07 +08:00
|
|
|
|
app.post("/chat/fastchat",
|
|
|
|
|
|
tags=["Chat"],
|
|
|
|
|
|
summary="与llm模型对话(直接与fastchat api对话)")(openai_chat)
|
|
|
|
|
|
|
|
|
|
|
|
app.post("/chat/chat",
|
|
|
|
|
|
tags=["Chat"],
|
|
|
|
|
|
summary="与llm模型对话(通过LLMChain)")(chat)
|
|
|
|
|
|
|
|
|
|
|
|
app.post("/chat/knowledge_base_chat",
|
|
|
|
|
|
tags=["Chat"],
|
|
|
|
|
|
summary="与知识库对话")(knowledge_base_chat)
|
|
|
|
|
|
|
2023-08-03 18:22:36 +08:00
|
|
|
|
app.post("/chat/search_engine_chat",
|
2023-08-01 16:39:17 +08:00
|
|
|
|
tags=["Chat"],
|
2023-08-03 18:22:36 +08:00
|
|
|
|
summary="与搜索引擎对话")(search_engine_chat)
|
2023-07-27 23:22:07 +08:00
|
|
|
|
|
2023-08-08 23:55:27 +08:00
|
|
|
|
# Tag: Knowledge Base Management
|
|
|
|
|
|
app.get("/knowledge_base/list_knowledge_bases",
|
|
|
|
|
|
tags=["Knowledge Base Management"],
|
|
|
|
|
|
response_model=ListResponse,
|
|
|
|
|
|
summary="获取知识库列表")(list_kbs)
|
|
|
|
|
|
|
|
|
|
|
|
app.post("/knowledge_base/create_knowledge_base",
|
|
|
|
|
|
tags=["Knowledge Base Management"],
|
|
|
|
|
|
response_model=BaseResponse,
|
|
|
|
|
|
summary="创建知识库"
|
|
|
|
|
|
)(create_kb)
|
|
|
|
|
|
|
2023-08-11 08:37:07 +08:00
|
|
|
|
app.post("/knowledge_base/delete_knowledge_base",
|
2023-08-17 22:19:26 +08:00
|
|
|
|
tags=["Knowledge Base Management"],
|
|
|
|
|
|
response_model=BaseResponse,
|
|
|
|
|
|
summary="删除知识库"
|
|
|
|
|
|
)(delete_kb)
|
2023-08-08 23:55:27 +08:00
|
|
|
|
|
2023-08-28 13:50:35 +08:00
|
|
|
|
app.get("/knowledge_base/list_files",
|
2023-08-08 23:55:27 +08:00
|
|
|
|
tags=["Knowledge Base Management"],
|
|
|
|
|
|
response_model=ListResponse,
|
|
|
|
|
|
summary="获取知识库内的文件列表"
|
2023-08-28 13:50:35 +08:00
|
|
|
|
)(list_files)
|
2023-08-08 23:55:27 +08:00
|
|
|
|
|
2023-08-16 13:18:58 +08:00
|
|
|
|
app.post("/knowledge_base/search_docs",
|
2023-08-17 22:19:26 +08:00
|
|
|
|
tags=["Knowledge Base Management"],
|
|
|
|
|
|
response_model=List[DocumentWithScore],
|
|
|
|
|
|
summary="搜索知识库"
|
|
|
|
|
|
)(search_docs)
|
2023-08-16 13:18:58 +08:00
|
|
|
|
|
2023-09-08 08:55:12 +08:00
|
|
|
|
app.post("/knowledge_base/upload_docs",
|
2023-08-08 23:55:27 +08:00
|
|
|
|
tags=["Knowledge Base Management"],
|
|
|
|
|
|
response_model=BaseResponse,
|
2023-09-08 08:55:12 +08:00
|
|
|
|
summary="上传文件到知识库,并/或进行向量化"
|
|
|
|
|
|
)(upload_docs)
|
2023-08-08 23:55:27 +08:00
|
|
|
|
|
2023-09-08 08:55:12 +08:00
|
|
|
|
app.post("/knowledge_base/delete_docs",
|
2023-08-17 22:19:26 +08:00
|
|
|
|
tags=["Knowledge Base Management"],
|
|
|
|
|
|
response_model=BaseResponse,
|
|
|
|
|
|
summary="删除知识库内指定文件"
|
2023-09-08 08:55:12 +08:00
|
|
|
|
)(delete_docs)
|
2023-08-08 23:55:27 +08:00
|
|
|
|
|
2023-09-08 08:55:12 +08:00
|
|
|
|
app.post("/knowledge_base/update_docs",
|
2023-08-09 16:52:04 +08:00
|
|
|
|
tags=["Knowledge Base Management"],
|
|
|
|
|
|
response_model=BaseResponse,
|
|
|
|
|
|
summary="更新现有文件到知识库"
|
2023-09-08 08:55:12 +08:00
|
|
|
|
)(update_docs)
|
2023-08-08 23:55:27 +08:00
|
|
|
|
|
2023-08-14 11:46:36 +08:00
|
|
|
|
app.get("/knowledge_base/download_doc",
|
|
|
|
|
|
tags=["Knowledge Base Management"],
|
|
|
|
|
|
summary="下载对应的知识文件")(download_doc)
|
|
|
|
|
|
|
2023-08-08 23:55:27 +08:00
|
|
|
|
app.post("/knowledge_base/recreate_vector_store",
|
|
|
|
|
|
tags=["Knowledge Base Management"],
|
|
|
|
|
|
summary="根据content中文档重建向量库,流式输出处理进度。"
|
|
|
|
|
|
)(recreate_vector_store)
|
2023-08-10 14:12:02 +08:00
|
|
|
|
|
2023-09-01 23:58:09 +08:00
|
|
|
|
# LLM模型相关接口
|
|
|
|
|
|
@app.post("/llm_model/list_models",
|
|
|
|
|
|
tags=["LLM Model Management"],
|
|
|
|
|
|
summary="列出当前已加载的模型")
|
|
|
|
|
|
def list_models(
|
|
|
|
|
|
controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()])
|
|
|
|
|
|
) -> BaseResponse:
|
|
|
|
|
|
'''
|
|
|
|
|
|
从fastchat controller获取已加载模型列表
|
|
|
|
|
|
'''
|
|
|
|
|
|
try:
|
|
|
|
|
|
controller_address = controller_address or fschat_controller_address()
|
|
|
|
|
|
r = httpx.post(controller_address + "/list_models")
|
|
|
|
|
|
return BaseResponse(data=r.json()["models"])
|
|
|
|
|
|
except Exception as e:
|
2023-09-08 20:48:31 +08:00
|
|
|
|
logger.error(f'{e.__class__.__name__}: {e}',
|
|
|
|
|
|
exc_info=e if log_verbose else None)
|
2023-09-01 23:58:09 +08:00
|
|
|
|
return BaseResponse(
|
|
|
|
|
|
code=500,
|
|
|
|
|
|
data=[],
|
|
|
|
|
|
msg=f"failed to get available models from controller: {controller_address}。错误信息是: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
@app.post("/llm_model/stop",
|
|
|
|
|
|
tags=["LLM Model Management"],
|
|
|
|
|
|
summary="停止指定的LLM模型(Model Worker)",
|
|
|
|
|
|
)
|
|
|
|
|
|
def stop_llm_model(
|
|
|
|
|
|
model_name: str = Body(..., description="要停止的LLM模型名称", examples=[LLM_MODEL]),
|
|
|
|
|
|
controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()])
|
|
|
|
|
|
) -> BaseResponse:
|
|
|
|
|
|
'''
|
|
|
|
|
|
向fastchat controller请求停止某个LLM模型。
|
|
|
|
|
|
注意:由于Fastchat的实现方式,实际上是把LLM模型所在的model_worker停掉。
|
|
|
|
|
|
'''
|
|
|
|
|
|
try:
|
|
|
|
|
|
controller_address = controller_address or fschat_controller_address()
|
|
|
|
|
|
r = httpx.post(
|
|
|
|
|
|
controller_address + "/release_worker",
|
|
|
|
|
|
json={"model_name": model_name},
|
|
|
|
|
|
)
|
|
|
|
|
|
return r.json()
|
|
|
|
|
|
except Exception as e:
|
2023-09-08 20:48:31 +08:00
|
|
|
|
logger.error(f'{e.__class__.__name__}: {e}',
|
|
|
|
|
|
exc_info=e if log_verbose else None)
|
2023-09-01 23:58:09 +08:00
|
|
|
|
return BaseResponse(
|
|
|
|
|
|
code=500,
|
|
|
|
|
|
msg=f"failed to stop LLM model {model_name} from controller: {controller_address}。错误信息是: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
@app.post("/llm_model/change",
|
|
|
|
|
|
tags=["LLM Model Management"],
|
|
|
|
|
|
summary="切换指定的LLM模型(Model Worker)",
|
|
|
|
|
|
)
|
|
|
|
|
|
def change_llm_model(
|
|
|
|
|
|
model_name: str = Body(..., description="当前运行模型", examples=[LLM_MODEL]),
|
|
|
|
|
|
new_model_name: str = Body(..., description="要切换的新模型", examples=[LLM_MODEL]),
|
|
|
|
|
|
controller_address: str = Body(None, description="Fastchat controller服务器地址", examples=[fschat_controller_address()])
|
|
|
|
|
|
):
|
|
|
|
|
|
'''
|
|
|
|
|
|
向fastchat controller请求切换LLM模型。
|
|
|
|
|
|
'''
|
|
|
|
|
|
try:
|
|
|
|
|
|
controller_address = controller_address or fschat_controller_address()
|
|
|
|
|
|
r = httpx.post(
|
|
|
|
|
|
controller_address + "/release_worker",
|
|
|
|
|
|
json={"model_name": model_name, "new_model_name": new_model_name},
|
|
|
|
|
|
timeout=HTTPX_DEFAULT_TIMEOUT, # wait for new worker_model
|
|
|
|
|
|
)
|
|
|
|
|
|
return r.json()
|
|
|
|
|
|
except Exception as e:
|
2023-09-08 20:48:31 +08:00
|
|
|
|
logger.error(f'{e.__class__.__name__}: {e}',
|
|
|
|
|
|
exc_info=e if log_verbose else None)
|
2023-09-01 23:58:09 +08:00
|
|
|
|
return BaseResponse(
|
|
|
|
|
|
code=500,
|
|
|
|
|
|
msg=f"failed to switch LLM model from controller: {controller_address}。错误信息是: {e}")
|
|
|
|
|
|
|
2023-07-31 10:05:19 +08:00
|
|
|
|
return app
|
|
|
|
|
|
|
2023-08-05 03:15:41 +08:00
|
|
|
|
|
2023-07-31 10:05:19 +08:00
|
|
|
|
app = create_app()
|
2023-07-27 23:22:07 +08:00
|
|
|
|
|
2023-08-05 03:15:41 +08:00
|
|
|
|
|
2023-07-31 10:05:19 +08:00
|
|
|
|
def run_api(host, port, **kwargs):
|
2023-07-27 23:22:07 +08:00
|
|
|
|
if kwargs.get("ssl_keyfile") and kwargs.get("ssl_certfile"):
|
2023-08-01 21:53:19 +08:00
|
|
|
|
uvicorn.run(app,
|
|
|
|
|
|
host=host,
|
|
|
|
|
|
port=port,
|
|
|
|
|
|
ssl_keyfile=kwargs.get("ssl_keyfile"),
|
|
|
|
|
|
ssl_certfile=kwargs.get("ssl_certfile"),
|
|
|
|
|
|
)
|
2023-07-27 23:22:07 +08:00
|
|
|
|
else:
|
|
|
|
|
|
uvicorn.run(app, host=host, port=port)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
|
parser = argparse.ArgumentParser(prog='langchain-ChatGLM',
|
|
|
|
|
|
description='About langchain-ChatGLM, local knowledge based ChatGLM with langchain'
|
|
|
|
|
|
' | 基于本地知识库的 ChatGLM 问答')
|
|
|
|
|
|
parser.add_argument("--host", type=str, default="0.0.0.0")
|
|
|
|
|
|
parser.add_argument("--port", type=int, default=7861)
|
|
|
|
|
|
parser.add_argument("--ssl_keyfile", type=str)
|
|
|
|
|
|
parser.add_argument("--ssl_certfile", type=str)
|
|
|
|
|
|
# 初始化消息
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
args_dict = vars(args)
|
2023-08-01 21:53:19 +08:00
|
|
|
|
run_api(host=args.host,
|
|
|
|
|
|
port=args.port,
|
|
|
|
|
|
ssl_keyfile=args.ssl_keyfile,
|
|
|
|
|
|
ssl_certfile=args.ssl_certfile,
|
|
|
|
|
|
)
|