Skip to content

Commit

Permalink
support for dance track
Browse files Browse the repository at this point in the history
  • Loading branch information
zyayoung committed Feb 9, 2022
1 parent d87f6a4 commit c245b9a
Show file tree
Hide file tree
Showing 7 changed files with 847 additions and 4 deletions.
34 changes: 34 additions & 0 deletions configs/r50_motr_submit_dance.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# ------------------------------------------------------------------------
# Copyright (c) 2021 megvii-model. All Rights Reserved.
# ------------------------------------------------------------------------
# Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR)
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# ------------------------------------------------------------------------

EXP_DIR=exps/e2e_motr_r50_dance
python3 submit_dance.py \
--meta_arch motr \
--dataset_file e2e_joint \
--mot_path /data/datasets \
--epoch 200 \
--with_box_refine \
--lr_drop 100 \
--lr 2e-4 \
--lr_backbone 2e-5 \
--output_dir ${EXP_DIR} \
--batch_size 1 \
--sample_mode 'random_interval' \
--sample_interval 10 \
--sampler_steps 50 90 150 \
--sampler_lengths 2 3 4 5 \
--update_query_pos \
--merger_dropout 0 \
--dropout 0 \
--random_drop 0.1 \
--fp_ratio 0.3 \
--query_interaction_layer 'QIM' \
--extra_track_attn \
--data_txt_path_train ./datasets/data_path/joint.train \
--data_txt_path_val ./datasets/data_path/mot17.train \
--resume ${EXP_DIR}/checkpoint.pth \
--exp_name tracker
39 changes: 39 additions & 0 deletions configs/r50_motr_train_dance.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# ------------------------------------------------------------------------
# Copyright (c) 2021 megvii-model. All Rights Reserved.
# ------------------------------------------------------------------------
# Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR)
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# ------------------------------------------------------------------------


# for MOT17

PRETRAIN=r50_deformable_detr_plus_iterative_bbox_refinement-checkpoint.pth
EXP_DIR=exps/e2e_motr_r50_dance
python3 -m torch.distributed.launch --nproc_per_node=8 \
--use_env main.py \
--meta_arch motr \
--use_checkpoint \
--dataset_file e2e_dance \
--epoch 20 \
--with_box_refine \
--lr_drop 10 \
--lr 2e-4 \
--lr_backbone 2e-5 \
--pretrained ${PRETRAIN} \
--output_dir ${EXP_DIR} \
--batch_size 1 \
--sample_mode 'random_interval' \
--sample_interval 10 \
--sampler_steps 5 9 15 \
--sampler_lengths 2 3 4 5 \
--update_query_pos \
--merger_dropout 0 \
--dropout 0 \
--random_drop 0.1 \
--fp_ratio 0.3 \
--query_interaction_layer 'QIM' \
--extra_track_attn \
--data_txt_path_train ./datasets/data_path/joint.train \
--data_txt_path_val ./datasets/data_path/mot17.train \
|& tee ${EXP_DIR}/output.log
3 changes: 3 additions & 0 deletions datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from .coco import build as build_coco
from .detmot import build as build_e2e_mot
from .dance import build as build_e2e_dance
from .static_detmot import build as build_e2e_static_mot
from .joint import build as build_e2e_joint
from .torchvision_datasets import CocoDetection
Expand Down Expand Up @@ -40,4 +41,6 @@ def build_dataset(image_set, args):
return build_e2e_static_mot(image_set, args)
if args.dataset_file == 'e2e_mot':
return build_e2e_mot(image_set, args)
if args.dataset_file == 'e2e_dance':
return build_e2e_dance(image_set, args)
raise ValueError(f'dataset {args.dataset_file} not supported')
257 changes: 257 additions & 0 deletions datasets/dance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,257 @@
# ------------------------------------------------------------------------
# Copyright (c) 2021 megvii-model. All Rights Reserved.
# ------------------------------------------------------------------------
# Modified from Deformable DETR (https://github.com/fundamentalvision/Deformable-DETR)
# Copyright (c) 2020 SenseTime. All Rights Reserved.
# ------------------------------------------------------------------------
# Modified from DETR (https://github.com/facebookresearch/detr)
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# ------------------------------------------------------------------------

