Intention/embedding.py

105 lines
3.2 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
import numpy as np
import logging
import uvicorn
import datetime
from fastapi import FastAPI, Security, HTTPException
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from pydantic import Field, BaseModel, validator
from typing import Optional, List,Dict
from models import *
def response(code, msg, data=None):
time = str(datetime.datetime.now())
if data is None:
data = []
result = {
"code": code,
"message": msg,
"data": data,
"time": time
}
return result
def success(data=None, msg=''):
return
class QADocs(BaseModel):
query: Optional[str]
documents: Optional[List[str]]
class Knows(BaseModel):
know_id: Optional[str]
contents: Optional[List[str]]
drop_dup: Optional[bool]
is_cover: Optional[bool]
class Know_Sim(BaseModel):
query: Optional[str]
know_id: Optional[str]
top_k: Optional[int] = 10
app = FastAPI()
security = HTTPBearer()
env_bearer_token = 'ACCESS_TOKEN'
@app.post('/v1/embedding')
async def handle_post_request1(sentences1: List[str], sentences2: List[str], credentials: HTTPAuthorizationCredentials = Security(security)):
global know_pass
token = credentials.credentials
if env_bearer_token is not None and token != env_bearer_token:
raise HTTPException(status_code=401, detail="Invalid token")
similarity = know_pass.get_similarity_pair(sentences1,sentences2)
return response(200, msg="获取相似性成功",data=similarity)
@app.post('/v1/load_know')
async def handle_post_request2(knows: Knows, credentials: HTTPAuthorizationCredentials = Security(security)):
know_id = knows.know_id
contents = knows.contents
drop_dup = knows.drop_dup
is_cover = knows.is_cover
global know_pass
token = credentials.credentials
if env_bearer_token is not None and token != env_bearer_token:
raise HTTPException(status_code=401, detail="Invalid token")
know_pass.load_know(know_id, contents, drop_dup,is_cover)
return response(200, msg=f"生成{know_id}知识库成功")
@app.post('/v1/know_sim')
async def handle_post_request3(know_sim:Know_Sim, credentials: HTTPAuthorizationCredentials = Security(security)):
query = know_sim.query
know_id = know_sim.know_id
top_k = know_sim.top_k
global know_pass
token = credentials.credentials
if env_bearer_token is not None and token != env_bearer_token:
raise HTTPException(status_code=401, detail="Invalid token")
similarity = know_pass.get_similarity_know(query,know_id,top_k)
return response(200, msg="获取相似性成功",data=similarity)
def init_env():
# 初始化模型
print("初始模型加载")
# embedding = BCE_EMB(EMBEDDING_MODEL_PATH)
EMBEDDING_MODEL_PATH = "/mnt/d/weiweiwang/intention/models/m3e-base"
embedding = M3E_EMB(EMBEDDING_MODEL_PATH)
know_pass = KNOW_PASS(embedding)
return know_pass
if __name__ == "__main__":
token = os.getenv("ACCESS_TOKEN")
if token is not None:
env_bearer_token = token
try:
# 默认环境
know_pass = init_env()
uvicorn.run(app, host='0.0.0.0', port=6007)
except Exception as e:
print(f"API启动失败\n报错:\n{e}")