Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expectation Over Transformation Wrapper #719

Merged
merged 4 commits into from
Mar 4, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 129 additions & 0 deletions examples/eot_attack_pytorch_resnet18.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
#!/usr/bin/env python3
"""
A simple example that demonstrates how to run Expectation over Transformation
coupled with any attack, on a Resnet-18 PyTorch model.
"""
from typing import Any

import torch
from torch import Tensor
import torchvision.models as models
import torchvision.transforms as transforms
import eagerpy as ep
from foolbox import PyTorchModel, accuracy, samples
from foolbox.attacks import LinfPGD
from foolbox.models import ExpectationOverTransformationWrapper


class RandomizedResNet18(torch.nn.Module):
def __init__(self) -> None:

super().__init__()

# base model
self.model = models.resnet18(pretrained=True)

# random apply rotation
self.transforms = transforms.RandomRotation(degrees=25)

def forward(self, x: Tensor) -> Any:

# random transform
x = self.transforms(x)

return self.model(x)


def main() -> None:
# instantiate a model (could also be a TensorFlow or JAX model)
model = models.resnet18(pretrained=True).eval()
preprocessing = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], axis=-3)
fmodel = PyTorchModel(model, bounds=(0, 1), preprocessing=preprocessing)

# get data and test the model
# wrapping the tensors with ep.astensors is optional, but it allows
# us to work with EagerPy tensors in the following
images, labels = ep.astensors(*samples(fmodel, dataset="imagenet", batchsize=16))

print("Testing attack on the base model (no transformations applied)")
clean_acc = accuracy(fmodel, images, labels)
print(f"clean accuracy: {clean_acc * 100:.1f} %")

# apply an attack with different eps
attack = LinfPGD()
epsilons = [
0.0,
0.0002,
0.0005,
0.0008,
0.001,
0.0015,
0.002,
0.003,
0.01,
0.02,
0.03,
0.1,
0.3,
0.5,
1.0,
]

raw_advs, clipped_advs, success = attack(fmodel, images, labels, epsilons=epsilons)

# calculate and report the robust accuracy (the accuracy of the model when
# it is attacked)
robust_accuracy = 1 - success.float32().mean(axis=-1)
print("robust accuracy for perturbations with")
for eps, acc in zip(epsilons, robust_accuracy):
print(f" Linf norm ≤ {eps:<6}: {acc.item() * 100:4.1f} %")

# Let's apply the same LinfPGD attack, but on a model with random transformations
rand_model = RandomizedResNet18().eval()
fmodel = PyTorchModel(rand_model, bounds=(0, 1), preprocessing=preprocessing)
seed = 1111

print("#" * 40)
print("Testing attack on the randomized model (random rotation applied)")

# Note: accuracy may slightly decrease, depending on seed
torch.manual_seed(seed)
clean_acc = accuracy(fmodel, images, labels)
print(f"clean accuracy: {clean_acc * 100:.1f} %")

# test the base attack on the randomized model
print("robust accuracy for perturbations with")
for eps in epsilons:

# reset seed to have the same perturbations in each attack
torch.manual_seed(seed)
_, _, success = attack(fmodel, images, labels, epsilons=eps)

# calculate and report the robust accuracy
# the attack is performing worse on the randomized models, since gradient computation is affected!
robust_accuracy = 1 - success.float32().mean(axis=-1)
print(f" Linf norm ≤ {eps:<6}: {robust_accuracy.item() * 100:4.1f} %")

# Now, Let's use Expectation Over Transformation to counter the randomization
eot_model = ExpectationOverTransformationWrapper(fmodel, n_steps=16)

print("#" * 40)
print("Testing EoT attack on the randomized model (random crop applied)")
torch.manual_seed(seed)
clean_acc = accuracy(eot_model, images, labels)
print(f"clean accuracy: {clean_acc * 100:.1f} %")

print("robust accuracy for perturbations with")
for eps in epsilons:
# reset seed to have the same perturbations in each attack
torch.manual_seed(seed)
_, _, success = attack(eot_model, images, labels, epsilons=eps)

# calculate and report the robust accuracy
# with EoT, the base attack is working again!
robust_accuracy = 1 - success.float32().mean(axis=-1)
print(f" Linf norm ≤ {eps:<6}: {robust_accuracy.item() * 100:4.1f} %")


if __name__ == "__main__":
main()
1 change: 1 addition & 0 deletions foolbox/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@
from .numpy import NumPyModel # noqa: F401

from .wrappers import ThresholdingWrapper # noqa: F401
from .wrappers import ExpectationOverTransformationWrapper # noqa: F401
26 changes: 26 additions & 0 deletions foolbox/models/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,29 @@ def __call__(self, inputs: T) -> T:
y = ep.where(x < self._threshold, min_, max_).astype(x.dtype)
z = self._model(y)
return restore_type(z)


class ExpectationOverTransformationWrapper(Model):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for not mentioning this earlier, but can you please add docstrings to the class & module like its done for the other classes in the project?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added it to the models.rst file, taking the other wrapper module as an example.
However, I am not sure how to fix the version error for sphinx. In my local machine, I tried to build html docs with sphinx version 5 and it works, but the requirements.txt file in the docs folder wants:

sphinx==4.5.0
sphinx-autobuild==2021.3.14
sphinx_rtd_theme==1.0.0
sphinx-typlog-theme==0.8.0

def __init__(self, model: Model, n_steps: int = 16):
self._model = model
self._n_steps = n_steps

@property
def bounds(self) -> Bounds:
return self._model.bounds

def __call__(self, inputs: T) -> T:

x, restore_type = ep.astensor_(inputs)

for i in range(self._n_steps):
z_t = self._model(x)

if i == 0:
z = z_t.expand_dims(0)
else:
z = ep.concatenate([z, z_t.expand_dims(0)], axis=0)

z = z.mean(0)

return restore_type(z)
53 changes: 53 additions & 0 deletions tests/test_eot_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import pytest

import eagerpy as ep

from foolbox import accuracy
from foolbox.attacks import (
LinfBasicIterativeAttack,
L1BasicIterativeAttack,
L2BasicIterativeAttack,
)
from foolbox.models import ExpectationOverTransformationWrapper
from foolbox.types import L2, Linf

from conftest import ModeAndDataAndDescription


def test_eot_wrapper(
fmodel_and_data_ext_for_attacks: ModeAndDataAndDescription,
) -> None:

(fmodel, x, y), real, low_dimensional_input = fmodel_and_data_ext_for_attacks

if isinstance(x, ep.NumPyTensor):
pytest.skip()

# test clean accuracy when wrapping EoT
x = (x - fmodel.bounds.lower) / (fmodel.bounds.upper - fmodel.bounds.lower)
fmodel = fmodel.transform_bounds((0, 1))
acc = accuracy(fmodel, x, y)

rand_model = ExpectationOverTransformationWrapper(fmodel, n_steps=4)
rand_acc = accuracy(rand_model, x, y)
assert acc - rand_acc == 0

# test with base attacks
# (accuracy should not change, since fmodel is not random)
attacks = (
L1BasicIterativeAttack(),
L2BasicIterativeAttack(),
LinfBasicIterativeAttack(),
)
epsilons = (5000.0, L2(50.0), Linf(1.0))

for attack, eps in zip(attacks, epsilons):

# acc on standard model
advs, _, _ = attack(fmodel, x, y, epsilons=eps)
adv_acc = accuracy(fmodel, advs, y)

# acc on eot model
advs, _, _ = attack(rand_model, x, y, epsilons=eps)
r_adv_acc = accuracy(rand_model, advs, y)
assert adv_acc - r_adv_acc == 0
Loading