diff --git a/.travis.yml b/.travis.yml index 2781fff..434797e 100644 --- a/.travis.yml +++ b/.travis.yml @@ -115,40 +115,46 @@ matrix: before_install: + # Make data dir + - mkdir $TRAVIS_BUILD_DIR/test_data + - TEST_DATA_PATH=$TRAVIS_BUILD_DIR/test_data + # Download stips data - - wget https://stsci.box.com/shared/static/iufbhsiu0lts16wmdsi12cun25888nrb.gz -O /tmp/stips_data_current.tar.gz - - tar -xzvf /tmp/stips_data_current.tar.gz - - export stips_data=$PWD/stips_data + - wget https://stsci.box.com/shared/static/iufbhsiu0lts16wmdsi12cun25888nrb.gz -O $TEST_DATA_PATH/stips_data_current.tar.gz + - tar -xzvf $TEST_DATA_PATH/stips_data_current.tar.gz -C $TEST_DATA_PATH/ + - export stips_data=$TEST_DATA_PATH/stips_data # download synphot data - - mkdir synphot + - mkdir $TEST_DATA_PATH/synphot - - wget -qO- http://ssb.stsci.edu/cdbs/tarfiles/synphot1.tar.gz -O /tmp/synphot1.tar.gz - - wget -qO- http://ssb.stsci.edu/cdbs/tarfiles/synphot2.tar.gz -O /tmp/synphot2.tar.gz - - wget -qO- http://ssb.stsci.edu/cdbs/tarfiles/synphot5.tar.gz -O /tmp/synphot5.tar.gz + - wget -qO- http://ssb.stsci.edu/cdbs/tarfiles/synphot1.tar.gz -O $TEST_DATA_PATH/synphot1.tar.gz + - wget -qO- http://ssb.stsci.edu/cdbs/tarfiles/synphot2.tar.gz -O $TEST_DATA_PATH/synphot2.tar.gz + - wget -qO- http://ssb.stsci.edu/cdbs/tarfiles/synphot5.tar.gz -O $TEST_DATA_PATH/synphot5.tar.gz - - tar -xzvf /tmp/synphot1.tar.gz -C $PWD/synphot/ - - tar -xzvf /tmp/synphot2.tar.gz -C $PWD/synphot/ - - tar -xzvf /tmp/synphot5.tar.gz -C $PWD/synphot/ + - tar -xzvf $TEST_DATA_PATH/synphot1.tar.gz -C $TEST_DATA_PATH/synphot/ + - tar -xzvf $TEST_DATA_PATH/synphot2.tar.gz -C $TEST_DATA_PATH/synphot/ + - tar -xzvf $TEST_DATA_PATH/synphot5.tar.gz -C $TEST_DATA_PATH/synphot/ - - export PYSYN_CDBS=$PWD/synphot/grp/hst/cdbs + - export PYSYN_CDBS=$TEST_DATA_PATH/synphot/grp/hst/cdbs # download pandeia_data data - - wget -qO- https://stsci.box.com/shared/static/5j506xzg9tem2l7ymaqzwqtxne7ts3sr.gz -O /tmp/pandeia_data-1.5_wfirst.gz + - wget -qO- https://stsci.box.com/shared/static/5j506xzg9tem2l7ymaqzwqtxne7ts3sr.gz -O $TEST_DATA_PATH/pandeia_data-1.5_wfirst.gz - - tar -xzvf /tmp/pandeia_data-1.5_wfirst.gz -C $PWD/ - - export pandeia_refdata=$PWD/pandeia_data-1.5_wfirst + - tar -xzvf $TEST_DATA_PATH/pandeia_data-1.5_wfirst.gz -C $TEST_DATA_PATH/ + - export pandeia_refdata=$TEST_DATA_PATH/pandeia_data-1.5_wfirst # download webbpsf data - - wget -qO- https://stsci.box.com/shared/static/qcptcokkbx7fgi3c00w2732yezkxzb99.gz -O /tmp/webbpsf_data.gz - - tar -xzvf /tmp/webbpsf_data.gz -C $PWD/ - - export WEBBPSF_PATH=$PWD/webbpsf-data - + - wget -qO- https://stsci.box.com/shared/static/qcptcokkbx7fgi3c00w2732yezkxzb99.gz -O $TEST_DATA_PATH/webbpsf_data.gz + - tar -xzvf $TEST_DATA_PATH/webbpsf_data.gz -C $TEST_DATA_PATH/ + - export WEBBPSF_PATH=$TEST_DATA_PATH/webbpsf-data + # If VERYSLOW is set to false for a test, add args to pytest + # The format (... "..") is because there are multiple quotation marks in bash + # They variable will be constructed into a command via "${extra_args[@]}" - | if [[ $VERYSLOW == 'false' && $SETUP_CMD == 'test' ]]; then echo "Testing functions without veryslow marker" - extra_args=(-a " -m \'veryslow\' "); + extra_args=(-a "-m 'not veryslow'"); else extra_args=(); fi diff --git a/setup.cfg b/setup.cfg index 297c8a9..5604a9b 100644 --- a/setup.cfg +++ b/setup.cfg @@ -49,6 +49,11 @@ minversion = 3.0 norecursedirs = build docs/_build doctest_plus = enabled addopts = -p no:warnings +markers = + slow: takes more than 12.5 minutes. Deselect with '-m "not slow"' + veryslow: takes more than 1 hour. Deselect with '-m "not veryslow"' + network: requires internet connection. Deselect with '-m "not network"' + [ah_bootstrap] auto_use = True diff --git a/stips/astro_image/astro_image.py b/stips/astro_image/astro_image.py index 6a3834c..6844766 100644 --- a/stips/astro_image/astro_image.py +++ b/stips/astro_image/astro_image.py @@ -7,9 +7,11 @@ import numpy as np from astropy import wcs -from astropy.io import fits as pyfits +from astropy.io import fits from astropy.table import Table, Column +from copy import deepcopy from photutils import CircularAperture, aperture_photometry +from photutils.psf.models import GriddedPSFModel from scipy.ndimage.interpolation import zoom, rotate @@ -19,7 +21,14 @@ from cStringIO import StringIO #Local Modules -from ..utilities import OffsetPosition, overlapadd2, overlapaddparallel, read_table, ImageData, Percenter, StipsDataTable +from stips.version import __version__ as stips_version +from ..utilities import OffsetPosition +from ..utilities import overlapadd2 +from ..utilities import overlapaddparallel +from ..utilities import read_table +from ..utilities import ImageData +from ..utilities import Percenter +from ..utilities import StipsDataTable from ..errors import GetCrProbs, GetCrTemplate, MakeCosmicRay @@ -41,41 +50,102 @@ def __init__(self, **kwargs): Astronomical image. The __init__ function creates an empty image with all other data values set to zero. """ - - if 'logger' in kwargs: - self.logger = kwargs['logger'] + default = self.INSTRUMENT_DEFAULT + + if 'parent' in kwargs: + self.parent = kwargs['parent'] + self.logger = self.parent.logger + self.out_path = self.parent.out_path + self.prefix = self.parent.prefix + self.seed = self.parent.seed + self.telescope = self.parent.TELESCOPE.lower() + self.instrument = self.parent.PSF_INSTRUMENT + self.filter = self.parent.filter + self.oversample = self.parent.oversample + self.shape = np.array(self.parent.DETECTOR_SIZE)*self.oversample + self._scale = np.array(self.parent.SCALE)/self.oversample + self.zeropoint = self.parent.zeropoint + self.photflam = self.parent.photflam + self.photplam = self.parent.PHOTPLAM[self.filter] + background = self.parent.background + self.psf_grid_size = self.parent.psf_grid_size + self.psf_commands = self.parent.psf_commands + small_subarray = self.parent.small_subarray + self.cat_type = self.parent.cat_type + self.set_celery = self.parent.set_celery + self.get_celery = self.parent.get_celery + self.convolve_size = self.parent.convolve_size + self.memmap = self.parent.memmap else: - self.logger = logging.getLogger('__stips__') - self.logger.setLevel(logging.INFO) - if not len(self.logger.handlers): - stream_handler = logging.StreamHandler(sys.stderr) - stream_handler.setFormatter(logging.Formatter('%(asctime)s %(levelname)s: %(message)s'))# [in %(pathname)s:%(lineno)d]')) - self.logger.addHandler(stream_handler) - + self.parent = None + if 'logger' in kwargs: + self.logger = kwargs['logger'] + else: + self.logger = logging.getLogger('__stips__') + self.logger.setLevel(logging.INFO) + if not len(self.logger.handlers): + stream_handler = logging.StreamHandler(sys.stderr) + format = '%(asctime)s %(levelname)s: %(message)s' + stream_handler.setFormatter(logging.Formatter(format)) + self.logger.addHandler(stream_handler) + self.out_path = kwargs.get('out_path', os.getcwd()) + self.oversample = kwargs.get('oversample', default['oversample']) + self.shape = kwargs.get('shape', default['shape']) + self.shape = np.array(self.shape) * self.oversample + self._scale = kwargs.get('scale', np.array(default['scale'])) + self.prefix = kwargs.get('prefix', '') + self.cat_type = kwargs.get('cat_type', 'fits') + self.set_celery = kwargs.get('set_celery', None) + self.get_celery = kwargs.get('get_celery', None) + self.seed = kwargs.get('seed', 1234) + small_subarray = kwargs.get('small_subarray', False) + self.zeropoint = kwargs.get('zeropoint', default['zeropoint']) + self.photflam = kwargs.get('photflam', default['photflam']) + self.photplam = kwargs.get('photplam', default['photplam']) + background = kwargs.get('background', default['background']) + self.telescope = kwargs.get('telescope', default['telescope']) + self.instrument = kwargs.get('instrument', default['instrument']) + self.filter = kwargs.get('filter', default['filter']) + self.psf_grid_size = kwargs.get('psf_grid_size', + default['psf_grid_size']) + self.psf_commands = kwargs.get('psf_commands', '') + self.convolve_size = kwargs.get('convolve_size', 8192) + self.memmap = kwargs.get('memmap', True) + + if self.get_celery is None: + self.get_celery = lambda: "" + if self.set_celery is None: + self.set_celery = lambda x: None + #Set unique ID and figure out where the numpy memmap will be stored - self.out_path = kwargs.get('out_path', os.getcwd()) - self.prefix = kwargs.get('prefix', '') - self.cat_type = kwargs.get('cat_type', 'fits') - self.name = kwargs.get('detname', "") - self.fname = os.path.join(self.out_path, self.prefix+"_"+uuid.uuid4().hex+"_"+self.name+".tmp") - self.set_celery = kwargs.get('set_celery', None) - self.get_celery = kwargs.get('get_celery', None) - self.seed = kwargs.get('seed', 1234) + self.name = kwargs.get('detname', default['detector'][self.instrument]) + self.detector = self.name + if self.memmap: + fname = self.prefix+"_"+uuid.uuid4().hex+"_"+self.name+".tmp" + self.fname = os.path.join(self.out_path, fname) + + if self.psf_commands is None: + self.psf_commands = '' + psf = kwargs.get('psf', True) + if psf: + self.make_psf() - self.oversample = kwargs.get('oversample', 1) - psf_shape = kwargs.get('psf_shape', (0, 0)) data = kwargs.get('data', None) if data is not None: - base_shape = data.shape + base_shape = np.array(data.shape) else: - if kwargs.get('small_subarray', False): #restrict data size to PSF size - base_shape = psf_shape + #restrict data size to PSF size + if small_subarray: + if not hasattr(self, 'psf'): + msg = "{}: Unable to set image size to PSF size when image " + msg += "has no valid PSF." + raise ValueError(msg.format(self.name)) + base_shape = self.psf_shape else: - base_shape = kwargs.get('shape', (1, 1)) - self._init_dat(base_shape, psf_shape, data) + base_shape = self.shape + self._init_dat(base_shape, self.psf_shape, data) #Get WCS values if present, or set up a default - self._scale = kwargs.get('scale', [0., 0.]) self.wcs = self._getWcs(**kwargs) self._prepRaDec() @@ -95,11 +165,8 @@ def __init__(self, **kwargs): #Special values for Sersic profile generation self.profile_multiplier = kwargs.get('profile_multiplier', 100.) - self.noise_floor = max(kwargs.get('background', 0.), kwargs.get('noise_floor', 1.)) - - #Zero Point. Necessary for output catalogues. - self.zeropoint = kwargs.get('zeropoint', 0.) - self.photflam = kwargs.get('photflam', 0.) + self.noise_floor = max(background, kwargs.get('noise_floor', 1.)) + def __del__(self): if os.path.exists(self.fname): @@ -128,12 +195,12 @@ def initFromFits(cls, file, **kwargs): img = cls(**kwargs) if file != '': try: - with pyfits.open(file) as fits: + with fits.open(file) as inf: ext = kwargs.get('ext', 0) - dat = fits[ext].data - img._init_dat(base_shape=dat.shape, data=dat) - my_wcs = wcs.WCS(fits[ext].header) - for k,v in fits[ext].header.items(): + dat = inf[ext].data + img._init_dat(base_shape=np.array(dat.shape), data=dat) + my_wcs = wcs.WCS(inf[ext].header) + for k,v in inf[ext].header.items(): if k != '' and k not in my_wcs.wcs.to_header(): img.header[k] = v img.wcs = img._normalizeWCS(my_wcs) @@ -155,10 +222,10 @@ def initDataFromFits(cls, file, **kwargs): img = cls(**kwargs) if file != '': try: - with pyfits.open(file) as fits: + with fits.open(file) as inf: ext = kwargs.get('ext', 0) - dat = fits[ext].data - img._init_dat(base_shape=dat.shape, data=dat) + dat = inf[ext].data + img._init_dat(base_shape=np.array(dat.shape), data=dat) img.wcs = img._getWcs(**kwargs) img._prepRaDec() img._prepHeader() @@ -206,7 +273,7 @@ def xsize(self): def xsize(self, size, offset=0): """Change the horizontal size. The offset will be applied to the new image before adding""" self.crop(size, self.ysize, offset, 0) - self.shape = (self.ysize, size) + self.shape = np.array((self.ysize, size)) @property def ysize(self): @@ -216,7 +283,7 @@ def ysize(self): def ysize(self, size, offset=0): """Change the vertical size. The offset will be applied to the new image before adding""" self.crop(self.xsize, size, 0, offset) - self.shape = (size, self.xsize) + self.shape = np.array((size, self.xsize)) @property def xscale(self): @@ -293,12 +360,25 @@ def pa(self,pa): spa = np.sin(np.radians(pa%360.)) self.wcs.wcs.pc = np.array([[cpa, -spa], [spa, cpa]]) self.addHistory("Set PA to %f" % (pa)) + + + @property + def celery_state(self): + if self.get_celery is not None: + return self.get_celery() + return "" + + @celery_state.setter + def celery_state(self, state): + if self.set_celery is not None: + self.set_celery(state) + @property def hdu(self): """Output AstroImage as a FITS Primary HDU""" - with ImageData(self.fname, self.shape, mode='r+') as dat: - hdu = pyfits.PrimaryHDU(dat, header=self.wcs.to_header(relax=True)) + with ImageData(self.fname, self.shape, mode='r+', memmap=self.memmap) as dat: + hdu = fits.PrimaryHDU(dat, header=self.wcs.to_header(relax=True)) hdu.header['CDELT1'] = self.scale[0]/3600. hdu.header['CDELT2'] = self.scale[0]/3600. if sys.version_info[0] >= 3: @@ -318,8 +398,8 @@ def hdu(self): def imageHdu(self): """Output AstroImage as a FITS Extension HDU""" self._log("info","Creating Extension HDU from AstroImage %s" % (self.name)) - with ImageData(self.fname, self.shape, mode='r+') as dat: - hdu = pyfits.ImageHDU(dat, header=self.wcs.to_header(relax=True), name=self.name) + with ImageData(self.fname, self.shape, mode='r+', memmap=self.memmap) as dat: + hdu = fits.ImageHDU(dat, header=self.wcs.to_header(relax=True), name=self.name) if sys.version_info[0] >= 3: for k,v in self.header.items(): hdu.header[k] = v @@ -330,11 +410,25 @@ def imageHdu(self): hdu.header.add_history(item) self._log("info","Created Extension HDU from AstroImage %s" % (self.name)) return hdu + + @property + def psf_constructor(self): + import webbpsf + return getattr(getattr(webbpsf, self.telescope), self.instrument)() + + + @property + def psf_shape(self): + if hasattr(self, 'psf'): + sampled_shape = self.psf.data.shape + return np.array([sampled_shape[1], sampled_shape[2]]) + return (0, 0) + def toFits(self, outFile): """Create a FITS file from the current state of the AstroImage data.""" self._log("info","Writing AstroImage %s to FITS" % (self.name)) - hdulist = pyfits.HDUList([self.hdu]) + hdulist = fits.HDUList([self.hdu]) hdulist.writeto(outFile, overwrite=True) def updateHeader(self,k,v): @@ -492,7 +586,7 @@ def addPoints(self, xs, ys, rates, *args, **kwargs): self._log("info","Adding %d point sources to AstroImage %s" % (len(xs),self.name)) xs = np.floor(xs).astype(int) ys = np.floor(ys).astype(int) - with ImageData(self.fname, self.shape) as dat: + with ImageData(self.fname, self.shape, memmap=self.memmap) as dat: dat[ys, xs] += rates def addSersicProfile(self, posX, posY, flux, n, re, phi, axialRatio, *args, **kwargs): @@ -565,71 +659,257 @@ def addSersicProfile(self, posX, posY, flux, n, re, phi, axialRatio, *args, **kw self._addArrayWithOffset(img, offset_x, offset_y) return central_flux - def convolve(self, other, max=4095, do_convolution=True, parallel=False, cores=None, state_setter=None, base_state=""): - """Convolves the AstroImage with another (provided) AstroImage, e.g. for PSF convolution.""" - self.addHistory("Convolving with file %s" % (other.name)) - self._log("info","Convolving AstroImage %s with %s" % (self.name,other.name)) - f, g = os.path.join(self.out_path, uuid.uuid4().hex+"_convolve_01.tmp"), os.path.join(self.out_path, uuid.uuid4().hex+"_convolve_02.tmp") + def make_psf(self): + from webbpsf import __version__ as psf_version + have_psf = False + psf_name = "psf_{}_{}_{}_{}_{}_{}.fits".format(self.instrument, + stips_version, + self.filter, + self.oversample, + self.psf_grid_size, + self.detector) + if os.path.exists(os.path.join(self.out_path, "psf_cache")): + psf_file = os.path.join(self.out_path, "psf_cache", psf_name) + if os.path.exists(psf_file): + from webbpsf.utils import to_griddedpsfmodel + if (self.psf_commands is None or self.psf_commands == ''): + self.psf = to_griddedpsfmodel(psf_file) + have_psf = True + if not have_psf: + base_state = self.celery_state + update_state = "
Generating PSF" + self.celery_state = base_state + update_state + ins = self.psf_constructor + if self.psf_commands != '': + for attribute,value in self.psf_commands.iteritems(): + setattr(ins,attribute,value) + ins.filter = self.filter + ins.detector = self.detector + scale = self.scale[0] + # First limit -- PSF no larger than detector + ins_size = max(self.xsize, self.ysize) * self.oversample + # Second limit -- PSF no larger than half of max convolution area. + conv_size = self.convolve_size // (2*self.oversample) + # Third limit -- prevent aliasing + safe_size = int(np.floor(30. * self.photplam / (2 * self.scale[0]))) + if safe_size <= 0: + safe_size = max(self.xsize, self.ysize) + msg = "PSF choosing between {}, {} and {}" + self._log("info", msg.format(ins_size, conv_size, safe_size)) + fov_pix = min(ins_size, conv_size, safe_size) + if fov_pix%2 != 0: + fov_pix += 1 + num_psfs = self.psf_grid_size*self.psf_grid_size + if os.path.exists(os.path.join(self.out_path, "psf_cache")): + save = True + overwrite = True + psf_dir = os.path.join(self.out_path, "psf_cache") + psf_file = "psf_{}_{}_{}_{}_{}".format(self.instrument, + stips_version, + self.filter, + self.oversample, + self.psf_grid_size) + else: + save = False + overwrite = False + psf_dir = None + psf_file = None + + msg = "{}: Starting {}x{} PSF Grid creation at {}" + self._log("info", msg.format(self.name, self.psf_grid_size, + self.psf_grid_size, time.ctime())) + self.psf = ins.psf_grid(all_detectors=False, num_psfs=num_psfs, + fov_pixels=fov_pix, normalize='last', + oversample=self.oversample, save=save, + outdir=psf_dir, outfile=psf_file, + overwrite=overwrite) + msg = "{}: Finished PSF Grid creation at {}" + self._log("info", msg.format(self.name, time.ctime())) + self.celery_state = base_state + + def convolve_psf(self, max_size=None, parallel=False, cores=None): + """ + Convolve the current AstroImage state with the generated PSF (if there + is a generated PSF). Otherwise, do nothing. + + Parameters + ---------- + max_size : int, default=4095 + The maximum size chunk to use in convolution chunks. + parallel : bool, default=False + Whether to perform convolution chunks in parallel + cores : int, default=None + How many CPU cores to use for parallel computation (used only if + parallel=True) + """ + if hasattr(self, 'psf'): + if max_size is None: + max_size = self.convolve_size + self.convolve(self.psf, max_size=max_size, parallel=parallel, + cores=cores, crop=True) + + + def convolve(self, other, max_size=None, parallel=False, cores=None, + crop=True): + """ + Convolve the AstroImage with another image. This image can be + - another AstroImage + - a numpy NDData array + - a FITS file + - a webbpsf PSF grid + + Parameters + ---------- + other : object + The other image to convolve. See above for possible formats. + max_size : int, default=4095 + The maximum size chunk to use in convolution chunks. + parallel : bool, default=False + Whether to perform convolution chunks in parallel + cores : int, default=None + How many CPU cores to use for parallel computation (used only if + parallel=True) + crop : bool, default=True + After convolving, should the AstroImage be cropped down to no + longer include the PSF overlap region. + """ + if max_size is None: + max_size = self.convolve_size + + other_type = "" + if isinstance(other, AstroImage): + other_type = "astro_image" + other_data = ImageData(other.fname, other.shape, mode='r', + memmap=other.memmap) + other_img = other_data.data + other_shape = other_img.shape + other_name = other + elif isinstance(other, np.ndarray): + other_type = "ndarray" + other_img = other + other_shape = other_img.shape + other_name = "{}x{} array".format(other.shape[0], other.shape[1]) + elif isinstance(other, GriddedPSFModel): + other_type = "psfgrid" + other_img = other + other_shape = self.psf_shape + other_name = "PSF grid" + else: + # There are so darn many fits HDU classes that I can't + # figure out how to do this with them, so FITS file is now + # a diagnosis of last resort. + other_type = 'fits_file' + if 'primary' in other: + other_img = other['primary'].data + elif len(other) > 1: + other_img = other[1].data + elif hasattr(other, 'data'): + other_img = other.data + else: + other_img = other[0].data + other_shape = other_img.shape + other_name = "FITS file" + + self.addHistory("Convolving {} with {}".format(self.name, other_name)) + self._log("info", "Convolving {} with {}".format(self.name, other_name)) + + f = os.path.join(self.out_path, uuid.uuid4().hex+"_convolve_01.tmp") + g = os.path.join(self.out_path, uuid.uuid4().hex+"_convolve_02.tmp") try: - sub_shape = (min(max - other.shape[0], self.shape[0] + other.shape[0] - 1), min(max - other.shape[1], self.shape[1] + other.shape[1] - 1)) - with ImageData(self.fname, self.shape, mode='r') as dat, ImageData(other.fname, other.shape, mode='r') as psf: - fp_result = np.memmap(f, dtype='float32', mode='w+', shape=(self.shape[0]+psf.shape[0]-1, self.shape[1]+psf.shape[1]-1)) - centre = (fp_result.shape[0]//2, fp_result.shape[1]//2) - if do_convolution: - sub_shape = (min(max - psf.shape[0], self.shape[0] + psf.shape[0] - 1), min(max - psf.shape[1], self.shape[1] + psf.shape[1] - 1)) - self._log('info', "PSF Shape: {}; Current Shape: {}".format(psf.shape, self.shape)) - self._log('info', "Choosing between {}-{}={} and {}+{}-1={}".format(max, psf.shape, max-psf.shape[0], psf.shape, self.shape, psf.shape[0]+self.shape[0]-1)) - self._log('info', "Using overlapping arrays of size {}".format(sub_shape)) - self._log('info', "Starting Convolution at {}".format(time.ctime())) - if parallel: - self._log('info', 'Convolving in parallel') - del fp_result - overlapaddparallel(dat, psf, sub_shape, y=f, verbose=True, logger=self.logger, base_state=base_state, state_setter=state_setter, path=self.out_path, cores=cores) - fp_result = np.memmap(f, dtype='float32', mode='r+', shape=(self.shape[0]+psf.shape[0]-1, self.shape[1]+psf.shape[1]-1)) + self_y, self_x = self.shape + other_y, other_x = other_shape + max_y = min(max_size - other_y, self_y + other_y - 1) + max_x = min(max_size - other_x, self_x + other_x - 1) + sub_shape = (max_y, max_x) + with ImageData(self.fname, self.shape, mode='r', memmap=self.memmap) as dat: + shape = (self_y + other_y - 1, self_x + other_x - 1) + if self.memmap: + fp_res = np.memmap(f, dtype='float32', mode='w+', shape=shape) + else: + fp_res = np.zeros(shape, dtype='float32') + centre = (fp_res.shape[0]//2, fp_res.shape[1]//2) + max_y = min(max_size - other_y, self_y + other_y - 1) + max_x = min(max_size - other_x, self_x + other_x - 1) + sub_shape = (max_y, max_x) + msg = "PSF Shape: {}; Current Shape: {}" + self._log('info', msg.format(other_shape, self.shape)) + msg = "Choosing between {}-{}={} and {}+{}-1={}" + self._log('info', msg.format(max_size, other_y, + max_size-other_y, other_y, + self_y, other_y+self_y-1)) + msg = "Using overlapping arrays of size {}" + self._log('info', msg.format(sub_shape)) + msg = "{}: Starting Convolution at {}" + self._log('info', msg.format(self.name, time.ctime())) + if parallel: + self._log('info', 'Convolving in parallel') + if self.memmap: + del fp_res else: - overlapadd2(dat, psf, sub_shape, y=fp_result, verbose=True, logger=self.logger, base_state=base_state, state_setter=state_setter) - self._log('info', "Finished Convolution at {}".format(time.ctime())) + f = fp_res + overlapaddparallel(dat, dat.shape, other_img, other_shape, + sub_shape, y=f, verbose=True, + logger=self.logger, + base_state=self.get_celery(), + state_setter=self.set_celery, + path=self.out_path, cores=cores, + memmap=self.memmap) + if self.memmap: + fp_res = np.memmap(f, dtype='float32', mode='r+', + shape=shape) else: - bordered_half = self.shape[0]//2, self.shape[1]//2 - ly, hy, lx, hx = centre[0]-bordered_half[0], centre[0]+bordered_half[0], centre[1]-bordered_half[1], centre[1]+bordered_half[1] - if hx-lx < self.shape[1]: - hx += 1 - elif hx-lx > self.shape[1]: - hx -= 1 - if hy-ly < self.shape[0]: - hy += 1 - elif hy-ly > self.shape[0]: - hy -= 1 - fp_result[ly:hy, lx:hx] += dat[:,:] - self._log('info', "Cropping convolved image down to detector size") - half = (self.base_shape[0]//2, self.base_shape[1]//2) - self._log('info', "Image Centre: {}; Image Half-size: {}".format(centre, half)) - ly, hy, lx, hx = centre[0]-half[0], centre[0]+half[0], centre[1]-half[1], centre[1]+half[1] - if hx-lx < self.base_shape[1]: - hx += 1 - elif hx-lx > self.base_shape[1]: - hx -= 1 - if hy-ly < self.base_shape[0]: - hy += 1 - elif hy-ly > self.base_shape[0]: - hy -= 1 - self._log('info', "Taking [{}:{}, {}:{}]".format(ly, hy, lx, hx)) - fp_crop = np.memmap(g, dtype='float32', mode='w+', shape=self.base_shape) - fp_crop[:,:] = fp_result[ly:hy, lx:hx] - crpix = [half[0], half[1]] - if self.wcs.sip is not None: - sip = wcs.Sip(self.wcs.sip.a, self.wcs.sip.b, None, None, crpix) - else: - sip = None - self.wcs = self._wcs(self.ra, self.dec, self.pa, self.scale, crpix=crpix, sip=sip) - del fp_result - del fp_crop - if os.path.exists(self.fname): - os.remove(self.fname) + overlapadd2(dat, dat.shape, other_img, other_shape, + sub_shape, y=fp_res, verbose=True, + logger=self.logger, + base_state=self.get_celery(), + state_setter=self.set_celery) + msg = "{}: Finished Convolution at {}" + self._log('info', msg.format(self.name, time.ctime())) + + if crop: + msg = "Cropping convolved image down to detector size" + self._log('info', msg) + half = (self.base_shape[0]//2, self.base_shape[1]//2) + msg = "Image Centre: {}; Image Half-size: {}" + self._log('info', msg.format(centre, half)) + ly, hy = centre[0]-half[0], centre[0]+half[0] + lx, hx = centre[1]-half[1], centre[1]+half[1] + if hx-lx < self.base_shape[1]: + hx += 1 + elif hx-lx > self.base_shape[1]: + hx -= 1 + if hy-ly < self.base_shape[0]: + hy += 1 + elif hy-ly > self.base_shape[0]: + hy -= 1 + msg = "Taking [{}:{}, {}:{}]" + self._log('info', msg.format(ly, hy, lx, hx)) + if self.memmap: + fp_crop = np.memmap(g, dtype='float32', mode='w+', + shape=tuple(self.base_shape)) + else: + fp_crop = np.zeros(tuple(self.base_shape), dtype='float32') + fp_crop[:,:] = fp_res[ly:hy, lx:hx] + crpix = [half[0], half[1]] + if self.wcs.sip is not None: + sip = wcs.Sip(self.wcs.sip.a, self.wcs.sip.b, None, None, + crpix) + else: + sip = None + self.wcs = self._wcs(self.ra, self.dec, self.pa, self.scale, + crpix=crpix, sip=sip) + del fp_res + if self.memmap: + del fp_crop + if os.path.exists(self.fname): + os.remove(self.fname) + self.fname = g + else: + del self.fname + self.fname = fp_crop + self.shape = self.base_shape if os.path.exists(f): os.remove(f) - self.fname = g - self.shape = self.base_shape except Exception as e: if os.path.exists(f): os.remove(f) @@ -650,13 +930,21 @@ def rotate(self,angle,reshape=False): self.pa = (self.pa + angle)%360.%360. f = os.path.join(self.out_path, uuid.uuid4().hex+"_rotate.tmp") try: - fp_result = np.memmap(f, dtype='float32', mode='w+', shape=self.shape) - with ImageData(self.fname, self.shape, mode='r') as dat: + t_shape = tuple(self.shape) + if self.memmap: + fp_result = np.memmap(f, dtype='float32', mode='w+', shape=t_shape) + else: + fp_result = np.zeros(t_shape, dtype='float32') + with ImageData(self.fname, self.shape, mode='r', memmap=self.memmap) as dat: rotate(dat, angle, order=5, reshape=reshape, output=fp_result) - del fp_result - if os.path.exists(self.fname): - os.remove(self.fname) - self.fname = f + if self.memmap: + del fp_result + if os.path.exists(self.fname): + os.remove(self.fname) + self.fname = f + else: + del self.fname + self.fname = fp_result except Exception as e: if os.path.exists(f): os.remove(f) @@ -738,7 +1026,7 @@ def _addArrayWithOffset(self, other, offset_x, offset_y): #If any of these are false, the images are disjoint. if low_x < self.xsize and low_y < self.ysize and high_x > 0 and high_y > 0: - with ImageData(self.fname, self.shape, mode='r+') as dat: + with ImageData(self.fname, self.shape, mode='r+', memmap=self.memmap) as dat: d = dat[low_y:high_y, low_x:high_x] d = other[ly:hy, lx:hx] dat[low_y:high_y, low_x:high_x] += other[ly:hy, lx:hx] @@ -756,7 +1044,7 @@ def _addWithOffset(self, other, offset_x, offset_y): #If any of these are false, the images are disjoint. if low_x < self.xsize and low_y < self.ysize and high_x > 0 and high_y > 0: - with ImageData(self.fname, self.shape, mode='r+') as dat, ImageData(other.fname, other.shape, mode='r') as other_data: + with ImageData(self.fname, self.shape, mode='r+', memmap=self.memmap) as dat, ImageData(other.fname, other.shape, mode='r', memmap=other.memmap) as other_data: dat[low_y:high_y, low_x:high_x] += other_data[ly:hy, lx:hx] else: self.addHistory("Added image is disjoint") @@ -794,18 +1082,25 @@ def rescale(self,scale): shape_y = int(round(self.shape[0] * self.scale[1] / scale[1])) new_shape = (shape_y, shape_x) self._log("info","New shape will be {}".format(new_shape)) - fp_result = np.memmap(f, dtype='float32', mode='w+', shape=new_shape) - with ImageData(self.fname, self.shape, mode='r+') as dat: + if self.memmap: + fp_result = np.memmap(f, dtype='float32', mode='w+', shape=new_shape) + else: + fp_result = np.zeros(new_shape, dtype='float32') + with ImageData(self.fname, self.shape, mode='r+', memmap=self.memmap) as dat: flux = dat.sum() self._log("info", "Max flux is {}, sum is {}".format(np.max(dat), flux)) zoom(dat, (np.array(self.scale)/np.array(scale)), fp_result) factor = flux / fp_result.sum() fp_result *= factor self._log("info", "Max flux is {}, sum is {}".format(np.max(fp_result), np.sum(fp_result))) - del fp_result - if os.path.exists(self.fname): - os.remove(self.fname) - self.fname = f + if self.memmap: + del fp_result + if os.path.exists(self.fname): + os.remove(self.fname) + self.fname = f + else: + del self.fname + self.fname = fp_result self.shape = new_shape except Exception as e: if os.path.exists(f): @@ -835,14 +1130,18 @@ def bin(self,binx,biny=None): f = os.path.join(self.out_path, uuid.uuid4().hex+"_bin.tmp") try: shape_x, shape_y = int(self.shape[1] // binx), int(self.shape[0] // biny) - with ImageData(self.fname, self.shape, mode='r+') as dat: + with ImageData(self.fname, self.shape, mode='r+', memmap=self.memmap) as dat: binned = dat.reshape(shape_y, biny, shape_x, binx).sum(axis=(1, 3)) - mapped = np.memmap(f, dtype='float32', mode='w+', shape=binned.shape) - mapped[:] = binned[:] - if os.path.exists(self.fname): - os.remove(self.fname) - self.fname = f - self.shape = (shape_y, shape_x) + if self.memmap: + mapped = np.memmap(f, dtype='float32', mode='w+', shape=binned.shape) + mapped[:] = binned[:] + if os.path.exists(self.fname): + os.remove(self.fname) + self.fname = f + else: + del self.fname + self.fname = binned + self.shape = np.array((shape_y, shape_x)) except Exception as e: if os.path.exists(f): os.remove(f) @@ -858,7 +1157,7 @@ def setExptime(self, exptime): Set the exposure time. Multiply data by new_exptime / old_exptime. """ factor = exptime / self.exptime - with ImageData(self.fname, self.shape, mode='r+') as dat: + with ImageData(self.fname, self.shape, mode='r+', memmap=self.memmap) as dat: dat *= factor self.exptime = exptime self.updateHeader('exptime', self.exptime) @@ -870,7 +1169,7 @@ def addBackground(self, background): per_pixel_background = background / (self.oversample*self.oversample) self.addHistory("Added background of {} counts/s/detector pixel ({} counts/s/oversampled pixel)".format(background, per_pixel_background)) self._log("info", "Added background of {} counts/s/detector pixel ({} counts/s/oversampled pixel)".format(background, per_pixel_background)) - with ImageData(self.fname, self.shape) as dat: + with ImageData(self.fname, self.shape, memmap=self.memmap) as dat: dat += per_pixel_background def introducePoissonNoise(self,absVal=False): @@ -887,11 +1186,18 @@ def introducePoissonNoise(self,absVal=False): """ a, n = os.path.join(self.out_path, uuid.uuid4().hex+"_poisson_a.tmp"), os.path.join(self.out_path, uuid.uuid4().hex+"_poisson_n.tmp") try: - with ImageData(self.fname, self.shape, mode='r+') as dat: - abs_data = np.memmap(a, dtype='float32', mode='w+', shape=self.shape) + with ImageData(self.fname, self.shape, mode='r+', memmap=self.memmap) as dat: + t_shape = tuple(self.shape) + if self.memmap: + abs_data = np.memmap(a, dtype='float32', mode='w+', shape=t_shape) + else: + abs_data = np.zeros(t_shape, dtype='float32') np.absolute(dat, abs_data) - noise_data = np.memmap(n, dtype='float32', mode='w+', shape=self.shape) + if self.memmap: + noise_data = np.memmap(n, dtype='float32', mode='w+', shape=t_shape) + else: + noise_data = np.zeros(t_shape, dtype='float32') noise_data[:,:] = np.random.RandomState(seed=self.seed).normal(size=self.shape) * np.sqrt(abs_data) del abs_data if absVal: @@ -924,8 +1230,12 @@ def introduceReadnoise(self,readnoise): """ n = os.path.join(self.out_path, uuid.uuid4().hex+"_readnoise.tmp") try: - with ImageData(self.fname, self.shape, mode='r+') as dat: - noise_data = np.memmap(n, dtype='float32', mode='w+', shape=self.shape) + with ImageData(self.fname, self.shape, mode='r+', memmap=self.memmap) as dat: + t_shape = tuple(self.shape) + if self.memmap: + noise_data = np.memmap(n, dtype='float32', mode='w+', shape=t_shape) + else: + noise_data = np.zeros(t_shape, dtype='float32') noise_data[:,:] = readnoise * np.random.RandomState(seed=self.seed).randn(self.ysize,self.xsize) mean, std = noise_data.mean(), noise_data.std() dat += noise_data @@ -949,7 +1259,7 @@ def introduceFlatfieldResidual(self,flat): returns: mean, std. Mean and standard deviation of used portion of error array. """ - with ImageData(self.fname, self.shape, mode='r+') as dat, ImageData(flat.fname, flat.shape, mode='r') as flat_data: + with ImageData(self.fname, self.shape, mode='r+', memmap=self.memmap) as dat, ImageData(flat.fname, flat.shape, mode='r', memmap=flat.memmap) as flat_data: err = flat_data[:self.ysize,:self.xsize] mean, std = err.mean(), err.std() dat *= err @@ -964,7 +1274,7 @@ def introduceDarkResidual(self,dark): returns: mean,std: mean and standard deviation of dark error array. """ - with ImageData(self.fname, self.shape, mode='r+') as dat, ImageData(dark.fname, dark.shape, mode='r') as dark_data: + with ImageData(self.fname, self.shape, mode='r+', memmap=self.memmap) as dat, ImageData(dark.fname, dark.shape, mode='r', memmap=dark.memmap) as dark_data: err = dark_data[:self.ysize,:self.xsize] mean, std = err.mean(), err.std() dat += err @@ -987,9 +1297,13 @@ def introduceCosmicRayResidual(self, pixel_size): n = os.path.join(self.out_path, uuid.uuid4().hex+"_cosmic.tmp") try: - with ImageData(self.fname, self.shape, mode='r+') as dat: - noise_data = np.memmap(n, dtype='float32', mode='w+', shape=self.shape) - noise_data.fill(0.) + with ImageData(self.fname, self.shape, mode='r+', memmap=self.memmap) as dat: + t_shape = tuple(self.shape) + if self.memmap: + noise_data = np.memmap(n, dtype='float32', mode='w+', shape=t_shape) + noise_data.fill(0.) + else: + noise_data = np.zeros(t_shape, dtype='float32') for i in range(len(energies)): noise_data += MakeCosmicRay(self.shape[1], self.shape[0], probs[i], energies[i], cr_size, cr_psf, self.seed, verbose=False) noise_data *= 0.01 @@ -1068,20 +1382,20 @@ def _getWcs(self,**kwargs): wcs = self._normalizeWCS(wcs) else: #get pixel scale (if available) - scale = kwargs.get('scale', [1., 1.]) + scale = self._scale ra = kwargs.get('ra', 0.) offset_ra = kwargs.get('offset_ra', 0.) dec = kwargs.get('dec', 0.) offset_dec = kwargs.get('offset_dec', 0.) - ra,dec = OffsetPosition(ra,dec,offset_ra,offset_dec) + ra,dec = OffsetPosition(ra, dec, offset_ra, offset_dec) pa = kwargs.get('pa', 0.) - wcs = self._wcs(ra,dec,pa,scale,sip=sip) + wcs = self._wcs(ra, dec, pa, scale, sip=sip) return wcs - def _sip(self,da,db,dap,dbp): + def _sip(self, da, db, dap, dbp): """Create a SIP distortion model from the distortion arrays""" - crpix = [int(np.floor(self.xsize/2.)),int(np.floor(self.ysize/2.))] - sip = wcs.Sip(da,db,dap,dbp,crpix) + crpix = [self.xsize//2, self.ysize//2] + sip = wcs.Sip(da, db, dap, dbp, crpix) return sip def _wcs(self, ra, dec, pa, scale, crpix=None, sip=None, ranum=0, decnum=1): @@ -1090,10 +1404,12 @@ def _wcs(self, ra, dec, pa, scale, crpix=None, sip=None, ranum=0, decnum=1): ra is right ascension in decimal degrees dec is declination in decimal degrees - pa is the angle between north and east on the tangent plane, with the quadrant between north - and east being positive - scale is the pixel scale in arcseconds/pixel, for the (x, y) axes of the image - crpix is the (x, y) location on the image of the reference pixel (default is centre) + pa is the angle between north and east on the tangent plane, with the + quadrant between northand east being positive + scale is the pixel scale in arcseconds/pixel, for the (x, y) axes of the + image + crpix is the (x, y) location on the image of the reference pixel + (default is centre) sip is the simple imaging polynomial distortion object (if present) ranum indicates which image axis represents RA (default 0) decnum indicates which image axis represents DEC (default 1) @@ -1103,7 +1419,7 @@ def _wcs(self, ra, dec, pa, scale, crpix=None, sip=None, ranum=0, decnum=1): w.wcs.ctype[ranum] = "RA---TAN" w.wcs.ctype[decnum] = "DEC--TAN" if crpix is None: - w.wcs.crpix = [int(np.floor(self.xsize/2.)),int(np.floor(self.ysize/2.))] + w.wcs.crpix = [self.xsize//2, self.ysize//2] else: w.wcs.crpix = crpix w.wcs.crval = [0.,0.] @@ -1126,8 +1442,10 @@ def _wcs(self, ra, dec, pa, scale, crpix=None, sip=None, ranum=0, decnum=1): w.wcs.ctype[decnum] = "DEC--TAN-SIP" w.sip = sip self._scale = scale - message = "{}: (RA, DEC, PA) := ({}, {}, {}), detected as ({}, {}, {})" - self._log("info", message.format(self.name, ra, dec, pa, w.wcs.crval[ranum], w.wcs.crval[decnum], self._getPA(w, scale, decnum))) + msg = "{}: (RA, DEC, PA) := ({}, {}, {}), detected as ({}, {}, {})" + msg = msg.format(self.name, ra, dec, pa, w.wcs.crval[ranum], + w.wcs.crval[decnum], self._getPA(w, scale, decnum)) + self._log("info", msg) return w def _normalizeWCS(self,w): @@ -1178,7 +1496,7 @@ def __iadd__(self,other): """Adds an AstroImage to the current one. (i.e. the '+=' operator)""" if isinstance(other, int) or isinstance(other, float): self.addHistory("Added constant %f/pixel" % (other)) - with ImageData(self.fname, self.shape) as dat: + with ImageData(self.fname, self.shape, memmap=self.memmap) as dat: dat += other return self else: @@ -1198,14 +1516,14 @@ def __radd__(self,other): .. warning:: Assumes constant value per-pixel """ self.addHistory("Added %f/pixel" % (other)) - with ImageData(self.fname, self.shape) as dat: + with ImageData(self.fname, self.shape, memmap=self.memmap) as dat: dat += other return self def __mul__(self, other): """Multiples an integer or floating-point constant to the AstroImage""" result = self.copy() - with ImageData(result.fname, result.shape) as dat: + with ImageData(result.fname, result.shape, memmap=result.memmap) as dat: dat *= other result.addHistory("Multiplied by %f" % (other)) return result @@ -1213,14 +1531,14 @@ def __mul__(self, other): def __imul__(self,other): """Multiples an integer or floating-point constant to the AstroImage""" self.addHistory("Multiplied by %f" % (other)) - with ImageData(self.fname, self.shape) as dat: + with ImageData(self.fname, self.shape, memmap=self.memmap) as dat: dat *= other return self def __rmul__(self,other): """Multiples an integer or floating-point constant to the AstroImage""" result = self.copy() - with ImageData(result.fname, result.shape) as dat: + with ImageData(result.fname, result.shape, memmap=result.memmap) as dat: dat *= other result.addHistory("Multiplied by %f" % (other)) return result @@ -1228,7 +1546,7 @@ def __rmul__(self,other): @property def sum(self): """Returns the sum of the flux in the current image""" - with ImageData(self.fname, self.shape) as dat: + with ImageData(self.fname, self.shape, memmap=self.memmap) as dat: sum = np.sum(dat) return sum @@ -1242,28 +1560,33 @@ def _log(self,mtype,message): sys.stderr.write("%s: %s\n" % (mtype,message)) def _init_dat(self, base_shape, psf_shape=(0,0), data=None): - if os.path.exists(self.fname): - os.remove(self.fname) self.base_shape = base_shape - self.shape = tuple(np.array(base_shape) + np.array(psf_shape)) - fp = np.memmap(self.fname, dtype='float32', mode='w+', shape=self.shape) - fp.fill(0.) + self.shape = np.array(base_shape) + np.array(psf_shape) + t_shape = tuple(self.shape) + if self.memmap: + if os.path.exists(self.fname): + os.remove(self.fname) + fp = np.memmap(self.fname, dtype='float32', mode='w+', + shape=t_shape) + fp.fill(0.) + else: + fp = np.zeros(t_shape, dtype='float32') + self.fname = fp if data is not None: - centre = tuple(np.array(self.shape)//2) - half = tuple(np.array(base_shape)//2) - fp[centre[0]-half[0]:centre[0]+self.base_shape[0]-half[0],centre[1]-half[1]:centre[1]+self.base_shape[1]-half[1]] = data - del fp + centre = self.shape//2 + cy, cx = centre + half = base_shape//2 + hy, hx = half + fp[cy-hy:cy+base_shape[0]-hy, cx-hx:cx+base_shape[1]-hx] = data + if self.memmap: + del fp def _remap(self, xs, ys): # Step 1 -- compensate for PSF adjustments - adj_x = (self.shape[1] - self.base_shape[1]) // 2 - adj_y = (self.shape[0] - self.base_shape[0]) // 2 - x_outs = xs - adj_x - y_outs = ys - adj_y - # Step 2 -- handle oversample - x_outs /= self.oversample - y_outs /= self.oversample - return x_outs, y_outs + adj = ((self.shape - self.base_shape)/2).astype(np.int32) + out_y = (ys - adj[0])/self.oversample + out_x = (xs - adj[1])/self.oversample + return out_x, out_y def updateState(self, state): if self.set_celery is not None: @@ -1273,3 +1596,23 @@ def getState(self): if self.get_celery is not None: return self.get_celery() return "" + + + INSTRUMENT_DEFAULT = { + 'telescope': 'wfirst', + 'instrument': 'WFI', + 'filter': 'F062', + 'detector': { + 'WFI': 'SCA01', + 'NIRCam': 'A1', + 'MIRI': 'MIRI' + }, + 'shape': (4096, 4096), + 'scale': [0.11,0.11], + 'zeropoint': 21.0, + 'photflam': 0., + 'photplam': 0.6700, + 'background': 0., + 'oversample': 1, + 'psf_grid_size': 1 + } diff --git a/stips/astro_image/tests/test_AstroImage.py b/stips/astro_image/tests/test_AstroImage.py index e352453..e78576c 100644 --- a/stips/astro_image/tests/test_AstroImage.py +++ b/stips/astro_image/tests/test_AstroImage.py @@ -132,7 +132,12 @@ def test_convolve(input,kernel,result): ] @pytest.mark.parametrize(("input","angle","output"), rotate_data) def test_rotate(input,angle,output): - image = AstroImage(data=input) + # If psf=True, then the image will create a default PSF, and pad its + # borders (and thus increase its size) sufficiently to include off-image + # regions of half the PSF width on each side. That will, in turn, result + # in the image sizes not matching the pre-made output arrays. Because this + # test is not related to PSFs in any way, no PSF is created. + image = AstroImage(data=input, psf=False) image.rotate(angle) verifyData(image.hdu.data,output) @@ -146,6 +151,11 @@ def test_rotate(input,angle,output): ] @pytest.mark.parametrize(("input","bin","result"), bin_data) def test_bin(input,bin,result): - im1 = AstroImage(data=input) + # If psf=True, then the image will create a default PSF, and pad its + # borders (and thus increase its size) sufficiently to include off-image + # regions of half the PSF width on each side. That will, in turn, result + # in the image sizes not matching the pre-made output arrays. Because this + # test is not related to PSFs in any way, no PSF is created. + im1 = AstroImage(data=input, psf=False) im1.bin(bin[0],bin[1]) - verifyData(im1.hdu.data,result) \ No newline at end of file + verifyData(im1.hdu.data,result) diff --git a/stips/astro_image/tests/test_wcs.py b/stips/astro_image/tests/test_wcs.py index 3dbc237..e1a5c48 100644 --- a/stips/astro_image/tests/test_wcs.py +++ b/stips/astro_image/tests/test_wcs.py @@ -1,4 +1,4 @@ -import os +import os, pytest import numpy as np @@ -6,6 +6,7 @@ from stips.astro_image import AstroImage +@pytest.mark.veryslow def test_astro_image_rotation(data_base): fits_path = os.path.join(data_base, "test", "wcs_test.fits") diff --git a/stips/instruments/instrument.py b/stips/instruments/instrument.py index d815d36..a22ccb2 100644 --- a/stips/instruments/instrument.py +++ b/stips/instruments/instrument.py @@ -73,10 +73,27 @@ def __init__(self, **kwargs): self.background_value = kwargs.get('background', 'none') self.custom_background = kwargs.get('custom_background', 0.) self.CENTRAL_OFFSET = (0., 0., 0.) - self.convolve_size = kwargs.get('convolve_size', 4096) + self.convolve_size = kwargs.get('convolve_size', 8192) self.set_celery = kwargs.get('set_celery', None) self.get_celery = kwargs.get('get_celery', None) self.use_local_cache = kwargs.get('use_local_cache', False) + self.memmap = kwargs.get('memmap', True) + + #Adjust # of detectors based on keyword: + n_detectors = int(kwargs.get('detectors', len(self.DETECTOR_OFFSETS))) + self.DETECTOR_OFFSETS = self.DETECTOR_OFFSETS[:n_detectors] + self.OFFSET_NAMES = self.OFFSET_NAMES[:n_detectors] + self.CENTRAL_OFFSET = self.N_OFFSET[n_detectors] + msg = "{} with {} detectors. Central offset {}" + self._log('info', msg.format(self.DETECTOR, n_detectors, + self.CENTRAL_OFFSET)) + + #Set oversampling + self.oversample = kwargs.get('oversample', self.OVERSAMPLE_DEFAULT) + + #Set PSF grid points + self.psf_grid_size = kwargs.get('psf_grid_size', + self.PSF_GRID_SIZE_DEFAULT) @classmethod def initFromImage(cls, image, **kwargs): @@ -133,50 +150,45 @@ def reset(self, ra, dec, pa, filter, obs_count, psf=True, detectors=True, celery self.pa = pa self.obs_count = obs_count if filter != self.filter: + if filter not in self.FILTERS: + msg = "Filter {} is not a valid {} filter" + raise ValueError(msg.format(filter, self.instrument)) self.filter = filter - if psf: - self.resetPSF() self.background = self.pixel_background self.photfnu = self.PHOTFNU[self.filter] self.photplam = self.PHOTPLAM[self.filter] + if hasattr(self, "_bp"): + del self._bp if detectors: - self.resetDetectors() - - def resetPSF(self): - pass + self.resetDetectors(psf=psf) - def resetDetectors(self): + + def resetDetectors(self, psf=True): if self.detectors is not None: del self.detectors #Create Detectors self.detectors = [] - for offset,name in zip(self.DETECTOR_OFFSETS,self.OFFSET_NAMES): + for offset, name in zip(self.DETECTOR_OFFSETS, self.OFFSET_NAMES): distortion = None if self.distortion: distortion = self.DISTORTION[name] - (delta_ra,delta_dec,delta_pa) = offset - delta_ra -= self.CENTRAL_OFFSET[0] - delta_dec -= self.CENTRAL_OFFSET[1] - delta_pa -= self.CENTRAL_OFFSET[2] - ra,dec = OffsetPosition(self.ra,self.dec,delta_ra/3600.,delta_dec/3600.) + (delta_ra, delta_dec, delta_pa) = offset + delta_ra = (delta_ra - self.CENTRAL_OFFSET[0])/3600. + delta_dec = (delta_dec - self.CENTRAL_OFFSET[1])/3600. + delta_pa = delta_pa - self.CENTRAL_OFFSET[2] + ra,dec = OffsetPosition(self.ra, self.dec, delta_ra, delta_dec) pa = (self.pa + delta_pa)%360. - hdr = {"DETECTOR":name,"FILTER":self.filter} - hist = ["Initialized %s Detector %s, filter %s" % (self.instrument,name,self.filter)] - xsize = self.DETECTOR_SIZE[0]*self.oversample - ysize = self.DETECTOR_SIZE[1]*self.oversample - scale = [self.SCALE[0]/self.oversample,self.SCALE[1]/self.oversample] - self._log("info","Creating Detector with (RA,DEC,PA) = (%f,%f,%f)" % (ra,dec,pa)) - self._log("info","Creating Detector with pixel offset ({},{})".format(delta_ra/scale[0], delta_dec/scale[1])) - detector = AstroImage(out_path=self.out_path, shape=(ysize, xsize), scale=scale, ra=ra, - dec=dec, pa=pa, exptime=1., header=hdr, history=hist, - psf_shape=self.psf.shape, zeropoint=self.zeropoint, - background=self.background, noise_floor=1./self.exptime, - photflam=self.photflam, detname=name, logger=self.logger, - oversample=self.oversample, small_subarray=self.small_subarray, - distortion=distortion, prefix=self.prefix, seed=self.seed, - set_celery=self.set_celery, get_celery=self.get_celery, - cat_type=self.cat_type) - self._log("info", "Detector created") + hdr = {"DETECTOR":name, "FILTER":self.filter} + msg = "Initialized {} Detector {} with filter {}" + hist = [msg.format(self.instrument, name, self.filter)] + msg = "Creating Detector {} with (RA,DEC,PA) = ({},{},{})" + self._log("info", msg.format(name, ra, dec, pa)) + msg = "Creating Detector {} with offset ({},{})" + self._log("info", msg.format(name, delta_ra, delta_dec)) + detector = AstroImage(parent=self, ra=ra, dec=dec, pa=pa, psf=psf, + header=hdr, history=hist, detname=name, + distortion=distortion) + self._log("info", "Detector {} created".format(name)) self.detectors.append(detector) def toFits(self,outfile): @@ -781,7 +793,7 @@ def addError(self, convolve=True, poisson=True, readnoise=True, flat=True, dark= self._log("info","Convolving with PSF") convolve_state = base_state + "
Detector {}: Convolving PSF".format(detector.name) self.updateState(convolve_state) - detector.convolve(self.psf, max=self.convolve_size-1, do_convolution=convolve, parallel=parallel, cores=cores, state_setter=self.updateState, base_state=convolve_state) + detector.convolve_psf(max_size=self.convolve_size-1, parallel=parallel, cores=cores) if 'convolve' in snapshots or 'all' in snapshots: detector.toFits(self.imgbase+"_{}_{}_snapshot_convolve.fits".format(self.obs_count, detector.name)) if self.oversample != 1: @@ -993,4 +1005,16 @@ def getState(self): if self.get_celery is not None: return self.get_celery() return "" + + # Simulation defaults + + # Default detector oversample. + # - it is not actually recommended to use this for real simulations. + # - an oversample of at least 5 is preferable, 10 is recommended. + # - that said, it's a balance between size, speed, and accuracy. + OVERSAMPLE_DEFAULT = 1 + + # Size of the side of the grid. e.g. 5 = 5X5 grid = 25 PSFs. + # - using the current default will result in a non-varying PSF. + PSF_GRID_SIZE_DEFAULT = 1 diff --git a/stips/instruments/miri.py b/stips/instruments/miri.py index 13c1e97..4b2fb24 100644 --- a/stips/instruments/miri.py +++ b/stips/instruments/miri.py @@ -38,44 +38,7 @@ def __init__(self, **kwargs): self.k = self.K[kwargs.get('miri_mods', 'fast')] self.a = self.A[kwargs.get('miri_mods', 'fast')] - def resetPSF(self): - import webbpsf - if self.filter not in self.FILTERS: - raise ValueError("Filter %s is not a valid MIRI filter" % (self.filter)) - have_psf = False - if os.path.exists(os.path.join(self.out_path, "psf_cache")): - if os.path.exists(os.path.join(self.out_path, "psf_cache", "psf_{}_{}_{}.fits".format("MIRI", self.filter, self.oversample))): - with pyfits.open(os.path.join(self.out_path, "psf_cache", "psf_{}_{}_{}.fits".format("MIRI", self.filter, self.oversample))) as psf: - if psf[0].header['VERSION'] >= webbpsf.__version__ and (self.psf_commands is None or self.psf_commands == ''): - self.psf = AstroImage(data=psf[0].data, detname="MIRI {} PSF".format(self.filter), logger=self.logger) - have_psf = True - if not have_psf: - base_state = self.getState() - self.updateState(base_state+"
Generating PSF") - self._log("info", "Creating PSF") - ins = webbpsf.MIRI() - self._log("info", "Setting PSF attributes") - if self.psf_commands is not None and self.psf_commands != '': - for attribute,value in self.psf_commands.iteritems(): - self._log("info", "Setting PSF attribute {} to {}".format(attribute, value)) - setattr(ins,attribute,value) - self._log("info", "Setting PSF filter to '{}'".format(self.filter)) - ins.filter = self.filter - max_safe_size = int(np.floor(30. * self.PHOTPLAM[self.filter] / (2. * self.SCALE[0]))) - max_ins_size = max(self.DETECTOR_SIZE) * self.oversample - max_conv_size = int(np.floor(self.convolve_size / (2*self.oversample))) - psf_size = min(max_safe_size, max_ins_size, max_conv_size) - self._log("info", "PSF choosing between {}, {}, and {}, chose {}".format(max_safe_size, max_ins_size, max_conv_size, psf_size)) - if hasattr(ins, 'calc_psf'): - ins.calcPSF = ins.calc_psf - psf = ins.calcPSF(oversample=self.oversample, fov_pixels=psf_size, normalize='last') - psf[0].header['VERSION'] = webbpsf.__version__ - if os.path.exists(os.path.join(self.out_path, "psf_cache")): - dest = os.path.join(self.out_path, "psf_cache", "psf_{}_{}_{}.fits".format("MIRI", self.filter, self.oversample)) - pyfits.writeto(dest, psf[0].data, header=psf[0].header, overwrite=True) - self.psf = AstroImage(data=psf[0].data,detname="MIRI %s PSF" % (self.filter),logger=self.logger) - self.updateState(base_state) - + def generateReadnoise(self): """ Readnoise formula that is similar to JWST ETC. @@ -151,6 +114,12 @@ def handleDithers(cls,form): } FILTERS = ('F560W','F770W','F1000W','F1130W','F1280W','F1500W','F1800W','F2100W','F2550W') DEFAULT_FILTER = 'F1000W' + + # Simulation Values + OVERSAMPLE_DEFAULT = 1 # by default sample at detector size. + PSF_GRID_SIZE_DEFAULT = 5 # 5X5 grid = 25 PSFs + PSF_INSTRUMENT = "MIRI" + FLATFILE = 'err_flat_test.fits' DARKFILE = 'err_rdrk_miri_im.fits' # ETC short BACKGROUND = { 'none': { 'F560W': 0., 'F770W': 0., 'F1000W' :0., 'F1130W': 0., 'F1280W': 0., 'F1500W': 0., 'F1800W':0., diff --git a/stips/instruments/nircam.py b/stips/instruments/nircam.py index 53bd7dd..e3604dd 100644 --- a/stips/instruments/nircam.py +++ b/stips/instruments/nircam.py @@ -19,67 +19,7 @@ class NIRCamBase(JwstInstrument): Base class with data common to NIRCam in all instances. Uses NIRCamShort dither patterns. """ - def __init__(self, **kwargs): - """ - Init does the following: - - super().__init__() - - looks for self.oversample (if present) - - finds filter and verifies that it is a valid filter - - finds target co-ordinates and PA (if present) - - looks for input and output paths - - sets flat residual reference file - - sets dark residual reference file - - determines instrument/filter zeropoint - - creates PSF - - creates detectors and specifies their relative offset - """ - self.classname = self.__class__.__name__ - #Initialize superclass - super(NIRCamBase,self).__init__(**kwargs) - #Set oversampling - self.oversample = kwargs.get('oversample', 1) - - #Adjust # of detectors based on keyword: - n_detectors = int(kwargs.get('detectors', len(self.DETECTOR_OFFSETS))) - self.DETECTOR_OFFSETS = self.DETECTOR_OFFSETS[:n_detectors] - self.OFFSET_NAMES = self.OFFSET_NAMES[:n_detectors] - self.CENTRAL_OFFSET = self.N_OFFSET[n_detectors] - self._log('info', "{} with {} detectors. Central offset {}".format(self.DETECTOR, n_detectors, self.CENTRAL_OFFSET)) - - def resetPSF(self): - import webbpsf - if self.filter not in self.FILTERS: - raise ValueError("Filter %s is not a valid %s filter" % (self.filter,self.classname)) - have_psf = False - if os.path.exists(os.path.join(self.out_path, "psf_cache")): - if os.path.exists(os.path.join(self.out_path, "psf_cache", "psf_{}_{}_{}.fits".format("NIRCam", self.filter, self.oversample))): - with pyfits.open(os.path.join(self.out_path, "psf_cache", "psf_{}_{}_{}.fits".format("NIRCam", self.filter, self.oversample))) as psf: - if psf[0].header['VERSION'] >= webbpsf.__version__ and (self.psf_commands is None or self.psf_commands == ''): - self.psf = AstroImage(data=psf[0].data, detname="NIRCam {} PSF".format(self.filter), logger=self.logger) - have_psf = True - if not have_psf: - base_state = self.getState() - self.updateState(base_state+"
Generating PSF") - ins = webbpsf.NIRCam() - if self.psf_commands is not None and self.psf_commands != '': - for attribute,value in self.psf_commands.iteritems(): - setattr(ins,attribute,value) - ins.filter = self.filter - max_safe_size = int(np.floor(30. * self.PHOTPLAM[self.filter] / (2. * self.SCALE[0]))) - max_ins_size = max(self.DETECTOR_SIZE) * self.oversample - max_conv_size = int(np.floor(2048 / self.oversample)) - self._log("info", "PSF choosing between {}, {}, and {}".format(max_safe_size, max_ins_size, max_conv_size)) - if hasattr(ins, 'calc_psf'): - ins.calcPSF = ins.calc_psf - psf = ins.calcPSF(oversample=self.oversample, fov_pixels=min(max_safe_size, max_ins_size, max_conv_size), normalize='last') - psf[0].header['VERSION'] = webbpsf.__version__ - if os.path.exists(os.path.join(self.out_path, "psf_cache")): - dest = os.path.join(self.out_path, "psf_cache", "psf_{}_{}_{}.fits".format("NIRCam", self.filter, self.oversample)) - pyfits.writeto(dest, psf[0].data, header=psf[0].header, overwrite=True) - self.psf = AstroImage(data=psf[0].data,detname="NIRCam %s PSF" % (self.filter),logger=self.logger) - self.updateState(base_state) - def generateReadnoise(self): """ Readnoise formula that is similar to JWST ETC. @@ -126,6 +66,12 @@ def handleDithers(cls,form): # Offsets are in (arcseconds_ra,arcseconds_dec,degrees_angle) DETECTOR_SIZE = (2040,2040) #pixels PIXEL_SIZE = 18.0 #um + + # Simulation Values + OVERSAMPLE_DEFAULT = 1 # by default sample at detector size. + PSF_GRID_SIZE_DEFAULT = 5 # 5X5 grid = 25 PSFs + PSF_INSTRUMENT = "NIRCam" + DISTORTION = { 'A1': { 'DIST_A': [[0., 0., 0., 0., 0., 0.], [3.11450E-2, 0., 0., 0., 0., 0.], diff --git a/stips/instruments/wfc3ir.py b/stips/instruments/wfc3ir.py index 3a63529..85085b2 100644 --- a/stips/instruments/wfc3ir.py +++ b/stips/instruments/wfc3ir.py @@ -31,12 +31,14 @@ def __init__(self, **kwargs): #For WFC3 at the moment, there is no way to oversample the PSF. Thus no oversample. # self.oversample = kwargs.get('oversample', 1) + def resetPSF(self): if self.filter not in self.FILTERS: raise ValueError("Filter %s is not a valid WFC3IR filter" % (self.filter)) psf_path = GetStipsData(os.path.join('psf_data', 'PSF_WFC3IR_{}.fits'.format(self.filter))) self.psf = AstroImage.initDataFromFits(psf_path, detname="WFC3IRPSF", logger=self.logger) + @property def bandpass(self): import pysynphot as ps diff --git a/stips/instruments/wfi.py b/stips/instruments/wfi.py index 9b226e7..a9f6811 100644 --- a/stips/instruments/wfi.py +++ b/stips/instruments/wfi.py @@ -15,6 +15,8 @@ from .wfirst_instrument import WfirstInstrument from ..utilities import OffsetPosition +from stips.version import __version__ as stips_version + class WFI(WfirstInstrument): __classtype__ = "detector" """ @@ -32,60 +34,15 @@ def __init__(self, **kwargs): """ self.classname = self.__class__.__name__ #Initialize superclass - super(WFI,self).__init__(**kwargs) + super(WFI, self).__init__(**kwargs) #Set oversampling - self.oversample = kwargs.get('oversample', 1) - - def resetPSF(self): - import webbpsf - if self.filter not in self.FILTERS: - raise ValueError("Filter %s is not a valid WFI filter" % (self.filter)) - have_psf = False - if os.path.exists(os.path.join(self.out_path, "psf_cache")): - if os.path.exists(os.path.join(self.out_path, "psf_cache", "psf_{}_{}_{}.fits".format("WFI", self.filter, self.oversample))): - with pyfits.open(os.path.join(self.out_path, "psf_cache", "psf_{}_{}_{}.fits".format("WFI", self.filter, self.oversample))) as psf: - if psf[0].header['VERSION'] >= webbpsf.__version__ and (self.psf_commands is None or self.psf_commands == ''): - psf_scale = [psf[0].header["PIXELSCL"], psf[0].header["PIXELSCL"]] - self.psf = AstroImage(data=psf[0].data, - scale=psf_scale, - ra=self.ra, - dec=self.dec, - pa=self.pa, - detname="WFI {} PSF".format(self.filter), - logger=self.logger) - have_psf = True - if not have_psf: - base_state = self.getState() - self.updateState(base_state+"
Generating PSF") - from webbpsf import wfirst - ins = wfirst.WFI() - if self.psf_commands is not None and self.psf_commands != '': - for attribute,value in self.psf_commands.iteritems(): - setattr(ins,attribute,value) - ins.filter = self.filter - max_safe_size = int(np.floor(30. * self.PHOTPLAM[self.filter] / (2. * self.SCALE[0]))) - max_ins_size = max(self.DETECTOR_SIZE) * self.oversample - max_conv_size = int(np.floor(2048 / self.oversample)) - self._log("info", "PSF choosing between {}, {}, and {}".format(max_safe_size, max_ins_size, max_conv_size)) - if hasattr(ins, 'calc_psf'): - ins.calcPSF = ins.calc_psf - psf = ins.calcPSF(oversample=self.oversample, fov_pixels=min(max_safe_size, max_ins_size, max_conv_size), normalize='last') - self._log("info", "PSF Total Flux: {}".format(np.sum(psf[0].data))) - psf[0].header['VERSION'] = webbpsf.__version__ - if os.path.exists(os.path.join(self.out_path, "psf_cache")): - dest = os.path.join(self.out_path, "psf_cache", "psf_{}_{}_{}.fits".format("WFI", self.filter, self.oversample)) - pyfits.writeto(dest, psf[0].data, header=psf[0].header, overwrite=True) - psf_scale = [psf[0].header["PIXELSCL"], psf[0].header["PIXELSCL"]] - self.psf = AstroImage(data=psf[0].data, - scale=psf_scale, - ra=self.ra, - dec=self.dec, - pa=self.pa, - detname="WFI %s PSF" % (self.filter), - logger=self.logger) - self.updateState(base_state) + self.oversample = kwargs.get('oversample', self.OVERSAMPLE_DEFAULT) + #Set PSF grid points + self.grid_size = kwargs.get('grid_size', self.PSF_GRID_SIZE_DEFAULT) + + def generateReadnoise(self): """ Readnoise formula that is similar to HST ETC. @@ -114,31 +71,39 @@ def handleDithers(cls,form): INSTRUMENT = "WFI" DETECTOR = "WFI" + # Offsets are in (arcseconds_ra,arcseconds_dec,degrees_angle) - DETECTOR_OFFSETS = ((0.,0.,0.),) #There will be 18, but simulate as single - OFFSET_NAMES = (("WFIRST-WFI"),) - # N_DETECTORS is a set of options on how many of the instrument's detectors you want to use + DETECTOR_OFFSETS = ((0.,0.,0.), (0.,0.,0.), (0.,0.,0.), (0.,0.,0.), + (0.,0.,0.), (0.,0.,0.), (0.,0.,0.), (0.,0.,0.), + (0.,0.,0.), (0.,0.,0.), (0.,0.,0.), (0.,0.,0.), + (0.,0.,0.), (0.,0.,0.), (0.,0.,0.), (0.,0.,0.), + (0.,0.,0.), (0.,0.,0.)) + OFFSET_NAMES = ("SCA01", "SCA02", "SCA03", "SCA04", "SCA05", "SCA06", + "SCA07", "SCA08", "SCA09", "SCA10", "SCA11", "SCA12", + "SCA13", "SCA14", "SCA15", "SCA16", "SCA17", "SCA18") + N_OFFSET = {1: (0., 0., 0.), 2: (0., 0., 0.), 8: (0., 0., 0.), + 4: (0., 0., 0.), 5: (0., 0., 0.), 6: (0., 0., 0.), + 7: (0., 0., 0.), 8: (0., 0., 0.), 9: (0., 0., 0.), + 10: (0., 0., 0.), 11: (0., 0., 0.), 12: (0., 0., 0.), + 13: (0., 0., 0.), 14: (0., 0., 0.), 15: (0., 0., 0.), + 16: (0., 0., 0.), 17: (0., 0., 0.), 18: (0., 0., 0.)} + + # N_DETECTORS is a set of options on how many of the instrument's detectors you want to use N_DETECTORS = [1] INSTRUMENT_OFFSET = (0.,0.,0.) #Presumably there is one, but not determined DETECTOR_SIZE = (4096,4096) #pixels PIXEL_SIZE = 18.0 #um (Assume for now) SCALE = [0.11,0.11] #Assume for now - DIST_A = [[ 0., 0., 0.], - [ 0., 0., 0.], - [ 0., 0., 0.]] - DIST_B = [[ 0., 0., 0.], - [ 0., 0., 0.], - [ 0., 0., 0.]] - DIST_AP = [[ 0., 0., 0.], - [ 0., 0., 0.], - [ 0., 0., 0.]] - DIST_BP = [[ 0., 0., 0.], - [ 0., 0., 0.], - [ 0., 0., 0.]] FILTERS = ('F062', 'F087', 'F106', 'F129', 'F158', 'F184', 'F146', 'F149') #W149 needs to go away at some point. DEFAULT_FILTER = 'F184' #Assume for now + + PSF_INSTRUMENT = "WFI" + + # Reference Files FLATFILE = 'err_flat_wfi.fits' #Use for the moment DARKFILE = 'err_rdrk_wfi.fits' # IREF, IHB (use for the moment) + + # Background Values BACKGROUND = { 'none': {'F062': 0., 'F087': 0.,'F106': 0.,'F129': 0.,'F158': 0.,'F184': 0., 'F146': 0., 'F149': 0.}, 'avg': {'F062': 1.401E+00, 'F087': 1.401E+00, 'F106': 1.401E+00, 'F129': 7.000E-01, 'F158': 7.521E-01, 'F184': 8.500E-01, 'F146': 7.000E-01, 'F149': 7.000E-01} diff --git a/stips/observation_module/observation_module.py b/stips/observation_module/observation_module.py index c749894..02ec42d 100644 --- a/stips/observation_module/observation_module.py +++ b/stips/observation_module/observation_module.py @@ -1,7 +1,7 @@ from __future__ import absolute_import # External modules -import logging, os, sys +import glob, logging, os, sys # Local modules from ..utilities import GetStipsData, InstrumentList, OffsetPosition, StipsDataTable @@ -70,7 +70,7 @@ def __init__(self, obs, **kwargs): self.cat_path = kwargs.get('cat_path', os.getcwd()) self.out_path = kwargs.get('out_path', os.getcwd()) self.cat_type = kwargs.get('cat_type', 'fits') - self.convolve_size = kwargs.get('convolve_size', 4096) + self.convolve_size = kwargs.get('convolve_size', 8192) self.parallel = kwargs.get('parallel', False) self.cores = kwargs.get('cores', None) @@ -101,6 +101,7 @@ def __init__(self, obs, **kwargs): self.poisson = kwargs.get('poisson', True) self.readnoise = kwargs.get('readnoise', True) self.version = kwargs.get('version', '0.0') + self.memmap = kwargs.get('memmap', True) self.set_celery = kwargs.get('set_celery', None) self.get_celery = kwargs.get('get_celery', None) @@ -277,16 +278,20 @@ def addError(self, *args, **kwargs): self: obj Class instance. """ - psf_name = "%s_%d_psf.fits" % (self.imgbase, self.obs_count) - self.instrument.psf.toFits(psf_name) + psf_names = [] + psf_path = os.path.join(self.out_path, "psf_cache") + if os.path.exists(psf_path): + psf_names = glob.glob(os.path.join(psf_path, "*.fits")) self._log("info","Adding Error") if 'parallel' not in kwargs: kwargs['parallel'] = self.parallel if 'cores' not in kwargs: kwargs['cores'] = self.cores - self.instrument.addError(poisson=self.poisson, readnoise=self.readnoise, flat=self.flat, dark=self.dark, cosmic=self.cosmic, *args, **kwargs) + self.instrument.addError(poisson=self.poisson, readnoise=self.readnoise, + flat=self.flat, dark=self.dark, + cosmic=self.cosmic, *args, **kwargs) self._log("info","Finished Adding Error") - return psf_name + return psf_names #----------- def finalize(self, mosaic=True, *args, **kwargs): diff --git a/stips/stellar_module/star_generator.py b/stips/stellar_module/star_generator.py index ed019bf..8b9d8d4 100644 --- a/stips/stellar_module/star_generator.py +++ b/stips/stellar_module/star_generator.py @@ -269,7 +269,8 @@ def make_cluster_rates(self,masses,instrument,filter,bandpass=None,refs=None): countrates[np.where(mags < mags_min)] = countrates_min[np.where(mags < mags_min)] countrates[np.where(mags > mags_max)] = countrates_max[np.where(mags > mags_max)] else: - self.log('warning', 'Could not find result file "result_{}_{}.npy"'.format(instrument.lower(), filter.lower())) + self.log('warning', 'Could not find result file "result_{}_{}.npy" from {}'.format(instrument.lower(), filter.lower(), self.gridpath)) +# raise FileNotFoundError('Could not find result file "result_{}_{}.npy" from {}'.format(instrument.lower(), filter.lower(), self.gridpath)) import pysynphot as ps countrates = np.array(()) ps.setref(**refs) diff --git a/stips/tests/test_wfirst.py b/stips/tests/test_wfirst.py index a8554b0..8a05320 100644 --- a/stips/tests/test_wfirst.py +++ b/stips/tests/test_wfirst.py @@ -1,56 +1,139 @@ from stips.scene_module import SceneModule from stips.observation_module import ObservationModule -def test_wfirst_observation(): - scm = SceneModule() +import pytest - stellar = { - 'n_stars': 100, - 'age_low': 1.0e12, 'age_high': 1.0e12, - 'z_low': -2.0, 'z_high': -2.0, - 'imf': 'salpeter', 'alpha': -2.35, - 'binary_fraction': 0.1, - 'distribution': 'invpow', 'clustered': True, - 'radius': 100.0, 'radius_units': 'pc', - 'distance_low': 20.0, 'distance_high': 20.0, - 'offset_ra': 0.0, 'offset_dec': 0.0 - } +def create_catalogues(): + star_data = { + 'n_stars': 100, + 'age_low': 1.0e12, 'age_high': 1.0e12, + 'z_low': -2.0, 'z_high': -2.0, + 'imf': 'salpeter', 'alpha': -2.35, + 'binary_fraction': 0.1, + 'distribution': 'invpow', 'clustered': True, + 'radius': 100.0, 'radius_units': 'pc', + 'distance_low': 10.0, 'distance_high': 10.0, + 'offset_ra': 0.0, 'offset_dec': 0.0 + } - stellar_cat_file = scm.CreatePopulation(stellar) + galaxy_data = { + 'n_gals': 25, + 'z_low': 0.0, 'z_high': 0.2, + 'rad_low': 0.01, 'rad_high': 2.0, + 'sb_v_low': 30.0, 'sb_v_high': 25.0, + 'distribution': 'uniform', 'clustered': False, + 'radius': 200.0, 'radius_units': 'arcsec', + 'offset_ra': 0.0, 'offset_dec': 0.0 + } - galaxy = { - 'n_gals': 10, - 'z_low': 0.0, 'z_high': 0.2, - 'rad_low': 0.01, 'rad_high': 2.0, - 'sb_v_low': 30.0, 'sb_v_high': 25.0, - 'distribution': 'uniform', 'clustered': False, - 'radius': 200.0, 'radius_units': 'arcsec', - 'offset_ra': 0.0, 'offset_dec': 0.0 - } - - galaxy_cat_file = scm.CreateGalaxies(galaxy) + scm = SceneModule() + stellar_cat_file = scm.CreatePopulation(star_data) + galaxy_cat_file = scm.CreateGalaxies(galaxy_data) + + return stellar_cat_file, galaxy_cat_file +def get_default_obs(): obs = { 'instrument': 'WFI', 'filters': ['F129'], 'detectors': 1, 'distortion': False, - 'oversample': 5, + 'oversample': 1, + 'psf_grid_size': 1, 'pupil_mask': '', 'background': 'avg', 'observations_id': 1, 'exptime': 1000, - 'offsets': [{'offset_id': 1, 'offset_centre': False, 'offset_ra': 0.5, 'offset_dec': 0.0, 'offset_pa': 27.0}] + 'memmap': True, + 'parallel': True, + 'convolve_size': 4096, + 'offsets': [ + { + 'offset_id': 1, + 'offset_centre': False, + 'offset_ra': 0.5, + 'offset_dec': 0.0, + 'offset_pa': 27.0 + } + ] } + return obs + + +def test_wfirst_observation(): + + stellar_cat_file, galaxy_cat_file = create_catalogues() + + obs = get_default_obs() obm = ObservationModule(obs) obm.nextObservation() - output_stellar_catalogues = obm.addCatalogue(stellar_cat_file) output_galaxy_catalogues = obm.addCatalogue(galaxy_cat_file) + psf_file = obm.addError() + fits_file, mosaic_file, params = obm.finalize(mosaic=False) + + +@pytest.mark.veryslow +def test_wfirst_observation_deluxe(): + + stellar_cat_file, galaxy_cat_file = create_catalogues() + + obs = get_default_obs() + obs['psf_grid_size'] = 3 + obs['oversample'] = 5 + obm = ObservationModule(obs) + obm.nextObservation() + output_stellar_catalogues = obm.addCatalogue(stellar_cat_file) + output_galaxy_catalogues = obm.addCatalogue(galaxy_cat_file) psf_file = obm.addError() + fits_file, mosaic_file, params = obm.finalize(mosaic=False) + + +obs_data = [ + ( + { + 'psf_grid_size': 3, + }, + ), + ( + { + 'oversample': 3, + }, + ), + ( + { + 'memmap': False, + }, + ), + ( + { + 'filters': ['F106'], + }, + ), + ( + { + 'parallel': False, + }, + ) +] + +@pytest.mark.veryslow +@pytest.mark.parametrize(("obs_changes"), obs_data) +def test_obs_parameters(obs_changes): + + stellar_cat_file, galaxy_cat_file = create_catalogues() - fits_file, mosaic_file, params = obm.finalize(mosaic=False) \ No newline at end of file + obs = get_default_obs() + for key in obs_changes[0]: + obs[key] = obs_changes[0][key] + + obm = ObservationModule(obs) + obm.nextObservation() + output_stellar_catalogues = obm.addCatalogue(stellar_cat_file) + output_galaxy_catalogues = obm.addCatalogue(galaxy_cat_file) + psf_file = obm.addError() + fits_file, mosaic_file, params = obm.finalize(mosaic=False) diff --git a/stips/utilities/testing.py b/stips/utilities/testing.py index 89a8f24..ca00ddf 100644 --- a/stips/utilities/testing.py +++ b/stips/utilities/testing.py @@ -28,7 +28,8 @@ def makeWCS(coords=[0.,0.],coord_types=["RA---TAN","DEC--TAN"],xsize=1,ysize=1,p def verifyData(dat1,dat2): - assert dat1.shape == dat2.shape + assert dat1.shape[0] == dat2.shape[0] + assert dat1.shape[1] == dat2.shape[1] # np.testing.assert_allclose(dat1,dat2,atol=1e-3) for y in range(dat1.shape[0]): for x in range(dat1.shape[1]): @@ -76,4 +77,4 @@ def verifyParameters(image,results): np.testing.assert_allclose((image.header['RA_APER']),(results['ra_aper']),atol=1e-3) np.testing.assert_allclose((image.header['DEC_APER']),(results['dec_aper']),atol=1e-3) assert image.header['NAXIS1'] == results['naxis1'] - assert image.header['NAXIS2'] == results['naxis2'] \ No newline at end of file + assert image.header['NAXIS2'] == results['naxis2'] diff --git a/stips/utilities/utilities.py b/stips/utilities/utilities.py index e252c26..bd17d90 100644 --- a/stips/utilities/utilities.py +++ b/stips/utilities/utilities.py @@ -16,6 +16,7 @@ from astropy.io import ascii from astropy.table import Table from jwst_backgrounds.jbt import background +from photutils.psf.models import GriddedPSFModel from stips.version import __version__ as __stips__version__ @@ -44,14 +45,26 @@ def __stips__version__(self): #----------- class ImageData(object): - def __init__(self, fname, shape, mode='r+'): - self.fp = np.memmap(fname, dtype='float32', mode=mode, shape=shape) + def __init__(self, fname, shape, mode='r+', memmap=True): + self.shape = tuple(shape) + if memmap: + self.fp = np.memmap(fname, dtype='float32', mode=mode, + shape=self.shape) + else: + if isinstance(fname, np.ndarray): + self.fp = fname + else: + self.fp = np.ndarray(self.shape, dtype='float32') def __enter__(self): return self.fp def __exit__(self, exc_type, exc_value, traceback): del self.fp + + @property + def data(self): + return self.fp #----------- class CachedJbtBackground(background, object): @@ -345,17 +358,32 @@ def incr(self): -def computation(arr, Hf, pos, Nfft, y, ys, adjust, lock, path): +def computation(arr, Hf, pos, Nfft, y, ys, shape, lock, path, memmap): start_y, end_y, start_x, end_x, thisend_y, thisend_x = pos - conv = adjust(ifft2(Hf * fft2(arr, Nfft))) + if isinstance(Hf, GriddedPSFModel): + # Generate PSF + y, x = np.mgrid[start_y:end_y, start_x:end_x] + y_0, x_0 = (start_y+end_y)//2, (start_x+end_x)//2 + # Make a grid that's the size of the PSF fov_pix around its centre. + ys2, xs2 = shape[0]//2, shape[1]//2 + y, x = np.mgrid[y_0-ys2:y_0+ys2, x_0-xs2:x_0+xs2] + psf = Hf.evaluate(x=x, y=y, flux=1, x_0=x_0, y_0=y_0) + fftHf = fft2(psf, Nfft) + conv = np.real(ifft2(fftHf * fft2(arr, Nfft))) + else: + conv = np.real(ifft2(Hf * fft2(arr, Nfft))) lock.acquire() - with ImageData(y, ys) as dat: + with ImageData(y, ys, memmap=memmap) as dat: dat[start_y:thisend_y, start_x:thisend_x] += (conv[:(thisend_y-start_y), :(thisend_x-start_x)]) lock.release() return "[{}, {}]".format(start_y, start_x) -def overlapaddparallel(Amat, Hmat, L=None, Nfft=None, y=None, verbose=False, logger=None, state_setter=None, base_state="", path=None, cores=None): +def overlapaddparallel(Amat, amat_shape, + Hmat, hmat_shape, + L=None, Nfft=None, y=None, verbose=False, logger=None, + state_setter=None, base_state="", path=None, cores=None, + memmap=True): """ Fast two-dimensional linear convolution via the overlap-add method. The overlap-add method is well-suited to convolving a very large array, @@ -409,10 +437,15 @@ def overlapaddparallel(Amat, Hmat, L=None, Nfft=None, y=None, verbose=False, log import multiprocessing.dummy as multiprocessing - M = np.array(Hmat.shape) - Na = np.array(Amat.shape) + if amat_shape is None: + amat_shape = Amat.shape + if hmat_shape is None: + hmat_shape = Hmat.shape - ys = (Amat.shape[0]+Hmat.shape[0]-1, Amat.shape[1]+Hmat.shape[1]-1) + M = np.array(hmat_shape) + Na = np.array(amat_shape) + + ys = (amat_shape[0] + hmat_shape[0] - 1, amat_shape[1] + hmat_shape[1] - 1) if path is None: path = os.getcwd() @@ -430,11 +463,17 @@ def overlapaddparallel(Amat, Hmat, L=None, Nfft=None, y=None, verbose=False, log if not (np.all(L > 0) and L.size == 2): raise ValueError('L must have two positive elements') if not (np.all(Nfft >= L + M - 1) and Nfft.size == 2): - raise ValueError('Nfft must have two elements >= L + M - 1 where M = Hmat.shape') - if not (Amat.ndim <= 2 and Hmat.ndim <= 2): - raise ValueError('Amat and Hmat must be 2D arrays') - - Hf = fft2(Hmat, Nfft) + msg = 'Nfft must have two elements >= L + M - 1 where M = Hmat.shape' + raise ValueError(msg) + if not (Amat.ndim <= 2): + raise ValueError('Amat must be a 2D array') + if hasattr(Hmat, 'ndim') and not (Hmat.ndim <= 2): + raise ValueError('Hmat must be a 2D array') + + if isinstance(Hmat, GriddedPSFModel): + Hf = Hmat + else: + Hf = fft2(Hmat, Nfft) pool = multiprocessing.Pool(processes=cores) m = multiprocessing.Manager() @@ -445,8 +484,6 @@ def overlapaddparallel(Amat, Hmat, L=None, Nfft=None, y=None, verbose=False, log (XDIM, YDIM) = (1, 0) adjust = lambda x: x # no adjuster - if np.isrealobj(Amat) and np.isrealobj(Hmat): # unless inputs are real - adjust = np.real # then ensure real start = [0, 0] endd = [0, 0] @@ -472,7 +509,7 @@ def closing_log(pos): pos = (start[YDIM], endd[YDIM], start[XDIM], endd[XDIM], thisend[YDIM], thisend[XDIM]) sub_arr = np.empty_like(Amat[start[YDIM]:endd[YDIM], start[XDIM]:endd[XDIM]]) sub_arr[:,:] = Amat[start[YDIM]:endd[YDIM], start[XDIM]:endd[XDIM]] - res = pool.apply_async(computation, args=(sub_arr, Hf, pos, Nfft, y, ys, adjust, lock, path), callback=closing_log) + res = pool.apply_async(computation, args=(sub_arr, Hf, pos, Nfft, y, ys, adjust, lock, path, memmap), callback=closing_log) results.append(res) start[YDIM] += L[YDIM] start[XDIM] += L[XDIM] @@ -483,7 +520,10 @@ def closing_log(pos): # logger.info("Success: {}".format(result.successful())) -def overlapadd2(Amat, Hmat, L=None, Nfft=None, y=None, verbose=False, logger=None, state_setter=None, base_state=""): +def overlapadd2(Amat, amat_shape, + Hmat, hmat_shape, + L=None, Nfft=None, y=None, verbose=False, logger=None, + state_setter=None, base_state=""): """ Fast two-dimensional linear convolution via the overlap-add method. The overlap-add method is well-suited to convolving a very large array, @@ -534,8 +574,13 @@ def overlapadd2(Amat, Hmat, L=None, Nfft=None, y=None, verbose=False, logger=Non ---------- Wikipedia is only semi-unhelpful on this topic: see "Overlap-add method". """ - M = np.array(Hmat.shape) - Na = np.array(Amat.shape) + if amat_shape is None: + amat_shape = Amat.shape + if hmat_shape is None: + hmat_shape = Hmat.shape + + M = np.array(hmat_shape) + Na = np.array(amat_shape) if y is None: y = np.zeros(M + Na - 1, dtype=Amat.dtype) @@ -556,15 +601,17 @@ def overlapadd2(Amat, Hmat, L=None, Nfft=None, y=None, verbose=False, logger=Non raise ValueError('L must have two positive elements') if not (np.all(Nfft >= L + M - 1) and Nfft.size == 2): raise ValueError('Nfft must have two elements >= L + M - 1 where M = Hmat.shape') - if not (Amat.ndim <= 2 and Hmat.ndim <= 2): - raise ValueError('Amat and Hmat must be 2D arrays') + if not (Amat.ndim <= 2): + raise ValueError('Amat must be a 2D array') + if hasattr(Hmat, 'ndim') and not (Hmat.ndim <= 2): + raise ValueError('Hmat must be a 2D array') - Hf = fft2(Hmat, Nfft) + if isinstance(Hmat, GriddedPSFModel): + Hf = Hmat + else: + Hf = fft2(Hmat, Nfft) (XDIM, YDIM) = (1, 0) - adjust = lambda x: x # no adjuster - if np.isrealobj(Amat) and np.isrealobj(Hmat): # unless inputs are real - adjust = np.real # then ensure real start = [0, 0] endd = [0, 0] total_boxes = (Na[XDIM] // L[XDIM] + 1) * (Na[YDIM] // L[YDIM] + 1) @@ -576,11 +623,27 @@ def overlapadd2(Amat, Hmat, L=None, Nfft=None, y=None, verbose=False, logger=Non if verbose and logger is not None: logger.info("Starting box {}".format(start)) if verbose and state_setter is not None: - state_setter(base_state + " {:.2f}% done".format((current_box/total_boxes)*100.)) + new_state = base_state + " {:.2f}% done" + state_setter(new_state.format((current_box/total_boxes)*100.)) endd[YDIM] = min(start[YDIM] + L[YDIM], Na[YDIM]) thisend = np.minimum(Na + M - 1, start + Nfft) - yt = adjust(ifft2(Hf * fft2(Amat[start[YDIM] : endd[YDIM], start[XDIM] : endd[XDIM]], Nfft))) - y[start[YDIM] : thisend[YDIM], start[XDIM] : thisend[XDIM]] += (yt[:(thisend[YDIM] - start[YDIM]), :(thisend[XDIM] - start[XDIM])]) + Asub = Amat[start[YDIM]:endd[YDIM], start[XDIM]:endd[XDIM]] + Af = fft2(Asub, Nfft) + + if isinstance(Hf, GriddedPSFModel): + # Generate PSF + yg, xg = np.mgrid[start[YDIM]:endd[YDIM],start[XDIM]:endd[XDIM]] + y_0 = (start[YDIM]+endd[YDIM])//2 + x_0 = (start[XDIM]+endd[XDIM])//2 + # Make a grid that's the size of the PSF fov_pix around its centre. + ys2, xs2 = hmat_shape[0]//2, hmat_shape[1]//2 + yg, xg = np.mgrid[y_0-ys2:y_0+ys2, x_0-xs2:x_0+xs2] + psf = Hf.evaluate(x=xg, y=yg, flux=1, x_0=x_0, y_0=y_0) + yt = np.real(ifft2(fft2(psf, Nfft) * Af)) + else: + yt = np.real(ifft2(Hf * Af)) + ys = yt[:(thisend[YDIM]-start[YDIM]), :(thisend[XDIM]-start[XDIM])] + y[start[YDIM]:thisend[YDIM], start[XDIM]:thisend[XDIM]] += ys[:,:] start[YDIM] += L[YDIM] current_box += 1 start[XDIM] += L[XDIM]