diff --git a/train/train.py b/train/train.py index aebc7e4..cb04d8f 100644 --- a/train/train.py +++ b/train/train.py @@ -127,7 +127,7 @@ def train(): ## network logits, end_points, pyramid_map = network.get_network(FLAGS.network, image, - weight_decay=FLAGS.weight_decay) + weight_decay=FLAGS.weight_decay, is_training=True) outputs = pyramid_network.build(end_points, im_shape[1], im_shape[2], pyramid_map, num_classes=81, base_anchors=9,