Skip to content

Commit

Permalink
Make max_windows a required argument for RandomWindowGeoDataset (#2223)
Browse files Browse the repository at this point in the history
* Makes max_windows a required argument and make all args other than scene keyword args.

* Updated tutorials to reflect change to RandomWindowGeoDataset.
  • Loading branch information
keves1 authored and AdeelH committed Aug 30, 2024
1 parent f3440cc commit 702eee3
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 14 deletions.
3 changes: 2 additions & 1 deletion docs/usage/tutorials/sampling_training_data.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,8 @@
" # resize chips to 256x256 before returning\n",
" out_size=256,\n",
" # allow windows to overflow the extent by 100 pixels\n",
" padding=100\n",
" padding=100,\n",
" max_windows=10\n",
")\n",
"\n",
"img_full = ds.scene.raster_source[:, :]\n",
Expand Down
2 changes: 1 addition & 1 deletion docs/usage/tutorials/temporal.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,7 @@
"source": [
"scene = Scene(id='test_scene', raster_source=raster_source)\n",
"ds = SemanticSegmentationRandomWindowGeoDataset(\n",
" scene=scene, size_lims=(256, 256 + 1), out_size=256, return_window=True)"
" scene=scene, size_lims=(256, 256 + 1), out_size=256, max_windows=10, return_window=True)"
]
},
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -283,12 +283,13 @@ class RandomWindowGeoDataset(GeoDataset):
def __init__(
self,
scene: Scene,
*,
out_size: PosInt | tuple[PosInt, PosInt] | None,
size_lims: tuple[PosInt, PosInt] | None = None,
h_lims: tuple[PosInt, PosInt] | None = None,
w_lims: tuple[PosInt, PosInt] | None = None,
padding: NonNegInt | tuple[NonNegInt, NonNegInt] | None = None,
max_windows: NonNegInt | None = None,
max_windows: NonNegInt,
max_sample_attempts: PosInt = 100,
efficient_aoi_sampling: bool = True,
within_aoi: bool = True,
Expand Down Expand Up @@ -316,8 +317,7 @@ def __init__(
sides of the raster source. If ``None``, ``padding = size``.
Defaults to ``None``.
max_windows: Max allowed reads. Will raise ``StopIteration`` on
further read attempts. If None, will be set to ``np.inf``.
Defaults to ``None``.
further read attempts.
transform: Albumentations
transform to apply to the windows. Defaults to ``None``.
Each transform in Albumentations takes images of type uint8, and
Expand Down Expand Up @@ -384,9 +384,6 @@ def __init__(
padding = (max_h // 2, max_w // 2)
padding: tuple[NonNegInt, NonNegInt] = ensure_tuple(padding)

if max_windows is None:
max_windows = np.iinfo('int').max

self.size_lims = size_lims
self.h_lims = h_lims
self.w_lims = w_lims
Expand Down
27 changes: 21 additions & 6 deletions tests/pytorch_learner/dataset/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,26 +213,29 @@ def test_sample_window_within_aoi(self):

ds = RandomWindowGeoDataset(
scene,
10,
(5, 6),
out_size=10,
size_lims=(5, 6),
max_windows=10,
within_aoi=True,
transform_type=TransformType.noop,
)
self.assertNoError(ds.sample_window)

ds = RandomWindowGeoDataset(
scene,
10,
(12, 13),
out_size=10,
size_lims=(12, 13),
max_windows=10,
within_aoi=True,
transform_type=TransformType.noop,
)
self.assertRaises(StopIteration, ds.sample_window)

ds = RandomWindowGeoDataset(
scene,
10,
(12, 13),
out_size=10,
size_lims=(12, 13),
max_windows=10,
within_aoi=False,
transform_type=TransformType.noop,
)
Expand All @@ -245,6 +248,7 @@ def test_init_validation(self):
args = dict(
scene=scene,
out_size=10,
max_windows=10,
transform_type=TransformType.noop,
)
self.assertRaises(ValueError, lambda: RandomWindowGeoDataset(**args))
Expand All @@ -255,6 +259,7 @@ def test_init_validation(self):
out_size=10,
size_lims=(10, 11),
h_lims=(10, 11),
max_windows=10,
transform_type=TransformType.noop,
)
self.assertRaises(ValueError, lambda: RandomWindowGeoDataset(**args))
Expand All @@ -266,6 +271,7 @@ def test_init_validation(self):
size_lims=(10, 11),
h_lims=(10, 11),
w_lims=(10, 11),
max_windows=10,
transform_type=TransformType.noop,
)
self.assertRaises(ValueError, lambda: RandomWindowGeoDataset(**args))
Expand All @@ -275,6 +281,7 @@ def test_init_validation(self):
scene=scene,
out_size=10,
w_lims=(10, 11),
max_windows=10,
transform_type=TransformType.noop,
)
self.assertRaises(ValueError, lambda: RandomWindowGeoDataset(**args))
Expand All @@ -284,6 +291,7 @@ def test_init_validation(self):
scene,
out_size=None,
size_lims=(12, 13),
max_windows=10,
transform_type=TransformType.noop,
)
self.assertFalse(ds.normalize)
Expand All @@ -295,6 +303,7 @@ def test_init_validation(self):
out_size=None,
h_lims=(10, 11),
w_lims=(10, 11),
max_windows=10,
transform_type=TransformType.noop,
)
self.assertTupleEqual(ds.padding, (5, 5))
Expand All @@ -305,6 +314,7 @@ def test_min_max_size(self):
scene,
out_size=None,
size_lims=(10, 15),
max_windows=10,
transform_type=TransformType.noop,
)
self.assertTupleEqual(ds.min_size, (10, 10))
Expand All @@ -315,6 +325,7 @@ def test_min_max_size(self):
out_size=None,
h_lims=(10, 15),
w_lims=(8, 12),
max_windows=10,
transform_type=TransformType.noop,
)
self.assertTupleEqual(ds.min_size, (10, 8))
Expand All @@ -326,6 +337,7 @@ def test_sample_window_size(self):
scene,
out_size=None,
size_lims=(10, 15),
max_windows=10,
transform_type=TransformType.noop,
)
sampled_h, sampled_w = ds.sample_window_size()
Expand All @@ -337,6 +349,7 @@ def test_sample_window_size(self):
out_size=None,
h_lims=(10, 15),
w_lims=(8, 12),
max_windows=10,
transform_type=TransformType.noop,
)
sampled_h, sampled_w = ds.sample_window_size()
Expand All @@ -360,6 +373,7 @@ def test_return_window(self):
scene,
out_size=10,
size_lims=(5, 6),
max_windows=10,
transform_type=TransformType.noop,
return_window=True,
)
Expand All @@ -376,6 +390,7 @@ def test_triangle_missing(self):
scene=scene,
out_size=10,
size_lims=(5, 6),
max_windows=10,
transform_type=TransformType.noop,
)
self.assertNoError(lambda: RandomWindowGeoDataset(**args))
Expand Down

0 comments on commit 702eee3

Please sign in to comment.