-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Restormer Implementation #8312
base: dev
Are you sure you want to change the base?
Restormer Implementation #8312
Conversation
…nsample class alias
…pass ./runtests.sh -f -u --net --coverage
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good overall but I had a few inline comments, and we should have full docstrings everywhere appropriate. For any classes meant for general purpose use (ie. not just by Restormer) please ensure they have docstring descriptions for the arguments (at the very least for constructor args). Thanks!
See: Aitken et al., 2017, "Checkerboard artifact free sub-pixel convolution". | ||
|
||
Args: | ||
x: Input tensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here we should specifically state that x
has shape BCHW[D]
.
|
||
if any(d % factor != 0 for d in input_size[2:]): | ||
raise ValueError( | ||
f"All spatial dimensions must be divisible by factor {factor}. " f"Got spatial dimensions: {input_size[2:]}" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
f"All spatial dimensions must be divisible by factor {factor}. " f"Got spatial dimensions: {input_size[2:]}" | |
f"All spatial dimensions must be divisible by {factor}, spatial shape is: {input_size[2:]}" |
Maybe a little shorter?
kernel_size_ = ensure_tuple_rep(kernel_size, spatial_dims) | ||
padding = tuple((k - 1) // 2 for k in kernel_size_) | ||
|
||
if down_mode == "conv": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if down_mode == "conv": | |
if down_mode == DownsampleMode.CONV: |
bias=bias, | ||
), | ||
) | ||
elif down_mode == "convgroup": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
elif down_mode == "convgroup": | |
elif down_mode == DownsampleMode.CONVGROUP: |
if post_conv: | ||
self.add_module("postconv", post_conv) | ||
|
||
elif down_mode == "pixelunshuffle": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
elif down_mode == "pixelunshuffle": | |
elif down_mode == DownsampleMode.PIXELSHUFFLE: |
"""Multi-DConv Head Transposed Self-Attention (MDTA): Differs from standard self-attention | ||
by operating on feature channels instead of spatial dimensions. Incorporates depth-wise | ||
convolutions for local mixing before attention, achieving linear complexity vs quadratic | ||
in vanilla attention. Based on SW Zamir, et al., 2022 <https://arxiv.org/abs/2111.09881>""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should have a full docstring here describing the arguments for the constructor, and in the previous class.
class OverlapPatchEmbed(nn.Module): | ||
"""Initial feature extraction using overlapped convolutions. | ||
Unlike standard patch embeddings that use non-overlapping patches, | ||
this approach maintains spatial continuity through 3x3 convolutions.""" | ||
|
||
def __init__(self, spatial_dims: int, in_c: int = 3, embed_dim: int = 48, bias: bool = False): | ||
super().__init__() | ||
self.proj = Convolution( | ||
spatial_dims=spatial_dims, | ||
in_channels=in_c, | ||
out_channels=embed_dim, | ||
kernel_size=3, | ||
strides=1, | ||
padding=1, | ||
bias=bias, | ||
conv_only=True, | ||
) | ||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
return self.proj(x) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
class OverlapPatchEmbed(nn.Module): | |
"""Initial feature extraction using overlapped convolutions. | |
Unlike standard patch embeddings that use non-overlapping patches, | |
this approach maintains spatial continuity through 3x3 convolutions.""" | |
def __init__(self, spatial_dims: int, in_c: int = 3, embed_dim: int = 48, bias: bool = False): | |
super().__init__() | |
self.proj = Convolution( | |
spatial_dims=spatial_dims, | |
in_channels=in_c, | |
out_channels=embed_dim, | |
kernel_size=3, | |
strides=1, | |
padding=1, | |
bias=bias, | |
conv_only=True, | |
) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
return self.proj(x) | |
class OverlapPatchEmbed(Convolution): | |
""" | |
Initial feature extraction using overlapped convolutions. Unlike standard patch embeddings | |
that use non-overlapping patches, this approach maintains spatial continuity through 3x3 convolutions. | |
""" | |
def __init__(self, spatial_dims: int, in_c: int = 3, embed_dim: int = 48, bias: bool = False): | |
super().__init__( | |
spatial_dims=spatial_dims, | |
in_channels=in_c, | |
out_channels=embed_dim, | |
kernel_size=3, | |
strides=1, | |
padding=1, | |
bias=bias, | |
conv_only=True, | |
) |
Would it work to inherit directly from Convolution
?
Fixes # .
Description
This PR implements the Restormer architecture for high-resolution image restoration in MONAI following the discussion in issue #8261. The implementation supports both 2D and 3D images using MONAI's convolution as the base. Key additions include:
The implementation follows MONAI's coding patterns and includes performance validations against native PyTorch operations where applicable.
Types of changes
./runtests.sh -f -u --net --coverage
../runtests.sh --quick --unittests --disttests
.make html
command in thedocs/
folder.