"""
MOT dataset which returns image_id for evaluation.
"""
from collections import defaultdict
import json
import os
from pathlib import Path
import cv2
import numpy as np
import torch
import torch.utils.data
import os.path as osp
from PIL import Image, ImageDraw
import copy
import datasets.transforms as T
from models.structures import Instances

from random import choice, randint


class DetMOTDetection:
def __init__(self, args, data_txt_path: str, seqs_folder, dataset2transform):
self.args = args
self.dataset2transform = dataset2transform
self.num_frames_per_batch = max(args.sampler_lengths)
self.sample_mode = args.sample_mode
self.sample_interval = args.sample_interval
self.video_dict = {}
self.split_dir = os.path.join(args.mot_path, "DanceTrack", "train")

self.labels_full = defaultdict(lambda : defaultdict(list))
for vid in os.listdir(self.split_dir):
if 'DPM' in vid or 'FRCNN' in vid:
print(f'filter {vid}')
continue
gt_path = os.path.join(self.split_dir, vid, 'gt', 'gt.txt')
for l in open(gt_path):
t, i, *xywh, mark, label = l.strip().split(',')[:8]
t, i, mark, label = map(int, (t, i, mark, label))
if mark == 0:
continue
if label in [3, 4, 5, 6, 9, 10, 11]: # Non-person
continue
else:
crowd = False
x, y, w, h = map(float, (xywh))
self.labels_full[vid][t].append([x, y, w, h, i, crowd])
vid_files = list(self.labels_full.keys())

self.indices = []
self.vid_tmax = {}
for vid in vid_files:
self.video_dict[vid] = len(self.video_dict)
t_min = min(self.labels_full[vid].keys())
t_max = max(self.labels_full[vid].keys()) + 1
self.vid_tmax[vid] = t_max - 1
for t in range(t_min, t_max - self.num_frames_per_batch):
self.indices.append((vid, t))

self.sampler_steps: list = args.sampler_steps
self.lengths: list = args.sampler_lengths
print("sampler_steps={} lenghts={}".format(self.sampler_steps, self.lengths))
self.period_idx = 0

def set_epoch(self, epoch):
self.current_epoch = epoch
if self.sampler_steps is None or len(self.sampler_steps) == 0:
# fixed sampling length.
return

for i in range(len(self.sampler_steps)):
if epoch >= self.sampler_steps[i]:
self.period_idx = i + 1
print("set epoch: epoch {} period_idx={}".format(epoch, self.period_idx))
self.num_frames_per_batch = self.lengths[self.period_idx]

def step_epoch(self):
# one epoch finishes.
print("Dataset: epoch {} finishes".format(self.current_epoch))
self.set_epoch(self.current_epoch + 1)

@staticmethod
def _targets_to_instances(targets: dict, img_shape) -> Instances:
gt_instances = Instances(tuple(img_shape))
gt_instances.boxes = targets['boxes']
gt_instances.labels = targets['labels']
gt_instances.obj_ids = targets['obj_ids']
gt_instances.area = targets['area']
return gt_instances

def load_crowd(self):
path, boxes, crowd = choice(self.crowd_gts)
img = Image.open(path)

w, h = img._size
boxes = torch.tensor(boxes, dtype=torch.float32)
areas = boxes[..., 2:].prod(-1)
boxes[:, 2:] += boxes[:, :2]
target = {
'boxes': boxes,
'labels': torch.zeros((len(boxes), ), dtype=torch.long),
'iscrowd': torch.as_tensor(crowd),
'image_id': torch.tensor([0]),
'area': areas,
'obj_ids': torch.arange(len(boxes)),
'size': torch.as_tensor([h, w]),
'orig_size': torch.as_tensor([h, w]),
'dataset': "CrowdHuman",
}
return [img], [target]

