#!/usr/bin/env python # -*- coding: utf-8 -*- import os import numpy as np import logging import uvicorn import datetime from fastapi import FastAPI, Security, HTTPException from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from pydantic import Field, BaseModel, validator from typing import Optional, List,Dict from models import * def response(code, msg, data=None): time = str(datetime.datetime.now()) if data is None: data = [] result = { "code": code, "message": msg, "data": data, "time": time } return result def success(data=None, msg=''): return class QADocs(BaseModel): query: Optional[str] documents: Optional[List[str]] class Knows(BaseModel): know_id: Optional[str] contents: Optional[List[str]] drop_dup: Optional[bool] is_cover: Optional[bool] class Know_Sim(BaseModel): query: Optional[str] know_id: Optional[str] top_k: Optional[int] = 10 app = FastAPI() security = HTTPBearer() env_bearer_token = 'ACCESS_TOKEN' @app.post('/v1/embedding') async def handle_post_request1(sentences1: List[str], sentences2: List[str], credentials: HTTPAuthorizationCredentials = Security(security)): global know_pass token = credentials.credentials if env_bearer_token is not None and token != env_bearer_token: raise HTTPException(status_code=401, detail="Invalid token") similarity = know_pass.get_similarity_pair(sentences1,sentences2) return response(200, msg="获取相似性成功",data=similarity) @app.post('/v1/load_know') async def handle_post_request2(knows: Knows, credentials: HTTPAuthorizationCredentials = Security(security)): know_id = knows.know_id contents = knows.contents drop_dup = knows.drop_dup is_cover = knows.is_cover global know_pass token = credentials.credentials if env_bearer_token is not None and token != env_bearer_token: raise HTTPException(status_code=401, detail="Invalid token") know_pass.load_know(know_id, contents, drop_dup,is_cover) return response(200, msg=f"生成{know_id}知识库成功") @app.post('/v1/know_sim') async def handle_post_request3(know_sim:Know_Sim, credentials: HTTPAuthorizationCredentials = Security(security)): query = know_sim.query know_id = know_sim.know_id top_k = know_sim.top_k global know_pass token = credentials.credentials if env_bearer_token is not None and token != env_bearer_token: raise HTTPException(status_code=401, detail="Invalid token") similarity = know_pass.get_similarity_know(query,know_id,top_k) return response(200, msg="获取相似性成功",data=similarity) def init_env(): # 初始化模型 print("初始模型加载") # embedding = BCE_EMB(EMBEDDING_MODEL_PATH) EMBEDDING_MODEL_PATH = "/mnt/d/weiweiwang/intention/models/m3e-base" embedding = M3E_EMB(EMBEDDING_MODEL_PATH) know_pass = KNOW_PASS(embedding) return know_pass if __name__ == "__main__": token = os.getenv("ACCESS_TOKEN") if token is not None: env_bearer_token = token try: # 默认环境 know_pass = init_env() uvicorn.run(app, host='0.0.0.0', port=6007) except Exception as e: print(f"API启动失败!\n报错:\n{e}")