Skip to content

Commit

Permalink
Run lint/black
Browse files Browse the repository at this point in the history
  • Loading branch information
tcromartie committed Apr 20, 2024
1 parent 097e621 commit c117fcf
Showing 1 changed file with 24 additions and 77 deletions.
101 changes: 24 additions & 77 deletions enterprise/pulsar.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,7 @@
try:
import libstempo as t2
except ImportError:
logger.warning(
"libstempo not installed. Will use PINT instead."
) # pragma: no cover
logger.warning("libstempo not installed. Will use PINT instead.") # pragma: no cover
t2 = None

try:
Expand All @@ -34,9 +32,7 @@
from pint.residuals import Residuals as resids
from pint.toa import TOAs
except ImportError:
logger.warning(
"PINT not installed. Will use libstempo instead."
) # pragma: no cover
logger.warning("PINT not installed. Will use libstempo instead.") # pragma: no cover
pint = None

if pint is None and t2 is None:
Expand Down Expand Up @@ -136,9 +132,7 @@ def filter_data(self, start_time=None, end_time=None):
if start_time is None and end_time is None:
mask = np.ones(self._toas.shape, dtype=bool)
else:
mask = np.logical_and(
self._toas >= start_time * 86400, self._toas <= end_time * 86400
)
mask = np.logical_and(self._toas >= start_time * 86400, self._toas <= end_time * 86400)

self._toas = self._toas[mask]
self._toaerrs = self._toaerrs[mask]
Expand Down Expand Up @@ -244,11 +238,7 @@ def dmx(self):
def flags(self):
"""Return a dictionary of tim-file flags."""

flagnames = (
self._flags.dtype.names
if isinstance(self._flags, np.ndarray)
else self._flags.keys()
)
flagnames = self._flags.dtype.names if isinstance(self._flags, np.ndarray) else self._flags.keys()

return {flag: self._flags[flag][self._isort] for flag in flagnames}

