Skip to content

Commit

Permalink
Fixed oversampled PSF binning to bin around and out from the centre p…
Browse files Browse the repository at this point in the history
…ixel
  • Loading branch information
york-stsci committed Sep 29, 2021
1 parent 7b2e2e0 commit 89ad032
Show file tree
Hide file tree
Showing 4 changed files with 125 additions and 50 deletions.
38 changes: 25 additions & 13 deletions stips/astro_image/astro_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,13 +611,14 @@ def makePoint(id, x, y, rate, psf_params, overrides, logger):
"""
ix, iy = int(np.floor(x)), int(np.floor(y))
if psf_params['has_psf']:
psf_img = getReducedPsf(psf_params, rate, x, y)
residual_ipc = SelectParameter("residual_ipc", overrides)
psf_type = SelectParameter("psf_type", overrides)
if residual_ipc and psf_type == "orig":
psf_img = apply_ipc(psf_img, 1)
psf_img = PSFGrid.reduced_psf_from_fits(psf_params['psf_file'],
psf_params['psf_ext'],
rate,
x,
y)
else:
psf_img = np.array([[0., 0., 0.], [0., rate, 0.], [0., 0., 0.]], dtype=np.float64)
psf_img = np.array([[0., 0., 0.], [0., rate, 0.], [0., 0., 0.]],
dtype=np.float64)
return psf_img,ix,iy


Expand All @@ -630,8 +631,7 @@ def addPoints(self, ids, xs, ys, rates, *args, **kwargs):
self._log("info", msg.format(len(xs),self.name))

psf_file = os.path.join(self.psf_dir, self.psf_name)
psf_params = {'psf_file': psf_file, 'psf_oversample': self.psf_oversample,
'has_psf': self.has_psf}
psf_params = {'psf_file': psf_file, 'has_psf': self.has_psf}
psf_type = SelectParameter('psf_type', self.overrides)
if SelectParameter('residual_ipc', self.overrides):
if 'ipc' not in psf_type:
Expand Down Expand Up @@ -842,7 +842,11 @@ def galsimSersic(id, gal_params, psf_params, xsize, ysize, dir, overrides, logge
central_flux = img[iyc, ixc]

if psf_params['has_psf']:
psf_img = getReducedPsf(psf_params, 1., gal_params['x'], gal_params['y'])
psf_img = PSFGrid.reduced_psf_from_fits(psf_params['psf_file'],
psf_params['psf_ext'],
1.,
gal_params['x'],
gal_params['y'])
result = convolve_fft(img, psf_img, allow_huge=True)
else:
result = img
Expand Down Expand Up @@ -951,7 +955,11 @@ def astropySersic(id, gal_params, psf_params, xsize, ysize, dir, overrides, logg
img *= flux / np.sum(img)

if psf_params['has_psf']:
psf_img = getReducedPsf(psf_params, 1., gal_params['x'], gal_params['y'])
psf_img = PSFGrid.reduced_psf_from_fits(psf_params['psf_file'],
psf_params['psf_ext'],
1.,
gal_params['x'],
gal_params['y'])
result = convolve_fft(img, psf_img, allow_huge=True)
else:
result = img
Expand Down Expand Up @@ -1057,7 +1065,11 @@ def pandeiaSersic(galnum, gal_params, psf_params, xsize, ysize, dir, overrides,

if psf_params['has_psf']:
# print("\tGenerating PSF ({})".format(time.ctime()))
psf_img = getReducedPsf(psf_params, 1., gal_params['x'], gal_params['y'])
psf_img = PSFGrid.reduced_psf_from_fits(psf_params['psf_file'],
psf_params['psf_ext'],
1.,
gal_params['x'],
gal_params['y'])
# print("\tPSF image has size {} dtype {}".format(psf_img.shape, psf_img.dtype))
# print("\tConvolving by PSF ({})".format(time.ctime()))
result = convolve_fft(img, psf_img)
Expand Down Expand Up @@ -1129,8 +1141,8 @@ def addSersics(self, ids, xs, ys, fluxes, ns, res, phis, ratios, *args, **kwargs

central_fluxes = []
psf_file = os.path.join(self.psf_dir, self.psf_name)
psf_params = {'psf_file': psf_file, 'psf_oversample': self.psf_oversample,
'has_psf': self.has_psf, 'psf_ext': self.psf_ext}
psf_params = {'psf_file': psf_file, 'has_psf': self.has_psf,
'psf_ext': self.psf_ext}
if self.parallel_enable:
self._log("info", "Adding sersic profiles in parallel.")

Expand Down
133 changes: 98 additions & 35 deletions stips/utilities/PSFGrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@
"""

import logging
import numpy as np
import os
import shutil
import sys
import webbpsf

from collections import defaultdict
Expand Down Expand Up @@ -50,7 +52,7 @@ def make_epsf_kernel(oversample):
if oversample%2 == 0:
kernel_size += 1

kernel = np.ones((kernel_size, kernel_size), dtype=np.int64)
kernel = np.ones((kernel_size, kernel_size), dtype=np.float64)

if oversample%2 == 0:
kernel[0,:] = kernel[0,:] * 0.5
Expand Down Expand Up @@ -116,6 +118,15 @@ def __init__(self, data, **kwargs):
Slight override to the parent class __init__ method, adding in some PSF
information and parameters to the metadata attribute.
"""
# if self.logger.level == logging.DEBUG:
self.logger = logging.getLogger('__stips__')
log_level = SelectParameter("log_level")
self.logger.setLevel(getattr(logging, log_level))
if not len(self.logger.handlers):
stream_handler = logging.StreamHandler(sys.stderr)
format = '%(asctime)s %(levelname)s: %(message)s'
stream_handler.setFormatter(logging.Formatter(format))
self.logger.addHandler(stream_handler)
self.data = {}
if 'meta' not in kwargs:
meta = {
Expand Down Expand Up @@ -148,7 +159,6 @@ def create_from_parameters(cls, **kwargs):
- oversample
- grid size
"""
logger = kwargs.get('logger', None)
data = {}
meta = {}
meta['psf_fov_pixels'] = SelectParameter('psf_fov_pixels', kwargs)
Expand All @@ -158,7 +168,6 @@ def create_from_parameters(cls, **kwargs):
meta['instrument'] = kwargs.get('instrument', 'wfi').lower()
meta['detector'] = kwargs.get('detector', 'sca01').lower()
meta['filter'] = kwargs.get('filter', 'f067').lower()
oversample_by_scale = kwargs.get('oversample_by_scale', True)
pixel_scale = kwargs.get('pixel_scale', 0.11)
psf_grid_size = SelectParameter('psf_grid_default_size', kwargs)
num_psfs = psf_grid_size*psf_grid_size
Expand All @@ -174,28 +183,21 @@ def create_from_parameters(cls, **kwargs):
psf_xgrid = np.linspace(pix_limit['low'], pix_limit['high'], psf_grid_size)
psf_ygrid = np.linspace(pix_limit['low'], pix_limit['high'], psf_grid_size)
i = 0
for xp in psf_xgrid:
for yp in psf_ygrid:
if logger is not None:
logger.info("Calculating PSF at ({},{})".format(xp, yp))
for yp in psf_ygrid:
for xp in psf_xgrid:
self.logger.info("Calculating PSF at ({},{})".format(xp, yp))
meta['grid_xypos'].append((xp, yp))
ins = cls.get_instrument(meta['telescope'], meta['instrument'])
ins.options['parity'] = 'odd'
ins.filter = meta['filter']
ins.detector = meta['detector']
ins.detector_position = (xp, yp)
if oversample_by_scale:
ins.pixelscale = pixel_scale/meta['oversampling']
psf_fits = ins.calc_psf(fov_pixels=fov,
nlambda=meta['psf_nlambda'],
oversample=1)
psf = psf_fits["DET_SAMP"].data
else:
ins.pixelscale = pixel_scale
psf_fits = ins.calc_psf(fov_pixels=meta['psf_fov_pixels'],
nlambda=meta['psf_nlambda'],
oversample=meta['oversampling'])
psf = psf_fits['OVERSAMP'].data
ins.pixelscale = pixel_scale/meta['oversampling']
psf_fits = ins.calc_psf(fov_pixels=fov,
nlambda=meta['psf_nlambda'],
oversample=1)
psf = psf_fits["DET_SAMP"].data
self.logger.info("\tPSF size is {}".format(psf.shape))
data['orig'][i][:] = deepcopy(psf) * meta['oversampling']**2
data['orig_ipc'][i][:] = deepcopy(data['orig'][i][:])
data['epsf'][i][:] = make_epsf(deepcopy(psf), meta['oversampling'])
Expand Down Expand Up @@ -243,6 +245,16 @@ def create_from_fits(cls, file_name):
i += 1
data[psf_type] = GriddedPSFModel(NDData(psf_data, meta=meta))
return cls(data, meta=meta)


@classmethod
def reduced_psf_from_fits(cls, file_name, psf_ext, flux, cx, cy):
"""
Instantiates from a FITS file, then loads and bins a PSF at a given centre, then
softly and silently vanishes away.
"""
model = cls.create_from_fits(file_name)
return model.get_reduced_psf(psf_ext, flux, cx, cy)


def from_fits(self, file_name):
Expand Down Expand Up @@ -279,22 +291,6 @@ def to_fits(self, file_name):
hdul.writeto(file_name, overwrite=True)


