Skip to content

Commit

Permalink
Merge pull request #423 from constantinpape/release-prep
Browse files Browse the repository at this point in the history
Bump to 0.7.5
  • Loading branch information
constantinpape authored Dec 2, 2024
2 parents af35ff9 + 121d5d4 commit bca7842
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 9 deletions.
1 change: 1 addition & 0 deletions test/util/test_modelzoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def __call__(self, labels):
return labels


@unittest.skip("BioImage.IO libraries are broken.")
class TestModelzoo(unittest.TestCase):
checkpoint_folder = "./checkpoints"
log_folder = "./logs"
Expand Down
2 changes: 1 addition & 1 deletion torch_em/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.7.4"
__version__ = "0.7.5"
12 changes: 5 additions & 7 deletions torch_em/transform/raw.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,18 +160,16 @@ def __call__(self, img):


class GaussianBlur():
"""Blur the image with a randomly drawn sigma / bandwidth value.
"""
Blur the image.
"""
def __init__(self, kernel_size=(2, 12), sigma=(0, 2.5)):
self.kernel_size = kernel_size
def __init__(self, sigma=(0, 3.0)):
self.sigma = sigma

def __call__(self, img):
# sample kernel_size and make sure it is odd
kernel_size = 2 * (np.random.randint(self.kernel_size[0], self.kernel_size[1]) // 2) + 1
# switch boundaries to make sure 0 is excluded from sampling
# Sample the sigma value. Note that we switch the bounds to ensure zero is excluded from sampling.
sigma = np.random.uniform(self.sigma[1], self.sigma[0])
# Determine the kernel size based on the sigma value.
kernel_size = int(2 * np.ceil(3 * sigma) + 1)
if isinstance(img, np.ndarray):
img = torch.from_numpy(img)
out = transforms.GaussianBlur(kernel_size, sigma=sigma)(img)
Expand Down
4 changes: 3 additions & 1 deletion torch_em/util/modelzoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -630,7 +630,9 @@ def _load_model(model_spec, device):
model = PytorchModelAdapter.get_network(weight_spec)
weight_file = weight_spec.source.path
if not os.path.exists(weight_file):
weight_file = os.path.join(model_spec.root, weight_file)
root_folder = f"{model_spec.root.filename}.unzip"
assert os.path.exists(root_folder), root_folder
weight_file = os.path.join(root_folder, weight_file)
assert os.path.exists(weight_file), weight_file
state = torch.load(weight_file, map_location=device)
model.load_state_dict(state)
Expand Down

0 comments on commit bca7842

Please sign in to comment.