-
Notifications
You must be signed in to change notification settings - Fork 420
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
9398cac
commit 537c65d
Showing
4 changed files
with
380 additions
and
42 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1204,9 +1204,9 @@ def suppfig_specialist(folder, save_fig=True): | |
|
||
il = 0 | ||
|
||
fig = plt.figure(figsize=(9, 5), dpi=100) | ||
fig = plt.figure(figsize=(9, 9), dpi=100) | ||
yratio = 9 / 5 | ||
grid = plt.GridSpec(2, 4, figure=fig, left=0.02, right=0.96, top=0.96, bottom=0.1, | ||
grid = plt.GridSpec(3, 4, figure=fig, left=0.02, right=0.96, top=0.96, bottom=0.1, | ||
wspace=0.15, hspace=0.2) | ||
|
||
titles = ["train - clean", "train - noisy", "test - noisy"] | ||
|
@@ -1265,32 +1265,46 @@ def suppfig_specialist(folder, save_fig=True): | |
ax.set_xticks(np.arange(0.5, 1.05, 0.1)) | ||
ax.set_xlim([0.5, 1.0]) | ||
|
||
transl = mtransforms.ScaledTranslation(-10 / 72, 20 / 72, fig.dpi_scale_trans) | ||
grid1 = matplotlib.gridspec.GridSpecFromSubplotSpec(2, 5, subplot_spec=grid[1:, :], wspace=0.05, | ||
hspace=0.1) | ||
|
||
kk = [2, 3, 4, 10] | ||
transl = mtransforms.ScaledTranslation(-10 / 72, 25 / 72, fig.dpi_scale_trans) | ||
|
||
kk = [2, 3, 4, 6, 10] | ||
iex = 8 | ||
ylim = [10, 310] | ||
xlim = [100, 500] | ||
ylim = [125, 512] # [0, 350] | ||
xlim = [50, 325] # [100, 500] | ||
legstr0[-1] = u"\u2013 Cellpose3 (per. + seg.)" | ||
for j, k in enumerate(kk): | ||
ax = plt.subplot(grid[1, j]) | ||
pos = ax.get_position().bounds | ||
ax.set_position([pos[0], pos[1] - 0.07, pos[2], pos[3]]) | ||
img0 = imgs_all[k][iex].squeeze() | ||
img0 *= 1.1 | ||
img0 = np.clip(img0, 0, 1) | ||
outlines_gt = utils.outlines_list(masks_all[0][iex].T.copy(), multiprocessing=False) | ||
for ii in range(2): | ||
ax = plt.subplot(grid1[ii, j]) | ||
pos = ax.get_position().bounds | ||
ax.set_position([pos[0], pos[1] - 0.07 + ii*0.03, pos[2], pos[3]]) | ||
img0 = imgs_all[k][iex].squeeze().T | ||
masks0 = masks_all[k][iex].squeeze().T | ||
img0 *= 1. | ||
img0 = np.clip(img0, 0, 1) | ||
|
||
ax.imshow(img0, cmap="gray", vmin=0, vmax=1) | ||
ax.axis("off") | ||
ax.set_ylim(ylim) | ||
ax.set_xlim(xlim) | ||
ax.set_title(legstr0[k][2:], color=cols0[k], fontsize="medium") | ||
ax.text(1, -0.04, f"[email protected] = {aps[k,iex,0] : 0.2f}", va="top", ha="right", | ||
transform=ax.transAxes) | ||
if j == 0: | ||
il = plot_label(ltr, il, ax, transl, fs_title) | ||
ax.text(0.02, 1.2, "Denoised test image", fontsize="large", | ||
fontstyle="italic", transform=ax.transAxes) | ||
ax.imshow(img0, cmap="gray", vmin=0, vmax=1) | ||
if ii==1: | ||
outlines = utils.outlines_list(masks0, multiprocessing=False) | ||
for o in outlines_gt: | ||
ax.plot(o[:, 0], o[:, 1], color=[0.7,0.4,1], lw=2) | ||
for o in outlines: | ||
ax.plot(o[:, 0], o[:, 1], color=[1, 1, 0.3], lw=1.5, ls="--") | ||
ax.axis("off") | ||
ax.set_ylim(ylim) | ||
ax.set_xlim(xlim) | ||
if ii==0: | ||
ax.set_title(legstr0[k][2:], color=cols0[k], fontsize="medium") | ||
else: | ||
ax.text(1, -0.04, f"[email protected] = {aps[k,iex,0] : 0.2f}", va="top", ha="right", | ||
transform=ax.transAxes) | ||
if j == 0 and ii==0: | ||
il = plot_label(ltr, il, ax, transl, fs_title) | ||
ax.text(0.02, 1.15, "Denoised test image", fontsize="large", | ||
fontstyle="italic", transform=ax.transAxes) | ||
|
||
print(aps.mean(axis=1)[:, [0, 5, 8]]) | ||
|
||
|
@@ -1493,9 +1507,9 @@ def fig6(folder, save_fig=True): | |
|
||
diams = [utils.diameters(lbl)[0] for lbl in lbls] | ||
|
||
gen_model = "/home/carsen/dm11_string/datasets_cellpose/models/per_1.00_seg_1.50_rec_0.00_poisson_blur_downsample_2024_08_20_11_46_25.557039" | ||
gen_model = "oneclick_cyto3" #"/home/carsen/dm11_string/datasets_cellpose/models/per_1.00_seg_1.50_rec_0.00_poisson_blur_downsample_2024_08_20_11_46_25.557039" | ||
model = denoise.DenoiseModel(gpu=True, nchan=1, diam_mean=diam_mean, | ||
pretrained_model=gen_model) | ||
model_type=gen_model) | ||
seg_model = models.CellposeModel(gpu=True, model_type="cyto3") | ||
pscales = [1.5, 20., 1.5, 1., 5., 40., 3.] | ||
denoise.deterministic() | ||
|
@@ -1561,6 +1575,7 @@ def fig6(folder, save_fig=True): | |
legstr0 = ["", u"\u2013 noisy image", u"\u2013 original", | ||
u"\u2013 noise-specific", "\u2013 data-specific", u"-- one-click"] | ||
theight = [0, 0,4,3,2,1] | ||
cstr = ["noisy\nimage", "blurry\nimage", "bilinear\nupsampled"] | ||
for i in range(6): | ||
ctype = "cellpose test set" if i < 3 else "nuclei test set" | ||
noise_type = ["denoising", "deblurring", "upsampling"][i % 3] | ||
|
@@ -1580,7 +1595,7 @@ def fig6(folder, save_fig=True): | |
if i == 1 or i == 4: | ||
ax.text(0.5, 1.18, ctype, transform=ax.transAxes, ha="center", | ||
fontsize="large") | ||
|
||
ax.text(0.03, 0.03, cstr[i%3], transform=ax.transAxes, fontsize="small") | ||
ax.set_ylim([0, 0.72]) | ||
ax.set_xticks(np.arange(0.5, 1.05, 0.25)) | ||
ax.set_xlim([0.5, 1.0]) | ||
|
@@ -1593,9 +1608,98 @@ def fig6(folder, save_fig=True): | |
] | ||
colsj = cols0[[0, 1, -1]] | ||
|
||
ly0 = 250 | ||
generalist_restoration_panels(fig, grid, imgs, lbls, masks, diams, api, | ||
titlesj, colsj, titlesi, j0=0, il=il) | ||
|
||
if save_fig: | ||
os.makedirs("figs/", exist_ok=True) | ||
fig.savefig("figs/fig6.pdf", dpi=150) | ||
|
||
def suppfig_generalist_examples(folder, save_fig=True): | ||
cols0 = np.array([[0, 0, 0], [0, 0, 0], [0, 128, 0], [180, 229, 162], | ||
[246, 198, 173], [192, 71, 29], ]) | ||
cols0 = cols0 / 255 | ||
titlesi = [ | ||
"Tissuenet", "Livecell", "Yeaz bright-field", "YeaZ phase-contrast", | ||
"Omnipose phase-contrast", "Omnipose fluorescent", "DeepBacs" | ||
] | ||
colsj = cols0[[0, 1, -1]] | ||
folders = [ | ||
"cyto2", "nuclei", "tissuenet", "livecell", "yeast_BF", "yeast_PhC", | ||
"bact_phase", "bact_fluor", "deepbacs" | ||
] | ||
diam_mean = 30. | ||
|
||
#iexs = [340, 50, 10, 5, 70, 2, 33] | ||
iexs = [305, 1071, 0, 3, 70, 9, 31] | ||
imgs, lbls = [[], [], []], [] | ||
masks = [[], [], []] | ||
for f, iex in zip(folders[2:], iexs): | ||
dat = np.load(Path(folder) / f"{f}_generalist_masks.npy", | ||
allow_pickle=True).item() | ||
img = dat["imgs"][iex].copy() | ||
img = img[:1] if img.ndim > 2 else img | ||
img = np.maximum(0, transforms.normalize99(img)) | ||
imgs[0].append(img) | ||
masks[0].append(dat["masks_pred"][iex]) | ||
lbls.append(dat["masks"][iex].astype("uint16")) | ||
|
||
diams = [utils.diameters(lbl)[0] for lbl in lbls] | ||
|
||
transl = mtransforms.ScaledTranslation(-15 / 72, 30 / 72, fig.dpi_scale_trans) | ||
gen_model = "oneclick_cyto3" | ||
model = denoise.DenoiseModel(gpu=True, nchan=1, diam_mean=diam_mean, | ||
model_type=gen_model) | ||
seg_model = models.CellposeModel(gpu=True, model_type="cyto3") | ||
|
||
fig = plt.figure(figsize=(14, 8), dpi=100) | ||
grid = plt.GridSpec(4, 14, figure=fig, left=0.02, right=0.97, top=0.97, bottom=0.03) | ||
|
||
for ii in range(2): | ||
if ii==0: | ||
titlesj = ["clean", "blurry", "deblurred (one-click)"] | ||
else: | ||
titlesj = ["clean", "downsampled", "upsampled (one-click)"] | ||
masks[1] = [] | ||
masks[2] = [] | ||
imgs[1] = [] | ||
imgs[2] = [] | ||
sigmas = [5., 3., 7., 12., 5., 5., 3.] | ||
ds = [6,4,8,8,6,6,6] | ||
denoise.deterministic() | ||
for i, img in tqdm(enumerate(imgs[0])): | ||
img0 = torch.from_numpy(img.copy()).squeeze().unsqueeze(0).unsqueeze(0) | ||
img0 = img0.float() | ||
noisy0 = denoise.add_noise(img0, poisson=0., downsample=1. if ii==1 else 0, | ||
blur=1., ds=ds[i] if ii==1 else 0, | ||
sigma0 = sigmas[i] if ii==0 else sigmas[i]/2, | ||
sigma1 = sigmas[i] if ii==0 else sigmas[i]/2, | ||
pscale=120.).numpy().squeeze() | ||
denoised0 = model.eval(noisy0, diameter=diams[i], normalize=True) | ||
|
||
imgs[1].append(noisy0) | ||
imgs[2].append(denoised0) | ||
for j in range(1, 3): | ||
masks[j].append( | ||
seg_model.eval( | ||
imgs[j][i], diameter=diams[i], channels=[0, 0], tile_overlap=0.5, | ||
flow_threshold=0.4, augment=True, bsize=224, | ||
niter=2000 if folders[i - 2] == "bact_phase" else None)[0]) | ||
api = np.array( | ||
[metrics.average_precision(lbls, masks[i])[0][:, 0] for i in range(3)]) | ||
|
||
generalist_restoration_panels(fig, grid, imgs, lbls, masks, diams, api, | ||
titlesj, colsj, titlesi, j0=-1 + 2*ii, letter=True) | ||
if save_fig: | ||
os.makedirs("figs/", exist_ok=True) | ||
fig.savefig("figs/suppfig_genex.pdf", dpi=150) | ||
|
||
def generalist_restoration_panels(fig, grid, imgs, lbls, masks, diams, api, | ||
titlesj, colsj, titlesi, j0=0, ly0=250, letter=False, il=0): | ||
if letter: | ||
il = j0>0 | ||
transl = mtransforms.ScaledTranslation(-20 / 72, 15 / 72, fig.dpi_scale_trans) | ||
else: | ||
transl = mtransforms.ScaledTranslation(-20 / 72, 5 / 72, fig.dpi_scale_trans) | ||
for i in range(len(imgs[0])): | ||
ratio = diams[i] / 30. | ||
d = utils.diameters(lbls[i])[0] | ||
|
@@ -1608,20 +1712,18 @@ def fig6(folder, save_fig=True): | |
for j in range(1, 3): | ||
img = np.clip(transforms.normalize99(imgs[j][i].copy().squeeze()), 0, 1) | ||
for k in range(2): | ||
ax = plt.subplot(grid[j, 2 * i + k]) | ||
ax = plt.subplot(grid[j+j0, 2 * i + k]) | ||
pos = ax.get_position().bounds | ||
ax.set_position([ | ||
pos[0] + 0.003 * i - 0.00 * k, pos[1] - (2 - j) * 0.025 - 0.07, | ||
pos[0] + 0.003 * i - 0.00 * k, pos[1] - (2 - j) * 0.025 - 0.08*(j0==0), | ||
pos[2], pos[3] | ||
]) | ||
if 1: | ||
ax.imshow(img, cmap="gray", vmin=0, | ||
vmax=0.35 if j == 1 and i == 2 else 1.0) | ||
vmax=0.35 if j == 1 and i == 2 and j0==0 else 1.0) | ||
if k == 1: | ||
outlines = utils.outlines_list(masks[j][i], | ||
multiprocessing=False) | ||
#for o in outlines_gt: | ||
# ax.plot(o[:,0], o[:,1], color=[0.7,0.4,1], lw=1, ls="-") | ||
for o in outlines: | ||
ax.plot(o[:, 0], o[:, 1], color=[1, 1, 0.3], lw=1.5, | ||
ls="--") | ||
|
@@ -1638,17 +1740,19 @@ def fig6(folder, save_fig=True): | |
if k == 0 and i == 0: | ||
ax.text(-0.22, 0.5, titlesj[j], transform=ax.transAxes, va="center", | ||
rotation=90, color=colsj[j], fontsize="medium") | ||
if j == 0: | ||
if j==1: | ||
il = plot_label(ltr, il, ax, transl, fs_title) | ||
ax.text(-0.0, 1.22, "Denoising examples from other datasets", | ||
ax.text(-0.02, 1.05, "Denoising examples from other datasets", | ||
fontstyle="italic", transform=ax.transAxes, | ||
fontsize="large") | ||
if k == 0 and j == 0: | ||
ax.text(0.0, 1.05, titlesi[i], transform=ax.transAxes, | ||
fontsize="medium") | ||
if save_fig: | ||
os.makedirs("figs/", exist_ok=True) | ||
fig.savefig("figs/fig6.pdf", dpi=150) | ||
if j==1 and letter: | ||
ax.text(-0.0, 1.11, "Deblurring examples from other datasets" if j0==-1 else "Upsampling examples from other datasets", | ||
fontstyle="italic", transform=ax.transAxes, | ||
fontsize="large") | ||
il = plot_label(ltr, il, ax, transl, fs_title) | ||
#if k == 0 and (j == 0 or (j==1 and j0==0)): | ||
#ax.text(0.0, 1.05, titlesi[i], transform=ax.transAxes, | ||
# fontsize="medium") | ||
|
||
def load_seg_generalist(folder): | ||
folders = [ | ||
|
Oops, something went wrong.