diff --git a/src/defense/process.py b/src/defense/process.py index 7a7d0fc..9b435f9 100644 --- a/src/defense/process.py +++ b/src/defense/process.py @@ -1,5 +1,5 @@ from utils.logging import concat_batch_images -from defense.models import Reformer, Denoiser, Motd +from defense.models import Reformer, Denoiser, Motd, ExMotd import tensorflow as tf @@ -30,7 +30,7 @@ type=str, help="Defense method", required=True, - choices=["reformer", "denoiser", "motd"], + choices=["reformer", "exformer", "denoiser", "motd", "exmotd"], ) parser.add_argument( @@ -59,19 +59,32 @@ input_shape=input_shape, intensity=args.intensity[0], ) + elif args.defense == "exformer": + defense_model = Reformer( + "defense_exformer_cifar10", + input_shape=input_shape, + intensity=args.intensity[0], + ) elif args.defense == "denoiser": defense_model = Denoiser( f"defense_denoiser_{args.dataset}", input_shape=input_shape, intensity=args.intensity[0], ) - else: + elif args.defense == "motd": defense_model = Motd( f"defense_motd_{args.dataset}", input_shape=input_shape, dataset=args.dataset, intensities=args.intensity, ) + else: + defense_model = ExMotd( + f"defense_exmotd_{args.dataset}", + input_shape=input_shape, + dataset=args.dataset, + intensities=args.intensity, + ) defense_model.compile() defense_model.load()