diff --git a/tests/encoders/base.py b/tests/encoders/base.py index c1858cdb..28b12ab8 100644 --- a/tests/encoders/base.py +++ b/tests/encoders/base.py @@ -4,7 +4,11 @@ import segmentation_models_pytorch as smp from functools import lru_cache -from tests.utils import default_device, check_run_test_on_diff_or_main +from tests.utils import ( + default_device, + check_run_test_on_diff_or_main, + requires_torch_greater_or_equal, +) class BaseEncoderTester(unittest.TestCase): @@ -29,11 +33,19 @@ class BaseEncoderTester(unittest.TestCase): depth_to_test = [3, 4, 5] strides_to_test = [8, 16] # 32 is a default one + # enable/disable tests + do_test_torch_compile = True + do_test_torch_export = True + def get_tiny_encoder(self): return smp.encoders.get_encoder(self.encoder_names[0], encoder_weights=None) @lru_cache - def _get_sample(self, batch_size=1, num_channels=3, height=32, width=32): + def _get_sample(self, batch_size=None, num_channels=None, height=None, width=None): + batch_size = batch_size or self.default_batch_size + num_channels = num_channels or self.default_num_channels + height = height or self.default_height + width = width or self.default_width return torch.rand(batch_size, num_channels, height, width) def get_features_output_strides(self, sample, features): @@ -43,12 +55,7 @@ def get_features_output_strides(self, sample, features): return height_strides, width_strides def test_forward_backward(self): - sample = self._get_sample( - batch_size=self.default_batch_size, - num_channels=self.default_num_channels, - height=self.default_height, - width=self.default_width, - ).to(default_device) + sample = self._get_sample().to(default_device) for encoder_name in self.encoder_names: with self.subTest(encoder_name=encoder_name): # init encoder @@ -75,12 +82,7 @@ def test_in_channels(self): ] for encoder_name, in_channels in cases: - sample = self._get_sample( - batch_size=self.default_batch_size, - num_channels=in_channels, - height=self.default_height, - width=self.default_width, - ).to(default_device) + sample = self._get_sample(num_channels=in_channels).to(default_device) with self.subTest(encoder_name=encoder_name, in_channels=in_channels): encoder = smp.encoders.get_encoder( @@ -93,12 +95,7 @@ def test_in_channels(self): encoder.forward(sample) def test_depth(self): - sample = self._get_sample( - batch_size=self.default_batch_size, - num_channels=self.default_num_channels, - height=self.default_height, - width=self.default_width, - ).to(default_device) + sample = self._get_sample().to(default_device) cases = [ (encoder_name, depth) @@ -157,12 +154,7 @@ def test_depth(self): ) def test_dilated(self): - sample = self._get_sample( - batch_size=self.default_batch_size, - num_channels=self.default_num_channels, - height=self.default_height, - width=self.default_width, - ).to(default_device) + sample = self._get_sample().to(default_device) cases = [ (encoder_name, stride) @@ -216,15 +208,15 @@ def test_dilated(self): @pytest.mark.compile def test_compile(self): + if not self.do_test_torch_compile: + self.skipTest( + f"torch_compile test is disabled for {self.encoder_names[0]}." + ) + if not check_run_test_on_diff_or_main(self.files_for_diff): self.skipTest("No diff and not on `main`.") - sample = self._get_sample( - batch_size=self.default_batch_size, - num_channels=self.default_num_channels, - height=self.default_height, - width=self.default_width, - ).to(default_device) + sample = self._get_sample().to(default_device) encoder = self.get_tiny_encoder().eval().to(default_device) compiled_encoder = torch.compile(encoder, fullgraph=True, dynamic=True) @@ -233,16 +225,15 @@ def test_compile(self): compiled_encoder(sample) @pytest.mark.torch_export + @requires_torch_greater_or_equal("2.4.0") def test_torch_export(self): + if not self.do_test_torch_export: + self.skipTest(f"torch_export test is disabled for {self.encoder_names[0]}.") + if not check_run_test_on_diff_or_main(self.files_for_diff): self.skipTest("No diff and not on `main`.") - sample = self._get_sample( - batch_size=self.default_batch_size, - num_channels=self.default_num_channels, - height=self.default_height, - width=self.default_width, - ).to(default_device) + sample = self._get_sample().to(default_device) encoder = self.get_tiny_encoder() encoder = encoder.eval().to(default_device) diff --git a/tests/encoders/test_pretrainedmodels_encoders.py b/tests/encoders/test_pretrainedmodels_encoders.py index 868f686d..e77c3652 100644 --- a/tests/encoders/test_pretrainedmodels_encoders.py +++ b/tests/encoders/test_pretrainedmodels_encoders.py @@ -12,6 +12,10 @@ class TestDPNEncoder(base.BaseEncoderTester): ) files_for_diff = ["encoders/dpn.py"] + # works with torch 2.4.0, but not with torch 2.5.1 + # dynamo error, probably on Sequential + OrderedDict + do_test_torch_export = False + def get_tiny_encoder(self): params = { "stage_idxs": (2, 3, 4, 5), @@ -29,17 +33,21 @@ def get_tiny_encoder(self): class TestInceptionResNetV2Encoder(base.BaseEncoderTester): - supports_dilated = False encoder_names = ( ["inceptionresnetv2"] if not RUN_ALL_ENCODERS else ["inceptionresnetv2"] ) files_for_diff = ["encoders/inceptionresnetv2.py"] + supports_dilated = False class TestInceptionV4Encoder(base.BaseEncoderTester): - supports_dilated = False encoder_names = ["inceptionv4"] if not RUN_ALL_ENCODERS else ["inceptionv4"] files_for_diff = ["encoders/inceptionv4.py"] + supports_dilated = False + + # works with torch 2.4.0, but not with torch 2.5.1 + # dynamo error, probably on Sequential + OrderedDict + do_test_torch_export = False class TestSeNetEncoder(base.BaseEncoderTester): diff --git a/tests/encoders/test_smp_encoders.py b/tests/encoders/test_smp_encoders.py index 876d9266..f65a61b8 100644 --- a/tests/encoders/test_smp_encoders.py +++ b/tests/encoders/test_smp_encoders.py @@ -63,5 +63,5 @@ class TestEfficientNetEncoder(base.BaseEncoderTester): ) files_for_diff = ["encoders/efficientnet.py"] - def test_compile(self): - self.skipTest("compile fullgraph is not supported for efficientnet encoders") + # torch_compile is not supported for efficientnet encoders + do_test_torch_compile = False