diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index d2585daf63..ff313b7135 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -2199,6 +2199,12 @@ Utility (Dict) :members: :special-members: __call__ +`SubtractItemsd` +"""""""""""""""" +.. autoclass:: SubtractItemsd + :members: + :special-members: __call__ + `ConcatItemsd` """""""""""""" .. autoclass:: ConcatItemsd diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index d15042181b..80f54975b1 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -647,6 +647,9 @@ SqueezeDimd, SqueezeDimD, SqueezeDimDict, + SubtractItemsd, + SubtractItemsD, + SubtractItemsDict, ToCupyd, ToCupyD, ToCupyDict, diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index 7dd2397a74..4144094d6b 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -161,6 +161,9 @@ "SqueezeDimD", "SqueezeDimDict", "SqueezeDimd", + "SubtractItemsD", + "SubtractItemsDict", + "SubtractItemsd", "ToCupyD", "ToCupyDict", "ToCupyd", @@ -957,6 +960,57 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N return d +class SubtractItemsd(MapTransform): + """ + Subtract specified items from data dictionary elementwise. + Expect all the items are numpy array or PyTorch Tensor or MetaTensor. + Return the first input's meta information when items are MetaTensor. + """ + + backend = [TransformBackends.TORCH, TransformBackends.NUMPY] + + def __init__(self, keys: KeysCollection, name: str, allow_missing_keys: bool = False) -> None: + """ + Args: + keys: keys of the corresponding items to be subtracted. + See also: :py:class:`monai.transforms.compose.MapTransform` + name: the name corresponding to the key to store the resulting data. + allow_missing_keys: don't raise exception if key is missing. + """ + super().__init__(keys, allow_missing_keys) + self.name = name + + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]: + """ + Raises: + TypeError: When items in ``data`` differ in type. + TypeError: When the item type is not in ``Union[numpy.ndarray, torch.Tensor, MetaTensor]``. + + """ + d = dict(data) + output = [] + data_type = None + for key in self.key_iterator(d): + if data_type is None: + data_type = type(d[key]) + elif not isinstance(d[key], data_type): + raise TypeError("All items in data must have the same type.") + output.append(d[key]) + + if len(output) == 0: + return d + + if data_type is np.ndarray: + d[self.name] = np.subtract(output[0], output[1]) + elif issubclass(data_type, torch.Tensor): # type: ignore + d[self.name] = torch.sub(output[0], output[1]) # type: ignore + else: + raise TypeError( + f"Unsupported data type: {data_type}, available options are (numpy.ndarray, torch.Tensor, MetaTensor)." + ) + return d + + class ConcatItemsd(MapTransform): """ Concatenate specified items from data dictionary together on the first dim to construct a big array. @@ -1927,6 +1981,7 @@ def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch DataStatsD = DataStatsDict = DataStatsd SimulateDelayD = SimulateDelayDict = SimulateDelayd CopyItemsD = CopyItemsDict = CopyItemsd +SubtractItemsD = SubtractItemsDict = SubtractItemsd ConcatItemsD = ConcatItemsDict = ConcatItemsd LambdaD = LambdaDict = Lambdad LabelToMaskD = LabelToMaskDict = LabelToMaskd diff --git a/tests/test_module_list.py b/tests/test_module_list.py index 833441cbca..761a5016fd 100644 --- a/tests/test_module_list.py +++ b/tests/test_module_list.py @@ -42,7 +42,9 @@ def test_transform_api(self): """monai subclasses of MapTransforms must have alias names ending with 'd', 'D', 'Dict'""" to_exclude = {"MapTransform"} # except for these transforms to_exclude_docs = {"Decollate", "Ensemble", "Invert", "SaveClassification", "RandTorchVision", "RandCrop"} - to_exclude_docs.update({"DeleteItems", "SelectItems", "FlattenSubKeys", "CopyItems", "ConcatItems"}) + to_exclude_docs.update( + {"DeleteItems", "SelectItems", "FlattenSubKeys", "CopyItems", "ConcatItems", "SubtractItems"} + ) to_exclude_docs.update({"ToMetaTensor", "FromMetaTensor"}) xforms = { name: obj diff --git a/tests/test_subtract_itemsd.py b/tests/test_subtract_itemsd.py new file mode 100644 index 0000000000..d754acbc30 --- /dev/null +++ b/tests/test_subtract_itemsd.py @@ -0,0 +1,63 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import unittest + +import numpy as np +import torch + +from monai.data import MetaTensor +from monai.transforms import SubtractItemsd +from tests.utils import assert_allclose + + +class TestSubtractItemsd(unittest.TestCase): + + def test_tensor_values(self): + device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu:0") + input_data = { + "img1": torch.tensor([[0, 1], [1, 2]], device=device), + "img2": torch.tensor([[0, 1], [1, 2]], device=device), + "name": "key_name", + } + result = SubtractItemsd(keys=["img1", "img2"], name="sub_img")(input_data) + self.assertIn("sub_img", result) + result["sub_img"] += 1 + assert_allclose(result["img1"], torch.tensor([[0, 1], [1, 2]], device=device)) + assert_allclose(result["sub_img"], torch.tensor([[1, 1], [1, 1]], device=device)) + + def test_metatensor_values(self): + device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu:0") + input_data = { + "img1": MetaTensor([[0, 1], [1, 2]], device=device), + "img2": MetaTensor([[0, 1], [1, 2]], device=device), + } + result = SubtractItemsd(keys=["img1", "img2"], name="sub_img")(input_data) + self.assertIn("sub_img", result) + self.assertIsInstance(result["sub_img"], MetaTensor) + self.assertEqual(result["img1"].meta, result["sub_img"].meta) + result["sub_img"] += 1 + assert_allclose(result["img1"], torch.tensor([[0, 1], [1, 2]], device=device)) + assert_allclose(result["sub_img"], torch.tensor([[1, 1], [1, 1]], device=device)) + + def test_numpy_values(self): + input_data = {"img1": np.array([[0, 1], [1, 2]]), "img2": np.array([[0, 1], [1, 2]])} + result = SubtractItemsd(keys=["img1", "img2"], name="sub_img")(input_data) + self.assertIn("sub_img", result) + result["sub_img"] += 1 + np.testing.assert_allclose(result["img1"], np.array([[0, 1], [1, 2]])) + np.testing.assert_allclose(result["sub_img"], np.array([[1, 1], [1, 1]])) + + +if __name__ == "__main__": + unittest.main()