Fix pow recursive and gradient functions #42
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This PR applies changes to
_pow
andPow.backward
to fix some bugs and expand their functionality. I've verified these changes work at least for my use case, but any refinements are welcome._pow(a, n)
The current implementation only supports positive integers, and has some unexpected behavior when the
n
value passed does not satisfy these properties. This PR changes_pow
to work for all real values, as well as aligning the function's code with the other recursive functions.Pow.backward(dz, z)
This function handles the backprop gradients for$z=a^n$ for constant $n$ .
Pow.forward(a, n)
, which returns z based on the formulaCurrently,$\frac{dz}{da}=2a$ regardless of the value of $n$ . This is correct if the input tensor is squared as in $\frac{dz}{da}=na^{n-1}$ as intended. To achieve this,
Pow.backward(dz, z)
calculatesa.pow(2)
, but for any other exponents the function will calculate the wrong derivative, resulting in incorrect gradients being propagated. This PR changesPow.backward
to instead calculatePow.forward(a, n)
now storesn
in the cache during the forward pass, which is then retrieved in the backward pass.