Skip to content

Commit

Permalink
(#6) Attack: Separate attack.py
Browse files Browse the repository at this point in the history
  • Loading branch information
betarixm committed Mar 31, 2022
1 parent c1a8fdf commit 75f2eb2
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 15 deletions.
64 changes: 64 additions & 0 deletions src/attack/attack.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from typings.models import Attack
from models import Fgsm
from defense.models import Reformer
from victim.models import Classifier

from utils.dataset import Mnist, Cifar10

import argparse
import tensorflow as tf


keras = tf.keras

if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Attack and defense pretrained models."
)

parser.add_argument(
"--dataset",
"-d",
metavar="DATASET",
type=str,
help="Dataset for training",
required=True,
choices=["mnist", "cifar10"],
)

parser.add_argument(
"--defense",
"-f",
help="Use defense model",
action="store_true",
)

args = parser.parse_args()

a: Attack

if args.dataset == "mnist":
a = Fgsm(
Classifier(name="victim_classifier_mnist", input_shape=(28, 28, 1)),
Mnist(),
defense_model=Reformer("defense_reformer_mnist", input_shape=(28, 28, 1))
if args.defense
else None,
)
elif args.dataset == "cifar10":
a = Fgsm(
Classifier(name="victim_classifier_cifar10", input_shape=(32, 32, 3)),
Cifar10(),
defense_model=Reformer("defense_reformer_cifar10", input_shape=(32, 32, 3))
if args.defense
else None,
)

acc, acc_under_attack, acc_with_defense = a.attack()

print(f"[*] Attack {args.dataset.upper()} {'with defense' if args.defense else ''}")
print(f" - Normal: {acc.result()}")
print(f" - Under Attack: {acc_under_attack.result()}")

if args.defense:
print(f" - With Defense: {acc_with_defense.result()}")
15 changes: 0 additions & 15 deletions src/attack/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,6 @@

from cleverhans.tf2.attacks.fast_gradient_method import fast_gradient_method

from victim.models import Classifier
from defense.models import Reformer

from utils.dataset import Cifar10

import numpy as np
import tensorflow as tf

Expand All @@ -16,13 +11,3 @@
class Fgsm(Attack):
def add_perturbation(self, x: np.array) -> np.array:
return fast_gradient_method(self.victim_model.model(), x, 0.05, np.inf)


if __name__ == "__main__":
f = Fgsm(
Classifier(name="victim_classifier_cifar10", input_shape=(32, 32, 3)),
Cifar10(),
defense_model=Reformer("defense_reformer_cifar10", input_shape=(32, 32, 3)),
)
acc, acc_under_attack, acc_with_defense = f.attack()
print(acc.result(), acc_under_attack.result(), acc_with_defense.result())

0 comments on commit 75f2eb2

Please sign in to comment.