Skip to content

Commit

Permalink
add 1 test
Browse files Browse the repository at this point in the history
  • Loading branch information
jmduarte committed Oct 24, 2022
1 parent 12e4fc1 commit b70c3be
Showing 1 changed file with 102 additions and 24 deletions.
126 changes: 102 additions & 24 deletions tests/bn_folding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,11 @@
from tensorflow.keras.backend import clear_session
from tensorflow.keras.utils import to_categorical
from tensorflow.keras import metrics
import pytest

from qkeras import QConv2DBatchnorm
from qkeras import QConv2D
from qkeras import QDenseBatchnorm
from qkeras import QDense
from qkeras import QActivation
from qkeras import QDepthwiseConv2D
Expand Down Expand Up @@ -110,7 +112,7 @@ def get_qconv2d_batchnorm_model(input_shape, kernel_size, folding_mode,
return model


def get_models_with_one_layer(kernel_quantizer, folding_mode, ema_freeze_delay):
def get_conv2d_models_with_one_layer(kernel_quantizer, folding_mode, ema_freeze_delay):

x_shape = (2, 2, 1)
loss_fn = tf.keras.losses.MeanSquaredError()
Expand Down Expand Up @@ -164,6 +166,60 @@ def get_models_with_one_layer(kernel_quantizer, folding_mode, ema_freeze_delay):
return (unfold_model, fold_model)


def get_dense_models_with_one_layer(kernel_quantizer, folding_mode, ema_freeze_delay):

x_shape = (4,)
loss_fn = tf.keras.losses.MeanSquaredError()
optimizer = get_sgd_optimizer(learning_rate=1e-3)

# define a model with seperate conv2d and bn layers
x = x_in = layers.Input(x_shape, name="input")
x = QDense(
2,
kernel_initializer="ones",
bias_initializer="zeros", use_bias=False,
kernel_quantizer=kernel_quantizer, bias_quantizer=None,
name="conv2d")(x)
x = layers.BatchNormalization(
axis=-1,
momentum=0.99,
epsilon=0.001,
center=True,
scale=True,
beta_initializer="zeros",
gamma_initializer="ones",
moving_mean_initializer="zeros",
moving_variance_initializer="ones",
beta_regularizer=None,
gamma_regularizer=None,
beta_constraint=None,
gamma_constraint=None,
renorm=False,
renorm_clipping=None,
renorm_momentum=0.99,
fused=None,
trainable=True,
virtual_batch_size=None,
adjustment=None,
name="bn")(x)
unfold_model = Model(inputs=[x_in], outputs=[x])
unfold_model.compile(loss=loss_fn, optimizer=optimizer, metrics="acc")

x = x_in = layers.Input(x_shape, name="input")
x = QDenseBatchnorm(
2,
kernel_initializer="ones", bias_initializer="zeros", use_bias=False,
kernel_quantizer=kernel_quantizer, beta_initializer="zeros",
gamma_initializer="ones", moving_mean_initializer="zeros",
moving_variance_initializer="ones", folding_mode=folding_mode,
ema_freeze_delay=ema_freeze_delay,
name="foldconv2d")(x)
fold_model = Model(inputs=[x_in], outputs=[x])
fold_model.compile(loss=loss_fn, optimizer=optimizer, metrics="acc")

return (unfold_model, fold_model)


def get_debug_model(model):
layer_output_list = []
for layer in model.layers:
Expand All @@ -181,10 +237,7 @@ def generate_dataset(train_size=10,
output_shape=None):
"""create tf.data.Dataset with shape: (N,) + input_shape."""

x_train = np.random.randint(
4, size=(train_size, input_shape[0], input_shape[1], input_shape[2]))
x_train = np.random.rand(
train_size, input_shape[0], input_shape[1], input_shape[2])
x_train = np.random.rand(*(train_size,) + input_shape)

if output_shape:
y_train = np.random.random_sample((train_size,) + output_shape)
Expand Down Expand Up @@ -397,31 +450,48 @@ def test_loading():
assert_equal(weight1[1], weight2[1])


def test_same_training_and_prediction():
@pytest.mark.parametrize("model_name", ["conv2d", "dense"])
def test_same_training_and_prediction(model_name):
"""test if fold/unfold layer has the same training and prediction output."""

epochs = 5
loss_fn = tf.keras.losses.MeanSquaredError()
loss_metric = metrics.Mean()
optimizer = get_sgd_optimizer(learning_rate=1e-3)

x_shape = (2, 2, 1)
kernel = np.array([[[[1., 1.]], [[1., 0.]]], [[[1., 1.]], [[0., 1.]]]])
gamma = np.array([2., 1.])
beta = np.array([0., 1.])
moving_mean = np.array([1., 1.])
moving_variance = np.array([1., 2.])
if model_name == "conv2d":
x_shape = (2, 2, 1)
kernel = np.array([[[[1., 1.]], [[1., 0.]]], [[[1., 1.]], [[0., 1.]]]])
gamma = np.array([2., 1.])
beta = np.array([0., 1.])
moving_mean = np.array([1., 1.])
moving_variance = np.array([1., 2.])
elif model_name == "dense":
x_shape = (4,)
kernel = np.array([[1., 1.], [1., 0.], [1., 1.], [0., 1.]])
gamma = np.array([2., 1.])
beta = np.array([0., 1.])
moving_mean = np.array([1., 1.])
moving_variance = np.array([1., 2.])
iteration = np.array(-1)

train_ds = generate_dataset(train_size=10, batch_size=10, input_shape=x_shape,
num_class=2)

(unfold_model, fold_model_batch) = get_models_with_one_layer(
kernel_quantizer=None, folding_mode="batch_stats_folding",
ema_freeze_delay=10)
(_, fold_model_ema) = get_models_with_one_layer(
kernel_quantizer=None, folding_mode="ema_stats_folding",
ema_freeze_delay=10)
if model_name == "conv2d":
(unfold_model, fold_model_batch) = get_conv2d_models_with_one_layer(
kernel_quantizer=None, folding_mode="batch_stats_folding",
ema_freeze_delay=10)
(_, fold_model_ema) = get_conv2d_models_with_one_layer(
kernel_quantizer=None, folding_mode="ema_stats_folding",
ema_freeze_delay=10)
elif model_name == "dense":
(unfold_model, fold_model_batch) = get_dense_models_with_one_layer(
kernel_quantizer=None, folding_mode="batch_stats_folding",
ema_freeze_delay=10)
(_, fold_model_ema) = get_dense_models_with_one_layer(
kernel_quantizer=None, folding_mode="ema_stats_folding",
ema_freeze_delay=10)

unfold_model.layers[1].set_weights([kernel])
unfold_model.layers[2].set_weights(
Expand Down Expand Up @@ -455,12 +525,20 @@ def test_same_training_and_prediction():
# models should be different, but the two folding modes should be the same
epochs = 5
iteration = np.array(8)
(unfold_model, fold_model_batch) = get_models_with_one_layer(
kernel_quantizer=None, folding_mode="batch_stats_folding",
ema_freeze_delay=10)
(_, fold_model_ema) = get_models_with_one_layer(
kernel_quantizer=None, folding_mode="ema_stats_folding",
ema_freeze_delay=10)
if model_name == "conv2d":
(unfold_model, fold_model_batch) = get_conv2d_models_with_one_layer(
kernel_quantizer=None, folding_mode="batch_stats_folding",
ema_freeze_delay=10)
(_, fold_model_ema) = get_conv2d_models_with_one_layer(
kernel_quantizer=None, folding_mode="ema_stats_folding",
ema_freeze_delay=10)
elif model_name == "dense":
(unfold_model, fold_model_batch) = get_dense_models_with_one_layer(
kernel_quantizer=None, folding_mode="batch_stats_folding",
ema_freeze_delay=10)
(_, fold_model_ema) = get_dense_models_with_one_layer(
kernel_quantizer=None, folding_mode="ema_stats_folding",
ema_freeze_delay=10)
unfold_model.layers[1].set_weights([kernel])
unfold_model.layers[2].set_weights(
[gamma, beta, moving_mean, moving_variance])
Expand Down

0 comments on commit b70c3be

Please sign in to comment.