Skip to content

Commit

Permalink
Support specialization constant on WGSL and Metal. (#5780)
Browse files Browse the repository at this point in the history
  • Loading branch information
csyonghe authored Dec 6, 2024
1 parent 22b64a4 commit 8ce7c6f
Show file tree
Hide file tree
Showing 12 changed files with 227 additions and 75 deletions.
19 changes: 18 additions & 1 deletion docs/user-guide/a2-02-metal-target-specific.md
Original file line number Diff line number Diff line change
Expand Up @@ -274,4 +274,21 @@ The HLSL `:register()` semantic is respected when emitting Metal code.

Since metal does not differentiate a constant buffer, a shader resource (read-only) buffer and an unordered access buffer, Slang will map `register(tN)`, `register(uN)` and `register(bN)` to `[[buffer(N)]]` when such `register` semantic is declared on a buffer typed parameter.

`spaceN` specifiers inside `register` semantics are ignored.
`spaceN` specifiers inside `register` semantics are ignored.

## Specialization Constants

Specialization constants declared with the `[SpecializationConstant]` or `[vk::constant_id]` attribute will be translated into a `function_constant` when generating Metal source.
For example:

```csharp
[vk::constant_id(7)]
const int a = 2;
```

Translates to:

```metal
constant int fc_a_0 [[function_constant(7)]];
constant int a_0 = is_function_constant_defined(fc_a_0) ? fc_a_0 : 2;
```
18 changes: 17 additions & 1 deletion docs/user-guide/a2-03-wgsl-target-specific.md
Original file line number Diff line number Diff line change
Expand Up @@ -161,4 +161,20 @@ Since the WGSL matrix multiplication convention is the normal one, where inner p

The `[vk::binding(index,set)]` attribute is respected when emitting WGSL code, and will translate to `@binding(index) @group(set)` in WGSL.

If the `[vk::binding()]` attribute is not specified by a `:register()` semantic is present, Slang will derive the binding from the `register` semantic the same way as the SPIRV and GLSL backends.
If the `[vk::binding()]` attribute is not specified by a `:register()` semantic is present, Slang will derive the binding from the `register` semantic the same way as the SPIRV and GLSL backends.

## Specialization Constants

Specialization constants declared with the `[SpecializationConstant]` or `[vk::constant_id]` attribute will be translated into a global `override` declaration when generating WGSL source.
For example:

```csharp
[vk::constant_id(7)]
const int a = 2;
```

Translates to:

```wgsl
@id(7) override a : i32 = 2;
```
2 changes: 2 additions & 0 deletions docs/user-guide/toc.html
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@
<li data-link="metal-target-specific#conservative-rasterization"><span>Conservative Rasterization</span></li>
<li data-link="metal-target-specific#address-space-assignment"><span>Address Space Assignment</span></li>
<li data-link="metal-target-specific#explicit-parameter-binding"><span>Explicit Parameter Binding</span></li>
<li data-link="metal-target-specific#specialization-constants"><span>Specialization Constants</span></li>
</ul>
</li>
<li data-link="wgsl-target-specific"><span>WGSL specific functionalities</span>
Expand All @@ -251,6 +252,7 @@
<li data-link="wgsl-target-specific#address-space-assignment"><span>Address Space Assignment</span></li>
<li data-link="wgsl-target-specific#matrix-type-translation"><span>Matrix type translation</span></li>
<li data-link="wgsl-target-specific#explicit-parameter-binding"><span>Explicit Parameter Binding</span></li>
<li data-link="wgsl-target-specific#specialization-constants"><span>Specialization Constants</span></li>
</ul>
</li>
</ul>
Expand Down
46 changes: 45 additions & 1 deletion source/slang/slang-emit-metal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1331,17 +1331,60 @@ bool MetalSourceEmitter::_emitUserSemantic(
return false;
}

bool MetalSourceEmitter::tryEmitGlobalParamImpl(IRGlobalParam* varDecl, IRType* varType)
{
auto layout = getVarLayout(varDecl);
if (!layout)
return false;
if (auto specConstLayout = layout->findOffsetAttr(LayoutResourceKind::SpecializationConstant))
{
// Emit specialization constant.
auto name = getName(varDecl);
auto prefixName = "fc_" + name;
auto defaultVal = varDecl->findDecoration<IRDefaultValueDecoration>();

m_writer->emit("constant ");
emitType(varType, prefixName);
m_writer->emit(" ");
m_writer->emit("[[function_constant(");
m_writer->emit(specConstLayout->getOffset());
m_writer->emit(")]];\n");

m_writer->emit("constant ");
emitType(varType, name);
m_writer->emit(" = ");
if (defaultVal)
{
m_writer->emit("is_function_constant_defined(");
m_writer->emit(prefixName);
m_writer->emit(") ? ");
m_writer->emit(prefixName);
m_writer->emit(" : ");
emitVal(defaultVal->getOperand(0), getInfo(EmitOp::General));
}
else
{
m_writer->emit(prefixName);
}
m_writer->emit(";\n");
return true;
}
return false;
}

void MetalSourceEmitter::emitSemanticsImpl(IRInst* inst, bool allowOffsets)
{
SLANG_UNUSED(allowOffsets);

auto varLayout = findVarLayout(inst);

if (inst->getOp() == kIROp_StructKey)
{
// Only emit [[attribute(n)]] on struct keys.

if (maybeEmitSystemSemantic(inst))
return;

auto varLayout = findVarLayout(inst);
bool hasSemantic = false;

if (varLayout)
Expand Down Expand Up @@ -1378,6 +1421,7 @@ void MetalSourceEmitter::emitSemanticsImpl(IRInst* inst, bool allowOffsets)
semanticDecor->getSemanticIndex());
}
}
return;
}
}

