49 lines
1.3 KiB
Python
49 lines
1.3 KiB
Python
#!/usr/bin/env python
|
|
# -*- encoding: utf-8 -*-
|
|
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
|
|
|
|
import torch.nn as nn
|
|
|
|
from .yolo_head import YOLOXHead
|
|
from .yolo_pafpn import YOLOPAFPN
|
|
|
|
|
|
class YOLOX(nn.Module):
|
|
"""
|
|
YOLOX model module. The module list is defined by create_yolov3_modules function.
|
|
The network returns loss values from three YOLO layers during training
|
|
and detection results during test.
|
|
"""
|
|
|
|
def __init__(self, backbone=None, head=None):
|
|
super().__init__()
|
|
if backbone is None:
|
|
backbone = YOLOPAFPN()
|
|
if head is None:
|
|
head = YOLOXHead(80)
|
|
|
|
self.backbone = backbone
|
|
self.head = head
|
|
|
|
def forward(self, x, targets=None):
|
|
# fpn output content features of [dark3, dark4, dark5]
|
|
fpn_outs = self.backbone(x)
|
|
|
|
if self.training:
|
|
assert targets is not None
|
|
loss, iou_loss, conf_loss, cls_loss, l1_loss, num_fg = self.head(
|
|
fpn_outs, targets, x
|
|
)
|
|
outputs = {
|
|
"total_loss": loss,
|
|
"iou_loss": iou_loss,
|
|
"l1_loss": l1_loss,
|
|
"conf_loss": conf_loss,
|
|
"cls_loss": cls_loss,
|
|
"num_fg": num_fg,
|
|
}
|
|
else:
|
|
outputs = self.head(fpn_outs)
|
|
|
|
return outputs
|