Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error: 'tensor.expand_shape' op expected dimension 1 of collapsed type to be static value of 320 #1618

Open
pfultz2 opened this issue Aug 16, 2024 · 1 comment

Comments

@pfultz2
Copy link

pfultz2 commented Aug 16, 2024

There is an error:

Invalid MLIR created: Error: 'tensor.expand_shape' op expected dimension 1 of collapsed type to be static value of 320
Note: see current operation: %14 = "tensor.expand_shape"(%12) <{reassociation = [[0], [1, 2], [3], [4]], static_output_shape = array<i64: 2, 32, 10, 128, 128>}> : (tensor<2x128x128x320xf16>) -> tensor<2x32x10x128x128xf16>

From compiling this mlir program:

Problem: gfx942:sramecc+:xnack- 304     -t f16 -out_datatype f16 -transA false -transB false -g 1 -m 2 -n 1280 -k 320
module {
  func.func @mlir_convolution_reshape_add(%arg0: !migraphx.shaped<2x32x10x128x128xf16, 0x10x1x0x0>, %arg1: !migraphx.shaped<2x4x128x128xf16, 65536x1x512x4>, %arg2: !migraphx.shaped<320x4x3x3xf16, 36x1x12x4>) -> (!migraphx.shaped<2x320x128x128xf16, 5242880x1x40960x320>, !migraphx.shaped<2x32x10x128x128xf16, 5242880x163840x16384x128x1>) attributes {arch = "gfx942:sramecc+:xnack-", enable_splitk_for_tuning, kernel = "mixr", num_cu = 304 : i64} {
    %0 = migraphx.convolution %arg1, %arg2 {dilation = [1, 1], group = 1 : i64, padding = [1, 1, 1, 1], padding_mode = 0 : i64, stride = [1, 1]} : <2x4x128x128xf16, 65536x1x512x4>, <320x4x3x3xf16, 36x1x12x4> -> <2x320x128x128xf16, 5242880x1x40960x320>
    %1 = migraphx.reshape %0 {dims = [2, 32, 10, 128, 128]} : <2x320x128x128xf16, 5242880x1x40960x320> -> <2x32x10x128x128xf16, 5242880x163840x16384x128x1>
    %2 = migraphx.add %1, %arg0 : <2x32x10x128x128xf16, 5242880x163840x16384x128x1>, <2x32x10x128x128xf16, 0x10x1x0x0> -> <2x32x10x128x128xf16, 5242880x163840x16384x128x1>
    return %0, %2 : !migraphx.shaped<2x320x128x128xf16, 5242880x1x40960x320>, !migraphx.shaped<2x32x10x128x128xf16, 5242880x163840x16384x128x1>
  }
}

Which comes from this migraphx module:

y1.0 = @param:y1.0 -> half_type, {320, 4, 3, 3}, {36, 1, 12, 4}
y0.0 = @param:y0.0 -> half_type, {2, 4, 128, 128}, {65536, 1, 512, 4}
x2.0 = @param:x2.0 -> half_type, {2, 32, 10, 128, 128}, {0, 10, 1, 0, 0}
@3 = convolution[padding={1, 1, 1, 1},stride={1, 1},dilation={1, 1},group=1,padding_mode=0](y0.0,y1.0) -> half_type, {2, 320, 128, 128}, {5242880, 1, 40960, 320}
@4 = reshape[dims={2, 32, 10, 128, 128}](@3) -> half_type, {2, 32, 10, 128, 128}, {5242880, 163840, 16384, 128, 1}
@5 = add(@4,x2.0) -> half_type, {2, 32, 10, 128, 128}, {5242880, 163840, 16384, 128, 1}
@6 = @return(@3,@5)

The backend output shows:

