Intention/intent_uie_api.py

587 lines
21 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# phddns start
import paddle
import paddle.nn.functional as F
from paddlenlp.data import Tuple,Pad
from paddlenlp.transformers import AutoModelForSequenceClassification, AutoTokenizer
import argparse
import pandas as pd
import time
import datetime
import pydantic
import uvicorn
from fastapi import Body, FastAPI, File, Form, Query, UploadFile, WebSocket
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import List, Optional
from typing import Dict, Any
import sys
from post_api import Similar
sim = Similar()
def load_know():
file_path = './knows/project_know.txt'
f = open(file_path,'r',encoding='utf-8')
data = f.read()
data = data.split('\n')
works = {}
for d in data:
wk = d[:d.find('(')]
works[wk] = d
keys = list(works.keys())
sim.load_know('work_name',keys,drop_dup=True,is_cover=True)
return works
works = load_know()
# print(works)
def standard_project(work_name):
res = sim.know_sim(work_name,'work_name')
res = eval(res)
res = res['data']['results'][0]
t = res['document']['text']
s = res['document']['relevance_score']
print(t)
print(works[t])
print(s)
print(res)
if s > 330:
return works[t]
else:
return work_name
OPEN_CROSS_DOMAIN = True
def time_consume(func):
def inner(*args, **kwargs): # 加入可变参数
time_start = time.time()
print("\n接口请求前的时间是:",datetime.datetime.now())
res = func(*args, **kwargs)
time_end = time.time()
print("\n接口请求后的时间是:",datetime.datetime.now())
t = time_end - time_start
print("接口耗时:", round(t,4),"s")
return res
return inner
parser = argparse.ArgumentParser()
parser.add_argument("--device", default="gpu:0", help="Select which device to train model, default to gpu.")
parser.add_argument("--output_file", default="output.txt",type=str)
parser.add_argument("--max_seq_length",
default=128,
type=int,
help="The maximum total input sequence length "
"after tokenization. Sequences longer than this"
"will be truncated, sequences shorter will be padded.")
parser.add_argument("--batch_size",
default=64,
type=int,
help="Batch size per GPU/CPU for training.")
parser.add_argument("--max_seq_len",
default=256,
type=int,
help="The maximum total input sequence length after tokenization.")
parser.add_argument("--schema_lang", choices=["ch", "en"], default="ch", help="Select the language type for schema.")
args = parser.parse_args()
model_names = {
"intent_token": "/mnt/d/weiweiwang/intention/models/zero_int",
"intent_model": "/mnt/d/weiweiwang/intention/models/zero_int/model",
"uie_token":"/mnt/d/weiweiwang/intention/models/zero_uie",
"uie_model":"/mnt/d/weiweiwang/intention/models/zero_uie/model",
}
label_names = {
"int" : "./int/data/label.txt",
}
def get_label(label_name):
label = []
with open(label_names[label_name]) as inf:
for idx, line in enumerate(inf):
line = line.strip()
if line != '':
label.append(line)
return label
def get_model(token_path,model_path):
paddle.set_device(args.device)
model = paddle.jit.load(model_path)
tokenizer = AutoTokenizer.from_pretrained(token_path)
return model,tokenizer
@paddle.no_grad()
def int_predict(model, tokenizer, data, label_list):
examples = []
for text in data:
result = tokenizer(text=text, max_seq_length=args.max_seq_length)
examples.append((result['input_ids'], result['token_type_ids']))
# Seperates data into some batches
batches = [
examples[i: i + args.batch_size]
for i in range(0, len(examples), args.batch_size)
]
batchify_fn = lambda samples, fn=Tuple(
Pad(axis=0, pad_val=tokenizer.pad_token_id),
Pad(axis=0, pad_val=tokenizer.pad_token_type_id),
): fn(samples)
results = []
model.eval()
for batch in batches:
input_ids, token_type_ids = batchify_fn(batch)
input_ids = paddle.to_tensor(input_ids)
token_type_ids = paddle.to_tensor(token_type_ids)
logits = model(input_ids, token_type_ids)
probs = F.softmax(logits, axis=1)
idx = paddle.argmax(probs, axis=1).numpy()
idx = idx.tolist()
pros = probs.tolist()
labels = []
for lbi,pro in zip(idx,pros):
#print(pro)
#print(pro[lbi])
labels.append((label_list[lbi], str(round(pro[lbi],5))))
results.extend(labels)
mapping = {"天气查询":1,"互联网查询":2,"页面切换":3,
"日计划数量查询":4,"周计划数量查询":5,"日计划作业内容":6,"周计划作业内容":7,
"施工人数":8,"作业考勤人数":9,"知识问答":10}
nresults = []
for res in results:
nresults.append((mapping[res[0]],res[1]))
return nresults
@paddle.no_grad()
def uie_predict(model, tokenizer, data):
examples = []
for text in data:
result = tokenizer(text=[text['prompt']],
text_pair= [text['content']],
truncation=True,
max_seq_length=args.max_seq_len,
pad_to_max_seq_len=True,
return_attention_mask=True,
return_position_ids=True,
return_dict=False,
return_offsets_mapping=True)
result = result[0]
# print(result)
tokens = tokenizer.convert_ids_to_tokens(result["input_ids"])
# print(tokens)
examples.append((result['input_ids'], result['token_type_ids'],result['position_ids'],result['attention_mask']))
# Seperates data into some batches
batches = [
examples[i: i + args.batch_size]
for i in range(0, len(examples), args.batch_size)
]
batchify_fn = lambda samples, fn=Tuple(
Pad(axis=0, pad_val=tokenizer.pad_token_id),
Pad(axis=0, pad_val=tokenizer.pad_token_type_id),
Pad(axis=0, pad_val=0),
Pad(axis=0, pad_val=0),
): fn(samples)
results = []
model.eval()
for batch in batches:
input_ids, token_type_ids,position_ids,attention_mask = batchify_fn(batch)
input_ids = paddle.to_tensor(input_ids)
token_type_ids = paddle.to_tensor(token_type_ids)
position_ids = paddle.to_tensor(position_ids)
attention_mask = paddle.to_tensor(attention_mask)
start_prob,end_prob = model(input_ids, token_type_ids,position_ids,attention_mask)
# 获取最大值及其索引
stmax_value = paddle.max(start_prob)
stmax_value = stmax_value.numpy().item()
stmax_idx = paddle.argmax(start_prob)
# print(f"起始点最大值: {stmax_value.numpy()}")
# print(f"起始点最大值索引: {stmax_idx.numpy()}")
etmax_value = paddle.max(end_prob)
etmax_value = etmax_value.numpy().item()
etmax_idx = paddle.argmax(end_prob)
# print(f"终止点点最大值: {etmax_value.numpy()}")
# print(f"终止点最大值索引: {etmax_idx.numpy()}")
# print(data[0]['content'][stmax_idx:etmax_idx]) # 前面会增加prompt因此需要多加4个token
tokens = tokenizer.convert_ids_to_tokens(result["input_ids"])
res = ''.join(tokens[stmax_idx:etmax_idx+1])
results.append((res,stmax_value,etmax_value,stmax_idx.numpy().item(),etmax_idx.numpy().item()))
return results
def get_slot(int_res,query):
slot_int = [1,3,4,5,6,7,8,9]
if int_res not in slot_int:
return {}
if int_res==1: #'天气查询'
prompts = ['时间','地点']
elif int_res==3: #'页面切换'
prompts = ['应用','模块','页面']
else:
prompts = ['时间','项目','项目经理','公司','工程','风险等级','周时间','施工状态']
slot = {}
comp = {}
for prompt in prompts:
uie_res = uie_predict(uie_model,uie_tokenizer,[{"prompt":prompt,"content":query}])
print(f"prompt:{prompt} uie_res:{uie_res}")
uie_res[0] = list(uie_res[0])
uie_res[0][3] = uie_res[0][3] - len(prompt)
uie_res[0][4] = uie_res[0][4] - len(prompt)
if uie_res[0][1] >= 0.8 and uie_res[0][2] >= 0.8:
# print(f"prompt:{prompt},值:{uie_res}")
slot[prompt] = uie_res[0][0]
comp[prompt] = uie_res
drop = []
for k,v in slot.items():
flag = 1
for kk,vv in comp.items():
if kk == k:
continue
if comp[k][0][4] < vv[0][3] or comp[k][0][3] > vv[0][4]:
continue
elif comp[k][0][1]<=vv[0][1] and comp[k][0][2] <=vv[0][2]:
print(f"1:{comp[k]}")
print(f"2:{vv}")
flag = 0
break
if flag==0:
drop.append(k)
for d in drop:
del slot[d]
if int_res == 5: # '周计划查询'
if slot.__contains__("周时间"):
slot['时间'] = slot['周时间']
del slot['周时间']
else:
if slot.__contains__("周时间"):
del slot['周时间']
cn_en = {"时间":"date","项目":"project","项目经理":"people",
"公司":"company","工程":"work","风险等级":"risk_level",
"施工状态":"work_status","应用":"app","模块":"module","页面":"page","地点":"area"}
nslot = {}
for k,v in slot.items():
v = v.replace('#','').replace('kv','kV').replace("模块","").replace("页面","")
if k == '工程':
print(v)
v = standard_project(v)
nslot[cn_en[k]] = v
return nslot
def check_lost(int_res,slot):
# mapping = {
# "页面切换":[['页面','应用']],
# "作业计划数量查询":[['时间']],
# "周计划查询":[['时间']],
# "作业内容":[['时间']],
# "施工人数":[['时间']],
# "作业考勤人数":[['时间']],
# }
mapping = {
1:[['date','area']],
3:[['page'],['app'],['module']],
4:[['date']],
5:[['date']],
6:[['date']],
7:[['date']],
8:[[]],
9:[[]],
}
if not mapping.__contains__(int_res):
return 0,[]
cur_k = list(slot.keys())
idx = -1
idx_len = 99
for i in range(len(mapping[int_res])):
sk = mapping[int_res][i]
left = [x for x in sk if x not in cur_k]
more = [x for x in cur_k if x not in sk]
if len(more) >= 0 and len(left) == 0:
idx = i
idx_len = 0
break
if len(left) < idx_len:
idx = i
idx_len =len(left)
if idx_len == 0: # 匹配通过
return 0,cur_k
left= [x for x in mapping[int_res][idx] if x not in cur_k]
return 1,left#mapping[int_res][idx]
@time_consume
def agent(sk,slot,messages):
from openai import OpenAI
# api_base_url='http://36.33.26.201:27861/v1'
# api_key='EMPTY'
# model_name = 'qwen2.5-instruct'
# api_base_url = "http://36.33.26.201:27861/chat"
# api_base_url = "http://58.57.119.80:45291/v1"
# model_name="qwen2.5-instruct"
client = OpenAI(api_key=api_key,
base_url=api_base_url)
prompt = f'''
你是一个槽位询问工具,你需要结合用户输入,根据必填槽位列表,以及当前已有槽位,对用户进行槽位问询
询问用户哪些槽位是缺失的需要补充待必填槽位都已填充完毕将所有必填槽位类型及槽位值以json格式返回
请不要对用户输入的槽位进行修改
###以下是必填槽位列表###
need_slot_lists = {sk}
###以上是必填槽位列表###
###以下是当前已有槽位###
present_slot={slot}
###
'''
# prompt = prompt.replace("{sk}",str(sk)).replace("{slot}",slot)
message = [{"role":"system","content":prompt}]
message.extend(messages)
# print(message)
response = client.chat.completions.create(
messages=message,
model=model_name,
max_tokens=1000,
temperature=0.001,
stream=False
)
res = response.choices[0].message.content
return res
@time_consume
def agent1(int_res,messages):
lslot = {}
for message in messages:
if message["role"] == 'user':
# print(int_res)
# print(message['content'])
slot = get_slot(int_res,message['content'])
for k,v in slot.items():
lslot[k] = v
# print(lslot)
status,sk = check_lost(int_res,lslot)
# print(f"检测是否槽位缺失:{status} 缺失:{sk}")
# print(lslot)
# print(4567)
# print(int_res)
reres = ''
if len(messages)==3:
reres="非常抱歉,"
elif len(messages) == 5:
reres="特别抱歉,我没有理解您的意思,"
if len(messages) >= 7:
return lslot
if status == 1:
# print(11122)
if int_res == 1: # 天气查询
if sk==['date','area']:
return f"{reres}请问您想查询的天气的时间和地点分别是什么?"
elif sk ==['area']:
return f"{reres}请问您想查询哪里的天气情况?"
else:
return f"{reres}请问您想查询哪一天的天气情况?"
if int_res == 3:
return f"{reres}请问您想查询哪个页面?"
if int_res in [4,5,6,7,8,9]:
mapping = {4:"日计划数量",5:"周计划数量",6:"日计划作业内容",7:"周计划作业内容",
8:"施工人数",9:"作业考勤人数"}
# print(f"123{reres}请问您想查询哪天的{mapping[int_res]}")
return f"{reres}请问您想查询哪天的{mapping[int_res]}"
else:
return lslot
# @time_consume
def main(messages):
if len(messages) == 1: # 首轮
query = messages[0]['content']
# query = "请帮我打开风险管理应用的风险勘察申请"
# query = "未来一周内刘思辰500千伏变电站新建工程有多少三级风险项作业计划刚开始施工"
int_res = int_predict(int_model,int_tokenizer,[query],int_label)
# int_res = int_res[0][0]
print(f"意图识别结果:{int_res}")
# slot_int = ['页面切换','作业计划数量查询','周计划查询','作业内容','施工人数','作业考勤人数']
slot = get_slot(int_res[0][0],query)
print(f"槽位获取结果:{slot}")
# 槽位预备,检查是否有缺失的槽位
status,sk = check_lost(int_res[0][0],slot)
print(f"检测是否槽位缺失:{status} 缺失:{sk}")
if status == 1:
# 调用Agent补全
# res = agent(sk,slot,messages)
res = agent1(int_res[0][0],messages)
return {"finished":0,"text":res}
else:
return {"finished":1,"int":int_res[0][0],"slot":slot}
else:
query = messages[0]['content']
int_res = int_predict(int_model,int_tokenizer,[query],int_label)
# int_res = int_res[0][0]
print(f"意图识别结果:{int_res}")
# slot = get_slot(int_res[0][0],query)
# # print(f"槽位获取结果:{slot}")
# # 槽位预备,检查是否有缺失的槽位
# status,sk = check_lost(int_res[0][0],slot)
# res = agent(sk,slot,messages)
res = agent1(int_res[0][0],messages)
if len(messages) >=7:
return {"finished":1,"int":int_res[0][0],"slot":res}
if isinstance(res,str):
return {"finished":0,"text":res}
else:
return {"finished":1,"int":int_res[0][0],"slot":res}
# st = res.find('{')
# et = res.find('}')
# if st != -1 and et != -1:
# res = res[st:et+1]
# res = eval(res)
# return {"finished":1,"int":int_res[0][0],"slot":res}
# else:
# return {"finished":0,"text":res}
# -----------------------------------------------------------------------
# 日志记录
def log_save(fun_name,params,result,exe_time):
cur_time =datetime.datetime.now()
cur_time = cur_time.strftime("%Y%m%d")
log_file = './log/log'+cur_time+".log"
f = open(log_file,'a',encoding='utf-8')
tmp = {}
tmp['fun_name'] = fun_name
tmp['params'] = params
tmp['result'] = result
tmp['exe_time'] = exe_time
tmp['time'] = (datetime.datetime.now()).strftime("%Y-%m-%d %H:%M:%S")
f.write(str(tmp)+'\n')
f.close()
class LabelMessage(BaseModel):
code: int = pydantic.Field(200, description="HTTP status code")
msg: str = pydantic.Field('成功', description="msg")
label: int = pydantic.Field(..., description="标签类型")
probability: float = pydantic.Field(..., description="概率")
async def get_int_label(
user_id: str = Body(...,description='鉴权'),
text: str = Body(..., description="意图识别样本"),
):
cur_time =datetime.datetime.now()
cur_time = cur_time.strftime("%Y:%m:%d:%H:%M:%S")
print(f"{cur_time} :: INT user_id:{user_id} text:{text}")
st = time.time()
if user_id !='3bb66776-1722-4c36-b14a-73dd210fe750':
return LabelMessage(
code=401,
msg = '权限验证失败,请联系接口开发人员',
label=-1,
probability=-1
)
global int_label,int_model,int_tokenizer
res = int_predict(int_model,int_tokenizer,[text],int_label)
log_save('int',text,res,round(time.time() - st,5))
return LabelMessage(
code=200,
msg = '成功',
label=res[0][0],
probability=res[0][1]
)
class LabelMessage1(BaseModel):
code: int = pydantic.Field(200, description="HTTP status code")
msg: str = pydantic.Field('成功', description="msg")
slot: Dict[str, Any] = pydantic.Field(..., description="槽位JSON 格式)")
async def get_uie_label(
user_id: str = Body(...,description='鉴权'),
text: str = Body(..., description="槽位填充样本"),
):
cur_time =datetime.datetime.now()
cur_time = cur_time.strftime("%Y:%m:%d:%H:%M:%S")
print(f"{cur_time} :: UIE user_id:{user_id} text:{text}")
st = time.time()
if user_id !='3bb66776-1722-4c36-b14a-73dd210fe750':
return LabelMessage1(
code=401,
msg = '权限验证失败,请联系接口开发人员',
slot={}
)
global int_label,int_model,int_tokenizer
res = int_predict(int_model,int_tokenizer,[text],int_label)
slot = get_slot(res[0][0],text)
log_save('uie',text,{"int":res,"uie":slot},round(time.time() - st,5))
return LabelMessage1(
code=200,
msg = '成功',
slot=slot
)
class LabelMessage2(BaseModel):
code: int = pydantic.Field(200, description="HTTP status code")
msg: str = pydantic.Field('成功', description="msg")
answer: Dict[str, Any] = pydantic.Field(..., description="槽位JSON 格式)")
async def Agent(
user_id: str = Body(...,description='鉴权'),
messages: List[Dict[str,str]] = Body(..., description="消息"),
):
cur_time =datetime.datetime.now()
cur_time = cur_time.strftime("%Y:%m:%d:%H:%M:%S")
print(f"{cur_time} :: AGENT user_id:{user_id} messages:{messages}")
st = time.time()
if user_id !='3bb66776-1722-4c36-b14a-73dd210fe750':
return LabelMessage2(
code=401,
msg = '权限验证失败,请联系接口开发人员',
answer='',
)
his = main(messages)
log_save('agent',messages,his,round(time.time() - st,5))
return LabelMessage2(
code=200,
msg = '成功',
answer=his,
)
# -----------------------------------------------------------------------
def api_start(host,port,**kwargs):
global app
app = FastAPI()
# Add CORS middleware to allow all origins
# 在config.py中设置OPEN_DOMAIN=True允许跨域
# set OPEN_DOMAIN=True in config.py to allow cross-domain
if OPEN_CROSS_DOMAIN:
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.post("/intent_reco",response_model=LabelMessage, summary="意图识别")(get_int_label)
app.post("/slot_reco",response_model=LabelMessage1, summary="槽位抽取")(get_uie_label)
app.post("/agent",response_model=LabelMessage2, summary="Agent")(Agent)
uvicorn.run(app, host=host, port=port)
if __name__ == "__main__":
# messages = []
int_label = get_label("int")
int_model,int_tokenizer = get_model(model_names['intent_token'],model_names['intent_model'])
uie_model,uie_tokenizer = get_model(model_names['uie_token'],model_names['uie_model'])
host = "0.0.0.0"
port = 18074
api_start(host,port)