-
Notifications
You must be signed in to change notification settings - Fork 96
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
7 changed files
with
847 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
# ------------------------------------------------------------------------ | ||
# Copyright (c) 2021 megvii-model. All Rights Reserved. | ||
# ------------------------------------------------------------------------ | ||
# Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR) | ||
# Copyright (c) 2020 SenseTime. All Rights Reserved. | ||
# ------------------------------------------------------------------------ | ||
|
||
EXP_DIR=exps/e2e_motr_r50_dance | ||
python3 submit_dance.py \ | ||
--meta_arch motr \ | ||
--dataset_file e2e_joint \ | ||
--mot_path /data/datasets \ | ||
--epoch 200 \ | ||
--with_box_refine \ | ||
--lr_drop 100 \ | ||
--lr 2e-4 \ | ||
--lr_backbone 2e-5 \ | ||
--output_dir ${EXP_DIR} \ | ||
--batch_size 1 \ | ||
--sample_mode 'random_interval' \ | ||
--sample_interval 10 \ | ||
--sampler_steps 50 90 150 \ | ||
--sampler_lengths 2 3 4 5 \ | ||
--update_query_pos \ | ||
--merger_dropout 0 \ | ||
--dropout 0 \ | ||
--random_drop 0.1 \ | ||
--fp_ratio 0.3 \ | ||
--query_interaction_layer 'QIM' \ | ||
--extra_track_attn \ | ||
--data_txt_path_train ./datasets/data_path/joint.train \ | ||
--data_txt_path_val ./datasets/data_path/mot17.train \ | ||
--resume ${EXP_DIR}/checkpoint.pth \ | ||
--exp_name tracker |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
# ------------------------------------------------------------------------ | ||
# Copyright (c) 2021 megvii-model. All Rights Reserved. | ||
# ------------------------------------------------------------------------ | ||
# Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR) | ||
# Copyright (c) 2020 SenseTime. All Rights Reserved. | ||
# ------------------------------------------------------------------------ | ||
|
||
|
||
# for MOT17 | ||
|
||
PRETRAIN=r50_deformable_detr_plus_iterative_bbox_refinement-checkpoint.pth | ||
EXP_DIR=exps/e2e_motr_r50_dance | ||
python3 -m torch.distributed.launch --nproc_per_node=8 \ | ||
--use_env main.py \ | ||
--meta_arch motr \ | ||
--use_checkpoint \ | ||
--dataset_file e2e_dance \ | ||
--epoch 20 \ | ||
--with_box_refine \ | ||
--lr_drop 10 \ | ||
--lr 2e-4 \ | ||
--lr_backbone 2e-5 \ | ||
--pretrained ${PRETRAIN} \ | ||
--output_dir ${EXP_DIR} \ | ||
--batch_size 1 \ | ||
--sample_mode 'random_interval' \ | ||
--sample_interval 10 \ | ||
--sampler_steps 5 9 15 \ | ||
--sampler_lengths 2 3 4 5 \ | ||
--update_query_pos \ | ||
--merger_dropout 0 \ | ||
--dropout 0 \ | ||
--random_drop 0.1 \ | ||
--fp_ratio 0.3 \ | ||
--query_interaction_layer 'QIM' \ | ||
--extra_track_attn \ | ||
--data_txt_path_train ./datasets/data_path/joint.train \ | ||
--data_txt_path_val ./datasets/data_path/mot17.train \ | ||
|& tee ${EXP_DIR}/output.log |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,257 @@ | ||
# ------------------------------------------------------------------------ | ||
# Copyright (c) 2021 megvii-model. All Rights Reserved. | ||
# ------------------------------------------------------------------------ | ||
# Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR) | ||
# Copyright (c) 2020 SenseTime. All Rights Reserved. | ||
# ------------------------------------------------------------------------ | ||
# Modified from DETR (https://github.com/facebookresearch/detr) | ||
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | ||
# ------------------------------------------------------------------------ | ||
|
||
""" | ||
MOT dataset which returns image_id for evaluation. | ||
""" | ||
from collections import defaultdict | ||
import json | ||
import os | ||
from pathlib import Path | ||
import cv2 | ||
import numpy as np | ||
import torch | ||
import torch.utils.data | ||
import os.path as osp | ||
from PIL import Image, ImageDraw | ||
import copy | ||
import datasets.transforms as T | ||
from models.structures import Instances | ||
|
||
from random import choice, randint | ||
|
||
|
||
class DetMOTDetection: | ||
def __init__(self, args, data_txt_path: str, seqs_folder, dataset2transform): | ||
self.args = args | ||
self.dataset2transform = dataset2transform | ||
self.num_frames_per_batch = max(args.sampler_lengths) | ||
self.sample_mode = args.sample_mode | ||
self.sample_interval = args.sample_interval | ||
self.video_dict = {} | ||
self.split_dir = os.path.join(args.mot_path, "DanceTrack", "train") | ||
|
||
self.labels_full = defaultdict(lambda : defaultdict(list)) | ||
for vid in os.listdir(self.split_dir): | ||
if 'DPM' in vid or 'FRCNN' in vid: | ||
print(f'filter {vid}') | ||
continue | ||
gt_path = os.path.join(self.split_dir, vid, 'gt', 'gt.txt') | ||
for l in open(gt_path): | ||
t, i, *xywh, mark, label = l.strip().split(',')[:8] | ||
t, i, mark, label = map(int, (t, i, mark, label)) | ||
if mark == 0: | ||
continue | ||
if label in [3, 4, 5, 6, 9, 10, 11]: # Non-person | ||
continue | ||
else: | ||
crowd = False | ||
x, y, w, h = map(float, (xywh)) | ||
self.labels_full[vid][t].append([x, y, w, h, i, crowd]) | ||
vid_files = list(self.labels_full.keys()) | ||
|
||
self.indices = [] | ||
self.vid_tmax = {} | ||
for vid in vid_files: | ||
self.video_dict[vid] = len(self.video_dict) | ||
t_min = min(self.labels_full[vid].keys()) | ||
t_max = max(self.labels_full[vid].keys()) + 1 | ||
self.vid_tmax[vid] = t_max - 1 | ||
for t in range(t_min, t_max - self.num_frames_per_batch): | ||
self.indices.append((vid, t)) | ||
|
||
self.sampler_steps: list = args.sampler_steps | ||
self.lengths: list = args.sampler_lengths | ||
print("sampler_steps={} lenghts={}".format(self.sampler_steps, self.lengths)) | ||
self.period_idx = 0 | ||
|
||
def set_epoch(self, epoch): | ||
self.current_epoch = epoch | ||
if self.sampler_steps is None or len(self.sampler_steps) == 0: | ||
# fixed sampling length. | ||
return | ||
|
||
for i in range(len(self.sampler_steps)): | ||
if epoch >= self.sampler_steps[i]: | ||
self.period_idx = i + 1 | ||
print("set epoch: epoch {} period_idx={}".format(epoch, self.period_idx)) | ||
self.num_frames_per_batch = self.lengths[self.period_idx] | ||
|
||
def step_epoch(self): | ||
# one epoch finishes. | ||
print("Dataset: epoch {} finishes".format(self.current_epoch)) | ||
self.set_epoch(self.current_epoch + 1) | ||
|
||
@staticmethod | ||
def _targets_to_instances(targets: dict, img_shape) -> Instances: | ||
gt_instances = Instances(tuple(img_shape)) | ||
gt_instances.boxes = targets['boxes'] | ||
gt_instances.labels = targets['labels'] | ||
gt_instances.obj_ids = targets['obj_ids'] | ||
gt_instances.area = targets['area'] | ||
return gt_instances | ||
|
||
def load_crowd(self): | ||
path, boxes, crowd = choice(self.crowd_gts) | ||
img = Image.open(path) | ||
|
||
w, h = img._size | ||
boxes = torch.tensor(boxes, dtype=torch.float32) | ||
areas = boxes[..., 2:].prod(-1) | ||
boxes[:, 2:] += boxes[:, :2] | ||
target = { | ||
'boxes': boxes, | ||
'labels': torch.zeros((len(boxes), ), dtype=torch.long), | ||
'iscrowd': torch.as_tensor(crowd), | ||
'image_id': torch.tensor([0]), | ||
'area': areas, | ||
'obj_ids': torch.arange(len(boxes)), | ||
'size': torch.as_tensor([h, w]), | ||
'orig_size': torch.as_tensor([h, w]), | ||
'dataset': "CrowdHuman", | ||
} | ||
return [img], [target] | ||
|
||
def _pre_single_frame(self, vid, idx: int): | ||
img_path = os.path.join(self.split_dir, vid, 'img1', f'{idx:08d}.jpg') | ||
img = Image.open(img_path) | ||
targets = {} | ||
w, h = img._size | ||
assert w > 0 and h > 0, "invalid image {} with shape {} {}".format(img_path, w, h) | ||
obj_idx_offset = self.video_dict[vid] * 100000 # 100000 unique ids is enough for a video. | ||
|
||
targets['dataset'] = 'MOT17' | ||
targets['boxes'] = [] | ||
targets['area'] = [] | ||
targets['iscrowd'] = [] | ||
targets['labels'] = [] | ||
targets['obj_ids'] = [] | ||
targets['image_id'] = torch.as_tensor(idx) | ||
targets['size'] = torch.as_tensor([h, w]) | ||
targets['orig_size'] = torch.as_tensor([h, w]) | ||
for *xywh, id, crowd in self.labels_full[vid][idx]: | ||
targets['boxes'].append(xywh) | ||
targets['area'].append(xywh[2] * xywh[3]) | ||
targets['iscrowd'].append(crowd) | ||
targets['labels'].append(0) | ||
targets['obj_ids'].append(id + obj_idx_offset) | ||
|
||
targets['area'] = torch.as_tensor(targets['area']) | ||
targets['iscrowd'] = torch.as_tensor(targets['iscrowd']) | ||
targets['labels'] = torch.as_tensor(targets['labels']) | ||
targets['obj_ids'] = torch.as_tensor(targets['obj_ids'], dtype=torch.float64) | ||
targets['boxes'] = torch.as_tensor(targets['boxes'], dtype=torch.float32).reshape(-1, 4) | ||
targets['boxes'][:, 2:] += targets['boxes'][:, :2] | ||
return img, targets | ||
|
||
def _get_sample_range(self, start_idx): | ||
|
||
# take default sampling method for normal dataset. | ||
assert self.sample_mode in ['fixed_interval', 'random_interval'], 'invalid sample mode: {}'.format(self.sample_mode) | ||
if self.sample_mode == 'fixed_interval': | ||
sample_interval = self.sample_interval | ||
elif self.sample_mode == 'random_interval': | ||
sample_interval = np.random.randint(1, self.sample_interval + 1) | ||
default_range = start_idx, start_idx + (self.num_frames_per_batch - 1) * sample_interval + 1, sample_interval | ||
return default_range | ||
|
||
def pre_continuous_frames(self, vid, indices): | ||
return zip(*[self._pre_single_frame(vid, i) for i in indices]) | ||
|
||
def sample_indices(self, vid, f_index): | ||
assert self.sample_mode == 'random_interval' | ||
rate = randint(1, self.sample_interval + 1) | ||
tmax = self.vid_tmax[vid] | ||
ids = [f_index + rate * i for i in range(self.num_frames_per_batch)] | ||
return [min(i, tmax) for i in ids] | ||
|
||
def __getitem__(self, idx): | ||
vid, f_index = self.indices[idx] | ||
indices = self.sample_indices(vid, f_index) | ||
images, targets = self.pre_continuous_frames(vid, indices) | ||
dataset_name = targets[0]['dataset'] | ||
transform = self.dataset2transform[dataset_name] | ||
if transform is not None: | ||
images, targets = transform(images, targets) | ||
gt_instances = [] | ||
for img_i, targets_i in zip(images, targets): | ||
gt_instances_i = self._targets_to_instances(targets_i, img_i.shape[1:3]) | ||
gt_instances.append(gt_instances_i) | ||
return { | ||
'imgs': images, | ||
'gt_instances': gt_instances, | ||
} | ||
|
||
def __len__(self): | ||
return len(self.indices) | ||
|
||
|
||
class DetMOTDetectionValidation(DetMOTDetection): | ||
def __init__(self, args, seqs_folder, dataset2transform): | ||
args.data_txt_path = args.val_data_txt_path | ||
super().__init__(args, seqs_folder, dataset2transform) | ||
|
||
|
||
def make_transforms_for_mot17(image_set, args=None): | ||
|
||
normalize = T.MotCompose([ | ||
T.MotToTensor(), | ||
T.MotNormalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | ||
]) | ||
scales = [608, 640, 672, 704, 736, 768, 800, 832, 864, 896, 928, 960, 992] | ||
|
||
if image_set == 'train': | ||
return T.MotCompose([ | ||
T.MotRandomHorizontalFlip(), | ||
T.MotRandomSelect( | ||
T.MotRandomResize(scales, max_size=1536), | ||
T.MotCompose([ | ||
T.MotRandomResize([800, 1000, 1200]), | ||
T.FixedMotRandomCrop(800, 1200), | ||
T.MotRandomResize(scales, max_size=1536), | ||
]) | ||
), | ||
normalize, | ||
]) | ||
|
||
if image_set == 'val': | ||
return T.MotCompose([ | ||
T.MotRandomResize([800], max_size=1333), | ||
normalize, | ||
]) | ||
|
||
raise ValueError(f'unknown {image_set}') | ||
|
||
|
||
def build_dataset2transform(args, image_set): | ||
mot17_train = make_transforms_for_mot17('train', args) | ||
mot17_test = make_transforms_for_mot17('val', args) | ||
|
||
dataset2transform_train = {'MOT17': mot17_train} | ||
dataset2transform_val = {'MOT17': mot17_test} | ||
if image_set == 'train': | ||
return dataset2transform_train | ||
elif image_set == 'val': | ||
return dataset2transform_val | ||
else: | ||
raise NotImplementedError() | ||
|
||
|
||
def build(image_set, args): | ||
root = Path(args.mot_path) | ||
assert root.exists(), f'provided MOT path {root} does not exist' | ||
dataset2transform = build_dataset2transform(args, image_set) | ||
if image_set == 'train': | ||
data_txt_path = args.data_txt_path_train | ||
dataset = DetMOTDetection(args, data_txt_path=data_txt_path, seqs_folder=root, dataset2transform=dataset2transform) | ||
if image_set == 'val': | ||
data_txt_path = args.data_txt_path_val | ||
dataset = DetMOTDetection(args, data_txt_path=data_txt_path, seqs_folder=root, dataset2transform=dataset2transform, joint=False) | ||
return dataset |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.