删除文件

This commit is contained in:
jiang 2025-02-21 16:50:01 +08:00
parent f89ef84f54
commit 3cea60d4fd
20 changed files with 0 additions and 3995 deletions

23
emb.log
View File

@ -1,23 +0,0 @@
nohup: ignoring input
初始模型加载
INFO: Started server process [2712]
INFO: Waiting for application startup.
INFO: Application startup complete.
INFO: Uvicorn running on http://0.0.0.0:6006 (Press CTRL+C to quit)
已更新work_name知识库现在包含知识条数2054
INFO: 127.0.0.1:54240 - "POST /v1/load_know HTTP/1.1" 200 OK
INFO: 127.0.0.1:50842 - "POST /v1/know_sim HTTP/1.1" 200 OK
INFO: 127.0.0.1:32800 - "POST /v1/know_sim HTTP/1.1" 200 OK
清空work_name知识库
已更新work_name知识库现在包含知识条数2054
INFO: 127.0.0.1:41464 - "POST /v1/load_know HTTP/1.1" 200 OK
INFO: 127.0.0.1:41478 - "POST /v1/know_sim HTTP/1.1" 200 OK
INFO: 127.0.0.1:41494 - "POST /v1/know_sim HTTP/1.1" 200 OK
INFO: 127.0.0.1:37822 - "POST /v1/know_sim HTTP/1.1" 200 OK
INFO: 127.0.0.1:54392 - "POST /v1/know_sim HTTP/1.1" 200 OK
INFO: 127.0.0.1:49690 - "POST /v1/know_sim HTTP/1.1" 200 OK
INFO: 127.0.0.1:39220 - "POST /v1/know_sim HTTP/1.1" 200 OK
INFO: 127.0.0.1:47842 - "POST /v1/know_sim HTTP/1.1" 200 OK
INFO: 127.0.0.1:37910 - "POST /v1/know_sim HTTP/1.1" 200 OK
INFO: 127.0.0.1:58032 - "POST /v1/know_sim HTTP/1.1" 200 OK
INFO: 127.0.0.1:43468 - "POST /v1/know_sim HTTP/1.1" 200 OK

View File

@ -1,104 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
import numpy as np
import logging
import uvicorn
import datetime
from fastapi import FastAPI, Security, HTTPException
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from pydantic import Field, BaseModel, validator
from typing import Optional, List,Dict
from models import *
def response(code, msg, data=None):
time = str(datetime.datetime.now())
if data is None:
data = []
result = {
"code": code,
"message": msg,
"data": data,
"time": time
}
return result
def success(data=None, msg=''):
return
class QADocs(BaseModel):
query: Optional[str]
documents: Optional[List[str]]
class Knows(BaseModel):
know_id: Optional[str]
contents: Optional[List[str]]
drop_dup: Optional[bool]
is_cover: Optional[bool]
class Know_Sim(BaseModel):
query: Optional[str]
know_id: Optional[str]
top_k: Optional[int] = 10
app = FastAPI()
security = HTTPBearer()
env_bearer_token = 'ACCESS_TOKEN'
@app.post('/v1/embedding')
async def handle_post_request1(sentences1: List[str], sentences2: List[str], credentials: HTTPAuthorizationCredentials = Security(security)):
global know_pass
token = credentials.credentials
if env_bearer_token is not None and token != env_bearer_token:
raise HTTPException(status_code=401, detail="Invalid token")
similarity = know_pass.get_similarity_pair(sentences1,sentences2)
return response(200, msg="获取相似性成功",data=similarity)
@app.post('/v1/load_know')
async def handle_post_request2(knows: Knows, credentials: HTTPAuthorizationCredentials = Security(security)):
know_id = knows.know_id
contents = knows.contents
drop_dup = knows.drop_dup
is_cover = knows.is_cover
global know_pass
token = credentials.credentials
if env_bearer_token is not None and token != env_bearer_token:
raise HTTPException(status_code=401, detail="Invalid token")
know_pass.load_know(know_id, contents, drop_dup,is_cover)
return response(200, msg=f"生成{know_id}知识库成功")
@app.post('/v1/know_sim')
async def handle_post_request3(know_sim:Know_Sim, credentials: HTTPAuthorizationCredentials = Security(security)):
query = know_sim.query
know_id = know_sim.know_id
top_k = know_sim.top_k
global know_pass
token = credentials.credentials
if env_bearer_token is not None and token != env_bearer_token:
raise HTTPException(status_code=401, detail="Invalid token")
similarity = know_pass.get_similarity_know(query,know_id,top_k)
return response(200, msg="获取相似性成功",data=similarity)
def init_env():
# 初始化模型
print("初始模型加载")
# embedding = BCE_EMB(EMBEDDING_MODEL_PATH)
EMBEDDING_MODEL_PATH = "/home/workplace/models/m3e-base"
embedding = M3E_EMB(EMBEDDING_MODEL_PATH)
know_pass = KNOW_PASS(embedding)
return know_pass
if __name__ == "__main__":
token = os.getenv("ACCESS_TOKEN")
if token is not None:
env_bearer_token = token
try:
# 默认环境
know_pass = init_env()
uvicorn.run(app, host='0.0.0.0', port=6006)
except Exception as e:
print(f"API启动失败\n报错:\n{e}")

View File

@ -1,11 +0,0 @@
天气查询
互联网查询
页面切换
日计划数量查询
周计划数量查询
日计划作业内容
周计划作业内容
施工人数
作业考勤人数
知识问答

View File

@ -1,12 +0,0 @@
python train.py \
--do_eval \
--debug True \
--device gpu \
--model_name_or_path checkpoint \
--output_dir checkpoint \
--per_device_eval_batch_size 32 \
--max_length 256 \
--label_path ./data/label.txt \
--train_path ./data/train.txt \
--test_path ./data/test.txt \
--dev_path ./data/dev.txt

View File

