删除文件
This commit is contained in:
parent
f89ef84f54
commit
3cea60d4fd
23
emb.log
23
emb.log
|
|
@ -1,23 +0,0 @@
|
|||
nohup: ignoring input
|
||||
初始模型加载
|
||||
INFO: Started server process [2712]
|
||||
INFO: Waiting for application startup.
|
||||
INFO: Application startup complete.
|
||||
INFO: Uvicorn running on http://0.0.0.0:6006 (Press CTRL+C to quit)
|
||||
已更新work_name知识库,现在包含知识条数:2054
|
||||
INFO: 127.0.0.1:54240 - "POST /v1/load_know HTTP/1.1" 200 OK
|
||||
INFO: 127.0.0.1:50842 - "POST /v1/know_sim HTTP/1.1" 200 OK
|
||||
INFO: 127.0.0.1:32800 - "POST /v1/know_sim HTTP/1.1" 200 OK
|
||||
清空work_name知识库
|
||||
已更新work_name知识库,现在包含知识条数:2054
|
||||
INFO: 127.0.0.1:41464 - "POST /v1/load_know HTTP/1.1" 200 OK
|
||||
INFO: 127.0.0.1:41478 - "POST /v1/know_sim HTTP/1.1" 200 OK
|
||||
INFO: 127.0.0.1:41494 - "POST /v1/know_sim HTTP/1.1" 200 OK
|
||||
INFO: 127.0.0.1:37822 - "POST /v1/know_sim HTTP/1.1" 200 OK
|
||||
INFO: 127.0.0.1:54392 - "POST /v1/know_sim HTTP/1.1" 200 OK
|
||||
INFO: 127.0.0.1:49690 - "POST /v1/know_sim HTTP/1.1" 200 OK
|
||||
INFO: 127.0.0.1:39220 - "POST /v1/know_sim HTTP/1.1" 200 OK
|
||||
INFO: 127.0.0.1:47842 - "POST /v1/know_sim HTTP/1.1" 200 OK
|
||||
INFO: 127.0.0.1:37910 - "POST /v1/know_sim HTTP/1.1" 200 OK
|
||||
INFO: 127.0.0.1:58032 - "POST /v1/know_sim HTTP/1.1" 200 OK
|
||||
INFO: 127.0.0.1:43468 - "POST /v1/know_sim HTTP/1.1" 200 OK
|
||||
104
embedding.py
104
embedding.py
|
|
@ -1,104 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
import os
|
||||
import numpy as np
|
||||
import logging
|
||||
import uvicorn
|
||||
import datetime
|
||||
from fastapi import FastAPI, Security, HTTPException
|
||||
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
|
||||
from pydantic import Field, BaseModel, validator
|
||||
from typing import Optional, List,Dict
|
||||
|
||||
from models import *
|
||||
|
||||
def response(code, msg, data=None):
|
||||
time = str(datetime.datetime.now())
|
||||
if data is None:
|
||||
data = []
|
||||
result = {
|
||||
"code": code,
|
||||
"message": msg,
|
||||
"data": data,
|
||||
"time": time
|
||||
}
|
||||
return result
|
||||
|
||||
def success(data=None, msg=''):
|
||||
return
|
||||
|
||||
class QADocs(BaseModel):
|
||||
query: Optional[str]
|
||||
documents: Optional[List[str]]
|
||||
|
||||
class Knows(BaseModel):
|
||||
know_id: Optional[str]
|
||||
contents: Optional[List[str]]
|
||||
drop_dup: Optional[bool]
|
||||
is_cover: Optional[bool]
|
||||
|
||||
class Know_Sim(BaseModel):
|
||||
query: Optional[str]
|
||||
know_id: Optional[str]
|
||||
top_k: Optional[int] = 10
|
||||
|
||||
app = FastAPI()
|
||||
security = HTTPBearer()
|
||||
env_bearer_token = 'ACCESS_TOKEN'
|
||||
|
||||
|
||||
@app.post('/v1/embedding')
|
||||
async def handle_post_request1(sentences1: List[str], sentences2: List[str], credentials: HTTPAuthorizationCredentials = Security(security)):
|
||||
global know_pass
|
||||
token = credentials.credentials
|
||||
if env_bearer_token is not None and token != env_bearer_token:
|
||||
raise HTTPException(status_code=401, detail="Invalid token")
|
||||
similarity = know_pass.get_similarity_pair(sentences1,sentences2)
|
||||
return response(200, msg="获取相似性成功",data=similarity)
|
||||
|
||||
@app.post('/v1/load_know')
|
||||
async def handle_post_request2(knows: Knows, credentials: HTTPAuthorizationCredentials = Security(security)):
|
||||
know_id = knows.know_id
|
||||
contents = knows.contents
|
||||
drop_dup = knows.drop_dup
|
||||
is_cover = knows.is_cover
|
||||
global know_pass
|
||||
token = credentials.credentials
|
||||
if env_bearer_token is not None and token != env_bearer_token:
|
||||
raise HTTPException(status_code=401, detail="Invalid token")
|
||||
know_pass.load_know(know_id, contents, drop_dup,is_cover)
|
||||
return response(200, msg=f"生成{know_id}知识库成功")
|
||||
|
||||
@app.post('/v1/know_sim')
|
||||
async def handle_post_request3(know_sim:Know_Sim, credentials: HTTPAuthorizationCredentials = Security(security)):
|
||||
query = know_sim.query
|
||||
know_id = know_sim.know_id
|
||||
top_k = know_sim.top_k
|
||||
global know_pass
|
||||
token = credentials.credentials
|
||||
if env_bearer_token is not None and token != env_bearer_token:
|
||||
raise HTTPException(status_code=401, detail="Invalid token")
|
||||
similarity = know_pass.get_similarity_know(query,know_id,top_k)
|
||||
return response(200, msg="获取相似性成功",data=similarity)
|
||||
|
||||
def init_env():
|
||||
# 初始化模型
|
||||
print("初始模型加载")
|
||||
# embedding = BCE_EMB(EMBEDDING_MODEL_PATH)
|
||||
EMBEDDING_MODEL_PATH = "/home/workplace/models/m3e-base"
|
||||
embedding = M3E_EMB(EMBEDDING_MODEL_PATH)
|
||||
know_pass = KNOW_PASS(embedding)
|
||||
return know_pass
|
||||
|
||||
if __name__ == "__main__":
|
||||
token = os.getenv("ACCESS_TOKEN")
|
||||
if token is not None:
|
||||
env_bearer_token = token
|
||||
try:
|
||||
# 默认环境
|
||||
know_pass = init_env()
|
||||
|
||||
uvicorn.run(app, host='0.0.0.0', port=6006)
|
||||
except Exception as e:
|
||||
print(f"API启动失败!\n报错:\n{e}")
|
||||
|
||||
|
|
@ -1,11 +0,0 @@
|
|||
天气查询
|
||||
互联网查询
|
||||
页面切换
|
||||
日计划数量查询
|
||||
周计划数量查询
|
||||
日计划作业内容
|
||||
周计划作业内容
|
||||
施工人数
|
||||
作业考勤人数
|
||||
知识问答
|
||||
|
||||
12
int/pred.sh
12
int/pred.sh
|
|
@ -1,12 +0,0 @@
|
|||
python train.py \
|
||||
--do_eval \
|
||||
--debug True \
|
||||
--device gpu \
|
||||
--model_name_or_path checkpoint \
|
||||
--output_dir checkpoint \
|
||||
--per_device_eval_batch_size 32 \
|
||||
--max_length 256 \
|
||||
--label_path ./data/label.txt \
|
||||
--train_path ./data/train.txt \
|
||||
--test_path ./data/test.txt \
|
||||
--dev_path ./data/dev.txt
|
||||
231
int/train.py
231
int/train.py
|
|
@ -1,231 +0,0 @@
|
|||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import functools
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
from sklearn.metrics import (
|
||||
accuracy_score,
|
||||
classification_report,
|
||||
precision_recall_fscore_support,
|
||||
)
|
||||
from utils import log_metrics_debug, preprocess_function, read_local_dataset
|
||||
|
||||
from paddlenlp.data import DataCollatorWithPadding
|
||||
from paddlenlp.datasets import load_dataset
|
||||
from paddlenlp.trainer import (
|
||||
CompressionArguments,
|
||||
EarlyStoppingCallback,
|
||||
PdArgumentParser,
|
||||
Trainer,
|
||||
)
|
||||
from paddlenlp.transformers import (
|
||||
AutoModelForSequenceClassification,
|
||||
AutoTokenizer,
|
||||
export_model,
|
||||
)
|
||||
from paddlenlp.utils.log import logger
|
||||
|
||||
SUPPORTED_MODELS = [
|
||||
"ernie-1.0-large-zh-cw",
|
||||
"ernie-1.0-base-zh-cw",
|
||||
"ernie-3.0-xbase-zh",
|
||||
"ernie-3.0-base-zh",
|
||||
"ernie-3.0-medium-zh",
|
||||
"ernie-3.0-micro-zh",
|
||||
"ernie-3.0-mini-zh",
|
||||
"ernie-3.0-nano-zh",
|
||||
"ernie-3.0-tiny-base-v2-zh",
|
||||
"ernie-3.0-tiny-medium-v2-zh",
|
||||
"ernie-3.0-tiny-micro-v2-zh",
|
||||
"ernie-3.0-tiny-mini-v2-zh",
|
||||
"ernie-3.0-tiny-nano-v2-zh ",
|
||||
"ernie-3.0-tiny-pico-v2-zh",
|
||||
"ernie-2.0-large-en",
|
||||
"ernie-2.0-base-en",
|
||||
"ernie-3.0-tiny-mini-v2-en",
|
||||
"ernie-m-base",
|
||||
"ernie-m-large",
|
||||
]
|
||||
|
||||
|
||||
# yapf: disable
|
||||
@dataclass
|
||||
class DataArguments:
|
||||
max_length: int = field(default=128, metadata={"help": "Maximum number of tokens for the model."})
|
||||
early_stopping: bool = field(default=False, metadata={"help": "Whether apply early stopping strategy."})
|
||||
early_stopping_patience: int = field(default=4, metadata={"help": "Stop training when the specified metric worsens for early_stopping_patience evaluation calls"})
|
||||
debug: bool = field(default=False, metadata={"help": "Whether choose debug mode."})
|
||||
train_path: str = field(default='./data/train.txt', metadata={"help": "Train dataset file path."})
|
||||
dev_path: str = field(default='./data/dev.txt', metadata={"help": "Dev dataset file path."})
|
||||
test_path: str = field(default='./data/dev.txt', metadata={"help": "Test dataset file path."})
|
||||
label_path: str = field(default='./data/label.txt', metadata={"help": "Label file path."})
|
||||
bad_case_path: str = field(default='./data/bad_case.txt', metadata={"help": "Bad case file path."})
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments:
|
||||
model_name_or_path: str = field(default="ernie-3.0-tiny-medium-v2-zh", metadata={"help": "Build-in pretrained model name or the path to local model."})
|
||||
export_model_dir: Optional[str] = field(default=None, metadata={"help": "Path to directory to store the exported inference model."})
|
||||
# yapf: enable
|
||||
|
||||
|
||||
def main():
|
||||
"""
|
||||
Training a binary or multi classification model
|
||||
"""
|
||||
|
||||
parser = PdArgumentParser((ModelArguments, DataArguments, CompressionArguments))
|
||||
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
||||
if training_args.do_compress:
|
||||
training_args.strategy = "dynabert"
|
||||
if training_args.do_train or training_args.do_compress:
|
||||
training_args.print_config(model_args, "Model")
|
||||
training_args.print_config(data_args, "Data")
|
||||
paddle.set_device(training_args.device)
|
||||
|
||||
# Define id2label
|
||||
id2label = {}
|
||||
label2id = {}
|
||||
with open(data_args.label_path, "r", encoding="utf-8") as f:
|
||||
for i, line in enumerate(f):
|
||||
l = line.strip()
|
||||
id2label[i] = l
|
||||
label2id[l] = i
|
||||
|
||||
# Define model & tokenizer
|
||||
if os.path.isdir(model_args.model_name_or_path):
|
||||
model = AutoModelForSequenceClassification.from_pretrained(
|
||||
model_args.model_name_or_path, label2id=label2id, id2label=id2label
|
||||
)
|
||||
elif model_args.model_name_or_path in SUPPORTED_MODELS:
|
||||
model = AutoModelForSequenceClassification.from_pretrained(
|
||||
model_args.model_name_or_path, num_classes=len(label2id), label2id=label2id, id2label=id2label
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"{model_args.model_name_or_path} is not a supported model type. Either use a local model path or select a model from {SUPPORTED_MODELS}"
|
||||
)
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path)
|
||||
|
||||
# load and preprocess dataset
|
||||
train_ds = load_dataset(read_local_dataset, path=data_args.train_path, label2id=label2id, lazy=False)
|
||||
dev_ds = load_dataset(read_local_dataset, path=data_args.dev_path, label2id=label2id, lazy=False)
|
||||
trans_func = functools.partial(preprocess_function, tokenizer=tokenizer, max_length=data_args.max_length)
|
||||
train_ds = train_ds.map(trans_func)
|
||||
dev_ds = dev_ds.map(trans_func)
|
||||
|
||||
if data_args.debug:
|
||||
test_ds = load_dataset(read_local_dataset, path=data_args.test_path, label2id=label2id, lazy=False)
|
||||
test_ds = test_ds.map(trans_func)
|
||||
|
||||
# Define the metric function.
|
||||
def compute_metrics(eval_preds):
|
||||
pred_ids = np.argmax(eval_preds.predictions, axis=-1)
|
||||
metrics = {}
|
||||
metrics["accuracy"] = accuracy_score(y_true=eval_preds.label_ids, y_pred=pred_ids)
|
||||
for average in ["micro", "macro"]:
|
||||
precision, recall, f1, _ = precision_recall_fscore_support(
|
||||
y_true=eval_preds.label_ids, y_pred=pred_ids, average=average
|
||||
)
|
||||
metrics[f"{average}_precision"] = precision
|
||||
metrics[f"{average}_recall"] = recall
|
||||
metrics[f"{average}_f1"] = f1
|
||||
return metrics
|
||||
|
||||
def compute_metrics_debug(eval_preds):
|
||||
pred_ids = np.argmax(eval_preds.predictions, axis=-1)
|
||||
metrics = classification_report(eval_preds.label_ids, pred_ids, output_dict=True)
|
||||
return metrics
|
||||
|
||||
# Define the early-stopping callback.
|
||||
if data_args.early_stopping:
|
||||
callbacks = [EarlyStoppingCallback(early_stopping_patience=data_args.early_stopping_patience)]
|
||||
else:
|
||||
callbacks = None
|
||||
|
||||
# Define Trainer
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
args=training_args,
|
||||
criterion=paddle.nn.loss.CrossEntropyLoss(),
|
||||
train_dataset=train_ds,
|
||||
eval_dataset=dev_ds,
|
||||
callbacks=callbacks,
|
||||
data_collator=DataCollatorWithPadding(tokenizer),
|
||||
compute_metrics=compute_metrics_debug if data_args.debug else compute_metrics,
|
||||
)
|
||||
|
||||
# Training
|
||||
if training_args.do_train:
|
||||
train_result = trainer.train()
|
||||
metrics = train_result.metrics
|
||||
trainer.save_model()
|
||||
trainer.log_metrics("train", metrics)
|
||||
for checkpoint_path in Path(training_args.output_dir).glob("checkpoint-*"):
|
||||
shutil.rmtree(checkpoint_path)
|
||||
|
||||
# Evaluate and tests model
|
||||
if training_args.do_eval:
|
||||
if data_args.debug:
|
||||
output = trainer.predict(test_ds)
|
||||
log_metrics_debug(output, id2label, test_ds, data_args.bad_case_path)
|
||||
else:
|
||||
eval_metrics = trainer.evaluate()
|
||||
trainer.log_metrics("eval", eval_metrics)
|
||||
|
||||
# export inference model
|
||||
if training_args.do_export:
|
||||
if model.init_config["init_class"] in ["ErnieMForSequenceClassification"]:
|
||||
input_spec = [paddle.static.InputSpec(shape=[None, None], dtype="int64", name="input_ids")]
|
||||
else:
|
||||
input_spec = [
|
||||
paddle.static.InputSpec(shape=[None, None], dtype="int64", name="input_ids"),
|
||||
paddle.static.InputSpec(shape=[None, None], dtype="int64", name="token_type_ids"),
|
||||
]
|
||||
if model_args.export_model_dir is None:
|
||||
model_args.export_model_dir = os.path.join(training_args.output_dir, "export")
|
||||
export_model(model=trainer.model, input_spec=input_spec, path=model_args.export_model_dir)
|
||||
tokenizer.save_pretrained(model_args.export_model_dir)
|
||||
id2label_file = os.path.join(model_args.export_model_dir, "id2label.json")
|
||||
with open(id2label_file, "w", encoding="utf-8") as f:
|
||||
json.dump(id2label, f, ensure_ascii=False)
|
||||
logger.info(f"id2label file saved in {id2label_file}")
|
||||
|
||||
# compress
|
||||
if training_args.do_compress:
|
||||
trainer.compress()
|
||||
for width_mult in training_args.width_mult_list:
|
||||
pruned_infer_model_dir = os.path.join(training_args.output_dir, "width_mult_" + str(round(width_mult, 2)))
|
||||
tokenizer.save_pretrained(pruned_infer_model_dir)
|
||||
id2label_file = os.path.join(pruned_infer_model_dir, "id2label.json")
|
||||
with open(id2label_file, "w", encoding="utf-8") as f:
|
||||
json.dump(id2label, f, ensure_ascii=False)
|
||||
logger.info(f"id2label file saved in {id2label_file}")
|
||||
|
||||
for path in Path(training_args.output_dir).glob("runs"):
|
||||
shutil.rmtree(path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
24
int/train.sh
24
int/train.sh
|
|
@ -1,24 +0,0 @@
|
|||
python train.py \
|
||||
--do_train \
|
||||
--do_eval \
|
||||
--do_export \
|
||||
--model_name_or_path /home/workplace/models/ernie-3.0-tiny-base-v2-zh \
|
||||
--output_dir checkpoint \
|
||||
--device gpu:0 \
|
||||
--train_path ./data/train.txt \
|
||||
--dev_path ./data/dev.txt \
|
||||
--label_path ./data/label.txt \
|
||||
--num_train_epochs 100 \
|
||||
--early_stopping True \
|
||||
--early_stopping_patience 5 \
|
||||
--learning_rate 3e-5 \
|
||||
--max_length 256 \
|
||||
--per_device_eval_batch_size 16 \
|
||||
--per_device_train_batch_size 16 \
|
||||
--metric_for_best_model accuracy \
|
||||
--load_best_model_at_end \
|
||||
--logging_steps 5 \
|
||||
--evaluation_strategy epoch \
|
||||
--save_strategy epoch \
|
||||
--save_total_limit 1
|
||||
|
||||
89
int/utils.py
89
int/utils.py
|
|
@ -1,89 +0,0 @@
|
|||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
|
||||
from paddlenlp.utils.log import logger
|
||||
|
||||
|
||||
def preprocess_function(examples, tokenizer, max_length, is_test=False):
|
||||
"""
|
||||
Builds model inputs from a sequence for sequence classification tasks
|
||||
by concatenating and adding special tokens.
|
||||
"""
|
||||
result = tokenizer(examples["text"], max_length=max_length, truncation=True)
|
||||
if not is_test:
|
||||
result["labels"] = np.array([examples["label"]], dtype="int64")
|
||||
return result
|
||||
|
||||
|
||||
def read_local_dataset(path, label2id=None, is_test=False):
|
||||
"""
|
||||
Read dataset.
|
||||
"""
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if is_test:
|
||||
sentence = line.strip()
|
||||
yield {"text": sentence}
|
||||
else:
|
||||
items = line.strip().split("\t")
|
||||
if len(items) != 2:
|
||||
continue
|
||||
yield {"text": items[0], "label": label2id[items[1]]}
|
||||
|
||||
|
||||
def log_metrics_debug(output, id2label, dev_ds, bad_case_path):
|
||||
"""
|
||||
Log metrics in debug mode.
|
||||
"""
|
||||
predictions, label_ids, metrics = output
|
||||
pred_ids = np.argmax(predictions, axis=-1)
|
||||
logger.info("-----Evaluate model-------")
|
||||
logger.info("Dev dataset size: {}".format(len(dev_ds)))
|
||||
logger.info("Accuracy in dev dataset: {:.2f}%".format(metrics["test_accuracy"] * 100))
|
||||
logger.info(
|
||||
"Macro average | precision: {:.2f} | recall: {:.2f} | F1 score {:.2f}".format(
|
||||
metrics["test_macro avg"]["precision"] * 100,
|
||||
metrics["test_macro avg"]["recall"] * 100,
|
||||
metrics["test_macro avg"]["f1-score"] * 100,
|
||||
)
|
||||
)
|
||||
for i in id2label:
|
||||
l = id2label[i]
|
||||
logger.info("Class name: {}".format(l))
|
||||
i = "test_" + str(i)
|
||||
if i in metrics:
|
||||
logger.info(
|
||||
"Evaluation examples in dev dataset: {}({:.1f}%) | precision: {:.2f} | recall: {:.2f} | F1 score {:.2f}".format(
|
||||
metrics[i]["support"],
|
||||
100 * metrics[i]["support"] / len(dev_ds),
|
||||
metrics[i]["precision"] * 100,
|
||||
metrics[i]["recall"] * 100,
|
||||
metrics[i]["f1-score"] * 100,
|
||||
)
|
||||
)
|
||||
else:
|
||||
logger.info("Evaluation examples in dev dataset: 0 (0%)")
|
||||
logger.info("----------------------------")
|
||||
|
||||
with open(bad_case_path, "w", encoding="utf-8") as f:
|
||||
f.write("Text\tLabel\tPrediction\n")
|
||||
for i, (p, l) in enumerate(zip(pred_ids, label_ids)):
|
||||
p, l = int(p), int(l)
|
||||
if p != l:
|
||||
f.write(dev_ds.data[i]["text"] + "\t" + id2label[l] + "\t" + id2label[p] + "\n")
|
||||
|
||||
logger.info("Bad case in dev dataset saved in {}".format(bad_case_path))
|
||||
|
||||
|
|
@ -1,586 +0,0 @@
|
|||
# phddns start
|
||||
import paddle
|
||||
import paddle.nn.functional as F
|
||||
from paddlenlp.data import Tuple,Pad
|
||||
from paddlenlp.transformers import AutoModelForSequenceClassification, AutoTokenizer
|
||||
import argparse
|
||||
import pandas as pd
|
||||
import time
|
||||
import datetime
|
||||
import pydantic
|
||||
import uvicorn
|
||||
from fastapi import Body, FastAPI, File, Form, Query, UploadFile, WebSocket
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Optional
|
||||
from typing import Dict, Any
|
||||
|
||||
import sys
|
||||
from post_api import Similar
|
||||
sim = Similar()
|
||||
|
||||
def load_know():
|
||||
file_path = './knows/project_know.txt'
|
||||
f = open(file_path,'r',encoding='utf-8')
|
||||
data = f.read()
|
||||
data = data.split('\n')
|
||||
works = {}
|
||||
for d in data:
|
||||
wk = d[:d.find('(')]
|
||||
works[wk] = d
|
||||
keys = list(works.keys())
|
||||
sim.load_know('work_name',keys,drop_dup=True,is_cover=True)
|
||||
return works
|
||||
works = load_know()
|
||||
# print(works)
|
||||
def standard_project(work_name):
|
||||
res = sim.know_sim(work_name,'work_name')
|
||||
res = eval(res)
|
||||
res = res['data']['results'][0]
|
||||
t = res['document']['text']
|
||||
s = res['document']['relevance_score']
|
||||
print(t)
|
||||
print(works[t])
|
||||
print(s)
|
||||
print(res)
|
||||
if s > 330:
|
||||
return works[t]
|
||||
else:
|
||||
return work_name
|
||||
|
||||
|
||||
|
||||
OPEN_CROSS_DOMAIN = True
|
||||
def time_consume(func):
|
||||
def inner(*args, **kwargs): # 加入可变参数
|
||||
time_start = time.time()
|
||||
print("\n接口请求前的时间是:",datetime.datetime.now())
|
||||
res = func(*args, **kwargs)
|
||||
time_end = time.time()
|
||||
print("\n接口请求后的时间是:",datetime.datetime.now())
|
||||
t = time_end - time_start
|
||||
print("接口耗时:", round(t,4),"s")
|
||||
return res
|
||||
return inner
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--device", default="gpu:0", help="Select which device to train model, default to gpu.")
|
||||
parser.add_argument("--output_file", default="output.txt",type=str)
|
||||
parser.add_argument("--max_seq_length",
|
||||
default=128,
|
||||
type=int,
|
||||
help="The maximum total input sequence length "
|
||||
"after tokenization. Sequences longer than this"
|
||||
"will be truncated, sequences shorter will be padded.")
|
||||
parser.add_argument("--batch_size",
|
||||
default=64,
|
||||
type=int,
|
||||
help="Batch size per GPU/CPU for training.")
|
||||
parser.add_argument("--max_seq_len",
|
||||
default=256,
|
||||
type=int,
|
||||
help="The maximum total input sequence length after tokenization.")
|
||||
parser.add_argument("--schema_lang", choices=["ch", "en"], default="ch", help="Select the language type for schema.")
|
||||
|
||||
args = parser.parse_args()
|
||||
model_names = {
|
||||
"intent_token": "/home/workplace/models/zero_int",
|
||||
"intent_model": "/home/workplace/models/zero_int/model",
|
||||
"uie_token":"/home/workplace/models/zero_uie",
|
||||
"uie_model":"/home/workplace/models/zero_uie/model",
|
||||
}
|
||||
label_names = {
|
||||
"int" : "./int/data/label.txt",
|
||||
}
|
||||
def get_label(label_name):
|
||||
label = []
|
||||
with open(label_names[label_name]) as inf:
|
||||
for idx, line in enumerate(inf):
|
||||
line = line.strip()
|
||||
if line != '':
|
||||
label.append(line)
|
||||
return label
|
||||
|
||||
def get_model(token_path,model_path):
|
||||
paddle.set_device(args.device)
|
||||
model = paddle.jit.load(model_path)
|
||||
tokenizer = AutoTokenizer.from_pretrained(token_path)
|
||||
return model,tokenizer
|
||||
|
||||
@paddle.no_grad()
|
||||
def int_predict(model, tokenizer, data, label_list):
|
||||
examples = []
|
||||
for text in data:
|
||||
result = tokenizer(text=text, max_seq_length=args.max_seq_length)
|
||||
examples.append((result['input_ids'], result['token_type_ids']))
|
||||
|
||||
# Seperates data into some batches
|
||||
batches = [
|
||||
examples[i: i + args.batch_size]
|
||||
for i in range(0, len(examples), args.batch_size)
|
||||
]
|
||||
|
||||
batchify_fn = lambda samples, fn=Tuple(
|
||||
Pad(axis=0, pad_val=tokenizer.pad_token_id),
|
||||
Pad(axis=0, pad_val=tokenizer.pad_token_type_id),
|
||||
): fn(samples)
|
||||
|
||||
results = []
|
||||
model.eval()
|
||||
for batch in batches:
|
||||
input_ids, token_type_ids = batchify_fn(batch)
|
||||
input_ids = paddle.to_tensor(input_ids)
|
||||
token_type_ids = paddle.to_tensor(token_type_ids)
|
||||
logits = model(input_ids, token_type_ids)
|
||||
probs = F.softmax(logits, axis=1)
|
||||
idx = paddle.argmax(probs, axis=1).numpy()
|
||||
idx = idx.tolist()
|
||||
pros = probs.tolist()
|
||||
labels = []
|
||||
for lbi,pro in zip(idx,pros):
|
||||
#print(pro)
|
||||
#print(pro[lbi])
|
||||
labels.append((label_list[lbi], str(round(pro[lbi],5))))
|
||||
results.extend(labels)
|
||||
|
||||
mapping = {"天气查询":1,"互联网查询":2,"页面切换":3,
|
||||
"日计划数量查询":4,"周计划数量查询":5,"日计划作业内容":6,"周计划作业内容":7,
|
||||
"施工人数":8,"作业考勤人数":9,"知识问答":10}
|
||||
nresults = []
|
||||
for res in results:
|
||||
nresults.append((mapping[res[0]],res[1]))
|
||||
|
||||
return nresults
|
||||
|
||||
@paddle.no_grad()
|
||||
def uie_predict(model, tokenizer, data):
|
||||
examples = []
|
||||
for text in data:
|
||||
result = tokenizer(text=[text['prompt']],
|
||||
text_pair= [text['content']],
|
||||
truncation=True,
|
||||
max_seq_length=args.max_seq_len,
|
||||
pad_to_max_seq_len=True,
|
||||
return_attention_mask=True,
|
||||
return_position_ids=True,
|
||||
return_dict=False,
|
||||
return_offsets_mapping=True)
|
||||
result = result[0]
|
||||
# print(result)
|
||||
tokens = tokenizer.convert_ids_to_tokens(result["input_ids"])
|
||||
# print(tokens)
|
||||
examples.append((result['input_ids'], result['token_type_ids'],result['position_ids'],result['attention_mask']))
|
||||
|
||||
# Seperates data into some batches
|
||||
batches = [
|
||||
examples[i: i + args.batch_size]
|
||||
for i in range(0, len(examples), args.batch_size)
|
||||
]
|
||||
|
||||
batchify_fn = lambda samples, fn=Tuple(
|
||||
Pad(axis=0, pad_val=tokenizer.pad_token_id),
|
||||
Pad(axis=0, pad_val=tokenizer.pad_token_type_id),
|
||||
Pad(axis=0, pad_val=0),
|
||||
Pad(axis=0, pad_val=0),
|
||||
): fn(samples)
|
||||
|
||||
results = []
|
||||
model.eval()
|
||||
for batch in batches:
|
||||
input_ids, token_type_ids,position_ids,attention_mask = batchify_fn(batch)
|
||||
input_ids = paddle.to_tensor(input_ids)
|
||||
token_type_ids = paddle.to_tensor(token_type_ids)
|
||||
position_ids = paddle.to_tensor(position_ids)
|
||||
attention_mask = paddle.to_tensor(attention_mask)
|
||||
start_prob,end_prob = model(input_ids, token_type_ids,position_ids,attention_mask)
|
||||
# 获取最大值及其索引
|
||||
stmax_value = paddle.max(start_prob)
|
||||
stmax_value = stmax_value.numpy().item()
|
||||
stmax_idx = paddle.argmax(start_prob)
|
||||
# print(f"起始点最大值: {stmax_value.numpy()}")
|
||||
# print(f"起始点最大值索引: {stmax_idx.numpy()}")
|
||||
etmax_value = paddle.max(end_prob)
|
||||
etmax_value = etmax_value.numpy().item()
|
||||
etmax_idx = paddle.argmax(end_prob)
|
||||
# print(f"终止点点最大值: {etmax_value.numpy()}")
|
||||
# print(f"终止点最大值索引: {etmax_idx.numpy()}")
|
||||
# print(data[0]['content'][stmax_idx:etmax_idx]) # 前面会增加prompt,因此需要多加4个token
|
||||
tokens = tokenizer.convert_ids_to_tokens(result["input_ids"])
|
||||
res = ''.join(tokens[stmax_idx:etmax_idx+1])
|
||||
results.append((res,stmax_value,etmax_value,stmax_idx.numpy().item(),etmax_idx.numpy().item()))
|
||||
return results
|
||||
|
||||
def get_slot(int_res,query):
|
||||
slot_int = [1,3,4,5,6,7,8,9]
|
||||
if int_res not in slot_int:
|
||||
return {}
|
||||
if int_res==1: #'天气查询'
|
||||
prompts = ['时间','地点']
|
||||
elif int_res==3: #'页面切换'
|
||||
prompts = ['应用','模块','页面']
|
||||
else:
|
||||
prompts = ['时间','项目','项目经理','公司','工程','风险等级','周时间','施工状态']
|
||||
slot = {}
|
||||
comp = {}
|
||||
for prompt in prompts:
|
||||
uie_res = uie_predict(uie_model,uie_tokenizer,[{"prompt":prompt,"content":query}])
|
||||
print(f"prompt:{prompt} uie_res:{uie_res}")
|
||||
uie_res[0] = list(uie_res[0])
|
||||
uie_res[0][3] = uie_res[0][3] - len(prompt)
|
||||
uie_res[0][4] = uie_res[0][4] - len(prompt)
|
||||
if uie_res[0][1] >= 0.8 and uie_res[0][2] >= 0.8:
|
||||
# print(f"prompt:{prompt},值:{uie_res}")
|
||||
slot[prompt] = uie_res[0][0]
|
||||
comp[prompt] = uie_res
|
||||
drop = []
|
||||
for k,v in slot.items():
|
||||
flag = 1
|
||||
for kk,vv in comp.items():
|
||||
if kk == k:
|
||||
continue
|
||||
if comp[k][0][4] < vv[0][3] or comp[k][0][3] > vv[0][4]:
|
||||
continue
|
||||
elif comp[k][0][1]<=vv[0][1] and comp[k][0][2] <=vv[0][2]:
|
||||
print(f"1:{comp[k]}")
|
||||
print(f"2:{vv}")
|
||||
flag = 0
|
||||
break
|
||||
if flag==0:
|
||||
drop.append(k)
|
||||
for d in drop:
|
||||
del slot[d]
|
||||
|
||||
if int_res == 5: # '周计划查询'
|
||||
if slot.__contains__("周时间"):
|
||||
slot['时间'] = slot['周时间']
|
||||
del slot['周时间']
|
||||
else:
|
||||
if slot.__contains__("周时间"):
|
||||
del slot['周时间']
|
||||
cn_en = {"时间":"date","项目":"project","项目经理":"people",
|
||||
"公司":"company","工程":"work","风险等级":"risk_level",
|
||||
"施工状态":"work_status","应用":"app","模块":"module","页面":"page","地点":"area"}
|
||||
nslot = {}
|
||||
for k,v in slot.items():
|
||||
v = v.replace('#','').replace('kv','kV').replace("模块","").replace("页面","")
|
||||
if k == '工程':
|
||||
print(v)
|
||||
v = standard_project(v)
|
||||
nslot[cn_en[k]] = v
|
||||
return nslot
|
||||
|
||||
def check_lost(int_res,slot):
|
||||
# mapping = {
|
||||
# "页面切换":[['页面','应用']],
|
||||
# "作业计划数量查询":[['时间']],
|
||||
# "周计划查询":[['时间']],
|
||||
# "作业内容":[['时间']],
|
||||
# "施工人数":[['时间']],
|
||||
# "作业考勤人数":[['时间']],
|
||||
# }
|
||||
mapping = {
|
||||
1:[['date','area']],
|
||||
3:[['page'],['app'],['module']],
|
||||
4:[['date']],
|
||||
5:[['date']],
|
||||
6:[['date']],
|
||||
7:[['date']],
|
||||
8:[[]],
|
||||
9:[[]],
|
||||
}
|
||||
if not mapping.__contains__(int_res):
|
||||
return 0,[]
|
||||
cur_k = list(slot.keys())
|
||||
idx = -1
|
||||
idx_len = 99
|
||||
for i in range(len(mapping[int_res])):
|
||||
sk = mapping[int_res][i]
|
||||
left = [x for x in sk if x not in cur_k]
|
||||
more = [x for x in cur_k if x not in sk]
|
||||
if len(more) >= 0 and len(left) == 0:
|
||||
idx = i
|
||||
idx_len = 0
|
||||
break
|
||||
if len(left) < idx_len:
|
||||
idx = i
|
||||
idx_len =len(left)
|
||||
|
||||
if idx_len == 0: # 匹配通过
|
||||
return 0,cur_k
|
||||
left= [x for x in mapping[int_res][idx] if x not in cur_k]
|
||||
return 1,left#mapping[int_res][idx]
|
||||
|
||||
@time_consume
|
||||
def agent(sk,slot,messages):
|
||||
from openai import OpenAI
|
||||
|
||||
# api_base_url='http://36.33.26.201:27861/v1'
|
||||
# api_key='EMPTY'
|
||||
# model_name = 'qwen2.5-instruct'
|
||||
# api_base_url = "http://36.33.26.201:27861/chat"
|
||||
# api_base_url = "http://58.57.119.80:45291/v1"
|
||||
# model_name="qwen2.5-instruct"
|
||||
|
||||
client = OpenAI(api_key=api_key,
|
||||
base_url=api_base_url)
|
||||
|
||||
prompt = f'''
|
||||
你是一个槽位询问工具,你需要结合用户输入,根据必填槽位列表,以及当前已有槽位,对用户进行槽位问询
|
||||
询问用户哪些槽位是缺失的,需要补充,待必填槽位都已填充完毕,将所有必填槽位类型及槽位值以json格式返回
|
||||
请不要对用户输入的槽位进行修改
|
||||
|
||||
###以下是必填槽位列表###
|
||||
need_slot_lists = {sk}
|
||||
###以上是必填槽位列表###
|
||||
|
||||
###以下是当前已有槽位###
|
||||
present_slot={slot}
|
||||
|
||||
###
|
||||
|
||||
'''
|
||||
# prompt = prompt.replace("{sk}",str(sk)).replace("{slot}",slot)
|
||||
message = [{"role":"system","content":prompt}]
|
||||
message.extend(messages)
|
||||
# print(message)
|
||||
response = client.chat.completions.create(
|
||||
messages=message,
|
||||
model=model_name,
|
||||
max_tokens=1000,
|
||||
temperature=0.001,
|
||||
stream=False
|
||||
)
|
||||
res = response.choices[0].message.content
|
||||
return res
|
||||
|
||||
@time_consume
|
||||
def agent1(int_res,messages):
|
||||
lslot = {}
|
||||
for message in messages:
|
||||
if message["role"] == 'user':
|
||||
# print(int_res)
|
||||
# print(message['content'])
|
||||
slot = get_slot(int_res,message['content'])
|
||||
for k,v in slot.items():
|
||||
lslot[k] = v
|
||||
# print(lslot)
|
||||
status,sk = check_lost(int_res,lslot)
|
||||
# print(f"检测是否槽位缺失:{status} 缺失:{sk}")
|
||||
# print(lslot)
|
||||
# print(4567)
|
||||
# print(int_res)
|
||||
reres = ''
|
||||
if len(messages)==3:
|
||||
reres="非常抱歉,"
|
||||
elif len(messages) == 5:
|
||||
reres="特别抱歉,我没有理解您的意思,"
|
||||
if len(messages) >= 7:
|
||||
return lslot
|
||||
if status == 1:
|
||||
# print(11122)
|
||||
if int_res == 1: # 天气查询
|
||||
if sk==['date','area']:
|
||||
return f"{reres}请问您想查询的天气的时间和地点分别是什么?"
|
||||
elif sk ==['area']:
|
||||
return f"{reres}请问您想查询哪里的天气情况?"
|
||||
else:
|
||||
return f"{reres}请问您想查询哪一天的天气情况?"
|
||||
if int_res == 3:
|
||||
return f"{reres}请问您想查询哪个页面?"
|
||||
if int_res in [4,5,6,7,8,9]:
|
||||
mapping = {4:"日计划数量",5:"周计划数量",6:"日计划作业内容",7:"周计划作业内容",
|
||||
8:"施工人数",9:"作业考勤人数"}
|
||||
# print(f"123{reres}请问您想查询哪天的{mapping[int_res]}")
|
||||
return f"{reres}请问您想查询哪天的{mapping[int_res]}"
|
||||
else:
|
||||
return lslot
|
||||
|
||||
# @time_consume
|
||||
def main(messages):
|
||||
if len(messages) == 1: # 首轮
|
||||
query = messages[0]['content']
|
||||
# query = "请帮我打开风险管理应用的风险勘察申请"
|
||||
# query = "未来一周内刘思辰500千伏变电站新建工程有多少三级风险项作业计划刚开始施工"
|
||||
int_res = int_predict(int_model,int_tokenizer,[query],int_label)
|
||||
# int_res = int_res[0][0]
|
||||
print(f"意图识别结果:{int_res}")
|
||||
# slot_int = ['页面切换','作业计划数量查询','周计划查询','作业内容','施工人数','作业考勤人数']
|
||||
slot = get_slot(int_res[0][0],query)
|
||||
print(f"槽位获取结果:{slot}")
|
||||
# 槽位预备,检查是否有缺失的槽位
|
||||
status,sk = check_lost(int_res[0][0],slot)
|
||||
print(f"检测是否槽位缺失:{status} 缺失:{sk}")
|
||||
if status == 1:
|
||||
# 调用Agent补全
|
||||
# res = agent(sk,slot,messages)
|
||||
res = agent1(int_res[0][0],messages)
|
||||
return {"finished":0,"text":res}
|
||||
else:
|
||||
return {"finished":1,"int":int_res[0][0],"slot":slot}
|
||||
else:
|
||||
query = messages[0]['content']
|
||||
int_res = int_predict(int_model,int_tokenizer,[query],int_label)
|
||||
# int_res = int_res[0][0]
|
||||
print(f"意图识别结果:{int_res}")
|
||||
# slot = get_slot(int_res[0][0],query)
|
||||
# # print(f"槽位获取结果:{slot}")
|
||||
# # 槽位预备,检查是否有缺失的槽位
|
||||
# status,sk = check_lost(int_res[0][0],slot)
|
||||
# res = agent(sk,slot,messages)
|
||||
res = agent1(int_res[0][0],messages)
|
||||
if len(messages) >=7:
|
||||
return {"finished":1,"int":int_res[0][0],"slot":res}
|
||||
if isinstance(res,str):
|
||||
return {"finished":0,"text":res}
|
||||
else:
|
||||
return {"finished":1,"int":int_res[0][0],"slot":res}
|
||||
# st = res.find('{')
|
||||
# et = res.find('}')
|
||||
# if st != -1 and et != -1:
|
||||
# res = res[st:et+1]
|
||||
# res = eval(res)
|
||||
# return {"finished":1,"int":int_res[0][0],"slot":res}
|
||||
# else:
|
||||
# return {"finished":0,"text":res}
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# 日志记录
|
||||
def log_save(fun_name,params,result,exe_time):
|
||||
cur_time =datetime.datetime.now()
|
||||
cur_time = cur_time.strftime("%Y%m%d")
|
||||
log_file = './log/log'+cur_time+".log"
|
||||
f = open(log_file,'a',encoding='utf-8')
|
||||
tmp = {}
|
||||
tmp['fun_name'] = fun_name
|
||||
tmp['params'] = params
|
||||
tmp['result'] = result
|
||||
tmp['exe_time'] = exe_time
|
||||
tmp['time'] = (datetime.datetime.now()).strftime("%Y-%m-%d %H:%M:%S")
|
||||
f.write(str(tmp)+'\n')
|
||||
f.close()
|
||||
|
||||
class LabelMessage(BaseModel):
|
||||
code: int = pydantic.Field(200, description="HTTP status code")
|
||||
msg: str = pydantic.Field('成功', description="msg")
|
||||
label: int = pydantic.Field(..., description="标签类型")
|
||||
probability: float = pydantic.Field(..., description="概率")
|
||||
|
||||
async def get_int_label(
|
||||
user_id: str = Body(...,description='鉴权'),
|
||||
text: str = Body(..., description="意图识别样本"),
|
||||
):
|
||||
cur_time =datetime.datetime.now()
|
||||
cur_time = cur_time.strftime("%Y:%m:%d:%H:%M:%S")
|
||||
print(f"{cur_time} :: INT user_id:{user_id} text:{text}")
|
||||
st = time.time()
|
||||
if user_id !='3bb66776-1722-4c36-b14a-73dd210fe750':
|
||||
return LabelMessage(
|
||||
code=401,
|
||||
msg = '权限验证失败,请联系接口开发人员',
|
||||
label=-1,
|
||||
probability=-1
|
||||
)
|
||||
global int_label,int_model,int_tokenizer
|
||||
res = int_predict(int_model,int_tokenizer,[text],int_label)
|
||||
log_save('int',text,res,round(time.time() - st,5))
|
||||
|
||||
return LabelMessage(
|
||||
code=200,
|
||||
msg = '成功',
|
||||
label=res[0][0],
|
||||
probability=res[0][1]
|
||||
)
|
||||
|
||||
|
||||
class LabelMessage1(BaseModel):
|
||||
code: int = pydantic.Field(200, description="HTTP status code")
|
||||
msg: str = pydantic.Field('成功', description="msg")
|
||||
slot: Dict[str, Any] = pydantic.Field(..., description="槽位(JSON 格式)")
|
||||
|
||||
async def get_uie_label(
|
||||
user_id: str = Body(...,description='鉴权'),
|
||||
text: str = Body(..., description="槽位填充样本"),
|
||||
):
|
||||
cur_time =datetime.datetime.now()
|
||||
cur_time = cur_time.strftime("%Y:%m:%d:%H:%M:%S")
|
||||
print(f"{cur_time} :: UIE user_id:{user_id} text:{text}")
|
||||
st = time.time()
|
||||
if user_id !='3bb66776-1722-4c36-b14a-73dd210fe750':
|
||||
return LabelMessage1(
|
||||
code=401,
|
||||
msg = '权限验证失败,请联系接口开发人员',
|
||||
slot={}
|
||||
)
|
||||
global int_label,int_model,int_tokenizer
|
||||
res = int_predict(int_model,int_tokenizer,[text],int_label)
|
||||
slot = get_slot(res[0][0],text)
|
||||
log_save('uie',text,{"int":res,"uie":slot},round(time.time() - st,5))
|
||||
|
||||
return LabelMessage1(
|
||||
code=200,
|
||||
msg = '成功',
|
||||
slot=slot
|
||||
)
|
||||
|
||||
class LabelMessage2(BaseModel):
|
||||
code: int = pydantic.Field(200, description="HTTP status code")
|
||||
msg: str = pydantic.Field('成功', description="msg")
|
||||
answer: Dict[str, Any] = pydantic.Field(..., description="槽位(JSON 格式)")
|
||||
|
||||
async def Agent(
|
||||
user_id: str = Body(...,description='鉴权'),
|
||||
messages: List[Dict[str,str]] = Body(..., description="消息"),
|
||||
):
|
||||
cur_time =datetime.datetime.now()
|
||||
cur_time = cur_time.strftime("%Y:%m:%d:%H:%M:%S")
|
||||
print(f"{cur_time} :: AGENT user_id:{user_id} messages:{messages}")
|
||||
st = time.time()
|
||||
if user_id !='3bb66776-1722-4c36-b14a-73dd210fe750':
|
||||
return LabelMessage2(
|
||||
code=401,
|
||||
msg = '权限验证失败,请联系接口开发人员',
|
||||
answer='',
|
||||
)
|
||||
his = main(messages)
|
||||
log_save('agent',messages,his,round(time.time() - st,5))
|
||||
|
||||
return LabelMessage2(
|
||||
code=200,
|
||||
msg = '成功',
|
||||
answer=his,
|
||||
)
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
|
||||
def api_start(host,port,**kwargs):
|
||||
global app
|
||||
|
||||
app = FastAPI()
|
||||
# Add CORS middleware to allow all origins
|
||||
# 在config.py中设置OPEN_DOMAIN=True,允许跨域
|
||||
# set OPEN_DOMAIN=True in config.py to allow cross-domain
|
||||
if OPEN_CROSS_DOMAIN:
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
app.post("/intent_reco",response_model=LabelMessage, summary="意图识别")(get_int_label)
|
||||
app.post("/slot_reco",response_model=LabelMessage1, summary="槽位抽取")(get_uie_label)
|
||||
app.post("/agent",response_model=LabelMessage2, summary="Agent")(Agent)
|
||||
|
||||
uvicorn.run(app, host=host, port=port)
|
||||
|
||||
if __name__ == "__main__":
|
||||
# messages = []
|
||||
int_label = get_label("int")
|
||||
int_model,int_tokenizer = get_model(model_names['intent_token'],model_names['intent_model'])
|
||||
uie_model,uie_tokenizer = get_model(model_names['uie_token'],model_names['uie_model'])
|
||||
host = "0.0.0.0"
|
||||
port = 8072
|
||||
api_start(host,port)
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -1,15 +0,0 @@
|
|||
{'fun_name': 'agent', 'params': [{'role': 'user', 'content': '今天临关5374线跨越明巢高速公路迁改工程多少项工作计划'}], 'result': {'finished': 1, 'int': 5, 'slot': {'date': '今天', 'work': '500kV临昭5373/临关5374线跨越明巢高速公路迁改工程(PROJ-2021-0013)'}}, 'exe_time': 2.97077, 'time': '2025-02-06 21:51:56'}
|
||||
{'fun_name': 'agent', 'params': [{'role': 'user', 'content': '今天临关5374线跨越明巢高速公路迁改工程多少项工作计划'}], 'result': {'finished': 1, 'int': 5, 'slot': {'date': '今天', 'work': '500kV临昭5373/临关5374线跨越明巢高速公路迁改工程(PROJ-2021-0013)'}}, 'exe_time': 0.35625, 'time': '2025-02-06 21:52:50'}
|
||||
{'fun_name': 'agent', 'params': [{'role': 'user', 'content': '送变电一公司正在施工的作业内容是什么'}, {'role': 'assistant', 'content': '请问您想查询哪天的周计划作业内容'}, {'role': 'user', 'content': '本周'}], 'result': {'finished': 1, 'int': 7, 'slot': {'company': '送变电一公司', 'work_status': '正在施工', 'date': '本周'}}, 'exe_time': 0.48126, 'time': '2025-02-06 21:52:50'}
|
||||
{'fun_name': 'agent', 'params': [{'role': 'user', 'content': '今天临关5374线跨越明巢高速公路迁改工程多少项工作计划'}], 'result': {'finished': 1, 'int': 5, 'slot': {'date': '今天', 'work': '500kV临昭5373/临关5374线跨越明巢高速公路迁改工程(PROJ-2021-0013)'}}, 'exe_time': 2.88186, 'time': '2025-02-06 21:54:26'}
|
||||
{'fun_name': 'agent', 'params': [{'role': 'user', 'content': '送变电一公司正在施工的作业内容是什么'}, {'role': 'assistant', 'content': '请问您想查询哪天的周计划作业内容'}, {'role': 'user', 'content': '本周'}], 'result': {'finished': 1, 'int': 7, 'slot': {'company': '送变电一公司', 'work_status': '正在施工', 'date': '本周'}}, 'exe_time': 0.42058, 'time': '2025-02-06 21:54:26'}
|
||||
{'fun_name': 'agent', 'params': [{'role': 'user', 'content': '今天临关5374线跨越明巢高速公路迁改工程多少项工作计划'}], 'result': {'finished': 1, 'int': 5, 'slot': {'date': '今天', 'work': '500kV临昭5373/临关5374线跨越明巢高速公路迁改工程(PROJ-2021-0013)'}}, 'exe_time': 0.35122, 'time': '2025-02-06 21:54:33'}
|
||||
{'fun_name': 'agent', 'params': [{'role': 'user', 'content': '送变电一公司正在施工的作业内容是什么'}, {'role': 'assistant', 'content': '请问您想查询哪天的周计划作业内容'}, {'role': 'user', 'content': '本周'}], 'result': {'finished': 1, 'int': 7, 'slot': {'company': '送变电一公司', 'work_status': '正在施工', 'date': '本周'}}, 'exe_time': 0.41923, 'time': '2025-02-06 21:54:34'}
|
||||
{'fun_name': 'agent', 'params': [{'role': 'user', 'content': '今天临关5374线跨越明巢高速公路迁改工程多少项工作计划'}], 'result': {'finished': 1, 'int': 5, 'slot': {'date': '今天', 'work': '500kV临昭5373/临关5374线跨越明巢高速公路迁改工程(PROJ-2021-0013)'}}, 'exe_time': 0.32405, 'time': '2025-02-06 21:55:15'}
|
||||
{'fun_name': 'agent', 'params': [{'role': 'user', 'content': '送变电一公司正在施工的作业内容是什么'}, {'role': 'assistant', 'content': '请问您想查询哪天的周计划作业内容'}, {'role': 'user', 'content': '本周'}], 'result': {'finished': 1, 'int': 7, 'slot': {'company': '送变电一公司', 'work_status': '正在施工', 'date': '本周'}}, 'exe_time': 0.42524, 'time': '2025-02-06 21:55:15'}
|
||||
{'fun_name': 'agent', 'params': [{'role': 'user', 'content': '今天临关5374线跨越明巢高速公路迁改工程多少项工作计划'}], 'result': {'finished': 1, 'int': 5, 'slot': {'date': '今天', 'work': '500kV临昭5373/临关5374线跨越明巢高速公路迁改工程(PROJ-2021-0013)'}}, 'exe_time': 0.31932, 'time': '2025-02-06 21:55:19'}
|
||||
{'fun_name': 'agent', 'params': [{'role': 'user', 'content': '送变电一公司正在施工的作业内容是什么'}, {'role': 'assistant', 'content': '请问您想查询哪天的周计划作业内容'}, {'role': 'user', 'content': '本周'}], 'result': {'finished': 1, 'int': 7, 'slot': {'company': '送变电一公司', 'work_status': '正在施工', 'date': '本周'}}, 'exe_time': 0.41869, 'time': '2025-02-06 21:55:19'}
|
||||
{'fun_name': 'agent', 'params': [{'role': 'user', 'content': '今天临关5374线跨越明巢高速公路迁改工程多少项工作计划'}], 'result': {'finished': 1, 'int': 5, 'slot': {'date': '今天', 'work': '500kV临昭5373/临关5374线跨越明巢高速公路迁改工程(PROJ-2021-0013)'}}, 'exe_time': 0.37248, 'time': '2025-02-06 21:55:59'}
|
||||
{'fun_name': 'agent', 'params': [{'role': 'user', 'content': '送变电一公司正在施工的作业内容是什么'}, {'role': 'assistant', 'content': '请问您想查询哪天的周计划作业内容'}, {'role': 'user', 'content': '本周'}], 'result': {'finished': 1, 'int': 7, 'slot': {'company': '送变电一公司', 'work_status': '正在施工', 'date': '本周'}}, 'exe_time': 0.44125, 'time': '2025-02-06 21:56:00'}
|
||||
{'fun_name': 'agent', 'params': [{'role': 'user', 'content': '今天临关5374线跨越明巢高速公路迁改工程多少项工作计划'}], 'result': {'finished': 1, 'int': 5, 'slot': {'date': '今天', 'work': '500kV临昭5373/临关5374线跨越明巢高速公路迁改工程(PROJ-2021-0013)'}}, 'exe_time': 0.30526, 'time': '2025-02-06 21:57:13'}
|
||||
{'fun_name': 'agent', 'params': [{'role': 'user', 'content': '送变电一公司正在施工的作业内容是什么'}, {'role': 'assistant', 'content': '请问您想查询哪天的周计划作业内容'}, {'role': 'user', 'content': '本周'}], 'result': {'finished': 1, 'int': 7, 'slot': {'company': '送变电一公司', 'work_status': '正在施工', 'date': '本周'}}, 'exe_time': 0.43103, 'time': '2025-02-06 21:57:14'}
|
||||
|
|
@ -1,14 +0,0 @@
|
|||
{'fun_name': 'agent', 'params': [{'role': 'user', 'content': '110kv线路工程一般档,弧垂允许的最大偏差值'}], 'result': {'finished': 1, 'int': 10, 'slot': {}}, 'exe_time': 0.31451, 'time': '2025-02-07 15:05:33'}
|
||||
{'fun_name': 'agent', 'params': [{'role': 'user', 'content': '打开风险管控页面'}], 'result': {'finished': 0, 'text': '请问您想查询哪个页面?'}, 'exe_time': 0.3189, 'time': '2025-02-07 15:06:14'}
|
||||
{'fun_name': 'agent', 'params': [{'role': 'user', 'content': '打开日计划管理'}], 'result': {'finished': 1, 'int': 3, 'slot': {'module': '日计划管理'}}, 'exe_time': 0.12494, 'time': '2025-02-07 15:06:32'}
|
||||
{'fun_name': 'agent', 'params': [{'role': 'user', 'content': '今天1号工程有多少项作业计划'}], 'result': {'finished': 1, 'int': 4, 'slot': {'date': '今天', 'work': '1号工程'}}, 'exe_time': 0.76934, 'time': '2025-02-07 15:07:30'}
|
||||
{'fun_name': 'agent', 'params': [{'role': 'user', 'content': '本周1号项目部多少项一级风险作业计划正在施工'}], 'result': {'finished': 1, 'int': 5, 'slot': {'date': '本周', 'project': '1号项目部', 'risk_level': '一级风险', 'work_status': '正在施工'}}, 'exe_time': 0.28309, 'time': '2025-02-07 15:07:43'}
|
||||
{'fun_name': 'agent', 'params': [{'role': 'user', 'content': '今天天气怎么样'}], 'result': {'finished': 0, 'text': '请问您想查询哪里的天气情况?'}, 'exe_time': 0.14188, 'time': '2025-02-07 15:08:28'}
|
||||
{'fun_name': 'agent', 'params': [{'role': 'user', 'content': '今天合肥天气怎么样'}], 'result': {'finished': 1, 'int': 1, 'slot': {'date': '今天', 'area': '合肥'}}, 'exe_time': 0.09992, 'time': '2025-02-07 15:08:38'}
|
||||
{'fun_name': 'agent', 'params': [{'role': 'user', 'content': '明天合肥天气怎么样'}], 'result': {'finished': 1, 'int': 1, 'slot': {'date': '明天', 'area': '合肥'}}, 'exe_time': 0.09429, 'time': '2025-02-07 15:08:50'}
|
||||
{'fun_name': 'agent', 'params': [{'role': 'user', 'content': '明日合肥天气怎么样'}], 'result': {'finished': 1, 'int': 1, 'slot': {'date': '明日', 'area': '合肥'}}, 'exe_time': 0.09947, 'time': '2025-02-07 15:08:56'}
|
||||
{'fun_name': 'agent', 'params': [{'role': 'user', 'content': '今天1号工程的作业内容是什么'}], 'result': {'finished': 1, 'int': 6, 'slot': {'date': '今天', 'work': '1号工程'}}, 'exe_time': 0.31921, 'time': '2025-02-07 15:09:26'}
|
||||
{'fun_name': 'agent', 'params': [{'role': 'user', 'content': '今天1号工程的作业内容是什么'}], 'result': {'finished': 1, 'int': 6, 'slot': {'date': '今天', 'work': '1号工程'}}, 'exe_time': 0.28888, 'time': '2025-02-07 15:31:01'}
|
||||
{'fun_name': 'agent', 'params': [{'role': 'user', 'content': '今天1号工程的作业内容是什么'}], 'result': {'finished': 1, 'int': 6, 'slot': {'date': '今天', 'work': '1号工程'}}, 'exe_time': 0.33831, 'time': '2025-02-07 15:54:12'}
|
||||
{'fun_name': 'agent', 'params': [{'role': 'user', 'content': '今天1号工程的作业内容是什么'}], 'result': {'finished': 1, 'int': 6, 'slot': {'date': '今天', 'work': '1号工程'}}, 'exe_time': 0.52175, 'time': '2025-02-07 15:54:52'}
|
||||
{'fun_name': 'agent', 'params': [{'role': 'user', 'content': '今天1号工程的作业内容是什么'}], 'result': {'finished': 1, 'int': 6, 'slot': {'date': '今天', 'work': '1号工程'}}, 'exe_time': 0.33173, 'time': '2025-02-07 15:58:55'}
|
||||
97
models.py
97
models.py
|
|
@ -1,97 +0,0 @@
|
|||
from typing import List
|
||||
def return_extra(scores,documents,top_k = None):
|
||||
new_docs = []
|
||||
for idx,score in enumerate(scores):
|
||||
#new_docs.append({"index": idx, "text": documents[idx], "score": 1 / (1 + np.exp(-score))})
|
||||
new_docs.append({"index": idx, "text": documents[idx], "score": score})
|
||||
results = [{"document": {"text": doc["text"], "index": doc["index"], "relevance_score": doc["score"]}} for doc in list(sorted(new_docs, key = lambda x: x["score"], reverse=True))]
|
||||
if top_k:
|
||||
return {"results": results[:top_k]}
|
||||
else:
|
||||
return {"results":results}
|
||||
|
||||
class Singleton(type):
|
||||
def __call__(cls, *args, **kwargs):
|
||||
if not hasattr(cls,'_instence'):
|
||||
cls._instance = super().__call__(*args, **kwargs)
|
||||
return cls._instance
|
||||
|
||||
class Embedding(metaclass=Singleton):
|
||||
def __init__(self,emb):
|
||||
self.embedding = emb
|
||||
|
||||
def compute_similarity(self,sentences1: List[str], sentences2: List[str]):
|
||||
if len(sentences1) > 0 and len(sentences2) > 0:
|
||||
embeddings1 = self.encode(sentences1)
|
||||
embeddings2 = self.encode(sentences2)
|
||||
similarity = embeddings1 @ embeddings2.T
|
||||
return similarity
|
||||
else:
|
||||
return None
|
||||
|
||||
def encode(self,sentences):
|
||||
return self.embedding.encode(sentences)
|
||||
|
||||
def get_similarity(self,em1,em2):
|
||||
return em1 @ em2.T
|
||||
|
||||
class M3E_EMB(Embedding):
|
||||
def __init__(self,model_path):
|
||||
from sentence_transformers import SentenceTransformer
|
||||
self.embedding = SentenceTransformer(model_path)
|
||||
super(M3E_EMB,self).__init__(self.embedding)
|
||||
|
||||
|
||||
class KNOW_PASS():
|
||||
def __init__(self,embedding: Embedding):
|
||||
self.embedding = embedding
|
||||
self.knowbase = {} # 临时保存知识库
|
||||
self.knownames = []
|
||||
|
||||
def load_know(self,know_id: str, contents: List[str],drop_dup = True,is_cover = False):
|
||||
if self.knowbase.__contains__(know_id):
|
||||
cur = self.knowbase[know_id]
|
||||
if cur['contents'] == contents:
|
||||
print(f"已经存在{know_id}知识库,包含知识条数:{len(cur['contents'])}")
|
||||
return
|
||||
if is_cover:
|
||||
print(f"清空{know_id}知识库")
|
||||
self.knowbase[know_id] = {}
|
||||
else:
|
||||
contents.extend(cur['contents'])
|
||||
else:
|
||||
self.knownames.append(know_id)
|
||||
if len(self.knownames) > 500:
|
||||
for knowname in self.knownames[:100]:
|
||||
del self.knowbase[knowname]
|
||||
self.knownames = self.knownames[100:]
|
||||
self.knowbase[know_id] = {}
|
||||
if drop_dup:
|
||||
contents = list(set(contents))
|
||||
em = self.embedding.encode(contents)
|
||||
self.knowbase[know_id]['contents'] = contents
|
||||
self.knowbase[know_id]['embedding'] = em
|
||||
print(f"已更新{know_id}知识库,现在包含知识条数:{len(contents)}")
|
||||
return know_id
|
||||
def get_similarity_pair(self, sentences1: List[str], sentences2: List[str]):
|
||||
similarity = self.embedding.compute_similarity(sentences1,sentences2)
|
||||
similarity = similarity.tolist()
|
||||
res = {"results": similarity}
|
||||
if similarity is not None:
|
||||
return res
|
||||
return None
|
||||
|
||||
def get_similarity_know(self,query: str,know_id: str,top_k: int = 10):
|
||||
if not self.knowbase.__contains__(know_id):
|
||||
print("当前知识库中不包含{know_id},当前知识库中包括:{self.knowbase.keys()},请确定知识库名称是否正确或者创建知识库")
|
||||
return None
|
||||
em = self.embedding.encode([query])
|
||||
similarity = self.embedding.get_similarity(em,self.knowbase[know_id]['embedding'])
|
||||
similarity = similarity.tolist()
|
||||
if similarity is None:
|
||||
return None
|
||||
return return_extra(similarity[0],self.knowbase[know_id]['contents'],top_k)
|
||||
|
||||
|
||||
|
||||
|
||||
40
post_api.py
40
post_api.py
|
|
@ -1,40 +0,0 @@
|
|||
import requests
|
||||
|
||||
class Similar():
|
||||
def __init__(self):
|
||||
self.url = 'http://127.0.0.1:6006'
|
||||
|
||||
def post(self,url,json_p):
|
||||
headers = {
|
||||
"Authorization": "Bearer ACCESS_TOKEN"
|
||||
}
|
||||
response = requests.post(url,json=json_p,headers=headers)
|
||||
return response.text
|
||||
|
||||
def similar(self,texts1,texts2):
|
||||
url = self.url + '/v1/embedding'
|
||||
json_p = {
|
||||
"sentences1":texts1,
|
||||
"sentences2":texts2,
|
||||
}
|
||||
return self.post(url,json_p)
|
||||
|
||||
def load_know(self,know_id,contents,drop_dup,is_cover = False):
|
||||
url = self.url + '/v1/load_know'
|
||||
json_p = {
|
||||
"know_id": know_id,
|
||||
"contents": contents,
|
||||
"drop_dup": drop_dup,
|
||||
"is_cover": is_cover,
|
||||
}
|
||||
return self.post(url,json_p)
|
||||
|
||||
def know_sim(self,query,know_id,top_k = 10):
|
||||
url = self.url + '/v1/know_sim'
|
||||
json_p = {
|
||||
"query": query,
|
||||
"know_id": know_id,
|
||||
"top_k": top_k
|
||||
}
|
||||
return self.post(url,json_p)
|
||||
|
||||
7
run.sh
7
run.sh
|
|
@ -1,7 +0,0 @@
|
|||
curl -X POST http://127.0.0.1:8072/agent \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"user_id":"3bb66776-1722-4c36-b14a-73dd210fe750","messages":[{"role":"user","content":"今天临关5374线跨越明巢高速公路迁改工程多少项工作计划"}]}'
|
||||
|
||||
curl -X POST http://192.168.0.21:8072/agent \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{"user_id":"3bb66776-1722-4c36-b14a-73dd210fe750","messages":[{"role":"user","content":"送变电一公司正在施工的作业内容是什么"},{"role":"assistant","content":"请问您想查询哪天的周计划作业内容"},{"role":"user","content":"本周"}]}'
|
||||
143
uie/evaluate.py
143
uie/evaluate.py
|
|
@ -1,143 +0,0 @@
|
|||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
from functools import partial
|
||||
|
||||
import paddle
|
||||
from utils import convert_example, create_data_loader, reader
|
||||
|
||||
from paddlenlp.data import DataCollatorWithPadding
|
||||
from paddlenlp.datasets import MapDataset, load_dataset
|
||||
from paddlenlp.metrics import SpanEvaluator
|
||||
from paddlenlp.transformers import UIE, UIEM, AutoTokenizer
|
||||
from paddlenlp.utils.ie_utils import get_relation_type_dict, unify_prompt_name
|
||||
from paddlenlp.utils.log import logger
|
||||
|
||||
|
||||
@paddle.no_grad()
|
||||
def evaluate(model, metric, data_loader, multilingual=False):
|
||||
"""
|
||||
Given a dataset, it evals model and computes the metric.
|
||||
Args:
|
||||
model(obj:`paddle.nn.Layer`): A model to classify texts.
|
||||
metric(obj:`paddle.metric.Metric`): The evaluation metric.
|
||||
data_loader(obj:`paddle.io.DataLoader`): The dataset loader which generates batches.
|
||||
multilingual(bool): Whether is the multilingual model.
|
||||
"""
|
||||
model.eval()
|
||||
metric.reset()
|
||||
for batch in data_loader:
|
||||
if multilingual:
|
||||
start_prob, end_prob = model(batch["input_ids"], batch["position_ids"])
|
||||
else:
|
||||
start_prob, end_prob = model(
|
||||
batch["input_ids"], batch["token_type_ids"], batch["position_ids"], batch["attention_mask"]
|
||||
)
|
||||
|
||||
start_ids = paddle.cast(batch["start_positions"], "float32")
|
||||
end_ids = paddle.cast(batch["end_positions"], "float32")
|
||||
num_correct, num_infer, num_label = metric.compute(start_prob, end_prob, start_ids, end_ids)
|
||||
metric.update(num_correct, num_infer, num_label)
|
||||
precision, recall, f1 = metric.accumulate()
|
||||
model.train()
|
||||
return precision, recall, f1
|
||||
|
||||
|
||||
def do_eval():
|
||||
paddle.set_device(args.device)
|
||||
|
||||
if args.model_path in ["uie-m-base", "uie-m-large"]:
|
||||
args.multilingual = True
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.model_path)
|
||||
if args.multilingual:
|
||||
model = UIEM.from_pretrained(args.model_path)
|
||||
else:
|
||||
model = UIE.from_pretrained(args.model_path)
|
||||
|
||||
test_ds = load_dataset(reader, data_path=args.test_path, max_seq_len=args.max_seq_len, lazy=False)
|
||||
class_dict = {}
|
||||
relation_data = []
|
||||
if args.debug:
|
||||
for data in test_ds:
|
||||
class_name = unify_prompt_name(data["prompt"])
|
||||
# Only positive examples are evaluated in debug mode
|
||||
if len(data["result_list"]) != 0:
|
||||
p = "的" if args.schema_lang == "ch" else " of "
|
||||
if p not in data["prompt"]:
|
||||
class_dict.setdefault(class_name, []).append(data)
|
||||
else:
|
||||
relation_data.append((data["prompt"], data))
|
||||
|
||||
relation_type_dict = get_relation_type_dict(relation_data, schema_lang=args.schema_lang)
|
||||
else:
|
||||
class_dict["all_classes"] = test_ds
|
||||
|
||||
trans_fn = partial(
|
||||
convert_example, tokenizer=tokenizer, max_seq_len=args.max_seq_len, multilingual=args.multilingual
|
||||
)
|
||||
|
||||
for key in class_dict.keys():
|
||||
if args.debug:
|
||||
test_ds = MapDataset(class_dict[key])
|
||||
else:
|
||||
test_ds = class_dict[key]
|
||||
test_ds = test_ds.map(trans_fn)
|
||||
|
||||
data_collator = DataCollatorWithPadding(tokenizer)
|
||||
|
||||
test_data_loader = create_data_loader(test_ds, mode="test", batch_size=args.batch_size, trans_fn=data_collator)
|
||||
|
||||
metric = SpanEvaluator()
|
||||
precision, recall, f1 = evaluate(model, metric, test_data_loader, args.multilingual)
|
||||
logger.info("-----------------------------")
|
||||
logger.info("Class Name: %s" % key)
|
||||
logger.info("Evaluation Precision: %.5f | Recall: %.5f | F1: %.5f" % (precision, recall, f1))
|
||||
|
||||
if args.debug and len(relation_type_dict.keys()) != 0:
|
||||
for key in relation_type_dict.keys():
|
||||
test_ds = MapDataset(relation_type_dict[key])
|
||||
test_ds = test_ds.map(trans_fn)
|
||||
test_data_loader = create_data_loader(
|
||||
test_ds, mode="test", batch_size=args.batch_size, trans_fn=data_collator
|
||||
)
|
||||
|
||||
metric = SpanEvaluator()
|
||||
precision, recall, f1 = evaluate(model, metric, test_data_loader)
|
||||
logger.info("-----------------------------")
|
||||
if args.schema_lang == "ch":
|
||||
logger.info("Class Name: X的%s" % key)
|
||||
else:
|
||||
logger.info("Class Name: %s of X" % key)
|
||||
logger.info("Evaluation Precision: %.5f | Recall: %.5f | F1: %.5f" % (precision, recall, f1))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# yapf: disable
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument("--model_path", type=str, default=None, help="The path of saved model that you want to load.")
|
||||
parser.add_argument("--test_path", type=str, default=None, help="The path of test set.")
|
||||
parser.add_argument("--batch_size", type=int, default=16, help="Batch size per GPU/CPU for training.")
|
||||
parser.add_argument("--device", type=str, default="gpu", choices=["gpu", "cpu", "npu"], help="Device selected for evaluate.")
|
||||
parser.add_argument("--max_seq_len", type=int, default=512, help="The maximum total input sequence length after tokenization.")
|
||||
parser.add_argument("--debug", action='store_true', help="Precision, recall and F1 score are calculated for each class separately if this option is enabled.")
|
||||
parser.add_argument("--multilingual", action='store_true', help="Whether is the multilingual model.")
|
||||
parser.add_argument("--schema_lang", choices=["ch", "en"], default="ch", help="Select the language type for schema.")
|
||||
|
||||
args = parser.parse_args()
|
||||
# yapf: enable
|
||||
|
||||
do_eval()
|
||||
|
||||
244
uie/finetune.py
244
uie/finetune.py
|
|
@ -1,244 +0,0 @@
|
|||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from functools import partial
|
||||
from typing import List, Optional
|
||||
|
||||
import paddle
|
||||
from utils import convert_example, reader
|
||||
|
||||
from paddlenlp.data import DataCollatorWithPadding
|
||||
from paddlenlp.datasets import load_dataset
|
||||
from paddlenlp.metrics import SpanEvaluator
|
||||
from paddlenlp.trainer import (
|
||||
CompressionArguments,
|
||||
PdArgumentParser,
|
||||
Trainer,
|
||||
get_last_checkpoint,
|
||||
)
|
||||
from paddlenlp.transformers import UIE, UIEM, AutoTokenizer, export_model
|
||||
from paddlenlp.utils.ie_utils import compute_metrics, uie_loss_func
|
||||
from paddlenlp.utils.log import logger
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataArguments:
|
||||
"""
|
||||
Arguments pertaining to what data we are going to input our model for training and eval.
|
||||
Using `PdArgumentParser` we can turn this class into argparse arguments to be able to
|
||||
specify them on the command line.
|
||||
"""
|
||||
|
||||
train_path: str = field(
|
||||
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
|
||||
)
|
||||
|
||||
dev_path: str = field(
|
||||
default=None, metadata={"help": "The name of the dataset to use (via the datasets library)."}
|
||||
)
|
||||
|
||||
max_seq_length: Optional[int] = field(
|
||||
default=512,
|
||||
metadata={
|
||||
"help": "The maximum total input sequence length after tokenization. Sequences longer "
|
||||
"than this will be truncated, sequences shorter will be padded."
|
||||
},
|
||||
)
|
||||
|
||||
dynamic_max_length: Optional[List[int]] = field(
|
||||
default=None,
|
||||
metadata={"help": "dynamic max length from batch, it can be array of length, eg: 16 32 64 128"},
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArguments:
|
||||
"""
|
||||
Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
|
||||
"""
|
||||
|
||||
model_name_or_path: Optional[str] = field(
|
||||
default="uie-base",
|
||||
metadata={
|
||||
"help": "Path to pretrained model, such as 'uie-base', 'uie-tiny', "
|
||||
"'uie-medium', 'uie-mini', 'uie-micro', 'uie-nano', 'uie-base-en', "
|
||||
"'uie-m-base', 'uie-m-large', or finetuned model path."
|
||||
},
|
||||
)
|
||||
export_model_dir: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to directory to store the exported inference model."},
|
||||
)
|
||||
multilingual: bool = field(default=False, metadata={"help": "Whether the model is a multilingual model."})
|
||||
|
||||
|
||||
def main():
|
||||
parser = PdArgumentParser((ModelArguments, DataArguments, CompressionArguments))
|
||||
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
||||
training_args.label_names = ["start_positions", "end_positions"]
|
||||
|
||||
if model_args.model_name_or_path in ["uie-m-base", "uie-m-large"]:
|
||||
model_args.multilingual = True
|
||||
elif os.path.exists(os.path.join(model_args.model_name_or_path, "model_config.json")):
|
||||
with open(os.path.join(model_args.model_name_or_path, "model_config.json")) as f:
|
||||
init_class = json.load(f)["init_class"]
|
||||
if init_class == "UIEM":
|
||||
model_args.multilingual = True
|
||||
|
||||
# Log model and data config
|
||||
training_args.print_config(model_args, "Model")
|
||||
training_args.print_config(data_args, "Data")
|
||||
|
||||
paddle.set_device(training_args.device)
|
||||
|
||||
# Log on each process the small summary:
|
||||
logger.warning(
|
||||
f"Process rank: {training_args.local_rank}, device: {training_args.device}, world_size: {training_args.world_size}, "
|
||||
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
|
||||
)
|
||||
|
||||
# Detecting last checkpoint.
|
||||
last_checkpoint = None
|
||||
if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
|
||||
last_checkpoint = get_last_checkpoint(training_args.output_dir)
|
||||
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
|
||||
raise ValueError(
|
||||
f"Output directory ({training_args.output_dir}) already exists and is not empty. "
|
||||
"Use --overwrite_output_dir to overcome."
|
||||
)
|
||||
elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
|
||||
logger.info(
|
||||
f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
|
||||
"the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
|
||||
)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path)
|
||||
if model_args.multilingual:
|
||||
model = UIEM.from_pretrained(model_args.model_name_or_path)
|
||||
else:
|
||||
model = UIE.from_pretrained(model_args.model_name_or_path)
|
||||
|
||||
train_ds = load_dataset(reader, data_path=data_args.train_path, max_seq_len=data_args.max_seq_length, lazy=False)
|
||||
dev_ds = load_dataset(reader, data_path=data_args.dev_path, max_seq_len=data_args.max_seq_length, lazy=False)
|
||||
|
||||
trans_fn = partial(
|
||||
convert_example,
|
||||
tokenizer=tokenizer,
|
||||
max_seq_len=data_args.max_seq_length,
|
||||
multilingual=model_args.multilingual,
|
||||
dynamic_max_length=data_args.dynamic_max_length,
|
||||
)
|
||||
|
||||
train_ds = train_ds.map(trans_fn)
|
||||
dev_ds = dev_ds.map(trans_fn)
|
||||
|
||||
if training_args.device == "npu":
|
||||
data_collator = DataCollatorWithPadding(tokenizer, padding="longest")
|
||||
else:
|
||||
data_collator = DataCollatorWithPadding(tokenizer)
|
||||
|
||||
trainer = Trainer(
|
||||
model=model,
|
||||
criterion=uie_loss_func,
|
||||
args=training_args,
|
||||
data_collator=data_collator,
|
||||
train_dataset=train_ds if training_args.do_train or training_args.do_compress else None,
|
||||
eval_dataset=dev_ds if training_args.do_eval or training_args.do_compress else None,
|
||||
tokenizer=tokenizer,
|
||||
compute_metrics=compute_metrics,
|
||||
)
|
||||
|
||||
trainer.optimizer = paddle.optimizer.AdamW(
|
||||
learning_rate=training_args.learning_rate, parameters=model.parameters()
|
||||
)
|
||||
checkpoint = None
|
||||
if training_args.resume_from_checkpoint is not None:
|
||||
checkpoint = training_args.resume_from_checkpoint
|
||||
elif last_checkpoint is not None:
|
||||
checkpoint = last_checkpoint
|
||||
|
||||
# Training
|
||||
if training_args.do_train:
|
||||
train_result = trainer.train(resume_from_checkpoint=checkpoint)
|
||||
metrics = train_result.metrics
|
||||
trainer.save_model()
|
||||
trainer.log_metrics("train", metrics)
|
||||
trainer.save_metrics("train", metrics)
|
||||
trainer.save_state()
|
||||
|
||||
# Evaluate and tests model
|
||||
if training_args.do_eval:
|
||||
eval_metrics = trainer.evaluate()
|
||||
trainer.log_metrics("eval", eval_metrics)
|
||||
|
||||
# export inference model
|
||||
if training_args.do_export:
|
||||
# You can also load from certain checkpoint
|
||||
# trainer.load_state_dict_from_checkpoint("/path/to/checkpoint/")
|
||||
if training_args.device == "npu":
|
||||
# npu will transform int64 to int32 for internal calculation.
|
||||
# To reduce useless transformation, we feed int32 inputs.
|
||||
input_spec_dtype = "int32"
|
||||
else:
|
||||
input_spec_dtype = "int64"
|
||||
if model_args.multilingual:
|
||||
input_spec = [
|
||||
paddle.static.InputSpec(shape=[None, None], dtype=input_spec_dtype, name="input_ids"),
|
||||
paddle.static.InputSpec(shape=[None, None], dtype=input_spec_dtype, name="position_ids"),
|
||||
]
|
||||
else:
|
||||
input_spec = [
|
||||
paddle.static.InputSpec(shape=[None, None], dtype=input_spec_dtype, name="input_ids"),
|
||||
paddle.static.InputSpec(shape=[None, None], dtype=input_spec_dtype, name="token_type_ids"),
|
||||
paddle.static.InputSpec(shape=[None, None], dtype=input_spec_dtype, name="position_ids"),
|
||||
paddle.static.InputSpec(shape=[None, None], dtype=input_spec_dtype, name="attention_mask"),
|
||||
]
|
||||
if model_args.export_model_dir is None:
|
||||
model_args.export_model_dir = os.path.join(training_args.output_dir, "export")
|
||||
export_model(model=trainer.model, input_spec=input_spec, path=model_args.export_model_dir)
|
||||
if training_args.do_compress:
|
||||
|
||||
@paddle.no_grad()
|
||||
def custom_evaluate(self, model, data_loader):
|
||||
metric = SpanEvaluator()
|
||||
model.eval()
|
||||
metric.reset()
|
||||
for batch in data_loader:
|
||||
if model_args.multilingual:
|
||||
logits = model(input_ids=batch["input_ids"], position_ids=batch["position_ids"])
|
||||
else:
|
||||
logits = model(
|
||||
input_ids=batch["input_ids"],
|
||||
token_type_ids=batch["token_type_ids"],
|
||||
position_ids=batch["position_ids"],
|
||||
attention_mask=batch["attention_mask"],
|
||||
)
|
||||
start_prob, end_prob = logits
|
||||
start_ids, end_ids = batch["start_positions"], batch["end_positions"]
|
||||
num_correct, num_infer, num_label = metric.compute(start_prob, end_prob, start_ids, end_ids)
|
||||
metric.update(num_correct, num_infer, num_label)
|
||||
precision, recall, f1 = metric.accumulate()
|
||||
logger.info("f1: %s, precision: %s, recall: %s" % (f1, precision, f1))
|
||||
model.train()
|
||||
return f1
|
||||
|
||||
trainer.compress(custom_evaluate=custom_evaluate)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
|
@ -1,8 +0,0 @@
|
|||
python evaluate.py \
|
||||
--model_path ./checkpoint \
|
||||
--test_path ./data/data.txt \
|
||||
--batch_size 16 \
|
||||
--device gpu \
|
||||
--debug \
|
||||
--max_seq_len 256 \
|
||||
--schema_lang ch
|
||||
23
uie/train.sh
23
uie/train.sh
|
|
@ -1,23 +0,0 @@
|
|||
python finetune.py \
|
||||
--device gpu \
|
||||
--logging_steps 10 \
|
||||
--eval_steps 2000 \
|
||||
--save_steps 2000 \
|
||||
--seed 1000 \
|
||||
--model_name_or_path /home/workplace/models/uie-base \
|
||||
--output_dir ./checkpoint \
|
||||
--train_path ./data/train.txt \
|
||||
--dev_path ./data/dev.txt \
|
||||
--max_seq_len 256 \
|
||||
--per_device_train_batch_size 4 \
|
||||
--per_device_eval_batch_size 8 \
|
||||
--num_train_epochs 10 \
|
||||
--learning_rate 1e-5 \
|
||||
--do_train \
|
||||
--do_eval \
|
||||
--do_export \
|
||||
--overwrite_output_dir \
|
||||
--disable_tqdm True \
|
||||
--metric_for_best_model eval_f1 \
|
||||
--load_best_model_at_end True \
|
||||
--save_total_limit 1
|
||||
235
uie/utils.py
235
uie/utils.py
|
|
@ -1,235 +0,0 @@
|
|||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import random
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
|
||||
from paddlenlp.utils.log import logger
|
||||
|
||||
|
||||
def set_seed(seed):
|
||||
paddle.seed(seed)
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
|
||||
|
||||
def create_data_loader(dataset, mode="train", batch_size=1, trans_fn=None):
|
||||
"""
|
||||
Create dataloader.
|
||||
Args:
|
||||
dataset(obj:`paddle.io.Dataset`): Dataset instance.
|
||||
mode(obj:`str`, optional, defaults to obj:`train`): If mode is 'train', it will shuffle the dataset randomly.
|
||||
batch_size(obj:`int`, optional, defaults to 1): The sample number of a mini-batch.
|
||||
trans_fn(obj:`callable`, optional, defaults to `None`): function to convert a data sample to input ids, etc.
|
||||
Returns:
|
||||
dataloader(obj:`paddle.io.DataLoader`): The dataloader which generates batches.
|
||||
"""
|
||||
if trans_fn:
|
||||
dataset = dataset.map(trans_fn)
|
||||
|
||||
shuffle = True if mode == "train" else False
|
||||
if mode == "train":
|
||||
sampler = paddle.io.DistributedBatchSampler(dataset=dataset, batch_size=batch_size, shuffle=shuffle)
|
||||
else:
|
||||
sampler = paddle.io.BatchSampler(dataset=dataset, batch_size=batch_size, shuffle=shuffle)
|
||||
dataloader = paddle.io.DataLoader(dataset, batch_sampler=sampler, return_list=True)
|
||||
return dataloader
|
||||
|
||||
|
||||
def map_offset(ori_offset, offset_mapping):
|
||||
"""
|
||||
map ori offset to token offset
|
||||
"""
|
||||
for index, span in enumerate(offset_mapping):
|
||||
if span[0] <= ori_offset < span[1]:
|
||||
return index
|
||||
return -1
|
||||
|
||||
|
||||
def reader(data_path, max_seq_len=512):
|
||||
"""
|
||||
read json
|
||||
"""
|
||||
with open(data_path, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
json_line = json.loads(line)
|
||||
content = json_line["content"].strip()
|
||||
prompt = json_line["prompt"]
|
||||
# Model Input is aslike: [CLS] Prompt [SEP] Content [SEP]
|
||||
# It include three summary tokens.
|
||||
if max_seq_len <= len(prompt) + 3:
|
||||
raise ValueError("The value of max_seq_len is too small, please set a larger value")
|
||||
max_content_len = max_seq_len - len(prompt) - 3
|
||||
if len(content) <= max_content_len:
|
||||
yield json_line
|
||||
else:
|
||||
result_list = json_line["result_list"]
|
||||
json_lines = []
|
||||
accumulate = 0
|
||||
while True:
|
||||
cur_result_list = []
|
||||
for result in result_list:
|
||||
if result["end"] - result["start"] > max_content_len:
|
||||
logger.warning(
|
||||
"result['end'] - result ['start'] exceeds max_content_len, which will result in no valid instance being returned"
|
||||
)
|
||||
if (
|
||||
result["start"] + 1 <= max_content_len < result["end"]
|
||||
and result["end"] - result["start"] <= max_content_len
|
||||
):
|
||||
max_content_len = result["start"]
|
||||
break
|
||||
|
||||
cur_content = content[:max_content_len]
|
||||
res_content = content[max_content_len:]
|
||||
|
||||
while True:
|
||||
if len(result_list) == 0:
|
||||
break
|
||||
elif result_list[0]["end"] <= max_content_len:
|
||||
if result_list[0]["end"] > 0:
|
||||
cur_result = result_list.pop(0)
|
||||
cur_result_list.append(cur_result)
|
||||
else:
|
||||
cur_result_list = [result for result in result_list]
|
||||
break
|
||||
else:
|
||||
break
|
||||
|
||||
json_line = {"content": cur_content, "result_list": cur_result_list, "prompt": prompt}
|
||||
json_lines.append(json_line)
|
||||
|
||||
for result in result_list:
|
||||
if result["end"] <= 0:
|
||||
break
|
||||
result["start"] -= max_content_len
|
||||
result["end"] -= max_content_len
|
||||
accumulate += max_content_len
|
||||
max_content_len = max_seq_len - len(prompt) - 3
|
||||
if len(res_content) == 0:
|
||||
break
|
||||
elif len(res_content) < max_content_len:
|
||||
json_line = {"content": res_content, "result_list": result_list, "prompt": prompt}
|
||||
json_lines.append(json_line)
|
||||
break
|
||||
else:
|
||||
content = res_content
|
||||
|
||||
for json_line in json_lines:
|
||||
yield json_line
|
||||
|
||||
|
||||
def get_dynamic_max_length(examples, default_max_length: int, dynamic_max_length: List[int]) -> int:
|
||||
"""get max_length by examples which you can change it by examples in batch"""
|
||||
cur_length = len(examples[0]["input_ids"])
|
||||
max_length = default_max_length
|
||||
for max_length_option in sorted(dynamic_max_length):
|
||||
if cur_length <= max_length_option:
|
||||
max_length = max_length_option
|
||||
break
|
||||
return max_length
|
||||
|
||||
|
||||
def convert_example(
|
||||
example, tokenizer, max_seq_len, multilingual=False, dynamic_max_length: Optional[List[int]] = None
|
||||
):
|
||||
"""
|
||||
example: {
|
||||
title
|
||||
prompt
|
||||
content
|
||||
result_list
|
||||
}
|
||||
"""
|
||||
if dynamic_max_length is not None:
|
||||
temp_encoded_inputs = tokenizer(
|
||||
text=[example["prompt"]],
|
||||
text_pair=[example["content"]],
|
||||
truncation=True,
|
||||
max_seq_len=max_seq_len,
|
||||
return_attention_mask=True,
|
||||
return_position_ids=True,
|
||||
return_dict=False,
|
||||
return_offsets_mapping=True,
|
||||
)
|
||||
max_length = get_dynamic_max_length(
|
||||
examples=temp_encoded_inputs, default_max_length=max_seq_len, dynamic_max_length=dynamic_max_length
|
||||
)
|
||||
# always pad to max_length
|
||||
encoded_inputs = tokenizer(
|
||||
text=[example["prompt"]],
|
||||
text_pair=[example["content"]],
|
||||
truncation=True,
|
||||
max_seq_len=max_length,
|
||||
pad_to_max_seq_len=True,
|
||||
return_attention_mask=True,
|
||||
return_position_ids=True,
|
||||
return_dict=False,
|
||||
return_offsets_mapping=True,
|
||||
)
|
||||
start_ids = [0.0 for x in range(max_length)]
|
||||
end_ids = [0.0 for x in range(max_length)]
|
||||
else:
|
||||
encoded_inputs = tokenizer(
|
||||
text=[example["prompt"]],
|
||||
text_pair=[example["content"]],
|
||||
truncation=True,
|
||||
max_seq_len=max_seq_len,
|
||||
pad_to_max_seq_len=True,
|
||||
return_attention_mask=True,
|
||||
return_position_ids=True,
|
||||
return_dict=False,
|
||||
return_offsets_mapping=True,
|
||||
)
|
||||
start_ids = [0.0 for x in range(max_seq_len)]
|
||||
end_ids = [0.0 for x in range(max_seq_len)]
|
||||
|
||||
encoded_inputs = encoded_inputs[0]
|
||||
offset_mapping = [list(x) for x in encoded_inputs["offset_mapping"]]
|
||||
bias = 0
|
||||
for index in range(1, len(offset_mapping)):
|
||||
mapping = offset_mapping[index]
|
||||
if mapping[0] == 0 and mapping[1] == 0 and bias == 0:
|
||||
bias = offset_mapping[index - 1][1] + 1 # Includes [SEP] token
|
||||
if mapping[0] == 0 and mapping[1] == 0:
|
||||
continue
|
||||
offset_mapping[index][0] += bias
|
||||
offset_mapping[index][1] += bias
|
||||
for item in example["result_list"]:
|
||||
start = map_offset(item["start"] + bias, offset_mapping)
|
||||
end = map_offset(item["end"] - 1 + bias, offset_mapping)
|
||||
start_ids[start] = 1.0
|
||||
end_ids[end] = 1.0
|
||||
if multilingual:
|
||||
tokenized_output = {
|
||||
"input_ids": encoded_inputs["input_ids"],
|
||||
"position_ids": encoded_inputs["position_ids"],
|
||||
"start_positions": start_ids,
|
||||
"end_positions": end_ids,
|
||||
}
|
||||
else:
|
||||
tokenized_output = {
|
||||
"input_ids": encoded_inputs["input_ids"],
|
||||
"token_type_ids": encoded_inputs["token_type_ids"],
|
||||
"position_ids": encoded_inputs["position_ids"],
|
||||
"attention_mask": encoded_inputs["attention_mask"],
|
||||
"start_positions": start_ids,
|
||||
"end_positions": end_ids,
|
||||
}
|
||||
return tokenized_output
|
||||
|
||||
Loading…
Reference in New Issue