Skip to content

Commit

Permalink
type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
abhisrkckl committed Jan 3, 2025
1 parent c96c950 commit e8f40f2
Showing 1 changed file with 42 additions and 24 deletions.
66 changes: 42 additions & 24 deletions src/pint/models/timing_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
import contextlib
from collections import OrderedDict, defaultdict
from functools import wraps
from typing import Callable, Dict, List, Literal
from typing import Callable, Dict, List, Literal, Optional, Tuple, Union
from warnings import warn
from uncertainties import ufloat

Expand Down Expand Up @@ -842,12 +842,17 @@ def is_binary(self) -> bool:

return any(isinstance(x, PulsarBinary) for x in self.components.values())

def orbital_phase(self, barytimes, anom="mean", radians=True):
def orbital_phase(
self,
barytimes: time.Time | TOAs | np.ndarray | float | MJDParameter,
anom: Literal["mean", "eccentric", "true"] = "mean",
radians: bool = True,
) -> np.ndarray:
"""Return orbital phase (in radians) at barycentric MJD times.
Parameters
----------
barytimes: Time, TOAs, array-like, or float
barytimes: Time, TOAs, array-like, MJDParameter, or float
MJD barycentric time(s). The times to compute the
orbital phases. Needs to be a barycentric time in TDB.
If a TOAs instance is passed, the barycentering will happen
Expand Down Expand Up @@ -911,12 +916,14 @@ def orbital_phase(self, barytimes, anom="mean", radians=True):
# return with radian units or return as unitless cycles from 0-1
return anoms * u.rad if radians else anoms / (2 * np.pi)

def pulsar_radial_velocity(self, barytimes):
def pulsar_radial_velocity(
self, barytimes: time.Time | TOAs | np.ndarray | float | MJDParameter
) -> np.ndarray:
"""Return line-of-sight velocity of the pulsar relative to the system barycenter at barycentric MJD times.
Parameters
----------
barytimes: Time, TOAs, array-like, or float
barytimes: Time, TOAs, array-like, MJDParameter, or float
MJD barycentric time(s). The times to compute the
orbital phases. Needs to be a barycentric time in TDB.
If a TOAs instance is passed, the barycentering will happen
Expand Down Expand Up @@ -957,7 +964,11 @@ def pulsar_radial_velocity(self, barytimes):
* (np.cos(psi) + bbi.ecc() * np.cos(bbi.omega()))
).cgs

def companion_radial_velocity(self, barytimes, massratio):
def companion_radial_velocity(
self,
barytimes: time.Time | TOAs | np.ndarray | float | MJDParameter,
massratio: float,
) -> np.ndarray:
"""Return line-of-sight velocity of the companion relative to the system barycenter at barycentric MJD times.
Parameters
Expand Down Expand Up @@ -993,7 +1004,7 @@ def companion_radial_velocity(self, barytimes, massratio):
"""
return -self.pulsar_radial_velocity(barytimes) * massratio

def conjunction(self, baryMJD):
def conjunction(self, baryMJD: float | time.Time) -> float | np.ndarray:
"""Return the time(s) of the first superior conjunction(s) after baryMJD.
Args
Expand Down Expand Up @@ -1054,7 +1065,7 @@ def funct(t):
return scs[0] if len(scs) == 1 else np.asarray(scs)

@property_exists
def dm_funcs(self):
def dm_funcs(self) -> List[Callable]:
"""List of all dm value functions."""
dmfs = []
for cp in self.components.values():
Expand All @@ -1065,7 +1076,7 @@ def dm_funcs(self):
return dmfs

@property_exists
def has_correlated_errors(self):
def has_correlated_errors(self) -> bool:
"""Whether or not this model has correlated errors."""

return (
Expand All @@ -1081,7 +1092,7 @@ def has_correlated_errors(self):
)

@property_exists
def has_time_correlated_errors(self):
def has_time_correlated_errors(self) -> bool:
"""Whether or not this model has time-correlated errors."""

return (
Expand All @@ -1097,7 +1108,7 @@ def has_time_correlated_errors(self):
)

@property_exists
def covariance_matrix_funcs(self):
def covariance_matrix_funcs(self) -> List[Callable]:
"""List of covariance matrix functions."""
cvfs = []
if "NoiseComponent" in self.component_types:
Expand All @@ -1106,7 +1117,7 @@ def covariance_matrix_funcs(self):
return cvfs

@property_exists
def dm_covariance_matrix_funcs(self):
def dm_covariance_matrix_funcs(self) -> List[Callable]:
"""List of covariance matrix functions."""
cvfs = []
if "NoiseComponent" in self.component_types:
Expand All @@ -1116,7 +1127,7 @@ def dm_covariance_matrix_funcs(self):

# Change sigma to uncertainty to avoid name conflict.
@property_exists
def scaled_toa_uncertainty_funcs(self):
def scaled_toa_uncertainty_funcs(self) -> List[Callable]:
"""List of scaled toa uncertainty functions."""
ssfs = []
if "NoiseComponent" in self.component_types:
Expand All @@ -1126,7 +1137,7 @@ def scaled_toa_uncertainty_funcs(self):

# Change sigma to uncertainty to avoid name conflict.
@property_exists
def scaled_dm_uncertainty_funcs(self):
def scaled_dm_uncertainty_funcs(self) -> List[Callable]:
"""List of scaled dm uncertainty functions."""
ssfs = []
if "NoiseComponent" in self.component_types:
Expand All @@ -1136,7 +1147,7 @@ def scaled_dm_uncertainty_funcs(self):
return ssfs

@property_exists
def basis_funcs(self):
def basis_funcs(self) -> List[Callable]:
"""List of scaled uncertainty functions."""
bfs = []
if "NoiseComponent" in self.component_types:
Expand All @@ -1145,34 +1156,39 @@ def basis_funcs(self):
return bfs

@property_exists
def phase_deriv_funcs(self):
def phase_deriv_funcs(self) -> List[Callable]:
"""List of derivative functions for phase components."""
return self.get_deriv_funcs("PhaseComponent")

@property_exists
def delay_deriv_funcs(self):
def delay_deriv_funcs(self) -> List[Callable]:
"""List of derivative functions for delay components."""
return self.get_deriv_funcs("DelayComponent")

@property_exists
def dm_derivs(self): # TODO need to be careful about the name here.
def dm_derivs(self) -> List[Callable]:
# TODO need to be careful about the name here.
"""List of DM derivative functions."""
return self.get_deriv_funcs("DelayComponent", "dm")

@property_exists
def toasigma_derivs(self):
def toasigma_derivs(self) -> List[Callable]:
"""List of scaled TOA uncertainty derivative functions"""
return self.get_deriv_funcs("NoiseComponent", "toasigma")

@property_exists
def d_phase_d_delay_funcs(self):
def d_phase_d_delay_funcs(self) -> List[Callable]:
"""List of d_phase_d_delay functions."""
Dphase_Ddelay = []
for cp in self.PhaseComponent_list:
Dphase_Ddelay += cp.phase_derivs_wrt_delay
return Dphase_Ddelay

def get_deriv_funcs(self, component_type, derivative_type=""):
def get_deriv_funcs(
self,
component_type: Literal["PhaseComponent", "DelayComponent", "NoiseComponent"],
derivative_type: Literal["", "dm", "toasigma"] = "",
) -> Dict[str, Callable]:
"""Return a dictionary of derivative functions.
Parameters
Expand All @@ -1197,7 +1213,7 @@ def get_deriv_funcs(self, component_type, derivative_type=""):
deriv_funcs[k] += v
return dict(deriv_funcs)

def search_cmp_attr(self, name):
def search_cmp_attr(self, name: str) -> Optional["Component"]:
"""Search for an attribute in all components.
Return the component, or None.
Expand All @@ -1210,7 +1226,7 @@ def search_cmp_attr(self, name):
return cp
raise AttributeError(f"{name} not found in any component")

def get_component_type(self, component):
def get_component_type(self, component: "Component") -> str:
"""Identify the component object's type.
Parameters
Expand Down Expand Up @@ -1243,7 +1259,9 @@ def get_component_type(self, component):
comp_type = comp_base[-3].__name__
return comp_type

def map_component(self, component):
def map_component(
self, component: Union[str, "Component"]
) -> Tuple["Component", int, List["Component"], str]:
"""Get the location of component.
Parameters
Expand Down

0 comments on commit e8f40f2

Please sign in to comment.