Skip to content

Commit

Permalink
increase epochs; set seed
Browse files Browse the repository at this point in the history
  • Loading branch information
jmduarte committed Oct 25, 2022
1 parent bf9ccee commit 2a787ff
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions tests/autoqkeras_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import tempfile
import numpy as np
import pytest
import random
from sklearn.datasets import load_iris
from sklearn.preprocessing import MinMaxScaler
import tensorflow.compat.v2 as tf
Expand Down Expand Up @@ -64,8 +65,10 @@ def dense_model():

def test_autoqkeras():
"""Tests AutoQKeras scheduler."""
np.random.seed(42)
tf.random.set_seed(42)
seed = 42
random.seed(seed)
np.random.seed(seed)
tf.random.set_seed(seed)

x_train, y_train = load_iris(return_X_y=True)

Expand Down Expand Up @@ -143,14 +146,14 @@ def test_autoqkeras():
}

autoqk = AutoQKerasScheduler(model, metrics=["acc"], **run_config)
autoqk.fit(x_train, y_train, validation_split=0.1, batch_size=150, epochs=4)
autoqk.fit(x_train, y_train, validation_split=0.1, batch_size=150, epochs=8)

qmodel = autoqk.get_best_model()

optimizer = get_adam_optimizer(learning_rate=0.01)
qmodel.compile(optimizer=optimizer, loss="categorical_crossentropy",
metrics=["acc"])
history = qmodel.fit(x_train, y_train, epochs=5, batch_size=150,
history = qmodel.fit(x_train, y_train, epochs=10, batch_size=150,
validation_split=0.1)

quantized_acc = history.history["acc"][-1]
Expand Down

0 comments on commit 2a787ff

Please sign in to comment.