Skip to content

Commit

Permalink
Implement explciit binding for metal and wgsl. (#5778)
Browse files Browse the repository at this point in the history
* Respect explicit bindings in wgsl emit.

* Implement explciit binding generation for metal and wgsl.

* Update toc.

* Fix warnings in tests.

* Fix tests.

---------

Co-authored-by: Ellie Hermaszewska <[email protected]>
  • Loading branch information
csyonghe and expipiplus1 authored Dec 6, 2024
1 parent ecc5a39 commit 7dabfa7
Show file tree
Hide file tree
Showing 30 changed files with 83 additions and 27 deletions.
8 changes: 8 additions & 0 deletions docs/user-guide/a2-02-metal-target-specific.md
Original file line number Diff line number Diff line change
Expand Up @@ -267,3 +267,11 @@ Metal requires explicit address space qualifiers. Slang automatically assigns ap
| RW/Structured Buffers | `device` |
| Group Shared | `threadgroup` |
| Parameter Blocks | `constant` |

## Explicit Parameter Binding

The HLSL `:register()` semantic is respected when emitting Metal code.

Since metal does not differentiate a constant buffer, a shader resource (read-only) buffer and an unordered access buffer, Slang will map `register(tN)`, `register(uN)` and `register(bN)` to `[[buffer(N)]]` when such `register` semantic is declared on a buffer typed parameter.

`spaceN` specifiers inside `register` semantics are ignored.
6 changes: 6 additions & 0 deletions docs/user-guide/a2-03-wgsl-target-specific.md
Original file line number Diff line number Diff line change
Expand Up @@ -156,3 +156,9 @@ Matrix type translation
A m-row-by-n-column matrix in Slang, represented as float`m`x`n` or matrix<T, m, n>, is translated to `mat[n]x[m]` in WGSL, i.e. a matrix with `n` columns and `m` rows.
The rationale for this inversion of terminology is the same as [the rationale for SPIR-V](a2-01-spirv-target-specific.md#matrix-type-translation).
Since the WGSL matrix multiplication convention is the normal one, where inner products of rows of the matrix on the left are taken with columns of the matrix on the right, the order of matrix products is also reversed in WGSL. This is relying on the fact that the transpose of a matrix product equals the product of the transposed matrix operands in reverse order.

## Explicit Parameter Binding

The `[vk::binding(index,set)]` attribute is respected when emitting WGSL code, and will translate to `@binding(index) @group(set)` in WGSL.

If the `[vk::binding()]` attribute is not specified by a `:register()` semantic is present, Slang will derive the binding from the `register` semantic the same way as the SPIRV and GLSL backends.
2 changes: 2 additions & 0 deletions docs/user-guide/toc.html
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@
<li data-link="metal-target-specific#value-type-conversion"><span>Value Type Conversion</span></li>
<li data-link="metal-target-specific#conservative-rasterization"><span>Conservative Rasterization</span></li>
<li data-link="metal-target-specific#address-space-assignment"><span>Address Space Assignment</span></li>
<li data-link="metal-target-specific#explicit-parameter-binding"><span>Explicit Parameter Binding</span></li>
</ul>
</li>
<li data-link="wgsl-target-specific"><span>WGSL specific functionalities</span>
Expand All @@ -249,6 +250,7 @@
<li data-link="wgsl-target-specific#pointers"><span>Pointers</span></li>
<li data-link="wgsl-target-specific#address-space-assignment"><span>Address Space Assignment</span></li>
<li data-link="wgsl-target-specific#matrix-type-translation"><span>Matrix type translation</span></li>
<li data-link="wgsl-target-specific#explicit-parameter-binding"><span>Explicit Parameter Binding</span></li>
</ul>
</li>
</ul>
Expand Down
2 changes: 1 addition & 1 deletion source/slang/slang-emit-wgsl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -663,7 +663,7 @@ void WGSLSourceEmitter::emitLayoutQualifiersImpl(IRVarLayout* layout)

EmitVarChain chain = {};
chain.varLayout = layout;
auto space = getBindingSpaceForKinds(&chain, kind);
auto space = getBindingSpaceForKinds(&chain, LayoutResourceKindFlag::make(kind));
m_writer->emit("@group(");
m_writer->emit(space);
m_writer->emit(") ");
Expand Down
19 changes: 17 additions & 2 deletions source/slang/slang-parameter-binding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -984,7 +984,8 @@ static void addExplicitParameterBindings_HLSL(
//
// For now we do the filtering on target in a very direct fashion:
//
if (!isD3DTarget(context->getTargetRequest()) && !isMetalTarget(context->getTargetRequest()))
bool isMetal = isMetalTarget(context->getTargetRequest());
if (!isD3DTarget(context->getTargetRequest()) && !isMetal)
return;

auto typeLayout = varLayout->typeLayout;
Expand Down Expand Up @@ -1018,13 +1019,27 @@ static void addExplicitParameterBindings_HLSL(
if (kind == LayoutResourceKind::None)
continue;


// TODO: need to special-case when this is a `c` register binding...

// Find the appropriate resource-binding information
// inside the type, to see if we even use any resources
// of the given kind.

auto typeRes = typeLayout->FindResourceInfo(kind);
if (isMetal && !typeRes)
{
// Metal doesn't distinguish a unordered access and a readonly/uniform buffer.
switch (kind)
{
case LayoutResourceKind::UnorderedAccess:
case LayoutResourceKind::ShaderResource:
semanticInfo.kind = LayoutResourceKind::MetalBuffer;
typeRes = typeLayout->FindResourceInfo(LayoutResourceKind::MetalBuffer);
break;
}
}

LayoutSize count = 0;
if (typeRes)
{
Expand Down Expand Up @@ -1073,7 +1088,7 @@ static void addExplicitParameterBindings_GLSL(
// so that we are able to distinguish between
// Vulkan and OpenGL as targets.
//
if (!isKhronosTarget(context->getTargetRequest()))
if (!isKhronosTarget(context->getTargetRequest()) && !isWGPUTarget(context->getTargetRequest()))
return;

auto typeLayout = varLayout->typeLayout;
Expand Down
2 changes: 1 addition & 1 deletion tests/bugs/gh-471.slang
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ int test(int inVal)
return x * 16;
}

RWStructuredBuffer<int> outputBuffer : register(u0);
RWStructuredBuffer<int> outputBuffer;

[numthreads(4, 1, 1)]
void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
Expand Down
2 changes: 1 addition & 1 deletion tests/bugs/gh-775.slang
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ int test(int inVal)
}

//TEST_INPUT:ubuffer(data=[9 9 9 9], stride=4):out,name outputBuffer
RWStructuredBuffer<int> outputBuffer : register(u0);
RWStructuredBuffer<int> outputBuffer;

[numthreads(4, 1, 1)]
void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
Expand Down
2 changes: 1 addition & 1 deletion tests/bugs/static-method.slang
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ struct S
}

//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name outputBuffer
RWStructuredBuffer<int> outputBuffer : register(u0);
RWStructuredBuffer<int> outputBuffer;

int test(int t)
{
Expand Down
2 changes: 1 addition & 1 deletion tests/bugs/static-var.slang
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ int test(int inVal)
}

//TEST_INPUT:ubuffer(data=[9 9 9 9], stride=4):out,name outputBuffer
RWStructuredBuffer<int> outputBuffer : register(u0);
RWStructuredBuffer<int> outputBuffer;

[numthreads(4, 1, 1)]
void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
Expand Down
5 changes: 3 additions & 2 deletions tests/bugs/texture2d-gather.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
//TEST_INPUT: Texture2D(size=16, content=chessboard, format=R32_FLOAT):name g_texture
//TEST_INPUT: Sampler :name g_sampler

Texture2D<float> g_texture : register(t0);
SamplerState g_sampler : register(s0);
Texture2D<float> g_texture;

SamplerState g_sampler;

cbuffer Uniforms
{
Expand Down
2 changes: 1 addition & 1 deletion tests/bugs/type-legalize-bug-1.slang
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
//TEST_INPUT:type_conformance A:IFoo=0
//TEST_INPUT:type_conformance B:IFoo=1

RWStructuredBuffer<int> outputBuffer : register(u0);
RWStructuredBuffer<int> outputBuffer;
interface IFoo
{
associatedtype T : IFoo;
Expand Down
2 changes: 1 addition & 1 deletion tests/compute/break-stmt.slang
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ int test(int inVal)
}

//TEST_INPUT:ubuffer(data=[0 1 2 3], stride=4):out,name=outputBuffer
RWStructuredBuffer<int> outputBuffer : register(u0);
RWStructuredBuffer<int> outputBuffer;

[numthreads(4, 1, 1)]
void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
Expand Down
2 changes: 1 addition & 1 deletion tests/compute/continue-stmt.slang
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ int test(int inVal)
}

//TEST_INPUT:ubuffer(data=[0 1 2 3], stride=4):out,name=outputBuffer
RWStructuredBuffer<int> outputBuffer : register(u0);
RWStructuredBuffer<int> outputBuffer;

[numthreads(4, 1, 1)]
void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
Expand Down
2 changes: 1 addition & 1 deletion tests/compute/default-initializer.slang
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ int test(int value)
}

//TEST_INPUT:ubuffer(data=[0 1 2 3], stride=4):out,name=outputBuffer
RWStructuredBuffer<int> outputBuffer : register(u0);
RWStructuredBuffer<int> outputBuffer;

[numthreads(4, 1, 1)]
void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
Expand Down
2 changes: 1 addition & 1 deletion tests/compute/explicit-this-expr.slang
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ struct A
};

//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer
RWStructuredBuffer<float> outputBuffer : register(u0);
RWStructuredBuffer<float> outputBuffer;


float test(float inVal)
Expand Down
1 change: 1 addition & 0 deletions tests/compute/generics-constrained.slang
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ float testHelp(T helper)
}

