Skip to content

Commit

Permalink
nitrain#39 metrics were not printed after each epoch (validation metr…
Browse files Browse the repository at this point in the history
…ics). Refactored for version 0.1.3. TODO update history.batch_metrics
  • Loading branch information
recastrodiaz committed Jul 21, 2017
1 parent 4c64f2d commit 209b830
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 29 deletions.
21 changes: 21 additions & 0 deletions tests/test_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import unittest
import torch
from torch.autograd import Variable

from torchsample.metrics import CategoricalAccuracy

class TestMetrics(unittest.TestCase):

def test_categorical_accuracy(self):
metric = CategoricalAccuracy()
predicted = Variable(torch.eye(10))
expected = Variable(torch.LongTensor(list(range(10))))
self.assertEqual(metric(predicted, expected), 100.0)

# Set 1st column to ones
predicted = Variable(torch.zeros(10, 10))
predicted.data[:, 0] = torch.ones(10)
self.assertEqual(metric(predicted, expected), 55.0)

if __name__ == '__main__':
unittest.main()
9 changes: 7 additions & 2 deletions torchsample/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,8 @@ def on_epoch_end(self, epoch, logs=None):
for k, v in logs.items():
if k.endswith('metric'):
log_data[k.split('_metric')[0]] = '%.02f' % v
else:
log_data[k] = v
self.progbar.set_postfix(log_data)
self.progbar.update()
self.progbar.close()
Expand Down Expand Up @@ -199,8 +201,11 @@ def on_epoch_begin(self, epoch, logs=None):
self.samples_seen = 0.

def on_epoch_end(self, epoch, logs=None):
for k in self.batch_metrics:
self.epoch_metrics[k].append(self.batch_metrics[k])
#for k in self.batch_metrics:
# k_log = k.split('_metric')[0]
# self.epoch_metrics.update(self.batch_metrics)
# TODO
pass

def on_batch_end(self, batch, logs=None):
for k in self.batch_metrics:
Expand Down
10 changes: 2 additions & 8 deletions torchsample/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,23 +47,22 @@ def __init__(self, container):
def on_epoch_begin(self, epoch_idx, logs):
self.container.reset()


class CategoricalAccuracy(Metric):

def __init__(self, top_k=1):
self.top_k = top_k
self.correct_count = 0
self.total_count = 0
self.accuracy = 0

self._name = 'acc_metric'

def reset(self):
self.correct_count = 0
self.total_count = 0
self.accuracy = 0

def __call__(self, y_pred, y_true):
top_k = y_pred.topk(self.top_k,1)[1]
top_k = y_pred.topk(self.top_k,1)[1]
true_k = y_true.view(len(y_true),1).expand_as(top_k)
self.correct_count += top_k.eq(true_k).float().sum().data[0]
self.total_count += len(y_pred)
Expand All @@ -76,14 +75,12 @@ class BinaryAccuracy(Metric):
def __init__(self):
self.correct_count = 0
self.total_count = 0
self.accuracy = 0

self._name = 'acc_metric'

def reset(self):
self.correct_count = 0
self.total_count = 0
self.accuracy = 0

def __call__(self, y_pred, y_true):
y_pred_round = y_pred.round().long()
Expand All @@ -104,7 +101,6 @@ def __init__(self):
def reset(self):
self.corr_sum = 0.
self.total_count = 0.
self.average = 0.

def __call__(self, y_pred, y_true=None):
"""
Expand All @@ -121,14 +117,12 @@ class ProjectionAntiCorrelation(Metric):
def __init__(self):
self.anticorr_sum = 0.
self.total_count = 0.
self.average = 0.

self._name = 'anticorr_metric'

def reset(self):
self.anticorr_sum = 0.
self.total_count = 0.
self.average = 0.

def __call__(self, y_pred, y_true=None):
"""
Expand Down
51 changes: 32 additions & 19 deletions torchsample/modules/module_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@ def __init__(self, model):
self._loss_fn = None

# other properties
self._in_train_loop = False
self._stop_training = False

