-
Notifications
You must be signed in to change notification settings - Fork 357
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Implement FP32 accumulation for matmul (#3110)
Co-authored-by: Hoonkyung Cho <[email protected]> Co-authored-by: Zewen (Evan) Li <[email protected]>
- Loading branch information
1 parent
451179b
commit 38771e6
Showing
22 changed files
with
314 additions
and
341 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
.. _mixed_precision: | ||
|
||
Compile Mixed Precision models with Torch-TensorRT | ||
==================================== | ||
.. currentmodule:: torch_tensorrt.dynamo | ||
|
||
.. automodule:: torch_tensorrt.dynamo | ||
:members: | ||
:undoc-members: | ||
:show-inheritance: | ||
|
||
Consider the following Pytorch model which explicitly casts intermediate layer to run in FP16. | ||
|
||
.. code-block:: python | ||
class MyModule(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.linear1 = torch.nn.Linear(10,10) | ||
self.linear2 = torch.nn.Linear(10,30).half() | ||
self.linear3 = torch.nn.Linear(30,40) | ||
def forward(self, x): | ||
x = self.linear1(x) | ||
x = x.to(torch.float16) | ||
x = self.linear2(x) | ||
x = x.to(torch.float32) | ||
x = self.linear3(x) | ||
return x | ||
If we compile the above model using Torch-TensorRT, layer profiling logs indicate that all the layers are | ||
run in FP32. This is because TensorRT picks the kernels for layers which result in the best performance. | ||
|
||
.. code-block:: python | ||
inputs = [torch.randn((1, 10), dtype=torch.float32).cuda()] | ||
mod = MyModule().eval().cuda() | ||
ep = torch.export.export(mod, tuple(inputs)) | ||
with torch_tensorrt.logging.debug(): | ||
trt_gm = torch_tensorrt.dynamo.compile(ep, | ||
inputs=inputs, | ||
debug=True) | ||
# Debug log info | ||
# Layers: | ||
# Name: __myl_MulSum_myl0_0, LayerType: kgen, Inputs: [ { Name: __mye116_dconst, Dimensions: [10,10], Format/Datatype: Float }, { Name: x, Dimensions: [10,1], Format/Datatype: Float }], Outputs: [ { Name: __myln_k_arg__bb1_2, Dimensions: [1,10], Format/Datatype: Float }], TacticName: __myl_MulSum_0xfa6c1858aea1b13b03f90165d7149ec6, StreamId: 0, Metadata: | ||
# Name: __myl_AddResMulSum_myl0_1, LayerType: kgen, Inputs: [ { Name: __mye131_dconst, Dimensions: [10,30], Format/Datatype: Float }, { Name: __myln_k_arg__bb1_2, Dimensions: [1,10], Format/Datatype: Float }, { Name: linear1/addmm_constant_0 _ linear1/addmm_add_broadcast_to_same_shape_lhs_broadcast_constantFloat, Dimensions: [1,10], Format/Datatype: Float }], Outputs: [ { Name: __myln_k_arg__bb1_3, Dimensions: [1,30], Format/Datatype: Float }], TacticName: __myl_AddResMulSum_0xb3915d7ebfe48be45b6d49083479e12f, StreamId: 0, Metadata: | ||
# Name: __myl_AddResMulSumAdd_myl0_2, LayerType: kgen, Inputs: [ { Name: __mye146_dconst, Dimensions: [30,40], Format/Datatype: Float }, { Name: linear3/addmm_2_constant_0 _ linear3/addmm_2_add_broadcast_to_same_shape_lhs_broadcast_constantFloat, Dimensions: [1,40], Format/Datatype: Float }, { Name: __myln_k_arg__bb1_3, Dimensions: [1,30], Format/Datatype: Float }, { Name: linear2/addmm_1_constant_0 _ linear2/addmm_1_add_broadcast_to_same_shape_lhs_broadcast_constantFloat, Dimensions: [1,30], Format/Datatype: Float }], Outputs: [ { Name: output0, Dimensions: [1,40], Format/Datatype: Float }], TacticName: __myl_AddResMulSumAdd_0xcdd0085ad25f5f45ac5fafb72acbffd6, StreamId: 0, Metadata: | ||
In order to respect the types specified by the user in the model (eg: in this case, ``linear2`` layer to run in FP16), users can enable | ||
the compilation setting ``use_explicit_typing=True``. Compiling with this option results in the following TensorRT logs | ||
|
||
.. note:: If you enable ``use_explicit_typing=True``, only torch.float32 is supported in the enabled_precisions. | ||
|
||
.. code-block:: python | ||
inputs = [torch.randn((1, 10), dtype=torch.float32).cuda()] | ||
mod = MyModule().eval().cuda() | ||
ep = torch.export.export(mod, tuple(inputs)) | ||
with torch_tensorrt.logging.debug(): | ||
trt_gm = torch_tensorrt.dynamo.compile(ep, | ||
inputs=inputs, | ||
use_explicit_typing=True | ||
debug=True) | ||
# Debug log info | ||
# Layers: | ||
# Name: __myl_MulSumAddCas_myl0_0, LayerType: kgen, Inputs: [ { Name: linear1/addmm_constant_0 _ linear1/addmm_add_broadcast_to_same_shape_lhs_broadcast_constantFloat, Dimensions: [1,10], Format/Datatype: Float }, { Name: __mye112_dconst, Dimensions: [10,10], Format/Datatype: Float }, { Name: x, Dimensions: [10,1], Format/Datatype: Float }], Outputs: [ { Name: __myln_k_arg__bb1_2, Dimensions: [1,10], Format/Datatype: Half }], TacticName: __myl_MulSumAddCas_0xacf8f5dd9be2f3e7bb09cdddeac6c936, StreamId: 0, Metadata: | ||
# Name: __myl_ResMulSumAddCas_myl0_1, LayerType: kgen, Inputs: [ { Name: __mye127_dconst, Dimensions: [10,30], Format/Datatype: Half }, { Name: linear2/addmm_1_constant_0 _ linear2/addmm_1_add_broadcast_to_same_shape_lhs_broadcast_constantHalf, Dimensions: [1,30], Format/Datatype: Half }, { Name: __myln_k_arg__bb1_2, Dimensions: [1,10], Format/Datatype: Half }], Outputs: [ { Name: __myln_k_arg__bb1_3, Dimensions: [1,30], Format/Datatype: Float }], TacticName: __myl_ResMulSumAddCas_0x5a3b318b5a1c97b7d5110c0291481337, StreamId: 0, Metadata: | ||
# Name: __myl_ResMulSumAdd_myl0_2, LayerType: kgen, Inputs: [ { Name: __mye142_dconst, Dimensions: [30,40], Format/Datatype: Float }, { Name: linear3/addmm_2_constant_0 _ linear3/addmm_2_add_broadcast_to_same_shape_lhs_broadcast_constantFloat, Dimensions: [1,40], Format/Datatype: Float }, { Name: __myln_k_arg__bb1_3, Dimensions: [1,30], Format/Datatype: Float }], Outputs: [ { Name: output0, Dimensions: [1,40], Format/Datatype: Float }], TacticName: __myl_ResMulSumAdd_0x3fad91127c640fd6db771aa9cde67db0, StreamId: 0, Metadata: | ||
Now the ``linear2`` layer runs in FP16 as shown in the above logs. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
cupy==13.1.0 | ||
torch>=2.4.0.dev20240503+cu121 | ||
torch-tensorrt>=2.4.0.dev20240503+cu121 | ||
triton==2.3.0 | ||
diffusers==0.30.3 | ||
transformers==4.44.2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.