From 7920c131d22c851b0700a91c943663df4521b91d Mon Sep 17 00:00:00 2001 From: Jeremy Baier Date: Mon, 21 Oct 2024 03:31:34 +0000 Subject: [PATCH] bug fixes --- src/pint_pal/noise_utils.py | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/src/pint_pal/noise_utils.py b/src/pint_pal/noise_utils.py index 56f4584..178b4e8 100644 --- a/src/pint_pal/noise_utils.py +++ b/src/pint_pal/noise_utils.py @@ -138,7 +138,7 @@ def analyze_noise( noise_dict: Dictionary of maximum a posterior noise values rn_bf: Savage-Dickey BF for achromatic RN for given pulsar """ - if which_sampler == 'PTMCMCSampler': + if which_sampler == 'PTMCMCSampler' or which_sampler == 'discovery': try: noise_core = co.Core(chaindir=chaindir) except: @@ -339,7 +339,7 @@ def model_noise( run_noise_analysis=True, wb_efac_sigma=0.25, base_op_dir="./", - noise_kwargs={}, + model_kwargs={}, sampler_kwargs={}, return_sampler=False, ): @@ -370,8 +370,12 @@ def model_noise( # get the default settings model_defaults, sampler_defaults = get_model_and_sampler_default_settings() # update with args passed in - model_kwargs = model_defaults.update(noise_kwargs) - sampler_kwargs = sampler_defaults.update(sampler_kwargs) + model_defaults.update(model_kwargs) + sampler_defaults.update(sampler_kwargs) + model_kwargs = model_defaults.copy() + sampler_kwargs = sampler_defaults.copy() + + if not using_wideband: outdir = base_op_dir + mo.PSR.value + "_nb/" @@ -512,7 +516,7 @@ def model_noise( except ImportError: log.error("Please install latest version of jax and/or xarray") ValueError("Please install lastest version of jax and/or xarray") - os.mkdir(outdir, parents=True, exist_ok=True) + os.makedirs(outdir, exist_ok=True) with open(outdir+"model_kwargs.json", "w") as f: json.dump(model_kwargs, f) with open(outdir+"sampler_kwargs.json", "w") as f: @@ -524,7 +528,7 @@ def model_noise( df = log_x.to_df(samp.get_samples()['par']) # convert DataFrame to dictionary samples_dict = df.to_dict(orient='list') - if sampler_kwargs['numpyro_sampler'] is not 'HMC_GIBBS': + if sampler_kwargs['numpyro_sampler'] != 'HMC_GIBBS': ln_like = log_likelihood(numpyro_model, samp.get_samples())['ll'] ln_prior = dist.Normal(0, 10).log_prob(samp.get_samples()['par']).sum(axis=-1) ln_post = ln_like + ln_prior @@ -562,6 +566,7 @@ def add_noise_to_model( rn_bf_thres=1e2, base_dir=None, compare_dir=None, + which_sampler='PTMCMCSampler' ): """ Add WN, RN, DMGP, and parameters to timing model. @@ -599,12 +604,13 @@ def add_noise_to_model( log.info(f"Using existing noise analysis results in {chaindir}") log.info("Adding new noise parameters to model.") - noise_dict, rn_bf = analyze_noise( + noise_core, noise_dict, rn_bf = analyze_noise( chaindir, burn_frac, save_corner, no_corner_plot, chaindir_compare=chaindir_compare, + which_sampler=which_sampler, ) chainfile = chaindir + "chain_1.txt" mtime = Time(os.path.getmtime(chainfile), format="unix") @@ -988,7 +994,7 @@ def setup_discovery_noise(psr, model_components.append(ds.makegp_fourier(psr, ds.powerlaw, model_kwargs['dmgp_nfreqs'], T=time_span, name='dm_gp')) elif model_kwargs['dmgp_psd'] == 'free_spectral': model_components.append(ds.makegp_fourier(psr, ds.free_spectral, model_kwargs['dmgp_nfreqs'], T=time_span, name='dm_gp')) - if model_kwargs['inc_chrom']: + if model_kwargs['inc_chromgp']: if model_kwargs['rn_psd'] == 'powerlaw': model_components.append(ds.makegp_fourier(psr, ds.powerlaw, model_kwargs['chromgp_nfreqs'], T=time_span, name='dm_gp')) elif model_kwargs['rn_psd'] == 'free_spectral': @@ -1006,7 +1012,7 @@ def numpyro_model(): invhdorf=None, nuts_kwargs={}) sampler = infer.MCMC(gibbs_hmc_kernel, num_warmup=sampler_kwargs['num_warmup'], - num_samples=sampler_kwargs['num_warmup'], + num_samples=sampler_kwargs['num_samples'], num_chains=sampler_kwargs['num_chains'], chain_method=sampler_kwargs['chain_method'], progress_bar=True, @@ -1018,7 +1024,7 @@ def numpyro_model(): nuts_kernel = infer.NUTS(numpyro_model, num_steps=sampler_kwargs['num_steps']) sampler = infer.MCMC(nuts_kernel, num_warmup=sampler_kwargs['num_warmup'], - num_samples=sampler_kwargs['num_warmup'], + num_samples=sampler_kwargs['num_samples'], num_chains=sampler_kwargs['num_chains'], chain_method=sampler_kwargs['chain_method'], progress_bar=True,