109 lines
3.3 KiB
Python
109 lines
3.3 KiB
Python
import json
|
||
from flask import Flask, jsonify, request
|
||
from werkzeug.exceptions import HTTPException
|
||
from paddlenlp.transformers import ErnieForTokenClassification, ErnieTokenizer
|
||
import paddle
|
||
|
||
# 1. 加载模型和 tokenizer
|
||
model_path = R"E:\workingSpace\PycharmProjects\Intention_dev\uie\uie_ner\checkpoint-4320" # 你的模型路径
|
||
model = ErnieForTokenClassification.from_pretrained(model_path)
|
||
tokenizer = ErnieTokenizer.from_pretrained(model_path)
|
||
|
||
# 标签映射
|
||
label_map = {
|
||
0: 'O', 1: 'B-date', 11: 'I-date',
|
||
2: 'B-project_name', 12: 'I-project_name',
|
||
3: 'B-project_type', 13: 'I-project_type',
|
||
4: 'B-construction_unit', 14: 'I-construction_unit',
|
||
5: 'B-implementation_organization', 15: 'I-implementation_organization',
|
||
6: 'B-project_department', 16: 'I-project_department',
|
||
7: 'B-project_manager', 17: 'I-project_manager',
|
||
8: 'B-subcontractor', 18: 'I-subcontractor',
|
||
9: 'B-team_leader', 19: 'I-team_leader',
|
||
10: 'B-risk_level', 20: 'I-risk_level'
|
||
}
|
||
|
||
app = Flask(__name__)
|
||
|
||
|
||
# 统一的异常处理函数
|
||
@app.errorhandler(Exception)
|
||
def handle_exception(e):
|
||
"""统一异常处理"""
|
||
if isinstance(e, HTTPException):
|
||
return jsonify({
|
||
"error": {
|
||
"type": e.name,
|
||
"message": e.description,
|
||
"status_code": e.code
|
||
}
|
||
}), e.code
|
||
return jsonify({
|
||
"error": {
|
||
"type": "InternalServerError",
|
||
"message": str(e)
|
||
}
|
||
}), 500
|
||
|
||
|
||
@app.route('/')
|
||
def hello_world():
|
||
"""示例路由,返回 Hello World"""
|
||
return jsonify({"message": "Hello, world!"})
|
||
|
||
|
||
@app.route('/predict', methods=['POST'])
|
||
def predict():
|
||
"""处理预测请求"""
|
||
data = request.get_json()
|
||
|
||
# 提取文本
|
||
text = data.get("text", "")
|
||
if not text:
|
||
return jsonify({"error": "No text provided"}), 400
|
||
|
||
# 处理输入文本
|
||
inputs = tokenizer(text, max_len=512, return_tensors="pd")
|
||
model.eval()
|
||
|
||
with paddle.no_grad():
|
||
logits = model(**inputs)
|
||
predictions = paddle.argmax(logits, axis=-1)
|
||
|
||
# 解析预测结果
|
||
predicted_labels = predictions.numpy()[0]
|
||
tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0].numpy())
|
||
|
||
entities = {}
|
||
current_entity = None
|
||
current_label = None
|
||
|
||
for token, label_id in zip(tokens, predicted_labels):
|
||
label = label_map.get(label_id, "O")
|
||
|
||
if label.startswith("B-"): # 开始新实体
|
||
if current_entity:
|
||
entities[current_label] = "".join(current_entity)
|
||
current_entity = [token]
|
||
current_label = label[2:] # 去掉 B-
|
||
|
||
elif label.startswith("I-") and current_entity and label[2:] == current_label:
|
||
current_entity.append(token) # 继续合并同一实体
|
||
|
||
else: # 非实体
|
||
if current_entity:
|
||
entities[current_label] = "".join(current_entity)
|
||
current_entity = None
|
||
current_label = None
|
||
|
||
# 处理最后一个实体
|
||
if current_entity:
|
||
entities[current_label] = "".join(current_entity)
|
||
|
||
# 输出最终的实体作为 JSON
|
||
return jsonify(entities)
|
||
|
||
|
||
if __name__ == '__main__':
|
||
app.run(host='0.0.0.0', port=5000, debug=True) # 启动 API,调试模式和指定端口
|