Skip to content

Commit

Permalink
Version 0.3.0 - LearningRateFinder #5
Browse files Browse the repository at this point in the history
LearningRateFinder
  • Loading branch information
eliorc authored Jun 14, 2019
2 parents d32b53d + 74cfd7a commit ec3e459
Show file tree
Hide file tree
Showing 8 changed files with 211 additions and 7 deletions.
2 changes: 1 addition & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ Tavolo
| tavolo gathers implementations of these useful ideas from the community (by contribution, from `Kaggle`_ etc.)
and makes them accessible in a single PyPI hosted package that compliments the `tf.keras`_ module.
|
| *Notice: tavolo is developed for TensorFlow 2.0 (right now on alpha), most modules will work with earlier versions but some won't (like LayerNorm)*
| *Notice: tavolo is developed for TensorFlow 2.0 (right now on beta), most modules will work with earlier versions but some won't (like LayerNorm)*
Documentation
-------------
Expand Down
4 changes: 2 additions & 2 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ Welcome to tavolo's documentation!
.. warning::

tavolo is developed for TensorFlow 2.0 (right now on alpha), most modules will work with earlier versions but some won't (like LayerNorm)
tavolo is developed for TensorFlow 2.0 (right now on beta), most modules will work with earlier versions but some won't (like LayerNorm)

Showcase
--------
Expand Down Expand Up @@ -43,7 +43,7 @@ Showcase


.. toctree::
:maxdepth: 1
:maxdepth: 2
:caption: Modules

embeddings
Expand Down
12 changes: 12 additions & 0 deletions docs/source/learning.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,15 @@ Modules for altering the learning process
++++++++++++++++++++++++++++++

.. automodule:: learning.CyclicLearningRateCallback

-------

.. _`learning_rate_finder`:

``LearningRateFinder``
++++++++++++++++++++++

.. automodule:: learning.LearningRateFinder

.. automethod:: learning.LearningRateFinder::scan

2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from setuptools import setup

VERSION = '0.2.1'
VERSION = '0.3.0'

