From 57c1f30a5d5e3d789dc220971cef6e1d451fe77e Mon Sep 17 00:00:00 2001 From: Aymeric DUJARDIN Date: Wed, 18 Mar 2020 18:24:46 +0100 Subject: [PATCH] Adding demo (inference on video or webcam) + kp threshold for display --- lib/utils/vis.py | 6 +- tools/demo.py | 183 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 186 insertions(+), 3 deletions(-) create mode 100644 tools/demo.py diff --git a/lib/utils/vis.py b/lib/utils/vis.py index 69a1f77..5c41425 100755 --- a/lib/utils/vis.py +++ b/lib/utils/vis.py @@ -18,7 +18,7 @@ from dataset import VIS_CONFIG -def add_joints(image, joints, color, dataset='COCO'): +def add_joints(image, joints, color, dataset='COCO', kpt_threshold=0.1): part_idx = VIS_CONFIG[dataset]['part_idx'] part_orders = VIS_CONFIG[dataset]['part_orders'] @@ -26,7 +26,7 @@ def link(a, b, color): if part_idx[a] < joints.shape[0] and part_idx[b] < joints.shape[0]: jointa = joints[part_idx[a]] jointb = joints[part_idx[b]] - if jointa[2] > 0 and jointb[2] > 0: + if jointa[2] > kpt_threshold and jointb[2] > kpt_threshold: cv2.line( image, (int(jointa[0]), int(jointa[1])), @@ -37,7 +37,7 @@ def link(a, b, color): # add joints for joint in joints: - if joint[2] > 0: + if joint[2] > kpt_threshold: cv2.circle(image, (int(joint[0]), int(joint[1])), 1, color, 2) # add link diff --git a/tools/demo.py b/tools/demo.py new file mode 100644 index 0000000..63a0e78 --- /dev/null +++ b/tools/demo.py @@ -0,0 +1,183 @@ +# ------------------------------------------------------------------------------ +# Copyright (c) Microsoft +# Licensed under the MIT License. +# Written by Bin Xiao (leoxiaobin@gmail.com) +# Modified by Bowen Cheng (bcheng9@illinois.edu) +# ------------------------------------------------------------------------------ + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import os +import pprint + +import torch +import torch.backends.cudnn as cudnn +import torch.nn.parallel +import torch.optim +import torch.utils.data +import torch.utils.data.distributed +import torchvision.transforms +import torch.multiprocessing +import numpy as np + +import _init_paths +import models + +from config import cfg +from config import check_config +from config import update_config +from core.inference import get_multi_stage_outputs +from core.inference import aggregate_results +from core.group import HeatmapParser +from dataset import make_test_dataloader +from fp16_utils.fp16util import network_to_half +from utils.utils import create_logger +from utils.vis import add_joints +from utils.transforms import resize_align_multi_scale +from utils.transforms import get_final_preds +from utils.transforms import get_multi_scale_size +import datetime +import cv2 + +torch.multiprocessing.set_sharing_strategy('file_system') + +def parse_args(): + parser = argparse.ArgumentParser(description='Test keypoints network') + # general + parser.add_argument('--cfg', + help='experiment configure file name', + required=True, + type=str) + + parser.add_argument('opts', + help="Modify config options using the command-line", + default=None, + nargs=argparse.REMAINDER) + + args = parser.parse_args() + + return args + + +def main(): + args = parse_args() + update_config(cfg, args) + check_config(cfg) + + logger, final_output_dir, tb_log_dir = create_logger( + cfg, args.cfg, 'valid' + ) + + logger.info(pprint.pformat(args)) + + # cudnn related setting + cudnn.benchmark = cfg.CUDNN.BENCHMARK + torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED + + model = eval('models.'+cfg.MODEL.NAME+'.get_pose_net')( + cfg, is_train=False + ) + + if cfg.FP16.ENABLED: + model = network_to_half(model) + + if cfg.TEST.MODEL_FILE: + logger.info('=> loading model from {}'.format(cfg.TEST.MODEL_FILE)) + model.load_state_dict(torch.load(cfg.TEST.MODEL_FILE), strict=True) + else: + model_state_file = os.path.join( + final_output_dir, 'model_best.pth.tar' + ) + logger.info('=> loading model from {}'.format(model_state_file)) + model.load_state_dict(torch.load(model_state_file)) + + model = torch.nn.DataParallel(model, device_ids=cfg.GPUS).cuda() + model.eval() + + data_loader, test_dataset = make_test_dataloader(cfg) + + if cfg.MODEL.NAME == 'pose_hourglass': + transforms = torchvision.transforms.Compose( + [ + torchvision.transforms.ToTensor(), + ] + ) + else: + transforms = torchvision.transforms.Compose( + [ + torchvision.transforms.ToTensor(), + torchvision.transforms.Normalize( + mean=[0.485, 0.456, 0.406], + std=[0.229, 0.224, 0.225] + ) + ] + ) + + parser = HeatmapParser(cfg) + + vid_file = 0 # Or video file path + print("Opening Camera " + str(vid_file)) + cap = cv2.VideoCapture(vid_file) + + while True: + ret, image = cap.read() + + a = datetime.datetime.now() + + # size at scale 1.0 + base_size, center, scale = get_multi_scale_size( + image, cfg.DATASET.INPUT_SIZE, 1.0, min(cfg.TEST.SCALE_FACTOR) + ) + + with torch.no_grad(): + final_heatmaps = None + tags_list = [] + for idx, s in enumerate(sorted(cfg.TEST.SCALE_FACTOR, reverse=True)): + input_size = cfg.DATASET.INPUT_SIZE + image_resized, center, scale = resize_align_multi_scale( + image, input_size, s, min(cfg.TEST.SCALE_FACTOR) + ) + image_resized = transforms(image_resized) + image_resized = image_resized.unsqueeze(0).cuda() + + outputs, heatmaps, tags = get_multi_stage_outputs( + cfg, model, image_resized, cfg.TEST.FLIP_TEST, + cfg.TEST.PROJECT2IMAGE, base_size + ) + + final_heatmaps, tags_list = aggregate_results( + cfg, s, final_heatmaps, tags_list, heatmaps, tags + ) + + final_heatmaps = final_heatmaps / float(len(cfg.TEST.SCALE_FACTOR)) + tags = torch.cat(tags_list, dim=4) + grouped, scores = parser.parse( + final_heatmaps, tags, cfg.TEST.ADJUST, cfg.TEST.REFINE + ) + + final_results = get_final_preds( + grouped, center, scale, + [final_heatmaps.size(3), final_heatmaps.size(2)] + ) + + b = datetime.datetime.now() + inf_time = (b - a).total_seconds()*1000 + print("Inf time {} ms".format(inf_time)) + + # Display the resulting frame + for person in final_results: + color = np.random.randint(0, 255, size=3) + color = [int(i) for i in color] + add_joints(image, person, color, test_dataset.name, cfg.TEST.DETECTION_THRESHOLD) + + image = cv2.putText(image, "{:.2f} ms / frame".format(inf_time), (40, 40), + cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2, cv2.LINE_AA) + cv2.imshow('frame', image) + if cv2.waitKey(1) & 0xFF == ord('q'): + break + +if __name__ == '__main__': + main()