-
Notifications
You must be signed in to change notification settings - Fork 14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
stablehlo-to-tensorrt
conversion pass doesn't support stablehlo.reduce
with multiple reduction dims
#279
Comments
new finding: Example
then if we do
and then
and even
but if you skip a dim
|
I was able to around this to implement bathnorm as follows # Transpose the channel dimension (dim 1) with the batch dimension (dim 0)
x_transposed = tp.transpose(x, 0, 1)
# Reshape to combine the batch and spatial dimensions
C, N, H, W = x_transposed.shape
x_reshaped = tp.reshape(x_transposed, (C, N * H * W))
# Calculate mean and variance across the merged dimensions for each channel (C)
mean = tp.mean(x_reshaped, dim=1, keepdim=True)
variance = tp.var(x_reshaped, dim=1, keepdim=True)
mean = tp.reshape(mean, (1, C, 1, 1))
variance = tp.reshape(variance, (1, C, 1, 1))
# Transpose back to the original shape
x_transposed_back = tp.transpose(x_transposed, 0, 1)
# Normalize the input
x_normalized = (x_transposed_back - mean) / tp.sqrt(variance + self.eps)
# Apply the learned scaling (gamma) and shifting (beta)
x_scaled = self.gamma * x_normalized + self.beta
return x_scaled |
TensorRT doesn't actually support doing reduction across multiple dimensions. In MLIR-TRT don't do anything special to work around this limitations. We would have to add a transformation that decomposes |
stablehlo-to-tensorrt
conversion pass doesn't support stablehlo.reduce
with multiple reduction dims
@christopherbate TRT does support reduction across multiple dimensions - the axes parameter is a bitset. MLIR-TRT also seems to support this in the case where the reduces axes are contiguous (see #297 for an example of multiple contiguous axes working). It seems to be only when we skip over certain dimensions that compilation fails. TRT should work fine with skipped dimensions, so is this just a lowering bug? |
I'm aware the axes parameter is a bitset, but IIRC if you actually tried to reduce multiple dimensions, tRT will return an error. This is the reason we have the current restriction, although maybe it has been lifted since TRT 8. We would need to confirm. Right now in StableHLO-to-TRT, we only convert single-axis reductions. The preprocessing pipelien tries to take care of flattening in the case the reduction is over multiple dims IIRC. |
This issue is related to
tp.mean
andtp.var
failures when implementingBatchNorm
using Tripy for Resnet50 model.stout (helper prints for shapes)
stderr
mlir dumps
tripy-mlir-batchnorm.zip
seems like
tp.mean
andtp.variance
reduction operation failure at MLIR compile.The text was updated successfully, but these errors were encountered: