Langchain-Chatchat/server/api.py

114 lines
4.1 KiB
Python
Raw Normal View History

2023-07-27 23:22:07 +08:00
import nltk
from configs.model_config import NLTK_DATA_PATH, OPEN_CROSS_DOMAIN
import argparse
import uvicorn
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from starlette.responses import RedirectResponse, StreamingResponse
from server.chat import chat, knowledge_base_chat, openai_chat
from server.knowledge_base import (list_kbs, create_kb, delete_kb,
list_docs, upload_doc, delete_doc, update_doc)
from server.utils import BaseResponse, ListResponse
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
async def document():
return RedirectResponse(url="/docs")
def api_start(host, port, **kwargs):
global app
app = FastAPI()
# Add CORS middleware to allow all origins
# 在config.py中设置OPEN_DOMAIN=True允许跨域
# set OPEN_DOMAIN=True in config.py to allow cross-domain
if OPEN_CROSS_DOMAIN:
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.get("/",
response_model=BaseResponse,
summary="swagger 文档")(document)
app.post("/chat/fastchat",
tags=["Chat"],
summary="与llm模型对话(直接与fastchat api对话)")(openai_chat)
app.post("/chat/chat",
tags=["Chat"],
summary="与llm模型对话(通过LLMChain)")(chat)
app.post("/chat/knowledge_base_chat",
tags=["Chat"],
summary="与知识库对话")(knowledge_base_chat)
# app.post("/chat/bing_search_chat", tags=["Chat"], summary="与Bing搜索对话")(bing_search_chat)
app.get("/knowledge_base/list_knowledge_bases",
tags=["Knowledge Base Management"],
response_model=ListResponse,
summary="获取知识库列表")(list_kbs)
app.post("/knowledge_base/create_knowledge_base",
tags=["Knowledge Base Management"],
response_model=BaseResponse,
summary="创建知识库"
)(create_kb)
app.delete("/knowledge_base/delete_knowledge_base",
tags=["Knowledge Base Management"],
response_model=BaseResponse,
summary="删除知识库"
)(delete_kb)
app.get("/knowledge_base/list_docs",
tags=["Knowledge Base Management"],
response_model=ListResponse,
summary="获取知识库内的文件列表"
)(list_docs)
app.post("/knowledge_base/upload_doc",
tags=["Knowledge Base Management"],
response_model=BaseResponse,
summary="上传文件到知识库"
)(upload_doc)
app.delete("/knowledge_base/delete_doc",
tags=["Knowledge Base Management"],
response_model=BaseResponse,
summary="删除知识库内的文件"
)(delete_doc)
app.post("/knowledge_base/update_doc",
tags=["Knowledge Base Management"],
response_model=BaseResponse,
summary="上传文件到知识库,并删除另一个文件"
)(update_doc)
if kwargs.get("ssl_keyfile") and kwargs.get("ssl_certfile"):
uvicorn.run(app, host=host, port=port, ssl_keyfile=kwargs.get("ssl_keyfile"),
ssl_certfile=kwargs.get("ssl_certfile"))
else:
uvicorn.run(app, host=host, port=port)
if __name__ == "__main__":
parser = argparse.ArgumentParser(prog='langchain-ChatGLM',
description='About langchain-ChatGLM, local knowledge based ChatGLM with langchain'
' 基于本地知识库的 ChatGLM 问答')
parser.add_argument("--host", type=str, default="0.0.0.0")
parser.add_argument("--port", type=int, default=7861)
parser.add_argument("--ssl_keyfile", type=str)
parser.add_argument("--ssl_certfile", type=str)
# 初始化消息
args = parser.parse_args()
args_dict = vars(args)
api_start(args.host, args.port, ssl_keyfile=args.ssl_keyfile, ssl_certfile=args.ssl_certfile)