587 lines
21 KiB
Python
587 lines
21 KiB
Python
# phddns start
|
||
import paddle
|
||
import paddle.nn.functional as F
|
||
from paddlenlp.data import Tuple,Pad
|
||
from paddlenlp.transformers import AutoModelForSequenceClassification, AutoTokenizer
|
||
import argparse
|
||
import pandas as pd
|
||
import time
|
||
import datetime
|
||
import pydantic
|
||
import uvicorn
|
||
from fastapi import Body, FastAPI, File, Form, Query, UploadFile, WebSocket
|
||
from fastapi.middleware.cors import CORSMiddleware
|
||
from pydantic import BaseModel
|
||
from typing import List, Optional
|
||
from typing import Dict, Any
|
||
|
||
import sys
|
||
from post_api import Similar
|
||
sim = Similar()
|
||
|
||
def load_know():
|
||
file_path = './knows/project_know.txt'
|
||
f = open(file_path,'r',encoding='utf-8')
|
||
data = f.read()
|
||
data = data.split('\n')
|
||
works = {}
|
||
for d in data:
|
||
wk = d[:d.find('(')]
|
||
works[wk] = d
|
||
keys = list(works.keys())
|
||
sim.load_know('work_name',keys,drop_dup=True,is_cover=True)
|
||
return works
|
||
works = load_know()
|
||
# print(works)
|
||
def standard_project(work_name):
|
||
res = sim.know_sim(work_name,'work_name')
|
||
res = eval(res)
|
||
res = res['data']['results'][0]
|
||
t = res['document']['text']
|
||
s = res['document']['relevance_score']
|
||
print(t)
|
||
print(works[t])
|
||
print(s)
|
||
print(res)
|
||
if s > 330:
|
||
return works[t]
|
||
else:
|
||
return work_name
|
||
|
||
|
||
|
||
OPEN_CROSS_DOMAIN = True
|
||
def time_consume(func):
|
||
def inner(*args, **kwargs): # 加入可变参数
|
||
time_start = time.time()
|
||
print("\n接口请求前的时间是:",datetime.datetime.now())
|
||
res = func(*args, **kwargs)
|
||
time_end = time.time()
|
||
print("\n接口请求后的时间是:",datetime.datetime.now())
|
||
t = time_end - time_start
|
||
print("接口耗时:", round(t,4),"s")
|
||
return res
|
||
return inner
|
||
parser = argparse.ArgumentParser()
|
||
parser.add_argument("--device", default="gpu:0", help="Select which device to train model, default to gpu.")
|
||
parser.add_argument("--output_file", default="output.txt",type=str)
|
||
parser.add_argument("--max_seq_length",
|
||
default=128,
|
||
type=int,
|
||
help="The maximum total input sequence length "
|
||
"after tokenization. Sequences longer than this"
|
||
"will be truncated, sequences shorter will be padded.")
|
||
parser.add_argument("--batch_size",
|
||
default=64,
|
||
type=int,
|
||
help="Batch size per GPU/CPU for training.")
|
||
parser.add_argument("--max_seq_len",
|
||
default=256,
|
||
type=int,
|
||
help="The maximum total input sequence length after tokenization.")
|
||
parser.add_argument("--schema_lang", choices=["ch", "en"], default="ch", help="Select the language type for schema.")
|
||
|
||
args = parser.parse_args()
|
||
model_names = {
|
||
"intent_token": "/mnt/d/weiweiwang/intention/models/zero_int",
|
||
"intent_model": "/mnt/d/weiweiwang/intention/models/zero_int/model",
|
||
"uie_token":"/mnt/d/weiweiwang/intention/models/zero_uie",
|
||
"uie_model":"/mnt/d/weiweiwang/intention/models/zero_uie/model",
|
||
}
|
||
label_names = {
|
||
"int" : "./int/data/label.txt",
|
||
}
|
||
def get_label(label_name):
|
||
label = []
|
||
with open(label_names[label_name]) as inf:
|
||
for idx, line in enumerate(inf):
|
||
line = line.strip()
|
||
if line != '':
|
||
label.append(line)
|
||
return label
|
||
|
||
def get_model(token_path,model_path):
|
||
paddle.set_device(args.device)
|
||
model = paddle.jit.load(model_path)
|
||
tokenizer = AutoTokenizer.from_pretrained(token_path)
|
||
return model,tokenizer
|
||
|
||
@paddle.no_grad()
|
||
def int_predict(model, tokenizer, data, label_list):
|
||
examples = []
|
||
for text in data:
|
||
result = tokenizer(text=text, max_seq_length=args.max_seq_length)
|
||
examples.append((result['input_ids'], result['token_type_ids']))
|
||
|
||
# Seperates data into some batches
|
||
batches = [
|
||
examples[i: i + args.batch_size]
|
||
for i in range(0, len(examples), args.batch_size)
|
||
]
|
||
|
||
batchify_fn = lambda samples, fn=Tuple(
|
||
Pad(axis=0, pad_val=tokenizer.pad_token_id),
|
||
Pad(axis=0, pad_val=tokenizer.pad_token_type_id),
|
||
): fn(samples)
|
||
|
||
results = []
|
||
model.eval()
|
||
for batch in batches:
|
||
input_ids, token_type_ids = batchify_fn(batch)
|
||
input_ids = paddle.to_tensor(input_ids)
|
||
token_type_ids = paddle.to_tensor(token_type_ids)
|
||
logits = model(input_ids, token_type_ids)
|
||
probs = F.softmax(logits, axis=1)
|
||
idx = paddle.argmax(probs, axis=1).numpy()
|
||
idx = idx.tolist()
|
||
pros = probs.tolist()
|
||
labels = []
|
||
for lbi,pro in zip(idx,pros):
|
||
#print(pro)
|
||
#print(pro[lbi])
|
||
labels.append((label_list[lbi], str(round(pro[lbi],5))))
|
||
results.extend(labels)
|
||
|
||
mapping = {"天气查询":1,"互联网查询":2,"页面切换":3,
|
||
"日计划数量查询":4,"周计划数量查询":5,"日计划作业内容":6,"周计划作业内容":7,
|
||
"施工人数":8,"作业考勤人数":9,"知识问答":10}
|
||
nresults = []
|
||
for res in results:
|
||
nresults.append((mapping[res[0]],res[1]))
|
||
|
||
return nresults
|
||
|
||
@paddle.no_grad()
|
||
def uie_predict(model, tokenizer, data):
|
||
examples = []
|
||
for text in data:
|
||
result = tokenizer(text=[text['prompt']],
|
||
text_pair= [text['content']],
|
||
truncation=True,
|
||
max_seq_length=args.max_seq_len,
|
||
pad_to_max_seq_len=True,
|
||
return_attention_mask=True,
|
||
return_position_ids=True,
|
||
return_dict=False,
|
||
return_offsets_mapping=True)
|
||
result = result[0]
|
||
# print(result)
|
||
tokens = tokenizer.convert_ids_to_tokens(result["input_ids"])
|
||
# print(tokens)
|
||
examples.append((result['input_ids'], result['token_type_ids'],result['position_ids'],result['attention_mask']))
|
||
|
||
# Seperates data into some batches
|
||
batches = [
|
||
examples[i: i + args.batch_size]
|
||
for i in range(0, len(examples), args.batch_size)
|
||
]
|
||
|
||
batchify_fn = lambda samples, fn=Tuple(
|
||
Pad(axis=0, pad_val=tokenizer.pad_token_id),
|
||
Pad(axis=0, pad_val=tokenizer.pad_token_type_id),
|
||
Pad(axis=0, pad_val=0),
|
||
Pad(axis=0, pad_val=0),
|
||
): fn(samples)
|
||
|
||
results = []
|
||
model.eval()
|
||
for batch in batches:
|
||
input_ids, token_type_ids,position_ids,attention_mask = batchify_fn(batch)
|
||
input_ids = paddle.to_tensor(input_ids)
|
||
token_type_ids = paddle.to_tensor(token_type_ids)
|
||
position_ids = paddle.to_tensor(position_ids)
|
||
attention_mask = paddle.to_tensor(attention_mask)
|
||
start_prob,end_prob = model(input_ids, token_type_ids,position_ids,attention_mask)
|
||
# 获取最大值及其索引
|
||
stmax_value = paddle.max(start_prob)
|
||
stmax_value = stmax_value.numpy().item()
|
||
stmax_idx = paddle.argmax(start_prob)
|
||
# print(f"起始点最大值: {stmax_value.numpy()}")
|
||
# print(f"起始点最大值索引: {stmax_idx.numpy()}")
|
||
etmax_value = paddle.max(end_prob)
|
||
etmax_value = etmax_value.numpy().item()
|
||
etmax_idx = paddle.argmax(end_prob)
|
||
# print(f"终止点点最大值: {etmax_value.numpy()}")
|
||
# print(f"终止点最大值索引: {etmax_idx.numpy()}")
|
||
# print(data[0]['content'][stmax_idx:etmax_idx]) # 前面会增加prompt,因此需要多加4个token
|
||
tokens = tokenizer.convert_ids_to_tokens(result["input_ids"])
|
||
res = ''.join(tokens[stmax_idx:etmax_idx+1])
|
||
results.append((res,stmax_value,etmax_value,stmax_idx.numpy().item(),etmax_idx.numpy().item()))
|
||
return results
|
||
|
||
def get_slot(int_res,query):
|
||
slot_int = [1,3,4,5,6,7,8,9]
|
||
if int_res not in slot_int:
|
||
return {}
|
||
if int_res==1: #'天气查询'
|
||
prompts = ['时间','地点']
|
||
elif int_res==3: #'页面切换'
|
||
prompts = ['应用','模块','页面']
|
||
else:
|
||
prompts = ['时间','项目','项目经理','公司','工程','风险等级','周时间','施工状态']
|
||
slot = {}
|
||
comp = {}
|
||
for prompt in prompts:
|
||
uie_res = uie_predict(uie_model,uie_tokenizer,[{"prompt":prompt,"content":query}])
|
||
print(f"prompt:{prompt} uie_res:{uie_res}")
|
||
uie_res[0] = list(uie_res[0])
|
||
uie_res[0][3] = uie_res[0][3] - len(prompt)
|
||
uie_res[0][4] = uie_res[0][4] - len(prompt)
|
||
if uie_res[0][1] >= 0.8 and uie_res[0][2] >= 0.8:
|
||
# print(f"prompt:{prompt},值:{uie_res}")
|
||
slot[prompt] = uie_res[0][0]
|
||
comp[prompt] = uie_res
|
||
drop = []
|
||
for k,v in slot.items():
|
||
flag = 1
|
||
for kk,vv in comp.items():
|
||
if kk == k:
|
||
continue
|
||
if comp[k][0][4] < vv[0][3] or comp[k][0][3] > vv[0][4]:
|
||
continue
|
||
elif comp[k][0][1]<=vv[0][1] and comp[k][0][2] <=vv[0][2]:
|
||
print(f"1:{comp[k]}")
|
||
print(f"2:{vv}")
|
||
flag = 0
|
||
break
|
||
if flag==0:
|
||
drop.append(k)
|
||
for d in drop:
|
||
del slot[d]
|
||
|
||
if int_res == 5: # '周计划查询'
|
||
if slot.__contains__("周时间"):
|
||
slot['时间'] = slot['周时间']
|
||
del slot['周时间']
|
||
else:
|
||
if slot.__contains__("周时间"):
|
||
del slot['周时间']
|
||
cn_en = {"时间":"date","项目":"project","项目经理":"people",
|
||
"公司":"company","工程":"work","风险等级":"risk_level",
|
||
"施工状态":"work_status","应用":"app","模块":"module","页面":"page","地点":"area"}
|
||
nslot = {}
|
||
for k,v in slot.items():
|
||
v = v.replace('#','').replace('kv','kV').replace("模块","").replace("页面","")
|
||
if k == '工程':
|
||
print(v)
|
||
v = standard_project(v)
|
||
nslot[cn_en[k]] = v
|
||
return nslot
|
||
|
||
def check_lost(int_res,slot):
|
||
# mapping = {
|
||
# "页面切换":[['页面','应用']],
|
||
# "作业计划数量查询":[['时间']],
|
||
# "周计划查询":[['时间']],
|
||
# "作业内容":[['时间']],
|
||
# "施工人数":[['时间']],
|
||
# "作业考勤人数":[['时间']],
|
||
# }
|
||
mapping = {
|
||
1:[['date','area']],
|
||
3:[['page'],['app'],['module']],
|
||
4:[['date']],
|
||
5:[['date']],
|
||
6:[['date']],
|
||
7:[['date']],
|
||
8:[[]],
|
||
9:[[]],
|
||
}
|
||
if not mapping.__contains__(int_res):
|
||
return 0,[]
|
||
cur_k = list(slot.keys())
|
||
idx = -1
|
||
idx_len = 99
|
||
for i in range(len(mapping[int_res])):
|
||
sk = mapping[int_res][i]
|
||
left = [x for x in sk if x not in cur_k]
|
||
more = [x for x in cur_k if x not in sk]
|
||
if len(more) >= 0 and len(left) == 0:
|
||
idx = i
|
||
idx_len = 0
|
||
break
|
||
if len(left) < idx_len:
|
||
idx = i
|
||
idx_len =len(left)
|
||
|
||
if idx_len == 0: # 匹配通过
|
||
return 0,cur_k
|
||
left= [x for x in mapping[int_res][idx] if x not in cur_k]
|
||
return 1,left#mapping[int_res][idx]
|
||
|
||
@time_consume
|
||
def agent(sk,slot,messages):
|
||
from openai import OpenAI
|
||
|
||
# api_base_url='http://36.33.26.201:27861/v1'
|
||
# api_key='EMPTY'
|
||
# model_name = 'qwen2.5-instruct'
|
||
# api_base_url = "http://36.33.26.201:27861/chat"
|
||
# api_base_url = "http://58.57.119.80:45291/v1"
|
||
# model_name="qwen2.5-instruct"
|
||
|
||
client = OpenAI(api_key=api_key,
|
||
base_url=api_base_url)
|
||
|
||
prompt = f'''
|
||
你是一个槽位询问工具,你需要结合用户输入,根据必填槽位列表,以及当前已有槽位,对用户进行槽位问询
|
||
询问用户哪些槽位是缺失的,需要补充,待必填槽位都已填充完毕,将所有必填槽位类型及槽位值以json格式返回
|
||
请不要对用户输入的槽位进行修改
|
||
|
||
###以下是必填槽位列表###
|
||
need_slot_lists = {sk}
|
||
###以上是必填槽位列表###
|
||
|
||
###以下是当前已有槽位###
|
||
present_slot={slot}
|
||
|
||
###
|
||
|
||
'''
|
||
# prompt = prompt.replace("{sk}",str(sk)).replace("{slot}",slot)
|
||
message = [{"role":"system","content":prompt}]
|
||
message.extend(messages)
|
||
# print(message)
|
||
response = client.chat.completions.create(
|
||
messages=message,
|
||
model=model_name,
|
||
max_tokens=1000,
|
||
temperature=0.001,
|
||
stream=False
|
||
)
|
||
res = response.choices[0].message.content
|
||
return res
|
||
|
||
@time_consume
|
||
def agent1(int_res,messages):
|
||
lslot = {}
|
||
for message in messages:
|
||
if message["role"] == 'user':
|
||
# print(int_res)
|
||
# print(message['content'])
|
||
slot = get_slot(int_res,message['content'])
|
||
for k,v in slot.items():
|
||
lslot[k] = v
|
||
# print(lslot)
|
||
status,sk = check_lost(int_res,lslot)
|
||
# print(f"检测是否槽位缺失:{status} 缺失:{sk}")
|
||
# print(lslot)
|
||
# print(4567)
|
||
# print(int_res)
|
||
reres = ''
|
||
if len(messages)==3:
|
||
reres="非常抱歉,"
|
||
elif len(messages) == 5:
|
||
reres="特别抱歉,我没有理解您的意思,"
|
||
if len(messages) >= 7:
|
||
return lslot
|
||
if status == 1:
|
||
# print(11122)
|
||
if int_res == 1: # 天气查询
|
||
if sk==['date','area']:
|
||
return f"{reres}请问您想查询的天气的时间和地点分别是什么?"
|
||
elif sk ==['area']:
|
||
return f"{reres}请问您想查询哪里的天气情况?"
|
||
else:
|
||
return f"{reres}请问您想查询哪一天的天气情况?"
|
||
if int_res == 3:
|
||
return f"{reres}请问您想查询哪个页面?"
|
||
if int_res in [4,5,6,7,8,9]:
|
||
mapping = {4:"日计划数量",5:"周计划数量",6:"日计划作业内容",7:"周计划作业内容",
|
||
8:"施工人数",9:"作业考勤人数"}
|
||
# print(f"123{reres}请问您想查询哪天的{mapping[int_res]}")
|
||
return f"{reres}请问您想查询哪天的{mapping[int_res]}"
|
||
else:
|
||
return lslot
|
||
|
||
# @time_consume
|
||
def main(messages):
|
||
if len(messages) == 1: # 首轮
|
||
query = messages[0]['content']
|
||
# query = "请帮我打开风险管理应用的风险勘察申请"
|
||
# query = "未来一周内刘思辰500千伏变电站新建工程有多少三级风险项作业计划刚开始施工"
|
||
int_res = int_predict(int_model,int_tokenizer,[query],int_label)
|
||
# int_res = int_res[0][0]
|
||
print(f"意图识别结果:{int_res}")
|
||
# slot_int = ['页面切换','作业计划数量查询','周计划查询','作业内容','施工人数','作业考勤人数']
|
||
slot = get_slot(int_res[0][0],query)
|
||
print(f"槽位获取结果:{slot}")
|
||
# 槽位预备,检查是否有缺失的槽位
|
||
status,sk = check_lost(int_res[0][0],slot)
|
||
print(f"检测是否槽位缺失:{status} 缺失:{sk}")
|
||
if status == 1:
|
||
# 调用Agent补全
|
||
# res = agent(sk,slot,messages)
|
||
res = agent1(int_res[0][0],messages)
|
||
return {"finished":0,"text":res}
|
||
else:
|
||
return {"finished":1,"int":int_res[0][0],"slot":slot}
|
||
else:
|
||
query = messages[0]['content']
|
||
int_res = int_predict(int_model,int_tokenizer,[query],int_label)
|
||
# int_res = int_res[0][0]
|
||
print(f"意图识别结果:{int_res}")
|
||
# slot = get_slot(int_res[0][0],query)
|
||
# # print(f"槽位获取结果:{slot}")
|
||
# # 槽位预备,检查是否有缺失的槽位
|
||
# status,sk = check_lost(int_res[0][0],slot)
|
||
# res = agent(sk,slot,messages)
|
||
res = agent1(int_res[0][0],messages)
|
||
if len(messages) >=7:
|
||
return {"finished":1,"int":int_res[0][0],"slot":res}
|
||
if isinstance(res,str):
|
||
return {"finished":0,"text":res}
|
||
else:
|
||
return {"finished":1,"int":int_res[0][0],"slot":res}
|
||
# st = res.find('{')
|
||
# et = res.find('}')
|
||
# if st != -1 and et != -1:
|
||
# res = res[st:et+1]
|
||
# res = eval(res)
|
||
# return {"finished":1,"int":int_res[0][0],"slot":res}
|
||
# else:
|
||
# return {"finished":0,"text":res}
|
||
|
||
# -----------------------------------------------------------------------
|
||
# 日志记录
|
||
def log_save(fun_name,params,result,exe_time):
|
||
cur_time =datetime.datetime.now()
|
||
cur_time = cur_time.strftime("%Y%m%d")
|
||
log_file = './log/log'+cur_time+".log"
|
||
f = open(log_file,'a',encoding='utf-8')
|
||
tmp = {}
|
||
tmp['fun_name'] = fun_name
|
||
tmp['params'] = params
|
||
tmp['result'] = result
|
||
tmp['exe_time'] = exe_time
|
||
tmp['time'] = (datetime.datetime.now()).strftime("%Y-%m-%d %H:%M:%S")
|
||
f.write(str(tmp)+'\n')
|
||
f.close()
|
||
|
||
class LabelMessage(BaseModel):
|
||
code: int = pydantic.Field(200, description="HTTP status code")
|
||
msg: str = pydantic.Field('成功', description="msg")
|
||
label: int = pydantic.Field(..., description="标签类型")
|
||
probability: float = pydantic.Field(..., description="概率")
|
||
|
||
async def get_int_label(
|
||
user_id: str = Body(...,description='鉴权'),
|
||
text: str = Body(..., description="意图识别样本"),
|
||
):
|
||
cur_time =datetime.datetime.now()
|
||
cur_time = cur_time.strftime("%Y:%m:%d:%H:%M:%S")
|
||
print(f"{cur_time} :: INT user_id:{user_id} text:{text}")
|
||
st = time.time()
|
||
if user_id !='3bb66776-1722-4c36-b14a-73dd210fe750':
|
||
return LabelMessage(
|
||
code=401,
|
||
msg = '权限验证失败,请联系接口开发人员',
|
||
label=-1,
|
||
probability=-1
|
||
)
|
||
global int_label,int_model,int_tokenizer
|
||
res = int_predict(int_model,int_tokenizer,[text],int_label)
|
||
log_save('int',text,res,round(time.time() - st,5))
|
||
|
||
return LabelMessage(
|
||
code=200,
|
||
msg = '成功',
|
||
label=res[0][0],
|
||
probability=res[0][1]
|
||
)
|
||
|
||
|
||
class LabelMessage1(BaseModel):
|
||
code: int = pydantic.Field(200, description="HTTP status code")
|
||
msg: str = pydantic.Field('成功', description="msg")
|
||
slot: Dict[str, Any] = pydantic.Field(..., description="槽位(JSON 格式)")
|
||
|
||
async def get_uie_label(
|
||
user_id: str = Body(...,description='鉴权'),
|
||
text: str = Body(..., description="槽位填充样本"),
|
||
):
|
||
cur_time =datetime.datetime.now()
|
||
cur_time = cur_time.strftime("%Y:%m:%d:%H:%M:%S")
|
||
print(f"{cur_time} :: UIE user_id:{user_id} text:{text}")
|
||
st = time.time()
|
||
if user_id !='3bb66776-1722-4c36-b14a-73dd210fe750':
|
||
return LabelMessage1(
|
||
code=401,
|
||
msg = '权限验证失败,请联系接口开发人员',
|
||
slot={}
|
||
)
|
||
global int_label,int_model,int_tokenizer
|
||
res = int_predict(int_model,int_tokenizer,[text],int_label)
|
||
slot = get_slot(res[0][0],text)
|
||
log_save('uie',text,{"int":res,"uie":slot},round(time.time() - st,5))
|
||
|
||
return LabelMessage1(
|
||
code=200,
|
||
msg = '成功',
|
||
slot=slot
|
||
)
|
||
|
||
class LabelMessage2(BaseModel):
|
||
code: int = pydantic.Field(200, description="HTTP status code")
|
||
msg: str = pydantic.Field('成功', description="msg")
|
||
answer: Dict[str, Any] = pydantic.Field(..., description="槽位(JSON 格式)")
|
||
|
||
async def Agent(
|
||
user_id: str = Body(...,description='鉴权'),
|
||
messages: List[Dict[str,str]] = Body(..., description="消息"),
|
||
):
|
||
cur_time =datetime.datetime.now()
|
||
cur_time = cur_time.strftime("%Y:%m:%d:%H:%M:%S")
|
||
print(f"{cur_time} :: AGENT user_id:{user_id} messages:{messages}")
|
||
st = time.time()
|
||
if user_id !='3bb66776-1722-4c36-b14a-73dd210fe750':
|
||
return LabelMessage2(
|
||
code=401,
|
||
msg = '权限验证失败,请联系接口开发人员',
|
||
answer='',
|
||
)
|
||
his = main(messages)
|
||
log_save('agent',messages,his,round(time.time() - st,5))
|
||
|
||
return LabelMessage2(
|
||
code=200,
|
||
msg = '成功',
|
||
answer=his,
|
||
)
|
||
|
||
# -----------------------------------------------------------------------
|
||
|
||
|
||
def api_start(host,port,**kwargs):
|
||
global app
|
||
|
||
app = FastAPI()
|
||
# Add CORS middleware to allow all origins
|
||
# 在config.py中设置OPEN_DOMAIN=True,允许跨域
|
||
# set OPEN_DOMAIN=True in config.py to allow cross-domain
|
||
if OPEN_CROSS_DOMAIN:
|
||
app.add_middleware(
|
||
CORSMiddleware,
|
||
allow_origins=["*"],
|
||
allow_credentials=True,
|
||
allow_methods=["*"],
|
||
allow_headers=["*"],
|
||
)
|
||
|
||
app.post("/intent_reco",response_model=LabelMessage, summary="意图识别")(get_int_label)
|
||
app.post("/slot_reco",response_model=LabelMessage1, summary="槽位抽取")(get_uie_label)
|
||
app.post("/agent",response_model=LabelMessage2, summary="Agent")(Agent)
|
||
|
||
uvicorn.run(app, host=host, port=port)
|
||
|
||
if __name__ == "__main__":
|
||
# messages = []
|
||
int_label = get_label("int")
|
||
int_model,int_tokenizer = get_model(model_names['intent_token'],model_names['intent_model'])
|
||
uie_model,uie_tokenizer = get_model(model_names['uie_token'],model_names['uie_model'])
|
||
host = "0.0.0.0"
|
||
port = 18074
|
||
api_start(host,port)
|
||
|
||
|