2025-02-07 17:29:20 +08:00
|
|
|
|
#!/usr/bin/env python
|
|
|
|
|
|
# -*- coding: utf-8 -*-
|
|
|
|
|
|
import os
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
|
import logging
|
|
|
|
|
|
import uvicorn
|
|
|
|
|
|
import datetime
|
|
|
|
|
|
from fastapi import FastAPI, Security, HTTPException
|
|
|
|
|
|
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
|
|
|
|
|
from pydantic import Field, BaseModel, validator
|
|
|
|
|
|
from typing import Optional, List,Dict
|
|
|
|
|
|
|
|
|
|
|
|
from models import *
|
|
|
|
|
|
|
|
|
|
|
|
def response(code, msg, data=None):
|
|
|
|
|
|
time = str(datetime.datetime.now())
|
|
|
|
|
|
if data is None:
|
|
|
|
|
|
data = []
|
|
|
|
|
|
result = {
|
|
|
|
|
|
"code": code,
|
|
|
|
|
|
"message": msg,
|
|
|
|
|
|
"data": data,
|
|
|
|
|
|
"time": time
|
|
|
|
|
|
}
|
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
|
|
def success(data=None, msg=''):
|
|
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
class QADocs(BaseModel):
|
|
|
|
|
|
query: Optional[str]
|
|
|
|
|
|
documents: Optional[List[str]]
|
|
|
|
|
|
|
|
|
|
|
|
class Knows(BaseModel):
|
|
|
|
|
|
know_id: Optional[str]
|
|
|
|
|
|
contents: Optional[List[str]]
|
|
|
|
|
|
drop_dup: Optional[bool]
|
|
|
|
|
|
is_cover: Optional[bool]
|
|
|
|
|
|
|
|
|
|
|
|
class Know_Sim(BaseModel):
|
|
|
|
|
|
query: Optional[str]
|
|
|
|
|
|
know_id: Optional[str]
|
|
|
|
|
|
top_k: Optional[int] = 10
|
|
|
|
|
|
|
|
|
|
|
|
app = FastAPI()
|
|
|
|
|
|
security = HTTPBearer()
|
|
|
|
|
|
env_bearer_token = 'ACCESS_TOKEN'
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.post('/v1/embedding')
|
|
|
|
|
|
async def handle_post_request1(sentences1: List[str], sentences2: List[str], credentials: HTTPAuthorizationCredentials = Security(security)):
|
|
|
|
|
|
global know_pass
|
|
|
|
|
|
token = credentials.credentials
|
|
|
|
|
|
if env_bearer_token is not None and token != env_bearer_token:
|
|
|
|
|
|
raise HTTPException(status_code=401, detail="Invalid token")
|
|
|
|
|
|
similarity = know_pass.get_similarity_pair(sentences1,sentences2)
|
|
|
|
|
|
return response(200, msg="获取相似性成功",data=similarity)
|
|
|
|
|
|
|
|
|
|
|
|
@app.post('/v1/load_know')
|
|
|
|
|
|
async def handle_post_request2(knows: Knows, credentials: HTTPAuthorizationCredentials = Security(security)):
|
|
|
|
|
|
know_id = knows.know_id
|
|
|
|
|
|
contents = knows.contents
|
|
|
|
|
|
drop_dup = knows.drop_dup
|
|
|
|
|
|
is_cover = knows.is_cover
|
|
|
|
|
|
global know_pass
|
|
|
|
|
|
token = credentials.credentials
|
|
|
|
|
|
if env_bearer_token is not None and token != env_bearer_token:
|
|
|
|
|
|
raise HTTPException(status_code=401, detail="Invalid token")
|
|
|
|
|
|
know_pass.load_know(know_id, contents, drop_dup,is_cover)
|
|
|
|
|
|
return response(200, msg=f"生成{know_id}知识库成功")
|
|
|
|
|
|
|
|
|
|
|
|
@app.post('/v1/know_sim')
|
|
|
|
|
|
async def handle_post_request3(know_sim:Know_Sim, credentials: HTTPAuthorizationCredentials = Security(security)):
|
|
|
|
|
|
query = know_sim.query
|
|
|
|
|
|
know_id = know_sim.know_id
|
|
|
|
|
|
top_k = know_sim.top_k
|
|
|
|
|
|
global know_pass
|
|
|
|
|
|
token = credentials.credentials
|
|
|
|
|
|
if env_bearer_token is not None and token != env_bearer_token:
|
|
|
|
|
|
raise HTTPException(status_code=401, detail="Invalid token")
|
|
|
|
|
|
similarity = know_pass.get_similarity_know(query,know_id,top_k)
|
|
|
|
|
|
return response(200, msg="获取相似性成功",data=similarity)
|
|
|
|
|
|
|
|
|
|
|
|
def init_env():
|
|
|
|
|
|
# 初始化模型
|
|
|
|
|
|
print("初始模型加载")
|
|
|
|
|
|
# embedding = BCE_EMB(EMBEDDING_MODEL_PATH)
|
2025-02-26 08:45:03 +08:00
|
|
|
|
EMBEDDING_MODEL_PATH = "/mnt/d/weiweiwang/intention/models/m3e-base"
|
2025-02-07 17:29:20 +08:00
|
|
|
|
embedding = M3E_EMB(EMBEDDING_MODEL_PATH)
|
|
|
|
|
|
know_pass = KNOW_PASS(embedding)
|
|
|
|
|
|
return know_pass
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
|
token = os.getenv("ACCESS_TOKEN")
|
|
|
|
|
|
if token is not None:
|
|
|
|
|
|
env_bearer_token = token
|
|
|
|
|
|
try:
|
|
|
|
|
|
# 默认环境
|
|
|
|
|
|
know_pass = init_env()
|
|
|
|
|
|
|
2025-02-26 08:45:03 +08:00
|
|
|
|
uvicorn.run(app, host='0.0.0.0', port=6007)
|
2025-02-07 17:29:20 +08:00
|
|
|
|
except Exception as e:
|
|
|
|
|
|
print(f"API启动失败!\n报错:\n{e}")
|
|
|
|
|
|
|