Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementing Channel-Wise Transforms #8311

Open
LemuelPuglisi opened this issue Jan 23, 2025 · 0 comments
Open

Implementing Channel-Wise Transforms #8311

LemuelPuglisi opened this issue Jan 23, 2025 · 0 comments

Comments

@LemuelPuglisi
Copy link

Is your feature request related to a problem? Please describe.

It is common practice to concatenate different images along the channel axis before feeding them into a model (early fusion). However (to the best of my knowledge), applying data augmentation channel-wise in the current setup is not straightforward. Introducing a helper class that wraps a MONAI transform and applies it along a specified axis could be helpful.

Describe the solution you'd like

In my case, I had different 3D volumes concatenated along the channel axis, resulting in a shape of C x H x W x D. My solution was to create a wrapper transform as shown below.

P.S.: Apologies for not adhering to the MONAI coding guidelines—this was a quick prototype.

from monai.transforms import MapTransform, RandomizableTransform


class RandChannelWiseApply(RandomizableTransform):
    
    def __init__(self, transform_to_wrap, prob=1):
        RandomizableTransform.__init__(self, prob)
        self.transform_to_wrap = transform_to_wrap
        
    def __call__(self, x):
        x = x.clone()
        if len(x.shape) != 4:
            raise Exception("Input tensor must be of shape C x H x W x D")
        for ch in range(x.shape[0]):
            x[ch] = self.transform_to_wrap(x[ch])
        return x
            

class RandChannelWiseApplyD(MapTransform, RandomizableTransform):
    
    def __init__(self,
                 keys, 
                 transform_to_wrap, 
                 prob=1,
                 allow_missing_keys=False):
        
        MapTransform.__init__(self, keys, allow_missing_keys)
        RandomizableTransform.__init__(self, 1)
        self.transform = RandChannelWiseApply(transform_to_wrap, prob)

    def __call__(self, data):
        for key in self.keys:
            if key in data:
                data[key] = self.transform(data[key])
            elif self.allow_missing_keys:
                continue
            else:
                raise Exception(f'Key {key} is missing.') 
        return data

Let me know if you think implementing this in MONAI would be useful. I’d be happy to contribute to the library.

Describe alternatives you've considered

  1. An alternative is to concatenate all the images after applying the augmentation. However, this approach becomes challenging when the images are already saved in a concatenated format.

  2. It’s quite possible that MONAI already offers an alternative way to achieve this that I’m not aware of.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant