diff --git a/csrc/codegen.cpp b/csrc/codegen.cpp index fe00c1f8105..9571425b7e7 100644 --- a/csrc/codegen.cpp +++ b/csrc/codegen.cpp @@ -207,10 +207,10 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { // For 64-bit Unix systems, int is 32-bit, long and long long are 64-bit // For 64-bit Windows, int and long are 32-bit, long long are 64-bit return "LL"; - case DataType::UInt: - return "ULL"; case DataType::UInt32: return "U"; + case DataType::UInt64: + return "ULL"; case DataType::Index: return getLiteralSuffix(kernel_->indexType()); default: diff --git a/csrc/device_lower/pass/allocation.cpp b/csrc/device_lower/pass/allocation.cpp index bee61c46873..e08d1d78711 100644 --- a/csrc/device_lower/pass/allocation.cpp +++ b/csrc/device_lower/pass/allocation.cpp @@ -632,7 +632,7 @@ class AllocationInserter : public kir::ExprMutator { // create and allocate a memory barrier TensorView* mbarrier = TensorViewBuilder() .shape(std::vector{}) - .dtype(DataType::UInt) + .dtype(DataType::UInt64) .contiguity(true) .build(); mbarrier->setMemoryType(MemoryType::Shared); @@ -697,7 +697,7 @@ class AllocationInserter : public kir::ExprMutator { TensorView* mbarrier = TensorViewBuilder() .shape(std::vector{num_mbarriers}) - .dtype(DataType::UInt) + .dtype(DataType::UInt64) .contiguity(true) .build(); mbarrier->setMemoryType(MemoryType::Shared); diff --git a/csrc/device_lower/pass/index.cpp b/csrc/device_lower/pass/index.cpp index 4605eb9eac4..17c34733a79 100644 --- a/csrc/device_lower/pass/index.cpp +++ b/csrc/device_lower/pass/index.cpp @@ -1527,7 +1527,7 @@ void IndexLowering::handleCpAsyncBulkLoad(const LoadStoreOp* ldst) { ldst, mbarrier_index, for_loops_, rotated_loop_); // arrive and expect_tx mbarrier - Val* state = IrBuilder::create(DataType::UInt); + Val* state = IrBuilder::create(DataType::UInt64); pushBack(IrBuilder::create( state, MemoryType::Local, ldst->container()->oneVal())); pushBack(IrBuilder::create( @@ -2160,10 +2160,10 @@ void IndexLowering::handle(const LoadStoreOp* ldst) { // Reference: // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-shared-memory-layout-matrix-descriptor static Val* matrixDescriptorEncode(Val* x) { - auto x_cast = IrBuilder::maybeCastExpr(DataType::UInt, x); - auto mask = IrBuilder::create(0x3FFFF, DataType::UInt); + auto x_cast = IrBuilder::maybeCastExpr(DataType::UInt64, x); + auto mask = IrBuilder::create(0x3FFFF, DataType::UInt64); auto x_and = IrBuilder::bitwiseAndExpr(x_cast, mask); - auto shift = IrBuilder::create(0x4, DataType::UInt); + auto shift = IrBuilder::create(0x4, DataType::UInt64); return IrBuilder::rShiftExpr(x_and, shift); } @@ -2176,15 +2176,15 @@ static Val* constructMatrixDescriptor( auto or0 = matrixDescriptorEncode(start_address); auto or1 = IrBuilder::lShiftExpr( matrixDescriptorEncode(leading_dim_byte_offset), - IrBuilder::create(16, DataType::UInt)); + IrBuilder::create(16, DataType::UInt64)); auto or2 = IrBuilder::lShiftExpr( matrixDescriptorEncode(stride_dim_byte_offset), - IrBuilder::create(32, DataType::UInt)); + IrBuilder::create(32, DataType::UInt64)); auto or3 = IrBuilder::lShiftExpr( - matrix_base_offset, IrBuilder::create(49, DataType::UInt)); + matrix_base_offset, IrBuilder::create(49, DataType::UInt64)); auto or4 = IrBuilder::lShiftExpr( - IrBuilder::create((int64_t)swizzle, DataType::UInt), - IrBuilder::create(62, DataType::UInt)); + IrBuilder::create((int64_t)swizzle, DataType::UInt64), + IrBuilder::create(62, DataType::UInt64)); return IrBuilder::bitwiseOrExpr( IrBuilder::bitwiseOrExpr( IrBuilder::bitwiseOrExpr(IrBuilder::bitwiseOrExpr(or0, or1), or2), @@ -2444,7 +2444,7 @@ void IndexLowering::handle(const MmaOp* mma) { base_addr, leading_bytes, stride_bytes, - IrBuilder::create(0, DataType::UInt), + IrBuilder::create(0, DataType::UInt64), getSwizzleMode(tv)); a = IrBuilder::create( tv, @@ -2476,7 +2476,7 @@ void IndexLowering::handle(const MmaOp* mma) { base_addr, leading_bytes, stride_bytes, - IrBuilder::create(0, DataType::UInt), + IrBuilder::create(0, DataType::UInt64), swizzle); b = IrBuilder::create( tv, diff --git a/csrc/device_lower/pass/predicate.cpp b/csrc/device_lower/pass/predicate.cpp index afff481fb94..015de0930cd 100644 --- a/csrc/device_lower/pass/predicate.cpp +++ b/csrc/device_lower/pass/predicate.cpp @@ -237,8 +237,8 @@ class ConditionalFromPredicateModifier : public kir::ExprMutator { return IrBuilder::create(true, DataType::Bool); } case PredicateType::ElectSync: { - Val* zero = IrBuilder::create(0L, PrimDataType::UInt); - Val* warp_size = IrBuilder::create(32L, PrimDataType::UInt); + Val* zero = IrBuilder::create(0L, PrimDataType::UInt64); + Val* warp_size = IrBuilder::create(32L, PrimDataType::UInt64); Val* full_mask_val = IrBuilder::create(0xFFFFFFFF, PrimDataType::UInt32); diff --git a/csrc/kernel_ir.cpp b/csrc/kernel_ir.cpp index a1549ba28f6..49ce8f820c8 100644 --- a/csrc/kernel_ir.cpp +++ b/csrc/kernel_ir.cpp @@ -100,7 +100,7 @@ TensorIndex::TensorIndex( isPointerType(index->dtype()) || index->dtype() == DataType::Index || isStructType(index->dtype()) || index->dtype() == - DataType::UInt /*For matrix descriptor for hopper MMA*/, + DataType::UInt64 /*For matrix descriptor for hopper MMA*/, "Cannot index with a value other than an int/pointer/struct."); } @@ -577,7 +577,7 @@ MBarrierArrive::MBarrierArrive( NVF_ERROR(passkey.ir_container_ != nullptr); addInput(mbarrier); if (state != nullptr) { - NVF_CHECK(state->dtype() == DataType::UInt); + NVF_CHECK(state->dtype() == DataType::UInt64); addOutput(state); } } @@ -606,7 +606,7 @@ MBarrierArriveExpectTx::MBarrierArriveExpectTx( addInput(mbarrier); addInput(tx_count); if (state != nullptr) { - NVF_CHECK(state->dtype() == DataType::UInt); + NVF_CHECK(state->dtype() == DataType::UInt64); addOutput(state); } } @@ -627,7 +627,7 @@ NVFUSER_DEFINE_CLONE_AND_CREATE(MBarrierArriveExpectTx) MBarrierWait::MBarrierWait(IrBuilderPasskey passkey, Val* mbarrier, Val* state) : Expr(passkey) { NVF_ERROR(passkey.ir_container_ != nullptr); - NVF_CHECK(state->dtype() == DataType::UInt); + NVF_CHECK(state->dtype() == DataType::UInt64); addInput(mbarrier); addInput(state); } diff --git a/csrc/runtime/allocations.cpp b/csrc/runtime/allocations.cpp index 45bfe431c3f..e12b4f0ec8d 100644 --- a/csrc/runtime/allocations.cpp +++ b/csrc/runtime/allocations.cpp @@ -207,9 +207,6 @@ void setFillAllocationWithNan(bool value) { void fillTensorWithNan(at::Tensor& t) { switch (t.scalar_type()) { - case at::ScalarType::Byte: - t.fill_(0xFF); - break; case at::ScalarType::Char: t.fill_(0x7F); break; @@ -222,6 +219,18 @@ void fillTensorWithNan(at::Tensor& t) { case at::ScalarType::Long: t.fill_(0x7FFFFFFFFFFFFFFFL); break; + case at::ScalarType::Byte: + t.fill_(0xFF); + break; + case at::ScalarType::UInt16: + t.fill_(0xFFFF); + break; + case at::ScalarType::UInt32: + t.fill_(0xFFFFFFFF); + break; + case at::ScalarType::UInt64: + t.fill_(0xFFFFFFFFFFFFFFFFL); + break; case at::ScalarType::Bool: t.fill_(true); break; diff --git a/csrc/type.cpp b/csrc/type.cpp index 36819404ae9..c3e07fa1d67 100644 --- a/csrc/type.cpp +++ b/csrc/type.cpp @@ -228,16 +228,24 @@ static std::string data_type2string(DataType t) { return "__e4m3"; case DataType::Float8_e5m2: return "__e5m2"; - case DataType::Int: - return "int64_t"; case DataType::Index: return "nvfuser_index_t"; + case DataType::Char: + return "int8_t"; + case DataType::Short: + return "int16_t"; case DataType::Int32: return "int"; - case DataType::UInt: - return "uint64_t"; + case DataType::Int: + return "int64_t"; + case DataType::Byte: + return "uint8_t"; + case DataType::UInt16: + return "uint16_t"; case DataType::UInt32: return "uint32_t"; + case DataType::UInt64: + return "uint64_t"; case DataType::SMemAddress: return "unsigned"; case DataType::ComplexFloat: @@ -866,64 +874,148 @@ static const char* supported_casts2string(std::pair t) { std::get(t.first.type), std::get(t.second.type))) { case supported_switch_pair(DataType::Index, DataType::Float): - case supported_switch_pair(DataType::Int, DataType::Float): + case supported_switch_pair(DataType::Char, DataType::Float): + case supported_switch_pair(DataType::Short, DataType::Float): case supported_switch_pair(DataType::Int32, DataType::Float): - case supported_switch_pair(DataType::UInt, DataType::Float): + case supported_switch_pair(DataType::Int, DataType::Float): + case supported_switch_pair(DataType::Byte, DataType::Float): + case supported_switch_pair(DataType::UInt16, DataType::Float): case supported_switch_pair(DataType::UInt32, DataType::Float): + case supported_switch_pair(DataType::UInt64, DataType::Float): case supported_switch_pair(DataType::Double, DataType::Float): case supported_switch_pair(DataType::Bool, DataType::Float): return "(float)"; case supported_switch_pair(DataType::ComplexFloat, DataType::Float): case supported_switch_pair(DataType::ComplexDouble, DataType::Float): return "(float)std::real"; - case supported_switch_pair(DataType::Index, DataType::Int): - case supported_switch_pair(DataType::Int32, DataType::Int): - case supported_switch_pair(DataType::UInt, DataType::Int): - case supported_switch_pair(DataType::UInt32, DataType::Int): - case supported_switch_pair(DataType::Float, DataType::Int): - case supported_switch_pair(DataType::Double, DataType::Int): - case supported_switch_pair(DataType::Bool, DataType::Int): - return "(int64_t)"; - case supported_switch_pair(DataType::ComplexFloat, DataType::Int): - case supported_switch_pair(DataType::ComplexDouble, DataType::Int): - return "(int64_t)std::real"; + case supported_switch_pair(DataType::Index, DataType::Char): + case supported_switch_pair(DataType::Short, DataType::Char): + case supported_switch_pair(DataType::Int32, DataType::Char): + case supported_switch_pair(DataType::Int, DataType::Char): + case supported_switch_pair(DataType::Byte, DataType::Char): + case supported_switch_pair(DataType::UInt16, DataType::Char): + case supported_switch_pair(DataType::UInt32, DataType::Char): + case supported_switch_pair(DataType::UInt64, DataType::Char): + case supported_switch_pair(DataType::Float, DataType::Char): + case supported_switch_pair(DataType::Double, DataType::Char): + case supported_switch_pair(DataType::Bool, DataType::Char): + return "(int8_t)"; + case supported_switch_pair(DataType::Index, DataType::Short): + case supported_switch_pair(DataType::Char, DataType::Short): + case supported_switch_pair(DataType::Int32, DataType::Short): + case supported_switch_pair(DataType::Int, DataType::Short): + case supported_switch_pair(DataType::Byte, DataType::Short): + case supported_switch_pair(DataType::UInt16, DataType::Short): + case supported_switch_pair(DataType::UInt32, DataType::Short): + case supported_switch_pair(DataType::UInt64, DataType::Short): + case supported_switch_pair(DataType::Float, DataType::Short): + case supported_switch_pair(DataType::Double, DataType::Short): + case supported_switch_pair(DataType::Bool, DataType::Short): + return "(int16_t)"; case supported_switch_pair(DataType::Index, DataType::Int32): + case supported_switch_pair(DataType::Char, DataType::Int32): + case supported_switch_pair(DataType::Short, DataType::Int32): case supported_switch_pair(DataType::Int, DataType::Int32): - case supported_switch_pair(DataType::UInt, DataType::Int32): + case supported_switch_pair(DataType::Byte, DataType::Int32): + case supported_switch_pair(DataType::UInt16, DataType::Int32): case supported_switch_pair(DataType::UInt32, DataType::Int32): + case supported_switch_pair(DataType::UInt64, DataType::Int32): case supported_switch_pair(DataType::Float, DataType::Int32): case supported_switch_pair(DataType::Double, DataType::Int32): case supported_switch_pair(DataType::Bool, DataType::Int32): return "(int32_t)"; + case supported_switch_pair(DataType::Index, DataType::Int): + case supported_switch_pair(DataType::Char, DataType::Int): + case supported_switch_pair(DataType::Short, DataType::Int): + case supported_switch_pair(DataType::Int32, DataType::Int): + case supported_switch_pair(DataType::Byte, DataType::Int): + case supported_switch_pair(DataType::UInt16, DataType::Int): + case supported_switch_pair(DataType::UInt32, DataType::Int): + case supported_switch_pair(DataType::UInt64, DataType::Int): + case supported_switch_pair(DataType::Float, DataType::Int): + case supported_switch_pair(DataType::Double, DataType::Int): + case supported_switch_pair(DataType::Bool, DataType::Int): + return "(int64_t)"; + case supported_switch_pair(DataType::ComplexFloat, DataType::Char): + case supported_switch_pair(DataType::ComplexDouble, DataType::Char): + return "(int8_t)std::real"; + case supported_switch_pair(DataType::ComplexFloat, DataType::Short): + case supported_switch_pair(DataType::ComplexDouble, DataType::Short): + return "(int16_t)std::real"; case supported_switch_pair(DataType::ComplexFloat, DataType::Int32): case supported_switch_pair(DataType::ComplexDouble, DataType::Int32): return "(int32_t)std::real"; - case supported_switch_pair(DataType::Index, DataType::UInt): - case supported_switch_pair(DataType::Int, DataType::UInt): - case supported_switch_pair(DataType::Int32, DataType::UInt): - case supported_switch_pair(DataType::UInt32, DataType::UInt): - case supported_switch_pair(DataType::Float, DataType::UInt): - case supported_switch_pair(DataType::Double, DataType::UInt): - case supported_switch_pair(DataType::Bool, DataType::UInt): - return "(uint64_t)"; - case supported_switch_pair(DataType::ComplexFloat, DataType::UInt): - case supported_switch_pair(DataType::ComplexDouble, DataType::UInt): - return "(uint64_t)std::real"; + case supported_switch_pair(DataType::ComplexFloat, DataType::Int): + case supported_switch_pair(DataType::ComplexDouble, DataType::Int): + return "(int64_t)std::real"; + case supported_switch_pair(DataType::Index, DataType::Byte): + case supported_switch_pair(DataType::Char, DataType::Byte): + case supported_switch_pair(DataType::Short, DataType::Byte): + case supported_switch_pair(DataType::Int32, DataType::Byte): + case supported_switch_pair(DataType::Int, DataType::Byte): + case supported_switch_pair(DataType::UInt16, DataType::Byte): + case supported_switch_pair(DataType::UInt32, DataType::Byte): + case supported_switch_pair(DataType::UInt64, DataType::Byte): + case supported_switch_pair(DataType::Float, DataType::Byte): + case supported_switch_pair(DataType::Double, DataType::Byte): + case supported_switch_pair(DataType::Bool, DataType::Byte): + return "(uint8_t)"; + case supported_switch_pair(DataType::Index, DataType::UInt16): + case supported_switch_pair(DataType::Char, DataType::UInt16): + case supported_switch_pair(DataType::Short, DataType::UInt16): + case supported_switch_pair(DataType::Int32, DataType::UInt16): + case supported_switch_pair(DataType::Int, DataType::UInt16): + case supported_switch_pair(DataType::Byte, DataType::UInt16): + case supported_switch_pair(DataType::UInt32, DataType::UInt16): + case supported_switch_pair(DataType::UInt64, DataType::UInt16): + case supported_switch_pair(DataType::Float, DataType::UInt16): + case supported_switch_pair(DataType::Double, DataType::UInt16): + case supported_switch_pair(DataType::Bool, DataType::UInt16): + return "(uint16_t)"; case supported_switch_pair(DataType::Index, DataType::UInt32): - case supported_switch_pair(DataType::Int, DataType::UInt32): + case supported_switch_pair(DataType::Char, DataType::UInt32): + case supported_switch_pair(DataType::Short, DataType::UInt32): case supported_switch_pair(DataType::Int32, DataType::UInt32): - case supported_switch_pair(DataType::UInt, DataType::UInt32): + case supported_switch_pair(DataType::Int, DataType::UInt32): + case supported_switch_pair(DataType::Byte, DataType::UInt32): + case supported_switch_pair(DataType::UInt16, DataType::UInt32): + case supported_switch_pair(DataType::UInt64, DataType::UInt32): case supported_switch_pair(DataType::Float, DataType::UInt32): case supported_switch_pair(DataType::Double, DataType::UInt32): case supported_switch_pair(DataType::Bool, DataType::UInt32): return "(uint32_t)"; + case supported_switch_pair(DataType::Index, DataType::UInt64): + case supported_switch_pair(DataType::Char, DataType::UInt64): + case supported_switch_pair(DataType::Short, DataType::UInt64): + case supported_switch_pair(DataType::Int32, DataType::UInt64): + case supported_switch_pair(DataType::Int, DataType::UInt64): + case supported_switch_pair(DataType::Byte, DataType::UInt64): + case supported_switch_pair(DataType::UInt16, DataType::UInt64): + case supported_switch_pair(DataType::UInt32, DataType::UInt64): + case supported_switch_pair(DataType::Float, DataType::UInt64): + case supported_switch_pair(DataType::Double, DataType::UInt64): + case supported_switch_pair(DataType::Bool, DataType::UInt64): + return "(uint64_t)"; + case supported_switch_pair(DataType::ComplexFloat, DataType::Byte): + case supported_switch_pair(DataType::ComplexDouble, DataType::Byte): + return "(uint8_t)std::real"; + case supported_switch_pair(DataType::ComplexFloat, DataType::UInt16): + case supported_switch_pair(DataType::ComplexDouble, DataType::UInt16): + return "(uint16_t)std::real"; case supported_switch_pair(DataType::ComplexFloat, DataType::UInt32): case supported_switch_pair(DataType::ComplexDouble, DataType::UInt32): return "(uint32_t)std::real"; - case supported_switch_pair(DataType::Int, DataType::Index): + case supported_switch_pair(DataType::ComplexFloat, DataType::UInt64): + case supported_switch_pair(DataType::ComplexDouble, DataType::UInt64): + return "(uint64_t)std::real"; + case supported_switch_pair(DataType::Char, DataType::Index): + case supported_switch_pair(DataType::Short, DataType::Index): case supported_switch_pair(DataType::Int32, DataType::Index): - case supported_switch_pair(DataType::UInt, DataType::Index): + case supported_switch_pair(DataType::Int, DataType::Index): + case supported_switch_pair(DataType::Byte, DataType::Index): + case supported_switch_pair(DataType::UInt16, DataType::Index): case supported_switch_pair(DataType::UInt32, DataType::Index): + case supported_switch_pair(DataType::UInt64, DataType::Index): case supported_switch_pair(DataType::Float, DataType::Index): case supported_switch_pair(DataType::Double, DataType::Index): case supported_switch_pair(DataType::Bool, DataType::Index): @@ -932,10 +1024,14 @@ static const char* supported_casts2string(std::pair t) { case supported_switch_pair(DataType::ComplexDouble, DataType::Index): return "(nvfuser_index_t)std::real"; case supported_switch_pair(DataType::Index, DataType::Double): - case supported_switch_pair(DataType::Int, DataType::Double): + case supported_switch_pair(DataType::Char, DataType::Double): + case supported_switch_pair(DataType::Short, DataType::Double): case supported_switch_pair(DataType::Int32, DataType::Double): - case supported_switch_pair(DataType::UInt, DataType::Double): + case supported_switch_pair(DataType::Int, DataType::Double): + case supported_switch_pair(DataType::Byte, DataType::Double): + case supported_switch_pair(DataType::UInt16, DataType::Double): case supported_switch_pair(DataType::UInt32, DataType::Double): + case supported_switch_pair(DataType::UInt64, DataType::Double): case supported_switch_pair(DataType::Float, DataType::Double): case supported_switch_pair(DataType::Bool, DataType::Double): return "(double)"; @@ -945,29 +1041,41 @@ static const char* supported_casts2string(std::pair t) { case supported_switch_pair(DataType::Float, DataType::Bool): case supported_switch_pair(DataType::Double, DataType::Bool): case supported_switch_pair(DataType::Index, DataType::Bool): - case supported_switch_pair(DataType::Int, DataType::Bool): + case supported_switch_pair(DataType::Char, DataType::Bool): + case supported_switch_pair(DataType::Short, DataType::Bool): case supported_switch_pair(DataType::Int32, DataType::Bool): - case supported_switch_pair(DataType::UInt, DataType::Bool): + case supported_switch_pair(DataType::Int, DataType::Bool): + case supported_switch_pair(DataType::Byte, DataType::Bool): + case supported_switch_pair(DataType::UInt16, DataType::Bool): case supported_switch_pair(DataType::UInt32, DataType::Bool): + case supported_switch_pair(DataType::UInt64, DataType::Bool): return "(bool)"; case supported_switch_pair(DataType::ComplexFloat, DataType::Bool): case supported_switch_pair(DataType::ComplexDouble, DataType::Bool): return "(bool)std::real"; case supported_switch_pair(DataType::Index, DataType::ComplexDouble): - case supported_switch_pair(DataType::Int, DataType::ComplexDouble): + case supported_switch_pair(DataType::Char, DataType::ComplexDouble): + case supported_switch_pair(DataType::Short, DataType::ComplexDouble): case supported_switch_pair(DataType::Int32, DataType::ComplexDouble): - case supported_switch_pair(DataType::UInt, DataType::ComplexDouble): + case supported_switch_pair(DataType::Int, DataType::ComplexDouble): + case supported_switch_pair(DataType::Byte, DataType::ComplexDouble): + case supported_switch_pair(DataType::UInt16, DataType::ComplexDouble): case supported_switch_pair(DataType::UInt32, DataType::ComplexDouble): + case supported_switch_pair(DataType::UInt64, DataType::ComplexDouble): case supported_switch_pair(DataType::Double, DataType::ComplexDouble): case supported_switch_pair(DataType::Float, DataType::ComplexDouble): case supported_switch_pair(DataType::Bool, DataType::ComplexDouble): case supported_switch_pair(DataType::ComplexFloat, DataType::ComplexDouble): return "(std::complex)"; case supported_switch_pair(DataType::Index, DataType::ComplexFloat): - case supported_switch_pair(DataType::Int, DataType::ComplexFloat): + case supported_switch_pair(DataType::Char, DataType::ComplexFloat): + case supported_switch_pair(DataType::Short, DataType::ComplexFloat): case supported_switch_pair(DataType::Int32, DataType::ComplexFloat): - case supported_switch_pair(DataType::UInt, DataType::ComplexFloat): + case supported_switch_pair(DataType::Int, DataType::ComplexFloat): + case supported_switch_pair(DataType::Byte, DataType::ComplexFloat): + case supported_switch_pair(DataType::UInt16, DataType::ComplexFloat): case supported_switch_pair(DataType::UInt32, DataType::ComplexFloat): + case supported_switch_pair(DataType::UInt64, DataType::ComplexFloat): case supported_switch_pair(DataType::Double, DataType::ComplexFloat): case supported_switch_pair(DataType::Float, DataType::ComplexFloat): case supported_switch_pair(DataType::Bool, DataType::ComplexFloat): @@ -978,10 +1086,14 @@ static const char* supported_casts2string(std::pair t) { return "__float2half"; case supported_switch_pair(DataType::Double, DataType::Half): return "__double2half"; - case supported_switch_pair(DataType::Int, DataType::Half): + case supported_switch_pair(DataType::Char, DataType::Half): + case supported_switch_pair(DataType::Short, DataType::Half): case supported_switch_pair(DataType::Int32, DataType::Half): - case supported_switch_pair(DataType::UInt, DataType::Half): + case supported_switch_pair(DataType::Int, DataType::Half): + case supported_switch_pair(DataType::Byte, DataType::Half): + case supported_switch_pair(DataType::UInt16, DataType::Half): case supported_switch_pair(DataType::UInt32, DataType::Half): + case supported_switch_pair(DataType::UInt64, DataType::Half): case supported_switch_pair(DataType::Index, DataType::Half): return "__int2half"; case supported_switch_pair(DataType::Bool, DataType::Half): @@ -994,13 +1106,17 @@ static const char* supported_casts2string(std::pair t) { return "__half2float"; case supported_switch_pair(DataType::Half, DataType::Double): return "__half2double"; + case supported_switch_pair(DataType::Half, DataType::Char): + case supported_switch_pair(DataType::Half, DataType::Short): case supported_switch_pair(DataType::Half, DataType::Int32): return "__half2int32"; case supported_switch_pair(DataType::Half, DataType::Int): return "__half2int"; + case supported_switch_pair(DataType::Half, DataType::Byte): + case supported_switch_pair(DataType::Half, DataType::UInt16): case supported_switch_pair(DataType::Half, DataType::UInt32): return "__half2uint32"; - case supported_switch_pair(DataType::Half, DataType::UInt): + case supported_switch_pair(DataType::Half, DataType::UInt64): return "__half2uint"; case supported_switch_pair(DataType::Half, DataType::Index): return "__half2index"; @@ -1017,10 +1133,14 @@ static const char* supported_casts2string(std::pair t) { return "__double2bfloat"; case supported_switch_pair(DataType::Half, DataType::BFloat16): return "__half2bfloat"; - case supported_switch_pair(DataType::Int, DataType::BFloat16): + case supported_switch_pair(DataType::Char, DataType::BFloat16): + case supported_switch_pair(DataType::Short, DataType::BFloat16): case supported_switch_pair(DataType::Int32, DataType::BFloat16): - case supported_switch_pair(DataType::UInt, DataType::BFloat16): + case supported_switch_pair(DataType::Int, DataType::BFloat16): + case supported_switch_pair(DataType::Byte, DataType::BFloat16): + case supported_switch_pair(DataType::UInt16, DataType::BFloat16): case supported_switch_pair(DataType::UInt32, DataType::BFloat16): + case supported_switch_pair(DataType::UInt64, DataType::BFloat16): case supported_switch_pair(DataType::Index, DataType::BFloat16): return "__int2bfloat"; case supported_switch_pair(DataType::Bool, DataType::BFloat16): @@ -1035,13 +1155,17 @@ static const char* supported_casts2string(std::pair t) { return "__bfloat2double"; case supported_switch_pair(DataType::BFloat16, DataType::Half): return "__bfloat2half"; + case supported_switch_pair(DataType::BFloat16, DataType::Char): + case supported_switch_pair(DataType::BFloat16, DataType::Short): case supported_switch_pair(DataType::BFloat16, DataType::Int32): return "__bfloat2int32"; case supported_switch_pair(DataType::BFloat16, DataType::Int): return "__bfloat2int"; + case supported_switch_pair(DataType::BFloat16, DataType::Byte): + case supported_switch_pair(DataType::BFloat16, DataType::UInt16): case supported_switch_pair(DataType::BFloat16, DataType::UInt32): return "__bfloat2uint32"; - case supported_switch_pair(DataType::BFloat16, DataType::UInt): + case supported_switch_pair(DataType::BFloat16, DataType::UInt64): return "__bfloat2uint"; case supported_switch_pair(DataType::BFloat16, DataType::Index): return "__bfloat2index"; @@ -1107,10 +1231,22 @@ DataType aten_to_data_type(const at::ScalarType& scalar_type) { return DataType::Float8_e4m3fn; case at::ScalarType::Float8_e5m2: return DataType::Float8_e5m2; - case at::ScalarType::Long: - return DataType::Int; + case at::ScalarType::Char: + return DataType::Char; + case at::ScalarType::Short: + return DataType::Short; case at::ScalarType::Int: return DataType::Int32; + case at::ScalarType::Long: + return DataType::Int; + case at::ScalarType::Byte: + return DataType::Byte; + case at::ScalarType::UInt16: + return DataType::UInt16; + case at::ScalarType::UInt32: + return DataType::UInt32; + case at::ScalarType::UInt64: + return DataType::UInt64; case at::ScalarType::ComplexFloat: return DataType::ComplexFloat; case at::ScalarType::ComplexDouble: @@ -1136,16 +1272,28 @@ at::ScalarType data_type_to_aten(const DataType& data_type) { return at::ScalarType::Float8_e4m3fn; case DataType::Float8_e5m2: return at::ScalarType::Float8_e5m2; - case DataType::Int: - return at::ScalarType::Long; case DataType::Index: NVF_THROW( "Index is determined at compile time,", " to convert from an aten type you need to have the compiled information. ", "This information is passed to GpuLower at compile time, and then copied to kerned.", "There's also this information in FusionExecutorCache and the Registry system."); + case DataType::Char: + return at::ScalarType::Char; + case DataType::Short: + return at::ScalarType::Short; case DataType::Int32: return at::ScalarType::Int; + case DataType::Int: + return at::ScalarType::Long; + case DataType::Byte: + return at::ScalarType::Byte; + case DataType::UInt16: + return at::ScalarType::UInt16; + case DataType::UInt32: + return at::ScalarType::UInt32; + case DataType::UInt64: + return at::ScalarType::UInt64; case DataType::ComplexFloat: return at::ScalarType::ComplexFloat; case DataType::ComplexDouble: @@ -1371,7 +1519,7 @@ std::string typePrefix(const DataType data_type) { case DataType::Index: case DataType::Int: case DataType::Int32: - case DataType::UInt: + case DataType::UInt64: case DataType::UInt32: case DataType::SMemAddress: return "i"; diff --git a/csrc/type.h b/csrc/type.h index 265f1a939ee..13e84ca1d98 100644 --- a/csrc/type.h +++ b/csrc/type.h @@ -72,10 +72,14 @@ enum class PrimDataType { Float8_e4m3fn, Float8_e5m2, // Integral types - Int, + Char, + Short, Int32, - UInt, + Int, + Byte, // Following ATen convention + UInt16, // Following ATen convention UInt32, + UInt64, Index, // Boolean types Bool, @@ -178,11 +182,15 @@ struct DataType { static constexpr PrimDataType Half = PrimDataType::Half; static constexpr PrimDataType Float8_e4m3fn = PrimDataType::Float8_e4m3fn; static constexpr PrimDataType Float8_e5m2 = PrimDataType::Float8_e5m2; - static constexpr PrimDataType Int = PrimDataType::Int; static constexpr PrimDataType Index = PrimDataType::Index; + static constexpr PrimDataType Char = PrimDataType::Char; + static constexpr PrimDataType Short = PrimDataType::Short; static constexpr PrimDataType Int32 = PrimDataType::Int32; - static constexpr PrimDataType UInt = PrimDataType::UInt; + static constexpr PrimDataType Int = PrimDataType::Int; + static constexpr PrimDataType Byte = PrimDataType::Byte; + static constexpr PrimDataType UInt16 = PrimDataType::UInt16; static constexpr PrimDataType UInt32 = PrimDataType::UInt32; + static constexpr PrimDataType UInt64 = PrimDataType::UInt64; static constexpr PrimDataType Bool = PrimDataType::Bool; static constexpr PrimDataType BFloat16 = PrimDataType::BFloat16; static constexpr PrimDataType ComplexFloat = PrimDataType::ComplexFloat; @@ -262,10 +270,14 @@ inline bool isIntegralType(DataType dtype) { if constexpr (std::is_same_v) { switch (dtype) { case DataType::Index: + case DataType::Char: + case DataType::Short: case DataType::Int: case DataType::Int32: - case DataType::UInt: + case DataType::Byte: + case DataType::UInt16: case DataType::UInt32: + case DataType::UInt64: return true; default: return false; @@ -278,7 +290,8 @@ inline bool isIntegralType(DataType dtype) { // Returns if the datatype is an unsigned integer type inline bool isUnsignedIntegralType(DataType dtype) { - return dtype == DataType::UInt || dtype == DataType::UInt32; + return dtype == DataType::Byte || dtype == DataType::UInt16 || + dtype == DataType::UInt32 || dtype == DataType::UInt64; } // Returns if the datatype is a pointer type @@ -386,15 +399,37 @@ DEFINE_DATATYPE_TO_ATEN_AND_NATIVE_TYPE( at::ScalarType::Float8_e5m2, at::Float8_e5m2); DEFINE_DATATYPE_TO_ATEN_AND_NATIVE_TYPE( - DataType::Int, - at::ScalarType::Long, - int64_t); + DataType::Char, + at::ScalarType::Char, + int8_t); +DEFINE_DATATYPE_TO_ATEN_AND_NATIVE_TYPE( + DataType::Short, + at::ScalarType::Short, + int16_t); DEFINE_DATATYPE_TO_ATEN_AND_NATIVE_TYPE( DataType::Int32, at::ScalarType::Int, int); -DEFINE_DATATYPE_TO_NATIVE_TYPE(DataType::UInt, uint64_t); -DEFINE_DATATYPE_TO_NATIVE_TYPE(DataType::UInt32, uint32_t); +DEFINE_DATATYPE_TO_ATEN_AND_NATIVE_TYPE( + DataType::Int, + at::ScalarType::Long, + int64_t); +DEFINE_DATATYPE_TO_ATEN_AND_NATIVE_TYPE( + DataType::Byte, + at::ScalarType::Byte, + uint8_t); +DEFINE_DATATYPE_TO_ATEN_AND_NATIVE_TYPE( + DataType::UInt16, + at::ScalarType::UInt16, + uint16_t); +DEFINE_DATATYPE_TO_ATEN_AND_NATIVE_TYPE( + DataType::UInt32, + at::ScalarType::UInt32, + uint32_t); +DEFINE_DATATYPE_TO_ATEN_AND_NATIVE_TYPE( + DataType::UInt64, + at::ScalarType::UInt64, + uint64_t); DEFINE_DATATYPE_TO_ATEN_AND_NATIVE_TYPE( DataType::Bool, at::ScalarType::Bool, @@ -1006,14 +1041,22 @@ constexpr inline size_t primDataTypeSize(PrimDataType type) { return sizeof(at::Float8_e5m2); case DataType::Index: NVF_THROW("The actual type of Index is only known at compile time."); - case DataType::Int: - return sizeof(int64_t); + case DataType::Char: + return sizeof(int8_t); + case DataType::Short: + return sizeof(int16_t); case DataType::Int32: return sizeof(int32_t); - case DataType::UInt: - return sizeof(uint64_t); + case DataType::Int: + return sizeof(int64_t); + case DataType::Byte: + return sizeof(uint8_t); + case DataType::UInt16: + return sizeof(uint16_t); case DataType::UInt32: return sizeof(uint32_t); + case DataType::UInt64: + return sizeof(uint64_t); case DataType::SMemAddress: return sizeof(unsigned); default: diff --git a/csrc/validator_utils.cpp b/csrc/validator_utils.cpp index bf75621ef0f..45cb40c049c 100644 --- a/csrc/validator_utils.cpp +++ b/csrc/validator_utils.cpp @@ -141,8 +141,14 @@ std::pair getTolerance( return {abs_tol * 10.0, abs_tol * 0.01 * 10.0}; } } - case DataType::Int: + case DataType::Char: + case DataType::Short: case DataType::Int32: + case DataType::Int: + case DataType::Byte: + case DataType::UInt16: + case DataType::UInt32: + case DataType::UInt64: case DataType::Index: case DataType::Bool: return {0.0, 0.0}; diff --git a/tests/cpp/test_gpu3.cpp b/tests/cpp/test_gpu3.cpp index 0d6632616ef..4281339994a 100644 --- a/tests/cpp/test_gpu3.cpp +++ b/tests/cpp/test_gpu3.cpp @@ -4534,10 +4534,14 @@ TEST_F(NVFuserTest, FusionCastings_CUDA) { DataType::Double, DataType::Float, DataType::Half, - DataType::Int, + DataType::Char, + DataType::Short, DataType::Int32, - DataType::UInt, + DataType::Int, + DataType::Byte, + DataType::UInt16, DataType::UInt32, + DataType::UInt64, DataType::Bool, DataType::ComplexFloat, DataType::ComplexDouble}; @@ -4548,32 +4552,12 @@ TEST_F(NVFuserTest, FusionCastings_CUDA) { } #endif - // ATen does not support uint32_t and uint64_t as dtype, so we need to - // use int32_t and int64_t as a proxy for these two types. - auto convert_aten_unsupported_dtype = [](DataType dt) -> DataType { - if (dt == DataType::UInt) { - return DataType::Int; - } else if (dt == DataType::UInt32) { - return DataType::Int32; - } - return dt; - }; - for (const auto& input_type : data_types) { - DataType proxy_input_type = convert_aten_unsupported_dtype(input_type); - auto tv_in = makeContigTensor(2, proxy_input_type); + auto tv_in = makeContigTensor(2, input_type); fusion.addInput(tv_in); - if (proxy_input_type != input_type) { - tv_in = bitCastOp(input_type, tv_in); - } - for (const auto& output_type : data_types) { - DataType proxy_output_type = convert_aten_unsupported_dtype(output_type); auto tv_out = castOp(output_type, tv_in); - if (proxy_output_type != output_type) { - tv_out = bitCastOp(proxy_output_type, tv_out); - } fusion.addOutput(tv_out); } } @@ -4583,16 +4567,14 @@ TEST_F(NVFuserTest, FusionCastings_CUDA) { std::vector inputs; std::vector outputs; for (const auto& input_type : data_types) { - DataType proxy_input_type = convert_aten_unsupported_dtype(input_type); at::Tensor t = at::randn({x, y}, options) .relu() // Discard negative numbers so that signed and // unsigned types are equivalent. There is no way // to represent unsigned numbers in PyTorch. - .to(data_type_to_aten(proxy_input_type)); + .to(data_type_to_aten(input_type)); inputs.emplace_back(t); for (const auto& output_type : data_types) { - DataType proxy_output_type = convert_aten_unsupported_dtype(output_type); - outputs.emplace_back(t.to(data_type_to_aten(proxy_output_type))); + outputs.emplace_back(t.to(data_type_to_aten(output_type))); } } diff --git a/tests/cpp/test_mbarrier.cpp b/tests/cpp/test_mbarrier.cpp index f7f9611d895..e643e1a2dc3 100644 --- a/tests/cpp/test_mbarrier.cpp +++ b/tests/cpp/test_mbarrier.cpp @@ -62,7 +62,7 @@ TEST_F(MBarrierTest, Simple) { summary.dynamic_smem_allocations; ASSERT_EQ(dynamic_smem_allocations.size(), 1); - TensorView* mbarrier = makeContigConcreteTensor({}, DataType::UInt); + TensorView* mbarrier = makeContigConcreteTensor({}, DataType::UInt64); mbarrier->setMemoryType(MemoryType::Shared); kir::Allocate* mbarrier_alloc = IrBuilder::create(mbarrier, MemoryType::Shared); @@ -107,7 +107,7 @@ TEST_F(MBarrierTest, Simple) { return expr->isA(); }); ASSERT_NE(sync_it, top_level_exprs.end()); - auto state = IrBuilder::create(DataType::UInt); + auto state = IrBuilder::create(DataType::UInt64); auto alloc_state = IrBuilder::create( state, MemoryType::Local, kernel->oneVal()); auto arrive = IrBuilder::create(state, mbarrier_index);