-
Notifications
You must be signed in to change notification settings - Fork 240
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
241 additions
and
0 deletions.
There are no files selected for viewing
37 changes: 37 additions & 0 deletions
37
examples/quantization_aware_training/torch/anomalib/README.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
# Quantization-Aware Training of STFPM PyTorch model from Anomalib | ||
|
||
The anomaly detection domain is one of the domains in which models are used in scenarios where the cost of model error is high and accuracy cannot be sacrificed for better model performance. Quantization-Aware Training (QAT) is perfect for such cases, as it reduces quantization error without model performance degradation by training the model. | ||
|
||
This example demonstrates how to quantize [Student-Teacher Feature Pyramid Matching (STFPM)](https://anomalib.readthedocs.io/en/latest/markdown/guides/reference/models/image/stfpm.html) PyTorch model from [Anomalib](https://github.com/openvinotoolkit/anomalib) using Quantization API from Neural Network Compression Framework (NNCF). At the first step, the model is quanitzed using Post-Training Quantization (PTQ) algorithm to obtain the best initialization of the quantized model. If the accuracy of the quantized model after PTQ does not meet requiremenets, the next step is to train the quantized model using PyTorch framework. | ||
|
||
NNCF provides semiless transition from Post-Training Quantization to Quantization-Aware Training without additional model preparation and transfer of magic parameters. | ||
|
||
The example includes the following steps: | ||
|
||
- Loading the [MVTec (capsule category)](https://www.mvtec.com/company/research/datasets/mvtec-ad) dataset (~4.9 Gb). | ||
- (Optional) Training STFPM PyTorch model from scratch. | ||
- Loading STFPM model pretrained on this dataset. | ||
- Quantizing the model using NNCF Post-Training Quantization algorithm. | ||
- Fine tuning quantized model for one epoch to improve quantized model metrics. | ||
- Output of the following characteristics of the quantized model: | ||
- Accuracy drop of the quantized model (INT8) over the pre-trained model (FP32) | ||
- Compression rate of the quantized model file size relative to the pre-trained model file size | ||
- Performance speed up of the quantized model (INT8) | ||
|
||
## Install requirements | ||
|
||
At this point it is assumed that you have already installed NNCF. You can find information on installation NNCF [here](https://github.com/openvinotoolkit/nncf#user-content-installation). | ||
|
||
To work with the example you should install the corresponding Python package dependencies: | ||
|
||
```bash | ||
pip install -r requirements.txt | ||
``` | ||
|
||
## Run Example | ||
|
||
It's pretty simple. The example does not require additional preparation. It will do the preparation itself, such as loading the dataset and model, etc. | ||
|
||
```bash | ||
python main.py | ||
``` |
161 changes: 161 additions & 0 deletions
161
examples/quantization_aware_training/torch/anomalib/main.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,161 @@ | ||
# Copyright (c) 2024 Intel Corporation | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import os | ||
import re | ||
import subprocess | ||
from copy import deepcopy | ||
from pathlib import Path | ||
from typing import List | ||
|
||
import torch | ||
from anomalib import TaskType | ||
from anomalib.data import MVTec | ||
from anomalib.deploy import ExportType | ||
from anomalib.engine import Engine | ||
from anomalib.models import Stfpm | ||
|
||
import nncf | ||
|
||
HOME_PATH = Path.home() | ||
DATASET_PATH = HOME_PATH / ".cache/nncf/datasets/mvtec" | ||
CHECKPOINT_PATH = HOME_PATH / ".cache/nncf/models/stfpm_mvtec" | ||
ROOT = Path(__file__).parent.resolve() | ||
FP32_RESULTS_ROOT = ROOT / "fp32" | ||
INT8_RESULTS_ROOT = ROOT / "int8" | ||
CHECKPOINT_URL = "https://huggingface.co/alexsu52/stfpm_mvtec_capsule/resolve/main/qat/model.ckpt" | ||
USE_PRETRAINED = True | ||
|
||
|
||
def run_benchmark(model_path: Path, shape: List[int]) -> float: | ||
command = f"benchmark_app -m {model_path} -d CPU -api async -t 15" | ||
command += f' -shape "[{",".join(str(x) for x in shape)}]"' | ||
cmd_output = subprocess.check_output(command, shell=True) # nosec | ||
print(*str(cmd_output).split("\\n")[-9:-1], sep="\n") | ||
match = re.search(r"Throughput\: (.+?) FPS", str(cmd_output)) | ||
return float(match.group(1)) | ||
|
||
|
||
def get_model_size(ir_path: Path, m_type: str = "Mb") -> float: | ||
xml_size = ir_path.stat().st_size | ||
bin_size = ir_path.with_suffix(".bin").stat().st_size | ||
for t in ["bytes", "Kb", "Mb"]: | ||
if m_type == t: | ||
break | ||
xml_size /= 1024 | ||
bin_size /= 1024 | ||
model_size = xml_size + bin_size | ||
print(f"Model graph (xml): {xml_size:.3f} Mb") | ||
print(f"Model weights (bin): {bin_size:.3f} Mb") | ||
print(f"Model size: {model_size:.3f} Mb") | ||
return model_size | ||
|
||
|
||
def main(): | ||
############################################################################### | ||
# Step 1: Prepare the model and dataset | ||
print(os.linesep + "[Step 1] Prepare the model and dataset") | ||
|
||
model = Stfpm() | ||
datamodule = MVTec(root=DATASET_PATH) | ||
|
||
# Create an engine for the original model | ||
engine = Engine(task=TaskType.SEGMENTATION, default_root_dir=FP32_RESULTS_ROOT) | ||
if USE_PRETRAINED: | ||
# Load the pretrained checkpoint | ||
CHECKPOINT_PATH.mkdir(parents=True, exist_ok=True) | ||
ckpt_path = CHECKPOINT_PATH / "model.ckpt" | ||
torch.hub.download_url_to_file(CHECKPOINT_URL, ckpt_path) | ||
else: | ||
# (Optional) Train the model from scratch | ||
engine.fit(model=model, datamodule=datamodule) | ||
ckpt_path = engine.trainer.checkpoint_callback.best_model_path | ||
|
||
print("Test results for original FP32 model:") | ||
fp32_test_results = engine.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path) | ||
|
||
############################################################################### | ||
# Step 2: Quantize the model | ||
print(os.linesep + "[Step 2] Quantize the model") | ||
|
||
# Create calibration dataset | ||
def transform_fn(data_item): | ||
return data_item["image"] | ||
|
||
test_loader = datamodule.test_dataloader() | ||
calibration_dataset = nncf.Dataset(test_loader, transform_fn) | ||
|
||
# Quantize the inference model using Post-Training Quantization | ||
inference_model = model.model | ||
quantized_inference_model = nncf.quantize(model=inference_model, calibration_dataset=calibration_dataset) | ||
|
||
# Deepcopy the original model and set the quantized inference model | ||
quantized_model = deepcopy(model) | ||
quantized_model.model = quantized_inference_model | ||
|
||
# Create engine for the quantized model | ||
engine = Engine(task=TaskType.SEGMENTATION, default_root_dir=INT8_RESULTS_ROOT, max_epochs=1) | ||
|
||
# Validate the quantized model | ||
print("Test results for INT8 model after PTQ:") | ||
int8_init_test_results = engine.test(model=quantized_model, datamodule=datamodule) | ||
|
||
############################################################################### | ||
# Step 3: Fine tune the quantized model | ||
print(os.linesep + "[Step 3] Fine tune the quantized model") | ||
|
||
engine.fit(model=quantized_model, datamodule=datamodule) | ||
print("Test results for INT8 model after QAT:") | ||
int8_test_results = engine.test(model=quantized_model, datamodule=datamodule) | ||
|
||
############################################################################### | ||
# Step 4: Export models | ||
print(os.linesep + "[Step 4] Export models") | ||
|
||
# Export FP32 model to OpenVINO™ IR | ||
fp32_ir_path = engine.export(model=model, export_type=ExportType.OPENVINO, export_root=FP32_RESULTS_ROOT) | ||
print(f"Original model path: {fp32_ir_path}") | ||
fp32_size = get_model_size(fp32_ir_path) | ||
|
||
# Export INT8 model to OpenVINO™ IR | ||
int8_ir_path = engine.export(model=quantized_model, export_type=ExportType.OPENVINO, export_root=INT8_RESULTS_ROOT) | ||
print(f"Quantized model path: {int8_ir_path}") | ||
int8_size = get_model_size(int8_ir_path) | ||
|
||
############################################################################### | ||
# Step 5: Run benchmarks | ||
print(os.linesep + "[Step 5] Run benchmarks") | ||
|
||
print("Run benchmark for FP32 model (IR)...") | ||
fp32_fps = run_benchmark(fp32_ir_path, shape=[1, 3, 256, 256]) | ||
|
||
print("Run benchmark for INT8 model (IR)...") | ||
int8_fps = run_benchmark(int8_ir_path, shape=[1, 3, 256, 256]) | ||
|
||
############################################################################### | ||
# Step 6: Summary | ||
print(os.linesep + "[Step 6] Summary") | ||
|
||
fp32_f1score = fp32_test_results[0]["image_F1Score"] | ||
int8_init_f1score = int8_init_test_results[0]["image_F1Score"] | ||
int8_f1score = int8_test_results[0]["image_F1Score"] | ||
|
||
print(f"Accuracy drop after PTQ: {fp32_f1score - int8_init_f1score:.3f}") | ||
print(f"Accuracy drop after QAT: {fp32_f1score - int8_f1score:.3f}") | ||
print(f"Model compression rate: {fp32_size / int8_size:.3f}") | ||
# https://docs.openvino.ai/latest/openvino_docs_optimization_guide_dldt_optimization_guide.html | ||
print(f"Performance speed up (throughput mode): {int8_fps / fp32_fps:.3f}") | ||
|
||
return fp32_f1score, int8_init_f1score, int8_f1score, fp32_fps, int8_fps, fp32_size, int8_size | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
2 changes: 2 additions & 0 deletions
2
examples/quantization_aware_training/torch/anomalib/requirements.txt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
anomalib | ||
torch |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters