Skip to content

Commit

Permalink
[mlir] Introduce bare ptr calling convention for MemRefs in LLVM dialect
Browse files Browse the repository at this point in the history
Summary:
This patch introduces an alternative calling convention for
MemRef function arguments in LLVM dialect. It converts MemRef
function arguments to LLVM bare pointers to the MemRef element
type instead of creating a MemRef descriptor. Bare pointers are
then promoted to a MemRef descriptors at the beginning of the
function. This calling convention is only enabled with a flag.

This is a stepping stone towards having an alternative and simpler
lowering for MemRefs when dynamic shapes are not needed. It can
also be used to temporarily overcome the issue with passing 'noalias'
attribute for MemRef arguments, discussed in [1, 2], since we can
now convert:

func @check_noalias(%static : memref<2xf32> {llvm.noalias = true}) {
  return
}

into:

llvm.func @check_noalias(%arg0: !llvm<"float*"> {llvm.noalias = true}) {
  %0 = llvm.mlir.undef ...
  %1 = llvm.insertvalue %arg0, %0[0] ...
  %2 = llvm.insertvalue %arg0, %1[1] ...
  ...
  llvm.return
}

Related discussion:
  [1] tensorflow/mlir#309
  [2] tensorflow/mlir#337

WIP: I plan to move all the tests with only static shapes from
convert-memref-ops.mlir to an independent file so that
we can also have coverage for those tests with this
alternative calling convention.

Reviewers: ftynse, bondhugula, nicolasvasilache

Subscribers: jholewinski, mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, csigg, arpith-jacob, mgester, lucyrfox, herhut, aartbik, liufengdb, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D72802
  • Loading branch information
