-
Notifications
You must be signed in to change notification settings - Fork 45
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
4cef352
commit 0f49636
Showing
20 changed files
with
318 additions
and
142 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,12 @@ | ||
# Example checklist | ||
|
||
- Elaboration | ||
- [ ] Example concept is described and agreed upon | ||
- [ ] Example concept is described and agreed upon | ||
- Implementation | ||
- [ ] Example is implemented | ||
- [ ] Example is implemented | ||
- Internal review | ||
- [ ] Internal code review is done | ||
- [ ] Internal code review is done | ||
- External review | ||
- [ ] Upstreaming PR is opened, external review is done | ||
- [ ] Upstreaming PR is opened, external review is done | ||
- Done | ||
- [ ] Example merged to upstream | ||
- [ ] Example merged to upstream |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,16 +1,18 @@ | ||
## Notes for the reviewer | ||
|
||
_The reviewer should acknowledge all these topics._ | ||
<insert notes> | ||
|
||
## Checklist before merge | ||
|
||
- [ ] CMake support is added | ||
- [ ] Dependencies are copied via `IMPORTED_RUNTIME_ARTIFACTS` if applicable | ||
- [ ] Dependencies are copied via `IMPORTED_RUNTIME_ARTIFACTS` if applicable | ||
- [ ] GNU Make support is added (Linux) | ||
- [ ] Visual Studio project is added for VS2017, 2019, 2022 (Windows) (use [the script](https://projects.streamhpc.com/departments/knowledge/employee-handbook/-/wikis/Projects/AMD/Libraries/examples/Adding-Visual-Studio-Projects-to-new-examples#scripts)) | ||
- [ ] DLL dependencies are copied via `<Content Include` | ||
- [ ] Visual Studio project is added to `ROCm-Examples-vs*.sln` (ROCm) | ||
- [ ] Visual Studio project is added to `ROCm-Examples-Portable-vs*.sln` (ROCm/CUDA) if applicable | ||
- [ ] DLL dependencies are copied via `<Content Include` | ||
- [ ] Visual Studio project is added to `ROCm-Examples-vs*.sln` (ROCm) | ||
- [ ] Visual Studio project is added to `ROCm-Examples-Portable-vs*.sln` (ROCm/CUDA) if applicable | ||
- [ ] Inline code documentation is added | ||
- [ ] README is added according to template | ||
- [ ] Related READMEs, ToC are updated | ||
- [ ] Related READMEs, ToC are updated | ||
- [ ] The CI passes for Linux/ROCm, Linux/CUDA, Windows/ROCm, Windows/CUDA. |
157 changes: 82 additions & 75 deletions
157
AI/MIGraphX/Quantization/Running-Quantized-ResNet50-via-MIGraphX.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,102 +1,109 @@ | ||
|
||
|
||
# Running quantized ResNet50 via MIGraphX | ||
|
||
## Summary | ||
|
||
This example walks through the dynamo Post Training Quantization (PTQ) workflow for running a quantized model using torch_migraphx. | ||
|
||
## Prerequisites | ||
|
||
- You must follow the installation instructions for the torch_migraphx library in [Readme.md](https://github.com/Rmalavally/rocm-examples/blob/develop/AI/MIGraphX/Quantization/Readme.md) before using this example. | ||
- You must follow the installation instructions for the torch_migraphx library in [README.md](README.md) before using this example. | ||
|
||
## Steps for running a quantized model using torch_migraphx | ||
|
||
1. Use torch.export and quantize_pt2e APIs to perform quantization | ||
*Note*: The export API call is considered a prototype feature at the time this tutorial is written. Some call signatures may be modified in the future. | ||
1. Use torch.export and quantize_pt2e APIs to perform quantization. | ||
|
||
**Note**: The export API call is considered a prototype feature at the time this tutorial is written. Some call signatures may be modified in the future. | ||
|
||
```python | ||
import torch | ||
from torchvision import models | ||
from torch._export import capture_pre_autograd_graph | ||
from torch.ao.quantization.quantize_pt2e import prepare_pt2e, convert_pt2e | ||
``` | ||
|
||
```import torch | ||
from torchvision import models | ||
from torch._export import capture_pre_autograd_graph | ||
from torch.ao.quantization.quantize_pt2e import prepare_pt2e, convert_pt2e | ||
``` | ||
```python | ||
import torch_migraphx | ||
from torch_migraphx.dynamo.quantization import MGXQuantizer | ||
|
||
```import torch_migraphx | ||
from torch_migraphx.dynamo.quantization import MGXQuantizer | ||
model_fp32 = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1).eval() | ||
input_fp32 = torch.randn(2, 3, 28, 28) | ||
model_fp32 = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1).eval() | ||
input_fp32 = torch.randn(2, 3, 28, 28) | ||
|
||
torch_fp32_out = model_fp32(input_fp32) | ||
``` | ||
torch_fp32_out = model_fp32(input_fp32) | ||
``` | ||
|
||
```python | ||
model_export = capture_pre_autograd_graph(model_fp32, (input_fp32, )) | ||
``` | ||
|
||
``` | ||
model_export = capture_pre_autograd_graph(model_fp32, (input_fp32, )) | ||
``` | ||
Use the pt2e API to prepare, calibrate, and convert the model. Torch-MIGraphX provides a custom Quantizer for performing quantization that is compatible with MIGraphX. | ||
Use the pt2e API to prepare, calibrate, and convert the model. Torch-MIGraphX provides a custom Quantizer for performing quantization that is compatible with MIGraphX. | ||
|
||
``` | ||
quantizer = MGXQuantizer() | ||
m = prepare_pt2e(model_export, quantizer) | ||
# psudo calibrate | ||
with torch.no_grad(): | ||
for _ in range(10): | ||
```python | ||
quantizer = MGXQuantizer() | ||
m = prepare_pt2e(model_export, quantizer) | ||
|
||
# psudo calibrate | ||
with torch.no_grad(): | ||
for _ in range(10): | ||
m(torch.randn(2, 3, 28, 28)) | ||
q_m = convert_pt2e(m) | ||
torch_qout = q_m(input_fp32) | ||
``` | ||
q_m = convert_pt2e(m) | ||
torch_qout = q_m(input_fp32) | ||
``` | ||
|
||
2. Lower Quantized model to MIGraphX. This step is the same as lowering any other model using torch.compile! | ||
|
||
``` | ||
mgx_mod = torch.compile(q_m, backend='migraphx').cuda() | ||
mgx_out = mgx_mod(input_fp32.cuda()) | ||
print(f"PyTorch FP32 (Gold Value):\n{torch_fp32_out}") | ||
print(f"PyTorch INT8 (Fake Quantized):\n{torch_qout}") | ||
print(f"MIGraphX INT8:\n{mgx_out}") | ||
``` | ||
```python | ||
mgx_mod = torch.compile(q_m, backend='migraphx').cuda() | ||
mgx_out = mgx_mod(input_fp32.cuda()) | ||
|
||
print(f"PyTorch FP32 (Gold Value):\n{torch_fp32_out}") | ||
print(f"PyTorch INT8 (Fake Quantized):\n{torch_qout}") | ||
print(f"MIGraphX INT8:\n{mgx_out}") | ||
``` | ||
|
||
3. Performance | ||
|
||
Do a quick test to measure the performance gain from using quantization. | ||
|
||
``` | ||
import copy | ||
import torch._dynamo | ||
# We will use this function to benchmark all modules: | ||
def benchmark_module(model, inputs, iterations=100): | ||
model(*inputs) | ||
torch.cuda.synchronize() | ||
start_event = torch.cuda.Event(enable_timing=True) | ||
end_event = torch.cuda.Event(enable_timing=True) | ||
start_event.record() | ||
for _ in range(iterations): | ||
|
||
Do a quick test to measure the performance gain from using quantization. | ||
|
||
```python | ||
import copy | ||
import torch._dynamo | ||
|
||
# We will use this function to benchmark all modules: | ||
def benchmark_module(model, inputs, iterations=100): | ||
model(*inputs) | ||
end_event.record() | ||
torch.cuda.synchronize() | ||
return start_event.elapsed_time(end_event) / iterations | ||
# Benchmark MIGraphX INT8 | ||
mgx_int8_time = benchmark_module(mgx_mod, [input_fp32.cuda()]) | ||
torch._dynamo.reset() | ||
# Benchmark MIGraphX FP32 | ||
mgx_module_fp32 = torch.compile(copy.deepcopy(model_fp32), backend='migraphx').cuda() | ||
mgx_module_fp32(input_fp32.cuda()) | ||
mgx_fp32_time = benchmark_module(mgx_module_fp32, [input_fp32.cuda()]) | ||
torch._dynamo.reset() | ||
# Benchmark MIGraphX FP16 | ||
mgx_module_fp16 = torch.compile(copy.deepcopy(model_fp32).half(), backend='migraphx').cuda() | ||
input_fp16 = input_fp32.cuda().half() | ||
mgx_module_fp16(input_fp16) | ||
mgx_fp16_time = benchmark_module(mgx_module_fp16, [input_fp16]) | ||
print(f"{mgx_fp32_time=:0.4f}ms") | ||
print(f"{mgx_fp16_time=:0.4f}ms") | ||
print(f"{mgx_int8_time=:0.4f}ms") | ||
``` | ||
Note that these performance gains (or lack of gains) will vary depending on the specific hardware in use. | ||
torch.cuda.synchronize() | ||
|
||
start_event = torch.cuda.Event(enable_timing=True) | ||
end_event = torch.cuda.Event(enable_timing=True) | ||
|
||
start_event.record() | ||
|
||
for _ in range(iterations): | ||
model(*inputs) | ||
end_event.record() | ||
torch.cuda.synchronize() | ||
|
||
return start_event.elapsed_time(end_event) / iterations | ||
|
||
# Benchmark MIGraphX INT8 | ||
mgx_int8_time = benchmark_module(mgx_mod, [input_fp32.cuda()]) | ||
torch._dynamo.reset() | ||
|
||
# Benchmark MIGraphX FP32 | ||
mgx_module_fp32 = torch.compile(copy.deepcopy(model_fp32), backend='migraphx').cuda() | ||
mgx_module_fp32(input_fp32.cuda()) | ||
mgx_fp32_time = benchmark_module(mgx_module_fp32, [input_fp32.cuda()]) | ||
torch._dynamo.reset() | ||
|
||
# Benchmark MIGraphX FP16 | ||
mgx_module_fp16 = torch.compile(copy.deepcopy(model_fp32).half(), backend='migraphx').cuda() | ||
input_fp16 = input_fp32.cuda().half() | ||
mgx_module_fp16(input_fp16) | ||
mgx_fp16_time = benchmark_module(mgx_module_fp16, [input_fp16]) | ||
|
||
print(f"{mgx_fp32_time=:0.4f}ms") | ||
print(f"{mgx_fp16_time=:0.4f}ms") | ||
print(f"{mgx_int8_time=:0.4f}ms") | ||
``` | ||
|
||
Note that these performance gains (or lack of gains) will vary depending on the specific hardware in use. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.