From 43e31648fc9a32459b7b3aa4d2c203b3fad01eee Mon Sep 17 00:00:00 2001 From: Adeel Hassan Date: Thu, 12 Oct 2023 09:29:09 -0400 Subject: [PATCH] fix bug in visualizer when plotting temporal data w/ batch size 1 --- .../pytorch_learner/dataset/visualizer/visualizer.py | 3 ++- .../visualizer/test_semantic_segmentation_visualizer.py | 4 ++++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/rastervision_pytorch_learner/rastervision/pytorch_learner/dataset/visualizer/visualizer.py b/rastervision_pytorch_learner/rastervision/pytorch_learner/dataset/visualizer/visualizer.py index 82491106a..0866da8b2 100644 --- a/rastervision_pytorch_learner/rastervision/pytorch_learner/dataset/visualizer/visualizer.py +++ b/rastervision_pytorch_learner/rastervision/pytorch_learner/dataset/visualizer/visualizer.py @@ -119,7 +119,8 @@ def plot_batch(self, batch_sz, T, *_ = x.shape params['fig_args']['figsize'][1] *= T fig = plt.figure(**params['fig_args']) - subfigs = fig.subfigures(nrows=batch_sz, ncols=1, hspace=0.0) + subfigs = fig.subfigures( + nrows=batch_sz, ncols=1, hspace=0.0, squeeze=False) subfig_axs = [ subfig.subplots( nrows=T, ncols=params['subplot_args']['ncols']) diff --git a/tests/pytorch_learner/dataset/visualizer/test_semantic_segmentation_visualizer.py b/tests/pytorch_learner/dataset/visualizer/test_semantic_segmentation_visualizer.py index e7e44272e..edd73fa81 100644 --- a/tests/pytorch_learner/dataset/visualizer/test_semantic_segmentation_visualizer.py +++ b/tests/pytorch_learner/dataset/visualizer/test_semantic_segmentation_visualizer.py @@ -40,6 +40,8 @@ def test_plot_batch_temporal(self): x = torch.randn(size=(2, 3, 4, 256, 256)) y = (torch.randn(size=(2, 256, 256)) > 0).long() self.assertNoError(lambda: viz.plot_batch(x, y)) + # w/o z, batch size = 1 + self.assertNoError(lambda: viz.plot_batch(x[[0]], y[[0]])) # w/ z viz = SemanticSegmentationVisualizer( @@ -50,3 +52,5 @@ def test_plot_batch_temporal(self): y = (torch.randn(size=(2, 256, 256)) > 0).long() z = torch.randn(size=(2, num_classes, 256, 256)).softmax(dim=-3) self.assertNoError(lambda: viz.plot_batch(x, y, z=z)) + # w/ z, batch size = 1 + self.assertNoError(lambda: viz.plot_batch(x[[0]], y[[0]]))