Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix pow recursive and gradient functions #42

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

deerlord-tsuno
Copy link

@deerlord-tsuno deerlord-tsuno commented Jan 10, 2025

This PR applies changes to _pow and Pow.backward to fix some bugs and expand their functionality. I've verified these changes work at least for my use case, but any refinements are welcome.

_pow(a, n)

The current implementation only supports positive integers, and has some unexpected behavior when the n value passed does not satisfy these properties. This PR changes _pow to work for all real values, as well as aligning the function's code with the other recursive functions.

Pow.backward(dz, z)

This function handles the backprop gradients for Pow.forward(a, n), which returns z based on the formula $z=a^n$ for constant $n$.

Currently, Pow.backward(dz, z) calculates $\frac{dz}{da}=2a$ regardless of the value of $n$. This is correct if the input tensor is squared as in a.pow(2), but for any other exponents the function will calculate the wrong derivative, resulting in incorrect gradients being propagated. This PR changes Pow.backward to instead calculate $\frac{dz}{da}=na^{n-1}$ as intended. To achieve this, Pow.forward(a, n) now stores n in the cache during the forward pass, which is then retrieved in the backward pass.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant