diff --git a/CHANGELOG-unreleased.md b/CHANGELOG-unreleased.md index e5ca814a7..d2a93fa95 100644 --- a/CHANGELOG-unreleased.md +++ b/CHANGELOG-unreleased.md @@ -16,6 +16,7 @@ the released changes. - Added an option `linearize_model` to speed up the photon phases calculation within `event_optimize` through the designmatrix. - Added AIC and BIC calculation to be written in the post fit parfile from `event_optimize` - When TCB->TDB conversion info is missing, will print parameter name +- Piecewise-constant model for chromatic variations (CMX) - `add_param` returns the name of the parameter (useful for numbered parameters) ### Fixed - Changed WAVE_OM units from 1/d to rad/d. diff --git a/src/pint/models/__init__.py b/src/pint/models/__init__.py index 67201a259..27c78a4f4 100644 --- a/src/pint/models/__init__.py +++ b/src/pint/models/__init__.py @@ -25,7 +25,7 @@ from pint.models.binary_dd import BinaryDD, BinaryDDS, BinaryDDGR, BinaryDDH from pint.models.binary_ddk import BinaryDDK from pint.models.binary_ell1 import BinaryELL1, BinaryELL1H, BinaryELL1k -from pint.models.chromatic_model import ChromaticCM +from pint.models.chromatic_model import ChromaticCM, ChromaticCMX from pint.models.cmwavex import CMWaveX from pint.models.dispersion_model import ( DispersionDM, diff --git a/src/pint/models/chromatic_model.py b/src/pint/models/chromatic_model.py index e64b8f421..8f428fea2 100644 --- a/src/pint/models/chromatic_model.py +++ b/src/pint/models/chromatic_model.py @@ -2,18 +2,20 @@ from warnings import warn import numpy as np import astropy.units as u -from pint.models.timing_model import DelayComponent +from pint.models.timing_model import DelayComponent, MissingParameter, MissingTOAs from pint.models.parameter import floatParameter, prefixParameter, MJDParameter +from pint.toa_select import TOASelect from pint.utils import split_prefixed_name, taylor_horner, taylor_horner_deriv from pint import DMconst from pint.exceptions import MissingParameter from astropy.time import Time +from loguru import logger as log cmu = u.pc / u.cm**3 / u.MHz**2 class Chromatic(DelayComponent): - """A base chromatic timing model.""" + """A base chromatic timing model with a constant chromatic index.""" def __init__(self): super().__init__() @@ -112,11 +114,14 @@ def register_cm_deriv_funcs(self, func, param): class ChromaticCM(Chromatic): - """Simple chromatic delay model. + """Simple chromatic delay model with a constant chromatic index. This model uses Taylor expansion to model CM variation over time. It can also be used for a constant CM. + Fitting for the chromatic index is not supported because the fit is too + unstable when fit simultaneously with the DM. + Parameters supported: .. paramtable:: @@ -301,3 +306,401 @@ def change_cmepoch(self, new_epoch): dt.to(u.yr), cmterms, deriv_order=n + 1 ) self.CMEPOCH.value = new_epoch + + +class ChromaticCMX(Chromatic): + """This class provides a CMX model - piecewise-constant chromatic variations with constant + chromatic index. + + This model lets the user specify time ranges and fit for a different CMX value in each time range. + + It should be used in combination with the `ChromaticCM` model. Specifically, TNCHROMIDX must be + set. + + Parameters supported: + + .. paramtable:: + :class: pint.models.chromatic_model.ChromaticCMX + """ + + register = True + category = "chromatic_cmx" + + def __init__(self): + super().__init__() + + self.add_CMX_range(None, None, cmx=0, frozen=False, index=1) + + self.cm_value_funcs += [self.cmx_cm] + self.set_special_params(["CMX_0001", "CMXR1_0001", "CMXR2_0001"]) + self.delay_funcs_component += [self.CMX_chromatic_delay] + + def add_CMX_range(self, mjd_start, mjd_end, index=None, cmx=0, frozen=True): + """Add CMX range to a chromatic model with specified start/end MJDs and CMX value. + + Parameters + ---------- + + mjd_start : float or astropy.quantity.Quantity or astropy.time.Time + MJD for beginning of CMX event. + mjd_end : float or astropy.quantity.Quantity or astropy.time.Time + MJD for end of CMX event. + index : int, None + Integer label for CMX event. If None, will increment largest used index by 1. + cmx : float or astropy.quantity.Quantity + Change in CM during CMX event. + frozen : bool + Indicates whether CMX will be fit. + + Returns + ------- + + index : int + Index that has been assigned to new CMX event. + """ + + #### Setting up the CMX title convention. If index is None, want to increment the current max CMX index by 1. + if index is None: + dct = self.get_prefix_mapping_component("CMX_") + index = np.max(list(dct.keys())) + 1 + i = f"{int(index):04d}" + + if mjd_end is not None and mjd_start is not None: + if mjd_end < mjd_start: + raise ValueError("Starting MJD is greater than ending MJD.") + elif mjd_start != mjd_end: + raise ValueError("Only one MJD bound is set.") + + if int(index) in self.get_prefix_mapping_component("CMX_"): + raise ValueError( + f"Index '{index}' is already in use in this model. Please choose another." + ) + + if isinstance(cmx, u.quantity.Quantity): + cmx = cmx.to_value(cmu) + + if isinstance(mjd_start, Time): + mjd_start = mjd_start.mjd + elif isinstance(mjd_start, u.quantity.Quantity): + mjd_start = mjd_start.value + if isinstance(mjd_end, Time): + mjd_end = mjd_end.mjd + elif isinstance(mjd_end, u.quantity.Quantity): + mjd_end = mjd_end.value + self.add_param( + prefixParameter( + name=f"CMX_{i}", + units=cmu, + value=cmx, + description="Dispersion measure variation", + parameter_type="float", + frozen=frozen, + convert_tcb2tdb=False, + ) + ) + self.add_param( + prefixParameter( + name=f"CMXR1_{i}", + units="MJD", + description="Beginning of CMX interval", + parameter_type="MJD", + time_scale="utc", + value=mjd_start, + convert_tcb2tdb=False, + ) + ) + self.add_param( + prefixParameter( + name=f"CMXR2_{i}", + units="MJD", + description="End of CMX interval", + parameter_type="MJD", + time_scale="utc", + value=mjd_end, + convert_tcb2tdb=False, + ) + ) + self.setup() + self.validate() + return index + + def add_CMX_ranges(self, mjd_starts, mjd_ends, indices=None, cmxs=0, frozens=True): + """Add CMX ranges to a dispersion model with specified start/end MJDs and CMXs. + + Parameters + ---------- + + mjd_starts : iterable of float or astropy.quantity.Quantity or astropy.time.Time + MJD for beginning of CMX event. + mjd_end : iterable of float or astropy.quantity.Quantity or astropy.time.Time + MJD for end of CMX event. + indices : iterable of int, None + Integer label for CMX event. If None, will increment largest used index by 1. + cmxs : iterable of float or astropy.quantity.Quantity, or float or astropy.quantity.Quantity + Change in CM during CMX event. + frozens : iterable of bool or bool + Indicates whether CMX will be fit. + + Returns + ------- + + indices : list + Indices that has been assigned to new CMX events + """ + if len(mjd_starts) != len(mjd_ends): + raise ValueError( + f"Number of mjd_start values {len(mjd_starts)} must match number of mjd_end values {len(mjd_ends)}" + ) + if indices is None: + indices = [None] * len(mjd_starts) + cmxs = np.atleast_1d(cmxs) + if len(cmxs) == 1: + cmxs = np.repeat(cmxs, len(mjd_starts)) + if len(cmxs) != len(mjd_starts): + raise ValueError( + f"Number of mjd_start values {len(mjd_starts)} must match number of cmx values {len(cmxs)}" + ) + frozens = np.atleast_1d(frozens) + if len(frozens) == 1: + frozens = np.repeat(frozens, len(mjd_starts)) + if len(frozens) != len(mjd_starts): + raise ValueError( + f"Number of mjd_start values {len(mjd_starts)} must match number of frozen values {len(frozens)}" + ) + + #### Setting up the CMX title convention. If index is None, want to increment the current max CMX index by 1. + dct = self.get_prefix_mapping_component("CMX_") + last_index = np.max(list(dct.keys())) + added_indices = [] + for mjd_start, mjd_end, index, cmx, frozen in zip( + mjd_starts, mjd_ends, indices, cmxs, frozens + ): + if index is None: + index = last_index + 1 + last_index += 1 + elif index in list(dct.keys()): + raise ValueError( + f"Attempting to insert CMX_{index:04d} but it already exists" + ) + added_indices.append(index) + i = f"{int(index):04d}" + + if mjd_end is not None and mjd_start is not None: + if mjd_end < mjd_start: + raise ValueError("Starting MJD is greater than ending MJD.") + elif mjd_start != mjd_end: + raise ValueError("Only one MJD bound is set.") + if int(index) in dct: + raise ValueError( + f"Index '{index}' is already in use in this model. Please choose another." + ) + if isinstance(cmx, u.quantity.Quantity): + cmx = cmx.to_value(u.pc / u.cm**3) + if isinstance(mjd_start, Time): + mjd_start = mjd_start.mjd + elif isinstance(mjd_start, u.quantity.Quantity): + mjd_start = mjd_start.value + if isinstance(mjd_end, Time): + mjd_end = mjd_end.mjd + elif isinstance(mjd_end, u.quantity.Quantity): + mjd_end = mjd_end.value + log.trace(f"Adding CMX_{i} from MJD {mjd_start} to MJD {mjd_end}") + self.add_param( + prefixParameter( + name=f"CMX_{i}", + units=cmu, + value=cmx, + description="Dispersion measure variation", + parameter_type="float", + frozen=frozen, + convert_tcb2tdb=False, + ) + ) + self.add_param( + prefixParameter( + name=f"CMXR1_{i}", + units="MJD", + description="Beginning of CMX interval", + parameter_type="MJD", + time_scale="utc", + value=mjd_start, + convert_tcb2tdb=False, + ) + ) + self.add_param( + prefixParameter( + name=f"CMXR2_{i}", + units="MJD", + description="End of CMX interval", + parameter_type="MJD", + time_scale="utc", + value=mjd_end, + convert_tcb2tdb=False, + ) + ) + self.setup() + self.validate() + return added_indices + + def remove_CMX_range(self, index): + """Removes all CMX parameters associated with a given index/list of indices. + + Parameters + ---------- + + index : float, int, list, np.ndarray + Number or list/array of numbers corresponding to CMX indices to be removed from model. + """ + + if isinstance(index, (int, float, np.int64)): + indices = [index] + elif isinstance(index, (list, set, np.ndarray)): + indices = index + else: + raise TypeError( + f"index must be a float, int, set, list, or array - not {type(index)}" + ) + for index in indices: + index_rf = f"{int(index):04d}" + for prefix in ["CMX_", "CMXR1_", "CMXR2_"]: + self.remove_param(prefix + index_rf) + self.validate() + + def get_indices(self): + """Returns an array of integers corresponding to CMX parameters. + + Returns + ------- + inds : np.ndarray + Array of CMX indices in model. + """ + inds = [int(p.split("_")[-1]) for p in self.params if "CMX_" in p] + return np.array(inds) + + def setup(self): + super().setup() + # Get CMX mapping. + # Register the CMX derivatives + for prefix_par in self.get_params_of_type("prefixParameter"): + if prefix_par.startswith("CMX_"): + self.register_deriv_funcs(self.d_delay_d_cmparam, prefix_par) + self.register_cm_deriv_funcs(self.d_cm_d_CMX, prefix_par) + + def validate(self): + """Validate the CMX parameters.""" + super().validate() + CMX_mapping = self.get_prefix_mapping_component("CMX_") + CMXR1_mapping = self.get_prefix_mapping_component("CMXR1_") + CMXR2_mapping = self.get_prefix_mapping_component("CMXR2_") + if CMX_mapping.keys() != CMXR1_mapping.keys(): + # FIXME: report mismatch + raise ValueError( + "CMX_ parameters do not " + "match CMXR1_ parameters. " + "Please check your prefixed parameters." + ) + if CMX_mapping.keys() != CMXR2_mapping.keys(): + raise ValueError( + "CMX_ parameters do not " + "match CMXR2_ parameters. " + "Please check your prefixed parameters." + ) + r1 = np.zeros(len(CMX_mapping)) + r2 = np.zeros(len(CMX_mapping)) + indices = np.zeros(len(CMX_mapping), dtype=np.int32) + for j, index in enumerate(CMX_mapping): + if ( + getattr(self, f"CMXR1_{index:04d}").quantity is not None + and getattr(self, f"CMXR2_{index:04d}").quantity is not None + ): + r1[j] = getattr(self, f"CMXR1_{index:04d}").quantity.mjd + r2[j] = getattr(self, f"CMXR2_{index:04d}").quantity.mjd + indices[j] = index + for j, index in enumerate(CMXR1_mapping): + if np.any((r1[j] > r1) & (r1[j] < r2)): + k = np.where((r1[j] > r1) & (r1[j] < r2))[0] + for kk in k.flatten(): + log.warning( + f"Start of CMX_{index:04d} ({r1[j]}-{r2[j]}) overlaps with CMX_{indices[kk]:04d} ({r1[kk]}-{r2[kk]})" + ) + if np.any((r2[j] > r1) & (r2[j] < r2)): + k = np.where((r2[j] > r1) & (r2[j] < r2))[0] + for kk in k.flatten(): + log.warning( + f"End of CMX_{index:04d} ({r1[j]}-{r2[j]}) overlaps with CMX_{indices[kk]:04d} ({r1[kk]}-{r2[kk]})" + ) + + def validate_toas(self, toas): + CMX_mapping = self.get_prefix_mapping_component("CMX_") + CMXR1_mapping = self.get_prefix_mapping_component("CMXR1_") + CMXR2_mapping = self.get_prefix_mapping_component("CMXR2_") + bad_parameters = [] + for k in CMXR1_mapping.keys(): + if self._parent[CMX_mapping[k]].frozen: + continue + b = self._parent[CMXR1_mapping[k]].quantity.mjd * u.d + e = self._parent[CMXR2_mapping[k]].quantity.mjd * u.d + mjds = toas.get_mjds() + n = np.sum((b <= mjds) & (mjds < e)) + if n == 0: + bad_parameters.append(CMX_mapping[k]) + if bad_parameters: + raise MissingTOAs(bad_parameters) + + def cmx_cm(self, toas): + condition = {} + tbl = toas.table + if not hasattr(self, "cmx_toas_selector"): + self.cmx_toas_selector = TOASelect(is_range=True) + CMX_mapping = self.get_prefix_mapping_component("CMX_") + CMXR1_mapping = self.get_prefix_mapping_component("CMXR1_") + CMXR2_mapping = self.get_prefix_mapping_component("CMXR2_") + for epoch_ind in CMX_mapping.keys(): + r1 = getattr(self, CMXR1_mapping[epoch_ind]).quantity + r2 = getattr(self, CMXR2_mapping[epoch_ind]).quantity + condition[CMX_mapping[epoch_ind]] = (r1.mjd, r2.mjd) + select_idx = self.cmx_toas_selector.get_select_index( + condition, tbl["mjd_float"] + ) + # Get CMX delays + cm = np.zeros(len(tbl)) * self._parent.CM.units + for k, v in select_idx.items(): + cm[v] += getattr(self, k).quantity + return cm + + def CMX_chromatic_delay(self, toas, acc_delay=None): + """This is a wrapper function for interacting with the TimingModel class""" + return self.chromatic_type_delay(toas) + + def d_cm_d_CMX(self, toas, param_name, acc_delay=None): + condition = {} + tbl = toas.table + if not hasattr(self, "cmx_toas_selector"): + self.cmx_toas_selector = TOASelect(is_range=True) + param = getattr(self, param_name) + cmx_index = param.index + CMXR1_mapping = self.get_prefix_mapping_component("CMXR1_") + CMXR2_mapping = self.get_prefix_mapping_component("CMXR2_") + r1 = getattr(self, CMXR1_mapping[cmx_index]).quantity + r2 = getattr(self, CMXR2_mapping[cmx_index]).quantity + condition = {param_name: (r1.mjd, r2.mjd)} + select_idx = self.cmx_toas_selector.get_select_index( + condition, tbl["mjd_float"] + ) + + cmx = np.zeros(len(tbl)) + for k, v in select_idx.items(): + cmx[v] = 1.0 + return cmx * (u.pc / u.cm**3) / (u.pc / u.cm**3) + + def print_par(self, format="pint"): + result = "" + CMX_mapping = self.get_prefix_mapping_component("CMX_") + CMXR1_mapping = self.get_prefix_mapping_component("CMXR1_") + CMXR2_mapping = self.get_prefix_mapping_component("CMXR2_") + sorted_list = sorted(CMX_mapping.keys()) + for ii in sorted_list: + result += getattr(self, CMX_mapping[ii]).as_parfile_line(format=format) + result += getattr(self, CMXR1_mapping[ii]).as_parfile_line(format=format) + result += getattr(self, CMXR2_mapping[ii]).as_parfile_line(format=format) + return result diff --git a/src/pint/models/dispersion_model.py b/src/pint/models/dispersion_model.py index 368392dd9..ef6725548 100644 --- a/src/pint/models/dispersion_model.py +++ b/src/pint/models/dispersion_model.py @@ -309,7 +309,7 @@ def change_dmepoch(self, new_epoch): class DispersionDMX(Dispersion): - """This class provides a DMX model - multiple DM values. + """This class provides a DMX model - piecewise-constant DM variations. This model lets the user specify time ranges and fit for a different DM value in each time range. @@ -345,7 +345,7 @@ def __init__(self): self.delay_funcs_component += [self.DMX_dispersion_delay] def add_DMX_range(self, mjd_start, mjd_end, index=None, dmx=0, frozen=True): - """Add DMX range to a dispersion model with specified start/end MJDs and DMX. + """Add DMX range to a dispersion model with specified start/end MJDs and DMX value. Parameters ---------- diff --git a/tests/test_cmx.py b/tests/test_cmx.py new file mode 100644 index 000000000..976581880 --- /dev/null +++ b/tests/test_cmx.py @@ -0,0 +1,132 @@ +from io import StringIO +from pint.models import get_model +from pint.simulation import make_fake_toas_uniform +from pint.fitter import Fitter + +import pytest +import numpy as np +import astropy.units as u + + +@pytest.fixture +def model_and_toas(): + par = """ + RAJ 05:00:00 + DECJ 12:00:00 + F0 101.0 + F1 -1.1e-14 + PEPOCH 55000 + DM 12 + DMEPOCH 55000 + CMEPOCH 55000 + CM 3.5 + TNCHROMIDX 4.0 + EPHEM DE440 + CLOCK TT(BIPM2021) + UNITS TDB + TZRMJD 55000 + TZRFRQ 1400 + TZRSITE gmrt + CMXR1_0001 53999.9 + CMXR2_0001 54500 + CMX_0001 0.5 1 + CMXR1_0002 54500.1 + CMXR2_0002 55000 + CMX_0002 -0.5 1 + CMXR1_0003 55000.1 + CMXR2_0003 55500 + CMX_0003 -0.1 1 + """ + + model = get_model(StringIO(par)) + + freqs = np.linspace(300, 1600, 16) * u.MHz + toas = make_fake_toas_uniform( + startMJD=54000, + endMJD=56000, + ntoas=2000, + model=model, + freq=freqs, + add_noise=True, + obs="gmrt", + include_bipm=True, + multi_freqs_in_epoch=True, + ) + + return model, toas + + +def test_cmx(model_and_toas): + model, toas = model_and_toas + + assert "ChromaticCMX" in model.components + + ftr = Fitter.auto(toas, model) + ftr.fit_toas() + + assert np.abs(model.CMX_0001.value - ftr.model.CMX_0001.value) / ( + 3 * ftr.model.CMX_0001.uncertainty_value + ) + assert np.abs(model.CMX_0002.value - ftr.model.CMX_0002.value) / ( + 3 * ftr.model.CMX_0002.uncertainty_value + ) + assert np.abs(model.CMX_0003.value - ftr.model.CMX_0003.value) / ( + 3 * ftr.model.CMX_0003.uncertainty_value + ) + + assert ftr.resids.chi2_reduced < 1.6 + + assert "CMX_0001" in str(ftr.model) + + +def test_cmx_delay(model_and_toas): + model, toas = model_and_toas + + # Zero delay outside CMX ranges + nocmx_mask = toas.get_mjds().value > 55500 + assert all( + model.components["ChromaticCMX"].CMX_chromatic_delay(toas)[nocmx_mask] == 0 + ) + + # The delay is consistent + cmx1_mask = np.logical_and( + toas.get_mjds().value >= model.CMXR1_0001.value, + toas.get_mjds().value <= model.CMXR2_0001.value, + ) + cmx1_freqs = toas.get_freqs()[cmx1_mask] + assert all( + np.isclose( + model.components["ChromaticCMX"].chromatic_time_delay( + model.CMX_0001.quantity, model.TNCHROMIDX.quantity, cmx1_freqs + ), + model.components["ChromaticCMX"].CMX_chromatic_delay(toas)[cmx1_mask], + ) + ) + + cmx2_mask = np.logical_and( + toas.get_mjds().value >= model.CMXR1_0002.value, + toas.get_mjds().value <= model.CMXR2_0002.value, + ) + cmx2_freqs = toas.get_freqs()[cmx2_mask] + assert all( + np.isclose( + model.components["ChromaticCMX"].chromatic_time_delay( + model.CMX_0002.quantity, model.TNCHROMIDX.quantity, cmx2_freqs + ), + model.components["ChromaticCMX"].CMX_chromatic_delay(toas)[cmx2_mask], + ) + ) + + cmx3_mask = np.logical_and( + toas.get_mjds().value >= model.CMXR1_0003.value, + toas.get_mjds().value <= model.CMXR2_0003.value, + ) + cmx3_freqs = toas.get_freqs()[cmx3_mask] + assert all( + np.isclose( + model.components["ChromaticCMX"].chromatic_time_delay( + model.CMX_0003.quantity, model.TNCHROMIDX.quantity, cmx3_freqs + ), + model.components["ChromaticCMX"].CMX_chromatic_delay(toas)[cmx3_mask], + ) + )