-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[js/webgpu] ConvTranspose1D slower on Webgpu than Wasm #23273
Comments
@gianlourbano I can reproduce it. Will take a look, thanks. |
### Description <!-- Describe your changes. --> BUG #23273 With this change, I see the convTranspose time in that bug becomes ~7s from ~90s on my Meteor Lake. This PR does below things: 1. Use stride to update the increasement in the loop. In the bug, the stride is 1024, which can greatly reduce the loop times. 2. Support components for A to reduce the memory access times. 3. When output channels is 1, the b components can be same with A to further reduce the memory access times.
Thanks for the help @qjia7 ! Do you think there's more room for improvement? The same op in torch/onnx python cpu takes about 400-600ms |
### Description <!-- Describe your changes. --> BUG #23273 With this change, I see the convTranspose time in that bug becomes ~7s from ~90s on my Meteor Lake. This PR does below things: 1. Use stride to update the increasement in the loop. In the bug, the stride is 1024, which can greatly reduce the loop times. 2. Support components for A to reduce the memory access times. 3. When output channels is 1, the b components can be same with A to further reduce the memory access times.
Yes. Your shape is very special, the stride is 1024 which is very big. I can do some specific optimization for such big of stride. And the output channel is only 1, which can also be further optimized. Glad to know cpu only takes 400-600ms which gives gpu a high target :) |
Yes, i'm aware. The convolution is part of an implementation of a custom inverse short time fourier transform, given that the conversion of such operator still does not work from torch to onnx. Thank you for precious help |
BUG #23273 This PR does below optimizations: 1. When output channels is one, 1) calculate the offset before the inchannel loop to reduce indices to offsets calculation, 2) split the `inputChannelsPerGroup` into `inputChannelsPerGroupInt` and `inputChannelsRemainder` parts so that we can always access 4 data for `inputChannelsPerGroupInt`. 2. Use precise initial value to reduce useless loop iterations. Thanks @jiangzhaoming 's suggestion's on this. With this PR, ConvTranspose becomes 3.7s from 8.4s on Intel Meteor Lake. On NV RTX 2000 Ada, it becomes 1.6s from 2.7s.
BUG #23273 This PR does below optimizations: 1. When output channels is one, 1) calculate the offset before the inchannel loop to reduce indices to offsets calculation, 2) split the `inputChannelsPerGroup` into `inputChannelsPerGroupInt` and `inputChannelsRemainder` parts so that we can always access 4 data for `inputChannelsPerGroupInt`. 2. Use precise initial value to reduce useless loop iterations. Thanks @jiangzhaoming 's suggestion's on this. With this PR, ConvTranspose becomes 3.7s from 8.4s on Intel Meteor Lake. On NV RTX 2000 Ada, it becomes 1.6s from 2.7s.
Describe the issue
ConvTranpose1D with input shapes [8, 4098, 435], weights [4096, 1, 4098] strides 1024 and padding 0 appears to be slower on WebGPU than Wasm, with timings:
To reproduce
Simple torch script to generate the conv and convert it to onnx
To test in browser:
Urgency
Urgent
ONNX Runtime Installation
Released Package
ONNX Runtime Version or Commit ID
1.21.0-dev.20241224-2d05c4bcd9
Execution Provider
'webgpu' (WebGPU), 'wasm'/'cpu' (WebAssembly CPU)
The text was updated successfully, but these errors were encountered: