增加允许跨域调用API功能 (#279)
This commit is contained in:
parent
e1c56edb6f
commit
2987c9cd52
17
api.py
17
api.py
|
|
@ -11,12 +11,14 @@ import pydantic
|
||||||
import uvicorn
|
import uvicorn
|
||||||
from fastapi import Body, FastAPI, File, Form, Query, UploadFile, WebSocket
|
from fastapi import Body, FastAPI, File, Form, Query, UploadFile, WebSocket
|
||||||
from fastapi.openapi.utils import get_openapi
|
from fastapi.openapi.utils import get_openapi
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
from starlette.responses import RedirectResponse
|
from starlette.responses import RedirectResponse
|
||||||
from chains.local_doc_qa import LocalDocQA
|
from chains.local_doc_qa import LocalDocQA
|
||||||
from configs.model_config import (VS_ROOT_PATH, EMBEDDING_DEVICE, EMBEDDING_MODEL, LLM_MODEL, UPLOAD_ROOT_PATH,
|
from configs.model_config import (API_UPLOAD_ROOT_PATH, EMBEDDING_DEVICE,
|
||||||
NLTK_DATA_PATH, VECTOR_SEARCH_TOP_K, LLM_HISTORY_LEN)
|
EMBEDDING_MODEL, LLM_MODEL, NLTK_DATA_PATH,
|
||||||
|
VECTOR_SEARCH_TOP_K, LLM_HISTORY_LEN, OPEN_CROSS_DOMAIN)
|
||||||
|
|
||||||
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
|
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
|
||||||
|
|
||||||
|
|
@ -310,6 +312,17 @@ def main():
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
# Add CORS middleware to allow all origins
|
||||||
|
# 在config.py中设置OPEN_DOMAIN=True,允许跨域
|
||||||
|
# set OPEN_DOMAIN=True in config.py to allow cross-domain
|
||||||
|
if OPEN_CROSS_DOMAIN:
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=["*"],
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
app.websocket("/chat-docs/stream-chat/{knowledge_base_id}")(stream_chat)
|
app.websocket("/chat-docs/stream-chat/{knowledge_base_id}")(stream_chat)
|
||||||
app.post("/chat-docs/chat", response_model=ChatMessage)(chat)
|
app.post("/chat-docs/chat", response_model=ChatMessage)(chat)
|
||||||
app.post("/chat-docs/chatno", response_model=ChatMessage)(no_knowledge_chat)
|
app.post("/chat-docs/chatno", response_model=ChatMessage)(no_knowledge_chat)
|
||||||
|
|
|
||||||
|
|
@ -83,3 +83,7 @@ embedding device: {EMBEDDING_DEVICE}
|
||||||
dir: {os.path.dirname(os.path.dirname(__file__))}
|
dir: {os.path.dirname(os.path.dirname(__file__))}
|
||||||
flagging username: {FLAG_USER_NAME}
|
flagging username: {FLAG_USER_NAME}
|
||||||
""")
|
""")
|
||||||
|
|
||||||
|
# 是否开启跨域,默认为False,如果需要开启,请设置为True
|
||||||
|
# is open cross domain
|
||||||
|
OPEN_CROSS_DOMAIN = False
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue