From 44f625d989f58c4e71201910ea0c131c94f12893 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Mon, 2 Dec 2024 13:27:31 +0100 Subject: [PATCH 1/3] Update the gaussian blur augmentation --- torch_em/transform/raw.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/torch_em/transform/raw.py b/torch_em/transform/raw.py index cd43c379..dbcdabf8 100644 --- a/torch_em/transform/raw.py +++ b/torch_em/transform/raw.py @@ -160,18 +160,16 @@ def __call__(self, img): class GaussianBlur(): + """Blur the image with a randomly drawn sigma / bandwidth value. """ - Blur the image. - """ - def __init__(self, kernel_size=(2, 12), sigma=(0, 2.5)): - self.kernel_size = kernel_size + def __init__(self, sigma=(0, 3.0)): self.sigma = sigma def __call__(self, img): - # sample kernel_size and make sure it is odd - kernel_size = 2 * (np.random.randint(self.kernel_size[0], self.kernel_size[1]) // 2) + 1 - # switch boundaries to make sure 0 is excluded from sampling + # Sample the sigma value. Note that we switch the bounds to ensure zero is excluded from sampling. sigma = np.random.uniform(self.sigma[1], self.sigma[0]) + # Determine the kernel size based on the sigma value. + kernel_size = int(2 * np.ceil(3 * sigma) + 1) if isinstance(img, np.ndarray): img = torch.from_numpy(img) out = transforms.GaussianBlur(kernel_size, sigma=sigma)(img) From fa8d10045a1d2d961a6cfbe2a4ef38d742fdf821 Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Mon, 2 Dec 2024 13:47:43 +0100 Subject: [PATCH 2/3] Deactivate modelzoo tests as the bioimage.io library is broken --- test/util/test_modelzoo.py | 1 + torch_em/util/modelzoo.py | 4 +++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/test/util/test_modelzoo.py b/test/util/test_modelzoo.py index e560c094..8aaad613 100644 --- a/test/util/test_modelzoo.py +++ b/test/util/test_modelzoo.py @@ -26,6 +26,7 @@ def __call__(self, labels): return labels +@unittest.skip("BioImage.IO libraries are broken.") class TestModelzoo(unittest.TestCase): checkpoint_folder = "./checkpoints" log_folder = "./logs" diff --git a/torch_em/util/modelzoo.py b/torch_em/util/modelzoo.py index e8ac8b36..4622cef2 100644 --- a/torch_em/util/modelzoo.py +++ b/torch_em/util/modelzoo.py @@ -630,7 +630,9 @@ def _load_model(model_spec, device): model = PytorchModelAdapter.get_network(weight_spec) weight_file = weight_spec.source.path if not os.path.exists(weight_file): - weight_file = os.path.join(model_spec.root, weight_file) + root_folder = f"{model_spec.root.filename}.unzip" + assert os.path.exists(root_folder), root_folder + weight_file = os.path.join(root_folder, weight_file) assert os.path.exists(weight_file), weight_file state = torch.load(weight_file, map_location=device) model.load_state_dict(state) From 121d5d4acd7f55817b7782b299b32d2888a2268b Mon Sep 17 00:00:00 2001 From: Constantin Pape Date: Mon, 2 Dec 2024 13:48:15 +0100 Subject: [PATCH 3/3] Bump version --- torch_em/__version__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch_em/__version__.py b/torch_em/__version__.py index ed9d4d87..ab55bb1a 100644 --- a/torch_em/__version__.py +++ b/torch_em/__version__.py @@ -1 +1 @@ -__version__ = "0.7.4" +__version__ = "0.7.5"