Skip to content

Commit

Permalink
improved slicing capabilities, added mask util functions
Browse files Browse the repository at this point in the history
  • Loading branch information
menegon committed Jan 24, 2018
1 parent eea1aa8 commit bc36a98
Showing 1 changed file with 79 additions and 12 deletions.
91 changes: 79 additions & 12 deletions rectifiedgrid/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@
from itertools import izip
import matplotlib.pyplot as plt
from matplotlib import colors
from matplotlib.ticker import LogFormatter

BASEMAP = False

try:
from mpl_toolkits import basemap
BASEMAP = True
Expand Down Expand Up @@ -164,21 +167,34 @@ def __array_finalize__(self, obj):
if callable(getattr(super(SubRectifiedGrid, self),
'__array_finalize__', None)):
super(SubRectifiedGrid, self).__array_finalize__(obj)

self.proj = copy.deepcopy(getattr(obj, 'proj', None))
self.gtransform = copy.deepcopy(getattr(obj, 'gtransform', None))
# self.proj = getattr(obj, 'proj', None)
# self.gtransform = getattr(obj, 'gtransform', None)
return

def copy(self, *args, **kwargs):
obj = super(SubRectifiedGrid, self).copy(*args, **kwargs)
obj.proj = copy.deepcopy(getattr(self, 'proj', None))
obj.gtransform = copy.deepcopy(getattr(self, 'gtransform', None))
return obj

def __getitem__(self, *args, **kwargs):
rslice = None
cslice = None
rstart = 0
cstart = 0
if isinstance(args[0], slice):
rslice = args[0]
cslice = slice(None, None, None)
if isinstance(args[0], tuple) and isinstance(args[0][0], slice) and isinstance(args[0][1], slice):
rslice, cslice = args[0]
if isinstance(args[0], tuple) and len(args[0]) == 2:
if isinstance(args[0], tuple) and isinstance(args[0][0], slice) and isinstance(args[0][1], slice):
rslice, cslice = args[0]
if isinstance(args[0], tuple) and isinstance(args[0][0], np.ndarray) and isinstance(args[0][1], np.ndarray):
rstart = args[0][0].min()
cstart = args[0][1].min()
obj = super(SubRectifiedGrid, self).__getitem__(*args, **kwargs)
rstart = 0
cstart = 0
if rslice is not None and rslice.start is not None:
rstart = rslice.start
if cslice is not None and cslice.start is not None:
Expand All @@ -187,6 +203,7 @@ def __getitem__(self, *args, **kwargs):
g = obj.gtransform
xmax, ymax = g * [cstart, rstart]
obj.gtransform = Affine(g.a, g.b, xmax, g.d, g.e, ymax)
# self.gtransform = Affine(g.a, g.b, xmax, g.d, g.e, ymax)
return obj

# def __add__(self, other):
Expand Down Expand Up @@ -251,7 +268,7 @@ def rasterize_features_area(self, features):

