Skip to content

Commit

Permalink
STYLE: docstring and typehints for plot_dice
Browse files Browse the repository at this point in the history
  • Loading branch information
daniellepace committed Nov 14, 2023
1 parent 3967219 commit 86a9803
Showing 1 changed file with 24 additions and 1 deletion.
25 changes: 24 additions & 1 deletion ml4h/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -2737,7 +2737,30 @@ def plot_precision_recalls(predictions, truth, labels, title, prefix="./figures/
plt.savefig(figure_path)
logging.info("Saved Precision Recall curve at: {}".format(figure_path))

def plot_dice(predictions, truth, labels, paths, title, prefix="./figures/", dpi=300, width=3, height=3):
def plot_dice(
predictions: Dict[str, List[np.ndarray]],
truth: np.ndarray,
labels: Dict[str, int],
paths: List[str],
title: str,
prefix: str = "./figures/",
dpi: int = 300,
width: int = 3,
height: int = 3,
) -> None:
"""
Produces boxplots of dice score distributions and .tsv files of dice scores for individual images and structures.
:param predictions: dictionary of predicted segmentations for each model, in which keys are model names and values are lists of arrays with shape (height, width, num_labels)
:param truth: ground truth segmentations, with shape (num images, height, width, num labels)
:param labels: channel map dictionary mapping label names to integer values
:param paths: paths of input hd5 files
:param title: name for the output files
:param prefix: directory that the outputs will be written to
:param dpi: dots per inch of the plot
:param width: width of the plot
:param height: height of the plot
:return: None
"""
label_names = labels.keys()
label_vals = [labels[k] for k in label_names]
batch_size = truth.shape[0]
Expand Down

0 comments on commit 86a9803

Please sign in to comment.