From 8fd746bd1580eee29b326bb825e52fa2cb0ff6a3 Mon Sep 17 00:00:00 2001 From: Jeremy Baier Date: Sat, 19 Oct 2024 15:28:10 -0700 Subject: [PATCH] commiting what i have rn --- src/pint_pal/noise_utils.py | 407 ++++++++++++++++++++++++++++-------- 1 file changed, 317 insertions(+), 90 deletions(-) diff --git a/src/pint_pal/noise_utils.py b/src/pint_pal/noise_utils.py index 112e162..769560c 100644 --- a/src/pint_pal/noise_utils.py +++ b/src/pint_pal/noise_utils.py @@ -1,4 +1,4 @@ -import numpy as np, os +import numpy as np, os, json import arviz as az from astropy import log from astropy.time import Time @@ -29,6 +29,7 @@ from enterprise.signals import deterministic_signals from enterprise import constants as const +from enterprise_extensions.sampler import group_from_params, get_parameter_groups from enterprise_extensions import model_utils from enterprise_extensions import deterministic from enterprise_extensions.timing import timing_block @@ -41,12 +42,82 @@ from enterprise.signals import gp_priors as gpp +def setup_sampling_groups(pta, + write_groups=True, + outdir='./'): + """ + Sets sampling groups for PTMCMCSampler. + The sampling groups can help ensure the sampler does not get stuck. + The idea is to group parameters which are more highly correlated. + + Params + ------ + pta: the enterprise pta object + write_groups: bool, write the groups to a file + outdir: str, directory to write the groups to + + returns + ------- + groups: list of lists of indices corresponding to parameter groups + + """ + + # groups + pnames = pta.param_names + groups = get_parameter_groups(pta) + # add per-backend white noise + backends = np.unique([p[p.index('_')+1:p.index('efac')-1] for p in pnames if 'efac' in p]) + for be in backends: + groups.append(group_from_params(pta,[be])) + # group red noise parameters + exclude = ['linear_timing_model','sw_r2','sw_4p39','measurement_noise', + 'ecorr_sherman-morrison', 'ecorr_fast-sherman-morrison'] + red_signals = [p[p.index('_')+1:] for p in list(pta.signals.keys()) + if not p[p.index('_')+1:] in exclude] + rn_ct = 0 + for rs in red_signals: + if len(group_from_params(pta,[rs])) > 0: + rn_ct += 1 + groups.append(group_from_params(pta,[rs])) + if rn_ct > 1: + groups.append(group_from_params(pta,red_signals)) + # add cross chromatic groups + if 'n_earth' in pnames or 'log10_sigma_ne' in pnames: + # cross SW and chrom groups + dmgp_sw = [idx for idx, nm in enumerate(pnames) + if any([flag in nm for flag in ['dm_gp','n_earth', 'log10_sigma_ne']])] + groups.append(dmgp_sw) + if np.any(['chrom' in param for param in pnames]): + chromgp_sw = [idx for idx, nm in enumerate(pnames) + if any([flag in nm for flag in ['chrom_gp','n_earth', 'log10_sigma_ne']])] + dmgp_chromgp_sw = [idx for idx, nm in enumerate(pnames) + if any([flag in nm for flag in ['dm_gp','chrom','n_earth', 'log10_sigma_ne']])] + groups.append(chromgp_sw) + groups.append(dmgp_chromgp_sw) + if np.any(['chrom' in param for param in pnames]): + # cross dmgp and chromgp group + dmgp_chromgp = [idx for idx, nm in enumerate(pnames) + if any([flag in nm for flag in ['dm_gp','chrom']])] + groups.append(dmgp_chromgp) + # everything + groups.append([i for i in range(len(pnames))]) + # save list of params corresponding to groups + if write_groups is True: + with open(f'{outdir}/groups.txt', 'w') as fi: + for group in groups: + line = np.array(pnames)[np.array(group)] + fi.write("[" + " ".join(line) + "]\n") + # return the groups to be passed to the sampler + return groups + + def analyze_noise( chaindir="./noise_run_chains/", burn_frac=0.25, save_corner=True, no_corner_plot=False, chaindir_compare=None, + which_sampler = 'PTMCMCSampler', ): """ Reads enterprise chain file; produces and saves corner plot; returns WN dictionary and RN (SD) BF @@ -56,38 +127,41 @@ def analyze_noise( chaindir: path to enterprise noise run chain; Default: './noise_run_chains/' burn_frac: fraction of chain to use for burn-in; Default: 0.25 save_corner: Flag to toggle saving of corner plots; Default: True - chaindir_compare: path to enterprise noise run chain wish to plot in corner plot for comparison; Default: None + no_corner_plot: Flag to toggle saving of corner plots; Default: False + chaindir_compare: path to noise run chain wish to plot in corner plot for comparison; Default: None + which_sampler: choose from ['PTMCMCSampler' or 'GibbsSampler' or 'discovery'] Returns ======= - wn_dict: Dictionary of maximum likelihood WN values - rn_bf: Savage-Dickey BF for RN for given pulsar + noise_core: la_forge.core object which contains noise chains and run metadata + noise_dict: Dictionary of maximum a posterior noise values + rn_bf: Savage-Dickey BF for achromatic RN for given pulsar """ - #### replacing this with la_forge to be more flexible - # chainfile = chaindir + "chain_1.txt" - # chain = np.loadtxt(chainfile) - # burn = int(burn_frac * chain.shape[0]) - # pars = np.loadtxt(chaindir + "pars.txt", dtype=str) try: noise_core = co.Core(chaindir=chaindir) except: log.error(f"Could not load noise run from {chaindir}") - return None - noise_core.set_burn(burn_frac) + ValueError(f"Could not load noise run from {chaindir}") + if which_sampler == 'PTMCMCSampler': + noise_core.set_burn(burn_frac) + elif which_sampler == 'discovery': + noise_core.set_burn(0) + else: + noise_core.set_burn(burn_frac) chain = noise_core.chain psr_name = noise_core.params[0].split("_")[0] - pars = np.array(noise_core.params) - if chain.shape[1] != len(pars): - a = -4 - elif chain.shape[1] == len(pars): - a = len(chain.shape[1]) + pars = np.array([p for p in noise_core.params if p not in ['lnlike', 'lnpost']]) + if len(pars)+2 != chain.shape[1]: + chain = chain[:, :len(pars)+2] # load in same for comparison noise model if chaindir_compare is not None: - chainfile_compare = chaindir_compare + "chain_1.txt" - chain_compare = np.loadtxt(chainfile_compare) - burn_compare = int(burn_frac * chain_compare.shape[0]) - pars_compare = np.loadtxt(chaindir_compare + "pars.txt", dtype=str) + compare_core = co.Core(chaindir=chaindir) + compare_core.set_burn(noise_core.burn) + chain_compare = compare_core.chain + pars_compare = np.array([p for p in compare_core.params if p not in ['lnlike', 'lnpost']]) + if len(pars_compare)+2 != chain_compare.shape[1]: + chain_compare = chain_compare[:, :len(pars_compare)+2] psr_name_compare = pars_compare[0].split("_")[0] if psr_name_compare != psr_name: @@ -105,22 +179,22 @@ def analyze_noise( compare_pars_short = [p.split("_", 1)[1] for p in pars_compare] log.info(f"Comparison chain parameter names are {compare_pars_short}") log.info( - f"Comparison chain parameter convention: {test_equad_convention(compare_pars_short)}" + f"Comparison chain parameter convention: {test_equad_convention(compare_pars_short)}" ) # don't plot comparison if the parameter names don't match if compare_pars_short != pars_short: log.warning( - "Parameter names for comparison noise chains do not match, not plotting the compare-noise-dir chains" + "Parameter names for comparison noise chains do not match, not plotting the compare-noise-dir chains" ) chaindir_compare = None else: normalization_factor = ( - np.ones(len(chain_compare[:, :a])) - * len(chain[:, :a]) - / len(chain_compare[:, :a]) + np.ones(len(chain_compare)) + * len(chain) + / len(chain_compare) ) fig = corner.corner( - chain_compare[:, :a], + chain_compare, color="orange", alpha=0.5, weights=normalization_factor, @@ -128,10 +202,10 @@ def analyze_noise( ) # normal corner plot corner.corner( - chain[:, :a], fig=fig, color="black", labels=pars_short + chain, fig=fig, color="black", labels=pars_short ) if chaindir_compare is None: - corner.corner(chain[:, :a], labels=pars_short) + corner.corner(chain, labels=pars_short) if "_wb" in chaindir: figname = f"./{psr_name}_noise_corner_wb.pdf" @@ -174,9 +248,9 @@ def analyze_noise( chaindir_compare = None else: normalization_factor = ( - np.ones(len(chain_compare[:, :a])) - * len(chain[:, :a]) - / len(chain_compare[:, :a]) + np.ones(len(chain_compare)) + * len(chain) + / len(chain_compare) ) # Set the shape of the subplots @@ -189,9 +263,10 @@ def analyze_noise( nrows = 5 # number of rows per page - mp_idx = np.argmax(chain[:, a]) + mp_idx = noise_core.map_idx + #mp_idx = np.argmax(chain[:, a]) if chaindir_compare is not None: - mp_compare_idx = np.argmax(chain_compare[:, a]) + mp_compare_idx = compare_core.map_idx nbins = 20 pp = 0 @@ -235,17 +310,14 @@ def analyze_noise( # Wasn't working before, but how do I implement a legend? # ax[nr][nc].legend(loc = 'best') pl.show() - - ml_idx = np.argmax(chain[:, a]) - - wn_vals = chain[:, :a][ml_idx] - - wn_dict = dict(zip(pars, wn_vals)) + + noise_dict = noise_core.get_map_dict() # Print bayes factor for red noise in pulsar - rn_bf = model_utils.bayes_fac(chain[:, -5], ntol=1, logAmax=-11, logAmin=-20)[0] + rn_amp_nm = psr_name+"_red_noise_log10_A" + rn_bf = model_utils.bayes_fac(noise_core(rn_amp_nm), ntol=1, logAmax=-11, logAmin=-20)[0] - return wn_dict, rn_bf + return noise_core, noise_dict, rn_bf def model_noise( @@ -261,6 +333,7 @@ def model_noise( base_op_dir="./", noise_kwargs={}, sampler_kwargs={}, + return_sampler=False, ): """ Setup enterprise PTA and perform MCMC noise analysis @@ -272,17 +345,19 @@ def model_noise( sampler: choose from ['PTMCMCSampler' or 'GibbsSampler' or 'discovery'] PTMCMCSampler -- MCMC sampling with the Enterprise likelihood GibbsSampler -- enterprise_extension's GibbsSampler with PTMCMC and Enterprise white noise - discovery -- blocked Gibbs-Hamiltonian MC in numpyro with a discovery likelihood + discovery -- various numpyro samplers with a discovery likelihood red_noise: include red noise in the model n_iter: number of MCMC iterations; Default: 1e5; Recommended > 5e4 using_wideband: Flag to toggle between narrowband and wideband datasets; Default: False run_noise_analysis: Flag to toggle execution of noise modeling; Default: True noise_kwargs: dictionary of noise model parameters; Default: {} sampler_kwargs: dictionary of sampler parameters; Default: {} + return_sampler: Flag to return the sampler object; Default: False Returns ======= - None + None or + samp: sampler object """ if not using_wideband: @@ -318,11 +393,7 @@ def model_noise( ) # Create enterprise Pulsar object for supplied pulsar timing model (mo) and toas (to) - if which_sampler == "discovery": - # discovery requires feathered pulsars - f_psr = Pulsar(mo, to) - elif which_sampler == "GibbsSampler" or which_sampler == "PTMCMCSampler": - e_psr = Pulsar(mo, to) + e_psr = Pulsar(mo, to) if which_sampler == "PTMCMCSampler": log.info(f"INFO: Running noise analysis with {which_sampler} for {e_psr.name}") @@ -359,11 +430,14 @@ def model_noise( ) dmjump_params[dmjump_param_name] = dmjump_param.value pta.set_default_params(dmjump_params) - # FIXME: set groups here + # set groups here + groups = setup_sampling_groups(pta, write_groups=True, outdir=outdir) ####### # setup sampler using enterprise_extensions - samp = sampler.setup_sampler(pta, outdir=outdir, resume=resume) - + samp = sampler.setup_sampler(pta, + outdir=outdir, + resume=resume, + groups=groups) # Initial sample x0 = np.hstack([p.sample() for p in pta.params]) # Start sampling @@ -371,6 +445,11 @@ def model_noise( x0, n_iter, SCAMweight=30, AMweight=15, DEweight=50, **sampler_kwargs ) elif which_sampler == "GibbsSampler": + try: + from enterprise_extensions import GibbsSampler + except: + log.error("Please install the latest enterprise_extensions") + ValueError("Please install the latest enterprise_extensions") log.info(f"INFO: Running noise analysis with {which_sampler} for {e_psr.name}") samp = GibbsSampler( e_psr, @@ -386,22 +465,42 @@ def model_noise( except ImportError: log.error("Please install latest version of jax and/or xarray") ValueError("Please install lastest version of jax and/or xarray") - samp, log_x = setup_discovery_noise(f_psr) + # get the default settings + model_defaults, sampler_defaults = get_model_and_sampler_default_settings() + # update with args passed in + model_kwargs = model_defaults.update(noise_kwargs) + sampler_kwargs = sampler_defaults.update(sampler_kwargs) + os.mkdir(outdir, parents=True, exist_ok=True) + with open(outdir+"model_kwargs.json", "w") as f: + json.dump(model_kwargs, f) + with open(outdir+"sampler_kwargs.json", "w") as f: + json.dump(sampler_kwargs, f) + samp, log_x, numpyro_model = setup_discovery_noise(e_psr, model_kwargs, sampler_kwargs) # run the sampler samp.run(jax.random.key(42)) # convert to a DataFrame df = log_x.to_df(samp.get_samples()['par']) # convert DataFrame to dictionary samples_dict = df.to_dict(orient='list') + if sampler_kwargs['numpyro_sampler'] is not 'HMC_GIBBS': + ln_like = log_likelihood(numpyro_model, samp.get_samples())['ll'] + ln_prior = dist.Normal(0, 10).log_prob(samp.get_samples()['par']).sum(axis=-1) + ln_post = ln_like + ln_prior + samples_dict['lnlike'] = ln_like + samples_dict['lnpost'] = ln_post + else: + samples_dict['lnlike'] = None + samples_dict['lnpost'] = None # convert dictionary to ArviZ InferenceData object inference_data = az.from_dict(samples_dict) # Save to NetCDF file which can be loaded into la_forge - os.mkdir(outdir, parents=True, exist_ok=True) inference_data.to_netcdf(outdir+"chain.nc") else: log.error( "Invalid sampler specified. Please use 'PTMCMCSampler' or 'GibbsSampler' or 'discovery' " ) + if return_sampler: + return samp def convert_to_RNAMP(value): @@ -458,7 +557,7 @@ def add_noise_to_model( log.info(f"Using existing noise analysis results in {chaindir}") log.info("Adding new noise parameters to model.") - wn_dict, rn_bf = analyze_noise( + noise_dict, rn_bf = analyze_noise( chaindir, burn_frac, save_corner, @@ -472,9 +571,6 @@ def add_noise_to_model( # Create the maskParameter for EFACS efac_params = [] equad_params = [] - rn_params = [] - dm_gp_params = [] - chrom_gp_params = [] ecorr_params = [] dmefac_params = [] dmequad_params = [] @@ -485,7 +581,7 @@ def add_noise_to_model( dmefac_idx = 1 dmequad_idx = 1 - for key, val in wn_dict.items(): + for key, val in noise_dict.items(): psr_name = key.split("_")[0] @@ -617,14 +713,14 @@ def add_noise_to_model( # Test EQUAD convention and decide whether to convert convert_equad_to_t2 = False - if test_equad_convention(wn_dict.keys()) == "tnequad": + if test_equad_convention(noise_dict.keys()) == "tnequad": log.info( "WN paramaters use temponest convention; EQUAD values will be converted once added to model" ) convert_equad_to_t2 = True - if np.any(["_equad" in p for p in wn_dict.keys()]): + if np.any(["_equad" in p for p in noise_dict.keys()]): log.info("WN parameters generated using enterprise pre-v3.3.0") - elif test_equad_convention(wn_dict.keys()) == "t2equad": + elif test_equad_convention(noise_dict.keys()) == "t2equad": log.info("WN parameters use T2 convention; no conversion necessary") # Create white noise components and add them to the model @@ -662,11 +758,75 @@ def add_noise_to_model( # Add the ML RN parameters to their component rn_comp = pm.PLRedNoise() - rn_keys = np.array([key for key, val in wn_dict.items() if "_red_" in key]) + rn_keys = np.array([key for key, val in noise_dict.items() if "_red_" in key]) rn_comp.RNAMP.quantity = convert_to_RNAMP( - wn_dict[psr_name + "_red_noise_log10_A"] + noise_dict[psr_name + "_red_noise_log10_A"] ) - rn_comp.RNIDX.quantity = -1 * wn_dict[psr_name + "_red_noise_gamma"] + rn_comp.RNIDX.quantity = -1 * noise_dict[psr_name + "_red_noise_gamma"] + + # Add red noise to the timing model + model.add_component(rn_comp, validate=True, force=True) + else: + log.info("Not including red noise for this pulsar") + + # Check to see if dm noise is present + dm_pars = [key for key in list(noise_dict.keys()) if "_dm_gp" in key] + if len(dm_pars) > 0: + ###### POWERLAW DM NOISE ###### + if f'{psr_name}_dm_gp_log10_A' in dm_pars: + #dm_bf = model_utils.bayes_fac(noise_core(rn_amp_nm), ntol=1, logAmax=-11, logAmin=-20)[0] + #log.info(f"The SD Bayes factor for dm noise in this pulsar is: {dm_bf}") + log.info('Adding Powerlaw DM GP noise as PLDMNoise to par file') + # Add the ML RN parameters to their component + dm_comp = pm.PLDMNoise() + dm_keys = np.array([key for key, val in noise_dict.items() if "_red_" in key]) + dm_comp.TNDMAMP.quantity = convert_to_RNAMP( + noise_dict[psr_name + "_dm_gp_log10_A"] + ) + dm_comp.TNDMIDX.quantity = -1 * noise_dict[psr_name + "_dm_gp_gamma"] + ##### FIXMEEEEEEE : need to figure out some way to softcode this + dm_comp.TNDMC.quantitity = 100 + # Add red noise to the timing model + model.add_component(dm_comp, validate=True, force=True) + ###### FREE SPECTRAL (WaveX) DM NOISE ###### + elif f'{psr_name}_dm_gp_log10_rho_0' in dm_pars: + log.info('Adding Free Spectral DM GP as DMWaveXnoise to par file') + NotImplementedError('DMWaveXNoise not yet implemented') + + # Check to see if higher order chromatic noise is present + chrom_pars = [key for key in list(noise_dict.keys()) if "_chrom_gp" in key] + if len(chrom_pars) > 0: + ###### POWERLAW CHROMATIC NOISE ###### + if f'{psr_name}_chrom_gp_log10_A' in chrom_pars: + log.info('Adding Powerlaw CHROM GP noise as PLCMNoise to par file') + # Add the ML RN parameters to their component + chrom_comp = pm.PLCMNoise() + chrom_keys = np.array([key for key, val in noise_dict.items() if "_chrom_gp_" in key]) + dm_comp.TNDMAMP.quantity = convert_to_RNAMP( + noise_dict[psr_name + "_chrom_gp_log10_A"] + ) + chrom_comp.TNCMIDX.quantity = -1 * noise_dict[psr_name + "_dm_gp_gamma"] + ##### FIXMEEEEEEE : need to figure out some way to softcode this + chrom_comp.TNCMC.quantitity = 100 + # Add red noise to the timing model + model.add_component(dm_comp, validate=True, force=True) + ###### FREE SPECTRAL (WaveX) DM NOISE ###### + elif f'{psr_name}_chrom_gp_log10_rho_0' in chrom_pars: + log.info('Adding Free Spectral CHROM GP as CMWaveXnoise to par file') + NotImplementedError('CMWaveXNoise not yet implemented') + + log.info(f"The SD Bayes factor for dm noise in this pulsar is: {rn_bf}") + if (rn_bf >= rn_bf_thres or np.isnan(rn_bf)) and (not ignore_red_noise): + + log.info("Including red noise for this pulsar") + # Add the ML RN parameters to their component + rn_comp = pm.PLRedNoise() + + rn_keys = np.array([key for key, val in noise_dict.items() if "_red_" in key]) + rn_comp.RNAMP.quantity = convert_to_RNAMP( + noise_dict[psr_name + "_red_noise_log10_A"] + ) + rn_comp.RNIDX.quantity = -1 * noise_dict[psr_name + "_red_noise_gamma"] # Add red noise to the timing model model.add_component(rn_comp, validate=True, force=True) @@ -696,15 +856,16 @@ def setup_gibbs_sampler(): except ImportError: log.error("Please install the latest version of enterprise_extensions") return None - - pass + NotImplementedError("Gibbs sampler not yet implemented") -def setup_discovery_noise(psr): +def setup_discovery_noise(psr, + model_kwargs={}, + sampler_kwargs={}): """ - Setup the discovery sampler for noise analysis from enterprise extensions + Setup the discovery likelihood with numpyro sampling for noise analysis """ - # check that a sufficiently up-to-date version of enterprise_extensions is installed + # check that jax, numpyro and discovery are installed try: import discovery as ds import jax @@ -716,38 +877,78 @@ def setup_discovery_noise(psr): from discovery.prior import (makelogtransform_uniform, makelogprior_uniform, sample_uniform) + from discovery.gibbs import setup_single_psr_hmc_gibbs except ImportError: log.error("Please install the latest version of discovery, numpyro, and/or jax") ValueError("Please install the latest version of discovery, numpyro, and/or jax") - + # set up the model time_span = ds.getspan([psr]) - args = ( + model_components = [ + psr.residuals, + ds.makegp_timing(psr, svd=True), ds.makenoise_measurement(psr), ds.makegp_ecorr(psr), - ds.makegp_timing(psr, svd=True), - ds.makegp_fourier(psr, ds.powerlaw, 30, T=time_span, name='red_noise'), - psr.residuals - ) - psl = ds.PulsarLikelihood(args) + ] + if model_kwargs['inc_rn']: + if model_kwargs['rn_psd'] == 'powerlaw': + model_components.append(ds.makegp_fourier(psr, ds.powerlaw, model_kwargs['rn_nfreqs'], T=time_span, name='red_noise')) + elif model_kwargs['rn_psd'] == 'free_spectral': + model_components.append(ds.makegp_fourier(psr, ds.free_spectral, model_kwargs['rn_nfreqs'], T=time_span, name='red_noise')) + if model_kwargs['inc_dmgp']: + if model_kwargs['dmgp_psd'] == 'powerlaw': + model_components.append(ds.makegp_fourier(psr, ds.powerlaw, model_kwargs['dmgp_nfreqs'], T=time_span, name='dm_gp')) + elif model_kwargs['dmgp_psd'] == 'free_spectral': + model_components.append(ds.makegp_fourier(psr, ds.free_spectral, model_kwargs['dmgp_nfreqs'], T=time_span, name='dm_gp')) + if model_kwargs['inc_chrom']: + if model_kwargs['rn_psd'] == 'powerlaw': + model_components.append(ds.makegp_fourier(psr, ds.powerlaw, model_kwargs['chromgp_nfreqs'], T=time_span, name='dm_gp')) + elif model_kwargs['rn_psd'] == 'free_spectral': + model_components.append(ds.makegp_fourier(psr, ds.free_spectral, model_kwargs['chromgp_nfreqs'], T=time_span, name='dm_gp')) + psl = ds.PulsarLikelihood(model_components) prior = prior.makelogprior_uniform(psl.logL.params, {}) log_x = makelogtransform_uniform(psl.logL) # x0 = sample_uniform(psl.logL.params) - def numpyro_model(): - params = jnp.array(numpyro.sample("par", dist.Normal(0,10).expand([len(log_x.params)]))) - numpyro.factor("ll", log_x(params)) - - sampler = infer.MCMC( - infer.NUTS(numpyro_model), - num_warmup=250, - num_samples=4096, - num_chains=4, - progress_bar=True, - chain_method='vectorized' - ) + if sampler_kwargs['numpyro_sampler'] == 'HMC_Gibbs': + def numpyro_model(): + return None + gibbs_hmc_kernel = setup_single_psr_hmc_gibbs( + psrl=psl, psrs=psr, + priordict=ds.priordict_standard, + invhdorf=None, nuts_kwargs={}) + sampler = infer.MCMC(gibbs_hmc_kernel, + num_warmup=sampler_kwargs['num_warmup'], + num_samples=sampler_kwargs['num_warmup'], + num_chains=sampler_kwargs['num_chains'], + chain_method=sampler_kwargs['chain_method'], + progress_bar=True, + ) + elif sampler_kwargs['numpyro_sampler'] == 'NUTS': + def numpyro_model(): + params = jnp.array(numpyro.sample("par", dist.Normal(0,10).expand([len(log_x.params)]))) + numpyro.factor("ll", log_x(params)) + nuts_kernel = infer.NUTS(numpyro_model, num_steps=sampler_kwargs['num_steps']) + sampler = infer.MCMC(nuts_kernel, + num_warmup=sampler_kwargs['num_warmup'], + num_samples=sampler_kwargs['num_warmup'], + num_chains=sampler_kwargs['num_chains'], + chain_method=sampler_kwargs['chain_method'], + progress_bar=True, + ) + elif sampler_kwargs['numpyro_sampler'] == 'HMC': + def numpyro_model(): + params = jnp.array(numpyro.sample("par", dist.Normal(0,10).expand([len(log_x.params)]))) + numpyro.factor("ll", log_x(params)) + hmc_kernel = infer.HMC(numpyro_model, num_steps=sampler_kwargs['num_steps']) + sampler = infer.MCMC(hmc_kernel, + num_warmup=sampler_kwargs['num_warmup'], + num_samples=sampler_kwargs['num_warmup'], + num_chains=sampler_kwargs['num_chains'], + chain_method=sampler_kwargs['chain_method'], + progress_bar=True, + ) - return sampler, log_x - + return sampler, log_x, numpyro_model def test_equad_convention(pars_list): @@ -775,3 +976,29 @@ def test_equad_convention(pars_list): "EQUADs not present in parameter list (or something strange is going on)." ) return None + +def get_model_and_sampler_default_settings(): + model_defaults = { + 'inc_rn': True, + 'rn_psd': 'powerlaw', + 'rn_nfreqs': 30, + 'inc_dmgp': False, + 'dmgp_psd': 'powerlaw', + 'dmgp_nfreqs': 100, + 'inc_chromgp': False, + 'chromgp_psd': 'powerlaw', + 'chromgp_nfreqs': 100, + 'vary_chrom_idx': False, + 'inc_swgp': False, + 'swgp_psd': 'powerlaw', + 'swgp_nfreqs': 100, + } + sampler_defaults = { + 'numpyro_sampler': 'HMC', + 'num_steps': 5, + 'num_warmup': 500, + 'num_samples': 2500, + 'num_chains': 4, + 'chain_method': 'vectorized', + } + return model_defaults, sampler_defaults \ No newline at end of file