Skip to content

Commit

Permalink
feat: Enhance figure processing and image dictionary creation in visu…
Browse files Browse the repository at this point in the history
…alization utilities
  • Loading branch information
rhoadesScholar committed Jan 14, 2025
1 parent e73cddc commit bd9a967
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 23 deletions.
2 changes: 1 addition & 1 deletion src/cellmap_data/multidataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def class_counts(self) -> dict[str, float]:
except AttributeError:
class_counts = {c: 0.0 for c in self.classes}
class_counts.update({c + "_bg": 0.0 for c in self.classes})
print("Gathering class counts")
print("Gathering class counts...")
for ds in tqdm(self.datasets):
for c in self.classes:
if c in ds.class_counts["totals"]:
Expand Down
5 changes: 3 additions & 2 deletions src/cellmap_data/transforms/targets/cellpose.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from cellpose.dynamics import masks_to_flows_gpu_3d, masks_to_flows
from cellpose.dynamics import masks_to_flows_gpu as masks_to_flows_gpu_2d
import torch


Expand All @@ -15,6 +13,9 @@ class CellposeFlow:

def __init__(self, ndim: int, device: str | None = None) -> None:
UserWarning("This is still in development and may not work as expected")
from cellpose.dynamics import masks_to_flows_gpu_3d, masks_to_flows
from cellpose.dynamics import masks_to_flows_gpu as masks_to_flows_gpu_2d

self.ndim = ndim
if device is None:
if torch.cuda.is_available():
Expand Down
8 changes: 7 additions & 1 deletion src/cellmap_data/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
from .figs import get_image_grid, get_image_dict, get_image_grid_numpy
from .figs import (
get_image_grid,
get_image_dict,
get_image_grid_numpy,
fig_to_image,
get_fig_dict,
)
from .dtype import torch_max_value
from .metadata import (
create_multiscale_metadata,
Expand Down
82 changes: 63 additions & 19 deletions src/cellmap_data/utils/figs.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,17 @@ def get_image_grid(
return fig


def fig_to_image(fig: plt.Figure) -> np.ndarray: # type: ignore
with io.BytesIO() as buff:
fig.savefig(buff, format="raw", dpi=fig.dpi)
buff.seek(0)
data = np.frombuffer(buff.getvalue(), dtype=np.uint8)
w, h = fig.canvas.get_width_height()
im = data.reshape((int(h), int(w), -1))
plt.close("all")
return im


def get_image_grid_numpy(
input_data: torch.Tensor,
target_data: torch.Tensor,
Expand Down Expand Up @@ -117,22 +128,10 @@ def get_image_grid_numpy(
clim=clim,
cmap=cmap,
)
# fig.tight_layout(pad=0)
# fig.canvas.draw()
# im = np.frombuffer(fig.canvas.buffer_rgba(), dtype=np.uint8)
# im = im.reshape(fig.canvas.get_width_height()[::-1] + (4,))
# plt.close(fig)
with io.BytesIO() as buff:
fig.savefig(buff, format="raw", dpi=fig.dpi)
buff.seek(0)
data = np.frombuffer(buff.getvalue(), dtype=np.uint8)
w, h = fig.canvas.get_width_height()
im = data.reshape((int(h), int(w), -1))
plt.close("all")
return im
return fig_to_image(fig)


def get_image_dict(
def get_fig_dict(
input_data: torch.Tensor,
target_data: torch.Tensor,
outputs: torch.Tensor,
Expand All @@ -143,7 +142,7 @@ def get_image_dict(
colorbar: bool = True,
) -> dict:
"""
Create a dictionary of images for input, target, and output data.
Create a dictionary of figures for input, target, and output data.
Args:
input_data (torch.Tensor): Input data.
target_data (torch.Tensor): Target data.
Expand All @@ -161,10 +160,15 @@ def get_image_dict(
batch_size = input_data.shape[0]
image_dict = {}
for c, label in enumerate(classes):
if colorbar:
grid_spec_kw = {"width_ratios": [1, 1, 1, 1, 0.2]}
else:
grid_spec_kw = {}
fig, ax = plt.subplots(
batch_size,
4 + colorbar,
figsize=(fig_size * (4 + colorbar), fig_size * batch_size),
figsize=(fig_size * (4 + colorbar * 0.2), fig_size * batch_size),
gridspec_kw=grid_spec_kw,
)
if len(ax.shape) == 1:
ax = ax[None, :]
Expand All @@ -181,13 +185,14 @@ def get_image_dict(
im = ax[b, 3].imshow(output, clim=clim)
ax[b, 3].axis("off")
ax[b, 3].set_title(f"Pred. {label}")
if colorbar and clim is None:
if colorbar:
orientation = "vertical"
location = "right"
fig.colorbar(
cbar = fig.colorbar(
im, orientation=orientation, location=location, cax=ax[b, 4]
)
ax[b, 4].aspect = 10
cbar.ax.set_aspect(40)
ax[b, 4].set_title("Intensity")
input_img = input_data[b][0].squeeze().cpu().detach().numpy()
if len(input_img.shape) == 3:
input_mid = input_img.shape[0] // 2
Expand Down Expand Up @@ -217,3 +222,42 @@ def get_image_dict(
fig.tight_layout()
image_dict[label] = fig
return image_dict


def get_image_dict(
input_data: torch.Tensor,
target_data: torch.Tensor,
outputs: torch.Tensor,
classes: Sequence[str],
batch_size: Optional[int] = None,
fig_size: int = 3,
clim: Optional[Sequence] = None,
colorbar: bool = True,
) -> dict:
"""
Create a dictionary of images for input, target, and output data.
Args:
input_data (torch.Tensor): Input data.
target_data (torch.Tensor): Target data.
outputs (torch.Tensor): Model outputs.
classes (list): List of class labels.
batch_size (int, optional): Number of images to display. Defaults to the length of the first axis of 'input_data'.
fig_size (int, optional): Size of the figure. Defaults to 3.
clim (tuple, optional): Color limits for the images. Defaults to be scaled by the image's intensity.
colorbar (bool, optional): Whether to display a colorbar for the model outputs. Defaults to True.
Returns:
image_dict (dict): Dictionary of image data.
"""
# TODO: Get list of figs for the batches, instead of one fig per class
fig_dict = get_fig_dict(
input_data=input_data,
target_data=target_data,
outputs=outputs,
classes=classes,
batch_size=batch_size,
fig_size=fig_size,
clim=clim,
colorbar=colorbar,
)
return {k: fig_to_image(v) for k, v in fig_dict.items()}

0 comments on commit bd9a967

Please sign in to comment.