Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow for index arrays in ShermanMorrison #356

Merged
merged 11 commits into from
Nov 17, 2023
Merged
69 changes: 36 additions & 33 deletions enterprise/signals/signal_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from enterprise.signals.parameter import function # noqa: F401
from enterprise.signals.parameter import ConstantParameter
from enterprise.signals.utils import KernelMatrix
from enterprise.signals.utils import indices_from_slice

from enterprise import __version__
from sys import version
Expand Down Expand Up @@ -1118,6 +1119,7 @@
def __init__(self, blocks, slices, nvec=0):
self._blocks = blocks
self._slices = slices
self._idxs = [indices_from_slice(slc) for slc in slices]
self._nvec = nvec

if np.any(nvec != 0):
Expand Down Expand Up @@ -1152,15 +1154,15 @@
ZNXr = np.dot(Z[self._idx, :].T, X[self._idx, :] / self._nvec[self._idx, None])
else:
ZNXr = 0
for slc, block in zip(self._slices, self._blocks):
Zblock = Z[slc, :]
Xblock = X[slc, :]
for idx, block in zip(self._idxs, self._blocks):
Zblock = Z[idx, :]
Xblock = X[idx, :]

if slc.stop - slc.start > 1:
cf = sl.cho_factor(block + np.diag(self._nvec[slc]))
if len(idx) > 1:
cf = sl.cho_factor(block + np.diag(self._nvec[idx]))
bx = sl.cho_solve(cf, Xblock)
else:
bx = Xblock / self._nvec[slc][:, None]
bx = Xblock / self._nvec[idx][:, None]

Check warning on line 1165 in enterprise/signals/signal_base.py

View check run for this annotation

Codecov / codecov/patch

enterprise/signals/signal_base.py#L1165

Added line #L1165 was not covered by tests
ZNX += np.dot(Zblock.T, bx)
ZNX += ZNXr
return ZNX.squeeze() if len(ZNX) > 1 else float(ZNX)
Expand All @@ -1173,11 +1175,11 @@
X = X.reshape(X.shape[0], 1)

NX = X / self._nvec[:, None]
for slc, block in zip(self._slices, self._blocks):
Xblock = X[slc, :]
if slc.stop - slc.start > 1:
cf = sl.cho_factor(block + np.diag(self._nvec[slc]))
NX[slc] = sl.cho_solve(cf, Xblock)
for idx, block in zip(self._idxs, self._blocks):
Xblock = X[idx, :]
if len(idx) > 1:
cf = sl.cho_factor(block + np.diag(self._nvec[idx]))
NX[idx] = sl.cho_solve(cf, Xblock)
return NX.squeeze()

def _get_logdet(self):
Expand All @@ -1188,12 +1190,12 @@
logdet = np.sum(np.log(self._nvec[self._idx]))
else:
logdet = 0
for slc, block in zip(self._slices, self._blocks):
if slc.stop - slc.start > 1:
cf = sl.cho_factor(block + np.diag(self._nvec[slc]))
for idx, block in zip(self._idxs, self._blocks):
if len(idx) > 1:
cf = sl.cho_factor(block + np.diag(self._nvec[idx]))
logdet += np.sum(2 * np.log(np.diag(cf[0])))
else:
logdet += np.sum(np.log(self._nvec[slc]))
logdet += np.sum(np.log(self._nvec[idx]))

Check warning on line 1198 in enterprise/signals/signal_base.py

View check run for this annotation

Codecov / codecov/patch

enterprise/signals/signal_base.py#L1198

Added line #L1198 was not covered by tests
AaronDJohnson marked this conversation as resolved.
Show resolved Hide resolved
return logdet

def solve(self, other, left_array=None, logdet=False):
Expand All @@ -1218,6 +1220,7 @@
def __init__(self, jvec, slices, nvec=0.0):
self._jvec = jvec
self._slices = slices
self._idxs = [indices_from_slice(slc) for slc in slices]
self._nvec = nvec

def __add__(self, other):
Expand All @@ -1235,12 +1238,12 @@
"""Solves :math:`N^{-1}x` where :math:`x` is a vector."""

Nx = x / self._nvec
for slc, jv in zip(self._slices, self._jvec):
if slc.stop - slc.start > 1:
rblock = x[slc]
niblock = 1 / self._nvec[slc]
for idx, jv in zip(self._idxs, self._jvec):
if len(idx) > 1:
rblock = x[idx]
niblock = 1 / self._nvec[idx]
beta = 1.0 / (np.einsum("i->", niblock) + 1.0 / jv)
Nx[slc] -= beta * np.dot(niblock, rblock) * niblock
Nx[idx] -= beta * np.dot(niblock, rblock) * niblock
return Nx

def _solve_1D1(self, x, y):
Expand All @@ -1250,11 +1253,11 @@

Nx = x / self._nvec
yNx = np.dot(y, Nx)
for slc, jv in zip(self._slices, self._jvec):
if slc.stop - slc.start > 1:
xblock = x[slc]
yblock = y[slc]
niblock = 1 / self._nvec[slc]
for idx, jv in zip(self._idxs, self._jvec):
if len(idx) > 1:
xblock = x[idx]
yblock = y[idx]
niblock = 1 / self._nvec[idx]
beta = 1.0 / (np.einsum("i->", niblock) + 1.0 / jv)
yNx -= beta * np.dot(niblock, xblock) * np.dot(niblock, yblock)
return yNx
Expand All @@ -1265,11 +1268,11 @@
"""

ZNX = np.dot(Z.T / self._nvec, X)
for slc, jv in zip(self._slices, self._jvec):
if slc.stop - slc.start > 1:
Zblock = Z[slc, :]
Xblock = X[slc, :]
niblock = 1 / self._nvec[slc]
for idx, jv in zip(self._idxs, self._jvec):
if len(idx) > 1:
Zblock = Z[idx, :]
Xblock = X[idx, :]
niblock = 1 / self._nvec[idx]
beta = 1.0 / (np.einsum("i->", niblock) + 1.0 / jv)
zn = np.dot(niblock, Zblock)
xn = np.dot(niblock, Xblock)
Expand All @@ -1281,9 +1284,9 @@
is a quantization matrix.
"""
logdet = np.einsum("i->", np.log(self._nvec))
for slc, jv in zip(self._slices, self._jvec):
if slc.stop - slc.start > 1:
niblock = 1 / self._nvec[slc]
for idx, jv in zip(self._idxs, self._jvec):
if len(idx) > 1:
niblock = 1 / self._nvec[idx]
beta = 1.0 / (np.einsum("i->", niblock) + 1.0 / jv)
logdet += np.log(jv) - np.log(beta)
return logdet
Expand Down
25 changes: 18 additions & 7 deletions enterprise/signals/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -767,26 +767,37 @@ def create_quantization_matrix(toas, dt=1, nmin=2):
return U, weights


def quant2ind(U):
def quant2ind(U, as_slice=False):
"""
Use quantization matrix to return slices of non-zero elements.
Use quantization matrix to return indices of non-zero elements.

:param U: quantization matrix
:param as_slice: whether to return a slice object

:return: list of `slice`s for non-zero elements of U
:return: list of `slice`s or indices for non-zero elements of U

.. note:: This function assumes that the pulsar TOAs were sorted by time.
.. note:: For slice objects the TOAs need to be sorted by time

"""
inds = []
for cc, col in enumerate(U.T):
epinds = np.flatnonzero(col)
if epinds[-1] - epinds[0] + 1 != len(epinds):
raise ValueError("ERROR: TOAs not sorted properly!")
inds.append(slice(epinds[0], epinds[-1] + 1))
if epinds[-1] - epinds[0] + 1 != len(epinds) or not as_slice:
inds.append(epinds)
else:
inds.append(slice(epinds[0], epinds[-1] + 1))
AaronDJohnson marked this conversation as resolved.
Show resolved Hide resolved
return inds


def indices_from_slice(slc):
"""Given a slice object, return an index arrays"""

if isinstance(slc, np.ndarray):
return slc
else:
return np.arange(*slc.indices(slc.stop))
AaronDJohnson marked this conversation as resolved.
Show resolved Hide resolved


