Skip to content

Commit

Permalink
🐛 Fix MapDe dist_filter Shape (#914)
Browse files Browse the repository at this point in the history
- Fix `dist_filter` in `MapDe` model for multi-class output.

Explanation:
Previously, if we set `num_class` to more than 1, the model would still output 1 channel. This was because the `dist_filter` always had size of 1 in its first dimension, however the first dimension determines the number of output channels in the tensor produced by `torch.functional.F.conv2d`.
This PR changes this by repeating the filters the match the number of output classes.
  • Loading branch information
Jiaqi-Lv authored Mar 3, 2025
1 parent 9021b57 commit 95e70fa
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 1 deletion.
9 changes: 9 additions & 0 deletions tests/models/test_arch_mapde.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,12 @@ def test_functionality(remote_sample: Callable) -> None:
output = model.infer_batch(model, batch, device=select_device(on_gpu=ON_GPU))
output = model.postproc(output[0])
assert np.all(output[0:2] == [[19, 171], [53, 89]])


def test_multiclass_output() -> None:
"""Test the architecture for multi-class output."""
multiclass_model = MapDe(num_input_channels=3, num_classes=3)
test_input = torch.rand((1, 3, 252, 252))

output = multiclass_model(test_input)
assert output.shape == (1, 3, 252, 252)
5 changes: 4 additions & 1 deletion tiatoolbox/models/architecture/mapde.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,8 +199,11 @@ def __init__(
dtype=np.float32,
)

dist_filter = np.expand_dims(dist_filter, axis=(0, 1)) # NCHW
# For conv2d, filter shape = (out_channels, in_channels//groups, H, W)
dist_filter = np.expand_dims(dist_filter, axis=(0, 1))
dist_filter = np.repeat(dist_filter, repeats=num_classes * 2, axis=1)
# Need to repeat for out_channels
dist_filter = np.repeat(dist_filter, repeats=num_classes, axis=0)

self.min_distance = min_distance
self.threshold_abs = threshold_abs
Expand Down

0 comments on commit 95e70fa

Please sign in to comment.