Skip to content

A PyTorch Implementation of Single Shot MultiBox Detector.

License

Notifications You must be signed in to change notification settings

chenyuntc/ssd.pytorch

 
 

Repository files navigation

SSD: Single Shot MultiBox Object Detector, in PyTorch

A PyTorch implementation of Single Shot MultiBox Detector from the 2016 paper by Wei Liu, Dragomir Anguelov, Dumitru Erhan, Christian Szegedy, Scott Reed, Cheng-Yang, and Alexander C. Berg. The official and original Caffe code can be found here.

Table of Contents

       

Installation

  • Install PyTorch by selecting your environment on the website and running the appropriate command.
  • Clone this repository.
    • Note: We currently only support Python 3+.
  • Then download the dataset by following the instructions below.
  • We now support Visdom for real-time loss visualization during training!
    • To use Visdom in the browser:
    # First install Python server and client 
    pip install visdom
    # Start the server (probably in a screen or tmux)
    python -m visdom.server
    • Then (during training) navigate to http://localhost:8097/ (see the Train section below for training details).
  • Note: For training, we currently only support VOC, but are adding COCO and hopefully ImageNet soon.
  • UPDATE: We have switched from PIL Image support to cv2 as it is more accurate and significantly faster.

Datasets

To make things easy, we provide a simple VOC dataset loader that enherits torch.utils.data.Dataset making it fully compatible with the torchvision.datasets API.

VOC Dataset

Download VOC2007 trainval & test
# specify a directory for dataset to be downloaded into, else default is ~/data/
sh data/scripts/VOC2007.sh # <directory>
Download VOC2012 trainval
# specify a directory for dataset to be downloaded into, else default is ~/data/
sh data/scripts/VOC2012.sh # <directory>

Ensure the following directory structure (as specified in VOCdevkit):

VOCdevkit/                                  % development kit
VOCdevkit/VOC2007/ImageSets                 % image sets
VOCdevkit/VOC2007/Annotations               % annotation files
VOCdevkit/VOC2007/JPEGImages                % images
VOCdevkit/VOC2007/SegmentationObject        % segmentations by object
VOCdevkit/VOC2007/SegmentationClass         % segmentations by class

Training SSD

mkdir weights
cd weights
wget https://s3.amazonaws.com/amdegroot-models/vgg16_reducedfc.pth
  • To train SSD using the train script simply specify the parameters listed in train.py as a flag or manually change them.
python train.py
  • Training Parameter Options:
parser = argparse.ArgumentParser(description='Single Shot MultiBox Detector Training')
parser.add_argument('--version', default='v2', help='conv11_2(v2) or pool6(v1) as last layer')
parser.add_argument('--basenet', default='vgg16_reducedfc.pth', help='pretrained base model')
parser.add_argument('--jaccard_threshold', default=0.5, type=float, help='Min Jaccard index for matching')
parser.add_argument('--batch_size', default=32, type=int, help='Batch size for training')
parser.add_argument('--num_workers', default=4, type=int, help='Number of workers used in dataloading')
parser.add_argument('--iterations', default=120000, type=int, help='Number of training epochs')
parser.add_argument('--cuda', default=True, type=bool, help='Use cuda to train model')
parser.add_argument('--lr', '--learning-rate', default=1e-3, type=float, help='initial learning rate')
parser.add_argument('--momentum', default=0.9, type=float, help='momentum')
parser.add_argument('--weight_decay', default=5e-4, type=float, help='Weight decay for SGD')
parser.add_argument('--gamma', default=0.1, type=float, help='Gamma update for SGD')
parser.add_argument('--log_iters', default=True, type=bool, help='Print the loss at each iteration')
parser.add_argument('--visdom', default=False, type=bool, help='Use visdom to for loss visualization')
parser.add_argument('--save_folder', default='weights/', help='Location to save checkpoint models')
args = parser.parse_args()
  • Note:
    • For training, an NVIDIA GPU is strongly recommended for speed.
    • Currently we only support training on v2 (the newest version).
    • For instructions on Visdom usage/installation, see the Installation section.

Evaluation

To evaluate a trained network:

python test.py

You can specify the parameters listed in the test.py file by flagging them or manually changing them.

Performance (In progress)

VOC2007 Test

mAP
Original Test (weiliu89 weights) Train (w/o data aug) and Test*
77.2 % 77.26 % 50.8%*

* note: constant learning rate of 1e-3, default training params. with proper adjustment, this should increase dramatically even w/o data aug

FPS

GTX 1060: ~45.45 FPS for detection on a single image

Demos

Use a pre-trained SSD network for detection

Download a pre-trained network

SSD results on multiple datasets

Try the demo notebook

  • Make sure you have jupyter notebook installed.
  • Two alternatives for installing jupyter notebook:
    1. If you installed PyTorch with conda (recommended), then you should already have it. (Just navigate to the ssd.pytorch cloned repo and run): jupyter notebook

    2. If using pip:

# make sure pip is upgraded
pip3 install --upgrade pip
# install jupyter notebook
pip install jupyter
# Run this inside ssd.pytorch
jupyter notebook

Try the webcam demo

  • Works on CPU (may have to tweak cv2.waitkey for optimal fps) or on an NVIDIA GPU
  • This demo requires opencv2+ w/ python and an onboard webcam
    • You can change the default webcam in live_demo.py
  • Running python live_demo.py opens the webcam and begins detecting!

TODO

We have accumulated the following to-do list, which you can expect to be done in the very near future

  • In progress:
    • Complete data augmentation (progress in augmentation branch)
    • Produce a purely PyTorch mAP matching the original Caffe result
  • Still to come:
    • Train SSD300 with batch norm
    • Add support for SSD512 training and testing
    • Add support for COCO dataset
    • Create a functional model definition for Sergey Zagoruyko's functional-zoo

References

About

A PyTorch Implementation of Single Shot MultiBox Detector.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 99.9%
  • Shell 0.1%