From e8f40f23bab915b39db6d1e01acfd42327acf297 Mon Sep 17 00:00:00 2001 From: Abhimanyu Susobhanan Date: Fri, 3 Jan 2025 10:40:11 +0100 Subject: [PATCH] type hints --- src/pint/models/timing_model.py | 66 +++++++++++++++++++++------------ 1 file changed, 42 insertions(+), 24 deletions(-) diff --git a/src/pint/models/timing_model.py b/src/pint/models/timing_model.py index a650ea666..f5f798ce1 100644 --- a/src/pint/models/timing_model.py +++ b/src/pint/models/timing_model.py @@ -34,7 +34,7 @@ import contextlib from collections import OrderedDict, defaultdict from functools import wraps -from typing import Callable, Dict, List, Literal +from typing import Callable, Dict, List, Literal, Optional, Tuple, Union from warnings import warn from uncertainties import ufloat @@ -842,12 +842,17 @@ def is_binary(self) -> bool: return any(isinstance(x, PulsarBinary) for x in self.components.values()) - def orbital_phase(self, barytimes, anom="mean", radians=True): + def orbital_phase( + self, + barytimes: time.Time | TOAs | np.ndarray | float | MJDParameter, + anom: Literal["mean", "eccentric", "true"] = "mean", + radians: bool = True, + ) -> np.ndarray: """Return orbital phase (in radians) at barycentric MJD times. Parameters ---------- - barytimes: Time, TOAs, array-like, or float + barytimes: Time, TOAs, array-like, MJDParameter, or float MJD barycentric time(s). The times to compute the orbital phases. Needs to be a barycentric time in TDB. If a TOAs instance is passed, the barycentering will happen @@ -911,12 +916,14 @@ def orbital_phase(self, barytimes, anom="mean", radians=True): # return with radian units or return as unitless cycles from 0-1 return anoms * u.rad if radians else anoms / (2 * np.pi) - def pulsar_radial_velocity(self, barytimes): + def pulsar_radial_velocity( + self, barytimes: time.Time | TOAs | np.ndarray | float | MJDParameter + ) -> np.ndarray: """Return line-of-sight velocity of the pulsar relative to the system barycenter at barycentric MJD times. Parameters ---------- - barytimes: Time, TOAs, array-like, or float + barytimes: Time, TOAs, array-like, MJDParameter, or float MJD barycentric time(s). The times to compute the orbital phases. Needs to be a barycentric time in TDB. If a TOAs instance is passed, the barycentering will happen @@ -957,7 +964,11 @@ def pulsar_radial_velocity(self, barytimes): * (np.cos(psi) + bbi.ecc() * np.cos(bbi.omega())) ).cgs - def companion_radial_velocity(self, barytimes, massratio): + def companion_radial_velocity( + self, + barytimes: time.Time | TOAs | np.ndarray | float | MJDParameter, + massratio: float, + ) -> np.ndarray: """Return line-of-sight velocity of the companion relative to the system barycenter at barycentric MJD times. Parameters @@ -993,7 +1004,7 @@ def companion_radial_velocity(self, barytimes, massratio): """ return -self.pulsar_radial_velocity(barytimes) * massratio - def conjunction(self, baryMJD): + def conjunction(self, baryMJD: float | time.Time) -> float | np.ndarray: """Return the time(s) of the first superior conjunction(s) after baryMJD. Args @@ -1054,7 +1065,7 @@ def funct(t): return scs[0] if len(scs) == 1 else np.asarray(scs) @property_exists - def dm_funcs(self): + def dm_funcs(self) -> List[Callable]: """List of all dm value functions.""" dmfs = [] for cp in self.components.values(): @@ -1065,7 +1076,7 @@ def dm_funcs(self): return dmfs @property_exists - def has_correlated_errors(self): + def has_correlated_errors(self) -> bool: """Whether or not this model has correlated errors.""" return ( @@ -1081,7 +1092,7 @@ def has_correlated_errors(self): ) @property_exists - def has_time_correlated_errors(self): + def has_time_correlated_errors(self) -> bool: """Whether or not this model has time-correlated errors.""" return ( @@ -1097,7 +1108,7 @@ def has_time_correlated_errors(self): ) @property_exists - def covariance_matrix_funcs(self): + def covariance_matrix_funcs(self) -> List[Callable]: """List of covariance matrix functions.""" cvfs = [] if "NoiseComponent" in self.component_types: @@ -1106,7 +1117,7 @@ def covariance_matrix_funcs(self): return cvfs @property_exists - def dm_covariance_matrix_funcs(self): + def dm_covariance_matrix_funcs(self) -> List[Callable]: """List of covariance matrix functions.""" cvfs = [] if "NoiseComponent" in self.component_types: @@ -1116,7 +1127,7 @@ def dm_covariance_matrix_funcs(self): # Change sigma to uncertainty to avoid name conflict. @property_exists - def scaled_toa_uncertainty_funcs(self): + def scaled_toa_uncertainty_funcs(self) -> List[Callable]: """List of scaled toa uncertainty functions.""" ssfs = [] if "NoiseComponent" in self.component_types: @@ -1126,7 +1137,7 @@ def scaled_toa_uncertainty_funcs(self): # Change sigma to uncertainty to avoid name conflict. @property_exists - def scaled_dm_uncertainty_funcs(self): + def scaled_dm_uncertainty_funcs(self) -> List[Callable]: """List of scaled dm uncertainty functions.""" ssfs = [] if "NoiseComponent" in self.component_types: @@ -1136,7 +1147,7 @@ def scaled_dm_uncertainty_funcs(self): return ssfs @property_exists - def basis_funcs(self): + def basis_funcs(self) -> List[Callable]: """List of scaled uncertainty functions.""" bfs = [] if "NoiseComponent" in self.component_types: @@ -1145,34 +1156,39 @@ def basis_funcs(self): return bfs @property_exists - def phase_deriv_funcs(self): + def phase_deriv_funcs(self) -> List[Callable]: """List of derivative functions for phase components.""" return self.get_deriv_funcs("PhaseComponent") @property_exists - def delay_deriv_funcs(self): + def delay_deriv_funcs(self) -> List[Callable]: """List of derivative functions for delay components.""" return self.get_deriv_funcs("DelayComponent") @property_exists - def dm_derivs(self): # TODO need to be careful about the name here. + def dm_derivs(self) -> List[Callable]: + # TODO need to be careful about the name here. """List of DM derivative functions.""" return self.get_deriv_funcs("DelayComponent", "dm") @property_exists - def toasigma_derivs(self): + def toasigma_derivs(self) -> List[Callable]: """List of scaled TOA uncertainty derivative functions""" return self.get_deriv_funcs("NoiseComponent", "toasigma") @property_exists - def d_phase_d_delay_funcs(self): + def d_phase_d_delay_funcs(self) -> List[Callable]: """List of d_phase_d_delay functions.""" Dphase_Ddelay = [] for cp in self.PhaseComponent_list: Dphase_Ddelay += cp.phase_derivs_wrt_delay return Dphase_Ddelay - def get_deriv_funcs(self, component_type, derivative_type=""): + def get_deriv_funcs( + self, + component_type: Literal["PhaseComponent", "DelayComponent", "NoiseComponent"], + derivative_type: Literal["", "dm", "toasigma"] = "", + ) -> Dict[str, Callable]: """Return a dictionary of derivative functions. Parameters @@ -1197,7 +1213,7 @@ def get_deriv_funcs(self, component_type, derivative_type=""): deriv_funcs[k] += v return dict(deriv_funcs) - def search_cmp_attr(self, name): + def search_cmp_attr(self, name: str) -> Optional["Component"]: """Search for an attribute in all components. Return the component, or None. @@ -1210,7 +1226,7 @@ def search_cmp_attr(self, name): return cp raise AttributeError(f"{name} not found in any component") - def get_component_type(self, component): + def get_component_type(self, component: "Component") -> str: """Identify the component object's type. Parameters @@ -1243,7 +1259,9 @@ def get_component_type(self, component): comp_type = comp_base[-3].__name__ return comp_type - def map_component(self, component): + def map_component( + self, component: Union[str, "Component"] + ) -> Tuple["Component", int, List["Component"], str]: """Get the location of component. Parameters