Expand Down
1 change: 1 addition & 0 deletions source/slang/slang-emit-metal.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class MetalSourceEmitter : public CLikeSourceEmitter

virtual void emitRateQualifiersAndAddressSpaceImpl(IRRate* rate, AddressSpace addressSpace)
SLANG_OVERRIDE;
virtual bool tryEmitGlobalParamImpl(IRGlobalParam* varDecl, IRType* varType) SLANG_OVERRIDE;
virtual void emitSemanticsImpl(IRInst* inst, bool allowOffsets) SLANG_OVERRIDE;
virtual void emitSimpleFuncParamImpl(IRParam* param) SLANG_OVERRIDE;
virtual void emitPostDeclarationAttributesForType(IRInst* type) SLANG_OVERRIDE;
Expand Down
33 changes: 32 additions & 1 deletion source/slang/slang-emit-wgsl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -643,6 +643,21 @@ void WGSLSourceEmitter::emitSimpleTypeImpl(IRType* type)
}
}

void WGSLSourceEmitter::emitGlobalParamDefaultVal(IRGlobalParam* varDecl)
{
auto layout = getVarLayout(varDecl);
if (!layout)
return;
if (layout->findOffsetAttr(LayoutResourceKind::SpecializationConstant))
{
if (auto defaultValDecor = varDecl->findDecoration<IRDefaultValueDecoration>())
{
m_writer->emit(" = ");
emitInstExpr(defaultValDecor->getOperand(0), EmitOpInfo());
}
}
}

