Skip to content

Commit

Permalink
Merge pull request fastai#171 from cbarokk/master
Browse files Browse the repository at this point in the history
allow to pass a stepper class to fit()
  • Loading branch information
jph00 authored Feb 23, 2018
2 parents 1e74dde + a79dd2b commit 34ea07e
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions fastai/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def set_train_mode(m):
else: m.train()


def fit(model, data, epochs, opt, crit, metrics=None, callbacks=None, **kwargs):
def fit(model, data, epochs, opt, crit, metrics=None, callbacks=None, stepper=Stepper, **kwargs):
""" Fits a model
Arguments:
Expand All @@ -72,7 +72,7 @@ def fit(model, data, epochs, opt, crit, metrics=None, callbacks=None, **kwargs):
epochs(int): number of epochs
crit: loss function to optimize. Example: F.cross_entropy
"""
stepper = Stepper(model, opt, crit, **kwargs)
stepper = stepper(model, opt, crit, **kwargs)
metrics = metrics or []
callbacks = callbacks or []
avg_mom=0.98
Expand Down

0 comments on commit 34ea07e

Please sign in to comment.