Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Demo inference code on webcam/video #28

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions lib/utils/vis.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,15 @@
from dataset import VIS_CONFIG


def add_joints(image, joints, color, dataset='COCO'):
def add_joints(image, joints, color, dataset='COCO', kpt_threshold=0.1):
part_idx = VIS_CONFIG[dataset]['part_idx']
part_orders = VIS_CONFIG[dataset]['part_orders']

def link(a, b, color):
if part_idx[a] < joints.shape[0] and part_idx[b] < joints.shape[0]:
jointa = joints[part_idx[a]]
jointb = joints[part_idx[b]]
if jointa[2] > 0 and jointb[2] > 0:
if jointa[2] > kpt_threshold and jointb[2] > kpt_threshold:
cv2.line(
image,
(int(jointa[0]), int(jointa[1])),
Expand All @@ -37,7 +37,7 @@ def link(a, b, color):

# add joints
for joint in joints:
if joint[2] > 0:
if joint[2] > kpt_threshold:
cv2.circle(image, (int(joint[0]), int(joint[1])), 1, color, 2)

# add link
Expand Down
183 changes: 183 additions & 0 deletions tools/demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
# ------------------------------------------------------------------------------
# Copyright (c) Microsoft
# Licensed under the MIT License.
# Written by Bin Xiao ([email protected])
# Modified by Bowen Cheng ([email protected])
# ------------------------------------------------------------------------------

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import argparse
import os
import pprint

import torch
import torch.backends.cudnn as cudnn
import torch.nn.parallel
import torch.optim
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms
import torch.multiprocessing
import numpy as np

import _init_paths
import models

from config import cfg
from config import check_config
from config import update_config
from core.inference import get_multi_stage_outputs
from core.inference import aggregate_results
from core.group import HeatmapParser
from dataset import make_test_dataloader
from fp16_utils.fp16util import network_to_half
from utils.utils import create_logger
from utils.vis import add_joints
from utils.transforms import resize_align_multi_scale
from utils.transforms import get_final_preds
from utils.transforms import get_multi_scale_size
import datetime
import cv2

torch.multiprocessing.set_sharing_strategy('file_system')

def parse_args():
parser = argparse.ArgumentParser(description='Test keypoints network')
# general
parser.add_argument('--cfg',
help='experiment configure file name',
required=True,
type=str)

parser.add_argument('opts',
help="Modify config options using the command-line",
default=None,
nargs=argparse.REMAINDER)

args = parser.parse_args()

return args


def main():
args = parse_args()
update_config(cfg, args)
check_config(cfg)

logger, final_output_dir, tb_log_dir = create_logger(
cfg, args.cfg, 'valid'
)

logger.info(pprint.pformat(args))

# cudnn related setting
cudnn.benchmark = cfg.CUDNN.BENCHMARK
torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED

model = eval('models.'+cfg.MODEL.NAME+'.get_pose_net')(
cfg, is_train=False
)

if cfg.FP16.ENABLED:
model = network_to_half(model)

if cfg.TEST.MODEL_FILE:
logger.info('=> loading model from {}'.format(cfg.TEST.MODEL_FILE))
model.load_state_dict(torch.load(cfg.TEST.MODEL_FILE), strict=True)
else:
model_state_file = os.path.join(
final_output_dir, 'model_best.pth.tar'
)
logger.info('=> loading model from {}'.format(model_state_file))
model.load_state_dict(torch.load(model_state_file))

model = torch.nn.DataParallel(model, device_ids=cfg.GPUS).cuda()
model.eval()

data_loader, test_dataset = make_test_dataloader(cfg)

if cfg.MODEL.NAME == 'pose_hourglass':
transforms = torchvision.transforms.Compose(
[
torchvision.transforms.ToTensor(),
]
)
else:
transforms = torchvision.transforms.Compose(
[
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
]
)

parser = HeatmapParser(cfg)

vid_file = 0 # Or video file path
print("Opening Camera " + str(vid_file))
cap = cv2.VideoCapture(vid_file)

while True:
ret, image = cap.read()

a = datetime.datetime.now()

# size at scale 1.0
base_size, center, scale = get_multi_scale_size(
image, cfg.DATASET.INPUT_SIZE, 1.0, min(cfg.TEST.SCALE_FACTOR)
)

with torch.no_grad():
final_heatmaps = None
tags_list = []
for idx, s in enumerate(sorted(cfg.TEST.SCALE_FACTOR, reverse=True)):
input_size = cfg.DATASET.INPUT_SIZE
image_resized, center, scale = resize_align_multi_scale(
image, input_size, s, min(cfg.TEST.SCALE_FACTOR)
)
image_resized = transforms(image_resized)
image_resized = image_resized.unsqueeze(0).cuda()

outputs, heatmaps, tags = get_multi_stage_outputs(
cfg, model, image_resized, cfg.TEST.FLIP_TEST,
cfg.TEST.PROJECT2IMAGE, base_size
)

final_heatmaps, tags_list = aggregate_results(
cfg, s, final_heatmaps, tags_list, heatmaps, tags
)

final_heatmaps = final_heatmaps / float(len(cfg.TEST.SCALE_FACTOR))
tags = torch.cat(tags_list, dim=4)
grouped, scores = parser.parse(
final_heatmaps, tags, cfg.TEST.ADJUST, cfg.TEST.REFINE
)

final_results = get_final_preds(
grouped, center, scale,
[final_heatmaps.size(3), final_heatmaps.size(2)]
)

b = datetime.datetime.now()
inf_time = (b - a).total_seconds()*1000
print("Inf time {} ms".format(inf_time))

# Display the resulting frame
for person in final_results:
color = np.random.randint(0, 255, size=3)
color = [int(i) for i in color]
add_joints(image, person, color, test_dataset.name, cfg.TEST.DETECTION_THRESHOLD)

image = cv2.putText(image, "{:.2f} ms / frame".format(inf_time), (40, 40),
cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2, cv2.LINE_AA)
cv2.imshow('frame', image)
if cv2.waitKey(1) & 0xFF == ord('q'):
break

if __name__ == '__main__':
main()