def set_loss(self, loss):
Expand Down Expand Up @@ -251,8 +250,8 @@ def fit(self,
if shuffle:
inputs, targets = fit_helper.shuffle_arrays(inputs, targets)

batch_logs = {}
for batch_idx in range(num_batches):
batch_logs = {}
callback_container.on_batch_begin(batch_idx, batch_logs)

input_batch, target_batch = fit_helper.grab_batch(batch_idx, batch_size, inputs, targets)
Expand All @@ -279,14 +278,15 @@ def fit(self,
callback_container.on_batch_end(batch_idx, batch_logs)

if has_val_data:
self._in_train_loop = True
val_epoch_logs = self.evaluate(val_inputs,
val_targets,
batch_size=batch_size,
cuda_device=cuda_device,
verbose=verbose)
self._in_train_loop = False
self.history.batch_metrics.update(val_epoch_logs)
epoch_logs.update(val_epoch_logs)
epoch_logs.update(batch_logs)
# TODO how to fix this?
# self.history.batch_metrics.update(val_epoch_logs)

callback_container.on_epoch_end(epoch_idx, epoch_logs)

Expand Down Expand Up @@ -350,10 +350,10 @@ def fit_loader(self,
epoch_logs = {}
callback_container.on_epoch_begin(epoch_idx, epoch_logs)

batch_logs = {}
loader_iter = iter(loader)
for batch_idx in range(num_batches):

batch_logs = {}
callback_container.on_batch_begin(batch_idx, batch_logs)

input_batch, target_batch = fit_helper.grab_batch_from_loader(loader_iter)
Expand All @@ -376,14 +376,15 @@ def fit_loader(self,

batch_logs['loss'] = loss.data[0]
callback_container.on_batch_end(batch_idx, batch_logs)

if has_val_data:
self._in_train_loop = True
val_epoch_logs = self.evaluate_loader(val_loader,
cuda_device=cuda_device,
verbose=verbose)
self._in_train_loop = False
self.history.batch_metrics.update(val_epoch_logs)
epoch_logs.update(val_epoch_logs)
epoch_logs.update(batch_logs)
# TODO how to fix this?
# self.history.batch_metrics.update(val_epoch_logs)

callback_container.on_epoch_end(epoch_idx, epoch_logs)

Expand All @@ -396,7 +397,7 @@ def predict(self,
batch_size=32,
cuda_device=-1,
verbose=1):
self.model.train(mode=True)
self.model.train(mode=False)
# --------------------------------------------------------
num_inputs, _ = _parse_num_inputs_and_targets(inputs, None)
len_inputs = len(inputs) if not _is_tuple_or_list(inputs) else len(inputs[0])
Expand Down Expand Up @@ -474,6 +475,10 @@ def evaluate(self,
eval_loss_fn = evaluate_helper.get_partial_loss_fn(self._loss_fn)
eval_forward_fn = evaluate_helper.get_partial_forward_fn(self.model)
eval_logs= {'val_loss': 0.}

if self._has_metrics:
metric_container = MetricContainer(self._metrics, prefix='val_')
metric_container.set_helper(evaluate_helper)

samples_seen = 0
for batch_idx in range(num_batches):
Expand All @@ -487,12 +492,14 @@ def evaluate(self,

samples_seen += batch_size
eval_logs['val_loss'] = (samples_seen*eval_logs['val_loss'] + loss.data[0]*batch_size) / (samples_seen+batch_size)

if self._has_metrics:
metric_container.reset()
metrics_logs = metric_container(output_batch, target_batch)
eval_logs.update(metrics_logs)

if self._in_train_loop:
return eval_logs
else:
return eval_logs['val_loss']
self.model.train(mode=True)
return eval_logs

def evaluate_loader(self,
loader,
Expand All @@ -509,6 +516,10 @@ def evaluate_loader(self,
eval_forward_fn = evaluate_helper.get_partial_forward_fn(self.model)
eval_logs= {'val_loss': 0.}
loader_iter = iter(loader)

if self._has_metrics:
metric_container = MetricContainer(self._metrics, prefix='val_')
metric_container.set_helper(evaluate_helper)

samples_seen = 0
for batch_idx in range(num_batches):
Expand All @@ -522,12 +533,14 @@ def evaluate_loader(self,

samples_seen += batch_size
eval_logs['val_loss'] = (samples_seen*eval_logs['val_loss'] + loss.data[0]*batch_size) / (samples_seen+batch_size)

if self._has_metrics:
metric_container.reset()
metrics_logs = metric_container(output_batch, target_batch)
eval_logs.update(metrics_logs)

if self._in_train_loop:
return eval_logs
else:
return eval_logs['val_loss']
self.model.train(mode=True)
return eval_logs

def summary(self, input_size):
def register_hook(module):
Expand Down

0 comments on commit 209b830

Please sign in to comment.