Expand All @@ -271,25 +261,16 @@ def backend_flags(self):
"""

# collect flag names
flagnames = (
self._flags.dtype.names
if isinstance(self._flags, np.ndarray)
else list(self._flags.keys())
)
flagnames = self._flags.dtype.names if isinstance(self._flags, np.ndarray) else list(self._flags.keys())

# allocate array with widest dtype
ret = np.zeros(
len(self._toas), dtype=max([self._flags[name].dtype for name in flagnames])
)
ret = np.zeros(len(self._toas), dtype=max([self._flags[name].dtype for name in flagnames]))

# go through the flags in reverse order of preference
# setting or replacing values for each TOA

if "fe" in flagnames and "be" in flagnames:
ret[:] = [
(a + "_" + b if (a and b) else "")
for a, b in zip(self._flags["fe"], self._flags["be"])
]
ret[:] = [(a + "_" + b if (a and b) else "") for a, b in zip(self._flags["fe"], self._flags["be"])]

for flag in ["f", "i", "sys", "g", "group"]:
if flag in flagnames:
Expand Down Expand Up @@ -345,23 +326,17 @@ def __init__(self, toas, model, sort=True, drop_pintpsr=True, planets=True):

# these are TDB but not barycentered
# self._toas = np.array(toas.table["tdbld"], dtype="float64") * 86400
self._toas = (
np.array(model.get_barycentric_toas(toas).value, dtype="float64") * 86400
)
self._toas = np.array(model.get_barycentric_toas(toas).value, dtype="float64") * 86400
# saving also stoas (e.g., for DMX comparisons)
self._stoas = np.array(toas.get_mjds().value, dtype="float64") * 86400
self._residuals = np.array(
resids(toas, model).time_resids.to(u.s), dtype="float64"
)
self._residuals = np.array(resids(toas, model).time_resids.to(u.s), dtype="float64")
self._toaerrs = np.array(toas.get_errors().to(u.s), dtype="float64")
self._designmatrix = model.designmatrix(toas)[0]
self._ssbfreqs = np.array(model.barycentric_radio_freq(toas), dtype="float64")
self._telescope = np.array(toas.get_obss())

# fitted parameters
self.fitpars = ["Offset"] + [
par for par in model.params if not getattr(model, par).frozen
]
self.fitpars = ["Offset"] + [par for par in model.params if not getattr(model, par).frozen]

# gather DM/DMX information if available
self._set_dm(model)
Expand Down Expand Up @@ -393,9 +368,7 @@ def __init__(self, toas, model, sort=True, drop_pintpsr=True, planets=True):
self._sunssb = self._get_sunssb(toas, model)

which_astrometry = (
"AstrometryEquatorial"
if "AstrometryEquatorial" in model.components
else "AstrometryEcliptic"
"AstrometryEquatorial" if "AstrometryEquatorial" in model.components else "AstrometryEcliptic"
)

self._pos_t = (
Expand All @@ -415,9 +388,7 @@ def _set_dm(self, model):
dmx = {
par: {
"DMX": float(model[par].value),
"DMXerr": None
if model[par].uncertainty_value is None
else float(model[par].uncertainty_value),
"DMXerr": None if model[par].uncertainty_value is None else float(model[par].uncertainty_value),
"DMXR1": float(model[par[:3] + "R1" + par[3:]].value),
"DMXR2": float(model[par[:3] + "R2" + par[3:]].value),
"fit": par in pars,
Expand Down Expand Up @@ -511,9 +482,7 @@ def __init__(self, t2pulsar, sort=True, drop_t2pulsar=True, planets=True):
flags[key] = t2pulsar.flagvals(key)

# new-style storage of flags as a numpy record array (previously, psr._flags = flags)
self._flags = np.zeros(
len(self._toas), dtype=[(key, val.dtype) for key, val in flags.items()]
)
self._flags = np.zeros(len(self._toas), dtype=[(key, val.dtype) for key, val in flags.items()])
for key, val in flags.items():
self._flags[key] = val

Expand All @@ -527,9 +496,7 @@ def __init__(self, t2pulsar, sort=True, drop_t2pulsar=True, planets=True):
self._set_dm(t2pulsar)

self._pos_t = t2pulsar.psrPos.copy()
if "ELONG" and "ELAT" in np.concatenate(
(t2pulsar.pars(which="fit"), t2pulsar.pars(which="set"))
):
if "ELONG" and "ELAT" in np.concatenate((t2pulsar.pars(which="fit"), t2pulsar.pars(which="set"))):
self._pos_t = utils.ecl2eq_vec(self._pos_t)

self.sort_data()
Expand Down Expand Up @@ -560,9 +527,7 @@ def _set_dm(self, t2pulsar):
self._dmx = dmx

def _get_radec(self, t2pulsar):
if "RAJ" in np.concatenate(
(t2pulsar.pars(which="fit"), t2pulsar.pars(which="set"))
):
if "RAJ" in np.concatenate((t2pulsar.pars(which="fit"), t2pulsar.pars(which="set"))):
return (np.double(t2pulsar["RAJ"].val), np.double(t2pulsar["DECJ"].val))

else:
Expand All @@ -589,9 +554,7 @@ def _get_planetssb(self, t2pulsar):
planetssb[:, 7, :] = self.t2pulsar.neptune_ssb
planetssb[:, 8, :] = self.t2pulsar.pluto_ssb

if "ELONG" and "ELAT" in np.concatenate(
(t2pulsar.pars(), t2pulsar.pars(which="set"))
):
if "ELONG" and "ELAT" in np.concatenate((t2pulsar.pars(), t2pulsar.pars(which="set"))):
for ii in range(9):
planetssb[:, ii, :3] = utils.ecl2eq_vec(planetssb[:, ii, :3])
planetssb[:, ii, 3:] = utils.ecl2eq_vec(planetssb[:, ii, 3:])
Expand All @@ -607,9 +570,7 @@ def _get_sunssb(self, t2pulsar):
sunssb = np.zeros((len(self._toas), 6))
sunssb[:, :] = self.t2pulsar.sun_ssb

if "ELONG" and "ELAT" in np.concatenate(
(t2pulsar.pars(), t2pulsar.pars(which="set"))
):
if "ELONG" and "ELAT" in np.concatenate((t2pulsar.pars(), t2pulsar.pars(which="set"))):
sunssb[:, :3] = utils.ecl2eq_vec(sunssb[:, :3])
sunssb[:, 3:] = utils.ecl2eq_vec(sunssb[:, 3:])
return sunssb
Expand Down Expand Up @@ -666,24 +627,16 @@ def Pulsar(*args, **kwargs):
t2pulsar = [x for x in args if isinstance(x, t2.tempopulsar)]

parfile = [x for x in args if isinstance(x, str) and x.split(".")[-1] == "par"]
timfile = [
x for x in args if isinstance(x, str) and x.split(".")[-1] in ["tim", "toa"]
]
timfile = [x for x in args if isinstance(x, str) and x.split(".")[-1] in ["tim", "toa"]]

if pint and toas and model:
return PintPulsar(
toas[0], model[0], sort=sort, drop_pintpsr=drop_pintpsr, planets=planets
)
return PintPulsar(toas[0], model[0], sort=sort, drop_pintpsr=drop_pintpsr, planets=planets)
elif t2 and t2pulsar:
return Tempo2Pulsar(
t2pulsar[0], sort=sort, drop_t2pulsar=drop_t2pulsar, planets=planets
)
return Tempo2Pulsar(t2pulsar[0], sort=sort, drop_t2pulsar=drop_t2pulsar, planets=planets)
elif parfile and timfile:
# Check whether the two files exist
if not os.path.isfile(parfile[0]) or not os.path.isfile(timfile[0]):
msg = "Cannot find parfile {0} or timfile {1}!".format(
parfile[0], timfile[0]
)
msg = "Cannot find parfile {0} or timfile {1}!".format(parfile[0], timfile[0])
raise IOError(msg)

# Obtain the directory name of the timfile, and change to it
Expand All @@ -710,19 +663,13 @@ def Pulsar(*args, **kwargs):
planets=planets,
)
os.chdir(cwd)
return PintPulsar(
toas, model, sort=sort, drop_pintpsr=drop_pintpsr, planets=planets
)
return PintPulsar(toas, model, sort=sort, drop_pintpsr=drop_pintpsr, planets=planets)

elif timing_package.lower() == "tempo2":
# hack to set maxobs
maxobs = get_maxobs(reltimfile) + 100
t2pulsar = t2.tempopulsar(
relparfile, reltimfile, maxobs=maxobs, ephem=ephem, clk=clk
)
t2pulsar = t2.tempopulsar(relparfile, reltimfile, maxobs=maxobs, ephem=ephem, clk=clk)
os.chdir(cwd)
return Tempo2Pulsar(
t2pulsar, sort=sort, drop_t2pulsar=drop_t2pulsar, planets=planets
)
return Tempo2Pulsar(t2pulsar, sort=sort, drop_t2pulsar=drop_t2pulsar, planets=planets)

raise ValueError("Unknown arguments {}".format(args))

0 comments on commit c117fcf

Please sign in to comment.