Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add missing integer and unsigned integer types. #3734

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions csrc/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -207,10 +207,10 @@ class CudaKernelGenerator : private kir::ConstIrVisitor {
// For 64-bit Unix systems, int is 32-bit, long and long long are 64-bit
// For 64-bit Windows, int and long are 32-bit, long long are 64-bit
return "LL";
case DataType::UInt:
return "ULL";
case DataType::UInt32:
return "U";
case DataType::UInt64:
return "ULL";
case DataType::Index:
return getLiteralSuffix(kernel_->indexType());
default:
Expand Down
4 changes: 2 additions & 2 deletions csrc/device_lower/pass/allocation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -632,7 +632,7 @@ class AllocationInserter : public kir::ExprMutator {
// create and allocate a memory barrier
TensorView* mbarrier = TensorViewBuilder()
.shape(std::vector<int64_t>{})
.dtype(DataType::UInt)
.dtype(DataType::UInt64)
.contiguity(true)
.build();
mbarrier->setMemoryType(MemoryType::Shared);
Expand Down Expand Up @@ -697,7 +697,7 @@ class AllocationInserter : public kir::ExprMutator {

TensorView* mbarrier = TensorViewBuilder()
.shape(std::vector<int64_t>{num_mbarriers})
.dtype(DataType::UInt)
.dtype(DataType::UInt64)
.contiguity(true)
.build();
mbarrier->setMemoryType(MemoryType::Shared);
Expand Down
22 changes: 11 additions & 11 deletions csrc/device_lower/pass/index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1527,7 +1527,7 @@ void IndexLowering::handleCpAsyncBulkLoad(const LoadStoreOp* ldst) {
ldst, mbarrier_index, for_loops_, rotated_loop_);

// arrive and expect_tx mbarrier
Val* state = IrBuilder::create<Val>(DataType::UInt);
Val* state = IrBuilder::create<Val>(DataType::UInt64);
pushBack(IrBuilder::create<kir::Allocate>(
state, MemoryType::Local, ldst->container()->oneVal()));
pushBack(IrBuilder::create<kir::MBarrierArriveExpectTx>(
Expand Down Expand Up @@ -2160,10 +2160,10 @@ void IndexLowering::handle(const LoadStoreOp* ldst) {
// Reference:
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-shared-memory-layout-matrix-descriptor
static Val* matrixDescriptorEncode(Val* x) {
auto x_cast = IrBuilder::maybeCastExpr(DataType::UInt, x);
auto mask = IrBuilder::create<Val>(0x3FFFF, DataType::UInt);
auto x_cast = IrBuilder::maybeCastExpr(DataType::UInt64, x);
auto mask = IrBuilder::create<Val>(0x3FFFF, DataType::UInt64);
auto x_and = IrBuilder::bitwiseAndExpr(x_cast, mask);
auto shift = IrBuilder::create<Val>(0x4, DataType::UInt);
auto shift = IrBuilder::create<Val>(0x4, DataType::UInt64);
return IrBuilder::rShiftExpr(x_and, shift);
}

Expand All @@ -2176,15 +2176,15 @@ static Val* constructMatrixDescriptor(
auto or0 = matrixDescriptorEncode(start_address);
auto or1 = IrBuilder::lShiftExpr(
matrixDescriptorEncode(leading_dim_byte_offset),
IrBuilder::create<Val>(16, DataType::UInt));
IrBuilder::create<Val>(16, DataType::UInt64));
auto or2 = IrBuilder::lShiftExpr(
matrixDescriptorEncode(stride_dim_byte_offset),
IrBuilder::create<Val>(32, DataType::UInt));
IrBuilder::create<Val>(32, DataType::UInt64));
auto or3 = IrBuilder::lShiftExpr(
matrix_base_offset, IrBuilder::create<Val>(49, DataType::UInt));
matrix_base_offset, IrBuilder::create<Val>(49, DataType::UInt64));
auto or4 = IrBuilder::lShiftExpr(
IrBuilder::create<Val>((int64_t)swizzle, DataType::UInt),
IrBuilder::create<Val>(62, DataType::UInt));
IrBuilder::create<Val>((int64_t)swizzle, DataType::UInt64),
IrBuilder::create<Val>(62, DataType::UInt64));
return IrBuilder::bitwiseOrExpr(
IrBuilder::bitwiseOrExpr(
IrBuilder::bitwiseOrExpr(IrBuilder::bitwiseOrExpr(or0, or1), or2),
Expand Down Expand Up @@ -2444,7 +2444,7 @@ void IndexLowering::handle(const MmaOp* mma) {
base_addr,
leading_bytes,
stride_bytes,
IrBuilder::create<Val>(0, DataType::UInt),
IrBuilder::create<Val>(0, DataType::UInt64),
getSwizzleMode(tv));
a = IrBuilder::create<kir::TensorIndex>(
tv,
Expand Down Expand Up @@ -2476,7 +2476,7 @@ void IndexLowering::handle(const MmaOp* mma) {
base_addr,
leading_bytes,
stride_bytes,
IrBuilder::create<Val>(0, DataType::UInt),
IrBuilder::create<Val>(0, DataType::UInt64),
swizzle);
b = IrBuilder::create<kir::TensorIndex>(
tv,
Expand Down
4 changes: 2 additions & 2 deletions csrc/device_lower/pass/predicate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -237,8 +237,8 @@ class ConditionalFromPredicateModifier : public kir::ExprMutator {
return IrBuilder::create<Val>(true, DataType::Bool);
}
case PredicateType::ElectSync: {
Val* zero = IrBuilder::create<Val>(0L, PrimDataType::UInt);
Val* warp_size = IrBuilder::create<Val>(32L, PrimDataType::UInt);
Val* zero = IrBuilder::create<Val>(0L, PrimDataType::UInt64);
Val* warp_size = IrBuilder::create<Val>(32L, PrimDataType::UInt64);
Val* full_mask_val =
IrBuilder::create<Val>(0xFFFFFFFF, PrimDataType::UInt32);

Expand Down
8 changes: 4 additions & 4 deletions csrc/kernel_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ TensorIndex::TensorIndex(
isPointerType(index->dtype()) || index->dtype() == DataType::Index ||
isStructType(index->dtype()) ||
index->dtype() ==
DataType::UInt /*For matrix descriptor for hopper MMA*/,
DataType::UInt64 /*For matrix descriptor for hopper MMA*/,
"Cannot index with a value other than an int/pointer/struct.");
}

Expand Down Expand Up @@ -577,7 +577,7 @@ MBarrierArrive::MBarrierArrive(
NVF_ERROR(passkey.ir_container_ != nullptr);
addInput(mbarrier);
if (state != nullptr) {
NVF_CHECK(state->dtype() == DataType::UInt);
NVF_CHECK(state->dtype() == DataType::UInt64);
addOutput(state);
}
}
Expand Down Expand Up @@ -606,7 +606,7 @@ MBarrierArriveExpectTx::MBarrierArriveExpectTx(
addInput(mbarrier);
addInput(tx_count);
if (state != nullptr) {
NVF_CHECK(state->dtype() == DataType::UInt);
NVF_CHECK(state->dtype() == DataType::UInt64);
addOutput(state);
}
}
Expand All @@ -627,7 +627,7 @@ NVFUSER_DEFINE_CLONE_AND_CREATE(MBarrierArriveExpectTx)
MBarrierWait::MBarrierWait(IrBuilderPasskey passkey, Val* mbarrier, Val* state)
: Expr(passkey) {
NVF_ERROR(passkey.ir_container_ != nullptr);
NVF_CHECK(state->dtype() == DataType::UInt);
NVF_CHECK(state->dtype() == DataType::UInt64);
addInput(mbarrier);
addInput(state);
}
Expand Down
15 changes: 12 additions & 3 deletions csrc/runtime/allocations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -207,9 +207,6 @@ void setFillAllocationWithNan(bool value) {

void fillTensorWithNan(at::Tensor& t) {
switch (t.scalar_type()) {
case at::ScalarType::Byte:
t.fill_(0xFF);
break;
case at::ScalarType::Char:
t.fill_(0x7F);
break;
Expand All @@ -222,6 +219,18 @@ void fillTensorWithNan(at::Tensor& t) {
case at::ScalarType::Long:
t.fill_(0x7FFFFFFFFFFFFFFFL);
break;
case at::ScalarType::Byte:
t.fill_(0xFF);
break;
case at::ScalarType::UInt16:
t.fill_(0xFFFF);
break;
case at::ScalarType::UInt32:
t.fill_(0xFFFFFFFF);
break;
case at::ScalarType::UInt64:
t.fill_(0xFFFFFFFFFFFFFFFFL);
break;
case at::ScalarType::Bool:
t.fill_(true);
break;
Expand Down
Loading
Loading