Skip to content

Commit

Permalink
allow visualizers to plot predictions without ground truth (#1987)
Browse files Browse the repository at this point in the history
Co-authored-by: Adeel Hassan <[email protected]>
  • Loading branch information
AdeelH and AdeelH authored Nov 8, 2023
1 parent e882bfe commit d173194
Show file tree
Hide file tree
Showing 9 changed files with 238 additions and 101 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import (Sequence, Optional)
from typing import TYPE_CHECKING, Optional, Sequence
from textwrap import wrap

import torch
Expand All @@ -7,14 +7,17 @@
from rastervision.pytorch_learner.utils import (plot_channel_groups,
channel_groups_to_imgs)

if TYPE_CHECKING:
from matplotlib.pyplot import Axes


class ClassificationVisualizer(Visualizer):
"""Plots samples from image classification Datasets."""

def plot_xyz(self,
axs: Sequence,
axs: Sequence['Axes'],
x: torch.Tensor,
y: int,
y: Optional[int] = None,
z: Optional[int] = None,
plot_title: bool = True) -> None:
channel_groups = self.get_channel_display_groups(x.shape[1])
Expand All @@ -30,46 +33,61 @@ def plot_xyz(self,
# plot label
class_names = self.class_names
class_names = ['-\n-'.join(wrap(c, width=16)) for c in class_names]
if z is None:
# just display the class name as text
class_name = class_names[y]
label_ax.text(
.5,
.5,
class_name,
ha='center',
va='center',
fontdict={
'size': 20,
'family': 'sans-serif'
})
label_ax.set_xlim((0, 1))
label_ax.set_ylim((0, 1))
label_ax.axis('off')
else:
# display predicted class probabilities as a horizontal bar plot
# legend: green = ground truth, dark-red = wrong prediction,
# light-gray = other. In case predicted class matches ground truth,
# only one bar will be green and the others will be light-gray.
class_probabilities = z.softmax(dim=-1)
class_index_pred = z.argmax(dim=-1)
if y is not None and z is None:
self.plot_gt(label_ax, class_names, y)
elif z is not None:
self.plot_pred(label_ax, class_names, z, y=y)
if plot_title:
label_ax.set_title('Prediction')

def plot_gt(self, ax: 'Axes', class_names: Sequence[str], y: torch.Tensor):
"""Display ground truth class names as text."""
class_name = class_names[y]
ax.text(
x=.5,
y=.5,
s=class_name,
ha='center',
va='center',
fontdict={
'size': 20,
'family': 'sans-serif'
})
ax.set_xlim((0, 1))
ax.set_ylim((0, 1))
ax.axis('off')

def plot_pred(self,
ax: 'Axes',
class_names: Sequence[str],
z: torch.Tensor,
y: Optional[torch.Tensor] = None):
"""Plot predictions.
Plots predicted class probabilities as a horizontal bar plot. If ground
truth, y, is provided, the bar colors represent: green = ground truth,
dark-red = wrong prediction, light-gray = other. In case predicted
class matches ground truth, only one bar will be green and the others
will be light-gray.
"""
class_probabilities = z.softmax(dim=-1)
class_index_pred = z.argmax(dim=-1)
bar_colors = ['lightgray'] * len(z)
if y is not None:
class_index_gt = y
bar_colors = ['lightgray'] * len(z)
if class_index_pred == class_index_gt:
bar_colors[class_index_pred] = 'green'
else:
bar_colors[class_index_pred] = 'darkred'
bar_colors[class_index_gt] = 'green'
label_ax.barh(
y=class_names,
width=class_probabilities,
color=bar_colors,
edgecolor='black')
label_ax.set_xlim((0, 1))
label_ax.xaxis.grid(linestyle='--', alpha=1)
label_ax.set_xlabel('Probability')
if plot_title:
label_ax.set_title('Prediction')
ax.barh(
y=class_names,
width=class_probabilities,
color=bar_colors,
edgecolor='black')
ax.set_xlim((0, 1))
ax.xaxis.grid(linestyle='--', alpha=1)
ax.set_xlabel('Probability')

def get_plot_ncols(self, **kwargs) -> int:
x = kwargs['x']
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,18 @@ def get_collate_fn(self):
def plot_xyz(self,
axs: Sequence,
x: torch.Tensor,
y: BoxList,
y: Optional[BoxList] = None,
z: Optional[BoxList] = None,
plot_title: bool = True) -> None:
y = y if z is None else z
channel_groups = self.get_channel_display_groups(x.shape[1])
imgs = channel_groups_to_imgs(x, channel_groups)

class_names = self.class_names
class_colors = self.class_colors
if y is not None or z is not None:
y = y if z is None else z
class_names = self.class_names
class_colors = self.class_colors
imgs = [
draw_boxes(img, y, class_names, class_colors) for img in imgs
]

imgs = channel_groups_to_imgs(x, channel_groups)
imgs = [draw_boxes(img, y, class_names, class_colors) for img in imgs]
plot_channel_groups(axs, imgs, channel_groups, plot_title=plot_title)
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import (Sequence, Optional)
from typing import TYPE_CHECKING, Optional, Sequence
from textwrap import wrap

import torch
Expand All @@ -9,6 +9,9 @@
from rastervision.pytorch_learner.utils import (plot_channel_groups,
channel_groups_to_imgs)

if TYPE_CHECKING:
from matplotlib.pyplot import Axes


class RegressionVisualizer(Visualizer):
"""Plots samples from image regression Datasets."""
Expand All @@ -22,7 +25,7 @@ def plot_xyz(self,
channel_groups = self.get_channel_display_groups(x.shape[1])

img_axes = axs[:-1]
label_ax = axs[-1]
label_ax: 'Axes' = axs[-1]

# plot image
imgs = channel_groups_to_imgs(x, channel_groups)
Expand All @@ -32,44 +35,63 @@ def plot_xyz(self,
# plot label
class_names = self.class_names
class_names = ['-\n-'.join(wrap(c, width=8)) for c in class_names]
if z is None:
# display targets as a horizontal bar plot
bars_gt = label_ax.barh(
y=class_names, width=y, color='lightgray', edgecolor='black')
# show values on the end of bars
label_ax.bar_label(bars_gt, fmt='%.3f', padding=3)

if y is not None and z is None:
self.plot_gt(label_ax, class_names, y)
if plot_title:
label_ax.set_title('Ground truth')
else:
# display targets and predictions as a grouped horizontal bar plot
bar_thickness = 0.35
y_tick_locs = np.arange(len(class_names))
bars_gt = label_ax.barh(
elif z is not None:
self.plot_pred(label_ax, class_names, z, y=y)

def plot_gt(self, ax: 'Axes', class_names: Sequence[str], y: torch.Tensor):
"""Plot targets as a horizontal bar plot with values at the tips."""
bars_gt = ax.barh(
y=class_names, width=y, color='lightgray', edgecolor='black')
# show values on the end of bars
ax.bar_label(bars_gt, fmt='%.3f', padding=3)

ax.xaxis.grid(linestyle='--', alpha=1)
ax.set_xlabel('Value')
ax.spines['right'].set_visible(False)
ax.get_yaxis().tick_left()

def plot_pred(self,
ax: 'Axes',
class_names: Sequence[str],
z: torch.Tensor,
y: Optional[torch.Tensor] = None):
"""Plot targets and predictions as a grouped horizontal bar plot."""
# display targets and predictions as a grouped horizontal bar plot
bar_thickness = 0.35 if y is not None else 0.70
y_tick_locs = np.arange(len(class_names))
if y is not None:
bars_gt = ax.barh(
y=y_tick_locs + bar_thickness / 2,
width=y,
height=bar_thickness,
color='lightgray',
edgecolor='black',
label='true')
bars_pred = label_ax.barh(
y=y_tick_locs - bar_thickness / 2,
width=z,
height=bar_thickness,
color=plt.get_cmap('tab10')(0),
edgecolor='black',
label='pred')
# show values on the end of bars
label_ax.bar_label(bars_gt, fmt='%.3f', padding=3)
label_ax.bar_label(bars_pred, fmt='%.3f', padding=3)
ax.bar_label(bars_gt, fmt='%.3f', padding=3)

bars_pred = ax.barh(
y=y_tick_locs - bar_thickness / 2,
width=z,
height=bar_thickness,
color=plt.get_cmap('tab10')(0),
edgecolor='black',
label='pred')
# show values on the end of bars
ax.bar_label(bars_pred, fmt='%.3f', padding=3)

label_ax.set_yticks(ticks=y_tick_locs, labels=class_names)
label_ax.legend(
ncol=2, loc='lower center', bbox_to_anchor=(0.5, 1.0))
ax.set_yticks(ticks=y_tick_locs, labels=class_names)
ax.legend(ncol=2, loc='lower center', bbox_to_anchor=(0.5, 1.0))

label_ax.xaxis.grid(linestyle='--', alpha=1)
label_ax.set_xlabel('Target value')
label_ax.spines['right'].set_visible(False)
label_ax.get_yaxis().tick_left()
ax.xaxis.grid(linestyle='--', alpha=1)
ax.set_xlabel('Value')
ax.spines['right'].set_visible(False)
ax.get_yaxis().tick_left()

def get_plot_ncols(self, **kwargs) -> int:
x = kwargs['x']
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import (Sequence, Optional, Union)
from typing import TYPE_CHECKING, Optional, Sequence, Union

import torch
import numpy as np
Expand All @@ -9,26 +9,32 @@
from rastervision.pytorch_learner.utils import (
color_to_triple, plot_channel_groups, channel_groups_to_imgs)

if TYPE_CHECKING:
from matplotlib.pyplot import Axes
from matplotlib.colors import Colormap


class SemanticSegmentationVisualizer(Visualizer):
"""Plots samples from semantic segmentation Datasets."""

def plot_xyz(self,
axs: Sequence,
x: torch.Tensor,
y: Union[torch.Tensor, np.ndarray],
y: Optional[Union[torch.Tensor, np.ndarray]] = None,
z: Optional[torch.Tensor] = None,
plot_title: bool = True) -> None:
channel_groups = self.get_channel_display_groups(x.shape[1])

img_axes = axs[:len(channel_groups)]
label_ax = axs[len(channel_groups)]

# plot image
imgs = channel_groups_to_imgs(x, channel_groups)
plot_channel_groups(
img_axes, imgs, channel_groups, plot_title=plot_title)

if y is None and z is None:
return

# plot labels
class_colors = self.class_colors
colors = [
Expand All @@ -38,27 +44,17 @@ def plot_xyz(self,
colors = np.array(colors) / 255.
cmap = mcolors.ListedColormap(colors)

label_ax.imshow(
y, vmin=0, vmax=len(colors), cmap=cmap, interpolation='none')
if plot_title:
label_ax.set_title(f'Ground truth')
label_ax.set_xticks([])
label_ax.set_yticks([])
if y is not None:
label_ax: 'Axes' = axs[len(channel_groups)]
self.plot_gt(label_ax, y, num_classes=len(colors), cmap=cmap)
if plot_title:
label_ax.set_title('Ground truth')

# plot predictions
if z is not None:
pred_ax = axs[-1]
preds = z.argmax(dim=0)
pred_ax.imshow(
preds,
vmin=0,
vmax=len(colors),
cmap=cmap,
interpolation='none')
self.plot_pred(pred_ax, z, num_classes=len(colors), cmap=cmap)
if plot_title:
pred_ax.set_title(f'Predicted labels')
pred_ax.set_xticks([])
pred_ax.set_yticks([])
pred_ax.set_title('Predicted labels')

# add a legend to the rightmost subplot
class_names = self.class_names
Expand All @@ -72,11 +68,30 @@ def plot_xyz(self,
loc='center left',
bbox_to_anchor=(1., 0.5))

def plot_gt(self, ax: 'Axes', y: Union[torch.Tensor, np.ndarray],
num_classes: int, cmap: 'Colormap', **kwargs):
ax.imshow(
y,
vmin=0,
vmax=num_classes,
cmap=cmap,
interpolation='none',
**kwargs)
ax.set_xticks([])
ax.set_yticks([])

def plot_pred(self, ax: 'Axes', z: Union[torch.Tensor, np.ndarray],
num_classes: int, cmap: 'Colormap', **kwargs):
if z.ndim == 3:
z = z.argmax(dim=0)
self.plot_gt(ax, y=z, num_classes=num_classes, cmap=cmap, **kwargs)

def get_plot_ncols(self, **kwargs) -> int:
x = kwargs['x']
nb_img_channels = x.shape[1]
ncols = len(self.get_channel_display_groups(nb_img_channels)) + 1
z = kwargs.get('z')
if z is not None:
if kwargs.get('y') is not None:
ncols += 1
if kwargs.get('z') is not None:
ncols += 1
return ncols
Loading

0 comments on commit d173194

Please sign in to comment.