From f62e9689127c60221eb2aad2df64acf306b9f8bf Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Sun, 22 Dec 2024 22:58:12 +0000 Subject: [PATCH 01/50] Add parallel test deps --- requirements/test.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/requirements/test.txt b/requirements/test.txt index c635aa51..aaa038b4 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -1,2 +1,3 @@ pytest==8.3.4 -ruff==0.8.4 +pytest-xdist==3.6.1 +ruff==0.8.4 \ No newline at end of file From 7e08ed364d4be2a011be4579f7630bce0dc710bf Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Sun, 22 Dec 2024 22:58:25 +0000 Subject: [PATCH 02/50] Update signature --- segmentation_models_pytorch/encoders/inceptionv4.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/segmentation_models_pytorch/encoders/inceptionv4.py b/segmentation_models_pytorch/encoders/inceptionv4.py index 83adf003..96540f9a 100644 --- a/segmentation_models_pytorch/encoders/inceptionv4.py +++ b/segmentation_models_pytorch/encoders/inceptionv4.py @@ -49,7 +49,7 @@ def __init__(self, stage_idxs, out_channels, depth=5, **kwargs): # remove linear layers del self.last_linear - def make_dilated(self, stage_list, dilation_list): + def make_dilated(self, *args, **kwargs): raise ValueError( "InceptionV4 encoder does not support dilated mode " "due to pooling operation for downsampling!" From 053db256f5f3af47be58d20fee68b58657b791df Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Sun, 22 Dec 2024 22:59:08 +0000 Subject: [PATCH 03/50] Add encoders tests --- tests/config.py | 25 +++ tests/encoders/__init__.py | 0 tests/encoders/base.py | 207 ++++++++++++++++++ .../test_pretrainedmodels_encoders.py | 71 ++++++ tests/encoders/test_smp_encoders.py | 41 ++++ tests/encoders/test_timm_ported_encoders.py | 131 +++++++++++ tests/encoders/test_timm_universal.py | 6 + tests/encoders/test_torchvision_encoders.py | 24 ++ 8 files changed, 505 insertions(+) create mode 100644 tests/config.py create mode 100644 tests/encoders/__init__.py create mode 100644 tests/encoders/base.py create mode 100644 tests/encoders/test_pretrainedmodels_encoders.py create mode 100644 tests/encoders/test_smp_encoders.py create mode 100644 tests/encoders/test_timm_ported_encoders.py create mode 100644 tests/encoders/test_timm_universal.py create mode 100644 tests/encoders/test_torchvision_encoders.py diff --git a/tests/config.py b/tests/config.py new file mode 100644 index 00000000..d53eebad --- /dev/null +++ b/tests/config.py @@ -0,0 +1,25 @@ +import os + + +def get_commit_message(): + # Get commit message from CI environment variables + # Common CI systems store commit messages in different env vars + commit_msg = os.getenv("COMMIT_MESSAGE", "") # Generic + if not commit_msg: + commit_msg = os.getenv("CI_COMMIT_MESSAGE", "") # GitLab CI + if not commit_msg: + commit_msg = os.getenv("GITHUB_EVENT_HEAD_COMMIT_MESSAGE", "") # GitHub Actions + return commit_msg.lower() + + +# Check both environment variables and commit message +commit_message = get_commit_message() +RUN_ALL_ENCODERS = ( + os.getenv("RUN_ALL_ENCODERS", "false").lower() in ["true", "1", "y", "yes"] + or "run-all-encoders" in commit_message +) + +RUN_SLOW = ( + os.getenv("RUN_SLOW", "false").lower() in ["true", "1", "y", "yes"] + or "run-slow" in commit_message +) diff --git a/tests/encoders/__init__.py b/tests/encoders/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/encoders/base.py b/tests/encoders/base.py new file mode 100644 index 00000000..ec75d79a --- /dev/null +++ b/tests/encoders/base.py @@ -0,0 +1,207 @@ +import unittest +import torch +import segmentation_models_pytorch as smp + +from functools import lru_cache + + +class BaseEncoderTester(unittest.TestCase): + encoder_names = [] + + # standard encoder configuration + num_output_features = 6 + output_strides = [1, 2, 4, 8, 16, 32] + supports_dilated = True + + # test sample configuration + default_batch_size = 1 + default_num_channels = 3 + default_height = 64 + default_width = 64 + + # test configurations + in_channels_to_test = [1, 3, 4] + depth_to_test = [3, 4, 5] + strides_to_test = [8, 16] # 32 is a default one + + @lru_cache + def _get_sample(self, batch_size=1, num_channels=3, height=32, width=32): + return torch.rand(batch_size, num_channels, height, width) + + def get_features_output_strides(self, sample, features): + height, width = sample.shape[2:] + height_strides = [height // f.shape[2] for f in features] + width_strides = [width // f.shape[3] for f in features] + return height_strides, width_strides + + def test_forward_backward(self): + sample = self._get_sample( + batch_size=self.default_batch_size, + num_channels=self.default_num_channels, + height=self.default_height, + width=self.default_width, + ) + for encoder_name in self.encoder_names: + with self.subTest(encoder_name=encoder_name): + # init encoder + encoder = smp.encoders.get_encoder( + encoder_name, in_channels=3, encoder_weights=None + ) + + # forward + features = encoder.forward(sample) + self.assertEqual( + len(features), + self.num_output_features, + f"Encoder `{encoder_name}` should have {self.num_output_features} output feature maps, but has {len(features)}", + ) + + # backward + features[-1].mean().backward() + + def test_in_channels(self): + cases = [ + (encoder_name, in_channels) + for encoder_name in self.encoder_names + for in_channels in self.in_channels_to_test + ] + + for encoder_name, in_channels in cases: + sample = self._get_sample( + batch_size=self.default_batch_size, + num_channels=in_channels, + height=self.default_height, + width=self.default_width, + ) + + with self.subTest(encoder_name=encoder_name, in_channels=in_channels): + encoder = smp.encoders.get_encoder( + encoder_name, in_channels=in_channels, encoder_weights=None + ) + encoder.eval() + + # forward + with torch.no_grad(): + encoder.forward(sample) + + def test_depth(self): + sample = self._get_sample( + batch_size=self.default_batch_size, + num_channels=self.default_num_channels, + height=self.default_height, + width=self.default_width, + ) + + cases = [ + (encoder_name, depth) + for encoder_name in self.encoder_names + for depth in self.depth_to_test + ] + + for encoder_name, depth in cases: + with self.subTest(encoder_name=encoder_name, depth=depth): + encoder = smp.encoders.get_encoder( + encoder_name, + in_channels=self.default_num_channels, + encoder_weights=None, + depth=depth, + ) + encoder.eval() + + # forward + with torch.no_grad(): + features = encoder.forward(sample) + + # check number of features + self.assertEqual( + len(features), + depth + 1, + f"Encoder `{encoder_name}` should have {depth + 1} output feature maps, but has {len(features)}", + ) + + # check feature strides + height_strides, width_strides = self.get_features_output_strides( + sample, features + ) + self.assertEqual( + height_strides, + self.output_strides[: depth + 1], + f"Encoder `{encoder_name}` should have output strides {self.output_strides[:depth + 1]}, but has {height_strides}", + ) + self.assertEqual( + width_strides, + self.output_strides[: depth + 1], + f"Encoder `{encoder_name}` should have output strides {self.output_strides[:depth + 1]}, but has {width_strides}", + ) + + # check encoder output stride property + self.assertEqual( + encoder.output_stride, + self.output_strides[depth], + f"Encoder `{encoder_name}` last feature map should have output stride {self.output_strides[depth]}, but has {encoder.output_stride}", + ) + + # check out channels also have proper length + self.assertEqual( + len(encoder.out_channels), + depth + 1, + f"Encoder `{encoder_name}` should have {depth + 1} out_channels, but has {len(encoder.out_channels)}", + ) + + def test_dilated(self): + sample = self._get_sample( + batch_size=self.default_batch_size, + num_channels=self.default_num_channels, + height=self.default_height, + width=self.default_width, + ) + + cases = [ + (encoder_name, stride) + for encoder_name in self.encoder_names + for stride in self.strides_to_test + ] + + # special case for encoders that do not support dilated model + # just check proper error is raised + if not self.supports_dilated: + with self.assertRaises(ValueError, msg="not support dilated mode"): + encoder_name, stride = cases[0] + encoder = smp.encoders.get_encoder( + encoder_name, + in_channels=self.default_num_channels, + encoder_weights=None, + output_stride=stride, + ) + return + + for encoder_name, stride in cases: + with self.subTest(encoder_name=encoder_name, stride=stride): + encoder = smp.encoders.get_encoder( + encoder_name, + in_channels=self.default_num_channels, + encoder_weights=None, + output_stride=stride, + ) + encoder.eval() + + # forward + with torch.no_grad(): + features = encoder.forward(sample) + + height_strides, width_strides = self.get_features_output_strides( + sample, features + ) + expected_height_strides = [min(stride, s) for s in height_strides] + expected_width_strides = [min(stride, s) for s in width_strides] + + self.assertEqual( + height_strides, + expected_height_strides, + f"Encoder `{encoder_name}` should have height output strides {expected_height_strides}, but has {height_strides}", + ) + self.assertEqual( + width_strides, + expected_width_strides, + f"Encoder `{encoder_name}` should have width output strides {expected_width_strides}, but has {width_strides}", + ) diff --git a/tests/encoders/test_pretrainedmodels_encoders.py b/tests/encoders/test_pretrainedmodels_encoders.py new file mode 100644 index 00000000..b1bf1e14 --- /dev/null +++ b/tests/encoders/test_pretrainedmodels_encoders.py @@ -0,0 +1,71 @@ +from tests.encoders import base +from tests.config import RUN_ALL_ENCODERS + + +class TestDenseNetEncoder(base.BaseEncoderTester): + supports_dilated = False + encoder_names = ( + ["densenet121"] + if not RUN_ALL_ENCODERS + else ["densenet121", "densenet169", "densenet161"] + ) + + +class TestDPNEncoder(base.BaseEncoderTester): + encoder_names = ( + ["dpn68"] + if not RUN_ALL_ENCODERS + else ["dpn68", "dpn68b", "dpn92", "dpn98", "dpn107", "dpn131"] + ) + + +class TestInceptionResNetV2Encoder(base.BaseEncoderTester): + supports_dilated = False + encoder_names = ( + ["inceptionresnetv2"] if not RUN_ALL_ENCODERS else ["inceptionresnetv2"] + ) + + +class TestInceptionV4Encoder(base.BaseEncoderTester): + supports_dilated = False + encoder_names = ["inceptionv4"] if not RUN_ALL_ENCODERS else ["inceptionv4"] + + +class TestResNetEncoder(base.BaseEncoderTester): + encoder_names = ( + ["resnet18"] + if not RUN_ALL_ENCODERS + else [ + "resnet18", + "resnet34", + "resnet50", + "resnet101", + "resnet152", + "resnext50_32x4d", + "resnext101_32x4d", + "resnext101_32x8d", + "resnext101_32x16d", + "resnext101_32x32d", + "resnext101_32x48d", + ] + ) + + +class TestSeNetEncoder(base.BaseEncoderTester): + encoder_names = ( + ["se_resnet50"] + if not RUN_ALL_ENCODERS + else [ + "se_resnet50", + "se_resnet101", + "se_resnet152", + "se_resnext50_32x4d", + "se_resnext101_32x4d", + # "senet154", # extra large model + ] + ) + + +class TestXceptionEncoder(base.BaseEncoderTester): + supports_dilated = False + encoder_names = ["xception"] if not RUN_ALL_ENCODERS else ["xception"] diff --git a/tests/encoders/test_smp_encoders.py b/tests/encoders/test_smp_encoders.py new file mode 100644 index 00000000..c248f1cb --- /dev/null +++ b/tests/encoders/test_smp_encoders.py @@ -0,0 +1,41 @@ +from tests.encoders import base +from tests.config import RUN_ALL_ENCODERS + + +class TestMobileoneEncoder(base.BaseEncoderTester): + encoder_names = ( + ["mobileone_s0"] + if not RUN_ALL_ENCODERS + else [ + "mobileone_s0", + "mobileone_s1", + "mobileone_s2", + "mobileone_s3", + "mobileone_s4", + ] + ) + + +class TestMixTransformerEncoder(base.BaseEncoderTester): + encoder_names = ( + ["mit_b0"] + if not RUN_ALL_ENCODERS + else ["mit_b0", "mit_b1", "mit_b2", "mit_b3", "mit_b4", "mit_b5"] + ) + + +class TestEfficientNetEncoder(base.BaseEncoderTester): + encoder_names = ( + ["efficientnet-b0"] + if not RUN_ALL_ENCODERS + else [ + "efficientnet-b0", + "efficientnet-b1", + "efficientnet-b2", + "efficientnet-b3", + "efficientnet-b4", + "efficientnet-b5", + "efficientnet-b6", + # "efficientnet-b7", # extra large model + ] + ) diff --git a/tests/encoders/test_timm_ported_encoders.py b/tests/encoders/test_timm_ported_encoders.py new file mode 100644 index 00000000..d24f22ad --- /dev/null +++ b/tests/encoders/test_timm_ported_encoders.py @@ -0,0 +1,131 @@ +from tests.encoders import base +from tests.config import RUN_ALL_ENCODERS + + +class TestTimmEfficientNetEncoder(base.BaseEncoderTester): + encoder_names = ( + ["timm-efficientnet-b0"] + if not RUN_ALL_ENCODERS + else [ + "timm-efficientnet-b0", + "timm-efficientnet-b1", + "timm-efficientnet-b2", + "timm-efficientnet-b3", + "timm-efficientnet-b4", + "timm-efficientnet-b5", + "timm-efficientnet-b6", + "timm-efficientnet-b7", + "timm-efficientnet-b8", + "timm-efficientnet-l2", + "timm-tf_efficientnet_lite0", + "timm-tf_efficientnet_lite1", + "timm-tf_efficientnet_lite2", + "timm-tf_efficientnet_lite3", + "timm-tf_efficientnet_lite4", + ] + ) + + +class TestTimmGERNetEncoder(base.BaseEncoderTester): + encoder_names = ( + ["timm-gernet_s"] + if not RUN_ALL_ENCODERS + else ["timm-gernet_s", "timm-gernet_m", "timm-gernet_l"] + ) + + +class TestTimmMobileNetV3Encoder(base.BaseEncoderTester): + encoder_names = ( + ["timm-mobilenetv3_small_100"] + if not RUN_ALL_ENCODERS + else [ + "timm-mobilenetv3_large_075", + "timm-mobilenetv3_large_100", + "timm-mobilenetv3_large_minimal_100", + "timm-mobilenetv3_small_075", + "timm-mobilenetv3_small_100", + "timm-mobilenetv3_small_minimal_100", + ] + ) + + +class TestTimmRegNetEncoder(base.BaseEncoderTester): + encoder_names = ( + ["timm-regnetx_002", "timm-regnety_002"] + if not RUN_ALL_ENCODERS + else [ + "timm-regnetx_002", + "timm-regnetx_004", + "timm-regnetx_006", + "timm-regnetx_008", + "timm-regnetx_016", + "timm-regnetx_032", + "timm-regnetx_040", + "timm-regnetx_064", + "timm-regnetx_080", + "timm-regnetx_120", + "timm-regnetx_160", + "timm-regnetx_320", + "timm-regnety_002", + "timm-regnety_004", + "timm-regnety_006", + "timm-regnety_008", + "timm-regnety_016", + "timm-regnety_032", + "timm-regnety_040", + "timm-regnety_064", + "timm-regnety_080", + "timm-regnety_120", + "timm-regnety_160", + "timm-regnety_320", + ] + ) + + +class TestTimmRes2NetEncoder(base.BaseEncoderTester): + supports_dilated = False + encoder_names = ( + ["timm-res2net50_26w_4s"] + if not RUN_ALL_ENCODERS + else [ + "timm-res2net50_26w_4s", + "timm-res2net101_26w_4s", + "timm-res2net50_26w_6s", + "timm-res2net50_26w_8s", + "timm-res2net50_48w_2s", + "timm-res2net50_14w_8s", + "timm-res2next50", + ] + ) + + +class TestTimmResnestEncoder(base.BaseEncoderTester): + default_batch_size = 2 + supports_dilated = False + encoder_names = ( + ["timm-resnest14d"] + if not RUN_ALL_ENCODERS + else [ + "timm-resnest14d", + "timm-resnest26d", + "timm-resnest50d", + "timm-resnest101e", + "timm-resnest200e", + "timm-resnest269e", + "timm-resnest50d_4s2x40d", + "timm-resnest50d_1s4x24d", + ] + ) + + +class TestTimmSkNetEncoder(base.BaseEncoderTester): + default_batch_size = 2 + encoder_names = ( + ["timm-skresnet18"] + if not RUN_ALL_ENCODERS + else [ + "timm-skresnet18", + "timm-skresnet34", + "timm-skresnext50_32x4d", + ] + ) diff --git a/tests/encoders/test_timm_universal.py b/tests/encoders/test_timm_universal.py new file mode 100644 index 00000000..2f40ac4d --- /dev/null +++ b/tests/encoders/test_timm_universal.py @@ -0,0 +1,6 @@ +from tests.encoders import base +from tests.config import RUN_ALL_ENCODERS + + +class TestTimmUniversalEncoder(base.BaseEncoderTester): + encoder_names = ["tu-resnet18"] if not RUN_ALL_ENCODERS else ["tu-resnet18"] diff --git a/tests/encoders/test_torchvision_encoders.py b/tests/encoders/test_torchvision_encoders.py new file mode 100644 index 00000000..893a17e2 --- /dev/null +++ b/tests/encoders/test_torchvision_encoders.py @@ -0,0 +1,24 @@ +from tests.encoders import base +from tests.config import RUN_ALL_ENCODERS + + +class TestMobileoneEncoder(base.BaseEncoderTester): + encoder_names = ["mobilenet_v2"] if not RUN_ALL_ENCODERS else ["mobilenet_v2"] + + +class TestVggEncoder(base.BaseEncoderTester): + supports_dilated = False + encoder_names = ( + ["vgg11"] + if not RUN_ALL_ENCODERS + else [ + "vgg11", + "vgg11_bn", + "vgg13", + "vgg13_bn", + "vgg16", + "vgg16_bn", + "vgg19", + "vgg19_bn", + ] + ) From ebe261add3dea0bf55eafd5d2507e2634b46efcb Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Sun, 22 Dec 2024 23:17:13 +0000 Subject: [PATCH 04/50] Update gitignore --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 99a7807d..33db579f 100644 --- a/.gitignore +++ b/.gitignore @@ -93,6 +93,7 @@ venv/ ENV/ env.bak/ venv.bak/ +.vscode/ # Spyder project settings .spyderproject From e8f7a4e16119f706ad9a43fb0b5676bf7b658ae5 Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Sun, 22 Dec 2024 23:17:30 +0000 Subject: [PATCH 05/50] Update encoders for timm-universal --- tests/encoders/test_timm_universal.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tests/encoders/test_timm_universal.py b/tests/encoders/test_timm_universal.py index 2f40ac4d..9fcfa351 100644 --- a/tests/encoders/test_timm_universal.py +++ b/tests/encoders/test_timm_universal.py @@ -1,6 +1,9 @@ from tests.encoders import base -from tests.config import RUN_ALL_ENCODERS class TestTimmUniversalEncoder(base.BaseEncoderTester): - encoder_names = ["tu-resnet18"] if not RUN_ALL_ENCODERS else ["tu-resnet18"] + encoder_names = [ + "tu-resnet18", # for timm universal traditional-like encoder + "tu-convnext_atto", # for timm universal transformer-like encoder + "tu-darknet17", # for timm universal vgg-like encoder + ] From 988c85fab54b8c3ad9ba6c010c7e5ed1c2bde6ee Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Sun, 22 Dec 2024 23:19:51 +0000 Subject: [PATCH 06/50] Add parallel tests run --- .github/workflows/tests.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index e2767433..1049570d 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -43,7 +43,7 @@ jobs: - name: Install dependencies run: python -m pip install -r requirements/required.txt -r requirements/test.txt - name: Test with pytest - run: pytest + run: pytest -v -n 2 minimum: runs-on: ubuntu-latest @@ -56,4 +56,4 @@ jobs: - name: Install dependencies run: python -m pip install -r requirements/minimum.old -r requirements/test.txt - name: Test with pytest - run: pytest + run: pytest -v -n 2 From ae0977e64579bee4a08352240a9685926e8deaa0 Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Sun, 22 Dec 2024 23:20:02 +0000 Subject: [PATCH 07/50] Disable models tests --- tests/test_models.py | 388 +++++++++++++++++++++---------------------- 1 file changed, 194 insertions(+), 194 deletions(-) diff --git a/tests/test_models.py b/tests/test_models.py index 460dcdf2..c2b71585 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -1,194 +1,194 @@ -import pytest -import torch - -import segmentation_models_pytorch as smp # noqa - - -def get_encoders(): - exclude_encoders = [ - "senet154", - "resnext101_32x16d", - "resnext101_32x32d", - "resnext101_32x48d", - ] - encoders = smp.encoders.get_encoder_names() - encoders = [e for e in encoders if e not in exclude_encoders] - encoders.append("tu-resnet34") # for timm universal traditional-like encoder - encoders.append("tu-convnext_atto") # for timm universal transformer-like encoder - encoders.append("tu-darknet17") # for timm universal vgg-like encoder - return encoders - - -ENCODERS = get_encoders() -DEFAULT_ENCODER = "resnet18" - - -def get_sample(model_class): - if model_class in [ - smp.FPN, - smp.Linknet, - smp.Unet, - smp.UnetPlusPlus, - smp.MAnet, - smp.Segformer, - ]: - sample = torch.ones([1, 3, 64, 64]) - elif model_class in [smp.PAN, smp.DeepLabV3, smp.DeepLabV3Plus]: - sample = torch.ones([2, 3, 128, 128]) - elif model_class in [smp.PSPNet, smp.UPerNet]: - # Batch size 2 needed due to nn.BatchNorm2d not supporting (1, C, 1, 1) input - # from PSPModule pooling in PSPNet/UPerNet. - sample = torch.ones([2, 3, 64, 64]) - else: - raise ValueError("Not supported model class {}".format(model_class)) - return sample - - -def _test_forward(model, sample, test_shape=False): - with torch.no_grad(): - out = model(sample) - if test_shape: - assert out.shape[2:] == sample.shape[2:] - - -def _test_forward_backward(model, sample, test_shape=False): - out = model(sample) - out.mean().backward() - if test_shape: - assert out.shape[2:] == sample.shape[2:] - - -@pytest.mark.parametrize("encoder_name", ENCODERS) -@pytest.mark.parametrize("encoder_depth", [3, 5]) -@pytest.mark.parametrize( - "model_class", - [ - smp.FPN, - smp.PSPNet, - smp.Linknet, - smp.Unet, - smp.UnetPlusPlus, - smp.MAnet, - smp.UPerNet, - smp.Segformer, - ], -) -def test_forward(model_class, encoder_name, encoder_depth, **kwargs): - if ( - model_class is smp.Unet - or model_class is smp.UnetPlusPlus - or model_class is smp.MAnet - ): - kwargs["decoder_channels"] = (16, 16, 16, 16, 16)[-encoder_depth:] - if model_class in [smp.UnetPlusPlus, smp.Linknet]: - if encoder_name.startswith("mit_b") or encoder_name.startswith("tu-convnext"): - return # skip transformer-like model* - if model_class is smp.FPN and encoder_depth != 5: - if encoder_name.startswith("mit_b") or encoder_name.startswith("tu-convnext"): - return # skip transformer-like model* - model = model_class( - encoder_name, encoder_depth=encoder_depth, encoder_weights=None, **kwargs - ) - sample = get_sample(model_class) - model.eval() - if encoder_depth == 5 and model_class != smp.PSPNet: - test_shape = True - else: - test_shape = False - - _test_forward(model, sample, test_shape) - - -@pytest.mark.parametrize( - "model_class", - [ - smp.PAN, - smp.FPN, - smp.PSPNet, - smp.Linknet, - smp.Unet, - smp.UnetPlusPlus, - smp.MAnet, - smp.DeepLabV3, - smp.DeepLabV3Plus, - smp.UPerNet, - smp.Segformer, - ], -) -def test_forward_backward(model_class): - sample = get_sample(model_class) - model = model_class(DEFAULT_ENCODER, encoder_weights=None) - _test_forward_backward(model, sample) - - -@pytest.mark.parametrize( - "model_class", - [ - smp.PAN, - smp.FPN, - smp.PSPNet, - smp.Linknet, - smp.Unet, - smp.UnetPlusPlus, - smp.MAnet, - smp.DeepLabV3, - smp.DeepLabV3Plus, - smp.UPerNet, - smp.Segformer, - ], -) -def test_aux_output(model_class): - model = model_class( - DEFAULT_ENCODER, encoder_weights=None, aux_params=dict(classes=2) - ) - sample = get_sample(model_class) - label_size = (sample.shape[0], 2) - mask, label = model(sample) - assert label.size() == label_size - - -@pytest.mark.parametrize("upsampling", [2, 4, 8]) -@pytest.mark.parametrize("model_class", [smp.FPN, smp.PSPNet]) -def test_upsample(model_class, upsampling): - default_upsampling = 4 if model_class is smp.FPN else 8 - model = model_class(DEFAULT_ENCODER, encoder_weights=None, upsampling=upsampling) - sample = get_sample(model_class) - mask = model(sample) - assert mask.size()[-1] / 64 == upsampling / default_upsampling - - -@pytest.mark.parametrize("model_class", [smp.FPN]) -@pytest.mark.parametrize("in_channels", [1, 2, 4]) -def test_in_channels(model_class, in_channels): - sample = torch.ones([1, in_channels, 64, 64]) - model = model_class(DEFAULT_ENCODER, encoder_weights=None, in_channels=in_channels) - model.eval() - with torch.no_grad(): - model(sample) - - assert model.encoder._in_channels == in_channels - - -@pytest.mark.parametrize("encoder_name", ENCODERS) -def test_dilation(encoder_name): - if ( - encoder_name in ["inceptionresnetv2", "xception", "inceptionv4"] - or encoder_name.startswith("vgg") - or encoder_name.startswith("densenet") - or encoder_name.startswith("timm-res") - ): - return - - encoder = smp.encoders.get_encoder(encoder_name, output_stride=16) - - encoder.eval() - with torch.no_grad(): - sample = torch.ones([1, 3, 64, 64]) - output = encoder(sample) - - shapes = [out.shape[-1] for out in output] - assert shapes == [64, 32, 16, 8, 4, 4] # last downsampling replaced with dilation - - -if __name__ == "__main__": - pytest.main([__file__]) +# import pytest +# import torch + +# import segmentation_models_pytorch as smp # noqa + + +# def get_encoders(): +# exclude_encoders = [ +# "senet154", +# "resnext101_32x16d", +# "resnext101_32x32d", +# "resnext101_32x48d", +# ] +# encoders = smp.encoders.get_encoder_names() +# encoders = [e for e in encoders if e not in exclude_encoders] +# encoders.append("tu-resnet34") # for timm universal traditional-like encoder +# encoders.append("tu-convnext_atto") # for timm universal transformer-like encoder +# encoders.append("tu-darknet17") # for timm universal vgg-like encoder +# return encoders + + +# ENCODERS = get_encoders() +# DEFAULT_ENCODER = "resnet18" + + +# def get_sample(model_class): +# if model_class in [ +# smp.FPN, +# smp.Linknet, +# smp.Unet, +# smp.UnetPlusPlus, +# smp.MAnet, +# smp.Segformer, +# ]: +# sample = torch.ones([1, 3, 64, 64]) +# elif model_class in [smp.PAN, smp.DeepLabV3, smp.DeepLabV3Plus]: +# sample = torch.ones([2, 3, 128, 128]) +# elif model_class in [smp.PSPNet, smp.UPerNet]: +# # Batch size 2 needed due to nn.BatchNorm2d not supporting (1, C, 1, 1) input +# # from PSPModule pooling in PSPNet/UPerNet. +# sample = torch.ones([2, 3, 64, 64]) +# else: +# raise ValueError("Not supported model class {}".format(model_class)) +# return sample + + +# def _test_forward(model, sample, test_shape=False): +# with torch.no_grad(): +# out = model(sample) +# if test_shape: +# assert out.shape[2:] == sample.shape[2:] + + +# def _test_forward_backward(model, sample, test_shape=False): +# out = model(sample) +# out.mean().backward() +# if test_shape: +# assert out.shape[2:] == sample.shape[2:] + + +# @pytest.mark.parametrize("encoder_name", ENCODERS) +# @pytest.mark.parametrize("encoder_depth", [3, 5]) +# @pytest.mark.parametrize( +# "model_class", +# [ +# smp.FPN, +# smp.PSPNet, +# smp.Linknet, +# smp.Unet, +# smp.UnetPlusPlus, +# smp.MAnet, +# smp.UPerNet, +# smp.Segformer, +# ], +# ) +# def test_forward(model_class, encoder_name, encoder_depth, **kwargs): +# if ( +# model_class is smp.Unet +# or model_class is smp.UnetPlusPlus +# or model_class is smp.MAnet +# ): +# kwargs["decoder_channels"] = (16, 16, 16, 16, 16)[-encoder_depth:] +# if model_class in [smp.UnetPlusPlus, smp.Linknet]: +# if encoder_name.startswith("mit_b") or encoder_name.startswith("tu-convnext"): +# return # skip transformer-like model* +# if model_class is smp.FPN and encoder_depth != 5: +# if encoder_name.startswith("mit_b") or encoder_name.startswith("tu-convnext"): +# return # skip transformer-like model* +# model = model_class( +# encoder_name, encoder_depth=encoder_depth, encoder_weights=None, **kwargs +# ) +# sample = get_sample(model_class) +# model.eval() +# if encoder_depth == 5 and model_class != smp.PSPNet: +# test_shape = True +# else: +# test_shape = False + +# _test_forward(model, sample, test_shape) + + +# @pytest.mark.parametrize( +# "model_class", +# [ +# smp.PAN, +# smp.FPN, +# smp.PSPNet, +# smp.Linknet, +# smp.Unet, +# smp.UnetPlusPlus, +# smp.MAnet, +# smp.DeepLabV3, +# smp.DeepLabV3Plus, +# smp.UPerNet, +# smp.Segformer, +# ], +# ) +# def test_forward_backward(model_class): +# sample = get_sample(model_class) +# model = model_class(DEFAULT_ENCODER, encoder_weights=None) +# _test_forward_backward(model, sample) + + +# @pytest.mark.parametrize( +# "model_class", +# [ +# smp.PAN, +# smp.FPN, +# smp.PSPNet, +# smp.Linknet, +# smp.Unet, +# smp.UnetPlusPlus, +# smp.MAnet, +# smp.DeepLabV3, +# smp.DeepLabV3Plus, +# smp.UPerNet, +# smp.Segformer, +# ], +# ) +# def test_aux_output(model_class): +# model = model_class( +# DEFAULT_ENCODER, encoder_weights=None, aux_params=dict(classes=2) +# ) +# sample = get_sample(model_class) +# label_size = (sample.shape[0], 2) +# mask, label = model(sample) +# assert label.size() == label_size + + +# @pytest.mark.parametrize("upsampling", [2, 4, 8]) +# @pytest.mark.parametrize("model_class", [smp.FPN, smp.PSPNet]) +# def test_upsample(model_class, upsampling): +# default_upsampling = 4 if model_class is smp.FPN else 8 +# model = model_class(DEFAULT_ENCODER, encoder_weights=None, upsampling=upsampling) +# sample = get_sample(model_class) +# mask = model(sample) +# assert mask.size()[-1] / 64 == upsampling / default_upsampling + + +# @pytest.mark.parametrize("model_class", [smp.FPN]) +# @pytest.mark.parametrize("in_channels", [1, 2, 4]) +# def test_in_channels(model_class, in_channels): +# sample = torch.ones([1, in_channels, 64, 64]) +# model = model_class(DEFAULT_ENCODER, encoder_weights=None, in_channels=in_channels) +# model.eval() +# with torch.no_grad(): +# model(sample) + +# assert model.encoder._in_channels == in_channels + + +# @pytest.mark.parametrize("encoder_name", ENCODERS) +# def test_dilation(encoder_name): +# if ( +# encoder_name in ["inceptionresnetv2", "xception", "inceptionv4"] +# or encoder_name.startswith("vgg") +# or encoder_name.startswith("densenet") +# or encoder_name.startswith("timm-res") +# ): +# return + +# encoder = smp.encoders.get_encoder(encoder_name, output_stride=16) + +# encoder.eval() +# with torch.no_grad(): +# sample = torch.ones([1, 3, 64, 64]) +# output = encoder(sample) + +# shapes = [out.shape[-1] for out in output] +# assert shapes == [64, 32, 16, 8, 4, 4] # last downsampling replaced with dilation + + +# if __name__ == "__main__": +# pytest.main([__file__]) From f590d9d3722116cb0171a2196f934cede2ec4cc4 Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Sun, 22 Dec 2024 23:34:03 +0000 Subject: [PATCH 08/50] Add uv to CI --- .github/workflows/tests.yml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 1049570d..4c67eb31 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -41,7 +41,9 @@ jobs: with: python-version: ${{ matrix.python-version }} - name: Install dependencies - run: python -m pip install -r requirements/required.txt -r requirements/test.txt + run: | + python -m pip install uv + python -m uv pip install --system -r requirements/required.txt -r requirements/test.txt - name: Test with pytest run: pytest -v -n 2 From bd48d9e4d2b8836bf6057c858a412a13e30dd857 Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Sun, 22 Dec 2024 23:37:56 +0000 Subject: [PATCH 09/50] Add uv to minimum --- .github/workflows/tests.yml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 4c67eb31..4a0c36b0 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -56,6 +56,8 @@ jobs: with: python-version: "3.9" - name: Install dependencies - run: python -m pip install -r requirements/minimum.old -r requirements/test.txt + run: | + python -m pip install uv + python -m uv pip install --system -r requirements/minimum.old -r requirements/test.txt - name: Test with pytest run: pytest -v -n 2 From cc79c8ee93b7e64b761587cc30844f799fbf53d9 Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Sun, 22 Dec 2024 23:39:33 +0000 Subject: [PATCH 10/50] Add show-install-packages --- .github/workflows/tests.yml | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 4a0c36b0..f640b3bc 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -40,6 +40,9 @@ jobs: uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} + - name: Show installed packages + run: | + python -m pip list - name: Install dependencies run: | python -m pip install uv @@ -55,6 +58,9 @@ jobs: uses: actions/setup-python@v5 with: python-version: "3.9" + - name: Show installed packages + run: | + python -m pip list - name: Install dependencies run: | python -m pip install uv From dba99f8809051ea7e12f97cf9430b600d1dec290 Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Sun, 22 Dec 2024 23:40:00 +0000 Subject: [PATCH 11/50] Increase to 3 workers --- .github/workflows/tests.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index f640b3bc..d5dc1f28 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -48,7 +48,7 @@ jobs: python -m pip install uv python -m uv pip install --system -r requirements/required.txt -r requirements/test.txt - name: Test with pytest - run: pytest -v -n 2 + run: pytest -v -n 3 minimum: runs-on: ubuntu-latest @@ -66,4 +66,4 @@ jobs: python -m pip install uv python -m uv pip install --system -r requirements/minimum.old -r requirements/test.txt - name: Test with pytest - run: pytest -v -n 2 + run: pytest -v -n 3 From 2bb41b42bb3319d08e22dde5791314aa9f564f8f Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Sun, 22 Dec 2024 23:41:15 +0000 Subject: [PATCH 12/50] Fix show-packages --- .github/workflows/tests.yml | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index d5dc1f28..81cc0459 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -40,13 +40,13 @@ jobs: uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - - name: Show installed packages - run: | - python -m pip list - name: Install dependencies run: | python -m pip install uv python -m uv pip install --system -r requirements/required.txt -r requirements/test.txt + - name: Show installed packages + run: | + python -m pip list - name: Test with pytest run: pytest -v -n 3 @@ -58,12 +58,12 @@ jobs: uses: actions/setup-python@v5 with: python-version: "3.9" - - name: Show installed packages - run: | - python -m pip list - name: Install dependencies run: | python -m pip install uv python -m uv pip install --system -r requirements/minimum.old -r requirements/test.txt + - name: Show installed packages + run: | + python -m pip list - name: Test with pytest run: pytest -v -n 3 From 527799ecee347cb4360e597f66d9ea943dac89c5 Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Sun, 22 Dec 2024 23:43:45 +0000 Subject: [PATCH 13/50] Change back for 2 workers --- .github/workflows/tests.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 81cc0459..9278c84b 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -48,7 +48,7 @@ jobs: run: | python -m pip list - name: Test with pytest - run: pytest -v -n 3 + run: pytest -v -n 2 minimum: runs-on: ubuntu-latest @@ -66,4 +66,4 @@ jobs: run: | python -m pip list - name: Test with pytest - run: pytest -v -n 3 + run: pytest -v -n 2 From 0474e774599950550f6db37491ed758b2398ea5e Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Mon, 23 Dec 2024 08:28:40 +0000 Subject: [PATCH 14/50] Add coverage --- .github/workflows/tests.yml | 8 +++++++- requirements/test.txt | 1 + 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 9278c84b..4f87bb60 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -48,7 +48,13 @@ jobs: run: | python -m pip list - name: Test with pytest - run: pytest -v -n 2 + run: pytest -v -n 2 --cov=segmentation_models_pytorch --cov-report=xml + - name: Upload coverage reports to Codecov + uses: codecov/codecov-action@v5 + with: + token: ${{ secrets.CODECOV_TOKEN }} + slug: qubvel-org/segmentation_models.pytorch + if: matrix.os == 'macos-latest' && matrix.python-version == '3.12' minimum: runs-on: ubuntu-latest diff --git a/requirements/test.txt b/requirements/test.txt index aaa038b4..ca126ece 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -1,3 +1,4 @@ pytest==8.3.4 pytest-xdist==3.6.1 +pytest-cov==6.0.0 ruff==0.8.4 \ No newline at end of file From 61a54960a2460691a4aaeab8856b4a3764f05f5c Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Mon, 23 Dec 2024 09:28:37 +0000 Subject: [PATCH 15/50] Basic model test --- pyproject.toml | 15 +++++++ tests/encoders/test_timm_universal.py | 1 + tests/models/__init__.py | 0 tests/models/base.py | 56 +++++++++++++++++++++++++++ tests/models/test_deeplab.py | 16 ++++++++ tests/models/test_fpn.py | 7 ++++ tests/models/test_linknet.py | 7 ++++ tests/models/test_manet.py | 7 ++++ tests/models/test_pan.py | 12 ++++++ tests/models/test_psp.py | 9 +++++ tests/models/test_segformer.py | 7 ++++ tests/models/test_unet.py | 7 ++++ tests/models/test_unetplusplus.py | 7 ++++ tests/models/test_upernet.py | 8 ++++ 14 files changed, 159 insertions(+) create mode 100644 tests/models/__init__.py create mode 100644 tests/models/base.py create mode 100644 tests/models/test_deeplab.py create mode 100644 tests/models/test_fpn.py create mode 100644 tests/models/test_linknet.py create mode 100644 tests/models/test_manet.py create mode 100644 tests/models/test_pan.py create mode 100644 tests/models/test_psp.py create mode 100644 tests/models/test_segformer.py create mode 100644 tests/models/test_unet.py create mode 100644 tests/models/test_unetplusplus.py create mode 100644 tests/models/test_upernet.py diff --git a/pyproject.toml b/pyproject.toml index 3df76c0f..c6b48009 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,3 +55,18 @@ version = {attr = 'segmentation_models_pytorch.__version__.__version__'} [tool.setuptools.packages.find] include = ['segmentation_models_pytorch*'] + +[tool.pytest.ini_options] +markers = [ + "deeplabv3", + "deeplabv3plus", + "fpn", + "linknet", + "manet", + "pan", + "psp", + "segformer", + "unet", + "unetplusplus", + "upernet", +] diff --git a/tests/encoders/test_timm_universal.py b/tests/encoders/test_timm_universal.py index 9fcfa351..73f97edf 100644 --- a/tests/encoders/test_timm_universal.py +++ b/tests/encoders/test_timm_universal.py @@ -3,6 +3,7 @@ class TestTimmUniversalEncoder(base.BaseEncoderTester): encoder_names = [ + "tu-test_resnet.r160_in1k", "tu-resnet18", # for timm universal traditional-like encoder "tu-convnext_atto", # for timm universal transformer-like encoder "tu-darknet17", # for timm universal vgg-like encoder diff --git a/tests/models/__init__.py b/tests/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/models/base.py b/tests/models/base.py new file mode 100644 index 00000000..8ccdd3aa --- /dev/null +++ b/tests/models/base.py @@ -0,0 +1,56 @@ +import unittest +from functools import lru_cache + +import torch +import segmentation_models_pytorch as smp + + +class BaseModelTester(unittest.TestCase): + test_encoder_name = "tu-test_resnet.r160_in1k" + + # should be overriden + test_model_type = None + + # test sample configuration + default_batch_size = 1 + default_num_channels = 3 + default_height = 64 + default_width = 64 + + @property + def model_type(self): + if self.test_model_type is None: + raise ValueError("test_model_type is not set") + return self.test_model_type + + @lru_cache + def _get_sample(self, batch_size=1, num_channels=3, height=32, width=32): + return torch.rand(batch_size, num_channels, height, width) + + def test_forward_backward(self): + sample = self._get_sample( + batch_size=self.default_batch_size, + num_channels=self.default_num_channels, + height=self.default_height, + width=self.default_width, + ) + model = smp.create_model(arch=self.model_type) + + # check default in_channels=3 + output = model(sample) + + # check default output number of classes = 1 + expected_number_of_classes = 1 + result_number_of_classes = output.shape[1] + self.assertEqual( + result_number_of_classes, + expected_number_of_classes, + f"Default output number of classes should be {expected_number_of_classes}, but got {result_number_of_classes}", + ) + + # check backward pass + output.mean().backward() + + def test_encoder_params_are_set(self): + model = smp.create_model(arch=self.model_type) + self.assertEqual(model.encoder.name, self.test_encoder_name) diff --git a/tests/models/test_deeplab.py b/tests/models/test_deeplab.py new file mode 100644 index 00000000..d3d350e9 --- /dev/null +++ b/tests/models/test_deeplab.py @@ -0,0 +1,16 @@ +import pytest +from tests.models import base + + +@pytest.mark.deeplabv3 +class TestDeeplabV3Model(base.BaseModelTester): + test_model_type = "deeplabv3" + + default_batch_size = 2 + + +@pytest.mark.deeplabv3plus +class TestDeeplabV3PlusModel(base.BaseModelTester): + test_model_type = "deeplabv3plus" + + default_batch_size = 2 diff --git a/tests/models/test_fpn.py b/tests/models/test_fpn.py new file mode 100644 index 00000000..15ae1f6a --- /dev/null +++ b/tests/models/test_fpn.py @@ -0,0 +1,7 @@ +import pytest +from tests.models import base + + +@pytest.mark.fpn +class TestFpnModel(base.BaseModelTester): + test_model_type = "fpn" diff --git a/tests/models/test_linknet.py b/tests/models/test_linknet.py new file mode 100644 index 00000000..1ab5eb4e --- /dev/null +++ b/tests/models/test_linknet.py @@ -0,0 +1,7 @@ +import pytest +from tests.models import base + + +@pytest.mark.linknet +class TestLinknetModel(base.BaseModelTester): + test_model_type = "linknet" diff --git a/tests/models/test_manet.py b/tests/models/test_manet.py new file mode 100644 index 00000000..33a8ae3b --- /dev/null +++ b/tests/models/test_manet.py @@ -0,0 +1,7 @@ +import pytest +from tests.models import base + + +@pytest.mark.manet +class TestManetModel(base.BaseModelTester): + test_model_type = "manet" diff --git a/tests/models/test_pan.py b/tests/models/test_pan.py new file mode 100644 index 00000000..f3cddc58 --- /dev/null +++ b/tests/models/test_pan.py @@ -0,0 +1,12 @@ +import pytest +from tests.models import base + + +@pytest.mark.pan +class TestPanModel(base.BaseModelTester): + test_model_type = "pan" + test_encoder_name = "resnet-18" + + default_batch_size = 2 + default_height = 128 + default_width = 128 diff --git a/tests/models/test_psp.py b/tests/models/test_psp.py new file mode 100644 index 00000000..2603cdda --- /dev/null +++ b/tests/models/test_psp.py @@ -0,0 +1,9 @@ +import pytest +from tests.models import base + + +@pytest.mark.psp +class TestPspModel(base.BaseModelTester): + test_model_type = "pspnet" + + default_batch_size = 2 diff --git a/tests/models/test_segformer.py b/tests/models/test_segformer.py new file mode 100644 index 00000000..3317acd2 --- /dev/null +++ b/tests/models/test_segformer.py @@ -0,0 +1,7 @@ +import pytest +from tests.models import base + + +@pytest.mark.segformer +class TestSegformerModel(base.BaseModelTester): + test_model_type = "segformer" diff --git a/tests/models/test_unet.py b/tests/models/test_unet.py new file mode 100644 index 00000000..54c69bf0 --- /dev/null +++ b/tests/models/test_unet.py @@ -0,0 +1,7 @@ +import pytest +from tests.models import base + + +@pytest.mark.unet +class TestUnetModel(base.BaseModelTester): + test_model_type = "unet" diff --git a/tests/models/test_unetplusplus.py b/tests/models/test_unetplusplus.py new file mode 100644 index 00000000..9e67f2ed --- /dev/null +++ b/tests/models/test_unetplusplus.py @@ -0,0 +1,7 @@ +import pytest +from tests.models import base + + +@pytest.mark.unetplusplus +class TestUnetPlusPlusModel(base.BaseModelTester): + test_model_type = "unetplusplus" diff --git a/tests/models/test_upernet.py b/tests/models/test_upernet.py new file mode 100644 index 00000000..71d703f9 --- /dev/null +++ b/tests/models/test_upernet.py @@ -0,0 +1,8 @@ +import pytest +from tests.models import base + + +@pytest.mark.upernet +class TestUnetModel(base.BaseModelTester): + test_model_type = "upernet" + default_batch_size = 2 From e97ce926d6f5e68b884a46668e41dfc44cc24189 Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Mon, 23 Dec 2024 09:30:35 +0000 Subject: [PATCH 16/50] Fix --- tests/models/base.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/models/base.py b/tests/models/base.py index 8ccdd3aa..5ceb3c99 100644 --- a/tests/models/base.py +++ b/tests/models/base.py @@ -50,7 +50,3 @@ def test_forward_backward(self): # check backward pass output.mean().backward() - - def test_encoder_params_are_set(self): - model = smp.create_model(arch=self.model_type) - self.assertEqual(model.encoder.name, self.test_encoder_name) From 7c947f824c1620d81dbbf1f811c667fcdf38c06b Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Mon, 23 Dec 2024 11:04:46 +0000 Subject: [PATCH 17/50] Move model archs --- segmentation_models_pytorch/__init__.py | 33 +++++++++++++------------ 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/segmentation_models_pytorch/__init__.py b/segmentation_models_pytorch/__init__.py index 5cde6004..f1807836 100644 --- a/segmentation_models_pytorch/__init__.py +++ b/segmentation_models_pytorch/__init__.py @@ -30,6 +30,21 @@ "ignore", message=r'"is" with \'str\' literal.*', category=SyntaxWarning ) # for python >= 3.12 +_MODEL_ARCHITECTURES = [ + Unet, + UnetPlusPlus, + MAnet, + Linknet, + FPN, + PSPNet, + DeepLabV3, + DeepLabV3Plus, + PAN, + UPerNet, + Segformer, +] +MODEL_ARCHITECTURES_MAPPING = {a.__name__.lower(): a for a in _MODEL_ARCHITECTURES} + def create_model( arch: str, @@ -43,26 +58,12 @@ def create_model( parameters, without using its class """ - archs = [ - Unet, - UnetPlusPlus, - MAnet, - Linknet, - FPN, - PSPNet, - DeepLabV3, - DeepLabV3Plus, - PAN, - UPerNet, - Segformer, - ] - archs_dict = {a.__name__.lower(): a for a in archs} try: - model_class = archs_dict[arch.lower()] + model_class = MODEL_ARCHITECTURES_MAPPING[arch.lower()] except KeyError: raise KeyError( "Wrong architecture type `{}`. Available options are: {}".format( - arch, list(archs_dict.keys()) + arch, list(MODEL_ARCHITECTURES_MAPPING.keys()) ) ) return model_class( From 1d5e1ea9f885513bbfd2b3ec001b3584e34363f2 Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Mon, 23 Dec 2024 11:05:00 +0000 Subject: [PATCH 18/50] Add base params test --- tests/models/base.py | 39 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/tests/models/base.py b/tests/models/base.py index 5ceb3c99..4468cbed 100644 --- a/tests/models/base.py +++ b/tests/models/base.py @@ -1,4 +1,5 @@ import unittest +import inspect from functools import lru_cache import torch @@ -23,6 +24,18 @@ def model_type(self): raise ValueError("test_model_type is not set") return self.test_model_type + @property + def model_class(self): + return smp.MODEL_ARCHITECTURES_MAPPING[self.model_type] + + @property + def decoder_channels(self): + signature = inspect.signature(self.model_class) + # check if decoder_channels is in the signature + if "decoder_channels" in signature.parameters: + return signature.parameters["decoder_channels"].default + return None + @lru_cache def _get_sample(self, batch_size=1, num_channels=3, height=32, width=32): return torch.rand(batch_size, num_channels, height, width) @@ -50,3 +63,29 @@ def test_forward_backward(self): # check backward pass output.mean().backward() + + def test_base_params_are_set(self, in_channels=1, depth=3, classes=7): + kwargs = {} + + if self.model_type in ["unet", "unetplusplus", "manet"]: + kwargs = {"decoder_channels": self.decoder_channels[:depth]} + + model = smp.create_model( + arch=self.model_type, + encoder_depth=depth, + in_channels=in_channels, + classes=classes, + **kwargs, + ) + sample = self._get_sample( + batch_size=self.default_batch_size, + num_channels=in_channels, + height=self.default_height, + width=self.default_width, + ) + + # check in channels correctly set + with torch.no_grad(): + output = model(sample) + + self.assertEqual(output.shape[1], classes) From ede9942ec48e18c96233adebd0d18d342fcc3ead Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Mon, 23 Dec 2024 11:10:42 +0000 Subject: [PATCH 19/50] Fix timm test for minimum version --- tests/encoders/test_timm_universal.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/tests/encoders/test_timm_universal.py b/tests/encoders/test_timm_universal.py index 73f97edf..72adaf05 100644 --- a/tests/encoders/test_timm_universal.py +++ b/tests/encoders/test_timm_universal.py @@ -1,10 +1,17 @@ +import timm from tests.encoders import base +from packaging.version import Version + +# check if timm >= 1.0.12 +timm_encoders = [ + "tu-resnet18", # for timm universal traditional-like encoder + "tu-convnext_atto", # for timm universal transformer-like encoder + "tu-darknet17", # for timm universal vgg-like encoder +] + +if Version(timm.__version__) >= Version("1.0.12"): + timm_encoders.append("tu-test_resnet.r160_in1k") class TestTimmUniversalEncoder(base.BaseEncoderTester): - encoder_names = [ - "tu-test_resnet.r160_in1k", - "tu-resnet18", # for timm universal traditional-like encoder - "tu-convnext_atto", # for timm universal transformer-like encoder - "tu-darknet17", # for timm universal vgg-like encoder - ] + encoder_names = timm_encoders From b0c928920a8a934de28bfcece45f6fb64edc7b8e Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Mon, 23 Dec 2024 11:21:34 +0000 Subject: [PATCH 20/50] Remove deprecated utils from coverage --- .github/workflows/tests.yml | 2 +- pyproject.toml | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 4f87bb60..83cb0edc 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -48,7 +48,7 @@ jobs: run: | python -m pip list - name: Test with pytest - run: pytest -v -n 2 --cov=segmentation_models_pytorch --cov-report=xml + run: pytest -v -n 2 --cov=segmentation_models_pytorch --cov-report=xml --cov-config=pyproject.toml - name: Upload coverage reports to Codecov uses: codecov/codecov-action@v5 with: diff --git a/pyproject.toml b/pyproject.toml index c6b48009..1d2f992c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -70,3 +70,6 @@ markers = [ "unetplusplus", "upernet", ] + +[tool.coverage.run] +omit = segmentation_models_pytorch/utils/* From d58d1fded3efd1d0aa9cce6e96bd48e764feecbf Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Mon, 23 Dec 2024 11:22:33 +0000 Subject: [PATCH 21/50] Fix --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 1d2f992c..6a38c97b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -72,4 +72,4 @@ markers = [ ] [tool.coverage.run] -omit = segmentation_models_pytorch/utils/* +omit = "segmentation_models_pytorch/utils/*" From 18ee80df8ea7668bdf9e984cfc33c709f76feaab Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Mon, 23 Dec 2024 11:23:33 +0000 Subject: [PATCH 22/50] Fix --- pyproject.toml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 6a38c97b..6556391e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -72,4 +72,6 @@ markers = [ ] [tool.coverage.run] -omit = "segmentation_models_pytorch/utils/*" +omit = [ + "segmentation_models_pytorch/utils/*" +] From 2b113f0055bdb7b973c8d10c20945e644fb7d8d2 Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Mon, 23 Dec 2024 11:30:08 +0000 Subject: [PATCH 23/50] Exclude conversion script --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 6556391e..bb4a107d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -73,5 +73,6 @@ markers = [ [tool.coverage.run] omit = [ - "segmentation_models_pytorch/utils/*" + "segmentation_models_pytorch/utils/*", + "**/convert_*", ] From b7ce422b672670d064db5755fa45907be102aeda Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Mon, 23 Dec 2024 11:42:43 +0000 Subject: [PATCH 24/50] Add save-load test, add aux head test --- tests/models/base.py | 58 +++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 57 insertions(+), 1 deletion(-) diff --git a/tests/models/base.py b/tests/models/base.py index 4468cbed..f81c7275 100644 --- a/tests/models/base.py +++ b/tests/models/base.py @@ -1,5 +1,6 @@ -import unittest import inspect +import tempfile +import unittest from functools import lru_cache import torch @@ -89,3 +90,58 @@ def test_base_params_are_set(self, in_channels=1, depth=3, classes=7): output = model(sample) self.assertEqual(output.shape[1], classes) + + def test_aux_params(self): + model = smp.create_model( + arch=self.model_type, + aux_params={ + "pooling": "avg", + "classes": 10, + "dropout": 0.5, + "activation": "sigmoid", + }, + ) + + self.assertIsNotNone(model.classification_head) + self.assertIsInstance(model.classification_head[0], torch.nn.AdaptiveAvgPool2d) + self.assertIsInstance(model.classification_head[1], torch.nn.Flatten) + self.assertIsInstance(model.classification_head[2], torch.nn.Dropout) + self.assertEqual(model.classification_head[2].p, 0.5) + self.assertIsInstance(model.classification_head[3], torch.nn.Linear) + self.assertIsInstance(model.classification_head[4].activation, torch.nn.Sigmoid) + + sample = self._get_sample( + batch_size=self.default_batch_size, + num_channels=self.default_num_channels, + height=self.default_height, + width=self.default_width, + ) + + with torch.no_grad(): + _, cls_probs = model(sample) + + self.assertEqual(cls_probs.shape[1], 10) + + def test_save_load(self): + # instantiate model + model = smp.create_model(arch=self.model_type) + + # save model + with tempfile.TemporaryDirectory() as tmpdir: + model.save_pretrained(tmpdir) + restored_model = model.from_pretrained(tmpdir) + + # check inference is correct + sample = self._get_sample( + batch_size=self.default_batch_size, + num_channels=self.default_num_channels, + height=self.default_height, + width=self.default_width, + ) + + with torch.no_grad(): + output = model(sample) + restored_output = restored_model(sample) + + self.assertEqual(output.shape, restored_output.shape) + self.assertEqual(output.shape[1], 1) From 23bfce6716df6d6eed3f470bf67d5d1aaae356a5 Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Mon, 23 Dec 2024 12:15:07 +0000 Subject: [PATCH 25/50] Remove custom encoder --- tests/models/test_pan.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/models/test_pan.py b/tests/models/test_pan.py index f3cddc58..d66fefe0 100644 --- a/tests/models/test_pan.py +++ b/tests/models/test_pan.py @@ -5,7 +5,6 @@ @pytest.mark.pan class TestPanModel(base.BaseModelTester): test_model_type = "pan" - test_encoder_name = "resnet-18" default_batch_size = 2 default_height = 128 From 4d7fed06bf84e299e316e9dd5ad50bae253c47ea Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Mon, 23 Dec 2024 12:15:25 +0000 Subject: [PATCH 26/50] Set encoder for models tests --- tests/models/base.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/tests/models/base.py b/tests/models/base.py index f81c7275..f9eb9a52 100644 --- a/tests/models/base.py +++ b/tests/models/base.py @@ -48,7 +48,9 @@ def test_forward_backward(self): height=self.default_height, width=self.default_width, ) - model = smp.create_model(arch=self.model_type) + model = smp.create_model( + arch=self.model_type, encoder_name=self.test_encoder_name + ) # check default in_channels=3 output = model(sample) @@ -65,7 +67,9 @@ def test_forward_backward(self): # check backward pass output.mean().backward() - def test_base_params_are_set(self, in_channels=1, depth=3, classes=7): + def test_in_channels_and_depth_and_out_classes( + self, in_channels=1, depth=3, classes=7 + ): kwargs = {} if self.model_type in ["unet", "unetplusplus", "manet"]: @@ -73,6 +77,7 @@ def test_base_params_are_set(self, in_channels=1, depth=3, classes=7): model = smp.create_model( arch=self.model_type, + encoder_name=self.test_encoder_name, encoder_depth=depth, in_channels=in_channels, classes=classes, @@ -91,9 +96,10 @@ def test_base_params_are_set(self, in_channels=1, depth=3, classes=7): self.assertEqual(output.shape[1], classes) - def test_aux_params(self): + def test_classification_head(self): model = smp.create_model( arch=self.model_type, + encoder_name=self.test_encoder_name, aux_params={ "pooling": "avg", "classes": 10, @@ -122,14 +128,16 @@ def test_aux_params(self): self.assertEqual(cls_probs.shape[1], 10) - def test_save_load(self): + def test_save_load_with_hub_mixin(self): # instantiate model - model = smp.create_model(arch=self.model_type) + model = smp.create_model( + arch=self.model_type, encoder_name=self.test_encoder_name + ) # save model with tempfile.TemporaryDirectory() as tmpdir: model.save_pretrained(tmpdir) - restored_model = model.from_pretrained(tmpdir) + restored_model = smp.from_pretrained(tmpdir) # check inference is correct sample = self._get_sample( From eadfe1f034d95002047ed7a0424b4cd182d7f6f5 Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Mon, 23 Dec 2024 12:15:41 +0000 Subject: [PATCH 27/50] Docs + flag for anyres --- segmentation_models_pytorch/base/model.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/segmentation_models_pytorch/base/model.py b/segmentation_models_pytorch/base/model.py index 29c7dd2a..6d7bf643 100644 --- a/segmentation_models_pytorch/base/model.py +++ b/segmentation_models_pytorch/base/model.py @@ -5,6 +5,12 @@ class SegmentationModel(torch.nn.Module, SMPHubMixin): + """Base class for all segmentation models.""" + + # if model supports shape not divisible by 2 ^ n + # set to False + requires_divisible_input_shape = True + def initialize(self): init.initialize_decoder(self.decoder) init.initialize_head(self.segmentation_head) @@ -12,6 +18,9 @@ def initialize(self): init.initialize_head(self.classification_head) def check_input_shape(self, x): + """Check if the input shape is divisible by the output stride. + If not, raise a RuntimeError. + """ h, w = x.shape[-2:] output_stride = self.encoder.output_stride if h % output_stride != 0 or w % output_stride != 0: @@ -33,7 +42,7 @@ def check_input_shape(self, x): def forward(self, x): """Sequentially pass `x` trough model`s encoder, decoder and heads""" - if not torch.jit.is_tracing(): + if not torch.jit.is_tracing() or self.requires_divisible_input_shape: self.check_input_shape(x) features = self.encoder(x) From 66bc2c892f27bbb065c3b0bac000aa03354d219f Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Mon, 23 Dec 2024 12:25:24 +0000 Subject: [PATCH 28/50] Fix loading from config --- segmentation_models_pytorch/base/hub_mixin.py | 12 ++++++++++++ .../decoders/deeplabv3/model.py | 3 +++ segmentation_models_pytorch/decoders/fpn/model.py | 2 ++ .../decoders/linknet/model.py | 2 ++ segmentation_models_pytorch/decoders/manet/model.py | 2 ++ segmentation_models_pytorch/decoders/pan/model.py | 2 ++ segmentation_models_pytorch/decoders/pspnet/model.py | 2 ++ .../decoders/segformer/model.py | 2 ++ segmentation_models_pytorch/decoders/unet/model.py | 8 +++++--- .../decoders/unetplusplus/model.py | 2 ++ .../decoders/upernet/model.py | 2 ++ 11 files changed, 36 insertions(+), 3 deletions(-) diff --git a/segmentation_models_pytorch/base/hub_mixin.py b/segmentation_models_pytorch/base/hub_mixin.py index 8095c5b8..affb7079 100644 --- a/segmentation_models_pytorch/base/hub_mixin.py +++ b/segmentation_models_pytorch/base/hub_mixin.py @@ -136,3 +136,15 @@ def from_pretrained(pretrained_model_name_or_path: str, *args, **kwargs): model_class = getattr(smp, model_class_name) return model_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs) + + +def supports_config_loading(func): + """Decorator to filter special config kwargs""" + + @wraps(func) + def wrapper(self, *args, **kwargs): + kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")} + return func(self, *args, **kwargs) + + return wrapper +from segmentation_models_pytorch.base.hub_mixin import supports_config_loading \ No newline at end of file diff --git a/segmentation_models_pytorch/decoders/deeplabv3/model.py b/segmentation_models_pytorch/decoders/deeplabv3/model.py index 830906cb..654e38d4 100644 --- a/segmentation_models_pytorch/decoders/deeplabv3/model.py +++ b/segmentation_models_pytorch/decoders/deeplabv3/model.py @@ -8,6 +8,7 @@ SegmentationModel, ) from segmentation_models_pytorch.encoders import get_encoder +from segmentation_models_pytorch.base.hub_mixin import supports_config_loading from .decoder import DeepLabV3Decoder, DeepLabV3PlusDecoder @@ -54,6 +55,7 @@ class DeepLabV3(SegmentationModel): """ + @supports_config_loading def __init__( self, encoder_name: str = "resnet34", @@ -163,6 +165,7 @@ class DeepLabV3Plus(SegmentationModel): """ + @supports_config_loading def __init__( self, encoder_name: str = "resnet34", diff --git a/segmentation_models_pytorch/decoders/fpn/model.py b/segmentation_models_pytorch/decoders/fpn/model.py index 373269c5..7420b289 100644 --- a/segmentation_models_pytorch/decoders/fpn/model.py +++ b/segmentation_models_pytorch/decoders/fpn/model.py @@ -6,6 +6,7 @@ SegmentationModel, ) from segmentation_models_pytorch.encoders import get_encoder +from segmentation_models_pytorch.base.hub_mixin import supports_config_loading from .decoder import FPNDecoder @@ -51,6 +52,7 @@ class FPN(SegmentationModel): """ + @supports_config_loading def __init__( self, encoder_name: str = "resnet34", diff --git a/segmentation_models_pytorch/decoders/linknet/model.py b/segmentation_models_pytorch/decoders/linknet/model.py index 708ea562..356468ed 100644 --- a/segmentation_models_pytorch/decoders/linknet/model.py +++ b/segmentation_models_pytorch/decoders/linknet/model.py @@ -6,6 +6,7 @@ SegmentationModel, ) from segmentation_models_pytorch.encoders import get_encoder +from segmentation_models_pytorch.base.hub_mixin import supports_config_loading from .decoder import LinknetDecoder @@ -53,6 +54,7 @@ class Linknet(SegmentationModel): https://arxiv.org/abs/1707.03718 """ + @supports_config_loading def __init__( self, encoder_name: str = "resnet34", diff --git a/segmentation_models_pytorch/decoders/manet/model.py b/segmentation_models_pytorch/decoders/manet/model.py index 6651dee6..6ed59207 100644 --- a/segmentation_models_pytorch/decoders/manet/model.py +++ b/segmentation_models_pytorch/decoders/manet/model.py @@ -6,6 +6,7 @@ SegmentationModel, ) from segmentation_models_pytorch.encoders import get_encoder +from segmentation_models_pytorch.base.hub_mixin import supports_config_loading from .decoder import MAnetDecoder @@ -56,6 +57,7 @@ class MAnet(SegmentationModel): """ + @supports_config_loading def __init__( self, encoder_name: str = "resnet34", diff --git a/segmentation_models_pytorch/decoders/pan/model.py b/segmentation_models_pytorch/decoders/pan/model.py index 712541a5..6d5e78c2 100644 --- a/segmentation_models_pytorch/decoders/pan/model.py +++ b/segmentation_models_pytorch/decoders/pan/model.py @@ -6,6 +6,7 @@ SegmentationModel, ) from segmentation_models_pytorch.encoders import get_encoder +from segmentation_models_pytorch.base.hub_mixin import supports_config_loading from .decoder import PANDecoder @@ -53,6 +54,7 @@ class PAN(SegmentationModel): """ + @supports_config_loading def __init__( self, encoder_name: str = "resnet34", diff --git a/segmentation_models_pytorch/decoders/pspnet/model.py b/segmentation_models_pytorch/decoders/pspnet/model.py index dbf04ea4..8b99b3da 100644 --- a/segmentation_models_pytorch/decoders/pspnet/model.py +++ b/segmentation_models_pytorch/decoders/pspnet/model.py @@ -6,6 +6,7 @@ SegmentationModel, ) from segmentation_models_pytorch.encoders import get_encoder +from segmentation_models_pytorch.base.hub_mixin import supports_config_loading from .decoder import PSPDecoder @@ -54,6 +55,7 @@ class PSPNet(SegmentationModel): https://arxiv.org/abs/1612.01105 """ + @supports_config_loading def __init__( self, encoder_name: str = "resnet34", diff --git a/segmentation_models_pytorch/decoders/segformer/model.py b/segmentation_models_pytorch/decoders/segformer/model.py index 2a5e8dba..45805de7 100644 --- a/segmentation_models_pytorch/decoders/segformer/model.py +++ b/segmentation_models_pytorch/decoders/segformer/model.py @@ -6,6 +6,7 @@ SegmentationModel, ) from segmentation_models_pytorch.encoders import get_encoder +from segmentation_models_pytorch.base.hub_mixin import supports_config_loading from .decoder import SegformerDecoder @@ -46,6 +47,7 @@ class Segformer(SegmentationModel): """ + @supports_config_loading def __init__( self, encoder_name: str = "resnet34", diff --git a/segmentation_models_pytorch/decoders/unet/model.py b/segmentation_models_pytorch/decoders/unet/model.py index 0ac7b5bd..547581eb 100644 --- a/segmentation_models_pytorch/decoders/unet/model.py +++ b/segmentation_models_pytorch/decoders/unet/model.py @@ -1,4 +1,4 @@ -from typing import Any, List, Optional, Union +from typing import Any, Optional, Union, Tuple, Callable from segmentation_models_pytorch.base import ( ClassificationHead, @@ -6,6 +6,7 @@ SegmentationModel, ) from segmentation_models_pytorch.encoders import get_encoder +from segmentation_models_pytorch.base.hub_mixin import supports_config_loading from .decoder import UnetDecoder @@ -55,17 +56,18 @@ class Unet(SegmentationModel): """ + @supports_config_loading def __init__( self, encoder_name: str = "resnet34", encoder_depth: int = 5, encoder_weights: Optional[str] = "imagenet", decoder_use_batchnorm: bool = True, - decoder_channels: List[int] = (256, 128, 64, 32, 16), + decoder_channels: Tuple[int, ...] = (256, 128, 64, 32, 16), decoder_attention_type: Optional[str] = None, in_channels: int = 3, classes: int = 1, - activation: Optional[Union[str, callable]] = None, + activation: Optional[Union[str, Callable]] = None, aux_params: Optional[dict] = None, **kwargs: dict[str, Any], ): diff --git a/segmentation_models_pytorch/decoders/unetplusplus/model.py b/segmentation_models_pytorch/decoders/unetplusplus/model.py index 9ba72321..9d4a1e35 100644 --- a/segmentation_models_pytorch/decoders/unetplusplus/model.py +++ b/segmentation_models_pytorch/decoders/unetplusplus/model.py @@ -6,6 +6,7 @@ SegmentationModel, ) from segmentation_models_pytorch.encoders import get_encoder +from segmentation_models_pytorch.base.hub_mixin import supports_config_loading from .decoder import UnetPlusPlusDecoder @@ -55,6 +56,7 @@ class UnetPlusPlus(SegmentationModel): """ + @supports_config_loading def __init__( self, encoder_name: str = "resnet34", diff --git a/segmentation_models_pytorch/decoders/upernet/model.py b/segmentation_models_pytorch/decoders/upernet/model.py index de30a7bb..076ed2de 100644 --- a/segmentation_models_pytorch/decoders/upernet/model.py +++ b/segmentation_models_pytorch/decoders/upernet/model.py @@ -6,6 +6,7 @@ SegmentationModel, ) from segmentation_models_pytorch.encoders import get_encoder +from segmentation_models_pytorch.base.hub_mixin import supports_config_loading from .decoder import UPerNetDecoder @@ -47,6 +48,7 @@ class UPerNet(SegmentationModel): """ + @supports_config_loading def __init__( self, encoder_name: str = "resnet34", From d4b82b385167b0c80f260ea923c996732f2ea5d3 Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Mon, 23 Dec 2024 12:28:20 +0000 Subject: [PATCH 29/50] Bump min hf-hub to 0.25.0 --- requirements/minimum.old | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/minimum.old b/requirements/minimum.old index 3f687871..12b21bb6 100644 --- a/requirements/minimum.old +++ b/requirements/minimum.old @@ -1,5 +1,5 @@ efficientnet-pytorch==0.6.1 -huggingface-hub==0.24.0 +huggingface-hub==0.25.0 numpy==1.19.3 pillow==8.0.0 pretrainedmodels==0.7.1 From 634640512cacb7e99257ec03e65ad5e5fbf86283 Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Mon, 23 Dec 2024 12:40:07 +0000 Subject: [PATCH 30/50] Fix minimal --- segmentation_models_pytorch/base/hub_mixin.py | 1 - tests/config.py | 5 +++++ tests/encoders/test_timm_universal.py | 5 ++--- tests/models/base.py | 17 +++++++++++++++-- 4 files changed, 22 insertions(+), 6 deletions(-) diff --git a/segmentation_models_pytorch/base/hub_mixin.py b/segmentation_models_pytorch/base/hub_mixin.py index affb7079..360aa521 100644 --- a/segmentation_models_pytorch/base/hub_mixin.py +++ b/segmentation_models_pytorch/base/hub_mixin.py @@ -147,4 +147,3 @@ def wrapper(self, *args, **kwargs): return func(self, *args, **kwargs) return wrapper -from segmentation_models_pytorch.base.hub_mixin import supports_config_loading \ No newline at end of file diff --git a/tests/config.py b/tests/config.py index d53eebad..f4b342b9 100644 --- a/tests/config.py +++ b/tests/config.py @@ -1,4 +1,9 @@ import os +import timm +from packaging.version import Version + + +has_timm_test_models = Version(timm.__version__) >= Version("1.0.12") def get_commit_message(): diff --git a/tests/encoders/test_timm_universal.py b/tests/encoders/test_timm_universal.py index 72adaf05..fdd93e78 100644 --- a/tests/encoders/test_timm_universal.py +++ b/tests/encoders/test_timm_universal.py @@ -1,6 +1,5 @@ -import timm from tests.encoders import base -from packaging.version import Version +from tests.config import has_timm_test_models # check if timm >= 1.0.12 timm_encoders = [ @@ -9,7 +8,7 @@ "tu-darknet17", # for timm universal vgg-like encoder ] -if Version(timm.__version__) >= Version("1.0.12"): +if has_timm_test_models: timm_encoders.append("tu-test_resnet.r160_in1k") diff --git a/tests/models/base.py b/tests/models/base.py index f9eb9a52..7634040f 100644 --- a/tests/models/base.py +++ b/tests/models/base.py @@ -1,3 +1,4 @@ +import os import inspect import tempfile import unittest @@ -6,9 +7,13 @@ import torch import segmentation_models_pytorch as smp +from tests.config import has_timm_test_models + class BaseModelTester(unittest.TestCase): - test_encoder_name = "tu-test_resnet.r160_in1k" + test_encoder_name = ( + "tu-test_resnet.r160_in1k" if has_timm_test_models else "resnet18" + ) # should be overriden test_model_type = None @@ -136,8 +141,12 @@ def test_save_load_with_hub_mixin(self): # save model with tempfile.TemporaryDirectory() as tmpdir: - model.save_pretrained(tmpdir) + model.save_pretrained( + tmpdir, dataset="test_dataset", metrics={"my_awesome_metric": 0.99} + ) restored_model = smp.from_pretrained(tmpdir) + with open(os.path.join(tmpdir, "README.md"), "r") as f: + readme = f.read() # check inference is correct sample = self._get_sample( @@ -153,3 +162,7 @@ def test_save_load_with_hub_mixin(self): self.assertEqual(output.shape, restored_output.shape) self.assertEqual(output.shape[1], 1) + + # check dataset and metrics are saved in readme + self.assertIn("test_dataset", readme) + self.assertIn("my_awesome_metric", readme) From 8d06cba886ac32209b0eafc24b2dfb2ba0673833 Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Mon, 23 Dec 2024 13:40:05 +0000 Subject: [PATCH 31/50] Add test with hub checkpoint --- misc/generate_test_models.py | 41 +++++++++++++++++++ .../test_pretrainedmodels_encoders.py | 2 +- tests/encoders/test_smp_encoders.py | 2 +- tests/encoders/test_timm_ported_encoders.py | 2 +- tests/encoders/test_timm_universal.py | 2 +- tests/encoders/test_torchvision_encoders.py | 2 +- tests/models/base.py | 30 +++++++++++++- tests/models/test_segformer.py | 34 +++++++++++++++ tests/{config.py => utils.py} | 23 +++++++++++ 9 files changed, 132 insertions(+), 6 deletions(-) create mode 100644 misc/generate_test_models.py rename tests/{config.py => utils.py} (59%) diff --git a/misc/generate_test_models.py b/misc/generate_test_models.py new file mode 100644 index 00000000..61d6bfd0 --- /dev/null +++ b/misc/generate_test_models.py @@ -0,0 +1,41 @@ +import os +import torch +import tempfile +import huggingface_hub +import segmentation_models_pytorch as smp + +HUB_REPO = "smp-test-models" +ENCODER_NAME = "tu-resnet18" + +api = huggingface_hub.HfApi(token=os.getenv("HF_TOKEN")) + +for model_name, model_class in smp.MODEL_ARCHITECTURES_MAPPING.items(): + model = model_class(encoder_name=ENCODER_NAME) + model = model.eval() + + # generate test sample + torch.manual_seed(423553) + sample = torch.rand(1, 3, 256, 256) + + with torch.no_grad(): + output = model(sample) + + with tempfile.TemporaryDirectory() as tmpdir: + # save model + model.save_pretrained(f"{tmpdir}") + + # save input and output + torch.save(sample, f"{tmpdir}/input-tensor.pth") + torch.save(output, f"{tmpdir}/output-tensor.pth") + + # create repo + repo_id = f"{HUB_REPO}/{model_name}-{ENCODER_NAME}" + if not api.repo_exists(repo_id=repo_id): + api.create_repo(repo_id=repo_id, repo_type="model") + + # upload to hub + api.upload_folder( + folder_path=tmpdir, + repo_id=f"{HUB_REPO}/{model_name}-{ENCODER_NAME}", + repo_type="model", + ) diff --git a/tests/encoders/test_pretrainedmodels_encoders.py b/tests/encoders/test_pretrainedmodels_encoders.py index b1bf1e14..bbde576c 100644 --- a/tests/encoders/test_pretrainedmodels_encoders.py +++ b/tests/encoders/test_pretrainedmodels_encoders.py @@ -1,5 +1,5 @@ from tests.encoders import base -from tests.config import RUN_ALL_ENCODERS +from tests.utils import RUN_ALL_ENCODERS class TestDenseNetEncoder(base.BaseEncoderTester): diff --git a/tests/encoders/test_smp_encoders.py b/tests/encoders/test_smp_encoders.py index c248f1cb..863537bf 100644 --- a/tests/encoders/test_smp_encoders.py +++ b/tests/encoders/test_smp_encoders.py @@ -1,5 +1,5 @@ from tests.encoders import base -from tests.config import RUN_ALL_ENCODERS +from tests.utils import RUN_ALL_ENCODERS class TestMobileoneEncoder(base.BaseEncoderTester): diff --git a/tests/encoders/test_timm_ported_encoders.py b/tests/encoders/test_timm_ported_encoders.py index d24f22ad..b467c968 100644 --- a/tests/encoders/test_timm_ported_encoders.py +++ b/tests/encoders/test_timm_ported_encoders.py @@ -1,5 +1,5 @@ from tests.encoders import base -from tests.config import RUN_ALL_ENCODERS +from tests.utils import RUN_ALL_ENCODERS class TestTimmEfficientNetEncoder(base.BaseEncoderTester): diff --git a/tests/encoders/test_timm_universal.py b/tests/encoders/test_timm_universal.py index fdd93e78..753ee4de 100644 --- a/tests/encoders/test_timm_universal.py +++ b/tests/encoders/test_timm_universal.py @@ -1,5 +1,5 @@ from tests.encoders import base -from tests.config import has_timm_test_models +from tests.utils import has_timm_test_models # check if timm >= 1.0.12 timm_encoders = [ diff --git a/tests/encoders/test_torchvision_encoders.py b/tests/encoders/test_torchvision_encoders.py index 893a17e2..99b8b9d5 100644 --- a/tests/encoders/test_torchvision_encoders.py +++ b/tests/encoders/test_torchvision_encoders.py @@ -1,5 +1,5 @@ from tests.encoders import base -from tests.config import RUN_ALL_ENCODERS +from tests.utils import RUN_ALL_ENCODERS class TestMobileoneEncoder(base.BaseEncoderTester): diff --git a/tests/models/base.py b/tests/models/base.py index 7634040f..2521b9ae 100644 --- a/tests/models/base.py +++ b/tests/models/base.py @@ -7,7 +7,7 @@ import torch import segmentation_models_pytorch as smp -from tests.config import has_timm_test_models +from tests.utils import has_timm_test_models, slow_test class BaseModelTester(unittest.TestCase): @@ -30,6 +30,10 @@ def model_type(self): raise ValueError("test_model_type is not set") return self.test_model_type + @property + def hub_checkpoint(self): + return f"smp-test-models/{self.model_type}-tu-resnet18" + @property def model_class(self): return smp.MODEL_ARCHITECTURES_MAPPING[self.model_type] @@ -166,3 +170,27 @@ def test_save_load_with_hub_mixin(self): # check dataset and metrics are saved in readme self.assertIn("test_dataset", readme) self.assertIn("my_awesome_metric", readme) + + @slow_test + def test_preserve_forward_output(self): + from huggingface_hub import hf_hub_download + + model = smp.from_pretrained(self.hub_checkpoint).eval() + + input_tensor_path = hf_hub_download( + repo_id=self.hub_checkpoint, filename="input-tensor.pth" + ) + output_tensor_path = hf_hub_download( + repo_id=self.hub_checkpoint, filename="output-tensor.pth" + ) + + input_tensor = torch.load(input_tensor_path, weights_only=True) + output_tensor = torch.load(output_tensor_path, weights_only=True) + + with torch.no_grad(): + output = model(input_tensor) + + self.assertEqual(output.shape, output_tensor.shape) + is_close = torch.allclose(output, output_tensor, atol=1e-3) + max_diff = torch.max(torch.abs(output - output_tensor)) + self.assertTrue(is_close, f"Max diff: {max_diff}") diff --git a/tests/models/test_segformer.py b/tests/models/test_segformer.py index 3317acd2..3d073763 100644 --- a/tests/models/test_segformer.py +++ b/tests/models/test_segformer.py @@ -1,7 +1,41 @@ +import torch import pytest +import segmentation_models_pytorch as smp + from tests.models import base +from tests.utils import slow_test, default_device @pytest.mark.segformer class TestSegformerModel(base.BaseModelTester): test_model_type = "segformer" + + @slow_test + def test_load_pretrained(self): + hub_checkpoint = "smp-hub/segformer-b0-512x512-ade-160k" + + model = smp.from_pretrained(hub_checkpoint) + model = model.eval().to(default_device) + + sample = torch.ones([1, 3, 512, 512]).to(default_device) + + with torch.no_grad(): + output = model(sample) + + self.assertEqual(output.shape, (1, 150, 512, 512)) + + expected_logits_slice = torch.tensor( + [-4.4172, -4.4723, -4.5273, -4.5824, -4.6375, -4.7157] + ) + resulted_logits_slice = output[0, 0, 256, :6].cpu() + is_equal = torch.allclose( + expected_logits_slice, resulted_logits_slice, atol=1e-2 + ) + max_diff = torch.max(torch.abs(expected_logits_slice - resulted_logits_slice)) + self.assertTrue( + is_equal, + f"Expected logits slice and resulted logits slice are not equal.\n" + f"Max diff: {max_diff}\n" + f"Expected: {expected_logits_slice}\n" + f"Resulted: {resulted_logits_slice}\n", + ) diff --git a/tests/config.py b/tests/utils.py similarity index 59% rename from tests/config.py rename to tests/utils.py index f4b342b9..ed418ffd 100644 --- a/tests/config.py +++ b/tests/utils.py @@ -1,9 +1,13 @@ import os import timm +import torch +import unittest + from packaging.version import Version has_timm_test_models = Version(timm.__version__) >= Version("1.0.12") +default_device = "cuda" if torch.cuda.is_available() else "cpu" def get_commit_message(): @@ -28,3 +32,22 @@ def get_commit_message(): os.getenv("RUN_SLOW", "false").lower() in ["true", "1", "y", "yes"] or "run-slow" in commit_message ) + + +def slow_test(test_case): + """ + Decorator marking a test as slow. + + Slow tests are skipped by default. Set the RUN_SLOW environment variable to a truthy value to run them. + + """ + return unittest.skipUnless(RUN_SLOW, "test is slow")(test_case) + + +def requires_torch_greater_or_equal(version: str): + torch_version = Version(torch.__version__) + provided_version = Version(version) + return unittest.skipUnless( + torch_version >= provided_version, + f"torch version {torch_version} is less than {provided_version}", + ) From a83a5e95f1717750aee9ea081a9e6261f361f086 Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Mon, 23 Dec 2024 13:42:56 +0000 Subject: [PATCH 32/50] Fixing minimum --- requirements/minimum.old | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/requirements/minimum.old b/requirements/minimum.old index 12b21bb6..16293f52 100644 --- a/requirements/minimum.old +++ b/requirements/minimum.old @@ -1,5 +1,5 @@ efficientnet-pytorch==0.6.1 -huggingface-hub==0.25.0 +huggingface-hub==0.24.0 numpy==1.19.3 pillow==8.0.0 pretrainedmodels==0.7.1 @@ -8,3 +8,4 @@ timm==0.9.0 torch==1.9.0 torchvision==0.10.0 tqdm==4.42.1 +Jinja2=3.0.0 From 10993bbbece08e93dfbe1d27af7ba05d7086e7d0 Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Mon, 23 Dec 2024 13:44:23 +0000 Subject: [PATCH 33/50] Fix --- requirements/minimum.old | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements/minimum.old b/requirements/minimum.old index 16293f52..1080bdb4 100644 --- a/requirements/minimum.old +++ b/requirements/minimum.old @@ -8,4 +8,4 @@ timm==0.9.0 torch==1.9.0 torchvision==0.10.0 tqdm==4.42.1 -Jinja2=3.0.0 +Jinja2==3.0.0 From 444673048efdbc072e3de83c309c81899cc4b981 Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Mon, 23 Dec 2024 13:47:12 +0000 Subject: [PATCH 34/50] Fix torch for minimum tests --- tests/models/base.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/models/base.py b/tests/models/base.py index 2521b9ae..4006fa28 100644 --- a/tests/models/base.py +++ b/tests/models/base.py @@ -7,7 +7,7 @@ import torch import segmentation_models_pytorch as smp -from tests.utils import has_timm_test_models, slow_test +from tests.utils import has_timm_test_models, slow_test, requires_torch_greater_or_equal class BaseModelTester(unittest.TestCase): @@ -172,6 +172,7 @@ def test_save_load_with_hub_mixin(self): self.assertIn("my_awesome_metric", readme) @slow_test + @requires_torch_greater_or_equal("2.0.0") def test_preserve_forward_output(self): from huggingface_hub import hf_hub_download From 914fc6ee73e71fce04d8d9439812549c02f78e68 Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Mon, 23 Dec 2024 13:52:30 +0000 Subject: [PATCH 35/50] Update torch version and run-slow --- tests/models/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/base.py b/tests/models/base.py index 4006fa28..15ca6751 100644 --- a/tests/models/base.py +++ b/tests/models/base.py @@ -172,7 +172,7 @@ def test_save_load_with_hub_mixin(self): self.assertIn("my_awesome_metric", readme) @slow_test - @requires_torch_greater_or_equal("2.0.0") + @requires_torch_greater_or_equal("2.0.1") def test_preserve_forward_output(self): from huggingface_hub import hf_hub_download From 3eb990061001147c7982ac4e942831fef8250c2c Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Mon, 23 Dec 2024 13:59:49 +0000 Subject: [PATCH 36/50] run-slow --- .github/workflows/tests.yml | 2 ++ tests/utils.py | 8 +------- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 83cb0edc..88fba586 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -48,6 +48,8 @@ jobs: run: | python -m pip list - name: Test with pytest + env: + COMMIT_MESSAGE: ${{ github.event.head_commit.message }} run: pytest -v -n 2 --cov=segmentation_models_pytorch --cov-report=xml --cov-config=pyproject.toml - name: Upload coverage reports to Codecov uses: codecov/codecov-action@v5 diff --git a/tests/utils.py b/tests/utils.py index ed418ffd..e8bce88e 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -11,13 +11,7 @@ def get_commit_message(): - # Get commit message from CI environment variables - # Common CI systems store commit messages in different env vars - commit_msg = os.getenv("COMMIT_MESSAGE", "") # Generic - if not commit_msg: - commit_msg = os.getenv("CI_COMMIT_MESSAGE", "") # GitLab CI - if not commit_msg: - commit_msg = os.getenv("GITHUB_EVENT_HEAD_COMMIT_MESSAGE", "") # GitHub Actions + commit_msg = os.getenv("COMMIT_MESSAGE", "") return commit_msg.lower() From 15b0658862466a614266e71d6d3a943a9d4a8c9c Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Mon, 23 Dec 2024 14:06:22 +0000 Subject: [PATCH 37/50] Show skipped --- .github/workflows/tests.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 88fba586..61874416 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -50,7 +50,7 @@ jobs: - name: Test with pytest env: COMMIT_MESSAGE: ${{ github.event.head_commit.message }} - run: pytest -v -n 2 --cov=segmentation_models_pytorch --cov-report=xml --cov-config=pyproject.toml + run: pytest -v -rsx -n 2 --cov=segmentation_models_pytorch --cov-report=xml --cov-config=pyproject.toml - name: Upload coverage reports to Codecov uses: codecov/codecov-action@v5 with: @@ -74,4 +74,4 @@ jobs: run: | python -m pip list - name: Test with pytest - run: pytest -v -n 2 + run: pytest -v -rsx -n 2 From 238dd11631dc466aaaf27728ae0a5f6da456eaa2 Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Mon, 23 Dec 2024 14:13:18 +0000 Subject: [PATCH 38/50] [run-slow] Fixing minimum --- tests/models/base.py | 3 ++- tests/models/test_segformer.py | 3 ++- tests/utils.py | 2 +- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/tests/models/base.py b/tests/models/base.py index 15ca6751..739f40eb 100644 --- a/tests/models/base.py +++ b/tests/models/base.py @@ -137,6 +137,7 @@ def test_classification_head(self): self.assertEqual(cls_probs.shape[1], 10) + @requires_torch_greater_or_equal("2.0.0") def test_save_load_with_hub_mixin(self): # instantiate model model = smp.create_model( @@ -172,7 +173,7 @@ def test_save_load_with_hub_mixin(self): self.assertIn("my_awesome_metric", readme) @slow_test - @requires_torch_greater_or_equal("2.0.1") + @requires_torch_greater_or_equal("2.0.0") def test_preserve_forward_output(self): from huggingface_hub import hf_hub_download diff --git a/tests/models/test_segformer.py b/tests/models/test_segformer.py index 3d073763..f59a0fcc 100644 --- a/tests/models/test_segformer.py +++ b/tests/models/test_segformer.py @@ -3,7 +3,7 @@ import segmentation_models_pytorch as smp from tests.models import base -from tests.utils import slow_test, default_device +from tests.utils import slow_test, default_device, requires_torch_greater_or_equal @pytest.mark.segformer @@ -11,6 +11,7 @@ class TestSegformerModel(base.BaseModelTester): test_model_type = "segformer" @slow_test + @requires_torch_greater_or_equal("2.0.0") def test_load_pretrained(self): hub_checkpoint = "smp-hub/segformer-b0-512x512-ade-160k" diff --git a/tests/utils.py b/tests/utils.py index e8bce88e..e87874f3 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -42,6 +42,6 @@ def requires_torch_greater_or_equal(version: str): torch_version = Version(torch.__version__) provided_version = Version(version) return unittest.skipUnless( - torch_version >= provided_version, + torch_version < provided_version, f"torch version {torch_version} is less than {provided_version}", ) From 63ba1c4a7efe51b7a5f317324449fafd3e1b7aab Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Mon, 23 Dec 2024 14:15:11 +0000 Subject: [PATCH 39/50] [run-slow] Fixing minimum --- tests/models/base.py | 4 ++-- tests/models/test_segformer.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/models/base.py b/tests/models/base.py index 739f40eb..9fd23902 100644 --- a/tests/models/base.py +++ b/tests/models/base.py @@ -137,7 +137,7 @@ def test_classification_head(self): self.assertEqual(cls_probs.shape[1], 10) - @requires_torch_greater_or_equal("2.0.0") + @requires_torch_greater_or_equal("2.0.1") def test_save_load_with_hub_mixin(self): # instantiate model model = smp.create_model( @@ -173,7 +173,7 @@ def test_save_load_with_hub_mixin(self): self.assertIn("my_awesome_metric", readme) @slow_test - @requires_torch_greater_or_equal("2.0.0") + @requires_torch_greater_or_equal("2.0.1") def test_preserve_forward_output(self): from huggingface_hub import hf_hub_download diff --git a/tests/models/test_segformer.py b/tests/models/test_segformer.py index f59a0fcc..abe0cae3 100644 --- a/tests/models/test_segformer.py +++ b/tests/models/test_segformer.py @@ -11,7 +11,7 @@ class TestSegformerModel(base.BaseModelTester): test_model_type = "segformer" @slow_test - @requires_torch_greater_or_equal("2.0.0") + @requires_torch_greater_or_equal("2.0.1") def test_load_pretrained(self): hub_checkpoint = "smp-hub/segformer-b0-512x512-ade-160k" From 81818885d268fa37b458b038688765c7ef06a7db Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Mon, 23 Dec 2024 14:16:12 +0000 Subject: [PATCH 40/50] Fix decorator --- tests/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/utils.py b/tests/utils.py index e87874f3..e8bce88e 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -42,6 +42,6 @@ def requires_torch_greater_or_equal(version: str): torch_version = Version(torch.__version__) provided_version = Version(version) return unittest.skipUnless( - torch_version < provided_version, + torch_version >= provided_version, f"torch version {torch_version} is less than {provided_version}", ) From e41d480750bcbc2bb16b94f8ae57b17d18e7ddb2 Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Mon, 23 Dec 2024 14:16:50 +0000 Subject: [PATCH 41/50] Raise error --- tests/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/utils.py b/tests/utils.py index e8bce88e..1e989e36 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -12,6 +12,7 @@ def get_commit_message(): commit_msg = os.getenv("COMMIT_MESSAGE", "") + raise ValueError(commit_msg) return commit_msg.lower() From e1caeb9c9ad5bd78dd51000e146a7ea76df1cb70 Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Mon, 23 Dec 2024 14:25:15 +0000 Subject: [PATCH 42/50] [run-slow] Fixing run slow --- .github/workflows/tests.yml | 17 ++++++++++++++--- tests/utils.py | 1 - 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 61874416..aa3f9ce1 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -36,21 +36,32 @@ jobs: runs-on: ${{ matrix.os }} steps: - uses: actions/checkout@v4 + - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} + - name: Install dependencies run: | python -m pip install uv python -m uv pip install --system -r requirements/required.txt -r requirements/test.txt + - name: Show installed packages run: | python -m pip list - - name: Test with pytest + + - name: Fetch latest commit message + id: get_commit_message + run: echo "COMMIT_MESSAGE=$(git log -1 --pretty=%B)" >> $GITHUB_ENV + + - name: Test with PyTest env: - COMMIT_MESSAGE: ${{ github.event.head_commit.message }} - run: pytest -v -rsx -n 2 --cov=segmentation_models_pytorch --cov-report=xml --cov-config=pyproject.toml + COMMIT_MESSAGE: ${{ env.COMMIT_MESSAGE }} + run: | + echo "COMMIT_MESSAGE: ${{ env.COMMIT_MESSAGE }}" + pytest -v -rsx -n 2 --cov=segmentation_models_pytorch --cov-report=xml --cov-config=pyproject.toml + - name: Upload coverage reports to Codecov uses: codecov/codecov-action@v5 with: diff --git a/tests/utils.py b/tests/utils.py index 1e989e36..e8bce88e 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -12,7 +12,6 @@ def get_commit_message(): commit_msg = os.getenv("COMMIT_MESSAGE", "") - raise ValueError(commit_msg) return commit_msg.lower() From 101b48720954602a24f3f6e98772f6b2aeaca15c Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Mon, 23 Dec 2024 14:31:38 +0000 Subject: [PATCH 43/50] [run-slow] Fixing run slow --- .github/workflows/tests.yml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index aa3f9ce1..aa484500 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -36,6 +36,8 @@ jobs: runs-on: ${{ matrix.os }} steps: - uses: actions/checkout@v4 + with: + fetch-depth: 1 # Fetch the latest commit to inspect its message - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v5 @@ -53,7 +55,7 @@ jobs: - name: Fetch latest commit message id: get_commit_message - run: echo "COMMIT_MESSAGE=$(git log -1 --pretty=%B)" >> $GITHUB_ENV + run: echo "COMMIT_MESSAGE=$(git log -2 --pretty=%B)" >> $GITHUB_ENV - name: Test with PyTest env: From 13563a6084a813f59a57cbf7975a2136b2f39a0b Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Mon, 23 Dec 2024 15:02:39 +0000 Subject: [PATCH 44/50] Run slow tests in separate job --- .github/workflows/tests.yml | 30 ++++++++++++++++++------------ pyproject.toml | 1 + tests/models/base.py | 5 +++-- tests/models/test_segformer.py | 1 + 4 files changed, 23 insertions(+), 14 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index aa484500..a829dbcd 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -36,8 +36,6 @@ jobs: runs-on: ${{ matrix.os }} steps: - uses: actions/checkout@v4 - with: - fetch-depth: 1 # Fetch the latest commit to inspect its message - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v5 @@ -53,16 +51,9 @@ jobs: run: | python -m pip list - - name: Fetch latest commit message - id: get_commit_message - run: echo "COMMIT_MESSAGE=$(git log -2 --pretty=%B)" >> $GITHUB_ENV - - name: Test with PyTest - env: - COMMIT_MESSAGE: ${{ env.COMMIT_MESSAGE }} run: | - echo "COMMIT_MESSAGE: ${{ env.COMMIT_MESSAGE }}" - pytest -v -rsx -n 2 --cov=segmentation_models_pytorch --cov-report=xml --cov-config=pyproject.toml + pytest -v -rsx -n 2 --cov=segmentation_models_pytorch --cov-report=xml --cov-config=pyproject.toml -k "not logits_match" - name: Upload coverage reports to Codecov uses: codecov/codecov-action@v5 @@ -71,11 +62,26 @@ jobs: slug: qubvel-org/segmentation_models.pytorch if: matrix.os == 'macos-latest' && matrix.python-version == '3.12' + test_logits_match: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.9" + - name: Install dependencies + run: | + python -m pip install uv + python -m uv pip install --system -r requirements/required.txt -r requirements/test.txt + - name: Test with PyTest + run: RUN_SLOW=1 pytest -v -rsx -n 2 -k "logits_match" + minimum: runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - - name: Set up Python ${{ matrix.python-version }} + - name: Set up Python uses: actions/setup-python@v5 with: python-version: "3.9" @@ -87,4 +93,4 @@ jobs: run: | python -m pip list - name: Test with pytest - run: pytest -v -rsx -n 2 + run: pytest -v -rsx -n 2 -k "not logits_match" diff --git a/pyproject.toml b/pyproject.toml index bb4a107d..05259ec4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -69,6 +69,7 @@ markers = [ "unet", "unetplusplus", "upernet", + "logits_match", ] [tool.coverage.run] diff --git a/tests/models/base.py b/tests/models/base.py index 9fd23902..32ff16f4 100644 --- a/tests/models/base.py +++ b/tests/models/base.py @@ -1,8 +1,10 @@ import os +import pytest import inspect import tempfile import unittest from functools import lru_cache +from huggingface_hub import hf_hub_download import torch import segmentation_models_pytorch as smp @@ -174,9 +176,8 @@ def test_save_load_with_hub_mixin(self): @slow_test @requires_torch_greater_or_equal("2.0.1") + @pytest.mark.logits_match def test_preserve_forward_output(self): - from huggingface_hub import hf_hub_download - model = smp.from_pretrained(self.hub_checkpoint).eval() input_tensor_path = hf_hub_download( diff --git a/tests/models/test_segformer.py b/tests/models/test_segformer.py index abe0cae3..3ca5016c 100644 --- a/tests/models/test_segformer.py +++ b/tests/models/test_segformer.py @@ -12,6 +12,7 @@ class TestSegformerModel(base.BaseModelTester): @slow_test @requires_torch_greater_or_equal("2.0.1") + @pytest.mark.logits_match def test_load_pretrained(self): hub_checkpoint = "smp-hub/segformer-b0-512x512-ade-160k" From d4d85083e7ae72c5d4208897f7a82a327c41f1ed Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Mon, 23 Dec 2024 15:03:51 +0000 Subject: [PATCH 45/50] FIx --- .github/workflows/tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index a829dbcd..e9b34d73 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -69,7 +69,7 @@ jobs: - name: Set up Python uses: actions/setup-python@v5 with: - python-version: "3.9" + python-version: "3.10" - name: Install dependencies run: | python -m pip install uv From a17bb6ad247664f2560bb96023d7dfc26b9db325 Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Mon, 23 Dec 2024 15:09:22 +0000 Subject: [PATCH 46/50] Fixes --- Makefile | 5 +- tests/test_models.py | 194 ------------------------------------------- 2 files changed, 4 insertions(+), 195 deletions(-) delete mode 100644 tests/test_models.py diff --git a/Makefile b/Makefile index 9e974026..cacb22ad 100644 --- a/Makefile +++ b/Makefile @@ -7,7 +7,10 @@ install_dev: .venv .venv/bin/pip install -e ".[test]" test: .venv - .venv/bin/pytest -p no:cacheprovider tests/ + .venv/bin/pytest -v -rsx -n 2 tests/ -k "not logits_match" + +test_all: .venv + .venv/bin/pytest -v -rsx -n 2 tests/ table: .venv/bin/python misc/generate_table.py diff --git a/tests/test_models.py b/tests/test_models.py deleted file mode 100644 index c2b71585..00000000 --- a/tests/test_models.py +++ /dev/null @@ -1,194 +0,0 @@ -# import pytest -# import torch - -# import segmentation_models_pytorch as smp # noqa - - -# def get_encoders(): -# exclude_encoders = [ -# "senet154", -# "resnext101_32x16d", -# "resnext101_32x32d", -# "resnext101_32x48d", -# ] -# encoders = smp.encoders.get_encoder_names() -# encoders = [e for e in encoders if e not in exclude_encoders] -# encoders.append("tu-resnet34") # for timm universal traditional-like encoder -# encoders.append("tu-convnext_atto") # for timm universal transformer-like encoder -# encoders.append("tu-darknet17") # for timm universal vgg-like encoder -# return encoders - - -# ENCODERS = get_encoders() -# DEFAULT_ENCODER = "resnet18" - - -# def get_sample(model_class): -# if model_class in [ -# smp.FPN, -# smp.Linknet, -# smp.Unet, -# smp.UnetPlusPlus, -# smp.MAnet, -# smp.Segformer, -# ]: -# sample = torch.ones([1, 3, 64, 64]) -# elif model_class in [smp.PAN, smp.DeepLabV3, smp.DeepLabV3Plus]: -# sample = torch.ones([2, 3, 128, 128]) -# elif model_class in [smp.PSPNet, smp.UPerNet]: -# # Batch size 2 needed due to nn.BatchNorm2d not supporting (1, C, 1, 1) input -# # from PSPModule pooling in PSPNet/UPerNet. -# sample = torch.ones([2, 3, 64, 64]) -# else: -# raise ValueError("Not supported model class {}".format(model_class)) -# return sample - - -# def _test_forward(model, sample, test_shape=False): -# with torch.no_grad(): -# out = model(sample) -# if test_shape: -# assert out.shape[2:] == sample.shape[2:] - - -# def _test_forward_backward(model, sample, test_shape=False): -# out = model(sample) -# out.mean().backward() -# if test_shape: -# assert out.shape[2:] == sample.shape[2:] - - -# @pytest.mark.parametrize("encoder_name", ENCODERS) -# @pytest.mark.parametrize("encoder_depth", [3, 5]) -# @pytest.mark.parametrize( -# "model_class", -# [ -# smp.FPN, -# smp.PSPNet, -# smp.Linknet, -# smp.Unet, -# smp.UnetPlusPlus, -# smp.MAnet, -# smp.UPerNet, -# smp.Segformer, -# ], -# ) -# def test_forward(model_class, encoder_name, encoder_depth, **kwargs): -# if ( -# model_class is smp.Unet -# or model_class is smp.UnetPlusPlus -# or model_class is smp.MAnet -# ): -# kwargs["decoder_channels"] = (16, 16, 16, 16, 16)[-encoder_depth:] -# if model_class in [smp.UnetPlusPlus, smp.Linknet]: -# if encoder_name.startswith("mit_b") or encoder_name.startswith("tu-convnext"): -# return # skip transformer-like model* -# if model_class is smp.FPN and encoder_depth != 5: -# if encoder_name.startswith("mit_b") or encoder_name.startswith("tu-convnext"): -# return # skip transformer-like model* -# model = model_class( -# encoder_name, encoder_depth=encoder_depth, encoder_weights=None, **kwargs -# ) -# sample = get_sample(model_class) -# model.eval() -# if encoder_depth == 5 and model_class != smp.PSPNet: -# test_shape = True -# else: -# test_shape = False - -# _test_forward(model, sample, test_shape) - - -# @pytest.mark.parametrize( -# "model_class", -# [ -# smp.PAN, -# smp.FPN, -# smp.PSPNet, -# smp.Linknet, -# smp.Unet, -# smp.UnetPlusPlus, -# smp.MAnet, -# smp.DeepLabV3, -# smp.DeepLabV3Plus, -# smp.UPerNet, -# smp.Segformer, -# ], -# ) -# def test_forward_backward(model_class): -# sample = get_sample(model_class) -# model = model_class(DEFAULT_ENCODER, encoder_weights=None) -# _test_forward_backward(model, sample) - - -# @pytest.mark.parametrize( -# "model_class", -# [ -# smp.PAN, -# smp.FPN, -# smp.PSPNet, -# smp.Linknet, -# smp.Unet, -# smp.UnetPlusPlus, -# smp.MAnet, -# smp.DeepLabV3, -# smp.DeepLabV3Plus, -# smp.UPerNet, -# smp.Segformer, -# ], -# ) -# def test_aux_output(model_class): -# model = model_class( -# DEFAULT_ENCODER, encoder_weights=None, aux_params=dict(classes=2) -# ) -# sample = get_sample(model_class) -# label_size = (sample.shape[0], 2) -# mask, label = model(sample) -# assert label.size() == label_size - - -# @pytest.mark.parametrize("upsampling", [2, 4, 8]) -# @pytest.mark.parametrize("model_class", [smp.FPN, smp.PSPNet]) -# def test_upsample(model_class, upsampling): -# default_upsampling = 4 if model_class is smp.FPN else 8 -# model = model_class(DEFAULT_ENCODER, encoder_weights=None, upsampling=upsampling) -# sample = get_sample(model_class) -# mask = model(sample) -# assert mask.size()[-1] / 64 == upsampling / default_upsampling - - -# @pytest.mark.parametrize("model_class", [smp.FPN]) -# @pytest.mark.parametrize("in_channels", [1, 2, 4]) -# def test_in_channels(model_class, in_channels): -# sample = torch.ones([1, in_channels, 64, 64]) -# model = model_class(DEFAULT_ENCODER, encoder_weights=None, in_channels=in_channels) -# model.eval() -# with torch.no_grad(): -# model(sample) - -# assert model.encoder._in_channels == in_channels - - -# @pytest.mark.parametrize("encoder_name", ENCODERS) -# def test_dilation(encoder_name): -# if ( -# encoder_name in ["inceptionresnetv2", "xception", "inceptionv4"] -# or encoder_name.startswith("vgg") -# or encoder_name.startswith("densenet") -# or encoder_name.startswith("timm-res") -# ): -# return - -# encoder = smp.encoders.get_encoder(encoder_name, output_stride=16) - -# encoder.eval() -# with torch.no_grad(): -# sample = torch.ones([1, 3, 64, 64]) -# output = encoder(sample) - -# shapes = [out.shape[-1] for out in output] -# assert shapes == [64, 32, 16, 8, 4, 4] # last downsampling replaced with dilation - - -# if __name__ == "__main__": -# pytest.main([__file__]) From 58afe1b3b3abc61f84f99c604cc1819d4995af71 Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Mon, 23 Dec 2024 15:14:58 +0000 Subject: [PATCH 47/50] Add device --- Makefile | 2 +- pyproject.toml | 2 ++ tests/models/base.py | 31 +++++++++++++++++++------------ 3 files changed, 22 insertions(+), 13 deletions(-) diff --git a/Makefile b/Makefile index cacb22ad..a58d230f 100644 --- a/Makefile +++ b/Makefile @@ -10,7 +10,7 @@ test: .venv .venv/bin/pytest -v -rsx -n 2 tests/ -k "not logits_match" test_all: .venv - .venv/bin/pytest -v -rsx -n 2 tests/ + RUN_SLOW=1 .venv/bin/pytest -v -rsx -n 2 tests/ table: .venv/bin/python misc/generate_table.py diff --git a/pyproject.toml b/pyproject.toml index 05259ec4..5df18bc7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,6 +40,8 @@ docs = [ ] test = [ 'pytest', + 'pytest-cov', + 'pytest-xdist', 'ruff', ] diff --git a/tests/models/base.py b/tests/models/base.py index 32ff16f4..8b224ed1 100644 --- a/tests/models/base.py +++ b/tests/models/base.py @@ -9,7 +9,12 @@ import torch import segmentation_models_pytorch as smp -from tests.utils import has_timm_test_models, slow_test, requires_torch_greater_or_equal +from tests.utils import ( + has_timm_test_models, + default_device, + slow_test, + requires_torch_greater_or_equal, +) class BaseModelTester(unittest.TestCase): @@ -58,10 +63,10 @@ def test_forward_backward(self): num_channels=self.default_num_channels, height=self.default_height, width=self.default_width, - ) + ).to(default_device) model = smp.create_model( arch=self.model_type, encoder_name=self.test_encoder_name - ) + ).to(default_device) # check default in_channels=3 output = model(sample) @@ -93,13 +98,13 @@ def test_in_channels_and_depth_and_out_classes( in_channels=in_channels, classes=classes, **kwargs, - ) + ).to(default_device) sample = self._get_sample( batch_size=self.default_batch_size, num_channels=in_channels, height=self.default_height, width=self.default_width, - ) + ).to(default_device) # check in channels correctly set with torch.no_grad(): @@ -117,7 +122,7 @@ def test_classification_head(self): "dropout": 0.5, "activation": "sigmoid", }, - ) + ).to(default_device) self.assertIsNotNone(model.classification_head) self.assertIsInstance(model.classification_head[0], torch.nn.AdaptiveAvgPool2d) @@ -132,7 +137,7 @@ def test_classification_head(self): num_channels=self.default_num_channels, height=self.default_height, width=self.default_width, - ) + ).to(default_device) with torch.no_grad(): _, cls_probs = model(sample) @@ -144,14 +149,14 @@ def test_save_load_with_hub_mixin(self): # instantiate model model = smp.create_model( arch=self.model_type, encoder_name=self.test_encoder_name - ) + ).to(default_device) # save model with tempfile.TemporaryDirectory() as tmpdir: model.save_pretrained( tmpdir, dataset="test_dataset", metrics={"my_awesome_metric": 0.99} ) - restored_model = smp.from_pretrained(tmpdir) + restored_model = smp.from_pretrained(tmpdir).to(default_device) with open(os.path.join(tmpdir, "README.md"), "r") as f: readme = f.read() @@ -161,7 +166,7 @@ def test_save_load_with_hub_mixin(self): num_channels=self.default_num_channels, height=self.default_height, width=self.default_width, - ) + ).to(default_device) with torch.no_grad(): output = model(sample) @@ -178,7 +183,7 @@ def test_save_load_with_hub_mixin(self): @requires_torch_greater_or_equal("2.0.1") @pytest.mark.logits_match def test_preserve_forward_output(self): - model = smp.from_pretrained(self.hub_checkpoint).eval() + model = smp.from_pretrained(self.hub_checkpoint).eval().to(default_device) input_tensor_path = hf_hub_download( repo_id=self.hub_checkpoint, filename="input-tensor.pth" @@ -188,12 +193,14 @@ def test_preserve_forward_output(self): ) input_tensor = torch.load(input_tensor_path, weights_only=True) + input_tensor = input_tensor.to(default_device) output_tensor = torch.load(output_tensor_path, weights_only=True) + output_tensor = output_tensor.to(default_device) with torch.no_grad(): output = model(input_tensor) self.assertEqual(output.shape, output_tensor.shape) - is_close = torch.allclose(output, output_tensor, atol=1e-3) + is_close = torch.allclose(output, output_tensor, atol=1e-2) max_diff = torch.max(torch.abs(output - output_tensor)) self.assertTrue(is_close, f"Max diff: {max_diff}") From 4d51fac9154cb845e6a583a5c968c00e1cb1e0ab Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Mon, 23 Dec 2024 15:17:38 +0000 Subject: [PATCH 48/50] Bum tolerance --- tests/models/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/base.py b/tests/models/base.py index 8b224ed1..02e17303 100644 --- a/tests/models/base.py +++ b/tests/models/base.py @@ -201,6 +201,6 @@ def test_preserve_forward_output(self): output = model(input_tensor) self.assertEqual(output.shape, output_tensor.shape) - is_close = torch.allclose(output, output_tensor, atol=1e-2) + is_close = torch.allclose(output, output_tensor, atol=5e-2) max_diff = torch.max(torch.abs(output - output_tensor)) self.assertTrue(is_close, f"Max diff: {max_diff}") From aa2cf991f56f11df0312fce63c8ca368e142c912 Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Mon, 23 Dec 2024 15:17:45 +0000 Subject: [PATCH 49/50] Add device --- tests/encoders/base.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/tests/encoders/base.py b/tests/encoders/base.py index ec75d79a..a14cb36b 100644 --- a/tests/encoders/base.py +++ b/tests/encoders/base.py @@ -3,7 +3,7 @@ import segmentation_models_pytorch as smp from functools import lru_cache - +from tests.utils import default_device class BaseEncoderTester(unittest.TestCase): encoder_names = [] @@ -40,13 +40,13 @@ def test_forward_backward(self): num_channels=self.default_num_channels, height=self.default_height, width=self.default_width, - ) + ).to(default_device) for encoder_name in self.encoder_names: with self.subTest(encoder_name=encoder_name): # init encoder encoder = smp.encoders.get_encoder( encoder_name, in_channels=3, encoder_weights=None - ) + ).to(default_device) # forward features = encoder.forward(sample) @@ -72,12 +72,12 @@ def test_in_channels(self): num_channels=in_channels, height=self.default_height, width=self.default_width, - ) + ).to(default_device) with self.subTest(encoder_name=encoder_name, in_channels=in_channels): encoder = smp.encoders.get_encoder( encoder_name, in_channels=in_channels, encoder_weights=None - ) + ).to(default_device) encoder.eval() # forward @@ -90,7 +90,7 @@ def test_depth(self): num_channels=self.default_num_channels, height=self.default_height, width=self.default_width, - ) + ).to(default_device) cases = [ (encoder_name, depth) @@ -105,7 +105,7 @@ def test_depth(self): in_channels=self.default_num_channels, encoder_weights=None, depth=depth, - ) + ).to(default_device) encoder.eval() # forward @@ -154,7 +154,7 @@ def test_dilated(self): num_channels=self.default_num_channels, height=self.default_height, width=self.default_width, - ) + ).to(default_device) cases = [ (encoder_name, stride) @@ -172,7 +172,7 @@ def test_dilated(self): in_channels=self.default_num_channels, encoder_weights=None, output_stride=stride, - ) + ).to(default_device) return for encoder_name, stride in cases: @@ -182,7 +182,7 @@ def test_dilated(self): in_channels=self.default_num_channels, encoder_weights=None, output_stride=stride, - ) + ).to(default_device) encoder.eval() # forward From 3e8d7032f7ab16c82997abd7c06ca7adfa5a6233 Mon Sep 17 00:00:00 2001 From: Pavel Iakubovskii Date: Mon, 23 Dec 2024 15:18:39 +0000 Subject: [PATCH 50/50] Fixup --- tests/encoders/base.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/encoders/base.py b/tests/encoders/base.py index a14cb36b..39cd4164 100644 --- a/tests/encoders/base.py +++ b/tests/encoders/base.py @@ -5,6 +5,7 @@ from functools import lru_cache from tests.utils import default_device + class BaseEncoderTester(unittest.TestCase): encoder_names = []