2025-02-07 17:29:20 +08:00
|
|
|
|
# phddns start
|
|
|
|
|
|
import paddle
|
|
|
|
|
|
import paddle.nn.functional as F
|
|
|
|
|
|
from paddlenlp.data import Tuple,Pad
|
|
|
|
|
|
from paddlenlp.transformers import AutoModelForSequenceClassification, AutoTokenizer
|
|
|
|
|
|
import argparse
|
|
|
|
|
|
import pandas as pd
|
|
|
|
|
|
import time
|
|
|
|
|
|
import datetime
|
|
|
|
|
|
import pydantic
|
|
|
|
|
|
import uvicorn
|
|
|
|
|
|
from fastapi import Body, FastAPI, File, Form, Query, UploadFile, WebSocket
|
|
|
|
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
|
|
|
|
from pydantic import BaseModel
|
|
|
|
|
|
from typing import List, Optional
|
|
|
|
|
|
from typing import Dict, Any
|
|
|
|
|
|
|
|
|
|
|
|
import sys
|
|
|
|
|
|
from post_api import Similar
|
|
|
|
|
|
sim = Similar()
|
|
|
|
|
|
|
|
|
|
|
|
def load_know():
|
|
|
|
|
|
file_path = './knows/project_know.txt'
|
|
|
|
|
|
f = open(file_path,'r',encoding='utf-8')
|
|
|
|
|
|
data = f.read()
|
|
|
|
|
|
data = data.split('\n')
|
|
|
|
|
|
works = {}
|
|
|
|
|
|
for d in data:
|
|
|
|
|
|
wk = d[:d.find('(')]
|
|
|
|
|
|
works[wk] = d
|
|
|
|
|
|
keys = list(works.keys())
|
|
|
|
|
|
sim.load_know('work_name',keys,drop_dup=True,is_cover=True)
|
|
|
|
|
|
return works
|
|
|
|
|
|
works = load_know()
|
|
|
|
|
|
# print(works)
|
|
|
|
|
|
def standard_project(work_name):
|
|
|
|
|
|
res = sim.know_sim(work_name,'work_name')
|
|
|
|
|
|
res = eval(res)
|
|
|
|
|
|
res = res['data']['results'][0]
|
|
|
|
|
|
t = res['document']['text']
|
|
|
|
|
|
s = res['document']['relevance_score']
|
|
|
|
|
|
print(t)
|
|
|
|
|
|
print(works[t])
|
|
|
|
|
|
print(s)
|
|
|
|
|
|
print(res)
|
|
|
|
|
|
if s > 330:
|
|
|
|
|
|
return works[t]
|
|
|
|
|
|
else:
|
|
|
|
|
|
return work_name
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
OPEN_CROSS_DOMAIN = True
|
|
|
|
|
|
def time_consume(func):
|
|
|
|
|
|
def inner(*args, **kwargs): # 加入可变参数
|
|
|
|
|
|
time_start = time.time()
|
|
|
|
|
|
print("\n接口请求前的时间是:",datetime.datetime.now())
|
|
|
|
|
|
res = func(*args, **kwargs)
|
|
|
|
|
|
time_end = time.time()
|
|
|
|
|
|
print("\n接口请求后的时间是:",datetime.datetime.now())
|
|
|
|
|
|
t = time_end - time_start
|
|
|
|
|
|
print("接口耗时:", round(t,4),"s")
|
|
|
|
|
|
return res
|
|
|
|
|
|
return inner
|
|
|
|
|
|
parser = argparse.ArgumentParser()
|
|
|
|
|
|
parser.add_argument("--device", default="gpu:0", help="Select which device to train model, default to gpu.")
|
|
|
|
|
|
parser.add_argument("--output_file", default="output.txt",type=str)
|
|
|
|
|
|
parser.add_argument("--max_seq_length",
|
|
|
|
|
|
default=128,
|
|
|
|
|
|
type=int,
|
|
|
|
|
|
help="The maximum total input sequence length "
|
|
|
|
|
|
"after tokenization. Sequences longer than this"
|
|
|
|
|
|
"will be truncated, sequences shorter will be padded.")
|
|
|
|
|
|
parser.add_argument("--batch_size",
|
|
|
|
|
|
default=64,
|
|
|
|
|
|
type=int,
|
|
|
|
|
|
help="Batch size per GPU/CPU for training.")
|
|
|
|
|
|
parser.add_argument("--max_seq_len",
|
|
|
|
|
|
default=256,
|
|
|
|
|
|
type=int,
|
|
|
|
|
|
help="The maximum total input sequence length after tokenization.")
|
|
|
|
|
|
parser.add_argument("--schema_lang", choices=["ch", "en"], default="ch", help="Select the language type for schema.")
|
|
|
|
|
|
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
model_names = {
|
2025-02-26 08:45:03 +08:00
|
|
|
|
"intent_token": "/mnt/d/weiweiwang/intention/models/zero_int",
|
|
|
|
|
|
"intent_model": "/mnt/d/weiweiwang/intention/models/zero_int/model",
|
|
|
|
|
|
"uie_token":"/mnt/d/weiweiwang/intention/models/zero_uie",
|
|
|
|
|
|
"uie_model":"/mnt/d/weiweiwang/intention/models/zero_uie/model",
|
2025-02-07 17:29:20 +08:00
|
|
|
|
}
|
|
|
|
|
|
label_names = {
|
|
|
|
|
|
"int" : "./int/data/label.txt",
|
|
|
|
|
|
}
|
|
|
|
|
|
def get_label(label_name):
|
|
|
|
|
|
label = []
|
|
|
|
|
|
with open(label_names[label_name]) as inf:
|
|
|
|
|
|
for idx, line in enumerate(inf):
|
|
|
|
|
|
line = line.strip()
|
|
|
|
|
|
if line != '':
|
|
|
|
|
|
label.append(line)
|
|
|
|
|
|
return label
|
|
|
|
|
|
|
|
|
|
|
|
def get_model(token_path,model_path):
|
|
|
|
|
|
paddle.set_device(args.device)
|
|
|
|
|
|
model = paddle.jit.load(model_path)
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(token_path)
|
|
|
|
|
|
return model,tokenizer
|
|
|
|
|
|
|
|
|
|
|
|
@paddle.no_grad()
|
|
|
|
|
|
def int_predict(model, tokenizer, data, label_list):
|
|
|
|
|
|
examples = []
|
|
|
|
|
|
for text in data:
|
|
|
|
|
|
result = tokenizer(text=text, max_seq_length=args.max_seq_length)
|
|
|
|
|
|
examples.append((result['input_ids'], result['token_type_ids']))
|
|
|
|
|
|
|
|
|
|
|
|
# Seperates data into some batches
|
|
|
|
|
|
batches = [
|
|
|
|
|
|
examples[i: i + args.batch_size]
|
|
|
|
|
|
for i in range(0, len(examples), args.batch_size)
|
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
batchify_fn = lambda samples, fn=Tuple(
|
|
|
|
|
|
Pad(axis=0, pad_val=tokenizer.pad_token_id),
|
|
|
|
|
|
Pad(axis=0, pad_val=tokenizer.pad_token_type_id),
|
|
|
|
|
|
): fn(samples)
|
|
|
|
|
|
|
|
|
|
|
|
results = []
|
|
|
|
|
|
model.eval()
|
|
|
|
|
|
for batch in batches:
|
|
|
|
|
|
input_ids, token_type_ids = batchify_fn(batch)
|
|
|
|
|
|
input_ids = paddle.to_tensor(input_ids)
|
|
|
|
|
|
token_type_ids = paddle.to_tensor(token_type_ids)
|
|
|
|
|
|
logits = model(input_ids, token_type_ids)
|
|
|
|
|
|
probs = F.softmax(logits, axis=1)
|
|
|
|
|
|
idx = paddle.argmax(probs, axis=1).numpy()
|
|
|
|
|
|
idx = idx.tolist()
|
|
|
|
|
|
pros = probs.tolist()
|
|
|
|
|
|
labels = []
|
|
|
|
|
|
for lbi,pro in zip(idx,pros):
|
|
|
|
|
|
#print(pro)
|
|
|
|
|
|
#print(pro[lbi])
|
|
|
|
|
|
labels.append((label_list[lbi], str(round(pro[lbi],5))))
|
|
|
|
|
|
results.extend(labels)
|
|
|
|
|
|
|
|
|
|
|
|
mapping = {"天气查询":1,"互联网查询":2,"页面切换":3,
|
|
|
|
|
|
"日计划数量查询":4,"周计划数量查询":5,"日计划作业内容":6,"周计划作业内容":7,
|
|
|
|
|
|
"施工人数":8,"作业考勤人数":9,"知识问答":10}
|
|
|
|
|
|
nresults = []
|
|
|
|
|
|
for res in results:
|
|
|
|
|
|
nresults.append((mapping[res[0]],res[1]))
|
|
|
|
|
|
|
|
|
|
|
|
return nresults
|
|
|
|
|
|
|
|
|
|
|
|
@paddle.no_grad()
|
|
|
|
|
|
def uie_predict(model, tokenizer, data):
|
|
|
|
|
|
examples = []
|
|
|
|
|
|
for text in data:
|
|
|
|
|
|
result = tokenizer(text=[text['prompt']],
|
|
|
|
|
|
text_pair= [text['content']],
|
|
|
|
|
|
truncation=True,
|
|
|
|
|
|
max_seq_length=args.max_seq_len,
|
|
|
|
|
|
pad_to_max_seq_len=True,
|
|
|
|
|
|
return_attention_mask=True,
|
|
|
|
|
|
return_position_ids=True,
|
|
|
|
|
|
return_dict=False,
|
|
|
|
|
|
return_offsets_mapping=True)
|
|
|
|
|
|
result = result[0]
|
|
|
|
|
|
# print(result)
|
|
|
|
|
|
tokens = tokenizer.convert_ids_to_tokens(result["input_ids"])
|
|
|
|
|
|
# print(tokens)
|
|
|
|
|
|
examples.append((result['input_ids'], result['token_type_ids'],result['position_ids'],result['attention_mask']))
|
|
|
|
|
|
|
|
|
|
|
|
# Seperates data into some batches
|
|
|
|
|
|
batches = [
|
|
|
|
|
|
examples[i: i + args.batch_size]
|
|
|
|
|
|
for i in range(0, len(examples), args.batch_size)
|
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
batchify_fn = lambda samples, fn=Tuple(
|
|
|
|
|
|
Pad(axis=0, pad_val=tokenizer.pad_token_id),
|
|
|
|
|
|
Pad(axis=0, pad_val=tokenizer.pad_token_type_id),
|
|
|
|
|
|
Pad(axis=0, pad_val=0),
|
|
|
|
|
|
Pad(axis=0, pad_val=0),
|
|
|
|
|
|
): fn(samples)
|
|
|
|
|
|
|
|
|
|
|
|
results = []
|
|
|
|
|
|
model.eval()
|
|
|
|
|
|
for batch in batches:
|
|
|
|
|
|
input_ids, token_type_ids,position_ids,attention_mask = batchify_fn(batch)
|
|
|
|
|
|
input_ids = paddle.to_tensor(input_ids)
|
|
|
|
|
|
token_type_ids = paddle.to_tensor(token_type_ids)
|
|
|
|
|
|
position_ids = paddle.to_tensor(position_ids)
|
|
|
|
|
|
attention_mask = paddle.to_tensor(attention_mask)
|
|
|
|
|
|
start_prob,end_prob = model(input_ids, token_type_ids,position_ids,attention_mask)
|
|
|
|
|
|
# 获取最大值及其索引
|
|
|
|
|
|
stmax_value = paddle.max(start_prob)
|
|
|
|
|
|
stmax_value = stmax_value.numpy().item()
|
|
|
|
|
|
stmax_idx = paddle.argmax(start_prob)
|
|
|
|
|
|
# print(f"起始点最大值: {stmax_value.numpy()}")
|
|
|
|
|
|
# print(f"起始点最大值索引: {stmax_idx.numpy()}")
|
|
|
|
|
|
etmax_value = paddle.max(end_prob)
|
|
|
|
|
|
etmax_value = etmax_value.numpy().item()
|
|
|
|
|
|
etmax_idx = paddle.argmax(end_prob)
|
|
|
|
|
|
# print(f"终止点点最大值: {etmax_value.numpy()}")
|
|
|
|
|
|
# print(f"终止点最大值索引: {etmax_idx.numpy()}")
|
|
|
|
|
|
# print(data[0]['content'][stmax_idx:etmax_idx]) # 前面会增加prompt,因此需要多加4个token
|
|
|
|
|
|
tokens = tokenizer.convert_ids_to_tokens(result["input_ids"])
|
|
|
|
|
|
res = ''.join(tokens[stmax_idx:etmax_idx+1])
|
|
|
|
|
|
results.append((res,stmax_value,etmax_value,stmax_idx.numpy().item(),etmax_idx.numpy().item()))
|
|
|
|
|
|
return results
|
|
|
|
|
|
|
|
|
|
|
|
def get_slot(int_res,query):
|
|
|
|
|
|
slot_int = [1,3,4,5,6,7,8,9]
|
|
|
|
|
|
if int_res not in slot_int:
|
|
|
|
|
|
return {}
|
|
|
|
|
|
if int_res==1: #'天气查询'
|
|
|
|
|
|
prompts = ['时间','地点']
|
|
|
|
|
|
elif int_res==3: #'页面切换'
|
|
|
|
|
|
prompts = ['应用','模块','页面']
|
|
|
|
|
|
else:
|
|
|
|
|
|
prompts = ['时间','项目','项目经理','公司','工程','风险等级','周时间','施工状态']
|
|
|
|
|
|
slot = {}
|
|
|
|
|
|
comp = {}
|
|
|
|
|
|
for prompt in prompts:
|
|
|
|
|
|
uie_res = uie_predict(uie_model,uie_tokenizer,[{"prompt":prompt,"content":query}])
|
|
|
|
|
|
print(f"prompt:{prompt} uie_res:{uie_res}")
|
|
|
|
|
|
uie_res[0] = list(uie_res[0])
|
|
|
|
|
|
uie_res[0][3] = uie_res[0][3] - len(prompt)
|
|
|
|
|
|
uie_res[0][4] = uie_res[0][4] - len(prompt)
|
|
|
|
|
|
if uie_res[0][1] >= 0.8 and uie_res[0][2] >= 0.8:
|
|
|
|
|
|
# print(f"prompt:{prompt},值:{uie_res}")
|
|
|
|
|
|
slot[prompt] = uie_res[0][0]
|
|
|
|
|
|
comp[prompt] = uie_res
|
|
|
|
|
|
drop = []
|
|
|
|
|
|
for k,v in slot.items():
|
|
|
|
|
|
flag = 1
|
|
|
|
|
|
for kk,vv in comp.items():
|
|
|
|
|
|
if kk == k:
|
|
|
|
|
|
continue
|
|
|
|
|
|
if comp[k][0][4] < vv[0][3] or comp[k][0][3] > vv[0][4]:
|
|
|
|
|
|
continue
|
|
|
|
|
|
elif comp[k][0][1]<=vv[0][1] and comp[k][0][2] <=vv[0][2]:
|
|
|
|
|
|
print(f"1:{comp[k]}")
|
|
|
|
|
|
print(f"2:{vv}")
|
|
|
|
|
|
flag = 0
|
|
|
|
|
|
break
|
|
|
|
|
|
if flag==0:
|
|
|
|
|
|
drop.append(k)
|
|
|
|
|
|
for d in drop:
|
|
|
|
|
|
del slot[d]
|
|
|
|
|
|
|
|
|
|
|
|
if int_res == 5: # '周计划查询'
|
|
|
|
|
|
if slot.__contains__("周时间"):
|
|
|
|
|
|
slot['时间'] = slot['周时间']
|
|
|
|
|
|
del slot['周时间']
|
|
|
|
|
|
else:
|
|
|
|
|
|
if slot.__contains__("周时间"):
|
|
|
|
|
|
del slot['周时间']
|
|
|
|
|
|
cn_en = {"时间":"date","项目":"project","项目经理":"people",
|
|
|
|
|
|
"公司":"company","工程":"work","风险等级":"risk_level",
|
|
|
|
|
|
"施工状态":"work_status","应用":"app","模块":"module","页面":"page","地点":"area"}
|
|
|
|
|
|
nslot = {}
|
|
|
|
|
|
for k,v in slot.items():
|
|
|
|
|
|
v = v.replace('#','').replace('kv','kV').replace("模块","").replace("页面","")
|
|
|
|
|
|
if k == '工程':
|
|
|
|
|
|
print(v)
|
|
|
|
|
|
v = standard_project(v)
|
|
|
|
|
|
nslot[cn_en[k]] = v
|
|
|
|
|
|
return nslot
|
|
|
|
|
|
|
|
|
|
|
|
def check_lost(int_res,slot):
|
|
|
|
|
|
# mapping = {
|
|
|
|
|
|
# "页面切换":[['页面','应用']],
|
|
|
|
|
|
# "作业计划数量查询":[['时间']],
|
|
|
|
|
|
# "周计划查询":[['时间']],
|
|
|
|
|
|
# "作业内容":[['时间']],
|
|
|
|
|
|
# "施工人数":[['时间']],
|
|
|
|
|
|
# "作业考勤人数":[['时间']],
|
|
|
|
|
|
# }
|
|
|
|
|
|
mapping = {
|
|
|
|
|
|
1:[['date','area']],
|
|
|
|
|
|
3:[['page'],['app'],['module']],
|
|
|
|
|
|
4:[['date']],
|
|
|
|
|
|
5:[['date']],
|
|
|
|
|
|
6:[['date']],
|
|
|
|
|
|
7:[['date']],
|
|
|
|
|
|
8:[[]],
|
|
|
|
|
|
9:[[]],
|
|
|
|
|
|
}
|
|
|
|
|
|
if not mapping.__contains__(int_res):
|
|
|
|
|
|
return 0,[]
|
|
|
|
|
|
cur_k = list(slot.keys())
|
|
|
|
|
|
idx = -1
|
|
|
|
|
|
idx_len = 99
|
|
|
|
|
|
for i in range(len(mapping[int_res])):
|
|
|
|
|
|
sk = mapping[int_res][i]
|
|
|
|
|
|
left = [x for x in sk if x not in cur_k]
|
|
|
|
|
|
more = [x for x in cur_k if x not in sk]
|
|
|
|
|
|
if len(more) >= 0 and len(left) == 0:
|
|
|
|
|
|
idx = i
|
|
|
|
|
|
idx_len = 0
|
|
|
|
|
|
break
|
|
|
|
|
|
if len(left) < idx_len:
|
|
|
|
|
|
idx = i
|
|
|
|
|
|
idx_len =len(left)
|
|
|
|
|
|
|
|
|
|
|
|
if idx_len == 0: # 匹配通过
|
|
|
|
|
|
return 0,cur_k
|
|
|
|
|
|
left= [x for x in mapping[int_res][idx] if x not in cur_k]
|
|
|
|
|
|
return 1,left#mapping[int_res][idx]
|
|
|
|
|
|
|
|
|
|
|
|
@time_consume
|
|
|
|
|
|
def agent(sk,slot,messages):
|
|
|
|
|
|
from openai import OpenAI
|
|
|
|
|
|
|
|
|
|
|
|
# api_base_url='http://36.33.26.201:27861/v1'
|
|
|
|
|
|
# api_key='EMPTY'
|
|
|
|
|
|
# model_name = 'qwen2.5-instruct'
|
|
|
|
|
|
# api_base_url = "http://36.33.26.201:27861/chat"
|
|
|
|
|
|
# api_base_url = "http://58.57.119.80:45291/v1"
|
|
|
|
|
|
# model_name="qwen2.5-instruct"
|
|
|
|
|
|
|
|
|
|
|
|
client = OpenAI(api_key=api_key,
|
|
|
|
|
|
base_url=api_base_url)
|
|
|
|
|
|
|
|
|
|
|
|
prompt = f'''
|
|
|
|
|
|
你是一个槽位询问工具,你需要结合用户输入,根据必填槽位列表,以及当前已有槽位,对用户进行槽位问询
|
|
|
|
|
|
询问用户哪些槽位是缺失的,需要补充,待必填槽位都已填充完毕,将所有必填槽位类型及槽位值以json格式返回
|
|
|
|
|
|
请不要对用户输入的槽位进行修改
|
|
|
|
|
|
|
|
|
|
|
|
###以下是必填槽位列表###
|
|
|
|
|
|
need_slot_lists = {sk}
|
|
|
|
|
|
###以上是必填槽位列表###
|
|
|
|
|
|
|
|
|
|
|
|
###以下是当前已有槽位###
|
|
|
|
|
|
present_slot={slot}
|
|
|
|
|
|
|
|
|
|
|
|
###
|
|
|
|
|
|
|
|
|
|
|
|
'''
|
|
|
|
|
|
# prompt = prompt.replace("{sk}",str(sk)).replace("{slot}",slot)
|
|
|
|
|
|
message = [{"role":"system","content":prompt}]
|
|
|
|
|
|
message.extend(messages)
|
|
|
|
|
|
# print(message)
|
|
|
|
|
|
response = client.chat.completions.create(
|
|
|
|
|
|
messages=message,
|
|
|
|
|
|
model=model_name,
|
|
|
|
|
|
max_tokens=1000,
|
|
|
|
|
|
temperature=0.001,
|
|
|
|
|
|
stream=False
|
|
|
|
|
|
)
|
|
|
|
|
|
res = response.choices[0].message.content
|
|
|
|
|
|
return res
|
|
|
|
|
|
|
|
|
|
|
|
@time_consume
|
|
|
|
|
|
def agent1(int_res,messages):
|
|
|
|
|
|
lslot = {}
|
|
|
|
|
|
for message in messages:
|
|
|
|
|
|
if message["role"] == 'user':
|
|
|
|
|
|
# print(int_res)
|
|
|
|
|
|
# print(message['content'])
|
|
|
|
|
|
slot = get_slot(int_res,message['content'])
|
|
|
|
|
|
for k,v in slot.items():
|
|
|
|
|
|
lslot[k] = v
|
|
|
|
|
|
# print(lslot)
|
|
|
|
|
|
status,sk = check_lost(int_res,lslot)
|
|
|
|
|
|
# print(f"检测是否槽位缺失:{status} 缺失:{sk}")
|
|
|
|
|
|
# print(lslot)
|
|
|
|
|
|
# print(4567)
|
|
|
|
|
|
# print(int_res)
|
|
|
|
|
|
reres = ''
|
|
|
|
|
|
if len(messages)==3:
|
|
|
|
|
|
reres="非常抱歉,"
|
|
|
|
|
|
elif len(messages) == 5:
|
|
|
|
|
|
reres="特别抱歉,我没有理解您的意思,"
|
|
|
|
|
|
if len(messages) >= 7:
|
|
|
|
|
|
return lslot
|
|
|
|
|
|
if status == 1:
|
|
|
|
|
|
# print(11122)
|
|
|
|
|
|
if int_res == 1: # 天气查询
|
|
|
|
|
|
if sk==['date','area']:
|
|
|
|
|
|
return f"{reres}请问您想查询的天气的时间和地点分别是什么?"
|
|
|
|
|
|
elif sk ==['area']:
|
|
|
|
|
|
return f"{reres}请问您想查询哪里的天气情况?"
|
|
|
|
|
|
else:
|
|
|
|
|
|
return f"{reres}请问您想查询哪一天的天气情况?"
|
|
|
|
|
|
if int_res == 3:
|
|
|
|
|
|
return f"{reres}请问您想查询哪个页面?"
|
|
|
|
|
|
if int_res in [4,5,6,7,8,9]:
|
|
|
|
|
|
mapping = {4:"日计划数量",5:"周计划数量",6:"日计划作业内容",7:"周计划作业内容",
|
|
|
|
|
|
8:"施工人数",9:"作业考勤人数"}
|
|
|
|
|
|
# print(f"123{reres}请问您想查询哪天的{mapping[int_res]}")
|
|
|
|
|
|
return f"{reres}请问您想查询哪天的{mapping[int_res]}"
|
|
|
|
|
|
else:
|
|
|
|
|
|
return lslot
|
|
|
|
|
|
|
|
|
|
|
|
# @time_consume
|
|
|
|
|
|
def main(messages):
|
|
|
|
|
|
if len(messages) == 1: # 首轮
|
|
|
|
|
|
query = messages[0]['content']
|
|
|
|
|
|
# query = "请帮我打开风险管理应用的风险勘察申请"
|
|
|
|
|
|
# query = "未来一周内刘思辰500千伏变电站新建工程有多少三级风险项作业计划刚开始施工"
|
|
|
|
|
|
int_res = int_predict(int_model,int_tokenizer,[query],int_label)
|
|
|
|
|
|
# int_res = int_res[0][0]
|
|
|
|
|
|
print(f"意图识别结果:{int_res}")
|
|
|
|
|
|
# slot_int = ['页面切换','作业计划数量查询','周计划查询','作业内容','施工人数','作业考勤人数']
|
|
|
|
|
|
slot = get_slot(int_res[0][0],query)
|
|
|
|
|
|
print(f"槽位获取结果:{slot}")
|
|
|
|
|
|
# 槽位预备,检查是否有缺失的槽位
|
|
|
|
|
|
status,sk = check_lost(int_res[0][0],slot)
|
|
|
|
|
|
print(f"检测是否槽位缺失:{status} 缺失:{sk}")
|
|
|
|
|
|
if status == 1:
|
|
|
|
|
|
# 调用Agent补全
|
|
|
|
|
|
# res = agent(sk,slot,messages)
|
|
|
|
|
|
res = agent1(int_res[0][0],messages)
|
|
|
|
|
|
return {"finished":0,"text":res}
|
|
|
|
|
|
else:
|
|
|
|
|
|
return {"finished":1,"int":int_res[0][0],"slot":slot}
|
|
|
|
|
|
else:
|
|
|
|
|
|
query = messages[0]['content']
|
|
|
|
|
|
int_res = int_predict(int_model,int_tokenizer,[query],int_label)
|
|
|
|
|
|
# int_res = int_res[0][0]
|
|
|
|
|
|
print(f"意图识别结果:{int_res}")
|
|
|
|
|
|
# slot = get_slot(int_res[0][0],query)
|
|
|
|
|
|
# # print(f"槽位获取结果:{slot}")
|
|
|
|
|
|
# # 槽位预备,检查是否有缺失的槽位
|
|
|
|
|
|
# status,sk = check_lost(int_res[0][0],slot)
|
|
|
|
|
|
# res = agent(sk,slot,messages)
|
|
|
|
|
|
res = agent1(int_res[0][0],messages)
|
|
|
|
|
|
if len(messages) >=7:
|
|
|
|
|
|
return {"finished":1,"int":int_res[0][0],"slot":res}
|
|
|
|
|
|
if isinstance(res,str):
|
|
|
|
|
|
return {"finished":0,"text":res}
|
|
|
|
|
|
else:
|
|
|
|
|
|
return {"finished":1,"int":int_res[0][0],"slot":res}
|
|
|
|
|
|
# st = res.find('{')
|
|
|
|
|
|
# et = res.find('}')
|
|
|
|
|
|
# if st != -1 and et != -1:
|
|
|
|
|
|
# res = res[st:et+1]
|
|
|
|
|
|
# res = eval(res)
|
|
|
|
|
|
# return {"finished":1,"int":int_res[0][0],"slot":res}
|
|
|
|
|
|
# else:
|
|
|
|
|
|
# return {"finished":0,"text":res}
|
|
|
|
|
|
|
|
|
|
|
|
# -----------------------------------------------------------------------
|
|
|
|
|
|
# 日志记录
|
|
|
|
|
|
def log_save(fun_name,params,result,exe_time):
|
|
|
|
|
|
cur_time =datetime.datetime.now()
|
|
|
|
|
|
cur_time = cur_time.strftime("%Y%m%d")
|
|
|
|
|
|
log_file = './log/log'+cur_time+".log"
|
|
|
|
|
|
f = open(log_file,'a',encoding='utf-8')
|
|
|
|
|
|
tmp = {}
|
|
|
|
|
|
tmp['fun_name'] = fun_name
|
|
|
|
|
|
tmp['params'] = params
|
|
|
|
|
|
tmp['result'] = result
|
|
|
|
|
|
tmp['exe_time'] = exe_time
|
|
|
|
|
|
tmp['time'] = (datetime.datetime.now()).strftime("%Y-%m-%d %H:%M:%S")
|
|
|
|
|
|
f.write(str(tmp)+'\n')
|
|
|
|
|
|
f.close()
|
|
|
|
|
|
|
|
|
|
|
|
class LabelMessage(BaseModel):
|
|
|
|
|
|
code: int = pydantic.Field(200, description="HTTP status code")
|
|
|
|
|
|
msg: str = pydantic.Field('成功', description="msg")
|
|
|
|
|
|
label: int = pydantic.Field(..., description="标签类型")
|
|
|
|
|
|
probability: float = pydantic.Field(..., description="概率")
|
|
|
|
|
|
|
|
|
|
|
|
async def get_int_label(
|
|
|
|
|
|
user_id: str = Body(...,description='鉴权'),
|
|
|
|
|
|
text: str = Body(..., description="意图识别样本"),
|
|
|
|
|
|
):
|
|
|
|
|
|
cur_time =datetime.datetime.now()
|
|
|
|
|
|
cur_time = cur_time.strftime("%Y:%m:%d:%H:%M:%S")
|
|
|
|
|
|
print(f"{cur_time} :: INT user_id:{user_id} text:{text}")
|
|
|
|
|
|
st = time.time()
|
|
|
|
|
|
if user_id !='3bb66776-1722-4c36-b14a-73dd210fe750':
|
|
|
|
|
|
return LabelMessage(
|
|
|
|
|
|
code=401,
|
|
|
|
|
|
msg = '权限验证失败,请联系接口开发人员',
|
|
|
|
|
|
label=-1,
|
|
|
|
|
|
probability=-1
|
|
|
|
|
|
)
|
|
|
|
|
|
global int_label,int_model,int_tokenizer
|
|
|
|
|
|
res = int_predict(int_model,int_tokenizer,[text],int_label)
|
|
|
|
|
|
log_save('int',text,res,round(time.time() - st,5))
|
|
|
|
|
|
|
|
|
|
|
|
return LabelMessage(
|
|
|
|
|
|
code=200,
|
|
|
|
|
|
msg = '成功',
|
|
|
|
|
|
label=res[0][0],
|
|
|
|
|
|
probability=res[0][1]
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LabelMessage1(BaseModel):
|
|
|
|
|
|
code: int = pydantic.Field(200, description="HTTP status code")
|
|
|
|
|
|
msg: str = pydantic.Field('成功', description="msg")
|
|
|
|
|
|
slot: Dict[str, Any] = pydantic.Field(..., description="槽位(JSON 格式)")
|
|
|
|
|
|
|
|
|
|
|
|
async def get_uie_label(
|
|
|
|
|
|
user_id: str = Body(...,description='鉴权'),
|
|
|
|
|
|
text: str = Body(..., description="槽位填充样本"),
|
|
|
|
|
|
):
|
|
|
|
|
|
cur_time =datetime.datetime.now()
|
|
|
|
|
|
cur_time = cur_time.strftime("%Y:%m:%d:%H:%M:%S")
|
|
|
|
|
|
print(f"{cur_time} :: UIE user_id:{user_id} text:{text}")
|
|
|
|
|
|
st = time.time()
|
|
|
|
|
|
if user_id !='3bb66776-1722-4c36-b14a-73dd210fe750':
|
|
|
|
|
|
return LabelMessage1(
|
|
|
|
|
|
code=401,
|
|
|
|
|
|
msg = '权限验证失败,请联系接口开发人员',
|
|
|
|
|
|
slot={}
|
|
|
|
|
|
)
|
|
|
|
|
|
global int_label,int_model,int_tokenizer
|
|
|
|
|
|
res = int_predict(int_model,int_tokenizer,[text],int_label)
|
|
|
|
|
|
slot = get_slot(res[0][0],text)
|
|
|
|
|
|
log_save('uie',text,{"int":res,"uie":slot},round(time.time() - st,5))
|
|
|
|
|
|
|
|
|
|
|
|
return LabelMessage1(
|
|
|
|
|
|
code=200,
|
|
|
|
|
|
msg = '成功',
|
|
|
|
|
|
slot=slot
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
class LabelMessage2(BaseModel):
|
|
|
|
|
|
code: int = pydantic.Field(200, description="HTTP status code")
|
|
|
|
|
|
msg: str = pydantic.Field('成功', description="msg")
|
|
|
|
|
|
answer: Dict[str, Any] = pydantic.Field(..., description="槽位(JSON 格式)")
|
|
|
|
|
|
|
|
|
|
|
|
async def Agent(
|
|
|
|
|
|
user_id: str = Body(...,description='鉴权'),
|
|
|
|
|
|
messages: List[Dict[str,str]] = Body(..., description="消息"),
|
|
|
|
|
|
):
|
|
|
|
|
|
cur_time =datetime.datetime.now()
|
|
|
|
|
|
cur_time = cur_time.strftime("%Y:%m:%d:%H:%M:%S")
|
|
|
|
|
|
print(f"{cur_time} :: AGENT user_id:{user_id} messages:{messages}")
|
|
|
|
|
|
st = time.time()
|
|
|
|
|
|
if user_id !='3bb66776-1722-4c36-b14a-73dd210fe750':
|
|
|
|
|
|
return LabelMessage2(
|
|
|
|
|
|
code=401,
|
|
|
|
|
|
msg = '权限验证失败,请联系接口开发人员',
|
|
|
|
|
|
answer='',
|
|
|
|
|
|
)
|
|
|
|
|
|
his = main(messages)
|
|
|
|
|
|
log_save('agent',messages,his,round(time.time() - st,5))
|
|
|
|
|
|
|
|
|
|
|
|
return LabelMessage2(
|
|
|
|
|
|
code=200,
|
|
|
|
|
|
msg = '成功',
|
|
|
|
|
|
answer=his,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# -----------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def api_start(host,port,**kwargs):
|
|
|
|
|
|
global app
|
|
|
|
|
|
|
|
|
|
|
|
app = FastAPI()
|
|
|
|
|
|
# Add CORS middleware to allow all origins
|
|
|
|
|
|
# 在config.py中设置OPEN_DOMAIN=True,允许跨域
|
|
|
|
|
|
# set OPEN_DOMAIN=True in config.py to allow cross-domain
|
|
|
|
|
|
if OPEN_CROSS_DOMAIN:
|
|
|
|
|
|
app.add_middleware(
|
|
|
|
|
|
CORSMiddleware,
|
|
|
|
|
|
allow_origins=["*"],
|
|
|
|
|
|
allow_credentials=True,
|
|
|
|
|
|
allow_methods=["*"],
|
|
|
|
|
|
allow_headers=["*"],
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
app.post("/intent_reco",response_model=LabelMessage, summary="意图识别")(get_int_label)
|
|
|
|
|
|
app.post("/slot_reco",response_model=LabelMessage1, summary="槽位抽取")(get_uie_label)
|
|
|
|
|
|
app.post("/agent",response_model=LabelMessage2, summary="Agent")(Agent)
|
|
|
|
|
|
|
|
|
|
|
|
uvicorn.run(app, host=host, port=port)
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
|
# messages = []
|
|
|
|
|
|
int_label = get_label("int")
|
|
|
|
|
|
int_model,int_tokenizer = get_model(model_names['intent_token'],model_names['intent_model'])
|
|
|
|
|
|
uie_model,uie_tokenizer = get_model(model_names['uie_token'],model_names['uie_model'])
|
|
|
|
|
|
host = "0.0.0.0"
|
2025-02-26 08:45:03 +08:00
|
|
|
|
port = 18074
|
2025-02-07 17:29:20 +08:00
|
|
|
|
api_start(host,port)
|
|
|
|
|
|
|
|
|
|
|
|
|