Skip to content

Commit

Permalink
bug fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Jeremy Baier committed Oct 21, 2024
1 parent f0aa78d commit 7920c13
Showing 1 changed file with 16 additions and 10 deletions.
26 changes: 16 additions & 10 deletions src/pint_pal/noise_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def analyze_noise(
noise_dict: Dictionary of maximum a posterior noise values
rn_bf: Savage-Dickey BF for achromatic RN for given pulsar
"""
if which_sampler == 'PTMCMCSampler':
if which_sampler == 'PTMCMCSampler' or which_sampler == 'discovery':
try:
noise_core = co.Core(chaindir=chaindir)
except:
Expand Down Expand Up @@ -339,7 +339,7 @@ def model_noise(
run_noise_analysis=True,
wb_efac_sigma=0.25,
base_op_dir="./",
noise_kwargs={},
model_kwargs={},
sampler_kwargs={},
return_sampler=False,
):
Expand Down Expand Up @@ -370,8 +370,12 @@ def model_noise(
# get the default settings
model_defaults, sampler_defaults = get_model_and_sampler_default_settings()
# update with args passed in
model_kwargs = model_defaults.update(noise_kwargs)
sampler_kwargs = sampler_defaults.update(sampler_kwargs)
model_defaults.update(model_kwargs)
sampler_defaults.update(sampler_kwargs)
model_kwargs = model_defaults.copy()
sampler_kwargs = sampler_defaults.copy()



if not using_wideband:
outdir = base_op_dir + mo.PSR.value + "_nb/"
Expand Down Expand Up @@ -512,7 +516,7 @@ def model_noise(
except ImportError:
log.error("Please install latest version of jax and/or xarray")
ValueError("Please install lastest version of jax and/or xarray")
os.mkdir(outdir, parents=True, exist_ok=True)
os.makedirs(outdir, exist_ok=True)
with open(outdir+"model_kwargs.json", "w") as f:
json.dump(model_kwargs, f)
with open(outdir+"sampler_kwargs.json", "w") as f:
Expand All @@ -524,7 +528,7 @@ def model_noise(
df = log_x.to_df(samp.get_samples()['par'])
# convert DataFrame to dictionary
samples_dict = df.to_dict(orient='list')
if sampler_kwargs['numpyro_sampler'] is not 'HMC_GIBBS':
if sampler_kwargs['numpyro_sampler'] != 'HMC_GIBBS':
ln_like = log_likelihood(numpyro_model, samp.get_samples())['ll']
ln_prior = dist.Normal(0, 10).log_prob(samp.get_samples()['par']).sum(axis=-1)
ln_post = ln_like + ln_prior
Expand Down Expand Up @@ -562,6 +566,7 @@ def add_noise_to_model(
rn_bf_thres=1e2,
base_dir=None,
compare_dir=None,
which_sampler='PTMCMCSampler'
):
"""
Add WN, RN, DMGP, and parameters to timing model.
Expand Down Expand Up @@ -599,12 +604,13 @@ def add_noise_to_model(

log.info(f"Using existing noise analysis results in {chaindir}")
log.info("Adding new noise parameters to model.")
noise_dict, rn_bf = analyze_noise(
noise_core, noise_dict, rn_bf = analyze_noise(
chaindir,
burn_frac,
save_corner,
no_corner_plot,
chaindir_compare=chaindir_compare,
which_sampler=which_sampler,
)
chainfile = chaindir + "chain_1.txt"
mtime = Time(os.path.getmtime(chainfile), format="unix")
Expand Down Expand Up @@ -988,7 +994,7 @@ def setup_discovery_noise(psr,
model_components.append(ds.makegp_fourier(psr, ds.powerlaw, model_kwargs['dmgp_nfreqs'], T=time_span, name='dm_gp'))
elif model_kwargs['dmgp_psd'] == 'free_spectral':
model_components.append(ds.makegp_fourier(psr, ds.free_spectral, model_kwargs['dmgp_nfreqs'], T=time_span, name='dm_gp'))
if model_kwargs['inc_chrom']:
if model_kwargs['inc_chromgp']:
if model_kwargs['rn_psd'] == 'powerlaw':
model_components.append(ds.makegp_fourier(psr, ds.powerlaw, model_kwargs['chromgp_nfreqs'], T=time_span, name='dm_gp'))
elif model_kwargs['rn_psd'] == 'free_spectral':
Expand All @@ -1006,7 +1012,7 @@ def numpyro_model():
invhdorf=None, nuts_kwargs={})
sampler = infer.MCMC(gibbs_hmc_kernel,
num_warmup=sampler_kwargs['num_warmup'],
num_samples=sampler_kwargs['num_warmup'],
num_samples=sampler_kwargs['num_samples'],
num_chains=sampler_kwargs['num_chains'],
chain_method=sampler_kwargs['chain_method'],
progress_bar=True,
Expand All @@ -1018,7 +1024,7 @@ def numpyro_model():
nuts_kernel = infer.NUTS(numpyro_model, num_steps=sampler_kwargs['num_steps'])
sampler = infer.MCMC(nuts_kernel,
num_warmup=sampler_kwargs['num_warmup'],
num_samples=sampler_kwargs['num_warmup'],
num_samples=sampler_kwargs['num_samples'],
num_chains=sampler_kwargs['num_chains'],
chain_method=sampler_kwargs['chain_method'],
progress_bar=True,
Expand Down

0 comments on commit 7920c13

Please sign in to comment.