//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer
[vk::binding(0, 0)]
RWStructuredBuffer<float> outputBuffer : register(u0);


Expand Down
2 changes: 1 addition & 1 deletion tests/compute/global-init.slang
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ int test(int inVal)
}

//TEST_INPUT:ubuffer(data=[0 1 2 3], stride=4):out,name=outputBuffer
RWStructuredBuffer<int> outputBuffer : register(u0);
RWStructuredBuffer<int> outputBuffer;

[numthreads(4, 1, 1)]
void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
Expand Down
2 changes: 1 addition & 1 deletion tests/compute/implicit-generic-app.slang
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ int test(int val)
}

//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer
RWStructuredBuffer<int> outputBuffer : register(u0);
RWStructuredBuffer<int> outputBuffer;

[numthreads(4, 1, 1)]
void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
Expand Down
2 changes: 1 addition & 1 deletion tests/compute/implicit-this-expr.slang
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ struct A
};

//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer
RWStructuredBuffer<float> outputBuffer : register(u0);
RWStructuredBuffer<float> outputBuffer;

float test(float inVal)
{
Expand Down
2 changes: 1 addition & 1 deletion tests/compute/init-list-defaults.slang
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ int test(int inVal)
}

//TEST_INPUT:ubuffer(data=[9 9 9 9], stride=4):out,name=outputBuffer
RWStructuredBuffer<int> outputBuffer : register(u0);
RWStructuredBuffer<int> outputBuffer;

