Skip to content

Commit

Permalink
Fix/autotune error handling (#2670)
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanielsimard authored Jan 8, 2025
1 parent e588632 commit da8de56
Show file tree
Hide file tree
Showing 28 changed files with 158 additions and 706 deletions.
24 changes: 12 additions & 12 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,8 @@ ahash = { version = "0.8.11", default-features = false }
portable-atomic-util = { version = "0.2.4", features = ["alloc"] }

### For the main burn branch. ###
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "4d6f50f3af4c8dd664619b61e6adf437e4b09e2e" }
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "4d6f50f3af4c8dd664619b61e6adf437e4b09e2e" }
cubecl = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "5a3f9ac9f6178c4f76570535bf5e42ef12a19a3d" }
cubecl-common = { git = "https://github.com/tracel-ai/cubecl", default-features = false, rev = "5a3f9ac9f6178c4f76570535bf5e42ef12a19a3d" }
### For local development. ###
# cubecl = { path = "../cubecl/crates/cubecl", default-features = false }
# cubecl-common = { path = "../cubecl/crates/cubecl-common", default-features = false }
Expand Down
3 changes: 2 additions & 1 deletion crates/burn-jit/src/fusion/matmul/optimization.rs
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,8 @@ impl<R: JitRuntime> MatmulOptimization<R> {
rhs_tensor,
None,
matmul::MatmulStrategy::default(),
);
)
.unwrap();
(out_tensor, out)
};
context
Expand Down
16 changes: 9 additions & 7 deletions crates/burn-jit/src/kernel/conv/conv2d/base.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use burn_tensor::ops::{ConvOptions, ConvTransposeOptions};

use crate::{tensor::JitTensor, FloatElement, IntElement, JitRuntime};
use crate::{
kernel::conv::ConvLaunchError, tensor::JitTensor, FloatElement, IntElement, JitRuntime,
};

