From 02c363f56b6302cf181f28cff4687b0be4f5fedf Mon Sep 17 00:00:00 2001 From: Siddharth Rao Date: Thu, 20 Jun 2024 21:02:02 -0700 Subject: [PATCH] Changed Metric To Categorical Cross Entropy --- src/training/_build_model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/training/_build_model.py b/src/training/_build_model.py index 66fe9bb..67c831e 100644 --- a/src/training/_build_model.py +++ b/src/training/_build_model.py @@ -3,7 +3,7 @@ from keras.api.layers import Input, Conv2D, BatchNormalization, LeakyReLU, GlobalAveragePooling2D, Dense, Dropout, TimeDistributed, GRU, Add, LayerNormalization, MultiHeadAttention from keras.api.optimizers import Adam from keras.api.regularizers import l2 -from keras.api.metrics import Accuracy, BinaryCrossentropy, AUC, Precision, Recall, MeanSquaredError +from keras.api.metrics import Accuracy, CategoricalCrossentropy, AUC, Precision, Recall, MeanSquaredError from src.utils.path_utils import find_project_directory from src.training._load_dataset import MAX_MOVES, BATCH_SIZE @@ -67,7 +67,7 @@ def build_model(): model.compile( optimizer=Adam(learning_rate=1e-4), loss='category_crossentropy', - metrics=[Accuracy(), BinaryCrossentropy(), AUC(), Precision(), + metrics=[Accuracy(), CategoricalCrossentropy(), AUC(), Precision(), Recall(), MeanSquaredError()] )