增加允许跨域调用API功能 (#279)

This commit is contained in:
akou 2023-05-11 09:32:58 +08:00 committed by GitHub
parent e1c56edb6f
commit 2987c9cd52
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 19 additions and 2 deletions

17
api.py
View File

@ -11,12 +11,14 @@ import pydantic
import uvicorn
from fastapi import Body, FastAPI, File, Form, Query, UploadFile, WebSocket
from fastapi.openapi.utils import get_openapi
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing_extensions import Annotated
from starlette.responses import RedirectResponse
from chains.local_doc_qa import LocalDocQA
from configs.model_config import (VS_ROOT_PATH, EMBEDDING_DEVICE, EMBEDDING_MODEL, LLM_MODEL, UPLOAD_ROOT_PATH,
NLTK_DATA_PATH, VECTOR_SEARCH_TOP_K, LLM_HISTORY_LEN)
from configs.model_config import (API_UPLOAD_ROOT_PATH, EMBEDDING_DEVICE,
EMBEDDING_MODEL, LLM_MODEL, NLTK_DATA_PATH,
VECTOR_SEARCH_TOP_K, LLM_HISTORY_LEN, OPEN_CROSS_DOMAIN)
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
@ -310,6 +312,17 @@ def main():
args = parser.parse_args()
app = FastAPI()
# Add CORS middleware to allow all origins
# 在config.py中设置OPEN_DOMAIN=True允许跨域
# set OPEN_DOMAIN=True in config.py to allow cross-domain
if OPEN_CROSS_DOMAIN:
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.websocket("/chat-docs/stream-chat/{knowledge_base_id}")(stream_chat)
app.post("/chat-docs/chat", response_model=ChatMessage)(chat)
app.post("/chat-docs/chatno", response_model=ChatMessage)(no_knowledge_chat)

View File

@ -83,3 +83,7 @@ embedding device: {EMBEDDING_DEVICE}
dir: {os.path.dirname(os.path.dirname(__file__))}
flagging username: {FLAG_USER_NAME}
""")
# 是否开启跨域默认为False如果需要开启请设置为True
# is open cross domain
OPEN_CROSS_DOMAIN = False