#[cfg(feature = "autotune")]
use super::{conv2d_autotune, conv_transpose2d_autotune};
Expand Down Expand Up @@ -75,11 +77,11 @@ pub fn conv2d<R: JitRuntime, E: FloatElement>(
bias: Option<JitTensor<R>>,
options: ConvOptions<2>,
strategy: Conv2dStrategy,
) -> JitTensor<R> {
) -> Result<JitTensor<R>, ConvLaunchError> {
match strategy {
Conv2dStrategy::Direct => conv2d_direct::<R, E>(input, weight, bias, options),
#[cfg(feature = "autotune")]
Conv2dStrategy::Autotune => conv2d_autotune::<R, E>(input, weight, bias, options),
Conv2dStrategy::Autotune => Ok(conv2d_autotune::<R, E>(input, weight, bias, options)),
Conv2dStrategy::Gemm => conv2d_im2col::<R, E>(input, weight, bias, options),
Conv2dStrategy::ImplicitGemm => conv2d_implicit_gemm::<R, E>(input, weight, bias, options),
Conv2dStrategy::ImplicitGemmComplex => {
Expand All @@ -102,15 +104,15 @@ pub fn conv_transpose2d<R: JitRuntime, E: FloatElement, I: IntElement>(
bias: Option<JitTensor<R>>,
options: ConvTransposeOptions<2>,
strategy: ConvTranspose2dStrategy,
) -> JitTensor<R> {
) -> Result<JitTensor<R>, ConvLaunchError> {
match strategy {
ConvTranspose2dStrategy::Direct => {
conv_transpose2d_direct::<R, E>(input, weight, bias, options)
}
#[cfg(feature = "autotune")]
ConvTranspose2dStrategy::Autotune => {
conv_transpose2d_autotune::<R, E>(input, weight, bias, options)
}
ConvTranspose2dStrategy::Autotune => Ok(conv_transpose2d_autotune::<R, E>(
input, weight, bias, options,
)),
ConvTranspose2dStrategy::Gemm => {
conv_transpose2d_col2im::<R, E>(input, weight, bias, options)
}
Expand Down
20 changes: 13 additions & 7 deletions crates/burn-jit/src/kernel/conv/conv2d/col2im.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use cubecl::{calculate_cube_count_elemwise, prelude::*};

use crate::{
kernel::{
conv::ConvLaunchError,
into_contiguous,
matmul::{matmul, MatmulStrategy},
slice,
Expand All @@ -29,7 +30,7 @@ pub fn conv_transpose2d_col2im<R: JitRuntime, E: FloatElement>(
weight: JitTensor<R>,
bias: Option<JitTensor<R>>,
options: ConvTransposeOptions<2>,
) -> JitTensor<R> {
) -> Result<JitTensor<R>, ConvLaunchError> {
let [input_channels, im_ch_per_group, kernel_h, kernel_w] = weight.shape.dims();
let [batch_size, _, input_h, input_w] = input.shape.dims();
let groups = options.groups;
Expand Down Expand Up @@ -94,9 +95,12 @@ pub fn conv_transpose2d_col2im<R: JitRuntime, E: FloatElement>(
options.clone(),
kernel_h,
kernel_w,
);
)?;
}
reshape(image, Shape::new([batch_size, im_channels, im_h, im_w]))
Ok(reshape(
image,
Shape::new([batch_size, im_channels, im_h, im_w]),
))
} else {
let im_shape = Shape::new([batches_per_run, im_channels, im_h, im_w]);
let image = empty_device::<R, E>(input.client.clone(), input.device.clone(), im_shape);
Expand All @@ -108,8 +112,8 @@ pub fn conv_transpose2d_col2im<R: JitRuntime, E: FloatElement>(
options,
kernel_h,
kernel_w,
);
image
)?;
Ok(image)
}
}

Expand All @@ -135,7 +139,7 @@ fn execute<R: JitRuntime, E: FloatElement>(
options: ConvTransposeOptions<2>,
kernel_h: usize,
kernel_w: usize,
) {
) -> Result<(), ConvLaunchError> {
let [batch_size, _, input_h, input_w] = input.shape.dims();
let [groups, col_shape_0, input_ch_per_group] = weight.shape.dims();

Expand All @@ -145,12 +149,14 @@ fn execute<R: JitRuntime, E: FloatElement>(
let input_shape = Shape::new([groups, input_ch_per_group, col_shape_1]);
let input = reshape(input, input_shape);

let columns = matmul::<R, E>(weight, input, None, MatmulStrategy::default());
let columns = matmul::<R, E>(weight, input, None, MatmulStrategy::default())?;
let columns = reshape(columns, Shape::new([col_shape_0 * groups, col_shape_1]));

col2im::<R, E>(
columns, bias, image, kernel_h, kernel_w, input_h, input_w, options,
);

Ok(())
}

#[allow(clippy::too_many_arguments)]
Expand Down
6 changes: 3 additions & 3 deletions crates/burn-jit/src/kernel/conv/conv2d/direct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use burn_tensor::{
use cubecl::{calculate_cube_count_elemwise, prelude::*};

use crate::{
kernel::into_contiguous,
kernel::{conv::ConvLaunchError, into_contiguous},
ops::{
numeric::{empty_device, zeros_device},
reshape,
Expand Down Expand Up @@ -125,7 +125,7 @@ pub fn conv2d_direct<R: JitRuntime, E: FloatElement>(
weight: JitTensor<R>,
bias: Option<JitTensor<R>>,
options: ConvOptions<2>,
) -> JitTensor<R> {
) -> Result<JitTensor<R>, ConvLaunchError> {
let [batch_size, _, in_height, in_width] = input.shape.dims();
let [out_channels, _, kernel_h, kernel_w] = weight.shape.dims();
let channels_per_group = out_channels / options.groups;
Expand Down Expand Up @@ -193,5 +193,5 @@ pub fn conv2d_direct<R: JitRuntime, E: FloatElement>(
kernel_w_unroll,
);

output
Ok(output)
}
12 changes: 6 additions & 6 deletions crates/burn-jit/src/kernel/conv/conv2d/gemm/launch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use crate::{
algorithm::{Algorithm, ImplicitCmmaConv},
base::{ConvolutionLaunch, ConvolutionProblem},
},
nchw_to_nhwc, Conv2dAutotuneKey,
nchw_to_nhwc, Conv2dAutotuneKey, ConvLaunchError,
},
into_contiguous,
},
Expand All @@ -44,7 +44,7 @@ pub fn conv2d_gemm_cmma_large_m<R: JitRuntime, F: FloatElement>(
weight: JitTensor<R>,
bias: Option<JitTensor<R>>,
options: ConvOptions<2>,
) -> JitTensor<R> {
) -> Result<JitTensor<R>, ConvLaunchError> {
conv2d_gemm_cmma_strategy::<R, F, ImplicitCmmaConv, Large>(input, weight, bias, options)
}

Expand All @@ -60,7 +60,7 @@ pub fn conv2d_gemm_cmma_balanced<R: JitRuntime, F: FloatElement>(
weight: JitTensor<R>,
bias: Option<JitTensor<R>>,
options: ConvOptions<2>,
) -> JitTensor<R> {
) -> Result<JitTensor<R>, ConvLaunchError> {
conv2d_gemm_cmma_strategy::<R, F, ImplicitCmmaConv, Balanced>(input, weight, bias, options)
}

Expand All @@ -74,7 +74,7 @@ fn conv2d_gemm_cmma_strategy<
weight: JitTensor<R>,
bias: Option<JitTensor<R>>,
options: ConvOptions<2>,
) -> JitTensor<R> {
) -> Result<JitTensor<R>, ConvLaunchError> {
if TypeId::of::<F>() == TypeId::of::<flex32>() {
conv2d_gemm_with_algo::<R, (F, f16, f32), Alg, S>(input, weight, bias, options)
} else if TypeId::of::<F>() == TypeId::of::<bf16>() || TypeId::of::<F>() == TypeId::of::<f16>()
Expand Down Expand Up @@ -102,7 +102,7 @@ pub fn conv2d_gemm_with_algo<
weight: JitTensor<R>,
bias: Option<JitTensor<R>>,
options: ConvOptions<2>,
) -> JitTensor<R>
) -> Result<JitTensor<R>, ConvLaunchError>
where
SP::EG: JitElement,
{
Expand Down Expand Up @@ -221,7 +221,7 @@ where

// Reset to NCHW
let out = reshape(out, Shape::new([batch_size, out_h, out_w, out_channels]));
permute(out, &[0, 3, 1, 2])
Ok(permute(out, &[0, 3, 1, 2]))
}

pub fn problem_from_key<R: JitRuntime, F: FloatElement>(
Expand Down
Loading

0 comments on commit da8de56

Please sign in to comment.