From 2987c9cd5299f79829c97dfc94835f44d87d07b0 Mon Sep 17 00:00:00 2001 From: akou Date: Thu, 11 May 2023 09:32:58 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E5=85=81=E8=AE=B8=E8=B7=A8?= =?UTF-8?q?=E5=9F=9F=E8=B0=83=E7=94=A8API=E5=8A=9F=E8=83=BD=20(#279)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- api.py | 17 +++++++++++++++-- configs/model_config.py | 4 ++++ 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/api.py b/api.py index 051e26c..eb1848e 100644 --- a/api.py +++ b/api.py @@ -11,12 +11,14 @@ import pydantic import uvicorn from fastapi import Body, FastAPI, File, Form, Query, UploadFile, WebSocket from fastapi.openapi.utils import get_openapi +from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from typing_extensions import Annotated from starlette.responses import RedirectResponse from chains.local_doc_qa import LocalDocQA -from configs.model_config import (VS_ROOT_PATH, EMBEDDING_DEVICE, EMBEDDING_MODEL, LLM_MODEL, UPLOAD_ROOT_PATH, - NLTK_DATA_PATH, VECTOR_SEARCH_TOP_K, LLM_HISTORY_LEN) +from configs.model_config import (API_UPLOAD_ROOT_PATH, EMBEDDING_DEVICE, + EMBEDDING_MODEL, LLM_MODEL, NLTK_DATA_PATH, + VECTOR_SEARCH_TOP_K, LLM_HISTORY_LEN, OPEN_CROSS_DOMAIN) nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path @@ -310,6 +312,17 @@ def main(): args = parser.parse_args() app = FastAPI() + # Add CORS middleware to allow all origins + # 在config.py中设置OPEN_DOMAIN=True,允许跨域 + # set OPEN_DOMAIN=True in config.py to allow cross-domain + if OPEN_CROSS_DOMAIN: + app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) app.websocket("/chat-docs/stream-chat/{knowledge_base_id}")(stream_chat) app.post("/chat-docs/chat", response_model=ChatMessage)(chat) app.post("/chat-docs/chatno", response_model=ChatMessage)(no_knowledge_chat) diff --git a/configs/model_config.py b/configs/model_config.py index e054444..e9bef7d 100644 --- a/configs/model_config.py +++ b/configs/model_config.py @@ -83,3 +83,7 @@ embedding device: {EMBEDDING_DEVICE} dir: {os.path.dirname(os.path.dirname(__file__))} flagging username: {FLAG_USER_NAME} """) + +# 是否开启跨域,默认为False,如果需要开启,请设置为True +# is open cross domain +OPEN_CROSS_DOMAIN = False