def _pre_single_frame(self, vid, idx: int):
img_path = os.path.join(self.split_dir, vid, 'img1', f'{idx:08d}.jpg')
img = Image.open(img_path)
targets = {}
w, h = img._size
assert w > 0 and h > 0, "invalid image {} with shape {} {}".format(img_path, w, h)
obj_idx_offset = self.video_dict[vid] * 100000 # 100000 unique ids is enough for a video.

targets['dataset'] = 'MOT17'
targets['boxes'] = []
targets['area'] = []
targets['iscrowd'] = []
targets['labels'] = []
targets['obj_ids'] = []
targets['image_id'] = torch.as_tensor(idx)
targets['size'] = torch.as_tensor([h, w])
targets['orig_size'] = torch.as_tensor([h, w])
for *xywh, id, crowd in self.labels_full[vid][idx]:
targets['boxes'].append(xywh)
targets['area'].append(xywh[2] * xywh[3])
targets['iscrowd'].append(crowd)
targets['labels'].append(0)
targets['obj_ids'].append(id + obj_idx_offset)

targets['area'] = torch.as_tensor(targets['area'])
targets['iscrowd'] = torch.as_tensor(targets['iscrowd'])
targets['labels'] = torch.as_tensor(targets['labels'])
targets['obj_ids'] = torch.as_tensor(targets['obj_ids'], dtype=torch.float64)
targets['boxes'] = torch.as_tensor(targets['boxes'], dtype=torch.float32).reshape(-1, 4)
targets['boxes'][:, 2:] += targets['boxes'][:, :2]
return img, targets

def _get_sample_range(self, start_idx):

# take default sampling method for normal dataset.
assert self.sample_mode in ['fixed_interval', 'random_interval'], 'invalid sample mode: {}'.format(self.sample_mode)
if self.sample_mode == 'fixed_interval':
sample_interval = self.sample_interval
elif self.sample_mode == 'random_interval':
sample_interval = np.random.randint(1, self.sample_interval + 1)
default_range = start_idx, start_idx + (self.num_frames_per_batch - 1) * sample_interval + 1, sample_interval
return default_range

def pre_continuous_frames(self, vid, indices):
return zip(*[self._pre_single_frame(vid, i) for i in indices])

def sample_indices(self, vid, f_index):
assert self.sample_mode == 'random_interval'
rate = randint(1, self.sample_interval + 1)
tmax = self.vid_tmax[vid]
ids = [f_index + rate * i for i in range(self.num_frames_per_batch)]
return [min(i, tmax) for i in ids]

def __getitem__(self, idx):
vid, f_index = self.indices[idx]
indices = self.sample_indices(vid, f_index)
images, targets = self.pre_continuous_frames(vid, indices)
dataset_name = targets[0]['dataset']
transform = self.dataset2transform[dataset_name]
if transform is not None:
images, targets = transform(images, targets)
gt_instances = []
for img_i, targets_i in zip(images, targets):
gt_instances_i = self._targets_to_instances(targets_i, img_i.shape[1:3])
gt_instances.append(gt_instances_i)
return {
'imgs': images,
'gt_instances': gt_instances,
}

def __len__(self):
return len(self.indices)


class DetMOTDetectionValidation(DetMOTDetection):
def __init__(self, args, seqs_folder, dataset2transform):
args.data_txt_path = args.val_data_txt_path
super().__init__(args, seqs_folder, dataset2transform)


def make_transforms_for_mot17(image_set, args=None):

