Skip to content

Commit

Permalink
Add raypayload decoration to ray payload structs
Browse files Browse the repository at this point in the history
Closes #6104
  • Loading branch information
expipiplus1 committed Jan 23, 2025
1 parent a9ce752 commit aa8bef7
Show file tree
Hide file tree
Showing 7 changed files with 104 additions and 12 deletions.
15 changes: 11 additions & 4 deletions source/slang/hlsl.meta.slang
Original file line number Diff line number Diff line change
Expand Up @@ -16105,6 +16105,13 @@ __generic<T>
__intrinsic_op($(kIROp_ForceVarIntoStructTemporarily))
Ref<T> __forceVarIntoStructTemporarily(inout T maybeStruct);

// Some functions require a struct type which is decorated with a [raypayload]
// attribute. This will do the same as __forceVarIntoStructTemporarily and also
// ensure that the struct type in question is decorated appropriately.
__generic<T>
__intrinsic_op($(kIROp_ForceVarIntoRayPayloadStructTemporarily))
Ref<T> __forceVarIntoRayPayloadStructTemporarily(inout T maybeStruct);

__generic<payload_t>
[require(hlsl, raytracing)]
void __traceRayHLSL(
Expand Down Expand Up @@ -16189,7 +16196,7 @@ void TraceRay(
MultiplierForGeometryContributionToHitGroupIndex,
MissShaderIndex,
Ray,
__forceVarIntoStructTemporarily(Payload));
__forceVarIntoRayPayloadStructTemporarily(Payload));
return;
case cuda: __intrinsic_asm "traceOptiXRay";
case glsl:
Expand Down Expand Up @@ -16327,7 +16334,7 @@ void TraceMotionRay(
MissShaderIndex,
Ray,
CurrentTime,
__forceVarIntoStructTemporarily(Payload));
__forceVarIntoRayPayloadStructTemporarily(Payload));
return;
case glsl:
{
Expand Down Expand Up @@ -18471,7 +18478,7 @@ struct HitObject
MultiplierForGeometryContributionToHitGroupIndex,
MissShaderIndex,
Ray,
__forceVarIntoStructTemporarily(Payload),
__forceVarIntoRayPayloadStructTemporarily(Payload),
hitObj);
return hitObj;
}
Expand Down Expand Up @@ -18564,7 +18571,7 @@ struct HitObject
MissShaderIndex,
Ray,
CurrentTime,
__forceVarIntoStructTemporarily(Payload));
__forceVarIntoRayPayloadStructTemporarily(Payload));
case glsl:
{
[__vulkanRayPayload]
Expand Down
4 changes: 4 additions & 0 deletions source/slang/slang-emit-hlsl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1579,6 +1579,10 @@ void HLSLSourceEmitter::emitPostKeywordTypeAttributesImpl(IRInst* inst)
{
m_writer->emit("[payload] ");
}
if (const auto payloadDecoration = inst->findDecoration<IRRayPayloadDecoration>())
{
m_writer->emit("[raypayload] ");
}
}

