Skip to content

Commit

Permalink
Post-process when no sample sites present.
Browse files Browse the repository at this point in the history
Current post-processing behaviour skips models with only deterministic variables. Applying this change will return consistent samples regardless of whether `sample` sites are present.
  • Loading branch information
hessammehr committed Nov 21, 2024
1 parent 4f2c9b2 commit 057cbf2
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 3 deletions.
4 changes: 1 addition & 3 deletions numpyro/infer/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,9 +195,7 @@ def collect_and_postprocess(x):
if collect_fields:
fields = nested_attrgetter(*collect_fields)(x[0])
fields = [fields] if len(collect_fields) == 1 else list(fields)
site_values = jax.tree.flatten(fields[0])[0]
if len(site_values) > 0:
fields[0] = postprocess_fn(fields[0], *x[1:])
fields[0] = postprocess_fn(fields[0], *x[1:])

if remove_sites != ():
assert isinstance(fields[0], dict)
Expand Down
24 changes: 24 additions & 0 deletions test/infer/test_mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1208,3 +1208,27 @@ def model():
mcmc = MCMC(NUTS(model), num_warmup=10, num_samples=10)
mcmc.run(random.PRNGKey(0), extra_fields=("z.x",))
assert_allclose(mcmc.get_samples()["x"], jnp.exp(mcmc.get_extra_fields()["z.x"]))

def test_all_deterministic():
def model1():
numpyro.deterministic("x", 1.0)

def model2():
numpyro.deterministic("x", jnp.array([1.0, 2.0]))

num_samples = 10
shapes = {model1: (), model2: (2,)}

for model, shape in shapes.items():
mcmc = MCMC(NUTS(model), num_warmup=10, num_samples=num_samples)
mcmc.run(random.PRNGKey(0))
assert mcmc.get_samples()["x"].shape == (num_samples,) + shape

def test_empty_summary():
def model():
pass

mcmc = MCMC(NUTS(model), num_warmup=10, num_samples=10)
mcmc.run(random.PRNGKey(0))

mcmc.print_summary()

0 comments on commit 057cbf2

Please sign in to comment.