Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
juanitorduz committed Nov 10, 2024
1 parent b9bacee commit adb8ef2
Showing 1 changed file with 26 additions and 32 deletions.
58 changes: 26 additions & 32 deletions numpyro/infer/elbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ class ELBO:
:param num_particles: The number of particles/samples used to form the ELBO
(gradient) estimators.
:param vectorize_particles: Whether to use `jax.vmap` to compute ELBOs over the
num_particles-many particles in parallel. If False use `jax.lax.map`.
Defaults to True.
:param vectorize_particles: Callable to vectorize the computation of ELBOs over
the num_particles-many particles. Defaults to `jax.lax.map`. Other options are
`jax.vmap` and `jax.pmap`.
"""

"""
Expand All @@ -46,7 +46,7 @@ class ELBO:
"""
can_infer_discrete = False

def __init__(self, num_particles=1, vectorize_particles=True):
def __init__(self, num_particles=1, vectorize_particles=jax.lax.map):
self.num_particles = num_particles
self.vectorize_particles = vectorize_particles

Expand Down Expand Up @@ -121,15 +121,15 @@ class Trace_ELBO(ELBO):
:param num_particles: The number of particles/samples used to form the ELBO
(gradient) estimators.
:param vectorize_particles: Whether to use `jax.vmap` to compute ELBOs over the
num_particles-many particles in parallel. If False use `jax.lax.map`.
Defaults to True.
:param vectorize_particles: Callable to vectorize the computation of ELBOs over
the num_particles-many particles. Defaults to `jax.lax.map`. Other options are
`jax.vmap` and `jax.pmap`.
:param multi_sample_guide: Whether to make an assumption that the guide proposes
multiple samples.
"""

def __init__(
self, num_particles=1, vectorize_particles=True, multi_sample_guide=False
self, num_particles=1, vectorize_particles=jax.lax.map, multi_sample_guide=False
):
self.multi_sample_guide = multi_sample_guide
super().__init__(
Expand Down Expand Up @@ -228,10 +228,9 @@ def get_model_density(key, latent):
return {"loss": -elbo, "mutable_state": mutable_state}
else:
rng_keys = random.split(rng_key, self.num_particles)
if self.vectorize_particles:
elbos, mutable_state = vmap(single_particle_elbo)(rng_keys)
else:
elbos, mutable_state = jax.lax.map(single_particle_elbo, rng_keys)
elbos, mutable_state = self.vectorize_particles(
single_particle_elbo, rng_keys
)
return {"loss": -jnp.mean(elbos), "mutable_state": mutable_state}


Expand Down Expand Up @@ -362,10 +361,9 @@ def single_particle_elbo(rng_key):
return {"loss": -elbo, "mutable_state": mutable_state}
else:
rng_keys = random.split(rng_key, self.num_particles)
if self.vectorize_particles:
elbos, mutable_state = vmap(single_particle_elbo)(rng_keys)
else:
elbos, mutable_state = jax.lax.map(single_particle_elbo, rng_keys)
elbos, mutable_state = self.vectorize_particles(
single_particle_elbo, rng_keys
)
return {"loss": -jnp.mean(elbos), "mutable_state": mutable_state}


Expand All @@ -385,9 +383,9 @@ class RenyiELBO(ELBO):
Here :math:`\alpha \neq 1`. Default is 0.
:param num_particles: The number of particles/samples
used to form the objective (gradient) estimator. Default is 2.
:param vectorize_particles: Whether to use `jax.vmap` to compute ELBOs over the
num_particles-many particles in parallel. If False use `jax.lax.map`.
Defaults to True.
:param vectorize_particles: Callable to vectorize the computation of ELBOs over
the num_particles-many particles. Defaults to `jax.lax.map`. Other options are
`jax.vmap` and `jax.pmap`.
Example::
Expand Down Expand Up @@ -504,10 +502,9 @@ def loss(self, rng_key, param_map, model, guide, *args, **kwargs):
)

rng_keys = random.split(rng_key, self.num_particles)
if self.vectorize_particles:
elbos, common_plate_scale = vmap(single_particle_elbo)(rng_keys)
else:
elbos, common_plate_scale = jax.lax.map(single_particle_elbo, rng_keys)
elbos, common_plate_scale = self.vectorize_particles(
single_particle_elbo, rng_keys
)
assert common_plate_scale.shape == (self.num_particles,)
assert elbos.shape[0] == self.num_particles
scaled_elbos = (1.0 - self.alpha) * elbos
Expand Down Expand Up @@ -853,10 +850,7 @@ def single_particle_elbo(rng_key):
return -single_particle_elbo(rng_key)
else:
rng_keys = random.split(rng_key, self.num_particles)
if self.vectorize_particles:
return -jnp.mean(vmap(single_particle_elbo)(rng_keys))
else:
return -jnp.mean(jax.lax.map(single_particle_elbo, rng_keys))
return -jnp.mean(self.vectorize_particles(single_particle_elbo, rng_keys))


def get_importance_trace_enum(
Expand Down Expand Up @@ -1043,7 +1037,10 @@ class TraceEnum_ELBO(ELBO):
can_infer_discrete = True

def __init__(
self, num_particles=1, max_plate_nesting=float("inf"), vectorize_particles=True
self,
num_particles=1,
max_plate_nesting=float("inf"),
vectorize_particles=jax.lax.map,
):
self.max_plate_nesting = max_plate_nesting
super().__init__(
Expand Down Expand Up @@ -1221,7 +1218,4 @@ def single_particle_elbo(rng_key):
return -single_particle_elbo(rng_key)
else:
rng_keys = random.split(rng_key, self.num_particles)
if self.vectorize_particles:
return -jnp.mean(vmap(single_particle_elbo)(rng_keys))
else:
return -jnp.mean(jax.lax.map(single_particle_elbo, rng_keys))
return -jnp.mean(self.vectorize_particles(single_particle_elbo, rng_keys))

0 comments on commit adb8ef2

Please sign in to comment.