Skip to content

Commit

Permalink
s/ConversionPatternRewriter/RewriterBase/ in utils (NFC) (triton-lang…
Browse files Browse the repository at this point in the history
…#4140)

You can't create a ConversionPatternRewriter in C++ unit tests,
so this means that anything which is touched from a C++ unit
test cannot transitively touch anything which uses a
ConversionPatternRewriter.

This is a simple search+replace, there's no functional difference
between these types for our purposes.
  • Loading branch information
jlebar authored Jun 14, 2024
1 parent 3aeb266 commit 6d93696
Show file tree
Hide file tree
Showing 11 changed files with 205 additions and 230 deletions.
36 changes: 18 additions & 18 deletions include/triton/Conversion/TritonGPUToLLVM/TargetInfoBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,34 +10,34 @@ class TargetInfoBase {

virtual Value getClusterCTAId(RewriterBase &rewriter, Location loc) const = 0;

virtual Value ballot(ConversionPatternRewriter &rewriter, Location loc,
Type type, Value cmp) const = 0;
virtual Value ballot(RewriterBase &rewriter, Location loc, Type type,
Value cmp) const = 0;

virtual void storeShared(ConversionPatternRewriter &rewriter, Location loc,
Value ptr, Value val, Value pred) const = 0;
virtual Value loadShared(ConversionPatternRewriter &rewriter, Location loc,
virtual void storeShared(RewriterBase &rewriter, Location loc, Value ptr,
Value val, Value pred) const = 0;
virtual Value loadShared(RewriterBase &rewriter, Location loc,
const TypeConverter *converter, Value ptr,
Type elemTy, Value pred) const = 0;

virtual Value shuffleXor(ConversionPatternRewriter &rewriter, Location loc,
Value val, int i) const = 0;
virtual Value shuffleUp(ConversionPatternRewriter &rewriter, Location loc,
Value val, int i) const = 0;
virtual Value shuffleIdx(ConversionPatternRewriter &rewriter, Location loc,
Value val, int i) const = 0;
virtual Value shuffleIdx(ConversionPatternRewriter &rewriter, Location loc,
Value val, Value i) const = 0;
virtual Value shuffleXor(RewriterBase &rewriter, Location loc, Value val,
int i) const = 0;
virtual Value shuffleUp(RewriterBase &rewriter, Location loc, Value val,
int i) const = 0;
virtual Value shuffleIdx(RewriterBase &rewriter, Location loc, Value val,
int i) const = 0;
virtual Value shuffleIdx(RewriterBase &rewriter, Location loc, Value val,
Value i) const = 0;

virtual Value programId(ConversionPatternRewriter &rewriter, Location loc,
virtual Value programId(RewriterBase &rewriter, Location loc,
ModuleOp moduleOp, int axis) const = 0;

virtual bool warpReduce(ConversionPatternRewriter &rewriter, Location loc,
virtual bool warpReduce(RewriterBase &rewriter, Location loc,
SmallVector<Value> &acc, triton::ReduceOp op,
unsigned numLaneToReduce,
unsigned interleave) const = 0;

virtual bool processReplicaUsingStMatrix(
ConversionPatternRewriter &rewriter, Location loc, Value smemBase,
RewriterBase &rewriter, Location loc, Value smemBase,
SmallVector<Value> &vals, RankedTensorType srcTy, Type elemTy,
ArrayRef<unsigned> paddedRepShape, ArrayRef<unsigned> origRepShape,
ArrayRef<unsigned> outOrd, unsigned accumNumReplicates,
Expand All @@ -48,11 +48,11 @@ class TargetInfoBase {
// format from the device. |formatStrStart| is the pointer to the start of
// the format string global variable; |args| are the arguments to fill
// placeholders in the format string.
virtual void printf(ConversionPatternRewriter &rewriter, Value formatStrStart,
virtual void printf(RewriterBase &rewriter, Value formatStrStart,
int formatStrByteCount, ValueRange args) const = 0;
// Emits LLVM code with |rewriter| to perform assertion failure with the given
// |message| from the given |func| in |file|.
virtual void assertFail(ConversionPatternRewriter &rewriter, Location loc,
virtual void assertFail(RewriterBase &rewriter, Location loc,
StringRef message, StringRef file, StringRef func,
int line) const = 0;

Expand Down
56 changes: 27 additions & 29 deletions include/triton/Conversion/TritonGPUToLLVM/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -202,9 +202,9 @@ T getLinearIndex(llvm::ArrayRef<T> multiDimIndex, llvm::ArrayRef<T> shape,
namespace gpu {
Type getFunctionType(Type resultType, ValueRange operands);

LLVM::LLVMFuncOp appendOrGetExternFuncOp(ConversionPatternRewriter &rewriter,
Operation *op, StringRef funcName,
Type funcType, StringRef libname = "",
LLVM::LLVMFuncOp appendOrGetExternFuncOp(RewriterBase &rewriter, Operation *op,
StringRef funcName, Type funcType,
StringRef libname = "",
StringRef libpath = "");
} // namespace gpu

Expand Down Expand Up @@ -305,17 +305,18 @@ struct SharedMemoryObject {
}

Value getBaseBeforeSlice(int order, Location loc,
ConversionPatternRewriter &rewriter) const {
RewriterBase &rewriter) const {
Value cSwizzleOffset = getCSwizzleOffset(order);
Value offset = sub(i32_val(0), cSwizzleOffset);
Type type = base.getType();
return gep(type, baseElemType, base, offset);
}
};

SharedMemoryObject
getSharedMemoryObjectFromStruct(Location loc, Value llvmStruct, Type elemTy,
ConversionPatternRewriter &rewriter);
SharedMemoryObject getSharedMemoryObjectFromStruct(Location loc,
Value llvmStruct,
Type elemTy,
RewriterBase &rewriter);

// Convert an \param index to a multi-dim coordinate given \param shape and
// \param order.
Expand All @@ -329,15 +330,14 @@ SmallVector<Value> delinearize(RewriterBase &rewriter, Location loc,
SmallVector<Value> delinearize(RewriterBase &rewriter, Location loc,
Value linear, ArrayRef<unsigned> shape);

Value linearize(ConversionPatternRewriter &rewriter, Location loc,
ArrayRef<Value> multiDim, ArrayRef<unsigned> shape,
ArrayRef<unsigned> order);
Value linearize(RewriterBase &rewriter, Location loc, ArrayRef<Value> multiDim,
ArrayRef<unsigned> shape, ArrayRef<unsigned> order);

Value linearize(ConversionPatternRewriter &rewriter, Location loc,
ArrayRef<Value> multiDim, ArrayRef<unsigned> shape);
Value linearize(RewriterBase &rewriter, Location loc, ArrayRef<Value> multiDim,
ArrayRef<unsigned> shape);

Value addStringToModule(Location loc, ConversionPatternRewriter &rewriter,
StringRef key, StringRef content);
Value addStringToModule(Location loc, RewriterBase &rewriter, StringRef key,
StringRef content);

// Given an elemId which represents the index of an element from the list of
// elements that are in the thread's registers (i.e. total of
Expand All @@ -346,7 +346,7 @@ Value addStringToModule(Location loc, ConversionPatternRewriter &rewriter,
// when converting distributed to distributed layout. Also, a replica is the
// smallest CTA tile that is common between input and output layouts.
SmallVector<Value> getMultiDimOffset(Attribute layout, Location loc,
ConversionPatternRewriter &rewriter,
RewriterBase &rewriter,
const TargetInfoBase &targetInfo,
unsigned elemId, RankedTensorType type,
ArrayRef<unsigned> multiDimCTAInRepId,
Expand All @@ -355,15 +355,15 @@ SmallVector<Value> getMultiDimOffset(Attribute layout, Location loc,
// Given a multiDimOffset, this function wraps around each dimension to be
// within shape.
SmallVector<Value> getWrappedMultiDimOffset(
ConversionPatternRewriter &rewriter, Location loc,
ArrayRef<Value> multiDimOffset, ArrayRef<unsigned> shape,
SmallVector<unsigned> shapePerCTATile, SmallVector<int64_t> shapePerCTA);
RewriterBase &rewriter, Location loc, ArrayRef<Value> multiDimOffset,
ArrayRef<unsigned> shape, SmallVector<unsigned> shapePerCTATile,
SmallVector<int64_t> shapePerCTA);

inline bool isKernel(FunctionOpInterface funcOp) {
return funcOp.getVisibility() == SymbolTable::Visibility::Public;
}

inline Value getStackPointer(PatternRewriter &rewriter,
inline Value getStackPointer(RewriterBase &rewriter,
FunctionOpInterface funcOp) {
auto mod = funcOp->getParentOfType<ModuleOp>();
LLVM::GlobalOp globalBase = nullptr;
Expand All @@ -378,8 +378,7 @@ inline Value getStackPointer(PatternRewriter &rewriter,
return funcOp.getArgument(funcOp.getNumArguments() - 1);
}

inline Value getSharedMemoryBase(Location loc,
ConversionPatternRewriter &rewriter,
inline Value getSharedMemoryBase(Location loc, RewriterBase &rewriter,
Operation *op) {
auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), 3);
FunctionOpInterface func =
Expand Down Expand Up @@ -1566,9 +1565,9 @@ inline void storeDistributedToShared(MemDescType dstTy, RankedTensorType srcTy,
}
}

inline Value
getStructFromSharedMemoryObject(Location loc, const SharedMemoryObject &smemObj,
ConversionPatternRewriter &rewriter) {
inline Value getStructFromSharedMemoryObject(Location loc,
const SharedMemoryObject &smemObj,
RewriterBase &rewriter) {
auto elems = smemObj.getElems();
auto types = smemObj.getTypes();
auto structTy =
Expand All @@ -1582,9 +1581,8 @@ getStructFromSharedMemoryObject(Location loc, const SharedMemoryObject &smemObj,
return llvmStruct;
}

inline SmallVector<Value>
unpackLLElements(Location loc, Value llvmStruct,
ConversionPatternRewriter &rewriter) {
inline SmallVector<Value> unpackLLElements(Location loc, Value llvmStruct,
RewriterBase &rewriter) {
assert(bool(llvmStruct) && "can not unpack null values");
if (llvmStruct.getType().isIntOrIndexOrFloat() ||
isa<triton::PointerType>(llvmStruct.getType()) ||
Expand All @@ -1602,8 +1600,8 @@ unpackLLElements(Location loc, Value llvmStruct,

inline Value packLLElements(Location loc,
const LLVMTypeConverter *typeConverter,
ValueRange resultVals,
ConversionPatternRewriter &rewriter, Type type) {
ValueRange resultVals, RewriterBase &rewriter,
Type type) {
auto structType =
dyn_cast<LLVM::LLVMStructType>(typeConverter->convertType(type));
if (!structType) {
Expand Down
37 changes: 18 additions & 19 deletions lib/Conversion/TritonGPUToLLVM/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ using CoordTy = SmallVector<Value>;
using ValueTable = std::map<std::pair<int, int>, std::pair<Value, Value>>;

static SmallVector<CoordTy>
getMNCoords(Value thread, Location loc, ConversionPatternRewriter &rewriter,
getMNCoords(Value thread, Location loc, RewriterBase &rewriter,
ArrayRef<unsigned int> wpt, const NvidiaMmaEncodingAttr &mmaLayout,
ArrayRef<int64_t> shape, bool isARow, bool isBRow, bool isAVec4,
bool isBVec4) {
Expand Down Expand Up @@ -120,9 +120,8 @@ Type getFunctionType(Type resultType, ValueRange operands) {
return LLVM::LLVMFunctionType::get(resultType, operandTypes);
}

LLVM::LLVMFuncOp appendOrGetExternFuncOp(ConversionPatternRewriter &rewriter,
Operation *op, StringRef funcName,
Type funcType,
LLVM::LLVMFuncOp appendOrGetExternFuncOp(RewriterBase &rewriter, Operation *op,
StringRef funcName, Type funcType,
StringRef libname /*= ""*/,
StringRef libpath /*= ""*/) {
using LLVM::LLVMFuncOp;
Expand Down Expand Up @@ -496,9 +495,10 @@ Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width,
builder.getIntegerAttr(ty, value));
}

SharedMemoryObject
getSharedMemoryObjectFromStruct(Location loc, Value llvmStruct, Type elemTy,
ConversionPatternRewriter &rewriter) {
SharedMemoryObject getSharedMemoryObjectFromStruct(Location loc,
Value llvmStruct,
Type elemTy,
RewriterBase &rewriter) {
ArrayRef<Type> types =
cast<LLVM::LLVMStructType>(llvmStruct.getType()).getBody();
SmallVector<Value> elems(types.size());
Expand Down Expand Up @@ -580,15 +580,14 @@ SmallVector<Value> delinearize(RewriterBase &rewriter, Location loc,
return multiDim;
}

Value linearize(ConversionPatternRewriter &rewriter, Location loc,
ArrayRef<Value> multiDim, ArrayRef<unsigned> shape,
ArrayRef<unsigned> order) {
Value linearize(RewriterBase &rewriter, Location loc, ArrayRef<Value> multiDim,
ArrayRef<unsigned> shape, ArrayRef<unsigned> order) {
return linearize(rewriter, loc, applyPermutation(multiDim, order),
applyPermutation(shape, order));
}

Value linearize(ConversionPatternRewriter &rewriter, Location loc,
ArrayRef<Value> multiDim, ArrayRef<unsigned> shape) {
Value linearize(RewriterBase &rewriter, Location loc, ArrayRef<Value> multiDim,
ArrayRef<unsigned> shape) {
auto rank = multiDim.size();
Value linear = i32_val(0);
if (rank > 0) {
Expand All @@ -602,8 +601,8 @@ Value linearize(ConversionPatternRewriter &rewriter, Location loc,
return linear;
}

Value addStringToModule(Location loc, ConversionPatternRewriter &rewriter,
StringRef key, StringRef content) {
Value addStringToModule(Location loc, RewriterBase &rewriter, StringRef key,
StringRef content) {
auto moduleOp = rewriter.getBlock()->getParent()->getParentOfType<ModuleOp>();
auto ctx = moduleOp.getContext();
unsigned stringNumber = 0;
Expand All @@ -619,7 +618,7 @@ Value addStringToModule(Location loc, ConversionPatternRewriter &rewriter,

LLVM::GlobalOp global;
{
ConversionPatternRewriter::InsertionGuard guard(rewriter);
RewriterBase::InsertionGuard guard(rewriter);
rewriter.setInsertionPointToStart(moduleOp.getBody());
global = rewriter.create<LLVM::GlobalOp>(
UnknownLoc::get(ctx), globalType,
Expand All @@ -637,7 +636,7 @@ Value addStringToModule(Location loc, ConversionPatternRewriter &rewriter,
}

SmallVector<Value> getMultiDimOffset(Attribute layout, Location loc,
ConversionPatternRewriter &rewriter,
RewriterBase &rewriter,
const TargetInfoBase &targetInfo,
unsigned elemId, RankedTensorType type,
ArrayRef<unsigned> multiDimCTAInRepId,
Expand Down Expand Up @@ -791,9 +790,9 @@ SmallVector<Value> getMultiDimOffset(Attribute layout, Location loc,
}

SmallVector<Value> getWrappedMultiDimOffset(
ConversionPatternRewriter &rewriter, Location loc,
ArrayRef<Value> multiDimOffset, ArrayRef<unsigned> shape,
SmallVector<unsigned> shapePerCTATile, SmallVector<int64_t> shapePerCTA) {
RewriterBase &rewriter, Location loc, ArrayRef<Value> multiDimOffset,
ArrayRef<unsigned> shape, SmallVector<unsigned> shapePerCTATile,
SmallVector<int64_t> shapePerCTA) {
unsigned rank = shape.size();
SmallVector<Value> multiDimOffsetWrapped(rank);
for (unsigned d = 0; d < rank; ++d) {
Expand Down
Loading

0 comments on commit 6d93696

Please sign in to comment.