增加允许跨域调用API功能 (#279)
This commit is contained in:
parent
e1c56edb6f
commit
2987c9cd52
17
api.py
17
api.py
|
|
@ -11,12 +11,14 @@ import pydantic
|
|||
import uvicorn
|
||||
from fastapi import Body, FastAPI, File, Form, Query, UploadFile, WebSocket
|
||||
from fastapi.openapi.utils import get_openapi
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Annotated
|
||||
from starlette.responses import RedirectResponse
|
||||
from chains.local_doc_qa import LocalDocQA
|
||||
from configs.model_config import (VS_ROOT_PATH, EMBEDDING_DEVICE, EMBEDDING_MODEL, LLM_MODEL, UPLOAD_ROOT_PATH,
|
||||
NLTK_DATA_PATH, VECTOR_SEARCH_TOP_K, LLM_HISTORY_LEN)
|
||||
from configs.model_config import (API_UPLOAD_ROOT_PATH, EMBEDDING_DEVICE,
|
||||
EMBEDDING_MODEL, LLM_MODEL, NLTK_DATA_PATH,
|
||||
VECTOR_SEARCH_TOP_K, LLM_HISTORY_LEN, OPEN_CROSS_DOMAIN)
|
||||
|
||||
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
|
||||
|
||||
|
|
@ -310,6 +312,17 @@ def main():
|
|||
args = parser.parse_args()
|
||||
|
||||
app = FastAPI()
|
||||
# Add CORS middleware to allow all origins
|
||||
# 在config.py中设置OPEN_DOMAIN=True,允许跨域
|
||||
# set OPEN_DOMAIN=True in config.py to allow cross-domain
|
||||
if OPEN_CROSS_DOMAIN:
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
app.websocket("/chat-docs/stream-chat/{knowledge_base_id}")(stream_chat)
|
||||
app.post("/chat-docs/chat", response_model=ChatMessage)(chat)
|
||||
app.post("/chat-docs/chatno", response_model=ChatMessage)(no_knowledge_chat)
|
||||
|
|
|
|||
|
|
@ -83,3 +83,7 @@ embedding device: {EMBEDDING_DEVICE}
|
|||
dir: {os.path.dirname(os.path.dirname(__file__))}
|
||||
flagging username: {FLAG_USER_NAME}
|
||||
""")
|
||||
|
||||
# 是否开启跨域,默认为False,如果需要开启,请设置为True
|
||||
# is open cross domain
|
||||
OPEN_CROSS_DOMAIN = False
|
||||
|
|
|
|||
Loading…
Reference in New Issue