forked from amd/RyzenAI-SW
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathresnet_quantize.py
117 lines (89 loc) · 4.77 KB
/
resnet_quantize.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import torch
from onnxruntime.quantization.calibrate import CalibrationDataReader
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from torchvision.datasets import CIFAR10
import onnx
import onnxruntime
from onnxruntime.quantization import CalibrationDataReader, QuantType, QuantFormat, CalibrationMethod, quantize_static
import vai_q_onnx
class CIFAR10DataSet:
def __init__(
self,
data_dir,
**kwargs,
):
super().__init__()
self.train_path = data_dir
self.vld_path = data_dir
self.setup("fit")
def setup(self, stage: str):
transform = transforms.Compose(
[transforms.Pad(4), transforms.RandomHorizontalFlip(), transforms.RandomCrop(32), transforms.ToTensor()]
)
self.train_dataset = CIFAR10(root=self.train_path, train=True, transform=transform, download=False)
self.val_dataset = CIFAR10(root=self.vld_path, train=True, transform=transform, download=False)
class PytorchResNetDataset(Dataset):
def __init__(self, dataset):
self.dataset = dataset
def __len__(self):
return len(self.dataset)
def __getitem__(self, index):
sample = self.dataset[index]
input_data = sample[0]
label = sample[1]
return input_data, label
def create_dataloader(data_dir, batch_size):
cifar10_dataset = CIFAR10DataSet(data_dir)
_, val_set = torch.utils.data.random_split(cifar10_dataset.val_dataset, [49000, 1000])
benchmark_dataloader = DataLoader(PytorchResNetDataset(val_set), batch_size=batch_size, drop_last=True)
return benchmark_dataloader
class ResnetCalibrationDataReader(CalibrationDataReader):
def __init__(self, data_dir: str, batch_size: int = 16):
super().__init__()
self.iterator = iter(create_dataloader(data_dir, batch_size))
def get_next(self) -> dict:
try:
images, labels = next(self.iterator)
return {"input": images.numpy()}
except Exception:
return None
def resnet_calibration_reader(data_dir, batch_size=16):
return ResnetCalibrationDataReader(data_dir, batch_size=batch_size)
def main():
# `input_model_path` is the path to the original, unquantized ONNX model.
input_model_path = "models/resnet_trained_for_cifar10.onnx"
# `output_model_path` is the path where the quantized model will be saved.
output_model_path = "models/resnet.qdq.U8S8.onnx"
# `calibration_dataset_path` is the path to the dataset used for calibration during quantization.
calibration_dataset_path = "data/"
# `dr` (Data Reader) is an instance of ResNetDataReader, which is a utility class that
# reads the calibration dataset and prepares it for the quantization process.
dr = resnet_calibration_reader(calibration_dataset_path)
# `quantize_static` is a function that applies static quantization to the model.
# The parameters of this function are:
# - `input_model_path`: the path to the original, unquantized model.
# - `output_model_path`: the path where the quantized model will be saved.
# - `dr`: an instance of a data reader utility, which provides data for model calibration.
# - `quant_format`: the format of quantization operators. Need to set to QDQ or QOperator.
# - `activation_type`: the data type of activation tensors after quantization. In this case, it's QUInt8 (Quantized Int 8).
# - `weight_type`: the data type of weight tensors after quantization. In this case, it's QInt8 (Quantized Int 8).
# - `enable_dpu`: (Boolean) determines whether to generate a quantized model that is suitable for the DPU. If set to True, the quantization process will create a model that is optimized for DPU computations.
# - `extra_options`: (Dict or None) Dictionary of additional options that can be passed to the quantization process. In this example, ``ActivationSymmetric`` is set to True i.e., calibration data for activations is symmetrized.
vai_q_onnx.quantize_static(
input_model_path,
output_model_path,
dr,
quant_format=vai_q_onnx.QuantFormat.QDQ,
calibrate_method=vai_q_onnx.PowerOfTwoMethod.MinMSE,
activation_type=vai_q_onnx.QuantType.QUInt8,
weight_type=vai_q_onnx.QuantType.QInt8,
enable_dpu=True,
extra_options={'ActivationSymmetric': True}
)
print('Calibrated and quantized model saved at:', output_model_path)
if __name__ == '__main__':
main()
#################################################################################
#License
#Ryzen AI is licensed under `MIT License <https://github.com/amd/ryzen-ai-documentation/blob/main/License>`_ . Refer to the `LICENSE File <https://github.com/amd/ryzen-ai-documentation/blob/main/License>`_ for the full license text and copyright notice.