From 50139dd14bef6e9ae6421998d4e927ba8f53d75f Mon Sep 17 00:00:00 2001 From: Alexander Suslov Date: Mon, 18 Mar 2024 13:56:06 +0400 Subject: [PATCH] QAT anomaly example --- .../torch/anomaly/main.py | 107 ++++++++++++++++++ 1 file changed, 107 insertions(+) create mode 100644 examples/quantization_aware_training/torch/anomaly/main.py diff --git a/examples/quantization_aware_training/torch/anomaly/main.py b/examples/quantization_aware_training/torch/anomaly/main.py new file mode 100644 index 00000000000..ea74606f425 --- /dev/null +++ b/examples/quantization_aware_training/torch/anomaly/main.py @@ -0,0 +1,107 @@ +import os +import re +import subprocess +from pathlib import Path +from typing import List + +from anomalib import TaskType +from anomalib.data import MVTec +from anomalib.deploy import ExportType +from anomalib.engine import Engine +from anomalib.models import Stfpm + +import nncf + +ROOT = Path(__file__).parent.resolve() +DATASET_PATH = "~/.cache/nncf/datasets/mvtec" + + +def run_benchmark(model_path: str, shape: List[int]) -> float: + command = f"benchmark_app -m {model_path} -d CPU -api async -t 15" + command += f' -shape "[{",".join(str(x) for x in shape)}]"' + cmd_output = subprocess.check_output(command, shell=True) # nosec + print(*str(cmd_output).split("\\n")[-9:-1], sep="\n") + match = re.search(r"Throughput\: (.+?) FPS", str(cmd_output)) + return float(match.group(1)) + + +def get_model_size(ir_path: str, m_type: str = "Mb") -> float: + xml_size = os.path.getsize(ir_path) + bin_size = os.path.getsize(os.path.splitext(ir_path)[0] + ".bin") + for t in ["bytes", "Kb", "Mb"]: + if m_type == t: + break + xml_size /= 1024 + bin_size /= 1024 + model_size = xml_size + bin_size + return model_size + + +def main(): + # Create model and datamodule + model = Stfpm() + datamodule = MVTec(root=DATASET_PATH) + + # Create engine and fit model + engine = Engine(task=TaskType.SEGMENTATION, max_epochs=20) + engine.fit(model=model, datamodule=datamodule) + + # Load best model from checkpoint before evaluating + print(f"Test results for original FP32 model:") + fp32_test_results = engine.test( + model=model, + datamodule=datamodule, + ckpt_path=engine.trainer.checkpoint_callback.best_model_path, + ) + + # Create calibration dataset + def transform_fn(data_item): + return data_item["image"] + + test_loader = datamodule.test_dataloader() + calibration_dataset = nncf.Dataset(test_loader, transform_fn) + + # Quantize the model using Post-Training Quantization + quantized_model = nncf.quantize(model=model, calibration_dataset=calibration_dataset) + + # Create engine for the quantized model + engine = Engine(task=TaskType.SEGMENTATION, max_epochs=10) + + # Validate the quantized model + print(f"Test results for INT8 model after PTQ:") + engine.test(model=quantized_model, datamodule=datamodule) + + # (optional step) Fine tune the quantized model + engine.fit(model=quantized_model, datamodule=datamodule) + int8_test_results = engine.test(model=quantized_model, datamodule=datamodule) + print(int8_test_results) + + # Export FP32 model to OpenVINO™ IR + fp32_ir_path = engine.export(model=model, export_type=ExportType.OPENVINO, export_root=ROOT / "stfpm_fp32") + print(f"Original model path: {fp32_ir_path}") + + # Export INT8 model to OpenVINO™ IR + int8_ir_path = engine.export( + model=quantized_model, export_type=ExportType.OPENVINO, export_root=ROOT / "stfpm_int8" + ) + print(f"Quantized model path: {int8_ir_path}") + + # Run benchmarks + print("Run benchmark for FP32 model (IR)...") + fp32_fps = run_benchmark(fp32_ir_path, shape=[1, 3, 256, 256]) + + print("Run benchmark for INT8 model (IR)...") + int8_fps = run_benchmark(int8_ir_path, shape=[1, 3, 256, 256]) + + fp32_size = get_model_size(fp32_ir_path) + int8_size = get_model_size(int8_ir_path) + + print(f"Model compression rate: {fp32_size / int8_size:.3f}") + # https://docs.openvino.ai/latest/openvino_docs_optimization_guide_dldt_optimization_guide.html + print(f"Performance speed up (throughput mode): {int8_fps / fp32_fps:.3f}") + + return fp32_test_results, int8_test_results, fp32_fps, int8_fps, fp32_size, int8_size + + +if __name__ == "__main__": + main()