Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

VectorToXeGPU: Allows lowering vector.transfer_read and vector.transfer_write to XeGPU #773

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/imex/Conversion/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,4 @@ add_subdirectory(DistToStandard)
add_subdirectory(DropRegions)
add_subdirectory(XeTileToXeGPU)
add_subdirectory(XeGPUToVC)
add_subdirectory(VectorToXeGPU)
1 change: 1 addition & 0 deletions include/imex/Conversion/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include <imex/Conversion/XeGPUToSPIRV/XeGPUToSPIRV.h>
#include <imex/Conversion/XeGPUToVC/XeGPUToVC.h>
#include <imex/Conversion/XeTileToXeGPU/XeTileToXeGPU.h>
#include <imex/Conversion/VectorToXeGPU/VectorToXeGPU.h>

namespace imex {

Expand Down
30 changes: 30 additions & 0 deletions include/imex/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -441,4 +441,34 @@ def ConvertXeGPUToVC : Pass<"convert-xegpu-to-vc", "::mlir::gpu::GPUModuleOp"> {
let constructor = "imex::createConvertXeGPUToVCPass()";
}

//===----------------------------------------------------------------------===//
// VectorToXeGPU
//===----------------------------------------------------------------------===//

def ConvertVectorToXeGPU: Pass<"convert-vector-to-xegpu", "::mlir::ModuleOp"> {
let summary = "Convert from the Vector dialect to the XeGPU dialect.";
let description = [{
Convert Vector dialect operations into the XeGPU dialect operations. It aims at lowering `vector.transfer_read` and `vector.transfer_write` operations to `xegpu.load_nd` and `xegpu.store_nd` operations, creating the descriptors meanwhile.

#### Input invariant

%3 = vector.transfer_read %arg1[%0, %2], %arg2 : memref<512x640xf32>, vector<2x32xf32>
%4 = arith.cmpf ugt, %3, %arg3 : vector<2x32xf32>
%5 = arith.select %4, %3, %arg3 : vector<2x32xi1>, vector<2x32xf32>
vector.transfer_write %5, %arg4[%0, %2] : vector<2x32xf32>, memref<512x640xf32>

#### Output IR

%desc = xegpu.create_nd_tdesc %arg1[%0, %2] {mode = vc} : memref<512x640xf32> -> !xegpu.tensor_desc<2x32xf32>
%3 = xegpu.load_nd %desc {mode = vc}: !xegpu.tensor_desc<2x32xf32> -> vector<2x32xf32>
%4 = arith.cmpf ugt, %3, %arg3 : vector<2x32xf32>
%5 = arith.select %4, %3, %arg3 : vector<2x32xi1>, vector<2x32xf32>
%desc2 = xegpu.create_nd_tdesc %arg4[%0, %2] {mode = vc} : memref<512x640xf32> -> !xegpu.tensor_desc<2x32xf32>
xegpu.store_nd %5, %desc2 {mode = vc} : vector<2x32xf32>, !xegpu.tensor_desc<32xf32>
}];
let constructor = "::imex::createConvertVectorToXeGPUPass()";
let dependentDialects = ["::mlir::xegpu::XeGPUDialect"];
let options = [];
}

#endif // _IMEX_CONVERSION_PASSES_TD_INCLUDED_
Empty file.
37 changes: 37 additions & 0 deletions include/imex/Conversion/VectorToXeGPU/VectorToXeGPU.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
//===- VectorToXeGPU.h - VectorToXeGPU conversion -------*- C++ -*-===//
//
// Copyright 2022 Intel Corporation
// Part of the IMEX Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
///
/// \file
/// This file defines the VectorToXeGPU conversion, converting the Vector
/// dialect to the XeGPU dialect.
///
//===----------------------------------------------------------------------===//

#ifndef _VectorToXeGPU_H_INCLUDED_
#define _VectorToXeGPU_H_INCLUDED_

#include <mlir/IR/PatternMatch.h>
#include <mlir/Transforms/DialectConversion.h>

namespace mlir {
class LLVMTypeConverter;
class MLIRContext;
class ModuleOp;
template <typename T>
class OperationPass;
class RewritePatternSet;
}

namespace imex {
/// Create a pass to convert the Vector dialect to the XeGPU dialect.
std::unique_ptr<::mlir::OperationPass<::mlir::ModuleOp>> createConvertVectorToXeGPUPass();

} // namespace imex

#endif // _VectorToXeGPU_H_INCLUDED_
1 change: 1 addition & 0 deletions lib/Conversion/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ add_subdirectory(GPUXToLLVM)
add_subdirectory(XeGPUToSPIRV)
add_subdirectory(XeTileToXeGPU)
add_subdirectory(XeGPUToVC)
add_subdirectory(VectorToXeGPU)
12 changes: 12 additions & 0 deletions lib/Conversion/VectorToXeGPU/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
add_imex_conversion_library(IMEXVectorToXeGPU
VectorToXeGPU.cpp

ADDITIONAL_HEADER_DIRS
${MLIR_MAIN_INCLUDE_DIR}/imex/Conversion/VectorToXeGPU

DEPENDS
IMEXConversionPassIncGen

LINK_LIBS PUBLIC
MLIRXeGPUDialect
)
224 changes: 224 additions & 0 deletions lib/Conversion/VectorToXeGPU/VectorToXeGPU.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
//===- VectorToXeGPU.cpp - VectorToXeGPU conversion -------*- C++ -*-===//
//
// Copyright 2022 Intel Corporation
// Part of the IMEX Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
///
/// \file
/// This file implements the VectorToXeGPU conversion, converting the Vector
/// dialect to the XeGPU dialect.
///
//===----------------------------------------------------------------------===//

#include <imex/Conversion/VectorToXeGPU/VectorToXeGPU.h>
#include <imex/Utils/PassWrapper.h>
#include <mlir/Dialect/Vector/IR/VectorOps.h>
#include <mlir/Dialect/XeGPU/IR/XeGPU.h>

#include <mlir/IR/BuiltinOps.h>

#include "../PassDetail.h"
#include "imex/Conversion/XeTileToXeGPU/XeTileToXeGPUConversion.h"
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h"
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Interfaces/VectorInterfaces.h"
#include "mlir/Support/LogicalResult.h"

using namespace mlir;

namespace imex {

namespace {

class MyPatternRewriter : public PatternRewriter {
public:
MyPatternRewriter(MLIRContext *ctx) : PatternRewriter(ctx) {}

/// Override the necessary PatternRewriter hooks here.
};

struct MyTarget : public ConversionTarget {
MyTarget(MLIRContext &ctx) : ConversionTarget(ctx) {

/// Mark `cf.br` and `cf.cond_br` as illegal.
addIllegalOp<vector::TransferReadOp>(); //, vector::TransferWriteOp
}
};

// *******************************
// ***** Individual patterns *****
// *******************************

// Goal: vector.transfer_read -> xegpu.create_nd_tdesc + xegpu.load_nd
// E.g. translate
// %3 = vector.transfer_read %arg1[%0, %2], %arg2 : memref<512x640xf32>,
// vector<1x32xf32> to %desc = xegpu.create_nd_tdesc %arg1[%0, %2] {mode = vc}
// : memref<512x640xf32> -> !xegpu.tensor_desc<32xf32>
// to
// %4 = xegpu.load_nd %3 {mode = vc}: !xegpu.tensor_desc<32xf32> ->
// vector<32xf32>
// %5 = vector.shape_cast %4 : vector<1x32xf32> to vector<32xf32>

struct TransferReadOpConverter
: public OpRewritePattern<vector::TransferReadOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::TransferReadOp read,
PatternRewriter &rewriter) const override {
auto ctx = read->getContext();
auto resultTile = read.getResult();
auto resTileType = resultTile.getType();
auto resTileShape = resTileType.getShape();
auto rank = resTileType.getRank();
auto source = read.getSource();

ArrayRef<int64_t> loadShape;
if (rank == 1)
loadShape = {1, resTileShape[0]};
else
loadShape = resTileShape;
auto loadType = VectorType::get(loadShape, resTileType.getElementType());
auto tDescTy =
xegpu::TensorDescType::get(loadShape, resTileType.getElementType());
mlir::SmallVector<mlir::OpFoldResult> tDescOffsets{read->getOperand(1),
read->getOperand(2)};
rewriter.setInsertionPoint(read);
mlir::Value desc;
if (auto MemRefTypedSource =
mlir::cast<mlir::TypedValue<mlir::MemRefType>>(source)) {
desc = rewriter.create<mlir::xegpu::CreateNdDescOp>(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't there be a check for memref rank here? XeGPU supports limited ranks.

read.getLoc(), tDescTy, MemRefTypedSource, tDescOffsets);
} else {
return mlir::failure();
}

mlir::IntegerAttr vnniAxisAttr;
mlir::DenseI64ArrayAttr transposeAttr;
mlir::IntegerAttr transposeBitWidthAttr;
auto CACHED = mlir::xegpu::CachePolicy::CACHED;
auto L1 = mlir::xegpu::CachePolicyAttr::get(ctx, CACHED);
auto L2 = mlir::xegpu::CachePolicyAttr::get(ctx, CACHED);
auto L3 = mlir::xegpu::CachePolicyAttr::get(ctx, CACHED);
Operation *payload = rewriter.create<xegpu::LoadNdOp>(
read.getLoc(), loadType, desc, vnniAxisAttr, transposeAttr,
transposeBitWidthAttr, L1, L2, L3);

if (rank == 1) {
// xegpu currently don't support 1d vector load. We need to cast it to 2d
auto cast = rewriter.create<vector::ShapeCastOp>(
read.getLoc(), resTileType, payload->getResults());
if (auto map = read.getPermutationMap(); map.isSingleConstant()) {
SmallVector<int64_t> mask(resTileShape[0],
map.getSingleConstantResult());
payload =
rewriter.create<vector::ShuffleOp>(read.getLoc(), cast, cast, mask);
} else {
AffineExpr d0, d1;
bindDims(read.getContext(), d0, d1);
auto mp = AffineMap::get(map.getNumDims(), 0, {d1}, read.getContext());
// (d0, d1) -> (d1)
if (map != mp) {
// Unsupported permutation map
return ::mlir::failure();
}
payload = cast;
}
}
rewriter.replaceOp(read, payload->getResults());

return ::mlir::success();
}
};

// vector.transfer_write %5, %arg4[%0, %2] : vector<1x32xf32>,
// memref<512x640xf32> to %5 = vector.shape_cast %4 : vector<32xf32> to
// vector<1x32xf32> %desc2 = xegpu.create_nd_tdesc %arg4[%0, %2] {mode = vc} :
// memref<512x640xf32> -> !xegpu.tensor_desc<1x32xf32> xegpu.store_nd %5, %desc2
// {mode = vc} : vector<1x32xf32>, !xegpu.tensor_desc<1x32xf32>

struct TransferWriteOpConverter
: public OpRewritePattern<vector::TransferWriteOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(vector::TransferWriteOp write,
PatternRewriter &rewriter) const override {
auto ctx = write->getContext();
auto resultTile = write->getOperand(0); //%5
auto source = write.getSource(); // memref<512x640xi32>
auto resTileType = dyn_cast<VectorType>(resultTile.getType());
auto resTileShape = resTileType.getShape();
auto rank = resTileType.getRank();
auto intermediateType =
VectorType::get({1, resTileShape[0]}, resTileType.getElementType());

ArrayRef<int64_t> loadShape;
if (rank == 1)
loadShape = {1, resTileShape[0]};
else
loadShape = resTileShape;
auto tDescTy =
xegpu::TensorDescType::get(loadShape, resTileType.getElementType());
mlir::SmallVector<mlir::OpFoldResult> tDescOffsets{write->getOperand(2),
write->getOperand(3)};
rewriter.setInsertionPoint(write);
mlir::Value payload = write.getOperand(0);
if (rank == 1) {
payload = rewriter.create<vector::ShapeCastOp>(
write.getLoc(), intermediateType, write->getOperand(0));
}
mlir::Value desc;
if (auto MemRefTypedSource =
mlir::cast<mlir::TypedValue<mlir::MemRefType>>(source)) {
desc = rewriter.create<mlir::xegpu::CreateNdDescOp>(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as my comment above: Check rank and return failure if unsupported shape.

write.getLoc(), tDescTy /*resultTy*/, MemRefTypedSource /*source*/,
tDescOffsets /*offsets*/);
} else {
return mlir::failure();
}

auto WRITE_BACK = mlir::xegpu::CachePolicy::WRITE_BACK;
auto L1 = mlir::xegpu::CachePolicyAttr::get(ctx, WRITE_BACK);
auto L2 = mlir::xegpu::CachePolicyAttr::get(ctx, WRITE_BACK);
auto L3 = mlir::xegpu::CachePolicyAttr::get(ctx, WRITE_BACK);
rewriter.create<xegpu::StoreNdOp>(write.getLoc(), payload, desc, L1, L2,
L3);
rewriter.eraseOp(write);

return ::mlir::success();
}
};

