Skip to content

Commit

Permalink
Add packed 8bit builtin types (#5939)
Browse files Browse the repository at this point in the history
* Add packed bytes builtin type

* fix test
  • Loading branch information
fairywreath authored Dec 27, 2024
1 parent 2ad1f81 commit 7cecc51
Show file tree
Hide file tree
Showing 26 changed files with 273 additions and 0 deletions.
29 changes: 29 additions & 0 deletions source/slang-core-module/slang-embedded-core-module-source.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,12 @@ enum BaseTypeConversionRank : uint8_t
kBaseTypeConversionRank_Int32,
kBaseTypeConversionRank_IntPtr,
kBaseTypeConversionRank_Int64,

// Packed type conversion ranks where the overall rank order does not apply.
// They must be explicitly casted to another type.
kBaseTypeConversionRank_Int8x4Packed,
kBaseTypeConversionRank_UInt8x4Packed,

kBaseTypeConversionRank_Error,
};

Expand Down Expand Up @@ -150,6 +156,16 @@ static const BaseTypeConversionInfo kBaseTypes[] = {
kBaseTypeConversionKind_Unsigned,
kBaseTypeConversionRank_IntPtr},

{"int8_t4_packed",
BaseType::Int8x4Packed,
0,
kBaseTypeConversionKind_Unsigned,
kBaseTypeConversionRank_Int8x4Packed},
{"uint8_t4_packed",
BaseType::UInt8x4Packed,
0,
kBaseTypeConversionKind_Unsigned,
kBaseTypeConversionRank_UInt8x4Packed},
};

void Session::finalizeSharedASTBuilder()
Expand All @@ -176,6 +192,11 @@ void Session::finalizeSharedASTBuilder()
globalAstBuilder->getBuiltinType(baseType.tag);
}

static bool isConversionRankPackedType(BaseTypeConversionRank rank)
{
return (rank == BaseTypeConversionRank::kBaseTypeConversionRank_Int8x4Packed) ||
(rank == BaseTypeConversionRank::kBaseTypeConversionRank_UInt8x4Packed);
}

