Skip to content

Commit

Permalink
🔧 Fix dimension indexing in SoftmaxND + add unit tests for SoftmaxND …
Browse files Browse the repository at this point in the history
…and LogSoftmaxND
  • Loading branch information
jejon committed Nov 22, 2024
1 parent b0487de commit 58f5f92
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 2 deletions.
7 changes: 5 additions & 2 deletions src/landmarker/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@
class SoftmaxND(nn.Module):
def __init__(self, spatial_dims):
super().__init__()
self.dim = (-2, -1) if spatial_dims == 2 else (-3, -2, -2)
self.dim = (-2, -1) if spatial_dims == 2 else (-3, -2, -1)

def forward(self, x):
out = torch.exp(x - torch.max(x, dim=self.dim, keepdim=True)[0])
max_val = x
for d in self.dim:
max_val, _ = torch.max(max_val, dim=d, keepdim=True)
out = torch.exp(x - max_val)
return out / torch.sum(out, dim=self.dim, keepdim=True)


Expand Down
35 changes: 35 additions & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
ProbSpatialConfigurationNet,
SpatialConfigurationNet,
)
from landmarker.models.utils import LogSoftmaxND, SoftmaxND


def test_original_spatial_configuration_net():
Expand Down Expand Up @@ -267,3 +268,37 @@ def test_coord_conv_layer_coord_channels_range():
# check that the output values are within the range [-1, 1]
assert out.shape == torch.Size([2, 5, 64, 64])
assert (-1 <= out[:, 3:]).all() and (out[:, 3:] <= 1).all()


def test_softmax_nd():
"""Test the SoftmaxND class."""
# Test for 2D case
softmax_2d = SoftmaxND(spatial_dims=2)
x = torch.randn(1, 3, 4, 4)
output = softmax_2d(x)
assert output.shape == x.shape
assert torch.allclose(torch.sum(output, dim=(-2, -1)), torch.ones(1, 3))

# Test for 3D case
softmax_3d = SoftmaxND(spatial_dims=3)
x = torch.randn(1, 3, 4, 4, 4)
output = softmax_3d(x)
assert output.shape == x.shape
assert torch.allclose(torch.sum(output, dim=(-3, -2, -1)), torch.ones(1, 3))


def test_log_softmax_nd():
"""Test the LogSoftmaxND class."""
# Test for 2D case
log_softmax_2d = LogSoftmaxND(spatial_dims=2)
x = torch.randn(1, 3, 4, 4)
output = log_softmax_2d(x)
assert output.shape == x.shape
assert torch.allclose(torch.sum(torch.exp(output), dim=(-2, -1)), torch.ones(1, 3))

# Test for 3D case
log_softmax_3d = LogSoftmaxND(spatial_dims=3)
x = torch.randn(1, 3, 4, 4, 4)
output = log_softmax_3d(x)
assert output.shape == x.shape
assert torch.allclose(torch.sum(torch.exp(output), dim=(-3, -2, -1)), torch.ones(1, 3))

0 comments on commit 58f5f92

Please sign in to comment.