Replies: 1 comment 2 replies
-
Hi @ntessore. The For import jax
import jrng
def normal(rng, loc, scale):
return loc + rng.standard_normal() * scale
jitted_normal = jax.jit(normal)
rng = jrng.JRNG(20241216)
jitted_normal(rng, 1., 2.) gives an error
We could try to address this using the suggested approach of marking the jitted_normal_static = jax.jit(normal, static_argnames=("rng",)) This function is now callable, however it will always return the same value (or a shifted / scaled version if the assert jitted_normal_static(rng, 1., 2.) == jitted_normal_static(rng, 1., 2.) This is because One option would be to have functions which take a random number generator state object to also always return the (updated) state, which is what the JAX docs recommend for using stateful computations, for example for the normal function above: def normal(rng, loc=0.0, scale=1.0):
return loc + scale * rng.standard_normal(), rng To allow using the |
Beta Was this translation helpful? Give feedback.
-
What are the cases where the numpy.random.Generator interface is inadequate for sampling?
Beta Was this translation helpful? Give feedback.
All reactions