[numthreads(4, 1, 1)]
void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
Expand Down
2 changes: 1 addition & 1 deletion tests/compute/inout.slang
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ int test(int inVal)
}

//TEST_INPUT:ubuffer(data=[0 1 2 3], stride=4):out,name=outputBuffer
RWStructuredBuffer<int> outputBuffer : register(u0);
RWStructuredBuffer<int> outputBuffer;

[numthreads(4, 1, 1)]
void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
Expand Down
2 changes: 1 addition & 1 deletion tests/compute/multiple-continue-sites.slang
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ int test(int inVal)
}

//TEST_INPUT:ubuffer(data=[0 1 2 3], stride=4):out,name=outputBuffer
RWStructuredBuffer<int> outputBuffer : register(u0);
RWStructuredBuffer<int> outputBuffer;

[numthreads(4, 1, 1)]
void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
Expand Down
2 changes: 1 addition & 1 deletion tests/compute/struct-default-init.slang
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ int test(int inVal)
}

//TEST_INPUT:ubuffer(data=[9 9 9 9], stride=4):out,name=outputBuffer
RWStructuredBuffer<int> outputBuffer : register(u0);
RWStructuredBuffer<int> outputBuffer;

[numthreads(4, 1, 1)]
void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
Expand Down
2 changes: 1 addition & 1 deletion tests/compute/switch-stmt.slang
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ int test(int inVal)
}

