Skip to content

Commit

Permalink
more bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Jeremy Baier committed Oct 21, 2024
1 parent 7920c13 commit 7a0b140
Showing 1 changed file with 34 additions and 30 deletions.
64 changes: 34 additions & 30 deletions src/pint_pal/noise_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,24 @@
from enterprise_extensions import model_utils
from enterprise_extensions import deterministic
from enterprise_extensions.timing import timing_block
try:
import xarray as xr
import jax
from jax import numpy as jnp
import numpyro
from numpyro.infer import log_likelihood
from numpyro import distributions as dist
from numpyro import infer
import discovery as ds
from discovery import prior as ds_prior
from discovery.prior import (makelogtransform_uniform,
makelogprior_uniform,
sample_uniform)
from discovery.gibbs import setup_single_psr_hmc_gibbs

except ImportError:
log.error("Please install the latest version of discovery, numpyro, and/or jax")
ValueError("Please install the latest version of discovery, numpyro, and/or jax")

# from enterprise_extensions.blocks import (white_noise_block, red_noise_block)

Expand Down Expand Up @@ -507,15 +525,6 @@ def model_noise(
#################################################################
elif which_sampler == "discovery":
log.info(f"INFO: Running noise analysis with {which_sampler} for {e_psr.name}")
try:
import jax
import xarray as xr
from numpyro import distributions as dist
from numpyro.infer import log_likelihood

except ImportError:
log.error("Please install latest version of jax and/or xarray")
ValueError("Please install lastest version of jax and/or xarray")
os.makedirs(outdir, exist_ok=True)
with open(outdir+"model_kwargs.json", "w") as f:
json.dump(model_kwargs, f)
Expand Down Expand Up @@ -585,7 +594,7 @@ def add_noise_to_model(
Returns
=======
model: New timing model which includes WN and RN parameters
model: New timing model which includes WN and RN (and potentially dmgp, chrom_gp, and solar wind) parameters
"""

# Assume results are in current working directory if not specified
Expand Down Expand Up @@ -959,25 +968,10 @@ def setup_discovery_noise(psr,
"""
Setup the discovery likelihood with numpyro sampling for noise analysis
"""
# check that jax, numpyro and discovery are installed
try:
import discovery as ds
import jax
from jax import numpy as jnp
import numpyro
from numpyro import distributions as dist
from numpyro import infer
from discovery import prior
from discovery.prior import (makelogtransform_uniform,
makelogprior_uniform,
sample_uniform)
from discovery.gibbs import setup_single_psr_hmc_gibbs

except ImportError:
log.error("Please install the latest version of discovery, numpyro, and/or jax")
ValueError("Please install the latest version of discovery, numpyro, and/or jax")
# set up the model
time_span = ds.getspan([psr])
# this updates the ds.stand_priordict object
ds.priordict_standard.update(prior_dictionary_updates())
model_components = [
psr.residuals,
ds.makegp_timing(psr, svd=True),
Expand All @@ -1000,7 +994,7 @@ def setup_discovery_noise(psr,
elif model_kwargs['rn_psd'] == 'free_spectral':
model_components.append(ds.makegp_fourier(psr, ds.free_spectral, model_kwargs['chromgp_nfreqs'], T=time_span, name='dm_gp'))
psl = ds.PulsarLikelihood(model_components)
prior = prior.makelogprior_uniform(psl.logL.params, {})
prior = ds_prior.makelogprior_uniform(psl.logL.params, ds.priordict_standard)
log_x = makelogtransform_uniform(psl.logL)
# x0 = sample_uniform(psl.logL.params)
if sampler_kwargs['numpyro_sampler'] == 'HMC_Gibbs':
Expand All @@ -1021,7 +1015,8 @@ def numpyro_model():
def numpyro_model():
params = jnp.array(numpyro.sample("par", dist.Normal(0,10).expand([len(log_x.params)])))
numpyro.factor("ll", log_x(params))
nuts_kernel = infer.NUTS(numpyro_model, num_steps=sampler_kwargs['num_steps'])
nuts_kernel = infer.NUTS(numpyro_model, max_tree_depth=5, dense_mass=True,
forward_mode_differentiation=False, target_accept_prob=0.99)
sampler = infer.MCMC(nuts_kernel,
num_warmup=sampler_kwargs['num_warmup'],
num_samples=sampler_kwargs['num_samples'],
Expand All @@ -1036,7 +1031,7 @@ def numpyro_model():
hmc_kernel = infer.HMC(numpyro_model, num_steps=sampler_kwargs['num_steps'])
sampler = infer.MCMC(hmc_kernel,
num_warmup=sampler_kwargs['num_warmup'],
num_samples=sampler_kwargs['num_warmup'],
num_samples=sampler_kwargs['num_samples'],
num_chains=sampler_kwargs['num_chains'],
chain_method=sampler_kwargs['chain_method'],
progress_bar=True,
Expand Down Expand Up @@ -1071,6 +1066,15 @@ def test_equad_convention(pars_list):
)
return None


def prior_dictionary_updates():
return {
'(.*_)?dm_gp_log10_A': [-20, -11],
'(.*_)?dm_gp_gamma': [0, 7],
'(.*_)?chrom_gp_log10_A': [-20, -11],
'(.*_)?chrom_gp_gamma': [0, 7],
}

def get_model_and_sampler_default_settings():
model_defaults = {
# acrhomatic red noise
Expand Down

0 comments on commit 7a0b140

Please sign in to comment.