Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Restormer Implementation #8312

Open
wants to merge 24 commits into
base: dev
Choose a base branch
from
Open

Restormer Implementation #8312

wants to merge 24 commits into from

Conversation

phisanti
Copy link

Fixes # .

Description

This PR implements the Restormer architecture for high-resolution image restoration in MONAI following the discussion in issue #8261. The implementation supports both 2D and 3D images using MONAI's convolution as the base. Key additions include:

  • Downsample class for efficient downsampling operations
  • pixel_unshuffle operation complementing existing pixel_shuffle
  • Channel Attention Block (CABlock) with FeedForward layer
  • Multi-DConv Head Transposed Self-Attention (MDTA)
  • OverlapPatchEmbed class
  • Comprehensive unit tests for all new components

The implementation follows MONAI's coding patterns and includes performance validations against native PyTorch operations where applicable.

Types of changes

  • Non-breaking change (fix or new feature that would not break existing functionality).
  • Breaking change (fix or new feature that would cause existing functionality to change).
  • New tests added to cover the changes.
  • Integration tests passed locally by running ./runtests.sh -f -u --net --coverage.
  • Quick tests passed locally by running ./runtests.sh --quick --unittests --disttests.
  • In-line docstrings updated.
  • Documentation updated, tested make html command in the docs/ folder.

Copy link
Member

@ericspod ericspod left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good overall but I had a few inline comments, and we should have full docstrings everywhere appropriate. For any classes meant for general purpose use (ie. not just by Restormer) please ensure they have docstring descriptions for the arguments (at the very least for constructor args). Thanks!

See: Aitken et al., 2017, "Checkerboard artifact free sub-pixel convolution".

Args:
x: Input tensor
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we should specifically state that x has shape BCHW[D].


if any(d % factor != 0 for d in input_size[2:]):
raise ValueError(
f"All spatial dimensions must be divisible by factor {factor}. " f"Got spatial dimensions: {input_size[2:]}"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
f"All spatial dimensions must be divisible by factor {factor}. " f"Got spatial dimensions: {input_size[2:]}"
f"All spatial dimensions must be divisible by {factor}, spatial shape is: {input_size[2:]}"

Maybe a little shorter?

kernel_size_ = ensure_tuple_rep(kernel_size, spatial_dims)
padding = tuple((k - 1) // 2 for k in kernel_size_)

if down_mode == "conv":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if down_mode == "conv":
if down_mode == DownsampleMode.CONV:

bias=bias,
),
)
elif down_mode == "convgroup":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
elif down_mode == "convgroup":
elif down_mode == DownsampleMode.CONVGROUP:

if post_conv:
self.add_module("postconv", post_conv)

elif down_mode == "pixelunshuffle":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
elif down_mode == "pixelunshuffle":
elif down_mode == DownsampleMode.PIXELSHUFFLE:

Comment on lines +68 to +72
"""Multi-DConv Head Transposed Self-Attention (MDTA): Differs from standard self-attention
by operating on feature channels instead of spatial dimensions. Incorporates depth-wise
convolutions for local mixing before attention, achieving linear complexity vs quadratic
in vanilla attention. Based on SW Zamir, et al., 2022 <https://arxiv.org/abs/2111.09881>"""

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should have a full docstring here describing the arguments for the constructor, and in the previous class.

Comment on lines +51 to +70
class OverlapPatchEmbed(nn.Module):
"""Initial feature extraction using overlapped convolutions.
Unlike standard patch embeddings that use non-overlapping patches,
this approach maintains spatial continuity through 3x3 convolutions."""

def __init__(self, spatial_dims: int, in_c: int = 3, embed_dim: int = 48, bias: bool = False):
super().__init__()
self.proj = Convolution(
spatial_dims=spatial_dims,
in_channels=in_c,
out_channels=embed_dim,
kernel_size=3,
strides=1,
padding=1,
bias=bias,
conv_only=True,
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.proj(x)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class OverlapPatchEmbed(nn.Module):
"""Initial feature extraction using overlapped convolutions.
Unlike standard patch embeddings that use non-overlapping patches,
this approach maintains spatial continuity through 3x3 convolutions."""
def __init__(self, spatial_dims: int, in_c: int = 3, embed_dim: int = 48, bias: bool = False):
super().__init__()
self.proj = Convolution(
spatial_dims=spatial_dims,
in_channels=in_c,
out_channels=embed_dim,
kernel_size=3,
strides=1,
padding=1,
bias=bias,
conv_only=True,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.proj(x)
class OverlapPatchEmbed(Convolution):
"""
Initial feature extraction using overlapped convolutions. Unlike standard patch embeddings
that use non-overlapping patches, this approach maintains spatial continuity through 3x3 convolutions.
"""
def __init__(self, spatial_dims: int, in_c: int = 3, embed_dim: int = 48, bias: bool = False):
super().__init__(
spatial_dims=spatial_dims,
in_channels=in_c,
out_channels=embed_dim,
kernel_size=3,
strides=1,
padding=1,
bias=bias,
conv_only=True,
)

Would it work to inherit directly from Convolution?

@ericspod ericspod requested a review from aylward January 24, 2025 13:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants