From 6d936965f5b5ed2e14e89fc80e60e737c3170f5e Mon Sep 17 00:00:00 2001 From: Justin Lebar Date: Fri, 14 Jun 2024 11:43:11 -0700 Subject: [PATCH] s/ConversionPatternRewriter/RewriterBase/ in utils (NFC) (#4140) You can't create a ConversionPatternRewriter in C++ unit tests, so this means that anything which is touched from a C++ unit test cannot transitively touch anything which uses a ConversionPatternRewriter. This is a simple search+replace, there's no functional difference between these types for our purposes. --- .../TritonGPUToLLVM/TargetInfoBase.h | 36 +++++------ .../Conversion/TritonGPUToLLVM/Utility.h | 56 +++++++++--------- lib/Conversion/TritonGPUToLLVM/Utility.cpp | 37 ++++++------ .../amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp | 49 ++++++++------- .../amd/lib/TritonAMDGPUToLLVM/TargetInfo.h | 49 +++++++-------- .../amd/lib/TritonAMDGPUToLLVM/Utility.cpp | 29 ++++----- .../amd/lib/TritonAMDGPUToLLVM/Utility.h | 26 ++++---- .../lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp | 59 +++++++++---------- .../lib/TritonNVIDIAGPUToLLVM/TargetInfo.h | 47 +++++++-------- .../lib/TritonNVIDIAGPUToLLVM/Utility.cpp | 25 ++++---- .../lib/TritonNVIDIAGPUToLLVM/Utility.h | 22 +++---- 11 files changed, 205 insertions(+), 230 deletions(-) diff --git a/include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h b/include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h index f977d30c0259..61f8fb87b4a6 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h +++ b/include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h @@ -10,34 +10,34 @@ class TargetInfoBase { virtual Value getClusterCTAId(RewriterBase &rewriter, Location loc) const = 0; - virtual Value ballot(ConversionPatternRewriter &rewriter, Location loc, - Type type, Value cmp) const = 0; + virtual Value ballot(RewriterBase &rewriter, Location loc, Type type, + Value cmp) const = 0; - virtual void storeShared(ConversionPatternRewriter &rewriter, Location loc, - Value ptr, Value val, Value pred) const = 0; - virtual Value loadShared(ConversionPatternRewriter &rewriter, Location loc, + virtual void storeShared(RewriterBase &rewriter, Location loc, Value ptr, + Value val, Value pred) const = 0; + virtual Value loadShared(RewriterBase &rewriter, Location loc, const TypeConverter *converter, Value ptr, Type elemTy, Value pred) const = 0; - virtual Value shuffleXor(ConversionPatternRewriter &rewriter, Location loc, - Value val, int i) const = 0; - virtual Value shuffleUp(ConversionPatternRewriter &rewriter, Location loc, - Value val, int i) const = 0; - virtual Value shuffleIdx(ConversionPatternRewriter &rewriter, Location loc, - Value val, int i) const = 0; - virtual Value shuffleIdx(ConversionPatternRewriter &rewriter, Location loc, - Value val, Value i) const = 0; + virtual Value shuffleXor(RewriterBase &rewriter, Location loc, Value val, + int i) const = 0; + virtual Value shuffleUp(RewriterBase &rewriter, Location loc, Value val, + int i) const = 0; + virtual Value shuffleIdx(RewriterBase &rewriter, Location loc, Value val, + int i) const = 0; + virtual Value shuffleIdx(RewriterBase &rewriter, Location loc, Value val, + Value i) const = 0; - virtual Value programId(ConversionPatternRewriter &rewriter, Location loc, + virtual Value programId(RewriterBase &rewriter, Location loc, ModuleOp moduleOp, int axis) const = 0; - virtual bool warpReduce(ConversionPatternRewriter &rewriter, Location loc, + virtual bool warpReduce(RewriterBase &rewriter, Location loc, SmallVector &acc, triton::ReduceOp op, unsigned numLaneToReduce, unsigned interleave) const = 0; virtual bool processReplicaUsingStMatrix( - ConversionPatternRewriter &rewriter, Location loc, Value smemBase, + RewriterBase &rewriter, Location loc, Value smemBase, SmallVector &vals, RankedTensorType srcTy, Type elemTy, ArrayRef paddedRepShape, ArrayRef origRepShape, ArrayRef outOrd, unsigned accumNumReplicates, @@ -48,11 +48,11 @@ class TargetInfoBase { // format from the device. |formatStrStart| is the pointer to the start of // the format string global variable; |args| are the arguments to fill // placeholders in the format string. - virtual void printf(ConversionPatternRewriter &rewriter, Value formatStrStart, + virtual void printf(RewriterBase &rewriter, Value formatStrStart, int formatStrByteCount, ValueRange args) const = 0; // Emits LLVM code with |rewriter| to perform assertion failure with the given // |message| from the given |func| in |file|. - virtual void assertFail(ConversionPatternRewriter &rewriter, Location loc, + virtual void assertFail(RewriterBase &rewriter, Location loc, StringRef message, StringRef file, StringRef func, int line) const = 0; diff --git a/include/triton/Conversion/TritonGPUToLLVM/Utility.h b/include/triton/Conversion/TritonGPUToLLVM/Utility.h index 6ad63e37967f..54f79ab0a257 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/Utility.h +++ b/include/triton/Conversion/TritonGPUToLLVM/Utility.h @@ -202,9 +202,9 @@ T getLinearIndex(llvm::ArrayRef multiDimIndex, llvm::ArrayRef shape, namespace gpu { Type getFunctionType(Type resultType, ValueRange operands); -LLVM::LLVMFuncOp appendOrGetExternFuncOp(ConversionPatternRewriter &rewriter, - Operation *op, StringRef funcName, - Type funcType, StringRef libname = "", +LLVM::LLVMFuncOp appendOrGetExternFuncOp(RewriterBase &rewriter, Operation *op, + StringRef funcName, Type funcType, + StringRef libname = "", StringRef libpath = ""); } // namespace gpu @@ -305,7 +305,7 @@ struct SharedMemoryObject { } Value getBaseBeforeSlice(int order, Location loc, - ConversionPatternRewriter &rewriter) const { + RewriterBase &rewriter) const { Value cSwizzleOffset = getCSwizzleOffset(order); Value offset = sub(i32_val(0), cSwizzleOffset); Type type = base.getType(); @@ -313,9 +313,10 @@ struct SharedMemoryObject { } }; -SharedMemoryObject -getSharedMemoryObjectFromStruct(Location loc, Value llvmStruct, Type elemTy, - ConversionPatternRewriter &rewriter); +SharedMemoryObject getSharedMemoryObjectFromStruct(Location loc, + Value llvmStruct, + Type elemTy, + RewriterBase &rewriter); // Convert an \param index to a multi-dim coordinate given \param shape and // \param order. @@ -329,15 +330,14 @@ SmallVector delinearize(RewriterBase &rewriter, Location loc, SmallVector delinearize(RewriterBase &rewriter, Location loc, Value linear, ArrayRef shape); -Value linearize(ConversionPatternRewriter &rewriter, Location loc, - ArrayRef multiDim, ArrayRef shape, - ArrayRef order); +Value linearize(RewriterBase &rewriter, Location loc, ArrayRef multiDim, + ArrayRef shape, ArrayRef order); -Value linearize(ConversionPatternRewriter &rewriter, Location loc, - ArrayRef multiDim, ArrayRef shape); +Value linearize(RewriterBase &rewriter, Location loc, ArrayRef multiDim, + ArrayRef shape); -Value addStringToModule(Location loc, ConversionPatternRewriter &rewriter, - StringRef key, StringRef content); +Value addStringToModule(Location loc, RewriterBase &rewriter, StringRef key, + StringRef content); // Given an elemId which represents the index of an element from the list of // elements that are in the thread's registers (i.e. total of @@ -346,7 +346,7 @@ Value addStringToModule(Location loc, ConversionPatternRewriter &rewriter, // when converting distributed to distributed layout. Also, a replica is the // smallest CTA tile that is common between input and output layouts. SmallVector getMultiDimOffset(Attribute layout, Location loc, - ConversionPatternRewriter &rewriter, + RewriterBase &rewriter, const TargetInfoBase &targetInfo, unsigned elemId, RankedTensorType type, ArrayRef multiDimCTAInRepId, @@ -355,15 +355,15 @@ SmallVector getMultiDimOffset(Attribute layout, Location loc, // Given a multiDimOffset, this function wraps around each dimension to be // within shape. SmallVector getWrappedMultiDimOffset( - ConversionPatternRewriter &rewriter, Location loc, - ArrayRef multiDimOffset, ArrayRef shape, - SmallVector shapePerCTATile, SmallVector shapePerCTA); + RewriterBase &rewriter, Location loc, ArrayRef multiDimOffset, + ArrayRef shape, SmallVector shapePerCTATile, + SmallVector shapePerCTA); inline bool isKernel(FunctionOpInterface funcOp) { return funcOp.getVisibility() == SymbolTable::Visibility::Public; } -inline Value getStackPointer(PatternRewriter &rewriter, +inline Value getStackPointer(RewriterBase &rewriter, FunctionOpInterface funcOp) { auto mod = funcOp->getParentOfType(); LLVM::GlobalOp globalBase = nullptr; @@ -378,8 +378,7 @@ inline Value getStackPointer(PatternRewriter &rewriter, return funcOp.getArgument(funcOp.getNumArguments() - 1); } -inline Value getSharedMemoryBase(Location loc, - ConversionPatternRewriter &rewriter, +inline Value getSharedMemoryBase(Location loc, RewriterBase &rewriter, Operation *op) { auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), 3); FunctionOpInterface func = @@ -1566,9 +1565,9 @@ inline void storeDistributedToShared(MemDescType dstTy, RankedTensorType srcTy, } } -inline Value -getStructFromSharedMemoryObject(Location loc, const SharedMemoryObject &smemObj, - ConversionPatternRewriter &rewriter) { +inline Value getStructFromSharedMemoryObject(Location loc, + const SharedMemoryObject &smemObj, + RewriterBase &rewriter) { auto elems = smemObj.getElems(); auto types = smemObj.getTypes(); auto structTy = @@ -1582,9 +1581,8 @@ getStructFromSharedMemoryObject(Location loc, const SharedMemoryObject &smemObj, return llvmStruct; } -inline SmallVector -unpackLLElements(Location loc, Value llvmStruct, - ConversionPatternRewriter &rewriter) { +inline SmallVector unpackLLElements(Location loc, Value llvmStruct, + RewriterBase &rewriter) { assert(bool(llvmStruct) && "can not unpack null values"); if (llvmStruct.getType().isIntOrIndexOrFloat() || isa(llvmStruct.getType()) || @@ -1602,8 +1600,8 @@ unpackLLElements(Location loc, Value llvmStruct, inline Value packLLElements(Location loc, const LLVMTypeConverter *typeConverter, - ValueRange resultVals, - ConversionPatternRewriter &rewriter, Type type) { + ValueRange resultVals, RewriterBase &rewriter, + Type type) { auto structType = dyn_cast(typeConverter->convertType(type)); if (!structType) { diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.cpp b/lib/Conversion/TritonGPUToLLVM/Utility.cpp index 7811a2ef56d0..eaaa690c0b19 100644 --- a/lib/Conversion/TritonGPUToLLVM/Utility.cpp +++ b/lib/Conversion/TritonGPUToLLVM/Utility.cpp @@ -12,7 +12,7 @@ using CoordTy = SmallVector; using ValueTable = std::map, std::pair>; static SmallVector -getMNCoords(Value thread, Location loc, ConversionPatternRewriter &rewriter, +getMNCoords(Value thread, Location loc, RewriterBase &rewriter, ArrayRef wpt, const NvidiaMmaEncodingAttr &mmaLayout, ArrayRef shape, bool isARow, bool isBRow, bool isAVec4, bool isBVec4) { @@ -120,9 +120,8 @@ Type getFunctionType(Type resultType, ValueRange operands) { return LLVM::LLVMFunctionType::get(resultType, operandTypes); } -LLVM::LLVMFuncOp appendOrGetExternFuncOp(ConversionPatternRewriter &rewriter, - Operation *op, StringRef funcName, - Type funcType, +LLVM::LLVMFuncOp appendOrGetExternFuncOp(RewriterBase &rewriter, Operation *op, + StringRef funcName, Type funcType, StringRef libname /*= ""*/, StringRef libpath /*= ""*/) { using LLVM::LLVMFuncOp; @@ -496,9 +495,10 @@ Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width, builder.getIntegerAttr(ty, value)); } -SharedMemoryObject -getSharedMemoryObjectFromStruct(Location loc, Value llvmStruct, Type elemTy, - ConversionPatternRewriter &rewriter) { +SharedMemoryObject getSharedMemoryObjectFromStruct(Location loc, + Value llvmStruct, + Type elemTy, + RewriterBase &rewriter) { ArrayRef types = cast(llvmStruct.getType()).getBody(); SmallVector elems(types.size()); @@ -580,15 +580,14 @@ SmallVector delinearize(RewriterBase &rewriter, Location loc, return multiDim; } -Value linearize(ConversionPatternRewriter &rewriter, Location loc, - ArrayRef multiDim, ArrayRef shape, - ArrayRef order) { +Value linearize(RewriterBase &rewriter, Location loc, ArrayRef multiDim, + ArrayRef shape, ArrayRef order) { return linearize(rewriter, loc, applyPermutation(multiDim, order), applyPermutation(shape, order)); } -Value linearize(ConversionPatternRewriter &rewriter, Location loc, - ArrayRef multiDim, ArrayRef shape) { +Value linearize(RewriterBase &rewriter, Location loc, ArrayRef multiDim, + ArrayRef shape) { auto rank = multiDim.size(); Value linear = i32_val(0); if (rank > 0) { @@ -602,8 +601,8 @@ Value linearize(ConversionPatternRewriter &rewriter, Location loc, return linear; } -Value addStringToModule(Location loc, ConversionPatternRewriter &rewriter, - StringRef key, StringRef content) { +Value addStringToModule(Location loc, RewriterBase &rewriter, StringRef key, + StringRef content) { auto moduleOp = rewriter.getBlock()->getParent()->getParentOfType(); auto ctx = moduleOp.getContext(); unsigned stringNumber = 0; @@ -619,7 +618,7 @@ Value addStringToModule(Location loc, ConversionPatternRewriter &rewriter, LLVM::GlobalOp global; { - ConversionPatternRewriter::InsertionGuard guard(rewriter); + RewriterBase::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(moduleOp.getBody()); global = rewriter.create( UnknownLoc::get(ctx), globalType, @@ -637,7 +636,7 @@ Value addStringToModule(Location loc, ConversionPatternRewriter &rewriter, } SmallVector getMultiDimOffset(Attribute layout, Location loc, - ConversionPatternRewriter &rewriter, + RewriterBase &rewriter, const TargetInfoBase &targetInfo, unsigned elemId, RankedTensorType type, ArrayRef multiDimCTAInRepId, @@ -791,9 +790,9 @@ SmallVector getMultiDimOffset(Attribute layout, Location loc, } SmallVector getWrappedMultiDimOffset( - ConversionPatternRewriter &rewriter, Location loc, - ArrayRef multiDimOffset, ArrayRef shape, - SmallVector shapePerCTATile, SmallVector shapePerCTA) { + RewriterBase &rewriter, Location loc, ArrayRef multiDimOffset, + ArrayRef shape, SmallVector shapePerCTATile, + SmallVector shapePerCTA) { unsigned rank = shape.size(); SmallVector multiDimOffsetWrapped(rank); for (unsigned d = 0; d < rank; ++d) { diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp index 55e3609da137..527b89d30549 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.cpp @@ -10,12 +10,11 @@ namespace mlir::triton::AMD { namespace { template LLVM::LLVMFuncOp getOrInsertFunction(T &moduleOp, const Location loc, - ConversionPatternRewriter &rewriter, - StringRef name, + RewriterBase &rewriter, StringRef name, LLVM::LLVMFunctionType type) { LLVM::LLVMFuncOp ret; if (!(ret = moduleOp.template lookupSymbol(name))) { - ConversionPatternRewriter::InsertionGuard guard(rewriter); + RewriterBase::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(moduleOp.getBody()); ret = rewriter.create(loc, name, type, LLVM::Linkage::External); @@ -24,7 +23,7 @@ LLVM::LLVMFuncOp getOrInsertFunction(T &moduleOp, const Location loc, } // Extend all values to 64-bit per printf call requirements. -Value printfPromoteValue(ConversionPatternRewriter &rewriter, Value value) { +Value printfPromoteValue(RewriterBase &rewriter, Value value) { auto *context = rewriter.getContext(); auto loc = UnknownLoc::get(context); auto type = value.getType(); @@ -68,8 +67,8 @@ Value TargetInfo::getClusterCTAId(RewriterBase &rewriter, Location loc) const { return rewriter.create(loc, 0, 32); } -Value TargetInfo::ballot(ConversionPatternRewriter &rewriter, Location loc, - Type type, Value cmp) const { +Value TargetInfo::ballot(RewriterBase &rewriter, Location loc, Type type, + Value cmp) const { auto stringAttr = rewriter.getStringAttr("llvm.amdgcn.ballot"); SmallVector operands = {cmp}; Value asmResult = @@ -78,12 +77,12 @@ Value TargetInfo::ballot(ConversionPatternRewriter &rewriter, Location loc, return asmResult; } -void TargetInfo::storeShared(ConversionPatternRewriter &rewriter, Location loc, - Value ptr, Value val, Value pred) const { +void TargetInfo::storeShared(RewriterBase &rewriter, Location loc, Value ptr, + Value val, Value pred) const { mlir::LLVM::AMD::llStore(rewriter, loc, ptr, val, pred); } -Value TargetInfo::loadShared(ConversionPatternRewriter &rewriter, Location loc, +Value TargetInfo::loadShared(RewriterBase &rewriter, Location loc, const TypeConverter *converter, Value ptr, Type elemTy, Value pred) const { Value falseVal = rewriter.create( @@ -91,32 +90,32 @@ Value TargetInfo::loadShared(ConversionPatternRewriter &rewriter, Location loc, return mlir::LLVM::AMD::llLoad(rewriter, loc, ptr, elemTy, pred, falseVal); } -Value TargetInfo::shuffleXor(ConversionPatternRewriter &rewriter, Location loc, - Value val, int i) const { +Value TargetInfo::shuffleXor(RewriterBase &rewriter, Location loc, Value val, + int i) const { return LLVM::AMD::shuffleXor(loc, rewriter, val, i); } -Value TargetInfo::shuffleUp(ConversionPatternRewriter &rewriter, Location loc, - Value val, int i) const { +Value TargetInfo::shuffleUp(RewriterBase &rewriter, Location loc, Value val, + int i) const { return LLVM::AMD::shuffleUp(loc, rewriter, val, i); } -Value TargetInfo::shuffleIdx(ConversionPatternRewriter &rewriter, Location loc, - Value val, int i) const { +Value TargetInfo::shuffleIdx(RewriterBase &rewriter, Location loc, Value val, + int i) const { return LLVM::AMD::shuffleIdx(loc, rewriter, val, i); } -Value TargetInfo::shuffleIdx(ConversionPatternRewriter &rewriter, Location loc, - Value val, Value i) const { +Value TargetInfo::shuffleIdx(RewriterBase &rewriter, Location loc, Value val, + Value i) const { return LLVM::AMD::shuffleIdx(loc, rewriter, val, i); } -Value TargetInfo::programId(ConversionPatternRewriter &rewriter, Location loc, +Value TargetInfo::programId(RewriterBase &rewriter, Location loc, ModuleOp moduleOp, int axis) const { return LLVM::AMD::llGetPid(loc, rewriter, moduleOp, axis); } -bool TargetInfo::warpReduce(ConversionPatternRewriter &rewriter, Location loc, +bool TargetInfo::warpReduce(RewriterBase &rewriter, Location loc, SmallVector &acc, triton::ReduceOp op, unsigned numLaneToReduce, unsigned interleave) const { @@ -124,7 +123,7 @@ bool TargetInfo::warpReduce(ConversionPatternRewriter &rewriter, Location loc, } bool TargetInfo::processReplicaUsingStMatrix( - ConversionPatternRewriter &rewriter, Location loc, Value smemBase, + RewriterBase &rewriter, Location loc, Value smemBase, SmallVector &vals, RankedTensorType srcTy, Type elemTy, ArrayRef paddedRepShape, ArrayRef origRepShape, ArrayRef outOrd, unsigned accumNumReplicates, @@ -133,8 +132,7 @@ bool TargetInfo::processReplicaUsingStMatrix( } void TargetInfo::printfImpl(Value formatStrStart, int formatStrByteCount, - ValueRange args, - ConversionPatternRewriter &rewriter, + ValueRange args, RewriterBase &rewriter, bool useStdErr) const { auto moduleOp = rewriter.getBlock()->getParent()->getParentOfType(); auto *ctx = rewriter.getContext(); @@ -205,14 +203,13 @@ std::string TargetInfo::getMulhiFuncName(Type resultElementTy) const { return funcName; } -void TargetInfo::printf(ConversionPatternRewriter &rewriter, - Value formatStrStart, int formatStrByteCount, - ValueRange args) const { +void TargetInfo::printf(RewriterBase &rewriter, Value formatStrStart, + int formatStrByteCount, ValueRange args) const { return printfImpl(formatStrStart, formatStrByteCount, args, rewriter, /*useStdError=*/false); } -void TargetInfo::assertFail(ConversionPatternRewriter &rewriter, Location loc, +void TargetInfo::assertFail(RewriterBase &rewriter, Location loc, StringRef message, StringRef file, StringRef func, int line) const { // Compose and print an assert message. diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.h b/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.h index 4e86beb3ca38..2e7c604cf1b0 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.h +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.h @@ -18,51 +18,52 @@ class TargetInfo : public mlir::triton::TargetInfoBase { Value getClusterCTAId(RewriterBase &rewriter, Location loc) const override; - Value ballot(ConversionPatternRewriter &rewriter, Location loc, Type type, + Value ballot(RewriterBase &rewriter, Location loc, Type type, Value cmp) const override; - void storeShared(ConversionPatternRewriter &rewriter, Location loc, Value ptr, - Value val, Value pred) const override; - Value loadShared(ConversionPatternRewriter &rewriter, Location loc, + void storeShared(RewriterBase &rewriter, Location loc, Value ptr, Value val, + Value pred) const override; + Value loadShared(RewriterBase &rewriter, Location loc, const TypeConverter *converter, Value ptr, Type elemTy, Value pred) const override; - Value shuffleXor(ConversionPatternRewriter &rewriter, Location loc, Value val, + Value shuffleXor(RewriterBase &rewriter, Location loc, Value val, int i) const override; - Value shuffleUp(ConversionPatternRewriter &rewriter, Location loc, Value val, + Value shuffleUp(RewriterBase &rewriter, Location loc, Value val, int i) const override; - Value shuffleIdx(ConversionPatternRewriter &rewriter, Location loc, Value val, + Value shuffleIdx(RewriterBase &rewriter, Location loc, Value val, int i) const override; - Value shuffleIdx(ConversionPatternRewriter &rewriter, Location loc, Value val, + Value shuffleIdx(RewriterBase &rewriter, Location loc, Value val, Value i) const override; - Value programId(ConversionPatternRewriter &rewriter, Location loc, - ModuleOp moduleOp, int axis) const override; + Value programId(RewriterBase &rewriter, Location loc, ModuleOp moduleOp, + int axis) const override; - bool warpReduce(ConversionPatternRewriter &rewriter, Location loc, - SmallVector &acc, triton::ReduceOp op, - unsigned numLaneToReduce, unsigned interleave) const override; + bool warpReduce(RewriterBase &rewriter, Location loc, SmallVector &acc, + triton::ReduceOp op, unsigned numLaneToReduce, + unsigned interleave) const override; - bool processReplicaUsingStMatrix( - ConversionPatternRewriter &rewriter, Location loc, Value smemBase, - SmallVector &vals, RankedTensorType srcTy, Type elemTy, - ArrayRef paddedRepShape, ArrayRef origRepShape, - ArrayRef outOrd, unsigned accumNumReplicates, - int swizzleByteWidth) const override; + bool processReplicaUsingStMatrix(RewriterBase &rewriter, Location loc, + Value smemBase, SmallVector &vals, + RankedTensorType srcTy, Type elemTy, + ArrayRef paddedRepShape, + ArrayRef origRepShape, + ArrayRef outOrd, + unsigned accumNumReplicates, + int swizzleByteWidth) const override; std::string getMulhiFuncName(Type resultElementTy) const override; - void printf(ConversionPatternRewriter &rewriter, Value formatStrStart, + void printf(RewriterBase &rewriter, Value formatStrStart, int formatStrByteCount, ValueRange args) const override; - void assertFail(ConversionPatternRewriter &rewriter, Location loc, - StringRef message, StringRef file, StringRef func, - int line) const override; + void assertFail(RewriterBase &rewriter, Location loc, StringRef message, + StringRef file, StringRef func, int line) const override; bool enableLinearLayout() const override { return false; } private: void printfImpl(Value formatStrStart, int formatStrByteCount, ValueRange args, - ConversionPatternRewriter &rewriter, bool useStdErr) const; + RewriterBase &rewriter, bool useStdErr) const; std::string arch; }; diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp index 111045d134cb..f4325bbcf9c7 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp @@ -38,9 +38,8 @@ std::string mangleFunc(std::string name, Type type) { } // namespace namespace mlir::LLVM::AMD { -static Value shuffleCommon(Location loc, ConversionPatternRewriter &rewriter, - Value val, Value i, int strideInt, ShflKind mode, - Value clamp) { +static Value shuffleCommon(Location loc, RewriterBase &rewriter, Value val, + Value i, int strideInt, ShflKind mode, Value clamp) { unsigned bits = val.getType().getIntOrFloatBitWidth(); // On AMD, the ds_swizzle_b32 and ds_permute_b32 instructions work on @@ -126,30 +125,26 @@ static Value shuffleCommon(Location loc, ConversionPatternRewriter &rewriter, return Value(); } -Value shuffleXor(Location loc, ConversionPatternRewriter &rewriter, Value val, - int i) { +Value shuffleXor(Location loc, RewriterBase &rewriter, Value val, int i) { return shuffleCommon(loc, rewriter, val, i32_val(i), i, ShflKind::bfly, i32_val(0x1f)); } -Value shuffleUp(Location loc, ConversionPatternRewriter &rewriter, Value val, - int i) { +Value shuffleUp(Location loc, RewriterBase &rewriter, Value val, int i) { return shuffleCommon(loc, rewriter, val, i32_val(i), i, ShflKind::up, i32_val(0x0)); } -Value shuffleIdx(Location loc, ConversionPatternRewriter &rewriter, Value val, - int i) { +Value shuffleIdx(Location loc, RewriterBase &rewriter, Value val, int i) { return shuffleIdx(loc, rewriter, val, i32_val(i)); } -Value shuffleIdx(Location loc, ConversionPatternRewriter &rewriter, Value val, - Value i) { +Value shuffleIdx(Location loc, RewriterBase &rewriter, Value val, Value i) { return shuffleCommon(loc, rewriter, val, i, 0, ShflKind::idx, i32_val(0x1f)); } -Value llGetPid(Location loc, ConversionPatternRewriter &rewriter, - ModuleOp moduleOp, int axis) { +Value llGetPid(Location loc, RewriterBase &rewriter, ModuleOp moduleOp, + int axis) { assert(axis >= 0); assert(axis < 3); assert(moduleOp); @@ -160,8 +155,8 @@ Value llGetPid(Location loc, ConversionPatternRewriter &rewriter, return rewriter.create(loc, i32_ty, blockId); } -Value llLoad(ConversionPatternRewriter &rewriter, Location loc, Value ptr, - Type elemTy, Value pred, Value falseVal) { +Value llLoad(RewriterBase &rewriter, Location loc, Value ptr, Type elemTy, + Value pred, Value falseVal) { Type funcType = getFunctionType(elemTy, ValueRange({ptr, pred, falseVal})); auto parent = ptr.getParentRegion()->getParentOfType(); auto funcName = mangleFunc(mlir::LLVM::AMD::Predicated_Load, funcType); @@ -174,8 +169,8 @@ Value llLoad(ConversionPatternRewriter &rewriter, Location loc, Value ptr, return loadVal; } -void llStore(ConversionPatternRewriter &rewriter, Location loc, Value ptr, - Value val, Value pred) { +void llStore(RewriterBase &rewriter, Location loc, Value ptr, Value val, + Value pred) { auto ctx = ptr.getContext(); Type funcType = getFunctionType(void_ty(ctx), ValueRange({ptr, val, pred})); auto parent = ptr.getParentRegion()->getParentOfType(); diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h b/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h index c60d53f4b456..b8aa25475f7d 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h @@ -13,26 +13,22 @@ namespace mlir::LLVM::AMD { const char Predicated_Load[] = "__predicated_load"; const char Predicated_Store[] = "__predicated_store"; -Value shuffleXor(Location loc, ConversionPatternRewriter &rewriter, Value val, - int i); -Value shuffleUp(Location loc, ConversionPatternRewriter &rewriter, Value val, - int i); -Value shuffleIdx(Location loc, ConversionPatternRewriter &rewriter, Value val, - int i); -Value shuffleIdx(Location loc, ConversionPatternRewriter &rewriter, Value val, - Value i); - -Value llGetPid(Location loc, ConversionPatternRewriter &rewriter, - ModuleOp moduleOp, int axis); +Value shuffleXor(Location loc, RewriterBase &rewriter, Value val, int i); +Value shuffleUp(Location loc, RewriterBase &rewriter, Value val, int i); +Value shuffleIdx(Location loc, RewriterBase &rewriter, Value val, int i); +Value shuffleIdx(Location loc, RewriterBase &rewriter, Value val, Value i); + +Value llGetPid(Location loc, RewriterBase &rewriter, ModuleOp moduleOp, + int axis); // Loads from shared or global memory with predication. // `otherElems` is used to mask out the elements that are not loaded -Value llLoad(ConversionPatternRewriter &rewriter, Location loc, Value ptr, - Type elemTy, Value pred, Value falseVal); +Value llLoad(RewriterBase &rewriter, Location loc, Value ptr, Type elemTy, + Value pred, Value falseVal); // Stores to shared or global memory with predication. -void llStore(ConversionPatternRewriter &rewriter, Location loc, Value ptr, - Value val, Value pred); +void llStore(RewriterBase &rewriter, Location loc, Value ptr, Value val, + Value pred); } // namespace mlir::LLVM::AMD #endif diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp index 2afa2383c829..902951cc94ff 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.cpp @@ -14,8 +14,7 @@ using ::mlir::triton::gpu::getShapePerCTA; using ::mlir::triton::gpu::getShapePerCTATile; namespace { Value computeStMatrixAddr(Value laneId, int matStride, Location loc, - ConversionPatternRewriter &rewriter, - int swizzleByteWidth) { + RewriterBase &rewriter, int swizzleByteWidth) { Value rowInMat = urem(laneId, i32_val(8)); // row in the 8x8 matrix // linear index of the matrix in the 2x2 matrices // Decompose matIndex => s_0, s_1, that is the coordinate in 2x2 matrices in @@ -34,7 +33,7 @@ Value computeStMatrixAddr(Value laneId, int matStride, Location loc, void stMatrixm8n8x4(Value offset, ArrayRef vals, int indexOffset, Value smemBase, Type elemTy, Location loc, - ConversionPatternRewriter &rewriter) { + RewriterBase &rewriter) { SmallVector inputs; auto prTy = ptr_ty(rewriter.getContext(), 3); // Pack the input into 2xf16 @@ -53,8 +52,8 @@ void stMatrixm8n8x4(Value offset, ArrayRef vals, int indexOffset, void storeDistributedToSharedWithStMatrix( RankedTensorType tensorTy, Type elemTy, SmallVector &inVals, Value smemBase, ArrayRef paddedRepShape, - ArrayRef origRepShape, Location loc, - ConversionPatternRewriter &rewriter, int swizzlingByteWidth) { + ArrayRef origRepShape, Location loc, RewriterBase &rewriter, + int swizzlingByteWidth) { auto shapePerCTA = getShapePerCTA(tensorTy); auto mmaLayout = mlir::cast(tensorTy.getEncoding()); auto order = triton::gpu::getOrder(mmaLayout); @@ -140,7 +139,7 @@ bool isStMatrixCompatible(RankedTensorType tensorTy, int swizzlingByteWidth) { } // declare vprintf(i8*, i8*) as external function -LLVM::LLVMFuncOp getVprintfDeclaration(ConversionPatternRewriter &rewriter) { +LLVM::LLVMFuncOp getVprintfDeclaration(RewriterBase &rewriter) { auto moduleOp = rewriter.getBlock()->getParent()->getParentOfType(); StringRef funcName("vprintf"); Operation *funcOp = moduleOp.lookupSymbol(funcName); @@ -152,7 +151,7 @@ LLVM::LLVMFuncOp getVprintfDeclaration(ConversionPatternRewriter &rewriter) { SmallVector argsType{ptr_ty(context), ptr_ty(context)}; auto funcType = LLVM::LLVMFunctionType::get(i32_ty, argsType); - ConversionPatternRewriter::InsertionGuard guard(rewriter); + RewriterBase::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(moduleOp.getBody()); return rewriter.create(UnknownLoc::get(context), funcName, @@ -161,8 +160,7 @@ LLVM::LLVMFuncOp getVprintfDeclaration(ConversionPatternRewriter &rewriter) { // extend integer to int32, extend float to float64 // this comes from vprintf alignment requirements. -std::pair printfPromoteValue(ConversionPatternRewriter &rewriter, - Value value) { +std::pair printfPromoteValue(RewriterBase &rewriter, Value value) { auto *context = rewriter.getContext(); auto type = value.getType(); Value newOp = value; @@ -186,7 +184,7 @@ std::pair printfPromoteValue(ConversionPatternRewriter &rewriter, return {newType, newOp}; } -LLVM::LLVMFuncOp getAssertfailDeclaration(ConversionPatternRewriter &rewriter) { +LLVM::LLVMFuncOp getAssertfailDeclaration(RewriterBase &rewriter) { auto moduleOp = rewriter.getBlock()->getParent()->getParentOfType(); StringRef funcName("__assertfail"); { @@ -200,7 +198,7 @@ LLVM::LLVMFuncOp getAssertfailDeclaration(ConversionPatternRewriter &rewriter) { SmallVector argsType{ptr_ty(ctx), ptr_ty(ctx), i32_ty, ptr_ty(ctx), rewriter.getIntegerType(sizeof(size_t) * 8)}; auto funcType = LLVM::LLVMFunctionType::get(void_ty(ctx), argsType); - ConversionPatternRewriter::InsertionGuard guard(rewriter); + RewriterBase::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(moduleOp.getBody()); auto funcOp = rewriter.create(UnknownLoc::get(ctx), funcName, funcType); @@ -260,13 +258,13 @@ Value TargetInfo::getClusterCTAId(RewriterBase &rewriter, Location loc) const { rewriter.getI32Type()); } -Value TargetInfo::ballot(ConversionPatternRewriter &rewriter, Location loc, - Type type, Value cmp) const { +Value TargetInfo::ballot(RewriterBase &rewriter, Location loc, Type type, + Value cmp) const { Value threadMask = int_val(type.getIntOrFloatBitWidth(), -1); return rewriter.create(loc, type, threadMask, cmp); } -void TargetInfo::storeShared(ConversionPatternRewriter &rewriter, Location loc, - Value ptr, Value val, Value pred) const { +void TargetInfo::storeShared(RewriterBase &rewriter, Location loc, Value ptr, + Value val, Value pred) const { MLIRContext *ctx = rewriter.getContext(); unsigned bits = std::max(8u, val.getType().getIntOrFloatBitWidth()); const char *c = bits == 64 ? "l" : (bits == 16 ? "h" : "r"); @@ -279,7 +277,7 @@ void TargetInfo::storeShared(ConversionPatternRewriter &rewriter, Location loc, builder.launch(rewriter, loc, void_ty(ctx)); } -Value TargetInfo::loadShared(ConversionPatternRewriter &rewriter, Location loc, +Value TargetInfo::loadShared(RewriterBase &rewriter, Location loc, const TypeConverter *converter, Value ptr, Type elemTy, Value pred) const { MLIRContext *ctx = rewriter.getContext(); @@ -297,31 +295,31 @@ Value TargetInfo::loadShared(ConversionPatternRewriter &rewriter, Location loc, return builder.launch(rewriter, loc, elemTy); } -Value TargetInfo::shuffleXor(ConversionPatternRewriter &rewriter, Location loc, - Value val, int i) const { +Value TargetInfo::shuffleXor(RewriterBase &rewriter, Location loc, Value val, + int i) const { return LLVM::NVIDIA::shuffleXor(loc, rewriter, val, i); } -Value TargetInfo::shuffleUp(ConversionPatternRewriter &rewriter, Location loc, - Value val, int i) const { +Value TargetInfo::shuffleUp(RewriterBase &rewriter, Location loc, Value val, + int i) const { return LLVM::NVIDIA::shuffleUp(loc, rewriter, val, i); } -Value TargetInfo::shuffleIdx(ConversionPatternRewriter &rewriter, Location loc, - Value val, int i) const { +Value TargetInfo::shuffleIdx(RewriterBase &rewriter, Location loc, Value val, + int i) const { return LLVM::NVIDIA::shuffleIdx(loc, rewriter, val, i); } -Value TargetInfo::shuffleIdx(ConversionPatternRewriter &rewriter, Location loc, - Value val, Value i) const { +Value TargetInfo::shuffleIdx(RewriterBase &rewriter, Location loc, Value val, + Value i) const { return LLVM::NVIDIA::shuffleIdx(loc, rewriter, val, i); } -Value TargetInfo::programId(ConversionPatternRewriter &rewriter, Location loc, +Value TargetInfo::programId(RewriterBase &rewriter, Location loc, ModuleOp moduleOp, int axis) const { return LLVM::NVIDIA::llGetPid(loc, rewriter, moduleOp, axis); } -bool TargetInfo::warpReduce(ConversionPatternRewriter &rewriter, Location loc, +bool TargetInfo::warpReduce(RewriterBase &rewriter, Location loc, SmallVector &acc, triton::ReduceOp op, unsigned numLaneToReduce, unsigned interleave) const { @@ -362,7 +360,7 @@ bool TargetInfo::warpReduce(ConversionPatternRewriter &rewriter, Location loc, return false; } bool TargetInfo::processReplicaUsingStMatrix( - ConversionPatternRewriter &rewriter, Location loc, Value smemBase, + RewriterBase &rewriter, Location loc, Value smemBase, SmallVector &vals, RankedTensorType srcTy, Type elemTy, ArrayRef paddedRepShape, ArrayRef origRepShape, ArrayRef outOrd, unsigned accumNumReplicates, @@ -383,9 +381,8 @@ std::string TargetInfo::getMulhiFuncName(Type resultElementTy) const { return funcName; } -void TargetInfo::printf(ConversionPatternRewriter &rewriter, - Value formatStrStart, int /*formatStrByteCount*/, - ValueRange args) const { +void TargetInfo::printf(RewriterBase &rewriter, Value formatStrStart, + int /*formatStrByteCount*/, ValueRange args) const { auto *ctx = rewriter.getContext(); Type ptr = ptr_ty(ctx); auto moduleOp = rewriter.getBlock()->getParent()->getParentOfType(); @@ -426,7 +423,7 @@ void TargetInfo::printf(ConversionPatternRewriter &rewriter, call(funcOp, operands); } -void TargetInfo::assertFail(ConversionPatternRewriter &rewriter, Location loc, +void TargetInfo::assertFail(RewriterBase &rewriter, Location loc, StringRef message, StringRef file, StringRef func, int line) const { auto funcOp = getAssertfailDeclaration(rewriter); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.h b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.h index 9b59993e6a06..8572feeb6fc0 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.h +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TargetInfo.h @@ -13,45 +13,46 @@ class TargetInfo : public mlir::triton::TargetInfoBase { Value getClusterCTAId(RewriterBase &rewriter, Location loc) const override; - Value ballot(ConversionPatternRewriter &rewriter, Location loc, Type type, + Value ballot(RewriterBase &rewriter, Location loc, Type type, Value cmp) const override; - void storeShared(ConversionPatternRewriter &rewriter, Location loc, Value ptr, - Value val, Value pred) const override; - Value loadShared(ConversionPatternRewriter &rewriter, Location loc, + void storeShared(RewriterBase &rewriter, Location loc, Value ptr, Value val, + Value pred) const override; + Value loadShared(RewriterBase &rewriter, Location loc, const TypeConverter *converter, Value ptr, Type elemTy, Value pred) const override; - Value shuffleXor(ConversionPatternRewriter &rewriter, Location loc, Value val, + Value shuffleXor(RewriterBase &rewriter, Location loc, Value val, int i) const override; - Value shuffleUp(ConversionPatternRewriter &rewriter, Location loc, Value val, + Value shuffleUp(RewriterBase &rewriter, Location loc, Value val, int i) const override; - Value shuffleIdx(ConversionPatternRewriter &rewriter, Location loc, Value val, + Value shuffleIdx(RewriterBase &rewriter, Location loc, Value val, int i) const override; - Value shuffleIdx(ConversionPatternRewriter &rewriter, Location loc, Value val, + Value shuffleIdx(RewriterBase &rewriter, Location loc, Value val, Value i) const override; - Value programId(ConversionPatternRewriter &rewriter, Location loc, - ModuleOp moduleOp, int axis) const override; + Value programId(RewriterBase &rewriter, Location loc, ModuleOp moduleOp, + int axis) const override; - bool warpReduce(ConversionPatternRewriter &rewriter, Location loc, - SmallVector &acc, triton::ReduceOp op, - unsigned numLaneToReduce, unsigned interleave) const override; + bool warpReduce(RewriterBase &rewriter, Location loc, SmallVector &acc, + triton::ReduceOp op, unsigned numLaneToReduce, + unsigned interleave) const override; - bool processReplicaUsingStMatrix( - ConversionPatternRewriter &rewriter, Location loc, Value smemBase, - SmallVector &vals, RankedTensorType srcTy, Type elemTy, - ArrayRef paddedRepShape, ArrayRef origRepShape, - ArrayRef outOrd, unsigned accumNumReplicates, - int swizzleByteWidth) const override; + bool processReplicaUsingStMatrix(RewriterBase &rewriter, Location loc, + Value smemBase, SmallVector &vals, + RankedTensorType srcTy, Type elemTy, + ArrayRef paddedRepShape, + ArrayRef origRepShape, + ArrayRef outOrd, + unsigned accumNumReplicates, + int swizzleByteWidth) const override; std::string getMulhiFuncName(Type resultElementTy) const override; - void printf(ConversionPatternRewriter &rewriter, Value formatStrStart, + void printf(RewriterBase &rewriter, Value formatStrStart, int formatStrByteCount, ValueRange args) const override; - void assertFail(ConversionPatternRewriter &rewriter, Location loc, - StringRef message, StringRef file, StringRef func, - int line) const override; + void assertFail(RewriterBase &rewriter, Location loc, StringRef message, + StringRef file, StringRef func, int line) const override; private: int computeCapability; diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.cpp index 37c5b6ec7c50..685b83620821 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.cpp @@ -8,9 +8,8 @@ namespace LLVM { namespace NVIDIA { using namespace mlir::triton; -static Value shuffleCommon(Location loc, ConversionPatternRewriter &rewriter, - Value val, Value i, NVVM::ShflKind mode, - Value clamp) { +static Value shuffleCommon(Location loc, RewriterBase &rewriter, Value val, + Value i, NVVM::ShflKind mode, Value clamp) { unsigned bits = val.getType().getIntOrFloatBitWidth(); if (bits == 64) { @@ -42,31 +41,27 @@ static Value shuffleCommon(Location loc, ConversionPatternRewriter &rewriter, return result; } -Value shuffleXor(Location loc, ConversionPatternRewriter &rewriter, Value val, - int i) { +Value shuffleXor(Location loc, RewriterBase &rewriter, Value val, int i) { return shuffleCommon(loc, rewriter, val, i32_val(i), NVVM::ShflKind::bfly, i32_val(0x1f)); } -Value shuffleUp(Location loc, ConversionPatternRewriter &rewriter, Value val, - int i) { +Value shuffleUp(Location loc, RewriterBase &rewriter, Value val, int i) { return shuffleCommon(loc, rewriter, val, i32_val(i), NVVM::ShflKind::up, i32_val(0x0)); } -Value shuffleIdx(Location loc, ConversionPatternRewriter &rewriter, Value val, - int i) { +Value shuffleIdx(Location loc, RewriterBase &rewriter, Value val, int i) { return shuffleIdx(loc, rewriter, val, i32_val(i)); } -Value shuffleIdx(Location loc, ConversionPatternRewriter &rewriter, Value val, - Value i) { +Value shuffleIdx(Location loc, RewriterBase &rewriter, Value val, Value i) { return shuffleCommon(loc, rewriter, val, i, NVVM::ShflKind::idx, i32_val(0x1f)); } -Value llGetPid(Location loc, ConversionPatternRewriter &rewriter, - ModuleOp moduleOp, int axis) { +Value llGetPid(Location loc, RewriterBase &rewriter, ModuleOp moduleOp, + int axis) { assert(axis >= 0); assert(axis < 3); assert(moduleOp); @@ -92,8 +87,8 @@ Value getSRegValue(OpBuilder &b, Location loc, const std::string &sRegStr) { return val; } -Value permute(Location loc, ConversionPatternRewriter &rewriter, Value a, - Value b, Value mask) { +Value permute(Location loc, RewriterBase &rewriter, Value a, Value b, + Value mask) { PTXBuilder builder; auto &prmt = builder.create("prmt")->o("b32"); auto *destOpr = builder.newOperand("=r"); diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.h b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.h index bb4e9dd33646..3d3eeb1affb9 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.h +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/Utility.h @@ -40,19 +40,15 @@ namespace LLVM { namespace NVIDIA { Value getSRegValue(OpBuilder &b, Location loc, const std::string &sRegStr); -Value shuffleXor(Location loc, ConversionPatternRewriter &rewriter, Value val, - int i); -Value shuffleUp(Location loc, ConversionPatternRewriter &rewriter, Value val, - int i); -Value shuffleIdx(Location loc, ConversionPatternRewriter &rewriter, Value val, - int i); -Value shuffleIdx(Location loc, ConversionPatternRewriter &rewriter, Value val, - Value i); -Value permute(Location loc, ConversionPatternRewriter &rewriter, Value a, - Value b, Value mask); - -Value llGetPid(Location loc, ConversionPatternRewriter &rewriter, - ModuleOp moduleOp, int axis); +Value shuffleXor(Location loc, RewriterBase &rewriter, Value val, int i); +Value shuffleUp(Location loc, RewriterBase &rewriter, Value val, int i); +Value shuffleIdx(Location loc, RewriterBase &rewriter, Value val, int i); +Value shuffleIdx(Location loc, RewriterBase &rewriter, Value val, Value i); +Value permute(Location loc, RewriterBase &rewriter, Value a, Value b, + Value mask); + +Value llGetPid(Location loc, RewriterBase &rewriter, ModuleOp moduleOp, + int axis); /// Usage of macro load_dsmem /// (1) load_dsmem(addr, ctaId)