From c96c95011671f18d0c15233068d13d15c8bea737 Mon Sep 17 00:00:00 2001 From: Abhimanyu Susobhanan Date: Fri, 3 Jan 2025 10:24:38 +0100 Subject: [PATCH 1/6] type hints --- CHANGELOG-unreleased.md | 2 ++ src/pint/models/timing_model.py | 56 ++++++++++++++++++++------------- 2 files changed, 36 insertions(+), 22 deletions(-) diff --git a/CHANGELOG-unreleased.md b/CHANGELOG-unreleased.md index a2deae5cd..1cd425292 100644 --- a/CHANGELOG-unreleased.md +++ b/CHANGELOG-unreleased.md @@ -10,6 +10,8 @@ the released changes. ## Unreleased ### Changed ### Added +- Type hints in `pint.models.timing_model` ### Fixed +- Made `TimingModel.is_binary()` more robust. ### Removed diff --git a/src/pint/models/timing_model.py b/src/pint/models/timing_model.py index d6aeb0525..a650ea666 100644 --- a/src/pint/models/timing_model.py +++ b/src/pint/models/timing_model.py @@ -34,6 +34,7 @@ import contextlib from collections import OrderedDict, defaultdict from functools import wraps +from typing import Callable, Dict, List, Literal from warnings import warn from uncertainties import ufloat @@ -239,7 +240,7 @@ class TimingModel: rather than to any particular component. """ - def __init__(self, name="", components=[]): + def __init__(self, name: str = "", components: List["Component"] = []): if not isinstance(name, str): raise ValueError( "First parameter should be the model name, was {!r}".format(name) @@ -381,16 +382,16 @@ def __init__(self, name="", components=[]): for cp in components: self.add_component(cp, setup=False, validate=False) - def __repr__(self): + def __repr__(self) -> str: return "{}(\n {}\n)".format( self.__class__.__name__, ",\n ".join(str(v) for k, v in sorted(self.components.items())), ) - def __str__(self): + def __str__(self) -> str: return self.as_parfile() - def validate(self, allow_tcb=False): + def validate(self, allow_tcb: bool = False) -> None: """Validate component setup. The checks include required parameters and parameter values, and component types. @@ -440,7 +441,7 @@ def validate(self, allow_tcb=False): self.validate_component_types() - def validate_component_types(self): + def validate_component_types(self) -> None: """Physically motivated validation of a timing model. This method checks the compatibility of different model components when used together. @@ -534,7 +535,7 @@ def num_components_of_type(type): # result += str(getattr(cp, pp)) + "\n" # return result - def __getattr__(self, name): + def __getattr__(self, name: str): if name in ["components", "component_types", "search_cmp_attr"]: raise AttributeError if not hasattr(self, "component_types"): @@ -548,7 +549,9 @@ def __getattr__(self, name): f"Attribute {name} not found in TimingModel or any Component" ) - def __setattr__(self, name, value): + def __setattr__( + self, name: str, value: Parameter | prefixParameter | u.Quantity | float + ): """Mostly this just sets ``self.name = value``. But there are a few special cases: * Where they are both :class:`Parameter` instances with different names, @@ -588,7 +591,7 @@ def __setattr__(self, name, value): super().__setattr__(name, value) @property_exists - def params_ordered(self): + def params_ordered(self) -> List[str]: """List of all parameter names in this model and all its components. This is the same as `params`.""" @@ -605,7 +608,7 @@ def params_ordered(self): return self.params @property_exists - def params(self): + def params(self) -> List[str]: """List of all parameter names in this model and all its components, in a sensible order.""" # Define the order of components in the list @@ -646,7 +649,7 @@ def params(self): return pstart + pmid + pend @property_exists - def free_params(self): + def free_params(self) -> List[str]: """List of all the free parameters in the timing model. Can be set to change which are free. @@ -661,7 +664,7 @@ def free_params(self): return [p for p in self.params if not getattr(self, p).frozen] @free_params.setter - def free_params(self, params): + def free_params(self, params: List[str]): params_true = {self.match_param_aliases(p) for p in params} for p in self.params: getattr(self, p).frozen = p not in params_true @@ -672,7 +675,7 @@ def free_params(self, params): ) @property_exists - def fittable_params(self): + def fittable_params(self) -> List[str]: """List of parameters that are fittable, i.e., the parameters which have a derivative implemented. These derivatives are usually accessed via the `d_delay_d_param` and `d_phase_d_param` methods.""" @@ -692,7 +695,7 @@ def fittable_params(self): ) ] - def match_param_aliases(self, alias): + def match_param_aliases(self, alias: str) -> str: """Return PINT parameter name corresponding to this alias. Parameters @@ -722,7 +725,11 @@ def match_param_aliases(self, alias): raise UnknownParameter(f"{alias} is not recognized as a parameter or alias") - def get_params_dict(self, which="free", kind="quantity"): + def get_params_dict( + self, + which: Literal["free", "all"] = "free", + kind: Literal["quantity", "value", "uncertainty"] = "quantity", + ) -> OrderedDict[str, float] | OrderedDict[str, u.Quantity]: """Return a dict mapping parameter names to values. This can return only the free parameters or all; and it can return the @@ -756,7 +763,10 @@ def get_params_dict(self, which="free", kind="quantity"): raise ValueError(f"Unknown kind {kind!r}") return c - def get_params_of_component_type(self, component_type): + def get_params_of_component_type( + self, + component_type: Literal["PhaseComponent", "DelayComponent", "NoiseComponent"], + ) -> List[str]: """Get a list of parameters belonging to a component type. Parameters @@ -776,7 +786,7 @@ def get_params_of_component_type(self, component_type): else: return [] - def set_param_values(self, fitp): + def set_param_values(self, fitp: Dict[str, float]) -> None: """Set the model parameters to the value contained in the input dict. Ex. model.set_param_values({'F0':60.1,'F1':-1.3e-15}) @@ -794,14 +804,14 @@ def set_param_values(self, fitp): else: p.value = v - def set_param_uncertainties(self, fitp): + def set_param_uncertainties(self, fitp: Dict[str, float]) -> None: """Set the model parameters to the value contained in the input dict.""" for k, v in fitp.items(): p = getattr(self, k) p.uncertainty = v if isinstance(v, u.Quantity) else v * p.units @property_exists - def components(self): + def components(self) -> Dict[str, "Component"]: """All the components in a dictionary indexed by name.""" comps = {} for ct in self.component_types: @@ -810,7 +820,7 @@ def components(self): return comps @property_exists - def delay_funcs(self): + def delay_funcs(self) -> List[Callable]: """List of all delay functions.""" dfs = [] for d in self.DelayComponent_list: @@ -818,7 +828,7 @@ def delay_funcs(self): return dfs @property_exists - def phase_funcs(self): + def phase_funcs(self) -> List[Callable]: """List of all phase functions.""" pfs = [] for p in self.PhaseComponent_list: @@ -826,9 +836,11 @@ def phase_funcs(self): return pfs @property_exists - def is_binary(self): + def is_binary(self) -> bool: """Does the model describe a binary pulsar?""" - return any(x.startswith("Binary") for x in self.components.keys()) + from pint.models.pulsar_binary import PulsarBinary + + return any(isinstance(x, PulsarBinary) for x in self.components.values()) def orbital_phase(self, barytimes, anom="mean", radians=True): """Return orbital phase (in radians) at barycentric MJD times. From e8f40f23bab915b39db6d1e01acfd42327acf297 Mon Sep 17 00:00:00 2001 From: Abhimanyu Susobhanan Date: Fri, 3 Jan 2025 10:40:11 +0100 Subject: [PATCH 2/6] type hints --- src/pint/models/timing_model.py | 66 +++++++++++++++++++++------------ 1 file changed, 42 insertions(+), 24 deletions(-) diff --git a/src/pint/models/timing_model.py b/src/pint/models/timing_model.py index a650ea666..f5f798ce1 100644 --- a/src/pint/models/timing_model.py +++ b/src/pint/models/timing_model.py @@ -34,7 +34,7 @@ import contextlib from collections import OrderedDict, defaultdict from functools import wraps -from typing import Callable, Dict, List, Literal +from typing import Callable, Dict, List, Literal, Optional, Tuple, Union from warnings import warn from uncertainties import ufloat @@ -842,12 +842,17 @@ def is_binary(self) -> bool: return any(isinstance(x, PulsarBinary) for x in self.components.values()) - def orbital_phase(self, barytimes, anom="mean", radians=True): + def orbital_phase( + self, + barytimes: time.Time | TOAs | np.ndarray | float | MJDParameter, + anom: Literal["mean", "eccentric", "true"] = "mean", + radians: bool = True, + ) -> np.ndarray: """Return orbital phase (in radians) at barycentric MJD times. Parameters ---------- - barytimes: Time, TOAs, array-like, or float + barytimes: Time, TOAs, array-like, MJDParameter, or float MJD barycentric time(s). The times to compute the orbital phases. Needs to be a barycentric time in TDB. If a TOAs instance is passed, the barycentering will happen @@ -911,12 +916,14 @@ def orbital_phase(self, barytimes, anom="mean", radians=True): # return with radian units or return as unitless cycles from 0-1 return anoms * u.rad if radians else anoms / (2 * np.pi) - def pulsar_radial_velocity(self, barytimes): + def pulsar_radial_velocity( + self, barytimes: time.Time | TOAs | np.ndarray | float | MJDParameter + ) -> np.ndarray: """Return line-of-sight velocity of the pulsar relative to the system barycenter at barycentric MJD times. Parameters ---------- - barytimes: Time, TOAs, array-like, or float + barytimes: Time, TOAs, array-like, MJDParameter, or float MJD barycentric time(s). The times to compute the orbital phases. Needs to be a barycentric time in TDB. If a TOAs instance is passed, the barycentering will happen @@ -957,7 +964,11 @@ def pulsar_radial_velocity(self, barytimes): * (np.cos(psi) + bbi.ecc() * np.cos(bbi.omega())) ).cgs - def companion_radial_velocity(self, barytimes, massratio): + def companion_radial_velocity( + self, + barytimes: time.Time | TOAs | np.ndarray | float | MJDParameter, + massratio: float, + ) -> np.ndarray: """Return line-of-sight velocity of the companion relative to the system barycenter at barycentric MJD times. Parameters @@ -993,7 +1004,7 @@ def companion_radial_velocity(self, barytimes, massratio): """ return -self.pulsar_radial_velocity(barytimes) * massratio - def conjunction(self, baryMJD): + def conjunction(self, baryMJD: float | time.Time) -> float | np.ndarray: """Return the time(s) of the first superior conjunction(s) after baryMJD. Args @@ -1054,7 +1065,7 @@ def funct(t): return scs[0] if len(scs) == 1 else np.asarray(scs) @property_exists - def dm_funcs(self): + def dm_funcs(self) -> List[Callable]: """List of all dm value functions.""" dmfs = [] for cp in self.components.values(): @@ -1065,7 +1076,7 @@ def dm_funcs(self): return dmfs @property_exists - def has_correlated_errors(self): + def has_correlated_errors(self) -> bool: """Whether or not this model has correlated errors.""" return ( @@ -1081,7 +1092,7 @@ def has_correlated_errors(self): ) @property_exists - def has_time_correlated_errors(self): + def has_time_correlated_errors(self) -> bool: """Whether or not this model has time-correlated errors.""" return ( @@ -1097,7 +1108,7 @@ def has_time_correlated_errors(self): ) @property_exists - def covariance_matrix_funcs(self): + def covariance_matrix_funcs(self) -> List[Callable]: """List of covariance matrix functions.""" cvfs = [] if "NoiseComponent" in self.component_types: @@ -1106,7 +1117,7 @@ def covariance_matrix_funcs(self): return cvfs @property_exists - def dm_covariance_matrix_funcs(self): + def dm_covariance_matrix_funcs(self) -> List[Callable]: """List of covariance matrix functions.""" cvfs = [] if "NoiseComponent" in self.component_types: @@ -1116,7 +1127,7 @@ def dm_covariance_matrix_funcs(self): # Change sigma to uncertainty to avoid name conflict. @property_exists - def scaled_toa_uncertainty_funcs(self): + def scaled_toa_uncertainty_funcs(self) -> List[Callable]: """List of scaled toa uncertainty functions.""" ssfs = [] if "NoiseComponent" in self.component_types: @@ -1126,7 +1137,7 @@ def scaled_toa_uncertainty_funcs(self): # Change sigma to uncertainty to avoid name conflict. @property_exists - def scaled_dm_uncertainty_funcs(self): + def scaled_dm_uncertainty_funcs(self) -> List[Callable]: """List of scaled dm uncertainty functions.""" ssfs = [] if "NoiseComponent" in self.component_types: @@ -1136,7 +1147,7 @@ def scaled_dm_uncertainty_funcs(self): return ssfs @property_exists - def basis_funcs(self): + def basis_funcs(self) -> List[Callable]: """List of scaled uncertainty functions.""" bfs = [] if "NoiseComponent" in self.component_types: @@ -1145,34 +1156,39 @@ def basis_funcs(self): return bfs @property_exists - def phase_deriv_funcs(self): + def phase_deriv_funcs(self) -> List[Callable]: """List of derivative functions for phase components.""" return self.get_deriv_funcs("PhaseComponent") @property_exists - def delay_deriv_funcs(self): + def delay_deriv_funcs(self) -> List[Callable]: """List of derivative functions for delay components.""" return self.get_deriv_funcs("DelayComponent") @property_exists - def dm_derivs(self): # TODO need to be careful about the name here. + def dm_derivs(self) -> List[Callable]: + # TODO need to be careful about the name here. """List of DM derivative functions.""" return self.get_deriv_funcs("DelayComponent", "dm") @property_exists - def toasigma_derivs(self): + def toasigma_derivs(self) -> List[Callable]: """List of scaled TOA uncertainty derivative functions""" return self.get_deriv_funcs("NoiseComponent", "toasigma") @property_exists - def d_phase_d_delay_funcs(self): + def d_phase_d_delay_funcs(self) -> List[Callable]: """List of d_phase_d_delay functions.""" Dphase_Ddelay = [] for cp in self.PhaseComponent_list: Dphase_Ddelay += cp.phase_derivs_wrt_delay return Dphase_Ddelay - def get_deriv_funcs(self, component_type, derivative_type=""): + def get_deriv_funcs( + self, + component_type: Literal["PhaseComponent", "DelayComponent", "NoiseComponent"], + derivative_type: Literal["", "dm", "toasigma"] = "", + ) -> Dict[str, Callable]: """Return a dictionary of derivative functions. Parameters @@ -1197,7 +1213,7 @@ def get_deriv_funcs(self, component_type, derivative_type=""): deriv_funcs[k] += v return dict(deriv_funcs) - def search_cmp_attr(self, name): + def search_cmp_attr(self, name: str) -> Optional["Component"]: """Search for an attribute in all components. Return the component, or None. @@ -1210,7 +1226,7 @@ def search_cmp_attr(self, name): return cp raise AttributeError(f"{name} not found in any component") - def get_component_type(self, component): + def get_component_type(self, component: "Component") -> str: """Identify the component object's type. Parameters @@ -1243,7 +1259,9 @@ def get_component_type(self, component): comp_type = comp_base[-3].__name__ return comp_type - def map_component(self, component): + def map_component( + self, component: Union[str, "Component"] + ) -> Tuple["Component", int, List["Component"], str]: """Get the location of component. Parameters From dcd69c8ebb0b29aa045e3176b28cf5c3c88fc4b2 Mon Sep 17 00:00:00 2001 From: Abhimanyu Susobhanan Date: Fri, 3 Jan 2025 12:15:36 +0100 Subject: [PATCH 3/6] hints --- src/pint/models/timing_model.py | 251 ++++++++++++++++++++------------ 1 file changed, 154 insertions(+), 97 deletions(-) diff --git a/src/pint/models/timing_model.py b/src/pint/models/timing_model.py index f5f798ce1..2042137d0 100644 --- a/src/pint/models/timing_model.py +++ b/src/pint/models/timing_model.py @@ -39,6 +39,7 @@ from uncertainties import ufloat import astropy.time as time +from astropy.table import Table from astropy import units as u, constants as c import numpy as np from astropy.utils.decorators import lazyproperty @@ -81,6 +82,7 @@ UnknownBinaryModel, MissingBinaryError, ) +from pint.types import file_like, time_like __all__ = [ @@ -1297,8 +1299,13 @@ def map_component( return comp, order, host_list, comp_type def add_component( - self, component, order=DEFAULT_ORDER, force=False, setup=True, validate=True - ): + self, + component: "Component", + order: List[str] = DEFAULT_ORDER, + force: bool = False, + setup: bool = True, + validate: bool = True, + ) -> None: """Add a component into TimingModel. Parameters @@ -1352,7 +1359,7 @@ def add_component( if validate: self.validate() - def remove_component(self, component): + def remove_component(self, component: Union[str, "Component"]) -> None: """Remove one component from the timing model. Parameters @@ -1363,7 +1370,7 @@ def remove_component(self, component): cp, co_order, host, cp_type = self.map_component(component) host.remove(cp) - def _locate_param_host(self, param): + def _locate_param_host(self, param: str): """Search for the parameter host component in the timing model. Parameters @@ -1379,6 +1386,10 @@ def _locate_param_host(self, param): second one is the parameter object. If it is a prefix-style parameter, it will return one example of such parameter. """ + + # AS: The return signature of this function is a mess. It is not clear to me + # how exactly this is used, so I am leaving it alone. + result_comp = [] for cp_name, cp in self.components.items(): if param in cp.params: @@ -1396,7 +1407,7 @@ def _locate_param_host(self, param): return result_comp - def get_components_by_category(self): + def get_components_by_category(self) -> Dict[str, List["Component"]]: """Return a dict of this model's component objects keyed by the category name.""" categorydict = defaultdict(list) for cp in self.components.values(): @@ -1404,7 +1415,9 @@ def get_components_by_category(self): # Convert from defaultdict to dict return dict(categorydict) - def add_param_from_top(self, param, target_component, setup=False): + def add_param_from_top( + self, param: Parameter, target_component: str, setup: bool = False + ) -> None: """Add a parameter to a timing model component. Parameters @@ -1427,7 +1440,7 @@ def add_param_from_top(self, param, target_component, setup=False): f"Can not find component '{target_component}' in " "timing model." ) - def remove_param(self, param): + def remove_param(self, param: str) -> None: """Remove a parameter from timing model. Parameters @@ -1446,7 +1459,7 @@ def remove_param(self, param): self.components[target_component].remove_param(param) self.setup() - def get_params_mapping(self): + def get_params_mapping(self) -> Dict[str, str]: """Report which component each parameter name comes from.""" param_mapping = {p: "TimingModel" for p in self.top_level_params} for cp in list(self.components.values()): @@ -1454,13 +1467,14 @@ def get_params_mapping(self): param_mapping[pp] = cp.__class__.__name__ return param_mapping - def get_params_of_type_top(self, param_type): + def get_params_of_type_top(self, param_type: str) -> List[str]: + """Return all parameters in the model that belong to a certain `Parameter` subtype.""" result = [] for cp in self.components.values(): result += cp.get_params_of_type(param_type) return result - def get_prefix_mapping(self, prefix): + def get_prefix_mapping(self, prefix: str) -> Dict[int, str]: """Get the index mapping for the prefix parameters. Parameters @@ -1480,7 +1494,7 @@ def get_prefix_mapping(self, prefix): return mapping raise ValueError(f"Can not find prefix {prefix!r}") - def get_prefix_list(self, prefix, start_index=0): + def get_prefix_list(self, prefix: str, start_index: int = 0) -> List[u.Quantity]: """Return the Quantities associated with a sequence of prefix parameters. Parameters @@ -1529,14 +1543,16 @@ def get_prefix_list(self, prefix, start_index=0): ) return r - def param_help(self): - """Print help lines for all available parameters in model.""" + def param_help(self) -> str: + """Return help strings for all available parameters in model.""" return "".join( "{:<40}{}\n".format(cp, getattr(self, par).help_line()) for par, cp in self.get_params_mapping().items() ) - def delay(self, toas, cutoff_component="", include_last=True): + def delay( + self, toas: TOAs, cutoff_component: str = "", include_last: bool = True + ) -> u.Quantity: """Total delay for the TOAs. Return the total delay which will be subtracted from the given @@ -1569,7 +1585,7 @@ def delay(self, toas, cutoff_component="", include_last=True): delay += df(toas, delay) return delay - def phase(self, toas, abs_phase=None): + def phase(self, toas: TOAs, abs_phase: Optional[bool] = None) -> Phase: """Return the model-predicted pulse phase for the given TOAs. This is the phase as observed at the observatory at the exact moment @@ -1605,7 +1621,7 @@ def phase(self, toas, abs_phase=None): tz_phase += Phase(pf(tz_toa, tz_delay)) return phase - tz_phase - def add_tzr_toa(self, toas): + def add_tzr_toa(self, toas: TOAs) -> None: """Create a TZR TOA for the given TOAs object and add it to the timing model. This corresponds to TOA closest to the PEPOCH.""" from pint.models.absolute_phase import AbsPhase @@ -1614,7 +1630,7 @@ def add_tzr_toa(self, toas): self.make_TZR_toa(toas) self.validate() - def total_dm(self, toas): + def total_dm(self, toas: TOAs) -> u.Quantity: """Calculate dispersion measure from all the dispersion type of components.""" # Here we assume the unit would be the same for all the dm value function. # By doing so, we do not have to hard code an unit here. @@ -1624,13 +1640,14 @@ def total_dm(self, toas): dm += dm_f(toas) return dm - def total_dispersion_slope(self, toas): + def total_dispersion_slope(self, toas: TOAs) -> u.Quantity: """Calculate the dispersion slope from all the dispersion-type components.""" dm_tot = self.total_dm(toas) return dispersion_slope(dm_tot) - def toa_covariance_matrix(self, toas): - """Get the TOA covariance matrix for noise models. + def toa_covariance_matrix(self, toas: TOAs) -> np.ndarray: + """Get the TOA covariance matrix for noise models. The matrix elements + have units of s^2. If there is no noise model component provided, a diagonal matrix with TOAs error as diagonal element will be returned. @@ -1643,12 +1660,14 @@ def toa_covariance_matrix(self, toas): result += nf(toas) return result - def dm_covariance_matrix(self, toas): - """Get the DM covariance matrix for noise models. + def dm_covariance_matrix(self, toas: TOAs) -> np.ndarray: + """Get the DM covariance matrix for noise models. The matrix elements have + units of dmu^2. If there is no noise model component provided, a diagonal matrix with TOAs error as diagonal element will be returned. """ + # TODO: Check if this is correct when PLDMNoise is present. dms, valid_dm = toas.get_flag_value("pp_dm", as_type=float) dmes, valid_dme = toas.get_flag_value("pp_dme", as_type=float) dms = np.array(dms)[valid_dm] @@ -1665,7 +1684,7 @@ def dm_covariance_matrix(self, toas): result += nf(toas) return result - def scaled_toa_uncertainty(self, toas): + def scaled_toa_uncertainty(self, toas: TOAs) -> u.Quantity: """Get the scaled TOA data uncertainties noise models. If there is no noise model component provided, a vector with @@ -1688,7 +1707,7 @@ def scaled_toa_uncertainty(self, toas): result += nf(toas) return result - def scaled_dm_uncertainty(self, toas): + def scaled_dm_uncertainty(self, toas: TOAs) -> u.Quantity: """Get the scaled DM data uncertainties noise models. If there is no noise model component provided, a vector with @@ -1711,19 +1730,21 @@ def scaled_dm_uncertainty(self, toas): result += nf(toas) return result - def noise_model_designmatrix(self, toas): + def noise_model_designmatrix(self, toas: TOAs) -> np.ndarray: + """Returns the joint design/basis matrix for all noise components.""" if len(self.basis_funcs) == 0: return None result = [nf(toas)[0] for nf in self.basis_funcs] return np.hstack(list(result)) - def noise_model_basis_weight(self, toas): + def noise_model_basis_weight(self, toas: TOAs) -> np.ndarray: + """Returns the joint weight vector for all noise components.""" if len(self.basis_funcs) == 0: return None result = [nf(toas)[1] for nf in self.basis_funcs] return np.hstack(list(result)) - def noise_model_dimensions(self, toas): + def noise_model_dimensions(self, toas: TOAs) -> Dict[str, Tuple[int, int]]: """Number of basis functions for each noise model component. Returns a dictionary of correlated-noise components in the noise @@ -1748,7 +1769,7 @@ def noise_model_dimensions(self, toas): return result - def jump_flags_to_params(self, toas): + def jump_flags_to_params(self, toas: TOAs) -> None: """Add JUMP parameters corresponding to tim_jump flags. When a ``.tim`` file contains pairs of JUMP lines, the user's expectation @@ -1825,7 +1846,7 @@ def jump_flags_to_params(self, toas): self.components["PhaseJump"].setup() - def delete_jump_and_flags(self, toa_table, jump_num): + def delete_jump_and_flags(self, toa_table: Optional[list], jump_num: int) -> None: """Delete jump object from PhaseJump and remove its flags from TOA table. This is a helper function for pintk. @@ -1866,14 +1887,16 @@ def delete_jump_and_flags(self, toa_table, jump_num): return self.components["PhaseJump"].setup() - def get_barycentric_toas(self, toas, cutoff_component=""): + def get_barycentric_toas( + self, toas: TOAs, cutoff_component: str = "" + ) -> u.Quantity: """Conveniently calculate the barycentric TOAs. Parameters ---------- toas: TOAs object The TOAs the barycentric corrections are applied on - cutoff_delay: str, optional + cutoff_component: str, optional The cutoff delay component name. If it is not provided, it will search for binary delay and apply all the delay before binary. @@ -1891,8 +1914,11 @@ def get_barycentric_toas(self, toas, cutoff_component=""): corr = self.delay(toas, cutoff_component, False) return tbl["tdbld"] * u.day - corr - def d_phase_d_toa(self, toas, sample_step=None): - """Return the finite-difference derivative of phase wrt TOA. + def d_phase_d_toa( + self, toas: TOAs, sample_step: Optional[float] = None + ) -> u.Quantity: + """Return the finite-difference derivative of phase wrt TOA. This is the same as + the topocentric frequency of the pulsar. Parameters ---------- @@ -1924,14 +1950,14 @@ def d_phase_d_toa(self, toas, sample_step=None): del copy_toas return d_phase_d_toa.to(u.Hz) - def d_phase_d_tpulsar(self, toas): + def d_phase_d_tpulsar(self, toas: TOAs): """Return the derivative of phase wrt time at the pulsar. NOT implemented yet. """ raise NotImplementedError - def d_phase_d_param(self, toas, delay, param): + def d_phase_d_param(self, toas: TOAs, delay: u.Quantity, param: str) -> u.Quantity: """Return the derivative of phase with respect to the parameter. This is the derivative of the phase observed at each TOA with @@ -1987,7 +2013,9 @@ def d_phase_d_param(self, toas, delay, param): result = dpdd_result * d_delay_d_p return result.to(result.unit, equivalencies=u.dimensionless_angles()) - def d_delay_d_param(self, toas, param, acc_delay=None): + def d_delay_d_param( + self, toas: TOAs, param: str, acc_delay: Optional[u.Quantity] = None + ) -> u.Quantity: """Return the derivative of delay with respect to the parameter.""" par = getattr(self, param) result = np.longdouble(np.zeros(toas.ntoas) << (u.s / par.units)) @@ -2003,7 +2031,9 @@ def d_delay_d_param(self, toas, param, acc_delay=None): ) return result - def d_phase_d_param_num(self, toas, param, step=1e-2): + def d_phase_d_param_num( + self, toas: TOAs, param: str, step: float = 1e-2 + ) -> u.Quantity: """Return the derivative of phase with respect to the parameter. Compute the value numerically, using a symmetric finite difference. @@ -2033,7 +2063,9 @@ def d_phase_d_param_num(self, toas, param, step=1e-2): par.value = ori_value return result - def d_delay_d_param_num(self, toas, param, step=1e-2): + def d_delay_d_param_num( + self, toas: TOAs, param: str, step: float = 1e-2 + ) -> u.Quantity: """Return the derivative of delay with respect to the parameter. Compute the value numerically, using a symmetric finite difference. @@ -2060,10 +2092,10 @@ def d_delay_d_param_num(self, toas, param, step=1e-2): par.value = ori_value return d_delay * (u.second / unit) - def d_dm_d_param(self, data, param): + def d_dm_d_param(self, toas: TOAs, param: str) -> u.Quantity: """Return the derivative of DM with respect to the parameter.""" par = getattr(self, param) - result = np.zeros(len(data)) << (u.pc / u.cm**3 / par.units) + result = np.zeros(len(toas)) << (u.pc / u.cm**3 / par.units) dm_df = self.dm_derivs.get(param, None) if dm_df is None: if param not in self.params: # Maybe add differentiable params @@ -2072,15 +2104,15 @@ def d_dm_d_param(self, data, param): return result for df in dm_df: - result += df(data, param).to( + result += df(toas, param).to( result.unit, equivalencies=u.dimensionless_angles() ) return result - def d_toasigma_d_param(self, data, param): + def d_toasigma_d_param(self, toas: TOAs, param: str) -> u.Quantity: """Return the derivative of the scaled TOA uncertainty with respect to the parameter.""" par = getattr(self, param) - result = np.zeros(len(data)) << (u.s / par.units) + result = np.zeros(len(toas)) << (u.s / par.units) sigma_df = self.toasigma_derivs.get(param, None) if sigma_df is None: if param not in self.params: # Maybe add differentiable params @@ -2089,12 +2121,18 @@ def d_toasigma_d_param(self, data, param): return result for df in sigma_df: - result += df(data, param).to( + result += df(toas, param).to( result.unit, equivalencies=u.dimensionless_angles() ) return result - def designmatrix(self, toas, acc_delay=None, incfrozen=False, incoffset=True): + def designmatrix( + self, + toas: TOAs, + acc_delay: u.Quantity = None, + incfrozen: bool = False, + incoffset: bool = True, + ) -> Tuple[np.ndarray, List[str], List[u.Unit]]: """Return the design matrix. The design matrix is the matrix with columns of ``d_phase_d_param/F0`` @@ -2200,15 +2238,15 @@ def designmatrix(self, toas, acc_delay=None, incfrozen=False, incoffset=True): def compare( self, - othermodel, - nodmx=True, - convertcoordinates=True, - threshold_sigma=3.0, - unc_rat_threshold=1.05, - verbosity="max", - usecolor=True, - format="text", - ): + othermodel: "TimingModel", + nodmx: bool = True, + convertcoordinates: bool = True, + threshold_sigma: float = 3.0, + unc_rat_threshold: float = 1.05, + verbosity: Literal["max", "med", "min", "check"] = "max", + usecolor: True = True, + format: Literal["text", "markdown"] = "text", + ) -> str: """Print comparison with another model Parameters @@ -2743,7 +2781,11 @@ def compare( if verbosity != "check": return "\n".join(s) - def use_aliases(self, reset_to_default=True, alias_translation=None): + def use_aliases( + self, + reset_to_default: bool = True, + alias_translation: Optional[Dict[str, str]] = None, + ) -> None: """Set the parameters to use aliases as specified upon writing. Parameters @@ -2770,13 +2812,13 @@ def use_aliases(self, reset_to_default=True, alias_translation=None): def as_parfile( self, - start_order=["astrometry", "spindown", "dispersion"], - last_order=["jump_delay"], + start_order: List[str] = ["astrometry", "spindown", "dispersion"], + last_order: List[str] = ["jump_delay"], *, - include_info=True, - comment=None, - format="pint", - ): + include_info: bool = True, + comment: Optional[str] = None, + format: Literal["tempo", "tempo2", "pint"] = "pint", + ) -> str: """Represent the entire model as a parfile string. See also :func:`pint.models.TimingModel.write_parfile`. @@ -2847,14 +2889,14 @@ def as_parfile( def write_parfile( self, - filename, - start_order=["astrometry", "spindown", "dispersion"], - last_order=["jump_delay"], + filename: file_like, + start_order: List[str] = ["astrometry", "spindown", "dispersion"], + last_order: List[str] = ["jump_delay"], *, - include_info=True, - comment=None, - format="pint", - ): + include_info: bool = True, + comment: Optional[str] = None, + format: Literal["tempo", "tempo2", "pint"] = "pint", + ) -> None: """Write the entire model as a parfile. See also :func:`pint.models.TimingModel.as_parfile`. @@ -2886,7 +2928,7 @@ def write_parfile( ) ) - def validate_toas(self, toas): + def validate_toas(self, toas: TOAs) -> None: """Sanity check to verify that this model is compatible with these toas. This checks that where this model needs TOAs to constrain parameters, @@ -2917,7 +2959,7 @@ def validate_toas(self, toas): if bad_parameters: raise MissingTOAs(bad_parameters) - def find_empty_masks(self, toas, freeze=False): + def find_empty_masks(self, toas: TOAs, freeze: bool = False) -> List[str]: """Find unfrozen mask parameters with no TOAs before trying to fit Parameters @@ -2951,7 +2993,7 @@ def find_empty_masks(self, toas, freeze=False): bad_parameters.append(k) return bad_parameters - def setup(self): + def setup(self) -> None: """Run setup methods on all components.""" for cp in self.components.values(): cp.setup() @@ -2971,19 +3013,19 @@ def __setitem__(self, name, value): # FIXME: This could be the right way to add Parameters? raise NotImplementedError - def keys(self): + def keys(self) -> List[str]: return self.params - def items(self): + def items(self) -> List[str, Parameter]: return [(p, self[p]) for p in self.params] - def __len__(self): + def __len__(self) -> int: return len(self.params) def __iter__(self): yield from self.params - def as_ECL(self, epoch=None, ecl="IERS2010"): + def as_ECL(self, epoch: time_like = None, ecl: str = "IERS2010") -> "TimingModel": """Return TimingModel in PulsarEcliptic frame. Parameters @@ -3033,7 +3075,7 @@ def as_ECL(self, epoch=None, ecl="IERS2010"): return new_model - def as_ICRS(self, epoch=None): + def as_ICRS(self, epoch: time_like = None) -> "TimingModel": """Return TimingModel in ICRS frame. Parameters @@ -3077,7 +3119,12 @@ def as_ICRS(self, epoch=None): return new_model - def get_derived_params(self, rms=None, ntoas=None, returndict=False): + def get_derived_params( + self, + rms: Optional[u.Quantity] = None, + ntoas: Optional[int] = None, + returndict: bool = False, + ) -> Union[str, Tuple[str, Dict[str, u.Quantity]]]: """Return a string with various derived parameters from the fitted model Parameters @@ -3293,7 +3340,6 @@ class ModelMeta(abc.ABCMeta): a class attribute ``component_types``, provided that the class has an attribute ``register``. This makes sure all timing model components are listed in ``Component.component_types``. - """ def __init__(cls, name, bases, dct): @@ -3328,7 +3374,7 @@ def __init__(self): self.deriv_funcs = {} self.component_special_params = [] - def __repr__(self): + def __repr__(self) -> str: return "{}(\n {})".format( self.__class__.__name__, ",\n ".join( @@ -3338,25 +3384,25 @@ def __repr__(self): ), ) - def setup(self): + def setup(self) -> None: """Finalize construction loaded values.""" pass - def validate(self): + def validate(self) -> None: """Validate loaded values.""" pass - def validate_toas(self, toas): + def validate_toas(self, toas) -> None: """Check that this model component has TOAs where needed.""" pass @property_exists - def category(self): + def category(self) -> str: """Category is a feature the class, so delegate.""" return self.__class__.category @property_exists - def free_params_component(self): + def free_params_component(self) -> List[str]: """Return the free parameters in the component. This function collects the non-frozen parameters. @@ -3373,7 +3419,7 @@ def free_params_component(self): return free_param @property_exists - def param_prefixs(self): + def param_prefixs(self) -> Dict[str, List[str]]: prefixs = {} for p in self.params: par = getattr(self, p) @@ -3385,7 +3431,7 @@ def param_prefixs(self): return prefixs @property_exists - def aliases_map(self): + def aliases_map(self) -> Dict[str, str]: """Return all the aliases and map to the PINT parameter name. This property returns a dictionary from the current in timing model @@ -3402,7 +3448,12 @@ def aliases_map(self): ali_map[ali] = p return ali_map - def add_param(self, param, deriv_func=None, setup=False): + def add_param( + self, + param: Parameter, + deriv_func: Optional[Callable] = None, + setup: bool = False, + ): """Add a parameter to the Component. The parameter is stored in an attribute on the Component object. @@ -3467,7 +3518,7 @@ def add_param(self, param, deriv_func=None, setup=False): param._parent = self return param.name - def remove_param(self, param): + def remove_param(self, param: str | Parameter) -> None: """Remove a parameter from the Component. Parameters @@ -3489,7 +3540,7 @@ def remove_param(self, param): self.component_special_params.remove(pn) delattr(self, param) - def set_special_params(self, spcl_params): + def set_special_params(self, spcl_params: List[str]) -> None: als = [] for p in spcl_params: als += getattr(self, p).aliases @@ -3498,15 +3549,15 @@ def set_special_params(self, spcl_params): if sp not in self.component_special_params: self.component_special_params.append(sp) - def param_help(self): + def param_help(self) -> str: """Print help lines for all available parameters in model.""" s = "Available parameters for %s\n" % self.__class__ for par in self.params: s += "%s\n" % getattr(self, par).help_line() return s - def get_params_of_type(self, param_type): - """Get all the parameters in timing model for one specific type.""" + def get_params_of_type(self, param_type: str) -> List[str]: + """Get all the parameters in timing model for one specific `Parameter` subtype.""" result = [] for p in self.params: par = getattr(self, p) @@ -3516,7 +3567,7 @@ def get_params_of_type(self, param_type): result.append(par.name) return result - def get_prefix_mapping_component(self, prefix): + def get_prefix_mapping_component(self, prefix: str) -> Dict[int, str]: """Get the index mapping for the prefix parameters. Parameters @@ -3539,7 +3590,7 @@ def get_prefix_mapping_component(self, prefix): mapping[par.index] = parname return OrderedDict(sorted(mapping.items())) - def match_param_aliases(self, alias): + def match_param_aliases(self, alias: str) -> str: """Return the parameter corresponding to this alias. Parameters @@ -3587,7 +3638,7 @@ def match_param_aliases(self, alias): else: raise UnknownParameter(f"Unknown parameter name or alias {alias}") - def register_deriv_funcs(self, func, param): + def register_deriv_funcs(self, func: Callable, param: str) -> None: """Register the derivative function in to the deriv_func dictionaries. Parameters @@ -3607,7 +3658,7 @@ def register_deriv_funcs(self, func, param): else: self.deriv_funcs[pn] += [func] - def is_in_parfile(self, para_dict): + def is_in_parfile(self, para_dict: Dict) -> bool: """Check if this subclass included in parfile. Parameters @@ -3659,7 +3710,7 @@ def is_in_parfile(self, para_dict): return True - def print_par(self, format="pint"): + def print_par(self, format: Literal["tempo", "tempo2", "pint"] = "pint") -> str: """ Parameters ---------- @@ -3677,12 +3728,18 @@ def print_par(self, format="pint"): class DelayComponent(Component): + """Abstract base class of all delay components. These components implement a `delay` + method.""" + def __init__(self): super().__init__() self.delay_funcs_component = [] class PhaseComponent(Component): + """Abstract base class of all phase components. These components implement a `phase` + method.""" + def __init__(self): super().__init__() self.phase_funcs_component = [] From 25d8b9a7a415b31b5ba99a437e84bea52da00015 Mon Sep 17 00:00:00 2001 From: Abhimanyu Susobhanan Date: Fri, 3 Jan 2025 12:19:27 +0100 Subject: [PATCH 4/6] -- --- src/pint/models/timing_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pint/models/timing_model.py b/src/pint/models/timing_model.py index 2042137d0..c455d3449 100644 --- a/src/pint/models/timing_model.py +++ b/src/pint/models/timing_model.py @@ -3016,7 +3016,7 @@ def __setitem__(self, name, value): def keys(self) -> List[str]: return self.params - def items(self) -> List[str, Parameter]: + def items(self) -> List[Tuple[str, Parameter]]: return [(p, self[p]) for p in self.params] def __len__(self) -> int: From 06ab04425c353b65ba3d5a8998d7b232452819dd Mon Sep 17 00:00:00 2001 From: Abhimanyu Susobhanan Date: Fri, 3 Jan 2025 12:36:26 +0100 Subject: [PATCH 5/6] union --- src/pint/models/timing_model.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/pint/models/timing_model.py b/src/pint/models/timing_model.py index c455d3449..39e794792 100644 --- a/src/pint/models/timing_model.py +++ b/src/pint/models/timing_model.py @@ -552,7 +552,7 @@ def __getattr__(self, name: str): ) def __setattr__( - self, name: str, value: Parameter | prefixParameter | u.Quantity | float + self, name: str, value: Union[Parameter, prefixParameter, u.Quantity, float] ): """Mostly this just sets ``self.name = value``. But there are a few special cases: @@ -731,7 +731,7 @@ def get_params_dict( self, which: Literal["free", "all"] = "free", kind: Literal["quantity", "value", "uncertainty"] = "quantity", - ) -> OrderedDict[str, float] | OrderedDict[str, u.Quantity]: + ) -> Union[OrderedDict[str, float], OrderedDict[str, u.Quantity]]: """Return a dict mapping parameter names to values. This can return only the free parameters or all; and it can return the @@ -846,7 +846,7 @@ def is_binary(self) -> bool: def orbital_phase( self, - barytimes: time.Time | TOAs | np.ndarray | float | MJDParameter, + barytimes: Union[time.Time, TOAs, np.ndarray, float, MJDParameter], anom: Literal["mean", "eccentric", "true"] = "mean", radians: bool = True, ) -> np.ndarray: @@ -919,7 +919,7 @@ def orbital_phase( return anoms * u.rad if radians else anoms / (2 * np.pi) def pulsar_radial_velocity( - self, barytimes: time.Time | TOAs | np.ndarray | float | MJDParameter + self, barytimes: Union[time.Time, TOAs, np.ndarray, float, MJDParameter] ) -> np.ndarray: """Return line-of-sight velocity of the pulsar relative to the system barycenter at barycentric MJD times. @@ -968,7 +968,7 @@ def pulsar_radial_velocity( def companion_radial_velocity( self, - barytimes: time.Time | TOAs | np.ndarray | float | MJDParameter, + barytimes: Union[time.Time, TOAs, np.ndarray, float, MJDParameter], massratio: float, ) -> np.ndarray: """Return line-of-sight velocity of the companion relative to the system barycenter at barycentric MJD times. @@ -1006,7 +1006,7 @@ def companion_radial_velocity( """ return -self.pulsar_radial_velocity(barytimes) * massratio - def conjunction(self, baryMJD: float | time.Time) -> float | np.ndarray: + def conjunction(self, baryMJD: Union[float, time.Time]) -> Union[float, np.ndarray]: """Return the time(s) of the first superior conjunction(s) after baryMJD. Args @@ -3518,7 +3518,7 @@ def add_param( param._parent = self return param.name - def remove_param(self, param: str | Parameter) -> None: + def remove_param(self, param: Union[str, Parameter]) -> None: """Remove a parameter from the Component. Parameters From 527e366eb8ef923a975d22cd0ea28ddf99a1be4f Mon Sep 17 00:00:00 2001 From: Abhimanyu Susobhanan Date: Fri, 3 Jan 2025 12:45:01 +0100 Subject: [PATCH 6/6] AllComponents --- src/pint/models/timing_model.py | 32 +++++++++++++++++--------------- 1 file changed, 17 insertions(+), 15 deletions(-) diff --git a/src/pint/models/timing_model.py b/src/pint/models/timing_model.py index 39e794792..fe695f3aa 100644 --- a/src/pint/models/timing_model.py +++ b/src/pint/models/timing_model.py @@ -34,7 +34,7 @@ import contextlib from collections import OrderedDict, defaultdict from functools import wraps -from typing import Callable, Dict, List, Literal, Optional, Tuple, Union +from typing import Callable, Dict, List, Literal, Optional, Set, Tuple, Union from warnings import warn from uncertainties import ufloat @@ -3773,7 +3773,7 @@ def __init__(self): self.components[k] = v() @lazyproperty - def param_component_map(self): + def param_component_map(self) -> Dict[str, List[str]]: """Return the parameter to component map. This property returns the all PINT defined parameters to their host @@ -3798,7 +3798,9 @@ def param_component_map(self): p2c_map[ap].append("timing_model") return p2c_map - def _check_alias_conflict(self, alias, param_name, alias_map): + def _check_alias_conflict( + self, alias: str, param_name: str, alias_map: dict + ) -> None: """Check if a aliase has conflict in the alias map. This function checks if an alias already have record in the alias_map. @@ -3831,7 +3833,7 @@ def _check_alias_conflict(self, alias, param_name, alias_map): return @lazyproperty - def _param_alias_map(self): + def _param_alias_map(self) -> Dict[str, str]: """Return the aliases map of all parameters The returned map includes: 1. alias to PINT parameter name. 2. PINT @@ -3862,7 +3864,7 @@ def _param_alias_map(self): return alias @lazyproperty - def _param_unit_map(self): + def _param_unit_map(self) -> Dict[str, u.Unit]: """A dictionary to map parameter names to their units This excludes prefix parameters and aliases. Use :func:`param_to_unit` to handle those. @@ -3881,7 +3883,7 @@ def _param_unit_map(self): return units @lazyproperty - def repeatable_param(self): + def repeatable_param(self) -> Set[str]: """Return the repeatable parameter map.""" repeatable = [] for k, cp in self.components.items(): @@ -3894,7 +3896,7 @@ def repeatable_param(self): return set(repeatable) @lazyproperty - def category_component_map(self): + def category_component_map(self) -> Dict[str, str]: """A dictionary mapping category to a list of component names. Return @@ -3911,7 +3913,7 @@ def category_component_map(self): return category @lazyproperty - def component_category_map(self): + def component_category_map(self) -> Dict[str, str]: """A dictionary mapping component name to its category name. Return @@ -3923,7 +3925,7 @@ def component_category_map(self): return {k: cp.category for k, cp in self.components.items()} @lazyproperty - def component_unique_params(self): + def component_unique_params(self) -> Dict[str, List[str]]: """Return the parameters that are only present in one component. Return @@ -3943,7 +3945,7 @@ def component_unique_params(self): component_special_params[cps[0]].append(param) return component_special_params - def search_binary_components(self, system_name): + def search_binary_components(self, system_name: str) -> "Component": """Search the pulsar binary component based on given name. Parameters @@ -3991,7 +3993,7 @@ def search_binary_components(self, system_name): f"Pulsar system/Binary model component" f" {system_name} is not provided." ) - def alias_to_pint_param(self, alias): + def alias_to_pint_param(self, alias: str) -> Tuple[str, str]: """Translate a alias to a PINT parameter name. This is a wrapper function over the property ``_param_alias_map``. It @@ -4008,10 +4010,10 @@ def alias_to_pint_param(self, alias): ------- pint_par : str PINT parameter name the given alias maps to. If there is no matching - PINT parameters, it will raise a `UnknownParameter` error. + PINT parameters, it will raise an `UnknownParameter` error. first_init_par : str The parameter name that is first initialized in a component. If the - paramere is non-indexable, it is the same as ``pint_par``, otherwrise + paramere is non-indexable, it is the same as ``pint_par``, otherwise it returns the parameter with the first index. For example, the ``first_init_par`` for 'T2EQUAD25' is 'EQUAD1' @@ -4078,7 +4080,7 @@ def alias_to_pint_param(self, alias): return pint_par, first_init_par - def param_to_unit(self, name): + def param_to_unit(self, name: str) -> u.Unit: """Return the unit associated with a parameter This is a wrapper function over the property ``_param_unit_map``. It @@ -4095,7 +4097,7 @@ def param_to_unit(self, name): Returns ------- - astropy.u.Unit + astropy.units.Unit """ pintname, firstname = self.alias_to_pint_param(name) if pintname == firstname: