Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Hessian of gamma-distributed samples #21432

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 52 additions & 11 deletions jax/_src/lax/special.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"""

from enum import Enum
from typing import Any
import numpy as np
from functools import partial

Expand All @@ -29,14 +30,30 @@
standard_naryop, standard_unop, sub,
_const, _dtype,
_float, _nary_lower_hlo, _ones, _isnan, _reduce)
from jax._src.lax.control_flow import while_loop
from jax._src.lax.control_flow import cond, scan, while_loop

from jax._src import api
from jax._src import dtypes
from jax._src.interpreters import ad
from jax._src.interpreters import mlir
from jax._src.lib.mlir.dialects import chlo
from jax._src.typing import Array, ArrayLike

def _while_loop_scan(cond_fun, body_fun, init_val, max_iter):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should I move this to api_util?

"""Scan-based implementation (jit ok, reverse-mode autodiff ok)."""
def _iter(val):
next_val = body_fun(val)
next_cond = cond_fun(next_val)
return next_val, next_cond

def _fun(tup, it):
val, _cond = tup
# When _cond is met, we start doing no-ops.
return cond(_cond, _iter, lambda x: (x, False), val), it

init = (init_val, cond_fun(init_val))
return scan(_fun, init, None, length=max_iter)[0][0]

def betainc(a: ArrayLike, b: ArrayLike, x: ArrayLike) -> Array:
r"""Elementwise regularized incomplete beta integral."""
return regularized_incomplete_beta_p.bind(a, b, x)
Expand Down Expand Up @@ -250,7 +267,7 @@ def _any(predicates: Array) -> Array:
all_dimensions = tuple(range(len(predicates_shape)))
return reduce(predicates, f, bitwise_or, all_dimensions)

def _igamma_series(ax, x, a, enabled, dtype, mode):
def _igamma_series(ax, x, a, enabled, dtype, mode, *, hessian: bool = False):
def cond_fn(vals):
return _any(vals[0])

Expand Down Expand Up @@ -285,7 +302,9 @@ def body_fn(vals):
full_like(a, 0),
)

vals = while_loop(cond_fn, body_fn, init_vals)
vals = (_while_loop_scan(cond_fn, body_fn, init_vals, 256)
if hessian
else while_loop(cond_fn, body_fn, init_vals))
ans = vals[3]
dans_da = vals[6]

Expand Down Expand Up @@ -327,7 +346,9 @@ def igamma_impl(a, x, *, dtype):
full_like(a, float('nan')), output)
return output

def _igammac_continued_fraction(ax, x, a, enabled, dtype, mode):
def _igammac_continued_fraction(ax, x, a, enabled, dtype, mode,
*,
hessian: bool = False):
eps = dtypes.finfo(dtype).eps

def cond_fn(vals):
Expand Down Expand Up @@ -418,7 +439,9 @@ def body_fn(vals):
c, pkm1, qkm1, pkm2, qkm2,
dpkm2_da, dqkm2_da, dpkm1_da, dqkm1_da, dans_da)

vals = while_loop(cond_fn, body_fn, init_vals)
vals = (_while_loop_scan(cond_fn, body_fn, init_vals, 256)
if hessian
else while_loop(cond_fn, body_fn, init_vals))
ans = vals[1]
if mode == IgammaMode.VALUE:
return ans * ax
Expand Down Expand Up @@ -470,7 +493,12 @@ def igamma_grad_a_impl(a, x, *, dtype):
full_like(a, float('nan')), output)
return output

def random_gamma_grad_impl(a, x, *, dtype):
def random_gamma_grad_impl(a: Array,
x: Array,
*,
dtype: Any,
hessian: bool = False
) -> Array:
is_nan = bitwise_or(_isnan(a), _isnan(x))
x_is_zero = eq(x, full_like(x,0))
domain_error = bitwise_or(lt(x, full_like(x,0)), le(a, full_like(a,0)))
Expand All @@ -480,11 +508,13 @@ def random_gamma_grad_impl(a, x, *, dtype):
ax = exp(ax)
enabled = bitwise_not(bitwise_or(bitwise_or(bitwise_or
(x_is_zero, domain_error), underflow), is_nan))
output = select(use_igammac,
-_igammac_continued_fraction(ax, x, a, bitwise_and(enabled, use_igammac),
dtype, IgammaMode.SAMPLE_DERIVATIVE),
_igamma_series(ax, x, a, bitwise_and(enabled, bitwise_not(use_igammac)),
dtype, IgammaMode.SAMPLE_DERIVATIVE))
output = select(
use_igammac,
-_igammac_continued_fraction(ax, x, a, bitwise_and(enabled, use_igammac),
dtype, IgammaMode.SAMPLE_DERIVATIVE,
hessian=hessian),
_igamma_series(ax, x, a, bitwise_and(enabled, bitwise_not(use_igammac)),
dtype, IgammaMode.SAMPLE_DERIVATIVE, hessian=hessian))
output = select(x_is_zero, full_like(output,0), output)
output = select(bitwise_or(domain_error, is_nan),
full_like(a, float('nan')), output)
Expand Down Expand Up @@ -653,10 +683,21 @@ def bessel_i0e_impl(x):

ad.defjvp(igammac_p, igammac_grada, igammac_gradx)

def random_gamma_hessian_a(g, a, x, *, dtype):
return api.grad(random_gamma_grad_impl, argnums=0)(a, x, dtype=dtype,
hessian=True)

def random_gamma_hessian_x(g, a, x, *, dtype):
return api.grad(random_gamma_grad_impl, argnums=1)(a, x, dtype=dtype,
hessian=True)

random_gamma_grad_p = standard_naryop([_float, _float], 'random_gamma_grad')
mlir.register_lowering(random_gamma_grad_p,
mlir.lower_fun(_up_and_broadcast(random_gamma_grad_impl),
multiple_results=False))
ad.defjvp(random_gamma_grad_p,
_up_and_broadcast(random_gamma_hessian_a),
_up_and_broadcast(random_gamma_hessian_x))

zeta_p = standard_naryop([_float, _float], 'zeta')
mlir.register_lowering(zeta_p, partial(_nary_lower_hlo, chlo.zeta))
Expand Down
13 changes: 13 additions & 0 deletions tests/random_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1484,5 +1484,18 @@ def f():
jax.random.normal(jax.random.key(0), 1000)
f() # don't crash

class SamplingDerivativeTest(jtu.JaxTestCase):
def test_gamma_hessian(self):
# Regression test for https://github.com/google/jax/issues/16076
def hessian_sample(key: jax.Array) -> jax.Array:
((retval,),) = jax.hessian(random.gamma, argnums=(1,))(key, 0.8)
return retval

keys = random.split(random.key(0), 300)
x = jax.vmap(hessian_sample)(keys)
mean_x = jnp.mean(x, axis=-1)
self.assertArraysAllClose(mean_x, jnp.asarray(0.61), atol=0.1, rtol=0.4)


if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())