diff --git a/src/training/_build_model.py b/src/training/_build_model.py index 66fe9bb..67c831e 100644 --- a/src/training/_build_model.py +++ b/src/training/_build_model.py @@ -3,7 +3,7 @@ from keras.api.layers import Input, Conv2D, BatchNormalization, LeakyReLU, GlobalAveragePooling2D, Dense, Dropout, TimeDistributed, GRU, Add, LayerNormalization, MultiHeadAttention from keras.api.optimizers import Adam from keras.api.regularizers import l2 -from keras.api.metrics import Accuracy, BinaryCrossentropy, AUC, Precision, Recall, MeanSquaredError +from keras.api.metrics import Accuracy, CategoricalCrossentropy, AUC, Precision, Recall, MeanSquaredError from src.utils.path_utils import find_project_directory from src.training._load_dataset import MAX_MOVES, BATCH_SIZE @@ -67,7 +67,7 @@ def build_model(): model.compile( optimizer=Adam(learning_rate=1e-4), loss='category_crossentropy', - metrics=[Accuracy(), BinaryCrossentropy(), AUC(), Precision(), + metrics=[Accuracy(), CategoricalCrossentropy(), AUC(), Precision(), Recall(), MeanSquaredError()] )