void HLSLSourceEmitter::_emitPrefixTypeAttr(IRAttr* attr)
Expand Down
25 changes: 17 additions & 8 deletions source/slang/slang-ir-hlsl-legalize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,20 @@ void searchChildrenForForceVarIntoStructTemporarily(IRModule* module, IRInst* in
for (UInt i = 0; i < call->getArgCount(); i++)
{
auto arg = call->getArg(i);
if (arg->getOp() != kIROp_ForceVarIntoStructTemporarily)
const bool isForcedStruct = arg->getOp() == kIROp_ForceVarIntoStructTemporarily;
const bool isForcedRayPayloadStruct =
arg->getOp() == kIROp_ForceVarIntoRayPayloadStructTemporarily;
if (!(isForcedStruct || isForcedRayPayloadStruct))
continue;
auto forceStructArg = arg->getOperand(0);
auto forceStructBaseType =
as<IRType>(forceStructArg->getDataType()->getOperand(0));
IRBuilder builder(call);
if (forceStructBaseType->getOp() == kIROp_StructType)
{
call->setArg(i, arg->getOperand(0));
if (isForcedRayPayloadStruct)
builder.addRayPayloadDecoration(forceStructBaseType);
continue;
}

Expand All @@ -47,14 +53,19 @@ void searchChildrenForForceVarIntoStructTemporarily(IRModule* module, IRInst* in
// `__forceVarIntoStructTemporarily` is a parameter to a side effect type
// (`ref`, `out`, `inout`) we copy the struct back into our original non-struct
// parameter.
IRBuilder builder(call);

const auto typeNameHint = isForcedRayPayloadStruct
? "RayPayload_t"
: "ForceVarIntoStructTemporarily_t";
const auto varNameHint =
isForcedRayPayloadStruct ? "rayPayload" : "forceVarIntoStructTemporarily";

builder.setInsertBefore(call->getCallee());
auto structType = builder.createStructType();
StringBuilder structName;
builder.addNameHintDecoration(
structType,
UnownedStringSlice("ForceVarIntoStructTemporarily_t"));
builder.addNameHintDecoration(structType, UnownedStringSlice(typeNameHint));
if (isForcedRayPayloadStruct)
builder.addRayPayloadDecoration(structType);

auto elementBufferKey = builder.createStructKey();
builder.addNameHintDecoration(elementBufferKey, UnownedStringSlice("data"));
Expand All @@ -65,9 +76,7 @@ void searchChildrenForForceVarIntoStructTemporarily(IRModule* module, IRInst* in

builder.setInsertBefore(call);
auto structVar = builder.emitVar(structType);
builder.addNameHintDecoration(
structVar,
UnownedStringSlice("forceVarIntoStructTemporarily"));
builder.addNameHintDecoration(structVar, UnownedStringSlice(varNameHint));
builder.emitStore(
builder.emitFieldAddress(
builder.getPtrType(_dataField->getFieldType()),
Expand Down
4 changes: 4 additions & 0 deletions source/slang/slang-ir-inst-defs.h
Original file line number Diff line number Diff line change
Expand Up @@ -751,6 +751,9 @@ INST(GetPerVertexInputArray, GetPerVertexInputArray, 1, HOISTABLE)
INST(ResolveVaryingInputRef, ResolveVaryingInputRef, 1, HOISTABLE)

INST(ForceVarIntoStructTemporarily, ForceVarIntoStructTemporarily, 1, 0)
INST(ForceVarIntoRayPayloadStructTemporarily, ForceVarIntoRayPayloadStructTemporarily, 1, 0)
INST_RANGE(ForceVarIntoStructTemporarily, ForceVarIntoStructTemporarily, ForceVarIntoRayPayloadStructTemporarily)

INST(MetalAtomicCast, MetalAtomicCast, 1, 0)

INST(IsTextureAccess, IsTextureAccess, 1, 0)
Expand Down Expand Up @@ -982,6 +985,7 @@ INST_RANGE(BindingQuery, GetRegisterIndex, GetRegisterSpace)
INST(GLSLLocationDecoration, glslLocation, 1, 0)
INST(GLSLOffsetDecoration, glslOffset, 1, 0)
INST(PayloadDecoration, payload, 0, 0)
INST(RayPayloadDecoration, raypayload, 0, 0)

/* Mesh Shader outputs */
INST(VerticesDecoration, vertices, 1, 0)
Expand Down
7 changes: 7 additions & 0 deletions source/slang/slang-ir-insts.h
Original file line number Diff line number Diff line change
Expand Up @@ -1585,6 +1585,11 @@ struct IRPayloadDecoration : public IRDecoration
IR_LEAF_ISA(PayloadDecoration)
};

struct IRRayPayloadDecoration : public IRDecoration
{
IR_LEAF_ISA(RayPayloadDecoration)
};

// Mesh shader decorations

struct IRMeshOutputDecoration : public IRDecoration
Expand Down Expand Up @@ -5246,6 +5251,8 @@ struct IRBuilder
{
addDecoration(inst, kIROp_EntryPointParamDecoration, entryPointFunc);
}

void addRayPayloadDecoration(IRType* inst) { addDecoration(inst, kIROp_RayPayloadDecoration); }
};

// Helper to establish the source location that will be used
Expand Down
28 changes: 28 additions & 0 deletions tests/hlsl/raypayload-attribute-no-struct.slang
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
//TEST:SIMPLE(filecheck=CHECK): -target hlsl -stage raygeneration -entry rayGenShaderA

// CHECK: struct [raypayload]

uniform RWTexture2D resultTexture;
uniform RaytracingAccelerationStructure sceneBVH;

[shader("raygeneration")]
void rayGenShaderA()
{
int2 threadIdx = DispatchRaysIndex().xy;

float3 rayDir = float3(0, 0, 1);
float3 rayOrigin = 0;
rayOrigin.x = (threadIdx.x * 2) - 1;
rayOrigin.y = (threadIdx.y * 2) - 1;

// Trace the ray.
RayDesc ray;
ray.Origin = rayOrigin;
ray.Direction = rayDir;
ray.TMin = 0.001;
ray.TMax = 10000.0;
float4 payload = float4(0, 0, 0, 0);
TraceRay(sceneBVH, RAY_FLAG_NONE, ~0, 0, 0, 0, ray, payload);

resultTexture[threadIdx.xy] = payload;
}
33 changes: 33 additions & 0 deletions tests/hlsl/raypayload-attribute.slang
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
//TEST:SIMPLE(filecheck=CHECK): -target hlsl -stage raygeneration -entry rayGenShaderA

// CHECK: struct [raypayload]

struct RayPayload
{
float4 color;
};

uniform RWTexture2D resultTexture;
uniform RaytracingAccelerationStructure sceneBVH;

[shader("raygeneration")]
void rayGenShaderA()
{
int2 threadIdx = DispatchRaysIndex().xy;

float3 rayDir = float3(0, 0, 1);
float3 rayOrigin = 0;
rayOrigin.x = (threadIdx.x * 2) - 1;
rayOrigin.y = (threadIdx.y * 2) - 1;

// Trace the ray.
RayDesc ray;
ray.Origin = rayOrigin;
ray.Direction = rayDir;
ray.TMin = 0.001;
ray.TMax = 10000.0;
RayPayload payload = { float4(0, 0, 0, 0) };
TraceRay(sceneBVH, RAY_FLAG_NONE, ~0, 0, 0, 0, ray, payload);

resultTexture[threadIdx.xy] = payload.color;
}

0 comments on commit aa8bef7

Please sign in to comment.