if len(features) != 1:
# create a spatialindex
print "create spatial index"
# print "create spatial index"
stream = ((i, geo.bounds, value) for i, (geo, value) in
enumerate(features))
sindex = RTreeIndex(stream)
Expand Down Expand Up @@ -329,6 +346,8 @@ def write_raster(self, filepath, dtype=None, driver='GTiff', nodata=None, compre

if dtype is None:
dtype = self.dtype
if dtype == 'float64':
dtype = 'float32'

profile = {
'count': count,
Expand Down Expand Up @@ -369,11 +388,26 @@ def write_raster(self, filepath, dtype=None, driver='GTiff', nodata=None, compre
# dst.write_mask(255 * (~self.mask).astype('uint8'))
# dst.close()

def masked_mask(self, mask, copy=False):
raster = self
if copy:
raster = self.copy()
raster[mask] = np.ma.masked
return raster

def masked_equal(self, value, copy=False):
raster = self
if copy:
raster = self.copy()
raster[:] = np.ma.masked_equal(raster, value)
raster[:] = np.ma.masked_equal(raster, value, copy=True)
return raster

def masked_not_equal(self, value, copy=False):
raster = self
if copy:
raster = self.copy()
raster[:] = np.ma.masked_not_equal(raster, value, copy=True)
# raster[raster != 3.] = np.ma.masked
return raster

def masked_values(self, value, copy=False):
Expand Down Expand Up @@ -422,6 +456,13 @@ def lognorm(self, copy=False):
raster.norm()
return raster

def replace_value(self, oldvalue, value, copy=False):
raster = self
if copy:
raster = self.copy()
raster[raster == oldvalue] = value
return raster

def log(self, copy=False):
raster = self
if copy:
Expand All @@ -446,16 +487,28 @@ def gaussian_filter(self, sigma, mode="constant", copy=False, **kwargs):
raster[:] = ndimage.gaussian_filter(raster, sigma, mode=mode, **kwargs)
return raster

def fill_underlying_data(self, fill_value=None):
def fill_underlying_data(self, fill_value=None, copy=False):
raster = self
if copy:
raster = self.copy()
self.data[:] = self.filled(fill_value)
return raster

def unmask(self, fill_value=None, copy=False):
raster = self
if copy:
raster = self.copy()
raster.data[raster.mask] = fill_value
raster.mask = False
return raster

def to_srs_like(self, rgrid, src_nodata=None, dst_nodata=None,
resampling=Resampling.bilinear):
if src_nodata is None:
src_nodata = self.fill_value
if dst_nodata is None:
dst_nodata = rgrid.fill_value
print src_nodata, dst_nodata
# print src_nodata, dst_nodata
# TODO: actually this modify the original data
self.fill_underlying_data(src_nodata)

Expand Down Expand Up @@ -527,7 +580,8 @@ def zoom(self, zoom, resampling=Resampling.bilinear):
def plot(self, cmap='Greys'):
if isinstance(cmap, str):
cmap = plt.get_cmap(cmap)
plt.imshow(self, cmap=cmap)
mapimg = plt.imshow(self, cmap=cmap)
plt.colorbar(mapimg, orientation='vertical')

def get_basemap(self, ax=None):
minx, miny, maxx, maxy = self.geollur
Expand All @@ -540,7 +594,8 @@ def get_basemap(self, ax=None):
def plotmap(self, legend=False, arcgis=False, coast=False, countries=False,
rivers=False, grid=False, gridrange=2, bluemarble=False, etopo=False,
maptype=None, cmap=None, norm=None, logcolor=False, vmin=None,
vmax=None, ax=None, basemap=None):
vmax=None, ax=None, basemap=None, ticks=None, minor_thresholds=None,
arcgisxpixels=1000):
if not BASEMAP:
raise ImportError("Cannot load mpl_toolkits module")
if maptype == 'minimal':
Expand Down Expand Up @@ -573,7 +628,7 @@ def plotmap(self, legend=False, arcgis=False, coast=False, countries=False,

if arcgis:
m.arcgisimage(service='ESRI_Imagery_World_2D',
xpixels=2000, verbose= True)
xpixels=arcgisxpixels, verbose= True)

mapimg = m.imshow(np.flipud(self), cmap=cmap, norm=norm,
vmin=vmin, vmax=vmax)
Expand All @@ -588,7 +643,11 @@ def plotmap(self, legend=False, arcgis=False, coast=False, countries=False,
m.drawparallels(np.arange(-90,90,gridrange),labels=[1,0,0,0],fontsize=10)
m.drawmeridians(np.arange(-90,90,gridrange),labels=[0,0,0,1],fontsize=10)
if legend:
plt.colorbar(mapimg, orientation='vertical', ax=ax)
if logcolor:
formatter = LogFormatter(10, labelOnlyBase=False, minor_thresholds=minor_thresholds)
cbar = plt.colorbar(mapimg, orientation='vertical', ax=ax, ticks=ticks, format=formatter)
else:
cbar = plt.colorbar(mapimg, orientation='vertical', ax=ax, ticks=ticks)

return m, mapimg

Expand All @@ -603,3 +662,11 @@ def griddata(self, x, y, z, method='nearest', copy=False):
method=method)
raster[np.isnan(raster)] = np.ma.masked
return raster

def crop(self, value=None):
if value is None:
m = ~self.mask
else:
m = self != value
return self[np.ix_(m.any(1),
m.any(0))]

0 comments on commit bc36a98

Please sign in to comment.