def linear_interp_basis(toas, dt=30 * 86400):
"""Provides a basis for linear interpolation.

Expand Down
37 changes: 20 additions & 17 deletions enterprise/signals/white_signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from enterprise.signals import parameter, selections, signal_base, utils
from enterprise.signals.parameter import function
from enterprise.signals.selections import Selection
from enterprise.signals.utils import indices_from_slice

try:
import fastshermanmorrison.fastshermanmorrison as fastshermanmorrison
Expand Down Expand Up @@ -217,13 +218,18 @@ def __init__(self, psr):
nepoch = sum(U.shape[1] for U in Umats)
U = np.zeros((len(psr.toas), nepoch))
self._slices = {}
self._idxs = {}
netot = 0
for ct, (key, mask) in enumerate(zip(keys, masks)):
nn = Umats[ct].shape[1]
U[mask, netot : nn + netot] = Umats[ct]
self._slices.update({key: utils.quant2ind(U[:, netot : nn + netot])})
netot += nn

self._idxs.update(
{key: [indices_from_slice(slc) for slc in slices] for (key, slices) in self._slices.items()}
)

# initialize sparse matrix
self._setup(psr)

Expand Down Expand Up @@ -252,17 +258,17 @@ def _setup(self, psr):

def _setup_sparse(self, psr):
Ns = scipy.sparse.csc_matrix((len(psr.toas), len(psr.toas)))
for key, slices in self._slices.items():
for slc in slices:
if slc.stop - slc.start > 1:
Ns[slc, slc] = 1.0
for key, idxs in self._idxs.items():
for idx in idxs:
if len(idx) > 1:
Ns[np.ix_(idx, idx)] = 1.0
self._Ns = signal_base.csc_matrix_alt(Ns)

def _get_ndiag_sparse(self, params):
for p in self._params:
for slc in self._slices[p]:
if slc.stop - slc.start > 1:
self._Ns[slc, slc] = 10 ** (2 * self.get(p, params))
for idx in self._idxs[p]:
if len(idx) > 1:
self._Ns[np.ix_(idx, idx)] = 10 ** (2 * self.get(p, params))
return self._Ns

def _get_ndiag_sherman_morrison(self, params):
Expand All @@ -274,21 +280,18 @@ def _get_ndiag_fast_sherman_morrison(self, params):
return fastshermanmorrison.FastShermanMorrison(jvec, slices)

def _get_ndiag_block(self, params):
slices, jvec = self._get_jvecs(params)
idxs, jvec = self._get_jvecs(params)
blocks = []
for jv, slc in zip(jvec, slices):
nb = slc.stop - slc.start
for jv, idx in zip(jvec, idxs):
nb = len(idx)
blocks.append(np.ones((nb, nb)) * jv)
return signal_base.BlockMatrix(blocks, slices)
return signal_base.BlockMatrix(blocks, idxs)

def _get_jvecs(self, params):
slices = sum([self._slices[key] for key in sorted(self._slices.keys())], [])
idxs = sum([self._idxs[key] for key in sorted(self._idxs.keys())], [])
jvec = np.concatenate(
[
np.ones(len(self._slices[key])) * 10 ** (2 * self.get(key, params))
for key in sorted(self._slices.keys())
]
[np.ones(len(self._idxs[key])) * 10 ** (2 * self.get(key, params)) for key in sorted(self._idxs.keys())]
)
return (slices, jvec)
return (idxs, jvec)

return EcorrKernelNoise
21 changes: 21 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,27 @@ def test_quantization_matrix(self):
assert U.shape == (4005, 235), msg1
assert all(np.sum(U, axis=0) > 1), msg2

inds = utils.quant2ind(U, as_slice=False)
slcs = utils.quant2ind(U, as_slice=True)
inds_check = [utils.indices_from_slice(slc) for slc in slcs]

msg3 = "Quantization Matrix slice not equal to quantization indices"
for ind, ind_c in zip(inds, inds_check):
assert np.all(ind == ind_c), msg3

def test_indices_from_slice(self):
"""Test conversion of slices to numpy indices"""
ind_np = np.array([2, 4, 6, 8])
ind_np_check = utils.indices_from_slice(ind_np)

msg1 = "Numpy indices not left as-is by indices_from_slice"
assert np.all(ind_np == ind_np_check), msg1

slc = slice(2, 10, 2)
ind_np_check = utils.indices_from_slice(slc)
msg2 = "Slice not converted properly by indices_from_slice"
assert np.all(ind_np == ind_np_check), msg2

def test_psd(self):
"""Test PSD functions."""
Tmax = self.psr.toas.max() - self.psr.toas.min()
Expand Down
Loading