-
Notifications
You must be signed in to change notification settings - Fork 245
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Correct IR generation for no-diff pointer type (#5976)
* Correct IR generation for no-diff pointer type Close #5805 There is an issue on checking whether a pointer type parameter is no_diff, we should first check whether this parameter is an Attribute type first, then check the data type. In the back-propagate pass, for the pointer type parameter, we should load this parameter to a temp variable, then pass it to the primal function call. Otherwise, the temp variable will no be initialized, which will cause the following calculation wrong.
- Loading branch information
1 parent
e3b71cf
commit d48cd13
Showing
5 changed files
with
69 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
|
||
[Differentiable] | ||
float sumOfSquares(float x, float y, no_diff float4* test) | ||
{ | ||
return x * x + y * y * (test->x + test->y + test->z); | ||
} | ||
|
||
//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type -compile-arg -skip-spirv-validation -emit-spirv-directly | ||
|
||
//TEST_INPUT: set ptr = ubuffer(data=[1.0 2.0 3.0], stride=4) | ||
uniform float* ptr; | ||
|
||
//TEST_INPUT:ubuffer(data=[0.0 0.0 0.0 0.0 0.0], stride=4):out, name outputBuffer | ||
RWStructuredBuffer<float> outputBuffer; | ||
|
||
[shader("compute")] | ||
[numthreads(1, 1, 1)] | ||
void computeMain() | ||
{ | ||
float4* testPtr = (float4*)ptr; | ||
|
||
let result = sumOfSquares(2.0, 3.0, testPtr); | ||
|
||
// Use forward differentiation to compute the gradient of the output w.r.t. x only. | ||
let diffX = fwd_diff(sumOfSquares)(diffPair(2.0, 1.0), diffPair(3.0, 0.0), testPtr); | ||
|
||
// Create a differentiable pair to pass in the primal value and to receive the gradient. | ||
var dpX = diffPair(2.0); | ||
var dpY = diffPair(3.0); | ||
|
||
// Propagate the gradient of the output (1.0f) to the input parameters. | ||
bwd_diff(sumOfSquares)(dpX, dpY, testPtr, 1.0); | ||
|
||
outputBuffer[0] = result; // 2^2 + 3^2 * (1 + 2 + 3) = 58 | ||
outputBuffer[1] = diffX.d; // 2*x * dx + 2*y * dy * (1 + 2 + 3) = 4 | ||
outputBuffer[2] = diffX.p; // 2^2 + 3^2 * (1 + 2 + 3) = 58 | ||
outputBuffer[3] = dpX.d; // 2*x = 4 | ||
|
||
outputBuffer[4] = dpY.d; // 2*y * (1 + 2 +3) = 36 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
type: float | ||
58.000000 | ||
4.000000 | ||
58.000000 | ||
4.000000 | ||
36.000000 |