-
Notifications
You must be signed in to change notification settings - Fork 240
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
1 changed file
with
107 additions
and
0 deletions.
There are no files selected for viewing
107 changes: 107 additions & 0 deletions
107
examples/quantization_aware_training/torch/anomaly/main.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
import os | ||
import re | ||
import subprocess | ||
from pathlib import Path | ||
from typing import List | ||
|
||
from anomalib import TaskType | ||
from anomalib.data import MVTec | ||
from anomalib.deploy import ExportType | ||
from anomalib.engine import Engine | ||
from anomalib.models import Stfpm | ||
|
||
import nncf | ||
|
||
ROOT = Path(__file__).parent.resolve() | ||
DATASET_PATH = "~/.cache/nncf/datasets/mvtec" | ||
|
||
|
||
def run_benchmark(model_path: str, shape: List[int]) -> float: | ||
command = f"benchmark_app -m {model_path} -d CPU -api async -t 15" | ||
command += f' -shape "[{",".join(str(x) for x in shape)}]"' | ||
cmd_output = subprocess.check_output(command, shell=True) # nosec | ||
print(*str(cmd_output).split("\\n")[-9:-1], sep="\n") | ||
match = re.search(r"Throughput\: (.+?) FPS", str(cmd_output)) | ||
return float(match.group(1)) | ||
|
||
|
||
def get_model_size(ir_path: str, m_type: str = "Mb") -> float: | ||
xml_size = os.path.getsize(ir_path) | ||
bin_size = os.path.getsize(os.path.splitext(ir_path)[0] + ".bin") | ||
for t in ["bytes", "Kb", "Mb"]: | ||
if m_type == t: | ||
break | ||
xml_size /= 1024 | ||
bin_size /= 1024 | ||
model_size = xml_size + bin_size | ||
return model_size | ||
|
||
|
||
def main(): | ||
# Create model and datamodule | ||
model = Stfpm() | ||
datamodule = MVTec(root=DATASET_PATH) | ||
|
||
# Create engine and fit model | ||
engine = Engine(task=TaskType.SEGMENTATION, max_epochs=20) | ||
engine.fit(model=model, datamodule=datamodule) | ||
|
||
# Load best model from checkpoint before evaluating | ||
print(f"Test results for original FP32 model:") | ||
fp32_test_results = engine.test( | ||
model=model, | ||
datamodule=datamodule, | ||
ckpt_path=engine.trainer.checkpoint_callback.best_model_path, | ||
) | ||
|
||
# Create calibration dataset | ||
def transform_fn(data_item): | ||
return data_item["image"] | ||
|
||
test_loader = datamodule.test_dataloader() | ||
calibration_dataset = nncf.Dataset(test_loader, transform_fn) | ||
|
||
# Quantize the model using Post-Training Quantization | ||
quantized_model = nncf.quantize(model=model, calibration_dataset=calibration_dataset) | ||
|
||
# Create engine for the quantized model | ||
engine = Engine(task=TaskType.SEGMENTATION, max_epochs=10) | ||
|
||
# Validate the quantized model | ||
print(f"Test results for INT8 model after PTQ:") | ||
engine.test(model=quantized_model, datamodule=datamodule) | ||
|
||
# (optional step) Fine tune the quantized model | ||
engine.fit(model=quantized_model, datamodule=datamodule) | ||
int8_test_results = engine.test(model=quantized_model, datamodule=datamodule) | ||
print(int8_test_results) | ||
|
||
# Export FP32 model to OpenVINO™ IR | ||
fp32_ir_path = engine.export(model=model, export_type=ExportType.OPENVINO, export_root=ROOT / "stfpm_fp32") | ||
print(f"Original model path: {fp32_ir_path}") | ||
|
||
# Export INT8 model to OpenVINO™ IR | ||
int8_ir_path = engine.export( | ||
model=quantized_model, export_type=ExportType.OPENVINO, export_root=ROOT / "stfpm_int8" | ||
) | ||
print(f"Quantized model path: {int8_ir_path}") | ||
|
||
# Run benchmarks | ||
print("Run benchmark for FP32 model (IR)...") | ||
fp32_fps = run_benchmark(fp32_ir_path, shape=[1, 3, 256, 256]) | ||
|
||
print("Run benchmark for INT8 model (IR)...") | ||
int8_fps = run_benchmark(int8_ir_path, shape=[1, 3, 256, 256]) | ||
|
||
fp32_size = get_model_size(fp32_ir_path) | ||
int8_size = get_model_size(int8_ir_path) | ||
|
||
print(f"Model compression rate: {fp32_size / int8_size:.3f}") | ||
# https://docs.openvino.ai/latest/openvino_docs_optimization_guide_dldt_optimization_guide.html | ||
print(f"Performance speed up (throughput mode): {int8_fps / fp32_fps:.3f}") | ||
|
||
return fp32_test_results, int8_test_results, fp32_fps, int8_fps, fp32_size, int8_size | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |