Skip to content

Commit

Permalink
Add Prefetch support and retain discardable attributes in XeTile Cano…
Browse files Browse the repository at this point in the history
…nicalization. (#853)

* Add Prefetch support and retain discardable attributes in XeTile Canonicalization

* pre-commit issue
  • Loading branch information
charithaintc authored Aug 30, 2024
1 parent 9fd715c commit f217947
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 12 deletions.
60 changes: 49 additions & 11 deletions lib/Dialect/XeTile/Transforms/Canonicalization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,20 @@ struct UpdateTileOffsetOpPattern final
}
};

struct PrefetchTilePattern final
: public mlir::OpConversionPattern<imex::xetile::PrefetchTileOp> {
using OpConversionPattern<imex::xetile::PrefetchTileOp>::OpConversionPattern;
mlir::LogicalResult
matchAndRewrite(imex::xetile::PrefetchTileOp prefetchOp, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
// Create a new prefetch op.
rewriter.replaceOpWithNewOp<imex::xetile::PrefetchTileOp>(
prefetchOp, adaptor.getTile(), prefetchOp.getL1HintAttr(),
prefetchOp.getL2HintAttr(), prefetchOp.getL3HintAttr());
return mlir::success();
}
};

// Pattern for rewriting LoadTileOp to consume row-major tiles.
struct LoadTileOpPattern final
: public mlir::OpConversionPattern<imex::xetile::LoadTileOp> {
Expand Down Expand Up @@ -216,9 +230,13 @@ struct VectorTransposeToXetileTransposeOpPattern
mlir::PatternRewriter &rewriter) const override {
if (op.getVector().getType().getRank() != 2)
return mlir::failure();
// Retain discardable attributes if any.
llvm::SmallVector<mlir::NamedAttribute> discardableAttrs(
op->getDiscardableAttrs().begin(), op->getDiscardableAttrs().end());
// Create an equivalent XeTileTransposeOp
rewriter.replaceOpWithNewOp<imex::xetile::TransposeOp>(
auto newOp = rewriter.replaceOpWithNewOp<imex::xetile::TransposeOp>(
op, op.getType(), op.getVector(), op.getPermutation());
newOp->setDiscardableAttrs(discardableAttrs);
return mlir::success();
}
};
Expand All @@ -242,6 +260,9 @@ struct VectorBroadcastToXetileBroadcastOpPattern
auto sourceVectorTy = llvm::cast<mlir::VectorType>(op.getSourceType());
auto sourceRank = sourceVectorTy.getRank();
auto sourceShape = sourceVectorTy.getShape();
// Retain the discardable attributes if any.
llvm::SmallVector<mlir::NamedAttribute> discardableAttrs(
op->getDiscardableAttrs().begin(), op->getDiscardableAttrs().end());
// If the source rank is 1 and result rank is 2, we need to create a shape
// cast to convert source to 2D and then create a xetile.broadcast. In this
// case, broadcast dimension is 0 according to vector.broadcast definition.
Expand All @@ -251,14 +272,17 @@ struct VectorBroadcastToXetileBroadcastOpPattern
resultTy.getElementType());
auto source2D = rewriter.create<mlir::vector::ShapeCastOp>(
op.getLoc(), source2DTy, op.getSource());
rewriter.replaceOpWithNewOp<imex::xetile::BroadcastOp>(
source2D->setDiscardableAttrs(discardableAttrs);
auto newOp = rewriter.replaceOpWithNewOp<imex::xetile::BroadcastOp>(
op, resultTy, source2D, llvm::ArrayRef<int64_t>({0}));
newOp->setDiscardableAttrs(discardableAttrs);
return mlir::success();
}
// If ranks are same, inner dimension is stretched in vector.broadcast. So
// broadcast dimension is 1 for this case.
rewriter.replaceOpWithNewOp<imex::xetile::BroadcastOp>(
auto newOp = rewriter.replaceOpWithNewOp<imex::xetile::BroadcastOp>(
op, resultTy, op.getSource(), llvm::ArrayRef<int64_t>({1}));
newOp->setDiscardableAttrs(discardableAttrs);
return mlir::success();
}
};
Expand All @@ -281,7 +305,9 @@ struct VectorMultiReductionToXeTileReduce
auto reductionDims = op.getReductionDims().getValue();
if (reductionDims.size() != 1)
return mlir::failure();

// Retain discardable attributes if any.
llvm::SmallVector<mlir::NamedAttribute> discardableAttrs(
op->getDiscardableAttrs().begin(), op->getDiscardableAttrs().end());
// Create an equivalent XeTileReduceOp
int64_t reduceDim = llvm::cast<mlir::IntegerAttr>(reductionDims[0])
.getValue()
Expand All @@ -294,16 +320,21 @@ struct VectorMultiReductionToXeTileReduce
auto reduceOp = rewriter.create<imex::xetile::ReductionOp>(
op->getLoc(), xetileResultTy, op.getKind(), op.getSource(),
mlir::ArrayRef<int64_t>({reduceDim}));
reduceOp->setDiscardableAttrs(discardableAttrs);
// Shape cast the result back to original shape.
auto shapeCastOp = rewriter.create<mlir::vector::ShapeCastOp>(
op->getLoc(), resultTy, reduceOp.getResult());
shapeCastOp->setDiscardableAttrs(discardableAttrs);
// Finally add the result to the accumulator.
if (llvm::isa<mlir::IntegerType>(sourceTy.getElementType()))
rewriter.replaceOpWithNewOp<mlir::arith::AddIOp>(op, shapeCastOp,
op.getAcc());
else
rewriter.replaceOpWithNewOp<mlir::arith::AddFOp>(op, shapeCastOp,
op.getAcc());
if (llvm::isa<mlir::IntegerType>(sourceTy.getElementType())) {
auto accOp = rewriter.replaceOpWithNewOp<mlir::arith::AddIOp>(
op, shapeCastOp, op.getAcc());
accOp->setDiscardableAttrs(discardableAttrs);
} else {
auto accOp = rewriter.replaceOpWithNewOp<mlir::arith::AddFOp>(
op, shapeCastOp, op.getAcc());
accOp->setDiscardableAttrs(discardableAttrs);
}
return mlir::success();
}
};
Expand Down Expand Up @@ -406,6 +437,12 @@ struct XeTileCanonicalizationPass final
op) {
return op.getType().getOrder().asArrayRef() != mlir::ArrayRef({0, 1});
});
// PrefetchTileOp is legal if it does not consume col-major tiles.
target.addDynamicallyLegalOp<imex::xetile::PrefetchTileOp>(
[&](imex::xetile::PrefetchTileOp op) {
return op.getTile().getType().getOrder().asArrayRef() !=
mlir::ArrayRef({0, 1});
});
// LoadTileOp is legal if it does not consume col-major tiles.
target.addDynamicallyLegalOp<imex::xetile::LoadTileOp>(
[&](imex::xetile::LoadTileOp op) {
Expand Down Expand Up @@ -437,7 +474,8 @@ struct XeTileCanonicalizationPass final
});
patterns
.add<InitTileOpPattern, LoadTileOpPattern, UpdateTileOffsetOpPattern,
ScfForOpPattern, ScfYieldOpPattern>(typeConverter, context);
PrefetchTilePattern, ScfForOpPattern, ScfYieldOpPattern>(
typeConverter, context);

if (mlir::failed(mlir::applyPartialConversion(getOperation(), target,
std::move(patterns))))
Expand Down
35 changes: 34 additions & 1 deletion test/Dialect/XeTile/Transforms/canonicalization.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@ gpu.module @test_module {
gpu.func @test_static_memref(%arg0 : memref<512x128xf16, strided<[1, 512], offset:0>>, %arg1 : index, %arg2 : index) {
%0 = xetile.init_tile %arg0 [%arg1, %arg2] : memref<512x128xf16, strided<[1, 512], offset:0>> -> !xetile.tile<16x32xf16, #xetile.tile_attr<order=[0,1]>>
%3 = xetile.load_tile %0 : !xetile.tile<16x32xf16, #xetile.tile_attr<order=[0,1]>> -> vector<16x32xf16>
xetile.prefetch_tile %0 : !xetile.tile<16x32xf16, #xetile.tile_attr<order=[0,1]>>
// With static offsets
%1 = xetile.init_tile %arg0 [12, %arg1] : memref<512x128xf16, strided<[1, 512], offset:0>> -> !xetile.tile<16x32xf16, #xetile.tile_attr<order=[0,1]>>
// Update offsets
%c16 = arith.constant 16 : index
%c32 = arith.constant 32 : index
%2 = xetile.update_tile_offset %1, [%c32, %c16] : !xetile.tile<16x32xf16, #xetile.tile_attr<order=[0,1]>>, index, index -> !xetile.tile<16x32xf16, #xetile.tile_attr<order=[0,1]>>
%4 = xetile.load_tile %2 : !xetile.tile<16x32xf16, #xetile.tile_attr<order=[0,1]>> -> vector<16x32xf16>
xetile.prefetch_tile %1 : !xetile.tile<16x32xf16, #xetile.tile_attr<order=[0,1]>>
gpu.return
}
}
Expand All @@ -23,22 +25,25 @@ gpu.module @test_module {
// CHECK: %[[T0:.*]] = xetile.init_tile %[[RCAST]][%[[ARG2]], %[[ARG1]]] : memref<128x512xf16, strided<[512, 1]>> -> !xetile.tile<32x16xf16, #xetile.tile_attr<>>
// CHECK: %[[T1:.*]] = xetile.load_tile %[[T0]] { padding = 0.000000e+00 : f32 } : !xetile.tile<32x16xf16, #xetile.tile_attr<>> -> vector<32x16xf16>
// CHECK: %[[T2:.*]] = xetile.transpose %[[T1]], [1, 0] : vector<32x16xf16> -> vector<16x32xf16>
// CHECK: xetile.prefetch_tile %[[T0]] : !xetile.tile<32x16xf16, #xetile.tile_attr<>>
// CHECK: %[[RCAST0:.*]] = memref.reinterpret_cast %[[ARG0]] to offset: [0], sizes: [128, 512], strides: [512, 1] : memref<512x128xf16, strided<[1, 512]>> to memref<128x512xf16, strided<[512, 1]>>
// CHECK: %[[T3:.*]] = xetile.init_tile %[[RCAST0]][%[[ARG1]], 12] : memref<128x512xf16, strided<[512, 1]>> -> !xetile.tile<32x16xf16, #xetile.tile_attr<>>
// CHECK: %[[T4:.*]] = xetile.update_tile_offset %[[T3]], [%[[C16]], %[[C32]]] : !xetile.tile<32x16xf16, #xetile.tile_attr<>>, index, index -> !xetile.tile<32x16xf16, #xetile.tile_attr<>>
// CHECK: %[[T5:.*]] = xetile.load_tile %[[T4]] { padding = 0.000000e+00 : f32 } : !xetile.tile<32x16xf16, #xetile.tile_attr<>> -> vector<32x16xf16>
// CHECK: %[[T6:.*]] = xetile.transpose %[[T5]], [1, 0] : vector<32x16xf16> -> vector<16x32xf16>
// CHECK: xetile.prefetch_tile %[[T3]] : !xetile.tile<32x16xf16, #xetile.tile_attr<>>

// -----
gpu.module @test_module {
gpu.func @test_dynamic_memref(%arg0 : memref<?x?xf16>, %arg1 : index, %arg2 : index, %arg3 : index, %arg4 : index, %arg5 : index, %arg6 : index) {
%0 = xetile.init_tile %arg0 [%arg1, %arg2], [%arg3, %arg4], [%arg5, %arg6] : memref<?x?xf16> -> !xetile.tile<16x32xf16, #xetile.tile_attr<order=[0,1]>>
%1 = xetile.load_tile %0 : !xetile.tile<16x32xf16, #xetile.tile_attr<order=[0,1]>> -> vector<16x32xf16>
xetile.load_tile %0 : !xetile.tile<16x32xf16, #xetile.tile_attr<order=[0,1]>> -> vector<16x32xf16>
// Update offsets
%c16 = arith.constant 16 : index
%c32 = arith.constant 32 : index
%2 = xetile.update_tile_offset %0, [%c32, %c16] : !xetile.tile<16x32xf16, #xetile.tile_attr<order=[0,1]>>, index, index -> !xetile.tile<16x32xf16, #xetile.tile_attr<order=[0,1]>>
%3 = xetile.load_tile %2 : !xetile.tile<16x32xf16, #xetile.tile_attr<order=[0,1]>> -> vector<16x32xf16>
xetile.prefetch_tile %2 : !xetile.tile<16x32xf16, #xetile.tile_attr<order=[0,1]>>
gpu.return
}
}
Expand All @@ -59,6 +64,7 @@ gpu.module @test_module {
// CHECK: %[[T3:.*]] = xetile.update_tile_offset %[[T0]], [%[[C16]], %[[C32]]] : !xetile.tile<32x16xf16, #xetile.tile_attr<>>, index, index -> !xetile.tile<32x16xf16, #xetile.tile_attr<>>
// CHECK: %[[T4:.*]] = xetile.load_tile %[[T3]] { padding = 0.000000e+00 : f32 } : !xetile.tile<32x16xf16, #xetile.tile_attr<>> -> vector<32x16xf16>
// CHECK: %[[T5:.*]] = xetile.transpose %[[T4]], [1, 0] : vector<32x16xf16> -> vector<16x32xf16>
// CHECK: xetile.prefetch_tile %[[T3]] : !xetile.tile<32x16xf16, #xetile.tile_attr<>>

// -----
gpu.module @test_module {
Expand Down Expand Up @@ -272,10 +278,37 @@ gpu.module @test_module {
}
}

// CHECK-LABEL: @test_multireduction_1
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: vector<64x256xf32>, %[[ARG1:[a-zA-Z0-9]+]]: vector<256xf32>) -> vector<256xf32>
// CHECK: %[[T0:.*]] = xetile.reduction <add>, %[[ARG0]] [0] : vector<64x256xf32> -> vector<1x256xf32>
// CHECK: %[[T1:.*]] = vector.shape_cast %[[T0]] : vector<1x256xf32> to vector<256xf32>
// CHECK: %[[T2:.*]] = arith.addf %[[T1]], %[[ARG1]] : vector<256xf32>
// CHECK: gpu.return %[[T2]] : vector<256xf32>

// -----
gpu.module @test_module {
gpu.func @test_multireduction_2(%arg0 : vector<64x256xi8>, %arg1 : vector<256xi8>) -> vector<256xi8> {
%0 = vector.multi_reduction <add>, %arg0, %arg1 [0] : vector<64x256xi8> to vector<256xi8>
gpu.return %0 : vector<256xi8>
}
}

// CHECK-LABEL: @test_multireduction_2
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: vector<64x256xi8>, %[[ARG1:[a-zA-Z0-9]+]]: vector<256xi8>) -> vector<256xi8>
// CHECK: %[[T0:.*]] = xetile.reduction <add>, %[[ARG0]] [0] : vector<64x256xi8> -> vector<1x256xi8>
// CHECK: %[[T1:.*]] = vector.shape_cast %[[T0]] : vector<1x256xi8> to vector<256xi8>
// CHECK: %[[T2:.*]] = arith.addi %[[T1]], %[[ARG1]] : vector<256xi8>
// CHECK: gpu.return %[[T2]] : vector<256xi8>

// -----
gpu.module @test_module {
gpu.func @test_transpose_1(%arg0 : vector<16x32xf32>) -> vector<32x16xf32> {
%0 = vector.transpose %arg0, [1, 0] : vector<16x32xf32> to vector<32x16xf32>
gpu.return %0 : vector<32x16xf32>
}
}

// CHECK-LABEL: @test_transpose_1
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: vector<16x32xf32>) -> vector<32x16xf32>
// CHECK: %[[T0:.*]] = xetile.transpose %arg0, [1, 0] : vector<16x32xf32> -> vector<32x16xf32>
// CHECK: gpu.return %[[T0]] : vector<32x16xf32>

0 comments on commit f217947

Please sign in to comment.