Skip to content

Commit

Permalink
Disable export tests for dpn and inceptionv4
Browse files Browse the repository at this point in the history
  • Loading branch information
qubvel committed Jan 13, 2025
1 parent a806147 commit aa5b088
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 42 deletions.
67 changes: 29 additions & 38 deletions tests/encoders/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@
import segmentation_models_pytorch as smp

from functools import lru_cache
from tests.utils import default_device, check_run_test_on_diff_or_main
from tests.utils import (
default_device,
check_run_test_on_diff_or_main,
requires_torch_greater_or_equal,
)


class BaseEncoderTester(unittest.TestCase):
Expand All @@ -29,11 +33,19 @@ class BaseEncoderTester(unittest.TestCase):
depth_to_test = [3, 4, 5]
strides_to_test = [8, 16] # 32 is a default one

# enable/disable tests
do_test_torch_compile = True
do_test_torch_export = True

def get_tiny_encoder(self):
return smp.encoders.get_encoder(self.encoder_names[0], encoder_weights=None)

@lru_cache
def _get_sample(self, batch_size=1, num_channels=3, height=32, width=32):
def _get_sample(self, batch_size=None, num_channels=None, height=None, width=None):
batch_size = batch_size or self.default_batch_size
num_channels = num_channels or self.default_num_channels
height = height or self.default_height
width = width or self.default_width
return torch.rand(batch_size, num_channels, height, width)

def get_features_output_strides(self, sample, features):
Expand All @@ -43,12 +55,7 @@ def get_features_output_strides(self, sample, features):
return height_strides, width_strides

def test_forward_backward(self):
sample = self._get_sample(
batch_size=self.default_batch_size,
num_channels=self.default_num_channels,
height=self.default_height,
width=self.default_width,
).to(default_device)
sample = self._get_sample().to(default_device)
for encoder_name in self.encoder_names:
with self.subTest(encoder_name=encoder_name):
# init encoder
Expand All @@ -75,12 +82,7 @@ def test_in_channels(self):
]

for encoder_name, in_channels in cases:
sample = self._get_sample(
batch_size=self.default_batch_size,
num_channels=in_channels,
height=self.default_height,
width=self.default_width,
).to(default_device)
sample = self._get_sample(num_channels=in_channels).to(default_device)

with self.subTest(encoder_name=encoder_name, in_channels=in_channels):
encoder = smp.encoders.get_encoder(
Expand All @@ -93,12 +95,7 @@ def test_in_channels(self):
encoder.forward(sample)

def test_depth(self):
sample = self._get_sample(
batch_size=self.default_batch_size,
num_channels=self.default_num_channels,
height=self.default_height,
width=self.default_width,
).to(default_device)
sample = self._get_sample().to(default_device)

cases = [
(encoder_name, depth)
Expand Down Expand Up @@ -157,12 +154,7 @@ def test_depth(self):
)

def test_dilated(self):
sample = self._get_sample(
batch_size=self.default_batch_size,
num_channels=self.default_num_channels,
height=self.default_height,
width=self.default_width,
).to(default_device)
sample = self._get_sample().to(default_device)

cases = [
(encoder_name, stride)
Expand Down Expand Up @@ -216,15 +208,15 @@ def test_dilated(self):

@pytest.mark.compile
def test_compile(self):
if not self.do_test_torch_compile:
self.skipTest(
f"torch_compile test is disabled for {self.encoder_names[0]}."
)

if not check_run_test_on_diff_or_main(self.files_for_diff):
self.skipTest("No diff and not on `main`.")

sample = self._get_sample(
batch_size=self.default_batch_size,
num_channels=self.default_num_channels,
height=self.default_height,
width=self.default_width,
).to(default_device)
sample = self._get_sample().to(default_device)

encoder = self.get_tiny_encoder().eval().to(default_device)
compiled_encoder = torch.compile(encoder, fullgraph=True, dynamic=True)
Expand All @@ -233,16 +225,15 @@ def test_compile(self):
compiled_encoder(sample)

@pytest.mark.torch_export
@requires_torch_greater_or_equal("2.4.0")
def test_torch_export(self):
if not self.do_test_torch_export:
self.skipTest(f"torch_export test is disabled for {self.encoder_names[0]}.")

if not check_run_test_on_diff_or_main(self.files_for_diff):
self.skipTest("No diff and not on `main`.")

sample = self._get_sample(
batch_size=self.default_batch_size,
num_channels=self.default_num_channels,
height=self.default_height,
width=self.default_width,
).to(default_device)
sample = self._get_sample().to(default_device)

encoder = self.get_tiny_encoder()
encoder = encoder.eval().to(default_device)
Expand Down
12 changes: 10 additions & 2 deletions tests/encoders/test_pretrainedmodels_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ class TestDPNEncoder(base.BaseEncoderTester):
)
files_for_diff = ["encoders/dpn.py"]

# works with torch 2.4.0, but not with torch 2.5.1
# dynamo error, probably on Sequential + OrderedDict
do_test_torch_export = False

def get_tiny_encoder(self):
params = {
"stage_idxs": (2, 3, 4, 5),
Expand All @@ -29,17 +33,21 @@ def get_tiny_encoder(self):


class TestInceptionResNetV2Encoder(base.BaseEncoderTester):
supports_dilated = False
encoder_names = (
["inceptionresnetv2"] if not RUN_ALL_ENCODERS else ["inceptionresnetv2"]
)
files_for_diff = ["encoders/inceptionresnetv2.py"]
supports_dilated = False


class TestInceptionV4Encoder(base.BaseEncoderTester):
supports_dilated = False
encoder_names = ["inceptionv4"] if not RUN_ALL_ENCODERS else ["inceptionv4"]
files_for_diff = ["encoders/inceptionv4.py"]
supports_dilated = False

# works with torch 2.4.0, but not with torch 2.5.1
# dynamo error, probably on Sequential + OrderedDict
do_test_torch_export = False


class TestSeNetEncoder(base.BaseEncoderTester):
Expand Down
4 changes: 2 additions & 2 deletions tests/encoders/test_smp_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,5 +63,5 @@ class TestEfficientNetEncoder(base.BaseEncoderTester):
)
files_for_diff = ["encoders/efficientnet.py"]

def test_compile(self):
self.skipTest("compile fullgraph is not supported for efficientnet encoders")
# torch_compile is not supported for efficientnet encoders
do_test_torch_compile = False

0 comments on commit aa5b088

Please sign in to comment.