Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Amortizing overhead in .run #1249

Open
bkj opened this issue Dec 10, 2021 · 1 comment
Open

Amortizing overhead in .run #1249

bkj opened this issue Dec 10, 2021 · 1 comment
Labels
enhancement New feature or request help wanted Extra attention is needed jax This issue is specific to JAX

Comments

@bkj
Copy link

bkj commented Dec 10, 2021

Hello --

New to numpyro, but really excited to be looking at it!

Question: When I call something like

nuts = NUTS(model)
mcmc = MCMC(nuts, num_warmup=500, num_samples=2000)
mcmc.run(rng_key_, data)

there's a startup overhead, where the progress bar appears but doesn't move for ~5 seconds and then moves very fast.

I'm guessing this is compilation overhead? If that's right, is there a way to cache the compilation so that I don't have to pay that overhead everytime I run my program?

Thanks!

@fehiepsi
Copy link
Member

Good timing, @bkj! This is not possible in the past but it seems to be possible with the experimental compilation cache submodule added to the recent jax releases. I haven't tested it yet (so I'm not sure it is robust now) but happy to help if someone wants to tackle this issue, which will have a huge impact. I'm excited to see this feature available in numpyro.

Currently, we have a _compile method that separates the compiling job from the sampling job. But it only works for a single program (i.e. we can't compile to a file, then load it). We can add new methods named compile_to_file and get_compliation(...)... for this feature. We then pass some infos to fori_collect method, where we will compile to file in this branch and get compilation in this branch (rather than using jit there).

@fehiepsi fehiepsi added enhancement New feature or request help wanted Extra attention is needed jax This issue is specific to JAX labels Dec 10, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request help wanted Extra attention is needed jax This issue is specific to JAX
Projects
None yet
Development

No branches or pull requests

2 participants