setup(name='tavolo',
version=VERSION,
Expand Down
2 changes: 1 addition & 1 deletion tavolo/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
__name__ = 'tavolo'
__version__ = '0.2.1'
__version__ = '0.3.0'

from . import embeddings
from . import normalization
Expand Down
144 changes: 142 additions & 2 deletions tavolo/learning.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from typing import Optional, Callable
import tempfile
import math
from typing import Optional, Callable, Tuple, List
from contextlib import suppress

import tensorflow as tf
import numpy as np
Expand Down Expand Up @@ -159,7 +162,7 @@ def _set_scale_scheme(self):
"""

# Check for supported scale schemes
if self.scale_scheme not in {'triangular', 'triangular2', 'exp_rage'}:
if self.scale_scheme not in {'triangular', 'triangular2', 'exp_range'}:
raise ValueError('{} is not a supported scale scheme'.format(self.scale_scheme))

# Set scheme
Expand All @@ -172,3 +175,140 @@ def _set_scale_scheme(self):
elif self.scale_scheme == 'exp_range':
self.scale_fn = lambda x: self.gamma ** x
self.scale_mode = 'iterations'


class LearningRateFinder:
"""
Learning rate finding utility for conducting the "LR range test", see article reference for more information
Use the ``scan`` method for finding the loss values for learning rates in the given range
Arguments
---------
- `model` (``tf.keras.Model``): Model for conduct test for. Must call ``model.compile`` before using this utility
Examples
--------
Run an learning rate range test in the domain ``[0.0001, 1.0]``
.. code-block:: python3
import tensorflow as tf
import tavolo as tvl
train_data = ...
train_labels = ...
# Build model
model = tf.keras.Sequential([tf.keras.layers.Input(shape=(784,)),
tf.keras.layers.Dense(128, activation=tf.nn.relu),
tf.keras.layers.Dense(10, activation=tf.nn.softmax)])
# Must call compile with optimizer before test
model.compile(optimizer=tf.keras.optimizers.SGD(), loss=tf.keras.losses.CategoricalCrossentropy())
# Run learning rate range test
lr_finder = tvl.learning.LearningRateFinder(model=model)
learning_rates, losses = lr_finder.scan(train_data, train_labels, min_lr=0.0001, max_lr=1.0, batch_size=128)
### Plot the results to choose your learning rate
References
----------
- `Cyclical Learning Rates for Training Neural Networks`_
.. _`Cyclical Learning Rates for Training Neural Networks`: https://arxiv.org/abs/1506.01186
"""

def __init__(self, model: tf.keras.Model):
"""
:param model: Model for conduct test for. Must call ``model.compile`` before using this utility
"""

self._model = model
self._lr_range = None
self._iteration = None
self._learning_rates = None
self._losses = None

def scan(self, x, y,
min_lr: float = 0.0001,
max_lr: float = 1.0,
batch_size: Optional[int] = None,
steps: int = 100) -> Tuple[List[float], List[float]]:
"""
Scans the learning rate range ``[min_lr, max_lr]`` for loss values
:param x: Input data. It could be:
- A Numpy array (or array-like), or a list of arrays (in case the model has multiple inputs)
- A TensorFlow tensor, or a list of tensors (in case the model has multiple inputs)
- A dict mapping input names to the corresponding array/tensors, if the model has named inputs
- A ``tf.data`` dataset or a dataset iterator. Should return a tuple of either ``(inputs, targets)`` or
``(inputs, targets, sample_weights)``
- A generator or ``keras.utils.Sequence`` returning ``(inputs, targets)`` or ``(inputs, targets, sample weights)``
:param y: Target data. Like the input data `x`,
it could be either Numpy array(s) or TensorFlow tensor(s).
It should be consistent with ``x`` (you cannot have Numpy inputs and
tensor targets, or inversely). If ``x`` is a dataset, dataset
iterator, generator, or ``tf.keras.utils.Sequence`` instance, ``y`` should
not be specified (since targets will be obtained from ``x``).
:param min_lr: Minimum learning rate
:param max_lr: Maximum learning rate
:param batch_size: Number of samples per gradient update.
Do not specify the ``batch_size`` if your data is in the
form of symbolic tensors, dataset, dataset iterators,
generators, or ``tf.keras.utils.Sequence`` instances (since they generate batches)
:param steps: Number of steps to scan between min_lr and max_lr
:return: Learning rates, losses documented
"""

# Prerequisites
self._iteration = 0
self._learning_rates = list()
self._losses = list()

# Save initial values
initial_checkpoint = tempfile.NamedTemporaryFile()
self._model.save_weights(initial_checkpoint.name) # Save original weights
initial_learning_rate = tf.keras.backend.get_value(self._model.optimizer.lr) # Save original lr

# Build range
self._lr_range = np.linspace(start=min_lr, stop=max_lr, num=steps)

# Scan
tf.keras.backend.set_value(self._model.optimizer.lr, self._lr_range[self._iteration])
scan_callback = tf.keras.callbacks.LambdaCallback(
on_batch_end=lambda batch, logs: self._on_batch_end(batch, logs))

self._model.fit(x, y, batch_size=batch_size, epochs=1, steps_per_epoch=steps,
verbose=0, callbacks=[scan_callback])

# Restore initial values
self._model.load_weights(initial_checkpoint.name) # Restore original weights
tf.keras.backend.set_value(self._model.optimizer.lr, initial_learning_rate) # Restore original lr

return self._learning_rates, self._losses

def _on_batch_end(self, batch: tf.Tensor, logs: dict):
# Save learning rate and corresponding loss
self._learning_rates.append(
tf.keras.backend.get_value(self._model.optimizer.lr))

self._losses.append(
logs['loss'])

# Stop on exploding gradient
if math.isnan(logs['loss']):
self._model.stop_training = True
return

# Apply next learning rate
with suppress(IndexError):
tf.keras.backend.set_value(self._model.optimizer.lr, self._lr_range[self._iteration + 1])
self._iteration += 1
26 changes: 26 additions & 0 deletions tests/learning/cyclic_learning_rate_callback_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
def test_logic():
""" Test logic on known input """

# -------- TRIANGULAR --------

# Input
input_2d = np.random.normal(size=(1000, 20))
labels = np.random.randint(low=0, high=2, size=1000)
Expand All @@ -28,3 +30,27 @@ def test_logic():
model.fit(input_2d, labels, batch_size=10, epochs=5, callbacks=[clr], verbose=0)

assert all(math.isclose(a, b, rel_tol=0.001) for a, b in zip(clr.history['lr'], expected_lr_values))

# -------- TRIANGULAR2 --------

# Create model
model = tf.keras.Sequential([tf.keras.layers.Input(shape=(20,)),
tf.keras.layers.Dense(10, activation=tf.nn.relu),
tf.keras.layers.Dense(1, activation=tf.nn.sigmoid)])
model.compile(optimizer=tf.keras.optimizers.SGD(), loss='binary_crossentropy')

clr = CyclicLearningRateCallback(scale_scheme='triangular2')

model.fit(input_2d, labels, batch_size=10, epochs=5, callbacks=[clr], verbose=0)

# -------- EXPONENT RANGE --------

# Create model
model = tf.keras.Sequential([tf.keras.layers.Input(shape=(20,)),
tf.keras.layers.Dense(10, activation=tf.nn.relu),
tf.keras.layers.Dense(1, activation=tf.nn.sigmoid)])
model.compile(optimizer=tf.keras.optimizers.SGD(), loss='binary_crossentropy')

clr = CyclicLearningRateCallback(scale_scheme='exp_range')

model.fit(input_2d, labels, batch_size=10, epochs=5, callbacks=[clr], verbose=0)
26 changes: 26 additions & 0 deletions tests/learning/learning_rate_finder_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import tensorflow as tf
import numpy as np

from tavolo.learning import LearningRateFinder


def test_logic():
""" Test logic on known input """

# Input
input_2d = np.random.normal(size=(10000, 20))
labels = np.random.randint(low=0, high=2, size=10000)

# Create model
model = tf.keras.Sequential([tf.keras.layers.Input(shape=(20,)),
tf.keras.layers.Dense(10, activation=tf.nn.relu),
tf.keras.layers.Dense(1, activation=tf.nn.sigmoid)])
model.compile(optimizer=tf.keras.optimizers.SGD(), loss='binary_crossentropy')

# Learning rate range test
lr_finder = LearningRateFinder(model=model)

# Run model
lrs, losses = lr_finder.scan(input_2d, labels, batch_size=50)

assert len(lrs) == len(losses)

0 comments on commit ec3e459

Please sign in to comment.