//TEST_INPUT:ubuffer(data=[0 1 2 3 4 5 6 7], stride=4):out,name=outputBuffer
RWStructuredBuffer<int> outputBuffer : register(u0);
RWStructuredBuffer<int> outputBuffer;

[numthreads(8, 1, 1)]
void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
Expand Down
2 changes: 1 addition & 1 deletion tests/compute/this-type.slang
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ int test(int value)
}

//TEST_INPUT:ubuffer(data=[0 1 2 3], stride=4):out,name=outputBuffer
RWStructuredBuffer<int> outputBuffer : register(u0);
RWStructuredBuffer<int> outputBuffer;

[numthreads(4, 1, 1)]
void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
Expand Down
2 changes: 1 addition & 1 deletion tests/compute/user-defined-initializer.slang
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ int test(int value)
}

//TEST_INPUT:ubuffer(data=[0 1 2 3], stride=4):out,name=outputBuffer
RWStructuredBuffer<int> outputBuffer : register(u0);
RWStructuredBuffer<int> outputBuffer;

[numthreads(4, 1, 1)]
void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ int test(int value)
}

//TEST_INPUT:ubuffer(data=[0 1 2 3], stride=4):out,name=outputBuffer
RWStructuredBuffer<int> outputBuffer : register(u0);
RWStructuredBuffer<int> outputBuffer;

[numthreads(4, 1, 1)]
void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
Expand Down
2 changes: 1 addition & 1 deletion tests/preprocessor/line-macro.slang
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
//TEST:COMPARE_COMPUTE(filecheck-buffer=CHECK): -output-using-type

//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer
RWStructuredBuffer<int> outputBuffer : register(u0);
RWStructuredBuffer<int> outputBuffer;

#define T(x) x
#define LL T(__LINE__)
Expand Down
2 changes: 1 addition & 1 deletion tests/serialization/std-lib-serialize.slang
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ struct A
};

//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer
RWStructuredBuffer<float> outputBuffer : register(u0);
RWStructuredBuffer<float> outputBuffer;


float test(float inVal)
Expand Down
23 changes: 23 additions & 0 deletions tests/wgsl/explicit-binding.slang
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
//TEST:SIMPLE(filecheck=METAL): -target metal
//TEST:SIMPLE(filecheck=CHECK): -target wgsl -entry computeMain -stage compute

// CHECK-DAG: @binding(9) @group(7)
// CHECK-DAG: @binding(3) @group(4)
// CHECK-DAG: @binding(1) @group(2)

// METAL-DAG: buffer(9)
// METAL-DAG: texture(7)

[vk::binding(1, 2)]
Texture2D texA : register(t7);

[vk::binding(3, 4)]
ConstantBuffer<float> cb;

RWStructuredBuffer<float> ob : register(u9, space7);

[numthreads(1,1,1)]
void computeMain()
{
ob[0] = cb + texA.Load(int3(0)).x;
}

0 comments on commit 7dabfa7

Please sign in to comment.