#map = affine_map<(d0, d1) -> (d0 + d1)>
#map1 = affine_map<(d0, d1) -> (d1, d0)>
#map2 = affine_map<(d0, d1) -> (0, d1)>
#map3 = affine_map<(d0, d1) -> (d0 * 320 + d1)>
#map4 = affine_map<(d0, d1, d2) -> (d0 * 2 + d1, d2)>
#map5 = affine_map<(d0, d1, d2) -> (d0 * 320 + d1, d2)>
#map6 = affine_map<(d0, d1) -> (0, d0, d1)>
#map7 = affine_map<(d0, d1) -> (d0, d1)>
#map8 = affine_map<(d0) -> (d0 floordiv 1280, d0 mod 1280)>
#transform_map = #rock.transform_map<#map by [<Unmerge{1280, 1} ["exp0", "exp1"] at [0, 1] -> ["dim0"] at [0]>] bounds = [1280, 1] -> [1280]>
#transform_map1 = #rock.transform_map<#map1 by [<PassThrough ["dim1", "dim0"] at [0, 1] -> ["dim1", "dim0"] at [1, 0]>] bounds = [1, 1280] -> [1280, 1]>
#transform_map2 = #rock.transform_map<#map2 by [<Broadcast{1} ["dim0"] at [0] -> ["dim0"] at [0]>, <PassThrough ["dim1"] at [1] -> ["dim1"] at [1]>] bounds = [2, 1280] -> [1, 1280]>
#transform_map3 = #rock.transform_map<#map3 by [<Unmerge{1280, 320} ["exp0", "exp1"] at [0, 1] -> ["dim0"] at [0]>] bounds = [1280, 320] -> [409600]>
#transform_map4 = #rock.transform_map<#map1 by [<PassThrough ["dim1", "dim0"] at [0, 1] -> ["dim1", "dim0"] at [1, 0]>] bounds = [320, 1280] -> [1280, 320]>
#transform_map5 = #rock.transform_map<#map3 by [<Unmerge{1, 320} ["exp0", "exp1"] at [0, 1] -> ["dim0"] at [0]>] bounds = [1, 320] -> [320]>
#transform_map6 = #rock.transform_map<#map2 by [<Broadcast{1} ["dim0"] at [0] -> ["dim0"] at [0]>, <PassThrough ["dim1"] at [1] -> ["dim1"] at [1]>] bounds = [2, 320] -> [1, 320]>
#transform_map7 = #rock.transform_map<#map4 by [<Unmerge{1, 2} ["exp0", "exp1"] at [0, 1] -> ["dim0"] at [0]>, <PassThrough ["dim1"] at [2] -> ["dim1"] at [1]>] bounds = [1, 2, 320] -> [2, 320]>
#transform_map8 = #rock.transform_map<#map5 by [<Unmerge{1, 320} ["exp0", "exp1"] at [0, 1] -> ["dim0"] at [0]>, <PassThrough ["dim1"] at [2] -> ["dim1"] at [1]>] bounds = [1, 320, 1280] -> [320, 1280]>
#transform_map9 = #rock.transform_map<#map6 by [<ConstDim{0, 1} [] at [] -> ["g"] at [0]>, <PassThrough ["d0", "d1"] at [0, 1] -> ["d0", "d1"] at [1, 2]>] bounds = [2, 320] -> [1, 2, 320]>
#transform_map10 = #rock.transform_map<#map6 by [<Merge{1, 1280} ["gd1"] at [1] -> ["g", "d1"] at [0, 2]>, <PassThrough ["d0"] at [0] -> ["d0"] at [1]>] bounds = [320, 1280] -> [1, 320, 1280]>
#transform_map11 = #rock.transform_map<#map6 by [<Merge{1, 1280} ["gd1"] at [1] -> ["g", "d1"] at [0, 2]>, <PassThrough ["d0"] at [0] -> ["d0"] at [1]>] bounds = [2, 1280] -> [1, 2, 1280]>
#transform_map12 = #rock.transform_map<#map6 by [<Merge{1, 2} ["dim0"] at [0] -> ["col0", "col1"] at [0, 1]>, <PassThrough ["dim1"] at [1] -> ["dim1"] at [2]>] bounds = [2, 1280] -> [1, 2, 1280]>
#transform_map13 = #rock.transform_map<#map8 by [<Merge{2, 1280} ["dim0"] at [0] -> ["col0", "col1"] at [0, 1]>] bounds = [2560] -> [2, 1280]>
module {
  func.func @mlir_dot_add_sigmoid_mul(%arg0: memref<1280xf16>, %arg1: memref<320xf16>, %arg2: memref<409600xf16>, %arg3: memref<2560xf16>) attributes {arch = "gfx942:sramecc+:xnack-", enable_splitk_for_tuning, kernel = "mixr", num_cu = 304 : i64} {
    %cst = arith.constant 1.000000e+00 : f16
    %0 = rock.transform %arg0 by #transform_map : memref<1280xf16> to memref<1280x1xf16>
    %1 = rock.transform %0 by #transform_map1 : memref<1280x1xf16> to memref<1x1280xf16>
    %2 = rock.transform %1 by #transform_map2 : memref<1x1280xf16> to memref<2x1280xf16>
    %3 = rock.transform %arg2 by #transform_map3 : memref<409600xf16> to memref<1280x320xf16>
    %4 = rock.transform %3 by #transform_map4 : memref<1280x320xf16> to memref<320x1280xf16>
    %5 = rock.transform %arg1 by #transform_map5 : memref<320xf16> to memref<1x320xf16>
    %6 = rock.transform %5 by #transform_map6 : memref<1x320xf16> to memref<2x320xf16>
    %7 = rock.transform %6 by #transform_map7 : memref<2x320xf16> to memref<1x2x320xf16>
    %8 = rock.transform %4 by #transform_map8 : memref<320x1280xf16> to memref<1x320x1280xf16>
    %alloc = memref.alloc() {alignment = 64 : i64} : memref<1x2x1280xf16>
    %9 = rock.transform %7 by #transform_map9 : memref<1x2x320xf16> to memref<2x320xf16>
    %10 = rock.transform %8 by #transform_map10 : memref<1x320x1280xf16> to memref<320x1280xf16>
    %11 = rock.transform %alloc by #transform_map11 : memref<1x2x1280xf16> to memref<2x1280xf16>
    rock.gemm %11 = %9 * %10 features =  mfma|dot|atomic_add storeMethod =  set {arch = "gfx942:sramecc+:xnack-", numCU = 304 : i32} : memref<2x1280xf16> = memref<2x320xf16> * memref<320x1280xf16>
    %12 = rock.transform %alloc by #transform_map12 : memref<1x2x1280xf16> to memref<2x1280xf16>
    %alloc_0 = memref.alloc() {alignment = 64 : i64} : memref<2x1280xf16>
    linalg.generic {indexing_maps = [#map7, #map7, #map7], iterator_types = ["parallel", "parallel"]} ins(%12, %2 : memref<2x1280xf16>, memref<2x1280xf16>) outs(%alloc_0 : memref<2x1280xf16>) {
    ^bb0(%in: f16, %in_1: f16, %out: f16):
      %14 = arith.addf %in, %in_1 : f16
      %15 = arith.negf %14 : f16
      %16 = math.exp %15 : f16
      %17 = arith.addf %16, %cst : f16
      %18 = arith.divf %cst, %17 : f16
      %19 = arith.mulf %14, %18 : f16
      linalg.yield %19 : f16
    }
    %13 = rock.transform %alloc_0 by #transform_map13 : memref<2x1280xf16> to memref<2560xf16>
    memref.copy %13, %arg3 : memref<2560xf16> to memref<2560xf16>
    return
  }
}
@krzysz00
Copy link
Collaborator

@pfultz2 I can't reproduce the issue - is this the right input?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants