Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[js/webgpu] ConvTranspose1D slower on Webgpu than Wasm #23273

Open
gianlourbano opened this issue Jan 7, 2025 · 5 comments
Open

[js/webgpu] ConvTranspose1D slower on Webgpu than Wasm #23273

gianlourbano opened this issue Jan 7, 2025 · 5 comments
Labels
ep:WebGPU ort-web webgpu provider platform:web issues related to ONNX Runtime web; typically submitted using template

Comments

@gianlourbano
Copy link

Describe the issue

ConvTranpose1D with input shapes [8, 4098, 435], weights [4096, 1, 4098] strides 1024 and padding 0 appears to be slower on WebGPU than Wasm, with timings:

EP timing (m1 macbook pro)
wasm 6s
webgpu (latest chrome) 30s
webgpu (canary chrome) 18s

canary faster due to this bug

To reproduce

Simple torch script to generate the conv and convert it to onnx

import torch

class ConvTest (torch.nn.Module):
    def __init__(self, weight, stride, padding = 0):
        super(ConvTest, self).__init__()
        self.weight = weight
        self.stride = stride
        self.padding = padding
    
    def forward(self, x):
        return torch.nn.functional.conv_transpose1d(x, self.weight, stride=self.stride, padding=self.padding)

convtest = ConvTest(weight = torch.randn(4098, 1, 4096), stride = 1024)

input = torch.randn(8, 4098,  435)

torch.onnx.export(
    convtest,
    (input,),
    "convtest.onnx",
    input_names=["input"],
    output_names=["output"],
    opset_version=20,
    dynamo=True,
    do_constant_folding=True,
    keep_initializers_as_inputs=True,
    # report=True,
    external_data=None,
    # verify=True
)

To test in browser:

       const session = await ort.InferenceSession.create("/convtest.onnx", {
            executionProviders: ["webgpu"],
            // logSeverityLevel: 0
        });

        const wgpu_profile = []

        ort.env.webgpu.profiling = {
            mode: "default",
            ondata: (data) => {
                wgpu_profile.push(data);
            }
        }

        const input_dims = [8, 4098, 435];
        const size = 8 * 4098 * 435;

        const no_chunks = 1;
        const chunks = [];

        for (let i = 0; i < no_chunks; i++) {
            const chunk = new Float32Array(size);
            chunks.push(chunk);
        }

        for(let i = 0; i < no_chunks; i++) {
            console.time("onnx step " + i);
            const input = new ort.Tensor("float32", chunks[i], input_dims);
            const output = await session.run({input});
            console.timeEnd("onnx step " + i);
        }

        await session.release();

        wgpu_profile.sort((a, b) => (a.endTime-a.startTime) - (b.endTime-b.startTime));

        wgpu_profile.forEach((kernel) => {
            console.log(`${kernel.kernelType} (${kernel.kernelName}) took ${(kernel.endTime - kernel.startTime) / 1000 / 1000} ms`);
        })

Urgency

Urgent

ONNX Runtime Installation

Released Package

ONNX Runtime Version or Commit ID

1.21.0-dev.20241224-2d05c4bcd9

Execution Provider

'webgpu' (WebGPU), 'wasm'/'cpu' (WebAssembly CPU)

@gianlourbano gianlourbano added the platform:web issues related to ONNX Runtime web; typically submitted using template label Jan 7, 2025
@gianlourbano
Copy link
Author

@qjia7 @gyagp could you please take a look? Maybe it has something to do with this pr

@github-actions github-actions bot added the ep:WebGPU ort-web webgpu provider label Jan 7, 2025
@qjia7
Copy link
Contributor

qjia7 commented Jan 8, 2025

@gianlourbano I can reproduce it. Will take a look, thanks.

guschmue pushed a commit that referenced this issue Jan 9, 2025
### Description
<!-- Describe your changes. -->
BUG #23273

With this change, I see the convTranspose time in that bug becomes ~7s
from ~90s on my Meteor Lake.

This PR does below things:
1. Use stride to update the increasement in the loop.
In the bug, the stride is 1024, which can greatly reduce the loop times.
2. Support components for A to reduce the memory access times.
3. When output channels is 1, the b components can be same with A to
further reduce the memory access times.
@gianlourbano
Copy link
Author

gianlourbano commented Jan 10, 2025

Thanks for the help @qjia7 ! Do you think there's more room for improvement? The same op in torch/onnx python cpu takes about 400-600ms

guschmue pushed a commit that referenced this issue Jan 12, 2025
### Description
<!-- Describe your changes. -->
BUG #23273

With this change, I see the convTranspose time in that bug becomes ~7s
from ~90s on my Meteor Lake.

This PR does below things:
1. Use stride to update the increasement in the loop.
In the bug, the stride is 1024, which can greatly reduce the loop times.
2. Support components for A to reduce the memory access times.
3. When output channels is 1, the b components can be same with A to
further reduce the memory access times.
@qjia7
Copy link
Contributor

qjia7 commented Jan 13, 2025

Do you think there's more room for improvement? The same op in torch/onnx python cpu takes about 400-600ms

Yes. Your shape is very special, the stride is 1024 which is very big. I can do some specific optimization for such big of stride. And the output channel is only 1, which can also be further optimized. Glad to know cpu only takes 400-600ms which gives gpu a high target :)

@gianlourbano
Copy link
Author

Yes, i'm aware. The convolution is part of an implementation of a custom inverse short time fourier transform, given that the conversion of such operator still does not work from torch to onnx. Thank you for precious help

guschmue pushed a commit that referenced this issue Jan 22, 2025
BUG #23273

This PR does below optimizations:
1. When output channels is one, 1) calculate the offset before the
inchannel loop to reduce indices to offsets calculation, 2) split the
`inputChannelsPerGroup` into `inputChannelsPerGroupInt` and
`inputChannelsRemainder` parts so that we can always access 4 data for
`inputChannelsPerGroupInt`.
2. Use precise initial value to reduce useless loop iterations. Thanks
@jiangzhaoming 's suggestion's on this.

With this PR, ConvTranspose becomes 3.7s from 8.4s on Intel Meteor Lake.
On NV RTX 2000 Ada, it becomes 1.6s from 2.7s.
ashrit-ms pushed a commit that referenced this issue Jan 23, 2025
BUG #23273

This PR does below optimizations:
1. When output channels is one, 1) calculate the offset before the
inchannel loop to reduce indices to offsets calculation, 2) split the
`inputChannelsPerGroup` into `inputChannelsPerGroupInt` and
`inputChannelsRemainder` parts so that we can always access 4 data for
`inputChannelsPerGroupInt`.
2. Use precise initial value to reduce useless loop iterations. Thanks
@jiangzhaoming 's suggestion's on this.

With this PR, ConvTranspose becomes 3.7s from 8.4s on Intel Meteor Lake.
On NV RTX 2000 Ada, it becomes 1.6s from 2.7s.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ep:WebGPU ort-web webgpu provider platform:web issues related to ONNX Runtime web; typically submitted using template
Projects
None yet
Development

No branches or pull requests

2 participants