import json from flask import Flask, jsonify, request from werkzeug.exceptions import HTTPException from paddlenlp.transformers import ErnieForTokenClassification, ErnieTokenizer import paddle # 1. 加载模型和 tokenizer model_path = R"E:\workingSpace\PycharmProjects\Intention_dev\uie\uie_ner\checkpoint-4320" # 你的模型路径 model = ErnieForTokenClassification.from_pretrained(model_path) tokenizer = ErnieTokenizer.from_pretrained(model_path) # 标签映射 label_map = { 0: 'O', 1: 'B-date', 11: 'I-date', 2: 'B-project_name', 12: 'I-project_name', 3: 'B-project_type', 13: 'I-project_type', 4: 'B-construction_unit', 14: 'I-construction_unit', 5: 'B-implementation_organization', 15: 'I-implementation_organization', 6: 'B-project_department', 16: 'I-project_department', 7: 'B-project_manager', 17: 'I-project_manager', 8: 'B-subcontractor', 18: 'I-subcontractor', 9: 'B-team_leader', 19: 'I-team_leader', 10: 'B-risk_level', 20: 'I-risk_level' } app = Flask(__name__) # 统一的异常处理函数 @app.errorhandler(Exception) def handle_exception(e): """统一异常处理""" if isinstance(e, HTTPException): return jsonify({ "error": { "type": e.name, "message": e.description, "status_code": e.code } }), e.code return jsonify({ "error": { "type": "InternalServerError", "message": str(e) } }), 500 @app.route('/') def hello_world(): """示例路由,返回 Hello World""" return jsonify({"message": "Hello, world!"}) @app.route('/predict', methods=['POST']) def predict(): """处理预测请求""" data = request.get_json() # 提取文本 text = data.get("text", "") if not text: return jsonify({"error": "No text provided"}), 400 # 处理输入文本 inputs = tokenizer(text, max_len=512, return_tensors="pd") model.eval() with paddle.no_grad(): logits = model(**inputs) predictions = paddle.argmax(logits, axis=-1) # 解析预测结果 predicted_labels = predictions.numpy()[0] tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0].numpy()) entities = {} current_entity = None current_label = None for token, label_id in zip(tokens, predicted_labels): label = label_map.get(label_id, "O") if label.startswith("B-"): # 开始新实体 if current_entity: entities[current_label] = "".join(current_entity) current_entity = [token] current_label = label[2:] # 去掉 B- elif label.startswith("I-") and current_entity and label[2:] == current_label: current_entity.append(token) # 继续合并同一实体 else: # 非实体 if current_entity: entities[current_label] = "".join(current_entity) current_entity = None current_label = None # 处理最后一个实体 if current_entity: entities[current_label] = "".join(current_entity) # 输出最终的实体作为 JSON return jsonify(entities) if __name__ == '__main__': app.run(host='0.0.0.0', port=5000, debug=True) # 启动 API,调试模式和指定端口