import argparse import os from typing import Literal import uvicorn from fastapi import Body, FastAPI from fastapi.middleware.cors import CORSMiddleware from fastapi.staticfiles import StaticFiles from starlette.responses import RedirectResponse from chatchat import __version__ from chatchat.settings import Settings from chatchat.server.api_server.chat_routes import chat_router from chatchat.server.api_server.kb_routes import kb_router from chatchat.server.api_server.openai_routes import openai_router from chatchat.server.api_server.server_routes import server_router from chatchat.server.api_server.tool_routes import tool_router from chatchat.server.chat.completion import completion from chatchat.server.utils import MakeFastAPIOffline def create_app(run_mode: str = None): app = FastAPI(title="Langchain-Chatchat API Server", version=__version__) MakeFastAPIOffline(app) # Add CORS middleware to allow all origins # 在config.py中设置OPEN_DOMAIN=True,允许跨域 # set OPEN_DOMAIN=True in config.py to allow cross-domain if Settings.basic_settings.OPEN_CROSS_DOMAIN: app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @app.get("/", summary="swagger 文档", include_in_schema=False) async def document(): return RedirectResponse(url="/docs") app.include_router(chat_router) app.include_router(kb_router) app.include_router(tool_router) app.include_router(openai_router) app.include_router(server_router) # 其它接口 app.post( "/other/completion", tags=["Other"], summary="要求llm模型补全(通过LLMChain)", )(completion) # 媒体文件 app.mount("/media", StaticFiles(directory=Settings.basic_settings.MEDIA_PATH), name="media") # 项目相关图片 img_dir = str(Settings.basic_settings.IMG_DIR) app.mount("/img", StaticFiles(directory=img_dir), name="img") return app def run_api(host, port, **kwargs): if kwargs.get("ssl_keyfile") and kwargs.get("ssl_certfile"): uvicorn.run( app, host=host, port=port, ssl_keyfile=kwargs.get("ssl_keyfile"), ssl_certfile=kwargs.get("ssl_certfile"), ) else: uvicorn.run(app, host=host, port=port) app = create_app() if __name__ == "__main__": parser = argparse.ArgumentParser( prog="langchain-ChatGLM", description="About langchain-ChatGLM, local knowledge based ChatGLM with langchain" " | 基于本地知识库的 ChatGLM 问答", ) parser.add_argument("--host", type=str, default="0.0.0.0") parser.add_argument("--port", type=int, default=7861) parser.add_argument("--ssl_keyfile", type=str) parser.add_argument("--ssl_certfile", type=str) # 初始化消息 args = parser.parse_args() args_dict = vars(args) run_api( host=args.host, port=args.port, ssl_keyfile=args.ssl_keyfile, ssl_certfile=args.ssl_certfile, )