Skip to content

Commit

Permalink
QAT anomaly example
Browse files Browse the repository at this point in the history
  • Loading branch information
alexsu52 committed Mar 27, 2024
1 parent f2f3bb7 commit e468a9a
Show file tree
Hide file tree
Showing 5 changed files with 241 additions and 0 deletions.
37 changes: 37 additions & 0 deletions examples/quantization_aware_training/torch/anomalib/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Quantization-Aware Training of STFPM PyTorch model from Anomalib

The anomaly detection domain is one of the domains in which models are used in scenarios where the cost of model error is high and accuracy cannot be sacrificed for better model performance. Quantization-Aware Training (QAT) is perfect for such cases, as it reduces quantization error without model performance degradation by training the model.

This example demonstrates how to quantize [Student-Teacher Feature Pyramid Matching (STFPM)](https://anomalib.readthedocs.io/en/latest/markdown/guides/reference/models/image/stfpm.html) PyTorch model from [Anomalib](https://github.com/openvinotoolkit/anomalib) using Quantization API from Neural Network Compression Framework (NNCF). At the first step, the model is quanitzed using Post-Training Quantization (PTQ) algorithm to obtain the best initialization of the quantized model. If the accuracy of the quantized model after PTQ does not meet requiremenets, the next step is to train the quantized model using PyTorch framework.

NNCF provides semiless transition from Post-Training Quantization to Quantization-Aware Training without additional model preparation and transfer of magic parameters.

The example includes the following steps:

- Loading the [MVTec (capsule category)](https://www.mvtec.com/company/research/datasets/mvtec-ad) dataset (~4.9 Gb).
- (Optional) Training STFPM PyTorch model from scratch.
- Loading STFPM model pretrained on this dataset.
- Quantizing the model using NNCF Post-Training Quantization algorithm.
- Fine tuning quantized model for one epoch to improve quantized model metrics.
- Output of the following characteristics of the quantized model:
- Accuracy drop of the quantized model (INT8) over the pre-trained model (FP32)
- Compression rate of the quantized model file size relative to the pre-trained model file size
- Performance speed up of the quantized model (INT8)

## Install requirements

At this point it is assumed that you have already installed NNCF. You can find information on installation NNCF [here](https://github.com/openvinotoolkit/nncf#user-content-installation).

To work with the example you should install the corresponding Python package dependencies:

```bash
pip install -r requirements.txt
```

## Run Example

It's pretty simple. The example does not require additional preparation. It will do the preparation itself, such as loading the dataset and model, etc.

```bash
python main.py
```
161 changes: 161 additions & 0 deletions examples/quantization_aware_training/torch/anomalib/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
# Copyright (c) 2024 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import re
import subprocess
from copy import deepcopy
from pathlib import Path
from typing import List

import torch
from anomalib import TaskType
from anomalib.data import MVTec
from anomalib.deploy import ExportType
from anomalib.engine import Engine
from anomalib.models import Stfpm

import nncf

HOME_PATH = Path.home()
DATASET_PATH = HOME_PATH / ".cache/nncf/datasets/mvtec"
CHECKPOINT_PATH = HOME_PATH / ".cache/nncf/models/stfpm_mvtec"
ROOT = Path(__file__).parent.resolve()
FP32_RESULTS_ROOT = ROOT / "fp32"
INT8_RESULTS_ROOT = ROOT / "int8"
CHECKPOINT_URL = "https://huggingface.co/alexsu52/stfpm_mvtec_capsule/resolve/main/qat/model.ckpt"
USE_PRETRAINED = True


def run_benchmark(model_path: Path, shape: List[int]) -> float:
command = f"benchmark_app -m {model_path} -d CPU -api async -t 15"
command += f' -shape "[{",".join(str(x) for x in shape)}]"'
cmd_output = subprocess.check_output(command, shell=True) # nosec
print(*str(cmd_output).split("\\n")[-9:-1], sep="\n")
match = re.search(r"Throughput\: (.+?) FPS", str(cmd_output))
return float(match.group(1))


def get_model_size(ir_path: Path, m_type: str = "Mb") -> float:
xml_size = ir_path.stat().st_size
bin_size = ir_path.with_suffix(".bin").stat().st_size
for t in ["bytes", "Kb", "Mb"]:
if m_type == t:
break
xml_size /= 1024
bin_size /= 1024
model_size = xml_size + bin_size
print(f"Model graph (xml): {xml_size:.3f} Mb")
print(f"Model weights (bin): {bin_size:.3f} Mb")
print(f"Model size: {model_size:.3f} Mb")
return model_size


def main():
###############################################################################
# Step 1: Prepare the model and dataset
print(os.linesep + "[Step 1] Prepare the model and dataset")

model = Stfpm()
datamodule = MVTec(root=DATASET_PATH)

# Create an engine for the original model
engine = Engine(task=TaskType.SEGMENTATION, default_root_dir=FP32_RESULTS_ROOT)
if USE_PRETRAINED:
# Load the pretrained checkpoint
CHECKPOINT_PATH.mkdir(parents=True, exist_ok=True)
ckpt_path = CHECKPOINT_PATH / "model.ckpt"
torch.hub.download_url_to_file(CHECKPOINT_URL, ckpt_path)
else:
# (Optional) Train the model from scratch
engine.fit(model=model, datamodule=datamodule)
ckpt_path = engine.trainer.checkpoint_callback.best_model_path

print("Test results for original FP32 model:")
fp32_test_results = engine.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path)

###############################################################################
# Step 2: Quantize the model
print(os.linesep + "[Step 2] Quantize the model")

# Create calibration dataset
def transform_fn(data_item):
return data_item["image"]

test_loader = datamodule.test_dataloader()
calibration_dataset = nncf.Dataset(test_loader, transform_fn)

# Quantize the inference model using Post-Training Quantization
inference_model = model.model
quantized_inference_model = nncf.quantize(model=inference_model, calibration_dataset=calibration_dataset)

# Deepcopy the original model and set the quantized inference model
quantized_model = deepcopy(model)
quantized_model.model = quantized_inference_model

# Create engine for the quantized model
engine = Engine(task=TaskType.SEGMENTATION, default_root_dir=INT8_RESULTS_ROOT, max_epochs=1)

# Validate the quantized model
print("Test results for INT8 model after PTQ:")
int8_init_test_results = engine.test(model=quantized_model, datamodule=datamodule)

###############################################################################
# Step 3: Fine tune the quantized model
print(os.linesep + "[Step 3] Fine tune the quantized model")

engine.fit(model=quantized_model, datamodule=datamodule)
print("Test results for INT8 model after QAT:")
int8_test_results = engine.test(model=quantized_model, datamodule=datamodule)

###############################################################################
# Step 4: Export models
print(os.linesep + "[Step 4] Export models")

# Export FP32 model to OpenVINO™ IR
fp32_ir_path = engine.export(model=model, export_type=ExportType.OPENVINO, export_root=FP32_RESULTS_ROOT)
print(f"Original model path: {fp32_ir_path}")
fp32_size = get_model_size(fp32_ir_path)

# Export INT8 model to OpenVINO™ IR
int8_ir_path = engine.export(model=quantized_model, export_type=ExportType.OPENVINO, export_root=INT8_RESULTS_ROOT)
print(f"Quantized model path: {int8_ir_path}")
int8_size = get_model_size(int8_ir_path)

###############################################################################
# Step 5: Run benchmarks
print(os.linesep + "[Step 5] Run benchmarks")

print("Run benchmark for FP32 model (IR)...")
fp32_fps = run_benchmark(fp32_ir_path, shape=[1, 3, 256, 256])

print("Run benchmark for INT8 model (IR)...")
int8_fps = run_benchmark(int8_ir_path, shape=[1, 3, 256, 256])

###############################################################################
# Step 6: Summary
print(os.linesep + "[Step 6] Summary")

fp32_f1score = fp32_test_results[0]["image_F1Score"]
int8_init_f1score = int8_init_test_results[0]["image_F1Score"]
int8_f1score = int8_test_results[0]["image_F1Score"]

print(f"Accuracy drop after PTQ: {fp32_f1score - int8_init_f1score:.3f}")
print(f"Accuracy drop after QAT: {fp32_f1score - int8_f1score:.3f}")
print(f"Model compression rate: {fp32_size / int8_size:.3f}")
# https://docs.openvino.ai/latest/openvino_docs_optimization_guide_dldt_optimization_guide.html
print(f"Performance speed up (throughput mode): {int8_fps / fp32_fps:.3f}")

return fp32_f1score, int8_init_f1score, int8_f1score, fp32_fps, int8_fps, fp32_size, int8_size


if __name__ == "__main__":
main()
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
anomalib
torch
22 changes: 22 additions & 0 deletions tests/cross_fw/examples/example_scope.json
Original file line number Diff line number Diff line change
Expand Up @@ -207,5 +207,27 @@
"ratio": 1.0,
"group_size": 64
}
},
"quantization_aware_training_torch_anomalib": {
"backend": "torch",
"requirements": "examples/quantization_aware_training/torch/anomalib/requirements.txt",
"cpu": "Intel(R) Core(TM) i9-10980XE CPU @ 3.00GHz",
"accuracy_tolerance": 0.02,
"accuracy_metrics": {
"fp32_f1score": 0.9919999837875366,
"int8_init_f1score": 0.9767441749572754,
"int8_f1score": 0.9919999837875366,
"accuracy_drop": 0.0
},
"performance_metrics": {
"fp32_fps": 316.28,
"int8_fps": 879.84,
"performance_speed_up": 2.7818388769444797
},
"model_size_metrics": {
"fp32_model_size": 21.37990665435791,
"int8_model_size": 5.677968978881836,
"model_compression_rate": 3.7654144877995197
}
}
}
19 changes: 19 additions & 0 deletions tests/cross_fw/examples/run_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,25 @@ def quantization_aware_training_torch_resnet18():
}


def quantization_aware_training_torch_anomalib():
from examples.quantization_aware_training.torch.anomalib.main import main as anomalib_main

results = anomalib_main()

return {
"fp32_f1score": float(results[0]),
"int8_init_f1score": float(results[1]),
"int8_f1score": float(results[2]),
"accuracy_drop": float(results[0] - results[2]),
"fp32_fps": results[3],
"int8_fps": results[4],
"performance_speed_up": results[4] / results[3],
"fp32_model_size": results[5],
"int8_model_size": results[6],
"model_compression_rate": results[5] / results[6],
}


def main(argv):
parser = ArgumentParser()
parser.add_argument("--name", help="Example name", required=True)
Expand Down

0 comments on commit e468a9a

Please sign in to comment.