-
Notifications
You must be signed in to change notification settings - Fork 27
/
Copy pathdemo2d.py
92 lines (77 loc) · 3.12 KB
/
demo2d.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
# -*- coding: utf-8 -*-
# Author: Xiangde Luo
# Date: 2 Sep., 2021
# Implementation of MIDeepSeg for interactive medical image segmentation and annotation.
# This file was borrowed from [GeodisTK](https://github.com/taigw/GeodisTK)
# Reference:
# [1] X. Luo and G. Wang et al. MIDeepSeg: Minimally interactive segmentation of unseen objects
# from medical images using deep learning. Medical Image Analysis, 2021. DOI:https://doi.org/10.1016/j.media.2021.102102.
# [2] Wang, Guotai, et al. "DeepIGeoS: A deep interactive geodesic framework for medical image segmentation." TPAMI, 2018.
import GeodisTK
import numpy as np
import time
from PIL import Image
import matplotlib.pyplot as plt
def geodesic_distance_2d(I, S, lamb, iter):
'''
get 2d geodesic disntance by raser scanning.
I: input image, can have multiple channels. Type should be np.float32.
S: binary image where non-zero pixels are used as seeds. Type should be np.uint8.
lamb: weighting betwween 0.0 and 1.0
if lamb==0.0, return spatial euclidean distance without considering gradient
if lamb==1.0, the distance is based on gradient only without using spatial distance
iter: number of iteration for raster scanning.
'''
return GeodisTK.geodesic2d_raster_scan(I, S, lamb, iter)
def demo_geodesic_distance2d(img, seed_pos):
I = np.asanyarray(img, np.float32)
S = np.zeros((I.shape[0], I.shape[1]), np.uint8)
S[seed_pos[0]][seed_pos[1]] = 1
t0 = time.time()
D1 = GeodisTK.geodesic2d_fast_marching(I, S)
t1 = time.time()
D2 = geodesic_distance_2d(I, S, 1.0, 2)
dt1 = t1 - t0
dt2 = time.time() - t1
D3 = geodesic_distance_2d(I, S, 0.0, 2)
D4 = geodesic_distance_2d(I, S, 0.5, 2)
print("runtime(s) of fast marching {0:}".format(dt1))
print("runtime(s) of raster scan {0:}".format(dt2))
plt.figure(figsize=(18, 6))
plt.subplot(1, 6, 1)
plt.imshow(img, "gray")
plt.autoscale(False)
plt.plot([seed_pos[1]], [seed_pos[0]], 'ro')
plt.axis('off')
plt.title('(a) input image \n with a seed point')
plt.subplot(1, 6, 2)
plt.imshow(D1)
plt.axis('off')
plt.title('(b) Geodesic distance \n based on fast marching')
plt.subplot(1, 6, 3)
plt.imshow(D2)
plt.axis('off')
plt.title('(c) Geodesic distance \n based on ranster scan')
plt.subplot(1, 6, 4)
plt.imshow(D3)
plt.axis('off')
plt.title('(d) Euclidean distance')
plt.subplot(1, 6, 5)
plt.imshow(D4)
plt.axis('off')
plt.title('(e) Mexture of Geodesic \n and Euclidean distance')
plt.subplot(1, 6, 6)
plt.imshow(np.exp(-D1))
plt.axis('off')
plt.title('(f) Exponential Geodesic distance')
plt.savefig("demo_dataset/egd_vis.png",
bbox_inches='tight', dpi=500, pad_inches=0.0)
plt.show()
def demo_geodesic_distance2d_gray_scale_image():
img = Image.open('demo_dataset/pancreas.png').convert('L')
img = np.array(img)[100:400, 100:400]
img = (img - img.mean()) / img.std()
seed_position = [121, 182]
demo_geodesic_distance2d(img, seed_position)
if __name__ == '__main__':
demo_geodesic_distance2d_gray_scale_image()