From a305ae0760bda5940bbad2d6f22dcfdd3861f23e Mon Sep 17 00:00:00 2001 From: Abhimanyu Susobhanan Date: Mon, 3 Jun 2024 12:02:12 +0200 Subject: [PATCH 1/7] ChromaticCMX stub --- src/pint/models/chromatic_model.py | 414 ++++++++++++++++++++++++++++ src/pint/models/dispersion_model.py | 2 +- 2 files changed, 415 insertions(+), 1 deletion(-) diff --git a/src/pint/models/chromatic_model.py b/src/pint/models/chromatic_model.py index 83583a5e9..5fab08926 100644 --- a/src/pint/models/chromatic_model.py +++ b/src/pint/models/chromatic_model.py @@ -300,3 +300,417 @@ def change_cmepoch(self, new_epoch): dt.to(u.yr), cmterms, deriv_order=n + 1 ) self.CMEPOCH.value = new_epoch + + +class ChromaticCMX(Chromatic): + """This class provides a CMX model - piecewise-constant chromatic variations. + + This model lets the user specify time ranges and fit for a different + CMX value in each time range. + + Parameters supported: + + .. paramtable:: + :class: pint.models.dispersion_model.DispersionDMX + """ + + register = True + category = "dispersion_dmx" + + def __init__(self): + super().__init__() + + # DMX is for info output right now + # @abhisrkckl: What exactly is the use of this parameter? + self.add_param( + floatParameter( + name="DMX", + units="pc cm^-3", + value=0.0, + description="Dispersion measure", + convert_tcb2tdb=False, + ) + ) + + self.add_DMX_range(None, None, dmx=0, frozen=False, index=1) + + self.dm_value_funcs += [self.dmx_dm] + self.set_special_params(["DMX_0001", "DMXR1_0001", "DMXR2_0001"]) + self.delay_funcs_component += [self.DMX_dispersion_delay] + + def add_DMX_range(self, mjd_start, mjd_end, index=None, dmx=0, frozen=True): + """Add DMX range to a dispersion model with specified start/end MJDs and DMX. + + Parameters + ---------- + + mjd_start : float or astropy.quantity.Quantity or astropy.time.Time + MJD for beginning of DMX event. + mjd_end : float or astropy.quantity.Quantity or astropy.time.Time + MJD for end of DMX event. + index : int, None + Integer label for DMX event. If None, will increment largest used index by 1. + dmx : float or astropy.quantity.Quantity + Change in DM during DMX event. + frozen : bool + Indicates whether DMX will be fit. + + Returns + ------- + + index : int + Index that has been assigned to new DMX event. + + """ + + #### Setting up the DMX title convention. If index is None, want to increment the current max DMX index by 1. + if index is None: + dct = self.get_prefix_mapping_component("DMX_") + index = np.max(list(dct.keys())) + 1 + i = f"{int(index):04d}" + + if mjd_end is not None and mjd_start is not None: + if mjd_end < mjd_start: + raise ValueError("Starting MJD is greater than ending MJD.") + elif mjd_start != mjd_end: + raise ValueError("Only one MJD bound is set.") + + if int(index) in self.get_prefix_mapping_component("DMX_"): + raise ValueError( + f"Index '{index}' is already in use in this model. Please choose another." + ) + + if isinstance(dmx, u.quantity.Quantity): + dmx = dmx.to_value(u.pc / u.cm**3) + if isinstance(mjd_start, Time): + mjd_start = mjd_start.mjd + elif isinstance(mjd_start, u.quantity.Quantity): + mjd_start = mjd_start.value + if isinstance(mjd_end, Time): + mjd_end = mjd_end.mjd + elif isinstance(mjd_end, u.quantity.Quantity): + mjd_end = mjd_end.value + self.add_param( + prefixParameter( + name=f"DMX_{i}", + units="pc cm^-3", + value=dmx, + description="Dispersion measure variation", + parameter_type="float", + frozen=frozen, + tcb2tdb_scale_factor=DMconst, + ) + ) + self.add_param( + prefixParameter( + name=f"DMXR1_{i}", + units="MJD", + description="Beginning of DMX interval", + parameter_type="MJD", + time_scale="utc", + value=mjd_start, + tcb2tdb_scale_factor=u.Quantity(1), + ) + ) + self.add_param( + prefixParameter( + name=f"DMXR2_{i}", + units="MJD", + description="End of DMX interval", + parameter_type="MJD", + time_scale="utc", + value=mjd_end, + tcb2tdb_scale_factor=u.Quantity(1), + ) + ) + self.setup() + self.validate() + return index + + def add_DMX_ranges(self, mjd_starts, mjd_ends, indices=None, dmxs=0, frozens=True): + """Add DMX ranges to a dispersion model with specified start/end MJDs and DMXs. + + Parameters + ---------- + + mjd_starts : iterable of float or astropy.quantity.Quantity or astropy.time.Time + MJD for beginning of DMX event. + mjd_end : iterable of float or astropy.quantity.Quantity or astropy.time.Time + MJD for end of DMX event. + indices : iterable of int, None + Integer label for DMX event. If None, will increment largest used index by 1. + dmxs : iterable of float or astropy.quantity.Quantity, or float or astropy.quantity.Quantity + Change in DM during DMX event. + frozens : iterable of bool or bool + Indicates whether DMX will be fit. + + Returns + ------- + + indices : list + Indices that has been assigned to new DMX events + + """ + if len(mjd_starts) != len(mjd_ends): + raise ValueError( + f"Number of mjd_start values {len(mjd_starts)} must match number of mjd_end values {len(mjd_ends)}" + ) + if indices is None: + indices = [None] * len(mjd_starts) + dmxs = np.atleast_1d(dmxs) + if len(dmxs) == 1: + dmxs = np.repeat(dmxs, len(mjd_starts)) + if len(dmxs) != len(mjd_starts): + raise ValueError( + f"Number of mjd_start values {len(mjd_starts)} must match number of dmx values {len(dmxs)}" + ) + frozens = np.atleast_1d(frozens) + if len(frozens) == 1: + frozens = np.repeat(frozens, len(mjd_starts)) + if len(frozens) != len(mjd_starts): + raise ValueError( + f"Number of mjd_start values {len(mjd_starts)} must match number of frozen values {len(frozens)}" + ) + + #### Setting up the DMX title convention. If index is None, want to increment the current max DMX index by 1. + dct = self.get_prefix_mapping_component("DMX_") + last_index = np.max(list(dct.keys())) + added_indices = [] + for mjd_start, mjd_end, index, dmx, frozen in zip( + mjd_starts, mjd_ends, indices, dmxs, frozens + ): + if index is None: + index = last_index + 1 + last_index += 1 + elif index in list(dct.keys()): + raise ValueError( + f"Attempting to insert DMX_{index:04d} but it already exists" + ) + added_indices.append(index) + i = f"{int(index):04d}" + + if mjd_end is not None and mjd_start is not None: + if mjd_end < mjd_start: + raise ValueError("Starting MJD is greater than ending MJD.") + elif mjd_start != mjd_end: + raise ValueError("Only one MJD bound is set.") + if int(index) in dct: + raise ValueError( + f"Index '{index}' is already in use in this model. Please choose another." + ) + if isinstance(dmx, u.quantity.Quantity): + dmx = dmx.to_value(u.pc / u.cm**3) + if isinstance(mjd_start, Time): + mjd_start = mjd_start.mjd + elif isinstance(mjd_start, u.quantity.Quantity): + mjd_start = mjd_start.value + if isinstance(mjd_end, Time): + mjd_end = mjd_end.mjd + elif isinstance(mjd_end, u.quantity.Quantity): + mjd_end = mjd_end.value + log.trace(f"Adding DMX_{i} from MJD {mjd_start} to MJD {mjd_end}") + self.add_param( + prefixParameter( + name=f"DMX_{i}", + units="pc cm^-3", + value=dmx, + description="Dispersion measure variation", + parameter_type="float", + frozen=frozen, + tcb2tdb_scale_factor=DMconst, + ) + ) + self.add_param( + prefixParameter( + name=f"DMXR1_{i}", + units="MJD", + description="Beginning of DMX interval", + parameter_type="MJD", + time_scale="utc", + value=mjd_start, + tcb2tdb_scale_factor=u.Quantity(1), + ) + ) + self.add_param( + prefixParameter( + name=f"DMXR2_{i}", + units="MJD", + description="End of DMX interval", + parameter_type="MJD", + time_scale="utc", + value=mjd_end, + tcb2tdb_scale_factor=u.Quantity(1), + ) + ) + self.setup() + self.validate() + return added_indices + + def remove_DMX_range(self, index): + """Removes all DMX parameters associated with a given index/list of indices. + + Parameters + ---------- + + index : float, int, list, np.ndarray + Number or list/array of numbers corresponding to DMX indices to be removed from model. + """ + + if isinstance(index, (int, float, np.int64)): + indices = [index] + elif isinstance(index, (list, set, np.ndarray)): + indices = index + else: + raise TypeError( + f"index must be a float, int, set, list, or array - not {type(index)}" + ) + for index in indices: + index_rf = f"{int(index):04d}" + for prefix in ["DMX_", "DMXR1_", "DMXR2_"]: + self.remove_param(prefix + index_rf) + self.validate() + + def get_indices(self): + """Returns an array of integers corresponding to DMX parameters. + + Returns + ------- + inds : np.ndarray + Array of DMX indices in model. + """ + inds = [int(p.split("_")[-1]) for p in self.params if "DMX_" in p] + return np.array(inds) + + def setup(self): + super().setup() + # Get DMX mapping. + # Register the DMX derivatives + for prefix_par in self.get_params_of_type("prefixParameter"): + if prefix_par.startswith("DMX_"): + self.register_deriv_funcs(self.d_delay_d_dmparam, prefix_par) + self.register_dm_deriv_funcs(self.d_dm_d_DMX, prefix_par) + + def validate(self): + """Validate the DMX parameters.""" + super().validate() + DMX_mapping = self.get_prefix_mapping_component("DMX_") + DMXR1_mapping = self.get_prefix_mapping_component("DMXR1_") + DMXR2_mapping = self.get_prefix_mapping_component("DMXR2_") + if DMX_mapping.keys() != DMXR1_mapping.keys(): + # FIXME: report mismatch + raise ValueError( + "DMX_ parameters do not " + "match DMXR1_ parameters. " + "Please check your prefixed parameters." + ) + if DMX_mapping.keys() != DMXR2_mapping.keys(): + raise ValueError( + "DMX_ parameters do not " + "match DMXR2_ parameters. " + "Please check your prefixed parameters." + ) + r1 = np.zeros(len(DMX_mapping)) + r2 = np.zeros(len(DMX_mapping)) + indices = np.zeros(len(DMX_mapping), dtype=np.int32) + for j, index in enumerate(DMX_mapping): + if ( + getattr(self, f"DMXR1_{index:04d}").quantity is not None + and getattr(self, f"DMXR2_{index:04d}").quantity is not None + ): + r1[j] = getattr(self, f"DMXR1_{index:04d}").quantity.mjd + r2[j] = getattr(self, f"DMXR2_{index:04d}").quantity.mjd + indices[j] = index + for j, index in enumerate(DMXR1_mapping): + if np.any((r1[j] > r1) & (r1[j] < r2)): + k = np.where((r1[j] > r1) & (r1[j] < r2))[0] + for kk in k.flatten(): + log.warning( + f"Start of DMX_{index:04d} ({r1[j]}-{r2[j]}) overlaps with DMX_{indices[kk]:04d} ({r1[kk]}-{r2[kk]})" + ) + if np.any((r2[j] > r1) & (r2[j] < r2)): + k = np.where((r2[j] > r1) & (r2[j] < r2))[0] + for kk in k.flatten(): + log.warning( + f"End of DMX_{index:04d} ({r1[j]}-{r2[j]}) overlaps with DMX_{indices[kk]:04d} ({r1[kk]}-{r2[kk]})" + ) + + def validate_toas(self, toas): + DMX_mapping = self.get_prefix_mapping_component("DMX_") + DMXR1_mapping = self.get_prefix_mapping_component("DMXR1_") + DMXR2_mapping = self.get_prefix_mapping_component("DMXR2_") + bad_parameters = [] + for k in DMXR1_mapping.keys(): + if self._parent[DMX_mapping[k]].frozen: + continue + b = self._parent[DMXR1_mapping[k]].quantity.mjd * u.d + e = self._parent[DMXR2_mapping[k]].quantity.mjd * u.d + mjds = toas.get_mjds() + n = np.sum((b <= mjds) & (mjds < e)) + if n == 0: + bad_parameters.append(DMX_mapping[k]) + if bad_parameters: + raise MissingTOAs(bad_parameters) + + def dmx_dm(self, toas): + condition = {} + tbl = toas.table + if not hasattr(self, "dmx_toas_selector"): + self.dmx_toas_selector = TOASelect(is_range=True) + DMX_mapping = self.get_prefix_mapping_component("DMX_") + DMXR1_mapping = self.get_prefix_mapping_component("DMXR1_") + DMXR2_mapping = self.get_prefix_mapping_component("DMXR2_") + for epoch_ind in DMX_mapping.keys(): + r1 = getattr(self, DMXR1_mapping[epoch_ind]).quantity + r2 = getattr(self, DMXR2_mapping[epoch_ind]).quantity + condition[DMX_mapping[epoch_ind]] = (r1.mjd, r2.mjd) + select_idx = self.dmx_toas_selector.get_select_index( + condition, tbl["mjd_float"] + ) + # Get DMX delays + dm = np.zeros(len(tbl)) * self._parent.DM.units + for k, v in select_idx.items(): + dm[v] += getattr(self, k).quantity + return dm + + def DMX_dispersion_delay(self, toas, acc_delay=None): + """This is a wrapper function for interacting with the TimingModel class""" + return self.dispersion_type_delay(toas) + + def d_dm_d_DMX(self, toas, param_name, acc_delay=None): + condition = {} + tbl = toas.table + if not hasattr(self, "dmx_toas_selector"): + self.dmx_toas_selector = TOASelect(is_range=True) + param = getattr(self, param_name) + dmx_index = param.index + DMXR1_mapping = self.get_prefix_mapping_component("DMXR1_") + DMXR2_mapping = self.get_prefix_mapping_component("DMXR2_") + r1 = getattr(self, DMXR1_mapping[dmx_index]).quantity + r2 = getattr(self, DMXR2_mapping[dmx_index]).quantity + condition = {param_name: (r1.mjd, r2.mjd)} + select_idx = self.dmx_toas_selector.get_select_index( + condition, tbl["mjd_float"] + ) + + try: + bfreq = self._parent.barycentric_radio_freq(toas) + except AttributeError: + warn("Using topocentric frequency for dedispersion!") + bfreq = tbl["freq"] + dmx = np.zeros(len(tbl)) + for k, v in select_idx.items(): + dmx[v] = 1.0 + return dmx * (u.pc / u.cm**3) / (u.pc / u.cm**3) + + def print_par(self, format="pint"): + result = "" + DMX_mapping = self.get_prefix_mapping_component("DMX_") + DMXR1_mapping = self.get_prefix_mapping_component("DMXR1_") + DMXR2_mapping = self.get_prefix_mapping_component("DMXR2_") + result += getattr(self, "DMX").as_parfile_line(format=format) + sorted_list = sorted(DMX_mapping.keys()) + for ii in sorted_list: + result += getattr(self, DMX_mapping[ii]).as_parfile_line(format=format) + result += getattr(self, DMXR1_mapping[ii]).as_parfile_line(format=format) + result += getattr(self, DMXR2_mapping[ii]).as_parfile_line(format=format) + return result diff --git a/src/pint/models/dispersion_model.py b/src/pint/models/dispersion_model.py index e5317b06c..9ddd06ed7 100644 --- a/src/pint/models/dispersion_model.py +++ b/src/pint/models/dispersion_model.py @@ -308,7 +308,7 @@ def change_dmepoch(self, new_epoch): class DispersionDMX(Dispersion): - """This class provides a DMX model - multiple DM values. + """This class provides a DMX model - piecewise-constant DM variations. This model lets the user specify time ranges and fit for a different DM value in each time range. From 2142369383ad760e3cc183bb8d3e111f21c4a9d2 Mon Sep 17 00:00:00 2001 From: Abhimanyu Susobhanan Date: Mon, 3 Jun 2024 13:04:26 +0200 Subject: [PATCH 2/7] cmx --- src/pint/models/chromatic_model.py | 287 ++++++++++++++-------------- src/pint/models/dispersion_model.py | 2 +- 2 files changed, 140 insertions(+), 149 deletions(-) diff --git a/src/pint/models/chromatic_model.py b/src/pint/models/chromatic_model.py index 5fab08926..3a0d4f40c 100644 --- a/src/pint/models/chromatic_model.py +++ b/src/pint/models/chromatic_model.py @@ -2,11 +2,13 @@ from warnings import warn import numpy as np import astropy.units as u -from pint.models.timing_model import DelayComponent, MissingParameter +from pint.models.timing_model import DelayComponent, MissingParameter, MissingTOAs from pint.models.parameter import floatParameter, prefixParameter, MJDParameter +from pint.toa_select import TOASelect from pint.utils import split_prefixed_name, taylor_horner, taylor_horner_deriv from pint import DMconst from astropy.time import Time +from loguru import logger as log cmu = u.pc / u.cm**3 / u.MHz**2 @@ -311,61 +313,49 @@ class ChromaticCMX(Chromatic): Parameters supported: .. paramtable:: - :class: pint.models.dispersion_model.DispersionDMX + :class: pint.models.chromatic_model.ChromaticCMX """ register = True - category = "dispersion_dmx" + category = "chromatic_cmx" def __init__(self): super().__init__() - # DMX is for info output right now - # @abhisrkckl: What exactly is the use of this parameter? - self.add_param( - floatParameter( - name="DMX", - units="pc cm^-3", - value=0.0, - description="Dispersion measure", - convert_tcb2tdb=False, - ) - ) + self.add_CMX_range(None, None, cmx=0, frozen=False, index=1) - self.add_DMX_range(None, None, dmx=0, frozen=False, index=1) + self.cm_value_funcs += [self.cmx_cm] + self.set_special_params(["CMX_0001", "CMXR1_0001", "CMXR2_0001"]) + self.delay_funcs_component += [self.CMX_chromatic_delay] - self.dm_value_funcs += [self.dmx_dm] - self.set_special_params(["DMX_0001", "DMXR1_0001", "DMXR2_0001"]) - self.delay_funcs_component += [self.DMX_dispersion_delay] - - def add_DMX_range(self, mjd_start, mjd_end, index=None, dmx=0, frozen=True): - """Add DMX range to a dispersion model with specified start/end MJDs and DMX. + def add_CMX_range(self, mjd_start, mjd_end, index=None, cmx=0, frozen=True): + """Add CMX range to a chromatic model with specified start/end MJDs and CMX value. Parameters ---------- mjd_start : float or astropy.quantity.Quantity or astropy.time.Time - MJD for beginning of DMX event. + MJD for beginning of CMX event. mjd_end : float or astropy.quantity.Quantity or astropy.time.Time - MJD for end of DMX event. + MJD for end of CMX event. index : int, None - Integer label for DMX event. If None, will increment largest used index by 1. - dmx : float or astropy.quantity.Quantity - Change in DM during DMX event. + Integer label for CMX event. If None, will increment largest used index by 1. + cmx : float or astropy.quantity.Quantity + Change in CM during CMX event. frozen : bool - Indicates whether DMX will be fit. + Indicates whether CMX will be fit. Returns ------- index : int - Index that has been assigned to new DMX event. + Index that has been assigned to new CMX event. """ - #### Setting up the DMX title convention. If index is None, want to increment the current max DMX index by 1. + #### Setting up the CMX title convention. If index is None, want to increment the current max CMX index by 1. if index is None: - dct = self.get_prefix_mapping_component("DMX_") + dct = self.get_prefix_mapping_component("CMX_") index = np.max(list(dct.keys())) + 1 i = f"{int(index):04d}" @@ -375,13 +365,14 @@ def add_DMX_range(self, mjd_start, mjd_end, index=None, dmx=0, frozen=True): elif mjd_start != mjd_end: raise ValueError("Only one MJD bound is set.") - if int(index) in self.get_prefix_mapping_component("DMX_"): + if int(index) in self.get_prefix_mapping_component("CMX_"): raise ValueError( f"Index '{index}' is already in use in this model. Please choose another." ) - if isinstance(dmx, u.quantity.Quantity): - dmx = dmx.to_value(u.pc / u.cm**3) + if isinstance(cmx, u.quantity.Quantity): + cmx = cmx.to_value(cmu) + if isinstance(mjd_start, Time): mjd_start = mjd_start.mjd elif isinstance(mjd_start, u.quantity.Quantity): @@ -392,63 +383,63 @@ def add_DMX_range(self, mjd_start, mjd_end, index=None, dmx=0, frozen=True): mjd_end = mjd_end.value self.add_param( prefixParameter( - name=f"DMX_{i}", - units="pc cm^-3", - value=dmx, + name=f"CMX_{i}", + units=cmu, + value=cmx, description="Dispersion measure variation", parameter_type="float", frozen=frozen, - tcb2tdb_scale_factor=DMconst, + convert_tcb2tdb=False, ) ) self.add_param( prefixParameter( - name=f"DMXR1_{i}", + name=f"CMXR1_{i}", units="MJD", - description="Beginning of DMX interval", + description="Beginning of CMX interval", parameter_type="MJD", time_scale="utc", value=mjd_start, - tcb2tdb_scale_factor=u.Quantity(1), + convert_tcb2tdb=False, ) ) self.add_param( prefixParameter( - name=f"DMXR2_{i}", + name=f"CMXR2_{i}", units="MJD", - description="End of DMX interval", + description="End of CMX interval", parameter_type="MJD", time_scale="utc", value=mjd_end, - tcb2tdb_scale_factor=u.Quantity(1), + convert_tcb2tdb=False, ) ) self.setup() self.validate() return index - def add_DMX_ranges(self, mjd_starts, mjd_ends, indices=None, dmxs=0, frozens=True): - """Add DMX ranges to a dispersion model with specified start/end MJDs and DMXs. + def add_CMX_ranges(self, mjd_starts, mjd_ends, indices=None, cmxs=0, frozens=True): + """Add CMX ranges to a dispersion model with specified start/end MJDs and CMXs. Parameters ---------- mjd_starts : iterable of float or astropy.quantity.Quantity or astropy.time.Time - MJD for beginning of DMX event. + MJD for beginning of CMX event. mjd_end : iterable of float or astropy.quantity.Quantity or astropy.time.Time - MJD for end of DMX event. + MJD for end of CMX event. indices : iterable of int, None - Integer label for DMX event. If None, will increment largest used index by 1. - dmxs : iterable of float or astropy.quantity.Quantity, or float or astropy.quantity.Quantity - Change in DM during DMX event. + Integer label for CMX event. If None, will increment largest used index by 1. + cmxs : iterable of float or astropy.quantity.Quantity, or float or astropy.quantity.Quantity + Change in CM during CMX event. frozens : iterable of bool or bool - Indicates whether DMX will be fit. + Indicates whether CMX will be fit. Returns ------- indices : list - Indices that has been assigned to new DMX events + Indices that has been assigned to new CMX events """ if len(mjd_starts) != len(mjd_ends): @@ -457,12 +448,12 @@ def add_DMX_ranges(self, mjd_starts, mjd_ends, indices=None, dmxs=0, frozens=Tru ) if indices is None: indices = [None] * len(mjd_starts) - dmxs = np.atleast_1d(dmxs) - if len(dmxs) == 1: - dmxs = np.repeat(dmxs, len(mjd_starts)) - if len(dmxs) != len(mjd_starts): + cmxs = np.atleast_1d(cmxs) + if len(cmxs) == 1: + cmxs = np.repeat(cmxs, len(mjd_starts)) + if len(cmxs) != len(mjd_starts): raise ValueError( - f"Number of mjd_start values {len(mjd_starts)} must match number of dmx values {len(dmxs)}" + f"Number of mjd_start values {len(mjd_starts)} must match number of cmx values {len(cmxs)}" ) frozens = np.atleast_1d(frozens) if len(frozens) == 1: @@ -472,19 +463,19 @@ def add_DMX_ranges(self, mjd_starts, mjd_ends, indices=None, dmxs=0, frozens=Tru f"Number of mjd_start values {len(mjd_starts)} must match number of frozen values {len(frozens)}" ) - #### Setting up the DMX title convention. If index is None, want to increment the current max DMX index by 1. - dct = self.get_prefix_mapping_component("DMX_") + #### Setting up the CMX title convention. If index is None, want to increment the current max CMX index by 1. + dct = self.get_prefix_mapping_component("CMX_") last_index = np.max(list(dct.keys())) added_indices = [] - for mjd_start, mjd_end, index, dmx, frozen in zip( - mjd_starts, mjd_ends, indices, dmxs, frozens + for mjd_start, mjd_end, index, cmx, frozen in zip( + mjd_starts, mjd_ends, indices, cmxs, frozens ): if index is None: index = last_index + 1 last_index += 1 elif index in list(dct.keys()): raise ValueError( - f"Attempting to insert DMX_{index:04d} but it already exists" + f"Attempting to insert CMX_{index:04d} but it already exists" ) added_indices.append(index) i = f"{int(index):04d}" @@ -498,8 +489,8 @@ def add_DMX_ranges(self, mjd_starts, mjd_ends, indices=None, dmxs=0, frozens=Tru raise ValueError( f"Index '{index}' is already in use in this model. Please choose another." ) - if isinstance(dmx, u.quantity.Quantity): - dmx = dmx.to_value(u.pc / u.cm**3) + if isinstance(cmx, u.quantity.Quantity): + cmx = cmx.to_value(u.pc / u.cm**3) if isinstance(mjd_start, Time): mjd_start = mjd_start.mjd elif isinstance(mjd_start, u.quantity.Quantity): @@ -508,52 +499,52 @@ def add_DMX_ranges(self, mjd_starts, mjd_ends, indices=None, dmxs=0, frozens=Tru mjd_end = mjd_end.mjd elif isinstance(mjd_end, u.quantity.Quantity): mjd_end = mjd_end.value - log.trace(f"Adding DMX_{i} from MJD {mjd_start} to MJD {mjd_end}") + log.trace(f"Adding CMX_{i} from MJD {mjd_start} to MJD {mjd_end}") self.add_param( prefixParameter( - name=f"DMX_{i}", - units="pc cm^-3", - value=dmx, + name=f"CMX_{i}", + units=cmu, + value=cmx, description="Dispersion measure variation", parameter_type="float", frozen=frozen, - tcb2tdb_scale_factor=DMconst, + convert_tcb2tdb=False, ) ) self.add_param( prefixParameter( - name=f"DMXR1_{i}", + name=f"CMXR1_{i}", units="MJD", - description="Beginning of DMX interval", + description="Beginning of CMX interval", parameter_type="MJD", time_scale="utc", value=mjd_start, - tcb2tdb_scale_factor=u.Quantity(1), + convert_tcb2tdb=False, ) ) self.add_param( prefixParameter( - name=f"DMXR2_{i}", + name=f"CMXR2_{i}", units="MJD", - description="End of DMX interval", + description="End of CMX interval", parameter_type="MJD", time_scale="utc", value=mjd_end, - tcb2tdb_scale_factor=u.Quantity(1), + convert_tcb2tdb=False, ) ) self.setup() self.validate() return added_indices - def remove_DMX_range(self, index): - """Removes all DMX parameters associated with a given index/list of indices. + def remove_CMX_range(self, index): + """Removes all CMX parameters associated with a given index/list of indices. Parameters ---------- index : float, int, list, np.ndarray - Number or list/array of numbers corresponding to DMX indices to be removed from model. + Number or list/array of numbers corresponding to CMX indices to be removed from model. """ if isinstance(index, (int, float, np.int64)): @@ -566,129 +557,129 @@ def remove_DMX_range(self, index): ) for index in indices: index_rf = f"{int(index):04d}" - for prefix in ["DMX_", "DMXR1_", "DMXR2_"]: + for prefix in ["CMX_", "CMXR1_", "CMXR2_"]: self.remove_param(prefix + index_rf) self.validate() def get_indices(self): - """Returns an array of integers corresponding to DMX parameters. + """Returns an array of integers corresponding to CMX parameters. Returns ------- inds : np.ndarray - Array of DMX indices in model. + Array of CMX indices in model. """ - inds = [int(p.split("_")[-1]) for p in self.params if "DMX_" in p] + inds = [int(p.split("_")[-1]) for p in self.params if "CMX_" in p] return np.array(inds) def setup(self): super().setup() - # Get DMX mapping. - # Register the DMX derivatives + # Get CMX mapping. + # Register the CMX derivatives for prefix_par in self.get_params_of_type("prefixParameter"): - if prefix_par.startswith("DMX_"): + if prefix_par.startswith("CMX_"): self.register_deriv_funcs(self.d_delay_d_dmparam, prefix_par) - self.register_dm_deriv_funcs(self.d_dm_d_DMX, prefix_par) + self.register_dm_deriv_funcs(self.d_dm_d_CMX, prefix_par) def validate(self): - """Validate the DMX parameters.""" + """Validate the CMX parameters.""" super().validate() - DMX_mapping = self.get_prefix_mapping_component("DMX_") - DMXR1_mapping = self.get_prefix_mapping_component("DMXR1_") - DMXR2_mapping = self.get_prefix_mapping_component("DMXR2_") - if DMX_mapping.keys() != DMXR1_mapping.keys(): + CMX_mapping = self.get_prefix_mapping_component("CMX_") + CMXR1_mapping = self.get_prefix_mapping_component("CMXR1_") + CMXR2_mapping = self.get_prefix_mapping_component("CMXR2_") + if CMX_mapping.keys() != CMXR1_mapping.keys(): # FIXME: report mismatch raise ValueError( - "DMX_ parameters do not " - "match DMXR1_ parameters. " + "CMX_ parameters do not " + "match CMXR1_ parameters. " "Please check your prefixed parameters." ) - if DMX_mapping.keys() != DMXR2_mapping.keys(): + if CMX_mapping.keys() != CMXR2_mapping.keys(): raise ValueError( - "DMX_ parameters do not " - "match DMXR2_ parameters. " + "CMX_ parameters do not " + "match CMXR2_ parameters. " "Please check your prefixed parameters." ) - r1 = np.zeros(len(DMX_mapping)) - r2 = np.zeros(len(DMX_mapping)) - indices = np.zeros(len(DMX_mapping), dtype=np.int32) - for j, index in enumerate(DMX_mapping): + r1 = np.zeros(len(CMX_mapping)) + r2 = np.zeros(len(CMX_mapping)) + indices = np.zeros(len(CMX_mapping), dtype=np.int32) + for j, index in enumerate(CMX_mapping): if ( - getattr(self, f"DMXR1_{index:04d}").quantity is not None - and getattr(self, f"DMXR2_{index:04d}").quantity is not None + getattr(self, f"CMXR1_{index:04d}").quantity is not None + and getattr(self, f"CMXR2_{index:04d}").quantity is not None ): - r1[j] = getattr(self, f"DMXR1_{index:04d}").quantity.mjd - r2[j] = getattr(self, f"DMXR2_{index:04d}").quantity.mjd + r1[j] = getattr(self, f"CMXR1_{index:04d}").quantity.mjd + r2[j] = getattr(self, f"CMXR2_{index:04d}").quantity.mjd indices[j] = index - for j, index in enumerate(DMXR1_mapping): + for j, index in enumerate(CMXR1_mapping): if np.any((r1[j] > r1) & (r1[j] < r2)): k = np.where((r1[j] > r1) & (r1[j] < r2))[0] for kk in k.flatten(): log.warning( - f"Start of DMX_{index:04d} ({r1[j]}-{r2[j]}) overlaps with DMX_{indices[kk]:04d} ({r1[kk]}-{r2[kk]})" + f"Start of CMX_{index:04d} ({r1[j]}-{r2[j]}) overlaps with CMX_{indices[kk]:04d} ({r1[kk]}-{r2[kk]})" ) if np.any((r2[j] > r1) & (r2[j] < r2)): k = np.where((r2[j] > r1) & (r2[j] < r2))[0] for kk in k.flatten(): log.warning( - f"End of DMX_{index:04d} ({r1[j]}-{r2[j]}) overlaps with DMX_{indices[kk]:04d} ({r1[kk]}-{r2[kk]})" + f"End of CMX_{index:04d} ({r1[j]}-{r2[j]}) overlaps with CMX_{indices[kk]:04d} ({r1[kk]}-{r2[kk]})" ) def validate_toas(self, toas): - DMX_mapping = self.get_prefix_mapping_component("DMX_") - DMXR1_mapping = self.get_prefix_mapping_component("DMXR1_") - DMXR2_mapping = self.get_prefix_mapping_component("DMXR2_") + CMX_mapping = self.get_prefix_mapping_component("CMX_") + CMXR1_mapping = self.get_prefix_mapping_component("CMXR1_") + CMXR2_mapping = self.get_prefix_mapping_component("CMXR2_") bad_parameters = [] - for k in DMXR1_mapping.keys(): - if self._parent[DMX_mapping[k]].frozen: + for k in CMXR1_mapping.keys(): + if self._parent[CMX_mapping[k]].frozen: continue - b = self._parent[DMXR1_mapping[k]].quantity.mjd * u.d - e = self._parent[DMXR2_mapping[k]].quantity.mjd * u.d + b = self._parent[CMXR1_mapping[k]].quantity.mjd * u.d + e = self._parent[CMXR2_mapping[k]].quantity.mjd * u.d mjds = toas.get_mjds() n = np.sum((b <= mjds) & (mjds < e)) if n == 0: - bad_parameters.append(DMX_mapping[k]) + bad_parameters.append(CMX_mapping[k]) if bad_parameters: raise MissingTOAs(bad_parameters) - def dmx_dm(self, toas): + def cmx_dm(self, toas): condition = {} tbl = toas.table - if not hasattr(self, "dmx_toas_selector"): - self.dmx_toas_selector = TOASelect(is_range=True) - DMX_mapping = self.get_prefix_mapping_component("DMX_") - DMXR1_mapping = self.get_prefix_mapping_component("DMXR1_") - DMXR2_mapping = self.get_prefix_mapping_component("DMXR2_") - for epoch_ind in DMX_mapping.keys(): - r1 = getattr(self, DMXR1_mapping[epoch_ind]).quantity - r2 = getattr(self, DMXR2_mapping[epoch_ind]).quantity - condition[DMX_mapping[epoch_ind]] = (r1.mjd, r2.mjd) - select_idx = self.dmx_toas_selector.get_select_index( + if not hasattr(self, "cmx_toas_selector"): + self.cmx_toas_selector = TOASelect(is_range=True) + CMX_mapping = self.get_prefix_mapping_component("CMX_") + CMXR1_mapping = self.get_prefix_mapping_component("CMXR1_") + CMXR2_mapping = self.get_prefix_mapping_component("CMXR2_") + for epoch_ind in CMX_mapping.keys(): + r1 = getattr(self, CMXR1_mapping[epoch_ind]).quantity + r2 = getattr(self, CMXR2_mapping[epoch_ind]).quantity + condition[CMX_mapping[epoch_ind]] = (r1.mjd, r2.mjd) + select_idx = self.cmx_toas_selector.get_select_index( condition, tbl["mjd_float"] ) - # Get DMX delays + # Get CMX delays dm = np.zeros(len(tbl)) * self._parent.DM.units for k, v in select_idx.items(): dm[v] += getattr(self, k).quantity return dm - def DMX_dispersion_delay(self, toas, acc_delay=None): + def CMX_dispersion_delay(self, toas, acc_delay=None): """This is a wrapper function for interacting with the TimingModel class""" return self.dispersion_type_delay(toas) - def d_dm_d_DMX(self, toas, param_name, acc_delay=None): + def d_dm_d_CMX(self, toas, param_name, acc_delay=None): condition = {} tbl = toas.table - if not hasattr(self, "dmx_toas_selector"): - self.dmx_toas_selector = TOASelect(is_range=True) + if not hasattr(self, "cmx_toas_selector"): + self.cmx_toas_selector = TOASelect(is_range=True) param = getattr(self, param_name) - dmx_index = param.index - DMXR1_mapping = self.get_prefix_mapping_component("DMXR1_") - DMXR2_mapping = self.get_prefix_mapping_component("DMXR2_") - r1 = getattr(self, DMXR1_mapping[dmx_index]).quantity - r2 = getattr(self, DMXR2_mapping[dmx_index]).quantity + cmx_index = param.index + CMXR1_mapping = self.get_prefix_mapping_component("CMXR1_") + CMXR2_mapping = self.get_prefix_mapping_component("CMXR2_") + r1 = getattr(self, CMXR1_mapping[cmx_index]).quantity + r2 = getattr(self, CMXR2_mapping[cmx_index]).quantity condition = {param_name: (r1.mjd, r2.mjd)} - select_idx = self.dmx_toas_selector.get_select_index( + select_idx = self.cmx_toas_selector.get_select_index( condition, tbl["mjd_float"] ) @@ -697,20 +688,20 @@ def d_dm_d_DMX(self, toas, param_name, acc_delay=None): except AttributeError: warn("Using topocentric frequency for dedispersion!") bfreq = tbl["freq"] - dmx = np.zeros(len(tbl)) + cmx = np.zeros(len(tbl)) for k, v in select_idx.items(): - dmx[v] = 1.0 - return dmx * (u.pc / u.cm**3) / (u.pc / u.cm**3) + cmx[v] = 1.0 + return cmx * (u.pc / u.cm**3) / (u.pc / u.cm**3) def print_par(self, format="pint"): result = "" - DMX_mapping = self.get_prefix_mapping_component("DMX_") - DMXR1_mapping = self.get_prefix_mapping_component("DMXR1_") - DMXR2_mapping = self.get_prefix_mapping_component("DMXR2_") - result += getattr(self, "DMX").as_parfile_line(format=format) - sorted_list = sorted(DMX_mapping.keys()) + CMX_mapping = self.get_prefix_mapping_component("CMX_") + CMXR1_mapping = self.get_prefix_mapping_component("CMXR1_") + CMXR2_mapping = self.get_prefix_mapping_component("CMXR2_") + result += getattr(self, "CMX").as_parfile_line(format=format) + sorted_list = sorted(CMX_mapping.keys()) for ii in sorted_list: - result += getattr(self, DMX_mapping[ii]).as_parfile_line(format=format) - result += getattr(self, DMXR1_mapping[ii]).as_parfile_line(format=format) - result += getattr(self, DMXR2_mapping[ii]).as_parfile_line(format=format) + result += getattr(self, CMX_mapping[ii]).as_parfile_line(format=format) + result += getattr(self, CMXR1_mapping[ii]).as_parfile_line(format=format) + result += getattr(self, CMXR2_mapping[ii]).as_parfile_line(format=format) return result diff --git a/src/pint/models/dispersion_model.py b/src/pint/models/dispersion_model.py index 9ddd06ed7..61f68dd93 100644 --- a/src/pint/models/dispersion_model.py +++ b/src/pint/models/dispersion_model.py @@ -344,7 +344,7 @@ def __init__(self): self.delay_funcs_component += [self.DMX_dispersion_delay] def add_DMX_range(self, mjd_start, mjd_end, index=None, dmx=0, frozen=True): - """Add DMX range to a dispersion model with specified start/end MJDs and DMX. + """Add DMX range to a dispersion model with specified start/end MJDs and DMX value. Parameters ---------- From cc6830b5e0c8dfcf4cbc983ad578f52616e710c8 Mon Sep 17 00:00:00 2001 From: Abhimanyu Susobhanan Date: Tue, 2 Jul 2024 09:44:51 +0200 Subject: [PATCH 3/7] fix cmx --- CHANGELOG-unreleased.md | 1 + src/pint/models/chromatic_model.py | 16 ++++++++-------- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/CHANGELOG-unreleased.md b/CHANGELOG-unreleased.md index 2899b99f8..36c201efc 100644 --- a/CHANGELOG-unreleased.md +++ b/CHANGELOG-unreleased.md @@ -10,5 +10,6 @@ the released changes. ## Unreleased ### Changed ### Added +- Piecewise-constant model for chromatic variations (CMX) ### Fixed ### Removed diff --git a/src/pint/models/chromatic_model.py b/src/pint/models/chromatic_model.py index 799440637..c7cf6854f 100644 --- a/src/pint/models/chromatic_model.py +++ b/src/pint/models/chromatic_model.py @@ -578,8 +578,8 @@ def setup(self): # Register the CMX derivatives for prefix_par in self.get_params_of_type("prefixParameter"): if prefix_par.startswith("CMX_"): - self.register_deriv_funcs(self.d_delay_d_dmparam, prefix_par) - self.register_dm_deriv_funcs(self.d_dm_d_CMX, prefix_par) + self.register_deriv_funcs(self.d_delay_d_cmparam, prefix_par) + self.register_cm_deriv_funcs(self.d_cm_d_CMX, prefix_par) def validate(self): """Validate the CMX parameters.""" @@ -642,7 +642,7 @@ def validate_toas(self, toas): if bad_parameters: raise MissingTOAs(bad_parameters) - def cmx_dm(self, toas): + def cmx_cm(self, toas): condition = {} tbl = toas.table if not hasattr(self, "cmx_toas_selector"): @@ -658,16 +658,16 @@ def cmx_dm(self, toas): condition, tbl["mjd_float"] ) # Get CMX delays - dm = np.zeros(len(tbl)) * self._parent.DM.units + cm = np.zeros(len(tbl)) * self._parent.CM.units for k, v in select_idx.items(): - dm[v] += getattr(self, k).quantity - return dm + cm[v] += getattr(self, k).quantity + return cm - def CMX_dispersion_delay(self, toas, acc_delay=None): + def CMX_chromatic_delay(self, toas, acc_delay=None): """This is a wrapper function for interacting with the TimingModel class""" return self.dispersion_type_delay(toas) - def d_dm_d_CMX(self, toas, param_name, acc_delay=None): + def d_cm_d_CMX(self, toas, param_name, acc_delay=None): condition = {} tbl = toas.table if not hasattr(self, "cmx_toas_selector"): From 73732784fcf618c449800fb774ac1b4105e304e2 Mon Sep 17 00:00:00 2001 From: Abhimanyu Susobhanan Date: Wed, 13 Nov 2024 10:13:59 +0100 Subject: [PATCH 4/7] docs --- src/pint/models/chromatic_model.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/src/pint/models/chromatic_model.py b/src/pint/models/chromatic_model.py index 5b6b5f298..52592bb14 100644 --- a/src/pint/models/chromatic_model.py +++ b/src/pint/models/chromatic_model.py @@ -15,7 +15,7 @@ class Chromatic(DelayComponent): - """A base chromatic timing model.""" + """A base chromatic timing model with a constant chromatic index.""" def __init__(self): super().__init__() @@ -114,11 +114,14 @@ def register_cm_deriv_funcs(self, func, param): class ChromaticCM(Chromatic): - """Simple chromatic delay model. + """Simple chromatic delay model with a constant chromatic index. This model uses Taylor expansion to model CM variation over time. It can also be used for a constant CM. + Fitting for the chromatic index is not supported because the fit is too + unstable when fit simultaneously with the DM. + Parameters supported: .. paramtable:: @@ -306,10 +309,13 @@ def change_cmepoch(self, new_epoch): class ChromaticCMX(Chromatic): - """This class provides a CMX model - piecewise-constant chromatic variations. + """This class provides a CMX model - piecewise-constant chromatic variations with constant + chromatic index. + + This model lets the user specify time ranges and fit for a different CMX value in each time range. - This model lets the user specify time ranges and fit for a different - CMX value in each time range. + It should be used in combination with the `ChromaticCM` model. Specifically, TNCHROMIDX must be + set. Parameters supported: From 3c91b51caf6247e13e962a35b69abafb6eadee9e Mon Sep 17 00:00:00 2001 From: Abhimanyu Susobhanan Date: Wed, 13 Nov 2024 10:36:03 +0100 Subject: [PATCH 5/7] test_cmx --- src/pint/models/__init__.py | 2 +- src/pint/models/chromatic_model.py | 2 +- tests/test_cmx.py | 77 ++++++++++++++++++++++++++++++ 3 files changed, 79 insertions(+), 2 deletions(-) create mode 100644 tests/test_cmx.py diff --git a/src/pint/models/__init__.py b/src/pint/models/__init__.py index 67201a259..27c78a4f4 100644 --- a/src/pint/models/__init__.py +++ b/src/pint/models/__init__.py @@ -25,7 +25,7 @@ from pint.models.binary_dd import BinaryDD, BinaryDDS, BinaryDDGR, BinaryDDH from pint.models.binary_ddk import BinaryDDK from pint.models.binary_ell1 import BinaryELL1, BinaryELL1H, BinaryELL1k -from pint.models.chromatic_model import ChromaticCM +from pint.models.chromatic_model import ChromaticCM, ChromaticCMX from pint.models.cmwavex import CMWaveX from pint.models.dispersion_model import ( DispersionDM, diff --git a/src/pint/models/chromatic_model.py b/src/pint/models/chromatic_model.py index 52592bb14..8f9dcb815 100644 --- a/src/pint/models/chromatic_model.py +++ b/src/pint/models/chromatic_model.py @@ -672,7 +672,7 @@ def cmx_cm(self, toas): def CMX_chromatic_delay(self, toas, acc_delay=None): """This is a wrapper function for interacting with the TimingModel class""" - return self.dispersion_type_delay(toas) + return self.chromatic_type_delay(toas) def d_cm_d_CMX(self, toas, param_name, acc_delay=None): condition = {} diff --git a/tests/test_cmx.py b/tests/test_cmx.py new file mode 100644 index 000000000..b92b68dd8 --- /dev/null +++ b/tests/test_cmx.py @@ -0,0 +1,77 @@ +from io import StringIO +from pint.models import get_model +from pint.simulation import make_fake_toas_uniform +from pint.fitter import Fitter + +import pytest +import numpy as np +import astropy.units as u + + +@pytest.fixture +def model_and_toas(): + par = """ + RAJ 05:00:00 + DECJ 12:00:00 + F0 101.0 + F1 -1.1e-14 + PEPOCH 55000 + DM 12 + DMEPOCH 55000 + CMEPOCH 55000 + CM 3.5 + TNCHROMIDX 4.0 + EPHEM DE440 + CLOCK TT(BIPM2021) + UNITS TDB + TZRMJD 55000 + TZRFRQ 1400 + TZRSITE gmrt + CMXR1_0001 53999.9 + CMXR2_0001 54500 + CMX_0001 0.5 1 + CMXR1_0002 54500.1 + CMXR2_0002 55000 + CMX_0002 -0.5 1 + CMXR1_0003 55000.1 + CMXR2_0003 55500 + CMX_0003 -0.1 1 + """ + + model = get_model(StringIO(par)) + + freqs = np.linspace(300, 1600, 16) * u.MHz + toas = make_fake_toas_uniform( + startMJD=54000, + endMJD=55500, + ntoas=2000, + model=model, + freq=freqs, + add_noise=True, + obs="gmrt", + include_bipm=True, + multi_freqs_in_epoch=True, + ) + + return model, toas + + +def test_cmx(model_and_toas): + model, toas = model_and_toas + + assert "ChromaticCMX" in model.components + + ftr = Fitter.auto(toas, model) + ftr.fit_toas() + + assert np.abs(model.CMX_0001.value - ftr.model.CMX_0001.value) / ( + 3 * ftr.model.CMX_0001.uncertainty_value + ) + assert np.abs(model.CMX_0002.value - ftr.model.CMX_0002.value) / ( + 3 * ftr.model.CMX_0002.uncertainty_value + ) + assert np.abs(model.CMX_0003.value - ftr.model.CMX_0003.value) / ( + 3 * ftr.model.CMX_0003.uncertainty_value + ) + + assert ftr.resids.chi2_reduced < 1.6 From 4cab923a4cfa81590b48188ae2af348c24afded5 Mon Sep 17 00:00:00 2001 From: Abhimanyu Susobhanan Date: Wed, 13 Nov 2024 10:39:30 +0100 Subject: [PATCH 6/7] test --- src/pint/models/chromatic_model.py | 1 - tests/test_cmx.py | 2 ++ 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/pint/models/chromatic_model.py b/src/pint/models/chromatic_model.py index 8f9dcb815..d745729ce 100644 --- a/src/pint/models/chromatic_model.py +++ b/src/pint/models/chromatic_model.py @@ -705,7 +705,6 @@ def print_par(self, format="pint"): CMX_mapping = self.get_prefix_mapping_component("CMX_") CMXR1_mapping = self.get_prefix_mapping_component("CMXR1_") CMXR2_mapping = self.get_prefix_mapping_component("CMXR2_") - result += getattr(self, "CMX").as_parfile_line(format=format) sorted_list = sorted(CMX_mapping.keys()) for ii in sorted_list: result += getattr(self, CMX_mapping[ii]).as_parfile_line(format=format) diff --git a/tests/test_cmx.py b/tests/test_cmx.py index b92b68dd8..f62563840 100644 --- a/tests/test_cmx.py +++ b/tests/test_cmx.py @@ -75,3 +75,5 @@ def test_cmx(model_and_toas): ) assert ftr.resids.chi2_reduced < 1.6 + + assert "CMX_0001" in str(ftr.model) From 52645a901f05019b7aad0a707484cb4c33636888 Mon Sep 17 00:00:00 2001 From: Abhimanyu Susobhanan Date: Wed, 20 Nov 2024 15:01:56 +0100 Subject: [PATCH 7/7] test_cmx_delay --- src/pint/models/chromatic_model.py | 9 +---- tests/test_cmx.py | 55 +++++++++++++++++++++++++++++- 2 files changed, 55 insertions(+), 9 deletions(-) diff --git a/src/pint/models/chromatic_model.py b/src/pint/models/chromatic_model.py index d745729ce..8f428fea2 100644 --- a/src/pint/models/chromatic_model.py +++ b/src/pint/models/chromatic_model.py @@ -357,7 +357,6 @@ def add_CMX_range(self, mjd_start, mjd_end, index=None, cmx=0, frozen=True): index : int Index that has been assigned to new CMX event. - """ #### Setting up the CMX title convention. If index is None, want to increment the current max CMX index by 1. @@ -447,7 +446,6 @@ def add_CMX_ranges(self, mjd_starts, mjd_ends, indices=None, cmxs=0, frozens=Tru indices : list Indices that has been assigned to new CMX events - """ if len(mjd_starts) != len(mjd_ends): raise ValueError( @@ -574,7 +572,7 @@ def get_indices(self): Returns ------- inds : np.ndarray - Array of CMX indices in model. + Array of CMX indices in model. """ inds = [int(p.split("_")[-1]) for p in self.params if "CMX_" in p] return np.array(inds) @@ -690,11 +688,6 @@ def d_cm_d_CMX(self, toas, param_name, acc_delay=None): condition, tbl["mjd_float"] ) - try: - bfreq = self._parent.barycentric_radio_freq(toas) - except AttributeError: - warn("Using topocentric frequency for dedispersion!") - bfreq = tbl["freq"] cmx = np.zeros(len(tbl)) for k, v in select_idx.items(): cmx[v] = 1.0 diff --git a/tests/test_cmx.py b/tests/test_cmx.py index f62563840..976581880 100644 --- a/tests/test_cmx.py +++ b/tests/test_cmx.py @@ -43,7 +43,7 @@ def model_and_toas(): freqs = np.linspace(300, 1600, 16) * u.MHz toas = make_fake_toas_uniform( startMJD=54000, - endMJD=55500, + endMJD=56000, ntoas=2000, model=model, freq=freqs, @@ -77,3 +77,56 @@ def test_cmx(model_and_toas): assert ftr.resids.chi2_reduced < 1.6 assert "CMX_0001" in str(ftr.model) + + +def test_cmx_delay(model_and_toas): + model, toas = model_and_toas + + # Zero delay outside CMX ranges + nocmx_mask = toas.get_mjds().value > 55500 + assert all( + model.components["ChromaticCMX"].CMX_chromatic_delay(toas)[nocmx_mask] == 0 + ) + + # The delay is consistent + cmx1_mask = np.logical_and( + toas.get_mjds().value >= model.CMXR1_0001.value, + toas.get_mjds().value <= model.CMXR2_0001.value, + ) + cmx1_freqs = toas.get_freqs()[cmx1_mask] + assert all( + np.isclose( + model.components["ChromaticCMX"].chromatic_time_delay( + model.CMX_0001.quantity, model.TNCHROMIDX.quantity, cmx1_freqs + ), + model.components["ChromaticCMX"].CMX_chromatic_delay(toas)[cmx1_mask], + ) + ) + + cmx2_mask = np.logical_and( + toas.get_mjds().value >= model.CMXR1_0002.value, + toas.get_mjds().value <= model.CMXR2_0002.value, + ) + cmx2_freqs = toas.get_freqs()[cmx2_mask] + assert all( + np.isclose( + model.components["ChromaticCMX"].chromatic_time_delay( + model.CMX_0002.quantity, model.TNCHROMIDX.quantity, cmx2_freqs + ), + model.components["ChromaticCMX"].CMX_chromatic_delay(toas)[cmx2_mask], + ) + ) + + cmx3_mask = np.logical_and( + toas.get_mjds().value >= model.CMXR1_0003.value, + toas.get_mjds().value <= model.CMXR2_0003.value, + ) + cmx3_freqs = toas.get_freqs()[cmx3_mask] + assert all( + np.isclose( + model.components["ChromaticCMX"].chromatic_time_delay( + model.CMX_0003.quantity, model.TNCHROMIDX.quantity, cmx3_freqs + ), + model.components["ChromaticCMX"].CMX_chromatic_delay(toas)[cmx3_mask], + ) + )