@ -1,231 +0,0 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import json
import os
import shutil
from dataclasses import dataclass, field
from pathlib import Path
from typing import Optional
import numpy as np
import paddle
from sklearn.metrics import (
accuracy_score,
classification_report,
precision_recall_fscore_support,
)
from utils import log_metrics_debug, preprocess_function, read_local_dataset
from paddlenlp.data import DataCollatorWithPadding
from paddlenlp.datasets import load_dataset
from paddlenlp.trainer import (
CompressionArguments,
EarlyStoppingCallback,
PdArgumentParser,
Trainer,
)
from paddlenlp.transformers import (
AutoModelForSequenceClassification,
AutoTokenizer,
export_model,
)
from paddlenlp.utils.log import logger
SUPPORTED_MODELS = [
"ernie-1.0-large-zh-cw",
"ernie-1.0-base-zh-cw",
"ernie-3.0-xbase-zh",
"ernie-3.0-base-zh",
"ernie-3.0-medium-zh",
"ernie-3.0-micro-zh",
"ernie-3.0-mini-zh",
"ernie-3.0-nano-zh",
"ernie-3.0-tiny-base-v2-zh",
"ernie-3.0-tiny-medium-v2-zh",
"ernie-3.0-tiny-micro-v2-zh",
"ernie-3.0-tiny-mini-v2-zh",
"ernie-3.0-tiny-nano-v2-zh ",
"ernie-3.0-tiny-pico-v2-zh",
"ernie-2.0-large-en",
"ernie-2.0-base-en",
"ernie-3.0-tiny-mini-v2-en",
"ernie-m-base",
"ernie-m-large",
]
# yapf: disable
@dataclass
class DataArguments:
max_length: int = field(default=128, metadata={"help": "Maximum number of tokens for the model."})
early_stopping: bool = field(default=False, metadata={"help": "Whether apply early stopping strategy."})
early_stopping_patience: int = field(default=4, metadata={"help": "Stop training when the specified metric worsens for early_stopping_patience evaluation calls"})
debug: bool = field(default=False, metadata={"help": "Whether choose debug mode."})
train_path: str = field(default='./data/train.txt', metadata={"help": "Train dataset file path."})
dev_path: str = field(default='./data/dev.txt', metadata={"help": "Dev dataset file path."})
test_path: str = field(default='./data/dev.txt', metadata={"help": "Test dataset file path."})
label_path: str = field(default='./data/label.txt', metadata={"help": "Label file path."})
bad_case_path: str = field(default='./data/bad_case.txt', metadata={"help": "Bad case file path."})
@dataclass
class ModelArguments:
model_name_or_path: str = field(default="ernie-3.0-tiny-medium-v2-zh", metadata={"help": "Build-in pretrained model name or the path to local model."})
export_model_dir: Optional[str] = field(default=None, metadata={"help": "Path to directory to store the exported inference model."})
# yapf: enable
def main():
"""
Training a binary or multi classification model
"""
parser = PdArgumentParser((ModelArguments, DataArguments, CompressionArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
if training_args.do_compress:
training_args.strategy = "dynabert"
if training_args.do_train or training_args.do_compress:
training_args.print_config(model_args, "Model")
training_args.print_config(data_args, "Data")
paddle.set_device(training_args.device)
# Define id2label
id2label = {}
label2id = {}
with open(data_args.label_path, "r", encoding="utf-8") as f:
for i, line in enumerate(f):
l = line.strip()
id2label[i] = l
label2id[l] = i
# Define model & tokenizer
if os.path.isdir(model_args.model_name_or_path):
model = AutoModelForSequenceClassification.from_pretrained(
model_args.model_name_or_path, label2id=label2id, id2label=id2label
)
elif model_args.model_name_or_path in SUPPORTED_MODELS:
model = AutoModelForSequenceClassification.from_pretrained(
model_args.model_name_or_path, num_classes=len(label2id), label2id=label2id, id2label=id2label
)
else:
raise ValueError(
f"{model_args.model_name_or_path} is not a supported model type. Either use a local model path or select a model from {SUPPORTED_MODELS}"
)
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path)
# load and preprocess dataset
train_ds = load_dataset(read_local_dataset, path=data_args.train_path, label2id=label2id, lazy=False)
dev_ds = load_dataset(read_local_dataset, path=data_args.dev_path, label2id=label2id, lazy=False)
trans_func = functools.partial(preprocess_function, tokenizer=tokenizer, max_length=data_args.max_length)
train_ds = train_ds.map(trans_func)
dev_ds = dev_ds.map(trans_func)
if data_args.debug:
test_ds = load_dataset(read_local_dataset, path=data_args.test_path, label2id=label2id, lazy=False)
test_ds = test_ds.map(trans_func)
# Define the metric function.
def compute_metrics(eval_preds):
pred_ids = np.argmax(eval_preds.predictions, axis=-1)
metrics = {}
metrics["accuracy"] = accuracy_score(y_true=eval_preds.label_ids, y_pred=pred_ids)
for average in ["micro", "macro"]:
precision, recall, f1, _ = precision_recall_fscore_support(
y_true=eval_preds.label_ids, y_pred=pred_ids, average=average
)
metrics[f"{average}_precision"] = precision
metrics[f"{average}_recall"] = recall
metrics[f"{average}_f1"] = f1
return metrics
def compute_metrics_debug(eval_preds):
pred_ids = np.argmax(eval_preds.predictions, axis=-1)
metrics = classification_report(eval_preds.label_ids, pred_ids, output_dict=True)
return metrics
# Define the early-stopping callback.
if data_args.early_stopping:
callbacks = [EarlyStoppingCallback(early_stopping_patience=data_args.early_stopping_patience)]
else:
callbacks = None
# Define Trainer
trainer = Trainer(
model=model,
tokenizer=tokenizer,
args=training_args,
criterion=paddle.nn.loss.CrossEntropyLoss(),
train_dataset=train_ds,
eval_dataset=dev_ds,
callbacks=callbacks,
data_collator=DataCollatorWithPadding(tokenizer),
compute_metrics=compute_metrics_debug if data_args.debug else compute_metrics,
)
# Training
if training_args.do_train:
train_result = trainer.train()
metrics = train_result.metrics
trainer.save_model()
trainer.log_metrics("train", metrics)
for checkpoint_path in Path(training_args.output_dir).glob("checkpoint-*"):
shutil.rmtree(checkpoint_path)
# Evaluate and tests model
if training_args.do_eval:
if data_args.debug:
output = trainer.predict(test_ds)
log_metrics_debug(output, id2label, test_ds, data_args.bad_case_path)
else:
eval_metrics = trainer.evaluate()
trainer.log_metrics("eval", eval_metrics)
# export inference model
if training_args.do_export:
if model.init_config["init_class"] in ["ErnieMForSequenceClassification"]:
input_spec = [paddle.static.InputSpec(shape=[None, None], dtype="int64", name="input_ids")]
else:
input_spec = [
paddle.static.InputSpec(shape=[None, None], dtype="int64", name="input_ids"),
paddle.static.InputSpec(shape=[None, None], dtype="int64", name="token_type_ids"),
]
if model_args.export_model_dir is None:
model_args.export_model_dir = os.path.join(training_args.output_dir, "export")
export_model(model=trainer.model, input_spec=input_spec, path=model_args.export_model_dir)
tokenizer.save_pretrained(model_args.export_model_dir)
id2label_file = os.path.join(model_args.export_model_dir, "id2label.json")
with open(id2label_file, "w", encoding="utf-8") as f:
json.dump(id2label, f, ensure_ascii=False)
logger.info(f"id2label file saved in {id2label_file}")
# compress
if training_args.do_compress:
trainer.compress()
for width_mult in training_args.width_mult_list:
pruned_infer_model_dir = os.path.join(training_args.output_dir, "width_mult_" + str(round(width_mult, 2)))
tokenizer.save_pretrained(pruned_infer_model_dir)
id2label_file = os.path.join(pruned_infer_model_dir, "id2label.json")
with open(id2label_file, "w", encoding="utf-8") as f:
json.dump(id2label, f, ensure_ascii=False)
logger.info(f"id2label file saved in {id2label_file}")
for path in Path(training_args.output_dir).glob("runs"):
shutil.rmtree(path)
if __name__ == "__main__":
main()

View File

@ -1,24 +0,0 @@
python train.py \
--do_train \
--do_eval \
--do_export \
--model_name_or_path /home/workplace/models/ernie-3.0-tiny-base-v2-zh \
--output_dir checkpoint \
--device gpu:0 \
--train_path ./data/train.txt \
--dev_path ./data/dev.txt \
--label_path ./data/label.txt \
--num_train_epochs 100 \
--early_stopping True \
--early_stopping_patience 5 \
--learning_rate 3e-5 \
--max_length 256 \
--per_device_eval_batch_size 16 \
--per_device_train_batch_size 16 \
--metric_for_best_model accuracy \
--load_best_model_at_end \
--logging_steps 5 \
--evaluation_strategy epoch \
--save_strategy epoch \
--save_total_limit 1

View File

@ -1,89 +0,0 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from paddlenlp.utils.log import logger
def preprocess_function(examples, tokenizer, max_length, is_test=False):
"""
Builds model inputs from a sequence for sequence classification tasks
by concatenating and adding special tokens.
"""
result = tokenizer(examples["text"], max_length=max_length, truncation=True)
if not is_test:
result["labels"] = np.array([examples["label"]], dtype="int64")
return result
def read_local_dataset(path, label2id=None, is_test=False):
"""
Read dataset.
"""
with open(path, "r", encoding="utf-8") as f:
for line in f:
if is_test:
sentence = line.strip()
yield {"text": sentence}
else:
items = line.strip().split("\t")
if len(items) != 2:
continue
yield {"text": items[0], "label": label2id[items[1]]}
def log_metrics_debug(output, id2label, dev_ds, bad_case_path):
"""
Log metrics in debug mode.
"""
predictions, label_ids, metrics = output
pred_ids = np.argmax(predictions, axis=-1)
logger.info("-----Evaluate model-------")
logger.info("Dev dataset size: {}".format(len(dev_ds)))
logger.info("Accuracy in dev dataset: {:.2f}%".format(metrics["test_accuracy"] * 100))
logger.info(
"Macro average | precision: {:.2f} | recall: {:.2f} | F1 score {:.2f}".format(
metrics["test_macro avg"]["precision"] * 100,
metrics["test_macro avg"]["recall"] * 100,
metrics["test_macro avg"]["f1-score"] * 100,
)
)
for i in id2label:
l = id2label[i]
logger.info("Class name: {}".format(l))
i = "test_" + str(i)
if i in metrics:
logger.info(
"Evaluation examples in dev dataset: {}({:.1f}%) | precision: {:.2f} | recall: {:.2f} | F1 score {:.2f}".format(
metrics[i]["support"],
100 * metrics[i]["support"] / len(dev_ds),
metrics[i]["precision"] * 100,
metrics[i]["recall"] * 100,
metrics[i]["f1-score"] * 100,
)
)
else:
logger.info("Evaluation examples in dev dataset: 0 (0%)")
logger.info("----------------------------")
with open(bad_case_path, "w", encoding="utf-8") as f:
f.write("Text\tLabel\tPrediction\n")
for i, (p, l) in enumerate(zip(pred_ids, label_ids)):
p, l = int(p), int(l)
if p != l:
f.write(dev_ds.data[i]["text"] + "\t" + id2label[l] + "\t" + id2label[p] + "\n")
logger.info("Bad case in dev dataset saved in {}".format(bad_case_path))

View File

@ -1,586 +0,0 @@
# phddns start
import paddle
import paddle.nn.functional as F
from paddlenlp.data import Tuple,Pad
from paddlenlp.transformers import AutoModelForSequenceClassification, AutoTokenizer
import argparse
import pandas as pd
import time
import datetime
import pydantic
import uvicorn
from fastapi import Body, FastAPI, File, Form, Query, UploadFile, WebSocket
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import List, Optional
from typing import Dict, Any
import sys
from post_api import Similar
sim = Similar()
def load_know():
file_path = './knows/project_know.txt'
f = open(file_path,'r',encoding='utf-8')
data = f.read()
data = data.split('\n')
works = {}
for d in data:
wk = d[:d.find('(')]
works[wk] = d
keys = list(works.keys())
sim.load_know('work_name',keys,drop_dup=True,is_cover=True)
return works
works = load_know()
# print(works)
def standard_project(work_name):
res = sim.know_sim(work_name,'work_name')
res = eval(res)
res = res['data']['results'][0]
t = res['document']['text']
s = res['document']['relevance_score']
print(t)
print(works[t])
print(s)
print(res)
if s > 330:
return works[t]
else:
return work_name
OPEN_CROSS_DOMAIN = True
def time_consume(func):
def inner(*args, **kwargs): # 加入可变参数
time_start = time.time()
print("\n接口请求前的时间是:",datetime.datetime.now())
res = func(*args, **kwargs)
time_end = time.time()
print("\n接口请求后的时间是:",datetime.datetime.now())
t = time_end - time_start
print("接口耗时:", round(t,4),"s")
return res
return inner
parser = argparse.ArgumentParser()
parser.add_argument("--device", default="gpu:0", help="Select which device to train model, default to gpu.")
parser.add_argument("--output_file", default="output.txt",type=str)
parser.add_argument("--max_seq_length",
default=128,
type=int,
help="The maximum total input sequence length "
"after tokenization. Sequences longer than this"
"will be truncated, sequences shorter will be padded.")
parser.add_argument("--batch_size",
default=64,
type=int,
help="Batch size per GPU/CPU for training.")
parser.add_argument("--max_seq_len",
default=256,
type=int,
help="The maximum total input sequence length after tokenization.")
parser.add_argument("--schema_lang", choices=["ch", "en"], default="ch", help="Select the language type for schema.")
args = parser.parse_args()
model_names = {
"intent_token": "/home/workplace/models/zero_int",
"intent_model": "/home/workplace/models/zero_int/model",
"uie_token":"/home/workplace/models/zero_uie",
"uie_model":"/home/workplace/models/zero_uie/model",
}
label_names = {
"int" : "./int/data/label.txt",
}
def get_label(label_name):
label = []
with open(label_names[label_name]) as inf:
for idx, line in enumerate(inf):
line = line.strip()
if line != '':
label.append(line)
return label
def get_model(token_path,model_path):
paddle.set_device(args.device)
model = paddle.jit.load(model_path)
tokenizer = AutoTokenizer.from_pretrained(token_path)
return model,tokenizer
@paddle.no_grad()
def int_predict(model, tokenizer, data, label_list):
examples = []
for text in data:
result = tokenizer(text=text, max_seq_length=args.max_seq_length)
examples.append((result['input_ids'], result['token_type_ids']))
# Seperates data into some batches
batches = [
examples[i: i + args.batch_size]
for i in range(0, len(examples), args.batch_size)
]
batchify_fn = lambda samples, fn=Tuple(
Pad(axis=0, pad_val=tokenizer.pad_token_id),
Pad(axis=0, pad_val=tokenizer.pad_token_type_id),
): fn(samples)
results = []
model.eval()
for batch in batches:
input_ids, token_type_ids = batchify_fn(batch)
input_ids = paddle.to_tensor(input_ids)
token_type_ids = paddle.to_tensor(token_type_ids)
logits = model(input_ids, token_type_ids)
probs = F.softmax(logits, axis=1)
idx = paddle.argmax(probs, axis=1).numpy()
idx = idx.tolist()
pros = probs.tolist()
labels = []
for lbi,pro in zip(idx,pros):
#print(pro)
#print(pro[lbi])
labels.append((label_list[lbi], str(round(pro[lbi],5))))
results.extend(labels)
mapping = {"天气查询":1,"互联网查询":2,"页面切换":3,
"日计划数量查询":4,"周计划数量查询":5,"日计划作业内容":6,"周计划作业内容":7,
"施工人数":8,"作业考勤人数":9,"知识问答":10}
nresults = []
for res in results:
nresults.append((mapping[res[0]],res[1]))
return nresults
@paddle.no_grad()
def uie_predict(model, tokenizer, data):
examples = []
for text in data:
result = tokenizer(text=[text['prompt']],
text_pair= [text['content']],
truncation=True,
max_seq_length=args.max_seq_len,
pad_to_max_seq_len=True,
return_attention_mask=True,
return_position_ids=True,
return_dict=False,
return_offsets_mapping=True)
result = result[0]
# print(result)
tokens = tokenizer.convert_ids_to_tokens(result["input_ids"])
# print(tokens)
examples.append((result['input_ids'], result['token_type_ids'],result['position_ids'],result['attention_mask']))
# Seperates data into some batches
batches = [
examples[i: i + args.batch_size]
for i in range(0, len(examples), args.batch_size)
]
batchify_fn = lambda samples, fn=Tuple(
Pad(axis=0, pad_val=tokenizer.pad_token_id),
Pad(axis=0, pad_val=tokenizer.pad_token_type_id),
Pad(axis=0, pad_val=0),
Pad(axis=0, pad_val=0),
): fn(samples)
results = []
model.eval()
for batch in batches:
input_ids, token_type_ids,position_ids,attention_mask = batchify_fn(batch)
input_ids = paddle.to_tensor(input_ids)
token_type_ids = paddle.to_tensor(token_type_ids)
position_ids = paddle.to_tensor(position_ids)
attention_mask = paddle.to_tensor(attention_mask)
start_prob,end_prob = model(input_ids, token_type_ids,position_ids,attention_mask)
# 获取最大值及其索引
stmax_value = paddle.max(start_prob)
stmax_value = stmax_value.numpy().item()
stmax_idx = paddle.argmax(start_prob)
# print(f"起始点最大值: {stmax_value.numpy()}")
# print(f"起始点最大值索引: {stmax_idx.numpy()}")
etmax_value = paddle.max(end_prob)
etmax_value = etmax_value.numpy().item()
etmax_idx = paddle.argmax(end_prob)
# print(f"终止点点最大值: {etmax_value.numpy()}")
# print(f"终止点最大值索引: {etmax_idx.numpy()}")
# print(data[0]['content'][stmax_idx:etmax_idx]) # 前面会增加prompt因此需要多加4个token
tokens = tokenizer.convert_ids_to_tokens(result["input_ids"])
res = ''.join(tokens[stmax_idx:etmax_idx+1])
results.append((res,stmax_value,etmax_value,stmax_idx.numpy().item(),etmax_idx.numpy().item()))
return results
def get_slot(int_res,query):
slot_int = [1,3,4,5,6,7,8,9]
if int_res not in slot_int:
return {}
if int_res==1: #'天气查询'
prompts = ['时间','地点']
elif int_res==3: #'页面切换'
prompts = ['应用','模块','页面']
else:
prompts = ['时间','项目','项目经理','公司','工程','风险等级','周时间','施工状态']
slot = {}
comp = {}
for prompt in prompts:
uie_res = uie_predict(uie_model,uie_tokenizer,[{"prompt":prompt,"content":query}])
print(f"prompt:{prompt} uie_res:{uie_res}")
uie_res[0] = list(uie_res[0])
uie_res[0][3] = uie_res[0][3] - len(prompt)
uie_res[0][4] = uie_res[0][4] - len(prompt)
if uie_res[0][1] >= 0.8 and uie_res[0][2] >= 0.8:
# print(f"prompt:{prompt},值:{uie_res}")
slot[prompt] = uie_res[0][0]
comp[prompt] = uie_res
drop = []
for k,v in slot.items():
flag = 1
for kk,vv in comp.items():
if kk == k:
continue
if comp[k][0][4] < vv[0][3] or comp[k][0][3] > vv[0][4]:
continue
elif comp[k][0][1]<=vv[0][1] and comp[k][0][2] <=vv[0][2]:
print(f"1:{comp[k]}")
print(f"2:{vv}")
flag = 0
break
if flag==0:
drop.append(k)
for d in drop:
del slot[d]
if int_res == 5: # '周计划查询'
if slot.__contains__("周时间"):
slot['时间'] = slot['周时间']
del slot['周时间']
else:
if slot.__contains__("周时间"):
del slot['周时间']
cn_en = {"时间":"date","项目":"project","项目经理":"people",
"公司":"company","工程":"work","风险等级":"risk_level",
"施工状态":"work_status","应用":"app","模块":"module","页面":"page","地点":"area"}
nslot = {}
for k,v in slot.items():
v = v.replace('#','').replace('kv','kV').replace("模块","").replace("页面","")
if k == '工程':
print(v)
v = standard_project(v)
nslot[cn_en[k]] = v
return nslot
def check_lost(int_res,slot):
# mapping = {
# "页面切换":[['页面','应用']],
# "作业计划数量查询":[['时间']],
# "周计划查询":[['时间']],
# "作业内容":[['时间']],
# "施工人数":[['时间']],
# "作业考勤人数":[['时间']],
# }
mapping = {
1:[['date','area']],
3:[['page'],['app'],['module']],
4:[['date']],
5:[['date']],
6:[['date']],
7:[['date']],
8:[[]],
9:[[]],
}
if not mapping.__contains__(int_res):
return 0,[]
cur_k = list(slot.keys())
idx = -1
idx_len = 99
for i in range(len(mapping[int_res])):
sk = mapping[int_res][i]
left = [x for x in sk if x not in cur_k]
more = [x for x in cur_k if x not in sk]
if len(more) >= 0 and len(left) == 0:
idx = i
idx_len = 0
break
if len(left) < idx_len:
idx = i
idx_len =len(left)
if idx_len == 0: # 匹配通过
return 0,cur_k
left= [x for x in mapping[int_res][idx] if x not in cur_k]
return 1,left#mapping[int_res][idx]
@time_consume
def agent(sk,slot,messages):
from openai import OpenAI
# api_base_url='http://36.33.26.201:27861/v1'
# api_key='EMPTY'
# model_name = 'qwen2.5-instruct'
# api_base_url = "http://36.33.26.201:27861/chat"
# api_base_url = "http://58.57.119.80:45291/v1"
# model_name="qwen2.5-instruct"
client = OpenAI(api_key=api_key,
base_url=api_base_url)
prompt = f'''
你是一个槽位询问工具你需要结合用户输入根据必填槽位列表以及当前已有槽位对用户进行槽位问询
询问用户哪些槽位是缺失的需要补充待必填槽位都已填充完毕将所有必填槽位类型及槽位值以json格式返回
请不要对用户输入的槽位进行修改
###以下是必填槽位列表###
need_slot_lists = {sk}
###以上是必填槽位列表###
###以下是当前已有槽位###
present_slot={slot}
###
'''
# prompt = prompt.replace("{sk}",str(sk)).replace("{slot}",slot)
message = [{"role":"system","content":prompt}]
message.extend(messages)
# print(message)
response = client.chat.completions.create(
messages=message,
model=model_name,
max_tokens=1000,
temperature=0.001,
stream=False
)
res = response.choices[0].message.content
return res
@time_consume
def agent1(int_res,messages):
lslot = {}
for message in messages:
if message["role"] == 'user':
# print(int_res)
# print(message['content'])
slot = get_slot(int_res,message['content'])
for k,v in slot.items():
lslot[k] = v
# print(lslot)
status,sk = check_lost(int_res,lslot)
# print(f"检测是否槽位缺失:{status} 缺失:{sk}")
# print(lslot)
# print(4567)
# print(int_res)
reres = ''
if len(messages)==3:
reres="非常抱歉,"
elif len(messages) == 5:
reres="特别抱歉,我没有理解您的意思,"
if len(messages) >= 7:
return lslot
if status == 1:
# print(11122)
if int_res == 1: # 天气查询
if sk==['date','area']:
return f"{reres}请问您想查询的天气的时间和地点分别是什么?"
elif sk ==['area']:
return f"{reres}请问您想查询哪里的天气情况?"
else:
return f"{reres}请问您想查询哪一天的天气情况?"
if int_res == 3:
return f"{reres}请问您想查询哪个页面?"
if int_res in [4,5,6,7,8,9]:
mapping = {4:"日计划数量",5:"周计划数量",6:"日计划作业内容",7:"周计划作业内容",
8:"施工人数",9:"作业考勤人数"}
# print(f"123{reres}请问您想查询哪天的{mapping[int_res]}")
return f"{reres}请问您想查询哪天的{mapping[int_res]}"
else:
return lslot
# @time_consume
def main(messages):
if len(messages) == 1: # 首轮
query = messages[0]['content']
# query = "请帮我打开风险管理应用的风险勘察申请"
# query = "未来一周内刘思辰500千伏变电站新建工程有多少三级风险项作业计划刚开始施工"
int_res = int_predict(int_model,int_tokenizer,[query],int_label)
# int_res = int_res[0][0]
print(f"意图识别结果:{int_res}")
# slot_int = ['页面切换','作业计划数量查询','周计划查询','作业内容','施工人数','作业考勤人数']
slot = get_slot(int_res[0][0],query)
print(f"槽位获取结果:{slot}")
# 槽位预备,检查是否有缺失的槽位
status,sk = check_lost(int_res[0][0],slot)
print(f"检测是否槽位缺失:{status} 缺失:{sk}")
if status == 1:
# 调用Agent补全
# res = agent(sk,slot,messages)
res = agent1(int_res[0][0],messages)
return {"finished":0,"text":res}
else:
return {"finished":1,"int":int_res[0][0],"slot":slot}
else:
query = messages[0]['content']
int_res = int_predict(int_model,int_tokenizer,[query],int_label)
# int_res = int_res[0][0]
print(f"意图识别结果:{int_res}")
# slot = get_slot(int_res[0][0],query)
# # print(f"槽位获取结果:{slot}")
# # 槽位预备,检查是否有缺失的槽位
# status,sk = check_lost(int_res[0][0],slot)
# res = agent(sk,slot,messages)
res = agent1(int_res[0][0],messages)
if len(messages) >=7:
return {"finished":1,"int":int_res[0][0],"slot":res}
if isinstance(res,str):
return {"finished":0,"text":res}
else:
return {"finished":1,"int":int_res[0][0],"slot":res}
# st = res.find('{')
# et = res.find('}')
# if st != -1 and et != -1:
# res = res[st:et+1]
# res = eval(res)
# return {"finished":1,"int":int_res[0][0],"slot":res}
# else:
# return {"finished":0,"text":res}
# -----------------------------------------------------------------------
# 日志记录
def log_save(fun_name,params,result,exe_time):
cur_time =datetime.datetime.now()
cur_time = cur_time.strftime("%Y%m%d")
log_file = './log/log'+cur_time+".log"
f = open(log_file,'a',encoding='utf-8')
tmp = {}
tmp['fun_name'] = fun_name
tmp['params'] = params
tmp['result'] = result
tmp['exe_time'] = exe_time
tmp['time'] = (datetime.datetime.now()).strftime("%Y-%m-%d %H:%M:%S")
f.write(str(tmp)+'\n')
f.close()
class LabelMessage(BaseModel):
code: int = pydantic.Field(200, description="HTTP status code")
msg: str = pydantic.Field('成功', description="msg")
label: int = pydantic.Field(..., description="标签类型")
probability: float = pydantic.Field(..., description="概率")
async def get_int_label(
user_id: str = Body(...,description='鉴权'),
text: str = Body(..., description="意图识别样本"),
):
cur_time =datetime.datetime.now()
cur_time = cur_time.strftime("%Y:%m:%d:%H:%M:%S")
print(f"{cur_time} :: INT user_id:{user_id} text:{text}")
st = time.time()
if user_id !='3bb66776-1722-4c36-b14a-73dd210fe750':
return LabelMessage(
code=401,
msg = '权限验证失败,请联系接口开发人员',
label=-1,
probability=-1
)
global int_label,int_model,int_tokenizer
res = int_predict(int_model,int_tokenizer,[text],int_label)
log_save('int',text,res,round(time.time() - st,5))
return LabelMessage(
code=200,
msg = '成功',
label=res[0][0],
probability=res[0][1]
)
class LabelMessage1(BaseModel):
code: int = pydantic.Field(200, description="HTTP status code")
msg: str = pydantic.Field('成功', description="msg")
slot: Dict[str, Any] = pydantic.Field(..., description="槽位JSON 格式)")
async def get_uie_label(
user_id: str = Body(...,description='鉴权'),
text: str = Body(..., description="槽位填充样本"),
):
cur_time =datetime.datetime.now()
cur_time = cur_time.strftime("%Y:%m:%d:%H:%M:%S")
print(f"{cur_time} :: UIE user_id:{user_id} text:{text}")
st = time.time()
if user_id !='3bb66776-1722-4c36-b14a-73dd210fe750':
return LabelMessage1(
code=401,
msg = '权限验证失败,请联系接口开发人员',
slot={}
)
global int_label,int_model,int_tokenizer
res = int_predict(int_model,int_tokenizer,[text],int_label)
slot = get_slot(res[0][0],text)
log_save('uie',text,{"int":res,"uie":slot},round(time.time() - st,5))
return LabelMessage1(
code=200,
msg = '成功',
slot=slot
)
class LabelMessage2(BaseModel):
code: int = pydantic.Field(200, description="HTTP status code")
msg: str = pydantic.Field('成功', description="msg")
answer: Dict[str, Any] = pydantic.Field(..., description="槽位JSON 格式)")
async def Agent(
user_id: str = Body(...,description='鉴权'),
messages: List[Dict[str,str]] = Body(..., description="消息"),
):
cur_time =datetime.datetime.now()
cur_time = cur_time.strftime("%Y:%m:%d:%H:%M:%S")
print(f"{cur_time} :: AGENT user_id:{user_id} messages:{messages}")
st = time.time()
if user_id !='3bb66776-1722-4c36-b14a-73dd210fe750':
return LabelMessage2(
code=401,
msg = '权限验证失败,请联系接口开发人员',
answer='',
)
his = main(messages)
log_save('agent',messages,his,round(time.time() - st,5))
return LabelMessage2(
code=200,
msg = '成功',
answer=his,
)
# -----------------------------------------------------------------------
def api_start(host,port,**kwargs):
global app
app = FastAPI()
# Add CORS middleware to allow all origins
# 在config.py中设置OPEN_DOMAIN=True允许跨域
# set OPEN_DOMAIN=True in config.py to allow cross-domain
if OPEN_CROSS_DOMAIN:
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.post("/intent_reco",response_model=LabelMessage, summary="意图识别")(get_int_label)
app.post("/slot_reco",response_model=LabelMessage1, summary="槽位抽取")(get_uie_label)
app.post("/agent",response_model=LabelMessage2, summary="Agent")(Agent)
uvicorn.run(app, host=host, port=port)
if __name__ == "__main__":
# messages = []
int_label = get_label("int")
int_model,int_tokenizer = get_model(model_names['intent_token'],model_names['intent_model'])
uie_model,uie_tokenizer = get_model(model_names['uie_token'],model_names['uie_model'])
host = "0.0.0.0"
port = 8072
api_start(host,port)

File diff suppressed because it is too large Load Diff

BIN
log.log

Binary file not shown.

View File

@ -1,15 +0,0 @@
{'fun_name': 'agent', 'params': [{'role': 'user', 'content': '今天临关5374线跨越明巢高速公路迁改工程多少项工作计划'}], 'result': {'finished': 1, 'int': 5, 'slot': {'date': '今天', 'work': '500kV临昭5373/临关5374线跨越明巢高速公路迁改工程(PROJ-2021-0013)'}}, 'exe_time': 2.97077, 'time': '2025-02-06 21:51:56'}
{'fun_name': 'agent', 'params': [{'role': 'user', 'content': '今天临关5374线跨越明巢高速公路迁改工程多少项工作计划'}], 'result': {'finished': 1, 'int': 5, 'slot': {'date': '今天', 'work': '500kV临昭5373/临关5374线跨越明巢高速公路迁改工程(PROJ-2021-0013)'}}, 'exe_time': 0.35625, 'time': '2025-02-06 21:52:50'}
{'fun_name': 'agent', 'params': [{'role': 'user', 'content': '送变电一公司正在施工的作业内容是什么'}, {'role': 'assistant', 'content': '请问您想查询哪天的周计划作业内容'}, {'role': 'user', 'content': '本周'}], 'result': {'finished': 1, 'int': 7, 'slot': {'company': '送变电一公司', 'work_status': '正在施工', 'date': '本周'}}, 'exe_time': 0.48126, 'time': '2025-02-06 21:52:50'}
{'fun_name': 'agent', 'params': [{'role': 'user', 'content': '今天临关5374线跨越明巢高速公路迁改工程多少项工作计划'}], 'result': {'finished': 1, 'int': 5, 'slot': {'date': '今天', 'work': '500kV临昭5373/临关5374线跨越明巢高速公路迁改工程(PROJ-2021-0013)'}}, 'exe_time': 2.88186, 'time': '2025-02-06 21:54:26'}
{'fun_name': 'agent', 'params': [{'role': 'user', 'content': '送变电一公司正在施工的作业内容是什么'}, {'role': 'assistant', 'content': '请问您想查询哪天的周计划作业内容'}, {'role': 'user', 'content': '本周'}], 'result': {'finished': 1, 'int': 7, 'slot': {'company': '送变电一公司', 'work_status': '正在施工', 'date': '本周'}}, 'exe_time': 0.42058, 'time': '2025-02-06 21:54:26'}
{'fun_name': 'agent', 'params': [{'role': 'user', 'content': '今天临关5374线跨越明巢高速公路迁改工程多少项工作计划'}], 'result': {'finished': 1, 'int': 5, 'slot': {'date': '今天', 'work': '500kV临昭5373/临关5374线跨越明巢高速公路迁改工程(PROJ-2021-0013)'}}, 'exe_time': 0.35122, 'time': '2025-02-06 21:54:33'}
{'fun_name': 'agent', 'params': [{'role': 'user', 'content': '送变电一公司正在施工的作业内容是什么'}, {'role': 'assistant', 'content': '请问您想查询哪天的周计划作业内容'}, {'role': 'user', 'content': '本周'}], 'result': {'finished': 1, 'int': 7, 'slot': {'company': '送变电一公司', 'work_status': '正在施工', 'date': '本周'}}, 'exe_time': 0.41923, 'time': '2025-02-06 21:54:34'}
{'fun_name': 'agent', 'params': [{'role': 'user', 'content': '今天临关5374线跨越明巢高速公路迁改工程多少项工作计划'}], 'result': {'finished': 1, 'int': 5, 'slot': {'date': '今天', 'work': '500kV临昭5373/临关5374线跨越明巢高速公路迁改工程(PROJ-2021-0013)'}}, 'exe_time': 0.32405, 'time': '2025-02-06 21:55:15'}
{'fun_name': 'agent', 'params': [{'role': 'user', 'content': '送变电一公司正在施工的作业内容是什么'}, {'role': 'assistant', 'content': '请问您想查询哪天的周计划作业内容'}, {'role': 'user', 'content': '本周'}], 'result': {'finished': 1, 'int': 7, 'slot': {'company': '送变电一公司', 'work_status': '正在施工', 'date': '本周'}}, 'exe_time': 0.42524, 'time': '2025-02-06 21:55:15'}
{'fun_name': 'agent', 'params': [{'role': 'user', 'content': '今天临关5374线跨越明巢高速公路迁改工程多少项工作计划'}], 'result': {'finished': 1, 'int': 5, 'slot': {'date': '今天', 'work': '500kV临昭5373/临关5374线跨越明巢高速公路迁改工程(PROJ-2021-0013)'}}, 'exe_time': 0.31932, 'time': '2025-02-06 21:55:19'}
{'fun_name': 'agent', 'params': [{'role': 'user', 'content': '送变电一公司正在施工的作业内容是什么'}, {'role': 'assistant', 'content': '请问您想查询哪天的周计划作业内容'}, {'role': 'user', 'content': '本周'}], 'result': {'finished': 1, 'int': 7, 'slot': {'company': '送变电一公司', 'work_status': '正在施工', 'date': '本周'}}, 'exe_time': 0.41869, 'time': '2025-02-06 21:55:19'}
{'fun_name': 'agent', 'params': [{'role': 'user', 'content': '今天临关5374线跨越明巢高速公路迁改工程多少项工作计划'}], 'result': {'finished': 1, 'int': 5, 'slot': {'date': '今天', 'work': '500kV临昭5373/临关5374线跨越明巢高速公路迁改工程(PROJ-2021-0013)'}}, 'exe_time': 0.37248, 'time': '2025-02-06 21:55:59'}
{'fun_name': 'agent', 'params': [{'role': 'user', 'content': '送变电一公司正在施工的作业内容是什么'}, {'role': 'assistant', 'content': '请问您想查询哪天的周计划作业内容'}, {'role': 'user', 'content': '本周'}], 'result': {'finished': 1, 'int': 7, 'slot': {'company': '送变电一公司', 'work_status': '正在施工', 'date': '本周'}}, 'exe_time': 0.44125, 'time': '2025-02-06 21:56:00'}
{'fun_name': 'agent', 'params': [{'role': 'user', 'content': '今天临关5374线跨越明巢高速公路迁改工程多少项工作计划'}], 'result': {'finished': 1, 'int': 5, 'slot': {'date': '今天', 'work': '500kV临昭5373/临关5374线跨越明巢高速公路迁改工程(PROJ-2021-0013)'}}, 'exe_time': 0.30526, 'time': '2025-02-06 21:57:13'}
{'fun_name': 'agent', 'params': [{'role': 'user', 'content': '送变电一公司正在施工的作业内容是什么'}, {'role': 'assistant', 'content': '请问您想查询哪天的周计划作业内容'}, {'role': 'user', 'content': '本周'}], 'result': {'finished': 1, 'int': 7, 'slot': {'company': '送变电一公司', 'work_status': '正在施工', 'date': '本周'}}, 'exe_time': 0.43103, 'time': '2025-02-06 21:57:14'}

View File

@ -1,14 +0,0 @@
{'fun_name': 'agent', 'params': [{'role': 'user', 'content': '110kv线路工程一般档弧垂允许的最大偏差值'}], 'result': {'finished': 1, 'int': 10, 'slot': {}}, 'exe_time': 0.31451, 'time': '2025-02-07 15:05:33'}
{'fun_name': 'agent', 'params': [{'role': 'user', 'content': '打开风险管控页面'}], 'result': {'finished': 0, 'text': '请问您想查询哪个页面?'}, 'exe_time': 0.3189, 'time': '2025-02-07 15:06:14'}
{'fun_name': 'agent', 'params': [{'role': 'user', 'content': '打开日计划管理'}], 'result': {'finished': 1, 'int': 3, 'slot': {'module': '日计划管理'}}, 'exe_time': 0.12494, 'time': '2025-02-07 15:06:32'}
{'fun_name': 'agent', 'params': [{'role': 'user', 'content': '今天1号工程有多少项作业计划'}], 'result': {'finished': 1, 'int': 4, 'slot': {'date': '今天', 'work': '1号工程'}}, 'exe_time': 0.76934, 'time': '2025-02-07 15:07:30'}
{'fun_name': 'agent', 'params': [{'role': 'user', 'content': '本周1号项目部多少项一级风险作业计划正在施工'}], 'result': {'finished': 1, 'int': 5, 'slot': {'date': '本周', 'project': '1号项目部', 'risk_level': '一级风险', 'work_status': '正在施工'}}, 'exe_time': 0.28309, 'time': '2025-02-07 15:07:43'}
{'fun_name': 'agent', 'params': [{'role': 'user', 'content': '今天天气怎么样'}], 'result': {'finished': 0, 'text': '请问您想查询哪里的天气情况?'}, 'exe_time': 0.14188, 'time': '2025-02-07 15:08:28'}
{'fun_name': 'agent', 'params': [{'role': 'user', 'content': '今天合肥天气怎么样'}], 'result': {'finished': 1, 'int': 1, 'slot': {'date': '今天', 'area': '合肥'}}, 'exe_time': 0.09992, 'time': '2025-02-07 15:08:38'}
{'fun_name': 'agent', 'params': [{'role': 'user', 'content': '明天合肥天气怎么样'}], 'result': {'finished': 1, 'int': 1, 'slot': {'date': '明天', 'area': '合肥'}}, 'exe_time': 0.09429, 'time': '2025-02-07 15:08:50'}
{'fun_name': 'agent', 'params': [{'role': 'user', 'content': '明日合肥天气怎么样'}], 'result': {'finished': 1, 'int': 1, 'slot': {'date': '明日', 'area': '合肥'}}, 'exe_time': 0.09947, 'time': '2025-02-07 15:08:56'}
{'fun_name': 'agent', 'params': [{'role': 'user', 'content': '今天1号工程的作业内容是什么'}], 'result': {'finished': 1, 'int': 6, 'slot': {'date': '今天', 'work': '1号工程'}}, 'exe_time': 0.31921, 'time': '2025-02-07 15:09:26'}
{'fun_name': 'agent', 'params': [{'role': 'user', 'content': '今天1号工程的作业内容是什么'}], 'result': {'finished': 1, 'int': 6, 'slot': {'date': '今天', 'work': '1号工程'}}, 'exe_time': 0.28888, 'time': '2025-02-07 15:31:01'}
{'fun_name': 'agent', 'params': [{'role': 'user', 'content': '今天1号工程的作业内容是什么'}], 'result': {'finished': 1, 'int': 6, 'slot': {'date': '今天', 'work': '1号工程'}}, 'exe_time': 0.33831, 'time': '2025-02-07 15:54:12'}
{'fun_name': 'agent', 'params': [{'role': 'user', 'content': '今天1号工程的作业内容是什么'}], 'result': {'finished': 1, 'int': 6, 'slot': {'date': '今天', 'work': '1号工程'}}, 'exe_time': 0.52175, 'time': '2025-02-07 15:54:52'}
{'fun_name': 'agent', 'params': [{'role': 'user', 'content': '今天1号工程的作业内容是什么'}], 'result': {'finished': 1, 'int': 6, 'slot': {'date': '今天', 'work': '1号工程'}}, 'exe_time': 0.33173, 'time': '2025-02-07 15:58:55'}

View File

@ -1,97 +0,0 @@
from typing import List
def return_extra(scores,documents,top_k = None):
new_docs = []
for idx,score in enumerate(scores):
#new_docs.append({"index": idx, "text": documents[idx], "score": 1 / (1 + np.exp(-score))})
new_docs.append({"index": idx, "text": documents[idx], "score": score})
results = [{"document": {"text": doc["text"], "index": doc["index"], "relevance_score": doc["score"]}} for doc in list(sorted(new_docs, key = lambda x: x["score"], reverse=True))]
if top_k:
return {"results": results[:top_k]}
else:
return {"results":results}
class Singleton(type):
def __call__(cls, *args, **kwargs):
if not hasattr(cls,'_instence'):
cls._instance = super().__call__(*args, **kwargs)
return cls._instance
class Embedding(metaclass=Singleton):
def __init__(self,emb):
self.embedding = emb
def compute_similarity(self,sentences1: List[str], sentences2: List[str]):
if len(sentences1) > 0 and len(sentences2) > 0:
embeddings1 = self.encode(sentences1)
embeddings2 = self.encode(sentences2)
similarity = embeddings1 @ embeddings2.T
return similarity
else:
return None
def encode(self,sentences):
return self.embedding.encode(sentences)
def get_similarity(self,em1,em2):
return em1 @ em2.T
class M3E_EMB(Embedding):
def __init__(self,model_path):
from sentence_transformers import SentenceTransformer
self.embedding = SentenceTransformer(model_path)
super(M3E_EMB,self).__init__(self.embedding)
class KNOW_PASS():
def __init__(self,embedding: Embedding):
self.embedding = embedding
self.knowbase = {} # 临时保存知识库
self.knownames = []
def load_know(self,know_id: str, contents: List[str],drop_dup = True,is_cover = False):
if self.knowbase.__contains__(know_id):
cur = self.knowbase[know_id]
if cur['contents'] == contents:
print(f"已经存在{know_id}知识库,包含知识条数:{len(cur['contents'])}")
return
if is_cover:
print(f"清空{know_id}知识库")
self.knowbase[know_id] = {}
else:
contents.extend(cur['contents'])
else:
self.knownames.append(know_id)
if len(self.knownames) > 500:
for knowname in self.knownames[:100]:
del self.knowbase[knowname]
self.knownames = self.knownames[100:]
self.knowbase[know_id] = {}
if drop_dup:
contents = list(set(contents))
em = self.embedding.encode(contents)
self.knowbase[know_id]['contents'] = contents
self.knowbase[know_id]['embedding'] = em
print(f"已更新{know_id}知识库,现在包含知识条数:{len(contents)}")
return know_id
def get_similarity_pair(self, sentences1: List[str], sentences2: List[str]):
similarity = self.embedding.compute_similarity(sentences1,sentences2)
similarity = similarity.tolist()
res = {"results": similarity}
if similarity is not None:
return res
return None
def get_similarity_know(self,query: str,know_id: str,top_k: int = 10):
if not self.knowbase.__contains__(know_id):
print("当前知识库中不包含{know_id},当前知识库中包括:{self.knowbase.keys()},请确定知识库名称是否正确或者创建知识库")
return None
em = self.embedding.encode([query])
similarity = self.embedding.get_similarity(em,self.knowbase[know_id]['embedding'])
similarity = similarity.tolist()
if similarity is None:
return None
return return_extra(similarity[0],self.knowbase[know_id]['contents'],top_k)

View File

@ -1,40 +0,0 @@
import requests
class Similar():
def __init__(self):
self.url = 'http://127.0.0.1:6006'
def post(self,url,json_p):
headers = {
"Authorization": "Bearer ACCESS_TOKEN"
}
response = requests.post(url,json=json_p,headers=headers)
return response.text
def similar(self,texts1,texts2):
url = self.url + '/v1/embedding'
json_p = {
"sentences1":texts1,
"sentences2":texts2,
}
return self.post(url,json_p)
def load_know(self,know_id,contents,drop_dup,is_cover = False):
url = self.url + '/v1/load_know'
json_p = {
"know_id": know_id,
"contents": contents,
"drop_dup": drop_dup,
"is_cover": is_cover,
}
return self.post(url,json_p)
def know_sim(self,query,know_id,top_k = 10):
url = self.url + '/v1/know_sim'
json_p = {
"query": query,
"know_id": know_id,
"top_k": top_k
}
return self.post(url,json_p)

7
run.sh
View File

@ -1,7 +0,0 @@
curl -X POST http://127.0.0.1:8072/agent \
-H "Content-Type: application/json" \
-d '{"user_id":"3bb66776-1722-4c36-b14a-73dd210fe750","messages":[{"role":"user","content":"今天临关5374线跨越明巢高速公路迁改工程多少项工作计划"}]}'
curl -X POST http://192.168.0.21:8072/agent \
-H "Content-Type: application/json" \
-d '{"user_id":"3bb66776-1722-4c36-b14a-73dd210fe750","messages":[{"role":"user","content":"送变电一公司正在施工的作业内容是什么"},{"role":"assistant","content":"请问您想查询哪天的周计划作业内容"},{"role":"user","content":"本周"}]}'

View File

@ -1,143 +0,0 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
from functools import partial
import paddle
from utils import convert_example, create_data_loader, reader
from paddlenlp.data import DataCollatorWithPadding
from paddlenlp.datasets import MapDataset, load_dataset
from paddlenlp.metrics import SpanEvaluator
from paddlenlp.transformers import UIE, UIEM, AutoTokenizer
from paddlenlp.utils.ie_utils import get_relation_type_dict, unify_prompt_name
from paddlenlp.utils.log import logger
@paddle.no_grad()
def evaluate(model, metric, data_loader, multilingual=False):
"""
Given a dataset, it evals model and computes the metric.
Args:
model(obj:`paddle.nn.Layer`): A model to classify texts.
metric(obj:`paddle.metric.Metric`): The evaluation metric.
data_loader(obj:`paddle.io.DataLoader`): The dataset loader which generates batches.
multilingual(bool): Whether is the multilingual model.
"""
model.eval()
metric.reset()
for batch in data_loader:
if multilingual:
start_prob, end_prob = model(batch["input_ids"], batch["position_ids"])
else:
start_prob, end_prob = model(
batch["input_ids"], batch["token_type_ids"], batch["position_ids"], batch["attention_mask"]
)
start_ids = paddle.cast(batch["start_positions"], "float32")
end_ids = paddle.cast(batch["end_positions"], "float32")
num_correct, num_infer, num_label = metric.compute(start_prob, end_prob, start_ids, end_ids)
metric.update(num_correct, num_infer, num_label)
precision, recall, f1 = metric.accumulate()
model.train()
return precision, recall, f1
def do_eval():
paddle.set_device(args.device)
if args.model_path in ["uie-m-base", "uie-m-large"]:
args.multilingual = True
tokenizer = AutoTokenizer.from_pretrained(args.model_path)
if args.multilingual:
model = UIEM.from_pretrained(args.model_path)
else:
model = UIE.from_pretrained(args.model_path)
test_ds = load_dataset(reader, data_path=args.test_path, max_seq_len=args.max_seq_len, lazy=False)
class_dict = {}
relation_data = []
if args.debug:
for data in test_ds:
class_name = unify_prompt_name(data["prompt"])
# Only positive examples are evaluated in debug mode
if len(data["result_list"]) != 0:
p = "" if args.schema_lang == "ch" else " of "
if p not in data["prompt"]:
class_dict.setdefault(class_name, []).append(data)
else:
relation_data.append((data["prompt"], data))
relation_type_dict = get_relation_type_dict(relation_data, schema_lang=args.schema_lang)
else:
class_dict["all_classes"] = test_ds
trans_fn = partial(
convert_example, tokenizer=tokenizer, max_seq_len=args.max_seq_len, multilingual=args.multilingual
)
for key in class_dict.keys():
if args.debug:
test_ds = MapDataset(class_dict[key])
else:
test_ds = class_dict[key]
test_ds = test_ds.map(trans_fn)
data_collator = DataCollatorWithPadding(tokenizer)
test_data_loader = create_data_loader(test_ds, mode="test", batch_size=args.batch_size, trans_fn=data_collator)
metric = SpanEvaluator()
precision, recall, f1 = evaluate(model, metric, test_data_loader, args.multilingual)
logger.info("-----------------------------")
logger.info("Class Name: %s" % key)
logger.info("Evaluation Precision: %.5f | Recall: %.5f | F1: %.5f" % (precision, recall, f1))
if args.debug and len(relation_type_dict.keys()) != 0:
for key in relation_type_dict.keys():
test_ds = MapDataset(relation_type_dict[key])
test_ds = test_ds.map(trans_fn)
test_data_loader = create_data_loader(
test_ds, mode="test", batch_size=args.batch_size, trans_fn=data_collator
)
metric = SpanEvaluator()
precision, recall, f1 = evaluate(model, metric, test_data_loader)
logger.info("-----------------------------")
if args.schema_lang == "ch":
logger.info("Class Name: X的%s" % key)
else:
logger.info("Class Name: %s of X" % key)
logger.info("Evaluation Precision: %.5f | Recall: %.5f | F1: %.5f" % (precision, recall, f1))
if __name__ == "__main__":
# yapf: disable
parser = argparse.ArgumentParser()
parser.add_argument("--model_path", type=str, default=None, help="The path of saved model that you want to load.")
parser.add_argument("--test_path", type=str, default=None, help="The path of test set.")
parser.add_argument("--batch_size", type=int, default=16, help="Batch size per GPU/CPU for training.")
parser.add_argument("--device", type=str, default="gpu", choices=["gpu", "cpu", "npu"], help="Device selected for evaluate.")
parser.add_argument("--max_seq_len", type=int, default=512, help="The maximum total input sequence length after tokenization.")
parser.add_argument("--debug", action='store_true', help="Precision, recall and F1 score are calculated for each class separately if this option is enabled.")
parser.add_argument("--multilingual", action='store_true', help="Whether is the multilingual model.")
parser.add_argument("--schema_lang", choices=["ch", "en"], default="ch", help="Select the language type for schema.")
args = parser.parse_args()
# yapf: enable
do_eval()

View File

@ -1,244 +0,0 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
from dataclasses import dataclass, field
from functools import partial
from typing import List, Optional
import paddle
from utils import convert_example, reader
from paddlenlp.data import DataCollatorWithPadding
from paddlenlp.datasets import load_dataset
from paddlenlp.metrics import SpanEvaluator
from paddlenlp.trainer import (
CompressionArguments,
PdArgumentParser,
Trainer,
get_last_checkpoint,
)
from paddlenlp.transformers import UIE, UIEM, AutoTokenizer, export_model
from paddlenlp.utils.ie_utils import compute_metrics, uie_loss_func
from paddlenlp.utils.log import logger
@dataclass
class DataArguments:
"""
Arguments pertaining to what data we are going to input our model for training and eval.
Using `PdArgumentParser` we can turn this class into argparse arguments to be able to
specify them on the command line.
"""
train_path: str = field(
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
)
dev_path: str = field(
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
)
max_seq_length: Optional[int] = field(
default=512,
metadata={
"help": "The maximum total input sequence length after tokenization. Sequences longer "
"than this will be truncated, sequences shorter will be padded."
},
)
dynamic_max_length: Optional[List[int]] = field(
default=None,
metadata={"help": "dynamic max length from batch, it can be array of length, eg: 16 32 64 128"},
)
@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
"""
model_name_or_path: Optional[str] = field(
default="uie-base",
metadata={
"help": "Path to pretrained model, such as 'uie-base', 'uie-tiny', "
"'uie-medium', 'uie-mini', 'uie-micro', 'uie-nano', 'uie-base-en', "
"'uie-m-base', 'uie-m-large', or finetuned model path."
},
)
export_model_dir: Optional[str] = field(
default=None,
metadata={"help": "Path to directory to store the exported inference model."},
)
multilingual: bool = field(default=False, metadata={"help": "Whether the model is a multilingual model."})
def main():
parser = PdArgumentParser((ModelArguments, DataArguments, CompressionArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
training_args.label_names = ["start_positions", "end_positions"]
if model_args.model_name_or_path in ["uie-m-base", "uie-m-large"]:
model_args.multilingual = True
elif os.path.exists(os.path.join(model_args.model_name_or_path, "model_config.json")):
with open(os.path.join(model_args.model_name_or_path, "model_config.json")) as f:
init_class = json.load(f)["init_class"]
if init_class == "UIEM":
model_args.multilingual = True
# Log model and data config
training_args.print_config(model_args, "Model")
training_args.print_config(data_args, "Data")
paddle.set_device(training_args.device)
# Log on each process the small summary:
logger.warning(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, world_size: {training_args.world_size}, "
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
)
# Detecting last checkpoint.
last_checkpoint = None
if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
last_checkpoint = get_last_checkpoint(training_args.output_dir)
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
raise ValueError(
f"Output directory ({training_args.output_dir}) already exists and is not empty. "
"Use --overwrite_output_dir to overcome."
)
elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
logger.info(
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
)
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path)
if model_args.multilingual:
model = UIEM.from_pretrained(model_args.model_name_or_path)
else:
model = UIE.from_pretrained(model_args.model_name_or_path)
train_ds = load_dataset(reader, data_path=data_args.train_path, max_seq_len=data_args.max_seq_length, lazy=False)
dev_ds = load_dataset(reader, data_path=data_args.dev_path, max_seq_len=data_args.max_seq_length, lazy=False)
trans_fn = partial(
convert_example,
tokenizer=tokenizer,
max_seq_len=data_args.max_seq_length,
multilingual=model_args.multilingual,
dynamic_max_length=data_args.dynamic_max_length,
)
train_ds = train_ds.map(trans_fn)
dev_ds = dev_ds.map(trans_fn)
if training_args.device == "npu":
data_collator = DataCollatorWithPadding(tokenizer, padding="longest")
else:
data_collator = DataCollatorWithPadding(tokenizer)
trainer = Trainer(
model=model,
criterion=uie_loss_func,
args=training_args,
data_collator=data_collator,
train_dataset=train_ds if training_args.do_train or training_args.do_compress else None,
eval_dataset=dev_ds if training_args.do_eval or training_args.do_compress else None,
tokenizer=tokenizer,
compute_metrics=compute_metrics,
)
trainer.optimizer = paddle.optimizer.AdamW(
learning_rate=training_args.learning_rate, parameters=model.parameters()
)
checkpoint = None
if training_args.resume_from_checkpoint is not None:
checkpoint = training_args.resume_from_checkpoint
elif last_checkpoint is not None:
checkpoint = last_checkpoint
# Training
if training_args.do_train:
train_result = trainer.train(resume_from_checkpoint=checkpoint)
metrics = train_result.metrics
trainer.save_model()
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()
# Evaluate and tests model
if training_args.do_eval:
eval_metrics = trainer.evaluate()
trainer.log_metrics("eval", eval_metrics)
# export inference model
if training_args.do_export:
# You can also load from certain checkpoint
# trainer.load_state_dict_from_checkpoint("/path/to/checkpoint/")
if training_args.device == "npu":
# npu will transform int64 to int32 for internal calculation.
# To reduce useless transformation, we feed int32 inputs.
input_spec_dtype = "int32"
else:
input_spec_dtype = "int64"
if model_args.multilingual:
input_spec = [
paddle.static.InputSpec(shape=[None, None], dtype=input_spec_dtype, name="input_ids"),
paddle.static.InputSpec(shape=[None, None], dtype=input_spec_dtype, name="position_ids"),
]
else:
input_spec = [
paddle.static.InputSpec(shape=[None, None], dtype=input_spec_dtype, name="input_ids"),
paddle.static.InputSpec(shape=[None, None], dtype=input_spec_dtype, name="token_type_ids"),
paddle.static.InputSpec(shape=[None, None], dtype=input_spec_dtype, name="position_ids"),
paddle.static.InputSpec(shape=[None, None], dtype=input_spec_dtype, name="attention_mask"),
]
if model_args.export_model_dir is None:
model_args.export_model_dir = os.path.join(training_args.output_dir, "export")
export_model(model=trainer.model, input_spec=input_spec, path=model_args.export_model_dir)
if training_args.do_compress:
@paddle.no_grad()
def custom_evaluate(self, model, data_loader):
metric = SpanEvaluator()
model.eval()
metric.reset()
for batch in data_loader:
if model_args.multilingual:
logits = model(input_ids=batch["input_ids"], position_ids=batch["position_ids"])
else:
logits = model(
input_ids=batch["input_ids"],
token_type_ids=batch["token_type_ids"],
position_ids=batch["position_ids"],
attention_mask=batch["attention_mask"],
)
start_prob, end_prob = logits
start_ids, end_ids = batch["start_positions"], batch["end_positions"]
num_correct, num_infer, num_label = metric.compute(start_prob, end_prob, start_ids, end_ids)
metric.update(num_correct, num_infer, num_label)
precision, recall, f1 = metric.accumulate()
logger.info("f1: %s, precision: %s, recall: %s" % (f1, precision, f1))
model.train()
return f1
trainer.compress(custom_evaluate=custom_evaluate)
if __name__ == "__main__":
main()

View File

@ -1,8 +0,0 @@
python evaluate.py \
--model_path ./checkpoint \
--test_path ./data/data.txt \
--batch_size 16 \
--device gpu \
--debug \
--max_seq_len 256 \
--schema_lang ch

View File

@ -1,23 +0,0 @@
python finetune.py \
--device gpu \
--logging_steps 10 \
--eval_steps 2000 \
--save_steps 2000 \
--seed 1000 \
--model_name_or_path /home/workplace/models/uie-base \
--output_dir ./checkpoint \
--train_path ./data/train.txt \
--dev_path ./data/dev.txt \
--max_seq_len 256 \
--per_device_train_batch_size 4 \
--per_device_eval_batch_size 8 \
--num_train_epochs 10 \
--learning_rate 1e-5 \
--do_train \
--do_eval \
--do_export \
--overwrite_output_dir \
--disable_tqdm True \
--metric_for_best_model eval_f1 \
--load_best_model_at_end True \
--save_total_limit 1

View File

@ -1,235 +0,0 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import random
from typing import List, Optional
import numpy as np
import paddle
from paddlenlp.utils.log import logger
def set_seed(seed):
paddle.seed(seed)
random.seed(seed)
np.random.seed(seed)
def create_data_loader(dataset, mode="train", batch_size=1, trans_fn=None):
"""
Create dataloader.
Args:
dataset(obj:`paddle.io.Dataset`): Dataset instance.
mode(obj:`str`, optional, defaults to obj:`train`): If mode is 'train', it will shuffle the dataset randomly.
batch_size(obj:`int`, optional, defaults to 1): The sample number of a mini-batch.
trans_fn(obj:`callable`, optional, defaults to `None`): function to convert a data sample to input ids, etc.
Returns:
dataloader(obj:`paddle.io.DataLoader`): The dataloader which generates batches.
"""
if trans_fn:
dataset = dataset.map(trans_fn)
shuffle = True if mode == "train" else False
if mode == "train":
sampler = paddle.io.DistributedBatchSampler(dataset=dataset, batch_size=batch_size, shuffle=shuffle)
else:
sampler = paddle.io.BatchSampler(dataset=dataset, batch_size=batch_size, shuffle=shuffle)
dataloader = paddle.io.DataLoader(dataset, batch_sampler=sampler, return_list=True)
return dataloader
def map_offset(ori_offset, offset_mapping):
"""
map ori offset to token offset
"""
for index, span in enumerate(offset_mapping):
if span[0] <= ori_offset < span[1]:
return index
return -1
def reader(data_path, max_seq_len=512):
"""
read json
"""
with open(data_path, "r", encoding="utf-8") as f:
for line in f:
json_line = json.loads(line)
content = json_line["content"].strip()
prompt = json_line["prompt"]
# Model Input is aslike: [CLS] Prompt [SEP] Content [SEP]
# It include three summary tokens.
if max_seq_len <= len(prompt) + 3:
raise ValueError("The value of max_seq_len is too small, please set a larger value")
max_content_len = max_seq_len - len(prompt) - 3
if len(content) <= max_content_len:
yield json_line
else:
result_list = json_line["result_list"]
json_lines = []
accumulate = 0
while True:
cur_result_list = []
for result in result_list:
if result["end"] - result["start"] > max_content_len:
logger.warning(
"result['end'] - result ['start'] exceeds max_content_len, which will result in no valid instance being returned"
)
if (
result["start"] + 1 <= max_content_len < result["end"]
and result["end"] - result["start"] <= max_content_len
):
max_content_len = result["start"]
break
cur_content = content[:max_content_len]
res_content = content[max_content_len:]
while True:
if len(result_list) == 0:
break
elif result_list[0]["end"] <= max_content_len:
if result_list[0]["end"] > 0:
cur_result = result_list.pop(0)
cur_result_list.append(cur_result)
else:
cur_result_list = [result for result in result_list]
break
else:
break
json_line = {"content": cur_content, "result_list": cur_result_list, "prompt": prompt}
json_lines.append(json_line)
for result in result_list:
if result["end"] <= 0:
break
result["start"] -= max_content_len
result["end"] -= max_content_len
accumulate += max_content_len
max_content_len = max_seq_len - len(prompt) - 3
if len(res_content) == 0:
break
elif len(res_content) < max_content_len:
json_line = {"content": res_content, "result_list": result_list, "prompt": prompt}
json_lines.append(json_line)
break
else:
content = res_content
for json_line in json_lines:
yield json_line
def get_dynamic_max_length(examples, default_max_length: int, dynamic_max_length: List[int]) -> int:
"""get max_length by examples which you can change it by examples in batch"""
cur_length = len(examples[0]["input_ids"])
max_length = default_max_length
for max_length_option in sorted(dynamic_max_length):
if cur_length <= max_length_option:
max_length = max_length_option
break
return max_length
def convert_example(
example, tokenizer, max_seq_len, multilingual=False, dynamic_max_length: Optional[List[int]] = None
):
"""
example: {
title
prompt
content
result_list
}
"""
if dynamic_max_length is not None:
temp_encoded_inputs = tokenizer(
text=[example["prompt"]],
text_pair=[example["content"]],
truncation=True,
max_seq_len=max_seq_len,
return_attention_mask=True,
return_position_ids=True,
return_dict=False,
return_offsets_mapping=True,
)
max_length = get_dynamic_max_length(
examples=temp_encoded_inputs, default_max_length=max_seq_len, dynamic_max_length=dynamic_max_length
)
# always pad to max_length
encoded_inputs = tokenizer(
text=[example["prompt"]],
text_pair=[example["content"]],
truncation=True,
max_seq_len=max_length,
pad_to_max_seq_len=True,
return_attention_mask=True,
return_position_ids=True,
return_dict=False,
return_offsets_mapping=True,
)
start_ids = [0.0 for x in range(max_length)]
end_ids = [0.0 for x in range(max_length)]
else:
encoded_inputs = tokenizer(
text=[example["prompt"]],
text_pair=[example["content"]],
truncation=True,
max_seq_len=max_seq_len,
pad_to_max_seq_len=True,
return_attention_mask=True,
return_position_ids=True,
return_dict=False,
return_offsets_mapping=True,
)
start_ids = [0.0 for x in range(max_seq_len)]
end_ids = [0.0 for x in range(max_seq_len)]
encoded_inputs = encoded_inputs[0]
offset_mapping = [list(x) for x in encoded_inputs["offset_mapping"]]
bias = 0
for index in range(1, len(offset_mapping)):
mapping = offset_mapping[index]
if mapping[0] == 0 and mapping[1] == 0 and bias == 0:
bias = offset_mapping[index - 1][1] + 1 # Includes [SEP] token
if mapping[0] == 0 and mapping[1] == 0:
continue
offset_mapping[index][0] += bias
offset_mapping[index][1] += bias
for item in example["result_list"]:
start = map_offset(item["start"] + bias, offset_mapping)
end = map_offset(item["end"] - 1 + bias, offset_mapping)
start_ids[start] = 1.0
end_ids[end] = 1.0
if multilingual:
tokenized_output = {
"input_ids": encoded_inputs["input_ids"],
"position_ids": encoded_inputs["position_ids"],
"start_positions": start_ids,
"end_positions": end_ids,
}
else:
tokenized_output = {
"input_ids": encoded_inputs["input_ids"],
"token_type_ids": encoded_inputs["token_type_ids"],
"position_ids": encoded_inputs["position_ids"],
"attention_mask": encoded_inputs["attention_mask"],
"start_positions": start_ids,
"end_positions": end_ids,
}
return tokenized_output