From deb29280ff5047d69af5763a8910f45805408f92 Mon Sep 17 00:00:00 2001 From: beta Date: Sun, 1 May 2022 21:26:42 +0900 Subject: [PATCH] (#3) Defense: Fix SLQ Layer to compat 1-channel images --- src/utils/layers.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/utils/layers.py b/src/utils/layers.py index d525b43..b8bf06d 100644 --- a/src/utils/layers.py +++ b/src/utils/layers.py @@ -28,6 +28,9 @@ def compress(image): n, m, c = image.shape + if c == 1: + image = tf.image.grayscale_to_rgb(image) + patch_n = tf.cast(n / self.patch_size, dtype=tf.int32) + tf.cond( tf.constant(n % self.patch_size > 0, tf.bool), lambda: one, @@ -72,8 +75,13 @@ def compress(image): name="compressed_images", ) - return keras.layers.experimental.preprocessing.Rescaling(1.0 / 255)( + result = keras.layers.experimental.preprocessing.Rescaling(1.0 / 255)( tf.gather_nd(x_compressed_stack, indices, name="final_image") ) + if c == 1: + result = tf.image.rgb_to_grayscale(result) + + return result + return tf.map_fn(fn=compress, elems=inputs)