normalize = T.MotCompose([
T.MotToTensor(),
T.MotNormalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
scales = [608, 640, 672, 704, 736, 768, 800, 832, 864, 896, 928, 960, 992]

if image_set == 'train':
return T.MotCompose([
T.MotRandomHorizontalFlip(),
T.MotRandomSelect(
T.MotRandomResize(scales, max_size=1536),
T.MotCompose([
T.MotRandomResize([800, 1000, 1200]),
T.FixedMotRandomCrop(800, 1200),
T.MotRandomResize(scales, max_size=1536),
])
),
normalize,
])

if image_set == 'val':
return T.MotCompose([
T.MotRandomResize([800], max_size=1333),
normalize,
])

raise ValueError(f'unknown {image_set}')


def build_dataset2transform(args, image_set):
mot17_train = make_transforms_for_mot17('train', args)
mot17_test = make_transforms_for_mot17('val', args)

dataset2transform_train = {'MOT17': mot17_train}
dataset2transform_val = {'MOT17': mot17_test}
if image_set == 'train':
return dataset2transform_train
elif image_set == 'val':
return dataset2transform_val
else:
raise NotImplementedError()


def build(image_set, args):
root = Path(args.mot_path)
assert root.exists(), f'provided MOT path {root} does not exist'
dataset2transform = build_dataset2transform(args, image_set)
if image_set == 'train':
data_txt_path = args.data_txt_path_train
dataset = DetMOTDetection(args, data_txt_path=data_txt_path, seqs_folder=root, dataset2transform=dataset2transform)
if image_set == 'val':
data_txt_path = args.data_txt_path_val
dataset = DetMOTDetection(args, data_txt_path=data_txt_path, seqs_folder=root, dataset2transform=dataset2transform, joint=False)
return dataset
8 changes: 4 additions & 4 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def main(args):

batch_sampler_train = torch.utils.data.BatchSampler(
sampler_train, args.batch_size, drop_last=True)
if args.dataset_file in ['e2e_mot', 'mot', 'ori_mot', 'e2e_static_mot', 'e2e_joint']:
if args.dataset_file in ['e2e_mot', 'e2e_dance', 'mot', 'ori_mot', 'e2e_static_mot', 'e2e_joint']:
collate_fn = utils.mot_collate_fn
else:
collate_fn = utils.collate_fn
Expand Down Expand Up @@ -322,7 +322,7 @@ def match_name_keywords(n, name_keywords):
start_time = time.time()

train_func = train_one_epoch
if args.dataset_file in ['e2e_mot', 'mot', 'ori_mot', 'e2e_static_mot', 'e2e_joint']:
if args.dataset_file in ['e2e_mot', 'e2e_dance', 'mot', 'ori_mot', 'e2e_static_mot', 'e2e_joint']:
train_func = train_one_epoch_mot
dataset_train.set_epoch(args.start_epoch)
dataset_val.set_epoch(args.start_epoch)
Expand All @@ -346,7 +346,7 @@ def match_name_keywords(n, name_keywords):
'args': args,
}, checkpoint_path)

if args.dataset_file not in ['e2e_mot', 'mot', 'ori_mot', 'e2e_static_mot', 'e2e_joint']:
if args.dataset_file not in ['e2e_mot', 'e2e_dance', 'mot', 'ori_mot', 'e2e_static_mot', 'e2e_joint']:
test_stats, coco_evaluator = evaluate(
model, criterion, postprocessors, data_loader_val, base_ds, device, args.output_dir
)
Expand All @@ -370,7 +370,7 @@ def match_name_keywords(n, name_keywords):
for name in filenames:
torch.save(coco_evaluator.coco_eval["bbox"].eval,
output_dir / "eval" / name)
if args.dataset_file in ['e2e_mot', 'mot', 'ori_mot', 'e2e_static_mot', 'e2e_joint']:
if args.dataset_file in ['e2e_mot', 'e2e_dance', 'mot', 'ori_mot', 'e2e_static_mot', 'e2e_joint']:
dataset_train.step_epoch()
dataset_val.step_epoch()
total_time = time.time() - start_time
Expand Down
2 changes: 2 additions & 0 deletions models/deformable_detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,8 @@ def build(args):
num_classes = 1
if args.dataset_file == 'e2e_mot':
num_classes = 1
if args.dataset_file == 'e2e_dance':
num_classes = 1
device = torch.device(args.device)

backbone = build_backbone(args)
Expand Down
Loading

0 comments on commit c245b9a

Please sign in to comment.