338 lines
12 KiB
Python
338 lines
12 KiB
Python
import numpy as np
|
|
import torchvision
|
|
import time
|
|
import math
|
|
import os
|
|
import copy
|
|
import pdb
|
|
import argparse
|
|
import sys
|
|
import cv2
|
|
import skimage.io
|
|
import skimage.transform
|
|
import skimage.color
|
|
import skimage
|
|
import torch
|
|
import model
|
|
|
|
from torch.utils.data import Dataset, DataLoader
|
|
from torchvision import datasets, models, transforms
|
|
from dataloader import CSVDataset, collater, Resizer, AspectRatioBasedSampler, Augmenter, UnNormalizer, Normalizer, RGB_MEAN, RGB_STD
|
|
from scipy.optimize import linear_sum_assignment
|
|
|
|
# assert torch.__version__.split('.')[1] == '4'
|
|
|
|
print('CUDA available: {}'.format(torch.cuda.is_available()))
|
|
|
|
color_list = [(0, 0, 255), (255, 0, 0), (0, 255, 0), (255, 0, 255), (0, 255, 255), (255, 255, 0), (128, 0, 255),
|
|
(0, 128, 255), (128, 255, 0), (0, 255, 128), (255, 128, 0), (255, 0, 128), (128, 128, 255), (128, 255, 128), (255, 128, 128), (128, 128, 0), (128, 0, 128)]
|
|
|
|
class detect_rect:
|
|
def __init__(self):
|
|
self.curr_frame = 0
|
|
self.curr_rect = np.array([0, 0, 1, 1])
|
|
self.next_rect = np.array([0, 0, 1, 1])
|
|
self.conf = 0
|
|
self.id = 0
|
|
|
|
@property
|
|
def position(self):
|
|
x = (self.curr_rect[0] + self.curr_rect[2])/2
|
|
y = (self.curr_rect[1] + self.curr_rect[3])/2
|
|
return np.array([x, y])
|
|
|
|
@property
|
|
def size(self):
|
|
w = self.curr_rect[2] - self.curr_rect[0]
|
|
h = self.curr_rect[3] - self.curr_rect[1]
|
|
return np.array([w, h])
|
|
|
|
class tracklet:
|
|
def __init__(self, det_rect):
|
|
self.id = det_rect.id
|
|
self.rect_list = [det_rect]
|
|
self.rect_num = 1
|
|
self.last_rect = det_rect
|
|
self.last_frame = det_rect.curr_frame
|
|
self.no_match_frame = 0
|
|
|
|
def add_rect(self, det_rect):
|
|
self.rect_list.append(det_rect)
|
|
self.rect_num = self.rect_num + 1
|
|
self.last_rect = det_rect
|
|
self.last_frame = det_rect.curr_frame
|
|
|
|
@property
|
|
def velocity(self):
|
|
if(self.rect_num < 2):
|
|
return (0, 0)
|
|
elif(self.rect_num < 6):
|
|
return (self.rect_list[self.rect_num - 1].position - self.rect_list[self.rect_num - 2].position) / (self.rect_list[self.rect_num - 1].curr_frame - self.rect_list[self.rect_num - 2].curr_frame)
|
|
else:
|
|
v1 = (self.rect_list[self.rect_num - 1].position - self.rect_list[self.rect_num - 4].position) / (self.rect_list[self.rect_num - 1].curr_frame - self.rect_list[self.rect_num - 4].curr_frame)
|
|
v2 = (self.rect_list[self.rect_num - 2].position - self.rect_list[self.rect_num - 5].position) / (self.rect_list[self.rect_num - 2].curr_frame - self.rect_list[self.rect_num - 5].curr_frame)
|
|
v3 = (self.rect_list[self.rect_num - 3].position - self.rect_list[self.rect_num - 6].position) / (self.rect_list[self.rect_num - 3].curr_frame - self.rect_list[self.rect_num - 6].curr_frame)
|
|
return (v1 + v2 + v3) / 3
|
|
|
|
|
|
def cal_iou(rect1, rect2):
|
|
x1, y1, x2, y2 = rect1
|
|
x3, y3, x4, y4 = rect2
|
|
i_w = min(x2, x4) - max(x1, x3)
|
|
i_h = min(y2, y4) - max(y1, y3)
|
|
if(i_w <= 0 or i_h <= 0):
|
|
return 0
|
|
i_s = i_w * i_h
|
|
s_1 = (x2 - x1) * (y2 - y1)
|
|
s_2 = (x4 - x3) * (y4 - y3)
|
|
return float(i_s) / (s_1 + s_2 - i_s)
|
|
|
|
def cal_simi(det_rect1, det_rect2):
|
|
return cal_iou(det_rect1.next_rect, det_rect2.curr_rect)
|
|
|
|
def cal_simi_track_det(track, det_rect):
|
|
if(det_rect.curr_frame <= track.last_frame):
|
|
print("cal_simi_track_det error")
|
|
return 0
|
|
elif(det_rect.curr_frame - track.last_frame == 1):
|
|
return cal_iou(track.last_rect.next_rect, det_rect.curr_rect)
|
|
else:
|
|
pred_rect = track.last_rect.curr_rect + np.append(track.velocity, track.velocity) * (det_rect.curr_frame - track.last_frame)
|
|
return cal_iou(pred_rect, det_rect.curr_rect)
|
|
|
|
def track_det_match(tracklet_list, det_rect_list, min_iou = 0.5):
|
|
num1 = len(tracklet_list)
|
|
num2 = len(det_rect_list)
|
|
cost_mat = np.zeros((num1, num2))
|
|
for i in range(num1):
|
|
for j in range(num2):
|
|
cost_mat[i, j] = -cal_simi_track_det(tracklet_list[i], det_rect_list[j])
|
|
|
|
match_result = linear_sum_assignment(cost_mat)
|
|
match_result = np.asarray(match_result)
|
|
match_result = np.transpose(match_result)
|
|
|
|
matches, unmatched1, unmatched2 = [], [], []
|
|
for i in range(num1):
|
|
if i not in match_result[:, 0]:
|
|
unmatched1.append(i)
|
|
for j in range(num2):
|
|
if j not in match_result[:, 1]:
|
|
unmatched2.append(j)
|
|
for i, j in match_result:
|
|
if cost_mat[i, j] > -min_iou:
|
|
unmatched1.append(i)
|
|
unmatched2.append(j)
|
|
else:
|
|
matches.append((i, j))
|
|
return matches, unmatched1, unmatched2
|
|
|
|
def draw_caption(image, box, caption, color):
|
|
b = np.array(box).astype(int)
|
|
cv2.putText(image, caption, (b[0], b[1] - 8), cv2.FONT_HERSHEY_PLAIN, 2, color, 2)
|
|
|
|
|
|
def run_each_dataset(model_dir, retinanet, dataset_path, subset, cur_dataset):
|
|
print(cur_dataset)
|
|
|
|
img_list = os.listdir(os.path.join(dataset_path, subset, cur_dataset, 'img1'))
|
|
img_list = [os.path.join(dataset_path, subset, cur_dataset, 'img1', _) for _ in img_list if ('jpg' in _) or ('png' in _)]
|
|
img_list = sorted(img_list)
|
|
|
|
img_len = len(img_list)
|
|
last_feat = None
|
|
|
|
confidence_threshold = 0.4
|
|
IOU_threshold = 0.5
|
|
retention_threshold = 10
|
|
|
|
det_list_all = []
|
|
tracklet_all = []
|
|
max_id = 0
|
|
max_draw_len = 100
|
|
draw_interval = 5
|
|
img_width = 1920
|
|
img_height = 1080
|
|
fps = 30
|
|
|
|
for i in range(img_len):
|
|
det_list_all.append([])
|
|
|
|
for idx in range((int(img_len / 2)), img_len + 1):
|
|
i = idx - 1
|
|
print('tracking: ', i)
|
|
with torch.no_grad():
|
|
data_path1 = img_list[min(idx, img_len - 1)]
|
|
img_origin1 = skimage.io.imread(data_path1)
|
|
img_h, img_w, _ = img_origin1.shape
|
|
img_height, img_width = img_h, img_w
|
|
resize_h, resize_w = math.ceil(img_h / 32) * 32, math.ceil(img_w / 32) * 32
|
|
img1 = np.zeros((resize_h, resize_w, 3), dtype=img_origin1.dtype)
|
|
img1[:img_h, :img_w, :] = img_origin1
|
|
img1 = (img1.astype(np.float32) / 255.0 - np.array([[RGB_MEAN]])) / np.array([[RGB_STD]])
|
|
img1 = torch.from_numpy(img1).permute(2, 0, 1).view(1, 3, resize_h, resize_w)
|
|
scores, transformed_anchors, last_feat = retinanet(img1.cuda().float(), last_feat=last_feat)
|
|
# if idx > 0:
|
|
if idx > (int(img_len / 2)):
|
|
idxs = np.where(scores>0.1)
|
|
|
|
for j in range(idxs[0].shape[0]):
|
|
bbox = transformed_anchors[idxs[0][j], :]
|
|
x1 = int(bbox[0])
|
|
y1 = int(bbox[1])
|
|
x2 = int(bbox[2])
|
|
y2 = int(bbox[3])
|
|
|
|
x3 = int(bbox[4])
|
|
y3 = int(bbox[5])
|
|
x4 = int(bbox[6])
|
|
y4 = int(bbox[7])
|
|
|
|
det_conf = float(scores[idxs[0][j]])
|
|
|
|
det_rect = detect_rect()
|
|
det_rect.curr_frame = idx
|
|
det_rect.curr_rect = np.array([x1, y1, x2, y2])
|
|
det_rect.next_rect = np.array([x3, y3, x4, y4])
|
|
det_rect.conf = det_conf
|
|
|
|
if det_rect.conf > confidence_threshold:
|
|
det_list_all[det_rect.curr_frame - 1].append(det_rect)
|
|
# if i == 0:
|
|
if i == int(img_len / 2):
|
|
for j in range(len(det_list_all[i])):
|
|
det_list_all[i][j].id = j + 1
|
|
max_id = max(max_id, j + 1)
|
|
track = tracklet(det_list_all[i][j])
|
|
tracklet_all.append(track)
|
|
continue
|
|
|
|
matches, unmatched1, unmatched2 = track_det_match(tracklet_all, det_list_all[i], IOU_threshold)
|
|
|
|
for j in range(len(matches)):
|
|
det_list_all[i][matches[j][1]].id = tracklet_all[matches[j][0]].id
|
|
det_list_all[i][matches[j][1]].id = tracklet_all[matches[j][0]].id
|
|
tracklet_all[matches[j][0]].add_rect(det_list_all[i][matches[j][1]])
|
|
|
|
delete_track_list = []
|
|
for j in range(len(unmatched1)):
|
|
tracklet_all[unmatched1[j]].no_match_frame = tracklet_all[unmatched1[j]].no_match_frame + 1
|
|
if(tracklet_all[unmatched1[j]].no_match_frame >= retention_threshold):
|
|
delete_track_list.append(unmatched1[j])
|
|
|
|
origin_index = set([k for k in range(len(tracklet_all))])
|
|
delete_index = set(delete_track_list)
|
|
left_index = list(origin_index - delete_index)
|
|
tracklet_all = [tracklet_all[k] for k in left_index]
|
|
|
|
|
|
for j in range(len(unmatched2)):
|
|
det_list_all[i][unmatched2[j]].id = max_id + 1
|
|
max_id = max_id + 1
|
|
track = tracklet(det_list_all[i][unmatched2[j]])
|
|
tracklet_all.append(track)
|
|
|
|
|
|
|
|
#**************visualize tracking result and save evaluate file****************
|
|
|
|
fout_tracking = open(os.path.join(model_dir, 'results', cur_dataset + '.txt'), 'w')
|
|
|
|
save_img_dir = os.path.join(model_dir, 'results', cur_dataset)
|
|
if not os.path.exists(save_img_dir):
|
|
os.makedirs(save_img_dir)
|
|
|
|
out_video = os.path.join(model_dir, 'results', cur_dataset + '.mp4')
|
|
videoWriter = cv2.VideoWriter(out_video, cv2.VideoWriter_fourcc('m', 'p', '4', 'v'), fps, (img_width, img_height))
|
|
|
|
id_dict = {}
|
|
|
|
|
|
for i in range((int(img_len / 2)), img_len):
|
|
print('saving: ', i)
|
|
img = cv2.imread(img_list[i])
|
|
|
|
for j in range(len(det_list_all[i])):
|
|
|
|
x1, y1, x2, y2 = det_list_all[i][j].curr_rect.astype(int)
|
|
trace_id = det_list_all[i][j].id
|
|
|
|
id_dict.setdefault(str(trace_id),[]).append((int((x1+x2)/2), y2))
|
|
draw_trace_id = str(trace_id)
|
|
draw_caption(img, (x1, y1, x2, y2), draw_trace_id, color=color_list[trace_id % len(color_list)])
|
|
cv2.rectangle(img, (x1, y1), (x2, y2), color=color_list[trace_id % len(color_list)], thickness=2)
|
|
|
|
trace_len = len(id_dict[str(trace_id)])
|
|
trace_len_draw = min(max_draw_len, trace_len)
|
|
|
|
for k in range(trace_len_draw - draw_interval):
|
|
if(k % draw_interval == 0):
|
|
draw_point1 = id_dict[str(trace_id)][trace_len - k - 1]
|
|
draw_point2 = id_dict[str(trace_id)][trace_len - k - 1 - draw_interval]
|
|
cv2.line(img, draw_point1, draw_point2, color=color_list[trace_id % len(color_list)], thickness=2)
|
|
|
|
fout_tracking.write(str(i+1) + ',' + str(trace_id) + ',' + str(x1) + ',' + str(y1) + ',' + str(x2 - x1) + ',' + str(y2 - y1) + ',-1,-1,-1,-1\n')
|
|
|
|
cv2.imwrite(os.path.join(save_img_dir, str(i + 1).zfill(6) + '.jpg'), img)
|
|
videoWriter.write(img)
|
|
# cv2.waitKey(0)
|
|
|
|
fout_tracking.close()
|
|
videoWriter.release()
|
|
|
|
def run_from_train(model_dir, root_path):
|
|
if not os.path.exists(os.path.join(model_dir, 'results')):
|
|
os.makedirs(os.path.join(model_dir, 'results'))
|
|
retinanet = torch.load(os.path.join(model_dir, 'model_final.pt'))
|
|
|
|
use_gpu = True
|
|
|
|
if use_gpu: retinanet = retinanet.cuda()
|
|
|
|
retinanet.eval()
|
|
|
|
for seq_num in [2, 4, 5, 9, 10, 11, 13]:
|
|
run_each_dataset(model_dir, retinanet, root_path, 'train', 'MOT17-{:02d}'.format(seq_num))
|
|
for seq_num in [1, 3, 6, 7, 8, 12, 14]:
|
|
run_each_dataset(model_dir, retinanet, root_path, 'test', 'MOT17-{:02d}'.format(seq_num))
|
|
|
|
def main(args=None):
|
|
parser = argparse.ArgumentParser(description='Simple script for testing a CTracker network.')
|
|
parser.add_argument('--dataset_path', default='/dockerdata/home/jeromepeng/data/MOT/MOT17/', type=str, help='Dataset path, location of the images sequence.')
|
|
parser.add_argument('--model_dir', default='./trained_model/', help='Path to model (.pt) file.')
|
|
parser.add_argument('--model_path', default='./trained_model/model_final.pth', help='Path to model (.pt) file.')
|
|
parser = parser.parse_args(args)
|
|
|
|
if not os.path.exists(os.path.join(parser.model_dir, 'results')):
|
|
os.makedirs(os.path.join(parser.model_dir, 'results'))
|
|
|
|
retinanet = model.resnet50(num_classes=1, pretrained=True)
|
|
# retinanet_save = torch.load(os.path.join(parser.model_dir, 'model_final.pth'))
|
|
retinanet_save = torch.load(os.path.join(parser.model_path))
|
|
|
|
# rename moco pre-trained keys
|
|
state_dict = retinanet_save.state_dict()
|
|
for k in list(state_dict.keys()):
|
|
# retain only encoder up to before the embedding layer
|
|
if k.startswith('module.'):
|
|
# remove prefix
|
|
state_dict[k[len("module."):]] = state_dict[k]
|
|
# delete renamed or unused k
|
|
del state_dict[k]
|
|
|
|
retinanet.load_state_dict(state_dict)
|
|
|
|
use_gpu = True
|
|
|
|
if use_gpu: retinanet = retinanet.cuda()
|
|
|
|
retinanet.eval()
|
|
|
|
for seq_num in [2, 4, 5, 9, 10, 11, 13]:
|
|
run_each_dataset(parser.model_dir, retinanet, parser.dataset_path, 'train', 'MOT17-{:02d}'.format(seq_num))
|
|
# for seq_num in [1, 3, 6, 7, 8, 12, 14]:
|
|
# run_each_dataset(parser.model_dir, retinanet, parser.dataset_path, 'test', 'MOT17-{:02d}'.format(seq_num))
|
|
|
|
if __name__ == '__main__':
|
|
main()
|