Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Incorrect IR generated for Vector Mode AD #235

Open
avik-pal opened this issue Jan 12, 2025 · 4 comments
Open

Incorrect IR generated for Vector Mode AD #235

avik-pal opened this issue Jan 12, 2025 · 4 comments

Comments

@avik-pal
Copy link
Collaborator

module {
  func.func private @"Const{typeof(rosenbrock)}(Main.rosenbrock)_autodiff"(%arg0: tensor<2xf64>) -> (tensor<f64>, tensor<2xf64>) {
    %cst = stablehlo.constant dense<1.000000e+02> : tensor<f64>
    %cst_0 = stablehlo.constant dense<1.000000e+00> : tensor<f64>
    %0 = stablehlo.slice %arg0 [0:1] : (tensor<2xf64>) -> tensor<1xf64>
    %1 = stablehlo.reshape %0 : (tensor<1xf64>) -> tensor<f64>
    %2 = stablehlo.subtract %cst_0, %1 : tensor<f64>
    %3 = stablehlo.multiply %2, %2 : tensor<f64>
    %4 = stablehlo.slice %arg0 [1:2] : (tensor<2xf64>) -> tensor<1xf64>
    %5 = stablehlo.reshape %4 : (tensor<1xf64>) -> tensor<f64>
    %6 = stablehlo.multiply %1, %1 : tensor<f64>
    %7 = stablehlo.subtract %5, %6 : tensor<f64>
    %8 = stablehlo.multiply %7, %7 : tensor<f64>
    %9 = stablehlo.multiply %cst, %8 : tensor<f64>
    %10 = stablehlo.add %3, %9 : tensor<f64>
    return %10, %arg0 : tensor<f64>, tensor<2xf64>
  }
  func.func @main(%arg0: tensor<2xf64>, %arg1: tensor<2xf64>, %arg2: tensor<2xf64>) -> (tensor<1xf64>, tensor<1xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>) {
    %0 = stablehlo.concatenate %arg1, %arg2, dim = 0 : (tensor<2xf64>, tensor<2xf64>) -> tensor<4xf64>
    %1 = stablehlo.reshape %0 : (tensor<4xf64>) -> tensor<2x2xf64>
    %2 = stablehlo.transpose %1, dims = [1, 0] : (tensor<2x2xf64>) -> tensor<2x2xf64>
    %3:3 = enzyme.fwddiff @"Const{typeof(rosenbrock)}(Main.rosenbrock)_autodiff"(%arg0, %2) {activity = [#enzyme<activity enzyme_dup>], ret_activity = [#enzyme<activity enzyme_dupnoneed>, #enzyme<activity enzyme_dup>]} : (tensor<2xf64>, tensor<2x2xf64>) -> (tensor<2xf64>, tensor<2xf64>, tensor<2x2xf64>)
    %4 = stablehlo.slice %3#0 [0:1] : (tensor<2xf64>) -> tensor<1xf64>
    %5 = stablehlo.slice %3#0 [1:2] : (tensor<2xf64>) -> tensor<1xf64>
    return %4, %5, %3#1, %arg1, %arg2 : tensor<1xf64>, tensor<1xf64>, tensor<2xf64>, tensor<2xf64>, tensor<2xf64>
  }
}
envs/nested.mlir:40:11: error: the number of elements in start_indices (1) does not match the rank of the operand (0)
    %14 = stablehlo.slice %13 [0:1] : (tensor<2xf64>) -> tensor<1xf64>
          ^
envs/nested.mlir:40:11: error: 'stablehlo.slice' op failed to infer returned types
    %14 = stablehlo.slice %13 [0:1] : (tensor<2xf64>) -> tensor<1xf64>
          ^
envs/nested.mlir:40:11: note: see current operation: %4 = "stablehlo.slice"(%3#0) <{limit_indices = array<i64: 1>, start_indices = array<i64: 0>, strides = array<i64: 1>}> : (tensor<f64>) -> tensor<1xf64>
@wsmoses
Copy link
Member

wsmoses commented Jan 12, 2025

Actually I think this is invalid IR, since the primal and shadow inputs should be the same type?

We clearly should have a better error message though

@wsmoses
Copy link
Member

wsmoses commented Jan 12, 2025

Alternatively this would be valid for batch size ==2, but perhaps we’re not passing in a batch size correctly?

That also said I’m fairly certain that the custom c++ (aka not tablegen) rules don’t support batching yet

@wsmoses
Copy link
Member

wsmoses commented Jan 12, 2025

cc @jumerckx

@avik-pal avik-pal changed the title Incorrect SliceOp adjoint for scalars Incorrect IR generated for Vector Mode AD Jan 12, 2025
@jumerckx
Copy link
Collaborator

Alternatively this would be valid for batch size ==2, but perhaps we’re not passing in a batch size correctly?

Yeah, if the batchwidth is passed correctly, there should be a width attribute in the fwdiff op, I believe?

That also said I’m fairly certain that the custom c++ (aka not tablegen) rules don’t support batching yet

👍

I can have a look tomorrow for how to properly pass in batch size if no one beats me to it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants