147 lines
5.0 KiB
Python
147 lines
5.0 KiB
Python
import nltk
|
||
import sys
|
||
import os
|
||
|
||
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
||
|
||
from configs.model_config import NLTK_DATA_PATH, OPEN_CROSS_DOMAIN
|
||
import argparse
|
||
import uvicorn
|
||
from server.utils import FastAPIOffline as FastAPI
|
||
from fastapi.middleware.cors import CORSMiddleware
|
||
from starlette.responses import RedirectResponse
|
||
from server.chat import (chat, knowledge_base_chat, openai_chat,
|
||
search_engine_chat)
|
||
from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb
|
||
from server.knowledge_base.kb_doc_api import (list_docs, upload_doc, delete_doc,
|
||
update_doc, download_doc, recreate_vector_store)
|
||
from server.utils import BaseResponse, ListResponse
|
||
|
||
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
|
||
|
||
|
||
async def document():
|
||
return RedirectResponse(url="/docs")
|
||
|
||
|
||
def create_app():
|
||
app = FastAPI()
|
||
# Add CORS middleware to allow all origins
|
||
# 在config.py中设置OPEN_DOMAIN=True,允许跨域
|
||
# set OPEN_DOMAIN=True in config.py to allow cross-domain
|
||
if OPEN_CROSS_DOMAIN:
|
||
app.add_middleware(
|
||
CORSMiddleware,
|
||
allow_origins=["*"],
|
||
allow_credentials=True,
|
||
allow_methods=["*"],
|
||
allow_headers=["*"],
|
||
)
|
||
|
||
app.get("/",
|
||
response_model=BaseResponse,
|
||
summary="swagger 文档")(document)
|
||
|
||
# Tag: Chat
|
||
app.post("/chat/fastchat",
|
||
tags=["Chat"],
|
||
summary="与llm模型对话(直接与fastchat api对话)")(openai_chat)
|
||
|
||
app.post("/chat/chat",
|
||
tags=["Chat"],
|
||
summary="与llm模型对话(通过LLMChain)")(chat)
|
||
|
||
app.post("/chat/knowledge_base_chat",
|
||
tags=["Chat"],
|
||
summary="与知识库对话")(knowledge_base_chat)
|
||
|
||
app.post("/chat/search_engine_chat",
|
||
tags=["Chat"],
|
||
summary="与搜索引擎对话")(search_engine_chat)
|
||
|
||
# Tag: Knowledge Base Management
|
||
app.get("/knowledge_base/list_knowledge_bases",
|
||
tags=["Knowledge Base Management"],
|
||
response_model=ListResponse,
|
||
summary="获取知识库列表")(list_kbs)
|
||
|
||
app.post("/knowledge_base/create_knowledge_base",
|
||
tags=["Knowledge Base Management"],
|
||
response_model=BaseResponse,
|
||
summary="创建知识库"
|
||
)(create_kb)
|
||
|
||
app.post("/knowledge_base/delete_knowledge_base",
|
||
tags=["Knowledge Base Management"],
|
||
response_model=BaseResponse,
|
||
summary="删除知识库"
|
||
)(delete_kb)
|
||
|
||
app.get("/knowledge_base/list_docs",
|
||
tags=["Knowledge Base Management"],
|
||
response_model=ListResponse,
|
||
summary="获取知识库内的文件列表"
|
||
)(list_docs)
|
||
|
||
app.post("/knowledge_base/upload_doc",
|
||
tags=["Knowledge Base Management"],
|
||
response_model=BaseResponse,
|
||
summary="上传文件到知识库"
|
||
)(upload_doc)
|
||
|
||
app.post("/knowledge_base/delete_doc",
|
||
tags=["Knowledge Base Management"],
|
||
response_model=BaseResponse,
|
||
summary="删除知识库内指定文件"
|
||
)(delete_doc)
|
||
|
||
app.post("/knowledge_base/update_doc",
|
||
tags=["Knowledge Base Management"],
|
||
response_model=BaseResponse,
|
||
summary="更新现有文件到知识库"
|
||
)(update_doc)
|
||
|
||
app.get("/knowledge_base/download_doc",
|
||
tags=["Knowledge Base Management"],
|
||
summary="下载对应的知识文件")(download_doc)
|
||
|
||
app.post("/knowledge_base/recreate_vector_store",
|
||
tags=["Knowledge Base Management"],
|
||
summary="根据content中文档重建向量库,流式输出处理进度。"
|
||
)(recreate_vector_store)
|
||
|
||
return app
|
||
|
||
|
||
app = create_app()
|
||
|
||
|
||
def run_api(host, port, **kwargs):
|
||
if kwargs.get("ssl_keyfile") and kwargs.get("ssl_certfile"):
|
||
uvicorn.run(app,
|
||
host=host,
|
||
port=port,
|
||
ssl_keyfile=kwargs.get("ssl_keyfile"),
|
||
ssl_certfile=kwargs.get("ssl_certfile"),
|
||
)
|
||
else:
|
||
uvicorn.run(app, host=host, port=port)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
parser = argparse.ArgumentParser(prog='langchain-ChatGLM',
|
||
description='About langchain-ChatGLM, local knowledge based ChatGLM with langchain'
|
||
' | 基于本地知识库的 ChatGLM 问答')
|
||
parser.add_argument("--host", type=str, default="0.0.0.0")
|
||
parser.add_argument("--port", type=int, default=7861)
|
||
parser.add_argument("--ssl_keyfile", type=str)
|
||
parser.add_argument("--ssl_certfile", type=str)
|
||
# 初始化消息
|
||
args = parser.parse_args()
|
||
args_dict = vars(args)
|
||
run_api(host=args.host,
|
||
port=args.port,
|
||
ssl_keyfile=args.ssl_keyfile,
|
||
ssl_certfile=args.ssl_certfile,
|
||
)
|