Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow to collect nested dict keys in mcmc #1905

Merged
merged 3 commits into from
Nov 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions numpyro/infer/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
fori_collect,
identity,
is_prng_key,
nested_attrgetter,
)

__all__ = [
Expand Down Expand Up @@ -192,7 +193,7 @@ def _collect_fn(collect_fields, remove_sites):
@cached_by(_collect_fn, collect_fields, remove_sites)
def collect(x):
if collect_fields:
fields = attrgetter(*collect_fields)(x[0])
fields = nested_attrgetter(*collect_fields)(x[0])

if remove_sites != ():
fields = [fields] if len(collect_fields) == 1 else list(fields)
Expand Down Expand Up @@ -585,7 +586,10 @@ def warmup(
:param extra_fields: Extra fields (aside from :meth:`~numpyro.infer.mcmc.MCMCKernel.default_fields`)
from the state object (e.g. :data:`numpyro.infer.hmc.HMCState` for HMC) to collect during
the MCMC run. Exclude sample sites from collection with "~`sampler.sample_field`.`sample_site`".
e.g. "~z.a" will prevent site "a" from being collected if you're using the NUTS sampler.
e.g. "~z.a" will prevent site "a" from being collected if you're using the NUTS sampler. To
collect samples of a site "a" in the unconstrained space, we can specify the variable here, e.g.
`extra_fields=("z.a",)`.

:type extra_fields: tuple or list
:param bool collect_warmup: Whether to collect samples from the warmup phase. Defaults
to `False`.
Expand Down Expand Up @@ -622,7 +626,8 @@ def run(self, rng_key, *args, extra_fields=(), init_params=None, **kwargs):
during the MCMC run. Note that subfields can be accessed using dots, e.g.
`"adapt_state.step_size"` can be used to collect step sizes at each step. Exclude sample sites from
collection with "~`sampler.sample_field`.`sample_site`". e.g. "~z.a" will prevent site "a" from
being collected if you're using the NUTS sampler.
being collected if you're using the NUTS sampler. To collect samples of a site "a" in the
unconstrained space, we can specify the variable here, e.g. `extra_fields=("z.a",)`.
:type extra_fields: tuple or list of str
:param init_params: Initial parameters to begin sampling. The type must be consistent
with the input type to `potential_fn` provided to the kernel. If the kernel is
Expand Down
25 changes: 25 additions & 0 deletions numpyro/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -778,3 +778,28 @@ def find_stack_level() -> int:
else:
break
return n


def nested_attrgetter(*collect_fields):
"""
Like attrgetter, but allows for accessing dictionary keys
using the dot notation (e.g., 'x.c.d').
"""

def getter(obj):
results = tuple(_get_nested_attr(obj, field) for field in collect_fields)
return results if len(collect_fields) > 1 else results[0]

return getter


def _get_nested_attr(obj, field):
"""
Helper function to recursively access attributes and dictionary keys.
"""
for attr in field.split("."):
try:
obj = getattr(obj, attr)
except AttributeError:
obj = obj[attr]
return obj
9 changes: 9 additions & 0 deletions test/infer/test_mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1199,3 +1199,12 @@ def model():
samps = mcmc.get_samples()

assert all([site[3:] not in samps for site in remove_sites])


def test_extra_fields_include_unconstrained_samples():
def model():
numpyro.sample("x", dist.HalfNormal(1))

mcmc = MCMC(NUTS(model), num_warmup=10, num_samples=10)
mcmc.run(random.PRNGKey(0), extra_fields=("z.x",))
assert_allclose(mcmc.get_samples()["x"], jnp.exp(mcmc.get_extra_fields()["z.x"]))
Loading