dcaballe committed Jan 15, 2020
1 parent d629525 commit 19bad95
Show file tree
Hide file tree
Showing 7 changed files with 221 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,9 @@ class LLVMTypeConverter : public TypeConverter {
/// Convert a function type. The arguments and results are converted one by
/// one and results are packed into a wrapped LLVM IR structure type. `result`
/// is populated with argument mapping.
LLVM::LLVMType convertFunctionSignature(FunctionType type, bool isVariadic,
SignatureConversion &result);
virtual LLVM::LLVMType convertFunctionSignature(FunctionType type,
bool isVariadic,
SignatureConversion &result);

/// Convert a non-empty list of types to be returned from a function into a
/// supported LLVM IR type. In particular, if more than one values is
Expand Down Expand Up @@ -81,6 +82,9 @@ class LLVMTypeConverter : public TypeConverter {
llvm::Module *module;
LLVM::LLVMDialect *llvmDialect;

// Extract an LLVM IR dialect type.
LLVM::LLVMType unwrap(Type type);

private:
Type convertStandardType(Type type);

Expand Down Expand Up @@ -120,9 +124,24 @@ class LLVMTypeConverter : public TypeConverter {
// Get the LLVM representation of the index type based on the bitwidth of the
// pointer as defined by the data layout of the module.
LLVM::LLVMType getIndexType();
};

// Extract an LLVM IR dialect type.
LLVM::LLVMType unwrap(Type type);
/// Custom LLVMTypeConverter that overrides `convertFunctionSignature` to
/// replace the type of MemRef function arguments with bare pointer to the
/// MemRef element type.
class BarePtrTypeConverter : public mlir::LLVMTypeConverter {
public:
using LLVMTypeConverter::LLVMTypeConverter;

/// Converts function signature following LLVMTypeConverter approach but
/// replacing the type of MemRef arguments with a bare LLVM pointer to
/// the MemRef element type.
mlir::LLVM::LLVMType convertFunctionSignature(
mlir::FunctionType type, bool isVariadic,
mlir::LLVMTypeConverter::SignatureConversion &result) override;

private:
mlir::Type convertMemRefTypeToBarePtr(mlir::MemRefType type);
};

/// Helper class to produce LLVM dialect operations extracting or inserting
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@ using LLVMTypeConverterMaker =
std::function<std::unique_ptr<LLVMTypeConverter>(MLIRContext *)>;

/// Collect a set of patterns to convert memory-related operations from the
/// Standard dialect to the LLVM dialect, excluding the memory-related
/// operations.
/// Standard dialect to the LLVM dialect, excluding non-memory-related
/// operations and FuncOp.
void populateStdToLLVMMemoryConversionPatters(
LLVMTypeConverter &converter, OwningRewritePatternList &patterns);

Expand All @@ -54,10 +54,26 @@ void populateStdToLLVMMemoryConversionPatters(
void populateStdToLLVMNonMemoryConversionPatterns(
LLVMTypeConverter &converter, OwningRewritePatternList &patterns);

/// Collect a set of patterns to convert from the Standard dialect to LLVM.
/// Collect the default pattern to convert a FuncOp to the LLVM dialect.
void populateStdToLLVMDefaultFuncOpConversionPattern(
LLVMTypeConverter &converter, OwningRewritePatternList &patterns);

/// Collect a set of default patterns to convert from the Standard dialect to
/// LLVM.
void populateStdToLLVMConversionPatterns(LLVMTypeConverter &converter,
OwningRewritePatternList &patterns);

/// Collect the pattern to convert a FuncOp to the LLVM dialect using the bare
/// pointer calling convertion for MemRef function arguments.
void populateStdToLLVMBarePtrFuncOpConversionPattern(
LLVMTypeConverter &converter, OwningRewritePatternList &patterns);

/// Collect a set of patterns to convert from the Standard dialect to
/// LLVM using the bare pointer calling convention for MemRef function
/// arguments.
void populateStdToLLVMBarePtrConversionPatterns(
LLVMTypeConverter &converter, OwningRewritePatternList &patterns);

/// Creates a pass to convert the Standard dialect into the LLVMIR dialect.
/// By default stdlib malloc/free are used for allocating MemRef payloads.
/// Specifying `useAlloca-true` emits stack allocations instead. In the future
Expand Down
2 changes: 1 addition & 1 deletion mlir/include/mlir/Transforms/DialectConversion.h
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@ class ConversionPatternRewriter final : public PatternRewriter {
TypeConverter::SignatureConversion &conversion);

/// Replace all the uses of the block argument `from` with value `to`.
void replaceUsesOfBlockArgument(BlockArgument from, Value to);
void replaceUsesOfWith(Value from, Value to);

/// Return the converted value that replaces 'key'. Return 'key' if there is
/// no such a converted value.
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -667,7 +667,7 @@ struct GPUFuncOpLowering : LLVMOpLowering {

BlockArgument arg = block.getArgument(en.index());
Value loaded = rewriter.create<LLVM::LoadOp>(loc, arg);
rewriter.replaceUsesOfBlockArgument(arg, loaded);
rewriter.replaceUsesOfWith(arg, loaded);
}
}

Expand Down
171 changes: 167 additions & 4 deletions mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@ static llvm::cl::opt<bool>
llvm::cl::desc("Replace emission of malloc/free by alloca"),
llvm::cl::init(false));

static llvm::cl::opt<bool> clUseBarePtrCallConv(
PASS_NAME "-use-bare-ptr-memref-call-conv",
llvm::cl::desc("Replace FuncOp's MemRef arguments with "
"bare pointers to the MemRef element types"),
llvm::cl::init(false));

LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx)
: llvmDialect(ctx->getRegisteredDialect<LLVM::LLVMDialect>()) {
assert(llvmDialect && "LLVM IR dialect is not registered");
Expand Down Expand Up @@ -239,6 +245,60 @@ Type LLVMTypeConverter::convertStandardType(Type t) {
.Default([](Type) { return Type(); });
}

// Converts function signature following LLVMTypeConverter approach but
// replacing the type of MemRef arguments with a bare LLVM pointer to
// the MemRef element type.
LLVM::LLVMType BarePtrTypeConverter::convertFunctionSignature(
FunctionType type, bool isVariadic,
LLVMTypeConverter::SignatureConversion &result) {
// Convert argument types one by one and check for errors.
for (auto &en : llvm::enumerate(type.getInputs())) {
Type type = en.value();
Type converted;
if (auto memrefTy = type.dyn_cast<MemRefType>())
converted = convertMemRefTypeToBarePtr(memrefTy)
.dyn_cast_or_null<LLVM::LLVMType>();
else
converted = convertType(type).dyn_cast_or_null<LLVM::LLVMType>();

if (!converted)
return {};
result.addInputs(en.index(), converted);
}

SmallVector<LLVM::LLVMType, 8> argTypes;
argTypes.reserve(llvm::size(result.getConvertedTypes()));
for (Type type : result.getConvertedTypes())
argTypes.push_back(unwrap(type));

// If function does not return anything, create the void result type, if it
// returns on element, convert it, otherwise pack the result types into a
// struct.
LLVM::LLVMType resultType =
type.getNumResults() == 0
? LLVM::LLVMType::getVoidTy(llvmDialect)
: unwrap(packFunctionResults(type.getResults()));
if (!resultType)
return {};
return LLVM::LLVMType::getFunctionTy(resultType, argTypes, isVariadic);
}

// Converts MemRefType to a bare LLVM pointer to the MemRef element type.
Type BarePtrTypeConverter::convertMemRefTypeToBarePtr(MemRefType type) {
int64_t offset;
SmallVector<int64_t, 4> strides;
bool strideSuccess = succeeded(getStridesAndOffset(type, strides, offset));
assert(strideSuccess &&
"Non-strided layout maps must have been normalized away");
(void)strideSuccess;

LLVM::LLVMType elementType = unwrap(convertType(type.getElementType()));
if (!elementType)
return {};
auto ptrTy = elementType.getPointerTo(type.getMemorySpace());
return ptrTy;
}

LLVMOpLowering::LLVMOpLowering(StringRef rootOpName, MLIRContext *context,
LLVMTypeConverter &lowering_,
PatternBenefit benefit)
Expand Down Expand Up @@ -548,7 +608,84 @@ struct FuncOpConversion : public LLVMLegalizationPattern<FuncOp> {
for (unsigned idx : promotedArgIndices) {
BlockArgument arg = firstBlock->getArgument(idx);
Value loaded = rewriter.create<LLVM::LoadOp>(funcOp.getLoc(), arg);
rewriter.replaceUsesOfBlockArgument(arg, loaded);
rewriter.replaceUsesOfWith(arg, loaded);
}
}

rewriter.eraseOp(op);
return matchSuccess();
}
};