@staticmethod
def pix_limit(telescope, instrument, detector):
roman_det = defaultdict(lambda: {'low': 4, 'high': 4092})
roman_ins = defaultdict(lambda: roman_det)
miri_det = defaultdict(lambda: {'low': 4, 'high': 1024})
nircam_det = defaultdict(lambda: {'low': 0, 'high': 2040})
detector_dict = {
'roman': roman_ins,
'jwst': {
'nircam': nircam_det,
'miri': miri_det
}
}
return detector_dict[telescope][instrument][detector]


def make_gridded_model(self, data, **kwargs):
"""
Given data (which may be a GriddedPSFModel, an astropy NDData, or a numpy array),
Expand Down Expand Up @@ -323,7 +319,74 @@ def make_gridded_model(self, data, **kwargs):
for yp in psf_ygrid:
meta['grid_xypos'].append((xp,yp))
return GriddedPSFModel(NDData(data, meta=meta), **kwargs)


def make_bin(self, x_size, y_size):
"""
Create X and Y lists that will result in creating lists of centre pixels from
which to create the lovely binned PSF model.
"""
oversample = self.meta['oversampling']
xl = np.arange(x_size//2, oversample//2, -oversample)
xh = np.arange(x_size//2+oversample, x_size-oversample//2, oversample)
x_arr = list(np.append(np.sort(xl), np.sort(xh)))
yl = np.arange(y_size//2, oversample//2, -oversample)
yh = np.arange(y_size//2+oversample, y_size-oversample//2, oversample)
y_arr = list(np.append(np.sort(yl), np.sort(yh)))
return x_arr, y_arr


def get_reduced_psf(self, psf_ext, flux, cx, cy):
"""
Evaluate the appropriate (chosen based on psf_ext) GriddedPSFModel PSF, and then
bin it down to detector sampling, bearing in mind that the centre pixel always
needs to be the centre of a bin.
"""
psf_data = self.data[psf_ext]
ix, iy = int(np.floor(cx)), int(np.floor(cy))
oversample = self.meta['oversampling']
psf_pix = self.meta['psf_fov_pixels']

psf_y, psf_x = psf_data.data.shape[1], psf_data.data.shape[2]
ly, hy = iy - psf_y//2, iy + psf_y//2 + 1
lx, hx = ix - psf_x//2, ix + psf_x//2 + 1
y_arr,x_arr = np.mgrid[ly:hy, lx:hx]
psf_img = psf_data.evaluate(x_arr, y_arr, flux=flux, x_0=cx, y_0=cy)

kernel = make_epsf_kernel(oversample)
# print("PSF Image Size: {}".format(psf_img.shape))
x_list, y_list = self.make_bin(psf_x, psf_y)
binned_psf = np.zeros((len(y_list), len(x_list)), dtype=np.float64)
bx, by = 0, 0
for y in y_list:
for x in x_list:
kyl, kyh = y - oversample//2, y + oversample//2 + 1
kxl, kxh = x - oversample//2, x + oversample//2 + 1
# print("(x,y) = ({},{})".format(x,y))
# print("\tUsing img[{}:{},{}:{}]".format(kyl, kyh, kxl, kxh))
binned_psf[by,bx] = np.sum(kernel[:,:]*psf_img[kyl:kyh,kxl:kxh])
bx += 1
by += 1
bx = 0

return binned_psf


@staticmethod
def pix_limit(telescope, instrument, detector):
roman_det = defaultdict(lambda: {'low': 4, 'high': 4092})
roman_ins = defaultdict(lambda: roman_det)
miri_det = defaultdict(lambda: {'low': 4, 'high': 1024})
nircam_det = defaultdict(lambda: {'low': 0, 'high': 2040})
detector_dict = {
'roman': roman_ins,
'jwst': {
'nircam': nircam_det,
'miri': miri_det
}
}
return detector_dict[telescope][instrument][detector]


@staticmethod
def get_instrument(telescope, instrument):
Expand Down
2 changes: 1 addition & 1 deletion stips/utilities/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,6 @@ def getReducedPsf(psf_params, flux, cx, cy):
psf_img = psf_data.evaluate(x_arr, y_arr, flux=flux, x_0=cx, y_0=cy)
if np.sum(psf_img) == 0.:
print("ERROR: ZERO FLUX PSF: {} ({},{})".format(psf_params, cx, cy))
psf_img = psf_img.reshape(psf_y//os, os, psf_x//os, os).sum(axis=(1, 3))[1:-1,1:-1]
psf_img = psf_img.reshape(psf_y//os, os, psf_x//os, os).sum(axis=(1, 3))

return psf_img
2 changes: 1 addition & 1 deletion stips/version.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
__version__ = '2.0.0dev7'
__version__ = '2.0.0dev8'
__data__version__ = '2.0.0'

0 comments on commit 89ad032

Please sign in to comment.