Skip to content

Commit

Permalink
Add zero_from_primal
Browse files Browse the repository at this point in the history
  • Loading branch information
NeilGirdhar committed Dec 23, 2024
1 parent f338a30 commit 62bcc19
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 1 deletion.
4 changes: 3 additions & 1 deletion tjax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
NumpyIntegralNumeric, NumpyRealArray, NumpyRealNumeric, PyTree,
RealArray, RealNumeric, Shape, ShapeLike, SliceLike)
from ._src.cotangent_tools import (copy_cotangent, cotangent_combinator, print_cotangent,
replace_cotangent, reverse_scale_cotangent, scale_cotangent)
replace_cotangent, reverse_scale_cotangent, scale_cotangent,
zero_from_primal)
from ._src.display.display_generic import display_generic
from ._src.display.internal import internal_print_generic
from ._src.display.print_generic import print_generic
Expand Down Expand Up @@ -116,4 +117,5 @@
'scale_cotangent',
'softplus',
'tree_allclose',
'zero_from_primal',
]
5 changes: 5 additions & 0 deletions tjax/_src/cotangent_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import jax.numpy as jnp
from jax import tree, vjp
from jax.custom_derivatives import zero_from_primal as jax_zero_from_primal

from .annotations import JaxRealArray, RealNumeric
from .display.print_generic import print_generic
Expand All @@ -17,6 +18,10 @@
Y = TypeVar('Y')


def zero_from_primal(x: X, /, *, symbolic_zeros: bool = False) -> X:
return jax_zero_from_primal(x, symbolic_zeros=symbolic_zeros)


# scale_cotangent ----------------------------------------------------------------------------------
def scale_cotangent(x: X,
scalar_scale: RealNumeric | None = None,
Expand Down

0 comments on commit 62bcc19

Please sign in to comment.