// *******************************
// ***** Pass infrastructure *****
// *******************************

// Full Pass
struct ConvertVectorToXeGPUPass // convert Vector to XeGPU
: public ::imex::ConvertVectorToXeGPUBase<ConvertVectorToXeGPUPass> {
ConvertVectorToXeGPUPass() = default;

void runOnOperation() override {
auto *ctx = &getContext();
mlir::RewritePatternSet patterns(ctx);

patterns.insert<TransferReadOpConverter, TransferWriteOpConverter>(ctx);

(void)mlir::applyPatternsAndFoldGreedily(getOperation(),
std::move(patterns));
}
};

} // namespace

/// Populate the given list with patterns that convert Vector to XeGPU

/// Create a pass that convert Vector to XeGPU
std::unique_ptr<::mlir::OperationPass<::mlir::ModuleOp>>
createConvertVectorToXeGPUPass() {
return std::make_unique<ConvertVectorToXeGPUPass>();
}

} // namespace imex
2 changes: 2 additions & 0 deletions lib/Utils/XeCommon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,8 @@ encodeVectorType(mlir::ConversionPatternRewriter &rewriter,
} else if (elemType == rewriter.getBF16Type()) {
str += "i32";
elemType = rewriter.getI32Type();
} else if (elemType == rewriter.getI32Type()) {
str += "i32";
} else
assert(0 && "add more support");
auto newType = mlir::VectorType::get(size, elemType);
Expand Down
Loading