192 lines
7.3 KiB
Python
192 lines
7.3 KiB
Python
"""
|
|
Copyright (c) https://github.com/xingyizhou/CenterTrack
|
|
Modified by Peize Sun, Rufeng Zhang
|
|
"""
|
|
# coding: utf-8
|
|
import torch
|
|
from scipy.optimize import linear_sum_assignment
|
|
from util import box_ops
|
|
import copy
|
|
|
|
class Tracker(object):
|
|
def __init__(self, score_thresh, max_age=32):
|
|
self.score_thresh = score_thresh
|
|
self.low_thresh = 0.2
|
|
self.high_thresh = score_thresh + 0.1
|
|
self.max_age = max_age
|
|
self.id_count = 0
|
|
self.tracks_dict = dict()
|
|
self.tracks = list()
|
|
self.unmatched_tracks = list()
|
|
self.reset_all()
|
|
|
|
def reset_all(self):
|
|
self.id_count = 0
|
|
self.tracks_dict = dict()
|
|
self.tracks = list()
|
|
self.unmatched_tracks = list()
|
|
|
|
def init_track(self, results):
|
|
|
|
scores = results["scores"]
|
|
classes = results["labels"]
|
|
bboxes = results["boxes"] # x1y1x2y2
|
|
|
|
ret = list()
|
|
ret_dict = dict()
|
|
for idx in range(scores.shape[0]):
|
|
if scores[idx] >= self.score_thresh:
|
|
self.id_count += 1
|
|
obj = dict()
|
|
obj["score"] = float(scores[idx])
|
|
obj["bbox"] = bboxes[idx, :].cpu().numpy().tolist()
|
|
obj["tracking_id"] = self.id_count
|
|
obj['active'] = 1
|
|
obj['age'] = 1
|
|
ret.append(obj)
|
|
ret_dict[idx] = obj
|
|
|
|
self.tracks = ret
|
|
self.tracks_dict = ret_dict
|
|
return copy.deepcopy(ret)
|
|
|
|
|
|
def step(self, output_results):
|
|
scores = output_results["scores"]
|
|
bboxes = output_results["boxes"] # x1y1x2y2
|
|
track_bboxes = output_results["track_boxes"] if "track_boxes" in output_results else None # x1y1x2y2
|
|
|
|
results = list()
|
|
results_dict = dict()
|
|
results_second = list()
|
|
|
|
tracks = list()
|
|
|
|
for idx in range(scores.shape[0]):
|
|
if idx in self.tracks_dict and track_bboxes is not None:
|
|
self.tracks_dict[idx]["bbox"] = track_bboxes[idx, :].cpu().numpy().tolist()
|
|
|
|
if scores[idx] >= self.score_thresh:
|
|
obj = dict()
|
|
obj["score"] = float(scores[idx])
|
|
obj["bbox"] = bboxes[idx, :].cpu().numpy().tolist()
|
|
results.append(obj)
|
|
results_dict[idx] = obj
|
|
elif scores[idx] >= self.low_thresh:
|
|
second_obj = dict()
|
|
second_obj["score"] = float(scores[idx])
|
|
second_obj["bbox"] = bboxes[idx, :].cpu().numpy().tolist()
|
|
results_second.append(second_obj)
|
|
results_dict[idx] = second_obj
|
|
|
|
tracks = [v for v in self.tracks_dict.values()] + self.unmatched_tracks
|
|
# for trackss in tracks:
|
|
# print(trackss.keys())
|
|
N = len(results)
|
|
M = len(tracks)
|
|
|
|
ret = list()
|
|
unmatched_tracks = [t for t in range(M)]
|
|
unmatched_dets = [d for d in range(N)]
|
|
|
|
if N > 0 and M > 0:
|
|
det_box = torch.stack([torch.tensor(obj['bbox']) for obj in results], dim=0) # N x 4
|
|
track_box = torch.stack([torch.tensor(obj['bbox']) for obj in tracks], dim=0) # M x 4
|
|
cost_bbox = 1.0 - box_ops.generalized_box_iou(det_box, track_box) # N x M
|
|
|
|
matched_indices = linear_sum_assignment(cost_bbox)
|
|
unmatched_dets = [d for d in range(N) if not (d in matched_indices[0])]
|
|
unmatched_tracks = [d for d in range(M) if not (d in matched_indices[1])]
|
|
|
|
matches = [[],[]]
|
|
for (m0, m1) in zip(matched_indices[0], matched_indices[1]):
|
|
if cost_bbox[m0, m1] > 1.2:
|
|
unmatched_dets.append(m0)
|
|
unmatched_tracks.append(m1)
|
|
else:
|
|
matches[0].append(m0)
|
|
matches[1].append(m1)
|
|
|
|
for (m0, m1) in zip(matches[0], matches[1]):
|
|
track = results[m0]
|
|
track['tracking_id'] = tracks[m1]['tracking_id']
|
|
track['age'] = 1
|
|
track['active'] = 1
|
|
ret.append(track)
|
|
|
|
# second association
|
|
N_second = len(results_second)
|
|
unmatched_tracks_obj = list()
|
|
for i in unmatched_tracks:
|
|
#print(tracks[i].keys())
|
|
track = tracks[i]
|
|
if track['active'] == 1:
|
|
unmatched_tracks_obj.append(track)
|
|
M_second = len(unmatched_tracks_obj)
|
|
unmatched_tracks_second = [t for t in range(M_second)]
|
|
|
|
if N_second > 0 and M_second > 0:
|
|
det_box_second = torch.stack([torch.tensor(obj['bbox']) for obj in results_second], dim=0) # N_second x 4
|
|
track_box_second = torch.stack([torch.tensor(obj['bbox']) for obj in unmatched_tracks_obj], dim=0) # M_second x 4
|
|
cost_bbox_second = 1.0 - box_ops.generalized_box_iou(det_box_second, track_box_second) # N_second x M_second
|
|
|
|
matched_indices_second = linear_sum_assignment(cost_bbox_second)
|
|
unmatched_tracks_second = [d for d in range(M_second) if not (d in matched_indices_second[1])]
|
|
|
|
matches_second = [[],[]]
|
|
for (m0, m1) in zip(matched_indices_second[0], matched_indices_second[1]):
|
|
if cost_bbox_second[m0, m1] > 0.8:
|
|
unmatched_tracks_second.append(m1)
|
|
else:
|
|
matches_second[0].append(m0)
|
|
matches_second[1].append(m1)
|
|
|
|
for (m0, m1) in zip(matches_second[0], matches_second[1]):
|
|
track = results_second[m0]
|
|
track['tracking_id'] = unmatched_tracks_obj[m1]['tracking_id']
|
|
track['age'] = 1
|
|
track['active'] = 1
|
|
ret.append(track)
|
|
|
|
for i in unmatched_dets:
|
|
trackd = results[i]
|
|
if trackd["score"] >= self.high_thresh:
|
|
self.id_count += 1
|
|
trackd['tracking_id'] = self.id_count
|
|
trackd['age'] = 1
|
|
trackd['active'] = 1
|
|
ret.append(trackd)
|
|
|
|
# ------------------------------------------------------ #
|
|
ret_unmatched_tracks = []
|
|
|
|
for j in unmatched_tracks:
|
|
track = tracks[j]
|
|
if track['active'] == 0 and track['age'] < self.max_age:
|
|
track['age'] += 1
|
|
track['active'] = 0
|
|
ret.append(track)
|
|
ret_unmatched_tracks.append(track)
|
|
|
|
for i in unmatched_tracks_second:
|
|
track = unmatched_tracks_obj[i]
|
|
if track['age'] < self.max_age:
|
|
track['age'] += 1
|
|
track['active'] = 0
|
|
ret.append(track)
|
|
ret_unmatched_tracks.append(track)
|
|
|
|
# for i in unmatched_tracks:
|
|
# track = tracks[i]
|
|
# if track['age'] < self.max_age:
|
|
# track['age'] += 1
|
|
# track['active'] = 0
|
|
# ret.append(track)
|
|
# ret_unmatched_tracks.append(track)
|
|
#print(len(ret_unmatched_tracks))
|
|
|
|
self.tracks = ret
|
|
self.tracks_dict = {red_ind:red for red_ind, red in results_dict.items() if 'tracking_id' in red}
|
|
self.unmatched_tracks = ret_unmatched_tracks
|
|
return copy.deepcopy(ret)
|