void WGSLSourceEmitter::emitLayoutQualifiersImpl(IRVarLayout* layout)
{

Expand All @@ -668,6 +683,14 @@ void WGSLSourceEmitter::emitLayoutQualifiersImpl(IRVarLayout* layout)
m_writer->emit(space);
m_writer->emit(") ");

return;
}
else if (kind == LayoutResourceKind::SpecializationConstant)
{
m_writer->emit("@id(");
m_writer->emit(attr->getOffset());
m_writer->emit(") ");

return;
}
}
Expand Down Expand Up @@ -708,7 +731,15 @@ void WGSLSourceEmitter::emitVarKeywordImpl(IRType* type, IRInst* varDecl)
case kIROp_GlobalParam:
case kIROp_GlobalVar:
case kIROp_Var:
m_writer->emit("var");
{
auto layout = getVarLayout(varDecl);
if (layout && layout->findOffsetAttr(LayoutResourceKind::SpecializationConstant))
{
m_writer->emit("override");
break;
}
m_writer->emit("var");
}
break;
default:
if (isStaticConst(varDecl))
Expand Down
1 change: 1 addition & 0 deletions source/slang/slang-emit-wgsl.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ class WGSLSourceEmitter : public CLikeSourceEmitter
UnownedStringSlice intrinsicDefinition,
IRInst* intrinsicInst,
EmitOpInfo const& inOuterPrec) SLANG_OVERRIDE;
virtual void emitGlobalParamDefaultVal(IRGlobalParam* varDecl) SLANG_OVERRIDE;

void emit(const AddressSpace addressSpace);

Expand Down
24 changes: 21 additions & 3 deletions source/slang/slang-ir-explicit-global-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include "slang-ir-clone.h"
#include "slang-ir-insts.h"
#include "slang-ir-util.h"

namespace Slang
{
Expand Down Expand Up @@ -39,7 +40,8 @@ struct IntroduceExplicitGlobalContextPass
class ExplicitContextPolicy
{
public:
ExplicitContextPolicy(CodeGenTarget target)
ExplicitContextPolicy(CodeGenTarget inTarget)
: target(inTarget)
{
switch (target)
{
Expand Down Expand Up @@ -74,7 +76,7 @@ struct IntroduceExplicitGlobalContextPass
return (UInt)hoistableGlobalObjectKind & (UInt)hoistable;
}

bool canHoistGlobalVar(IRGlobalVar* inst)
bool canHoistGlobalVar(IRInst* inst)
{
if (!((UInt)hoistGlobalVarOptions & (UInt)HoistGlobalVarOptions::SharedGlobal) &&
as<IRGroupSharedRate>(inst->getRate()))
Expand All @@ -99,6 +101,19 @@ struct IntroduceExplicitGlobalContextPass
}
}

// Do not move specialization constants to context.
switch (target)
{
case CodeGenTarget::Metal:
case CodeGenTarget::MetalLib:
case CodeGenTarget::MetalLibAssembly:
{
auto varLayout = findVarLayout(inst);
if (varLayout &&
varLayout->findOffsetAttr(LayoutResourceKind::SpecializationConstant))
return false;
}
}
return true;
}

Expand All @@ -111,6 +126,7 @@ struct IntroduceExplicitGlobalContextPass
GlobalObjectKind hoistableGlobalObjectKind = GlobalObjectKind::All;
bool requiresFuncTypeCorrectionPass = false;
AddressSpace addressSpaceOfLocals = AddressSpace::ThreadLocal;
CodeGenTarget target;
};

IntroduceExplicitGlobalContextPass(IRModule* module, CodeGenTarget target)
Expand All @@ -134,7 +150,7 @@ struct IntroduceExplicitGlobalContextPass

bool canHoistType(GlobalObjectKind hoistable) { return m_options.canHoistType(hoistable); }

bool canHoistGlobalVar(IRGlobalVar* inst) { return m_options.canHoistGlobalVar(inst); }
bool canHoistGlobalVar(IRInst* inst) { return m_options.canHoistGlobalVar(inst); }

void processModule()
{
Expand Down Expand Up @@ -183,6 +199,8 @@ struct IntroduceExplicitGlobalContextPass
//
auto globalParam = cast<IRGlobalParam>(inst);

if (!canHoistGlobalVar(globalParam))
continue;

// One detail we need to be careful about is that as a result
// of legalizing the varying parameters of compute kernels to
Expand Down
Loading

0 comments on commit 8ce7c6f

Please sign in to comment.