Skip to content

Commit

Permalink
Update bn_folding_test.py
Browse files Browse the repository at this point in the history
  • Loading branch information
jmduarte authored Oct 25, 2022
1 parent 2a787ff commit 224797b
Showing 1 changed file with 4 additions and 8 deletions.
12 changes: 4 additions & 8 deletions tests/bn_folding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,17 +462,13 @@ def test_same_training_and_prediction(model_name):
if model_name == "conv2d":
x_shape = (2, 2, 1)
kernel = np.array([[[[1., 1.]], [[1., 0.]]], [[[1., 1.]], [[0., 1.]]]])
gamma = np.array([2., 1.])
beta = np.array([0., 1.])
moving_mean = np.array([1., 1.])
moving_variance = np.array([1., 2.])
elif model_name == "dense":
x_shape = (4,)
kernel = np.array([[1., 1.], [1., 0.], [1., 1.], [0., 1.]])
gamma = np.array([2., 1.])
beta = np.array([0., 1.])
moving_mean = np.array([1., 1.])
moving_variance = np.array([1., 2.])
gamma = np.array([2., 1.])
beta = np.array([0., 1.])
moving_mean = np.array([1., 1.])
moving_variance = np.array([1., 2.])
iteration = np.array(-1)

train_ds = generate_dataset(train_size=10, batch_size=10, input_shape=x_shape,
Expand Down

0 comments on commit 224797b

Please sign in to comment.