// FuncOp conversion that converts MemRef arguments to bare pointers to the type
// of the MemRef.
struct BarePtrFuncOpConversion : public LLVMLegalizationPattern<FuncOp> {
using LLVMLegalizationPattern<FuncOp>::LLVMLegalizationPattern;

PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto funcOp = cast<FuncOp>(op);
FunctionType type = funcOp.getType();
auto funcLoc = funcOp.getLoc();

// Store the positions of memref-typed arguments so that we can promote them
// to MemRef descriptor structs at the beginning of the function.
SmallVector<std::pair<unsigned, Type>, 4> promotedArgIndices;
promotedArgIndices.reserve(type.getNumInputs());
for (auto en : llvm::enumerate(type.getInputs())) {
if (en.value().isa<MemRefType>())
promotedArgIndices.push_back({en.index(), en.value()});
}

// Convert the original function arguments. MemRef types are lowered to bare
// pointers to the MemRef element type.
auto varargsAttr = funcOp.getAttrOfType<BoolAttr>("std.varargs");
TypeConverter::SignatureConversion result(funcOp.getNumArguments());
auto llvmType = lowering.convertFunctionSignature(
funcOp.getType(), varargsAttr && varargsAttr.getValue(), result);

// Only retain those attributes that are not constructed by build.
SmallVector<NamedAttribute, 4> attributes;
for (const auto &attr : funcOp.getAttrs()) {
if (attr.first.is(SymbolTable::getSymbolAttrName()) ||
attr.first.is(impl::getTypeAttrName()) ||
attr.first.is("std.varargs"))
continue;
attributes.push_back(attr);
}

// Create an LLVM function, use external linkage by default until MLIR
// functions have linkage.
auto newFuncOp =
rewriter.create<LLVM::LLVMFuncOp>(funcLoc, funcOp.getName(), llvmType,
LLVM::Linkage::External, attributes);
rewriter.inlineRegionBefore(funcOp.getBody(), newFuncOp.getBody(),
newFuncOp.end());

// Tell the rewriter to convert the region signature.
rewriter.applySignatureConversion(&newFuncOp.getBody(), result);

// Promote bare pointers from MemRef arguments to a MemRef descriptor struct
// at the beginning of the function so that all the MemRefs in the function
// have a uniform representation.
if (!newFuncOp.getBody().empty()) {
Block *firstBlock = &newFuncOp.getBody().front();
rewriter.setInsertionPoint(firstBlock, firstBlock->begin());
for (auto argIdxTypePair : promotedArgIndices) {
// Replace argument with a placeholder (undef), promote argument to a
// MemRef descriptor and replace placeholder with the last instruction
// of the MemRef descriptor. The placeholder is needed to avoid
// replacing argument uses in the MemRef descriptor instructions.
BlockArgument arg = firstBlock->getArgument(argIdxTypePair.first);
Value placeHolder =
rewriter.create<LLVM::UndefOp>(funcLoc, arg.getType());
rewriter.replaceUsesOfWith(arg, placeHolder);
auto desc = MemRefDescriptor::fromStaticShape(
rewriter, funcLoc, lowering,
argIdxTypePair.second.cast<MemRefType>(), arg);
rewriter.replaceUsesOfWith(placeHolder, desc);
placeHolder.getDefiningOp()->erase();
}
}

Expand Down Expand Up @@ -2126,7 +2263,6 @@ void mlir::populateStdToLLVMMemoryConversionPatters(
// clang-format off
patterns.insert<
DimOpLowering,
FuncOpConversion,
LoadOpLowering,
MemRefCastOpLowering,
StoreOpLowering,
Expand All @@ -2139,8 +2275,26 @@ void mlir::populateStdToLLVMMemoryConversionPatters(
// clang-format on
}

void mlir::populateStdToLLVMDefaultFuncOpConversionPattern(
LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
patterns.insert<FuncOpConversion>(*converter.getDialect(), converter);
}

void mlir::populateStdToLLVMConversionPatterns(
LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
populateStdToLLVMDefaultFuncOpConversionPattern(converter, patterns);
populateStdToLLVMNonMemoryConversionPatterns(converter, patterns);
populateStdToLLVMMemoryConversionPatters(converter, patterns);
}

void mlir::populateStdToLLVMBarePtrFuncOpConversionPattern(
LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
patterns.insert<BarePtrFuncOpConversion>(*converter.getDialect(), converter);
}

void mlir::populateStdToLLVMBarePtrConversionPatterns(
LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
populateStdToLLVMBarePtrFuncOpConversionPattern(converter, patterns);
populateStdToLLVMNonMemoryConversionPatterns(converter, patterns);
populateStdToLLVMMemoryConversionPatters(converter, patterns);
}
Expand Down Expand Up @@ -2210,6 +2364,12 @@ makeStandardToLLVMTypeConverter(MLIRContext *context) {
return std::make_unique<LLVMTypeConverter>(context);
}

/// Create an instance of BarePtrTypeConverter in the given context.
static std::unique_ptr<LLVMTypeConverter>
makeStandardToLLVMBarePtrTypeConverter(MLIRContext *context) {
return std::make_unique<BarePtrTypeConverter>(context);
}

namespace {
/// A pass converting MLIR operations into the LLVM IR dialect.
struct LLVMLoweringPass : public ModulePass<LLVMLoweringPass> {
Expand Down Expand Up @@ -2274,6 +2434,9 @@ static PassRegistration<LLVMLoweringPass>
"Standard to the LLVM dialect",
[] {
return std::make_unique<LLVMLoweringPass>(
clUseAlloca.getValue(), populateStdToLLVMConversionPatterns,
makeStandardToLLVMTypeConverter);
clUseAlloca.getValue(),
clUseBarePtrCallConv ? populateStdToLLVMBarePtrConversionPatterns
: populateStdToLLVMConversionPatterns,
clUseBarePtrCallConv ? makeStandardToLLVMBarePtrTypeConverter
: makeStandardToLLVMTypeConverter);
});
3 changes: 1 addition & 2 deletions mlir/lib/Transforms/DialectConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -861,8 +861,7 @@ Block *ConversionPatternRewriter::applySignatureConversion(
return impl->applySignatureConversion(region, conversion);
}

void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
Value to) {
void ConversionPatternRewriter::replaceUsesOfWith(Value from, Value to) {
for (auto &u : from.getUses()) {
if (u.getOwner() == to.getDefiningOp())
continue;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
// RUN: mlir-opt -convert-std-to-llvm -split-input-file -convert-std-to-llvm-use-bare-ptr-memref-call-conv=1 %s | FileCheck %s --check-prefix=BAREPTR

// BAREPTR-LABEL: func @check_noalias
// BAREPTR-SAME: [[ARG:%.*]]: !llvm<"float*"> {llvm.noalias = true}
func @check_noalias(%static : memref<2xf32> {llvm.noalias = true}) {
return
}

// WIP: Move tests with static shapes from convert-memref-ops.mlir here.

0 comments on commit 19bad95

Please sign in to comment.