// Given two base types, we need to be able to compute the cost of converting between them.
ConversionCost getBaseTypeConversionCost(
Expand All @@ -189,6 +210,14 @@ ConversionCost getBaseTypeConversionCost(
return kConversionCost_None;
}

// Handle special case for packed types, where they must be explicitly casted to another type.
bool isToPackedType = isConversionRankPackedType(toInfo.conversionRank);
bool isFromPackedType = isConversionRankPackedType(fromInfo.conversionRank);
if (isToPackedType || isFromPackedType)
{
return kConversionCost_GeneralConversion;
}

// Conversions within the same kind are easist to handle
if (toInfo.conversionKind == fromInfo.conversionKind)
{
Expand Down
2 changes: 2 additions & 0 deletions source/slang/slang-check-conversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -699,6 +699,8 @@ int getTypeBitSize(Type* t)
return 16;
case BaseType::Int:
case BaseType::UInt:
case BaseType::Int8x4Packed:
case BaseType::UInt8x4Packed:
return 32;
case BaseType::Int64:
case BaseType::UInt64:
Expand Down
2 changes: 2 additions & 0 deletions source/slang/slang-check-decl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1939,6 +1939,8 @@ void SemanticsDeclHeaderVisitor::checkVarDeclCommon(VarDeclBase* varDecl)
case BaseType::UInt:
case BaseType::UInt64:
case BaseType::UIntPtr:
case BaseType::Int8x4Packed:
case BaseType::UInt8x4Packed:
break;
default:
getSink()->diagnose(varDecl, Diagnostics::staticConstRequirementMustBeIntOrBool);
Expand Down
9 changes: 9 additions & 0 deletions source/slang/slang-emit-c-like.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,11 @@ void CLikeSourceEmitter::emitSimpleType(IRType* type)
case kIROp_UIntPtrType:
return UnownedStringSlice("uintptr_t");

case kIROp_Int8x4PackedType:
return UnownedStringSlice("int8_t4_packed");
case kIROp_UInt8x4PackedType:
return UnownedStringSlice("uint8_t4_packed");

case kIROp_HalfType:
return UnownedStringSlice("half");

Expand Down Expand Up @@ -1272,6 +1277,8 @@ void CLikeSourceEmitter::emitSimpleValueImpl(IRInst* inst)
return;
}
case BaseType::UInt:
case BaseType::Int8x4Packed:
case BaseType::UInt8x4Packed:
{
m_writer->emit(UInt(uint32_t(litInst->value.intVal)));
m_writer->emit("U");
Expand Down Expand Up @@ -3896,6 +3903,8 @@ void CLikeSourceEmitter::emitVecNOrScalar(
m_writer->emit("ushort");
break;
case kIROp_UIntType:
case kIROp_Int8x4PackedType:
case kIROp_UInt8x4PackedType:
m_writer->emit("uint");
break;
case kIROp_UInt64Type:
Expand Down
4 changes: 4 additions & 0 deletions source/slang/slang-emit-cpp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,10 @@ static const char s_xyzwNames[] = "xyzw";
case kIROp_UIntPtrType:
return UnownedStringSlice("uintptr_t");

case kIROp_Int8x4PackedType:
case kIROp_UInt8x4PackedType:
return UnownedStringSlice("uint32_t");

// Not clear just yet how we should handle half... we want all processing as float
// probly, but when reading/writing to memory converting

Expand Down
4 changes: 4 additions & 0 deletions source/slang/slang-emit-cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,10 @@ UnownedStringSlice CUDASourceEmitter::getBuiltinTypeName(IROp op)
case kIROp_UIntPtrType:
return UnownedStringSlice("uint");
#endif
case kIROp_Int8x4PackedType:
case kIROp_UInt8x4PackedType:
return UnownedStringSlice("uint");

case kIROp_HalfType:
return UnownedStringSlice("__half");

Expand Down
16 changes: 16 additions & 0 deletions source/slang/slang-emit-glsl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1236,6 +1236,8 @@ void GLSLSourceEmitter::emitSimpleValueImpl(IRInst* inst)
return;
}
case BaseType::UInt:
case BaseType::Int8x4Packed:
case BaseType::UInt8x4Packed:
{
m_writer->emit(UInt(uint32_t(litInst->value.intVal)));
m_writer->emit("U");
Expand Down Expand Up @@ -1984,6 +1986,8 @@ bool GLSLSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOu
break;

case BaseType::UInt:
case BaseType::Int8x4Packed:
case BaseType::UInt8x4Packed:
if (fromType == BaseType::Float)
{
m_writer->emit("floatBitsToUint");
Expand Down Expand Up @@ -3050,6 +3054,18 @@ void GLSLSourceEmitter::emitSimpleTypeImpl(IRType* type)
#endif
return;
}
case kIROp_Int8x4PackedType:
{
_requireBaseType(BaseType::Int8x4Packed);
m_writer->emit("uint");
return;
}
case kIROp_UInt8x4PackedType:
{
_requireBaseType(BaseType::UInt8x4Packed);
m_writer->emit("uint");
return;
}
case kIROp_VoidType:
case kIROp_BoolType:
case kIROp_Int8Type:
Expand Down
6 changes: 6 additions & 0 deletions source/slang/slang-emit-hlsl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -822,6 +822,8 @@ bool HLSLSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOu
case BaseType::UInt64:
case BaseType::UIntPtr:
case BaseType::Bool:
case BaseType::Int8x4Packed:
case BaseType::UInt8x4Packed:
// Because the intermediate type will always
// be an integer type, we can convert to
// another integer type of the same size
Expand Down Expand Up @@ -861,6 +863,8 @@ bool HLSLSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOu
case BaseType::UInt:
case BaseType::Int:
case BaseType::Bool:
case BaseType::Int8x4Packed:
case BaseType::UInt8x4Packed:
break;
case BaseType::UInt16:
case BaseType::Int16:
Expand Down Expand Up @@ -1193,6 +1197,8 @@ void HLSLSourceEmitter::emitSimpleTypeImpl(IRType* type)
case kIROp_Int16Type:
case kIROp_UInt16Type:
case kIROp_HalfType:
case kIROp_Int8x4PackedType:
case kIROp_UInt8x4PackedType:
{
m_writer->emit(getDefaultBuiltinTypeName(type->getOp()));
return;
Expand Down
4 changes: 4 additions & 0 deletions source/slang/slang-emit-metal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1076,6 +1076,10 @@ void MetalSourceEmitter::emitSimpleTypeImpl(IRType* type)
case kIROp_UIntPtrType:
m_writer->emit("ulong");
return;
case kIROp_Int8x4PackedType:
case kIROp_UInt8x4PackedType:
m_writer->emit("uint");
return;
case kIROp_StructType:
m_writer->emit(getName(type));
return;
Expand Down
4 changes: 4 additions & 0 deletions source/slang/slang-emit-spirv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1465,6 +1465,8 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
case kIROp_Int8Type:
case kIROp_IntType:
case kIROp_Int64Type:
case kIROp_Int8x4PackedType:
case kIROp_UInt8x4PackedType:
{
const IntInfo i = getIntTypeInfo(as<IRType>(inst));
if (i.width == 16)
Expand Down Expand Up @@ -7366,6 +7368,8 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
case kIROp_UInt64Type:
case kIROp_UInt8Type:
case kIROp_UIntPtrType:
case kIROp_Int8x4PackedType:
case kIROp_UInt8x4PackedType:
spvEncoding = 6; // Unsigned
break;
case kIROp_FloatType:
Expand Down
6 changes: 6 additions & 0 deletions source/slang/slang-emit-wgsl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,10 @@ void WGSLSourceEmitter::emitSimpleTypeImpl(IRType* type)
case kIROp_UIntPtrType:
m_writer->emit("u64");
return;
case kIROp_Int8x4PackedType:
case kIROp_UInt8x4PackedType:
m_writer->emit("u32");
return;
case kIROp_StructType:
m_writer->emit(getName(type));
return;
Expand Down Expand Up @@ -940,6 +944,8 @@ void WGSLSourceEmitter::emitSimpleValueImpl(IRInst* inst)
return;
}
case BaseType::UInt:
case BaseType::Int8x4Packed:
case BaseType::UInt8x4Packed:
{
m_writer->emit("u32(");
m_writer->emit(UInt(uint32_t(litInst->value.intVal)));
Expand Down
8 changes: 8 additions & 0 deletions source/slang/slang-ir-any-value-marshalling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,8 @@ struct AnyValueMarshallingContext
case kIROp_IntPtrType:
case kIROp_UIntPtrType:
case kIROp_PtrType:
case kIROp_Int8x4PackedType:
case kIROp_UInt8x4PackedType:
context->marshalBasicType(builder, dataType, concreteTypedVar);
break;
case kIROp_VectorType:
Expand Down Expand Up @@ -309,6 +311,8 @@ struct AnyValueMarshallingContext
break;
}
case kIROp_UIntType:
case kIROp_Int8x4PackedType:
case kIROp_UInt8x4PackedType:
#if SLANG_PTR_IS_32
case kIROp_UIntPtrType:
#endif
Expand Down Expand Up @@ -537,6 +541,8 @@ struct AnyValueMarshallingContext
break;
}
case kIROp_UIntType:
case kIROp_Int8x4PackedType:
case kIROp_UInt8x4PackedType:
{
ensureOffsetAt4ByteBoundary();
if (fieldOffset < static_cast<uint32_t>(anyValInfo->fieldKeys.getCount()))
Expand Down Expand Up @@ -812,6 +818,8 @@ SlangInt _getAnyValueSizeRaw(IRType* type, SlangInt offset)
case kIROp_FloatType:
case kIROp_UIntType:
case kIROp_BoolType:
case kIROp_Int8x4PackedType:
case kIROp_UInt8x4PackedType:
return alignUp(offset, 4) + 4;
case kIROp_UInt64Type:
case kIROp_Int64Type:
Expand Down
2 changes: 2 additions & 0 deletions source/slang/slang-ir-byte-address-legalize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -840,6 +840,8 @@ struct ByteAddressBufferLegalizationContext
case kIROp_IntType:
case kIROp_FloatType:
case kIROp_BoolType:
case kIROp_Int8x4PackedType:
case kIROp_UInt8x4PackedType:
// The basic 32-bit types (and `bool`) can be handled by
// loading `uint` values and then bit-casting.
//
Expand Down
3 changes: 3 additions & 0 deletions source/slang/slang-ir-layout.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,9 @@ static Result _calcSizeAndAlignment(
BASE(UIntPtr, kPointerSize);
BASE(Double, 8);

BASE(Int8x4Packed, 4);
BASE(UInt8x4Packed, 4);

// We are currently handling `bool` following the HLSL
// precednet of storing it in 4 bytes.
//
Expand Down
2 changes: 2 additions & 0 deletions source/slang/slang-ir-lower-bit-cast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,8 @@ struct BitCastLoweringContext
case kIROp_UIntType:
case kIROp_FloatType:
case kIROp_BoolType:
case kIROp_Int8x4PackedType:
case kIROp_UInt8x4PackedType:
#if SLANG_PTR_IS_32
case kIROp_IntPtrType:
case kIROp_UIntPtrType:
Expand Down
14 changes: 14 additions & 0 deletions source/slang/slang-ir-util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ IROp getTypeStyle(IROp op)
case kIROp_UInt64Type:
case kIROp_IntPtrType:
case kIROp_UIntPtrType:
case kIROp_Int8x4PackedType:
case kIROp_UInt8x4PackedType:
{
// All int like
return kIROp_IntType;
Expand Down Expand Up @@ -140,6 +142,8 @@ IROp getTypeStyle(BaseType op)
case BaseType::UInt:
case BaseType::UInt64:
case BaseType::UIntPtr:
case BaseType::Int8x4Packed:
case BaseType::UInt8x4Packed:
return kIROp_IntType;
case BaseType::Half:
case BaseType::Float:
Expand Down Expand Up @@ -445,6 +449,12 @@ void getTypeNameHint(StringBuilder& sb, IRInst* type)
case kIROp_UIntPtrType:
sb << "uintptr";
break;
case kIROp_Int8x4PackedType:
sb << "int8_t4_packed";
break;
case kIROp_UInt8x4PackedType:
sb << "uint8_t4_packed";
break;
case kIROp_CharType:
sb << "char";
break;
Expand Down Expand Up @@ -1735,6 +1745,10 @@ UnownedStringSlice getBasicTypeNameHint(IRType* basicType)
return UnownedStringSlice::fromLiteral("uint64");
case kIROp_UIntPtrType:
return UnownedStringSlice::fromLiteral("uintptr");
case kIROp_Int8x4PackedType:
return UnownedStringSlice::fromLiteral("int8_t4_packed");
case kIROp_UInt8x4PackedType:
return UnownedStringSlice::fromLiteral("uint8_t4_packed");
case kIROp_FloatType:
return UnownedStringSlice::fromLiteral("float");
case kIROp_HalfType:
Expand Down
8 changes: 8 additions & 0 deletions source/slang/slang-ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3798,6 +3798,8 @@ IRInst* IRBuilder::emitDefaultConstruct(IRType* type, bool fallback)
case kIROp_UIntType:
case kIROp_UIntPtrType:
case kIROp_UInt64Type:
case kIROp_Int8x4PackedType:
case kIROp_UInt8x4PackedType:
case kIROp_CharType:
return getIntValue(type, 0);
case kIROp_BoolType:
Expand Down Expand Up @@ -7421,6 +7423,8 @@ bool isIntegralType(IRType* t)
case BaseType::UInt64:
case BaseType::IntPtr:
case BaseType::UIntPtr:
case BaseType::Int8x4Packed:
case BaseType::UInt8x4Packed:
return true;
default:
return false;
Expand Down Expand Up @@ -7467,6 +7471,10 @@ IntInfo getIntTypeInfo(const IRType* intType)
case kIROp_Int64Type:
return {64, true};

case kIROp_Int8x4PackedType:
case kIROp_UInt8x4PackedType:
return {32, false};

case kIROp_IntPtrType: // target platform dependent
case kIROp_UIntPtrType: // target platform dependent
default:
Expand Down
2 changes: 2 additions & 0 deletions source/slang/slang-lower-to-ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4551,6 +4551,8 @@ struct ExprLoweringVisitorBase : public ExprVisitor<Derived, LoweredValInfo>
case BaseType::UInt64:
case BaseType::UIntPtr:
case BaseType::IntPtr:
case BaseType::Int8x4Packed:
case BaseType::UInt8x4Packed:
return LoweredValInfo::simple(getBuilder()->getIntValue(type, 0));

case BaseType::Half:
Expand Down
5 changes: 5 additions & 0 deletions source/slang/slang-mangle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,11 @@ void emitBaseType(ManglingContext* context, BaseType baseType)
case BaseType::IntPtr:
emitRaw(context, "ip");
break;
case BaseType::Int8x4Packed:
emitRaw(context, "c4p");
break;
case BaseType::UInt8x4Packed:
emitRaw(context, "C4p");
break;

default:
Expand Down
Loading

0 comments on commit 7cecc51

Please sign in to comment.