81 lines
2.8 KiB
Python
81 lines
2.8 KiB
Python
#!/usr/bin/env python
|
|
# -*- encoding: utf-8 -*-
|
|
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
|
|
class IOUloss(nn.Module):
|
|
def __init__(self, reduction="none", loss_type="iou"):
|
|
super(IOUloss, self).__init__()
|
|
self.reduction = reduction
|
|
self.loss_type = loss_type
|
|
|
|
def forward(self, pred, target):
|
|
assert pred.shape[0] == target.shape[0]
|
|
|
|
pred = pred.view(-1, 4)
|
|
target = target.view(-1, 4)
|
|
tl = torch.max(
|
|
(pred[:, :2] - pred[:, 2:] / 2), (target[:, :2] - target[:, 2:] / 2)
|
|
)
|
|
br = torch.min(
|
|
(pred[:, :2] + pred[:, 2:] / 2), (target[:, :2] + target[:, 2:] / 2)
|
|
)
|
|
|
|
area_p = torch.prod(pred[:, 2:], 1)
|
|
area_g = torch.prod(target[:, 2:], 1)
|
|
|
|
en = (tl < br).type(tl.type()).prod(dim=1)
|
|
area_i = torch.prod(br - tl, 1) * en
|
|
iou = (area_i) / (area_p + area_g - area_i + 1e-16)
|
|
|
|
if self.loss_type == "iou":
|
|
loss = 1 - iou ** 2
|
|
elif self.loss_type == "giou":
|
|
c_tl = torch.min(
|
|
(pred[:, :2] - pred[:, 2:] / 2), (target[:, :2] - target[:, 2:] / 2)
|
|
)
|
|
c_br = torch.max(
|
|
(pred[:, :2] + pred[:, 2:] / 2), (target[:, :2] + target[:, 2:] / 2)
|
|
)
|
|
area_c = torch.prod(c_br - c_tl, 1)
|
|
giou = iou - (area_c - area_i) / area_c.clamp(1e-16)
|
|
loss = 1 - giou.clamp(min=-1.0, max=1.0)
|
|
|
|
if self.reduction == "mean":
|
|
loss = loss.mean()
|
|
elif self.reduction == "sum":
|
|
loss = loss.sum()
|
|
|
|
return loss
|
|
|
|
|
|
def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2):
|
|
"""
|
|
Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
|
|
Args:
|
|
inputs: A float tensor of arbitrary shape.
|
|
The predictions for each example.
|
|
targets: A float tensor with the same shape as inputs. Stores the binary
|
|
classification label for each element in inputs
|
|
(0 for the negative class and 1 for the positive class).
|
|
alpha: (optional) Weighting factor in range (0,1) to balance
|
|
positive vs negative examples. Default = -1 (no weighting).
|
|
gamma: Exponent of the modulating factor (1 - p_t) to
|
|
balance easy vs hard examples.
|
|
Returns:
|
|
Loss tensor
|
|
"""
|
|
prob = inputs.sigmoid()
|
|
ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
|
|
p_t = prob * targets + (1 - prob) * (1 - targets)
|
|
loss = ce_loss * ((1 - p_t) ** gamma)
|
|
|
|
if alpha >= 0:
|
|
alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
|
|
loss = alpha_t * loss
|
|
#return loss.mean(0).sum() / num_boxes
|
|
return loss.sum() / num_boxes |